mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 14:22:55 +08:00
⭐导入最最近PR
This commit is contained in:
parent
c7427ecb20
commit
6d322b0f13
@ -1,12 +1,17 @@
|
||||
import random
|
||||
from typing import Any
|
||||
|
||||
from nonebot import on_regex
|
||||
from nonebot.adapters import Bot
|
||||
from nonebot.params import Depends, RegexGroup
|
||||
from nonebot.plugin import PluginMetadata
|
||||
from nonebot.rule import to_me
|
||||
from nonebot_plugin_alconna import Alconna, Option, on_alconna, store_true
|
||||
from nonebot_plugin_alconna import (
|
||||
Alconna,
|
||||
Args,
|
||||
Arparma,
|
||||
CommandMeta,
|
||||
Option,
|
||||
on_alconna,
|
||||
store_true,
|
||||
)
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.configs.config import BotConfig, Config
|
||||
@ -54,15 +59,22 @@ __plugin_meta__ = PluginMetadata(
|
||||
).to_dict(),
|
||||
)
|
||||
|
||||
_nickname_matcher = on_regex(
|
||||
"(?:以后)?(?:叫我|请叫我|称呼我)(.*)",
|
||||
_nickname_matcher = on_alconna(
|
||||
Alconna(
|
||||
"re:(?:以后)?(?:叫我|请叫我|称呼我)",
|
||||
Args["name?", str],
|
||||
meta=CommandMeta(compact=True),
|
||||
),
|
||||
rule=to_me(),
|
||||
priority=5,
|
||||
block=True,
|
||||
)
|
||||
|
||||
_global_nickname_matcher = on_regex(
|
||||
"设置全局昵称(.*)", rule=to_me(), priority=5, block=True
|
||||
_global_nickname_matcher = on_alconna(
|
||||
Alconna("设置全局昵称", Args["name?", str], meta=CommandMeta(compact=True)),
|
||||
rule=to_me(),
|
||||
priority=5,
|
||||
block=True,
|
||||
)
|
||||
|
||||
_matcher = on_alconna(
|
||||
@ -117,34 +129,32 @@ CANCEL = [
|
||||
]
|
||||
|
||||
|
||||
def CheckNickname():
|
||||
async def CheckNickname(
|
||||
bot: Bot,
|
||||
session: Uninfo,
|
||||
params: Arparma,
|
||||
):
|
||||
"""
|
||||
检查名称是否合法
|
||||
"""
|
||||
|
||||
async def dependency(
|
||||
bot: Bot,
|
||||
session: Uninfo,
|
||||
reg_group: tuple[Any, ...] = RegexGroup(),
|
||||
):
|
||||
black_word = Config.get_config("nickname", "BLACK_WORD")
|
||||
(name,) = reg_group
|
||||
logger.debug(f"昵称检查: {name}", "昵称设置", session=session)
|
||||
if not name:
|
||||
await MessageUtils.build_message("叫你空白?叫你虚空?叫你无名??").finish(
|
||||
at_sender=True
|
||||
)
|
||||
if session.user.id in bot.config.superusers:
|
||||
logger.debug(
|
||||
f"超级用户设置昵称, 跳过合法检测: {name}", "昵称设置", session=session
|
||||
)
|
||||
return
|
||||
black_word = Config.get_config("nickname", "BLACK_WORD")
|
||||
name = params.query("name")
|
||||
logger.debug(f"昵称检查: {name}", "昵称设置", session=session)
|
||||
if not name:
|
||||
await MessageUtils.build_message("叫你空白?叫你虚空?叫你无名??").finish(
|
||||
at_sender=True
|
||||
)
|
||||
if session.user.id in bot.config.superusers:
|
||||
logger.debug(
|
||||
f"超级用户设置昵称, 跳过合法检测: {name}", "昵称设置", session=session
|
||||
)
|
||||
else:
|
||||
if len(name) > 20:
|
||||
await MessageUtils.build_message("昵称可不能超过20个字!").finish(
|
||||
at_sender=True
|
||||
)
|
||||
if name in bot.config.nickname:
|
||||
await MessageUtils.build_message("笨蛋!休想占用我的名字! #").finish(
|
||||
await MessageUtils.build_message("笨蛋!休想占用我的名字! ").finish(
|
||||
at_sender=True
|
||||
)
|
||||
if black_word:
|
||||
@ -162,17 +172,17 @@ def CheckNickname():
|
||||
await MessageUtils.build_message(
|
||||
f"字符 [{word}] 为禁止字符!"
|
||||
).finish(at_sender=True)
|
||||
|
||||
return Depends(dependency)
|
||||
return name
|
||||
|
||||
|
||||
@_nickname_matcher.handle(parameterless=[CheckNickname()])
|
||||
@_nickname_matcher.handle()
|
||||
async def _(
|
||||
bot: Bot,
|
||||
session: Uninfo,
|
||||
name_: Arparma,
|
||||
uname: str = UserName(),
|
||||
reg_group: tuple[Any, ...] = RegexGroup(),
|
||||
):
|
||||
(name,) = reg_group
|
||||
name = await CheckNickname(bot, session, name_)
|
||||
if len(name) < 5 and random.random() < 0.3:
|
||||
name = "~".join(name)
|
||||
group_id = None
|
||||
@ -200,13 +210,14 @@ async def _(
|
||||
)
|
||||
|
||||
|
||||
@_global_nickname_matcher.handle(parameterless=[CheckNickname()])
|
||||
@_global_nickname_matcher.handle()
|
||||
async def _(
|
||||
bot: Bot,
|
||||
session: Uninfo,
|
||||
name_: Arparma,
|
||||
nickname: str = UserName(),
|
||||
reg_group: tuple[Any, ...] = RegexGroup(),
|
||||
):
|
||||
(name,) = reg_group
|
||||
name = await CheckNickname(bot, session, name_)
|
||||
await FriendUser.set_user_nickname(
|
||||
session.user.id,
|
||||
name,
|
||||
@ -227,15 +238,14 @@ async def _(session: Uninfo, uname: str = UserName()):
|
||||
group_id = session.group.parent.id if session.group.parent else session.group.id
|
||||
if group_id:
|
||||
nickname = await GroupInfoUser.get_user_nickname(session.user.id, group_id)
|
||||
card = uname
|
||||
else:
|
||||
nickname = await FriendUser.get_user_nickname(session.user.id)
|
||||
card = uname
|
||||
if nickname:
|
||||
await MessageUtils.build_message(random.choice(REMIND).format(nickname)).finish(
|
||||
reply_to=True
|
||||
)
|
||||
else:
|
||||
card = uname
|
||||
await MessageUtils.build_message(
|
||||
random.choice(
|
||||
[
|
||||
@ -270,4 +280,4 @@ async def _(bot: Bot, session: Uninfo):
|
||||
else:
|
||||
await MessageUtils.build_message("你在做梦吗?你没有昵称啊").finish(
|
||||
reply_to=True
|
||||
)
|
||||
)
|
||||
51
zhenxun/builtin_plugins/scheduler_admin/__init__.py
Normal file
51
zhenxun/builtin_plugins/scheduler_admin/__init__.py
Normal file
@ -0,0 +1,51 @@
|
||||
from nonebot.plugin import PluginMetadata
|
||||
|
||||
from zhenxun.configs.utils import PluginExtraData
|
||||
from zhenxun.utils.enum import PluginType
|
||||
|
||||
from . import command # noqa: F401
|
||||
|
||||
__plugin_meta__ = PluginMetadata(
|
||||
name="定时任务管理",
|
||||
description="查看和管理由 SchedulerManager 控制的定时任务。",
|
||||
usage="""
|
||||
📋 定时任务管理 - 支持群聊和私聊操作
|
||||
|
||||
🔍 查看任务:
|
||||
定时任务 查看 [-all] [-g <群号>] [-p <插件>] [--page <页码>]
|
||||
• 群聊中: 查看本群任务
|
||||
• 私聊中: 必须使用 -g <群号> 或 -all 选项 (SUPERUSER)
|
||||
|
||||
📊 任务状态:
|
||||
定时任务 状态 <任务ID> 或 任务状态 <任务ID>
|
||||
• 查看单个任务的详细信息和状态
|
||||
|
||||
⚙️ 任务管理 (SUPERUSER):
|
||||
定时任务 设置 <插件> [时间选项] [-g <群号> | -g all] [--kwargs <参数>]
|
||||
定时任务 删除 <任务ID> | -p <插件> [-g <群号>] | -all
|
||||
定时任务 暂停 <任务ID> | -p <插件> [-g <群号>] | -all
|
||||
定时任务 恢复 <任务ID> | -p <插件> [-g <群号>] | -all
|
||||
定时任务 执行 <任务ID>
|
||||
定时任务 更新 <任务ID> [时间选项] [--kwargs <参数>]
|
||||
|
||||
📝 时间选项 (三选一):
|
||||
--cron "<分> <时> <日> <月> <周>" # 例: --cron "0 8 * * *"
|
||||
--interval <时间间隔> # 例: --interval 30m, 2h, 10s
|
||||
--date "<YYYY-MM-DD HH:MM:SS>" # 例: --date "2024-01-01 08:00:00"
|
||||
--daily "<HH:MM>" # 例: --daily "08:30"
|
||||
|
||||
📚 其他功能:
|
||||
定时任务 插件列表 # 查看所有可设置定时任务的插件 (SUPERUSER)
|
||||
|
||||
🏷️ 别名支持:
|
||||
查看: ls, list | 设置: add, 开启 | 删除: del, rm, remove, 关闭, 取消
|
||||
暂停: pause | 恢复: resume | 执行: trigger, run | 状态: status, info
|
||||
更新: update, modify, 修改 | 插件列表: plugins
|
||||
""".strip(),
|
||||
extra=PluginExtraData(
|
||||
author="HibiKier",
|
||||
version="0.1.2",
|
||||
plugin_type=PluginType.SUPERUSER,
|
||||
is_show=False,
|
||||
).to_dict(),
|
||||
)
|
||||
836
zhenxun/builtin_plugins/scheduler_admin/command.py
Normal file
836
zhenxun/builtin_plugins/scheduler_admin/command.py
Normal file
@ -0,0 +1,836 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
import re
|
||||
|
||||
from nonebot.adapters import Event
|
||||
from nonebot.adapters.onebot.v11 import Bot
|
||||
from nonebot.params import Depends
|
||||
from nonebot.permission import SUPERUSER
|
||||
from nonebot_plugin_alconna import (
|
||||
Alconna,
|
||||
AlconnaMatch,
|
||||
Args,
|
||||
Arparma,
|
||||
Match,
|
||||
Option,
|
||||
Query,
|
||||
Subcommand,
|
||||
on_alconna,
|
||||
)
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from zhenxun.utils._image_template import ImageTemplate
|
||||
from zhenxun.utils.manager.schedule_manager import scheduler_manager
|
||||
|
||||
|
||||
def _get_type_name(annotation) -> str:
|
||||
"""获取类型注解的名称"""
|
||||
if hasattr(annotation, "__name__"):
|
||||
return annotation.__name__
|
||||
elif hasattr(annotation, "_name"):
|
||||
return annotation._name
|
||||
else:
|
||||
return str(annotation)
|
||||
|
||||
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
from zhenxun.utils.rules import admin_check
|
||||
|
||||
|
||||
def _format_trigger(schedule_status: dict) -> str:
|
||||
"""将触发器配置格式化为人类可读的字符串"""
|
||||
trigger_type = schedule_status["trigger_type"]
|
||||
config = schedule_status["trigger_config"]
|
||||
|
||||
if trigger_type == "cron":
|
||||
minute = config.get("minute", "*")
|
||||
hour = config.get("hour", "*")
|
||||
day = config.get("day", "*")
|
||||
month = config.get("month", "*")
|
||||
day_of_week = config.get("day_of_week", "*")
|
||||
|
||||
if day == "*" and month == "*" and day_of_week == "*":
|
||||
formatted_hour = hour if hour == "*" else f"{int(hour):02d}"
|
||||
formatted_minute = minute if minute == "*" else f"{int(minute):02d}"
|
||||
return f"每天 {formatted_hour}:{formatted_minute}"
|
||||
else:
|
||||
return f"Cron: {minute} {hour} {day} {month} {day_of_week}"
|
||||
elif trigger_type == "interval":
|
||||
seconds = config.get("seconds", 0)
|
||||
minutes = config.get("minutes", 0)
|
||||
hours = config.get("hours", 0)
|
||||
days = config.get("days", 0)
|
||||
if days:
|
||||
trigger_str = f"每 {days} 天"
|
||||
elif hours:
|
||||
trigger_str = f"每 {hours} 小时"
|
||||
elif minutes:
|
||||
trigger_str = f"每 {minutes} 分钟"
|
||||
else:
|
||||
trigger_str = f"每 {seconds} 秒"
|
||||
elif trigger_type == "date":
|
||||
run_date = config.get("run_date", "未知时间")
|
||||
trigger_str = f"在 {run_date}"
|
||||
else:
|
||||
trigger_str = f"{trigger_type}: {config}"
|
||||
|
||||
return trigger_str
|
||||
|
||||
|
||||
def _format_params(schedule_status: dict) -> str:
|
||||
"""将任务参数格式化为人类可读的字符串"""
|
||||
if kwargs := schedule_status.get("job_kwargs"):
|
||||
kwargs_str = " | ".join(f"{k}: {v}" for k, v in kwargs.items())
|
||||
return kwargs_str
|
||||
return "-"
|
||||
|
||||
|
||||
def _parse_interval(interval_str: str) -> dict:
|
||||
"""增强版解析器,支持 d(天)"""
|
||||
match = re.match(r"(\d+)([smhd])", interval_str.lower())
|
||||
if not match:
|
||||
raise ValueError("时间间隔格式错误, 请使用如 '30m', '2h', '1d', '10s' 的格式。")
|
||||
|
||||
value, unit = int(match.group(1)), match.group(2)
|
||||
if unit == "s":
|
||||
return {"seconds": value}
|
||||
if unit == "m":
|
||||
return {"minutes": value}
|
||||
if unit == "h":
|
||||
return {"hours": value}
|
||||
if unit == "d":
|
||||
return {"days": value}
|
||||
return {}
|
||||
|
||||
|
||||
def _parse_daily_time(time_str: str) -> dict:
|
||||
"""解析 HH:MM 或 HH:MM:SS 格式的时间为 cron 配置"""
|
||||
if match := re.match(r"^(\d{1,2}):(\d{1,2})(?::(\d{1,2}))?$", time_str):
|
||||
hour, minute, second = match.groups()
|
||||
hour, minute = int(hour), int(minute)
|
||||
|
||||
if not (0 <= hour <= 23 and 0 <= minute <= 59):
|
||||
raise ValueError("小时或分钟数值超出范围。")
|
||||
|
||||
cron_config = {
|
||||
"minute": str(minute),
|
||||
"hour": str(hour),
|
||||
"day": "*",
|
||||
"month": "*",
|
||||
"day_of_week": "*",
|
||||
}
|
||||
if second is not None:
|
||||
if not (0 <= int(second) <= 59):
|
||||
raise ValueError("秒数值超出范围。")
|
||||
cron_config["second"] = str(second)
|
||||
|
||||
return cron_config
|
||||
else:
|
||||
raise ValueError("时间格式错误,请使用 'HH:MM' 或 'HH:MM:SS' 格式。")
|
||||
|
||||
|
||||
async def GetBotId(
|
||||
bot: Bot,
|
||||
bot_id_match: Match[str] = AlconnaMatch("bot_id"),
|
||||
) -> str:
|
||||
"""获取要操作的Bot ID"""
|
||||
if bot_id_match.available:
|
||||
return bot_id_match.result
|
||||
return bot.self_id
|
||||
|
||||
|
||||
class ScheduleTarget:
|
||||
"""定时任务操作目标的基类"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TargetByID(ScheduleTarget):
|
||||
"""按任务ID操作"""
|
||||
|
||||
def __init__(self, id: int):
|
||||
self.id = id
|
||||
|
||||
|
||||
class TargetByPlugin(ScheduleTarget):
|
||||
"""按插件名操作"""
|
||||
|
||||
def __init__(
|
||||
self, plugin: str, group_id: str | None = None, all_groups: bool = False
|
||||
):
|
||||
self.plugin = plugin
|
||||
self.group_id = group_id
|
||||
self.all_groups = all_groups
|
||||
|
||||
|
||||
class TargetAll(ScheduleTarget):
|
||||
"""操作所有任务"""
|
||||
|
||||
def __init__(self, for_group: str | None = None):
|
||||
self.for_group = for_group
|
||||
|
||||
|
||||
TargetScope = TargetByID | TargetByPlugin | TargetAll | None
|
||||
|
||||
|
||||
def create_target_parser(subcommand_name: str):
|
||||
"""
|
||||
创建一个依赖注入函数,用于解析删除、暂停、恢复等命令的操作目标。
|
||||
"""
|
||||
|
||||
async def dependency(
|
||||
event: Event,
|
||||
schedule_id: Match[int] = AlconnaMatch("schedule_id"),
|
||||
plugin_name: Match[str] = AlconnaMatch("plugin_name"),
|
||||
group_id: Match[str] = AlconnaMatch("group_id"),
|
||||
all_enabled: Query[bool] = Query(f"{subcommand_name}.all"),
|
||||
) -> TargetScope:
|
||||
if schedule_id.available:
|
||||
return TargetByID(schedule_id.result)
|
||||
|
||||
if plugin_name.available:
|
||||
p_name = plugin_name.result
|
||||
if all_enabled.available:
|
||||
return TargetByPlugin(plugin=p_name, all_groups=True)
|
||||
elif group_id.available:
|
||||
gid = group_id.result
|
||||
if gid.lower() == "all":
|
||||
return TargetByPlugin(plugin=p_name, all_groups=True)
|
||||
return TargetByPlugin(plugin=p_name, group_id=gid)
|
||||
else:
|
||||
current_group_id = getattr(event, "group_id", None)
|
||||
if current_group_id:
|
||||
return TargetByPlugin(plugin=p_name, group_id=str(current_group_id))
|
||||
else:
|
||||
await schedule_cmd.finish(
|
||||
"私聊中操作插件任务必须使用 -g <群号> 或 -all 选项。"
|
||||
)
|
||||
|
||||
if all_enabled.available:
|
||||
return TargetAll(for_group=group_id.result if group_id.available else None)
|
||||
|
||||
return None
|
||||
|
||||
return dependency
|
||||
|
||||
|
||||
schedule_cmd = on_alconna(
|
||||
Alconna(
|
||||
"定时任务",
|
||||
Subcommand(
|
||||
"查看",
|
||||
Option("-g", Args["target_group_id", str]),
|
||||
Option("-all", help_text="查看所有群聊 (SUPERUSER)"),
|
||||
Option("-p", Args["plugin_name", str], help_text="按插件名筛选"),
|
||||
Option("--page", Args["page", int, 1], help_text="指定页码"),
|
||||
alias=["ls", "list"],
|
||||
help_text="查看定时任务",
|
||||
),
|
||||
Subcommand(
|
||||
"设置",
|
||||
Args["plugin_name", str],
|
||||
Option("--cron", Args["cron_expr", str], help_text="设置 cron 表达式"),
|
||||
Option("--interval", Args["interval_expr", str], help_text="设置时间间隔"),
|
||||
Option("--date", Args["date_expr", str], help_text="设置特定执行日期"),
|
||||
Option(
|
||||
"--daily",
|
||||
Args["daily_expr", str],
|
||||
help_text="设置每天执行的时间 (如 08:20)",
|
||||
),
|
||||
Option("-g", Args["group_id", str], help_text="指定群组ID或'all'"),
|
||||
Option("-all", help_text="对所有群生效 (等同于 -g all)"),
|
||||
Option("--kwargs", Args["kwargs_str", str], help_text="设置任务参数"),
|
||||
Option(
|
||||
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
|
||||
),
|
||||
alias=["add", "开启"],
|
||||
help_text="设置/开启一个定时任务",
|
||||
),
|
||||
Subcommand(
|
||||
"删除",
|
||||
Args["schedule_id?", int],
|
||||
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
|
||||
Option("-g", Args["group_id", str], help_text="指定群组ID"),
|
||||
Option("-all", help_text="对所有群生效"),
|
||||
Option(
|
||||
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
|
||||
),
|
||||
alias=["del", "rm", "remove", "关闭", "取消"],
|
||||
help_text="删除一个或多个定时任务",
|
||||
),
|
||||
Subcommand(
|
||||
"暂停",
|
||||
Args["schedule_id?", int],
|
||||
Option("-all", help_text="对当前群所有任务生效"),
|
||||
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
|
||||
Option("-g", Args["group_id", str], help_text="指定群组ID (SUPERUSER)"),
|
||||
Option(
|
||||
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
|
||||
),
|
||||
alias=["pause"],
|
||||
help_text="暂停一个或多个定时任务",
|
||||
),
|
||||
Subcommand(
|
||||
"恢复",
|
||||
Args["schedule_id?", int],
|
||||
Option("-all", help_text="对当前群所有任务生效"),
|
||||
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
|
||||
Option("-g", Args["group_id", str], help_text="指定群组ID (SUPERUSER)"),
|
||||
Option(
|
||||
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
|
||||
),
|
||||
alias=["resume"],
|
||||
help_text="恢复一个或多个定时任务",
|
||||
),
|
||||
Subcommand(
|
||||
"执行",
|
||||
Args["schedule_id", int],
|
||||
alias=["trigger", "run"],
|
||||
help_text="立即执行一次任务",
|
||||
),
|
||||
Subcommand(
|
||||
"更新",
|
||||
Args["schedule_id", int],
|
||||
Option("--cron", Args["cron_expr", str], help_text="设置 cron 表达式"),
|
||||
Option("--interval", Args["interval_expr", str], help_text="设置时间间隔"),
|
||||
Option("--date", Args["date_expr", str], help_text="设置特定执行日期"),
|
||||
Option(
|
||||
"--daily",
|
||||
Args["daily_expr", str],
|
||||
help_text="更新每天执行的时间 (如 08:20)",
|
||||
),
|
||||
Option("--kwargs", Args["kwargs_str", str], help_text="更新参数"),
|
||||
alias=["update", "modify", "修改"],
|
||||
help_text="更新任务配置",
|
||||
),
|
||||
Subcommand(
|
||||
"状态",
|
||||
Args["schedule_id", int],
|
||||
alias=["status", "info"],
|
||||
help_text="查看单个任务的详细状态",
|
||||
),
|
||||
Subcommand(
|
||||
"插件列表",
|
||||
alias=["plugins"],
|
||||
help_text="列出所有可用的插件",
|
||||
),
|
||||
),
|
||||
priority=5,
|
||||
block=True,
|
||||
rule=admin_check(1),
|
||||
)
|
||||
|
||||
schedule_cmd.shortcut(
|
||||
"任务状态",
|
||||
command="定时任务",
|
||||
arguments=["状态", "{%0}"],
|
||||
prefix=True,
|
||||
)
|
||||
|
||||
|
||||
@schedule_cmd.handle()
|
||||
async def _handle_time_options_mutex(arp: Arparma):
|
||||
time_options = ["cron", "interval", "date", "daily"]
|
||||
provided_options = [opt for opt in time_options if arp.query(opt) is not None]
|
||||
if len(provided_options) > 1:
|
||||
await schedule_cmd.finish(
|
||||
f"时间选项 --{', --'.join(provided_options)} 不能同时使用,请只选择一个。"
|
||||
)
|
||||
|
||||
|
||||
@schedule_cmd.assign("查看")
|
||||
async def _(
|
||||
bot: Bot,
|
||||
event: Event,
|
||||
target_group_id: Match[str] = AlconnaMatch("target_group_id"),
|
||||
all_groups: Query[bool] = Query("查看.all"),
|
||||
plugin_name: Match[str] = AlconnaMatch("plugin_name"),
|
||||
page: Match[int] = AlconnaMatch("page"),
|
||||
):
|
||||
is_superuser = await SUPERUSER(bot, event)
|
||||
schedules = []
|
||||
title = ""
|
||||
|
||||
current_group_id = getattr(event, "group_id", None)
|
||||
if not (all_groups.available or target_group_id.available) and not current_group_id:
|
||||
await schedule_cmd.finish("私聊中查看任务必须使用 -g <群号> 或 -all 选项。")
|
||||
|
||||
if all_groups.available:
|
||||
if not is_superuser:
|
||||
await schedule_cmd.finish("需要超级用户权限才能查看所有群组的定时任务。")
|
||||
schedules = await scheduler_manager.get_all_schedules()
|
||||
title = "所有群组的定时任务"
|
||||
elif target_group_id.available:
|
||||
if not is_superuser:
|
||||
await schedule_cmd.finish("需要超级用户权限才能查看指定群组的定时任务。")
|
||||
gid = target_group_id.result
|
||||
schedules = [
|
||||
s for s in await scheduler_manager.get_all_schedules() if s.group_id == gid
|
||||
]
|
||||
title = f"群 {gid} 的定时任务"
|
||||
else:
|
||||
gid = str(current_group_id)
|
||||
schedules = [
|
||||
s for s in await scheduler_manager.get_all_schedules() if s.group_id == gid
|
||||
]
|
||||
title = "本群的定时任务"
|
||||
|
||||
if plugin_name.available:
|
||||
schedules = [s for s in schedules if s.plugin_name == plugin_name.result]
|
||||
title += f" [插件: {plugin_name.result}]"
|
||||
|
||||
if not schedules:
|
||||
await schedule_cmd.finish("没有找到任何相关的定时任务。")
|
||||
|
||||
page_size = 15
|
||||
current_page = page.result
|
||||
total_items = len(schedules)
|
||||
total_pages = (total_items + page_size - 1) // page_size
|
||||
start_index = (current_page - 1) * page_size
|
||||
end_index = start_index + page_size
|
||||
paginated_schedules = schedules[start_index:end_index]
|
||||
|
||||
if not paginated_schedules:
|
||||
await schedule_cmd.finish("这一页没有内容了哦~")
|
||||
|
||||
status_tasks = [
|
||||
scheduler_manager.get_schedule_status(s.id) for s in paginated_schedules
|
||||
]
|
||||
all_statuses = await asyncio.gather(*status_tasks)
|
||||
data_list = [
|
||||
[
|
||||
s["id"],
|
||||
s["plugin_name"],
|
||||
s.get("bot_id") or "N/A",
|
||||
s["group_id"] or "全局",
|
||||
s["next_run_time"],
|
||||
_format_trigger(s),
|
||||
_format_params(s),
|
||||
"✔️ 已启用" if s["is_enabled"] else "⏸️ 已暂停",
|
||||
]
|
||||
for s in all_statuses
|
||||
if s
|
||||
]
|
||||
|
||||
if not data_list:
|
||||
await schedule_cmd.finish("没有找到任何相关的定时任务。")
|
||||
|
||||
img = await ImageTemplate.table_page(
|
||||
head_text=title,
|
||||
tip_text=f"第 {current_page}/{total_pages} 页,共 {total_items} 条任务",
|
||||
column_name=[
|
||||
"ID",
|
||||
"插件",
|
||||
"Bot ID",
|
||||
"群组/目标",
|
||||
"下次运行",
|
||||
"触发规则",
|
||||
"参数",
|
||||
"状态",
|
||||
],
|
||||
data_list=data_list,
|
||||
column_space=20,
|
||||
)
|
||||
await MessageUtils.build_message(img).send(reply_to=True)
|
||||
|
||||
|
||||
@schedule_cmd.assign("设置")
|
||||
async def _(
|
||||
event: Event,
|
||||
plugin_name: str,
|
||||
cron_expr: str | None = None,
|
||||
interval_expr: str | None = None,
|
||||
date_expr: str | None = None,
|
||||
daily_expr: str | None = None,
|
||||
group_id: str | None = None,
|
||||
kwargs_str: str | None = None,
|
||||
all_enabled: Query[bool] = Query("设置.all"),
|
||||
bot_id_to_operate: str = Depends(GetBotId),
|
||||
):
|
||||
if plugin_name not in scheduler_manager._registered_tasks:
|
||||
await schedule_cmd.finish(
|
||||
f"插件 '{plugin_name}' 没有注册可用的定时任务。\n"
|
||||
f"可用插件: {list(scheduler_manager._registered_tasks.keys())}"
|
||||
)
|
||||
|
||||
trigger_type = ""
|
||||
trigger_config = {}
|
||||
|
||||
try:
|
||||
if cron_expr:
|
||||
trigger_type = "cron"
|
||||
parts = cron_expr.split()
|
||||
if len(parts) != 5:
|
||||
raise ValueError("Cron 表达式必须有5个部分 (分 时 日 月 周)")
|
||||
cron_keys = ["minute", "hour", "day", "month", "day_of_week"]
|
||||
trigger_config = dict(zip(cron_keys, parts))
|
||||
elif interval_expr:
|
||||
trigger_type = "interval"
|
||||
trigger_config = _parse_interval(interval_expr)
|
||||
elif date_expr:
|
||||
trigger_type = "date"
|
||||
trigger_config = {"run_date": datetime.fromisoformat(date_expr)}
|
||||
elif daily_expr:
|
||||
trigger_type = "cron"
|
||||
trigger_config = _parse_daily_time(daily_expr)
|
||||
else:
|
||||
await schedule_cmd.finish(
|
||||
"必须提供一种时间选项: --cron, --interval, --date, 或 --daily。"
|
||||
)
|
||||
except ValueError as e:
|
||||
await schedule_cmd.finish(f"时间参数解析错误: {e}")
|
||||
|
||||
job_kwargs = {}
|
||||
if kwargs_str:
|
||||
task_meta = scheduler_manager._registered_tasks[plugin_name]
|
||||
params_model = task_meta.get("model")
|
||||
if not params_model:
|
||||
await schedule_cmd.finish(f"插件 '{plugin_name}' 不支持设置额外参数。")
|
||||
|
||||
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
|
||||
await schedule_cmd.finish(f"插件 '{plugin_name}' 的参数模型配置错误。")
|
||||
|
||||
raw_kwargs = {}
|
||||
try:
|
||||
for item in kwargs_str.split(","):
|
||||
key, value = item.strip().split("=", 1)
|
||||
raw_kwargs[key.strip()] = value
|
||||
except Exception as e:
|
||||
await schedule_cmd.finish(
|
||||
f"参数格式错误,请使用 'key=value,key2=value2' 格式。错误: {e}"
|
||||
)
|
||||
|
||||
try:
|
||||
model_validate = getattr(params_model, "model_validate", None)
|
||||
if not model_validate:
|
||||
await schedule_cmd.finish(
|
||||
f"插件 '{plugin_name}' 的参数模型不支持验证。"
|
||||
)
|
||||
return
|
||||
|
||||
validated_model = model_validate(raw_kwargs)
|
||||
|
||||
model_dump = getattr(validated_model, "model_dump", None)
|
||||
if not model_dump:
|
||||
await schedule_cmd.finish(
|
||||
f"插件 '{plugin_name}' 的参数模型不支持导出。"
|
||||
)
|
||||
return
|
||||
|
||||
job_kwargs = model_dump()
|
||||
except ValidationError as e:
|
||||
errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()]
|
||||
error_str = "\n".join(errors)
|
||||
await schedule_cmd.finish(
|
||||
f"插件 '{plugin_name}' 的任务参数验证失败:\n{error_str}"
|
||||
)
|
||||
return
|
||||
|
||||
target_group_id: str | None
|
||||
current_group_id = getattr(event, "group_id", None)
|
||||
|
||||
if group_id and group_id.lower() == "all":
|
||||
target_group_id = "__ALL_GROUPS__"
|
||||
elif all_enabled.available:
|
||||
target_group_id = "__ALL_GROUPS__"
|
||||
elif group_id:
|
||||
target_group_id = group_id
|
||||
elif current_group_id:
|
||||
target_group_id = str(current_group_id)
|
||||
else:
|
||||
await schedule_cmd.finish(
|
||||
"私聊中设置定时任务时,必须使用 -g <群号> 或 --all 选项指定目标。"
|
||||
)
|
||||
return
|
||||
|
||||
success, msg = await scheduler_manager.add_schedule(
|
||||
plugin_name,
|
||||
target_group_id,
|
||||
trigger_type,
|
||||
trigger_config,
|
||||
job_kwargs,
|
||||
bot_id=bot_id_to_operate,
|
||||
)
|
||||
|
||||
if target_group_id == "__ALL_GROUPS__":
|
||||
target_desc = f"所有群组 (Bot: {bot_id_to_operate})"
|
||||
elif target_group_id is None:
|
||||
target_desc = "全局"
|
||||
else:
|
||||
target_desc = f"群组 {target_group_id}"
|
||||
|
||||
if success:
|
||||
await schedule_cmd.finish(f"已成功为 [{target_desc}] {msg}")
|
||||
else:
|
||||
await schedule_cmd.finish(f"为 [{target_desc}] 设置任务失败: {msg}")
|
||||
|
||||
|
||||
@schedule_cmd.assign("删除")
|
||||
async def _(
|
||||
target: TargetScope = Depends(create_target_parser("删除")),
|
||||
bot_id_to_operate: str = Depends(GetBotId),
|
||||
):
|
||||
if isinstance(target, TargetByID):
|
||||
_, message = await scheduler_manager.remove_schedule_by_id(target.id)
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
elif isinstance(target, TargetByPlugin):
|
||||
p_name = target.plugin
|
||||
if p_name not in scheduler_manager.get_registered_plugins():
|
||||
await schedule_cmd.finish(f"未找到插件 '{p_name}'。")
|
||||
|
||||
if target.all_groups:
|
||||
removed_count = await scheduler_manager.remove_schedule_for_all(
|
||||
p_name, bot_id=bot_id_to_operate
|
||||
)
|
||||
message = (
|
||||
f"已取消了 {removed_count} 个群组的插件 '{p_name}' 定时任务。"
|
||||
if removed_count > 0
|
||||
else f"没有找到插件 '{p_name}' 的定时任务。"
|
||||
)
|
||||
await schedule_cmd.finish(message)
|
||||
else:
|
||||
_, message = await scheduler_manager.remove_schedule(
|
||||
p_name, target.group_id, bot_id=bot_id_to_operate
|
||||
)
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
elif isinstance(target, TargetAll):
|
||||
if target.for_group:
|
||||
_, message = await scheduler_manager.remove_schedules_by_group(
|
||||
target.for_group
|
||||
)
|
||||
await schedule_cmd.finish(message)
|
||||
else:
|
||||
_, message = await scheduler_manager.remove_all_schedules()
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
else:
|
||||
await schedule_cmd.finish(
|
||||
"删除任务失败:请提供任务ID,或通过 -p <插件> 或 -all 指定要删除的任务。"
|
||||
)
|
||||
|
||||
|
||||
@schedule_cmd.assign("暂停")
|
||||
async def _(
|
||||
target: TargetScope = Depends(create_target_parser("暂停")),
|
||||
bot_id_to_operate: str = Depends(GetBotId),
|
||||
):
|
||||
if isinstance(target, TargetByID):
|
||||
_, message = await scheduler_manager.pause_schedule(target.id)
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
elif isinstance(target, TargetByPlugin):
|
||||
p_name = target.plugin
|
||||
if p_name not in scheduler_manager.get_registered_plugins():
|
||||
await schedule_cmd.finish(f"未找到插件 '{p_name}'。")
|
||||
|
||||
if target.all_groups:
|
||||
_, message = await scheduler_manager.pause_schedules_by_plugin(p_name)
|
||||
await schedule_cmd.finish(message)
|
||||
else:
|
||||
_, message = await scheduler_manager.pause_schedule_by_plugin_group(
|
||||
p_name, target.group_id, bot_id=bot_id_to_operate
|
||||
)
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
elif isinstance(target, TargetAll):
|
||||
if target.for_group:
|
||||
_, message = await scheduler_manager.pause_schedules_by_group(
|
||||
target.for_group
|
||||
)
|
||||
await schedule_cmd.finish(message)
|
||||
else:
|
||||
_, message = await scheduler_manager.pause_all_schedules()
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
else:
|
||||
await schedule_cmd.finish("请提供任务ID、使用 -p <插件> 或 -all 选项。")
|
||||
|
||||
|
||||
@schedule_cmd.assign("恢复")
|
||||
async def _(
|
||||
target: TargetScope = Depends(create_target_parser("恢复")),
|
||||
bot_id_to_operate: str = Depends(GetBotId),
|
||||
):
|
||||
if isinstance(target, TargetByID):
|
||||
_, message = await scheduler_manager.resume_schedule(target.id)
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
elif isinstance(target, TargetByPlugin):
|
||||
p_name = target.plugin
|
||||
if p_name not in scheduler_manager.get_registered_plugins():
|
||||
await schedule_cmd.finish(f"未找到插件 '{p_name}'。")
|
||||
|
||||
if target.all_groups:
|
||||
_, message = await scheduler_manager.resume_schedules_by_plugin(p_name)
|
||||
await schedule_cmd.finish(message)
|
||||
else:
|
||||
_, message = await scheduler_manager.resume_schedule_by_plugin_group(
|
||||
p_name, target.group_id, bot_id=bot_id_to_operate
|
||||
)
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
elif isinstance(target, TargetAll):
|
||||
if target.for_group:
|
||||
_, message = await scheduler_manager.resume_schedules_by_group(
|
||||
target.for_group
|
||||
)
|
||||
await schedule_cmd.finish(message)
|
||||
else:
|
||||
_, message = await scheduler_manager.resume_all_schedules()
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
else:
|
||||
await schedule_cmd.finish("请提供任务ID、使用 -p <插件> 或 -all 选项。")
|
||||
|
||||
|
||||
@schedule_cmd.assign("执行")
|
||||
async def _(schedule_id: int):
|
||||
_, message = await scheduler_manager.trigger_now(schedule_id)
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
|
||||
@schedule_cmd.assign("更新")
|
||||
async def _(
|
||||
schedule_id: int,
|
||||
cron_expr: str | None = None,
|
||||
interval_expr: str | None = None,
|
||||
date_expr: str | None = None,
|
||||
daily_expr: str | None = None,
|
||||
kwargs_str: str | None = None,
|
||||
):
|
||||
if not any([cron_expr, interval_expr, date_expr, daily_expr, kwargs_str]):
|
||||
await schedule_cmd.finish(
|
||||
"请提供需要更新的时间 (--cron/--interval/--date/--daily) 或参数 (--kwargs)"
|
||||
)
|
||||
|
||||
trigger_config = None
|
||||
trigger_type = None
|
||||
try:
|
||||
if cron_expr:
|
||||
trigger_type = "cron"
|
||||
parts = cron_expr.split()
|
||||
if len(parts) != 5:
|
||||
raise ValueError("Cron 表达式必须有5个部分")
|
||||
cron_keys = ["minute", "hour", "day", "month", "day_of_week"]
|
||||
trigger_config = dict(zip(cron_keys, parts))
|
||||
elif interval_expr:
|
||||
trigger_type = "interval"
|
||||
trigger_config = _parse_interval(interval_expr)
|
||||
elif date_expr:
|
||||
trigger_type = "date"
|
||||
trigger_config = {"run_date": datetime.fromisoformat(date_expr)}
|
||||
elif daily_expr:
|
||||
trigger_type = "cron"
|
||||
trigger_config = _parse_daily_time(daily_expr)
|
||||
except ValueError as e:
|
||||
await schedule_cmd.finish(f"时间参数解析错误: {e}")
|
||||
|
||||
job_kwargs = None
|
||||
if kwargs_str:
|
||||
schedule = await scheduler_manager.get_schedule_by_id(schedule_id)
|
||||
if not schedule:
|
||||
await schedule_cmd.finish(f"未找到 ID 为 {schedule_id} 的任务。")
|
||||
|
||||
task_meta = scheduler_manager._registered_tasks.get(schedule.plugin_name)
|
||||
if not task_meta or not (params_model := task_meta.get("model")):
|
||||
await schedule_cmd.finish(
|
||||
f"插件 '{schedule.plugin_name}' 未定义参数模型,无法更新参数。"
|
||||
)
|
||||
|
||||
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
|
||||
await schedule_cmd.finish(
|
||||
f"插件 '{schedule.plugin_name}' 的参数模型配置错误。"
|
||||
)
|
||||
|
||||
raw_kwargs = {}
|
||||
try:
|
||||
for item in kwargs_str.split(","):
|
||||
key, value = item.strip().split("=", 1)
|
||||
raw_kwargs[key.strip()] = value
|
||||
except Exception as e:
|
||||
await schedule_cmd.finish(
|
||||
f"参数格式错误,请使用 'key=value,key2=value2' 格式。错误: {e}"
|
||||
)
|
||||
|
||||
try:
|
||||
model_validate = getattr(params_model, "model_validate", None)
|
||||
if not model_validate:
|
||||
await schedule_cmd.finish(
|
||||
f"插件 '{schedule.plugin_name}' 的参数模型不支持验证。"
|
||||
)
|
||||
return
|
||||
|
||||
validated_model = model_validate(raw_kwargs)
|
||||
|
||||
model_dump = getattr(validated_model, "model_dump", None)
|
||||
if not model_dump:
|
||||
await schedule_cmd.finish(
|
||||
f"插件 '{schedule.plugin_name}' 的参数模型不支持导出。"
|
||||
)
|
||||
return
|
||||
|
||||
job_kwargs = model_dump(exclude_unset=True)
|
||||
except ValidationError as e:
|
||||
errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()]
|
||||
error_str = "\n".join(errors)
|
||||
await schedule_cmd.finish(f"更新的参数验证失败:\n{error_str}")
|
||||
return
|
||||
|
||||
_, message = await scheduler_manager.update_schedule(
|
||||
schedule_id, trigger_type, trigger_config, job_kwargs
|
||||
)
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
|
||||
@schedule_cmd.assign("插件列表")
|
||||
async def _():
|
||||
registered_plugins = scheduler_manager.get_registered_plugins()
|
||||
if not registered_plugins:
|
||||
await schedule_cmd.finish("当前没有已注册的定时任务插件。")
|
||||
|
||||
message_parts = ["📋 已注册的定时任务插件:"]
|
||||
for i, plugin_name in enumerate(registered_plugins, 1):
|
||||
task_meta = scheduler_manager._registered_tasks[plugin_name]
|
||||
params_model = task_meta.get("model")
|
||||
|
||||
if not params_model:
|
||||
message_parts.append(f"{i}. {plugin_name} - 无参数")
|
||||
continue
|
||||
|
||||
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
|
||||
message_parts.append(f"{i}. {plugin_name} - ⚠️ 参数模型配置错误")
|
||||
continue
|
||||
|
||||
model_fields = getattr(params_model, "model_fields", None)
|
||||
if model_fields:
|
||||
param_info = ", ".join(
|
||||
f"{field_name}({_get_type_name(field_info.annotation)})"
|
||||
for field_name, field_info in model_fields.items()
|
||||
)
|
||||
message_parts.append(f"{i}. {plugin_name} - 参数: {param_info}")
|
||||
else:
|
||||
message_parts.append(f"{i}. {plugin_name} - 无参数")
|
||||
|
||||
await schedule_cmd.finish("\n".join(message_parts))
|
||||
|
||||
|
||||
@schedule_cmd.assign("状态")
|
||||
async def _(schedule_id: int):
|
||||
status = await scheduler_manager.get_schedule_status(schedule_id)
|
||||
if not status:
|
||||
await schedule_cmd.finish(f"未找到ID为 {schedule_id} 的定时任务。")
|
||||
|
||||
info_lines = [
|
||||
f"📋 定时任务详细信息 (ID: {schedule_id})",
|
||||
"--------------------",
|
||||
f"▫️ 插件: {status['plugin_name']}",
|
||||
f"▫️ Bot ID: {status.get('bot_id') or '默认'}",
|
||||
f"▫️ 目标: {status['group_id'] or '全局'}",
|
||||
f"▫️ 状态: {'✔️ 已启用' if status['is_enabled'] else '⏸️ 已暂停'}",
|
||||
f"▫️ 下次运行: {status['next_run_time']}",
|
||||
f"▫️ 触发规则: {_format_trigger(status)}",
|
||||
f"▫️ 任务参数: {_format_params(status)}",
|
||||
]
|
||||
await schedule_cmd.finish("\n".join(info_lines))
|
||||
38
zhenxun/models/schedule_info.py
Normal file
38
zhenxun/models/schedule_info.py
Normal file
@ -0,0 +1,38 @@
|
||||
from tortoise import fields
|
||||
|
||||
from zhenxun.services.db_context import Model
|
||||
|
||||
|
||||
class ScheduleInfo(Model):
|
||||
id = fields.IntField(pk=True, generated=True, auto_increment=True)
|
||||
"""自增id"""
|
||||
bot_id = fields.CharField(
|
||||
255, null=True, default=None, description="任务关联的Bot ID"
|
||||
)
|
||||
"""任务关联的Bot ID"""
|
||||
plugin_name = fields.CharField(255, description="插件模块名")
|
||||
"""插件模块名"""
|
||||
group_id = fields.CharField(
|
||||
255,
|
||||
null=True,
|
||||
description="群组ID, '__ALL_GROUPS__' 表示所有群, 为空表示全局任务",
|
||||
)
|
||||
"""群组ID, 为空表示全局任务"""
|
||||
trigger_type = fields.CharField(
|
||||
max_length=20, default="cron", description="触发器类型 (cron, interval, date)"
|
||||
)
|
||||
"""触发器类型 (cron, interval, date)"""
|
||||
trigger_config = fields.JSONField(description="触发器具体配置")
|
||||
"""触发器具体配置"""
|
||||
job_kwargs = fields.JSONField(
|
||||
default=dict, description="传递给任务函数的额外关键字参数"
|
||||
)
|
||||
"""传递给任务函数的额外关键字参数"""
|
||||
is_enabled = fields.BooleanField(default=True, description="是否启用")
|
||||
"""是否启用"""
|
||||
create_time = fields.DatetimeField(auto_now_add=True)
|
||||
"""创建时间"""
|
||||
|
||||
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
|
||||
table = "schedule_info"
|
||||
table_description = "通用定时任务表"
|
||||
@ -44,7 +44,8 @@ class Model(TortoiseModel):
|
||||
sem_data: ClassVar[dict[str, dict[str, Semaphore]]] = {}
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
MODELS.append(cls.__module__)
|
||||
if cls.__module__ not in MODELS:
|
||||
MODELS.append(cls.__module__)
|
||||
|
||||
if func := getattr(cls, "_run_script", None):
|
||||
SCRIPT_METHOD.append((cls.__module__, func))
|
||||
@ -171,7 +172,7 @@ class Model(TortoiseModel):
|
||||
await CacheRoot.reload(cache_type)
|
||||
|
||||
|
||||
class DbUrlMissing(Exception):
|
||||
class DbUrlIsNode(HookPriorityException):
|
||||
"""
|
||||
数据库链接地址为空
|
||||
"""
|
||||
@ -190,7 +191,7 @@ class DbConnectError(Exception):
|
||||
@PriorityLifecycle.on_startup(priority=1)
|
||||
async def init():
|
||||
if not BotConfig.db_url:
|
||||
# raise DbUrlMissing("数据库配置为空,请在.env.dev中配置DB_URL...")
|
||||
# raise DbUrlIsNode("数据库配置为空,请在.env.dev中配置DB_URL...")
|
||||
error = f"""
|
||||
**********************************************************************
|
||||
🌟 **************************** 配置为空 ************************* 🌟
|
||||
@ -199,7 +200,7 @@ async def init():
|
||||
***********************************************************************
|
||||
***********************************************************************
|
||||
"""
|
||||
raise DbUrlMissing("\n" + error.strip())
|
||||
raise DbUrlIsNode("\n" + error.strip())
|
||||
try:
|
||||
await Tortoise.init(
|
||||
db_url=BotConfig.db_url,
|
||||
|
||||
@ -1,91 +1,94 @@
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
from nonebot import get_driver
|
||||
from playwright.__main__ import main
|
||||
from playwright.async_api import Browser, Playwright, async_playwright
|
||||
from nonebot_plugin_alconna import UniMessage
|
||||
from nonebot_plugin_htmlrender import get_browser
|
||||
from playwright.async_api import Page
|
||||
|
||||
from zhenxun.configs.config import BotConfig
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
driver = get_driver()
|
||||
|
||||
_playwright: Playwright | None = None
|
||||
_browser: Browser | None = None
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
|
||||
|
||||
# @driver.on_startup
|
||||
# async def start_browser():
|
||||
# global _playwright
|
||||
# global _browser
|
||||
# install()
|
||||
# await check_playwright_env()
|
||||
# _playwright = await async_playwright().start()
|
||||
# _browser = await _playwright.chromium.launch()
|
||||
class BrowserIsNone(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# @driver.on_shutdown
|
||||
# async def shutdown_browser():
|
||||
# if _browser:
|
||||
# await _browser.close()
|
||||
# if _playwright:
|
||||
# await _playwright.stop() # type: ignore
|
||||
class AsyncPlaywright:
|
||||
@classmethod
|
||||
@asynccontextmanager
|
||||
async def new_page(
|
||||
cls, cookies: list[dict[str, Any]] | dict[str, Any] | None = None, **kwargs
|
||||
) -> AsyncGenerator[Page, None]:
|
||||
"""获取一个新页面
|
||||
|
||||
|
||||
# def get_browser() -> Browser:
|
||||
# if not _browser:
|
||||
# raise RuntimeError("playwright is not initalized")
|
||||
# return _browser
|
||||
|
||||
|
||||
def install():
|
||||
"""自动安装、更新 Chromium"""
|
||||
|
||||
def set_env_variables():
|
||||
os.environ["PLAYWRIGHT_DOWNLOAD_HOST"] = (
|
||||
"https://npmmirror.com/mirrors/playwright/"
|
||||
)
|
||||
if BotConfig.system_proxy:
|
||||
os.environ["HTTPS_PROXY"] = BotConfig.system_proxy
|
||||
|
||||
def restore_env_variables():
|
||||
os.environ.pop("PLAYWRIGHT_DOWNLOAD_HOST", None)
|
||||
if BotConfig.system_proxy:
|
||||
os.environ.pop("HTTPS_PROXY", None)
|
||||
if original_proxy is not None:
|
||||
os.environ["HTTPS_PROXY"] = original_proxy
|
||||
|
||||
def try_install_chromium():
|
||||
参数:
|
||||
cookies: cookies
|
||||
"""
|
||||
browser = await get_browser()
|
||||
ctx = await browser.new_context(**kwargs)
|
||||
if cookies:
|
||||
if isinstance(cookies, dict):
|
||||
cookies = [cookies]
|
||||
await ctx.add_cookies(cookies) # type: ignore
|
||||
page = await ctx.new_page()
|
||||
try:
|
||||
sys.argv = ["", "install", "chromium"]
|
||||
main()
|
||||
except SystemExit as e:
|
||||
return e.code == 0
|
||||
return False
|
||||
yield page
|
||||
finally:
|
||||
await page.close()
|
||||
await ctx.close()
|
||||
|
||||
logger.info("检查 Chromium 更新")
|
||||
@classmethod
|
||||
async def screenshot(
|
||||
cls,
|
||||
url: str,
|
||||
path: Path | str,
|
||||
element: str | list[str],
|
||||
*,
|
||||
wait_time: int | None = None,
|
||||
viewport_size: dict[str, int] | None = None,
|
||||
wait_until: (
|
||||
Literal["domcontentloaded", "load", "networkidle"] | None
|
||||
) = "networkidle",
|
||||
timeout: float | None = None,
|
||||
type_: Literal["jpeg", "png"] | None = None,
|
||||
user_agent: str | None = None,
|
||||
cookies: list[dict[str, Any]] | dict[str, Any] | None = None,
|
||||
**kwargs,
|
||||
) -> UniMessage | None:
|
||||
"""截图,该方法仅用于简单快捷截图,复杂截图请操作 page
|
||||
|
||||
original_proxy = os.environ.get("HTTPS_PROXY")
|
||||
set_env_variables()
|
||||
|
||||
success = try_install_chromium()
|
||||
|
||||
if not success:
|
||||
logger.info("Chromium 更新失败,尝试从原始仓库下载,速度较慢")
|
||||
os.environ["PLAYWRIGHT_DOWNLOAD_HOST"] = ""
|
||||
success = try_install_chromium()
|
||||
|
||||
restore_env_variables()
|
||||
|
||||
if not success:
|
||||
raise RuntimeError("未知错误,Chromium 下载失败")
|
||||
|
||||
|
||||
async def check_playwright_env():
|
||||
"""检查 Playwright 依赖"""
|
||||
logger.info("检查 Playwright 依赖")
|
||||
try:
|
||||
async with async_playwright() as p:
|
||||
await p.chromium.launch()
|
||||
except Exception as e:
|
||||
raise ImportError("加载失败,Playwright 依赖不全,") from e
|
||||
参数:
|
||||
url: 网址
|
||||
path: 存储路径
|
||||
element: 元素选择
|
||||
wait_time: 等待截取超时时间
|
||||
viewport_size: 窗口大小
|
||||
wait_until: 等待类型
|
||||
timeout: 超时限制
|
||||
type_: 保存类型
|
||||
user_agent: user_agent
|
||||
cookies: cookies
|
||||
"""
|
||||
if viewport_size is None:
|
||||
viewport_size = {"width": 2560, "height": 1080}
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
wait_time = wait_time * 1000 if wait_time else None
|
||||
element_list = [element] if isinstance(element, str) else element
|
||||
async with cls.new_page(
|
||||
cookies,
|
||||
viewport=viewport_size,
|
||||
user_agent=user_agent,
|
||||
**kwargs,
|
||||
) as page:
|
||||
await page.goto(url, timeout=timeout, wait_until=wait_until)
|
||||
card = page
|
||||
for e in element_list:
|
||||
if not card:
|
||||
return None
|
||||
card = await card.wait_for_selector(e, timeout=wait_time)
|
||||
if card:
|
||||
await card.screenshot(path=path, timeout=timeout, type=type_)
|
||||
return MessageUtils.build_message(path)
|
||||
return None
|
||||
@ -1,24 +1,226 @@
|
||||
from collections.abc import Callable
|
||||
from functools import partial, wraps
|
||||
from typing import Any, Literal
|
||||
|
||||
from anyio import EndOfStream
|
||||
from httpx import ConnectError, HTTPStatusError, TimeoutException
|
||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
|
||||
from httpx import (
|
||||
ConnectError,
|
||||
HTTPStatusError,
|
||||
RemoteProtocolError,
|
||||
StreamError,
|
||||
TimeoutException,
|
||||
)
|
||||
from nonebot.utils import is_coroutine_callable
|
||||
from tenacity import (
|
||||
RetryCallState,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
retry_if_result,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
wait_fixed,
|
||||
)
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
LOG_COMMAND = "RetryDecorator"
|
||||
_SENTINEL = object()
|
||||
|
||||
|
||||
def _log_before_sleep(log_name: str | None, retry_state: RetryCallState):
|
||||
"""
|
||||
tenacity 重试前的日志记录回调函数。
|
||||
"""
|
||||
func_name = retry_state.fn.__name__ if retry_state.fn else "unknown_function"
|
||||
log_context = f"函数 '{func_name}'"
|
||||
if log_name:
|
||||
log_context = f"操作 '{log_name}' ({log_context})"
|
||||
|
||||
reason = ""
|
||||
if retry_state.outcome:
|
||||
if exc := retry_state.outcome.exception():
|
||||
reason = f"触发异常: {exc.__class__.__name__}({exc})"
|
||||
else:
|
||||
reason = f"不满足结果条件: result={retry_state.outcome.result()}"
|
||||
|
||||
wait_time = (
|
||||
getattr(retry_state.next_action, "sleep", 0) if retry_state.next_action else 0
|
||||
)
|
||||
logger.warning(
|
||||
f"{log_context} 第 {retry_state.attempt_number} 次重试... "
|
||||
f"等待 {wait_time:.2f} 秒. {reason}",
|
||||
LOG_COMMAND,
|
||||
)
|
||||
|
||||
|
||||
class Retry:
|
||||
@staticmethod
|
||||
def api(
|
||||
retry_count: int = 3, wait: int = 1, exception: tuple[type[Exception], ...] = ()
|
||||
def simple(
|
||||
stop_max_attempt: int = 3,
|
||||
wait_fixed_seconds: int = 2,
|
||||
exception: tuple[type[Exception], ...] = (),
|
||||
*,
|
||||
log_name: str | None = None,
|
||||
on_failure: Callable[[Exception], Any] | None = None,
|
||||
return_on_failure: Any = _SENTINEL,
|
||||
):
|
||||
"""接口调用重试"""
|
||||
"""
|
||||
一个简单的、用于通用网络请求的重试装饰器预设。
|
||||
|
||||
参数:
|
||||
stop_max_attempt: 最大重试次数。
|
||||
wait_fixed_seconds: 固定等待策略的等待秒数。
|
||||
exception: 额外需要重试的异常类型元组。
|
||||
log_name: 用于日志记录的操作名称。
|
||||
on_failure: (可选) 所有重试失败后的回调。
|
||||
return_on_failure: (可选) 所有重试失败后的返回值。
|
||||
"""
|
||||
return Retry.api(
|
||||
stop_max_attempt=stop_max_attempt,
|
||||
wait_fixed_seconds=wait_fixed_seconds,
|
||||
exception=exception,
|
||||
strategy="fixed",
|
||||
log_name=log_name,
|
||||
on_failure=on_failure,
|
||||
return_on_failure=return_on_failure,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def download(
|
||||
stop_max_attempt: int = 3,
|
||||
exception: tuple[type[Exception], ...] = (),
|
||||
*,
|
||||
wait_exp_multiplier: int = 2,
|
||||
wait_exp_max: int = 15,
|
||||
log_name: str | None = None,
|
||||
on_failure: Callable[[Exception], Any] | None = None,
|
||||
return_on_failure: Any = _SENTINEL,
|
||||
):
|
||||
"""
|
||||
一个适用于文件下载的重试装饰器预设,使用指数退避策略。
|
||||
|
||||
参数:
|
||||
stop_max_attempt: 最大重试次数。
|
||||
exception: 额外需要重试的异常类型元组。
|
||||
wait_exp_multiplier: 指数退避的乘数。
|
||||
wait_exp_max: 指数退避的最大等待时间。
|
||||
log_name: 用于日志记录的操作名称。
|
||||
on_failure: (可选) 所有重试失败后的回调。
|
||||
return_on_failure: (可选) 所有重试失败后的返回值。
|
||||
"""
|
||||
return Retry.api(
|
||||
stop_max_attempt=stop_max_attempt,
|
||||
exception=exception,
|
||||
strategy="exponential",
|
||||
wait_exp_multiplier=wait_exp_multiplier,
|
||||
wait_exp_max=wait_exp_max,
|
||||
log_name=log_name,
|
||||
on_failure=on_failure,
|
||||
return_on_failure=return_on_failure,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def api(
|
||||
stop_max_attempt: int = 3,
|
||||
wait_fixed_seconds: int = 1,
|
||||
exception: tuple[type[Exception], ...] = (),
|
||||
*,
|
||||
strategy: Literal["fixed", "exponential"] = "fixed",
|
||||
retry_on_result: Callable[[Any], bool] | None = None,
|
||||
wait_exp_multiplier: int = 1,
|
||||
wait_exp_max: int = 10,
|
||||
log_name: str | None = None,
|
||||
on_failure: Callable[[Exception], Any] | None = None,
|
||||
return_on_failure: Any = _SENTINEL,
|
||||
):
|
||||
"""
|
||||
通用、可配置的API调用重试装饰器。
|
||||
|
||||
参数:
|
||||
stop_max_attempt: 最大重试次数。
|
||||
wait_fixed_seconds: 固定等待策略的等待秒数。
|
||||
exception: 额外需要重试的异常类型元组。
|
||||
strategy: 重试等待策略, 'fixed' (固定) 或 'exponential' (指数退避)。
|
||||
retry_on_result: 一个回调函数,接收函数返回值。如果返回 True,则触发重试。
|
||||
例如 `lambda r: r.status_code != 200`
|
||||
wait_exp_multiplier: 指数退避的乘数。
|
||||
wait_exp_max: 指数退避的最大等待时间。
|
||||
log_name: 用于日志记录的操作名称,方便区分不同的重试场景。
|
||||
on_failure: (可选) 当所有重试都失败后,在抛出异常或返回默认值之前,
|
||||
会调用此函数,并将最终的异常实例作为参数传入。
|
||||
return_on_failure: (可选) 如果设置了此参数,当所有重试失败后,
|
||||
将不再抛出异常,而是返回此参数指定的值。
|
||||
"""
|
||||
base_exceptions = (
|
||||
TimeoutException,
|
||||
ConnectError,
|
||||
HTTPStatusError,
|
||||
StreamError,
|
||||
RemoteProtocolError,
|
||||
EndOfStream,
|
||||
*exception,
|
||||
)
|
||||
return retry(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(retry_count),
|
||||
wait=wait_fixed(wait),
|
||||
retry=retry_if_exception_type(base_exceptions),
|
||||
)
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
if strategy == "exponential":
|
||||
wait_strategy = wait_exponential(
|
||||
multiplier=wait_exp_multiplier, max=wait_exp_max
|
||||
)
|
||||
else:
|
||||
wait_strategy = wait_fixed(wait_fixed_seconds)
|
||||
|
||||
retry_conditions = retry_if_exception_type(base_exceptions)
|
||||
if retry_on_result:
|
||||
retry_conditions |= retry_if_result(retry_on_result)
|
||||
|
||||
log_callback = partial(_log_before_sleep, log_name)
|
||||
|
||||
tenacity_retry_decorator = retry(
|
||||
stop=stop_after_attempt(stop_max_attempt),
|
||||
wait=wait_strategy,
|
||||
retry=retry_conditions,
|
||||
before_sleep=log_callback,
|
||||
reraise=True,
|
||||
)
|
||||
|
||||
decorated_func = tenacity_retry_decorator(func)
|
||||
|
||||
if return_on_failure is _SENTINEL:
|
||||
return decorated_func
|
||||
|
||||
if is_coroutine_callable(func):
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
try:
|
||||
return await decorated_func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
if on_failure:
|
||||
if is_coroutine_callable(on_failure):
|
||||
await on_failure(e)
|
||||
else:
|
||||
on_failure(e)
|
||||
return return_on_failure
|
||||
|
||||
return async_wrapper
|
||||
else:
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
try:
|
||||
return decorated_func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
if on_failure:
|
||||
if is_coroutine_callable(on_failure):
|
||||
logger.error(
|
||||
f"不能在同步函数 '{func.__name__}' 中调用异步的 "
|
||||
f"on_failure 回调。",
|
||||
LOG_COMMAND,
|
||||
)
|
||||
else:
|
||||
on_failure(e)
|
||||
return return_on_failure
|
||||
|
||||
return sync_wrapper
|
||||
|
||||
return decorator
|
||||
@ -64,3 +64,23 @@ class GoodsNotFound(Exception):
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AllURIsFailedError(Exception):
|
||||
"""
|
||||
当所有备用URL都尝试失败后抛出此异常
|
||||
"""
|
||||
|
||||
def __init__(self, urls: list[str], exceptions: list[Exception]):
|
||||
self.urls = urls
|
||||
self.exceptions = exceptions
|
||||
super().__init__(
|
||||
f"All {len(urls)} URIs failed. Last exception: {exceptions[-1]}"
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
exc_info = "\n".join(
|
||||
f" - {url}: {exc.__class__.__name__}({exc})"
|
||||
for url, exc in zip(self.urls, self.exceptions)
|
||||
)
|
||||
return f"All {len(self.urls)} URIs failed:\n{exc_info}"
|
||||
@ -1,16 +1,15 @@
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator, Sequence
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable, Sequence
|
||||
from contextlib import asynccontextmanager
|
||||
import os
|
||||
from pathlib import Path
|
||||
import time
|
||||
from typing import Any, ClassVar, Literal, cast
|
||||
from typing import Any, ClassVar, cast
|
||||
|
||||
import aiofiles
|
||||
import httpx
|
||||
from httpx import AsyncHTTPTransport, HTTPStatusError, Proxy, Response
|
||||
from nonebot_plugin_alconna import UniMessage
|
||||
from nonebot_plugin_htmlrender import get_browser
|
||||
from playwright.async_api import Page
|
||||
from httpx import AsyncClient, AsyncHTTPTransport, HTTPStatusError, Proxy, Response
|
||||
import nonebot
|
||||
from rich.progress import (
|
||||
BarColumn,
|
||||
DownloadColumn,
|
||||
@ -18,13 +17,84 @@ from rich.progress import (
|
||||
TextColumn,
|
||||
TransferSpeedColumn,
|
||||
)
|
||||
import ujson as json
|
||||
|
||||
from zhenxun.configs.config import BotConfig
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
from zhenxun.utils.decorator.retry import Retry
|
||||
from zhenxun.utils.exception import AllURIsFailedError
|
||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||
from zhenxun.utils.user_agent import get_user_agent
|
||||
|
||||
CLIENT_KEY = ["use_proxy", "proxies", "proxy", "verify", "headers"]
|
||||
from .browser import AsyncPlaywright, BrowserIsNone # noqa: F401
|
||||
|
||||
_SENTINEL = object()
|
||||
|
||||
driver = nonebot.get_driver()
|
||||
_client: AsyncClient | None = None
|
||||
|
||||
|
||||
@PriorityLifecycle.on_startup(priority=0)
|
||||
async def _():
|
||||
"""
|
||||
在Bot启动时初始化全局httpx客户端。
|
||||
"""
|
||||
global _client
|
||||
client_kwargs = {}
|
||||
if proxy_url := BotConfig.system_proxy or None:
|
||||
try:
|
||||
version_parts = httpx.__version__.split(".")
|
||||
major = int("".join(c for c in version_parts[0] if c.isdigit()))
|
||||
minor = (
|
||||
int("".join(c for c in version_parts[1] if c.isdigit()))
|
||||
if len(version_parts) > 1
|
||||
else 0
|
||||
)
|
||||
if (major, minor) >= (0, 28):
|
||||
client_kwargs["proxy"] = proxy_url
|
||||
else:
|
||||
client_kwargs["proxies"] = proxy_url
|
||||
except (ValueError, IndexError):
|
||||
client_kwargs["proxy"] = proxy_url
|
||||
logger.warning(
|
||||
f"无法解析 httpx 版本 '{httpx.__version__}',"
|
||||
"将默认使用新版 'proxy' 参数语法。"
|
||||
)
|
||||
|
||||
_client = httpx.AsyncClient(
|
||||
headers=get_user_agent(),
|
||||
follow_redirects=True,
|
||||
**client_kwargs,
|
||||
)
|
||||
|
||||
logger.info("全局 httpx.AsyncClient 已启动。", "HTTPClient")
|
||||
|
||||
|
||||
@driver.on_shutdown
|
||||
async def _():
|
||||
"""
|
||||
在Bot关闭时关闭全局httpx客户端。
|
||||
"""
|
||||
if _client:
|
||||
await _client.aclose()
|
||||
logger.info("全局 httpx.AsyncClient 已关闭。", "HTTPClient")
|
||||
|
||||
|
||||
def get_client() -> AsyncClient:
|
||||
"""
|
||||
获取全局 httpx.AsyncClient 实例。
|
||||
"""
|
||||
global _client
|
||||
if not _client:
|
||||
if not os.environ.get("PYTEST_CURRENT_TEST"):
|
||||
raise RuntimeError("全局 httpx.AsyncClient 未初始化,请检查启动流程。")
|
||||
# 在测试环境中创建临时客户端
|
||||
logger.warning("在测试环境中创建临时HTTP客户端", "HTTPClient")
|
||||
_client = httpx.AsyncClient(
|
||||
headers=get_user_agent(),
|
||||
follow_redirects=True,
|
||||
)
|
||||
return _client
|
||||
|
||||
|
||||
def get_async_client(
|
||||
@ -33,6 +103,10 @@ def get_async_client(
|
||||
verify: bool = False,
|
||||
**kwargs,
|
||||
) -> httpx.AsyncClient:
|
||||
"""
|
||||
[向后兼容] 创建 httpx.AsyncClient 实例的工厂函数。
|
||||
此函数完全保留了旧版本的接口,确保现有代码无需修改即可使用。
|
||||
"""
|
||||
transport = kwargs.pop("transport", None) or AsyncHTTPTransport(verify=verify)
|
||||
if proxies:
|
||||
http_proxy = proxies.get("http://")
|
||||
@ -62,6 +136,30 @@ def get_async_client(
|
||||
|
||||
|
||||
class AsyncHttpx:
|
||||
"""
|
||||
一个高级的、健壮的异步HTTP客户端工具类。
|
||||
|
||||
设计理念:
|
||||
- **全局共享客户端**: 默认情况下,所有请求都通过一个在应用启动时初始化的全局
|
||||
`httpx.AsyncClient` 实例发出。这个实例共享连接池,提高了效率和性能。
|
||||
- **向后兼容与灵活性**: 完全兼容旧的API,同时提供了两种方式来处理需要
|
||||
特殊网络配置(如不同代理、超时)的请求:
|
||||
1. **单次请求覆盖**: 在调用 `get`, `post` 等方法时,直接传入 `proxies`,
|
||||
`timeout` 等参数,将为该次请求创建一个临时的、独立的客户端。
|
||||
2. **临时客户端上下文**: 使用 `temporary_client()` 上下文管理器,可以
|
||||
获取一个独立的、可配置的客户端,用于执行一系列需要相同特殊配置的请求。
|
||||
- **健壮性**: 内置了自动重试、多镜像URL回退(fallback)机制,并提供了便捷的
|
||||
JSON解析和文件下载方法。
|
||||
"""
|
||||
|
||||
CLIENT_KEY: ClassVar[list[str]] = [
|
||||
"use_proxy",
|
||||
"proxies",
|
||||
"proxy",
|
||||
"verify",
|
||||
"headers",
|
||||
]
|
||||
|
||||
default_proxy: ClassVar[dict[str, str] | None] = (
|
||||
{
|
||||
"http://": BotConfig.system_proxy,
|
||||
@ -72,155 +170,346 @@ class AsyncHttpx:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@asynccontextmanager
|
||||
async def _create_client(
|
||||
cls,
|
||||
*,
|
||||
use_proxy: bool = True,
|
||||
proxies: dict[str, str] | None = None,
|
||||
proxy: str | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
verify: bool = False,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[httpx.AsyncClient, None]:
|
||||
"""创建一个私有的、配置好的 httpx.AsyncClient 上下文管理器。
|
||||
def _prepare_temporary_client_config(cls, client_kwargs: dict) -> dict:
|
||||
"""
|
||||
[向后兼容] 处理旧式的客户端kwargs,将其转换为get_async_client可用的配置。
|
||||
主要负责处理 use_proxy 标志,这是为了兼容旧版本代码中使用的 use_proxy 参数。
|
||||
"""
|
||||
final_config = client_kwargs.copy()
|
||||
|
||||
说明:
|
||||
此方法用于内部统一创建客户端,处理代理和请求头逻辑,减少代码重复。
|
||||
use_proxy = final_config.pop("use_proxy", True)
|
||||
|
||||
if "proxies" not in final_config and "proxy" not in final_config:
|
||||
final_config["proxies"] = cls.default_proxy if use_proxy else None
|
||||
return final_config
|
||||
|
||||
@classmethod
|
||||
def _split_kwargs(cls, kwargs: dict) -> tuple[dict, dict]:
|
||||
"""[优化] 分离客户端配置和请求参数,使逻辑更清晰。"""
|
||||
client_kwargs = {k: v for k, v in kwargs.items() if k in cls.CLIENT_KEY}
|
||||
request_kwargs = {k: v for k, v in kwargs.items() if k not in cls.CLIENT_KEY}
|
||||
return client_kwargs, request_kwargs
|
||||
|
||||
@classmethod
|
||||
@asynccontextmanager
|
||||
async def _get_active_client_context(
|
||||
cls, client: AsyncClient | None = None, **kwargs
|
||||
) -> AsyncGenerator[AsyncClient, None]:
|
||||
"""
|
||||
内部辅助方法,根据 kwargs 决定并提供一个活动的 HTTP 客户端。
|
||||
- 如果 kwargs 中有客户端配置,则创建并返回一个临时客户端。
|
||||
- 否则,返回传入的 client 或全局客户端。
|
||||
- 自动处理临时客户端的关闭。
|
||||
"""
|
||||
if kwargs:
|
||||
logger.debug(f"为单次请求创建临时客户端,配置: {kwargs}")
|
||||
temp_client_config = cls._prepare_temporary_client_config(kwargs)
|
||||
async with get_async_client(**temp_client_config) as temp_client:
|
||||
yield temp_client
|
||||
else:
|
||||
yield client or get_client()
|
||||
|
||||
@Retry.simple(log_name="内部HTTP请求")
|
||||
async def _execute_request_inner(
|
||||
self, client: AsyncClient, method: str, url: str, **kwargs
|
||||
) -> Response:
|
||||
"""
|
||||
[内部] 执行单次HTTP请求的私有核心方法,被重试装饰器包裹。
|
||||
"""
|
||||
return await client.request(method, url, **kwargs)
|
||||
|
||||
@classmethod
|
||||
async def _single_request(
|
||||
cls, method: str, url: str, *, client: AsyncClient | None = None, **kwargs
|
||||
) -> Response:
|
||||
"""
|
||||
执行单次HTTP请求的私有方法,内置了默认的重试逻辑。
|
||||
"""
|
||||
client_kwargs, request_kwargs = cls._split_kwargs(kwargs)
|
||||
|
||||
async with cls._get_active_client_context(
|
||||
client=client, **client_kwargs
|
||||
) as active_client:
|
||||
response = await cls()._execute_request_inner(
|
||||
active_client, method, url, **request_kwargs
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
async def _execute_with_fallbacks(
|
||||
cls,
|
||||
urls: str | list[str],
|
||||
worker: Callable[..., Awaitable[Any]],
|
||||
*,
|
||||
client: AsyncClient | None = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""
|
||||
通用执行器,按顺序尝试多个URL,直到成功。
|
||||
|
||||
参数:
|
||||
use_proxy: 是否使用在类中定义的默认代理。
|
||||
proxies: 手动指定的代理,会覆盖默认代理。
|
||||
proxy: 单个代理,用于兼容旧版本,不再使用
|
||||
headers: 需要合并到客户端的自定义请求头。
|
||||
verify: 是否验证 SSL 证书。
|
||||
**kwargs: 其他所有传递给 httpx.AsyncClient 的参数。
|
||||
|
||||
返回:
|
||||
AsyncGenerator[httpx.AsyncClient, None]: 生成器。
|
||||
urls: 单个URL或URL列表。
|
||||
worker: 一个接受单个URL和其他kwargs并执行请求的协程函数。
|
||||
client: 可选的HTTP客户端。
|
||||
**kwargs: 传递给worker的额外参数。
|
||||
"""
|
||||
proxies_to_use = proxies or (cls.default_proxy if use_proxy else None)
|
||||
url_list = [urls] if isinstance(urls, str) else urls
|
||||
exceptions = []
|
||||
|
||||
final_headers = get_user_agent()
|
||||
if headers:
|
||||
final_headers.update(headers)
|
||||
for i, url in enumerate(url_list):
|
||||
try:
|
||||
result = await worker(url, client=client, **kwargs)
|
||||
if i > 0:
|
||||
logger.info(
|
||||
f"成功从镜像 '{url}' 获取资源 "
|
||||
f"(在尝试了 {i} 个失败的镜像之后)。",
|
||||
"AsyncHttpx:FallbackExecutor",
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
exceptions.append(e)
|
||||
if url != url_list[-1]:
|
||||
logger.warning(
|
||||
f"Worker '{worker.__name__}' on {url} failed, trying next. "
|
||||
f"Error: {e.__class__.__name__}",
|
||||
"AsyncHttpx:FallbackExecutor",
|
||||
)
|
||||
|
||||
async with get_async_client(
|
||||
proxies=proxies_to_use,
|
||||
proxy=proxy,
|
||||
verify=verify,
|
||||
headers=final_headers,
|
||||
**kwargs,
|
||||
) as client:
|
||||
yield client
|
||||
raise AllURIsFailedError(url_list, exceptions)
|
||||
|
||||
@classmethod
|
||||
async def get(
|
||||
cls,
|
||||
url: str | list[str],
|
||||
*,
|
||||
follow_redirects: bool = True,
|
||||
check_status_code: int | None = None,
|
||||
client: AsyncClient | None = None,
|
||||
**kwargs,
|
||||
) -> Response: # sourcery skip: use-assigned-variable
|
||||
) -> Response:
|
||||
"""发送 GET 请求,并返回第一个成功的响应。
|
||||
|
||||
说明:
|
||||
本方法是 httpx.get 的高级包装,增加了多链接尝试、自动重试和统一的代理管理。
|
||||
如果提供 URL 列表,它将依次尝试直到成功为止。
|
||||
本方法是 httpx.get 的高级包装,增加了多链接尝试、自动重试和统一的
|
||||
客户端管理。如果提供 URL 列表,它将依次尝试直到成功为止。
|
||||
|
||||
用法建议:
|
||||
- **常规使用**: `await AsyncHttpx.get(url)` 将使用全局客户端。
|
||||
- **单次覆盖配置**: `await AsyncHttpx.get(url, timeout=5, proxies=None)`
|
||||
将为本次请求创建一个独立的临时客户端。
|
||||
|
||||
参数:
|
||||
url: 单个请求 URL 或一个 URL 列表。
|
||||
follow_redirects: 是否跟随重定向。
|
||||
check_status_code: (可选) 若提供,将检查响应状态码是否匹配,否则抛出异常。
|
||||
**kwargs: 其他所有传递给 httpx.get 的参数
|
||||
(如 `params`, `headers`, `timeout`等)。
|
||||
client: (可选) 指定一个活动的HTTP客户端实例。若提供,则忽略
|
||||
`**kwargs`中的客户端配置。
|
||||
**kwargs: 其他所有传递给 httpx.get 的参数 (如 `params`, `headers`,
|
||||
`timeout`)。如果包含 `proxies`, `verify` 等客户端配置参数,
|
||||
将创建一个临时客户端。
|
||||
|
||||
返回:
|
||||
Response: Response
|
||||
Response: httpx 的响应对象。
|
||||
|
||||
Raises:
|
||||
AllURIsFailedError: 当所有提供的URL都请求失败时抛出。
|
||||
"""
|
||||
urls = [url] if isinstance(url, str) else url
|
||||
last_exception = None
|
||||
for current_url in urls:
|
||||
try:
|
||||
logger.info(f"开始获取 {current_url}..")
|
||||
client_kwargs = {k: v for k, v in kwargs.items() if k in CLIENT_KEY}
|
||||
for key in CLIENT_KEY:
|
||||
kwargs.pop(key, None)
|
||||
async with cls._create_client(**client_kwargs) as client:
|
||||
response = await client.get(current_url, **kwargs)
|
||||
|
||||
if check_status_code and response.status_code != check_status_code:
|
||||
raise HTTPStatusError(
|
||||
f"状态码错误: {response.status_code}!={check_status_code}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
if current_url != urls[-1]:
|
||||
logger.warning(f"获取 {current_url} 失败, 尝试下一个", e=e)
|
||||
async def worker(current_url: str, **worker_kwargs) -> Response:
|
||||
logger.info(f"开始获取 {current_url}..", "AsyncHttpx:get")
|
||||
response = await cls._single_request(
|
||||
"GET", current_url, follow_redirects=follow_redirects, **worker_kwargs
|
||||
)
|
||||
if check_status_code and response.status_code != check_status_code:
|
||||
raise HTTPStatusError(
|
||||
f"状态码错误: {response.status_code}!={check_status_code}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
return response
|
||||
|
||||
raise last_exception or Exception("所有URL都获取失败")
|
||||
return await cls._execute_with_fallbacks(url, worker, client=client, **kwargs)
|
||||
|
||||
@classmethod
|
||||
async def head(cls, url: str, **kwargs) -> Response:
|
||||
"""发送 HEAD 请求。
|
||||
async def head(
|
||||
cls, url: str | list[str], *, client: AsyncClient | None = None, **kwargs
|
||||
) -> Response:
|
||||
"""发送 HEAD 请求,并返回第一个成功的响应。"""
|
||||
|
||||
说明:
|
||||
本方法是对 httpx.head 的封装,通常用于检查资源的元信息(如大小、类型)。
|
||||
async def worker(current_url: str, **worker_kwargs) -> Response:
|
||||
return await cls._single_request("HEAD", current_url, **worker_kwargs)
|
||||
|
||||
参数:
|
||||
url: 请求的 URL。
|
||||
**kwargs: 其他所有传递给 httpx.head 的参数
|
||||
(如 `headers`, `timeout`, `allow_redirects`)。
|
||||
|
||||
返回:
|
||||
Response: Response
|
||||
"""
|
||||
client_kwargs = {k: v for k, v in kwargs.items() if k in CLIENT_KEY}
|
||||
for key in CLIENT_KEY:
|
||||
kwargs.pop(key, None)
|
||||
async with cls._create_client(**client_kwargs) as client:
|
||||
return await client.head(url, **kwargs)
|
||||
return await cls._execute_with_fallbacks(url, worker, client=client, **kwargs)
|
||||
|
||||
@classmethod
|
||||
async def post(cls, url: str, **kwargs) -> Response:
|
||||
"""发送 POST 请求。
|
||||
async def post(
|
||||
cls, url: str | list[str], *, client: AsyncClient | None = None, **kwargs
|
||||
) -> Response:
|
||||
"""发送 POST 请求,并返回第一个成功的响应。"""
|
||||
|
||||
说明:
|
||||
本方法是对 httpx.post 的封装,提供了统一的代理和客户端管理。
|
||||
async def worker(current_url: str, **worker_kwargs) -> Response:
|
||||
return await cls._single_request("POST", current_url, **worker_kwargs)
|
||||
|
||||
参数:
|
||||
url: 请求的 URL。
|
||||
**kwargs: 其他所有传递给 httpx.post 的参数
|
||||
(如 `data`, `json`, `content` 等)。
|
||||
|
||||
返回:
|
||||
Response: Response。
|
||||
"""
|
||||
client_kwargs = {k: v for k, v in kwargs.items() if k in CLIENT_KEY}
|
||||
for key in CLIENT_KEY:
|
||||
kwargs.pop(key, None)
|
||||
async with cls._create_client(**client_kwargs) as client:
|
||||
return await client.post(url, **kwargs)
|
||||
return await cls._execute_with_fallbacks(url, worker, client=client, **kwargs)
|
||||
|
||||
@classmethod
|
||||
async def get_content(cls, url: str, **kwargs) -> bytes:
|
||||
"""获取指定 URL 的二进制内容。
|
||||
|
||||
说明:
|
||||
这是一个便捷方法,等同于调用 get() 后再访问 .content 属性。
|
||||
|
||||
参数:
|
||||
url: 请求的 URL。
|
||||
**kwargs: 所有传递给 get() 方法的参数。
|
||||
|
||||
返回:
|
||||
bytes: 响应内容的二进制字节流 (bytes)。
|
||||
"""
|
||||
res = await cls.get(url, **kwargs)
|
||||
async def get_content(
|
||||
cls, url: str | list[str], *, client: AsyncClient | None = None, **kwargs
|
||||
) -> bytes:
|
||||
"""获取指定 URL 的二进制内容。"""
|
||||
res = await cls.get(url, client=client, **kwargs)
|
||||
return res.content
|
||||
|
||||
@classmethod
|
||||
@Retry.api(
|
||||
log_name="JSON请求",
|
||||
exception=(json.JSONDecodeError,),
|
||||
return_on_failure=_SENTINEL,
|
||||
)
|
||||
async def _request_and_parse_json(
|
||||
cls, method: str, url: str, *, client: AsyncClient | None = None, **kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
[私有] 执行单个HTTP请求并解析JSON,用于内部统一处理。
|
||||
"""
|
||||
async with cls._get_active_client_context(
|
||||
client=client, **kwargs
|
||||
) as active_client:
|
||||
_, request_kwargs = cls._split_kwargs(kwargs)
|
||||
response = await active_client.request(method, url, **request_kwargs)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@classmethod
|
||||
async def get_json(
|
||||
cls,
|
||||
url: str | list[str],
|
||||
*,
|
||||
default: Any = None,
|
||||
raise_on_failure: bool = False,
|
||||
client: AsyncClient | None = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""
|
||||
发送GET请求并自动解析为JSON,支持重试和多链接尝试。
|
||||
|
||||
说明:
|
||||
这是一个高度便捷的方法,封装了请求、重试、JSON解析和错误处理。
|
||||
它会在网络错误或JSON解析错误时自动重试。
|
||||
如果所有尝试都失败,它会安全地返回一个默认值。
|
||||
|
||||
参数:
|
||||
url: 单个请求 URL 或一个备用 URL 列表。
|
||||
default: (可选) 当所有尝试都失败时返回的默认值,默认为None。
|
||||
raise_on_failure: (可选) 如果为 True, 当所有尝试失败时将抛出
|
||||
`AllURIsFailedError` 异常, 默认为 False.
|
||||
client: (可选) 指定的HTTP客户端。
|
||||
**kwargs: 其他所有传递给 httpx.get 的参数。
|
||||
例如 `params`, `headers`, `timeout`等。
|
||||
|
||||
返回:
|
||||
Any: 解析后的JSON数据,或在失败时返回 `default` 值。
|
||||
|
||||
Raises:
|
||||
AllURIsFailedError: 当 `raise_on_failure` 为 True 且所有URL都请求失败时抛出
|
||||
"""
|
||||
|
||||
async def worker(current_url: str, **worker_kwargs):
|
||||
logger.debug(f"开始GET JSON: {current_url}", "AsyncHttpx:get_json")
|
||||
return await cls._request_and_parse_json(
|
||||
"GET", current_url, **worker_kwargs
|
||||
)
|
||||
|
||||
try:
|
||||
result = await cls._execute_with_fallbacks(
|
||||
url, worker, client=client, **kwargs
|
||||
)
|
||||
return default if result is _SENTINEL else result
|
||||
except AllURIsFailedError as e:
|
||||
logger.error(f"所有URL的JSON GET均失败: {e}", "AsyncHttpx:get_json")
|
||||
if raise_on_failure:
|
||||
raise e
|
||||
return default
|
||||
|
||||
@classmethod
|
||||
async def post_json(
|
||||
cls,
|
||||
url: str | list[str],
|
||||
*,
|
||||
json: Any = None,
|
||||
data: Any = None,
|
||||
default: Any = None,
|
||||
raise_on_failure: bool = False,
|
||||
client: AsyncClient | None = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""
|
||||
发送POST请求并自动解析为JSON,功能与 get_json 类似。
|
||||
|
||||
参数:
|
||||
url: 单个请求 URL 或一个备用 URL 列表。
|
||||
json: (可选) 作为请求体发送的JSON数据。
|
||||
data: (可选) 作为请求体发送的表单数据。
|
||||
default: (可选) 当所有尝试都失败时返回的默认值,默认为None。
|
||||
raise_on_failure: (可选) 如果为 True, 当所有尝试失败时将抛出
|
||||
AllURIsFailedError 异常, 默认为 False.
|
||||
client: (可选) 指定的HTTP客户端。
|
||||
**kwargs: 其他所有传递给 httpx.post 的参数。
|
||||
|
||||
返回:
|
||||
Any: 解析后的JSON数据,或在失败时返回 `default` 值。
|
||||
"""
|
||||
if json is not None:
|
||||
kwargs["json"] = json
|
||||
if data is not None:
|
||||
kwargs["data"] = data
|
||||
|
||||
async def worker(current_url: str, **worker_kwargs):
|
||||
logger.debug(f"开始POST JSON: {current_url}", "AsyncHttpx:post_json")
|
||||
return await cls._request_and_parse_json(
|
||||
"POST", current_url, **worker_kwargs
|
||||
)
|
||||
|
||||
try:
|
||||
result = await cls._execute_with_fallbacks(
|
||||
url, worker, client=client, **kwargs
|
||||
)
|
||||
return default if result is _SENTINEL else result
|
||||
except AllURIsFailedError as e:
|
||||
logger.error(f"所有URL的JSON POST均失败: {e}", "AsyncHttpx:post_json")
|
||||
if raise_on_failure:
|
||||
raise e
|
||||
return default
|
||||
|
||||
@classmethod
|
||||
@Retry.api(log_name="文件下载(流式)")
|
||||
async def _stream_download(
|
||||
cls, url: str, path: Path, *, client: AsyncClient | None = None, **kwargs
|
||||
) -> None:
|
||||
"""
|
||||
执行单个流式下载的私有方法,被重试装饰器包裹。
|
||||
"""
|
||||
async with cls._get_active_client_context(
|
||||
client=client, **kwargs
|
||||
) as active_client:
|
||||
async with active_client.stream("GET", url, **kwargs) as response:
|
||||
response.raise_for_status()
|
||||
total = int(response.headers.get("Content-Length", 0))
|
||||
|
||||
with Progress(
|
||||
TextColumn(path.name),
|
||||
"[progress.percentage]{task.percentage:>3.0f}%",
|
||||
BarColumn(bar_width=None),
|
||||
DownloadColumn(),
|
||||
TransferSpeedColumn(),
|
||||
) as progress:
|
||||
task_id = progress.add_task("Download", total=total)
|
||||
async with aiofiles.open(path, "wb") as f:
|
||||
async for chunk in response.aiter_bytes():
|
||||
await f.write(chunk)
|
||||
progress.update(task_id, advance=len(chunk))
|
||||
|
||||
@classmethod
|
||||
async def download_file(
|
||||
cls,
|
||||
@ -228,6 +517,7 @@ class AsyncHttpx:
|
||||
path: str | Path,
|
||||
*,
|
||||
stream: bool = False,
|
||||
client: AsyncClient | None = None,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
"""下载文件到指定路径。
|
||||
@ -239,6 +529,7 @@ class AsyncHttpx:
|
||||
url: 单个文件 URL 或一个备用 URL 列表。
|
||||
path: 文件保存的本地路径。
|
||||
stream: (可选) 是否使用流式下载,适用于大文件,默认为 False。
|
||||
client: (可选) 指定的HTTP客户端。
|
||||
**kwargs: 其他所有传递给 get() 方法或 httpx.stream() 的参数。
|
||||
|
||||
返回:
|
||||
@ -247,49 +538,29 @@ class AsyncHttpx:
|
||||
path = Path(path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
urls = [url] if isinstance(url, str) else url
|
||||
async def worker(current_url: str, **worker_kwargs) -> bool:
|
||||
if not stream:
|
||||
content = await cls.get_content(current_url, **worker_kwargs)
|
||||
async with aiofiles.open(path, "wb") as f:
|
||||
await f.write(content)
|
||||
else:
|
||||
await cls._stream_download(current_url, path, **worker_kwargs)
|
||||
|
||||
for current_url in urls:
|
||||
try:
|
||||
if not stream:
|
||||
response = await cls.get(current_url, **kwargs)
|
||||
response.raise_for_status()
|
||||
async with aiofiles.open(path, "wb") as f:
|
||||
await f.write(response.content)
|
||||
else:
|
||||
async with cls._create_client(**kwargs) as client:
|
||||
stream_kwargs = {
|
||||
k: v
|
||||
for k, v in kwargs.items()
|
||||
if k not in ["use_proxy", "proxy", "verify"]
|
||||
}
|
||||
async with client.stream(
|
||||
"GET", current_url, **stream_kwargs
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
total = int(response.headers.get("Content-Length", 0))
|
||||
logger.info(
|
||||
f"下载 {current_url} 成功 -> {path.absolute()}",
|
||||
"AsyncHttpx:download",
|
||||
)
|
||||
return True
|
||||
|
||||
with Progress(
|
||||
TextColumn(path.name),
|
||||
"[progress.percentage]{task.percentage:>3.0f}%",
|
||||
BarColumn(bar_width=None),
|
||||
DownloadColumn(),
|
||||
TransferSpeedColumn(),
|
||||
) as progress:
|
||||
task_id = progress.add_task("Download", total=total)
|
||||
async with aiofiles.open(path, "wb") as f:
|
||||
async for chunk in response.aiter_bytes():
|
||||
await f.write(chunk)
|
||||
progress.update(task_id, advance=len(chunk))
|
||||
|
||||
logger.info(f"下载 {current_url} 成功 -> {path.absolute()}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"下载 {current_url} 失败,尝试下一个。错误: {e}")
|
||||
|
||||
logger.error(f"所有URL {urls} 下载均失败 -> {path.absolute()}")
|
||||
return False
|
||||
try:
|
||||
return await cls._execute_with_fallbacks(
|
||||
url, worker, client=client, **kwargs
|
||||
)
|
||||
except AllURIsFailedError:
|
||||
logger.error(
|
||||
f"所有URL下载均失败 -> {path.absolute()}", "AsyncHttpx:download"
|
||||
)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def gather_download_file(
|
||||
@ -346,7 +617,6 @@ class AsyncHttpx:
|
||||
logger.error(f"并发下载任务 ({url_info}) 时发生错误", e=result)
|
||||
final_results.append(False)
|
||||
else:
|
||||
# download_file 返回的是 bool,可以直接附加
|
||||
final_results.append(cast(bool, result))
|
||||
|
||||
return final_results
|
||||
@ -395,86 +665,30 @@ class AsyncHttpx:
|
||||
_results = sorted(iter(_results), key=lambda r: r["elapsed_time"])
|
||||
return [result["url"] for result in _results]
|
||||
|
||||
|
||||
class AsyncPlaywright:
|
||||
@classmethod
|
||||
@asynccontextmanager
|
||||
async def new_page(
|
||||
cls, cookies: list[dict[str, Any]] | dict[str, Any] | None = None, **kwargs
|
||||
) -> AsyncGenerator[Page, None]:
|
||||
"""获取一个新页面
|
||||
async def temporary_client(cls, **kwargs) -> AsyncGenerator[AsyncClient, None]:
|
||||
"""
|
||||
创建一个临时的、可配置的HTTP客户端上下文,并直接返回该客户端实例。
|
||||
|
||||
此方法返回一个标准的 `httpx.AsyncClient`,它不使用全局连接池,
|
||||
拥有独立的配置(如代理、headers、超时等),并在退出上下文后自动关闭。
|
||||
适用于需要用一套特殊网络配置执行一系列请求的场景。
|
||||
|
||||
用法:
|
||||
async with AsyncHttpx.temporary_client(proxies=None, timeout=5) as client:
|
||||
# client 是一个标准的 httpx.AsyncClient 实例
|
||||
response1 = await client.get("http://some.internal.api/1")
|
||||
response2 = await client.get("http://some.internal.api/2")
|
||||
data = response2.json()
|
||||
|
||||
参数:
|
||||
cookies: cookies
|
||||
**kwargs: 所有传递给 `httpx.AsyncClient` 构造函数的参数。
|
||||
例如: `proxies`, `headers`, `verify`, `timeout`,
|
||||
`follow_redirects`。
|
||||
|
||||
Yields:
|
||||
httpx.AsyncClient: 一个配置好的、临时的客户端实例。
|
||||
"""
|
||||
browser = await get_browser()
|
||||
ctx = await browser.new_context(**kwargs)
|
||||
if cookies:
|
||||
if isinstance(cookies, dict):
|
||||
cookies = [cookies]
|
||||
await ctx.add_cookies(cookies) # type: ignore
|
||||
page = await ctx.new_page()
|
||||
try:
|
||||
yield page
|
||||
finally:
|
||||
await page.close()
|
||||
await ctx.close()
|
||||
|
||||
@classmethod
|
||||
async def screenshot(
|
||||
cls,
|
||||
url: str,
|
||||
path: Path | str,
|
||||
element: str | list[str],
|
||||
*,
|
||||
wait_time: int | None = None,
|
||||
viewport_size: dict[str, int] | None = None,
|
||||
wait_until: (
|
||||
Literal["domcontentloaded", "load", "networkidle"] | None
|
||||
) = "networkidle",
|
||||
timeout: float | None = None,
|
||||
type_: Literal["jpeg", "png"] | None = None,
|
||||
user_agent: str | None = None,
|
||||
cookies: list[dict[str, Any]] | dict[str, Any] | None = None,
|
||||
**kwargs,
|
||||
) -> UniMessage | None:
|
||||
"""截图,该方法仅用于简单快捷截图,复杂截图请操作 page
|
||||
|
||||
参数:
|
||||
url: 网址
|
||||
path: 存储路径
|
||||
element: 元素选择
|
||||
wait_time: 等待截取超时时间
|
||||
viewport_size: 窗口大小
|
||||
wait_until: 等待类型
|
||||
timeout: 超时限制
|
||||
type_: 保存类型
|
||||
user_agent: user_agent
|
||||
cookies: cookies
|
||||
"""
|
||||
if viewport_size is None:
|
||||
viewport_size = {"width": 2560, "height": 1080}
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
wait_time = wait_time * 1000 if wait_time else None
|
||||
element_list = [element] if isinstance(element, str) else element
|
||||
async with cls.new_page(
|
||||
cookies,
|
||||
viewport=viewport_size,
|
||||
user_agent=user_agent,
|
||||
**kwargs,
|
||||
) as page:
|
||||
await page.goto(url, timeout=timeout, wait_until=wait_until)
|
||||
card = page
|
||||
for e in element_list:
|
||||
if not card:
|
||||
return None
|
||||
card = await card.wait_for_selector(e, timeout=wait_time)
|
||||
if card:
|
||||
await card.screenshot(path=path, timeout=timeout, type=type_)
|
||||
return MessageUtils.build_message(path)
|
||||
return None
|
||||
|
||||
|
||||
class BrowserIsNone(Exception):
|
||||
pass
|
||||
async with get_async_client(**kwargs) as client:
|
||||
yield client
|
||||
810
zhenxun/utils/manager/schedule_manager.py
Normal file
810
zhenxun/utils/manager/schedule_manager.py
Normal file
@ -0,0 +1,810 @@
|
||||
import asyncio
|
||||
from collections.abc import Callable, Coroutine
|
||||
import copy
|
||||
import inspect
|
||||
import random
|
||||
from typing import ClassVar
|
||||
|
||||
import nonebot
|
||||
from nonebot import get_bots
|
||||
from nonebot_plugin_apscheduler import scheduler
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.models.schedule_info import ScheduleInfo
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.common_utils import CommonUtils
|
||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
SCHEDULE_CONCURRENCY_KEY = "all_groups_concurrency_limit"
|
||||
|
||||
|
||||
class SchedulerManager:
|
||||
"""
|
||||
一个通用的、持久化的定时任务管理器,供所有插件使用。
|
||||
"""
|
||||
|
||||
_registered_tasks: ClassVar[
|
||||
dict[str, dict[str, Callable | type[BaseModel] | None]]
|
||||
] = {}
|
||||
_JOB_PREFIX = "zhenxun_schedule_"
|
||||
_running_tasks: ClassVar[set] = set()
|
||||
|
||||
def register(
|
||||
self, plugin_name: str, params_model: type[BaseModel] | None = None
|
||||
) -> Callable:
|
||||
"""
|
||||
注册一个可调度的任务函数。
|
||||
被装饰的函数签名应为 `async def func(group_id: str | None, **kwargs)`
|
||||
|
||||
Args:
|
||||
plugin_name (str): 插件的唯一名称 (通常是模块名)。
|
||||
params_model (type[BaseModel], optional): 一个 Pydantic BaseModel 类,
|
||||
用于定义和验证任务函数接受的额外参数。
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[..., Coroutine]) -> Callable[..., Coroutine]:
|
||||
if plugin_name in self._registered_tasks:
|
||||
logger.warning(f"插件 '{plugin_name}' 的定时任务已被重复注册。")
|
||||
self._registered_tasks[plugin_name] = {
|
||||
"func": func,
|
||||
"model": params_model,
|
||||
}
|
||||
model_name = params_model.__name__ if params_model else "无"
|
||||
logger.debug(
|
||||
f"插件 '{plugin_name}' 的定时任务已注册,参数模型: {model_name}"
|
||||
)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def get_registered_plugins(self) -> list[str]:
|
||||
"""获取所有已注册定时任务的插件列表。"""
|
||||
return list(self._registered_tasks.keys())
|
||||
|
||||
def _get_job_id(self, schedule_id: int) -> str:
|
||||
"""根据数据库ID生成唯一的 APScheduler Job ID。"""
|
||||
return f"{self._JOB_PREFIX}{schedule_id}"
|
||||
|
||||
async def _execute_job(self, schedule_id: int):
|
||||
"""
|
||||
APScheduler 调度的入口函数。
|
||||
根据 schedule_id 处理特定任务、所有群组任务或全局任务。
|
||||
"""
|
||||
schedule = await ScheduleInfo.get_or_none(id=schedule_id)
|
||||
if not schedule or not schedule.is_enabled:
|
||||
logger.warning(f"定时任务 {schedule_id} 不存在或已禁用,跳过执行。")
|
||||
return
|
||||
|
||||
plugin_name = schedule.plugin_name
|
||||
|
||||
task_meta = self._registered_tasks.get(plugin_name)
|
||||
if not task_meta:
|
||||
logger.error(
|
||||
f"无法执行定时任务:插件 '{plugin_name}' 未注册或已卸载。将禁用该任务。"
|
||||
)
|
||||
schedule.is_enabled = False
|
||||
await schedule.save(update_fields=["is_enabled"])
|
||||
self._remove_aps_job(schedule.id)
|
||||
return
|
||||
|
||||
try:
|
||||
if schedule.bot_id:
|
||||
bot = nonebot.get_bot(schedule.bot_id)
|
||||
else:
|
||||
bot = nonebot.get_bot()
|
||||
logger.debug(
|
||||
f"任务 {schedule_id} 未关联特定Bot,使用默认Bot {bot.self_id}"
|
||||
)
|
||||
except KeyError:
|
||||
logger.warning(
|
||||
f"定时任务 {schedule_id} 需要的 Bot {schedule.bot_id} "
|
||||
f"不在线,本次执行跳过。"
|
||||
)
|
||||
return
|
||||
except ValueError:
|
||||
logger.warning(f"当前没有Bot在线,定时任务 {schedule_id} 跳过。")
|
||||
return
|
||||
|
||||
if schedule.group_id == "__ALL_GROUPS__":
|
||||
await self._execute_for_all_groups(schedule, task_meta, bot)
|
||||
else:
|
||||
await self._execute_for_single_target(schedule, task_meta, bot)
|
||||
|
||||
async def _execute_for_all_groups(
|
||||
self, schedule: ScheduleInfo, task_meta: dict, bot
|
||||
):
|
||||
"""为所有群组执行任务,并处理优先级覆盖。"""
|
||||
plugin_name = schedule.plugin_name
|
||||
|
||||
concurrency_limit = Config.get_config(
|
||||
"SchedulerManager", SCHEDULE_CONCURRENCY_KEY, 5
|
||||
)
|
||||
if not isinstance(concurrency_limit, int) or concurrency_limit <= 0:
|
||||
logger.warning(
|
||||
f"无效的定时任务并发限制配置 '{concurrency_limit}',将使用默认值 5。"
|
||||
)
|
||||
concurrency_limit = 5
|
||||
|
||||
logger.info(
|
||||
f"开始执行针对 [所有群组] 的任务 "
|
||||
f"(ID: {schedule.id}, 插件: {plugin_name}, Bot: {bot.self_id}),"
|
||||
f"并发限制: {concurrency_limit}"
|
||||
)
|
||||
|
||||
all_gids = set()
|
||||
try:
|
||||
group_list, _ = await PlatformUtils.get_group_list(bot)
|
||||
all_gids.update(
|
||||
g.group_id for g in group_list if g.group_id and not g.channel_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"为 'all' 任务获取 Bot {bot.self_id} 的群列表失败", e=e)
|
||||
return
|
||||
|
||||
specific_tasks_gids = set(
|
||||
await ScheduleInfo.filter(
|
||||
plugin_name=plugin_name, group_id__in=list(all_gids)
|
||||
).values_list("group_id", flat=True)
|
||||
)
|
||||
|
||||
semaphore = asyncio.Semaphore(concurrency_limit)
|
||||
|
||||
async def worker(gid: str):
|
||||
"""使用 Semaphore 包装单个群组的任务执行"""
|
||||
async with semaphore:
|
||||
temp_schedule = copy.deepcopy(schedule)
|
||||
temp_schedule.group_id = gid
|
||||
await self._execute_for_single_target(temp_schedule, task_meta, bot)
|
||||
await asyncio.sleep(random.uniform(0.1, 0.5))
|
||||
|
||||
tasks_to_run = []
|
||||
for gid in all_gids:
|
||||
if gid in specific_tasks_gids:
|
||||
logger.debug(f"群组 {gid} 已有特定任务,跳过 'all' 任务的执行。")
|
||||
continue
|
||||
tasks_to_run.append(worker(gid))
|
||||
|
||||
if tasks_to_run:
|
||||
await asyncio.gather(*tasks_to_run)
|
||||
|
||||
async def _execute_for_single_target(
|
||||
self, schedule: ScheduleInfo, task_meta: dict, bot
|
||||
):
|
||||
"""为单个目标(具体群组或全局)执行任务。"""
|
||||
plugin_name = schedule.plugin_name
|
||||
group_id = schedule.group_id
|
||||
|
||||
try:
|
||||
is_blocked = await CommonUtils.task_is_block(bot, plugin_name, group_id)
|
||||
if is_blocked:
|
||||
target_desc = f"群 {group_id}" if group_id else "全局"
|
||||
logger.info(
|
||||
f"插件 '{plugin_name}' 的定时任务在目标 [{target_desc}]"
|
||||
"因功能被禁用而跳过执行。"
|
||||
)
|
||||
return
|
||||
|
||||
task_func = task_meta["func"]
|
||||
job_kwargs = schedule.job_kwargs
|
||||
if not isinstance(job_kwargs, dict):
|
||||
logger.error(
|
||||
f"任务 {schedule.id} 的 job_kwargs 不是字典类型: {type(job_kwargs)}"
|
||||
)
|
||||
return
|
||||
|
||||
sig = inspect.signature(task_func)
|
||||
if "bot" in sig.parameters:
|
||||
job_kwargs["bot"] = bot
|
||||
|
||||
logger.info(
|
||||
f"插件 '{plugin_name}' 开始为目标 [{group_id or '全局'}] "
|
||||
f"执行定时任务 (ID: {schedule.id})。"
|
||||
)
|
||||
task = asyncio.create_task(task_func(group_id, **job_kwargs))
|
||||
self._running_tasks.add(task)
|
||||
task.add_done_callback(self._running_tasks.discard)
|
||||
await task
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"执行定时任务 (ID: {schedule.id}, 插件: {plugin_name}, "
|
||||
f"目标: {group_id or '全局'}) 时发生异常",
|
||||
e=e,
|
||||
)
|
||||
|
||||
def _validate_and_prepare_kwargs(
|
||||
self, plugin_name: str, job_kwargs: dict | None
|
||||
) -> tuple[bool, str | dict]:
|
||||
"""验证并准备任务参数,应用默认值"""
|
||||
task_meta = self._registered_tasks.get(plugin_name)
|
||||
if not task_meta:
|
||||
return False, f"插件 '{plugin_name}' 未注册。"
|
||||
|
||||
params_model = task_meta.get("model")
|
||||
job_kwargs = job_kwargs if job_kwargs is not None else {}
|
||||
|
||||
if not params_model:
|
||||
if job_kwargs:
|
||||
logger.warning(
|
||||
f"插件 '{plugin_name}' 未定义参数模型,但收到了参数: {job_kwargs}"
|
||||
)
|
||||
return True, job_kwargs
|
||||
|
||||
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
|
||||
logger.error(f"插件 '{plugin_name}' 的参数模型不是有效的 BaseModel 类")
|
||||
return False, f"插件 '{plugin_name}' 的参数模型配置错误"
|
||||
|
||||
try:
|
||||
model_validate = getattr(params_model, "model_validate", None)
|
||||
if not model_validate:
|
||||
return False, f"插件 '{plugin_name}' 的参数模型不支持验证"
|
||||
|
||||
validated_model = model_validate(job_kwargs)
|
||||
|
||||
model_dump = getattr(validated_model, "model_dump", None)
|
||||
if not model_dump:
|
||||
return False, f"插件 '{plugin_name}' 的参数模型不支持导出"
|
||||
|
||||
return True, model_dump()
|
||||
except ValidationError as e:
|
||||
errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()]
|
||||
error_str = "\n".join(errors)
|
||||
msg = f"插件 '{plugin_name}' 的任务参数验证失败:\n{error_str}"
|
||||
return False, msg
|
||||
|
||||
def _add_aps_job(self, schedule: ScheduleInfo):
|
||||
"""根据 ScheduleInfo 对象添加或更新一个 APScheduler 任务。"""
|
||||
job_id = self._get_job_id(schedule.id)
|
||||
try:
|
||||
scheduler.remove_job(job_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not isinstance(schedule.trigger_config, dict):
|
||||
logger.error(
|
||||
f"任务 {schedule.id} 的 trigger_config 不是字典类型: "
|
||||
f"{type(schedule.trigger_config)}"
|
||||
)
|
||||
return
|
||||
|
||||
scheduler.add_job(
|
||||
self._execute_job,
|
||||
trigger=schedule.trigger_type,
|
||||
id=job_id,
|
||||
misfire_grace_time=300,
|
||||
args=[schedule.id],
|
||||
**schedule.trigger_config,
|
||||
)
|
||||
logger.debug(
|
||||
f"已在 APScheduler 中添加/更新任务: {job_id} "
|
||||
f"with trigger: {schedule.trigger_config}"
|
||||
)
|
||||
|
||||
def _remove_aps_job(self, schedule_id: int):
|
||||
"""移除一个 APScheduler 任务。"""
|
||||
job_id = self._get_job_id(schedule_id)
|
||||
try:
|
||||
scheduler.remove_job(job_id)
|
||||
logger.debug(f"已从 APScheduler 中移除任务: {job_id}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def add_schedule(
|
||||
self,
|
||||
plugin_name: str,
|
||||
group_id: str | None,
|
||||
trigger_type: str,
|
||||
trigger_config: dict,
|
||||
job_kwargs: dict | None = None,
|
||||
bot_id: str | None = None,
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
添加或更新一个定时任务。
|
||||
"""
|
||||
if plugin_name not in self._registered_tasks:
|
||||
return False, f"插件 '{plugin_name}' 没有注册可用的定时任务。"
|
||||
|
||||
is_valid, result = self._validate_and_prepare_kwargs(plugin_name, job_kwargs)
|
||||
if not is_valid:
|
||||
return False, str(result)
|
||||
|
||||
validated_job_kwargs = result
|
||||
|
||||
effective_bot_id = bot_id if group_id == "__ALL_GROUPS__" else None
|
||||
|
||||
search_kwargs = {
|
||||
"plugin_name": plugin_name,
|
||||
"group_id": group_id,
|
||||
}
|
||||
if effective_bot_id:
|
||||
search_kwargs["bot_id"] = effective_bot_id
|
||||
else:
|
||||
search_kwargs["bot_id__isnull"] = True
|
||||
|
||||
defaults = {
|
||||
"trigger_type": trigger_type,
|
||||
"trigger_config": trigger_config,
|
||||
"job_kwargs": validated_job_kwargs,
|
||||
"is_enabled": True,
|
||||
}
|
||||
|
||||
schedule = await ScheduleInfo.filter(**search_kwargs).first()
|
||||
created = False
|
||||
|
||||
if schedule:
|
||||
for key, value in defaults.items():
|
||||
setattr(schedule, key, value)
|
||||
await schedule.save()
|
||||
else:
|
||||
creation_kwargs = {
|
||||
"plugin_name": plugin_name,
|
||||
"group_id": group_id,
|
||||
"bot_id": effective_bot_id,
|
||||
**defaults,
|
||||
}
|
||||
schedule = await ScheduleInfo.create(**creation_kwargs)
|
||||
created = True
|
||||
self._add_aps_job(schedule)
|
||||
action = "设置" if created else "更新"
|
||||
return True, f"已成功{action}插件 '{plugin_name}' 的定时任务。"
|
||||
|
||||
async def add_schedule_for_all(
|
||||
self,
|
||||
plugin_name: str,
|
||||
trigger_type: str,
|
||||
trigger_config: dict,
|
||||
job_kwargs: dict | None = None,
|
||||
) -> tuple[int, int]:
|
||||
"""为所有机器人所在的群组添加定时任务。"""
|
||||
if plugin_name not in self._registered_tasks:
|
||||
raise ValueError(f"插件 '{plugin_name}' 没有注册可用的定时任务。")
|
||||
|
||||
groups = set()
|
||||
for bot in get_bots().values():
|
||||
try:
|
||||
group_list, _ = await PlatformUtils.get_group_list(bot)
|
||||
groups.update(
|
||||
g.group_id for g in group_list if g.group_id and not g.channel_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取 Bot {bot.self_id} 的群列表失败", e=e)
|
||||
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
for gid in groups:
|
||||
try:
|
||||
success, _ = await self.add_schedule(
|
||||
plugin_name, gid, trigger_type, trigger_config, job_kwargs
|
||||
)
|
||||
if success:
|
||||
success_count += 1
|
||||
else:
|
||||
fail_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"为群 {gid} 添加定时任务失败: {e}", e=e)
|
||||
fail_count += 1
|
||||
await asyncio.sleep(0.05)
|
||||
return success_count, fail_count
|
||||
|
||||
async def update_schedule(
|
||||
self,
|
||||
schedule_id: int,
|
||||
trigger_type: str | None = None,
|
||||
trigger_config: dict | None = None,
|
||||
job_kwargs: dict | None = None,
|
||||
) -> tuple[bool, str]:
|
||||
"""部分更新一个已存在的定时任务。"""
|
||||
schedule = await self.get_schedule_by_id(schedule_id)
|
||||
if not schedule:
|
||||
return False, f"未找到 ID 为 {schedule_id} 的任务。"
|
||||
|
||||
updated_fields = []
|
||||
if trigger_config is not None:
|
||||
schedule.trigger_config = trigger_config
|
||||
updated_fields.append("trigger_config")
|
||||
|
||||
if trigger_type is not None and schedule.trigger_type != trigger_type:
|
||||
schedule.trigger_type = trigger_type
|
||||
updated_fields.append("trigger_type")
|
||||
|
||||
if job_kwargs is not None:
|
||||
if not isinstance(schedule.job_kwargs, dict):
|
||||
return False, f"任务 {schedule_id} 的 job_kwargs 数据格式错误。"
|
||||
|
||||
merged_kwargs = schedule.job_kwargs.copy()
|
||||
merged_kwargs.update(job_kwargs)
|
||||
|
||||
is_valid, result = self._validate_and_prepare_kwargs(
|
||||
schedule.plugin_name, merged_kwargs
|
||||
)
|
||||
if not is_valid:
|
||||
return False, str(result)
|
||||
|
||||
schedule.job_kwargs = result # type: ignore
|
||||
updated_fields.append("job_kwargs")
|
||||
|
||||
if not updated_fields:
|
||||
return True, "没有任何需要更新的配置。"
|
||||
|
||||
await schedule.save(update_fields=updated_fields)
|
||||
self._add_aps_job(schedule)
|
||||
return True, f"成功更新了任务 ID: {schedule_id} 的配置。"
|
||||
|
||||
async def remove_schedule(
|
||||
self, plugin_name: str, group_id: str | None, bot_id: str | None = None
|
||||
) -> tuple[bool, str]:
|
||||
"""移除指定插件和群组的定时任务。"""
|
||||
query = {"plugin_name": plugin_name, "group_id": group_id}
|
||||
if bot_id:
|
||||
query["bot_id"] = bot_id
|
||||
|
||||
schedules = await ScheduleInfo.filter(**query)
|
||||
if not schedules:
|
||||
msg = (
|
||||
f"未找到与 Bot {bot_id} 相关的群 {group_id} "
|
||||
f"的插件 '{plugin_name}' 定时任务。"
|
||||
)
|
||||
return (False, msg)
|
||||
|
||||
for schedule in schedules:
|
||||
self._remove_aps_job(schedule.id)
|
||||
await schedule.delete()
|
||||
|
||||
target_desc = f"群 {group_id}" if group_id else "全局"
|
||||
msg = (
|
||||
f"已取消 Bot {bot_id} 在 [{target_desc}] "
|
||||
f"的插件 '{plugin_name}' 所有定时任务。"
|
||||
)
|
||||
return (True, msg)
|
||||
|
||||
async def remove_schedule_for_all(
|
||||
self, plugin_name: str, bot_id: str | None = None
|
||||
) -> int:
|
||||
"""移除指定插件在所有群组的定时任务。"""
|
||||
query = {"plugin_name": plugin_name}
|
||||
if bot_id:
|
||||
query["bot_id"] = bot_id
|
||||
|
||||
schedules_to_delete = await ScheduleInfo.filter(**query).all()
|
||||
if not schedules_to_delete:
|
||||
return 0
|
||||
|
||||
for schedule in schedules_to_delete:
|
||||
self._remove_aps_job(schedule.id)
|
||||
await schedule.delete()
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
return len(schedules_to_delete)
|
||||
|
||||
async def remove_schedules_by_group(self, group_id: str) -> tuple[bool, str]:
|
||||
"""移除指定群组的所有定时任务。"""
|
||||
schedules = await ScheduleInfo.filter(group_id=group_id)
|
||||
if not schedules:
|
||||
return False, f"群 {group_id} 没有任何定时任务。"
|
||||
|
||||
count = 0
|
||||
for schedule in schedules:
|
||||
self._remove_aps_job(schedule.id)
|
||||
await schedule.delete()
|
||||
count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
return True, f"已成功移除群 {group_id} 的 {count} 个定时任务。"
|
||||
|
||||
async def pause_schedules_by_group(self, group_id: str) -> tuple[int, str]:
|
||||
"""暂停指定群组的所有定时任务。"""
|
||||
schedules = await ScheduleInfo.filter(group_id=group_id, is_enabled=True)
|
||||
if not schedules:
|
||||
return 0, f"群 {group_id} 没有正在运行的定时任务可暂停。"
|
||||
|
||||
count = 0
|
||||
for schedule in schedules:
|
||||
success, _ = await self.pause_schedule(schedule.id)
|
||||
if success:
|
||||
count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
return count, f"已成功暂停群 {group_id} 的 {count} 个定时任务。"
|
||||
|
||||
async def resume_schedules_by_group(self, group_id: str) -> tuple[int, str]:
|
||||
"""恢复指定群组的所有定时任务。"""
|
||||
schedules = await ScheduleInfo.filter(group_id=group_id, is_enabled=False)
|
||||
if not schedules:
|
||||
return 0, f"群 {group_id} 没有已暂停的定时任务可恢复。"
|
||||
|
||||
count = 0
|
||||
for schedule in schedules:
|
||||
success, _ = await self.resume_schedule(schedule.id)
|
||||
if success:
|
||||
count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
return count, f"已成功恢复群 {group_id} 的 {count} 个定时任务。"
|
||||
|
||||
async def pause_schedules_by_plugin(self, plugin_name: str) -> tuple[int, str]:
|
||||
"""暂停指定插件在所有群组的定时任务。"""
|
||||
schedules = await ScheduleInfo.filter(plugin_name=plugin_name, is_enabled=True)
|
||||
if not schedules:
|
||||
return 0, f"插件 '{plugin_name}' 没有正在运行的定时任务可暂停。"
|
||||
|
||||
count = 0
|
||||
for schedule in schedules:
|
||||
success, _ = await self.pause_schedule(schedule.id)
|
||||
if success:
|
||||
count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
return (
|
||||
count,
|
||||
f"已成功暂停插件 '{plugin_name}' 在所有群组的 {count} 个定时任务。",
|
||||
)
|
||||
|
||||
async def resume_schedules_by_plugin(self, plugin_name: str) -> tuple[int, str]:
|
||||
"""恢复指定插件在所有群组的定时任务。"""
|
||||
schedules = await ScheduleInfo.filter(plugin_name=plugin_name, is_enabled=False)
|
||||
if not schedules:
|
||||
return 0, f"插件 '{plugin_name}' 没有已暂停的定时任务可恢复。"
|
||||
|
||||
count = 0
|
||||
for schedule in schedules:
|
||||
success, _ = await self.resume_schedule(schedule.id)
|
||||
if success:
|
||||
count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
return (
|
||||
count,
|
||||
f"已成功恢复插件 '{plugin_name}' 在所有群组的 {count} 个定时任务。",
|
||||
)
|
||||
|
||||
async def pause_schedule_by_plugin_group(
|
||||
self, plugin_name: str, group_id: str | None, bot_id: str | None = None
|
||||
) -> tuple[bool, str]:
|
||||
"""暂停指定插件在指定群组的定时任务。"""
|
||||
query = {"plugin_name": plugin_name, "group_id": group_id, "is_enabled": True}
|
||||
if bot_id:
|
||||
query["bot_id"] = bot_id
|
||||
|
||||
schedules = await ScheduleInfo.filter(**query)
|
||||
if not schedules:
|
||||
return (
|
||||
False,
|
||||
f"群 {group_id} 未设置插件 '{plugin_name}' 的定时任务或任务已暂停。",
|
||||
)
|
||||
|
||||
count = 0
|
||||
for schedule in schedules:
|
||||
success, _ = await self.pause_schedule(schedule.id)
|
||||
if success:
|
||||
count += 1
|
||||
|
||||
return (
|
||||
True,
|
||||
f"已成功暂停群 {group_id} 的插件 '{plugin_name}' 共 {count} 个定时任务。",
|
||||
)
|
||||
|
||||
async def resume_schedule_by_plugin_group(
|
||||
self, plugin_name: str, group_id: str | None, bot_id: str | None = None
|
||||
) -> tuple[bool, str]:
|
||||
"""恢复指定插件在指定群组的定时任务。"""
|
||||
query = {"plugin_name": plugin_name, "group_id": group_id, "is_enabled": False}
|
||||
if bot_id:
|
||||
query["bot_id"] = bot_id
|
||||
|
||||
schedules = await ScheduleInfo.filter(**query)
|
||||
if not schedules:
|
||||
return (
|
||||
False,
|
||||
f"群 {group_id} 未设置插件 '{plugin_name}' 的定时任务或任务已启用。",
|
||||
)
|
||||
|
||||
count = 0
|
||||
for schedule in schedules:
|
||||
success, _ = await self.resume_schedule(schedule.id)
|
||||
if success:
|
||||
count += 1
|
||||
|
||||
return (
|
||||
True,
|
||||
f"已成功恢复群 {group_id} 的插件 '{plugin_name}' 共 {count} 个定时任务。",
|
||||
)
|
||||
|
||||
async def remove_all_schedules(self) -> tuple[int, str]:
|
||||
"""移除所有群组的所有定时任务。"""
|
||||
schedules = await ScheduleInfo.all()
|
||||
if not schedules:
|
||||
return 0, "当前没有任何定时任务。"
|
||||
|
||||
count = 0
|
||||
for schedule in schedules:
|
||||
self._remove_aps_job(schedule.id)
|
||||
await schedule.delete()
|
||||
count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
return count, f"已成功移除所有群组的 {count} 个定时任务。"
|
||||
|
||||
async def pause_all_schedules(self) -> tuple[int, str]:
|
||||
"""暂停所有群组的所有定时任务。"""
|
||||
schedules = await ScheduleInfo.filter(is_enabled=True)
|
||||
if not schedules:
|
||||
return 0, "当前没有正在运行的定时任务可暂停。"
|
||||
|
||||
count = 0
|
||||
for schedule in schedules:
|
||||
success, _ = await self.pause_schedule(schedule.id)
|
||||
if success:
|
||||
count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
return count, f"已成功暂停所有群组的 {count} 个定时任务。"
|
||||
|
||||
async def resume_all_schedules(self) -> tuple[int, str]:
|
||||
"""恢复所有群组的所有定时任务。"""
|
||||
schedules = await ScheduleInfo.filter(is_enabled=False)
|
||||
if not schedules:
|
||||
return 0, "当前没有已暂停的定时任务可恢复。"
|
||||
|
||||
count = 0
|
||||
for schedule in schedules:
|
||||
success, _ = await self.resume_schedule(schedule.id)
|
||||
if success:
|
||||
count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
return count, f"已成功恢复所有群组的 {count} 个定时任务。"
|
||||
|
||||
async def remove_schedule_by_id(self, schedule_id: int) -> tuple[bool, str]:
|
||||
"""通过ID移除指定的定时任务。"""
|
||||
schedule = await self.get_schedule_by_id(schedule_id)
|
||||
if not schedule:
|
||||
return False, f"未找到 ID 为 {schedule_id} 的定时任务。"
|
||||
|
||||
self._remove_aps_job(schedule.id)
|
||||
await schedule.delete()
|
||||
|
||||
return (
|
||||
True,
|
||||
f"已删除插件 '{schedule.plugin_name}' 在群 {schedule.group_id} "
|
||||
f"的定时任务 (ID: {schedule.id})。",
|
||||
)
|
||||
|
||||
async def get_schedule_by_id(self, schedule_id: int) -> ScheduleInfo | None:
|
||||
"""通过ID获取定时任务信息。"""
|
||||
return await ScheduleInfo.get_or_none(id=schedule_id)
|
||||
|
||||
async def get_schedules(
|
||||
self, plugin_name: str, group_id: str | None
|
||||
) -> list[ScheduleInfo]:
|
||||
"""获取特定群组特定插件的所有定时任务。"""
|
||||
return await ScheduleInfo.filter(plugin_name=plugin_name, group_id=group_id)
|
||||
|
||||
async def get_schedule(
|
||||
self, plugin_name: str, group_id: str | None
|
||||
) -> ScheduleInfo | None:
|
||||
"""获取特定群组的定时任务信息。"""
|
||||
return await ScheduleInfo.get_or_none(
|
||||
plugin_name=plugin_name, group_id=group_id
|
||||
)
|
||||
|
||||
async def get_all_schedules(
|
||||
self, plugin_name: str | None = None
|
||||
) -> list[ScheduleInfo]:
|
||||
"""获取所有定时任务信息,可按插件名过滤。"""
|
||||
if plugin_name:
|
||||
return await ScheduleInfo.filter(plugin_name=plugin_name).all()
|
||||
return await ScheduleInfo.all()
|
||||
|
||||
async def get_schedule_status(self, schedule_id: int) -> dict | None:
|
||||
"""获取任务的详细状态。"""
|
||||
schedule = await self.get_schedule_by_id(schedule_id)
|
||||
if not schedule:
|
||||
return None
|
||||
|
||||
job_id = self._get_job_id(schedule.id)
|
||||
job = scheduler.get_job(job_id)
|
||||
|
||||
status = {
|
||||
"id": schedule.id,
|
||||
"bot_id": schedule.bot_id,
|
||||
"plugin_name": schedule.plugin_name,
|
||||
"group_id": schedule.group_id,
|
||||
"is_enabled": schedule.is_enabled,
|
||||
"trigger_type": schedule.trigger_type,
|
||||
"trigger_config": schedule.trigger_config,
|
||||
"job_kwargs": schedule.job_kwargs,
|
||||
"next_run_time": job.next_run_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
if job and job.next_run_time
|
||||
else "N/A",
|
||||
"is_paused_in_scheduler": not bool(job.next_run_time) if job else "N/A",
|
||||
}
|
||||
return status
|
||||
|
||||
async def pause_schedule(self, schedule_id: int) -> tuple[bool, str]:
|
||||
"""暂停一个定时任务。"""
|
||||
schedule = await self.get_schedule_by_id(schedule_id)
|
||||
if not schedule or not schedule.is_enabled:
|
||||
return False, "任务不存在或已暂停。"
|
||||
|
||||
schedule.is_enabled = False
|
||||
await schedule.save(update_fields=["is_enabled"])
|
||||
|
||||
job_id = self._get_job_id(schedule.id)
|
||||
try:
|
||||
scheduler.pause_job(job_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return (
|
||||
True,
|
||||
f"已暂停插件 '{schedule.plugin_name}' 在群 {schedule.group_id} "
|
||||
f"的定时任务 (ID: {schedule.id})。",
|
||||
)
|
||||
|
||||
async def resume_schedule(self, schedule_id: int) -> tuple[bool, str]:
|
||||
"""恢复一个定时任务。"""
|
||||
schedule = await self.get_schedule_by_id(schedule_id)
|
||||
if not schedule or schedule.is_enabled:
|
||||
return False, "任务不存在或已启用。"
|
||||
|
||||
schedule.is_enabled = True
|
||||
await schedule.save(update_fields=["is_enabled"])
|
||||
|
||||
job_id = self._get_job_id(schedule.id)
|
||||
try:
|
||||
scheduler.resume_job(job_id)
|
||||
except Exception:
|
||||
self._add_aps_job(schedule)
|
||||
|
||||
return (
|
||||
True,
|
||||
f"已恢复插件 '{schedule.plugin_name}' 在群 {schedule.group_id} "
|
||||
f"的定时任务 (ID: {schedule.id})。",
|
||||
)
|
||||
|
||||
async def trigger_now(self, schedule_id: int) -> tuple[bool, str]:
|
||||
"""手动触发一个定时任务。"""
|
||||
schedule = await self.get_schedule_by_id(schedule_id)
|
||||
if not schedule:
|
||||
return False, f"未找到 ID 为 {schedule_id} 的定时任务。"
|
||||
|
||||
if schedule.plugin_name not in self._registered_tasks:
|
||||
return False, f"插件 '{schedule.plugin_name}' 没有注册可用的定时任务。"
|
||||
|
||||
try:
|
||||
await self._execute_job(schedule.id)
|
||||
return (
|
||||
True,
|
||||
f"已手动触发插件 '{schedule.plugin_name}' 在群 {schedule.group_id} "
|
||||
f"的定时任务 (ID: {schedule.id})。",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"手动触发任务失败: {e}")
|
||||
return False, f"手动触发任务失败: {e}"
|
||||
|
||||
|
||||
scheduler_manager = SchedulerManager()
|
||||
|
||||
|
||||
@PriorityLifecycle.on_startup(priority=90)
|
||||
async def _load_schedules_from_db():
|
||||
"""在服务启动时从数据库加载并调度所有任务。"""
|
||||
Config.add_plugin_config(
|
||||
"SchedulerManager",
|
||||
SCHEDULE_CONCURRENCY_KEY,
|
||||
5,
|
||||
help="“所有群组”类型定时任务的并发执行数量限制",
|
||||
type=int,
|
||||
)
|
||||
|
||||
logger.info("正在从数据库加载并调度所有定时任务...")
|
||||
schedules = await ScheduleInfo.filter(is_enabled=True).all()
|
||||
count = 0
|
||||
for schedule in schedules:
|
||||
if schedule.plugin_name in scheduler_manager._registered_tasks:
|
||||
scheduler_manager._add_aps_job(schedule)
|
||||
count += 1
|
||||
else:
|
||||
logger.warning(f"跳过加载定时任务:插件 '{schedule.plugin_name}' 未注册。")
|
||||
logger.info(f"定时任务加载完成,共成功加载 {count} 个任务。")
|
||||
@ -227,7 +227,7 @@ class PlatformUtils:
|
||||
url = None
|
||||
if platform == "qq":
|
||||
if user_id.isdigit():
|
||||
url = f"http://q1.qlogo.cn/g?b=qq&nk={user_id}&s=160"
|
||||
url = f"http://q1.qlogo.cn/g?b=qq&nk={user_id}&s=640"
|
||||
else:
|
||||
url = f"https://q.qlogo.cn/qqapp/{appid}/{user_id}/640"
|
||||
return await AsyncHttpx.get_content(url) if url else None
|
||||
|
||||
Loading…
Reference in New Issue
Block a user