diff --git a/zhenxun/builtin_plugins/__init__.py b/zhenxun/builtin_plugins/__init__.py index fbaeb280..f2688905 100644 --- a/zhenxun/builtin_plugins/__init__.py +++ b/zhenxun/builtin_plugins/__init__.py @@ -16,6 +16,7 @@ from zhenxun.models.sign_user import SignUser from zhenxun.models.user_console import UserConsole from zhenxun.services.log import logger from zhenxun.utils.decorator.shop import shop_register +from zhenxun.utils.manager.priority_manager import PriorityLifecycle from zhenxun.utils.manager.resource_manager import ResourceManager from zhenxun.utils.platform import PlatformUtils @@ -70,7 +71,7 @@ from public.bag_users t1 """ -@driver.on_startup +@PriorityLifecycle.on_startup(priority=5) async def _(): await ResourceManager.init_resources() """签到与用户的数据迁移""" diff --git a/zhenxun/builtin_plugins/about.py b/zhenxun/builtin_plugins/about.py index faa0ba0e..31c77bc7 100644 --- a/zhenxun/builtin_plugins/about.py +++ b/zhenxun/builtin_plugins/about.py @@ -26,6 +26,21 @@ __plugin_meta__ = PluginMetadata( _matcher = on_alconna(Alconna("关于"), priority=5, block=True, rule=to_me()) +QQ_INFO = """ +『绪山真寻Bot』 +版本:{version} +简介:基于Nonebot2开发,支持多平台,是一个非常可爱的Bot呀,希望与大家要好好相处 +""".strip() + +INFO = """ +『绪山真寻Bot』 +版本:{version} +简介:基于Nonebot2开发,支持多平台,是一个非常可爱的Bot呀,希望与大家要好好相处 +项目地址:https://github.com/zhenxun-org/zhenxun_bot +文档地址:https://zhenxun-org.github.io/zhenxun_bot/ +""".strip() + + @_matcher.handle() async def _(session: Uninfo, arparma: Arparma): ver_file = Path() / "__version__" @@ -35,25 +50,11 @@ async def _(session: Uninfo, arparma: Arparma): if text := await f.read(): version = text.split(":")[-1].strip() if PlatformUtils.is_qbot(session): - info: list[str | Path] = [ - f""" -『绪山真寻Bot』 -版本:{version} -简介:基于Nonebot2开发,支持多平台,是一个非常可爱的Bot呀,希望与大家要好好相处 - """.strip() - ] + result: list[str | Path] = [QQ_INFO.format(version=version)] path = DATA_PATH / "about.png" if path.exists(): - info.append(path) + result.append(path) + await MessageUtils.build_message(result).send() # type: ignore else: - info = [ - f""" -『绪山真寻Bot』 -版本:{version} -简介:基于Nonebot2开发,支持多平台,是一个非常可爱的Bot呀,希望与大家要好好相处 -项目地址:https://github.com/HibiKier/zhenxun_bot -文档地址:https://hibikier.github.io/zhenxun_bot/ - """.strip() - ] - await MessageUtils.build_message(info).send() # type: ignore - logger.info("查看关于", arparma.header_result, session=session) + await MessageUtils.build_message(INFO.format(version=version)).send() + logger.info("查看关于", arparma.header_result, session=session) diff --git a/zhenxun/builtin_plugins/admin/ban/_data_source.py b/zhenxun/builtin_plugins/admin/ban/_data_source.py index 2d4dd6dc..ae465bdf 100644 --- a/zhenxun/builtin_plugins/admin/ban/_data_source.py +++ b/zhenxun/builtin_plugins/admin/ban/_data_source.py @@ -8,6 +8,7 @@ from zhenxun.models.level_user import LevelUser from zhenxun.services.log import logger from zhenxun.utils.image_utils import BuildImage, ImageTemplate + async def call_ban(user_id: str): """调用ban @@ -18,7 +19,6 @@ async def call_ban(user_id: str): logger.info("辱骂次数过多,已将用户加入黑名单...", "ban", session=user_id) - class BanManage: @classmethod async def build_ban_image( diff --git a/zhenxun/builtin_plugins/admin/plugin_switch/_data_source.py b/zhenxun/builtin_plugins/admin/plugin_switch/_data_source.py index 72862266..6c76f2a7 100644 --- a/zhenxun/builtin_plugins/admin/plugin_switch/_data_source.py +++ b/zhenxun/builtin_plugins/admin/plugin_switch/_data_source.py @@ -4,7 +4,8 @@ from zhenxun.configs.path_config import DATA_PATH, IMAGE_PATH from zhenxun.models.group_console import GroupConsole from zhenxun.models.plugin_info import PluginInfo from zhenxun.models.task_info import TaskInfo -from zhenxun.utils.enum import BlockType, PluginType +from zhenxun.services.cache import Cache +from zhenxun.utils.enum import BlockType, CacheType, PluginType from zhenxun.utils.exception import GroupInfoNotFound from zhenxun.utils.image_utils import BuildImage, ImageTemplate, RowStyle @@ -245,10 +246,12 @@ class PluginManage: 参数: group_id: 群组id """ - await GroupConsole.filter(group_id=group_id, channel_id__isnull=True).update( - status=False + group, _ = await GroupConsole.get_or_create( + group_id=group_id, channel_id__isnull=True ) - + group.status = False + await group.save(update_fields=["status"]) + @classmethod async def wake(cls, group_id: str): """醒来 @@ -256,9 +259,11 @@ class PluginManage: 参数: group_id: 群组id """ - await GroupConsole.filter(group_id=group_id, channel_id__isnull=True).update( - status=True + group, _ = await GroupConsole.get_or_create( + group_id=group_id, channel_id__isnull=True ) + group.status = True + await group.save(update_fields=["status"]) @classmethod async def block(cls, module: str): @@ -267,7 +272,9 @@ class PluginManage: 参数: module: 模块名 """ - await PluginInfo.filter(module=module).update(status=False) + if plugin := await PluginInfo.get_plugin(module=module): + plugin.status = False + await plugin.save(update_fields=["status"]) @classmethod async def unblock(cls, module: str): @@ -276,7 +283,9 @@ class PluginManage: 参数: module: 模块名 """ - await PluginInfo.filter(module=module).update(status=True) + if plugin := await PluginInfo.get_plugin(module=module): + plugin.status = True + await plugin.save(update_fields=["status"]) @classmethod async def block_group_plugin(cls, plugin_name: str, group_id: str) -> str: diff --git a/zhenxun/builtin_plugins/admin/welcome_message/data_source.py b/zhenxun/builtin_plugins/admin/welcome_message/data_source.py index 2ccb33ee..c8e486ed 100644 --- a/zhenxun/builtin_plugins/admin/welcome_message/data_source.py +++ b/zhenxun/builtin_plugins/admin/welcome_message/data_source.py @@ -14,6 +14,7 @@ from zhenxun.services.log import logger from zhenxun.utils._build_image import BuildImage from zhenxun.utils._image_template import ImageTemplate from zhenxun.utils.http_utils import AsyncHttpx +from zhenxun.utils.manager.priority_manager import PriorityLifecycle from zhenxun.utils.platform import PlatformUtils BASE_PATH = DATA_PATH / "welcome_message" @@ -91,7 +92,7 @@ def migrate(path: Path): json.dump(new_data, f, ensure_ascii=False, indent=4) -@driver.on_startup +@PriorityLifecycle.on_startup(priority=5) def _(): """数据迁移 diff --git a/zhenxun/builtin_plugins/help/__init__.py b/zhenxun/builtin_plugins/help/__init__.py index 17002f0c..35edf114 100644 --- a/zhenxun/builtin_plugins/help/__init__.py +++ b/zhenxun/builtin_plugins/help/__init__.py @@ -40,7 +40,13 @@ __plugin_meta__ = PluginMetadata( value="zhenxun", help="帮助图片样式 [normal, HTML, zhenxun]", default_value="zhenxun", - ) + ), + RegisterConfig( + key="detail_type", + value="zhenxun", + help="帮助详情图片样式 ['normal', 'zhenxun']", + default_value="zhenxun", + ), ], ).to_dict(), ) diff --git a/zhenxun/builtin_plugins/help/_data_source.py b/zhenxun/builtin_plugins/help/_data_source.py index 86f42536..23e9ec1b 100644 --- a/zhenxun/builtin_plugins/help/_data_source.py +++ b/zhenxun/builtin_plugins/help/_data_source.py @@ -1,13 +1,19 @@ from pathlib import Path import nonebot +from nonebot.plugin import PluginMetadata +from nonebot_plugin_htmlrender import template_to_pic from nonebot_plugin_uninfo import Uninfo -from zhenxun.configs.path_config import IMAGE_PATH +from zhenxun.configs.config import Config +from zhenxun.configs.path_config import IMAGE_PATH, TEMPLATE_PATH +from zhenxun.configs.utils import PluginExtraData from zhenxun.models.level_user import LevelUser from zhenxun.models.plugin_info import PluginInfo +from zhenxun.models.statistics import Statistics +from zhenxun.utils._image_template import ImageTemplate from zhenxun.utils.enum import PluginType -from zhenxun.utils.image_utils import BuildImage, ImageTemplate +from zhenxun.utils.image_utils import BuildImage from ._config import ( GROUP_HELP_PATH, @@ -80,9 +86,96 @@ async def get_user_allow_help(user_id: str) -> list[PluginType]: return type_list -async def get_plugin_help( - user_id: str, name: str, is_superuser: bool -) -> str | BuildImage: +async def get_normal_help( + metadata: PluginMetadata, extra: PluginExtraData, is_superuser: bool +) -> str | bytes: + """构建默认帮助详情 + + 参数: + metadata: PluginMetadata + extra: PluginExtraData + is_superuser: 是否超级用户帮助 + + 返回: + str | bytes: 返回信息 + """ + items = None + if is_superuser: + if usage := extra.superuser_help: + items = { + "简介": metadata.description, + "用法": usage, + } + else: + items = { + "简介": metadata.description, + "用法": metadata.usage, + } + if items: + return (await ImageTemplate.hl_page(metadata.name, items)).pic2bytes() + return "该功能没有帮助信息" + + +def min_leading_spaces(str_list: list[str]) -> int: + min_spaces = 9999 + + for s in str_list: + leading_spaces = len(s) - len(s.lstrip(" ")) + + if leading_spaces < min_spaces: + min_spaces = leading_spaces + + return min_spaces if min_spaces != 9999 else 0 + + +def split_text(text: str): + split_text = text.split("\n") + min_spaces = min_leading_spaces(split_text) + if min_spaces > 0: + split_text = [s[min_spaces:] for s in split_text] + return [s.replace(" ", " ") for s in split_text] + + +async def get_zhenxun_help( + module: str, metadata: PluginMetadata, extra: PluginExtraData, is_superuser: bool +) -> str | bytes: + """构建ZhenXun帮助详情 + + 参数: + module: 模块名 + metadata: PluginMetadata + extra: PluginExtraData + is_superuser: 是否超级用户帮助 + + 返回: + str | bytes: 返回信息 + """ + call_count = await Statistics.filter(plugin_name=module).count() + usage = metadata.usage + if is_superuser: + if not extra.superuser_help: + return "该功能没有超级用户帮助信息" + usage = extra.superuser_help + return await template_to_pic( + template_path=str((TEMPLATE_PATH / "help_detail").absolute()), + template_name="main.html", + templates={ + "title": metadata.name, + "author": extra.author, + "version": extra.version, + "call_count": call_count, + "descriptions": split_text(metadata.description), + "usages": split_text(usage), + }, + pages={ + "viewport": {"width": 824, "height": 590}, + "base_url": f"file://{TEMPLATE_PATH}", + }, + wait=2, + ) + + +async def get_plugin_help(user_id: str, name: str, is_superuser: bool) -> str | bytes: """获取功能的帮助信息 参数: @@ -100,20 +193,12 @@ async def get_plugin_help( if plugin: _plugin = nonebot.get_plugin_by_module_name(plugin.module_path) if _plugin and _plugin.metadata: - items = None - if is_superuser: - extra = _plugin.metadata.extra - if usage := extra.get("superuser_help"): - items = { - "简介": _plugin.metadata.description, - "用法": usage, - } + extra_data = PluginExtraData(**_plugin.metadata.extra) + if Config.get_config("help", "detail_type") == "zhenxun": + return await get_zhenxun_help( + plugin.module, _plugin.metadata, extra_data, is_superuser + ) else: - items = { - "简介": _plugin.metadata.description, - "用法": _plugin.metadata.usage, - } - if items: - return await ImageTemplate.hl_page(plugin.name, items) + return await get_normal_help(_plugin.metadata, extra_data, is_superuser) return "糟糕! 该功能没有帮助喔..." return "没有查找到这个功能噢..." diff --git a/zhenxun/builtin_plugins/help_help.py b/zhenxun/builtin_plugins/help_help.py index fec04a8d..6b5ecce9 100644 --- a/zhenxun/builtin_plugins/help_help.py +++ b/zhenxun/builtin_plugins/help_help.py @@ -21,7 +21,7 @@ from zhenxun.utils.message import MessageUtils __plugin_meta__ = PluginMetadata( name="笨蛋检测", description="功能名称当命令检测", - usage="""被动""".strip(), + usage="""当一些笨蛋直接输入功能名称时,提示笨蛋使用帮助指令查看功能帮助""".strip(), extra=PluginExtraData( author="HibiKier", version="0.1", diff --git a/zhenxun/builtin_plugins/hooks/__init__.py b/zhenxun/builtin_plugins/hooks/__init__.py index 4136ca95..2f8c79de 100644 --- a/zhenxun/builtin_plugins/hooks/__init__.py +++ b/zhenxun/builtin_plugins/hooks/__init__.py @@ -53,7 +53,7 @@ Config.add_plugin_config( "hook", "RECORD_BOT_SENT_MESSAGES", True, - help="记录bot消息校内", + help="记录bot消息发送", default_value=True, type=bool, ) diff --git a/zhenxun/builtin_plugins/hooks/auth_hook.py b/zhenxun/builtin_plugins/hooks/auth_hook.py index 5c83cb75..34ea8018 100644 --- a/zhenxun/builtin_plugins/hooks/auth_hook.py +++ b/zhenxun/builtin_plugins/hooks/auth_hook.py @@ -14,9 +14,7 @@ from .auth_checker import LimitManager, auth # # 权限检测 @run_preprocessor -async def _( - matcher: Matcher, event: Event, bot: Bot, session: Uninfo, message: UniMsg -): +async def _(matcher: Matcher, event: Event, bot: Bot, session: Uninfo, message: UniMsg): start_time = time.time() await auth( matcher, diff --git a/zhenxun/builtin_plugins/init/init_plugin.py b/zhenxun/builtin_plugins/init/init_plugin.py index dbeddb54..5bf50409 100644 --- a/zhenxun/builtin_plugins/init/init_plugin.py +++ b/zhenxun/builtin_plugins/init/init_plugin.py @@ -20,6 +20,7 @@ from zhenxun.utils.enum import ( PluginLimitType, PluginType, ) +from zhenxun.utils.manager.priority_manager import PriorityLifecycle from .manager import manager @@ -95,7 +96,7 @@ async def _handle_setting( ) -@driver.on_startup +@PriorityLifecycle.on_startup(priority=5) async def _(): """ 初始化插件数据配置 diff --git a/zhenxun/builtin_plugins/init/init_task.py b/zhenxun/builtin_plugins/init/init_task.py index cead7d72..b9bab56d 100644 --- a/zhenxun/builtin_plugins/init/init_task.py +++ b/zhenxun/builtin_plugins/init/init_task.py @@ -10,6 +10,7 @@ from zhenxun.models.group_console import GroupConsole from zhenxun.models.task_info import TaskInfo from zhenxun.services.log import logger from zhenxun.utils.common_utils import CommonUtils +from zhenxun.utils.manager.priority_manager import PriorityLifecycle driver: Driver = nonebot.get_driver() @@ -132,7 +133,7 @@ async def create_schedule(task: Task): logger.error(f"动态创建定时任务 {task.name}({task.module}) 失败", e=e) -@driver.on_startup +@PriorityLifecycle.on_startup(priority=5) async def _(): """ 初始化插件数据配置 diff --git a/zhenxun/builtin_plugins/mahiro_bank/__init__.py b/zhenxun/builtin_plugins/mahiro_bank/__init__.py index 2f6fbf1f..8e82cf08 100644 --- a/zhenxun/builtin_plugins/mahiro_bank/__init__.py +++ b/zhenxun/builtin_plugins/mahiro_bank/__init__.py @@ -95,13 +95,6 @@ _matcher = on_alconna( block=True, ) -_matcher.shortcut( - r"1111", - command="mahiro-bank", - arguments=["test"], - prefix=True, -) - _matcher.shortcut( r"存款\s*(?P\d+)?", command="mahiro-bank", diff --git a/zhenxun/builtin_plugins/mahiro_bank/data_source.py b/zhenxun/builtin_plugins/mahiro_bank/data_source.py index dc64fa16..b717e9a4 100644 --- a/zhenxun/builtin_plugins/mahiro_bank/data_source.py +++ b/zhenxun/builtin_plugins/mahiro_bank/data_source.py @@ -241,7 +241,7 @@ class BankManager: @classmethod async def get_bank_info(cls) -> bytes: now = datetime.now() - now_start = datetime.now() - timedelta( + now_start = now - timedelta( hours=now.hour, minutes=now.minute, seconds=now.second ) ( @@ -255,7 +255,9 @@ class BankManager: MahiroBank.annotate( amount_sum=Sum("amount"), user_count=Count("id") ).values("amount_sum", "user_count"), - MahiroBankLog.filter(create_time__gt=now_start).count(), + MahiroBankLog.filter( + create_time__gt=now_start, handle_type=BankHandleType.DEPOSIT + ).count(), MahiroBankLog.filter(handle_type=BankHandleType.INTEREST) .annotate(amount_sum=Sum("amount")) .values("amount_sum"), diff --git a/zhenxun/builtin_plugins/nickname.py b/zhenxun/builtin_plugins/nickname.py index 5cbc519e..7dd9a697 100644 --- a/zhenxun/builtin_plugins/nickname.py +++ b/zhenxun/builtin_plugins/nickname.py @@ -1,17 +1,12 @@ import random +from typing import Any +from nonebot import on_regex from nonebot.adapters import Bot +from nonebot.params import Depends, RegexGroup from nonebot.plugin import PluginMetadata from nonebot.rule import to_me -from nonebot_plugin_alconna import ( - Alconna, - Args, - Arparma, - CommandMeta, - Option, - on_alconna, - store_true, -) +from nonebot_plugin_alconna import Alconna, Option, on_alconna, store_true from nonebot_plugin_uninfo import Uninfo from zhenxun.configs.config import BotConfig, Config @@ -59,22 +54,15 @@ __plugin_meta__ = PluginMetadata( ).to_dict(), ) -_nickname_matcher = on_alconna( - Alconna( - "re:(?:以后)?(?:叫我|请叫我|称呼我)", - Args["name?", str], - meta=CommandMeta(compact=True), - ), +_nickname_matcher = on_regex( + "(?:以后)?(?:叫我|请叫我|称呼我)(.*)", rule=to_me(), priority=5, block=True, ) -_global_nickname_matcher = on_alconna( - Alconna("设置全局昵称", Args["name?", str], meta=CommandMeta(compact=True)), - rule=to_me(), - priority=5, - block=True, +_global_nickname_matcher = on_regex( + "设置全局昵称(.*)", rule=to_me(), priority=5, block=True ) _matcher = on_alconna( @@ -129,32 +117,34 @@ CANCEL = [ ] -async def CheckNickname( - bot: Bot, - session: Uninfo, - params: Arparma, -): +def CheckNickname(): """ 检查名称是否合法 """ - black_word = Config.get_config("nickname", "BLACK_WORD") - name = params.query("name") - logger.debug(f"昵称检查: {name}", "昵称设置", session=session) - if not name: - await MessageUtils.build_message("叫你空白?叫你虚空?叫你无名??").finish( - at_sender=True - ) - if session.user.id in bot.config.superusers: - logger.debug( - f"超级用户设置昵称, 跳过合法检测: {name}", "昵称设置", session=session - ) - else: + + async def dependency( + bot: Bot, + session: Uninfo, + reg_group: tuple[Any, ...] = RegexGroup(), + ): + black_word = Config.get_config("nickname", "BLACK_WORD") + (name,) = reg_group + logger.debug(f"昵称检查: {name}", "昵称设置", session=session) + if not name: + await MessageUtils.build_message("叫你空白?叫你虚空?叫你无名??").finish( + at_sender=True + ) + if session.user.id in bot.config.superusers: + logger.debug( + f"超级用户设置昵称, 跳过合法检测: {name}", "昵称设置", session=session + ) + return if len(name) > 20: await MessageUtils.build_message("昵称可不能超过20个字!").finish( at_sender=True ) if name in bot.config.nickname: - await MessageUtils.build_message("笨蛋!休想占用我的名字! ").finish( + await MessageUtils.build_message("笨蛋!休想占用我的名字! #").finish( at_sender=True ) if black_word: @@ -172,17 +162,17 @@ async def CheckNickname( await MessageUtils.build_message( f"字符 [{word}] 为禁止字符!" ).finish(at_sender=True) - return name + + return Depends(dependency) -@_nickname_matcher.handle() +@_nickname_matcher.handle(parameterless=[CheckNickname()]) async def _( - bot: Bot, session: Uninfo, - name_: Arparma, uname: str = UserName(), + reg_group: tuple[Any, ...] = RegexGroup(), ): - name = await CheckNickname(bot, session, name_) + (name,) = reg_group if len(name) < 5 and random.random() < 0.3: name = "~".join(name) group_id = None @@ -210,14 +200,13 @@ async def _( ) -@_global_nickname_matcher.handle() +@_global_nickname_matcher.handle(parameterless=[CheckNickname()]) async def _( - bot: Bot, session: Uninfo, - name_: Arparma, nickname: str = UserName(), + reg_group: tuple[Any, ...] = RegexGroup(), ): - name = await CheckNickname(bot, session, name_) + (name,) = reg_group await FriendUser.set_user_nickname( session.user.id, name, @@ -238,14 +227,15 @@ async def _(session: Uninfo, uname: str = UserName()): group_id = session.group.parent.id if session.group.parent else session.group.id if group_id: nickname = await GroupInfoUser.get_user_nickname(session.user.id, group_id) + card = uname else: nickname = await FriendUser.get_user_nickname(session.user.id) + card = uname if nickname: await MessageUtils.build_message(random.choice(REMIND).format(nickname)).finish( reply_to=True ) else: - card = uname await MessageUtils.build_message( random.choice( [ diff --git a/zhenxun/builtin_plugins/platform/qq/group_handle/data_source.py b/zhenxun/builtin_plugins/platform/qq/group_handle/data_source.py index 9e8d7ea2..d5be67d1 100644 --- a/zhenxun/builtin_plugins/platform/qq/group_handle/data_source.py +++ b/zhenxun/builtin_plugins/platform/qq/group_handle/data_source.py @@ -147,7 +147,7 @@ class GroupManager: e=e, ) raise ForceAddGroupError("强制拉群或未有群信息,退出群聊失败...") from e - # await GroupConsole.filter(group_id=group_id).delete() + #await GroupConsole.filter(group_id=group_id).delete() raise ForceAddGroupError(f"触发强制入群保护,已成功退出群聊 {group_id}...") else: await cls.__handle_add_group(bot, group_id, group) diff --git a/zhenxun/builtin_plugins/platform/qq_api/ug_watch.py b/zhenxun/builtin_plugins/platform/qq_api/ug_watch.py index 4435e880..800a2363 100644 --- a/zhenxun/builtin_plugins/platform/qq_api/ug_watch.py +++ b/zhenxun/builtin_plugins/platform/qq_api/ug_watch.py @@ -31,4 +31,4 @@ async def _(session: Uninfo): await FriendUser.create( user_id=session.user.id, platform=PlatformUtils.get_platform(session) ) - logger.info("添加当前好友用户信息", "", session=session) + logger.info("添加当前好友用户信息", "", session=session) \ No newline at end of file diff --git a/zhenxun/builtin_plugins/plugin_store/config.py b/zhenxun/builtin_plugins/plugin_store/config.py index 7512d49e..dd48a5c7 100644 --- a/zhenxun/builtin_plugins/plugin_store/config.py +++ b/zhenxun/builtin_plugins/plugin_store/config.py @@ -10,13 +10,4 @@ DEFAULT_GITHUB_URL = "https://github.com/zhenxun-org/zhenxun_bot_plugins/tree/ma EXTRA_GITHUB_URL = "https://github.com/zhenxun-org/zhenxun_bot_plugins_index/tree/index" """插件库索引github仓库地址""" -GITEE_RAW_URL = "https://gitee.com/two_Dimension/zhenxun_bot_plugins/raw/main" -"""GITEE仓库文件内容""" - -GITEE_CONTENTS_URL = ( - "https://gitee.com/api/v5/repos/two_Dimension/zhenxun_bot_plugins/contents" -) -"""GITEE仓库文件列表获取""" - - LOG_COMMAND = "插件商店" diff --git a/zhenxun/builtin_plugins/scheduler_admin/__init__.py b/zhenxun/builtin_plugins/scheduler_admin/__init__.py deleted file mode 100644 index adaaa621..00000000 --- a/zhenxun/builtin_plugins/scheduler_admin/__init__.py +++ /dev/null @@ -1,51 +0,0 @@ -from nonebot.plugin import PluginMetadata - -from zhenxun.configs.utils import PluginExtraData -from zhenxun.utils.enum import PluginType - -from . import command # noqa: F401 - -__plugin_meta__ = PluginMetadata( - name="定时任务管理", - description="查看和管理由 SchedulerManager 控制的定时任务。", - usage=""" -📋 定时任务管理 - 支持群聊和私聊操作 - -🔍 查看任务: - 定时任务 查看 [-all] [-g <群号>] [-p <插件>] [--page <页码>] - • 群聊中: 查看本群任务 - • 私聊中: 必须使用 -g <群号> 或 -all 选项 (SUPERUSER) - -📊 任务状态: - 定时任务 状态 <任务ID> 或 任务状态 <任务ID> - • 查看单个任务的详细信息和状态 - -⚙️ 任务管理 (SUPERUSER): - 定时任务 设置 <插件> [时间选项] [-g <群号> | -g all] [--kwargs <参数>] - 定时任务 删除 <任务ID> | -p <插件> [-g <群号>] | -all - 定时任务 暂停 <任务ID> | -p <插件> [-g <群号>] | -all - 定时任务 恢复 <任务ID> | -p <插件> [-g <群号>] | -all - 定时任务 执行 <任务ID> - 定时任务 更新 <任务ID> [时间选项] [--kwargs <参数>] - -📝 时间选项 (三选一): - --cron "<分> <时> <日> <月> <周>" # 例: --cron "0 8 * * *" - --interval <时间间隔> # 例: --interval 30m, 2h, 10s - --date "" # 例: --date "2024-01-01 08:00:00" - --daily "" # 例: --daily "08:30" - -📚 其他功能: - 定时任务 插件列表 # 查看所有可设置定时任务的插件 (SUPERUSER) - -🏷️ 别名支持: - 查看: ls, list | 设置: add, 开启 | 删除: del, rm, remove, 关闭, 取消 - 暂停: pause | 恢复: resume | 执行: trigger, run | 状态: status, info - 更新: update, modify, 修改 | 插件列表: plugins - """.strip(), - extra=PluginExtraData( - author="HibiKier", - version="0.1.2", - plugin_type=PluginType.SUPERUSER, - is_show=False, - ).to_dict(), -) diff --git a/zhenxun/builtin_plugins/scheduler_admin/command.py b/zhenxun/builtin_plugins/scheduler_admin/command.py deleted file mode 100644 index 08a085fb..00000000 --- a/zhenxun/builtin_plugins/scheduler_admin/command.py +++ /dev/null @@ -1,836 +0,0 @@ -import asyncio -from datetime import datetime -import re - -from nonebot.adapters import Event -from nonebot.adapters.onebot.v11 import Bot -from nonebot.params import Depends -from nonebot.permission import SUPERUSER -from nonebot_plugin_alconna import ( - Alconna, - AlconnaMatch, - Args, - Arparma, - Match, - Option, - Query, - Subcommand, - on_alconna, -) -from pydantic import BaseModel, ValidationError - -from zhenxun.utils._image_template import ImageTemplate -from zhenxun.utils.manager.schedule_manager import scheduler_manager - - -def _get_type_name(annotation) -> str: - """获取类型注解的名称""" - if hasattr(annotation, "__name__"): - return annotation.__name__ - elif hasattr(annotation, "_name"): - return annotation._name - else: - return str(annotation) - - -from zhenxun.utils.message import MessageUtils -from zhenxun.utils.rules import admin_check - - -def _format_trigger(schedule_status: dict) -> str: - """将触发器配置格式化为人类可读的字符串""" - trigger_type = schedule_status["trigger_type"] - config = schedule_status["trigger_config"] - - if trigger_type == "cron": - minute = config.get("minute", "*") - hour = config.get("hour", "*") - day = config.get("day", "*") - month = config.get("month", "*") - day_of_week = config.get("day_of_week", "*") - - if day == "*" and month == "*" and day_of_week == "*": - formatted_hour = hour if hour == "*" else f"{int(hour):02d}" - formatted_minute = minute if minute == "*" else f"{int(minute):02d}" - return f"每天 {formatted_hour}:{formatted_minute}" - else: - return f"Cron: {minute} {hour} {day} {month} {day_of_week}" - elif trigger_type == "interval": - seconds = config.get("seconds", 0) - minutes = config.get("minutes", 0) - hours = config.get("hours", 0) - days = config.get("days", 0) - if days: - trigger_str = f"每 {days} 天" - elif hours: - trigger_str = f"每 {hours} 小时" - elif minutes: - trigger_str = f"每 {minutes} 分钟" - else: - trigger_str = f"每 {seconds} 秒" - elif trigger_type == "date": - run_date = config.get("run_date", "未知时间") - trigger_str = f"在 {run_date}" - else: - trigger_str = f"{trigger_type}: {config}" - - return trigger_str - - -def _format_params(schedule_status: dict) -> str: - """将任务参数格式化为人类可读的字符串""" - if kwargs := schedule_status.get("job_kwargs"): - kwargs_str = " | ".join(f"{k}: {v}" for k, v in kwargs.items()) - return kwargs_str - return "-" - - -def _parse_interval(interval_str: str) -> dict: - """增强版解析器,支持 d(天)""" - match = re.match(r"(\d+)([smhd])", interval_str.lower()) - if not match: - raise ValueError("时间间隔格式错误, 请使用如 '30m', '2h', '1d', '10s' 的格式。") - - value, unit = int(match.group(1)), match.group(2) - if unit == "s": - return {"seconds": value} - if unit == "m": - return {"minutes": value} - if unit == "h": - return {"hours": value} - if unit == "d": - return {"days": value} - return {} - - -def _parse_daily_time(time_str: str) -> dict: - """解析 HH:MM 或 HH:MM:SS 格式的时间为 cron 配置""" - if match := re.match(r"^(\d{1,2}):(\d{1,2})(?::(\d{1,2}))?$", time_str): - hour, minute, second = match.groups() - hour, minute = int(hour), int(minute) - - if not (0 <= hour <= 23 and 0 <= minute <= 59): - raise ValueError("小时或分钟数值超出范围。") - - cron_config = { - "minute": str(minute), - "hour": str(hour), - "day": "*", - "month": "*", - "day_of_week": "*", - } - if second is not None: - if not (0 <= int(second) <= 59): - raise ValueError("秒数值超出范围。") - cron_config["second"] = str(second) - - return cron_config - else: - raise ValueError("时间格式错误,请使用 'HH:MM' 或 'HH:MM:SS' 格式。") - - -async def GetBotId( - bot: Bot, - bot_id_match: Match[str] = AlconnaMatch("bot_id"), -) -> str: - """获取要操作的Bot ID""" - if bot_id_match.available: - return bot_id_match.result - return bot.self_id - - -class ScheduleTarget: - """定时任务操作目标的基类""" - - pass - - -class TargetByID(ScheduleTarget): - """按任务ID操作""" - - def __init__(self, id: int): - self.id = id - - -class TargetByPlugin(ScheduleTarget): - """按插件名操作""" - - def __init__( - self, plugin: str, group_id: str | None = None, all_groups: bool = False - ): - self.plugin = plugin - self.group_id = group_id - self.all_groups = all_groups - - -class TargetAll(ScheduleTarget): - """操作所有任务""" - - def __init__(self, for_group: str | None = None): - self.for_group = for_group - - -TargetScope = TargetByID | TargetByPlugin | TargetAll | None - - -def create_target_parser(subcommand_name: str): - """ - 创建一个依赖注入函数,用于解析删除、暂停、恢复等命令的操作目标。 - """ - - async def dependency( - event: Event, - schedule_id: Match[int] = AlconnaMatch("schedule_id"), - plugin_name: Match[str] = AlconnaMatch("plugin_name"), - group_id: Match[str] = AlconnaMatch("group_id"), - all_enabled: Query[bool] = Query(f"{subcommand_name}.all"), - ) -> TargetScope: - if schedule_id.available: - return TargetByID(schedule_id.result) - - if plugin_name.available: - p_name = plugin_name.result - if all_enabled.available: - return TargetByPlugin(plugin=p_name, all_groups=True) - elif group_id.available: - gid = group_id.result - if gid.lower() == "all": - return TargetByPlugin(plugin=p_name, all_groups=True) - return TargetByPlugin(plugin=p_name, group_id=gid) - else: - current_group_id = getattr(event, "group_id", None) - if current_group_id: - return TargetByPlugin(plugin=p_name, group_id=str(current_group_id)) - else: - await schedule_cmd.finish( - "私聊中操作插件任务必须使用 -g <群号> 或 -all 选项。" - ) - - if all_enabled.available: - return TargetAll(for_group=group_id.result if group_id.available else None) - - return None - - return dependency - - -schedule_cmd = on_alconna( - Alconna( - "定时任务", - Subcommand( - "查看", - Option("-g", Args["target_group_id", str]), - Option("-all", help_text="查看所有群聊 (SUPERUSER)"), - Option("-p", Args["plugin_name", str], help_text="按插件名筛选"), - Option("--page", Args["page", int, 1], help_text="指定页码"), - alias=["ls", "list"], - help_text="查看定时任务", - ), - Subcommand( - "设置", - Args["plugin_name", str], - Option("--cron", Args["cron_expr", str], help_text="设置 cron 表达式"), - Option("--interval", Args["interval_expr", str], help_text="设置时间间隔"), - Option("--date", Args["date_expr", str], help_text="设置特定执行日期"), - Option( - "--daily", - Args["daily_expr", str], - help_text="设置每天执行的时间 (如 08:20)", - ), - Option("-g", Args["group_id", str], help_text="指定群组ID或'all'"), - Option("-all", help_text="对所有群生效 (等同于 -g all)"), - Option("--kwargs", Args["kwargs_str", str], help_text="设置任务参数"), - Option( - "--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)" - ), - alias=["add", "开启"], - help_text="设置/开启一个定时任务", - ), - Subcommand( - "删除", - Args["schedule_id?", int], - Option("-p", Args["plugin_name", str], help_text="指定插件名"), - Option("-g", Args["group_id", str], help_text="指定群组ID"), - Option("-all", help_text="对所有群生效"), - Option( - "--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)" - ), - alias=["del", "rm", "remove", "关闭", "取消"], - help_text="删除一个或多个定时任务", - ), - Subcommand( - "暂停", - Args["schedule_id?", int], - Option("-all", help_text="对当前群所有任务生效"), - Option("-p", Args["plugin_name", str], help_text="指定插件名"), - Option("-g", Args["group_id", str], help_text="指定群组ID (SUPERUSER)"), - Option( - "--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)" - ), - alias=["pause"], - help_text="暂停一个或多个定时任务", - ), - Subcommand( - "恢复", - Args["schedule_id?", int], - Option("-all", help_text="对当前群所有任务生效"), - Option("-p", Args["plugin_name", str], help_text="指定插件名"), - Option("-g", Args["group_id", str], help_text="指定群组ID (SUPERUSER)"), - Option( - "--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)" - ), - alias=["resume"], - help_text="恢复一个或多个定时任务", - ), - Subcommand( - "执行", - Args["schedule_id", int], - alias=["trigger", "run"], - help_text="立即执行一次任务", - ), - Subcommand( - "更新", - Args["schedule_id", int], - Option("--cron", Args["cron_expr", str], help_text="设置 cron 表达式"), - Option("--interval", Args["interval_expr", str], help_text="设置时间间隔"), - Option("--date", Args["date_expr", str], help_text="设置特定执行日期"), - Option( - "--daily", - Args["daily_expr", str], - help_text="更新每天执行的时间 (如 08:20)", - ), - Option("--kwargs", Args["kwargs_str", str], help_text="更新参数"), - alias=["update", "modify", "修改"], - help_text="更新任务配置", - ), - Subcommand( - "状态", - Args["schedule_id", int], - alias=["status", "info"], - help_text="查看单个任务的详细状态", - ), - Subcommand( - "插件列表", - alias=["plugins"], - help_text="列出所有可用的插件", - ), - ), - priority=5, - block=True, - rule=admin_check(1), -) - -schedule_cmd.shortcut( - "任务状态", - command="定时任务", - arguments=["状态", "{%0}"], - prefix=True, -) - - -@schedule_cmd.handle() -async def _handle_time_options_mutex(arp: Arparma): - time_options = ["cron", "interval", "date", "daily"] - provided_options = [opt for opt in time_options if arp.query(opt) is not None] - if len(provided_options) > 1: - await schedule_cmd.finish( - f"时间选项 --{', --'.join(provided_options)} 不能同时使用,请只选择一个。" - ) - - -@schedule_cmd.assign("查看") -async def _( - bot: Bot, - event: Event, - target_group_id: Match[str] = AlconnaMatch("target_group_id"), - all_groups: Query[bool] = Query("查看.all"), - plugin_name: Match[str] = AlconnaMatch("plugin_name"), - page: Match[int] = AlconnaMatch("page"), -): - is_superuser = await SUPERUSER(bot, event) - schedules = [] - title = "" - - current_group_id = getattr(event, "group_id", None) - if not (all_groups.available or target_group_id.available) and not current_group_id: - await schedule_cmd.finish("私聊中查看任务必须使用 -g <群号> 或 -all 选项。") - - if all_groups.available: - if not is_superuser: - await schedule_cmd.finish("需要超级用户权限才能查看所有群组的定时任务。") - schedules = await scheduler_manager.get_all_schedules() - title = "所有群组的定时任务" - elif target_group_id.available: - if not is_superuser: - await schedule_cmd.finish("需要超级用户权限才能查看指定群组的定时任务。") - gid = target_group_id.result - schedules = [ - s for s in await scheduler_manager.get_all_schedules() if s.group_id == gid - ] - title = f"群 {gid} 的定时任务" - else: - gid = str(current_group_id) - schedules = [ - s for s in await scheduler_manager.get_all_schedules() if s.group_id == gid - ] - title = "本群的定时任务" - - if plugin_name.available: - schedules = [s for s in schedules if s.plugin_name == plugin_name.result] - title += f" [插件: {plugin_name.result}]" - - if not schedules: - await schedule_cmd.finish("没有找到任何相关的定时任务。") - - page_size = 15 - current_page = page.result - total_items = len(schedules) - total_pages = (total_items + page_size - 1) // page_size - start_index = (current_page - 1) * page_size - end_index = start_index + page_size - paginated_schedules = schedules[start_index:end_index] - - if not paginated_schedules: - await schedule_cmd.finish("这一页没有内容了哦~") - - status_tasks = [ - scheduler_manager.get_schedule_status(s.id) for s in paginated_schedules - ] - all_statuses = await asyncio.gather(*status_tasks) - data_list = [ - [ - s["id"], - s["plugin_name"], - s.get("bot_id") or "N/A", - s["group_id"] or "全局", - s["next_run_time"], - _format_trigger(s), - _format_params(s), - "✔️ 已启用" if s["is_enabled"] else "⏸️ 已暂停", - ] - for s in all_statuses - if s - ] - - if not data_list: - await schedule_cmd.finish("没有找到任何相关的定时任务。") - - img = await ImageTemplate.table_page( - head_text=title, - tip_text=f"第 {current_page}/{total_pages} 页,共 {total_items} 条任务", - column_name=[ - "ID", - "插件", - "Bot ID", - "群组/目标", - "下次运行", - "触发规则", - "参数", - "状态", - ], - data_list=data_list, - column_space=20, - ) - await MessageUtils.build_message(img).send(reply_to=True) - - -@schedule_cmd.assign("设置") -async def _( - event: Event, - plugin_name: str, - cron_expr: str | None = None, - interval_expr: str | None = None, - date_expr: str | None = None, - daily_expr: str | None = None, - group_id: str | None = None, - kwargs_str: str | None = None, - all_enabled: Query[bool] = Query("设置.all"), - bot_id_to_operate: str = Depends(GetBotId), -): - if plugin_name not in scheduler_manager._registered_tasks: - await schedule_cmd.finish( - f"插件 '{plugin_name}' 没有注册可用的定时任务。\n" - f"可用插件: {list(scheduler_manager._registered_tasks.keys())}" - ) - - trigger_type = "" - trigger_config = {} - - try: - if cron_expr: - trigger_type = "cron" - parts = cron_expr.split() - if len(parts) != 5: - raise ValueError("Cron 表达式必须有5个部分 (分 时 日 月 周)") - cron_keys = ["minute", "hour", "day", "month", "day_of_week"] - trigger_config = dict(zip(cron_keys, parts)) - elif interval_expr: - trigger_type = "interval" - trigger_config = _parse_interval(interval_expr) - elif date_expr: - trigger_type = "date" - trigger_config = {"run_date": datetime.fromisoformat(date_expr)} - elif daily_expr: - trigger_type = "cron" - trigger_config = _parse_daily_time(daily_expr) - else: - await schedule_cmd.finish( - "必须提供一种时间选项: --cron, --interval, --date, 或 --daily。" - ) - except ValueError as e: - await schedule_cmd.finish(f"时间参数解析错误: {e}") - - job_kwargs = {} - if kwargs_str: - task_meta = scheduler_manager._registered_tasks[plugin_name] - params_model = task_meta.get("model") - if not params_model: - await schedule_cmd.finish(f"插件 '{plugin_name}' 不支持设置额外参数。") - - if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)): - await schedule_cmd.finish(f"插件 '{plugin_name}' 的参数模型配置错误。") - - raw_kwargs = {} - try: - for item in kwargs_str.split(","): - key, value = item.strip().split("=", 1) - raw_kwargs[key.strip()] = value - except Exception as e: - await schedule_cmd.finish( - f"参数格式错误,请使用 'key=value,key2=value2' 格式。错误: {e}" - ) - - try: - model_validate = getattr(params_model, "model_validate", None) - if not model_validate: - await schedule_cmd.finish( - f"插件 '{plugin_name}' 的参数模型不支持验证。" - ) - return - - validated_model = model_validate(raw_kwargs) - - model_dump = getattr(validated_model, "model_dump", None) - if not model_dump: - await schedule_cmd.finish( - f"插件 '{plugin_name}' 的参数模型不支持导出。" - ) - return - - job_kwargs = model_dump() - except ValidationError as e: - errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()] - error_str = "\n".join(errors) - await schedule_cmd.finish( - f"插件 '{plugin_name}' 的任务参数验证失败:\n{error_str}" - ) - return - - target_group_id: str | None - current_group_id = getattr(event, "group_id", None) - - if group_id and group_id.lower() == "all": - target_group_id = "__ALL_GROUPS__" - elif all_enabled.available: - target_group_id = "__ALL_GROUPS__" - elif group_id: - target_group_id = group_id - elif current_group_id: - target_group_id = str(current_group_id) - else: - await schedule_cmd.finish( - "私聊中设置定时任务时,必须使用 -g <群号> 或 --all 选项指定目标。" - ) - return - - success, msg = await scheduler_manager.add_schedule( - plugin_name, - target_group_id, - trigger_type, - trigger_config, - job_kwargs, - bot_id=bot_id_to_operate, - ) - - if target_group_id == "__ALL_GROUPS__": - target_desc = f"所有群组 (Bot: {bot_id_to_operate})" - elif target_group_id is None: - target_desc = "全局" - else: - target_desc = f"群组 {target_group_id}" - - if success: - await schedule_cmd.finish(f"已成功为 [{target_desc}] {msg}") - else: - await schedule_cmd.finish(f"为 [{target_desc}] 设置任务失败: {msg}") - - -@schedule_cmd.assign("删除") -async def _( - target: TargetScope = Depends(create_target_parser("删除")), - bot_id_to_operate: str = Depends(GetBotId), -): - if isinstance(target, TargetByID): - _, message = await scheduler_manager.remove_schedule_by_id(target.id) - await schedule_cmd.finish(message) - - elif isinstance(target, TargetByPlugin): - p_name = target.plugin - if p_name not in scheduler_manager.get_registered_plugins(): - await schedule_cmd.finish(f"未找到插件 '{p_name}'。") - - if target.all_groups: - removed_count = await scheduler_manager.remove_schedule_for_all( - p_name, bot_id=bot_id_to_operate - ) - message = ( - f"已取消了 {removed_count} 个群组的插件 '{p_name}' 定时任务。" - if removed_count > 0 - else f"没有找到插件 '{p_name}' 的定时任务。" - ) - await schedule_cmd.finish(message) - else: - _, message = await scheduler_manager.remove_schedule( - p_name, target.group_id, bot_id=bot_id_to_operate - ) - await schedule_cmd.finish(message) - - elif isinstance(target, TargetAll): - if target.for_group: - _, message = await scheduler_manager.remove_schedules_by_group( - target.for_group - ) - await schedule_cmd.finish(message) - else: - _, message = await scheduler_manager.remove_all_schedules() - await schedule_cmd.finish(message) - - else: - await schedule_cmd.finish( - "删除任务失败:请提供任务ID,或通过 -p <插件> 或 -all 指定要删除的任务。" - ) - - -@schedule_cmd.assign("暂停") -async def _( - target: TargetScope = Depends(create_target_parser("暂停")), - bot_id_to_operate: str = Depends(GetBotId), -): - if isinstance(target, TargetByID): - _, message = await scheduler_manager.pause_schedule(target.id) - await schedule_cmd.finish(message) - - elif isinstance(target, TargetByPlugin): - p_name = target.plugin - if p_name not in scheduler_manager.get_registered_plugins(): - await schedule_cmd.finish(f"未找到插件 '{p_name}'。") - - if target.all_groups: - _, message = await scheduler_manager.pause_schedules_by_plugin(p_name) - await schedule_cmd.finish(message) - else: - _, message = await scheduler_manager.pause_schedule_by_plugin_group( - p_name, target.group_id, bot_id=bot_id_to_operate - ) - await schedule_cmd.finish(message) - - elif isinstance(target, TargetAll): - if target.for_group: - _, message = await scheduler_manager.pause_schedules_by_group( - target.for_group - ) - await schedule_cmd.finish(message) - else: - _, message = await scheduler_manager.pause_all_schedules() - await schedule_cmd.finish(message) - - else: - await schedule_cmd.finish("请提供任务ID、使用 -p <插件> 或 -all 选项。") - - -@schedule_cmd.assign("恢复") -async def _( - target: TargetScope = Depends(create_target_parser("恢复")), - bot_id_to_operate: str = Depends(GetBotId), -): - if isinstance(target, TargetByID): - _, message = await scheduler_manager.resume_schedule(target.id) - await schedule_cmd.finish(message) - - elif isinstance(target, TargetByPlugin): - p_name = target.plugin - if p_name not in scheduler_manager.get_registered_plugins(): - await schedule_cmd.finish(f"未找到插件 '{p_name}'。") - - if target.all_groups: - _, message = await scheduler_manager.resume_schedules_by_plugin(p_name) - await schedule_cmd.finish(message) - else: - _, message = await scheduler_manager.resume_schedule_by_plugin_group( - p_name, target.group_id, bot_id=bot_id_to_operate - ) - await schedule_cmd.finish(message) - - elif isinstance(target, TargetAll): - if target.for_group: - _, message = await scheduler_manager.resume_schedules_by_group( - target.for_group - ) - await schedule_cmd.finish(message) - else: - _, message = await scheduler_manager.resume_all_schedules() - await schedule_cmd.finish(message) - - else: - await schedule_cmd.finish("请提供任务ID、使用 -p <插件> 或 -all 选项。") - - -@schedule_cmd.assign("执行") -async def _(schedule_id: int): - _, message = await scheduler_manager.trigger_now(schedule_id) - await schedule_cmd.finish(message) - - -@schedule_cmd.assign("更新") -async def _( - schedule_id: int, - cron_expr: str | None = None, - interval_expr: str | None = None, - date_expr: str | None = None, - daily_expr: str | None = None, - kwargs_str: str | None = None, -): - if not any([cron_expr, interval_expr, date_expr, daily_expr, kwargs_str]): - await schedule_cmd.finish( - "请提供需要更新的时间 (--cron/--interval/--date/--daily) 或参数 (--kwargs)" - ) - - trigger_config = None - trigger_type = None - try: - if cron_expr: - trigger_type = "cron" - parts = cron_expr.split() - if len(parts) != 5: - raise ValueError("Cron 表达式必须有5个部分") - cron_keys = ["minute", "hour", "day", "month", "day_of_week"] - trigger_config = dict(zip(cron_keys, parts)) - elif interval_expr: - trigger_type = "interval" - trigger_config = _parse_interval(interval_expr) - elif date_expr: - trigger_type = "date" - trigger_config = {"run_date": datetime.fromisoformat(date_expr)} - elif daily_expr: - trigger_type = "cron" - trigger_config = _parse_daily_time(daily_expr) - except ValueError as e: - await schedule_cmd.finish(f"时间参数解析错误: {e}") - - job_kwargs = None - if kwargs_str: - schedule = await scheduler_manager.get_schedule_by_id(schedule_id) - if not schedule: - await schedule_cmd.finish(f"未找到 ID 为 {schedule_id} 的任务。") - - task_meta = scheduler_manager._registered_tasks.get(schedule.plugin_name) - if not task_meta or not (params_model := task_meta.get("model")): - await schedule_cmd.finish( - f"插件 '{schedule.plugin_name}' 未定义参数模型,无法更新参数。" - ) - - if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)): - await schedule_cmd.finish( - f"插件 '{schedule.plugin_name}' 的参数模型配置错误。" - ) - - raw_kwargs = {} - try: - for item in kwargs_str.split(","): - key, value = item.strip().split("=", 1) - raw_kwargs[key.strip()] = value - except Exception as e: - await schedule_cmd.finish( - f"参数格式错误,请使用 'key=value,key2=value2' 格式。错误: {e}" - ) - - try: - model_validate = getattr(params_model, "model_validate", None) - if not model_validate: - await schedule_cmd.finish( - f"插件 '{schedule.plugin_name}' 的参数模型不支持验证。" - ) - return - - validated_model = model_validate(raw_kwargs) - - model_dump = getattr(validated_model, "model_dump", None) - if not model_dump: - await schedule_cmd.finish( - f"插件 '{schedule.plugin_name}' 的参数模型不支持导出。" - ) - return - - job_kwargs = model_dump(exclude_unset=True) - except ValidationError as e: - errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()] - error_str = "\n".join(errors) - await schedule_cmd.finish(f"更新的参数验证失败:\n{error_str}") - return - - _, message = await scheduler_manager.update_schedule( - schedule_id, trigger_type, trigger_config, job_kwargs - ) - await schedule_cmd.finish(message) - - -@schedule_cmd.assign("插件列表") -async def _(): - registered_plugins = scheduler_manager.get_registered_plugins() - if not registered_plugins: - await schedule_cmd.finish("当前没有已注册的定时任务插件。") - - message_parts = ["📋 已注册的定时任务插件:"] - for i, plugin_name in enumerate(registered_plugins, 1): - task_meta = scheduler_manager._registered_tasks[plugin_name] - params_model = task_meta.get("model") - - if not params_model: - message_parts.append(f"{i}. {plugin_name} - 无参数") - continue - - if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)): - message_parts.append(f"{i}. {plugin_name} - ⚠️ 参数模型配置错误") - continue - - model_fields = getattr(params_model, "model_fields", None) - if model_fields: - param_info = ", ".join( - f"{field_name}({_get_type_name(field_info.annotation)})" - for field_name, field_info in model_fields.items() - ) - message_parts.append(f"{i}. {plugin_name} - 参数: {param_info}") - else: - message_parts.append(f"{i}. {plugin_name} - 无参数") - - await schedule_cmd.finish("\n".join(message_parts)) - - -@schedule_cmd.assign("状态") -async def _(schedule_id: int): - status = await scheduler_manager.get_schedule_status(schedule_id) - if not status: - await schedule_cmd.finish(f"未找到ID为 {schedule_id} 的定时任务。") - - info_lines = [ - f"📋 定时任务详细信息 (ID: {schedule_id})", - "--------------------", - f"▫️ 插件: {status['plugin_name']}", - f"▫️ Bot ID: {status.get('bot_id') or '默认'}", - f"▫️ 目标: {status['group_id'] or '全局'}", - f"▫️ 状态: {'✔️ 已启用' if status['is_enabled'] else '⏸️ 已暂停'}", - f"▫️ 下次运行: {status['next_run_time']}", - f"▫️ 触发规则: {_format_trigger(status)}", - f"▫️ 任务参数: {_format_params(status)}", - ] - await schedule_cmd.finish("\n".join(info_lines)) diff --git a/zhenxun/builtin_plugins/scripts.py b/zhenxun/builtin_plugins/scripts.py index 27705301..b5fca300 100644 --- a/zhenxun/builtin_plugins/scripts.py +++ b/zhenxun/builtin_plugins/scripts.py @@ -1,67 +1,8 @@ -from asyncio.exceptions import TimeoutError - -import aiofiles -import nonebot -from nonebot.drivers import Driver -from nonebot_plugin_apscheduler import scheduler -import ujson as json - -from zhenxun.configs.path_config import TEXT_PATH from zhenxun.models.group_console import GroupConsole -from zhenxun.services.log import logger -from zhenxun.utils.http_utils import AsyncHttpx - -driver: Driver = nonebot.get_driver() +from zhenxun.utils.manager.priority_manager import PriorityLifecycle -@driver.on_startup -async def update_city(): - """ - 部分插件需要中国省份城市 - 这里直接更新,避免插件内代码重复 - """ - china_city = TEXT_PATH / "china_city.json" - if not china_city.exists(): - data = {} - try: - logger.debug("开始更新城市列表...") - res = await AsyncHttpx.get( - "http://www.weather.com.cn/data/city3jdata/china.html", timeout=5 - ) - res.encoding = "utf8" - provinces_data = json.loads(res.text) - for province in provinces_data.keys(): - data[provinces_data[province]] = [] - res = await AsyncHttpx.get( - f"http://www.weather.com.cn/data/city3jdata/provshi/{province}.html", - timeout=5, - ) - res.encoding = "utf8" - city_data = json.loads(res.text) - for city in city_data.keys(): - data[provinces_data[province]].append(city_data[city]) - async with aiofiles.open(china_city, "w", encoding="utf8") as f: - json.dump(data, f, indent=4, ensure_ascii=False) - logger.info("自动更新城市列表完成.....") - except TimeoutError as e: - logger.warning("自动更新城市列表超时...", e=e) - except ValueError as e: - logger.warning("自动城市列表失败.....", e=e) - except Exception as e: - logger.error("自动城市列表未知错误", e=e) - - -# 自动更新城市列表 -@scheduler.scheduled_job( - "cron", - hour=6, - minute=1, -) -async def _(): - await update_city() - - -@driver.on_startup +@PriorityLifecycle.on_startup(priority=5) async def _(): """开启/禁用插件格式修改""" _, is_create = await GroupConsole.get_or_create(group_id=133133133) diff --git a/zhenxun/builtin_plugins/shop/__init__.py b/zhenxun/builtin_plugins/shop/__init__.py index 89282d63..120d2198 100644 --- a/zhenxun/builtin_plugins/shop/__init__.py +++ b/zhenxun/builtin_plugins/shop/__init__.py @@ -5,7 +5,9 @@ from nonebot_plugin_alconna import ( AlconnaQuery, Args, Arparma, + At, Match, + MultiVar, Option, Query, Subcommand, @@ -33,6 +35,7 @@ __plugin_meta__ = PluginMetadata( usage=""" 商品操作 指令: + 商店 我的金币 我的道具 使用道具 [名称/Id] @@ -46,6 +49,7 @@ __plugin_meta__ = PluginMetadata( plugin_type=PluginType.NORMAL, menu_type="商店", commands=[ + Command(command="商店"), Command(command="我的金币"), Command(command="我的道具"), Command(command="购买道具"), @@ -74,13 +78,21 @@ _matcher = on_alconna( Subcommand("my-cost", help_text="我的金币"), Subcommand("my-props", help_text="我的道具"), Subcommand("buy", Args["name?", str]["num?", int], help_text="购买道具"), - Subcommand("use", Args["name?", str]["num?", int], help_text="使用道具"), Subcommand("gold-list", Args["num?", int], help_text="金币排行"), ), priority=5, block=True, ) +_use_matcher = on_alconna( + Alconna( + "使用道具", + Args["name?", str]["num?", int]["at_users?", MultiVar(At)], + ), + priority=5, + block=True, +) + _matcher.shortcut( "我的金币", command="商店", @@ -102,13 +114,6 @@ _matcher.shortcut( prefix=True, ) -_matcher.shortcut( - "使用道具(?P.*?)", - command="商店", - arguments=["use", "{name}"], - prefix=True, -) - _matcher.shortcut( "金币排行", command="商店", @@ -172,7 +177,7 @@ async def _( await MessageUtils.build_message(result).send(reply_to=True) -@_matcher.assign("use") +@_use_matcher.handle() async def _( bot: Bot, event: Event, @@ -181,6 +186,7 @@ async def _( arparma: Arparma, name: Match[str], num: Query[int] = AlconnaQuery("num", 1), + at_users: Query[list[At]] = AlconnaQuery("at_users", []), ): if not name.available: await MessageUtils.build_message( @@ -188,7 +194,7 @@ async def _( ).finish(reply_to=True) try: result = await ShopManage.use( - bot, event, session, message, name.result, num.result, "" + bot, event, session, message, name.result, num.result, "", at_users.result ) logger.info( f"使用道具 {name.result}, 数量: {num.result}", diff --git a/zhenxun/builtin_plugins/shop/_data_source.py b/zhenxun/builtin_plugins/shop/_data_source.py index 4c35c6ff..682bd85e 100644 --- a/zhenxun/builtin_plugins/shop/_data_source.py +++ b/zhenxun/builtin_plugins/shop/_data_source.py @@ -8,7 +8,7 @@ from typing import Any, Literal from nonebot.adapters import Bot, Event from nonebot.compat import model_dump -from nonebot_plugin_alconna import UniMessage, UniMsg +from nonebot_plugin_alconna import At, UniMessage, UniMsg from nonebot_plugin_uninfo import Uninfo from pydantic import BaseModel, Field, create_model from tortoise.expressions import Q @@ -48,6 +48,10 @@ class Goods(BaseModel): """model""" session: Uninfo | None = None """Uninfo""" + at_user: str | None = None + """At对象""" + at_users: list[str] = [] + """At对象列表""" class ShopParam(BaseModel): @@ -73,6 +77,10 @@ class ShopParam(BaseModel): """Uninfo""" message: UniMsg """UniMessage""" + at_user: str | None = None + """At对象""" + at_users: list[str] = [] + """At对象列表""" extra_data: dict[str, Any] = Field(default_factory=dict) """额外数据""" @@ -156,6 +164,7 @@ class ShopManage: goods: Goods, num: int, text: str, + at_users: list[str] = [], ) -> tuple[ShopParam, dict[str, Any]]: """构造参数 @@ -165,6 +174,7 @@ class ShopManage: goods_name: 商品名称 num: 数量 text: 其他信息 + at_users: at用户 """ group_id = None if session.group: @@ -172,6 +182,7 @@ class ShopManage: session.group.parent.id if session.group.parent else session.group.id ) _kwargs = goods.params + at_user = at_users[0] if at_users else None model = goods.model( **{ "goods_name": goods.name, @@ -183,6 +194,8 @@ class ShopManage: "text": text, "session": session, "message": message, + "at_user": at_user, + "at_users": at_users, } ) return model, { @@ -194,6 +207,8 @@ class ShopManage: "num": num, "text": text, "goods_name": goods.name, + "at_user": at_user, + "at_users": at_users, } @classmethod @@ -223,6 +238,7 @@ class ShopManage: **param.extra_data, "session": session, "message": message, + "shop_param": ShopParam, } for key in list(param_json.keys()): if key not in args: @@ -308,6 +324,7 @@ class ShopManage: goods_name: str, num: int, text: str, + at_users: list[At] = [], ) -> str | UniMessage | None: """使用道具 @@ -319,6 +336,7 @@ class ShopManage: goods_name: 商品名称 num: 使用数量 text: 其他信息 + at_users: at用户 返回: str | MessageFactory | None: 使用完成后返回信息 @@ -339,8 +357,9 @@ class ShopManage: goods = cls.uuid2goods.get(goods_info.uuid) if not goods or not goods.func: return f"{goods_info.goods_name} 未注册使用函数, 无法使用..." + at_user_ids = [at.target for at in at_users] param, kwargs = cls.__build_params( - bot, event, session, message, goods, num, text + bot, event, session, message, goods, num, text, at_user_ids ) if num > param.max_num_limit: return f"{goods_info.goods_name} 单次使用最大数量为{param.max_num_limit}..." @@ -480,10 +499,13 @@ class ShopManage: if not user.props: return None - user.props = {uuid: count for uuid, count in user.props.items() if count > 0} - goods_list = await GoodsInfo.filter(uuid__in=user.props.keys()).all() goods_by_uuid = {item.uuid: item for item in goods_list} + user.props = { + uuid: count + for uuid, count in user.props.items() + if count > 0 and goods_by_uuid.get(uuid) + } table_rows = [] for i, prop_uuid in enumerate(user.props): diff --git a/zhenxun/builtin_plugins/sign_in/__init__.py b/zhenxun/builtin_plugins/sign_in/__init__.py index 0b48a0e7..0986e476 100644 --- a/zhenxun/builtin_plugins/sign_in/__init__.py +++ b/zhenxun/builtin_plugins/sign_in/__init__.py @@ -10,7 +10,6 @@ from nonebot_plugin_alconna import ( store_true, ) from nonebot_plugin_apscheduler import scheduler -from nonebot_plugin_uninfo import Uninfo from zhenxun.configs.utils import ( Command, @@ -23,7 +22,7 @@ from zhenxun.utils.depends import UserName from zhenxun.utils.message import MessageUtils from ._data_source import SignManage -from .goods_register import driver # noqa: F401 +from .goods_register import Uninfo from .utils import clear_sign_data_pic __plugin_meta__ = PluginMetadata( diff --git a/zhenxun/builtin_plugins/sign_in/goods_register.py b/zhenxun/builtin_plugins/sign_in/goods_register.py index f7a65359..6c8e39bb 100644 --- a/zhenxun/builtin_plugins/sign_in/goods_register.py +++ b/zhenxun/builtin_plugins/sign_in/goods_register.py @@ -1,7 +1,6 @@ from decimal import Decimal import nonebot -from nonebot.drivers import Driver from nonebot_plugin_uninfo import Uninfo from zhenxun.models.sign_user import SignUser @@ -9,14 +8,7 @@ from zhenxun.models.user_console import UserConsole from zhenxun.utils.decorator.shop import shop_register from zhenxun.utils.platform import PlatformUtils -driver: Driver = nonebot.get_driver() - - -# @driver.on_startup -# async def _(): -# """ -# 导入内置的三个商品 -# """ +driver = nonebot.get_driver() @shop_register( diff --git a/zhenxun/builtin_plugins/sign_in/utils.py b/zhenxun/builtin_plugins/sign_in/utils.py index 9faf1120..910b90d8 100644 --- a/zhenxun/builtin_plugins/sign_in/utils.py +++ b/zhenxun/builtin_plugins/sign_in/utils.py @@ -16,6 +16,7 @@ from zhenxun.models.sign_log import SignLog from zhenxun.models.sign_user import SignUser from zhenxun.utils.http_utils import AsyncHttpx from zhenxun.utils.image_utils import BuildImage +from zhenxun.utils.manager.priority_manager import PriorityLifecycle from zhenxun.utils.platform import PlatformUtils from .config import ( @@ -54,7 +55,7 @@ LG_MESSAGE = [ ] -@driver.on_startup +@PriorityLifecycle.on_startup(priority=5) async def init_image(): SIGN_RESOURCE_PATH.mkdir(parents=True, exist_ok=True) SIGN_TODAY_CARD_PATH.mkdir(exist_ok=True, parents=True) diff --git a/zhenxun/builtin_plugins/superuser/broadcast/__init__.py b/zhenxun/builtin_plugins/superuser/broadcast/__init__.py index c025fd0c..3fc08e4c 100644 --- a/zhenxun/builtin_plugins/superuser/broadcast/__init__.py +++ b/zhenxun/builtin_plugins/superuser/broadcast/__init__.py @@ -1,32 +1,77 @@ -from typing import Annotated - -from nonebot import on_command -from nonebot.adapters import Bot -from nonebot.params import Command +from arclet.alconna import AllParam +from nepattern import UnionPattern +from nonebot.adapters import Bot, Event from nonebot.permission import SUPERUSER from nonebot.plugin import PluginMetadata from nonebot.rule import to_me -from nonebot_plugin_alconna import Text as alcText -from nonebot_plugin_alconna import UniMsg +import nonebot_plugin_alconna as alc +from nonebot_plugin_alconna import ( + Alconna, + Args, + on_alconna, +) +from nonebot_plugin_alconna.uniseg.segment import ( + At, + AtAll, + Audio, + Button, + Emoji, + File, + Hyper, + Image, + Keyboard, + Reference, + Reply, + Text, + Video, + Voice, +) from nonebot_plugin_session import EventSession from zhenxun.configs.utils import PluginExtraData, RegisterConfig, Task -from zhenxun.services.log import logger from zhenxun.utils.enum import PluginType from zhenxun.utils.message import MessageUtils -from ._data_source import BroadcastManage +from .broadcast_manager import BroadcastManager +from .message_processor import ( + _extract_broadcast_content, + get_broadcast_target_groups, + send_broadcast_and_notify, +) + +BROADCAST_SEND_DELAY_RANGE = (1, 3) __plugin_meta__ = PluginMetadata( name="广播", description="昭告天下!", usage=""" - 广播 [消息] [图片] - 示例:广播 你们好! + 广播 [消息内容] + - 直接发送消息到除当前群组外的所有群组 + - 支持文本、图片、@、表情、视频等多种消息类型 + - 示例:广播 你们好! + - 示例:广播 [图片] 新活动开始啦! + + 广播 + 引用消息 + - 将引用的消息作为广播内容发送 + - 支持引用普通消息或合并转发消息 + - 示例:(引用一条消息) 广播 + + 广播撤回 + - 撤回最近一次由您触发的广播消息 + - 仅能撤回短时间内的消息 + - 示例:广播撤回 + + 特性: + - 在群组中使用广播时,不会将消息发送到当前群组 + - 在私聊中使用广播时,会发送到所有群组 + + 别名: + - bc (广播的简写) + - recall (广播撤回的别名) """.strip(), extra=PluginExtraData( author="HibiKier", - version="0.1", + version="1.2", plugin_type=PluginType.SUPERUSER, configs=[ RegisterConfig( @@ -42,26 +87,106 @@ __plugin_meta__ = PluginMetadata( ).to_dict(), ) -_matcher = on_command( - "广播", priority=1, permission=SUPERUSER, block=True, rule=to_me() +AnySeg = ( + UnionPattern( + [ + Text, + Image, + At, + AtAll, + Audio, + Video, + File, + Emoji, + Reply, + Reference, + Hyper, + Button, + Keyboard, + Voice, + ] + ) + @ "AnySeg" +) + +_matcher = on_alconna( + Alconna( + "广播", + Args["content?", AllParam], + ), + aliases={"bc"}, + priority=1, + permission=SUPERUSER, + block=True, + rule=to_me(), + use_origin=False, +) + +_recall_matcher = on_alconna( + Alconna("广播撤回"), + aliases={"recall"}, + priority=1, + permission=SUPERUSER, + block=True, + rule=to_me(), ) @_matcher.handle() -async def _( +async def handle_broadcast( bot: Bot, + event: Event, session: EventSession, - message: UniMsg, - command: Annotated[tuple[str, ...], Command()], + arp: alc.Arparma, ): - for msg in message: - if isinstance(msg, alcText) and msg.text.strip().startswith(command[0]): - msg.text = msg.text.replace(command[0], "", 1).strip() - break - await MessageUtils.build_message("正在发送..请等一下哦!").send() - count, error_count = await BroadcastManage.send(bot, message, session) - result = f"成功广播 {count} 个群组" - if error_count: - result += f"\n广播失败 {error_count} 个群组" - await MessageUtils.build_message(f"发送广播完成!\n{result}").send(reply_to=True) - logger.info(f"发送广播信息: {message}", "广播", session=session) + broadcast_content_msg = await _extract_broadcast_content(bot, event, arp, session) + if not broadcast_content_msg: + return + + target_groups, enabled_groups = await get_broadcast_target_groups(bot, session) + if not target_groups or not enabled_groups: + return + + try: + await send_broadcast_and_notify( + bot, event, broadcast_content_msg, enabled_groups, target_groups, session + ) + except Exception as e: + error_msg = "发送广播失败" + BroadcastManager.log_error(error_msg, e, session) + await MessageUtils.build_message(f"{error_msg}。").send(reply_to=True) + + +@_recall_matcher.handle() +async def handle_broadcast_recall( + bot: Bot, + event: Event, + session: EventSession, +): + """处理广播撤回命令""" + await MessageUtils.build_message("正在尝试撤回最近一次广播...").send() + + try: + success_count, error_count = await BroadcastManager.recall_last_broadcast( + bot, session + ) + + user_id = str(event.get_user_id()) + if success_count == 0 and error_count == 0: + await bot.send_private_msg( + user_id=user_id, + message="没有找到最近的广播消息记录,可能已经撤回或超过可撤回时间。", + ) + else: + result = f"广播撤回完成!\n成功撤回 {success_count} 条消息" + if error_count: + result += f"\n撤回失败 {error_count} 条消息 (可能已过期或无权限)" + await bot.send_private_msg(user_id=user_id, message=result) + BroadcastManager.log_info( + f"广播撤回完成: 成功 {success_count}, 失败 {error_count}", session + ) + except Exception as e: + error_msg = "撤回广播消息失败" + BroadcastManager.log_error(error_msg, e, session) + user_id = str(event.get_user_id()) + await bot.send_private_msg(user_id=user_id, message=f"{error_msg}。") diff --git a/zhenxun/builtin_plugins/superuser/broadcast/_data_source.py b/zhenxun/builtin_plugins/superuser/broadcast/_data_source.py deleted file mode 100644 index 1ee1a28c..00000000 --- a/zhenxun/builtin_plugins/superuser/broadcast/_data_source.py +++ /dev/null @@ -1,72 +0,0 @@ -import asyncio -import random - -from nonebot.adapters import Bot -import nonebot_plugin_alconna as alc -from nonebot_plugin_alconna import Image, UniMsg -from nonebot_plugin_session import EventSession - -from zhenxun.services.log import logger -from zhenxun.utils.common_utils import CommonUtils -from zhenxun.utils.message import MessageUtils -from zhenxun.utils.platform import PlatformUtils - - -class BroadcastManage: - @classmethod - async def send( - cls, bot: Bot, message: UniMsg, session: EventSession - ) -> tuple[int, int]: - """发送广播消息 - - 参数: - bot: Bot - message: 消息内容 - session: Session - - 返回: - tuple[int, int]: 发送成功的群组数量, 发送失败的群组数量 - """ - message_list = [] - for msg in message: - if isinstance(msg, alc.Image) and msg.url: - message_list.append(Image(url=msg.url)) - elif isinstance(msg, alc.Text): - message_list.append(msg.text) - group_list, _ = await PlatformUtils.get_group_list(bot) - if group_list: - error_count = 0 - for group in group_list: - try: - if not await CommonUtils.task_is_block( - bot, - "broadcast", # group.channel_id - group.group_id, - ): - target = PlatformUtils.get_target( - group_id=group.group_id, channel_id=group.channel_id - ) - if target: - await MessageUtils.build_message(message_list).send( - target, bot - ) - logger.debug( - "发送成功", - "广播", - session=session, - target=f"{group.group_id}:{group.channel_id}", - ) - await asyncio.sleep(random.randint(1, 3)) - else: - logger.warning("target为空", "广播", session=session) - except Exception as e: - error_count += 1 - logger.error( - "发送失败", - "广播", - session=session, - target=f"{group.group_id}:{group.channel_id}", - e=e, - ) - return len(group_list) - error_count, error_count - return 0, 0 diff --git a/zhenxun/builtin_plugins/superuser/broadcast/broadcast_manager.py b/zhenxun/builtin_plugins/superuser/broadcast/broadcast_manager.py new file mode 100644 index 00000000..c3d7b5cc --- /dev/null +++ b/zhenxun/builtin_plugins/superuser/broadcast/broadcast_manager.py @@ -0,0 +1,490 @@ +import asyncio +import random +import traceback +from typing import ClassVar + +from nonebot.adapters import Bot +from nonebot.adapters.onebot.v11 import Bot as V11Bot +from nonebot.exception import ActionFailed +from nonebot_plugin_alconna import UniMessage +from nonebot_plugin_alconna.uniseg import Receipt, Reference +from nonebot_plugin_session import EventSession + +from zhenxun.models.group_console import GroupConsole +from zhenxun.services.log import logger +from zhenxun.utils.common_utils import CommonUtils +from zhenxun.utils.platform import PlatformUtils + +from .models import BroadcastDetailResult, BroadcastResult +from .utils import custom_nodes_to_v11_nodes, uni_message_to_v11_list_of_dicts + + +class BroadcastManager: + """广播管理器""" + + _last_broadcast_msg_ids: ClassVar[dict[str, int]] = {} + + @staticmethod + def _get_session_info(session: EventSession | None) -> str: + """获取会话信息字符串""" + if not session: + return "" + + try: + platform = getattr(session, "platform", "unknown") + session_id = str(session) + return f"[{platform}:{session_id}]" + except Exception: + return "[session-info-error]" + + @staticmethod + def log_error( + message: str, error: Exception, session: EventSession | None = None, **kwargs + ): + """记录错误日志""" + session_info = BroadcastManager._get_session_info(session) + error_type = type(error).__name__ + stack_trace = traceback.format_exc() + error_details = f"\n类型: {error_type}\n信息: {error!s}\n堆栈: {stack_trace}" + + logger.error( + f"{session_info} {message}{error_details}", "广播", e=error, **kwargs + ) + + @staticmethod + def log_warning(message: str, session: EventSession | None = None, **kwargs): + """记录警告级别日志""" + session_info = BroadcastManager._get_session_info(session) + logger.warning(f"{session_info} {message}", "广播", **kwargs) + + @staticmethod + def log_info(message: str, session: EventSession | None = None, **kwargs): + """记录信息级别日志""" + session_info = BroadcastManager._get_session_info(session) + logger.info(f"{session_info} {message}", "广播", **kwargs) + + @classmethod + def get_last_broadcast_msg_ids(cls) -> dict[str, int]: + """获取最近广播消息ID""" + return cls._last_broadcast_msg_ids.copy() + + @classmethod + def clear_last_broadcast_msg_ids(cls) -> None: + """清空消息ID记录""" + cls._last_broadcast_msg_ids.clear() + + @classmethod + async def get_all_groups(cls, bot: Bot) -> tuple[list[GroupConsole], str]: + """获取群组列表""" + return await PlatformUtils.get_group_list(bot) + + @classmethod + async def send( + cls, bot: Bot, message: UniMessage, session: EventSession + ) -> BroadcastResult: + """发送广播到所有群组""" + logger.debug( + f"开始广播(send - 广播到所有群组),Bot ID: {bot.self_id}", + "广播", + session=session, + ) + + logger.debug("清空上一次的广播消息ID记录", "广播", session=session) + cls.clear_last_broadcast_msg_ids() + + all_groups, _ = await cls.get_all_groups(bot) + return await cls.send_to_specific_groups(bot, message, all_groups, session) + + @classmethod + async def send_to_specific_groups( + cls, + bot: Bot, + message: UniMessage, + target_groups: list[GroupConsole], + session_info: EventSession | str | None = None, + ) -> BroadcastResult: + """发送广播到指定群组""" + log_session = session_info or bot.self_id + logger.debug( + f"开始广播,目标 {len(target_groups)} 个群组,Bot ID: {bot.self_id}", + "广播", + session=log_session, + ) + + if not target_groups: + logger.debug("目标群组列表为空,广播结束", "广播", session=log_session) + return 0, 0 + + platform = PlatformUtils.get_platform(bot) + is_forward_broadcast = any( + isinstance(seg, Reference) and getattr(seg, "nodes", None) + for seg in message + ) + + if platform == "qq" and isinstance(bot, V11Bot) and is_forward_broadcast: + if ( + len(message) == 1 + and isinstance(message[0], Reference) + and getattr(message[0], "nodes", None) + ): + nodes_list = getattr(message[0], "nodes", []) + v11_nodes = custom_nodes_to_v11_nodes(nodes_list) + node_count = len(v11_nodes) + logger.debug( + f"从 UniMessage 构造转发节点数: {node_count}", + "广播", + session=log_session, + ) + else: + logger.warning( + "广播消息包含合并转发段和其他段,将尝试打平成一个节点发送", + "广播", + session=log_session, + ) + v11_content_list = uni_message_to_v11_list_of_dicts(message) + v11_nodes = ( + [ + { + "type": "node", + "data": { + "user_id": bot.self_id, + "nickname": "广播", + "content": v11_content_list, + }, + } + ] + if v11_content_list + else [] + ) + + if not v11_nodes: + logger.warning( + "构造出的 V11 合并转发节点为空,无法发送", + "广播", + session=log_session, + ) + return 0, len(target_groups) + success_count, error_count, skip_count = await cls._broadcast_forward( + bot, log_session, target_groups, v11_nodes + ) + else: + if is_forward_broadcast: + logger.warning( + f"合并转发消息在适配器 ({platform}) 不支持,将作为普通消息发送", + "广播", + session=log_session, + ) + success_count, error_count, skip_count = await cls._broadcast_normal( + bot, log_session, target_groups, message + ) + + total = len(target_groups) + stats = f"成功: {success_count}, 失败: {error_count}" + stats += f", 跳过: {skip_count}, 总计: {total}" + logger.debug( + f"广播统计 - {stats}", + "广播", + session=log_session, + ) + + msg_ids = cls.get_last_broadcast_msg_ids() + if msg_ids: + id_list_str = ", ".join([f"{k}:{v}" for k, v in msg_ids.items()]) + logger.debug( + f"广播结束,记录了 {len(msg_ids)} 条消息ID: {id_list_str}", + "广播", + session=log_session, + ) + else: + logger.warning( + "广播结束,但没有记录任何消息ID", + "广播", + session=log_session, + ) + + return success_count, error_count + + @classmethod + async def _extract_message_id_from_result( + cls, + result: dict | Receipt, + group_key: str, + session_info: EventSession | str, + msg_type: str = "普通", + ) -> None: + """提取消息ID并记录""" + if isinstance(result, dict) and "message_id" in result: + msg_id = result["message_id"] + try: + msg_id_int = int(msg_id) + cls._last_broadcast_msg_ids[group_key] = msg_id_int + logger.debug( + f"记录群 {group_key} 的{msg_type}消息ID: {msg_id_int}", + "广播", + session=session_info, + ) + except (ValueError, TypeError): + logger.warning( + f"{msg_type}结果中的 message_id 不是有效整数: {msg_id}", + "广播", + session=session_info, + ) + elif isinstance(result, Receipt) and result.msg_ids: + try: + first_id_info = result.msg_ids[0] + msg_id = None + if isinstance(first_id_info, dict) and "message_id" in first_id_info: + msg_id = first_id_info["message_id"] + logger.debug( + f"从 Receipt.msg_ids[0] 提取到 ID: {msg_id}", + "广播", + session=session_info, + ) + elif isinstance(first_id_info, int | str): + msg_id = first_id_info + logger.debug( + f"从 Receipt.msg_ids[0] 提取到原始ID: {msg_id}", + "广播", + session=session_info, + ) + + if msg_id is not None: + try: + msg_id_int = int(msg_id) + cls._last_broadcast_msg_ids[group_key] = msg_id_int + logger.debug( + f"记录群 {group_key} 的消息ID: {msg_id_int}", + "广播", + session=session_info, + ) + except (ValueError, TypeError): + logger.warning( + f"提取的ID ({msg_id}) 不是有效整数", + "广播", + session=session_info, + ) + else: + info_str = str(first_id_info) + logger.warning( + f"无法从 Receipt.msg_ids[0] 提取ID: {info_str}", + "广播", + session=session_info, + ) + except IndexError: + logger.warning("Receipt.msg_ids 为空", "广播", session=session_info) + except Exception as e_extract: + logger.error( + f"从 Receipt 提取 msg_id 时出错: {e_extract}", + "广播", + session=session_info, + e=e_extract, + ) + else: + logger.warning( + f"发送成功但无法从结果获取消息 ID. 结果: {result}", + "广播", + session=session_info, + ) + + @classmethod + async def _check_group_availability(cls, bot: Bot, group: GroupConsole) -> bool: + """检查群组是否可用""" + if not group.group_id: + return False + + if await CommonUtils.task_is_block(bot, "broadcast", group.group_id): + return False + + return True + + @classmethod + async def _broadcast_forward( + cls, + bot: V11Bot, + session_info: EventSession | str, + group_list: list[GroupConsole], + v11_nodes: list[dict], + ) -> BroadcastDetailResult: + """发送合并转发""" + success_count = 0 + error_count = 0 + skip_count = 0 + + for _, group in enumerate(group_list): + group_key = group.group_id or group.channel_id + + if not await cls._check_group_availability(bot, group): + skip_count += 1 + continue + + try: + result = await bot.send_group_forward_msg( + group_id=int(group.group_id), messages=v11_nodes + ) + + logger.debug( + f"合并转发消息发送结果: {result}, 类型: {type(result)}", + "广播", + session=session_info, + ) + + await cls._extract_message_id_from_result( + result, group_key, session_info, "合并转发" + ) + + success_count += 1 + await asyncio.sleep(random.randint(1, 3)) + except ActionFailed as af_e: + error_count += 1 + logger.error( + f"发送失败(合并转发) to {group_key}: {af_e}", + "广播", + session=session_info, + e=af_e, + ) + except Exception as e: + error_count += 1 + logger.error( + f"发送失败(合并转发) to {group_key}: {e}", + "广播", + session=session_info, + e=e, + ) + + return success_count, error_count, skip_count + + @classmethod + async def _broadcast_normal( + cls, + bot: Bot, + session_info: EventSession | str, + group_list: list[GroupConsole], + message: UniMessage, + ) -> BroadcastDetailResult: + """发送普通消息""" + success_count = 0 + error_count = 0 + skip_count = 0 + + for _, group in enumerate(group_list): + group_key = ( + f"{group.group_id}:{group.channel_id}" + if group.channel_id + else str(group.group_id) + ) + + if not await cls._check_group_availability(bot, group): + skip_count += 1 + continue + + try: + target = PlatformUtils.get_target( + group_id=group.group_id, channel_id=group.channel_id + ) + + if target: + receipt: Receipt = await message.send(target, bot=bot) + + logger.debug( + f"广播消息发送结果: {receipt}, 类型: {type(receipt)}", + "广播", + session=session_info, + ) + + await cls._extract_message_id_from_result( + receipt, group_key, session_info + ) + + success_count += 1 + await asyncio.sleep(random.randint(1, 3)) + else: + logger.warning( + "target为空", "广播", session=session_info, target=group_key + ) + skip_count += 1 + except Exception as e: + error_count += 1 + logger.error( + f"发送失败(普通) to {group_key}: {e}", + "广播", + session=session_info, + e=e, + ) + + return success_count, error_count, skip_count + + @classmethod + async def recall_last_broadcast( + cls, bot: Bot, session_info: EventSession | str + ) -> BroadcastResult: + """撤回最近广播""" + msg_ids_to_recall = cls.get_last_broadcast_msg_ids() + + if not msg_ids_to_recall: + logger.warning( + "没有找到最近的广播消息ID记录", "广播撤回", session=session_info + ) + return 0, 0 + + id_list_str = ", ".join([f"{k}:{v}" for k, v in msg_ids_to_recall.items()]) + logger.debug( + f"找到 {len(msg_ids_to_recall)} 条广播消息ID记录: {id_list_str}", + "广播撤回", + session=session_info, + ) + + success_count = 0 + error_count = 0 + + logger.info( + f"准备撤回 {len(msg_ids_to_recall)} 条广播消息", + "广播撤回", + session=session_info, + ) + + for group_key, msg_id in msg_ids_to_recall.items(): + try: + logger.debug( + f"尝试撤回消息 (ID: {msg_id}) in {group_key}", + "广播撤回", + session=session_info, + ) + await bot.call_api("delete_msg", message_id=msg_id) + success_count += 1 + except ActionFailed as af_e: + retcode = getattr(af_e, "retcode", None) + wording = getattr(af_e, "wording", "") + if retcode == 100 and "MESSAGE_NOT_FOUND" in wording.upper(): + logger.warning( + f"消息 (ID: {msg_id}) 可能已被撤回或不存在于 {group_key}", + "广播撤回", + session=session_info, + ) + elif retcode == 300 and "delete message" in wording.lower(): + logger.warning( + f"消息 (ID: {msg_id}) 可能已被撤回或不存在于 {group_key}", + "广播撤回", + session=session_info, + ) + else: + error_count += 1 + logger.error( + f"撤回消息失败 (ID: {msg_id}) in {group_key}: {af_e}", + "广播撤回", + session=session_info, + e=af_e, + ) + except Exception as e: + error_count += 1 + logger.error( + f"撤回消息时发生未知错误 (ID: {msg_id}) in {group_key}: {e}", + "广播撤回", + session=session_info, + e=e, + ) + await asyncio.sleep(0.2) + + logger.debug("撤回操作完成,清空消息ID记录", "广播撤回", session=session_info) + cls.clear_last_broadcast_msg_ids() + + return success_count, error_count diff --git a/zhenxun/builtin_plugins/superuser/broadcast/message_processor.py b/zhenxun/builtin_plugins/superuser/broadcast/message_processor.py new file mode 100644 index 00000000..809e3645 --- /dev/null +++ b/zhenxun/builtin_plugins/superuser/broadcast/message_processor.py @@ -0,0 +1,584 @@ +import base64 +import json +from typing import Any + +from nonebot.adapters import Bot, Event +from nonebot.adapters.onebot.v11 import Message as V11Message +from nonebot.adapters.onebot.v11 import MessageSegment as V11MessageSegment +from nonebot.exception import ActionFailed +import nonebot_plugin_alconna as alc +from nonebot_plugin_alconna import UniMessage +from nonebot_plugin_alconna.uniseg.segment import ( + At, + AtAll, + CustomNode, + Image, + Reference, + Reply, + Text, + Video, +) +from nonebot_plugin_alconna.uniseg.tools import reply_fetch +from nonebot_plugin_session import EventSession + +from zhenxun.services.log import logger +from zhenxun.utils.common_utils import CommonUtils +from zhenxun.utils.message import MessageUtils + +from .broadcast_manager import BroadcastManager + +MAX_FORWARD_DEPTH = 3 + + +async def _process_forward_content( + forward_content: Any, forward_id: str | None, bot: Bot, depth: int +) -> list[CustomNode]: + """处理转发消息内容""" + nodes_for_alc = [] + content_parsed = False + + if forward_content: + nodes_from_content = None + if isinstance(forward_content, list): + nodes_from_content = forward_content + elif isinstance(forward_content, str): + try: + parsed_content = json.loads(forward_content) + if isinstance(parsed_content, list): + nodes_from_content = parsed_content + except Exception as json_e: + logger.debug( + f"[Depth {depth}] JSON解析失败: {json_e}", + "广播", + ) + + if nodes_from_content is not None: + logger.debug( + f"[D{depth}] 节点数: {len(nodes_from_content)}", + "广播", + ) + content_parsed = True + for node_data in nodes_from_content: + node = await _create_custom_node_from_data(node_data, bot, depth + 1) + if node: + nodes_for_alc.append(node) + + if not content_parsed and forward_id: + logger.debug( + f"[D{depth}] 尝试API调用ID: {forward_id}", + "广播", + ) + try: + forward_data = await bot.call_api("get_forward_msg", id=forward_id) + nodes_list = None + + if isinstance(forward_data, dict) and "messages" in forward_data: + nodes_list = forward_data["messages"] + elif ( + isinstance(forward_data, dict) + and "data" in forward_data + and isinstance(forward_data["data"], dict) + and "message" in forward_data["data"] + ): + nodes_list = forward_data["data"]["message"] + elif isinstance(forward_data, list): + nodes_list = forward_data + + if nodes_list: + node_count = len(nodes_list) + logger.debug( + f"[D{depth + 1}] 节点:{node_count}", + "广播", + ) + for node_data in nodes_list: + node = await _create_custom_node_from_data( + node_data, bot, depth + 1 + ) + if node: + nodes_for_alc.append(node) + else: + logger.warning( + f"[D{depth + 1}] ID:{forward_id}无节点", + "广播", + ) + nodes_for_alc.append( + CustomNode( + uid="0", + name="错误", + content="[嵌套转发消息获取失败]", + ) + ) + except ActionFailed as af_e: + logger.error( + f"[D{depth + 1}] API失败: {af_e}", + "广播", + e=af_e, + ) + nodes_for_alc.append( + CustomNode( + uid="0", + name="错误", + content="[嵌套转发消息获取失败]", + ) + ) + except Exception as e: + logger.error( + f"[D{depth + 1}] 处理出错: {e}", + "广播", + e=e, + ) + nodes_for_alc.append( + CustomNode( + uid="0", + name="错误", + content="[处理嵌套转发时出错]", + ) + ) + elif not content_parsed and not forward_id: + logger.warning( + f"[D{depth}] 转发段无内容也无ID", + "广播", + ) + nodes_for_alc.append( + CustomNode( + uid="0", + name="错误", + content="[嵌套转发消息无法解析]", + ) + ) + elif content_parsed and not nodes_for_alc: + logger.warning( + f"[D{depth}] 解析成功但无有效节点", + "广播", + ) + nodes_for_alc.append( + CustomNode( + uid="0", + name="信息", + content="[嵌套转发内容为空]", + ) + ) + + return nodes_for_alc + + +async def _create_custom_node_from_data( + node_data: dict, bot: Bot, depth: int +) -> CustomNode | None: + """从节点数据创建CustomNode""" + node_content_raw = node_data.get("message") or node_data.get("content") + if not node_content_raw: + logger.warning(f"[D{depth}] 节点缺少消息内容", "广播") + return None + + sender = node_data.get("sender", {}) + uid = str(sender.get("user_id", "10000")) + name = sender.get("nickname", f"用户{uid[:4]}") + + extracted_uni_msg = await _extract_content_from_message( + node_content_raw, bot, depth + ) + if not extracted_uni_msg: + return None + + return CustomNode(uid=uid, name=name, content=extracted_uni_msg) + + +async def _extract_broadcast_content( + bot: Bot, + event: Event, + arp: alc.Arparma, + session: EventSession, +) -> UniMessage | None: + """从命令参数或引用消息中提取广播内容""" + broadcast_content_msg: UniMessage | None = None + + command_content_list = arp.all_matched_args.get("content", []) + + processed_command_list = [] + has_command_content = False + + if command_content_list: + for item in command_content_list: + if isinstance(item, alc.Segment): + processed_command_list.append(item) + if not (isinstance(item, Text) and not item.text.strip()): + has_command_content = True + elif isinstance(item, str): + if item.strip(): + processed_command_list.append(Text(item.strip())) + has_command_content = True + else: + logger.warning( + f"Unexpected type in command content: {type(item)}", "广播" + ) + + if has_command_content: + logger.debug("检测到命令参数内容,优先使用参数内容", "广播", session=session) + broadcast_content_msg = UniMessage(processed_command_list) + + if not broadcast_content_msg.filter( + lambda x: not (isinstance(x, Text) and not x.text.strip()) + ): + logger.warning( + "命令参数内容解析后为空或只包含空白", "广播", session=session + ) + broadcast_content_msg = None + + if not broadcast_content_msg: + reply_segment_obj: Reply | None = await reply_fetch(event, bot) + if ( + reply_segment_obj + and hasattr(reply_segment_obj, "msg") + and reply_segment_obj.msg + ): + logger.debug( + "未检测到有效命令参数,检测到引用消息", "广播", session=session + ) + raw_quoted_content = reply_segment_obj.msg + is_forward = False + forward_id = None + + if isinstance(raw_quoted_content, V11Message): + for seg in raw_quoted_content: + if isinstance(seg, V11MessageSegment): + if seg.type == "forward": + forward_id = seg.data.get("id") + is_forward = bool(forward_id) + break + elif seg.type == "json": + try: + json_data_str = seg.data.get("data", "{}") + if isinstance(json_data_str, str): + import json + + json_data = json.loads(json_data_str) + if ( + json_data.get("app") == "com.tencent.multimsg" + or json_data.get("view") == "Forward" + ) and json_data.get("meta", {}).get( + "detail", {} + ).get("resid"): + forward_id = json_data["meta"]["detail"][ + "resid" + ] + is_forward = True + break + except Exception: + pass + + if is_forward and forward_id: + logger.info( + f"尝试获取并构造合并转发内容 (ID: {forward_id})", + "广播", + session=session, + ) + nodes_to_forward: list[CustomNode] = [] + try: + forward_data = await bot.call_api("get_forward_msg", id=forward_id) + nodes_list = None + if isinstance(forward_data, dict) and "messages" in forward_data: + nodes_list = forward_data["messages"] + elif ( + isinstance(forward_data, dict) + and "data" in forward_data + and isinstance(forward_data["data"], dict) + and "message" in forward_data["data"] + ): + nodes_list = forward_data["data"]["message"] + elif isinstance(forward_data, list): + nodes_list = forward_data + + if nodes_list is not None: + for node_data in nodes_list: + node_sender = node_data.get("sender", {}) + node_user_id = str(node_sender.get("user_id", "10000")) + node_nickname = node_sender.get( + "nickname", f"用户{node_user_id[:4]}" + ) + node_content_raw = node_data.get( + "message" + ) or node_data.get("content") + if node_content_raw: + extracted_node_uni_msg = ( + await _extract_content_from_message( + node_content_raw, bot + ) + ) + if extracted_node_uni_msg: + nodes_to_forward.append( + CustomNode( + uid=node_user_id, + name=node_nickname, + content=extracted_node_uni_msg, + ) + ) + if nodes_to_forward: + broadcast_content_msg = UniMessage( + Reference(nodes=nodes_to_forward) + ) + except ActionFailed: + await MessageUtils.build_message( + "获取合并转发消息失败,可能不支持此 API。" + ).send(reply_to=True) + return None + except Exception as api_e: + logger.error(f"处理合并转发时出错: {api_e}", "广播", e=api_e) + await MessageUtils.build_message( + "处理合并转发消息时发生内部错误。" + ).send(reply_to=True) + return None + else: + broadcast_content_msg = await _extract_content_from_message( + raw_quoted_content, bot + ) + else: + logger.debug("未检测到命令参数和引用消息", "广播", session=session) + await MessageUtils.build_message("请提供广播内容或引用要广播的消息").send( + reply_to=True + ) + return None + + if not broadcast_content_msg: + logger.error( + "未能从命令参数或引用消息中获取有效的广播内容", "广播", session=session + ) + await MessageUtils.build_message("错误:未能获取有效的广播内容。").send( + reply_to=True + ) + return None + + return broadcast_content_msg + + +async def _process_v11_segment( + seg_obj: V11MessageSegment | dict, depth: int, index: int, bot: Bot +) -> list[alc.Segment]: + """处理V11消息段""" + result = [] + seg_type = None + data_dict = None + + if isinstance(seg_obj, V11MessageSegment): + seg_type = seg_obj.type + data_dict = seg_obj.data + elif isinstance(seg_obj, dict): + seg_type = seg_obj.get("type") + data_dict = seg_obj.get("data") + else: + return result + + if not (seg_type and data_dict is not None): + logger.warning(f"[D{depth}] 跳过无效数据: {type(seg_obj)}", "广播") + return result + + if seg_type == "text": + text_content = data_dict.get("text", "") + if isinstance(text_content, str) and text_content.strip(): + result.append(Text(text_content)) + elif seg_type == "image": + img_seg = None + if data_dict.get("url"): + img_seg = Image(url=data_dict["url"]) + elif data_dict.get("file"): + file_val = data_dict["file"] + if isinstance(file_val, str) and file_val.startswith("base64://"): + b64_data = file_val[9:] + raw_bytes = base64.b64decode(b64_data) + img_seg = Image(raw=raw_bytes) + else: + img_seg = Image(path=file_val) + if img_seg: + result.append(img_seg) + else: + logger.warning(f"[Depth {depth}] V11 图片 {index} 缺少URL/文件", "广播") + elif seg_type == "at": + target_qq = data_dict.get("qq", "") + if target_qq.lower() == "all": + result.append(AtAll()) + elif target_qq: + result.append(At(flag="user", target=target_qq)) + elif seg_type == "video": + video_seg = None + if data_dict.get("url"): + video_seg = Video(url=data_dict["url"]) + elif data_dict.get("file"): + file_val = data_dict["file"] + if isinstance(file_val, str) and file_val.startswith("base64://"): + b64_data = file_val[9:] + raw_bytes = base64.b64decode(b64_data) + video_seg = Video(raw=raw_bytes) + else: + video_seg = Video(path=file_val) + if video_seg: + result.append(video_seg) + logger.debug(f"[Depth {depth}] 处理视频消息成功", "广播") + else: + logger.warning(f"[Depth {depth}] V11 视频 {index} 缺少URL/文件", "广播") + elif seg_type == "forward": + nested_forward_id = data_dict.get("id") or data_dict.get("resid") + nested_forward_content = data_dict.get("content") + + logger.debug(f"[D{depth}] 嵌套转发ID: {nested_forward_id}", "广播") + + nested_nodes = await _process_forward_content( + nested_forward_content, nested_forward_id, bot, depth + ) + + if nested_nodes: + result.append(Reference(nodes=nested_nodes)) + else: + logger.warning(f"[D{depth}] 跳过类型: {seg_type}", "广播") + + return result + + +async def _extract_content_from_message( + message_content: Any, bot: Bot, depth: int = 0 +) -> UniMessage: + """提取消息内容到UniMessage""" + temp_msg = UniMessage() + input_type_str = str(type(message_content)) + + if depth >= MAX_FORWARD_DEPTH: + logger.warning( + f"[Depth {depth}] 达到最大递归深度 {MAX_FORWARD_DEPTH},停止解析嵌套转发。", + "广播", + ) + temp_msg.append(Text("[嵌套转发层数过多,内容已省略]")) + return temp_msg + + segments_to_process = [] + + if isinstance(message_content, UniMessage): + segments_to_process = list(message_content) + elif isinstance(message_content, V11Message): + segments_to_process = list(message_content) + elif isinstance(message_content, list): + segments_to_process = message_content + elif ( + isinstance(message_content, dict) + and "type" in message_content + and "data" in message_content + ): + segments_to_process = [message_content] + elif isinstance(message_content, str): + if message_content.strip(): + temp_msg.append(Text(message_content)) + return temp_msg + else: + logger.warning(f"[Depth {depth}] 无法处理的输入类型: {input_type_str}", "广播") + return temp_msg + + if segments_to_process: + for index, seg_obj in enumerate(segments_to_process): + try: + if isinstance(seg_obj, Text): + text_content = getattr(seg_obj, "text", None) + if isinstance(text_content, str) and text_content.strip(): + temp_msg.append(seg_obj) + elif isinstance(seg_obj, Image): + if ( + getattr(seg_obj, "url", None) + or getattr(seg_obj, "path", None) + or getattr(seg_obj, "raw", None) + ): + temp_msg.append(seg_obj) + elif isinstance(seg_obj, At): + temp_msg.append(seg_obj) + elif isinstance(seg_obj, AtAll): + temp_msg.append(seg_obj) + elif isinstance(seg_obj, Video): + if ( + getattr(seg_obj, "url", None) + or getattr(seg_obj, "path", None) + or getattr(seg_obj, "raw", None) + ): + temp_msg.append(seg_obj) + logger.debug(f"[D{depth}] 处理Video对象成功", "广播") + else: + processed_segments = await _process_v11_segment( + seg_obj, depth, index, bot + ) + temp_msg.extend(processed_segments) + except Exception as e_conv_seg: + logger.warning( + f"[D{depth}] 处理段 {index} 出错: {e_conv_seg}", + "广播", + e=e_conv_seg, + ) + + if not temp_msg and message_content: + logger.warning(f"未能从类型 {input_type_str} 中提取内容", "广播") + + return temp_msg + + +async def get_broadcast_target_groups( + bot: Bot, session: EventSession +) -> tuple[list, list]: + """获取广播目标群组和启用了广播功能的群组""" + target_groups = [] + all_groups, _ = await BroadcastManager.get_all_groups(bot) + + current_group_id = None + if hasattr(session, "id2") and session.id2: + current_group_id = session.id2 + + if current_group_id: + target_groups = [ + group for group in all_groups if group.group_id != current_group_id + ] + logger.info( + f"向除当前群组({current_group_id})外的所有群组广播", "广播", session=session + ) + else: + target_groups = all_groups + logger.info("向所有群组广播", "广播", session=session) + + if not target_groups: + await MessageUtils.build_message("没有找到符合条件的广播目标群组。").send( + reply_to=True + ) + return [], [] + + enabled_groups = [] + for group in target_groups: + if not await CommonUtils.task_is_block(bot, "broadcast", group.group_id): + enabled_groups.append(group) + + if not enabled_groups: + await MessageUtils.build_message( + "没有启用了广播功能的目标群组可供立即发送。" + ).send(reply_to=True) + return target_groups, [] + + return target_groups, enabled_groups + + +async def send_broadcast_and_notify( + bot: Bot, + event: Event, + message: UniMessage, + enabled_groups: list, + target_groups: list, + session: EventSession, +) -> None: + """发送广播并通知结果""" + BroadcastManager.clear_last_broadcast_msg_ids() + count, error_count = await BroadcastManager.send_to_specific_groups( + bot, message, enabled_groups, session + ) + + result = f"成功广播 {count} 个群组" + if error_count: + result += f"\n发送失败 {error_count} 个群组" + result += f"\n有效: {len(enabled_groups)} / 总计: {len(target_groups)}" + + user_id = str(event.get_user_id()) + await bot.send_private_msg(user_id=user_id, message=f"发送广播完成!\n{result}") + + BroadcastManager.log_info( + f"广播完成,有效/总计: {len(enabled_groups)}/{len(target_groups)}", + session, + ) diff --git a/zhenxun/builtin_plugins/superuser/broadcast/models.py b/zhenxun/builtin_plugins/superuser/broadcast/models.py new file mode 100644 index 00000000..4bcdf936 --- /dev/null +++ b/zhenxun/builtin_plugins/superuser/broadcast/models.py @@ -0,0 +1,64 @@ +from datetime import datetime +from typing import Any + +from nonebot_plugin_alconna import UniMessage + +from zhenxun.models.group_console import GroupConsole + +GroupKey = str +MessageID = int +BroadcastResult = tuple[int, int] +BroadcastDetailResult = tuple[int, int, int] + + +class BroadcastTarget: + """广播目标""" + + def __init__(self, group_id: str, channel_id: str | None = None): + self.group_id = group_id + self.channel_id = channel_id + + def to_dict(self) -> dict[str, str | None]: + """转换为字典格式""" + return {"group_id": self.group_id, "channel_id": self.channel_id} + + @classmethod + def from_group_console(cls, group: GroupConsole) -> "BroadcastTarget": + """从 GroupConsole 对象创建""" + return cls(group_id=group.group_id, channel_id=group.channel_id) + + @property + def key(self) -> str: + """获取群组的唯一标识""" + if self.channel_id: + return f"{self.group_id}:{self.channel_id}" + return str(self.group_id) + + +class BroadcastTask: + """广播任务""" + + def __init__( + self, + bot_id: str, + message: UniMessage, + targets: list[BroadcastTarget], + scheduled_time: datetime | None = None, + task_id: str | None = None, + ): + self.bot_id = bot_id + self.message = message + self.targets = targets + self.scheduled_time = scheduled_time + self.task_id = task_id + + def to_dict(self) -> dict[str, Any]: + """转换为字典格式,用于序列化""" + return { + "bot_id": self.bot_id, + "targets": [t.to_dict() for t in self.targets], + "scheduled_time": self.scheduled_time.isoformat() + if self.scheduled_time + else None, + "task_id": self.task_id, + } diff --git a/zhenxun/builtin_plugins/superuser/broadcast/utils.py b/zhenxun/builtin_plugins/superuser/broadcast/utils.py new file mode 100644 index 00000000..748559fd --- /dev/null +++ b/zhenxun/builtin_plugins/superuser/broadcast/utils.py @@ -0,0 +1,175 @@ +import base64 + +import nonebot_plugin_alconna as alc +from nonebot_plugin_alconna import UniMessage +from nonebot_plugin_alconna.uniseg import Reference +from nonebot_plugin_alconna.uniseg.segment import CustomNode, Video + +from zhenxun.services.log import logger + + +def uni_segment_to_v11_segment_dict( + seg: alc.Segment, depth: int = 0 +) -> dict | list[dict] | None: + """UniSeg段转V11字典""" + if isinstance(seg, alc.Text): + return {"type": "text", "data": {"text": seg.text}} + elif isinstance(seg, alc.Image): + if getattr(seg, "url", None): + return { + "type": "image", + "data": {"file": seg.url}, + } + elif getattr(seg, "raw", None): + raw_data = seg.raw + if isinstance(raw_data, str): + if len(raw_data) >= 9 and raw_data[:9] == "base64://": + return {"type": "image", "data": {"file": raw_data}} + elif isinstance(raw_data, bytes): + b64_str = base64.b64encode(raw_data).decode() + return {"type": "image", "data": {"file": f"base64://{b64_str}"}} + else: + logger.warning(f"无法处理 Image.raw 的类型: {type(raw_data)}", "广播") + elif getattr(seg, "path", None): + logger.warning( + f"在合并转发中使用了本地图片路径,可能无法显示: {seg.path}", "广播" + ) + return {"type": "image", "data": {"file": f"file:///{seg.path}"}} + else: + logger.warning(f"alc.Image 缺少有效数据,无法转换为 V11 段: {seg}", "广播") + elif isinstance(seg, alc.At): + return {"type": "at", "data": {"qq": seg.target}} + elif isinstance(seg, alc.AtAll): + return {"type": "at", "data": {"qq": "all"}} + elif isinstance(seg, Video): + if getattr(seg, "url", None): + return { + "type": "video", + "data": {"file": seg.url}, + } + elif getattr(seg, "raw", None): + raw_data = seg.raw + if isinstance(raw_data, str): + if len(raw_data) >= 9 and raw_data[:9] == "base64://": + return {"type": "video", "data": {"file": raw_data}} + elif isinstance(raw_data, bytes): + b64_str = base64.b64encode(raw_data).decode() + return {"type": "video", "data": {"file": f"base64://{b64_str}"}} + else: + logger.warning(f"无法处理 Video.raw 的类型: {type(raw_data)}", "广播") + elif getattr(seg, "path", None): + logger.warning( + f"在合并转发中使用了本地视频路径,可能无法显示: {seg.path}", "广播" + ) + return {"type": "video", "data": {"file": f"file:///{seg.path}"}} + else: + logger.warning(f"Video 缺少有效数据,无法转换为 V11 段: {seg}", "广播") + elif isinstance(seg, Reference) and getattr(seg, "nodes", None): + if depth >= 3: + logger.warning( + f"嵌套转发深度超过限制 (depth={depth}),不再继续解析", "广播" + ) + return {"type": "text", "data": {"text": "[嵌套转发层数过多,内容已省略]"}} + + nested_v11_content_list = [] + nodes_list = getattr(seg, "nodes", []) + for node in nodes_list: + if isinstance(node, CustomNode): + node_v11_content = [] + if isinstance(node.content, UniMessage): + for nested_seg in node.content: + converted_dict = uni_segment_to_v11_segment_dict( + nested_seg, depth + 1 + ) + if isinstance(converted_dict, list): + node_v11_content.extend(converted_dict) + elif converted_dict: + node_v11_content.append(converted_dict) + elif isinstance(node.content, str): + node_v11_content.append( + {"type": "text", "data": {"text": node.content}} + ) + if node_v11_content: + separator = { + "type": "text", + "data": { + "text": f"\n--- 来自 {node.name} ({node.uid}) 的消息 ---\n" + }, + } + nested_v11_content_list.insert(0, separator) + nested_v11_content_list.extend(node_v11_content) + nested_v11_content_list.append( + {"type": "text", "data": {"text": "\n---\n"}} + ) + + return nested_v11_content_list + + else: + logger.warning(f"广播时跳过不支持的 UniSeg 段类型: {type(seg)}", "广播") + return None + + +def uni_message_to_v11_list_of_dicts(uni_msg: UniMessage | str | list) -> list[dict]: + """UniMessage转V11字典列表""" + try: + if isinstance(uni_msg, str): + return [{"type": "text", "data": {"text": uni_msg}}] + + if isinstance(uni_msg, list): + if not uni_msg: + return [] + + if all(isinstance(item, str) for item in uni_msg): + return [{"type": "text", "data": {"text": item}} for item in uni_msg] + + result = [] + for item in uni_msg: + if hasattr(item, "__iter__") and not isinstance(item, str | bytes): + result.extend(uni_message_to_v11_list_of_dicts(item)) + elif hasattr(item, "text") and not isinstance(item, str | bytes): + text_value = getattr(item, "text", "") + result.append({"type": "text", "data": {"text": str(text_value)}}) + elif hasattr(item, "url") and not isinstance(item, str | bytes): + url_value = getattr(item, "url", "") + if isinstance(item, Video): + result.append( + {"type": "video", "data": {"file": str(url_value)}} + ) + else: + result.append( + {"type": "image", "data": {"file": str(url_value)}} + ) + else: + try: + result.append({"type": "text", "data": {"text": str(item)}}) + except Exception as e: + logger.warning(f"无法转换列表元素: {item}, 错误: {e}", "广播") + return result + except Exception as e: + logger.warning(f"消息转换过程中出错: {e}", "广播") + + return [{"type": "text", "data": {"text": str(uni_msg)}}] + + +def custom_nodes_to_v11_nodes(custom_nodes: list[CustomNode]) -> list[dict]: + """CustomNode列表转V11节点""" + v11_nodes = [] + for node in custom_nodes: + v11_content_list = uni_message_to_v11_list_of_dicts(node.content) + + if v11_content_list: + v11_nodes.append( + { + "type": "node", + "data": { + "user_id": str(node.uid), + "nickname": node.name, + "content": v11_content_list, + }, + } + ) + else: + logger.warning( + f"CustomNode (uid={node.uid}) 内容转换后为空,跳过此节点", "广播" + ) + return v11_nodes diff --git a/zhenxun/builtin_plugins/web_ui/__init__.py b/zhenxun/builtin_plugins/web_ui/__init__.py index 90772bc5..619d56bf 100644 --- a/zhenxun/builtin_plugins/web_ui/__init__.py +++ b/zhenxun/builtin_plugins/web_ui/__init__.py @@ -10,7 +10,9 @@ from zhenxun.configs.config import Config as gConfig from zhenxun.configs.utils import PluginExtraData, RegisterConfig from zhenxun.services.log import logger, logger_ from zhenxun.utils.enum import PluginType +from zhenxun.utils.manager.priority_manager import PriorityLifecycle +from .api.configure import router as configure_router from .api.logs import router as ws_log_routes from .api.logs.log_manager import LOG_STORAGE from .api.menu import router as menu_router @@ -81,6 +83,7 @@ BaseApiRouter.include_router(database_router) BaseApiRouter.include_router(plugin_router) BaseApiRouter.include_router(system_router) BaseApiRouter.include_router(menu_router) +BaseApiRouter.include_router(configure_router) WsApiRouter = APIRouter(prefix="/zhenxun/socket") @@ -89,7 +92,7 @@ WsApiRouter.include_router(status_routes) WsApiRouter.include_router(chat_routes) -@driver.on_startup +@PriorityLifecycle.on_startup(priority=0) async def _(): try: # 存储任务引用的列表,防止任务被垃圾回收 diff --git a/zhenxun/builtin_plugins/web_ui/api/configure/__init__.py b/zhenxun/builtin_plugins/web_ui/api/configure/__init__.py new file mode 100644 index 00000000..0ecde197 --- /dev/null +++ b/zhenxun/builtin_plugins/web_ui/api/configure/__init__.py @@ -0,0 +1,133 @@ +import asyncio +import os +from pathlib import Path +import re +import subprocess +import sys +import time + +from fastapi import APIRouter +from fastapi.responses import JSONResponse +import nonebot + +from zhenxun.configs.config import BotConfig, Config + +from ...base_model import Result +from .data_source import test_db_connection +from .model import Setting + +router = APIRouter(prefix="/configure") + +driver = nonebot.get_driver() + +port = driver.config.port + +BAT_FILE = Path() / "win启动.bat" + +FILE_NAME = ".configure_restart" + + +@router.post( + "/set_configure", + response_model=Result, + response_class=JSONResponse, + description="设置基础配置", +) +async def _(setting: Setting) -> Result: + global port + password = Config.get_config("web-ui", "password") + if password or BotConfig.db_url: + return Result.fail("配置已存在,请先删除DB_URL内容和前端密码再进行设置。") + env_file = Path() / ".env.dev" + if not env_file.exists(): + return Result.fail("配置文件.env.dev不存在。") + env_text = env_file.read_text(encoding="utf-8") + if setting.db_url: + if setting.db_url.startswith("sqlite"): + base_dir = Path().resolve() + # 清理和验证数据库路径 + db_path_str = setting.db_url.split(":")[-1].strip() + # 移除任何可能的路径遍历尝试 + db_path_str = re.sub(r"[\\/]\.\.[\\/]", "", db_path_str) + # 规范化路径 + db_path = Path(db_path_str).resolve() + parent_path = db_path.parent + + # 验证路径是否在项目根目录内 + try: + if not parent_path.absolute().is_relative_to(base_dir): + return Result.fail("数据库路径不在项目根目录内。") + except ValueError: + return Result.fail("无效的数据库路径。") + + # 创建目录 + try: + parent_path.mkdir(parents=True, exist_ok=True) + except Exception as e: + return Result.fail(f"创建数据库目录失败: {e!s}") + + env_text = env_text.replace('DB_URL = ""', f'DB_URL = "{setting.db_url}"') + if setting.superusers: + superusers = ", ".join([f'"{s}"' for s in setting.superusers]) + env_text = re.sub(r"SUPERUSERS=\[.*?\]", f"SUPERUSERS=[{superusers}]", env_text) + if setting.host: + env_text = env_text.replace("HOST = 127.0.0.1", f"HOST = {setting.host}") + if setting.port: + env_text = env_text.replace("PORT = 8080", f"PORT = {setting.port}") + port = setting.port + if setting.username: + Config.set_config("web-ui", "username", setting.username) + Config.set_config("web-ui", "password", setting.password, True) + env_file.write_text(env_text, encoding="utf-8") + if BAT_FILE.exists(): + for file in os.listdir(Path()): + if file.startswith(FILE_NAME): + Path(file).unlink() + flag_file = Path() / f"{FILE_NAME}_{int(time.time())}" + flag_file.touch() + return Result.ok(BAT_FILE.exists(), info="设置成功,请重启真寻以完成配置!") + + +@router.get( + "/test_db", + response_model=Result, + response_class=JSONResponse, + description="设置基础配置", +) +async def _(db_url: str) -> Result: + result = await test_db_connection(db_url) + if isinstance(result, str): + return Result.fail(result) + return Result.ok(info="数据库连接成功!") + + +async def run_restart_command(bat_path: Path, port: int): + """在后台执行重启命令""" + await asyncio.sleep(1) # 确保 FastAPI 已返回响应 + subprocess.Popen([bat_path, str(port)], shell=True) # noqa: ASYNC220 + sys.exit(0) # 退出当前进程 + + +@router.post( + "/restart", + response_model=Result, + response_class=JSONResponse, + description="重启", +) +async def _() -> Result: + if not BAT_FILE.exists(): + return Result.fail("自动重启仅支持意见整合包,请尝试手动重启") + flag_file = next( + (Path() / file for file in os.listdir(Path()) if file.startswith(FILE_NAME)), + None, + ) + if not flag_file or not flag_file.exists(): + return Result.fail("重启标志文件不存在...") + set_time = flag_file.name.split("_")[-1] + if time.time() - float(set_time) > 10 * 60: + return Result.fail("重启标志文件已过期,请重新设置配置。") + flag_file.unlink() + try: + return Result.ok(info="执行重启命令成功") + finally: + asyncio.create_task(run_restart_command(BAT_FILE, port)) # noqa: RUF006 diff --git a/zhenxun/builtin_plugins/web_ui/api/configure/data_source.py b/zhenxun/builtin_plugins/web_ui/api/configure/data_source.py new file mode 100644 index 00000000..ad8c73c9 --- /dev/null +++ b/zhenxun/builtin_plugins/web_ui/api/configure/data_source.py @@ -0,0 +1,18 @@ +from tortoise import Tortoise + + +async def test_db_connection(db_url: str) -> bool | str: + try: + # 初始化 Tortoise ORM + await Tortoise.init( + db_url=db_url, + modules={"models": ["__main__"]}, # 这里不需要实际模型 + ) + # 测试连接 + await Tortoise.get_connection("default").execute_query("SELECT 1") + return True + except Exception as e: + return str(e) + finally: + # 关闭连接 + await Tortoise.close_connections() diff --git a/zhenxun/builtin_plugins/web_ui/api/configure/model.py b/zhenxun/builtin_plugins/web_ui/api/configure/model.py new file mode 100644 index 00000000..4a6b3486 --- /dev/null +++ b/zhenxun/builtin_plugins/web_ui/api/configure/model.py @@ -0,0 +1,16 @@ +from pydantic import BaseModel + + +class Setting(BaseModel): + superusers: list[str] + """超级用户列表""" + db_url: str + """数据库地址""" + host: str + """主机地址""" + port: int + """端口""" + username: str + """前端用户名""" + password: str + """前端密码""" diff --git a/zhenxun/builtin_plugins/web_ui/api/menu/data_source.py b/zhenxun/builtin_plugins/web_ui/api/menu/data_source.py index 14f5c928..e54bf9e5 100644 --- a/zhenxun/builtin_plugins/web_ui/api/menu/data_source.py +++ b/zhenxun/builtin_plugins/web_ui/api/menu/data_source.py @@ -5,54 +5,63 @@ from zhenxun.services.log import logger from .model import MenuData, MenuItem +default_menus = [ + MenuItem( + name="仪表盘", + module="dashboard", + router="/dashboard", + icon="dashboard", + default=True, + ), + MenuItem( + name="真寻控制台", + module="command", + router="/command", + icon="command", + ), + MenuItem(name="插件列表", module="plugin", router="/plugin", icon="plugin"), + MenuItem(name="插件商店", module="store", router="/store", icon="store"), + MenuItem(name="好友/群组", module="manage", router="/manage", icon="user"), + MenuItem( + name="数据库管理", + module="database", + router="/database", + icon="database", + ), + MenuItem(name="系统信息", module="system", router="/system", icon="system"), + MenuItem(name="关于我们", module="about", router="/about", icon="about"), +] -class MenuManage: + +class MenuManager: def __init__(self) -> None: self.file = DATA_PATH / "web_ui" / "menu.json" self.menu = [] if self.file.exists(): try: + temp_menu = [] self.menu = json.load(self.file.open(encoding="utf8")) + self_menu_name = [menu["name"] for menu in self.menu] + for module in [m.module for m in default_menus]: + if module in self_menu_name: + temp_menu.append( + MenuItem( + **next(m for m in self.menu if m["module"] == module) + ) + ) + else: + temp_menu.append(self.__get_menu_model(module)) + self.menu = temp_menu except Exception as e: logger.warning("菜单文件损坏,已重新生成...", "WebUi", e=e) if not self.menu: - self.menu = [ - MenuItem( - name="仪表盘", - module="dashboard", - router="/dashboard", - icon="dashboard", - default=True, - ), - MenuItem( - name="真寻控制台", - module="command", - router="/command", - icon="command", - ), - MenuItem( - name="插件列表", module="plugin", router="/plugin", icon="plugin" - ), - MenuItem( - name="插件商店", module="store", router="/store", icon="store" - ), - MenuItem( - name="好友/群组", module="manage", router="/manage", icon="user" - ), - MenuItem( - name="数据库管理", - module="database", - router="/database", - icon="database", - ), - MenuItem( - name="文件管理", module="system", router="/system", icon="system" - ), - MenuItem( - name="关于我们", module="about", router="/about", icon="about" - ), - ] - self.save() + self.menu = default_menus + self.save() + + def __get_menu_model(self, module: str): + return default_menus[ + next(i for i, m in enumerate(default_menus) if m.module == module) + ] def get_menus(self): return MenuData(menus=self.menu) @@ -64,4 +73,4 @@ class MenuManage: json.dump(temp, f, ensure_ascii=False, indent=4) -menu_manage = MenuManage() +menu_manage = MenuManager() diff --git a/zhenxun/builtin_plugins/web_ui/api/tabs/dashboard/data_source.py b/zhenxun/builtin_plugins/web_ui/api/tabs/dashboard/data_source.py index 6c312db3..87011c93 100644 --- a/zhenxun/builtin_plugins/web_ui/api/tabs/dashboard/data_source.py +++ b/zhenxun/builtin_plugins/web_ui/api/tabs/dashboard/data_source.py @@ -13,6 +13,7 @@ from zhenxun.models.bot_connect_log import BotConnectLog from zhenxun.models.chat_history import ChatHistory from zhenxun.models.statistics import Statistics from zhenxun.services.log import logger +from zhenxun.utils.manager.priority_manager import PriorityLifecycle from zhenxun.utils.platform import PlatformUtils from ....base_model import BaseResultModel, QueryModel @@ -31,7 +32,7 @@ driver: Driver = nonebot.get_driver() CONNECT_TIME = 0 -@driver.on_startup +@PriorityLifecycle.on_startup(priority=5) async def _(): global CONNECT_TIME CONNECT_TIME = int(time.time()) diff --git a/zhenxun/builtin_plugins/web_ui/api/tabs/database/__init__.py b/zhenxun/builtin_plugins/web_ui/api/tabs/database/__init__.py index b963e291..91fbc5c0 100644 --- a/zhenxun/builtin_plugins/web_ui/api/tabs/database/__init__.py +++ b/zhenxun/builtin_plugins/web_ui/api/tabs/database/__init__.py @@ -8,6 +8,7 @@ from zhenxun.configs.config import BotConfig from zhenxun.models.plugin_info import PluginInfo from zhenxun.models.task_info import TaskInfo from zhenxun.services.log import logger +from zhenxun.utils.manager.priority_manager import PriorityLifecycle from ....base_model import BaseResultModel, QueryModel, Result from ....utils import authentication @@ -21,7 +22,7 @@ router = APIRouter(prefix="/database") driver: Driver = nonebot.get_driver() -@driver.on_startup +@PriorityLifecycle.on_startup(priority=5) async def _(): for plugin in nonebot.get_loaded_plugins(): module = plugin.name diff --git a/zhenxun/builtin_plugins/web_ui/api/tabs/main/data_source.py b/zhenxun/builtin_plugins/web_ui/api/tabs/main/data_source.py index e87647dd..2a783b22 100644 --- a/zhenxun/builtin_plugins/web_ui/api/tabs/main/data_source.py +++ b/zhenxun/builtin_plugins/web_ui/api/tabs/main/data_source.py @@ -119,7 +119,7 @@ class ApiDataSource: (await PlatformUtils.get_friend_list(select_bot.bot))[0] ) except Exception as e: - logger.warning("获取bot好友/群组信息失败...", "WebUi", e=e) + logger.warning("获取bot好友/群组数量失败...", "WebUi", e=e) select_bot.group_count = 0 select_bot.friend_count = 0 select_bot.status = await BotConsole.get_bot_status(select_bot.self_id) diff --git a/zhenxun/builtin_plugins/web_ui/api/tabs/system/__init__.py b/zhenxun/builtin_plugins/web_ui/api/tabs/system/__init__.py index ffcd05be..7617be09 100644 --- a/zhenxun/builtin_plugins/web_ui/api/tabs/system/__init__.py +++ b/zhenxun/builtin_plugins/web_ui/api/tabs/system/__init__.py @@ -9,7 +9,7 @@ from fastapi.responses import JSONResponse from zhenxun.utils._build_image import BuildImage from ....base_model import Result, SystemFolderSize -from ....utils import authentication, get_system_disk +from ....utils import authentication, get_system_disk, validate_path from .model import AddFile, DeleteFile, DirFile, RenameFile, SaveFile router = APIRouter(prefix="/system") @@ -25,22 +25,30 @@ IMAGE_TYPE = ["jpg", "jpeg", "png", "gif", "bmp", "webp", "svg"] description="获取文件列表", ) async def _(path: str | None = None) -> Result[list[DirFile]]: - base_path = Path(path) if path else Path() - data_list = [] - for file in os.listdir(base_path): - file_path = base_path / file - is_image = any(file.endswith(f".{t}") for t in IMAGE_TYPE) - data_list.append( - DirFile( - is_file=not file_path.is_dir(), - is_image=is_image, - name=file, - parent=path, - size=None if file_path.is_dir() else file_path.stat().st_size, - mtime=file_path.stat().st_mtime, + try: + base_path, error = validate_path(path) + if error: + return Result.fail(error) + if not base_path: + return Result.fail("无效的路径") + data_list = [] + for file in os.listdir(base_path): + file_path = base_path / file + is_image = any(file.endswith(f".{t}") for t in IMAGE_TYPE) + data_list.append( + DirFile( + is_file=not file_path.is_dir(), + is_image=is_image, + name=file, + parent=path, + size=None if file_path.is_dir() else file_path.stat().st_size, + mtime=file_path.stat().st_mtime, + ) ) - ) - return Result.ok(data_list) + sorted(data_list, key=lambda f: f.name) + return Result.ok(data_list) + except Exception as e: + return Result.fail(f"获取文件列表失败: {e!s}") @router.get( @@ -62,8 +70,12 @@ async def _(full_path: str | None = None) -> Result[list[SystemFolderSize]]: description="删除文件", ) async def _(param: DeleteFile) -> Result: - path = Path(param.full_path) - if not path or not path.exists(): + path, error = validate_path(param.full_path) + if error: + return Result.fail(error) + if not path: + return Result.fail("无效的路径") + if not path.exists(): return Result.warning_("文件不存在...") try: path.unlink() @@ -80,8 +92,12 @@ async def _(param: DeleteFile) -> Result: description="删除文件夹", ) async def _(param: DeleteFile) -> Result: - path = Path(param.full_path) - if not path or not path.exists() or path.is_file(): + path, error = validate_path(param.full_path) + if error: + return Result.fail(error) + if not path: + return Result.fail("无效的路径") + if not path.exists() or path.is_file(): return Result.warning_("文件夹不存在...") try: shutil.rmtree(path.absolute()) @@ -98,10 +114,14 @@ async def _(param: DeleteFile) -> Result: description="重命名文件", ) async def _(param: RenameFile) -> Result: - path = ( - (Path(param.parent) / param.old_name) if param.parent else Path(param.old_name) - ) - if not path or not path.exists(): + parent_path, error = validate_path(param.parent) + if error: + return Result.fail(error) + if not parent_path: + return Result.fail("无效的路径") + + path = (parent_path / param.old_name) if param.parent else Path(param.old_name) + if not path.exists(): return Result.warning_("文件不存在...") try: path.rename(path.parent / param.name) @@ -118,10 +138,14 @@ async def _(param: RenameFile) -> Result: description="重命名文件夹", ) async def _(param: RenameFile) -> Result: - path = ( - (Path(param.parent) / param.old_name) if param.parent else Path(param.old_name) - ) - if not path or not path.exists() or path.is_file(): + parent_path, error = validate_path(param.parent) + if error: + return Result.fail(error) + if not parent_path: + return Result.fail("无效的路径") + + path = (parent_path / param.old_name) if param.parent else Path(param.old_name) + if not path.exists() or path.is_file(): return Result.warning_("文件夹不存在...") try: new_path = path.parent / param.name @@ -139,7 +163,13 @@ async def _(param: RenameFile) -> Result: description="新建文件", ) async def _(param: AddFile) -> Result: - path = (Path(param.parent) / param.name) if param.parent else Path(param.name) + parent_path, error = validate_path(param.parent) + if error: + return Result.fail(error) + if not parent_path: + return Result.fail("无效的路径") + + path = (parent_path / param.name) if param.parent else Path(param.name) if path.exists(): return Result.warning_("文件已存在...") try: @@ -157,7 +187,13 @@ async def _(param: AddFile) -> Result: description="新建文件夹", ) async def _(param: AddFile) -> Result: - path = (Path(param.parent) / param.name) if param.parent else Path(param.name) + parent_path, error = validate_path(param.parent) + if error: + return Result.fail(error) + if not parent_path: + return Result.fail("无效的路径") + + path = (parent_path / param.name) if param.parent else Path(param.name) if path.exists(): return Result.warning_("文件夹已存在...") try: @@ -175,7 +211,11 @@ async def _(param: AddFile) -> Result: description="读取文件", ) async def _(full_path: str) -> Result: - path = Path(full_path) + path, error = validate_path(full_path) + if error: + return Result.fail(error) + if not path: + return Result.fail("无效的路径") if not path.exists(): return Result.warning_("文件不存在...") try: @@ -193,9 +233,13 @@ async def _(full_path: str) -> Result: description="读取文件", ) async def _(param: SaveFile) -> Result[str]: - path = Path(param.full_path) + path, error = validate_path(param.full_path) + if error: + return Result.fail(error) + if not path: + return Result.fail("无效的路径") try: - async with aiofiles.open(path, "w", encoding="utf-8") as f: + async with aiofiles.open(str(path), "w", encoding="utf-8") as f: await f.write(param.content) return Result.ok("更新成功!") except Exception as e: @@ -210,7 +254,11 @@ async def _(param: SaveFile) -> Result[str]: description="读取图片base64", ) async def _(full_path: str) -> Result[str]: - path = Path(full_path) + path, error = validate_path(full_path) + if error: + return Result.fail(error) + if not path: + return Result.fail("无效的路径") if not path.exists(): return Result.warning_("文件不存在...") try: diff --git a/zhenxun/builtin_plugins/web_ui/config.py b/zhenxun/builtin_plugins/web_ui/config.py index bddcb062..4a88aad9 100644 --- a/zhenxun/builtin_plugins/web_ui/config.py +++ b/zhenxun/builtin_plugins/web_ui/config.py @@ -1,6 +1,12 @@ +import sys + from fastapi.middleware.cors import CORSMiddleware import nonebot -from strenum import StrEnum + +if sys.version_info >= (3, 11): + from enum import StrEnum +else: + from strenum import StrEnum from zhenxun.configs.path_config import DATA_PATH, TEMP_PATH diff --git a/zhenxun/builtin_plugins/web_ui/public/data_source.py b/zhenxun/builtin_plugins/web_ui/public/data_source.py index 9f5a657e..51b29533 100644 --- a/zhenxun/builtin_plugins/web_ui/public/data_source.py +++ b/zhenxun/builtin_plugins/web_ui/public/data_source.py @@ -18,6 +18,7 @@ async def update_webui_assets(): download_url = await GithubUtils.parse_github_url( WEBUI_DIST_GITHUB_URL ).get_archive_download_urls() + logger.info("开始下载 webui_assets 资源...", COMMAND_NAME) if await AsyncHttpx.download_file( download_url, webui_assets_path, follow_redirects=True ): diff --git a/zhenxun/builtin_plugins/web_ui/utils.py b/zhenxun/builtin_plugins/web_ui/utils.py index a7e22a07..84459114 100644 --- a/zhenxun/builtin_plugins/web_ui/utils.py +++ b/zhenxun/builtin_plugins/web_ui/utils.py @@ -2,6 +2,7 @@ import contextlib from datetime import datetime, timedelta, timezone import os from pathlib import Path +import re from fastapi import Depends, HTTPException from fastapi.security import OAuth2PasswordBearer @@ -28,6 +29,45 @@ if token_file.exists(): token_data = json.load(open(token_file, encoding="utf8")) +def validate_path(path_str: str | None) -> tuple[Path | None, str | None]: + """验证路径是否安全 + + 参数: + path_str: 用户输入的路径 + + 返回: + tuple[Path | None, str | None]: (验证后的路径, 错误信息) + """ + try: + if not path_str: + return Path().resolve(), None + + # 1. 移除任何可能的路径遍历尝试 + path_str = re.sub(r"[\\/]\.\.[\\/]", "", path_str) + + # 2. 规范化路径并转换为绝对路径 + path = Path(path_str).resolve() + + # 3. 获取项目根目录 + root_dir = Path().resolve() + + # 4. 验证路径是否在项目根目录内 + try: + if not path.is_relative_to(root_dir): + return None, "访问路径超出允许范围" + except ValueError: + return None, "无效的路径格式" + + # 5. 验证路径是否包含任何危险字符 + if any(c in str(path) for c in ["..", "~", "*", "?", ">", "<", "|", '"']): + return None, "路径包含非法字符" + + # 6. 验证路径长度是否合理 + return (None, "路径长度超出限制") if len(str(path)) > 4096 else (path, None) + except Exception as e: + return None, f"路径验证失败: {e!s}" + + GROUP_HELP_PATH = DATA_PATH / "group_help" SIMPLE_HELP_IMAGE = IMAGE_PATH / "SIMPLE_HELP.png" SIMPLE_DETAIL_HELP_IMAGE = IMAGE_PATH / "SIMPLE_DETAIL_HELP.png" diff --git a/zhenxun/models/bot_console.py b/zhenxun/models/bot_console.py index d329c551..cf1faaf2 100644 --- a/zhenxun/models/bot_console.py +++ b/zhenxun/models/bot_console.py @@ -29,7 +29,7 @@ class BotConsole(Model): class Meta: # pyright: ignore [reportIncompatibleVariableOverride] table = "bot_console" table_description = "Bot数据表" - + cache_type = CacheType.BOT @staticmethod diff --git a/zhenxun/models/level_user.py b/zhenxun/models/level_user.py index 0a926e6a..88d69274 100644 --- a/zhenxun/models/level_user.py +++ b/zhenxun/models/level_user.py @@ -3,7 +3,6 @@ from tortoise import fields from zhenxun.services.db_context import Model from zhenxun.utils.enum import CacheType - class LevelUser(Model): id = fields.IntField(pk=True, generated=True, auto_increment=True) """自增id""" @@ -56,7 +55,7 @@ class LevelUser(Model): level: 权限等级 group_flag: 是否被自动更新刷新权限 0:是, 1:否. """ - if await cls.exists(user_id=user_id, group_id=group_id, user_level=level): + if await cls.exists(user_id=user_id, group_id=group_id, level=level): # 权限相同时跳过 return await cls.update_or_create( diff --git a/zhenxun/models/plugin_info.py b/zhenxun/models/plugin_info.py index aeecc71b..793b31d5 100644 --- a/zhenxun/models/plugin_info.py +++ b/zhenxun/models/plugin_info.py @@ -58,7 +58,7 @@ class PluginInfo(Model): class Meta: # pyright: ignore [reportIncompatibleVariableOverride] table = "plugin_info" table_description = "插件基本信息" - + cache_type = CacheType.PLUGINS @classmethod diff --git a/zhenxun/models/schedule_info.py b/zhenxun/models/schedule_info.py deleted file mode 100644 index c7583078..00000000 --- a/zhenxun/models/schedule_info.py +++ /dev/null @@ -1,38 +0,0 @@ -from tortoise import fields - -from zhenxun.services.db_context import Model - - -class ScheduleInfo(Model): - id = fields.IntField(pk=True, generated=True, auto_increment=True) - """自增id""" - bot_id = fields.CharField( - 255, null=True, default=None, description="任务关联的Bot ID" - ) - """任务关联的Bot ID""" - plugin_name = fields.CharField(255, description="插件模块名") - """插件模块名""" - group_id = fields.CharField( - 255, - null=True, - description="群组ID, '__ALL_GROUPS__' 表示所有群, 为空表示全局任务", - ) - """群组ID, 为空表示全局任务""" - trigger_type = fields.CharField( - max_length=20, default="cron", description="触发器类型 (cron, interval, date)" - ) - """触发器类型 (cron, interval, date)""" - trigger_config = fields.JSONField(description="触发器具体配置") - """触发器具体配置""" - job_kwargs = fields.JSONField( - default=dict, description="传递给任务函数的额外关键字参数" - ) - """传递给任务函数的额外关键字参数""" - is_enabled = fields.BooleanField(default=True, description="是否启用") - """是否启用""" - create_time = fields.DatetimeField(auto_now_add=True) - """创建时间""" - - class Meta: # pyright: ignore [reportIncompatibleVariableOverride] - table = "schedule_info" - table_description = "通用定时任务表" diff --git a/zhenxun/plugins/__init__.py b/zhenxun/plugins/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/zhenxun/services/cache.py b/zhenxun/services/cache.py index 4957029f..8c5fa820 100644 --- a/zhenxun/services/cache.py +++ b/zhenxun/services/cache.py @@ -12,7 +12,7 @@ from aiocache.serializers import JsonSerializer import nonebot from nonebot.compat import model_dump from nonebot.utils import is_coroutine_callable -from pydantic import BaseModel +from pydantic import BaseModel, PrivateAttr from zhenxun.services.log import logger @@ -121,7 +121,7 @@ class CacheData(BaseModel): lazy_load: bool = True # 默认延迟加载 result_model: type | None = None _keys: set[str] = set() # 存储所有缓存键 - _cache: BaseCache | AioCache + _cache: BaseCache | AioCache = PrivateAttr() class Config: arbitrary_types_allowed = True diff --git a/zhenxun/services/db_context.py b/zhenxun/services/db_context.py index 928334e3..85aee620 100644 --- a/zhenxun/services/db_context.py +++ b/zhenxun/services/db_context.py @@ -2,7 +2,6 @@ from asyncio import Semaphore from collections.abc import Iterable from typing import Any, ClassVar from typing_extensions import Self -from urllib.parse import urlparse import nonebot from nonebot.utils import is_coroutine_callable @@ -32,17 +31,20 @@ def _(): global CACHE_FLAG CACHE_FLAG = True +driver = nonebot.get_driver() + class Model(TortoiseModel): """ 自动添加模块 + + Args: + Model_: Model """ - sem_data: ClassVar[dict[str, dict[str, Semaphore]]] = {} - + def __init_subclass__(cls, **kwargs): - if cls.__module__ not in MODELS: - MODELS.append(cls.__module__) + MODELS.append(cls.__module__) if func := getattr(cls, "_run_script", None): SCRIPT_METHOD.append((cls.__module__, func)) @@ -169,7 +171,7 @@ class Model(TortoiseModel): await CacheRoot.reload(cache_type) -class DbUrlIsNode(HookPriorityException): +class DbUrlMissing(Exception): """ 数据库链接地址为空 """ @@ -188,6 +190,7 @@ class DbConnectError(Exception): @PriorityLifecycle.on_startup(priority=1) async def init(): if not BotConfig.db_url: + # raise DbUrlMissing("数据库配置为空,请在.env.dev中配置DB_URL...") error = f""" ********************************************************************** 🌟 **************************** 配置为空 ************************* 🌟 @@ -196,74 +199,13 @@ async def init(): *********************************************************************** *********************************************************************** """ - raise DbUrlIsNode("\n" + error.strip()) + raise DbUrlMissing("\n" + error.strip()) try: - db_url = BotConfig.db_url - url_scheme = db_url.split(":", 1)[0] - config = None - - if url_scheme in ("postgres", "postgresql"): - # 解析 db_url - url = urlparse(db_url) - credentials = { - "host": url.hostname, - "port": url.port or 5432, - "user": url.username, - "password": url.password, - "database": url.path.lstrip("/"), - "minsize": 1, - "maxsize": 50, - } - config = { - "connections": { - "default": { - "engine": "tortoise.backends.asyncpg", - "credentials": credentials, - } - }, - "apps": { - "models": { - "models": MODELS, - "default_connection": "default", - } - }, - } - elif url_scheme in ("mysql", "mysql+aiomysql"): - url = urlparse(db_url) - credentials = { - "host": url.hostname, - "port": url.port or 3306, - "user": url.username, - "password": url.password, - "database": url.path.lstrip("/"), - "minsize": 1, - "maxsize": 50, - } - config = { - "connections": { - "default": { - "engine": "tortoise.backends.mysql", - "credentials": credentials, - } - }, - "apps": { - "models": { - "models": MODELS, - "default_connection": "default", - } - }, - } - else: - # sqlite 或其它,直接用 db_url - await Tortoise.init( - db_url=db_url, - modules={"models": MODELS}, - timezone="Asia/Shanghai", - ) - - if config: - await Tortoise.init(config=config) - + await Tortoise.init( + db_url=BotConfig.db_url, + modules={"models": MODELS}, + timezone="Asia/Shanghai", + ) if SCRIPT_METHOD: db = Tortoise.get_connection("default") logger.debug( @@ -282,6 +224,7 @@ async def init(): logger.debug(f"执行SQL: {sql}") try: await db.execute_query_dict(sql) + # await TestSQL.raw(sql) except Exception as e: logger.debug(f"执行SQL: {sql} 错误...", e=e) if sql_list: @@ -293,4 +236,4 @@ async def init(): async def disconnect(): - await connections.close_all() \ No newline at end of file + await connections.close_all() diff --git a/zhenxun/services/plugin_init.py b/zhenxun/services/plugin_init.py index 159e042c..a622a9e8 100644 --- a/zhenxun/services/plugin_init.py +++ b/zhenxun/services/plugin_init.py @@ -6,6 +6,7 @@ from nonebot.utils import is_coroutine_callable from pydantic import BaseModel from zhenxun.services.log import logger +from zhenxun.utils.manager.priority_manager import PriorityLifecycle driver = nonebot.get_driver() @@ -100,6 +101,6 @@ class PluginInitManager: logger.error(f"执行: {module_path}:remove 失败", e=e) -@driver.on_startup +@PriorityLifecycle.on_startup(priority=5) async def _(): await PluginInitManager.install_all() diff --git a/zhenxun/utils/_build_mat.py b/zhenxun/utils/_build_mat.py index de73e69d..a3de3087 100644 --- a/zhenxun/utils/_build_mat.py +++ b/zhenxun/utils/_build_mat.py @@ -1,12 +1,17 @@ from io import BytesIO from pathlib import Path import random +import sys from pydantic import BaseModel, Field -from strenum import StrEnum from ._build_image import BuildImage +if sys.version_info >= (3, 11): + from enum import StrEnum +else: + from strenum import StrEnum + class MatType(StrEnum): LINE = "LINE" diff --git a/zhenxun/utils/_image_template.py b/zhenxun/utils/_image_template.py index 7f27db76..c7678b2f 100644 --- a/zhenxun/utils/_image_template.py +++ b/zhenxun/utils/_image_template.py @@ -3,9 +3,12 @@ from io import BytesIO from pathlib import Path import random +from nonebot_plugin_htmlrender import md_to_pic, template_to_pic from PIL.ImageFont import FreeTypeFont from pydantic import BaseModel +from zhenxun.configs.path_config import TEMPLATE_PATH + from ._build_image import BuildImage @@ -283,3 +286,191 @@ class ImageTemplate: width = max(width, w) height += h return width, height + + +class MarkdownTable: + def __init__(self, headers: list[str], rows: list[list[str]]): + self.headers = headers + self.rows = rows + + def to_markdown(self) -> str: + """将表格转换为Markdown格式""" + header_row = "| " + " | ".join(self.headers) + " |" + separator_row = "| " + " | ".join(["---"] * len(self.headers)) + " |" + data_rows = "\n".join( + "| " + " | ".join(map(str, row)) + " |" for row in self.rows + ) + return f"{header_row}\n{separator_row}\n{data_rows}" + + +class Markdown: + def __init__(self, data: list[str] | None = None): + if data is None: + data = [] + self._data = data + + def text(self, text: str) -> "Markdown": + """添加Markdown文本""" + self._data.append(text) + return self + + def head(self, text: str, level: int = 1) -> "Markdown": + """添加Markdown标题""" + if level < 1 or level > 6: + raise ValueError("标题级别必须在1到6之间") + self._data.append(f"{'#' * level} {text}") + return self + + def image(self, content: str | Path, add_empty_line: bool = True) -> "Markdown": + """添加Markdown图片 + + 参数: + content: 图片内容,可以是url地址,图片路径或base64字符串. + add_empty_line: 默认添加换行. + + 返回: + Markdown: Markdown + """ + if isinstance(content, Path): + content = str(content.absolute()) + if content.startswith("base64"): + content = f"data:image/png;base64,{content.split('base64://', 1)[-1]}" + self._data.append(f"![image]({content})") + if add_empty_line: + self._add_empty_line() + return self + + def quote(self, text: str | list[str]) -> "Markdown": + """添加Markdown引用文本 + + 参数: + text: 引用文本内容,可以是字符串或字符串列表. + 如果是列表,则每个元素都会被单独引用。 + + 返回: + Markdown: Markdown + """ + if isinstance(text, str): + self._data.append(f"> {text}") + elif isinstance(text, list): + for t in text: + self._data.append(f"> {t}") + self._add_empty_line() + return self + + def code(self, code: str, language: str = "python") -> "Markdown": + """添加Markdown代码块""" + self._data.append(f"```{language}\n{code}\n```") + return self + + def table(self, headers: list[str], rows: list[list[str]]) -> "Markdown": + """添加Markdown表格""" + table = MarkdownTable(headers, rows) + self._data.append(table.to_markdown()) + return self + + def list(self, items: list[str | list[str]]) -> "Markdown": + """添加Markdown列表""" + self._add_empty_line() + _text = "\n".join( + f"- {item}" + if isinstance(item, str) + else "\n".join(f"- {sub_item}" for sub_item in item) + for item in items + ) + self._data.append(_text) + return self + + def _add_empty_line(self): + """添加空行""" + self._data.append("") + + async def build(self, width: int = 800, css_path: Path | None = None) -> bytes: + """构建Markdown文本""" + if css_path is not None: + return await md_to_pic( + md="\n".join(self._data), width=width, css_path=str(css_path.absolute()) + ) + return await md_to_pic(md="\n".join(self._data), width=width) + + +class Notebook: + def __init__(self, data: list[dict] | None = None): + self._data = data if data is not None else [] + + def text(self, text: str) -> "Notebook": + """添加Notebook文本""" + self._data.append({"type": "paragraph", "text": text}) + return self + + def head(self, text: str, level: int = 1) -> "Notebook": + """添加Notebook标题""" + if not 1 <= level <= 4: + raise ValueError("标题级别必须在1-4之间") + self._data.append({"type": "heading", "text": text, "level": level}) + return self + + def image( + self, + content: str | Path, + caption: str | None = None, + ) -> "Notebook": + """添加Notebook图片 + + 参数: + content: 图片内容,可以是url地址,图片路径或base64字符串. + caption: 图片说明. + + 返回: + Notebook: Notebook + """ + if isinstance(content, Path): + content = str(content.absolute()) + if content.startswith("base64"): + content = f"data:image/png;base64,{content.split('base64://', 1)[-1]}" + self._data.append({"type": "image", "src": content, "caption": caption}) + return self + + def quote(self, text: str | list[str]) -> "Notebook": + """添加Notebook引用文本 + + 参数: + text: 引用文本内容,可以是字符串或字符串列表. + 如果是列表,则每个元素都会被单独引用。 + + 返回: + Notebook: Notebook + """ + if isinstance(text, str): + self._data.append({"type": "blockquote", "text": text}) + elif isinstance(text, list): + for t in text: + self._data.append({"type": "blockquote", "text": text}) + return self + + def code(self, code: str, language: str = "python") -> "Notebook": + """添加Notebook代码块""" + self._data.append({"type": "code", "code": code, "language": language}) + return self + + def list(self, items: list[str], ordered: bool = False) -> "Notebook": + """添加Notebook列表""" + self._data.append({"type": "list", "data": items, "ordered": ordered}) + return self + + def add_divider(self) -> None: + """添加分隔线""" + self._data.append({"type": "divider"}) + + async def build(self) -> bytes: + """构建Notebook""" + return await template_to_pic( + template_path=str((TEMPLATE_PATH / "notebook").absolute()), + template_name="main.html", + templates={"elements": self._data}, + pages={ + "viewport": {"width": 700, "height": 1000}, + "base_url": f"file://{TEMPLATE_PATH}", + }, + wait=2, + ) diff --git a/zhenxun/utils/exception.py b/zhenxun/utils/exception.py index 889c6c5c..8ec925ec 100644 --- a/zhenxun/utils/exception.py +++ b/zhenxun/utils/exception.py @@ -10,8 +10,6 @@ class HookPriorityException(BaseException): return self.info - - class NotFoundError(Exception): """ 未发现 diff --git a/zhenxun/utils/github_utils/const.py b/zhenxun/utils/github_utils/const.py index 23effa4c..68fffad9 100644 --- a/zhenxun/utils/github_utils/const.py +++ b/zhenxun/utils/github_utils/const.py @@ -21,6 +21,9 @@ CACHED_API_TTL = 300 RAW_CONTENT_FORMAT = "https://raw.githubusercontent.com/{owner}/{repo}/{branch}/{path}" """raw content格式""" +GITEE_RAW_CONTENT_FORMAT = "https://gitee.com/{owner}/{repo}/raw/main/{path}" +"""gitee raw content格式""" + ARCHIVE_URL_FORMAT = "https://github.com/{owner}/{repo}/archive/refs/heads/{branch}.zip" """archive url格式""" diff --git a/zhenxun/utils/github_utils/func.py b/zhenxun/utils/github_utils/func.py index b568b5bd..db3afa03 100644 --- a/zhenxun/utils/github_utils/func.py +++ b/zhenxun/utils/github_utils/func.py @@ -4,6 +4,7 @@ from zhenxun.utils.http_utils import AsyncHttpx from .const import ( ARCHIVE_URL_FORMAT, + GITEE_RAW_CONTENT_FORMAT, RAW_CONTENT_FORMAT, RELEASE_ASSETS_FORMAT, RELEASE_SOURCE_FORMAT, @@ -21,15 +22,11 @@ async def __get_fastest_formats(formats: dict[str, str]) -> list[str]: async def get_fastest_raw_formats() -> list[str]: """获取最快的raw下载地址格式""" formats: dict[str, str] = { + "https://gitee.com/": GITEE_RAW_CONTENT_FORMAT, "https://raw.githubusercontent.com/": RAW_CONTENT_FORMAT, "https://ghproxy.cc/": f"https://ghproxy.cc/{RAW_CONTENT_FORMAT}", - "https://mirror.ghproxy.com/": f"https://mirror.ghproxy.com/{RAW_CONTENT_FORMAT}", "https://gh-proxy.com/": f"https://gh-proxy.com/{RAW_CONTENT_FORMAT}", "https://cdn.jsdelivr.net/": "https://cdn.jsdelivr.net/gh/{owner}/{repo}@{branch}/{path}", - #"https://raw.gitcode.com/": "https://raw.gitcode.com/qq_41605780/{repo}/raw/{branch}/{path}", # ✅ 新增 GitCode raw 格式 - #"https://raw.gitcode.com/": "https://raw.gitcode.com/ATTomato/{repo}/raw/{branch}/{path}" - #"https://raw.gitcode.com/": "https://raw.gitcode.com/{owner}/{repo}/raw/{branch}/{path}" - } return await __get_fastest_formats(formats) @@ -40,7 +37,6 @@ async def get_fastest_archive_formats() -> list[str]: formats: dict[str, str] = { "https://github.com/": ARCHIVE_URL_FORMAT, "https://ghproxy.cc/": f"https://ghproxy.cc/{ARCHIVE_URL_FORMAT}", - "https://mirror.ghproxy.com/": f"https://mirror.ghproxy.com/{ARCHIVE_URL_FORMAT}", "https://gh-proxy.com/": f"https://gh-proxy.com/{ARCHIVE_URL_FORMAT}", } return await __get_fastest_formats(formats) @@ -52,7 +48,6 @@ async def get_fastest_release_formats() -> list[str]: formats: dict[str, str] = { "https://objects.githubusercontent.com/": RELEASE_ASSETS_FORMAT, "https://ghproxy.cc/": f"https://ghproxy.cc/{RELEASE_ASSETS_FORMAT}", - "https://mirror.ghproxy.com/": f"https://mirror.ghproxy.com/{RELEASE_ASSETS_FORMAT}", "https://gh-proxy.com/": f"https://gh-proxy.com/{RELEASE_ASSETS_FORMAT}", } return await __get_fastest_formats(formats) @@ -65,4 +60,4 @@ async def get_fastest_release_source_formats() -> list[str]: "https://codeload.github.com/": RELEASE_SOURCE_FORMAT, "https://p.102333.xyz/": f"https://p.102333.xyz/{RELEASE_SOURCE_FORMAT}", } - return await __get_fastest_formats(formats) \ No newline at end of file + return await __get_fastest_formats(formats) diff --git a/zhenxun/utils/github_utils/models.py b/zhenxun/utils/github_utils/models.py index e3e5dfe3..fb690616 100644 --- a/zhenxun/utils/github_utils/models.py +++ b/zhenxun/utils/github_utils/models.py @@ -1,13 +1,18 @@ import contextlib +import sys from typing import Protocol from aiocache import cached from nonebot.compat import model_dump from pydantic import BaseModel, Field -from strenum import StrEnum from zhenxun.utils.http_utils import AsyncHttpx +if sys.version_info >= (3, 11): + from enum import StrEnum +else: + from strenum import StrEnum + from .const import ( CACHED_API_TTL, GIT_API_COMMIT_FORMAT, diff --git a/zhenxun/utils/html_template/__init__.py b/zhenxun/utils/html_template/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/zhenxun/utils/html_template/__init__.py @@ -0,0 +1 @@ + diff --git a/zhenxun/utils/html_template/component.py b/zhenxun/utils/html_template/component.py new file mode 100644 index 00000000..c23ed503 --- /dev/null +++ b/zhenxun/utils/html_template/component.py @@ -0,0 +1,36 @@ +from abc import ABC +from typing import Literal + +from pydantic import BaseModel + + +class Style(BaseModel): + """常用样式""" + + padding: str = "0px" + margin: str = "0px" + border: str = "0px" + border_radius: str = "0px" + text_align: Literal["left", "right", "center"] = "left" + color: str = "#000" + font_size: str = "16px" + + +class Component(ABC): + def __init__(self, background_color: str = "#fff", is_container: bool = False): + self.extra_style = [] + self.style = Style() + self.background_color = background_color + self.is_container = is_container + self.children = [] + + def add_child(self, child: "Component | str"): + self.children.append(child) + + def set_style(self, style: Style): + self.style = style + + def add_style(self, style: str): + self.extra_style.append(style) + + def to_html(self) -> str: ... diff --git a/zhenxun/utils/html_template/components/title.py b/zhenxun/utils/html_template/components/title.py new file mode 100644 index 00000000..860ad17e --- /dev/null +++ b/zhenxun/utils/html_template/components/title.py @@ -0,0 +1,15 @@ +from ..component import Component, Style +from ..container import Row + + +class Title(Component): + def __init__(self, text: str, color: str = "#000"): + self.text = text + self.color = color + + def build(self): + row = Row() + style = Style(font_size="36px", color=self.color) + row.set_style(style) + + # def diff --git a/zhenxun/utils/html_template/container.py b/zhenxun/utils/html_template/container.py new file mode 100644 index 00000000..3d5341c0 --- /dev/null +++ b/zhenxun/utils/html_template/container.py @@ -0,0 +1,31 @@ +from .component import Component + + +class Row(Component): + def __init__(self, background_color: str = "#fff"): + super().__init__(background_color, True) + + +class Col(Component): + def __init__(self, background_color: str = "#fff"): + super().__init__(background_color, True) + + +class Container(Component): + def __init__(self, background_color: str = "#fff"): + super().__init__(background_color, True) + self.children = [] + + +class GlobalOverview: + def __init__(self, name: str): + self.name = name + self.class_name: dict[str, list[str]] = {} + self.content = None + + def set_content(self, content: Container): + self.content = content + + def add_class(self, class_name: str, contents: list[str]): + """全局样式""" + self.class_name[class_name] = contents diff --git a/zhenxun/utils/manager/message_manager.py b/zhenxun/utils/manager/message_manager.py index e714c8d8..ee34369d 100644 --- a/zhenxun/utils/manager/message_manager.py +++ b/zhenxun/utils/manager/message_manager.py @@ -22,6 +22,4 @@ class MessageManager: @classmethod def get(cls, uid: str) -> list[str]: - if uid in cls.data: - return cls.data[uid] - return [] + return cls.data[uid] if uid in cls.data else [] diff --git a/zhenxun/utils/manager/schedule_manager.py b/zhenxun/utils/manager/schedule_manager.py deleted file mode 100644 index a3b21272..00000000 --- a/zhenxun/utils/manager/schedule_manager.py +++ /dev/null @@ -1,810 +0,0 @@ -import asyncio -from collections.abc import Callable, Coroutine -import copy -import inspect -import random -from typing import ClassVar - -import nonebot -from nonebot import get_bots -from nonebot_plugin_apscheduler import scheduler -from pydantic import BaseModel, ValidationError - -from zhenxun.configs.config import Config -from zhenxun.models.schedule_info import ScheduleInfo -from zhenxun.services.log import logger -from zhenxun.utils.common_utils import CommonUtils -from zhenxun.utils.manager.priority_manager import PriorityLifecycle -from zhenxun.utils.platform import PlatformUtils - -SCHEDULE_CONCURRENCY_KEY = "all_groups_concurrency_limit" - - -class SchedulerManager: - """ - 一个通用的、持久化的定时任务管理器,供所有插件使用。 - """ - - _registered_tasks: ClassVar[ - dict[str, dict[str, Callable | type[BaseModel] | None]] - ] = {} - _JOB_PREFIX = "zhenxun_schedule_" - _running_tasks: ClassVar[set] = set() - - def register( - self, plugin_name: str, params_model: type[BaseModel] | None = None - ) -> Callable: - """ - 注册一个可调度的任务函数。 - 被装饰的函数签名应为 `async def func(group_id: str | None, **kwargs)` - - Args: - plugin_name (str): 插件的唯一名称 (通常是模块名)。 - params_model (type[BaseModel], optional): 一个 Pydantic BaseModel 类, - 用于定义和验证任务函数接受的额外参数。 - """ - - def decorator(func: Callable[..., Coroutine]) -> Callable[..., Coroutine]: - if plugin_name in self._registered_tasks: - logger.warning(f"插件 '{plugin_name}' 的定时任务已被重复注册。") - self._registered_tasks[plugin_name] = { - "func": func, - "model": params_model, - } - model_name = params_model.__name__ if params_model else "无" - logger.debug( - f"插件 '{plugin_name}' 的定时任务已注册,参数模型: {model_name}" - ) - return func - - return decorator - - def get_registered_plugins(self) -> list[str]: - """获取所有已注册定时任务的插件列表。""" - return list(self._registered_tasks.keys()) - - def _get_job_id(self, schedule_id: int) -> str: - """根据数据库ID生成唯一的 APScheduler Job ID。""" - return f"{self._JOB_PREFIX}{schedule_id}" - - async def _execute_job(self, schedule_id: int): - """ - APScheduler 调度的入口函数。 - 根据 schedule_id 处理特定任务、所有群组任务或全局任务。 - """ - schedule = await ScheduleInfo.get_or_none(id=schedule_id) - if not schedule or not schedule.is_enabled: - logger.warning(f"定时任务 {schedule_id} 不存在或已禁用,跳过执行。") - return - - plugin_name = schedule.plugin_name - - task_meta = self._registered_tasks.get(plugin_name) - if not task_meta: - logger.error( - f"无法执行定时任务:插件 '{plugin_name}' 未注册或已卸载。将禁用该任务。" - ) - schedule.is_enabled = False - await schedule.save(update_fields=["is_enabled"]) - self._remove_aps_job(schedule.id) - return - - try: - if schedule.bot_id: - bot = nonebot.get_bot(schedule.bot_id) - else: - bot = nonebot.get_bot() - logger.debug( - f"任务 {schedule_id} 未关联特定Bot,使用默认Bot {bot.self_id}" - ) - except KeyError: - logger.warning( - f"定时任务 {schedule_id} 需要的 Bot {schedule.bot_id} " - f"不在线,本次执行跳过。" - ) - return - except ValueError: - logger.warning(f"当前没有Bot在线,定时任务 {schedule_id} 跳过。") - return - - if schedule.group_id == "__ALL_GROUPS__": - await self._execute_for_all_groups(schedule, task_meta, bot) - else: - await self._execute_for_single_target(schedule, task_meta, bot) - - async def _execute_for_all_groups( - self, schedule: ScheduleInfo, task_meta: dict, bot - ): - """为所有群组执行任务,并处理优先级覆盖。""" - plugin_name = schedule.plugin_name - - concurrency_limit = Config.get_config( - "SchedulerManager", SCHEDULE_CONCURRENCY_KEY, 5 - ) - if not isinstance(concurrency_limit, int) or concurrency_limit <= 0: - logger.warning( - f"无效的定时任务并发限制配置 '{concurrency_limit}',将使用默认值 5。" - ) - concurrency_limit = 5 - - logger.info( - f"开始执行针对 [所有群组] 的任务 " - f"(ID: {schedule.id}, 插件: {plugin_name}, Bot: {bot.self_id})," - f"并发限制: {concurrency_limit}" - ) - - all_gids = set() - try: - group_list, _ = await PlatformUtils.get_group_list(bot) - all_gids.update( - g.group_id for g in group_list if g.group_id and not g.channel_id - ) - except Exception as e: - logger.error(f"为 'all' 任务获取 Bot {bot.self_id} 的群列表失败", e=e) - return - - specific_tasks_gids = set( - await ScheduleInfo.filter( - plugin_name=plugin_name, group_id__in=list(all_gids) - ).values_list("group_id", flat=True) - ) - - semaphore = asyncio.Semaphore(concurrency_limit) - - async def worker(gid: str): - """使用 Semaphore 包装单个群组的任务执行""" - async with semaphore: - temp_schedule = copy.deepcopy(schedule) - temp_schedule.group_id = gid - await self._execute_for_single_target(temp_schedule, task_meta, bot) - await asyncio.sleep(random.uniform(0.1, 0.5)) - - tasks_to_run = [] - for gid in all_gids: - if gid in specific_tasks_gids: - logger.debug(f"群组 {gid} 已有特定任务,跳过 'all' 任务的执行。") - continue - tasks_to_run.append(worker(gid)) - - if tasks_to_run: - await asyncio.gather(*tasks_to_run) - - async def _execute_for_single_target( - self, schedule: ScheduleInfo, task_meta: dict, bot - ): - """为单个目标(具体群组或全局)执行任务。""" - plugin_name = schedule.plugin_name - group_id = schedule.group_id - - try: - is_blocked = await CommonUtils.task_is_block(bot, plugin_name, group_id) - if is_blocked: - target_desc = f"群 {group_id}" if group_id else "全局" - logger.info( - f"插件 '{plugin_name}' 的定时任务在目标 [{target_desc}]" - "因功能被禁用而跳过执行。" - ) - return - - task_func = task_meta["func"] - job_kwargs = schedule.job_kwargs - if not isinstance(job_kwargs, dict): - logger.error( - f"任务 {schedule.id} 的 job_kwargs 不是字典类型: {type(job_kwargs)}" - ) - return - - sig = inspect.signature(task_func) - if "bot" in sig.parameters: - job_kwargs["bot"] = bot - - logger.info( - f"插件 '{plugin_name}' 开始为目标 [{group_id or '全局'}] " - f"执行定时任务 (ID: {schedule.id})。" - ) - task = asyncio.create_task(task_func(group_id, **job_kwargs)) - self._running_tasks.add(task) - task.add_done_callback(self._running_tasks.discard) - await task - except Exception as e: - logger.error( - f"执行定时任务 (ID: {schedule.id}, 插件: {plugin_name}, " - f"目标: {group_id or '全局'}) 时发生异常", - e=e, - ) - - def _validate_and_prepare_kwargs( - self, plugin_name: str, job_kwargs: dict | None - ) -> tuple[bool, str | dict]: - """验证并准备任务参数,应用默认值""" - task_meta = self._registered_tasks.get(plugin_name) - if not task_meta: - return False, f"插件 '{plugin_name}' 未注册。" - - params_model = task_meta.get("model") - job_kwargs = job_kwargs if job_kwargs is not None else {} - - if not params_model: - if job_kwargs: - logger.warning( - f"插件 '{plugin_name}' 未定义参数模型,但收到了参数: {job_kwargs}" - ) - return True, job_kwargs - - if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)): - logger.error(f"插件 '{plugin_name}' 的参数模型不是有效的 BaseModel 类") - return False, f"插件 '{plugin_name}' 的参数模型配置错误" - - try: - model_validate = getattr(params_model, "model_validate", None) - if not model_validate: - return False, f"插件 '{plugin_name}' 的参数模型不支持验证" - - validated_model = model_validate(job_kwargs) - - model_dump = getattr(validated_model, "model_dump", None) - if not model_dump: - return False, f"插件 '{plugin_name}' 的参数模型不支持导出" - - return True, model_dump() - except ValidationError as e: - errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()] - error_str = "\n".join(errors) - msg = f"插件 '{plugin_name}' 的任务参数验证失败:\n{error_str}" - return False, msg - - def _add_aps_job(self, schedule: ScheduleInfo): - """根据 ScheduleInfo 对象添加或更新一个 APScheduler 任务。""" - job_id = self._get_job_id(schedule.id) - try: - scheduler.remove_job(job_id) - except Exception: - pass - - if not isinstance(schedule.trigger_config, dict): - logger.error( - f"任务 {schedule.id} 的 trigger_config 不是字典类型: " - f"{type(schedule.trigger_config)}" - ) - return - - scheduler.add_job( - self._execute_job, - trigger=schedule.trigger_type, - id=job_id, - misfire_grace_time=300, - args=[schedule.id], - **schedule.trigger_config, - ) - logger.debug( - f"已在 APScheduler 中添加/更新任务: {job_id} " - f"with trigger: {schedule.trigger_config}" - ) - - def _remove_aps_job(self, schedule_id: int): - """移除一个 APScheduler 任务。""" - job_id = self._get_job_id(schedule_id) - try: - scheduler.remove_job(job_id) - logger.debug(f"已从 APScheduler 中移除任务: {job_id}") - except Exception: - pass - - async def add_schedule( - self, - plugin_name: str, - group_id: str | None, - trigger_type: str, - trigger_config: dict, - job_kwargs: dict | None = None, - bot_id: str | None = None, - ) -> tuple[bool, str]: - """ - 添加或更新一个定时任务。 - """ - if plugin_name not in self._registered_tasks: - return False, f"插件 '{plugin_name}' 没有注册可用的定时任务。" - - is_valid, result = self._validate_and_prepare_kwargs(plugin_name, job_kwargs) - if not is_valid: - return False, str(result) - - validated_job_kwargs = result - - effective_bot_id = bot_id if group_id == "__ALL_GROUPS__" else None - - search_kwargs = { - "plugin_name": plugin_name, - "group_id": group_id, - } - if effective_bot_id: - search_kwargs["bot_id"] = effective_bot_id - else: - search_kwargs["bot_id__isnull"] = True - - defaults = { - "trigger_type": trigger_type, - "trigger_config": trigger_config, - "job_kwargs": validated_job_kwargs, - "is_enabled": True, - } - - schedule = await ScheduleInfo.filter(**search_kwargs).first() - created = False - - if schedule: - for key, value in defaults.items(): - setattr(schedule, key, value) - await schedule.save() - else: - creation_kwargs = { - "plugin_name": plugin_name, - "group_id": group_id, - "bot_id": effective_bot_id, - **defaults, - } - schedule = await ScheduleInfo.create(**creation_kwargs) - created = True - self._add_aps_job(schedule) - action = "设置" if created else "更新" - return True, f"已成功{action}插件 '{plugin_name}' 的定时任务。" - - async def add_schedule_for_all( - self, - plugin_name: str, - trigger_type: str, - trigger_config: dict, - job_kwargs: dict | None = None, - ) -> tuple[int, int]: - """为所有机器人所在的群组添加定时任务。""" - if plugin_name not in self._registered_tasks: - raise ValueError(f"插件 '{plugin_name}' 没有注册可用的定时任务。") - - groups = set() - for bot in get_bots().values(): - try: - group_list, _ = await PlatformUtils.get_group_list(bot) - groups.update( - g.group_id for g in group_list if g.group_id and not g.channel_id - ) - except Exception as e: - logger.error(f"获取 Bot {bot.self_id} 的群列表失败", e=e) - - success_count = 0 - fail_count = 0 - for gid in groups: - try: - success, _ = await self.add_schedule( - plugin_name, gid, trigger_type, trigger_config, job_kwargs - ) - if success: - success_count += 1 - else: - fail_count += 1 - except Exception as e: - logger.error(f"为群 {gid} 添加定时任务失败: {e}", e=e) - fail_count += 1 - await asyncio.sleep(0.05) - return success_count, fail_count - - async def update_schedule( - self, - schedule_id: int, - trigger_type: str | None = None, - trigger_config: dict | None = None, - job_kwargs: dict | None = None, - ) -> tuple[bool, str]: - """部分更新一个已存在的定时任务。""" - schedule = await self.get_schedule_by_id(schedule_id) - if not schedule: - return False, f"未找到 ID 为 {schedule_id} 的任务。" - - updated_fields = [] - if trigger_config is not None: - schedule.trigger_config = trigger_config - updated_fields.append("trigger_config") - - if trigger_type is not None and schedule.trigger_type != trigger_type: - schedule.trigger_type = trigger_type - updated_fields.append("trigger_type") - - if job_kwargs is not None: - if not isinstance(schedule.job_kwargs, dict): - return False, f"任务 {schedule_id} 的 job_kwargs 数据格式错误。" - - merged_kwargs = schedule.job_kwargs.copy() - merged_kwargs.update(job_kwargs) - - is_valid, result = self._validate_and_prepare_kwargs( - schedule.plugin_name, merged_kwargs - ) - if not is_valid: - return False, str(result) - - schedule.job_kwargs = result # type: ignore - updated_fields.append("job_kwargs") - - if not updated_fields: - return True, "没有任何需要更新的配置。" - - await schedule.save(update_fields=updated_fields) - self._add_aps_job(schedule) - return True, f"成功更新了任务 ID: {schedule_id} 的配置。" - - async def remove_schedule( - self, plugin_name: str, group_id: str | None, bot_id: str | None = None - ) -> tuple[bool, str]: - """移除指定插件和群组的定时任务。""" - query = {"plugin_name": plugin_name, "group_id": group_id} - if bot_id: - query["bot_id"] = bot_id - - schedules = await ScheduleInfo.filter(**query) - if not schedules: - msg = ( - f"未找到与 Bot {bot_id} 相关的群 {group_id} " - f"的插件 '{plugin_name}' 定时任务。" - ) - return (False, msg) - - for schedule in schedules: - self._remove_aps_job(schedule.id) - await schedule.delete() - - target_desc = f"群 {group_id}" if group_id else "全局" - msg = ( - f"已取消 Bot {bot_id} 在 [{target_desc}] " - f"的插件 '{plugin_name}' 所有定时任务。" - ) - return (True, msg) - - async def remove_schedule_for_all( - self, plugin_name: str, bot_id: str | None = None - ) -> int: - """移除指定插件在所有群组的定时任务。""" - query = {"plugin_name": plugin_name} - if bot_id: - query["bot_id"] = bot_id - - schedules_to_delete = await ScheduleInfo.filter(**query).all() - if not schedules_to_delete: - return 0 - - for schedule in schedules_to_delete: - self._remove_aps_job(schedule.id) - await schedule.delete() - await asyncio.sleep(0.01) - - return len(schedules_to_delete) - - async def remove_schedules_by_group(self, group_id: str) -> tuple[bool, str]: - """移除指定群组的所有定时任务。""" - schedules = await ScheduleInfo.filter(group_id=group_id) - if not schedules: - return False, f"群 {group_id} 没有任何定时任务。" - - count = 0 - for schedule in schedules: - self._remove_aps_job(schedule.id) - await schedule.delete() - count += 1 - await asyncio.sleep(0.01) - - return True, f"已成功移除群 {group_id} 的 {count} 个定时任务。" - - async def pause_schedules_by_group(self, group_id: str) -> tuple[int, str]: - """暂停指定群组的所有定时任务。""" - schedules = await ScheduleInfo.filter(group_id=group_id, is_enabled=True) - if not schedules: - return 0, f"群 {group_id} 没有正在运行的定时任务可暂停。" - - count = 0 - for schedule in schedules: - success, _ = await self.pause_schedule(schedule.id) - if success: - count += 1 - await asyncio.sleep(0.01) - - return count, f"已成功暂停群 {group_id} 的 {count} 个定时任务。" - - async def resume_schedules_by_group(self, group_id: str) -> tuple[int, str]: - """恢复指定群组的所有定时任务。""" - schedules = await ScheduleInfo.filter(group_id=group_id, is_enabled=False) - if not schedules: - return 0, f"群 {group_id} 没有已暂停的定时任务可恢复。" - - count = 0 - for schedule in schedules: - success, _ = await self.resume_schedule(schedule.id) - if success: - count += 1 - await asyncio.sleep(0.01) - - return count, f"已成功恢复群 {group_id} 的 {count} 个定时任务。" - - async def pause_schedules_by_plugin(self, plugin_name: str) -> tuple[int, str]: - """暂停指定插件在所有群组的定时任务。""" - schedules = await ScheduleInfo.filter(plugin_name=plugin_name, is_enabled=True) - if not schedules: - return 0, f"插件 '{plugin_name}' 没有正在运行的定时任务可暂停。" - - count = 0 - for schedule in schedules: - success, _ = await self.pause_schedule(schedule.id) - if success: - count += 1 - await asyncio.sleep(0.01) - - return ( - count, - f"已成功暂停插件 '{plugin_name}' 在所有群组的 {count} 个定时任务。", - ) - - async def resume_schedules_by_plugin(self, plugin_name: str) -> tuple[int, str]: - """恢复指定插件在所有群组的定时任务。""" - schedules = await ScheduleInfo.filter(plugin_name=plugin_name, is_enabled=False) - if not schedules: - return 0, f"插件 '{plugin_name}' 没有已暂停的定时任务可恢复。" - - count = 0 - for schedule in schedules: - success, _ = await self.resume_schedule(schedule.id) - if success: - count += 1 - await asyncio.sleep(0.01) - - return ( - count, - f"已成功恢复插件 '{plugin_name}' 在所有群组的 {count} 个定时任务。", - ) - - async def pause_schedule_by_plugin_group( - self, plugin_name: str, group_id: str | None, bot_id: str | None = None - ) -> tuple[bool, str]: - """暂停指定插件在指定群组的定时任务。""" - query = {"plugin_name": plugin_name, "group_id": group_id, "is_enabled": True} - if bot_id: - query["bot_id"] = bot_id - - schedules = await ScheduleInfo.filter(**query) - if not schedules: - return ( - False, - f"群 {group_id} 未设置插件 '{plugin_name}' 的定时任务或任务已暂停。", - ) - - count = 0 - for schedule in schedules: - success, _ = await self.pause_schedule(schedule.id) - if success: - count += 1 - - return ( - True, - f"已成功暂停群 {group_id} 的插件 '{plugin_name}' 共 {count} 个定时任务。", - ) - - async def resume_schedule_by_plugin_group( - self, plugin_name: str, group_id: str | None, bot_id: str | None = None - ) -> tuple[bool, str]: - """恢复指定插件在指定群组的定时任务。""" - query = {"plugin_name": plugin_name, "group_id": group_id, "is_enabled": False} - if bot_id: - query["bot_id"] = bot_id - - schedules = await ScheduleInfo.filter(**query) - if not schedules: - return ( - False, - f"群 {group_id} 未设置插件 '{plugin_name}' 的定时任务或任务已启用。", - ) - - count = 0 - for schedule in schedules: - success, _ = await self.resume_schedule(schedule.id) - if success: - count += 1 - - return ( - True, - f"已成功恢复群 {group_id} 的插件 '{plugin_name}' 共 {count} 个定时任务。", - ) - - async def remove_all_schedules(self) -> tuple[int, str]: - """移除所有群组的所有定时任务。""" - schedules = await ScheduleInfo.all() - if not schedules: - return 0, "当前没有任何定时任务。" - - count = 0 - for schedule in schedules: - self._remove_aps_job(schedule.id) - await schedule.delete() - count += 1 - await asyncio.sleep(0.01) - - return count, f"已成功移除所有群组的 {count} 个定时任务。" - - async def pause_all_schedules(self) -> tuple[int, str]: - """暂停所有群组的所有定时任务。""" - schedules = await ScheduleInfo.filter(is_enabled=True) - if not schedules: - return 0, "当前没有正在运行的定时任务可暂停。" - - count = 0 - for schedule in schedules: - success, _ = await self.pause_schedule(schedule.id) - if success: - count += 1 - await asyncio.sleep(0.01) - - return count, f"已成功暂停所有群组的 {count} 个定时任务。" - - async def resume_all_schedules(self) -> tuple[int, str]: - """恢复所有群组的所有定时任务。""" - schedules = await ScheduleInfo.filter(is_enabled=False) - if not schedules: - return 0, "当前没有已暂停的定时任务可恢复。" - - count = 0 - for schedule in schedules: - success, _ = await self.resume_schedule(schedule.id) - if success: - count += 1 - await asyncio.sleep(0.01) - - return count, f"已成功恢复所有群组的 {count} 个定时任务。" - - async def remove_schedule_by_id(self, schedule_id: int) -> tuple[bool, str]: - """通过ID移除指定的定时任务。""" - schedule = await self.get_schedule_by_id(schedule_id) - if not schedule: - return False, f"未找到 ID 为 {schedule_id} 的定时任务。" - - self._remove_aps_job(schedule.id) - await schedule.delete() - - return ( - True, - f"已删除插件 '{schedule.plugin_name}' 在群 {schedule.group_id} " - f"的定时任务 (ID: {schedule.id})。", - ) - - async def get_schedule_by_id(self, schedule_id: int) -> ScheduleInfo | None: - """通过ID获取定时任务信息。""" - return await ScheduleInfo.get_or_none(id=schedule_id) - - async def get_schedules( - self, plugin_name: str, group_id: str | None - ) -> list[ScheduleInfo]: - """获取特定群组特定插件的所有定时任务。""" - return await ScheduleInfo.filter(plugin_name=plugin_name, group_id=group_id) - - async def get_schedule( - self, plugin_name: str, group_id: str | None - ) -> ScheduleInfo | None: - """获取特定群组的定时任务信息。""" - return await ScheduleInfo.get_or_none( - plugin_name=plugin_name, group_id=group_id - ) - - async def get_all_schedules( - self, plugin_name: str | None = None - ) -> list[ScheduleInfo]: - """获取所有定时任务信息,可按插件名过滤。""" - if plugin_name: - return await ScheduleInfo.filter(plugin_name=plugin_name).all() - return await ScheduleInfo.all() - - async def get_schedule_status(self, schedule_id: int) -> dict | None: - """获取任务的详细状态。""" - schedule = await self.get_schedule_by_id(schedule_id) - if not schedule: - return None - - job_id = self._get_job_id(schedule.id) - job = scheduler.get_job(job_id) - - status = { - "id": schedule.id, - "bot_id": schedule.bot_id, - "plugin_name": schedule.plugin_name, - "group_id": schedule.group_id, - "is_enabled": schedule.is_enabled, - "trigger_type": schedule.trigger_type, - "trigger_config": schedule.trigger_config, - "job_kwargs": schedule.job_kwargs, - "next_run_time": job.next_run_time.strftime("%Y-%m-%d %H:%M:%S") - if job and job.next_run_time - else "N/A", - "is_paused_in_scheduler": not bool(job.next_run_time) if job else "N/A", - } - return status - - async def pause_schedule(self, schedule_id: int) -> tuple[bool, str]: - """暂停一个定时任务。""" - schedule = await self.get_schedule_by_id(schedule_id) - if not schedule or not schedule.is_enabled: - return False, "任务不存在或已暂停。" - - schedule.is_enabled = False - await schedule.save(update_fields=["is_enabled"]) - - job_id = self._get_job_id(schedule.id) - try: - scheduler.pause_job(job_id) - except Exception: - pass - - return ( - True, - f"已暂停插件 '{schedule.plugin_name}' 在群 {schedule.group_id} " - f"的定时任务 (ID: {schedule.id})。", - ) - - async def resume_schedule(self, schedule_id: int) -> tuple[bool, str]: - """恢复一个定时任务。""" - schedule = await self.get_schedule_by_id(schedule_id) - if not schedule or schedule.is_enabled: - return False, "任务不存在或已启用。" - - schedule.is_enabled = True - await schedule.save(update_fields=["is_enabled"]) - - job_id = self._get_job_id(schedule.id) - try: - scheduler.resume_job(job_id) - except Exception: - self._add_aps_job(schedule) - - return ( - True, - f"已恢复插件 '{schedule.plugin_name}' 在群 {schedule.group_id} " - f"的定时任务 (ID: {schedule.id})。", - ) - - async def trigger_now(self, schedule_id: int) -> tuple[bool, str]: - """手动触发一个定时任务。""" - schedule = await self.get_schedule_by_id(schedule_id) - if not schedule: - return False, f"未找到 ID 为 {schedule_id} 的定时任务。" - - if schedule.plugin_name not in self._registered_tasks: - return False, f"插件 '{schedule.plugin_name}' 没有注册可用的定时任务。" - - try: - await self._execute_job(schedule.id) - return ( - True, - f"已手动触发插件 '{schedule.plugin_name}' 在群 {schedule.group_id} " - f"的定时任务 (ID: {schedule.id})。", - ) - except Exception as e: - logger.error(f"手动触发任务失败: {e}") - return False, f"手动触发任务失败: {e}" - - -scheduler_manager = SchedulerManager() - - -@PriorityLifecycle.on_startup(priority=90) -async def _load_schedules_from_db(): - """在服务启动时从数据库加载并调度所有任务。""" - Config.add_plugin_config( - "SchedulerManager", - SCHEDULE_CONCURRENCY_KEY, - 5, - help="“所有群组”类型定时任务的并发执行数量限制", - type=int, - ) - - logger.info("正在从数据库加载并调度所有定时任务...") - schedules = await ScheduleInfo.filter(is_enabled=True).all() - count = 0 - for schedule in schedules: - if schedule.plugin_name in scheduler_manager._registered_tasks: - scheduler_manager._add_aps_job(schedule) - count += 1 - else: - logger.warning(f"跳过加载定时任务:插件 '{schedule.plugin_name}' 未注册。") - logger.info(f"定时任务加载完成,共成功加载 {count} 个任务。") diff --git a/zhenxun/utils/platform.py b/zhenxun/utils/platform.py index 634c8226..f01aec3f 100644 --- a/zhenxun/utils/platform.py +++ b/zhenxun/utils/platform.py @@ -1,7 +1,7 @@ import asyncio from collections.abc import Awaitable, Callable import random -from typing import Literal +from typing import cast import httpx import nonebot @@ -227,7 +227,7 @@ class PlatformUtils: url = None if platform == "qq": if user_id.isdigit(): - url = f"http://q1.qlogo.cn/g?b=qq&nk={user_id}&s=640" + url = f"http://q1.qlogo.cn/g?b=qq&nk={user_id}&s=160" else: url = f"https://q.qlogo.cn/qqapp/{appid}/{user_id}/640" return await AsyncHttpx.get_content(url) if url else None @@ -486,15 +486,134 @@ class PlatformUtils: return target +class BroadcastEngine: + def __init__( + self, + message: str | UniMessage, + bot: Bot | list[Bot] | None = None, + bot_id: str | set[str] | None = None, + ignore_group: list[str] | None = None, + check_func: Callable[[Bot, str], Awaitable] | None = None, + log_cmd: str | None = None, + platform: str | None = None, + ): + """广播引擎 + + 参数: + message: 广播消息内容 + bot: 指定bot对象. + bot_id: 指定bot id. + ignore_group: 忽略群聊列表. + check_func: 发送前对群聊检测方法,判断是否发送. + log_cmd: 日志标记. + platform: 指定平台. + + 异常: + ValueError: 没有可用的Bot对象 + """ + if ignore_group is None: + ignore_group = [] + self.message = MessageUtils.build_message(message) + self.ignore_group = ignore_group + self.check_func = check_func + self.log_cmd = log_cmd + self.platform = platform + self.bot_list = [] + self.count = 0 + if bot: + self.bot_list = [bot] if isinstance(bot, Bot) else bot + if isinstance(bot_id, str): + bot_id = set(bot_id) + if bot_id: + for i in bot_id: + try: + self.bot_list.append(nonebot.get_bot(i)) + except KeyError: + logger.warning(f"Bot:{i} 对象未连接或不存在") + if not self.bot_list: + raise ValueError("当前没有可用的Bot对象...", log_cmd) + + async def call_check(self, bot: Bot, group_id: str) -> bool: + """运行发送检测函数 + + 参数: + bot: Bot + group_id: 群组id + + 返回: + bool: 是否发送 + """ + if not self.check_func: + return True + if is_coroutine_callable(self.check_func): + is_run = await self.check_func(bot, group_id) + else: + is_run = self.check_func(bot, group_id) + return cast(bool, is_run) + + async def __send_message(self, bot: Bot, group: GroupConsole): + """群组发送消息 + + 参数: + bot: Bot + group: GroupConsole + """ + key = f"{group.group_id}:{group.channel_id}" + if not await self.call_check(bot, group.group_id): + logger.debug( + "广播方法检测运行方法为 False, 已跳过该群组...", + self.log_cmd, + group_id=group.group_id, + ) + return + if target := PlatformUtils.get_target( + group_id=group.group_id, + channel_id=group.channel_id, + ): + self.ignore_group.append(key) + await MessageUtils.build_message(self.message).send(target, bot) + logger.debug("广播消息发送成功...", self.log_cmd, target=key) + else: + logger.warning("广播消息获取Target失败...", self.log_cmd, target=key) + + async def broadcast(self) -> int: + """广播消息 + + 返回: + int: 成功发送次数 + """ + for bot in self.bot_list: + if self.platform and self.platform != PlatformUtils.get_platform(bot): + continue + group_list, _ = await PlatformUtils.get_group_list(bot) + if not group_list: + continue + for group in group_list: + if ( + group.group_id in self.ignore_group + or group.channel_id in self.ignore_group + ): + continue + try: + await self.__send_message(bot, group) + await asyncio.sleep(random.randint(1, 3)) + self.count += 1 + except Exception as e: + logger.warning( + "广播消息发送失败", self.log_cmd, target=group.group_id, e=e + ) + return self.count + + async def broadcast_group( message: str | UniMessage, bot: Bot | list[Bot] | None = None, bot_id: str | set[str] | None = None, - ignore_group: set[int] | None = None, + ignore_group: list[str] = [], check_func: Callable[[Bot, str], Awaitable] | None = None, log_cmd: str | None = None, - platform: Literal["qq", "dodo", "kaiheila"] | None = None, -): + platform: str | None = None, +) -> int: """获取所有Bot或指定Bot对象广播群聊 参数: @@ -505,81 +624,18 @@ async def broadcast_group( check_func: 发送前对群聊检测方法,判断是否发送. log_cmd: 日志标记. platform: 指定平台 + + 返回: + int: 成功发送次数 """ - if platform and platform not in ["qq", "dodo", "kaiheila"]: - raise ValueError("指定平台不支持") - if not message: - raise ValueError("群聊广播消息不能为空") - bot_dict = nonebot.get_bots() - bot_list: list[Bot] = [] - if bot: - if isinstance(bot, list): - bot_list = bot - else: - bot_list.append(bot) - elif bot_id: - _bot_id_list = bot_id - if isinstance(bot_id, str): - _bot_id_list = [bot_id] - for id_ in _bot_id_list: - if bot_id in bot_dict: - bot_list.append(bot_dict[bot_id]) - else: - logger.warning(f"Bot:{id_} 对象未连接或不存在") - else: - bot_list = list(bot_dict.values()) - _used_group = [] - for _bot in bot_list: - try: - if platform and platform != PlatformUtils.get_platform(_bot): - continue - group_list, _ = await PlatformUtils.get_group_list(_bot) - if group_list: - for group in group_list: - key = f"{group.group_id}:{group.channel_id}" - try: - if ( - ignore_group - and ( - group.group_id in ignore_group - or group.channel_id in ignore_group - ) - ) or key in _used_group: - logger.debug( - "广播方法群组重复, 已跳过...", - log_cmd, - group_id=group.group_id, - ) - continue - is_run = False - if check_func: - if is_coroutine_callable(check_func): - is_run = await check_func(_bot, group.group_id) - else: - is_run = check_func(_bot, group.group_id) - if not is_run: - logger.debug( - "广播方法检测运行方法为 False, 已跳过...", - log_cmd, - group_id=group.group_id, - ) - continue - target = PlatformUtils.get_target( - user_id=None, - group_id=group.group_id, - channel_id=group.channel_id, - ) - if target: - _used_group.append(key) - message_list = message - await MessageUtils.build_message(message_list).send( - target, _bot - ) - logger.debug("发送成功", log_cmd, target=key) - await asyncio.sleep(random.randint(1, 3)) - else: - logger.warning("target为空", log_cmd, target=key) - except Exception as e: - logger.error("发送失败", log_cmd, target=key, e=e) - except Exception as e: - logger.error(f"Bot: {_bot.self_id} 获取群聊列表失败", command=log_cmd, e=e) + if not message.strip(): + raise ValueError("群聊广播消息不能为空...") + return await BroadcastEngine( + message=message, + bot=bot, + bot_id=bot_id, + ignore_group=ignore_group, + check_func=check_func, + log_cmd=log_cmd, + platform=platform, + ).broadcast() diff --git a/zhenxun/utils/utils.py b/zhenxun/utils/utils.py index 1e3d26e3..5b3ec6a9 100644 --- a/zhenxun/utils/utils.py +++ b/zhenxun/utils/utils.py @@ -244,10 +244,8 @@ def is_valid_date(date_text: str, separator: str = "-") -> bool: def get_entity_ids(session: Uninfo) -> EntityIDs: """获取用户id,群组id,频道id - 参数: session: Uninfo - 返回: EntityIDs: 用户id,群组id,频道id """