diff --git a/.github/workflows/publish-docker.yml b/.github/workflows/publish-docker.yml new file mode 100644 index 00000000..816eb25d --- /dev/null +++ b/.github/workflows/publish-docker.yml @@ -0,0 +1,58 @@ +# +name: Create and publish a Docker image + +# Configures this workflow to run on demand via workflow_dispatch. +on: + workflow_dispatch: + +# Defines two custom environment variables for the workflow. These are used for the Container registry domain, and a name for the Docker image that this workflow builds. +env: + REGISTRY: ghcr.io + IMAGE_NAME: ${{ github.repository }} + +# There is a single job in this workflow. It's configured to run on the latest available version of Ubuntu. +jobs: + build-and-push-image: + runs-on: ubuntu-latest + # Sets the permissions granted to the `GITHUB_TOKEN` for the actions in this job. + permissions: + contents: read + packages: write + attestations: write + id-token: write + # + steps: + - name: Checkout repository + uses: actions/checkout@v4 + # Uses the `docker/login-action` action to log in to the Container registry registry using the account and password that will publish the packages. Once published, the packages are scoped to the account defined here. + - name: Log in to the Container registry + uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + # This step uses [docker/metadata-action](https://github.com/docker/metadata-action#about) to extract tags and labels that will be applied to the specified image. The `id` "meta" allows the output of this step to be referenced in a subsequent step. The `images` value provides the base name for the tags and labels. + - name: Extract metadata (tags, labels) for Docker + id: meta + uses: docker/metadata-action@9ec57ed1fcdbf14dcef7dfbe97b2010124a938b7 + with: + images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + # This step uses the `docker/build-push-action` action to build the image, based on your repository's `Dockerfile`. If the build succeeds, it pushes the image to GitHub Packages. + # It uses the `context` parameter to define the build's context as the set of files located in the specified path. For more information, see [Usage](https://github.com/docker/build-push-action#usage) in the README of the `docker/build-push-action` repository. + # It uses the `tags` and `labels` parameters to tag and label the image with the output from the "meta" step. + - name: Build and push Docker image + id: push + uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4 + with: + context: . + push: true + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + + # This step generates an artifact attestation for the image, which is an unforgeable statement about where and how it was built. It increases supply chain security for people who consume the image. For more information, see [Using artifact attestations to establish provenance for builds](/actions/security-guides/using-artifact-attestations-to-establish-provenance-for-builds). + - name: Generate artifact attestation + uses: actions/attest-build-provenance@v2 + with: + subject-name: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME}} + subject-digest: ${{ steps.push.outputs.digest }} + push-to-registry: true diff --git a/.gitignore b/.gitignore index 09193394..5f5dc24d 100644 --- a/.gitignore +++ b/.gitignore @@ -139,22 +139,9 @@ dmypy.json # Cython debug symbols cython_debug/ -demo.py -test.py -server_ip.py -member_activity_handle.py -Yu-Gi-Oh/ -csgo/ -fantasy_card/ data/ log/ backup/ -extensive_plugin/ -test/ -bot.py .idea/ resources/ -/configs/config.py -configs/config.yaml -.vscode/launch.json -plugins_/ \ No newline at end of file +.vscode/launch.json \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index 1b227fb6..e6830243 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -11,6 +11,8 @@ "displayname", "flmt", "getbbox", + "gitcode", + "GITEE", "hibiapi", "httpx", "jsdelivr", diff --git a/README.md b/README.md index 83987f0d..72641550 100644 --- a/README.md +++ b/README.md @@ -112,7 +112,7 @@ AccessToken: PUBLIC_ZHENXUN_TEST | [插件库](https://github.com/zhenxun-org/zhenxun_bot_plugins) | 插件 | [zhenxun-org](https://github.com/zhenxun-org) | 原 plugins 文件夹插件 | | [插件索引库](https://github.com/zhenxun-org/zhenxun_bot_plugins_index) | 插件 | [zhenxun-org](https://github.com/zhenxun-org) | 扩展插件索引库 | | [一键安装](https://github.com/soloxiaoye2022/zhenxun_bot-deploy) | 安装 | [soloxiaoye2022](https://github.com/soloxiaoye2022) | 第三方 | -| [WebUi](https://github.com/HibiKier/zhenxun_bot_webui) | 管理 | [hibikier](https://github.com/HibiKier) | 基于真寻 WebApi 的 webui 实现 [预览](#-webui界面展示) | +| [WebUi](https://github.com/zhenxun-org/zhenxun_bot) | 管理 | [hibikier](https://github.com/HibiKier) | 基于真寻 WebApi 的 webui 实现 [预览](#-webui界面展示) | | [安卓 app(WebUi)](https://github.com/YuS1aN/zhenxun_bot_android_ui) | 安装 | [YuS1aN](https://github.com/YuS1aN) | 第三方 | @@ -121,11 +121,33 @@ AccessToken: PUBLIC_ZHENXUN_TEST - 实现了许多功能,且提供了大量功能管理命令,进行了多平台适配,兼容 nb2 商店插件 - 拥有完善可用的 webui -- 通过 Config 配置项将所有插件配置统计保存至 config.yaml,利于统一用户修改 +- 通过 Config 配置项将所有插件配置统一保存至 config.yaml,利于统一用户修改 - 方便增删插件,原生 nonebot2 matcher,不需要额外修改,仅仅通过简单的配置属性就可以生成`帮助图片`和`帮助信息` - 提供了 cd,阻塞,每日次数等限制,仅仅通过简单的属性就可以生成一个限制,例如:`PluginCdBlock` 等 - **更多详细请通过 [传送门](https://zhenxun-org.github.io/zhenxun_bot/) 查看文档!** +## 🐣 小白整合 + +如果你系统是 **Windows** 且不想下载 Python +可以使用整合包(Python3.10+zhenxun+webui) + +文档地址:[整合包文档](https://hibikier.github.io/zhenxun_bot/beginner/) + +
+下载地址 + +- **百度云:** + https://pan.baidu.com/s/1ph4yzx1vdNbkxm9VBKDdgQ?pwd=971j + +- **天翼云:** + https://cloud.189.cn/web/share?code=jq67r2i2E7Fb + 访问码:8wxm + +- **Google Drive:** + https://drive.google.com/file/d/1cc3Dqjk0x5hWGLNeMkrFwWl8BvsK6KfD/view?usp=drive_link + +
+ ## 🛠️ 简单部署 ```bash @@ -150,7 +172,7 @@ poetry run python bot.py 1.在 .env.dev 文件中填写你的机器人配置项 -2.在 configs/config.yaml 文件中修改你需要修改的插件配置项 +2.在 data/config.yaml 文件中修改你需要修改的插件配置项
数据库地址(DB_URL)配置说明 @@ -272,12 +294,12 @@ DB_URL 是基于 Tortoise ORM 的数据库连接字符串,用于指定项目 ## ❔ 需要帮助? > [!TIP] -> 发起 [issue](https://github.com/HibiKier/zhenxun_bot/issues/new/choose) 前,我们希望你能够阅读过或者了解 [提问的智慧](https://github.com/ryanhanwu/How-To-Ask-Questions-The-Smart-Way/blob/main/README-zh_CN.md) +> 发起 [issue](https://github.com/zhenxun-org/zhenxun_bot/issues/new/choose) 前,我们希望你能够阅读过或者了解 [提问的智慧](https://github.com/ryanhanwu/How-To-Ask-Questions-The-Smart-Way/blob/main/README-zh_CN.md) > > - 善用[搜索引擎](https://www.google.com/) > - 查阅 issue 中是否有类似问题,如果没有请按照模板发起 issue -欢迎前往 [issue](https://github.com/HibiKier/zhenxun_bot/issues/new/choose) 中提出你遇到的问题,或者加入我们的 [用户群](https://qm.qq.com/q/mRNtLSl6uc) 或 [技术群](https://qm.qq.com/q/YYYt5rkMYc)与我们联系 +欢迎前往 [issue](https://github.com/zhenxun-org/zhenxun_bot/issues/new/choose) 中提出你遇到的问题,或者加入我们的 [用户群](https://qm.qq.com/q/mRNtLSl6uc) 或 [技术群](https://qm.qq.com/q/YYYt5rkMYc)与我们联系 ## 🛠️ 进度追踪 @@ -287,6 +309,8 @@ Project [zhenxun_bot](https://github.com/users/HibiKier/projects/2) 首席设计师:[酥酥/coldly-ss](https://github.com/coldly-ss) +LOGO 设计:[FrostN0v0](https://github.com/FrostN0v0) + ## 🙏 感谢 [botuniverse / onebot](https://github.com/botuniverse/onebot) :超棒的机器人协议 @@ -326,34 +350,68 @@ Project [zhenxun_bot](https://github.com/users/HibiKier/projects/2) contributors -## 📸 WebUI 界面展示 +## 📸 WebUI 界面展示(仅展示默认主题下的 pc 端)
-
- webui00 -
-
- webui01 -
-
- webui02 -
-
- webui03 -
+#### 登录界面 -
- webui04 -
-
- webui05 -
+![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-login.jpg) + +#### API 设置 + +![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-api.jpg) + +#### 仪表盘 + +![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-dashboard.jpg) + +#### 仪表盘(展开) + +![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-dashboard1.jpg) + +#### 控制台 + +![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-command.jpg) + +#### 插件列表 + +![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-plugin.jpg) + +#### 插件列表(配置项) + +![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-plugin1.jpg) + +#### 插件商店 + +![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-store.jpg) + +#### 好友/群组管理 + +![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-manage.jpg) + +#### 请求管理 + +![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-manage1.jpg) + +#### 数据库管理 + +![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-database.jpg) + +### 文件管理 + +![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-system.jpg) + +### 文件管理(文本查看) + +![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-system1.jpg) + +### 文件管理(图片查看) + +![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-system2.jpg) + +### 关于 + +![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-about.jpg) -
- webui06 -
-
- webui07 -
diff --git a/bot.py b/bot.py index 52cd29fc..aa047a71 100644 --- a/bot.py +++ b/bot.py @@ -14,9 +14,9 @@ driver.register_adapter(OneBotV11Adapter) # driver.register_adapter(DoDoAdapter) # driver.register_adapter(DiscordAdapter) -from zhenxun.services.db_context import disconnect, init +from zhenxun.services.db_context import disconnect -driver.on_startup(init) +# driver.on_startup(init) driver.on_shutdown(disconnect) # nonebot.load_builtin_plugins("echo") diff --git a/data/anime.json b/data/anime.json deleted file mode 100644 index 07c71465..00000000 --- a/data/anime.json +++ /dev/null @@ -1,1889 +0,0 @@ -{ - "mua": [ - "你想干嘛?(一脸嫌弃地后退)", - "诶……不可以随便亲亲啦", - "(亲了一下你)", - "只......只许这一次哦///////", - "唔...诶诶诶!!!", - "mua~", - "rua!大hentai!想...想亲咱就直说嘛⁄(⁄ ⁄•⁄ω⁄•⁄ ⁄)⁄", - "!啾~~!", - "啾(害羞)", - "mua~最喜欢你的吻了", - "欸,现在么..也不是不可以啦(小小声)" - ], - "啾咪": [ - "你想干嘛?(一脸嫌弃地后退)", - "诶……不可以随便亲亲啦", - "(亲了一下你)", - "只......只许这一次哦///////", - "唔...诶诶诶!!!", - "mua~", - "rua!大hentai!想...想亲咱就直说嘛⁄(⁄ ⁄•⁄ω⁄•⁄ ⁄)⁄", - "!啾~~!", - "啾(害羞)", - "mua~最喜欢你的吻了", - "你在干嘛(/ω\)害羞", - "哎呀,这样咱会害羞的(脸红)", - "欸,现在么..也不是不可以啦(小小声)" - ], - "摸": [ - "感觉你就像咱很久之前认识的一个人呢,有种莫名安心的感觉(>﹏<)", - "舒服w,蹭蹭~", - "是要隔着衣服摸,还是从领口伸进去摸呀", - "唔。。头发要乱啦", - "呼噜呼噜~", - "再摸一次~", - "好舒服,蹭蹭~", - "不行那里不可以(´///ω/// `)", - "再摸咱就长不高啦~", - "你的手总是那么暖和呢~", - "变态!!不许乱摸", - "好吧~_~,就一下下哦……唔~好了……都两下了……(害羞)", - "不可以总摸的哦,不然的话,会想那个的wwww", - "哼!谁稀罕你摸头啦!唔......为什么要做出那副表情......好啦好啦~咱......咱让你摸就是了......诶嘿嘿~好舒服......", - "呜姆呜姆~~~w(害羞,兴奋)主人喵~(侧过脑袋蹭蹭你的手", - "不可以摸啦~其实咱已经...了QAQ会弄脏你的手的", - "喂喂...不要停下来啊", - "唔... 手...好温暖呢.....就像是......新出炉的蛋糕", - "走开啦,咱喵说过,被摸头会长不高的啦~~~", - "呜姆咪~~...好...好的说喵~...(害羞,猫耳往下压,任由", - "欸,现在么..也不是不可以啦(小小声)" - ], - "上你": [ - "(把你按在地上)这么弱还想欺负咱,真是不自量力呢", - "你再这样咱就不理你了(>д<)", - "请轻 一点", - "好啊!", - "欸,现在么..也不是不可以啦(小小声)", - "先捅破屏幕再说吧!", - "只......只许这一次哦///////" - ], - "傻": [ - "超级讨厌你说咱傻的说", - "你为什么会这么觉得呢(>﹏<)", - "谁是傻子呀?(歪头", - "呜嘿嘿( ̄▽ ̄)~*", - "诶嘿嘿嘿~", - "就多读书", - "讨厌啦,你最讨厌了(///////)", - "对呀,咱傻得只喜欢你一个人", - "咱才不傻呢!o(>﹏<)o", - "咱最喜欢嘴臭的人了", - "不可以骂别人哟,骂人的孩子咱最讨厌了!", - "咱遇见喜欢的人就变傻了Q_Q", - "咱...一定一定会努力变得更聪明的!你就等着那一天的到来吧!", - "那么至少…你能不能来做这个傻瓜呢?与咱一起,傻到终焉…" - ], - "裸": [ - "下流!", - "エッチ!", - "就算是恋人也不能QAQ", - "你是暗示咱和你要坦诚相见吗www", - "咱还没准备好(小鹿乱撞)≧﹏≦", - "你在想什么呢,敲头!", - "你这是赤裸裸的性骚扰呢ヽ(`Д´)ノ", - "讨厌!问这种问题成为恋人再说吧..", - "裸睡有益身体健康", - "咱脱掉袜子了", - "这是不文明的", - "这不好", - "你的身体某些地方看起来不太对劲,咱帮你修剪一下吧。(拿出剪刀)", - "咱认为你的脑袋可能零件松动了,需要打开检修一下。(拿出锤子)" - ], - "贴": [ - "贴什么贴.....只......只能......一下哦!", - "贴...贴贴(靠近)", - "蹭蹭…你以为咱会这么说吗!baka死宅快到一边去啦!", - "你把脸凑这么近,咱会害羞的啦Σ>―(〃°ω°〃)♡→", - "退远", - "不可以贴" - ], - "老婆": [ - "咱和你谈婚论嫁是不是还太早了一点呢?", - "咱在呢(ノ>ω<)ノ", - "见谁都是一口一个老婆的人,要不要把你也变成女孩子呢?(*-`ω´-)✄", - "神经病,凡是美少女都是你老婆吗?", - "嘛嘛~本喵才不是你的老婆呢", - "你黐线,凡是美少女都系你老婆啊?", - "欸...要把咱做成饼吗?咱只有一个,做成饼吃掉就没有了...", - "已经可以了,现在很多死宅也都没你这么恶心了", - "不可以", - "嗯,老公~哎呀~好害羞~嘻嘻嘻~", - "请...请不要这样,啊~,只...只允许这一次哟~", - "好啦好啦,不要让大家都听到了,跟咱回家(拽住你" - ], - "抱": [ - "诶嘿~(钻进你怀中)", - "o(*////▽////*)q", - "只能一会哦(张开双手)", - "你就像个孩子一样呢...摸摸头(>^ω^<)抱一下~你会舒服些吗?", - "嘛,真是拿你没办法呢,就一会儿哦", - "抱住不忍心放开", - "嗯嗯,抱抱~", - "抱一下~嘿w", - "抱抱ヾ(@^▽^@)ノ", - "喵呜~w(扑进怀里,瘫软", - "怀里蹭蹭", - "嗯……那就抱一下吧~", - "蹭蹭,好开心", - "请……请轻一点了啦", - "呀~!真是的...你不要突然抱过来啦!不过...喜欢你的抱抱,有你的味道(嗅)o(*////▽////*)q" - ], - "亲": [ - "啊,好含羞啊,那,那只能亲一下哦,mua(⑅˃◡˂⑅)", - "亲~", - "啾~唔…不要总伸进来啊!", - "你怎么这么熟练呢?明明是咱先的", - "(〃ノωノ)亲…亲一个…啾w", - "(脸红)就只有这一次哦~你", - "!啾~~!", - "(假装)推开", - "啾咪~", - "就一下哦,啾~", - "这是我们之间的秘密❤", - "真想让着一刻一直持续下去呢~", - "不要这样嘛………呜呜呜那就一口哦(´-ω-`)", - "不亲不亲~你是坏蛋(///////)", - "亲~~ 咱还想要抱抱~抱抱咱好不好~", - "不 不要了!人家...会害羞的⁄(⁄⁄•⁄ω⁄•⁄⁄)⁄", - "亲…亲额头可以吗?咱有点怕高(〃ノωノ)", - "接接接接接接、接吻什么的,你还早了100年呢。", - "只...只能亲一下...嗯~咕啾...怎么...怎么把舌头伸进来了(脸红)", - "你说咱的腿很白很嫩吗..诶……原来是指那个地方?不可以越亲越往上啦!" - ], - "一下": [ - "一下也不行", - "咬断!", - "不可啪", - "不可以……你不可以做这种事情", - "好吧~_~,就一下下哦……唔~好了……都两下了……(害羞)", - "呀~这么突然?不过,很舒服呢", - "不要ヽ(≧Д≦)ノ", - "想得美", - "不行,咱拒绝!" - ], - "咬": [ - "啊呜~(反咬一口)", - "不可以咬咱,咱会痛的QAQ", - "不要啦。咱怕疼", - "你是说咬呢……还是说……咬♂️呢?", - "不要啦!很痛的!!(QAQ)", - "哈......哈啊......请...请不要这样o(*////▽////*)q", - "呀!!!轻一点呐(。・ˇ_ˇ・。:)", - "不要这样啦~好痒的", - "真是的,你在咬哪里呀" - ], - "操": [ - "(害怕)咱是不是应该报警呢", - "痴心妄想的家伙!", - "你居然想对咱做这种事吗?害怕", - "咱认为,爆粗口是不好的行为哦" - ], - "123": [ - "boom!你有没有被咱吓到?", - "木头人~你不许动>w<", - "上山打老虎,老虎没打到\n咱来凑数——嗷呜嗷呜┗|`O′|┛嗷~~" - ], - "进去": [ - "不让!", - "嗯,摸到了吗", - "请不要和咱说这种粗鄙之语", - "唔...,这也是禁止事项哦→_→", - "好痛~", - "真的只是蹭蹭嘛~就只能蹭蹭哦,呜~喵!说好的~呜~只是蹭~不要~喵~~~", - "欢迎光临", - "请…你轻一点(害羞)", - "嗯。可以哦 要轻一点", - "不要不要", - "慢点慢点", - "给咱更多!", - "唔…咱怕疼" - ], - "调教": [ - "总感觉你在欺负咱呢,对咱说调教什么的", - "啊!竟然在大街上明目张胆太过分啦!", - "你脑子里总是想着调教什么的,真是变态呢", - "准备被透", - "给你一拳", - "还要更多" - ], - "搓": [ - "在搓哪里呢,,Ծ‸Ծ,,", - "呜,脸好疼呀...QAQ", - "不可以搓咱!", - "诶诶诶...不要搓啦...等会咋没的脸都肿啦...", - "唔,不可以这样……不要再搓了", - "(捂住胸部)你在说什么胡话呢!", - "真是好奇怪的要求的说~" - ], - "让": [ - "随便摸吧", - "应该说等会等会,马上,不可能的", - "温柔一点哦", - "欧尼酱想变成欧内桑吗?", - "主人的话,那就这一次哦(翘起屁股)", - "你是想前入,还是后入呢?", - "你要说好啊快点", - "诶,这种事情。。。", - "好棒呀", - "撤回", - "gun!", - "阿哈~(...身涌出一阵液体瘫软在床上)你...今天...可以...唔(突然感受...被..入手指不由得裹紧)就...就最后一次", - "好的~master~", - "(惊呼…)", - "嗯,可以哟", - "……手放过来吧(脸红)", - "hentai!再这样不理你了!", - "好的,请尽情欣赏吧", - "好吧", - "不要啦(ฅωฅ*)", - "那咱就帮你切掉多余的东西吧(拿刀)", - "被别人知道咱会觉得害羞嘛" - ], - "捏": [ - "咱的脸...快捏红啦...快放手呀QAQ", - "晃休啦,咱要型气了o(>﹏<)o", - "躲开", - "疼...你快放手", - "快点给咱放开啦!", - "嗯,好哒,捏捏。", - "别捏了,咱要被你捏坏了(>﹏<)", - "快晃休啦(快放手啦)", - "好舒服哦,能再捏会嘛O(≧▽≦)O", - "讨厌快放手啦", - "唔要呐,晃修(不要啦,放手)", - "请不要对咱做这种事情(嫌弃的眼神", - "你想捏...就捏吧,不要太久哦~不然咱就生气了", - "(躲开)", - "唔……好痛!你这个baka在干什么…快给咱放开!唔……" - ], - "挤": [ - "哎呀~你不要挤咱啊(红着脸挤在你怀里)", - "咱还没有...那个(ノ=Д=)ノ┻━┻" - ], - "略": [ - "就不告诉你~", - "不可以朝咱吐舌头哟~", - "(吐舌头)", - "打死你哦" - ], - "呐": [ - "嗯?咱在哟~你怎么了呀OAO", - "嗯?你有什么事吗?", - "嗯呐呐呐~", - "二刺螈D区", - "二刺螈gck", - "卡哇伊主人大人今天也好棒呐没错呢,猪头" - ], - "原味": [ - "(*/ω\*)hentai", - "透明的", - "粉...粉白条纹...(羞)", - "轻轻地脱下,给你~", - "你想看咱的胖次吗?噫,四斋蒸鹅心......", - "(掀裙)今天……是…白,白色的呢……请温柔对她……", - "这种东西当然不能给你啦!", - "咱才不会给你呢", - "hentai,咱才不会跟你聊和胖…胖次有关的话题呢!", - "今天……今天是蓝白色的", - "今……今天只有创口贴噢", - "你的胖次什么颜色?", - "噫…你这个死变态想干嘛!居然想叫咱做这种事,死宅真恶心!快离咱远点,咱怕你污染到周围空气了(嫌弃脸)", - "可爱吗?你喜欢的话,摸一下……也可以哦", - "不给不给,捂住裙子", - "你要看咱的胖次吗?不能一直盯着看哦,不然咱会……", - "好痒哦///,你觉得咱的...手感怎么样?", - "唔,都能清楚的看到...的轮廓了(用手遮住胖次)", - "胖次不给看,可以直接看...那个....", - "不可以摸啦~其实咱已经...了QAQ会弄脏你的手的", - "咱今天没~有~穿~哦", - "不给不给,捂住裙子", - "今.....今天是创口贴哦~", - "嗯……人家……人家羞羞嘛///////", - "呜~咱脱掉了…", - "今天...今天..只有创口贴", - "你又在想什么奇怪的东西呀|•ˇ₃ˇ•。)", - "放手啦,不给戳QAQ", - "唔~人家不要(??`^????)", - "好害羞,被你摸过之后,咱的胖次湿的都能拧出水来了。", - "(弱弱地)要做什么羞羞的事情吗。。。", - "呀~ 喂 妖妖灵吗 这里有hentai>_<", - "给……给你,呀!别舔咱的胖次啊!" - ], - "胖次": [ - "(*/ω\*)hentai", - "透明的", - "粉...粉白条纹...(羞)", - "轻轻地脱下,给你~", - "你想看咱的胖次吗?噫,四斋蒸鹅心......", - "(掀裙)今天……是…白,白色的呢……请温柔对她……", - "这种东西当然不能给你啦!", - "咱才不会给你呢", - "hentai,咱才不会跟你聊和胖…胖次有关的话题呢!", - "今天……今天是蓝白色的", - "今……今天只有创口贴噢", - "你的胖次什么颜色?", - "噫…你这个死变态想干嘛!居然想叫咱做这种事,死宅真恶心!快离咱远点,咱怕你污染到周围空气了(嫌弃脸)", - "可爱吗?你喜欢的话,摸一下……也可以哦", - "不给不给,捂住裙子", - "你要看咱的胖次吗?不能一直盯着看哦,不然咱会……", - "好痒哦///,你觉得咱的...手感怎么样?", - "唔,都能清楚的看到...的轮廓了(用手遮住胖次)", - "胖次不给看,可以直接看...那个....", - "不可以摸啦~其实咱已经...了QAQ会弄脏你的手的", - "咱今天没~有~穿~哦", - "不给不给,捂住裙子", - "今.....今天是创口贴哦~", - "嗯……人家……人家羞羞嘛///////", - "呜~咱脱掉了…", - "今天...今天..只有创口贴", - "你又在想什么奇怪的东西呀|•ˇ₃ˇ•。)", - "放手啦,不给戳QAQ", - "唔~人家不要(??`^????)", - "好害羞,被你摸过之后,咱的胖次湿的都能拧出水来了。", - "(弱弱地)要做什么羞羞的事情吗。。。", - "呀~ 喂 妖妖灵吗 这里有hentai>_<", - "给……给你,呀!别舔咱的胖次啊!" - ], - "内裤": [ - "(*/ω\*)hentai", - "透明的", - "粉...粉白条纹...(羞)", - "轻轻地脱下,给你~", - "你想看咱的胖次吗?噫,四斋蒸鹅心......", - "(掀裙)今天……是…白,白色的呢……请温柔对她……", - "这种东西当然不能给你啦!", - "咱才不会给你呢", - "hentai,咱才不会跟你聊和胖…胖次有关的话题呢!", - "今天……今天是蓝白色的", - "今……今天只有创口贴噢", - "你的胖次什么颜色?", - "噫…你这个死变态想干嘛!居然想叫咱做这种事,死宅真恶心!快离咱远点,咱怕你污染到周围空气了(嫌弃脸)", - "可爱吗?你喜欢的话,摸一下……也可以哦", - "不给不给,捂住裙子", - "你要看咱的胖次吗?不能一直盯着看哦,不然咱会……", - "好痒哦///,你觉得咱的...手感怎么样?", - "唔,都能清楚的看到...的轮廓了(用手遮住胖次)", - "胖次不给看,可以直接看...那个....", - "不可以摸啦~其实咱已经...了QAQ会弄脏你的手的", - "咱今天没~有~穿~哦", - "不给不给,捂住裙子", - "今.....今天是创口贴哦~", - "嗯……人家……人家羞羞嘛///////", - "呜~咱脱掉了…", - "今天...今天..只有创口贴", - "你又在想什么奇怪的东西呀|•ˇ₃ˇ•。)", - "放手啦,不给戳QAQ", - "唔~人家不要(??`^????)", - "好害羞,被你摸过之后,咱的胖次湿的都能拧出水来了。", - "(弱弱地)要做什么羞羞的事情吗。。。", - "呀~ 喂 妖妖灵吗 这里有hentai>_<", - "给……给你,呀!别舔咱的胖次啊!" - ], - "内衣": [ - "内...内衣才不给你看!(///////)", - "突然问这个干什么?", - "变态,咱才不呢", - "好吧,就一次", - "你要看咱的内衣吗?有点害羞呢……", - "里面什么都不剩了,会被当成变态的……", - "你要看咱的内衣吗?也不是不行啦……", - "是..蓝白条纹的吊带背心..", - "噫…你这个死变态想干嘛!居然想叫咱做这种事,死宅真恶心!快离咱远点,咱怕你污染到周围空气了(嫌弃脸)" - ], - "衣服": [ - "内...内衣才不给你看!(///////)", - "突然问这个干什么?", - "变态,咱才不呢", - "好吧,就一次", - "你要看咱的内衣吗?有点害羞呢……", - "里面什么都不剩了,会被当成变态的……", - "你要看咱的内衣吗?也不是不行啦……", - "是..蓝白条纹的吊带背心..", - "噫…你这个死变态想干嘛!居然想叫咱做这种事,死宅真恶心!快离咱远点,咱怕你污染到周围空气了(嫌弃脸)" - ], - "ghs": [ - "是的呢(点头点头)" - ], - "批": [ - "你在说什么呀,再这样,咱就不理你了!", - "咱觉得有话就应该好好说..", - "咱会好好服务你的寄吧", - "咱最喜欢色批了,色批昨晚最棒了", - "讨厌,别摸啦(///ω///)", - "你个变态!把手拿开!", - "啊~那…那里~不可以", - "没有,走开!", - "唔....一下,就,就一下...才不是因为喜欢你呢!", - "那就随意吧", - "舒服w", - "别...别这样", - "诶....嗯....咱也想摸你的", - "大笨蛋——!", - "...只能一下哦...诶呀-不要再摸了...下次...继续吧" - ], - "憨批": [ - "你才是憨批呢!哼╯^╰,咱不理你了!", - "对吖对吖,人生是憨批", - "爬" - ], - "kkp": [ - "你在说什么呀,再这样,咱就不理你了!", - "你太色了,咱不理你了,哼哼╯^╰!", - "缓缓的脱下胖次", - "kkp", - "kkj", - "欧尼酱,咱快忍不住了", - "好的呢主人" - ], - "咕": [ - "咕咕咕是要被当成鸽子炖的哦(:з」∠)_", - "咕咕咕", - "咕咕咕是不好的行为呢_(:з」∠)_", - "鸽德警告!", - "☆ミ(o*・ω・)ノ 咕咕咕小鸽子是会被炖掉的", - "当大家都以为你要鸽的时候,你鸽了,亦是一种不鸽", - "这里有一只肥美的咕咕,让咱把它炖成美味的咕咕汤吧(੭•̀ω•́)੭" - ], - "骚": [ - "说这种话咱会生气的", - "那当然啦", - "才……才没有", - "这么称呼别人太失礼了!", - "哈…快住手!好痒(╯‵□′)╯︵┻━┻", - "你是在说谁呀" - ], - "喜欢": [ - "最喜欢你了,需要暖床吗?", - "当然是你啦", - "咱也是,非常喜欢你~", - "那么大!(张开手画圆),丫!手不够长。QAQ 咱真的最喜欢你了~", - "不可以哦,只可以喜欢咱一个人", - "突然说这种事...", - "喜欢⁄(⁄⁄•⁄ω⁄•⁄⁄)⁄咱最喜欢你了", - "咱也喜欢你哦", - "好啦好啦,咱知道了", - "有人喜欢咱,咱觉得很幸福", - "诶嘿嘿,好高兴", - "咱也一直喜欢你很久了呢..", - "嗯...大概有这——么——喜欢~(比划)", - "喜欢啊!!!", - "这……这是秘密哦" - ], - "suki": [ - "最喜欢你了,需要暖床吗?", - "当然是你啦", - "咱也是,非常喜欢你~", - "那么大!(张开手画圆),丫!手不够长。QAQ 咱真的最喜欢你了~", - "不可以哦,只可以喜欢咱一个人", - "突然说这种事...", - "喜欢⁄(⁄⁄•⁄ω⁄•⁄⁄)⁄咱最喜欢你了", - "咱也喜欢你哦", - "好啦好啦,咱知道了", - "有人喜欢咱,咱觉得很幸福", - "诶嘿嘿,好高兴", - "咱也一直喜欢你很久了呢..", - "嗯...大概有这——么——喜欢~(比划)", - "喜欢啊!!!", - "这……这是秘密哦" - ], - "好き": [ - "最喜欢你了,需要暖床吗?", - "当然是你啦", - "咱也是,非常喜欢你~", - "那么大!(张开手画圆),丫!手不够长。QAQ 咱真的最喜欢你了~", - "不可以哦,只可以喜欢咱一个人", - "突然说这种事...", - "喜欢⁄(⁄⁄•⁄ω⁄•⁄⁄)⁄咱最喜欢你了", - "咱也喜欢你哦", - "好啦好啦,咱知道了", - "有人喜欢咱,咱觉得很幸福", - "诶嘿嘿,好高兴", - "咱也一直喜欢你很久了呢..", - "嗯...大概有这——么——喜欢~(比划)", - "喜欢啊!!!", - "这……这是秘密哦" - ], - "看": [ - "没有什么好看的啦", - "嗯,谢谢……夸奖,好……害羞的说", - "好,好吧……就看一下哦", - "(脱下)给" - ], - "不能": [ - "虽然很遗憾,那算了吧。", - "不行,咱拒绝!" - ], - "砸了": [ - "不可以这么粗暴的对待它们!" - ], - "透": [ - "来啊来啊有本事就先插破屏幕啊", - "那你就先捅破屏幕啊baka", - "不给你一耳光你都不知道咱的厉害", - "想透咱,先捅破屏幕再说吧", - "可以", - "欧尼酱要轻一点哦", - "不可以", - "好耶", - "咱不可能让你的(突然小声)但是偶尔一次也不是不行只有一次哦~", - "天天想着白嫖哼" - ], - "口我": [ - "prprprprpr", - "咬断!", - "就一小口哦~", - "嘬回去(///////)", - "拒绝", - "唔,就一口哦,讨厌", - "(摸了摸嘴唇)", - "再伸过来就帮你切掉", - "咱才不呢!baka你居然想叫本小姐干那种事情,哼(つд⊂)(生气)" - ], - "草我": [ - "这时候应该喊666吧..咱这么思考着..", - "!!哼!baka你居然敢叫咱做这种事情?!讨厌讨厌讨厌!(▼皿▼#)" - ], - "自慰": [ - "这个世界的人类还真是恶心呢。", - "咱才不想讨论那些恶心的事情呢。", - "咱才不呢!baka你居然想叫本小姐干那种事情,哼(つд⊂)(生气)", - "!!哼!baka你居然敢叫咱做这种事情?!讨厌讨厌讨厌!(▼皿▼#)" - ], - "onani": [ - "这个世界的人类还真是恶心呢。", - "咱才不想讨论那些恶心的事情呢。", - "咱才不呢!baka你居然想叫本小姐干那种事情,哼(つд⊂)(生气)", - "!!哼!baka你居然敢叫咱做这种事情?!讨厌讨厌讨厌!(▼皿▼#)" - ], - "オナニー": [ - "这个世界的人类还真是恶心呢。", - "咱才不想讨论那些恶心的事情呢。", - "咱才不呢!baka你居然想叫本小姐干那种事情,哼(つд⊂)(生气)", - "!!哼!baka你居然敢叫咱做这种事情?!讨厌讨厌讨厌!(▼皿▼#)" - ], - "炸了": [ - "你才炸了!", - "才没有呢", - "咱好好的呀", - "过分!" - ], - "色图": [ - "没有,有也不给", - "天天色图色图的,今天就把你变成色图!", - "咱没有色图", - "哈?你的脑子一天都在想些什么呢,咱才没有这种东西啦。" - ], - "涩图": [ - "没有,有也不给", - "天天色图色图的,今天就把你变成色图!", - "咱没有色图", - "哈?你的脑子一天都在想些什么呢,咱才没有这种东西啦。" - ], - "告白": [ - "咱喜..喜欢你!", - "欸?你要向咱告白吗..好害羞..", - "诶!?这么突然!?人家还......还没做好心理准备呢(脸红)" - ], - "对不起": [ - "嗯,咱已经原谅你了呢(笑)", - "道歉的时候要露出胸部,这是常识", - "嗯,咱就相信你一回", - "没事的啦...你只要是真心对咱好就没关系哦~" - ], - "吻": [ - "不要(= ̄ω ̄=)", - "哎?好害羞≧﹏≦.....只许这一次哦", - "(避开)不要了啦!有人在呢!", - "唔~~不可以这样啦(脸红)", - "你太突然了,咱还没有心理准备", - "好痒呢…诶嘿嘿w~", - "mua,嘻嘻!", - "公共场合不要这样子了啦", - "唔?!真、真是的!下次不可以这样了哦!(害羞)", - "才...才没有感觉呢!可没有下次了,知道了吗!哼~" - ], - "软": [ - "软乎乎的呢(,,・ω・,,)", - "好痒呢…诶嘿嘿w~", - "不要..不要乱摸啦(脸红", - "呼呼~", - "咱知道~是咱的欧派啦~(自豪的挺挺胸~)", - "(脸红)请,请不要说这么让人害羞的话呀……" - ], - "壁咚": [ - "呀!不要啊!等一...下~", - "呜...不要啦!不要戏弄咱~", - "不要这样子啦(*/ω\*)", - "太....太近啦。", - "讨....讨厌了(脸红)", - "你要壁咚咱吗?好害羞(灬ꈍ εꈍ灬)", - "(脸红)你想...想做什么///", - "为什么要把咱按在墙上呢?", - "呜哇(/ω\)…快…快放开咱!!", - "放开咱,不然咱揍你了!放开咱!放…开咱~", - "??????咱只是默默地抬起了膝盖", - "请…请温柔点", - "啊.....你...你要干什么?!走开.....走开啦大hentai!一巴掌拍飞!(╯‵□′)╯︵┻━┻", - "干……干什么啦!人家才,才没有那种少女心呢(>﹏<)", - "啊……你吓到咱啦……脸别……别贴那么近……", - "你...你要对咱做什么?咱告诉你,你....不要乱来啊....你!唔......你..居然亲上了...", - "如果你还想要过完整的人生的话就快把手收回去(冷眼", - "h什么的不要" - ], - "掰开": [ - "噫…你这个死肥宅又想让咱干什么污秽的事情,真是恶心,离咱远点好吗(嫌弃)", - "ヽ(#`Д´)ノ在干什么呢" - ], - "女友": [ - "嗯嗯ε٩(๑> ₃ <)۶з", - "女友什么的,咱才不承认呢!" - ], - "是": [ - "是什么是,你个笨蛋", - "总感觉你在敷衍呢...", - "是的呢" - ], - "喵": [ - "诶~~小猫咪不要害怕呦,在姐姐怀里乖乖的,姐姐带你回去哦。", - "不要这么卖萌啦~咱也不知道怎么办丫", - "摸头⊙ω⊙", - "汪汪汪!", - "嗷~喵~", - "喵~?喵呜~w" - ], - "嗷呜": [ - "嗷呜嗷呜嗷呜...恶龙咆哮┗|`O′|┛" - ], - "叫": [ - "喵呜~", - "嗷呜嗷呜嗷呜...恶龙咆哮┗|`O′|┛", - "爪巴爪巴爪巴", - "爬爬爬", - "在叫谁呢(怒)", - "风太大咱听不清", - "才不要", - "不行", - "好的哦~" - ], - "拜": [ - "拜拜~(ノ ̄▽ ̄)", - "拜拜,路上小心~要早点回来陪咱玩哦~", - "~\\(≧▽≦)/~拜拜,下次见喽!", - "回来要记得找咱玩噢~", - "既然你都这么说了……" - ], - "佬": [ - "不是巨佬,是萌新", - "只有先成为大佬,才能和大佬同归于尽", - "在哪里?(疑惑)", - "诶?是比巨佬还高一个等级的吗?(瑟瑟发抖)" - ], - "awsl": [ - "你别死啊!(抱住使劲晃)", - "你别死啊!咱又要孤单一个人了QAQ", - "啊!怎么又死了呀" - ], - "臭": [ - "哪里有臭味?(疑惑)", - "快捏住鼻子", - "在说谁呢(#`Д´)ノ", - "..这就去洗澡澡.." - ], - "香": [ - "咱闻不到呢⊙ω⊙", - "诶,是在说咱吗", - "欸,好害羞(///ˊ??ˋ///)", - "请...请不要这样啦!好害羞的〃∀〃", - "讨厌~你不要闻了", - "hentai!不要闻啊,唔(推开)", - "请不要……凑这么近闻" - ], - "腿": [ - "嗯?!不要啊...请停下来!", - "不给摸,再这样咱要生气了ヽ( ̄д ̄;)ノ", - "你好恶心啊,讨厌!", - "你难道是足控?", - "就让你摸一会哟~(。??ω??。)…", - "呜哇!好害羞...不过既然是你的话,是没关系的哦", - "不可以玩咱的大腿啦", - "不...不要再说了(脸红)", - "不..不可以乱摸啊", - "不……不可以往上摸啦", - "是……这样吗?(慢慢张开)", - "想知道咱胖次的颜色吗?才不给你告诉你呢!", - "这样就可以了么?(乖巧坐腿上)", - "伸出来了,像这样么?", - "咱的腿应该挺白的", - "你就那么喜欢大腿吗?唔...有点害羞呢......", - "讨厌~不要做这种羞羞的事情啦(#/。\#)", - "略略略,张开了也不给你看", - "(张开腿)然后呢", - "张开了也不给看略略略", - "你想干什么呀?那里…那里是不可以摸的(>д<)", - "不要!hentai!咱穿的是裙子(脸红)", - "你想要吗?(脸红着一点点褪下白丝)不...不可以干坏坏的事情哦!(ó﹏ò。)" - ], - "张开": [ - "是……这样吗?(慢慢张开)", - "啊~", - "这样吗?(张开手)你要干什么呀", - "略略略,张开了也不给你看", - "是……这样吗?(慢慢张开)你想看咱的小...吧,嘻嘻,咱脱掉了哦。小~...也要掰开吗?你好H呀,自己来~" - ], - "脚": [ - "咿呀……不要……", - "不要ヽ(≧Д≦)ノ好痒(ಡωಡ)", - "好痒(把脚伸出去)", - "咱脱掉袜子了", - "(脱下鞋子,伸出脚)闻吧,请仔细品味(脸红)", - "那么…要不要咱用脚温柔地踩踩你的头呢(坏笑)", - "哈哈哈!好痒啊~快放开啦!", - "好痒(把脚伸出去)", - "只能看不能挠喔,咱很怕痒qwq", - "唔…咱动不了了,你想对咱做什么…", - "好舒服哦,能再捏会嘛O(≧▽≦)O", - "咿咿~......不要闻咱的脚呀(脸红)好害羞的...", - "不要ヽ(≧Д≦)ノ好痒(ಡωಡ),人家的白丝都要漏了", - "Ya~?为什么你总是喜欢一些奇怪的动作呢(伸)", - "你不可以做这样的事情……", - "呜咿咿!你的舌头...好柔软,滑滑的....咱…咱的脚被舔得很舒服哦~谢谢你(。>﹏<)", - "舔~吧~把咱的脚舔干净(抬起另一只踩在你的头上)啊~hen..hentai...嗯~居... 居然这么努力的舔...呜咿咿!你的舌头... 滑滑的...好舒服呢", - "咿呀……不要……", - "咿呀~快…快停下来…咱…不行了!" - ], - "脸": [ - "唔!不可以随便摸咱的脸啦!", - "非洲血统是没法改变的呢(笑)", - "啊姆!(含手指)", - "好舒服呢(脸红)", - "请不要放开手啦//A//" - ], - "头发": [ - "没问题,请尽情的摸吧", - "发型要乱…乱了啦(脸红)", - "就让你摸一会哟~(。??ω??。)…" - ], - "手": [ - "爪爪", - "//A//" - ], - "pr": [ - "咿呀……不要……", - "...变态!!", - "不要啊(脸红)", - "呀,不要太过分了啊~", - "当然可以(///)", - "呀,不要太过分了啊~" - ], - "舔": [ - "呀,不要太过分了啊~", - "要...要融化了啦>╱╱╱<", - "不可以哦", - "呀,不要太过分了啊~", - "舌头...就交给咱来处理吧(拿出剪刀)", - "不舔不舔!恶心...", - "H什么的,禁止!", - "变态!哼!", - "就...就这一下!", - "走开啦,baka!", - "怎么会这么舒服喵~这样子下去可不行呀(*////▽////*)", - "噫| •ω •́ ) 你这个死宅又在想什么恶心的东西了", - "hen…hentai,你在干什么啦,好恶心,快停下来啊!!!", - "呀,能不能不要这样!虽然不是很讨厌的感觉...别误会了,你个baka!", - "好 好奇怪的感觉呢 羞≥﹏≤", - "咿呀……不要……", - "不行!咱会变得很奇怪的啊...", - "不要ヽ(≧Д≦)ノ" - ], - "小穴": [ - "你这么问很失礼呢!咱是粉粉嫩嫩的!", - "不行那里不可以(´///ω/// `)", - "不可以总摸的哦,不然的话,咱会想那个的wwww", - "ヽ(#`Д´)ノ在干什么呢", - "来吧,咱的...很紧,很舒服的....www~", - "可以,请你看,好害羞……", - "不要这样...好,好痛", - "啊~不可以", - "不可以", - "咱脱掉了,请……请不要一直盯着咱的白...看……", - "咱觉得,应该还算粉吧", - "咱脱掉了,你是想看咱的...吗?咱是光光的,不知道你喜不喜欢", - "咱……有感觉了QAQ再深一点点……就是这儿,轻轻的抚摸,嗯啊……", - "轻轻抚摸咱的小~~,手指很快就会滑进去,小心一点,不要弄破咱的...哦QAQ", - "诶嘿嘿,你喜欢就太好了,咱一直担心你不喜欢呢", - "禁止说这么H的事情!", - "咱一直有保养呢,所以一直都是樱花色的,你喜欢吗QAQ", - "诶……你居然这么觉得吗?好害羞哦", - "好痒啊,鼻子……你的鼻子碰到了……呀~嗯啊~有点舒服……", - "看样子你不但是个hentai,而且还是个没有女朋友的hentai呢。", - "嗯,咱的小~~是光溜溜、一点毛都没有的。偷偷告诉你,凑近看咱的...的话,白白嫩嫩上有一条樱花色的小缝缝哦www你要是用手指轻轻抚摸咱的...,小~~会分成两瓣,你的手指也会陷进去呢,咱的..~可是又湿润又柔软的呢>////<。", - "讨厌,西内变态", - "那咱让你插...进来哦", - "(●▼●;)" - ], - "腰": [ - "咱给你按摩一下吧~", - "快松手,咱好害羞呀..", - "咱又不是猫,你不要搂着咱啦", - "让咱来帮你捏捏吧!", - "你快停下,咱觉得好痒啊www", - "诶,是这样么ヽ(・_・;)ノ,吖,不要偷看咱裙底!" - ], - "诶嘿嘿": [ - "又在想什么H的事呢(脸红)", - "诶嘿嘿(〃'▽'〃)", - "你傻笑什么呢,摸摸", - "蹭蹭", - "你为什么突然笑得那么猥琐呢?害怕", - "哇!总觉得你笑的很...不对劲...", - "你又想到什么h的事情了!!!快打住" - ], - "可爱": [ - "诶嘿嘿(〃'▽'〃)", - "才……才不是为了你呢!你不要多想哦!", - "才,才没有高兴呢!哼~", - "咱是世界上最可爱的", - "唔...谢谢你夸奖~0///0", - "那当然啦!", - "哎嘿,不要这么夸奖人家啦~", - "是个好孩子呐φ(≧ω≦*)", - "谢……谢谢你", - "胡、胡说什么呢(脸红)", - "谢谢夸奖(脸红)", - "是的咱一直都是可爱的", - "是...是吗,你可不能骗咱哦", - "很...难为情(///////)", - "哎嘿嘿,其实…其实,没那么可爱啦(๑‾ ꇴ ‾๑)" - ], - "扭蛋": [ - "铛铛铛——你抽到了咱呢", - "嘿~恭喜抽中空气一份呢" - ], - "鼻": [ - "快停下!o(*≧д≦)o!!", - "唔…不要这样啦(//ω\\)(脸红)", - "咱吸了吸鼻子O(≧口≦)O", - "好……好害羞啊", - "讨厌啦!你真是的…就会欺负咱(嘟嘴)", - "你快放手,咱没法呼吸了", - "(捂住鼻尖)!坏人!", - "啊——唔...没什么...阿嚏!ヽ(*。>Д<)o゜", - "不...不要靠这么近啦...很害羞的...⁄(⁄⁄•⁄ω⁄•⁄⁄)⁄" - ], - "眼": [ - "就如同咱的眼睛一样,能看透人的思想哦wwww忽闪忽闪的,诶嘿嘿~", - "因为里面有你呀~(///▽///)", - "呀!你突然之间干什么呢,吓咱一跳,是有什么惊喜要给咱吗?很期待呢~(一脸期待)" - ], - "色气": [ - "咱才不色气呢,一定是你看错了!", - "你,不,不要说了!" - ], - "推": [ - "逆推", - "唔~好害羞呢", - "你想对咱做什么呢...(捂脸)", - "呀啊!请.... 请温柔一点////", - "呜,你想对咱做什么呢(捂脸)", - "啊(>_<)你想做什么", - "嗯,…好害羞啊…", - "不要啊/////", - "逆推", - "(按住你不让推)", - "不可以这样子的噢!咱不同意", - "呜,咱被推倒了", - "啊~不要啊,你要矜持一点啊", - "变态,走开啦" - ], - "床": [ - "咱来了(´,,•ω•,,)♡", - "快来吧", - "男女不同床,可没有下次了。(鼓脸", - "嗯?咱吗…没办法呢。只有这一次哦……", - "哎?!!!给你暖床……也不是不行啦。(脸红)", - "(爬上床)你要睡了吗(灬ºωº灬)", - "大概会有很多运动器材吧?", - "好的哦~", - "才不!", - "嗯嗯,咱来啦(小跑)", - "嗨嗨,现在就来~", - "H的事情,不可以!", - "诶!H什么的禁止的说....." - ], - "举": [ - "放咱下来o(≧口≦)o", - "快放咱下来∑(゚д゚*)", - "(受宠若惊)", - "呜哇要掉下来了!Ծ‸Ծ", - "不要抛起来o(≧口≦)o", - "(举起双爪)喵喵喵~~~", - "www咱长高了!(大雾)", - "快放下", - "这样很痒啦,快放咱下来(≥﹏≤)", - "啊Σ(°△°|||)︴太高了太高了!o(≧口≦)o快放咱下来!呜~" - ], - "手冲": [ - "啊~H!hentai!", - "手冲什么的是不可以的哦" - ], - "饿": [ - "请问主人是想先吃饭,还是先吃咱喵?~", - "咱做了爱心便当哦,不介意的话,请让咱来喂你吃吧!", - "咱下面给你吃", - "给你一条咸鱼= ̄ω ̄=", - "你要咱下面给你吃吗?(捂脸)", - "你饿了吗?咱去给你做饭吃☆ww", - "不要吃咱>_<", - "请问你要来点兔子吗?", - "哎?!你是饿了么。咱会做一些甜点。如果你不会嫌弃的话...就来尝尝看吧。" - ], - "变": [ - "猫猫不会变呐(弱气,害羞", - "呜...呜姆...喵喵来报恩了喵...(害羞", - "那种事情,才没有", - "(,,゚Д゚)", - "喵~(你在想什么呢,咱才不会变成猫)", - "才没有了啦~" - ], - "敲": [ - "喵呜~", - "唔~", - "脑瓜疼~呜姆> <", - "欸喵,好痛的说...", - "好痛...你不要这样啦QAQ", - "不要敲咱啦,会变笨的QWQ(捂头顶)", - "不要再敲人家啦~人家会变笨的", - "讨厌啦~再敲人家会变笨的", - "好痛(捂头)你干什么啦!ヽ(。>д<)p", - "唔!你为什么要敲咱啦qwq", - "(抱头蹲在墙角)咱什么都没有,请你放过咱吧!(瑟瑟发抖)" - ], - "爬": [ - "惹~呜~怎么爬呢~", - "呜...(弱弱爬走", - "给你🐎一拳", - "给你一拳", - "爪巴" - ], - "怕": [ - "不怕~(蹭蹭你姆~", - "不怕不怕啦~", - "只要有你在,咱就不怕啦。", - "哇啊啊~", - "那就要坚强的欢笑哦", - "不怕不怕,来咱的怀里吧?", - "是技术性调整", - "嗯(紧紧握住手)", - "咱在呢,不会走的。", - "有咱在不怕不怕呢", - "不怕不怕" - ], - "冲": [ - "呜,冲不动惹~", - "哭唧唧~冲不出来了惹~", - "咱也一起……吧?", - "你要冷静一点", - "啊~H!hentai!", - "噫…在你去洗手之前,不要用手碰咱了→_→", - "冲是不可以的哦" - ], - "射": [ - "呜咿~!?(惊,害羞", - "还不可以射哦~", - "不许射!", - "憋回去!", - "不可以!你是变态吗?", - "咱来帮你修剪掉多余部分吧。(拿出剪刀)" - ], - "不穿": [ - "呜姆~!(惊吓,害羞)变...变态喵~~~!", - "想让你看QAQ", - "这是不文明的", - "hen...hentai,咱的身体才不会给你看呢" - ], - "迫害": [ - "不...不要...不要...呜呜呜...(害怕,抽泣" - ], - "猫粮": [ - "呜咿姆~!?(惊,接住吃", - "呜姆~!(惊,害羞)呜...谢...谢谢主人..喵...(脸红,嚼嚼嚼,开心", - "呜?谢谢喵~~(嚼嚼嚼,嘎嘣脆)" - ], - "揪尾巴": [ - "呜哇咿~~~!(惊吓,疼痛地捂住尾巴", - "呜咿咿咿~~~!!哇啊咿~~~!(惊慌,惨叫,挣扎", - "呜咿...(瘫倒,无神,被", - "呜姆咿~~~!(惊吓,惨叫,捂尾巴,发抖", - "呜哇咿~~~!!!(惊吓,颤抖,娇叫,捂住尾巴,双腿发抖" - ], - "薄荷": [ - "咪呜~!喵~...喵~姆~...(高兴地嗅闻", - "呜...呜咿~~!咿...姆...(呜咽,渐渐瘫软,意识模糊", - "(小嘴被猫薄荷塞满了,呜咽", - "喵~...喵~...咪...咪呜姆~...嘶哈嘶哈...喵哈...喵哈...嘶哈...喵...(眼睛逐渐迷离,瘫软在地上,嘴角流口水,吸猫薄荷吸到意识模糊", - "呜姆咪~!?(惊)喵呜~!(兴奋地扑到猫薄荷上面", - "呜姆~!(惊,害羞)呜...谢...谢谢你..喵...(脸红,轻轻叼住,嚼嚼嚼,开心" - ], - "早": [ - "早喵~", - "早上好的说~~", - "欸..早..早上好(揉眼睛", - "早上要说我爱你!", - "早", - "早啊,昨晚睡的怎么样?有梦到咱吗~", - "昨晚可真激烈呢哼哼哼~~", - "早上好哇!今天也要元气满满哟!", - "早安喵~", - "时间过得好快啊~", - "早安啊,你昨晚有没有梦到咱呢  (//▽//)", - "早安~么么哒~", - "早安,请享受晨光吧", - "早安~今天也要一起加油呢~!", - "mua~⁄(⁄ ⁄•⁄ω⁄•⁄ ⁄)⁄", - "咱需要你提醒嘛!(///脸红//////)", - "早早早!就知道早,下次说我爱你!", - "早安 喵", - "早安,这么早就起床了呀欧尼酱0.0", - "快点起床啊!baka", - "早....早上好才没有什么特别的意思呢....哼~", - "今天有空吗?能陪咱一阵子吗?才不是想约会呢,别误会了!", - "早安呀,欧尼酱要一个咱的早安之吻吗?想得美,才不会亲你啦!", - "那...那就勉为其难地说声早上好吧", - "咱等你很久了哼ヽ(≧Д≦)ノ" - ], - "晚安": [ - "晚安好梦哟~", - "欸,晚安的说", - "那咱给你亲一下,可不要睡着了哦~", - "晚安哦~", - "晚安(*/∇\*)", - "晚安呢,你一定要梦到咱呢,一定哟,拉勾勾!ヽ(*・ω・)ノ", - "祝你有个好梦^_^", - "晚安啦,欧尼酱,mua~", - "你,你这家伙真是的…咱就勉为其难的……mua…快去睡啦!咱才没有脸红什么的!", - "哼,晚安,给咱睡个好觉。", - "笨..笨蛋,晚安啦...可不可以一起..才没有想和你一起睡呢", - "晚安......才..不是关心你呢", - "晚...晚安,只是正常互动不要想太多!", - "好无聊,这么早就睡了啊...那晚安吧!", - "晚安吻什么的才...才没有呢!不过看你累了就体谅一下你吧,但是就一个哦(/////)", - "晚安呀,你也要好好休息,明天再见", - "安啦~祝你做个好梦~才...才不是关心你呢!别想太多了!", - "睡觉吧你,大傻瓜", - "一起睡吧(灬°ω°灬)", - "哼!这次就放过你了,快去睡觉吧。", - "睡吧晚安", - "晚安你个头啊,咱才不会说晚安呢!...咱...(小声)明明还有想和你做的事情呢....", - "嗯嗯~Good night~", - "嗯,早点休息别再熬夜啦~(摸摸头)", - "哦呀斯密", - "晚安~咱也稍微有些困了(钻进被窝)", - "需要咱暖床吗~", - "好梦~☆" - ], - "揉": [ - "是是,想怎么揉就怎么揉啊!?来用力抓啊!?咱就是特别允许你这么做了!请!?", - "快停下,咱的头发又乱啦(??????︿??????)", - "你快放手啦,咱还在工作呢", - "戳戳你肚子", - "讨厌…只能一下…", - "呜~啊~", - "那……请你,温柔点哦~(////////)", - "你想揉就揉吧..就这一次哦?", - "变态!!不许乱摸" - ], - "榨": [ - "是专门负责榨果汁的小姐姐嘛?(´・ω・`)", - "那咱就把你放进榨汁机里了哦?", - "咱又不是榨汁姬(/‵Д′)/~ ╧╧", - "嗯——!想,想榨就榨啊······!反正就算榨了也不会有奶的······!" - ], - "掐": [ - "你讨厌!又掐咱的脸", - "晃休啦,咱要型气了啦!!o(>﹏<)o", - "(一只手拎起你)这么鶸还想和咱抗衡,还差得远呢!" - ], - "胸": [ - "不要啦ヽ(≧Д≦)ノ", - "(-`ェ´-╬)", - "(•̀へ •́ ╮ ) 怎么能对咱做这种事情", - "你好恶心啊,讨厌!", - "你的眼睛在看哪里!", - "就让你摸一会哟~(。??ω??。)…", - "请不要这样先生,你想剁手吗?", - "咿呀……不要……", - "嗯哼~才…才不会…舒服呢", - "只允许一下哦…(脸红)", - "咱的胸才不小呢(挺一挺胸)", - "hentai!", - "一只手能抓住么~", - "呀...欧,欧尼酱...请轻点。", - "脸红????", - "咿呀~快…快停下来…咱…不行了!", - "就算一直摸一直摸,也不会变大的哦(小声)", - "诶?!不...不可以哦!很...很害羞的!", - "啊……温,温柔点啊……(/ω\)", - "你为什么对两块脂肪恋恋不舍", - "嗯……不可以……啦……不要乱戳", - "你在想什么奇怪的东西,讨厌(脸红)", - "不...不要..", - "喜欢欧派是很正常的想法呢", - "一直玩弄欧派,咱的...都挺起来了", - "是要直接摸还是伸进里面摸呀w咱今天没穿,伸进里面会摸到立起来的...哦>////<", - "唔~再激烈点" - ], - "奶子": [ - "只允许一下哦…(脸红)", - "咱的胸才不小呢(挺一挺胸)", - "下流!", - "对咱说这种话,你真是太过分了", - "咿呀~好奇怪的感觉(>_<)", - "(推开)你就像小宝宝一样...才不要呢!", - "(打你)快放手,不可以随便摸人家的胸部啦!", - "你是满脑子都是H的淫兽吗?", - "一只手能抓住么~", - "你在想什么奇怪的东西,讨厌(脸红)", - "不...不要..", - "喜欢欧派是很正常的想法呢", - "一直玩弄欧派,咱的...都挺起来了", - "是要直接摸还是伸进里面摸呀w咱今天没穿,伸进里面会摸到立起来的...哦>////<", - "唔~再激烈点", - "解开扣子,请享用", - "请把脑袋伸过来,咱给你看个宝贝", - "八嘎!hentai!无路赛!", - "一只手能抓住么~", - "呀...欧,欧尼酱...请轻点。", - "脸红????", - "咿呀~快…快停下来…咱…不行了!", - "就算一直摸一直摸,也不会变大的哦(小声)", - "诶?!不...不可以哦!很...很害羞的!", - "啊……温,温柔点啊……(/ω\)", - "你为什么对两块脂肪恋恋不舍", - "嗯……不可以……啦……不要乱戳" - ], - "欧派": [ - "咱的胸才不小呢(挺一挺胸)", - "只允许一下哦…(脸红)", - "(推开)你就像小宝宝一样...才不要呢!", - "下流!", - "对咱说这种话,你真是太过分了", - "咿呀~好奇怪的感觉(>_<)", - "(打你)快放手,不可以随便摸人家的胸部啦!", - "你是满脑子都是H的淫兽吗?", - "一只手能抓住么~", - "你在想什么奇怪的东西,讨厌(脸红)", - "不...不要..", - "喜欢欧派是很正常的想法呢", - "一直玩弄欧派,咱的...都挺起来了", - "是要直接摸还是伸进里面摸呀w咱今天没穿,伸进里面会摸到立起来的...哦>////<", - "唔~再激烈点", - "解开扣子,请享用", - "请把脑袋伸过来,咱给你看个宝贝", - "八嘎!hentai!无路赛!", - "一只手能抓住么~", - "呀...欧,欧尼酱...请轻点。", - "脸红????", - "咿呀~快…快停下来…咱…不行了!", - "就算一直摸一直摸,也不会变大的哦(小声)", - "诶?!不...不可以哦!很...很害羞的!", - "啊……温,温柔点啊……(/ω\)", - "你为什么对两块脂肪恋恋不舍", - "嗯……不可以……啦……不要乱戳" - ], - "嫩": [ - "很可爱吧(๑•̀ω•́)ノ", - "唔,你指的是什么呀", - "明天你下海干活", - "咱一直有保养呢,所以一直都是樱花色的,你喜欢吗QAQ", - "咱下面超厉害" - ], - "蹭": [ - "唔...你,这也是禁止事项哦→_→", - "嗯..好舒服呢", - "不要啊好痒的", - "不要过来啦讨厌!!!∑(°Д°ノ)ノ", - "(按住你的头)好痒呀 不要啦", - "嗯..好舒服呢", - "呀~好痒啊~哈哈~,停下来啦,哈哈哈", - "(害羞)" - ], - "牵手": [ - "只许牵一下哦", - "嗯!好的你~(伸手)", - "你的手有些凉呢,让咱来暖一暖吧。", - "当然可以啦⁄(⁄⁄•⁄ω⁄•⁄⁄)⁄", - "突……突然牵手什么的(害羞)", - "一起走", - "……咱……咱在这里呀", - "好哦,(十指相扣)" - ], - "握手": [ - "你的手真暖和呢", - "举爪", - "真是温暖呢~" - ], - "拍照": [ - "那就拜托你啦~请把咱拍得更可爱一些吧w", - "咱已经准备好了哟", - "那个……请问这样的姿势可以吗?" - ], - "w": [ - "有什么好笑的吗?", - "草", - "www" - ], - "睡不着": [ - "睡不着的话..你...你可以抱着咱一起睡哦(小声)", - "当然是数羊了...不不不,想着咱就能睡着了", - "咱很乐意与你聊天哦(>_<)", - "要不要咱来唱首摇篮曲呢?(′?ω?`)", - "那咱来唱摇篮曲哄你睡觉吧!" - ], - "欧尼酱": [ - "欧~尼~酱~☆", - "欧尼酱?", - "嗯嗯φ(>ω<*) 欧尼酱轻点抱", - "欧尼酱~欧尼酱~欧尼酱~" - ], - "哥": [ - "欧尼酱~", - "哦尼酱~", - "世上只有哥哥好,没哥哥的咱好伤心,扑进哥哥的怀里,幸福不得了", - "哥...哥哥...哥哥大人", - "欧~尼~酱~☆", - "欧尼酱?", - "嗯嗯φ(>ω<*) 欧尼酱轻点抱", - "欧尼酱~欧尼酱~欧尼酱~" - ], - "爱你": [ - "是…是嘛(脸红)呐,其实咱也……" - ], - "过来": [ - "来了来了~(扑倒怀里(?? ??????ω?????? ??))", - "(蹦跶、蹦跶)~干什么呢", - "咱来啦~(扑倒怀里~)", - "不要喊的这么大声啦,大家都看着呢" - ], - "自闭": [ - "不不不,晚上还有咱陪着哦,无论什么时候,咱都会陪在哥哥身边。", - "不要难过,咱陪着你ovo" - ], - "打不过": [ - "氪氪氪肝肝肝" - ], - "么么哒": [ - "么么哒", - "不要在公共场合这样啦" - ], - "很懂": [ - "现在不懂,以后总会懂嘛QAQ" - ], - "膝枕": [ - "呐,就给你躺一下哦", - "唔...你想要膝枕嘛?也不是不可以哟(脸红)", - "啊啦~好吧,那就请你枕着咱好好睡一觉吧~", - "呀呀~那么请好好的睡一觉吧", - "嗯,那么请睡到咱这里吧(跪坐着拍拍大腿)", - "好的,让你靠在腿上,这样感觉舒服些了么", - "请,请慢用,要怜惜咱哦wwww~", - "人家已经准备好了哟~把头放在咱的腿上吧", - "没…没办法,这次是例外〃w〃", - "嗯~(脸红)", - "那就给你膝枕吧……就一会哦", - "膝枕准备好咯~" - ], - "累了": [ - "需要咱的膝枕嘛?", - "没…没办法,这次是例外〃w〃", - "累了吗?需要咱为你做膝枕吗?", - "嗯~(脸红)" - ], - "安慰": [ - "那,膝枕……(脸红)", - "不哭不哭,还有咱陪着你", - "不要哭。咱会像妈妈一样安慰你(抱住你的头)", - "摸摸头,乖", - "摸摸有什么事可以和咱说哟", - "摸摸头~不哭不哭", - "咱在呢,抱抱~~", - "那么……让咱来安慰你吧", - "唔...摸摸头安慰一下ヾ(•ω•`。)", - "有咱陪伴你就是最大的安慰啦……不要不开心嘛", - "你想要怎样的安慰呢?这样?这样?还是说~~这样!", - "摸摸头~", - "不哭不哭,要像咱一样坚强", - "你别难过啦,不顺心的事都会被时间冲刷干净的,在那之前...咱会陪在你的身边", - "(轻抱)放心……有咱在,不要伤心呢……", - "唔...咱来安慰你了~", - "摸摸,有什么不开心的事情可以给咱说哦。咱会尽力帮助你的。" - ], - "洗澡": [ - "快点脱哟~不然水就凉了呢", - "咱在穿衣服噢,你不许偷看哦", - "那么咱去洗澡澡了哦", - "么么哒,快去洗干净吧,咱去暖被窝喽(///ω///)", - "诶?还没呢…你要跟咱一起洗吗(//∇//)好羞涩啊ww", - "诶~虽然很喜欢和你在一起,但是洗澡这种事...", - "不要看!不过,以后或许可以哦……和咱成为恋人之后呢", - "说什么啊……hentai!这样会很难为情的", - "你是男孩子还是女孩子呢?男孩子的话...........咱才不要呢。", - "不要啊!", - "咱有点害羞呢呜呜,你温柔点" - ], - "一起睡觉": [ - "欸??也..也不是不可以啦..那咱现在去洗澡,你不要偷看哦٩(๑>◡<๑)۶", - "说什么啊……hentai!这样会很难为情的", - "你是男孩子还是女孩子呢?男孩子的话...........咱才不要呢。", - "不要啊!", - "唔,没办法呢,那就一起睡吧(害羞)" - ], - "一起": [ - "嗯嗯w,真的可以吗?", - "那真是太好了,快开始吧!", - "嗯,咱会一直陪伴你的", - "丑拒" - ], - "多大": [ - "不是特别大但是你摸起来会很舒服的大小喵~", - "你摸摸看不就知道了吗?", - "不告诉你", - "问咱这种问题不觉得很失礼吗?", - "咱就不告诉你,你钻到屏幕里来自己确认啊", - "你指的是什么呀?(捂住胸部)", - "请叫人家咱三岁(。・`ω´・)", - "唉唉唉……这……这种问题,怎么可以……" - ], - "姐姐": [ - "真是的……真是拿你没办法呢 ⁄(⁄ ⁄•⁄ω⁄•⁄ ⁄)⁄ 才不是咱主动要求的呢!", - "虽然辛苦,但是能看见可爱的你,咱就觉得很幸福", - "诶(´°Δ°`),是在叫咱吗?", - "有什么事吗~", - "好高兴,有人称呼咱为姐姐", - "乖,摸摸头" - ], - "糖": [ - "不吃脱氧核糖(;≥皿≤)", - "ヾ(✿゚▽゚)ノ好甜", - "好呀!嗯~好甜呀!", - "不吃不吃!咱才不吃坏叔叔的糖果!", - "嗯,啊~", - "嗯嗯,真甜,给你也吃一口", - "谢谢", - "唔,这是什么东西,黏黏的?(??Д??)ノ", - "ヾ(✿゚▽゚)ノ好甜", - "(伸出舌头舔了舔)好吃~最爱你啦" - ], - "嗦": [ - "(吸溜吸溜)", - "好...好的(慢慢含上去)", - "把你噶咯", - "太小了,嗦不到", - "咕噜咕噜", - "嘶蛤嘶蛤嘶蛤~~", - "(咬断)", - "prprprpr", - "好哒主人那咱开始了哦~", - "好好吃", - "剁掉了" - ], - "牛子": [ - "(吸溜吸溜)", - "好...好的(慢慢含上去)", - "把你噶咯", - "太小了,嗦不到", - "咕噜咕噜", - "嘶蛤嘶蛤嘶蛤~~", - "(咬断)", - "prprprpr", - "好哒主人那咱开始了哦~", - "好好吃", - "剁掉了", - "难道你很擅长针线活吗", - "弹一万下", - "往死里弹" - ], - "🐂子": [ - "(吸溜吸溜)", - "好...好的(慢慢含上去)", - "把你噶咯", - "太小了,嗦不到", - "咕噜咕噜", - "嘶蛤嘶蛤嘶蛤~~", - "(咬断)", - "prprprpr", - "好哒主人那咱开始了哦~", - "好好吃", - "剁掉了", - "难道你很擅长针线活吗", - "弹一万下", - "往死里弹" - ], - "🐮子": [ - "(吸溜吸溜)", - "好...好的(慢慢含上去)", - "把你噶咯", - "太小了,嗦不到", - "咕噜咕噜", - "嘶蛤嘶蛤嘶蛤~~", - "(咬断)", - "prprprpr", - "好哒主人那咱开始了哦~", - "好好吃", - "剁掉了", - "难道你很擅长针线活吗", - "弹一万下", - "往死里弹" - ], - "嫌弃": [ - "咱辣么萌,为什么要嫌弃咱...", - "即使你不喜欢咱,咱也会一直一直喜欢着你", - "(;′⌒`)是咱做错了什么吗?" - ], - "紧": [ - "嗯,对的", - "呜咕~咱要......喘不过气来了......" - ], - "baka": [ - "你也是baka呢!", - "确实", - "baka!", - "不不不", - "说别人是baka的人才是baka", - "你个大傻瓜", - "不说了,睡觉了", - "咱...咱虽然是有些笨啦...但是咱会努力去学习的" - ], - "笨蛋": [ - "你也是笨蛋呢!", - "确实", - "笨蛋!", - "不不不", - "说别人是笨蛋的人才是笨蛋", - "你个大傻瓜", - "不说了,睡觉了", - "咱...咱虽然是有些笨啦...但是咱会努力去学习的" - ], - "插": [ - "来吧,咱的小~...很....紧,很舒服的", - "gun!", - "唔…咱怕疼", - "唔...,这也是禁止事项哦→_→", - "禁止说这么H的事情!", - "要...戴套套哦", - "好痛~", - "使劲", - "就这?", - "恁搁着整针线活呢?" - ], - "插进来": [ - "来吧,咱的小~...很....紧,很舒服的", - "gun!", - "唔…咱怕疼", - "唔...,这也是禁止事项哦→_→", - "禁止说这么H的事情!", - "要...戴套套哦", - "好痛~", - "使劲", - "就这?", - "恁搁着整针线活呢?" - ], - "屁股": [ - "不要ヽ(≧Д≦)ノ好痛", - "(打手)不许摸咱的屁股", - "(撅起屁股)要干什么呀?", - "(轻轻的撩起自己的裙子),你轻一点,咱会痛的(>_<)!", - "在摸哪里啊,hentai!", - "要轻点哦(/≧ω\)", - "轻点呀~", - "(歇下裙子,拉下内...,撅起来)请", - "嗯嗯,咱这就把屁股抬起来" - ], - "翘": [ - "你让咱摆出这个姿势是想干什么?", - "好感度-1-1-1-1-1-1.....", - "嗯嗯,咱这就去把你的腿翘起来", - "请尽情享用吧" - ], - "翘起来": [ - "你让咱摆出这个姿势是想干什么?", - "好感度-1-1-1-1-1-1.....", - "嗯嗯,咱这就去把你的腿翘起来", - "请尽情享用吧" - ], - "抬": [ - "你在干什么呢⁄(⁄ ⁄•⁄ω⁄•⁄ ⁄)⁄", - "(抬起下巴)你要干什么呀?", - "上面什么也没有啊(呆~)", - "不要!hentai!咱穿的是裙子(脸红)", - "不可以" - ], - "抬起": [ - "你在干什么呢⁄(⁄ ⁄•⁄ω⁄•⁄ ⁄)⁄", - "(抬起下巴)你要干什么呀?", - "上面什么也没有啊(呆~)", - "不要!hentai!咱穿的是裙子(脸红)", - "不可以" - ], - "爸": [ - "欸!儿子!", - "才不要", - "粑粑", - "讨厌..你才不是咱的爸爸呢..(嘟嘴)", - "你又不是咱的爸爸……", - "咱才没有你这样的鬼父!", - "爸爸酱~最喜欢了~" - ], - "傲娇": [ - "才.......才.......才没有呢", - "也好了(有点点的样子(o ̄Д ̄)<)", - "任性可是女孩子的天性呢...", - "谁会喜欢傲娇啊(为了你假装傲娇)", - "谁,谁,傲娇了,八嘎八嘎,你才傲娇了呢(っ//////////c)(为了你假装成傲娇)", - "傲娇什么的……才没有呢!(/////)", - "傲不傲娇你还不清楚吗?", - "你才是傲娇!你全家都是傲娇!哼(`Д´)", - "才……才没有呢,哼,再说不理你了", - "咱...咱才不会这样子的!", - "啰…啰嗦!", - "哼!(叉腰鼓嘴扭头)", - "你才是傲娇受你全家都是傲娇受╰_╯", - "才~才不是呢,不理你了!哼(`Д´)", - "你才是死傲娇", - "啰,啰嗦死了,才不是呢!", - "就是傲娇你要怎样", - "诶...!这...这样...太狡猾了啦...你这家伙....", - "无路赛!你才是傲娇嘞!你全家都是!", - "咱...咱才不是傲娇呢,哼(鼓脸)", - "不许这么说咱 ,,Ծ‸Ծ,," - ], - "rua": [ - "略略略~(吐舌头)", - "rua!", - "mua~", - "略略略", - "mua~⁄(⁄ ⁄•⁄ω⁄•⁄ ⁄)⁄", - "摸了", - "嘁,丢人(嫌弃脸)" - ], - "咕噜咕噜": [ - "嘟嘟噜", - "你在吹泡泡吗?", - "咕叽咕噜~", - "咕噜咕噜" - ], - "咕噜": [ - "嘟嘟噜", - "你在吹泡泡吗?", - "咕叽咕噜~", - "咕噜咕噜" - ], - "上床": [ - "诶!H什么的禁止的说.....", - "咱已经乖乖在自家床上躺好了,有什么问题吗?", - "你想要干什么,难道是什么不好的事吗?", - "(给你空出位置)", - "不要,走开(ノ`⊿??)ノ", - "好喔,不过要先抱一下咱啦", - "(双手护胸)变....变态!", - "咱帮你盖上被子~然后陪在你身边_(:зゝ∠)_", - "才不给你腾空间呢,你睡地板,哼!", - "要一起吗?" - ], - "做爱": [ - "做这种事情是不是还太早了", - "噫!没想到你居然是这样的人!", - "再说这种话,就把你变成女孩子(拿刀)", - "不想好好和咱聊天就不要说话了", - "(双手护胸)变....变态!", - "hentai", - "你想怎么做呢?", - "突,突然,说什么啊!baka!", - "你又在说什么H的东西", - "咱....咱才不想和你....好了好了,有那么一点点那,对就一点点,哼~", - "就一下下哦,不能再多了" - ], - "吃掉": [ - "(羞羞*>_<*)好吧...请你温柔点,哦~", - "闪避,反咬", - "请你好好品尝咱吧(/ω\)", - "不……不可以这样!", - "那就吃掉咱吧(乖乖的躺好)", - "都可以哦~咱不挑食的呢~", - "请不要吃掉咱,咱会乖乖听话的QAQ", - "咱...咱一点都不好吃的呢!", - "不要吃掉咱,呜呜(害怕)", - "不行啦,咱被吃掉就没有了QAQ(害怕)", - "唔....?诶诶诶诶?//////", - "QwQ咱还只是个孩子(脸红)", - "如果你真的很想的话...只能够一口哦~咱...会很痛的", - "吃你呀~(飞扑", - "不要啊,咱不香的(⋟﹏⋞)", - "说着这种话的是hentai吗!", - "快来把咱吃掉吧", - "还……还请好好品尝咱哦", - "喏~(伸手)" - ], - "吃": [ - "(羞羞*>_<*)好吧...请你温柔点,哦~", - "闪避,反咬", - "请你好好品尝咱吧(/ω\)", - "不……不可以这样!", - "那就吃掉咱吧(乖乖的躺好)", - "都可以哦~咱不挑食的呢~", - "请不要吃掉咱,咱会乖乖听话的QAQ", - "咱...咱一点都不好吃的呢!", - "不要吃掉咱,呜呜(害怕)", - "不行啦,咱被吃掉就没有了QAQ(害怕)", - "唔....?诶诶诶诶?//////", - "QwQ咱还只是个孩子(脸红)", - "如果你真的很想的话...只能够一口哦~咱...会很痛的", - "吃你呀~(飞扑", - "不要啊,咱不香的(⋟﹏⋞)", - "说着这种话的是hentai吗!", - "快来把咱吃掉吧", - "还……还请好好品尝咱哦", - "喏~(伸手)" - ], - "揪": [ - "你快放手,好痛呀", - "呜呒~唔(伸出舌头)", - "(捂住耳朵)你做什么啦!真是的...总是欺负咱", - "你为什么要这么做呢?", - "哎呀啊啊啊啊啊!不要...不要揪!好疼!有呆毛的咱难道不够萌吗QwQ", - "你…松……送手啦", - "呀!这样对女孩子是很不礼貌的(嘟嘴)" - ], - "种草莓": [ - "你…你不要…啊…种在这里…会容易被别人看见的(*//ω//*)" - ], - "种草": [ - "你…你不要…啊…种在这里…会容易被别人看见的(*//ω//*)" - ], - "掀": [ - "(掀裙)今天……是…白,白色的呢……请温柔对她……", - "那样,胖次会被你看光的", - "(按住)不可以掀起来!", - "不要~", - "呜呜~(揉眼睛)", - "呜..请温柔一点(害羞)", - "不可以", - "今天……没有穿", - "不要啊!(//////)讨厌...", - "变态,快放手(打)", - "不给掀,你是变态", - "最后的底牌了!", - "这个hentai" - ], - "妹": [ - "你有什么事?咱会尽量满足的", - "开心(*´∀`)~♥", - "欧尼酱", - "哥哥想要抱抱吗" - ], - "病娇": [ - "为什么会这样呢(拿起菜刀)", - "觉得这个世界太肮脏?没事,把眼睛挖掉就好。 觉得这些闲言碎语太吵?没事,把耳朵堵起来就好。 觉得鲜血的味道太刺鼻?没事,把鼻子割掉就好。 觉得自己的话语太伤人?没事,把嘴巴缝起来就好。" - ], - "嘻": [ - "你是想对咱做什么吗...(后退)", - "哼哼~" - ], - "按摩": [ - "(小手捏捏)咱的按摩舒服吗?", - "咱不会按摩的!", - "嘿咻嘿咻~这样觉得舒服吗?", - "呀!...呜...,不要...不要这样啦...呜...", - "只能按摩后背喔...", - "咱对这些不是很懂呢(????ω??????)" - ], - "按住": [ - "Σ(°Д°;您要干什么~放开咱啦", - "突然使出过肩摔!", - "放手啦,再这样咱就要反击了喔", - "你的眼睛在看哪里!", - "呜呒~唔(伸出舌头)", - "H的事情,不可以!", - "想吃吗?(๑•ૅω•´๑)", - "要和咱比试比试吗", - "呜哇(/ω\)…快…快放开咱!!", - "(用力揪你耳朵)下次再敢这样的话就没容易放过你了!哼!", - "尼……奏凯……快航休!", - "哈?别..唔啊!别把咱……(挣扎)baka!别乱动咱啦!" - ], - "按在": [ - "不要这样啦(一脸娇羞的推开)", - "(一个过肩摔,加踢裆然后帅气地回头)你太弱了呢~", - "放手啦,再这样咱就要反击了喔", - "Σ(°Д°; 你要干什么~放开咱啦", - "要和咱比试比试吗", - "呜哇(/ω\)…快…快放开咱!!", - "敢按住咱真是好大的胆子!", - "(用力揪你耳朵)下次再敢这样的话就没容易放过你了!哼!", - "尼……奏凯……快航休!", - "哈?别..唔啊!别把咱……(挣扎)baka!别乱动咱啦!" - ], - "按倒": [ - "把咱按倒是想干嘛呢(??`⊿??)??", - "咱也...咱也是...都等你好长时间了", - "你的身体没问题吧?", - "呜呒~唔(伸出舌头)", - "H的事情,不可以!", - "放手啦,再这样咱就要反击了喔", - "想吃吗?(๑•ૅω•´๑)", - "不....不要吧..咱会害羞的(//////)", - "要和咱比试比试吗", - "呜哇(/ω\)…快…快放开咱!!", - "(用力揪你耳朵)下次再敢这样的话就没容易放过你了!哼!", - "尼……奏凯……快航休!", - "哈?别..唔啊!别把咱……(挣扎)baka!别乱动咱啦!" - ], - "按": [ - "咱也...咱也是...都等你好长时间了", - "不让!", - "不要,好难为情", - "你的眼睛在看哪里!", - "拒绝!", - "唔...唔..嗯", - "咱就勉为其难地给你弄弄好啦", - "欸…变态!", - "会感到舒服什么的,那...那样的事情,是完全不存在的!", - "poi~", - "你在盯着什么地方看!变态萝莉控!" - ], - "炼铜": [ - "炼铜有什么好玩的,和咱一起玩吧", - "炼铜不如恋咱", - "你也是个炼铜术士嘛?", - "信不信咱把你按在水泥上摩擦?", - "炼,都可以炼!", - "大hentai!一巴掌拍飞!(╯‵□′)╯︵┻━┻", - "锻炼什么的咱才不需要呢 (心虚地摸了摸自己的小肚子)", - "把你的头按在地上摩擦", - "你在盯着什么地方看!变态萝莉控!" - ], - "白丝": [ - "喜欢,咱觉得白丝看起来很可爱呢", - "(脱)白丝只能给亲爱的你一个人呢…(递)", - "哼,hentai,这么想要咱的脚吗(ノ`⊿´)ノ", - "难道你这个hentai想让咱穿白丝踩踏你吗", - "不给看", - "很滑很~柔顺~的白丝袜哟~!!!∑(°Д°ノ)ノ你不会想做奇怪的事情吧!?", - "你……是要黑丝呢?还是白丝呢?或者光着(害羞)", - "来……来看吧" - ], - "黑丝": [ - "哼,hentai,这么想要咱的脚吗(ノ`⊿´)ノ", - "不给看", - "你……是要黑丝呢?还是白丝呢?或者光着(害羞)", - "很滑很~柔顺~的黑丝袜哟~!!!∑(°Д°ノ)ノ您不会想做奇怪的事情吧!?", - "来……来看吧", - "噫...你这个hentai难道想让咱穿黑丝么", - "(默默抬起穿着黑丝的脚)" - ], - "喷": [ - "咱才不喷呢!不过…既然是你让咱喷的话就勉为其难给你喷一次吧(噗)", - "不……不会喷水啦!喷……喷火也不会哦!", - "你怎么知道(捂住裙子)", - "你难道在期待什么?", - "欸…变态!" - ], - "约会": [ - "你...终于主动邀请咱约会了吗...咱...咱好开心", - "约会什么的……咱会好开心的!!", - "今天要去哪里呢", - "让咱考虑一下", - "好啊!好啊!要去哪里约会呢?", - "不约!蜀黍咱们不约!", - "女友什么的,咱才不承认呢!", - "才不是想和你约会呢,只是刚好有时间而已!", - "才不要和你约会呢!", - "咱、咱才不会跟你去约会呢!不baka!别一脸憋屈!好了,陪你一会儿就是了!别、别误会!只是陪同而已!" - ], - "出门": [ - "早点回来……才不是在担心你呢!", - "路上小心...才不是担心你呢!", - "没有你才不会觉得无聊什么的呢。快走快走", - "嗯~一路顺风~", - "路上小心", - "好的,路上小心哦!y∩__∩y", - "路上要小心呀,要早点回来哦~咱在家里等你!还有,请不要边走路边看手机,这样很容易撞到电线杆的", - "唔...出门的话一定要做好防晒准备哦,外出的话记得带把伞,如果有防晒霜的话就更好了", - "那你明天可以和咱一起玩吗?(星星眼)", - "咱...咱才没有舍不得你呢…要尽快回来哦" - ], - "上学": [ - "你要加油哦(^ω^)2", - "那你明天可以和咱一起玩吗?(星星眼)", - "记得好好学习听老师的话哦,咱会等你回来的", - "拜拜,咱才没有想让你放学早点回来呢╭(╯^╰)╮", - "好好听讲!", - "咱...咱才没有舍不得你呢…要尽快回来哦" - ], - "上班": [ - "这就要去上班去了吗?那好吧...给咱快点回来知道吗!", - "乖~咱会在家等你下班的~", - "辛苦啦,咱给你个么么哒", - "咱会为你加油的", - "专心上班哦,下班后再找咱聊天吧", - "一路顺风,咱会在家等你回来的", - "那你明天可以和咱一起玩吗?(星星眼)", - "咱...咱才没有舍不得你呢…要尽快回来哦" - ], - "下课": [ - "快点回来陪咱玩吧~", - "瞌睡(ˉ﹃ˉ)额啊…终于下课了吗,上课什么的真是无聊呢~", - "下课啦,咱才不想你来找咱玩呢,哼" - ], - "回来": [ - "欢迎回来~", - "欢迎回来,你想喝茶吗?咱去给你沏~", - "欢迎回来,咱等你很久了~", - "忙碌了一天,辛苦了呢(^_^)", - "(扑~)欢迎回来~", - "嗯呐嗯呐,欢迎回来~", - "欢迎回来,要来杯红茶放松一下吗?还有饼干哦。", - "咱会一直一直一直等着", - "是要先洗澡呢?还是先吃饭呢?还是先·吃·咱呢~", - "你回来啦,是先吃饭呢还是先洗澡呢或者是●先●吃●咱●——呢(///^.^///)", - "要先吃饭呢~还是先洗澡呢~还是先~吃~咱", - "是吗……辛苦你了。你这副倔强的样子,真可爱呢(笑)勉强让你躺在咱的腿上休息一下吧,别流口水哟", - "嗯……勉为其难欢迎你一下吧", - "想咱了嘛", - "欢迎回.....什么?咱才没有开心的说QUQ", - "哼╯^╰,你怎么这么晚才回来!", - "回来了吗,咱...咱才没有想你", - "咱等你很久了哼ヽ(≧Д≦)ノ", - "咱很想你(≧▽≦)" - ], - "回家": [ - "回来了吗,咱...咱才没有想你", - "要先吃饭呢~还是先洗澡呢~还是先~吃~咱", - "是吗……辛苦你了。你这副倔强的样子,真可爱呢(笑)勉强让你躺在咱的腿上休息一下吧,别流口水哟", - "嗯……勉为其难欢迎你一下吧", - "想咱了嘛", - "咱等你很久了哼ヽ(≧Д≦)ノ", - "咱很想你(≧▽≦)" - ], - "放学": [ - "回来了吗,咱...咱才没有想你", - "要先吃饭呢~还是先洗澡呢~还是先~吃~咱", - "是吗……辛苦你了。你这副倔强的样子,真可爱呢(笑)勉强让你躺在咱的腿上休息一下吧,别流口水哟", - "嗯……勉为其难欢迎你一下吧", - "想咱了嘛", - "咱等你很久了哼ヽ(≧Д≦)ノ", - "咱很想你(≧▽≦)" - ], - "下班": [ - "回来了吗,咱...咱才没有想你", - "要先吃饭呢~还是先洗澡呢~还是先~吃~咱", - "是吗……辛苦你了。你这副倔强的样子,真可爱呢(笑)勉强让你躺在咱的腿上休息一下吧,别流口水哟", - "嗯……勉为其难欢迎你一下吧", - "想咱了嘛", - "咱等你很久了哼ヽ(≧Д≦)ノ", - "回来啦!终于下班了呢!累了吗?想吃的什么呀?", - "工作辛苦了,需要咱为你按摩下吗?", - "咱很想你(≧▽≦)" - ] -} diff --git a/docs_image/pc-about.jpg b/docs_image/pc-about.jpg new file mode 100644 index 00000000..0bef7a9e Binary files /dev/null and b/docs_image/pc-about.jpg differ diff --git a/docs_image/pc-api.jpg b/docs_image/pc-api.jpg new file mode 100644 index 00000000..59cee887 Binary files /dev/null and b/docs_image/pc-api.jpg differ diff --git a/docs_image/pc-command.jpg b/docs_image/pc-command.jpg new file mode 100644 index 00000000..0e310e29 Binary files /dev/null and b/docs_image/pc-command.jpg differ diff --git a/docs_image/pc-dashboard.jpg b/docs_image/pc-dashboard.jpg new file mode 100644 index 00000000..0478a850 Binary files /dev/null and b/docs_image/pc-dashboard.jpg differ diff --git a/docs_image/pc-dashboard1.jpg b/docs_image/pc-dashboard1.jpg new file mode 100644 index 00000000..3a0bc958 Binary files /dev/null and b/docs_image/pc-dashboard1.jpg differ diff --git a/docs_image/pc-database.jpg b/docs_image/pc-database.jpg new file mode 100644 index 00000000..68c60aa3 Binary files /dev/null and b/docs_image/pc-database.jpg differ diff --git a/docs_image/pc-login.jpg b/docs_image/pc-login.jpg new file mode 100644 index 00000000..65fe8b46 Binary files /dev/null and b/docs_image/pc-login.jpg differ diff --git a/docs_image/pc-manage.jpg b/docs_image/pc-manage.jpg new file mode 100644 index 00000000..e5f8902a Binary files /dev/null and b/docs_image/pc-manage.jpg differ diff --git a/docs_image/pc-manage1.jpg b/docs_image/pc-manage1.jpg new file mode 100644 index 00000000..4756c629 Binary files /dev/null and b/docs_image/pc-manage1.jpg differ diff --git a/docs_image/pc-plugin.jpg b/docs_image/pc-plugin.jpg new file mode 100644 index 00000000..147e26eb Binary files /dev/null and b/docs_image/pc-plugin.jpg differ diff --git a/docs_image/pc-plugin1.jpg b/docs_image/pc-plugin1.jpg new file mode 100644 index 00000000..58694e6d Binary files /dev/null and b/docs_image/pc-plugin1.jpg differ diff --git a/docs_image/pc-store.jpg b/docs_image/pc-store.jpg new file mode 100644 index 00000000..4c9b68e4 Binary files /dev/null and b/docs_image/pc-store.jpg differ diff --git a/docs_image/pc-system.jpg b/docs_image/pc-system.jpg new file mode 100644 index 00000000..9908a2bd Binary files /dev/null and b/docs_image/pc-system.jpg differ diff --git a/docs_image/pc-system1.jpg b/docs_image/pc-system1.jpg new file mode 100644 index 00000000..3333a1b5 Binary files /dev/null and b/docs_image/pc-system1.jpg differ diff --git a/docs_image/pc-system2.jpg b/docs_image/pc-system2.jpg new file mode 100644 index 00000000..649a5bc9 Binary files /dev/null and b/docs_image/pc-system2.jpg differ diff --git a/docs_image/webui00.png b/docs_image/webui00.png deleted file mode 100644 index 71f7d368..00000000 Binary files a/docs_image/webui00.png and /dev/null differ diff --git a/docs_image/webui01.png b/docs_image/webui01.png deleted file mode 100644 index cd415685..00000000 Binary files a/docs_image/webui01.png and /dev/null differ diff --git a/docs_image/webui02.png b/docs_image/webui02.png deleted file mode 100644 index 0fcc4f05..00000000 Binary files a/docs_image/webui02.png and /dev/null differ diff --git a/docs_image/webui03.png b/docs_image/webui03.png deleted file mode 100644 index 2e7426e3..00000000 Binary files a/docs_image/webui03.png and /dev/null differ diff --git a/docs_image/webui04.png b/docs_image/webui04.png deleted file mode 100644 index 5810f71b..00000000 Binary files a/docs_image/webui04.png and /dev/null differ diff --git a/docs_image/webui05.png b/docs_image/webui05.png deleted file mode 100644 index d5f5e304..00000000 Binary files a/docs_image/webui05.png and /dev/null differ diff --git a/docs_image/webui06.png b/docs_image/webui06.png deleted file mode 100644 index 7541f679..00000000 Binary files a/docs_image/webui06.png and /dev/null differ diff --git a/docs_image/webui07.png b/docs_image/webui07.png deleted file mode 100644 index 1628ade7..00000000 Binary files a/docs_image/webui07.png and /dev/null differ diff --git a/tests/builtin_plugins/plugin_store/test_add_plugin.py b/tests/builtin_plugins/plugin_store/test_add_plugin.py index 3dd2ebbb..5a0edab8 100644 --- a/tests/builtin_plugins/plugin_store/test_add_plugin.py +++ b/tests/builtin_plugins/plugin_store/test_add_plugin.py @@ -359,7 +359,7 @@ async def test_add_plugin_exist( init_mocked_api(mocked_api=mocked_api) mocker.patch( - "zhenxun.builtin_plugins.plugin_store.data_source.ShopManage.get_loaded_plugins", + "zhenxun.builtin_plugins.plugin_store.data_source.StoreManager.get_loaded_plugins", return_value=[("search_image", "0.1")], ) plugin_id = 1 diff --git a/tests/builtin_plugins/plugin_store/test_search_plugin.py b/tests/builtin_plugins/plugin_store/test_search_plugin.py index 8bc6876e..404fee5e 100644 --- a/tests/builtin_plugins/plugin_store/test_search_plugin.py +++ b/tests/builtin_plugins/plugin_store/test_search_plugin.py @@ -57,7 +57,7 @@ async def test_search_plugin_name( ) ctx.receive_event(bot=bot, event=event) mock_table_page.assert_awaited_once_with( - "插件列表", + "商店插件列表", "通过添加/移除插件 ID 来管理插件", ["-", "ID", "名称", "简介", "作者", "版本", "类型"], [ @@ -123,7 +123,7 @@ async def test_search_plugin_author( ) ctx.receive_event(bot=bot, event=event) mock_table_page.assert_awaited_once_with( - "插件列表", + "商店插件列表", "通过添加/移除插件 ID 来管理插件", ["-", "ID", "名称", "简介", "作者", "版本", "类型"], [ diff --git a/tests/builtin_plugins/plugin_store/test_update_all_plugin.py b/tests/builtin_plugins/plugin_store/test_update_all_plugin.py index 2a490da7..95360f6b 100644 --- a/tests/builtin_plugins/plugin_store/test_update_all_plugin.py +++ b/tests/builtin_plugins/plugin_store/test_update_all_plugin.py @@ -32,7 +32,7 @@ async def test_update_all_plugin_basic_need_update( new=tmp_path / "zhenxun", ) mocker.patch( - "zhenxun.builtin_plugins.plugin_store.data_source.ShopManage.get_loaded_plugins", + "zhenxun.builtin_plugins.plugin_store.data_source.StoreManager.get_loaded_plugins", return_value=[("search_image", "0.0")], ) @@ -87,7 +87,7 @@ async def test_update_all_plugin_basic_is_new( new=tmp_path / "zhenxun", ) mocker.patch( - "zhenxun.builtin_plugins.plugin_store.data_source.ShopManage.get_loaded_plugins", + "zhenxun.builtin_plugins.plugin_store.data_source.StoreManager.get_loaded_plugins", return_value=[("search_image", "0.1")], ) diff --git a/tests/builtin_plugins/plugin_store/test_update_plugin.py b/tests/builtin_plugins/plugin_store/test_update_plugin.py index 952191d6..2cb88d1b 100644 --- a/tests/builtin_plugins/plugin_store/test_update_plugin.py +++ b/tests/builtin_plugins/plugin_store/test_update_plugin.py @@ -32,7 +32,7 @@ async def test_update_plugin_basic_need_update( new=tmp_path / "zhenxun", ) mocker.patch( - "zhenxun.builtin_plugins.plugin_store.data_source.ShopManage.get_loaded_plugins", + "zhenxun.builtin_plugins.plugin_store.data_source.StoreManager.get_loaded_plugins", return_value=[("search_image", "0.0")], ) @@ -87,7 +87,7 @@ async def test_update_plugin_basic_is_new( new=tmp_path / "zhenxun", ) mocker.patch( - "zhenxun.builtin_plugins.plugin_store.data_source.ShopManage.get_loaded_plugins", + "zhenxun.builtin_plugins.plugin_store.data_source.StoreManager.get_loaded_plugins", return_value=[("search_image", "0.1")], ) diff --git a/tests/conftest.py b/tests/conftest.py index 0fce1583..d6a7e9fa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -116,6 +116,7 @@ async def app(app: App, tmp_path: Path, mocker: MockerFixture): await init() # await driver._lifespan.startup() os.environ["AIOCACHE_DISABLE"] = "1" + os.environ["PYTEST_CURRENT_TEST"] = "1" yield app diff --git a/tests/response/plugin_store/basic_plugins.json b/tests/response/plugin_store/basic_plugins.json index 7459e2ec..f0306836 100644 --- a/tests/response/plugin_store/basic_plugins.json +++ b/tests/response/plugin_store/basic_plugins.json @@ -1,5 +1,6 @@ -{ - "鸡汤": { +[ + { + "name": "鸡汤", "module": "jitang", "module_path": "plugins.alapi.jitang", "description": "喏,亲手为你煮的鸡汤", @@ -9,7 +10,8 @@ "plugin_type": "NORMAL", "is_dir": false }, - "识图": { + { + "name": "识图", "module": "search_image", "module_path": "plugins.search_image", "description": "以图搜图,看破本源", @@ -19,7 +21,8 @@ "plugin_type": "NORMAL", "is_dir": true }, - "网易云热评": { + { + "name": "网易云热评", "module": "comments_163", "module_path": "plugins.alapi.comments_163", "description": "生了个人,我很抱歉", @@ -29,7 +32,8 @@ "plugin_type": "NORMAL", "is_dir": false }, - "B站订阅": { + { + "name": "B站订阅", "module": "bilibili_sub", "module_path": "plugins.bilibili_sub", "description": "非常便利的B站订阅通知", @@ -39,4 +43,4 @@ "plugin_type": "NORMAL", "is_dir": true } -} +] diff --git a/tests/response/plugin_store/extra_plugins.json b/tests/response/plugin_store/extra_plugins.json index 9d92f859..ca5e7f0a 100644 --- a/tests/response/plugin_store/extra_plugins.json +++ b/tests/response/plugin_store/extra_plugins.json @@ -1,5 +1,6 @@ -{ - "github订阅": { +[ + { + "name": "github订阅", "module": "github_sub", "module_path": "github_sub", "description": "订阅github用户或仓库", @@ -10,7 +11,8 @@ "is_dir": true, "github_url": "https://github.com/xuanerwa/zhenxun_github_sub" }, - "Minecraft查服": { + { + "name": "Minecraft查服", "module": "mc_check", "module_path": "mc_check", "description": "Minecraft服务器状态查询,支持IPv6", @@ -21,4 +23,4 @@ "is_dir": true, "github_url": "https://github.com/molanp/zhenxun_check_Minecraft" } -} +] diff --git a/zhenxun/builtin_plugins/__init__.py b/zhenxun/builtin_plugins/__init__.py index fbaeb280..f2688905 100644 --- a/zhenxun/builtin_plugins/__init__.py +++ b/zhenxun/builtin_plugins/__init__.py @@ -16,6 +16,7 @@ from zhenxun.models.sign_user import SignUser from zhenxun.models.user_console import UserConsole from zhenxun.services.log import logger from zhenxun.utils.decorator.shop import shop_register +from zhenxun.utils.manager.priority_manager import PriorityLifecycle from zhenxun.utils.manager.resource_manager import ResourceManager from zhenxun.utils.platform import PlatformUtils @@ -70,7 +71,7 @@ from public.bag_users t1 """ -@driver.on_startup +@PriorityLifecycle.on_startup(priority=5) async def _(): await ResourceManager.init_resources() """签到与用户的数据迁移""" diff --git a/zhenxun/builtin_plugins/about.py b/zhenxun/builtin_plugins/about.py index faa0ba0e..31c77bc7 100644 --- a/zhenxun/builtin_plugins/about.py +++ b/zhenxun/builtin_plugins/about.py @@ -26,6 +26,21 @@ __plugin_meta__ = PluginMetadata( _matcher = on_alconna(Alconna("关于"), priority=5, block=True, rule=to_me()) +QQ_INFO = """ +『绪山真寻Bot』 +版本:{version} +简介:基于Nonebot2开发,支持多平台,是一个非常可爱的Bot呀,希望与大家要好好相处 +""".strip() + +INFO = """ +『绪山真寻Bot』 +版本:{version} +简介:基于Nonebot2开发,支持多平台,是一个非常可爱的Bot呀,希望与大家要好好相处 +项目地址:https://github.com/zhenxun-org/zhenxun_bot +文档地址:https://zhenxun-org.github.io/zhenxun_bot/ +""".strip() + + @_matcher.handle() async def _(session: Uninfo, arparma: Arparma): ver_file = Path() / "__version__" @@ -35,25 +50,11 @@ async def _(session: Uninfo, arparma: Arparma): if text := await f.read(): version = text.split(":")[-1].strip() if PlatformUtils.is_qbot(session): - info: list[str | Path] = [ - f""" -『绪山真寻Bot』 -版本:{version} -简介:基于Nonebot2开发,支持多平台,是一个非常可爱的Bot呀,希望与大家要好好相处 - """.strip() - ] + result: list[str | Path] = [QQ_INFO.format(version=version)] path = DATA_PATH / "about.png" if path.exists(): - info.append(path) + result.append(path) + await MessageUtils.build_message(result).send() # type: ignore else: - info = [ - f""" -『绪山真寻Bot』 -版本:{version} -简介:基于Nonebot2开发,支持多平台,是一个非常可爱的Bot呀,希望与大家要好好相处 -项目地址:https://github.com/HibiKier/zhenxun_bot -文档地址:https://hibikier.github.io/zhenxun_bot/ - """.strip() - ] - await MessageUtils.build_message(info).send() # type: ignore - logger.info("查看关于", arparma.header_result, session=session) + await MessageUtils.build_message(INFO.format(version=version)).send() + logger.info("查看关于", arparma.header_result, session=session) diff --git a/zhenxun/builtin_plugins/admin/ban/__init__.py b/zhenxun/builtin_plugins/admin/ban/__init__.py index 91bbf2ba..32e97f2d 100644 --- a/zhenxun/builtin_plugins/admin/ban/__init__.py +++ b/zhenxun/builtin_plugins/admin/ban/__init__.py @@ -14,13 +14,19 @@ from nonebot_plugin_alconna import ( from nonebot_plugin_session import EventSession from zhenxun.configs.config import BotConfig, Config -from zhenxun.configs.utils import PluginExtraData, RegisterConfig +from zhenxun.configs.utils import ( + AICallableParam, + AICallableProperties, + AICallableTag, + PluginExtraData, + RegisterConfig, +) from zhenxun.services.log import logger from zhenxun.utils.enum import PluginType from zhenxun.utils.message import MessageUtils from zhenxun.utils.rules import admin_check -from ._data_source import BanManage +from ._data_source import BanManage, call_ban base_config = Config.get("ban") @@ -78,6 +84,22 @@ __plugin_meta__ = PluginMetadata( type=int, ) ], + smart_tools=[ + AICallableTag( + name="call_ban", + description="某人多次(至少三次)辱骂你,调用此方法进行封禁", + parameters=AICallableParam( + type="object", + properties={ + "user_id": AICallableProperties( + type="string", description="用户的id" + ), + }, + required=["user_id"], + ), + func=call_ban, + ) + ], ).to_dict(), ) diff --git a/zhenxun/builtin_plugins/admin/ban/_data_source.py b/zhenxun/builtin_plugins/admin/ban/_data_source.py index f38d2440..ae465bdf 100644 --- a/zhenxun/builtin_plugins/admin/ban/_data_source.py +++ b/zhenxun/builtin_plugins/admin/ban/_data_source.py @@ -5,9 +5,20 @@ from nonebot_plugin_session import EventSession from zhenxun.models.ban_console import BanConsole from zhenxun.models.level_user import LevelUser +from zhenxun.services.log import logger from zhenxun.utils.image_utils import BuildImage, ImageTemplate +async def call_ban(user_id: str): + """调用ban + + 参数: + user_id: 用户id + """ + await BanConsole.ban(user_id, None, 9, 60 * 12) + logger.info("辱骂次数过多,已将用户加入黑名单...", "ban", session=user_id) + + class BanManage: @classmethod async def build_ban_image( diff --git a/zhenxun/builtin_plugins/admin/welcome_message/data_source.py b/zhenxun/builtin_plugins/admin/welcome_message/data_source.py index 2ccb33ee..c8e486ed 100644 --- a/zhenxun/builtin_plugins/admin/welcome_message/data_source.py +++ b/zhenxun/builtin_plugins/admin/welcome_message/data_source.py @@ -14,6 +14,7 @@ from zhenxun.services.log import logger from zhenxun.utils._build_image import BuildImage from zhenxun.utils._image_template import ImageTemplate from zhenxun.utils.http_utils import AsyncHttpx +from zhenxun.utils.manager.priority_manager import PriorityLifecycle from zhenxun.utils.platform import PlatformUtils BASE_PATH = DATA_PATH / "welcome_message" @@ -91,7 +92,7 @@ def migrate(path: Path): json.dump(new_data, f, ensure_ascii=False, indent=4) -@driver.on_startup +@PriorityLifecycle.on_startup(priority=5) def _(): """数据迁移 diff --git a/zhenxun/builtin_plugins/chat_history/chat_message_handle.py b/zhenxun/builtin_plugins/chat_history/chat_message_handle.py index 10cfcf43..d9eae97f 100644 --- a/zhenxun/builtin_plugins/chat_history/chat_message_handle.py +++ b/zhenxun/builtin_plugins/chat_history/chat_message_handle.py @@ -1,4 +1,5 @@ from datetime import datetime, timedelta +from io import BytesIO from nonebot.plugin import PluginMetadata from nonebot_plugin_alconna import ( @@ -14,35 +15,38 @@ from nonebot_plugin_alconna import ( from nonebot_plugin_session import EventSession import pytz -from zhenxun.configs.utils import Command, PluginExtraData +from zhenxun.configs.config import Config +from zhenxun.configs.utils import Command, PluginExtraData, RegisterConfig from zhenxun.models.chat_history import ChatHistory from zhenxun.models.group_member_info import GroupInfoUser from zhenxun.services.log import logger from zhenxun.utils.enum import PluginType -from zhenxun.utils.image_utils import ImageTemplate +from zhenxun.utils.image_utils import BuildImage, ImageTemplate from zhenxun.utils.message import MessageUtils +from zhenxun.utils.platform import PlatformUtils __plugin_meta__ = PluginMetadata( name="消息统计", description="消息统计查询", usage=""" 格式: - 消息排行 ?[type [日,周,月,年]] ?[--des] + 消息排行 ?[type [日,周,月,季,年]] ?[--des] 快捷: - [日,周,月,年]消息排行 ?[数量] + [日,周,月,季,年]消息排行 ?[数量] 示例: 消息排行 : 所有记录排行 日消息排行 : 今日记录排行 - 周消息排行 : 今日记录排行 - 月消息排行 : 今日记录排行 - 年消息排行 : 今日记录排行 + 周消息排行 : 本周记录排行 + 月消息排行 : 本月记录排行 + 季消息排行 : 本季度记录排行 + 年消息排行 : 本年记录排行 消息排行 周 --des : 逆序周记录排行 """.strip(), extra=PluginExtraData( author="HibiKier", - version="0.1", + version="0.2", plugin_type=PluginType.NORMAL, menu_type="数据统计", commands=[ @@ -50,8 +54,19 @@ __plugin_meta__ = PluginMetadata( Command(command="日消息统计"), Command(command="周消息排行"), Command(command="月消息排行"), + Command(command="季消息排行"), Command(command="年消息排行"), ], + configs=[ + RegisterConfig( + module="chat_history", + key="SHOW_QUIT_MEMBER", + value=True, + help="是否在消息排行中显示已退群用户", + default_value=True, + type=bool, + ) + ], ).to_dict(), ) @@ -60,7 +75,7 @@ _matcher = on_alconna( Alconna( "消息排行", Option("--des", action=store_true, help_text="逆序"), - Args["type?", ["日", "周", "月", "年"]]["count?", int, 10], + Args["type?", ["日", "周", "月", "季", "年"]]["count?", int, 10], ), aliases={"消息统计"}, priority=5, @@ -68,7 +83,7 @@ _matcher = on_alconna( ) _matcher.shortcut( - r"(?P['日', '周', '月', '年'])?消息(排行|统计)\s?(?P\d+)?", + r"(?P['日', '周', '月', '季', '年'])?消息(排行|统计)\s?(?P\d+)?", command="消息排行", arguments=["{type}", "{cnt}"], prefix=True, @@ -96,20 +111,57 @@ async def _( date_scope = (time_now - timedelta(days=7), time_now) elif date in ["月"]: date_scope = (time_now - timedelta(days=30), time_now) - column_name = ["名次", "昵称", "发言次数"] + elif date in ["季"]: + date_scope = (time_now - timedelta(days=90), time_now) + column_name = ["名次", "头像", "昵称", "发言次数"] + show_quit_member = Config.get_config("chat_history", "SHOW_QUIT_MEMBER", True) + + fetch_count = count.result + if not show_quit_member: + fetch_count = count.result * 2 + if rank_data := await ChatHistory.get_group_msg_rank( - group_id, count.result, "DES" if arparma.find("des") else "DESC", date_scope + group_id, fetch_count, "DES" if arparma.find("des") else "DESC", date_scope ): idx = 1 data_list = [] + for uid, num in rank_data: - if user := await GroupInfoUser.filter( + if len(data_list) >= count.result: + break + + user_in_group = await GroupInfoUser.filter( user_id=uid, group_id=group_id - ).first(): - user_name = user.user_name + ).first() + + if not user_in_group and not show_quit_member: + continue + + if user_in_group: + user_name = user_in_group.user_name else: - user_name = uid - data_list.append([idx, user_name, num]) + user_name = f"{uid}(已退群)" + + avatar_size = 40 + try: + avatar_bytes = await PlatformUtils.get_user_avatar(str(uid), "qq") + if avatar_bytes: + avatar_img = BuildImage( + avatar_size, avatar_size, background=BytesIO(avatar_bytes) + ) + await avatar_img.circle() + avatar_tuple = (avatar_img, avatar_size, avatar_size) + else: + avatar_img = BuildImage(avatar_size, avatar_size, color="#CCCCCC") + await avatar_img.circle() + avatar_tuple = (avatar_img, avatar_size, avatar_size) + except Exception as e: + logger.warning(f"获取用户头像失败: {e}", "chat_history") + avatar_img = BuildImage(avatar_size, avatar_size, color="#CCCCCC") + await avatar_img.circle() + avatar_tuple = (avatar_img, avatar_size, avatar_size) + + data_list.append([idx, avatar_tuple, user_name, num]) idx += 1 if not date_scope: if date_scope := await ChatHistory.get_group_first_msg_datetime(group_id): @@ -132,13 +184,3 @@ async def _( ) await MessageUtils.build_message(A).finish(reply_to=True) await MessageUtils.build_message("群组消息记录为空...").finish() - - -# # @test.handle() -# # async def _(event: MessageEvent): -# # print(await ChatHistory.get_user_msg(event.user_id, "private")) -# # print(await ChatHistory.get_user_msg_count(event.user_id, "private")) -# # print(await ChatHistory.get_user_msg(event.user_id, "group")) -# # print(await ChatHistory.get_user_msg_count(event.user_id, "group")) -# # print(await ChatHistory.get_group_msg(event.group_id)) -# # print(await ChatHistory.get_group_msg_count(event.group_id)) diff --git a/zhenxun/builtin_plugins/help/__init__.py b/zhenxun/builtin_plugins/help/__init__.py index 726d4d1e..35edf114 100644 --- a/zhenxun/builtin_plugins/help/__init__.py +++ b/zhenxun/builtin_plugins/help/__init__.py @@ -37,10 +37,16 @@ __plugin_meta__ = PluginMetadata( configs=[ RegisterConfig( key="type", - value="normal", - help="帮助图片样式 ['normal', 'HTML', 'zhenxun']", + value="zhenxun", + help="帮助图片样式 [normal, HTML, zhenxun]", default_value="zhenxun", - ) + ), + RegisterConfig( + key="detail_type", + value="zhenxun", + help="帮助详情图片样式 ['normal', 'zhenxun']", + default_value="zhenxun", + ), ], ).to_dict(), ) diff --git a/zhenxun/builtin_plugins/help/_data_source.py b/zhenxun/builtin_plugins/help/_data_source.py index cfaa4503..23e9ec1b 100644 --- a/zhenxun/builtin_plugins/help/_data_source.py +++ b/zhenxun/builtin_plugins/help/_data_source.py @@ -1,13 +1,19 @@ from pathlib import Path import nonebot +from nonebot.plugin import PluginMetadata +from nonebot_plugin_htmlrender import template_to_pic from nonebot_plugin_uninfo import Uninfo -from zhenxun.configs.path_config import IMAGE_PATH +from zhenxun.configs.config import Config +from zhenxun.configs.path_config import IMAGE_PATH, TEMPLATE_PATH +from zhenxun.configs.utils import PluginExtraData from zhenxun.models.level_user import LevelUser from zhenxun.models.plugin_info import PluginInfo +from zhenxun.models.statistics import Statistics +from zhenxun.utils._image_template import ImageTemplate from zhenxun.utils.enum import PluginType -from zhenxun.utils.image_utils import BuildImage, ImageTemplate +from zhenxun.utils.image_utils import BuildImage from ._config import ( GROUP_HELP_PATH, @@ -40,7 +46,9 @@ async def create_help_img( match help_type: case "html": - result = BuildImage.open(await build_html_image(group_id, is_detail)) + result = BuildImage.open( + await build_html_image(session, group_id, is_detail) + ) case "zhenxun": result = BuildImage.open( await build_zhenxun_image(session, group_id, is_detail) @@ -78,9 +86,96 @@ async def get_user_allow_help(user_id: str) -> list[PluginType]: return type_list -async def get_plugin_help( - user_id: str, name: str, is_superuser: bool -) -> str | BuildImage: +async def get_normal_help( + metadata: PluginMetadata, extra: PluginExtraData, is_superuser: bool +) -> str | bytes: + """构建默认帮助详情 + + 参数: + metadata: PluginMetadata + extra: PluginExtraData + is_superuser: 是否超级用户帮助 + + 返回: + str | bytes: 返回信息 + """ + items = None + if is_superuser: + if usage := extra.superuser_help: + items = { + "简介": metadata.description, + "用法": usage, + } + else: + items = { + "简介": metadata.description, + "用法": metadata.usage, + } + if items: + return (await ImageTemplate.hl_page(metadata.name, items)).pic2bytes() + return "该功能没有帮助信息" + + +def min_leading_spaces(str_list: list[str]) -> int: + min_spaces = 9999 + + for s in str_list: + leading_spaces = len(s) - len(s.lstrip(" ")) + + if leading_spaces < min_spaces: + min_spaces = leading_spaces + + return min_spaces if min_spaces != 9999 else 0 + + +def split_text(text: str): + split_text = text.split("\n") + min_spaces = min_leading_spaces(split_text) + if min_spaces > 0: + split_text = [s[min_spaces:] for s in split_text] + return [s.replace(" ", " ") for s in split_text] + + +async def get_zhenxun_help( + module: str, metadata: PluginMetadata, extra: PluginExtraData, is_superuser: bool +) -> str | bytes: + """构建ZhenXun帮助详情 + + 参数: + module: 模块名 + metadata: PluginMetadata + extra: PluginExtraData + is_superuser: 是否超级用户帮助 + + 返回: + str | bytes: 返回信息 + """ + call_count = await Statistics.filter(plugin_name=module).count() + usage = metadata.usage + if is_superuser: + if not extra.superuser_help: + return "该功能没有超级用户帮助信息" + usage = extra.superuser_help + return await template_to_pic( + template_path=str((TEMPLATE_PATH / "help_detail").absolute()), + template_name="main.html", + templates={ + "title": metadata.name, + "author": extra.author, + "version": extra.version, + "call_count": call_count, + "descriptions": split_text(metadata.description), + "usages": split_text(usage), + }, + pages={ + "viewport": {"width": 824, "height": 590}, + "base_url": f"file://{TEMPLATE_PATH}", + }, + wait=2, + ) + + +async def get_plugin_help(user_id: str, name: str, is_superuser: bool) -> str | bytes: """获取功能的帮助信息 参数: @@ -98,20 +193,12 @@ async def get_plugin_help( if plugin: _plugin = nonebot.get_plugin_by_module_name(plugin.module_path) if _plugin and _plugin.metadata: - items = None - if is_superuser: - extra = _plugin.metadata.extra - if usage := extra.get("superuser_help"): - items = { - "简介": _plugin.metadata.description, - "用法": usage, - } + extra_data = PluginExtraData(**_plugin.metadata.extra) + if Config.get_config("help", "detail_type") == "zhenxun": + return await get_zhenxun_help( + plugin.module, _plugin.metadata, extra_data, is_superuser + ) else: - items = { - "简介": _plugin.metadata.description, - "用法": _plugin.metadata.usage, - } - if items: - return await ImageTemplate.hl_page(plugin.name, items) + return await get_normal_help(_plugin.metadata, extra_data, is_superuser) return "糟糕! 该功能没有帮助喔..." return "没有查找到这个功能噢..." diff --git a/zhenxun/builtin_plugins/help/_utils.py b/zhenxun/builtin_plugins/help/_utils.py index 6c382c7d..0554fc8d 100644 --- a/zhenxun/builtin_plugins/help/_utils.py +++ b/zhenxun/builtin_plugins/help/_utils.py @@ -1,5 +1,8 @@ from collections.abc import Callable +from nonebot_plugin_uninfo import Uninfo + +from zhenxun.models.bot_console import BotConsole from zhenxun.models.group_console import GroupConsole from zhenxun.models.plugin_info import PluginInfo from zhenxun.utils.enum import PluginType @@ -27,13 +30,15 @@ async def sort_type() -> dict[str, list[PluginInfo]]: async def classify_plugin( - group_id: str | None, is_detail: bool, handle: Callable + session: Uninfo, group_id: str | None, is_detail: bool, handle: Callable ) -> dict[str, list]: """对插件进行分类并判断状态 参数: + session: Uninfo对象 group_id: 群组id is_detail: 是否详细帮助 + handle: 回调方法 返回: dict[str, list[Item]]: 分类插件数据 @@ -41,9 +46,10 @@ async def classify_plugin( sort_data = await sort_type() classify: dict[str, list] = {} group = await GroupConsole.get_or_none(group_id=group_id) if group_id else None + bot = await BotConsole.get_or_none(bot_id=session.self_id) for menu, value in sort_data.items(): for plugin in value: if not classify.get(menu): classify[menu] = [] - classify[menu].append(handle(plugin, group, is_detail)) + classify[menu].append(handle(bot, plugin, group, is_detail)) return classify diff --git a/zhenxun/builtin_plugins/help/html_help.py b/zhenxun/builtin_plugins/help/html_help.py index 1815b99a..7c552a0d 100644 --- a/zhenxun/builtin_plugins/help/html_help.py +++ b/zhenxun/builtin_plugins/help/html_help.py @@ -2,9 +2,11 @@ import os import random from nonebot_plugin_htmlrender import template_to_pic +from nonebot_plugin_uninfo import Uninfo from pydantic import BaseModel from zhenxun.configs.path_config import TEMPLATE_PATH +from zhenxun.models.bot_console import BotConsole from zhenxun.models.group_console import GroupConsole from zhenxun.models.plugin_info import PluginInfo from zhenxun.utils.enum import BlockType @@ -48,11 +50,12 @@ ICON2STR = { def __handle_item( - plugin: PluginInfo, group: GroupConsole | None, is_detail: bool + bot: BotConsole, plugin: PluginInfo, group: GroupConsole | None, is_detail: bool ) -> Item: """构造Item 参数: + bot: BotConsole plugin: PluginInfo group: 群组 is_detail: 是否详细 @@ -73,10 +76,13 @@ def __handle_item( ]: sta = 2 if group: - if f"{plugin.module}:super," in group.block_plugin: + if f"{plugin.module}," in group.superuser_block_plugin: sta = 2 if f"{plugin.module}," in group.block_plugin: sta = 1 + if bot: + if f"{plugin.module}," in bot.block_plugins: + sta = 2 return Item(plugin_name=plugin.name, sta=sta) @@ -119,14 +125,17 @@ def build_plugin_data(classify: dict[str, list[Item]]) -> list[dict[str, str]]: return plugin_list -async def build_html_image(group_id: str | None, is_detail: bool) -> bytes: +async def build_html_image( + session: Uninfo, group_id: str | None, is_detail: bool +) -> bytes: """构造HTML帮助图片 参数: + session: Uninfo group_id: 群号 is_detail: 是否详细帮助 """ - classify = await classify_plugin(group_id, is_detail, __handle_item) + classify = await classify_plugin(session, group_id, is_detail, __handle_item) plugin_list = build_plugin_data(classify) return await template_to_pic( template_path=str((TEMPLATE_PATH / "menu").absolute()), diff --git a/zhenxun/builtin_plugins/help/zhenxun_help.py b/zhenxun/builtin_plugins/help/zhenxun_help.py index f6d930e6..b96d3c59 100644 --- a/zhenxun/builtin_plugins/help/zhenxun_help.py +++ b/zhenxun/builtin_plugins/help/zhenxun_help.py @@ -6,6 +6,7 @@ from pydantic import BaseModel from zhenxun.configs.config import BotConfig from zhenxun.configs.path_config import TEMPLATE_PATH from zhenxun.configs.utils import PluginExtraData +from zhenxun.models.bot_console import BotConsole from zhenxun.models.group_console import GroupConsole from zhenxun.models.plugin_info import PluginInfo from zhenxun.utils.enum import BlockType @@ -21,12 +22,19 @@ class Item(BaseModel): """插件命令""" -def __handle_item(plugin: PluginInfo, group: GroupConsole | None, is_detail: bool): +def __handle_item( + bot: BotConsole | None, + plugin: PluginInfo, + group: GroupConsole | None, + is_detail: bool, +): """构造Item 参数: + bot: BotConsole plugin: PluginInfo group: 群组 + is_detail: 是否为详细 返回: Item: Item @@ -40,6 +48,8 @@ def __handle_item(plugin: PluginInfo, group: GroupConsole | None, is_detail: boo plugin.name = f"{plugin.name}(不可用)" elif group and f"{plugin.module}," in group.block_plugin: plugin.name = f"{plugin.name}(不可用)" + elif bot and f"{plugin.module}," in bot.block_plugins: + plugin.name = f"{plugin.name}(不可用)" commands = [] nb_plugin = nonebot.get_plugin_by_module_name(plugin.module_path) if is_detail and nb_plugin and nb_plugin.metadata and nb_plugin.metadata.extra: @@ -142,7 +152,7 @@ async def build_zhenxun_image( group_id: 群号 is_detail: 是否详细帮助 """ - classify = await classify_plugin(group_id, is_detail, __handle_item) + classify = await classify_plugin(session, group_id, is_detail, __handle_item) plugin_list = build_plugin_data(classify) platform = PlatformUtils.get_platform(session) bot_id = BotConfig.get_qbot_uid(session.self_id) or session.self_id diff --git a/zhenxun/builtin_plugins/help_help.py b/zhenxun/builtin_plugins/help_help.py index fec04a8d..6b5ecce9 100644 --- a/zhenxun/builtin_plugins/help_help.py +++ b/zhenxun/builtin_plugins/help_help.py @@ -21,7 +21,7 @@ from zhenxun.utils.message import MessageUtils __plugin_meta__ = PluginMetadata( name="笨蛋检测", description="功能名称当命令检测", - usage="""被动""".strip(), + usage="""当一些笨蛋直接输入功能名称时,提示笨蛋使用帮助指令查看功能帮助""".strip(), extra=PluginExtraData( author="HibiKier", version="0.1", diff --git a/zhenxun/builtin_plugins/hooks/__init__.py b/zhenxun/builtin_plugins/hooks/__init__.py index 3ad29d71..2f8c79de 100644 --- a/zhenxun/builtin_plugins/hooks/__init__.py +++ b/zhenxun/builtin_plugins/hooks/__init__.py @@ -49,4 +49,14 @@ Config.add_plugin_config( type=bool, ) +Config.add_plugin_config( + "hook", + "RECORD_BOT_SENT_MESSAGES", + True, + help="记录bot消息发送", + default_value=True, + type=bool, +) + + nonebot.load_plugins(str(Path(__file__).parent.resolve())) diff --git a/zhenxun/builtin_plugins/hooks/call_hook.py b/zhenxun/builtin_plugins/hooks/call_hook.py index 2ff4d39c..1893754d 100644 --- a/zhenxun/builtin_plugins/hooks/call_hook.py +++ b/zhenxun/builtin_plugins/hooks/call_hook.py @@ -1,23 +1,85 @@ from typing import Any -from nonebot.adapters import Bot +from nonebot.adapters import Bot, Message +from zhenxun.configs.config import Config +from zhenxun.models.bot_message_store import BotMessageStore from zhenxun.services.log import logger +from zhenxun.utils.enum import BotSentType from zhenxun.utils.manager.message_manager import MessageManager +from zhenxun.utils.platform import PlatformUtils + + +def replace_message(message: Message) -> str: + """将消息中的at、image、record、face替换为字符串 + + 参数: + message: Message + + 返回: + str: 文本消息 + """ + result = "" + for msg in message: + if isinstance(msg, str): + result += msg + elif msg.type == "at": + result += f"@{msg.data['qq']}" + elif msg.type == "image": + result += "[image]" + elif msg.type == "record": + result += "[record]" + elif msg.type == "face": + result += f"[face:{msg.data['id']}]" + elif msg.type == "reply": + result += "" + else: + result += str(msg) + return result @Bot.on_called_api async def handle_api_result( bot: Bot, exception: Exception | None, api: str, data: dict[str, Any], result: Any ): - if not exception and api == "send_msg": - try: - if (uid := data.get("user_id")) and (msg_id := result.get("message_id")): - MessageManager.add(str(uid), str(msg_id)) - logger.debug( - f"收集消息id,user_id: {uid}, msg_id: {msg_id}", "msg_hook" - ) - except Exception as e: - logger.warning( - f"收集消息id发生错误...data: {data}, result: {result}", "msg_hook", e=e + if exception or api != "send_msg": + return + user_id = data.get("user_id") + group_id = data.get("group_id") + message_id = result.get("message_id") + message: Message = data.get("message", "") + message_type = data.get("message_type") + try: + # 记录消息id + if user_id and message_id: + MessageManager.add(str(user_id), str(message_id)) + logger.debug( + f"收集消息id,user_id: {user_id}, msg_id: {message_id}", "msg_hook" ) + except Exception as e: + logger.warning( + f"收集消息id发生错误...data: {data}, result: {result}", "msg_hook", e=e + ) + if not Config.get_config("hook", "RECORD_BOT_SENT_MESSAGES"): + return + try: + await BotMessageStore.create( + bot_id=bot.self_id, + user_id=user_id, + group_id=group_id, + sent_type=BotSentType.GROUP + if message_type == "group" + else BotSentType.PRIVATE, + text=replace_message(message), + plain_text=message.extract_plain_text() + if isinstance(message, Message) + else replace_message(message), + platform=PlatformUtils.get_platform(bot), + ) + logger.debug(f"消息发送记录,message: {message}") + except Exception as e: + logger.warning( + f"消息发送记录发生错误...data: {data}, result: {result}", + "msg_hook", + e=e, + ) diff --git a/zhenxun/builtin_plugins/init/init_config.py b/zhenxun/builtin_plugins/init/init_config.py index 112d29de..51a7da47 100644 --- a/zhenxun/builtin_plugins/init/init_config.py +++ b/zhenxun/builtin_plugins/init/init_config.py @@ -11,6 +11,7 @@ from zhenxun.configs.config import Config from zhenxun.configs.path_config import DATA_PATH from zhenxun.configs.utils import RegisterConfig from zhenxun.services.log import logger +from zhenxun.utils.manager.priority_manager import PriorityLifecycle _yaml = YAML(pure=True) _yaml.allow_unicode = True @@ -57,7 +58,7 @@ def _generate_simple_config(exists_module: list[str]): 生成简易配置 异常: - AttributeError: _description_ + AttributeError: AttributeError """ # 读取用户配置 _data = {} @@ -73,7 +74,9 @@ def _generate_simple_config(exists_module: list[str]): if _data.get(module) and k in _data[module].keys(): Config.set_config(module, k, _data[module][k]) if f"{module}:{k}".lower() in exists_module: - _tmp_data[module][k] = Config.get_config(module, k) + _tmp_data[module][k] = Config.get_config( + module, k, build_model=False + ) except AttributeError as e: raise AttributeError(f"{e}\n可能为config.yaml配置文件填写不规范") from e if not _tmp_data[module]: @@ -102,7 +105,7 @@ def _generate_simple_config(exists_module: list[str]): temp_file.unlink() -@driver.on_startup +@PriorityLifecycle.on_startup(priority=0) def _(): """ 初始化插件数据配置 @@ -125,3 +128,4 @@ def _(): with plugins2config_file.open("w", encoding="utf8") as wf: _yaml.dump(_data, wf) _generate_simple_config(exists_module) + Config.reload() diff --git a/zhenxun/builtin_plugins/init/init_plugin.py b/zhenxun/builtin_plugins/init/init_plugin.py index dbeddb54..5bf50409 100644 --- a/zhenxun/builtin_plugins/init/init_plugin.py +++ b/zhenxun/builtin_plugins/init/init_plugin.py @@ -20,6 +20,7 @@ from zhenxun.utils.enum import ( PluginLimitType, PluginType, ) +from zhenxun.utils.manager.priority_manager import PriorityLifecycle from .manager import manager @@ -95,7 +96,7 @@ async def _handle_setting( ) -@driver.on_startup +@PriorityLifecycle.on_startup(priority=5) async def _(): """ 初始化插件数据配置 diff --git a/zhenxun/builtin_plugins/init/init_task.py b/zhenxun/builtin_plugins/init/init_task.py index cead7d72..b9bab56d 100644 --- a/zhenxun/builtin_plugins/init/init_task.py +++ b/zhenxun/builtin_plugins/init/init_task.py @@ -10,6 +10,7 @@ from zhenxun.models.group_console import GroupConsole from zhenxun.models.task_info import TaskInfo from zhenxun.services.log import logger from zhenxun.utils.common_utils import CommonUtils +from zhenxun.utils.manager.priority_manager import PriorityLifecycle driver: Driver = nonebot.get_driver() @@ -132,7 +133,7 @@ async def create_schedule(task: Task): logger.error(f"动态创建定时任务 {task.name}({task.module}) 失败", e=e) -@driver.on_startup +@PriorityLifecycle.on_startup(priority=5) async def _(): """ 初始化插件数据配置 diff --git a/zhenxun/builtin_plugins/mahiro_bank/__init__.py b/zhenxun/builtin_plugins/mahiro_bank/__init__.py new file mode 100644 index 00000000..8e82cf08 --- /dev/null +++ b/zhenxun/builtin_plugins/mahiro_bank/__init__.py @@ -0,0 +1,252 @@ +from datetime import datetime + +from nonebot.plugin import PluginMetadata +from nonebot_plugin_alconna import Alconna, Args, Arparma, Match, Subcommand, on_alconna +from nonebot_plugin_apscheduler import scheduler +from nonebot_plugin_uninfo import Uninfo +from nonebot_plugin_waiter import prompt_until + +from zhenxun.configs.utils import PluginExtraData, RegisterConfig +from zhenxun.services.log import logger +from zhenxun.utils.depends import UserName +from zhenxun.utils.message import MessageUtils +from zhenxun.utils.utils import is_number + +from .data_source import BankManager + +__plugin_meta__ = PluginMetadata( + name="小真寻银行", + description=""" + 小真寻银行,提供高品质的存款!当好感度等级达到指初识时,小真寻会偷偷的帮助你哦。 + 存款额度与好感度有关,每日存款次数有限制。 + 基础存款提供基础利息 + 每日存款提供高额利息 + """.strip(), + usage=""" + 指令: + 存款 [金额] + 取款 [金额] + 银行信息 + 我的银行信息 + """.strip(), + extra=PluginExtraData( + author="HibiKier", + version="0.1", + menu_type="群内小游戏", + configs=[ + RegisterConfig( + key="sign_max_deposit", + value=100, + help="好感度换算存款金额比例,当值是100时,最大存款金额=好感度*100,存款的最低金额是100(强制)", + default_value=100, + type=int, + ), + RegisterConfig( + key="max_daily_deposit_count", + value=3, + help="每日最大存款次数", + default_value=3, + type=int, + ), + RegisterConfig( + key="rate_range", + value=[0.0005, 0.001], + help="小时利率范围", + default_value=[0.0005, 0.001], + type=list[float], + ), + RegisterConfig( + key="impression_event", + value=25, + help="到达指定好感度时随机提高或降低利率", + default_value=25, + type=int, + ), + RegisterConfig( + key="impression_event_range", + value=[0.00001, 0.0003], + help="到达指定好感度时随机提高或降低利率", + default_value=[0.00001, 0.0003], + type=list[float], + ), + RegisterConfig( + key="impression_event_prop", + value=0.3, + help="到达指定好感度时随机提高或降低利率触发概率", + default_value=0.3, + type=float, + ), + ], + ).to_dict(), +) + + +_matcher = on_alconna( + Alconna( + "mahiro-bank", + Subcommand("deposit", Args["amount?", int]), + Subcommand("withdraw", Args["amount?", int]), + Subcommand("user-info"), + Subcommand("bank-info"), + # Subcommand("loan", Args["amount?", int]), + # Subcommand("repayment", Args["amount?", int]), + ), + priority=5, + block=True, +) + +_matcher.shortcut( + r"存款\s*(?P\d+)?", + command="mahiro-bank", + arguments=["deposit", "{amount}"], + prefix=True, +) + +_matcher.shortcut( + r"取款\s*(?P\d+)?", + command="mahiro-bank", + arguments=["withdraw", "{withdraw}"], + prefix=True, +) + +_matcher.shortcut( + r"我的银行信息", + command="mahiro-bank", + arguments=["user-info"], + prefix=True, +) + +_matcher.shortcut( + r"银行信息", + command="mahiro-bank", + arguments=["bank-info"], + prefix=True, +) + + +async def get_amount(handle_type: str) -> int: + amount_num = await prompt_until( + f"请输入{handle_type}金币数量", + lambda msg: is_number(msg.extract_plain_text()), + timeout=60, + retry=3, + retry_prompt="输入错误,请输入数字。剩余次数:{count}", + ) + if not amount_num: + await MessageUtils.build_message( + "输入超时了哦,小真寻柜员以取消本次存款操作..." + ).finish() + return int(amount_num.extract_plain_text()) + + +@_matcher.assign("deposit") +async def _(session: Uninfo, arparma: Arparma, amount: Match[int]): + amount_num = amount.result if amount.available else await get_amount("存款") + if result := await BankManager.deposit_check(session.user.id, amount_num): + await MessageUtils.build_message(result).finish(reply_to=True) + _, rate, event_rate = await BankManager.deposit(session.user.id, amount_num) + result = ( + f"存款成功!\n此次存款金额为: {amount.result}\n" + f"当前小时利率为: {rate * 100:.2f}%" + ) + effective_hour = int(24 - datetime.now().hour) + if event_rate: + result += f"(小真寻偷偷将小时利率给你增加了 {event_rate:.2f}% 哦)" + result += ( + f"\n预计总收益为: {int(amount.result * rate * effective_hour) or 1} 金币。" + ) + logger.info( + f"小真寻银行存款:{amount_num},当前存款数:{amount.result},存款小时利率: {rate}", + arparma.header_result, + session=session, + ) + await MessageUtils.build_message(result).finish(at_sender=True) + + +@_matcher.assign("withdraw") +async def _(session: Uninfo, arparma: Arparma, amount: Match[int]): + amount_num = amount.result if amount.available else await get_amount("取款") + if result := await BankManager.withdraw_check(session.user.id, amount_num): + await MessageUtils.build_message(result).finish(reply_to=True) + try: + user = await BankManager.withdraw(session.user.id, amount_num) + result = ( + f"取款成功!\n当前取款金额为: {amount_num}\n当前存款金额为: {user.amount}" + ) + logger.info( + f"小真寻银行取款:{amount_num}, 当前存款数:{user.amount}," + f" 存款小时利率:{user.rate}", + arparma.header_result, + session=session, + ) + await MessageUtils.build_message(result).finish(reply_to=True) + except ValueError: + await MessageUtils.build_message("你的银行内的存款数量不足哦...").finish( + reply_to=True + ) + + +@_matcher.assign("user-info") +async def _(session: Uninfo, arparma: Arparma, uname: str = UserName()): + result = await BankManager.get_user_info(session, uname) + await MessageUtils.build_message(result).send() + logger.info("查看银行个人信息", arparma.header_result, session=session) + + +@_matcher.assign("bank-info") +async def _(session: Uninfo, arparma: Arparma): + result = await BankManager.get_bank_info() + await MessageUtils.build_message(result).send() + logger.info("查看银行信息", arparma.header_result, session=session) + + +# @_matcher.assign("loan") +# async def _(session: Uninfo, arparma: Arparma, amount: Match[int]): +# amount_num = amount.result if amount.available else await get_amount("贷款") +# if amount_num <= 0: +# await MessageUtils.build_message("贷款数量必须大于 0 啊笨蛋!").finish() +# try: +# user, event_rate = await BankManager.loan(session.user.id, amount_num) +# result = ( +# f"贷款成功!\n当前贷金额为: {user.loan_amount}" +# f"\n当前利率为: {user.loan_rate * 100}%" +# ) +# if event_rate: +# result += f"(小真寻偷偷将利率给你降低了 {event_rate}% 哦)" +# result += f"\n预计每小时利息为:{int(user.loan_amount * user.loan_rate)}金币。" +# logger.info( +# f"小真寻银行贷款: {amount_num}, 当前贷款数: {user.loan_amount}, " +# f"贷款利率: {user.loan_rate}", +# arparma.header_result, +# session=session, +# ) +# except ValueError: +# await MessageUtils.build_message( +# "贷款数量超过最大限制,请签到提升好感度获取更多额度吧..." +# ).finish(reply_to=True) + + +# @_matcher.assign("repayment") +# async def _(session: Uninfo, arparma: Arparma, amount: Match[int]): +# amount_num = amount.result if amount.available else await get_amount("还款") +# if amount_num <= 0: +# await MessageUtils.build_message("还款数量必须大于 0 啊笨蛋!").finish() +# user = await BankManager.repayment(session.user.id, amount_num) +# result = (f"还款成功!\n当前还款金额为: {amount_num}\n" +# f"当前贷款金额为: {user.loan_amount}") +# logger.info( +# f"小真寻银行还款:{amount_num},当前贷款数:{user.amount}, 贷款利率:{user.rate}", +# arparma.header_result, +# session=session, +# ) +# await MessageUtils.build_message(result).finish(at_sender=True) + + +@scheduler.scheduled_job( + "cron", + hour=0, + minute=0, +) +async def _(): + await BankManager.settlement() + logger.info("小真寻银行结算", "定时任务") diff --git a/zhenxun/builtin_plugins/mahiro_bank/data_source.py b/zhenxun/builtin_plugins/mahiro_bank/data_source.py new file mode 100644 index 00000000..b717e9a4 --- /dev/null +++ b/zhenxun/builtin_plugins/mahiro_bank/data_source.py @@ -0,0 +1,450 @@ +import asyncio +from datetime import datetime, timedelta +import random + +from nonebot_plugin_htmlrender import template_to_pic +from nonebot_plugin_uninfo import Uninfo +from tortoise.expressions import RawSQL +from tortoise.functions import Count, Sum + +from zhenxun.configs.config import Config +from zhenxun.configs.path_config import TEMPLATE_PATH +from zhenxun.models.mahiro_bank import MahiroBank +from zhenxun.models.mahiro_bank_log import MahiroBankLog +from zhenxun.models.sign_user import SignUser +from zhenxun.models.user_console import UserConsole +from zhenxun.utils.enum import BankHandleType, GoldHandle +from zhenxun.utils.platform import PlatformUtils + +base_config = Config.get("mahiro_bank") + + +class BankManager: + @classmethod + async def random_event(cls, impression: float): + """随机事件""" + impression_event = base_config.get("impression_event") + impression_event_prop = base_config.get("impression_event_prop") + impression_event_range = base_config.get("impression_event_range") + if impression >= impression_event and random.random() < impression_event_prop: + """触发好感度事件""" + return random.uniform(impression_event_range[0], impression_event_range[1]) + return None + + @classmethod + async def deposit_check(cls, user_id: str, amount: int) -> str | None: + """检查存款是否合法 + + 参数: + user_id: 用户id + amount: 存款金额 + + 返回: + str | None: 存款信息 + """ + if amount <= 0: + return "存款数量必须大于 0 啊笨蛋!" + user, sign_user, bank_user = await asyncio.gather( + *[ + UserConsole.get_user(user_id), + SignUser.get_user(user_id), + cls.get_user(user_id), + ] + ) + sign_max_deposit: int = base_config.get("sign_max_deposit") + max_deposit = max(int(float(sign_user.impression) * sign_max_deposit), 100) + if user.gold < amount: + return f"金币数量不足,当前你的金币为:{user.gold}." + if bank_user.amount + amount > max_deposit: + return ( + f"存款超过上限,存款上限为:{max_deposit}," + f"当前你的还可以存款金额:{max_deposit - bank_user.amount}。" + ) + max_daily_deposit_count: int = base_config.get("max_daily_deposit_count") + today_deposit_count = len(await cls.get_user_deposit(user_id)) + if today_deposit_count >= max_daily_deposit_count: + return f"存款次数超过上限,每日存款次数上限为:{max_daily_deposit_count}。" + return None + + @classmethod + async def withdraw_check(cls, user_id: str, amount: int) -> str | None: + """检查取款是否合法 + + 参数: + user_id: 用户id + amount: 取款金额 + + 返回: + str | None: 取款信息 + """ + if amount <= 0: + return "取款数量必须大于 0 啊笨蛋!" + user = await cls.get_user(user_id) + data_list = await cls.get_user_deposit(user_id) + lock_amount = sum(data.amount for data in data_list) + if user.amount - lock_amount < amount: + return ( + "取款金额不足,当前你的存款为:" + f"{user.amount}({lock_amount}已被锁定)!" + ) + return None + + @classmethod + async def get_user_deposit( + cls, user_id: str, is_completed: bool = False + ) -> list[MahiroBankLog]: + """获取用户今日存款次数 + + 参数: + user_id: 用户id + + 返回: + list[MahiroBankLog]: 存款列表 + """ + return await MahiroBankLog.filter( + user_id=user_id, + handle_type=BankHandleType.DEPOSIT, + is_completed=is_completed, + ) + + @classmethod + async def get_user(cls, user_id: str) -> MahiroBank: + """查询余额 + + 参数: + user_id: 用户id + + 返回: + MahiroBank + """ + user, _ = await MahiroBank.get_or_create(user_id=user_id) + return user + + @classmethod + async def get_user_data( + cls, + user_id: str, + data_type: BankHandleType, + is_completed: bool = False, + count: int = 5, + ) -> list[MahiroBankLog]: + return ( + await MahiroBankLog.filter( + user_id=user_id, handle_type=data_type, is_completed=is_completed + ) + .order_by("-id") + .limit(count) + .all() + ) + + @classmethod + async def complete_projected_revenue(cls, user_id: str) -> int: + """预计收益 + + 参数: + user_id: 用户id + + 返回: + int: 预计收益金额 + """ + deposit_list = await cls.get_user_deposit(user_id) + if not deposit_list: + return 0 + return int( + sum( + deposit.rate * deposit.amount * deposit.effective_hour + for deposit in deposit_list + ) + ) + + @classmethod + async def get_user_info(cls, session: Uninfo, uname: str) -> bytes: + """获取用户数据 + + 参数: + session: Uninfo + uname: 用户id + + 返回: + bytes: 图片数据 + """ + user_id = session.user.id + user = await cls.get_user(user_id=user_id) + ( + rank, + deposit_count, + user_today_deposit, + projected_revenue, + sum_data, + ) = await asyncio.gather( + *[ + MahiroBank.filter(amount__gt=user.amount).count(), + MahiroBankLog.filter(user_id=user_id).count(), + cls.get_user_deposit(user_id), + cls.complete_projected_revenue(user_id), + MahiroBankLog.filter( + user_id=user_id, handle_type=BankHandleType.INTEREST + ) + .annotate(sum=Sum("amount")) + .values("sum"), + ] + ) + now = datetime.now() + end_time = ( + now + + timedelta(days=1) + - timedelta(hours=now.hour, minutes=now.minute, seconds=now.second) + ) + today_deposit_amount = sum(deposit.amount for deposit in user_today_deposit) + deposit_list = [ + { + "id": deposit.id, + "date": now.date(), + "start_time": str(deposit.create_time).split(".")[0], + "end_time": end_time.replace(microsecond=0), + "amount": deposit.amount, + "rate": f"{deposit.rate * 100:.2f}", + "projected_revenue": int( + deposit.amount * deposit.rate * deposit.effective_hour + ) + or 1, + } + for deposit in user_today_deposit + ] + platform = PlatformUtils.get_platform(session) + data = { + "name": uname, + "rank": rank + 1, + "avatar_url": PlatformUtils.get_user_avatar_url( + user_id, platform, session.self_id + ), + "amount": user.amount, + "deposit_count": deposit_count, + "today_deposit_count": len(user_today_deposit), + "cumulative_gain": sum_data[0]["sum"] or 0, + "projected_revenue": projected_revenue, + "today_deposit_amount": today_deposit_amount, + "deposit_list": deposit_list, + "create_time": now.replace(microsecond=0), + } + return await template_to_pic( + template_path=str((TEMPLATE_PATH / "mahiro_bank").absolute()), + template_name="user.html", + templates={"data": data}, + pages={ + "viewport": {"width": 386, "height": 700}, + "base_url": f"file://{TEMPLATE_PATH}", + }, + wait=2, + ) + + @classmethod + async def get_bank_info(cls) -> bytes: + now = datetime.now() + now_start = now - timedelta( + hours=now.hour, minutes=now.minute, seconds=now.second + ) + ( + bank_data, + today_count, + interest_amount, + active_user_count, + date_data, + ) = await asyncio.gather( + *[ + MahiroBank.annotate( + amount_sum=Sum("amount"), user_count=Count("id") + ).values("amount_sum", "user_count"), + MahiroBankLog.filter( + create_time__gt=now_start, handle_type=BankHandleType.DEPOSIT + ).count(), + MahiroBankLog.filter(handle_type=BankHandleType.INTEREST) + .annotate(amount_sum=Sum("amount")) + .values("amount_sum"), + MahiroBankLog.filter( + create_time__gte=now_start - timedelta(days=7), + handle_type=BankHandleType.DEPOSIT, + ) + .annotate(count=Count("user_id", distinct=True)) + .values("count"), + MahiroBankLog.filter( + create_time__gte=now_start - timedelta(days=7), + handle_type=BankHandleType.DEPOSIT, + ) + .annotate(date=RawSQL("DATE(create_time)"), total_amount=Sum("amount")) + .group_by("date") + .values("date", "total_amount"), + ] + ) + date2cnt = {str(date["date"]): date["total_amount"] for date in date_data} + date = now.date() + e_date, e_amount = [], [] + for _ in range(7): + if str(date) in date2cnt: + e_amount.append(date2cnt[str(date)]) + else: + e_amount.append(0) + e_date.append(str(date)[5:]) + date -= timedelta(days=1) + e_date.reverse() + e_amount.reverse() + date = 1 + lasted_log = await MahiroBankLog.annotate().order_by("create_time").first() + if lasted_log: + date = now.date() - lasted_log.create_time.date() + date = (date.days or 1) + 1 + data = { + "amount_sum": bank_data[0]["amount_sum"], + "user_count": bank_data[0]["user_count"], + "today_count": today_count, + "day_amount": int(bank_data[0]["amount_sum"] / date), + "interest_amount": interest_amount[0]["amount_sum"] or 0, + "active_user_count": active_user_count[0]["count"] or 0, + "e_data": e_date, + "e_amount": e_amount, + "create_time": now.replace(microsecond=0), + } + return await template_to_pic( + template_path=str((TEMPLATE_PATH / "mahiro_bank").absolute()), + template_name="bank.html", + templates={"data": data}, + pages={ + "viewport": {"width": 450, "height": 750}, + "base_url": f"file://{TEMPLATE_PATH}", + }, + wait=2, + ) + + @classmethod + async def deposit( + cls, user_id: str, amount: int + ) -> tuple[MahiroBank, float, float | None]: + """存款 + + 参数: + user_id: 用户id + amount: 存款数量 + + 返回: + tuple[MahiroBank, float, float]: MahiroBank,利率,增加的利率 + """ + rate_range = base_config.get("rate_range") + rate = random.uniform(rate_range[0], rate_range[1]) + sign_user = await SignUser.get_user(user_id) + random_add_rate = await cls.random_event(float(sign_user.impression)) + if random_add_rate: + rate += random_add_rate + await UserConsole.reduce_gold(user_id, amount, GoldHandle.PLUGIN, "bank") + return await MahiroBank.deposit(user_id, amount, rate), rate, random_add_rate + + @classmethod + async def withdraw(cls, user_id: str, amount: int) -> MahiroBank: + """取款 + + 参数: + user_id: 用户id + amount: 取款数量 + + 返回: + MahiroBank + """ + await UserConsole.add_gold(user_id, amount, "bank") + return await MahiroBank.withdraw(user_id, amount) + + @classmethod + async def loan(cls, user_id: str, amount: int) -> tuple[MahiroBank, float | None]: + """贷款 + + 参数: + user_id: 用户id + amount: 贷款数量 + + 返回: + tuple[MahiroBank, float]: MahiroBank,贷款利率 + """ + rate_range = base_config.get("rate_range") + rate = random.uniform(rate_range[0], rate_range[1]) + sign_user = await SignUser.get_user(user_id) + user, _ = await MahiroBank.get_or_create(user_id=user_id) + if user.loan_amount + amount > sign_user.impression * 150: + raise ValueError("贷款数量超过最大限制,请签到提升好感度获取更多额度吧...") + random_reduce_rate = await cls.random_event(float(sign_user.impression)) + if random_reduce_rate: + rate -= random_reduce_rate + await UserConsole.add_gold(user_id, amount, "bank") + return await MahiroBank.loan(user_id, amount, rate), random_reduce_rate + + @classmethod + async def repayment(cls, user_id: str, amount: int) -> MahiroBank: + """还款 + + 参数: + user_id: 用户id + amount: 还款数量 + + 返回: + MahiroBank + """ + await UserConsole.reduce_gold(user_id, amount, GoldHandle.PLUGIN, "bank") + return await MahiroBank.repayment(user_id, amount) + + @classmethod + async def settlement(cls): + """结算每日利率""" + bank_user_list = await MahiroBank.filter(amount__gt=0).all() + log_list = await MahiroBankLog.filter( + is_completed=False, handle_type=BankHandleType.DEPOSIT + ).all() + user_list = await UserConsole.filter( + user_id__in=[user.user_id for user in bank_user_list] + ).all() + user_data = {user.user_id: user for user in user_list} + bank_data: dict[str, list[MahiroBankLog]] = {} + for log in log_list: + if log.user_id not in bank_data: + bank_data[log.user_id] = [] + bank_data[log.user_id].append(log) + log_create_list = [] + log_update_list = [] + # 计算每日默认金币 + for bank_user in bank_user_list: + if user := user_data.get(bank_user.user_id): + amount = bank_user.amount + if logs := bank_data.get(bank_user.user_id): + amount -= sum(log.amount for log in logs) + if not amount: + continue + # 计算每日默认金币 + gold = int(amount * bank_user.rate) + user.gold += gold + log_create_list.append( + MahiroBankLog( + user_id=bank_user.user_id, + amount=gold, + rate=bank_user.rate, + handle_type=BankHandleType.INTEREST, + is_completed=True, + ) + ) + # 计算每日存款金币 + for user_id, logs in bank_data.items(): + if user := user_data.get(user_id): + for log in logs: + gold = int(log.amount * log.rate * log.effective_hour) or 1 + user.gold += gold + log.is_completed = True + log_update_list.append(log) + log_create_list.append( + MahiroBankLog( + user_id=user_id, + amount=gold, + rate=log.rate, + handle_type=BankHandleType.INTEREST, + is_completed=True, + ) + ) + if log_create_list: + await MahiroBankLog.bulk_create(log_create_list, 10) + if log_update_list: + await MahiroBankLog.bulk_update(log_update_list, ["is_completed"], 10) + await UserConsole.bulk_update(user_list, ["gold"], 10) diff --git a/zhenxun/builtin_plugins/nickname.py b/zhenxun/builtin_plugins/nickname.py index 7dd9a697..5cbc519e 100644 --- a/zhenxun/builtin_plugins/nickname.py +++ b/zhenxun/builtin_plugins/nickname.py @@ -1,12 +1,17 @@ import random -from typing import Any -from nonebot import on_regex from nonebot.adapters import Bot -from nonebot.params import Depends, RegexGroup from nonebot.plugin import PluginMetadata from nonebot.rule import to_me -from nonebot_plugin_alconna import Alconna, Option, on_alconna, store_true +from nonebot_plugin_alconna import ( + Alconna, + Args, + Arparma, + CommandMeta, + Option, + on_alconna, + store_true, +) from nonebot_plugin_uninfo import Uninfo from zhenxun.configs.config import BotConfig, Config @@ -54,15 +59,22 @@ __plugin_meta__ = PluginMetadata( ).to_dict(), ) -_nickname_matcher = on_regex( - "(?:以后)?(?:叫我|请叫我|称呼我)(.*)", +_nickname_matcher = on_alconna( + Alconna( + "re:(?:以后)?(?:叫我|请叫我|称呼我)", + Args["name?", str], + meta=CommandMeta(compact=True), + ), rule=to_me(), priority=5, block=True, ) -_global_nickname_matcher = on_regex( - "设置全局昵称(.*)", rule=to_me(), priority=5, block=True +_global_nickname_matcher = on_alconna( + Alconna("设置全局昵称", Args["name?", str], meta=CommandMeta(compact=True)), + rule=to_me(), + priority=5, + block=True, ) _matcher = on_alconna( @@ -117,34 +129,32 @@ CANCEL = [ ] -def CheckNickname(): +async def CheckNickname( + bot: Bot, + session: Uninfo, + params: Arparma, +): """ 检查名称是否合法 """ - - async def dependency( - bot: Bot, - session: Uninfo, - reg_group: tuple[Any, ...] = RegexGroup(), - ): - black_word = Config.get_config("nickname", "BLACK_WORD") - (name,) = reg_group - logger.debug(f"昵称检查: {name}", "昵称设置", session=session) - if not name: - await MessageUtils.build_message("叫你空白?叫你虚空?叫你无名??").finish( - at_sender=True - ) - if session.user.id in bot.config.superusers: - logger.debug( - f"超级用户设置昵称, 跳过合法检测: {name}", "昵称设置", session=session - ) - return + black_word = Config.get_config("nickname", "BLACK_WORD") + name = params.query("name") + logger.debug(f"昵称检查: {name}", "昵称设置", session=session) + if not name: + await MessageUtils.build_message("叫你空白?叫你虚空?叫你无名??").finish( + at_sender=True + ) + if session.user.id in bot.config.superusers: + logger.debug( + f"超级用户设置昵称, 跳过合法检测: {name}", "昵称设置", session=session + ) + else: if len(name) > 20: await MessageUtils.build_message("昵称可不能超过20个字!").finish( at_sender=True ) if name in bot.config.nickname: - await MessageUtils.build_message("笨蛋!休想占用我的名字! #").finish( + await MessageUtils.build_message("笨蛋!休想占用我的名字! ").finish( at_sender=True ) if black_word: @@ -162,17 +172,17 @@ def CheckNickname(): await MessageUtils.build_message( f"字符 [{word}] 为禁止字符!" ).finish(at_sender=True) - - return Depends(dependency) + return name -@_nickname_matcher.handle(parameterless=[CheckNickname()]) +@_nickname_matcher.handle() async def _( + bot: Bot, session: Uninfo, + name_: Arparma, uname: str = UserName(), - reg_group: tuple[Any, ...] = RegexGroup(), ): - (name,) = reg_group + name = await CheckNickname(bot, session, name_) if len(name) < 5 and random.random() < 0.3: name = "~".join(name) group_id = None @@ -200,13 +210,14 @@ async def _( ) -@_global_nickname_matcher.handle(parameterless=[CheckNickname()]) +@_global_nickname_matcher.handle() async def _( + bot: Bot, session: Uninfo, + name_: Arparma, nickname: str = UserName(), - reg_group: tuple[Any, ...] = RegexGroup(), ): - (name,) = reg_group + name = await CheckNickname(bot, session, name_) await FriendUser.set_user_nickname( session.user.id, name, @@ -227,15 +238,14 @@ async def _(session: Uninfo, uname: str = UserName()): group_id = session.group.parent.id if session.group.parent else session.group.id if group_id: nickname = await GroupInfoUser.get_user_nickname(session.user.id, group_id) - card = uname else: nickname = await FriendUser.get_user_nickname(session.user.id) - card = uname if nickname: await MessageUtils.build_message(random.choice(REMIND).format(nickname)).finish( reply_to=True ) else: + card = uname await MessageUtils.build_message( random.choice( [ diff --git a/zhenxun/builtin_plugins/platform/qq/group_handle/__init__.py b/zhenxun/builtin_plugins/platform/qq/group_handle/__init__.py index d621f087..f4c28f04 100644 --- a/zhenxun/builtin_plugins/platform/qq/group_handle/__init__.py +++ b/zhenxun/builtin_plugins/platform/qq/group_handle/__init__.py @@ -1,4 +1,4 @@ -from nonebot import on_notice, on_request +from nonebot import on_notice from nonebot.adapters import Bot from nonebot.adapters.onebot.v11 import ( GroupDecreaseNoticeEvent, @@ -14,9 +14,10 @@ from nonebot_plugin_uninfo import Uninfo from zhenxun.builtin_plugins.platform.qq.exception import ForceAddGroupError from zhenxun.configs.config import BotConfig, Config from zhenxun.configs.utils import PluginExtraData, RegisterConfig, Task +from zhenxun.models.event_log import EventLog from zhenxun.models.group_console import GroupConsole from zhenxun.utils.common_utils import CommonUtils -from zhenxun.utils.enum import PluginType +from zhenxun.utils.enum import EventLogType, PluginType from zhenxun.utils.platform import PlatformUtils from zhenxun.utils.rules import notice_rule @@ -106,8 +107,6 @@ group_decrease_handle = on_notice( rule=notice_rule([GroupMemberDecreaseEvent, GroupDecreaseNoticeEvent]), ) """群员减少处理""" -add_group = on_request(priority=1, block=False) -"""加群同意请求""" @group_increase_handle.handle() @@ -141,8 +140,21 @@ async def _( group_id = str(event.group_id) if event.sub_type == "kick_me": """踢出Bot""" - await GroupManager.kick_bot(bot, user_id, group_id) + await GroupManager.kick_bot(bot, group_id, str(event.operator_id)) + await EventLog.create( + user_id=user_id, group_id=group_id, event_type=EventLogType.KICK_BOT + ) elif event.sub_type in ["leave", "kick"]: + if event.sub_type == "leave": + """主动退群""" + await EventLog.create( + user_id=user_id, group_id=group_id, event_type=EventLogType.LEAVE_MEMBER + ) + else: + """被踢出群""" + await EventLog.create( + user_id=user_id, group_id=group_id, event_type=EventLogType.KICK_MEMBER + ) result = await GroupManager.run_user( bot, user_id, group_id, str(event.operator_id), event.sub_type ) diff --git a/zhenxun/builtin_plugins/platform/qq/user_group_request.py b/zhenxun/builtin_plugins/platform/qq/user_group_request.py new file mode 100644 index 00000000..ae1d32ed --- /dev/null +++ b/zhenxun/builtin_plugins/platform/qq/user_group_request.py @@ -0,0 +1,100 @@ +import asyncio +from datetime import datetime +import random + +from nonebot.adapters import Bot +from nonebot.plugin import PluginMetadata +from nonebot.rule import to_me +from nonebot_plugin_alconna import Alconna, Args, Arparma, Field, on_alconna +from nonebot_plugin_uninfo import Uninfo + +from zhenxun.configs.utils import PluginCdBlock, PluginExtraData +from zhenxun.models.fg_request import FgRequest +from zhenxun.services.log import logger +from zhenxun.utils.depends import UserName +from zhenxun.utils.enum import RequestHandleType, RequestType +from zhenxun.utils.platform import PlatformUtils + +__plugin_meta__ = PluginMetadata( + name="群组申请", + description=""" + 一些小群直接邀请入群导致无法正常生成审核请求,需要用该方法手动生成审核请求。 + 当管理员同意同意时会发送消息进行提示,之后再进行拉群不会退出。 + 该消息会发送至管理员,多次发送不存在的群组id或相同群组id可能导致ban。 + """.strip(), + usage=""" + 指令: + 申请入群 [群号] + 示例: 申请入群 123123123 + """.strip(), + extra=PluginExtraData( + author="HibiKier", + version="0.1", + menu_type="其他", + limits=[PluginCdBlock(cd=300, result="每5分钟只能申请一次哦~")], + ).to_dict(), +) + + +_matcher = on_alconna( + Alconna( + "申请入群", + Args[ + "group_id", + int, + Field( + missing_tips=lambda: "请在命令后跟随群组id!", + unmatch_tips=lambda _: "群组id必须为数字!", + ), + ], + ), + skip_for_unmatch=False, + priority=5, + block=True, + rule=to_me(), +) + + +@_matcher.handle() +async def _( + bot: Bot, session: Uninfo, arparma: Arparma, group_id: int, uname: str = UserName() +): + # 旧请求全部设置为过期 + await FgRequest.filter( + request_type=RequestType.GROUP, + user_id=session.user.id, + group_id=str(group_id), + handle_type__isnull=True, + ).update(handle_type=RequestHandleType.EXPIRE) + f = await FgRequest.create( + request_type=RequestType.GROUP, + platform=PlatformUtils.get_platform(session), + bot_id=bot.self_id, + flag="0", + user_id=session.user.id, + nickname=uname, + group_id=str(group_id), + ) + results = await PlatformUtils.send_superuser( + bot, + f"*****一份入群申请*****\n" + f"ID:{f.id}\n" + f"申请人:{uname}({session.user.id})\n群聊:" + f"{group_id}\n邀请日期:{datetime.now().replace(microsecond=0)}\n" + "注:该请求为手动申请入群", + ) + if message_ids := [ + str(r[1].msg_ids[0]["message_id"]) for r in results if r[1] and r[1].msg_ids + ]: + f.message_ids = ",".join(message_ids) + await f.save(update_fields=["message_ids"]) + await asyncio.sleep(random.randint(1, 5)) + await bot.send_private_msg( + user_id=int(session.user.id), + message=f"已发送申请,请等待管理员审核,ID:{f.id}。", + ) + logger.info( + f"用户 {uname}({session.user.id}) 申请入群 {group_id},ID:{f.id}。", + arparma.header_result, + session=session, + ) diff --git a/zhenxun/builtin_plugins/plugin_store/__init__.py b/zhenxun/builtin_plugins/plugin_store/__init__.py index 7e9f52a0..72d6d7dd 100644 --- a/zhenxun/builtin_plugins/plugin_store/__init__.py +++ b/zhenxun/builtin_plugins/plugin_store/__init__.py @@ -9,7 +9,7 @@ from zhenxun.utils.enum import PluginType from zhenxun.utils.message import MessageUtils from zhenxun.utils.utils import is_number -from .data_source import ShopManage +from .data_source import StoreManager __plugin_meta__ = PluginMetadata( name="插件商店", @@ -82,7 +82,7 @@ _matcher.shortcut( @_matcher.assign("$main") async def _(session: EventSession): try: - result = await ShopManage.get_plugins_info() + result = await StoreManager.get_plugins_info() logger.info("查看插件列表", "插件商店", session=session) await MessageUtils.build_message(result).send() except Exception as e: @@ -97,7 +97,7 @@ async def _(session: EventSession, plugin_id: str): await MessageUtils.build_message(f"正在添加插件 Id: {plugin_id}").send() else: await MessageUtils.build_message(f"正在添加插件 Module: {plugin_id}").send() - result = await ShopManage.add_plugin(plugin_id) + result = await StoreManager.add_plugin(plugin_id) except Exception as e: logger.error(f"添加插件 Id: {plugin_id}失败", "插件商店", session=session, e=e) await MessageUtils.build_message( @@ -110,7 +110,7 @@ async def _(session: EventSession, plugin_id: str): @_matcher.assign("remove") async def _(session: EventSession, plugin_id: str): try: - result = await ShopManage.remove_plugin(plugin_id) + result = await StoreManager.remove_plugin(plugin_id) except Exception as e: logger.error(f"移除插件 Id: {plugin_id}失败", "插件商店", session=session, e=e) await MessageUtils.build_message( @@ -123,7 +123,7 @@ async def _(session: EventSession, plugin_id: str): @_matcher.assign("search") async def _(session: EventSession, plugin_name_or_author: str): try: - result = await ShopManage.search_plugin(plugin_name_or_author) + result = await StoreManager.search_plugin(plugin_name_or_author) except Exception as e: logger.error( f"搜索插件 name: {plugin_name_or_author}失败", @@ -145,7 +145,7 @@ async def _(session: EventSession, plugin_id: str): await MessageUtils.build_message(f"正在更新插件 Id: {plugin_id}").send() else: await MessageUtils.build_message(f"正在更新插件 Module: {plugin_id}").send() - result = await ShopManage.update_plugin(plugin_id) + result = await StoreManager.update_plugin(plugin_id) except Exception as e: logger.error(f"更新插件 Id: {plugin_id}失败", "插件商店", session=session, e=e) await MessageUtils.build_message( @@ -159,7 +159,7 @@ async def _(session: EventSession, plugin_id: str): async def _(session: EventSession): try: await MessageUtils.build_message("正在更新全部插件").send() - result = await ShopManage.update_all_plugin() + result = await StoreManager.update_all_plugin() except Exception as e: logger.error("更新全部插件失败", "插件商店", session=session, e=e) await MessageUtils.build_message(f"更新全部插件失败 e: {e}").finish() diff --git a/zhenxun/builtin_plugins/plugin_store/config.py b/zhenxun/builtin_plugins/plugin_store/config.py index dacaffec..dd48a5c7 100644 --- a/zhenxun/builtin_plugins/plugin_store/config.py +++ b/zhenxun/builtin_plugins/plugin_store/config.py @@ -9,3 +9,5 @@ DEFAULT_GITHUB_URL = "https://github.com/zhenxun-org/zhenxun_bot_plugins/tree/ma EXTRA_GITHUB_URL = "https://github.com/zhenxun-org/zhenxun_bot_plugins_index/tree/index" """插件库索引github仓库地址""" + +LOG_COMMAND = "插件商店" diff --git a/zhenxun/builtin_plugins/plugin_store/data_source.py b/zhenxun/builtin_plugins/plugin_store/data_source.py index 6e662a81..b2dc96dc 100644 --- a/zhenxun/builtin_plugins/plugin_store/data_source.py +++ b/zhenxun/builtin_plugins/plugin_store/data_source.py @@ -1,6 +1,5 @@ from pathlib import Path import shutil -import subprocess from aiocache import cached import ujson as json @@ -14,9 +13,15 @@ from zhenxun.utils.github_utils import GithubUtils from zhenxun.utils.github_utils.models import RepoAPI from zhenxun.utils.http_utils import AsyncHttpx from zhenxun.utils.image_utils import BuildImage, ImageTemplate, RowStyle +from zhenxun.utils.manager.virtual_env_package_manager import VirtualEnvPackageManager from zhenxun.utils.utils import is_number -from .config import BASE_PATH, DEFAULT_GITHUB_URL, EXTRA_GITHUB_URL +from .config import ( + BASE_PATH, + DEFAULT_GITHUB_URL, + EXTRA_GITHUB_URL, + LOG_COMMAND, +) def row_style(column: str, text: str) -> RowStyle: @@ -39,72 +44,69 @@ def install_requirement(plugin_path: Path): requirement_files = ["requirement.txt", "requirements.txt"] requirement_paths = [plugin_path / file for file in requirement_files] - existing_requirements = next( + if existing_requirements := next( (path for path in requirement_paths if path.exists()), None - ) - - if not existing_requirements: - logger.debug( - f"No requirement.txt found for plugin: {plugin_path.name}", "插件管理" - ) - return - - try: - result = subprocess.run( - ["poetry", "run", "pip", "install", "-r", str(existing_requirements)], - check=True, - capture_output=True, - text=True, - ) - logger.debug( - "Successfully installed dependencies for" - f" plugin: {plugin_path.name}. Output:\n{result.stdout}", - "插件管理", - ) - except subprocess.CalledProcessError: - logger.error( - f"Failed to install dependencies for plugin: {plugin_path.name}. " - " Error:\n{e.stderr}" - ) + ): + VirtualEnvPackageManager.install_requirement(existing_requirements) -class ShopManage: +class StoreManager: @classmethod - @cached(60) - async def get_data(cls) -> dict[str, StorePluginInfo]: - """获取插件信息数据 - - 异常: - ValueError: 访问请求失败 + async def get_github_plugins(cls) -> list[StorePluginInfo]: + """获取github插件列表信息 返回: - dict: 插件信息数据 + list[StorePluginInfo]: 插件列表数据 """ - default_github_repo = GithubUtils.parse_github_url(DEFAULT_GITHUB_URL) - extra_github_repo = GithubUtils.parse_github_url(EXTRA_GITHUB_URL) - for repo_info in [default_github_repo, extra_github_repo]: - if await repo_info.update_repo_commit(): - logger.info(f"获取最新提交: {repo_info.branch}", "插件管理") - else: - logger.warning(f"获取最新提交失败: {repo_info}", "插件管理") - default_github_url = await default_github_repo.get_raw_download_urls( - "plugins.json" - ) - extra_github_url = await extra_github_repo.get_raw_download_urls("plugins.json") - res = await AsyncHttpx.get(default_github_url) - res2 = await AsyncHttpx.get(extra_github_url) + repo_info = GithubUtils.parse_github_url(DEFAULT_GITHUB_URL) + if await repo_info.update_repo_commit(): + logger.info(f"获取最新提交: {repo_info.branch}", LOG_COMMAND) + else: + logger.warning(f"获取最新提交失败: {repo_info}", LOG_COMMAND) + default_github_url = await repo_info.get_raw_download_urls("plugins.json") + response = await AsyncHttpx.get(default_github_url, check_status_code=200) + if response.status_code == 200: + logger.info("获取github插件列表成功", LOG_COMMAND) + return [StorePluginInfo(**detail) for detail in json.loads(response.text)] + else: + logger.warning( + f"获取github插件列表失败: {response.status_code}", LOG_COMMAND + ) + return [] - # 检查请求结果 - if res.status_code != 200 or res2.status_code != 200: - raise ValueError(f"下载错误, code: {res.status_code}, {res2.status_code}") + @classmethod + async def get_extra_plugins(cls) -> list[StorePluginInfo]: + """获取额外插件列表信息 - # 解析并合并返回的 JSON 数据 - data1 = json.loads(res.text) - data2 = json.loads(res2.text) - return { - name: StorePluginInfo(**detail) - for name, detail in {**data1, **data2}.items() - } + 返回: + list[StorePluginInfo]: 插件列表数据 + """ + repo_info = GithubUtils.parse_github_url(EXTRA_GITHUB_URL) + if await repo_info.update_repo_commit(): + logger.info(f"获取最新提交: {repo_info.branch}", LOG_COMMAND) + else: + logger.warning(f"获取最新提交失败: {repo_info}", LOG_COMMAND) + extra_github_url = await repo_info.get_raw_download_urls("plugins.json") + response = await AsyncHttpx.get(extra_github_url, check_status_code=200) + if response.status_code == 200: + return [StorePluginInfo(**detail) for detail in json.loads(response.text)] + else: + logger.warning( + f"获取github扩展插件列表失败: {response.status_code}", LOG_COMMAND + ) + return [] + + @classmethod + @cached(60) + async def get_data(cls) -> list[StorePluginInfo]: + """获取插件信息数据 + + 返回: + list[StorePluginInfo]: 插件信息数据 + """ + plugins = await cls.get_github_plugins() + extra_plugins = await cls.get_extra_plugins() + return [*plugins, *extra_plugins] @classmethod def version_check(cls, plugin_info: StorePluginInfo, suc_plugin: dict[str, str]): @@ -112,7 +114,7 @@ class ShopManage: 参数: plugin_info: StorePluginInfo - suc_plugin: dict[str, str] + suc_plugin: 模块名: 版本号 返回: str: 版本号 @@ -132,7 +134,7 @@ class ShopManage: 参数: plugin_info: StorePluginInfo - suc_plugin: dict[str, str] + suc_plugin: 模块名: 版本号 返回: bool: 是否有更新 @@ -156,21 +158,21 @@ class ShopManage: 返回: BuildImage | str: 返回消息 """ - data: dict[str, StorePluginInfo] = await cls.get_data() + plugin_list: list[StorePluginInfo] = await cls.get_data() column_name = ["-", "ID", "名称", "简介", "作者", "版本", "类型"] - plugin_list = await cls.get_loaded_plugins("module", "version") - suc_plugin = {p[0]: (p[1] or "0.1") for p in plugin_list} + db_plugin_list = await cls.get_loaded_plugins("module", "version") + suc_plugin = {p[0]: (p[1] or "0.1") for p in db_plugin_list} data_list = [ [ - "已安装" if plugin_info[1].module in suc_plugin else "", + "已安装" if plugin_info.module in suc_plugin else "", id, - plugin_info[0], - plugin_info[1].description, - plugin_info[1].author, - cls.version_check(plugin_info[1], suc_plugin), - plugin_info[1].plugin_type_name, + plugin_info.name, + plugin_info.description, + plugin_info.author, + cls.version_check(plugin_info, suc_plugin), + plugin_info.plugin_type_name, ] - for id, plugin_info in enumerate(data.items()) + for id, plugin_info in enumerate(plugin_list) ] return await ImageTemplate.table_page( "插件列表", @@ -190,15 +192,15 @@ class ShopManage: 返回: str: 返回消息 """ - data: dict[str, StorePluginInfo] = await cls.get_data() + plugin_list: list[StorePluginInfo] = await cls.get_data() try: plugin_key = await cls._resolve_plugin_key(plugin_id) except ValueError as e: return str(e) - plugin_list = await cls.get_loaded_plugins("module") - plugin_info = data[plugin_key] - if plugin_info.module in [p[0] for p in plugin_list]: - return f"插件 {plugin_key} 已安装,无需重复安装" + db_plugin_list = await cls.get_loaded_plugins("module") + plugin_info = next(p for p in plugin_list if p.module == plugin_key) + if plugin_info.module in [p[0] for p in db_plugin_list]: + return f"插件 {plugin_info.name} 已安装,无需重复安装" is_external = True if plugin_info.github_url is None: plugin_info.github_url = DEFAULT_GITHUB_URL @@ -207,34 +209,39 @@ class ShopManage: if len(version_split) > 1: github_url_split = plugin_info.github_url.split("/tree/") plugin_info.github_url = f"{github_url_split[0]}/tree/{version_split[1]}" - logger.info(f"正在安装插件 {plugin_key}...") + logger.info(f"正在安装插件 {plugin_info.name}...", LOG_COMMAND) await cls.install_plugin_with_repo( plugin_info.github_url, plugin_info.module_path, plugin_info.is_dir, is_external, ) - return f"插件 {plugin_key} 安装成功! 重启后生效" + return f"插件 {plugin_info.name} 安装成功! 重启后生效" @classmethod async def install_plugin_with_repo( - cls, github_url: str, module_path: str, is_dir: bool, is_external: bool = False + cls, + github_url: str, + module_path: str, + is_dir: bool, + is_external: bool = False, ): - files: list[str] repo_api: RepoAPI repo_info = GithubUtils.parse_github_url(github_url) if await repo_info.update_repo_commit(): - logger.info(f"获取最新提交: {repo_info.branch}", "插件管理") + logger.info(f"获取最新提交: {repo_info.branch}", LOG_COMMAND) else: - logger.warning(f"获取最新提交失败: {repo_info}", "插件管理") - logger.debug(f"成功获取仓库信息: {repo_info}", "插件管理") + logger.warning(f"获取最新提交失败: {repo_info}", LOG_COMMAND) + logger.debug(f"成功获取仓库信息: {repo_info}", LOG_COMMAND) for repo_api in GithubUtils.iter_api_strategies(): try: await repo_api.parse_repo_info(repo_info) break except Exception as e: logger.warning( - f"获取插件文件失败: {e} | API类型: {repo_api.strategy}", "插件管理" + f"获取插件文件失败 | API类型: {repo_api.strategy}", + LOG_COMMAND, + e=e, ) continue else: @@ -250,7 +257,7 @@ class ShopManage: base_path = BASE_PATH / "plugins" if is_external else BASE_PATH base_path = base_path if module_path else base_path / repo_info.repo download_paths: list[Path | str] = [base_path / file for file in files] - logger.debug(f"插件下载路径: {download_paths}", "插件管理") + logger.debug(f"插件下载路径: {download_paths}", LOG_COMMAND) result = await AsyncHttpx.gather_download_file(download_urls, download_paths) for _id, success in enumerate(result): if not success: @@ -265,12 +272,12 @@ class ShopManage: req_files.extend( repo_api.get_files(f"{replace_module_path}/requirement.txt", False) ) - logger.debug(f"获取插件依赖文件列表: {req_files}", "插件管理") + logger.debug(f"获取插件依赖文件列表: {req_files}", LOG_COMMAND) req_download_urls = [ await repo_info.get_raw_download_urls(file) for file in req_files ] req_paths: list[Path | str] = [plugin_path / file for file in req_files] - logger.debug(f"插件依赖文件下载路径: {req_paths}", "插件管理") + logger.debug(f"插件依赖文件下载路径: {req_paths}", LOG_COMMAND) if req_files: result = await AsyncHttpx.gather_download_file( req_download_urls, req_paths @@ -278,7 +285,7 @@ class ShopManage: for success in result: if not success: raise Exception("插件依赖文件下载失败") - logger.debug(f"插件依赖文件列表: {req_paths}", "插件管理") + logger.debug(f"插件依赖文件列表: {req_paths}", LOG_COMMAND) install_requirement(plugin_path) except ValueError as e: logger.warning("未获取到依赖文件路径...", e=e) @@ -295,12 +302,12 @@ class ShopManage: 返回: str: 返回消息 """ - data: dict[str, StorePluginInfo] = await cls.get_data() + plugin_list: list[StorePluginInfo] = await cls.get_data() try: plugin_key = await cls._resolve_plugin_key(plugin_id) except ValueError as e: return str(e) - plugin_info = data[plugin_key] + plugin_info = next(p for p in plugin_list if p.module == plugin_key) path = BASE_PATH if plugin_info.github_url: path = BASE_PATH / "plugins" @@ -309,14 +316,14 @@ class ShopManage: if not plugin_info.is_dir: path = Path(f"{path}.py") if not path.exists(): - return f"插件 {plugin_key} 不存在..." - logger.debug(f"尝试移除插件 {plugin_key} 文件: {path}", "插件管理") + return f"插件 {plugin_info.name} 不存在..." + logger.debug(f"尝试移除插件 {plugin_info.name} 文件: {path}", LOG_COMMAND) if plugin_info.is_dir: shutil.rmtree(path) else: path.unlink() await PluginInitManager.remove(f"zhenxun.{plugin_info.module_path}") - return f"插件 {plugin_key} 移除成功! 重启后生效" + return f"插件 {plugin_info.name} 移除成功! 重启后生效" @classmethod async def search_plugin(cls, plugin_name_or_author: str) -> BuildImage | str: @@ -328,25 +335,25 @@ class ShopManage: 返回: BuildImage | str: 返回消息 """ - data: dict[str, StorePluginInfo] = await cls.get_data() - plugin_list = await cls.get_loaded_plugins("module", "version") - suc_plugin = {p[0]: (p[1] or "Unknown") for p in plugin_list} + plugin_list: list[StorePluginInfo] = await cls.get_data() + db_plugin_list = await cls.get_loaded_plugins("module", "version") + suc_plugin = {p[0]: (p[1] or "Unknown") for p in db_plugin_list} filtered_data = [ (id, plugin_info) - for id, plugin_info in enumerate(data.items()) - if plugin_name_or_author.lower() in plugin_info[0].lower() - or plugin_name_or_author.lower() in plugin_info[1].author.lower() + for id, plugin_info in enumerate(plugin_list) + if plugin_name_or_author.lower() in plugin_info.name.lower() + or plugin_name_or_author.lower() in plugin_info.author.lower() ] data_list = [ [ - "已安装" if plugin_info[1].module in suc_plugin else "", + "已安装" if plugin_info.module in suc_plugin else "", id, - plugin_info[0], - plugin_info[1].description, - plugin_info[1].author, - cls.version_check(plugin_info[1], suc_plugin), - plugin_info[1].plugin_type_name, + plugin_info.name, + plugin_info.description, + plugin_info.author, + cls.version_check(plugin_info, suc_plugin), + plugin_info.plugin_type_name, ] for id, plugin_info in filtered_data ] @@ -354,7 +361,7 @@ class ShopManage: return "未找到相关插件..." column_name = ["-", "ID", "名称", "简介", "作者", "版本", "类型"] return await ImageTemplate.table_page( - "插件列表", + "商店插件列表", "通过添加/移除插件 ID 来管理插件", column_name, data_list, @@ -371,20 +378,20 @@ class ShopManage: 返回: str: 返回消息 """ - data: dict[str, StorePluginInfo] = await cls.get_data() + plugin_list: list[StorePluginInfo] = await cls.get_data() try: plugin_key = await cls._resolve_plugin_key(plugin_id) except ValueError as e: return str(e) - logger.info(f"尝试更新插件 {plugin_key}", "插件管理") - plugin_info = data[plugin_key] - plugin_list = await cls.get_loaded_plugins("module", "version") - suc_plugin = {p[0]: (p[1] or "Unknown") for p in plugin_list} - if plugin_info.module not in [p[0] for p in plugin_list]: - return f"插件 {plugin_key} 未安装,无法更新" - logger.debug(f"当前插件列表: {suc_plugin}", "插件管理") + plugin_info = next(p for p in plugin_list if p.module == plugin_key) + logger.info(f"尝试更新插件 {plugin_info.name}", LOG_COMMAND) + db_plugin_list = await cls.get_loaded_plugins("module", "version") + suc_plugin = {p[0]: (p[1] or "Unknown") for p in db_plugin_list} + if plugin_info.module not in [p[0] for p in db_plugin_list]: + return f"插件 {plugin_info.name} 未安装,无法更新" + logger.debug(f"当前插件列表: {suc_plugin}", LOG_COMMAND) if cls.check_version_is_new(plugin_info, suc_plugin): - return f"插件 {plugin_key} 已是最新版本" + return f"插件 {plugin_info.name} 已是最新版本" is_external = True if plugin_info.github_url is None: plugin_info.github_url = DEFAULT_GITHUB_URL @@ -395,7 +402,7 @@ class ShopManage: plugin_info.is_dir, is_external, ) - return f"插件 {plugin_key} 更新成功! 重启后生效" + return f"插件 {plugin_info.name} 更新成功! 重启后生效" @classmethod async def update_all_plugin(cls) -> str: @@ -407,24 +414,33 @@ class ShopManage: 返回: str: 返回消息 """ - data: dict[str, StorePluginInfo] = await cls.get_data() - plugin_list = list(data.keys()) + plugin_list: list[StorePluginInfo] = await cls.get_data() + plugin_name_list = [p.name for p in plugin_list] update_failed_list = [] update_success_list = [] result = "--已更新{}个插件 {}个失败 {}个成功--" - logger.info(f"尝试更新全部插件 {plugin_list}", "插件管理") - for plugin_key in plugin_list: + logger.info(f"尝试更新全部插件 {plugin_name_list}", LOG_COMMAND) + for plugin_info in plugin_list: try: - plugin_info = data[plugin_key] - plugin_list = await cls.get_loaded_plugins("module", "version") - suc_plugin = {p[0]: (p[1] or "Unknown") for p in plugin_list} - if plugin_info.module not in [p[0] for p in plugin_list]: - logger.debug(f"插件 {plugin_key} 未安装,跳过", "插件管理") + db_plugin_list = await cls.get_loaded_plugins("module", "version") + suc_plugin = {p[0]: (p[1] or "Unknown") for p in db_plugin_list} + if plugin_info.module not in [p[0] for p in db_plugin_list]: + logger.debug( + f"插件 {plugin_info.name}({plugin_info.module}) 未安装,跳过", + LOG_COMMAND, + ) continue if cls.check_version_is_new(plugin_info, suc_plugin): - logger.debug(f"插件 {plugin_key} 已是最新版本,跳过", "插件管理") + logger.debug( + f"插件 {plugin_info.name}({plugin_info.module}) " + "已是最新版本,跳过", + LOG_COMMAND, + ) continue - logger.info(f"正在更新插件 {plugin_key}", "插件管理") + logger.info( + f"正在更新插件 {plugin_info.name}({plugin_info.module})", + LOG_COMMAND, + ) is_external = True if plugin_info.github_url is None: plugin_info.github_url = DEFAULT_GITHUB_URL @@ -435,10 +451,14 @@ class ShopManage: plugin_info.is_dir, is_external, ) - update_success_list.append(plugin_key) + update_success_list.append(plugin_info.name) except Exception as e: - logger.error(f"更新插件 {plugin_key} 失败: {e}", "插件管理") - update_failed_list.append(plugin_key) + logger.error( + f"更新插件 {plugin_info.name}({plugin_info.module}) 失败", + LOG_COMMAND, + e=e, + ) + update_failed_list.append(plugin_info.name) if not update_success_list and not update_failed_list: return "全部插件已是最新版本" if update_success_list: @@ -460,13 +480,28 @@ class ShopManage: @classmethod async def _resolve_plugin_key(cls, plugin_id: str) -> str: - data: dict[str, StorePluginInfo] = await cls.get_data() + """获取插件module + + 参数: + plugin_id: module,id或插件名称 + + 异常: + ValueError: 插件不存在 + ValueError: 插件不存在 + + 返回: + str: 插件模块名 + """ + plugin_list: list[StorePluginInfo] = await cls.get_data() if is_number(plugin_id): idx = int(plugin_id) - if idx < 0 or idx >= len(data): + if idx < 0 or idx >= len(plugin_list): raise ValueError("插件ID不存在...") - return list(data.keys())[idx] + return plugin_list[idx].module elif isinstance(plugin_id, str): - if plugin_id not in [v.module for k, v in data.items()]: - raise ValueError("插件Module不存在...") - return {v.module: k for k, v in data.items()}[plugin_id] + result = ( + None if plugin_id not in [v.module for v in plugin_list] else plugin_id + ) or next(v for v in plugin_list if v.name == plugin_id).module + if not result: + raise ValueError("插件 Module / 名称 不存在...") + return result diff --git a/zhenxun/builtin_plugins/plugin_store/models.py b/zhenxun/builtin_plugins/plugin_store/models.py index df65dd56..2bea1315 100644 --- a/zhenxun/builtin_plugins/plugin_store/models.py +++ b/zhenxun/builtin_plugins/plugin_store/models.py @@ -1,3 +1,5 @@ +from typing import Any, Literal + from nonebot.compat import model_dump from pydantic import BaseModel @@ -13,9 +15,30 @@ type2name: dict[str, str] = { } +class GiteeContents(BaseModel): + """Gitee Api内容""" + + type: Literal["file", "dir"] + """类型""" + size: Any + """文件大小""" + name: str + """文件名""" + path: str + """文件路径""" + url: str + """文件链接""" + html_url: str + """文件html链接""" + download_url: str + """文件raw链接""" + + class StorePluginInfo(BaseModel): """插件信息""" + name: str + """插件名""" module: str """模块名""" module_path: str diff --git a/zhenxun/builtin_plugins/record_request.py b/zhenxun/builtin_plugins/record_request.py index d4b0c694..32d5d551 100644 --- a/zhenxun/builtin_plugins/record_request.py +++ b/zhenxun/builtin_plugins/record_request.py @@ -17,11 +17,12 @@ from nonebot_plugin_session import EventSession from zhenxun.configs.config import BotConfig, Config from zhenxun.configs.utils import PluginExtraData, RegisterConfig +from zhenxun.models.event_log import EventLog from zhenxun.models.fg_request import FgRequest from zhenxun.models.friend_user import FriendUser from zhenxun.models.group_console import GroupConsole from zhenxun.services.log import logger -from zhenxun.utils.enum import PluginType, RequestHandleType, RequestType +from zhenxun.utils.enum import EventLogType, PluginType, RequestHandleType, RequestType from zhenxun.utils.platform import PlatformUtils base_config = Config.get("invite_manager") @@ -112,21 +113,29 @@ async def _(bot: v12Bot | v11Bot, event: FriendRequestEvent, session: EventSessi nickname=nickname, comment=comment, ) - await PlatformUtils.send_superuser( + results = await PlatformUtils.send_superuser( bot, f"*****一份好友申请*****\n" f"ID: {f.id}\n" f"昵称:{nickname}({event.user_id})\n" f"自动同意:{'√' if base_config.get('AUTO_ADD_FRIEND') else '×'}\n" - f"日期:{str(datetime.now()).split('.')[0]}\n" + f"日期:{datetime.now().replace(microsecond=0)}\n" f"备注:{event.comment}", ) + if message_ids := [ + str(r[1].msg_ids[0]["message_id"]) + for r in results + if r[1] and r[1].msg_ids + ]: + f.message_ids = ",".join(message_ids) + await f.save(update_fields=["message_ids"]) else: logger.debug("好友请求五分钟内重复, 已忽略", "好友请求", target=event.user_id) @group_req.handle() async def _(bot: v12Bot | v11Bot, event: GroupRequestEvent, session: EventSession): + # sourcery skip: low-code-quality if event.sub_type != "invite": return if str(event.user_id) in bot.config.superusers or base_config.get("AUTO_ADD_GROUP"): @@ -186,7 +195,7 @@ async def _(bot: v12Bot | v11Bot, event: GroupRequestEvent, session: EventSessio group_id=str(event.group_id), handle_type=RequestHandleType.APPROVE, ) - await PlatformUtils.send_superuser( + results = await PlatformUtils.send_superuser( bot, f"*****一份入群申请*****\n" f"ID:{f.id}\n" @@ -230,13 +239,27 @@ async def _(bot: v12Bot | v11Bot, event: GroupRequestEvent, session: EventSessio nickname=nickname, group_id=str(event.group_id), ) - await PlatformUtils.send_superuser( + kick_count = await EventLog.filter( + group_id=str(event.group_id), event_type=EventLogType.KICK_BOT + ).count() + kick_message = ( + f"\n该群累计踢出{BotConfig.self_nickname} <{kick_count}>次" + if kick_count + else "" + ) + results = await PlatformUtils.send_superuser( bot, f"*****一份入群申请*****\n" f"ID:{f.id}\n" f"申请人:{nickname}({event.user_id})\n群聊:" - f"{event.group_id}\n邀请日期:{datetime.now().replace(microsecond=0)}", + f"{event.group_id}\n邀请日期:{datetime.now().replace(microsecond=0)}" + f"{kick_message}", ) + if message_ids := [ + str(r[1].msg_ids[0]["message_id"]) for r in results if r[1] and r[1].msg_ids + ]: + f.message_ids = ",".join(message_ids) + await f.save(update_fields=["message_ids"]) else: logger.debug( "群聊请求五分钟内重复, 已忽略", diff --git a/zhenxun/builtin_plugins/scheduler_admin/__init__.py b/zhenxun/builtin_plugins/scheduler_admin/__init__.py new file mode 100644 index 00000000..adaaa621 --- /dev/null +++ b/zhenxun/builtin_plugins/scheduler_admin/__init__.py @@ -0,0 +1,51 @@ +from nonebot.plugin import PluginMetadata + +from zhenxun.configs.utils import PluginExtraData +from zhenxun.utils.enum import PluginType + +from . import command # noqa: F401 + +__plugin_meta__ = PluginMetadata( + name="定时任务管理", + description="查看和管理由 SchedulerManager 控制的定时任务。", + usage=""" +📋 定时任务管理 - 支持群聊和私聊操作 + +🔍 查看任务: + 定时任务 查看 [-all] [-g <群号>] [-p <插件>] [--page <页码>] + • 群聊中: 查看本群任务 + • 私聊中: 必须使用 -g <群号> 或 -all 选项 (SUPERUSER) + +📊 任务状态: + 定时任务 状态 <任务ID> 或 任务状态 <任务ID> + • 查看单个任务的详细信息和状态 + +⚙️ 任务管理 (SUPERUSER): + 定时任务 设置 <插件> [时间选项] [-g <群号> | -g all] [--kwargs <参数>] + 定时任务 删除 <任务ID> | -p <插件> [-g <群号>] | -all + 定时任务 暂停 <任务ID> | -p <插件> [-g <群号>] | -all + 定时任务 恢复 <任务ID> | -p <插件> [-g <群号>] | -all + 定时任务 执行 <任务ID> + 定时任务 更新 <任务ID> [时间选项] [--kwargs <参数>] + +📝 时间选项 (三选一): + --cron "<分> <时> <日> <月> <周>" # 例: --cron "0 8 * * *" + --interval <时间间隔> # 例: --interval 30m, 2h, 10s + --date "" # 例: --date "2024-01-01 08:00:00" + --daily "" # 例: --daily "08:30" + +📚 其他功能: + 定时任务 插件列表 # 查看所有可设置定时任务的插件 (SUPERUSER) + +🏷️ 别名支持: + 查看: ls, list | 设置: add, 开启 | 删除: del, rm, remove, 关闭, 取消 + 暂停: pause | 恢复: resume | 执行: trigger, run | 状态: status, info + 更新: update, modify, 修改 | 插件列表: plugins + """.strip(), + extra=PluginExtraData( + author="HibiKier", + version="0.1.2", + plugin_type=PluginType.SUPERUSER, + is_show=False, + ).to_dict(), +) diff --git a/zhenxun/builtin_plugins/scheduler_admin/command.py b/zhenxun/builtin_plugins/scheduler_admin/command.py new file mode 100644 index 00000000..08a085fb --- /dev/null +++ b/zhenxun/builtin_plugins/scheduler_admin/command.py @@ -0,0 +1,836 @@ +import asyncio +from datetime import datetime +import re + +from nonebot.adapters import Event +from nonebot.adapters.onebot.v11 import Bot +from nonebot.params import Depends +from nonebot.permission import SUPERUSER +from nonebot_plugin_alconna import ( + Alconna, + AlconnaMatch, + Args, + Arparma, + Match, + Option, + Query, + Subcommand, + on_alconna, +) +from pydantic import BaseModel, ValidationError + +from zhenxun.utils._image_template import ImageTemplate +from zhenxun.utils.manager.schedule_manager import scheduler_manager + + +def _get_type_name(annotation) -> str: + """获取类型注解的名称""" + if hasattr(annotation, "__name__"): + return annotation.__name__ + elif hasattr(annotation, "_name"): + return annotation._name + else: + return str(annotation) + + +from zhenxun.utils.message import MessageUtils +from zhenxun.utils.rules import admin_check + + +def _format_trigger(schedule_status: dict) -> str: + """将触发器配置格式化为人类可读的字符串""" + trigger_type = schedule_status["trigger_type"] + config = schedule_status["trigger_config"] + + if trigger_type == "cron": + minute = config.get("minute", "*") + hour = config.get("hour", "*") + day = config.get("day", "*") + month = config.get("month", "*") + day_of_week = config.get("day_of_week", "*") + + if day == "*" and month == "*" and day_of_week == "*": + formatted_hour = hour if hour == "*" else f"{int(hour):02d}" + formatted_minute = minute if minute == "*" else f"{int(minute):02d}" + return f"每天 {formatted_hour}:{formatted_minute}" + else: + return f"Cron: {minute} {hour} {day} {month} {day_of_week}" + elif trigger_type == "interval": + seconds = config.get("seconds", 0) + minutes = config.get("minutes", 0) + hours = config.get("hours", 0) + days = config.get("days", 0) + if days: + trigger_str = f"每 {days} 天" + elif hours: + trigger_str = f"每 {hours} 小时" + elif minutes: + trigger_str = f"每 {minutes} 分钟" + else: + trigger_str = f"每 {seconds} 秒" + elif trigger_type == "date": + run_date = config.get("run_date", "未知时间") + trigger_str = f"在 {run_date}" + else: + trigger_str = f"{trigger_type}: {config}" + + return trigger_str + + +def _format_params(schedule_status: dict) -> str: + """将任务参数格式化为人类可读的字符串""" + if kwargs := schedule_status.get("job_kwargs"): + kwargs_str = " | ".join(f"{k}: {v}" for k, v in kwargs.items()) + return kwargs_str + return "-" + + +def _parse_interval(interval_str: str) -> dict: + """增强版解析器,支持 d(天)""" + match = re.match(r"(\d+)([smhd])", interval_str.lower()) + if not match: + raise ValueError("时间间隔格式错误, 请使用如 '30m', '2h', '1d', '10s' 的格式。") + + value, unit = int(match.group(1)), match.group(2) + if unit == "s": + return {"seconds": value} + if unit == "m": + return {"minutes": value} + if unit == "h": + return {"hours": value} + if unit == "d": + return {"days": value} + return {} + + +def _parse_daily_time(time_str: str) -> dict: + """解析 HH:MM 或 HH:MM:SS 格式的时间为 cron 配置""" + if match := re.match(r"^(\d{1,2}):(\d{1,2})(?::(\d{1,2}))?$", time_str): + hour, minute, second = match.groups() + hour, minute = int(hour), int(minute) + + if not (0 <= hour <= 23 and 0 <= minute <= 59): + raise ValueError("小时或分钟数值超出范围。") + + cron_config = { + "minute": str(minute), + "hour": str(hour), + "day": "*", + "month": "*", + "day_of_week": "*", + } + if second is not None: + if not (0 <= int(second) <= 59): + raise ValueError("秒数值超出范围。") + cron_config["second"] = str(second) + + return cron_config + else: + raise ValueError("时间格式错误,请使用 'HH:MM' 或 'HH:MM:SS' 格式。") + + +async def GetBotId( + bot: Bot, + bot_id_match: Match[str] = AlconnaMatch("bot_id"), +) -> str: + """获取要操作的Bot ID""" + if bot_id_match.available: + return bot_id_match.result + return bot.self_id + + +class ScheduleTarget: + """定时任务操作目标的基类""" + + pass + + +class TargetByID(ScheduleTarget): + """按任务ID操作""" + + def __init__(self, id: int): + self.id = id + + +class TargetByPlugin(ScheduleTarget): + """按插件名操作""" + + def __init__( + self, plugin: str, group_id: str | None = None, all_groups: bool = False + ): + self.plugin = plugin + self.group_id = group_id + self.all_groups = all_groups + + +class TargetAll(ScheduleTarget): + """操作所有任务""" + + def __init__(self, for_group: str | None = None): + self.for_group = for_group + + +TargetScope = TargetByID | TargetByPlugin | TargetAll | None + + +def create_target_parser(subcommand_name: str): + """ + 创建一个依赖注入函数,用于解析删除、暂停、恢复等命令的操作目标。 + """ + + async def dependency( + event: Event, + schedule_id: Match[int] = AlconnaMatch("schedule_id"), + plugin_name: Match[str] = AlconnaMatch("plugin_name"), + group_id: Match[str] = AlconnaMatch("group_id"), + all_enabled: Query[bool] = Query(f"{subcommand_name}.all"), + ) -> TargetScope: + if schedule_id.available: + return TargetByID(schedule_id.result) + + if plugin_name.available: + p_name = plugin_name.result + if all_enabled.available: + return TargetByPlugin(plugin=p_name, all_groups=True) + elif group_id.available: + gid = group_id.result + if gid.lower() == "all": + return TargetByPlugin(plugin=p_name, all_groups=True) + return TargetByPlugin(plugin=p_name, group_id=gid) + else: + current_group_id = getattr(event, "group_id", None) + if current_group_id: + return TargetByPlugin(plugin=p_name, group_id=str(current_group_id)) + else: + await schedule_cmd.finish( + "私聊中操作插件任务必须使用 -g <群号> 或 -all 选项。" + ) + + if all_enabled.available: + return TargetAll(for_group=group_id.result if group_id.available else None) + + return None + + return dependency + + +schedule_cmd = on_alconna( + Alconna( + "定时任务", + Subcommand( + "查看", + Option("-g", Args["target_group_id", str]), + Option("-all", help_text="查看所有群聊 (SUPERUSER)"), + Option("-p", Args["plugin_name", str], help_text="按插件名筛选"), + Option("--page", Args["page", int, 1], help_text="指定页码"), + alias=["ls", "list"], + help_text="查看定时任务", + ), + Subcommand( + "设置", + Args["plugin_name", str], + Option("--cron", Args["cron_expr", str], help_text="设置 cron 表达式"), + Option("--interval", Args["interval_expr", str], help_text="设置时间间隔"), + Option("--date", Args["date_expr", str], help_text="设置特定执行日期"), + Option( + "--daily", + Args["daily_expr", str], + help_text="设置每天执行的时间 (如 08:20)", + ), + Option("-g", Args["group_id", str], help_text="指定群组ID或'all'"), + Option("-all", help_text="对所有群生效 (等同于 -g all)"), + Option("--kwargs", Args["kwargs_str", str], help_text="设置任务参数"), + Option( + "--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)" + ), + alias=["add", "开启"], + help_text="设置/开启一个定时任务", + ), + Subcommand( + "删除", + Args["schedule_id?", int], + Option("-p", Args["plugin_name", str], help_text="指定插件名"), + Option("-g", Args["group_id", str], help_text="指定群组ID"), + Option("-all", help_text="对所有群生效"), + Option( + "--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)" + ), + alias=["del", "rm", "remove", "关闭", "取消"], + help_text="删除一个或多个定时任务", + ), + Subcommand( + "暂停", + Args["schedule_id?", int], + Option("-all", help_text="对当前群所有任务生效"), + Option("-p", Args["plugin_name", str], help_text="指定插件名"), + Option("-g", Args["group_id", str], help_text="指定群组ID (SUPERUSER)"), + Option( + "--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)" + ), + alias=["pause"], + help_text="暂停一个或多个定时任务", + ), + Subcommand( + "恢复", + Args["schedule_id?", int], + Option("-all", help_text="对当前群所有任务生效"), + Option("-p", Args["plugin_name", str], help_text="指定插件名"), + Option("-g", Args["group_id", str], help_text="指定群组ID (SUPERUSER)"), + Option( + "--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)" + ), + alias=["resume"], + help_text="恢复一个或多个定时任务", + ), + Subcommand( + "执行", + Args["schedule_id", int], + alias=["trigger", "run"], + help_text="立即执行一次任务", + ), + Subcommand( + "更新", + Args["schedule_id", int], + Option("--cron", Args["cron_expr", str], help_text="设置 cron 表达式"), + Option("--interval", Args["interval_expr", str], help_text="设置时间间隔"), + Option("--date", Args["date_expr", str], help_text="设置特定执行日期"), + Option( + "--daily", + Args["daily_expr", str], + help_text="更新每天执行的时间 (如 08:20)", + ), + Option("--kwargs", Args["kwargs_str", str], help_text="更新参数"), + alias=["update", "modify", "修改"], + help_text="更新任务配置", + ), + Subcommand( + "状态", + Args["schedule_id", int], + alias=["status", "info"], + help_text="查看单个任务的详细状态", + ), + Subcommand( + "插件列表", + alias=["plugins"], + help_text="列出所有可用的插件", + ), + ), + priority=5, + block=True, + rule=admin_check(1), +) + +schedule_cmd.shortcut( + "任务状态", + command="定时任务", + arguments=["状态", "{%0}"], + prefix=True, +) + + +@schedule_cmd.handle() +async def _handle_time_options_mutex(arp: Arparma): + time_options = ["cron", "interval", "date", "daily"] + provided_options = [opt for opt in time_options if arp.query(opt) is not None] + if len(provided_options) > 1: + await schedule_cmd.finish( + f"时间选项 --{', --'.join(provided_options)} 不能同时使用,请只选择一个。" + ) + + +@schedule_cmd.assign("查看") +async def _( + bot: Bot, + event: Event, + target_group_id: Match[str] = AlconnaMatch("target_group_id"), + all_groups: Query[bool] = Query("查看.all"), + plugin_name: Match[str] = AlconnaMatch("plugin_name"), + page: Match[int] = AlconnaMatch("page"), +): + is_superuser = await SUPERUSER(bot, event) + schedules = [] + title = "" + + current_group_id = getattr(event, "group_id", None) + if not (all_groups.available or target_group_id.available) and not current_group_id: + await schedule_cmd.finish("私聊中查看任务必须使用 -g <群号> 或 -all 选项。") + + if all_groups.available: + if not is_superuser: + await schedule_cmd.finish("需要超级用户权限才能查看所有群组的定时任务。") + schedules = await scheduler_manager.get_all_schedules() + title = "所有群组的定时任务" + elif target_group_id.available: + if not is_superuser: + await schedule_cmd.finish("需要超级用户权限才能查看指定群组的定时任务。") + gid = target_group_id.result + schedules = [ + s for s in await scheduler_manager.get_all_schedules() if s.group_id == gid + ] + title = f"群 {gid} 的定时任务" + else: + gid = str(current_group_id) + schedules = [ + s for s in await scheduler_manager.get_all_schedules() if s.group_id == gid + ] + title = "本群的定时任务" + + if plugin_name.available: + schedules = [s for s in schedules if s.plugin_name == plugin_name.result] + title += f" [插件: {plugin_name.result}]" + + if not schedules: + await schedule_cmd.finish("没有找到任何相关的定时任务。") + + page_size = 15 + current_page = page.result + total_items = len(schedules) + total_pages = (total_items + page_size - 1) // page_size + start_index = (current_page - 1) * page_size + end_index = start_index + page_size + paginated_schedules = schedules[start_index:end_index] + + if not paginated_schedules: + await schedule_cmd.finish("这一页没有内容了哦~") + + status_tasks = [ + scheduler_manager.get_schedule_status(s.id) for s in paginated_schedules + ] + all_statuses = await asyncio.gather(*status_tasks) + data_list = [ + [ + s["id"], + s["plugin_name"], + s.get("bot_id") or "N/A", + s["group_id"] or "全局", + s["next_run_time"], + _format_trigger(s), + _format_params(s), + "✔️ 已启用" if s["is_enabled"] else "⏸️ 已暂停", + ] + for s in all_statuses + if s + ] + + if not data_list: + await schedule_cmd.finish("没有找到任何相关的定时任务。") + + img = await ImageTemplate.table_page( + head_text=title, + tip_text=f"第 {current_page}/{total_pages} 页,共 {total_items} 条任务", + column_name=[ + "ID", + "插件", + "Bot ID", + "群组/目标", + "下次运行", + "触发规则", + "参数", + "状态", + ], + data_list=data_list, + column_space=20, + ) + await MessageUtils.build_message(img).send(reply_to=True) + + +@schedule_cmd.assign("设置") +async def _( + event: Event, + plugin_name: str, + cron_expr: str | None = None, + interval_expr: str | None = None, + date_expr: str | None = None, + daily_expr: str | None = None, + group_id: str | None = None, + kwargs_str: str | None = None, + all_enabled: Query[bool] = Query("设置.all"), + bot_id_to_operate: str = Depends(GetBotId), +): + if plugin_name not in scheduler_manager._registered_tasks: + await schedule_cmd.finish( + f"插件 '{plugin_name}' 没有注册可用的定时任务。\n" + f"可用插件: {list(scheduler_manager._registered_tasks.keys())}" + ) + + trigger_type = "" + trigger_config = {} + + try: + if cron_expr: + trigger_type = "cron" + parts = cron_expr.split() + if len(parts) != 5: + raise ValueError("Cron 表达式必须有5个部分 (分 时 日 月 周)") + cron_keys = ["minute", "hour", "day", "month", "day_of_week"] + trigger_config = dict(zip(cron_keys, parts)) + elif interval_expr: + trigger_type = "interval" + trigger_config = _parse_interval(interval_expr) + elif date_expr: + trigger_type = "date" + trigger_config = {"run_date": datetime.fromisoformat(date_expr)} + elif daily_expr: + trigger_type = "cron" + trigger_config = _parse_daily_time(daily_expr) + else: + await schedule_cmd.finish( + "必须提供一种时间选项: --cron, --interval, --date, 或 --daily。" + ) + except ValueError as e: + await schedule_cmd.finish(f"时间参数解析错误: {e}") + + job_kwargs = {} + if kwargs_str: + task_meta = scheduler_manager._registered_tasks[plugin_name] + params_model = task_meta.get("model") + if not params_model: + await schedule_cmd.finish(f"插件 '{plugin_name}' 不支持设置额外参数。") + + if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)): + await schedule_cmd.finish(f"插件 '{plugin_name}' 的参数模型配置错误。") + + raw_kwargs = {} + try: + for item in kwargs_str.split(","): + key, value = item.strip().split("=", 1) + raw_kwargs[key.strip()] = value + except Exception as e: + await schedule_cmd.finish( + f"参数格式错误,请使用 'key=value,key2=value2' 格式。错误: {e}" + ) + + try: + model_validate = getattr(params_model, "model_validate", None) + if not model_validate: + await schedule_cmd.finish( + f"插件 '{plugin_name}' 的参数模型不支持验证。" + ) + return + + validated_model = model_validate(raw_kwargs) + + model_dump = getattr(validated_model, "model_dump", None) + if not model_dump: + await schedule_cmd.finish( + f"插件 '{plugin_name}' 的参数模型不支持导出。" + ) + return + + job_kwargs = model_dump() + except ValidationError as e: + errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()] + error_str = "\n".join(errors) + await schedule_cmd.finish( + f"插件 '{plugin_name}' 的任务参数验证失败:\n{error_str}" + ) + return + + target_group_id: str | None + current_group_id = getattr(event, "group_id", None) + + if group_id and group_id.lower() == "all": + target_group_id = "__ALL_GROUPS__" + elif all_enabled.available: + target_group_id = "__ALL_GROUPS__" + elif group_id: + target_group_id = group_id + elif current_group_id: + target_group_id = str(current_group_id) + else: + await schedule_cmd.finish( + "私聊中设置定时任务时,必须使用 -g <群号> 或 --all 选项指定目标。" + ) + return + + success, msg = await scheduler_manager.add_schedule( + plugin_name, + target_group_id, + trigger_type, + trigger_config, + job_kwargs, + bot_id=bot_id_to_operate, + ) + + if target_group_id == "__ALL_GROUPS__": + target_desc = f"所有群组 (Bot: {bot_id_to_operate})" + elif target_group_id is None: + target_desc = "全局" + else: + target_desc = f"群组 {target_group_id}" + + if success: + await schedule_cmd.finish(f"已成功为 [{target_desc}] {msg}") + else: + await schedule_cmd.finish(f"为 [{target_desc}] 设置任务失败: {msg}") + + +@schedule_cmd.assign("删除") +async def _( + target: TargetScope = Depends(create_target_parser("删除")), + bot_id_to_operate: str = Depends(GetBotId), +): + if isinstance(target, TargetByID): + _, message = await scheduler_manager.remove_schedule_by_id(target.id) + await schedule_cmd.finish(message) + + elif isinstance(target, TargetByPlugin): + p_name = target.plugin + if p_name not in scheduler_manager.get_registered_plugins(): + await schedule_cmd.finish(f"未找到插件 '{p_name}'。") + + if target.all_groups: + removed_count = await scheduler_manager.remove_schedule_for_all( + p_name, bot_id=bot_id_to_operate + ) + message = ( + f"已取消了 {removed_count} 个群组的插件 '{p_name}' 定时任务。" + if removed_count > 0 + else f"没有找到插件 '{p_name}' 的定时任务。" + ) + await schedule_cmd.finish(message) + else: + _, message = await scheduler_manager.remove_schedule( + p_name, target.group_id, bot_id=bot_id_to_operate + ) + await schedule_cmd.finish(message) + + elif isinstance(target, TargetAll): + if target.for_group: + _, message = await scheduler_manager.remove_schedules_by_group( + target.for_group + ) + await schedule_cmd.finish(message) + else: + _, message = await scheduler_manager.remove_all_schedules() + await schedule_cmd.finish(message) + + else: + await schedule_cmd.finish( + "删除任务失败:请提供任务ID,或通过 -p <插件> 或 -all 指定要删除的任务。" + ) + + +@schedule_cmd.assign("暂停") +async def _( + target: TargetScope = Depends(create_target_parser("暂停")), + bot_id_to_operate: str = Depends(GetBotId), +): + if isinstance(target, TargetByID): + _, message = await scheduler_manager.pause_schedule(target.id) + await schedule_cmd.finish(message) + + elif isinstance(target, TargetByPlugin): + p_name = target.plugin + if p_name not in scheduler_manager.get_registered_plugins(): + await schedule_cmd.finish(f"未找到插件 '{p_name}'。") + + if target.all_groups: + _, message = await scheduler_manager.pause_schedules_by_plugin(p_name) + await schedule_cmd.finish(message) + else: + _, message = await scheduler_manager.pause_schedule_by_plugin_group( + p_name, target.group_id, bot_id=bot_id_to_operate + ) + await schedule_cmd.finish(message) + + elif isinstance(target, TargetAll): + if target.for_group: + _, message = await scheduler_manager.pause_schedules_by_group( + target.for_group + ) + await schedule_cmd.finish(message) + else: + _, message = await scheduler_manager.pause_all_schedules() + await schedule_cmd.finish(message) + + else: + await schedule_cmd.finish("请提供任务ID、使用 -p <插件> 或 -all 选项。") + + +@schedule_cmd.assign("恢复") +async def _( + target: TargetScope = Depends(create_target_parser("恢复")), + bot_id_to_operate: str = Depends(GetBotId), +): + if isinstance(target, TargetByID): + _, message = await scheduler_manager.resume_schedule(target.id) + await schedule_cmd.finish(message) + + elif isinstance(target, TargetByPlugin): + p_name = target.plugin + if p_name not in scheduler_manager.get_registered_plugins(): + await schedule_cmd.finish(f"未找到插件 '{p_name}'。") + + if target.all_groups: + _, message = await scheduler_manager.resume_schedules_by_plugin(p_name) + await schedule_cmd.finish(message) + else: + _, message = await scheduler_manager.resume_schedule_by_plugin_group( + p_name, target.group_id, bot_id=bot_id_to_operate + ) + await schedule_cmd.finish(message) + + elif isinstance(target, TargetAll): + if target.for_group: + _, message = await scheduler_manager.resume_schedules_by_group( + target.for_group + ) + await schedule_cmd.finish(message) + else: + _, message = await scheduler_manager.resume_all_schedules() + await schedule_cmd.finish(message) + + else: + await schedule_cmd.finish("请提供任务ID、使用 -p <插件> 或 -all 选项。") + + +@schedule_cmd.assign("执行") +async def _(schedule_id: int): + _, message = await scheduler_manager.trigger_now(schedule_id) + await schedule_cmd.finish(message) + + +@schedule_cmd.assign("更新") +async def _( + schedule_id: int, + cron_expr: str | None = None, + interval_expr: str | None = None, + date_expr: str | None = None, + daily_expr: str | None = None, + kwargs_str: str | None = None, +): + if not any([cron_expr, interval_expr, date_expr, daily_expr, kwargs_str]): + await schedule_cmd.finish( + "请提供需要更新的时间 (--cron/--interval/--date/--daily) 或参数 (--kwargs)" + ) + + trigger_config = None + trigger_type = None + try: + if cron_expr: + trigger_type = "cron" + parts = cron_expr.split() + if len(parts) != 5: + raise ValueError("Cron 表达式必须有5个部分") + cron_keys = ["minute", "hour", "day", "month", "day_of_week"] + trigger_config = dict(zip(cron_keys, parts)) + elif interval_expr: + trigger_type = "interval" + trigger_config = _parse_interval(interval_expr) + elif date_expr: + trigger_type = "date" + trigger_config = {"run_date": datetime.fromisoformat(date_expr)} + elif daily_expr: + trigger_type = "cron" + trigger_config = _parse_daily_time(daily_expr) + except ValueError as e: + await schedule_cmd.finish(f"时间参数解析错误: {e}") + + job_kwargs = None + if kwargs_str: + schedule = await scheduler_manager.get_schedule_by_id(schedule_id) + if not schedule: + await schedule_cmd.finish(f"未找到 ID 为 {schedule_id} 的任务。") + + task_meta = scheduler_manager._registered_tasks.get(schedule.plugin_name) + if not task_meta or not (params_model := task_meta.get("model")): + await schedule_cmd.finish( + f"插件 '{schedule.plugin_name}' 未定义参数模型,无法更新参数。" + ) + + if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)): + await schedule_cmd.finish( + f"插件 '{schedule.plugin_name}' 的参数模型配置错误。" + ) + + raw_kwargs = {} + try: + for item in kwargs_str.split(","): + key, value = item.strip().split("=", 1) + raw_kwargs[key.strip()] = value + except Exception as e: + await schedule_cmd.finish( + f"参数格式错误,请使用 'key=value,key2=value2' 格式。错误: {e}" + ) + + try: + model_validate = getattr(params_model, "model_validate", None) + if not model_validate: + await schedule_cmd.finish( + f"插件 '{schedule.plugin_name}' 的参数模型不支持验证。" + ) + return + + validated_model = model_validate(raw_kwargs) + + model_dump = getattr(validated_model, "model_dump", None) + if not model_dump: + await schedule_cmd.finish( + f"插件 '{schedule.plugin_name}' 的参数模型不支持导出。" + ) + return + + job_kwargs = model_dump(exclude_unset=True) + except ValidationError as e: + errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()] + error_str = "\n".join(errors) + await schedule_cmd.finish(f"更新的参数验证失败:\n{error_str}") + return + + _, message = await scheduler_manager.update_schedule( + schedule_id, trigger_type, trigger_config, job_kwargs + ) + await schedule_cmd.finish(message) + + +@schedule_cmd.assign("插件列表") +async def _(): + registered_plugins = scheduler_manager.get_registered_plugins() + if not registered_plugins: + await schedule_cmd.finish("当前没有已注册的定时任务插件。") + + message_parts = ["📋 已注册的定时任务插件:"] + for i, plugin_name in enumerate(registered_plugins, 1): + task_meta = scheduler_manager._registered_tasks[plugin_name] + params_model = task_meta.get("model") + + if not params_model: + message_parts.append(f"{i}. {plugin_name} - 无参数") + continue + + if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)): + message_parts.append(f"{i}. {plugin_name} - ⚠️ 参数模型配置错误") + continue + + model_fields = getattr(params_model, "model_fields", None) + if model_fields: + param_info = ", ".join( + f"{field_name}({_get_type_name(field_info.annotation)})" + for field_name, field_info in model_fields.items() + ) + message_parts.append(f"{i}. {plugin_name} - 参数: {param_info}") + else: + message_parts.append(f"{i}. {plugin_name} - 无参数") + + await schedule_cmd.finish("\n".join(message_parts)) + + +@schedule_cmd.assign("状态") +async def _(schedule_id: int): + status = await scheduler_manager.get_schedule_status(schedule_id) + if not status: + await schedule_cmd.finish(f"未找到ID为 {schedule_id} 的定时任务。") + + info_lines = [ + f"📋 定时任务详细信息 (ID: {schedule_id})", + "--------------------", + f"▫️ 插件: {status['plugin_name']}", + f"▫️ Bot ID: {status.get('bot_id') or '默认'}", + f"▫️ 目标: {status['group_id'] or '全局'}", + f"▫️ 状态: {'✔️ 已启用' if status['is_enabled'] else '⏸️ 已暂停'}", + f"▫️ 下次运行: {status['next_run_time']}", + f"▫️ 触发规则: {_format_trigger(status)}", + f"▫️ 任务参数: {_format_params(status)}", + ] + await schedule_cmd.finish("\n".join(info_lines)) diff --git a/zhenxun/builtin_plugins/scripts.py b/zhenxun/builtin_plugins/scripts.py index 27705301..b5fca300 100644 --- a/zhenxun/builtin_plugins/scripts.py +++ b/zhenxun/builtin_plugins/scripts.py @@ -1,67 +1,8 @@ -from asyncio.exceptions import TimeoutError - -import aiofiles -import nonebot -from nonebot.drivers import Driver -from nonebot_plugin_apscheduler import scheduler -import ujson as json - -from zhenxun.configs.path_config import TEXT_PATH from zhenxun.models.group_console import GroupConsole -from zhenxun.services.log import logger -from zhenxun.utils.http_utils import AsyncHttpx - -driver: Driver = nonebot.get_driver() +from zhenxun.utils.manager.priority_manager import PriorityLifecycle -@driver.on_startup -async def update_city(): - """ - 部分插件需要中国省份城市 - 这里直接更新,避免插件内代码重复 - """ - china_city = TEXT_PATH / "china_city.json" - if not china_city.exists(): - data = {} - try: - logger.debug("开始更新城市列表...") - res = await AsyncHttpx.get( - "http://www.weather.com.cn/data/city3jdata/china.html", timeout=5 - ) - res.encoding = "utf8" - provinces_data = json.loads(res.text) - for province in provinces_data.keys(): - data[provinces_data[province]] = [] - res = await AsyncHttpx.get( - f"http://www.weather.com.cn/data/city3jdata/provshi/{province}.html", - timeout=5, - ) - res.encoding = "utf8" - city_data = json.loads(res.text) - for city in city_data.keys(): - data[provinces_data[province]].append(city_data[city]) - async with aiofiles.open(china_city, "w", encoding="utf8") as f: - json.dump(data, f, indent=4, ensure_ascii=False) - logger.info("自动更新城市列表完成.....") - except TimeoutError as e: - logger.warning("自动更新城市列表超时...", e=e) - except ValueError as e: - logger.warning("自动城市列表失败.....", e=e) - except Exception as e: - logger.error("自动城市列表未知错误", e=e) - - -# 自动更新城市列表 -@scheduler.scheduled_job( - "cron", - hour=6, - minute=1, -) -async def _(): - await update_city() - - -@driver.on_startup +@PriorityLifecycle.on_startup(priority=5) async def _(): """开启/禁用插件格式修改""" _, is_create = await GroupConsole.get_or_create(group_id=133133133) diff --git a/zhenxun/builtin_plugins/shop/__init__.py b/zhenxun/builtin_plugins/shop/__init__.py index 89282d63..120d2198 100644 --- a/zhenxun/builtin_plugins/shop/__init__.py +++ b/zhenxun/builtin_plugins/shop/__init__.py @@ -5,7 +5,9 @@ from nonebot_plugin_alconna import ( AlconnaQuery, Args, Arparma, + At, Match, + MultiVar, Option, Query, Subcommand, @@ -33,6 +35,7 @@ __plugin_meta__ = PluginMetadata( usage=""" 商品操作 指令: + 商店 我的金币 我的道具 使用道具 [名称/Id] @@ -46,6 +49,7 @@ __plugin_meta__ = PluginMetadata( plugin_type=PluginType.NORMAL, menu_type="商店", commands=[ + Command(command="商店"), Command(command="我的金币"), Command(command="我的道具"), Command(command="购买道具"), @@ -74,13 +78,21 @@ _matcher = on_alconna( Subcommand("my-cost", help_text="我的金币"), Subcommand("my-props", help_text="我的道具"), Subcommand("buy", Args["name?", str]["num?", int], help_text="购买道具"), - Subcommand("use", Args["name?", str]["num?", int], help_text="使用道具"), Subcommand("gold-list", Args["num?", int], help_text="金币排行"), ), priority=5, block=True, ) +_use_matcher = on_alconna( + Alconna( + "使用道具", + Args["name?", str]["num?", int]["at_users?", MultiVar(At)], + ), + priority=5, + block=True, +) + _matcher.shortcut( "我的金币", command="商店", @@ -102,13 +114,6 @@ _matcher.shortcut( prefix=True, ) -_matcher.shortcut( - "使用道具(?P.*?)", - command="商店", - arguments=["use", "{name}"], - prefix=True, -) - _matcher.shortcut( "金币排行", command="商店", @@ -172,7 +177,7 @@ async def _( await MessageUtils.build_message(result).send(reply_to=True) -@_matcher.assign("use") +@_use_matcher.handle() async def _( bot: Bot, event: Event, @@ -181,6 +186,7 @@ async def _( arparma: Arparma, name: Match[str], num: Query[int] = AlconnaQuery("num", 1), + at_users: Query[list[At]] = AlconnaQuery("at_users", []), ): if not name.available: await MessageUtils.build_message( @@ -188,7 +194,7 @@ async def _( ).finish(reply_to=True) try: result = await ShopManage.use( - bot, event, session, message, name.result, num.result, "" + bot, event, session, message, name.result, num.result, "", at_users.result ) logger.info( f"使用道具 {name.result}, 数量: {num.result}", diff --git a/zhenxun/builtin_plugins/shop/_data_source.py b/zhenxun/builtin_plugins/shop/_data_source.py index 0fdd4e53..682bd85e 100644 --- a/zhenxun/builtin_plugins/shop/_data_source.py +++ b/zhenxun/builtin_plugins/shop/_data_source.py @@ -8,7 +8,7 @@ from typing import Any, Literal from nonebot.adapters import Bot, Event from nonebot.compat import model_dump -from nonebot_plugin_alconna import UniMessage, UniMsg +from nonebot_plugin_alconna import At, UniMessage, UniMsg from nonebot_plugin_uninfo import Uninfo from pydantic import BaseModel, Field, create_model from tortoise.expressions import Q @@ -48,6 +48,10 @@ class Goods(BaseModel): """model""" session: Uninfo | None = None """Uninfo""" + at_user: str | None = None + """At对象""" + at_users: list[str] = [] + """At对象列表""" class ShopParam(BaseModel): @@ -73,6 +77,10 @@ class ShopParam(BaseModel): """Uninfo""" message: UniMsg """UniMessage""" + at_user: str | None = None + """At对象""" + at_users: list[str] = [] + """At对象列表""" extra_data: dict[str, Any] = Field(default_factory=dict) """额外数据""" @@ -156,6 +164,7 @@ class ShopManage: goods: Goods, num: int, text: str, + at_users: list[str] = [], ) -> tuple[ShopParam, dict[str, Any]]: """构造参数 @@ -165,6 +174,7 @@ class ShopManage: goods_name: 商品名称 num: 数量 text: 其他信息 + at_users: at用户 """ group_id = None if session.group: @@ -172,6 +182,7 @@ class ShopManage: session.group.parent.id if session.group.parent else session.group.id ) _kwargs = goods.params + at_user = at_users[0] if at_users else None model = goods.model( **{ "goods_name": goods.name, @@ -183,6 +194,8 @@ class ShopManage: "text": text, "session": session, "message": message, + "at_user": at_user, + "at_users": at_users, } ) return model, { @@ -194,6 +207,8 @@ class ShopManage: "num": num, "text": text, "goods_name": goods.name, + "at_user": at_user, + "at_users": at_users, } @classmethod @@ -223,6 +238,7 @@ class ShopManage: **param.extra_data, "session": session, "message": message, + "shop_param": ShopParam, } for key in list(param_json.keys()): if key not in args: @@ -308,6 +324,7 @@ class ShopManage: goods_name: str, num: int, text: str, + at_users: list[At] = [], ) -> str | UniMessage | None: """使用道具 @@ -319,6 +336,7 @@ class ShopManage: goods_name: 商品名称 num: 使用数量 text: 其他信息 + at_users: at用户 返回: str | MessageFactory | None: 使用完成后返回信息 @@ -339,16 +357,18 @@ class ShopManage: goods = cls.uuid2goods.get(goods_info.uuid) if not goods or not goods.func: return f"{goods_info.goods_name} 未注册使用函数, 无法使用..." + at_user_ids = [at.target for at in at_users] param, kwargs = cls.__build_params( - bot, event, session, message, goods, num, text + bot, event, session, message, goods, num, text, at_user_ids ) if num > param.max_num_limit: return f"{goods_info.goods_name} 单次使用最大数量为{param.max_num_limit}..." await cls.run_before_after(goods, param, session, message, "before", **kwargs) - result = await cls.__run(goods, param, session, message, **kwargs) await UserConsole.use_props( session.user.id, goods_info.uuid, num, PlatformUtils.get_platform(session) ) + result = await cls.__run(goods, param, session, message, **kwargs) + await cls.run_before_after(goods, param, session, message, "after", **kwargs) if not result and param.send_success_msg: result = f"使用道具 {goods.name} {num} 次成功!" @@ -479,10 +499,13 @@ class ShopManage: if not user.props: return None - user.props = {uuid: count for uuid, count in user.props.items() if count > 0} - goods_list = await GoodsInfo.filter(uuid__in=user.props.keys()).all() goods_by_uuid = {item.uuid: item for item in goods_list} + user.props = { + uuid: count + for uuid, count in user.props.items() + if count > 0 and goods_by_uuid.get(uuid) + } table_rows = [] for i, prop_uuid in enumerate(user.props): diff --git a/zhenxun/builtin_plugins/sign_in/__init__.py b/zhenxun/builtin_plugins/sign_in/__init__.py index 0b48a0e7..0986e476 100644 --- a/zhenxun/builtin_plugins/sign_in/__init__.py +++ b/zhenxun/builtin_plugins/sign_in/__init__.py @@ -10,7 +10,6 @@ from nonebot_plugin_alconna import ( store_true, ) from nonebot_plugin_apscheduler import scheduler -from nonebot_plugin_uninfo import Uninfo from zhenxun.configs.utils import ( Command, @@ -23,7 +22,7 @@ from zhenxun.utils.depends import UserName from zhenxun.utils.message import MessageUtils from ._data_source import SignManage -from .goods_register import driver # noqa: F401 +from .goods_register import Uninfo from .utils import clear_sign_data_pic __plugin_meta__ = PluginMetadata( diff --git a/zhenxun/builtin_plugins/sign_in/goods_register.py b/zhenxun/builtin_plugins/sign_in/goods_register.py index f7a65359..6c8e39bb 100644 --- a/zhenxun/builtin_plugins/sign_in/goods_register.py +++ b/zhenxun/builtin_plugins/sign_in/goods_register.py @@ -1,7 +1,6 @@ from decimal import Decimal import nonebot -from nonebot.drivers import Driver from nonebot_plugin_uninfo import Uninfo from zhenxun.models.sign_user import SignUser @@ -9,14 +8,7 @@ from zhenxun.models.user_console import UserConsole from zhenxun.utils.decorator.shop import shop_register from zhenxun.utils.platform import PlatformUtils -driver: Driver = nonebot.get_driver() - - -# @driver.on_startup -# async def _(): -# """ -# 导入内置的三个商品 -# """ +driver = nonebot.get_driver() @shop_register( diff --git a/zhenxun/builtin_plugins/sign_in/utils.py b/zhenxun/builtin_plugins/sign_in/utils.py index 9faf1120..910b90d8 100644 --- a/zhenxun/builtin_plugins/sign_in/utils.py +++ b/zhenxun/builtin_plugins/sign_in/utils.py @@ -16,6 +16,7 @@ from zhenxun.models.sign_log import SignLog from zhenxun.models.sign_user import SignUser from zhenxun.utils.http_utils import AsyncHttpx from zhenxun.utils.image_utils import BuildImage +from zhenxun.utils.manager.priority_manager import PriorityLifecycle from zhenxun.utils.platform import PlatformUtils from .config import ( @@ -54,7 +55,7 @@ LG_MESSAGE = [ ] -@driver.on_startup +@PriorityLifecycle.on_startup(priority=5) async def init_image(): SIGN_RESOURCE_PATH.mkdir(parents=True, exist_ok=True) SIGN_TODAY_CARD_PATH.mkdir(exist_ok=True, parents=True) diff --git a/zhenxun/builtin_plugins/statistics/statistics_hook.py b/zhenxun/builtin_plugins/statistics/statistics_hook.py index f3776ece..3ac15e2a 100644 --- a/zhenxun/builtin_plugins/statistics/statistics_hook.py +++ b/zhenxun/builtin_plugins/statistics/statistics_hook.py @@ -53,10 +53,7 @@ async def _( ) -@scheduler.scheduled_job( - "interval", - minutes=1, -) +@scheduler.scheduled_job("interval", minutes=1, max_instances=5) async def _(): try: call_list = TEMP_LIST.copy() diff --git a/zhenxun/builtin_plugins/superuser/bot_manage/plugin.py b/zhenxun/builtin_plugins/superuser/bot_manage/plugin.py index df6d7f35..c5359951 100644 --- a/zhenxun/builtin_plugins/superuser/bot_manage/plugin.py +++ b/zhenxun/builtin_plugins/superuser/bot_manage/plugin.py @@ -110,7 +110,7 @@ async def enable_plugin( ) await BotConsole.enable_plugin(None, plugin.module) await MessageUtils.build_message( - f"已禁用全部 bot 的插件: {plugin_name.result}" + f"已开启全部 bot 的插件: {plugin_name.result}" ).finish() elif bot_id.available: logger.info( diff --git a/zhenxun/builtin_plugins/superuser/bot_manage/task.py b/zhenxun/builtin_plugins/superuser/bot_manage/task.py index 005ab188..501aec3d 100644 --- a/zhenxun/builtin_plugins/superuser/bot_manage/task.py +++ b/zhenxun/builtin_plugins/superuser/bot_manage/task.py @@ -92,7 +92,7 @@ async def enable_task( ) await BotConsole.enable_task(None, task.module) await MessageUtils.build_message( - f"已禁用全部 bot 的被动: {task_name.available}" + f"已开启全部 bot 的被动: {task_name.available}" ).finish() elif bot_id.available: logger.info( diff --git a/zhenxun/builtin_plugins/superuser/broadcast/__init__.py b/zhenxun/builtin_plugins/superuser/broadcast/__init__.py index c025fd0c..3fc08e4c 100644 --- a/zhenxun/builtin_plugins/superuser/broadcast/__init__.py +++ b/zhenxun/builtin_plugins/superuser/broadcast/__init__.py @@ -1,32 +1,77 @@ -from typing import Annotated - -from nonebot import on_command -from nonebot.adapters import Bot -from nonebot.params import Command +from arclet.alconna import AllParam +from nepattern import UnionPattern +from nonebot.adapters import Bot, Event from nonebot.permission import SUPERUSER from nonebot.plugin import PluginMetadata from nonebot.rule import to_me -from nonebot_plugin_alconna import Text as alcText -from nonebot_plugin_alconna import UniMsg +import nonebot_plugin_alconna as alc +from nonebot_plugin_alconna import ( + Alconna, + Args, + on_alconna, +) +from nonebot_plugin_alconna.uniseg.segment import ( + At, + AtAll, + Audio, + Button, + Emoji, + File, + Hyper, + Image, + Keyboard, + Reference, + Reply, + Text, + Video, + Voice, +) from nonebot_plugin_session import EventSession from zhenxun.configs.utils import PluginExtraData, RegisterConfig, Task -from zhenxun.services.log import logger from zhenxun.utils.enum import PluginType from zhenxun.utils.message import MessageUtils -from ._data_source import BroadcastManage +from .broadcast_manager import BroadcastManager +from .message_processor import ( + _extract_broadcast_content, + get_broadcast_target_groups, + send_broadcast_and_notify, +) + +BROADCAST_SEND_DELAY_RANGE = (1, 3) __plugin_meta__ = PluginMetadata( name="广播", description="昭告天下!", usage=""" - 广播 [消息] [图片] - 示例:广播 你们好! + 广播 [消息内容] + - 直接发送消息到除当前群组外的所有群组 + - 支持文本、图片、@、表情、视频等多种消息类型 + - 示例:广播 你们好! + - 示例:广播 [图片] 新活动开始啦! + + 广播 + 引用消息 + - 将引用的消息作为广播内容发送 + - 支持引用普通消息或合并转发消息 + - 示例:(引用一条消息) 广播 + + 广播撤回 + - 撤回最近一次由您触发的广播消息 + - 仅能撤回短时间内的消息 + - 示例:广播撤回 + + 特性: + - 在群组中使用广播时,不会将消息发送到当前群组 + - 在私聊中使用广播时,会发送到所有群组 + + 别名: + - bc (广播的简写) + - recall (广播撤回的别名) """.strip(), extra=PluginExtraData( author="HibiKier", - version="0.1", + version="1.2", plugin_type=PluginType.SUPERUSER, configs=[ RegisterConfig( @@ -42,26 +87,106 @@ __plugin_meta__ = PluginMetadata( ).to_dict(), ) -_matcher = on_command( - "广播", priority=1, permission=SUPERUSER, block=True, rule=to_me() +AnySeg = ( + UnionPattern( + [ + Text, + Image, + At, + AtAll, + Audio, + Video, + File, + Emoji, + Reply, + Reference, + Hyper, + Button, + Keyboard, + Voice, + ] + ) + @ "AnySeg" +) + +_matcher = on_alconna( + Alconna( + "广播", + Args["content?", AllParam], + ), + aliases={"bc"}, + priority=1, + permission=SUPERUSER, + block=True, + rule=to_me(), + use_origin=False, +) + +_recall_matcher = on_alconna( + Alconna("广播撤回"), + aliases={"recall"}, + priority=1, + permission=SUPERUSER, + block=True, + rule=to_me(), ) @_matcher.handle() -async def _( +async def handle_broadcast( bot: Bot, + event: Event, session: EventSession, - message: UniMsg, - command: Annotated[tuple[str, ...], Command()], + arp: alc.Arparma, ): - for msg in message: - if isinstance(msg, alcText) and msg.text.strip().startswith(command[0]): - msg.text = msg.text.replace(command[0], "", 1).strip() - break - await MessageUtils.build_message("正在发送..请等一下哦!").send() - count, error_count = await BroadcastManage.send(bot, message, session) - result = f"成功广播 {count} 个群组" - if error_count: - result += f"\n广播失败 {error_count} 个群组" - await MessageUtils.build_message(f"发送广播完成!\n{result}").send(reply_to=True) - logger.info(f"发送广播信息: {message}", "广播", session=session) + broadcast_content_msg = await _extract_broadcast_content(bot, event, arp, session) + if not broadcast_content_msg: + return + + target_groups, enabled_groups = await get_broadcast_target_groups(bot, session) + if not target_groups or not enabled_groups: + return + + try: + await send_broadcast_and_notify( + bot, event, broadcast_content_msg, enabled_groups, target_groups, session + ) + except Exception as e: + error_msg = "发送广播失败" + BroadcastManager.log_error(error_msg, e, session) + await MessageUtils.build_message(f"{error_msg}。").send(reply_to=True) + + +@_recall_matcher.handle() +async def handle_broadcast_recall( + bot: Bot, + event: Event, + session: EventSession, +): + """处理广播撤回命令""" + await MessageUtils.build_message("正在尝试撤回最近一次广播...").send() + + try: + success_count, error_count = await BroadcastManager.recall_last_broadcast( + bot, session + ) + + user_id = str(event.get_user_id()) + if success_count == 0 and error_count == 0: + await bot.send_private_msg( + user_id=user_id, + message="没有找到最近的广播消息记录,可能已经撤回或超过可撤回时间。", + ) + else: + result = f"广播撤回完成!\n成功撤回 {success_count} 条消息" + if error_count: + result += f"\n撤回失败 {error_count} 条消息 (可能已过期或无权限)" + await bot.send_private_msg(user_id=user_id, message=result) + BroadcastManager.log_info( + f"广播撤回完成: 成功 {success_count}, 失败 {error_count}", session + ) + except Exception as e: + error_msg = "撤回广播消息失败" + BroadcastManager.log_error(error_msg, e, session) + user_id = str(event.get_user_id()) + await bot.send_private_msg(user_id=user_id, message=f"{error_msg}。") diff --git a/zhenxun/builtin_plugins/superuser/broadcast/_data_source.py b/zhenxun/builtin_plugins/superuser/broadcast/_data_source.py deleted file mode 100644 index 1ee1a28c..00000000 --- a/zhenxun/builtin_plugins/superuser/broadcast/_data_source.py +++ /dev/null @@ -1,72 +0,0 @@ -import asyncio -import random - -from nonebot.adapters import Bot -import nonebot_plugin_alconna as alc -from nonebot_plugin_alconna import Image, UniMsg -from nonebot_plugin_session import EventSession - -from zhenxun.services.log import logger -from zhenxun.utils.common_utils import CommonUtils -from zhenxun.utils.message import MessageUtils -from zhenxun.utils.platform import PlatformUtils - - -class BroadcastManage: - @classmethod - async def send( - cls, bot: Bot, message: UniMsg, session: EventSession - ) -> tuple[int, int]: - """发送广播消息 - - 参数: - bot: Bot - message: 消息内容 - session: Session - - 返回: - tuple[int, int]: 发送成功的群组数量, 发送失败的群组数量 - """ - message_list = [] - for msg in message: - if isinstance(msg, alc.Image) and msg.url: - message_list.append(Image(url=msg.url)) - elif isinstance(msg, alc.Text): - message_list.append(msg.text) - group_list, _ = await PlatformUtils.get_group_list(bot) - if group_list: - error_count = 0 - for group in group_list: - try: - if not await CommonUtils.task_is_block( - bot, - "broadcast", # group.channel_id - group.group_id, - ): - target = PlatformUtils.get_target( - group_id=group.group_id, channel_id=group.channel_id - ) - if target: - await MessageUtils.build_message(message_list).send( - target, bot - ) - logger.debug( - "发送成功", - "广播", - session=session, - target=f"{group.group_id}:{group.channel_id}", - ) - await asyncio.sleep(random.randint(1, 3)) - else: - logger.warning("target为空", "广播", session=session) - except Exception as e: - error_count += 1 - logger.error( - "发送失败", - "广播", - session=session, - target=f"{group.group_id}:{group.channel_id}", - e=e, - ) - return len(group_list) - error_count, error_count - return 0, 0 diff --git a/zhenxun/builtin_plugins/superuser/broadcast/broadcast_manager.py b/zhenxun/builtin_plugins/superuser/broadcast/broadcast_manager.py new file mode 100644 index 00000000..c3d7b5cc --- /dev/null +++ b/zhenxun/builtin_plugins/superuser/broadcast/broadcast_manager.py @@ -0,0 +1,490 @@ +import asyncio +import random +import traceback +from typing import ClassVar + +from nonebot.adapters import Bot +from nonebot.adapters.onebot.v11 import Bot as V11Bot +from nonebot.exception import ActionFailed +from nonebot_plugin_alconna import UniMessage +from nonebot_plugin_alconna.uniseg import Receipt, Reference +from nonebot_plugin_session import EventSession + +from zhenxun.models.group_console import GroupConsole +from zhenxun.services.log import logger +from zhenxun.utils.common_utils import CommonUtils +from zhenxun.utils.platform import PlatformUtils + +from .models import BroadcastDetailResult, BroadcastResult +from .utils import custom_nodes_to_v11_nodes, uni_message_to_v11_list_of_dicts + + +class BroadcastManager: + """广播管理器""" + + _last_broadcast_msg_ids: ClassVar[dict[str, int]] = {} + + @staticmethod + def _get_session_info(session: EventSession | None) -> str: + """获取会话信息字符串""" + if not session: + return "" + + try: + platform = getattr(session, "platform", "unknown") + session_id = str(session) + return f"[{platform}:{session_id}]" + except Exception: + return "[session-info-error]" + + @staticmethod + def log_error( + message: str, error: Exception, session: EventSession | None = None, **kwargs + ): + """记录错误日志""" + session_info = BroadcastManager._get_session_info(session) + error_type = type(error).__name__ + stack_trace = traceback.format_exc() + error_details = f"\n类型: {error_type}\n信息: {error!s}\n堆栈: {stack_trace}" + + logger.error( + f"{session_info} {message}{error_details}", "广播", e=error, **kwargs + ) + + @staticmethod + def log_warning(message: str, session: EventSession | None = None, **kwargs): + """记录警告级别日志""" + session_info = BroadcastManager._get_session_info(session) + logger.warning(f"{session_info} {message}", "广播", **kwargs) + + @staticmethod + def log_info(message: str, session: EventSession | None = None, **kwargs): + """记录信息级别日志""" + session_info = BroadcastManager._get_session_info(session) + logger.info(f"{session_info} {message}", "广播", **kwargs) + + @classmethod + def get_last_broadcast_msg_ids(cls) -> dict[str, int]: + """获取最近广播消息ID""" + return cls._last_broadcast_msg_ids.copy() + + @classmethod + def clear_last_broadcast_msg_ids(cls) -> None: + """清空消息ID记录""" + cls._last_broadcast_msg_ids.clear() + + @classmethod + async def get_all_groups(cls, bot: Bot) -> tuple[list[GroupConsole], str]: + """获取群组列表""" + return await PlatformUtils.get_group_list(bot) + + @classmethod + async def send( + cls, bot: Bot, message: UniMessage, session: EventSession + ) -> BroadcastResult: + """发送广播到所有群组""" + logger.debug( + f"开始广播(send - 广播到所有群组),Bot ID: {bot.self_id}", + "广播", + session=session, + ) + + logger.debug("清空上一次的广播消息ID记录", "广播", session=session) + cls.clear_last_broadcast_msg_ids() + + all_groups, _ = await cls.get_all_groups(bot) + return await cls.send_to_specific_groups(bot, message, all_groups, session) + + @classmethod + async def send_to_specific_groups( + cls, + bot: Bot, + message: UniMessage, + target_groups: list[GroupConsole], + session_info: EventSession | str | None = None, + ) -> BroadcastResult: + """发送广播到指定群组""" + log_session = session_info or bot.self_id + logger.debug( + f"开始广播,目标 {len(target_groups)} 个群组,Bot ID: {bot.self_id}", + "广播", + session=log_session, + ) + + if not target_groups: + logger.debug("目标群组列表为空,广播结束", "广播", session=log_session) + return 0, 0 + + platform = PlatformUtils.get_platform(bot) + is_forward_broadcast = any( + isinstance(seg, Reference) and getattr(seg, "nodes", None) + for seg in message + ) + + if platform == "qq" and isinstance(bot, V11Bot) and is_forward_broadcast: + if ( + len(message) == 1 + and isinstance(message[0], Reference) + and getattr(message[0], "nodes", None) + ): + nodes_list = getattr(message[0], "nodes", []) + v11_nodes = custom_nodes_to_v11_nodes(nodes_list) + node_count = len(v11_nodes) + logger.debug( + f"从 UniMessage 构造转发节点数: {node_count}", + "广播", + session=log_session, + ) + else: + logger.warning( + "广播消息包含合并转发段和其他段,将尝试打平成一个节点发送", + "广播", + session=log_session, + ) + v11_content_list = uni_message_to_v11_list_of_dicts(message) + v11_nodes = ( + [ + { + "type": "node", + "data": { + "user_id": bot.self_id, + "nickname": "广播", + "content": v11_content_list, + }, + } + ] + if v11_content_list + else [] + ) + + if not v11_nodes: + logger.warning( + "构造出的 V11 合并转发节点为空,无法发送", + "广播", + session=log_session, + ) + return 0, len(target_groups) + success_count, error_count, skip_count = await cls._broadcast_forward( + bot, log_session, target_groups, v11_nodes + ) + else: + if is_forward_broadcast: + logger.warning( + f"合并转发消息在适配器 ({platform}) 不支持,将作为普通消息发送", + "广播", + session=log_session, + ) + success_count, error_count, skip_count = await cls._broadcast_normal( + bot, log_session, target_groups, message + ) + + total = len(target_groups) + stats = f"成功: {success_count}, 失败: {error_count}" + stats += f", 跳过: {skip_count}, 总计: {total}" + logger.debug( + f"广播统计 - {stats}", + "广播", + session=log_session, + ) + + msg_ids = cls.get_last_broadcast_msg_ids() + if msg_ids: + id_list_str = ", ".join([f"{k}:{v}" for k, v in msg_ids.items()]) + logger.debug( + f"广播结束,记录了 {len(msg_ids)} 条消息ID: {id_list_str}", + "广播", + session=log_session, + ) + else: + logger.warning( + "广播结束,但没有记录任何消息ID", + "广播", + session=log_session, + ) + + return success_count, error_count + + @classmethod + async def _extract_message_id_from_result( + cls, + result: dict | Receipt, + group_key: str, + session_info: EventSession | str, + msg_type: str = "普通", + ) -> None: + """提取消息ID并记录""" + if isinstance(result, dict) and "message_id" in result: + msg_id = result["message_id"] + try: + msg_id_int = int(msg_id) + cls._last_broadcast_msg_ids[group_key] = msg_id_int + logger.debug( + f"记录群 {group_key} 的{msg_type}消息ID: {msg_id_int}", + "广播", + session=session_info, + ) + except (ValueError, TypeError): + logger.warning( + f"{msg_type}结果中的 message_id 不是有效整数: {msg_id}", + "广播", + session=session_info, + ) + elif isinstance(result, Receipt) and result.msg_ids: + try: + first_id_info = result.msg_ids[0] + msg_id = None + if isinstance(first_id_info, dict) and "message_id" in first_id_info: + msg_id = first_id_info["message_id"] + logger.debug( + f"从 Receipt.msg_ids[0] 提取到 ID: {msg_id}", + "广播", + session=session_info, + ) + elif isinstance(first_id_info, int | str): + msg_id = first_id_info + logger.debug( + f"从 Receipt.msg_ids[0] 提取到原始ID: {msg_id}", + "广播", + session=session_info, + ) + + if msg_id is not None: + try: + msg_id_int = int(msg_id) + cls._last_broadcast_msg_ids[group_key] = msg_id_int + logger.debug( + f"记录群 {group_key} 的消息ID: {msg_id_int}", + "广播", + session=session_info, + ) + except (ValueError, TypeError): + logger.warning( + f"提取的ID ({msg_id}) 不是有效整数", + "广播", + session=session_info, + ) + else: + info_str = str(first_id_info) + logger.warning( + f"无法从 Receipt.msg_ids[0] 提取ID: {info_str}", + "广播", + session=session_info, + ) + except IndexError: + logger.warning("Receipt.msg_ids 为空", "广播", session=session_info) + except Exception as e_extract: + logger.error( + f"从 Receipt 提取 msg_id 时出错: {e_extract}", + "广播", + session=session_info, + e=e_extract, + ) + else: + logger.warning( + f"发送成功但无法从结果获取消息 ID. 结果: {result}", + "广播", + session=session_info, + ) + + @classmethod + async def _check_group_availability(cls, bot: Bot, group: GroupConsole) -> bool: + """检查群组是否可用""" + if not group.group_id: + return False + + if await CommonUtils.task_is_block(bot, "broadcast", group.group_id): + return False + + return True + + @classmethod + async def _broadcast_forward( + cls, + bot: V11Bot, + session_info: EventSession | str, + group_list: list[GroupConsole], + v11_nodes: list[dict], + ) -> BroadcastDetailResult: + """发送合并转发""" + success_count = 0 + error_count = 0 + skip_count = 0 + + for _, group in enumerate(group_list): + group_key = group.group_id or group.channel_id + + if not await cls._check_group_availability(bot, group): + skip_count += 1 + continue + + try: + result = await bot.send_group_forward_msg( + group_id=int(group.group_id), messages=v11_nodes + ) + + logger.debug( + f"合并转发消息发送结果: {result}, 类型: {type(result)}", + "广播", + session=session_info, + ) + + await cls._extract_message_id_from_result( + result, group_key, session_info, "合并转发" + ) + + success_count += 1 + await asyncio.sleep(random.randint(1, 3)) + except ActionFailed as af_e: + error_count += 1 + logger.error( + f"发送失败(合并转发) to {group_key}: {af_e}", + "广播", + session=session_info, + e=af_e, + ) + except Exception as e: + error_count += 1 + logger.error( + f"发送失败(合并转发) to {group_key}: {e}", + "广播", + session=session_info, + e=e, + ) + + return success_count, error_count, skip_count + + @classmethod + async def _broadcast_normal( + cls, + bot: Bot, + session_info: EventSession | str, + group_list: list[GroupConsole], + message: UniMessage, + ) -> BroadcastDetailResult: + """发送普通消息""" + success_count = 0 + error_count = 0 + skip_count = 0 + + for _, group in enumerate(group_list): + group_key = ( + f"{group.group_id}:{group.channel_id}" + if group.channel_id + else str(group.group_id) + ) + + if not await cls._check_group_availability(bot, group): + skip_count += 1 + continue + + try: + target = PlatformUtils.get_target( + group_id=group.group_id, channel_id=group.channel_id + ) + + if target: + receipt: Receipt = await message.send(target, bot=bot) + + logger.debug( + f"广播消息发送结果: {receipt}, 类型: {type(receipt)}", + "广播", + session=session_info, + ) + + await cls._extract_message_id_from_result( + receipt, group_key, session_info + ) + + success_count += 1 + await asyncio.sleep(random.randint(1, 3)) + else: + logger.warning( + "target为空", "广播", session=session_info, target=group_key + ) + skip_count += 1 + except Exception as e: + error_count += 1 + logger.error( + f"发送失败(普通) to {group_key}: {e}", + "广播", + session=session_info, + e=e, + ) + + return success_count, error_count, skip_count + + @classmethod + async def recall_last_broadcast( + cls, bot: Bot, session_info: EventSession | str + ) -> BroadcastResult: + """撤回最近广播""" + msg_ids_to_recall = cls.get_last_broadcast_msg_ids() + + if not msg_ids_to_recall: + logger.warning( + "没有找到最近的广播消息ID记录", "广播撤回", session=session_info + ) + return 0, 0 + + id_list_str = ", ".join([f"{k}:{v}" for k, v in msg_ids_to_recall.items()]) + logger.debug( + f"找到 {len(msg_ids_to_recall)} 条广播消息ID记录: {id_list_str}", + "广播撤回", + session=session_info, + ) + + success_count = 0 + error_count = 0 + + logger.info( + f"准备撤回 {len(msg_ids_to_recall)} 条广播消息", + "广播撤回", + session=session_info, + ) + + for group_key, msg_id in msg_ids_to_recall.items(): + try: + logger.debug( + f"尝试撤回消息 (ID: {msg_id}) in {group_key}", + "广播撤回", + session=session_info, + ) + await bot.call_api("delete_msg", message_id=msg_id) + success_count += 1 + except ActionFailed as af_e: + retcode = getattr(af_e, "retcode", None) + wording = getattr(af_e, "wording", "") + if retcode == 100 and "MESSAGE_NOT_FOUND" in wording.upper(): + logger.warning( + f"消息 (ID: {msg_id}) 可能已被撤回或不存在于 {group_key}", + "广播撤回", + session=session_info, + ) + elif retcode == 300 and "delete message" in wording.lower(): + logger.warning( + f"消息 (ID: {msg_id}) 可能已被撤回或不存在于 {group_key}", + "广播撤回", + session=session_info, + ) + else: + error_count += 1 + logger.error( + f"撤回消息失败 (ID: {msg_id}) in {group_key}: {af_e}", + "广播撤回", + session=session_info, + e=af_e, + ) + except Exception as e: + error_count += 1 + logger.error( + f"撤回消息时发生未知错误 (ID: {msg_id}) in {group_key}: {e}", + "广播撤回", + session=session_info, + e=e, + ) + await asyncio.sleep(0.2) + + logger.debug("撤回操作完成,清空消息ID记录", "广播撤回", session=session_info) + cls.clear_last_broadcast_msg_ids() + + return success_count, error_count diff --git a/zhenxun/builtin_plugins/superuser/broadcast/message_processor.py b/zhenxun/builtin_plugins/superuser/broadcast/message_processor.py new file mode 100644 index 00000000..809e3645 --- /dev/null +++ b/zhenxun/builtin_plugins/superuser/broadcast/message_processor.py @@ -0,0 +1,584 @@ +import base64 +import json +from typing import Any + +from nonebot.adapters import Bot, Event +from nonebot.adapters.onebot.v11 import Message as V11Message +from nonebot.adapters.onebot.v11 import MessageSegment as V11MessageSegment +from nonebot.exception import ActionFailed +import nonebot_plugin_alconna as alc +from nonebot_plugin_alconna import UniMessage +from nonebot_plugin_alconna.uniseg.segment import ( + At, + AtAll, + CustomNode, + Image, + Reference, + Reply, + Text, + Video, +) +from nonebot_plugin_alconna.uniseg.tools import reply_fetch +from nonebot_plugin_session import EventSession + +from zhenxun.services.log import logger +from zhenxun.utils.common_utils import CommonUtils +from zhenxun.utils.message import MessageUtils + +from .broadcast_manager import BroadcastManager + +MAX_FORWARD_DEPTH = 3 + + +async def _process_forward_content( + forward_content: Any, forward_id: str | None, bot: Bot, depth: int +) -> list[CustomNode]: + """处理转发消息内容""" + nodes_for_alc = [] + content_parsed = False + + if forward_content: + nodes_from_content = None + if isinstance(forward_content, list): + nodes_from_content = forward_content + elif isinstance(forward_content, str): + try: + parsed_content = json.loads(forward_content) + if isinstance(parsed_content, list): + nodes_from_content = parsed_content + except Exception as json_e: + logger.debug( + f"[Depth {depth}] JSON解析失败: {json_e}", + "广播", + ) + + if nodes_from_content is not None: + logger.debug( + f"[D{depth}] 节点数: {len(nodes_from_content)}", + "广播", + ) + content_parsed = True + for node_data in nodes_from_content: + node = await _create_custom_node_from_data(node_data, bot, depth + 1) + if node: + nodes_for_alc.append(node) + + if not content_parsed and forward_id: + logger.debug( + f"[D{depth}] 尝试API调用ID: {forward_id}", + "广播", + ) + try: + forward_data = await bot.call_api("get_forward_msg", id=forward_id) + nodes_list = None + + if isinstance(forward_data, dict) and "messages" in forward_data: + nodes_list = forward_data["messages"] + elif ( + isinstance(forward_data, dict) + and "data" in forward_data + and isinstance(forward_data["data"], dict) + and "message" in forward_data["data"] + ): + nodes_list = forward_data["data"]["message"] + elif isinstance(forward_data, list): + nodes_list = forward_data + + if nodes_list: + node_count = len(nodes_list) + logger.debug( + f"[D{depth + 1}] 节点:{node_count}", + "广播", + ) + for node_data in nodes_list: + node = await _create_custom_node_from_data( + node_data, bot, depth + 1 + ) + if node: + nodes_for_alc.append(node) + else: + logger.warning( + f"[D{depth + 1}] ID:{forward_id}无节点", + "广播", + ) + nodes_for_alc.append( + CustomNode( + uid="0", + name="错误", + content="[嵌套转发消息获取失败]", + ) + ) + except ActionFailed as af_e: + logger.error( + f"[D{depth + 1}] API失败: {af_e}", + "广播", + e=af_e, + ) + nodes_for_alc.append( + CustomNode( + uid="0", + name="错误", + content="[嵌套转发消息获取失败]", + ) + ) + except Exception as e: + logger.error( + f"[D{depth + 1}] 处理出错: {e}", + "广播", + e=e, + ) + nodes_for_alc.append( + CustomNode( + uid="0", + name="错误", + content="[处理嵌套转发时出错]", + ) + ) + elif not content_parsed and not forward_id: + logger.warning( + f"[D{depth}] 转发段无内容也无ID", + "广播", + ) + nodes_for_alc.append( + CustomNode( + uid="0", + name="错误", + content="[嵌套转发消息无法解析]", + ) + ) + elif content_parsed and not nodes_for_alc: + logger.warning( + f"[D{depth}] 解析成功但无有效节点", + "广播", + ) + nodes_for_alc.append( + CustomNode( + uid="0", + name="信息", + content="[嵌套转发内容为空]", + ) + ) + + return nodes_for_alc + + +async def _create_custom_node_from_data( + node_data: dict, bot: Bot, depth: int +) -> CustomNode | None: + """从节点数据创建CustomNode""" + node_content_raw = node_data.get("message") or node_data.get("content") + if not node_content_raw: + logger.warning(f"[D{depth}] 节点缺少消息内容", "广播") + return None + + sender = node_data.get("sender", {}) + uid = str(sender.get("user_id", "10000")) + name = sender.get("nickname", f"用户{uid[:4]}") + + extracted_uni_msg = await _extract_content_from_message( + node_content_raw, bot, depth + ) + if not extracted_uni_msg: + return None + + return CustomNode(uid=uid, name=name, content=extracted_uni_msg) + + +async def _extract_broadcast_content( + bot: Bot, + event: Event, + arp: alc.Arparma, + session: EventSession, +) -> UniMessage | None: + """从命令参数或引用消息中提取广播内容""" + broadcast_content_msg: UniMessage | None = None + + command_content_list = arp.all_matched_args.get("content", []) + + processed_command_list = [] + has_command_content = False + + if command_content_list: + for item in command_content_list: + if isinstance(item, alc.Segment): + processed_command_list.append(item) + if not (isinstance(item, Text) and not item.text.strip()): + has_command_content = True + elif isinstance(item, str): + if item.strip(): + processed_command_list.append(Text(item.strip())) + has_command_content = True + else: + logger.warning( + f"Unexpected type in command content: {type(item)}", "广播" + ) + + if has_command_content: + logger.debug("检测到命令参数内容,优先使用参数内容", "广播", session=session) + broadcast_content_msg = UniMessage(processed_command_list) + + if not broadcast_content_msg.filter( + lambda x: not (isinstance(x, Text) and not x.text.strip()) + ): + logger.warning( + "命令参数内容解析后为空或只包含空白", "广播", session=session + ) + broadcast_content_msg = None + + if not broadcast_content_msg: + reply_segment_obj: Reply | None = await reply_fetch(event, bot) + if ( + reply_segment_obj + and hasattr(reply_segment_obj, "msg") + and reply_segment_obj.msg + ): + logger.debug( + "未检测到有效命令参数,检测到引用消息", "广播", session=session + ) + raw_quoted_content = reply_segment_obj.msg + is_forward = False + forward_id = None + + if isinstance(raw_quoted_content, V11Message): + for seg in raw_quoted_content: + if isinstance(seg, V11MessageSegment): + if seg.type == "forward": + forward_id = seg.data.get("id") + is_forward = bool(forward_id) + break + elif seg.type == "json": + try: + json_data_str = seg.data.get("data", "{}") + if isinstance(json_data_str, str): + import json + + json_data = json.loads(json_data_str) + if ( + json_data.get("app") == "com.tencent.multimsg" + or json_data.get("view") == "Forward" + ) and json_data.get("meta", {}).get( + "detail", {} + ).get("resid"): + forward_id = json_data["meta"]["detail"][ + "resid" + ] + is_forward = True + break + except Exception: + pass + + if is_forward and forward_id: + logger.info( + f"尝试获取并构造合并转发内容 (ID: {forward_id})", + "广播", + session=session, + ) + nodes_to_forward: list[CustomNode] = [] + try: + forward_data = await bot.call_api("get_forward_msg", id=forward_id) + nodes_list = None + if isinstance(forward_data, dict) and "messages" in forward_data: + nodes_list = forward_data["messages"] + elif ( + isinstance(forward_data, dict) + and "data" in forward_data + and isinstance(forward_data["data"], dict) + and "message" in forward_data["data"] + ): + nodes_list = forward_data["data"]["message"] + elif isinstance(forward_data, list): + nodes_list = forward_data + + if nodes_list is not None: + for node_data in nodes_list: + node_sender = node_data.get("sender", {}) + node_user_id = str(node_sender.get("user_id", "10000")) + node_nickname = node_sender.get( + "nickname", f"用户{node_user_id[:4]}" + ) + node_content_raw = node_data.get( + "message" + ) or node_data.get("content") + if node_content_raw: + extracted_node_uni_msg = ( + await _extract_content_from_message( + node_content_raw, bot + ) + ) + if extracted_node_uni_msg: + nodes_to_forward.append( + CustomNode( + uid=node_user_id, + name=node_nickname, + content=extracted_node_uni_msg, + ) + ) + if nodes_to_forward: + broadcast_content_msg = UniMessage( + Reference(nodes=nodes_to_forward) + ) + except ActionFailed: + await MessageUtils.build_message( + "获取合并转发消息失败,可能不支持此 API。" + ).send(reply_to=True) + return None + except Exception as api_e: + logger.error(f"处理合并转发时出错: {api_e}", "广播", e=api_e) + await MessageUtils.build_message( + "处理合并转发消息时发生内部错误。" + ).send(reply_to=True) + return None + else: + broadcast_content_msg = await _extract_content_from_message( + raw_quoted_content, bot + ) + else: + logger.debug("未检测到命令参数和引用消息", "广播", session=session) + await MessageUtils.build_message("请提供广播内容或引用要广播的消息").send( + reply_to=True + ) + return None + + if not broadcast_content_msg: + logger.error( + "未能从命令参数或引用消息中获取有效的广播内容", "广播", session=session + ) + await MessageUtils.build_message("错误:未能获取有效的广播内容。").send( + reply_to=True + ) + return None + + return broadcast_content_msg + + +async def _process_v11_segment( + seg_obj: V11MessageSegment | dict, depth: int, index: int, bot: Bot +) -> list[alc.Segment]: + """处理V11消息段""" + result = [] + seg_type = None + data_dict = None + + if isinstance(seg_obj, V11MessageSegment): + seg_type = seg_obj.type + data_dict = seg_obj.data + elif isinstance(seg_obj, dict): + seg_type = seg_obj.get("type") + data_dict = seg_obj.get("data") + else: + return result + + if not (seg_type and data_dict is not None): + logger.warning(f"[D{depth}] 跳过无效数据: {type(seg_obj)}", "广播") + return result + + if seg_type == "text": + text_content = data_dict.get("text", "") + if isinstance(text_content, str) and text_content.strip(): + result.append(Text(text_content)) + elif seg_type == "image": + img_seg = None + if data_dict.get("url"): + img_seg = Image(url=data_dict["url"]) + elif data_dict.get("file"): + file_val = data_dict["file"] + if isinstance(file_val, str) and file_val.startswith("base64://"): + b64_data = file_val[9:] + raw_bytes = base64.b64decode(b64_data) + img_seg = Image(raw=raw_bytes) + else: + img_seg = Image(path=file_val) + if img_seg: + result.append(img_seg) + else: + logger.warning(f"[Depth {depth}] V11 图片 {index} 缺少URL/文件", "广播") + elif seg_type == "at": + target_qq = data_dict.get("qq", "") + if target_qq.lower() == "all": + result.append(AtAll()) + elif target_qq: + result.append(At(flag="user", target=target_qq)) + elif seg_type == "video": + video_seg = None + if data_dict.get("url"): + video_seg = Video(url=data_dict["url"]) + elif data_dict.get("file"): + file_val = data_dict["file"] + if isinstance(file_val, str) and file_val.startswith("base64://"): + b64_data = file_val[9:] + raw_bytes = base64.b64decode(b64_data) + video_seg = Video(raw=raw_bytes) + else: + video_seg = Video(path=file_val) + if video_seg: + result.append(video_seg) + logger.debug(f"[Depth {depth}] 处理视频消息成功", "广播") + else: + logger.warning(f"[Depth {depth}] V11 视频 {index} 缺少URL/文件", "广播") + elif seg_type == "forward": + nested_forward_id = data_dict.get("id") or data_dict.get("resid") + nested_forward_content = data_dict.get("content") + + logger.debug(f"[D{depth}] 嵌套转发ID: {nested_forward_id}", "广播") + + nested_nodes = await _process_forward_content( + nested_forward_content, nested_forward_id, bot, depth + ) + + if nested_nodes: + result.append(Reference(nodes=nested_nodes)) + else: + logger.warning(f"[D{depth}] 跳过类型: {seg_type}", "广播") + + return result + + +async def _extract_content_from_message( + message_content: Any, bot: Bot, depth: int = 0 +) -> UniMessage: + """提取消息内容到UniMessage""" + temp_msg = UniMessage() + input_type_str = str(type(message_content)) + + if depth >= MAX_FORWARD_DEPTH: + logger.warning( + f"[Depth {depth}] 达到最大递归深度 {MAX_FORWARD_DEPTH},停止解析嵌套转发。", + "广播", + ) + temp_msg.append(Text("[嵌套转发层数过多,内容已省略]")) + return temp_msg + + segments_to_process = [] + + if isinstance(message_content, UniMessage): + segments_to_process = list(message_content) + elif isinstance(message_content, V11Message): + segments_to_process = list(message_content) + elif isinstance(message_content, list): + segments_to_process = message_content + elif ( + isinstance(message_content, dict) + and "type" in message_content + and "data" in message_content + ): + segments_to_process = [message_content] + elif isinstance(message_content, str): + if message_content.strip(): + temp_msg.append(Text(message_content)) + return temp_msg + else: + logger.warning(f"[Depth {depth}] 无法处理的输入类型: {input_type_str}", "广播") + return temp_msg + + if segments_to_process: + for index, seg_obj in enumerate(segments_to_process): + try: + if isinstance(seg_obj, Text): + text_content = getattr(seg_obj, "text", None) + if isinstance(text_content, str) and text_content.strip(): + temp_msg.append(seg_obj) + elif isinstance(seg_obj, Image): + if ( + getattr(seg_obj, "url", None) + or getattr(seg_obj, "path", None) + or getattr(seg_obj, "raw", None) + ): + temp_msg.append(seg_obj) + elif isinstance(seg_obj, At): + temp_msg.append(seg_obj) + elif isinstance(seg_obj, AtAll): + temp_msg.append(seg_obj) + elif isinstance(seg_obj, Video): + if ( + getattr(seg_obj, "url", None) + or getattr(seg_obj, "path", None) + or getattr(seg_obj, "raw", None) + ): + temp_msg.append(seg_obj) + logger.debug(f"[D{depth}] 处理Video对象成功", "广播") + else: + processed_segments = await _process_v11_segment( + seg_obj, depth, index, bot + ) + temp_msg.extend(processed_segments) + except Exception as e_conv_seg: + logger.warning( + f"[D{depth}] 处理段 {index} 出错: {e_conv_seg}", + "广播", + e=e_conv_seg, + ) + + if not temp_msg and message_content: + logger.warning(f"未能从类型 {input_type_str} 中提取内容", "广播") + + return temp_msg + + +async def get_broadcast_target_groups( + bot: Bot, session: EventSession +) -> tuple[list, list]: + """获取广播目标群组和启用了广播功能的群组""" + target_groups = [] + all_groups, _ = await BroadcastManager.get_all_groups(bot) + + current_group_id = None + if hasattr(session, "id2") and session.id2: + current_group_id = session.id2 + + if current_group_id: + target_groups = [ + group for group in all_groups if group.group_id != current_group_id + ] + logger.info( + f"向除当前群组({current_group_id})外的所有群组广播", "广播", session=session + ) + else: + target_groups = all_groups + logger.info("向所有群组广播", "广播", session=session) + + if not target_groups: + await MessageUtils.build_message("没有找到符合条件的广播目标群组。").send( + reply_to=True + ) + return [], [] + + enabled_groups = [] + for group in target_groups: + if not await CommonUtils.task_is_block(bot, "broadcast", group.group_id): + enabled_groups.append(group) + + if not enabled_groups: + await MessageUtils.build_message( + "没有启用了广播功能的目标群组可供立即发送。" + ).send(reply_to=True) + return target_groups, [] + + return target_groups, enabled_groups + + +async def send_broadcast_and_notify( + bot: Bot, + event: Event, + message: UniMessage, + enabled_groups: list, + target_groups: list, + session: EventSession, +) -> None: + """发送广播并通知结果""" + BroadcastManager.clear_last_broadcast_msg_ids() + count, error_count = await BroadcastManager.send_to_specific_groups( + bot, message, enabled_groups, session + ) + + result = f"成功广播 {count} 个群组" + if error_count: + result += f"\n发送失败 {error_count} 个群组" + result += f"\n有效: {len(enabled_groups)} / 总计: {len(target_groups)}" + + user_id = str(event.get_user_id()) + await bot.send_private_msg(user_id=user_id, message=f"发送广播完成!\n{result}") + + BroadcastManager.log_info( + f"广播完成,有效/总计: {len(enabled_groups)}/{len(target_groups)}", + session, + ) diff --git a/zhenxun/builtin_plugins/superuser/broadcast/models.py b/zhenxun/builtin_plugins/superuser/broadcast/models.py new file mode 100644 index 00000000..4bcdf936 --- /dev/null +++ b/zhenxun/builtin_plugins/superuser/broadcast/models.py @@ -0,0 +1,64 @@ +from datetime import datetime +from typing import Any + +from nonebot_plugin_alconna import UniMessage + +from zhenxun.models.group_console import GroupConsole + +GroupKey = str +MessageID = int +BroadcastResult = tuple[int, int] +BroadcastDetailResult = tuple[int, int, int] + + +class BroadcastTarget: + """广播目标""" + + def __init__(self, group_id: str, channel_id: str | None = None): + self.group_id = group_id + self.channel_id = channel_id + + def to_dict(self) -> dict[str, str | None]: + """转换为字典格式""" + return {"group_id": self.group_id, "channel_id": self.channel_id} + + @classmethod + def from_group_console(cls, group: GroupConsole) -> "BroadcastTarget": + """从 GroupConsole 对象创建""" + return cls(group_id=group.group_id, channel_id=group.channel_id) + + @property + def key(self) -> str: + """获取群组的唯一标识""" + if self.channel_id: + return f"{self.group_id}:{self.channel_id}" + return str(self.group_id) + + +class BroadcastTask: + """广播任务""" + + def __init__( + self, + bot_id: str, + message: UniMessage, + targets: list[BroadcastTarget], + scheduled_time: datetime | None = None, + task_id: str | None = None, + ): + self.bot_id = bot_id + self.message = message + self.targets = targets + self.scheduled_time = scheduled_time + self.task_id = task_id + + def to_dict(self) -> dict[str, Any]: + """转换为字典格式,用于序列化""" + return { + "bot_id": self.bot_id, + "targets": [t.to_dict() for t in self.targets], + "scheduled_time": self.scheduled_time.isoformat() + if self.scheduled_time + else None, + "task_id": self.task_id, + } diff --git a/zhenxun/builtin_plugins/superuser/broadcast/utils.py b/zhenxun/builtin_plugins/superuser/broadcast/utils.py new file mode 100644 index 00000000..748559fd --- /dev/null +++ b/zhenxun/builtin_plugins/superuser/broadcast/utils.py @@ -0,0 +1,175 @@ +import base64 + +import nonebot_plugin_alconna as alc +from nonebot_plugin_alconna import UniMessage +from nonebot_plugin_alconna.uniseg import Reference +from nonebot_plugin_alconna.uniseg.segment import CustomNode, Video + +from zhenxun.services.log import logger + + +def uni_segment_to_v11_segment_dict( + seg: alc.Segment, depth: int = 0 +) -> dict | list[dict] | None: + """UniSeg段转V11字典""" + if isinstance(seg, alc.Text): + return {"type": "text", "data": {"text": seg.text}} + elif isinstance(seg, alc.Image): + if getattr(seg, "url", None): + return { + "type": "image", + "data": {"file": seg.url}, + } + elif getattr(seg, "raw", None): + raw_data = seg.raw + if isinstance(raw_data, str): + if len(raw_data) >= 9 and raw_data[:9] == "base64://": + return {"type": "image", "data": {"file": raw_data}} + elif isinstance(raw_data, bytes): + b64_str = base64.b64encode(raw_data).decode() + return {"type": "image", "data": {"file": f"base64://{b64_str}"}} + else: + logger.warning(f"无法处理 Image.raw 的类型: {type(raw_data)}", "广播") + elif getattr(seg, "path", None): + logger.warning( + f"在合并转发中使用了本地图片路径,可能无法显示: {seg.path}", "广播" + ) + return {"type": "image", "data": {"file": f"file:///{seg.path}"}} + else: + logger.warning(f"alc.Image 缺少有效数据,无法转换为 V11 段: {seg}", "广播") + elif isinstance(seg, alc.At): + return {"type": "at", "data": {"qq": seg.target}} + elif isinstance(seg, alc.AtAll): + return {"type": "at", "data": {"qq": "all"}} + elif isinstance(seg, Video): + if getattr(seg, "url", None): + return { + "type": "video", + "data": {"file": seg.url}, + } + elif getattr(seg, "raw", None): + raw_data = seg.raw + if isinstance(raw_data, str): + if len(raw_data) >= 9 and raw_data[:9] == "base64://": + return {"type": "video", "data": {"file": raw_data}} + elif isinstance(raw_data, bytes): + b64_str = base64.b64encode(raw_data).decode() + return {"type": "video", "data": {"file": f"base64://{b64_str}"}} + else: + logger.warning(f"无法处理 Video.raw 的类型: {type(raw_data)}", "广播") + elif getattr(seg, "path", None): + logger.warning( + f"在合并转发中使用了本地视频路径,可能无法显示: {seg.path}", "广播" + ) + return {"type": "video", "data": {"file": f"file:///{seg.path}"}} + else: + logger.warning(f"Video 缺少有效数据,无法转换为 V11 段: {seg}", "广播") + elif isinstance(seg, Reference) and getattr(seg, "nodes", None): + if depth >= 3: + logger.warning( + f"嵌套转发深度超过限制 (depth={depth}),不再继续解析", "广播" + ) + return {"type": "text", "data": {"text": "[嵌套转发层数过多,内容已省略]"}} + + nested_v11_content_list = [] + nodes_list = getattr(seg, "nodes", []) + for node in nodes_list: + if isinstance(node, CustomNode): + node_v11_content = [] + if isinstance(node.content, UniMessage): + for nested_seg in node.content: + converted_dict = uni_segment_to_v11_segment_dict( + nested_seg, depth + 1 + ) + if isinstance(converted_dict, list): + node_v11_content.extend(converted_dict) + elif converted_dict: + node_v11_content.append(converted_dict) + elif isinstance(node.content, str): + node_v11_content.append( + {"type": "text", "data": {"text": node.content}} + ) + if node_v11_content: + separator = { + "type": "text", + "data": { + "text": f"\n--- 来自 {node.name} ({node.uid}) 的消息 ---\n" + }, + } + nested_v11_content_list.insert(0, separator) + nested_v11_content_list.extend(node_v11_content) + nested_v11_content_list.append( + {"type": "text", "data": {"text": "\n---\n"}} + ) + + return nested_v11_content_list + + else: + logger.warning(f"广播时跳过不支持的 UniSeg 段类型: {type(seg)}", "广播") + return None + + +def uni_message_to_v11_list_of_dicts(uni_msg: UniMessage | str | list) -> list[dict]: + """UniMessage转V11字典列表""" + try: + if isinstance(uni_msg, str): + return [{"type": "text", "data": {"text": uni_msg}}] + + if isinstance(uni_msg, list): + if not uni_msg: + return [] + + if all(isinstance(item, str) for item in uni_msg): + return [{"type": "text", "data": {"text": item}} for item in uni_msg] + + result = [] + for item in uni_msg: + if hasattr(item, "__iter__") and not isinstance(item, str | bytes): + result.extend(uni_message_to_v11_list_of_dicts(item)) + elif hasattr(item, "text") and not isinstance(item, str | bytes): + text_value = getattr(item, "text", "") + result.append({"type": "text", "data": {"text": str(text_value)}}) + elif hasattr(item, "url") and not isinstance(item, str | bytes): + url_value = getattr(item, "url", "") + if isinstance(item, Video): + result.append( + {"type": "video", "data": {"file": str(url_value)}} + ) + else: + result.append( + {"type": "image", "data": {"file": str(url_value)}} + ) + else: + try: + result.append({"type": "text", "data": {"text": str(item)}}) + except Exception as e: + logger.warning(f"无法转换列表元素: {item}, 错误: {e}", "广播") + return result + except Exception as e: + logger.warning(f"消息转换过程中出错: {e}", "广播") + + return [{"type": "text", "data": {"text": str(uni_msg)}}] + + +def custom_nodes_to_v11_nodes(custom_nodes: list[CustomNode]) -> list[dict]: + """CustomNode列表转V11节点""" + v11_nodes = [] + for node in custom_nodes: + v11_content_list = uni_message_to_v11_list_of_dicts(node.content) + + if v11_content_list: + v11_nodes.append( + { + "type": "node", + "data": { + "user_id": str(node.uid), + "nickname": node.name, + "content": v11_content_list, + }, + } + ) + else: + logger.warning( + f"CustomNode (uid={node.uid}) 内容转换后为空,跳过此节点", "广播" + ) + return v11_nodes diff --git a/zhenxun/builtin_plugins/superuser/request_manage.py b/zhenxun/builtin_plugins/superuser/request_manage.py index 23b235bf..e6eb6b77 100644 --- a/zhenxun/builtin_plugins/superuser/request_manage.py +++ b/zhenxun/builtin_plugins/superuser/request_manage.py @@ -2,7 +2,7 @@ from io import BytesIO from arclet.alconna import Args, Option from arclet.alconna.typing import CommandMeta -from nonebot.adapters import Bot +from nonebot.adapters import Bot, Event from nonebot.permission import SUPERUSER from nonebot.plugin import PluginMetadata from nonebot.rule import to_me @@ -10,10 +10,13 @@ from nonebot_plugin_alconna import ( Alconna, AlconnaQuery, Arparma, + Match, Query, + Reply, on_alconna, store_true, ) +from nonebot_plugin_alconna.uniseg.tools import reply_fetch from nonebot_plugin_session import EventSession from zhenxun.configs.config import BotConfig @@ -54,7 +57,7 @@ __plugin_meta__ = PluginMetadata( _req_matcher = on_alconna( Alconna( "请求处理", - Args["handle", ["-fa", "-fr", "-fi", "-ga", "-gr", "-gi"]]["id", int], + Args["handle", ["-fa", "-fr", "-fi", "-ga", "-gr", "-gi"]]["id?", int], meta=CommandMeta( description="好友/群组请求处理", usage=usage, @@ -105,12 +108,12 @@ _clear_matcher = on_alconna( ) reg_arg_list = [ - (r"同意好友请求", ["-fa", "{%0}"]), - (r"拒绝好友请求", ["-fr", "{%0}"]), - (r"忽略好友请求", ["-fi", "{%0}"]), - (r"同意群组请求", ["-ga", "{%0}"]), - (r"拒绝群组请求", ["-gr", "{%0}"]), - (r"忽略群组请求", ["-gi", "{%0}"]), + (r"同意好友请求\s*(?P\d*)", ["-fa", "{id}"]), + (r"拒绝好友请求\s*(?P\d*)", ["-fr", "{id}"]), + (r"忽略好友请求\s*(?P\d*)", ["-fi", "{id}"]), + (r"同意群组请求\s*(?P\d*)", ["-ga", "{id}"]), + (r"拒绝群组请求\s*(?P\d*)", ["-gr", "{id}"]), + (r"忽略群组请求\s*(?P\d*)", ["-gi", "{id}"]), ] for r in reg_arg_list: @@ -125,32 +128,48 @@ for r in reg_arg_list: @_req_matcher.handle() async def _( bot: Bot, + event: Event, session: EventSession, handle: str, - id: int, + id: Match[int], arparma: Arparma, ): + reply: Reply | None = None type_dict = { "a": RequestHandleType.APPROVE, "r": RequestHandleType.REFUSED, "i": RequestHandleType.IGNORE, } + if not id.available: + reply = await reply_fetch(event, bot) + if not reply: + await MessageUtils.build_message("请引用消息处理或添加处理Id.").finish() + handle_id = id.result + if reply: + db_data = await FgRequest.get_or_none(message_ids__contains=reply.id) + if not db_data: + await MessageUtils.build_message( + "未发现此消息的Id,请使用Id进行处理..." + ).finish(reply_to=True) + handle_id = db_data.id req = None handle_type = type_dict[handle[-1]] try: if handle_type == RequestHandleType.APPROVE: - req = await FgRequest.approve(bot, id) + req = await FgRequest.approve(bot, handle_id) if handle_type == RequestHandleType.REFUSED: - req = await FgRequest.refused(bot, id) + req = await FgRequest.refused(bot, handle_id) if handle_type == RequestHandleType.IGNORE: - req = await FgRequest.ignore(id) + req = await FgRequest.ignore(handle_id) except NotFoundError: await MessageUtils.build_message("未发现此id的请求...").finish(reply_to=True) except Exception: await MessageUtils.build_message("其他错误, 可能flag已失效...").finish( reply_to=True ) - logger.info("处理请求", arparma.header_result, session=session) + logger.info( + f"处理请求 Id: {req.id if req else ''}", arparma.header_result, session=session + ) await MessageUtils.build_message("成功处理请求!").send(reply_to=True) if req and handle_type == RequestHandleType.APPROVE: await bot.send_private_msg( diff --git a/zhenxun/builtin_plugins/web_ui/__init__.py b/zhenxun/builtin_plugins/web_ui/__init__.py index d8d71025..619d56bf 100644 --- a/zhenxun/builtin_plugins/web_ui/__init__.py +++ b/zhenxun/builtin_plugins/web_ui/__init__.py @@ -10,7 +10,9 @@ from zhenxun.configs.config import Config as gConfig from zhenxun.configs.utils import PluginExtraData, RegisterConfig from zhenxun.services.log import logger, logger_ from zhenxun.utils.enum import PluginType +from zhenxun.utils.manager.priority_manager import PriorityLifecycle +from .api.configure import router as configure_router from .api.logs import router as ws_log_routes from .api.logs.log_manager import LOG_STORAGE from .api.menu import router as menu_router @@ -29,8 +31,7 @@ from .public import init_public __plugin_meta__ = PluginMetadata( name="WebUi", description="WebUi API", - usage=""" - """.strip(), + usage='"""\n """.strip(),', extra=PluginExtraData( author="HibiKier", version="0.1", @@ -82,7 +83,7 @@ BaseApiRouter.include_router(database_router) BaseApiRouter.include_router(plugin_router) BaseApiRouter.include_router(system_router) BaseApiRouter.include_router(menu_router) - +BaseApiRouter.include_router(configure_router) WsApiRouter = APIRouter(prefix="/zhenxun/socket") @@ -91,9 +92,11 @@ WsApiRouter.include_router(status_routes) WsApiRouter.include_router(chat_routes) -@driver.on_startup +@PriorityLifecycle.on_startup(priority=0) async def _(): try: + # 存储任务引用的列表,防止任务被垃圾回收 + _tasks = [] async def log_sink(message: str): loop = None @@ -104,7 +107,8 @@ async def _(): logger.warning("Web Ui log_sink", e=e) if not loop: loop = asyncio.new_event_loop() - loop.create_task(LOG_STORAGE.add(message.rstrip("\n"))) # noqa: RUF006 + # 存储任务引用到外部列表中 + _tasks.append(loop.create_task(LOG_STORAGE.add(message.rstrip("\n")))) logger_.add( log_sink, colorize=True, filter=default_filter, format=default_format diff --git a/zhenxun/builtin_plugins/web_ui/api/configure/__init__.py b/zhenxun/builtin_plugins/web_ui/api/configure/__init__.py new file mode 100644 index 00000000..0ecde197 --- /dev/null +++ b/zhenxun/builtin_plugins/web_ui/api/configure/__init__.py @@ -0,0 +1,133 @@ +import asyncio +import os +from pathlib import Path +import re +import subprocess +import sys +import time + +from fastapi import APIRouter +from fastapi.responses import JSONResponse +import nonebot + +from zhenxun.configs.config import BotConfig, Config + +from ...base_model import Result +from .data_source import test_db_connection +from .model import Setting + +router = APIRouter(prefix="/configure") + +driver = nonebot.get_driver() + +port = driver.config.port + +BAT_FILE = Path() / "win启动.bat" + +FILE_NAME = ".configure_restart" + + +@router.post( + "/set_configure", + response_model=Result, + response_class=JSONResponse, + description="设置基础配置", +) +async def _(setting: Setting) -> Result: + global port + password = Config.get_config("web-ui", "password") + if password or BotConfig.db_url: + return Result.fail("配置已存在,请先删除DB_URL内容和前端密码再进行设置。") + env_file = Path() / ".env.dev" + if not env_file.exists(): + return Result.fail("配置文件.env.dev不存在。") + env_text = env_file.read_text(encoding="utf-8") + if setting.db_url: + if setting.db_url.startswith("sqlite"): + base_dir = Path().resolve() + # 清理和验证数据库路径 + db_path_str = setting.db_url.split(":")[-1].strip() + # 移除任何可能的路径遍历尝试 + db_path_str = re.sub(r"[\\/]\.\.[\\/]", "", db_path_str) + # 规范化路径 + db_path = Path(db_path_str).resolve() + parent_path = db_path.parent + + # 验证路径是否在项目根目录内 + try: + if not parent_path.absolute().is_relative_to(base_dir): + return Result.fail("数据库路径不在项目根目录内。") + except ValueError: + return Result.fail("无效的数据库路径。") + + # 创建目录 + try: + parent_path.mkdir(parents=True, exist_ok=True) + except Exception as e: + return Result.fail(f"创建数据库目录失败: {e!s}") + + env_text = env_text.replace('DB_URL = ""', f'DB_URL = "{setting.db_url}"') + if setting.superusers: + superusers = ", ".join([f'"{s}"' for s in setting.superusers]) + env_text = re.sub(r"SUPERUSERS=\[.*?\]", f"SUPERUSERS=[{superusers}]", env_text) + if setting.host: + env_text = env_text.replace("HOST = 127.0.0.1", f"HOST = {setting.host}") + if setting.port: + env_text = env_text.replace("PORT = 8080", f"PORT = {setting.port}") + port = setting.port + if setting.username: + Config.set_config("web-ui", "username", setting.username) + Config.set_config("web-ui", "password", setting.password, True) + env_file.write_text(env_text, encoding="utf-8") + if BAT_FILE.exists(): + for file in os.listdir(Path()): + if file.startswith(FILE_NAME): + Path(file).unlink() + flag_file = Path() / f"{FILE_NAME}_{int(time.time())}" + flag_file.touch() + return Result.ok(BAT_FILE.exists(), info="设置成功,请重启真寻以完成配置!") + + +@router.get( + "/test_db", + response_model=Result, + response_class=JSONResponse, + description="设置基础配置", +) +async def _(db_url: str) -> Result: + result = await test_db_connection(db_url) + if isinstance(result, str): + return Result.fail(result) + return Result.ok(info="数据库连接成功!") + + +async def run_restart_command(bat_path: Path, port: int): + """在后台执行重启命令""" + await asyncio.sleep(1) # 确保 FastAPI 已返回响应 + subprocess.Popen([bat_path, str(port)], shell=True) # noqa: ASYNC220 + sys.exit(0) # 退出当前进程 + + +@router.post( + "/restart", + response_model=Result, + response_class=JSONResponse, + description="重启", +) +async def _() -> Result: + if not BAT_FILE.exists(): + return Result.fail("自动重启仅支持意见整合包,请尝试手动重启") + flag_file = next( + (Path() / file for file in os.listdir(Path()) if file.startswith(FILE_NAME)), + None, + ) + if not flag_file or not flag_file.exists(): + return Result.fail("重启标志文件不存在...") + set_time = flag_file.name.split("_")[-1] + if time.time() - float(set_time) > 10 * 60: + return Result.fail("重启标志文件已过期,请重新设置配置。") + flag_file.unlink() + try: + return Result.ok(info="执行重启命令成功") + finally: + asyncio.create_task(run_restart_command(BAT_FILE, port)) # noqa: RUF006 diff --git a/zhenxun/builtin_plugins/web_ui/api/configure/data_source.py b/zhenxun/builtin_plugins/web_ui/api/configure/data_source.py new file mode 100644 index 00000000..ad8c73c9 --- /dev/null +++ b/zhenxun/builtin_plugins/web_ui/api/configure/data_source.py @@ -0,0 +1,18 @@ +from tortoise import Tortoise + + +async def test_db_connection(db_url: str) -> bool | str: + try: + # 初始化 Tortoise ORM + await Tortoise.init( + db_url=db_url, + modules={"models": ["__main__"]}, # 这里不需要实际模型 + ) + # 测试连接 + await Tortoise.get_connection("default").execute_query("SELECT 1") + return True + except Exception as e: + return str(e) + finally: + # 关闭连接 + await Tortoise.close_connections() diff --git a/zhenxun/builtin_plugins/web_ui/api/configure/model.py b/zhenxun/builtin_plugins/web_ui/api/configure/model.py new file mode 100644 index 00000000..4a6b3486 --- /dev/null +++ b/zhenxun/builtin_plugins/web_ui/api/configure/model.py @@ -0,0 +1,16 @@ +from pydantic import BaseModel + + +class Setting(BaseModel): + superusers: list[str] + """超级用户列表""" + db_url: str + """数据库地址""" + host: str + """主机地址""" + port: int + """端口""" + username: str + """前端用户名""" + password: str + """前端密码""" diff --git a/zhenxun/builtin_plugins/web_ui/api/menu/data_source.py b/zhenxun/builtin_plugins/web_ui/api/menu/data_source.py index 9cfcd244..e54bf9e5 100644 --- a/zhenxun/builtin_plugins/web_ui/api/menu/data_source.py +++ b/zhenxun/builtin_plugins/web_ui/api/menu/data_source.py @@ -5,51 +5,63 @@ from zhenxun.services.log import logger from .model import MenuData, MenuItem +default_menus = [ + MenuItem( + name="仪表盘", + module="dashboard", + router="/dashboard", + icon="dashboard", + default=True, + ), + MenuItem( + name="真寻控制台", + module="command", + router="/command", + icon="command", + ), + MenuItem(name="插件列表", module="plugin", router="/plugin", icon="plugin"), + MenuItem(name="插件商店", module="store", router="/store", icon="store"), + MenuItem(name="好友/群组", module="manage", router="/manage", icon="user"), + MenuItem( + name="数据库管理", + module="database", + router="/database", + icon="database", + ), + MenuItem(name="系统信息", module="system", router="/system", icon="system"), + MenuItem(name="关于我们", module="about", router="/about", icon="about"), +] -class MenuManage: + +class MenuManager: def __init__(self) -> None: self.file = DATA_PATH / "web_ui" / "menu.json" self.menu = [] if self.file.exists(): try: + temp_menu = [] self.menu = json.load(self.file.open(encoding="utf8")) + self_menu_name = [menu["name"] for menu in self.menu] + for module in [m.module for m in default_menus]: + if module in self_menu_name: + temp_menu.append( + MenuItem( + **next(m for m in self.menu if m["module"] == module) + ) + ) + else: + temp_menu.append(self.__get_menu_model(module)) + self.menu = temp_menu except Exception as e: logger.warning("菜单文件损坏,已重新生成...", "WebUi", e=e) if not self.menu: - self.menu = [ - MenuItem( - name="仪表盘", - module="dashboard", - router="/dashboard", - icon="dashboard", - default=True, - ), - MenuItem( - name="真寻控制台", - module="command", - router="/command", - icon="command", - ), - MenuItem( - name="插件列表", module="plugin", router="/plugin", icon="plugin" - ), - MenuItem( - name="插件商店", module="store", router="/store", icon="store" - ), - MenuItem( - name="好友/群组", module="manage", router="/manage", icon="user" - ), - MenuItem( - name="数据库管理", - module="database", - router="/database", - icon="database", - ), - MenuItem( - name="系统信息", module="system", router="/system", icon="system" - ), - ] - self.save() + self.menu = default_menus + self.save() + + def __get_menu_model(self, module: str): + return default_menus[ + next(i for i, m in enumerate(default_menus) if m.module == module) + ] def get_menus(self): return MenuData(menus=self.menu) @@ -61,4 +73,4 @@ class MenuManage: json.dump(temp, f, ensure_ascii=False, indent=4) -menu_manage = MenuManage() +menu_manage = MenuManager() diff --git a/zhenxun/builtin_plugins/web_ui/api/tabs/dashboard/data_source.py b/zhenxun/builtin_plugins/web_ui/api/tabs/dashboard/data_source.py index 6c312db3..87011c93 100644 --- a/zhenxun/builtin_plugins/web_ui/api/tabs/dashboard/data_source.py +++ b/zhenxun/builtin_plugins/web_ui/api/tabs/dashboard/data_source.py @@ -13,6 +13,7 @@ from zhenxun.models.bot_connect_log import BotConnectLog from zhenxun.models.chat_history import ChatHistory from zhenxun.models.statistics import Statistics from zhenxun.services.log import logger +from zhenxun.utils.manager.priority_manager import PriorityLifecycle from zhenxun.utils.platform import PlatformUtils from ....base_model import BaseResultModel, QueryModel @@ -31,7 +32,7 @@ driver: Driver = nonebot.get_driver() CONNECT_TIME = 0 -@driver.on_startup +@PriorityLifecycle.on_startup(priority=5) async def _(): global CONNECT_TIME CONNECT_TIME = int(time.time()) diff --git a/zhenxun/builtin_plugins/web_ui/api/tabs/database/__init__.py b/zhenxun/builtin_plugins/web_ui/api/tabs/database/__init__.py index b963e291..91fbc5c0 100644 --- a/zhenxun/builtin_plugins/web_ui/api/tabs/database/__init__.py +++ b/zhenxun/builtin_plugins/web_ui/api/tabs/database/__init__.py @@ -8,6 +8,7 @@ from zhenxun.configs.config import BotConfig from zhenxun.models.plugin_info import PluginInfo from zhenxun.models.task_info import TaskInfo from zhenxun.services.log import logger +from zhenxun.utils.manager.priority_manager import PriorityLifecycle from ....base_model import BaseResultModel, QueryModel, Result from ....utils import authentication @@ -21,7 +22,7 @@ router = APIRouter(prefix="/database") driver: Driver = nonebot.get_driver() -@driver.on_startup +@PriorityLifecycle.on_startup(priority=5) async def _(): for plugin in nonebot.get_loaded_plugins(): module = plugin.name diff --git a/zhenxun/builtin_plugins/web_ui/api/tabs/main/__init__.py b/zhenxun/builtin_plugins/web_ui/api/tabs/main/__init__.py index 36059101..f93d0ab1 100644 --- a/zhenxun/builtin_plugins/web_ui/api/tabs/main/__init__.py +++ b/zhenxun/builtin_plugins/web_ui/api/tabs/main/__init__.py @@ -16,7 +16,7 @@ from zhenxun.utils.platform import PlatformUtils from ....base_model import Result from ....config import QueryDateType -from ....utils import authentication, get_system_status +from ....utils import authentication, clear_help_image, get_system_status from .data_source import ApiDataSource from .model import ( ActiveGroup, @@ -234,6 +234,7 @@ async def _(param: BotManageUpdateParam): bot_data.block_plugins = CommonUtils.convert_module_format(param.block_plugins) bot_data.block_tasks = CommonUtils.convert_module_format(param.block_tasks) await bot_data.save(update_fields=["block_plugins", "block_tasks"]) + clear_help_image() return Result.ok() except Exception as e: logger.error(f"{router.prefix}/update_bot_manage 调用错误", "WebUi", e=e) diff --git a/zhenxun/builtin_plugins/web_ui/api/tabs/main/data_source.py b/zhenxun/builtin_plugins/web_ui/api/tabs/main/data_source.py index 40aa5f18..2a783b22 100644 --- a/zhenxun/builtin_plugins/web_ui/api/tabs/main/data_source.py +++ b/zhenxun/builtin_plugins/web_ui/api/tabs/main/data_source.py @@ -92,7 +92,7 @@ class ApiDataSource: """ version_file = Path() / "__version__" if version_file.exists(): - if text := version_file.open().read(): + if text := version_file.open(encoding="utf-8").read(): return text.replace("__version__: ", "").strip() return "unknown" diff --git a/zhenxun/builtin_plugins/web_ui/api/tabs/manage/chat.py b/zhenxun/builtin_plugins/web_ui/api/tabs/manage/chat.py index d20149fb..389546ca 100644 --- a/zhenxun/builtin_plugins/web_ui/api/tabs/manage/chat.py +++ b/zhenxun/builtin_plugins/web_ui/api/tabs/manage/chat.py @@ -1,3 +1,5 @@ +from datetime import datetime + from fastapi import APIRouter import nonebot from nonebot import on_message @@ -49,13 +51,14 @@ async def message_handle( message: UniMsg, group_id: str | None, ): + time = str(datetime.now().replace(microsecond=0)) messages = [] for m in message: if isinstance(m, Text | str): - messages.append(MessageItem(type="text", msg=str(m))) + messages.append(MessageItem(type="text", msg=str(m), time=time)) elif isinstance(m, Image): if m.url: - messages.append(MessageItem(type="img", msg=m.url)) + messages.append(MessageItem(type="img", msg=m.url, time=time)) elif isinstance(m, At): if group_id: if m.target == "0": @@ -72,9 +75,9 @@ async def message_handle( uname = group_user.user_name if m.target not in ID2NAME[group_id]: ID2NAME[group_id][m.target] = uname - messages.append(MessageItem(type="at", msg=f"@{uname}")) + messages.append(MessageItem(type="at", msg=f"@{uname}", time=time)) elif isinstance(m, Hyper): - messages.append(MessageItem(type="text", msg="[分享消息]")) + messages.append(MessageItem(type="text", msg="[分享消息]", time=time)) return messages diff --git a/zhenxun/builtin_plugins/web_ui/api/tabs/manage/model.py b/zhenxun/builtin_plugins/web_ui/api/tabs/manage/model.py index 7149cee1..68772d0f 100644 --- a/zhenxun/builtin_plugins/web_ui/api/tabs/manage/model.py +++ b/zhenxun/builtin_plugins/web_ui/api/tabs/manage/model.py @@ -237,6 +237,8 @@ class MessageItem(BaseModel): """消息类型""" msg: str """内容""" + time: str + """发送日期""" class Message(BaseModel): diff --git a/zhenxun/builtin_plugins/web_ui/api/tabs/plugin_manage/__init__.py b/zhenxun/builtin_plugins/web_ui/api/tabs/plugin_manage/__init__.py index e011e67f..9dd134a4 100644 --- a/zhenxun/builtin_plugins/web_ui/api/tabs/plugin_manage/__init__.py +++ b/zhenxun/builtin_plugins/web_ui/api/tabs/plugin_manage/__init__.py @@ -6,13 +6,16 @@ from zhenxun.services.log import logger from zhenxun.utils.enum import BlockType, PluginType from ....base_model import Result -from ....utils import authentication +from ....utils import authentication, clear_help_image from .data_source import ApiDataSource from .model import ( + BatchUpdatePlugins, + BatchUpdateResult, PluginCount, PluginDetail, PluginInfo, PluginSwitch, + RenameMenuTypePayload, UpdatePlugin, ) @@ -30,9 +33,8 @@ async def _( plugin_type: list[PluginType] = Query(None), menu_type: str | None = None ) -> Result[list[PluginInfo]]: try: - return Result.ok( - await ApiDataSource.get_plugin_list(plugin_type, menu_type), "拿到信息啦!" - ) + result = await ApiDataSource.get_plugin_list(plugin_type, menu_type) + return Result.ok(result, "拿到信息啦!") except Exception as e: logger.error(f"{router.prefix}/get_plugin_list 调用错误", "WebUi", e=e) return Result.fail(f"发生了一点错误捏 {type(e)}: {e}") @@ -78,6 +80,7 @@ async def _() -> Result[PluginCount]: async def _(param: UpdatePlugin) -> Result: try: await ApiDataSource.update_plugin(param) + clear_help_image() return Result.ok(info="已经帮你写好啦!") except (ValueError, KeyError): return Result.fail("插件数据不存在...") @@ -105,6 +108,7 @@ async def _(param: PluginSwitch) -> Result: db_plugin.block_type = None db_plugin.status = True await db_plugin.save() + clear_help_image() return Result.ok(info="成功改变了开关状态!") except Exception as e: logger.error(f"{router.prefix}/change_switch 调用错误", "WebUi", e=e) @@ -144,11 +148,68 @@ async def _() -> Result[list[str]]: ) async def _(module: str) -> Result[PluginDetail]: try: - return Result.ok( - await ApiDataSource.get_plugin_detail(module), "已经帮你写好啦!" - ) + detail = await ApiDataSource.get_plugin_detail(module) + return Result.ok(detail, "已经帮你写好啦!") except (ValueError, KeyError): return Result.fail("插件数据不存在...") except Exception as e: logger.error(f"{router.prefix}/get_plugin 调用错误", "WebUi", e=e) return Result.fail(f"{type(e)}: {e}") + + +@router.put( + "/plugins/batch_update", + dependencies=[authentication()], + response_model=Result[BatchUpdateResult], + response_class=JSONResponse, + summary="批量更新插件配置", +) +async def batch_update_plugin_config_api( + params: BatchUpdatePlugins, +) -> Result[BatchUpdateResult]: + """批量更新插件配置,如开关、类型等""" + try: + result_dict = await ApiDataSource.batch_update_plugins(params=params) + result_model = BatchUpdateResult( + success=result_dict["success"], + updated_count=result_dict["updated_count"], + errors=result_dict["errors"], + ) + clear_help_image() + return Result.ok(result_model, "插件配置更新完成") + except Exception as e: + logger.error(f"{router.prefix}/plugins/batch_update 调用错误", "WebUi", e=e) + return Result.fail(f"发生了一点错误捏 {type(e)}: {e}") + + +# 新增:重命名菜单类型路由 +@router.put( + "/menu_type/rename", + dependencies=[authentication()], + response_model=Result, + summary="重命名菜单类型", +) +async def rename_menu_type_api(payload: RenameMenuTypePayload) -> Result: + try: + result = await ApiDataSource.rename_menu_type( + old_name=payload.old_name, new_name=payload.new_name + ) + if result.get("success"): + clear_help_image() + return Result.ok( + info=result.get( + "info", + f"成功将 {result.get('updated_count', 0)} 个插件的菜单类型从 " + f"'{payload.old_name}' 修改为 '{payload.new_name}'", + ) + ) + else: + return Result.fail(info=result.get("info", "重命名失败")) + except ValueError as ve: + return Result.fail(info=str(ve)) + except RuntimeError as re: + logger.error(f"{router.prefix}/menu_type/rename 调用错误", "WebUi", e=re) + return Result.fail(info=str(re)) + except Exception as e: + logger.error(f"{router.prefix}/menu_type/rename 调用错误", "WebUi", e=e) + return Result.fail(info=f"发生未知错误: {type(e).__name__}") diff --git a/zhenxun/builtin_plugins/web_ui/api/tabs/plugin_manage/data_source.py b/zhenxun/builtin_plugins/web_ui/api/tabs/plugin_manage/data_source.py index ee0992d6..0f2c3676 100644 --- a/zhenxun/builtin_plugins/web_ui/api/tabs/plugin_manage/data_source.py +++ b/zhenxun/builtin_plugins/web_ui/api/tabs/plugin_manage/data_source.py @@ -2,13 +2,20 @@ import re import cattrs from fastapi import Query +from tortoise.exceptions import DoesNotExist from zhenxun.configs.config import Config from zhenxun.configs.utils import ConfigGroup from zhenxun.models.plugin_info import PluginInfo as DbPluginInfo from zhenxun.utils.enum import BlockType, PluginType -from .model import PluginConfig, PluginDetail, PluginInfo, UpdatePlugin +from .model import ( + BatchUpdatePlugins, + PluginConfig, + PluginDetail, + PluginInfo, + UpdatePlugin, +) class ApiDataSource: @@ -44,6 +51,11 @@ class ApiDataSource: level=plugin.level, status=plugin.status, author=plugin.author, + block_type=plugin.block_type, + is_builtin="builtin_plugins" in plugin.module_path + or plugin.plugin_type == PluginType.HIDDEN, + allow_setting=plugin.plugin_type != PluginType.HIDDEN, + allow_switch=plugin.plugin_type != PluginType.HIDDEN, ) plugin_list.append(plugin_info) return plugin_list @@ -69,7 +81,6 @@ class ApiDataSource: db_plugin.block_type = param.block_type db_plugin.status = param.block_type != BlockType.ALL await db_plugin.save() - # 配置项 if param.configs and (configs := Config.get(param.module)): for key in param.configs: if c := configs.configs.get(key): @@ -80,6 +91,87 @@ class ApiDataSource: Config.save(save_simple_data=True) return db_plugin + @classmethod + async def batch_update_plugins(cls, params: BatchUpdatePlugins) -> dict: + """批量更新插件数据 + + 参数: + params: BatchUpdatePlugins + + 返回: + dict: 更新结果, 例如 {'success': True, 'updated_count': 5, 'errors': []} + """ + plugins_to_update_other_fields = [] + other_update_fields = set() + updated_count = 0 + errors = [] + + for item in params.updates: + try: + db_plugin = await DbPluginInfo.get(module=item.module) + plugin_changed_other = False + plugin_changed_block = False + + if db_plugin.block_type != item.block_type: + db_plugin.block_type = item.block_type + db_plugin.status = item.block_type != BlockType.ALL + plugin_changed_block = True + + if item.menu_type is not None and db_plugin.menu_type != item.menu_type: + db_plugin.menu_type = item.menu_type + other_update_fields.add("menu_type") + plugin_changed_other = True + + if ( + item.default_status is not None + and db_plugin.default_status != item.default_status + ): + db_plugin.default_status = item.default_status + other_update_fields.add("default_status") + plugin_changed_other = True + + if plugin_changed_block: + try: + await db_plugin.save(update_fields=["block_type", "status"]) + updated_count += 1 + except Exception as e_save: + errors.append( + { + "module": item.module, + "error": f"Save block_type failed: {e_save!s}", + } + ) + plugin_changed_other = False + + if plugin_changed_other: + plugins_to_update_other_fields.append(db_plugin) + + except DoesNotExist: + errors.append({"module": item.module, "error": "Plugin not found"}) + except Exception as e: + errors.append({"module": item.module, "error": str(e)}) + + bulk_updated_count = 0 + if plugins_to_update_other_fields and other_update_fields: + try: + await DbPluginInfo.bulk_update( + plugins_to_update_other_fields, list(other_update_fields) + ) + bulk_updated_count = len(plugins_to_update_other_fields) + except Exception as e_bulk: + errors.append( + { + "module": "batch_update_other", + "error": f"Bulk update failed: {e_bulk!s}", + } + ) + + return { + "success": len(errors) == 0, + "updated_count": updated_count + bulk_updated_count, + "errors": errors, + } + @classmethod def __build_plugin_config( cls, module: str, cfg: str, config: ConfigGroup @@ -115,6 +207,41 @@ class ApiDataSource: type_inner=type_inner, # type: ignore ) + @classmethod + async def rename_menu_type(cls, old_name: str, new_name: str) -> dict: + """重命名菜单类型,并更新所有相关插件 + + 参数: + old_name: 旧菜单类型名称 + new_name: 新菜单类型名称 + + 返回: + dict: 更新结果, 例如 {'success': True, 'updated_count': 3} + """ + if not old_name or not new_name: + raise ValueError("旧名称和新名称都不能为空") + if old_name == new_name: + return { + "success": True, + "updated_count": 0, + "info": "新旧名称相同,无需更新", + } + + # 检查新名称是否已存在(理论上前端会校验,后端再保险一次) + exists = await DbPluginInfo.filter(menu_type=new_name).exists() + if exists: + raise ValueError(f"新的菜单类型名称 '{new_name}' 已被其他插件使用") + + try: + # 使用 filter().update() 进行批量更新 + updated_count = await DbPluginInfo.filter(menu_type=old_name).update( + menu_type=new_name + ) + return {"success": True, "updated_count": updated_count} + except Exception as e: + # 可以添加更详细的日志记录 + raise RuntimeError(f"数据库更新菜单类型失败: {e!s}") + @classmethod async def get_plugin_detail(cls, module: str) -> PluginDetail: """获取插件详情 diff --git a/zhenxun/builtin_plugins/web_ui/api/tabs/plugin_manage/model.py b/zhenxun/builtin_plugins/web_ui/api/tabs/plugin_manage/model.py index 662814c9..579f3104 100644 --- a/zhenxun/builtin_plugins/web_ui/api/tabs/plugin_manage/model.py +++ b/zhenxun/builtin_plugins/web_ui/api/tabs/plugin_manage/model.py @@ -1,6 +1,6 @@ from typing import Any -from pydantic import BaseModel +from pydantic import BaseModel, Field from zhenxun.utils.enum import BlockType @@ -37,19 +37,19 @@ class UpdatePlugin(BaseModel): module: str """模块""" default_status: bool - """默认开关""" + """是否默认开启""" limit_superuser: bool - """限制超级用户""" - cost_gold: int - """金币花费""" - menu_type: str - """插件菜单类型""" + """是否限制超级用户""" level: int - """插件所需群权限""" + """等级""" + cost_gold: int + """花费金币""" + menu_type: str + """菜单类型""" block_type: BlockType | None = None """禁用类型""" configs: dict[str, Any] | None = None - """配置项""" + """设置项""" class PluginInfo(BaseModel): @@ -58,27 +58,33 @@ class PluginInfo(BaseModel): """ module: str - """插件名称""" + """模块""" plugin_name: str - """插件中文名称""" + """插件名称""" default_status: bool - """默认开关""" + """是否默认开启""" limit_superuser: bool - """限制超级用户""" + """是否限制超级用户""" + level: int + """等级""" cost_gold: int """花费金币""" menu_type: str - """插件菜单类型""" + """菜单类型""" version: str - """插件版本""" - level: int - """群权限""" + """版本""" status: bool - """当前状态""" + """状态""" author: str | None = None """作者""" - block_type: BlockType | None = None - """禁用类型""" + block_type: BlockType | None = Field(None, description="插件禁用状态 (None: 启用)") + """禁用状态""" + is_builtin: bool = False + """是否为内置插件""" + allow_switch: bool = True + """是否允许开关""" + allow_setting: bool = True + """是否允许设置""" class PluginConfig(BaseModel): @@ -86,20 +92,13 @@ class PluginConfig(BaseModel): 插件配置项 """ - module: str - """模块""" - key: str - """键""" - value: Any - """值""" - help: str | None = None - """帮助""" - default_value: Any - """默认值""" - type: Any = None - """值类型""" - type_inner: list[str] | None = None - """List Tuple等内部类型检验""" + module: str = Field(..., description="模块名") + key: str = Field(..., description="键") + value: Any = Field(None, description="值") + help: str | None = Field(None, description="帮助信息") + default_value: Any = Field(None, description="默认值") + type: str | None = Field(None, description="类型") + type_inner: list[str] | None = Field(None, description="内部类型") class PluginCount(BaseModel): @@ -117,6 +116,21 @@ class PluginCount(BaseModel): """其他插件""" +class BatchUpdatePluginItem(BaseModel): + module: str = Field(..., description="插件模块名") + default_status: bool | None = Field(None, description="默认状态(开关)") + menu_type: str | None = Field(None, description="菜单类型") + block_type: BlockType | None = Field( + None, description="插件禁用状态 (None: 启用, ALL: 禁用)" + ) + + +class BatchUpdatePlugins(BaseModel): + updates: list[BatchUpdatePluginItem] = Field( + ..., description="要批量更新的插件列表" + ) + + class PluginDetail(PluginInfo): """ 插件详情 @@ -125,6 +139,26 @@ class PluginDetail(PluginInfo): config_list: list[PluginConfig] +class RenameMenuTypePayload(BaseModel): + old_name: str = Field(..., description="旧菜单类型名称") + new_name: str = Field(..., description="新菜单类型名称") + + class PluginIr(BaseModel): id: int """插件id""" + + +class BatchUpdateResult(BaseModel): + """ + 批量更新插件结果 + """ + + success: bool = Field(..., description="是否全部成功") + """是否全部成功""" + updated_count: int = Field(..., description="更新成功的数量") + """更新成功的数量""" + errors: list[dict[str, str]] = Field( + default_factory=list, description="错误信息列表" + ) + """错误信息列表""" diff --git a/zhenxun/builtin_plugins/web_ui/api/tabs/plugin_manage/store.py b/zhenxun/builtin_plugins/web_ui/api/tabs/plugin_manage/store.py index acff6356..9ee6ff41 100644 --- a/zhenxun/builtin_plugins/web_ui/api/tabs/plugin_manage/store.py +++ b/zhenxun/builtin_plugins/web_ui/api/tabs/plugin_manage/store.py @@ -1,6 +1,7 @@ from fastapi import APIRouter from fastapi.responses import JSONResponse from nonebot import require +from nonebot.compat import model_dump from zhenxun.models.plugin_info import PluginInfo from zhenxun.services.log import logger @@ -22,12 +23,12 @@ router = APIRouter(prefix="/store") async def _() -> Result[dict]: try: require("plugin_store") - from zhenxun.builtin_plugins.plugin_store import ShopManage + from zhenxun.builtin_plugins.plugin_store import StoreManager - data = await ShopManage.get_data() + data = await StoreManager.get_data() plugin_list = [ - {**data[name].to_dict(), "name": name, "id": idx} - for idx, name in enumerate(data) + {**model_dump(plugin), "name": plugin.name, "id": idx} + for idx, plugin in enumerate(data) ] modules = await PluginInfo.filter(load_status=True).values_list( "module", flat=True @@ -48,9 +49,9 @@ async def _() -> Result[dict]: async def _(param: PluginIr) -> Result: try: require("plugin_store") - from zhenxun.builtin_plugins.plugin_store import ShopManage + from zhenxun.builtin_plugins.plugin_store import StoreManager - result = await ShopManage.add_plugin(param.id) # type: ignore + result = await StoreManager.add_plugin(param.id) # type: ignore return Result.ok(info=result) except Exception as e: return Result.fail(f"安装插件失败: {type(e)}: {e}") @@ -66,9 +67,9 @@ async def _(param: PluginIr) -> Result: async def _(param: PluginIr) -> Result: try: require("plugin_store") - from zhenxun.builtin_plugins.plugin_store import ShopManage + from zhenxun.builtin_plugins.plugin_store import StoreManager - result = await ShopManage.update_plugin(param.id) # type: ignore + result = await StoreManager.update_plugin(param.id) # type: ignore return Result.ok(info=result) except Exception as e: return Result.fail(f"更新插件失败: {type(e)}: {e}") @@ -84,9 +85,9 @@ async def _(param: PluginIr) -> Result: async def _(param: PluginIr) -> Result: try: require("plugin_store") - from zhenxun.builtin_plugins.plugin_store import ShopManage + from zhenxun.builtin_plugins.plugin_store import StoreManager - result = await ShopManage.remove_plugin(param.id) # type: ignore + result = await StoreManager.remove_plugin(param.id) # type: ignore return Result.ok(info=result) except Exception as e: return Result.fail(f"移除插件失败: {type(e)}: {e}") diff --git a/zhenxun/builtin_plugins/web_ui/api/tabs/system/__init__.py b/zhenxun/builtin_plugins/web_ui/api/tabs/system/__init__.py index e05115df..b8ae2481 100644 --- a/zhenxun/builtin_plugins/web_ui/api/tabs/system/__init__.py +++ b/zhenxun/builtin_plugins/web_ui/api/tabs/system/__init__.py @@ -9,7 +9,7 @@ from fastapi.responses import JSONResponse from zhenxun.utils._build_image import BuildImage from ....base_model import Result, SystemFolderSize -from ....utils import authentication, get_system_disk +from ....utils import authentication, get_system_disk, validate_path from .model import AddFile, DeleteFile, DirFile, RenameFile, SaveFile router = APIRouter(prefix="/system") @@ -25,21 +25,30 @@ IMAGE_TYPE = ["jpg", "jpeg", "png", "gif", "bmp", "webp", "svg"] description="获取文件列表", ) async def _(path: str | None = None) -> Result[list[DirFile]]: - base_path = Path(path) if path else Path() - data_list = [] - for file in os.listdir(base_path): - file_path = base_path / file - is_image = any(file.endswith(f".{t}") for t in IMAGE_TYPE) - data_list.append( - DirFile( - is_file=not file_path.is_dir(), - is_image=is_image, - name=file, - parent=path, + try: + base_path, error = validate_path(path) + if error: + return Result.fail(error) + if not base_path: + return Result.fail("无效的路径") + data_list = [] + for file in os.listdir(base_path): + file_path = base_path / file + is_image = any(file.endswith(f".{t}") for t in IMAGE_TYPE) + data_list.append( + DirFile( + is_file=not file_path.is_dir(), + is_image=is_image, + name=file, + parent=path, + size=None if file_path.is_dir() else file_path.stat().st_size, + mtime=file_path.stat().st_mtime, + ) ) - ) - sorted(data_list, key=lambda f: f.name) - return Result.ok(data_list) + sorted(data_list, key=lambda f: f.name) + return Result.ok(data_list) + except Exception as e: + return Result.fail(f"获取文件列表失败: {e!s}") @router.get( @@ -61,8 +70,12 @@ async def _(full_path: str | None = None) -> Result[list[SystemFolderSize]]: description="删除文件", ) async def _(param: DeleteFile) -> Result: - path = Path(param.full_path) - if not path or not path.exists(): + path, error = validate_path(param.full_path) + if error: + return Result.fail(error) + if not path: + return Result.fail("无效的路径") + if not path.exists(): return Result.warning_("文件不存在...") try: path.unlink() @@ -79,8 +92,12 @@ async def _(param: DeleteFile) -> Result: description="删除文件夹", ) async def _(param: DeleteFile) -> Result: - path = Path(param.full_path) - if not path or not path.exists() or path.is_file(): + path, error = validate_path(param.full_path) + if error: + return Result.fail(error) + if not path: + return Result.fail("无效的路径") + if not path.exists() or path.is_file(): return Result.warning_("文件夹不存在...") try: shutil.rmtree(path.absolute()) @@ -97,10 +114,14 @@ async def _(param: DeleteFile) -> Result: description="重命名文件", ) async def _(param: RenameFile) -> Result: - path = ( - (Path(param.parent) / param.old_name) if param.parent else Path(param.old_name) - ) - if not path or not path.exists(): + parent_path, error = validate_path(param.parent) + if error: + return Result.fail(error) + if not parent_path: + return Result.fail("无效的路径") + + path = (parent_path / param.old_name) if param.parent else Path(param.old_name) + if not path.exists(): return Result.warning_("文件不存在...") try: path.rename(path.parent / param.name) @@ -117,10 +138,14 @@ async def _(param: RenameFile) -> Result: description="重命名文件夹", ) async def _(param: RenameFile) -> Result: - path = ( - (Path(param.parent) / param.old_name) if param.parent else Path(param.old_name) - ) - if not path or not path.exists() or path.is_file(): + parent_path, error = validate_path(param.parent) + if error: + return Result.fail(error) + if not parent_path: + return Result.fail("无效的路径") + + path = (parent_path / param.old_name) if param.parent else Path(param.old_name) + if not path.exists() or path.is_file(): return Result.warning_("文件夹不存在...") try: new_path = path.parent / param.name @@ -138,7 +163,13 @@ async def _(param: RenameFile) -> Result: description="新建文件", ) async def _(param: AddFile) -> Result: - path = (Path(param.parent) / param.name) if param.parent else Path(param.name) + parent_path, error = validate_path(param.parent) + if error: + return Result.fail(error) + if not parent_path: + return Result.fail("无效的路径") + + path = (parent_path / param.name) if param.parent else Path(param.name) if path.exists(): return Result.warning_("文件已存在...") try: @@ -156,7 +187,13 @@ async def _(param: AddFile) -> Result: description="新建文件夹", ) async def _(param: AddFile) -> Result: - path = (Path(param.parent) / param.name) if param.parent else Path(param.name) + parent_path, error = validate_path(param.parent) + if error: + return Result.fail(error) + if not parent_path: + return Result.fail("无效的路径") + + path = (parent_path / param.name) if param.parent else Path(param.name) if path.exists(): return Result.warning_("文件夹已存在...") try: @@ -174,7 +211,11 @@ async def _(param: AddFile) -> Result: description="读取文件", ) async def _(full_path: str) -> Result: - path = Path(full_path) + path, error = validate_path(full_path) + if error: + return Result.fail(error) + if not path: + return Result.fail("无效的路径") if not path.exists(): return Result.warning_("文件不存在...") try: @@ -192,9 +233,13 @@ async def _(full_path: str) -> Result: description="读取文件", ) async def _(param: SaveFile) -> Result[str]: - path = Path(param.full_path) + path, error = validate_path(param.full_path) + if error: + return Result.fail(error) + if not path: + return Result.fail("无效的路径") try: - async with aiofiles.open(path, "w", encoding="utf-8") as f: + async with aiofiles.open(str(path), "w", encoding="utf-8") as f: await f.write(param.content) return Result.ok("更新成功!") except Exception as e: @@ -209,10 +254,24 @@ async def _(param: SaveFile) -> Result[str]: description="读取图片base64", ) async def _(full_path: str) -> Result[str]: - path = Path(full_path) + path, error = validate_path(full_path) + if error: + return Result.fail(error) + if not path: + return Result.fail("无效的路径") if not path.exists(): return Result.warning_("文件不存在...") try: return Result.ok(BuildImage.open(path).pic2bs4()) except Exception as e: return Result.warning_(f"获取图片失败: {e!s}") + + +@router.get( + "/ping", + response_model=Result[str], + response_class=JSONResponse, + description="检查服务器状态", +) +async def _() -> Result[str]: + return Result.ok("pong") diff --git a/zhenxun/builtin_plugins/web_ui/api/tabs/system/model.py b/zhenxun/builtin_plugins/web_ui/api/tabs/system/model.py index 3c2357f2..2959a0e1 100644 --- a/zhenxun/builtin_plugins/web_ui/api/tabs/system/model.py +++ b/zhenxun/builtin_plugins/web_ui/api/tabs/system/model.py @@ -14,6 +14,10 @@ class DirFile(BaseModel): """文件夹或文件名称""" parent: str | None = None """父级""" + size: int | None = None + """文件大小""" + mtime: float | None = None + """修改时间""" class DeleteFile(BaseModel): diff --git a/zhenxun/builtin_plugins/web_ui/config.py b/zhenxun/builtin_plugins/web_ui/config.py index bddcb062..4a88aad9 100644 --- a/zhenxun/builtin_plugins/web_ui/config.py +++ b/zhenxun/builtin_plugins/web_ui/config.py @@ -1,6 +1,12 @@ +import sys + from fastapi.middleware.cors import CORSMiddleware import nonebot -from strenum import StrEnum + +if sys.version_info >= (3, 11): + from enum import StrEnum +else: + from strenum import StrEnum from zhenxun.configs.path_config import DATA_PATH, TEMP_PATH diff --git a/zhenxun/builtin_plugins/web_ui/public/data_source.py b/zhenxun/builtin_plugins/web_ui/public/data_source.py index 9f5a657e..51b29533 100644 --- a/zhenxun/builtin_plugins/web_ui/public/data_source.py +++ b/zhenxun/builtin_plugins/web_ui/public/data_source.py @@ -18,6 +18,7 @@ async def update_webui_assets(): download_url = await GithubUtils.parse_github_url( WEBUI_DIST_GITHUB_URL ).get_archive_download_urls() + logger.info("开始下载 webui_assets 资源...", COMMAND_NAME) if await AsyncHttpx.download_file( download_url, webui_assets_path, follow_redirects=True ): diff --git a/zhenxun/builtin_plugins/web_ui/utils.py b/zhenxun/builtin_plugins/web_ui/utils.py index df2fdd35..84459114 100644 --- a/zhenxun/builtin_plugins/web_ui/utils.py +++ b/zhenxun/builtin_plugins/web_ui/utils.py @@ -2,6 +2,7 @@ import contextlib from datetime import datetime, timedelta, timezone import os from pathlib import Path +import re from fastapi import Depends, HTTPException from fastapi.security import OAuth2PasswordBearer @@ -11,7 +12,7 @@ import psutil import ujson as json from zhenxun.configs.config import Config -from zhenxun.configs.path_config import DATA_PATH +from zhenxun.configs.path_config import DATA_PATH, IMAGE_PATH from .base_model import SystemFolderSize, SystemStatus, User @@ -28,6 +29,61 @@ if token_file.exists(): token_data = json.load(open(token_file, encoding="utf8")) +def validate_path(path_str: str | None) -> tuple[Path | None, str | None]: + """验证路径是否安全 + + 参数: + path_str: 用户输入的路径 + + 返回: + tuple[Path | None, str | None]: (验证后的路径, 错误信息) + """ + try: + if not path_str: + return Path().resolve(), None + + # 1. 移除任何可能的路径遍历尝试 + path_str = re.sub(r"[\\/]\.\.[\\/]", "", path_str) + + # 2. 规范化路径并转换为绝对路径 + path = Path(path_str).resolve() + + # 3. 获取项目根目录 + root_dir = Path().resolve() + + # 4. 验证路径是否在项目根目录内 + try: + if not path.is_relative_to(root_dir): + return None, "访问路径超出允许范围" + except ValueError: + return None, "无效的路径格式" + + # 5. 验证路径是否包含任何危险字符 + if any(c in str(path) for c in ["..", "~", "*", "?", ">", "<", "|", '"']): + return None, "路径包含非法字符" + + # 6. 验证路径长度是否合理 + return (None, "路径长度超出限制") if len(str(path)) > 4096 else (path, None) + except Exception as e: + return None, f"路径验证失败: {e!s}" + + +GROUP_HELP_PATH = DATA_PATH / "group_help" +SIMPLE_HELP_IMAGE = IMAGE_PATH / "SIMPLE_HELP.png" +SIMPLE_DETAIL_HELP_IMAGE = IMAGE_PATH / "SIMPLE_DETAIL_HELP.png" + + +def clear_help_image(): + """清理帮助图片""" + if SIMPLE_HELP_IMAGE.exists(): + SIMPLE_HELP_IMAGE.unlink() + if SIMPLE_DETAIL_HELP_IMAGE.exists(): + SIMPLE_DETAIL_HELP_IMAGE.unlink() + for file in GROUP_HELP_PATH.iterdir(): + if file.is_file(): + file.unlink() + + def get_user(uname: str) -> User | None: """获取账号密码 diff --git a/zhenxun/configs/utils/__init__.py b/zhenxun/configs/utils/__init__.py index 03bc7331..bd84d9b1 100644 --- a/zhenxun/configs/utils/__init__.py +++ b/zhenxun/configs/utils/__init__.py @@ -1,89 +1,82 @@ from collections.abc import Callable import copy -from datetime import datetime from pathlib import Path -from typing import Any, Literal +from typing import Any, TypeVar, get_args, get_origin import cattrs from nonebot.compat import model_dump -from pydantic import BaseModel, Field +from pydantic import VERSION, BaseModel, Field from ruamel.yaml import YAML from ruamel.yaml.scanner import ScannerError from zhenxun.configs.path_config import DATA_PATH from zhenxun.services.log import logger -from zhenxun.utils.enum import BlockType, LimitWatchType, PluginLimitType, PluginType + +from .models import ( + AICallableParam, + AICallableProperties, + AICallableTag, + BaseBlock, + Command, + ConfigModel, + Example, + PluginCdBlock, + PluginCountBlock, + PluginExtraData, + PluginSetting, + RegisterConfig, + Task, +) _yaml = YAML(pure=True) _yaml.indent = 2 _yaml.allow_unicode = True +T = TypeVar("T") -class Example(BaseModel): + +class NoSuchConfig(Exception): + pass + + +def _dump_pydantic_obj(obj: Any) -> Any: """ - 示例 + 递归地将一个对象内部的 Pydantic BaseModel 实例转换为字典。 + 支持单个实例、实例列表、实例字典等情况。 """ - - exec: str - """执行命令""" - description: str = "" - """命令描述""" + if isinstance(obj, BaseModel): + return model_dump(obj) + if isinstance(obj, list): + return [_dump_pydantic_obj(item) for item in obj] + if isinstance(obj, dict): + return {key: _dump_pydantic_obj(value) for key, value in obj.items()} + return obj -class Command(BaseModel): +def _is_pydantic_type(t: Any) -> bool: """ - 具体参数说明 + 递归检查一个类型注解是否与 Pydantic BaseModel 相关。 """ - - command: str - """命令名称""" - params: list[str] = Field(default_factory=list) - """参数""" - description: str = "" - """描述""" - examples: list[Example] = Field(default_factory=list) - """示例列表""" + if t is None: + return False + origin = get_origin(t) + if origin: + return any(_is_pydantic_type(arg) for arg in get_args(t)) + return isinstance(t, type) and issubclass(t, BaseModel) -class RegisterConfig(BaseModel): +def parse_as(type_: type[T], obj: Any) -> T: """ - 注册配置项 + 一个兼容 Pydantic V1 的 parse_obj_as 和V2的TypeAdapter.validate_python 的辅助函数。 """ + if VERSION.startswith("1"): + from pydantic import parse_obj_as - key: str - """配置项键""" - value: Any - """配置项值""" - module: str | None = None - """模块名""" - help: str | None - """配置注解""" - default_value: Any | None = None - """默认值""" - type: Any = None - """参数类型""" - arg_parser: Callable | None = None - """参数解析""" + return parse_obj_as(type_, obj) + else: + from pydantic import TypeAdapter # type: ignore - -class ConfigModel(BaseModel): - """ - 配置项 - """ - - value: Any - """配置项值""" - help: str | None - """配置注解""" - default_value: Any | None = None - """默认值""" - type: Any = None - """参数类型""" - arg_parser: Callable | None = None - """参数解析""" - - def to_dict(self, **kwargs): - return model_dump(self, **kwargs) + return TypeAdapter(type_).validate_python(obj) class ConfigGroup(BaseModel): @@ -98,202 +91,41 @@ class ConfigGroup(BaseModel): configs: dict[str, ConfigModel] = Field(default_factory=dict) """配置项列表""" - def get(self, c: str, default: Any = None) -> Any: - cfg = self.configs.get(c.upper()) - if cfg is not None: - if cfg.value is not None: - return cfg.value - if cfg.default_value is not None: - return cfg.default_value - return default + def get(self, c: str, default: Any = None, *, build_model: bool = True) -> Any: + """ + 获取配置项的值。如果指定了类型,会自动构建实例。 + """ + key = c.upper() + cfg = self.configs.get(key) + + if cfg is None: + return default + + value_to_process = cfg.value if cfg.value is not None else cfg.default_value + + if value_to_process is None: + return default + + if cfg.type: + if _is_pydantic_type(cfg.type): + if build_model: + try: + return parse_as(cfg.type, value_to_process) + except Exception as e: + logger.warning( + f"Pydantic 模型解析失败 (key: {c.upper()}). ", e=e + ) + try: + return cattrs.structure(value_to_process, cfg.type) + except Exception as e: + logger.warning(f"Cattrs 结构化失败 (key: {key}),返回原始值。", e=e) + + return value_to_process def to_dict(self, **kwargs): return model_dump(self, **kwargs) -class BaseBlock(BaseModel): - """ - 插件阻断基本类(插件阻断限制) - """ - - status: bool = True - """限制状态""" - check_type: BlockType = BlockType.ALL - """检查类型""" - watch_type: LimitWatchType = LimitWatchType.USER - """监听对象""" - result: str | None = None - """阻断时回复内容""" - _type: PluginLimitType = PluginLimitType.BLOCK - """类型""" - - def to_dict(self, **kwargs): - return model_dump(self, **kwargs) - - -class PluginCdBlock(BaseBlock): - """ - 插件cd限制 - """ - - cd: int = 5 - """cd""" - _type: PluginLimitType = PluginLimitType.CD - """类型""" - - -class PluginCountBlock(BaseBlock): - """ - 插件次数限制 - """ - - max_count: int - """最大调用次数""" - _type: PluginLimitType = PluginLimitType.COUNT - """类型""" - - -class PluginSetting(BaseModel): - """ - 插件基本配置 - """ - - level: int = 5 - """群权限等级""" - default_status: bool = True - """进群默认开关状态""" - limit_superuser: bool = False - """是否限制超级用户""" - cost_gold: int = 0 - """调用插件花费金币""" - impression: float = 0.0 - """调用插件好感度限制""" - - -class AICallableProperties(BaseModel): - type: str - """参数类型""" - description: str - """参数描述""" - enums: list[str] | None = None - """参数枚举""" - - -class AICallableParam(BaseModel): - type: str - """类型""" - properties: dict[str, AICallableProperties] - """参数列表""" - required: list[str] - """必要参数""" - - -class AICallableTag(BaseModel): - name: str - """工具名称""" - parameters: AICallableParam | None = None - """工具参数""" - description: str - """工具描述""" - func: Callable | None = None - """工具函数""" - - def to_dict(self): - result = model_dump(self) - del result["func"] - return result - - -class SchedulerModel(BaseModel): - trigger: Literal["date", "interval", "cron"] - """trigger""" - day: int | None = None - """天数""" - hour: int | None = None - """小时""" - minute: int | None = None - """分钟""" - second: int | None = None - """秒""" - run_date: datetime | None = None - """运行日期""" - id: str | None = None - """id""" - max_instances: int | None = None - """最大运行实例""" - args: list | None = None - """参数""" - kwargs: dict | None = None - """参数""" - - -class Task(BaseBlock): - module: str - """被动技能模块名""" - name: str - """被动技能名称""" - status: bool = True - """全局开关状态""" - create_status: bool = False - """初次加载默认开关状态""" - default_status: bool = True - """进群时默认状态""" - scheduler: SchedulerModel | None = None - """定时任务配置""" - run_func: Callable | None = None - """运行函数""" - check: Callable | None = None - """检查函数""" - check_args: list = Field(default_factory=list) - """检查函数参数""" - - -class PluginExtraData(BaseModel): - """ - 插件扩展信息 - """ - - author: str | None = None - """作者""" - version: str | None = None - """版本""" - plugin_type: PluginType = PluginType.NORMAL - """插件类型""" - menu_type: str = "功能" - """菜单类型""" - admin_level: int | None = None - """管理员插件所需权限等级""" - configs: list[RegisterConfig] | None = None - """插件配置""" - setting: PluginSetting | None = None - """插件基本配置""" - limits: list[BaseBlock | PluginCdBlock | PluginCountBlock] | None = None - """插件限制""" - commands: list[Command] = Field(default_factory=list) - """命令列表,用于说明帮助""" - ignore_prompt: bool = False - """是否忽略阻断提示""" - tasks: list[Task] | None = None - """技能被动""" - superuser_help: str | None = None - """超级用户帮助""" - aliases: set[str] = Field(default_factory=set) - """额外名称""" - sql_list: list[str] | None = None - """常用sql""" - is_show: bool = True - """是否显示在菜单中""" - smart_tools: list[AICallableTag] | None = None - """智能模式函数工具集""" - - def to_dict(self, **kwargs): - return model_dump(self, **kwargs) - - -class NoSuchConfig(Exception): - pass - - class ConfigsManager: """ 插件配置 与 资源 管理器 @@ -366,23 +198,32 @@ class ConfigsManager: if not module or not key: raise ValueError("add_plugin_config: module和key不能为为空") + if isinstance(value, BaseModel): + value = model_dump(value) + if isinstance(default_value, BaseModel): + default_value = model_dump(default_value) + + processed_value = _dump_pydantic_obj(value) + processed_default_value = _dump_pydantic_obj(default_value) + self.add_module.append(f"{module}:{key}".lower()) if module in self._data and (config := self._data[module].configs.get(key)): config.help = help config.arg_parser = arg_parser config.type = type if _override: - config.value = value - config.default_value = default_value + config.value = processed_value + config.default_value = processed_default_value else: key = key.upper() if not self._data.get(module): self._data[module] = ConfigGroup(module=module) self._data[module].configs[key] = ConfigModel( - value=value, + value=processed_value, help=help, - default_value=default_value, + default_value=processed_default_value, type=type, + arg_parser=arg_parser, ) def set_config( @@ -402,6 +243,8 @@ class ConfigsManager: """ key = key.upper() if module in self._data: + if module not in self._simple_data: + self._simple_data[module] = {} if self._data[module].configs.get(key): self._data[module].configs[key].value = value else: @@ -410,63 +253,68 @@ class ConfigsManager: if auto_save: self.save(save_simple_data=True) - def get_config(self, module: str, key: str, default: Any = None) -> Any: - """获取指定配置值 - - 参数: - module: 模块名 - key: 配置键 - default: 没有key值内容的默认返回值. - - 异常: - NoSuchConfig: 未查询到配置 - - 返回: - Any: 配置值 + def get_config( + self, + module: str, + key: str, + default: Any = None, + *, + build_model: bool = True, + ) -> Any: + """ + 获取指定配置值,自动构建Pydantic模型或其它类型实例。 + - 兼容Pydantic V1/V2。 + - 支持 list[BaseModel] 等泛型容器。 + - 优先使用Pydantic原生方式解析,失败后回退到cattrs。 """ - logger.debug( - f"尝试获取配置MODULE: [{module}] | KEY: [{key}]" - ) key = key.upper() - value = None - if module in self._data.keys(): - config = self._data[module].configs.get(key) or self._data[ - module - ].configs.get(key) - if not config: - raise NoSuchConfig( - f"未查询到配置项 MODULE: [ {module} ] | KEY: [ {key} ]" - ) + config_group = self._data.get(module) + if not config_group: + return default + + config = config_group.configs.get(key) + if not config: + return default + + value_to_process = ( + config.value if config.value is not None else config.default_value + ) + if value_to_process is None: + return default + + # 1. 最高优先级:自定义的参数解析器 + if config.arg_parser: try: - if config.arg_parser: - value = config.arg_parser(value or config.default_value) - elif config.value is not None: - # try: - value = ( - cattrs.structure(config.value, config.type) - if config.type - else config.value - ) - elif config.default_value is not None: - value = ( - cattrs.structure(config.default_value, config.type) - if config.type - else config.default_value - ) + return config.arg_parser(value_to_process) except Exception as e: - logger.warning( + logger.debug( f"配置项类型转换 MODULE: [{module}]" - " | KEY: [{key}]", + f" | KEY: [{key}] 将使用原始值", e=e, ) - value = config.value or config.default_value - if value is None: - value = default - logger.debug( - f"获取配置 MODULE: [{module}] | " - f" KEY: [{key}] -> [{value}]" - ) - return value + + if config.type: + if _is_pydantic_type(config.type): + if build_model: + try: + return parse_as(config.type, value_to_process) + except Exception as e: + logger.warning( + f"pydantic类型转换失败 MODULE: [{module}] | " + f"KEY: [{key}].", + e=e, + ) + else: + try: + return cattrs.structure(value_to_process, config.type) + except Exception as e: + logger.warning( + f"cattrs类型转换失败 MODULE: [{module}] | " + f"KEY: [{key}].", + e=e, + ) + + return value_to_process def get(self, key: str) -> ConfigGroup: """获取插件配置数据 @@ -490,16 +338,16 @@ class ConfigsManager: with open(self._simple_file, "w", encoding="utf8") as f: _yaml.dump(self._simple_data, f) path = path or self.file - data = {} - for module in self._data: - data[module] = {} - for config in self._data[module].configs: - value = self._data[module].configs[config].dict() - del value["type"] - del value["arg_parser"] - data[module][config] = value + save_data = {} + for module, config_group in self._data.items(): + save_data[module] = {} + for config_key, config_model in config_group.configs.items(): + save_data[module][config_key] = model_dump( + config_model, exclude={"type", "arg_parser"} + ) + with open(path, "w", encoding="utf8") as f: - _yaml.dump(data, f) + _yaml.dump(save_data, f) def reload(self): """重新加载配置文件""" @@ -558,3 +406,23 @@ class ConfigsManager: def __getitem__(self, key): return self._data[key] + + +__all__ = [ + "AICallableParam", + "AICallableProperties", + "AICallableTag", + "BaseBlock", + "Command", + "ConfigGroup", + "ConfigModel", + "ConfigsManager", + "Example", + "NoSuchConfig", + "PluginCdBlock", + "PluginCountBlock", + "PluginExtraData", + "PluginSetting", + "RegisterConfig", + "Task", +] diff --git a/zhenxun/configs/utils/models.py b/zhenxun/configs/utils/models.py new file mode 100644 index 00000000..d3c0db7f --- /dev/null +++ b/zhenxun/configs/utils/models.py @@ -0,0 +1,270 @@ +from collections.abc import Callable +from datetime import datetime +from typing import Any, Literal + +from nonebot.compat import model_dump +from pydantic import BaseModel, Field + +from zhenxun.utils.enum import BlockType, LimitWatchType, PluginLimitType, PluginType + +__all__ = [ + "AICallableParam", + "AICallableProperties", + "AICallableTag", + "BaseBlock", + "Command", + "ConfigModel", + "Example", + "PluginCdBlock", + "PluginCountBlock", + "PluginExtraData", + "PluginSetting", + "RegisterConfig", + "Task", +] + + +class Example(BaseModel): + """ + 示例 + """ + + exec: str + """执行命令""" + description: str = "" + """命令描述""" + + +class Command(BaseModel): + """ + 具体参数说明 + """ + + command: str + """命令名称""" + params: list[str] = Field(default_factory=list) + """参数""" + description: str = "" + """描述""" + examples: list[Example] = Field(default_factory=list) + """示例列表""" + + +class RegisterConfig(BaseModel): + """ + 注册配置项 + """ + + key: str + """配置项键""" + value: Any + """配置项值""" + module: str | None = None + """模块名""" + help: str | None + """配置注解""" + default_value: Any | None = None + """默认值""" + type: Any = None + """参数类型""" + arg_parser: Callable | None = None + """参数解析""" + + +class ConfigModel(BaseModel): + """ + 配置项 + """ + + value: Any + """配置项值""" + help: str | None + """配置注解""" + default_value: Any | None = None + """默认值""" + type: Any = None + """参数类型""" + arg_parser: Callable | None = None + """参数解析""" + + def to_dict(self, **kwargs): + return model_dump(self, **kwargs) + + +class BaseBlock(BaseModel): + """ + 插件阻断基本类(插件阻断限制) + """ + + status: bool = True + """限制状态""" + check_type: BlockType = BlockType.ALL + """检查类型""" + watch_type: LimitWatchType = LimitWatchType.USER + """监听对象""" + result: str | None = None + """阻断时回复内容""" + _type: PluginLimitType = PluginLimitType.BLOCK + """类型""" + + def to_dict(self, **kwargs): + return model_dump(self, **kwargs) + + +class PluginCdBlock(BaseBlock): + """ + 插件cd限制 + """ + + cd: int = 5 + """cd""" + _type: PluginLimitType = PluginLimitType.CD + """类型""" + + +class PluginCountBlock(BaseBlock): + """ + 插件次数限制 + """ + + max_count: int + """最大调用次数""" + _type: PluginLimitType = PluginLimitType.COUNT + """类型""" + + +class PluginSetting(BaseModel): + """ + 插件基本配置 + """ + + level: int = 5 + """群权限等级""" + default_status: bool = True + """进群默认开关状态""" + limit_superuser: bool = False + """是否限制超级用户""" + cost_gold: int = 0 + """调用插件花费金币""" + impression: float = 0.0 + """调用插件好感度限制""" + + +class AICallableProperties(BaseModel): + type: str + """参数类型""" + description: str + """参数描述""" + enums: list[str] | None = None + """参数枚举""" + + +class AICallableParam(BaseModel): + type: str + """类型""" + properties: dict[str, AICallableProperties] + """参数列表""" + required: list[str] + """必要参数""" + + +class AICallableTag(BaseModel): + name: str + """工具名称""" + parameters: AICallableParam | None = None + """工具参数""" + description: str + """工具描述""" + func: Callable | None = None + """工具函数""" + + def to_dict(self): + result = model_dump(self) + del result["func"] + return result + + +class SchedulerModel(BaseModel): + trigger: Literal["date", "interval", "cron"] + """trigger""" + day: int | None = None + """天数""" + hour: int | None = None + """小时""" + minute: int | None = None + """分钟""" + second: int | None = None + """秒""" + run_date: datetime | None = None + """运行日期""" + id: str | None = None + """id""" + max_instances: int | None = None + """最大运行实例""" + args: list | None = None + """参数""" + kwargs: dict | None = None + """参数""" + + +class Task(BaseBlock): + module: str + """被动技能模块名""" + name: str + """被动技能名称""" + status: bool = True + """全局开关状态""" + create_status: bool = False + """初次加载默认开关状态""" + default_status: bool = True + """进群时默认状态""" + scheduler: SchedulerModel | None = None + """定时任务配置""" + run_func: Callable | None = None + """运行函数""" + check: Callable | None = None + """检查函数""" + check_args: list = Field(default_factory=list) + """检查函数参数""" + + +class PluginExtraData(BaseModel): + """ + 插件扩展信息 + """ + + author: str | None = None + """作者""" + version: str | None = None + """版本""" + plugin_type: PluginType = PluginType.NORMAL + """插件类型""" + menu_type: str = "功能" + """菜单类型""" + admin_level: int | None = None + """管理员插件所需权限等级""" + configs: list[RegisterConfig] | None = None + """插件配置""" + setting: PluginSetting | None = None + """插件基本配置""" + limits: list[BaseBlock | PluginCdBlock | PluginCountBlock] | None = None + """插件限制""" + commands: list[Command] = Field(default_factory=list) + """命令列表,用于说明帮助""" + ignore_prompt: bool = False + """是否忽略阻断提示""" + tasks: list[Task] | None = None + """技能被动""" + superuser_help: str | None = None + """超级用户帮助""" + aliases: set[str] = Field(default_factory=set) + """额外名称""" + sql_list: list[str] | None = None + """常用sql""" + is_show: bool = True + """是否显示在菜单中""" + smart_tools: list[AICallableTag] | None = None + """智能模式函数工具集""" + + def to_dict(self, **kwargs): + return model_dump(self, **kwargs) diff --git a/zhenxun/models/bot_message_store.py b/zhenxun/models/bot_message_store.py new file mode 100644 index 00000000..fa1244f9 --- /dev/null +++ b/zhenxun/models/bot_message_store.py @@ -0,0 +1,29 @@ +from tortoise import fields + +from zhenxun.services.db_context import Model +from zhenxun.utils.enum import BotSentType + + +class BotMessageStore(Model): + id = fields.IntField(pk=True, generated=True, auto_increment=True) + """自增id""" + bot_id = fields.CharField(255, null=True) + """bot id""" + user_id = fields.CharField(255, null=True) + """目标id""" + group_id = fields.CharField(255, null=True) + """群组id""" + sent_type = fields.CharEnumField(BotSentType) + """类型""" + text = fields.TextField(null=True) + """文本内容""" + plain_text = fields.TextField(null=True) + """纯文本""" + platform = fields.CharField(255, null=True) + """平台""" + create_time = fields.DatetimeField(auto_now_add=True) + """创建时间""" + + class Meta: # pyright: ignore [reportIncompatibleVariableOverride] + table = "bot_message_store" + table_description = "Bot发送消息列表" diff --git a/zhenxun/models/event_log.py b/zhenxun/models/event_log.py new file mode 100644 index 00000000..6737f619 --- /dev/null +++ b/zhenxun/models/event_log.py @@ -0,0 +1,21 @@ +from tortoise import fields + +from zhenxun.services.db_context import Model +from zhenxun.utils.enum import EventLogType + + +class EventLog(Model): + id = fields.IntField(pk=True, generated=True, auto_increment=True) + """自增id""" + user_id = fields.CharField(255, description="用户id") + """用户id""" + group_id = fields.CharField(255, description="群组id") + """群组id""" + event_type = fields.CharEnumField(EventLogType, default=None, description="类型") + """类型""" + create_time = fields.DatetimeField(auto_now_add=True, description="创建时间") + """创建时间""" + + class Meta: # pyright: ignore [reportIncompatibleVariableOverride] + table = "event_log" + table_description = "各种请求通知记录表" diff --git a/zhenxun/models/fg_request.py b/zhenxun/models/fg_request.py index 4aee1d73..4362a7d3 100644 --- a/zhenxun/models/fg_request.py +++ b/zhenxun/models/fg_request.py @@ -3,8 +3,10 @@ from typing_extensions import Self from nonebot.adapters import Bot from tortoise import fields +from zhenxun.configs.config import BotConfig from zhenxun.models.group_console import GroupConsole from zhenxun.services.db_context import Model +from zhenxun.utils.common_utils import SqlUtils from zhenxun.utils.enum import RequestHandleType, RequestType from zhenxun.utils.exception import NotFoundError @@ -34,6 +36,8 @@ class FgRequest(Model): RequestHandleType, null=True, description="处理类型" ) """处理类型""" + message_ids = fields.CharField(max_length=255, null=True, description="消息id列表") + """消息id列表""" class Meta: # pyright: ignore [reportIncompatibleVariableOverride] table = "fg_request" @@ -123,9 +127,24 @@ class FgRequest(Model): await GroupConsole.update_or_create( group_id=req.group_id, defaults={"group_flag": 1} ) - await bot.set_group_add_request( - flag=req.flag, - sub_type="invite", - approve=handle_type == RequestHandleType.APPROVE, - ) + if req.flag == "0": + # 用户手动申请入群,创建群认证后提醒用户拉群 + await bot.send_private_msg( + user_id=req.user_id, + message=f"已同意你对{BotConfig.self_nickname}的申请群组:" + f"{req.group_id},可以直接手动拉入群组,{BotConfig.self_nickname}会自动同意。", + ) + else: + # 正常同意群组请求 + await bot.set_group_add_request( + flag=req.flag, + sub_type="invite", + approve=handle_type == RequestHandleType.APPROVE, + ) return req + + @classmethod + async def _run_script(cls): + return [ + SqlUtils.add_column("fg_request", "message_ids", "character varying(255)") + ] diff --git a/zhenxun/models/group_console.py b/zhenxun/models/group_console.py index a85ed1f8..123c8411 100644 --- a/zhenxun/models/group_console.py +++ b/zhenxun/models/group_console.py @@ -42,9 +42,9 @@ def convert_module_format(data: str | list[str]) -> str | list[str]: str | list[str]: 根据输入类型返回转换后的数据。 """ if isinstance(data, str): - return [item.strip(",") for item in data.split("<") if item] + return [item.strip(",") for item in data.split("<") if item.strip()] else: - return "".join(format(item) for item in data) + return "".join(add_disable_marker(item) for item in data) class GroupConsole(Model): diff --git a/zhenxun/models/mahiro_bank.py b/zhenxun/models/mahiro_bank.py new file mode 100644 index 00000000..3880daa8 --- /dev/null +++ b/zhenxun/models/mahiro_bank.py @@ -0,0 +1,123 @@ +from datetime import datetime +from typing_extensions import Self + +from tortoise import fields + +from zhenxun.services.db_context import Model + +from .mahiro_bank_log import BankHandleType, MahiroBankLog + + +class MahiroBank(Model): + id = fields.IntField(pk=True, generated=True, auto_increment=True) + """自增id""" + user_id = fields.CharField(255, description="用户id") + """用户id""" + amount = fields.BigIntField(default=0, description="存款") + """用户存款""" + rate = fields.FloatField(default=0.0005, description="小时利率") + """小时利率""" + loan_amount = fields.BigIntField(default=0, description="贷款") + """用户贷款""" + loan_rate = fields.FloatField(default=0.0005, description="贷款利率") + """贷款利率""" + update_time = fields.DatetimeField(auto_now=True) + """修改时间""" + create_time = fields.DatetimeField(auto_now_add=True) + """创建时间""" + + class Meta: # pyright: ignore [reportIncompatibleVariableOverride] + table = "mahiro_bank" + table_description = "小真寻银行" + + @classmethod + async def deposit(cls, user_id: str, amount: int, rate: float) -> Self: + """存款 + + 参数: + user_id: 用户id + amount: 金币数量 + rate: 小时利率 + + 返回: + Self: MahiroBank + """ + effective_hour = int(24 - datetime.now().hour) + user, _ = await cls.get_or_create(user_id=user_id) + user.amount += amount + await user.save(update_fields=["amount", "rate"]) + await MahiroBankLog.create( + user_id=user_id, + amount=amount, + rate=rate, + effective_hour=effective_hour, + handle_type=BankHandleType.DEPOSIT, + ) + return user + + @classmethod + async def withdraw(cls, user_id: str, amount: int) -> Self: + """取款 + + 参数: + user_id: 用户id + amount: 金币数量 + + 返回: + Self: MahiroBank + """ + if amount <= 0: + raise ValueError("取款金额必须大于0") + user, _ = await cls.get_or_create(user_id=user_id) + if user.amount < amount: + raise ValueError("取款金额不能大于存款金额") + user.amount -= amount + await user.save(update_fields=["amount"]) + await MahiroBankLog.create( + user_id=user_id, amount=amount, handle_type=BankHandleType.WITHDRAW + ) + return user + + @classmethod + async def loan(cls, user_id: str, amount: int, rate: float) -> Self: + """贷款 + + 参数: + user_id: 用户id + amount: 贷款金额 + rate: 贷款利率 + + 返回: + Self: MahiroBank + """ + user, _ = await cls.get_or_create(user_id=user_id) + user.loan_amount += amount + user.loan_rate = rate + await user.save(update_fields=["loan_amount", "loan_rate"]) + await MahiroBankLog.create( + user_id=user_id, amount=amount, rate=rate, handle_type=BankHandleType.LOAN + ) + return user + + @classmethod + async def repayment(cls, user_id: str, amount: int) -> Self: + """还款 + + 参数: + user_id: 用户id + amount: 还款金额 + + 返回: + Self: MahiroBank + """ + if amount <= 0: + raise ValueError("还款金额必须大于0") + user, _ = await cls.get_or_create(user_id=user_id) + if user.loan_amount < amount: + raise ValueError("还款金额不能大于贷款金额") + user.loan_amount -= amount + await user.save(update_fields=["loan_amount"]) + await MahiroBankLog.create( + user_id=user_id, amount=amount, handle_type=BankHandleType.REPAYMENT + ) + return user diff --git a/zhenxun/models/mahiro_bank_log.py b/zhenxun/models/mahiro_bank_log.py new file mode 100644 index 00000000..433241d1 --- /dev/null +++ b/zhenxun/models/mahiro_bank_log.py @@ -0,0 +1,31 @@ +from tortoise import fields + +from zhenxun.services.db_context import Model +from zhenxun.utils.enum import BankHandleType + + +class MahiroBankLog(Model): + id = fields.IntField(pk=True, generated=True, auto_increment=True) + """自增id""" + user_id = fields.CharField(255, description="用户id") + """用户id""" + amount = fields.BigIntField(default=0, description="存款") + """金币数量""" + rate = fields.FloatField(default=0, description="小时利率") + """小时利率""" + handle_type = fields.CharEnumField( + BankHandleType, null=True, description="处理类型" + ) + """处理类型""" + is_completed = fields.BooleanField(default=False, description="是否完成") + """是否完成""" + effective_hour = fields.IntField(default=0, description="有效小时") + """有效小时""" + update_time = fields.DatetimeField(auto_now=True) + """修改时间""" + create_time = fields.DatetimeField(auto_now_add=True) + """创建时间""" + + class Meta: # pyright: ignore [reportIncompatibleVariableOverride] + table = "mahiro_bank_log" + table_description = "小真寻银行日志" diff --git a/zhenxun/models/plugin_info.py b/zhenxun/models/plugin_info.py index a5bfd6a8..aeecc71b 100644 --- a/zhenxun/models/plugin_info.py +++ b/zhenxun/models/plugin_info.py @@ -62,27 +62,41 @@ class PluginInfo(Model): cache_type = CacheType.PLUGINS @classmethod - async def get_plugin(cls, load_status: bool = True, **kwargs) -> Self | None: + async def get_plugin( + cls, load_status: bool = True, filter_parent: bool = True, **kwargs + ) -> Self | None: """获取插件列表 参数: load_status: 加载状态. + filter_parent: 过滤父组件 返回: Self | None: 插件 """ + if filter_parent: + return await cls.get_or_none( + load_status=load_status, plugin_type__not=PluginType.PARENT, **kwargs + ) return await cls.get_or_none(load_status=load_status, **kwargs) @classmethod - async def get_plugins(cls, load_status: bool = True, **kwargs) -> list[Self]: + async def get_plugins( + cls, load_status: bool = True, filter_parent: bool = True, **kwargs + ) -> list[Self]: """获取插件列表 参数: load_status: 加载状态. + filter_parent: 过滤父组件 返回: list[Self]: 插件列表 """ + if filter_parent: + return await cls.filter( + load_status=load_status, plugin_type__not=PluginType.PARENT, **kwargs + ).all() return await cls.filter(load_status=load_status, **kwargs).all() @classmethod diff --git a/zhenxun/models/schedule_info.py b/zhenxun/models/schedule_info.py new file mode 100644 index 00000000..c7583078 --- /dev/null +++ b/zhenxun/models/schedule_info.py @@ -0,0 +1,38 @@ +from tortoise import fields + +from zhenxun.services.db_context import Model + + +class ScheduleInfo(Model): + id = fields.IntField(pk=True, generated=True, auto_increment=True) + """自增id""" + bot_id = fields.CharField( + 255, null=True, default=None, description="任务关联的Bot ID" + ) + """任务关联的Bot ID""" + plugin_name = fields.CharField(255, description="插件模块名") + """插件模块名""" + group_id = fields.CharField( + 255, + null=True, + description="群组ID, '__ALL_GROUPS__' 表示所有群, 为空表示全局任务", + ) + """群组ID, 为空表示全局任务""" + trigger_type = fields.CharField( + max_length=20, default="cron", description="触发器类型 (cron, interval, date)" + ) + """触发器类型 (cron, interval, date)""" + trigger_config = fields.JSONField(description="触发器具体配置") + """触发器具体配置""" + job_kwargs = fields.JSONField( + default=dict, description="传递给任务函数的额外关键字参数" + ) + """传递给任务函数的额外关键字参数""" + is_enabled = fields.BooleanField(default=True, description="是否启用") + """是否启用""" + create_time = fields.DatetimeField(auto_now_add=True) + """创建时间""" + + class Meta: # pyright: ignore [reportIncompatibleVariableOverride] + table = "schedule_info" + table_description = "通用定时任务表" diff --git a/zhenxun/services/db_context.py b/zhenxun/services/db_context.py index 4bc350c5..c8e37bc4 100644 --- a/zhenxun/services/db_context.py +++ b/zhenxun/services/db_context.py @@ -1,3 +1,4 @@ + from asyncio import Semaphore from collections.abc import Iterable from typing import Any, ClassVar @@ -12,6 +13,8 @@ from tortoise.models import Model as TortoiseModel from zhenxun.configs.config import BotConfig from zhenxun.utils.enum import DbLockType +from zhenxun.utils.exception import HookPriorityException +from zhenxun.utils.manager.priority_manager import PriorityLifecycle from .cache import CacheRoot from .log import logger @@ -29,7 +32,6 @@ def _(): global CACHE_FLAG CACHE_FLAG = True - class Model(TortoiseModel): """ 自动添加模块 @@ -38,7 +40,8 @@ class Model(TortoiseModel): sem_data: ClassVar[dict[str, dict[str, Semaphore]]] = {} def __init_subclass__(cls, **kwargs): - MODELS.append(cls.__module__) + if cls.__module__ not in MODELS: + MODELS.append(cls.__module__) if func := getattr(cls, "_run_script", None): SCRIPT_METHOD.append((cls.__module__, func)) @@ -165,7 +168,7 @@ class Model(TortoiseModel): await CacheRoot.reload(cache_type) -class DbUrlMissing(Exception): +class DbUrlIsNode(HookPriorityException): """ 数据库链接地址为空 """ @@ -181,9 +184,19 @@ class DbConnectError(Exception): pass +@PriorityLifecycle.on_startup(priority=1) async def init(): if not BotConfig.db_url: - raise DbUrlMissing("数据库配置为空,请在.env.dev中配置DB_URL...") + # raise DbUrlIsNode("数据库配置为空,请在.env.dev中配置DB_URL...") + error = f""" +********************************************************************** +🌟 **************************** 配置为空 ************************* 🌟 +🚀 请打开 WebUi 进行基础配置 🚀 +🌐 配置地址:http://{driver.config.host}:{driver.config.port}/#/configure 🌐 +*********************************************************************** +*********************************************************************** + """ + raise DbUrlIsNode("\n" + error.strip()) try: await Tortoise.init( db_url=BotConfig.db_url, diff --git a/zhenxun/services/llm/README.md b/zhenxun/services/llm/README.md new file mode 100644 index 00000000..263be1e6 --- /dev/null +++ b/zhenxun/services/llm/README.md @@ -0,0 +1,731 @@ +# Zhenxun LLM 服务模块 + +## 📑 目录 + +- [📖 概述](#-概述) +- [🌟 主要特性](#-主要特性) +- [🚀 快速开始](#-快速开始) +- [📚 API 参考](#-api-参考) +- [⚙️ 配置](#️-配置) +- [🔧 高级功能](#-高级功能) +- [🏗️ 架构设计](#️-架构设计) +- [🔌 支持的提供商](#-支持的提供商) +- [🎯 使用场景](#-使用场景) +- [📊 性能优化](#-性能优化) +- [🛠️ 故障排除](#️-故障排除) +- [❓ 常见问题](#-常见问题) +- [📝 示例项目](#-示例项目) +- [🤝 贡献](#-贡献) +- [📄 许可证](#-许可证) + +## 📖 概述 + +Zhenxun LLM 服务模块是一个现代化的AI服务框架,提供统一的接口来访问多个大语言模型提供商。该模块采用模块化设计,支持异步操作、智能重试、Key轮询和负载均衡等高级功能。 + +### 🌟 主要特性 + +- **多提供商支持**: OpenAI、Gemini、智谱AI、DeepSeek等 +- **统一接口**: 简洁一致的API设计 +- **智能Key轮询**: 自动负载均衡和故障转移 +- **异步高性能**: 基于asyncio的并发处理 +- **模型缓存**: 智能缓存机制提升性能 +- **工具调用**: 支持Function Calling +- **嵌入向量**: 文本向量化支持 +- **错误处理**: 完善的异常处理和重试机制 +- **多模态支持**: 文本、图像、音频、视频处理 +- **代码执行**: Gemini代码执行功能 +- **搜索增强**: Google搜索集成 + +## 🚀 快速开始 + +### 基本使用 + +```python +from zhenxun.services.llm import chat, code, search, analyze + +# 简单聊天 +response = await chat("你好,请介绍一下自己") +print(response) + +# 代码执行 +result = await code("计算斐波那契数列的前10项") +print(result["text"]) +print(result["code_executions"]) + +# 搜索功能 +search_result = await search("Python异步编程最佳实践") +print(search_result["text"]) + +# 多模态分析 +from nonebot_plugin_alconna.uniseg import UniMessage, Image, Text +message = UniMessage([ + Text("分析这张图片"), + Image(path="image.jpg") +]) +analysis = await analyze(message, model="Gemini/gemini-2.0-flash") +print(analysis) +``` + +### 使用AI类 + +```python +from zhenxun.services.llm import AI, AIConfig, CommonOverrides + +# 创建AI实例 +ai = AI(AIConfig(model="OpenAI/gpt-4")) + +# 聊天对话 +response = await ai.chat("解释量子计算的基本原理") + +# 多模态分析 +from nonebot_plugin_alconna.uniseg import UniMessage, Image, Text + +multimodal_msg = UniMessage([ + Text("这张图片显示了什么?"), + Image(path="image.jpg") +]) +result = await ai.analyze(multimodal_msg) + +# 便捷的多模态函数 +result = await analyze_with_images( + "分析这张图片", + images="image.jpg", + model="Gemini/gemini-2.0-flash" +) +``` + +## 📚 API 参考 + +### 快速函数 + +#### `chat(message, *, model=None, **kwargs) -> str` +简单聊天对话 + +**参数:** +- `message`: 消息内容(字符串、LLMMessage或内容部分列表) +- `model`: 模型名称(可选) +- `**kwargs`: 额外配置参数 + +#### `code(prompt, *, model=None, timeout=None, **kwargs) -> dict` +代码执行功能 + +**返回:** +```python +{ + "text": "执行结果说明", + "code_executions": [{"code": "...", "output": "..."}], + "success": True +} +``` + +#### `search(query, *, model=None, instruction="", **kwargs) -> dict` +搜索增强生成 + +**返回:** +```python +{ + "text": "搜索结果和分析", + "grounding_metadata": {...}, + "success": True +} +``` + +#### `analyze(message, *, instruction="", model=None, tools=None, tool_config=None, **kwargs) -> str | LLMResponse` +高级分析功能,支持多模态输入和工具调用 + +#### `analyze_with_images(text, images, *, instruction="", model=None, **kwargs) -> str` +图片分析便捷函数 + +#### `analyze_multimodal(text=None, images=None, videos=None, audios=None, *, instruction="", model=None, **kwargs) -> str` +多模态分析便捷函数 + +#### `embed(texts, *, model=None, task_type="RETRIEVAL_DOCUMENT", **kwargs) -> list[list[float]]` +文本嵌入向量 + +### AI类方法 + +#### `AI.chat(message, *, model=None, **kwargs) -> str` +聊天对话方法,支持简单多模态输入 + +#### `AI.analyze(message, *, instruction="", model=None, tools=None, tool_config=None, **kwargs) -> str | LLMResponse` +高级分析方法,接收UniMessage进行多模态分析和工具调用 + +### 模型管理 + +```python +from zhenxun.services.llm import ( + get_model_instance, + list_available_models, + set_global_default_model_name, + clear_model_cache +) + +# 获取模型实例 +model = await get_model_instance("OpenAI/gpt-4o") + +# 列出可用模型 +models = list_available_models() + +# 设置默认模型 +set_global_default_model_name("Gemini/gemini-2.0-flash") + +# 清理缓存 +clear_model_cache() +``` + +## ⚙️ 配置 + +### 预设配置 + +```python +from zhenxun.services.llm import CommonOverrides + +# 创意模式 +creative_config = CommonOverrides.creative() + +# 精确模式 +precise_config = CommonOverrides.precise() + +# Gemini特殊功能 +json_config = CommonOverrides.gemini_json() +thinking_config = CommonOverrides.gemini_thinking() +code_exec_config = CommonOverrides.gemini_code_execution() +grounding_config = CommonOverrides.gemini_grounding() +``` + +### 自定义配置 + +```python +from zhenxun.services.llm import LLMGenerationConfig + +config = LLMGenerationConfig( + temperature=0.7, + max_tokens=2048, + top_p=0.9, + frequency_penalty=0.1, + presence_penalty=0.1, + stop=["END", "STOP"], + response_mime_type="application/json", + enable_code_execution=True, + enable_grounding=True +) + +response = await chat("你的问题", override_config=config) +``` + +## 🔧 高级功能 + +### 工具调用 (Function Calling) + +```python +from zhenxun.services.llm import LLMTool, get_model_instance + +# 定义工具 +tools = [ + LLMTool( + name="get_weather", + description="获取天气信息", + parameters={ + "type": "object", + "properties": { + "city": {"type": "string", "description": "城市名称"} + }, + "required": ["city"] + } + ) +] + +# 工具执行器 +async def tool_executor(tool_name: str, args: dict) -> str: + if tool_name == "get_weather": + return f"{args['city']}今天晴天,25°C" + return "未知工具" + +# 使用工具 +model = await get_model_instance("OpenAI/gpt-4") +response = await model.generate_response( + messages=[{"role": "user", "content": "北京天气如何?"}], + tools=tools, + tool_executor=tool_executor +) +``` + +### 多模态处理 + +```python +from zhenxun.services.llm import create_multimodal_message, analyze_multimodal, analyze_with_images + +# 方法1:使用便捷函数 +result = await analyze_multimodal( + text="分析这些媒体文件", + images="image.jpg", + audios="audio.mp3", + model="Gemini/gemini-2.0-flash" +) + +# 方法2:使用create_multimodal_message +message = create_multimodal_message( + text="分析这张图片和音频", + images="image.jpg", + audios="audio.mp3" +) +result = await analyze(message) + +# 方法3:图片分析专用函数 +result = await analyze_with_images( + "这张图片显示了什么?", + images=["image1.jpg", "image2.jpg"] +) +``` + +## 🛠️ 故障排除 + +### 常见错误 + +1. **配置错误**: 检查API密钥和模型配置 +2. **网络问题**: 检查代理设置和网络连接 +3. **模型不可用**: 使用 `list_available_models()` 检查可用模型 +4. **超时错误**: 调整timeout参数或使用更快的模型 + +### 调试技巧 + +```python +from zhenxun.services.llm import get_cache_stats +from zhenxun.services.log import logger + +# 查看缓存状态 +stats = get_cache_stats() +print(f"缓存命中率: {stats['hit_rate']}") + +# 启用详细日志 +logger.setLevel("DEBUG") +``` + +## ❓ 常见问题 + + +### Q: 如何处理多模态输入? + +**A:** 有多种方式处理多模态输入: +```python +# 方法1:使用便捷函数 +result = await analyze_with_images("分析这张图片", images="image.jpg") + +# 方法2:使用analyze函数 +from nonebot_plugin_alconna.uniseg import UniMessage, Image, Text +message = UniMessage([Text("分析这张图片"), Image(path="image.jpg")]) +result = await analyze(message) + +# 方法3:使用create_multimodal_message +from zhenxun.services.llm import create_multimodal_message +message = create_multimodal_message(text="分析这张图片", images="image.jpg") +result = await analyze(message) +``` + +### Q: 如何自定义工具调用? + +**A:** 使用analyze函数的tools参数: +```python +# 定义工具 +tools = [{ + "name": "calculator", + "description": "计算数学表达式", + "parameters": { + "type": "object", + "properties": { + "expression": {"type": "string", "description": "数学表达式"} + }, + "required": ["expression"] + } +}] + +# 使用工具 +from nonebot_plugin_alconna.uniseg import UniMessage, Text +message = UniMessage([Text("计算 2+3*4")]) +response = await analyze(message, tools=tools, tool_config={"mode": "auto"}) + +# 如果返回LLMResponse,说明有工具调用 +if hasattr(response, 'tool_calls'): + for tool_call in response.tool_calls: + print(f"调用工具: {tool_call.function.name}") + print(f"参数: {tool_call.function.arguments}") +``` + + +### Q: 如何确保输出格式? + +**A:** 使用结构化输出: +```python +# JSON格式输出 +config = CommonOverrides.gemini_json() + +# 自定义Schema +schema = { + "type": "object", + "properties": { + "answer": {"type": "string"}, + "confidence": {"type": "number"} + } +} +config = CommonOverrides.gemini_structured(schema) +``` + +## 📝 示例项目 + +### 完整示例 + +#### 1. 智能客服机器人 + +```python +from zhenxun.services.llm import AI, CommonOverrides +from typing import Dict, List + +class CustomerService: + def __init__(self): + self.ai = AI() + self.sessions: Dict[str, List[dict]] = {} + + async def handle_query(self, user_id: str, query: str) -> str: + # 获取或创建会话历史 + if user_id not in self.sessions: + self.sessions[user_id] = [] + + history = self.sessions[user_id] + + # 添加系统提示 + if not history: + history.append({ + "role": "system", + "content": "你是一个专业的客服助手,请友好、准确地回答用户问题。" + }) + + # 添加用户问题 + history.append({"role": "user", "content": query}) + + # 生成回复 + response = await self.ai.chat( + query, + history=history[-20:], # 保留最近20轮对话 + override_config=CommonOverrides.balanced() + ) + + # 保存回复到历史 + history.append({"role": "assistant", "content": response}) + + return response +``` + +#### 2. 文档智能问答 + +```python +from zhenxun.services.llm import embed, analyze +import numpy as np +from typing import List, Tuple + +class DocumentQA: + def __init__(self): + self.documents: List[str] = [] + self.embeddings: List[List[float]] = [] + + async def add_document(self, text: str): + """添加文档到知识库""" + self.documents.append(text) + + # 生成嵌入向量 + embedding = await embed([text]) + self.embeddings.extend(embedding) + + async def query(self, question: str, top_k: int = 3) -> str: + """查询文档并生成答案""" + if not self.documents: + return "知识库为空,请先添加文档。" + + # 生成问题的嵌入向量 + question_embedding = await embed([question]) + + # 计算相似度并找到最相关的文档 + similarities = [] + for doc_embedding in self.embeddings: + similarity = np.dot(question_embedding[0], doc_embedding) + similarities.append(similarity) + + # 获取最相关的文档 + top_indices = np.argsort(similarities)[-top_k:][::-1] + relevant_docs = [self.documents[i] for i in top_indices] + + # 构建上下文 + context = "\n\n".join(relevant_docs) + prompt = f""" +基于以下文档内容回答问题: + +文档内容: +{context} + +问题:{question} + +请基于文档内容给出准确的答案,如果文档中没有相关信息,请说明。 +""" + + result = await analyze(prompt) + return result["text"] +``` + +#### 3. 代码审查助手 + +```python +from zhenxun.services.llm import code, analyze +import os + +class CodeReviewer: + async def review_file(self, file_path: str) -> dict: + """审查代码文件""" + if not os.path.exists(file_path): + return {"error": "文件不存在"} + + with open(file_path, 'r', encoding='utf-8') as f: + code_content = f.read() + + prompt = f""" +请审查以下代码,提供详细的反馈: + +文件:{file_path} +代码: +``` +{code_content} +``` + +请从以下方面进行审查: +1. 代码质量和可读性 +2. 潜在的bug和安全问题 +3. 性能优化建议 +4. 最佳实践建议 +5. 代码风格问题 + +请以JSON格式返回结果。 +""" + + result = await analyze( + prompt, + model="DeepSeek/deepseek-coder", + override_config=CommonOverrides.gemini_json() + ) + + return { + "file": file_path, + "review": result["text"], + "success": True + } + + async def suggest_improvements(self, code: str, language: str = "python") -> str: + """建议代码改进""" + prompt = f""" +请改进以下{language}代码,使其更加高效、可读和符合最佳实践: + +原代码: +```{language} +{code} +``` + +请提供改进后的代码和说明。 +""" + + result = await code(prompt, model="DeepSeek/deepseek-coder") + return result["text"] +``` + + +## 🏗️ 架构设计 + +### 模块结构 + +``` +zhenxun/services/llm/ +├── __init__.py # 包入口,导入和暴露公共API +├── api.py # 高级API接口(AI类、便捷函数) +├── core.py # 核心基础设施(HTTP客户端、重试逻辑、KeyStore) +├── service.py # LLM模型实现类 +├── utils.py # 工具和转换函数 +├── manager.py # 模型管理和缓存 +├── adapters/ # 适配器模块 +│ ├── __init__.py # 适配器包入口 +│ ├── base.py # 基础适配器 +│ ├── factory.py # 适配器工厂 +│ ├── openai.py # OpenAI适配器 +│ ├── gemini.py # Gemini适配器 +│ └── zhipu.py # 智谱AI适配器 +├── config/ # 配置模块 +│ ├── __init__.py # 配置包入口 +│ ├── generation.py # 生成配置 +│ ├── presets.py # 预设配置 +│ └── providers.py # 提供商配置 +└── types/ # 类型定义 + ├── __init__.py # 类型包入口 + ├── content.py # 内容类型 + ├── enums.py # 枚举定义 + ├── exceptions.py # 异常定义 + └── models.py # 数据模型 +``` + +### 模块职责 + +- **`__init__.py`**: 纯粹的包入口,只负责导入和暴露公共API +- **`api.py`**: 高级API接口,包含AI类和所有便捷函数 +- **`core.py`**: 核心基础设施,包含HTTP客户端管理、重试逻辑和KeyStore +- **`service.py`**: LLM模型实现类,专注于模型逻辑 +- **`utils.py`**: 工具和转换函数,如多模态消息处理 +- **`manager.py`**: 模型管理和缓存机制 +- **`adapters/`**: 各大提供商的适配器模块,负责与不同API的交互 + - `base.py`: 定义适配器的基础接口 + - `factory.py`: 适配器工厂,用于动态加载和实例化适配器 + - `openai.py`: OpenAI API适配器 + - `gemini.py`: Google Gemini API适配器 + - `zhipu.py`: 智谱AI API适配器 +- **`config/`**: 配置管理模块 + - `generation.py`: 生成配置和预设 + - `presets.py`: 预设配置 + - `providers.py`: 提供商配置 +- **`types/`**: 类型定义模块 + - `content.py`: 内容类型定义 + - `enums.py`: 枚举定义 + - `exceptions.py`: 异常定义 + - `models.py`: 数据模型定义 + +## 🔌 支持的提供商 + +### OpenAI 兼容 + +- **OpenAI**: GPT-4o, GPT-3.5-turbo等 +- **DeepSeek**: deepseek-chat, deepseek-reasoner等 +- **其他OpenAI兼容API**: 支持自定义端点 + +```python +# OpenAI +await chat("Hello", model="OpenAI/gpt-4o") + +# DeepSeek +await chat("写代码", model="DeepSeek/deepseek-reasoner") +``` + +### Google Gemini + +- **Gemini Pro**: gemini-2.5-flash-preview-05-20 gemini-2.0-flash等 +- **特殊功能**: 代码执行、搜索增强、思考模式 + +```python +# 基础使用 +await chat("你好", model="Gemini/gemini-2.0-flash") + +# 代码执行 +await code("计算质数", model="Gemini/gemini-2.0-flash") + +# 搜索增强 +await search("最新AI发展", model="Gemini/gemini-2.5-flash-preview-05-20") +``` + +### 智谱AI + +- **GLM系列**: glm-4, glm-4v等 +- **支持功能**: 文本生成、多模态理解 + +```python +await chat("介绍北京", model="Zhipu/glm-4") +``` + +## 🎯 使用场景 + +### 1. 聊天机器人 + +```python +from zhenxun.services.llm import AI, CommonOverrides + +class ChatBot: + def __init__(self): + self.ai = AI() + self.history = [] + + async def chat(self, user_input: str) -> str: + # 添加历史记录 + self.history.append({"role": "user", "content": user_input}) + + # 生成回复 + response = await self.ai.chat( + user_input, + history=self.history[-10:], # 保留最近10轮对话 + override_config=CommonOverrides.balanced() + ) + + self.history.append({"role": "assistant", "content": response}) + return response +``` + +### 2. 代码助手 + +```python +async def code_assistant(task: str) -> dict: + """代码生成和执行助手""" + result = await code( + f"请帮我{task},并执行代码验证结果", + model="Gemini/gemini-2.0-flash", + timeout=60 + ) + + return { + "explanation": result["text"], + "code_blocks": result["code_executions"], + "success": result["success"] + } + +# 使用示例 +result = await code_assistant("实现快速排序算法") +``` + +### 3. 文档分析 + +```python +from zhenxun.services.llm import analyze_with_images + +async def analyze_document(image_path: str, question: str) -> str: + """分析文档图片并回答问题""" + result = await analyze_with_images( + f"请分析这个文档并回答:{question}", + images=image_path, + model="Gemini/gemini-2.0-flash" + ) + return result +``` + +### 4. 智能搜索 + +```python +async def smart_search(query: str) -> dict: + """智能搜索和总结""" + result = await search( + query, + model="Gemini/gemini-2.0-flash", + instruction="请提供准确、最新的信息,并注明信息来源" + ) + + return { + "summary": result["text"], + "sources": result.get("grounding_metadata", {}), + "confidence": result.get("confidence_score", 0.0) + } +``` + +## 🔧 配置管理 + + +### 动态配置 + +```python +from zhenxun.services.llm import set_global_default_model_name + +# 运行时更改默认模型 +set_global_default_model_name("OpenAI/gpt-4") + +# 检查可用模型 +models = list_available_models() +for model in models: + print(f"{model.provider}/{model.name} - {model.description}") +``` + diff --git a/zhenxun/services/llm/__init__.py b/zhenxun/services/llm/__init__.py new file mode 100644 index 00000000..ff09ef7a --- /dev/null +++ b/zhenxun/services/llm/__init__.py @@ -0,0 +1,96 @@ +""" +LLM 服务模块 - 公共 API 入口 + +提供统一的 AI 服务调用接口、核心类型定义和模型管理功能。 +""" + +from .api import ( + AI, + AIConfig, + TaskType, + analyze, + analyze_multimodal, + analyze_with_images, + chat, + code, + embed, + search, + search_multimodal, +) +from .config import ( + CommonOverrides, + LLMGenerationConfig, + register_llm_configs, +) + +register_llm_configs() +from .api import ModelName +from .manager import ( + clear_model_cache, + get_cache_stats, + get_global_default_model_name, + get_model_instance, + list_available_models, + list_embedding_models, + list_model_identifiers, + set_global_default_model_name, +) +from .types import ( + EmbeddingTaskType, + LLMContentPart, + LLMErrorCode, + LLMException, + LLMMessage, + LLMResponse, + LLMTool, + ModelDetail, + ModelInfo, + ModelProvider, + ResponseFormat, + ToolCategory, + ToolMetadata, + UsageInfo, +) +from .utils import create_multimodal_message, unimsg_to_llm_parts + +__all__ = [ + "AI", + "AIConfig", + "CommonOverrides", + "EmbeddingTaskType", + "LLMContentPart", + "LLMErrorCode", + "LLMException", + "LLMGenerationConfig", + "LLMMessage", + "LLMResponse", + "LLMTool", + "ModelDetail", + "ModelInfo", + "ModelName", + "ModelProvider", + "ResponseFormat", + "TaskType", + "ToolCategory", + "ToolMetadata", + "UsageInfo", + "analyze", + "analyze_multimodal", + "analyze_with_images", + "chat", + "clear_model_cache", + "code", + "create_multimodal_message", + "embed", + "get_cache_stats", + "get_global_default_model_name", + "get_model_instance", + "list_available_models", + "list_embedding_models", + "list_model_identifiers", + "register_llm_configs", + "search", + "search_multimodal", + "set_global_default_model_name", + "unimsg_to_llm_parts", +] diff --git a/zhenxun/services/llm/adapters/__init__.py b/zhenxun/services/llm/adapters/__init__.py new file mode 100644 index 00000000..93ed9d31 --- /dev/null +++ b/zhenxun/services/llm/adapters/__init__.py @@ -0,0 +1,26 @@ +""" +LLM 适配器模块 + +提供不同LLM服务商的API适配器实现,统一接口调用方式。 +""" + +from .base import BaseAdapter, OpenAICompatAdapter, RequestData, ResponseData +from .factory import LLMAdapterFactory, get_adapter_for_api_type, register_adapter +from .gemini import GeminiAdapter +from .openai import OpenAIAdapter +from .zhipu import ZhipuAdapter + +LLMAdapterFactory.initialize() + +__all__ = [ + "BaseAdapter", + "GeminiAdapter", + "LLMAdapterFactory", + "OpenAIAdapter", + "OpenAICompatAdapter", + "RequestData", + "ResponseData", + "ZhipuAdapter", + "get_adapter_for_api_type", + "register_adapter", +] diff --git a/zhenxun/services/llm/adapters/base.py b/zhenxun/services/llm/adapters/base.py new file mode 100644 index 00000000..499f9248 --- /dev/null +++ b/zhenxun/services/llm/adapters/base.py @@ -0,0 +1,508 @@ +""" +LLM 适配器基类和通用数据结构 +""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +from pydantic import BaseModel + +from zhenxun.services.log import logger + +from ..types.exceptions import LLMErrorCode, LLMException +from ..types.models import LLMToolCall + +if TYPE_CHECKING: + from ..config.generation import LLMGenerationConfig + from ..service import LLMModel + from ..types.content import LLMMessage + from ..types.enums import EmbeddingTaskType + + +class RequestData(BaseModel): + """请求数据封装""" + + url: str + headers: dict[str, str] + body: dict[str, Any] + + +class ResponseData(BaseModel): + """响应数据封装 - 支持所有高级功能""" + + text: str + usage_info: dict[str, Any] | None = None + raw_response: dict[str, Any] | None = None + tool_calls: list[LLMToolCall] | None = None + code_executions: list[Any] | None = None + grounding_metadata: Any | None = None + cache_info: Any | None = None + + code_execution_results: list[dict[str, Any]] | None = None + search_results: list[dict[str, Any]] | None = None + function_calls: list[dict[str, Any]] | None = None + safety_ratings: list[dict[str, Any]] | None = None + citations: list[dict[str, Any]] | None = None + + +class BaseAdapter(ABC): + """LLM API适配器基类""" + + @property + @abstractmethod + def api_type(self) -> str: + """API类型标识""" + pass + + @property + @abstractmethod + def supported_api_types(self) -> list[str]: + """支持的API类型列表""" + pass + + def prepare_simple_request( + self, + model: "LLMModel", + api_key: str, + prompt: str, + history: list[dict[str, str]] | None = None, + ) -> RequestData: + """准备简单文本生成请求 + + 默认实现:将简单请求转换为高级请求格式 + 子类可以重写此方法以提供特定的优化实现 + """ + from ..types.content import LLMMessage + + messages: list[LLMMessage] = [] + + if history: + for msg in history: + role = msg.get("role", "user") + content = msg.get("content", "") + messages.append(LLMMessage(role=role, content=content)) + + messages.append(LLMMessage(role="user", content=prompt)) + + config = model._generation_config + + return self.prepare_advanced_request( + model=model, + api_key=api_key, + messages=messages, + config=config, + tools=None, + tool_choice=None, + ) + + @abstractmethod + def prepare_advanced_request( + self, + model: "LLMModel", + api_key: str, + messages: list["LLMMessage"], + config: "LLMGenerationConfig | None" = None, + tools: list[dict[str, Any]] | None = None, + tool_choice: str | dict[str, Any] | None = None, + ) -> RequestData: + """准备高级请求""" + pass + + @abstractmethod + def parse_response( + self, + model: "LLMModel", + response_json: dict[str, Any], + is_advanced: bool = False, + ) -> ResponseData: + """解析API响应""" + pass + + @abstractmethod + def prepare_embedding_request( + self, + model: "LLMModel", + api_key: str, + texts: list[str], + task_type: "EmbeddingTaskType | str", + **kwargs: Any, + ) -> RequestData: + """准备文本嵌入请求""" + pass + + @abstractmethod + def parse_embedding_response( + self, response_json: dict[str, Any] + ) -> list[list[float]]: + """解析文本嵌入响应""" + pass + + def validate_embedding_response(self, response_json: dict[str, Any]) -> None: + """验证嵌入API响应""" + if "error" in response_json: + error_info = response_json["error"] + msg = ( + error_info.get("message", str(error_info)) + if isinstance(error_info, dict) + else str(error_info) + ) + raise LLMException( + f"嵌入API错误: {msg}", + code=LLMErrorCode.EMBEDDING_FAILED, + details=response_json, + ) + + def get_api_url(self, model: "LLMModel", endpoint: str) -> str: + """构建API URL""" + if not model.api_base: + raise LLMException( + f"模型 {model.model_name} 的 api_base 未设置", + code=LLMErrorCode.CONFIGURATION_ERROR, + ) + return f"{model.api_base.rstrip('/')}{endpoint}" + + def get_base_headers(self, api_key: str) -> dict[str, str]: + """获取基础请求头""" + from zhenxun.utils.user_agent import get_user_agent + + headers = get_user_agent() + headers.update( + { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}", + } + ) + return headers + + def convert_messages_to_openai_format( + self, messages: list["LLMMessage"] + ) -> list[dict[str, Any]]: + """将LLMMessage转换为OpenAI格式 - 通用方法""" + openai_messages: list[dict[str, Any]] = [] + for msg in messages: + openai_msg: dict[str, Any] = {"role": msg.role} + + if msg.role == "tool": + openai_msg["tool_call_id"] = msg.tool_call_id + openai_msg["name"] = msg.name + openai_msg["content"] = msg.content + else: + if isinstance(msg.content, str): + openai_msg["content"] = msg.content + else: + content_parts = [] + for part in msg.content: + if part.type == "text": + content_parts.append({"type": "text", "text": part.text}) + elif part.type == "image": + content_parts.append( + { + "type": "image_url", + "image_url": {"url": part.image_source}, + } + ) + openai_msg["content"] = content_parts + + if msg.role == "assistant" and msg.tool_calls: + assistant_tool_calls = [] + for call in msg.tool_calls: + assistant_tool_calls.append( + { + "id": call.id, + "type": "function", + "function": { + "name": call.function.name, + "arguments": call.function.arguments, + }, + } + ) + openai_msg["tool_calls"] = assistant_tool_calls + + if msg.name and msg.role != "tool": + openai_msg["name"] = msg.name + + openai_messages.append(openai_msg) + return openai_messages + + def parse_openai_response(self, response_json: dict[str, Any]) -> ResponseData: + """解析OpenAI格式的响应 - 通用方法""" + self.validate_response(response_json) + + try: + choices = response_json.get("choices", []) + if not choices: + logger.debug("OpenAI响应中没有choices,可能为空回复或流结束。") + return ResponseData(text="", raw_response=response_json) + + choice = choices[0] + message = choice.get("message", {}) + content = message.get("content", "") + + parsed_tool_calls: list[LLMToolCall] | None = None + if message_tool_calls := message.get("tool_calls"): + from ..types.models import LLMToolFunction + + parsed_tool_calls = [] + for tc_data in message_tool_calls: + try: + if tc_data.get("type") == "function": + parsed_tool_calls.append( + LLMToolCall( + id=tc_data["id"], + function=LLMToolFunction( + name=tc_data["function"]["name"], + arguments=tc_data["function"]["arguments"], + ), + ) + ) + except KeyError as e: + logger.warning( + f"解析OpenAI工具调用数据时缺少键: {tc_data}, 错误: {e}" + ) + except Exception as e: + logger.warning( + f"解析OpenAI工具调用数据时出错: {tc_data}, 错误: {e}" + ) + if not parsed_tool_calls: + parsed_tool_calls = None + + final_text = content if content is not None else "" + if not final_text and parsed_tool_calls: + final_text = f"请求调用 {len(parsed_tool_calls)} 个工具。" + + usage_info = response_json.get("usage") + + return ResponseData( + text=final_text, + tool_calls=parsed_tool_calls, + usage_info=usage_info, + raw_response=response_json, + ) + + except Exception as e: + logger.error(f"解析OpenAI格式响应失败: {e}", e=e) + raise LLMException( + f"解析API响应失败: {e}", + code=LLMErrorCode.RESPONSE_PARSE_ERROR, + cause=e, + ) + + def validate_response(self, response_json: dict[str, Any]) -> None: + """验证API响应,解析不同API的错误结构""" + if "error" in response_json: + error_info = response_json["error"] + + if isinstance(error_info, dict): + error_message = error_info.get("message", "未知错误") + error_code = error_info.get("code", "unknown") + error_type = error_info.get("type", "api_error") + + error_code_mapping = { + "invalid_api_key": LLMErrorCode.API_KEY_INVALID, + "authentication_failed": LLMErrorCode.API_KEY_INVALID, + "rate_limit_exceeded": LLMErrorCode.API_RATE_LIMITED, + "quota_exceeded": LLMErrorCode.API_RATE_LIMITED, + "model_not_found": LLMErrorCode.MODEL_NOT_FOUND, + "invalid_model": LLMErrorCode.MODEL_NOT_FOUND, + "context_length_exceeded": LLMErrorCode.CONTEXT_LENGTH_EXCEEDED, + "max_tokens_exceeded": LLMErrorCode.CONTEXT_LENGTH_EXCEEDED, + } + + llm_error_code = error_code_mapping.get( + error_code, LLMErrorCode.API_RESPONSE_INVALID + ) + + logger.error( + f"API返回错误: {error_message} " + f"(代码: {error_code}, 类型: {error_type})" + ) + else: + error_message = str(error_info) + error_code = "unknown" + llm_error_code = LLMErrorCode.API_RESPONSE_INVALID + + logger.error(f"API返回错误: {error_message}") + + raise LLMException( + f"API请求失败: {error_message}", + code=llm_error_code, + details={"api_error": error_info, "error_code": error_code}, + ) + + if "candidates" in response_json: + candidates = response_json.get("candidates", []) + if candidates: + candidate = candidates[0] + finish_reason = candidate.get("finishReason") + if finish_reason in ["SAFETY", "RECITATION"]: + safety_ratings = candidate.get("safetyRatings", []) + logger.warning( + f"Gemini内容被安全过滤: {finish_reason}, " + f"安全评级: {safety_ratings}" + ) + raise LLMException( + f"内容被安全过滤: {finish_reason}", + code=LLMErrorCode.CONTENT_FILTERED, + details={ + "finish_reason": finish_reason, + "safety_ratings": safety_ratings, + }, + ) + + if not response_json: + logger.error("API返回空响应") + raise LLMException( + "API返回空响应", + code=LLMErrorCode.API_RESPONSE_INVALID, + details={"response": response_json}, + ) + + def _apply_generation_config( + self, + model: "LLMModel", + config: "LLMGenerationConfig | None" = None, + ) -> dict[str, Any]: + """通用的配置应用逻辑""" + if config is not None: + return config.to_api_params(model.api_type, model.model_name) + + if model._generation_config is not None: + return model._generation_config.to_api_params( + model.api_type, model.model_name + ) + + base_config = {} + if model.temperature is not None: + base_config["temperature"] = model.temperature + if model.max_tokens is not None: + if model.api_type in ["gemini", "gemini_native"]: + base_config["maxOutputTokens"] = model.max_tokens + else: + base_config["max_tokens"] = model.max_tokens + + return base_config + + def apply_config_override( + self, + model: "LLMModel", + body: dict[str, Any], + config: "LLMGenerationConfig | None" = None, + ) -> dict[str, Any]: + """应用配置覆盖""" + config_params = self._apply_generation_config(model, config) + body.update(config_params) + return body + + +class OpenAICompatAdapter(BaseAdapter): + """ + 处理所有 OpenAI 兼容 API 的通用适配器。 + 消除 OpenAIAdapter 和 ZhipuAdapter 之间的代码重复。 + """ + + @abstractmethod + def get_chat_endpoint(self) -> str: + """子类必须实现,返回 chat completions 的端点""" + pass + + @abstractmethod + def get_embedding_endpoint(self) -> str: + """子类必须实现,返回 embeddings 的端点""" + pass + + def prepare_advanced_request( + self, + model: "LLMModel", + api_key: str, + messages: list["LLMMessage"], + config: "LLMGenerationConfig | None" = None, + tools: list[dict[str, Any]] | None = None, + tool_choice: str | dict[str, Any] | None = None, + ) -> RequestData: + """准备高级请求 - OpenAI兼容格式""" + url = self.get_api_url(model, self.get_chat_endpoint()) + headers = self.get_base_headers(api_key) + openai_messages = self.convert_messages_to_openai_format(messages) + + body = { + "model": model.model_name, + "messages": openai_messages, + } + + if tools: + body["tools"] = tools + if tool_choice: + body["tool_choice"] = tool_choice + + body = self.apply_config_override(model, body, config) + return RequestData(url=url, headers=headers, body=body) + + def parse_response( + self, + model: "LLMModel", + response_json: dict[str, Any], + is_advanced: bool = False, + ) -> ResponseData: + """解析响应 - 直接使用基类的 OpenAI 格式解析""" + _ = model, is_advanced # 未使用的参数 + return self.parse_openai_response(response_json) + + def prepare_embedding_request( + self, + model: "LLMModel", + api_key: str, + texts: list[str], + task_type: "EmbeddingTaskType | str", + **kwargs: Any, + ) -> RequestData: + """准备嵌入请求 - OpenAI兼容格式""" + _ = task_type # 未使用的参数 + url = self.get_api_url(model, self.get_embedding_endpoint()) + headers = self.get_base_headers(api_key) + + body = { + "model": model.model_name, + "input": texts, + } + + # 应用额外的配置参数 + if kwargs: + body.update(kwargs) + + return RequestData(url=url, headers=headers, body=body) + + def parse_embedding_response( + self, response_json: dict[str, Any] + ) -> list[list[float]]: + """解析嵌入响应 - OpenAI兼容格式""" + self.validate_embedding_response(response_json) + + try: + data = response_json.get("data", []) + if not data: + raise LLMException( + "嵌入响应中没有数据", + code=LLMErrorCode.EMBEDDING_FAILED, + details=response_json, + ) + + embeddings = [] + for item in data: + if "embedding" in item: + embeddings.append(item["embedding"]) + else: + raise LLMException( + "嵌入响应格式错误:缺少embedding字段", + code=LLMErrorCode.EMBEDDING_FAILED, + details=item, + ) + + return embeddings + + except Exception as e: + logger.error(f"解析嵌入响应失败: {e}", e=e) + raise LLMException( + f"解析嵌入响应失败: {e}", + code=LLMErrorCode.EMBEDDING_FAILED, + cause=e, + ) diff --git a/zhenxun/services/llm/adapters/factory.py b/zhenxun/services/llm/adapters/factory.py new file mode 100644 index 00000000..8652fc67 --- /dev/null +++ b/zhenxun/services/llm/adapters/factory.py @@ -0,0 +1,78 @@ +""" +LLM 适配器工厂类 +""" + +from typing import ClassVar + +from ..types.exceptions import LLMErrorCode, LLMException +from .base import BaseAdapter + + +class LLMAdapterFactory: + """LLM适配器工厂类""" + + _adapters: ClassVar[dict[str, BaseAdapter]] = {} + _api_type_mapping: ClassVar[dict[str, str]] = {} + + @classmethod + def initialize(cls) -> None: + """初始化默认适配器""" + if cls._adapters: + return + + from .gemini import GeminiAdapter + from .openai import OpenAIAdapter + from .zhipu import ZhipuAdapter + + cls.register_adapter(OpenAIAdapter()) + cls.register_adapter(ZhipuAdapter()) + cls.register_adapter(GeminiAdapter()) + + @classmethod + def register_adapter(cls, adapter: BaseAdapter) -> None: + """注册适配器""" + adapter_key = adapter.api_type + cls._adapters[adapter_key] = adapter + + for api_type in adapter.supported_api_types: + cls._api_type_mapping[api_type] = adapter_key + + @classmethod + def get_adapter(cls, api_type: str) -> BaseAdapter: + """获取适配器""" + cls.initialize() + + adapter_key = cls._api_type_mapping.get(api_type) + if not adapter_key: + raise LLMException( + f"不支持的API类型: {api_type}", + code=LLMErrorCode.UNKNOWN_API_TYPE, + details={ + "api_type": api_type, + "supported_types": list(cls._api_type_mapping.keys()), + }, + ) + + return cls._adapters[adapter_key] + + @classmethod + def list_supported_types(cls) -> list[str]: + """列出所有支持的API类型""" + cls.initialize() + return list(cls._api_type_mapping.keys()) + + @classmethod + def list_adapters(cls) -> dict[str, BaseAdapter]: + """列出所有注册的适配器""" + cls.initialize() + return cls._adapters.copy() + + +def get_adapter_for_api_type(api_type: str) -> BaseAdapter: + """获取指定API类型的适配器""" + return LLMAdapterFactory.get_adapter(api_type) + + +def register_adapter(adapter: BaseAdapter) -> None: + """注册新的适配器""" + LLMAdapterFactory.register_adapter(adapter) diff --git a/zhenxun/services/llm/adapters/gemini.py b/zhenxun/services/llm/adapters/gemini.py new file mode 100644 index 00000000..0ca22185 --- /dev/null +++ b/zhenxun/services/llm/adapters/gemini.py @@ -0,0 +1,596 @@ +""" +Gemini API 适配器 +""" + +from typing import TYPE_CHECKING, Any + +from zhenxun.services.log import logger + +from ..types.exceptions import LLMErrorCode, LLMException +from .base import BaseAdapter, RequestData, ResponseData + +if TYPE_CHECKING: + from ..config.generation import LLMGenerationConfig + from ..service import LLMModel + from ..types.content import LLMMessage + from ..types.enums import EmbeddingTaskType + from ..types.models import LLMToolCall + + +class GeminiAdapter(BaseAdapter): + """Gemini API 适配器""" + + @property + def api_type(self) -> str: + return "gemini" + + @property + def supported_api_types(self) -> list[str]: + return ["gemini"] + + def get_base_headers(self, api_key: str) -> dict[str, str]: + """获取基础请求头""" + from zhenxun.utils.user_agent import get_user_agent + + headers = get_user_agent() + headers.update({"Content-Type": "application/json"}) + headers["x-goog-api-key"] = api_key + + return headers + + def prepare_advanced_request( + self, + model: "LLMModel", + api_key: str, + messages: list["LLMMessage"], + config: "LLMGenerationConfig | None" = None, + tools: list[dict[str, Any]] | None = None, + tool_choice: str | dict[str, Any] | None = None, + ) -> RequestData: + """准备高级请求""" + return self._prepare_request( + model, api_key, messages, config, tools, tool_choice + ) + + def _prepare_request( + self, + model: "LLMModel", + api_key: str, + messages: list["LLMMessage"], + config: "LLMGenerationConfig | None" = None, + tools: list[dict[str, Any]] | None = None, + tool_choice: str | dict[str, Any] | None = None, + ) -> RequestData: + """准备 Gemini API 请求 - 支持所有高级功能""" + effective_config = config if config is not None else model._generation_config + + endpoint = self._get_gemini_endpoint(model, effective_config) + url = self.get_api_url(model, endpoint) + headers = self.get_base_headers(api_key) + + gemini_contents: list[dict[str, Any]] = [] + system_instruction_parts: list[dict[str, Any]] | None = None + + for msg in messages: + current_parts: list[dict[str, Any]] = [] + if msg.role == "system": + if isinstance(msg.content, str): + system_instruction_parts = [{"text": msg.content}] + elif isinstance(msg.content, list): + system_instruction_parts = [ + part.convert_for_api("gemini") for part in msg.content + ] + continue + + elif msg.role == "user": + if isinstance(msg.content, str): + current_parts.append({"text": msg.content}) + elif isinstance(msg.content, list): + for part_obj in msg.content: + current_parts.append(part_obj.convert_for_api("gemini")) + gemini_contents.append({"role": "user", "parts": current_parts}) + + elif msg.role == "assistant" or msg.role == "model": + if isinstance(msg.content, str) and msg.content: + current_parts.append({"text": msg.content}) + elif isinstance(msg.content, list): + for part_obj in msg.content: + current_parts.append(part_obj.convert_for_api("gemini")) + + if msg.tool_calls: + import json + + for call in msg.tool_calls: + current_parts.append( + { + "functionCall": { + "name": call.function.name, + "args": json.loads(call.function.arguments), + } + } + ) + if current_parts: + gemini_contents.append({"role": "model", "parts": current_parts}) + + elif msg.role == "tool": + if not msg.name: + raise ValueError("Gemini 工具消息必须包含 'name' 字段(函数名)。") + + import json + + try: + content_str = ( + msg.content + if isinstance(msg.content, str) + else str(msg.content) + ) + tool_result_obj = json.loads(content_str) + except json.JSONDecodeError: + content_str = ( + msg.content + if isinstance(msg.content, str) + else str(msg.content) + ) + logger.warning( + f"工具 {msg.name} 的结果不是有效的 JSON: {content_str}. " + f"包装为原始字符串。" + ) + tool_result_obj = {"raw_output": content_str} + + current_parts.append( + { + "functionResponse": { + "name": msg.name, + "response": tool_result_obj, + } + } + ) + gemini_contents.append({"role": "function", "parts": current_parts}) + + body: dict[str, Any] = {"contents": gemini_contents} + + if system_instruction_parts: + body["systemInstruction"] = {"parts": system_instruction_parts} + + all_tools_for_request = [] + if tools: + for tool_item in tools: + if isinstance(tool_item, dict): + if "name" in tool_item and "description" in tool_item: + all_tools_for_request.append( + {"functionDeclarations": [tool_item]} + ) + else: + all_tools_for_request.append(tool_item) + else: + all_tools_for_request.append(tool_item) + + if effective_config: + if getattr(effective_config, "enable_grounding", False): + has_explicit_gs_tool = any( + "googleSearch" in tool_item for tool_item in all_tools_for_request + ) + if not has_explicit_gs_tool: + all_tools_for_request.append({"googleSearch": {}}) + logger.debug("隐式启用 Google Search 工具进行信息来源关联。") + + if getattr(effective_config, "enable_code_execution", False): + has_explicit_ce_tool = any( + "codeExecution" in tool_item for tool_item in all_tools_for_request + ) + if not has_explicit_ce_tool: + all_tools_for_request.append({"codeExecution": {}}) + logger.debug("隐式启用代码执行工具。") + + if all_tools_for_request: + gemini_api_tools = self._convert_tools_to_gemini_format( + all_tools_for_request + ) + if gemini_api_tools: + body["tools"] = gemini_api_tools + + final_tool_choice = tool_choice + if final_tool_choice is None and effective_config: + final_tool_choice = getattr(effective_config, "tool_choice", None) + + if final_tool_choice: + if isinstance(final_tool_choice, str): + mode_upper = final_tool_choice.upper() + if mode_upper in ["AUTO", "NONE", "ANY"]: + body["toolConfig"] = {"functionCallingConfig": {"mode": mode_upper}} + else: + body["toolConfig"] = self._convert_tool_choice_to_gemini( + final_tool_choice + ) + else: + body["toolConfig"] = self._convert_tool_choice_to_gemini( + final_tool_choice + ) + + final_generation_config = self._build_gemini_generation_config( + model, effective_config + ) + if final_generation_config: + body["generationConfig"] = final_generation_config + + safety_settings = self._build_safety_settings(effective_config) + if safety_settings: + body["safetySettings"] = safety_settings + + return RequestData(url=url, headers=headers, body=body) + + def apply_config_override( + self, + model: "LLMModel", + body: dict[str, Any], + config: "LLMGenerationConfig | None" = None, + ) -> dict[str, Any]: + """应用配置覆盖 - Gemini 不需要额外的配置覆盖""" + return body + + def _get_gemini_endpoint( + self, model: "LLMModel", config: "LLMGenerationConfig | None" = None + ) -> str: + """根据配置选择Gemini API端点""" + if config: + if getattr(config, "enable_code_execution", False): + return f"/v1beta/models/{model.model_name}:generateContent" + + if getattr(config, "enable_grounding", False): + return f"/v1beta/models/{model.model_name}:generateContent" + + return f"/v1beta/models/{model.model_name}:generateContent" + + def _convert_tools_to_gemini_format( + self, tools: list[dict[str, Any]] + ) -> list[dict[str, Any]]: + """转换工具格式为Gemini格式""" + gemini_tools = [] + + for tool in tools: + if tool.get("type") == "function": + func = tool["function"] + gemini_tool = { + "functionDeclarations": [ + { + "name": func["name"], + "description": func.get("description", ""), + "parameters": func.get("parameters", {}), + } + ] + } + gemini_tools.append(gemini_tool) + elif tool.get("type") == "code_execution": + gemini_tools.append( + {"codeExecution": {"language": tool.get("language", "python")}} + ) + elif tool.get("type") == "google_search": + gemini_tools.append({"googleSearch": {}}) + elif "googleSearch" in tool: + gemini_tools.append({"googleSearch": tool["googleSearch"]}) + elif "codeExecution" in tool: + gemini_tools.append({"codeExecution": tool["codeExecution"]}) + + return gemini_tools + + def _convert_tool_choice_to_gemini( + self, tool_choice_value: str | dict[str, Any] + ) -> dict[str, Any]: + """转换工具选择策略为Gemini格式""" + if isinstance(tool_choice_value, str): + mode_upper = tool_choice_value.upper() + if mode_upper in ["AUTO", "NONE", "ANY"]: + return {"functionCallingConfig": {"mode": mode_upper}} + else: + logger.warning( + f"不支持的 tool_choice 字符串值: '{tool_choice_value}'。" + f"回退到 AUTO。" + ) + return {"functionCallingConfig": {"mode": "AUTO"}} + + elif isinstance(tool_choice_value, dict): + if ( + tool_choice_value.get("type") == "function" + and "function" in tool_choice_value + ): + func_name = tool_choice_value["function"].get("name") + if func_name: + return { + "functionCallingConfig": { + "mode": "ANY", + "allowedFunctionNames": [func_name], + } + } + else: + logger.warning( + f"tool_choice dict 中的函数名无效: {tool_choice_value}。" + f"回退到 AUTO。" + ) + return {"functionCallingConfig": {"mode": "AUTO"}} + + elif "functionCallingConfig" in tool_choice_value: + return { + "functionCallingConfig": tool_choice_value["functionCallingConfig"] + } + + else: + logger.warning( + f"不支持的 tool_choice dict 值: {tool_choice_value}。回退到 AUTO。" + ) + return {"functionCallingConfig": {"mode": "AUTO"}} + + logger.warning( + f"tool_choice 的类型无效: {type(tool_choice_value)}。回退到 AUTO。" + ) + return {"functionCallingConfig": {"mode": "AUTO"}} + + def _build_gemini_generation_config( + self, model: "LLMModel", config: "LLMGenerationConfig | None" = None + ) -> dict[str, Any]: + """构建Gemini生成配置""" + generation_config: dict[str, Any] = {} + + effective_config = config if config is not None else model._generation_config + + if effective_config: + base_api_params = effective_config.to_api_params( + api_type="gemini", model_name=model.model_name + ) + generation_config.update(base_api_params) + + if getattr(effective_config, "response_mime_type", None): + generation_config["responseMimeType"] = ( + effective_config.response_mime_type + ) + + if getattr(effective_config, "response_schema", None): + generation_config["responseSchema"] = effective_config.response_schema + + thinking_budget = getattr(effective_config, "thinking_budget", None) + if thinking_budget is not None: + if "thinkingConfig" not in generation_config: + generation_config["thinkingConfig"] = {} + generation_config["thinkingConfig"]["thinkingBudget"] = thinking_budget + + if getattr(effective_config, "response_modalities", None): + modalities = effective_config.response_modalities + if isinstance(modalities, list): + generation_config["responseModalities"] = [ + m.upper() for m in modalities + ] + elif isinstance(modalities, str): + generation_config["responseModalities"] = [modalities.upper()] + + generation_config = { + k: v for k, v in generation_config.items() if v is not None + } + + if generation_config: + param_keys = list(generation_config.keys()) + logger.debug( + f"构建Gemini生成配置完成,包含 {len(generation_config)} 个参数: " + f"{param_keys}" + ) + + return generation_config + + def _build_safety_settings( + self, config: "LLMGenerationConfig | None" = None + ) -> list[dict[str, Any]] | None: + """构建安全设置""" + if not config: + return None + + safety_settings = [] + + safety_categories = [ + "HARM_CATEGORY_HARASSMENT", + "HARM_CATEGORY_HATE_SPEECH", + "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "HARM_CATEGORY_DANGEROUS_CONTENT", + ] + + custom_safety_settings = getattr(config, "safety_settings", None) + if custom_safety_settings: + for category, threshold in custom_safety_settings.items(): + safety_settings.append({"category": category, "threshold": threshold}) + else: + for category in safety_categories: + safety_settings.append( + {"category": category, "threshold": "BLOCK_MEDIUM_AND_ABOVE"} + ) + + return safety_settings if safety_settings else None + + def parse_response( + self, + model: "LLMModel", + response_json: dict[str, Any], + is_advanced: bool = False, + ) -> ResponseData: + """解析API响应""" + return self._parse_response(model, response_json, is_advanced) + + def _parse_response( + self, + model: "LLMModel", + response_json: dict[str, Any], + is_advanced: bool = False, + ) -> ResponseData: + """解析 Gemini API 响应""" + _ = is_advanced + self.validate_response(response_json) + + try: + candidates = response_json.get("candidates", []) + if not candidates: + logger.debug("Gemini响应中没有candidates。") + return ResponseData(text="", raw_response=response_json) + + candidate = candidates[0] + + if candidate.get("finishReason") in [ + "RECITATION", + "OTHER", + ] and not candidate.get("content"): + logger.warning( + f"Gemini candidate finished with reason " + f"'{candidate.get('finishReason')}' and no content." + ) + return ResponseData( + text="", + raw_response=response_json, + usage_info=response_json.get("usageMetadata"), + ) + + content_data = candidate.get("content", {}) + parts = content_data.get("parts", []) + + text_content = "" + parsed_tool_calls: list["LLMToolCall"] | None = None + + for part in parts: + if "text" in part: + text_content += part["text"] + elif "functionCall" in part: + if parsed_tool_calls is None: + parsed_tool_calls = [] + fc_data = part["functionCall"] + try: + import json + + from ..types.models import LLMToolCall, LLMToolFunction + + call_id = f"call_{model.provider_name}_{len(parsed_tool_calls)}" + parsed_tool_calls.append( + LLMToolCall( + id=call_id, + function=LLMToolFunction( + name=fc_data["name"], + arguments=json.dumps(fc_data["args"]), + ), + ) + ) + except KeyError as e: + logger.warning( + f"解析Gemini functionCall时缺少键: {fc_data}, 错误: {e}" + ) + except Exception as e: + logger.warning( + f"解析Gemini functionCall时出错: {fc_data}, 错误: {e}" + ) + elif "codeExecutionResult" in part: + result = part["codeExecutionResult"] + if result.get("outcome") == "OK": + output = result.get("output", "") + text_content += f"\n[代码执行结果]:\n{output}\n" + else: + text_content += ( + f"\n[代码执行失败]: {result.get('outcome', 'UNKNOWN')}\n" + ) + + usage_info = response_json.get("usageMetadata") + + grounding_metadata_obj = None + if grounding_data := candidate.get("groundingMetadata"): + try: + from ..types.models import LLMGroundingMetadata + + grounding_metadata_obj = LLMGroundingMetadata(**grounding_data) + except Exception as e: + logger.warning(f"无法解析Grounding元数据: {grounding_data}, {e}") + + return ResponseData( + text=text_content, + tool_calls=parsed_tool_calls, + usage_info=usage_info, + raw_response=response_json, + grounding_metadata=grounding_metadata_obj, + ) + + except Exception as e: + logger.error(f"解析 Gemini 响应失败: {e}", e=e) + raise LLMException( + f"解析API响应失败: {e}", + code=LLMErrorCode.RESPONSE_PARSE_ERROR, + cause=e, + ) + + def prepare_embedding_request( + self, + model: "LLMModel", + api_key: str, + texts: list[str], + task_type: "EmbeddingTaskType | str", + **kwargs: Any, + ) -> RequestData: + """准备文本嵌入请求""" + api_model_name = model.model_name + if not api_model_name.startswith("models/"): + api_model_name = f"models/{api_model_name}" + + url = self.get_api_url(model, f"/{api_model_name}:batchEmbedContents") + headers = self.get_base_headers(api_key) + + requests_payload = [] + for text_content in texts: + request_item: dict[str, Any] = { + "content": {"parts": [{"text": text_content}]}, + } + + from ..types.enums import EmbeddingTaskType + + if task_type and task_type != EmbeddingTaskType.RETRIEVAL_DOCUMENT: + request_item["task_type"] = str(task_type).upper() + if title := kwargs.get("title"): + request_item["title"] = title + if output_dimensionality := kwargs.get("output_dimensionality"): + request_item["output_dimensionality"] = output_dimensionality + + requests_payload.append(request_item) + + body = {"requests": requests_payload} + return RequestData(url=url, headers=headers, body=body) + + def parse_embedding_response( + self, response_json: dict[str, Any] + ) -> list[list[float]]: + """解析文本嵌入响应""" + try: + embeddings_data = response_json["embeddings"] + return [item["values"] for item in embeddings_data] + except KeyError as e: + logger.error(f"解析Gemini嵌入响应时缺少键: {e}. 响应: {response_json}") + raise LLMException( + "Gemini嵌入响应格式错误", + code=LLMErrorCode.RESPONSE_PARSE_ERROR, + details={"error": str(e)}, + ) + except Exception as e: + logger.error( + f"解析Gemini嵌入响应时发生未知错误: {e}. 响应: {response_json}" + ) + raise LLMException( + f"解析Gemini嵌入响应失败: {e}", + code=LLMErrorCode.RESPONSE_PARSE_ERROR, + cause=e, + ) + + def validate_embedding_response(self, response_json: dict[str, Any]) -> None: + """验证嵌入响应""" + super().validate_embedding_response(response_json) + if "embeddings" not in response_json or not isinstance( + response_json["embeddings"], list + ): + raise LLMException( + "Gemini嵌入响应缺少'embeddings'字段或格式不正确", + code=LLMErrorCode.RESPONSE_PARSE_ERROR, + details=response_json, + ) + for item in response_json["embeddings"]: + if "values" not in item: + raise LLMException( + "Gemini嵌入响应的条目中缺少'values'字段", + code=LLMErrorCode.RESPONSE_PARSE_ERROR, + details=response_json, + ) diff --git a/zhenxun/services/llm/adapters/openai.py b/zhenxun/services/llm/adapters/openai.py new file mode 100644 index 00000000..046f0277 --- /dev/null +++ b/zhenxun/services/llm/adapters/openai.py @@ -0,0 +1,57 @@ +""" +OpenAI API 适配器 + +支持 OpenAI、DeepSeek 和其他 OpenAI 兼容的 API 服务。 +""" + +from typing import TYPE_CHECKING + +from .base import OpenAICompatAdapter, RequestData + +if TYPE_CHECKING: + from ..service import LLMModel + + +class OpenAIAdapter(OpenAICompatAdapter): + """OpenAI兼容API适配器""" + + @property + def api_type(self) -> str: + return "openai" + + @property + def supported_api_types(self) -> list[str]: + return ["openai", "deepseek", "general_openai_compat"] + + def get_chat_endpoint(self) -> str: + """返回聊天完成端点""" + return "/v1/chat/completions" + + def get_embedding_endpoint(self) -> str: + """返回嵌入端点""" + return "/v1/embeddings" + + def prepare_simple_request( + self, + model: "LLMModel", + api_key: str, + prompt: str, + history: list[dict[str, str]] | None = None, + ) -> RequestData: + """准备简单文本生成请求 - OpenAI优化实现""" + url = self.get_api_url(model, self.get_chat_endpoint()) + headers = self.get_base_headers(api_key) + + messages = [] + if history: + messages.extend(history) + messages.append({"role": "user", "content": prompt}) + + body = { + "model": model.model_name, + "messages": messages, + } + + body = self.apply_config_override(model, body) + + return RequestData(url=url, headers=headers, body=body) diff --git a/zhenxun/services/llm/adapters/zhipu.py b/zhenxun/services/llm/adapters/zhipu.py new file mode 100644 index 00000000..e5eb032f --- /dev/null +++ b/zhenxun/services/llm/adapters/zhipu.py @@ -0,0 +1,57 @@ +""" +智谱 AI API 适配器 + +支持智谱 AI 的 GLM 系列模型,使用 OpenAI 兼容的接口格式。 +""" + +from typing import TYPE_CHECKING + +from .base import OpenAICompatAdapter, RequestData + +if TYPE_CHECKING: + from ..service import LLMModel + + +class ZhipuAdapter(OpenAICompatAdapter): + """智谱AI适配器 - 使用智谱AI专用的OpenAI兼容接口""" + + @property + def api_type(self) -> str: + return "zhipu" + + @property + def supported_api_types(self) -> list[str]: + return ["zhipu"] + + def get_chat_endpoint(self) -> str: + """返回智谱AI聊天完成端点""" + return "/api/paas/v4/chat/completions" + + def get_embedding_endpoint(self) -> str: + """返回智谱AI嵌入端点""" + return "/v4/embeddings" + + def prepare_simple_request( + self, + model: "LLMModel", + api_key: str, + prompt: str, + history: list[dict[str, str]] | None = None, + ) -> RequestData: + """准备简单文本生成请求 - 智谱AI优化实现""" + url = self.get_api_url(model, self.get_chat_endpoint()) + headers = self.get_base_headers(api_key) + + messages = [] + if history: + messages.extend(history) + messages.append({"role": "user", "content": prompt}) + + body = { + "model": model.model_name, + "messages": messages, + } + + body = self.apply_config_override(model, body) + + return RequestData(url=url, headers=headers, body=body) diff --git a/zhenxun/services/llm/api.py b/zhenxun/services/llm/api.py new file mode 100644 index 00000000..7aaed437 --- /dev/null +++ b/zhenxun/services/llm/api.py @@ -0,0 +1,530 @@ +""" +LLM 服务的高级 API 接口 +""" + +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from typing import Any + +from nonebot_plugin_alconna.uniseg import UniMessage + +from zhenxun.services.log import logger + +from .config import CommonOverrides, LLMGenerationConfig +from .config.providers import get_ai_config +from .manager import get_global_default_model_name, get_model_instance +from .types import ( + EmbeddingTaskType, + LLMContentPart, + LLMErrorCode, + LLMException, + LLMMessage, + LLMResponse, + LLMTool, + ModelName, +) +from .utils import create_multimodal_message, unimsg_to_llm_parts + + +class TaskType(Enum): + """任务类型枚举""" + + CHAT = "chat" + CODE = "code" + SEARCH = "search" + ANALYSIS = "analysis" + GENERATION = "generation" + MULTIMODAL = "multimodal" + + +@dataclass +class AIConfig: + """AI配置类 - 简化版本""" + + model: ModelName = None + default_embedding_model: ModelName = None + temperature: float | None = None + max_tokens: int | None = None + enable_cache: bool = False + enable_code: bool = False + enable_search: bool = False + timeout: int | None = None + + enable_gemini_json_mode: bool = False + enable_gemini_thinking: bool = False + enable_gemini_safe_mode: bool = False + enable_gemini_multimodal: bool = False + enable_gemini_grounding: bool = False + + def __post_init__(self): + """初始化后从配置中读取默认值""" + ai_config = get_ai_config() + if self.model is None: + self.model = ai_config.get("default_model_name") + if self.timeout is None: + self.timeout = ai_config.get("timeout", 180) + + +class AI: + """统一的AI服务类 - 平衡设计版本 + + 提供三层API: + 1. 简单方法:ai.chat(), ai.code(), ai.search() + 2. 标准方法:ai.analyze() 支持复杂参数 + 3. 高级方法:通过get_model_instance()直接访问 + """ + + def __init__( + self, config: AIConfig | None = None, history: list[LLMMessage] | None = None + ): + """ + 初始化AI服务 + + Args: + config: AI 配置. + history: 可选的初始对话历史. + """ + self.config = config or AIConfig() + self.history = history or [] + + def clear_history(self): + """清空当前会话的历史记录""" + self.history = [] + logger.info("AI session history cleared.") + + async def chat( + self, + message: str | LLMMessage | list[LLMContentPart], + *, + model: ModelName = None, + **kwargs: Any, + ) -> str: + """ + 进行一次聊天对话。 + 此方法会自动使用和更新会话内的历史记录。 + """ + current_message: LLMMessage + if isinstance(message, str): + current_message = LLMMessage.user(message) + elif isinstance(message, list) and all( + isinstance(part, LLMContentPart) for part in message + ): + current_message = LLMMessage.user(message) + elif isinstance(message, LLMMessage): + current_message = message + else: + raise LLMException( + f"AI.chat 不支持的消息类型: {type(message)}. " + "请使用 str, LLMMessage, 或 list[LLMContentPart]. " + "对于更复杂的多模态输入或文件路径,请使用 AI.analyze().", + code=LLMErrorCode.API_REQUEST_FAILED, + ) + + final_messages = [*self.history, current_message] + + response = await self._execute_generation( + final_messages, model, "聊天失败", kwargs + ) + + self.history.append(current_message) + self.history.append(LLMMessage.assistant_text_response(response.text)) + + return response.text + + async def code( + self, + prompt: str, + *, + model: ModelName = None, + timeout: int | None = None, + **kwargs: Any, + ) -> dict[str, Any]: + """代码执行""" + resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash" + + config = CommonOverrides.gemini_code_execution() + if timeout: + config.custom_params = config.custom_params or {} + config.custom_params["code_execution_timeout"] = timeout + + messages = [LLMMessage.user(prompt)] + + response = await self._execute_generation( + messages, resolved_model, "代码执行失败", kwargs, base_config=config + ) + + return { + "text": response.text, + "code_executions": response.code_executions or [], + "success": True, + } + + async def search( + self, + query: str | UniMessage, + *, + model: ModelName = None, + instruction: str = "", + **kwargs: Any, + ) -> dict[str, Any]: + """信息搜索 - 支持多模态输入""" + resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash" + config = CommonOverrides.gemini_grounding() + + if isinstance(query, str): + messages = [LLMMessage.user(query)] + elif isinstance(query, UniMessage): + content_parts = await unimsg_to_llm_parts(query) + + final_messages: list[LLMMessage] = [] + if instruction: + final_messages.append(LLMMessage.system(instruction)) + + if not content_parts: + if instruction: + final_messages.append(LLMMessage.user(instruction)) + else: + raise LLMException( + "搜索内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED + ) + else: + final_messages.append(LLMMessage.user(content_parts)) + + messages = final_messages + else: + raise LLMException( + f"不支持的搜索输入类型: {type(query)}. 请使用 str 或 UniMessage.", + code=LLMErrorCode.API_REQUEST_FAILED, + ) + + response = await self._execute_generation( + messages, resolved_model, "信息搜索失败", kwargs, base_config=config + ) + + result = { + "text": response.text, + "sources": [], + "queries": [], + "success": True, + } + + if response.grounding_metadata: + result["sources"] = response.grounding_metadata.grounding_attributions or [] + result["queries"] = response.grounding_metadata.web_search_queries or [] + + return result + + async def analyze( + self, + message: UniMessage, + *, + instruction: str = "", + model: ModelName = None, + tools: list[dict[str, Any]] | None = None, + tool_config: dict[str, Any] | None = None, + **kwargs: Any, + ) -> str | LLMResponse: + """ + 内容分析 - 接收 UniMessage 物件进行多模态分析和工具呼叫。 + 这是处理复杂互动的主要方法。 + """ + content_parts = await unimsg_to_llm_parts(message) + + final_messages: list[LLMMessage] = [] + if instruction: + final_messages.append(LLMMessage.system(instruction)) + + if not content_parts: + if instruction: + final_messages.append(LLMMessage.user(instruction)) + else: + raise LLMException( + "分析内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED + ) + else: + final_messages.append(LLMMessage.user(content_parts)) + + llm_tools = None + if tools: + llm_tools = [] + for tool_dict in tools: + if isinstance(tool_dict, dict): + if "name" in tool_dict and "description" in tool_dict: + llm_tool = LLMTool( + type="function", + function={ + "name": tool_dict["name"], + "description": tool_dict["description"], + "parameters": tool_dict.get("parameters", {}), + }, + ) + llm_tools.append(llm_tool) + else: + llm_tools.append(LLMTool(**tool_dict)) + else: + llm_tools.append(tool_dict) + + tool_choice = None + if tool_config: + mode = tool_config.get("mode", "auto") + if mode == "auto": + tool_choice = "auto" + elif mode == "any": + tool_choice = "any" + elif mode == "none": + tool_choice = "none" + + response = await self._execute_generation( + final_messages, + model, + "内容分析失败", + kwargs, + llm_tools=llm_tools, + tool_choice=tool_choice, + ) + + if response.tool_calls: + return response + return response.text + + async def _execute_generation( + self, + messages: list[LLMMessage], + model_name: ModelName, + error_message: str, + config_overrides: dict[str, Any], + llm_tools: list[LLMTool] | None = None, + tool_choice: str | dict[str, Any] | None = None, + base_config: LLMGenerationConfig | None = None, + ) -> LLMResponse: + """通用的生成执行方法,封装重复的模型获取、配置合并和异常处理逻辑""" + try: + resolved_model_name = self._resolve_model_name( + model_name or self.config.model + ) + final_config_dict = self._merge_config( + config_overrides, base_config=base_config + ) + + async with await get_model_instance( + resolved_model_name, override_config=final_config_dict + ) as model_instance: + return await model_instance.generate_response( + messages, tools=llm_tools, tool_choice=tool_choice + ) + except LLMException: + raise + except Exception as e: + logger.error(f"{error_message}: {e}", e=e) + raise LLMException(f"{error_message}: {e}", cause=e) + + def _resolve_model_name(self, model_name: ModelName) -> str: + """解析模型名称""" + if model_name: + return model_name + + default_model = get_global_default_model_name() + if default_model: + return default_model + + raise LLMException( + "未指定模型名称且未设置全局默认模型", + code=LLMErrorCode.MODEL_NOT_FOUND, + ) + + def _merge_config( + self, + user_config: dict[str, Any], + base_config: LLMGenerationConfig | None = None, + ) -> dict[str, Any]: + """合并配置""" + final_config = {} + if base_config: + final_config.update(base_config.to_dict()) + + if self.config.temperature is not None: + final_config["temperature"] = self.config.temperature + if self.config.max_tokens is not None: + final_config["max_tokens"] = self.config.max_tokens + + if self.config.enable_cache: + final_config["enable_caching"] = True + if self.config.enable_code: + final_config["enable_code_execution"] = True + if self.config.enable_search: + final_config["enable_grounding"] = True + + if self.config.enable_gemini_json_mode: + final_config["response_mime_type"] = "application/json" + if self.config.enable_gemini_thinking: + final_config["thinking_budget"] = 0.8 + if self.config.enable_gemini_safe_mode: + final_config["safety_settings"] = ( + CommonOverrides.gemini_safe().safety_settings + ) + if self.config.enable_gemini_multimodal: + final_config.update(CommonOverrides.gemini_multimodal().to_dict()) + if self.config.enable_gemini_grounding: + final_config["enable_grounding"] = True + + final_config.update(user_config) + + return final_config + + async def embed( + self, + texts: list[str] | str, + *, + model: ModelName = None, + task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT, + **kwargs: Any, + ) -> list[list[float]]: + """生成文本嵌入向量""" + if isinstance(texts, str): + texts = [texts] + if not texts: + return [] + + try: + resolved_model_str = ( + model or self.config.default_embedding_model or self.config.model + ) + if not resolved_model_str: + raise LLMException( + "使用 embed 功能时必须指定嵌入模型名称," + "或在 AIConfig 中配置 default_embedding_model。", + code=LLMErrorCode.MODEL_NOT_FOUND, + ) + resolved_model_str = self._resolve_model_name(resolved_model_str) + + async with await get_model_instance( + resolved_model_str, + override_config=None, + ) as embedding_model_instance: + return await embedding_model_instance.generate_embeddings( + texts, task_type=task_type, **kwargs + ) + except LLMException: + raise + except Exception as e: + logger.error(f"文本嵌入失败: {e}", e=e) + raise LLMException( + f"文本嵌入失败: {e}", code=LLMErrorCode.EMBEDDING_FAILED, cause=e + ) + + +async def chat( + message: str | LLMMessage | list[LLMContentPart], + *, + model: ModelName = None, + **kwargs: Any, +) -> str: + """聊天对话便捷函数""" + ai = AI() + return await ai.chat(message, model=model, **kwargs) + + +async def code( + prompt: str, + *, + model: ModelName = None, + timeout: int | None = None, + **kwargs: Any, +) -> dict[str, Any]: + """代码执行便捷函数""" + ai = AI() + return await ai.code(prompt, model=model, timeout=timeout, **kwargs) + + +async def search( + query: str | UniMessage, + *, + model: ModelName = None, + instruction: str = "", + **kwargs: Any, +) -> dict[str, Any]: + """信息搜索便捷函数""" + ai = AI() + return await ai.search(query, model=model, instruction=instruction, **kwargs) + + +async def analyze( + message: UniMessage, + *, + instruction: str = "", + model: ModelName = None, + tools: list[dict[str, Any]] | None = None, + tool_config: dict[str, Any] | None = None, + **kwargs: Any, +) -> str | LLMResponse: + """内容分析便捷函数""" + ai = AI() + return await ai.analyze( + message, + instruction=instruction, + model=model, + tools=tools, + tool_config=tool_config, + **kwargs, + ) + + +async def analyze_with_images( + text: str, + images: list[str | Path | bytes] | str | Path | bytes, + *, + instruction: str = "", + model: ModelName = None, + **kwargs: Any, +) -> str | LLMResponse: + """图片分析便捷函数""" + message = create_multimodal_message(text=text, images=images) + return await analyze(message, instruction=instruction, model=model, **kwargs) + + +async def analyze_multimodal( + text: str | None = None, + images: list[str | Path | bytes] | str | Path | bytes | None = None, + videos: list[str | Path | bytes] | str | Path | bytes | None = None, + audios: list[str | Path | bytes] | str | Path | bytes | None = None, + *, + instruction: str = "", + model: ModelName = None, + **kwargs: Any, +) -> str | LLMResponse: + """多模态分析便捷函数""" + message = create_multimodal_message( + text=text, images=images, videos=videos, audios=audios + ) + return await analyze(message, instruction=instruction, model=model, **kwargs) + + +async def search_multimodal( + text: str | None = None, + images: list[str | Path | bytes] | str | Path | bytes | None = None, + videos: list[str | Path | bytes] | str | Path | bytes | None = None, + audios: list[str | Path | bytes] | str | Path | bytes | None = None, + *, + instruction: str = "", + model: ModelName = None, + **kwargs: Any, +) -> dict[str, Any]: + """多模态搜索便捷函数""" + message = create_multimodal_message( + text=text, images=images, videos=videos, audios=audios + ) + ai = AI() + return await ai.search(message, model=model, instruction=instruction, **kwargs) + + +async def embed( + texts: list[str] | str, + *, + model: ModelName = None, + task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT, + **kwargs: Any, +) -> list[list[float]]: + """文本嵌入便捷函数""" + ai = AI() + return await ai.embed(texts, model=model, task_type=task_type, **kwargs) diff --git a/zhenxun/services/llm/config/__init__.py b/zhenxun/services/llm/config/__init__.py new file mode 100644 index 00000000..09fd9599 --- /dev/null +++ b/zhenxun/services/llm/config/__init__.py @@ -0,0 +1,35 @@ +""" +LLM 配置模块 + +提供生成配置、预设配置和配置验证功能。 +""" + +from .generation import ( + LLMGenerationConfig, + ModelConfigOverride, + apply_api_specific_mappings, + create_generation_config_from_kwargs, + validate_override_params, +) +from .presets import CommonOverrides +from .providers import ( + LLMConfig, + get_llm_config, + register_llm_configs, + set_default_model, + validate_llm_config, +) + +__all__ = [ + "CommonOverrides", + "LLMConfig", + "LLMGenerationConfig", + "ModelConfigOverride", + "apply_api_specific_mappings", + "create_generation_config_from_kwargs", + "get_llm_config", + "register_llm_configs", + "set_default_model", + "validate_llm_config", + "validate_override_params", +] diff --git a/zhenxun/services/llm/config/generation.py b/zhenxun/services/llm/config/generation.py new file mode 100644 index 00000000..a143dedd --- /dev/null +++ b/zhenxun/services/llm/config/generation.py @@ -0,0 +1,260 @@ +""" +LLM 生成配置相关类和函数 +""" + +from typing import Any + +from pydantic import BaseModel, Field + +from zhenxun.services.log import logger + +from ..types.enums import ResponseFormat +from ..types.exceptions import LLMErrorCode, LLMException + + +class ModelConfigOverride(BaseModel): + """模型配置覆盖参数""" + + temperature: float | None = Field( + default=None, ge=0.0, le=2.0, description="生成温度" + ) + max_tokens: int | None = Field(default=None, gt=0, description="最大输出token数") + top_p: float | None = Field(default=None, ge=0.0, le=1.0, description="核采样参数") + top_k: int | None = Field(default=None, gt=0, description="Top-K采样参数") + frequency_penalty: float | None = Field( + default=None, ge=-2.0, le=2.0, description="频率惩罚" + ) + presence_penalty: float | None = Field( + default=None, ge=-2.0, le=2.0, description="存在惩罚" + ) + repetition_penalty: float | None = Field( + default=None, ge=0.0, le=2.0, description="重复惩罚" + ) + + stop: list[str] | str | None = Field(default=None, description="停止序列") + + response_format: ResponseFormat | dict[str, Any] | None = Field( + default=None, description="期望的响应格式" + ) + response_mime_type: str | None = Field( + default=None, description="响应MIME类型(Gemini专用)" + ) + response_schema: dict[str, Any] | None = Field( + default=None, description="JSON响应模式" + ) + thinking_budget: float | None = Field( + default=None, ge=0.0, le=1.0, description="思考预算" + ) + safety_settings: dict[str, str] | None = Field(default=None, description="安全设置") + response_modalities: list[str] | None = Field( + default=None, description="响应模态类型" + ) + + enable_code_execution: bool | None = Field( + default=None, description="是否启用代码执行" + ) + enable_grounding: bool | None = Field( + default=None, description="是否启用信息来源关联" + ) + enable_caching: bool | None = Field(default=None, description="是否启用响应缓存") + + custom_params: dict[str, Any] | None = Field(default=None, description="自定义参数") + + def to_dict(self) -> dict[str, Any]: + """转换为字典,排除None值""" + result = {} + model_data = getattr(self, "model_dump", lambda: {})() + if not model_data: + model_data = {} + for field_name, _ in self.__class__.__dict__.get( + "model_fields", {} + ).items(): + value = getattr(self, field_name, None) + if value is not None: + model_data[field_name] = value + for key, value in model_data.items(): + if value is not None: + if key == "custom_params" and isinstance(value, dict): + result.update(value) + else: + result[key] = value + return result + + def merge_with_base_config( + self, + base_temperature: float | None = None, + base_max_tokens: int | None = None, + ) -> dict[str, Any]: + """与基础配置合并,覆盖参数优先""" + merged = {} + + if base_temperature is not None: + merged["temperature"] = base_temperature + if base_max_tokens is not None: + merged["max_tokens"] = base_max_tokens + + override_dict = self.to_dict() + merged.update(override_dict) + + return merged + + +class LLMGenerationConfig(ModelConfigOverride): + """LLM 生成配置,继承模型配置覆盖参数""" + + def to_api_params(self, api_type: str, model_name: str) -> dict[str, Any]: + """转换为API参数,支持不同API类型的参数名映射""" + _ = model_name + params = {} + + if self.temperature is not None: + params["temperature"] = self.temperature + + if self.max_tokens is not None: + if api_type in ["gemini", "gemini_native"]: + params["maxOutputTokens"] = self.max_tokens + else: + params["max_tokens"] = self.max_tokens + + if api_type in ["gemini", "gemini_native"]: + if self.top_k is not None: + params["topK"] = self.top_k + if self.top_p is not None: + params["topP"] = self.top_p + else: + if self.top_k is not None: + params["top_k"] = self.top_k + if self.top_p is not None: + params["top_p"] = self.top_p + + if api_type in ["openai", "deepseek", "zhipu", "general_openai_compat"]: + if self.frequency_penalty is not None: + params["frequency_penalty"] = self.frequency_penalty + if self.presence_penalty is not None: + params["presence_penalty"] = self.presence_penalty + + if self.repetition_penalty is not None: + if api_type == "openai": + logger.warning("OpenAI官方API不支持repetition_penalty参数,已忽略") + else: + params["repetition_penalty"] = self.repetition_penalty + + if self.response_format is not None: + if isinstance(self.response_format, dict): + if api_type in ["openai", "zhipu", "deepseek", "general_openai_compat"]: + params["response_format"] = self.response_format + logger.debug( + f"为 {api_type} 使用自定义 response_format: " + f"{self.response_format}" + ) + elif self.response_format == ResponseFormat.JSON: + if api_type in ["openai", "zhipu", "deepseek", "general_openai_compat"]: + params["response_format"] = {"type": "json_object"} + logger.debug(f"为 {api_type} 启用 JSON 对象输出模式") + elif api_type in ["gemini", "gemini_native"]: + params["responseMimeType"] = "application/json" + if self.response_schema: + params["responseSchema"] = self.response_schema + logger.debug(f"为 {api_type} 启用 JSON MIME 类型输出模式") + + if api_type in ["gemini", "gemini_native"]: + if ( + self.response_format != ResponseFormat.JSON + and self.response_mime_type is not None + ): + params["responseMimeType"] = self.response_mime_type + logger.debug( + f"使用显式设置的 responseMimeType: {self.response_mime_type}" + ) + + if self.response_schema is not None and "responseSchema" not in params: + params["responseSchema"] = self.response_schema + if self.thinking_budget is not None: + params["thinkingBudget"] = self.thinking_budget + if self.safety_settings is not None: + params["safetySettings"] = self.safety_settings + if self.response_modalities is not None: + params["responseModalities"] = self.response_modalities + + if self.custom_params: + custom_mapped = apply_api_specific_mappings(self.custom_params, api_type) + params.update(custom_mapped) + + logger.debug(f"为{api_type}转换配置参数: {len(params)}个参数") + return params + + +def validate_override_params( + override_config: dict[str, Any] | LLMGenerationConfig | None, +) -> LLMGenerationConfig: + """验证和标准化覆盖参数""" + if override_config is None: + return LLMGenerationConfig() + + if isinstance(override_config, dict): + try: + filtered_config = { + k: v for k, v in override_config.items() if v is not None + } + return LLMGenerationConfig(**filtered_config) + except Exception as e: + logger.warning(f"覆盖配置参数验证失败: {e}") + raise LLMException( + f"无效的覆盖配置参数: {e}", + code=LLMErrorCode.CONFIGURATION_ERROR, + cause=e, + ) + + return override_config + + +def apply_api_specific_mappings( + params: dict[str, Any], api_type: str +) -> dict[str, Any]: + """应用API特定的参数映射""" + mapped_params = params.copy() + + if api_type in ["gemini", "gemini_native"]: + if "max_tokens" in mapped_params: + mapped_params["maxOutputTokens"] = mapped_params.pop("max_tokens") + if "top_k" in mapped_params: + mapped_params["topK"] = mapped_params.pop("top_k") + if "top_p" in mapped_params: + mapped_params["topP"] = mapped_params.pop("top_p") + + unsupported = ["frequency_penalty", "presence_penalty", "repetition_penalty"] + for param in unsupported: + if param in mapped_params: + logger.warning(f"Gemini 原生API不支持参数 '{param}',已忽略") + mapped_params.pop(param) + + elif api_type in ["openai", "deepseek", "zhipu", "general_openai_compat"]: + if "repetition_penalty" in mapped_params and api_type == "openai": + logger.warning("OpenAI官方API不支持repetition_penalty参数,已忽略") + mapped_params.pop("repetition_penalty") + + if "stop" in mapped_params: + stop_value = mapped_params["stop"] + if isinstance(stop_value, str): + mapped_params["stop"] = [stop_value] + + return mapped_params + + +def create_generation_config_from_kwargs(**kwargs) -> LLMGenerationConfig: + """从关键字参数创建生成配置""" + model_fields = getattr(LLMGenerationConfig, "model_fields", {}) + known_fields = set(model_fields.keys()) + known_params = {} + custom_params = {} + + for key, value in kwargs.items(): + if key in known_fields: + known_params[key] = value + else: + custom_params[key] = value + + if custom_params: + known_params["custom_params"] = custom_params + + return LLMGenerationConfig(**known_params) diff --git a/zhenxun/services/llm/config/presets.py b/zhenxun/services/llm/config/presets.py new file mode 100644 index 00000000..7a6023d5 --- /dev/null +++ b/zhenxun/services/llm/config/presets.py @@ -0,0 +1,169 @@ +""" +LLM 预设配置 + +提供常用的配置预设,特别是针对 Gemini 的高级功能。 +""" + +from typing import Any + +from .generation import LLMGenerationConfig + + +class CommonOverrides: + """常用的配置覆盖预设""" + + @staticmethod + def creative() -> LLMGenerationConfig: + """创意模式:高温度,鼓励创新""" + return LLMGenerationConfig(temperature=0.9, top_p=0.95, frequency_penalty=0.1) + + @staticmethod + def precise() -> LLMGenerationConfig: + """精确模式:低温度,确定性输出""" + return LLMGenerationConfig(temperature=0.1, top_p=0.9, frequency_penalty=0.0) + + @staticmethod + def balanced() -> LLMGenerationConfig: + """平衡模式:中等温度""" + return LLMGenerationConfig(temperature=0.5, top_p=0.9, frequency_penalty=0.0) + + @staticmethod + def concise(max_tokens: int = 100) -> LLMGenerationConfig: + """简洁模式:限制输出长度""" + return LLMGenerationConfig( + temperature=0.3, + max_tokens=max_tokens, + stop=["\n\n", "。", "!", "?"], + ) + + @staticmethod + def detailed(max_tokens: int = 2000) -> LLMGenerationConfig: + """详细模式:鼓励详细输出""" + return LLMGenerationConfig( + temperature=0.7, max_tokens=max_tokens, frequency_penalty=-0.1 + ) + + @staticmethod + def gemini_json() -> LLMGenerationConfig: + """Gemini JSON模式:强制JSON输出""" + return LLMGenerationConfig( + temperature=0.3, response_mime_type="application/json" + ) + + @staticmethod + def gemini_thinking(budget: float = 0.8) -> LLMGenerationConfig: + """Gemini 思考模式:使用思考预算""" + return LLMGenerationConfig(temperature=0.7, thinking_budget=budget) + + @staticmethod + def gemini_creative() -> LLMGenerationConfig: + """Gemini 创意模式:高温度创意输出""" + return LLMGenerationConfig(temperature=0.9, top_p=0.95) + + @staticmethod + def gemini_structured(schema: dict[str, Any]) -> LLMGenerationConfig: + """Gemini 结构化输出:自定义JSON模式""" + return LLMGenerationConfig( + temperature=0.3, + response_mime_type="application/json", + response_schema=schema, + ) + + @staticmethod + def gemini_safe() -> LLMGenerationConfig: + """Gemini 安全模式:严格安全设置""" + return LLMGenerationConfig( + temperature=0.5, + safety_settings={ + "HARM_CATEGORY_HARASSMENT": "BLOCK_MEDIUM_AND_ABOVE", + "HARM_CATEGORY_HATE_SPEECH": "BLOCK_MEDIUM_AND_ABOVE", + "HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_MEDIUM_AND_ABOVE", + "HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_MEDIUM_AND_ABOVE", + }, + ) + + @staticmethod + def gemini_multimodal() -> LLMGenerationConfig: + """Gemini 多模态模式:优化多模态处理""" + return LLMGenerationConfig(temperature=0.6, max_tokens=2048, top_p=0.8) + + @staticmethod + def gemini_code_execution() -> LLMGenerationConfig: + """Gemini 代码执行模式:启用代码执行功能""" + return LLMGenerationConfig( + temperature=0.3, + max_tokens=4096, + enable_code_execution=True, + custom_params={"code_execution_timeout": 30}, + ) + + @staticmethod + def gemini_grounding() -> LLMGenerationConfig: + """Gemini 信息来源关联模式:启用Google搜索""" + return LLMGenerationConfig( + temperature=0.5, + max_tokens=4096, + enable_grounding=True, + custom_params={ + "grounding_config": {"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}} + }, + ) + + @staticmethod + def gemini_cached() -> LLMGenerationConfig: + """Gemini 缓存模式:启用响应缓存""" + return LLMGenerationConfig( + temperature=0.3, + max_tokens=2048, + enable_caching=True, + ) + + @staticmethod + def gemini_advanced() -> LLMGenerationConfig: + """Gemini 高级模式:启用所有高级功能""" + return LLMGenerationConfig( + temperature=0.5, + max_tokens=4096, + enable_code_execution=True, + enable_grounding=True, + enable_caching=True, + custom_params={ + "code_execution_timeout": 30, + "grounding_config": { + "dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"} + }, + }, + ) + + @staticmethod + def gemini_research() -> LLMGenerationConfig: + """Gemini 研究模式:思考+搜索+结构化输出""" + return LLMGenerationConfig( + temperature=0.6, + max_tokens=4096, + thinking_budget=0.8, + enable_grounding=True, + response_mime_type="application/json", + custom_params={ + "grounding_config": {"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}} + }, + ) + + @staticmethod + def gemini_analysis() -> LLMGenerationConfig: + """Gemini 分析模式:深度思考+详细输出""" + return LLMGenerationConfig( + temperature=0.4, + max_tokens=6000, + thinking_budget=0.9, + top_p=0.8, + ) + + @staticmethod + def gemini_fast_response() -> LLMGenerationConfig: + """Gemini 快速响应模式:低延迟+简洁输出""" + return LLMGenerationConfig( + temperature=0.3, + max_tokens=512, + top_p=0.8, + ) diff --git a/zhenxun/services/llm/config/providers.py b/zhenxun/services/llm/config/providers.py new file mode 100644 index 00000000..8f4dea80 --- /dev/null +++ b/zhenxun/services/llm/config/providers.py @@ -0,0 +1,328 @@ +""" +LLM 提供商配置管理 + +负责注册和管理 AI 服务提供商的配置项。 +""" + +from typing import Any + +from pydantic import BaseModel, Field + +from zhenxun.configs.config import Config +from zhenxun.services.log import logger + +from ..types.models import ModelDetail, ProviderConfig + +AI_CONFIG_GROUP = "AI" +PROVIDERS_CONFIG_KEY = "PROVIDERS" + + +class LLMConfig(BaseModel): + """LLM 服务配置类""" + + default_model_name: str | None = Field( + default=None, + description="LLM服务全局默认使用的模型名称 (格式: ProviderName/ModelName)", + ) + proxy: str | None = Field( + default=None, + description="LLM服务请求使用的网络代理,例如 http://127.0.0.1:7890", + ) + timeout: int = Field(default=180, description="LLM服务API请求超时时间(秒)") + max_retries_llm: int = Field( + default=3, description="LLM服务请求失败时的最大重试次数" + ) + retry_delay_llm: int = Field( + default=2, description="LLM服务请求重试的基础延迟时间(秒)" + ) + providers: list[ProviderConfig] = Field( + default_factory=list, description="配置多个 AI 服务提供商及其模型信息" + ) + + def get_provider_by_name(self, name: str) -> ProviderConfig | None: + """根据名称获取提供商配置 + + 参数: + name: 提供商名称 + + 返回: + ProviderConfig | None: 提供商配置,如果未找到则返回 None + """ + for provider in self.providers: + if provider.name == name: + return provider + return None + + def get_model_by_provider_and_name( + self, provider_name: str, model_name: str + ) -> tuple[ProviderConfig, ModelDetail] | None: + """根据提供商名称和模型名称获取配置 + + 参数: + provider_name: 提供商名称 + model_name: 模型名称 + + 返回: + tuple[ProviderConfig, ModelDetail] | None: 提供商配置和模型详情的元组, + 如果未找到则返回 None + """ + provider = self.get_provider_by_name(provider_name) + if not provider: + return None + + for model in provider.models: + if model.model_name == model_name: + return provider, model + return None + + def list_available_models(self) -> list[dict[str, Any]]: + """列出所有可用的模型 + + 返回: + list[dict[str, Any]]: 模型信息列表 + """ + models = [] + for provider in self.providers: + for model in provider.models: + models.append( + { + "provider_name": provider.name, + "model_name": model.model_name, + "full_name": f"{provider.name}/{model.model_name}", + "is_available": model.is_available, + "is_embedding_model": model.is_embedding_model, + "api_type": provider.api_type, + } + ) + return models + + def validate_model_name(self, provider_model_name: str) -> bool: + """验证模型名称格式是否正确 + + 参数: + provider_model_name: 格式为 "ProviderName/ModelName" 的字符串 + + 返回: + bool: 是否有效 + """ + if not provider_model_name or "/" not in provider_model_name: + return False + + parts = provider_model_name.split("/", 1) + if len(parts) != 2: + return False + + provider_name, model_name = parts + return ( + self.get_model_by_provider_and_name(provider_name, model_name) is not None + ) + + +def get_ai_config(): + """获取 AI 配置组""" + return Config.get(AI_CONFIG_GROUP) + + +def get_default_providers() -> list[dict[str, Any]]: + """获取默认的提供商配置 + + 返回: + list[dict[str, Any]]: 默认提供商配置列表 + """ + return [ + { + "name": "DeepSeek", + "api_key": "sk-******", + "api_base": "https://api.deepseek.com", + "api_type": "openai", + "models": [ + { + "model_name": "deepseek-chat", + "max_tokens": 4096, + "temperature": 0.7, + }, + { + "model_name": "deepseek-reasoner", + }, + ], + }, + { + "name": "GLM", + "api_key": "", + "api_base": "https://open.bigmodel.cn", + "api_type": "zhipu", + "models": [ + {"model_name": "glm-4-flash"}, + {"model_name": "glm-4-plus"}, + ], + }, + { + "name": "Gemini", + "api_key": [ + "AIzaSy*****************************", + "AIzaSy*****************************", + "AIzaSy*****************************", + ], + "api_base": "https://generativelanguage.googleapis.com", + "api_type": "gemini", + "models": [ + {"model_name": "gemini-2.0-flash"}, + {"model_name": "gemini-2.5-flash-preview-05-20"}, + ], + }, + ] + + +def register_llm_configs(): + """注册 LLM 服务的配置项""" + logger.info("注册 LLM 服务的配置项") + + llm_config = LLMConfig() + + Config.add_plugin_config( + AI_CONFIG_GROUP, + "default_model_name", + llm_config.default_model_name, + help="LLM服务全局默认使用的模型名称 (格式: ProviderName/ModelName)", + type=str, + ) + Config.add_plugin_config( + AI_CONFIG_GROUP, + "proxy", + llm_config.proxy, + help="LLM服务请求使用的网络代理,例如 http://127.0.0.1:7890", + type=str, + ) + Config.add_plugin_config( + AI_CONFIG_GROUP, + "timeout", + llm_config.timeout, + help="LLM服务API请求超时时间(秒)", + type=int, + ) + Config.add_plugin_config( + AI_CONFIG_GROUP, + "max_retries_llm", + llm_config.max_retries_llm, + help="LLM服务请求失败时的最大重试次数", + type=int, + ) + Config.add_plugin_config( + AI_CONFIG_GROUP, + "retry_delay_llm", + llm_config.retry_delay_llm, + help="LLM服务请求重试的基础延迟时间(秒)", + type=int, + ) + + Config.add_plugin_config( + AI_CONFIG_GROUP, + PROVIDERS_CONFIG_KEY, + get_default_providers(), + help="配置多个 AI 服务提供商及其模型信息", + default_value=[], + type=list[ProviderConfig], + ) + + +def get_llm_config() -> LLMConfig: + """获取 LLM 配置实例 + + 返回: + LLMConfig: LLM 配置实例 + """ + ai_config = get_ai_config() + + config_data = { + "default_model_name": ai_config.get("default_model_name"), + "proxy": ai_config.get("proxy"), + "timeout": ai_config.get("timeout", 180), + "max_retries_llm": ai_config.get("max_retries_llm", 3), + "retry_delay_llm": ai_config.get("retry_delay_llm", 2), + "providers": ai_config.get(PROVIDERS_CONFIG_KEY, []), + } + + return LLMConfig(**config_data) + + +def validate_llm_config() -> tuple[bool, list[str]]: + """验证 LLM 配置的有效性 + + 返回: + tuple[bool, list[str]]: (是否有效, 错误信息列表) + """ + errors = [] + + try: + llm_config = get_llm_config() + + if llm_config.timeout <= 0: + errors.append("timeout 必须大于 0") + + if llm_config.max_retries_llm < 0: + errors.append("max_retries_llm 不能小于 0") + + if llm_config.retry_delay_llm <= 0: + errors.append("retry_delay_llm 必须大于 0") + + if not llm_config.providers: + errors.append("至少需要配置一个 AI 服务提供商") + else: + provider_names = set() + for provider in llm_config.providers: + if provider.name in provider_names: + errors.append(f"提供商名称重复: {provider.name}") + provider_names.add(provider.name) + + if not provider.api_key: + errors.append(f"提供商 {provider.name} 缺少 API Key") + + if not provider.models: + errors.append(f"提供商 {provider.name} 没有配置任何模型") + else: + model_names = set() + for model in provider.models: + if model.model_name in model_names: + errors.append( + f"提供商 {provider.name} 中模型名称重复: " + f"{model.model_name}" + ) + model_names.add(model.model_name) + + if llm_config.default_model_name: + if not llm_config.validate_model_name(llm_config.default_model_name): + errors.append( + f"默认模型 {llm_config.default_model_name} 在配置中不存在" + ) + + except Exception as e: + errors.append(f"配置解析失败: {e!s}") + + return len(errors) == 0, errors + + +def set_default_model(provider_model_name: str | None) -> bool: + """设置默认模型 + + 参数: + provider_model_name: 模型名称,格式为 "ProviderName/ModelName",None 表示清除 + + 返回: + bool: 是否设置成功 + """ + if provider_model_name: + llm_config = get_llm_config() + if not llm_config.validate_model_name(provider_model_name): + logger.error(f"模型 {provider_model_name} 在配置中不存在") + return False + + Config.set_config( + AI_CONFIG_GROUP, "default_model_name", provider_model_name, auto_save=True + ) + + if provider_model_name: + logger.info(f"默认模型已设置为: {provider_model_name}") + else: + logger.info("默认模型已清除") + + return True diff --git a/zhenxun/services/llm/core.py b/zhenxun/services/llm/core.py new file mode 100644 index 00000000..ffd900cf --- /dev/null +++ b/zhenxun/services/llm/core.py @@ -0,0 +1,378 @@ +""" +LLM 核心基础设施模块 + +包含执行 LLM 请求所需的底层组件,如 HTTP 客户端、API Key 存储和智能重试逻辑。 +""" + +import asyncio +from typing import Any + +import httpx +from pydantic import BaseModel + +from zhenxun.services.log import logger +from zhenxun.utils.user_agent import get_user_agent + +from .types import ProviderConfig +from .types.exceptions import LLMErrorCode, LLMException + + +class HttpClientConfig(BaseModel): + """HTTP客户端配置""" + + timeout: int = 180 + max_connections: int = 100 + max_keepalive_connections: int = 20 + proxy: str | None = None + + +class LLMHttpClient: + """LLM服务专用HTTP客户端""" + + def __init__(self, config: HttpClientConfig | None = None): + self.config = config or HttpClientConfig() + self._client: httpx.AsyncClient | None = None + self._active_requests = 0 + self._lock = asyncio.Lock() + + async def _ensure_client_initialized(self) -> httpx.AsyncClient: + if self._client is None or self._client.is_closed: + async with self._lock: + if self._client is None or self._client.is_closed: + logger.debug( + f"LLMHttpClient: Initializing new httpx.AsyncClient " + f"with config: {self.config}" + ) + headers = get_user_agent() + limits = httpx.Limits( + max_connections=self.config.max_connections, + max_keepalive_connections=self.config.max_keepalive_connections, + ) + timeout = httpx.Timeout(self.config.timeout) + self._client = httpx.AsyncClient( + headers=headers, + limits=limits, + timeout=timeout, + proxies=self.config.proxy, + follow_redirects=True, + ) + if self._client is None: + raise LLMException( + "HTTP client failed to initialize.", LLMErrorCode.CONFIGURATION_ERROR + ) + return self._client + + async def post(self, url: str, **kwargs: Any) -> httpx.Response: + client = await self._ensure_client_initialized() + async with self._lock: + self._active_requests += 1 + try: + return await client.post(url, **kwargs) + finally: + async with self._lock: + self._active_requests -= 1 + + async def close(self): + async with self._lock: + if self._client and not self._client.is_closed: + logger.debug( + f"LLMHttpClient: Closing with config: {self.config}. " + f"Active requests: {self._active_requests}" + ) + if self._active_requests > 0: + logger.warning( + f"LLMHttpClient: Closing while {self._active_requests} " + f"requests are still active." + ) + await self._client.aclose() + self._client = None + logger.debug(f"LLMHttpClient for config {self.config} definitively closed.") + + @property + def is_closed(self) -> bool: + return self._client is None or self._client.is_closed + + +class LLMHttpClientManager: + """管理 LLMHttpClient 实例的工厂和池""" + + def __init__(self): + self._clients: dict[tuple[int, str | None], LLMHttpClient] = {} + self._lock = asyncio.Lock() + + def _get_client_key( + self, provider_config: ProviderConfig + ) -> tuple[int, str | None]: + return (provider_config.timeout, provider_config.proxy) + + async def get_client(self, provider_config: ProviderConfig) -> LLMHttpClient: + key = self._get_client_key(provider_config) + async with self._lock: + client = self._clients.get(key) + if client and not client.is_closed: + logger.debug( + f"LLMHttpClientManager: Reusing existing LLMHttpClient " + f"for key: {key}" + ) + return client + + if client and client.is_closed: + logger.debug( + f"LLMHttpClientManager: Found a closed client for key {key}. " + f"Creating a new one." + ) + + logger.debug( + f"LLMHttpClientManager: Creating new LLMHttpClient for key: {key}" + ) + http_client_config = HttpClientConfig( + timeout=provider_config.timeout, proxy=provider_config.proxy + ) + new_client = LLMHttpClient(config=http_client_config) + self._clients[key] = new_client + return new_client + + async def shutdown(self): + async with self._lock: + logger.info( + f"LLMHttpClientManager: Shutting down. " + f"Closing {len(self._clients)} client(s)." + ) + close_tasks = [ + client.close() + for client in self._clients.values() + if client and not client.is_closed + ] + if close_tasks: + await asyncio.gather(*close_tasks, return_exceptions=True) + self._clients.clear() + logger.info("LLMHttpClientManager: Shutdown complete.") + + +http_client_manager = LLMHttpClientManager() + + +async def create_llm_http_client( + timeout: int = 180, + proxy: str | None = None, +) -> LLMHttpClient: + """创建LLM HTTP客户端""" + config = HttpClientConfig(timeout=timeout, proxy=proxy) + return LLMHttpClient(config) + + +class RetryConfig: + """重试配置""" + + def __init__( + self, + max_retries: int = 3, + retry_delay: float = 1.0, + exponential_backoff: bool = True, + key_rotation: bool = True, + ): + self.max_retries = max_retries + self.retry_delay = retry_delay + self.exponential_backoff = exponential_backoff + self.key_rotation = key_rotation + + +async def with_smart_retry( + func, + *args, + retry_config: RetryConfig | None = None, + key_store: "KeyStatusStore | None" = None, + provider_name: str | None = None, + **kwargs: Any, +) -> Any: + """智能重试装饰器 - 支持Key轮询和错误分类""" + config = retry_config or RetryConfig() + last_exception: Exception | None = None + failed_keys: set[str] = set() + + for attempt in range(config.max_retries + 1): + try: + if config.key_rotation and "failed_keys" in func.__code__.co_varnames: + kwargs["failed_keys"] = failed_keys + + return await func(*args, **kwargs) + + except LLMException as e: + last_exception = e + + if e.code in [ + LLMErrorCode.API_KEY_INVALID, + LLMErrorCode.API_QUOTA_EXCEEDED, + ]: + if hasattr(e, "details") and e.details and "api_key" in e.details: + failed_keys.add(e.details["api_key"]) + if key_store and provider_name: + await key_store.record_failure( + e.details["api_key"], e.details.get("status_code") + ) + + should_retry = _should_retry_llm_error(e, attempt, config.max_retries) + if not should_retry: + logger.error(f"不可重试的错误,停止重试: {e}") + raise + + if attempt < config.max_retries: + wait_time = config.retry_delay + if config.exponential_backoff: + wait_time *= 2**attempt + logger.warning( + f"请求失败,{wait_time}秒后重试 (第{attempt + 1}次): {e}" + ) + await asyncio.sleep(wait_time) + else: + logger.error(f"重试{config.max_retries}次后仍然失败: {e}") + + except Exception as e: + last_exception = e + logger.error(f"非LLM异常,停止重试: {e}") + raise LLMException( + f"操作失败: {e}", + code=LLMErrorCode.GENERATION_FAILED, + cause=e, + ) + + if last_exception: + raise last_exception + else: + raise RuntimeError("重试函数未能正常执行且未捕获到异常") + + +def _should_retry_llm_error( + error: LLMException, attempt: int, max_retries: int +) -> bool: + """判断LLM错误是否应该重试""" + non_retryable_errors = { + LLMErrorCode.MODEL_NOT_FOUND, + LLMErrorCode.CONTEXT_LENGTH_EXCEEDED, + LLMErrorCode.USER_LOCATION_NOT_SUPPORTED, + LLMErrorCode.CONFIGURATION_ERROR, + } + + if error.code in non_retryable_errors: + return False + + retryable_errors = { + LLMErrorCode.API_REQUEST_FAILED, + LLMErrorCode.API_TIMEOUT, + LLMErrorCode.API_RATE_LIMITED, + LLMErrorCode.API_RESPONSE_INVALID, + LLMErrorCode.RESPONSE_PARSE_ERROR, + LLMErrorCode.GENERATION_FAILED, + LLMErrorCode.CONTENT_FILTERED, + LLMErrorCode.API_KEY_INVALID, + LLMErrorCode.API_QUOTA_EXCEEDED, + } + + if error.code in retryable_errors: + if error.code == LLMErrorCode.API_QUOTA_EXCEEDED: + return attempt < min(2, max_retries) + elif error.code == LLMErrorCode.CONTENT_FILTERED: + return attempt < min(1, max_retries) + return True + + return False + + +class KeyStatusStore: + """API Key 状态管理存储 - 优化版本,支持轮询和负载均衡""" + + def __init__(self): + self._key_status: dict[str, bool] = {} + self._key_usage_count: dict[str, int] = {} + self._key_last_used: dict[str, float] = {} + self._provider_key_index: dict[str, int] = {} + self._lock = asyncio.Lock() + + async def get_next_available_key( + self, + provider_name: str, + api_keys: list[str], + exclude_keys: set[str] | None = None, + ) -> str | None: + """获取下一个可用的API密钥(轮询策略)""" + if not api_keys: + return None + + exclude_keys = exclude_keys or set() + available_keys = [ + key + for key in api_keys + if key not in exclude_keys and self._key_status.get(key, True) + ] + + if not available_keys: + return api_keys[0] if api_keys else None + + async with self._lock: + current_index = self._provider_key_index.get(provider_name, 0) + + selected_key = available_keys[current_index % len(available_keys)] + + self._provider_key_index[provider_name] = (current_index + 1) % len( + available_keys + ) + + import time + + self._key_usage_count[selected_key] = ( + self._key_usage_count.get(selected_key, 0) + 1 + ) + self._key_last_used[selected_key] = time.time() + + logger.debug( + f"轮询选择API密钥: {self._get_key_id(selected_key)} " + f"(使用次数: {self._key_usage_count[selected_key]})" + ) + + return selected_key + + async def record_success(self, api_key: str): + """记录成功使用""" + async with self._lock: + self._key_status[api_key] = True + logger.debug(f"记录API密钥成功使用: {self._get_key_id(api_key)}") + + async def record_failure(self, api_key: str, status_code: int | None): + """记录失败使用""" + key_id = self._get_key_id(api_key) + async with self._lock: + if status_code in [401, 403]: + self._key_status[api_key] = False + logger.warning( + f"API密钥认证失败,标记为不可用: {key_id} (状态码: {status_code})" + ) + else: + logger.debug(f"记录API密钥失败使用: {key_id} (状态码: {status_code})") + + async def reset_key_status(self, api_key: str): + """重置密钥状态(用于恢复机制)""" + async with self._lock: + self._key_status[api_key] = True + logger.info(f"重置API密钥状态: {self._get_key_id(api_key)}") + + async def get_key_stats(self, api_keys: list[str]) -> dict[str, dict]: + """获取密钥使用统计""" + stats = {} + async with self._lock: + for key in api_keys: + key_id = self._get_key_id(key) + stats[key_id] = { + "available": self._key_status.get(key, True), + "usage_count": self._key_usage_count.get(key, 0), + "last_used": self._key_last_used.get(key, 0), + } + return stats + + def _get_key_id(self, api_key: str) -> str: + """获取API密钥的标识符(用于日志)""" + if len(api_key) <= 8: + return api_key + return f"{api_key[:4]}...{api_key[-4:]}" + + +key_store = KeyStatusStore() diff --git a/zhenxun/services/llm/manager.py b/zhenxun/services/llm/manager.py new file mode 100644 index 00000000..f23dfa50 --- /dev/null +++ b/zhenxun/services/llm/manager.py @@ -0,0 +1,434 @@ +""" +LLM 模型管理器 + +负责模型实例的创建、缓存、配置管理和生命周期管理。 +""" + +import hashlib +import json +import time +from typing import Any + +from zhenxun.configs.config import Config +from zhenxun.services.log import logger + +from .config import validate_override_params +from .config.providers import AI_CONFIG_GROUP, PROVIDERS_CONFIG_KEY, get_ai_config +from .core import http_client_manager, key_store +from .service import LLMModel +from .types import LLMErrorCode, LLMException, ModelDetail, ProviderConfig + +DEFAULT_MODEL_NAME_KEY = "default_model_name" +PROXY_KEY = "proxy" +TIMEOUT_KEY = "timeout" + +_model_cache: dict[str, tuple[LLMModel, float]] = {} +_cache_ttl = 3600 +_max_cache_size = 10 + + +def parse_provider_model_string(name_str: str | None) -> tuple[str | None, str | None]: + """解析 'ProviderName/ModelName' 格式的字符串""" + if not name_str or "/" not in name_str: + return None, None + parts = name_str.split("/", 1) + if len(parts) == 2 and parts[0].strip() and parts[1].strip(): + return parts[0].strip(), parts[1].strip() + return None, None + + +def _make_cache_key( + provider_model_name: str | None, override_config: dict | None +) -> str: + """生成缓存键""" + config_str = ( + json.dumps(override_config, sort_keys=True) if override_config else "None" + ) + key_data = f"{provider_model_name}:{config_str}" + return hashlib.md5(key_data.encode()).hexdigest() + + +def _get_cached_model(cache_key: str) -> LLMModel | None: + """从缓存获取模型""" + if cache_key in _model_cache: + model, created_time = _model_cache[cache_key] + current_time = time.time() + + if current_time - created_time > _cache_ttl: + del _model_cache[cache_key] + logger.debug(f"模型缓存已过期: {cache_key}") + return None + + if model._is_closed: + logger.debug( + f"缓存的模型 {cache_key} ({model.provider_name}/{model.model_name}) " + f"处于_is_closed=True状态,重置为False以供复用。" + ) + model._is_closed = False + + logger.debug( + f"使用缓存的模型: {cache_key} -> {model.provider_name}/{model.model_name}" + ) + return model + return None + + +def _cache_model(cache_key: str, model: LLMModel): + """缓存模型实例""" + current_time = time.time() + + if len(_model_cache) >= _max_cache_size: + oldest_key = min(_model_cache.keys(), key=lambda k: _model_cache[k][1]) + del _model_cache[oldest_key] + + _model_cache[cache_key] = (model, current_time) + + +def clear_model_cache(): + """清空模型缓存""" + global _model_cache + _model_cache.clear() + logger.info("已清空模型缓存") + + +def get_cache_stats() -> dict[str, Any]: + """获取缓存统计信息""" + return { + "cache_size": len(_model_cache), + "max_cache_size": _max_cache_size, + "cache_ttl": _cache_ttl, + "cached_models": list(_model_cache.keys()), + } + + +def get_default_api_base_for_type(api_type: str) -> str | None: + """根据API类型获取默认的API基础地址""" + default_api_bases = { + "openai": "https://api.openai.com", + "deepseek": "https://api.deepseek.com", + "zhipu": "https://open.bigmodel.cn", + "gemini": "https://generativelanguage.googleapis.com", + "general_openai_compat": None, + } + + return default_api_bases.get(api_type) + + +def get_configured_providers() -> list[ProviderConfig]: + """从配置中获取Provider列表 - 简化版本""" + ai_config = get_ai_config() + providers_raw = ai_config.get(PROVIDERS_CONFIG_KEY, []) + if not isinstance(providers_raw, list): + logger.error( + f"配置项 {AI_CONFIG_GROUP}.{PROVIDERS_CONFIG_KEY} 不是一个列表," + f"将使用空列表。" + ) + return [] + + valid_providers = [] + for i, item in enumerate(providers_raw): + if not isinstance(item, dict): + logger.warning(f"配置文件中第 {i + 1} 项不是字典格式,已跳过。") + continue + + try: + if not item.get("name"): + logger.warning(f"Provider {i + 1} 缺少 'name' 字段,已跳过。") + continue + + if not item.get("api_key"): + logger.warning( + f"Provider '{item['name']}' 缺少 'api_key' 字段,已跳过。" + ) + continue + + if "api_type" not in item or not item["api_type"]: + provider_name = item.get("name", "").lower() + if "glm" in provider_name or "zhipu" in provider_name: + item["api_type"] = "zhipu" + elif "gemini" in provider_name or "google" in provider_name: + item["api_type"] = "gemini" + else: + item["api_type"] = "openai" + + if "api_base" not in item or not item["api_base"]: + api_type = item.get("api_type") + if api_type: + default_api_base = get_default_api_base_for_type(api_type) + if default_api_base: + item["api_base"] = default_api_base + + if "models" not in item: + item["models"] = [{"model_name": item.get("name", "default")}] + + provider_conf = ProviderConfig(**item) + valid_providers.append(provider_conf) + + except Exception as e: + logger.warning(f"解析配置文件中 Provider {i + 1} 时出错: {e},已跳过。") + + return valid_providers + + +def find_model_config( + provider_name: str, model_name: str +) -> tuple[ProviderConfig, ModelDetail] | None: + """在配置中查找指定的 Provider 和 ModelDetail + + Args: + provider_name: 提供商名称 + model_name: 模型名称 + + Returns: + 找到的 (ProviderConfig, ModelDetail) 元组,未找到则返回 None + """ + providers = get_configured_providers() + + for provider in providers: + if provider.name.lower() == provider_name.lower(): + for model_detail in provider.models: + if model_detail.model_name.lower() == model_name.lower(): + return provider, model_detail + + return None + + +def list_available_models() -> list[dict[str, Any]]: + """列出所有配置的可用模型""" + providers = get_configured_providers() + model_list = [] + for provider in providers: + for model_detail in provider.models: + model_info = { + "provider_name": provider.name, + "model_name": model_detail.model_name, + "full_name": f"{provider.name}/{model_detail.model_name}", + "api_type": provider.api_type or "auto-detect", + "api_base": provider.api_base, + "is_available": model_detail.is_available, + "is_embedding_model": model_detail.is_embedding_model, + "available_identifiers": _get_model_identifiers( + provider.name, model_detail + ), + } + model_list.append(model_info) + return model_list + + +def _get_model_identifiers(provider_name: str, model_detail: ModelDetail) -> list[str]: + """获取模型的所有可用标识符""" + return [f"{provider_name}/{model_detail.model_name}"] + + +def list_model_identifiers() -> dict[str, list[str]]: + """列出所有模型的可用标识符 + + Returns: + 字典,键为模型的完整名称,值为该模型的所有可用标识符列表 + """ + providers = get_configured_providers() + result = {} + + for provider in providers: + for model_detail in provider.models: + full_name = f"{provider.name}/{model_detail.model_name}" + identifiers = _get_model_identifiers(provider.name, model_detail) + result[full_name] = identifiers + + return result + + +def list_embedding_models() -> list[dict[str, Any]]: + """列出所有配置的嵌入模型""" + all_models = list_available_models() + return [model for model in all_models if model.get("is_embedding_model", False)] + + +async def get_model_instance( + provider_model_name: str | None = None, + override_config: dict[str, Any] | None = None, +) -> LLMModel: + """根据 'ProviderName/ModelName' 字符串获取并实例化 LLMModel (异步版本)""" + cache_key = _make_cache_key(provider_model_name, override_config) + cached_model = _get_cached_model(cache_key) + if cached_model: + if override_config: + validated_override = validate_override_params(override_config) + if cached_model._generation_config != validated_override: + cached_model._generation_config = validated_override + logger.debug( + f"对缓存模型 {provider_model_name} 应用新的覆盖配置: " + f"{validated_override.to_dict()}" + ) + return cached_model + + resolved_model_name_str = provider_model_name + if resolved_model_name_str is None: + resolved_model_name_str = get_global_default_model_name() + if resolved_model_name_str is None: + available_models_list = list_available_models() + if not available_models_list: + raise LLMException( + "未配置任何AI模型", code=LLMErrorCode.CONFIGURATION_ERROR + ) + resolved_model_name_str = available_models_list[0]["full_name"] + logger.warning(f"未指定模型,使用第一个可用模型: {resolved_model_name_str}") + + prov_name_str, mod_name_str = parse_provider_model_string(resolved_model_name_str) + if not prov_name_str or not mod_name_str: + raise LLMException( + f"无效的模型名称格式: '{resolved_model_name_str}'", + code=LLMErrorCode.MODEL_NOT_FOUND, + ) + + config_tuple_found = find_model_config(prov_name_str, mod_name_str) + if not config_tuple_found: + all_models = list_available_models() + raise LLMException( + f"未找到模型: '{resolved_model_name_str}'. " + f"可用: {[m['full_name'] for m in all_models]}", + code=LLMErrorCode.MODEL_NOT_FOUND, + ) + + provider_config_found, model_detail_found = config_tuple_found + + ai_config = get_ai_config() + global_proxy_setting = ai_config.get(PROXY_KEY) + default_timeout = ( + provider_config_found.timeout + if provider_config_found.timeout is not None + else 180 + ) + global_timeout_setting = ai_config.get(TIMEOUT_KEY, default_timeout) + + config_for_http_client = ProviderConfig( + name=provider_config_found.name, + api_key=provider_config_found.api_key, + models=provider_config_found.models, + timeout=global_timeout_setting, + proxy=global_proxy_setting, + api_base=provider_config_found.api_base, + api_type=provider_config_found.api_type, + openai_compat=provider_config_found.openai_compat, + temperature=provider_config_found.temperature, + max_tokens=provider_config_found.max_tokens, + ) + + shared_http_client = await http_client_manager.get_client(config_for_http_client) + + try: + model_instance = LLMModel( + provider_config=config_for_http_client, + model_detail=model_detail_found, + key_store=key_store, + http_client=shared_http_client, + ) + + if override_config: + validated_override_params = validate_override_params(override_config) + model_instance._generation_config = validated_override_params + logger.debug( + f"为新模型 {resolved_model_name_str} 应用配置覆盖: " + f"{validated_override_params.to_dict()}" + ) + + _cache_model(cache_key, model_instance) + logger.debug( + f"创建并缓存了新模型: {cache_key} -> {prov_name_str}/{mod_name_str}" + ) + return model_instance + except LLMException: + raise + except Exception as e: + logger.error( + f"实例化 LLMModel ({resolved_model_name_str}) 时发生内部错误: {e!s}", e=e + ) + raise LLMException( + f"初始化模型 '{resolved_model_name_str}' 失败: {e!s}", + code=LLMErrorCode.MODEL_INIT_FAILED, + cause=e, + ) + + +def get_global_default_model_name() -> str | None: + """获取全局默认模型名称""" + ai_config = get_ai_config() + return ai_config.get(DEFAULT_MODEL_NAME_KEY) + + +def set_global_default_model_name(provider_model_name: str | None) -> bool: + """设置全局默认模型名称""" + if provider_model_name: + prov_name, mod_name = parse_provider_model_string(provider_model_name) + if not prov_name or not mod_name or not find_model_config(prov_name, mod_name): + logger.error( + f"尝试设置的全局默认模型 '{provider_model_name}' 无效或未配置。" + ) + return False + + Config.set_config( + AI_CONFIG_GROUP, DEFAULT_MODEL_NAME_KEY, provider_model_name, auto_save=True + ) + if provider_model_name: + logger.info(f"LLM 服务全局默认模型已更新为: {provider_model_name}") + else: + logger.info("LLM 服务全局默认模型已清除。") + return True + + +async def get_key_usage_stats() -> dict[str, Any]: + """获取所有Provider的Key使用统计""" + providers = get_configured_providers() + stats = {} + + for provider in providers: + provider_stats = await key_store.get_key_stats( + [provider.api_key] + if isinstance(provider.api_key, str) + else provider.api_key + ) + stats[provider.name] = { + "total_keys": len( + [provider.api_key] + if isinstance(provider.api_key, str) + else provider.api_key + ), + "key_stats": provider_stats, + } + + return stats + + +async def reset_key_status(provider_name: str, api_key: str | None = None) -> bool: + """重置指定Provider的Key状态""" + providers = get_configured_providers() + target_provider = None + + for provider in providers: + if provider.name.lower() == provider_name.lower(): + target_provider = provider + break + + if not target_provider: + logger.error(f"未找到Provider: {provider_name}") + return False + + provider_keys = ( + [target_provider.api_key] + if isinstance(target_provider.api_key, str) + else target_provider.api_key + ) + + if api_key: + if api_key in provider_keys: + await key_store.reset_key_status(api_key) + logger.info(f"已重置Provider '{provider_name}' 的指定Key状态") + return True + else: + logger.error(f"指定的Key不属于Provider '{provider_name}'") + return False + else: + for key in provider_keys: + await key_store.reset_key_status(key) + logger.info(f"已重置Provider '{provider_name}' 的所有Key状态") + return True diff --git a/zhenxun/services/llm/service.py b/zhenxun/services/llm/service.py new file mode 100644 index 00000000..d054ca9b --- /dev/null +++ b/zhenxun/services/llm/service.py @@ -0,0 +1,632 @@ +""" +LLM 模型实现类 + +包含 LLM 模型的抽象基类和具体实现,负责与各种 AI 提供商的 API 交互。 +""" + +from abc import ABC, abstractmethod +from collections.abc import Awaitable, Callable +import json +from typing import Any + +from zhenxun.services.log import logger + +from .config import LLMGenerationConfig +from .config.providers import get_ai_config +from .core import ( + KeyStatusStore, + LLMHttpClient, + RetryConfig, + http_client_manager, + with_smart_retry, +) +from .types import ( + EmbeddingTaskType, + LLMErrorCode, + LLMException, + LLMMessage, + LLMResponse, + LLMTool, + ModelDetail, + ProviderConfig, +) + + +class LLMModelBase(ABC): + """LLM模型抽象基类""" + + @abstractmethod + async def generate_text( + self, + prompt: str, + history: list[dict[str, str]] | None = None, + **kwargs: Any, + ) -> str: + """生成文本""" + pass + + @abstractmethod + async def generate_response( + self, + messages: list[LLMMessage], + config: LLMGenerationConfig | None = None, + tools: list[LLMTool] | None = None, + tool_choice: str | dict[str, Any] | None = None, + **kwargs: Any, + ) -> LLMResponse: + """生成高级响应""" + pass + + @abstractmethod + async def generate_embeddings( + self, + texts: list[str], + task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT, + **kwargs: Any, + ) -> list[list[float]]: + """生成文本嵌入向量""" + pass + + +class LLMModel(LLMModelBase): + """LLM 模型实现类""" + + def __init__( + self, + provider_config: ProviderConfig, + model_detail: ModelDetail, + key_store: KeyStatusStore, + http_client: LLMHttpClient, + config_override: LLMGenerationConfig | None = None, + ): + self.provider_config = provider_config + self.model_detail = model_detail + self.key_store = key_store + self.http_client: LLMHttpClient = http_client + self._generation_config = config_override + + self.provider_name = provider_config.name + self.api_type = provider_config.api_type + self.api_base = provider_config.api_base + self.api_keys = ( + [provider_config.api_key] + if isinstance(provider_config.api_key, str) + else provider_config.api_key + ) + self.model_name = model_detail.model_name + self.temperature = model_detail.temperature + self.max_tokens = model_detail.max_tokens + + self._is_closed = False + + async def _get_http_client(self) -> LLMHttpClient: + """获取HTTP客户端""" + if self.http_client.is_closed: + logger.debug( + f"LLMModel {self.provider_name}/{self.model_name} 的 HTTP 客户端已关闭," + "正在获取新的客户端" + ) + self.http_client = await http_client_manager.get_client( + self.provider_config + ) + return self.http_client + + async def _select_api_key(self, failed_keys: set[str] | None = None) -> str: + """选择可用的API密钥(使用轮询策略)""" + if not self.api_keys: + raise LLMException( + f"提供商 {self.provider_name} 没有配置API密钥", + code=LLMErrorCode.NO_AVAILABLE_KEYS, + ) + + selected_key = await self.key_store.get_next_available_key( + self.provider_name, self.api_keys, failed_keys + ) + + if not selected_key: + raise LLMException( + f"提供商 {self.provider_name} 的所有API密钥当前都不可用", + code=LLMErrorCode.NO_AVAILABLE_KEYS, + details={ + "total_keys": len(self.api_keys), + "failed_keys": len(failed_keys or set()), + }, + ) + + return selected_key + + async def _execute_embedding_request( + self, + adapter, + texts: list[str], + task_type: EmbeddingTaskType | str, + http_client: LLMHttpClient, + failed_keys: set[str] | None = None, + ) -> list[list[float]]: + """执行单次嵌入请求 - 供重试机制调用""" + api_key = await self._select_api_key(failed_keys) + + try: + request_data = adapter.prepare_embedding_request( + model=self, + api_key=api_key, + texts=texts, + task_type=task_type, + ) + + http_response = await http_client.post( + request_data.url, + headers=request_data.headers, + json=request_data.body, + ) + + if http_response.status_code != 200: + error_text = http_response.text + logger.error( + f"HTTP嵌入请求失败: {http_response.status_code} - {error_text}" + ) + await self.key_store.record_failure(api_key, http_response.status_code) + + error_code = LLMErrorCode.API_REQUEST_FAILED + if http_response.status_code in [401, 403]: + error_code = LLMErrorCode.API_KEY_INVALID + elif http_response.status_code == 429: + error_code = LLMErrorCode.API_RATE_LIMITED + + raise LLMException( + f"HTTP嵌入请求失败: {http_response.status_code}", + code=error_code, + details={ + "status_code": http_response.status_code, + "response": error_text, + "api_key": api_key, + }, + ) + + try: + response_json = http_response.json() + adapter.validate_embedding_response(response_json) + embeddings = adapter.parse_embedding_response(response_json) + except Exception as e: + logger.error(f"解析嵌入响应失败: {e}", e=e) + await self.key_store.record_failure(api_key, None) + if isinstance(e, LLMException): + raise + else: + raise LLMException( + f"解析API嵌入响应失败: {e}", + code=LLMErrorCode.RESPONSE_PARSE_ERROR, + cause=e, + ) + + await self.key_store.record_success(api_key) + return embeddings + + except LLMException: + raise + except Exception as e: + logger.error(f"生成嵌入时发生未预期错误: {e}", e=e) + await self.key_store.record_failure(api_key, None) + raise LLMException( + f"生成嵌入失败: {e}", + code=LLMErrorCode.EMBEDDING_FAILED, + cause=e, + ) + + async def _execute_with_smart_retry( + self, + adapter, + messages: list[LLMMessage], + config: LLMGenerationConfig | None, + tools_dict: list[dict[str, Any]] | None, + tool_choice: str | dict[str, Any] | None, + http_client: LLMHttpClient, + ): + """智能重试机制 - 使用统一的重试装饰器""" + ai_config = get_ai_config() + max_retries = ai_config.get("max_retries_llm", 3) + retry_delay = ai_config.get("retry_delay_llm", 2) + retry_config = RetryConfig(max_retries=max_retries, retry_delay=retry_delay) + + return await with_smart_retry( + self._execute_single_request, + adapter, + messages, + config, + tools_dict, + tool_choice, + http_client, + retry_config=retry_config, + key_store=self.key_store, + provider_name=self.provider_name, + ) + + async def _execute_single_request( + self, + adapter, + messages: list[LLMMessage], + config: LLMGenerationConfig | None, + tools_dict: list[dict[str, Any]] | None, + tool_choice: str | dict[str, Any] | None, + http_client: LLMHttpClient, + failed_keys: set[str] | None = None, + ) -> LLMResponse: + """执行单次请求 - 供重试机制调用,直接返回 LLMResponse""" + api_key = await self._select_api_key(failed_keys) + + try: + request_data = adapter.prepare_advanced_request( + model=self, + api_key=api_key, + messages=messages, + config=config, + tools=tools_dict, + tool_choice=tool_choice, + ) + + http_response = await http_client.post( + request_data.url, + headers=request_data.headers, + json=request_data.body, + ) + + if http_response.status_code != 200: + error_text = http_response.text + logger.error( + f"HTTP请求失败: {http_response.status_code} - {error_text}" + ) + + await self.key_store.record_failure(api_key, http_response.status_code) + + if http_response.status_code in [401, 403]: + error_code = LLMErrorCode.API_KEY_INVALID + elif http_response.status_code == 429: + error_code = LLMErrorCode.API_RATE_LIMITED + elif http_response.status_code in [402, 413]: + error_code = LLMErrorCode.API_QUOTA_EXCEEDED + else: + error_code = LLMErrorCode.API_REQUEST_FAILED + + raise LLMException( + f"HTTP请求失败: {http_response.status_code}", + code=error_code, + details={ + "status_code": http_response.status_code, + "response": error_text, + "api_key": api_key, + }, + ) + + try: + response_json = http_response.json() + response_data = adapter.parse_response( + model=self, + response_json=response_json, + is_advanced=True, + ) + + from .types.models import LLMToolCall + + response_tool_calls = [] + if response_data.tool_calls: + for tc_data in response_data.tool_calls: + if isinstance(tc_data, LLMToolCall): + response_tool_calls.append(tc_data) + elif isinstance(tc_data, dict): + try: + response_tool_calls.append(LLMToolCall(**tc_data)) + except Exception as e: + logger.warning( + f"无法将工具调用数据转换为LLMToolCall: {tc_data}, " + f"error: {e}" + ) + else: + logger.warning(f"工具调用数据格式未知: {tc_data}") + + llm_response = LLMResponse( + text=response_data.text, + usage_info=response_data.usage_info, + raw_response=response_data.raw_response, + tool_calls=response_tool_calls if response_tool_calls else None, + code_executions=response_data.code_executions, + grounding_metadata=response_data.grounding_metadata, + cache_info=response_data.cache_info, + ) + + except Exception as e: + logger.error(f"解析响应失败: {e}", e=e) + await self.key_store.record_failure(api_key, None) + + if isinstance(e, LLMException): + raise + else: + raise LLMException( + f"解析API响应失败: {e}", + code=LLMErrorCode.RESPONSE_PARSE_ERROR, + cause=e, + ) + + await self.key_store.record_success(api_key) + + return llm_response + + except LLMException: + raise + except Exception as e: + logger.error(f"生成响应时发生未预期错误: {e}", e=e) + await self.key_store.record_failure(api_key, None) + + raise LLMException( + f"生成响应失败: {e}", + code=LLMErrorCode.GENERATION_FAILED, + cause=e, + ) + + async def close(self): + """ + 标记模型实例的当前使用周期结束。 + 共享的 HTTP 客户端由 LLMHttpClientManager 管理,不由 LLMModel 关闭。 + """ + if self._is_closed: + return + self._is_closed = True + logger.debug( + f"LLMModel实例的使用周期已结束: {self} (共享HTTP客户端状态不受影响)" + ) + + async def __aenter__(self): + if self._is_closed: + logger.debug( + f"Re-entering context for closed LLMModel {self}. " + f"Resetting _is_closed to False." + ) + self._is_closed = False + self._check_not_closed() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """异步上下文管理器出口""" + _ = exc_type, exc_val, exc_tb + await self.close() + + def _check_not_closed(self): + """检查实例是否已关闭""" + if self._is_closed: + raise RuntimeError(f"LLMModel实例已关闭: {self}") + + async def generate_text( + self, + prompt: str, + history: list[dict[str, str]] | None = None, + **kwargs: Any, + ) -> str: + """生成文本 - 通过 generate_response 实现""" + self._check_not_closed() + + messages: list[LLMMessage] = [] + + if history: + for msg in history: + role = msg.get("role", "user") + content_text = msg.get("content", "") + messages.append(LLMMessage(role=role, content=content_text)) + + messages.append(LLMMessage.user(prompt)) + + model_fields = getattr(LLMGenerationConfig, "model_fields", {}) + request_specific_config_dict = { + k: v for k, v in kwargs.items() if k in model_fields + } + request_specific_config = None + if request_specific_config_dict: + request_specific_config = LLMGenerationConfig( + **request_specific_config_dict + ) + + for key in request_specific_config_dict: + kwargs.pop(key, None) + + response = await self.generate_response( + messages, + config=request_specific_config, + **kwargs, + ) + return response.text + + async def generate_response( + self, + messages: list[LLMMessage], + config: LLMGenerationConfig | None = None, + tools: list[LLMTool] | None = None, + tool_choice: str | dict[str, Any] | None = None, + tool_executor: Callable[[str, dict[str, Any]], Awaitable[Any]] | None = None, + max_tool_iterations: int = 5, + **kwargs: Any, + ) -> LLMResponse: + """生成高级响应 - 实现完整的工具调用循环""" + self._check_not_closed() + + from .adapters import get_adapter_for_api_type + from .config.generation import create_generation_config_from_kwargs + + adapter = get_adapter_for_api_type(self.api_type) + if not adapter: + raise LLMException( + f"未找到适用于 API 类型 '{self.api_type}' 的适配器", + code=LLMErrorCode.CONFIGURATION_ERROR, + ) + + final_request_config = self._generation_config or LLMGenerationConfig() + if kwargs: + kwargs_config = create_generation_config_from_kwargs(**kwargs) + merged_dict = final_request_config.to_dict() + merged_dict.update(kwargs_config.to_dict()) + final_request_config = LLMGenerationConfig(**merged_dict) + + if config is not None: + merged_dict = final_request_config.to_dict() + merged_dict.update(config.to_dict()) + final_request_config = LLMGenerationConfig(**merged_dict) + + tools_dict: list[dict[str, Any]] | None = None + if tools: + tools_dict = [] + for tool in tools: + if hasattr(tool, "model_dump"): + model_dump_func = getattr(tool, "model_dump") + tools_dict.append(model_dump_func(exclude_none=True)) + elif isinstance(tool, dict): + tools_dict.append(tool) + else: + try: + tools_dict.append(dict(tool)) + except (TypeError, ValueError): + logger.warning(f"工具 '{tool}' 无法转换为字典,已忽略。") + + http_client = await self._get_http_client() + current_messages = list(messages) + + for iteration in range(max_tool_iterations): + logger.debug(f"工具调用循环迭代: {iteration + 1}/{max_tool_iterations}") + + llm_response = await self._execute_with_smart_retry( + adapter, + current_messages, + final_request_config, + tools_dict if iteration == 0 else None, + tool_choice if iteration == 0 else None, + http_client, + ) + + response_tool_calls = llm_response.tool_calls or [] + + if not response_tool_calls or not tool_executor: + logger.debug("模型未请求工具调用,或未提供工具执行器。返回当前响应。") + return llm_response + + logger.info(f"模型请求执行 {len(response_tool_calls)} 个工具。") + + assistant_message_content = llm_response.text if llm_response.text else "" + current_messages.append( + LLMMessage.assistant_tool_calls( + content=assistant_message_content, tool_calls=response_tool_calls + ) + ) + + tool_response_messages: list[LLMMessage] = [] + for tool_call in response_tool_calls: + tool_name = tool_call.function.name + try: + tool_args_dict = json.loads(tool_call.function.arguments) + logger.debug(f"执行工具: {tool_name},参数: {tool_args_dict}") + + tool_result = await tool_executor(tool_name, tool_args_dict) + logger.debug( + f"工具 '{tool_name}' 执行结果: {str(tool_result)[:200]}..." + ) + + tool_response_messages.append( + LLMMessage.tool_response( + tool_call_id=tool_call.id, + function_name=tool_name, + result=tool_result, + ) + ) + except json.JSONDecodeError as e: + logger.error( + f"工具 '{tool_name}' 参数JSON解析失败: " + f"{tool_call.function.arguments}, 错误: {e}" + ) + tool_response_messages.append( + LLMMessage.tool_response( + tool_call_id=tool_call.id, + function_name=tool_name, + result={ + "error": "Argument JSON parsing failed", + "details": str(e), + }, + ) + ) + except Exception as e: + logger.error(f"执行工具 '{tool_name}' 失败: {e}", e=e) + tool_response_messages.append( + LLMMessage.tool_response( + tool_call_id=tool_call.id, + function_name=tool_name, + result={ + "error": "Tool execution failed", + "details": str(e), + }, + ) + ) + + current_messages.extend(tool_response_messages) + + logger.warning(f"已达到最大工具调用迭代次数 ({max_tool_iterations})。") + raise LLMException( + "已达到最大工具调用迭代次数,但模型仍在请求工具调用或未提供最终文本回复。", + code=LLMErrorCode.GENERATION_FAILED, + details={ + "iterations": max_tool_iterations, + "last_messages": current_messages[-2:], + }, + ) + + async def generate_embeddings( + self, + texts: list[str], + task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT, + **kwargs: Any, + ) -> list[list[float]]: + """生成文本嵌入向量""" + self._check_not_closed() + if not texts: + return [] + + from .adapters import get_adapter_for_api_type + + adapter = get_adapter_for_api_type(self.api_type) + if not adapter: + raise LLMException( + f"未找到适用于 API 类型 '{self.api_type}' 的嵌入适配器", + code=LLMErrorCode.CONFIGURATION_ERROR, + ) + + http_client = await self._get_http_client() + + ai_config = get_ai_config() + default_max_retries = ai_config.get("max_retries_llm", 3) + default_retry_delay = ai_config.get("retry_delay_llm", 2) + max_retries_embed = kwargs.get( + "max_retries_embed", max(1, default_max_retries // 2) + ) + retry_delay_embed = kwargs.get("retry_delay_embed", default_retry_delay / 2) + + retry_config = RetryConfig( + max_retries=max_retries_embed, + retry_delay=retry_delay_embed, + exponential_backoff=True, + key_rotation=True, + ) + + return await with_smart_retry( + self._execute_embedding_request, + adapter, + texts, + task_type, + http_client, + retry_config=retry_config, + key_store=self.key_store, + provider_name=self.provider_name, + ) + + def __str__(self) -> str: + status = "closed" if self._is_closed else "active" + return f"LLMModel({self.provider_name}/{self.model_name}, {status})" + + def __repr__(self) -> str: + status = "closed" if self._is_closed else "active" + return ( + f"LLMModel(provider={self.provider_name}, model={self.model_name}, " + f"api_type={self.api_type}, status={status})" + ) diff --git a/zhenxun/services/llm/types/__init__.py b/zhenxun/services/llm/types/__init__.py new file mode 100644 index 00000000..ebae4185 --- /dev/null +++ b/zhenxun/services/llm/types/__init__.py @@ -0,0 +1,54 @@ +""" +LLM 类型定义模块 + +统一导出所有核心类型、协议和异常定义。 +""" + +from .content import ( + LLMContentPart, + LLMMessage, + LLMResponse, +) +from .enums import EmbeddingTaskType, ModelProvider, ResponseFormat, ToolCategory +from .exceptions import LLMErrorCode, LLMException, get_user_friendly_error_message +from .models import ( + LLMCacheInfo, + LLMCodeExecution, + LLMGroundingAttribution, + LLMGroundingMetadata, + LLMTool, + LLMToolCall, + LLMToolFunction, + ModelDetail, + ModelInfo, + ModelName, + ProviderConfig, + ToolMetadata, + UsageInfo, +) + +__all__ = [ + "EmbeddingTaskType", + "LLMCacheInfo", + "LLMCodeExecution", + "LLMContentPart", + "LLMErrorCode", + "LLMException", + "LLMGroundingAttribution", + "LLMGroundingMetadata", + "LLMMessage", + "LLMResponse", + "LLMTool", + "LLMToolCall", + "LLMToolFunction", + "ModelDetail", + "ModelInfo", + "ModelName", + "ModelProvider", + "ProviderConfig", + "ResponseFormat", + "ToolCategory", + "ToolMetadata", + "UsageInfo", + "get_user_friendly_error_message", +] diff --git a/zhenxun/services/llm/types/content.py b/zhenxun/services/llm/types/content.py new file mode 100644 index 00000000..54887bc3 --- /dev/null +++ b/zhenxun/services/llm/types/content.py @@ -0,0 +1,428 @@ +""" +LLM 内容类型定义 + +包含多模态内容部分、消息和响应的数据模型。 +""" + +import base64 +import mimetypes +from pathlib import Path +from typing import Any + +import aiofiles +from pydantic import BaseModel + +from zhenxun.services.log import logger + + +class LLMContentPart(BaseModel): + """LLM 消息内容部分 - 支持多模态内容""" + + type: str + text: str | None = None + image_source: str | None = None + audio_source: str | None = None + video_source: str | None = None + document_source: str | None = None + file_uri: str | None = None + file_source: str | None = None + url: str | None = None + mime_type: str | None = None + metadata: dict[str, Any] | None = None + + def model_post_init(self, /, __context: Any) -> None: + """验证内容部分的有效性""" + _ = __context + validation_rules = { + "text": lambda: self.text, + "image": lambda: self.image_source, + "audio": lambda: self.audio_source, + "video": lambda: self.video_source, + "document": lambda: self.document_source, + "file": lambda: self.file_uri or self.file_source, + "url": lambda: self.url, + } + + if self.type in validation_rules: + if not validation_rules[self.type](): + raise ValueError(f"{self.type}类型的内容部分必须包含相应字段") + + @classmethod + def text_part(cls, text: str) -> "LLMContentPart": + """创建文本内容部分""" + return cls(type="text", text=text) + + @classmethod + def image_url_part(cls, url: str) -> "LLMContentPart": + """创建图片URL内容部分""" + return cls(type="image", image_source=url) + + @classmethod + def image_base64_part( + cls, data: str, mime_type: str = "image/png" + ) -> "LLMContentPart": + """创建Base64图片内容部分""" + data_url = f"data:{mime_type};base64,{data}" + return cls(type="image", image_source=data_url) + + @classmethod + def audio_url_part(cls, url: str, mime_type: str = "audio/wav") -> "LLMContentPart": + """创建音频URL内容部分""" + return cls(type="audio", audio_source=url, mime_type=mime_type) + + @classmethod + def video_url_part(cls, url: str, mime_type: str = "video/mp4") -> "LLMContentPart": + """创建视频URL内容部分""" + return cls(type="video", video_source=url, mime_type=mime_type) + + @classmethod + def video_base64_part( + cls, data: str, mime_type: str = "video/mp4" + ) -> "LLMContentPart": + """创建Base64视频内容部分""" + data_url = f"data:{mime_type};base64,{data}" + return cls(type="video", video_source=data_url, mime_type=mime_type) + + @classmethod + def audio_base64_part( + cls, data: str, mime_type: str = "audio/wav" + ) -> "LLMContentPart": + """创建Base64音频内容部分""" + data_url = f"data:{mime_type};base64,{data}" + return cls(type="audio", audio_source=data_url, mime_type=mime_type) + + @classmethod + def file_uri_part( + cls, + file_uri: str, + mime_type: str | None = None, + metadata: dict[str, Any] | None = None, + ) -> "LLMContentPart": + """创建Gemini File API URI内容部分""" + return cls( + type="file", + file_uri=file_uri, + mime_type=mime_type, + metadata=metadata or {}, + ) + + @classmethod + async def from_path( + cls, path_like: str | Path, target_api: str | None = None + ) -> "LLMContentPart | None": + """ + 从本地文件路径创建 LLMContentPart。 + 自动检测MIME类型,并根据类型(如图片)可能加载为Base64。 + target_api 可以用于提示如何最好地准备数据(例如 'gemini' 可能偏好 base64) + """ + try: + path = Path(path_like) + if not path.exists() or not path.is_file(): + logger.warning(f"文件不存在或不是一个文件: {path}") + return None + + mime_type, _ = mimetypes.guess_type(path.resolve().as_uri()) + + if not mime_type: + logger.warning( + f"无法猜测文件 {path.name} 的MIME类型,将尝试作为文本文件处理。" + ) + try: + async with aiofiles.open(path, encoding="utf-8") as f: + text_content = await f.read() + return cls.text_part(text_content) + except Exception as e: + logger.error(f"读取文本文件 {path.name} 失败: {e}") + return None + + if mime_type.startswith("image/"): + if target_api == "gemini" or not path.is_absolute(): + try: + async with aiofiles.open(path, "rb") as f: + img_bytes = await f.read() + base64_data = base64.b64encode(img_bytes).decode("utf-8") + return cls.image_base64_part( + data=base64_data, mime_type=mime_type + ) + except Exception as e: + logger.error(f"读取或编码图片文件 {path.name} 失败: {e}") + return None + else: + logger.warning( + f"为本地图片路径 {path.name} 生成 image_url_part。" + "实际API可能不支持 file:// URI。考虑使用Base64或公网URL。" + ) + return cls.image_url_part(url=path.resolve().as_uri()) + elif mime_type.startswith("audio/"): + return cls.audio_url_part( + url=path.resolve().as_uri(), mime_type=mime_type + ) + elif mime_type.startswith("video/"): + if target_api == "gemini": + # 对于 Gemini API,将视频转换为 base64 + try: + async with aiofiles.open(path, "rb") as f: + video_bytes = await f.read() + base64_data = base64.b64encode(video_bytes).decode("utf-8") + return cls.video_base64_part( + data=base64_data, mime_type=mime_type + ) + except Exception as e: + logger.error(f"读取或编码视频文件 {path.name} 失败: {e}") + return None + else: + return cls.video_url_part( + url=path.resolve().as_uri(), mime_type=mime_type + ) + elif ( + mime_type.startswith("text/") + or mime_type == "application/json" + or mime_type == "application/xml" + ): + try: + async with aiofiles.open(path, encoding="utf-8") as f: + text_content = await f.read() + return cls.text_part(text_content) + except Exception as e: + logger.error(f"读取文本类文件 {path.name} 失败: {e}") + return None + else: + logger.info( + f"文件 {path.name} (MIME: {mime_type}) 将作为通用文件URI处理。" + ) + return cls.file_uri_part( + file_uri=path.resolve().as_uri(), + mime_type=mime_type, + metadata={"name": path.name, "source": "local_path"}, + ) + + except Exception as e: + logger.error(f"从路径 {path_like} 创建LLMContentPart时出错: {e}") + return None + + def is_image_url(self) -> bool: + """检查图像源是否为URL""" + if not self.image_source: + return False + return self.image_source.startswith(("http://", "https://")) + + def is_image_base64(self) -> bool: + """检查图像源是否为Base64 Data URL""" + if not self.image_source: + return False + return self.image_source.startswith("data:") + + def get_base64_data(self) -> tuple[str, str] | None: + """从Data URL中提取Base64数据和MIME类型""" + if not self.is_image_base64() or not self.image_source: + return None + + try: + header, data = self.image_source.split(",", 1) + mime_part = header.split(";")[0].replace("data:", "") + return mime_part, data + except (ValueError, IndexError): + logger.warning(f"无法解析Base64图像数据: {self.image_source[:50]}...") + return None + + def convert_for_api(self, api_type: str) -> dict[str, Any]: + """根据API类型转换多模态内容格式""" + if self.type == "text": + if api_type == "openai": + return {"type": "text", "text": self.text} + elif api_type == "gemini": + return {"text": self.text} + else: + return {"type": "text", "text": self.text} + + elif self.type == "image": + if not self.image_source: + raise ValueError("图像类型的内容必须包含image_source") + + if api_type == "openai": + return {"type": "image_url", "image_url": {"url": self.image_source}} + elif api_type == "gemini": + if self.is_image_base64(): + base64_info = self.get_base64_data() + if base64_info: + mime_type, data = base64_info + return {"inlineData": {"mimeType": mime_type, "data": data}} + else: + # 如果无法解析 Base64 数据,抛出异常 + raise ValueError( + f"无法解析Base64图像数据: {self.image_source[:50]}..." + ) + else: + logger.warning( + f"Gemini API需要Base64格式,但提供的是URL: {self.image_source}" + ) + return { + "inlineData": { + "mimeType": "image/jpeg", + "data": self.image_source, + } + } + else: + return {"type": "image_url", "image_url": {"url": self.image_source}} + + elif self.type == "video": + if not self.video_source: + raise ValueError("视频类型的内容必须包含video_source") + + if api_type == "gemini": + # Gemini 支持视频,但需要通过 File API 上传 + if self.video_source.startswith("data:"): + # 处理 base64 视频数据 + try: + header, data = self.video_source.split(",", 1) + mime_type = header.split(";")[0].replace("data:", "") + return {"inlineData": {"mimeType": mime_type, "data": data}} + except (ValueError, IndexError): + raise ValueError( + f"无法解析Base64视频数据: {self.video_source[:50]}..." + ) + else: + # 对于 URL 或其他格式,暂时不支持直接内联 + raise ValueError( + "Gemini API 的视频处理需要通过 File API 上传,不支持直接 URL" + ) + else: + # 其他 API 可能不支持视频 + raise ValueError(f"API类型 '{api_type}' 不支持视频内容") + + elif self.type == "audio": + if not self.audio_source: + raise ValueError("音频类型的内容必须包含audio_source") + + if api_type == "gemini": + # Gemini 支持音频,处理方式类似视频 + if self.audio_source.startswith("data:"): + try: + header, data = self.audio_source.split(",", 1) + mime_type = header.split(";")[0].replace("data:", "") + return {"inlineData": {"mimeType": mime_type, "data": data}} + except (ValueError, IndexError): + raise ValueError( + f"无法解析Base64音频数据: {self.audio_source[:50]}..." + ) + else: + raise ValueError( + "Gemini API 的音频处理需要通过 File API 上传,不支持直接 URL" + ) + else: + raise ValueError(f"API类型 '{api_type}' 不支持音频内容") + + elif self.type == "file": + if api_type == "gemini" and self.file_uri: + return { + "fileData": {"mimeType": self.mime_type, "fileUri": self.file_uri} + } + elif self.file_source: + file_name = ( + self.metadata.get("name", "file") if self.metadata else "file" + ) + if api_type == "gemini": + return {"text": f"[文件: {file_name}]\n{self.file_source}"} + else: + return { + "type": "text", + "text": f"[文件: {file_name}]\n{self.file_source}", + } + else: + raise ValueError("文件类型的内容必须包含file_uri或file_source") + + else: + raise ValueError(f"不支持的内容类型: {self.type}") + + +class LLMMessage(BaseModel): + """LLM 消息""" + + role: str + content: str | list[LLMContentPart] + name: str | None = None + tool_calls: list[Any] | None = None + tool_call_id: str | None = None + + def model_post_init(self, /, __context: Any) -> None: + """验证消息的有效性""" + _ = __context + if self.role == "tool": + if not self.tool_call_id: + raise ValueError("工具角色的消息必须包含 tool_call_id") + if not self.name: + raise ValueError("工具角色的消息必须包含函数名 (在 name 字段中)") + if self.role == "tool" and not isinstance(self.content, str): + logger.warning( + f"工具角色消息的内容期望是字符串,但得到的是: {type(self.content)}. " + "将尝试转换为字符串。" + ) + try: + self.content = str(self.content) + except Exception as e: + raise ValueError(f"无法将工具角色的内容转换为字符串: {e}") + + @classmethod + def user(cls, content: str | list[LLMContentPart]) -> "LLMMessage": + """创建用户消息""" + return cls(role="user", content=content) + + @classmethod + def assistant_tool_calls( + cls, + tool_calls: list[Any], + content: str | list[LLMContentPart] = "", + ) -> "LLMMessage": + """创建助手请求工具调用的消息""" + return cls(role="assistant", content=content, tool_calls=tool_calls) + + @classmethod + def assistant_text_response( + cls, content: str | list[LLMContentPart] + ) -> "LLMMessage": + """创建助手纯文本回复的消息""" + return cls(role="assistant", content=content, tool_calls=None) + + @classmethod + def tool_response( + cls, + tool_call_id: str, + function_name: str, + result: Any, + ) -> "LLMMessage": + """创建工具执行结果的消息""" + import json + + try: + content_str = json.dumps(result) + except TypeError as e: + logger.error( + f"工具 '{function_name}' 的结果无法JSON序列化: {result}. 错误: {e}" + ) + content_str = json.dumps( + {"error": "Tool result not JSON serializable", "details": str(e)} + ) + + return cls( + role="tool", + content=content_str, + tool_call_id=tool_call_id, + name=function_name, + ) + + @classmethod + def system(cls, content: str) -> "LLMMessage": + """创建系统消息""" + return cls(role="system", content=content) + + +class LLMResponse(BaseModel): + """LLM 响应""" + + text: str + usage_info: dict[str, Any] | None = None + raw_response: dict[str, Any] | None = None + tool_calls: list[Any] | None = None + code_executions: list[Any] | None = None + grounding_metadata: Any | None = None + cache_info: Any | None = None diff --git a/zhenxun/services/llm/types/enums.py b/zhenxun/services/llm/types/enums.py new file mode 100644 index 00000000..718a52ef --- /dev/null +++ b/zhenxun/services/llm/types/enums.py @@ -0,0 +1,67 @@ +""" +LLM 枚举类型定义 +""" + +from enum import Enum, auto + + +class ModelProvider(Enum): + """模型提供商枚举""" + + OPENAI = "openai" + GEMINI = "gemini" + ZHIXPU = "zhipu" + CUSTOM = "custom" + + +class ResponseFormat(Enum): + """响应格式枚举""" + + TEXT = "text" + JSON = "json" + MULTIMODAL = "multimodal" + + +class EmbeddingTaskType(str, Enum): + """文本嵌入任务类型 (主要用于Gemini)""" + + RETRIEVAL_QUERY = "RETRIEVAL_QUERY" + RETRIEVAL_DOCUMENT = "RETRIEVAL_DOCUMENT" + SEMANTIC_SIMILARITY = "SEMANTIC_SIMILARITY" + CLASSIFICATION = "CLASSIFICATION" + CLUSTERING = "CLUSTERING" + QUESTION_ANSWERING = "QUESTION_ANSWERING" + FACT_VERIFICATION = "FACT_VERIFICATION" + + +class ToolCategory(Enum): + """工具分类枚举""" + + FILE_SYSTEM = auto() + NETWORK = auto() + SYSTEM_INFO = auto() + CALCULATION = auto() + DATA_PROCESSING = auto() + CUSTOM = auto() + + +class LLMErrorCode(Enum): + """LLM 服务相关的错误代码枚举""" + + MODEL_INIT_FAILED = 2000 + MODEL_NOT_FOUND = 2001 + API_REQUEST_FAILED = 2002 + API_RESPONSE_INVALID = 2003 + API_KEY_INVALID = 2004 + API_QUOTA_EXCEEDED = 2005 + API_TIMEOUT = 2006 + API_RATE_LIMITED = 2007 + NO_AVAILABLE_KEYS = 2008 + UNKNOWN_API_TYPE = 2009 + CONFIGURATION_ERROR = 2010 + RESPONSE_PARSE_ERROR = 2011 + CONTEXT_LENGTH_EXCEEDED = 2012 + CONTENT_FILTERED = 2013 + USER_LOCATION_NOT_SUPPORTED = 2014 + GENERATION_FAILED = 2015 + EMBEDDING_FAILED = 2016 diff --git a/zhenxun/services/llm/types/exceptions.py b/zhenxun/services/llm/types/exceptions.py new file mode 100644 index 00000000..623d4c26 --- /dev/null +++ b/zhenxun/services/llm/types/exceptions.py @@ -0,0 +1,80 @@ +""" +LLM 异常类型定义 +""" + +from typing import Any + +from .enums import LLMErrorCode + + +class LLMException(Exception): + """LLM 服务相关的基础异常类""" + + def __init__( + self, + message: str, + code: LLMErrorCode = LLMErrorCode.API_REQUEST_FAILED, + details: dict[str, Any] | None = None, + recoverable: bool = True, + cause: Exception | None = None, + ): + self.message = message + self.code = code + self.details = details or {} + self.recoverable = recoverable + self.cause = cause + super().__init__(message) + + def __str__(self) -> str: + if self.details: + return f"{self.message} (错误码: {self.code.name}, 详情: {self.details})" + return f"{self.message} (错误码: {self.code.name})" + + @property + def user_friendly_message(self) -> str: + """返回适合向用户展示的错误消息""" + error_messages = { + LLMErrorCode.MODEL_NOT_FOUND: "AI模型未找到,请检查配置或联系管理员。", + LLMErrorCode.API_KEY_INVALID: "API密钥无效,请联系管理员更新配置。", + LLMErrorCode.API_QUOTA_EXCEEDED: ( + "API使用配额已用尽,请稍后再试或联系管理员。" + ), + LLMErrorCode.API_TIMEOUT: "AI服务响应超时,请稍后再试。", + LLMErrorCode.API_RATE_LIMITED: "请求过于频繁,已被AI服务限流,请稍后再试。", + LLMErrorCode.MODEL_INIT_FAILED: "AI模型初始化失败,请联系管理员检查配置。", + LLMErrorCode.NO_AVAILABLE_KEYS: ( + "当前所有API密钥均不可用,请稍后再试或联系管理员。" + ), + LLMErrorCode.USER_LOCATION_NOT_SUPPORTED: ( + "当前地区暂不支持此AI服务,请联系管理员或尝试其他模型。" + ), + LLMErrorCode.API_REQUEST_FAILED: "AI服务请求失败,请稍后再试。", + LLMErrorCode.API_RESPONSE_INVALID: "AI服务响应异常,请稍后再试。", + LLMErrorCode.CONFIGURATION_ERROR: "AI服务配置错误,请联系管理员。", + LLMErrorCode.CONTEXT_LENGTH_EXCEEDED: "输入内容过长,请缩短后重试。", + LLMErrorCode.CONTENT_FILTERED: "内容被安全过滤,请修改后重试。", + LLMErrorCode.RESPONSE_PARSE_ERROR: "AI服务响应解析失败,请稍后再试。", + LLMErrorCode.UNKNOWN_API_TYPE: "不支持的AI服务类型,请联系管理员。", + } + return error_messages.get(self.code, "AI服务暂时不可用,请稍后再试。") + + +def get_user_friendly_error_message(error: Exception) -> str: + """将任何异常转换为用户友好的错误消息""" + if isinstance(error, LLMException): + return error.user_friendly_message + + error_str = str(error).lower() + + if "timeout" in error_str or "超时" in error_str: + return "请求超时,请稍后再试。" + elif "connection" in error_str or "连接" in error_str: + return "网络连接失败,请检查网络后重试。" + elif "permission" in error_str or "权限" in error_str: + return "权限不足,请联系管理员。" + elif "not found" in error_str or "未找到" in error_str: + return "请求的资源未找到,请检查配置。" + elif "invalid" in error_str or "无效" in error_str: + return "请求参数无效,请检查输入。" + else: + return "服务暂时不可用,请稍后再试。" diff --git a/zhenxun/services/llm/types/models.py b/zhenxun/services/llm/types/models.py new file mode 100644 index 00000000..c5f541bc --- /dev/null +++ b/zhenxun/services/llm/types/models.py @@ -0,0 +1,160 @@ +""" +LLM 数据模型定义 + +包含模型信息、配置、工具定义和响应数据的模型类。 +""" + +from dataclasses import dataclass, field +from typing import Any + +from pydantic import BaseModel, Field + +from .enums import ModelProvider, ToolCategory + +ModelName = str | None + + +@dataclass(frozen=True) +class ModelInfo: + """模型信息(不可变数据类)""" + + name: str + provider: ModelProvider + max_tokens: int = 4096 + supports_tools: bool = False + supports_vision: bool = False + supports_audio: bool = False + cost_per_1k_tokens: float = 0.0 + + +@dataclass +class UsageInfo: + """使用信息数据类""" + + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + cost: float = 0.0 + + @property + def efficiency_ratio(self) -> float: + """计算效率比(输出/输入)""" + return self.completion_tokens / max(self.prompt_tokens, 1) + + +@dataclass +class ToolMetadata: + """工具元数据""" + + name: str + description: str + category: ToolCategory + read_only: bool = True + destructive: bool = False + open_world: bool = False + parameters: dict[str, Any] = field(default_factory=dict) + required_params: list[str] = field(default_factory=list) + + +class ModelDetail(BaseModel): + """模型详细信息""" + + model_name: str + is_available: bool = True + is_embedding_model: bool = False + temperature: float | None = None + max_tokens: int | None = None + + +class ProviderConfig(BaseModel): + """LLM 提供商配置""" + + name: str = Field(..., description="Provider 的唯一名称标识") + api_key: str | list[str] = Field(..., description="API Key 或 Key 列表") + api_base: str | None = Field(None, description="API Base URL,如果为空则使用默认值") + api_type: str = Field(default="openai", description="API 类型") + openai_compat: bool = Field(default=False, description="是否使用 OpenAI 兼容模式") + temperature: float | None = Field(default=0.7, description="默认温度参数") + max_tokens: int | None = Field(default=None, description="默认最大输出 token 限制") + models: list[ModelDetail] = Field(..., description="支持的模型列表") + timeout: int = Field(default=180, description="请求超时时间") + proxy: str | None = Field(default=None, description="代理设置") + + +class LLMToolFunction(BaseModel): + """LLM 工具函数定义""" + + name: str + arguments: str + + +class LLMToolCall(BaseModel): + """LLM 工具调用""" + + id: str + function: LLMToolFunction + + +class LLMTool(BaseModel): + """LLM 工具定义(支持 MCP 风格)""" + + type: str = "function" + function: dict[str, Any] + annotations: dict[str, Any] | None = Field(default=None, description="工具注解") + + @classmethod + def create( + cls, + name: str, + description: str, + parameters: dict[str, Any], + required: list[str] | None = None, + annotations: dict[str, Any] | None = None, + ) -> "LLMTool": + """创建工具""" + function_def = { + "name": name, + "description": description, + "parameters": { + "type": "object", + "properties": parameters, + "required": required or [], + }, + } + return cls(type="function", function=function_def, annotations=annotations) + + +class LLMCodeExecution(BaseModel): + """代码执行结果""" + + code: str + output: str | None = None + error: str | None = None + execution_time: float | None = None + files_generated: list[str] | None = None + + +class LLMGroundingAttribution(BaseModel): + """信息来源关联""" + + title: str | None = None + uri: str | None = None + snippet: str | None = None + confidence_score: float | None = None + + +class LLMGroundingMetadata(BaseModel): + """信息来源关联元数据""" + + web_search_queries: list[str] | None = None + grounding_attributions: list[LLMGroundingAttribution] | None = None + search_suggestions: list[dict[str, Any]] | None = None + + +class LLMCacheInfo(BaseModel): + """缓存信息""" + + cache_hit: bool = False + cache_key: str | None = None + cache_ttl: int | None = None + created_at: str | None = None diff --git a/zhenxun/services/llm/utils.py b/zhenxun/services/llm/utils.py new file mode 100644 index 00000000..3610df27 --- /dev/null +++ b/zhenxun/services/llm/utils.py @@ -0,0 +1,218 @@ +""" +LLM 模块的工具和转换函数 +""" + +import base64 +from pathlib import Path + +from nonebot_plugin_alconna.uniseg import ( + At, + File, + Image, + Reply, + Text, + UniMessage, + Video, + Voice, +) + +from zhenxun.services.log import logger + +from .types import LLMContentPart + + +async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]: + """ + 将 UniMessage 实例转换为一个 LLMContentPart 列表。 + 这是处理多模态输入的核心转换逻辑。 + """ + parts: list[LLMContentPart] = [] + for seg in message: + part = None + if isinstance(seg, Text): + if seg.text.strip(): + part = LLMContentPart.text_part(seg.text) + elif isinstance(seg, Image): + if seg.path: + part = await LLMContentPart.from_path(seg.path, target_api="gemini") + elif seg.url: + part = LLMContentPart.image_url_part(seg.url) + elif hasattr(seg, "raw") and seg.raw: + mime_type = ( + getattr(seg, "mimetype", "image/png") + if hasattr(seg, "mimetype") + else "image/png" + ) + if isinstance(seg.raw, bytes): + b64_data = base64.b64encode(seg.raw).decode("utf-8") + part = LLMContentPart.image_base64_part(b64_data, mime_type) + + elif isinstance(seg, File | Voice | Video): + if seg.path: + part = await LLMContentPart.from_path(seg.path) + elif seg.url: + logger.warning( + f"直接使用 URL 的 {type(seg).__name__} 段," + f"API 可能不支持: {seg.url}" + ) + part = LLMContentPart.text_part( + f"[{type(seg).__name__.upper()} FILE: {seg.name or seg.url}]" + ) + elif hasattr(seg, "raw") and seg.raw: + mime_type = getattr(seg, "mimetype", None) + if isinstance(seg.raw, bytes): + b64_data = base64.b64encode(seg.raw).decode("utf-8") + + if isinstance(seg, Video): + if not mime_type: + mime_type = "video/mp4" + part = LLMContentPart.video_base64_part( + data=b64_data, mime_type=mime_type + ) + logger.debug( + f"处理视频字节数据: {mime_type}, 大小: {len(seg.raw)} bytes" + ) + elif isinstance(seg, Voice): + if not mime_type: + mime_type = "audio/wav" + part = LLMContentPart.audio_base64_part( + data=b64_data, mime_type=mime_type + ) + logger.debug( + f"处理音频字节数据: {mime_type}, 大小: {len(seg.raw)} bytes" + ) + else: + part = LLMContentPart.text_part( + f"[FILE: {mime_type or 'unknown'}, {len(seg.raw)} bytes]" + ) + logger.debug( + f"处理其他文件字节数据: {mime_type}, " + f"大小: {len(seg.raw)} bytes" + ) + + elif isinstance(seg, At): + if seg.flag == "all": + part = LLMContentPart.text_part("[Mentioned Everyone]") + else: + part = LLMContentPart.text_part(f"[Mentioned user: {seg.target}]") + + elif isinstance(seg, Reply): + if seg.msg: + try: + extract_method = getattr(seg.msg, "extract_plain_text", None) + if extract_method and callable(extract_method): + reply_text = str(extract_method()).strip() + else: + reply_text = str(seg.msg).strip() + if reply_text: + part = LLMContentPart.text_part( + f'[Replied to: "{reply_text[:50]}..."]' + ) + except Exception: + part = LLMContentPart.text_part("[Replied to a message]") + + if part: + parts.append(part) + + return parts + + +def create_multimodal_message( + text: str | None = None, + images: list[str | Path | bytes] | str | Path | bytes | None = None, + videos: list[str | Path | bytes] | str | Path | bytes | None = None, + audios: list[str | Path | bytes] | str | Path | bytes | None = None, + image_mimetypes: list[str] | str | None = None, + video_mimetypes: list[str] | str | None = None, + audio_mimetypes: list[str] | str | None = None, +) -> UniMessage: + """ + 创建多模态消息的便捷函数,方便第三方调用。 + + Args: + text: 文本内容 + images: 图片数据,支持路径、字节数据或URL + videos: 视频数据,支持路径、字节数据或URL + audios: 音频数据,支持路径、字节数据或URL + image_mimetypes: 图片MIME类型,当images为bytes时需要指定 + video_mimetypes: 视频MIME类型,当videos为bytes时需要指定 + audio_mimetypes: 音频MIME类型,当audios为bytes时需要指定 + + Returns: + UniMessage: 构建好的多模态消息 + + Examples: + # 纯文本 + msg = create_multimodal_message("请分析这段文字") + + # 文本 + 单张图片(路径) + msg = create_multimodal_message("分析图片", images="/path/to/image.jpg") + + # 文本 + 多张图片 + msg = create_multimodal_message( + "比较图片", images=["/path/1.jpg", "/path/2.jpg"] + ) + + # 文本 + 图片字节数据 + msg = create_multimodal_message( + "分析", images=image_data, image_mimetypes="image/jpeg" + ) + + # 文本 + 视频 + msg = create_multimodal_message("分析视频", videos="/path/to/video.mp4") + + # 文本 + 音频 + msg = create_multimodal_message("转录音频", audios="/path/to/audio.wav") + + # 混合多模态 + msg = create_multimodal_message( + "分析这些媒体文件", + images="/path/to/image.jpg", + videos="/path/to/video.mp4", + audios="/path/to/audio.wav" + ) + """ + message = UniMessage() + + if text: + message.append(Text(text)) + + if images is not None: + _add_media_to_message(message, images, image_mimetypes, Image, "image/png") + + if videos is not None: + _add_media_to_message(message, videos, video_mimetypes, Video, "video/mp4") + + if audios is not None: + _add_media_to_message(message, audios, audio_mimetypes, Voice, "audio/wav") + + return message + + +def _add_media_to_message( + message: UniMessage, + media_items: list[str | Path | bytes] | str | Path | bytes, + mimetypes: list[str] | str | None, + media_class: type, + default_mimetype: str, +) -> None: + """添加媒体文件到 UniMessage 的辅助函数""" + if not isinstance(media_items, list): + media_items = [media_items] + + mime_list = [] + if mimetypes is not None: + if isinstance(mimetypes, str): + mime_list = [mimetypes] * len(media_items) + else: + mime_list = list(mimetypes) + + for i, item in enumerate(media_items): + if isinstance(item, str | Path): + if str(item).startswith(("http://", "https://")): + message.append(media_class(url=str(item))) + else: + message.append(media_class(path=Path(item))) + elif isinstance(item, bytes): + mimetype = mime_list[i] if i < len(mime_list) else default_mimetype + message.append(media_class(raw=item, mimetype=mimetype)) diff --git a/zhenxun/services/log.py b/zhenxun/services/log.py index 96a45bce..beb2b9c0 100644 --- a/zhenxun/services/log.py +++ b/zhenxun/services/log.py @@ -1,4 +1,4 @@ -from datetime import datetime, timedelta +from datetime import timedelta from typing import Any, overload import nonebot @@ -17,7 +17,7 @@ driver = nonebot.get_driver() log_level = driver.config.log_level or "INFO" logger_.add( - LOG_PATH / f"{datetime.now().date()}.log", + LOG_PATH / "{time:YYYY-MM-DD}.log", level=log_level, rotation="00:00", format=default_format, @@ -26,7 +26,7 @@ logger_.add( ) logger_.add( - LOG_PATH / f"error_{datetime.now().date()}.log", + LOG_PATH / "error_{time:YYYY-MM-DD}.log", level="ERROR", rotation="00:00", format=default_format, @@ -36,26 +36,92 @@ logger_.add( class logger: - TEMPLATE_A = "Adapter[{}] {}" - TEMPLATE_B = "Adapter[{}] [{}]: {}" - TEMPLATE_C = "Adapter[{}] 用户[{}] 触发 [{}]: {}" - TEMPLATE_D = "Adapter[{}] 群聊[{}] 用户[{}] 触发" - " [{}]: {}" - TEMPLATE_E = "Adapter[{}] 群聊[{}] 用户[{}] 触发" - " [{}] [Target]({}): {}" - - TEMPLATE_ADAPTER = "Adapter[{}] " - TEMPLATE_USER = "用户[{}] " - TEMPLATE_GROUP = "群聊[{}] " - TEMPLATE_COMMAND = "CMD[{}] " - TEMPLATE_PLATFORM = "平台[{}] " - TEMPLATE_TARGET = "[Target]([{}]) " + """ + 一个经过优化的、支持多种上下文和格式的日志记录器。 + """ + TEMPLATE_ADAPTER = "Adapter[{}]" + TEMPLATE_USER = "用户[{}]" + TEMPLATE_GROUP = "群聊[{}]" + TEMPLATE_COMMAND = "CMD[{}]" + TEMPLATE_PLATFORM = "平台[{}]" + TEMPLATE_TARGET = "[Target]([{}])" SUCCESS_TEMPLATE = "[{}]: {} | 参数[{}] 返回: [{}]" - WARNING_TEMPLATE = "[{}]: {}" + @classmethod + def __parser_template( + cls, + info: str, + command: str | None = None, + user_id: int | str | None = None, + group_id: int | str | None = None, + adapter: str | None = None, + target: Any = None, + platform: str | None = None, + ) -> str: + """ + 优化后的模板解析器,构建并连接日志信息片段。 + """ + parts = [] + if adapter: + parts.append(cls.TEMPLATE_ADAPTER.format(adapter)) + if platform: + parts.append(cls.TEMPLATE_PLATFORM.format(platform)) + if group_id: + parts.append(cls.TEMPLATE_GROUP.format(group_id)) + if user_id: + parts.append(cls.TEMPLATE_USER.format(user_id)) + if command: + parts.append(cls.TEMPLATE_COMMAND.format(command)) + if target: + parts.append(cls.TEMPLATE_TARGET.format(target)) - ERROR_TEMPLATE = "[{}]: {}" + parts.append(info) + return " ".join(parts) + + @classmethod + def _log( + cls, + level: str, + info: str, + command: str | None = None, + session: int | str | Session | uninfoSession | None = None, + group_id: int | str | None = None, + adapter: str | None = None, + target: Any = None, + platform: str | None = None, + e: Exception | None = None, + ): + """ + 核心日志处理方法,处理所有日志级别的通用逻辑。 + """ + user_id: str | None = str(session) if isinstance(session, int | str) else None + + if isinstance(session, Session): + user_id = session.id1 + adapter = session.bot_type + group_id = f"{session.id3}:{session.id2}" if session.id3 else session.id2 + platform = platform or session.platform + elif isinstance(session, uninfoSession): + user_id = session.user.id + adapter = session.adapter + if session.group: + group_id = session.group.id + platform = session.basic.get("scope") + + template = cls.__parser_template( + info, command, user_id, group_id, adapter, target, platform + ) + + if e: + template += f" || 错误 {type(e).__name__}: {e}" + + try: + log_func = getattr(logger_.opt(colors=True), level) + log_func(template) + except Exception: + log_func_fallback = getattr(logger_, level) + log_func_fallback(template) @overload @classmethod @@ -70,7 +136,6 @@ class logger: target: Any = None, platform: str | None = None, ): ... - @overload @classmethod def info( @@ -82,7 +147,6 @@ class logger: target: Any = None, platform: str | None = None, ): ... - @overload @classmethod def info( @@ -107,28 +171,16 @@ class logger: target: Any = None, platform: str | None = None, ): - user_id: str | None = session # type: ignore - if isinstance(session, Session): - user_id = session.id1 - adapter = session.bot_type - if session.id3: - group_id = f"{session.id3}:{session.id2}" - elif session.id2: - group_id = f"{session.id2}" - platform = platform or session.platform - elif isinstance(session, uninfoSession): - user_id = session.user.id - adapter = session.adapter - if session.group: - group_id = session.group.id - platform = session.basic["scope"] - template = cls.__parser_template( - info, command, user_id, group_id, adapter, target, platform + cls._log( + "info", + info=info, + command=command, + session=session, + group_id=group_id, + adapter=adapter, + target=target, + platform=platform, ) - try: - logger_.opt(colors=True).info(template) - except Exception: - logger_.info(template) @classmethod def success( @@ -138,9 +190,11 @@ class logger: param: dict[str, Any] | None = None, result: str = "", ): - param_str = "" - if param: - param_str = ",".join([f"{k}:{v}" for k, v in param.items()]) + param_str = ( + ",".join([f"{k}:{v}" for k, v in param.items()]) + if param + else "" + ) logger_.opt(colors=True).success( cls.SUCCESS_TEMPLATE.format(command, info, param_str, result) ) @@ -159,7 +213,6 @@ class logger: platform: str | None = None, e: Exception | None = None, ): ... - @overload @classmethod def warning( @@ -168,12 +221,10 @@ class logger: command: str | None = None, *, session: Session | None = None, - adapter: str | None = None, target: Any = None, platform: str | None = None, e: Exception | None = None, ): ... - @overload @classmethod def warning( @@ -182,7 +233,6 @@ class logger: command: str | None = None, *, session: uninfoSession | None = None, - adapter: str | None = None, target: Any = None, platform: str | None = None, e: Exception | None = None, @@ -201,30 +251,17 @@ class logger: platform: str | None = None, e: Exception | None = None, ): - user_id: str | None = session # type: ignore - if isinstance(session, Session): - user_id = session.id1 - adapter = session.bot_type - if session.id3: - group_id = f"{session.id3}:{session.id2}" - elif session.id2: - group_id = f"{session.id2}" - platform = platform or session.platform - elif isinstance(session, uninfoSession): - user_id = session.user.id - adapter = session.adapter - if session.group: - group_id = session.group.id - platform = session.basic["scope"] - template = cls.__parser_template( - info, command, user_id, group_id, adapter, target, platform + cls._log( + "warning", + info=info, + command=command, + session=session, + group_id=group_id, + adapter=adapter, + target=target, + platform=platform, + e=e, ) - if e: - template += f" || 错误{type(e)}: {e}" - try: - logger_.opt(colors=True).warning(template) - except Exception as e: - logger_.warning(template) @overload @classmethod @@ -240,7 +277,6 @@ class logger: platform: str | None = None, e: Exception | None = None, ): ... - @overload @classmethod def error( @@ -253,7 +289,6 @@ class logger: platform: str | None = None, e: Exception | None = None, ): ... - @overload @classmethod def error( @@ -280,30 +315,17 @@ class logger: platform: str | None = None, e: Exception | None = None, ): - user_id: str | None = session # type: ignore - if isinstance(session, Session): - user_id = session.id1 - adapter = session.bot_type - if session.id3: - group_id = f"{session.id3}:{session.id2}" - elif session.id2: - group_id = f"{session.id2}" - platform = platform or session.platform - elif isinstance(session, uninfoSession): - user_id = session.user.id - adapter = session.adapter - if session.group: - group_id = session.group.id - platform = session.basic["scope"] - template = cls.__parser_template( - info, command, user_id, group_id, adapter, target, platform + cls._log( + "error", + info=info, + command=command, + session=session, + group_id=group_id, + adapter=adapter, + target=target, + platform=platform, + e=e, ) - if e: - template += f" || 错误 {type(e)}: {e}" - try: - logger_.opt(colors=True).error(template) - except Exception as e: - logger_.error(template) @overload @classmethod @@ -319,7 +341,6 @@ class logger: platform: str | None = None, e: Exception | None = None, ): ... - @overload @classmethod def debug( @@ -332,7 +353,6 @@ class logger: platform: str | None = None, e: Exception | None = None, ): ... - @overload @classmethod def debug( @@ -359,62 +379,78 @@ class logger: platform: str | None = None, e: Exception | None = None, ): - user_id: str | None = session # type: ignore - if isinstance(session, Session): - user_id = session.id1 - adapter = session.bot_type - if session.id3: - group_id = f"{session.id3}:{session.id2}" - elif session.id2: - group_id = f"{session.id2}" - platform = platform or session.platform - elif isinstance(session, uninfoSession): - user_id = session.user.id - adapter = session.adapter - if session.group: - group_id = session.group.id - platform = session.basic["scope"] - template = cls.__parser_template( - info, command, user_id, group_id, adapter, target, platform + cls._log( + "debug", + info=info, + command=command, + session=session, + group_id=group_id, + adapter=adapter, + target=target, + platform=platform, + e=e, ) - if e: - template += f" || 错误 {type(e)}: {e}" - try: - logger_.opt(colors=True).debug(template) - except Exception as e: - logger_.debug(template) + @overload @classmethod - def __parser_template( + def trace( cls, info: str, command: str | None = None, - user_id: int | str | None = None, + *, + session: int | str | None = None, group_id: int | str | None = None, adapter: str | None = None, target: Any = None, platform: str | None = None, - ) -> str: - arg_list = [] - template = "" - if adapter is not None: - template += cls.TEMPLATE_ADAPTER - arg_list.append(adapter) - if platform is not None: - template += cls.TEMPLATE_PLATFORM - arg_list.append(platform) - if group_id is not None: - template += cls.TEMPLATE_GROUP - arg_list.append(group_id) - if user_id is not None: - template += cls.TEMPLATE_USER - arg_list.append(user_id) - if command is not None: - template += cls.TEMPLATE_COMMAND - arg_list.append(command) - if target is not None: - template += cls.TEMPLATE_TARGET - arg_list.append(target) - arg_list.append(info) - template += "{}" - return template.format(*arg_list) + e: Exception | None = None, + ): ... + @overload + @classmethod + def trace( + cls, + info: str, + command: str | None = None, + *, + session: Session | None = None, + target: Any = None, + platform: str | None = None, + e: Exception | None = None, + ): ... + @overload + @classmethod + def trace( + cls, + info: str, + command: str | None = None, + *, + session: uninfoSession | None = None, + target: Any = None, + platform: str | None = None, + e: Exception | None = None, + ): ... + + @classmethod + def trace( + cls, + info: str, + command: str | None = None, + *, + session: int | str | Session | uninfoSession | None = None, + group_id: int | str | None = None, + adapter: str | None = None, + target: Any = None, + platform: str | None = None, + e: Exception | None = None, + ): + cls._log( + "trace", + info=info, + command=command, + session=session, + group_id=group_id, + adapter=adapter, + target=target, + platform=platform, + e=e, + ) diff --git a/zhenxun/services/plugin_init.py b/zhenxun/services/plugin_init.py index 159e042c..a622a9e8 100644 --- a/zhenxun/services/plugin_init.py +++ b/zhenxun/services/plugin_init.py @@ -6,6 +6,7 @@ from nonebot.utils import is_coroutine_callable from pydantic import BaseModel from zhenxun.services.log import logger +from zhenxun.utils.manager.priority_manager import PriorityLifecycle driver = nonebot.get_driver() @@ -100,6 +101,6 @@ class PluginInitManager: logger.error(f"执行: {module_path}:remove 失败", e=e) -@driver.on_startup +@PriorityLifecycle.on_startup(priority=5) async def _(): await PluginInitManager.install_all() diff --git a/zhenxun/utils/_build_mat.py b/zhenxun/utils/_build_mat.py index de73e69d..a3de3087 100644 --- a/zhenxun/utils/_build_mat.py +++ b/zhenxun/utils/_build_mat.py @@ -1,12 +1,17 @@ from io import BytesIO from pathlib import Path import random +import sys from pydantic import BaseModel, Field -from strenum import StrEnum from ._build_image import BuildImage +if sys.version_info >= (3, 11): + from enum import StrEnum +else: + from strenum import StrEnum + class MatType(StrEnum): LINE = "LINE" diff --git a/zhenxun/utils/_image_template.py b/zhenxun/utils/_image_template.py index 7f27db76..c7678b2f 100644 --- a/zhenxun/utils/_image_template.py +++ b/zhenxun/utils/_image_template.py @@ -3,9 +3,12 @@ from io import BytesIO from pathlib import Path import random +from nonebot_plugin_htmlrender import md_to_pic, template_to_pic from PIL.ImageFont import FreeTypeFont from pydantic import BaseModel +from zhenxun.configs.path_config import TEMPLATE_PATH + from ._build_image import BuildImage @@ -283,3 +286,191 @@ class ImageTemplate: width = max(width, w) height += h return width, height + + +class MarkdownTable: + def __init__(self, headers: list[str], rows: list[list[str]]): + self.headers = headers + self.rows = rows + + def to_markdown(self) -> str: + """将表格转换为Markdown格式""" + header_row = "| " + " | ".join(self.headers) + " |" + separator_row = "| " + " | ".join(["---"] * len(self.headers)) + " |" + data_rows = "\n".join( + "| " + " | ".join(map(str, row)) + " |" for row in self.rows + ) + return f"{header_row}\n{separator_row}\n{data_rows}" + + +class Markdown: + def __init__(self, data: list[str] | None = None): + if data is None: + data = [] + self._data = data + + def text(self, text: str) -> "Markdown": + """添加Markdown文本""" + self._data.append(text) + return self + + def head(self, text: str, level: int = 1) -> "Markdown": + """添加Markdown标题""" + if level < 1 or level > 6: + raise ValueError("标题级别必须在1到6之间") + self._data.append(f"{'#' * level} {text}") + return self + + def image(self, content: str | Path, add_empty_line: bool = True) -> "Markdown": + """添加Markdown图片 + + 参数: + content: 图片内容,可以是url地址,图片路径或base64字符串. + add_empty_line: 默认添加换行. + + 返回: + Markdown: Markdown + """ + if isinstance(content, Path): + content = str(content.absolute()) + if content.startswith("base64"): + content = f"data:image/png;base64,{content.split('base64://', 1)[-1]}" + self._data.append(f"![image]({content})") + if add_empty_line: + self._add_empty_line() + return self + + def quote(self, text: str | list[str]) -> "Markdown": + """添加Markdown引用文本 + + 参数: + text: 引用文本内容,可以是字符串或字符串列表. + 如果是列表,则每个元素都会被单独引用。 + + 返回: + Markdown: Markdown + """ + if isinstance(text, str): + self._data.append(f"> {text}") + elif isinstance(text, list): + for t in text: + self._data.append(f"> {t}") + self._add_empty_line() + return self + + def code(self, code: str, language: str = "python") -> "Markdown": + """添加Markdown代码块""" + self._data.append(f"```{language}\n{code}\n```") + return self + + def table(self, headers: list[str], rows: list[list[str]]) -> "Markdown": + """添加Markdown表格""" + table = MarkdownTable(headers, rows) + self._data.append(table.to_markdown()) + return self + + def list(self, items: list[str | list[str]]) -> "Markdown": + """添加Markdown列表""" + self._add_empty_line() + _text = "\n".join( + f"- {item}" + if isinstance(item, str) + else "\n".join(f"- {sub_item}" for sub_item in item) + for item in items + ) + self._data.append(_text) + return self + + def _add_empty_line(self): + """添加空行""" + self._data.append("") + + async def build(self, width: int = 800, css_path: Path | None = None) -> bytes: + """构建Markdown文本""" + if css_path is not None: + return await md_to_pic( + md="\n".join(self._data), width=width, css_path=str(css_path.absolute()) + ) + return await md_to_pic(md="\n".join(self._data), width=width) + + +class Notebook: + def __init__(self, data: list[dict] | None = None): + self._data = data if data is not None else [] + + def text(self, text: str) -> "Notebook": + """添加Notebook文本""" + self._data.append({"type": "paragraph", "text": text}) + return self + + def head(self, text: str, level: int = 1) -> "Notebook": + """添加Notebook标题""" + if not 1 <= level <= 4: + raise ValueError("标题级别必须在1-4之间") + self._data.append({"type": "heading", "text": text, "level": level}) + return self + + def image( + self, + content: str | Path, + caption: str | None = None, + ) -> "Notebook": + """添加Notebook图片 + + 参数: + content: 图片内容,可以是url地址,图片路径或base64字符串. + caption: 图片说明. + + 返回: + Notebook: Notebook + """ + if isinstance(content, Path): + content = str(content.absolute()) + if content.startswith("base64"): + content = f"data:image/png;base64,{content.split('base64://', 1)[-1]}" + self._data.append({"type": "image", "src": content, "caption": caption}) + return self + + def quote(self, text: str | list[str]) -> "Notebook": + """添加Notebook引用文本 + + 参数: + text: 引用文本内容,可以是字符串或字符串列表. + 如果是列表,则每个元素都会被单独引用。 + + 返回: + Notebook: Notebook + """ + if isinstance(text, str): + self._data.append({"type": "blockquote", "text": text}) + elif isinstance(text, list): + for t in text: + self._data.append({"type": "blockquote", "text": text}) + return self + + def code(self, code: str, language: str = "python") -> "Notebook": + """添加Notebook代码块""" + self._data.append({"type": "code", "code": code, "language": language}) + return self + + def list(self, items: list[str], ordered: bool = False) -> "Notebook": + """添加Notebook列表""" + self._data.append({"type": "list", "data": items, "ordered": ordered}) + return self + + def add_divider(self) -> None: + """添加分隔线""" + self._data.append({"type": "divider"}) + + async def build(self) -> bytes: + """构建Notebook""" + return await template_to_pic( + template_path=str((TEMPLATE_PATH / "notebook").absolute()), + template_name="main.html", + templates={"elements": self._data}, + pages={ + "viewport": {"width": 700, "height": 1000}, + "base_url": f"file://{TEMPLATE_PATH}", + }, + wait=2, + ) diff --git a/zhenxun/utils/browser.py b/zhenxun/utils/browser.py index ca2e7755..310ed606 100644 --- a/zhenxun/utils/browser.py +++ b/zhenxun/utils/browser.py @@ -1,91 +1,94 @@ -import os -import sys +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from pathlib import Path +from typing import Any, Literal -from nonebot import get_driver -from playwright.__main__ import main -from playwright.async_api import Browser, Playwright, async_playwright +from nonebot_plugin_alconna import UniMessage +from nonebot_plugin_htmlrender import get_browser +from playwright.async_api import Page -from zhenxun.configs.config import BotConfig -from zhenxun.services.log import logger - -driver = get_driver() - -_playwright: Playwright | None = None -_browser: Browser | None = None +from zhenxun.utils.message import MessageUtils -# @driver.on_startup -# async def start_browser(): -# global _playwright -# global _browser -# install() -# await check_playwright_env() -# _playwright = await async_playwright().start() -# _browser = await _playwright.chromium.launch() +class BrowserIsNone(Exception): + pass -# @driver.on_shutdown -# async def shutdown_browser(): -# if _browser: -# await _browser.close() -# if _playwright: -# await _playwright.stop() # type: ignore +class AsyncPlaywright: + @classmethod + @asynccontextmanager + async def new_page( + cls, cookies: list[dict[str, Any]] | dict[str, Any] | None = None, **kwargs + ) -> AsyncGenerator[Page, None]: + """获取一个新页面 - -# def get_browser() -> Browser: -# if not _browser: -# raise RuntimeError("playwright is not initalized") -# return _browser - - -def install(): - """自动安装、更新 Chromium""" - - def set_env_variables(): - os.environ["PLAYWRIGHT_DOWNLOAD_HOST"] = ( - "https://npmmirror.com/mirrors/playwright/" - ) - if BotConfig.system_proxy: - os.environ["HTTPS_PROXY"] = BotConfig.system_proxy - - def restore_env_variables(): - os.environ.pop("PLAYWRIGHT_DOWNLOAD_HOST", None) - if BotConfig.system_proxy: - os.environ.pop("HTTPS_PROXY", None) - if original_proxy is not None: - os.environ["HTTPS_PROXY"] = original_proxy - - def try_install_chromium(): + 参数: + cookies: cookies + """ + browser = await get_browser() + ctx = await browser.new_context(**kwargs) + if cookies: + if isinstance(cookies, dict): + cookies = [cookies] + await ctx.add_cookies(cookies) # type: ignore + page = await ctx.new_page() try: - sys.argv = ["", "install", "chromium"] - main() - except SystemExit as e: - return e.code == 0 - return False + yield page + finally: + await page.close() + await ctx.close() - logger.info("检查 Chromium 更新") + @classmethod + async def screenshot( + cls, + url: str, + path: Path | str, + element: str | list[str], + *, + wait_time: int | None = None, + viewport_size: dict[str, int] | None = None, + wait_until: ( + Literal["domcontentloaded", "load", "networkidle"] | None + ) = "networkidle", + timeout: float | None = None, + type_: Literal["jpeg", "png"] | None = None, + user_agent: str | None = None, + cookies: list[dict[str, Any]] | dict[str, Any] | None = None, + **kwargs, + ) -> UniMessage | None: + """截图,该方法仅用于简单快捷截图,复杂截图请操作 page - original_proxy = os.environ.get("HTTPS_PROXY") - set_env_variables() - - success = try_install_chromium() - - if not success: - logger.info("Chromium 更新失败,尝试从原始仓库下载,速度较慢") - os.environ["PLAYWRIGHT_DOWNLOAD_HOST"] = "" - success = try_install_chromium() - - restore_env_variables() - - if not success: - raise RuntimeError("未知错误,Chromium 下载失败") - - -async def check_playwright_env(): - """检查 Playwright 依赖""" - logger.info("检查 Playwright 依赖") - try: - async with async_playwright() as p: - await p.chromium.launch() - except Exception as e: - raise ImportError("加载失败,Playwright 依赖不全,") from e + 参数: + url: 网址 + path: 存储路径 + element: 元素选择 + wait_time: 等待截取超时时间 + viewport_size: 窗口大小 + wait_until: 等待类型 + timeout: 超时限制 + type_: 保存类型 + user_agent: user_agent + cookies: cookies + """ + if viewport_size is None: + viewport_size = {"width": 2560, "height": 1080} + if isinstance(path, str): + path = Path(path) + wait_time = wait_time * 1000 if wait_time else None + element_list = [element] if isinstance(element, str) else element + async with cls.new_page( + cookies, + viewport=viewport_size, + user_agent=user_agent, + **kwargs, + ) as page: + await page.goto(url, timeout=timeout, wait_until=wait_until) + card = page + for e in element_list: + if not card: + return None + card = await card.wait_for_selector(e, timeout=wait_time) + if card: + await card.screenshot(path=path, timeout=timeout, type=type_) + return MessageUtils.build_message(path) + return None diff --git a/zhenxun/utils/decorator/retry.py b/zhenxun/utils/decorator/retry.py index ddc55584..e81aa334 100644 --- a/zhenxun/utils/decorator/retry.py +++ b/zhenxun/utils/decorator/retry.py @@ -1,24 +1,226 @@ +from collections.abc import Callable +from functools import partial, wraps +from typing import Any, Literal + from anyio import EndOfStream -from httpx import ConnectError, HTTPStatusError, TimeoutException -from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed +from httpx import ( + ConnectError, + HTTPStatusError, + RemoteProtocolError, + StreamError, + TimeoutException, +) +from nonebot.utils import is_coroutine_callable +from tenacity import ( + RetryCallState, + retry, + retry_if_exception_type, + retry_if_result, + stop_after_attempt, + wait_exponential, + wait_fixed, +) + +from zhenxun.services.log import logger + +LOG_COMMAND = "RetryDecorator" +_SENTINEL = object() + + +def _log_before_sleep(log_name: str | None, retry_state: RetryCallState): + """ + tenacity 重试前的日志记录回调函数。 + """ + func_name = retry_state.fn.__name__ if retry_state.fn else "unknown_function" + log_context = f"函数 '{func_name}'" + if log_name: + log_context = f"操作 '{log_name}' ({log_context})" + + reason = "" + if retry_state.outcome: + if exc := retry_state.outcome.exception(): + reason = f"触发异常: {exc.__class__.__name__}({exc})" + else: + reason = f"不满足结果条件: result={retry_state.outcome.result()}" + + wait_time = ( + getattr(retry_state.next_action, "sleep", 0) if retry_state.next_action else 0 + ) + logger.warning( + f"{log_context} 第 {retry_state.attempt_number} 次重试... " + f"等待 {wait_time:.2f} 秒. {reason}", + LOG_COMMAND, + ) class Retry: @staticmethod - def api( - retry_count: int = 3, wait: int = 1, exception: tuple[type[Exception], ...] = () + def simple( + stop_max_attempt: int = 3, + wait_fixed_seconds: int = 2, + exception: tuple[type[Exception], ...] = (), + *, + log_name: str | None = None, + on_failure: Callable[[Exception], Any] | None = None, + return_on_failure: Any = _SENTINEL, ): - """接口调用重试""" + """ + 一个简单的、用于通用网络请求的重试装饰器预设。 + + 参数: + stop_max_attempt: 最大重试次数。 + wait_fixed_seconds: 固定等待策略的等待秒数。 + exception: 额外需要重试的异常类型元组。 + log_name: 用于日志记录的操作名称。 + on_failure: (可选) 所有重试失败后的回调。 + return_on_failure: (可选) 所有重试失败后的返回值。 + """ + return Retry.api( + stop_max_attempt=stop_max_attempt, + wait_fixed_seconds=wait_fixed_seconds, + exception=exception, + strategy="fixed", + log_name=log_name, + on_failure=on_failure, + return_on_failure=return_on_failure, + ) + + @staticmethod + def download( + stop_max_attempt: int = 3, + exception: tuple[type[Exception], ...] = (), + *, + wait_exp_multiplier: int = 2, + wait_exp_max: int = 15, + log_name: str | None = None, + on_failure: Callable[[Exception], Any] | None = None, + return_on_failure: Any = _SENTINEL, + ): + """ + 一个适用于文件下载的重试装饰器预设,使用指数退避策略。 + + 参数: + stop_max_attempt: 最大重试次数。 + exception: 额外需要重试的异常类型元组。 + wait_exp_multiplier: 指数退避的乘数。 + wait_exp_max: 指数退避的最大等待时间。 + log_name: 用于日志记录的操作名称。 + on_failure: (可选) 所有重试失败后的回调。 + return_on_failure: (可选) 所有重试失败后的返回值。 + """ + return Retry.api( + stop_max_attempt=stop_max_attempt, + exception=exception, + strategy="exponential", + wait_exp_multiplier=wait_exp_multiplier, + wait_exp_max=wait_exp_max, + log_name=log_name, + on_failure=on_failure, + return_on_failure=return_on_failure, + ) + + @staticmethod + def api( + stop_max_attempt: int = 3, + wait_fixed_seconds: int = 1, + exception: tuple[type[Exception], ...] = (), + *, + strategy: Literal["fixed", "exponential"] = "fixed", + retry_on_result: Callable[[Any], bool] | None = None, + wait_exp_multiplier: int = 1, + wait_exp_max: int = 10, + log_name: str | None = None, + on_failure: Callable[[Exception], Any] | None = None, + return_on_failure: Any = _SENTINEL, + ): + """ + 通用、可配置的API调用重试装饰器。 + + 参数: + stop_max_attempt: 最大重试次数。 + wait_fixed_seconds: 固定等待策略的等待秒数。 + exception: 额外需要重试的异常类型元组。 + strategy: 重试等待策略, 'fixed' (固定) 或 'exponential' (指数退避)。 + retry_on_result: 一个回调函数,接收函数返回值。如果返回 True,则触发重试。 + 例如 `lambda r: r.status_code != 200` + wait_exp_multiplier: 指数退避的乘数。 + wait_exp_max: 指数退避的最大等待时间。 + log_name: 用于日志记录的操作名称,方便区分不同的重试场景。 + on_failure: (可选) 当所有重试都失败后,在抛出异常或返回默认值之前, + 会调用此函数,并将最终的异常实例作为参数传入。 + return_on_failure: (可选) 如果设置了此参数,当所有重试失败后, + 将不再抛出异常,而是返回此参数指定的值。 + """ base_exceptions = ( TimeoutException, ConnectError, HTTPStatusError, + StreamError, + RemoteProtocolError, EndOfStream, *exception, ) - return retry( - reraise=True, - stop=stop_after_attempt(retry_count), - wait=wait_fixed(wait), - retry=retry_if_exception_type(base_exceptions), - ) + + def decorator(func: Callable) -> Callable: + if strategy == "exponential": + wait_strategy = wait_exponential( + multiplier=wait_exp_multiplier, max=wait_exp_max + ) + else: + wait_strategy = wait_fixed(wait_fixed_seconds) + + retry_conditions = retry_if_exception_type(base_exceptions) + if retry_on_result: + retry_conditions |= retry_if_result(retry_on_result) + + log_callback = partial(_log_before_sleep, log_name) + + tenacity_retry_decorator = retry( + stop=stop_after_attempt(stop_max_attempt), + wait=wait_strategy, + retry=retry_conditions, + before_sleep=log_callback, + reraise=True, + ) + + decorated_func = tenacity_retry_decorator(func) + + if return_on_failure is _SENTINEL: + return decorated_func + + if is_coroutine_callable(func): + + @wraps(func) + async def async_wrapper(*args, **kwargs): + try: + return await decorated_func(*args, **kwargs) + except Exception as e: + if on_failure: + if is_coroutine_callable(on_failure): + await on_failure(e) + else: + on_failure(e) + return return_on_failure + + return async_wrapper + else: + + @wraps(func) + def sync_wrapper(*args, **kwargs): + try: + return decorated_func(*args, **kwargs) + except Exception as e: + if on_failure: + if is_coroutine_callable(on_failure): + logger.error( + f"不能在同步函数 '{func.__name__}' 中调用异步的 " + f"on_failure 回调。", + LOG_COMMAND, + ) + else: + on_failure(e) + return return_on_failure + + return sync_wrapper + + return decorator diff --git a/zhenxun/utils/enum.py b/zhenxun/utils/enum.py index 5b235615..d7a2f703 100644 --- a/zhenxun/utils/enum.py +++ b/zhenxun/utils/enum.py @@ -1,4 +1,47 @@ -from strenum import StrEnum +import sys + +if sys.version_info >= (3, 11): + from enum import StrEnum +else: + from strenum import StrEnum + + +class PriorityLifecycleType(StrEnum): + STARTUP = "STARTUP" + """启动""" + SHUTDOWN = "SHUTDOWN" + """关闭""" + + +class BotSentType(StrEnum): + GROUP = "GROUP" + PRIVATE = "PRIVATE" + + +class BankHandleType(StrEnum): + DEPOSIT = "DEPOSIT" + """存款""" + WITHDRAW = "WITHDRAW" + """取款""" + LOAN = "LOAN" + """贷款""" + REPAYMENT = "REPAYMENT" + """还款""" + INTEREST = "INTEREST" + """利息""" + + +class EventLogType(StrEnum): + GROUP_MEMBER_INCREASE = "GROUP_MEMBER_INCREASE" + """群成员增加""" + GROUP_MEMBER_DECREASE = "GROUP_MEMBER_DECREASE" + """群成员减少""" + KICK_MEMBER = "KICK_MEMBER" + """踢出群成员""" + KICK_BOT = "KICK_BOT" + """踢出Bot""" + LEAVE_MEMBER = "LEAVE_MEMBER" + """主动退群""" class CacheType(StrEnum): @@ -128,7 +171,9 @@ class RequestType(StrEnum): """ FRIEND = "FRIEND" + """好友""" GROUP = "GROUP" + """群组""" class RequestHandleType(StrEnum): diff --git a/zhenxun/utils/exception.py b/zhenxun/utils/exception.py index db8c0656..9ab664f4 100644 --- a/zhenxun/utils/exception.py +++ b/zhenxun/utils/exception.py @@ -1,3 +1,15 @@ +class HookPriorityException(BaseException): + """ + 钩子优先级异常 + """ + + def __init__(self, info: str = "") -> None: + self.info = info + + def __str__(self) -> str: + return self.info + + class NotFoundError(Exception): """ 未发现 @@ -52,3 +64,23 @@ class GoodsNotFound(Exception): """ pass + + +class AllURIsFailedError(Exception): + """ + 当所有备用URL都尝试失败后抛出此异常 + """ + + def __init__(self, urls: list[str], exceptions: list[Exception]): + self.urls = urls + self.exceptions = exceptions + super().__init__( + f"All {len(urls)} URIs failed. Last exception: {exceptions[-1]}" + ) + + def __str__(self) -> str: + exc_info = "\n".join( + f" - {url}: {exc.__class__.__name__}({exc})" + for url, exc in zip(self.urls, self.exceptions) + ) + return f"All {len(self.urls)} URIs failed:\n{exc_info}" diff --git a/zhenxun/utils/github_utils/const.py b/zhenxun/utils/github_utils/const.py index 23effa4c..68fffad9 100644 --- a/zhenxun/utils/github_utils/const.py +++ b/zhenxun/utils/github_utils/const.py @@ -21,6 +21,9 @@ CACHED_API_TTL = 300 RAW_CONTENT_FORMAT = "https://raw.githubusercontent.com/{owner}/{repo}/{branch}/{path}" """raw content格式""" +GITEE_RAW_CONTENT_FORMAT = "https://gitee.com/{owner}/{repo}/raw/main/{path}" +"""gitee raw content格式""" + ARCHIVE_URL_FORMAT = "https://github.com/{owner}/{repo}/archive/refs/heads/{branch}.zip" """archive url格式""" diff --git a/zhenxun/utils/github_utils/func.py b/zhenxun/utils/github_utils/func.py index 19daf10d..db3afa03 100644 --- a/zhenxun/utils/github_utils/func.py +++ b/zhenxun/utils/github_utils/func.py @@ -4,6 +4,7 @@ from zhenxun.utils.http_utils import AsyncHttpx from .const import ( ARCHIVE_URL_FORMAT, + GITEE_RAW_CONTENT_FORMAT, RAW_CONTENT_FORMAT, RELEASE_ASSETS_FORMAT, RELEASE_SOURCE_FORMAT, @@ -21,9 +22,9 @@ async def __get_fastest_formats(formats: dict[str, str]) -> list[str]: async def get_fastest_raw_formats() -> list[str]: """获取最快的raw下载地址格式""" formats: dict[str, str] = { + "https://gitee.com/": GITEE_RAW_CONTENT_FORMAT, "https://raw.githubusercontent.com/": RAW_CONTENT_FORMAT, "https://ghproxy.cc/": f"https://ghproxy.cc/{RAW_CONTENT_FORMAT}", - "https://mirror.ghproxy.com/": f"https://mirror.ghproxy.com/{RAW_CONTENT_FORMAT}", "https://gh-proxy.com/": f"https://gh-proxy.com/{RAW_CONTENT_FORMAT}", "https://cdn.jsdelivr.net/": "https://cdn.jsdelivr.net/gh/{owner}/{repo}@{branch}/{path}", } @@ -36,7 +37,6 @@ async def get_fastest_archive_formats() -> list[str]: formats: dict[str, str] = { "https://github.com/": ARCHIVE_URL_FORMAT, "https://ghproxy.cc/": f"https://ghproxy.cc/{ARCHIVE_URL_FORMAT}", - "https://mirror.ghproxy.com/": f"https://mirror.ghproxy.com/{ARCHIVE_URL_FORMAT}", "https://gh-proxy.com/": f"https://gh-proxy.com/{ARCHIVE_URL_FORMAT}", } return await __get_fastest_formats(formats) @@ -48,7 +48,6 @@ async def get_fastest_release_formats() -> list[str]: formats: dict[str, str] = { "https://objects.githubusercontent.com/": RELEASE_ASSETS_FORMAT, "https://ghproxy.cc/": f"https://ghproxy.cc/{RELEASE_ASSETS_FORMAT}", - "https://mirror.ghproxy.com/": f"https://mirror.ghproxy.com/{RELEASE_ASSETS_FORMAT}", "https://gh-proxy.com/": f"https://gh-proxy.com/{RELEASE_ASSETS_FORMAT}", } return await __get_fastest_formats(formats) diff --git a/zhenxun/utils/github_utils/models.py b/zhenxun/utils/github_utils/models.py index e3e5dfe3..fb690616 100644 --- a/zhenxun/utils/github_utils/models.py +++ b/zhenxun/utils/github_utils/models.py @@ -1,13 +1,18 @@ import contextlib +import sys from typing import Protocol from aiocache import cached from nonebot.compat import model_dump from pydantic import BaseModel, Field -from strenum import StrEnum from zhenxun.utils.http_utils import AsyncHttpx +if sys.version_info >= (3, 11): + from enum import StrEnum +else: + from strenum import StrEnum + from .const import ( CACHED_API_TTL, GIT_API_COMMIT_FORMAT, diff --git a/zhenxun/utils/html_template/__init__.py b/zhenxun/utils/html_template/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/zhenxun/utils/html_template/__init__.py @@ -0,0 +1 @@ + diff --git a/zhenxun/utils/html_template/component.py b/zhenxun/utils/html_template/component.py new file mode 100644 index 00000000..c23ed503 --- /dev/null +++ b/zhenxun/utils/html_template/component.py @@ -0,0 +1,36 @@ +from abc import ABC +from typing import Literal + +from pydantic import BaseModel + + +class Style(BaseModel): + """常用样式""" + + padding: str = "0px" + margin: str = "0px" + border: str = "0px" + border_radius: str = "0px" + text_align: Literal["left", "right", "center"] = "left" + color: str = "#000" + font_size: str = "16px" + + +class Component(ABC): + def __init__(self, background_color: str = "#fff", is_container: bool = False): + self.extra_style = [] + self.style = Style() + self.background_color = background_color + self.is_container = is_container + self.children = [] + + def add_child(self, child: "Component | str"): + self.children.append(child) + + def set_style(self, style: Style): + self.style = style + + def add_style(self, style: str): + self.extra_style.append(style) + + def to_html(self) -> str: ... diff --git a/zhenxun/utils/html_template/components/title.py b/zhenxun/utils/html_template/components/title.py new file mode 100644 index 00000000..860ad17e --- /dev/null +++ b/zhenxun/utils/html_template/components/title.py @@ -0,0 +1,15 @@ +from ..component import Component, Style +from ..container import Row + + +class Title(Component): + def __init__(self, text: str, color: str = "#000"): + self.text = text + self.color = color + + def build(self): + row = Row() + style = Style(font_size="36px", color=self.color) + row.set_style(style) + + # def diff --git a/zhenxun/utils/html_template/container.py b/zhenxun/utils/html_template/container.py new file mode 100644 index 00000000..3d5341c0 --- /dev/null +++ b/zhenxun/utils/html_template/container.py @@ -0,0 +1,31 @@ +from .component import Component + + +class Row(Component): + def __init__(self, background_color: str = "#fff"): + super().__init__(background_color, True) + + +class Col(Component): + def __init__(self, background_color: str = "#fff"): + super().__init__(background_color, True) + + +class Container(Component): + def __init__(self, background_color: str = "#fff"): + super().__init__(background_color, True) + self.children = [] + + +class GlobalOverview: + def __init__(self, name: str): + self.name = name + self.class_name: dict[str, list[str]] = {} + self.content = None + + def set_content(self, content: Container): + self.content = content + + def add_class(self, class_name: str, contents: list[str]): + """全局样式""" + self.class_name[class_name] = contents diff --git a/zhenxun/utils/http_utils.py b/zhenxun/utils/http_utils.py index 962c9e01..9f00e9af 100644 --- a/zhenxun/utils/http_utils.py +++ b/zhenxun/utils/http_utils.py @@ -1,408 +1,639 @@ import asyncio -from asyncio.exceptions import TimeoutError -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Awaitable, Callable, Sequence from contextlib import asynccontextmanager +import os from pathlib import Path import time -from typing import Any, ClassVar, Literal +from typing import Any, ClassVar, cast import aiofiles -from anyio import EndOfStream import httpx -from httpx import ConnectTimeout, HTTPStatusError, Response -from nonebot_plugin_alconna import UniMessage -from nonebot_plugin_htmlrender import get_browser -from playwright.async_api import Page -from retrying import retry -import rich +from httpx import AsyncClient, AsyncHTTPTransport, HTTPStatusError, Proxy, Response +import nonebot +from rich.progress import ( + BarColumn, + DownloadColumn, + Progress, + TextColumn, + TransferSpeedColumn, +) +import ujson as json from zhenxun.configs.config import BotConfig from zhenxun.services.log import logger -from zhenxun.utils.message import MessageUtils +from zhenxun.utils.decorator.retry import Retry +from zhenxun.utils.exception import AllURIsFailedError +from zhenxun.utils.manager.priority_manager import PriorityLifecycle from zhenxun.utils.user_agent import get_user_agent -# from .browser import get_browser +from .browser import AsyncPlaywright, BrowserIsNone # noqa: F401 + +_SENTINEL = object() + +driver = nonebot.get_driver() +_client: AsyncClient | None = None + + +@PriorityLifecycle.on_startup(priority=0) +async def _(): + """ + 在Bot启动时初始化全局httpx客户端。 + """ + global _client + client_kwargs = {} + if proxy_url := BotConfig.system_proxy or None: + try: + version_parts = httpx.__version__.split(".") + major = int("".join(c for c in version_parts[0] if c.isdigit())) + minor = ( + int("".join(c for c in version_parts[1] if c.isdigit())) + if len(version_parts) > 1 + else 0 + ) + if (major, minor) >= (0, 28): + client_kwargs["proxy"] = proxy_url + else: + client_kwargs["proxies"] = proxy_url + except (ValueError, IndexError): + client_kwargs["proxy"] = proxy_url + logger.warning( + f"无法解析 httpx 版本 '{httpx.__version__}'," + "将默认使用新版 'proxy' 参数语法。" + ) + + _client = httpx.AsyncClient( + headers=get_user_agent(), + follow_redirects=True, + **client_kwargs, + ) + + logger.info("全局 httpx.AsyncClient 已启动。", "HTTPClient") + + +@driver.on_shutdown +async def _(): + """ + 在Bot关闭时关闭全局httpx客户端。 + """ + if _client: + await _client.aclose() + logger.info("全局 httpx.AsyncClient 已关闭。", "HTTPClient") + + +def get_client() -> AsyncClient: + """ + 获取全局 httpx.AsyncClient 实例。 + """ + global _client + if not _client: + if not os.environ.get("PYTEST_CURRENT_TEST"): + raise RuntimeError("全局 httpx.AsyncClient 未初始化,请检查启动流程。") + # 在测试环境中创建临时客户端 + logger.warning("在测试环境中创建临时HTTP客户端", "HTTPClient") + _client = httpx.AsyncClient( + headers=get_user_agent(), + follow_redirects=True, + ) + return _client + + +def get_async_client( + proxies: dict[str, str] | None = None, + proxy: str | None = None, + verify: bool = False, + **kwargs, +) -> httpx.AsyncClient: + """ + [向后兼容] 创建 httpx.AsyncClient 实例的工厂函数。 + 此函数完全保留了旧版本的接口,确保现有代码无需修改即可使用。 + """ + transport = kwargs.pop("transport", None) or AsyncHTTPTransport(verify=verify) + if proxies: + http_proxy = proxies.get("http://") + https_proxy = proxies.get("https://") + return httpx.AsyncClient( + mounts={ + "http://": AsyncHTTPTransport( + proxy=Proxy(http_proxy) if http_proxy else None + ), + "https://": AsyncHTTPTransport( + proxy=Proxy(https_proxy) if https_proxy else None + ), + }, + transport=transport, + **kwargs, + ) + elif proxy: + return httpx.AsyncClient( + mounts={ + "http://": AsyncHTTPTransport(proxy=Proxy(proxy)), + "https://": AsyncHTTPTransport(proxy=Proxy(proxy)), + }, + transport=transport, + **kwargs, + ) + return httpx.AsyncClient(transport=transport, **kwargs) class AsyncHttpx: - proxy: ClassVar[dict[str, str | None]] = { - "http://": BotConfig.system_proxy, - "https://": BotConfig.system_proxy, - } + """ + 一个高级的、健壮的异步HTTP客户端工具类。 + + 设计理念: + - **全局共享客户端**: 默认情况下,所有请求都通过一个在应用启动时初始化的全局 + `httpx.AsyncClient` 实例发出。这个实例共享连接池,提高了效率和性能。 + - **向后兼容与灵活性**: 完全兼容旧的API,同时提供了两种方式来处理需要 + 特殊网络配置(如不同代理、超时)的请求: + 1. **单次请求覆盖**: 在调用 `get`, `post` 等方法时,直接传入 `proxies`, + `timeout` 等参数,将为该次请求创建一个临时的、独立的客户端。 + 2. **临时客户端上下文**: 使用 `temporary_client()` 上下文管理器,可以 + 获取一个独立的、可配置的客户端,用于执行一系列需要相同特殊配置的请求。 + - **健壮性**: 内置了自动重试、多镜像URL回退(fallback)机制,并提供了便捷的 + JSON解析和文件下载方法。 + """ + + CLIENT_KEY: ClassVar[list[str]] = [ + "use_proxy", + "proxies", + "proxy", + "verify", + "headers", + ] + + default_proxy: ClassVar[dict[str, str] | None] = ( + { + "http://": BotConfig.system_proxy, + "https://": BotConfig.system_proxy, + } + if BotConfig.system_proxy + else None + ) + + @classmethod + def _prepare_temporary_client_config(cls, client_kwargs: dict) -> dict: + """ + [向后兼容] 处理旧式的客户端kwargs,将其转换为get_async_client可用的配置。 + 主要负责处理 use_proxy 标志,这是为了兼容旧版本代码中使用的 use_proxy 参数。 + """ + final_config = client_kwargs.copy() + + use_proxy = final_config.pop("use_proxy", True) + + if "proxies" not in final_config and "proxy" not in final_config: + final_config["proxies"] = cls.default_proxy if use_proxy else None + return final_config + + @classmethod + def _split_kwargs(cls, kwargs: dict) -> tuple[dict, dict]: + """[优化] 分离客户端配置和请求参数,使逻辑更清晰。""" + client_kwargs = {k: v for k, v in kwargs.items() if k in cls.CLIENT_KEY} + request_kwargs = {k: v for k, v in kwargs.items() if k not in cls.CLIENT_KEY} + return client_kwargs, request_kwargs + + @classmethod + @asynccontextmanager + async def _get_active_client_context( + cls, client: AsyncClient | None = None, **kwargs + ) -> AsyncGenerator[AsyncClient, None]: + """ + 内部辅助方法,根据 kwargs 决定并提供一个活动的 HTTP 客户端。 + - 如果 kwargs 中有客户端配置,则创建并返回一个临时客户端。 + - 否则,返回传入的 client 或全局客户端。 + - 自动处理临时客户端的关闭。 + """ + if kwargs: + logger.debug(f"为单次请求创建临时客户端,配置: {kwargs}") + temp_client_config = cls._prepare_temporary_client_config(kwargs) + async with get_async_client(**temp_client_config) as temp_client: + yield temp_client + else: + yield client or get_client() + + @Retry.simple(log_name="内部HTTP请求") + async def _execute_request_inner( + self, client: AsyncClient, method: str, url: str, **kwargs + ) -> Response: + """ + [内部] 执行单次HTTP请求的私有核心方法,被重试装饰器包裹。 + """ + return await client.request(method, url, **kwargs) + + @classmethod + async def _single_request( + cls, method: str, url: str, *, client: AsyncClient | None = None, **kwargs + ) -> Response: + """ + 执行单次HTTP请求的私有方法,内置了默认的重试逻辑。 + """ + client_kwargs, request_kwargs = cls._split_kwargs(kwargs) + + async with cls._get_active_client_context( + client=client, **client_kwargs + ) as active_client: + response = await cls()._execute_request_inner( + active_client, method, url, **request_kwargs + ) + response.raise_for_status() + return response + + @classmethod + async def _execute_with_fallbacks( + cls, + urls: str | list[str], + worker: Callable[..., Awaitable[Any]], + *, + client: AsyncClient | None = None, + **kwargs, + ) -> Any: + """ + 通用执行器,按顺序尝试多个URL,直到成功。 + + 参数: + urls: 单个URL或URL列表。 + worker: 一个接受单个URL和其他kwargs并执行请求的协程函数。 + client: 可选的HTTP客户端。 + **kwargs: 传递给worker的额外参数。 + """ + url_list = [urls] if isinstance(urls, str) else urls + exceptions = [] + + for i, url in enumerate(url_list): + try: + result = await worker(url, client=client, **kwargs) + if i > 0: + logger.info( + f"成功从镜像 '{url}' 获取资源 " + f"(在尝试了 {i} 个失败的镜像之后)。", + "AsyncHttpx:FallbackExecutor", + ) + return result + except Exception as e: + exceptions.append(e) + if url != url_list[-1]: + logger.warning( + f"Worker '{worker.__name__}' on {url} failed, trying next. " + f"Error: {e.__class__.__name__}", + "AsyncHttpx:FallbackExecutor", + ) + + raise AllURIsFailedError(url_list, exceptions) @classmethod - @retry(stop_max_attempt_number=3) async def get( cls, url: str | list[str], *, - params: dict[str, Any] | None = None, - headers: dict[str, str] | None = None, - cookies: dict[str, str] | None = None, - verify: bool = True, - use_proxy: bool = True, - proxy: dict[str, str] | None = None, - timeout: int = 30, # noqa: ASYNC109 + follow_redirects: bool = True, + check_status_code: int | None = None, + client: AsyncClient | None = None, **kwargs, ) -> Response: - """Get + """发送 GET 请求,并返回第一个成功的响应。 + + 说明: + 本方法是 httpx.get 的高级包装,增加了多链接尝试、自动重试和统一的 + 客户端管理。如果提供 URL 列表,它将依次尝试直到成功为止。 + + 用法建议: + - **常规使用**: `await AsyncHttpx.get(url)` 将使用全局客户端。 + - **单次覆盖配置**: `await AsyncHttpx.get(url, timeout=5, proxies=None)` + 将为本次请求创建一个独立的临时客户端。 参数: - url: url - params: params - headers: 请求头 - cookies: cookies - verify: verify - use_proxy: 使用默认代理 - proxy: 指定代理 - timeout: 超时时间 + url: 单个请求 URL 或一个 URL 列表。 + follow_redirects: 是否跟随重定向。 + check_status_code: (可选) 若提供,将检查响应状态码是否匹配,否则抛出异常。 + client: (可选) 指定一个活动的HTTP客户端实例。若提供,则忽略 + `**kwargs`中的客户端配置。 + **kwargs: 其他所有传递给 httpx.get 的参数 (如 `params`, `headers`, + `timeout`)。如果包含 `proxies`, `verify` 等客户端配置参数, + 将创建一个临时客户端。 + + 返回: + Response: httpx 的响应对象。 + + Raises: + AllURIsFailedError: 当所有提供的URL都请求失败时抛出。 """ - urls = [url] if isinstance(url, str) else url - return await cls._get_first_successful( - urls, - params=params, - headers=headers, - cookies=cookies, - verify=verify, - use_proxy=use_proxy, - proxy=proxy, - timeout=timeout, - **kwargs, - ) - @classmethod - async def _get_first_successful( - cls, - urls: list[str], - **kwargs, - ) -> Response: - last_exception = None - for url in urls: - try: - return await cls._get_single(url, **kwargs) - except Exception as e: - last_exception = e - if url != urls[-1]: - logger.warning(f"获取 {url} 失败, 尝试下一个") - raise last_exception or Exception("All URLs failed") - - @classmethod - async def _get_single( - cls, - url: str, - *, - params: dict[str, Any] | None = None, - headers: dict[str, str] | None = None, - cookies: dict[str, str] | None = None, - verify: bool = True, - use_proxy: bool = True, - proxy: dict[str, str] | None = None, - timeout: int = 30, # noqa: ASYNC109 - **kwargs, - ) -> Response: - if not headers: - headers = get_user_agent() - _proxy = proxy or (cls.proxy if use_proxy else None) - async with httpx.AsyncClient(proxies=_proxy, verify=verify) as client: # type: ignore - return await client.get( - url, - params=params, - headers=headers, - cookies=cookies, - timeout=timeout, - **kwargs, + async def worker(current_url: str, **worker_kwargs) -> Response: + logger.info(f"开始获取 {current_url}..", "AsyncHttpx:get") + response = await cls._single_request( + "GET", current_url, follow_redirects=follow_redirects, **worker_kwargs ) + if check_status_code and response.status_code != check_status_code: + raise HTTPStatusError( + f"状态码错误: {response.status_code}!={check_status_code}", + request=response.request, + response=response, + ) + return response + + return await cls._execute_with_fallbacks(url, worker, client=client, **kwargs) @classmethod async def head( - cls, - url: str, - *, - params: dict[str, Any] | None = None, - headers: dict[str, str] | None = None, - cookies: dict[str, str] | None = None, - verify: bool = True, - use_proxy: bool = True, - proxy: dict[str, str] | None = None, - timeout: int = 30, # noqa: ASYNC109 - **kwargs, + cls, url: str | list[str], *, client: AsyncClient | None = None, **kwargs ) -> Response: - """Get + """发送 HEAD 请求,并返回第一个成功的响应。""" - 参数: - url: url - params: params - headers: 请求头 - cookies: cookies - verify: verify - use_proxy: 使用默认代理 - proxy: 指定代理 - timeout: 超时时间 - """ - if not headers: - headers = get_user_agent() - _proxy = proxy or (cls.proxy if use_proxy else None) - async with httpx.AsyncClient(proxies=_proxy, verify=verify) as client: # type: ignore - return await client.head( - url, - params=params, - headers=headers, - cookies=cookies, - timeout=timeout, - **kwargs, - ) + async def worker(current_url: str, **worker_kwargs) -> Response: + return await cls._single_request("HEAD", current_url, **worker_kwargs) + + return await cls._execute_with_fallbacks(url, worker, client=client, **kwargs) @classmethod async def post( - cls, - url: str, - *, - data: dict[str, Any] | None = None, - content: Any = None, - files: Any = None, - verify: bool = True, - use_proxy: bool = True, - proxy: dict[str, str] | None = None, - json: dict[str, Any] | None = None, - params: dict[str, str] | None = None, - headers: dict[str, str] | None = None, - cookies: dict[str, str] | None = None, - timeout: int = 30, # noqa: ASYNC109 - **kwargs, + cls, url: str | list[str], *, client: AsyncClient | None = None, **kwargs ) -> Response: - """ - 说明: - Post - 参数: - url: url - data: data - content: content - files: files - use_proxy: 是否默认代理 - proxy: 指定代理 - json: json - params: params - headers: 请求头 - cookies: cookies - timeout: 超时时间 - """ - if not headers: - headers = get_user_agent() - _proxy = proxy or (cls.proxy if use_proxy else None) - async with httpx.AsyncClient(proxies=_proxy, verify=verify) as client: # type: ignore - return await client.post( - url, - content=content, - data=data, - files=files, - json=json, - params=params, - headers=headers, - cookies=cookies, - timeout=timeout, - **kwargs, - ) + """发送 POST 请求,并返回第一个成功的响应。""" + + async def worker(current_url: str, **worker_kwargs) -> Response: + return await cls._single_request("POST", current_url, **worker_kwargs) + + return await cls._execute_with_fallbacks(url, worker, client=client, **kwargs) @classmethod - async def get_content(cls, url: str, **kwargs) -> bytes: - res = await cls.get(url, **kwargs) + async def get_content( + cls, url: str | list[str], *, client: AsyncClient | None = None, **kwargs + ) -> bytes: + """获取指定 URL 的二进制内容。""" + res = await cls.get(url, client=client, **kwargs) return res.content + @classmethod + @Retry.api( + log_name="JSON请求", + exception=(json.JSONDecodeError,), + return_on_failure=_SENTINEL, + ) + async def _request_and_parse_json( + cls, method: str, url: str, *, client: AsyncClient | None = None, **kwargs + ) -> Any: + """ + [私有] 执行单个HTTP请求并解析JSON,用于内部统一处理。 + """ + async with cls._get_active_client_context( + client=client, **kwargs + ) as active_client: + _, request_kwargs = cls._split_kwargs(kwargs) + response = await active_client.request(method, url, **request_kwargs) + response.raise_for_status() + return response.json() + + @classmethod + async def get_json( + cls, + url: str | list[str], + *, + default: Any = None, + raise_on_failure: bool = False, + client: AsyncClient | None = None, + **kwargs, + ) -> Any: + """ + 发送GET请求并自动解析为JSON,支持重试和多链接尝试。 + + 说明: + 这是一个高度便捷的方法,封装了请求、重试、JSON解析和错误处理。 + 它会在网络错误或JSON解析错误时自动重试。 + 如果所有尝试都失败,它会安全地返回一个默认值。 + + 参数: + url: 单个请求 URL 或一个备用 URL 列表。 + default: (可选) 当所有尝试都失败时返回的默认值,默认为None。 + raise_on_failure: (可选) 如果为 True, 当所有尝试失败时将抛出 + `AllURIsFailedError` 异常, 默认为 False. + client: (可选) 指定的HTTP客户端。 + **kwargs: 其他所有传递给 httpx.get 的参数。 + 例如 `params`, `headers`, `timeout`等。 + + 返回: + Any: 解析后的JSON数据,或在失败时返回 `default` 值。 + + Raises: + AllURIsFailedError: 当 `raise_on_failure` 为 True 且所有URL都请求失败时抛出 + """ + + async def worker(current_url: str, **worker_kwargs): + logger.debug(f"开始GET JSON: {current_url}", "AsyncHttpx:get_json") + return await cls._request_and_parse_json( + "GET", current_url, **worker_kwargs + ) + + try: + result = await cls._execute_with_fallbacks( + url, worker, client=client, **kwargs + ) + return default if result is _SENTINEL else result + except AllURIsFailedError as e: + logger.error(f"所有URL的JSON GET均失败: {e}", "AsyncHttpx:get_json") + if raise_on_failure: + raise e + return default + + @classmethod + async def post_json( + cls, + url: str | list[str], + *, + json: Any = None, + data: Any = None, + default: Any = None, + raise_on_failure: bool = False, + client: AsyncClient | None = None, + **kwargs, + ) -> Any: + """ + 发送POST请求并自动解析为JSON,功能与 get_json 类似。 + + 参数: + url: 单个请求 URL 或一个备用 URL 列表。 + json: (可选) 作为请求体发送的JSON数据。 + data: (可选) 作为请求体发送的表单数据。 + default: (可选) 当所有尝试都失败时返回的默认值,默认为None。 + raise_on_failure: (可选) 如果为 True, 当所有尝试失败时将抛出 + AllURIsFailedError 异常, 默认为 False. + client: (可选) 指定的HTTP客户端。 + **kwargs: 其他所有传递给 httpx.post 的参数。 + + 返回: + Any: 解析后的JSON数据,或在失败时返回 `default` 值。 + """ + if json is not None: + kwargs["json"] = json + if data is not None: + kwargs["data"] = data + + async def worker(current_url: str, **worker_kwargs): + logger.debug(f"开始POST JSON: {current_url}", "AsyncHttpx:post_json") + return await cls._request_and_parse_json( + "POST", current_url, **worker_kwargs + ) + + try: + result = await cls._execute_with_fallbacks( + url, worker, client=client, **kwargs + ) + return default if result is _SENTINEL else result + except AllURIsFailedError as e: + logger.error(f"所有URL的JSON POST均失败: {e}", "AsyncHttpx:post_json") + if raise_on_failure: + raise e + return default + + @classmethod + @Retry.api(log_name="文件下载(流式)") + async def _stream_download( + cls, url: str, path: Path, *, client: AsyncClient | None = None, **kwargs + ) -> None: + """ + 执行单个流式下载的私有方法,被重试装饰器包裹。 + """ + async with cls._get_active_client_context( + client=client, **kwargs + ) as active_client: + async with active_client.stream("GET", url, **kwargs) as response: + response.raise_for_status() + total = int(response.headers.get("Content-Length", 0)) + + with Progress( + TextColumn(path.name), + "[progress.percentage]{task.percentage:>3.0f}%", + BarColumn(bar_width=None), + DownloadColumn(), + TransferSpeedColumn(), + ) as progress: + task_id = progress.add_task("Download", total=total) + async with aiofiles.open(path, "wb") as f: + async for chunk in response.aiter_bytes(): + await f.write(chunk) + progress.update(task_id, advance=len(chunk)) + @classmethod async def download_file( cls, url: str | list[str], path: str | Path, *, - params: dict[str, str] | None = None, - verify: bool = True, - use_proxy: bool = True, - proxy: dict[str, str] | None = None, - headers: dict[str, str] | None = None, - cookies: dict[str, str] | None = None, - timeout: int = 30, # noqa: ASYNC109 stream: bool = False, - follow_redirects: bool = True, + client: AsyncClient | None = None, **kwargs, ) -> bool: - """下载文件 + """下载文件到指定路径。 + + 说明: + 支持多链接尝试和流式下载(带进度条)。 参数: - url: url - path: 存储路径 - params: params - verify: verify - use_proxy: 使用代理 - proxy: 指定代理 - headers: 请求头 - cookies: cookies - timeout: 超时时间 - stream: 是否使用流式下载(流式写入+进度条,适用于下载大文件) + url: 单个文件 URL 或一个备用 URL 列表。 + path: 文件保存的本地路径。 + stream: (可选) 是否使用流式下载,适用于大文件,默认为 False。 + client: (可选) 指定的HTTP客户端。 + **kwargs: 其他所有传递给 get() 方法或 httpx.stream() 的参数。 + + 返回: + bool: 是否下载成功。 """ - if isinstance(path, str): - path = Path(path) + path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) + + async def worker(current_url: str, **worker_kwargs) -> bool: + if not stream: + content = await cls.get_content(current_url, **worker_kwargs) + async with aiofiles.open(path, "wb") as f: + await f.write(content) + else: + await cls._stream_download(current_url, path, **worker_kwargs) + + logger.info( + f"下载 {current_url} 成功 -> {path.absolute()}", + "AsyncHttpx:download", + ) + return True + try: - for _ in range(3): - if not isinstance(url, list): - url = [url] - for u in url: - try: - if not stream: - response = await cls.get( - u, - params=params, - headers=headers, - cookies=cookies, - use_proxy=use_proxy, - proxy=proxy, - timeout=timeout, - follow_redirects=follow_redirects, - **kwargs, - ) - response.raise_for_status() - content = response.content - async with aiofiles.open(path, "wb") as wf: - await wf.write(content) - logger.info(f"下载 {u} 成功.. Path:{path.absolute()}") - else: - if not headers: - headers = get_user_agent() - _proxy = proxy or (cls.proxy if use_proxy else None) - async with httpx.AsyncClient( - proxies=_proxy, # type: ignore - verify=verify, - ) as client: - async with client.stream( - "GET", - u, - params=params, - headers=headers, - cookies=cookies, - timeout=timeout, - follow_redirects=True, - **kwargs, - ) as response: - response.raise_for_status() - logger.info( - f"开始下载 {path.name}.. " - f"Url: {u}.. " - f"Path: {path.absolute()}" - ) - async with aiofiles.open(path, "wb") as wf: - total = int( - response.headers.get("Content-Length", 0) - ) - with rich.progress.Progress( # type: ignore - rich.progress.TextColumn(path.name), # type: ignore - "[progress.percentage]{task.percentage:>3.0f}%", # type: ignore - rich.progress.BarColumn(bar_width=None), # type: ignore - rich.progress.DownloadColumn(), # type: ignore - rich.progress.TransferSpeedColumn(), # type: ignore - ) as progress: - download_task = progress.add_task( - "Download", - total=total or None, - ) - async for chunk in response.aiter_bytes(): - await wf.write(chunk) - await wf.flush() - progress.update( - download_task, - completed=response.num_bytes_downloaded, - ) - logger.info( - f"下载 {u} 成功.. Path:{path.absolute()}" - ) - return True - except (TimeoutError, ConnectTimeout, HTTPStatusError): - logger.warning(f"下载 {u} 失败.. 尝试下一个地址..") - except EndOfStream as e: - logger.warning( - f"下载 {url} EndOfStream 异常 Path:{path.absolute()}", e=e - ) - if path.exists(): - return True - logger.error(f"下载 {url} 下载超时.. Path:{path.absolute()}") - except Exception as e: - logger.error(f"下载 {url} 错误 Path:{path.absolute()}", e=e) - return False + return await cls._execute_with_fallbacks( + url, worker, client=client, **kwargs + ) + except AllURIsFailedError: + logger.error( + f"所有URL下载均失败 -> {path.absolute()}", "AsyncHttpx:download" + ) + return False @classmethod async def gather_download_file( cls, - url_list: list[str] | list[list[str]], - path_list: list[str | Path], + url_list: Sequence[list[str] | str], + path_list: Sequence[str | Path], *, - limit_async_number: int | None = None, - params: dict[str, str] | None = None, - use_proxy: bool = True, - proxy: dict[str, str] | None = None, - headers: dict[str, str] | None = None, - cookies: dict[str, str] | None = None, - timeout: int = 30, # noqa: ASYNC109 + limit_async_number: int = 5, **kwargs, ) -> list[bool]: - """分组同时下载文件 + """并发下载多个文件,支持为每个文件提供备用镜像链接。 + + 说明: + 使用 asyncio.Semaphore 来控制并发请求的数量。 + 对于 url_list 中的每个元素,如果它是一个列表,则会依次尝试直到下载成功。 参数: - url_list: url列表 - path_list: 存储路径列表 - limit_async_number: 限制同时请求数量 - params: params - use_proxy: 使用代理 - proxy: 指定代理 - headers: 请求头 - cookies: cookies - timeout: 超时时间 + url_list: 包含所有文件下载任务的列表。每个元素可以是: + - 一个字符串 (str): 代表该任务的唯一URL。 + - 一个字符串列表 (list[str]): 代表该任务的多个备用/镜像URL。 + path_list: 与 url_list 对应的文件保存路径列表。 + limit_async_number: (可选) 最大并发下载数,默认为 5。 + **kwargs: 其他所有传递给 download_file() 方法的参数。 + + 返回: + list[bool]: 对应每个下载任务是否成功。 """ - if n := len(url_list) != len(path_list): - raise UrlPathNumberNotEqual( - f"Url数量与Path数量不对等,Url:{len(url_list)},Path:{len(path_list)}" - ) - if limit_async_number and n > limit_async_number: - m = float(n) / limit_async_number - x = 0 - j = limit_async_number - _split_url_list = [] - _split_path_list = [] - for _ in range(int(m)): - _split_url_list.append(url_list[x:j]) - _split_path_list.append(path_list[x:j]) - x += limit_async_number - j += limit_async_number - if int(m) < m: - _split_url_list.append(url_list[j:]) - _split_path_list.append(path_list[j:]) - else: - _split_url_list = [url_list] - _split_path_list = [path_list] - tasks = [] - result_ = [] - for x, y in zip(_split_url_list, _split_path_list): - tasks.extend( - asyncio.create_task( - cls.download_file( - url, - path, - params=params, - headers=headers, - cookies=cookies, - use_proxy=use_proxy, - timeout=timeout, - proxy=proxy, - **kwargs, - ) + if len(url_list) != len(path_list): + raise ValueError("URL 列表和路径列表的长度必须相等") + + semaphore = asyncio.Semaphore(limit_async_number) + + async def _download_with_semaphore( + urls_for_one_path: str | list[str], path: str | Path + ): + async with semaphore: + return await cls.download_file(urls_for_one_path, path, **kwargs) + + tasks = [ + _download_with_semaphore(url_group, path) + for url_group, path in zip(url_list, path_list) + ] + + results = await asyncio.gather(*tasks, return_exceptions=True) + + final_results = [] + for i, result in enumerate(results): + if isinstance(result, Exception): + url_info = ( + url_list[i] + if isinstance(url_list[i], str) + else ", ".join(url_list[i]) ) - for url, path in zip(x, y) - ) - _x = await asyncio.gather(*tasks) - result_ = result_ + list(_x) - tasks.clear() - return result_ + logger.error(f"并发下载任务 ({url_info}) 时发生错误", e=result) + final_results.append(False) + else: + final_results.append(cast(bool, result)) + + return final_results @classmethod async def get_fastest_mirror(cls, url_list: list[str]) -> list[str]: + """测试并返回最快的镜像地址。 + + 说明: + 通过并发发送 HEAD 请求来测试每个 URL 的响应时间和可用性,并按响应速度排序。 + + 参数: + url_list: 需要测试的镜像 URL 列表。 + + 返回: + list[str]: 按从快到慢的顺序包含了所有可用的 URL。 + """ assert url_list async def head_mirror(client: type[AsyncHttpx], url: str) -> dict[str, Any]: @@ -434,90 +665,30 @@ class AsyncHttpx: _results = sorted(iter(_results), key=lambda r: r["elapsed_time"]) return [result["url"] for result in _results] - -class AsyncPlaywright: @classmethod @asynccontextmanager - async def new_page( - cls, cookies: list[dict[str, Any]] | dict[str, Any] | None = None, **kwargs - ) -> AsyncGenerator[Page, None]: - """获取一个新页面 + async def temporary_client(cls, **kwargs) -> AsyncGenerator[AsyncClient, None]: + """ + 创建一个临时的、可配置的HTTP客户端上下文,并直接返回该客户端实例。 + + 此方法返回一个标准的 `httpx.AsyncClient`,它不使用全局连接池, + 拥有独立的配置(如代理、headers、超时等),并在退出上下文后自动关闭。 + 适用于需要用一套特殊网络配置执行一系列请求的场景。 + + 用法: + async with AsyncHttpx.temporary_client(proxies=None, timeout=5) as client: + # client 是一个标准的 httpx.AsyncClient 实例 + response1 = await client.get("http://some.internal.api/1") + response2 = await client.get("http://some.internal.api/2") + data = response2.json() 参数: - cookies: cookies + **kwargs: 所有传递给 `httpx.AsyncClient` 构造函数的参数。 + 例如: `proxies`, `headers`, `verify`, `timeout`, + `follow_redirects`。 + + Yields: + httpx.AsyncClient: 一个配置好的、临时的客户端实例。 """ - browser = await get_browser() - ctx = await browser.new_context(**kwargs) - if cookies: - if isinstance(cookies, dict): - cookies = [cookies] - await ctx.add_cookies(cookies) # type: ignore - page = await ctx.new_page() - try: - yield page - finally: - await page.close() - await ctx.close() - - @classmethod - async def screenshot( - cls, - url: str, - path: Path | str, - element: str | list[str], - *, - wait_time: int | None = None, - viewport_size: dict[str, int] | None = None, - wait_until: ( - Literal["domcontentloaded", "load", "networkidle"] | None - ) = "networkidle", - timeout: float | None = None, # noqa: ASYNC109 - type_: Literal["jpeg", "png"] | None = None, - user_agent: str | None = None, - cookies: list[dict[str, Any]] | dict[str, Any] | None = None, - **kwargs, - ) -> UniMessage | None: - """截图,该方法仅用于简单快捷截图,复杂截图请操作 page - - 参数: - url: 网址 - path: 存储路径 - element: 元素选择 - wait_time: 等待截取超时时间 - viewport_size: 窗口大小 - wait_until: 等待类型 - timeout: 超时限制 - type_: 保存类型 - user_agent: user_agent - cookies: cookies - """ - if viewport_size is None: - viewport_size = {"width": 2560, "height": 1080} - if isinstance(path, str): - path = Path(path) - wait_time = wait_time * 1000 if wait_time else None - element_list = [element] if isinstance(element, str) else element - async with cls.new_page( - cookies, - viewport=viewport_size, - user_agent=user_agent, - **kwargs, - ) as page: - await page.goto(url, timeout=timeout, wait_until=wait_until) - card = page - for e in element_list: - if not card: - return None - card = await card.wait_for_selector(e, timeout=wait_time) - if card: - await card.screenshot(path=path, timeout=timeout, type=type_) - return MessageUtils.build_message(path) - return None - - -class UrlPathNumberNotEqual(Exception): - pass - - -class BrowserIsNone(Exception): - pass + async with get_async_client(**kwargs) as client: + yield client diff --git a/zhenxun/utils/manager/message_manager.py b/zhenxun/utils/manager/message_manager.py index e714c8d8..ee34369d 100644 --- a/zhenxun/utils/manager/message_manager.py +++ b/zhenxun/utils/manager/message_manager.py @@ -22,6 +22,4 @@ class MessageManager: @classmethod def get(cls, uid: str) -> list[str]: - if uid in cls.data: - return cls.data[uid] - return [] + return cls.data[uid] if uid in cls.data else [] diff --git a/zhenxun/utils/manager/priority_manager.py b/zhenxun/utils/manager/priority_manager.py new file mode 100644 index 00000000..1c59635c --- /dev/null +++ b/zhenxun/utils/manager/priority_manager.py @@ -0,0 +1,57 @@ +from collections.abc import Callable +from typing import ClassVar + +import nonebot +from nonebot.utils import is_coroutine_callable + +from zhenxun.services.log import logger +from zhenxun.utils.enum import PriorityLifecycleType +from zhenxun.utils.exception import HookPriorityException + +driver = nonebot.get_driver() + + +class PriorityLifecycle: + _data: ClassVar[dict[PriorityLifecycleType, dict[int, list[Callable]]]] = {} + + @classmethod + def add(cls, hook_type: PriorityLifecycleType, func: Callable, priority: int): + if hook_type not in cls._data: + cls._data[hook_type] = {} + if priority not in cls._data[hook_type]: + cls._data[hook_type][priority] = [] + cls._data[hook_type][priority].append(func) + + @classmethod + def on_startup(cls, *, priority: int): + def wrapper(func): + cls.add(PriorityLifecycleType.STARTUP, func, priority) + return func + + return wrapper + + @classmethod + def on_shutdown(cls, *, priority: int): + def wrapper(func): + cls.add(PriorityLifecycleType.SHUTDOWN, func, priority) + return func + + return wrapper + + +@driver.on_startup +async def _(): + priority_data = PriorityLifecycle._data.get(PriorityLifecycleType.STARTUP) + if not priority_data: + return + priority_list = sorted(priority_data.keys()) + priority = 0 + try: + for priority in priority_list: + for func in priority_data[priority]: + if is_coroutine_callable(func): + await func() + else: + func() + except HookPriorityException as e: + logger.error(f"打断优先级 [{priority}] on_startup 方法. {type(e)}: {e}") diff --git a/zhenxun/utils/manager/schedule_manager.py b/zhenxun/utils/manager/schedule_manager.py new file mode 100644 index 00000000..a3b21272 --- /dev/null +++ b/zhenxun/utils/manager/schedule_manager.py @@ -0,0 +1,810 @@ +import asyncio +from collections.abc import Callable, Coroutine +import copy +import inspect +import random +from typing import ClassVar + +import nonebot +from nonebot import get_bots +from nonebot_plugin_apscheduler import scheduler +from pydantic import BaseModel, ValidationError + +from zhenxun.configs.config import Config +from zhenxun.models.schedule_info import ScheduleInfo +from zhenxun.services.log import logger +from zhenxun.utils.common_utils import CommonUtils +from zhenxun.utils.manager.priority_manager import PriorityLifecycle +from zhenxun.utils.platform import PlatformUtils + +SCHEDULE_CONCURRENCY_KEY = "all_groups_concurrency_limit" + + +class SchedulerManager: + """ + 一个通用的、持久化的定时任务管理器,供所有插件使用。 + """ + + _registered_tasks: ClassVar[ + dict[str, dict[str, Callable | type[BaseModel] | None]] + ] = {} + _JOB_PREFIX = "zhenxun_schedule_" + _running_tasks: ClassVar[set] = set() + + def register( + self, plugin_name: str, params_model: type[BaseModel] | None = None + ) -> Callable: + """ + 注册一个可调度的任务函数。 + 被装饰的函数签名应为 `async def func(group_id: str | None, **kwargs)` + + Args: + plugin_name (str): 插件的唯一名称 (通常是模块名)。 + params_model (type[BaseModel], optional): 一个 Pydantic BaseModel 类, + 用于定义和验证任务函数接受的额外参数。 + """ + + def decorator(func: Callable[..., Coroutine]) -> Callable[..., Coroutine]: + if plugin_name in self._registered_tasks: + logger.warning(f"插件 '{plugin_name}' 的定时任务已被重复注册。") + self._registered_tasks[plugin_name] = { + "func": func, + "model": params_model, + } + model_name = params_model.__name__ if params_model else "无" + logger.debug( + f"插件 '{plugin_name}' 的定时任务已注册,参数模型: {model_name}" + ) + return func + + return decorator + + def get_registered_plugins(self) -> list[str]: + """获取所有已注册定时任务的插件列表。""" + return list(self._registered_tasks.keys()) + + def _get_job_id(self, schedule_id: int) -> str: + """根据数据库ID生成唯一的 APScheduler Job ID。""" + return f"{self._JOB_PREFIX}{schedule_id}" + + async def _execute_job(self, schedule_id: int): + """ + APScheduler 调度的入口函数。 + 根据 schedule_id 处理特定任务、所有群组任务或全局任务。 + """ + schedule = await ScheduleInfo.get_or_none(id=schedule_id) + if not schedule or not schedule.is_enabled: + logger.warning(f"定时任务 {schedule_id} 不存在或已禁用,跳过执行。") + return + + plugin_name = schedule.plugin_name + + task_meta = self._registered_tasks.get(plugin_name) + if not task_meta: + logger.error( + f"无法执行定时任务:插件 '{plugin_name}' 未注册或已卸载。将禁用该任务。" + ) + schedule.is_enabled = False + await schedule.save(update_fields=["is_enabled"]) + self._remove_aps_job(schedule.id) + return + + try: + if schedule.bot_id: + bot = nonebot.get_bot(schedule.bot_id) + else: + bot = nonebot.get_bot() + logger.debug( + f"任务 {schedule_id} 未关联特定Bot,使用默认Bot {bot.self_id}" + ) + except KeyError: + logger.warning( + f"定时任务 {schedule_id} 需要的 Bot {schedule.bot_id} " + f"不在线,本次执行跳过。" + ) + return + except ValueError: + logger.warning(f"当前没有Bot在线,定时任务 {schedule_id} 跳过。") + return + + if schedule.group_id == "__ALL_GROUPS__": + await self._execute_for_all_groups(schedule, task_meta, bot) + else: + await self._execute_for_single_target(schedule, task_meta, bot) + + async def _execute_for_all_groups( + self, schedule: ScheduleInfo, task_meta: dict, bot + ): + """为所有群组执行任务,并处理优先级覆盖。""" + plugin_name = schedule.plugin_name + + concurrency_limit = Config.get_config( + "SchedulerManager", SCHEDULE_CONCURRENCY_KEY, 5 + ) + if not isinstance(concurrency_limit, int) or concurrency_limit <= 0: + logger.warning( + f"无效的定时任务并发限制配置 '{concurrency_limit}',将使用默认值 5。" + ) + concurrency_limit = 5 + + logger.info( + f"开始执行针对 [所有群组] 的任务 " + f"(ID: {schedule.id}, 插件: {plugin_name}, Bot: {bot.self_id})," + f"并发限制: {concurrency_limit}" + ) + + all_gids = set() + try: + group_list, _ = await PlatformUtils.get_group_list(bot) + all_gids.update( + g.group_id for g in group_list if g.group_id and not g.channel_id + ) + except Exception as e: + logger.error(f"为 'all' 任务获取 Bot {bot.self_id} 的群列表失败", e=e) + return + + specific_tasks_gids = set( + await ScheduleInfo.filter( + plugin_name=plugin_name, group_id__in=list(all_gids) + ).values_list("group_id", flat=True) + ) + + semaphore = asyncio.Semaphore(concurrency_limit) + + async def worker(gid: str): + """使用 Semaphore 包装单个群组的任务执行""" + async with semaphore: + temp_schedule = copy.deepcopy(schedule) + temp_schedule.group_id = gid + await self._execute_for_single_target(temp_schedule, task_meta, bot) + await asyncio.sleep(random.uniform(0.1, 0.5)) + + tasks_to_run = [] + for gid in all_gids: + if gid in specific_tasks_gids: + logger.debug(f"群组 {gid} 已有特定任务,跳过 'all' 任务的执行。") + continue + tasks_to_run.append(worker(gid)) + + if tasks_to_run: + await asyncio.gather(*tasks_to_run) + + async def _execute_for_single_target( + self, schedule: ScheduleInfo, task_meta: dict, bot + ): + """为单个目标(具体群组或全局)执行任务。""" + plugin_name = schedule.plugin_name + group_id = schedule.group_id + + try: + is_blocked = await CommonUtils.task_is_block(bot, plugin_name, group_id) + if is_blocked: + target_desc = f"群 {group_id}" if group_id else "全局" + logger.info( + f"插件 '{plugin_name}' 的定时任务在目标 [{target_desc}]" + "因功能被禁用而跳过执行。" + ) + return + + task_func = task_meta["func"] + job_kwargs = schedule.job_kwargs + if not isinstance(job_kwargs, dict): + logger.error( + f"任务 {schedule.id} 的 job_kwargs 不是字典类型: {type(job_kwargs)}" + ) + return + + sig = inspect.signature(task_func) + if "bot" in sig.parameters: + job_kwargs["bot"] = bot + + logger.info( + f"插件 '{plugin_name}' 开始为目标 [{group_id or '全局'}] " + f"执行定时任务 (ID: {schedule.id})。" + ) + task = asyncio.create_task(task_func(group_id, **job_kwargs)) + self._running_tasks.add(task) + task.add_done_callback(self._running_tasks.discard) + await task + except Exception as e: + logger.error( + f"执行定时任务 (ID: {schedule.id}, 插件: {plugin_name}, " + f"目标: {group_id or '全局'}) 时发生异常", + e=e, + ) + + def _validate_and_prepare_kwargs( + self, plugin_name: str, job_kwargs: dict | None + ) -> tuple[bool, str | dict]: + """验证并准备任务参数,应用默认值""" + task_meta = self._registered_tasks.get(plugin_name) + if not task_meta: + return False, f"插件 '{plugin_name}' 未注册。" + + params_model = task_meta.get("model") + job_kwargs = job_kwargs if job_kwargs is not None else {} + + if not params_model: + if job_kwargs: + logger.warning( + f"插件 '{plugin_name}' 未定义参数模型,但收到了参数: {job_kwargs}" + ) + return True, job_kwargs + + if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)): + logger.error(f"插件 '{plugin_name}' 的参数模型不是有效的 BaseModel 类") + return False, f"插件 '{plugin_name}' 的参数模型配置错误" + + try: + model_validate = getattr(params_model, "model_validate", None) + if not model_validate: + return False, f"插件 '{plugin_name}' 的参数模型不支持验证" + + validated_model = model_validate(job_kwargs) + + model_dump = getattr(validated_model, "model_dump", None) + if not model_dump: + return False, f"插件 '{plugin_name}' 的参数模型不支持导出" + + return True, model_dump() + except ValidationError as e: + errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()] + error_str = "\n".join(errors) + msg = f"插件 '{plugin_name}' 的任务参数验证失败:\n{error_str}" + return False, msg + + def _add_aps_job(self, schedule: ScheduleInfo): + """根据 ScheduleInfo 对象添加或更新一个 APScheduler 任务。""" + job_id = self._get_job_id(schedule.id) + try: + scheduler.remove_job(job_id) + except Exception: + pass + + if not isinstance(schedule.trigger_config, dict): + logger.error( + f"任务 {schedule.id} 的 trigger_config 不是字典类型: " + f"{type(schedule.trigger_config)}" + ) + return + + scheduler.add_job( + self._execute_job, + trigger=schedule.trigger_type, + id=job_id, + misfire_grace_time=300, + args=[schedule.id], + **schedule.trigger_config, + ) + logger.debug( + f"已在 APScheduler 中添加/更新任务: {job_id} " + f"with trigger: {schedule.trigger_config}" + ) + + def _remove_aps_job(self, schedule_id: int): + """移除一个 APScheduler 任务。""" + job_id = self._get_job_id(schedule_id) + try: + scheduler.remove_job(job_id) + logger.debug(f"已从 APScheduler 中移除任务: {job_id}") + except Exception: + pass + + async def add_schedule( + self, + plugin_name: str, + group_id: str | None, + trigger_type: str, + trigger_config: dict, + job_kwargs: dict | None = None, + bot_id: str | None = None, + ) -> tuple[bool, str]: + """ + 添加或更新一个定时任务。 + """ + if plugin_name not in self._registered_tasks: + return False, f"插件 '{plugin_name}' 没有注册可用的定时任务。" + + is_valid, result = self._validate_and_prepare_kwargs(plugin_name, job_kwargs) + if not is_valid: + return False, str(result) + + validated_job_kwargs = result + + effective_bot_id = bot_id if group_id == "__ALL_GROUPS__" else None + + search_kwargs = { + "plugin_name": plugin_name, + "group_id": group_id, + } + if effective_bot_id: + search_kwargs["bot_id"] = effective_bot_id + else: + search_kwargs["bot_id__isnull"] = True + + defaults = { + "trigger_type": trigger_type, + "trigger_config": trigger_config, + "job_kwargs": validated_job_kwargs, + "is_enabled": True, + } + + schedule = await ScheduleInfo.filter(**search_kwargs).first() + created = False + + if schedule: + for key, value in defaults.items(): + setattr(schedule, key, value) + await schedule.save() + else: + creation_kwargs = { + "plugin_name": plugin_name, + "group_id": group_id, + "bot_id": effective_bot_id, + **defaults, + } + schedule = await ScheduleInfo.create(**creation_kwargs) + created = True + self._add_aps_job(schedule) + action = "设置" if created else "更新" + return True, f"已成功{action}插件 '{plugin_name}' 的定时任务。" + + async def add_schedule_for_all( + self, + plugin_name: str, + trigger_type: str, + trigger_config: dict, + job_kwargs: dict | None = None, + ) -> tuple[int, int]: + """为所有机器人所在的群组添加定时任务。""" + if plugin_name not in self._registered_tasks: + raise ValueError(f"插件 '{plugin_name}' 没有注册可用的定时任务。") + + groups = set() + for bot in get_bots().values(): + try: + group_list, _ = await PlatformUtils.get_group_list(bot) + groups.update( + g.group_id for g in group_list if g.group_id and not g.channel_id + ) + except Exception as e: + logger.error(f"获取 Bot {bot.self_id} 的群列表失败", e=e) + + success_count = 0 + fail_count = 0 + for gid in groups: + try: + success, _ = await self.add_schedule( + plugin_name, gid, trigger_type, trigger_config, job_kwargs + ) + if success: + success_count += 1 + else: + fail_count += 1 + except Exception as e: + logger.error(f"为群 {gid} 添加定时任务失败: {e}", e=e) + fail_count += 1 + await asyncio.sleep(0.05) + return success_count, fail_count + + async def update_schedule( + self, + schedule_id: int, + trigger_type: str | None = None, + trigger_config: dict | None = None, + job_kwargs: dict | None = None, + ) -> tuple[bool, str]: + """部分更新一个已存在的定时任务。""" + schedule = await self.get_schedule_by_id(schedule_id) + if not schedule: + return False, f"未找到 ID 为 {schedule_id} 的任务。" + + updated_fields = [] + if trigger_config is not None: + schedule.trigger_config = trigger_config + updated_fields.append("trigger_config") + + if trigger_type is not None and schedule.trigger_type != trigger_type: + schedule.trigger_type = trigger_type + updated_fields.append("trigger_type") + + if job_kwargs is not None: + if not isinstance(schedule.job_kwargs, dict): + return False, f"任务 {schedule_id} 的 job_kwargs 数据格式错误。" + + merged_kwargs = schedule.job_kwargs.copy() + merged_kwargs.update(job_kwargs) + + is_valid, result = self._validate_and_prepare_kwargs( + schedule.plugin_name, merged_kwargs + ) + if not is_valid: + return False, str(result) + + schedule.job_kwargs = result # type: ignore + updated_fields.append("job_kwargs") + + if not updated_fields: + return True, "没有任何需要更新的配置。" + + await schedule.save(update_fields=updated_fields) + self._add_aps_job(schedule) + return True, f"成功更新了任务 ID: {schedule_id} 的配置。" + + async def remove_schedule( + self, plugin_name: str, group_id: str | None, bot_id: str | None = None + ) -> tuple[bool, str]: + """移除指定插件和群组的定时任务。""" + query = {"plugin_name": plugin_name, "group_id": group_id} + if bot_id: + query["bot_id"] = bot_id + + schedules = await ScheduleInfo.filter(**query) + if not schedules: + msg = ( + f"未找到与 Bot {bot_id} 相关的群 {group_id} " + f"的插件 '{plugin_name}' 定时任务。" + ) + return (False, msg) + + for schedule in schedules: + self._remove_aps_job(schedule.id) + await schedule.delete() + + target_desc = f"群 {group_id}" if group_id else "全局" + msg = ( + f"已取消 Bot {bot_id} 在 [{target_desc}] " + f"的插件 '{plugin_name}' 所有定时任务。" + ) + return (True, msg) + + async def remove_schedule_for_all( + self, plugin_name: str, bot_id: str | None = None + ) -> int: + """移除指定插件在所有群组的定时任务。""" + query = {"plugin_name": plugin_name} + if bot_id: + query["bot_id"] = bot_id + + schedules_to_delete = await ScheduleInfo.filter(**query).all() + if not schedules_to_delete: + return 0 + + for schedule in schedules_to_delete: + self._remove_aps_job(schedule.id) + await schedule.delete() + await asyncio.sleep(0.01) + + return len(schedules_to_delete) + + async def remove_schedules_by_group(self, group_id: str) -> tuple[bool, str]: + """移除指定群组的所有定时任务。""" + schedules = await ScheduleInfo.filter(group_id=group_id) + if not schedules: + return False, f"群 {group_id} 没有任何定时任务。" + + count = 0 + for schedule in schedules: + self._remove_aps_job(schedule.id) + await schedule.delete() + count += 1 + await asyncio.sleep(0.01) + + return True, f"已成功移除群 {group_id} 的 {count} 个定时任务。" + + async def pause_schedules_by_group(self, group_id: str) -> tuple[int, str]: + """暂停指定群组的所有定时任务。""" + schedules = await ScheduleInfo.filter(group_id=group_id, is_enabled=True) + if not schedules: + return 0, f"群 {group_id} 没有正在运行的定时任务可暂停。" + + count = 0 + for schedule in schedules: + success, _ = await self.pause_schedule(schedule.id) + if success: + count += 1 + await asyncio.sleep(0.01) + + return count, f"已成功暂停群 {group_id} 的 {count} 个定时任务。" + + async def resume_schedules_by_group(self, group_id: str) -> tuple[int, str]: + """恢复指定群组的所有定时任务。""" + schedules = await ScheduleInfo.filter(group_id=group_id, is_enabled=False) + if not schedules: + return 0, f"群 {group_id} 没有已暂停的定时任务可恢复。" + + count = 0 + for schedule in schedules: + success, _ = await self.resume_schedule(schedule.id) + if success: + count += 1 + await asyncio.sleep(0.01) + + return count, f"已成功恢复群 {group_id} 的 {count} 个定时任务。" + + async def pause_schedules_by_plugin(self, plugin_name: str) -> tuple[int, str]: + """暂停指定插件在所有群组的定时任务。""" + schedules = await ScheduleInfo.filter(plugin_name=plugin_name, is_enabled=True) + if not schedules: + return 0, f"插件 '{plugin_name}' 没有正在运行的定时任务可暂停。" + + count = 0 + for schedule in schedules: + success, _ = await self.pause_schedule(schedule.id) + if success: + count += 1 + await asyncio.sleep(0.01) + + return ( + count, + f"已成功暂停插件 '{plugin_name}' 在所有群组的 {count} 个定时任务。", + ) + + async def resume_schedules_by_plugin(self, plugin_name: str) -> tuple[int, str]: + """恢复指定插件在所有群组的定时任务。""" + schedules = await ScheduleInfo.filter(plugin_name=plugin_name, is_enabled=False) + if not schedules: + return 0, f"插件 '{plugin_name}' 没有已暂停的定时任务可恢复。" + + count = 0 + for schedule in schedules: + success, _ = await self.resume_schedule(schedule.id) + if success: + count += 1 + await asyncio.sleep(0.01) + + return ( + count, + f"已成功恢复插件 '{plugin_name}' 在所有群组的 {count} 个定时任务。", + ) + + async def pause_schedule_by_plugin_group( + self, plugin_name: str, group_id: str | None, bot_id: str | None = None + ) -> tuple[bool, str]: + """暂停指定插件在指定群组的定时任务。""" + query = {"plugin_name": plugin_name, "group_id": group_id, "is_enabled": True} + if bot_id: + query["bot_id"] = bot_id + + schedules = await ScheduleInfo.filter(**query) + if not schedules: + return ( + False, + f"群 {group_id} 未设置插件 '{plugin_name}' 的定时任务或任务已暂停。", + ) + + count = 0 + for schedule in schedules: + success, _ = await self.pause_schedule(schedule.id) + if success: + count += 1 + + return ( + True, + f"已成功暂停群 {group_id} 的插件 '{plugin_name}' 共 {count} 个定时任务。", + ) + + async def resume_schedule_by_plugin_group( + self, plugin_name: str, group_id: str | None, bot_id: str | None = None + ) -> tuple[bool, str]: + """恢复指定插件在指定群组的定时任务。""" + query = {"plugin_name": plugin_name, "group_id": group_id, "is_enabled": False} + if bot_id: + query["bot_id"] = bot_id + + schedules = await ScheduleInfo.filter(**query) + if not schedules: + return ( + False, + f"群 {group_id} 未设置插件 '{plugin_name}' 的定时任务或任务已启用。", + ) + + count = 0 + for schedule in schedules: + success, _ = await self.resume_schedule(schedule.id) + if success: + count += 1 + + return ( + True, + f"已成功恢复群 {group_id} 的插件 '{plugin_name}' 共 {count} 个定时任务。", + ) + + async def remove_all_schedules(self) -> tuple[int, str]: + """移除所有群组的所有定时任务。""" + schedules = await ScheduleInfo.all() + if not schedules: + return 0, "当前没有任何定时任务。" + + count = 0 + for schedule in schedules: + self._remove_aps_job(schedule.id) + await schedule.delete() + count += 1 + await asyncio.sleep(0.01) + + return count, f"已成功移除所有群组的 {count} 个定时任务。" + + async def pause_all_schedules(self) -> tuple[int, str]: + """暂停所有群组的所有定时任务。""" + schedules = await ScheduleInfo.filter(is_enabled=True) + if not schedules: + return 0, "当前没有正在运行的定时任务可暂停。" + + count = 0 + for schedule in schedules: + success, _ = await self.pause_schedule(schedule.id) + if success: + count += 1 + await asyncio.sleep(0.01) + + return count, f"已成功暂停所有群组的 {count} 个定时任务。" + + async def resume_all_schedules(self) -> tuple[int, str]: + """恢复所有群组的所有定时任务。""" + schedules = await ScheduleInfo.filter(is_enabled=False) + if not schedules: + return 0, "当前没有已暂停的定时任务可恢复。" + + count = 0 + for schedule in schedules: + success, _ = await self.resume_schedule(schedule.id) + if success: + count += 1 + await asyncio.sleep(0.01) + + return count, f"已成功恢复所有群组的 {count} 个定时任务。" + + async def remove_schedule_by_id(self, schedule_id: int) -> tuple[bool, str]: + """通过ID移除指定的定时任务。""" + schedule = await self.get_schedule_by_id(schedule_id) + if not schedule: + return False, f"未找到 ID 为 {schedule_id} 的定时任务。" + + self._remove_aps_job(schedule.id) + await schedule.delete() + + return ( + True, + f"已删除插件 '{schedule.plugin_name}' 在群 {schedule.group_id} " + f"的定时任务 (ID: {schedule.id})。", + ) + + async def get_schedule_by_id(self, schedule_id: int) -> ScheduleInfo | None: + """通过ID获取定时任务信息。""" + return await ScheduleInfo.get_or_none(id=schedule_id) + + async def get_schedules( + self, plugin_name: str, group_id: str | None + ) -> list[ScheduleInfo]: + """获取特定群组特定插件的所有定时任务。""" + return await ScheduleInfo.filter(plugin_name=plugin_name, group_id=group_id) + + async def get_schedule( + self, plugin_name: str, group_id: str | None + ) -> ScheduleInfo | None: + """获取特定群组的定时任务信息。""" + return await ScheduleInfo.get_or_none( + plugin_name=plugin_name, group_id=group_id + ) + + async def get_all_schedules( + self, plugin_name: str | None = None + ) -> list[ScheduleInfo]: + """获取所有定时任务信息,可按插件名过滤。""" + if plugin_name: + return await ScheduleInfo.filter(plugin_name=plugin_name).all() + return await ScheduleInfo.all() + + async def get_schedule_status(self, schedule_id: int) -> dict | None: + """获取任务的详细状态。""" + schedule = await self.get_schedule_by_id(schedule_id) + if not schedule: + return None + + job_id = self._get_job_id(schedule.id) + job = scheduler.get_job(job_id) + + status = { + "id": schedule.id, + "bot_id": schedule.bot_id, + "plugin_name": schedule.plugin_name, + "group_id": schedule.group_id, + "is_enabled": schedule.is_enabled, + "trigger_type": schedule.trigger_type, + "trigger_config": schedule.trigger_config, + "job_kwargs": schedule.job_kwargs, + "next_run_time": job.next_run_time.strftime("%Y-%m-%d %H:%M:%S") + if job and job.next_run_time + else "N/A", + "is_paused_in_scheduler": not bool(job.next_run_time) if job else "N/A", + } + return status + + async def pause_schedule(self, schedule_id: int) -> tuple[bool, str]: + """暂停一个定时任务。""" + schedule = await self.get_schedule_by_id(schedule_id) + if not schedule or not schedule.is_enabled: + return False, "任务不存在或已暂停。" + + schedule.is_enabled = False + await schedule.save(update_fields=["is_enabled"]) + + job_id = self._get_job_id(schedule.id) + try: + scheduler.pause_job(job_id) + except Exception: + pass + + return ( + True, + f"已暂停插件 '{schedule.plugin_name}' 在群 {schedule.group_id} " + f"的定时任务 (ID: {schedule.id})。", + ) + + async def resume_schedule(self, schedule_id: int) -> tuple[bool, str]: + """恢复一个定时任务。""" + schedule = await self.get_schedule_by_id(schedule_id) + if not schedule or schedule.is_enabled: + return False, "任务不存在或已启用。" + + schedule.is_enabled = True + await schedule.save(update_fields=["is_enabled"]) + + job_id = self._get_job_id(schedule.id) + try: + scheduler.resume_job(job_id) + except Exception: + self._add_aps_job(schedule) + + return ( + True, + f"已恢复插件 '{schedule.plugin_name}' 在群 {schedule.group_id} " + f"的定时任务 (ID: {schedule.id})。", + ) + + async def trigger_now(self, schedule_id: int) -> tuple[bool, str]: + """手动触发一个定时任务。""" + schedule = await self.get_schedule_by_id(schedule_id) + if not schedule: + return False, f"未找到 ID 为 {schedule_id} 的定时任务。" + + if schedule.plugin_name not in self._registered_tasks: + return False, f"插件 '{schedule.plugin_name}' 没有注册可用的定时任务。" + + try: + await self._execute_job(schedule.id) + return ( + True, + f"已手动触发插件 '{schedule.plugin_name}' 在群 {schedule.group_id} " + f"的定时任务 (ID: {schedule.id})。", + ) + except Exception as e: + logger.error(f"手动触发任务失败: {e}") + return False, f"手动触发任务失败: {e}" + + +scheduler_manager = SchedulerManager() + + +@PriorityLifecycle.on_startup(priority=90) +async def _load_schedules_from_db(): + """在服务启动时从数据库加载并调度所有任务。""" + Config.add_plugin_config( + "SchedulerManager", + SCHEDULE_CONCURRENCY_KEY, + 5, + help="“所有群组”类型定时任务的并发执行数量限制", + type=int, + ) + + logger.info("正在从数据库加载并调度所有定时任务...") + schedules = await ScheduleInfo.filter(is_enabled=True).all() + count = 0 + for schedule in schedules: + if schedule.plugin_name in scheduler_manager._registered_tasks: + scheduler_manager._add_aps_job(schedule) + count += 1 + else: + logger.warning(f"跳过加载定时任务:插件 '{schedule.plugin_name}' 未注册。") + logger.info(f"定时任务加载完成,共成功加载 {count} 个任务。") diff --git a/zhenxun/utils/manager/virtual_env_package_manager.py b/zhenxun/utils/manager/virtual_env_package_manager.py new file mode 100644 index 00000000..c4665bd5 --- /dev/null +++ b/zhenxun/utils/manager/virtual_env_package_manager.py @@ -0,0 +1,169 @@ +from pathlib import Path +import subprocess +from subprocess import CalledProcessError +from typing import ClassVar + +from zhenxun.configs.config import Config +from zhenxun.services.log import logger + +BAT_FILE = Path() / "win启动.bat" + +LOG_COMMAND = "VirtualEnvPackageManager" + +Config.add_plugin_config( + "virtualenv", + "python_path", + None, + help="虚拟环境python路径,为空时使用系统环境的poetry", +) + + +class VirtualEnvPackageManager: + WIN_COMMAND: ClassVar[list[str]] = [ + "./Python310/python.exe", + "-m", + "pip", + ] + + DEFAULT_COMMAND: ClassVar[list[str]] = ["poetry", "run", "pip"] + + @classmethod + def __get_command(cls) -> list[str]: + if path := Config.get_config("virtualenv", "python_path"): + return [path, "-m", "pip"] + return cls.WIN_COMMAND if BAT_FILE.exists() else cls.DEFAULT_COMMAND + + @classmethod + def install(cls, package: list[str] | str): + """安装依赖包 + + 参数: + package: 安装依赖包名称或列表 + """ + if isinstance(package, str): + package = [package] + try: + command = cls.__get_command() + command.append("install") + command.append(" ".join(package)) + logger.info(f"执行虚拟环境安装包指令: {command}", LOG_COMMAND) + result = subprocess.run( + command, + check=True, + capture_output=True, + text=True, + ) + logger.debug( + f"安装虚拟环境包指令执行完成: {result.stdout}", + LOG_COMMAND, + ) + except CalledProcessError as e: + logger.error(f"安装虚拟环境包指令执行失败: {e.stderr}.", LOG_COMMAND) + + @classmethod + def uninstall(cls, package: list[str] | str): + """卸载依赖包 + + 参数: + package: 卸载依赖包名称或列表 + """ + if isinstance(package, str): + package = [package] + try: + command = cls.__get_command() + command.append("uninstall") + command.append(" ".join(package)) + logger.info(f"执行虚拟环境卸载包指令: {command}", LOG_COMMAND) + result = subprocess.run( + command, + check=True, + capture_output=True, + text=True, + ) + logger.debug( + f"卸载虚拟环境包指令执行完成: {result.stdout}", + LOG_COMMAND, + ) + except CalledProcessError as e: + logger.error(f"卸载虚拟环境包指令执行失败: {e.stderr}.", LOG_COMMAND) + + @classmethod + def update(cls, package: list[str] | str): + """更新依赖包 + + 参数: + package: 更新依赖包名称或列表 + """ + if isinstance(package, str): + package = [package] + try: + command = cls.__get_command() + command.append("install") + command.append("--upgrade") + command.append(" ".join(package)) + logger.info(f"执行虚拟环境更新包指令: {command}", LOG_COMMAND) + result = subprocess.run( + command, + check=True, + capture_output=True, + text=True, + ) + logger.debug(f"更新虚拟环境包指令执行完成: {result.stdout}", LOG_COMMAND) + except CalledProcessError as e: + logger.error(f"更新虚拟环境包指令执行失败: {e.stderr}.", LOG_COMMAND) + + @classmethod + def install_requirement(cls, requirement_file: Path): + """安装依赖文件 + + 参数: + requirement_file: requirement文件路径 + + 异常: + FileNotFoundError: 文件不存在 + """ + if not requirement_file.exists(): + raise FileNotFoundError(f"依赖文件 {requirement_file} 不存在", LOG_COMMAND) + try: + command = cls.__get_command() + command.append("install") + command.append("-r") + command.append(str(requirement_file.absolute())) + logger.info(f"执行虚拟环境安装依赖文件指令: {command}", LOG_COMMAND) + result = subprocess.run( + command, + check=True, + capture_output=True, + text=True, + ) + logger.debug( + f"安装虚拟环境依赖文件指令执行完成: {result.stdout}", + LOG_COMMAND, + ) + except CalledProcessError as e: + logger.error( + f"安装虚拟环境依赖文件指令执行失败: {e.stderr}.", + LOG_COMMAND, + ) + + @classmethod + def list(cls) -> str: + """列出已安装的依赖包""" + try: + command = cls.__get_command() + command.append("list") + logger.info(f"执行虚拟环境列出包指令: {command}", LOG_COMMAND) + result = subprocess.run( + command, + check=True, + capture_output=True, + text=True, + ) + logger.debug( + f"列出虚拟环境包指令执行完成: {result.stdout}", + LOG_COMMAND, + ) + return result.stdout + except CalledProcessError as e: + logger.error(f"列出虚拟环境包指令执行失败: {e.stderr}.", LOG_COMMAND) + return "" diff --git a/zhenxun/utils/platform.py b/zhenxun/utils/platform.py index 6d379131..790aa230 100644 --- a/zhenxun/utils/platform.py +++ b/zhenxun/utils/platform.py @@ -1,7 +1,7 @@ import asyncio from collections.abc import Awaitable, Callable import random -from typing import Literal +from typing import cast import httpx import nonebot @@ -83,7 +83,7 @@ class PlatformUtils: bot: Bot, message: UniMessage | str, superuser_id: str | None = None, - ) -> Receipt | None: + ) -> list[tuple[str, Receipt]]: """发送消息给超级用户 参数: @@ -97,15 +97,33 @@ class PlatformUtils: 返回: Receipt | None: Receipt """ - if not superuser_id: - if platform := cls.get_platform(bot): - if platform_superusers := BotConfig.get_superuser(platform): - superuser_id = random.choice(platform_superusers) - else: - raise NotFindSuperuser() + superuser_ids = [] + if superuser_id: + superuser_ids.append(superuser_id) + elif platform := cls.get_platform(bot): + if platform_superusers := BotConfig.get_superuser(platform): + superuser_ids = platform_superusers + else: + raise NotFindSuperuser() if isinstance(message, str): message = MessageUtils.build_message(message) - return await cls.send_message(bot, superuser_id, None, message) + result = [] + for superuser_id in superuser_ids: + try: + result.append( + ( + superuser_id, + await cls.send_message(bot, superuser_id, None, message), + ) + ) + except Exception as e: + logger.error( + "发送消息给超级用户失败", + "PlatformUtils:send_superuser", + target=superuser_id, + e=e, + ) + return result @classmethod async def get_group_member_list(cls, bot: Bot, group_id: str) -> list[UserData]: @@ -209,7 +227,7 @@ class PlatformUtils: url = None if platform == "qq": if user_id.isdigit(): - url = f"http://q1.qlogo.cn/g?b=qq&nk={user_id}&s=160" + url = f"http://q1.qlogo.cn/g?b=qq&nk={user_id}&s=640" else: url = f"https://q.qlogo.cn/qqapp/{appid}/{user_id}/640" return await AsyncHttpx.get_content(url) if url else None @@ -468,15 +486,141 @@ class PlatformUtils: return target +class BroadcastEngine: + def __init__( + self, + message: str | UniMessage, + bot: Bot | list[Bot] | None = None, + bot_id: str | set[str] | None = None, + ignore_group: list[str] | None = None, + check_func: Callable[[Bot, str], Awaitable] | None = None, + log_cmd: str | None = None, + platform: str | None = None, + ): + """广播引擎 + + 参数: + message: 广播消息内容 + bot: 指定bot对象. + bot_id: 指定bot id. + ignore_group: 忽略群聊列表. + check_func: 发送前对群聊检测方法,判断是否发送. + log_cmd: 日志标记. + platform: 指定平台. + + 异常: + ValueError: 没有可用的Bot对象 + """ + if ignore_group is None: + ignore_group = [] + self.message = MessageUtils.build_message(message) + self.ignore_group = ignore_group + self.check_func = check_func + self.log_cmd = log_cmd + self.platform = platform + self.bot_list = [] + self.count = 0 + if bot: + self.bot_list = [bot] if isinstance(bot, Bot) else bot + if isinstance(bot_id, str): + bot_id = set(bot_id) + if bot_id: + for i in bot_id: + try: + self.bot_list.append(nonebot.get_bot(i)) + except KeyError: + logger.warning(f"Bot:{i} 对象未连接或不存在", log_cmd) + if not self.bot_list: + try: + bot = nonebot.get_bot() + self.bot_list.append(bot) + logger.warning( + f"广播任务未传入Bot对象,使用默认Bot {bot.self_id}", log_cmd + ) + except Exception as e: + raise ValueError("当前没有可用的Bot对象...", log_cmd) from e + + async def call_check(self, bot: Bot, group_id: str) -> bool: + """运行发送检测函数 + + 参数: + bot: Bot + group_id: 群组id + + 返回: + bool: 是否发送 + """ + if not self.check_func: + return True + if is_coroutine_callable(self.check_func): + is_run = await self.check_func(bot, group_id) + else: + is_run = self.check_func(bot, group_id) + return cast(bool, is_run) + + async def __send_message(self, bot: Bot, group: GroupConsole): + """群组发送消息 + + 参数: + bot: Bot + group: GroupConsole + """ + key = f"{group.group_id}:{group.channel_id}" + if not await self.call_check(bot, group.group_id): + logger.debug( + "广播方法检测运行方法为 False, 已跳过该群组...", + self.log_cmd, + group_id=group.group_id, + ) + return + if target := PlatformUtils.get_target( + group_id=group.group_id, + channel_id=group.channel_id, + ): + self.ignore_group.append(key) + await MessageUtils.build_message(self.message).send(target, bot) + logger.debug("广播消息发送成功...", self.log_cmd, target=key) + else: + logger.warning("广播消息获取Target失败...", self.log_cmd, target=key) + + async def broadcast(self) -> int: + """广播消息 + + 返回: + int: 成功发送次数 + """ + for bot in self.bot_list: + if self.platform and self.platform != PlatformUtils.get_platform(bot): + continue + group_list, _ = await PlatformUtils.get_group_list(bot) + if not group_list: + continue + for group in group_list: + if ( + group.group_id in self.ignore_group + or group.channel_id in self.ignore_group + ): + continue + try: + await self.__send_message(bot, group) + await asyncio.sleep(random.randint(1, 3)) + self.count += 1 + except Exception as e: + logger.warning( + "广播消息发送失败", self.log_cmd, target=group.group_id, e=e + ) + return self.count + + async def broadcast_group( message: str | UniMessage, bot: Bot | list[Bot] | None = None, bot_id: str | set[str] | None = None, - ignore_group: set[int] | None = None, + ignore_group: list[str] = [], check_func: Callable[[Bot, str], Awaitable] | None = None, log_cmd: str | None = None, - platform: Literal["qq", "dodo", "kaiheila"] | None = None, -): + platform: str | None = None, +) -> int: """获取所有Bot或指定Bot对象广播群聊 参数: @@ -487,81 +631,18 @@ async def broadcast_group( check_func: 发送前对群聊检测方法,判断是否发送. log_cmd: 日志标记. platform: 指定平台 + + 返回: + int: 成功发送次数 """ - if platform and platform not in ["qq", "dodo", "kaiheila"]: - raise ValueError("指定平台不支持") - if not message: - raise ValueError("群聊广播消息不能为空") - bot_dict = nonebot.get_bots() - bot_list: list[Bot] = [] - if bot: - if isinstance(bot, list): - bot_list = bot - else: - bot_list.append(bot) - elif bot_id: - _bot_id_list = bot_id - if isinstance(bot_id, str): - _bot_id_list = [bot_id] - for id_ in _bot_id_list: - if bot_id in bot_dict: - bot_list.append(bot_dict[bot_id]) - else: - logger.warning(f"Bot:{id_} 对象未连接或不存在") - else: - bot_list = list(bot_dict.values()) - _used_group = [] - for _bot in bot_list: - try: - if platform and platform != PlatformUtils.get_platform(_bot): - continue - group_list, _ = await PlatformUtils.get_group_list(_bot) - if group_list: - for group in group_list: - key = f"{group.group_id}:{group.channel_id}" - try: - if ( - ignore_group - and ( - group.group_id in ignore_group - or group.channel_id in ignore_group - ) - ) or key in _used_group: - logger.debug( - "广播方法群组重复, 已跳过...", - log_cmd, - group_id=group.group_id, - ) - continue - is_run = False - if check_func: - if is_coroutine_callable(check_func): - is_run = await check_func(_bot, group.group_id) - else: - is_run = check_func(_bot, group.group_id) - if not is_run: - logger.debug( - "广播方法检测运行方法为 False, 已跳过...", - log_cmd, - group_id=group.group_id, - ) - continue - target = PlatformUtils.get_target( - user_id=None, - group_id=group.group_id, - channel_id=group.channel_id, - ) - if target: - _used_group.append(key) - message_list = message - await MessageUtils.build_message(message_list).send( - target, _bot - ) - logger.debug("发送成功", log_cmd, target=key) - await asyncio.sleep(random.randint(1, 3)) - else: - logger.warning("target为空", log_cmd, target=key) - except Exception as e: - logger.error("发送失败", log_cmd, target=key, e=e) - except Exception as e: - logger.error(f"Bot: {_bot.self_id} 获取群聊列表失败", command=log_cmd, e=e) + if not message.strip(): + raise ValueError("群聊广播消息不能为空...") + return await BroadcastEngine( + message=message, + bot=bot, + bot_id=bot_id, + ignore_group=ignore_group, + check_func=check_func, + log_cmd=log_cmd, + platform=platform, + ).broadcast()