Merge branch 'main' into feature/db-cache

This commit is contained in:
HibiKier 2025-07-11 10:20:14 +08:00 committed by GitHub
commit 205325b994
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
61 changed files with 4601 additions and 3223 deletions

View File

@ -29,6 +29,7 @@
"unban",
"Uninfo",
"userinfo",
"webui",
"zhenxun"
],
"python.analysis.autoImportCompletions": true,

View File

@ -13,7 +13,11 @@ from pytest_mock import MockerFixture
from respx import MockRouter
from tests.config import BotId, GroupId, MessageId, UserId
from tests.utils import _v11_group_message_event, _v11_private_message_send
from tests.utils import (
_v11_group_message_event,
_v11_private_message_send,
get_reply_cq,
)
from tests.utils import get_response_json as _get_response_json
@ -311,6 +315,12 @@ async def test_check_update_release(
to_me=True,
)
ctx.receive_event(bot, event)
ctx.should_call_send(
event=event,
message=Message(f"{get_reply_cq(MessageId.MESSAGE_ID)}正在进行检查更新..."),
result=None,
bot=bot,
)
ctx.should_call_api(
"send_msg",
_v11_private_message_send(
@ -401,6 +411,12 @@ async def test_check_update_main(
to_me=True,
)
ctx.receive_event(bot, event)
ctx.should_call_send(
event=event,
message=Message(f"{get_reply_cq(MessageId.MESSAGE_ID)}正在进行检查更新..."),
result=None,
bot=bot,
)
ctx.should_call_api(
"send_msg",
_v11_private_message_send(

View File

@ -5,6 +5,10 @@ from nonebot.adapters.onebot.v11 import GroupMessageEvent, Message, MessageSegme
from nonebot.adapters.onebot.v11.event import Sender
def get_reply_cq(uid: int | str) -> str:
return f"[CQ:reply,id={uid}]"
def get_response_json(base_path: Path, file: str) -> dict:
try:
return json.loads(

View File

@ -1,7 +1,7 @@
from nonebot.adapters import Bot
from nonebot.plugin import PluginMetadata
from nonebot_plugin_alconna import AlconnaQuery, Arparma, Match, Query
from nonebot_plugin_session import EventSession
from nonebot_plugin_uninfo import Uninfo
from zhenxun.configs.config import Config
from zhenxun.configs.utils import PluginExtraData, RegisterConfig
@ -9,7 +9,7 @@ from zhenxun.services.log import logger
from zhenxun.utils.enum import BlockType, PluginType
from zhenxun.utils.message import MessageUtils
from ._data_source import PluginManage, build_plugin, build_task, delete_help_image
from ._data_source import PluginManager, build_plugin, build_task, delete_help_image
from .command import _group_status_matcher, _status_matcher
base_config = Config.get("plugin_switch")
@ -57,6 +57,11 @@ __plugin_meta__ = PluginMetadata(
关闭群被动早晚安
关闭群被动早晚安 -g 12355555
开启/关闭默认群被动 [被动名称]
私聊下: 开启/关闭群被动默认状态
示例:
关闭默认群被动 早晚安
开启/关闭所有群被动 ?[-g [group_id]]
私聊中: 开启/关闭全局或指定群组被动状态
示例:
@ -87,10 +92,10 @@ __plugin_meta__ = PluginMetadata(
@_status_matcher.assign("$main")
async def _(
bot: Bot,
session: EventSession,
session: Uninfo,
arparma: Arparma,
):
if session.id1 in bot.config.superusers:
if session.user.id in bot.config.superusers:
image = await build_plugin()
logger.info(
"查看功能列表",
@ -105,7 +110,7 @@ async def _(
@_status_matcher.assign("open")
async def _(
bot: Bot,
session: EventSession,
session: Uninfo,
arparma: Arparma,
plugin_name: Match[str],
group: Match[str],
@ -114,22 +119,23 @@ async def _(
all: Query[bool] = AlconnaQuery("all.value", False),
):
if not all.result and not plugin_name.available:
await MessageUtils.build_message("请输入功能名称").finish(reply_to=True)
await MessageUtils.build_message("请输入功能/被动名称").finish(reply_to=True)
name = plugin_name.result
if gid := session.id3 or session.id2:
if session.group:
group_id = session.group.id
"""修改当前群组的数据"""
if task.result:
if all.result:
result = await PluginManage.unblock_group_all_task(gid)
result = await PluginManager.unblock_group_all_task(group_id)
logger.info("开启所有群组被动", arparma.header_result, session=session)
else:
result = await PluginManage.unblock_group_task(name, gid)
result = await PluginManager.unblock_group_task(name, group_id)
logger.info(
f"开启群组被动 {name}", arparma.header_result, session=session
)
elif session.id1 in bot.config.superusers and default_status.result:
elif session.user.id in bot.config.superusers and default_status.result:
"""单个插件的进群默认修改"""
result = await PluginManage.set_default_status(name, True)
result = await PluginManager.set_default_status(name, True)
logger.info(
f"超级用户开启 {name} 功能进群默认开关",
arparma.header_result,
@ -137,8 +143,8 @@ async def _(
)
elif all.result:
"""所有插件"""
result = await PluginManage.set_all_plugin_status(
True, default_status.result, gid
result = await PluginManager.set_all_plugin_status(
True, default_status.result, group_id
)
logger.info(
"开启群组中全部功能",
@ -146,22 +152,24 @@ async def _(
session=session,
)
else:
result = await PluginManage.unblock_group_plugin(name, gid)
result = await PluginManager.unblock_group_plugin(name, group_id)
logger.info(f"开启功能 {name}", arparma.header_result, session=session)
delete_help_image(gid)
delete_help_image(group_id)
await MessageUtils.build_message(result).finish(reply_to=True)
elif session.id1 in bot.config.superusers:
elif session.user.id in bot.config.superusers:
"""私聊"""
group_id = group.result if group.available else None
if all.result:
if task.result:
"""关闭全局或指定群全部被动"""
if group_id:
result = await PluginManage.unblock_group_all_task(group_id)
result = await PluginManager.unblock_group_all_task(group_id)
else:
result = await PluginManage.unblock_global_all_task()
result = await PluginManager.unblock_global_all_task(
default_status.result
)
else:
result = await PluginManage.set_all_plugin_status(
result = await PluginManager.set_all_plugin_status(
True, default_status.result, group_id
)
logger.info(
@ -171,8 +179,8 @@ async def _(
session=session,
)
await MessageUtils.build_message(result).finish(reply_to=True)
if default_status.result:
result = await PluginManage.set_default_status(name, True)
if default_status.result and not task.result:
result = await PluginManager.set_default_status(name, True)
logger.info(
f"超级用户开启 {name} 功能进群默认开关",
arparma.header_result,
@ -186,7 +194,7 @@ async def _(
name = split_list[0]
group_id = split_list[1]
if group_id:
result = await PluginManage.superuser_task_handle(name, group_id, True)
result = await PluginManager.superuser_task_handle(name, group_id, True)
logger.info(
f"超级用户开启被动技能 {name}",
arparma.header_result,
@ -194,14 +202,16 @@ async def _(
target=group_id,
)
else:
result = await PluginManage.unblock_global_task(name)
result = await PluginManager.unblock_global_task(
name, default_status.result
)
logger.info(
f"超级用户开启全局被动技能 {name}",
arparma.header_result,
session=session,
)
else:
result = await PluginManage.superuser_unblock(name, None, group_id)
result = await PluginManager.superuser_unblock(name, None, group_id)
logger.info(
f"超级用户开启功能 {name}",
arparma.header_result,
@ -215,7 +225,7 @@ async def _(
@_status_matcher.assign("close")
async def _(
bot: Bot,
session: EventSession,
session: Uninfo,
arparma: Arparma,
plugin_name: Match[str],
block_type: Match[str],
@ -225,22 +235,23 @@ async def _(
all: Query[bool] = AlconnaQuery("all.value", False),
):
if not all.result and not plugin_name.available:
await MessageUtils.build_message("请输入功能名称").finish(reply_to=True)
await MessageUtils.build_message("请输入功能/被动名称").finish(reply_to=True)
name = plugin_name.result
if gid := session.id3 or session.id2:
if session.group:
group_id = session.group.id
"""修改当前群组的数据"""
if task.result:
if all.result:
result = await PluginManage.block_group_all_task(gid)
result = await PluginManager.block_group_all_task(group_id)
logger.info("开启所有群组被动", arparma.header_result, session=session)
else:
result = await PluginManage.block_group_task(name, gid)
result = await PluginManager.block_group_task(name, group_id)
logger.info(
f"关闭群组被动 {name}", arparma.header_result, session=session
)
elif session.id1 in bot.config.superusers and default_status.result:
elif session.user.id in bot.config.superusers and default_status.result:
"""单个插件的进群默认修改"""
result = await PluginManage.set_default_status(name, False)
result = await PluginManager.set_default_status(name, False)
logger.info(
f"超级用户开启 {name} 功能进群默认开关",
arparma.header_result,
@ -248,26 +259,28 @@ async def _(
)
elif all.result:
"""所有插件"""
result = await PluginManage.set_all_plugin_status(
False, default_status.result, gid
result = await PluginManager.set_all_plugin_status(
False, default_status.result, group_id
)
logger.info("关闭群组中全部功能", arparma.header_result, session=session)
else:
result = await PluginManage.block_group_plugin(name, gid)
result = await PluginManager.block_group_plugin(name, group_id)
logger.info(f"关闭功能 {name}", arparma.header_result, session=session)
delete_help_image(gid)
delete_help_image(group_id)
await MessageUtils.build_message(result).finish(reply_to=True)
elif session.id1 in bot.config.superusers:
elif session.user.id in bot.config.superusers:
group_id = group.result if group.available else None
if all.result:
if task.result:
"""关闭全局或指定群全部被动"""
if group_id:
result = await PluginManage.block_group_all_task(group_id)
result = await PluginManager.block_group_all_task(group_id)
else:
result = await PluginManage.block_global_all_task()
result = await PluginManager.block_global_all_task(
default_status.result
)
else:
result = await PluginManage.set_all_plugin_status(
result = await PluginManager.set_all_plugin_status(
False, default_status.result, group_id
)
logger.info(
@ -277,8 +290,8 @@ async def _(
session=session,
)
await MessageUtils.build_message(result).finish(reply_to=True)
if default_status.result:
result = await PluginManage.set_default_status(name, False)
if default_status.result and not task.result:
result = await PluginManager.set_default_status(name, False)
logger.info(
f"超级用户关闭 {name} 功能进群默认开关",
arparma.header_result,
@ -292,7 +305,9 @@ async def _(
name = split_list[0]
group_id = split_list[1]
if group_id:
result = await PluginManage.superuser_task_handle(name, group_id, False)
result = await PluginManager.superuser_task_handle(
name, group_id, False
)
logger.info(
f"超级用户关闭被动技能 {name}",
arparma.header_result,
@ -300,7 +315,9 @@ async def _(
target=group_id,
)
else:
result = await PluginManage.block_global_task(name)
result = await PluginManager.block_global_task(
name, default_status.result
)
logger.info(
f"超级用户关闭全局被动技能 {name}",
arparma.header_result,
@ -314,7 +331,7 @@ async def _(
elif block_type.result in ["g", "group"]:
if block_type.available:
_type = BlockType.GROUP
result = await PluginManage.superuser_block(name, _type, group_id)
result = await PluginManager.superuser_block(name, _type, group_id)
logger.info(
f"超级用户关闭功能 {name}, 禁用类型: {_type}",
arparma.header_result,
@ -327,19 +344,20 @@ async def _(
@_group_status_matcher.handle()
async def _(
session: EventSession,
session: Uninfo,
arparma: Arparma,
status: str,
):
if gid := session.id3 or session.id2:
if session.group:
group_id = session.group.id
if status == "sleep":
await PluginManage.sleep(gid)
await PluginManager.sleep(group_id)
logger.info("进行休眠", arparma.header_result, session=session)
await MessageUtils.build_message("那我先睡觉了...").finish()
else:
if await PluginManage.is_wake(gid):
if await PluginManager.is_wake(group_id):
await MessageUtils.build_message("我还醒着呢!").finish()
await PluginManage.wake(gid)
await PluginManager.wake(group_id)
logger.info("醒来", arparma.header_result, session=session)
await MessageUtils.build_message("呜..醒来了...").finish()
return MessageUtils.build_message("群组id为空...").send()
@ -347,10 +365,10 @@ async def _(
@_status_matcher.assign("task")
async def _(
session: EventSession,
session: Uninfo,
arparma: Arparma,
):
image = await build_task(session.id3 or session.id2)
image = await build_task(session.group.id if session.group else None)
if image:
logger.info("查看群被动列表", arparma.header_result, session=session)
await MessageUtils.build_message(image).finish(reply_to=True)

View File

@ -156,7 +156,7 @@ async def build_task(group_id: str | None) -> BuildImage:
)
class PluginManage:
class PluginManager:
@classmethod
async def set_default_status(cls, plugin_name: str, status: bool) -> str:
"""设置插件进群默认状态
@ -350,17 +350,21 @@ class PluginManage:
return await cls._change_group_task("", group_id, True, True)
@classmethod
async def block_global_all_task(cls) -> str:
async def block_global_all_task(cls, is_default: bool) -> str:
"""禁用全局被动技能
返回:
str: 返回信息
"""
await TaskInfo.all().update(status=False)
return "已全局禁用所有被动状态"
if is_default:
await TaskInfo.all().update(default_status=False)
return "已禁用所有被动进群默认状态"
else:
await TaskInfo.all().update(status=False)
return "已全局禁用所有被动状态"
@classmethod
async def block_global_task(cls, name: str) -> str:
async def block_global_task(cls, name: str, is_default: bool = False) -> str:
"""禁用全局被动技能
参数:
@ -369,31 +373,47 @@ class PluginManage:
返回:
str: 返回信息
"""
await TaskInfo.filter(name=name).update(status=False)
return f"已全局禁用被动状态 {name}"
if is_default:
await TaskInfo.filter(name=name).update(default_status=False)
return f"已禁用被动进群默认状态 {name}"
else:
await TaskInfo.filter(name=name).update(status=False)
return f"已全局禁用被动状态 {name}"
@classmethod
async def unblock_global_all_task(cls) -> str:
async def unblock_global_all_task(cls, is_default: bool) -> str:
"""开启全局被动技能
参数:
is_default: 是否为默认状态
返回:
str: 返回信息
"""
await TaskInfo.all().update(status=True)
return "已全局开启所有被动状态"
if is_default:
await TaskInfo.all().update(default_status=True)
return "已开启所有被动进群默认状态"
else:
await TaskInfo.all().update(status=True)
return "已全局开启所有被动状态"
@classmethod
async def unblock_global_task(cls, name: str) -> str:
async def unblock_global_task(cls, name: str, is_default: bool = False) -> str:
"""开启全局被动技能
参数:
name: 被动技能名称
is_default: 是否为默认状态
返回:
str: 返回信息
"""
await TaskInfo.filter(name=name).update(status=True)
return f"已全局开启被动状态 {name}"
if is_default:
await TaskInfo.filter(name=name).update(default_status=True)
return f"已开启被动进群默认状态 {name}"
else:
await TaskInfo.filter(name=name).update(status=True)
return f"已全局开启被动状态 {name}"
@classmethod
async def unblock_group_plugin(cls, plugin_name: str, group_id: str) -> str:

View File

@ -58,6 +58,19 @@ _status_matcher.shortcut(
prefix=True,
)
_status_matcher.shortcut(
r"开启(所有|全部)默认群被动",
command="switch",
arguments=["open", "--task", "--all", "-df"],
prefix=True,
)
_status_matcher.shortcut(
r"关闭(所有|全部)默认群被动",
command="switch",
arguments=["close", "--task", "--all", "-df"],
prefix=True,
)
_status_matcher.shortcut(
r"开启群被动\s*(?P<name>.+)",
@ -73,6 +86,20 @@ _status_matcher.shortcut(
prefix=True,
)
_status_matcher.shortcut(
r"开启默认群被动\s*(?P<name>.+)",
command="switch",
arguments=["open", "{name}", "--task", "-df"],
prefix=True,
)
_status_matcher.shortcut(
r"关闭默认群被动\s*(?P<name>.+)",
command="switch",
arguments=["close", "{name}", "--task", "-df"],
prefix=True,
)
_status_matcher.shortcut(
r"开启(所有|全部)群被动",

View File

@ -11,7 +11,7 @@ from nonebot_plugin_alconna import (
on_alconna,
store_true,
)
from nonebot_plugin_session import EventSession
from nonebot_plugin_uninfo import Uninfo
from zhenxun.configs.utils import PluginExtraData
from zhenxun.services.log import logger
@ -22,7 +22,7 @@ from zhenxun.utils.manager.resource_manager import (
)
from zhenxun.utils.message import MessageUtils
from ._data_source import UpdateManage
from ._data_source import UpdateManager
__plugin_meta__ = PluginMetadata(
name="自动更新",
@ -32,16 +32,18 @@ __plugin_meta__ = PluginMetadata(
检查更新真寻最新版本包括了自动更新
资源文件大小一般在130mb左右除非必须更新一般仅更新代码文件
指令
检查更新 [main|release|resource] ?[-r]
检查更新 [main|release|resource|webui] ?[-r]
main: main分支
release: 最新release
resource: 资源文件
webui: webui文件
-r: 下载资源文件一般在更新main或release时使用
示例:
检查更新 main
检查更新 main -r
检查更新 release -r
检查更新 resource
检查更新 webui
""".strip(),
extra=PluginExtraData(
author="HibiKier",
@ -53,7 +55,7 @@ __plugin_meta__ = PluginMetadata(
_matcher = on_alconna(
Alconna(
"检查更新",
Args["ver_type?", ["main", "release", "resource"]],
Args["ver_type?", ["main", "release", "resource", "webui"]],
Option("-r|--resource", action=store_true, help_text="下载资源文件"),
),
priority=1,
@ -66,23 +68,24 @@ _matcher = on_alconna(
@_matcher.handle()
async def _(
bot: Bot,
session: EventSession,
session: Uninfo,
ver_type: Match[str],
resource: Query[bool] = Query("resource", False),
):
if not session.id1:
await MessageUtils.build_message("用户id为空...").finish()
result = ""
await MessageUtils.build_message("正在进行检查更新...").send(reply_to=True)
if ver_type.result in {"main", "release"}:
if not ver_type.available:
result = await UpdateManage.check_version()
result = await UpdateManager.check_version()
logger.info("查看当前版本...", "检查更新", session=session)
await MessageUtils.build_message(result).finish()
try:
result = await UpdateManage.update(bot, session.id1, ver_type.result)
result = await UpdateManager.update(bot, session.user.id, ver_type.result)
except Exception as e:
logger.error("版本更新失败...", "检查更新", session=session, e=e)
await MessageUtils.build_message(f"更新版本失败...e: {e}").finish()
elif ver_type.result == "webui":
result = await UpdateManager.update_webui()
if resource.result or ver_type.result == "resource":
try:
await ResourceManager.init_resources(True)

View File

@ -7,6 +7,7 @@ import zipfile
from nonebot.adapters import Bot
from nonebot.utils import run_sync
from zhenxun.configs.path_config import DATA_PATH
from zhenxun.services.log import logger
from zhenxun.utils.github_utils import GithubUtils
from zhenxun.utils.github_utils.models import RepoInfo
@ -17,6 +18,7 @@ from .config import (
BACKUP_PATH,
BASE_PATH,
BASE_PATH_STRING,
COMMAND,
DEFAULT_GITHUB_URL,
DOWNLOAD_GZ_FILE,
DOWNLOAD_ZIP_FILE,
@ -38,7 +40,7 @@ def install_requirement():
if not requirement_path.exists():
logger.debug(
f"没有找到zhenxun的requirement.txt,目标路径为{requirement_path}", "插件管理"
f"没有找到zhenxun的requirement.txt,目标路径为{requirement_path}", COMMAND
)
return
try:
@ -48,9 +50,9 @@ def install_requirement():
capture_output=True,
text=True,
)
logger.debug(f"成功安装真寻依赖,日志:\n{result.stdout}", "插件管理")
logger.debug(f"成功安装真寻依赖,日志:\n{result.stdout}", COMMAND)
except subprocess.CalledProcessError as e:
logger.error(f"安装真寻依赖失败,错误:\n{e.stderr}", "插件管理", e=e)
logger.error(f"安装真寻依赖失败,错误:\n{e.stderr}", COMMAND, e=e)
@run_sync
@ -61,7 +63,7 @@ def _file_handle(latest_version: str | None):
latest_version: 版本号
"""
BACKUP_PATH.mkdir(exist_ok=True, parents=True)
logger.debug("开始解压文件压缩包...", "检查更新")
logger.debug("开始解压文件压缩包...", COMMAND)
download_file = DOWNLOAD_GZ_FILE
if DOWNLOAD_GZ_FILE.exists():
tf = tarfile.open(DOWNLOAD_GZ_FILE)
@ -69,7 +71,7 @@ def _file_handle(latest_version: str | None):
download_file = DOWNLOAD_ZIP_FILE
tf = zipfile.ZipFile(DOWNLOAD_ZIP_FILE)
tf.extractall(TMP_PATH)
logger.debug("解压文件压缩包完成...", "检查更新")
logger.debug("解压文件压缩包完成...", COMMAND)
download_file_path = TMP_PATH / next(
x for x in os.listdir(TMP_PATH) if (TMP_PATH / x).is_dir()
)
@ -79,52 +81,52 @@ def _file_handle(latest_version: str | None):
extract_path = download_file_path / BASE_PATH_STRING
target_path = BASE_PATH
if PYPROJECT_FILE.exists():
logger.debug(f"移除备份文件: {PYPROJECT_FILE}", "检查更新")
logger.debug(f"移除备份文件: {PYPROJECT_FILE}", COMMAND)
shutil.move(PYPROJECT_FILE, BACKUP_PATH / PYPROJECT_FILE_STRING)
if PYPROJECT_LOCK_FILE.exists():
logger.debug(f"移除备份文件: {PYPROJECT_LOCK_FILE}", "检查更新")
logger.debug(f"移除备份文件: {PYPROJECT_LOCK_FILE}", COMMAND)
shutil.move(PYPROJECT_LOCK_FILE, BACKUP_PATH / PYPROJECT_LOCK_FILE_STRING)
if REQ_TXT_FILE.exists():
logger.debug(f"移除备份文件: {REQ_TXT_FILE}", "检查更新")
logger.debug(f"移除备份文件: {REQ_TXT_FILE}", COMMAND)
shutil.move(REQ_TXT_FILE, BACKUP_PATH / REQ_TXT_FILE_STRING)
if _pyproject.exists():
logger.debug("移动文件: pyproject.toml", "检查更新")
logger.debug("移动文件: pyproject.toml", COMMAND)
shutil.move(_pyproject, PYPROJECT_FILE)
if _lock_file.exists():
logger.debug("移动文件: poetry.lock", "检查更新")
logger.debug("移动文件: poetry.lock", COMMAND)
shutil.move(_lock_file, PYPROJECT_LOCK_FILE)
if _req_file.exists():
logger.debug("移动文件: requirements.txt", "检查更新")
logger.debug("移动文件: requirements.txt", COMMAND)
shutil.move(_req_file, REQ_TXT_FILE)
for folder in REPLACE_FOLDERS:
"""移动指定文件夹"""
_dir = BASE_PATH / folder
_backup_dir = BACKUP_PATH / folder
if _backup_dir.exists():
logger.debug(f"删除备份文件夹 {_backup_dir}", "检查更新")
logger.debug(f"删除备份文件夹 {_backup_dir}", COMMAND)
shutil.rmtree(_backup_dir)
if _dir.exists():
logger.debug(f"移动旧文件夹 {_dir}", "检查更新")
logger.debug(f"移动旧文件夹 {_dir}", COMMAND)
shutil.move(_dir, _backup_dir)
else:
logger.warning(f"文件夹 {_dir} 不存在,跳过删除", "检查更新")
logger.warning(f"文件夹 {_dir} 不存在,跳过删除", COMMAND)
for folder in REPLACE_FOLDERS:
src_folder_path = extract_path / folder
dest_folder_path = target_path / folder
if src_folder_path.exists():
logger.debug(
f"移动文件夹: {src_folder_path} -> {dest_folder_path}", "检查更新"
f"移动文件夹: {src_folder_path} -> {dest_folder_path}", COMMAND
)
shutil.move(src_folder_path, dest_folder_path)
else:
logger.debug(f"源文件夹不存在: {src_folder_path}", "检查更新")
logger.debug(f"源文件夹不存在: {src_folder_path}", COMMAND)
if tf:
tf.close()
if download_file.exists():
logger.debug(f"删除下载文件: {download_file}", "检查更新")
logger.debug(f"删除下载文件: {download_file}", COMMAND)
download_file.unlink()
if extract_path.exists():
logger.debug(f"删除解压文件夹: {extract_path}", "检查更新")
logger.debug(f"删除解压文件夹: {extract_path}", COMMAND)
shutil.rmtree(extract_path)
if TMP_PATH.exists():
shutil.rmtree(TMP_PATH)
@ -134,7 +136,35 @@ def _file_handle(latest_version: str | None):
install_requirement()
class UpdateManage:
class UpdateManager:
@classmethod
async def update_webui(cls) -> str:
from zhenxun.builtin_plugins.web_ui.public.data_source import (
update_webui_assets,
)
WEBUI_PATH = DATA_PATH / "web_ui" / "public"
BACKUP_PATH = DATA_PATH / "web_ui" / "backup_public"
if WEBUI_PATH.exists():
if BACKUP_PATH.exists():
logger.debug(f"删除旧的备份webui文件夹 {BACKUP_PATH}", COMMAND)
shutil.rmtree(BACKUP_PATH)
WEBUI_PATH.rename(BACKUP_PATH)
try:
await update_webui_assets()
logger.info("更新webui成功...", COMMAND)
if BACKUP_PATH.exists():
logger.debug(f"删除旧的webui文件夹 {BACKUP_PATH}", COMMAND)
shutil.rmtree(BACKUP_PATH)
return "Webui更新成功"
except Exception as e:
logger.error("更新webui失败...", COMMAND, e=e)
if BACKUP_PATH.exists():
logger.debug(f"恢复旧的webui文件夹 {BACKUP_PATH}", COMMAND)
BACKUP_PATH.rename(WEBUI_PATH)
raise e
return ""
@classmethod
async def check_version(cls) -> str:
"""检查更新版本
@ -166,7 +196,7 @@ class UpdateManage:
返回:
str | None: 返回消息
"""
logger.info("开始下载真寻最新版文件....", "检查更新")
logger.info("开始下载真寻最新版文件....", COMMAND)
cur_version = cls.__get_version()
url = None
new_version = None
@ -186,11 +216,11 @@ class UpdateManage:
if not url:
return "获取版本下载链接失败..."
if TMP_PATH.exists():
logger.debug(f"删除临时文件夹 {TMP_PATH}", "检查更新")
logger.debug(f"删除临时文件夹 {TMP_PATH}", COMMAND)
shutil.rmtree(TMP_PATH)
logger.debug(
f"开始更新版本:{cur_version} -> {new_version} | 下载链接:{url}",
"检查更新",
COMMAND,
)
await PlatformUtils.send_superuser(
bot,
@ -201,7 +231,7 @@ class UpdateManage:
DOWNLOAD_GZ_FILE if version_type == "release" else DOWNLOAD_ZIP_FILE
)
if await AsyncHttpx.download_file(url, download_file, stream=True):
logger.debug("下载真寻最新版文件完成...", "检查更新")
logger.debug("下载真寻最新版文件完成...", COMMAND)
await _file_handle(new_version)
result = "版本更新完成"
return (
@ -210,7 +240,7 @@ class UpdateManage:
"请重新启动真寻以完成更新!"
)
else:
logger.debug("下载真寻最新版文件失败...", "检查更新")
logger.debug("下载真寻最新版文件失败...", COMMAND)
return ""
@classmethod

View File

@ -34,3 +34,5 @@ REPLACE_FOLDERS = [
"models",
"configs",
]
COMMAND = "检查更新"

View File

@ -54,22 +54,6 @@ __plugin_meta__ = PluginMetadata(
default_value=5,
type=int,
),
RegisterConfig(
module="_task",
key="DEFAULT_GROUP_WELCOME",
value=True,
help="被动 进群欢迎 进群默认开关状态",
default_value=True,
type=bool,
),
RegisterConfig(
module="_task",
key="DEFAULT_REFUND_GROUP_REMIND",
value=True,
help="被动 退群提醒 进群默认开关状态",
default_value=True,
type=bool,
),
],
tasks=[
Task(

View File

@ -1,9 +1,11 @@
from nonebot.plugin import PluginMetadata
from zhenxun.configs.utils import PluginExtraData
from zhenxun.configs.utils import PluginExtraData, RegisterConfig
from zhenxun.utils.enum import PluginType
from . import command # noqa: F401
from . import commands, handlers
__all__ = ["commands", "handlers"]
__plugin_meta__ = PluginMetadata(
name="定时任务管理",
@ -27,6 +29,8 @@ __plugin_meta__ = PluginMetadata(
定时任务 恢复 <任务ID> | -p <插件> [-g <群号>] | -all
定时任务 执行 <任务ID>
定时任务 更新 <任务ID> [时间选项] [--kwargs <参数>]
# [修改] 增加说明
说明: -p 选项可单独使用用于操作指定插件的所有任务
📝 时间选项 (三选一):
--cron "<分> <时> <日> <月> <周>" # 例: --cron "0 8 * * *"
@ -47,5 +51,35 @@ __plugin_meta__ = PluginMetadata(
version="0.1.2",
plugin_type=PluginType.SUPERUSER,
is_show=False,
configs=[
RegisterConfig(
module="SchedulerManager",
key="ALL_GROUPS_CONCURRENCY_LIMIT",
value=5,
help="“所有群组”类型定时任务的并发执行数量限制",
type=int,
),
RegisterConfig(
module="SchedulerManager",
key="JOB_MAX_RETRIES",
value=2,
help="定时任务执行失败时的最大重试次数",
type=int,
),
RegisterConfig(
module="SchedulerManager",
key="JOB_RETRY_DELAY",
value=10,
help="定时任务执行重试的间隔时间(秒)",
type=int,
),
RegisterConfig(
module="SchedulerManager",
key="SCHEDULER_TIMEZONE",
value="Asia/Shanghai",
help="定时任务使用的时区,默认为 Asia/Shanghai",
type=str,
),
],
).to_dict(),
)

View File

@ -1,836 +0,0 @@
import asyncio
from datetime import datetime
import re
from nonebot.adapters import Event
from nonebot.adapters.onebot.v11 import Bot
from nonebot.params import Depends
from nonebot.permission import SUPERUSER
from nonebot_plugin_alconna import (
Alconna,
AlconnaMatch,
Args,
Arparma,
Match,
Option,
Query,
Subcommand,
on_alconna,
)
from pydantic import BaseModel, ValidationError
from zhenxun.utils._image_template import ImageTemplate
from zhenxun.utils.manager.schedule_manager import scheduler_manager
def _get_type_name(annotation) -> str:
"""获取类型注解的名称"""
if hasattr(annotation, "__name__"):
return annotation.__name__
elif hasattr(annotation, "_name"):
return annotation._name
else:
return str(annotation)
from zhenxun.utils.message import MessageUtils
from zhenxun.utils.rules import admin_check
def _format_trigger(schedule_status: dict) -> str:
"""将触发器配置格式化为人类可读的字符串"""
trigger_type = schedule_status["trigger_type"]
config = schedule_status["trigger_config"]
if trigger_type == "cron":
minute = config.get("minute", "*")
hour = config.get("hour", "*")
day = config.get("day", "*")
month = config.get("month", "*")
day_of_week = config.get("day_of_week", "*")
if day == "*" and month == "*" and day_of_week == "*":
formatted_hour = hour if hour == "*" else f"{int(hour):02d}"
formatted_minute = minute if minute == "*" else f"{int(minute):02d}"
return f"每天 {formatted_hour}:{formatted_minute}"
else:
return f"Cron: {minute} {hour} {day} {month} {day_of_week}"
elif trigger_type == "interval":
seconds = config.get("seconds", 0)
minutes = config.get("minutes", 0)
hours = config.get("hours", 0)
days = config.get("days", 0)
if days:
trigger_str = f"{days}"
elif hours:
trigger_str = f"{hours} 小时"
elif minutes:
trigger_str = f"{minutes} 分钟"
else:
trigger_str = f"{seconds}"
elif trigger_type == "date":
run_date = config.get("run_date", "未知时间")
trigger_str = f"{run_date}"
else:
trigger_str = f"{trigger_type}: {config}"
return trigger_str
def _format_params(schedule_status: dict) -> str:
"""将任务参数格式化为人类可读的字符串"""
if kwargs := schedule_status.get("job_kwargs"):
kwargs_str = " | ".join(f"{k}: {v}" for k, v in kwargs.items())
return kwargs_str
return "-"
def _parse_interval(interval_str: str) -> dict:
"""增强版解析器,支持 d(天)"""
match = re.match(r"(\d+)([smhd])", interval_str.lower())
if not match:
raise ValueError("时间间隔格式错误, 请使用如 '30m', '2h', '1d', '10s' 的格式。")
value, unit = int(match.group(1)), match.group(2)
if unit == "s":
return {"seconds": value}
if unit == "m":
return {"minutes": value}
if unit == "h":
return {"hours": value}
if unit == "d":
return {"days": value}
return {}
def _parse_daily_time(time_str: str) -> dict:
"""解析 HH:MM 或 HH:MM:SS 格式的时间为 cron 配置"""
if match := re.match(r"^(\d{1,2}):(\d{1,2})(?::(\d{1,2}))?$", time_str):
hour, minute, second = match.groups()
hour, minute = int(hour), int(minute)
if not (0 <= hour <= 23 and 0 <= minute <= 59):
raise ValueError("小时或分钟数值超出范围。")
cron_config = {
"minute": str(minute),
"hour": str(hour),
"day": "*",
"month": "*",
"day_of_week": "*",
}
if second is not None:
if not (0 <= int(second) <= 59):
raise ValueError("秒数值超出范围。")
cron_config["second"] = str(second)
return cron_config
else:
raise ValueError("时间格式错误,请使用 'HH:MM''HH:MM:SS' 格式。")
async def GetBotId(
bot: Bot,
bot_id_match: Match[str] = AlconnaMatch("bot_id"),
) -> str:
"""获取要操作的Bot ID"""
if bot_id_match.available:
return bot_id_match.result
return bot.self_id
class ScheduleTarget:
"""定时任务操作目标的基类"""
pass
class TargetByID(ScheduleTarget):
"""按任务ID操作"""
def __init__(self, id: int):
self.id = id
class TargetByPlugin(ScheduleTarget):
"""按插件名操作"""
def __init__(
self, plugin: str, group_id: str | None = None, all_groups: bool = False
):
self.plugin = plugin
self.group_id = group_id
self.all_groups = all_groups
class TargetAll(ScheduleTarget):
"""操作所有任务"""
def __init__(self, for_group: str | None = None):
self.for_group = for_group
TargetScope = TargetByID | TargetByPlugin | TargetAll | None
def create_target_parser(subcommand_name: str):
"""
创建一个依赖注入函数用于解析删除暂停恢复等命令的操作目标
"""
async def dependency(
event: Event,
schedule_id: Match[int] = AlconnaMatch("schedule_id"),
plugin_name: Match[str] = AlconnaMatch("plugin_name"),
group_id: Match[str] = AlconnaMatch("group_id"),
all_enabled: Query[bool] = Query(f"{subcommand_name}.all"),
) -> TargetScope:
if schedule_id.available:
return TargetByID(schedule_id.result)
if plugin_name.available:
p_name = plugin_name.result
if all_enabled.available:
return TargetByPlugin(plugin=p_name, all_groups=True)
elif group_id.available:
gid = group_id.result
if gid.lower() == "all":
return TargetByPlugin(plugin=p_name, all_groups=True)
return TargetByPlugin(plugin=p_name, group_id=gid)
else:
current_group_id = getattr(event, "group_id", None)
if current_group_id:
return TargetByPlugin(plugin=p_name, group_id=str(current_group_id))
else:
await schedule_cmd.finish(
"私聊中操作插件任务必须使用 -g <群号> 或 -all 选项。"
)
if all_enabled.available:
return TargetAll(for_group=group_id.result if group_id.available else None)
return None
return dependency
schedule_cmd = on_alconna(
Alconna(
"定时任务",
Subcommand(
"查看",
Option("-g", Args["target_group_id", str]),
Option("-all", help_text="查看所有群聊 (SUPERUSER)"),
Option("-p", Args["plugin_name", str], help_text="按插件名筛选"),
Option("--page", Args["page", int, 1], help_text="指定页码"),
alias=["ls", "list"],
help_text="查看定时任务",
),
Subcommand(
"设置",
Args["plugin_name", str],
Option("--cron", Args["cron_expr", str], help_text="设置 cron 表达式"),
Option("--interval", Args["interval_expr", str], help_text="设置时间间隔"),
Option("--date", Args["date_expr", str], help_text="设置特定执行日期"),
Option(
"--daily",
Args["daily_expr", str],
help_text="设置每天执行的时间 (如 08:20)",
),
Option("-g", Args["group_id", str], help_text="指定群组ID或'all'"),
Option("-all", help_text="对所有群生效 (等同于 -g all)"),
Option("--kwargs", Args["kwargs_str", str], help_text="设置任务参数"),
Option(
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
),
alias=["add", "开启"],
help_text="设置/开启一个定时任务",
),
Subcommand(
"删除",
Args["schedule_id?", int],
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
Option("-g", Args["group_id", str], help_text="指定群组ID"),
Option("-all", help_text="对所有群生效"),
Option(
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
),
alias=["del", "rm", "remove", "关闭", "取消"],
help_text="删除一个或多个定时任务",
),
Subcommand(
"暂停",
Args["schedule_id?", int],
Option("-all", help_text="对当前群所有任务生效"),
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
Option("-g", Args["group_id", str], help_text="指定群组ID (SUPERUSER)"),
Option(
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
),
alias=["pause"],
help_text="暂停一个或多个定时任务",
),
Subcommand(
"恢复",
Args["schedule_id?", int],
Option("-all", help_text="对当前群所有任务生效"),
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
Option("-g", Args["group_id", str], help_text="指定群组ID (SUPERUSER)"),
Option(
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
),
alias=["resume"],
help_text="恢复一个或多个定时任务",
),
Subcommand(
"执行",
Args["schedule_id", int],
alias=["trigger", "run"],
help_text="立即执行一次任务",
),
Subcommand(
"更新",
Args["schedule_id", int],
Option("--cron", Args["cron_expr", str], help_text="设置 cron 表达式"),
Option("--interval", Args["interval_expr", str], help_text="设置时间间隔"),
Option("--date", Args["date_expr", str], help_text="设置特定执行日期"),
Option(
"--daily",
Args["daily_expr", str],
help_text="更新每天执行的时间 (如 08:20)",
),
Option("--kwargs", Args["kwargs_str", str], help_text="更新参数"),
alias=["update", "modify", "修改"],
help_text="更新任务配置",
),
Subcommand(
"状态",
Args["schedule_id", int],
alias=["status", "info"],
help_text="查看单个任务的详细状态",
),
Subcommand(
"插件列表",
alias=["plugins"],
help_text="列出所有可用的插件",
),
),
priority=5,
block=True,
rule=admin_check(1),
)
schedule_cmd.shortcut(
"任务状态",
command="定时任务",
arguments=["状态", "{%0}"],
prefix=True,
)
@schedule_cmd.handle()
async def _handle_time_options_mutex(arp: Arparma):
time_options = ["cron", "interval", "date", "daily"]
provided_options = [opt for opt in time_options if arp.query(opt) is not None]
if len(provided_options) > 1:
await schedule_cmd.finish(
f"时间选项 --{', --'.join(provided_options)} 不能同时使用,请只选择一个。"
)
@schedule_cmd.assign("查看")
async def _(
bot: Bot,
event: Event,
target_group_id: Match[str] = AlconnaMatch("target_group_id"),
all_groups: Query[bool] = Query("查看.all"),
plugin_name: Match[str] = AlconnaMatch("plugin_name"),
page: Match[int] = AlconnaMatch("page"),
):
is_superuser = await SUPERUSER(bot, event)
schedules = []
title = ""
current_group_id = getattr(event, "group_id", None)
if not (all_groups.available or target_group_id.available) and not current_group_id:
await schedule_cmd.finish("私聊中查看任务必须使用 -g <群号> 或 -all 选项。")
if all_groups.available:
if not is_superuser:
await schedule_cmd.finish("需要超级用户权限才能查看所有群组的定时任务。")
schedules = await scheduler_manager.get_all_schedules()
title = "所有群组的定时任务"
elif target_group_id.available:
if not is_superuser:
await schedule_cmd.finish("需要超级用户权限才能查看指定群组的定时任务。")
gid = target_group_id.result
schedules = [
s for s in await scheduler_manager.get_all_schedules() if s.group_id == gid
]
title = f"{gid} 的定时任务"
else:
gid = str(current_group_id)
schedules = [
s for s in await scheduler_manager.get_all_schedules() if s.group_id == gid
]
title = "本群的定时任务"
if plugin_name.available:
schedules = [s for s in schedules if s.plugin_name == plugin_name.result]
title += f" [插件: {plugin_name.result}]"
if not schedules:
await schedule_cmd.finish("没有找到任何相关的定时任务。")
page_size = 15
current_page = page.result
total_items = len(schedules)
total_pages = (total_items + page_size - 1) // page_size
start_index = (current_page - 1) * page_size
end_index = start_index + page_size
paginated_schedules = schedules[start_index:end_index]
if not paginated_schedules:
await schedule_cmd.finish("这一页没有内容了哦~")
status_tasks = [
scheduler_manager.get_schedule_status(s.id) for s in paginated_schedules
]
all_statuses = await asyncio.gather(*status_tasks)
data_list = [
[
s["id"],
s["plugin_name"],
s.get("bot_id") or "N/A",
s["group_id"] or "全局",
s["next_run_time"],
_format_trigger(s),
_format_params(s),
"✔️ 已启用" if s["is_enabled"] else "⏸️ 已暂停",
]
for s in all_statuses
if s
]
if not data_list:
await schedule_cmd.finish("没有找到任何相关的定时任务。")
img = await ImageTemplate.table_page(
head_text=title,
tip_text=f"{current_page}/{total_pages} 页,共 {total_items} 条任务",
column_name=[
"ID",
"插件",
"Bot ID",
"群组/目标",
"下次运行",
"触发规则",
"参数",
"状态",
],
data_list=data_list,
column_space=20,
)
await MessageUtils.build_message(img).send(reply_to=True)
@schedule_cmd.assign("设置")
async def _(
event: Event,
plugin_name: str,
cron_expr: str | None = None,
interval_expr: str | None = None,
date_expr: str | None = None,
daily_expr: str | None = None,
group_id: str | None = None,
kwargs_str: str | None = None,
all_enabled: Query[bool] = Query("设置.all"),
bot_id_to_operate: str = Depends(GetBotId),
):
if plugin_name not in scheduler_manager._registered_tasks:
await schedule_cmd.finish(
f"插件 '{plugin_name}' 没有注册可用的定时任务。\n"
f"可用插件: {list(scheduler_manager._registered_tasks.keys())}"
)
trigger_type = ""
trigger_config = {}
try:
if cron_expr:
trigger_type = "cron"
parts = cron_expr.split()
if len(parts) != 5:
raise ValueError("Cron 表达式必须有5个部分 (分 时 日 月 周)")
cron_keys = ["minute", "hour", "day", "month", "day_of_week"]
trigger_config = dict(zip(cron_keys, parts))
elif interval_expr:
trigger_type = "interval"
trigger_config = _parse_interval(interval_expr)
elif date_expr:
trigger_type = "date"
trigger_config = {"run_date": datetime.fromisoformat(date_expr)}
elif daily_expr:
trigger_type = "cron"
trigger_config = _parse_daily_time(daily_expr)
else:
await schedule_cmd.finish(
"必须提供一种时间选项: --cron, --interval, --date, 或 --daily。"
)
except ValueError as e:
await schedule_cmd.finish(f"时间参数解析错误: {e}")
job_kwargs = {}
if kwargs_str:
task_meta = scheduler_manager._registered_tasks[plugin_name]
params_model = task_meta.get("model")
if not params_model:
await schedule_cmd.finish(f"插件 '{plugin_name}' 不支持设置额外参数。")
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
await schedule_cmd.finish(f"插件 '{plugin_name}' 的参数模型配置错误。")
raw_kwargs = {}
try:
for item in kwargs_str.split(","):
key, value = item.strip().split("=", 1)
raw_kwargs[key.strip()] = value
except Exception as e:
await schedule_cmd.finish(
f"参数格式错误,请使用 'key=value,key2=value2' 格式。错误: {e}"
)
try:
model_validate = getattr(params_model, "model_validate", None)
if not model_validate:
await schedule_cmd.finish(
f"插件 '{plugin_name}' 的参数模型不支持验证。"
)
return
validated_model = model_validate(raw_kwargs)
model_dump = getattr(validated_model, "model_dump", None)
if not model_dump:
await schedule_cmd.finish(
f"插件 '{plugin_name}' 的参数模型不支持导出。"
)
return
job_kwargs = model_dump()
except ValidationError as e:
errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()]
error_str = "\n".join(errors)
await schedule_cmd.finish(
f"插件 '{plugin_name}' 的任务参数验证失败:\n{error_str}"
)
return
target_group_id: str | None
current_group_id = getattr(event, "group_id", None)
if group_id and group_id.lower() == "all":
target_group_id = "__ALL_GROUPS__"
elif all_enabled.available:
target_group_id = "__ALL_GROUPS__"
elif group_id:
target_group_id = group_id
elif current_group_id:
target_group_id = str(current_group_id)
else:
await schedule_cmd.finish(
"私聊中设置定时任务时,必须使用 -g <群号> 或 --all 选项指定目标。"
)
return
success, msg = await scheduler_manager.add_schedule(
plugin_name,
target_group_id,
trigger_type,
trigger_config,
job_kwargs,
bot_id=bot_id_to_operate,
)
if target_group_id == "__ALL_GROUPS__":
target_desc = f"所有群组 (Bot: {bot_id_to_operate})"
elif target_group_id is None:
target_desc = "全局"
else:
target_desc = f"群组 {target_group_id}"
if success:
await schedule_cmd.finish(f"已成功为 [{target_desc}] {msg}")
else:
await schedule_cmd.finish(f"为 [{target_desc}] 设置任务失败: {msg}")
@schedule_cmd.assign("删除")
async def _(
target: TargetScope = Depends(create_target_parser("删除")),
bot_id_to_operate: str = Depends(GetBotId),
):
if isinstance(target, TargetByID):
_, message = await scheduler_manager.remove_schedule_by_id(target.id)
await schedule_cmd.finish(message)
elif isinstance(target, TargetByPlugin):
p_name = target.plugin
if p_name not in scheduler_manager.get_registered_plugins():
await schedule_cmd.finish(f"未找到插件 '{p_name}'")
if target.all_groups:
removed_count = await scheduler_manager.remove_schedule_for_all(
p_name, bot_id=bot_id_to_operate
)
message = (
f"已取消了 {removed_count} 个群组的插件 '{p_name}' 定时任务。"
if removed_count > 0
else f"没有找到插件 '{p_name}' 的定时任务。"
)
await schedule_cmd.finish(message)
else:
_, message = await scheduler_manager.remove_schedule(
p_name, target.group_id, bot_id=bot_id_to_operate
)
await schedule_cmd.finish(message)
elif isinstance(target, TargetAll):
if target.for_group:
_, message = await scheduler_manager.remove_schedules_by_group(
target.for_group
)
await schedule_cmd.finish(message)
else:
_, message = await scheduler_manager.remove_all_schedules()
await schedule_cmd.finish(message)
else:
await schedule_cmd.finish(
"删除任务失败请提供任务ID或通过 -p <插件> 或 -all 指定要删除的任务。"
)
@schedule_cmd.assign("暂停")
async def _(
target: TargetScope = Depends(create_target_parser("暂停")),
bot_id_to_operate: str = Depends(GetBotId),
):
if isinstance(target, TargetByID):
_, message = await scheduler_manager.pause_schedule(target.id)
await schedule_cmd.finish(message)
elif isinstance(target, TargetByPlugin):
p_name = target.plugin
if p_name not in scheduler_manager.get_registered_plugins():
await schedule_cmd.finish(f"未找到插件 '{p_name}'")
if target.all_groups:
_, message = await scheduler_manager.pause_schedules_by_plugin(p_name)
await schedule_cmd.finish(message)
else:
_, message = await scheduler_manager.pause_schedule_by_plugin_group(
p_name, target.group_id, bot_id=bot_id_to_operate
)
await schedule_cmd.finish(message)
elif isinstance(target, TargetAll):
if target.for_group:
_, message = await scheduler_manager.pause_schedules_by_group(
target.for_group
)
await schedule_cmd.finish(message)
else:
_, message = await scheduler_manager.pause_all_schedules()
await schedule_cmd.finish(message)
else:
await schedule_cmd.finish("请提供任务ID、使用 -p <插件> 或 -all 选项。")
@schedule_cmd.assign("恢复")
async def _(
target: TargetScope = Depends(create_target_parser("恢复")),
bot_id_to_operate: str = Depends(GetBotId),
):
if isinstance(target, TargetByID):
_, message = await scheduler_manager.resume_schedule(target.id)
await schedule_cmd.finish(message)
elif isinstance(target, TargetByPlugin):
p_name = target.plugin
if p_name not in scheduler_manager.get_registered_plugins():
await schedule_cmd.finish(f"未找到插件 '{p_name}'")
if target.all_groups:
_, message = await scheduler_manager.resume_schedules_by_plugin(p_name)
await schedule_cmd.finish(message)
else:
_, message = await scheduler_manager.resume_schedule_by_plugin_group(
p_name, target.group_id, bot_id=bot_id_to_operate
)
await schedule_cmd.finish(message)
elif isinstance(target, TargetAll):
if target.for_group:
_, message = await scheduler_manager.resume_schedules_by_group(
target.for_group
)
await schedule_cmd.finish(message)
else:
_, message = await scheduler_manager.resume_all_schedules()
await schedule_cmd.finish(message)
else:
await schedule_cmd.finish("请提供任务ID、使用 -p <插件> 或 -all 选项。")
@schedule_cmd.assign("执行")
async def _(schedule_id: int):
_, message = await scheduler_manager.trigger_now(schedule_id)
await schedule_cmd.finish(message)
@schedule_cmd.assign("更新")
async def _(
schedule_id: int,
cron_expr: str | None = None,
interval_expr: str | None = None,
date_expr: str | None = None,
daily_expr: str | None = None,
kwargs_str: str | None = None,
):
if not any([cron_expr, interval_expr, date_expr, daily_expr, kwargs_str]):
await schedule_cmd.finish(
"请提供需要更新的时间 (--cron/--interval/--date/--daily) 或参数 (--kwargs)"
)
trigger_config = None
trigger_type = None
try:
if cron_expr:
trigger_type = "cron"
parts = cron_expr.split()
if len(parts) != 5:
raise ValueError("Cron 表达式必须有5个部分")
cron_keys = ["minute", "hour", "day", "month", "day_of_week"]
trigger_config = dict(zip(cron_keys, parts))
elif interval_expr:
trigger_type = "interval"
trigger_config = _parse_interval(interval_expr)
elif date_expr:
trigger_type = "date"
trigger_config = {"run_date": datetime.fromisoformat(date_expr)}
elif daily_expr:
trigger_type = "cron"
trigger_config = _parse_daily_time(daily_expr)
except ValueError as e:
await schedule_cmd.finish(f"时间参数解析错误: {e}")
job_kwargs = None
if kwargs_str:
schedule = await scheduler_manager.get_schedule_by_id(schedule_id)
if not schedule:
await schedule_cmd.finish(f"未找到 ID 为 {schedule_id} 的任务。")
task_meta = scheduler_manager._registered_tasks.get(schedule.plugin_name)
if not task_meta or not (params_model := task_meta.get("model")):
await schedule_cmd.finish(
f"插件 '{schedule.plugin_name}' 未定义参数模型,无法更新参数。"
)
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
await schedule_cmd.finish(
f"插件 '{schedule.plugin_name}' 的参数模型配置错误。"
)
raw_kwargs = {}
try:
for item in kwargs_str.split(","):
key, value = item.strip().split("=", 1)
raw_kwargs[key.strip()] = value
except Exception as e:
await schedule_cmd.finish(
f"参数格式错误,请使用 'key=value,key2=value2' 格式。错误: {e}"
)
try:
model_validate = getattr(params_model, "model_validate", None)
if not model_validate:
await schedule_cmd.finish(
f"插件 '{schedule.plugin_name}' 的参数模型不支持验证。"
)
return
validated_model = model_validate(raw_kwargs)
model_dump = getattr(validated_model, "model_dump", None)
if not model_dump:
await schedule_cmd.finish(
f"插件 '{schedule.plugin_name}' 的参数模型不支持导出。"
)
return
job_kwargs = model_dump(exclude_unset=True)
except ValidationError as e:
errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()]
error_str = "\n".join(errors)
await schedule_cmd.finish(f"更新的参数验证失败:\n{error_str}")
return
_, message = await scheduler_manager.update_schedule(
schedule_id, trigger_type, trigger_config, job_kwargs
)
await schedule_cmd.finish(message)
@schedule_cmd.assign("插件列表")
async def _():
registered_plugins = scheduler_manager.get_registered_plugins()
if not registered_plugins:
await schedule_cmd.finish("当前没有已注册的定时任务插件。")
message_parts = ["📋 已注册的定时任务插件:"]
for i, plugin_name in enumerate(registered_plugins, 1):
task_meta = scheduler_manager._registered_tasks[plugin_name]
params_model = task_meta.get("model")
if not params_model:
message_parts.append(f"{i}. {plugin_name} - 无参数")
continue
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
message_parts.append(f"{i}. {plugin_name} - ⚠️ 参数模型配置错误")
continue
model_fields = getattr(params_model, "model_fields", None)
if model_fields:
param_info = ", ".join(
f"{field_name}({_get_type_name(field_info.annotation)})"
for field_name, field_info in model_fields.items()
)
message_parts.append(f"{i}. {plugin_name} - 参数: {param_info}")
else:
message_parts.append(f"{i}. {plugin_name} - 无参数")
await schedule_cmd.finish("\n".join(message_parts))
@schedule_cmd.assign("状态")
async def _(schedule_id: int):
status = await scheduler_manager.get_schedule_status(schedule_id)
if not status:
await schedule_cmd.finish(f"未找到ID为 {schedule_id} 的定时任务。")
info_lines = [
f"📋 定时任务详细信息 (ID: {schedule_id})",
"--------------------",
f"▫️ 插件: {status['plugin_name']}",
f"▫️ Bot ID: {status.get('bot_id') or '默认'}",
f"▫️ 目标: {status['group_id'] or '全局'}",
f"▫️ 状态: {'✔️ 已启用' if status['is_enabled'] else '⏸️ 已暂停'}",
f"▫️ 下次运行: {status['next_run_time']}",
f"▫️ 触发规则: {_format_trigger(status)}",
f"▫️ 任务参数: {_format_params(status)}",
]
await schedule_cmd.finish("\n".join(info_lines))

View File

@ -0,0 +1,298 @@
import re
from nonebot.adapters import Event
from nonebot.adapters.onebot.v11 import Bot
from nonebot.params import Depends
from nonebot.permission import SUPERUSER
from nonebot_plugin_alconna import (
Alconna,
AlconnaMatch,
Args,
Match,
Option,
Query,
Subcommand,
on_alconna,
)
from zhenxun.configs.config import Config
from zhenxun.services.scheduler import scheduler_manager
from zhenxun.services.scheduler.targeter import ScheduleTargeter
from zhenxun.utils.rules import admin_check
schedule_cmd = on_alconna(
Alconna(
"定时任务",
Subcommand(
"查看",
Option("-g", Args["target_group_id", str]),
Option("-all", help_text="查看所有群聊 (SUPERUSER)"),
Option("-p", Args["plugin_name", str], help_text="按插件名筛选"),
Option("--page", Args["page", int, 1], help_text="指定页码"),
alias=["ls", "list"],
help_text="查看定时任务",
),
Subcommand(
"设置",
Args["plugin_name", str],
Option("--cron", Args["cron_expr", str], help_text="设置 cron 表达式"),
Option("--interval", Args["interval_expr", str], help_text="设置时间间隔"),
Option("--date", Args["date_expr", str], help_text="设置特定执行日期"),
Option(
"--daily",
Args["daily_expr", str],
help_text="设置每天执行的时间 (如 08:20)",
),
Option("-g", Args["group_id", str], help_text="指定群组ID或'all'"),
Option("-all", help_text="对所有群生效 (等同于 -g all)"),
Option("--kwargs", Args["kwargs_str", str], help_text="设置任务参数"),
Option(
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
),
alias=["add", "开启"],
help_text="设置/开启一个定时任务",
),
Subcommand(
"删除",
Args["schedule_id?", int],
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
Option("-g", Args["group_id", str], help_text="指定群组ID"),
Option("-all", help_text="对所有群生效"),
Option(
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
),
alias=["del", "rm", "remove", "关闭", "取消"],
help_text="删除一个或多个定时任务",
),
Subcommand(
"暂停",
Args["schedule_id?", int],
Option("-all", help_text="对当前群所有任务生效"),
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
Option("-g", Args["group_id", str], help_text="指定群组ID (SUPERUSER)"),
Option(
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
),
alias=["pause"],
help_text="暂停一个或多个定时任务",
),
Subcommand(
"恢复",
Args["schedule_id?", int],
Option("-all", help_text="对当前群所有任务生效"),
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
Option("-g", Args["group_id", str], help_text="指定群组ID (SUPERUSER)"),
Option(
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
),
alias=["resume"],
help_text="恢复一个或多个定时任务",
),
Subcommand(
"执行",
Args["schedule_id", int],
alias=["trigger", "run"],
help_text="立即执行一次任务",
),
Subcommand(
"更新",
Args["schedule_id", int],
Option("--cron", Args["cron_expr", str], help_text="设置 cron 表达式"),
Option("--interval", Args["interval_expr", str], help_text="设置时间间隔"),
Option("--date", Args["date_expr", str], help_text="设置特定执行日期"),
Option(
"--daily",
Args["daily_expr", str],
help_text="更新每天执行的时间 (如 08:20)",
),
Option("--kwargs", Args["kwargs_str", str], help_text="更新参数"),
alias=["update", "modify", "修改"],
help_text="更新任务配置",
),
Subcommand(
"状态",
Args["schedule_id", int],
alias=["status", "info"],
help_text="查看单个任务的详细状态",
),
Subcommand(
"插件列表",
alias=["plugins"],
help_text="列出所有可用的插件",
),
),
priority=5,
block=True,
rule=admin_check(1),
)
schedule_cmd.shortcut(
"任务状态",
command="定时任务",
arguments=["状态", "{%0}"],
prefix=True,
)
class ScheduleTarget:
pass
class TargetByID(ScheduleTarget):
def __init__(self, id: int):
self.id = id
class TargetByPlugin(ScheduleTarget):
def __init__(
self, plugin: str, group_id: str | None = None, all_groups: bool = False
):
self.plugin = plugin
self.group_id = group_id
self.all_groups = all_groups
class TargetAll(ScheduleTarget):
def __init__(self, for_group: str | None = None):
self.for_group = for_group
TargetScope = TargetByID | TargetByPlugin | TargetAll | None
def create_target_parser(subcommand_name: str):
async def dependency(
event: Event,
schedule_id: Match[int] = AlconnaMatch("schedule_id"),
plugin_name: Match[str] = AlconnaMatch("plugin_name"),
group_id: Match[str] = AlconnaMatch("group_id"),
all_enabled: Query[bool] = Query(f"{subcommand_name}.all"),
) -> TargetScope:
if schedule_id.available:
return TargetByID(schedule_id.result)
if plugin_name.available:
p_name = plugin_name.result
if all_enabled.available:
return TargetByPlugin(plugin=p_name, all_groups=True)
elif group_id.available:
gid = group_id.result
if gid.lower() == "all":
return TargetByPlugin(plugin=p_name, all_groups=True)
return TargetByPlugin(plugin=p_name, group_id=gid)
else:
current_group_id = getattr(event, "group_id", None)
return TargetByPlugin(
plugin=p_name,
group_id=str(current_group_id) if current_group_id else None,
)
if all_enabled.available:
current_group_id = getattr(event, "group_id", None)
if not current_group_id:
await schedule_cmd.finish(
"私聊中单独使用 -all 选项时,必须使用 -g <群号> 指定目标。"
)
return TargetAll(for_group=str(current_group_id))
return None
return dependency
def parse_interval(interval_str: str) -> dict:
match = re.match(r"(\d+)([smhd])", interval_str.lower())
if not match:
raise ValueError("时间间隔格式错误, 请使用如 '30m', '2h', '1d', '10s' 的格式。")
value, unit = int(match.group(1)), match.group(2)
if unit == "s":
return {"seconds": value}
if unit == "m":
return {"minutes": value}
if unit == "h":
return {"hours": value}
if unit == "d":
return {"days": value}
return {}
def parse_daily_time(time_str: str) -> dict:
if match := re.match(r"^(\d{1,2}):(\d{1,2})(?::(\d{1,2}))?$", time_str):
hour, minute, second = match.groups()
hour, minute = int(hour), int(minute)
if not (0 <= hour <= 23 and 0 <= minute <= 59):
raise ValueError("小时或分钟数值超出范围。")
cron_config = {
"minute": str(minute),
"hour": str(hour),
"day": "*",
"month": "*",
"day_of_week": "*",
"timezone": Config.get_config("SchedulerManager", "SCHEDULER_TIMEZONE"),
}
if second is not None:
if not (0 <= int(second) <= 59):
raise ValueError("秒数值超出范围。")
cron_config["second"] = str(second)
return cron_config
else:
raise ValueError("时间格式错误,请使用 'HH:MM''HH:MM:SS' 格式。")
async def GetBotId(bot: Bot, bot_id_match: Match[str] = AlconnaMatch("bot_id")) -> str:
if bot_id_match.available:
return bot_id_match.result
return bot.self_id
def GetTargeter(subcommand: str):
"""
依赖注入函数用于解析命令参数并返回一个配置好的 ScheduleTargeter 实例
"""
async def dependency(
event: Event,
bot: Bot,
schedule_id: Match[int] = AlconnaMatch("schedule_id"),
plugin_name: Match[str] = AlconnaMatch("plugin_name"),
group_id: Match[str] = AlconnaMatch("group_id"),
all_enabled: Query[bool] = Query(f"{subcommand}.all"),
bot_id_to_operate: str = Depends(GetBotId),
) -> ScheduleTargeter:
if schedule_id.available:
return scheduler_manager.target(id=schedule_id.result)
if plugin_name.available:
if all_enabled.available:
return scheduler_manager.target(plugin_name=plugin_name.result)
current_group_id = getattr(event, "group_id", None)
gid = group_id.result if group_id.available else current_group_id
return scheduler_manager.target(
plugin_name=plugin_name.result,
group_id=str(gid) if gid else None,
bot_id=bot_id_to_operate,
)
if all_enabled.available:
current_group_id = getattr(event, "group_id", None)
gid = group_id.result if group_id.available else current_group_id
is_su = await SUPERUSER(bot, event)
if not gid and not is_su:
await schedule_cmd.finish(
f"在私聊中对所有任务进行'{subcommand}'操作需要超级用户权限。"
)
if (gid and str(gid).lower() == "all") or (not gid and is_su):
return scheduler_manager.target()
return scheduler_manager.target(
group_id=str(gid) if gid else None, bot_id=bot_id_to_operate
)
await schedule_cmd.finish(
f"'{subcommand}'操作失败请提供任务ID"
f"或通过 -p <插件名> 或 -all 指定要操作的任务。"
)
return Depends(dependency)

View File

@ -0,0 +1,380 @@
from datetime import datetime
from nonebot.adapters import Event
from nonebot.adapters.onebot.v11 import Bot
from nonebot.params import Depends
from nonebot.permission import SUPERUSER
from nonebot_plugin_alconna import AlconnaMatch, Arparma, Match, Query
from pydantic import BaseModel, ValidationError
from zhenxun.models.schedule_info import ScheduleInfo
from zhenxun.services.scheduler import scheduler_manager
from zhenxun.services.scheduler.targeter import ScheduleTargeter
from zhenxun.utils.message import MessageUtils
from . import presenters
from .commands import (
GetBotId,
GetTargeter,
parse_daily_time,
parse_interval,
schedule_cmd,
)
@schedule_cmd.handle()
async def _handle_time_options_mutex(arp: Arparma):
time_options = ["cron", "interval", "date", "daily"]
provided_options = [opt for opt in time_options if arp.query(opt) is not None]
if len(provided_options) > 1:
await schedule_cmd.finish(
f"时间选项 --{', --'.join(provided_options)} 不能同时使用,请只选择一个。"
)
@schedule_cmd.assign("查看")
async def handle_view(
bot: Bot,
event: Event,
target_group_id: Match[str] = AlconnaMatch("target_group_id"),
all_groups: Query[bool] = Query("查看.all"),
plugin_name: Match[str] = AlconnaMatch("plugin_name"),
page: Match[int] = AlconnaMatch("page"),
):
is_superuser = await SUPERUSER(bot, event)
title = ""
gid_filter = None
current_group_id = getattr(event, "group_id", None)
if not (all_groups.available or target_group_id.available) and not current_group_id:
await schedule_cmd.finish("私聊中查看任务必须使用 -g <群号> 或 -all 选项。")
if all_groups.available:
if not is_superuser:
await schedule_cmd.finish("需要超级用户权限才能查看所有群组的定时任务。")
title = "所有群组的定时任务"
elif target_group_id.available:
if not is_superuser:
await schedule_cmd.finish("需要超级用户权限才能查看指定群组的定时任务。")
gid_filter = target_group_id.result
title = f"{gid_filter} 的定时任务"
else:
gid_filter = str(current_group_id)
title = "本群的定时任务"
p_name_filter = plugin_name.result if plugin_name.available else None
schedules = await scheduler_manager.get_schedules(
plugin_name=p_name_filter, group_id=gid_filter
)
if p_name_filter:
title += f" [插件: {p_name_filter}]"
if not schedules:
await schedule_cmd.finish("没有找到任何相关的定时任务。")
img = await presenters.format_schedule_list_as_image(
schedules=schedules, title=title, current_page=page.result
)
await MessageUtils.build_message(img).send(reply_to=True)
@schedule_cmd.assign("设置")
async def handle_set(
event: Event,
plugin_name: Match[str] = AlconnaMatch("plugin_name"),
cron_expr: Match[str] = AlconnaMatch("cron_expr"),
interval_expr: Match[str] = AlconnaMatch("interval_expr"),
date_expr: Match[str] = AlconnaMatch("date_expr"),
daily_expr: Match[str] = AlconnaMatch("daily_expr"),
group_id: Match[str] = AlconnaMatch("group_id"),
kwargs_str: Match[str] = AlconnaMatch("kwargs_str"),
all_enabled: Query[bool] = Query("设置.all"),
bot_id_to_operate: str = Depends(GetBotId),
):
if not plugin_name.available:
await schedule_cmd.finish("设置任务时必须提供插件名称。")
has_time_option = any(
[
cron_expr.available,
interval_expr.available,
date_expr.available,
daily_expr.available,
]
)
if not has_time_option:
await schedule_cmd.finish(
"必须提供一种时间选项: --cron, --interval, --date, 或 --daily。"
)
p_name = plugin_name.result
if p_name not in scheduler_manager.get_registered_plugins():
await schedule_cmd.finish(
f"插件 '{p_name}' 没有注册可用的定时任务。\n"
f"可用插件: {list(scheduler_manager.get_registered_plugins())}"
)
trigger_type, trigger_config = "", {}
try:
if cron_expr.available:
trigger_type, trigger_config = (
"cron",
dict(
zip(
["minute", "hour", "day", "month", "day_of_week"],
cron_expr.result.split(),
)
),
)
elif interval_expr.available:
trigger_type, trigger_config = (
"interval",
parse_interval(interval_expr.result),
)
elif date_expr.available:
trigger_type, trigger_config = (
"date",
{"run_date": datetime.fromisoformat(date_expr.result)},
)
elif daily_expr.available:
trigger_type, trigger_config = "cron", parse_daily_time(daily_expr.result)
else:
await schedule_cmd.finish(
"必须提供一种时间选项: --cron, --interval, --date, 或 --daily。"
)
except ValueError as e:
await schedule_cmd.finish(f"时间参数解析错误: {e}")
job_kwargs = {}
if kwargs_str.available:
task_meta = scheduler_manager._registered_tasks[p_name]
params_model = task_meta.get("model")
if not (
params_model
and isinstance(params_model, type)
and issubclass(params_model, BaseModel)
):
await schedule_cmd.finish(f"插件 '{p_name}' 不支持或配置了无效的参数模型。")
try:
raw_kwargs = dict(
item.strip().split("=", 1) for item in kwargs_str.result.split(",")
)
model_validate = getattr(params_model, "model_validate", None)
if not model_validate:
await schedule_cmd.finish(f"插件 '{p_name}' 的参数模型不支持验证")
validated_model = model_validate(raw_kwargs)
model_dump = getattr(validated_model, "model_dump", None)
if not model_dump:
await schedule_cmd.finish(f"插件 '{p_name}' 的参数模型不支持导出")
job_kwargs = model_dump()
except ValidationError as e:
errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()]
await schedule_cmd.finish(
f"插件 '{p_name}' 的任务参数验证失败:\n" + "\n".join(errors)
)
except Exception as e:
await schedule_cmd.finish(
f"参数格式错误,请使用 'key=value,key2=value2' 格式。错误: {e}"
)
gid_str = group_id.result if group_id.available else None
target_group_id = (
scheduler_manager.ALL_GROUPS
if (gid_str and gid_str.lower() == "all") or all_enabled.available
else gid_str or getattr(event, "group_id", None)
)
if not target_group_id:
await schedule_cmd.finish(
"私聊中设置定时任务时,必须使用 -g <群号> 或 --all 选项指定目标。"
)
schedule = await scheduler_manager.add_schedule(
p_name,
str(target_group_id),
trigger_type,
trigger_config,
job_kwargs,
bot_id=bot_id_to_operate,
)
target_desc = (
f"所有群组 (Bot: {bot_id_to_operate})"
if target_group_id == scheduler_manager.ALL_GROUPS
else f"群组 {target_group_id}"
)
if schedule:
await schedule_cmd.finish(
f"为 [{target_desc}] 已成功设置插件 '{p_name}' 的定时任务 "
f"(ID: {schedule.id})。"
)
else:
await schedule_cmd.finish(f"为 [{target_desc}] 设置任务失败。")
@schedule_cmd.assign("删除")
async def handle_delete(targeter: ScheduleTargeter = GetTargeter("删除")):
schedules_to_remove: list[ScheduleInfo] = await targeter._get_schedules()
if not schedules_to_remove:
await schedule_cmd.finish("没有找到可删除的任务。")
count, _ = await targeter.remove()
if count > 0 and schedules_to_remove:
if len(schedules_to_remove) == 1:
message = presenters.format_remove_success(schedules_to_remove[0])
else:
target_desc = targeter._generate_target_description()
message = f"✅ 成功移除了{target_desc} {count} 个任务。"
else:
message = "没有任务被移除。"
await schedule_cmd.finish(message)
@schedule_cmd.assign("暂停")
async def handle_pause(targeter: ScheduleTargeter = GetTargeter("暂停")):
schedules_to_pause: list[ScheduleInfo] = await targeter._get_schedules()
if not schedules_to_pause:
await schedule_cmd.finish("没有找到可暂停的任务。")
count, _ = await targeter.pause()
if count > 0 and schedules_to_pause:
if len(schedules_to_pause) == 1:
message = presenters.format_pause_success(schedules_to_pause[0])
else:
target_desc = targeter._generate_target_description()
message = f"✅ 成功暂停了{target_desc} {count} 个任务。"
else:
message = "没有任务被暂停。"
await schedule_cmd.finish(message)
@schedule_cmd.assign("恢复")
async def handle_resume(targeter: ScheduleTargeter = GetTargeter("恢复")):
schedules_to_resume: list[ScheduleInfo] = await targeter._get_schedules()
if not schedules_to_resume:
await schedule_cmd.finish("没有找到可恢复的任务。")
count, _ = await targeter.resume()
if count > 0 and schedules_to_resume:
if len(schedules_to_resume) == 1:
message = presenters.format_resume_success(schedules_to_resume[0])
else:
target_desc = targeter._generate_target_description()
message = f"✅ 成功恢复了{target_desc} {count} 个任务。"
else:
message = "没有任务被恢复。"
await schedule_cmd.finish(message)
@schedule_cmd.assign("执行")
async def handle_trigger(schedule_id: Match[int] = AlconnaMatch("schedule_id")):
from zhenxun.services.scheduler.repository import ScheduleRepository
schedule_info = await ScheduleRepository.get_by_id(schedule_id.result)
if not schedule_info:
await schedule_cmd.finish(f"未找到 ID 为 {schedule_id.result} 的任务。")
success, message = await scheduler_manager.trigger_now(schedule_id.result)
if success:
final_message = presenters.format_trigger_success(schedule_info)
else:
final_message = f"❌ 手动触发失败: {message}"
await schedule_cmd.finish(final_message)
@schedule_cmd.assign("更新")
async def handle_update(
schedule_id: Match[int] = AlconnaMatch("schedule_id"),
cron_expr: Match[str] = AlconnaMatch("cron_expr"),
interval_expr: Match[str] = AlconnaMatch("interval_expr"),
date_expr: Match[str] = AlconnaMatch("date_expr"),
daily_expr: Match[str] = AlconnaMatch("daily_expr"),
kwargs_str: Match[str] = AlconnaMatch("kwargs_str"),
):
if not any(
[
cron_expr.available,
interval_expr.available,
date_expr.available,
daily_expr.available,
kwargs_str.available,
]
):
await schedule_cmd.finish(
"请提供需要更新的时间 (--cron/--interval/--date/--daily) 或参数 (--kwargs)"
)
trigger_type, trigger_config, job_kwargs = None, None, None
try:
if cron_expr.available:
trigger_type, trigger_config = (
"cron",
dict(
zip(
["minute", "hour", "day", "month", "day_of_week"],
cron_expr.result.split(),
)
),
)
elif interval_expr.available:
trigger_type, trigger_config = (
"interval",
parse_interval(interval_expr.result),
)
elif date_expr.available:
trigger_type, trigger_config = (
"date",
{"run_date": datetime.fromisoformat(date_expr.result)},
)
elif daily_expr.available:
trigger_type, trigger_config = "cron", parse_daily_time(daily_expr.result)
except ValueError as e:
await schedule_cmd.finish(f"时间参数解析错误: {e}")
if kwargs_str.available:
job_kwargs = dict(
item.strip().split("=", 1) for item in kwargs_str.result.split(",")
)
success, message = await scheduler_manager.update_schedule(
schedule_id.result, trigger_type, trigger_config, job_kwargs
)
if success:
from zhenxun.services.scheduler.repository import ScheduleRepository
updated_schedule = await ScheduleRepository.get_by_id(schedule_id.result)
if updated_schedule:
final_message = presenters.format_update_success(updated_schedule)
else:
final_message = "✅ 更新成功,但无法获取更新后的任务详情。"
else:
final_message = f"❌ 更新失败: {message}"
await schedule_cmd.finish(final_message)
@schedule_cmd.assign("插件列表")
async def handle_plugins_list():
message = await presenters.format_plugins_list()
await schedule_cmd.finish(message)
@schedule_cmd.assign("状态")
async def handle_status(schedule_id: Match[int] = AlconnaMatch("schedule_id")):
status = await scheduler_manager.get_schedule_status(schedule_id.result)
if not status:
await schedule_cmd.finish(f"未找到ID为 {schedule_id.result} 的定时任务。")
message = presenters.format_single_status_message(status)
await schedule_cmd.finish(message)

View File

@ -0,0 +1,274 @@
import asyncio
from zhenxun.models.schedule_info import ScheduleInfo
from zhenxun.services.scheduler import scheduler_manager
from zhenxun.utils._image_template import ImageTemplate, RowStyle
def _get_type_name(annotation) -> str:
"""获取类型注解的名称"""
if hasattr(annotation, "__name__"):
return annotation.__name__
elif hasattr(annotation, "_name"):
return annotation._name
else:
return str(annotation)
def _format_trigger(schedule: dict) -> str:
"""格式化触发器信息为可读字符串"""
trigger_type = schedule.get("trigger_type")
config = schedule.get("trigger_config")
if not isinstance(config, dict):
return f"配置错误: {config}"
if trigger_type == "cron":
hour = config.get("hour", "??")
minute = config.get("minute", "??")
try:
hour_int = int(hour)
minute_int = int(minute)
return f"每天 {hour_int:02d}:{minute_int:02d}"
except (ValueError, TypeError):
return f"每天 {hour}:{minute}"
elif trigger_type == "interval":
units = {
"weeks": "",
"days": "",
"hours": "小时",
"minutes": "分钟",
"seconds": "",
}
for unit, unit_name in units.items():
if value := config.get(unit):
return f"{value} {unit_name}"
return "未知间隔"
elif trigger_type == "date":
run_date = config.get("run_date", "N/A")
return f"特定时间 {run_date}"
else:
return f"未知触发器类型: {trigger_type}"
def _format_trigger_for_card(schedule_info: ScheduleInfo | dict) -> str:
"""为信息卡片格式化触发器规则"""
trigger_type = (
schedule_info.get("trigger_type")
if isinstance(schedule_info, dict)
else schedule_info.trigger_type
)
config = (
schedule_info.get("trigger_config")
if isinstance(schedule_info, dict)
else schedule_info.trigger_config
)
if not isinstance(config, dict):
return f"配置错误: {config}"
if trigger_type == "cron":
hour = config.get("hour", "??")
minute = config.get("minute", "??")
try:
hour_int = int(hour)
minute_int = int(minute)
return f"每天 {hour_int:02d}:{minute_int:02d}"
except (ValueError, TypeError):
return f"每天 {hour}:{minute}"
elif trigger_type == "interval":
units = {
"weeks": "",
"days": "",
"hours": "小时",
"minutes": "分钟",
"seconds": "",
}
for unit, unit_name in units.items():
if value := config.get(unit):
return f"{value} {unit_name}"
return "未知间隔"
elif trigger_type == "date":
run_date = config.get("run_date", "N/A")
return f"特定时间 {run_date}"
else:
return f"未知规则: {trigger_type}"
def _format_operation_result_card(
title: str, schedule_info: ScheduleInfo, extra_info: list[str] | None = None
) -> str:
"""
生成一个标准的操作结果信息卡片
参数:
title: 卡片的标题 (例如 "✅ 成功暂停定时任务!")
schedule_info: 相关的 ScheduleInfo 对象
extra_info: (可选) 额外的补充信息行
"""
target_desc = (
f"群组 {schedule_info.group_id}"
if schedule_info.group_id
and schedule_info.group_id != scheduler_manager.ALL_GROUPS
else "所有群组"
if schedule_info.group_id == scheduler_manager.ALL_GROUPS
else "全局"
)
info_lines = [
title,
f"✓ 任务 ID: {schedule_info.id}",
f"🖋 插件: {schedule_info.plugin_name}",
f"🎯 目标: {target_desc}",
f"⏰ 时间: {_format_trigger_for_card(schedule_info)}",
]
if extra_info:
info_lines.extend(extra_info)
return "\n".join(info_lines)
def format_pause_success(schedule_info: ScheduleInfo) -> str:
"""格式化暂停成功的消息"""
return _format_operation_result_card("✅ 成功暂停定时任务!", schedule_info)
def format_resume_success(schedule_info: ScheduleInfo) -> str:
"""格式化恢复成功的消息"""
return _format_operation_result_card("▶️ 成功恢复定时任务!", schedule_info)
def format_remove_success(schedule_info: ScheduleInfo) -> str:
"""格式化删除成功的消息"""
return _format_operation_result_card("❌ 成功删除定时任务!", schedule_info)
def format_trigger_success(schedule_info: ScheduleInfo) -> str:
"""格式化手动触发成功的消息"""
return _format_operation_result_card("🚀 成功手动触发定时任务!", schedule_info)
def format_update_success(schedule_info: ScheduleInfo) -> str:
"""格式化更新成功的消息"""
return _format_operation_result_card("🔄️ 成功更新定时任务配置!", schedule_info)
def _status_row_style(column: str, text: str) -> RowStyle:
"""为状态列设置颜色"""
style = RowStyle()
if column == "状态":
if text == "启用":
style.font_color = "#67C23A"
elif text == "暂停":
style.font_color = "#F56C6C"
elif text == "运行中":
style.font_color = "#409EFF"
return style
def _format_params(schedule_status: dict) -> str:
"""将任务参数格式化为人类可读的字符串"""
if kwargs := schedule_status.get("job_kwargs"):
return " | ".join(f"{k}: {v}" for k, v in kwargs.items())
return "-"
async def format_schedule_list_as_image(
schedules: list[ScheduleInfo], title: str, current_page: int
):
"""将任务列表格式化为图片"""
page_size = 15
total_items = len(schedules)
total_pages = (total_items + page_size - 1) // page_size
start_index = (current_page - 1) * page_size
end_index = start_index + page_size
paginated_schedules = schedules[start_index:end_index]
if not paginated_schedules:
return "这一页没有内容了哦~"
status_tasks = [
scheduler_manager.get_schedule_status(s.id) for s in paginated_schedules
]
all_statuses = await asyncio.gather(*status_tasks)
def get_status_text(status_value):
if isinstance(status_value, bool):
return "启用" if status_value else "暂停"
return str(status_value)
data_list = [
[
s["id"],
s["plugin_name"],
s.get("bot_id") or "N/A",
s["group_id"] or "全局",
s["next_run_time"],
_format_trigger(s),
_format_params(s),
get_status_text(s["is_enabled"]),
]
for s in all_statuses
if s
]
if not data_list:
return "没有找到任何相关的定时任务。"
return await ImageTemplate.table_page(
head_text=title,
tip_text=f"{current_page}/{total_pages} 页,共 {total_items} 条任务",
column_name=["ID", "插件", "Bot", "目标", "下次运行", "规则", "参数", "状态"],
data_list=data_list,
column_space=20,
text_style=_status_row_style,
)
def format_single_status_message(status: dict) -> str:
"""格式化单个任务状态为文本消息"""
info_lines = [
f"📋 定时任务详细信息 (ID: {status['id']})",
"--------------------",
f"▫️ 插件: {status['plugin_name']}",
f"▫️ Bot ID: {status.get('bot_id') or '默认'}",
f"▫️ 目标: {status['group_id'] or '全局'}",
f"▫️ 状态: {'✔️ 已启用' if status['is_enabled'] else '⏸️ 已暂停'}",
f"▫️ 下次运行: {status['next_run_time']}",
f"▫️ 触发规则: {_format_trigger(status)}",
f"▫️ 任务参数: {_format_params(status)}",
]
return "\n".join(info_lines)
async def format_plugins_list() -> str:
"""格式化可用插件列表为文本消息"""
from pydantic import BaseModel
registered_plugins = scheduler_manager.get_registered_plugins()
if not registered_plugins:
return "当前没有已注册的定时任务插件。"
message_parts = ["📋 已注册的定时任务插件:"]
for i, plugin_name in enumerate(registered_plugins, 1):
task_meta = scheduler_manager._registered_tasks[plugin_name]
params_model = task_meta.get("model")
param_info_str = "无参数"
if (
params_model
and isinstance(params_model, type)
and issubclass(params_model, BaseModel)
):
model_fields = getattr(params_model, "model_fields", None)
if model_fields:
param_info_str = "参数: " + ", ".join(
f"{field_name}({_get_type_name(field_info.annotation)})"
for field_name, field_info in model_fields.items()
)
elif params_model:
param_info_str = "⚠️ 参数模型配置错误"
message_parts.append(f"{i}. {plugin_name} - {param_info_str}")
return "\n".join(message_parts)

View File

@ -1,5 +1,3 @@
from datetime import datetime, timedelta
from tortoise.functions import Count
from zhenxun.models.group_console import GroupConsole
@ -10,6 +8,7 @@ from zhenxun.utils.echart_utils import ChartUtils
from zhenxun.utils.echart_utils.models import Barh
from zhenxun.utils.enum import PluginType
from zhenxun.utils.image_utils import BuildImage
from zhenxun.utils.utils import TimeUtils
class StatisticsManage:
@ -66,8 +65,7 @@ class StatisticsManage:
if plugin_name:
query = query.filter(plugin_name=plugin_name)
if day:
time = datetime.now() - timedelta(days=day)
query = query.filter(create_time__gte=time)
query = query.filter(create_time__gte=TimeUtils.get_day_start())
data_list = (
await query.annotate(count=Count("id"))
.group_by("plugin_name")
@ -87,8 +85,7 @@ class StatisticsManage:
if group_id:
query = query.filter(group_id=group_id)
if day:
time = datetime.now() - timedelta(days=day)
query = query.filter(create_time__gte=time)
query = query.filter(create_time__gte=TimeUtils.get_day_start())
data_list = (
await query.annotate(count=Count("id"))
.group_by("plugin_name")
@ -104,8 +101,7 @@ class StatisticsManage:
async def get_group_statistics(cls, group_id: str, day: int | None, title: str):
query = Statistics.filter(group_id=group_id)
if day:
time = datetime.now() - timedelta(days=day)
query = query.filter(create_time__gte=time)
query = query.filter(create_time__gte=TimeUtils.get_day_start())
data_list = (
await query.annotate(count=Count("id"))
.group_by("plugin_name")

View File

@ -28,7 +28,7 @@ from nonebot_plugin_alconna.uniseg.segment import (
)
from nonebot_plugin_session import EventSession
from zhenxun.configs.utils import PluginExtraData, RegisterConfig, Task
from zhenxun.configs.utils import PluginExtraData, Task
from zhenxun.utils.enum import PluginType
from zhenxun.utils.message import MessageUtils
@ -73,16 +73,6 @@ __plugin_meta__ = PluginMetadata(
author="HibiKier",
version="1.2",
plugin_type=PluginType.SUPERUSER,
configs=[
RegisterConfig(
module="_task",
key="DEFAULT_BROADCAST",
value=True,
help="被动 广播 进群默认开关状态",
default_value=True,
type=bool,
)
],
tasks=[Task(module="broadcast", name="广播")],
).to_dict(),
)

View File

@ -4,6 +4,7 @@ from fastapi.responses import JSONResponse
from zhenxun.models.plugin_info import PluginInfo as DbPluginInfo
from zhenxun.services.log import logger
from zhenxun.utils.enum import BlockType, PluginType
from zhenxun.utils.manager.virtual_env_package_manager import VirtualEnvPackageManager
from ....base_model import Result
from ....utils import authentication, clear_help_image
@ -11,6 +12,7 @@ from .data_source import ApiDataSource
from .model import (
BatchUpdatePlugins,
BatchUpdateResult,
InstallDependenciesPayload,
PluginCount,
PluginDetail,
PluginInfo,
@ -162,9 +164,9 @@ async def _(module: str) -> Result[PluginDetail]:
dependencies=[authentication()],
response_model=Result[BatchUpdateResult],
response_class=JSONResponse,
summary="批量更新插件配置",
description="批量更新插件配置",
)
async def batch_update_plugin_config_api(
async def _(
params: BatchUpdatePlugins,
) -> Result[BatchUpdateResult]:
"""批量更新插件配置,如开关、类型等"""
@ -187,9 +189,9 @@ async def batch_update_plugin_config_api(
"/menu_type/rename",
dependencies=[authentication()],
response_model=Result,
summary="重命名菜单类型",
description="重命名菜单类型",
)
async def rename_menu_type_api(payload: RenameMenuTypePayload) -> Result:
async def _(payload: RenameMenuTypePayload) -> Result[str]:
try:
result = await ApiDataSource.rename_menu_type(
old_name=payload.old_name, new_name=payload.new_name
@ -213,3 +215,24 @@ async def rename_menu_type_api(payload: RenameMenuTypePayload) -> Result:
except Exception as e:
logger.error(f"{router.prefix}/menu_type/rename 调用错误", "WebUi", e=e)
return Result.fail(info=f"发生未知错误: {type(e).__name__}")
@router.post(
"/install_dependencies",
dependencies=[authentication()],
response_model=Result,
response_class=JSONResponse,
description="安装/卸载依赖",
)
async def _(payload: InstallDependenciesPayload) -> Result:
try:
if not payload.dependencies:
return Result.fail("依赖列表不能为空")
if payload.handle_type == "install":
result = VirtualEnvPackageManager.install(payload.dependencies)
else:
result = VirtualEnvPackageManager.uninstall(payload.dependencies)
return Result.ok(result)
except Exception as e:
logger.error(f"{router.prefix}/install_dependencies 调用错误", "WebUi", e=e)
return Result.fail(f"发生了一点错误捏 {type(e)}: {e}")

View File

@ -167,7 +167,7 @@ class ApiDataSource:
)
return {
"success": len(errors) == 0,
"success": not errors,
"updated_count": updated_count + bulk_updated_count,
"errors": errors,
}
@ -184,19 +184,24 @@ class ApiDataSource:
config: ConfigGroup
返回:
lPluginConfig: 配置数据
PluginConfig: 配置数据
"""
type_str = ""
type_inner = None
if r := re.search(r"<class '(.*)'>", str(config.configs[cfg].type)):
ct = str(config.configs[cfg].type)
if r := re.search(r"<class '(.*)'>", ct):
type_str = r[1]
elif r := re.search(r"typing\.(.*)\[(.*)\]", str(config.configs[cfg].type)):
elif (r := re.search(r"typing\.(.*)\[(.*)\]", ct)) or (
r := re.search(r"(.*)\[(.*)\]", ct)
):
type_str = r[1]
if type_str:
type_str = type_str.lower()
type_inner = r[2]
if type_inner:
type_inner = [x.strip() for x in type_inner.split(",")]
else:
type_str = ct
return PluginConfig(
module=module,
key=cfg,

View File

@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Literal
from pydantic import BaseModel, Field
@ -162,3 +162,15 @@ class BatchUpdateResult(BaseModel):
default_factory=list, description="错误信息列表"
)
"""错误信息列表"""
class InstallDependenciesPayload(BaseModel):
"""
安装依赖
"""
handle_type: Literal["install", "uninstall"] = Field(..., description="处理类型")
"""处理类型"""
dependencies: list[str] = Field(..., description="依赖列表")
"""依赖列表"""

View File

@ -106,21 +106,34 @@ class ConfigGroup(BaseModel):
if value_to_process is None:
return default
if cfg.type:
if _is_pydantic_type(cfg.type):
if build_model:
try:
return parse_as(cfg.type, value_to_process)
except Exception as e:
logger.warning(
f"Pydantic 模型解析失败 (key: {c.upper()}). ", e=e
)
if cfg.arg_parser:
try:
return cattrs.structure(value_to_process, cfg.type)
return cfg.arg_parser(value_to_process)
except Exception as e:
logger.warning(f"Cattrs 结构化失败 (key: {key}),返回原始值。", e=e)
logger.debug(
f"配置项类型转换 MODULE: [<u><y>{self.module}</y></u>] | "
f"KEY: [<u><y>{key}</y></u>] 的自定义解析器失败,将使用原始值",
e=e,
)
return value_to_process
return value_to_process
if not build_model or not cfg.type:
return value_to_process
try:
if _is_pydantic_type(cfg.type):
parsed_value = parse_as(cfg.type, value_to_process)
return parsed_value
else:
structured_value = cattrs.structure(value_to_process, cfg.type)
return structured_value
except Exception as e:
logger.error(
f"❌ 配置项 '{self.module}.{key}' 自动类型转换失败 "
f"(目标类型: {cfg.type}),将返回原始值。请检查配置文件格式。错误: {e}",
e=e,
)
return value_to_process
def to_dict(self, **kwargs):
return model_dump(self, **kwargs)
@ -167,6 +180,57 @@ class ConfigsManager:
if data := self._data.get(module):
data.name = name
def _merge_dicts(self, new_data: dict, original_data: dict) -> dict:
"""合并两个字典只进行key值的新增和删除操作不修改原有key的值
递归处理嵌套字典确保所有层级的key保持一致
参数:
new_data: 新数据字典
original_data: 原数据字典
返回:
合并后的字典
"""
result = dict(original_data)
# 遍历新数据的键
for key, value in new_data.items():
# 如果键不在原数据中,添加它
if key not in original_data:
result[key] = value
# 如果两边都是字典,递归处理
elif isinstance(value, dict) and isinstance(original_data[key], dict):
result[key] = self._merge_dicts(value, original_data[key])
# 如果键已存在,保留原值,不覆盖
# (不做任何操作,保持原值)
return result
def _normalize_config_data(self, value: Any, original_value: Any = None) -> Any:
"""标准化配置数据处理BaseModel和字典的情况
参数:
value: 要标准化的值
original_value: 原始值用于合并字典
返回:
标准化后的值
"""
# 处理BaseModel
processed_value = _dump_pydantic_obj(value)
# 如果处理后的值是字典,且原始值也存在
if isinstance(processed_value, dict) and original_value is not None:
# 处理原始值
processed_original = _dump_pydantic_obj(original_value)
# 如果原始值也是字典,合并它们
if isinstance(processed_original, dict):
return self._merge_dicts(processed_value, processed_original)
return processed_value
def add_plugin_config(
self,
module: str,
@ -195,16 +259,18 @@ class ConfigsManager:
ValueError: module和key不能为为空
ValueError: 填写错误
"""
key = key.upper()
if not module or not key:
raise ValueError("add_plugin_config: module和key不能为为空")
if isinstance(value, BaseModel):
value = model_dump(value)
if isinstance(default_value, BaseModel):
default_value = model_dump(default_value)
processed_value = _dump_pydantic_obj(value)
processed_default_value = _dump_pydantic_obj(default_value)
# 获取现有配置值(如果存在)
existing_value = None
if module in self._data and (config := self._data[module].configs.get(key)):
existing_value = config.value
# 标准化值和默认值
processed_value = self._normalize_config_data(value, existing_value)
processed_default_value = self._normalize_config_data(default_value)
self.add_module.append(f"{module}:{key}".lower())
if module in self._data and (config := self._data[module].configs.get(key)):
@ -338,14 +404,13 @@ class ConfigsManager:
with open(self._simple_file, "w", encoding="utf8") as f:
_yaml.dump(self._simple_data, f)
path = path or self.file
save_data = {}
for module, config_group in self._data.items():
save_data[module] = {}
for config_key, config_model in config_group.configs.items():
save_data[module][config_key] = model_dump(
config_model, exclude={"type", "arg_parser"}
)
save_data = {
module: {
config_key: model_dump(config_model, exclude={"type", "arg_parser"})
for config_key, config_model in config_group.configs.items()
}
for module, config_group in self._data.items()
}
with open(path, "w", encoding="utf8") as f:
_yaml.dump(save_data, f)

View File

@ -65,7 +65,7 @@ class RegisterConfig(BaseModel):
"""配置注解"""
default_value: Any | None = None
"""默认值"""
type: Any = None
type: Any = str
"""参数类型"""
arg_parser: Callable | None = None
"""参数解析"""

View File

@ -49,7 +49,8 @@ class ChatHistory(Model):
o = "-" if order == "DESC" else ""
query = cls.filter(group_id=gid) if gid else cls
if date_scope:
query = query.filter(create_time__range=date_scope)
filter_scope = (date_scope[0].isoformat(" "), date_scope[1].isoformat(" "))
query = query.filter(create_time__range=filter_scope)
return list(
await query.annotate(count=Count("user_id"))
.order_by(f"{o}count")

View File

@ -99,6 +99,8 @@ class LevelUser(Model):
返回:
bool: 是否大于level
"""
if level == 0:
return True
if group_id:
if user := await cls.get_or_none(user_id=user_id, group_id=group_id):
return user.user_level >= level

View File

@ -1,3 +1,14 @@
"""
Zhenxun Bot - 核心服务模块
主要服务包括
- 数据库上下文 (db_context): 提供数据库模型基类和连接管理
- 日志服务 (log): 提供增强的带上下文的日志记录器
- LLM服务 (llm): 提供与大语言模型交互的统一API
- 插件生命周期管理 (plugin_init): 支持插件安装和卸载时的钩子函数
- 定时任务调度器 (scheduler): 提供持久化的可管理的定时任务服务
"""
from nonebot import require
require("nonebot_plugin_apscheduler")
@ -6,3 +17,33 @@ require("nonebot_plugin_session")
require("nonebot_plugin_htmlrender")
require("nonebot_plugin_uninfo")
require("nonebot_plugin_waiter")
from .db_context import Model, disconnect
from .llm import (
AI,
LLMContentPart,
LLMException,
LLMMessage,
get_model_instance,
list_available_models,
tool_registry,
)
from .log import logger
from .plugin_init import PluginInit, PluginInitManager
from .scheduler import scheduler_manager
__all__ = [
"AI",
"LLMContentPart",
"LLMException",
"LLMMessage",
"Model",
"PluginInit",
"PluginInitManager",
"disconnect",
"get_model_instance",
"list_available_models",
"logger",
"scheduler_manager",
"tool_registry",
]

File diff suppressed because it is too large Load Diff

View File

@ -10,10 +10,10 @@ from .api import (
TaskType,
analyze,
analyze_multimodal,
analyze_with_images,
chat,
code,
embed,
pipeline_chat,
search,
search_multimodal,
)
@ -35,6 +35,7 @@ from .manager import (
list_model_identifiers,
set_global_default_model_name,
)
from .tools import tool_registry
from .types import (
EmbeddingTaskType,
LLMContentPart,
@ -43,6 +44,7 @@ from .types import (
LLMMessage,
LLMResponse,
LLMTool,
MCPCompatible,
ModelDetail,
ModelInfo,
ModelProvider,
@ -51,7 +53,7 @@ from .types import (
ToolMetadata,
UsageInfo,
)
from .utils import create_multimodal_message, unimsg_to_llm_parts
from .utils import create_multimodal_message, message_to_unimessage, unimsg_to_llm_parts
__all__ = [
"AI",
@ -65,6 +67,7 @@ __all__ = [
"LLMMessage",
"LLMResponse",
"LLMTool",
"MCPCompatible",
"ModelDetail",
"ModelInfo",
"ModelName",
@ -76,7 +79,6 @@ __all__ = [
"UsageInfo",
"analyze",
"analyze_multimodal",
"analyze_with_images",
"chat",
"clear_model_cache",
"code",
@ -88,9 +90,12 @@ __all__ = [
"list_available_models",
"list_embedding_models",
"list_model_identifiers",
"message_to_unimessage",
"pipeline_chat",
"register_llm_configs",
"search",
"search_multimodal",
"set_global_default_model_name",
"tool_registry",
"unimsg_to_llm_parts",
]

View File

@ -8,7 +8,6 @@ from .base import BaseAdapter, OpenAICompatAdapter, RequestData, ResponseData
from .factory import LLMAdapterFactory, get_adapter_for_api_type, register_adapter
from .gemini import GeminiAdapter
from .openai import OpenAIAdapter
from .zhipu import ZhipuAdapter
LLMAdapterFactory.initialize()
@ -20,7 +19,6 @@ __all__ = [
"OpenAICompatAdapter",
"RequestData",
"ResponseData",
"ZhipuAdapter",
"get_adapter_for_api_type",
"register_adapter",
]

View File

@ -17,6 +17,7 @@ if TYPE_CHECKING:
from ..service import LLMModel
from ..types.content import LLMMessage
from ..types.enums import EmbeddingTaskType
from ..types.models import LLMTool
class RequestData(BaseModel):
@ -60,7 +61,7 @@ class BaseAdapter(ABC):
"""支持的API类型列表"""
pass
def prepare_simple_request(
async def prepare_simple_request(
self,
model: "LLMModel",
api_key: str,
@ -86,7 +87,7 @@ class BaseAdapter(ABC):
config = model._generation_config
return self.prepare_advanced_request(
return await self.prepare_advanced_request(
model=model,
api_key=api_key,
messages=messages,
@ -96,13 +97,13 @@ class BaseAdapter(ABC):
)
@abstractmethod
def prepare_advanced_request(
async def prepare_advanced_request(
self,
model: "LLMModel",
api_key: str,
messages: list["LLMMessage"],
config: "LLMGenerationConfig | None" = None,
tools: list[dict[str, Any]] | None = None,
tools: list["LLMTool"] | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> RequestData:
"""准备高级请求"""
@ -238,6 +239,9 @@ class BaseAdapter(ABC):
message = choice.get("message", {})
content = message.get("content", "")
if content:
content = content.strip()
parsed_tool_calls: list[LLMToolCall] | None = None
if message_tool_calls := message.get("tool_calls"):
from ..types.models import LLMToolFunction
@ -375,7 +379,7 @@ class BaseAdapter(ABC):
if model.temperature is not None:
base_config["temperature"] = model.temperature
if model.max_tokens is not None:
if model.api_type in ["gemini", "gemini_native"]:
if model.api_type == "gemini":
base_config["maxOutputTokens"] = model.max_tokens
else:
base_config["max_tokens"] = model.max_tokens
@ -401,26 +405,51 @@ class OpenAICompatAdapter(BaseAdapter):
"""
@abstractmethod
def get_chat_endpoint(self) -> str:
def get_chat_endpoint(self, model: "LLMModel") -> str:
"""子类必须实现,返回 chat completions 的端点"""
pass
@abstractmethod
def get_embedding_endpoint(self) -> str:
def get_embedding_endpoint(self, model: "LLMModel") -> str:
"""子类必须实现,返回 embeddings 的端点"""
pass
def prepare_advanced_request(
async def prepare_simple_request(
self,
model: "LLMModel",
api_key: str,
prompt: str,
history: list[dict[str, str]] | None = None,
) -> RequestData:
"""准备简单文本生成请求 - OpenAI兼容API的通用实现"""
url = self.get_api_url(model, self.get_chat_endpoint(model))
headers = self.get_base_headers(api_key)
messages = []
if history:
messages.extend(history)
messages.append({"role": "user", "content": prompt})
body = {
"model": model.model_name,
"messages": messages,
}
body = self.apply_config_override(model, body)
return RequestData(url=url, headers=headers, body=body)
async def prepare_advanced_request(
self,
model: "LLMModel",
api_key: str,
messages: list["LLMMessage"],
config: "LLMGenerationConfig | None" = None,
tools: list[dict[str, Any]] | None = None,
tools: list["LLMTool"] | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> RequestData:
"""准备高级请求 - OpenAI兼容格式"""
url = self.get_api_url(model, self.get_chat_endpoint())
url = self.get_api_url(model, self.get_chat_endpoint(model))
headers = self.get_base_headers(api_key)
openai_messages = self.convert_messages_to_openai_format(messages)
@ -430,7 +459,21 @@ class OpenAICompatAdapter(BaseAdapter):
}
if tools:
body["tools"] = tools
openai_tools = []
for tool in tools:
if tool.type == "function" and tool.function:
openai_tools.append({"type": "function", "function": tool.function})
elif tool.type == "mcp" and tool.mcp_session:
if callable(tool.mcp_session):
raise ValueError(
"适配器接收到未激活的 MCP 会话工厂。"
"会话工厂应该在 LLMModel.generate_response 中被激活。"
)
openai_tools.append(
tool.mcp_session.to_api_tool(api_type=self.api_type)
)
if openai_tools:
body["tools"] = openai_tools
if tool_choice:
body["tool_choice"] = tool_choice
@ -444,7 +487,7 @@ class OpenAICompatAdapter(BaseAdapter):
is_advanced: bool = False,
) -> ResponseData:
"""解析响应 - 直接使用基类的 OpenAI 格式解析"""
_ = model, is_advanced # 未使用的参数
_ = model, is_advanced
return self.parse_openai_response(response_json)
def prepare_embedding_request(
@ -456,8 +499,8 @@ class OpenAICompatAdapter(BaseAdapter):
**kwargs: Any,
) -> RequestData:
"""准备嵌入请求 - OpenAI兼容格式"""
_ = task_type # 未使用的参数
url = self.get_api_url(model, self.get_embedding_endpoint())
_ = task_type
url = self.get_api_url(model, self.get_embedding_endpoint(model))
headers = self.get_base_headers(api_key)
body = {
@ -465,7 +508,6 @@ class OpenAICompatAdapter(BaseAdapter):
"input": texts,
}
# 应用额外的配置参数
if kwargs:
body.update(kwargs)

View File

@ -22,10 +22,8 @@ class LLMAdapterFactory:
from .gemini import GeminiAdapter
from .openai import OpenAIAdapter
from .zhipu import ZhipuAdapter
cls.register_adapter(OpenAIAdapter())
cls.register_adapter(ZhipuAdapter())
cls.register_adapter(GeminiAdapter())
@classmethod

View File

@ -14,7 +14,7 @@ if TYPE_CHECKING:
from ..service import LLMModel
from ..types.content import LLMMessage
from ..types.enums import EmbeddingTaskType
from ..types.models import LLMToolCall
from ..types.models import LLMTool, LLMToolCall
class GeminiAdapter(BaseAdapter):
@ -38,30 +38,16 @@ class GeminiAdapter(BaseAdapter):
return headers
def prepare_advanced_request(
async def prepare_advanced_request(
self,
model: "LLMModel",
api_key: str,
messages: list["LLMMessage"],
config: "LLMGenerationConfig | None" = None,
tools: list[dict[str, Any]] | None = None,
tools: list["LLMTool"] | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> RequestData:
"""准备高级请求"""
return self._prepare_request(
model, api_key, messages, config, tools, tool_choice
)
def _prepare_request(
self,
model: "LLMModel",
api_key: str,
messages: list["LLMMessage"],
config: "LLMGenerationConfig | None" = None,
tools: list[dict[str, Any]] | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> RequestData:
"""准备 Gemini API 请求 - 支持所有高级功能"""
effective_config = config if config is not None else model._generation_config
endpoint = self._get_gemini_endpoint(model, effective_config)
@ -78,7 +64,8 @@ class GeminiAdapter(BaseAdapter):
system_instruction_parts = [{"text": msg.content}]
elif isinstance(msg.content, list):
system_instruction_parts = [
part.convert_for_api("gemini") for part in msg.content
await part.convert_for_api_async("gemini")
for part in msg.content
]
continue
@ -87,7 +74,9 @@ class GeminiAdapter(BaseAdapter):
current_parts.append({"text": msg.content})
elif isinstance(msg.content, list):
for part_obj in msg.content:
current_parts.append(part_obj.convert_for_api("gemini"))
current_parts.append(
await part_obj.convert_for_api_async("gemini")
)
gemini_contents.append({"role": "user", "parts": current_parts})
elif msg.role == "assistant" or msg.role == "model":
@ -95,7 +84,9 @@ class GeminiAdapter(BaseAdapter):
current_parts.append({"text": msg.content})
elif isinstance(msg.content, list):
for part_obj in msg.content:
current_parts.append(part_obj.convert_for_api("gemini"))
current_parts.append(
await part_obj.convert_for_api_async("gemini")
)
if msg.tool_calls:
import json
@ -154,16 +145,22 @@ class GeminiAdapter(BaseAdapter):
all_tools_for_request = []
if tools:
for tool_item in tools:
if isinstance(tool_item, dict):
if "name" in tool_item and "description" in tool_item:
all_tools_for_request.append(
{"functionDeclarations": [tool_item]}
for tool in tools:
if tool.type == "function" and tool.function:
all_tools_for_request.append(
{"functionDeclarations": [tool.function]}
)
elif tool.type == "mcp" and tool.mcp_session:
if callable(tool.mcp_session):
raise ValueError(
"适配器接收到未激活的 MCP 会话工厂。"
"会话工厂应该在 LLMModel.generate_response 中被激活。"
)
else:
all_tools_for_request.append(tool_item)
else:
all_tools_for_request.append(tool_item)
all_tools_for_request.append(
tool.mcp_session.to_api_tool(api_type=self.api_type)
)
elif tool.type == "google_search":
all_tools_for_request.append({"googleSearch": {}})
if effective_config:
if getattr(effective_config, "enable_grounding", False):
@ -183,11 +180,7 @@ class GeminiAdapter(BaseAdapter):
logger.debug("隐式启用代码执行工具。")
if all_tools_for_request:
gemini_api_tools = self._convert_tools_to_gemini_format(
all_tools_for_request
)
if gemini_api_tools:
body["tools"] = gemini_api_tools
body["tools"] = all_tools_for_request
final_tool_choice = tool_choice
if final_tool_choice is None and effective_config:
@ -241,38 +234,6 @@ class GeminiAdapter(BaseAdapter):
return f"/v1beta/models/{model.model_name}:generateContent"
def _convert_tools_to_gemini_format(
self, tools: list[dict[str, Any]]
) -> list[dict[str, Any]]:
"""转换工具格式为Gemini格式"""
gemini_tools = []
for tool in tools:
if tool.get("type") == "function":
func = tool["function"]
gemini_tool = {
"functionDeclarations": [
{
"name": func["name"],
"description": func.get("description", ""),
"parameters": func.get("parameters", {}),
}
]
}
gemini_tools.append(gemini_tool)
elif tool.get("type") == "code_execution":
gemini_tools.append(
{"codeExecution": {"language": tool.get("language", "python")}}
)
elif tool.get("type") == "google_search":
gemini_tools.append({"googleSearch": {}})
elif "googleSearch" in tool:
gemini_tools.append({"googleSearch": tool["googleSearch"]})
elif "codeExecution" in tool:
gemini_tools.append({"codeExecution": tool["codeExecution"]})
return gemini_tools
def _convert_tool_choice_to_gemini(
self, tool_choice_value: str | dict[str, Any]
) -> dict[str, Any]:
@ -395,10 +356,11 @@ class GeminiAdapter(BaseAdapter):
for category, threshold in custom_safety_settings.items():
safety_settings.append({"category": category, "threshold": threshold})
else:
from ..config.providers import get_gemini_safety_threshold
threshold = get_gemini_safety_threshold()
for category in safety_categories:
safety_settings.append(
{"category": category, "threshold": "BLOCK_MEDIUM_AND_ABOVE"}
)
safety_settings.append({"category": category, "threshold": threshold})
return safety_settings if safety_settings else None

View File

@ -1,12 +1,12 @@
"""
OpenAI API 适配器
支持 OpenAIDeepSeek 和其他 OpenAI 兼容的 API 服务
支持 OpenAIDeepSeek智谱AI 和其他 OpenAI 兼容的 API 服务
"""
from typing import TYPE_CHECKING
from .base import OpenAICompatAdapter, RequestData
from .base import OpenAICompatAdapter
if TYPE_CHECKING:
from ..service import LLMModel
@ -21,37 +21,18 @@ class OpenAIAdapter(OpenAICompatAdapter):
@property
def supported_api_types(self) -> list[str]:
return ["openai", "deepseek", "general_openai_compat"]
return ["openai", "deepseek", "zhipu", "general_openai_compat", "ark"]
def get_chat_endpoint(self) -> str:
def get_chat_endpoint(self, model: "LLMModel") -> str:
"""返回聊天完成端点"""
if model.api_type == "ark":
return "/api/v3/chat/completions"
if model.api_type == "zhipu":
return "/api/paas/v4/chat/completions"
return "/v1/chat/completions"
def get_embedding_endpoint(self) -> str:
"""返回嵌入端点"""
def get_embedding_endpoint(self, model: "LLMModel") -> str:
"""根据API类型返回嵌入端点"""
if model.api_type == "zhipu":
return "/v4/embeddings"
return "/v1/embeddings"
def prepare_simple_request(
self,
model: "LLMModel",
api_key: str,
prompt: str,
history: list[dict[str, str]] | None = None,
) -> RequestData:
"""准备简单文本生成请求 - OpenAI优化实现"""
url = self.get_api_url(model, self.get_chat_endpoint())
headers = self.get_base_headers(api_key)
messages = []
if history:
messages.extend(history)
messages.append({"role": "user", "content": prompt})
body = {
"model": model.model_name,
"messages": messages,
}
body = self.apply_config_override(model, body)
return RequestData(url=url, headers=headers, body=body)

View File

@ -1,57 +0,0 @@
"""
智谱 AI API 适配器
支持智谱 AI GLM 系列模型使用 OpenAI 兼容的接口格式
"""
from typing import TYPE_CHECKING
from .base import OpenAICompatAdapter, RequestData
if TYPE_CHECKING:
from ..service import LLMModel
class ZhipuAdapter(OpenAICompatAdapter):
"""智谱AI适配器 - 使用智谱AI专用的OpenAI兼容接口"""
@property
def api_type(self) -> str:
return "zhipu"
@property
def supported_api_types(self) -> list[str]:
return ["zhipu"]
def get_chat_endpoint(self) -> str:
"""返回智谱AI聊天完成端点"""
return "/api/paas/v4/chat/completions"
def get_embedding_endpoint(self) -> str:
"""返回智谱AI嵌入端点"""
return "/v4/embeddings"
def prepare_simple_request(
self,
model: "LLMModel",
api_key: str,
prompt: str,
history: list[dict[str, str]] | None = None,
) -> RequestData:
"""准备简单文本生成请求 - 智谱AI优化实现"""
url = self.get_api_url(model, self.get_chat_endpoint())
headers = self.get_base_headers(api_key)
messages = []
if history:
messages.extend(history)
messages.append({"role": "user", "content": prompt})
body = {
"model": model.model_name,
"messages": messages,
}
body = self.apply_config_override(model, body)
return RequestData(url=url, headers=headers, body=body)

View File

@ -2,6 +2,7 @@
LLM 服务的高级 API 接口
"""
import copy
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
@ -14,6 +15,7 @@ from zhenxun.services.log import logger
from .config import CommonOverrides, LLMGenerationConfig
from .config.providers import get_ai_config
from .manager import get_global_default_model_name, get_model_instance
from .tools import tool_registry
from .types import (
EmbeddingTaskType,
LLMContentPart,
@ -56,6 +58,7 @@ class AIConfig:
enable_gemini_safe_mode: bool = False
enable_gemini_multimodal: bool = False
enable_gemini_grounding: bool = False
default_preserve_media_in_history: bool = False
def __post_init__(self):
"""初始化后从配置中读取默认值"""
@ -81,7 +84,7 @@ class AI:
"""
初始化AI服务
Args:
参数:
config: AI 配置.
history: 可选的初始对话历史.
"""
@ -93,16 +96,65 @@ class AI:
self.history = []
logger.info("AI session history cleared.")
def _sanitize_message_for_history(self, message: LLMMessage) -> LLMMessage:
"""
净化用于存入历史记录的消息
将非文本的多模态内容部分替换为文本占位符以避免重复处理
"""
if not isinstance(message.content, list):
return message
sanitized_message = copy.deepcopy(message)
content_list = sanitized_message.content
if not isinstance(content_list, list):
return sanitized_message
new_content_parts: list[LLMContentPart] = []
has_multimodal_content = False
for part in content_list:
if isinstance(part, LLMContentPart) and part.type == "text":
new_content_parts.append(part)
else:
has_multimodal_content = True
if has_multimodal_content:
placeholder = "[用户发送了媒体文件,内容已在首次分析时处理]"
text_part_found = False
for part in new_content_parts:
if part.type == "text":
part.text = f"{placeholder} {part.text or ''}".strip()
text_part_found = True
break
if not text_part_found:
new_content_parts.insert(0, LLMContentPart.text_part(placeholder))
sanitized_message.content = new_content_parts
return sanitized_message
async def chat(
self,
message: str | LLMMessage | list[LLMContentPart],
*,
model: ModelName = None,
preserve_media_in_history: bool | None = None,
**kwargs: Any,
) -> str:
"""
进行一次聊天对话
此方法会自动使用和更新会话内的历史记录
参数:
message: 用户输入的消息
model: 本次对话要使用的模型
preserve_media_in_history: 是否在历史记录中保留原始多模态信息
- True: 保留用于深度多轮媒体分析
- False: 不保留替换为占位符提高效率
- None (默认): 使用AI实例配置的默认值
**kwargs: 传递给模型的其他参数
返回:
str: 模型的文本响应
"""
current_message: LLMMessage
if isinstance(message, str):
@ -127,7 +179,20 @@ class AI:
final_messages, model, "聊天失败", kwargs
)
self.history.append(current_message)
should_preserve = (
preserve_media_in_history
if preserve_media_in_history is not None
else self.config.default_preserve_media_in_history
)
if should_preserve:
logger.debug("深度分析模式:在历史记录中保留原始多模态消息。")
self.history.append(current_message)
else:
logger.debug("高效模式:净化历史记录中的多模态消息。")
sanitized_user_message = self._sanitize_message_for_history(current_message)
self.history.append(sanitized_user_message)
self.history.append(LLMMessage.assistant_text_response(response.text))
return response.text
@ -140,7 +205,18 @@ class AI:
timeout: int | None = None,
**kwargs: Any,
) -> dict[str, Any]:
"""代码执行"""
"""
代码执行
参数:
prompt: 代码执行的提示词
model: 要使用的模型名称
timeout: 代码执行超时时间
**kwargs: 传递给模型的其他参数
返回:
dict[str, Any]: 包含执行结果的字典包含textcode_executions和success字段
"""
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
config = CommonOverrides.gemini_code_execution()
@ -168,7 +244,18 @@ class AI:
instruction: str = "",
**kwargs: Any,
) -> dict[str, Any]:
"""信息搜索 - 支持多模态输入"""
"""
信息搜索 - 支持多模态输入
参数:
query: 搜索查询内容支持文本或多模态消息
model: 要使用的模型名称
instruction: 搜索指令
**kwargs: 传递给模型的其他参数
返回:
dict[str, Any]: 包含搜索结果的字典包含textsourcesqueries和success字段
"""
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
config = CommonOverrides.gemini_grounding()
@ -217,63 +304,69 @@ class AI:
async def analyze(
self,
message: UniMessage,
message: UniMessage | None,
*,
instruction: str = "",
model: ModelName = None,
tools: list[dict[str, Any]] | None = None,
use_tools: list[str] | None = None,
tool_config: dict[str, Any] | None = None,
activated_tools: list[LLMTool] | None = None,
history: list[LLMMessage] | None = None,
**kwargs: Any,
) -> str | LLMResponse:
) -> LLMResponse:
"""
内容分析 - 接收 UniMessage 物件进行多模态分析和工具呼叫
这是处理复杂互动的主要方法
参数:
message: 要分析的消息内容支持多模态
instruction: 分析指令
model: 要使用的模型名称
use_tools: 要使用的工具名称列表
tool_config: 工具配置
activated_tools: 已激活的工具列表
history: 对话历史记录
**kwargs: 传递给模型的其他参数
返回:
LLMResponse: 模型的完整响应结果
"""
content_parts = await unimsg_to_llm_parts(message)
content_parts = await unimsg_to_llm_parts(message or UniMessage())
final_messages: list[LLMMessage] = []
if history:
final_messages.extend(history)
if instruction:
final_messages.append(LLMMessage.system(instruction))
if not any(msg.role == "system" for msg in final_messages):
final_messages.insert(0, LLMMessage.system(instruction))
if not content_parts:
if instruction:
if instruction and not history:
final_messages.append(LLMMessage.user(instruction))
else:
elif not history:
raise LLMException(
"分析内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
)
else:
final_messages.append(LLMMessage.user(content_parts))
llm_tools = None
if tools:
llm_tools = []
for tool_dict in tools:
if isinstance(tool_dict, dict):
if "name" in tool_dict and "description" in tool_dict:
llm_tool = LLMTool(
type="function",
function={
"name": tool_dict["name"],
"description": tool_dict["description"],
"parameters": tool_dict.get("parameters", {}),
},
)
llm_tools.append(llm_tool)
else:
llm_tools.append(LLMTool(**tool_dict))
else:
llm_tools.append(tool_dict)
llm_tools: list[LLMTool] | None = activated_tools
if not llm_tools and use_tools:
try:
llm_tools = tool_registry.get_tools(use_tools)
logger.debug(f"已从注册表加载工具定义: {use_tools}")
except ValueError as e:
raise LLMException(
f"加载工具定义失败: {e}",
code=LLMErrorCode.CONFIGURATION_ERROR,
cause=e,
)
tool_choice = None
if tool_config:
mode = tool_config.get("mode", "auto")
if mode == "auto":
tool_choice = "auto"
elif mode == "any":
tool_choice = "any"
elif mode == "none":
tool_choice = "none"
if mode in ["auto", "any", "none"]:
tool_choice = mode
response = await self._execute_generation(
final_messages,
@ -284,9 +377,7 @@ class AI:
tool_choice=tool_choice,
)
if response.tool_calls:
return response
return response.text
return response
async def _execute_generation(
self,
@ -298,7 +389,7 @@ class AI:
tool_choice: str | dict[str, Any] | None = None,
base_config: LLMGenerationConfig | None = None,
) -> LLMResponse:
"""通用的生成执行方法,封装重复的模型获取、配置合并和异常处理逻辑"""
"""通用的生成执行方法,封装模型获取和单次API调用"""
try:
resolved_model_name = self._resolve_model_name(
model_name or self.config.model
@ -311,7 +402,9 @@ class AI:
resolved_model_name, override_config=final_config_dict
) as model_instance:
return await model_instance.generate_response(
messages, tools=llm_tools, tool_choice=tool_choice
messages,
tools=llm_tools,
tool_choice=tool_choice,
)
except LLMException:
raise
@ -380,7 +473,18 @@ class AI:
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
**kwargs: Any,
) -> list[list[float]]:
"""生成文本嵌入向量"""
"""
生成文本嵌入向量
参数:
texts: 要生成嵌入向量的文本或文本列表
model: 要使用的嵌入模型名称
task_type: 嵌入任务类型
**kwargs: 传递给模型的其他参数
返回:
list[list[float]]: 文本的嵌入向量列表
"""
if isinstance(texts, str):
texts = [texts]
if not texts:
@ -420,7 +524,17 @@ async def chat(
model: ModelName = None,
**kwargs: Any,
) -> str:
"""聊天对话便捷函数"""
"""
聊天对话便捷函数
参数:
message: 用户输入的消息
model: 要使用的模型名称
**kwargs: 传递给模型的其他参数
返回:
str: 模型的文本响应
"""
ai = AI()
return await ai.chat(message, model=model, **kwargs)
@ -432,7 +546,18 @@ async def code(
timeout: int | None = None,
**kwargs: Any,
) -> dict[str, Any]:
"""代码执行便捷函数"""
"""
代码执行便捷函数
参数:
prompt: 代码执行的提示词
model: 要使用的模型名称
timeout: 代码执行超时时间
**kwargs: 传递给模型的其他参数
返回:
dict[str, Any]: 包含执行结果的字典
"""
ai = AI()
return await ai.code(prompt, model=model, timeout=timeout, **kwargs)
@ -444,45 +569,56 @@ async def search(
instruction: str = "",
**kwargs: Any,
) -> dict[str, Any]:
"""信息搜索便捷函数"""
"""
信息搜索便捷函数
参数:
query: 搜索查询内容
model: 要使用的模型名称
instruction: 搜索指令
**kwargs: 传递给模型的其他参数
返回:
dict[str, Any]: 包含搜索结果的字典
"""
ai = AI()
return await ai.search(query, model=model, instruction=instruction, **kwargs)
async def analyze(
message: UniMessage,
message: UniMessage | None,
*,
instruction: str = "",
model: ModelName = None,
tools: list[dict[str, Any]] | None = None,
use_tools: list[str] | None = None,
tool_config: dict[str, Any] | None = None,
**kwargs: Any,
) -> str | LLMResponse:
"""内容分析便捷函数"""
"""
内容分析便捷函数
参数:
message: 要分析的消息内容
instruction: 分析指令
model: 要使用的模型名称
use_tools: 要使用的工具名称列表
tool_config: 工具配置
**kwargs: 传递给模型的其他参数
返回:
str | LLMResponse: 分析结果
"""
ai = AI()
return await ai.analyze(
message,
instruction=instruction,
model=model,
tools=tools,
use_tools=use_tools,
tool_config=tool_config,
**kwargs,
)
async def analyze_with_images(
text: str,
images: list[str | Path | bytes] | str | Path | bytes,
*,
instruction: str = "",
model: ModelName = None,
**kwargs: Any,
) -> str | LLMResponse:
"""图片分析便捷函数"""
message = create_multimodal_message(text=text, images=images)
return await analyze(message, instruction=instruction, model=model, **kwargs)
async def analyze_multimodal(
text: str | None = None,
images: list[str | Path | bytes] | str | Path | bytes | None = None,
@ -493,7 +629,21 @@ async def analyze_multimodal(
model: ModelName = None,
**kwargs: Any,
) -> str | LLMResponse:
"""多模态分析便捷函数"""
"""
多模态分析便捷函数
参数:
text: 文本内容
images: 图片文件路径字节数据或列表
videos: 视频文件路径字节数据或列表
audios: 音频文件路径字节数据或列表
instruction: 分析指令
model: 要使用的模型名称
**kwargs: 传递给模型的其他参数
返回:
str | LLMResponse: 分析结果
"""
message = create_multimodal_message(
text=text, images=images, videos=videos, audios=audios
)
@ -510,7 +660,21 @@ async def search_multimodal(
model: ModelName = None,
**kwargs: Any,
) -> dict[str, Any]:
"""多模态搜索便捷函数"""
"""
多模态搜索便捷函数
参数:
text: 文本内容
images: 图片文件路径字节数据或列表
videos: 视频文件路径字节数据或列表
audios: 音频文件路径字节数据或列表
instruction: 搜索指令
model: 要使用的模型名称
**kwargs: 传递给模型的其他参数
返回:
dict[str, Any]: 包含搜索结果的字典
"""
message = create_multimodal_message(
text=text, images=images, videos=videos, audios=audios
)
@ -525,6 +689,101 @@ async def embed(
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
**kwargs: Any,
) -> list[list[float]]:
"""文本嵌入便捷函数"""
"""
文本嵌入便捷函数
参数:
texts: 要生成嵌入向量的文本或文本列表
model: 要使用的嵌入模型名称
task_type: 嵌入任务类型
**kwargs: 传递给模型的其他参数
返回:
list[list[float]]: 文本的嵌入向量列表
"""
ai = AI()
return await ai.embed(texts, model=model, task_type=task_type, **kwargs)
async def pipeline_chat(
message: UniMessage | str | list[LLMContentPart],
model_chain: list[ModelName],
*,
initial_instruction: str = "",
final_instruction: str = "",
**kwargs: Any,
) -> LLMResponse:
"""
AI模型链式调用前一个模型的输出作为下一个模型的输入
参数:
message: 初始输入消息支持多模态
model_chain: 模型名称列表
initial_instruction: 第一个模型的系统指令
final_instruction: 最后一个模型的系统指令
**kwargs: 传递给模型实例的其他参数
返回:
LLMResponse: 最后一个模型的响应结果
"""
if not model_chain:
raise ValueError("模型链`model_chain`不能为空。")
current_content: str | list[LLMContentPart]
if isinstance(message, str):
current_content = message
elif isinstance(message, list):
current_content = message
else:
current_content = await unimsg_to_llm_parts(message)
final_response: LLMResponse | None = None
for i, model_name in enumerate(model_chain):
if not model_name:
raise ValueError(f"模型链中第 {i + 1} 个模型名称为空。")
is_first_step = i == 0
is_last_step = i == len(model_chain) - 1
messages_for_step: list[LLMMessage] = []
instruction_for_step = ""
if is_first_step and initial_instruction:
instruction_for_step = initial_instruction
elif is_last_step and final_instruction:
instruction_for_step = final_instruction
if instruction_for_step:
messages_for_step.append(LLMMessage.system(instruction_for_step))
messages_for_step.append(LLMMessage.user(current_content))
logger.info(
f"Pipeline Step [{i + 1}/{len(model_chain)}]: "
f"使用模型 '{model_name}' 进行处理..."
)
try:
async with await get_model_instance(model_name, **kwargs) as model:
response = await model.generate_response(messages_for_step)
final_response = response
current_content = response.text.strip()
if not current_content and not is_last_step:
logger.warning(
f"模型 '{model_name}' 在中间步骤返回了空内容,流水线可能无法继续。"
)
break
except Exception as e:
logger.error(f"在模型链的第 {i + 1} 步 ('{model_name}') 出错: {e}", e=e)
raise LLMException(
f"流水线在模型 '{model_name}' 处执行失败: {e}",
code=LLMErrorCode.GENERATION_FAILED,
cause=e,
)
if final_response is None:
raise LLMException(
"AI流水线未能产生任何响应。", code=LLMErrorCode.GENERATION_FAILED
)
return final_response

View File

@ -14,6 +14,8 @@ from .generation import (
from .presets import CommonOverrides
from .providers import (
LLMConfig,
ToolConfig,
get_gemini_safety_threshold,
get_llm_config,
register_llm_configs,
set_default_model,
@ -25,8 +27,10 @@ __all__ = [
"LLMConfig",
"LLMGenerationConfig",
"ModelConfigOverride",
"ToolConfig",
"apply_api_specific_mappings",
"create_generation_config_from_kwargs",
"get_gemini_safety_threshold",
"get_llm_config",
"register_llm_configs",
"set_default_model",

View File

@ -111,12 +111,12 @@ class LLMGenerationConfig(ModelConfigOverride):
params["temperature"] = self.temperature
if self.max_tokens is not None:
if api_type in ["gemini", "gemini_native"]:
if api_type == "gemini":
params["maxOutputTokens"] = self.max_tokens
else:
params["max_tokens"] = self.max_tokens
if api_type in ["gemini", "gemini_native"]:
if api_type == "gemini":
if self.top_k is not None:
params["topK"] = self.top_k
if self.top_p is not None:
@ -151,13 +151,13 @@ class LLMGenerationConfig(ModelConfigOverride):
if api_type in ["openai", "zhipu", "deepseek", "general_openai_compat"]:
params["response_format"] = {"type": "json_object"}
logger.debug(f"{api_type} 启用 JSON 对象输出模式")
elif api_type in ["gemini", "gemini_native"]:
elif api_type == "gemini":
params["responseMimeType"] = "application/json"
if self.response_schema:
params["responseSchema"] = self.response_schema
logger.debug(f"{api_type} 启用 JSON MIME 类型输出模式")
if api_type in ["gemini", "gemini_native"]:
if api_type == "gemini":
if (
self.response_format != ResponseFormat.JSON
and self.response_mime_type is not None
@ -214,7 +214,7 @@ def apply_api_specific_mappings(
"""应用API特定的参数映射"""
mapped_params = params.copy()
if api_type in ["gemini", "gemini_native"]:
if api_type == "gemini":
if "max_tokens" in mapped_params:
mapped_params["maxOutputTokens"] = mapped_params.pop("max_tokens")
if "top_k" in mapped_params:

View File

@ -71,14 +71,17 @@ class CommonOverrides:
@staticmethod
def gemini_safe() -> LLMGenerationConfig:
"""Gemini 安全模式:严格安全设置"""
"""Gemini 安全模式:使用配置的安全设置"""
from .providers import get_gemini_safety_threshold
threshold = get_gemini_safety_threshold()
return LLMGenerationConfig(
temperature=0.5,
safety_settings={
"HARM_CATEGORY_HARASSMENT": "BLOCK_MEDIUM_AND_ABOVE",
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_MEDIUM_AND_ABOVE",
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_MEDIUM_AND_ABOVE",
"HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_MEDIUM_AND_ABOVE",
"HARM_CATEGORY_HARASSMENT": threshold,
"HARM_CATEGORY_HATE_SPEECH": threshold,
"HARM_CATEGORY_SEXUALLY_EXPLICIT": threshold,
"HARM_CATEGORY_DANGEROUS_CONTENT": threshold,
},
)

View File

@ -4,15 +4,33 @@ LLM 提供商配置管理
负责注册和管理 AI 服务提供商的配置项
"""
from functools import lru_cache
import json
import sys
from typing import Any
from pydantic import BaseModel, Field
from zhenxun.configs.config import Config
from zhenxun.configs.path_config import DATA_PATH
from zhenxun.configs.utils import parse_as
from zhenxun.services.log import logger
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from ..types.models import ModelDetail, ProviderConfig
class ToolConfig(BaseModel):
"""MCP类型工具的配置定义"""
type: str = "mcp"
name: str = Field(..., description="工具的唯一名称标识")
description: str | None = Field(None, description="工具功能的描述")
mcp_config: dict[str, Any] | BaseModel = Field(
..., description="MCP服务器的特定配置"
)
AI_CONFIG_GROUP = "AI"
PROVIDERS_CONFIG_KEY = "PROVIDERS"
@ -38,6 +56,9 @@ class LLMConfig(BaseModel):
providers: list[ProviderConfig] = Field(
default_factory=list, description="配置多个 AI 服务提供商及其模型信息"
)
mcp_tools: list[ToolConfig] = Field(
default_factory=list, description="配置可用的外部MCP工具"
)
def get_provider_by_name(self, name: str) -> ProviderConfig | None:
"""根据名称获取提供商配置
@ -132,7 +153,7 @@ def get_default_providers() -> list[dict[str, Any]]:
return [
{
"name": "DeepSeek",
"api_key": "sk-******",
"api_key": "YOUR_ARK_API_KEY",
"api_base": "https://api.deepseek.com",
"api_type": "openai",
"models": [
@ -146,9 +167,30 @@ def get_default_providers() -> list[dict[str, Any]]:
},
],
},
{
"name": "ARK",
"api_key": "YOUR_ARK_API_KEY",
"api_base": "https://ark.cn-beijing.volces.com",
"api_type": "ark",
"models": [
{"model_name": "deepseek-r1-250528"},
{"model_name": "doubao-seed-1-6-250615"},
{"model_name": "doubao-seed-1-6-flash-250615"},
{"model_name": "doubao-seed-1-6-thinking-250615"},
],
},
{
"name": "siliconflow",
"api_key": "YOUR_ARK_API_KEY",
"api_base": "https://api.siliconflow.cn",
"api_type": "openai",
"models": [
{"model_name": "deepseek-ai/DeepSeek-V3"},
],
},
{
"name": "GLM",
"api_key": "",
"api_key": "YOUR_ARK_API_KEY",
"api_base": "https://open.bigmodel.cn",
"api_type": "zhipu",
"models": [
@ -167,12 +209,41 @@ def get_default_providers() -> list[dict[str, Any]]:
"api_type": "gemini",
"models": [
{"model_name": "gemini-2.0-flash"},
{"model_name": "gemini-2.5-flash-preview-05-20"},
{"model_name": "gemini-2.5-flash"},
{"model_name": "gemini-2.5-pro"},
{"model_name": "gemini-2.5-flash-lite-preview-06-17"},
],
},
]
def get_default_mcp_tools() -> dict[str, Any]:
"""
获取默认的MCP工具配置用于在文件不存在时创建
包含了 baidu-map, Context7, sequential-thinking.
"""
return {
"mcpServers": {
"baidu-map": {
"command": "npx",
"args": ["-y", "@baidumap/mcp-server-baidu-map"],
"env": {"BAIDU_MAP_API_KEY": "<YOUR_BAIDU_MAP_API_KEY>"},
"description": "百度地图工具,提供地理编码、路线规划等功能。",
},
"sequential-thinking": {
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-sequential-thinking"],
"description": "顺序思维工具,用于帮助模型进行多步骤推理。",
},
"Context7": {
"command": "npx",
"args": ["-y", "@upstash/context7-mcp@latest"],
"description": "Upstash 提供的上下文管理和记忆工具。",
},
}
}
def register_llm_configs():
"""注册 LLM 服务的配置项"""
logger.info("注册 LLM 服务的配置项")
@ -214,6 +285,19 @@ def register_llm_configs():
help="LLM服务请求重试的基础延迟时间",
type=int,
)
Config.add_plugin_config(
AI_CONFIG_GROUP,
"gemini_safety_threshold",
"BLOCK_MEDIUM_AND_ABOVE",
help=(
"Gemini 安全过滤阈值 "
"(BLOCK_LOW_AND_ABOVE: 阻止低级别及以上, "
"BLOCK_MEDIUM_AND_ABOVE: 阻止中等级别及以上, "
"BLOCK_ONLY_HIGH: 只阻止高级别, "
"BLOCK_NONE: 不阻止)"
),
type=str,
)
Config.add_plugin_config(
AI_CONFIG_GROUP,
@ -225,24 +309,111 @@ def register_llm_configs():
)
@lru_cache(maxsize=1)
def get_llm_config() -> LLMConfig:
"""获取 LLM 配置实例
返回:
LLMConfig: LLM 配置实例
"""
"""获取 LLM 配置实例,现在会从新的 JSON 文件加载 MCP 工具"""
ai_config = get_ai_config()
llm_data_path = DATA_PATH / "llm"
mcp_tools_path = llm_data_path / "mcp_tools.json"
mcp_tools_list = []
mcp_servers_dict = {}
if not mcp_tools_path.exists():
logger.info(f"未找到 MCP 工具配置文件,将在 '{mcp_tools_path}' 创建一个。")
llm_data_path.mkdir(parents=True, exist_ok=True)
default_mcp_config = get_default_mcp_tools()
try:
with mcp_tools_path.open("w", encoding="utf-8") as f:
json.dump(default_mcp_config, f, ensure_ascii=False, indent=2)
mcp_servers_dict = default_mcp_config.get("mcpServers", {})
except Exception as e:
logger.error(f"创建默认 MCP 配置文件失败: {e}", e=e)
mcp_servers_dict = {}
else:
try:
with mcp_tools_path.open("r", encoding="utf-8") as f:
mcp_data = json.load(f)
mcp_servers_dict = mcp_data.get("mcpServers", {})
if not isinstance(mcp_servers_dict, dict):
logger.warning(
f"'{mcp_tools_path}' 中的 'mcpServers' 键不是一个字典,"
f"将使用空配置。"
)
mcp_servers_dict = {}
except json.JSONDecodeError as e:
logger.error(f"解析 MCP 配置文件 '{mcp_tools_path}' 失败: {e}", e=e)
except Exception as e:
logger.error(f"读取 MCP 配置文件时发生未知错误: {e}", e=e)
mcp_servers_dict = {}
if sys.platform == "win32":
logger.debug("检测到Windows平台正在调整MCP工具的npx命令...")
for name, config in mcp_servers_dict.items():
if isinstance(config, dict) and config.get("command") == "npx":
logger.info(f"为工具 '{name}' 包装npx命令以兼容Windows。")
original_args = config.get("args", [])
config["command"] = "cmd"
config["args"] = ["/c", "npx", *original_args]
if mcp_servers_dict:
mcp_tools_list = [
{
"name": name,
"type": "mcp",
"description": config.get("description", f"MCP tool for {name}"),
"mcp_config": config,
}
for name, config in mcp_servers_dict.items()
if isinstance(config, dict)
]
from ..tools.registry import tool_registry
for tool_dict in mcp_tools_list:
if isinstance(tool_dict, dict):
tool_name = tool_dict.get("name")
if not tool_name:
continue
config_model = tool_registry.get_mcp_config_model(tool_name)
if not config_model:
logger.debug(
f"MCP工具 '{tool_name}' 没有注册其配置模型,"
f"将跳过特定配置验证,直接使用原始配置字典。"
)
continue
mcp_config_data = tool_dict.get("mcp_config", {})
try:
parsed_mcp_config = parse_as(config_model, mcp_config_data)
tool_dict["mcp_config"] = parsed_mcp_config
except Exception as e:
raise ValueError(f"MCP工具 '{tool_name}' 的 `mcp_config` 配置错误: {e}")
config_data = {
"default_model_name": ai_config.get("default_model_name"),
"proxy": ai_config.get("proxy"),
"timeout": ai_config.get("timeout", 180),
"max_retries_llm": ai_config.get("max_retries_llm", 3),
"retry_delay_llm": ai_config.get("retry_delay_llm", 2),
"providers": ai_config.get(PROVIDERS_CONFIG_KEY, []),
PROVIDERS_CONFIG_KEY: ai_config.get(PROVIDERS_CONFIG_KEY, []),
"mcp_tools": mcp_tools_list,
}
return LLMConfig(**config_data)
return parse_as(LLMConfig, config_data)
def get_gemini_safety_threshold() -> str:
"""获取 Gemini 安全过滤阈值配置
返回:
str: 安全过滤阈值
"""
ai_config = get_ai_config()
return ai_config.get("gemini_safety_threshold", "BLOCK_MEDIUM_AND_ABOVE")
def validate_llm_config() -> tuple[bool, list[str]]:
@ -326,3 +497,17 @@ def set_default_model(provider_model_name: str | None) -> bool:
logger.info("默认模型已清除")
return True
@PriorityLifecycle.on_startup(priority=10)
async def _init_llm_config_on_startup():
"""
在服务启动时主动调用一次 get_llm_config
以触发必要的初始化操作例如创建默认的 mcp_tools.json 文件
"""
logger.info("正在初始化 LLM 配置并检查 MCP 工具文件...")
try:
get_llm_config()
logger.info("LLM 配置初始化完成。")
except Exception as e:
logger.error(f"LLM 配置初始化时发生错误: {e}", e=e)

View File

@ -49,12 +49,36 @@ class LLMHttpClient:
max_keepalive_connections=self.config.max_keepalive_connections,
)
timeout = httpx.Timeout(self.config.timeout)
client_kwargs = {}
if self.config.proxy:
try:
version_parts = httpx.__version__.split(".")
major = int(
"".join(c for c in version_parts[0] if c.isdigit())
)
minor = (
int("".join(c for c in version_parts[1] if c.isdigit()))
if len(version_parts) > 1
else 0
)
if (major, minor) >= (0, 28):
client_kwargs["proxy"] = self.config.proxy
else:
client_kwargs["proxies"] = self.config.proxy
except (ValueError, IndexError):
client_kwargs["proxies"] = self.config.proxy
logger.warning(
f"无法解析 httpx 版本 '{httpx.__version__}'"
"LLM模块将默认使用旧版 'proxies' 参数语法。"
)
self._client = httpx.AsyncClient(
headers=headers,
limits=limits,
timeout=timeout,
proxies=self.config.proxy,
follow_redirects=True,
**client_kwargs,
)
if self._client is None:
raise LLMException(
@ -156,7 +180,16 @@ async def create_llm_http_client(
timeout: int = 180,
proxy: str | None = None,
) -> LLMHttpClient:
"""创建LLM HTTP客户端"""
"""
创建LLM HTTP客户端
参数:
timeout: 超时时间
proxy: 代理服务器地址
返回:
LLMHttpClient: HTTP客户端实例
"""
config = HttpClientConfig(timeout=timeout, proxy=proxy)
return LLMHttpClient(config)
@ -185,7 +218,20 @@ async def with_smart_retry(
provider_name: str | None = None,
**kwargs: Any,
) -> Any:
"""智能重试装饰器 - 支持Key轮询和错误分类"""
"""
智能重试装饰器 - 支持Key轮询和错误分类
参数:
func: 要重试的异步函数
*args: 传递给函数的位置参数
retry_config: 重试配置
key_store: API密钥状态存储
provider_name: 提供商名称
**kwargs: 传递给函数的关键字参数
返回:
Any: 函数执行结果
"""
config = retry_config or RetryConfig()
last_exception: Exception | None = None
failed_keys: set[str] = set()
@ -294,7 +340,17 @@ class KeyStatusStore:
api_keys: list[str],
exclude_keys: set[str] | None = None,
) -> str | None:
"""获取下一个可用的API密钥轮询策略"""
"""
获取下一个可用的API密钥轮询策略
参数:
provider_name: 提供商名称
api_keys: API密钥列表
exclude_keys: 要排除的密钥集合
返回:
str | None: 可用的API密钥如果没有可用密钥则返回None
"""
if not api_keys:
return None
@ -338,7 +394,13 @@ class KeyStatusStore:
logger.debug(f"记录API密钥成功使用: {self._get_key_id(api_key)}")
async def record_failure(self, api_key: str, status_code: int | None):
"""记录失败使用"""
"""
记录失败使用
参数:
api_key: API密钥
status_code: HTTP状态码
"""
key_id = self._get_key_id(api_key)
async with self._lock:
if status_code in [401, 403]:
@ -356,7 +418,15 @@ class KeyStatusStore:
logger.info(f"重置API密钥状态: {self._get_key_id(api_key)}")
async def get_key_stats(self, api_keys: list[str]) -> dict[str, dict]:
"""获取密钥使用统计"""
"""
获取密钥使用统计
参数:
api_keys: API密钥列表
返回:
dict[str, dict]: 密钥统计信息字典
"""
stats = {}
async with self._lock:
for key in api_keys:

View File

@ -17,6 +17,7 @@ from .config.providers import AI_CONFIG_GROUP, PROVIDERS_CONFIG_KEY, get_ai_conf
from .core import http_client_manager, key_store
from .service import LLMModel
from .types import LLMErrorCode, LLMException, ModelDetail, ProviderConfig
from .types.capabilities import get_model_capabilities
DEFAULT_MODEL_NAME_KEY = "default_model_name"
PROXY_KEY = "proxy"
@ -115,57 +116,30 @@ def get_default_api_base_for_type(api_type: str) -> str | None:
def get_configured_providers() -> list[ProviderConfig]:
"""从配置中获取Provider列表 - 简化版本"""
"""从配置中获取Provider列表 - 简化和修正版本"""
ai_config = get_ai_config()
providers_raw = ai_config.get(PROVIDERS_CONFIG_KEY, [])
if not isinstance(providers_raw, list):
providers = ai_config.get(PROVIDERS_CONFIG_KEY, [])
if not isinstance(providers, list):
logger.error(
f"配置项 {AI_CONFIG_GROUP}.{PROVIDERS_CONFIG_KEY} 不是一个列表,"
f"配置项 {AI_CONFIG_GROUP}.{PROVIDERS_CONFIG_KEY} 的值不是一个列表,"
f"将使用空列表。"
)
return []
valid_providers = []
for i, item in enumerate(providers_raw):
if not isinstance(item, dict):
logger.warning(f"配置文件中第 {i + 1} 项不是字典格式,已跳过。")
continue
try:
if not item.get("name"):
logger.warning(f"Provider {i + 1} 缺少 'name' 字段,已跳过。")
continue
if not item.get("api_key"):
logger.warning(
f"Provider '{item['name']}' 缺少 'api_key' 字段,已跳过。"
)
continue
if "api_type" not in item or not item["api_type"]:
provider_name = item.get("name", "").lower()
if "glm" in provider_name or "zhipu" in provider_name:
item["api_type"] = "zhipu"
elif "gemini" in provider_name or "google" in provider_name:
item["api_type"] = "gemini"
else:
item["api_type"] = "openai"
if "api_base" not in item or not item["api_base"]:
api_type = item.get("api_type")
if api_type:
default_api_base = get_default_api_base_for_type(api_type)
if default_api_base:
item["api_base"] = default_api_base
if "models" not in item:
item["models"] = [{"model_name": item.get("name", "default")}]
provider_conf = ProviderConfig(**item)
valid_providers.append(provider_conf)
except Exception as e:
logger.warning(f"解析配置文件中 Provider {i + 1} 时出错: {e},已跳过。")
for i, item in enumerate(providers):
if isinstance(item, ProviderConfig):
if not item.api_base:
default_api_base = get_default_api_base_for_type(item.api_type)
if default_api_base:
item.api_base = default_api_base
valid_providers.append(item)
else:
logger.warning(
f"配置文件中第 {i + 1} 项未能正确解析为 ProviderConfig 对象,"
f"已跳过。实际类型: {type(item)}"
)
return valid_providers
@ -173,14 +147,15 @@ def get_configured_providers() -> list[ProviderConfig]:
def find_model_config(
provider_name: str, model_name: str
) -> tuple[ProviderConfig, ModelDetail] | None:
"""在配置中查找指定的 Provider 和 ModelDetail
"""
在配置中查找指定的 Provider ModelDetail
Args:
参数:
provider_name: 提供商名称
model_name: 模型名称
Returns:
找到的 (ProviderConfig, ModelDetail) 元组未找到则返回 None
返回:
tuple[ProviderConfig, ModelDetail] | None: 找到的配置元组未找到则返回 None
"""
providers = get_configured_providers()
@ -221,10 +196,11 @@ def _get_model_identifiers(provider_name: str, model_detail: ModelDetail) -> lis
def list_model_identifiers() -> dict[str, list[str]]:
"""列出所有模型的可用标识符
"""
列出所有模型的可用标识符
Returns:
字典键为模型的完整名称值为该模型的所有可用标识符列表
返回:
dict[str, list[str]]: 字典键为模型的完整名称值为该模型的所有可用标识符列表
"""
providers = get_configured_providers()
result = {}
@ -248,7 +224,16 @@ async def get_model_instance(
provider_model_name: str | None = None,
override_config: dict[str, Any] | None = None,
) -> LLMModel:
"""根据 'ProviderName/ModelName' 字符串获取并实例化 LLMModel (异步版本)"""
"""
根据 'ProviderName/ModelName' 字符串获取并实例化 LLMModel (异步版本)
参数:
provider_model_name: 模型名称格式为 'ProviderName/ModelName'
override_config: 覆盖配置字典
返回:
LLMModel: 模型实例
"""
cache_key = _make_cache_key(provider_model_name, override_config)
cached_model = _get_cached_model(cache_key)
if cached_model:
@ -292,6 +277,10 @@ async def get_model_instance(
provider_config_found, model_detail_found = config_tuple_found
capabilities = get_model_capabilities(model_detail_found.model_name)
model_detail_found.is_embedding_model = capabilities.is_embedding_model
ai_config = get_ai_config()
global_proxy_setting = ai_config.get(PROXY_KEY)
default_timeout = (
@ -322,6 +311,7 @@ async def get_model_instance(
model_detail=model_detail_found,
key_store=key_store,
http_client=shared_http_client,
capabilities=capabilities,
)
if override_config:
@ -357,7 +347,15 @@ def get_global_default_model_name() -> str | None:
def set_global_default_model_name(provider_model_name: str | None) -> bool:
"""设置全局默认模型名称"""
"""
设置全局默认模型名称
参数:
provider_model_name: 模型名称格式为 'ProviderName/ModelName'
返回:
bool: 设置是否成功
"""
if provider_model_name:
prov_name, mod_name = parse_provider_model_string(provider_model_name)
if not prov_name or not mod_name or not find_model_config(prov_name, mod_name):
@ -377,7 +375,12 @@ def set_global_default_model_name(provider_model_name: str | None) -> bool:
async def get_key_usage_stats() -> dict[str, Any]:
"""获取所有Provider的Key使用统计"""
"""
获取所有Provider的Key使用统计
返回:
dict[str, Any]: 包含所有Provider的Key使用统计信息
"""
providers = get_configured_providers()
stats = {}
@ -400,7 +403,16 @@ async def get_key_usage_stats() -> dict[str, Any]:
async def reset_key_status(provider_name: str, api_key: str | None = None) -> bool:
"""重置指定Provider的Key状态"""
"""
重置指定Provider的Key状态
参数:
provider_name: 提供商名称
api_key: 要重置的特定API密钥如果为None则重置所有密钥
返回:
bool: 重置是否成功
"""
providers = get_configured_providers()
target_provider = None

View File

@ -6,11 +6,13 @@ LLM 模型实现类
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Callable
from contextlib import AsyncExitStack
import json
from typing import Any
from zhenxun.services.log import logger
from .adapters.base import RequestData
from .config import LLMGenerationConfig
from .config.providers import get_ai_config
from .core import (
@ -30,6 +32,8 @@ from .types import (
ModelDetail,
ProviderConfig,
)
from .types.capabilities import ModelCapabilities, ModelModality
from .utils import _sanitize_request_body_for_logging
class LLMModelBase(ABC):
@ -42,7 +46,17 @@ class LLMModelBase(ABC):
history: list[dict[str, str]] | None = None,
**kwargs: Any,
) -> str:
"""生成文本"""
"""
生成文本
参数:
prompt: 输入提示词
history: 对话历史记录
**kwargs: 其他参数
返回:
str: 生成的文本
"""
pass
@abstractmethod
@ -54,7 +68,19 @@ class LLMModelBase(ABC):
tool_choice: str | dict[str, Any] | None = None,
**kwargs: Any,
) -> LLMResponse:
"""生成高级响应"""
"""
生成高级响应
参数:
messages: 消息列表
config: 生成配置
tools: 工具列表
tool_choice: 工具选择策略
**kwargs: 其他参数
返回:
LLMResponse: 模型响应
"""
pass
@abstractmethod
@ -64,7 +90,17 @@ class LLMModelBase(ABC):
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
**kwargs: Any,
) -> list[list[float]]:
"""生成文本嵌入向量"""
"""
生成文本嵌入向量
参数:
texts: 文本列表
task_type: 嵌入任务类型
**kwargs: 其他参数
返回:
list[list[float]]: 嵌入向量列表
"""
pass
@ -77,12 +113,14 @@ class LLMModel(LLMModelBase):
model_detail: ModelDetail,
key_store: KeyStatusStore,
http_client: LLMHttpClient,
capabilities: ModelCapabilities,
config_override: LLMGenerationConfig | None = None,
):
self.provider_config = provider_config
self.model_detail = model_detail
self.key_store = key_store
self.http_client: LLMHttpClient = http_client
self.capabilities = capabilities
self._generation_config = config_override
self.provider_name = provider_config.name
@ -99,6 +137,34 @@ class LLMModel(LLMModelBase):
self._is_closed = False
def can_process_images(self) -> bool:
"""检查模型是否支持图片作为输入。"""
return ModelModality.IMAGE in self.capabilities.input_modalities
def can_process_video(self) -> bool:
"""检查模型是否支持视频作为输入。"""
return ModelModality.VIDEO in self.capabilities.input_modalities
def can_process_audio(self) -> bool:
"""检查模型是否支持音频作为输入。"""
return ModelModality.AUDIO in self.capabilities.input_modalities
def can_generate_images(self) -> bool:
"""检查模型是否支持生成图片。"""
return ModelModality.IMAGE in self.capabilities.output_modalities
def can_generate_audio(self) -> bool:
"""检查模型是否支持生成音频 (TTS)。"""
return ModelModality.AUDIO in self.capabilities.output_modalities
def can_use_tools(self) -> bool:
"""检查模型是否支持工具调用/函数调用。"""
return self.capabilities.supports_tool_calling
def is_embedding_model(self) -> bool:
"""检查这是否是一个嵌入模型。"""
return self.capabilities.is_embedding_model
async def _get_http_client(self) -> LLMHttpClient:
"""获取HTTP客户端"""
if self.http_client.is_closed:
@ -135,24 +201,54 @@ class LLMModel(LLMModelBase):
return selected_key
async def _execute_embedding_request(
async def _perform_api_call(
self,
adapter,
texts: list[str],
task_type: EmbeddingTaskType | str,
http_client: LLMHttpClient,
prepare_request_func: Callable[[str], Awaitable["RequestData"]],
parse_response_func: Callable[[dict[str, Any]], Any],
http_client: "LLMHttpClient",
failed_keys: set[str] | None = None,
) -> list[list[float]]:
"""执行单次嵌入请求 - 供重试机制调用"""
log_context: str = "API",
) -> Any:
"""
执行API调用的通用核心方法
该方法封装了以下通用逻辑:
1. 选择API密钥
2. 准备和记录请求
3. 发送HTTP POST请求
4. 处理HTTP错误和API特定错误
5. 记录密钥使用状态
6. 解析成功的响应
参数:
prepare_request_func: 准备请求的函数
parse_response_func: 解析响应的函数
http_client: HTTP客户端
failed_keys: 失败的密钥集合
log_context: 日志上下文
返回:
Any: 解析后的响应数据
"""
api_key = await self._select_api_key(failed_keys)
try:
request_data = adapter.prepare_embedding_request(
model=self,
api_key=api_key,
texts=texts,
task_type=task_type,
request_data = await prepare_request_func(api_key)
logger.info(
f"🌐 发起LLM请求 - 模型: {self.provider_name}/{self.model_name} "
f"[{log_context}]"
)
logger.debug(f"📡 请求URL: {request_data.url}")
masked_key = (
f"{api_key[:8]}...{api_key[-4:] if len(api_key) > 12 else '***'}"
)
logger.debug(f"🔑 API密钥: {masked_key}")
logger.debug(f"📋 请求头: {dict(request_data.headers)}")
sanitized_body = _sanitize_request_body_for_logging(request_data.body)
request_body_str = json.dumps(sanitized_body, ensure_ascii=False, indent=2)
logger.debug(f"📦 请求体: {request_body_str}")
http_response = await http_client.post(
request_data.url,
@ -160,121 +256,16 @@ class LLMModel(LLMModelBase):
json=request_data.body,
)
if http_response.status_code != 200:
error_text = http_response.text
logger.error(
f"HTTP嵌入请求失败: {http_response.status_code} - {error_text}"
)
await self.key_store.record_failure(api_key, http_response.status_code)
error_code = LLMErrorCode.API_REQUEST_FAILED
if http_response.status_code in [401, 403]:
error_code = LLMErrorCode.API_KEY_INVALID
elif http_response.status_code == 429:
error_code = LLMErrorCode.API_RATE_LIMITED
raise LLMException(
f"HTTP嵌入请求失败: {http_response.status_code}",
code=error_code,
details={
"status_code": http_response.status_code,
"response": error_text,
"api_key": api_key,
},
)
try:
response_json = http_response.json()
adapter.validate_embedding_response(response_json)
embeddings = adapter.parse_embedding_response(response_json)
except Exception as e:
logger.error(f"解析嵌入响应失败: {e}", e=e)
await self.key_store.record_failure(api_key, None)
if isinstance(e, LLMException):
raise
else:
raise LLMException(
f"解析API嵌入响应失败: {e}",
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
cause=e,
)
await self.key_store.record_success(api_key)
return embeddings
except LLMException:
raise
except Exception as e:
logger.error(f"生成嵌入时发生未预期错误: {e}", e=e)
await self.key_store.record_failure(api_key, None)
raise LLMException(
f"生成嵌入失败: {e}",
code=LLMErrorCode.EMBEDDING_FAILED,
cause=e,
)
async def _execute_with_smart_retry(
self,
adapter,
messages: list[LLMMessage],
config: LLMGenerationConfig | None,
tools_dict: list[dict[str, Any]] | None,
tool_choice: str | dict[str, Any] | None,
http_client: LLMHttpClient,
):
"""智能重试机制 - 使用统一的重试装饰器"""
ai_config = get_ai_config()
max_retries = ai_config.get("max_retries_llm", 3)
retry_delay = ai_config.get("retry_delay_llm", 2)
retry_config = RetryConfig(max_retries=max_retries, retry_delay=retry_delay)
return await with_smart_retry(
self._execute_single_request,
adapter,
messages,
config,
tools_dict,
tool_choice,
http_client,
retry_config=retry_config,
key_store=self.key_store,
provider_name=self.provider_name,
)
async def _execute_single_request(
self,
adapter,
messages: list[LLMMessage],
config: LLMGenerationConfig | None,
tools_dict: list[dict[str, Any]] | None,
tool_choice: str | dict[str, Any] | None,
http_client: LLMHttpClient,
failed_keys: set[str] | None = None,
) -> LLMResponse:
"""执行单次请求 - 供重试机制调用,直接返回 LLMResponse"""
api_key = await self._select_api_key(failed_keys)
try:
request_data = adapter.prepare_advanced_request(
model=self,
api_key=api_key,
messages=messages,
config=config,
tools=tools_dict,
tool_choice=tool_choice,
)
http_response = await http_client.post(
request_data.url,
headers=request_data.headers,
json=request_data.body,
)
logger.debug(f"📥 响应状态码: {http_response.status_code}")
logger.debug(f"📄 响应头: {dict(http_response.headers)}")
if http_response.status_code != 200:
error_text = http_response.text
logger.error(
f"HTTP请求失败: {http_response.status_code} - {error_text}"
f"❌ HTTP请求失败: {http_response.status_code} - {error_text} "
f"[{log_context}]"
)
logger.debug(f"💥 完整错误响应: {error_text}")
await self.key_store.record_failure(api_key, http_response.status_code)
@ -299,69 +290,165 @@ class LLMModel(LLMModelBase):
try:
response_json = http_response.json()
response_data = adapter.parse_response(
model=self,
response_json=response_json,
is_advanced=True,
)
from .types.models import LLMToolCall
response_tool_calls = []
if response_data.tool_calls:
for tc_data in response_data.tool_calls:
if isinstance(tc_data, LLMToolCall):
response_tool_calls.append(tc_data)
elif isinstance(tc_data, dict):
try:
response_tool_calls.append(LLMToolCall(**tc_data))
except Exception as e:
logger.warning(
f"无法将工具调用数据转换为LLMToolCall: {tc_data}, "
f"error: {e}"
)
else:
logger.warning(f"工具调用数据格式未知: {tc_data}")
llm_response = LLMResponse(
text=response_data.text,
usage_info=response_data.usage_info,
raw_response=response_data.raw_response,
tool_calls=response_tool_calls if response_tool_calls else None,
code_executions=response_data.code_executions,
grounding_metadata=response_data.grounding_metadata,
cache_info=response_data.cache_info,
response_json_str = json.dumps(
response_json, ensure_ascii=False, indent=2
)
logger.debug(f"📋 响应JSON: {response_json_str}")
parsed_data = parse_response_func(response_json)
except Exception as e:
logger.error(f"解析响应失败: {e}", e=e)
logger.error(f"解析 {log_context} 响应失败: {e}", e=e)
await self.key_store.record_failure(api_key, None)
if isinstance(e, LLMException):
raise
else:
raise LLMException(
f"解析API响应失败: {e}",
f"解析API {log_context} 响应失败: {e}",
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
cause=e,
)
await self.key_store.record_success(api_key)
return llm_response
logger.debug(f"✅ API密钥使用成功: {masked_key}")
logger.info(f"🎯 LLM响应解析完成 [{log_context}]")
return parsed_data
except LLMException:
raise
except Exception as e:
logger.error(f"生成响应时发生未预期错误: {e}", e=e)
error_log_msg = f"生成 {log_context.lower()} 时发生未预期错误: {e}"
logger.error(error_log_msg, e=e)
await self.key_store.record_failure(api_key, None)
raise LLMException(
f"生成响应失败: {e}",
code=LLMErrorCode.GENERATION_FAILED,
error_log_msg,
code=LLMErrorCode.GENERATION_FAILED
if log_context == "Generation"
else LLMErrorCode.EMBEDDING_FAILED,
cause=e,
)
async def _execute_embedding_request(
self,
adapter,
texts: list[str],
task_type: EmbeddingTaskType | str,
http_client: LLMHttpClient,
failed_keys: set[str] | None = None,
) -> list[list[float]]:
"""执行单次嵌入请求 - 供重试机制调用"""
async def prepare_request(api_key: str) -> RequestData:
return adapter.prepare_embedding_request(
model=self,
api_key=api_key,
texts=texts,
task_type=task_type,
)
def parse_response(response_json: dict[str, Any]) -> list[list[float]]:
adapter.validate_embedding_response(response_json)
return adapter.parse_embedding_response(response_json)
return await self._perform_api_call(
prepare_request_func=prepare_request,
parse_response_func=parse_response,
http_client=http_client,
failed_keys=failed_keys,
log_context="Embedding",
)
async def _execute_with_smart_retry(
self,
adapter,
messages: list[LLMMessage],
config: LLMGenerationConfig | None,
tools: list[LLMTool] | None,
tool_choice: str | dict[str, Any] | None,
http_client: LLMHttpClient,
):
"""智能重试机制 - 使用统一的重试装饰器"""
ai_config = get_ai_config()
max_retries = ai_config.get("max_retries_llm", 3)
retry_delay = ai_config.get("retry_delay_llm", 2)
retry_config = RetryConfig(max_retries=max_retries, retry_delay=retry_delay)
return await with_smart_retry(
self._execute_single_request,
adapter,
messages,
config,
tools,
tool_choice,
http_client,
retry_config=retry_config,
key_store=self.key_store,
provider_name=self.provider_name,
)
async def _execute_single_request(
self,
adapter,
messages: list[LLMMessage],
config: LLMGenerationConfig | None,
tools: list[LLMTool] | None,
tool_choice: str | dict[str, Any] | None,
http_client: LLMHttpClient,
failed_keys: set[str] | None = None,
) -> LLMResponse:
"""执行单次请求 - 供重试机制调用,直接返回 LLMResponse"""
async def prepare_request(api_key: str) -> RequestData:
return await adapter.prepare_advanced_request(
model=self,
api_key=api_key,
messages=messages,
config=config,
tools=tools,
tool_choice=tool_choice,
)
def parse_response(response_json: dict[str, Any]) -> LLMResponse:
response_data = adapter.parse_response(
model=self,
response_json=response_json,
is_advanced=True,
)
from .types.models import LLMToolCall
response_tool_calls = []
if response_data.tool_calls:
for tc_data in response_data.tool_calls:
if isinstance(tc_data, LLMToolCall):
response_tool_calls.append(tc_data)
elif isinstance(tc_data, dict):
try:
response_tool_calls.append(LLMToolCall(**tc_data))
except Exception as e:
logger.warning(
f"无法将工具调用数据转换为LLMToolCall: {tc_data}, "
f"error: {e}"
)
else:
logger.warning(f"工具调用数据格式未知: {tc_data}")
return LLMResponse(
text=response_data.text,
usage_info=response_data.usage_info,
raw_response=response_data.raw_response,
tool_calls=response_tool_calls if response_tool_calls else None,
code_executions=response_data.code_executions,
grounding_metadata=response_data.grounding_metadata,
cache_info=response_data.cache_info,
)
return await self._perform_api_call(
prepare_request_func=prepare_request,
parse_response_func=parse_response,
http_client=http_client,
failed_keys=failed_keys,
log_context="Generation",
)
async def close(self):
"""
标记模型实例的当前使用周期结束
@ -400,7 +487,17 @@ class LLMModel(LLMModelBase):
history: list[dict[str, str]] | None = None,
**kwargs: Any,
) -> str:
"""生成文本 - 通过 generate_response 实现"""
"""
生成文本 - 通过 generate_response 实现
参数:
prompt: 输入提示词
history: 对话历史记录
**kwargs: 其他参数
返回:
str: 生成的文本
"""
self._check_not_closed()
messages: list[LLMMessage] = []
@ -439,11 +536,21 @@ class LLMModel(LLMModelBase):
config: LLMGenerationConfig | None = None,
tools: list[LLMTool] | None = None,
tool_choice: str | dict[str, Any] | None = None,
tool_executor: Callable[[str, dict[str, Any]], Awaitable[Any]] | None = None,
max_tool_iterations: int = 5,
**kwargs: Any,
) -> LLMResponse:
"""生成高级响应 - 实现完整的工具调用循环"""
"""
生成高级响应
参数:
messages: 消息列表
config: 生成配置
tools: 工具列表
tool_choice: 工具选择策略
**kwargs: 其他参数
返回:
LLMResponse: 模型响应
"""
self._check_not_closed()
from .adapters import get_adapter_for_api_type
@ -468,109 +575,43 @@ class LLMModel(LLMModelBase):
merged_dict.update(config.to_dict())
final_request_config = LLMGenerationConfig(**merged_dict)
tools_dict: list[dict[str, Any]] | None = None
if tools:
tools_dict = []
for tool in tools:
if hasattr(tool, "model_dump"):
model_dump_func = getattr(tool, "model_dump")
tools_dict.append(model_dump_func(exclude_none=True))
elif isinstance(tool, dict):
tools_dict.append(tool)
else:
try:
tools_dict.append(dict(tool))
except (TypeError, ValueError):
logger.warning(f"工具 '{tool}' 无法转换为字典,已忽略。")
http_client = await self._get_http_client()
current_messages = list(messages)
for iteration in range(max_tool_iterations):
logger.debug(f"工具调用循环迭代: {iteration + 1}/{max_tool_iterations}")
async with AsyncExitStack() as stack:
activated_tools = []
if tools:
for tool in tools:
if tool.type == "mcp" and callable(tool.mcp_session):
func_obj = getattr(tool.mcp_session, "func", None)
tool_name = (
getattr(func_obj, "__name__", "unknown")
if func_obj
else "unknown"
)
logger.debug(f"正在激活 MCP 工具会话: {tool_name}")
active_session = await stack.enter_async_context(
tool.mcp_session()
)
activated_tools.append(
LLMTool.from_mcp_session(
session=active_session, annotations=tool.annotations
)
)
else:
activated_tools.append(tool)
llm_response = await self._execute_with_smart_retry(
adapter,
current_messages,
messages,
final_request_config,
tools_dict if iteration == 0 else None,
tool_choice if iteration == 0 else None,
activated_tools if activated_tools else None,
tool_choice,
http_client,
)
response_tool_calls = llm_response.tool_calls or []
if not response_tool_calls or not tool_executor:
logger.debug("模型未请求工具调用,或未提供工具执行器。返回当前响应。")
return llm_response
logger.info(f"模型请求执行 {len(response_tool_calls)} 个工具。")
assistant_message_content = llm_response.text if llm_response.text else ""
current_messages.append(
LLMMessage.assistant_tool_calls(
content=assistant_message_content, tool_calls=response_tool_calls
)
)
tool_response_messages: list[LLMMessage] = []
for tool_call in response_tool_calls:
tool_name = tool_call.function.name
try:
tool_args_dict = json.loads(tool_call.function.arguments)
logger.debug(f"执行工具: {tool_name},参数: {tool_args_dict}")
tool_result = await tool_executor(tool_name, tool_args_dict)
logger.debug(
f"工具 '{tool_name}' 执行结果: {str(tool_result)[:200]}..."
)
tool_response_messages.append(
LLMMessage.tool_response(
tool_call_id=tool_call.id,
function_name=tool_name,
result=tool_result,
)
)
except json.JSONDecodeError as e:
logger.error(
f"工具 '{tool_name}' 参数JSON解析失败: "
f"{tool_call.function.arguments}, 错误: {e}"
)
tool_response_messages.append(
LLMMessage.tool_response(
tool_call_id=tool_call.id,
function_name=tool_name,
result={
"error": "Argument JSON parsing failed",
"details": str(e),
},
)
)
except Exception as e:
logger.error(f"执行工具 '{tool_name}' 失败: {e}", e=e)
tool_response_messages.append(
LLMMessage.tool_response(
tool_call_id=tool_call.id,
function_name=tool_name,
result={
"error": "Tool execution failed",
"details": str(e),
},
)
)
current_messages.extend(tool_response_messages)
logger.warning(f"已达到最大工具调用迭代次数 ({max_tool_iterations})。")
raise LLMException(
"已达到最大工具调用迭代次数,但模型仍在请求工具调用或未提供最终文本回复。",
code=LLMErrorCode.GENERATION_FAILED,
details={
"iterations": max_tool_iterations,
"last_messages": current_messages[-2:],
},
)
return llm_response
async def generate_embeddings(
self,
@ -578,7 +619,17 @@ class LLMModel(LLMModelBase):
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
**kwargs: Any,
) -> list[list[float]]:
"""生成文本嵌入向量"""
"""
生成文本嵌入向量
参数:
texts: 文本列表
task_type: 嵌入任务类型
**kwargs: 其他参数
返回:
list[list[float]]: 嵌入向量列表
"""
self._check_not_closed()
if not texts:
return []

View File

@ -0,0 +1,7 @@
"""
工具模块导出
"""
from .registry import tool_registry
__all__ = ["tool_registry"]

View File

@ -0,0 +1,181 @@
"""
工具注册表
负责加载管理和实例化来自配置的工具
"""
from collections.abc import Callable
from contextlib import AbstractAsyncContextManager
from functools import partial
from typing import TYPE_CHECKING
from pydantic import BaseModel
from zhenxun.services.log import logger
from ..types import LLMTool
if TYPE_CHECKING:
from ..config.providers import ToolConfig
from ..types.protocols import MCPCompatible
class ToolRegistry:
"""工具注册表,用于管理和实例化配置的工具。"""
def __init__(self):
self._function_tools: dict[str, LLMTool] = {}
self._mcp_config_models: dict[str, type[BaseModel]] = {}
if TYPE_CHECKING:
self._mcp_factories: dict[
str, Callable[..., AbstractAsyncContextManager["MCPCompatible"]]
] = {}
else:
self._mcp_factories: dict[str, Callable] = {}
self._tool_configs: dict[str, "ToolConfig"] | None = None
self._tool_cache: dict[str, "LLMTool"] = {}
def _load_configs_if_needed(self):
"""如果尚未加载则从主配置中加载MCP工具定义。"""
if self._tool_configs is None:
logger.debug("首次访问正在加载MCP工具配置...")
from ..config.providers import get_llm_config
llm_config = get_llm_config()
self._tool_configs = {tool.name: tool for tool in llm_config.mcp_tools}
logger.info(f"已加载 {len(self._tool_configs)} 个MCP工具配置。")
def function_tool(
self,
name: str,
description: str,
parameters: dict,
required: list[str] | None = None,
):
"""
装饰器在代码中注册一个简单的无状态的函数工具
参数:
name: 工具的唯一名称
description: 工具功能的描述
parameters: OpenAPI格式的函数参数schema的properties部分
required: 必需的参数列表
"""
def decorator(func: Callable):
if name in self._function_tools or name in self._mcp_factories:
logger.warning(f"正在覆盖已注册的工具: {name}")
tool_definition = LLMTool.create(
name=name,
description=description,
parameters=parameters,
required=required,
)
self._function_tools[name] = tool_definition
logger.info(f"已在代码中注册函数工具: '{name}'")
tool_definition.annotations = tool_definition.annotations or {}
tool_definition.annotations["executable"] = func
return func
return decorator
def mcp_tool(self, name: str, config_model: type[BaseModel]):
"""
装饰器注册一个MCP工具及其配置模型
参数:
name: 工具的唯一名称必须与配置文件中的名称匹配
config_model: 一个Pydantic模型用于定义和验证该工具的 `mcp_config`
"""
def decorator(factory_func: Callable):
if name in self._mcp_factories:
logger.warning(f"正在覆盖已注册的 MCP 工厂: {name}")
self._mcp_factories[name] = factory_func
self._mcp_config_models[name] = config_model
logger.info(f"已注册 MCP 工具 '{name}' (配置模型: {config_model.__name__})")
return factory_func
return decorator
def get_mcp_config_model(self, name: str) -> type[BaseModel] | None:
"""根据名称获取MCP工具的配置模型。"""
return self._mcp_config_models.get(name)
def register_mcp_factory(
self,
name: str,
factory: Callable,
):
"""
在代码中注册一个 MCP 会话工厂将其与配置中的工具名称关联
参数:
name: 工具的唯一名称必须与配置文件中的名称匹配
factory: 一个返回异步生成器的可调用对象会话工厂
"""
if name in self._mcp_factories:
logger.warning(f"正在覆盖已注册的 MCP 工厂: {name}")
self._mcp_factories[name] = factory
logger.info(f"已注册 MCP 会话工厂: '{name}'")
def get_tool(self, name: str) -> "LLMTool":
"""
根据名称获取一个 LLMTool 定义
对于MCP工具返回的 LLMTool 实例包含一个可调用的会话工厂
而不是一个已激活的会话
"""
logger.debug(f"🔍 请求获取工具定义: {name}")
if name in self._tool_cache:
logger.debug(f"✅ 从缓存中获取工具定义: {name}")
return self._tool_cache[name]
if name in self._function_tools:
logger.debug(f"🛠️ 获取函数工具定义: {name}")
tool = self._function_tools[name]
self._tool_cache[name] = tool
return tool
self._load_configs_if_needed()
if self._tool_configs is None or name not in self._tool_configs:
known_tools = list(self._function_tools.keys()) + (
list(self._tool_configs.keys()) if self._tool_configs else []
)
logger.error(f"❌ 未找到名为 '{name}' 的工具定义")
logger.debug(f"📋 可用工具定义列表: {known_tools}")
raise ValueError(f"未找到名为 '{name}' 的工具定义。已知工具: {known_tools}")
config = self._tool_configs[name]
tool: "LLMTool"
if name not in self._mcp_factories:
logger.error(f"❌ MCP工具 '{name}' 缺少工厂函数")
available_factories = list(self._mcp_factories.keys())
logger.debug(f"📋 已注册的MCP工厂: {available_factories}")
raise ValueError(
f"MCP 工具 '{name}' 已在配置中定义,但没有注册对应的工厂函数。"
"请使用 `@tool_registry.mcp_tool` 装饰器进行注册。"
)
logger.info(f"🔧 创建MCP工具定义: {name}")
factory = self._mcp_factories[name]
typed_mcp_config = config.mcp_config
logger.debug(f"📋 MCP工具配置: {typed_mcp_config}")
configured_factory = partial(factory, config=typed_mcp_config)
tool = LLMTool.from_mcp_session(session=configured_factory)
self._tool_cache[name] = tool
logger.debug(f"💾 MCP工具定义已缓存: {name}")
return tool
def get_tools(self, names: list[str]) -> list["LLMTool"]:
"""根据名称列表获取多个 LLMTool 实例。"""
return [self.get_tool(name) for name in names]
tool_registry = ToolRegistry()

View File

@ -4,6 +4,7 @@ LLM 类型定义模块
统一导出所有核心类型协议和异常定义
"""
from .capabilities import ModelCapabilities, ModelModality, get_model_capabilities
from .content import (
LLMContentPart,
LLMMessage,
@ -26,6 +27,7 @@ from .models import (
ToolMetadata,
UsageInfo,
)
from .protocols import MCPCompatible
__all__ = [
"EmbeddingTaskType",
@ -41,8 +43,11 @@ __all__ = [
"LLMTool",
"LLMToolCall",
"LLMToolFunction",
"MCPCompatible",
"ModelCapabilities",
"ModelDetail",
"ModelInfo",
"ModelModality",
"ModelName",
"ModelProvider",
"ProviderConfig",
@ -50,5 +55,6 @@ __all__ = [
"ToolCategory",
"ToolMetadata",
"UsageInfo",
"get_model_capabilities",
"get_user_friendly_error_message",
]

View File

@ -0,0 +1,128 @@
"""
LLM 模型能力定义模块
定义模型的输入输出模态工具调用支持等核心能力
"""
from enum import Enum
import fnmatch
from pydantic import BaseModel, Field
class ModelModality(str, Enum):
TEXT = "text"
IMAGE = "image"
AUDIO = "audio"
VIDEO = "video"
EMBEDDING = "embedding"
class ModelCapabilities(BaseModel):
"""定义一个模型的核心、稳定能力。"""
input_modalities: set[ModelModality] = Field(default={ModelModality.TEXT})
output_modalities: set[ModelModality] = Field(default={ModelModality.TEXT})
supports_tool_calling: bool = False
is_embedding_model: bool = False
STANDARD_TEXT_TOOL_CAPABILITIES = ModelCapabilities(
input_modalities={ModelModality.TEXT},
output_modalities={ModelModality.TEXT},
supports_tool_calling=True,
)
GEMINI_CAPABILITIES = ModelCapabilities(
input_modalities={
ModelModality.TEXT,
ModelModality.IMAGE,
ModelModality.AUDIO,
ModelModality.VIDEO,
},
output_modalities={ModelModality.TEXT},
supports_tool_calling=True,
)
DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES = ModelCapabilities(
input_modalities={ModelModality.TEXT, ModelModality.IMAGE, ModelModality.VIDEO},
output_modalities={ModelModality.TEXT},
supports_tool_calling=True,
)
MODEL_ALIAS_MAPPING: dict[str, str] = {
"deepseek-v3*": "deepseek-chat",
"deepseek-ai/DeepSeek-V3": "deepseek-chat",
"deepseek-r1*": "deepseek-reasoner",
}
MODEL_CAPABILITIES_REGISTRY: dict[str, ModelCapabilities] = {
"gemini-*-tts": ModelCapabilities(
input_modalities={ModelModality.TEXT},
output_modalities={ModelModality.AUDIO},
),
"gemini-*-native-audio-*": ModelCapabilities(
input_modalities={ModelModality.TEXT, ModelModality.AUDIO, ModelModality.VIDEO},
output_modalities={ModelModality.TEXT, ModelModality.AUDIO},
supports_tool_calling=True,
),
"gemini-2.0-flash-preview-image-generation": ModelCapabilities(
input_modalities={
ModelModality.TEXT,
ModelModality.IMAGE,
ModelModality.AUDIO,
ModelModality.VIDEO,
},
output_modalities={ModelModality.TEXT, ModelModality.IMAGE},
supports_tool_calling=True,
),
"gemini-embedding-exp": ModelCapabilities(
input_modalities={ModelModality.TEXT},
output_modalities={ModelModality.EMBEDDING},
is_embedding_model=True,
),
"gemini-2.5-pro*": GEMINI_CAPABILITIES,
"gemini-1.5-pro*": GEMINI_CAPABILITIES,
"gemini-2.5-flash*": GEMINI_CAPABILITIES,
"gemini-2.0-flash*": GEMINI_CAPABILITIES,
"gemini-1.5-flash*": GEMINI_CAPABILITIES,
"GLM-4V-Flash": ModelCapabilities(
input_modalities={ModelModality.TEXT, ModelModality.IMAGE},
output_modalities={ModelModality.TEXT},
supports_tool_calling=True,
),
"GLM-4V-Plus*": ModelCapabilities(
input_modalities={ModelModality.TEXT, ModelModality.IMAGE, ModelModality.VIDEO},
output_modalities={ModelModality.TEXT},
supports_tool_calling=True,
),
"glm-4-*": STANDARD_TEXT_TOOL_CAPABILITIES,
"glm-z1-*": STANDARD_TEXT_TOOL_CAPABILITIES,
"doubao-seed-*": DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES,
"doubao-1-5-thinking-vision-pro": DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES,
"deepseek-chat": STANDARD_TEXT_TOOL_CAPABILITIES,
"deepseek-reasoner": STANDARD_TEXT_TOOL_CAPABILITIES,
}
def get_model_capabilities(model_name: str) -> ModelCapabilities:
"""
从注册表获取模型能力支持别名映射和通配符匹配
查找顺序: 1. 标准化名称 -> 2. 精确匹配 -> 3. 通配符匹配 -> 4. 默认值
"""
canonical_name = model_name
for alias_pattern, c_name in MODEL_ALIAS_MAPPING.items():
if fnmatch.fnmatch(model_name, alias_pattern):
canonical_name = c_name
break
if canonical_name in MODEL_CAPABILITIES_REGISTRY:
return MODEL_CAPABILITIES_REGISTRY[canonical_name]
for pattern, capabilities in MODEL_CAPABILITIES_REGISTRY.items():
if "*" in pattern and fnmatch.fnmatch(model_name, pattern):
return capabilities
return ModelCapabilities()

View File

@ -225,8 +225,10 @@ class LLMContentPart(BaseModel):
logger.warning(f"无法解析Base64图像数据: {self.image_source[:50]}...")
return None
def convert_for_api(self, api_type: str) -> dict[str, Any]:
async def convert_for_api_async(self, api_type: str) -> dict[str, Any]:
"""根据API类型转换多模态内容格式"""
from zhenxun.utils.http_utils import AsyncHttpx
if self.type == "text":
if api_type == "openai":
return {"type": "text", "text": self.text}
@ -248,20 +250,23 @@ class LLMContentPart(BaseModel):
mime_type, data = base64_info
return {"inlineData": {"mimeType": mime_type, "data": data}}
else:
# 如果无法解析 Base64 数据,抛出异常
raise ValueError(
f"无法解析Base64图像数据: {self.image_source[:50]}..."
)
else:
logger.warning(
f"Gemini API需要Base64格式但提供的是URL: {self.image_source}"
)
return {
"inlineData": {
"mimeType": "image/jpeg",
"data": self.image_source,
elif self.is_image_url():
logger.debug(f"正在为Gemini下载并编码URL图片: {self.image_source}")
try:
image_bytes = await AsyncHttpx.get_content(self.image_source)
mime_type = self.mime_type or "image/jpeg"
base64_data = base64.b64encode(image_bytes).decode("utf-8")
return {
"inlineData": {"mimeType": mime_type, "data": base64_data}
}
}
except Exception as e:
logger.error(f"下载或编码URL图片失败: {e}", e=e)
raise ValueError(f"无法处理图片URL: {e}")
else:
raise ValueError(f"不支持的图像源格式: {self.image_source[:50]}...")
else:
return {"type": "image_url", "image_url": {"url": self.image_source}}

View File

@ -4,13 +4,25 @@ LLM 数据模型定义
包含模型信息配置工具定义和响应数据的模型类
"""
from collections.abc import Callable
from contextlib import AbstractAsyncContextManager
from dataclasses import dataclass, field
from typing import Any
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel, Field
from .enums import ModelProvider, ToolCategory
if TYPE_CHECKING:
from .protocols import MCPCompatible
MCPSessionType = (
MCPCompatible | Callable[[], AbstractAsyncContextManager[MCPCompatible]] | None
)
else:
MCPCompatible = object
MCPSessionType = Any
ModelName = str | None
@ -98,10 +110,21 @@ class LLMToolCall(BaseModel):
class LLMTool(BaseModel):
"""LLM 工具定义(支持 MCP 风格)"""
model_config = {"arbitrary_types_allowed": True}
type: str = "function"
function: dict[str, Any]
function: dict[str, Any] | None = None
mcp_session: MCPSessionType = None
annotations: dict[str, Any] | None = Field(default=None, description="工具注解")
def model_post_init(self, /, __context: Any) -> None:
"""验证工具定义的有效性"""
_ = __context
if self.type == "function" and self.function is None:
raise ValueError("函数类型的工具必须包含 'function' 字段。")
if self.type == "mcp" and self.mcp_session is None:
raise ValueError("MCP 类型的工具必须包含 'mcp_session' 字段。")
@classmethod
def create(
cls,
@ -111,7 +134,7 @@ class LLMTool(BaseModel):
required: list[str] | None = None,
annotations: dict[str, Any] | None = None,
) -> "LLMTool":
"""创建工具"""
"""创建函数工具"""
function_def = {
"name": name,
"description": description,
@ -123,6 +146,15 @@ class LLMTool(BaseModel):
}
return cls(type="function", function=function_def, annotations=annotations)
@classmethod
def from_mcp_session(
cls,
session: Any,
annotations: dict[str, Any] | None = None,
) -> "LLMTool":
"""从 MCP 会话创建工具"""
return cls(type="mcp", mcp_session=session, annotations=annotations)
class LLMCodeExecution(BaseModel):
"""代码执行结果"""

View File

@ -0,0 +1,24 @@
"""
LLM 模块的协议定义
"""
from typing import Any, Protocol
class MCPCompatible(Protocol):
"""
一个协议定义了与LLM模块兼容的MCP会话对象应具备的行为
任何实现了 to_api_tool 方法的对象都可以被认为是 MCPCompatible
"""
def to_api_tool(self, api_type: str) -> dict[str, Any]:
"""
将此MCP会话转换为特定LLM提供商API所需的工具格式
参数:
api_type: 目标API的类型 (例如 'gemini', 'openai')
返回:
dict[str, Any]: 一个字典代表可以在API请求中使用的工具定义
"""
...

View File

@ -3,8 +3,10 @@ LLM 模块的工具和转换函数
"""
import base64
import copy
from pathlib import Path
from nonebot.adapters import Message as PlatformMessage
from nonebot_plugin_alconna.uniseg import (
At,
File,
@ -17,6 +19,7 @@ from nonebot_plugin_alconna.uniseg import (
)
from zhenxun.services.log import logger
from zhenxun.utils.http_utils import AsyncHttpx
from .types import LLMContentPart
@ -25,6 +28,12 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
"""
UniMessage 实例转换为一个 LLMContentPart 列表
这是处理多模态输入的核心转换逻辑
参数:
message: 要转换的UniMessage实例
返回:
list[LLMContentPart]: 转换后的内容部分列表
"""
parts: list[LLMContentPart] = []
for seg in message:
@ -51,14 +60,25 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
if seg.path:
part = await LLMContentPart.from_path(seg.path)
elif seg.url:
logger.warning(
f"直接使用 URL 的 {type(seg).__name__} 段,"
f"API 可能不支持: {seg.url}"
)
part = LLMContentPart.text_part(
f"[{type(seg).__name__.upper()} FILE: {seg.name or seg.url}]"
)
elif hasattr(seg, "raw") and seg.raw:
try:
logger.debug(f"检测到媒体URL开始下载: {seg.url}")
media_bytes = await AsyncHttpx.get_content(seg.url)
new_seg = copy.copy(seg)
new_seg.raw = media_bytes
seg = new_seg
logger.debug(f"媒体文件下载成功,大小: {len(media_bytes)} bytes")
except Exception as e:
logger.error(f"从URL下载媒体失败: {seg.url}, 错误: {e}")
part = LLMContentPart.text_part(
f"[下载媒体失败: {seg.name or seg.url}]"
)
if part:
parts.append(part)
continue
if hasattr(seg, "raw") and seg.raw:
mime_type = getattr(seg, "mimetype", None)
if isinstance(seg.raw, bytes):
b64_data = base64.b64encode(seg.raw).decode("utf-8")
@ -127,50 +147,19 @@ def create_multimodal_message(
audio_mimetypes: list[str] | str | None = None,
) -> UniMessage:
"""
创建多模态消息的便捷函数方便第三方调用
创建多模态消息的便捷函数
Args:
参数:
text: 文本内容
images: 图片数据支持路径字节数据或URL
videos: 视频数据支持路径字节数据或URL
audios: 音频数据支持路径字节数据或URL
image_mimetypes: 图片MIME类型当images为bytes时需要指定
video_mimetypes: 视频MIME类型当videos为bytes时需要指定
audio_mimetypes: 音频MIME类型当audios为bytes时需要指定
videos: 视频数据
audios: 音频数据
image_mimetypes: 图片MIME类型bytes数据时需要指定
video_mimetypes: 视频MIME类型bytes数据时需要指定
audio_mimetypes: 音频MIME类型bytes数据时需要指定
Returns:
返回:
UniMessage: 构建好的多模态消息
Examples:
# 纯文本
msg = create_multimodal_message("请分析这段文字")
# 文本 + 单张图片(路径)
msg = create_multimodal_message("分析图片", images="/path/to/image.jpg")
# 文本 + 多张图片
msg = create_multimodal_message(
"比较图片", images=["/path/1.jpg", "/path/2.jpg"]
)
# 文本 + 图片字节数据
msg = create_multimodal_message(
"分析", images=image_data, image_mimetypes="image/jpeg"
)
# 文本 + 视频
msg = create_multimodal_message("分析视频", videos="/path/to/video.mp4")
# 文本 + 音频
msg = create_multimodal_message("转录音频", audios="/path/to/audio.wav")
# 混合多模态
msg = create_multimodal_message(
"分析这些媒体文件",
images="/path/to/image.jpg",
videos="/path/to/video.mp4",
audios="/path/to/audio.wav"
)
"""
message = UniMessage()
@ -196,7 +185,7 @@ def _add_media_to_message(
media_class: type,
default_mimetype: str,
) -> None:
"""添加媒体文件到 UniMessage 的辅助函数"""
"""添加媒体文件到 UniMessage"""
if not isinstance(media_items, list):
media_items = [media_items]
@ -216,3 +205,80 @@ def _add_media_to_message(
elif isinstance(item, bytes):
mimetype = mime_list[i] if i < len(mime_list) else default_mimetype
message.append(media_class(raw=item, mimetype=mimetype))
def message_to_unimessage(message: PlatformMessage) -> UniMessage:
"""
将平台特定的 Message 对象转换为通用的 UniMessage
主要用于处理引用消息等未被自动转换的消息体
参数:
message: 平台特定的Message对象
返回:
UniMessage: 转换后的通用消息对象
"""
uni_segments = []
for seg in message:
if seg.type == "text":
uni_segments.append(Text(seg.data.get("text", "")))
elif seg.type == "image":
uni_segments.append(Image(url=seg.data.get("url")))
elif seg.type == "record":
uni_segments.append(Voice(url=seg.data.get("url")))
elif seg.type == "video":
uni_segments.append(Video(url=seg.data.get("url")))
elif seg.type == "at":
uni_segments.append(At("user", str(seg.data.get("qq", ""))))
else:
logger.debug(f"跳过不支持的平台消息段类型: {seg.type}")
return UniMessage(uni_segments)
def _sanitize_request_body_for_logging(body: dict) -> dict:
"""
净化请求体用于日志记录移除大数据字段并添加摘要信息
参数:
body: 原始请求体字典
返回:
dict: 净化后的请求体字典
"""
try:
sanitized_body = copy.deepcopy(body)
if "contents" in sanitized_body and isinstance(
sanitized_body["contents"], list
):
for content_item in sanitized_body["contents"]:
if "parts" in content_item and isinstance(content_item["parts"], list):
media_summary = []
new_parts = []
for part in content_item["parts"]:
if "inlineData" in part and isinstance(
part["inlineData"], dict
):
data = part["inlineData"].get("data")
if isinstance(data, str):
mime_type = part["inlineData"].get(
"mimeType", "unknown"
)
media_summary.append(f"{mime_type} ({len(data)} chars)")
continue
new_parts.append(part)
if media_summary:
summary_text = (
f"[多模态内容: {len(media_summary)}个文件 - "
f"{', '.join(media_summary)}]"
)
new_parts.insert(0, {"text": summary_text})
content_item["parts"] = new_parts
return sanitized_body
except Exception as e:
logger.warning(f"日志净化失败: {e},将记录原始请求体。")
return body

View File

@ -0,0 +1,12 @@
"""
定时调度服务模块
提供一个统一的持久化的定时任务管理器供所有插件使用
"""
from .lifecycle import _load_schedules_from_db
from .service import scheduler_manager
_ = _load_schedules_from_db
__all__ = ["scheduler_manager"]

View File

@ -0,0 +1,102 @@
"""
引擎适配层 (Adapter)
封装所有对具体调度器引擎 (APScheduler) 的操作
使上层服务与调度器实现解耦
"""
from nonebot_plugin_apscheduler import scheduler
from zhenxun.models.schedule_info import ScheduleInfo
from zhenxun.services.log import logger
from .job import _execute_job
JOB_PREFIX = "zhenxun_schedule_"
class APSchedulerAdapter:
"""封装对 APScheduler 的操作"""
@staticmethod
def _get_job_id(schedule_id: int) -> str:
"""生成 APScheduler 的 Job ID"""
return f"{JOB_PREFIX}{schedule_id}"
@staticmethod
def add_or_reschedule_job(schedule: ScheduleInfo):
"""根据 ScheduleInfo 添加或重新调度一个 APScheduler 任务"""
job_id = APSchedulerAdapter._get_job_id(schedule.id)
if not isinstance(schedule.trigger_config, dict):
logger.error(
f"任务 {schedule.id} 的 trigger_config 不是字典类型: "
f"{type(schedule.trigger_config)}"
)
return
job = scheduler.get_job(job_id)
if job:
scheduler.reschedule_job(
job_id, trigger=schedule.trigger_type, **schedule.trigger_config
)
logger.debug(f"已更新APScheduler任务: {job_id}")
else:
scheduler.add_job(
_execute_job,
trigger=schedule.trigger_type,
id=job_id,
misfire_grace_time=300,
args=[schedule.id],
**schedule.trigger_config,
)
logger.debug(f"已添加新的APScheduler任务: {job_id}")
@staticmethod
def remove_job(schedule_id: int):
"""移除一个 APScheduler 任务"""
job_id = APSchedulerAdapter._get_job_id(schedule_id)
try:
scheduler.remove_job(job_id)
logger.debug(f"已从APScheduler中移除任务: {job_id}")
except Exception:
pass
@staticmethod
def pause_job(schedule_id: int):
"""暂停一个 APScheduler 任务"""
job_id = APSchedulerAdapter._get_job_id(schedule_id)
try:
scheduler.pause_job(job_id)
except Exception:
pass
@staticmethod
def resume_job(schedule_id: int):
"""恢复一个 APScheduler 任务"""
job_id = APSchedulerAdapter._get_job_id(schedule_id)
try:
scheduler.resume_job(job_id)
except Exception:
import asyncio
from .repository import ScheduleRepository
async def _re_add_job():
schedule = await ScheduleRepository.get_by_id(schedule_id)
if schedule:
APSchedulerAdapter.add_or_reschedule_job(schedule)
asyncio.create_task(_re_add_job()) # noqa: RUF006
@staticmethod
def get_job_status(schedule_id: int) -> dict:
"""获取 APScheduler Job 的状态"""
job_id = APSchedulerAdapter._get_job_id(schedule_id)
job = scheduler.get_job(job_id)
return {
"next_run_time": job.next_run_time.strftime("%Y-%m-%d %H:%M:%S")
if job and job.next_run_time
else "N/A",
"is_paused_in_scheduler": not bool(job.next_run_time) if job else "N/A",
}

View File

@ -0,0 +1,192 @@
"""
定时任务的执行逻辑
包含被 APScheduler 实际调度的函数以及处理不同目标单个所有群组的执行策略
"""
import asyncio
import copy
import inspect
import random
import nonebot
from zhenxun.configs.config import Config
from zhenxun.models.schedule_info import ScheduleInfo
from zhenxun.services.log import logger
from zhenxun.utils.common_utils import CommonUtils
from zhenxun.utils.decorator.retry import Retry
from zhenxun.utils.platform import PlatformUtils
SCHEDULE_CONCURRENCY_KEY = "all_groups_concurrency_limit"
async def _execute_job(schedule_id: int):
"""
APScheduler 调度的入口函数
根据 schedule_id 处理特定任务所有群组任务或全局任务
"""
from .repository import ScheduleRepository
from .service import scheduler_manager
scheduler_manager._running_tasks.add(schedule_id)
try:
schedule = await ScheduleRepository.get_by_id(schedule_id)
if not schedule or not schedule.is_enabled:
logger.warning(f"定时任务 {schedule_id} 不存在或已禁用,跳过执行。")
return
plugin_name = schedule.plugin_name
task_meta = scheduler_manager._registered_tasks.get(plugin_name)
if not task_meta:
logger.error(
f"无法执行定时任务:插件 '{plugin_name}' 未注册或已卸载。将禁用该任务。"
)
schedule.is_enabled = False
await ScheduleRepository.save(schedule, update_fields=["is_enabled"])
from .adapter import APSchedulerAdapter
APSchedulerAdapter.remove_job(schedule.id)
return
try:
if schedule.bot_id:
bot = nonebot.get_bot(schedule.bot_id)
else:
bot = nonebot.get_bot()
logger.debug(
f"任务 {schedule_id} 未关联特定Bot使用默认Bot {bot.self_id}"
)
except KeyError:
logger.warning(
f"定时任务 {schedule_id} 需要的 Bot {schedule.bot_id} "
f"不在线,本次执行跳过。"
)
return
except ValueError:
logger.warning(f"当前没有Bot在线定时任务 {schedule_id} 跳过。")
return
if schedule.group_id == scheduler_manager.ALL_GROUPS:
await _execute_for_all_groups(schedule, task_meta, bot)
else:
await _execute_for_single_target(schedule, task_meta, bot)
finally:
scheduler_manager._running_tasks.discard(schedule_id)
async def _execute_for_all_groups(schedule: ScheduleInfo, task_meta: dict, bot):
"""为所有群组执行任务,并处理优先级覆盖。"""
plugin_name = schedule.plugin_name
concurrency_limit = Config.get_config(
"SchedulerManager", SCHEDULE_CONCURRENCY_KEY, 5
)
if not isinstance(concurrency_limit, int) or concurrency_limit <= 0:
logger.warning(
f"无效的定时任务并发限制配置 '{concurrency_limit}',将使用默认值 5。"
)
concurrency_limit = 5
logger.info(
f"开始执行针对 [所有群组] 的任务 "
f"(ID: {schedule.id}, 插件: {plugin_name}, Bot: {bot.self_id})"
f"并发限制: {concurrency_limit}"
)
all_gids = set()
try:
group_list, _ = await PlatformUtils.get_group_list(bot)
all_gids.update(
g.group_id for g in group_list if g.group_id and not g.channel_id
)
except Exception as e:
logger.error(f"'all' 任务获取 Bot {bot.self_id} 的群列表失败", e=e)
return
specific_tasks_gids = set(
await ScheduleInfo.filter(
plugin_name=plugin_name, group_id__in=list(all_gids)
).values_list("group_id", flat=True)
)
semaphore = asyncio.Semaphore(concurrency_limit)
async def worker(gid: str):
"""使用 Semaphore 包装单个群组的任务执行"""
await asyncio.sleep(random.uniform(0, 59))
async with semaphore:
temp_schedule = copy.deepcopy(schedule)
temp_schedule.group_id = gid
await _execute_for_single_target(temp_schedule, task_meta, bot)
await asyncio.sleep(random.uniform(0.1, 0.5))
tasks_to_run = []
for gid in all_gids:
if gid in specific_tasks_gids:
logger.debug(f"群组 {gid} 已有特定任务,跳过 'all' 任务的执行。")
continue
tasks_to_run.append(worker(gid))
if tasks_to_run:
await asyncio.gather(*tasks_to_run)
async def _execute_for_single_target(schedule: ScheduleInfo, task_meta: dict, bot):
"""为单个目标(具体群组或全局)执行任务。"""
plugin_name = schedule.plugin_name
group_id = schedule.group_id
try:
is_blocked = await CommonUtils.task_is_block(bot, plugin_name, group_id)
if is_blocked:
target_desc = f"{group_id}" if group_id else "全局"
logger.info(
f"插件 '{plugin_name}' 的定时任务在目标 [{target_desc}]"
"因功能被禁用而跳过执行。"
)
return
max_retries = Config.get_config("SchedulerManager", "JOB_MAX_RETRIES", 2)
retry_delay = Config.get_config("SchedulerManager", "JOB_RETRY_DELAY", 10)
@Retry.simple(
stop_max_attempt=max_retries + 1,
wait_fixed_seconds=retry_delay,
log_name=f"定时任务执行:{schedule.plugin_name}",
)
async def _execute_task_with_retry():
task_func = task_meta["func"]
job_kwargs = schedule.job_kwargs
if not isinstance(job_kwargs, dict):
logger.error(
f"任务 {schedule.id} 的 job_kwargs 不是字典类型: {type(job_kwargs)}"
)
return
sig = inspect.signature(task_func)
if "bot" in sig.parameters:
job_kwargs["bot"] = bot
await task_func(group_id, **job_kwargs)
try:
logger.info(
f"插件 '{schedule.plugin_name}' 开始为目标 "
f"[{schedule.group_id or '全局'}] 执行定时任务 (ID: {schedule.id})。"
)
await _execute_task_with_retry()
except Exception as e:
logger.error(
f"执行定时任务 (ID: {schedule.id}, 插件: {schedule.plugin_name}, "
f"目标: {schedule.group_id or '全局'}) 在所有重试后最终失败",
e=e,
)
except Exception as e:
logger.error(
f"执行定时任务 (ID: {schedule.id}, 插件: {plugin_name}, "
f"目标: {group_id or '全局'}) 时发生异常",
e=e,
)

View File

@ -0,0 +1,62 @@
"""
定时任务的生命周期管理
包含在机器人启动时加载和调度数据库中保存的任务的逻辑
"""
from zhenxun.services.log import logger
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from .adapter import APSchedulerAdapter
from .repository import ScheduleRepository
from .service import scheduler_manager
@PriorityLifecycle.on_startup(priority=90)
async def _load_schedules_from_db():
"""在服务启动时从数据库加载并调度所有任务。"""
logger.info("正在从数据库加载并调度所有定时任务...")
schedules = await ScheduleRepository.get_all_enabled()
count = 0
for schedule in schedules:
if schedule.plugin_name in scheduler_manager._registered_tasks:
APSchedulerAdapter.add_or_reschedule_job(schedule)
count += 1
else:
logger.warning(f"跳过加载定时任务:插件 '{schedule.plugin_name}' 未注册。")
logger.info(f"数据库定时任务加载完成,共成功加载 {count} 个任务。")
logger.info("正在检查并注册声明式默认任务...")
declared_count = 0
for task_info in scheduler_manager._declared_tasks:
plugin_name = task_info["plugin_name"]
group_id = task_info["group_id"]
bot_id = task_info["bot_id"]
query_kwargs = {
"plugin_name": plugin_name,
"group_id": group_id,
"bot_id": bot_id,
}
exists = await ScheduleRepository.exists(**query_kwargs)
if not exists:
logger.info(f"为插件 '{plugin_name}' 注册新的默认定时任务...")
schedule = await scheduler_manager.add_schedule(
plugin_name=plugin_name,
group_id=group_id,
trigger_type=task_info["trigger_type"],
trigger_config=task_info["trigger_config"],
job_kwargs=task_info["job_kwargs"],
bot_id=bot_id,
)
if schedule:
declared_count += 1
logger.debug(f"默认任务 '{plugin_name}' 注册成功 (ID: {schedule.id})")
else:
logger.error(f"默认任务 '{plugin_name}' 注册失败")
else:
logger.debug(f"插件 '{plugin_name}' 的默认任务已存在于数据库中,跳过注册。")
if declared_count > 0:
logger.info(f"声明式任务检查完成,新注册了 {declared_count} 个默认任务。")

View File

@ -0,0 +1,79 @@
"""
数据持久层 (Repository)
封装所有对 ScheduleInfo 模型的数据库操作将数据访问逻辑与业务逻辑分离
"""
from typing import Any
from tortoise.queryset import QuerySet
from zhenxun.models.schedule_info import ScheduleInfo
class ScheduleRepository:
"""封装 ScheduleInfo 模型的数据库操作"""
@staticmethod
async def get_by_id(schedule_id: int) -> ScheduleInfo | None:
"""通过ID获取任务"""
return await ScheduleInfo.get_or_none(id=schedule_id)
@staticmethod
async def get_all_enabled() -> list[ScheduleInfo]:
"""获取所有启用的任务"""
return await ScheduleInfo.filter(is_enabled=True).all()
@staticmethod
async def get_all(plugin_name: str | None = None) -> list[ScheduleInfo]:
"""获取所有任务,可按插件名过滤"""
if plugin_name:
return await ScheduleInfo.filter(plugin_name=plugin_name).all()
return await ScheduleInfo.all()
@staticmethod
async def save(schedule: ScheduleInfo, update_fields: list[str] | None = None):
"""保存任务"""
await schedule.save(update_fields=update_fields)
@staticmethod
async def exists(**kwargs: Any) -> bool:
"""检查任务是否存在"""
return await ScheduleInfo.exists(**kwargs)
@staticmethod
async def get_by_plugin_and_group(
plugin_name: str, group_ids: list[str]
) -> list[ScheduleInfo]:
"""根据插件和群组ID列表获取任务"""
return await ScheduleInfo.filter(
plugin_name=plugin_name, group_id__in=group_ids
).all()
@staticmethod
async def update_or_create(
defaults: dict, **kwargs: Any
) -> tuple[ScheduleInfo, bool]:
"""更新或创建任务"""
return await ScheduleInfo.update_or_create(defaults=defaults, **kwargs)
@staticmethod
async def query_schedules(**filters: Any) -> list[ScheduleInfo]:
"""
根据任意条件查询任务列表
参数:
**filters: 过滤条件 group_id="123", plugin_name="abc"
返回:
list[ScheduleInfo]: 任务列表
"""
cleaned_filters = {k: v for k, v in filters.items() if v is not None}
if not cleaned_filters:
return await ScheduleInfo.all()
return await ScheduleInfo.filter(**cleaned_filters).all()
@staticmethod
def filter(**kwargs: Any) -> QuerySet[ScheduleInfo]:
"""提供一个通用的过滤查询接口供Targeter使用"""
return ScheduleInfo.filter(**kwargs)

View File

@ -0,0 +1,448 @@
"""
服务层 (Service)
定义 SchedulerManager 类作为定时任务服务的公共 API 入口
它负责编排业务逻辑并调用 Repository Adapter 层来完成具体工作
"""
from collections.abc import Callable, Coroutine
from datetime import datetime
from typing import Any, ClassVar
import nonebot
from pydantic import BaseModel
from zhenxun.configs.config import Config
from zhenxun.models.schedule_info import ScheduleInfo
from zhenxun.services.log import logger
from .adapter import APSchedulerAdapter
from .job import _execute_job
from .repository import ScheduleRepository
from .targeter import ScheduleTargeter
class SchedulerManager:
ALL_GROUPS: ClassVar[str] = "__ALL_GROUPS__"
_registered_tasks: ClassVar[
dict[str, dict[str, Callable | type[BaseModel] | None]]
] = {}
_declared_tasks: ClassVar[list[dict[str, Any]]] = []
_running_tasks: ClassVar[set] = set()
def target(self, **filters: Any) -> ScheduleTargeter:
"""
创建目标选择器以执行批量操作
参数:
**filters: 过滤条件支持plugin_namegroup_idbot_id等字段
返回:
ScheduleTargeter: 目标选择器对象可用于批量操作
"""
return ScheduleTargeter(self, **filters)
def task(
self,
trigger: str,
group_id: str | None = None,
bot_id: str | None = None,
**trigger_kwargs,
):
"""
声明式定时任务装饰器
参数:
trigger: 触发器类型'cron''interval'
group_id: 目标群组IDNone表示全局任务
bot_id: 目标Bot IDNone表示使用默认Bot
**trigger_kwargs: 触发器配置参数
返回:
Callable: 装饰器函数
"""
def decorator(func: Callable[..., Coroutine]) -> Callable[..., Coroutine]:
try:
plugin = nonebot.get_plugin_by_module_name(func.__module__)
if not plugin:
raise ValueError(f"函数 {func.__name__} 不在任何已加载的插件中。")
plugin_name = plugin.name
task_declaration = {
"plugin_name": plugin_name,
"func": func,
"group_id": group_id,
"bot_id": bot_id,
"trigger_type": trigger,
"trigger_config": trigger_kwargs,
"job_kwargs": {},
}
self._declared_tasks.append(task_declaration)
logger.debug(
f"发现声明式定时任务 '{plugin_name}',将在启动时进行注册。"
)
except Exception as e:
logger.error(f"注册声明式定时任务失败: {func.__name__}, 错误: {e}")
return func
return decorator
def register(
self, plugin_name: str, params_model: type[BaseModel] | None = None
) -> Callable:
"""
注册可调度的任务函数
参数:
plugin_name: 插件名称用于标识任务
params_model: 参数验证模型继承自BaseModel的类
返回:
Callable: 装饰器函数
"""
def decorator(func: Callable[..., Coroutine]) -> Callable[..., Coroutine]:
if plugin_name in self._registered_tasks:
logger.warning(f"插件 '{plugin_name}' 的定时任务已被重复注册。")
self._registered_tasks[plugin_name] = {
"func": func,
"model": params_model,
}
model_name = params_model.__name__ if params_model else ""
logger.debug(
f"插件 '{plugin_name}' 的定时任务已注册,参数模型: {model_name}"
)
return func
return decorator
def get_registered_plugins(self) -> list[str]:
"""
获取已注册插件列表
返回:
list[str]: 已注册的插件名称列表
"""
return list(self._registered_tasks.keys())
async def add_daily_task(
self,
plugin_name: str,
group_id: str | None,
hour: int,
minute: int,
second: int = 0,
job_kwargs: dict | None = None,
bot_id: str | None = None,
) -> "ScheduleInfo | None":
"""
添加每日定时任务
参数:
plugin_name: 插件名称
group_id: 目标群组IDNone表示全局任务
hour: 执行小时0-23
minute: 执行分钟0-59
second: 执行秒数0-59默认为0
job_kwargs: 任务参数字典
bot_id: 目标Bot IDNone表示使用默认Bot
返回:
ScheduleInfo | None: 创建的任务信息失败时返回None
"""
trigger_config = {
"hour": hour,
"minute": minute,
"second": second,
"timezone": Config.get_config("SchedulerManager", "SCHEDULER_TIMEZONE"),
}
return await self.add_schedule(
plugin_name,
group_id,
"cron",
trigger_config,
job_kwargs=job_kwargs,
bot_id=bot_id,
)
async def add_interval_task(
self,
plugin_name: str,
group_id: str | None,
*,
weeks: int = 0,
days: int = 0,
hours: int = 0,
minutes: int = 0,
seconds: int = 0,
start_date: str | datetime | None = None,
job_kwargs: dict | None = None,
bot_id: str | None = None,
) -> "ScheduleInfo | None":
"""添加间隔性定时任务"""
trigger_config = {
"weeks": weeks,
"days": days,
"hours": hours,
"minutes": minutes,
"seconds": seconds,
"start_date": start_date,
}
trigger_config = {k: v for k, v in trigger_config.items() if v}
return await self.add_schedule(
plugin_name,
group_id,
"interval",
trigger_config,
job_kwargs=job_kwargs,
bot_id=bot_id,
)
def _validate_and_prepare_kwargs(
self, plugin_name: str, job_kwargs: dict | None
) -> tuple[bool, str | dict]:
"""验证并准备任务参数,应用默认值"""
from pydantic import ValidationError
task_meta = self._registered_tasks.get(plugin_name)
if not task_meta:
return False, f"插件 '{plugin_name}' 未注册。"
params_model = task_meta.get("model")
job_kwargs = job_kwargs if job_kwargs is not None else {}
if not params_model:
if job_kwargs:
logger.warning(
f"插件 '{plugin_name}' 未定义参数模型,但收到了参数: {job_kwargs}"
)
return True, job_kwargs
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
logger.error(f"插件 '{plugin_name}' 的参数模型不是有效的 BaseModel 类")
return False, f"插件 '{plugin_name}' 的参数模型配置错误"
try:
model_validate = getattr(params_model, "model_validate", None)
if not model_validate:
return False, f"插件 '{plugin_name}' 的参数模型不支持验证"
validated_model = model_validate(job_kwargs)
model_dump = getattr(validated_model, "model_dump", None)
if not model_dump:
return False, f"插件 '{plugin_name}' 的参数模型不支持导出"
return True, model_dump()
except ValidationError as e:
errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()]
error_str = "\n".join(errors)
msg = f"插件 '{plugin_name}' 的任务参数验证失败:\n{error_str}"
return False, msg
async def add_schedule(
self,
plugin_name: str,
group_id: str | None,
trigger_type: str,
trigger_config: dict,
job_kwargs: dict | None = None,
bot_id: str | None = None,
) -> "ScheduleInfo | None":
"""
添加定时任务通用方法
参数:
plugin_name: 插件名称
group_id: 目标群组IDNone表示全局任务
trigger_type: 触发器类型'cron''interval'
trigger_config: 触发器配置字典
job_kwargs: 任务参数字典
bot_id: 目标Bot IDNone表示使用默认Bot
返回:
ScheduleInfo | None: 创建的任务信息失败时返回None
"""
if plugin_name not in self._registered_tasks:
logger.error(f"插件 '{plugin_name}' 没有注册可用的定时任务。")
return None
is_valid, result = self._validate_and_prepare_kwargs(plugin_name, job_kwargs)
if not is_valid:
logger.error(f"任务参数校验失败: {result}")
return None
search_kwargs = {"plugin_name": plugin_name, "group_id": group_id}
if bot_id and group_id == self.ALL_GROUPS:
search_kwargs["bot_id"] = bot_id
else:
search_kwargs["bot_id__isnull"] = True
defaults = {
"trigger_type": trigger_type,
"trigger_config": trigger_config,
"job_kwargs": result,
"is_enabled": True,
}
schedule, created = await ScheduleRepository.update_or_create(
defaults, **search_kwargs
)
APSchedulerAdapter.add_or_reschedule_job(schedule)
action = "设置" if created else "更新"
logger.info(
f"已成功{action}插件 '{plugin_name}' 的定时任务 (ID: {schedule.id})。"
)
return schedule
async def get_all_schedules(self) -> list[ScheduleInfo]:
"""
获取所有定时任务信息
"""
return await self.get_schedules()
async def get_schedules(
self,
plugin_name: str | None = None,
group_id: str | None = None,
bot_id: str | None = None,
) -> list[ScheduleInfo]:
"""
根据条件获取定时任务列表
参数:
plugin_name: 插件名称None表示不限制
group_id: 群组IDNone表示不限制
bot_id: Bot IDNone表示不限制
返回:
list[ScheduleInfo]: 符合条件的任务信息列表
"""
return await ScheduleRepository.query_schedules(
plugin_name=plugin_name, group_id=group_id, bot_id=bot_id
)
async def update_schedule(
self,
schedule_id: int,
trigger_type: str | None = None,
trigger_config: dict | None = None,
job_kwargs: dict | None = None,
) -> tuple[bool, str]:
"""
更新定时任务配置
参数:
schedule_id: 任务ID
trigger_type: 新的触发器类型None表示不更新
trigger_config: 新的触发器配置None表示不更新
job_kwargs: 新的任务参数None表示不更新
返回:
tuple[bool, str]: (是否成功, 结果消息)
"""
schedule = await ScheduleRepository.get_by_id(schedule_id)
if not schedule:
return False, f"未找到 ID 为 {schedule_id} 的任务。"
updated_fields = []
if trigger_config is not None:
schedule.trigger_config = trigger_config
updated_fields.append("trigger_config")
if trigger_type is not None and schedule.trigger_type != trigger_type:
schedule.trigger_type = trigger_type
updated_fields.append("trigger_type")
if job_kwargs is not None:
existing_kwargs = (
schedule.job_kwargs.copy()
if isinstance(schedule.job_kwargs, dict)
else {}
)
existing_kwargs.update(job_kwargs)
is_valid, result = self._validate_and_prepare_kwargs(
schedule.plugin_name, existing_kwargs
)
if not is_valid:
return False, str(result)
assert isinstance(result, dict), "验证成功时 result 应该是字典类型"
schedule.job_kwargs = result
updated_fields.append("job_kwargs")
if not updated_fields:
return True, "没有任何需要更新的配置。"
await ScheduleRepository.save(schedule, update_fields=updated_fields)
APSchedulerAdapter.add_or_reschedule_job(schedule)
return True, f"成功更新了任务 ID: {schedule_id} 的配置。"
async def get_schedule_status(self, schedule_id: int) -> dict | None:
"""获取定时任务的详细状态信息"""
schedule = await ScheduleRepository.get_by_id(schedule_id)
if not schedule:
return None
status_from_scheduler = APSchedulerAdapter.get_job_status(schedule.id)
status_text = (
"运行中"
if schedule_id in self._running_tasks
else ("启用" if schedule.is_enabled else "暂停")
)
return {
"id": schedule.id,
"bot_id": schedule.bot_id,
"plugin_name": schedule.plugin_name,
"group_id": schedule.group_id,
"is_enabled": status_text,
"trigger_type": schedule.trigger_type,
"trigger_config": schedule.trigger_config,
"job_kwargs": schedule.job_kwargs,
**status_from_scheduler,
}
async def pause_schedule(self, schedule_id: int) -> tuple[bool, str]:
"""暂停指定的定时任务"""
schedule = await ScheduleRepository.get_by_id(schedule_id)
if not schedule or not schedule.is_enabled:
return False, "任务不存在或已暂停。"
schedule.is_enabled = False
await ScheduleRepository.save(schedule, update_fields=["is_enabled"])
APSchedulerAdapter.pause_job(schedule_id)
return True, f"已暂停任务 (ID: {schedule.id})。"
async def resume_schedule(self, schedule_id: int) -> tuple[bool, str]:
"""恢复指定的定时任务"""
schedule = await ScheduleRepository.get_by_id(schedule_id)
if not schedule or schedule.is_enabled:
return False, "任务不存在或已启用。"
schedule.is_enabled = True
await ScheduleRepository.save(schedule, update_fields=["is_enabled"])
APSchedulerAdapter.resume_job(schedule_id)
return True, f"已恢复任务 (ID: {schedule.id})。"
async def trigger_now(self, schedule_id: int) -> tuple[bool, str]:
"""立即手动触发指定的定时任务"""
schedule = await ScheduleRepository.get_by_id(schedule_id)
if not schedule:
return False, f"未找到 ID 为 {schedule_id} 的定时任务。"
if schedule.plugin_name not in self._registered_tasks:
return False, f"插件 '{schedule.plugin_name}' 没有注册可用的定时任务。"
try:
await _execute_job(schedule.id)
return True, f"已手动触发任务 (ID: {schedule.id})。"
except Exception as e:
logger.error(f"手动触发任务失败: {e}")
return False, f"手动触发任务失败: {e}"
scheduler_manager = SchedulerManager()

View File

@ -0,0 +1,109 @@
"""
目标选择器 (Targeter)
提供链式API用于构建和执行对多个定时任务的批量操作
"""
from collections.abc import Callable, Coroutine
from typing import Any
from .adapter import APSchedulerAdapter
from .repository import ScheduleRepository
class ScheduleTargeter:
"""
一个用于构建和执行定时任务批量操作的目标选择器
"""
def __init__(self, manager: Any, **filters: Any):
"""初始化目标选择器"""
self._manager = manager
self._filters = {k: v for k, v in filters.items() if v is not None}
async def _get_schedules(self):
"""根据过滤器获取任务"""
query = ScheduleRepository.filter(**self._filters)
return await query.all()
def _generate_target_description(self) -> str:
"""根据过滤条件生成友好的目标描述"""
if "id" in self._filters:
return f"任务 ID {self._filters['id']}"
parts = []
if "group_id" in self._filters:
group_id = self._filters["group_id"]
if group_id == self._manager.ALL_GROUPS:
parts.append("所有群组中")
else:
parts.append(f"{group_id}")
if "plugin_name" in self._filters:
parts.append(f"插件 '{self._filters['plugin_name']}'")
if not parts:
return "所有"
return "".join(parts)
async def _apply_operation(
self,
operation_func: Callable[[int], Coroutine[Any, Any, tuple[bool, str]]],
operation_name: str,
) -> tuple[int, str]:
"""通用的操作应用模板"""
schedules = await self._get_schedules()
if not schedules:
target_desc = self._generate_target_description()
return 0, f"没有找到{target_desc}可供{operation_name}的任务。"
success_count = 0
for schedule in schedules:
success, _ = await operation_func(schedule.id)
if success:
success_count += 1
target_desc = self._generate_target_description()
return (
success_count,
f"成功{operation_name}{target_desc} {success_count} 个任务。",
)
async def pause(self) -> tuple[int, str]:
"""
暂停匹配的定时任务
返回:
tuple[int, str]: (成功暂停的任务数量, 操作结果消息)
"""
return await self._apply_operation(self._manager.pause_schedule, "暂停")
async def resume(self) -> tuple[int, str]:
"""
恢复匹配的定时任务
返回:
tuple[int, str]: (成功恢复的任务数量, 操作结果消息)
"""
return await self._apply_operation(self._manager.resume_schedule, "恢复")
async def remove(self) -> tuple[int, str]:
"""
移除匹配的定时任务
返回:
tuple[int, str]: (成功移除的任务数量, 操作结果消息)
"""
schedules = await self._get_schedules()
if not schedules:
target_desc = self._generate_target_description()
return 0, f"没有找到{target_desc}可供移除的任务。"
for schedule in schedules:
APSchedulerAdapter.remove_job(schedule.id)
query = ScheduleRepository.filter(**self._filters)
count = await query.delete()
target_desc = self._generate_target_description()
return count, f"成功移除了{target_desc} {count} 个任务。"

View File

@ -137,19 +137,13 @@ def get_async_client(
class AsyncHttpx:
"""
一个高级的健壮的异步HTTP客户端工具类
高性能异步HTTP客户端工具类
设计理念:
- **全局共享客户端**: 默认情况下所有请求都通过一个在应用启动时初始化的全局
`httpx.AsyncClient` 实例发出这个实例共享连接池提高了效率和性能
- **向后兼容与灵活性**: 完全兼容旧的API同时提供了两种方式来处理需要
特殊网络配置如不同代理超时的请求
1. **单次请求覆盖**: 在调用 `get`, `post` 等方法时直接传入 `proxies`,
`timeout` 等参数将为该次请求创建一个临时的独立的客户端
2. **临时客户端上下文**: 使用 `temporary_client()` 上下文管理器可以
获取一个独立的可配置的客户端用于执行一系列需要相同特殊配置的请求
- **健壮性**: 内置了自动重试多镜像URL回退fallback机制并提供了便捷的
JSON解析和文件下载方法
特性:
- 全局共享连接池提升性能
- 支持临时客户端配置代理超时等
- 内置重试机制和多URL回退
- 提供JSON解析和文件下载功能
"""
CLIENT_KEY: ClassVar[list[str]] = [
@ -157,7 +151,6 @@ class AsyncHttpx:
"proxies",
"proxy",
"verify",
"headers",
]
default_proxy: ClassVar[dict[str, str] | None] = (
@ -290,15 +283,6 @@ class AsyncHttpx:
) -> Response:
"""发送 GET 请求,并返回第一个成功的响应。
说明:
本方法是 httpx.get 的高级包装增加了多链接尝试自动重试和统一的
客户端管理如果提供 URL 列表它将依次尝试直到成功为止
用法建议:
- **常规使用**: `await AsyncHttpx.get(url)` 将使用全局客户端
- **单次覆盖配置**: `await AsyncHttpx.get(url, timeout=5, proxies=None)`
将为本次请求创建一个独立的临时客户端
参数:
url: 单个请求 URL 或一个 URL 列表
follow_redirects: 是否跟随重定向
@ -312,7 +296,7 @@ class AsyncHttpx:
返回:
Response: httpx 的响应对象
Raises:
异常:
AllURIsFailedError: 当所有提供的URL都请求失败时抛出
"""
@ -373,10 +357,11 @@ class AsyncHttpx:
"""
[私有] 执行单个HTTP请求并解析JSON用于内部统一处理
"""
client_kwargs, request_kwargs = cls._split_kwargs(kwargs)
async with cls._get_active_client_context(
client=client, **kwargs
client=client, **client_kwargs
) as active_client:
_, request_kwargs = cls._split_kwargs(kwargs)
response = await active_client.request(method, url, **request_kwargs)
response.raise_for_status()
return response.json()
@ -394,11 +379,6 @@ class AsyncHttpx:
"""
发送GET请求并自动解析为JSON支持重试和多链接尝试
说明:
这是一个高度便捷的方法封装了请求重试JSON解析和错误处理
它会在网络错误或JSON解析错误时自动重试
如果所有尝试都失败它会安全地返回一个默认值
参数:
url: 单个请求 URL 或一个备用 URL 列表
default: (可选) 当所有尝试都失败时返回的默认值默认为None
@ -411,7 +391,7 @@ class AsyncHttpx:
返回:
Any: 解析后的JSON数据或在失败时返回 `default`
Raises:
异常:
AllURIsFailedError: `raise_on_failure` True 且所有URL都请求失败时抛出
"""
@ -490,25 +470,33 @@ class AsyncHttpx:
"""
执行单个流式下载的私有方法被重试装饰器包裹
"""
client_kwargs, request_kwargs = cls._split_kwargs(kwargs)
show_progress = request_kwargs.pop("show_progress", False)
async with cls._get_active_client_context(
client=client, **kwargs
client=client, **client_kwargs
) as active_client:
async with active_client.stream("GET", url, **kwargs) as response:
async with active_client.stream("GET", url, **request_kwargs) as response:
response.raise_for_status()
total = int(response.headers.get("Content-Length", 0))
with Progress(
TextColumn(path.name),
"[progress.percentage]{task.percentage:>3.0f}%",
BarColumn(bar_width=None),
DownloadColumn(),
TransferSpeedColumn(),
) as progress:
task_id = progress.add_task("Download", total=total)
if show_progress:
with Progress(
TextColumn(path.name),
"[progress.percentage]{task.percentage:>3.0f}%",
BarColumn(bar_width=None),
DownloadColumn(),
TransferSpeedColumn(),
) as progress:
task_id = progress.add_task("Download", total=total)
async with aiofiles.open(path, "wb") as f:
async for chunk in response.aiter_bytes():
await f.write(chunk)
progress.update(task_id, advance=len(chunk))
else:
async with aiofiles.open(path, "wb") as f:
async for chunk in response.aiter_bytes():
await f.write(chunk)
progress.update(task_id, advance=len(chunk))
@classmethod
async def download_file(
@ -517,6 +505,7 @@ class AsyncHttpx:
path: str | Path,
*,
stream: bool = False,
show_progress: bool = False,
client: AsyncClient | None = None,
**kwargs,
) -> bool:
@ -529,6 +518,7 @@ class AsyncHttpx:
url: 单个文件 URL 或一个备用 URL 列表
path: 文件保存的本地路径
stream: (可选) 是否使用流式下载适用于大文件默认为 False
show_progress: (可选) stream=True 是否显示下载进度条默认为 False
client: (可选) 指定的HTTP客户端
**kwargs: 其他所有传递给 get() 方法或 httpx.stream() 的参数
@ -544,7 +534,9 @@ class AsyncHttpx:
async with aiofiles.open(path, "wb") as f:
await f.write(content)
else:
await cls._stream_download(current_url, path, **worker_kwargs)
await cls._stream_download(
current_url, path, show_progress=show_progress, **worker_kwargs
)
logger.info(
f"下载 {current_url} 成功 -> {path.absolute()}",
@ -573,10 +565,6 @@ class AsyncHttpx:
) -> list[bool]:
"""并发下载多个文件,支持为每个文件提供备用镜像链接。
说明:
使用 asyncio.Semaphore 来控制并发请求的数量
对于 url_list 中的每个元素如果它是一个列表则会依次尝试直到下载成功
参数:
url_list: 包含所有文件下载任务的列表每个元素可以是
- 一个字符串 (str): 代表该任务的唯一URL
@ -625,9 +613,6 @@ class AsyncHttpx:
async def get_fastest_mirror(cls, url_list: list[str]) -> list[str]:
"""测试并返回最快的镜像地址。
说明:
通过并发发送 HEAD 请求来测试每个 URL 的响应时间和可用性并按响应速度排序
参数:
url_list: 需要测试的镜像 URL 列表
@ -671,23 +656,12 @@ class AsyncHttpx:
"""
创建一个临时的可配置的HTTP客户端上下文并直接返回该客户端实例
此方法返回一个标准的 `httpx.AsyncClient`它不使用全局连接池
拥有独立的配置(如代理headers超时等)并在退出上下文后自动关闭
适用于需要用一套特殊网络配置执行一系列请求的场景
用法:
async with AsyncHttpx.temporary_client(proxies=None, timeout=5) as client:
# client 是一个标准的 httpx.AsyncClient 实例
response1 = await client.get("http://some.internal.api/1")
response2 = await client.get("http://some.internal.api/2")
data = response2.json()
参数:
**kwargs: 所有传递给 `httpx.AsyncClient` 构造函数的参数
例如: `proxies`, `headers`, `verify`, `timeout`,
`follow_redirects`
Yields:
返回:
httpx.AsyncClient: 一个配置好的临时的客户端实例
"""
async with get_async_client(**kwargs) as client:

View File

@ -1,810 +0,0 @@
import asyncio
from collections.abc import Callable, Coroutine
import copy
import inspect
import random
from typing import ClassVar
import nonebot
from nonebot import get_bots
from nonebot_plugin_apscheduler import scheduler
from pydantic import BaseModel, ValidationError
from zhenxun.configs.config import Config
from zhenxun.models.schedule_info import ScheduleInfo
from zhenxun.services.log import logger
from zhenxun.utils.common_utils import CommonUtils
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from zhenxun.utils.platform import PlatformUtils
SCHEDULE_CONCURRENCY_KEY = "all_groups_concurrency_limit"
class SchedulerManager:
"""
一个通用的持久化的定时任务管理器供所有插件使用
"""
_registered_tasks: ClassVar[
dict[str, dict[str, Callable | type[BaseModel] | None]]
] = {}
_JOB_PREFIX = "zhenxun_schedule_"
_running_tasks: ClassVar[set] = set()
def register(
self, plugin_name: str, params_model: type[BaseModel] | None = None
) -> Callable:
"""
注册一个可调度的任务函数
被装饰的函数签名应为 `async def func(group_id: str | None, **kwargs)`
Args:
plugin_name (str): 插件的唯一名称 (通常是模块名)
params_model (type[BaseModel], optional): 一个 Pydantic BaseModel
用于定义和验证任务函数接受的额外参数
"""
def decorator(func: Callable[..., Coroutine]) -> Callable[..., Coroutine]:
if plugin_name in self._registered_tasks:
logger.warning(f"插件 '{plugin_name}' 的定时任务已被重复注册。")
self._registered_tasks[plugin_name] = {
"func": func,
"model": params_model,
}
model_name = params_model.__name__ if params_model else ""
logger.debug(
f"插件 '{plugin_name}' 的定时任务已注册,参数模型: {model_name}"
)
return func
return decorator
def get_registered_plugins(self) -> list[str]:
"""获取所有已注册定时任务的插件列表。"""
return list(self._registered_tasks.keys())
def _get_job_id(self, schedule_id: int) -> str:
"""根据数据库ID生成唯一的 APScheduler Job ID。"""
return f"{self._JOB_PREFIX}{schedule_id}"
async def _execute_job(self, schedule_id: int):
"""
APScheduler 调度的入口函数
根据 schedule_id 处理特定任务所有群组任务或全局任务
"""
schedule = await ScheduleInfo.get_or_none(id=schedule_id)
if not schedule or not schedule.is_enabled:
logger.warning(f"定时任务 {schedule_id} 不存在或已禁用,跳过执行。")
return
plugin_name = schedule.plugin_name
task_meta = self._registered_tasks.get(plugin_name)
if not task_meta:
logger.error(
f"无法执行定时任务:插件 '{plugin_name}' 未注册或已卸载。将禁用该任务。"
)
schedule.is_enabled = False
await schedule.save(update_fields=["is_enabled"])
self._remove_aps_job(schedule.id)
return
try:
if schedule.bot_id:
bot = nonebot.get_bot(schedule.bot_id)
else:
bot = nonebot.get_bot()
logger.debug(
f"任务 {schedule_id} 未关联特定Bot使用默认Bot {bot.self_id}"
)
except KeyError:
logger.warning(
f"定时任务 {schedule_id} 需要的 Bot {schedule.bot_id} "
f"不在线,本次执行跳过。"
)
return
except ValueError:
logger.warning(f"当前没有Bot在线定时任务 {schedule_id} 跳过。")
return
if schedule.group_id == "__ALL_GROUPS__":
await self._execute_for_all_groups(schedule, task_meta, bot)
else:
await self._execute_for_single_target(schedule, task_meta, bot)
async def _execute_for_all_groups(
self, schedule: ScheduleInfo, task_meta: dict, bot
):
"""为所有群组执行任务,并处理优先级覆盖。"""
plugin_name = schedule.plugin_name
concurrency_limit = Config.get_config(
"SchedulerManager", SCHEDULE_CONCURRENCY_KEY, 5
)
if not isinstance(concurrency_limit, int) or concurrency_limit <= 0:
logger.warning(
f"无效的定时任务并发限制配置 '{concurrency_limit}',将使用默认值 5。"
)
concurrency_limit = 5
logger.info(
f"开始执行针对 [所有群组] 的任务 "
f"(ID: {schedule.id}, 插件: {plugin_name}, Bot: {bot.self_id})"
f"并发限制: {concurrency_limit}"
)
all_gids = set()
try:
group_list, _ = await PlatformUtils.get_group_list(bot)
all_gids.update(
g.group_id for g in group_list if g.group_id and not g.channel_id
)
except Exception as e:
logger.error(f"'all' 任务获取 Bot {bot.self_id} 的群列表失败", e=e)
return
specific_tasks_gids = set(
await ScheduleInfo.filter(
plugin_name=plugin_name, group_id__in=list(all_gids)
).values_list("group_id", flat=True)
)
semaphore = asyncio.Semaphore(concurrency_limit)
async def worker(gid: str):
"""使用 Semaphore 包装单个群组的任务执行"""
async with semaphore:
temp_schedule = copy.deepcopy(schedule)
temp_schedule.group_id = gid
await self._execute_for_single_target(temp_schedule, task_meta, bot)
await asyncio.sleep(random.uniform(0.1, 0.5))
tasks_to_run = []
for gid in all_gids:
if gid in specific_tasks_gids:
logger.debug(f"群组 {gid} 已有特定任务,跳过 'all' 任务的执行。")
continue
tasks_to_run.append(worker(gid))
if tasks_to_run:
await asyncio.gather(*tasks_to_run)
async def _execute_for_single_target(
self, schedule: ScheduleInfo, task_meta: dict, bot
):
"""为单个目标(具体群组或全局)执行任务。"""
plugin_name = schedule.plugin_name
group_id = schedule.group_id
try:
is_blocked = await CommonUtils.task_is_block(bot, plugin_name, group_id)
if is_blocked:
target_desc = f"{group_id}" if group_id else "全局"
logger.info(
f"插件 '{plugin_name}' 的定时任务在目标 [{target_desc}]"
"因功能被禁用而跳过执行。"
)
return
task_func = task_meta["func"]
job_kwargs = schedule.job_kwargs
if not isinstance(job_kwargs, dict):
logger.error(
f"任务 {schedule.id} 的 job_kwargs 不是字典类型: {type(job_kwargs)}"
)
return
sig = inspect.signature(task_func)
if "bot" in sig.parameters:
job_kwargs["bot"] = bot
logger.info(
f"插件 '{plugin_name}' 开始为目标 [{group_id or '全局'}] "
f"执行定时任务 (ID: {schedule.id})。"
)
task = asyncio.create_task(task_func(group_id, **job_kwargs))
self._running_tasks.add(task)
task.add_done_callback(self._running_tasks.discard)
await task
except Exception as e:
logger.error(
f"执行定时任务 (ID: {schedule.id}, 插件: {plugin_name}, "
f"目标: {group_id or '全局'}) 时发生异常",
e=e,
)
def _validate_and_prepare_kwargs(
self, plugin_name: str, job_kwargs: dict | None
) -> tuple[bool, str | dict]:
"""验证并准备任务参数,应用默认值"""
task_meta = self._registered_tasks.get(plugin_name)
if not task_meta:
return False, f"插件 '{plugin_name}' 未注册。"
params_model = task_meta.get("model")
job_kwargs = job_kwargs if job_kwargs is not None else {}
if not params_model:
if job_kwargs:
logger.warning(
f"插件 '{plugin_name}' 未定义参数模型,但收到了参数: {job_kwargs}"
)
return True, job_kwargs
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
logger.error(f"插件 '{plugin_name}' 的参数模型不是有效的 BaseModel 类")
return False, f"插件 '{plugin_name}' 的参数模型配置错误"
try:
model_validate = getattr(params_model, "model_validate", None)
if not model_validate:
return False, f"插件 '{plugin_name}' 的参数模型不支持验证"
validated_model = model_validate(job_kwargs)
model_dump = getattr(validated_model, "model_dump", None)
if not model_dump:
return False, f"插件 '{plugin_name}' 的参数模型不支持导出"
return True, model_dump()
except ValidationError as e:
errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()]
error_str = "\n".join(errors)
msg = f"插件 '{plugin_name}' 的任务参数验证失败:\n{error_str}"
return False, msg
def _add_aps_job(self, schedule: ScheduleInfo):
"""根据 ScheduleInfo 对象添加或更新一个 APScheduler 任务。"""
job_id = self._get_job_id(schedule.id)
try:
scheduler.remove_job(job_id)
except Exception:
pass
if not isinstance(schedule.trigger_config, dict):
logger.error(
f"任务 {schedule.id} 的 trigger_config 不是字典类型: "
f"{type(schedule.trigger_config)}"
)
return
scheduler.add_job(
self._execute_job,
trigger=schedule.trigger_type,
id=job_id,
misfire_grace_time=300,
args=[schedule.id],
**schedule.trigger_config,
)
logger.debug(
f"已在 APScheduler 中添加/更新任务: {job_id} "
f"with trigger: {schedule.trigger_config}"
)
def _remove_aps_job(self, schedule_id: int):
"""移除一个 APScheduler 任务。"""
job_id = self._get_job_id(schedule_id)
try:
scheduler.remove_job(job_id)
logger.debug(f"已从 APScheduler 中移除任务: {job_id}")
except Exception:
pass
async def add_schedule(
self,
plugin_name: str,
group_id: str | None,
trigger_type: str,
trigger_config: dict,
job_kwargs: dict | None = None,
bot_id: str | None = None,
) -> tuple[bool, str]:
"""
添加或更新一个定时任务
"""
if plugin_name not in self._registered_tasks:
return False, f"插件 '{plugin_name}' 没有注册可用的定时任务。"
is_valid, result = self._validate_and_prepare_kwargs(plugin_name, job_kwargs)
if not is_valid:
return False, str(result)
validated_job_kwargs = result
effective_bot_id = bot_id if group_id == "__ALL_GROUPS__" else None
search_kwargs = {
"plugin_name": plugin_name,
"group_id": group_id,
}
if effective_bot_id:
search_kwargs["bot_id"] = effective_bot_id
else:
search_kwargs["bot_id__isnull"] = True
defaults = {
"trigger_type": trigger_type,
"trigger_config": trigger_config,
"job_kwargs": validated_job_kwargs,
"is_enabled": True,
}
schedule = await ScheduleInfo.filter(**search_kwargs).first()
created = False
if schedule:
for key, value in defaults.items():
setattr(schedule, key, value)
await schedule.save()
else:
creation_kwargs = {
"plugin_name": plugin_name,
"group_id": group_id,
"bot_id": effective_bot_id,
**defaults,
}
schedule = await ScheduleInfo.create(**creation_kwargs)
created = True
self._add_aps_job(schedule)
action = "设置" if created else "更新"
return True, f"已成功{action}插件 '{plugin_name}' 的定时任务。"
async def add_schedule_for_all(
self,
plugin_name: str,
trigger_type: str,
trigger_config: dict,
job_kwargs: dict | None = None,
) -> tuple[int, int]:
"""为所有机器人所在的群组添加定时任务。"""
if plugin_name not in self._registered_tasks:
raise ValueError(f"插件 '{plugin_name}' 没有注册可用的定时任务。")
groups = set()
for bot in get_bots().values():
try:
group_list, _ = await PlatformUtils.get_group_list(bot)
groups.update(
g.group_id for g in group_list if g.group_id and not g.channel_id
)
except Exception as e:
logger.error(f"获取 Bot {bot.self_id} 的群列表失败", e=e)
success_count = 0
fail_count = 0
for gid in groups:
try:
success, _ = await self.add_schedule(
plugin_name, gid, trigger_type, trigger_config, job_kwargs
)
if success:
success_count += 1
else:
fail_count += 1
except Exception as e:
logger.error(f"为群 {gid} 添加定时任务失败: {e}", e=e)
fail_count += 1
await asyncio.sleep(0.05)
return success_count, fail_count
async def update_schedule(
self,
schedule_id: int,
trigger_type: str | None = None,
trigger_config: dict | None = None,
job_kwargs: dict | None = None,
) -> tuple[bool, str]:
"""部分更新一个已存在的定时任务。"""
schedule = await self.get_schedule_by_id(schedule_id)
if not schedule:
return False, f"未找到 ID 为 {schedule_id} 的任务。"
updated_fields = []
if trigger_config is not None:
schedule.trigger_config = trigger_config
updated_fields.append("trigger_config")
if trigger_type is not None and schedule.trigger_type != trigger_type:
schedule.trigger_type = trigger_type
updated_fields.append("trigger_type")
if job_kwargs is not None:
if not isinstance(schedule.job_kwargs, dict):
return False, f"任务 {schedule_id} 的 job_kwargs 数据格式错误。"
merged_kwargs = schedule.job_kwargs.copy()
merged_kwargs.update(job_kwargs)
is_valid, result = self._validate_and_prepare_kwargs(
schedule.plugin_name, merged_kwargs
)
if not is_valid:
return False, str(result)
schedule.job_kwargs = result # type: ignore
updated_fields.append("job_kwargs")
if not updated_fields:
return True, "没有任何需要更新的配置。"
await schedule.save(update_fields=updated_fields)
self._add_aps_job(schedule)
return True, f"成功更新了任务 ID: {schedule_id} 的配置。"
async def remove_schedule(
self, plugin_name: str, group_id: str | None, bot_id: str | None = None
) -> tuple[bool, str]:
"""移除指定插件和群组的定时任务。"""
query = {"plugin_name": plugin_name, "group_id": group_id}
if bot_id:
query["bot_id"] = bot_id
schedules = await ScheduleInfo.filter(**query)
if not schedules:
msg = (
f"未找到与 Bot {bot_id} 相关的群 {group_id} "
f"的插件 '{plugin_name}' 定时任务。"
)
return (False, msg)
for schedule in schedules:
self._remove_aps_job(schedule.id)
await schedule.delete()
target_desc = f"{group_id}" if group_id else "全局"
msg = (
f"已取消 Bot {bot_id} 在 [{target_desc}] "
f"的插件 '{plugin_name}' 所有定时任务。"
)
return (True, msg)
async def remove_schedule_for_all(
self, plugin_name: str, bot_id: str | None = None
) -> int:
"""移除指定插件在所有群组的定时任务。"""
query = {"plugin_name": plugin_name}
if bot_id:
query["bot_id"] = bot_id
schedules_to_delete = await ScheduleInfo.filter(**query).all()
if not schedules_to_delete:
return 0
for schedule in schedules_to_delete:
self._remove_aps_job(schedule.id)
await schedule.delete()
await asyncio.sleep(0.01)
return len(schedules_to_delete)
async def remove_schedules_by_group(self, group_id: str) -> tuple[bool, str]:
"""移除指定群组的所有定时任务。"""
schedules = await ScheduleInfo.filter(group_id=group_id)
if not schedules:
return False, f"{group_id} 没有任何定时任务。"
count = 0
for schedule in schedules:
self._remove_aps_job(schedule.id)
await schedule.delete()
count += 1
await asyncio.sleep(0.01)
return True, f"已成功移除群 {group_id}{count} 个定时任务。"
async def pause_schedules_by_group(self, group_id: str) -> tuple[int, str]:
"""暂停指定群组的所有定时任务。"""
schedules = await ScheduleInfo.filter(group_id=group_id, is_enabled=True)
if not schedules:
return 0, f"{group_id} 没有正在运行的定时任务可暂停。"
count = 0
for schedule in schedules:
success, _ = await self.pause_schedule(schedule.id)
if success:
count += 1
await asyncio.sleep(0.01)
return count, f"已成功暂停群 {group_id}{count} 个定时任务。"
async def resume_schedules_by_group(self, group_id: str) -> tuple[int, str]:
"""恢复指定群组的所有定时任务。"""
schedules = await ScheduleInfo.filter(group_id=group_id, is_enabled=False)
if not schedules:
return 0, f"{group_id} 没有已暂停的定时任务可恢复。"
count = 0
for schedule in schedules:
success, _ = await self.resume_schedule(schedule.id)
if success:
count += 1
await asyncio.sleep(0.01)
return count, f"已成功恢复群 {group_id}{count} 个定时任务。"
async def pause_schedules_by_plugin(self, plugin_name: str) -> tuple[int, str]:
"""暂停指定插件在所有群组的定时任务。"""
schedules = await ScheduleInfo.filter(plugin_name=plugin_name, is_enabled=True)
if not schedules:
return 0, f"插件 '{plugin_name}' 没有正在运行的定时任务可暂停。"
count = 0
for schedule in schedules:
success, _ = await self.pause_schedule(schedule.id)
if success:
count += 1
await asyncio.sleep(0.01)
return (
count,
f"已成功暂停插件 '{plugin_name}' 在所有群组的 {count} 个定时任务。",
)
async def resume_schedules_by_plugin(self, plugin_name: str) -> tuple[int, str]:
"""恢复指定插件在所有群组的定时任务。"""
schedules = await ScheduleInfo.filter(plugin_name=plugin_name, is_enabled=False)
if not schedules:
return 0, f"插件 '{plugin_name}' 没有已暂停的定时任务可恢复。"
count = 0
for schedule in schedules:
success, _ = await self.resume_schedule(schedule.id)
if success:
count += 1
await asyncio.sleep(0.01)
return (
count,
f"已成功恢复插件 '{plugin_name}' 在所有群组的 {count} 个定时任务。",
)
async def pause_schedule_by_plugin_group(
self, plugin_name: str, group_id: str | None, bot_id: str | None = None
) -> tuple[bool, str]:
"""暂停指定插件在指定群组的定时任务。"""
query = {"plugin_name": plugin_name, "group_id": group_id, "is_enabled": True}
if bot_id:
query["bot_id"] = bot_id
schedules = await ScheduleInfo.filter(**query)
if not schedules:
return (
False,
f"{group_id} 未设置插件 '{plugin_name}' 的定时任务或任务已暂停。",
)
count = 0
for schedule in schedules:
success, _ = await self.pause_schedule(schedule.id)
if success:
count += 1
return (
True,
f"已成功暂停群 {group_id} 的插件 '{plugin_name}'{count} 个定时任务。",
)
async def resume_schedule_by_plugin_group(
self, plugin_name: str, group_id: str | None, bot_id: str | None = None
) -> tuple[bool, str]:
"""恢复指定插件在指定群组的定时任务。"""
query = {"plugin_name": plugin_name, "group_id": group_id, "is_enabled": False}
if bot_id:
query["bot_id"] = bot_id
schedules = await ScheduleInfo.filter(**query)
if not schedules:
return (
False,
f"{group_id} 未设置插件 '{plugin_name}' 的定时任务或任务已启用。",
)
count = 0
for schedule in schedules:
success, _ = await self.resume_schedule(schedule.id)
if success:
count += 1
return (
True,
f"已成功恢复群 {group_id} 的插件 '{plugin_name}'{count} 个定时任务。",
)
async def remove_all_schedules(self) -> tuple[int, str]:
"""移除所有群组的所有定时任务。"""
schedules = await ScheduleInfo.all()
if not schedules:
return 0, "当前没有任何定时任务。"
count = 0
for schedule in schedules:
self._remove_aps_job(schedule.id)
await schedule.delete()
count += 1
await asyncio.sleep(0.01)
return count, f"已成功移除所有群组的 {count} 个定时任务。"
async def pause_all_schedules(self) -> tuple[int, str]:
"""暂停所有群组的所有定时任务。"""
schedules = await ScheduleInfo.filter(is_enabled=True)
if not schedules:
return 0, "当前没有正在运行的定时任务可暂停。"
count = 0
for schedule in schedules:
success, _ = await self.pause_schedule(schedule.id)
if success:
count += 1
await asyncio.sleep(0.01)
return count, f"已成功暂停所有群组的 {count} 个定时任务。"
async def resume_all_schedules(self) -> tuple[int, str]:
"""恢复所有群组的所有定时任务。"""
schedules = await ScheduleInfo.filter(is_enabled=False)
if not schedules:
return 0, "当前没有已暂停的定时任务可恢复。"
count = 0
for schedule in schedules:
success, _ = await self.resume_schedule(schedule.id)
if success:
count += 1
await asyncio.sleep(0.01)
return count, f"已成功恢复所有群组的 {count} 个定时任务。"
async def remove_schedule_by_id(self, schedule_id: int) -> tuple[bool, str]:
"""通过ID移除指定的定时任务。"""
schedule = await self.get_schedule_by_id(schedule_id)
if not schedule:
return False, f"未找到 ID 为 {schedule_id} 的定时任务。"
self._remove_aps_job(schedule.id)
await schedule.delete()
return (
True,
f"已删除插件 '{schedule.plugin_name}' 在群 {schedule.group_id} "
f"的定时任务 (ID: {schedule.id})。",
)
async def get_schedule_by_id(self, schedule_id: int) -> ScheduleInfo | None:
"""通过ID获取定时任务信息。"""
return await ScheduleInfo.get_or_none(id=schedule_id)
async def get_schedules(
self, plugin_name: str, group_id: str | None
) -> list[ScheduleInfo]:
"""获取特定群组特定插件的所有定时任务。"""
return await ScheduleInfo.filter(plugin_name=plugin_name, group_id=group_id)
async def get_schedule(
self, plugin_name: str, group_id: str | None
) -> ScheduleInfo | None:
"""获取特定群组的定时任务信息。"""
return await ScheduleInfo.get_or_none(
plugin_name=plugin_name, group_id=group_id
)
async def get_all_schedules(
self, plugin_name: str | None = None
) -> list[ScheduleInfo]:
"""获取所有定时任务信息,可按插件名过滤。"""
if plugin_name:
return await ScheduleInfo.filter(plugin_name=plugin_name).all()
return await ScheduleInfo.all()
async def get_schedule_status(self, schedule_id: int) -> dict | None:
"""获取任务的详细状态。"""
schedule = await self.get_schedule_by_id(schedule_id)
if not schedule:
return None
job_id = self._get_job_id(schedule.id)
job = scheduler.get_job(job_id)
status = {
"id": schedule.id,
"bot_id": schedule.bot_id,
"plugin_name": schedule.plugin_name,
"group_id": schedule.group_id,
"is_enabled": schedule.is_enabled,
"trigger_type": schedule.trigger_type,
"trigger_config": schedule.trigger_config,
"job_kwargs": schedule.job_kwargs,
"next_run_time": job.next_run_time.strftime("%Y-%m-%d %H:%M:%S")
if job and job.next_run_time
else "N/A",
"is_paused_in_scheduler": not bool(job.next_run_time) if job else "N/A",
}
return status
async def pause_schedule(self, schedule_id: int) -> tuple[bool, str]:
"""暂停一个定时任务。"""
schedule = await self.get_schedule_by_id(schedule_id)
if not schedule or not schedule.is_enabled:
return False, "任务不存在或已暂停。"
schedule.is_enabled = False
await schedule.save(update_fields=["is_enabled"])
job_id = self._get_job_id(schedule.id)
try:
scheduler.pause_job(job_id)
except Exception:
pass
return (
True,
f"已暂停插件 '{schedule.plugin_name}' 在群 {schedule.group_id} "
f"的定时任务 (ID: {schedule.id})。",
)
async def resume_schedule(self, schedule_id: int) -> tuple[bool, str]:
"""恢复一个定时任务。"""
schedule = await self.get_schedule_by_id(schedule_id)
if not schedule or schedule.is_enabled:
return False, "任务不存在或已启用。"
schedule.is_enabled = True
await schedule.save(update_fields=["is_enabled"])
job_id = self._get_job_id(schedule.id)
try:
scheduler.resume_job(job_id)
except Exception:
self._add_aps_job(schedule)
return (
True,
f"已恢复插件 '{schedule.plugin_name}' 在群 {schedule.group_id} "
f"的定时任务 (ID: {schedule.id})。",
)
async def trigger_now(self, schedule_id: int) -> tuple[bool, str]:
"""手动触发一个定时任务。"""
schedule = await self.get_schedule_by_id(schedule_id)
if not schedule:
return False, f"未找到 ID 为 {schedule_id} 的定时任务。"
if schedule.plugin_name not in self._registered_tasks:
return False, f"插件 '{schedule.plugin_name}' 没有注册可用的定时任务。"
try:
await self._execute_job(schedule.id)
return (
True,
f"已手动触发插件 '{schedule.plugin_name}' 在群 {schedule.group_id} "
f"的定时任务 (ID: {schedule.id})。",
)
except Exception as e:
logger.error(f"手动触发任务失败: {e}")
return False, f"手动触发任务失败: {e}"
scheduler_manager = SchedulerManager()
@PriorityLifecycle.on_startup(priority=90)
async def _load_schedules_from_db():
"""在服务启动时从数据库加载并调度所有任务。"""
Config.add_plugin_config(
"SchedulerManager",
SCHEDULE_CONCURRENCY_KEY,
5,
help="“所有群组”类型定时任务的并发执行数量限制",
type=int,
)
logger.info("正在从数据库加载并调度所有定时任务...")
schedules = await ScheduleInfo.filter(is_enabled=True).all()
count = 0
for schedule in schedules:
if schedule.plugin_name in scheduler_manager._registered_tasks:
scheduler_manager._add_aps_job(schedule)
count += 1
else:
logger.warning(f"跳过加载定时任务:插件 '{schedule.plugin_name}' 未注册。")
logger.info(f"定时任务加载完成,共成功加载 {count} 个任务。")

View File

@ -31,7 +31,9 @@ class VirtualEnvPackageManager:
def __get_command(cls) -> list[str]:
if path := Config.get_config("virtualenv", "python_path"):
return [path, "-m", "pip"]
return cls.WIN_COMMAND if BAT_FILE.exists() else cls.DEFAULT_COMMAND
return (
cls.WIN_COMMAND.copy() if BAT_FILE.exists() else cls.DEFAULT_COMMAND.copy()
)
@classmethod
def install(cls, package: list[str] | str):
@ -57,8 +59,10 @@ class VirtualEnvPackageManager:
f"安装虚拟环境包指令执行完成: {result.stdout}",
LOG_COMMAND,
)
return result.stdout
except CalledProcessError as e:
logger.error(f"安装虚拟环境包指令执行失败: {e.stderr}.", LOG_COMMAND)
return e.stderr
@classmethod
def uninstall(cls, package: list[str] | str):
@ -72,6 +76,7 @@ class VirtualEnvPackageManager:
try:
command = cls.__get_command()
command.append("uninstall")
command.append("-y")
command.append(" ".join(package))
logger.info(f"执行虚拟环境卸载包指令: {command}", LOG_COMMAND)
result = subprocess.run(
@ -84,8 +89,10 @@ class VirtualEnvPackageManager:
f"卸载虚拟环境包指令执行完成: {result.stdout}",
LOG_COMMAND,
)
return result.stdout
except CalledProcessError as e:
logger.error(f"卸载虚拟环境包指令执行失败: {e.stderr}.", LOG_COMMAND)
return e.stderr
@classmethod
def update(cls, package: list[str] | str):
@ -109,8 +116,10 @@ class VirtualEnvPackageManager:
text=True,
)
logger.debug(f"更新虚拟环境包指令执行完成: {result.stdout}", LOG_COMMAND)
return result.stdout
except CalledProcessError as e:
logger.error(f"更新虚拟环境包指令执行失败: {e.stderr}.", LOG_COMMAND)
return e.stderr
@classmethod
def install_requirement(cls, requirement_file: Path):
@ -140,11 +149,13 @@ class VirtualEnvPackageManager:
f"安装虚拟环境依赖文件指令执行完成: {result.stdout}",
LOG_COMMAND,
)
return result.stdout
except CalledProcessError as e:
logger.error(
f"安装虚拟环境依赖文件指令执行失败: {e.stderr}.",
LOG_COMMAND,
)
return e.stderr
@classmethod
def list(cls) -> str:

View File

@ -80,14 +80,14 @@ class PlatformUtils:
@classmethod
async def send_superuser(
cls,
bot: Bot,
bot: Bot | None,
message: UniMessage | str,
superuser_id: str | None = None,
) -> list[tuple[str, Receipt]]:
"""发送消息给超级用户
参数:
bot: Bot
bot: Bot没有传入时使用get_bot随机获取
message: 消息
superuser_id: 指定超级用户id.
@ -97,6 +97,8 @@ class PlatformUtils:
返回:
Receipt | None: Receipt
"""
if not bot:
bot = nonebot.get_bot()
superuser_ids = []
if superuser_id:
superuser_ids.append(superuser_id)

View File

@ -1,6 +1,6 @@
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime
from datetime import date, datetime
import os
from pathlib import Path
import time
@ -277,3 +277,20 @@ def is_number(text: str) -> bool:
return True
except ValueError:
return False
class TimeUtils:
@classmethod
def get_day_start(cls, target_date: date | datetime | None = None) -> datetime:
"""获取某天的0点时间
返回:
datetime: 今天某天的0点时间
"""
if not target_date:
target_date = datetime.now()
return (
target_date.replace(hour=0, minute=0, second=0, microsecond=0)
if isinstance(target_date, datetime)
else datetime.combine(target_date, datetime.min.time())
)