diff --git a/zhenxun/builtin_plugins/nickname.py b/zhenxun/builtin_plugins/nickname.py index 7dd9a697..4d69abcf 100644 --- a/zhenxun/builtin_plugins/nickname.py +++ b/zhenxun/builtin_plugins/nickname.py @@ -1,12 +1,17 @@ import random -from typing import Any -from nonebot import on_regex from nonebot.adapters import Bot -from nonebot.params import Depends, RegexGroup from nonebot.plugin import PluginMetadata from nonebot.rule import to_me -from nonebot_plugin_alconna import Alconna, Option, on_alconna, store_true +from nonebot_plugin_alconna import ( + Alconna, + Args, + Arparma, + CommandMeta, + Option, + on_alconna, + store_true, +) from nonebot_plugin_uninfo import Uninfo from zhenxun.configs.config import BotConfig, Config @@ -54,15 +59,22 @@ __plugin_meta__ = PluginMetadata( ).to_dict(), ) -_nickname_matcher = on_regex( - "(?:以后)?(?:叫我|请叫我|称呼我)(.*)", +_nickname_matcher = on_alconna( + Alconna( + "re:(?:以后)?(?:叫我|请叫我|称呼我)", + Args["name?", str], + meta=CommandMeta(compact=True), + ), rule=to_me(), priority=5, block=True, ) -_global_nickname_matcher = on_regex( - "设置全局昵称(.*)", rule=to_me(), priority=5, block=True +_global_nickname_matcher = on_alconna( + Alconna("设置全局昵称", Args["name?", str], meta=CommandMeta(compact=True)), + rule=to_me(), + priority=5, + block=True, ) _matcher = on_alconna( @@ -117,34 +129,32 @@ CANCEL = [ ] -def CheckNickname(): +async def CheckNickname( + bot: Bot, + session: Uninfo, + params: Arparma, +): """ 检查名称是否合法 """ - - async def dependency( - bot: Bot, - session: Uninfo, - reg_group: tuple[Any, ...] = RegexGroup(), - ): - black_word = Config.get_config("nickname", "BLACK_WORD") - (name,) = reg_group - logger.debug(f"昵称检查: {name}", "昵称设置", session=session) - if not name: - await MessageUtils.build_message("叫你空白?叫你虚空?叫你无名??").finish( - at_sender=True - ) - if session.user.id in bot.config.superusers: - logger.debug( - f"超级用户设置昵称, 跳过合法检测: {name}", "昵称设置", session=session - ) - return + black_word = Config.get_config("nickname", "BLACK_WORD") + name = params.query("name") + logger.debug(f"昵称检查: {name}", "昵称设置", session=session) + if not name: + await MessageUtils.build_message("叫你空白?叫你虚空?叫你无名??").finish( + at_sender=True + ) + if session.user.id in bot.config.superusers: + logger.debug( + f"超级用户设置昵称, 跳过合法检测: {name}", "昵称设置", session=session + ) + else: if len(name) > 20: await MessageUtils.build_message("昵称可不能超过20个字!").finish( at_sender=True ) if name in bot.config.nickname: - await MessageUtils.build_message("笨蛋!休想占用我的名字! #").finish( + await MessageUtils.build_message("笨蛋!休想占用我的名字! ").finish( at_sender=True ) if black_word: @@ -162,17 +172,17 @@ def CheckNickname(): await MessageUtils.build_message( f"字符 [{word}] 为禁止字符!" ).finish(at_sender=True) - - return Depends(dependency) + return name -@_nickname_matcher.handle(parameterless=[CheckNickname()]) +@_nickname_matcher.handle() async def _( + bot: Bot, session: Uninfo, + name_: Arparma, uname: str = UserName(), - reg_group: tuple[Any, ...] = RegexGroup(), ): - (name,) = reg_group + name = await CheckNickname(bot, session, name_) if len(name) < 5 and random.random() < 0.3: name = "~".join(name) group_id = None @@ -200,13 +210,14 @@ async def _( ) -@_global_nickname_matcher.handle(parameterless=[CheckNickname()]) +@_global_nickname_matcher.handle() async def _( + bot: Bot, session: Uninfo, + name_: Arparma, nickname: str = UserName(), - reg_group: tuple[Any, ...] = RegexGroup(), ): - (name,) = reg_group + name = await CheckNickname(bot, session, name_) await FriendUser.set_user_nickname( session.user.id, name, @@ -227,15 +238,14 @@ async def _(session: Uninfo, uname: str = UserName()): group_id = session.group.parent.id if session.group.parent else session.group.id if group_id: nickname = await GroupInfoUser.get_user_nickname(session.user.id, group_id) - card = uname else: nickname = await FriendUser.get_user_nickname(session.user.id) - card = uname if nickname: await MessageUtils.build_message(random.choice(REMIND).format(nickname)).finish( reply_to=True ) else: + card = uname await MessageUtils.build_message( random.choice( [ @@ -270,4 +280,4 @@ async def _(bot: Bot, session: Uninfo): else: await MessageUtils.build_message("你在做梦吗?你没有昵称啊").finish( reply_to=True - ) + ) \ No newline at end of file diff --git a/zhenxun/builtin_plugins/scheduler_admin/__init__.py b/zhenxun/builtin_plugins/scheduler_admin/__init__.py new file mode 100644 index 00000000..a2865979 --- /dev/null +++ b/zhenxun/builtin_plugins/scheduler_admin/__init__.py @@ -0,0 +1,51 @@ +from nonebot.plugin import PluginMetadata + +from zhenxun.configs.utils import PluginExtraData +from zhenxun.utils.enum import PluginType + +from . import command # noqa: F401 + +__plugin_meta__ = PluginMetadata( + name="定时任务管理", + description="查看和管理由 SchedulerManager 控制的定时任务。", + usage=""" +📋 定时任务管理 - 支持群聊和私聊操作 + +🔍 查看任务: + 定时任务 查看 [-all] [-g <群号>] [-p <插件>] [--page <页码>] + • 群聊中: 查看本群任务 + • 私聊中: 必须使用 -g <群号> 或 -all 选项 (SUPERUSER) + +📊 任务状态: + 定时任务 状态 <任务ID> 或 任务状态 <任务ID> + • 查看单个任务的详细信息和状态 + +⚙️ 任务管理 (SUPERUSER): + 定时任务 设置 <插件> [时间选项] [-g <群号> | -g all] [--kwargs <参数>] + 定时任务 删除 <任务ID> | -p <插件> [-g <群号>] | -all + 定时任务 暂停 <任务ID> | -p <插件> [-g <群号>] | -all + 定时任务 恢复 <任务ID> | -p <插件> [-g <群号>] | -all + 定时任务 执行 <任务ID> + 定时任务 更新 <任务ID> [时间选项] [--kwargs <参数>] + +📝 时间选项 (三选一): + --cron "<分> <时> <日> <月> <周>" # 例: --cron "0 8 * * *" + --interval <时间间隔> # 例: --interval 30m, 2h, 10s + --date "" # 例: --date "2024-01-01 08:00:00" + --daily "" # 例: --daily "08:30" + +📚 其他功能: + 定时任务 插件列表 # 查看所有可设置定时任务的插件 (SUPERUSER) + +🏷️ 别名支持: + 查看: ls, list | 设置: add, 开启 | 删除: del, rm, remove, 关闭, 取消 + 暂停: pause | 恢复: resume | 执行: trigger, run | 状态: status, info + 更新: update, modify, 修改 | 插件列表: plugins + """.strip(), + extra=PluginExtraData( + author="HibiKier", + version="0.1.2", + plugin_type=PluginType.SUPERUSER, + is_show=False, + ).to_dict(), +) \ No newline at end of file diff --git a/zhenxun/builtin_plugins/scheduler_admin/command.py b/zhenxun/builtin_plugins/scheduler_admin/command.py new file mode 100644 index 00000000..0238d9c8 --- /dev/null +++ b/zhenxun/builtin_plugins/scheduler_admin/command.py @@ -0,0 +1,836 @@ +import asyncio +from datetime import datetime +import re + +from nonebot.adapters import Event +from nonebot.adapters.onebot.v11 import Bot +from nonebot.params import Depends +from nonebot.permission import SUPERUSER +from nonebot_plugin_alconna import ( + Alconna, + AlconnaMatch, + Args, + Arparma, + Match, + Option, + Query, + Subcommand, + on_alconna, +) +from pydantic import BaseModel, ValidationError + +from zhenxun.utils._image_template import ImageTemplate +from zhenxun.utils.manager.schedule_manager import scheduler_manager + + +def _get_type_name(annotation) -> str: + """获取类型注解的名称""" + if hasattr(annotation, "__name__"): + return annotation.__name__ + elif hasattr(annotation, "_name"): + return annotation._name + else: + return str(annotation) + + +from zhenxun.utils.message import MessageUtils +from zhenxun.utils.rules import admin_check + + +def _format_trigger(schedule_status: dict) -> str: + """将触发器配置格式化为人类可读的字符串""" + trigger_type = schedule_status["trigger_type"] + config = schedule_status["trigger_config"] + + if trigger_type == "cron": + minute = config.get("minute", "*") + hour = config.get("hour", "*") + day = config.get("day", "*") + month = config.get("month", "*") + day_of_week = config.get("day_of_week", "*") + + if day == "*" and month == "*" and day_of_week == "*": + formatted_hour = hour if hour == "*" else f"{int(hour):02d}" + formatted_minute = minute if minute == "*" else f"{int(minute):02d}" + return f"每天 {formatted_hour}:{formatted_minute}" + else: + return f"Cron: {minute} {hour} {day} {month} {day_of_week}" + elif trigger_type == "interval": + seconds = config.get("seconds", 0) + minutes = config.get("minutes", 0) + hours = config.get("hours", 0) + days = config.get("days", 0) + if days: + trigger_str = f"每 {days} 天" + elif hours: + trigger_str = f"每 {hours} 小时" + elif minutes: + trigger_str = f"每 {minutes} 分钟" + else: + trigger_str = f"每 {seconds} 秒" + elif trigger_type == "date": + run_date = config.get("run_date", "未知时间") + trigger_str = f"在 {run_date}" + else: + trigger_str = f"{trigger_type}: {config}" + + return trigger_str + + +def _format_params(schedule_status: dict) -> str: + """将任务参数格式化为人类可读的字符串""" + if kwargs := schedule_status.get("job_kwargs"): + kwargs_str = " | ".join(f"{k}: {v}" for k, v in kwargs.items()) + return kwargs_str + return "-" + + +def _parse_interval(interval_str: str) -> dict: + """增强版解析器,支持 d(天)""" + match = re.match(r"(\d+)([smhd])", interval_str.lower()) + if not match: + raise ValueError("时间间隔格式错误, 请使用如 '30m', '2h', '1d', '10s' 的格式。") + + value, unit = int(match.group(1)), match.group(2) + if unit == "s": + return {"seconds": value} + if unit == "m": + return {"minutes": value} + if unit == "h": + return {"hours": value} + if unit == "d": + return {"days": value} + return {} + + +def _parse_daily_time(time_str: str) -> dict: + """解析 HH:MM 或 HH:MM:SS 格式的时间为 cron 配置""" + if match := re.match(r"^(\d{1,2}):(\d{1,2})(?::(\d{1,2}))?$", time_str): + hour, minute, second = match.groups() + hour, minute = int(hour), int(minute) + + if not (0 <= hour <= 23 and 0 <= minute <= 59): + raise ValueError("小时或分钟数值超出范围。") + + cron_config = { + "minute": str(minute), + "hour": str(hour), + "day": "*", + "month": "*", + "day_of_week": "*", + } + if second is not None: + if not (0 <= int(second) <= 59): + raise ValueError("秒数值超出范围。") + cron_config["second"] = str(second) + + return cron_config + else: + raise ValueError("时间格式错误,请使用 'HH:MM' 或 'HH:MM:SS' 格式。") + + +async def GetBotId( + bot: Bot, + bot_id_match: Match[str] = AlconnaMatch("bot_id"), +) -> str: + """获取要操作的Bot ID""" + if bot_id_match.available: + return bot_id_match.result + return bot.self_id + + +class ScheduleTarget: + """定时任务操作目标的基类""" + + pass + + +class TargetByID(ScheduleTarget): + """按任务ID操作""" + + def __init__(self, id: int): + self.id = id + + +class TargetByPlugin(ScheduleTarget): + """按插件名操作""" + + def __init__( + self, plugin: str, group_id: str | None = None, all_groups: bool = False + ): + self.plugin = plugin + self.group_id = group_id + self.all_groups = all_groups + + +class TargetAll(ScheduleTarget): + """操作所有任务""" + + def __init__(self, for_group: str | None = None): + self.for_group = for_group + + +TargetScope = TargetByID | TargetByPlugin | TargetAll | None + + +def create_target_parser(subcommand_name: str): + """ + 创建一个依赖注入函数,用于解析删除、暂停、恢复等命令的操作目标。 + """ + + async def dependency( + event: Event, + schedule_id: Match[int] = AlconnaMatch("schedule_id"), + plugin_name: Match[str] = AlconnaMatch("plugin_name"), + group_id: Match[str] = AlconnaMatch("group_id"), + all_enabled: Query[bool] = Query(f"{subcommand_name}.all"), + ) -> TargetScope: + if schedule_id.available: + return TargetByID(schedule_id.result) + + if plugin_name.available: + p_name = plugin_name.result + if all_enabled.available: + return TargetByPlugin(plugin=p_name, all_groups=True) + elif group_id.available: + gid = group_id.result + if gid.lower() == "all": + return TargetByPlugin(plugin=p_name, all_groups=True) + return TargetByPlugin(plugin=p_name, group_id=gid) + else: + current_group_id = getattr(event, "group_id", None) + if current_group_id: + return TargetByPlugin(plugin=p_name, group_id=str(current_group_id)) + else: + await schedule_cmd.finish( + "私聊中操作插件任务必须使用 -g <群号> 或 -all 选项。" + ) + + if all_enabled.available: + return TargetAll(for_group=group_id.result if group_id.available else None) + + return None + + return dependency + + +schedule_cmd = on_alconna( + Alconna( + "定时任务", + Subcommand( + "查看", + Option("-g", Args["target_group_id", str]), + Option("-all", help_text="查看所有群聊 (SUPERUSER)"), + Option("-p", Args["plugin_name", str], help_text="按插件名筛选"), + Option("--page", Args["page", int, 1], help_text="指定页码"), + alias=["ls", "list"], + help_text="查看定时任务", + ), + Subcommand( + "设置", + Args["plugin_name", str], + Option("--cron", Args["cron_expr", str], help_text="设置 cron 表达式"), + Option("--interval", Args["interval_expr", str], help_text="设置时间间隔"), + Option("--date", Args["date_expr", str], help_text="设置特定执行日期"), + Option( + "--daily", + Args["daily_expr", str], + help_text="设置每天执行的时间 (如 08:20)", + ), + Option("-g", Args["group_id", str], help_text="指定群组ID或'all'"), + Option("-all", help_text="对所有群生效 (等同于 -g all)"), + Option("--kwargs", Args["kwargs_str", str], help_text="设置任务参数"), + Option( + "--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)" + ), + alias=["add", "开启"], + help_text="设置/开启一个定时任务", + ), + Subcommand( + "删除", + Args["schedule_id?", int], + Option("-p", Args["plugin_name", str], help_text="指定插件名"), + Option("-g", Args["group_id", str], help_text="指定群组ID"), + Option("-all", help_text="对所有群生效"), + Option( + "--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)" + ), + alias=["del", "rm", "remove", "关闭", "取消"], + help_text="删除一个或多个定时任务", + ), + Subcommand( + "暂停", + Args["schedule_id?", int], + Option("-all", help_text="对当前群所有任务生效"), + Option("-p", Args["plugin_name", str], help_text="指定插件名"), + Option("-g", Args["group_id", str], help_text="指定群组ID (SUPERUSER)"), + Option( + "--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)" + ), + alias=["pause"], + help_text="暂停一个或多个定时任务", + ), + Subcommand( + "恢复", + Args["schedule_id?", int], + Option("-all", help_text="对当前群所有任务生效"), + Option("-p", Args["plugin_name", str], help_text="指定插件名"), + Option("-g", Args["group_id", str], help_text="指定群组ID (SUPERUSER)"), + Option( + "--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)" + ), + alias=["resume"], + help_text="恢复一个或多个定时任务", + ), + Subcommand( + "执行", + Args["schedule_id", int], + alias=["trigger", "run"], + help_text="立即执行一次任务", + ), + Subcommand( + "更新", + Args["schedule_id", int], + Option("--cron", Args["cron_expr", str], help_text="设置 cron 表达式"), + Option("--interval", Args["interval_expr", str], help_text="设置时间间隔"), + Option("--date", Args["date_expr", str], help_text="设置特定执行日期"), + Option( + "--daily", + Args["daily_expr", str], + help_text="更新每天执行的时间 (如 08:20)", + ), + Option("--kwargs", Args["kwargs_str", str], help_text="更新参数"), + alias=["update", "modify", "修改"], + help_text="更新任务配置", + ), + Subcommand( + "状态", + Args["schedule_id", int], + alias=["status", "info"], + help_text="查看单个任务的详细状态", + ), + Subcommand( + "插件列表", + alias=["plugins"], + help_text="列出所有可用的插件", + ), + ), + priority=5, + block=True, + rule=admin_check(1), +) + +schedule_cmd.shortcut( + "任务状态", + command="定时任务", + arguments=["状态", "{%0}"], + prefix=True, +) + + +@schedule_cmd.handle() +async def _handle_time_options_mutex(arp: Arparma): + time_options = ["cron", "interval", "date", "daily"] + provided_options = [opt for opt in time_options if arp.query(opt) is not None] + if len(provided_options) > 1: + await schedule_cmd.finish( + f"时间选项 --{', --'.join(provided_options)} 不能同时使用,请只选择一个。" + ) + + +@schedule_cmd.assign("查看") +async def _( + bot: Bot, + event: Event, + target_group_id: Match[str] = AlconnaMatch("target_group_id"), + all_groups: Query[bool] = Query("查看.all"), + plugin_name: Match[str] = AlconnaMatch("plugin_name"), + page: Match[int] = AlconnaMatch("page"), +): + is_superuser = await SUPERUSER(bot, event) + schedules = [] + title = "" + + current_group_id = getattr(event, "group_id", None) + if not (all_groups.available or target_group_id.available) and not current_group_id: + await schedule_cmd.finish("私聊中查看任务必须使用 -g <群号> 或 -all 选项。") + + if all_groups.available: + if not is_superuser: + await schedule_cmd.finish("需要超级用户权限才能查看所有群组的定时任务。") + schedules = await scheduler_manager.get_all_schedules() + title = "所有群组的定时任务" + elif target_group_id.available: + if not is_superuser: + await schedule_cmd.finish("需要超级用户权限才能查看指定群组的定时任务。") + gid = target_group_id.result + schedules = [ + s for s in await scheduler_manager.get_all_schedules() if s.group_id == gid + ] + title = f"群 {gid} 的定时任务" + else: + gid = str(current_group_id) + schedules = [ + s for s in await scheduler_manager.get_all_schedules() if s.group_id == gid + ] + title = "本群的定时任务" + + if plugin_name.available: + schedules = [s for s in schedules if s.plugin_name == plugin_name.result] + title += f" [插件: {plugin_name.result}]" + + if not schedules: + await schedule_cmd.finish("没有找到任何相关的定时任务。") + + page_size = 15 + current_page = page.result + total_items = len(schedules) + total_pages = (total_items + page_size - 1) // page_size + start_index = (current_page - 1) * page_size + end_index = start_index + page_size + paginated_schedules = schedules[start_index:end_index] + + if not paginated_schedules: + await schedule_cmd.finish("这一页没有内容了哦~") + + status_tasks = [ + scheduler_manager.get_schedule_status(s.id) for s in paginated_schedules + ] + all_statuses = await asyncio.gather(*status_tasks) + data_list = [ + [ + s["id"], + s["plugin_name"], + s.get("bot_id") or "N/A", + s["group_id"] or "全局", + s["next_run_time"], + _format_trigger(s), + _format_params(s), + "✔️ 已启用" if s["is_enabled"] else "⏸️ 已暂停", + ] + for s in all_statuses + if s + ] + + if not data_list: + await schedule_cmd.finish("没有找到任何相关的定时任务。") + + img = await ImageTemplate.table_page( + head_text=title, + tip_text=f"第 {current_page}/{total_pages} 页,共 {total_items} 条任务", + column_name=[ + "ID", + "插件", + "Bot ID", + "群组/目标", + "下次运行", + "触发规则", + "参数", + "状态", + ], + data_list=data_list, + column_space=20, + ) + await MessageUtils.build_message(img).send(reply_to=True) + + +@schedule_cmd.assign("设置") +async def _( + event: Event, + plugin_name: str, + cron_expr: str | None = None, + interval_expr: str | None = None, + date_expr: str | None = None, + daily_expr: str | None = None, + group_id: str | None = None, + kwargs_str: str | None = None, + all_enabled: Query[bool] = Query("设置.all"), + bot_id_to_operate: str = Depends(GetBotId), +): + if plugin_name not in scheduler_manager._registered_tasks: + await schedule_cmd.finish( + f"插件 '{plugin_name}' 没有注册可用的定时任务。\n" + f"可用插件: {list(scheduler_manager._registered_tasks.keys())}" + ) + + trigger_type = "" + trigger_config = {} + + try: + if cron_expr: + trigger_type = "cron" + parts = cron_expr.split() + if len(parts) != 5: + raise ValueError("Cron 表达式必须有5个部分 (分 时 日 月 周)") + cron_keys = ["minute", "hour", "day", "month", "day_of_week"] + trigger_config = dict(zip(cron_keys, parts)) + elif interval_expr: + trigger_type = "interval" + trigger_config = _parse_interval(interval_expr) + elif date_expr: + trigger_type = "date" + trigger_config = {"run_date": datetime.fromisoformat(date_expr)} + elif daily_expr: + trigger_type = "cron" + trigger_config = _parse_daily_time(daily_expr) + else: + await schedule_cmd.finish( + "必须提供一种时间选项: --cron, --interval, --date, 或 --daily。" + ) + except ValueError as e: + await schedule_cmd.finish(f"时间参数解析错误: {e}") + + job_kwargs = {} + if kwargs_str: + task_meta = scheduler_manager._registered_tasks[plugin_name] + params_model = task_meta.get("model") + if not params_model: + await schedule_cmd.finish(f"插件 '{plugin_name}' 不支持设置额外参数。") + + if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)): + await schedule_cmd.finish(f"插件 '{plugin_name}' 的参数模型配置错误。") + + raw_kwargs = {} + try: + for item in kwargs_str.split(","): + key, value = item.strip().split("=", 1) + raw_kwargs[key.strip()] = value + except Exception as e: + await schedule_cmd.finish( + f"参数格式错误,请使用 'key=value,key2=value2' 格式。错误: {e}" + ) + + try: + model_validate = getattr(params_model, "model_validate", None) + if not model_validate: + await schedule_cmd.finish( + f"插件 '{plugin_name}' 的参数模型不支持验证。" + ) + return + + validated_model = model_validate(raw_kwargs) + + model_dump = getattr(validated_model, "model_dump", None) + if not model_dump: + await schedule_cmd.finish( + f"插件 '{plugin_name}' 的参数模型不支持导出。" + ) + return + + job_kwargs = model_dump() + except ValidationError as e: + errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()] + error_str = "\n".join(errors) + await schedule_cmd.finish( + f"插件 '{plugin_name}' 的任务参数验证失败:\n{error_str}" + ) + return + + target_group_id: str | None + current_group_id = getattr(event, "group_id", None) + + if group_id and group_id.lower() == "all": + target_group_id = "__ALL_GROUPS__" + elif all_enabled.available: + target_group_id = "__ALL_GROUPS__" + elif group_id: + target_group_id = group_id + elif current_group_id: + target_group_id = str(current_group_id) + else: + await schedule_cmd.finish( + "私聊中设置定时任务时,必须使用 -g <群号> 或 --all 选项指定目标。" + ) + return + + success, msg = await scheduler_manager.add_schedule( + plugin_name, + target_group_id, + trigger_type, + trigger_config, + job_kwargs, + bot_id=bot_id_to_operate, + ) + + if target_group_id == "__ALL_GROUPS__": + target_desc = f"所有群组 (Bot: {bot_id_to_operate})" + elif target_group_id is None: + target_desc = "全局" + else: + target_desc = f"群组 {target_group_id}" + + if success: + await schedule_cmd.finish(f"已成功为 [{target_desc}] {msg}") + else: + await schedule_cmd.finish(f"为 [{target_desc}] 设置任务失败: {msg}") + + +@schedule_cmd.assign("删除") +async def _( + target: TargetScope = Depends(create_target_parser("删除")), + bot_id_to_operate: str = Depends(GetBotId), +): + if isinstance(target, TargetByID): + _, message = await scheduler_manager.remove_schedule_by_id(target.id) + await schedule_cmd.finish(message) + + elif isinstance(target, TargetByPlugin): + p_name = target.plugin + if p_name not in scheduler_manager.get_registered_plugins(): + await schedule_cmd.finish(f"未找到插件 '{p_name}'。") + + if target.all_groups: + removed_count = await scheduler_manager.remove_schedule_for_all( + p_name, bot_id=bot_id_to_operate + ) + message = ( + f"已取消了 {removed_count} 个群组的插件 '{p_name}' 定时任务。" + if removed_count > 0 + else f"没有找到插件 '{p_name}' 的定时任务。" + ) + await schedule_cmd.finish(message) + else: + _, message = await scheduler_manager.remove_schedule( + p_name, target.group_id, bot_id=bot_id_to_operate + ) + await schedule_cmd.finish(message) + + elif isinstance(target, TargetAll): + if target.for_group: + _, message = await scheduler_manager.remove_schedules_by_group( + target.for_group + ) + await schedule_cmd.finish(message) + else: + _, message = await scheduler_manager.remove_all_schedules() + await schedule_cmd.finish(message) + + else: + await schedule_cmd.finish( + "删除任务失败:请提供任务ID,或通过 -p <插件> 或 -all 指定要删除的任务。" + ) + + +@schedule_cmd.assign("暂停") +async def _( + target: TargetScope = Depends(create_target_parser("暂停")), + bot_id_to_operate: str = Depends(GetBotId), +): + if isinstance(target, TargetByID): + _, message = await scheduler_manager.pause_schedule(target.id) + await schedule_cmd.finish(message) + + elif isinstance(target, TargetByPlugin): + p_name = target.plugin + if p_name not in scheduler_manager.get_registered_plugins(): + await schedule_cmd.finish(f"未找到插件 '{p_name}'。") + + if target.all_groups: + _, message = await scheduler_manager.pause_schedules_by_plugin(p_name) + await schedule_cmd.finish(message) + else: + _, message = await scheduler_manager.pause_schedule_by_plugin_group( + p_name, target.group_id, bot_id=bot_id_to_operate + ) + await schedule_cmd.finish(message) + + elif isinstance(target, TargetAll): + if target.for_group: + _, message = await scheduler_manager.pause_schedules_by_group( + target.for_group + ) + await schedule_cmd.finish(message) + else: + _, message = await scheduler_manager.pause_all_schedules() + await schedule_cmd.finish(message) + + else: + await schedule_cmd.finish("请提供任务ID、使用 -p <插件> 或 -all 选项。") + + +@schedule_cmd.assign("恢复") +async def _( + target: TargetScope = Depends(create_target_parser("恢复")), + bot_id_to_operate: str = Depends(GetBotId), +): + if isinstance(target, TargetByID): + _, message = await scheduler_manager.resume_schedule(target.id) + await schedule_cmd.finish(message) + + elif isinstance(target, TargetByPlugin): + p_name = target.plugin + if p_name not in scheduler_manager.get_registered_plugins(): + await schedule_cmd.finish(f"未找到插件 '{p_name}'。") + + if target.all_groups: + _, message = await scheduler_manager.resume_schedules_by_plugin(p_name) + await schedule_cmd.finish(message) + else: + _, message = await scheduler_manager.resume_schedule_by_plugin_group( + p_name, target.group_id, bot_id=bot_id_to_operate + ) + await schedule_cmd.finish(message) + + elif isinstance(target, TargetAll): + if target.for_group: + _, message = await scheduler_manager.resume_schedules_by_group( + target.for_group + ) + await schedule_cmd.finish(message) + else: + _, message = await scheduler_manager.resume_all_schedules() + await schedule_cmd.finish(message) + + else: + await schedule_cmd.finish("请提供任务ID、使用 -p <插件> 或 -all 选项。") + + +@schedule_cmd.assign("执行") +async def _(schedule_id: int): + _, message = await scheduler_manager.trigger_now(schedule_id) + await schedule_cmd.finish(message) + + +@schedule_cmd.assign("更新") +async def _( + schedule_id: int, + cron_expr: str | None = None, + interval_expr: str | None = None, + date_expr: str | None = None, + daily_expr: str | None = None, + kwargs_str: str | None = None, +): + if not any([cron_expr, interval_expr, date_expr, daily_expr, kwargs_str]): + await schedule_cmd.finish( + "请提供需要更新的时间 (--cron/--interval/--date/--daily) 或参数 (--kwargs)" + ) + + trigger_config = None + trigger_type = None + try: + if cron_expr: + trigger_type = "cron" + parts = cron_expr.split() + if len(parts) != 5: + raise ValueError("Cron 表达式必须有5个部分") + cron_keys = ["minute", "hour", "day", "month", "day_of_week"] + trigger_config = dict(zip(cron_keys, parts)) + elif interval_expr: + trigger_type = "interval" + trigger_config = _parse_interval(interval_expr) + elif date_expr: + trigger_type = "date" + trigger_config = {"run_date": datetime.fromisoformat(date_expr)} + elif daily_expr: + trigger_type = "cron" + trigger_config = _parse_daily_time(daily_expr) + except ValueError as e: + await schedule_cmd.finish(f"时间参数解析错误: {e}") + + job_kwargs = None + if kwargs_str: + schedule = await scheduler_manager.get_schedule_by_id(schedule_id) + if not schedule: + await schedule_cmd.finish(f"未找到 ID 为 {schedule_id} 的任务。") + + task_meta = scheduler_manager._registered_tasks.get(schedule.plugin_name) + if not task_meta or not (params_model := task_meta.get("model")): + await schedule_cmd.finish( + f"插件 '{schedule.plugin_name}' 未定义参数模型,无法更新参数。" + ) + + if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)): + await schedule_cmd.finish( + f"插件 '{schedule.plugin_name}' 的参数模型配置错误。" + ) + + raw_kwargs = {} + try: + for item in kwargs_str.split(","): + key, value = item.strip().split("=", 1) + raw_kwargs[key.strip()] = value + except Exception as e: + await schedule_cmd.finish( + f"参数格式错误,请使用 'key=value,key2=value2' 格式。错误: {e}" + ) + + try: + model_validate = getattr(params_model, "model_validate", None) + if not model_validate: + await schedule_cmd.finish( + f"插件 '{schedule.plugin_name}' 的参数模型不支持验证。" + ) + return + + validated_model = model_validate(raw_kwargs) + + model_dump = getattr(validated_model, "model_dump", None) + if not model_dump: + await schedule_cmd.finish( + f"插件 '{schedule.plugin_name}' 的参数模型不支持导出。" + ) + return + + job_kwargs = model_dump(exclude_unset=True) + except ValidationError as e: + errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()] + error_str = "\n".join(errors) + await schedule_cmd.finish(f"更新的参数验证失败:\n{error_str}") + return + + _, message = await scheduler_manager.update_schedule( + schedule_id, trigger_type, trigger_config, job_kwargs + ) + await schedule_cmd.finish(message) + + +@schedule_cmd.assign("插件列表") +async def _(): + registered_plugins = scheduler_manager.get_registered_plugins() + if not registered_plugins: + await schedule_cmd.finish("当前没有已注册的定时任务插件。") + + message_parts = ["📋 已注册的定时任务插件:"] + for i, plugin_name in enumerate(registered_plugins, 1): + task_meta = scheduler_manager._registered_tasks[plugin_name] + params_model = task_meta.get("model") + + if not params_model: + message_parts.append(f"{i}. {plugin_name} - 无参数") + continue + + if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)): + message_parts.append(f"{i}. {plugin_name} - ⚠️ 参数模型配置错误") + continue + + model_fields = getattr(params_model, "model_fields", None) + if model_fields: + param_info = ", ".join( + f"{field_name}({_get_type_name(field_info.annotation)})" + for field_name, field_info in model_fields.items() + ) + message_parts.append(f"{i}. {plugin_name} - 参数: {param_info}") + else: + message_parts.append(f"{i}. {plugin_name} - 无参数") + + await schedule_cmd.finish("\n".join(message_parts)) + + +@schedule_cmd.assign("状态") +async def _(schedule_id: int): + status = await scheduler_manager.get_schedule_status(schedule_id) + if not status: + await schedule_cmd.finish(f"未找到ID为 {schedule_id} 的定时任务。") + + info_lines = [ + f"📋 定时任务详细信息 (ID: {schedule_id})", + "--------------------", + f"▫️ 插件: {status['plugin_name']}", + f"▫️ Bot ID: {status.get('bot_id') or '默认'}", + f"▫️ 目标: {status['group_id'] or '全局'}", + f"▫️ 状态: {'✔️ 已启用' if status['is_enabled'] else '⏸️ 已暂停'}", + f"▫️ 下次运行: {status['next_run_time']}", + f"▫️ 触发规则: {_format_trigger(status)}", + f"▫️ 任务参数: {_format_params(status)}", + ] + await schedule_cmd.finish("\n".join(info_lines)) \ No newline at end of file diff --git a/zhenxun/models/schedule_info.py b/zhenxun/models/schedule_info.py new file mode 100644 index 00000000..4b994756 --- /dev/null +++ b/zhenxun/models/schedule_info.py @@ -0,0 +1,38 @@ +from tortoise import fields + +from zhenxun.services.db_context import Model + + +class ScheduleInfo(Model): + id = fields.IntField(pk=True, generated=True, auto_increment=True) + """自增id""" + bot_id = fields.CharField( + 255, null=True, default=None, description="任务关联的Bot ID" + ) + """任务关联的Bot ID""" + plugin_name = fields.CharField(255, description="插件模块名") + """插件模块名""" + group_id = fields.CharField( + 255, + null=True, + description="群组ID, '__ALL_GROUPS__' 表示所有群, 为空表示全局任务", + ) + """群组ID, 为空表示全局任务""" + trigger_type = fields.CharField( + max_length=20, default="cron", description="触发器类型 (cron, interval, date)" + ) + """触发器类型 (cron, interval, date)""" + trigger_config = fields.JSONField(description="触发器具体配置") + """触发器具体配置""" + job_kwargs = fields.JSONField( + default=dict, description="传递给任务函数的额外关键字参数" + ) + """传递给任务函数的额外关键字参数""" + is_enabled = fields.BooleanField(default=True, description="是否启用") + """是否启用""" + create_time = fields.DatetimeField(auto_now_add=True) + """创建时间""" + + class Meta: # pyright: ignore [reportIncompatibleVariableOverride] + table = "schedule_info" + table_description = "通用定时任务表" \ No newline at end of file diff --git a/zhenxun/services/db_context.py b/zhenxun/services/db_context.py index 85aee620..8cc66d95 100644 --- a/zhenxun/services/db_context.py +++ b/zhenxun/services/db_context.py @@ -44,7 +44,8 @@ class Model(TortoiseModel): sem_data: ClassVar[dict[str, dict[str, Semaphore]]] = {} def __init_subclass__(cls, **kwargs): - MODELS.append(cls.__module__) + if cls.__module__ not in MODELS: + MODELS.append(cls.__module__) if func := getattr(cls, "_run_script", None): SCRIPT_METHOD.append((cls.__module__, func)) @@ -171,7 +172,7 @@ class Model(TortoiseModel): await CacheRoot.reload(cache_type) -class DbUrlMissing(Exception): +class DbUrlIsNode(HookPriorityException): """ 数据库链接地址为空 """ @@ -190,7 +191,7 @@ class DbConnectError(Exception): @PriorityLifecycle.on_startup(priority=1) async def init(): if not BotConfig.db_url: - # raise DbUrlMissing("数据库配置为空,请在.env.dev中配置DB_URL...") + # raise DbUrlIsNode("数据库配置为空,请在.env.dev中配置DB_URL...") error = f""" ********************************************************************** 🌟 **************************** 配置为空 ************************* 🌟 @@ -199,7 +200,7 @@ async def init(): *********************************************************************** *********************************************************************** """ - raise DbUrlMissing("\n" + error.strip()) + raise DbUrlIsNode("\n" + error.strip()) try: await Tortoise.init( db_url=BotConfig.db_url, diff --git a/zhenxun/utils/browser.py b/zhenxun/utils/browser.py index ca2e7755..5644bf88 100644 --- a/zhenxun/utils/browser.py +++ b/zhenxun/utils/browser.py @@ -1,91 +1,94 @@ -import os -import sys +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from pathlib import Path +from typing import Any, Literal -from nonebot import get_driver -from playwright.__main__ import main -from playwright.async_api import Browser, Playwright, async_playwright +from nonebot_plugin_alconna import UniMessage +from nonebot_plugin_htmlrender import get_browser +from playwright.async_api import Page -from zhenxun.configs.config import BotConfig -from zhenxun.services.log import logger - -driver = get_driver() - -_playwright: Playwright | None = None -_browser: Browser | None = None +from zhenxun.utils.message import MessageUtils -# @driver.on_startup -# async def start_browser(): -# global _playwright -# global _browser -# install() -# await check_playwright_env() -# _playwright = await async_playwright().start() -# _browser = await _playwright.chromium.launch() +class BrowserIsNone(Exception): + pass -# @driver.on_shutdown -# async def shutdown_browser(): -# if _browser: -# await _browser.close() -# if _playwright: -# await _playwright.stop() # type: ignore +class AsyncPlaywright: + @classmethod + @asynccontextmanager + async def new_page( + cls, cookies: list[dict[str, Any]] | dict[str, Any] | None = None, **kwargs + ) -> AsyncGenerator[Page, None]: + """获取一个新页面 - -# def get_browser() -> Browser: -# if not _browser: -# raise RuntimeError("playwright is not initalized") -# return _browser - - -def install(): - """自动安装、更新 Chromium""" - - def set_env_variables(): - os.environ["PLAYWRIGHT_DOWNLOAD_HOST"] = ( - "https://npmmirror.com/mirrors/playwright/" - ) - if BotConfig.system_proxy: - os.environ["HTTPS_PROXY"] = BotConfig.system_proxy - - def restore_env_variables(): - os.environ.pop("PLAYWRIGHT_DOWNLOAD_HOST", None) - if BotConfig.system_proxy: - os.environ.pop("HTTPS_PROXY", None) - if original_proxy is not None: - os.environ["HTTPS_PROXY"] = original_proxy - - def try_install_chromium(): + 参数: + cookies: cookies + """ + browser = await get_browser() + ctx = await browser.new_context(**kwargs) + if cookies: + if isinstance(cookies, dict): + cookies = [cookies] + await ctx.add_cookies(cookies) # type: ignore + page = await ctx.new_page() try: - sys.argv = ["", "install", "chromium"] - main() - except SystemExit as e: - return e.code == 0 - return False + yield page + finally: + await page.close() + await ctx.close() - logger.info("检查 Chromium 更新") + @classmethod + async def screenshot( + cls, + url: str, + path: Path | str, + element: str | list[str], + *, + wait_time: int | None = None, + viewport_size: dict[str, int] | None = None, + wait_until: ( + Literal["domcontentloaded", "load", "networkidle"] | None + ) = "networkidle", + timeout: float | None = None, + type_: Literal["jpeg", "png"] | None = None, + user_agent: str | None = None, + cookies: list[dict[str, Any]] | dict[str, Any] | None = None, + **kwargs, + ) -> UniMessage | None: + """截图,该方法仅用于简单快捷截图,复杂截图请操作 page - original_proxy = os.environ.get("HTTPS_PROXY") - set_env_variables() - - success = try_install_chromium() - - if not success: - logger.info("Chromium 更新失败,尝试从原始仓库下载,速度较慢") - os.environ["PLAYWRIGHT_DOWNLOAD_HOST"] = "" - success = try_install_chromium() - - restore_env_variables() - - if not success: - raise RuntimeError("未知错误,Chromium 下载失败") - - -async def check_playwright_env(): - """检查 Playwright 依赖""" - logger.info("检查 Playwright 依赖") - try: - async with async_playwright() as p: - await p.chromium.launch() - except Exception as e: - raise ImportError("加载失败,Playwright 依赖不全,") from e + 参数: + url: 网址 + path: 存储路径 + element: 元素选择 + wait_time: 等待截取超时时间 + viewport_size: 窗口大小 + wait_until: 等待类型 + timeout: 超时限制 + type_: 保存类型 + user_agent: user_agent + cookies: cookies + """ + if viewport_size is None: + viewport_size = {"width": 2560, "height": 1080} + if isinstance(path, str): + path = Path(path) + wait_time = wait_time * 1000 if wait_time else None + element_list = [element] if isinstance(element, str) else element + async with cls.new_page( + cookies, + viewport=viewport_size, + user_agent=user_agent, + **kwargs, + ) as page: + await page.goto(url, timeout=timeout, wait_until=wait_until) + card = page + for e in element_list: + if not card: + return None + card = await card.wait_for_selector(e, timeout=wait_time) + if card: + await card.screenshot(path=path, timeout=timeout, type=type_) + return MessageUtils.build_message(path) + return None \ No newline at end of file diff --git a/zhenxun/utils/decorator/retry.py b/zhenxun/utils/decorator/retry.py index ddc55584..892005bc 100644 --- a/zhenxun/utils/decorator/retry.py +++ b/zhenxun/utils/decorator/retry.py @@ -1,24 +1,226 @@ +from collections.abc import Callable +from functools import partial, wraps +from typing import Any, Literal + from anyio import EndOfStream -from httpx import ConnectError, HTTPStatusError, TimeoutException -from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed +from httpx import ( + ConnectError, + HTTPStatusError, + RemoteProtocolError, + StreamError, + TimeoutException, +) +from nonebot.utils import is_coroutine_callable +from tenacity import ( + RetryCallState, + retry, + retry_if_exception_type, + retry_if_result, + stop_after_attempt, + wait_exponential, + wait_fixed, +) + +from zhenxun.services.log import logger + +LOG_COMMAND = "RetryDecorator" +_SENTINEL = object() + + +def _log_before_sleep(log_name: str | None, retry_state: RetryCallState): + """ + tenacity 重试前的日志记录回调函数。 + """ + func_name = retry_state.fn.__name__ if retry_state.fn else "unknown_function" + log_context = f"函数 '{func_name}'" + if log_name: + log_context = f"操作 '{log_name}' ({log_context})" + + reason = "" + if retry_state.outcome: + if exc := retry_state.outcome.exception(): + reason = f"触发异常: {exc.__class__.__name__}({exc})" + else: + reason = f"不满足结果条件: result={retry_state.outcome.result()}" + + wait_time = ( + getattr(retry_state.next_action, "sleep", 0) if retry_state.next_action else 0 + ) + logger.warning( + f"{log_context} 第 {retry_state.attempt_number} 次重试... " + f"等待 {wait_time:.2f} 秒. {reason}", + LOG_COMMAND, + ) class Retry: @staticmethod - def api( - retry_count: int = 3, wait: int = 1, exception: tuple[type[Exception], ...] = () + def simple( + stop_max_attempt: int = 3, + wait_fixed_seconds: int = 2, + exception: tuple[type[Exception], ...] = (), + *, + log_name: str | None = None, + on_failure: Callable[[Exception], Any] | None = None, + return_on_failure: Any = _SENTINEL, ): - """接口调用重试""" + """ + 一个简单的、用于通用网络请求的重试装饰器预设。 + + 参数: + stop_max_attempt: 最大重试次数。 + wait_fixed_seconds: 固定等待策略的等待秒数。 + exception: 额外需要重试的异常类型元组。 + log_name: 用于日志记录的操作名称。 + on_failure: (可选) 所有重试失败后的回调。 + return_on_failure: (可选) 所有重试失败后的返回值。 + """ + return Retry.api( + stop_max_attempt=stop_max_attempt, + wait_fixed_seconds=wait_fixed_seconds, + exception=exception, + strategy="fixed", + log_name=log_name, + on_failure=on_failure, + return_on_failure=return_on_failure, + ) + + @staticmethod + def download( + stop_max_attempt: int = 3, + exception: tuple[type[Exception], ...] = (), + *, + wait_exp_multiplier: int = 2, + wait_exp_max: int = 15, + log_name: str | None = None, + on_failure: Callable[[Exception], Any] | None = None, + return_on_failure: Any = _SENTINEL, + ): + """ + 一个适用于文件下载的重试装饰器预设,使用指数退避策略。 + + 参数: + stop_max_attempt: 最大重试次数。 + exception: 额外需要重试的异常类型元组。 + wait_exp_multiplier: 指数退避的乘数。 + wait_exp_max: 指数退避的最大等待时间。 + log_name: 用于日志记录的操作名称。 + on_failure: (可选) 所有重试失败后的回调。 + return_on_failure: (可选) 所有重试失败后的返回值。 + """ + return Retry.api( + stop_max_attempt=stop_max_attempt, + exception=exception, + strategy="exponential", + wait_exp_multiplier=wait_exp_multiplier, + wait_exp_max=wait_exp_max, + log_name=log_name, + on_failure=on_failure, + return_on_failure=return_on_failure, + ) + + @staticmethod + def api( + stop_max_attempt: int = 3, + wait_fixed_seconds: int = 1, + exception: tuple[type[Exception], ...] = (), + *, + strategy: Literal["fixed", "exponential"] = "fixed", + retry_on_result: Callable[[Any], bool] | None = None, + wait_exp_multiplier: int = 1, + wait_exp_max: int = 10, + log_name: str | None = None, + on_failure: Callable[[Exception], Any] | None = None, + return_on_failure: Any = _SENTINEL, + ): + """ + 通用、可配置的API调用重试装饰器。 + + 参数: + stop_max_attempt: 最大重试次数。 + wait_fixed_seconds: 固定等待策略的等待秒数。 + exception: 额外需要重试的异常类型元组。 + strategy: 重试等待策略, 'fixed' (固定) 或 'exponential' (指数退避)。 + retry_on_result: 一个回调函数,接收函数返回值。如果返回 True,则触发重试。 + 例如 `lambda r: r.status_code != 200` + wait_exp_multiplier: 指数退避的乘数。 + wait_exp_max: 指数退避的最大等待时间。 + log_name: 用于日志记录的操作名称,方便区分不同的重试场景。 + on_failure: (可选) 当所有重试都失败后,在抛出异常或返回默认值之前, + 会调用此函数,并将最终的异常实例作为参数传入。 + return_on_failure: (可选) 如果设置了此参数,当所有重试失败后, + 将不再抛出异常,而是返回此参数指定的值。 + """ base_exceptions = ( TimeoutException, ConnectError, HTTPStatusError, + StreamError, + RemoteProtocolError, EndOfStream, *exception, ) - return retry( - reraise=True, - stop=stop_after_attempt(retry_count), - wait=wait_fixed(wait), - retry=retry_if_exception_type(base_exceptions), - ) + + def decorator(func: Callable) -> Callable: + if strategy == "exponential": + wait_strategy = wait_exponential( + multiplier=wait_exp_multiplier, max=wait_exp_max + ) + else: + wait_strategy = wait_fixed(wait_fixed_seconds) + + retry_conditions = retry_if_exception_type(base_exceptions) + if retry_on_result: + retry_conditions |= retry_if_result(retry_on_result) + + log_callback = partial(_log_before_sleep, log_name) + + tenacity_retry_decorator = retry( + stop=stop_after_attempt(stop_max_attempt), + wait=wait_strategy, + retry=retry_conditions, + before_sleep=log_callback, + reraise=True, + ) + + decorated_func = tenacity_retry_decorator(func) + + if return_on_failure is _SENTINEL: + return decorated_func + + if is_coroutine_callable(func): + + @wraps(func) + async def async_wrapper(*args, **kwargs): + try: + return await decorated_func(*args, **kwargs) + except Exception as e: + if on_failure: + if is_coroutine_callable(on_failure): + await on_failure(e) + else: + on_failure(e) + return return_on_failure + + return async_wrapper + else: + + @wraps(func) + def sync_wrapper(*args, **kwargs): + try: + return decorated_func(*args, **kwargs) + except Exception as e: + if on_failure: + if is_coroutine_callable(on_failure): + logger.error( + f"不能在同步函数 '{func.__name__}' 中调用异步的 " + f"on_failure 回调。", + LOG_COMMAND, + ) + else: + on_failure(e) + return return_on_failure + + return sync_wrapper + + return decorator \ No newline at end of file diff --git a/zhenxun/utils/exception.py b/zhenxun/utils/exception.py index 8ec925ec..eb62f72d 100644 --- a/zhenxun/utils/exception.py +++ b/zhenxun/utils/exception.py @@ -64,3 +64,23 @@ class GoodsNotFound(Exception): """ pass + + +class AllURIsFailedError(Exception): + """ + 当所有备用URL都尝试失败后抛出此异常 + """ + + def __init__(self, urls: list[str], exceptions: list[Exception]): + self.urls = urls + self.exceptions = exceptions + super().__init__( + f"All {len(urls)} URIs failed. Last exception: {exceptions[-1]}" + ) + + def __str__(self) -> str: + exc_info = "\n".join( + f" - {url}: {exc.__class__.__name__}({exc})" + for url, exc in zip(self.urls, self.exceptions) + ) + return f"All {len(self.urls)} URIs failed:\n{exc_info}" \ No newline at end of file diff --git a/zhenxun/utils/http_utils.py b/zhenxun/utils/http_utils.py index 0ccf777f..7f53226d 100644 --- a/zhenxun/utils/http_utils.py +++ b/zhenxun/utils/http_utils.py @@ -1,16 +1,15 @@ import asyncio -from collections.abc import AsyncGenerator, Sequence +from collections.abc import AsyncGenerator, Awaitable, Callable, Sequence from contextlib import asynccontextmanager +import os from pathlib import Path import time -from typing import Any, ClassVar, Literal, cast +from typing import Any, ClassVar, cast import aiofiles import httpx -from httpx import AsyncHTTPTransport, HTTPStatusError, Proxy, Response -from nonebot_plugin_alconna import UniMessage -from nonebot_plugin_htmlrender import get_browser -from playwright.async_api import Page +from httpx import AsyncClient, AsyncHTTPTransport, HTTPStatusError, Proxy, Response +import nonebot from rich.progress import ( BarColumn, DownloadColumn, @@ -18,13 +17,84 @@ from rich.progress import ( TextColumn, TransferSpeedColumn, ) +import ujson as json from zhenxun.configs.config import BotConfig from zhenxun.services.log import logger -from zhenxun.utils.message import MessageUtils +from zhenxun.utils.decorator.retry import Retry +from zhenxun.utils.exception import AllURIsFailedError +from zhenxun.utils.manager.priority_manager import PriorityLifecycle from zhenxun.utils.user_agent import get_user_agent -CLIENT_KEY = ["use_proxy", "proxies", "proxy", "verify", "headers"] +from .browser import AsyncPlaywright, BrowserIsNone # noqa: F401 + +_SENTINEL = object() + +driver = nonebot.get_driver() +_client: AsyncClient | None = None + + +@PriorityLifecycle.on_startup(priority=0) +async def _(): + """ + 在Bot启动时初始化全局httpx客户端。 + """ + global _client + client_kwargs = {} + if proxy_url := BotConfig.system_proxy or None: + try: + version_parts = httpx.__version__.split(".") + major = int("".join(c for c in version_parts[0] if c.isdigit())) + minor = ( + int("".join(c for c in version_parts[1] if c.isdigit())) + if len(version_parts) > 1 + else 0 + ) + if (major, minor) >= (0, 28): + client_kwargs["proxy"] = proxy_url + else: + client_kwargs["proxies"] = proxy_url + except (ValueError, IndexError): + client_kwargs["proxy"] = proxy_url + logger.warning( + f"无法解析 httpx 版本 '{httpx.__version__}'," + "将默认使用新版 'proxy' 参数语法。" + ) + + _client = httpx.AsyncClient( + headers=get_user_agent(), + follow_redirects=True, + **client_kwargs, + ) + + logger.info("全局 httpx.AsyncClient 已启动。", "HTTPClient") + + +@driver.on_shutdown +async def _(): + """ + 在Bot关闭时关闭全局httpx客户端。 + """ + if _client: + await _client.aclose() + logger.info("全局 httpx.AsyncClient 已关闭。", "HTTPClient") + + +def get_client() -> AsyncClient: + """ + 获取全局 httpx.AsyncClient 实例。 + """ + global _client + if not _client: + if not os.environ.get("PYTEST_CURRENT_TEST"): + raise RuntimeError("全局 httpx.AsyncClient 未初始化,请检查启动流程。") + # 在测试环境中创建临时客户端 + logger.warning("在测试环境中创建临时HTTP客户端", "HTTPClient") + _client = httpx.AsyncClient( + headers=get_user_agent(), + follow_redirects=True, + ) + return _client def get_async_client( @@ -33,6 +103,10 @@ def get_async_client( verify: bool = False, **kwargs, ) -> httpx.AsyncClient: + """ + [向后兼容] 创建 httpx.AsyncClient 实例的工厂函数。 + 此函数完全保留了旧版本的接口,确保现有代码无需修改即可使用。 + """ transport = kwargs.pop("transport", None) or AsyncHTTPTransport(verify=verify) if proxies: http_proxy = proxies.get("http://") @@ -62,6 +136,30 @@ def get_async_client( class AsyncHttpx: + """ + 一个高级的、健壮的异步HTTP客户端工具类。 + + 设计理念: + - **全局共享客户端**: 默认情况下,所有请求都通过一个在应用启动时初始化的全局 + `httpx.AsyncClient` 实例发出。这个实例共享连接池,提高了效率和性能。 + - **向后兼容与灵活性**: 完全兼容旧的API,同时提供了两种方式来处理需要 + 特殊网络配置(如不同代理、超时)的请求: + 1. **单次请求覆盖**: 在调用 `get`, `post` 等方法时,直接传入 `proxies`, + `timeout` 等参数,将为该次请求创建一个临时的、独立的客户端。 + 2. **临时客户端上下文**: 使用 `temporary_client()` 上下文管理器,可以 + 获取一个独立的、可配置的客户端,用于执行一系列需要相同特殊配置的请求。 + - **健壮性**: 内置了自动重试、多镜像URL回退(fallback)机制,并提供了便捷的 + JSON解析和文件下载方法。 + """ + + CLIENT_KEY: ClassVar[list[str]] = [ + "use_proxy", + "proxies", + "proxy", + "verify", + "headers", + ] + default_proxy: ClassVar[dict[str, str] | None] = ( { "http://": BotConfig.system_proxy, @@ -72,155 +170,346 @@ class AsyncHttpx: ) @classmethod - @asynccontextmanager - async def _create_client( - cls, - *, - use_proxy: bool = True, - proxies: dict[str, str] | None = None, - proxy: str | None = None, - headers: dict[str, str] | None = None, - verify: bool = False, - **kwargs, - ) -> AsyncGenerator[httpx.AsyncClient, None]: - """创建一个私有的、配置好的 httpx.AsyncClient 上下文管理器。 + def _prepare_temporary_client_config(cls, client_kwargs: dict) -> dict: + """ + [向后兼容] 处理旧式的客户端kwargs,将其转换为get_async_client可用的配置。 + 主要负责处理 use_proxy 标志,这是为了兼容旧版本代码中使用的 use_proxy 参数。 + """ + final_config = client_kwargs.copy() - 说明: - 此方法用于内部统一创建客户端,处理代理和请求头逻辑,减少代码重复。 + use_proxy = final_config.pop("use_proxy", True) + + if "proxies" not in final_config and "proxy" not in final_config: + final_config["proxies"] = cls.default_proxy if use_proxy else None + return final_config + + @classmethod + def _split_kwargs(cls, kwargs: dict) -> tuple[dict, dict]: + """[优化] 分离客户端配置和请求参数,使逻辑更清晰。""" + client_kwargs = {k: v for k, v in kwargs.items() if k in cls.CLIENT_KEY} + request_kwargs = {k: v for k, v in kwargs.items() if k not in cls.CLIENT_KEY} + return client_kwargs, request_kwargs + + @classmethod + @asynccontextmanager + async def _get_active_client_context( + cls, client: AsyncClient | None = None, **kwargs + ) -> AsyncGenerator[AsyncClient, None]: + """ + 内部辅助方法,根据 kwargs 决定并提供一个活动的 HTTP 客户端。 + - 如果 kwargs 中有客户端配置,则创建并返回一个临时客户端。 + - 否则,返回传入的 client 或全局客户端。 + - 自动处理临时客户端的关闭。 + """ + if kwargs: + logger.debug(f"为单次请求创建临时客户端,配置: {kwargs}") + temp_client_config = cls._prepare_temporary_client_config(kwargs) + async with get_async_client(**temp_client_config) as temp_client: + yield temp_client + else: + yield client or get_client() + + @Retry.simple(log_name="内部HTTP请求") + async def _execute_request_inner( + self, client: AsyncClient, method: str, url: str, **kwargs + ) -> Response: + """ + [内部] 执行单次HTTP请求的私有核心方法,被重试装饰器包裹。 + """ + return await client.request(method, url, **kwargs) + + @classmethod + async def _single_request( + cls, method: str, url: str, *, client: AsyncClient | None = None, **kwargs + ) -> Response: + """ + 执行单次HTTP请求的私有方法,内置了默认的重试逻辑。 + """ + client_kwargs, request_kwargs = cls._split_kwargs(kwargs) + + async with cls._get_active_client_context( + client=client, **client_kwargs + ) as active_client: + response = await cls()._execute_request_inner( + active_client, method, url, **request_kwargs + ) + response.raise_for_status() + return response + + @classmethod + async def _execute_with_fallbacks( + cls, + urls: str | list[str], + worker: Callable[..., Awaitable[Any]], + *, + client: AsyncClient | None = None, + **kwargs, + ) -> Any: + """ + 通用执行器,按顺序尝试多个URL,直到成功。 参数: - use_proxy: 是否使用在类中定义的默认代理。 - proxies: 手动指定的代理,会覆盖默认代理。 - proxy: 单个代理,用于兼容旧版本,不再使用 - headers: 需要合并到客户端的自定义请求头。 - verify: 是否验证 SSL 证书。 - **kwargs: 其他所有传递给 httpx.AsyncClient 的参数。 - - 返回: - AsyncGenerator[httpx.AsyncClient, None]: 生成器。 + urls: 单个URL或URL列表。 + worker: 一个接受单个URL和其他kwargs并执行请求的协程函数。 + client: 可选的HTTP客户端。 + **kwargs: 传递给worker的额外参数。 """ - proxies_to_use = proxies or (cls.default_proxy if use_proxy else None) + url_list = [urls] if isinstance(urls, str) else urls + exceptions = [] - final_headers = get_user_agent() - if headers: - final_headers.update(headers) + for i, url in enumerate(url_list): + try: + result = await worker(url, client=client, **kwargs) + if i > 0: + logger.info( + f"成功从镜像 '{url}' 获取资源 " + f"(在尝试了 {i} 个失败的镜像之后)。", + "AsyncHttpx:FallbackExecutor", + ) + return result + except Exception as e: + exceptions.append(e) + if url != url_list[-1]: + logger.warning( + f"Worker '{worker.__name__}' on {url} failed, trying next. " + f"Error: {e.__class__.__name__}", + "AsyncHttpx:FallbackExecutor", + ) - async with get_async_client( - proxies=proxies_to_use, - proxy=proxy, - verify=verify, - headers=final_headers, - **kwargs, - ) as client: - yield client + raise AllURIsFailedError(url_list, exceptions) @classmethod async def get( cls, url: str | list[str], *, + follow_redirects: bool = True, check_status_code: int | None = None, + client: AsyncClient | None = None, **kwargs, - ) -> Response: # sourcery skip: use-assigned-variable + ) -> Response: """发送 GET 请求,并返回第一个成功的响应。 说明: - 本方法是 httpx.get 的高级包装,增加了多链接尝试、自动重试和统一的代理管理。 - 如果提供 URL 列表,它将依次尝试直到成功为止。 + 本方法是 httpx.get 的高级包装,增加了多链接尝试、自动重试和统一的 + 客户端管理。如果提供 URL 列表,它将依次尝试直到成功为止。 + + 用法建议: + - **常规使用**: `await AsyncHttpx.get(url)` 将使用全局客户端。 + - **单次覆盖配置**: `await AsyncHttpx.get(url, timeout=5, proxies=None)` + 将为本次请求创建一个独立的临时客户端。 参数: url: 单个请求 URL 或一个 URL 列表。 + follow_redirects: 是否跟随重定向。 check_status_code: (可选) 若提供,将检查响应状态码是否匹配,否则抛出异常。 - **kwargs: 其他所有传递给 httpx.get 的参数 - (如 `params`, `headers`, `timeout`等)。 + client: (可选) 指定一个活动的HTTP客户端实例。若提供,则忽略 + `**kwargs`中的客户端配置。 + **kwargs: 其他所有传递给 httpx.get 的参数 (如 `params`, `headers`, + `timeout`)。如果包含 `proxies`, `verify` 等客户端配置参数, + 将创建一个临时客户端。 返回: - Response: Response + Response: httpx 的响应对象。 + + Raises: + AllURIsFailedError: 当所有提供的URL都请求失败时抛出。 """ - urls = [url] if isinstance(url, str) else url - last_exception = None - for current_url in urls: - try: - logger.info(f"开始获取 {current_url}..") - client_kwargs = {k: v for k, v in kwargs.items() if k in CLIENT_KEY} - for key in CLIENT_KEY: - kwargs.pop(key, None) - async with cls._create_client(**client_kwargs) as client: - response = await client.get(current_url, **kwargs) - if check_status_code and response.status_code != check_status_code: - raise HTTPStatusError( - f"状态码错误: {response.status_code}!={check_status_code}", - request=response.request, - response=response, - ) - return response - except Exception as e: - last_exception = e - if current_url != urls[-1]: - logger.warning(f"获取 {current_url} 失败, 尝试下一个", e=e) + async def worker(current_url: str, **worker_kwargs) -> Response: + logger.info(f"开始获取 {current_url}..", "AsyncHttpx:get") + response = await cls._single_request( + "GET", current_url, follow_redirects=follow_redirects, **worker_kwargs + ) + if check_status_code and response.status_code != check_status_code: + raise HTTPStatusError( + f"状态码错误: {response.status_code}!={check_status_code}", + request=response.request, + response=response, + ) + return response - raise last_exception or Exception("所有URL都获取失败") + return await cls._execute_with_fallbacks(url, worker, client=client, **kwargs) @classmethod - async def head(cls, url: str, **kwargs) -> Response: - """发送 HEAD 请求。 + async def head( + cls, url: str | list[str], *, client: AsyncClient | None = None, **kwargs + ) -> Response: + """发送 HEAD 请求,并返回第一个成功的响应。""" - 说明: - 本方法是对 httpx.head 的封装,通常用于检查资源的元信息(如大小、类型)。 + async def worker(current_url: str, **worker_kwargs) -> Response: + return await cls._single_request("HEAD", current_url, **worker_kwargs) - 参数: - url: 请求的 URL。 - **kwargs: 其他所有传递给 httpx.head 的参数 - (如 `headers`, `timeout`, `allow_redirects`)。 - - 返回: - Response: Response - """ - client_kwargs = {k: v for k, v in kwargs.items() if k in CLIENT_KEY} - for key in CLIENT_KEY: - kwargs.pop(key, None) - async with cls._create_client(**client_kwargs) as client: - return await client.head(url, **kwargs) + return await cls._execute_with_fallbacks(url, worker, client=client, **kwargs) @classmethod - async def post(cls, url: str, **kwargs) -> Response: - """发送 POST 请求。 + async def post( + cls, url: str | list[str], *, client: AsyncClient | None = None, **kwargs + ) -> Response: + """发送 POST 请求,并返回第一个成功的响应。""" - 说明: - 本方法是对 httpx.post 的封装,提供了统一的代理和客户端管理。 + async def worker(current_url: str, **worker_kwargs) -> Response: + return await cls._single_request("POST", current_url, **worker_kwargs) - 参数: - url: 请求的 URL。 - **kwargs: 其他所有传递给 httpx.post 的参数 - (如 `data`, `json`, `content` 等)。 - - 返回: - Response: Response。 - """ - client_kwargs = {k: v for k, v in kwargs.items() if k in CLIENT_KEY} - for key in CLIENT_KEY: - kwargs.pop(key, None) - async with cls._create_client(**client_kwargs) as client: - return await client.post(url, **kwargs) + return await cls._execute_with_fallbacks(url, worker, client=client, **kwargs) @classmethod - async def get_content(cls, url: str, **kwargs) -> bytes: - """获取指定 URL 的二进制内容。 - - 说明: - 这是一个便捷方法,等同于调用 get() 后再访问 .content 属性。 - - 参数: - url: 请求的 URL。 - **kwargs: 所有传递给 get() 方法的参数。 - - 返回: - bytes: 响应内容的二进制字节流 (bytes)。 - """ - res = await cls.get(url, **kwargs) + async def get_content( + cls, url: str | list[str], *, client: AsyncClient | None = None, **kwargs + ) -> bytes: + """获取指定 URL 的二进制内容。""" + res = await cls.get(url, client=client, **kwargs) return res.content + @classmethod + @Retry.api( + log_name="JSON请求", + exception=(json.JSONDecodeError,), + return_on_failure=_SENTINEL, + ) + async def _request_and_parse_json( + cls, method: str, url: str, *, client: AsyncClient | None = None, **kwargs + ) -> Any: + """ + [私有] 执行单个HTTP请求并解析JSON,用于内部统一处理。 + """ + async with cls._get_active_client_context( + client=client, **kwargs + ) as active_client: + _, request_kwargs = cls._split_kwargs(kwargs) + response = await active_client.request(method, url, **request_kwargs) + response.raise_for_status() + return response.json() + + @classmethod + async def get_json( + cls, + url: str | list[str], + *, + default: Any = None, + raise_on_failure: bool = False, + client: AsyncClient | None = None, + **kwargs, + ) -> Any: + """ + 发送GET请求并自动解析为JSON,支持重试和多链接尝试。 + + 说明: + 这是一个高度便捷的方法,封装了请求、重试、JSON解析和错误处理。 + 它会在网络错误或JSON解析错误时自动重试。 + 如果所有尝试都失败,它会安全地返回一个默认值。 + + 参数: + url: 单个请求 URL 或一个备用 URL 列表。 + default: (可选) 当所有尝试都失败时返回的默认值,默认为None。 + raise_on_failure: (可选) 如果为 True, 当所有尝试失败时将抛出 + `AllURIsFailedError` 异常, 默认为 False. + client: (可选) 指定的HTTP客户端。 + **kwargs: 其他所有传递给 httpx.get 的参数。 + 例如 `params`, `headers`, `timeout`等。 + + 返回: + Any: 解析后的JSON数据,或在失败时返回 `default` 值。 + + Raises: + AllURIsFailedError: 当 `raise_on_failure` 为 True 且所有URL都请求失败时抛出 + """ + + async def worker(current_url: str, **worker_kwargs): + logger.debug(f"开始GET JSON: {current_url}", "AsyncHttpx:get_json") + return await cls._request_and_parse_json( + "GET", current_url, **worker_kwargs + ) + + try: + result = await cls._execute_with_fallbacks( + url, worker, client=client, **kwargs + ) + return default if result is _SENTINEL else result + except AllURIsFailedError as e: + logger.error(f"所有URL的JSON GET均失败: {e}", "AsyncHttpx:get_json") + if raise_on_failure: + raise e + return default + + @classmethod + async def post_json( + cls, + url: str | list[str], + *, + json: Any = None, + data: Any = None, + default: Any = None, + raise_on_failure: bool = False, + client: AsyncClient | None = None, + **kwargs, + ) -> Any: + """ + 发送POST请求并自动解析为JSON,功能与 get_json 类似。 + + 参数: + url: 单个请求 URL 或一个备用 URL 列表。 + json: (可选) 作为请求体发送的JSON数据。 + data: (可选) 作为请求体发送的表单数据。 + default: (可选) 当所有尝试都失败时返回的默认值,默认为None。 + raise_on_failure: (可选) 如果为 True, 当所有尝试失败时将抛出 + AllURIsFailedError 异常, 默认为 False. + client: (可选) 指定的HTTP客户端。 + **kwargs: 其他所有传递给 httpx.post 的参数。 + + 返回: + Any: 解析后的JSON数据,或在失败时返回 `default` 值。 + """ + if json is not None: + kwargs["json"] = json + if data is not None: + kwargs["data"] = data + + async def worker(current_url: str, **worker_kwargs): + logger.debug(f"开始POST JSON: {current_url}", "AsyncHttpx:post_json") + return await cls._request_and_parse_json( + "POST", current_url, **worker_kwargs + ) + + try: + result = await cls._execute_with_fallbacks( + url, worker, client=client, **kwargs + ) + return default if result is _SENTINEL else result + except AllURIsFailedError as e: + logger.error(f"所有URL的JSON POST均失败: {e}", "AsyncHttpx:post_json") + if raise_on_failure: + raise e + return default + + @classmethod + @Retry.api(log_name="文件下载(流式)") + async def _stream_download( + cls, url: str, path: Path, *, client: AsyncClient | None = None, **kwargs + ) -> None: + """ + 执行单个流式下载的私有方法,被重试装饰器包裹。 + """ + async with cls._get_active_client_context( + client=client, **kwargs + ) as active_client: + async with active_client.stream("GET", url, **kwargs) as response: + response.raise_for_status() + total = int(response.headers.get("Content-Length", 0)) + + with Progress( + TextColumn(path.name), + "[progress.percentage]{task.percentage:>3.0f}%", + BarColumn(bar_width=None), + DownloadColumn(), + TransferSpeedColumn(), + ) as progress: + task_id = progress.add_task("Download", total=total) + async with aiofiles.open(path, "wb") as f: + async for chunk in response.aiter_bytes(): + await f.write(chunk) + progress.update(task_id, advance=len(chunk)) + @classmethod async def download_file( cls, @@ -228,6 +517,7 @@ class AsyncHttpx: path: str | Path, *, stream: bool = False, + client: AsyncClient | None = None, **kwargs, ) -> bool: """下载文件到指定路径。 @@ -239,6 +529,7 @@ class AsyncHttpx: url: 单个文件 URL 或一个备用 URL 列表。 path: 文件保存的本地路径。 stream: (可选) 是否使用流式下载,适用于大文件,默认为 False。 + client: (可选) 指定的HTTP客户端。 **kwargs: 其他所有传递给 get() 方法或 httpx.stream() 的参数。 返回: @@ -247,49 +538,29 @@ class AsyncHttpx: path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) - urls = [url] if isinstance(url, str) else url + async def worker(current_url: str, **worker_kwargs) -> bool: + if not stream: + content = await cls.get_content(current_url, **worker_kwargs) + async with aiofiles.open(path, "wb") as f: + await f.write(content) + else: + await cls._stream_download(current_url, path, **worker_kwargs) - for current_url in urls: - try: - if not stream: - response = await cls.get(current_url, **kwargs) - response.raise_for_status() - async with aiofiles.open(path, "wb") as f: - await f.write(response.content) - else: - async with cls._create_client(**kwargs) as client: - stream_kwargs = { - k: v - for k, v in kwargs.items() - if k not in ["use_proxy", "proxy", "verify"] - } - async with client.stream( - "GET", current_url, **stream_kwargs - ) as response: - response.raise_for_status() - total = int(response.headers.get("Content-Length", 0)) + logger.info( + f"下载 {current_url} 成功 -> {path.absolute()}", + "AsyncHttpx:download", + ) + return True - with Progress( - TextColumn(path.name), - "[progress.percentage]{task.percentage:>3.0f}%", - BarColumn(bar_width=None), - DownloadColumn(), - TransferSpeedColumn(), - ) as progress: - task_id = progress.add_task("Download", total=total) - async with aiofiles.open(path, "wb") as f: - async for chunk in response.aiter_bytes(): - await f.write(chunk) - progress.update(task_id, advance=len(chunk)) - - logger.info(f"下载 {current_url} 成功 -> {path.absolute()}") - return True - - except Exception as e: - logger.warning(f"下载 {current_url} 失败,尝试下一个。错误: {e}") - - logger.error(f"所有URL {urls} 下载均失败 -> {path.absolute()}") - return False + try: + return await cls._execute_with_fallbacks( + url, worker, client=client, **kwargs + ) + except AllURIsFailedError: + logger.error( + f"所有URL下载均失败 -> {path.absolute()}", "AsyncHttpx:download" + ) + return False @classmethod async def gather_download_file( @@ -346,7 +617,6 @@ class AsyncHttpx: logger.error(f"并发下载任务 ({url_info}) 时发生错误", e=result) final_results.append(False) else: - # download_file 返回的是 bool,可以直接附加 final_results.append(cast(bool, result)) return final_results @@ -395,86 +665,30 @@ class AsyncHttpx: _results = sorted(iter(_results), key=lambda r: r["elapsed_time"]) return [result["url"] for result in _results] - -class AsyncPlaywright: @classmethod @asynccontextmanager - async def new_page( - cls, cookies: list[dict[str, Any]] | dict[str, Any] | None = None, **kwargs - ) -> AsyncGenerator[Page, None]: - """获取一个新页面 + async def temporary_client(cls, **kwargs) -> AsyncGenerator[AsyncClient, None]: + """ + 创建一个临时的、可配置的HTTP客户端上下文,并直接返回该客户端实例。 + + 此方法返回一个标准的 `httpx.AsyncClient`,它不使用全局连接池, + 拥有独立的配置(如代理、headers、超时等),并在退出上下文后自动关闭。 + 适用于需要用一套特殊网络配置执行一系列请求的场景。 + + 用法: + async with AsyncHttpx.temporary_client(proxies=None, timeout=5) as client: + # client 是一个标准的 httpx.AsyncClient 实例 + response1 = await client.get("http://some.internal.api/1") + response2 = await client.get("http://some.internal.api/2") + data = response2.json() 参数: - cookies: cookies + **kwargs: 所有传递给 `httpx.AsyncClient` 构造函数的参数。 + 例如: `proxies`, `headers`, `verify`, `timeout`, + `follow_redirects`。 + + Yields: + httpx.AsyncClient: 一个配置好的、临时的客户端实例。 """ - browser = await get_browser() - ctx = await browser.new_context(**kwargs) - if cookies: - if isinstance(cookies, dict): - cookies = [cookies] - await ctx.add_cookies(cookies) # type: ignore - page = await ctx.new_page() - try: - yield page - finally: - await page.close() - await ctx.close() - - @classmethod - async def screenshot( - cls, - url: str, - path: Path | str, - element: str | list[str], - *, - wait_time: int | None = None, - viewport_size: dict[str, int] | None = None, - wait_until: ( - Literal["domcontentloaded", "load", "networkidle"] | None - ) = "networkidle", - timeout: float | None = None, - type_: Literal["jpeg", "png"] | None = None, - user_agent: str | None = None, - cookies: list[dict[str, Any]] | dict[str, Any] | None = None, - **kwargs, - ) -> UniMessage | None: - """截图,该方法仅用于简单快捷截图,复杂截图请操作 page - - 参数: - url: 网址 - path: 存储路径 - element: 元素选择 - wait_time: 等待截取超时时间 - viewport_size: 窗口大小 - wait_until: 等待类型 - timeout: 超时限制 - type_: 保存类型 - user_agent: user_agent - cookies: cookies - """ - if viewport_size is None: - viewport_size = {"width": 2560, "height": 1080} - if isinstance(path, str): - path = Path(path) - wait_time = wait_time * 1000 if wait_time else None - element_list = [element] if isinstance(element, str) else element - async with cls.new_page( - cookies, - viewport=viewport_size, - user_agent=user_agent, - **kwargs, - ) as page: - await page.goto(url, timeout=timeout, wait_until=wait_until) - card = page - for e in element_list: - if not card: - return None - card = await card.wait_for_selector(e, timeout=wait_time) - if card: - await card.screenshot(path=path, timeout=timeout, type=type_) - return MessageUtils.build_message(path) - return None - - -class BrowserIsNone(Exception): - pass + async with get_async_client(**kwargs) as client: + yield client \ No newline at end of file diff --git a/zhenxun/utils/manager/schedule_manager.py b/zhenxun/utils/manager/schedule_manager.py new file mode 100644 index 00000000..f69d7a81 --- /dev/null +++ b/zhenxun/utils/manager/schedule_manager.py @@ -0,0 +1,810 @@ +import asyncio +from collections.abc import Callable, Coroutine +import copy +import inspect +import random +from typing import ClassVar + +import nonebot +from nonebot import get_bots +from nonebot_plugin_apscheduler import scheduler +from pydantic import BaseModel, ValidationError + +from zhenxun.configs.config import Config +from zhenxun.models.schedule_info import ScheduleInfo +from zhenxun.services.log import logger +from zhenxun.utils.common_utils import CommonUtils +from zhenxun.utils.manager.priority_manager import PriorityLifecycle +from zhenxun.utils.platform import PlatformUtils + +SCHEDULE_CONCURRENCY_KEY = "all_groups_concurrency_limit" + + +class SchedulerManager: + """ + 一个通用的、持久化的定时任务管理器,供所有插件使用。 + """ + + _registered_tasks: ClassVar[ + dict[str, dict[str, Callable | type[BaseModel] | None]] + ] = {} + _JOB_PREFIX = "zhenxun_schedule_" + _running_tasks: ClassVar[set] = set() + + def register( + self, plugin_name: str, params_model: type[BaseModel] | None = None + ) -> Callable: + """ + 注册一个可调度的任务函数。 + 被装饰的函数签名应为 `async def func(group_id: str | None, **kwargs)` + + Args: + plugin_name (str): 插件的唯一名称 (通常是模块名)。 + params_model (type[BaseModel], optional): 一个 Pydantic BaseModel 类, + 用于定义和验证任务函数接受的额外参数。 + """ + + def decorator(func: Callable[..., Coroutine]) -> Callable[..., Coroutine]: + if plugin_name in self._registered_tasks: + logger.warning(f"插件 '{plugin_name}' 的定时任务已被重复注册。") + self._registered_tasks[plugin_name] = { + "func": func, + "model": params_model, + } + model_name = params_model.__name__ if params_model else "无" + logger.debug( + f"插件 '{plugin_name}' 的定时任务已注册,参数模型: {model_name}" + ) + return func + + return decorator + + def get_registered_plugins(self) -> list[str]: + """获取所有已注册定时任务的插件列表。""" + return list(self._registered_tasks.keys()) + + def _get_job_id(self, schedule_id: int) -> str: + """根据数据库ID生成唯一的 APScheduler Job ID。""" + return f"{self._JOB_PREFIX}{schedule_id}" + + async def _execute_job(self, schedule_id: int): + """ + APScheduler 调度的入口函数。 + 根据 schedule_id 处理特定任务、所有群组任务或全局任务。 + """ + schedule = await ScheduleInfo.get_or_none(id=schedule_id) + if not schedule or not schedule.is_enabled: + logger.warning(f"定时任务 {schedule_id} 不存在或已禁用,跳过执行。") + return + + plugin_name = schedule.plugin_name + + task_meta = self._registered_tasks.get(plugin_name) + if not task_meta: + logger.error( + f"无法执行定时任务:插件 '{plugin_name}' 未注册或已卸载。将禁用该任务。" + ) + schedule.is_enabled = False + await schedule.save(update_fields=["is_enabled"]) + self._remove_aps_job(schedule.id) + return + + try: + if schedule.bot_id: + bot = nonebot.get_bot(schedule.bot_id) + else: + bot = nonebot.get_bot() + logger.debug( + f"任务 {schedule_id} 未关联特定Bot,使用默认Bot {bot.self_id}" + ) + except KeyError: + logger.warning( + f"定时任务 {schedule_id} 需要的 Bot {schedule.bot_id} " + f"不在线,本次执行跳过。" + ) + return + except ValueError: + logger.warning(f"当前没有Bot在线,定时任务 {schedule_id} 跳过。") + return + + if schedule.group_id == "__ALL_GROUPS__": + await self._execute_for_all_groups(schedule, task_meta, bot) + else: + await self._execute_for_single_target(schedule, task_meta, bot) + + async def _execute_for_all_groups( + self, schedule: ScheduleInfo, task_meta: dict, bot + ): + """为所有群组执行任务,并处理优先级覆盖。""" + plugin_name = schedule.plugin_name + + concurrency_limit = Config.get_config( + "SchedulerManager", SCHEDULE_CONCURRENCY_KEY, 5 + ) + if not isinstance(concurrency_limit, int) or concurrency_limit <= 0: + logger.warning( + f"无效的定时任务并发限制配置 '{concurrency_limit}',将使用默认值 5。" + ) + concurrency_limit = 5 + + logger.info( + f"开始执行针对 [所有群组] 的任务 " + f"(ID: {schedule.id}, 插件: {plugin_name}, Bot: {bot.self_id})," + f"并发限制: {concurrency_limit}" + ) + + all_gids = set() + try: + group_list, _ = await PlatformUtils.get_group_list(bot) + all_gids.update( + g.group_id for g in group_list if g.group_id and not g.channel_id + ) + except Exception as e: + logger.error(f"为 'all' 任务获取 Bot {bot.self_id} 的群列表失败", e=e) + return + + specific_tasks_gids = set( + await ScheduleInfo.filter( + plugin_name=plugin_name, group_id__in=list(all_gids) + ).values_list("group_id", flat=True) + ) + + semaphore = asyncio.Semaphore(concurrency_limit) + + async def worker(gid: str): + """使用 Semaphore 包装单个群组的任务执行""" + async with semaphore: + temp_schedule = copy.deepcopy(schedule) + temp_schedule.group_id = gid + await self._execute_for_single_target(temp_schedule, task_meta, bot) + await asyncio.sleep(random.uniform(0.1, 0.5)) + + tasks_to_run = [] + for gid in all_gids: + if gid in specific_tasks_gids: + logger.debug(f"群组 {gid} 已有特定任务,跳过 'all' 任务的执行。") + continue + tasks_to_run.append(worker(gid)) + + if tasks_to_run: + await asyncio.gather(*tasks_to_run) + + async def _execute_for_single_target( + self, schedule: ScheduleInfo, task_meta: dict, bot + ): + """为单个目标(具体群组或全局)执行任务。""" + plugin_name = schedule.plugin_name + group_id = schedule.group_id + + try: + is_blocked = await CommonUtils.task_is_block(bot, plugin_name, group_id) + if is_blocked: + target_desc = f"群 {group_id}" if group_id else "全局" + logger.info( + f"插件 '{plugin_name}' 的定时任务在目标 [{target_desc}]" + "因功能被禁用而跳过执行。" + ) + return + + task_func = task_meta["func"] + job_kwargs = schedule.job_kwargs + if not isinstance(job_kwargs, dict): + logger.error( + f"任务 {schedule.id} 的 job_kwargs 不是字典类型: {type(job_kwargs)}" + ) + return + + sig = inspect.signature(task_func) + if "bot" in sig.parameters: + job_kwargs["bot"] = bot + + logger.info( + f"插件 '{plugin_name}' 开始为目标 [{group_id or '全局'}] " + f"执行定时任务 (ID: {schedule.id})。" + ) + task = asyncio.create_task(task_func(group_id, **job_kwargs)) + self._running_tasks.add(task) + task.add_done_callback(self._running_tasks.discard) + await task + except Exception as e: + logger.error( + f"执行定时任务 (ID: {schedule.id}, 插件: {plugin_name}, " + f"目标: {group_id or '全局'}) 时发生异常", + e=e, + ) + + def _validate_and_prepare_kwargs( + self, plugin_name: str, job_kwargs: dict | None + ) -> tuple[bool, str | dict]: + """验证并准备任务参数,应用默认值""" + task_meta = self._registered_tasks.get(plugin_name) + if not task_meta: + return False, f"插件 '{plugin_name}' 未注册。" + + params_model = task_meta.get("model") + job_kwargs = job_kwargs if job_kwargs is not None else {} + + if not params_model: + if job_kwargs: + logger.warning( + f"插件 '{plugin_name}' 未定义参数模型,但收到了参数: {job_kwargs}" + ) + return True, job_kwargs + + if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)): + logger.error(f"插件 '{plugin_name}' 的参数模型不是有效的 BaseModel 类") + return False, f"插件 '{plugin_name}' 的参数模型配置错误" + + try: + model_validate = getattr(params_model, "model_validate", None) + if not model_validate: + return False, f"插件 '{plugin_name}' 的参数模型不支持验证" + + validated_model = model_validate(job_kwargs) + + model_dump = getattr(validated_model, "model_dump", None) + if not model_dump: + return False, f"插件 '{plugin_name}' 的参数模型不支持导出" + + return True, model_dump() + except ValidationError as e: + errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()] + error_str = "\n".join(errors) + msg = f"插件 '{plugin_name}' 的任务参数验证失败:\n{error_str}" + return False, msg + + def _add_aps_job(self, schedule: ScheduleInfo): + """根据 ScheduleInfo 对象添加或更新一个 APScheduler 任务。""" + job_id = self._get_job_id(schedule.id) + try: + scheduler.remove_job(job_id) + except Exception: + pass + + if not isinstance(schedule.trigger_config, dict): + logger.error( + f"任务 {schedule.id} 的 trigger_config 不是字典类型: " + f"{type(schedule.trigger_config)}" + ) + return + + scheduler.add_job( + self._execute_job, + trigger=schedule.trigger_type, + id=job_id, + misfire_grace_time=300, + args=[schedule.id], + **schedule.trigger_config, + ) + logger.debug( + f"已在 APScheduler 中添加/更新任务: {job_id} " + f"with trigger: {schedule.trigger_config}" + ) + + def _remove_aps_job(self, schedule_id: int): + """移除一个 APScheduler 任务。""" + job_id = self._get_job_id(schedule_id) + try: + scheduler.remove_job(job_id) + logger.debug(f"已从 APScheduler 中移除任务: {job_id}") + except Exception: + pass + + async def add_schedule( + self, + plugin_name: str, + group_id: str | None, + trigger_type: str, + trigger_config: dict, + job_kwargs: dict | None = None, + bot_id: str | None = None, + ) -> tuple[bool, str]: + """ + 添加或更新一个定时任务。 + """ + if plugin_name not in self._registered_tasks: + return False, f"插件 '{plugin_name}' 没有注册可用的定时任务。" + + is_valid, result = self._validate_and_prepare_kwargs(plugin_name, job_kwargs) + if not is_valid: + return False, str(result) + + validated_job_kwargs = result + + effective_bot_id = bot_id if group_id == "__ALL_GROUPS__" else None + + search_kwargs = { + "plugin_name": plugin_name, + "group_id": group_id, + } + if effective_bot_id: + search_kwargs["bot_id"] = effective_bot_id + else: + search_kwargs["bot_id__isnull"] = True + + defaults = { + "trigger_type": trigger_type, + "trigger_config": trigger_config, + "job_kwargs": validated_job_kwargs, + "is_enabled": True, + } + + schedule = await ScheduleInfo.filter(**search_kwargs).first() + created = False + + if schedule: + for key, value in defaults.items(): + setattr(schedule, key, value) + await schedule.save() + else: + creation_kwargs = { + "plugin_name": plugin_name, + "group_id": group_id, + "bot_id": effective_bot_id, + **defaults, + } + schedule = await ScheduleInfo.create(**creation_kwargs) + created = True + self._add_aps_job(schedule) + action = "设置" if created else "更新" + return True, f"已成功{action}插件 '{plugin_name}' 的定时任务。" + + async def add_schedule_for_all( + self, + plugin_name: str, + trigger_type: str, + trigger_config: dict, + job_kwargs: dict | None = None, + ) -> tuple[int, int]: + """为所有机器人所在的群组添加定时任务。""" + if plugin_name not in self._registered_tasks: + raise ValueError(f"插件 '{plugin_name}' 没有注册可用的定时任务。") + + groups = set() + for bot in get_bots().values(): + try: + group_list, _ = await PlatformUtils.get_group_list(bot) + groups.update( + g.group_id for g in group_list if g.group_id and not g.channel_id + ) + except Exception as e: + logger.error(f"获取 Bot {bot.self_id} 的群列表失败", e=e) + + success_count = 0 + fail_count = 0 + for gid in groups: + try: + success, _ = await self.add_schedule( + plugin_name, gid, trigger_type, trigger_config, job_kwargs + ) + if success: + success_count += 1 + else: + fail_count += 1 + except Exception as e: + logger.error(f"为群 {gid} 添加定时任务失败: {e}", e=e) + fail_count += 1 + await asyncio.sleep(0.05) + return success_count, fail_count + + async def update_schedule( + self, + schedule_id: int, + trigger_type: str | None = None, + trigger_config: dict | None = None, + job_kwargs: dict | None = None, + ) -> tuple[bool, str]: + """部分更新一个已存在的定时任务。""" + schedule = await self.get_schedule_by_id(schedule_id) + if not schedule: + return False, f"未找到 ID 为 {schedule_id} 的任务。" + + updated_fields = [] + if trigger_config is not None: + schedule.trigger_config = trigger_config + updated_fields.append("trigger_config") + + if trigger_type is not None and schedule.trigger_type != trigger_type: + schedule.trigger_type = trigger_type + updated_fields.append("trigger_type") + + if job_kwargs is not None: + if not isinstance(schedule.job_kwargs, dict): + return False, f"任务 {schedule_id} 的 job_kwargs 数据格式错误。" + + merged_kwargs = schedule.job_kwargs.copy() + merged_kwargs.update(job_kwargs) + + is_valid, result = self._validate_and_prepare_kwargs( + schedule.plugin_name, merged_kwargs + ) + if not is_valid: + return False, str(result) + + schedule.job_kwargs = result # type: ignore + updated_fields.append("job_kwargs") + + if not updated_fields: + return True, "没有任何需要更新的配置。" + + await schedule.save(update_fields=updated_fields) + self._add_aps_job(schedule) + return True, f"成功更新了任务 ID: {schedule_id} 的配置。" + + async def remove_schedule( + self, plugin_name: str, group_id: str | None, bot_id: str | None = None + ) -> tuple[bool, str]: + """移除指定插件和群组的定时任务。""" + query = {"plugin_name": plugin_name, "group_id": group_id} + if bot_id: + query["bot_id"] = bot_id + + schedules = await ScheduleInfo.filter(**query) + if not schedules: + msg = ( + f"未找到与 Bot {bot_id} 相关的群 {group_id} " + f"的插件 '{plugin_name}' 定时任务。" + ) + return (False, msg) + + for schedule in schedules: + self._remove_aps_job(schedule.id) + await schedule.delete() + + target_desc = f"群 {group_id}" if group_id else "全局" + msg = ( + f"已取消 Bot {bot_id} 在 [{target_desc}] " + f"的插件 '{plugin_name}' 所有定时任务。" + ) + return (True, msg) + + async def remove_schedule_for_all( + self, plugin_name: str, bot_id: str | None = None + ) -> int: + """移除指定插件在所有群组的定时任务。""" + query = {"plugin_name": plugin_name} + if bot_id: + query["bot_id"] = bot_id + + schedules_to_delete = await ScheduleInfo.filter(**query).all() + if not schedules_to_delete: + return 0 + + for schedule in schedules_to_delete: + self._remove_aps_job(schedule.id) + await schedule.delete() + await asyncio.sleep(0.01) + + return len(schedules_to_delete) + + async def remove_schedules_by_group(self, group_id: str) -> tuple[bool, str]: + """移除指定群组的所有定时任务。""" + schedules = await ScheduleInfo.filter(group_id=group_id) + if not schedules: + return False, f"群 {group_id} 没有任何定时任务。" + + count = 0 + for schedule in schedules: + self._remove_aps_job(schedule.id) + await schedule.delete() + count += 1 + await asyncio.sleep(0.01) + + return True, f"已成功移除群 {group_id} 的 {count} 个定时任务。" + + async def pause_schedules_by_group(self, group_id: str) -> tuple[int, str]: + """暂停指定群组的所有定时任务。""" + schedules = await ScheduleInfo.filter(group_id=group_id, is_enabled=True) + if not schedules: + return 0, f"群 {group_id} 没有正在运行的定时任务可暂停。" + + count = 0 + for schedule in schedules: + success, _ = await self.pause_schedule(schedule.id) + if success: + count += 1 + await asyncio.sleep(0.01) + + return count, f"已成功暂停群 {group_id} 的 {count} 个定时任务。" + + async def resume_schedules_by_group(self, group_id: str) -> tuple[int, str]: + """恢复指定群组的所有定时任务。""" + schedules = await ScheduleInfo.filter(group_id=group_id, is_enabled=False) + if not schedules: + return 0, f"群 {group_id} 没有已暂停的定时任务可恢复。" + + count = 0 + for schedule in schedules: + success, _ = await self.resume_schedule(schedule.id) + if success: + count += 1 + await asyncio.sleep(0.01) + + return count, f"已成功恢复群 {group_id} 的 {count} 个定时任务。" + + async def pause_schedules_by_plugin(self, plugin_name: str) -> tuple[int, str]: + """暂停指定插件在所有群组的定时任务。""" + schedules = await ScheduleInfo.filter(plugin_name=plugin_name, is_enabled=True) + if not schedules: + return 0, f"插件 '{plugin_name}' 没有正在运行的定时任务可暂停。" + + count = 0 + for schedule in schedules: + success, _ = await self.pause_schedule(schedule.id) + if success: + count += 1 + await asyncio.sleep(0.01) + + return ( + count, + f"已成功暂停插件 '{plugin_name}' 在所有群组的 {count} 个定时任务。", + ) + + async def resume_schedules_by_plugin(self, plugin_name: str) -> tuple[int, str]: + """恢复指定插件在所有群组的定时任务。""" + schedules = await ScheduleInfo.filter(plugin_name=plugin_name, is_enabled=False) + if not schedules: + return 0, f"插件 '{plugin_name}' 没有已暂停的定时任务可恢复。" + + count = 0 + for schedule in schedules: + success, _ = await self.resume_schedule(schedule.id) + if success: + count += 1 + await asyncio.sleep(0.01) + + return ( + count, + f"已成功恢复插件 '{plugin_name}' 在所有群组的 {count} 个定时任务。", + ) + + async def pause_schedule_by_plugin_group( + self, plugin_name: str, group_id: str | None, bot_id: str | None = None + ) -> tuple[bool, str]: + """暂停指定插件在指定群组的定时任务。""" + query = {"plugin_name": plugin_name, "group_id": group_id, "is_enabled": True} + if bot_id: + query["bot_id"] = bot_id + + schedules = await ScheduleInfo.filter(**query) + if not schedules: + return ( + False, + f"群 {group_id} 未设置插件 '{plugin_name}' 的定时任务或任务已暂停。", + ) + + count = 0 + for schedule in schedules: + success, _ = await self.pause_schedule(schedule.id) + if success: + count += 1 + + return ( + True, + f"已成功暂停群 {group_id} 的插件 '{plugin_name}' 共 {count} 个定时任务。", + ) + + async def resume_schedule_by_plugin_group( + self, plugin_name: str, group_id: str | None, bot_id: str | None = None + ) -> tuple[bool, str]: + """恢复指定插件在指定群组的定时任务。""" + query = {"plugin_name": plugin_name, "group_id": group_id, "is_enabled": False} + if bot_id: + query["bot_id"] = bot_id + + schedules = await ScheduleInfo.filter(**query) + if not schedules: + return ( + False, + f"群 {group_id} 未设置插件 '{plugin_name}' 的定时任务或任务已启用。", + ) + + count = 0 + for schedule in schedules: + success, _ = await self.resume_schedule(schedule.id) + if success: + count += 1 + + return ( + True, + f"已成功恢复群 {group_id} 的插件 '{plugin_name}' 共 {count} 个定时任务。", + ) + + async def remove_all_schedules(self) -> tuple[int, str]: + """移除所有群组的所有定时任务。""" + schedules = await ScheduleInfo.all() + if not schedules: + return 0, "当前没有任何定时任务。" + + count = 0 + for schedule in schedules: + self._remove_aps_job(schedule.id) + await schedule.delete() + count += 1 + await asyncio.sleep(0.01) + + return count, f"已成功移除所有群组的 {count} 个定时任务。" + + async def pause_all_schedules(self) -> tuple[int, str]: + """暂停所有群组的所有定时任务。""" + schedules = await ScheduleInfo.filter(is_enabled=True) + if not schedules: + return 0, "当前没有正在运行的定时任务可暂停。" + + count = 0 + for schedule in schedules: + success, _ = await self.pause_schedule(schedule.id) + if success: + count += 1 + await asyncio.sleep(0.01) + + return count, f"已成功暂停所有群组的 {count} 个定时任务。" + + async def resume_all_schedules(self) -> tuple[int, str]: + """恢复所有群组的所有定时任务。""" + schedules = await ScheduleInfo.filter(is_enabled=False) + if not schedules: + return 0, "当前没有已暂停的定时任务可恢复。" + + count = 0 + for schedule in schedules: + success, _ = await self.resume_schedule(schedule.id) + if success: + count += 1 + await asyncio.sleep(0.01) + + return count, f"已成功恢复所有群组的 {count} 个定时任务。" + + async def remove_schedule_by_id(self, schedule_id: int) -> tuple[bool, str]: + """通过ID移除指定的定时任务。""" + schedule = await self.get_schedule_by_id(schedule_id) + if not schedule: + return False, f"未找到 ID 为 {schedule_id} 的定时任务。" + + self._remove_aps_job(schedule.id) + await schedule.delete() + + return ( + True, + f"已删除插件 '{schedule.plugin_name}' 在群 {schedule.group_id} " + f"的定时任务 (ID: {schedule.id})。", + ) + + async def get_schedule_by_id(self, schedule_id: int) -> ScheduleInfo | None: + """通过ID获取定时任务信息。""" + return await ScheduleInfo.get_or_none(id=schedule_id) + + async def get_schedules( + self, plugin_name: str, group_id: str | None + ) -> list[ScheduleInfo]: + """获取特定群组特定插件的所有定时任务。""" + return await ScheduleInfo.filter(plugin_name=plugin_name, group_id=group_id) + + async def get_schedule( + self, plugin_name: str, group_id: str | None + ) -> ScheduleInfo | None: + """获取特定群组的定时任务信息。""" + return await ScheduleInfo.get_or_none( + plugin_name=plugin_name, group_id=group_id + ) + + async def get_all_schedules( + self, plugin_name: str | None = None + ) -> list[ScheduleInfo]: + """获取所有定时任务信息,可按插件名过滤。""" + if plugin_name: + return await ScheduleInfo.filter(plugin_name=plugin_name).all() + return await ScheduleInfo.all() + + async def get_schedule_status(self, schedule_id: int) -> dict | None: + """获取任务的详细状态。""" + schedule = await self.get_schedule_by_id(schedule_id) + if not schedule: + return None + + job_id = self._get_job_id(schedule.id) + job = scheduler.get_job(job_id) + + status = { + "id": schedule.id, + "bot_id": schedule.bot_id, + "plugin_name": schedule.plugin_name, + "group_id": schedule.group_id, + "is_enabled": schedule.is_enabled, + "trigger_type": schedule.trigger_type, + "trigger_config": schedule.trigger_config, + "job_kwargs": schedule.job_kwargs, + "next_run_time": job.next_run_time.strftime("%Y-%m-%d %H:%M:%S") + if job and job.next_run_time + else "N/A", + "is_paused_in_scheduler": not bool(job.next_run_time) if job else "N/A", + } + return status + + async def pause_schedule(self, schedule_id: int) -> tuple[bool, str]: + """暂停一个定时任务。""" + schedule = await self.get_schedule_by_id(schedule_id) + if not schedule or not schedule.is_enabled: + return False, "任务不存在或已暂停。" + + schedule.is_enabled = False + await schedule.save(update_fields=["is_enabled"]) + + job_id = self._get_job_id(schedule.id) + try: + scheduler.pause_job(job_id) + except Exception: + pass + + return ( + True, + f"已暂停插件 '{schedule.plugin_name}' 在群 {schedule.group_id} " + f"的定时任务 (ID: {schedule.id})。", + ) + + async def resume_schedule(self, schedule_id: int) -> tuple[bool, str]: + """恢复一个定时任务。""" + schedule = await self.get_schedule_by_id(schedule_id) + if not schedule or schedule.is_enabled: + return False, "任务不存在或已启用。" + + schedule.is_enabled = True + await schedule.save(update_fields=["is_enabled"]) + + job_id = self._get_job_id(schedule.id) + try: + scheduler.resume_job(job_id) + except Exception: + self._add_aps_job(schedule) + + return ( + True, + f"已恢复插件 '{schedule.plugin_name}' 在群 {schedule.group_id} " + f"的定时任务 (ID: {schedule.id})。", + ) + + async def trigger_now(self, schedule_id: int) -> tuple[bool, str]: + """手动触发一个定时任务。""" + schedule = await self.get_schedule_by_id(schedule_id) + if not schedule: + return False, f"未找到 ID 为 {schedule_id} 的定时任务。" + + if schedule.plugin_name not in self._registered_tasks: + return False, f"插件 '{schedule.plugin_name}' 没有注册可用的定时任务。" + + try: + await self._execute_job(schedule.id) + return ( + True, + f"已手动触发插件 '{schedule.plugin_name}' 在群 {schedule.group_id} " + f"的定时任务 (ID: {schedule.id})。", + ) + except Exception as e: + logger.error(f"手动触发任务失败: {e}") + return False, f"手动触发任务失败: {e}" + + +scheduler_manager = SchedulerManager() + + +@PriorityLifecycle.on_startup(priority=90) +async def _load_schedules_from_db(): + """在服务启动时从数据库加载并调度所有任务。""" + Config.add_plugin_config( + "SchedulerManager", + SCHEDULE_CONCURRENCY_KEY, + 5, + help="“所有群组”类型定时任务的并发执行数量限制", + type=int, + ) + + logger.info("正在从数据库加载并调度所有定时任务...") + schedules = await ScheduleInfo.filter(is_enabled=True).all() + count = 0 + for schedule in schedules: + if schedule.plugin_name in scheduler_manager._registered_tasks: + scheduler_manager._add_aps_job(schedule) + count += 1 + else: + logger.warning(f"跳过加载定时任务:插件 '{schedule.plugin_name}' 未注册。") + logger.info(f"定时任务加载完成,共成功加载 {count} 个任务。") \ No newline at end of file diff --git a/zhenxun/utils/platform.py b/zhenxun/utils/platform.py index f01aec3f..ffbd0114 100644 --- a/zhenxun/utils/platform.py +++ b/zhenxun/utils/platform.py @@ -227,7 +227,7 @@ class PlatformUtils: url = None if platform == "qq": if user_id.isdigit(): - url = f"http://q1.qlogo.cn/g?b=qq&nk={user_id}&s=160" + url = f"http://q1.qlogo.cn/g?b=qq&nk={user_id}&s=640" else: url = f"https://q.qlogo.cn/qqapp/{appid}/{user_id}/640" return await AsyncHttpx.get_content(url) if url else None