导入最最近PR

导入最最近PR
修复LevelUser权限,以及格式化传入原始名称问题
This commit is contained in:
ManyManyTomato 2025-07-05 23:46:23 +08:00 committed by GitHub
commit d7d36af497
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 2586 additions and 396 deletions

View File

@ -18,3 +18,4 @@ jobs:
uses: astral-sh/ruff-action@v3
- name: Run Ruff Check
run: ruff check

View File

@ -4,11 +4,13 @@ from zhenxun.configs.path_config import DATA_PATH, IMAGE_PATH
from zhenxun.models.group_console import GroupConsole
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.task_info import TaskInfo
from zhenxun.services.cache import Cache
from zhenxun.utils.enum import BlockType, CacheType, PluginType
from zhenxun.utils.enum import BlockType, PluginType
from zhenxun.utils.exception import GroupInfoNotFound
from zhenxun.utils.image_utils import BuildImage, ImageTemplate, RowStyle
# from zhenxun.services.cache import Cache
# from zhenxun.utils.enum import CacheType
HELP_FILE = IMAGE_PATH / "SIMPLE_HELP.png"
GROUP_HELP_PATH = DATA_PATH / "group_help"
@ -251,7 +253,7 @@ class PluginManage:
)
group.status = False
await group.save(update_fields=["status"])
@classmethod
async def wake(cls, group_id: str):
"""醒来

View File

@ -1,12 +1,17 @@
import random
from typing import Any
from nonebot import on_regex
from nonebot.adapters import Bot
from nonebot.params import Depends, RegexGroup
from nonebot.plugin import PluginMetadata
from nonebot.rule import to_me
from nonebot_plugin_alconna import Alconna, Option, on_alconna, store_true
from nonebot_plugin_alconna import (
Alconna,
Args,
Arparma,
CommandMeta,
Option,
on_alconna,
store_true,
)
from nonebot_plugin_uninfo import Uninfo
from zhenxun.configs.config import BotConfig, Config
@ -54,15 +59,22 @@ __plugin_meta__ = PluginMetadata(
).to_dict(),
)
_nickname_matcher = on_regex(
"(?:以后)?(?:叫我|请叫我|称呼我)(.*)",
_nickname_matcher = on_alconna(
Alconna(
"re:(?:以后)?(?:叫我|请叫我|称呼我)",
Args["name?", str],
meta=CommandMeta(compact=True),
),
rule=to_me(),
priority=5,
block=True,
)
_global_nickname_matcher = on_regex(
"设置全局昵称(.*)", rule=to_me(), priority=5, block=True
_global_nickname_matcher = on_alconna(
Alconna("设置全局昵称", Args["name?", str], meta=CommandMeta(compact=True)),
rule=to_me(),
priority=5,
block=True,
)
_matcher = on_alconna(
@ -117,34 +129,32 @@ CANCEL = [
]
def CheckNickname():
async def CheckNickname(
bot: Bot,
session: Uninfo,
params: Arparma,
):
"""
检查名称是否合法
"""
async def dependency(
bot: Bot,
session: Uninfo,
reg_group: tuple[Any, ...] = RegexGroup(),
):
black_word = Config.get_config("nickname", "BLACK_WORD")
(name,) = reg_group
logger.debug(f"昵称检查: {name}", "昵称设置", session=session)
if not name:
await MessageUtils.build_message("叫你空白?叫你虚空?叫你无名??").finish(
at_sender=True
)
if session.user.id in bot.config.superusers:
logger.debug(
f"超级用户设置昵称, 跳过合法检测: {name}", "昵称设置", session=session
)
return
black_word = Config.get_config("nickname", "BLACK_WORD")
name = params.query("name")
logger.debug(f"昵称检查: {name}", "昵称设置", session=session)
if not name:
await MessageUtils.build_message("叫你空白?叫你虚空?叫你无名??").finish(
at_sender=True
)
if session.user.id in bot.config.superusers:
logger.debug(
f"超级用户设置昵称, 跳过合法检测: {name}", "昵称设置", session=session
)
else:
if len(name) > 20:
await MessageUtils.build_message("昵称可不能超过20个字!").finish(
at_sender=True
)
if name in bot.config.nickname:
await MessageUtils.build_message("笨蛋!休想占用我的名字! #").finish(
await MessageUtils.build_message("笨蛋!休想占用我的名字! ").finish(
at_sender=True
)
if black_word:
@ -162,17 +172,17 @@ def CheckNickname():
await MessageUtils.build_message(
f"字符 [{word}] 为禁止字符!"
).finish(at_sender=True)
return Depends(dependency)
return name
@_nickname_matcher.handle(parameterless=[CheckNickname()])
@_nickname_matcher.handle()
async def _(
bot: Bot,
session: Uninfo,
name_: Arparma,
uname: str = UserName(),
reg_group: tuple[Any, ...] = RegexGroup(),
):
(name,) = reg_group
name = await CheckNickname(bot, session, name_)
if len(name) < 5 and random.random() < 0.3:
name = "~".join(name)
group_id = None
@ -200,13 +210,14 @@ async def _(
)
@_global_nickname_matcher.handle(parameterless=[CheckNickname()])
@_global_nickname_matcher.handle()
async def _(
bot: Bot,
session: Uninfo,
name_: Arparma,
nickname: str = UserName(),
reg_group: tuple[Any, ...] = RegexGroup(),
):
(name,) = reg_group
name = await CheckNickname(bot, session, name_)
await FriendUser.set_user_nickname(
session.user.id,
name,
@ -227,15 +238,14 @@ async def _(session: Uninfo, uname: str = UserName()):
group_id = session.group.parent.id if session.group.parent else session.group.id
if group_id:
nickname = await GroupInfoUser.get_user_nickname(session.user.id, group_id)
card = uname
else:
nickname = await FriendUser.get_user_nickname(session.user.id)
card = uname
if nickname:
await MessageUtils.build_message(random.choice(REMIND).format(nickname)).finish(
reply_to=True
)
else:
card = uname
await MessageUtils.build_message(
random.choice(
[

View File

@ -31,4 +31,4 @@ async def _(session: Uninfo):
await FriendUser.create(
user_id=session.user.id, platform=PlatformUtils.get_platform(session)
)
logger.info("添加当前好友用户信息", "", session=session)
logger.info("添加当前好友用户信息", "", session=session)

View File

@ -0,0 +1,51 @@
from nonebot.plugin import PluginMetadata
from zhenxun.configs.utils import PluginExtraData
from zhenxun.utils.enum import PluginType
from . import command # noqa: F401
__plugin_meta__ = PluginMetadata(
name="定时任务管理",
description="查看和管理由 SchedulerManager 控制的定时任务。",
usage="""
📋 定时任务管理 - 支持群聊和私聊操作
🔍 查看任务:
定时任务 查看 [-all] [-g <群号>] [-p <插件>] [--page <页码>]
群聊中: 查看本群任务
私聊中: 必须使用 -g <群号> -all 选项 (SUPERUSER)
📊 任务状态:
定时任务 状态 <任务ID> 任务状态 <任务ID>
查看单个任务的详细信息和状态
任务管理 (SUPERUSER):
定时任务 设置 <插件> [时间选项] [-g <群号> | -g all] [--kwargs <参数>]
定时任务 删除 <任务ID> | -p <插件> [-g <群号>] | -all
定时任务 暂停 <任务ID> | -p <插件> [-g <群号>] | -all
定时任务 恢复 <任务ID> | -p <插件> [-g <群号>] | -all
定时任务 执行 <任务ID>
定时任务 更新 <任务ID> [时间选项] [--kwargs <参数>]
📝 时间选项 (三选一):
--cron "<分> <时> <日> <月> <周>" # 例: --cron "0 8 * * *"
--interval <时间间隔> # 例: --interval 30m, 2h, 10s
--date "<YYYY-MM-DD HH:MM:SS>" # 例: --date "2024-01-01 08:00:00"
--daily "<HH:MM>" # 例: --daily "08:30"
📚 其他功能:
定时任务 插件列表 # 查看所有可设置定时任务的插件 (SUPERUSER)
🏷 别名支持:
查看: ls, list | 设置: add, 开启 | 删除: del, rm, remove, 关闭, 取消
暂停: pause | 恢复: resume | 执行: trigger, run | 状态: status, info
更新: update, modify, 修改 | 插件列表: plugins
""".strip(),
extra=PluginExtraData(
author="HibiKier",
version="0.1.2",
plugin_type=PluginType.SUPERUSER,
is_show=False,
).to_dict(),
)

View File

@ -0,0 +1,836 @@
import asyncio
from datetime import datetime
import re
from nonebot.adapters import Event
from nonebot.adapters.onebot.v11 import Bot
from nonebot.params import Depends
from nonebot.permission import SUPERUSER
from nonebot_plugin_alconna import (
Alconna,
AlconnaMatch,
Args,
Arparma,
Match,
Option,
Query,
Subcommand,
on_alconna,
)
from pydantic import BaseModel, ValidationError
from zhenxun.utils._image_template import ImageTemplate
from zhenxun.utils.manager.schedule_manager import scheduler_manager
def _get_type_name(annotation) -> str:
"""获取类型注解的名称"""
if hasattr(annotation, "__name__"):
return annotation.__name__
elif hasattr(annotation, "_name"):
return annotation._name
else:
return str(annotation)
from zhenxun.utils.message import MessageUtils
from zhenxun.utils.rules import admin_check
def _format_trigger(schedule_status: dict) -> str:
"""将触发器配置格式化为人类可读的字符串"""
trigger_type = schedule_status["trigger_type"]
config = schedule_status["trigger_config"]
if trigger_type == "cron":
minute = config.get("minute", "*")
hour = config.get("hour", "*")
day = config.get("day", "*")
month = config.get("month", "*")
day_of_week = config.get("day_of_week", "*")
if day == "*" and month == "*" and day_of_week == "*":
formatted_hour = hour if hour == "*" else f"{int(hour):02d}"
formatted_minute = minute if minute == "*" else f"{int(minute):02d}"
return f"每天 {formatted_hour}:{formatted_minute}"
else:
return f"Cron: {minute} {hour} {day} {month} {day_of_week}"
elif trigger_type == "interval":
seconds = config.get("seconds", 0)
minutes = config.get("minutes", 0)
hours = config.get("hours", 0)
days = config.get("days", 0)
if days:
trigger_str = f"{days}"
elif hours:
trigger_str = f"{hours} 小时"
elif minutes:
trigger_str = f"{minutes} 分钟"
else:
trigger_str = f"{seconds}"
elif trigger_type == "date":
run_date = config.get("run_date", "未知时间")
trigger_str = f"{run_date}"
else:
trigger_str = f"{trigger_type}: {config}"
return trigger_str
def _format_params(schedule_status: dict) -> str:
"""将任务参数格式化为人类可读的字符串"""
if kwargs := schedule_status.get("job_kwargs"):
kwargs_str = " | ".join(f"{k}: {v}" for k, v in kwargs.items())
return kwargs_str
return "-"
def _parse_interval(interval_str: str) -> dict:
"""增强版解析器,支持 d(天)"""
match = re.match(r"(\d+)([smhd])", interval_str.lower())
if not match:
raise ValueError("时间间隔格式错误, 请使用如 '30m', '2h', '1d', '10s' 的格式。")
value, unit = int(match.group(1)), match.group(2)
if unit == "s":
return {"seconds": value}
if unit == "m":
return {"minutes": value}
if unit == "h":
return {"hours": value}
if unit == "d":
return {"days": value}
return {}
def _parse_daily_time(time_str: str) -> dict:
"""解析 HH:MM 或 HH:MM:SS 格式的时间为 cron 配置"""
if match := re.match(r"^(\d{1,2}):(\d{1,2})(?::(\d{1,2}))?$", time_str):
hour, minute, second = match.groups()
hour, minute = int(hour), int(minute)
if not (0 <= hour <= 23 and 0 <= minute <= 59):
raise ValueError("小时或分钟数值超出范围。")
cron_config = {
"minute": str(minute),
"hour": str(hour),
"day": "*",
"month": "*",
"day_of_week": "*",
}
if second is not None:
if not (0 <= int(second) <= 59):
raise ValueError("秒数值超出范围。")
cron_config["second"] = str(second)
return cron_config
else:
raise ValueError("时间格式错误,请使用 'HH:MM''HH:MM:SS' 格式。")
async def GetBotId(
bot: Bot,
bot_id_match: Match[str] = AlconnaMatch("bot_id"),
) -> str:
"""获取要操作的Bot ID"""
if bot_id_match.available:
return bot_id_match.result
return bot.self_id
class ScheduleTarget:
"""定时任务操作目标的基类"""
pass
class TargetByID(ScheduleTarget):
"""按任务ID操作"""
def __init__(self, id: int):
self.id = id
class TargetByPlugin(ScheduleTarget):
"""按插件名操作"""
def __init__(
self, plugin: str, group_id: str | None = None, all_groups: bool = False
):
self.plugin = plugin
self.group_id = group_id
self.all_groups = all_groups
class TargetAll(ScheduleTarget):
"""操作所有任务"""
def __init__(self, for_group: str | None = None):
self.for_group = for_group
TargetScope = TargetByID | TargetByPlugin | TargetAll | None
def create_target_parser(subcommand_name: str):
"""
创建一个依赖注入函数用于解析删除暂停恢复等命令的操作目标
"""
async def dependency(
event: Event,
schedule_id: Match[int] = AlconnaMatch("schedule_id"),
plugin_name: Match[str] = AlconnaMatch("plugin_name"),
group_id: Match[str] = AlconnaMatch("group_id"),
all_enabled: Query[bool] = Query(f"{subcommand_name}.all"),
) -> TargetScope:
if schedule_id.available:
return TargetByID(schedule_id.result)
if plugin_name.available:
p_name = plugin_name.result
if all_enabled.available:
return TargetByPlugin(plugin=p_name, all_groups=True)
elif group_id.available:
gid = group_id.result
if gid.lower() == "all":
return TargetByPlugin(plugin=p_name, all_groups=True)
return TargetByPlugin(plugin=p_name, group_id=gid)
else:
current_group_id = getattr(event, "group_id", None)
if current_group_id:
return TargetByPlugin(plugin=p_name, group_id=str(current_group_id))
else:
await schedule_cmd.finish(
"私聊中操作插件任务必须使用 -g <群号> 或 -all 选项。"
)
if all_enabled.available:
return TargetAll(for_group=group_id.result if group_id.available else None)
return None
return dependency
schedule_cmd = on_alconna(
Alconna(
"定时任务",
Subcommand(
"查看",
Option("-g", Args["target_group_id", str]),
Option("-all", help_text="查看所有群聊 (SUPERUSER)"),
Option("-p", Args["plugin_name", str], help_text="按插件名筛选"),
Option("--page", Args["page", int, 1], help_text="指定页码"),
alias=["ls", "list"],
help_text="查看定时任务",
),
Subcommand(
"设置",
Args["plugin_name", str],
Option("--cron", Args["cron_expr", str], help_text="设置 cron 表达式"),
Option("--interval", Args["interval_expr", str], help_text="设置时间间隔"),
Option("--date", Args["date_expr", str], help_text="设置特定执行日期"),
Option(
"--daily",
Args["daily_expr", str],
help_text="设置每天执行的时间 (如 08:20)",
),
Option("-g", Args["group_id", str], help_text="指定群组ID或'all'"),
Option("-all", help_text="对所有群生效 (等同于 -g all)"),
Option("--kwargs", Args["kwargs_str", str], help_text="设置任务参数"),
Option(
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
),
alias=["add", "开启"],
help_text="设置/开启一个定时任务",
),
Subcommand(
"删除",
Args["schedule_id?", int],
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
Option("-g", Args["group_id", str], help_text="指定群组ID"),
Option("-all", help_text="对所有群生效"),
Option(
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
),
alias=["del", "rm", "remove", "关闭", "取消"],
help_text="删除一个或多个定时任务",
),
Subcommand(
"暂停",
Args["schedule_id?", int],
Option("-all", help_text="对当前群所有任务生效"),
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
Option("-g", Args["group_id", str], help_text="指定群组ID (SUPERUSER)"),
Option(
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
),
alias=["pause"],
help_text="暂停一个或多个定时任务",
),
Subcommand(
"恢复",
Args["schedule_id?", int],
Option("-all", help_text="对当前群所有任务生效"),
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
Option("-g", Args["group_id", str], help_text="指定群组ID (SUPERUSER)"),
Option(
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
),
alias=["resume"],
help_text="恢复一个或多个定时任务",
),
Subcommand(
"执行",
Args["schedule_id", int],
alias=["trigger", "run"],
help_text="立即执行一次任务",
),
Subcommand(
"更新",
Args["schedule_id", int],
Option("--cron", Args["cron_expr", str], help_text="设置 cron 表达式"),
Option("--interval", Args["interval_expr", str], help_text="设置时间间隔"),
Option("--date", Args["date_expr", str], help_text="设置特定执行日期"),
Option(
"--daily",
Args["daily_expr", str],
help_text="更新每天执行的时间 (如 08:20)",
),
Option("--kwargs", Args["kwargs_str", str], help_text="更新参数"),
alias=["update", "modify", "修改"],
help_text="更新任务配置",
),
Subcommand(
"状态",
Args["schedule_id", int],
alias=["status", "info"],
help_text="查看单个任务的详细状态",
),
Subcommand(
"插件列表",
alias=["plugins"],
help_text="列出所有可用的插件",
),
),
priority=5,
block=True,
rule=admin_check(1),
)
schedule_cmd.shortcut(
"任务状态",
command="定时任务",
arguments=["状态", "{%0}"],
prefix=True,
)
@schedule_cmd.handle()
async def _handle_time_options_mutex(arp: Arparma):
time_options = ["cron", "interval", "date", "daily"]
provided_options = [opt for opt in time_options if arp.query(opt) is not None]
if len(provided_options) > 1:
await schedule_cmd.finish(
f"时间选项 --{', --'.join(provided_options)} 不能同时使用,请只选择一个。"
)
@schedule_cmd.assign("查看")
async def _(
bot: Bot,
event: Event,
target_group_id: Match[str] = AlconnaMatch("target_group_id"),
all_groups: Query[bool] = Query("查看.all"),
plugin_name: Match[str] = AlconnaMatch("plugin_name"),
page: Match[int] = AlconnaMatch("page"),
):
is_superuser = await SUPERUSER(bot, event)
schedules = []
title = ""
current_group_id = getattr(event, "group_id", None)
if not (all_groups.available or target_group_id.available) and not current_group_id:
await schedule_cmd.finish("私聊中查看任务必须使用 -g <群号> 或 -all 选项。")
if all_groups.available:
if not is_superuser:
await schedule_cmd.finish("需要超级用户权限才能查看所有群组的定时任务。")
schedules = await scheduler_manager.get_all_schedules()
title = "所有群组的定时任务"
elif target_group_id.available:
if not is_superuser:
await schedule_cmd.finish("需要超级用户权限才能查看指定群组的定时任务。")
gid = target_group_id.result
schedules = [
s for s in await scheduler_manager.get_all_schedules() if s.group_id == gid
]
title = f"{gid} 的定时任务"
else:
gid = str(current_group_id)
schedules = [
s for s in await scheduler_manager.get_all_schedules() if s.group_id == gid
]
title = "本群的定时任务"
if plugin_name.available:
schedules = [s for s in schedules if s.plugin_name == plugin_name.result]
title += f" [插件: {plugin_name.result}]"
if not schedules:
await schedule_cmd.finish("没有找到任何相关的定时任务。")
page_size = 15
current_page = page.result
total_items = len(schedules)
total_pages = (total_items + page_size - 1) // page_size
start_index = (current_page - 1) * page_size
end_index = start_index + page_size
paginated_schedules = schedules[start_index:end_index]
if not paginated_schedules:
await schedule_cmd.finish("这一页没有内容了哦~")
status_tasks = [
scheduler_manager.get_schedule_status(s.id) for s in paginated_schedules
]
all_statuses = await asyncio.gather(*status_tasks)
data_list = [
[
s["id"],
s["plugin_name"],
s.get("bot_id") or "N/A",
s["group_id"] or "全局",
s["next_run_time"],
_format_trigger(s),
_format_params(s),
"✔️ 已启用" if s["is_enabled"] else "⏸️ 已暂停",
]
for s in all_statuses
if s
]
if not data_list:
await schedule_cmd.finish("没有找到任何相关的定时任务。")
img = await ImageTemplate.table_page(
head_text=title,
tip_text=f"{current_page}/{total_pages} 页,共 {total_items} 条任务",
column_name=[
"ID",
"插件",
"Bot ID",
"群组/目标",
"下次运行",
"触发规则",
"参数",
"状态",
],
data_list=data_list,
column_space=20,
)
await MessageUtils.build_message(img).send(reply_to=True)
@schedule_cmd.assign("设置")
async def _(
event: Event,
plugin_name: str,
cron_expr: str | None = None,
interval_expr: str | None = None,
date_expr: str | None = None,
daily_expr: str | None = None,
group_id: str | None = None,
kwargs_str: str | None = None,
all_enabled: Query[bool] = Query("设置.all"),
bot_id_to_operate: str = Depends(GetBotId),
):
if plugin_name not in scheduler_manager._registered_tasks:
await schedule_cmd.finish(
f"插件 '{plugin_name}' 没有注册可用的定时任务。\n"
f"可用插件: {list(scheduler_manager._registered_tasks.keys())}"
)
trigger_type = ""
trigger_config = {}
try:
if cron_expr:
trigger_type = "cron"
parts = cron_expr.split()
if len(parts) != 5:
raise ValueError("Cron 表达式必须有5个部分 (分 时 日 月 周)")
cron_keys = ["minute", "hour", "day", "month", "day_of_week"]
trigger_config = dict(zip(cron_keys, parts))
elif interval_expr:
trigger_type = "interval"
trigger_config = _parse_interval(interval_expr)
elif date_expr:
trigger_type = "date"
trigger_config = {"run_date": datetime.fromisoformat(date_expr)}
elif daily_expr:
trigger_type = "cron"
trigger_config = _parse_daily_time(daily_expr)
else:
await schedule_cmd.finish(
"必须提供一种时间选项: --cron, --interval, --date, 或 --daily。"
)
except ValueError as e:
await schedule_cmd.finish(f"时间参数解析错误: {e}")
job_kwargs = {}
if kwargs_str:
task_meta = scheduler_manager._registered_tasks[plugin_name]
params_model = task_meta.get("model")
if not params_model:
await schedule_cmd.finish(f"插件 '{plugin_name}' 不支持设置额外参数。")
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
await schedule_cmd.finish(f"插件 '{plugin_name}' 的参数模型配置错误。")
raw_kwargs = {}
try:
for item in kwargs_str.split(","):
key, value = item.strip().split("=", 1)
raw_kwargs[key.strip()] = value
except Exception as e:
await schedule_cmd.finish(
f"参数格式错误,请使用 'key=value,key2=value2' 格式。错误: {e}"
)
try:
model_validate = getattr(params_model, "model_validate", None)
if not model_validate:
await schedule_cmd.finish(
f"插件 '{plugin_name}' 的参数模型不支持验证。"
)
return
validated_model = model_validate(raw_kwargs)
model_dump = getattr(validated_model, "model_dump", None)
if not model_dump:
await schedule_cmd.finish(
f"插件 '{plugin_name}' 的参数模型不支持导出。"
)
return
job_kwargs = model_dump()
except ValidationError as e:
errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()]
error_str = "\n".join(errors)
await schedule_cmd.finish(
f"插件 '{plugin_name}' 的任务参数验证失败:\n{error_str}"
)
return
target_group_id: str | None
current_group_id = getattr(event, "group_id", None)
if group_id and group_id.lower() == "all":
target_group_id = "__ALL_GROUPS__"
elif all_enabled.available:
target_group_id = "__ALL_GROUPS__"
elif group_id:
target_group_id = group_id
elif current_group_id:
target_group_id = str(current_group_id)
else:
await schedule_cmd.finish(
"私聊中设置定时任务时,必须使用 -g <群号> 或 --all 选项指定目标。"
)
return
success, msg = await scheduler_manager.add_schedule(
plugin_name,
target_group_id,
trigger_type,
trigger_config,
job_kwargs,
bot_id=bot_id_to_operate,
)
if target_group_id == "__ALL_GROUPS__":
target_desc = f"所有群组 (Bot: {bot_id_to_operate})"
elif target_group_id is None:
target_desc = "全局"
else:
target_desc = f"群组 {target_group_id}"
if success:
await schedule_cmd.finish(f"已成功为 [{target_desc}] {msg}")
else:
await schedule_cmd.finish(f"为 [{target_desc}] 设置任务失败: {msg}")
@schedule_cmd.assign("删除")
async def _(
target: TargetScope = Depends(create_target_parser("删除")),
bot_id_to_operate: str = Depends(GetBotId),
):
if isinstance(target, TargetByID):
_, message = await scheduler_manager.remove_schedule_by_id(target.id)
await schedule_cmd.finish(message)
elif isinstance(target, TargetByPlugin):
p_name = target.plugin
if p_name not in scheduler_manager.get_registered_plugins():
await schedule_cmd.finish(f"未找到插件 '{p_name}'")
if target.all_groups:
removed_count = await scheduler_manager.remove_schedule_for_all(
p_name, bot_id=bot_id_to_operate
)
message = (
f"已取消了 {removed_count} 个群组的插件 '{p_name}' 定时任务。"
if removed_count > 0
else f"没有找到插件 '{p_name}' 的定时任务。"
)
await schedule_cmd.finish(message)
else:
_, message = await scheduler_manager.remove_schedule(
p_name, target.group_id, bot_id=bot_id_to_operate
)
await schedule_cmd.finish(message)
elif isinstance(target, TargetAll):
if target.for_group:
_, message = await scheduler_manager.remove_schedules_by_group(
target.for_group
)
await schedule_cmd.finish(message)
else:
_, message = await scheduler_manager.remove_all_schedules()
await schedule_cmd.finish(message)
else:
await schedule_cmd.finish(
"删除任务失败请提供任务ID或通过 -p <插件> 或 -all 指定要删除的任务。"
)
@schedule_cmd.assign("暂停")
async def _(
target: TargetScope = Depends(create_target_parser("暂停")),
bot_id_to_operate: str = Depends(GetBotId),
):
if isinstance(target, TargetByID):
_, message = await scheduler_manager.pause_schedule(target.id)
await schedule_cmd.finish(message)
elif isinstance(target, TargetByPlugin):
p_name = target.plugin
if p_name not in scheduler_manager.get_registered_plugins():
await schedule_cmd.finish(f"未找到插件 '{p_name}'")
if target.all_groups:
_, message = await scheduler_manager.pause_schedules_by_plugin(p_name)
await schedule_cmd.finish(message)
else:
_, message = await scheduler_manager.pause_schedule_by_plugin_group(
p_name, target.group_id, bot_id=bot_id_to_operate
)
await schedule_cmd.finish(message)
elif isinstance(target, TargetAll):
if target.for_group:
_, message = await scheduler_manager.pause_schedules_by_group(
target.for_group
)
await schedule_cmd.finish(message)
else:
_, message = await scheduler_manager.pause_all_schedules()
await schedule_cmd.finish(message)
else:
await schedule_cmd.finish("请提供任务ID、使用 -p <插件> 或 -all 选项。")
@schedule_cmd.assign("恢复")
async def _(
target: TargetScope = Depends(create_target_parser("恢复")),
bot_id_to_operate: str = Depends(GetBotId),
):
if isinstance(target, TargetByID):
_, message = await scheduler_manager.resume_schedule(target.id)
await schedule_cmd.finish(message)
elif isinstance(target, TargetByPlugin):
p_name = target.plugin
if p_name not in scheduler_manager.get_registered_plugins():
await schedule_cmd.finish(f"未找到插件 '{p_name}'")
if target.all_groups:
_, message = await scheduler_manager.resume_schedules_by_plugin(p_name)
await schedule_cmd.finish(message)
else:
_, message = await scheduler_manager.resume_schedule_by_plugin_group(
p_name, target.group_id, bot_id=bot_id_to_operate
)
await schedule_cmd.finish(message)
elif isinstance(target, TargetAll):
if target.for_group:
_, message = await scheduler_manager.resume_schedules_by_group(
target.for_group
)
await schedule_cmd.finish(message)
else:
_, message = await scheduler_manager.resume_all_schedules()
await schedule_cmd.finish(message)
else:
await schedule_cmd.finish("请提供任务ID、使用 -p <插件> 或 -all 选项。")
@schedule_cmd.assign("执行")
async def _(schedule_id: int):
_, message = await scheduler_manager.trigger_now(schedule_id)
await schedule_cmd.finish(message)
@schedule_cmd.assign("更新")
async def _(
schedule_id: int,
cron_expr: str | None = None,
interval_expr: str | None = None,
date_expr: str | None = None,
daily_expr: str | None = None,
kwargs_str: str | None = None,
):
if not any([cron_expr, interval_expr, date_expr, daily_expr, kwargs_str]):
await schedule_cmd.finish(
"请提供需要更新的时间 (--cron/--interval/--date/--daily) 或参数 (--kwargs)"
)
trigger_config = None
trigger_type = None
try:
if cron_expr:
trigger_type = "cron"
parts = cron_expr.split()
if len(parts) != 5:
raise ValueError("Cron 表达式必须有5个部分")
cron_keys = ["minute", "hour", "day", "month", "day_of_week"]
trigger_config = dict(zip(cron_keys, parts))
elif interval_expr:
trigger_type = "interval"
trigger_config = _parse_interval(interval_expr)
elif date_expr:
trigger_type = "date"
trigger_config = {"run_date": datetime.fromisoformat(date_expr)}
elif daily_expr:
trigger_type = "cron"
trigger_config = _parse_daily_time(daily_expr)
except ValueError as e:
await schedule_cmd.finish(f"时间参数解析错误: {e}")
job_kwargs = None
if kwargs_str:
schedule = await scheduler_manager.get_schedule_by_id(schedule_id)
if not schedule:
await schedule_cmd.finish(f"未找到 ID 为 {schedule_id} 的任务。")
task_meta = scheduler_manager._registered_tasks.get(schedule.plugin_name)
if not task_meta or not (params_model := task_meta.get("model")):
await schedule_cmd.finish(
f"插件 '{schedule.plugin_name}' 未定义参数模型,无法更新参数。"
)
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
await schedule_cmd.finish(
f"插件 '{schedule.plugin_name}' 的参数模型配置错误。"
)
raw_kwargs = {}
try:
for item in kwargs_str.split(","):
key, value = item.strip().split("=", 1)
raw_kwargs[key.strip()] = value
except Exception as e:
await schedule_cmd.finish(
f"参数格式错误,请使用 'key=value,key2=value2' 格式。错误: {e}"
)
try:
model_validate = getattr(params_model, "model_validate", None)
if not model_validate:
await schedule_cmd.finish(
f"插件 '{schedule.plugin_name}' 的参数模型不支持验证。"
)
return
validated_model = model_validate(raw_kwargs)
model_dump = getattr(validated_model, "model_dump", None)
if not model_dump:
await schedule_cmd.finish(
f"插件 '{schedule.plugin_name}' 的参数模型不支持导出。"
)
return
job_kwargs = model_dump(exclude_unset=True)
except ValidationError as e:
errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()]
error_str = "\n".join(errors)
await schedule_cmd.finish(f"更新的参数验证失败:\n{error_str}")
return
_, message = await scheduler_manager.update_schedule(
schedule_id, trigger_type, trigger_config, job_kwargs
)
await schedule_cmd.finish(message)
@schedule_cmd.assign("插件列表")
async def _():
registered_plugins = scheduler_manager.get_registered_plugins()
if not registered_plugins:
await schedule_cmd.finish("当前没有已注册的定时任务插件。")
message_parts = ["📋 已注册的定时任务插件:"]
for i, plugin_name in enumerate(registered_plugins, 1):
task_meta = scheduler_manager._registered_tasks[plugin_name]
params_model = task_meta.get("model")
if not params_model:
message_parts.append(f"{i}. {plugin_name} - 无参数")
continue
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
message_parts.append(f"{i}. {plugin_name} - ⚠️ 参数模型配置错误")
continue
model_fields = getattr(params_model, "model_fields", None)
if model_fields:
param_info = ", ".join(
f"{field_name}({_get_type_name(field_info.annotation)})"
for field_name, field_info in model_fields.items()
)
message_parts.append(f"{i}. {plugin_name} - 参数: {param_info}")
else:
message_parts.append(f"{i}. {plugin_name} - 无参数")
await schedule_cmd.finish("\n".join(message_parts))
@schedule_cmd.assign("状态")
async def _(schedule_id: int):
status = await scheduler_manager.get_schedule_status(schedule_id)
if not status:
await schedule_cmd.finish(f"未找到ID为 {schedule_id} 的定时任务。")
info_lines = [
f"📋 定时任务详细信息 (ID: {schedule_id})",
"--------------------",
f"▫️ 插件: {status['plugin_name']}",
f"▫️ Bot ID: {status.get('bot_id') or '默认'}",
f"▫️ 目标: {status['group_id'] or '全局'}",
f"▫️ 状态: {'✔️ 已启用' if status['is_enabled'] else '⏸️ 已暂停'}",
f"▫️ 下次运行: {status['next_run_time']}",
f"▫️ 触发规则: {_format_trigger(status)}",
f"▫️ 任务参数: {_format_params(status)}",
]
await schedule_cmd.finish("\n".join(info_lines))

View File

@ -29,7 +29,7 @@ class BotConsole(Model):
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
table = "bot_console"
table_description = "Bot数据表"
cache_type = CacheType.BOT
@staticmethod

View File

@ -3,6 +3,7 @@ from tortoise import fields
from zhenxun.services.db_context import Model
from zhenxun.utils.enum import CacheType
class LevelUser(Model):
id = fields.IntField(pk=True, generated=True, auto_increment=True)
"""自增id"""

View File

@ -58,7 +58,7 @@ class PluginInfo(Model):
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
table = "plugin_info"
table_description = "插件基本信息"
cache_type = CacheType.PLUGINS
@classmethod

View File

@ -0,0 +1,38 @@
from tortoise import fields
from zhenxun.services.db_context import Model
class ScheduleInfo(Model):
id = fields.IntField(pk=True, generated=True, auto_increment=True)
"""自增id"""
bot_id = fields.CharField(
255, null=True, default=None, description="任务关联的Bot ID"
)
"""任务关联的Bot ID"""
plugin_name = fields.CharField(255, description="插件模块名")
"""插件模块名"""
group_id = fields.CharField(
255,
null=True,
description="群组ID, '__ALL_GROUPS__' 表示所有群, 为空表示全局任务",
)
"""群组ID, 为空表示全局任务"""
trigger_type = fields.CharField(
max_length=20, default="cron", description="触发器类型 (cron, interval, date)"
)
"""触发器类型 (cron, interval, date)"""
trigger_config = fields.JSONField(description="触发器具体配置")
"""触发器具体配置"""
job_kwargs = fields.JSONField(
default=dict, description="传递给任务函数的额外关键字参数"
)
"""传递给任务函数的额外关键字参数"""
is_enabled = fields.BooleanField(default=True, description="是否启用")
"""是否启用"""
create_time = fields.DatetimeField(auto_now_add=True)
"""创建时间"""
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
table = "schedule_info"
table_description = "通用定时任务表"

View File

@ -208,7 +208,7 @@ class CacheData(BaseModel):
async def get_data(self) -> Any:
"""从缓存获取数据"""
try:
data = await self._cache.get(self.name)
data = await self._cache.get(self.name)# type: ignore
logger.debug(f"获取缓存 {self.name} 数据: {data}")
# 如果数据为空,尝试重新加载
@ -307,11 +307,11 @@ class CacheData(BaseModel):
logger.debug(f"设置缓存 {self.name} 序列化后数据: {serialized_value}")
# 2. 删除旧数据
await self._cache.delete(self.name)
await self._cache.delete(self.name)# type: ignore
logger.debug(f"删除缓存 {self.name} 旧数据")
# 3. 设置新数据
await self._cache.set(self.name, serialized_value, ttl=self.expire)
await self._cache.set(self.name, serialized_value, ttl=self.expire)# type: ignore
logger.debug(f"设置缓存 {self.name} 新数据完成")
except Exception as e:
@ -321,7 +321,7 @@ class CacheData(BaseModel):
async def delete_data(self):
"""删除缓存数据"""
try:
await self._cache.delete(self.name)
await self._cache.delete(self.name)# type: ignore
except Exception as e:
logger.error(f"删除缓存 {self.name}", e=e)
@ -428,7 +428,7 @@ class CacheData(BaseModel):
"""
cache_key = self._get_cache_key(key)
try:
data = await self._cache.get(cache_key)
data = await self._cache.get(cache_key)# type: ignore
logger.debug(f"获取缓存 {cache_key} 数据: {data}")
if self.result_model:
@ -467,7 +467,7 @@ class CacheData(BaseModel):
for key in self._keys:
# 提取原始键名(去掉前缀)
original_key = key.split(":", 1)[1]
data = await self._cache.get(key)
data = await self._cache.get(key)# type: ignore
if self.result_model:
result[original_key] = self._deserialize_value(
data, self.result_model
@ -484,7 +484,7 @@ class CacheData(BaseModel):
cache_key = self._get_cache_key(key)
try:
serialized_value = self._serialize_value(value)
await self._cache.set(cache_key, serialized_value, ttl=self.expire)
await self._cache.set(cache_key, serialized_value, ttl=self.expire)# type: ignore
self._keys.add(cache_key) # 添加到键列表
logger.debug(f"设置缓存 {cache_key} 数据完成")
except Exception as e:
@ -495,7 +495,7 @@ class CacheData(BaseModel):
"""删除指定键的缓存数据"""
cache_key = self._get_cache_key(key)
try:
await self._cache.delete(cache_key)
await self._cache.delete(cache_key)# type: ignore
self._keys.discard(cache_key) # 从键列表中移除
logger.debug(f"删除缓存 {cache_key} 完成")
except Exception as e:
@ -505,7 +505,7 @@ class CacheData(BaseModel):
"""清除所有缓存数据"""
try:
for key in list(self._keys): # 使用列表复制避免在迭代时修改
await self._cache.delete(key)
await self._cache.delete(key)# type: ignore
self._keys.clear()
logger.debug(f"清除缓存 {self.name} 完成")
except Exception as e:
@ -524,7 +524,7 @@ class CacheManager:
if self._cache_instance is None:
if config.redis_host:
self._cache_instance = AioCache(
AioCache.REDIS,
AioCache.REDIS,# type: ignore
serializer=JsonSerializer(),
namespace="zhenxun_cache",
timeout=30, # 操作超时时间
@ -546,12 +546,12 @@ class CacheManager:
async def close(self):
if self._cache_instance:
await self._cache_instance.close()
await self._cache_instance.close()# type: ignore
async def verify_connection(self):
"""连接测试"""
try:
await self._cache.get("__test__")
await self._cache.get("__test__")# type: ignore
except Exception as e:
logger.error("连接失败", LOG_COMMAND, e=e)
raise
@ -582,13 +582,14 @@ class CacheManager:
if _name in self._data:
raise DbCacheException(f"缓存 {name} 已存在")
self._data[_name] = CacheData(
cache_data = CacheData(
name=_name,
func=func,
expire=expire,
lazy_load=lazy_load,
_cache=self._cache,
)
cache_data._cache = self._cache
self._data[_name] = cache_data
return func
return wrapper

View File

@ -42,9 +42,10 @@ class Model(TortoiseModel):
Model_: Model
"""
sem_data: ClassVar[dict[str, dict[str, Semaphore]]] = {}
def __init_subclass__(cls, **kwargs):
MODELS.append(cls.__module__)
if cls.__module__ not in MODELS:
MODELS.append(cls.__module__)
if func := getattr(cls, "_run_script", None):
SCRIPT_METHOD.append((cls.__module__, func))
@ -171,7 +172,7 @@ class Model(TortoiseModel):
await CacheRoot.reload(cache_type)
class DbUrlMissing(Exception):
class DbUrlIsNode(HookPriorityException):
"""
数据库链接地址为空
"""
@ -190,7 +191,7 @@ class DbConnectError(Exception):
@PriorityLifecycle.on_startup(priority=1)
async def init():
if not BotConfig.db_url:
# raise DbUrlMissing("数据库配置为空,请在.env.dev中配置DB_URL...")
# raise DbUrlIsNode("数据库配置为空,请在.env.dev中配置DB_URL...")
error = f"""
**********************************************************************
🌟 **************************** 配置为空 ************************* 🌟
@ -199,7 +200,7 @@ async def init():
***********************************************************************
***********************************************************************
"""
raise DbUrlMissing("\n" + error.strip())
raise DbUrlIsNode("\n" + error.strip())
try:
await Tortoise.init(
db_url=BotConfig.db_url,

View File

@ -1,91 +1,94 @@
import os
import sys
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, Literal
from nonebot import get_driver
from playwright.__main__ import main
from playwright.async_api import Browser, Playwright, async_playwright
from nonebot_plugin_alconna import UniMessage
from nonebot_plugin_htmlrender import get_browser
from playwright.async_api import Page
from zhenxun.configs.config import BotConfig
from zhenxun.services.log import logger
driver = get_driver()
_playwright: Playwright | None = None
_browser: Browser | None = None
from zhenxun.utils.message import MessageUtils
# @driver.on_startup
# async def start_browser():
# global _playwright
# global _browser
# install()
# await check_playwright_env()
# _playwright = await async_playwright().start()
# _browser = await _playwright.chromium.launch()
class BrowserIsNone(Exception):
pass
# @driver.on_shutdown
# async def shutdown_browser():
# if _browser:
# await _browser.close()
# if _playwright:
# await _playwright.stop() # type: ignore
class AsyncPlaywright:
@classmethod
@asynccontextmanager
async def new_page(
cls, cookies: list[dict[str, Any]] | dict[str, Any] | None = None, **kwargs
) -> AsyncGenerator[Page, None]:
"""获取一个新页面
# def get_browser() -> Browser:
# if not _browser:
# raise RuntimeError("playwright is not initalized")
# return _browser
def install():
"""自动安装、更新 Chromium"""
def set_env_variables():
os.environ["PLAYWRIGHT_DOWNLOAD_HOST"] = (
"https://npmmirror.com/mirrors/playwright/"
)
if BotConfig.system_proxy:
os.environ["HTTPS_PROXY"] = BotConfig.system_proxy
def restore_env_variables():
os.environ.pop("PLAYWRIGHT_DOWNLOAD_HOST", None)
if BotConfig.system_proxy:
os.environ.pop("HTTPS_PROXY", None)
if original_proxy is not None:
os.environ["HTTPS_PROXY"] = original_proxy
def try_install_chromium():
参数:
cookies: cookies
"""
browser = await get_browser()
ctx = await browser.new_context(**kwargs)
if cookies:
if isinstance(cookies, dict):
cookies = [cookies]
await ctx.add_cookies(cookies) # type: ignore
page = await ctx.new_page()
try:
sys.argv = ["", "install", "chromium"]
main()
except SystemExit as e:
return e.code == 0
return False
yield page
finally:
await page.close()
await ctx.close()
logger.info("检查 Chromium 更新")
@classmethod
async def screenshot(
cls,
url: str,
path: Path | str,
element: str | list[str],
*,
wait_time: int | None = None,
viewport_size: dict[str, int] | None = None,
wait_until: (
Literal["domcontentloaded", "load", "networkidle"] | None
) = "networkidle",
timeout: float | None = None,
type_: Literal["jpeg", "png"] | None = None,
user_agent: str | None = None,
cookies: list[dict[str, Any]] | dict[str, Any] | None = None,
**kwargs,
) -> UniMessage | None:
"""截图,该方法仅用于简单快捷截图,复杂截图请操作 page
original_proxy = os.environ.get("HTTPS_PROXY")
set_env_variables()
success = try_install_chromium()
if not success:
logger.info("Chromium 更新失败,尝试从原始仓库下载,速度较慢")
os.environ["PLAYWRIGHT_DOWNLOAD_HOST"] = ""
success = try_install_chromium()
restore_env_variables()
if not success:
raise RuntimeError("未知错误Chromium 下载失败")
async def check_playwright_env():
"""检查 Playwright 依赖"""
logger.info("检查 Playwright 依赖")
try:
async with async_playwright() as p:
await p.chromium.launch()
except Exception as e:
raise ImportError("加载失败Playwright 依赖不全,") from e
参数:
url: 网址
path: 存储路径
element: 元素选择
wait_time: 等待截取超时时间
viewport_size: 窗口大小
wait_until: 等待类型
timeout: 超时限制
type_: 保存类型
user_agent: user_agent
cookies: cookies
"""
if viewport_size is None:
viewport_size = {"width": 2560, "height": 1080}
if isinstance(path, str):
path = Path(path)
wait_time = wait_time * 1000 if wait_time else None
element_list = [element] if isinstance(element, str) else element
async with cls.new_page(
cookies,
viewport=viewport_size,
user_agent=user_agent,
**kwargs,
) as page:
await page.goto(url, timeout=timeout, wait_until=wait_until)
card = page
for e in element_list:
if not card:
return None
card = await card.wait_for_selector(e, timeout=wait_time)
if card:
await card.screenshot(path=path, timeout=timeout, type=type_)
return MessageUtils.build_message(path)
return None

View File

@ -1,24 +1,226 @@
from collections.abc import Callable
from functools import partial, wraps
from typing import Any, Literal
from anyio import EndOfStream
from httpx import ConnectError, HTTPStatusError, TimeoutException
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
from httpx import (
ConnectError,
HTTPStatusError,
RemoteProtocolError,
StreamError,
TimeoutException,
)
from nonebot.utils import is_coroutine_callable
from tenacity import (
RetryCallState,
retry,
retry_if_exception_type,
retry_if_result,
stop_after_attempt,
wait_exponential,
wait_fixed,
)
from zhenxun.services.log import logger
LOG_COMMAND = "RetryDecorator"
_SENTINEL = object()
def _log_before_sleep(log_name: str | None, retry_state: RetryCallState):
"""
tenacity 重试前的日志记录回调函数
"""
func_name = retry_state.fn.__name__ if retry_state.fn else "unknown_function"
log_context = f"函数 '{func_name}'"
if log_name:
log_context = f"操作 '{log_name}' ({log_context})"
reason = ""
if retry_state.outcome:
if exc := retry_state.outcome.exception():
reason = f"触发异常: {exc.__class__.__name__}({exc})"
else:
reason = f"不满足结果条件: result={retry_state.outcome.result()}"
wait_time = (
getattr(retry_state.next_action, "sleep", 0) if retry_state.next_action else 0
)
logger.warning(
f"{log_context}{retry_state.attempt_number} 次重试... "
f"等待 {wait_time:.2f} 秒. {reason}",
LOG_COMMAND,
)
class Retry:
@staticmethod
def api(
retry_count: int = 3, wait: int = 1, exception: tuple[type[Exception], ...] = ()
def simple(
stop_max_attempt: int = 3,
wait_fixed_seconds: int = 2,
exception: tuple[type[Exception], ...] = (),
*,
log_name: str | None = None,
on_failure: Callable[[Exception], Any] | None = None,
return_on_failure: Any = _SENTINEL,
):
"""接口调用重试"""
"""
一个简单的用于通用网络请求的重试装饰器预设
参数:
stop_max_attempt: 最大重试次数
wait_fixed_seconds: 固定等待策略的等待秒数
exception: 额外需要重试的异常类型元组
log_name: 用于日志记录的操作名称
on_failure: (可选) 所有重试失败后的回调
return_on_failure: (可选) 所有重试失败后的返回值
"""
return Retry.api(
stop_max_attempt=stop_max_attempt,
wait_fixed_seconds=wait_fixed_seconds,
exception=exception,
strategy="fixed",
log_name=log_name,
on_failure=on_failure,
return_on_failure=return_on_failure,
)
@staticmethod
def download(
stop_max_attempt: int = 3,
exception: tuple[type[Exception], ...] = (),
*,
wait_exp_multiplier: int = 2,
wait_exp_max: int = 15,
log_name: str | None = None,
on_failure: Callable[[Exception], Any] | None = None,
return_on_failure: Any = _SENTINEL,
):
"""
一个适用于文件下载的重试装饰器预设使用指数退避策略
参数:
stop_max_attempt: 最大重试次数
exception: 额外需要重试的异常类型元组
wait_exp_multiplier: 指数退避的乘数
wait_exp_max: 指数退避的最大等待时间
log_name: 用于日志记录的操作名称
on_failure: (可选) 所有重试失败后的回调
return_on_failure: (可选) 所有重试失败后的返回值
"""
return Retry.api(
stop_max_attempt=stop_max_attempt,
exception=exception,
strategy="exponential",
wait_exp_multiplier=wait_exp_multiplier,
wait_exp_max=wait_exp_max,
log_name=log_name,
on_failure=on_failure,
return_on_failure=return_on_failure,
)
@staticmethod
def api(
stop_max_attempt: int = 3,
wait_fixed_seconds: int = 1,
exception: tuple[type[Exception], ...] = (),
*,
strategy: Literal["fixed", "exponential"] = "fixed",
retry_on_result: Callable[[Any], bool] | None = None,
wait_exp_multiplier: int = 1,
wait_exp_max: int = 10,
log_name: str | None = None,
on_failure: Callable[[Exception], Any] | None = None,
return_on_failure: Any = _SENTINEL,
):
"""
通用可配置的API调用重试装饰器
参数:
stop_max_attempt: 最大重试次数
wait_fixed_seconds: 固定等待策略的等待秒数
exception: 额外需要重试的异常类型元组
strategy: 重试等待策略, 'fixed' (固定) 'exponential' (指数退避)
retry_on_result: 一个回调函数接收函数返回值如果返回 True则触发重试
例如 `lambda r: r.status_code != 200`
wait_exp_multiplier: 指数退避的乘数
wait_exp_max: 指数退避的最大等待时间
log_name: 用于日志记录的操作名称方便区分不同的重试场景
on_failure: (可选) 当所有重试都失败后在抛出异常或返回默认值之前
会调用此函数并将最终的异常实例作为参数传入
return_on_failure: (可选) 如果设置了此参数当所有重试失败后
将不再抛出异常而是返回此参数指定的值
"""
base_exceptions = (
TimeoutException,
ConnectError,
HTTPStatusError,
StreamError,
RemoteProtocolError,
EndOfStream,
*exception,
)
return retry(
reraise=True,
stop=stop_after_attempt(retry_count),
wait=wait_fixed(wait),
retry=retry_if_exception_type(base_exceptions),
)
def decorator(func: Callable) -> Callable:
if strategy == "exponential":
wait_strategy = wait_exponential(
multiplier=wait_exp_multiplier, max=wait_exp_max
)
else:
wait_strategy = wait_fixed(wait_fixed_seconds)
retry_conditions = retry_if_exception_type(base_exceptions)
if retry_on_result:
retry_conditions |= retry_if_result(retry_on_result)
log_callback = partial(_log_before_sleep, log_name)
tenacity_retry_decorator = retry(
stop=stop_after_attempt(stop_max_attempt),
wait=wait_strategy,
retry=retry_conditions,
before_sleep=log_callback,
reraise=True,
)
decorated_func = tenacity_retry_decorator(func)
if return_on_failure is _SENTINEL:
return decorated_func
if is_coroutine_callable(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
try:
return await decorated_func(*args, **kwargs)
except Exception as e:
if on_failure:
if is_coroutine_callable(on_failure):
await on_failure(e)
else:
on_failure(e)
return return_on_failure
return async_wrapper
else:
@wraps(func)
def sync_wrapper(*args, **kwargs):
try:
return decorated_func(*args, **kwargs)
except Exception as e:
if on_failure:
if is_coroutine_callable(on_failure):
logger.error(
f"不能在同步函数 '{func.__name__}' 中调用异步的 "
f"on_failure 回调。",
LOG_COMMAND,
)
else:
on_failure(e)
return return_on_failure
return sync_wrapper
return decorator

View File

@ -50,7 +50,7 @@ class CacheType(StrEnum):
"""用户权限"""
LIMIT = "GLOBAL_LIMIT"
"""插件限制"""
class BankHandleType(StrEnum):
DEPOSIT = "DEPOSIT"
"""存款"""

View File

@ -64,3 +64,23 @@ class GoodsNotFound(Exception):
"""
pass
class AllURIsFailedError(Exception):
"""
当所有备用URL都尝试失败后抛出此异常
"""
def __init__(self, urls: list[str], exceptions: list[Exception]):
self.urls = urls
self.exceptions = exceptions
super().__init__(
f"All {len(urls)} URIs failed. Last exception: {exceptions[-1]}"
)
def __str__(self) -> str:
exc_info = "\n".join(
f" - {url}: {exc.__class__.__name__}({exc})"
for url, exc in zip(self.urls, self.exceptions)
)
return f"All {len(self.urls)} URIs failed:\n{exc_info}"

View File

@ -1,16 +1,15 @@
import asyncio
from collections.abc import AsyncGenerator, Sequence
from collections.abc import AsyncGenerator, Awaitable, Callable, Sequence
from contextlib import asynccontextmanager
import os
from pathlib import Path
import time
from typing import Any, ClassVar, Literal, cast
from typing import Any, ClassVar, cast
import aiofiles
import httpx
from httpx import AsyncHTTPTransport, HTTPStatusError, Proxy, Response
from nonebot_plugin_alconna import UniMessage
from nonebot_plugin_htmlrender import get_browser
from playwright.async_api import Page
from httpx import AsyncClient, AsyncHTTPTransport, HTTPStatusError, Proxy, Response
import nonebot
from rich.progress import (
BarColumn,
DownloadColumn,
@ -18,13 +17,84 @@ from rich.progress import (
TextColumn,
TransferSpeedColumn,
)
import ujson as json
from zhenxun.configs.config import BotConfig
from zhenxun.services.log import logger
from zhenxun.utils.message import MessageUtils
from zhenxun.utils.decorator.retry import Retry
from zhenxun.utils.exception import AllURIsFailedError
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from zhenxun.utils.user_agent import get_user_agent
CLIENT_KEY = ["use_proxy", "proxies", "proxy", "verify", "headers"]
from .browser import AsyncPlaywright, BrowserIsNone # noqa: F401
_SENTINEL = object()
driver = nonebot.get_driver()
_client: AsyncClient | None = None
@PriorityLifecycle.on_startup(priority=0)
async def _():
"""
在Bot启动时初始化全局httpx客户端
"""
global _client
client_kwargs = {}
if proxy_url := BotConfig.system_proxy or None:
try:
version_parts = httpx.__version__.split(".")
major = int("".join(c for c in version_parts[0] if c.isdigit()))
minor = (
int("".join(c for c in version_parts[1] if c.isdigit()))
if len(version_parts) > 1
else 0
)
if (major, minor) >= (0, 28):
client_kwargs["proxy"] = proxy_url
else:
client_kwargs["proxies"] = proxy_url
except (ValueError, IndexError):
client_kwargs["proxy"] = proxy_url
logger.warning(
f"无法解析 httpx 版本 '{httpx.__version__}'"
"将默认使用新版 'proxy' 参数语法。"
)
_client = httpx.AsyncClient(
headers=get_user_agent(),
follow_redirects=True,
**client_kwargs,
)
logger.info("全局 httpx.AsyncClient 已启动。", "HTTPClient")
@driver.on_shutdown
async def _():
"""
在Bot关闭时关闭全局httpx客户端
"""
if _client:
await _client.aclose()
logger.info("全局 httpx.AsyncClient 已关闭。", "HTTPClient")
def get_client() -> AsyncClient:
"""
获取全局 httpx.AsyncClient 实例
"""
global _client
if not _client:
if not os.environ.get("PYTEST_CURRENT_TEST"):
raise RuntimeError("全局 httpx.AsyncClient 未初始化,请检查启动流程。")
# 在测试环境中创建临时客户端
logger.warning("在测试环境中创建临时HTTP客户端", "HTTPClient")
_client = httpx.AsyncClient(
headers=get_user_agent(),
follow_redirects=True,
)
return _client
def get_async_client(
@ -33,6 +103,10 @@ def get_async_client(
verify: bool = False,
**kwargs,
) -> httpx.AsyncClient:
"""
[向后兼容] 创建 httpx.AsyncClient 实例的工厂函数
此函数完全保留了旧版本的接口确保现有代码无需修改即可使用
"""
transport = kwargs.pop("transport", None) or AsyncHTTPTransport(verify=verify)
if proxies:
http_proxy = proxies.get("http://")
@ -62,6 +136,30 @@ def get_async_client(
class AsyncHttpx:
"""
一个高级的健壮的异步HTTP客户端工具类
设计理念:
- **全局共享客户端**: 默认情况下所有请求都通过一个在应用启动时初始化的全局
`httpx.AsyncClient` 实例发出这个实例共享连接池提高了效率和性能
- **向后兼容与灵活性**: 完全兼容旧的API同时提供了两种方式来处理需要
特殊网络配置如不同代理超时的请求
1. **单次请求覆盖**: 在调用 `get`, `post` 等方法时直接传入 `proxies`,
`timeout` 等参数将为该次请求创建一个临时的独立的客户端
2. **临时客户端上下文**: 使用 `temporary_client()` 上下文管理器可以
获取一个独立的可配置的客户端用于执行一系列需要相同特殊配置的请求
- **健壮性**: 内置了自动重试多镜像URL回退fallback机制并提供了便捷的
JSON解析和文件下载方法
"""
CLIENT_KEY: ClassVar[list[str]] = [
"use_proxy",
"proxies",
"proxy",
"verify",
"headers",
]
default_proxy: ClassVar[dict[str, str] | None] = (
{
"http://": BotConfig.system_proxy,
@ -72,155 +170,346 @@ class AsyncHttpx:
)
@classmethod
@asynccontextmanager
async def _create_client(
cls,
*,
use_proxy: bool = True,
proxies: dict[str, str] | None = None,
proxy: str | None = None,
headers: dict[str, str] | None = None,
verify: bool = False,
**kwargs,
) -> AsyncGenerator[httpx.AsyncClient, None]:
"""创建一个私有的、配置好的 httpx.AsyncClient 上下文管理器。
def _prepare_temporary_client_config(cls, client_kwargs: dict) -> dict:
"""
[向后兼容] 处理旧式的客户端kwargs将其转换为get_async_client可用的配置
主要负责处理 use_proxy 标志这是为了兼容旧版本代码中使用的 use_proxy 参数
"""
final_config = client_kwargs.copy()
说明:
此方法用于内部统一创建客户端处理代理和请求头逻辑减少代码重复
use_proxy = final_config.pop("use_proxy", True)
if "proxies" not in final_config and "proxy" not in final_config:
final_config["proxies"] = cls.default_proxy if use_proxy else None
return final_config
@classmethod
def _split_kwargs(cls, kwargs: dict) -> tuple[dict, dict]:
"""[优化] 分离客户端配置和请求参数,使逻辑更清晰。"""
client_kwargs = {k: v for k, v in kwargs.items() if k in cls.CLIENT_KEY}
request_kwargs = {k: v for k, v in kwargs.items() if k not in cls.CLIENT_KEY}
return client_kwargs, request_kwargs
@classmethod
@asynccontextmanager
async def _get_active_client_context(
cls, client: AsyncClient | None = None, **kwargs
) -> AsyncGenerator[AsyncClient, None]:
"""
内部辅助方法根据 kwargs 决定并提供一个活动的 HTTP 客户端
- 如果 kwargs 中有客户端配置则创建并返回一个临时客户端
- 否则返回传入的 client 或全局客户端
- 自动处理临时客户端的关闭
"""
if kwargs:
logger.debug(f"为单次请求创建临时客户端,配置: {kwargs}")
temp_client_config = cls._prepare_temporary_client_config(kwargs)
async with get_async_client(**temp_client_config) as temp_client:
yield temp_client
else:
yield client or get_client()
@Retry.simple(log_name="内部HTTP请求")
async def _execute_request_inner(
self, client: AsyncClient, method: str, url: str, **kwargs
) -> Response:
"""
[内部] 执行单次HTTP请求的私有核心方法被重试装饰器包裹
"""
return await client.request(method, url, **kwargs)
@classmethod
async def _single_request(
cls, method: str, url: str, *, client: AsyncClient | None = None, **kwargs
) -> Response:
"""
执行单次HTTP请求的私有方法内置了默认的重试逻辑
"""
client_kwargs, request_kwargs = cls._split_kwargs(kwargs)
async with cls._get_active_client_context(
client=client, **client_kwargs
) as active_client:
response = await cls()._execute_request_inner(
active_client, method, url, **request_kwargs
)
response.raise_for_status()
return response
@classmethod
async def _execute_with_fallbacks(
cls,
urls: str | list[str],
worker: Callable[..., Awaitable[Any]],
*,
client: AsyncClient | None = None,
**kwargs,
) -> Any:
"""
通用执行器按顺序尝试多个URL直到成功
参数:
use_proxy: 是否使用在类中定义的默认代理
proxies: 手动指定的代理会覆盖默认代理
proxy: 单个代理,用于兼容旧版本不再使用
headers: 需要合并到客户端的自定义请求头
verify: 是否验证 SSL 证书
**kwargs: 其他所有传递给 httpx.AsyncClient 的参数
返回:
AsyncGenerator[httpx.AsyncClient, None]: 生成器
urls: 单个URL或URL列表
worker: 一个接受单个URL和其他kwargs并执行请求的协程函数
client: 可选的HTTP客户端
**kwargs: 传递给worker的额外参数
"""
proxies_to_use = proxies or (cls.default_proxy if use_proxy else None)
url_list = [urls] if isinstance(urls, str) else urls
exceptions = []
final_headers = get_user_agent()
if headers:
final_headers.update(headers)
for i, url in enumerate(url_list):
try:
result = await worker(url, client=client, **kwargs)
if i > 0:
logger.info(
f"成功从镜像 '{url}' 获取资源 "
f"(在尝试了 {i} 个失败的镜像之后)。",
"AsyncHttpx:FallbackExecutor",
)
return result
except Exception as e:
exceptions.append(e)
if url != url_list[-1]:
logger.warning(
f"Worker '{worker.__name__}' on {url} failed, trying next. "
f"Error: {e.__class__.__name__}",
"AsyncHttpx:FallbackExecutor",
)
async with get_async_client(
proxies=proxies_to_use,
proxy=proxy,
verify=verify,
headers=final_headers,
**kwargs,
) as client:
yield client
raise AllURIsFailedError(url_list, exceptions)
@classmethod
async def get(
cls,
url: str | list[str],
*,
follow_redirects: bool = True,
check_status_code: int | None = None,
client: AsyncClient | None = None,
**kwargs,
) -> Response: # sourcery skip: use-assigned-variable
) -> Response:
"""发送 GET 请求,并返回第一个成功的响应。
说明:
本方法是 httpx.get 的高级包装增加了多链接尝试自动重试和统一的代理管理
如果提供 URL 列表它将依次尝试直到成功为止
本方法是 httpx.get 的高级包装增加了多链接尝试自动重试和统一的
客户端管理如果提供 URL 列表它将依次尝试直到成功为止
用法建议:
- **常规使用**: `await AsyncHttpx.get(url)` 将使用全局客户端
- **单次覆盖配置**: `await AsyncHttpx.get(url, timeout=5, proxies=None)`
将为本次请求创建一个独立的临时客户端
参数:
url: 单个请求 URL 或一个 URL 列表
follow_redirects: 是否跟随重定向
check_status_code: (可选) 若提供将检查响应状态码是否匹配否则抛出异常
**kwargs: 其他所有传递给 httpx.get 的参数
( `params`, `headers`, `timeout`)
client: (可选) 指定一个活动的HTTP客户端实例若提供则忽略
`**kwargs`中的客户端配置
**kwargs: 其他所有传递给 httpx.get 的参数 ( `params`, `headers`,
`timeout`)如果包含 `proxies`, `verify` 等客户端配置参数
将创建一个临时客户端
返回:
Response: Response
Response: httpx 的响应对象
Raises:
AllURIsFailedError: 当所有提供的URL都请求失败时抛出
"""
urls = [url] if isinstance(url, str) else url
last_exception = None
for current_url in urls:
try:
logger.info(f"开始获取 {current_url}..")
client_kwargs = {k: v for k, v in kwargs.items() if k in CLIENT_KEY}
for key in CLIENT_KEY:
kwargs.pop(key, None)
async with cls._create_client(**client_kwargs) as client:
response = await client.get(current_url, **kwargs)
if check_status_code and response.status_code != check_status_code:
raise HTTPStatusError(
f"状态码错误: {response.status_code}!={check_status_code}",
request=response.request,
response=response,
)
return response
except Exception as e:
last_exception = e
if current_url != urls[-1]:
logger.warning(f"获取 {current_url} 失败, 尝试下一个", e=e)
async def worker(current_url: str, **worker_kwargs) -> Response:
logger.info(f"开始获取 {current_url}..", "AsyncHttpx:get")
response = await cls._single_request(
"GET", current_url, follow_redirects=follow_redirects, **worker_kwargs
)
if check_status_code and response.status_code != check_status_code:
raise HTTPStatusError(
f"状态码错误: {response.status_code}!={check_status_code}",
request=response.request,
response=response,
)
return response
raise last_exception or Exception("所有URL都获取失败")
return await cls._execute_with_fallbacks(url, worker, client=client, **kwargs)
@classmethod
async def head(cls, url: str, **kwargs) -> Response:
"""发送 HEAD 请求。
async def head(
cls, url: str | list[str], *, client: AsyncClient | None = None, **kwargs
) -> Response:
"""发送 HEAD 请求,并返回第一个成功的响应。"""
说明:
本方法是对 httpx.head 的封装通常用于检查资源的元信息如大小类型
async def worker(current_url: str, **worker_kwargs) -> Response:
return await cls._single_request("HEAD", current_url, **worker_kwargs)
参数:
url: 请求的 URL
**kwargs: 其他所有传递给 httpx.head 的参数
( `headers`, `timeout`, `allow_redirects`)
返回:
Response: Response
"""
client_kwargs = {k: v for k, v in kwargs.items() if k in CLIENT_KEY}
for key in CLIENT_KEY:
kwargs.pop(key, None)
async with cls._create_client(**client_kwargs) as client:
return await client.head(url, **kwargs)
return await cls._execute_with_fallbacks(url, worker, client=client, **kwargs)
@classmethod
async def post(cls, url: str, **kwargs) -> Response:
"""发送 POST 请求。
async def post(
cls, url: str | list[str], *, client: AsyncClient | None = None, **kwargs
) -> Response:
"""发送 POST 请求,并返回第一个成功的响应。"""
说明:
本方法是对 httpx.post 的封装提供了统一的代理和客户端管理
async def worker(current_url: str, **worker_kwargs) -> Response:
return await cls._single_request("POST", current_url, **worker_kwargs)
参数:
url: 请求的 URL
**kwargs: 其他所有传递给 httpx.post 的参数
( `data`, `json`, `content` )
返回:
Response: Response
"""
client_kwargs = {k: v for k, v in kwargs.items() if k in CLIENT_KEY}
for key in CLIENT_KEY:
kwargs.pop(key, None)
async with cls._create_client(**client_kwargs) as client:
return await client.post(url, **kwargs)
return await cls._execute_with_fallbacks(url, worker, client=client, **kwargs)
@classmethod
async def get_content(cls, url: str, **kwargs) -> bytes:
"""获取指定 URL 的二进制内容。
说明:
这是一个便捷方法等同于调用 get() 后再访问 .content 属性
参数:
url: 请求的 URL
**kwargs: 所有传递给 get() 方法的参数
返回:
bytes: 响应内容的二进制字节流 (bytes)
"""
res = await cls.get(url, **kwargs)
async def get_content(
cls, url: str | list[str], *, client: AsyncClient | None = None, **kwargs
) -> bytes:
"""获取指定 URL 的二进制内容。"""
res = await cls.get(url, client=client, **kwargs)
return res.content
@classmethod
@Retry.api(
log_name="JSON请求",
exception=(json.JSONDecodeError,),
return_on_failure=_SENTINEL,
)
async def _request_and_parse_json(
cls, method: str, url: str, *, client: AsyncClient | None = None, **kwargs
) -> Any:
"""
[私有] 执行单个HTTP请求并解析JSON用于内部统一处理
"""
async with cls._get_active_client_context(
client=client, **kwargs
) as active_client:
_, request_kwargs = cls._split_kwargs(kwargs)
response = await active_client.request(method, url, **request_kwargs)
response.raise_for_status()
return response.json()
@classmethod
async def get_json(
cls,
url: str | list[str],
*,
default: Any = None,
raise_on_failure: bool = False,
client: AsyncClient | None = None,
**kwargs,
) -> Any:
"""
发送GET请求并自动解析为JSON支持重试和多链接尝试
说明:
这是一个高度便捷的方法封装了请求重试JSON解析和错误处理
它会在网络错误或JSON解析错误时自动重试
如果所有尝试都失败它会安全地返回一个默认值
参数:
url: 单个请求 URL 或一个备用 URL 列表
default: (可选) 当所有尝试都失败时返回的默认值默认为None
raise_on_failure: (可选) 如果为 True, 当所有尝试失败时将抛出
`AllURIsFailedError` 异常, 默认为 False.
client: (可选) 指定的HTTP客户端
**kwargs: 其他所有传递给 httpx.get 的参数
例如 `params`, `headers`, `timeout`
返回:
Any: 解析后的JSON数据或在失败时返回 `default`
Raises:
AllURIsFailedError: `raise_on_failure` True 且所有URL都请求失败时抛出
"""
async def worker(current_url: str, **worker_kwargs):
logger.debug(f"开始GET JSON: {current_url}", "AsyncHttpx:get_json")
return await cls._request_and_parse_json(
"GET", current_url, **worker_kwargs
)
try:
result = await cls._execute_with_fallbacks(
url, worker, client=client, **kwargs
)
return default if result is _SENTINEL else result
except AllURIsFailedError as e:
logger.error(f"所有URL的JSON GET均失败: {e}", "AsyncHttpx:get_json")
if raise_on_failure:
raise e
return default
@classmethod
async def post_json(
cls,
url: str | list[str],
*,
json: Any = None,
data: Any = None,
default: Any = None,
raise_on_failure: bool = False,
client: AsyncClient | None = None,
**kwargs,
) -> Any:
"""
发送POST请求并自动解析为JSON功能与 get_json 类似
参数:
url: 单个请求 URL 或一个备用 URL 列表
json: (可选) 作为请求体发送的JSON数据
data: (可选) 作为请求体发送的表单数据
default: (可选) 当所有尝试都失败时返回的默认值默认为None
raise_on_failure: (可选) 如果为 True, 当所有尝试失败时将抛出
AllURIsFailedError 异常, 默认为 False.
client: (可选) 指定的HTTP客户端
**kwargs: 其他所有传递给 httpx.post 的参数
返回:
Any: 解析后的JSON数据或在失败时返回 `default`
"""
if json is not None:
kwargs["json"] = json
if data is not None:
kwargs["data"] = data
async def worker(current_url: str, **worker_kwargs):
logger.debug(f"开始POST JSON: {current_url}", "AsyncHttpx:post_json")
return await cls._request_and_parse_json(
"POST", current_url, **worker_kwargs
)
try:
result = await cls._execute_with_fallbacks(
url, worker, client=client, **kwargs
)
return default if result is _SENTINEL else result
except AllURIsFailedError as e:
logger.error(f"所有URL的JSON POST均失败: {e}", "AsyncHttpx:post_json")
if raise_on_failure:
raise e
return default
@classmethod
@Retry.api(log_name="文件下载(流式)")
async def _stream_download(
cls, url: str, path: Path, *, client: AsyncClient | None = None, **kwargs
) -> None:
"""
执行单个流式下载的私有方法被重试装饰器包裹
"""
async with cls._get_active_client_context(
client=client, **kwargs
) as active_client:
async with active_client.stream("GET", url, **kwargs) as response:
response.raise_for_status()
total = int(response.headers.get("Content-Length", 0))
with Progress(
TextColumn(path.name),
"[progress.percentage]{task.percentage:>3.0f}%",
BarColumn(bar_width=None),
DownloadColumn(),
TransferSpeedColumn(),
) as progress:
task_id = progress.add_task("Download", total=total)
async with aiofiles.open(path, "wb") as f:
async for chunk in response.aiter_bytes():
await f.write(chunk)
progress.update(task_id, advance=len(chunk))
@classmethod
async def download_file(
cls,
@ -228,6 +517,7 @@ class AsyncHttpx:
path: str | Path,
*,
stream: bool = False,
client: AsyncClient | None = None,
**kwargs,
) -> bool:
"""下载文件到指定路径。
@ -239,6 +529,7 @@ class AsyncHttpx:
url: 单个文件 URL 或一个备用 URL 列表
path: 文件保存的本地路径
stream: (可选) 是否使用流式下载适用于大文件默认为 False
client: (可选) 指定的HTTP客户端
**kwargs: 其他所有传递给 get() 方法或 httpx.stream() 的参数
返回:
@ -247,49 +538,29 @@ class AsyncHttpx:
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
urls = [url] if isinstance(url, str) else url
async def worker(current_url: str, **worker_kwargs) -> bool:
if not stream:
content = await cls.get_content(current_url, **worker_kwargs)
async with aiofiles.open(path, "wb") as f:
await f.write(content)
else:
await cls._stream_download(current_url, path, **worker_kwargs)
for current_url in urls:
try:
if not stream:
response = await cls.get(current_url, **kwargs)
response.raise_for_status()
async with aiofiles.open(path, "wb") as f:
await f.write(response.content)
else:
async with cls._create_client(**kwargs) as client:
stream_kwargs = {
k: v
for k, v in kwargs.items()
if k not in ["use_proxy", "proxy", "verify"]
}
async with client.stream(
"GET", current_url, **stream_kwargs
) as response:
response.raise_for_status()
total = int(response.headers.get("Content-Length", 0))
logger.info(
f"下载 {current_url} 成功 -> {path.absolute()}",
"AsyncHttpx:download",
)
return True
with Progress(
TextColumn(path.name),
"[progress.percentage]{task.percentage:>3.0f}%",
BarColumn(bar_width=None),
DownloadColumn(),
TransferSpeedColumn(),
) as progress:
task_id = progress.add_task("Download", total=total)
async with aiofiles.open(path, "wb") as f:
async for chunk in response.aiter_bytes():
await f.write(chunk)
progress.update(task_id, advance=len(chunk))
logger.info(f"下载 {current_url} 成功 -> {path.absolute()}")
return True
except Exception as e:
logger.warning(f"下载 {current_url} 失败,尝试下一个。错误: {e}")
logger.error(f"所有URL {urls} 下载均失败 -> {path.absolute()}")
return False
try:
return await cls._execute_with_fallbacks(
url, worker, client=client, **kwargs
)
except AllURIsFailedError:
logger.error(
f"所有URL下载均失败 -> {path.absolute()}", "AsyncHttpx:download"
)
return False
@classmethod
async def gather_download_file(
@ -346,7 +617,6 @@ class AsyncHttpx:
logger.error(f"并发下载任务 ({url_info}) 时发生错误", e=result)
final_results.append(False)
else:
# download_file 返回的是 bool可以直接附加
final_results.append(cast(bool, result))
return final_results
@ -395,86 +665,30 @@ class AsyncHttpx:
_results = sorted(iter(_results), key=lambda r: r["elapsed_time"])
return [result["url"] for result in _results]
class AsyncPlaywright:
@classmethod
@asynccontextmanager
async def new_page(
cls, cookies: list[dict[str, Any]] | dict[str, Any] | None = None, **kwargs
) -> AsyncGenerator[Page, None]:
"""获取一个新页面
async def temporary_client(cls, **kwargs) -> AsyncGenerator[AsyncClient, None]:
"""
创建一个临时的可配置的HTTP客户端上下文并直接返回该客户端实例
此方法返回一个标准的 `httpx.AsyncClient`它不使用全局连接池
拥有独立的配置(如代理headers超时等)并在退出上下文后自动关闭
适用于需要用一套特殊网络配置执行一系列请求的场景
用法:
async with AsyncHttpx.temporary_client(proxies=None, timeout=5) as client:
# client 是一个标准的 httpx.AsyncClient 实例
response1 = await client.get("http://some.internal.api/1")
response2 = await client.get("http://some.internal.api/2")
data = response2.json()
参数:
cookies: cookies
**kwargs: 所有传递给 `httpx.AsyncClient` 构造函数的参数
例如: `proxies`, `headers`, `verify`, `timeout`,
`follow_redirects`
Yields:
httpx.AsyncClient: 一个配置好的临时的客户端实例
"""
browser = await get_browser()
ctx = await browser.new_context(**kwargs)
if cookies:
if isinstance(cookies, dict):
cookies = [cookies]
await ctx.add_cookies(cookies) # type: ignore
page = await ctx.new_page()
try:
yield page
finally:
await page.close()
await ctx.close()
@classmethod
async def screenshot(
cls,
url: str,
path: Path | str,
element: str | list[str],
*,
wait_time: int | None = None,
viewport_size: dict[str, int] | None = None,
wait_until: (
Literal["domcontentloaded", "load", "networkidle"] | None
) = "networkidle",
timeout: float | None = None,
type_: Literal["jpeg", "png"] | None = None,
user_agent: str | None = None,
cookies: list[dict[str, Any]] | dict[str, Any] | None = None,
**kwargs,
) -> UniMessage | None:
"""截图,该方法仅用于简单快捷截图,复杂截图请操作 page
参数:
url: 网址
path: 存储路径
element: 元素选择
wait_time: 等待截取超时时间
viewport_size: 窗口大小
wait_until: 等待类型
timeout: 超时限制
type_: 保存类型
user_agent: user_agent
cookies: cookies
"""
if viewport_size is None:
viewport_size = {"width": 2560, "height": 1080}
if isinstance(path, str):
path = Path(path)
wait_time = wait_time * 1000 if wait_time else None
element_list = [element] if isinstance(element, str) else element
async with cls.new_page(
cookies,
viewport=viewport_size,
user_agent=user_agent,
**kwargs,
) as page:
await page.goto(url, timeout=timeout, wait_until=wait_until)
card = page
for e in element_list:
if not card:
return None
card = await card.wait_for_selector(e, timeout=wait_time)
if card:
await card.screenshot(path=path, timeout=timeout, type=type_)
return MessageUtils.build_message(path)
return None
class BrowserIsNone(Exception):
pass
async with get_async_client(**kwargs) as client:
yield client

View File

@ -0,0 +1,810 @@
import asyncio
from collections.abc import Callable, Coroutine
import copy
import inspect
import random
from typing import ClassVar
import nonebot
from nonebot import get_bots
from nonebot_plugin_apscheduler import scheduler
from pydantic import BaseModel, ValidationError
from zhenxun.configs.config import Config
from zhenxun.models.schedule_info import ScheduleInfo
from zhenxun.services.log import logger
from zhenxun.utils.common_utils import CommonUtils
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from zhenxun.utils.platform import PlatformUtils
SCHEDULE_CONCURRENCY_KEY = "all_groups_concurrency_limit"
class SchedulerManager:
"""
一个通用的持久化的定时任务管理器供所有插件使用
"""
_registered_tasks: ClassVar[
dict[str, dict[str, Callable | type[BaseModel] | None]]
] = {}
_JOB_PREFIX = "zhenxun_schedule_"
_running_tasks: ClassVar[set] = set()
def register(
self, plugin_name: str, params_model: type[BaseModel] | None = None
) -> Callable:
"""
注册一个可调度的任务函数
被装饰的函数签名应为 `async def func(group_id: str | None, **kwargs)`
Args:
plugin_name (str): 插件的唯一名称 (通常是模块名)
params_model (type[BaseModel], optional): 一个 Pydantic BaseModel
用于定义和验证任务函数接受的额外参数
"""
def decorator(func: Callable[..., Coroutine]) -> Callable[..., Coroutine]:
if plugin_name in self._registered_tasks:
logger.warning(f"插件 '{plugin_name}' 的定时任务已被重复注册。")
self._registered_tasks[plugin_name] = {
"func": func,
"model": params_model,
}
model_name = params_model.__name__ if params_model else ""
logger.debug(
f"插件 '{plugin_name}' 的定时任务已注册,参数模型: {model_name}"
)
return func
return decorator
def get_registered_plugins(self) -> list[str]:
"""获取所有已注册定时任务的插件列表。"""
return list(self._registered_tasks.keys())
def _get_job_id(self, schedule_id: int) -> str:
"""根据数据库ID生成唯一的 APScheduler Job ID。"""
return f"{self._JOB_PREFIX}{schedule_id}"
async def _execute_job(self, schedule_id: int):
"""
APScheduler 调度的入口函数
根据 schedule_id 处理特定任务所有群组任务或全局任务
"""
schedule = await ScheduleInfo.get_or_none(id=schedule_id)
if not schedule or not schedule.is_enabled:
logger.warning(f"定时任务 {schedule_id} 不存在或已禁用,跳过执行。")
return
plugin_name = schedule.plugin_name
task_meta = self._registered_tasks.get(plugin_name)
if not task_meta:
logger.error(
f"无法执行定时任务:插件 '{plugin_name}' 未注册或已卸载。将禁用该任务。"
)
schedule.is_enabled = False
await schedule.save(update_fields=["is_enabled"])
self._remove_aps_job(schedule.id)
return
try:
if schedule.bot_id:
bot = nonebot.get_bot(schedule.bot_id)
else:
bot = nonebot.get_bot()
logger.debug(
f"任务 {schedule_id} 未关联特定Bot使用默认Bot {bot.self_id}"
)
except KeyError:
logger.warning(
f"定时任务 {schedule_id} 需要的 Bot {schedule.bot_id} "
f"不在线,本次执行跳过。"
)
return
except ValueError:
logger.warning(f"当前没有Bot在线定时任务 {schedule_id} 跳过。")
return
if schedule.group_id == "__ALL_GROUPS__":
await self._execute_for_all_groups(schedule, task_meta, bot)
else:
await self._execute_for_single_target(schedule, task_meta, bot)
async def _execute_for_all_groups(
self, schedule: ScheduleInfo, task_meta: dict, bot
):
"""为所有群组执行任务,并处理优先级覆盖。"""
plugin_name = schedule.plugin_name
concurrency_limit = Config.get_config(
"SchedulerManager", SCHEDULE_CONCURRENCY_KEY, 5
)
if not isinstance(concurrency_limit, int) or concurrency_limit <= 0:
logger.warning(
f"无效的定时任务并发限制配置 '{concurrency_limit}',将使用默认值 5。"
)
concurrency_limit = 5
logger.info(
f"开始执行针对 [所有群组] 的任务 "
f"(ID: {schedule.id}, 插件: {plugin_name}, Bot: {bot.self_id})"
f"并发限制: {concurrency_limit}"
)
all_gids = set()
try:
group_list, _ = await PlatformUtils.get_group_list(bot)
all_gids.update(
g.group_id for g in group_list if g.group_id and not g.channel_id
)
except Exception as e:
logger.error(f"'all' 任务获取 Bot {bot.self_id} 的群列表失败", e=e)
return
specific_tasks_gids = set(
await ScheduleInfo.filter(
plugin_name=plugin_name, group_id__in=list(all_gids)
).values_list("group_id", flat=True)
)
semaphore = asyncio.Semaphore(concurrency_limit)
async def worker(gid: str):
"""使用 Semaphore 包装单个群组的任务执行"""
async with semaphore:
temp_schedule = copy.deepcopy(schedule)
temp_schedule.group_id = gid
await self._execute_for_single_target(temp_schedule, task_meta, bot)
await asyncio.sleep(random.uniform(0.1, 0.5))
tasks_to_run = []
for gid in all_gids:
if gid in specific_tasks_gids:
logger.debug(f"群组 {gid} 已有特定任务,跳过 'all' 任务的执行。")
continue
tasks_to_run.append(worker(gid))
if tasks_to_run:
await asyncio.gather(*tasks_to_run)
async def _execute_for_single_target(
self, schedule: ScheduleInfo, task_meta: dict, bot
):
"""为单个目标(具体群组或全局)执行任务。"""
plugin_name = schedule.plugin_name
group_id = schedule.group_id
try:
is_blocked = await CommonUtils.task_is_block(bot, plugin_name, group_id)
if is_blocked:
target_desc = f"{group_id}" if group_id else "全局"
logger.info(
f"插件 '{plugin_name}' 的定时任务在目标 [{target_desc}]"
"因功能被禁用而跳过执行。"
)
return
task_func = task_meta["func"]
job_kwargs = schedule.job_kwargs
if not isinstance(job_kwargs, dict):
logger.error(
f"任务 {schedule.id} 的 job_kwargs 不是字典类型: {type(job_kwargs)}"
)
return
sig = inspect.signature(task_func)
if "bot" in sig.parameters:
job_kwargs["bot"] = bot
logger.info(
f"插件 '{plugin_name}' 开始为目标 [{group_id or '全局'}] "
f"执行定时任务 (ID: {schedule.id})。"
)
task = asyncio.create_task(task_func(group_id, **job_kwargs))
self._running_tasks.add(task)
task.add_done_callback(self._running_tasks.discard)
await task
except Exception as e:
logger.error(
f"执行定时任务 (ID: {schedule.id}, 插件: {plugin_name}, "
f"目标: {group_id or '全局'}) 时发生异常",
e=e,
)
def _validate_and_prepare_kwargs(
self, plugin_name: str, job_kwargs: dict | None
) -> tuple[bool, str | dict]:
"""验证并准备任务参数,应用默认值"""
task_meta = self._registered_tasks.get(plugin_name)
if not task_meta:
return False, f"插件 '{plugin_name}' 未注册。"
params_model = task_meta.get("model")
job_kwargs = job_kwargs if job_kwargs is not None else {}
if not params_model:
if job_kwargs:
logger.warning(
f"插件 '{plugin_name}' 未定义参数模型,但收到了参数: {job_kwargs}"
)
return True, job_kwargs
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
logger.error(f"插件 '{plugin_name}' 的参数模型不是有效的 BaseModel 类")
return False, f"插件 '{plugin_name}' 的参数模型配置错误"
try:
model_validate = getattr(params_model, "model_validate", None)
if not model_validate:
return False, f"插件 '{plugin_name}' 的参数模型不支持验证"
validated_model = model_validate(job_kwargs)
model_dump = getattr(validated_model, "model_dump", None)
if not model_dump:
return False, f"插件 '{plugin_name}' 的参数模型不支持导出"
return True, model_dump()
except ValidationError as e:
errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()]
error_str = "\n".join(errors)
msg = f"插件 '{plugin_name}' 的任务参数验证失败:\n{error_str}"
return False, msg
def _add_aps_job(self, schedule: ScheduleInfo):
"""根据 ScheduleInfo 对象添加或更新一个 APScheduler 任务。"""
job_id = self._get_job_id(schedule.id)
try:
scheduler.remove_job(job_id)
except Exception:
pass
if not isinstance(schedule.trigger_config, dict):
logger.error(
f"任务 {schedule.id} 的 trigger_config 不是字典类型: "
f"{type(schedule.trigger_config)}"
)
return
scheduler.add_job(
self._execute_job,
trigger=schedule.trigger_type,
id=job_id,
misfire_grace_time=300,
args=[schedule.id],
**schedule.trigger_config,
)
logger.debug(
f"已在 APScheduler 中添加/更新任务: {job_id} "
f"with trigger: {schedule.trigger_config}"
)
def _remove_aps_job(self, schedule_id: int):
"""移除一个 APScheduler 任务。"""
job_id = self._get_job_id(schedule_id)
try:
scheduler.remove_job(job_id)
logger.debug(f"已从 APScheduler 中移除任务: {job_id}")
except Exception:
pass
async def add_schedule(
self,
plugin_name: str,
group_id: str | None,
trigger_type: str,
trigger_config: dict,
job_kwargs: dict | None = None,
bot_id: str | None = None,
) -> tuple[bool, str]:
"""
添加或更新一个定时任务
"""
if plugin_name not in self._registered_tasks:
return False, f"插件 '{plugin_name}' 没有注册可用的定时任务。"
is_valid, result = self._validate_and_prepare_kwargs(plugin_name, job_kwargs)
if not is_valid:
return False, str(result)
validated_job_kwargs = result
effective_bot_id = bot_id if group_id == "__ALL_GROUPS__" else None
search_kwargs = {
"plugin_name": plugin_name,
"group_id": group_id,
}
if effective_bot_id:
search_kwargs["bot_id"] = effective_bot_id
else:
search_kwargs["bot_id__isnull"] = True
defaults = {
"trigger_type": trigger_type,
"trigger_config": trigger_config,
"job_kwargs": validated_job_kwargs,
"is_enabled": True,
}
schedule = await ScheduleInfo.filter(**search_kwargs).first()
created = False
if schedule:
for key, value in defaults.items():
setattr(schedule, key, value)
await schedule.save()
else:
creation_kwargs = {
"plugin_name": plugin_name,
"group_id": group_id,
"bot_id": effective_bot_id,
**defaults,
}
schedule = await ScheduleInfo.create(**creation_kwargs)
created = True
self._add_aps_job(schedule)
action = "设置" if created else "更新"
return True, f"已成功{action}插件 '{plugin_name}' 的定时任务。"
async def add_schedule_for_all(
self,
plugin_name: str,
trigger_type: str,
trigger_config: dict,
job_kwargs: dict | None = None,
) -> tuple[int, int]:
"""为所有机器人所在的群组添加定时任务。"""
if plugin_name not in self._registered_tasks:
raise ValueError(f"插件 '{plugin_name}' 没有注册可用的定时任务。")
groups = set()
for bot in get_bots().values():
try:
group_list, _ = await PlatformUtils.get_group_list(bot)
groups.update(
g.group_id for g in group_list if g.group_id and not g.channel_id
)
except Exception as e:
logger.error(f"获取 Bot {bot.self_id} 的群列表失败", e=e)
success_count = 0
fail_count = 0
for gid in groups:
try:
success, _ = await self.add_schedule(
plugin_name, gid, trigger_type, trigger_config, job_kwargs
)
if success:
success_count += 1
else:
fail_count += 1
except Exception as e:
logger.error(f"为群 {gid} 添加定时任务失败: {e}", e=e)
fail_count += 1
await asyncio.sleep(0.05)
return success_count, fail_count
async def update_schedule(
self,
schedule_id: int,
trigger_type: str | None = None,
trigger_config: dict | None = None,
job_kwargs: dict | None = None,
) -> tuple[bool, str]:
"""部分更新一个已存在的定时任务。"""
schedule = await self.get_schedule_by_id(schedule_id)
if not schedule:
return False, f"未找到 ID 为 {schedule_id} 的任务。"
updated_fields = []
if trigger_config is not None:
schedule.trigger_config = trigger_config
updated_fields.append("trigger_config")
if trigger_type is not None and schedule.trigger_type != trigger_type:
schedule.trigger_type = trigger_type
updated_fields.append("trigger_type")
if job_kwargs is not None:
if not isinstance(schedule.job_kwargs, dict):
return False, f"任务 {schedule_id} 的 job_kwargs 数据格式错误。"
merged_kwargs = schedule.job_kwargs.copy()
merged_kwargs.update(job_kwargs)
is_valid, result = self._validate_and_prepare_kwargs(
schedule.plugin_name, merged_kwargs
)
if not is_valid:
return False, str(result)
schedule.job_kwargs = result # type: ignore
updated_fields.append("job_kwargs")
if not updated_fields:
return True, "没有任何需要更新的配置。"
await schedule.save(update_fields=updated_fields)
self._add_aps_job(schedule)
return True, f"成功更新了任务 ID: {schedule_id} 的配置。"
async def remove_schedule(
self, plugin_name: str, group_id: str | None, bot_id: str | None = None
) -> tuple[bool, str]:
"""移除指定插件和群组的定时任务。"""
query = {"plugin_name": plugin_name, "group_id": group_id}
if bot_id:
query["bot_id"] = bot_id
schedules = await ScheduleInfo.filter(**query)
if not schedules:
msg = (
f"未找到与 Bot {bot_id} 相关的群 {group_id} "
f"的插件 '{plugin_name}' 定时任务。"
)
return (False, msg)
for schedule in schedules:
self._remove_aps_job(schedule.id)
await schedule.delete()
target_desc = f"{group_id}" if group_id else "全局"
msg = (
f"已取消 Bot {bot_id} 在 [{target_desc}] "
f"的插件 '{plugin_name}' 所有定时任务。"
)
return (True, msg)
async def remove_schedule_for_all(
self, plugin_name: str, bot_id: str | None = None
) -> int:
"""移除指定插件在所有群组的定时任务。"""
query = {"plugin_name": plugin_name}
if bot_id:
query["bot_id"] = bot_id
schedules_to_delete = await ScheduleInfo.filter(**query).all()
if not schedules_to_delete:
return 0
for schedule in schedules_to_delete:
self._remove_aps_job(schedule.id)
await schedule.delete()
await asyncio.sleep(0.01)
return len(schedules_to_delete)
async def remove_schedules_by_group(self, group_id: str) -> tuple[bool, str]:
"""移除指定群组的所有定时任务。"""
schedules = await ScheduleInfo.filter(group_id=group_id)
if not schedules:
return False, f"{group_id} 没有任何定时任务。"
count = 0
for schedule in schedules:
self._remove_aps_job(schedule.id)
await schedule.delete()
count += 1
await asyncio.sleep(0.01)
return True, f"已成功移除群 {group_id}{count} 个定时任务。"
async def pause_schedules_by_group(self, group_id: str) -> tuple[int, str]:
"""暂停指定群组的所有定时任务。"""
schedules = await ScheduleInfo.filter(group_id=group_id, is_enabled=True)
if not schedules:
return 0, f"{group_id} 没有正在运行的定时任务可暂停。"
count = 0
for schedule in schedules:
success, _ = await self.pause_schedule(schedule.id)
if success:
count += 1
await asyncio.sleep(0.01)
return count, f"已成功暂停群 {group_id}{count} 个定时任务。"
async def resume_schedules_by_group(self, group_id: str) -> tuple[int, str]:
"""恢复指定群组的所有定时任务。"""
schedules = await ScheduleInfo.filter(group_id=group_id, is_enabled=False)
if not schedules:
return 0, f"{group_id} 没有已暂停的定时任务可恢复。"
count = 0
for schedule in schedules:
success, _ = await self.resume_schedule(schedule.id)
if success:
count += 1
await asyncio.sleep(0.01)
return count, f"已成功恢复群 {group_id}{count} 个定时任务。"
async def pause_schedules_by_plugin(self, plugin_name: str) -> tuple[int, str]:
"""暂停指定插件在所有群组的定时任务。"""
schedules = await ScheduleInfo.filter(plugin_name=plugin_name, is_enabled=True)
if not schedules:
return 0, f"插件 '{plugin_name}' 没有正在运行的定时任务可暂停。"
count = 0
for schedule in schedules:
success, _ = await self.pause_schedule(schedule.id)
if success:
count += 1
await asyncio.sleep(0.01)
return (
count,
f"已成功暂停插件 '{plugin_name}' 在所有群组的 {count} 个定时任务。",
)
async def resume_schedules_by_plugin(self, plugin_name: str) -> tuple[int, str]:
"""恢复指定插件在所有群组的定时任务。"""
schedules = await ScheduleInfo.filter(plugin_name=plugin_name, is_enabled=False)
if not schedules:
return 0, f"插件 '{plugin_name}' 没有已暂停的定时任务可恢复。"
count = 0
for schedule in schedules:
success, _ = await self.resume_schedule(schedule.id)
if success:
count += 1
await asyncio.sleep(0.01)
return (
count,
f"已成功恢复插件 '{plugin_name}' 在所有群组的 {count} 个定时任务。",
)
async def pause_schedule_by_plugin_group(
self, plugin_name: str, group_id: str | None, bot_id: str | None = None
) -> tuple[bool, str]:
"""暂停指定插件在指定群组的定时任务。"""
query = {"plugin_name": plugin_name, "group_id": group_id, "is_enabled": True}
if bot_id:
query["bot_id"] = bot_id
schedules = await ScheduleInfo.filter(**query)
if not schedules:
return (
False,
f"{group_id} 未设置插件 '{plugin_name}' 的定时任务或任务已暂停。",
)
count = 0
for schedule in schedules:
success, _ = await self.pause_schedule(schedule.id)
if success:
count += 1
return (
True,
f"已成功暂停群 {group_id} 的插件 '{plugin_name}'{count} 个定时任务。",
)
async def resume_schedule_by_plugin_group(
self, plugin_name: str, group_id: str | None, bot_id: str | None = None
) -> tuple[bool, str]:
"""恢复指定插件在指定群组的定时任务。"""
query = {"plugin_name": plugin_name, "group_id": group_id, "is_enabled": False}
if bot_id:
query["bot_id"] = bot_id
schedules = await ScheduleInfo.filter(**query)
if not schedules:
return (
False,
f"{group_id} 未设置插件 '{plugin_name}' 的定时任务或任务已启用。",
)
count = 0
for schedule in schedules:
success, _ = await self.resume_schedule(schedule.id)
if success:
count += 1
return (
True,
f"已成功恢复群 {group_id} 的插件 '{plugin_name}'{count} 个定时任务。",
)
async def remove_all_schedules(self) -> tuple[int, str]:
"""移除所有群组的所有定时任务。"""
schedules = await ScheduleInfo.all()
if not schedules:
return 0, "当前没有任何定时任务。"
count = 0
for schedule in schedules:
self._remove_aps_job(schedule.id)
await schedule.delete()
count += 1
await asyncio.sleep(0.01)
return count, f"已成功移除所有群组的 {count} 个定时任务。"
async def pause_all_schedules(self) -> tuple[int, str]:
"""暂停所有群组的所有定时任务。"""
schedules = await ScheduleInfo.filter(is_enabled=True)
if not schedules:
return 0, "当前没有正在运行的定时任务可暂停。"
count = 0
for schedule in schedules:
success, _ = await self.pause_schedule(schedule.id)
if success:
count += 1
await asyncio.sleep(0.01)
return count, f"已成功暂停所有群组的 {count} 个定时任务。"
async def resume_all_schedules(self) -> tuple[int, str]:
"""恢复所有群组的所有定时任务。"""
schedules = await ScheduleInfo.filter(is_enabled=False)
if not schedules:
return 0, "当前没有已暂停的定时任务可恢复。"
count = 0
for schedule in schedules:
success, _ = await self.resume_schedule(schedule.id)
if success:
count += 1
await asyncio.sleep(0.01)
return count, f"已成功恢复所有群组的 {count} 个定时任务。"
async def remove_schedule_by_id(self, schedule_id: int) -> tuple[bool, str]:
"""通过ID移除指定的定时任务。"""
schedule = await self.get_schedule_by_id(schedule_id)
if not schedule:
return False, f"未找到 ID 为 {schedule_id} 的定时任务。"
self._remove_aps_job(schedule.id)
await schedule.delete()
return (
True,
f"已删除插件 '{schedule.plugin_name}' 在群 {schedule.group_id} "
f"的定时任务 (ID: {schedule.id})。",
)
async def get_schedule_by_id(self, schedule_id: int) -> ScheduleInfo | None:
"""通过ID获取定时任务信息。"""
return await ScheduleInfo.get_or_none(id=schedule_id)
async def get_schedules(
self, plugin_name: str, group_id: str | None
) -> list[ScheduleInfo]:
"""获取特定群组特定插件的所有定时任务。"""
return await ScheduleInfo.filter(plugin_name=plugin_name, group_id=group_id)
async def get_schedule(
self, plugin_name: str, group_id: str | None
) -> ScheduleInfo | None:
"""获取特定群组的定时任务信息。"""
return await ScheduleInfo.get_or_none(
plugin_name=plugin_name, group_id=group_id
)
async def get_all_schedules(
self, plugin_name: str | None = None
) -> list[ScheduleInfo]:
"""获取所有定时任务信息,可按插件名过滤。"""
if plugin_name:
return await ScheduleInfo.filter(plugin_name=plugin_name).all()
return await ScheduleInfo.all()
async def get_schedule_status(self, schedule_id: int) -> dict | None:
"""获取任务的详细状态。"""
schedule = await self.get_schedule_by_id(schedule_id)
if not schedule:
return None
job_id = self._get_job_id(schedule.id)
job = scheduler.get_job(job_id)
status = {
"id": schedule.id,
"bot_id": schedule.bot_id,
"plugin_name": schedule.plugin_name,
"group_id": schedule.group_id,
"is_enabled": schedule.is_enabled,
"trigger_type": schedule.trigger_type,
"trigger_config": schedule.trigger_config,
"job_kwargs": schedule.job_kwargs,
"next_run_time": job.next_run_time.strftime("%Y-%m-%d %H:%M:%S")
if job and job.next_run_time
else "N/A",
"is_paused_in_scheduler": not bool(job.next_run_time) if job else "N/A",
}
return status
async def pause_schedule(self, schedule_id: int) -> tuple[bool, str]:
"""暂停一个定时任务。"""
schedule = await self.get_schedule_by_id(schedule_id)
if not schedule or not schedule.is_enabled:
return False, "任务不存在或已暂停。"
schedule.is_enabled = False
await schedule.save(update_fields=["is_enabled"])
job_id = self._get_job_id(schedule.id)
try:
scheduler.pause_job(job_id)
except Exception:
pass
return (
True,
f"已暂停插件 '{schedule.plugin_name}' 在群 {schedule.group_id} "
f"的定时任务 (ID: {schedule.id})。",
)
async def resume_schedule(self, schedule_id: int) -> tuple[bool, str]:
"""恢复一个定时任务。"""
schedule = await self.get_schedule_by_id(schedule_id)
if not schedule or schedule.is_enabled:
return False, "任务不存在或已启用。"
schedule.is_enabled = True
await schedule.save(update_fields=["is_enabled"])
job_id = self._get_job_id(schedule.id)
try:
scheduler.resume_job(job_id)
except Exception:
self._add_aps_job(schedule)
return (
True,
f"已恢复插件 '{schedule.plugin_name}' 在群 {schedule.group_id} "
f"的定时任务 (ID: {schedule.id})。",
)
async def trigger_now(self, schedule_id: int) -> tuple[bool, str]:
"""手动触发一个定时任务。"""
schedule = await self.get_schedule_by_id(schedule_id)
if not schedule:
return False, f"未找到 ID 为 {schedule_id} 的定时任务。"
if schedule.plugin_name not in self._registered_tasks:
return False, f"插件 '{schedule.plugin_name}' 没有注册可用的定时任务。"
try:
await self._execute_job(schedule.id)
return (
True,
f"已手动触发插件 '{schedule.plugin_name}' 在群 {schedule.group_id} "
f"的定时任务 (ID: {schedule.id})。",
)
except Exception as e:
logger.error(f"手动触发任务失败: {e}")
return False, f"手动触发任务失败: {e}"
scheduler_manager = SchedulerManager()
@PriorityLifecycle.on_startup(priority=90)
async def _load_schedules_from_db():
"""在服务启动时从数据库加载并调度所有任务。"""
Config.add_plugin_config(
"SchedulerManager",
SCHEDULE_CONCURRENCY_KEY,
5,
help="“所有群组”类型定时任务的并发执行数量限制",
type=int,
)
logger.info("正在从数据库加载并调度所有定时任务...")
schedules = await ScheduleInfo.filter(is_enabled=True).all()
count = 0
for schedule in schedules:
if schedule.plugin_name in scheduler_manager._registered_tasks:
scheduler_manager._add_aps_job(schedule)
count += 1
else:
logger.warning(f"跳过加载定时任务:插件 '{schedule.plugin_name}' 未注册。")
logger.info(f"定时任务加载完成,共成功加载 {count} 个任务。")

View File

@ -227,7 +227,7 @@ class PlatformUtils:
url = None
if platform == "qq":
if user_id.isdigit():
url = f"http://q1.qlogo.cn/g?b=qq&nk={user_id}&s=160"
url = f"http://q1.qlogo.cn/g?b=qq&nk={user_id}&s=640"
else:
url = f"https://q.qlogo.cn/qqapp/{appid}/{user_id}/640"
return await AsyncHttpx.get_content(url) if url else None