Merge branch 'main' into feature/fix-config

This commit is contained in:
HibiKier 2025-07-11 10:09:04 +08:00 committed by GitHub
commit 1992e478b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
56 changed files with 5131 additions and 5340 deletions

15
.gitignore vendored
View File

@ -139,22 +139,9 @@ dmypy.json
# Cython debug symbols
cython_debug/
demo.py
test.py
server_ip.py
member_activity_handle.py
Yu-Gi-Oh/
csgo/
fantasy_card/
data/
log/
backup/
extensive_plugin/
test/
bot.py
.idea/
resources/
/configs/config.py
configs/config.yaml
.vscode/launch.json
plugins_/
.vscode/launch.json

File diff suppressed because it is too large Load Diff

View File

@ -116,6 +116,7 @@ async def app(app: App, tmp_path: Path, mocker: MockerFixture):
await init()
# await driver._lifespan.startup()
os.environ["AIOCACHE_DISABLE"] = "1"
os.environ["PYTEST_CURRENT_TEST"] = "1"
yield app

View File

@ -1,7 +1,7 @@
from nonebot.adapters import Bot
from nonebot.plugin import PluginMetadata
from nonebot_plugin_alconna import AlconnaQuery, Arparma, Match, Query
from nonebot_plugin_session import EventSession
from nonebot_plugin_uninfo import Uninfo
from zhenxun.configs.config import Config
from zhenxun.configs.utils import PluginExtraData, RegisterConfig
@ -9,7 +9,7 @@ from zhenxun.services.log import logger
from zhenxun.utils.enum import BlockType, PluginType
from zhenxun.utils.message import MessageUtils
from ._data_source import PluginManage, build_plugin, build_task, delete_help_image
from ._data_source import PluginManager, build_plugin, build_task, delete_help_image
from .command import _group_status_matcher, _status_matcher
base_config = Config.get("plugin_switch")
@ -57,6 +57,11 @@ __plugin_meta__ = PluginMetadata(
关闭群被动早晚安
关闭群被动早晚安 -g 12355555
开启/关闭默认群被动 [被动名称]
私聊下: 开启/关闭群被动默认状态
示例:
关闭默认群被动 早晚安
开启/关闭所有群被动 ?[-g [group_id]]
私聊中: 开启/关闭全局或指定群组被动状态
示例:
@ -87,10 +92,10 @@ __plugin_meta__ = PluginMetadata(
@_status_matcher.assign("$main")
async def _(
bot: Bot,
session: EventSession,
session: Uninfo,
arparma: Arparma,
):
if session.id1 in bot.config.superusers:
if session.user.id in bot.config.superusers:
image = await build_plugin()
logger.info(
"查看功能列表",
@ -105,7 +110,7 @@ async def _(
@_status_matcher.assign("open")
async def _(
bot: Bot,
session: EventSession,
session: Uninfo,
arparma: Arparma,
plugin_name: Match[str],
group: Match[str],
@ -114,22 +119,23 @@ async def _(
all: Query[bool] = AlconnaQuery("all.value", False),
):
if not all.result and not plugin_name.available:
await MessageUtils.build_message("请输入功能名称").finish(reply_to=True)
await MessageUtils.build_message("请输入功能/被动名称").finish(reply_to=True)
name = plugin_name.result
if gid := session.id3 or session.id2:
if session.group:
group_id = session.group.id
"""修改当前群组的数据"""
if task.result:
if all.result:
result = await PluginManage.unblock_group_all_task(gid)
result = await PluginManager.unblock_group_all_task(group_id)
logger.info("开启所有群组被动", arparma.header_result, session=session)
else:
result = await PluginManage.unblock_group_task(name, gid)
result = await PluginManager.unblock_group_task(name, group_id)
logger.info(
f"开启群组被动 {name}", arparma.header_result, session=session
)
elif session.id1 in bot.config.superusers and default_status.result:
elif session.user.id in bot.config.superusers and default_status.result:
"""单个插件的进群默认修改"""
result = await PluginManage.set_default_status(name, True)
result = await PluginManager.set_default_status(name, True)
logger.info(
f"超级用户开启 {name} 功能进群默认开关",
arparma.header_result,
@ -137,8 +143,8 @@ async def _(
)
elif all.result:
"""所有插件"""
result = await PluginManage.set_all_plugin_status(
True, default_status.result, gid
result = await PluginManager.set_all_plugin_status(
True, default_status.result, group_id
)
logger.info(
"开启群组中全部功能",
@ -146,22 +152,24 @@ async def _(
session=session,
)
else:
result = await PluginManage.unblock_group_plugin(name, gid)
result = await PluginManager.unblock_group_plugin(name, group_id)
logger.info(f"开启功能 {name}", arparma.header_result, session=session)
delete_help_image(gid)
delete_help_image(group_id)
await MessageUtils.build_message(result).finish(reply_to=True)
elif session.id1 in bot.config.superusers:
elif session.user.id in bot.config.superusers:
"""私聊"""
group_id = group.result if group.available else None
if all.result:
if task.result:
"""关闭全局或指定群全部被动"""
if group_id:
result = await PluginManage.unblock_group_all_task(group_id)
result = await PluginManager.unblock_group_all_task(group_id)
else:
result = await PluginManage.unblock_global_all_task()
result = await PluginManager.unblock_global_all_task(
default_status.result
)
else:
result = await PluginManage.set_all_plugin_status(
result = await PluginManager.set_all_plugin_status(
True, default_status.result, group_id
)
logger.info(
@ -171,8 +179,8 @@ async def _(
session=session,
)
await MessageUtils.build_message(result).finish(reply_to=True)
if default_status.result:
result = await PluginManage.set_default_status(name, True)
if default_status.result and not task.result:
result = await PluginManager.set_default_status(name, True)
logger.info(
f"超级用户开启 {name} 功能进群默认开关",
arparma.header_result,
@ -186,7 +194,7 @@ async def _(
name = split_list[0]
group_id = split_list[1]
if group_id:
result = await PluginManage.superuser_task_handle(name, group_id, True)
result = await PluginManager.superuser_task_handle(name, group_id, True)
logger.info(
f"超级用户开启被动技能 {name}",
arparma.header_result,
@ -194,14 +202,16 @@ async def _(
target=group_id,
)
else:
result = await PluginManage.unblock_global_task(name)
result = await PluginManager.unblock_global_task(
name, default_status.result
)
logger.info(
f"超级用户开启全局被动技能 {name}",
arparma.header_result,
session=session,
)
else:
result = await PluginManage.superuser_unblock(name, None, group_id)
result = await PluginManager.superuser_unblock(name, None, group_id)
logger.info(
f"超级用户开启功能 {name}",
arparma.header_result,
@ -215,7 +225,7 @@ async def _(
@_status_matcher.assign("close")
async def _(
bot: Bot,
session: EventSession,
session: Uninfo,
arparma: Arparma,
plugin_name: Match[str],
block_type: Match[str],
@ -225,22 +235,23 @@ async def _(
all: Query[bool] = AlconnaQuery("all.value", False),
):
if not all.result and not plugin_name.available:
await MessageUtils.build_message("请输入功能名称").finish(reply_to=True)
await MessageUtils.build_message("请输入功能/被动名称").finish(reply_to=True)
name = plugin_name.result
if gid := session.id3 or session.id2:
if session.group:
group_id = session.group.id
"""修改当前群组的数据"""
if task.result:
if all.result:
result = await PluginManage.block_group_all_task(gid)
result = await PluginManager.block_group_all_task(group_id)
logger.info("开启所有群组被动", arparma.header_result, session=session)
else:
result = await PluginManage.block_group_task(name, gid)
result = await PluginManager.block_group_task(name, group_id)
logger.info(
f"关闭群组被动 {name}", arparma.header_result, session=session
)
elif session.id1 in bot.config.superusers and default_status.result:
elif session.user.id in bot.config.superusers and default_status.result:
"""单个插件的进群默认修改"""
result = await PluginManage.set_default_status(name, False)
result = await PluginManager.set_default_status(name, False)
logger.info(
f"超级用户开启 {name} 功能进群默认开关",
arparma.header_result,
@ -248,26 +259,28 @@ async def _(
)
elif all.result:
"""所有插件"""
result = await PluginManage.set_all_plugin_status(
False, default_status.result, gid
result = await PluginManager.set_all_plugin_status(
False, default_status.result, group_id
)
logger.info("关闭群组中全部功能", arparma.header_result, session=session)
else:
result = await PluginManage.block_group_plugin(name, gid)
result = await PluginManager.block_group_plugin(name, group_id)
logger.info(f"关闭功能 {name}", arparma.header_result, session=session)
delete_help_image(gid)
delete_help_image(group_id)
await MessageUtils.build_message(result).finish(reply_to=True)
elif session.id1 in bot.config.superusers:
elif session.user.id in bot.config.superusers:
group_id = group.result if group.available else None
if all.result:
if task.result:
"""关闭全局或指定群全部被动"""
if group_id:
result = await PluginManage.block_group_all_task(group_id)
result = await PluginManager.block_group_all_task(group_id)
else:
result = await PluginManage.block_global_all_task()
result = await PluginManager.block_global_all_task(
default_status.result
)
else:
result = await PluginManage.set_all_plugin_status(
result = await PluginManager.set_all_plugin_status(
False, default_status.result, group_id
)
logger.info(
@ -277,8 +290,8 @@ async def _(
session=session,
)
await MessageUtils.build_message(result).finish(reply_to=True)
if default_status.result:
result = await PluginManage.set_default_status(name, False)
if default_status.result and not task.result:
result = await PluginManager.set_default_status(name, False)
logger.info(
f"超级用户关闭 {name} 功能进群默认开关",
arparma.header_result,
@ -292,7 +305,9 @@ async def _(
name = split_list[0]
group_id = split_list[1]
if group_id:
result = await PluginManage.superuser_task_handle(name, group_id, False)
result = await PluginManager.superuser_task_handle(
name, group_id, False
)
logger.info(
f"超级用户关闭被动技能 {name}",
arparma.header_result,
@ -300,7 +315,9 @@ async def _(
target=group_id,
)
else:
result = await PluginManage.block_global_task(name)
result = await PluginManager.block_global_task(
name, default_status.result
)
logger.info(
f"超级用户关闭全局被动技能 {name}",
arparma.header_result,
@ -314,7 +331,7 @@ async def _(
elif block_type.result in ["g", "group"]:
if block_type.available:
_type = BlockType.GROUP
result = await PluginManage.superuser_block(name, _type, group_id)
result = await PluginManager.superuser_block(name, _type, group_id)
logger.info(
f"超级用户关闭功能 {name}, 禁用类型: {_type}",
arparma.header_result,
@ -327,19 +344,20 @@ async def _(
@_group_status_matcher.handle()
async def _(
session: EventSession,
session: Uninfo,
arparma: Arparma,
status: str,
):
if gid := session.id3 or session.id2:
if session.group:
group_id = session.group.id
if status == "sleep":
await PluginManage.sleep(gid)
await PluginManager.sleep(group_id)
logger.info("进行休眠", arparma.header_result, session=session)
await MessageUtils.build_message("那我先睡觉了...").finish()
else:
if await PluginManage.is_wake(gid):
if await PluginManager.is_wake(group_id):
await MessageUtils.build_message("我还醒着呢!").finish()
await PluginManage.wake(gid)
await PluginManager.wake(group_id)
logger.info("醒来", arparma.header_result, session=session)
await MessageUtils.build_message("呜..醒来了...").finish()
return MessageUtils.build_message("群组id为空...").send()
@ -347,10 +365,10 @@ async def _(
@_status_matcher.assign("task")
async def _(
session: EventSession,
session: Uninfo,
arparma: Arparma,
):
image = await build_task(session.id3 or session.id2)
image = await build_task(session.group.id if session.group else None)
if image:
logger.info("查看群被动列表", arparma.header_result, session=session)
await MessageUtils.build_message(image).finish(reply_to=True)

View File

@ -155,7 +155,7 @@ async def build_task(group_id: str | None) -> BuildImage:
)
class PluginManage:
class PluginManager:
@classmethod
async def set_default_status(cls, plugin_name: str, status: bool) -> str:
"""设置插件进群默认状态
@ -342,17 +342,21 @@ class PluginManage:
return await cls._change_group_task("", group_id, True, True)
@classmethod
async def block_global_all_task(cls) -> str:
async def block_global_all_task(cls, is_default: bool) -> str:
"""禁用全局被动技能
返回:
str: 返回信息
"""
await TaskInfo.all().update(status=False)
return "已全局禁用所有被动状态"
if is_default:
await TaskInfo.all().update(default_status=False)
return "已禁用所有被动进群默认状态"
else:
await TaskInfo.all().update(status=False)
return "已全局禁用所有被动状态"
@classmethod
async def block_global_task(cls, name: str) -> str:
async def block_global_task(cls, name: str, is_default: bool = False) -> str:
"""禁用全局被动技能
参数:
@ -361,31 +365,47 @@ class PluginManage:
返回:
str: 返回信息
"""
await TaskInfo.filter(name=name).update(status=False)
return f"已全局禁用被动状态 {name}"
if is_default:
await TaskInfo.filter(name=name).update(default_status=False)
return f"已禁用被动进群默认状态 {name}"
else:
await TaskInfo.filter(name=name).update(status=False)
return f"已全局禁用被动状态 {name}"
@classmethod
async def unblock_global_all_task(cls) -> str:
async def unblock_global_all_task(cls, is_default: bool) -> str:
"""开启全局被动技能
参数:
is_default: 是否为默认状态
返回:
str: 返回信息
"""
await TaskInfo.all().update(status=True)
return "已全局开启所有被动状态"
if is_default:
await TaskInfo.all().update(default_status=True)
return "已开启所有被动进群默认状态"
else:
await TaskInfo.all().update(status=True)
return "已全局开启所有被动状态"
@classmethod
async def unblock_global_task(cls, name: str) -> str:
async def unblock_global_task(cls, name: str, is_default: bool = False) -> str:
"""开启全局被动技能
参数:
name: 被动技能名称
is_default: 是否为默认状态
返回:
str: 返回信息
"""
await TaskInfo.filter(name=name).update(status=True)
return f"已全局开启被动状态 {name}"
if is_default:
await TaskInfo.filter(name=name).update(default_status=True)
return f"已开启被动进群默认状态 {name}"
else:
await TaskInfo.filter(name=name).update(status=True)
return f"已全局开启被动状态 {name}"
@classmethod
async def unblock_group_plugin(cls, plugin_name: str, group_id: str) -> str:

View File

@ -58,6 +58,19 @@ _status_matcher.shortcut(
prefix=True,
)
_status_matcher.shortcut(
r"开启(所有|全部)默认群被动",
command="switch",
arguments=["open", "--task", "--all", "-df"],
prefix=True,
)
_status_matcher.shortcut(
r"关闭(所有|全部)默认群被动",
command="switch",
arguments=["close", "--task", "--all", "-df"],
prefix=True,
)
_status_matcher.shortcut(
r"开启群被动\s*(?P<name>.+)",
@ -73,6 +86,20 @@ _status_matcher.shortcut(
prefix=True,
)
_status_matcher.shortcut(
r"开启默认群被动\s*(?P<name>.+)",
command="switch",
arguments=["open", "{name}", "--task", "-df"],
prefix=True,
)
_status_matcher.shortcut(
r"关闭默认群被动\s*(?P<name>.+)",
command="switch",
arguments=["close", "{name}", "--task", "-df"],
prefix=True,
)
_status_matcher.shortcut(
r"开启(所有|全部)群被动",

View File

@ -54,22 +54,6 @@ __plugin_meta__ = PluginMetadata(
default_value=5,
type=int,
),
RegisterConfig(
module="_task",
key="DEFAULT_GROUP_WELCOME",
value=True,
help="被动 进群欢迎 进群默认开关状态",
default_value=True,
type=bool,
),
RegisterConfig(
module="_task",
key="DEFAULT_REFUND_GROUP_REMIND",
value=True,
help="被动 退群提醒 进群默认开关状态",
default_value=True,
type=bool,
),
],
tasks=[
Task(

View File

@ -1,9 +1,11 @@
from nonebot.plugin import PluginMetadata
from zhenxun.configs.utils import PluginExtraData
from zhenxun.configs.utils import PluginExtraData, RegisterConfig
from zhenxun.utils.enum import PluginType
from . import command # noqa: F401
from . import commands, handlers
__all__ = ["commands", "handlers"]
__plugin_meta__ = PluginMetadata(
name="定时任务管理",
@ -27,6 +29,8 @@ __plugin_meta__ = PluginMetadata(
定时任务 恢复 <任务ID> | -p <插件> [-g <群号>] | -all
定时任务 执行 <任务ID>
定时任务 更新 <任务ID> [时间选项] [--kwargs <参数>]
# [修改] 增加说明
说明: -p 选项可单独使用用于操作指定插件的所有任务
📝 时间选项 (三选一):
--cron "<分> <时> <日> <月> <周>" # 例: --cron "0 8 * * *"
@ -47,5 +51,35 @@ __plugin_meta__ = PluginMetadata(
version="0.1.2",
plugin_type=PluginType.SUPERUSER,
is_show=False,
configs=[
RegisterConfig(
module="SchedulerManager",
key="ALL_GROUPS_CONCURRENCY_LIMIT",
value=5,
help="“所有群组”类型定时任务的并发执行数量限制",
type=int,
),
RegisterConfig(
module="SchedulerManager",
key="JOB_MAX_RETRIES",
value=2,
help="定时任务执行失败时的最大重试次数",
type=int,
),
RegisterConfig(
module="SchedulerManager",
key="JOB_RETRY_DELAY",
value=10,
help="定时任务执行重试的间隔时间(秒)",
type=int,
),
RegisterConfig(
module="SchedulerManager",
key="SCHEDULER_TIMEZONE",
value="Asia/Shanghai",
help="定时任务使用的时区,默认为 Asia/Shanghai",
type=str,
),
],
).to_dict(),
)

View File

@ -1,836 +0,0 @@
import asyncio
from datetime import datetime
import re
from nonebot.adapters import Event
from nonebot.adapters.onebot.v11 import Bot
from nonebot.params import Depends
from nonebot.permission import SUPERUSER
from nonebot_plugin_alconna import (
Alconna,
AlconnaMatch,
Args,
Arparma,
Match,
Option,
Query,
Subcommand,
on_alconna,
)
from pydantic import BaseModel, ValidationError
from zhenxun.utils._image_template import ImageTemplate
from zhenxun.utils.manager.schedule_manager import scheduler_manager
def _get_type_name(annotation) -> str:
"""获取类型注解的名称"""
if hasattr(annotation, "__name__"):
return annotation.__name__
elif hasattr(annotation, "_name"):
return annotation._name
else:
return str(annotation)
from zhenxun.utils.message import MessageUtils
from zhenxun.utils.rules import admin_check
def _format_trigger(schedule_status: dict) -> str:
"""将触发器配置格式化为人类可读的字符串"""
trigger_type = schedule_status["trigger_type"]
config = schedule_status["trigger_config"]
if trigger_type == "cron":
minute = config.get("minute", "*")
hour = config.get("hour", "*")
day = config.get("day", "*")
month = config.get("month", "*")
day_of_week = config.get("day_of_week", "*")
if day == "*" and month == "*" and day_of_week == "*":
formatted_hour = hour if hour == "*" else f"{int(hour):02d}"
formatted_minute = minute if minute == "*" else f"{int(minute):02d}"
return f"每天 {formatted_hour}:{formatted_minute}"
else:
return f"Cron: {minute} {hour} {day} {month} {day_of_week}"
elif trigger_type == "interval":
seconds = config.get("seconds", 0)
minutes = config.get("minutes", 0)
hours = config.get("hours", 0)
days = config.get("days", 0)
if days:
trigger_str = f"{days}"
elif hours:
trigger_str = f"{hours} 小时"
elif minutes:
trigger_str = f"{minutes} 分钟"
else:
trigger_str = f"{seconds}"
elif trigger_type == "date":
run_date = config.get("run_date", "未知时间")
trigger_str = f"{run_date}"
else:
trigger_str = f"{trigger_type}: {config}"
return trigger_str
def _format_params(schedule_status: dict) -> str:
"""将任务参数格式化为人类可读的字符串"""
if kwargs := schedule_status.get("job_kwargs"):
kwargs_str = " | ".join(f"{k}: {v}" for k, v in kwargs.items())
return kwargs_str
return "-"
def _parse_interval(interval_str: str) -> dict:
"""增强版解析器,支持 d(天)"""
match = re.match(r"(\d+)([smhd])", interval_str.lower())
if not match:
raise ValueError("时间间隔格式错误, 请使用如 '30m', '2h', '1d', '10s' 的格式。")
value, unit = int(match.group(1)), match.group(2)
if unit == "s":
return {"seconds": value}
if unit == "m":
return {"minutes": value}
if unit == "h":
return {"hours": value}
if unit == "d":
return {"days": value}
return {}
def _parse_daily_time(time_str: str) -> dict:
"""解析 HH:MM 或 HH:MM:SS 格式的时间为 cron 配置"""
if match := re.match(r"^(\d{1,2}):(\d{1,2})(?::(\d{1,2}))?$", time_str):
hour, minute, second = match.groups()
hour, minute = int(hour), int(minute)
if not (0 <= hour <= 23 and 0 <= minute <= 59):
raise ValueError("小时或分钟数值超出范围。")
cron_config = {
"minute": str(minute),
"hour": str(hour),
"day": "*",
"month": "*",
"day_of_week": "*",
}
if second is not None:
if not (0 <= int(second) <= 59):
raise ValueError("秒数值超出范围。")
cron_config["second"] = str(second)
return cron_config
else:
raise ValueError("时间格式错误,请使用 'HH:MM''HH:MM:SS' 格式。")
async def GetBotId(
bot: Bot,
bot_id_match: Match[str] = AlconnaMatch("bot_id"),
) -> str:
"""获取要操作的Bot ID"""
if bot_id_match.available:
return bot_id_match.result
return bot.self_id
class ScheduleTarget:
"""定时任务操作目标的基类"""
pass
class TargetByID(ScheduleTarget):
"""按任务ID操作"""
def __init__(self, id: int):
self.id = id
class TargetByPlugin(ScheduleTarget):
"""按插件名操作"""
def __init__(
self, plugin: str, group_id: str | None = None, all_groups: bool = False
):
self.plugin = plugin
self.group_id = group_id
self.all_groups = all_groups
class TargetAll(ScheduleTarget):
"""操作所有任务"""
def __init__(self, for_group: str | None = None):
self.for_group = for_group
TargetScope = TargetByID | TargetByPlugin | TargetAll | None
def create_target_parser(subcommand_name: str):
"""
创建一个依赖注入函数用于解析删除暂停恢复等命令的操作目标
"""
async def dependency(
event: Event,
schedule_id: Match[int] = AlconnaMatch("schedule_id"),
plugin_name: Match[str] = AlconnaMatch("plugin_name"),
group_id: Match[str] = AlconnaMatch("group_id"),
all_enabled: Query[bool] = Query(f"{subcommand_name}.all"),
) -> TargetScope:
if schedule_id.available:
return TargetByID(schedule_id.result)
if plugin_name.available:
p_name = plugin_name.result
if all_enabled.available:
return TargetByPlugin(plugin=p_name, all_groups=True)
elif group_id.available:
gid = group_id.result
if gid.lower() == "all":
return TargetByPlugin(plugin=p_name, all_groups=True)
return TargetByPlugin(plugin=p_name, group_id=gid)
else:
current_group_id = getattr(event, "group_id", None)
if current_group_id:
return TargetByPlugin(plugin=p_name, group_id=str(current_group_id))
else:
await schedule_cmd.finish(
"私聊中操作插件任务必须使用 -g <群号> 或 -all 选项。"
)
if all_enabled.available:
return TargetAll(for_group=group_id.result if group_id.available else None)
return None
return dependency
schedule_cmd = on_alconna(
Alconna(
"定时任务",
Subcommand(
"查看",
Option("-g", Args["target_group_id", str]),
Option("-all", help_text="查看所有群聊 (SUPERUSER)"),
Option("-p", Args["plugin_name", str], help_text="按插件名筛选"),
Option("--page", Args["page", int, 1], help_text="指定页码"),
alias=["ls", "list"],
help_text="查看定时任务",
),
Subcommand(
"设置",
Args["plugin_name", str],
Option("--cron", Args["cron_expr", str], help_text="设置 cron 表达式"),
Option("--interval", Args["interval_expr", str], help_text="设置时间间隔"),
Option("--date", Args["date_expr", str], help_text="设置特定执行日期"),
Option(
"--daily",
Args["daily_expr", str],
help_text="设置每天执行的时间 (如 08:20)",
),
Option("-g", Args["group_id", str], help_text="指定群组ID或'all'"),
Option("-all", help_text="对所有群生效 (等同于 -g all)"),
Option("--kwargs", Args["kwargs_str", str], help_text="设置任务参数"),
Option(
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
),
alias=["add", "开启"],
help_text="设置/开启一个定时任务",
),
Subcommand(
"删除",
Args["schedule_id?", int],
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
Option("-g", Args["group_id", str], help_text="指定群组ID"),
Option("-all", help_text="对所有群生效"),
Option(
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
),
alias=["del", "rm", "remove", "关闭", "取消"],
help_text="删除一个或多个定时任务",
),
Subcommand(
"暂停",
Args["schedule_id?", int],
Option("-all", help_text="对当前群所有任务生效"),
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
Option("-g", Args["group_id", str], help_text="指定群组ID (SUPERUSER)"),
Option(
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
),
alias=["pause"],
help_text="暂停一个或多个定时任务",
),
Subcommand(
"恢复",
Args["schedule_id?", int],
Option("-all", help_text="对当前群所有任务生效"),
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
Option("-g", Args["group_id", str], help_text="指定群组ID (SUPERUSER)"),
Option(
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
),
alias=["resume"],
help_text="恢复一个或多个定时任务",
),
Subcommand(
"执行",
Args["schedule_id", int],
alias=["trigger", "run"],
help_text="立即执行一次任务",
),
Subcommand(
"更新",
Args["schedule_id", int],
Option("--cron", Args["cron_expr", str], help_text="设置 cron 表达式"),
Option("--interval", Args["interval_expr", str], help_text="设置时间间隔"),
Option("--date", Args["date_expr", str], help_text="设置特定执行日期"),
Option(
"--daily",
Args["daily_expr", str],
help_text="更新每天执行的时间 (如 08:20)",
),
Option("--kwargs", Args["kwargs_str", str], help_text="更新参数"),
alias=["update", "modify", "修改"],
help_text="更新任务配置",
),
Subcommand(
"状态",
Args["schedule_id", int],
alias=["status", "info"],
help_text="查看单个任务的详细状态",
),
Subcommand(
"插件列表",
alias=["plugins"],
help_text="列出所有可用的插件",
),
),
priority=5,
block=True,
rule=admin_check(1),
)
schedule_cmd.shortcut(
"任务状态",
command="定时任务",
arguments=["状态", "{%0}"],
prefix=True,
)
@schedule_cmd.handle()
async def _handle_time_options_mutex(arp: Arparma):
time_options = ["cron", "interval", "date", "daily"]
provided_options = [opt for opt in time_options if arp.query(opt) is not None]
if len(provided_options) > 1:
await schedule_cmd.finish(
f"时间选项 --{', --'.join(provided_options)} 不能同时使用,请只选择一个。"
)
@schedule_cmd.assign("查看")
async def _(
bot: Bot,
event: Event,
target_group_id: Match[str] = AlconnaMatch("target_group_id"),
all_groups: Query[bool] = Query("查看.all"),
plugin_name: Match[str] = AlconnaMatch("plugin_name"),
page: Match[int] = AlconnaMatch("page"),
):
is_superuser = await SUPERUSER(bot, event)
schedules = []
title = ""
current_group_id = getattr(event, "group_id", None)
if not (all_groups.available or target_group_id.available) and not current_group_id:
await schedule_cmd.finish("私聊中查看任务必须使用 -g <群号> 或 -all 选项。")
if all_groups.available:
if not is_superuser:
await schedule_cmd.finish("需要超级用户权限才能查看所有群组的定时任务。")
schedules = await scheduler_manager.get_all_schedules()
title = "所有群组的定时任务"
elif target_group_id.available:
if not is_superuser:
await schedule_cmd.finish("需要超级用户权限才能查看指定群组的定时任务。")
gid = target_group_id.result
schedules = [
s for s in await scheduler_manager.get_all_schedules() if s.group_id == gid
]
title = f"{gid} 的定时任务"
else:
gid = str(current_group_id)
schedules = [
s for s in await scheduler_manager.get_all_schedules() if s.group_id == gid
]
title = "本群的定时任务"
if plugin_name.available:
schedules = [s for s in schedules if s.plugin_name == plugin_name.result]
title += f" [插件: {plugin_name.result}]"
if not schedules:
await schedule_cmd.finish("没有找到任何相关的定时任务。")
page_size = 15
current_page = page.result
total_items = len(schedules)
total_pages = (total_items + page_size - 1) // page_size
start_index = (current_page - 1) * page_size
end_index = start_index + page_size
paginated_schedules = schedules[start_index:end_index]
if not paginated_schedules:
await schedule_cmd.finish("这一页没有内容了哦~")
status_tasks = [
scheduler_manager.get_schedule_status(s.id) for s in paginated_schedules
]
all_statuses = await asyncio.gather(*status_tasks)
data_list = [
[
s["id"],
s["plugin_name"],
s.get("bot_id") or "N/A",
s["group_id"] or "全局",
s["next_run_time"],
_format_trigger(s),
_format_params(s),
"✔️ 已启用" if s["is_enabled"] else "⏸️ 已暂停",
]
for s in all_statuses
if s
]
if not data_list:
await schedule_cmd.finish("没有找到任何相关的定时任务。")
img = await ImageTemplate.table_page(
head_text=title,
tip_text=f"{current_page}/{total_pages} 页,共 {total_items} 条任务",
column_name=[
"ID",
"插件",
"Bot ID",
"群组/目标",
"下次运行",
"触发规则",
"参数",
"状态",
],
data_list=data_list,
column_space=20,
)
await MessageUtils.build_message(img).send(reply_to=True)
@schedule_cmd.assign("设置")
async def _(
event: Event,
plugin_name: str,
cron_expr: str | None = None,
interval_expr: str | None = None,
date_expr: str | None = None,
daily_expr: str | None = None,
group_id: str | None = None,
kwargs_str: str | None = None,
all_enabled: Query[bool] = Query("设置.all"),
bot_id_to_operate: str = Depends(GetBotId),
):
if plugin_name not in scheduler_manager._registered_tasks:
await schedule_cmd.finish(
f"插件 '{plugin_name}' 没有注册可用的定时任务。\n"
f"可用插件: {list(scheduler_manager._registered_tasks.keys())}"
)
trigger_type = ""
trigger_config = {}
try:
if cron_expr:
trigger_type = "cron"
parts = cron_expr.split()
if len(parts) != 5:
raise ValueError("Cron 表达式必须有5个部分 (分 时 日 月 周)")
cron_keys = ["minute", "hour", "day", "month", "day_of_week"]
trigger_config = dict(zip(cron_keys, parts))
elif interval_expr:
trigger_type = "interval"
trigger_config = _parse_interval(interval_expr)
elif date_expr:
trigger_type = "date"
trigger_config = {"run_date": datetime.fromisoformat(date_expr)}
elif daily_expr:
trigger_type = "cron"
trigger_config = _parse_daily_time(daily_expr)
else:
await schedule_cmd.finish(
"必须提供一种时间选项: --cron, --interval, --date, 或 --daily。"
)
except ValueError as e:
await schedule_cmd.finish(f"时间参数解析错误: {e}")
job_kwargs = {}
if kwargs_str:
task_meta = scheduler_manager._registered_tasks[plugin_name]
params_model = task_meta.get("model")
if not params_model:
await schedule_cmd.finish(f"插件 '{plugin_name}' 不支持设置额外参数。")
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
await schedule_cmd.finish(f"插件 '{plugin_name}' 的参数模型配置错误。")
raw_kwargs = {}
try:
for item in kwargs_str.split(","):
key, value = item.strip().split("=", 1)
raw_kwargs[key.strip()] = value
except Exception as e:
await schedule_cmd.finish(
f"参数格式错误,请使用 'key=value,key2=value2' 格式。错误: {e}"
)
try:
model_validate = getattr(params_model, "model_validate", None)
if not model_validate:
await schedule_cmd.finish(
f"插件 '{plugin_name}' 的参数模型不支持验证。"
)
return
validated_model = model_validate(raw_kwargs)
model_dump = getattr(validated_model, "model_dump", None)
if not model_dump:
await schedule_cmd.finish(
f"插件 '{plugin_name}' 的参数模型不支持导出。"
)
return
job_kwargs = model_dump()
except ValidationError as e:
errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()]
error_str = "\n".join(errors)
await schedule_cmd.finish(
f"插件 '{plugin_name}' 的任务参数验证失败:\n{error_str}"
)
return
target_group_id: str | None
current_group_id = getattr(event, "group_id", None)
if group_id and group_id.lower() == "all":
target_group_id = "__ALL_GROUPS__"
elif all_enabled.available:
target_group_id = "__ALL_GROUPS__"
elif group_id:
target_group_id = group_id
elif current_group_id:
target_group_id = str(current_group_id)
else:
await schedule_cmd.finish(
"私聊中设置定时任务时,必须使用 -g <群号> 或 --all 选项指定目标。"
)
return
success, msg = await scheduler_manager.add_schedule(
plugin_name,
target_group_id,
trigger_type,
trigger_config,
job_kwargs,
bot_id=bot_id_to_operate,
)
if target_group_id == "__ALL_GROUPS__":
target_desc = f"所有群组 (Bot: {bot_id_to_operate})"
elif target_group_id is None:
target_desc = "全局"
else:
target_desc = f"群组 {target_group_id}"
if success:
await schedule_cmd.finish(f"已成功为 [{target_desc}] {msg}")
else:
await schedule_cmd.finish(f"为 [{target_desc}] 设置任务失败: {msg}")
@schedule_cmd.assign("删除")
async def _(
target: TargetScope = Depends(create_target_parser("删除")),
bot_id_to_operate: str = Depends(GetBotId),
):
if isinstance(target, TargetByID):
_, message = await scheduler_manager.remove_schedule_by_id(target.id)
await schedule_cmd.finish(message)
elif isinstance(target, TargetByPlugin):
p_name = target.plugin
if p_name not in scheduler_manager.get_registered_plugins():
await schedule_cmd.finish(f"未找到插件 '{p_name}'")
if target.all_groups:
removed_count = await scheduler_manager.remove_schedule_for_all(
p_name, bot_id=bot_id_to_operate
)
message = (
f"已取消了 {removed_count} 个群组的插件 '{p_name}' 定时任务。"
if removed_count > 0
else f"没有找到插件 '{p_name}' 的定时任务。"
)
await schedule_cmd.finish(message)
else:
_, message = await scheduler_manager.remove_schedule(
p_name, target.group_id, bot_id=bot_id_to_operate
)
await schedule_cmd.finish(message)
elif isinstance(target, TargetAll):
if target.for_group:
_, message = await scheduler_manager.remove_schedules_by_group(
target.for_group
)
await schedule_cmd.finish(message)
else:
_, message = await scheduler_manager.remove_all_schedules()
await schedule_cmd.finish(message)
else:
await schedule_cmd.finish(
"删除任务失败请提供任务ID或通过 -p <插件> 或 -all 指定要删除的任务。"
)
@schedule_cmd.assign("暂停")
async def _(
target: TargetScope = Depends(create_target_parser("暂停")),
bot_id_to_operate: str = Depends(GetBotId),
):
if isinstance(target, TargetByID):
_, message = await scheduler_manager.pause_schedule(target.id)
await schedule_cmd.finish(message)
elif isinstance(target, TargetByPlugin):
p_name = target.plugin
if p_name not in scheduler_manager.get_registered_plugins():
await schedule_cmd.finish(f"未找到插件 '{p_name}'")
if target.all_groups:
_, message = await scheduler_manager.pause_schedules_by_plugin(p_name)
await schedule_cmd.finish(message)
else:
_, message = await scheduler_manager.pause_schedule_by_plugin_group(
p_name, target.group_id, bot_id=bot_id_to_operate
)
await schedule_cmd.finish(message)
elif isinstance(target, TargetAll):
if target.for_group:
_, message = await scheduler_manager.pause_schedules_by_group(
target.for_group
)
await schedule_cmd.finish(message)
else:
_, message = await scheduler_manager.pause_all_schedules()
await schedule_cmd.finish(message)
else:
await schedule_cmd.finish("请提供任务ID、使用 -p <插件> 或 -all 选项。")
@schedule_cmd.assign("恢复")
async def _(
target: TargetScope = Depends(create_target_parser("恢复")),
bot_id_to_operate: str = Depends(GetBotId),
):
if isinstance(target, TargetByID):
_, message = await scheduler_manager.resume_schedule(target.id)
await schedule_cmd.finish(message)
elif isinstance(target, TargetByPlugin):
p_name = target.plugin
if p_name not in scheduler_manager.get_registered_plugins():
await schedule_cmd.finish(f"未找到插件 '{p_name}'")
if target.all_groups:
_, message = await scheduler_manager.resume_schedules_by_plugin(p_name)
await schedule_cmd.finish(message)
else:
_, message = await scheduler_manager.resume_schedule_by_plugin_group(
p_name, target.group_id, bot_id=bot_id_to_operate
)
await schedule_cmd.finish(message)
elif isinstance(target, TargetAll):
if target.for_group:
_, message = await scheduler_manager.resume_schedules_by_group(
target.for_group
)
await schedule_cmd.finish(message)
else:
_, message = await scheduler_manager.resume_all_schedules()
await schedule_cmd.finish(message)
else:
await schedule_cmd.finish("请提供任务ID、使用 -p <插件> 或 -all 选项。")
@schedule_cmd.assign("执行")
async def _(schedule_id: int):
_, message = await scheduler_manager.trigger_now(schedule_id)
await schedule_cmd.finish(message)
@schedule_cmd.assign("更新")
async def _(
schedule_id: int,
cron_expr: str | None = None,
interval_expr: str | None = None,
date_expr: str | None = None,
daily_expr: str | None = None,
kwargs_str: str | None = None,
):
if not any([cron_expr, interval_expr, date_expr, daily_expr, kwargs_str]):
await schedule_cmd.finish(
"请提供需要更新的时间 (--cron/--interval/--date/--daily) 或参数 (--kwargs)"
)
trigger_config = None
trigger_type = None
try:
if cron_expr:
trigger_type = "cron"
parts = cron_expr.split()
if len(parts) != 5:
raise ValueError("Cron 表达式必须有5个部分")
cron_keys = ["minute", "hour", "day", "month", "day_of_week"]
trigger_config = dict(zip(cron_keys, parts))
elif interval_expr:
trigger_type = "interval"
trigger_config = _parse_interval(interval_expr)
elif date_expr:
trigger_type = "date"
trigger_config = {"run_date": datetime.fromisoformat(date_expr)}
elif daily_expr:
trigger_type = "cron"
trigger_config = _parse_daily_time(daily_expr)
except ValueError as e:
await schedule_cmd.finish(f"时间参数解析错误: {e}")
job_kwargs = None
if kwargs_str:
schedule = await scheduler_manager.get_schedule_by_id(schedule_id)
if not schedule:
await schedule_cmd.finish(f"未找到 ID 为 {schedule_id} 的任务。")
task_meta = scheduler_manager._registered_tasks.get(schedule.plugin_name)
if not task_meta or not (params_model := task_meta.get("model")):
await schedule_cmd.finish(
f"插件 '{schedule.plugin_name}' 未定义参数模型,无法更新参数。"
)
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
await schedule_cmd.finish(
f"插件 '{schedule.plugin_name}' 的参数模型配置错误。"
)
raw_kwargs = {}
try:
for item in kwargs_str.split(","):
key, value = item.strip().split("=", 1)
raw_kwargs[key.strip()] = value
except Exception as e:
await schedule_cmd.finish(
f"参数格式错误,请使用 'key=value,key2=value2' 格式。错误: {e}"
)
try:
model_validate = getattr(params_model, "model_validate", None)
if not model_validate:
await schedule_cmd.finish(
f"插件 '{schedule.plugin_name}' 的参数模型不支持验证。"
)
return
validated_model = model_validate(raw_kwargs)
model_dump = getattr(validated_model, "model_dump", None)
if not model_dump:
await schedule_cmd.finish(
f"插件 '{schedule.plugin_name}' 的参数模型不支持导出。"
)
return
job_kwargs = model_dump(exclude_unset=True)
except ValidationError as e:
errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()]
error_str = "\n".join(errors)
await schedule_cmd.finish(f"更新的参数验证失败:\n{error_str}")
return
_, message = await scheduler_manager.update_schedule(
schedule_id, trigger_type, trigger_config, job_kwargs
)
await schedule_cmd.finish(message)
@schedule_cmd.assign("插件列表")
async def _():
registered_plugins = scheduler_manager.get_registered_plugins()
if not registered_plugins:
await schedule_cmd.finish("当前没有已注册的定时任务插件。")
message_parts = ["📋 已注册的定时任务插件:"]
for i, plugin_name in enumerate(registered_plugins, 1):
task_meta = scheduler_manager._registered_tasks[plugin_name]
params_model = task_meta.get("model")
if not params_model:
message_parts.append(f"{i}. {plugin_name} - 无参数")
continue
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
message_parts.append(f"{i}. {plugin_name} - ⚠️ 参数模型配置错误")
continue
model_fields = getattr(params_model, "model_fields", None)
if model_fields:
param_info = ", ".join(
f"{field_name}({_get_type_name(field_info.annotation)})"
for field_name, field_info in model_fields.items()
)
message_parts.append(f"{i}. {plugin_name} - 参数: {param_info}")
else:
message_parts.append(f"{i}. {plugin_name} - 无参数")
await schedule_cmd.finish("\n".join(message_parts))
@schedule_cmd.assign("状态")
async def _(schedule_id: int):
status = await scheduler_manager.get_schedule_status(schedule_id)
if not status:
await schedule_cmd.finish(f"未找到ID为 {schedule_id} 的定时任务。")
info_lines = [
f"📋 定时任务详细信息 (ID: {schedule_id})",
"--------------------",
f"▫️ 插件: {status['plugin_name']}",
f"▫️ Bot ID: {status.get('bot_id') or '默认'}",
f"▫️ 目标: {status['group_id'] or '全局'}",
f"▫️ 状态: {'✔️ 已启用' if status['is_enabled'] else '⏸️ 已暂停'}",
f"▫️ 下次运行: {status['next_run_time']}",
f"▫️ 触发规则: {_format_trigger(status)}",
f"▫️ 任务参数: {_format_params(status)}",
]
await schedule_cmd.finish("\n".join(info_lines))

View File

@ -0,0 +1,298 @@
import re
from nonebot.adapters import Event
from nonebot.adapters.onebot.v11 import Bot
from nonebot.params import Depends
from nonebot.permission import SUPERUSER
from nonebot_plugin_alconna import (
Alconna,
AlconnaMatch,
Args,
Match,
Option,
Query,
Subcommand,
on_alconna,
)
from zhenxun.configs.config import Config
from zhenxun.services.scheduler import scheduler_manager
from zhenxun.services.scheduler.targeter import ScheduleTargeter
from zhenxun.utils.rules import admin_check
schedule_cmd = on_alconna(
Alconna(
"定时任务",
Subcommand(
"查看",
Option("-g", Args["target_group_id", str]),
Option("-all", help_text="查看所有群聊 (SUPERUSER)"),
Option("-p", Args["plugin_name", str], help_text="按插件名筛选"),
Option("--page", Args["page", int, 1], help_text="指定页码"),
alias=["ls", "list"],
help_text="查看定时任务",
),
Subcommand(
"设置",
Args["plugin_name", str],
Option("--cron", Args["cron_expr", str], help_text="设置 cron 表达式"),
Option("--interval", Args["interval_expr", str], help_text="设置时间间隔"),
Option("--date", Args["date_expr", str], help_text="设置特定执行日期"),
Option(
"--daily",
Args["daily_expr", str],
help_text="设置每天执行的时间 (如 08:20)",
),
Option("-g", Args["group_id", str], help_text="指定群组ID或'all'"),
Option("-all", help_text="对所有群生效 (等同于 -g all)"),
Option("--kwargs", Args["kwargs_str", str], help_text="设置任务参数"),
Option(
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
),
alias=["add", "开启"],
help_text="设置/开启一个定时任务",
),
Subcommand(
"删除",
Args["schedule_id?", int],
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
Option("-g", Args["group_id", str], help_text="指定群组ID"),
Option("-all", help_text="对所有群生效"),
Option(
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
),
alias=["del", "rm", "remove", "关闭", "取消"],
help_text="删除一个或多个定时任务",
),
Subcommand(
"暂停",
Args["schedule_id?", int],
Option("-all", help_text="对当前群所有任务生效"),
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
Option("-g", Args["group_id", str], help_text="指定群组ID (SUPERUSER)"),
Option(
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
),
alias=["pause"],
help_text="暂停一个或多个定时任务",
),
Subcommand(
"恢复",
Args["schedule_id?", int],
Option("-all", help_text="对当前群所有任务生效"),
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
Option("-g", Args["group_id", str], help_text="指定群组ID (SUPERUSER)"),
Option(
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
),
alias=["resume"],
help_text="恢复一个或多个定时任务",
),
Subcommand(
"执行",
Args["schedule_id", int],
alias=["trigger", "run"],
help_text="立即执行一次任务",
),
Subcommand(
"更新",
Args["schedule_id", int],
Option("--cron", Args["cron_expr", str], help_text="设置 cron 表达式"),
Option("--interval", Args["interval_expr", str], help_text="设置时间间隔"),
Option("--date", Args["date_expr", str], help_text="设置特定执行日期"),
Option(
"--daily",
Args["daily_expr", str],
help_text="更新每天执行的时间 (如 08:20)",
),
Option("--kwargs", Args["kwargs_str", str], help_text="更新参数"),
alias=["update", "modify", "修改"],
help_text="更新任务配置",
),
Subcommand(
"状态",
Args["schedule_id", int],
alias=["status", "info"],
help_text="查看单个任务的详细状态",
),
Subcommand(
"插件列表",
alias=["plugins"],
help_text="列出所有可用的插件",
),
),
priority=5,
block=True,
rule=admin_check(1),
)
schedule_cmd.shortcut(
"任务状态",
command="定时任务",
arguments=["状态", "{%0}"],
prefix=True,
)
class ScheduleTarget:
pass
class TargetByID(ScheduleTarget):
def __init__(self, id: int):
self.id = id
class TargetByPlugin(ScheduleTarget):
def __init__(
self, plugin: str, group_id: str | None = None, all_groups: bool = False
):
self.plugin = plugin
self.group_id = group_id
self.all_groups = all_groups
class TargetAll(ScheduleTarget):
def __init__(self, for_group: str | None = None):
self.for_group = for_group
TargetScope = TargetByID | TargetByPlugin | TargetAll | None
def create_target_parser(subcommand_name: str):
async def dependency(
event: Event,
schedule_id: Match[int] = AlconnaMatch("schedule_id"),
plugin_name: Match[str] = AlconnaMatch("plugin_name"),
group_id: Match[str] = AlconnaMatch("group_id"),
all_enabled: Query[bool] = Query(f"{subcommand_name}.all"),
) -> TargetScope:
if schedule_id.available:
return TargetByID(schedule_id.result)
if plugin_name.available:
p_name = plugin_name.result
if all_enabled.available:
return TargetByPlugin(plugin=p_name, all_groups=True)
elif group_id.available:
gid = group_id.result
if gid.lower() == "all":
return TargetByPlugin(plugin=p_name, all_groups=True)
return TargetByPlugin(plugin=p_name, group_id=gid)
else:
current_group_id = getattr(event, "group_id", None)
return TargetByPlugin(
plugin=p_name,
group_id=str(current_group_id) if current_group_id else None,
)
if all_enabled.available:
current_group_id = getattr(event, "group_id", None)
if not current_group_id:
await schedule_cmd.finish(
"私聊中单独使用 -all 选项时,必须使用 -g <群号> 指定目标。"
)
return TargetAll(for_group=str(current_group_id))
return None
return dependency
def parse_interval(interval_str: str) -> dict:
match = re.match(r"(\d+)([smhd])", interval_str.lower())
if not match:
raise ValueError("时间间隔格式错误, 请使用如 '30m', '2h', '1d', '10s' 的格式。")
value, unit = int(match.group(1)), match.group(2)
if unit == "s":
return {"seconds": value}
if unit == "m":
return {"minutes": value}
if unit == "h":
return {"hours": value}
if unit == "d":
return {"days": value}
return {}
def parse_daily_time(time_str: str) -> dict:
if match := re.match(r"^(\d{1,2}):(\d{1,2})(?::(\d{1,2}))?$", time_str):
hour, minute, second = match.groups()
hour, minute = int(hour), int(minute)
if not (0 <= hour <= 23 and 0 <= minute <= 59):
raise ValueError("小时或分钟数值超出范围。")
cron_config = {
"minute": str(minute),
"hour": str(hour),
"day": "*",
"month": "*",
"day_of_week": "*",
"timezone": Config.get_config("SchedulerManager", "SCHEDULER_TIMEZONE"),
}
if second is not None:
if not (0 <= int(second) <= 59):
raise ValueError("秒数值超出范围。")
cron_config["second"] = str(second)
return cron_config
else:
raise ValueError("时间格式错误,请使用 'HH:MM''HH:MM:SS' 格式。")
async def GetBotId(bot: Bot, bot_id_match: Match[str] = AlconnaMatch("bot_id")) -> str:
if bot_id_match.available:
return bot_id_match.result
return bot.self_id
def GetTargeter(subcommand: str):
"""
依赖注入函数用于解析命令参数并返回一个配置好的 ScheduleTargeter 实例
"""
async def dependency(
event: Event,
bot: Bot,
schedule_id: Match[int] = AlconnaMatch("schedule_id"),
plugin_name: Match[str] = AlconnaMatch("plugin_name"),
group_id: Match[str] = AlconnaMatch("group_id"),
all_enabled: Query[bool] = Query(f"{subcommand}.all"),
bot_id_to_operate: str = Depends(GetBotId),
) -> ScheduleTargeter:
if schedule_id.available:
return scheduler_manager.target(id=schedule_id.result)
if plugin_name.available:
if all_enabled.available:
return scheduler_manager.target(plugin_name=plugin_name.result)
current_group_id = getattr(event, "group_id", None)
gid = group_id.result if group_id.available else current_group_id
return scheduler_manager.target(
plugin_name=plugin_name.result,
group_id=str(gid) if gid else None,
bot_id=bot_id_to_operate,
)
if all_enabled.available:
current_group_id = getattr(event, "group_id", None)
gid = group_id.result if group_id.available else current_group_id
is_su = await SUPERUSER(bot, event)
if not gid and not is_su:
await schedule_cmd.finish(
f"在私聊中对所有任务进行'{subcommand}'操作需要超级用户权限。"
)
if (gid and str(gid).lower() == "all") or (not gid and is_su):
return scheduler_manager.target()
return scheduler_manager.target(
group_id=str(gid) if gid else None, bot_id=bot_id_to_operate
)
await schedule_cmd.finish(
f"'{subcommand}'操作失败请提供任务ID"
f"或通过 -p <插件名> 或 -all 指定要操作的任务。"
)
return Depends(dependency)

View File

@ -0,0 +1,380 @@
from datetime import datetime
from nonebot.adapters import Event
from nonebot.adapters.onebot.v11 import Bot
from nonebot.params import Depends
from nonebot.permission import SUPERUSER
from nonebot_plugin_alconna import AlconnaMatch, Arparma, Match, Query
from pydantic import BaseModel, ValidationError
from zhenxun.models.schedule_info import ScheduleInfo
from zhenxun.services.scheduler import scheduler_manager
from zhenxun.services.scheduler.targeter import ScheduleTargeter
from zhenxun.utils.message import MessageUtils
from . import presenters
from .commands import (
GetBotId,
GetTargeter,
parse_daily_time,
parse_interval,
schedule_cmd,
)
@schedule_cmd.handle()
async def _handle_time_options_mutex(arp: Arparma):
time_options = ["cron", "interval", "date", "daily"]
provided_options = [opt for opt in time_options if arp.query(opt) is not None]
if len(provided_options) > 1:
await schedule_cmd.finish(
f"时间选项 --{', --'.join(provided_options)} 不能同时使用,请只选择一个。"
)
@schedule_cmd.assign("查看")
async def handle_view(
bot: Bot,
event: Event,
target_group_id: Match[str] = AlconnaMatch("target_group_id"),
all_groups: Query[bool] = Query("查看.all"),
plugin_name: Match[str] = AlconnaMatch("plugin_name"),
page: Match[int] = AlconnaMatch("page"),
):
is_superuser = await SUPERUSER(bot, event)
title = ""
gid_filter = None
current_group_id = getattr(event, "group_id", None)
if not (all_groups.available or target_group_id.available) and not current_group_id:
await schedule_cmd.finish("私聊中查看任务必须使用 -g <群号> 或 -all 选项。")
if all_groups.available:
if not is_superuser:
await schedule_cmd.finish("需要超级用户权限才能查看所有群组的定时任务。")
title = "所有群组的定时任务"
elif target_group_id.available:
if not is_superuser:
await schedule_cmd.finish("需要超级用户权限才能查看指定群组的定时任务。")
gid_filter = target_group_id.result
title = f"{gid_filter} 的定时任务"
else:
gid_filter = str(current_group_id)
title = "本群的定时任务"
p_name_filter = plugin_name.result if plugin_name.available else None
schedules = await scheduler_manager.get_schedules(
plugin_name=p_name_filter, group_id=gid_filter
)
if p_name_filter:
title += f" [插件: {p_name_filter}]"
if not schedules:
await schedule_cmd.finish("没有找到任何相关的定时任务。")
img = await presenters.format_schedule_list_as_image(
schedules=schedules, title=title, current_page=page.result
)
await MessageUtils.build_message(img).send(reply_to=True)
@schedule_cmd.assign("设置")
async def handle_set(
event: Event,
plugin_name: Match[str] = AlconnaMatch("plugin_name"),
cron_expr: Match[str] = AlconnaMatch("cron_expr"),
interval_expr: Match[str] = AlconnaMatch("interval_expr"),
date_expr: Match[str] = AlconnaMatch("date_expr"),
daily_expr: Match[str] = AlconnaMatch("daily_expr"),
group_id: Match[str] = AlconnaMatch("group_id"),
kwargs_str: Match[str] = AlconnaMatch("kwargs_str"),
all_enabled: Query[bool] = Query("设置.all"),
bot_id_to_operate: str = Depends(GetBotId),
):
if not plugin_name.available:
await schedule_cmd.finish("设置任务时必须提供插件名称。")
has_time_option = any(
[
cron_expr.available,
interval_expr.available,
date_expr.available,
daily_expr.available,
]
)
if not has_time_option:
await schedule_cmd.finish(
"必须提供一种时间选项: --cron, --interval, --date, 或 --daily。"
)
p_name = plugin_name.result
if p_name not in scheduler_manager.get_registered_plugins():
await schedule_cmd.finish(
f"插件 '{p_name}' 没有注册可用的定时任务。\n"
f"可用插件: {list(scheduler_manager.get_registered_plugins())}"
)
trigger_type, trigger_config = "", {}
try:
if cron_expr.available:
trigger_type, trigger_config = (
"cron",
dict(
zip(
["minute", "hour", "day", "month", "day_of_week"],
cron_expr.result.split(),
)
),
)
elif interval_expr.available:
trigger_type, trigger_config = (
"interval",
parse_interval(interval_expr.result),
)
elif date_expr.available:
trigger_type, trigger_config = (
"date",
{"run_date": datetime.fromisoformat(date_expr.result)},
)
elif daily_expr.available:
trigger_type, trigger_config = "cron", parse_daily_time(daily_expr.result)
else:
await schedule_cmd.finish(
"必须提供一种时间选项: --cron, --interval, --date, 或 --daily。"
)
except ValueError as e:
await schedule_cmd.finish(f"时间参数解析错误: {e}")
job_kwargs = {}
if kwargs_str.available:
task_meta = scheduler_manager._registered_tasks[p_name]
params_model = task_meta.get("model")
if not (
params_model
and isinstance(params_model, type)
and issubclass(params_model, BaseModel)
):
await schedule_cmd.finish(f"插件 '{p_name}' 不支持或配置了无效的参数模型。")
try:
raw_kwargs = dict(
item.strip().split("=", 1) for item in kwargs_str.result.split(",")
)
model_validate = getattr(params_model, "model_validate", None)
if not model_validate:
await schedule_cmd.finish(f"插件 '{p_name}' 的参数模型不支持验证")
validated_model = model_validate(raw_kwargs)
model_dump = getattr(validated_model, "model_dump", None)
if not model_dump:
await schedule_cmd.finish(f"插件 '{p_name}' 的参数模型不支持导出")
job_kwargs = model_dump()
except ValidationError as e:
errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()]
await schedule_cmd.finish(
f"插件 '{p_name}' 的任务参数验证失败:\n" + "\n".join(errors)
)
except Exception as e:
await schedule_cmd.finish(
f"参数格式错误,请使用 'key=value,key2=value2' 格式。错误: {e}"
)
gid_str = group_id.result if group_id.available else None
target_group_id = (
scheduler_manager.ALL_GROUPS
if (gid_str and gid_str.lower() == "all") or all_enabled.available
else gid_str or getattr(event, "group_id", None)
)
if not target_group_id:
await schedule_cmd.finish(
"私聊中设置定时任务时,必须使用 -g <群号> 或 --all 选项指定目标。"
)
schedule = await scheduler_manager.add_schedule(
p_name,
str(target_group_id),
trigger_type,
trigger_config,
job_kwargs,
bot_id=bot_id_to_operate,
)
target_desc = (
f"所有群组 (Bot: {bot_id_to_operate})"
if target_group_id == scheduler_manager.ALL_GROUPS
else f"群组 {target_group_id}"
)
if schedule:
await schedule_cmd.finish(
f"为 [{target_desc}] 已成功设置插件 '{p_name}' 的定时任务 "
f"(ID: {schedule.id})。"
)
else:
await schedule_cmd.finish(f"为 [{target_desc}] 设置任务失败。")
@schedule_cmd.assign("删除")
async def handle_delete(targeter: ScheduleTargeter = GetTargeter("删除")):
schedules_to_remove: list[ScheduleInfo] = await targeter._get_schedules()
if not schedules_to_remove:
await schedule_cmd.finish("没有找到可删除的任务。")
count, _ = await targeter.remove()
if count > 0 and schedules_to_remove:
if len(schedules_to_remove) == 1:
message = presenters.format_remove_success(schedules_to_remove[0])
else:
target_desc = targeter._generate_target_description()
message = f"✅ 成功移除了{target_desc} {count} 个任务。"
else:
message = "没有任务被移除。"
await schedule_cmd.finish(message)
@schedule_cmd.assign("暂停")
async def handle_pause(targeter: ScheduleTargeter = GetTargeter("暂停")):
schedules_to_pause: list[ScheduleInfo] = await targeter._get_schedules()
if not schedules_to_pause:
await schedule_cmd.finish("没有找到可暂停的任务。")
count, _ = await targeter.pause()
if count > 0 and schedules_to_pause:
if len(schedules_to_pause) == 1:
message = presenters.format_pause_success(schedules_to_pause[0])
else:
target_desc = targeter._generate_target_description()
message = f"✅ 成功暂停了{target_desc} {count} 个任务。"
else:
message = "没有任务被暂停。"
await schedule_cmd.finish(message)
@schedule_cmd.assign("恢复")
async def handle_resume(targeter: ScheduleTargeter = GetTargeter("恢复")):
schedules_to_resume: list[ScheduleInfo] = await targeter._get_schedules()
if not schedules_to_resume:
await schedule_cmd.finish("没有找到可恢复的任务。")
count, _ = await targeter.resume()
if count > 0 and schedules_to_resume:
if len(schedules_to_resume) == 1:
message = presenters.format_resume_success(schedules_to_resume[0])
else:
target_desc = targeter._generate_target_description()
message = f"✅ 成功恢复了{target_desc} {count} 个任务。"
else:
message = "没有任务被恢复。"
await schedule_cmd.finish(message)
@schedule_cmd.assign("执行")
async def handle_trigger(schedule_id: Match[int] = AlconnaMatch("schedule_id")):
from zhenxun.services.scheduler.repository import ScheduleRepository
schedule_info = await ScheduleRepository.get_by_id(schedule_id.result)
if not schedule_info:
await schedule_cmd.finish(f"未找到 ID 为 {schedule_id.result} 的任务。")
success, message = await scheduler_manager.trigger_now(schedule_id.result)
if success:
final_message = presenters.format_trigger_success(schedule_info)
else:
final_message = f"❌ 手动触发失败: {message}"
await schedule_cmd.finish(final_message)
@schedule_cmd.assign("更新")
async def handle_update(
schedule_id: Match[int] = AlconnaMatch("schedule_id"),
cron_expr: Match[str] = AlconnaMatch("cron_expr"),
interval_expr: Match[str] = AlconnaMatch("interval_expr"),
date_expr: Match[str] = AlconnaMatch("date_expr"),
daily_expr: Match[str] = AlconnaMatch("daily_expr"),
kwargs_str: Match[str] = AlconnaMatch("kwargs_str"),
):
if not any(
[
cron_expr.available,
interval_expr.available,
date_expr.available,
daily_expr.available,
kwargs_str.available,
]
):
await schedule_cmd.finish(
"请提供需要更新的时间 (--cron/--interval/--date/--daily) 或参数 (--kwargs)"
)
trigger_type, trigger_config, job_kwargs = None, None, None
try:
if cron_expr.available:
trigger_type, trigger_config = (
"cron",
dict(
zip(
["minute", "hour", "day", "month", "day_of_week"],
cron_expr.result.split(),
)
),
)
elif interval_expr.available:
trigger_type, trigger_config = (
"interval",
parse_interval(interval_expr.result),
)
elif date_expr.available:
trigger_type, trigger_config = (
"date",
{"run_date": datetime.fromisoformat(date_expr.result)},
)
elif daily_expr.available:
trigger_type, trigger_config = "cron", parse_daily_time(daily_expr.result)
except ValueError as e:
await schedule_cmd.finish(f"时间参数解析错误: {e}")
if kwargs_str.available:
job_kwargs = dict(
item.strip().split("=", 1) for item in kwargs_str.result.split(",")
)
success, message = await scheduler_manager.update_schedule(
schedule_id.result, trigger_type, trigger_config, job_kwargs
)
if success:
from zhenxun.services.scheduler.repository import ScheduleRepository
updated_schedule = await ScheduleRepository.get_by_id(schedule_id.result)
if updated_schedule:
final_message = presenters.format_update_success(updated_schedule)
else:
final_message = "✅ 更新成功,但无法获取更新后的任务详情。"
else:
final_message = f"❌ 更新失败: {message}"
await schedule_cmd.finish(final_message)
@schedule_cmd.assign("插件列表")
async def handle_plugins_list():
message = await presenters.format_plugins_list()
await schedule_cmd.finish(message)
@schedule_cmd.assign("状态")
async def handle_status(schedule_id: Match[int] = AlconnaMatch("schedule_id")):
status = await scheduler_manager.get_schedule_status(schedule_id.result)
if not status:
await schedule_cmd.finish(f"未找到ID为 {schedule_id.result} 的定时任务。")
message = presenters.format_single_status_message(status)
await schedule_cmd.finish(message)

View File

@ -0,0 +1,274 @@
import asyncio
from zhenxun.models.schedule_info import ScheduleInfo
from zhenxun.services.scheduler import scheduler_manager
from zhenxun.utils._image_template import ImageTemplate, RowStyle
def _get_type_name(annotation) -> str:
"""获取类型注解的名称"""
if hasattr(annotation, "__name__"):
return annotation.__name__
elif hasattr(annotation, "_name"):
return annotation._name
else:
return str(annotation)
def _format_trigger(schedule: dict) -> str:
"""格式化触发器信息为可读字符串"""
trigger_type = schedule.get("trigger_type")
config = schedule.get("trigger_config")
if not isinstance(config, dict):
return f"配置错误: {config}"
if trigger_type == "cron":
hour = config.get("hour", "??")
minute = config.get("minute", "??")
try:
hour_int = int(hour)
minute_int = int(minute)
return f"每天 {hour_int:02d}:{minute_int:02d}"
except (ValueError, TypeError):
return f"每天 {hour}:{minute}"
elif trigger_type == "interval":
units = {
"weeks": "",
"days": "",
"hours": "小时",
"minutes": "分钟",
"seconds": "",
}
for unit, unit_name in units.items():
if value := config.get(unit):
return f"{value} {unit_name}"
return "未知间隔"
elif trigger_type == "date":
run_date = config.get("run_date", "N/A")
return f"特定时间 {run_date}"
else:
return f"未知触发器类型: {trigger_type}"
def _format_trigger_for_card(schedule_info: ScheduleInfo | dict) -> str:
"""为信息卡片格式化触发器规则"""
trigger_type = (
schedule_info.get("trigger_type")
if isinstance(schedule_info, dict)
else schedule_info.trigger_type
)
config = (
schedule_info.get("trigger_config")
if isinstance(schedule_info, dict)
else schedule_info.trigger_config
)
if not isinstance(config, dict):
return f"配置错误: {config}"
if trigger_type == "cron":
hour = config.get("hour", "??")
minute = config.get("minute", "??")
try:
hour_int = int(hour)
minute_int = int(minute)
return f"每天 {hour_int:02d}:{minute_int:02d}"
except (ValueError, TypeError):
return f"每天 {hour}:{minute}"
elif trigger_type == "interval":
units = {
"weeks": "",
"days": "",
"hours": "小时",
"minutes": "分钟",
"seconds": "",
}
for unit, unit_name in units.items():
if value := config.get(unit):
return f"{value} {unit_name}"
return "未知间隔"
elif trigger_type == "date":
run_date = config.get("run_date", "N/A")
return f"特定时间 {run_date}"
else:
return f"未知规则: {trigger_type}"
def _format_operation_result_card(
title: str, schedule_info: ScheduleInfo, extra_info: list[str] | None = None
) -> str:
"""
生成一个标准的操作结果信息卡片
参数:
title: 卡片的标题 (例如 "✅ 成功暂停定时任务!")
schedule_info: 相关的 ScheduleInfo 对象
extra_info: (可选) 额外的补充信息行
"""
target_desc = (
f"群组 {schedule_info.group_id}"
if schedule_info.group_id
and schedule_info.group_id != scheduler_manager.ALL_GROUPS
else "所有群组"
if schedule_info.group_id == scheduler_manager.ALL_GROUPS
else "全局"
)
info_lines = [
title,
f"✓ 任务 ID: {schedule_info.id}",
f"🖋 插件: {schedule_info.plugin_name}",
f"🎯 目标: {target_desc}",
f"⏰ 时间: {_format_trigger_for_card(schedule_info)}",
]
if extra_info:
info_lines.extend(extra_info)
return "\n".join(info_lines)
def format_pause_success(schedule_info: ScheduleInfo) -> str:
"""格式化暂停成功的消息"""
return _format_operation_result_card("✅ 成功暂停定时任务!", schedule_info)
def format_resume_success(schedule_info: ScheduleInfo) -> str:
"""格式化恢复成功的消息"""
return _format_operation_result_card("▶️ 成功恢复定时任务!", schedule_info)
def format_remove_success(schedule_info: ScheduleInfo) -> str:
"""格式化删除成功的消息"""
return _format_operation_result_card("❌ 成功删除定时任务!", schedule_info)
def format_trigger_success(schedule_info: ScheduleInfo) -> str:
"""格式化手动触发成功的消息"""
return _format_operation_result_card("🚀 成功手动触发定时任务!", schedule_info)
def format_update_success(schedule_info: ScheduleInfo) -> str:
"""格式化更新成功的消息"""
return _format_operation_result_card("🔄️ 成功更新定时任务配置!", schedule_info)
def _status_row_style(column: str, text: str) -> RowStyle:
"""为状态列设置颜色"""
style = RowStyle()
if column == "状态":
if text == "启用":
style.font_color = "#67C23A"
elif text == "暂停":
style.font_color = "#F56C6C"
elif text == "运行中":
style.font_color = "#409EFF"
return style
def _format_params(schedule_status: dict) -> str:
"""将任务参数格式化为人类可读的字符串"""
if kwargs := schedule_status.get("job_kwargs"):
return " | ".join(f"{k}: {v}" for k, v in kwargs.items())
return "-"
async def format_schedule_list_as_image(
schedules: list[ScheduleInfo], title: str, current_page: int
):
"""将任务列表格式化为图片"""
page_size = 15
total_items = len(schedules)
total_pages = (total_items + page_size - 1) // page_size
start_index = (current_page - 1) * page_size
end_index = start_index + page_size
paginated_schedules = schedules[start_index:end_index]
if not paginated_schedules:
return "这一页没有内容了哦~"
status_tasks = [
scheduler_manager.get_schedule_status(s.id) for s in paginated_schedules
]
all_statuses = await asyncio.gather(*status_tasks)
def get_status_text(status_value):
if isinstance(status_value, bool):
return "启用" if status_value else "暂停"
return str(status_value)
data_list = [
[
s["id"],
s["plugin_name"],
s.get("bot_id") or "N/A",
s["group_id"] or "全局",
s["next_run_time"],
_format_trigger(s),
_format_params(s),
get_status_text(s["is_enabled"]),
]
for s in all_statuses
if s
]
if not data_list:
return "没有找到任何相关的定时任务。"
return await ImageTemplate.table_page(
head_text=title,
tip_text=f"{current_page}/{total_pages} 页,共 {total_items} 条任务",
column_name=["ID", "插件", "Bot", "目标", "下次运行", "规则", "参数", "状态"],
data_list=data_list,
column_space=20,
text_style=_status_row_style,
)
def format_single_status_message(status: dict) -> str:
"""格式化单个任务状态为文本消息"""
info_lines = [
f"📋 定时任务详细信息 (ID: {status['id']})",
"--------------------",
f"▫️ 插件: {status['plugin_name']}",
f"▫️ Bot ID: {status.get('bot_id') or '默认'}",
f"▫️ 目标: {status['group_id'] or '全局'}",
f"▫️ 状态: {'✔️ 已启用' if status['is_enabled'] else '⏸️ 已暂停'}",
f"▫️ 下次运行: {status['next_run_time']}",
f"▫️ 触发规则: {_format_trigger(status)}",
f"▫️ 任务参数: {_format_params(status)}",
]
return "\n".join(info_lines)
async def format_plugins_list() -> str:
"""格式化可用插件列表为文本消息"""
from pydantic import BaseModel
registered_plugins = scheduler_manager.get_registered_plugins()
if not registered_plugins:
return "当前没有已注册的定时任务插件。"
message_parts = ["📋 已注册的定时任务插件:"]
for i, plugin_name in enumerate(registered_plugins, 1):
task_meta = scheduler_manager._registered_tasks[plugin_name]
params_model = task_meta.get("model")
param_info_str = "无参数"
if (
params_model
and isinstance(params_model, type)
and issubclass(params_model, BaseModel)
):
model_fields = getattr(params_model, "model_fields", None)
if model_fields:
param_info_str = "参数: " + ", ".join(
f"{field_name}({_get_type_name(field_info.annotation)})"
for field_name, field_info in model_fields.items()
)
elif params_model:
param_info_str = "⚠️ 参数模型配置错误"
message_parts.append(f"{i}. {plugin_name} - {param_info_str}")
return "\n".join(message_parts)

View File

@ -1,5 +1,3 @@
from datetime import datetime, timedelta
from tortoise.functions import Count
from zhenxun.models.group_console import GroupConsole
@ -10,6 +8,7 @@ from zhenxun.utils.echart_utils import ChartUtils
from zhenxun.utils.echart_utils.models import Barh
from zhenxun.utils.enum import PluginType
from zhenxun.utils.image_utils import BuildImage
from zhenxun.utils.utils import TimeUtils
class StatisticsManage:
@ -68,8 +67,7 @@ class StatisticsManage:
if plugin_name:
query = query.filter(plugin_name=plugin_name)
if day:
time = datetime.now() - timedelta(days=day)
query = query.filter(create_time__gte=time)
query = query.filter(create_time__gte=TimeUtils.get_day_start())
data_list = (
await query.annotate(count=Count("id"))
.group_by("plugin_name")
@ -89,8 +87,7 @@ class StatisticsManage:
if group_id:
query = query.filter(group_id=group_id)
if day:
time = datetime.now() - timedelta(days=day)
query = query.filter(create_time__gte=time)
query = query.filter(create_time__gte=TimeUtils.get_day_start())
data_list = (
await query.annotate(count=Count("id"))
.group_by("plugin_name")
@ -106,8 +103,7 @@ class StatisticsManage:
async def get_group_statistics(cls, group_id: str, day: int | None, title: str):
query = Statistics.filter(group_id=group_id)
if day:
time = datetime.now() - timedelta(days=day)
query = query.filter(create_time__gte=time)
query = query.filter(create_time__gte=TimeUtils.get_day_start())
data_list = (
await query.annotate(count=Count("id"))
.group_by("plugin_name")

View File

@ -28,7 +28,7 @@ from nonebot_plugin_alconna.uniseg.segment import (
)
from nonebot_plugin_session import EventSession
from zhenxun.configs.utils import PluginExtraData, RegisterConfig, Task
from zhenxun.configs.utils import PluginExtraData, Task
from zhenxun.utils.enum import PluginType
from zhenxun.utils.message import MessageUtils
@ -73,16 +73,6 @@ __plugin_meta__ = PluginMetadata(
author="HibiKier",
version="1.2",
plugin_type=PluginType.SUPERUSER,
configs=[
RegisterConfig(
module="_task",
key="DEFAULT_BROADCAST",
value=True,
help="被动 广播 进群默认开关状态",
default_value=True,
type=bool,
)
],
tasks=[Task(module="broadcast", name="广播")],
).to_dict(),
)

View File

@ -106,18 +106,34 @@ class ConfigGroup(BaseModel):
if value_to_process is None:
return default
if cfg.type:
if build_model and _is_pydantic_type(cfg.type):
try:
return parse_as(cfg.type, value_to_process)
except Exception as e:
logger.warning(f"Pydantic 模型解析失败 (key: {c.upper()}). ", e=e)
if cfg.arg_parser:
try:
return cattrs.structure(value_to_process, cfg.type)
return cfg.arg_parser(value_to_process)
except Exception as e:
logger.warning(f"Cattrs 结构化失败 (key: {key}),返回原始值。", e=e)
logger.debug(
f"配置项类型转换 MODULE: [<u><y>{self.module}</y></u>] | "
f"KEY: [<u><y>{key}</y></u>] 的自定义解析器失败,将使用原始值",
e=e,
)
return value_to_process
return value_to_process
if not build_model or not cfg.type:
return value_to_process
try:
if _is_pydantic_type(cfg.type):
parsed_value = parse_as(cfg.type, value_to_process)
return parsed_value
else:
structured_value = cattrs.structure(value_to_process, cfg.type)
return structured_value
except Exception as e:
logger.error(
f"❌ 配置项 '{self.module}.{key}' 自动类型转换失败 "
f"(目标类型: {cfg.type}),将返回原始值。请检查配置文件格式。错误: {e}",
e=e,
)
return value_to_process
def to_dict(self, **kwargs):
return model_dump(self, **kwargs)

View File

@ -49,7 +49,8 @@ class ChatHistory(Model):
o = "-" if order == "DESC" else ""
query = cls.filter(group_id=gid) if gid else cls
if date_scope:
query = query.filter(create_time__range=date_scope)
filter_scope = (date_scope[0].isoformat(" "), date_scope[1].isoformat(" "))
query = query.filter(create_time__range=filter_scope)
return list(
await query.annotate(count=Count("user_id"))
.order_by(f"{o}count")

View File

@ -90,13 +90,14 @@ class LevelUser(Model):
返回:
bool: 是否大于level
"""
if level == 0:
return True
if group_id:
if user := await cls.get_or_none(user_id=user_id, group_id=group_id):
return user.user_level >= level
else:
if user_list := await cls.filter(user_id=user_id).all():
user = max(user_list, key=lambda x: x.user_level)
return user.user_level >= level
elif user_list := await cls.filter(user_id=user_id).all():
user = max(user_list, key=lambda x: x.user_level)
return user.user_level >= level
return False
@classmethod
@ -119,8 +120,7 @@ class LevelUser(Model):
return [
# 将user_id改为user_id
"ALTER TABLE level_users RENAME COLUMN user_qq TO user_id;",
"ALTER TABLE level_users "
"ALTER COLUMN user_id TYPE character varying(255);",
"ALTER TABLE level_users ALTER COLUMN user_id TYPE character varying(255);",
# 将user_id字段类型改为character varying(255)
"ALTER TABLE level_users "
"ALTER COLUMN group_id TYPE character varying(255);",

View File

@ -1,3 +1,14 @@
"""
Zhenxun Bot - 核心服务模块
主要服务包括
- 数据库上下文 (db_context): 提供数据库模型基类和连接管理
- 日志服务 (log): 提供增强的带上下文的日志记录器
- LLM服务 (llm): 提供与大语言模型交互的统一API
- 插件生命周期管理 (plugin_init): 支持插件安装和卸载时的钩子函数
- 定时任务调度器 (scheduler): 提供持久化的可管理的定时任务服务
"""
from nonebot import require
require("nonebot_plugin_apscheduler")
@ -6,3 +17,33 @@ require("nonebot_plugin_session")
require("nonebot_plugin_htmlrender")
require("nonebot_plugin_uninfo")
require("nonebot_plugin_waiter")
from .db_context import Model, disconnect
from .llm import (
AI,
LLMContentPart,
LLMException,
LLMMessage,
get_model_instance,
list_available_models,
tool_registry,
)
from .log import logger
from .plugin_init import PluginInit, PluginInitManager
from .scheduler import scheduler_manager
__all__ = [
"AI",
"LLMContentPart",
"LLMException",
"LLMMessage",
"Model",
"PluginInit",
"PluginInitManager",
"disconnect",
"get_model_instance",
"list_available_models",
"logger",
"scheduler_manager",
"tool_registry",
]

File diff suppressed because it is too large Load Diff

View File

@ -10,10 +10,10 @@ from .api import (
TaskType,
analyze,
analyze_multimodal,
analyze_with_images,
chat,
code,
embed,
pipeline_chat,
search,
search_multimodal,
)
@ -35,6 +35,7 @@ from .manager import (
list_model_identifiers,
set_global_default_model_name,
)
from .tools import tool_registry
from .types import (
EmbeddingTaskType,
LLMContentPart,
@ -43,6 +44,7 @@ from .types import (
LLMMessage,
LLMResponse,
LLMTool,
MCPCompatible,
ModelDetail,
ModelInfo,
ModelProvider,
@ -51,7 +53,7 @@ from .types import (
ToolMetadata,
UsageInfo,
)
from .utils import create_multimodal_message, unimsg_to_llm_parts
from .utils import create_multimodal_message, message_to_unimessage, unimsg_to_llm_parts
__all__ = [
"AI",
@ -65,6 +67,7 @@ __all__ = [
"LLMMessage",
"LLMResponse",
"LLMTool",
"MCPCompatible",
"ModelDetail",
"ModelInfo",
"ModelName",
@ -76,7 +79,6 @@ __all__ = [
"UsageInfo",
"analyze",
"analyze_multimodal",
"analyze_with_images",
"chat",
"clear_model_cache",
"code",
@ -88,9 +90,12 @@ __all__ = [
"list_available_models",
"list_embedding_models",
"list_model_identifiers",
"message_to_unimessage",
"pipeline_chat",
"register_llm_configs",
"search",
"search_multimodal",
"set_global_default_model_name",
"tool_registry",
"unimsg_to_llm_parts",
]

View File

@ -8,7 +8,6 @@ from .base import BaseAdapter, OpenAICompatAdapter, RequestData, ResponseData
from .factory import LLMAdapterFactory, get_adapter_for_api_type, register_adapter
from .gemini import GeminiAdapter
from .openai import OpenAIAdapter
from .zhipu import ZhipuAdapter
LLMAdapterFactory.initialize()
@ -20,7 +19,6 @@ __all__ = [
"OpenAICompatAdapter",
"RequestData",
"ResponseData",
"ZhipuAdapter",
"get_adapter_for_api_type",
"register_adapter",
]

View File

@ -17,6 +17,7 @@ if TYPE_CHECKING:
from ..service import LLMModel
from ..types.content import LLMMessage
from ..types.enums import EmbeddingTaskType
from ..types.models import LLMTool
class RequestData(BaseModel):
@ -60,7 +61,7 @@ class BaseAdapter(ABC):
"""支持的API类型列表"""
pass
def prepare_simple_request(
async def prepare_simple_request(
self,
model: "LLMModel",
api_key: str,
@ -86,7 +87,7 @@ class BaseAdapter(ABC):
config = model._generation_config
return self.prepare_advanced_request(
return await self.prepare_advanced_request(
model=model,
api_key=api_key,
messages=messages,
@ -96,13 +97,13 @@ class BaseAdapter(ABC):
)
@abstractmethod
def prepare_advanced_request(
async def prepare_advanced_request(
self,
model: "LLMModel",
api_key: str,
messages: list["LLMMessage"],
config: "LLMGenerationConfig | None" = None,
tools: list[dict[str, Any]] | None = None,
tools: list["LLMTool"] | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> RequestData:
"""准备高级请求"""
@ -238,6 +239,9 @@ class BaseAdapter(ABC):
message = choice.get("message", {})
content = message.get("content", "")
if content:
content = content.strip()
parsed_tool_calls: list[LLMToolCall] | None = None
if message_tool_calls := message.get("tool_calls"):
from ..types.models import LLMToolFunction
@ -375,7 +379,7 @@ class BaseAdapter(ABC):
if model.temperature is not None:
base_config["temperature"] = model.temperature
if model.max_tokens is not None:
if model.api_type in ["gemini", "gemini_native"]:
if model.api_type == "gemini":
base_config["maxOutputTokens"] = model.max_tokens
else:
base_config["max_tokens"] = model.max_tokens
@ -401,26 +405,51 @@ class OpenAICompatAdapter(BaseAdapter):
"""
@abstractmethod
def get_chat_endpoint(self) -> str:
def get_chat_endpoint(self, model: "LLMModel") -> str:
"""子类必须实现,返回 chat completions 的端点"""
pass
@abstractmethod
def get_embedding_endpoint(self) -> str:
def get_embedding_endpoint(self, model: "LLMModel") -> str:
"""子类必须实现,返回 embeddings 的端点"""
pass
def prepare_advanced_request(
async def prepare_simple_request(
self,
model: "LLMModel",
api_key: str,
prompt: str,
history: list[dict[str, str]] | None = None,
) -> RequestData:
"""准备简单文本生成请求 - OpenAI兼容API的通用实现"""
url = self.get_api_url(model, self.get_chat_endpoint(model))
headers = self.get_base_headers(api_key)
messages = []
if history:
messages.extend(history)
messages.append({"role": "user", "content": prompt})
body = {
"model": model.model_name,
"messages": messages,
}
body = self.apply_config_override(model, body)
return RequestData(url=url, headers=headers, body=body)
async def prepare_advanced_request(
self,
model: "LLMModel",
api_key: str,
messages: list["LLMMessage"],
config: "LLMGenerationConfig | None" = None,
tools: list[dict[str, Any]] | None = None,
tools: list["LLMTool"] | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> RequestData:
"""准备高级请求 - OpenAI兼容格式"""
url = self.get_api_url(model, self.get_chat_endpoint())
url = self.get_api_url(model, self.get_chat_endpoint(model))
headers = self.get_base_headers(api_key)
openai_messages = self.convert_messages_to_openai_format(messages)
@ -430,7 +459,21 @@ class OpenAICompatAdapter(BaseAdapter):
}
if tools:
body["tools"] = tools
openai_tools = []
for tool in tools:
if tool.type == "function" and tool.function:
openai_tools.append({"type": "function", "function": tool.function})
elif tool.type == "mcp" and tool.mcp_session:
if callable(tool.mcp_session):
raise ValueError(
"适配器接收到未激活的 MCP 会话工厂。"
"会话工厂应该在 LLMModel.generate_response 中被激活。"
)
openai_tools.append(
tool.mcp_session.to_api_tool(api_type=self.api_type)
)
if openai_tools:
body["tools"] = openai_tools
if tool_choice:
body["tool_choice"] = tool_choice
@ -444,7 +487,7 @@ class OpenAICompatAdapter(BaseAdapter):
is_advanced: bool = False,
) -> ResponseData:
"""解析响应 - 直接使用基类的 OpenAI 格式解析"""
_ = model, is_advanced # 未使用的参数
_ = model, is_advanced
return self.parse_openai_response(response_json)
def prepare_embedding_request(
@ -456,8 +499,8 @@ class OpenAICompatAdapter(BaseAdapter):
**kwargs: Any,
) -> RequestData:
"""准备嵌入请求 - OpenAI兼容格式"""
_ = task_type # 未使用的参数
url = self.get_api_url(model, self.get_embedding_endpoint())
_ = task_type
url = self.get_api_url(model, self.get_embedding_endpoint(model))
headers = self.get_base_headers(api_key)
body = {
@ -465,7 +508,6 @@ class OpenAICompatAdapter(BaseAdapter):
"input": texts,
}
# 应用额外的配置参数
if kwargs:
body.update(kwargs)

View File

@ -22,10 +22,8 @@ class LLMAdapterFactory:
from .gemini import GeminiAdapter
from .openai import OpenAIAdapter
from .zhipu import ZhipuAdapter
cls.register_adapter(OpenAIAdapter())
cls.register_adapter(ZhipuAdapter())
cls.register_adapter(GeminiAdapter())
@classmethod

View File

@ -14,7 +14,7 @@ if TYPE_CHECKING:
from ..service import LLMModel
from ..types.content import LLMMessage
from ..types.enums import EmbeddingTaskType
from ..types.models import LLMToolCall
from ..types.models import LLMTool, LLMToolCall
class GeminiAdapter(BaseAdapter):
@ -38,30 +38,16 @@ class GeminiAdapter(BaseAdapter):
return headers
def prepare_advanced_request(
async def prepare_advanced_request(
self,
model: "LLMModel",
api_key: str,
messages: list["LLMMessage"],
config: "LLMGenerationConfig | None" = None,
tools: list[dict[str, Any]] | None = None,
tools: list["LLMTool"] | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> RequestData:
"""准备高级请求"""
return self._prepare_request(
model, api_key, messages, config, tools, tool_choice
)
def _prepare_request(
self,
model: "LLMModel",
api_key: str,
messages: list["LLMMessage"],
config: "LLMGenerationConfig | None" = None,
tools: list[dict[str, Any]] | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> RequestData:
"""准备 Gemini API 请求 - 支持所有高级功能"""
effective_config = config if config is not None else model._generation_config
endpoint = self._get_gemini_endpoint(model, effective_config)
@ -78,7 +64,8 @@ class GeminiAdapter(BaseAdapter):
system_instruction_parts = [{"text": msg.content}]
elif isinstance(msg.content, list):
system_instruction_parts = [
part.convert_for_api("gemini") for part in msg.content
await part.convert_for_api_async("gemini")
for part in msg.content
]
continue
@ -87,7 +74,9 @@ class GeminiAdapter(BaseAdapter):
current_parts.append({"text": msg.content})
elif isinstance(msg.content, list):
for part_obj in msg.content:
current_parts.append(part_obj.convert_for_api("gemini"))
current_parts.append(
await part_obj.convert_for_api_async("gemini")
)
gemini_contents.append({"role": "user", "parts": current_parts})
elif msg.role == "assistant" or msg.role == "model":
@ -95,7 +84,9 @@ class GeminiAdapter(BaseAdapter):
current_parts.append({"text": msg.content})
elif isinstance(msg.content, list):
for part_obj in msg.content:
current_parts.append(part_obj.convert_for_api("gemini"))
current_parts.append(
await part_obj.convert_for_api_async("gemini")
)
if msg.tool_calls:
import json
@ -154,16 +145,22 @@ class GeminiAdapter(BaseAdapter):
all_tools_for_request = []
if tools:
for tool_item in tools:
if isinstance(tool_item, dict):
if "name" in tool_item and "description" in tool_item:
all_tools_for_request.append(
{"functionDeclarations": [tool_item]}
for tool in tools:
if tool.type == "function" and tool.function:
all_tools_for_request.append(
{"functionDeclarations": [tool.function]}
)
elif tool.type == "mcp" and tool.mcp_session:
if callable(tool.mcp_session):
raise ValueError(
"适配器接收到未激活的 MCP 会话工厂。"
"会话工厂应该在 LLMModel.generate_response 中被激活。"
)
else:
all_tools_for_request.append(tool_item)
else:
all_tools_for_request.append(tool_item)
all_tools_for_request.append(
tool.mcp_session.to_api_tool(api_type=self.api_type)
)
elif tool.type == "google_search":
all_tools_for_request.append({"googleSearch": {}})
if effective_config:
if getattr(effective_config, "enable_grounding", False):
@ -183,11 +180,7 @@ class GeminiAdapter(BaseAdapter):
logger.debug("隐式启用代码执行工具。")
if all_tools_for_request:
gemini_api_tools = self._convert_tools_to_gemini_format(
all_tools_for_request
)
if gemini_api_tools:
body["tools"] = gemini_api_tools
body["tools"] = all_tools_for_request
final_tool_choice = tool_choice
if final_tool_choice is None and effective_config:
@ -241,38 +234,6 @@ class GeminiAdapter(BaseAdapter):
return f"/v1beta/models/{model.model_name}:generateContent"
def _convert_tools_to_gemini_format(
self, tools: list[dict[str, Any]]
) -> list[dict[str, Any]]:
"""转换工具格式为Gemini格式"""
gemini_tools = []
for tool in tools:
if tool.get("type") == "function":
func = tool["function"]
gemini_tool = {
"functionDeclarations": [
{
"name": func["name"],
"description": func.get("description", ""),
"parameters": func.get("parameters", {}),
}
]
}
gemini_tools.append(gemini_tool)
elif tool.get("type") == "code_execution":
gemini_tools.append(
{"codeExecution": {"language": tool.get("language", "python")}}
)
elif tool.get("type") == "google_search":
gemini_tools.append({"googleSearch": {}})
elif "googleSearch" in tool:
gemini_tools.append({"googleSearch": tool["googleSearch"]})
elif "codeExecution" in tool:
gemini_tools.append({"codeExecution": tool["codeExecution"]})
return gemini_tools
def _convert_tool_choice_to_gemini(
self, tool_choice_value: str | dict[str, Any]
) -> dict[str, Any]:
@ -395,10 +356,11 @@ class GeminiAdapter(BaseAdapter):
for category, threshold in custom_safety_settings.items():
safety_settings.append({"category": category, "threshold": threshold})
else:
from ..config.providers import get_gemini_safety_threshold
threshold = get_gemini_safety_threshold()
for category in safety_categories:
safety_settings.append(
{"category": category, "threshold": "BLOCK_MEDIUM_AND_ABOVE"}
)
safety_settings.append({"category": category, "threshold": threshold})
return safety_settings if safety_settings else None

View File

@ -1,12 +1,12 @@
"""
OpenAI API 适配器
支持 OpenAIDeepSeek 和其他 OpenAI 兼容的 API 服务
支持 OpenAIDeepSeek智谱AI 和其他 OpenAI 兼容的 API 服务
"""
from typing import TYPE_CHECKING
from .base import OpenAICompatAdapter, RequestData
from .base import OpenAICompatAdapter
if TYPE_CHECKING:
from ..service import LLMModel
@ -21,37 +21,18 @@ class OpenAIAdapter(OpenAICompatAdapter):
@property
def supported_api_types(self) -> list[str]:
return ["openai", "deepseek", "general_openai_compat"]
return ["openai", "deepseek", "zhipu", "general_openai_compat", "ark"]
def get_chat_endpoint(self) -> str:
def get_chat_endpoint(self, model: "LLMModel") -> str:
"""返回聊天完成端点"""
if model.api_type == "ark":
return "/api/v3/chat/completions"
if model.api_type == "zhipu":
return "/api/paas/v4/chat/completions"
return "/v1/chat/completions"
def get_embedding_endpoint(self) -> str:
"""返回嵌入端点"""
def get_embedding_endpoint(self, model: "LLMModel") -> str:
"""根据API类型返回嵌入端点"""
if model.api_type == "zhipu":
return "/v4/embeddings"
return "/v1/embeddings"
def prepare_simple_request(
self,
model: "LLMModel",
api_key: str,
prompt: str,
history: list[dict[str, str]] | None = None,
) -> RequestData:
"""准备简单文本生成请求 - OpenAI优化实现"""
url = self.get_api_url(model, self.get_chat_endpoint())
headers = self.get_base_headers(api_key)
messages = []
if history:
messages.extend(history)
messages.append({"role": "user", "content": prompt})
body = {
"model": model.model_name,
"messages": messages,
}
body = self.apply_config_override(model, body)
return RequestData(url=url, headers=headers, body=body)

View File

@ -1,57 +0,0 @@
"""
智谱 AI API 适配器
支持智谱 AI GLM 系列模型使用 OpenAI 兼容的接口格式
"""
from typing import TYPE_CHECKING
from .base import OpenAICompatAdapter, RequestData
if TYPE_CHECKING:
from ..service import LLMModel
class ZhipuAdapter(OpenAICompatAdapter):
"""智谱AI适配器 - 使用智谱AI专用的OpenAI兼容接口"""
@property
def api_type(self) -> str:
return "zhipu"
@property
def supported_api_types(self) -> list[str]:
return ["zhipu"]
def get_chat_endpoint(self) -> str:
"""返回智谱AI聊天完成端点"""
return "/api/paas/v4/chat/completions"
def get_embedding_endpoint(self) -> str:
"""返回智谱AI嵌入端点"""
return "/v4/embeddings"
def prepare_simple_request(
self,
model: "LLMModel",
api_key: str,
prompt: str,
history: list[dict[str, str]] | None = None,
) -> RequestData:
"""准备简单文本生成请求 - 智谱AI优化实现"""
url = self.get_api_url(model, self.get_chat_endpoint())
headers = self.get_base_headers(api_key)
messages = []
if history:
messages.extend(history)
messages.append({"role": "user", "content": prompt})
body = {
"model": model.model_name,
"messages": messages,
}
body = self.apply_config_override(model, body)
return RequestData(url=url, headers=headers, body=body)

View File

@ -2,6 +2,7 @@
LLM 服务的高级 API 接口
"""
import copy
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
@ -14,6 +15,7 @@ from zhenxun.services.log import logger
from .config import CommonOverrides, LLMGenerationConfig
from .config.providers import get_ai_config
from .manager import get_global_default_model_name, get_model_instance
from .tools import tool_registry
from .types import (
EmbeddingTaskType,
LLMContentPart,
@ -56,6 +58,7 @@ class AIConfig:
enable_gemini_safe_mode: bool = False
enable_gemini_multimodal: bool = False
enable_gemini_grounding: bool = False
default_preserve_media_in_history: bool = False
def __post_init__(self):
"""初始化后从配置中读取默认值"""
@ -81,7 +84,7 @@ class AI:
"""
初始化AI服务
Args:
参数:
config: AI 配置.
history: 可选的初始对话历史.
"""
@ -93,16 +96,65 @@ class AI:
self.history = []
logger.info("AI session history cleared.")
def _sanitize_message_for_history(self, message: LLMMessage) -> LLMMessage:
"""
净化用于存入历史记录的消息
将非文本的多模态内容部分替换为文本占位符以避免重复处理
"""
if not isinstance(message.content, list):
return message
sanitized_message = copy.deepcopy(message)
content_list = sanitized_message.content
if not isinstance(content_list, list):
return sanitized_message
new_content_parts: list[LLMContentPart] = []
has_multimodal_content = False
for part in content_list:
if isinstance(part, LLMContentPart) and part.type == "text":
new_content_parts.append(part)
else:
has_multimodal_content = True
if has_multimodal_content:
placeholder = "[用户发送了媒体文件,内容已在首次分析时处理]"
text_part_found = False
for part in new_content_parts:
if part.type == "text":
part.text = f"{placeholder} {part.text or ''}".strip()
text_part_found = True
break
if not text_part_found:
new_content_parts.insert(0, LLMContentPart.text_part(placeholder))
sanitized_message.content = new_content_parts
return sanitized_message
async def chat(
self,
message: str | LLMMessage | list[LLMContentPart],
*,
model: ModelName = None,
preserve_media_in_history: bool | None = None,
**kwargs: Any,
) -> str:
"""
进行一次聊天对话
此方法会自动使用和更新会话内的历史记录
参数:
message: 用户输入的消息
model: 本次对话要使用的模型
preserve_media_in_history: 是否在历史记录中保留原始多模态信息
- True: 保留用于深度多轮媒体分析
- False: 不保留替换为占位符提高效率
- None (默认): 使用AI实例配置的默认值
**kwargs: 传递给模型的其他参数
返回:
str: 模型的文本响应
"""
current_message: LLMMessage
if isinstance(message, str):
@ -127,7 +179,20 @@ class AI:
final_messages, model, "聊天失败", kwargs
)
self.history.append(current_message)
should_preserve = (
preserve_media_in_history
if preserve_media_in_history is not None
else self.config.default_preserve_media_in_history
)
if should_preserve:
logger.debug("深度分析模式:在历史记录中保留原始多模态消息。")
self.history.append(current_message)
else:
logger.debug("高效模式:净化历史记录中的多模态消息。")
sanitized_user_message = self._sanitize_message_for_history(current_message)
self.history.append(sanitized_user_message)
self.history.append(LLMMessage.assistant_text_response(response.text))
return response.text
@ -140,7 +205,18 @@ class AI:
timeout: int | None = None,
**kwargs: Any,
) -> dict[str, Any]:
"""代码执行"""
"""
代码执行
参数:
prompt: 代码执行的提示词
model: 要使用的模型名称
timeout: 代码执行超时时间
**kwargs: 传递给模型的其他参数
返回:
dict[str, Any]: 包含执行结果的字典包含textcode_executions和success字段
"""
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
config = CommonOverrides.gemini_code_execution()
@ -168,7 +244,18 @@ class AI:
instruction: str = "",
**kwargs: Any,
) -> dict[str, Any]:
"""信息搜索 - 支持多模态输入"""
"""
信息搜索 - 支持多模态输入
参数:
query: 搜索查询内容支持文本或多模态消息
model: 要使用的模型名称
instruction: 搜索指令
**kwargs: 传递给模型的其他参数
返回:
dict[str, Any]: 包含搜索结果的字典包含textsourcesqueries和success字段
"""
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
config = CommonOverrides.gemini_grounding()
@ -217,63 +304,69 @@ class AI:
async def analyze(
self,
message: UniMessage,
message: UniMessage | None,
*,
instruction: str = "",
model: ModelName = None,
tools: list[dict[str, Any]] | None = None,
use_tools: list[str] | None = None,
tool_config: dict[str, Any] | None = None,
activated_tools: list[LLMTool] | None = None,
history: list[LLMMessage] | None = None,
**kwargs: Any,
) -> str | LLMResponse:
) -> LLMResponse:
"""
内容分析 - 接收 UniMessage 物件进行多模态分析和工具呼叫
这是处理复杂互动的主要方法
参数:
message: 要分析的消息内容支持多模态
instruction: 分析指令
model: 要使用的模型名称
use_tools: 要使用的工具名称列表
tool_config: 工具配置
activated_tools: 已激活的工具列表
history: 对话历史记录
**kwargs: 传递给模型的其他参数
返回:
LLMResponse: 模型的完整响应结果
"""
content_parts = await unimsg_to_llm_parts(message)
content_parts = await unimsg_to_llm_parts(message or UniMessage())
final_messages: list[LLMMessage] = []
if history:
final_messages.extend(history)
if instruction:
final_messages.append(LLMMessage.system(instruction))
if not any(msg.role == "system" for msg in final_messages):
final_messages.insert(0, LLMMessage.system(instruction))
if not content_parts:
if instruction:
if instruction and not history:
final_messages.append(LLMMessage.user(instruction))
else:
elif not history:
raise LLMException(
"分析内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
)
else:
final_messages.append(LLMMessage.user(content_parts))
llm_tools = None
if tools:
llm_tools = []
for tool_dict in tools:
if isinstance(tool_dict, dict):
if "name" in tool_dict and "description" in tool_dict:
llm_tool = LLMTool(
type="function",
function={
"name": tool_dict["name"],
"description": tool_dict["description"],
"parameters": tool_dict.get("parameters", {}),
},
)
llm_tools.append(llm_tool)
else:
llm_tools.append(LLMTool(**tool_dict))
else:
llm_tools.append(tool_dict)
llm_tools: list[LLMTool] | None = activated_tools
if not llm_tools and use_tools:
try:
llm_tools = tool_registry.get_tools(use_tools)
logger.debug(f"已从注册表加载工具定义: {use_tools}")
except ValueError as e:
raise LLMException(
f"加载工具定义失败: {e}",
code=LLMErrorCode.CONFIGURATION_ERROR,
cause=e,
)
tool_choice = None
if tool_config:
mode = tool_config.get("mode", "auto")
if mode == "auto":
tool_choice = "auto"
elif mode == "any":
tool_choice = "any"
elif mode == "none":
tool_choice = "none"
if mode in ["auto", "any", "none"]:
tool_choice = mode
response = await self._execute_generation(
final_messages,
@ -284,9 +377,7 @@ class AI:
tool_choice=tool_choice,
)
if response.tool_calls:
return response
return response.text
return response
async def _execute_generation(
self,
@ -298,7 +389,7 @@ class AI:
tool_choice: str | dict[str, Any] | None = None,
base_config: LLMGenerationConfig | None = None,
) -> LLMResponse:
"""通用的生成执行方法,封装重复的模型获取、配置合并和异常处理逻辑"""
"""通用的生成执行方法,封装模型获取和单次API调用"""
try:
resolved_model_name = self._resolve_model_name(
model_name or self.config.model
@ -311,7 +402,9 @@ class AI:
resolved_model_name, override_config=final_config_dict
) as model_instance:
return await model_instance.generate_response(
messages, tools=llm_tools, tool_choice=tool_choice
messages,
tools=llm_tools,
tool_choice=tool_choice,
)
except LLMException:
raise
@ -380,7 +473,18 @@ class AI:
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
**kwargs: Any,
) -> list[list[float]]:
"""生成文本嵌入向量"""
"""
生成文本嵌入向量
参数:
texts: 要生成嵌入向量的文本或文本列表
model: 要使用的嵌入模型名称
task_type: 嵌入任务类型
**kwargs: 传递给模型的其他参数
返回:
list[list[float]]: 文本的嵌入向量列表
"""
if isinstance(texts, str):
texts = [texts]
if not texts:
@ -420,7 +524,17 @@ async def chat(
model: ModelName = None,
**kwargs: Any,
) -> str:
"""聊天对话便捷函数"""
"""
聊天对话便捷函数
参数:
message: 用户输入的消息
model: 要使用的模型名称
**kwargs: 传递给模型的其他参数
返回:
str: 模型的文本响应
"""
ai = AI()
return await ai.chat(message, model=model, **kwargs)
@ -432,7 +546,18 @@ async def code(
timeout: int | None = None,
**kwargs: Any,
) -> dict[str, Any]:
"""代码执行便捷函数"""
"""
代码执行便捷函数
参数:
prompt: 代码执行的提示词
model: 要使用的模型名称
timeout: 代码执行超时时间
**kwargs: 传递给模型的其他参数
返回:
dict[str, Any]: 包含执行结果的字典
"""
ai = AI()
return await ai.code(prompt, model=model, timeout=timeout, **kwargs)
@ -444,45 +569,56 @@ async def search(
instruction: str = "",
**kwargs: Any,
) -> dict[str, Any]:
"""信息搜索便捷函数"""
"""
信息搜索便捷函数
参数:
query: 搜索查询内容
model: 要使用的模型名称
instruction: 搜索指令
**kwargs: 传递给模型的其他参数
返回:
dict[str, Any]: 包含搜索结果的字典
"""
ai = AI()
return await ai.search(query, model=model, instruction=instruction, **kwargs)
async def analyze(
message: UniMessage,
message: UniMessage | None,
*,
instruction: str = "",
model: ModelName = None,
tools: list[dict[str, Any]] | None = None,
use_tools: list[str] | None = None,
tool_config: dict[str, Any] | None = None,
**kwargs: Any,
) -> str | LLMResponse:
"""内容分析便捷函数"""
"""
内容分析便捷函数
参数:
message: 要分析的消息内容
instruction: 分析指令
model: 要使用的模型名称
use_tools: 要使用的工具名称列表
tool_config: 工具配置
**kwargs: 传递给模型的其他参数
返回:
str | LLMResponse: 分析结果
"""
ai = AI()
return await ai.analyze(
message,
instruction=instruction,
model=model,
tools=tools,
use_tools=use_tools,
tool_config=tool_config,
**kwargs,
)
async def analyze_with_images(
text: str,
images: list[str | Path | bytes] | str | Path | bytes,
*,
instruction: str = "",
model: ModelName = None,
**kwargs: Any,
) -> str | LLMResponse:
"""图片分析便捷函数"""
message = create_multimodal_message(text=text, images=images)
return await analyze(message, instruction=instruction, model=model, **kwargs)
async def analyze_multimodal(
text: str | None = None,
images: list[str | Path | bytes] | str | Path | bytes | None = None,
@ -493,7 +629,21 @@ async def analyze_multimodal(
model: ModelName = None,
**kwargs: Any,
) -> str | LLMResponse:
"""多模态分析便捷函数"""
"""
多模态分析便捷函数
参数:
text: 文本内容
images: 图片文件路径字节数据或列表
videos: 视频文件路径字节数据或列表
audios: 音频文件路径字节数据或列表
instruction: 分析指令
model: 要使用的模型名称
**kwargs: 传递给模型的其他参数
返回:
str | LLMResponse: 分析结果
"""
message = create_multimodal_message(
text=text, images=images, videos=videos, audios=audios
)
@ -510,7 +660,21 @@ async def search_multimodal(
model: ModelName = None,
**kwargs: Any,
) -> dict[str, Any]:
"""多模态搜索便捷函数"""
"""
多模态搜索便捷函数
参数:
text: 文本内容
images: 图片文件路径字节数据或列表
videos: 视频文件路径字节数据或列表
audios: 音频文件路径字节数据或列表
instruction: 搜索指令
model: 要使用的模型名称
**kwargs: 传递给模型的其他参数
返回:
dict[str, Any]: 包含搜索结果的字典
"""
message = create_multimodal_message(
text=text, images=images, videos=videos, audios=audios
)
@ -525,6 +689,101 @@ async def embed(
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
**kwargs: Any,
) -> list[list[float]]:
"""文本嵌入便捷函数"""
"""
文本嵌入便捷函数
参数:
texts: 要生成嵌入向量的文本或文本列表
model: 要使用的嵌入模型名称
task_type: 嵌入任务类型
**kwargs: 传递给模型的其他参数
返回:
list[list[float]]: 文本的嵌入向量列表
"""
ai = AI()
return await ai.embed(texts, model=model, task_type=task_type, **kwargs)
async def pipeline_chat(
message: UniMessage | str | list[LLMContentPart],
model_chain: list[ModelName],
*,
initial_instruction: str = "",
final_instruction: str = "",
**kwargs: Any,
) -> LLMResponse:
"""
AI模型链式调用前一个模型的输出作为下一个模型的输入
参数:
message: 初始输入消息支持多模态
model_chain: 模型名称列表
initial_instruction: 第一个模型的系统指令
final_instruction: 最后一个模型的系统指令
**kwargs: 传递给模型实例的其他参数
返回:
LLMResponse: 最后一个模型的响应结果
"""
if not model_chain:
raise ValueError("模型链`model_chain`不能为空。")
current_content: str | list[LLMContentPart]
if isinstance(message, str):
current_content = message
elif isinstance(message, list):
current_content = message
else:
current_content = await unimsg_to_llm_parts(message)
final_response: LLMResponse | None = None
for i, model_name in enumerate(model_chain):
if not model_name:
raise ValueError(f"模型链中第 {i + 1} 个模型名称为空。")
is_first_step = i == 0
is_last_step = i == len(model_chain) - 1
messages_for_step: list[LLMMessage] = []
instruction_for_step = ""
if is_first_step and initial_instruction:
instruction_for_step = initial_instruction
elif is_last_step and final_instruction:
instruction_for_step = final_instruction
if instruction_for_step:
messages_for_step.append(LLMMessage.system(instruction_for_step))
messages_for_step.append(LLMMessage.user(current_content))
logger.info(
f"Pipeline Step [{i + 1}/{len(model_chain)}]: "
f"使用模型 '{model_name}' 进行处理..."
)
try:
async with await get_model_instance(model_name, **kwargs) as model:
response = await model.generate_response(messages_for_step)
final_response = response
current_content = response.text.strip()
if not current_content and not is_last_step:
logger.warning(
f"模型 '{model_name}' 在中间步骤返回了空内容,流水线可能无法继续。"
)
break
except Exception as e:
logger.error(f"在模型链的第 {i + 1} 步 ('{model_name}') 出错: {e}", e=e)
raise LLMException(
f"流水线在模型 '{model_name}' 处执行失败: {e}",
code=LLMErrorCode.GENERATION_FAILED,
cause=e,
)
if final_response is None:
raise LLMException(
"AI流水线未能产生任何响应。", code=LLMErrorCode.GENERATION_FAILED
)
return final_response

View File

@ -14,6 +14,8 @@ from .generation import (
from .presets import CommonOverrides
from .providers import (
LLMConfig,
ToolConfig,
get_gemini_safety_threshold,
get_llm_config,
register_llm_configs,
set_default_model,
@ -25,8 +27,10 @@ __all__ = [
"LLMConfig",
"LLMGenerationConfig",
"ModelConfigOverride",
"ToolConfig",
"apply_api_specific_mappings",
"create_generation_config_from_kwargs",
"get_gemini_safety_threshold",
"get_llm_config",
"register_llm_configs",
"set_default_model",

View File

@ -111,12 +111,12 @@ class LLMGenerationConfig(ModelConfigOverride):
params["temperature"] = self.temperature
if self.max_tokens is not None:
if api_type in ["gemini", "gemini_native"]:
if api_type == "gemini":
params["maxOutputTokens"] = self.max_tokens
else:
params["max_tokens"] = self.max_tokens
if api_type in ["gemini", "gemini_native"]:
if api_type == "gemini":
if self.top_k is not None:
params["topK"] = self.top_k
if self.top_p is not None:
@ -151,13 +151,13 @@ class LLMGenerationConfig(ModelConfigOverride):
if api_type in ["openai", "zhipu", "deepseek", "general_openai_compat"]:
params["response_format"] = {"type": "json_object"}
logger.debug(f"{api_type} 启用 JSON 对象输出模式")
elif api_type in ["gemini", "gemini_native"]:
elif api_type == "gemini":
params["responseMimeType"] = "application/json"
if self.response_schema:
params["responseSchema"] = self.response_schema
logger.debug(f"{api_type} 启用 JSON MIME 类型输出模式")
if api_type in ["gemini", "gemini_native"]:
if api_type == "gemini":
if (
self.response_format != ResponseFormat.JSON
and self.response_mime_type is not None
@ -214,7 +214,7 @@ def apply_api_specific_mappings(
"""应用API特定的参数映射"""
mapped_params = params.copy()
if api_type in ["gemini", "gemini_native"]:
if api_type == "gemini":
if "max_tokens" in mapped_params:
mapped_params["maxOutputTokens"] = mapped_params.pop("max_tokens")
if "top_k" in mapped_params:

View File

@ -71,14 +71,17 @@ class CommonOverrides:
@staticmethod
def gemini_safe() -> LLMGenerationConfig:
"""Gemini 安全模式:严格安全设置"""
"""Gemini 安全模式:使用配置的安全设置"""
from .providers import get_gemini_safety_threshold
threshold = get_gemini_safety_threshold()
return LLMGenerationConfig(
temperature=0.5,
safety_settings={
"HARM_CATEGORY_HARASSMENT": "BLOCK_MEDIUM_AND_ABOVE",
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_MEDIUM_AND_ABOVE",
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_MEDIUM_AND_ABOVE",
"HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_MEDIUM_AND_ABOVE",
"HARM_CATEGORY_HARASSMENT": threshold,
"HARM_CATEGORY_HATE_SPEECH": threshold,
"HARM_CATEGORY_SEXUALLY_EXPLICIT": threshold,
"HARM_CATEGORY_DANGEROUS_CONTENT": threshold,
},
)

View File

@ -4,15 +4,33 @@ LLM 提供商配置管理
负责注册和管理 AI 服务提供商的配置项
"""
from functools import lru_cache
import json
import sys
from typing import Any
from pydantic import BaseModel, Field
from zhenxun.configs.config import Config
from zhenxun.configs.path_config import DATA_PATH
from zhenxun.configs.utils import parse_as
from zhenxun.services.log import logger
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from ..types.models import ModelDetail, ProviderConfig
class ToolConfig(BaseModel):
"""MCP类型工具的配置定义"""
type: str = "mcp"
name: str = Field(..., description="工具的唯一名称标识")
description: str | None = Field(None, description="工具功能的描述")
mcp_config: dict[str, Any] | BaseModel = Field(
..., description="MCP服务器的特定配置"
)
AI_CONFIG_GROUP = "AI"
PROVIDERS_CONFIG_KEY = "PROVIDERS"
@ -38,6 +56,9 @@ class LLMConfig(BaseModel):
providers: list[ProviderConfig] = Field(
default_factory=list, description="配置多个 AI 服务提供商及其模型信息"
)
mcp_tools: list[ToolConfig] = Field(
default_factory=list, description="配置可用的外部MCP工具"
)
def get_provider_by_name(self, name: str) -> ProviderConfig | None:
"""根据名称获取提供商配置
@ -132,7 +153,7 @@ def get_default_providers() -> list[dict[str, Any]]:
return [
{
"name": "DeepSeek",
"api_key": "sk-******",
"api_key": "YOUR_ARK_API_KEY",
"api_base": "https://api.deepseek.com",
"api_type": "openai",
"models": [
@ -146,9 +167,30 @@ def get_default_providers() -> list[dict[str, Any]]:
},
],
},
{
"name": "ARK",
"api_key": "YOUR_ARK_API_KEY",
"api_base": "https://ark.cn-beijing.volces.com",
"api_type": "ark",
"models": [
{"model_name": "deepseek-r1-250528"},
{"model_name": "doubao-seed-1-6-250615"},
{"model_name": "doubao-seed-1-6-flash-250615"},
{"model_name": "doubao-seed-1-6-thinking-250615"},
],
},
{
"name": "siliconflow",
"api_key": "YOUR_ARK_API_KEY",
"api_base": "https://api.siliconflow.cn",
"api_type": "openai",
"models": [
{"model_name": "deepseek-ai/DeepSeek-V3"},
],
},
{
"name": "GLM",
"api_key": "",
"api_key": "YOUR_ARK_API_KEY",
"api_base": "https://open.bigmodel.cn",
"api_type": "zhipu",
"models": [
@ -167,12 +209,41 @@ def get_default_providers() -> list[dict[str, Any]]:
"api_type": "gemini",
"models": [
{"model_name": "gemini-2.0-flash"},
{"model_name": "gemini-2.5-flash-preview-05-20"},
{"model_name": "gemini-2.5-flash"},
{"model_name": "gemini-2.5-pro"},
{"model_name": "gemini-2.5-flash-lite-preview-06-17"},
],
},
]
def get_default_mcp_tools() -> dict[str, Any]:
"""
获取默认的MCP工具配置用于在文件不存在时创建
包含了 baidu-map, Context7, sequential-thinking.
"""
return {
"mcpServers": {
"baidu-map": {
"command": "npx",
"args": ["-y", "@baidumap/mcp-server-baidu-map"],
"env": {"BAIDU_MAP_API_KEY": "<YOUR_BAIDU_MAP_API_KEY>"},
"description": "百度地图工具,提供地理编码、路线规划等功能。",
},
"sequential-thinking": {
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-sequential-thinking"],
"description": "顺序思维工具,用于帮助模型进行多步骤推理。",
},
"Context7": {
"command": "npx",
"args": ["-y", "@upstash/context7-mcp@latest"],
"description": "Upstash 提供的上下文管理和记忆工具。",
},
}
}
def register_llm_configs():
"""注册 LLM 服务的配置项"""
logger.info("注册 LLM 服务的配置项")
@ -214,6 +285,19 @@ def register_llm_configs():
help="LLM服务请求重试的基础延迟时间",
type=int,
)
Config.add_plugin_config(
AI_CONFIG_GROUP,
"gemini_safety_threshold",
"BLOCK_MEDIUM_AND_ABOVE",
help=(
"Gemini 安全过滤阈值 "
"(BLOCK_LOW_AND_ABOVE: 阻止低级别及以上, "
"BLOCK_MEDIUM_AND_ABOVE: 阻止中等级别及以上, "
"BLOCK_ONLY_HIGH: 只阻止高级别, "
"BLOCK_NONE: 不阻止)"
),
type=str,
)
Config.add_plugin_config(
AI_CONFIG_GROUP,
@ -225,24 +309,111 @@ def register_llm_configs():
)
@lru_cache(maxsize=1)
def get_llm_config() -> LLMConfig:
"""获取 LLM 配置实例
返回:
LLMConfig: LLM 配置实例
"""
"""获取 LLM 配置实例,现在会从新的 JSON 文件加载 MCP 工具"""
ai_config = get_ai_config()
llm_data_path = DATA_PATH / "llm"
mcp_tools_path = llm_data_path / "mcp_tools.json"
mcp_tools_list = []
mcp_servers_dict = {}
if not mcp_tools_path.exists():
logger.info(f"未找到 MCP 工具配置文件,将在 '{mcp_tools_path}' 创建一个。")
llm_data_path.mkdir(parents=True, exist_ok=True)
default_mcp_config = get_default_mcp_tools()
try:
with mcp_tools_path.open("w", encoding="utf-8") as f:
json.dump(default_mcp_config, f, ensure_ascii=False, indent=2)
mcp_servers_dict = default_mcp_config.get("mcpServers", {})
except Exception as e:
logger.error(f"创建默认 MCP 配置文件失败: {e}", e=e)
mcp_servers_dict = {}
else:
try:
with mcp_tools_path.open("r", encoding="utf-8") as f:
mcp_data = json.load(f)
mcp_servers_dict = mcp_data.get("mcpServers", {})
if not isinstance(mcp_servers_dict, dict):
logger.warning(
f"'{mcp_tools_path}' 中的 'mcpServers' 键不是一个字典,"
f"将使用空配置。"
)
mcp_servers_dict = {}
except json.JSONDecodeError as e:
logger.error(f"解析 MCP 配置文件 '{mcp_tools_path}' 失败: {e}", e=e)
except Exception as e:
logger.error(f"读取 MCP 配置文件时发生未知错误: {e}", e=e)
mcp_servers_dict = {}
if sys.platform == "win32":
logger.debug("检测到Windows平台正在调整MCP工具的npx命令...")
for name, config in mcp_servers_dict.items():
if isinstance(config, dict) and config.get("command") == "npx":
logger.info(f"为工具 '{name}' 包装npx命令以兼容Windows。")
original_args = config.get("args", [])
config["command"] = "cmd"
config["args"] = ["/c", "npx", *original_args]
if mcp_servers_dict:
mcp_tools_list = [
{
"name": name,
"type": "mcp",
"description": config.get("description", f"MCP tool for {name}"),
"mcp_config": config,
}
for name, config in mcp_servers_dict.items()
if isinstance(config, dict)
]
from ..tools.registry import tool_registry
for tool_dict in mcp_tools_list:
if isinstance(tool_dict, dict):
tool_name = tool_dict.get("name")
if not tool_name:
continue
config_model = tool_registry.get_mcp_config_model(tool_name)
if not config_model:
logger.debug(
f"MCP工具 '{tool_name}' 没有注册其配置模型,"
f"将跳过特定配置验证,直接使用原始配置字典。"
)
continue
mcp_config_data = tool_dict.get("mcp_config", {})
try:
parsed_mcp_config = parse_as(config_model, mcp_config_data)
tool_dict["mcp_config"] = parsed_mcp_config
except Exception as e:
raise ValueError(f"MCP工具 '{tool_name}' 的 `mcp_config` 配置错误: {e}")
config_data = {
"default_model_name": ai_config.get("default_model_name"),
"proxy": ai_config.get("proxy"),
"timeout": ai_config.get("timeout", 180),
"max_retries_llm": ai_config.get("max_retries_llm", 3),
"retry_delay_llm": ai_config.get("retry_delay_llm", 2),
"providers": ai_config.get(PROVIDERS_CONFIG_KEY, []),
PROVIDERS_CONFIG_KEY: ai_config.get(PROVIDERS_CONFIG_KEY, []),
"mcp_tools": mcp_tools_list,
}
return LLMConfig(**config_data)
return parse_as(LLMConfig, config_data)
def get_gemini_safety_threshold() -> str:
"""获取 Gemini 安全过滤阈值配置
返回:
str: 安全过滤阈值
"""
ai_config = get_ai_config()
return ai_config.get("gemini_safety_threshold", "BLOCK_MEDIUM_AND_ABOVE")
def validate_llm_config() -> tuple[bool, list[str]]:
@ -326,3 +497,17 @@ def set_default_model(provider_model_name: str | None) -> bool:
logger.info("默认模型已清除")
return True
@PriorityLifecycle.on_startup(priority=10)
async def _init_llm_config_on_startup():
"""
在服务启动时主动调用一次 get_llm_config
以触发必要的初始化操作例如创建默认的 mcp_tools.json 文件
"""
logger.info("正在初始化 LLM 配置并检查 MCP 工具文件...")
try:
get_llm_config()
logger.info("LLM 配置初始化完成。")
except Exception as e:
logger.error(f"LLM 配置初始化时发生错误: {e}", e=e)

View File

@ -49,12 +49,36 @@ class LLMHttpClient:
max_keepalive_connections=self.config.max_keepalive_connections,
)
timeout = httpx.Timeout(self.config.timeout)
client_kwargs = {}
if self.config.proxy:
try:
version_parts = httpx.__version__.split(".")
major = int(
"".join(c for c in version_parts[0] if c.isdigit())
)
minor = (
int("".join(c for c in version_parts[1] if c.isdigit()))
if len(version_parts) > 1
else 0
)
if (major, minor) >= (0, 28):
client_kwargs["proxy"] = self.config.proxy
else:
client_kwargs["proxies"] = self.config.proxy
except (ValueError, IndexError):
client_kwargs["proxies"] = self.config.proxy
logger.warning(
f"无法解析 httpx 版本 '{httpx.__version__}'"
"LLM模块将默认使用旧版 'proxies' 参数语法。"
)
self._client = httpx.AsyncClient(
headers=headers,
limits=limits,
timeout=timeout,
proxies=self.config.proxy,
follow_redirects=True,
**client_kwargs,
)
if self._client is None:
raise LLMException(
@ -156,7 +180,16 @@ async def create_llm_http_client(
timeout: int = 180,
proxy: str | None = None,
) -> LLMHttpClient:
"""创建LLM HTTP客户端"""
"""
创建LLM HTTP客户端
参数:
timeout: 超时时间
proxy: 代理服务器地址
返回:
LLMHttpClient: HTTP客户端实例
"""
config = HttpClientConfig(timeout=timeout, proxy=proxy)
return LLMHttpClient(config)
@ -185,7 +218,20 @@ async def with_smart_retry(
provider_name: str | None = None,
**kwargs: Any,
) -> Any:
"""智能重试装饰器 - 支持Key轮询和错误分类"""
"""
智能重试装饰器 - 支持Key轮询和错误分类
参数:
func: 要重试的异步函数
*args: 传递给函数的位置参数
retry_config: 重试配置
key_store: API密钥状态存储
provider_name: 提供商名称
**kwargs: 传递给函数的关键字参数
返回:
Any: 函数执行结果
"""
config = retry_config or RetryConfig()
last_exception: Exception | None = None
failed_keys: set[str] = set()
@ -294,7 +340,17 @@ class KeyStatusStore:
api_keys: list[str],
exclude_keys: set[str] | None = None,
) -> str | None:
"""获取下一个可用的API密钥轮询策略"""
"""
获取下一个可用的API密钥轮询策略
参数:
provider_name: 提供商名称
api_keys: API密钥列表
exclude_keys: 要排除的密钥集合
返回:
str | None: 可用的API密钥如果没有可用密钥则返回None
"""
if not api_keys:
return None
@ -338,7 +394,13 @@ class KeyStatusStore:
logger.debug(f"记录API密钥成功使用: {self._get_key_id(api_key)}")
async def record_failure(self, api_key: str, status_code: int | None):
"""记录失败使用"""
"""
记录失败使用
参数:
api_key: API密钥
status_code: HTTP状态码
"""
key_id = self._get_key_id(api_key)
async with self._lock:
if status_code in [401, 403]:
@ -356,7 +418,15 @@ class KeyStatusStore:
logger.info(f"重置API密钥状态: {self._get_key_id(api_key)}")
async def get_key_stats(self, api_keys: list[str]) -> dict[str, dict]:
"""获取密钥使用统计"""
"""
获取密钥使用统计
参数:
api_keys: API密钥列表
返回:
dict[str, dict]: 密钥统计信息字典
"""
stats = {}
async with self._lock:
for key in api_keys:

View File

@ -17,6 +17,7 @@ from .config.providers import AI_CONFIG_GROUP, PROVIDERS_CONFIG_KEY, get_ai_conf
from .core import http_client_manager, key_store
from .service import LLMModel
from .types import LLMErrorCode, LLMException, ModelDetail, ProviderConfig
from .types.capabilities import get_model_capabilities
DEFAULT_MODEL_NAME_KEY = "default_model_name"
PROXY_KEY = "proxy"
@ -115,57 +116,30 @@ def get_default_api_base_for_type(api_type: str) -> str | None:
def get_configured_providers() -> list[ProviderConfig]:
"""从配置中获取Provider列表 - 简化版本"""
"""从配置中获取Provider列表 - 简化和修正版本"""
ai_config = get_ai_config()
providers_raw = ai_config.get(PROVIDERS_CONFIG_KEY, [])
if not isinstance(providers_raw, list):
providers = ai_config.get(PROVIDERS_CONFIG_KEY, [])
if not isinstance(providers, list):
logger.error(
f"配置项 {AI_CONFIG_GROUP}.{PROVIDERS_CONFIG_KEY} 不是一个列表,"
f"配置项 {AI_CONFIG_GROUP}.{PROVIDERS_CONFIG_KEY} 的值不是一个列表,"
f"将使用空列表。"
)
return []
valid_providers = []
for i, item in enumerate(providers_raw):
if not isinstance(item, dict):
logger.warning(f"配置文件中第 {i + 1} 项不是字典格式,已跳过。")
continue
try:
if not item.get("name"):
logger.warning(f"Provider {i + 1} 缺少 'name' 字段,已跳过。")
continue
if not item.get("api_key"):
logger.warning(
f"Provider '{item['name']}' 缺少 'api_key' 字段,已跳过。"
)
continue
if "api_type" not in item or not item["api_type"]:
provider_name = item.get("name", "").lower()
if "glm" in provider_name or "zhipu" in provider_name:
item["api_type"] = "zhipu"
elif "gemini" in provider_name or "google" in provider_name:
item["api_type"] = "gemini"
else:
item["api_type"] = "openai"
if "api_base" not in item or not item["api_base"]:
api_type = item.get("api_type")
if api_type:
default_api_base = get_default_api_base_for_type(api_type)
if default_api_base:
item["api_base"] = default_api_base
if "models" not in item:
item["models"] = [{"model_name": item.get("name", "default")}]
provider_conf = ProviderConfig(**item)
valid_providers.append(provider_conf)
except Exception as e:
logger.warning(f"解析配置文件中 Provider {i + 1} 时出错: {e},已跳过。")
for i, item in enumerate(providers):
if isinstance(item, ProviderConfig):
if not item.api_base:
default_api_base = get_default_api_base_for_type(item.api_type)
if default_api_base:
item.api_base = default_api_base
valid_providers.append(item)
else:
logger.warning(
f"配置文件中第 {i + 1} 项未能正确解析为 ProviderConfig 对象,"
f"已跳过。实际类型: {type(item)}"
)
return valid_providers
@ -173,14 +147,15 @@ def get_configured_providers() -> list[ProviderConfig]:
def find_model_config(
provider_name: str, model_name: str
) -> tuple[ProviderConfig, ModelDetail] | None:
"""在配置中查找指定的 Provider 和 ModelDetail
"""
在配置中查找指定的 Provider ModelDetail
Args:
参数:
provider_name: 提供商名称
model_name: 模型名称
Returns:
找到的 (ProviderConfig, ModelDetail) 元组未找到则返回 None
返回:
tuple[ProviderConfig, ModelDetail] | None: 找到的配置元组未找到则返回 None
"""
providers = get_configured_providers()
@ -221,10 +196,11 @@ def _get_model_identifiers(provider_name: str, model_detail: ModelDetail) -> lis
def list_model_identifiers() -> dict[str, list[str]]:
"""列出所有模型的可用标识符
"""
列出所有模型的可用标识符
Returns:
字典键为模型的完整名称值为该模型的所有可用标识符列表
返回:
dict[str, list[str]]: 字典键为模型的完整名称值为该模型的所有可用标识符列表
"""
providers = get_configured_providers()
result = {}
@ -248,7 +224,16 @@ async def get_model_instance(
provider_model_name: str | None = None,
override_config: dict[str, Any] | None = None,
) -> LLMModel:
"""根据 'ProviderName/ModelName' 字符串获取并实例化 LLMModel (异步版本)"""
"""
根据 'ProviderName/ModelName' 字符串获取并实例化 LLMModel (异步版本)
参数:
provider_model_name: 模型名称格式为 'ProviderName/ModelName'
override_config: 覆盖配置字典
返回:
LLMModel: 模型实例
"""
cache_key = _make_cache_key(provider_model_name, override_config)
cached_model = _get_cached_model(cache_key)
if cached_model:
@ -292,6 +277,10 @@ async def get_model_instance(
provider_config_found, model_detail_found = config_tuple_found
capabilities = get_model_capabilities(model_detail_found.model_name)
model_detail_found.is_embedding_model = capabilities.is_embedding_model
ai_config = get_ai_config()
global_proxy_setting = ai_config.get(PROXY_KEY)
default_timeout = (
@ -322,6 +311,7 @@ async def get_model_instance(
model_detail=model_detail_found,
key_store=key_store,
http_client=shared_http_client,
capabilities=capabilities,
)
if override_config:
@ -357,7 +347,15 @@ def get_global_default_model_name() -> str | None:
def set_global_default_model_name(provider_model_name: str | None) -> bool:
"""设置全局默认模型名称"""
"""
设置全局默认模型名称
参数:
provider_model_name: 模型名称格式为 'ProviderName/ModelName'
返回:
bool: 设置是否成功
"""
if provider_model_name:
prov_name, mod_name = parse_provider_model_string(provider_model_name)
if not prov_name or not mod_name or not find_model_config(prov_name, mod_name):
@ -377,7 +375,12 @@ def set_global_default_model_name(provider_model_name: str | None) -> bool:
async def get_key_usage_stats() -> dict[str, Any]:
"""获取所有Provider的Key使用统计"""
"""
获取所有Provider的Key使用统计
返回:
dict[str, Any]: 包含所有Provider的Key使用统计信息
"""
providers = get_configured_providers()
stats = {}
@ -400,7 +403,16 @@ async def get_key_usage_stats() -> dict[str, Any]:
async def reset_key_status(provider_name: str, api_key: str | None = None) -> bool:
"""重置指定Provider的Key状态"""
"""
重置指定Provider的Key状态
参数:
provider_name: 提供商名称
api_key: 要重置的特定API密钥如果为None则重置所有密钥
返回:
bool: 重置是否成功
"""
providers = get_configured_providers()
target_provider = None

View File

@ -6,11 +6,13 @@ LLM 模型实现类
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Callable
from contextlib import AsyncExitStack
import json
from typing import Any
from zhenxun.services.log import logger
from .adapters.base import RequestData
from .config import LLMGenerationConfig
from .config.providers import get_ai_config
from .core import (
@ -30,6 +32,8 @@ from .types import (
ModelDetail,
ProviderConfig,
)
from .types.capabilities import ModelCapabilities, ModelModality
from .utils import _sanitize_request_body_for_logging
class LLMModelBase(ABC):
@ -42,7 +46,17 @@ class LLMModelBase(ABC):
history: list[dict[str, str]] | None = None,
**kwargs: Any,
) -> str:
"""生成文本"""
"""
生成文本
参数:
prompt: 输入提示词
history: 对话历史记录
**kwargs: 其他参数
返回:
str: 生成的文本
"""
pass
@abstractmethod
@ -54,7 +68,19 @@ class LLMModelBase(ABC):
tool_choice: str | dict[str, Any] | None = None,
**kwargs: Any,
) -> LLMResponse:
"""生成高级响应"""
"""
生成高级响应
参数:
messages: 消息列表
config: 生成配置
tools: 工具列表
tool_choice: 工具选择策略
**kwargs: 其他参数
返回:
LLMResponse: 模型响应
"""
pass
@abstractmethod
@ -64,7 +90,17 @@ class LLMModelBase(ABC):
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
**kwargs: Any,
) -> list[list[float]]:
"""生成文本嵌入向量"""
"""
生成文本嵌入向量
参数:
texts: 文本列表
task_type: 嵌入任务类型
**kwargs: 其他参数
返回:
list[list[float]]: 嵌入向量列表
"""
pass
@ -77,12 +113,14 @@ class LLMModel(LLMModelBase):
model_detail: ModelDetail,
key_store: KeyStatusStore,
http_client: LLMHttpClient,
capabilities: ModelCapabilities,
config_override: LLMGenerationConfig | None = None,
):
self.provider_config = provider_config
self.model_detail = model_detail
self.key_store = key_store
self.http_client: LLMHttpClient = http_client
self.capabilities = capabilities
self._generation_config = config_override
self.provider_name = provider_config.name
@ -99,6 +137,34 @@ class LLMModel(LLMModelBase):
self._is_closed = False
def can_process_images(self) -> bool:
"""检查模型是否支持图片作为输入。"""
return ModelModality.IMAGE in self.capabilities.input_modalities
def can_process_video(self) -> bool:
"""检查模型是否支持视频作为输入。"""
return ModelModality.VIDEO in self.capabilities.input_modalities
def can_process_audio(self) -> bool:
"""检查模型是否支持音频作为输入。"""
return ModelModality.AUDIO in self.capabilities.input_modalities
def can_generate_images(self) -> bool:
"""检查模型是否支持生成图片。"""
return ModelModality.IMAGE in self.capabilities.output_modalities
def can_generate_audio(self) -> bool:
"""检查模型是否支持生成音频 (TTS)。"""
return ModelModality.AUDIO in self.capabilities.output_modalities
def can_use_tools(self) -> bool:
"""检查模型是否支持工具调用/函数调用。"""
return self.capabilities.supports_tool_calling
def is_embedding_model(self) -> bool:
"""检查这是否是一个嵌入模型。"""
return self.capabilities.is_embedding_model
async def _get_http_client(self) -> LLMHttpClient:
"""获取HTTP客户端"""
if self.http_client.is_closed:
@ -135,24 +201,54 @@ class LLMModel(LLMModelBase):
return selected_key
async def _execute_embedding_request(
async def _perform_api_call(
self,
adapter,
texts: list[str],
task_type: EmbeddingTaskType | str,
http_client: LLMHttpClient,
prepare_request_func: Callable[[str], Awaitable["RequestData"]],
parse_response_func: Callable[[dict[str, Any]], Any],
http_client: "LLMHttpClient",
failed_keys: set[str] | None = None,
) -> list[list[float]]:
"""执行单次嵌入请求 - 供重试机制调用"""
log_context: str = "API",
) -> Any:
"""
执行API调用的通用核心方法
该方法封装了以下通用逻辑:
1. 选择API密钥
2. 准备和记录请求
3. 发送HTTP POST请求
4. 处理HTTP错误和API特定错误
5. 记录密钥使用状态
6. 解析成功的响应
参数:
prepare_request_func: 准备请求的函数
parse_response_func: 解析响应的函数
http_client: HTTP客户端
failed_keys: 失败的密钥集合
log_context: 日志上下文
返回:
Any: 解析后的响应数据
"""
api_key = await self._select_api_key(failed_keys)
try:
request_data = adapter.prepare_embedding_request(
model=self,
api_key=api_key,
texts=texts,
task_type=task_type,
request_data = await prepare_request_func(api_key)
logger.info(
f"🌐 发起LLM请求 - 模型: {self.provider_name}/{self.model_name} "
f"[{log_context}]"
)
logger.debug(f"📡 请求URL: {request_data.url}")
masked_key = (
f"{api_key[:8]}...{api_key[-4:] if len(api_key) > 12 else '***'}"
)
logger.debug(f"🔑 API密钥: {masked_key}")
logger.debug(f"📋 请求头: {dict(request_data.headers)}")
sanitized_body = _sanitize_request_body_for_logging(request_data.body)
request_body_str = json.dumps(sanitized_body, ensure_ascii=False, indent=2)
logger.debug(f"📦 请求体: {request_body_str}")
http_response = await http_client.post(
request_data.url,
@ -160,121 +256,16 @@ class LLMModel(LLMModelBase):
json=request_data.body,
)
if http_response.status_code != 200:
error_text = http_response.text
logger.error(
f"HTTP嵌入请求失败: {http_response.status_code} - {error_text}"
)
await self.key_store.record_failure(api_key, http_response.status_code)
error_code = LLMErrorCode.API_REQUEST_FAILED
if http_response.status_code in [401, 403]:
error_code = LLMErrorCode.API_KEY_INVALID
elif http_response.status_code == 429:
error_code = LLMErrorCode.API_RATE_LIMITED
raise LLMException(
f"HTTP嵌入请求失败: {http_response.status_code}",
code=error_code,
details={
"status_code": http_response.status_code,
"response": error_text,
"api_key": api_key,
},
)
try:
response_json = http_response.json()
adapter.validate_embedding_response(response_json)
embeddings = adapter.parse_embedding_response(response_json)
except Exception as e:
logger.error(f"解析嵌入响应失败: {e}", e=e)
await self.key_store.record_failure(api_key, None)
if isinstance(e, LLMException):
raise
else:
raise LLMException(
f"解析API嵌入响应失败: {e}",
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
cause=e,
)
await self.key_store.record_success(api_key)
return embeddings
except LLMException:
raise
except Exception as e:
logger.error(f"生成嵌入时发生未预期错误: {e}", e=e)
await self.key_store.record_failure(api_key, None)
raise LLMException(
f"生成嵌入失败: {e}",
code=LLMErrorCode.EMBEDDING_FAILED,
cause=e,
)
async def _execute_with_smart_retry(
self,
adapter,
messages: list[LLMMessage],
config: LLMGenerationConfig | None,
tools_dict: list[dict[str, Any]] | None,
tool_choice: str | dict[str, Any] | None,
http_client: LLMHttpClient,
):
"""智能重试机制 - 使用统一的重试装饰器"""
ai_config = get_ai_config()
max_retries = ai_config.get("max_retries_llm", 3)
retry_delay = ai_config.get("retry_delay_llm", 2)
retry_config = RetryConfig(max_retries=max_retries, retry_delay=retry_delay)
return await with_smart_retry(
self._execute_single_request,
adapter,
messages,
config,
tools_dict,
tool_choice,
http_client,
retry_config=retry_config,
key_store=self.key_store,
provider_name=self.provider_name,
)
async def _execute_single_request(
self,
adapter,
messages: list[LLMMessage],
config: LLMGenerationConfig | None,
tools_dict: list[dict[str, Any]] | None,
tool_choice: str | dict[str, Any] | None,
http_client: LLMHttpClient,
failed_keys: set[str] | None = None,
) -> LLMResponse:
"""执行单次请求 - 供重试机制调用,直接返回 LLMResponse"""
api_key = await self._select_api_key(failed_keys)
try:
request_data = adapter.prepare_advanced_request(
model=self,
api_key=api_key,
messages=messages,
config=config,
tools=tools_dict,
tool_choice=tool_choice,
)
http_response = await http_client.post(
request_data.url,
headers=request_data.headers,
json=request_data.body,
)
logger.debug(f"📥 响应状态码: {http_response.status_code}")
logger.debug(f"📄 响应头: {dict(http_response.headers)}")
if http_response.status_code != 200:
error_text = http_response.text
logger.error(
f"HTTP请求失败: {http_response.status_code} - {error_text}"
f"❌ HTTP请求失败: {http_response.status_code} - {error_text} "
f"[{log_context}]"
)
logger.debug(f"💥 完整错误响应: {error_text}")
await self.key_store.record_failure(api_key, http_response.status_code)
@ -299,69 +290,165 @@ class LLMModel(LLMModelBase):
try:
response_json = http_response.json()
response_data = adapter.parse_response(
model=self,
response_json=response_json,
is_advanced=True,
)
from .types.models import LLMToolCall
response_tool_calls = []
if response_data.tool_calls:
for tc_data in response_data.tool_calls:
if isinstance(tc_data, LLMToolCall):
response_tool_calls.append(tc_data)
elif isinstance(tc_data, dict):
try:
response_tool_calls.append(LLMToolCall(**tc_data))
except Exception as e:
logger.warning(
f"无法将工具调用数据转换为LLMToolCall: {tc_data}, "
f"error: {e}"
)
else:
logger.warning(f"工具调用数据格式未知: {tc_data}")
llm_response = LLMResponse(
text=response_data.text,
usage_info=response_data.usage_info,
raw_response=response_data.raw_response,
tool_calls=response_tool_calls if response_tool_calls else None,
code_executions=response_data.code_executions,
grounding_metadata=response_data.grounding_metadata,
cache_info=response_data.cache_info,
response_json_str = json.dumps(
response_json, ensure_ascii=False, indent=2
)
logger.debug(f"📋 响应JSON: {response_json_str}")
parsed_data = parse_response_func(response_json)
except Exception as e:
logger.error(f"解析响应失败: {e}", e=e)
logger.error(f"解析 {log_context} 响应失败: {e}", e=e)
await self.key_store.record_failure(api_key, None)
if isinstance(e, LLMException):
raise
else:
raise LLMException(
f"解析API响应失败: {e}",
f"解析API {log_context} 响应失败: {e}",
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
cause=e,
)
await self.key_store.record_success(api_key)
return llm_response
logger.debug(f"✅ API密钥使用成功: {masked_key}")
logger.info(f"🎯 LLM响应解析完成 [{log_context}]")
return parsed_data
except LLMException:
raise
except Exception as e:
logger.error(f"生成响应时发生未预期错误: {e}", e=e)
error_log_msg = f"生成 {log_context.lower()} 时发生未预期错误: {e}"
logger.error(error_log_msg, e=e)
await self.key_store.record_failure(api_key, None)
raise LLMException(
f"生成响应失败: {e}",
code=LLMErrorCode.GENERATION_FAILED,
error_log_msg,
code=LLMErrorCode.GENERATION_FAILED
if log_context == "Generation"
else LLMErrorCode.EMBEDDING_FAILED,
cause=e,
)
async def _execute_embedding_request(
self,
adapter,
texts: list[str],
task_type: EmbeddingTaskType | str,
http_client: LLMHttpClient,
failed_keys: set[str] | None = None,
) -> list[list[float]]:
"""执行单次嵌入请求 - 供重试机制调用"""
async def prepare_request(api_key: str) -> RequestData:
return adapter.prepare_embedding_request(
model=self,
api_key=api_key,
texts=texts,
task_type=task_type,
)
def parse_response(response_json: dict[str, Any]) -> list[list[float]]:
adapter.validate_embedding_response(response_json)
return adapter.parse_embedding_response(response_json)
return await self._perform_api_call(
prepare_request_func=prepare_request,
parse_response_func=parse_response,
http_client=http_client,
failed_keys=failed_keys,
log_context="Embedding",
)
async def _execute_with_smart_retry(
self,
adapter,
messages: list[LLMMessage],
config: LLMGenerationConfig | None,
tools: list[LLMTool] | None,
tool_choice: str | dict[str, Any] | None,
http_client: LLMHttpClient,
):
"""智能重试机制 - 使用统一的重试装饰器"""
ai_config = get_ai_config()
max_retries = ai_config.get("max_retries_llm", 3)
retry_delay = ai_config.get("retry_delay_llm", 2)
retry_config = RetryConfig(max_retries=max_retries, retry_delay=retry_delay)
return await with_smart_retry(
self._execute_single_request,
adapter,
messages,
config,
tools,
tool_choice,
http_client,
retry_config=retry_config,
key_store=self.key_store,
provider_name=self.provider_name,
)
async def _execute_single_request(
self,
adapter,
messages: list[LLMMessage],
config: LLMGenerationConfig | None,
tools: list[LLMTool] | None,
tool_choice: str | dict[str, Any] | None,
http_client: LLMHttpClient,
failed_keys: set[str] | None = None,
) -> LLMResponse:
"""执行单次请求 - 供重试机制调用,直接返回 LLMResponse"""
async def prepare_request(api_key: str) -> RequestData:
return await adapter.prepare_advanced_request(
model=self,
api_key=api_key,
messages=messages,
config=config,
tools=tools,
tool_choice=tool_choice,
)
def parse_response(response_json: dict[str, Any]) -> LLMResponse:
response_data = adapter.parse_response(
model=self,
response_json=response_json,
is_advanced=True,
)
from .types.models import LLMToolCall
response_tool_calls = []
if response_data.tool_calls:
for tc_data in response_data.tool_calls:
if isinstance(tc_data, LLMToolCall):
response_tool_calls.append(tc_data)
elif isinstance(tc_data, dict):
try:
response_tool_calls.append(LLMToolCall(**tc_data))
except Exception as e:
logger.warning(
f"无法将工具调用数据转换为LLMToolCall: {tc_data}, "
f"error: {e}"
)
else:
logger.warning(f"工具调用数据格式未知: {tc_data}")
return LLMResponse(
text=response_data.text,
usage_info=response_data.usage_info,
raw_response=response_data.raw_response,
tool_calls=response_tool_calls if response_tool_calls else None,
code_executions=response_data.code_executions,
grounding_metadata=response_data.grounding_metadata,
cache_info=response_data.cache_info,
)
return await self._perform_api_call(
prepare_request_func=prepare_request,
parse_response_func=parse_response,
http_client=http_client,
failed_keys=failed_keys,
log_context="Generation",
)
async def close(self):
"""
标记模型实例的当前使用周期结束
@ -400,7 +487,17 @@ class LLMModel(LLMModelBase):
history: list[dict[str, str]] | None = None,
**kwargs: Any,
) -> str:
"""生成文本 - 通过 generate_response 实现"""
"""
生成文本 - 通过 generate_response 实现
参数:
prompt: 输入提示词
history: 对话历史记录
**kwargs: 其他参数
返回:
str: 生成的文本
"""
self._check_not_closed()
messages: list[LLMMessage] = []
@ -439,11 +536,21 @@ class LLMModel(LLMModelBase):
config: LLMGenerationConfig | None = None,
tools: list[LLMTool] | None = None,
tool_choice: str | dict[str, Any] | None = None,
tool_executor: Callable[[str, dict[str, Any]], Awaitable[Any]] | None = None,
max_tool_iterations: int = 5,
**kwargs: Any,
) -> LLMResponse:
"""生成高级响应 - 实现完整的工具调用循环"""
"""
生成高级响应
参数:
messages: 消息列表
config: 生成配置
tools: 工具列表
tool_choice: 工具选择策略
**kwargs: 其他参数
返回:
LLMResponse: 模型响应
"""
self._check_not_closed()
from .adapters import get_adapter_for_api_type
@ -468,109 +575,43 @@ class LLMModel(LLMModelBase):
merged_dict.update(config.to_dict())
final_request_config = LLMGenerationConfig(**merged_dict)
tools_dict: list[dict[str, Any]] | None = None
if tools:
tools_dict = []
for tool in tools:
if hasattr(tool, "model_dump"):
model_dump_func = getattr(tool, "model_dump")
tools_dict.append(model_dump_func(exclude_none=True))
elif isinstance(tool, dict):
tools_dict.append(tool)
else:
try:
tools_dict.append(dict(tool))
except (TypeError, ValueError):
logger.warning(f"工具 '{tool}' 无法转换为字典,已忽略。")
http_client = await self._get_http_client()
current_messages = list(messages)
for iteration in range(max_tool_iterations):
logger.debug(f"工具调用循环迭代: {iteration + 1}/{max_tool_iterations}")
async with AsyncExitStack() as stack:
activated_tools = []
if tools:
for tool in tools:
if tool.type == "mcp" and callable(tool.mcp_session):
func_obj = getattr(tool.mcp_session, "func", None)
tool_name = (
getattr(func_obj, "__name__", "unknown")
if func_obj
else "unknown"
)
logger.debug(f"正在激活 MCP 工具会话: {tool_name}")
active_session = await stack.enter_async_context(
tool.mcp_session()
)
activated_tools.append(
LLMTool.from_mcp_session(
session=active_session, annotations=tool.annotations
)
)
else:
activated_tools.append(tool)
llm_response = await self._execute_with_smart_retry(
adapter,
current_messages,
messages,
final_request_config,
tools_dict if iteration == 0 else None,
tool_choice if iteration == 0 else None,
activated_tools if activated_tools else None,
tool_choice,
http_client,
)
response_tool_calls = llm_response.tool_calls or []
if not response_tool_calls or not tool_executor:
logger.debug("模型未请求工具调用,或未提供工具执行器。返回当前响应。")
return llm_response
logger.info(f"模型请求执行 {len(response_tool_calls)} 个工具。")
assistant_message_content = llm_response.text if llm_response.text else ""
current_messages.append(
LLMMessage.assistant_tool_calls(
content=assistant_message_content, tool_calls=response_tool_calls
)
)
tool_response_messages: list[LLMMessage] = []
for tool_call in response_tool_calls:
tool_name = tool_call.function.name
try:
tool_args_dict = json.loads(tool_call.function.arguments)
logger.debug(f"执行工具: {tool_name},参数: {tool_args_dict}")
tool_result = await tool_executor(tool_name, tool_args_dict)
logger.debug(
f"工具 '{tool_name}' 执行结果: {str(tool_result)[:200]}..."
)
tool_response_messages.append(
LLMMessage.tool_response(
tool_call_id=tool_call.id,
function_name=tool_name,
result=tool_result,
)
)
except json.JSONDecodeError as e:
logger.error(
f"工具 '{tool_name}' 参数JSON解析失败: "
f"{tool_call.function.arguments}, 错误: {e}"
)
tool_response_messages.append(
LLMMessage.tool_response(
tool_call_id=tool_call.id,
function_name=tool_name,
result={
"error": "Argument JSON parsing failed",
"details": str(e),
},
)
)
except Exception as e:
logger.error(f"执行工具 '{tool_name}' 失败: {e}", e=e)
tool_response_messages.append(
LLMMessage.tool_response(
tool_call_id=tool_call.id,
function_name=tool_name,
result={
"error": "Tool execution failed",
"details": str(e),
},
)
)
current_messages.extend(tool_response_messages)
logger.warning(f"已达到最大工具调用迭代次数 ({max_tool_iterations})。")
raise LLMException(
"已达到最大工具调用迭代次数,但模型仍在请求工具调用或未提供最终文本回复。",
code=LLMErrorCode.GENERATION_FAILED,
details={
"iterations": max_tool_iterations,
"last_messages": current_messages[-2:],
},
)
return llm_response
async def generate_embeddings(
self,
@ -578,7 +619,17 @@ class LLMModel(LLMModelBase):
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
**kwargs: Any,
) -> list[list[float]]:
"""生成文本嵌入向量"""
"""
生成文本嵌入向量
参数:
texts: 文本列表
task_type: 嵌入任务类型
**kwargs: 其他参数
返回:
list[list[float]]: 嵌入向量列表
"""
self._check_not_closed()
if not texts:
return []

View File

@ -0,0 +1,7 @@
"""
工具模块导出
"""
from .registry import tool_registry
__all__ = ["tool_registry"]

View File

@ -0,0 +1,181 @@
"""
工具注册表
负责加载管理和实例化来自配置的工具
"""
from collections.abc import Callable
from contextlib import AbstractAsyncContextManager
from functools import partial
from typing import TYPE_CHECKING
from pydantic import BaseModel
from zhenxun.services.log import logger
from ..types import LLMTool
if TYPE_CHECKING:
from ..config.providers import ToolConfig
from ..types.protocols import MCPCompatible
class ToolRegistry:
"""工具注册表,用于管理和实例化配置的工具。"""
def __init__(self):
self._function_tools: dict[str, LLMTool] = {}
self._mcp_config_models: dict[str, type[BaseModel]] = {}
if TYPE_CHECKING:
self._mcp_factories: dict[
str, Callable[..., AbstractAsyncContextManager["MCPCompatible"]]
] = {}
else:
self._mcp_factories: dict[str, Callable] = {}
self._tool_configs: dict[str, "ToolConfig"] | None = None
self._tool_cache: dict[str, "LLMTool"] = {}
def _load_configs_if_needed(self):
"""如果尚未加载则从主配置中加载MCP工具定义。"""
if self._tool_configs is None:
logger.debug("首次访问正在加载MCP工具配置...")
from ..config.providers import get_llm_config
llm_config = get_llm_config()
self._tool_configs = {tool.name: tool for tool in llm_config.mcp_tools}
logger.info(f"已加载 {len(self._tool_configs)} 个MCP工具配置。")
def function_tool(
self,
name: str,
description: str,
parameters: dict,
required: list[str] | None = None,
):
"""
装饰器在代码中注册一个简单的无状态的函数工具
参数:
name: 工具的唯一名称
description: 工具功能的描述
parameters: OpenAPI格式的函数参数schema的properties部分
required: 必需的参数列表
"""
def decorator(func: Callable):
if name in self._function_tools or name in self._mcp_factories:
logger.warning(f"正在覆盖已注册的工具: {name}")
tool_definition = LLMTool.create(
name=name,
description=description,
parameters=parameters,
required=required,
)
self._function_tools[name] = tool_definition
logger.info(f"已在代码中注册函数工具: '{name}'")
tool_definition.annotations = tool_definition.annotations or {}
tool_definition.annotations["executable"] = func
return func
return decorator
def mcp_tool(self, name: str, config_model: type[BaseModel]):
"""
装饰器注册一个MCP工具及其配置模型
参数:
name: 工具的唯一名称必须与配置文件中的名称匹配
config_model: 一个Pydantic模型用于定义和验证该工具的 `mcp_config`
"""
def decorator(factory_func: Callable):
if name in self._mcp_factories:
logger.warning(f"正在覆盖已注册的 MCP 工厂: {name}")
self._mcp_factories[name] = factory_func
self._mcp_config_models[name] = config_model
logger.info(f"已注册 MCP 工具 '{name}' (配置模型: {config_model.__name__})")
return factory_func
return decorator
def get_mcp_config_model(self, name: str) -> type[BaseModel] | None:
"""根据名称获取MCP工具的配置模型。"""
return self._mcp_config_models.get(name)
def register_mcp_factory(
self,
name: str,
factory: Callable,
):
"""
在代码中注册一个 MCP 会话工厂将其与配置中的工具名称关联
参数:
name: 工具的唯一名称必须与配置文件中的名称匹配
factory: 一个返回异步生成器的可调用对象会话工厂
"""
if name in self._mcp_factories:
logger.warning(f"正在覆盖已注册的 MCP 工厂: {name}")
self._mcp_factories[name] = factory
logger.info(f"已注册 MCP 会话工厂: '{name}'")
def get_tool(self, name: str) -> "LLMTool":
"""
根据名称获取一个 LLMTool 定义
对于MCP工具返回的 LLMTool 实例包含一个可调用的会话工厂
而不是一个已激活的会话
"""
logger.debug(f"🔍 请求获取工具定义: {name}")
if name in self._tool_cache:
logger.debug(f"✅ 从缓存中获取工具定义: {name}")
return self._tool_cache[name]
if name in self._function_tools:
logger.debug(f"🛠️ 获取函数工具定义: {name}")
tool = self._function_tools[name]
self._tool_cache[name] = tool
return tool
self._load_configs_if_needed()
if self._tool_configs is None or name not in self._tool_configs:
known_tools = list(self._function_tools.keys()) + (
list(self._tool_configs.keys()) if self._tool_configs else []
)
logger.error(f"❌ 未找到名为 '{name}' 的工具定义")
logger.debug(f"📋 可用工具定义列表: {known_tools}")
raise ValueError(f"未找到名为 '{name}' 的工具定义。已知工具: {known_tools}")
config = self._tool_configs[name]
tool: "LLMTool"
if name not in self._mcp_factories:
logger.error(f"❌ MCP工具 '{name}' 缺少工厂函数")
available_factories = list(self._mcp_factories.keys())
logger.debug(f"📋 已注册的MCP工厂: {available_factories}")
raise ValueError(
f"MCP 工具 '{name}' 已在配置中定义,但没有注册对应的工厂函数。"
"请使用 `@tool_registry.mcp_tool` 装饰器进行注册。"
)
logger.info(f"🔧 创建MCP工具定义: {name}")
factory = self._mcp_factories[name]
typed_mcp_config = config.mcp_config
logger.debug(f"📋 MCP工具配置: {typed_mcp_config}")
configured_factory = partial(factory, config=typed_mcp_config)
tool = LLMTool.from_mcp_session(session=configured_factory)
self._tool_cache[name] = tool
logger.debug(f"💾 MCP工具定义已缓存: {name}")
return tool
def get_tools(self, names: list[str]) -> list["LLMTool"]:
"""根据名称列表获取多个 LLMTool 实例。"""
return [self.get_tool(name) for name in names]
tool_registry = ToolRegistry()

View File

@ -4,6 +4,7 @@ LLM 类型定义模块
统一导出所有核心类型协议和异常定义
"""
from .capabilities import ModelCapabilities, ModelModality, get_model_capabilities
from .content import (
LLMContentPart,
LLMMessage,
@ -26,6 +27,7 @@ from .models import (
ToolMetadata,
UsageInfo,
)
from .protocols import MCPCompatible
__all__ = [
"EmbeddingTaskType",
@ -41,8 +43,11 @@ __all__ = [
"LLMTool",
"LLMToolCall",
"LLMToolFunction",
"MCPCompatible",
"ModelCapabilities",
"ModelDetail",
"ModelInfo",
"ModelModality",
"ModelName",
"ModelProvider",
"ProviderConfig",
@ -50,5 +55,6 @@ __all__ = [
"ToolCategory",
"ToolMetadata",
"UsageInfo",
"get_model_capabilities",
"get_user_friendly_error_message",
]

View File

@ -0,0 +1,128 @@
"""
LLM 模型能力定义模块
定义模型的输入输出模态工具调用支持等核心能力
"""
from enum import Enum
import fnmatch
from pydantic import BaseModel, Field
class ModelModality(str, Enum):
TEXT = "text"
IMAGE = "image"
AUDIO = "audio"
VIDEO = "video"
EMBEDDING = "embedding"
class ModelCapabilities(BaseModel):
"""定义一个模型的核心、稳定能力。"""
input_modalities: set[ModelModality] = Field(default={ModelModality.TEXT})
output_modalities: set[ModelModality] = Field(default={ModelModality.TEXT})
supports_tool_calling: bool = False
is_embedding_model: bool = False
STANDARD_TEXT_TOOL_CAPABILITIES = ModelCapabilities(
input_modalities={ModelModality.TEXT},
output_modalities={ModelModality.TEXT},
supports_tool_calling=True,
)
GEMINI_CAPABILITIES = ModelCapabilities(
input_modalities={
ModelModality.TEXT,
ModelModality.IMAGE,
ModelModality.AUDIO,
ModelModality.VIDEO,
},
output_modalities={ModelModality.TEXT},
supports_tool_calling=True,
)
DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES = ModelCapabilities(
input_modalities={ModelModality.TEXT, ModelModality.IMAGE, ModelModality.VIDEO},
output_modalities={ModelModality.TEXT},
supports_tool_calling=True,
)
MODEL_ALIAS_MAPPING: dict[str, str] = {
"deepseek-v3*": "deepseek-chat",
"deepseek-ai/DeepSeek-V3": "deepseek-chat",
"deepseek-r1*": "deepseek-reasoner",
}
MODEL_CAPABILITIES_REGISTRY: dict[str, ModelCapabilities] = {
"gemini-*-tts": ModelCapabilities(
input_modalities={ModelModality.TEXT},
output_modalities={ModelModality.AUDIO},
),
"gemini-*-native-audio-*": ModelCapabilities(
input_modalities={ModelModality.TEXT, ModelModality.AUDIO, ModelModality.VIDEO},
output_modalities={ModelModality.TEXT, ModelModality.AUDIO},
supports_tool_calling=True,
),
"gemini-2.0-flash-preview-image-generation": ModelCapabilities(
input_modalities={
ModelModality.TEXT,
ModelModality.IMAGE,
ModelModality.AUDIO,
ModelModality.VIDEO,
},
output_modalities={ModelModality.TEXT, ModelModality.IMAGE},
supports_tool_calling=True,
),
"gemini-embedding-exp": ModelCapabilities(
input_modalities={ModelModality.TEXT},
output_modalities={ModelModality.EMBEDDING},
is_embedding_model=True,
),
"gemini-2.5-pro*": GEMINI_CAPABILITIES,
"gemini-1.5-pro*": GEMINI_CAPABILITIES,
"gemini-2.5-flash*": GEMINI_CAPABILITIES,
"gemini-2.0-flash*": GEMINI_CAPABILITIES,
"gemini-1.5-flash*": GEMINI_CAPABILITIES,
"GLM-4V-Flash": ModelCapabilities(
input_modalities={ModelModality.TEXT, ModelModality.IMAGE},
output_modalities={ModelModality.TEXT},
supports_tool_calling=True,
),
"GLM-4V-Plus*": ModelCapabilities(
input_modalities={ModelModality.TEXT, ModelModality.IMAGE, ModelModality.VIDEO},
output_modalities={ModelModality.TEXT},
supports_tool_calling=True,
),
"glm-4-*": STANDARD_TEXT_TOOL_CAPABILITIES,
"glm-z1-*": STANDARD_TEXT_TOOL_CAPABILITIES,
"doubao-seed-*": DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES,
"doubao-1-5-thinking-vision-pro": DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES,
"deepseek-chat": STANDARD_TEXT_TOOL_CAPABILITIES,
"deepseek-reasoner": STANDARD_TEXT_TOOL_CAPABILITIES,
}
def get_model_capabilities(model_name: str) -> ModelCapabilities:
"""
从注册表获取模型能力支持别名映射和通配符匹配
查找顺序: 1. 标准化名称 -> 2. 精确匹配 -> 3. 通配符匹配 -> 4. 默认值
"""
canonical_name = model_name
for alias_pattern, c_name in MODEL_ALIAS_MAPPING.items():
if fnmatch.fnmatch(model_name, alias_pattern):
canonical_name = c_name
break
if canonical_name in MODEL_CAPABILITIES_REGISTRY:
return MODEL_CAPABILITIES_REGISTRY[canonical_name]
for pattern, capabilities in MODEL_CAPABILITIES_REGISTRY.items():
if "*" in pattern and fnmatch.fnmatch(model_name, pattern):
return capabilities
return ModelCapabilities()

View File

@ -225,8 +225,10 @@ class LLMContentPart(BaseModel):
logger.warning(f"无法解析Base64图像数据: {self.image_source[:50]}...")
return None
def convert_for_api(self, api_type: str) -> dict[str, Any]:
async def convert_for_api_async(self, api_type: str) -> dict[str, Any]:
"""根据API类型转换多模态内容格式"""
from zhenxun.utils.http_utils import AsyncHttpx
if self.type == "text":
if api_type == "openai":
return {"type": "text", "text": self.text}
@ -248,20 +250,23 @@ class LLMContentPart(BaseModel):
mime_type, data = base64_info
return {"inlineData": {"mimeType": mime_type, "data": data}}
else:
# 如果无法解析 Base64 数据,抛出异常
raise ValueError(
f"无法解析Base64图像数据: {self.image_source[:50]}..."
)
else:
logger.warning(
f"Gemini API需要Base64格式但提供的是URL: {self.image_source}"
)
return {
"inlineData": {
"mimeType": "image/jpeg",
"data": self.image_source,
elif self.is_image_url():
logger.debug(f"正在为Gemini下载并编码URL图片: {self.image_source}")
try:
image_bytes = await AsyncHttpx.get_content(self.image_source)
mime_type = self.mime_type or "image/jpeg"
base64_data = base64.b64encode(image_bytes).decode("utf-8")
return {
"inlineData": {"mimeType": mime_type, "data": base64_data}
}
}
except Exception as e:
logger.error(f"下载或编码URL图片失败: {e}", e=e)
raise ValueError(f"无法处理图片URL: {e}")
else:
raise ValueError(f"不支持的图像源格式: {self.image_source[:50]}...")
else:
return {"type": "image_url", "image_url": {"url": self.image_source}}

View File

@ -4,13 +4,25 @@ LLM 数据模型定义
包含模型信息配置工具定义和响应数据的模型类
"""
from collections.abc import Callable
from contextlib import AbstractAsyncContextManager
from dataclasses import dataclass, field
from typing import Any
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel, Field
from .enums import ModelProvider, ToolCategory
if TYPE_CHECKING:
from .protocols import MCPCompatible
MCPSessionType = (
MCPCompatible | Callable[[], AbstractAsyncContextManager[MCPCompatible]] | None
)
else:
MCPCompatible = object
MCPSessionType = Any
ModelName = str | None
@ -98,10 +110,21 @@ class LLMToolCall(BaseModel):
class LLMTool(BaseModel):
"""LLM 工具定义(支持 MCP 风格)"""
model_config = {"arbitrary_types_allowed": True}
type: str = "function"
function: dict[str, Any]
function: dict[str, Any] | None = None
mcp_session: MCPSessionType = None
annotations: dict[str, Any] | None = Field(default=None, description="工具注解")
def model_post_init(self, /, __context: Any) -> None:
"""验证工具定义的有效性"""
_ = __context
if self.type == "function" and self.function is None:
raise ValueError("函数类型的工具必须包含 'function' 字段。")
if self.type == "mcp" and self.mcp_session is None:
raise ValueError("MCP 类型的工具必须包含 'mcp_session' 字段。")
@classmethod
def create(
cls,
@ -111,7 +134,7 @@ class LLMTool(BaseModel):
required: list[str] | None = None,
annotations: dict[str, Any] | None = None,
) -> "LLMTool":
"""创建工具"""
"""创建函数工具"""
function_def = {
"name": name,
"description": description,
@ -123,6 +146,15 @@ class LLMTool(BaseModel):
}
return cls(type="function", function=function_def, annotations=annotations)
@classmethod
def from_mcp_session(
cls,
session: Any,
annotations: dict[str, Any] | None = None,
) -> "LLMTool":
"""从 MCP 会话创建工具"""
return cls(type="mcp", mcp_session=session, annotations=annotations)
class LLMCodeExecution(BaseModel):
"""代码执行结果"""

View File

@ -0,0 +1,24 @@
"""
LLM 模块的协议定义
"""
from typing import Any, Protocol
class MCPCompatible(Protocol):
"""
一个协议定义了与LLM模块兼容的MCP会话对象应具备的行为
任何实现了 to_api_tool 方法的对象都可以被认为是 MCPCompatible
"""
def to_api_tool(self, api_type: str) -> dict[str, Any]:
"""
将此MCP会话转换为特定LLM提供商API所需的工具格式
参数:
api_type: 目标API的类型 (例如 'gemini', 'openai')
返回:
dict[str, Any]: 一个字典代表可以在API请求中使用的工具定义
"""
...

View File

@ -3,8 +3,10 @@ LLM 模块的工具和转换函数
"""
import base64
import copy
from pathlib import Path
from nonebot.adapters import Message as PlatformMessage
from nonebot_plugin_alconna.uniseg import (
At,
File,
@ -17,6 +19,7 @@ from nonebot_plugin_alconna.uniseg import (
)
from zhenxun.services.log import logger
from zhenxun.utils.http_utils import AsyncHttpx
from .types import LLMContentPart
@ -25,6 +28,12 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
"""
UniMessage 实例转换为一个 LLMContentPart 列表
这是处理多模态输入的核心转换逻辑
参数:
message: 要转换的UniMessage实例
返回:
list[LLMContentPart]: 转换后的内容部分列表
"""
parts: list[LLMContentPart] = []
for seg in message:
@ -51,14 +60,25 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
if seg.path:
part = await LLMContentPart.from_path(seg.path)
elif seg.url:
logger.warning(
f"直接使用 URL 的 {type(seg).__name__} 段,"
f"API 可能不支持: {seg.url}"
)
part = LLMContentPart.text_part(
f"[{type(seg).__name__.upper()} FILE: {seg.name or seg.url}]"
)
elif hasattr(seg, "raw") and seg.raw:
try:
logger.debug(f"检测到媒体URL开始下载: {seg.url}")
media_bytes = await AsyncHttpx.get_content(seg.url)
new_seg = copy.copy(seg)
new_seg.raw = media_bytes
seg = new_seg
logger.debug(f"媒体文件下载成功,大小: {len(media_bytes)} bytes")
except Exception as e:
logger.error(f"从URL下载媒体失败: {seg.url}, 错误: {e}")
part = LLMContentPart.text_part(
f"[下载媒体失败: {seg.name or seg.url}]"
)
if part:
parts.append(part)
continue
if hasattr(seg, "raw") and seg.raw:
mime_type = getattr(seg, "mimetype", None)
if isinstance(seg.raw, bytes):
b64_data = base64.b64encode(seg.raw).decode("utf-8")
@ -127,50 +147,19 @@ def create_multimodal_message(
audio_mimetypes: list[str] | str | None = None,
) -> UniMessage:
"""
创建多模态消息的便捷函数方便第三方调用
创建多模态消息的便捷函数
Args:
参数:
text: 文本内容
images: 图片数据支持路径字节数据或URL
videos: 视频数据支持路径字节数据或URL
audios: 音频数据支持路径字节数据或URL
image_mimetypes: 图片MIME类型当images为bytes时需要指定
video_mimetypes: 视频MIME类型当videos为bytes时需要指定
audio_mimetypes: 音频MIME类型当audios为bytes时需要指定
videos: 视频数据
audios: 音频数据
image_mimetypes: 图片MIME类型bytes数据时需要指定
video_mimetypes: 视频MIME类型bytes数据时需要指定
audio_mimetypes: 音频MIME类型bytes数据时需要指定
Returns:
返回:
UniMessage: 构建好的多模态消息
Examples:
# 纯文本
msg = create_multimodal_message("请分析这段文字")
# 文本 + 单张图片(路径)
msg = create_multimodal_message("分析图片", images="/path/to/image.jpg")
# 文本 + 多张图片
msg = create_multimodal_message(
"比较图片", images=["/path/1.jpg", "/path/2.jpg"]
)
# 文本 + 图片字节数据
msg = create_multimodal_message(
"分析", images=image_data, image_mimetypes="image/jpeg"
)
# 文本 + 视频
msg = create_multimodal_message("分析视频", videos="/path/to/video.mp4")
# 文本 + 音频
msg = create_multimodal_message("转录音频", audios="/path/to/audio.wav")
# 混合多模态
msg = create_multimodal_message(
"分析这些媒体文件",
images="/path/to/image.jpg",
videos="/path/to/video.mp4",
audios="/path/to/audio.wav"
)
"""
message = UniMessage()
@ -196,7 +185,7 @@ def _add_media_to_message(
media_class: type,
default_mimetype: str,
) -> None:
"""添加媒体文件到 UniMessage 的辅助函数"""
"""添加媒体文件到 UniMessage"""
if not isinstance(media_items, list):
media_items = [media_items]
@ -216,3 +205,80 @@ def _add_media_to_message(
elif isinstance(item, bytes):
mimetype = mime_list[i] if i < len(mime_list) else default_mimetype
message.append(media_class(raw=item, mimetype=mimetype))
def message_to_unimessage(message: PlatformMessage) -> UniMessage:
"""
将平台特定的 Message 对象转换为通用的 UniMessage
主要用于处理引用消息等未被自动转换的消息体
参数:
message: 平台特定的Message对象
返回:
UniMessage: 转换后的通用消息对象
"""
uni_segments = []
for seg in message:
if seg.type == "text":
uni_segments.append(Text(seg.data.get("text", "")))
elif seg.type == "image":
uni_segments.append(Image(url=seg.data.get("url")))
elif seg.type == "record":
uni_segments.append(Voice(url=seg.data.get("url")))
elif seg.type == "video":
uni_segments.append(Video(url=seg.data.get("url")))
elif seg.type == "at":
uni_segments.append(At("user", str(seg.data.get("qq", ""))))
else:
logger.debug(f"跳过不支持的平台消息段类型: {seg.type}")
return UniMessage(uni_segments)
def _sanitize_request_body_for_logging(body: dict) -> dict:
"""
净化请求体用于日志记录移除大数据字段并添加摘要信息
参数:
body: 原始请求体字典
返回:
dict: 净化后的请求体字典
"""
try:
sanitized_body = copy.deepcopy(body)
if "contents" in sanitized_body and isinstance(
sanitized_body["contents"], list
):
for content_item in sanitized_body["contents"]:
if "parts" in content_item and isinstance(content_item["parts"], list):
media_summary = []
new_parts = []
for part in content_item["parts"]:
if "inlineData" in part and isinstance(
part["inlineData"], dict
):
data = part["inlineData"].get("data")
if isinstance(data, str):
mime_type = part["inlineData"].get(
"mimeType", "unknown"
)
media_summary.append(f"{mime_type} ({len(data)} chars)")
continue
new_parts.append(part)
if media_summary:
summary_text = (
f"[多模态内容: {len(media_summary)}个文件 - "
f"{', '.join(media_summary)}]"
)
new_parts.insert(0, {"text": summary_text})
content_item["parts"] = new_parts
return sanitized_body
except Exception as e:
logger.warning(f"日志净化失败: {e},将记录原始请求体。")
return body

View File

@ -0,0 +1,12 @@
"""
定时调度服务模块
提供一个统一的持久化的定时任务管理器供所有插件使用
"""
from .lifecycle import _load_schedules_from_db
from .service import scheduler_manager
_ = _load_schedules_from_db
__all__ = ["scheduler_manager"]

View File

@ -0,0 +1,102 @@
"""
引擎适配层 (Adapter)
封装所有对具体调度器引擎 (APScheduler) 的操作
使上层服务与调度器实现解耦
"""
from nonebot_plugin_apscheduler import scheduler
from zhenxun.models.schedule_info import ScheduleInfo
from zhenxun.services.log import logger
from .job import _execute_job
JOB_PREFIX = "zhenxun_schedule_"
class APSchedulerAdapter:
"""封装对 APScheduler 的操作"""
@staticmethod
def _get_job_id(schedule_id: int) -> str:
"""生成 APScheduler 的 Job ID"""
return f"{JOB_PREFIX}{schedule_id}"
@staticmethod
def add_or_reschedule_job(schedule: ScheduleInfo):
"""根据 ScheduleInfo 添加或重新调度一个 APScheduler 任务"""
job_id = APSchedulerAdapter._get_job_id(schedule.id)
if not isinstance(schedule.trigger_config, dict):
logger.error(
f"任务 {schedule.id} 的 trigger_config 不是字典类型: "
f"{type(schedule.trigger_config)}"
)
return
job = scheduler.get_job(job_id)
if job:
scheduler.reschedule_job(
job_id, trigger=schedule.trigger_type, **schedule.trigger_config
)
logger.debug(f"已更新APScheduler任务: {job_id}")
else:
scheduler.add_job(
_execute_job,
trigger=schedule.trigger_type,
id=job_id,
misfire_grace_time=300,
args=[schedule.id],
**schedule.trigger_config,
)
logger.debug(f"已添加新的APScheduler任务: {job_id}")
@staticmethod
def remove_job(schedule_id: int):
"""移除一个 APScheduler 任务"""
job_id = APSchedulerAdapter._get_job_id(schedule_id)
try:
scheduler.remove_job(job_id)
logger.debug(f"已从APScheduler中移除任务: {job_id}")
except Exception:
pass
@staticmethod
def pause_job(schedule_id: int):
"""暂停一个 APScheduler 任务"""
job_id = APSchedulerAdapter._get_job_id(schedule_id)
try:
scheduler.pause_job(job_id)
except Exception:
pass
@staticmethod
def resume_job(schedule_id: int):
"""恢复一个 APScheduler 任务"""
job_id = APSchedulerAdapter._get_job_id(schedule_id)
try:
scheduler.resume_job(job_id)
except Exception:
import asyncio
from .repository import ScheduleRepository
async def _re_add_job():
schedule = await ScheduleRepository.get_by_id(schedule_id)
if schedule:
APSchedulerAdapter.add_or_reschedule_job(schedule)
asyncio.create_task(_re_add_job()) # noqa: RUF006
@staticmethod
def get_job_status(schedule_id: int) -> dict:
"""获取 APScheduler Job 的状态"""
job_id = APSchedulerAdapter._get_job_id(schedule_id)
job = scheduler.get_job(job_id)
return {
"next_run_time": job.next_run_time.strftime("%Y-%m-%d %H:%M:%S")
if job and job.next_run_time
else "N/A",
"is_paused_in_scheduler": not bool(job.next_run_time) if job else "N/A",
}

View File

@ -0,0 +1,192 @@
"""
定时任务的执行逻辑
包含被 APScheduler 实际调度的函数以及处理不同目标单个所有群组的执行策略
"""
import asyncio
import copy
import inspect
import random
import nonebot
from zhenxun.configs.config import Config
from zhenxun.models.schedule_info import ScheduleInfo
from zhenxun.services.log import logger
from zhenxun.utils.common_utils import CommonUtils
from zhenxun.utils.decorator.retry import Retry
from zhenxun.utils.platform import PlatformUtils
SCHEDULE_CONCURRENCY_KEY = "all_groups_concurrency_limit"
async def _execute_job(schedule_id: int):
"""
APScheduler 调度的入口函数
根据 schedule_id 处理特定任务所有群组任务或全局任务
"""
from .repository import ScheduleRepository
from .service import scheduler_manager
scheduler_manager._running_tasks.add(schedule_id)
try:
schedule = await ScheduleRepository.get_by_id(schedule_id)
if not schedule or not schedule.is_enabled:
logger.warning(f"定时任务 {schedule_id} 不存在或已禁用,跳过执行。")
return
plugin_name = schedule.plugin_name
task_meta = scheduler_manager._registered_tasks.get(plugin_name)
if not task_meta:
logger.error(
f"无法执行定时任务:插件 '{plugin_name}' 未注册或已卸载。将禁用该任务。"
)
schedule.is_enabled = False
await ScheduleRepository.save(schedule, update_fields=["is_enabled"])
from .adapter import APSchedulerAdapter
APSchedulerAdapter.remove_job(schedule.id)
return
try:
if schedule.bot_id:
bot = nonebot.get_bot(schedule.bot_id)
else:
bot = nonebot.get_bot()
logger.debug(
f"任务 {schedule_id} 未关联特定Bot使用默认Bot {bot.self_id}"
)
except KeyError:
logger.warning(
f"定时任务 {schedule_id} 需要的 Bot {schedule.bot_id} "
f"不在线,本次执行跳过。"
)
return
except ValueError:
logger.warning(f"当前没有Bot在线定时任务 {schedule_id} 跳过。")
return
if schedule.group_id == scheduler_manager.ALL_GROUPS:
await _execute_for_all_groups(schedule, task_meta, bot)
else:
await _execute_for_single_target(schedule, task_meta, bot)
finally:
scheduler_manager._running_tasks.discard(schedule_id)
async def _execute_for_all_groups(schedule: ScheduleInfo, task_meta: dict, bot):
"""为所有群组执行任务,并处理优先级覆盖。"""
plugin_name = schedule.plugin_name
concurrency_limit = Config.get_config(
"SchedulerManager", SCHEDULE_CONCURRENCY_KEY, 5
)
if not isinstance(concurrency_limit, int) or concurrency_limit <= 0:
logger.warning(
f"无效的定时任务并发限制配置 '{concurrency_limit}',将使用默认值 5。"
)
concurrency_limit = 5
logger.info(
f"开始执行针对 [所有群组] 的任务 "
f"(ID: {schedule.id}, 插件: {plugin_name}, Bot: {bot.self_id})"
f"并发限制: {concurrency_limit}"
)
all_gids = set()
try:
group_list, _ = await PlatformUtils.get_group_list(bot)
all_gids.update(
g.group_id for g in group_list if g.group_id and not g.channel_id
)
except Exception as e:
logger.error(f"'all' 任务获取 Bot {bot.self_id} 的群列表失败", e=e)
return
specific_tasks_gids = set(
await ScheduleInfo.filter(
plugin_name=plugin_name, group_id__in=list(all_gids)
).values_list("group_id", flat=True)
)
semaphore = asyncio.Semaphore(concurrency_limit)
async def worker(gid: str):
"""使用 Semaphore 包装单个群组的任务执行"""
await asyncio.sleep(random.uniform(0, 59))
async with semaphore:
temp_schedule = copy.deepcopy(schedule)
temp_schedule.group_id = gid
await _execute_for_single_target(temp_schedule, task_meta, bot)
await asyncio.sleep(random.uniform(0.1, 0.5))
tasks_to_run = []
for gid in all_gids:
if gid in specific_tasks_gids:
logger.debug(f"群组 {gid} 已有特定任务,跳过 'all' 任务的执行。")
continue
tasks_to_run.append(worker(gid))
if tasks_to_run:
await asyncio.gather(*tasks_to_run)
async def _execute_for_single_target(schedule: ScheduleInfo, task_meta: dict, bot):
"""为单个目标(具体群组或全局)执行任务。"""
plugin_name = schedule.plugin_name
group_id = schedule.group_id
try:
is_blocked = await CommonUtils.task_is_block(bot, plugin_name, group_id)
if is_blocked:
target_desc = f"{group_id}" if group_id else "全局"
logger.info(
f"插件 '{plugin_name}' 的定时任务在目标 [{target_desc}]"
"因功能被禁用而跳过执行。"
)
return
max_retries = Config.get_config("SchedulerManager", "JOB_MAX_RETRIES", 2)
retry_delay = Config.get_config("SchedulerManager", "JOB_RETRY_DELAY", 10)
@Retry.simple(
stop_max_attempt=max_retries + 1,
wait_fixed_seconds=retry_delay,
log_name=f"定时任务执行:{schedule.plugin_name}",
)
async def _execute_task_with_retry():
task_func = task_meta["func"]
job_kwargs = schedule.job_kwargs
if not isinstance(job_kwargs, dict):
logger.error(
f"任务 {schedule.id} 的 job_kwargs 不是字典类型: {type(job_kwargs)}"
)
return
sig = inspect.signature(task_func)
if "bot" in sig.parameters:
job_kwargs["bot"] = bot
await task_func(group_id, **job_kwargs)
try:
logger.info(
f"插件 '{schedule.plugin_name}' 开始为目标 "
f"[{schedule.group_id or '全局'}] 执行定时任务 (ID: {schedule.id})。"
)
await _execute_task_with_retry()
except Exception as e:
logger.error(
f"执行定时任务 (ID: {schedule.id}, 插件: {schedule.plugin_name}, "
f"目标: {schedule.group_id or '全局'}) 在所有重试后最终失败",
e=e,
)
except Exception as e:
logger.error(
f"执行定时任务 (ID: {schedule.id}, 插件: {plugin_name}, "
f"目标: {group_id or '全局'}) 时发生异常",
e=e,
)

View File

@ -0,0 +1,62 @@
"""
定时任务的生命周期管理
包含在机器人启动时加载和调度数据库中保存的任务的逻辑
"""
from zhenxun.services.log import logger
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from .adapter import APSchedulerAdapter
from .repository import ScheduleRepository
from .service import scheduler_manager
@PriorityLifecycle.on_startup(priority=90)
async def _load_schedules_from_db():
"""在服务启动时从数据库加载并调度所有任务。"""
logger.info("正在从数据库加载并调度所有定时任务...")
schedules = await ScheduleRepository.get_all_enabled()
count = 0
for schedule in schedules:
if schedule.plugin_name in scheduler_manager._registered_tasks:
APSchedulerAdapter.add_or_reschedule_job(schedule)
count += 1
else:
logger.warning(f"跳过加载定时任务:插件 '{schedule.plugin_name}' 未注册。")
logger.info(f"数据库定时任务加载完成,共成功加载 {count} 个任务。")
logger.info("正在检查并注册声明式默认任务...")
declared_count = 0
for task_info in scheduler_manager._declared_tasks:
plugin_name = task_info["plugin_name"]
group_id = task_info["group_id"]
bot_id = task_info["bot_id"]
query_kwargs = {
"plugin_name": plugin_name,
"group_id": group_id,
"bot_id": bot_id,
}
exists = await ScheduleRepository.exists(**query_kwargs)
if not exists:
logger.info(f"为插件 '{plugin_name}' 注册新的默认定时任务...")
schedule = await scheduler_manager.add_schedule(
plugin_name=plugin_name,
group_id=group_id,
trigger_type=task_info["trigger_type"],
trigger_config=task_info["trigger_config"],
job_kwargs=task_info["job_kwargs"],
bot_id=bot_id,
)
if schedule:
declared_count += 1
logger.debug(f"默认任务 '{plugin_name}' 注册成功 (ID: {schedule.id})")
else:
logger.error(f"默认任务 '{plugin_name}' 注册失败")
else:
logger.debug(f"插件 '{plugin_name}' 的默认任务已存在于数据库中,跳过注册。")
if declared_count > 0:
logger.info(f"声明式任务检查完成,新注册了 {declared_count} 个默认任务。")

View File

@ -0,0 +1,79 @@
"""
数据持久层 (Repository)
封装所有对 ScheduleInfo 模型的数据库操作将数据访问逻辑与业务逻辑分离
"""
from typing import Any
from tortoise.queryset import QuerySet
from zhenxun.models.schedule_info import ScheduleInfo
class ScheduleRepository:
"""封装 ScheduleInfo 模型的数据库操作"""
@staticmethod
async def get_by_id(schedule_id: int) -> ScheduleInfo | None:
"""通过ID获取任务"""
return await ScheduleInfo.get_or_none(id=schedule_id)
@staticmethod
async def get_all_enabled() -> list[ScheduleInfo]:
"""获取所有启用的任务"""
return await ScheduleInfo.filter(is_enabled=True).all()
@staticmethod
async def get_all(plugin_name: str | None = None) -> list[ScheduleInfo]:
"""获取所有任务,可按插件名过滤"""
if plugin_name:
return await ScheduleInfo.filter(plugin_name=plugin_name).all()
return await ScheduleInfo.all()
@staticmethod
async def save(schedule: ScheduleInfo, update_fields: list[str] | None = None):
"""保存任务"""
await schedule.save(update_fields=update_fields)
@staticmethod
async def exists(**kwargs: Any) -> bool:
"""检查任务是否存在"""
return await ScheduleInfo.exists(**kwargs)
@staticmethod
async def get_by_plugin_and_group(
plugin_name: str, group_ids: list[str]
) -> list[ScheduleInfo]:
"""根据插件和群组ID列表获取任务"""
return await ScheduleInfo.filter(
plugin_name=plugin_name, group_id__in=group_ids
).all()
@staticmethod
async def update_or_create(
defaults: dict, **kwargs: Any
) -> tuple[ScheduleInfo, bool]:
"""更新或创建任务"""
return await ScheduleInfo.update_or_create(defaults=defaults, **kwargs)
@staticmethod
async def query_schedules(**filters: Any) -> list[ScheduleInfo]:
"""
根据任意条件查询任务列表
参数:
**filters: 过滤条件 group_id="123", plugin_name="abc"
返回:
list[ScheduleInfo]: 任务列表
"""
cleaned_filters = {k: v for k, v in filters.items() if v is not None}
if not cleaned_filters:
return await ScheduleInfo.all()
return await ScheduleInfo.filter(**cleaned_filters).all()
@staticmethod
def filter(**kwargs: Any) -> QuerySet[ScheduleInfo]:
"""提供一个通用的过滤查询接口供Targeter使用"""
return ScheduleInfo.filter(**kwargs)

View File

@ -0,0 +1,448 @@
"""
服务层 (Service)
定义 SchedulerManager 类作为定时任务服务的公共 API 入口
它负责编排业务逻辑并调用 Repository Adapter 层来完成具体工作
"""
from collections.abc import Callable, Coroutine
from datetime import datetime
from typing import Any, ClassVar
import nonebot
from pydantic import BaseModel
from zhenxun.configs.config import Config
from zhenxun.models.schedule_info import ScheduleInfo
from zhenxun.services.log import logger
from .adapter import APSchedulerAdapter
from .job import _execute_job
from .repository import ScheduleRepository
from .targeter import ScheduleTargeter
class SchedulerManager:
ALL_GROUPS: ClassVar[str] = "__ALL_GROUPS__"
_registered_tasks: ClassVar[
dict[str, dict[str, Callable | type[BaseModel] | None]]
] = {}
_declared_tasks: ClassVar[list[dict[str, Any]]] = []
_running_tasks: ClassVar[set] = set()
def target(self, **filters: Any) -> ScheduleTargeter:
"""
创建目标选择器以执行批量操作
参数:
**filters: 过滤条件支持plugin_namegroup_idbot_id等字段
返回:
ScheduleTargeter: 目标选择器对象可用于批量操作
"""
return ScheduleTargeter(self, **filters)
def task(
self,
trigger: str,
group_id: str | None = None,
bot_id: str | None = None,
**trigger_kwargs,
):
"""
声明式定时任务装饰器
参数:
trigger: 触发器类型'cron''interval'
group_id: 目标群组IDNone表示全局任务
bot_id: 目标Bot IDNone表示使用默认Bot
**trigger_kwargs: 触发器配置参数
返回:
Callable: 装饰器函数
"""
def decorator(func: Callable[..., Coroutine]) -> Callable[..., Coroutine]:
try:
plugin = nonebot.get_plugin_by_module_name(func.__module__)
if not plugin:
raise ValueError(f"函数 {func.__name__} 不在任何已加载的插件中。")
plugin_name = plugin.name
task_declaration = {
"plugin_name": plugin_name,
"func": func,
"group_id": group_id,
"bot_id": bot_id,
"trigger_type": trigger,
"trigger_config": trigger_kwargs,
"job_kwargs": {},
}
self._declared_tasks.append(task_declaration)
logger.debug(
f"发现声明式定时任务 '{plugin_name}',将在启动时进行注册。"
)
except Exception as e:
logger.error(f"注册声明式定时任务失败: {func.__name__}, 错误: {e}")
return func
return decorator
def register(
self, plugin_name: str, params_model: type[BaseModel] | None = None
) -> Callable:
"""
注册可调度的任务函数
参数:
plugin_name: 插件名称用于标识任务
params_model: 参数验证模型继承自BaseModel的类
返回:
Callable: 装饰器函数
"""
def decorator(func: Callable[..., Coroutine]) -> Callable[..., Coroutine]:
if plugin_name in self._registered_tasks:
logger.warning(f"插件 '{plugin_name}' 的定时任务已被重复注册。")
self._registered_tasks[plugin_name] = {
"func": func,
"model": params_model,
}
model_name = params_model.__name__ if params_model else ""
logger.debug(
f"插件 '{plugin_name}' 的定时任务已注册,参数模型: {model_name}"
)
return func
return decorator
def get_registered_plugins(self) -> list[str]:
"""
获取已注册插件列表
返回:
list[str]: 已注册的插件名称列表
"""
return list(self._registered_tasks.keys())
async def add_daily_task(
self,
plugin_name: str,
group_id: str | None,
hour: int,
minute: int,
second: int = 0,
job_kwargs: dict | None = None,
bot_id: str | None = None,
) -> "ScheduleInfo | None":
"""
添加每日定时任务
参数:
plugin_name: 插件名称
group_id: 目标群组IDNone表示全局任务
hour: 执行小时0-23
minute: 执行分钟0-59
second: 执行秒数0-59默认为0
job_kwargs: 任务参数字典
bot_id: 目标Bot IDNone表示使用默认Bot
返回:
ScheduleInfo | None: 创建的任务信息失败时返回None
"""
trigger_config = {
"hour": hour,
"minute": minute,
"second": second,
"timezone": Config.get_config("SchedulerManager", "SCHEDULER_TIMEZONE"),
}
return await self.add_schedule(
plugin_name,
group_id,
"cron",
trigger_config,
job_kwargs=job_kwargs,
bot_id=bot_id,
)
async def add_interval_task(
self,
plugin_name: str,
group_id: str | None,
*,
weeks: int = 0,
days: int = 0,
hours: int = 0,
minutes: int = 0,
seconds: int = 0,
start_date: str | datetime | None = None,
job_kwargs: dict | None = None,
bot_id: str | None = None,
) -> "ScheduleInfo | None":
"""添加间隔性定时任务"""
trigger_config = {
"weeks": weeks,
"days": days,
"hours": hours,
"minutes": minutes,
"seconds": seconds,
"start_date": start_date,
}
trigger_config = {k: v for k, v in trigger_config.items() if v}
return await self.add_schedule(
plugin_name,
group_id,
"interval",
trigger_config,
job_kwargs=job_kwargs,
bot_id=bot_id,
)
def _validate_and_prepare_kwargs(
self, plugin_name: str, job_kwargs: dict | None
) -> tuple[bool, str | dict]:
"""验证并准备任务参数,应用默认值"""
from pydantic import ValidationError
task_meta = self._registered_tasks.get(plugin_name)
if not task_meta:
return False, f"插件 '{plugin_name}' 未注册。"
params_model = task_meta.get("model")
job_kwargs = job_kwargs if job_kwargs is not None else {}
if not params_model:
if job_kwargs:
logger.warning(
f"插件 '{plugin_name}' 未定义参数模型,但收到了参数: {job_kwargs}"
)
return True, job_kwargs
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
logger.error(f"插件 '{plugin_name}' 的参数模型不是有效的 BaseModel 类")
return False, f"插件 '{plugin_name}' 的参数模型配置错误"
try:
model_validate = getattr(params_model, "model_validate", None)
if not model_validate:
return False, f"插件 '{plugin_name}' 的参数模型不支持验证"
validated_model = model_validate(job_kwargs)
model_dump = getattr(validated_model, "model_dump", None)
if not model_dump:
return False, f"插件 '{plugin_name}' 的参数模型不支持导出"
return True, model_dump()
except ValidationError as e:
errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()]
error_str = "\n".join(errors)
msg = f"插件 '{plugin_name}' 的任务参数验证失败:\n{error_str}"
return False, msg
async def add_schedule(
self,
plugin_name: str,
group_id: str | None,
trigger_type: str,
trigger_config: dict,
job_kwargs: dict | None = None,
bot_id: str | None = None,
) -> "ScheduleInfo | None":
"""
添加定时任务通用方法
参数:
plugin_name: 插件名称
group_id: 目标群组IDNone表示全局任务
trigger_type: 触发器类型'cron''interval'
trigger_config: 触发器配置字典
job_kwargs: 任务参数字典
bot_id: 目标Bot IDNone表示使用默认Bot
返回:
ScheduleInfo | None: 创建的任务信息失败时返回None
"""
if plugin_name not in self._registered_tasks:
logger.error(f"插件 '{plugin_name}' 没有注册可用的定时任务。")
return None
is_valid, result = self._validate_and_prepare_kwargs(plugin_name, job_kwargs)
if not is_valid:
logger.error(f"任务参数校验失败: {result}")
return None
search_kwargs = {"plugin_name": plugin_name, "group_id": group_id}
if bot_id and group_id == self.ALL_GROUPS:
search_kwargs["bot_id"] = bot_id
else:
search_kwargs["bot_id__isnull"] = True
defaults = {
"trigger_type": trigger_type,
"trigger_config": trigger_config,
"job_kwargs": result,
"is_enabled": True,
}
schedule, created = await ScheduleRepository.update_or_create(
defaults, **search_kwargs
)
APSchedulerAdapter.add_or_reschedule_job(schedule)
action = "设置" if created else "更新"
logger.info(
f"已成功{action}插件 '{plugin_name}' 的定时任务 (ID: {schedule.id})。"
)
return schedule
async def get_all_schedules(self) -> list[ScheduleInfo]:
"""
获取所有定时任务信息
"""
return await self.get_schedules()
async def get_schedules(
self,
plugin_name: str | None = None,
group_id: str | None = None,
bot_id: str | None = None,
) -> list[ScheduleInfo]:
"""
根据条件获取定时任务列表
参数:
plugin_name: 插件名称None表示不限制
group_id: 群组IDNone表示不限制
bot_id: Bot IDNone表示不限制
返回:
list[ScheduleInfo]: 符合条件的任务信息列表
"""
return await ScheduleRepository.query_schedules(
plugin_name=plugin_name, group_id=group_id, bot_id=bot_id
)
async def update_schedule(
self,
schedule_id: int,
trigger_type: str | None = None,
trigger_config: dict | None = None,
job_kwargs: dict | None = None,
) -> tuple[bool, str]:
"""
更新定时任务配置
参数:
schedule_id: 任务ID
trigger_type: 新的触发器类型None表示不更新
trigger_config: 新的触发器配置None表示不更新
job_kwargs: 新的任务参数None表示不更新
返回:
tuple[bool, str]: (是否成功, 结果消息)
"""
schedule = await ScheduleRepository.get_by_id(schedule_id)
if not schedule:
return False, f"未找到 ID 为 {schedule_id} 的任务。"
updated_fields = []
if trigger_config is not None:
schedule.trigger_config = trigger_config
updated_fields.append("trigger_config")
if trigger_type is not None and schedule.trigger_type != trigger_type:
schedule.trigger_type = trigger_type
updated_fields.append("trigger_type")
if job_kwargs is not None:
existing_kwargs = (
schedule.job_kwargs.copy()
if isinstance(schedule.job_kwargs, dict)
else {}
)
existing_kwargs.update(job_kwargs)
is_valid, result = self._validate_and_prepare_kwargs(
schedule.plugin_name, existing_kwargs
)
if not is_valid:
return False, str(result)
assert isinstance(result, dict), "验证成功时 result 应该是字典类型"
schedule.job_kwargs = result
updated_fields.append("job_kwargs")
if not updated_fields:
return True, "没有任何需要更新的配置。"
await ScheduleRepository.save(schedule, update_fields=updated_fields)
APSchedulerAdapter.add_or_reschedule_job(schedule)
return True, f"成功更新了任务 ID: {schedule_id} 的配置。"
async def get_schedule_status(self, schedule_id: int) -> dict | None:
"""获取定时任务的详细状态信息"""
schedule = await ScheduleRepository.get_by_id(schedule_id)
if not schedule:
return None
status_from_scheduler = APSchedulerAdapter.get_job_status(schedule.id)
status_text = (
"运行中"
if schedule_id in self._running_tasks
else ("启用" if schedule.is_enabled else "暂停")
)
return {
"id": schedule.id,
"bot_id": schedule.bot_id,
"plugin_name": schedule.plugin_name,
"group_id": schedule.group_id,
"is_enabled": status_text,
"trigger_type": schedule.trigger_type,
"trigger_config": schedule.trigger_config,
"job_kwargs": schedule.job_kwargs,
**status_from_scheduler,
}
async def pause_schedule(self, schedule_id: int) -> tuple[bool, str]:
"""暂停指定的定时任务"""
schedule = await ScheduleRepository.get_by_id(schedule_id)
if not schedule or not schedule.is_enabled:
return False, "任务不存在或已暂停。"
schedule.is_enabled = False
await ScheduleRepository.save(schedule, update_fields=["is_enabled"])
APSchedulerAdapter.pause_job(schedule_id)
return True, f"已暂停任务 (ID: {schedule.id})。"
async def resume_schedule(self, schedule_id: int) -> tuple[bool, str]:
"""恢复指定的定时任务"""
schedule = await ScheduleRepository.get_by_id(schedule_id)
if not schedule or schedule.is_enabled:
return False, "任务不存在或已启用。"
schedule.is_enabled = True
await ScheduleRepository.save(schedule, update_fields=["is_enabled"])
APSchedulerAdapter.resume_job(schedule_id)
return True, f"已恢复任务 (ID: {schedule.id})。"
async def trigger_now(self, schedule_id: int) -> tuple[bool, str]:
"""立即手动触发指定的定时任务"""
schedule = await ScheduleRepository.get_by_id(schedule_id)
if not schedule:
return False, f"未找到 ID 为 {schedule_id} 的定时任务。"
if schedule.plugin_name not in self._registered_tasks:
return False, f"插件 '{schedule.plugin_name}' 没有注册可用的定时任务。"
try:
await _execute_job(schedule.id)
return True, f"已手动触发任务 (ID: {schedule.id})。"
except Exception as e:
logger.error(f"手动触发任务失败: {e}")
return False, f"手动触发任务失败: {e}"
scheduler_manager = SchedulerManager()

View File

@ -0,0 +1,109 @@
"""
目标选择器 (Targeter)
提供链式API用于构建和执行对多个定时任务的批量操作
"""
from collections.abc import Callable, Coroutine
from typing import Any
from .adapter import APSchedulerAdapter
from .repository import ScheduleRepository
class ScheduleTargeter:
"""
一个用于构建和执行定时任务批量操作的目标选择器
"""
def __init__(self, manager: Any, **filters: Any):
"""初始化目标选择器"""
self._manager = manager
self._filters = {k: v for k, v in filters.items() if v is not None}
async def _get_schedules(self):
"""根据过滤器获取任务"""
query = ScheduleRepository.filter(**self._filters)
return await query.all()
def _generate_target_description(self) -> str:
"""根据过滤条件生成友好的目标描述"""
if "id" in self._filters:
return f"任务 ID {self._filters['id']}"
parts = []
if "group_id" in self._filters:
group_id = self._filters["group_id"]
if group_id == self._manager.ALL_GROUPS:
parts.append("所有群组中")
else:
parts.append(f"{group_id}")
if "plugin_name" in self._filters:
parts.append(f"插件 '{self._filters['plugin_name']}'")
if not parts:
return "所有"
return "".join(parts)
async def _apply_operation(
self,
operation_func: Callable[[int], Coroutine[Any, Any, tuple[bool, str]]],
operation_name: str,
) -> tuple[int, str]:
"""通用的操作应用模板"""
schedules = await self._get_schedules()
if not schedules:
target_desc = self._generate_target_description()
return 0, f"没有找到{target_desc}可供{operation_name}的任务。"
success_count = 0
for schedule in schedules:
success, _ = await operation_func(schedule.id)
if success:
success_count += 1
target_desc = self._generate_target_description()
return (
success_count,
f"成功{operation_name}{target_desc} {success_count} 个任务。",
)
async def pause(self) -> tuple[int, str]:
"""
暂停匹配的定时任务
返回:
tuple[int, str]: (成功暂停的任务数量, 操作结果消息)
"""
return await self._apply_operation(self._manager.pause_schedule, "暂停")
async def resume(self) -> tuple[int, str]:
"""
恢复匹配的定时任务
返回:
tuple[int, str]: (成功恢复的任务数量, 操作结果消息)
"""
return await self._apply_operation(self._manager.resume_schedule, "恢复")
async def remove(self) -> tuple[int, str]:
"""
移除匹配的定时任务
返回:
tuple[int, str]: (成功移除的任务数量, 操作结果消息)
"""
schedules = await self._get_schedules()
if not schedules:
target_desc = self._generate_target_description()
return 0, f"没有找到{target_desc}可供移除的任务。"
for schedule in schedules:
APSchedulerAdapter.remove_job(schedule.id)
query = ScheduleRepository.filter(**self._filters)
count = await query.delete()
target_desc = self._generate_target_description()
return count, f"成功移除了{target_desc} {count} 个任务。"

View File

@ -1,91 +1,94 @@
import os
import sys
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, Literal
from nonebot import get_driver
from playwright.__main__ import main
from playwright.async_api import Browser, Playwright, async_playwright
from nonebot_plugin_alconna import UniMessage
from nonebot_plugin_htmlrender import get_browser
from playwright.async_api import Page
from zhenxun.configs.config import BotConfig
from zhenxun.services.log import logger
driver = get_driver()
_playwright: Playwright | None = None
_browser: Browser | None = None
from zhenxun.utils.message import MessageUtils
# @driver.on_startup
# async def start_browser():
# global _playwright
# global _browser
# install()
# await check_playwright_env()
# _playwright = await async_playwright().start()
# _browser = await _playwright.chromium.launch()
class BrowserIsNone(Exception):
pass
# @driver.on_shutdown
# async def shutdown_browser():
# if _browser:
# await _browser.close()
# if _playwright:
# await _playwright.stop() # type: ignore
class AsyncPlaywright:
@classmethod
@asynccontextmanager
async def new_page(
cls, cookies: list[dict[str, Any]] | dict[str, Any] | None = None, **kwargs
) -> AsyncGenerator[Page, None]:
"""获取一个新页面
# def get_browser() -> Browser:
# if not _browser:
# raise RuntimeError("playwright is not initalized")
# return _browser
def install():
"""自动安装、更新 Chromium"""
def set_env_variables():
os.environ["PLAYWRIGHT_DOWNLOAD_HOST"] = (
"https://npmmirror.com/mirrors/playwright/"
)
if BotConfig.system_proxy:
os.environ["HTTPS_PROXY"] = BotConfig.system_proxy
def restore_env_variables():
os.environ.pop("PLAYWRIGHT_DOWNLOAD_HOST", None)
if BotConfig.system_proxy:
os.environ.pop("HTTPS_PROXY", None)
if original_proxy is not None:
os.environ["HTTPS_PROXY"] = original_proxy
def try_install_chromium():
参数:
cookies: cookies
"""
browser = await get_browser()
ctx = await browser.new_context(**kwargs)
if cookies:
if isinstance(cookies, dict):
cookies = [cookies]
await ctx.add_cookies(cookies) # type: ignore
page = await ctx.new_page()
try:
sys.argv = ["", "install", "chromium"]
main()
except SystemExit as e:
return e.code == 0
return False
yield page
finally:
await page.close()
await ctx.close()
logger.info("检查 Chromium 更新")
@classmethod
async def screenshot(
cls,
url: str,
path: Path | str,
element: str | list[str],
*,
wait_time: int | None = None,
viewport_size: dict[str, int] | None = None,
wait_until: (
Literal["domcontentloaded", "load", "networkidle"] | None
) = "networkidle",
timeout: float | None = None,
type_: Literal["jpeg", "png"] | None = None,
user_agent: str | None = None,
cookies: list[dict[str, Any]] | dict[str, Any] | None = None,
**kwargs,
) -> UniMessage | None:
"""截图,该方法仅用于简单快捷截图,复杂截图请操作 page
original_proxy = os.environ.get("HTTPS_PROXY")
set_env_variables()
success = try_install_chromium()
if not success:
logger.info("Chromium 更新失败,尝试从原始仓库下载,速度较慢")
os.environ["PLAYWRIGHT_DOWNLOAD_HOST"] = ""
success = try_install_chromium()
restore_env_variables()
if not success:
raise RuntimeError("未知错误Chromium 下载失败")
async def check_playwright_env():
"""检查 Playwright 依赖"""
logger.info("检查 Playwright 依赖")
try:
async with async_playwright() as p:
await p.chromium.launch()
except Exception as e:
raise ImportError("加载失败Playwright 依赖不全,") from e
参数:
url: 网址
path: 存储路径
element: 元素选择
wait_time: 等待截取超时时间
viewport_size: 窗口大小
wait_until: 等待类型
timeout: 超时限制
type_: 保存类型
user_agent: user_agent
cookies: cookies
"""
if viewport_size is None:
viewport_size = {"width": 2560, "height": 1080}
if isinstance(path, str):
path = Path(path)
wait_time = wait_time * 1000 if wait_time else None
element_list = [element] if isinstance(element, str) else element
async with cls.new_page(
cookies,
viewport=viewport_size,
user_agent=user_agent,
**kwargs,
) as page:
await page.goto(url, timeout=timeout, wait_until=wait_until)
card = page
for e in element_list:
if not card:
return None
card = await card.wait_for_selector(e, timeout=wait_time)
if card:
await card.screenshot(path=path, timeout=timeout, type=type_)
return MessageUtils.build_message(path)
return None

View File

@ -1,24 +1,226 @@
from collections.abc import Callable
from functools import partial, wraps
from typing import Any, Literal
from anyio import EndOfStream
from httpx import ConnectError, HTTPStatusError, TimeoutException
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
from httpx import (
ConnectError,
HTTPStatusError,
RemoteProtocolError,
StreamError,
TimeoutException,
)
from nonebot.utils import is_coroutine_callable
from tenacity import (
RetryCallState,
retry,
retry_if_exception_type,
retry_if_result,
stop_after_attempt,
wait_exponential,
wait_fixed,
)
from zhenxun.services.log import logger
LOG_COMMAND = "RetryDecorator"
_SENTINEL = object()
def _log_before_sleep(log_name: str | None, retry_state: RetryCallState):
"""
tenacity 重试前的日志记录回调函数
"""
func_name = retry_state.fn.__name__ if retry_state.fn else "unknown_function"
log_context = f"函数 '{func_name}'"
if log_name:
log_context = f"操作 '{log_name}' ({log_context})"
reason = ""
if retry_state.outcome:
if exc := retry_state.outcome.exception():
reason = f"触发异常: {exc.__class__.__name__}({exc})"
else:
reason = f"不满足结果条件: result={retry_state.outcome.result()}"
wait_time = (
getattr(retry_state.next_action, "sleep", 0) if retry_state.next_action else 0
)
logger.warning(
f"{log_context}{retry_state.attempt_number} 次重试... "
f"等待 {wait_time:.2f} 秒. {reason}",
LOG_COMMAND,
)
class Retry:
@staticmethod
def api(
retry_count: int = 3, wait: int = 1, exception: tuple[type[Exception], ...] = ()
def simple(
stop_max_attempt: int = 3,
wait_fixed_seconds: int = 2,
exception: tuple[type[Exception], ...] = (),
*,
log_name: str | None = None,
on_failure: Callable[[Exception], Any] | None = None,
return_on_failure: Any = _SENTINEL,
):
"""接口调用重试"""
"""
一个简单的用于通用网络请求的重试装饰器预设
参数:
stop_max_attempt: 最大重试次数
wait_fixed_seconds: 固定等待策略的等待秒数
exception: 额外需要重试的异常类型元组
log_name: 用于日志记录的操作名称
on_failure: (可选) 所有重试失败后的回调
return_on_failure: (可选) 所有重试失败后的返回值
"""
return Retry.api(
stop_max_attempt=stop_max_attempt,
wait_fixed_seconds=wait_fixed_seconds,
exception=exception,
strategy="fixed",
log_name=log_name,
on_failure=on_failure,
return_on_failure=return_on_failure,
)
@staticmethod
def download(
stop_max_attempt: int = 3,
exception: tuple[type[Exception], ...] = (),
*,
wait_exp_multiplier: int = 2,
wait_exp_max: int = 15,
log_name: str | None = None,
on_failure: Callable[[Exception], Any] | None = None,
return_on_failure: Any = _SENTINEL,
):
"""
一个适用于文件下载的重试装饰器预设使用指数退避策略
参数:
stop_max_attempt: 最大重试次数
exception: 额外需要重试的异常类型元组
wait_exp_multiplier: 指数退避的乘数
wait_exp_max: 指数退避的最大等待时间
log_name: 用于日志记录的操作名称
on_failure: (可选) 所有重试失败后的回调
return_on_failure: (可选) 所有重试失败后的返回值
"""
return Retry.api(
stop_max_attempt=stop_max_attempt,
exception=exception,
strategy="exponential",
wait_exp_multiplier=wait_exp_multiplier,
wait_exp_max=wait_exp_max,
log_name=log_name,
on_failure=on_failure,
return_on_failure=return_on_failure,
)
@staticmethod
def api(
stop_max_attempt: int = 3,
wait_fixed_seconds: int = 1,
exception: tuple[type[Exception], ...] = (),
*,
strategy: Literal["fixed", "exponential"] = "fixed",
retry_on_result: Callable[[Any], bool] | None = None,
wait_exp_multiplier: int = 1,
wait_exp_max: int = 10,
log_name: str | None = None,
on_failure: Callable[[Exception], Any] | None = None,
return_on_failure: Any = _SENTINEL,
):
"""
通用可配置的API调用重试装饰器
参数:
stop_max_attempt: 最大重试次数
wait_fixed_seconds: 固定等待策略的等待秒数
exception: 额外需要重试的异常类型元组
strategy: 重试等待策略, 'fixed' (固定) 'exponential' (指数退避)
retry_on_result: 一个回调函数接收函数返回值如果返回 True则触发重试
例如 `lambda r: r.status_code != 200`
wait_exp_multiplier: 指数退避的乘数
wait_exp_max: 指数退避的最大等待时间
log_name: 用于日志记录的操作名称方便区分不同的重试场景
on_failure: (可选) 当所有重试都失败后在抛出异常或返回默认值之前
会调用此函数并将最终的异常实例作为参数传入
return_on_failure: (可选) 如果设置了此参数当所有重试失败后
将不再抛出异常而是返回此参数指定的值
"""
base_exceptions = (
TimeoutException,
ConnectError,
HTTPStatusError,
StreamError,
RemoteProtocolError,
EndOfStream,
*exception,
)
return retry(
reraise=True,
stop=stop_after_attempt(retry_count),
wait=wait_fixed(wait),
retry=retry_if_exception_type(base_exceptions),
)
def decorator(func: Callable) -> Callable:
if strategy == "exponential":
wait_strategy = wait_exponential(
multiplier=wait_exp_multiplier, max=wait_exp_max
)
else:
wait_strategy = wait_fixed(wait_fixed_seconds)
retry_conditions = retry_if_exception_type(base_exceptions)
if retry_on_result:
retry_conditions |= retry_if_result(retry_on_result)
log_callback = partial(_log_before_sleep, log_name)
tenacity_retry_decorator = retry(
stop=stop_after_attempt(stop_max_attempt),
wait=wait_strategy,
retry=retry_conditions,
before_sleep=log_callback,
reraise=True,
)
decorated_func = tenacity_retry_decorator(func)
if return_on_failure is _SENTINEL:
return decorated_func
if is_coroutine_callable(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
try:
return await decorated_func(*args, **kwargs)
except Exception as e:
if on_failure:
if is_coroutine_callable(on_failure):
await on_failure(e)
else:
on_failure(e)
return return_on_failure
return async_wrapper
else:
@wraps(func)
def sync_wrapper(*args, **kwargs):
try:
return decorated_func(*args, **kwargs)
except Exception as e:
if on_failure:
if is_coroutine_callable(on_failure):
logger.error(
f"不能在同步函数 '{func.__name__}' 中调用异步的 "
f"on_failure 回调。",
LOG_COMMAND,
)
else:
on_failure(e)
return return_on_failure
return sync_wrapper
return decorator

View File

@ -64,3 +64,23 @@ class GoodsNotFound(Exception):
"""
pass
class AllURIsFailedError(Exception):
"""
当所有备用URL都尝试失败后抛出此异常
"""
def __init__(self, urls: list[str], exceptions: list[Exception]):
self.urls = urls
self.exceptions = exceptions
super().__init__(
f"All {len(urls)} URIs failed. Last exception: {exceptions[-1]}"
)
def __str__(self) -> str:
exc_info = "\n".join(
f" - {url}: {exc.__class__.__name__}({exc})"
for url, exc in zip(self.urls, self.exceptions)
)
return f"All {len(self.urls)} URIs failed:\n{exc_info}"

View File

@ -1,16 +1,15 @@
import asyncio
from collections.abc import AsyncGenerator, Sequence
from collections.abc import AsyncGenerator, Awaitable, Callable, Sequence
from contextlib import asynccontextmanager
import os
from pathlib import Path
import time
from typing import Any, ClassVar, Literal, cast
from typing import Any, ClassVar, cast
import aiofiles
import httpx
from httpx import AsyncHTTPTransport, HTTPStatusError, Proxy, Response
from nonebot_plugin_alconna import UniMessage
from nonebot_plugin_htmlrender import get_browser
from playwright.async_api import Page
from httpx import AsyncClient, AsyncHTTPTransport, HTTPStatusError, Proxy, Response
import nonebot
from rich.progress import (
BarColumn,
DownloadColumn,
@ -18,13 +17,84 @@ from rich.progress import (
TextColumn,
TransferSpeedColumn,
)
import ujson as json
from zhenxun.configs.config import BotConfig
from zhenxun.services.log import logger
from zhenxun.utils.message import MessageUtils
from zhenxun.utils.decorator.retry import Retry
from zhenxun.utils.exception import AllURIsFailedError
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from zhenxun.utils.user_agent import get_user_agent
CLIENT_KEY = ["use_proxy", "proxies", "proxy", "verify", "headers"]
from .browser import AsyncPlaywright, BrowserIsNone # noqa: F401
_SENTINEL = object()
driver = nonebot.get_driver()
_client: AsyncClient | None = None
@PriorityLifecycle.on_startup(priority=0)
async def _():
"""
在Bot启动时初始化全局httpx客户端
"""
global _client
client_kwargs = {}
if proxy_url := BotConfig.system_proxy or None:
try:
version_parts = httpx.__version__.split(".")
major = int("".join(c for c in version_parts[0] if c.isdigit()))
minor = (
int("".join(c for c in version_parts[1] if c.isdigit()))
if len(version_parts) > 1
else 0
)
if (major, minor) >= (0, 28):
client_kwargs["proxy"] = proxy_url
else:
client_kwargs["proxies"] = proxy_url
except (ValueError, IndexError):
client_kwargs["proxy"] = proxy_url
logger.warning(
f"无法解析 httpx 版本 '{httpx.__version__}'"
"将默认使用新版 'proxy' 参数语法。"
)
_client = httpx.AsyncClient(
headers=get_user_agent(),
follow_redirects=True,
**client_kwargs,
)
logger.info("全局 httpx.AsyncClient 已启动。", "HTTPClient")
@driver.on_shutdown
async def _():
"""
在Bot关闭时关闭全局httpx客户端
"""
if _client:
await _client.aclose()
logger.info("全局 httpx.AsyncClient 已关闭。", "HTTPClient")
def get_client() -> AsyncClient:
"""
获取全局 httpx.AsyncClient 实例
"""
global _client
if not _client:
if not os.environ.get("PYTEST_CURRENT_TEST"):
raise RuntimeError("全局 httpx.AsyncClient 未初始化,请检查启动流程。")
# 在测试环境中创建临时客户端
logger.warning("在测试环境中创建临时HTTP客户端", "HTTPClient")
_client = httpx.AsyncClient(
headers=get_user_agent(),
follow_redirects=True,
)
return _client
def get_async_client(
@ -33,6 +103,10 @@ def get_async_client(
verify: bool = False,
**kwargs,
) -> httpx.AsyncClient:
"""
[向后兼容] 创建 httpx.AsyncClient 实例的工厂函数
此函数完全保留了旧版本的接口确保现有代码无需修改即可使用
"""
transport = kwargs.pop("transport", None) or AsyncHTTPTransport(verify=verify)
if proxies:
http_proxy = proxies.get("http://")
@ -62,6 +136,30 @@ def get_async_client(
class AsyncHttpx:
"""
一个高级的健壮的异步HTTP客户端工具类
设计理念:
- **全局共享客户端**: 默认情况下所有请求都通过一个在应用启动时初始化的全局
`httpx.AsyncClient` 实例发出这个实例共享连接池提高了效率和性能
- **向后兼容与灵活性**: 完全兼容旧的API同时提供了两种方式来处理需要
特殊网络配置如不同代理超时的请求
1. **单次请求覆盖**: 在调用 `get`, `post` 等方法时直接传入 `proxies`,
`timeout` 等参数将为该次请求创建一个临时的独立的客户端
2. **临时客户端上下文**: 使用 `temporary_client()` 上下文管理器可以
获取一个独立的可配置的客户端用于执行一系列需要相同特殊配置的请求
- **健壮性**: 内置了自动重试多镜像URL回退fallback机制并提供了便捷的
JSON解析和文件下载方法
"""
CLIENT_KEY: ClassVar[list[str]] = [
"use_proxy",
"proxies",
"proxy",
"verify",
"headers",
]
default_proxy: ClassVar[dict[str, str] | None] = (
{
"http://": BotConfig.system_proxy,
@ -72,155 +170,346 @@ class AsyncHttpx:
)
@classmethod
@asynccontextmanager
async def _create_client(
cls,
*,
use_proxy: bool = True,
proxies: dict[str, str] | None = None,
proxy: str | None = None,
headers: dict[str, str] | None = None,
verify: bool = False,
**kwargs,
) -> AsyncGenerator[httpx.AsyncClient, None]:
"""创建一个私有的、配置好的 httpx.AsyncClient 上下文管理器。
def _prepare_temporary_client_config(cls, client_kwargs: dict) -> dict:
"""
[向后兼容] 处理旧式的客户端kwargs将其转换为get_async_client可用的配置
主要负责处理 use_proxy 标志这是为了兼容旧版本代码中使用的 use_proxy 参数
"""
final_config = client_kwargs.copy()
说明:
此方法用于内部统一创建客户端处理代理和请求头逻辑减少代码重复
use_proxy = final_config.pop("use_proxy", True)
if "proxies" not in final_config and "proxy" not in final_config:
final_config["proxies"] = cls.default_proxy if use_proxy else None
return final_config
@classmethod
def _split_kwargs(cls, kwargs: dict) -> tuple[dict, dict]:
"""[优化] 分离客户端配置和请求参数,使逻辑更清晰。"""
client_kwargs = {k: v for k, v in kwargs.items() if k in cls.CLIENT_KEY}
request_kwargs = {k: v for k, v in kwargs.items() if k not in cls.CLIENT_KEY}
return client_kwargs, request_kwargs
@classmethod
@asynccontextmanager
async def _get_active_client_context(
cls, client: AsyncClient | None = None, **kwargs
) -> AsyncGenerator[AsyncClient, None]:
"""
内部辅助方法根据 kwargs 决定并提供一个活动的 HTTP 客户端
- 如果 kwargs 中有客户端配置则创建并返回一个临时客户端
- 否则返回传入的 client 或全局客户端
- 自动处理临时客户端的关闭
"""
if kwargs:
logger.debug(f"为单次请求创建临时客户端,配置: {kwargs}")
temp_client_config = cls._prepare_temporary_client_config(kwargs)
async with get_async_client(**temp_client_config) as temp_client:
yield temp_client
else:
yield client or get_client()
@Retry.simple(log_name="内部HTTP请求")
async def _execute_request_inner(
self, client: AsyncClient, method: str, url: str, **kwargs
) -> Response:
"""
[内部] 执行单次HTTP请求的私有核心方法被重试装饰器包裹
"""
return await client.request(method, url, **kwargs)
@classmethod
async def _single_request(
cls, method: str, url: str, *, client: AsyncClient | None = None, **kwargs
) -> Response:
"""
执行单次HTTP请求的私有方法内置了默认的重试逻辑
"""
client_kwargs, request_kwargs = cls._split_kwargs(kwargs)
async with cls._get_active_client_context(
client=client, **client_kwargs
) as active_client:
response = await cls()._execute_request_inner(
active_client, method, url, **request_kwargs
)
response.raise_for_status()
return response
@classmethod
async def _execute_with_fallbacks(
cls,
urls: str | list[str],
worker: Callable[..., Awaitable[Any]],
*,
client: AsyncClient | None = None,
**kwargs,
) -> Any:
"""
通用执行器按顺序尝试多个URL直到成功
参数:
use_proxy: 是否使用在类中定义的默认代理
proxies: 手动指定的代理会覆盖默认代理
proxy: 单个代理,用于兼容旧版本不再使用
headers: 需要合并到客户端的自定义请求头
verify: 是否验证 SSL 证书
**kwargs: 其他所有传递给 httpx.AsyncClient 的参数
返回:
AsyncGenerator[httpx.AsyncClient, None]: 生成器
urls: 单个URL或URL列表
worker: 一个接受单个URL和其他kwargs并执行请求的协程函数
client: 可选的HTTP客户端
**kwargs: 传递给worker的额外参数
"""
proxies_to_use = proxies or (cls.default_proxy if use_proxy else None)
url_list = [urls] if isinstance(urls, str) else urls
exceptions = []
final_headers = get_user_agent()
if headers:
final_headers.update(headers)
for i, url in enumerate(url_list):
try:
result = await worker(url, client=client, **kwargs)
if i > 0:
logger.info(
f"成功从镜像 '{url}' 获取资源 "
f"(在尝试了 {i} 个失败的镜像之后)。",
"AsyncHttpx:FallbackExecutor",
)
return result
except Exception as e:
exceptions.append(e)
if url != url_list[-1]:
logger.warning(
f"Worker '{worker.__name__}' on {url} failed, trying next. "
f"Error: {e.__class__.__name__}",
"AsyncHttpx:FallbackExecutor",
)
async with get_async_client(
proxies=proxies_to_use,
proxy=proxy,
verify=verify,
headers=final_headers,
**kwargs,
) as client:
yield client
raise AllURIsFailedError(url_list, exceptions)
@classmethod
async def get(
cls,
url: str | list[str],
*,
follow_redirects: bool = True,
check_status_code: int | None = None,
client: AsyncClient | None = None,
**kwargs,
) -> Response: # sourcery skip: use-assigned-variable
) -> Response:
"""发送 GET 请求,并返回第一个成功的响应。
说明:
本方法是 httpx.get 的高级包装增加了多链接尝试自动重试和统一的代理管理
如果提供 URL 列表它将依次尝试直到成功为止
本方法是 httpx.get 的高级包装增加了多链接尝试自动重试和统一的
客户端管理如果提供 URL 列表它将依次尝试直到成功为止
用法建议:
- **常规使用**: `await AsyncHttpx.get(url)` 将使用全局客户端
- **单次覆盖配置**: `await AsyncHttpx.get(url, timeout=5, proxies=None)`
将为本次请求创建一个独立的临时客户端
参数:
url: 单个请求 URL 或一个 URL 列表
follow_redirects: 是否跟随重定向
check_status_code: (可选) 若提供将检查响应状态码是否匹配否则抛出异常
**kwargs: 其他所有传递给 httpx.get 的参数
( `params`, `headers`, `timeout`)
client: (可选) 指定一个活动的HTTP客户端实例若提供则忽略
`**kwargs`中的客户端配置
**kwargs: 其他所有传递给 httpx.get 的参数 ( `params`, `headers`,
`timeout`)如果包含 `proxies`, `verify` 等客户端配置参数
将创建一个临时客户端
返回:
Response: Response
Response: httpx 的响应对象
Raises:
AllURIsFailedError: 当所有提供的URL都请求失败时抛出
"""
urls = [url] if isinstance(url, str) else url
last_exception = None
for current_url in urls:
try:
logger.info(f"开始获取 {current_url}..")
client_kwargs = {k: v for k, v in kwargs.items() if k in CLIENT_KEY}
for key in CLIENT_KEY:
kwargs.pop(key, None)
async with cls._create_client(**client_kwargs) as client:
response = await client.get(current_url, **kwargs)
if check_status_code and response.status_code != check_status_code:
raise HTTPStatusError(
f"状态码错误: {response.status_code}!={check_status_code}",
request=response.request,
response=response,
)
return response
except Exception as e:
last_exception = e
if current_url != urls[-1]:
logger.warning(f"获取 {current_url} 失败, 尝试下一个", e=e)
async def worker(current_url: str, **worker_kwargs) -> Response:
logger.info(f"开始获取 {current_url}..", "AsyncHttpx:get")
response = await cls._single_request(
"GET", current_url, follow_redirects=follow_redirects, **worker_kwargs
)
if check_status_code and response.status_code != check_status_code:
raise HTTPStatusError(
f"状态码错误: {response.status_code}!={check_status_code}",
request=response.request,
response=response,
)
return response
raise last_exception or Exception("所有URL都获取失败")
return await cls._execute_with_fallbacks(url, worker, client=client, **kwargs)
@classmethod
async def head(cls, url: str, **kwargs) -> Response:
"""发送 HEAD 请求。
async def head(
cls, url: str | list[str], *, client: AsyncClient | None = None, **kwargs
) -> Response:
"""发送 HEAD 请求,并返回第一个成功的响应。"""
说明:
本方法是对 httpx.head 的封装通常用于检查资源的元信息如大小类型
async def worker(current_url: str, **worker_kwargs) -> Response:
return await cls._single_request("HEAD", current_url, **worker_kwargs)
参数:
url: 请求的 URL
**kwargs: 其他所有传递给 httpx.head 的参数
( `headers`, `timeout`, `allow_redirects`)
返回:
Response: Response
"""
client_kwargs = {k: v for k, v in kwargs.items() if k in CLIENT_KEY}
for key in CLIENT_KEY:
kwargs.pop(key, None)
async with cls._create_client(**client_kwargs) as client:
return await client.head(url, **kwargs)
return await cls._execute_with_fallbacks(url, worker, client=client, **kwargs)
@classmethod
async def post(cls, url: str, **kwargs) -> Response:
"""发送 POST 请求。
async def post(
cls, url: str | list[str], *, client: AsyncClient | None = None, **kwargs
) -> Response:
"""发送 POST 请求,并返回第一个成功的响应。"""
说明:
本方法是对 httpx.post 的封装提供了统一的代理和客户端管理
async def worker(current_url: str, **worker_kwargs) -> Response:
return await cls._single_request("POST", current_url, **worker_kwargs)
参数:
url: 请求的 URL
**kwargs: 其他所有传递给 httpx.post 的参数
( `data`, `json`, `content` )
返回:
Response: Response
"""
client_kwargs = {k: v for k, v in kwargs.items() if k in CLIENT_KEY}
for key in CLIENT_KEY:
kwargs.pop(key, None)
async with cls._create_client(**client_kwargs) as client:
return await client.post(url, **kwargs)
return await cls._execute_with_fallbacks(url, worker, client=client, **kwargs)
@classmethod
async def get_content(cls, url: str, **kwargs) -> bytes:
"""获取指定 URL 的二进制内容。
说明:
这是一个便捷方法等同于调用 get() 后再访问 .content 属性
参数:
url: 请求的 URL
**kwargs: 所有传递给 get() 方法的参数
返回:
bytes: 响应内容的二进制字节流 (bytes)
"""
res = await cls.get(url, **kwargs)
async def get_content(
cls, url: str | list[str], *, client: AsyncClient | None = None, **kwargs
) -> bytes:
"""获取指定 URL 的二进制内容。"""
res = await cls.get(url, client=client, **kwargs)
return res.content
@classmethod
@Retry.api(
log_name="JSON请求",
exception=(json.JSONDecodeError,),
return_on_failure=_SENTINEL,
)
async def _request_and_parse_json(
cls, method: str, url: str, *, client: AsyncClient | None = None, **kwargs
) -> Any:
"""
[私有] 执行单个HTTP请求并解析JSON用于内部统一处理
"""
async with cls._get_active_client_context(
client=client, **kwargs
) as active_client:
_, request_kwargs = cls._split_kwargs(kwargs)
response = await active_client.request(method, url, **request_kwargs)
response.raise_for_status()
return response.json()
@classmethod
async def get_json(
cls,
url: str | list[str],
*,
default: Any = None,
raise_on_failure: bool = False,
client: AsyncClient | None = None,
**kwargs,
) -> Any:
"""
发送GET请求并自动解析为JSON支持重试和多链接尝试
说明:
这是一个高度便捷的方法封装了请求重试JSON解析和错误处理
它会在网络错误或JSON解析错误时自动重试
如果所有尝试都失败它会安全地返回一个默认值
参数:
url: 单个请求 URL 或一个备用 URL 列表
default: (可选) 当所有尝试都失败时返回的默认值默认为None
raise_on_failure: (可选) 如果为 True, 当所有尝试失败时将抛出
`AllURIsFailedError` 异常, 默认为 False.
client: (可选) 指定的HTTP客户端
**kwargs: 其他所有传递给 httpx.get 的参数
例如 `params`, `headers`, `timeout`
返回:
Any: 解析后的JSON数据或在失败时返回 `default`
Raises:
AllURIsFailedError: `raise_on_failure` True 且所有URL都请求失败时抛出
"""
async def worker(current_url: str, **worker_kwargs):
logger.debug(f"开始GET JSON: {current_url}", "AsyncHttpx:get_json")
return await cls._request_and_parse_json(
"GET", current_url, **worker_kwargs
)
try:
result = await cls._execute_with_fallbacks(
url, worker, client=client, **kwargs
)
return default if result is _SENTINEL else result
except AllURIsFailedError as e:
logger.error(f"所有URL的JSON GET均失败: {e}", "AsyncHttpx:get_json")
if raise_on_failure:
raise e
return default
@classmethod
async def post_json(
cls,
url: str | list[str],
*,
json: Any = None,
data: Any = None,
default: Any = None,
raise_on_failure: bool = False,
client: AsyncClient | None = None,
**kwargs,
) -> Any:
"""
发送POST请求并自动解析为JSON功能与 get_json 类似
参数:
url: 单个请求 URL 或一个备用 URL 列表
json: (可选) 作为请求体发送的JSON数据
data: (可选) 作为请求体发送的表单数据
default: (可选) 当所有尝试都失败时返回的默认值默认为None
raise_on_failure: (可选) 如果为 True, 当所有尝试失败时将抛出
AllURIsFailedError 异常, 默认为 False.
client: (可选) 指定的HTTP客户端
**kwargs: 其他所有传递给 httpx.post 的参数
返回:
Any: 解析后的JSON数据或在失败时返回 `default`
"""
if json is not None:
kwargs["json"] = json
if data is not None:
kwargs["data"] = data
async def worker(current_url: str, **worker_kwargs):
logger.debug(f"开始POST JSON: {current_url}", "AsyncHttpx:post_json")
return await cls._request_and_parse_json(
"POST", current_url, **worker_kwargs
)
try:
result = await cls._execute_with_fallbacks(
url, worker, client=client, **kwargs
)
return default if result is _SENTINEL else result
except AllURIsFailedError as e:
logger.error(f"所有URL的JSON POST均失败: {e}", "AsyncHttpx:post_json")
if raise_on_failure:
raise e
return default
@classmethod
@Retry.api(log_name="文件下载(流式)")
async def _stream_download(
cls, url: str, path: Path, *, client: AsyncClient | None = None, **kwargs
) -> None:
"""
执行单个流式下载的私有方法被重试装饰器包裹
"""
async with cls._get_active_client_context(
client=client, **kwargs
) as active_client:
async with active_client.stream("GET", url, **kwargs) as response:
response.raise_for_status()
total = int(response.headers.get("Content-Length", 0))
with Progress(
TextColumn(path.name),
"[progress.percentage]{task.percentage:>3.0f}%",
BarColumn(bar_width=None),
DownloadColumn(),
TransferSpeedColumn(),
) as progress:
task_id = progress.add_task("Download", total=total)
async with aiofiles.open(path, "wb") as f:
async for chunk in response.aiter_bytes():
await f.write(chunk)
progress.update(task_id, advance=len(chunk))
@classmethod
async def download_file(
cls,
@ -228,6 +517,7 @@ class AsyncHttpx:
path: str | Path,
*,
stream: bool = False,
client: AsyncClient | None = None,
**kwargs,
) -> bool:
"""下载文件到指定路径。
@ -239,6 +529,7 @@ class AsyncHttpx:
url: 单个文件 URL 或一个备用 URL 列表
path: 文件保存的本地路径
stream: (可选) 是否使用流式下载适用于大文件默认为 False
client: (可选) 指定的HTTP客户端
**kwargs: 其他所有传递给 get() 方法或 httpx.stream() 的参数
返回:
@ -247,49 +538,29 @@ class AsyncHttpx:
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
urls = [url] if isinstance(url, str) else url
async def worker(current_url: str, **worker_kwargs) -> bool:
if not stream:
content = await cls.get_content(current_url, **worker_kwargs)
async with aiofiles.open(path, "wb") as f:
await f.write(content)
else:
await cls._stream_download(current_url, path, **worker_kwargs)
for current_url in urls:
try:
if not stream:
response = await cls.get(current_url, **kwargs)
response.raise_for_status()
async with aiofiles.open(path, "wb") as f:
await f.write(response.content)
else:
async with cls._create_client(**kwargs) as client:
stream_kwargs = {
k: v
for k, v in kwargs.items()
if k not in ["use_proxy", "proxy", "verify"]
}
async with client.stream(
"GET", current_url, **stream_kwargs
) as response:
response.raise_for_status()
total = int(response.headers.get("Content-Length", 0))
logger.info(
f"下载 {current_url} 成功 -> {path.absolute()}",
"AsyncHttpx:download",
)
return True
with Progress(
TextColumn(path.name),
"[progress.percentage]{task.percentage:>3.0f}%",
BarColumn(bar_width=None),
DownloadColumn(),
TransferSpeedColumn(),
) as progress:
task_id = progress.add_task("Download", total=total)
async with aiofiles.open(path, "wb") as f:
async for chunk in response.aiter_bytes():
await f.write(chunk)
progress.update(task_id, advance=len(chunk))
logger.info(f"下载 {current_url} 成功 -> {path.absolute()}")
return True
except Exception as e:
logger.warning(f"下载 {current_url} 失败,尝试下一个。错误: {e}")
logger.error(f"所有URL {urls} 下载均失败 -> {path.absolute()}")
return False
try:
return await cls._execute_with_fallbacks(
url, worker, client=client, **kwargs
)
except AllURIsFailedError:
logger.error(
f"所有URL下载均失败 -> {path.absolute()}", "AsyncHttpx:download"
)
return False
@classmethod
async def gather_download_file(
@ -346,7 +617,6 @@ class AsyncHttpx:
logger.error(f"并发下载任务 ({url_info}) 时发生错误", e=result)
final_results.append(False)
else:
# download_file 返回的是 bool可以直接附加
final_results.append(cast(bool, result))
return final_results
@ -395,86 +665,30 @@ class AsyncHttpx:
_results = sorted(iter(_results), key=lambda r: r["elapsed_time"])
return [result["url"] for result in _results]
class AsyncPlaywright:
@classmethod
@asynccontextmanager
async def new_page(
cls, cookies: list[dict[str, Any]] | dict[str, Any] | None = None, **kwargs
) -> AsyncGenerator[Page, None]:
"""获取一个新页面
async def temporary_client(cls, **kwargs) -> AsyncGenerator[AsyncClient, None]:
"""
创建一个临时的可配置的HTTP客户端上下文并直接返回该客户端实例
此方法返回一个标准的 `httpx.AsyncClient`它不使用全局连接池
拥有独立的配置(如代理headers超时等)并在退出上下文后自动关闭
适用于需要用一套特殊网络配置执行一系列请求的场景
用法:
async with AsyncHttpx.temporary_client(proxies=None, timeout=5) as client:
# client 是一个标准的 httpx.AsyncClient 实例
response1 = await client.get("http://some.internal.api/1")
response2 = await client.get("http://some.internal.api/2")
data = response2.json()
参数:
cookies: cookies
**kwargs: 所有传递给 `httpx.AsyncClient` 构造函数的参数
例如: `proxies`, `headers`, `verify`, `timeout`,
`follow_redirects`
Yields:
httpx.AsyncClient: 一个配置好的临时的客户端实例
"""
browser = await get_browser()
ctx = await browser.new_context(**kwargs)
if cookies:
if isinstance(cookies, dict):
cookies = [cookies]
await ctx.add_cookies(cookies) # type: ignore
page = await ctx.new_page()
try:
yield page
finally:
await page.close()
await ctx.close()
@classmethod
async def screenshot(
cls,
url: str,
path: Path | str,
element: str | list[str],
*,
wait_time: int | None = None,
viewport_size: dict[str, int] | None = None,
wait_until: (
Literal["domcontentloaded", "load", "networkidle"] | None
) = "networkidle",
timeout: float | None = None,
type_: Literal["jpeg", "png"] | None = None,
user_agent: str | None = None,
cookies: list[dict[str, Any]] | dict[str, Any] | None = None,
**kwargs,
) -> UniMessage | None:
"""截图,该方法仅用于简单快捷截图,复杂截图请操作 page
参数:
url: 网址
path: 存储路径
element: 元素选择
wait_time: 等待截取超时时间
viewport_size: 窗口大小
wait_until: 等待类型
timeout: 超时限制
type_: 保存类型
user_agent: user_agent
cookies: cookies
"""
if viewport_size is None:
viewport_size = {"width": 2560, "height": 1080}
if isinstance(path, str):
path = Path(path)
wait_time = wait_time * 1000 if wait_time else None
element_list = [element] if isinstance(element, str) else element
async with cls.new_page(
cookies,
viewport=viewport_size,
user_agent=user_agent,
**kwargs,
) as page:
await page.goto(url, timeout=timeout, wait_until=wait_until)
card = page
for e in element_list:
if not card:
return None
card = await card.wait_for_selector(e, timeout=wait_time)
if card:
await card.screenshot(path=path, timeout=timeout, type=type_)
return MessageUtils.build_message(path)
return None
class BrowserIsNone(Exception):
pass
async with get_async_client(**kwargs) as client:
yield client

View File

@ -1,810 +0,0 @@
import asyncio
from collections.abc import Callable, Coroutine
import copy
import inspect
import random
from typing import ClassVar
import nonebot
from nonebot import get_bots
from nonebot_plugin_apscheduler import scheduler
from pydantic import BaseModel, ValidationError
from zhenxun.configs.config import Config
from zhenxun.models.schedule_info import ScheduleInfo
from zhenxun.services.log import logger
from zhenxun.utils.common_utils import CommonUtils
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from zhenxun.utils.platform import PlatformUtils
SCHEDULE_CONCURRENCY_KEY = "all_groups_concurrency_limit"
class SchedulerManager:
"""
一个通用的持久化的定时任务管理器供所有插件使用
"""
_registered_tasks: ClassVar[
dict[str, dict[str, Callable | type[BaseModel] | None]]
] = {}
_JOB_PREFIX = "zhenxun_schedule_"
_running_tasks: ClassVar[set] = set()
def register(
self, plugin_name: str, params_model: type[BaseModel] | None = None
) -> Callable:
"""
注册一个可调度的任务函数
被装饰的函数签名应为 `async def func(group_id: str | None, **kwargs)`
Args:
plugin_name (str): 插件的唯一名称 (通常是模块名)
params_model (type[BaseModel], optional): 一个 Pydantic BaseModel
用于定义和验证任务函数接受的额外参数
"""
def decorator(func: Callable[..., Coroutine]) -> Callable[..., Coroutine]:
if plugin_name in self._registered_tasks:
logger.warning(f"插件 '{plugin_name}' 的定时任务已被重复注册。")
self._registered_tasks[plugin_name] = {
"func": func,
"model": params_model,
}
model_name = params_model.__name__ if params_model else ""
logger.debug(
f"插件 '{plugin_name}' 的定时任务已注册,参数模型: {model_name}"
)
return func
return decorator
def get_registered_plugins(self) -> list[str]:
"""获取所有已注册定时任务的插件列表。"""
return list(self._registered_tasks.keys())
def _get_job_id(self, schedule_id: int) -> str:
"""根据数据库ID生成唯一的 APScheduler Job ID。"""
return f"{self._JOB_PREFIX}{schedule_id}"
async def _execute_job(self, schedule_id: int):
"""
APScheduler 调度的入口函数
根据 schedule_id 处理特定任务所有群组任务或全局任务
"""
schedule = await ScheduleInfo.get_or_none(id=schedule_id)
if not schedule or not schedule.is_enabled:
logger.warning(f"定时任务 {schedule_id} 不存在或已禁用,跳过执行。")
return
plugin_name = schedule.plugin_name
task_meta = self._registered_tasks.get(plugin_name)
if not task_meta:
logger.error(
f"无法执行定时任务:插件 '{plugin_name}' 未注册或已卸载。将禁用该任务。"
)
schedule.is_enabled = False
await schedule.save(update_fields=["is_enabled"])
self._remove_aps_job(schedule.id)
return
try:
if schedule.bot_id:
bot = nonebot.get_bot(schedule.bot_id)
else:
bot = nonebot.get_bot()
logger.debug(
f"任务 {schedule_id} 未关联特定Bot使用默认Bot {bot.self_id}"
)
except KeyError:
logger.warning(
f"定时任务 {schedule_id} 需要的 Bot {schedule.bot_id} "
f"不在线,本次执行跳过。"
)
return
except ValueError:
logger.warning(f"当前没有Bot在线定时任务 {schedule_id} 跳过。")
return
if schedule.group_id == "__ALL_GROUPS__":
await self._execute_for_all_groups(schedule, task_meta, bot)
else:
await self._execute_for_single_target(schedule, task_meta, bot)
async def _execute_for_all_groups(
self, schedule: ScheduleInfo, task_meta: dict, bot
):
"""为所有群组执行任务,并处理优先级覆盖。"""
plugin_name = schedule.plugin_name
concurrency_limit = Config.get_config(
"SchedulerManager", SCHEDULE_CONCURRENCY_KEY, 5
)
if not isinstance(concurrency_limit, int) or concurrency_limit <= 0:
logger.warning(
f"无效的定时任务并发限制配置 '{concurrency_limit}',将使用默认值 5。"
)
concurrency_limit = 5
logger.info(
f"开始执行针对 [所有群组] 的任务 "
f"(ID: {schedule.id}, 插件: {plugin_name}, Bot: {bot.self_id})"
f"并发限制: {concurrency_limit}"
)
all_gids = set()
try:
group_list, _ = await PlatformUtils.get_group_list(bot)
all_gids.update(
g.group_id for g in group_list if g.group_id and not g.channel_id
)
except Exception as e:
logger.error(f"'all' 任务获取 Bot {bot.self_id} 的群列表失败", e=e)
return
specific_tasks_gids = set(
await ScheduleInfo.filter(
plugin_name=plugin_name, group_id__in=list(all_gids)
).values_list("group_id", flat=True)
)
semaphore = asyncio.Semaphore(concurrency_limit)
async def worker(gid: str):
"""使用 Semaphore 包装单个群组的任务执行"""
async with semaphore:
temp_schedule = copy.deepcopy(schedule)
temp_schedule.group_id = gid
await self._execute_for_single_target(temp_schedule, task_meta, bot)
await asyncio.sleep(random.uniform(0.1, 0.5))
tasks_to_run = []
for gid in all_gids:
if gid in specific_tasks_gids:
logger.debug(f"群组 {gid} 已有特定任务,跳过 'all' 任务的执行。")
continue
tasks_to_run.append(worker(gid))
if tasks_to_run:
await asyncio.gather(*tasks_to_run)
async def _execute_for_single_target(
self, schedule: ScheduleInfo, task_meta: dict, bot
):
"""为单个目标(具体群组或全局)执行任务。"""
plugin_name = schedule.plugin_name
group_id = schedule.group_id
try:
is_blocked = await CommonUtils.task_is_block(bot, plugin_name, group_id)
if is_blocked:
target_desc = f"{group_id}" if group_id else "全局"
logger.info(
f"插件 '{plugin_name}' 的定时任务在目标 [{target_desc}]"
"因功能被禁用而跳过执行。"
)
return
task_func = task_meta["func"]
job_kwargs = schedule.job_kwargs
if not isinstance(job_kwargs, dict):
logger.error(
f"任务 {schedule.id} 的 job_kwargs 不是字典类型: {type(job_kwargs)}"
)
return
sig = inspect.signature(task_func)
if "bot" in sig.parameters:
job_kwargs["bot"] = bot
logger.info(
f"插件 '{plugin_name}' 开始为目标 [{group_id or '全局'}] "
f"执行定时任务 (ID: {schedule.id})。"
)
task = asyncio.create_task(task_func(group_id, **job_kwargs))
self._running_tasks.add(task)
task.add_done_callback(self._running_tasks.discard)
await task
except Exception as e:
logger.error(
f"执行定时任务 (ID: {schedule.id}, 插件: {plugin_name}, "
f"目标: {group_id or '全局'}) 时发生异常",
e=e,
)
def _validate_and_prepare_kwargs(
self, plugin_name: str, job_kwargs: dict | None
) -> tuple[bool, str | dict]:
"""验证并准备任务参数,应用默认值"""
task_meta = self._registered_tasks.get(plugin_name)
if not task_meta:
return False, f"插件 '{plugin_name}' 未注册。"
params_model = task_meta.get("model")
job_kwargs = job_kwargs if job_kwargs is not None else {}
if not params_model:
if job_kwargs:
logger.warning(
f"插件 '{plugin_name}' 未定义参数模型,但收到了参数: {job_kwargs}"
)
return True, job_kwargs
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
logger.error(f"插件 '{plugin_name}' 的参数模型不是有效的 BaseModel 类")
return False, f"插件 '{plugin_name}' 的参数模型配置错误"
try:
model_validate = getattr(params_model, "model_validate", None)
if not model_validate:
return False, f"插件 '{plugin_name}' 的参数模型不支持验证"
validated_model = model_validate(job_kwargs)
model_dump = getattr(validated_model, "model_dump", None)
if not model_dump:
return False, f"插件 '{plugin_name}' 的参数模型不支持导出"
return True, model_dump()
except ValidationError as e:
errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()]
error_str = "\n".join(errors)
msg = f"插件 '{plugin_name}' 的任务参数验证失败:\n{error_str}"
return False, msg
def _add_aps_job(self, schedule: ScheduleInfo):
"""根据 ScheduleInfo 对象添加或更新一个 APScheduler 任务。"""
job_id = self._get_job_id(schedule.id)
try:
scheduler.remove_job(job_id)
except Exception:
pass
if not isinstance(schedule.trigger_config, dict):
logger.error(
f"任务 {schedule.id} 的 trigger_config 不是字典类型: "
f"{type(schedule.trigger_config)}"
)
return
scheduler.add_job(
self._execute_job,
trigger=schedule.trigger_type,
id=job_id,
misfire_grace_time=300,
args=[schedule.id],
**schedule.trigger_config,
)
logger.debug(
f"已在 APScheduler 中添加/更新任务: {job_id} "
f"with trigger: {schedule.trigger_config}"
)
def _remove_aps_job(self, schedule_id: int):
"""移除一个 APScheduler 任务。"""
job_id = self._get_job_id(schedule_id)
try:
scheduler.remove_job(job_id)
logger.debug(f"已从 APScheduler 中移除任务: {job_id}")
except Exception:
pass
async def add_schedule(
self,
plugin_name: str,
group_id: str | None,
trigger_type: str,
trigger_config: dict,
job_kwargs: dict | None = None,
bot_id: str | None = None,
) -> tuple[bool, str]:
"""
添加或更新一个定时任务
"""
if plugin_name not in self._registered_tasks:
return False, f"插件 '{plugin_name}' 没有注册可用的定时任务。"
is_valid, result = self._validate_and_prepare_kwargs(plugin_name, job_kwargs)
if not is_valid:
return False, str(result)
validated_job_kwargs = result
effective_bot_id = bot_id if group_id == "__ALL_GROUPS__" else None
search_kwargs = {
"plugin_name": plugin_name,
"group_id": group_id,
}
if effective_bot_id:
search_kwargs["bot_id"] = effective_bot_id
else:
search_kwargs["bot_id__isnull"] = True
defaults = {
"trigger_type": trigger_type,
"trigger_config": trigger_config,
"job_kwargs": validated_job_kwargs,
"is_enabled": True,
}
schedule = await ScheduleInfo.filter(**search_kwargs).first()
created = False
if schedule:
for key, value in defaults.items():
setattr(schedule, key, value)
await schedule.save()
else:
creation_kwargs = {
"plugin_name": plugin_name,
"group_id": group_id,
"bot_id": effective_bot_id,
**defaults,
}
schedule = await ScheduleInfo.create(**creation_kwargs)
created = True
self._add_aps_job(schedule)
action = "设置" if created else "更新"
return True, f"已成功{action}插件 '{plugin_name}' 的定时任务。"
async def add_schedule_for_all(
self,
plugin_name: str,
trigger_type: str,
trigger_config: dict,
job_kwargs: dict | None = None,
) -> tuple[int, int]:
"""为所有机器人所在的群组添加定时任务。"""
if plugin_name not in self._registered_tasks:
raise ValueError(f"插件 '{plugin_name}' 没有注册可用的定时任务。")
groups = set()
for bot in get_bots().values():
try:
group_list, _ = await PlatformUtils.get_group_list(bot)
groups.update(
g.group_id for g in group_list if g.group_id and not g.channel_id
)
except Exception as e:
logger.error(f"获取 Bot {bot.self_id} 的群列表失败", e=e)
success_count = 0
fail_count = 0
for gid in groups:
try:
success, _ = await self.add_schedule(
plugin_name, gid, trigger_type, trigger_config, job_kwargs
)
if success:
success_count += 1
else:
fail_count += 1
except Exception as e:
logger.error(f"为群 {gid} 添加定时任务失败: {e}", e=e)
fail_count += 1
await asyncio.sleep(0.05)
return success_count, fail_count
async def update_schedule(
self,
schedule_id: int,
trigger_type: str | None = None,
trigger_config: dict | None = None,
job_kwargs: dict | None = None,
) -> tuple[bool, str]:
"""部分更新一个已存在的定时任务。"""
schedule = await self.get_schedule_by_id(schedule_id)
if not schedule:
return False, f"未找到 ID 为 {schedule_id} 的任务。"
updated_fields = []
if trigger_config is not None:
schedule.trigger_config = trigger_config
updated_fields.append("trigger_config")
if trigger_type is not None and schedule.trigger_type != trigger_type:
schedule.trigger_type = trigger_type
updated_fields.append("trigger_type")
if job_kwargs is not None:
if not isinstance(schedule.job_kwargs, dict):
return False, f"任务 {schedule_id} 的 job_kwargs 数据格式错误。"
merged_kwargs = schedule.job_kwargs.copy()
merged_kwargs.update(job_kwargs)
is_valid, result = self._validate_and_prepare_kwargs(
schedule.plugin_name, merged_kwargs
)
if not is_valid:
return False, str(result)
schedule.job_kwargs = result # type: ignore
updated_fields.append("job_kwargs")
if not updated_fields:
return True, "没有任何需要更新的配置。"
await schedule.save(update_fields=updated_fields)
self._add_aps_job(schedule)
return True, f"成功更新了任务 ID: {schedule_id} 的配置。"
async def remove_schedule(
self, plugin_name: str, group_id: str | None, bot_id: str | None = None
) -> tuple[bool, str]:
"""移除指定插件和群组的定时任务。"""
query = {"plugin_name": plugin_name, "group_id": group_id}
if bot_id:
query["bot_id"] = bot_id
schedules = await ScheduleInfo.filter(**query)
if not schedules:
msg = (
f"未找到与 Bot {bot_id} 相关的群 {group_id} "
f"的插件 '{plugin_name}' 定时任务。"
)
return (False, msg)
for schedule in schedules:
self._remove_aps_job(schedule.id)
await schedule.delete()
target_desc = f"{group_id}" if group_id else "全局"
msg = (
f"已取消 Bot {bot_id} 在 [{target_desc}] "
f"的插件 '{plugin_name}' 所有定时任务。"
)
return (True, msg)
async def remove_schedule_for_all(
self, plugin_name: str, bot_id: str | None = None
) -> int:
"""移除指定插件在所有群组的定时任务。"""
query = {"plugin_name": plugin_name}
if bot_id:
query["bot_id"] = bot_id
schedules_to_delete = await ScheduleInfo.filter(**query).all()
if not schedules_to_delete:
return 0
for schedule in schedules_to_delete:
self._remove_aps_job(schedule.id)
await schedule.delete()
await asyncio.sleep(0.01)
return len(schedules_to_delete)
async def remove_schedules_by_group(self, group_id: str) -> tuple[bool, str]:
"""移除指定群组的所有定时任务。"""
schedules = await ScheduleInfo.filter(group_id=group_id)
if not schedules:
return False, f"{group_id} 没有任何定时任务。"
count = 0
for schedule in schedules:
self._remove_aps_job(schedule.id)
await schedule.delete()
count += 1
await asyncio.sleep(0.01)
return True, f"已成功移除群 {group_id}{count} 个定时任务。"
async def pause_schedules_by_group(self, group_id: str) -> tuple[int, str]:
"""暂停指定群组的所有定时任务。"""
schedules = await ScheduleInfo.filter(group_id=group_id, is_enabled=True)
if not schedules:
return 0, f"{group_id} 没有正在运行的定时任务可暂停。"
count = 0
for schedule in schedules:
success, _ = await self.pause_schedule(schedule.id)
if success:
count += 1
await asyncio.sleep(0.01)
return count, f"已成功暂停群 {group_id}{count} 个定时任务。"
async def resume_schedules_by_group(self, group_id: str) -> tuple[int, str]:
"""恢复指定群组的所有定时任务。"""
schedules = await ScheduleInfo.filter(group_id=group_id, is_enabled=False)
if not schedules:
return 0, f"{group_id} 没有已暂停的定时任务可恢复。"
count = 0
for schedule in schedules:
success, _ = await self.resume_schedule(schedule.id)
if success:
count += 1
await asyncio.sleep(0.01)
return count, f"已成功恢复群 {group_id}{count} 个定时任务。"
async def pause_schedules_by_plugin(self, plugin_name: str) -> tuple[int, str]:
"""暂停指定插件在所有群组的定时任务。"""
schedules = await ScheduleInfo.filter(plugin_name=plugin_name, is_enabled=True)
if not schedules:
return 0, f"插件 '{plugin_name}' 没有正在运行的定时任务可暂停。"
count = 0
for schedule in schedules:
success, _ = await self.pause_schedule(schedule.id)
if success:
count += 1
await asyncio.sleep(0.01)
return (
count,
f"已成功暂停插件 '{plugin_name}' 在所有群组的 {count} 个定时任务。",
)
async def resume_schedules_by_plugin(self, plugin_name: str) -> tuple[int, str]:
"""恢复指定插件在所有群组的定时任务。"""
schedules = await ScheduleInfo.filter(plugin_name=plugin_name, is_enabled=False)
if not schedules:
return 0, f"插件 '{plugin_name}' 没有已暂停的定时任务可恢复。"
count = 0
for schedule in schedules:
success, _ = await self.resume_schedule(schedule.id)
if success:
count += 1
await asyncio.sleep(0.01)
return (
count,
f"已成功恢复插件 '{plugin_name}' 在所有群组的 {count} 个定时任务。",
)
async def pause_schedule_by_plugin_group(
self, plugin_name: str, group_id: str | None, bot_id: str | None = None
) -> tuple[bool, str]:
"""暂停指定插件在指定群组的定时任务。"""
query = {"plugin_name": plugin_name, "group_id": group_id, "is_enabled": True}
if bot_id:
query["bot_id"] = bot_id
schedules = await ScheduleInfo.filter(**query)
if not schedules:
return (
False,
f"{group_id} 未设置插件 '{plugin_name}' 的定时任务或任务已暂停。",
)
count = 0
for schedule in schedules:
success, _ = await self.pause_schedule(schedule.id)
if success:
count += 1
return (
True,
f"已成功暂停群 {group_id} 的插件 '{plugin_name}'{count} 个定时任务。",
)
async def resume_schedule_by_plugin_group(
self, plugin_name: str, group_id: str | None, bot_id: str | None = None
) -> tuple[bool, str]:
"""恢复指定插件在指定群组的定时任务。"""
query = {"plugin_name": plugin_name, "group_id": group_id, "is_enabled": False}
if bot_id:
query["bot_id"] = bot_id
schedules = await ScheduleInfo.filter(**query)
if not schedules:
return (
False,
f"{group_id} 未设置插件 '{plugin_name}' 的定时任务或任务已启用。",
)
count = 0
for schedule in schedules:
success, _ = await self.resume_schedule(schedule.id)
if success:
count += 1
return (
True,
f"已成功恢复群 {group_id} 的插件 '{plugin_name}'{count} 个定时任务。",
)
async def remove_all_schedules(self) -> tuple[int, str]:
"""移除所有群组的所有定时任务。"""
schedules = await ScheduleInfo.all()
if not schedules:
return 0, "当前没有任何定时任务。"
count = 0
for schedule in schedules:
self._remove_aps_job(schedule.id)
await schedule.delete()
count += 1
await asyncio.sleep(0.01)
return count, f"已成功移除所有群组的 {count} 个定时任务。"
async def pause_all_schedules(self) -> tuple[int, str]:
"""暂停所有群组的所有定时任务。"""
schedules = await ScheduleInfo.filter(is_enabled=True)
if not schedules:
return 0, "当前没有正在运行的定时任务可暂停。"
count = 0
for schedule in schedules:
success, _ = await self.pause_schedule(schedule.id)
if success:
count += 1
await asyncio.sleep(0.01)
return count, f"已成功暂停所有群组的 {count} 个定时任务。"
async def resume_all_schedules(self) -> tuple[int, str]:
"""恢复所有群组的所有定时任务。"""
schedules = await ScheduleInfo.filter(is_enabled=False)
if not schedules:
return 0, "当前没有已暂停的定时任务可恢复。"
count = 0
for schedule in schedules:
success, _ = await self.resume_schedule(schedule.id)
if success:
count += 1
await asyncio.sleep(0.01)
return count, f"已成功恢复所有群组的 {count} 个定时任务。"
async def remove_schedule_by_id(self, schedule_id: int) -> tuple[bool, str]:
"""通过ID移除指定的定时任务。"""
schedule = await self.get_schedule_by_id(schedule_id)
if not schedule:
return False, f"未找到 ID 为 {schedule_id} 的定时任务。"
self._remove_aps_job(schedule.id)
await schedule.delete()
return (
True,
f"已删除插件 '{schedule.plugin_name}' 在群 {schedule.group_id} "
f"的定时任务 (ID: {schedule.id})。",
)
async def get_schedule_by_id(self, schedule_id: int) -> ScheduleInfo | None:
"""通过ID获取定时任务信息。"""
return await ScheduleInfo.get_or_none(id=schedule_id)
async def get_schedules(
self, plugin_name: str, group_id: str | None
) -> list[ScheduleInfo]:
"""获取特定群组特定插件的所有定时任务。"""
return await ScheduleInfo.filter(plugin_name=plugin_name, group_id=group_id)
async def get_schedule(
self, plugin_name: str, group_id: str | None
) -> ScheduleInfo | None:
"""获取特定群组的定时任务信息。"""
return await ScheduleInfo.get_or_none(
plugin_name=plugin_name, group_id=group_id
)
async def get_all_schedules(
self, plugin_name: str | None = None
) -> list[ScheduleInfo]:
"""获取所有定时任务信息,可按插件名过滤。"""
if plugin_name:
return await ScheduleInfo.filter(plugin_name=plugin_name).all()
return await ScheduleInfo.all()
async def get_schedule_status(self, schedule_id: int) -> dict | None:
"""获取任务的详细状态。"""
schedule = await self.get_schedule_by_id(schedule_id)
if not schedule:
return None
job_id = self._get_job_id(schedule.id)
job = scheduler.get_job(job_id)
status = {
"id": schedule.id,
"bot_id": schedule.bot_id,
"plugin_name": schedule.plugin_name,
"group_id": schedule.group_id,
"is_enabled": schedule.is_enabled,
"trigger_type": schedule.trigger_type,
"trigger_config": schedule.trigger_config,
"job_kwargs": schedule.job_kwargs,
"next_run_time": job.next_run_time.strftime("%Y-%m-%d %H:%M:%S")
if job and job.next_run_time
else "N/A",
"is_paused_in_scheduler": not bool(job.next_run_time) if job else "N/A",
}
return status
async def pause_schedule(self, schedule_id: int) -> tuple[bool, str]:
"""暂停一个定时任务。"""
schedule = await self.get_schedule_by_id(schedule_id)
if not schedule or not schedule.is_enabled:
return False, "任务不存在或已暂停。"
schedule.is_enabled = False
await schedule.save(update_fields=["is_enabled"])
job_id = self._get_job_id(schedule.id)
try:
scheduler.pause_job(job_id)
except Exception:
pass
return (
True,
f"已暂停插件 '{schedule.plugin_name}' 在群 {schedule.group_id} "
f"的定时任务 (ID: {schedule.id})。",
)
async def resume_schedule(self, schedule_id: int) -> tuple[bool, str]:
"""恢复一个定时任务。"""
schedule = await self.get_schedule_by_id(schedule_id)
if not schedule or schedule.is_enabled:
return False, "任务不存在或已启用。"
schedule.is_enabled = True
await schedule.save(update_fields=["is_enabled"])
job_id = self._get_job_id(schedule.id)
try:
scheduler.resume_job(job_id)
except Exception:
self._add_aps_job(schedule)
return (
True,
f"已恢复插件 '{schedule.plugin_name}' 在群 {schedule.group_id} "
f"的定时任务 (ID: {schedule.id})。",
)
async def trigger_now(self, schedule_id: int) -> tuple[bool, str]:
"""手动触发一个定时任务。"""
schedule = await self.get_schedule_by_id(schedule_id)
if not schedule:
return False, f"未找到 ID 为 {schedule_id} 的定时任务。"
if schedule.plugin_name not in self._registered_tasks:
return False, f"插件 '{schedule.plugin_name}' 没有注册可用的定时任务。"
try:
await self._execute_job(schedule.id)
return (
True,
f"已手动触发插件 '{schedule.plugin_name}' 在群 {schedule.group_id} "
f"的定时任务 (ID: {schedule.id})。",
)
except Exception as e:
logger.error(f"手动触发任务失败: {e}")
return False, f"手动触发任务失败: {e}"
scheduler_manager = SchedulerManager()
@PriorityLifecycle.on_startup(priority=90)
async def _load_schedules_from_db():
"""在服务启动时从数据库加载并调度所有任务。"""
Config.add_plugin_config(
"SchedulerManager",
SCHEDULE_CONCURRENCY_KEY,
5,
help="“所有群组”类型定时任务的并发执行数量限制",
type=int,
)
logger.info("正在从数据库加载并调度所有定时任务...")
schedules = await ScheduleInfo.filter(is_enabled=True).all()
count = 0
for schedule in schedules:
if schedule.plugin_name in scheduler_manager._registered_tasks:
scheduler_manager._add_aps_job(schedule)
count += 1
else:
logger.warning(f"跳过加载定时任务:插件 '{schedule.plugin_name}' 未注册。")
logger.info(f"定时任务加载完成,共成功加载 {count} 个任务。")

View File

@ -80,14 +80,14 @@ class PlatformUtils:
@classmethod
async def send_superuser(
cls,
bot: Bot,
bot: Bot | None,
message: UniMessage | str,
superuser_id: str | None = None,
) -> list[tuple[str, Receipt]]:
"""发送消息给超级用户
参数:
bot: Bot
bot: Bot没有传入时使用get_bot随机获取
message: 消息
superuser_id: 指定超级用户id.
@ -97,6 +97,8 @@ class PlatformUtils:
返回:
Receipt | None: Receipt
"""
if not bot:
bot = nonebot.get_bot()
superuser_ids = []
if superuser_id:
superuser_ids.append(superuser_id)
@ -529,9 +531,16 @@ class BroadcastEngine:
try:
self.bot_list.append(nonebot.get_bot(i))
except KeyError:
logger.warning(f"Bot:{i} 对象未连接或不存在")
logger.warning(f"Bot:{i} 对象未连接或不存在", log_cmd)
if not self.bot_list:
raise ValueError("当前没有可用的Bot对象...", log_cmd)
try:
bot = nonebot.get_bot()
self.bot_list.append(bot)
logger.warning(
f"广播任务未传入Bot对象使用默认Bot {bot.self_id}", log_cmd
)
except Exception as e:
raise ValueError("当前没有可用的Bot对象...", log_cmd) from e
async def call_check(self, bot: Bot, group_id: str) -> bool:
"""运行发送检测函数

View File

@ -1,5 +1,5 @@
from collections import defaultdict
from datetime import datetime
from datetime import date, datetime
import os
from pathlib import Path
import time
@ -244,3 +244,20 @@ def is_number(text: str) -> bool:
return True
except ValueError:
return False
class TimeUtils:
@classmethod
def get_day_start(cls, target_date: date | datetime | None = None) -> datetime:
"""获取某天的0点时间
返回:
datetime: 今天某天的0点时间
"""
if not target_date:
target_date = datetime.now()
return (
target_date.replace(hour=0, minute=0, second=0, microsecond=0)
if isinstance(target_date, datetime)
else datetime.combine(target_date, datetime.min.time())
)