diff --git a/README.md b/README.md index 094e00ba..72641550 100644 --- a/README.md +++ b/README.md @@ -112,7 +112,7 @@ AccessToken: PUBLIC_ZHENXUN_TEST | [插件库](https://github.com/zhenxun-org/zhenxun_bot_plugins) | 插件 | [zhenxun-org](https://github.com/zhenxun-org) | 原 plugins 文件夹插件 | | [插件索引库](https://github.com/zhenxun-org/zhenxun_bot_plugins_index) | 插件 | [zhenxun-org](https://github.com/zhenxun-org) | 扩展插件索引库 | | [一键安装](https://github.com/soloxiaoye2022/zhenxun_bot-deploy) | 安装 | [soloxiaoye2022](https://github.com/soloxiaoye2022) | 第三方 | -| [WebUi](https://github.com/HibiKier/zhenxun_bot_webui) | 管理 | [hibikier](https://github.com/HibiKier) | 基于真寻 WebApi 的 webui 实现 [预览](#-webui界面展示) | +| [WebUi](https://github.com/zhenxun-org/zhenxun_bot) | 管理 | [hibikier](https://github.com/HibiKier) | 基于真寻 WebApi 的 webui 实现 [预览](#-webui界面展示) | | [安卓 app(WebUi)](https://github.com/YuS1aN/zhenxun_bot_android_ui) | 安装 | [YuS1aN](https://github.com/YuS1aN) | 第三方 | @@ -126,6 +126,28 @@ AccessToken: PUBLIC_ZHENXUN_TEST - 提供了 cd,阻塞,每日次数等限制,仅仅通过简单的属性就可以生成一个限制,例如:`PluginCdBlock` 等 - **更多详细请通过 [传送门](https://zhenxun-org.github.io/zhenxun_bot/) 查看文档!** +## 🐣 小白整合 + +如果你系统是 **Windows** 且不想下载 Python +可以使用整合包(Python3.10+zhenxun+webui) + +文档地址:[整合包文档](https://hibikier.github.io/zhenxun_bot/beginner/) + +
+下载地址 + +- **百度云:** + https://pan.baidu.com/s/1ph4yzx1vdNbkxm9VBKDdgQ?pwd=971j + +- **天翼云:** + https://cloud.189.cn/web/share?code=jq67r2i2E7Fb + 访问码:8wxm + +- **Google Drive:** + https://drive.google.com/file/d/1cc3Dqjk0x5hWGLNeMkrFwWl8BvsK6KfD/view?usp=drive_link + +
+ ## 🛠️ 简单部署 ```bash @@ -150,7 +172,7 @@ poetry run python bot.py 1.在 .env.dev 文件中填写你的机器人配置项 -2.在 configs/config.yaml 文件中修改你需要修改的插件配置项 +2.在 data/config.yaml 文件中修改你需要修改的插件配置项
数据库地址(DB_URL)配置说明 @@ -272,12 +294,12 @@ DB_URL 是基于 Tortoise ORM 的数据库连接字符串,用于指定项目 ## ❔ 需要帮助? > [!TIP] -> 发起 [issue](https://github.com/HibiKier/zhenxun_bot/issues/new/choose) 前,我们希望你能够阅读过或者了解 [提问的智慧](https://github.com/ryanhanwu/How-To-Ask-Questions-The-Smart-Way/blob/main/README-zh_CN.md) +> 发起 [issue](https://github.com/zhenxun-org/zhenxun_bot/issues/new/choose) 前,我们希望你能够阅读过或者了解 [提问的智慧](https://github.com/ryanhanwu/How-To-Ask-Questions-The-Smart-Way/blob/main/README-zh_CN.md) > > - 善用[搜索引擎](https://www.google.com/) > - 查阅 issue 中是否有类似问题,如果没有请按照模板发起 issue -欢迎前往 [issue](https://github.com/HibiKier/zhenxun_bot/issues/new/choose) 中提出你遇到的问题,或者加入我们的 [用户群](https://qm.qq.com/q/mRNtLSl6uc) 或 [技术群](https://qm.qq.com/q/YYYt5rkMYc)与我们联系 +欢迎前往 [issue](https://github.com/zhenxun-org/zhenxun_bot/issues/new/choose) 中提出你遇到的问题,或者加入我们的 [用户群](https://qm.qq.com/q/mRNtLSl6uc) 或 [技术群](https://qm.qq.com/q/YYYt5rkMYc)与我们联系 ## 🛠️ 进度追踪 @@ -287,6 +309,8 @@ Project [zhenxun_bot](https://github.com/users/HibiKier/projects/2) 首席设计师:[酥酥/coldly-ss](https://github.com/coldly-ss) +LOGO 设计:[FrostN0v0](https://github.com/FrostN0v0) + ## 🙏 感谢 [botuniverse / onebot](https://github.com/botuniverse/onebot) :超棒的机器人协议 @@ -326,34 +350,68 @@ Project [zhenxun_bot](https://github.com/users/HibiKier/projects/2) contributors -## 📸 WebUI 界面展示 +## 📸 WebUI 界面展示(仅展示默认主题下的 pc 端)
-
- webui00 -
-
- webui01 -
-
- webui02 -
-
- webui03 -
+#### 登录界面 -
- webui04 -
-
- webui05 -
+![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-login.jpg) + +#### API 设置 + +![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-api.jpg) + +#### 仪表盘 + +![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-dashboard.jpg) + +#### 仪表盘(展开) + +![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-dashboard1.jpg) + +#### 控制台 + +![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-command.jpg) + +#### 插件列表 + +![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-plugin.jpg) + +#### 插件列表(配置项) + +![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-plugin1.jpg) + +#### 插件商店 + +![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-store.jpg) + +#### 好友/群组管理 + +![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-manage.jpg) + +#### 请求管理 + +![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-manage1.jpg) + +#### 数据库管理 + +![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-database.jpg) + +### 文件管理 + +![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-system.jpg) + +### 文件管理(文本查看) + +![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-system1.jpg) + +### 文件管理(图片查看) + +![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-system2.jpg) + +### 关于 + +![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-about.jpg) -
- webui06 -
-
- webui07 -
diff --git a/bot.py b/bot.py index 52cd29fc..aa047a71 100644 --- a/bot.py +++ b/bot.py @@ -14,9 +14,9 @@ driver.register_adapter(OneBotV11Adapter) # driver.register_adapter(DoDoAdapter) # driver.register_adapter(DiscordAdapter) -from zhenxun.services.db_context import disconnect, init +from zhenxun.services.db_context import disconnect -driver.on_startup(init) +# driver.on_startup(init) driver.on_shutdown(disconnect) # nonebot.load_builtin_plugins("echo") diff --git a/docs_image/pc-about.jpg b/docs_image/pc-about.jpg new file mode 100644 index 00000000..0bef7a9e Binary files /dev/null and b/docs_image/pc-about.jpg differ diff --git a/docs_image/pc-api.jpg b/docs_image/pc-api.jpg new file mode 100644 index 00000000..59cee887 Binary files /dev/null and b/docs_image/pc-api.jpg differ diff --git a/docs_image/pc-command.jpg b/docs_image/pc-command.jpg new file mode 100644 index 00000000..0e310e29 Binary files /dev/null and b/docs_image/pc-command.jpg differ diff --git a/docs_image/pc-dashboard.jpg b/docs_image/pc-dashboard.jpg new file mode 100644 index 00000000..0478a850 Binary files /dev/null and b/docs_image/pc-dashboard.jpg differ diff --git a/docs_image/pc-dashboard1.jpg b/docs_image/pc-dashboard1.jpg new file mode 100644 index 00000000..3a0bc958 Binary files /dev/null and b/docs_image/pc-dashboard1.jpg differ diff --git a/docs_image/pc-database.jpg b/docs_image/pc-database.jpg new file mode 100644 index 00000000..68c60aa3 Binary files /dev/null and b/docs_image/pc-database.jpg differ diff --git a/docs_image/pc-login.jpg b/docs_image/pc-login.jpg new file mode 100644 index 00000000..65fe8b46 Binary files /dev/null and b/docs_image/pc-login.jpg differ diff --git a/docs_image/pc-manage.jpg b/docs_image/pc-manage.jpg new file mode 100644 index 00000000..e5f8902a Binary files /dev/null and b/docs_image/pc-manage.jpg differ diff --git a/docs_image/pc-manage1.jpg b/docs_image/pc-manage1.jpg new file mode 100644 index 00000000..4756c629 Binary files /dev/null and b/docs_image/pc-manage1.jpg differ diff --git a/docs_image/pc-plugin.jpg b/docs_image/pc-plugin.jpg new file mode 100644 index 00000000..147e26eb Binary files /dev/null and b/docs_image/pc-plugin.jpg differ diff --git a/docs_image/pc-plugin1.jpg b/docs_image/pc-plugin1.jpg new file mode 100644 index 00000000..58694e6d Binary files /dev/null and b/docs_image/pc-plugin1.jpg differ diff --git a/docs_image/pc-store.jpg b/docs_image/pc-store.jpg new file mode 100644 index 00000000..4c9b68e4 Binary files /dev/null and b/docs_image/pc-store.jpg differ diff --git a/docs_image/pc-system.jpg b/docs_image/pc-system.jpg new file mode 100644 index 00000000..9908a2bd Binary files /dev/null and b/docs_image/pc-system.jpg differ diff --git a/docs_image/pc-system1.jpg b/docs_image/pc-system1.jpg new file mode 100644 index 00000000..3333a1b5 Binary files /dev/null and b/docs_image/pc-system1.jpg differ diff --git a/docs_image/pc-system2.jpg b/docs_image/pc-system2.jpg new file mode 100644 index 00000000..649a5bc9 Binary files /dev/null and b/docs_image/pc-system2.jpg differ diff --git a/docs_image/webui00.png b/docs_image/webui00.png deleted file mode 100644 index 71f7d368..00000000 Binary files a/docs_image/webui00.png and /dev/null differ diff --git a/docs_image/webui01.png b/docs_image/webui01.png deleted file mode 100644 index cd415685..00000000 Binary files a/docs_image/webui01.png and /dev/null differ diff --git a/docs_image/webui02.png b/docs_image/webui02.png deleted file mode 100644 index 0fcc4f05..00000000 Binary files a/docs_image/webui02.png and /dev/null differ diff --git a/docs_image/webui03.png b/docs_image/webui03.png deleted file mode 100644 index 2e7426e3..00000000 Binary files a/docs_image/webui03.png and /dev/null differ diff --git a/docs_image/webui04.png b/docs_image/webui04.png deleted file mode 100644 index 5810f71b..00000000 Binary files a/docs_image/webui04.png and /dev/null differ diff --git a/docs_image/webui05.png b/docs_image/webui05.png deleted file mode 100644 index d5f5e304..00000000 Binary files a/docs_image/webui05.png and /dev/null differ diff --git a/docs_image/webui06.png b/docs_image/webui06.png deleted file mode 100644 index 7541f679..00000000 Binary files a/docs_image/webui06.png and /dev/null differ diff --git a/docs_image/webui07.png b/docs_image/webui07.png deleted file mode 100644 index 1628ade7..00000000 Binary files a/docs_image/webui07.png and /dev/null differ diff --git a/zhenxun/builtin_plugins/__init__.py b/zhenxun/builtin_plugins/__init__.py index fbaeb280..f2688905 100644 --- a/zhenxun/builtin_plugins/__init__.py +++ b/zhenxun/builtin_plugins/__init__.py @@ -16,6 +16,7 @@ from zhenxun.models.sign_user import SignUser from zhenxun.models.user_console import UserConsole from zhenxun.services.log import logger from zhenxun.utils.decorator.shop import shop_register +from zhenxun.utils.manager.priority_manager import PriorityLifecycle from zhenxun.utils.manager.resource_manager import ResourceManager from zhenxun.utils.platform import PlatformUtils @@ -70,7 +71,7 @@ from public.bag_users t1 """ -@driver.on_startup +@PriorityLifecycle.on_startup(priority=5) async def _(): await ResourceManager.init_resources() """签到与用户的数据迁移""" diff --git a/zhenxun/builtin_plugins/admin/welcome_message/data_source.py b/zhenxun/builtin_plugins/admin/welcome_message/data_source.py index 2ccb33ee..c8e486ed 100644 --- a/zhenxun/builtin_plugins/admin/welcome_message/data_source.py +++ b/zhenxun/builtin_plugins/admin/welcome_message/data_source.py @@ -14,6 +14,7 @@ from zhenxun.services.log import logger from zhenxun.utils._build_image import BuildImage from zhenxun.utils._image_template import ImageTemplate from zhenxun.utils.http_utils import AsyncHttpx +from zhenxun.utils.manager.priority_manager import PriorityLifecycle from zhenxun.utils.platform import PlatformUtils BASE_PATH = DATA_PATH / "welcome_message" @@ -91,7 +92,7 @@ def migrate(path: Path): json.dump(new_data, f, ensure_ascii=False, indent=4) -@driver.on_startup +@PriorityLifecycle.on_startup(priority=5) def _(): """数据迁移 diff --git a/zhenxun/builtin_plugins/help/__init__.py b/zhenxun/builtin_plugins/help/__init__.py index 726d4d1e..17002f0c 100644 --- a/zhenxun/builtin_plugins/help/__init__.py +++ b/zhenxun/builtin_plugins/help/__init__.py @@ -37,8 +37,8 @@ __plugin_meta__ = PluginMetadata( configs=[ RegisterConfig( key="type", - value="normal", - help="帮助图片样式 ['normal', 'HTML', 'zhenxun']", + value="zhenxun", + help="帮助图片样式 [normal, HTML, zhenxun]", default_value="zhenxun", ) ], diff --git a/zhenxun/builtin_plugins/hooks/__init__.py b/zhenxun/builtin_plugins/hooks/__init__.py index 3ad29d71..2f8c79de 100644 --- a/zhenxun/builtin_plugins/hooks/__init__.py +++ b/zhenxun/builtin_plugins/hooks/__init__.py @@ -49,4 +49,14 @@ Config.add_plugin_config( type=bool, ) +Config.add_plugin_config( + "hook", + "RECORD_BOT_SENT_MESSAGES", + True, + help="记录bot消息发送", + default_value=True, + type=bool, +) + + nonebot.load_plugins(str(Path(__file__).parent.resolve())) diff --git a/zhenxun/builtin_plugins/hooks/call_hook.py b/zhenxun/builtin_plugins/hooks/call_hook.py index 2ff4d39c..1893754d 100644 --- a/zhenxun/builtin_plugins/hooks/call_hook.py +++ b/zhenxun/builtin_plugins/hooks/call_hook.py @@ -1,23 +1,85 @@ from typing import Any -from nonebot.adapters import Bot +from nonebot.adapters import Bot, Message +from zhenxun.configs.config import Config +from zhenxun.models.bot_message_store import BotMessageStore from zhenxun.services.log import logger +from zhenxun.utils.enum import BotSentType from zhenxun.utils.manager.message_manager import MessageManager +from zhenxun.utils.platform import PlatformUtils + + +def replace_message(message: Message) -> str: + """将消息中的at、image、record、face替换为字符串 + + 参数: + message: Message + + 返回: + str: 文本消息 + """ + result = "" + for msg in message: + if isinstance(msg, str): + result += msg + elif msg.type == "at": + result += f"@{msg.data['qq']}" + elif msg.type == "image": + result += "[image]" + elif msg.type == "record": + result += "[record]" + elif msg.type == "face": + result += f"[face:{msg.data['id']}]" + elif msg.type == "reply": + result += "" + else: + result += str(msg) + return result @Bot.on_called_api async def handle_api_result( bot: Bot, exception: Exception | None, api: str, data: dict[str, Any], result: Any ): - if not exception and api == "send_msg": - try: - if (uid := data.get("user_id")) and (msg_id := result.get("message_id")): - MessageManager.add(str(uid), str(msg_id)) - logger.debug( - f"收集消息id,user_id: {uid}, msg_id: {msg_id}", "msg_hook" - ) - except Exception as e: - logger.warning( - f"收集消息id发生错误...data: {data}, result: {result}", "msg_hook", e=e + if exception or api != "send_msg": + return + user_id = data.get("user_id") + group_id = data.get("group_id") + message_id = result.get("message_id") + message: Message = data.get("message", "") + message_type = data.get("message_type") + try: + # 记录消息id + if user_id and message_id: + MessageManager.add(str(user_id), str(message_id)) + logger.debug( + f"收集消息id,user_id: {user_id}, msg_id: {message_id}", "msg_hook" ) + except Exception as e: + logger.warning( + f"收集消息id发生错误...data: {data}, result: {result}", "msg_hook", e=e + ) + if not Config.get_config("hook", "RECORD_BOT_SENT_MESSAGES"): + return + try: + await BotMessageStore.create( + bot_id=bot.self_id, + user_id=user_id, + group_id=group_id, + sent_type=BotSentType.GROUP + if message_type == "group" + else BotSentType.PRIVATE, + text=replace_message(message), + plain_text=message.extract_plain_text() + if isinstance(message, Message) + else replace_message(message), + platform=PlatformUtils.get_platform(bot), + ) + logger.debug(f"消息发送记录,message: {message}") + except Exception as e: + logger.warning( + f"消息发送记录发生错误...data: {data}, result: {result}", + "msg_hook", + e=e, + ) diff --git a/zhenxun/builtin_plugins/init/init_config.py b/zhenxun/builtin_plugins/init/init_config.py index 112d29de..eef63635 100644 --- a/zhenxun/builtin_plugins/init/init_config.py +++ b/zhenxun/builtin_plugins/init/init_config.py @@ -11,6 +11,7 @@ from zhenxun.configs.config import Config from zhenxun.configs.path_config import DATA_PATH from zhenxun.configs.utils import RegisterConfig from zhenxun.services.log import logger +from zhenxun.utils.manager.priority_manager import PriorityLifecycle _yaml = YAML(pure=True) _yaml.allow_unicode = True @@ -102,7 +103,7 @@ def _generate_simple_config(exists_module: list[str]): temp_file.unlink() -@driver.on_startup +@PriorityLifecycle.on_startup(priority=0) def _(): """ 初始化插件数据配置 @@ -125,3 +126,4 @@ def _(): with plugins2config_file.open("w", encoding="utf8") as wf: _yaml.dump(_data, wf) _generate_simple_config(exists_module) + Config.reload() diff --git a/zhenxun/builtin_plugins/init/init_plugin.py b/zhenxun/builtin_plugins/init/init_plugin.py index dbeddb54..5bf50409 100644 --- a/zhenxun/builtin_plugins/init/init_plugin.py +++ b/zhenxun/builtin_plugins/init/init_plugin.py @@ -20,6 +20,7 @@ from zhenxun.utils.enum import ( PluginLimitType, PluginType, ) +from zhenxun.utils.manager.priority_manager import PriorityLifecycle from .manager import manager @@ -95,7 +96,7 @@ async def _handle_setting( ) -@driver.on_startup +@PriorityLifecycle.on_startup(priority=5) async def _(): """ 初始化插件数据配置 diff --git a/zhenxun/builtin_plugins/init/init_task.py b/zhenxun/builtin_plugins/init/init_task.py index cead7d72..b9bab56d 100644 --- a/zhenxun/builtin_plugins/init/init_task.py +++ b/zhenxun/builtin_plugins/init/init_task.py @@ -10,6 +10,7 @@ from zhenxun.models.group_console import GroupConsole from zhenxun.models.task_info import TaskInfo from zhenxun.services.log import logger from zhenxun.utils.common_utils import CommonUtils +from zhenxun.utils.manager.priority_manager import PriorityLifecycle driver: Driver = nonebot.get_driver() @@ -132,7 +133,7 @@ async def create_schedule(task: Task): logger.error(f"动态创建定时任务 {task.name}({task.module}) 失败", e=e) -@driver.on_startup +@PriorityLifecycle.on_startup(priority=5) async def _(): """ 初始化插件数据配置 diff --git a/zhenxun/builtin_plugins/plugin_store/data_source.py b/zhenxun/builtin_plugins/plugin_store/data_source.py index 54b087a6..41977689 100644 --- a/zhenxun/builtin_plugins/plugin_store/data_source.py +++ b/zhenxun/builtin_plugins/plugin_store/data_source.py @@ -23,6 +23,12 @@ from .config import ( LOG_COMMAND, ) +BAT_FILE = Path() / "win启动.bat" + +WIN_COMMAND = ["./Python310/python.exe", "-m", "pip", "install", "-r"] + +DEFAULT_COMMAND = ["poetry", "run", "pip", "install", "-r"] + def row_style(column: str, text: str) -> RowStyle: """被动技能文本风格 @@ -50,6 +56,33 @@ def install_requirement(plugin_path: Path): VirtualEnvPackageManager.install_requirement(existing_requirements) + if not existing_requirements: + logger.debug( + f"No requirement.txt found for plugin: {plugin_path.name}", "插件管理" + ) + return + + try: + command = WIN_COMMAND if BAT_FILE.exists() else DEFAULT_COMMAND + command.append(str(existing_requirements)) + result = subprocess.run( + command, + check=True, + capture_output=True, + text=True, + ) + logger.debug( + "Successfully installed dependencies for" + f" plugin: {plugin_path.name}. Output:\n{result.stdout}", + "插件管理", + ) + except subprocess.CalledProcessError: + logger.error( + f"Failed to install dependencies for plugin: {plugin_path.name}. " + " Error:\n{e.stderr}" + ) + + class StoreManager: @classmethod async def get_github_plugins(cls) -> list[StorePluginInfo]: diff --git a/zhenxun/builtin_plugins/scripts.py b/zhenxun/builtin_plugins/scripts.py index 0be7527c..b5fca300 100644 --- a/zhenxun/builtin_plugins/scripts.py +++ b/zhenxun/builtin_plugins/scripts.py @@ -1,12 +1,8 @@ -import nonebot -from nonebot.drivers import Driver - from zhenxun.models.group_console import GroupConsole - -driver: Driver = nonebot.get_driver() +from zhenxun.utils.manager.priority_manager import PriorityLifecycle -@driver.on_startup +@PriorityLifecycle.on_startup(priority=5) async def _(): """开启/禁用插件格式修改""" _, is_create = await GroupConsole.get_or_create(group_id=133133133) diff --git a/zhenxun/builtin_plugins/shop/__init__.py b/zhenxun/builtin_plugins/shop/__init__.py index 432b9b92..120d2198 100644 --- a/zhenxun/builtin_plugins/shop/__init__.py +++ b/zhenxun/builtin_plugins/shop/__init__.py @@ -5,7 +5,9 @@ from nonebot_plugin_alconna import ( AlconnaQuery, Args, Arparma, + At, Match, + MultiVar, Option, Query, Subcommand, @@ -47,6 +49,7 @@ __plugin_meta__ = PluginMetadata( plugin_type=PluginType.NORMAL, menu_type="商店", commands=[ + Command(command="商店"), Command(command="我的金币"), Command(command="我的道具"), Command(command="购买道具"), @@ -75,13 +78,21 @@ _matcher = on_alconna( Subcommand("my-cost", help_text="我的金币"), Subcommand("my-props", help_text="我的道具"), Subcommand("buy", Args["name?", str]["num?", int], help_text="购买道具"), - Subcommand("use", Args["name?", str]["num?", int], help_text="使用道具"), Subcommand("gold-list", Args["num?", int], help_text="金币排行"), ), priority=5, block=True, ) +_use_matcher = on_alconna( + Alconna( + "使用道具", + Args["name?", str]["num?", int]["at_users?", MultiVar(At)], + ), + priority=5, + block=True, +) + _matcher.shortcut( "我的金币", command="商店", @@ -103,13 +114,6 @@ _matcher.shortcut( prefix=True, ) -_matcher.shortcut( - "使用道具(?P.*?)", - command="商店", - arguments=["use", "{name}"], - prefix=True, -) - _matcher.shortcut( "金币排行", command="商店", @@ -173,7 +177,7 @@ async def _( await MessageUtils.build_message(result).send(reply_to=True) -@_matcher.assign("use") +@_use_matcher.handle() async def _( bot: Bot, event: Event, @@ -182,6 +186,7 @@ async def _( arparma: Arparma, name: Match[str], num: Query[int] = AlconnaQuery("num", 1), + at_users: Query[list[At]] = AlconnaQuery("at_users", []), ): if not name.available: await MessageUtils.build_message( @@ -189,7 +194,7 @@ async def _( ).finish(reply_to=True) try: result = await ShopManage.use( - bot, event, session, message, name.result, num.result, "" + bot, event, session, message, name.result, num.result, "", at_users.result ) logger.info( f"使用道具 {name.result}, 数量: {num.result}", diff --git a/zhenxun/builtin_plugins/shop/_data_source.py b/zhenxun/builtin_plugins/shop/_data_source.py index 4c35c6ff..682bd85e 100644 --- a/zhenxun/builtin_plugins/shop/_data_source.py +++ b/zhenxun/builtin_plugins/shop/_data_source.py @@ -8,7 +8,7 @@ from typing import Any, Literal from nonebot.adapters import Bot, Event from nonebot.compat import model_dump -from nonebot_plugin_alconna import UniMessage, UniMsg +from nonebot_plugin_alconna import At, UniMessage, UniMsg from nonebot_plugin_uninfo import Uninfo from pydantic import BaseModel, Field, create_model from tortoise.expressions import Q @@ -48,6 +48,10 @@ class Goods(BaseModel): """model""" session: Uninfo | None = None """Uninfo""" + at_user: str | None = None + """At对象""" + at_users: list[str] = [] + """At对象列表""" class ShopParam(BaseModel): @@ -73,6 +77,10 @@ class ShopParam(BaseModel): """Uninfo""" message: UniMsg """UniMessage""" + at_user: str | None = None + """At对象""" + at_users: list[str] = [] + """At对象列表""" extra_data: dict[str, Any] = Field(default_factory=dict) """额外数据""" @@ -156,6 +164,7 @@ class ShopManage: goods: Goods, num: int, text: str, + at_users: list[str] = [], ) -> tuple[ShopParam, dict[str, Any]]: """构造参数 @@ -165,6 +174,7 @@ class ShopManage: goods_name: 商品名称 num: 数量 text: 其他信息 + at_users: at用户 """ group_id = None if session.group: @@ -172,6 +182,7 @@ class ShopManage: session.group.parent.id if session.group.parent else session.group.id ) _kwargs = goods.params + at_user = at_users[0] if at_users else None model = goods.model( **{ "goods_name": goods.name, @@ -183,6 +194,8 @@ class ShopManage: "text": text, "session": session, "message": message, + "at_user": at_user, + "at_users": at_users, } ) return model, { @@ -194,6 +207,8 @@ class ShopManage: "num": num, "text": text, "goods_name": goods.name, + "at_user": at_user, + "at_users": at_users, } @classmethod @@ -223,6 +238,7 @@ class ShopManage: **param.extra_data, "session": session, "message": message, + "shop_param": ShopParam, } for key in list(param_json.keys()): if key not in args: @@ -308,6 +324,7 @@ class ShopManage: goods_name: str, num: int, text: str, + at_users: list[At] = [], ) -> str | UniMessage | None: """使用道具 @@ -319,6 +336,7 @@ class ShopManage: goods_name: 商品名称 num: 使用数量 text: 其他信息 + at_users: at用户 返回: str | MessageFactory | None: 使用完成后返回信息 @@ -339,8 +357,9 @@ class ShopManage: goods = cls.uuid2goods.get(goods_info.uuid) if not goods or not goods.func: return f"{goods_info.goods_name} 未注册使用函数, 无法使用..." + at_user_ids = [at.target for at in at_users] param, kwargs = cls.__build_params( - bot, event, session, message, goods, num, text + bot, event, session, message, goods, num, text, at_user_ids ) if num > param.max_num_limit: return f"{goods_info.goods_name} 单次使用最大数量为{param.max_num_limit}..." @@ -480,10 +499,13 @@ class ShopManage: if not user.props: return None - user.props = {uuid: count for uuid, count in user.props.items() if count > 0} - goods_list = await GoodsInfo.filter(uuid__in=user.props.keys()).all() goods_by_uuid = {item.uuid: item for item in goods_list} + user.props = { + uuid: count + for uuid, count in user.props.items() + if count > 0 and goods_by_uuid.get(uuid) + } table_rows = [] for i, prop_uuid in enumerate(user.props): diff --git a/zhenxun/builtin_plugins/sign_in/__init__.py b/zhenxun/builtin_plugins/sign_in/__init__.py index 0b48a0e7..0986e476 100644 --- a/zhenxun/builtin_plugins/sign_in/__init__.py +++ b/zhenxun/builtin_plugins/sign_in/__init__.py @@ -10,7 +10,6 @@ from nonebot_plugin_alconna import ( store_true, ) from nonebot_plugin_apscheduler import scheduler -from nonebot_plugin_uninfo import Uninfo from zhenxun.configs.utils import ( Command, @@ -23,7 +22,7 @@ from zhenxun.utils.depends import UserName from zhenxun.utils.message import MessageUtils from ._data_source import SignManage -from .goods_register import driver # noqa: F401 +from .goods_register import Uninfo from .utils import clear_sign_data_pic __plugin_meta__ = PluginMetadata( diff --git a/zhenxun/builtin_plugins/sign_in/goods_register.py b/zhenxun/builtin_plugins/sign_in/goods_register.py index f7a65359..6c8e39bb 100644 --- a/zhenxun/builtin_plugins/sign_in/goods_register.py +++ b/zhenxun/builtin_plugins/sign_in/goods_register.py @@ -1,7 +1,6 @@ from decimal import Decimal import nonebot -from nonebot.drivers import Driver from nonebot_plugin_uninfo import Uninfo from zhenxun.models.sign_user import SignUser @@ -9,14 +8,7 @@ from zhenxun.models.user_console import UserConsole from zhenxun.utils.decorator.shop import shop_register from zhenxun.utils.platform import PlatformUtils -driver: Driver = nonebot.get_driver() - - -# @driver.on_startup -# async def _(): -# """ -# 导入内置的三个商品 -# """ +driver = nonebot.get_driver() @shop_register( diff --git a/zhenxun/builtin_plugins/sign_in/utils.py b/zhenxun/builtin_plugins/sign_in/utils.py index 9faf1120..910b90d8 100644 --- a/zhenxun/builtin_plugins/sign_in/utils.py +++ b/zhenxun/builtin_plugins/sign_in/utils.py @@ -16,6 +16,7 @@ from zhenxun.models.sign_log import SignLog from zhenxun.models.sign_user import SignUser from zhenxun.utils.http_utils import AsyncHttpx from zhenxun.utils.image_utils import BuildImage +from zhenxun.utils.manager.priority_manager import PriorityLifecycle from zhenxun.utils.platform import PlatformUtils from .config import ( @@ -54,7 +55,7 @@ LG_MESSAGE = [ ] -@driver.on_startup +@PriorityLifecycle.on_startup(priority=5) async def init_image(): SIGN_RESOURCE_PATH.mkdir(parents=True, exist_ok=True) SIGN_TODAY_CARD_PATH.mkdir(exist_ok=True, parents=True) diff --git a/zhenxun/builtin_plugins/statistics/statistics_hook.py b/zhenxun/builtin_plugins/statistics/statistics_hook.py index f3776ece..3ac15e2a 100644 --- a/zhenxun/builtin_plugins/statistics/statistics_hook.py +++ b/zhenxun/builtin_plugins/statistics/statistics_hook.py @@ -53,10 +53,7 @@ async def _( ) -@scheduler.scheduled_job( - "interval", - minutes=1, -) +@scheduler.scheduled_job("interval", minutes=1, max_instances=5) async def _(): try: call_list = TEMP_LIST.copy() diff --git a/zhenxun/builtin_plugins/web_ui/__init__.py b/zhenxun/builtin_plugins/web_ui/__init__.py index 90772bc5..619d56bf 100644 --- a/zhenxun/builtin_plugins/web_ui/__init__.py +++ b/zhenxun/builtin_plugins/web_ui/__init__.py @@ -10,7 +10,9 @@ from zhenxun.configs.config import Config as gConfig from zhenxun.configs.utils import PluginExtraData, RegisterConfig from zhenxun.services.log import logger, logger_ from zhenxun.utils.enum import PluginType +from zhenxun.utils.manager.priority_manager import PriorityLifecycle +from .api.configure import router as configure_router from .api.logs import router as ws_log_routes from .api.logs.log_manager import LOG_STORAGE from .api.menu import router as menu_router @@ -81,6 +83,7 @@ BaseApiRouter.include_router(database_router) BaseApiRouter.include_router(plugin_router) BaseApiRouter.include_router(system_router) BaseApiRouter.include_router(menu_router) +BaseApiRouter.include_router(configure_router) WsApiRouter = APIRouter(prefix="/zhenxun/socket") @@ -89,7 +92,7 @@ WsApiRouter.include_router(status_routes) WsApiRouter.include_router(chat_routes) -@driver.on_startup +@PriorityLifecycle.on_startup(priority=0) async def _(): try: # 存储任务引用的列表,防止任务被垃圾回收 diff --git a/zhenxun/builtin_plugins/web_ui/api/configure/__init__.py b/zhenxun/builtin_plugins/web_ui/api/configure/__init__.py new file mode 100644 index 00000000..0ecde197 --- /dev/null +++ b/zhenxun/builtin_plugins/web_ui/api/configure/__init__.py @@ -0,0 +1,133 @@ +import asyncio +import os +from pathlib import Path +import re +import subprocess +import sys +import time + +from fastapi import APIRouter +from fastapi.responses import JSONResponse +import nonebot + +from zhenxun.configs.config import BotConfig, Config + +from ...base_model import Result +from .data_source import test_db_connection +from .model import Setting + +router = APIRouter(prefix="/configure") + +driver = nonebot.get_driver() + +port = driver.config.port + +BAT_FILE = Path() / "win启动.bat" + +FILE_NAME = ".configure_restart" + + +@router.post( + "/set_configure", + response_model=Result, + response_class=JSONResponse, + description="设置基础配置", +) +async def _(setting: Setting) -> Result: + global port + password = Config.get_config("web-ui", "password") + if password or BotConfig.db_url: + return Result.fail("配置已存在,请先删除DB_URL内容和前端密码再进行设置。") + env_file = Path() / ".env.dev" + if not env_file.exists(): + return Result.fail("配置文件.env.dev不存在。") + env_text = env_file.read_text(encoding="utf-8") + if setting.db_url: + if setting.db_url.startswith("sqlite"): + base_dir = Path().resolve() + # 清理和验证数据库路径 + db_path_str = setting.db_url.split(":")[-1].strip() + # 移除任何可能的路径遍历尝试 + db_path_str = re.sub(r"[\\/]\.\.[\\/]", "", db_path_str) + # 规范化路径 + db_path = Path(db_path_str).resolve() + parent_path = db_path.parent + + # 验证路径是否在项目根目录内 + try: + if not parent_path.absolute().is_relative_to(base_dir): + return Result.fail("数据库路径不在项目根目录内。") + except ValueError: + return Result.fail("无效的数据库路径。") + + # 创建目录 + try: + parent_path.mkdir(parents=True, exist_ok=True) + except Exception as e: + return Result.fail(f"创建数据库目录失败: {e!s}") + + env_text = env_text.replace('DB_URL = ""', f'DB_URL = "{setting.db_url}"') + if setting.superusers: + superusers = ", ".join([f'"{s}"' for s in setting.superusers]) + env_text = re.sub(r"SUPERUSERS=\[.*?\]", f"SUPERUSERS=[{superusers}]", env_text) + if setting.host: + env_text = env_text.replace("HOST = 127.0.0.1", f"HOST = {setting.host}") + if setting.port: + env_text = env_text.replace("PORT = 8080", f"PORT = {setting.port}") + port = setting.port + if setting.username: + Config.set_config("web-ui", "username", setting.username) + Config.set_config("web-ui", "password", setting.password, True) + env_file.write_text(env_text, encoding="utf-8") + if BAT_FILE.exists(): + for file in os.listdir(Path()): + if file.startswith(FILE_NAME): + Path(file).unlink() + flag_file = Path() / f"{FILE_NAME}_{int(time.time())}" + flag_file.touch() + return Result.ok(BAT_FILE.exists(), info="设置成功,请重启真寻以完成配置!") + + +@router.get( + "/test_db", + response_model=Result, + response_class=JSONResponse, + description="设置基础配置", +) +async def _(db_url: str) -> Result: + result = await test_db_connection(db_url) + if isinstance(result, str): + return Result.fail(result) + return Result.ok(info="数据库连接成功!") + + +async def run_restart_command(bat_path: Path, port: int): + """在后台执行重启命令""" + await asyncio.sleep(1) # 确保 FastAPI 已返回响应 + subprocess.Popen([bat_path, str(port)], shell=True) # noqa: ASYNC220 + sys.exit(0) # 退出当前进程 + + +@router.post( + "/restart", + response_model=Result, + response_class=JSONResponse, + description="重启", +) +async def _() -> Result: + if not BAT_FILE.exists(): + return Result.fail("自动重启仅支持意见整合包,请尝试手动重启") + flag_file = next( + (Path() / file for file in os.listdir(Path()) if file.startswith(FILE_NAME)), + None, + ) + if not flag_file or not flag_file.exists(): + return Result.fail("重启标志文件不存在...") + set_time = flag_file.name.split("_")[-1] + if time.time() - float(set_time) > 10 * 60: + return Result.fail("重启标志文件已过期,请重新设置配置。") + flag_file.unlink() + try: + return Result.ok(info="执行重启命令成功") + finally: + asyncio.create_task(run_restart_command(BAT_FILE, port)) # noqa: RUF006 diff --git a/zhenxun/builtin_plugins/web_ui/api/configure/data_source.py b/zhenxun/builtin_plugins/web_ui/api/configure/data_source.py new file mode 100644 index 00000000..ad8c73c9 --- /dev/null +++ b/zhenxun/builtin_plugins/web_ui/api/configure/data_source.py @@ -0,0 +1,18 @@ +from tortoise import Tortoise + + +async def test_db_connection(db_url: str) -> bool | str: + try: + # 初始化 Tortoise ORM + await Tortoise.init( + db_url=db_url, + modules={"models": ["__main__"]}, # 这里不需要实际模型 + ) + # 测试连接 + await Tortoise.get_connection("default").execute_query("SELECT 1") + return True + except Exception as e: + return str(e) + finally: + # 关闭连接 + await Tortoise.close_connections() diff --git a/zhenxun/builtin_plugins/web_ui/api/configure/model.py b/zhenxun/builtin_plugins/web_ui/api/configure/model.py new file mode 100644 index 00000000..4a6b3486 --- /dev/null +++ b/zhenxun/builtin_plugins/web_ui/api/configure/model.py @@ -0,0 +1,16 @@ +from pydantic import BaseModel + + +class Setting(BaseModel): + superusers: list[str] + """超级用户列表""" + db_url: str + """数据库地址""" + host: str + """主机地址""" + port: int + """端口""" + username: str + """前端用户名""" + password: str + """前端密码""" diff --git a/zhenxun/builtin_plugins/web_ui/api/menu/data_source.py b/zhenxun/builtin_plugins/web_ui/api/menu/data_source.py index 14f5c928..e54bf9e5 100644 --- a/zhenxun/builtin_plugins/web_ui/api/menu/data_source.py +++ b/zhenxun/builtin_plugins/web_ui/api/menu/data_source.py @@ -5,54 +5,63 @@ from zhenxun.services.log import logger from .model import MenuData, MenuItem +default_menus = [ + MenuItem( + name="仪表盘", + module="dashboard", + router="/dashboard", + icon="dashboard", + default=True, + ), + MenuItem( + name="真寻控制台", + module="command", + router="/command", + icon="command", + ), + MenuItem(name="插件列表", module="plugin", router="/plugin", icon="plugin"), + MenuItem(name="插件商店", module="store", router="/store", icon="store"), + MenuItem(name="好友/群组", module="manage", router="/manage", icon="user"), + MenuItem( + name="数据库管理", + module="database", + router="/database", + icon="database", + ), + MenuItem(name="系统信息", module="system", router="/system", icon="system"), + MenuItem(name="关于我们", module="about", router="/about", icon="about"), +] -class MenuManage: + +class MenuManager: def __init__(self) -> None: self.file = DATA_PATH / "web_ui" / "menu.json" self.menu = [] if self.file.exists(): try: + temp_menu = [] self.menu = json.load(self.file.open(encoding="utf8")) + self_menu_name = [menu["name"] for menu in self.menu] + for module in [m.module for m in default_menus]: + if module in self_menu_name: + temp_menu.append( + MenuItem( + **next(m for m in self.menu if m["module"] == module) + ) + ) + else: + temp_menu.append(self.__get_menu_model(module)) + self.menu = temp_menu except Exception as e: logger.warning("菜单文件损坏,已重新生成...", "WebUi", e=e) if not self.menu: - self.menu = [ - MenuItem( - name="仪表盘", - module="dashboard", - router="/dashboard", - icon="dashboard", - default=True, - ), - MenuItem( - name="真寻控制台", - module="command", - router="/command", - icon="command", - ), - MenuItem( - name="插件列表", module="plugin", router="/plugin", icon="plugin" - ), - MenuItem( - name="插件商店", module="store", router="/store", icon="store" - ), - MenuItem( - name="好友/群组", module="manage", router="/manage", icon="user" - ), - MenuItem( - name="数据库管理", - module="database", - router="/database", - icon="database", - ), - MenuItem( - name="文件管理", module="system", router="/system", icon="system" - ), - MenuItem( - name="关于我们", module="about", router="/about", icon="about" - ), - ] - self.save() + self.menu = default_menus + self.save() + + def __get_menu_model(self, module: str): + return default_menus[ + next(i for i, m in enumerate(default_menus) if m.module == module) + ] def get_menus(self): return MenuData(menus=self.menu) @@ -64,4 +73,4 @@ class MenuManage: json.dump(temp, f, ensure_ascii=False, indent=4) -menu_manage = MenuManage() +menu_manage = MenuManager() diff --git a/zhenxun/builtin_plugins/web_ui/api/tabs/dashboard/data_source.py b/zhenxun/builtin_plugins/web_ui/api/tabs/dashboard/data_source.py index 6c312db3..87011c93 100644 --- a/zhenxun/builtin_plugins/web_ui/api/tabs/dashboard/data_source.py +++ b/zhenxun/builtin_plugins/web_ui/api/tabs/dashboard/data_source.py @@ -13,6 +13,7 @@ from zhenxun.models.bot_connect_log import BotConnectLog from zhenxun.models.chat_history import ChatHistory from zhenxun.models.statistics import Statistics from zhenxun.services.log import logger +from zhenxun.utils.manager.priority_manager import PriorityLifecycle from zhenxun.utils.platform import PlatformUtils from ....base_model import BaseResultModel, QueryModel @@ -31,7 +32,7 @@ driver: Driver = nonebot.get_driver() CONNECT_TIME = 0 -@driver.on_startup +@PriorityLifecycle.on_startup(priority=5) async def _(): global CONNECT_TIME CONNECT_TIME = int(time.time()) diff --git a/zhenxun/builtin_plugins/web_ui/api/tabs/database/__init__.py b/zhenxun/builtin_plugins/web_ui/api/tabs/database/__init__.py index b963e291..91fbc5c0 100644 --- a/zhenxun/builtin_plugins/web_ui/api/tabs/database/__init__.py +++ b/zhenxun/builtin_plugins/web_ui/api/tabs/database/__init__.py @@ -8,6 +8,7 @@ from zhenxun.configs.config import BotConfig from zhenxun.models.plugin_info import PluginInfo from zhenxun.models.task_info import TaskInfo from zhenxun.services.log import logger +from zhenxun.utils.manager.priority_manager import PriorityLifecycle from ....base_model import BaseResultModel, QueryModel, Result from ....utils import authentication @@ -21,7 +22,7 @@ router = APIRouter(prefix="/database") driver: Driver = nonebot.get_driver() -@driver.on_startup +@PriorityLifecycle.on_startup(priority=5) async def _(): for plugin in nonebot.get_loaded_plugins(): module = plugin.name diff --git a/zhenxun/builtin_plugins/web_ui/api/tabs/system/__init__.py b/zhenxun/builtin_plugins/web_ui/api/tabs/system/__init__.py index ffcd05be..949a69de 100644 --- a/zhenxun/builtin_plugins/web_ui/api/tabs/system/__init__.py +++ b/zhenxun/builtin_plugins/web_ui/api/tabs/system/__init__.py @@ -9,7 +9,7 @@ from fastapi.responses import JSONResponse from zhenxun.utils._build_image import BuildImage from ....base_model import Result, SystemFolderSize -from ....utils import authentication, get_system_disk +from ....utils import authentication, get_system_disk, validate_path from .model import AddFile, DeleteFile, DirFile, RenameFile, SaveFile router = APIRouter(prefix="/system") @@ -25,22 +25,29 @@ IMAGE_TYPE = ["jpg", "jpeg", "png", "gif", "bmp", "webp", "svg"] description="获取文件列表", ) async def _(path: str | None = None) -> Result[list[DirFile]]: - base_path = Path(path) if path else Path() - data_list = [] - for file in os.listdir(base_path): - file_path = base_path / file - is_image = any(file.endswith(f".{t}") for t in IMAGE_TYPE) - data_list.append( - DirFile( - is_file=not file_path.is_dir(), - is_image=is_image, - name=file, - parent=path, - size=None if file_path.is_dir() else file_path.stat().st_size, - mtime=file_path.stat().st_mtime, + try: + base_path, error = validate_path(path) + if error: + return Result.fail(error) + if not base_path: + return Result.fail("无效的路径") + data_list = [] + for file in os.listdir(base_path): + file_path = base_path / file + is_image = any(file.endswith(f".{t}") for t in IMAGE_TYPE) + data_list.append( + DirFile( + is_file=not file_path.is_dir(), + is_image=is_image, + name=file, + parent=path, + size=None if file_path.is_dir() else file_path.stat().st_size, + mtime=file_path.stat().st_mtime, + ) ) - ) - return Result.ok(data_list) + return Result.ok(data_list) + except Exception as e: + return Result.fail(f"获取文件列表失败: {e!s}") @router.get( @@ -62,8 +69,12 @@ async def _(full_path: str | None = None) -> Result[list[SystemFolderSize]]: description="删除文件", ) async def _(param: DeleteFile) -> Result: - path = Path(param.full_path) - if not path or not path.exists(): + path, error = validate_path(param.full_path) + if error: + return Result.fail(error) + if not path: + return Result.fail("无效的路径") + if not path.exists(): return Result.warning_("文件不存在...") try: path.unlink() @@ -80,8 +91,12 @@ async def _(param: DeleteFile) -> Result: description="删除文件夹", ) async def _(param: DeleteFile) -> Result: - path = Path(param.full_path) - if not path or not path.exists() or path.is_file(): + path, error = validate_path(param.full_path) + if error: + return Result.fail(error) + if not path: + return Result.fail("无效的路径") + if not path.exists() or path.is_file(): return Result.warning_("文件夹不存在...") try: shutil.rmtree(path.absolute()) @@ -98,10 +113,14 @@ async def _(param: DeleteFile) -> Result: description="重命名文件", ) async def _(param: RenameFile) -> Result: - path = ( - (Path(param.parent) / param.old_name) if param.parent else Path(param.old_name) - ) - if not path or not path.exists(): + parent_path, error = validate_path(param.parent) + if error: + return Result.fail(error) + if not parent_path: + return Result.fail("无效的路径") + + path = (parent_path / param.old_name) if param.parent else Path(param.old_name) + if not path.exists(): return Result.warning_("文件不存在...") try: path.rename(path.parent / param.name) @@ -118,10 +137,14 @@ async def _(param: RenameFile) -> Result: description="重命名文件夹", ) async def _(param: RenameFile) -> Result: - path = ( - (Path(param.parent) / param.old_name) if param.parent else Path(param.old_name) - ) - if not path or not path.exists() or path.is_file(): + parent_path, error = validate_path(param.parent) + if error: + return Result.fail(error) + if not parent_path: + return Result.fail("无效的路径") + + path = (parent_path / param.old_name) if param.parent else Path(param.old_name) + if not path.exists() or path.is_file(): return Result.warning_("文件夹不存在...") try: new_path = path.parent / param.name @@ -139,7 +162,13 @@ async def _(param: RenameFile) -> Result: description="新建文件", ) async def _(param: AddFile) -> Result: - path = (Path(param.parent) / param.name) if param.parent else Path(param.name) + parent_path, error = validate_path(param.parent) + if error: + return Result.fail(error) + if not parent_path: + return Result.fail("无效的路径") + + path = (parent_path / param.name) if param.parent else Path(param.name) if path.exists(): return Result.warning_("文件已存在...") try: @@ -157,7 +186,13 @@ async def _(param: AddFile) -> Result: description="新建文件夹", ) async def _(param: AddFile) -> Result: - path = (Path(param.parent) / param.name) if param.parent else Path(param.name) + parent_path, error = validate_path(param.parent) + if error: + return Result.fail(error) + if not parent_path: + return Result.fail("无效的路径") + + path = (parent_path / param.name) if param.parent else Path(param.name) if path.exists(): return Result.warning_("文件夹已存在...") try: @@ -175,7 +210,11 @@ async def _(param: AddFile) -> Result: description="读取文件", ) async def _(full_path: str) -> Result: - path = Path(full_path) + path, error = validate_path(full_path) + if error: + return Result.fail(error) + if not path: + return Result.fail("无效的路径") if not path.exists(): return Result.warning_("文件不存在...") try: @@ -193,9 +232,13 @@ async def _(full_path: str) -> Result: description="读取文件", ) async def _(param: SaveFile) -> Result[str]: - path = Path(param.full_path) + path, error = validate_path(param.full_path) + if error: + return Result.fail(error) + if not path: + return Result.fail("无效的路径") try: - async with aiofiles.open(path, "w", encoding="utf-8") as f: + async with aiofiles.open(str(path), "w", encoding="utf-8") as f: await f.write(param.content) return Result.ok("更新成功!") except Exception as e: @@ -210,7 +253,11 @@ async def _(param: SaveFile) -> Result[str]: description="读取图片base64", ) async def _(full_path: str) -> Result[str]: - path = Path(full_path) + path, error = validate_path(full_path) + if error: + return Result.fail(error) + if not path: + return Result.fail("无效的路径") if not path.exists(): return Result.warning_("文件不存在...") try: diff --git a/zhenxun/builtin_plugins/web_ui/config.py b/zhenxun/builtin_plugins/web_ui/config.py index bddcb062..4a88aad9 100644 --- a/zhenxun/builtin_plugins/web_ui/config.py +++ b/zhenxun/builtin_plugins/web_ui/config.py @@ -1,6 +1,12 @@ +import sys + from fastapi.middleware.cors import CORSMiddleware import nonebot -from strenum import StrEnum + +if sys.version_info >= (3, 11): + from enum import StrEnum +else: + from strenum import StrEnum from zhenxun.configs.path_config import DATA_PATH, TEMP_PATH diff --git a/zhenxun/builtin_plugins/web_ui/public/data_source.py b/zhenxun/builtin_plugins/web_ui/public/data_source.py index 9f5a657e..51b29533 100644 --- a/zhenxun/builtin_plugins/web_ui/public/data_source.py +++ b/zhenxun/builtin_plugins/web_ui/public/data_source.py @@ -18,6 +18,7 @@ async def update_webui_assets(): download_url = await GithubUtils.parse_github_url( WEBUI_DIST_GITHUB_URL ).get_archive_download_urls() + logger.info("开始下载 webui_assets 资源...", COMMAND_NAME) if await AsyncHttpx.download_file( download_url, webui_assets_path, follow_redirects=True ): diff --git a/zhenxun/builtin_plugins/web_ui/utils.py b/zhenxun/builtin_plugins/web_ui/utils.py index a7e22a07..84459114 100644 --- a/zhenxun/builtin_plugins/web_ui/utils.py +++ b/zhenxun/builtin_plugins/web_ui/utils.py @@ -2,6 +2,7 @@ import contextlib from datetime import datetime, timedelta, timezone import os from pathlib import Path +import re from fastapi import Depends, HTTPException from fastapi.security import OAuth2PasswordBearer @@ -28,6 +29,45 @@ if token_file.exists(): token_data = json.load(open(token_file, encoding="utf8")) +def validate_path(path_str: str | None) -> tuple[Path | None, str | None]: + """验证路径是否安全 + + 参数: + path_str: 用户输入的路径 + + 返回: + tuple[Path | None, str | None]: (验证后的路径, 错误信息) + """ + try: + if not path_str: + return Path().resolve(), None + + # 1. 移除任何可能的路径遍历尝试 + path_str = re.sub(r"[\\/]\.\.[\\/]", "", path_str) + + # 2. 规范化路径并转换为绝对路径 + path = Path(path_str).resolve() + + # 3. 获取项目根目录 + root_dir = Path().resolve() + + # 4. 验证路径是否在项目根目录内 + try: + if not path.is_relative_to(root_dir): + return None, "访问路径超出允许范围" + except ValueError: + return None, "无效的路径格式" + + # 5. 验证路径是否包含任何危险字符 + if any(c in str(path) for c in ["..", "~", "*", "?", ">", "<", "|", '"']): + return None, "路径包含非法字符" + + # 6. 验证路径长度是否合理 + return (None, "路径长度超出限制") if len(str(path)) > 4096 else (path, None) + except Exception as e: + return None, f"路径验证失败: {e!s}" + + GROUP_HELP_PATH = DATA_PATH / "group_help" SIMPLE_HELP_IMAGE = IMAGE_PATH / "SIMPLE_HELP.png" SIMPLE_DETAIL_HELP_IMAGE = IMAGE_PATH / "SIMPLE_DETAIL_HELP.png" diff --git a/zhenxun/models/bot_message_store.py b/zhenxun/models/bot_message_store.py new file mode 100644 index 00000000..fa1244f9 --- /dev/null +++ b/zhenxun/models/bot_message_store.py @@ -0,0 +1,29 @@ +from tortoise import fields + +from zhenxun.services.db_context import Model +from zhenxun.utils.enum import BotSentType + + +class BotMessageStore(Model): + id = fields.IntField(pk=True, generated=True, auto_increment=True) + """自增id""" + bot_id = fields.CharField(255, null=True) + """bot id""" + user_id = fields.CharField(255, null=True) + """目标id""" + group_id = fields.CharField(255, null=True) + """群组id""" + sent_type = fields.CharEnumField(BotSentType) + """类型""" + text = fields.TextField(null=True) + """文本内容""" + plain_text = fields.TextField(null=True) + """纯文本""" + platform = fields.CharField(255, null=True) + """平台""" + create_time = fields.DatetimeField(auto_now_add=True) + """创建时间""" + + class Meta: # pyright: ignore [reportIncompatibleVariableOverride] + table = "bot_message_store" + table_description = "Bot发送消息列表" diff --git a/zhenxun/services/db_context.py b/zhenxun/services/db_context.py index 9a44fa74..33678965 100644 --- a/zhenxun/services/db_context.py +++ b/zhenxun/services/db_context.py @@ -1,9 +1,12 @@ +import nonebot from nonebot.utils import is_coroutine_callable from tortoise import Tortoise from tortoise.connection import connections from tortoise.models import Model as Model_ from zhenxun.configs.config import BotConfig +from zhenxun.utils.exception import HookPriorityException +from zhenxun.utils.manager.priority_manager import PriorityLifecycle from .log import logger @@ -11,6 +14,9 @@ SCRIPT_METHOD = [] MODELS: list[str] = [] +driver = nonebot.get_driver() + + class Model(Model_): """ 自动添加模块 @@ -26,7 +32,7 @@ class Model(Model_): SCRIPT_METHOD.append((cls.__module__, func)) -class DbUrlIsNode(Exception): +class DbUrlIsNode(HookPriorityException): """ 数据库链接地址为空 """ @@ -42,9 +48,19 @@ class DbConnectError(Exception): pass +@PriorityLifecycle.on_startup(priority=1) async def init(): if not BotConfig.db_url: - raise DbUrlIsNode("数据库配置为空,请在.env.dev中配置DB_URL...") + # raise DbUrlIsNode("数据库配置为空,请在.env.dev中配置DB_URL...") + error = f""" +********************************************************************** +🌟 **************************** 配置为空 ************************* 🌟 +🚀 请打开 WebUi 进行基础配置 🚀 +🌐 配置地址:http://{driver.config.host}:{driver.config.port}/#/configure 🌐 +*********************************************************************** +*********************************************************************** + """ + raise DbUrlIsNode("\n" + error.strip()) try: await Tortoise.init( db_url=BotConfig.db_url, diff --git a/zhenxun/services/log.py b/zhenxun/services/log.py index 96a45bce..beb2b9c0 100644 --- a/zhenxun/services/log.py +++ b/zhenxun/services/log.py @@ -1,4 +1,4 @@ -from datetime import datetime, timedelta +from datetime import timedelta from typing import Any, overload import nonebot @@ -17,7 +17,7 @@ driver = nonebot.get_driver() log_level = driver.config.log_level or "INFO" logger_.add( - LOG_PATH / f"{datetime.now().date()}.log", + LOG_PATH / "{time:YYYY-MM-DD}.log", level=log_level, rotation="00:00", format=default_format, @@ -26,7 +26,7 @@ logger_.add( ) logger_.add( - LOG_PATH / f"error_{datetime.now().date()}.log", + LOG_PATH / "error_{time:YYYY-MM-DD}.log", level="ERROR", rotation="00:00", format=default_format, @@ -36,26 +36,92 @@ logger_.add( class logger: - TEMPLATE_A = "Adapter[{}] {}" - TEMPLATE_B = "Adapter[{}] [{}]: {}" - TEMPLATE_C = "Adapter[{}] 用户[{}] 触发 [{}]: {}" - TEMPLATE_D = "Adapter[{}] 群聊[{}] 用户[{}] 触发" - " [{}]: {}" - TEMPLATE_E = "Adapter[{}] 群聊[{}] 用户[{}] 触发" - " [{}] [Target]({}): {}" - - TEMPLATE_ADAPTER = "Adapter[{}] " - TEMPLATE_USER = "用户[{}] " - TEMPLATE_GROUP = "群聊[{}] " - TEMPLATE_COMMAND = "CMD[{}] " - TEMPLATE_PLATFORM = "平台[{}] " - TEMPLATE_TARGET = "[Target]([{}]) " + """ + 一个经过优化的、支持多种上下文和格式的日志记录器。 + """ + TEMPLATE_ADAPTER = "Adapter[{}]" + TEMPLATE_USER = "用户[{}]" + TEMPLATE_GROUP = "群聊[{}]" + TEMPLATE_COMMAND = "CMD[{}]" + TEMPLATE_PLATFORM = "平台[{}]" + TEMPLATE_TARGET = "[Target]([{}])" SUCCESS_TEMPLATE = "[{}]: {} | 参数[{}] 返回: [{}]" - WARNING_TEMPLATE = "[{}]: {}" + @classmethod + def __parser_template( + cls, + info: str, + command: str | None = None, + user_id: int | str | None = None, + group_id: int | str | None = None, + adapter: str | None = None, + target: Any = None, + platform: str | None = None, + ) -> str: + """ + 优化后的模板解析器,构建并连接日志信息片段。 + """ + parts = [] + if adapter: + parts.append(cls.TEMPLATE_ADAPTER.format(adapter)) + if platform: + parts.append(cls.TEMPLATE_PLATFORM.format(platform)) + if group_id: + parts.append(cls.TEMPLATE_GROUP.format(group_id)) + if user_id: + parts.append(cls.TEMPLATE_USER.format(user_id)) + if command: + parts.append(cls.TEMPLATE_COMMAND.format(command)) + if target: + parts.append(cls.TEMPLATE_TARGET.format(target)) - ERROR_TEMPLATE = "[{}]: {}" + parts.append(info) + return " ".join(parts) + + @classmethod + def _log( + cls, + level: str, + info: str, + command: str | None = None, + session: int | str | Session | uninfoSession | None = None, + group_id: int | str | None = None, + adapter: str | None = None, + target: Any = None, + platform: str | None = None, + e: Exception | None = None, + ): + """ + 核心日志处理方法,处理所有日志级别的通用逻辑。 + """ + user_id: str | None = str(session) if isinstance(session, int | str) else None + + if isinstance(session, Session): + user_id = session.id1 + adapter = session.bot_type + group_id = f"{session.id3}:{session.id2}" if session.id3 else session.id2 + platform = platform or session.platform + elif isinstance(session, uninfoSession): + user_id = session.user.id + adapter = session.adapter + if session.group: + group_id = session.group.id + platform = session.basic.get("scope") + + template = cls.__parser_template( + info, command, user_id, group_id, adapter, target, platform + ) + + if e: + template += f" || 错误 {type(e).__name__}: {e}" + + try: + log_func = getattr(logger_.opt(colors=True), level) + log_func(template) + except Exception: + log_func_fallback = getattr(logger_, level) + log_func_fallback(template) @overload @classmethod @@ -70,7 +136,6 @@ class logger: target: Any = None, platform: str | None = None, ): ... - @overload @classmethod def info( @@ -82,7 +147,6 @@ class logger: target: Any = None, platform: str | None = None, ): ... - @overload @classmethod def info( @@ -107,28 +171,16 @@ class logger: target: Any = None, platform: str | None = None, ): - user_id: str | None = session # type: ignore - if isinstance(session, Session): - user_id = session.id1 - adapter = session.bot_type - if session.id3: - group_id = f"{session.id3}:{session.id2}" - elif session.id2: - group_id = f"{session.id2}" - platform = platform or session.platform - elif isinstance(session, uninfoSession): - user_id = session.user.id - adapter = session.adapter - if session.group: - group_id = session.group.id - platform = session.basic["scope"] - template = cls.__parser_template( - info, command, user_id, group_id, adapter, target, platform + cls._log( + "info", + info=info, + command=command, + session=session, + group_id=group_id, + adapter=adapter, + target=target, + platform=platform, ) - try: - logger_.opt(colors=True).info(template) - except Exception: - logger_.info(template) @classmethod def success( @@ -138,9 +190,11 @@ class logger: param: dict[str, Any] | None = None, result: str = "", ): - param_str = "" - if param: - param_str = ",".join([f"{k}:{v}" for k, v in param.items()]) + param_str = ( + ",".join([f"{k}:{v}" for k, v in param.items()]) + if param + else "" + ) logger_.opt(colors=True).success( cls.SUCCESS_TEMPLATE.format(command, info, param_str, result) ) @@ -159,7 +213,6 @@ class logger: platform: str | None = None, e: Exception | None = None, ): ... - @overload @classmethod def warning( @@ -168,12 +221,10 @@ class logger: command: str | None = None, *, session: Session | None = None, - adapter: str | None = None, target: Any = None, platform: str | None = None, e: Exception | None = None, ): ... - @overload @classmethod def warning( @@ -182,7 +233,6 @@ class logger: command: str | None = None, *, session: uninfoSession | None = None, - adapter: str | None = None, target: Any = None, platform: str | None = None, e: Exception | None = None, @@ -201,30 +251,17 @@ class logger: platform: str | None = None, e: Exception | None = None, ): - user_id: str | None = session # type: ignore - if isinstance(session, Session): - user_id = session.id1 - adapter = session.bot_type - if session.id3: - group_id = f"{session.id3}:{session.id2}" - elif session.id2: - group_id = f"{session.id2}" - platform = platform or session.platform - elif isinstance(session, uninfoSession): - user_id = session.user.id - adapter = session.adapter - if session.group: - group_id = session.group.id - platform = session.basic["scope"] - template = cls.__parser_template( - info, command, user_id, group_id, adapter, target, platform + cls._log( + "warning", + info=info, + command=command, + session=session, + group_id=group_id, + adapter=adapter, + target=target, + platform=platform, + e=e, ) - if e: - template += f" || 错误{type(e)}: {e}" - try: - logger_.opt(colors=True).warning(template) - except Exception as e: - logger_.warning(template) @overload @classmethod @@ -240,7 +277,6 @@ class logger: platform: str | None = None, e: Exception | None = None, ): ... - @overload @classmethod def error( @@ -253,7 +289,6 @@ class logger: platform: str | None = None, e: Exception | None = None, ): ... - @overload @classmethod def error( @@ -280,30 +315,17 @@ class logger: platform: str | None = None, e: Exception | None = None, ): - user_id: str | None = session # type: ignore - if isinstance(session, Session): - user_id = session.id1 - adapter = session.bot_type - if session.id3: - group_id = f"{session.id3}:{session.id2}" - elif session.id2: - group_id = f"{session.id2}" - platform = platform or session.platform - elif isinstance(session, uninfoSession): - user_id = session.user.id - adapter = session.adapter - if session.group: - group_id = session.group.id - platform = session.basic["scope"] - template = cls.__parser_template( - info, command, user_id, group_id, adapter, target, platform + cls._log( + "error", + info=info, + command=command, + session=session, + group_id=group_id, + adapter=adapter, + target=target, + platform=platform, + e=e, ) - if e: - template += f" || 错误 {type(e)}: {e}" - try: - logger_.opt(colors=True).error(template) - except Exception as e: - logger_.error(template) @overload @classmethod @@ -319,7 +341,6 @@ class logger: platform: str | None = None, e: Exception | None = None, ): ... - @overload @classmethod def debug( @@ -332,7 +353,6 @@ class logger: platform: str | None = None, e: Exception | None = None, ): ... - @overload @classmethod def debug( @@ -359,62 +379,78 @@ class logger: platform: str | None = None, e: Exception | None = None, ): - user_id: str | None = session # type: ignore - if isinstance(session, Session): - user_id = session.id1 - adapter = session.bot_type - if session.id3: - group_id = f"{session.id3}:{session.id2}" - elif session.id2: - group_id = f"{session.id2}" - platform = platform or session.platform - elif isinstance(session, uninfoSession): - user_id = session.user.id - adapter = session.adapter - if session.group: - group_id = session.group.id - platform = session.basic["scope"] - template = cls.__parser_template( - info, command, user_id, group_id, adapter, target, platform + cls._log( + "debug", + info=info, + command=command, + session=session, + group_id=group_id, + adapter=adapter, + target=target, + platform=platform, + e=e, ) - if e: - template += f" || 错误 {type(e)}: {e}" - try: - logger_.opt(colors=True).debug(template) - except Exception as e: - logger_.debug(template) + @overload @classmethod - def __parser_template( + def trace( cls, info: str, command: str | None = None, - user_id: int | str | None = None, + *, + session: int | str | None = None, group_id: int | str | None = None, adapter: str | None = None, target: Any = None, platform: str | None = None, - ) -> str: - arg_list = [] - template = "" - if adapter is not None: - template += cls.TEMPLATE_ADAPTER - arg_list.append(adapter) - if platform is not None: - template += cls.TEMPLATE_PLATFORM - arg_list.append(platform) - if group_id is not None: - template += cls.TEMPLATE_GROUP - arg_list.append(group_id) - if user_id is not None: - template += cls.TEMPLATE_USER - arg_list.append(user_id) - if command is not None: - template += cls.TEMPLATE_COMMAND - arg_list.append(command) - if target is not None: - template += cls.TEMPLATE_TARGET - arg_list.append(target) - arg_list.append(info) - template += "{}" - return template.format(*arg_list) + e: Exception | None = None, + ): ... + @overload + @classmethod + def trace( + cls, + info: str, + command: str | None = None, + *, + session: Session | None = None, + target: Any = None, + platform: str | None = None, + e: Exception | None = None, + ): ... + @overload + @classmethod + def trace( + cls, + info: str, + command: str | None = None, + *, + session: uninfoSession | None = None, + target: Any = None, + platform: str | None = None, + e: Exception | None = None, + ): ... + + @classmethod + def trace( + cls, + info: str, + command: str | None = None, + *, + session: int | str | Session | uninfoSession | None = None, + group_id: int | str | None = None, + adapter: str | None = None, + target: Any = None, + platform: str | None = None, + e: Exception | None = None, + ): + cls._log( + "trace", + info=info, + command=command, + session=session, + group_id=group_id, + adapter=adapter, + target=target, + platform=platform, + e=e, + ) diff --git a/zhenxun/services/plugin_init.py b/zhenxun/services/plugin_init.py index 159e042c..a622a9e8 100644 --- a/zhenxun/services/plugin_init.py +++ b/zhenxun/services/plugin_init.py @@ -6,6 +6,7 @@ from nonebot.utils import is_coroutine_callable from pydantic import BaseModel from zhenxun.services.log import logger +from zhenxun.utils.manager.priority_manager import PriorityLifecycle driver = nonebot.get_driver() @@ -100,6 +101,6 @@ class PluginInitManager: logger.error(f"执行: {module_path}:remove 失败", e=e) -@driver.on_startup +@PriorityLifecycle.on_startup(priority=5) async def _(): await PluginInitManager.install_all() diff --git a/zhenxun/utils/_build_mat.py b/zhenxun/utils/_build_mat.py index de73e69d..a3de3087 100644 --- a/zhenxun/utils/_build_mat.py +++ b/zhenxun/utils/_build_mat.py @@ -1,12 +1,17 @@ from io import BytesIO from pathlib import Path import random +import sys from pydantic import BaseModel, Field -from strenum import StrEnum from ._build_image import BuildImage +if sys.version_info >= (3, 11): + from enum import StrEnum +else: + from strenum import StrEnum + class MatType(StrEnum): LINE = "LINE" diff --git a/zhenxun/utils/enum.py b/zhenxun/utils/enum.py index 2ddf5297..db527fc3 100644 --- a/zhenxun/utils/enum.py +++ b/zhenxun/utils/enum.py @@ -1,4 +1,21 @@ -from strenum import StrEnum +import sys + +if sys.version_info >= (3, 11): + from enum import StrEnum +else: + from strenum import StrEnum + + +class PriorityLifecycleType(StrEnum): + STARTUP = "STARTUP" + """启动""" + SHUTDOWN = "SHUTDOWN" + """关闭""" + + +class BotSentType(StrEnum): + GROUP = "GROUP" + PRIVATE = "PRIVATE" class BankHandleType(StrEnum): diff --git a/zhenxun/utils/exception.py b/zhenxun/utils/exception.py index db8c0656..8ec925ec 100644 --- a/zhenxun/utils/exception.py +++ b/zhenxun/utils/exception.py @@ -1,3 +1,15 @@ +class HookPriorityException(BaseException): + """ + 钩子优先级异常 + """ + + def __init__(self, info: str = "") -> None: + self.info = info + + def __str__(self) -> str: + return self.info + + class NotFoundError(Exception): """ 未发现 diff --git a/zhenxun/utils/github_utils/models.py b/zhenxun/utils/github_utils/models.py index e3e5dfe3..fb690616 100644 --- a/zhenxun/utils/github_utils/models.py +++ b/zhenxun/utils/github_utils/models.py @@ -1,13 +1,18 @@ import contextlib +import sys from typing import Protocol from aiocache import cached from nonebot.compat import model_dump from pydantic import BaseModel, Field -from strenum import StrEnum from zhenxun.utils.http_utils import AsyncHttpx +if sys.version_info >= (3, 11): + from enum import StrEnum +else: + from strenum import StrEnum + from .const import ( CACHED_API_TTL, GIT_API_COMMIT_FORMAT, diff --git a/zhenxun/utils/http_utils.py b/zhenxun/utils/http_utils.py index c6766bda..aecf3154 100644 --- a/zhenxun/utils/http_utils.py +++ b/zhenxun/utils/http_utils.py @@ -1,217 +1,223 @@ import asyncio -from asyncio.exceptions import TimeoutError -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Sequence from contextlib import asynccontextmanager from pathlib import Path import time -from typing import Any, ClassVar, Literal +from typing import Any, ClassVar, Literal, cast import aiofiles -from anyio import EndOfStream import httpx -from httpx import ConnectTimeout, HTTPStatusError, Response +from httpx import AsyncHTTPTransport, HTTPStatusError, Proxy, Response from nonebot_plugin_alconna import UniMessage from nonebot_plugin_htmlrender import get_browser from playwright.async_api import Page -import rich +from rich.progress import ( + BarColumn, + DownloadColumn, + Progress, + TextColumn, + TransferSpeedColumn, +) from zhenxun.configs.config import BotConfig from zhenxun.services.log import logger from zhenxun.utils.message import MessageUtils from zhenxun.utils.user_agent import get_user_agent -# from .browser import get_browser +CLIENT_KEY = ["use_proxy", "proxies", "proxy", "verify", "headers"] + + +def get_async_client( + proxies: dict[str, str] | None = None, + proxy: str | None = None, + verify: bool = False, + **kwargs, +) -> httpx.AsyncClient: + transport = kwargs.pop("transport", None) or AsyncHTTPTransport(verify=verify) + if proxies: + http_proxy = proxies.get("http://") + https_proxy = proxies.get("https://") + return httpx.AsyncClient( + mounts={ + "http://": AsyncHTTPTransport( + proxy=Proxy(http_proxy) if http_proxy else None + ), + "https://": AsyncHTTPTransport( + proxy=Proxy(https_proxy) if https_proxy else None + ), + }, + transport=transport, + **kwargs, + ) + elif proxy: + return httpx.AsyncClient( + mounts={ + "http://": AsyncHTTPTransport(proxy=Proxy(proxy)), + "https://": AsyncHTTPTransport(proxy=Proxy(proxy)), + }, + transport=transport, + **kwargs, + ) + return httpx.AsyncClient(transport=transport, **kwargs) class AsyncHttpx: - proxy: ClassVar[dict[str, str | None]] = { - "http://": BotConfig.system_proxy, - "https://": BotConfig.system_proxy, - } + default_proxy: ClassVar[dict[str, str] | None] = ( + { + "http://": BotConfig.system_proxy, + "https://": BotConfig.system_proxy, + } + if BotConfig.system_proxy + else None + ) + + @classmethod + @asynccontextmanager + async def _create_client( + cls, + *, + use_proxy: bool = True, + proxies: dict[str, str] | None = None, + proxy: str | None = None, + headers: dict[str, str] | None = None, + verify: bool = False, + **kwargs, + ) -> AsyncGenerator[httpx.AsyncClient, None]: + """创建一个私有的、配置好的 httpx.AsyncClient 上下文管理器。 + + 说明: + 此方法用于内部统一创建客户端,处理代理和请求头逻辑,减少代码重复。 + + 参数: + use_proxy: 是否使用在类中定义的默认代理。 + proxies: 手动指定的代理,会覆盖默认代理。 + proxy: 单个代理,用于兼容旧版本,不再使用 + headers: 需要合并到客户端的自定义请求头。 + verify: 是否验证 SSL 证书。 + **kwargs: 其他所有传递给 httpx.AsyncClient 的参数。 + + 返回: + AsyncGenerator[httpx.AsyncClient, None]: 生成器。 + """ + proxies_to_use = proxies or (cls.default_proxy if use_proxy else None) + + final_headers = get_user_agent() + if headers: + final_headers.update(headers) + + async with get_async_client( + proxies=proxies_to_use, + proxy=proxy, + verify=verify, + headers=final_headers, + **kwargs, + ) as client: + yield client @classmethod async def get( cls, url: str | list[str], *, - params: dict[str, Any] | None = None, - headers: dict[str, str] | None = None, - cookies: dict[str, str] | None = None, - verify: bool = True, - use_proxy: bool = True, - proxy: dict[str, str] | None = None, - timeout: int = 30, # noqa: ASYNC109 check_status_code: int | None = None, **kwargs, - ) -> Response: - """Get + ) -> Response: # sourcery skip: use-assigned-variable + """发送 GET 请求,并返回第一个成功的响应。 + + 说明: + 本方法是 httpx.get 的高级包装,增加了多链接尝试、自动重试和统一的代理管理。 + 如果提供 URL 列表,它将依次尝试直到成功为止。 参数: - url: url - params: params - headers: 请求头 - cookies: cookies - verify: verify - use_proxy: 使用默认代理 - proxy: 指定代理 - timeout: 超时时间 - check_status_code: 检查状态码 + url: 单个请求 URL 或一个 URL 列表。 + check_status_code: (可选) 若提供,将检查响应状态码是否匹配,否则抛出异常。 + **kwargs: 其他所有传递给 httpx.get 的参数 + (如 `params`, `headers`, `timeout`等)。 + + 返回: + Response: Response """ urls = [url] if isinstance(url, str) else url - return await cls._get_first_successful( - urls, - params=params, - headers=headers, - cookies=cookies, - verify=verify, - use_proxy=use_proxy, - proxy=proxy, - timeout=timeout, - check_status_code=check_status_code, - **kwargs, - ) - - @classmethod - async def _get_first_successful( - cls, - urls: list[str], - check_status_code: int | None = None, - **kwargs, - ) -> Response: last_exception = None - for url in urls: + for current_url in urls: try: - logger.info(f"开始获取 {url}..") - response = await cls._get_single(url, **kwargs) + logger.info(f"开始获取 {current_url}..") + client_kwargs = {k: v for k, v in kwargs.items() if k in CLIENT_KEY} + for key in CLIENT_KEY: + kwargs.pop(key, None) + async with cls._create_client(**client_kwargs) as client: + response = await client.get(current_url, **kwargs) + if check_status_code and response.status_code != check_status_code: - status_code = response.status_code - raise Exception(f"状态码错误:{status_code}!={check_status_code}") + raise HTTPStatusError( + f"状态码错误: {response.status_code}!={check_status_code}", + request=response.request, + response=response, + ) return response except Exception as e: last_exception = e - if url != urls[-1]: - logger.warning(f"获取 {url} 失败, 尝试下一个") - raise last_exception or Exception("All URLs failed") + if current_url != urls[-1]: + logger.warning(f"获取 {current_url} 失败, 尝试下一个", e=e) + + raise last_exception or Exception("所有URL都获取失败") @classmethod - async def _get_single( - cls, - url: str, - *, - params: dict[str, Any] | None = None, - headers: dict[str, str] | None = None, - cookies: dict[str, str] | None = None, - verify: bool = True, - use_proxy: bool = True, - proxy: dict[str, str] | None = None, - timeout: int = 30, # noqa: ASYNC109 - **kwargs, - ) -> Response: - if not headers: - headers = get_user_agent() - _proxy = proxy or (cls.proxy if use_proxy else None) - async with httpx.AsyncClient(proxies=_proxy, verify=verify) as client: # type: ignore - return await client.get( - url, - params=params, - headers=headers, - cookies=cookies, - timeout=timeout, - **kwargs, - ) + async def head(cls, url: str, **kwargs) -> Response: + """发送 HEAD 请求。 - @classmethod - async def head( - cls, - url: str, - *, - params: dict[str, Any] | None = None, - headers: dict[str, str] | None = None, - cookies: dict[str, str] | None = None, - verify: bool = True, - use_proxy: bool = True, - proxy: dict[str, str] | None = None, - timeout: int = 30, # noqa: ASYNC109 - **kwargs, - ) -> Response: - """Get - - 参数: - url: url - params: params - headers: 请求头 - cookies: cookies - verify: verify - use_proxy: 使用默认代理 - proxy: 指定代理 - timeout: 超时时间 - """ - if not headers: - headers = get_user_agent() - _proxy = proxy or (cls.proxy if use_proxy else None) - async with httpx.AsyncClient(proxies=_proxy, verify=verify) as client: # type: ignore - return await client.head( - url, - params=params, - headers=headers, - cookies=cookies, - timeout=timeout, - **kwargs, - ) - - @classmethod - async def post( - cls, - url: str, - *, - data: dict[str, Any] | None = None, - content: Any = None, - files: Any = None, - verify: bool = True, - use_proxy: bool = True, - proxy: dict[str, str] | None = None, - json: dict[str, Any] | None = None, - params: dict[str, str] | None = None, - headers: dict[str, str] | None = None, - cookies: dict[str, str] | None = None, - timeout: int = 30, # noqa: ASYNC109 - **kwargs, - ) -> Response: - """ 说明: - Post + 本方法是对 httpx.head 的封装,通常用于检查资源的元信息(如大小、类型)。 + 参数: - url: url - data: data - content: content - files: files - use_proxy: 是否默认代理 - proxy: 指定代理 - json: json - params: params - headers: 请求头 - cookies: cookies - timeout: 超时时间 + url: 请求的 URL。 + **kwargs: 其他所有传递给 httpx.head 的参数 + (如 `headers`, `timeout`, `allow_redirects`)。 + + 返回: + Response: Response """ - if not headers: - headers = get_user_agent() - _proxy = proxy or (cls.proxy if use_proxy else None) - async with httpx.AsyncClient(proxies=_proxy, verify=verify) as client: # type: ignore - return await client.post( - url, - content=content, - data=data, - files=files, - json=json, - params=params, - headers=headers, - cookies=cookies, - timeout=timeout, - **kwargs, - ) + client_kwargs = {k: v for k, v in kwargs.items() if k in CLIENT_KEY} + for key in CLIENT_KEY: + kwargs.pop(key, None) + async with cls._create_client(**client_kwargs) as client: + return await client.head(url, **kwargs) + + @classmethod + async def post(cls, url: str, **kwargs) -> Response: + """发送 POST 请求。 + + 说明: + 本方法是对 httpx.post 的封装,提供了统一的代理和客户端管理。 + + 参数: + url: 请求的 URL。 + **kwargs: 其他所有传递给 httpx.post 的参数 + (如 `data`, `json`, `content` 等)。 + + 返回: + Response: Response。 + """ + client_kwargs = {k: v for k, v in kwargs.items() if k in CLIENT_KEY} + for key in CLIENT_KEY: + kwargs.pop(key, None) + async with cls._create_client(**client_kwargs) as client: + return await client.post(url, **kwargs) @classmethod async def get_content(cls, url: str, **kwargs) -> bytes: + """获取指定 URL 的二进制内容。 + + 说明: + 这是一个便捷方法,等同于调用 get() 后再访问 .content 属性。 + + 参数: + url: 请求的 URL。 + **kwargs: 所有传递给 get() 方法的参数。 + + 返回: + bytes: 响应内容的二进制字节流 (bytes)。 + """ res = await cls.get(url, **kwargs) return res.content @@ -221,195 +227,143 @@ class AsyncHttpx: url: str | list[str], path: str | Path, *, - params: dict[str, str] | None = None, - verify: bool = True, - use_proxy: bool = True, - proxy: dict[str, str] | None = None, - headers: dict[str, str] | None = None, - cookies: dict[str, str] | None = None, - timeout: int = 30, # noqa: ASYNC109 stream: bool = False, - follow_redirects: bool = True, **kwargs, ) -> bool: - """下载文件 + """下载文件到指定路径。 + + 说明: + 支持多链接尝试和流式下载(带进度条)。 参数: - url: url - path: 存储路径 - params: params - verify: verify - use_proxy: 使用代理 - proxy: 指定代理 - headers: 请求头 - cookies: cookies - timeout: 超时时间 - stream: 是否使用流式下载(流式写入+进度条,适用于下载大文件) + url: 单个文件 URL 或一个备用 URL 列表。 + path: 文件保存的本地路径。 + stream: (可选) 是否使用流式下载,适用于大文件,默认为 False。 + **kwargs: 其他所有传递给 get() 方法或 httpx.stream() 的参数。 + + 返回: + bool: 是否下载成功。 """ - if isinstance(path, str): - path = Path(path) + path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) - try: - for _ in range(3): - if not isinstance(url, list): - url = [url] - for u in url: - try: - if not stream: - response = await cls.get( - u, - params=params, - headers=headers, - cookies=cookies, - use_proxy=use_proxy, - proxy=proxy, - timeout=timeout, - follow_redirects=follow_redirects, - **kwargs, - ) + + urls = [url] if isinstance(url, str) else url + + for current_url in urls: + try: + if not stream: + response = await cls.get(current_url, **kwargs) + response.raise_for_status() + async with aiofiles.open(path, "wb") as f: + await f.write(response.content) + else: + async with cls._create_client(**kwargs) as client: + stream_kwargs = { + k: v + for k, v in kwargs.items() + if k not in ["use_proxy", "proxy", "verify"] + } + async with client.stream( + "GET", current_url, **stream_kwargs + ) as response: response.raise_for_status() - content = response.content - async with aiofiles.open(path, "wb") as wf: - await wf.write(content) - logger.info(f"下载 {u} 成功.. Path:{path.absolute()}") - else: - if not headers: - headers = get_user_agent() - _proxy = proxy or (cls.proxy if use_proxy else None) - async with httpx.AsyncClient( - proxies=_proxy, # type: ignore - verify=verify, - ) as client: - async with client.stream( - "GET", - u, - params=params, - headers=headers, - cookies=cookies, - timeout=timeout, - follow_redirects=True, - **kwargs, - ) as response: - response.raise_for_status() - logger.info( - f"开始下载 {path.name}.. " - f"Url: {u}.. " - f"Path: {path.absolute()}" - ) - async with aiofiles.open(path, "wb") as wf: - total = int( - response.headers.get("Content-Length", 0) - ) - with rich.progress.Progress( # type: ignore - rich.progress.TextColumn(path.name), # type: ignore - "[progress.percentage]{task.percentage:>3.0f}%", # type: ignore - rich.progress.BarColumn(bar_width=None), # type: ignore - rich.progress.DownloadColumn(), # type: ignore - rich.progress.TransferSpeedColumn(), # type: ignore - ) as progress: - download_task = progress.add_task( - "Download", - total=total or None, - ) - async for chunk in response.aiter_bytes(): - await wf.write(chunk) - await wf.flush() - progress.update( - download_task, - completed=response.num_bytes_downloaded, - ) - logger.info( - f"下载 {u} 成功.. Path:{path.absolute()}" - ) - return True - except (TimeoutError, ConnectTimeout, HTTPStatusError): - logger.warning(f"下载 {u} 失败.. 尝试下一个地址..") - except EndOfStream as e: - logger.warning( - f"下载 {url} EndOfStream 异常 Path:{path.absolute()}", e=e - ) - if path.exists(): - return True - logger.error(f"下载 {url} 下载超时.. Path:{path.absolute()}") - except Exception as e: - logger.error(f"下载 {url} 错误 Path:{path.absolute()}", e=e) + total = int(response.headers.get("Content-Length", 0)) + + with Progress( + TextColumn(path.name), + "[progress.percentage]{task.percentage:>3.0f}%", + BarColumn(bar_width=None), + DownloadColumn(), + TransferSpeedColumn(), + ) as progress: + task_id = progress.add_task("Download", total=total) + async with aiofiles.open(path, "wb") as f: + async for chunk in response.aiter_bytes(): + await f.write(chunk) + progress.update(task_id, advance=len(chunk)) + + logger.info(f"下载 {current_url} 成功 -> {path.absolute()}") + return True + + except Exception as e: + logger.warning(f"下载 {current_url} 失败,尝试下一个。错误: {e}") + + logger.error(f"所有URL {urls} 下载均失败 -> {path.absolute()}") return False @classmethod async def gather_download_file( cls, - url_list: list[str] | list[list[str]], - path_list: list[str | Path], + url_list: Sequence[list[str] | str], + path_list: Sequence[str | Path], *, - limit_async_number: int | None = None, - params: dict[str, str] | None = None, - use_proxy: bool = True, - proxy: dict[str, str] | None = None, - headers: dict[str, str] | None = None, - cookies: dict[str, str] | None = None, - timeout: int = 30, # noqa: ASYNC109 + limit_async_number: int = 5, **kwargs, ) -> list[bool]: - """分组同时下载文件 + """并发下载多个文件,支持为每个文件提供备用镜像链接。 + + 说明: + 使用 asyncio.Semaphore 来控制并发请求的数量。 + 对于 url_list 中的每个元素,如果它是一个列表,则会依次尝试直到下载成功。 参数: - url_list: url列表 - path_list: 存储路径列表 - limit_async_number: 限制同时请求数量 - params: params - use_proxy: 使用代理 - proxy: 指定代理 - headers: 请求头 - cookies: cookies - timeout: 超时时间 + url_list: 包含所有文件下载任务的列表。每个元素可以是: + - 一个字符串 (str): 代表该任务的唯一URL。 + - 一个字符串列表 (list[str]): 代表该任务的多个备用/镜像URL。 + path_list: 与 url_list 对应的文件保存路径列表。 + limit_async_number: (可选) 最大并发下载数,默认为 5。 + **kwargs: 其他所有传递给 download_file() 方法的参数。 + + 返回: + list[bool]: 对应每个下载任务是否成功。 """ - if n := len(url_list) != len(path_list): - raise UrlPathNumberNotEqual( - f"Url数量与Path数量不对等,Url:{len(url_list)},Path:{len(path_list)}" - ) - if limit_async_number and n > limit_async_number: - m = float(n) / limit_async_number - x = 0 - j = limit_async_number - _split_url_list = [] - _split_path_list = [] - for _ in range(int(m)): - _split_url_list.append(url_list[x:j]) - _split_path_list.append(path_list[x:j]) - x += limit_async_number - j += limit_async_number - if int(m) < m: - _split_url_list.append(url_list[j:]) - _split_path_list.append(path_list[j:]) - else: - _split_url_list = [url_list] - _split_path_list = [path_list] - tasks = [] - result_ = [] - for x, y in zip(_split_url_list, _split_path_list): - tasks.extend( - asyncio.create_task( - cls.download_file( - url, - path, - params=params, - headers=headers, - cookies=cookies, - use_proxy=use_proxy, - timeout=timeout, - proxy=proxy, - **kwargs, - ) + if len(url_list) != len(path_list): + raise ValueError("URL 列表和路径列表的长度必须相等") + + semaphore = asyncio.Semaphore(limit_async_number) + + async def _download_with_semaphore( + urls_for_one_path: str | list[str], path: str | Path + ): + async with semaphore: + return await cls.download_file(urls_for_one_path, path, **kwargs) + + tasks = [ + _download_with_semaphore(url_group, path) + for url_group, path in zip(url_list, path_list) + ] + + results = await asyncio.gather(*tasks, return_exceptions=True) + + final_results = [] + for i, result in enumerate(results): + if isinstance(result, Exception): + url_info = ( + url_list[i] + if isinstance(url_list[i], str) + else ", ".join(url_list[i]) ) - for url, path in zip(x, y) - ) - _x = await asyncio.gather(*tasks) - result_ = result_ + list(_x) - tasks.clear() - return result_ + logger.error(f"并发下载任务 ({url_info}) 时发生错误", e=result) + final_results.append(False) + else: + # download_file 返回的是 bool,可以直接附加 + final_results.append(cast(bool, result)) + + return final_results @classmethod async def get_fastest_mirror(cls, url_list: list[str]) -> list[str]: + """测试并返回最快的镜像地址。 + + 说明: + 通过并发发送 HEAD 请求来测试每个 URL 的响应时间和可用性,并按响应速度排序。 + + 参数: + url_list: 需要测试的镜像 URL 列表。 + + 返回: + list[str]: 按从快到慢的顺序包含了所有可用的 URL。 + """ assert url_list async def head_mirror(client: type[AsyncHttpx], url: str) -> dict[str, Any]: @@ -478,7 +432,7 @@ class AsyncPlaywright: wait_until: ( Literal["domcontentloaded", "load", "networkidle"] | None ) = "networkidle", - timeout: float | None = None, # noqa: ASYNC109 + timeout: float | None = None, type_: Literal["jpeg", "png"] | None = None, user_agent: str | None = None, cookies: list[dict[str, Any]] | dict[str, Any] | None = None, @@ -522,9 +476,5 @@ class AsyncPlaywright: return None -class UrlPathNumberNotEqual(Exception): - pass - - class BrowserIsNone(Exception): - pass + pass \ No newline at end of file diff --git a/zhenxun/utils/manager/message_manager.py b/zhenxun/utils/manager/message_manager.py index e714c8d8..ee34369d 100644 --- a/zhenxun/utils/manager/message_manager.py +++ b/zhenxun/utils/manager/message_manager.py @@ -22,6 +22,4 @@ class MessageManager: @classmethod def get(cls, uid: str) -> list[str]: - if uid in cls.data: - return cls.data[uid] - return [] + return cls.data[uid] if uid in cls.data else [] diff --git a/zhenxun/utils/manager/priority_manager.py b/zhenxun/utils/manager/priority_manager.py new file mode 100644 index 00000000..1c59635c --- /dev/null +++ b/zhenxun/utils/manager/priority_manager.py @@ -0,0 +1,57 @@ +from collections.abc import Callable +from typing import ClassVar + +import nonebot +from nonebot.utils import is_coroutine_callable + +from zhenxun.services.log import logger +from zhenxun.utils.enum import PriorityLifecycleType +from zhenxun.utils.exception import HookPriorityException + +driver = nonebot.get_driver() + + +class PriorityLifecycle: + _data: ClassVar[dict[PriorityLifecycleType, dict[int, list[Callable]]]] = {} + + @classmethod + def add(cls, hook_type: PriorityLifecycleType, func: Callable, priority: int): + if hook_type not in cls._data: + cls._data[hook_type] = {} + if priority not in cls._data[hook_type]: + cls._data[hook_type][priority] = [] + cls._data[hook_type][priority].append(func) + + @classmethod + def on_startup(cls, *, priority: int): + def wrapper(func): + cls.add(PriorityLifecycleType.STARTUP, func, priority) + return func + + return wrapper + + @classmethod + def on_shutdown(cls, *, priority: int): + def wrapper(func): + cls.add(PriorityLifecycleType.SHUTDOWN, func, priority) + return func + + return wrapper + + +@driver.on_startup +async def _(): + priority_data = PriorityLifecycle._data.get(PriorityLifecycleType.STARTUP) + if not priority_data: + return + priority_list = sorted(priority_data.keys()) + priority = 0 + try: + for priority in priority_list: + for func in priority_data[priority]: + if is_coroutine_callable(func): + await func() + else: + func() + except HookPriorityException as e: + logger.error(f"打断优先级 [{priority}] on_startup 方法. {type(e)}: {e}") diff --git a/zhenxun/utils/platform.py b/zhenxun/utils/platform.py index 6a13293a..f01aec3f 100644 --- a/zhenxun/utils/platform.py +++ b/zhenxun/utils/platform.py @@ -1,7 +1,7 @@ import asyncio from collections.abc import Awaitable, Callable import random -from typing import Literal +from typing import cast import httpx import nonebot @@ -486,15 +486,134 @@ class PlatformUtils: return target +class BroadcastEngine: + def __init__( + self, + message: str | UniMessage, + bot: Bot | list[Bot] | None = None, + bot_id: str | set[str] | None = None, + ignore_group: list[str] | None = None, + check_func: Callable[[Bot, str], Awaitable] | None = None, + log_cmd: str | None = None, + platform: str | None = None, + ): + """广播引擎 + + 参数: + message: 广播消息内容 + bot: 指定bot对象. + bot_id: 指定bot id. + ignore_group: 忽略群聊列表. + check_func: 发送前对群聊检测方法,判断是否发送. + log_cmd: 日志标记. + platform: 指定平台. + + 异常: + ValueError: 没有可用的Bot对象 + """ + if ignore_group is None: + ignore_group = [] + self.message = MessageUtils.build_message(message) + self.ignore_group = ignore_group + self.check_func = check_func + self.log_cmd = log_cmd + self.platform = platform + self.bot_list = [] + self.count = 0 + if bot: + self.bot_list = [bot] if isinstance(bot, Bot) else bot + if isinstance(bot_id, str): + bot_id = set(bot_id) + if bot_id: + for i in bot_id: + try: + self.bot_list.append(nonebot.get_bot(i)) + except KeyError: + logger.warning(f"Bot:{i} 对象未连接或不存在") + if not self.bot_list: + raise ValueError("当前没有可用的Bot对象...", log_cmd) + + async def call_check(self, bot: Bot, group_id: str) -> bool: + """运行发送检测函数 + + 参数: + bot: Bot + group_id: 群组id + + 返回: + bool: 是否发送 + """ + if not self.check_func: + return True + if is_coroutine_callable(self.check_func): + is_run = await self.check_func(bot, group_id) + else: + is_run = self.check_func(bot, group_id) + return cast(bool, is_run) + + async def __send_message(self, bot: Bot, group: GroupConsole): + """群组发送消息 + + 参数: + bot: Bot + group: GroupConsole + """ + key = f"{group.group_id}:{group.channel_id}" + if not await self.call_check(bot, group.group_id): + logger.debug( + "广播方法检测运行方法为 False, 已跳过该群组...", + self.log_cmd, + group_id=group.group_id, + ) + return + if target := PlatformUtils.get_target( + group_id=group.group_id, + channel_id=group.channel_id, + ): + self.ignore_group.append(key) + await MessageUtils.build_message(self.message).send(target, bot) + logger.debug("广播消息发送成功...", self.log_cmd, target=key) + else: + logger.warning("广播消息获取Target失败...", self.log_cmd, target=key) + + async def broadcast(self) -> int: + """广播消息 + + 返回: + int: 成功发送次数 + """ + for bot in self.bot_list: + if self.platform and self.platform != PlatformUtils.get_platform(bot): + continue + group_list, _ = await PlatformUtils.get_group_list(bot) + if not group_list: + continue + for group in group_list: + if ( + group.group_id in self.ignore_group + or group.channel_id in self.ignore_group + ): + continue + try: + await self.__send_message(bot, group) + await asyncio.sleep(random.randint(1, 3)) + self.count += 1 + except Exception as e: + logger.warning( + "广播消息发送失败", self.log_cmd, target=group.group_id, e=e + ) + return self.count + + async def broadcast_group( message: str | UniMessage, bot: Bot | list[Bot] | None = None, bot_id: str | set[str] | None = None, - ignore_group: set[int] | None = None, + ignore_group: list[str] = [], check_func: Callable[[Bot, str], Awaitable] | None = None, log_cmd: str | None = None, - platform: Literal["qq", "dodo", "kaiheila"] | None = None, -): + platform: str | None = None, +) -> int: """获取所有Bot或指定Bot对象广播群聊 参数: @@ -505,81 +624,18 @@ async def broadcast_group( check_func: 发送前对群聊检测方法,判断是否发送. log_cmd: 日志标记. platform: 指定平台 + + 返回: + int: 成功发送次数 """ - if platform and platform not in ["qq", "dodo", "kaiheila"]: - raise ValueError("指定平台不支持") - if not message: - raise ValueError("群聊广播消息不能为空") - bot_dict = nonebot.get_bots() - bot_list: list[Bot] = [] - if bot: - if isinstance(bot, list): - bot_list = bot - else: - bot_list.append(bot) - elif bot_id: - _bot_id_list = bot_id - if isinstance(bot_id, str): - _bot_id_list = [bot_id] - for id_ in _bot_id_list: - if bot_id in bot_dict: - bot_list.append(bot_dict[bot_id]) - else: - logger.warning(f"Bot:{id_} 对象未连接或不存在") - else: - bot_list = list(bot_dict.values()) - _used_group = [] - for _bot in bot_list: - try: - if platform and platform != PlatformUtils.get_platform(_bot): - continue - group_list, _ = await PlatformUtils.get_group_list(_bot) - if group_list: - for group in group_list: - key = f"{group.group_id}:{group.channel_id}" - try: - if ( - ignore_group - and ( - group.group_id in ignore_group - or group.channel_id in ignore_group - ) - ) or key in _used_group: - logger.debug( - "广播方法群组重复, 已跳过...", - log_cmd, - group_id=group.group_id, - ) - continue - is_run = False - if check_func: - if is_coroutine_callable(check_func): - is_run = await check_func(_bot, group.group_id) - else: - is_run = check_func(_bot, group.group_id) - if not is_run: - logger.debug( - "广播方法检测运行方法为 False, 已跳过...", - log_cmd, - group_id=group.group_id, - ) - continue - target = PlatformUtils.get_target( - user_id=None, - group_id=group.group_id, - channel_id=group.channel_id, - ) - if target: - _used_group.append(key) - message_list = message - await MessageUtils.build_message(message_list).send( - target, _bot - ) - logger.debug("发送成功", log_cmd, target=key) - await asyncio.sleep(random.randint(1, 3)) - else: - logger.warning("target为空", log_cmd, target=key) - except Exception as e: - logger.error("发送失败", log_cmd, target=key, e=e) - except Exception as e: - logger.error(f"Bot: {_bot.self_id} 获取群聊列表失败", command=log_cmd, e=e) + if not message.strip(): + raise ValueError("群聊广播消息不能为空...") + return await BroadcastEngine( + message=message, + bot=bot, + bot_id=bot_id, + ignore_group=ignore_group, + check_func=check_func, + log_cmd=log_cmd, + platform=platform, + ).broadcast()