mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 14:22:55 +08:00
Merge branch 'main' into feature/db-cache
This commit is contained in:
commit
205325b994
1
.vscode/settings.json
vendored
1
.vscode/settings.json
vendored
@ -29,6 +29,7 @@
|
||||
"unban",
|
||||
"Uninfo",
|
||||
"userinfo",
|
||||
"webui",
|
||||
"zhenxun"
|
||||
],
|
||||
"python.analysis.autoImportCompletions": true,
|
||||
|
||||
@ -13,7 +13,11 @@ from pytest_mock import MockerFixture
|
||||
from respx import MockRouter
|
||||
|
||||
from tests.config import BotId, GroupId, MessageId, UserId
|
||||
from tests.utils import _v11_group_message_event, _v11_private_message_send
|
||||
from tests.utils import (
|
||||
_v11_group_message_event,
|
||||
_v11_private_message_send,
|
||||
get_reply_cq,
|
||||
)
|
||||
from tests.utils import get_response_json as _get_response_json
|
||||
|
||||
|
||||
@ -311,6 +315,12 @@ async def test_check_update_release(
|
||||
to_me=True,
|
||||
)
|
||||
ctx.receive_event(bot, event)
|
||||
ctx.should_call_send(
|
||||
event=event,
|
||||
message=Message(f"{get_reply_cq(MessageId.MESSAGE_ID)}正在进行检查更新..."),
|
||||
result=None,
|
||||
bot=bot,
|
||||
)
|
||||
ctx.should_call_api(
|
||||
"send_msg",
|
||||
_v11_private_message_send(
|
||||
@ -401,6 +411,12 @@ async def test_check_update_main(
|
||||
to_me=True,
|
||||
)
|
||||
ctx.receive_event(bot, event)
|
||||
ctx.should_call_send(
|
||||
event=event,
|
||||
message=Message(f"{get_reply_cq(MessageId.MESSAGE_ID)}正在进行检查更新..."),
|
||||
result=None,
|
||||
bot=bot,
|
||||
)
|
||||
ctx.should_call_api(
|
||||
"send_msg",
|
||||
_v11_private_message_send(
|
||||
|
||||
@ -5,6 +5,10 @@ from nonebot.adapters.onebot.v11 import GroupMessageEvent, Message, MessageSegme
|
||||
from nonebot.adapters.onebot.v11.event import Sender
|
||||
|
||||
|
||||
def get_reply_cq(uid: int | str) -> str:
|
||||
return f"[CQ:reply,id={uid}]"
|
||||
|
||||
|
||||
def get_response_json(base_path: Path, file: str) -> dict:
|
||||
try:
|
||||
return json.loads(
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from nonebot.adapters import Bot
|
||||
from nonebot.plugin import PluginMetadata
|
||||
from nonebot_plugin_alconna import AlconnaQuery, Arparma, Match, Query
|
||||
from nonebot_plugin_session import EventSession
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.configs.utils import PluginExtraData, RegisterConfig
|
||||
@ -9,7 +9,7 @@ from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import BlockType, PluginType
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
|
||||
from ._data_source import PluginManage, build_plugin, build_task, delete_help_image
|
||||
from ._data_source import PluginManager, build_plugin, build_task, delete_help_image
|
||||
from .command import _group_status_matcher, _status_matcher
|
||||
|
||||
base_config = Config.get("plugin_switch")
|
||||
@ -57,6 +57,11 @@ __plugin_meta__ = PluginMetadata(
|
||||
关闭群被动早晚安
|
||||
关闭群被动早晚安 -g 12355555
|
||||
|
||||
开启/关闭默认群被动 [被动名称]
|
||||
私聊下: 开启/关闭群被动默认状态
|
||||
示例:
|
||||
关闭默认群被动 早晚安
|
||||
|
||||
开启/关闭所有群被动 ?[-g [group_id]]
|
||||
私聊中: 开启/关闭全局或指定群组被动状态
|
||||
示例:
|
||||
@ -87,10 +92,10 @@ __plugin_meta__ = PluginMetadata(
|
||||
@_status_matcher.assign("$main")
|
||||
async def _(
|
||||
bot: Bot,
|
||||
session: EventSession,
|
||||
session: Uninfo,
|
||||
arparma: Arparma,
|
||||
):
|
||||
if session.id1 in bot.config.superusers:
|
||||
if session.user.id in bot.config.superusers:
|
||||
image = await build_plugin()
|
||||
logger.info(
|
||||
"查看功能列表",
|
||||
@ -105,7 +110,7 @@ async def _(
|
||||
@_status_matcher.assign("open")
|
||||
async def _(
|
||||
bot: Bot,
|
||||
session: EventSession,
|
||||
session: Uninfo,
|
||||
arparma: Arparma,
|
||||
plugin_name: Match[str],
|
||||
group: Match[str],
|
||||
@ -114,22 +119,23 @@ async def _(
|
||||
all: Query[bool] = AlconnaQuery("all.value", False),
|
||||
):
|
||||
if not all.result and not plugin_name.available:
|
||||
await MessageUtils.build_message("请输入功能名称").finish(reply_to=True)
|
||||
await MessageUtils.build_message("请输入功能/被动名称").finish(reply_to=True)
|
||||
name = plugin_name.result
|
||||
if gid := session.id3 or session.id2:
|
||||
if session.group:
|
||||
group_id = session.group.id
|
||||
"""修改当前群组的数据"""
|
||||
if task.result:
|
||||
if all.result:
|
||||
result = await PluginManage.unblock_group_all_task(gid)
|
||||
result = await PluginManager.unblock_group_all_task(group_id)
|
||||
logger.info("开启所有群组被动", arparma.header_result, session=session)
|
||||
else:
|
||||
result = await PluginManage.unblock_group_task(name, gid)
|
||||
result = await PluginManager.unblock_group_task(name, group_id)
|
||||
logger.info(
|
||||
f"开启群组被动 {name}", arparma.header_result, session=session
|
||||
)
|
||||
elif session.id1 in bot.config.superusers and default_status.result:
|
||||
elif session.user.id in bot.config.superusers and default_status.result:
|
||||
"""单个插件的进群默认修改"""
|
||||
result = await PluginManage.set_default_status(name, True)
|
||||
result = await PluginManager.set_default_status(name, True)
|
||||
logger.info(
|
||||
f"超级用户开启 {name} 功能进群默认开关",
|
||||
arparma.header_result,
|
||||
@ -137,8 +143,8 @@ async def _(
|
||||
)
|
||||
elif all.result:
|
||||
"""所有插件"""
|
||||
result = await PluginManage.set_all_plugin_status(
|
||||
True, default_status.result, gid
|
||||
result = await PluginManager.set_all_plugin_status(
|
||||
True, default_status.result, group_id
|
||||
)
|
||||
logger.info(
|
||||
"开启群组中全部功能",
|
||||
@ -146,22 +152,24 @@ async def _(
|
||||
session=session,
|
||||
)
|
||||
else:
|
||||
result = await PluginManage.unblock_group_plugin(name, gid)
|
||||
result = await PluginManager.unblock_group_plugin(name, group_id)
|
||||
logger.info(f"开启功能 {name}", arparma.header_result, session=session)
|
||||
delete_help_image(gid)
|
||||
delete_help_image(group_id)
|
||||
await MessageUtils.build_message(result).finish(reply_to=True)
|
||||
elif session.id1 in bot.config.superusers:
|
||||
elif session.user.id in bot.config.superusers:
|
||||
"""私聊"""
|
||||
group_id = group.result if group.available else None
|
||||
if all.result:
|
||||
if task.result:
|
||||
"""关闭全局或指定群全部被动"""
|
||||
if group_id:
|
||||
result = await PluginManage.unblock_group_all_task(group_id)
|
||||
result = await PluginManager.unblock_group_all_task(group_id)
|
||||
else:
|
||||
result = await PluginManage.unblock_global_all_task()
|
||||
result = await PluginManager.unblock_global_all_task(
|
||||
default_status.result
|
||||
)
|
||||
else:
|
||||
result = await PluginManage.set_all_plugin_status(
|
||||
result = await PluginManager.set_all_plugin_status(
|
||||
True, default_status.result, group_id
|
||||
)
|
||||
logger.info(
|
||||
@ -171,8 +179,8 @@ async def _(
|
||||
session=session,
|
||||
)
|
||||
await MessageUtils.build_message(result).finish(reply_to=True)
|
||||
if default_status.result:
|
||||
result = await PluginManage.set_default_status(name, True)
|
||||
if default_status.result and not task.result:
|
||||
result = await PluginManager.set_default_status(name, True)
|
||||
logger.info(
|
||||
f"超级用户开启 {name} 功能进群默认开关",
|
||||
arparma.header_result,
|
||||
@ -186,7 +194,7 @@ async def _(
|
||||
name = split_list[0]
|
||||
group_id = split_list[1]
|
||||
if group_id:
|
||||
result = await PluginManage.superuser_task_handle(name, group_id, True)
|
||||
result = await PluginManager.superuser_task_handle(name, group_id, True)
|
||||
logger.info(
|
||||
f"超级用户开启被动技能 {name}",
|
||||
arparma.header_result,
|
||||
@ -194,14 +202,16 @@ async def _(
|
||||
target=group_id,
|
||||
)
|
||||
else:
|
||||
result = await PluginManage.unblock_global_task(name)
|
||||
result = await PluginManager.unblock_global_task(
|
||||
name, default_status.result
|
||||
)
|
||||
logger.info(
|
||||
f"超级用户开启全局被动技能 {name}",
|
||||
arparma.header_result,
|
||||
session=session,
|
||||
)
|
||||
else:
|
||||
result = await PluginManage.superuser_unblock(name, None, group_id)
|
||||
result = await PluginManager.superuser_unblock(name, None, group_id)
|
||||
logger.info(
|
||||
f"超级用户开启功能 {name}",
|
||||
arparma.header_result,
|
||||
@ -215,7 +225,7 @@ async def _(
|
||||
@_status_matcher.assign("close")
|
||||
async def _(
|
||||
bot: Bot,
|
||||
session: EventSession,
|
||||
session: Uninfo,
|
||||
arparma: Arparma,
|
||||
plugin_name: Match[str],
|
||||
block_type: Match[str],
|
||||
@ -225,22 +235,23 @@ async def _(
|
||||
all: Query[bool] = AlconnaQuery("all.value", False),
|
||||
):
|
||||
if not all.result and not plugin_name.available:
|
||||
await MessageUtils.build_message("请输入功能名称").finish(reply_to=True)
|
||||
await MessageUtils.build_message("请输入功能/被动名称").finish(reply_to=True)
|
||||
name = plugin_name.result
|
||||
if gid := session.id3 or session.id2:
|
||||
if session.group:
|
||||
group_id = session.group.id
|
||||
"""修改当前群组的数据"""
|
||||
if task.result:
|
||||
if all.result:
|
||||
result = await PluginManage.block_group_all_task(gid)
|
||||
result = await PluginManager.block_group_all_task(group_id)
|
||||
logger.info("开启所有群组被动", arparma.header_result, session=session)
|
||||
else:
|
||||
result = await PluginManage.block_group_task(name, gid)
|
||||
result = await PluginManager.block_group_task(name, group_id)
|
||||
logger.info(
|
||||
f"关闭群组被动 {name}", arparma.header_result, session=session
|
||||
)
|
||||
elif session.id1 in bot.config.superusers and default_status.result:
|
||||
elif session.user.id in bot.config.superusers and default_status.result:
|
||||
"""单个插件的进群默认修改"""
|
||||
result = await PluginManage.set_default_status(name, False)
|
||||
result = await PluginManager.set_default_status(name, False)
|
||||
logger.info(
|
||||
f"超级用户开启 {name} 功能进群默认开关",
|
||||
arparma.header_result,
|
||||
@ -248,26 +259,28 @@ async def _(
|
||||
)
|
||||
elif all.result:
|
||||
"""所有插件"""
|
||||
result = await PluginManage.set_all_plugin_status(
|
||||
False, default_status.result, gid
|
||||
result = await PluginManager.set_all_plugin_status(
|
||||
False, default_status.result, group_id
|
||||
)
|
||||
logger.info("关闭群组中全部功能", arparma.header_result, session=session)
|
||||
else:
|
||||
result = await PluginManage.block_group_plugin(name, gid)
|
||||
result = await PluginManager.block_group_plugin(name, group_id)
|
||||
logger.info(f"关闭功能 {name}", arparma.header_result, session=session)
|
||||
delete_help_image(gid)
|
||||
delete_help_image(group_id)
|
||||
await MessageUtils.build_message(result).finish(reply_to=True)
|
||||
elif session.id1 in bot.config.superusers:
|
||||
elif session.user.id in bot.config.superusers:
|
||||
group_id = group.result if group.available else None
|
||||
if all.result:
|
||||
if task.result:
|
||||
"""关闭全局或指定群全部被动"""
|
||||
if group_id:
|
||||
result = await PluginManage.block_group_all_task(group_id)
|
||||
result = await PluginManager.block_group_all_task(group_id)
|
||||
else:
|
||||
result = await PluginManage.block_global_all_task()
|
||||
result = await PluginManager.block_global_all_task(
|
||||
default_status.result
|
||||
)
|
||||
else:
|
||||
result = await PluginManage.set_all_plugin_status(
|
||||
result = await PluginManager.set_all_plugin_status(
|
||||
False, default_status.result, group_id
|
||||
)
|
||||
logger.info(
|
||||
@ -277,8 +290,8 @@ async def _(
|
||||
session=session,
|
||||
)
|
||||
await MessageUtils.build_message(result).finish(reply_to=True)
|
||||
if default_status.result:
|
||||
result = await PluginManage.set_default_status(name, False)
|
||||
if default_status.result and not task.result:
|
||||
result = await PluginManager.set_default_status(name, False)
|
||||
logger.info(
|
||||
f"超级用户关闭 {name} 功能进群默认开关",
|
||||
arparma.header_result,
|
||||
@ -292,7 +305,9 @@ async def _(
|
||||
name = split_list[0]
|
||||
group_id = split_list[1]
|
||||
if group_id:
|
||||
result = await PluginManage.superuser_task_handle(name, group_id, False)
|
||||
result = await PluginManager.superuser_task_handle(
|
||||
name, group_id, False
|
||||
)
|
||||
logger.info(
|
||||
f"超级用户关闭被动技能 {name}",
|
||||
arparma.header_result,
|
||||
@ -300,7 +315,9 @@ async def _(
|
||||
target=group_id,
|
||||
)
|
||||
else:
|
||||
result = await PluginManage.block_global_task(name)
|
||||
result = await PluginManager.block_global_task(
|
||||
name, default_status.result
|
||||
)
|
||||
logger.info(
|
||||
f"超级用户关闭全局被动技能 {name}",
|
||||
arparma.header_result,
|
||||
@ -314,7 +331,7 @@ async def _(
|
||||
elif block_type.result in ["g", "group"]:
|
||||
if block_type.available:
|
||||
_type = BlockType.GROUP
|
||||
result = await PluginManage.superuser_block(name, _type, group_id)
|
||||
result = await PluginManager.superuser_block(name, _type, group_id)
|
||||
logger.info(
|
||||
f"超级用户关闭功能 {name}, 禁用类型: {_type}",
|
||||
arparma.header_result,
|
||||
@ -327,19 +344,20 @@ async def _(
|
||||
|
||||
@_group_status_matcher.handle()
|
||||
async def _(
|
||||
session: EventSession,
|
||||
session: Uninfo,
|
||||
arparma: Arparma,
|
||||
status: str,
|
||||
):
|
||||
if gid := session.id3 or session.id2:
|
||||
if session.group:
|
||||
group_id = session.group.id
|
||||
if status == "sleep":
|
||||
await PluginManage.sleep(gid)
|
||||
await PluginManager.sleep(group_id)
|
||||
logger.info("进行休眠", arparma.header_result, session=session)
|
||||
await MessageUtils.build_message("那我先睡觉了...").finish()
|
||||
else:
|
||||
if await PluginManage.is_wake(gid):
|
||||
if await PluginManager.is_wake(group_id):
|
||||
await MessageUtils.build_message("我还醒着呢!").finish()
|
||||
await PluginManage.wake(gid)
|
||||
await PluginManager.wake(group_id)
|
||||
logger.info("醒来", arparma.header_result, session=session)
|
||||
await MessageUtils.build_message("呜..醒来了...").finish()
|
||||
return MessageUtils.build_message("群组id为空...").send()
|
||||
@ -347,10 +365,10 @@ async def _(
|
||||
|
||||
@_status_matcher.assign("task")
|
||||
async def _(
|
||||
session: EventSession,
|
||||
session: Uninfo,
|
||||
arparma: Arparma,
|
||||
):
|
||||
image = await build_task(session.id3 or session.id2)
|
||||
image = await build_task(session.group.id if session.group else None)
|
||||
if image:
|
||||
logger.info("查看群被动列表", arparma.header_result, session=session)
|
||||
await MessageUtils.build_message(image).finish(reply_to=True)
|
||||
|
||||
@ -156,7 +156,7 @@ async def build_task(group_id: str | None) -> BuildImage:
|
||||
)
|
||||
|
||||
|
||||
class PluginManage:
|
||||
class PluginManager:
|
||||
@classmethod
|
||||
async def set_default_status(cls, plugin_name: str, status: bool) -> str:
|
||||
"""设置插件进群默认状态
|
||||
@ -350,17 +350,21 @@ class PluginManage:
|
||||
return await cls._change_group_task("", group_id, True, True)
|
||||
|
||||
@classmethod
|
||||
async def block_global_all_task(cls) -> str:
|
||||
async def block_global_all_task(cls, is_default: bool) -> str:
|
||||
"""禁用全局被动技能
|
||||
|
||||
返回:
|
||||
str: 返回信息
|
||||
"""
|
||||
await TaskInfo.all().update(status=False)
|
||||
return "已全局禁用所有被动状态"
|
||||
if is_default:
|
||||
await TaskInfo.all().update(default_status=False)
|
||||
return "已禁用所有被动进群默认状态"
|
||||
else:
|
||||
await TaskInfo.all().update(status=False)
|
||||
return "已全局禁用所有被动状态"
|
||||
|
||||
@classmethod
|
||||
async def block_global_task(cls, name: str) -> str:
|
||||
async def block_global_task(cls, name: str, is_default: bool = False) -> str:
|
||||
"""禁用全局被动技能
|
||||
|
||||
参数:
|
||||
@ -369,31 +373,47 @@ class PluginManage:
|
||||
返回:
|
||||
str: 返回信息
|
||||
"""
|
||||
await TaskInfo.filter(name=name).update(status=False)
|
||||
return f"已全局禁用被动状态 {name}"
|
||||
if is_default:
|
||||
await TaskInfo.filter(name=name).update(default_status=False)
|
||||
return f"已禁用被动进群默认状态 {name}"
|
||||
else:
|
||||
await TaskInfo.filter(name=name).update(status=False)
|
||||
return f"已全局禁用被动状态 {name}"
|
||||
|
||||
@classmethod
|
||||
async def unblock_global_all_task(cls) -> str:
|
||||
async def unblock_global_all_task(cls, is_default: bool) -> str:
|
||||
"""开启全局被动技能
|
||||
|
||||
参数:
|
||||
is_default: 是否为默认状态
|
||||
|
||||
返回:
|
||||
str: 返回信息
|
||||
"""
|
||||
await TaskInfo.all().update(status=True)
|
||||
return "已全局开启所有被动状态"
|
||||
if is_default:
|
||||
await TaskInfo.all().update(default_status=True)
|
||||
return "已开启所有被动进群默认状态"
|
||||
else:
|
||||
await TaskInfo.all().update(status=True)
|
||||
return "已全局开启所有被动状态"
|
||||
|
||||
@classmethod
|
||||
async def unblock_global_task(cls, name: str) -> str:
|
||||
async def unblock_global_task(cls, name: str, is_default: bool = False) -> str:
|
||||
"""开启全局被动技能
|
||||
|
||||
参数:
|
||||
name: 被动技能名称
|
||||
is_default: 是否为默认状态
|
||||
|
||||
返回:
|
||||
str: 返回信息
|
||||
"""
|
||||
await TaskInfo.filter(name=name).update(status=True)
|
||||
return f"已全局开启被动状态 {name}"
|
||||
if is_default:
|
||||
await TaskInfo.filter(name=name).update(default_status=True)
|
||||
return f"已开启被动进群默认状态 {name}"
|
||||
else:
|
||||
await TaskInfo.filter(name=name).update(status=True)
|
||||
return f"已全局开启被动状态 {name}"
|
||||
|
||||
@classmethod
|
||||
async def unblock_group_plugin(cls, plugin_name: str, group_id: str) -> str:
|
||||
|
||||
@ -58,6 +58,19 @@ _status_matcher.shortcut(
|
||||
prefix=True,
|
||||
)
|
||||
|
||||
_status_matcher.shortcut(
|
||||
r"开启(所有|全部)默认群被动",
|
||||
command="switch",
|
||||
arguments=["open", "--task", "--all", "-df"],
|
||||
prefix=True,
|
||||
)
|
||||
|
||||
_status_matcher.shortcut(
|
||||
r"关闭(所有|全部)默认群被动",
|
||||
command="switch",
|
||||
arguments=["close", "--task", "--all", "-df"],
|
||||
prefix=True,
|
||||
)
|
||||
|
||||
_status_matcher.shortcut(
|
||||
r"开启群被动\s*(?P<name>.+)",
|
||||
@ -73,6 +86,20 @@ _status_matcher.shortcut(
|
||||
prefix=True,
|
||||
)
|
||||
|
||||
_status_matcher.shortcut(
|
||||
r"开启默认群被动\s*(?P<name>.+)",
|
||||
command="switch",
|
||||
arguments=["open", "{name}", "--task", "-df"],
|
||||
prefix=True,
|
||||
)
|
||||
|
||||
_status_matcher.shortcut(
|
||||
r"关闭默认群被动\s*(?P<name>.+)",
|
||||
command="switch",
|
||||
arguments=["close", "{name}", "--task", "-df"],
|
||||
prefix=True,
|
||||
)
|
||||
|
||||
|
||||
_status_matcher.shortcut(
|
||||
r"开启(所有|全部)群被动",
|
||||
|
||||
@ -11,7 +11,7 @@ from nonebot_plugin_alconna import (
|
||||
on_alconna,
|
||||
store_true,
|
||||
)
|
||||
from nonebot_plugin_session import EventSession
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.configs.utils import PluginExtraData
|
||||
from zhenxun.services.log import logger
|
||||
@ -22,7 +22,7 @@ from zhenxun.utils.manager.resource_manager import (
|
||||
)
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
|
||||
from ._data_source import UpdateManage
|
||||
from ._data_source import UpdateManager
|
||||
|
||||
__plugin_meta__ = PluginMetadata(
|
||||
name="自动更新",
|
||||
@ -32,16 +32,18 @@ __plugin_meta__ = PluginMetadata(
|
||||
检查更新真寻最新版本,包括了自动更新
|
||||
资源文件大小一般在130mb左右,除非必须更新一般仅更新代码文件
|
||||
指令:
|
||||
检查更新 [main|release|resource] ?[-r]
|
||||
检查更新 [main|release|resource|webui] ?[-r]
|
||||
main: main分支
|
||||
release: 最新release
|
||||
resource: 资源文件
|
||||
webui: webui文件
|
||||
-r: 下载资源文件,一般在更新main或release时使用
|
||||
示例:
|
||||
检查更新 main
|
||||
检查更新 main -r
|
||||
检查更新 release -r
|
||||
检查更新 resource
|
||||
检查更新 webui
|
||||
""".strip(),
|
||||
extra=PluginExtraData(
|
||||
author="HibiKier",
|
||||
@ -53,7 +55,7 @@ __plugin_meta__ = PluginMetadata(
|
||||
_matcher = on_alconna(
|
||||
Alconna(
|
||||
"检查更新",
|
||||
Args["ver_type?", ["main", "release", "resource"]],
|
||||
Args["ver_type?", ["main", "release", "resource", "webui"]],
|
||||
Option("-r|--resource", action=store_true, help_text="下载资源文件"),
|
||||
),
|
||||
priority=1,
|
||||
@ -66,23 +68,24 @@ _matcher = on_alconna(
|
||||
@_matcher.handle()
|
||||
async def _(
|
||||
bot: Bot,
|
||||
session: EventSession,
|
||||
session: Uninfo,
|
||||
ver_type: Match[str],
|
||||
resource: Query[bool] = Query("resource", False),
|
||||
):
|
||||
if not session.id1:
|
||||
await MessageUtils.build_message("用户id为空...").finish()
|
||||
result = ""
|
||||
await MessageUtils.build_message("正在进行检查更新...").send(reply_to=True)
|
||||
if ver_type.result in {"main", "release"}:
|
||||
if not ver_type.available:
|
||||
result = await UpdateManage.check_version()
|
||||
result = await UpdateManager.check_version()
|
||||
logger.info("查看当前版本...", "检查更新", session=session)
|
||||
await MessageUtils.build_message(result).finish()
|
||||
try:
|
||||
result = await UpdateManage.update(bot, session.id1, ver_type.result)
|
||||
result = await UpdateManager.update(bot, session.user.id, ver_type.result)
|
||||
except Exception as e:
|
||||
logger.error("版本更新失败...", "检查更新", session=session, e=e)
|
||||
await MessageUtils.build_message(f"更新版本失败...e: {e}").finish()
|
||||
elif ver_type.result == "webui":
|
||||
result = await UpdateManager.update_webui()
|
||||
if resource.result or ver_type.result == "resource":
|
||||
try:
|
||||
await ResourceManager.init_resources(True)
|
||||
|
||||
@ -7,6 +7,7 @@ import zipfile
|
||||
from nonebot.adapters import Bot
|
||||
from nonebot.utils import run_sync
|
||||
|
||||
from zhenxun.configs.path_config import DATA_PATH
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.github_utils import GithubUtils
|
||||
from zhenxun.utils.github_utils.models import RepoInfo
|
||||
@ -17,6 +18,7 @@ from .config import (
|
||||
BACKUP_PATH,
|
||||
BASE_PATH,
|
||||
BASE_PATH_STRING,
|
||||
COMMAND,
|
||||
DEFAULT_GITHUB_URL,
|
||||
DOWNLOAD_GZ_FILE,
|
||||
DOWNLOAD_ZIP_FILE,
|
||||
@ -38,7 +40,7 @@ def install_requirement():
|
||||
|
||||
if not requirement_path.exists():
|
||||
logger.debug(
|
||||
f"没有找到zhenxun的requirement.txt,目标路径为{requirement_path}", "插件管理"
|
||||
f"没有找到zhenxun的requirement.txt,目标路径为{requirement_path}", COMMAND
|
||||
)
|
||||
return
|
||||
try:
|
||||
@ -48,9 +50,9 @@ def install_requirement():
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
logger.debug(f"成功安装真寻依赖,日志:\n{result.stdout}", "插件管理")
|
||||
logger.debug(f"成功安装真寻依赖,日志:\n{result.stdout}", COMMAND)
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"安装真寻依赖失败,错误:\n{e.stderr}", "插件管理", e=e)
|
||||
logger.error(f"安装真寻依赖失败,错误:\n{e.stderr}", COMMAND, e=e)
|
||||
|
||||
|
||||
@run_sync
|
||||
@ -61,7 +63,7 @@ def _file_handle(latest_version: str | None):
|
||||
latest_version: 版本号
|
||||
"""
|
||||
BACKUP_PATH.mkdir(exist_ok=True, parents=True)
|
||||
logger.debug("开始解压文件压缩包...", "检查更新")
|
||||
logger.debug("开始解压文件压缩包...", COMMAND)
|
||||
download_file = DOWNLOAD_GZ_FILE
|
||||
if DOWNLOAD_GZ_FILE.exists():
|
||||
tf = tarfile.open(DOWNLOAD_GZ_FILE)
|
||||
@ -69,7 +71,7 @@ def _file_handle(latest_version: str | None):
|
||||
download_file = DOWNLOAD_ZIP_FILE
|
||||
tf = zipfile.ZipFile(DOWNLOAD_ZIP_FILE)
|
||||
tf.extractall(TMP_PATH)
|
||||
logger.debug("解压文件压缩包完成...", "检查更新")
|
||||
logger.debug("解压文件压缩包完成...", COMMAND)
|
||||
download_file_path = TMP_PATH / next(
|
||||
x for x in os.listdir(TMP_PATH) if (TMP_PATH / x).is_dir()
|
||||
)
|
||||
@ -79,52 +81,52 @@ def _file_handle(latest_version: str | None):
|
||||
extract_path = download_file_path / BASE_PATH_STRING
|
||||
target_path = BASE_PATH
|
||||
if PYPROJECT_FILE.exists():
|
||||
logger.debug(f"移除备份文件: {PYPROJECT_FILE}", "检查更新")
|
||||
logger.debug(f"移除备份文件: {PYPROJECT_FILE}", COMMAND)
|
||||
shutil.move(PYPROJECT_FILE, BACKUP_PATH / PYPROJECT_FILE_STRING)
|
||||
if PYPROJECT_LOCK_FILE.exists():
|
||||
logger.debug(f"移除备份文件: {PYPROJECT_LOCK_FILE}", "检查更新")
|
||||
logger.debug(f"移除备份文件: {PYPROJECT_LOCK_FILE}", COMMAND)
|
||||
shutil.move(PYPROJECT_LOCK_FILE, BACKUP_PATH / PYPROJECT_LOCK_FILE_STRING)
|
||||
if REQ_TXT_FILE.exists():
|
||||
logger.debug(f"移除备份文件: {REQ_TXT_FILE}", "检查更新")
|
||||
logger.debug(f"移除备份文件: {REQ_TXT_FILE}", COMMAND)
|
||||
shutil.move(REQ_TXT_FILE, BACKUP_PATH / REQ_TXT_FILE_STRING)
|
||||
if _pyproject.exists():
|
||||
logger.debug("移动文件: pyproject.toml", "检查更新")
|
||||
logger.debug("移动文件: pyproject.toml", COMMAND)
|
||||
shutil.move(_pyproject, PYPROJECT_FILE)
|
||||
if _lock_file.exists():
|
||||
logger.debug("移动文件: poetry.lock", "检查更新")
|
||||
logger.debug("移动文件: poetry.lock", COMMAND)
|
||||
shutil.move(_lock_file, PYPROJECT_LOCK_FILE)
|
||||
if _req_file.exists():
|
||||
logger.debug("移动文件: requirements.txt", "检查更新")
|
||||
logger.debug("移动文件: requirements.txt", COMMAND)
|
||||
shutil.move(_req_file, REQ_TXT_FILE)
|
||||
for folder in REPLACE_FOLDERS:
|
||||
"""移动指定文件夹"""
|
||||
_dir = BASE_PATH / folder
|
||||
_backup_dir = BACKUP_PATH / folder
|
||||
if _backup_dir.exists():
|
||||
logger.debug(f"删除备份文件夹 {_backup_dir}", "检查更新")
|
||||
logger.debug(f"删除备份文件夹 {_backup_dir}", COMMAND)
|
||||
shutil.rmtree(_backup_dir)
|
||||
if _dir.exists():
|
||||
logger.debug(f"移动旧文件夹 {_dir}", "检查更新")
|
||||
logger.debug(f"移动旧文件夹 {_dir}", COMMAND)
|
||||
shutil.move(_dir, _backup_dir)
|
||||
else:
|
||||
logger.warning(f"文件夹 {_dir} 不存在,跳过删除", "检查更新")
|
||||
logger.warning(f"文件夹 {_dir} 不存在,跳过删除", COMMAND)
|
||||
for folder in REPLACE_FOLDERS:
|
||||
src_folder_path = extract_path / folder
|
||||
dest_folder_path = target_path / folder
|
||||
if src_folder_path.exists():
|
||||
logger.debug(
|
||||
f"移动文件夹: {src_folder_path} -> {dest_folder_path}", "检查更新"
|
||||
f"移动文件夹: {src_folder_path} -> {dest_folder_path}", COMMAND
|
||||
)
|
||||
shutil.move(src_folder_path, dest_folder_path)
|
||||
else:
|
||||
logger.debug(f"源文件夹不存在: {src_folder_path}", "检查更新")
|
||||
logger.debug(f"源文件夹不存在: {src_folder_path}", COMMAND)
|
||||
if tf:
|
||||
tf.close()
|
||||
if download_file.exists():
|
||||
logger.debug(f"删除下载文件: {download_file}", "检查更新")
|
||||
logger.debug(f"删除下载文件: {download_file}", COMMAND)
|
||||
download_file.unlink()
|
||||
if extract_path.exists():
|
||||
logger.debug(f"删除解压文件夹: {extract_path}", "检查更新")
|
||||
logger.debug(f"删除解压文件夹: {extract_path}", COMMAND)
|
||||
shutil.rmtree(extract_path)
|
||||
if TMP_PATH.exists():
|
||||
shutil.rmtree(TMP_PATH)
|
||||
@ -134,7 +136,35 @@ def _file_handle(latest_version: str | None):
|
||||
install_requirement()
|
||||
|
||||
|
||||
class UpdateManage:
|
||||
class UpdateManager:
|
||||
@classmethod
|
||||
async def update_webui(cls) -> str:
|
||||
from zhenxun.builtin_plugins.web_ui.public.data_source import (
|
||||
update_webui_assets,
|
||||
)
|
||||
|
||||
WEBUI_PATH = DATA_PATH / "web_ui" / "public"
|
||||
BACKUP_PATH = DATA_PATH / "web_ui" / "backup_public"
|
||||
if WEBUI_PATH.exists():
|
||||
if BACKUP_PATH.exists():
|
||||
logger.debug(f"删除旧的备份webui文件夹 {BACKUP_PATH}", COMMAND)
|
||||
shutil.rmtree(BACKUP_PATH)
|
||||
WEBUI_PATH.rename(BACKUP_PATH)
|
||||
try:
|
||||
await update_webui_assets()
|
||||
logger.info("更新webui成功...", COMMAND)
|
||||
if BACKUP_PATH.exists():
|
||||
logger.debug(f"删除旧的webui文件夹 {BACKUP_PATH}", COMMAND)
|
||||
shutil.rmtree(BACKUP_PATH)
|
||||
return "Webui更新成功!"
|
||||
except Exception as e:
|
||||
logger.error("更新webui失败...", COMMAND, e=e)
|
||||
if BACKUP_PATH.exists():
|
||||
logger.debug(f"恢复旧的webui文件夹 {BACKUP_PATH}", COMMAND)
|
||||
BACKUP_PATH.rename(WEBUI_PATH)
|
||||
raise e
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
async def check_version(cls) -> str:
|
||||
"""检查更新版本
|
||||
@ -166,7 +196,7 @@ class UpdateManage:
|
||||
返回:
|
||||
str | None: 返回消息
|
||||
"""
|
||||
logger.info("开始下载真寻最新版文件....", "检查更新")
|
||||
logger.info("开始下载真寻最新版文件....", COMMAND)
|
||||
cur_version = cls.__get_version()
|
||||
url = None
|
||||
new_version = None
|
||||
@ -186,11 +216,11 @@ class UpdateManage:
|
||||
if not url:
|
||||
return "获取版本下载链接失败..."
|
||||
if TMP_PATH.exists():
|
||||
logger.debug(f"删除临时文件夹 {TMP_PATH}", "检查更新")
|
||||
logger.debug(f"删除临时文件夹 {TMP_PATH}", COMMAND)
|
||||
shutil.rmtree(TMP_PATH)
|
||||
logger.debug(
|
||||
f"开始更新版本:{cur_version} -> {new_version} | 下载链接:{url}",
|
||||
"检查更新",
|
||||
COMMAND,
|
||||
)
|
||||
await PlatformUtils.send_superuser(
|
||||
bot,
|
||||
@ -201,7 +231,7 @@ class UpdateManage:
|
||||
DOWNLOAD_GZ_FILE if version_type == "release" else DOWNLOAD_ZIP_FILE
|
||||
)
|
||||
if await AsyncHttpx.download_file(url, download_file, stream=True):
|
||||
logger.debug("下载真寻最新版文件完成...", "检查更新")
|
||||
logger.debug("下载真寻最新版文件完成...", COMMAND)
|
||||
await _file_handle(new_version)
|
||||
result = "版本更新完成"
|
||||
return (
|
||||
@ -210,7 +240,7 @@ class UpdateManage:
|
||||
"请重新启动真寻以完成更新!"
|
||||
)
|
||||
else:
|
||||
logger.debug("下载真寻最新版文件失败...", "检查更新")
|
||||
logger.debug("下载真寻最新版文件失败...", COMMAND)
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -34,3 +34,5 @@ REPLACE_FOLDERS = [
|
||||
"models",
|
||||
"configs",
|
||||
]
|
||||
|
||||
COMMAND = "检查更新"
|
||||
|
||||
@ -54,22 +54,6 @@ __plugin_meta__ = PluginMetadata(
|
||||
default_value=5,
|
||||
type=int,
|
||||
),
|
||||
RegisterConfig(
|
||||
module="_task",
|
||||
key="DEFAULT_GROUP_WELCOME",
|
||||
value=True,
|
||||
help="被动 进群欢迎 进群默认开关状态",
|
||||
default_value=True,
|
||||
type=bool,
|
||||
),
|
||||
RegisterConfig(
|
||||
module="_task",
|
||||
key="DEFAULT_REFUND_GROUP_REMIND",
|
||||
value=True,
|
||||
help="被动 退群提醒 进群默认开关状态",
|
||||
default_value=True,
|
||||
type=bool,
|
||||
),
|
||||
],
|
||||
tasks=[
|
||||
Task(
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
from nonebot.plugin import PluginMetadata
|
||||
|
||||
from zhenxun.configs.utils import PluginExtraData
|
||||
from zhenxun.configs.utils import PluginExtraData, RegisterConfig
|
||||
from zhenxun.utils.enum import PluginType
|
||||
|
||||
from . import command # noqa: F401
|
||||
from . import commands, handlers
|
||||
|
||||
__all__ = ["commands", "handlers"]
|
||||
|
||||
__plugin_meta__ = PluginMetadata(
|
||||
name="定时任务管理",
|
||||
@ -27,6 +29,8 @@ __plugin_meta__ = PluginMetadata(
|
||||
定时任务 恢复 <任务ID> | -p <插件> [-g <群号>] | -all
|
||||
定时任务 执行 <任务ID>
|
||||
定时任务 更新 <任务ID> [时间选项] [--kwargs <参数>]
|
||||
# [修改] 增加说明
|
||||
• 说明: -p 选项可单独使用,用于操作指定插件的所有任务
|
||||
|
||||
📝 时间选项 (三选一):
|
||||
--cron "<分> <时> <日> <月> <周>" # 例: --cron "0 8 * * *"
|
||||
@ -47,5 +51,35 @@ __plugin_meta__ = PluginMetadata(
|
||||
version="0.1.2",
|
||||
plugin_type=PluginType.SUPERUSER,
|
||||
is_show=False,
|
||||
configs=[
|
||||
RegisterConfig(
|
||||
module="SchedulerManager",
|
||||
key="ALL_GROUPS_CONCURRENCY_LIMIT",
|
||||
value=5,
|
||||
help="“所有群组”类型定时任务的并发执行数量限制",
|
||||
type=int,
|
||||
),
|
||||
RegisterConfig(
|
||||
module="SchedulerManager",
|
||||
key="JOB_MAX_RETRIES",
|
||||
value=2,
|
||||
help="定时任务执行失败时的最大重试次数",
|
||||
type=int,
|
||||
),
|
||||
RegisterConfig(
|
||||
module="SchedulerManager",
|
||||
key="JOB_RETRY_DELAY",
|
||||
value=10,
|
||||
help="定时任务执行重试的间隔时间(秒)",
|
||||
type=int,
|
||||
),
|
||||
RegisterConfig(
|
||||
module="SchedulerManager",
|
||||
key="SCHEDULER_TIMEZONE",
|
||||
value="Asia/Shanghai",
|
||||
help="定时任务使用的时区,默认为 Asia/Shanghai",
|
||||
type=str,
|
||||
),
|
||||
],
|
||||
).to_dict(),
|
||||
)
|
||||
|
||||
@ -1,836 +0,0 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
import re
|
||||
|
||||
from nonebot.adapters import Event
|
||||
from nonebot.adapters.onebot.v11 import Bot
|
||||
from nonebot.params import Depends
|
||||
from nonebot.permission import SUPERUSER
|
||||
from nonebot_plugin_alconna import (
|
||||
Alconna,
|
||||
AlconnaMatch,
|
||||
Args,
|
||||
Arparma,
|
||||
Match,
|
||||
Option,
|
||||
Query,
|
||||
Subcommand,
|
||||
on_alconna,
|
||||
)
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from zhenxun.utils._image_template import ImageTemplate
|
||||
from zhenxun.utils.manager.schedule_manager import scheduler_manager
|
||||
|
||||
|
||||
def _get_type_name(annotation) -> str:
|
||||
"""获取类型注解的名称"""
|
||||
if hasattr(annotation, "__name__"):
|
||||
return annotation.__name__
|
||||
elif hasattr(annotation, "_name"):
|
||||
return annotation._name
|
||||
else:
|
||||
return str(annotation)
|
||||
|
||||
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
from zhenxun.utils.rules import admin_check
|
||||
|
||||
|
||||
def _format_trigger(schedule_status: dict) -> str:
|
||||
"""将触发器配置格式化为人类可读的字符串"""
|
||||
trigger_type = schedule_status["trigger_type"]
|
||||
config = schedule_status["trigger_config"]
|
||||
|
||||
if trigger_type == "cron":
|
||||
minute = config.get("minute", "*")
|
||||
hour = config.get("hour", "*")
|
||||
day = config.get("day", "*")
|
||||
month = config.get("month", "*")
|
||||
day_of_week = config.get("day_of_week", "*")
|
||||
|
||||
if day == "*" and month == "*" and day_of_week == "*":
|
||||
formatted_hour = hour if hour == "*" else f"{int(hour):02d}"
|
||||
formatted_minute = minute if minute == "*" else f"{int(minute):02d}"
|
||||
return f"每天 {formatted_hour}:{formatted_minute}"
|
||||
else:
|
||||
return f"Cron: {minute} {hour} {day} {month} {day_of_week}"
|
||||
elif trigger_type == "interval":
|
||||
seconds = config.get("seconds", 0)
|
||||
minutes = config.get("minutes", 0)
|
||||
hours = config.get("hours", 0)
|
||||
days = config.get("days", 0)
|
||||
if days:
|
||||
trigger_str = f"每 {days} 天"
|
||||
elif hours:
|
||||
trigger_str = f"每 {hours} 小时"
|
||||
elif minutes:
|
||||
trigger_str = f"每 {minutes} 分钟"
|
||||
else:
|
||||
trigger_str = f"每 {seconds} 秒"
|
||||
elif trigger_type == "date":
|
||||
run_date = config.get("run_date", "未知时间")
|
||||
trigger_str = f"在 {run_date}"
|
||||
else:
|
||||
trigger_str = f"{trigger_type}: {config}"
|
||||
|
||||
return trigger_str
|
||||
|
||||
|
||||
def _format_params(schedule_status: dict) -> str:
|
||||
"""将任务参数格式化为人类可读的字符串"""
|
||||
if kwargs := schedule_status.get("job_kwargs"):
|
||||
kwargs_str = " | ".join(f"{k}: {v}" for k, v in kwargs.items())
|
||||
return kwargs_str
|
||||
return "-"
|
||||
|
||||
|
||||
def _parse_interval(interval_str: str) -> dict:
|
||||
"""增强版解析器,支持 d(天)"""
|
||||
match = re.match(r"(\d+)([smhd])", interval_str.lower())
|
||||
if not match:
|
||||
raise ValueError("时间间隔格式错误, 请使用如 '30m', '2h', '1d', '10s' 的格式。")
|
||||
|
||||
value, unit = int(match.group(1)), match.group(2)
|
||||
if unit == "s":
|
||||
return {"seconds": value}
|
||||
if unit == "m":
|
||||
return {"minutes": value}
|
||||
if unit == "h":
|
||||
return {"hours": value}
|
||||
if unit == "d":
|
||||
return {"days": value}
|
||||
return {}
|
||||
|
||||
|
||||
def _parse_daily_time(time_str: str) -> dict:
|
||||
"""解析 HH:MM 或 HH:MM:SS 格式的时间为 cron 配置"""
|
||||
if match := re.match(r"^(\d{1,2}):(\d{1,2})(?::(\d{1,2}))?$", time_str):
|
||||
hour, minute, second = match.groups()
|
||||
hour, minute = int(hour), int(minute)
|
||||
|
||||
if not (0 <= hour <= 23 and 0 <= minute <= 59):
|
||||
raise ValueError("小时或分钟数值超出范围。")
|
||||
|
||||
cron_config = {
|
||||
"minute": str(minute),
|
||||
"hour": str(hour),
|
||||
"day": "*",
|
||||
"month": "*",
|
||||
"day_of_week": "*",
|
||||
}
|
||||
if second is not None:
|
||||
if not (0 <= int(second) <= 59):
|
||||
raise ValueError("秒数值超出范围。")
|
||||
cron_config["second"] = str(second)
|
||||
|
||||
return cron_config
|
||||
else:
|
||||
raise ValueError("时间格式错误,请使用 'HH:MM' 或 'HH:MM:SS' 格式。")
|
||||
|
||||
|
||||
async def GetBotId(
|
||||
bot: Bot,
|
||||
bot_id_match: Match[str] = AlconnaMatch("bot_id"),
|
||||
) -> str:
|
||||
"""获取要操作的Bot ID"""
|
||||
if bot_id_match.available:
|
||||
return bot_id_match.result
|
||||
return bot.self_id
|
||||
|
||||
|
||||
class ScheduleTarget:
|
||||
"""定时任务操作目标的基类"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TargetByID(ScheduleTarget):
|
||||
"""按任务ID操作"""
|
||||
|
||||
def __init__(self, id: int):
|
||||
self.id = id
|
||||
|
||||
|
||||
class TargetByPlugin(ScheduleTarget):
|
||||
"""按插件名操作"""
|
||||
|
||||
def __init__(
|
||||
self, plugin: str, group_id: str | None = None, all_groups: bool = False
|
||||
):
|
||||
self.plugin = plugin
|
||||
self.group_id = group_id
|
||||
self.all_groups = all_groups
|
||||
|
||||
|
||||
class TargetAll(ScheduleTarget):
|
||||
"""操作所有任务"""
|
||||
|
||||
def __init__(self, for_group: str | None = None):
|
||||
self.for_group = for_group
|
||||
|
||||
|
||||
TargetScope = TargetByID | TargetByPlugin | TargetAll | None
|
||||
|
||||
|
||||
def create_target_parser(subcommand_name: str):
|
||||
"""
|
||||
创建一个依赖注入函数,用于解析删除、暂停、恢复等命令的操作目标。
|
||||
"""
|
||||
|
||||
async def dependency(
|
||||
event: Event,
|
||||
schedule_id: Match[int] = AlconnaMatch("schedule_id"),
|
||||
plugin_name: Match[str] = AlconnaMatch("plugin_name"),
|
||||
group_id: Match[str] = AlconnaMatch("group_id"),
|
||||
all_enabled: Query[bool] = Query(f"{subcommand_name}.all"),
|
||||
) -> TargetScope:
|
||||
if schedule_id.available:
|
||||
return TargetByID(schedule_id.result)
|
||||
|
||||
if plugin_name.available:
|
||||
p_name = plugin_name.result
|
||||
if all_enabled.available:
|
||||
return TargetByPlugin(plugin=p_name, all_groups=True)
|
||||
elif group_id.available:
|
||||
gid = group_id.result
|
||||
if gid.lower() == "all":
|
||||
return TargetByPlugin(plugin=p_name, all_groups=True)
|
||||
return TargetByPlugin(plugin=p_name, group_id=gid)
|
||||
else:
|
||||
current_group_id = getattr(event, "group_id", None)
|
||||
if current_group_id:
|
||||
return TargetByPlugin(plugin=p_name, group_id=str(current_group_id))
|
||||
else:
|
||||
await schedule_cmd.finish(
|
||||
"私聊中操作插件任务必须使用 -g <群号> 或 -all 选项。"
|
||||
)
|
||||
|
||||
if all_enabled.available:
|
||||
return TargetAll(for_group=group_id.result if group_id.available else None)
|
||||
|
||||
return None
|
||||
|
||||
return dependency
|
||||
|
||||
|
||||
schedule_cmd = on_alconna(
|
||||
Alconna(
|
||||
"定时任务",
|
||||
Subcommand(
|
||||
"查看",
|
||||
Option("-g", Args["target_group_id", str]),
|
||||
Option("-all", help_text="查看所有群聊 (SUPERUSER)"),
|
||||
Option("-p", Args["plugin_name", str], help_text="按插件名筛选"),
|
||||
Option("--page", Args["page", int, 1], help_text="指定页码"),
|
||||
alias=["ls", "list"],
|
||||
help_text="查看定时任务",
|
||||
),
|
||||
Subcommand(
|
||||
"设置",
|
||||
Args["plugin_name", str],
|
||||
Option("--cron", Args["cron_expr", str], help_text="设置 cron 表达式"),
|
||||
Option("--interval", Args["interval_expr", str], help_text="设置时间间隔"),
|
||||
Option("--date", Args["date_expr", str], help_text="设置特定执行日期"),
|
||||
Option(
|
||||
"--daily",
|
||||
Args["daily_expr", str],
|
||||
help_text="设置每天执行的时间 (如 08:20)",
|
||||
),
|
||||
Option("-g", Args["group_id", str], help_text="指定群组ID或'all'"),
|
||||
Option("-all", help_text="对所有群生效 (等同于 -g all)"),
|
||||
Option("--kwargs", Args["kwargs_str", str], help_text="设置任务参数"),
|
||||
Option(
|
||||
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
|
||||
),
|
||||
alias=["add", "开启"],
|
||||
help_text="设置/开启一个定时任务",
|
||||
),
|
||||
Subcommand(
|
||||
"删除",
|
||||
Args["schedule_id?", int],
|
||||
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
|
||||
Option("-g", Args["group_id", str], help_text="指定群组ID"),
|
||||
Option("-all", help_text="对所有群生效"),
|
||||
Option(
|
||||
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
|
||||
),
|
||||
alias=["del", "rm", "remove", "关闭", "取消"],
|
||||
help_text="删除一个或多个定时任务",
|
||||
),
|
||||
Subcommand(
|
||||
"暂停",
|
||||
Args["schedule_id?", int],
|
||||
Option("-all", help_text="对当前群所有任务生效"),
|
||||
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
|
||||
Option("-g", Args["group_id", str], help_text="指定群组ID (SUPERUSER)"),
|
||||
Option(
|
||||
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
|
||||
),
|
||||
alias=["pause"],
|
||||
help_text="暂停一个或多个定时任务",
|
||||
),
|
||||
Subcommand(
|
||||
"恢复",
|
||||
Args["schedule_id?", int],
|
||||
Option("-all", help_text="对当前群所有任务生效"),
|
||||
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
|
||||
Option("-g", Args["group_id", str], help_text="指定群组ID (SUPERUSER)"),
|
||||
Option(
|
||||
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
|
||||
),
|
||||
alias=["resume"],
|
||||
help_text="恢复一个或多个定时任务",
|
||||
),
|
||||
Subcommand(
|
||||
"执行",
|
||||
Args["schedule_id", int],
|
||||
alias=["trigger", "run"],
|
||||
help_text="立即执行一次任务",
|
||||
),
|
||||
Subcommand(
|
||||
"更新",
|
||||
Args["schedule_id", int],
|
||||
Option("--cron", Args["cron_expr", str], help_text="设置 cron 表达式"),
|
||||
Option("--interval", Args["interval_expr", str], help_text="设置时间间隔"),
|
||||
Option("--date", Args["date_expr", str], help_text="设置特定执行日期"),
|
||||
Option(
|
||||
"--daily",
|
||||
Args["daily_expr", str],
|
||||
help_text="更新每天执行的时间 (如 08:20)",
|
||||
),
|
||||
Option("--kwargs", Args["kwargs_str", str], help_text="更新参数"),
|
||||
alias=["update", "modify", "修改"],
|
||||
help_text="更新任务配置",
|
||||
),
|
||||
Subcommand(
|
||||
"状态",
|
||||
Args["schedule_id", int],
|
||||
alias=["status", "info"],
|
||||
help_text="查看单个任务的详细状态",
|
||||
),
|
||||
Subcommand(
|
||||
"插件列表",
|
||||
alias=["plugins"],
|
||||
help_text="列出所有可用的插件",
|
||||
),
|
||||
),
|
||||
priority=5,
|
||||
block=True,
|
||||
rule=admin_check(1),
|
||||
)
|
||||
|
||||
schedule_cmd.shortcut(
|
||||
"任务状态",
|
||||
command="定时任务",
|
||||
arguments=["状态", "{%0}"],
|
||||
prefix=True,
|
||||
)
|
||||
|
||||
|
||||
@schedule_cmd.handle()
|
||||
async def _handle_time_options_mutex(arp: Arparma):
|
||||
time_options = ["cron", "interval", "date", "daily"]
|
||||
provided_options = [opt for opt in time_options if arp.query(opt) is not None]
|
||||
if len(provided_options) > 1:
|
||||
await schedule_cmd.finish(
|
||||
f"时间选项 --{', --'.join(provided_options)} 不能同时使用,请只选择一个。"
|
||||
)
|
||||
|
||||
|
||||
@schedule_cmd.assign("查看")
|
||||
async def _(
|
||||
bot: Bot,
|
||||
event: Event,
|
||||
target_group_id: Match[str] = AlconnaMatch("target_group_id"),
|
||||
all_groups: Query[bool] = Query("查看.all"),
|
||||
plugin_name: Match[str] = AlconnaMatch("plugin_name"),
|
||||
page: Match[int] = AlconnaMatch("page"),
|
||||
):
|
||||
is_superuser = await SUPERUSER(bot, event)
|
||||
schedules = []
|
||||
title = ""
|
||||
|
||||
current_group_id = getattr(event, "group_id", None)
|
||||
if not (all_groups.available or target_group_id.available) and not current_group_id:
|
||||
await schedule_cmd.finish("私聊中查看任务必须使用 -g <群号> 或 -all 选项。")
|
||||
|
||||
if all_groups.available:
|
||||
if not is_superuser:
|
||||
await schedule_cmd.finish("需要超级用户权限才能查看所有群组的定时任务。")
|
||||
schedules = await scheduler_manager.get_all_schedules()
|
||||
title = "所有群组的定时任务"
|
||||
elif target_group_id.available:
|
||||
if not is_superuser:
|
||||
await schedule_cmd.finish("需要超级用户权限才能查看指定群组的定时任务。")
|
||||
gid = target_group_id.result
|
||||
schedules = [
|
||||
s for s in await scheduler_manager.get_all_schedules() if s.group_id == gid
|
||||
]
|
||||
title = f"群 {gid} 的定时任务"
|
||||
else:
|
||||
gid = str(current_group_id)
|
||||
schedules = [
|
||||
s for s in await scheduler_manager.get_all_schedules() if s.group_id == gid
|
||||
]
|
||||
title = "本群的定时任务"
|
||||
|
||||
if plugin_name.available:
|
||||
schedules = [s for s in schedules if s.plugin_name == plugin_name.result]
|
||||
title += f" [插件: {plugin_name.result}]"
|
||||
|
||||
if not schedules:
|
||||
await schedule_cmd.finish("没有找到任何相关的定时任务。")
|
||||
|
||||
page_size = 15
|
||||
current_page = page.result
|
||||
total_items = len(schedules)
|
||||
total_pages = (total_items + page_size - 1) // page_size
|
||||
start_index = (current_page - 1) * page_size
|
||||
end_index = start_index + page_size
|
||||
paginated_schedules = schedules[start_index:end_index]
|
||||
|
||||
if not paginated_schedules:
|
||||
await schedule_cmd.finish("这一页没有内容了哦~")
|
||||
|
||||
status_tasks = [
|
||||
scheduler_manager.get_schedule_status(s.id) for s in paginated_schedules
|
||||
]
|
||||
all_statuses = await asyncio.gather(*status_tasks)
|
||||
data_list = [
|
||||
[
|
||||
s["id"],
|
||||
s["plugin_name"],
|
||||
s.get("bot_id") or "N/A",
|
||||
s["group_id"] or "全局",
|
||||
s["next_run_time"],
|
||||
_format_trigger(s),
|
||||
_format_params(s),
|
||||
"✔️ 已启用" if s["is_enabled"] else "⏸️ 已暂停",
|
||||
]
|
||||
for s in all_statuses
|
||||
if s
|
||||
]
|
||||
|
||||
if not data_list:
|
||||
await schedule_cmd.finish("没有找到任何相关的定时任务。")
|
||||
|
||||
img = await ImageTemplate.table_page(
|
||||
head_text=title,
|
||||
tip_text=f"第 {current_page}/{total_pages} 页,共 {total_items} 条任务",
|
||||
column_name=[
|
||||
"ID",
|
||||
"插件",
|
||||
"Bot ID",
|
||||
"群组/目标",
|
||||
"下次运行",
|
||||
"触发规则",
|
||||
"参数",
|
||||
"状态",
|
||||
],
|
||||
data_list=data_list,
|
||||
column_space=20,
|
||||
)
|
||||
await MessageUtils.build_message(img).send(reply_to=True)
|
||||
|
||||
|
||||
@schedule_cmd.assign("设置")
|
||||
async def _(
|
||||
event: Event,
|
||||
plugin_name: str,
|
||||
cron_expr: str | None = None,
|
||||
interval_expr: str | None = None,
|
||||
date_expr: str | None = None,
|
||||
daily_expr: str | None = None,
|
||||
group_id: str | None = None,
|
||||
kwargs_str: str | None = None,
|
||||
all_enabled: Query[bool] = Query("设置.all"),
|
||||
bot_id_to_operate: str = Depends(GetBotId),
|
||||
):
|
||||
if plugin_name not in scheduler_manager._registered_tasks:
|
||||
await schedule_cmd.finish(
|
||||
f"插件 '{plugin_name}' 没有注册可用的定时任务。\n"
|
||||
f"可用插件: {list(scheduler_manager._registered_tasks.keys())}"
|
||||
)
|
||||
|
||||
trigger_type = ""
|
||||
trigger_config = {}
|
||||
|
||||
try:
|
||||
if cron_expr:
|
||||
trigger_type = "cron"
|
||||
parts = cron_expr.split()
|
||||
if len(parts) != 5:
|
||||
raise ValueError("Cron 表达式必须有5个部分 (分 时 日 月 周)")
|
||||
cron_keys = ["minute", "hour", "day", "month", "day_of_week"]
|
||||
trigger_config = dict(zip(cron_keys, parts))
|
||||
elif interval_expr:
|
||||
trigger_type = "interval"
|
||||
trigger_config = _parse_interval(interval_expr)
|
||||
elif date_expr:
|
||||
trigger_type = "date"
|
||||
trigger_config = {"run_date": datetime.fromisoformat(date_expr)}
|
||||
elif daily_expr:
|
||||
trigger_type = "cron"
|
||||
trigger_config = _parse_daily_time(daily_expr)
|
||||
else:
|
||||
await schedule_cmd.finish(
|
||||
"必须提供一种时间选项: --cron, --interval, --date, 或 --daily。"
|
||||
)
|
||||
except ValueError as e:
|
||||
await schedule_cmd.finish(f"时间参数解析错误: {e}")
|
||||
|
||||
job_kwargs = {}
|
||||
if kwargs_str:
|
||||
task_meta = scheduler_manager._registered_tasks[plugin_name]
|
||||
params_model = task_meta.get("model")
|
||||
if not params_model:
|
||||
await schedule_cmd.finish(f"插件 '{plugin_name}' 不支持设置额外参数。")
|
||||
|
||||
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
|
||||
await schedule_cmd.finish(f"插件 '{plugin_name}' 的参数模型配置错误。")
|
||||
|
||||
raw_kwargs = {}
|
||||
try:
|
||||
for item in kwargs_str.split(","):
|
||||
key, value = item.strip().split("=", 1)
|
||||
raw_kwargs[key.strip()] = value
|
||||
except Exception as e:
|
||||
await schedule_cmd.finish(
|
||||
f"参数格式错误,请使用 'key=value,key2=value2' 格式。错误: {e}"
|
||||
)
|
||||
|
||||
try:
|
||||
model_validate = getattr(params_model, "model_validate", None)
|
||||
if not model_validate:
|
||||
await schedule_cmd.finish(
|
||||
f"插件 '{plugin_name}' 的参数模型不支持验证。"
|
||||
)
|
||||
return
|
||||
|
||||
validated_model = model_validate(raw_kwargs)
|
||||
|
||||
model_dump = getattr(validated_model, "model_dump", None)
|
||||
if not model_dump:
|
||||
await schedule_cmd.finish(
|
||||
f"插件 '{plugin_name}' 的参数模型不支持导出。"
|
||||
)
|
||||
return
|
||||
|
||||
job_kwargs = model_dump()
|
||||
except ValidationError as e:
|
||||
errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()]
|
||||
error_str = "\n".join(errors)
|
||||
await schedule_cmd.finish(
|
||||
f"插件 '{plugin_name}' 的任务参数验证失败:\n{error_str}"
|
||||
)
|
||||
return
|
||||
|
||||
target_group_id: str | None
|
||||
current_group_id = getattr(event, "group_id", None)
|
||||
|
||||
if group_id and group_id.lower() == "all":
|
||||
target_group_id = "__ALL_GROUPS__"
|
||||
elif all_enabled.available:
|
||||
target_group_id = "__ALL_GROUPS__"
|
||||
elif group_id:
|
||||
target_group_id = group_id
|
||||
elif current_group_id:
|
||||
target_group_id = str(current_group_id)
|
||||
else:
|
||||
await schedule_cmd.finish(
|
||||
"私聊中设置定时任务时,必须使用 -g <群号> 或 --all 选项指定目标。"
|
||||
)
|
||||
return
|
||||
|
||||
success, msg = await scheduler_manager.add_schedule(
|
||||
plugin_name,
|
||||
target_group_id,
|
||||
trigger_type,
|
||||
trigger_config,
|
||||
job_kwargs,
|
||||
bot_id=bot_id_to_operate,
|
||||
)
|
||||
|
||||
if target_group_id == "__ALL_GROUPS__":
|
||||
target_desc = f"所有群组 (Bot: {bot_id_to_operate})"
|
||||
elif target_group_id is None:
|
||||
target_desc = "全局"
|
||||
else:
|
||||
target_desc = f"群组 {target_group_id}"
|
||||
|
||||
if success:
|
||||
await schedule_cmd.finish(f"已成功为 [{target_desc}] {msg}")
|
||||
else:
|
||||
await schedule_cmd.finish(f"为 [{target_desc}] 设置任务失败: {msg}")
|
||||
|
||||
|
||||
@schedule_cmd.assign("删除")
|
||||
async def _(
|
||||
target: TargetScope = Depends(create_target_parser("删除")),
|
||||
bot_id_to_operate: str = Depends(GetBotId),
|
||||
):
|
||||
if isinstance(target, TargetByID):
|
||||
_, message = await scheduler_manager.remove_schedule_by_id(target.id)
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
elif isinstance(target, TargetByPlugin):
|
||||
p_name = target.plugin
|
||||
if p_name not in scheduler_manager.get_registered_plugins():
|
||||
await schedule_cmd.finish(f"未找到插件 '{p_name}'。")
|
||||
|
||||
if target.all_groups:
|
||||
removed_count = await scheduler_manager.remove_schedule_for_all(
|
||||
p_name, bot_id=bot_id_to_operate
|
||||
)
|
||||
message = (
|
||||
f"已取消了 {removed_count} 个群组的插件 '{p_name}' 定时任务。"
|
||||
if removed_count > 0
|
||||
else f"没有找到插件 '{p_name}' 的定时任务。"
|
||||
)
|
||||
await schedule_cmd.finish(message)
|
||||
else:
|
||||
_, message = await scheduler_manager.remove_schedule(
|
||||
p_name, target.group_id, bot_id=bot_id_to_operate
|
||||
)
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
elif isinstance(target, TargetAll):
|
||||
if target.for_group:
|
||||
_, message = await scheduler_manager.remove_schedules_by_group(
|
||||
target.for_group
|
||||
)
|
||||
await schedule_cmd.finish(message)
|
||||
else:
|
||||
_, message = await scheduler_manager.remove_all_schedules()
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
else:
|
||||
await schedule_cmd.finish(
|
||||
"删除任务失败:请提供任务ID,或通过 -p <插件> 或 -all 指定要删除的任务。"
|
||||
)
|
||||
|
||||
|
||||
@schedule_cmd.assign("暂停")
|
||||
async def _(
|
||||
target: TargetScope = Depends(create_target_parser("暂停")),
|
||||
bot_id_to_operate: str = Depends(GetBotId),
|
||||
):
|
||||
if isinstance(target, TargetByID):
|
||||
_, message = await scheduler_manager.pause_schedule(target.id)
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
elif isinstance(target, TargetByPlugin):
|
||||
p_name = target.plugin
|
||||
if p_name not in scheduler_manager.get_registered_plugins():
|
||||
await schedule_cmd.finish(f"未找到插件 '{p_name}'。")
|
||||
|
||||
if target.all_groups:
|
||||
_, message = await scheduler_manager.pause_schedules_by_plugin(p_name)
|
||||
await schedule_cmd.finish(message)
|
||||
else:
|
||||
_, message = await scheduler_manager.pause_schedule_by_plugin_group(
|
||||
p_name, target.group_id, bot_id=bot_id_to_operate
|
||||
)
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
elif isinstance(target, TargetAll):
|
||||
if target.for_group:
|
||||
_, message = await scheduler_manager.pause_schedules_by_group(
|
||||
target.for_group
|
||||
)
|
||||
await schedule_cmd.finish(message)
|
||||
else:
|
||||
_, message = await scheduler_manager.pause_all_schedules()
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
else:
|
||||
await schedule_cmd.finish("请提供任务ID、使用 -p <插件> 或 -all 选项。")
|
||||
|
||||
|
||||
@schedule_cmd.assign("恢复")
|
||||
async def _(
|
||||
target: TargetScope = Depends(create_target_parser("恢复")),
|
||||
bot_id_to_operate: str = Depends(GetBotId),
|
||||
):
|
||||
if isinstance(target, TargetByID):
|
||||
_, message = await scheduler_manager.resume_schedule(target.id)
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
elif isinstance(target, TargetByPlugin):
|
||||
p_name = target.plugin
|
||||
if p_name not in scheduler_manager.get_registered_plugins():
|
||||
await schedule_cmd.finish(f"未找到插件 '{p_name}'。")
|
||||
|
||||
if target.all_groups:
|
||||
_, message = await scheduler_manager.resume_schedules_by_plugin(p_name)
|
||||
await schedule_cmd.finish(message)
|
||||
else:
|
||||
_, message = await scheduler_manager.resume_schedule_by_plugin_group(
|
||||
p_name, target.group_id, bot_id=bot_id_to_operate
|
||||
)
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
elif isinstance(target, TargetAll):
|
||||
if target.for_group:
|
||||
_, message = await scheduler_manager.resume_schedules_by_group(
|
||||
target.for_group
|
||||
)
|
||||
await schedule_cmd.finish(message)
|
||||
else:
|
||||
_, message = await scheduler_manager.resume_all_schedules()
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
else:
|
||||
await schedule_cmd.finish("请提供任务ID、使用 -p <插件> 或 -all 选项。")
|
||||
|
||||
|
||||
@schedule_cmd.assign("执行")
|
||||
async def _(schedule_id: int):
|
||||
_, message = await scheduler_manager.trigger_now(schedule_id)
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
|
||||
@schedule_cmd.assign("更新")
|
||||
async def _(
|
||||
schedule_id: int,
|
||||
cron_expr: str | None = None,
|
||||
interval_expr: str | None = None,
|
||||
date_expr: str | None = None,
|
||||
daily_expr: str | None = None,
|
||||
kwargs_str: str | None = None,
|
||||
):
|
||||
if not any([cron_expr, interval_expr, date_expr, daily_expr, kwargs_str]):
|
||||
await schedule_cmd.finish(
|
||||
"请提供需要更新的时间 (--cron/--interval/--date/--daily) 或参数 (--kwargs)"
|
||||
)
|
||||
|
||||
trigger_config = None
|
||||
trigger_type = None
|
||||
try:
|
||||
if cron_expr:
|
||||
trigger_type = "cron"
|
||||
parts = cron_expr.split()
|
||||
if len(parts) != 5:
|
||||
raise ValueError("Cron 表达式必须有5个部分")
|
||||
cron_keys = ["minute", "hour", "day", "month", "day_of_week"]
|
||||
trigger_config = dict(zip(cron_keys, parts))
|
||||
elif interval_expr:
|
||||
trigger_type = "interval"
|
||||
trigger_config = _parse_interval(interval_expr)
|
||||
elif date_expr:
|
||||
trigger_type = "date"
|
||||
trigger_config = {"run_date": datetime.fromisoformat(date_expr)}
|
||||
elif daily_expr:
|
||||
trigger_type = "cron"
|
||||
trigger_config = _parse_daily_time(daily_expr)
|
||||
except ValueError as e:
|
||||
await schedule_cmd.finish(f"时间参数解析错误: {e}")
|
||||
|
||||
job_kwargs = None
|
||||
if kwargs_str:
|
||||
schedule = await scheduler_manager.get_schedule_by_id(schedule_id)
|
||||
if not schedule:
|
||||
await schedule_cmd.finish(f"未找到 ID 为 {schedule_id} 的任务。")
|
||||
|
||||
task_meta = scheduler_manager._registered_tasks.get(schedule.plugin_name)
|
||||
if not task_meta or not (params_model := task_meta.get("model")):
|
||||
await schedule_cmd.finish(
|
||||
f"插件 '{schedule.plugin_name}' 未定义参数模型,无法更新参数。"
|
||||
)
|
||||
|
||||
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
|
||||
await schedule_cmd.finish(
|
||||
f"插件 '{schedule.plugin_name}' 的参数模型配置错误。"
|
||||
)
|
||||
|
||||
raw_kwargs = {}
|
||||
try:
|
||||
for item in kwargs_str.split(","):
|
||||
key, value = item.strip().split("=", 1)
|
||||
raw_kwargs[key.strip()] = value
|
||||
except Exception as e:
|
||||
await schedule_cmd.finish(
|
||||
f"参数格式错误,请使用 'key=value,key2=value2' 格式。错误: {e}"
|
||||
)
|
||||
|
||||
try:
|
||||
model_validate = getattr(params_model, "model_validate", None)
|
||||
if not model_validate:
|
||||
await schedule_cmd.finish(
|
||||
f"插件 '{schedule.plugin_name}' 的参数模型不支持验证。"
|
||||
)
|
||||
return
|
||||
|
||||
validated_model = model_validate(raw_kwargs)
|
||||
|
||||
model_dump = getattr(validated_model, "model_dump", None)
|
||||
if not model_dump:
|
||||
await schedule_cmd.finish(
|
||||
f"插件 '{schedule.plugin_name}' 的参数模型不支持导出。"
|
||||
)
|
||||
return
|
||||
|
||||
job_kwargs = model_dump(exclude_unset=True)
|
||||
except ValidationError as e:
|
||||
errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()]
|
||||
error_str = "\n".join(errors)
|
||||
await schedule_cmd.finish(f"更新的参数验证失败:\n{error_str}")
|
||||
return
|
||||
|
||||
_, message = await scheduler_manager.update_schedule(
|
||||
schedule_id, trigger_type, trigger_config, job_kwargs
|
||||
)
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
|
||||
@schedule_cmd.assign("插件列表")
|
||||
async def _():
|
||||
registered_plugins = scheduler_manager.get_registered_plugins()
|
||||
if not registered_plugins:
|
||||
await schedule_cmd.finish("当前没有已注册的定时任务插件。")
|
||||
|
||||
message_parts = ["📋 已注册的定时任务插件:"]
|
||||
for i, plugin_name in enumerate(registered_plugins, 1):
|
||||
task_meta = scheduler_manager._registered_tasks[plugin_name]
|
||||
params_model = task_meta.get("model")
|
||||
|
||||
if not params_model:
|
||||
message_parts.append(f"{i}. {plugin_name} - 无参数")
|
||||
continue
|
||||
|
||||
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
|
||||
message_parts.append(f"{i}. {plugin_name} - ⚠️ 参数模型配置错误")
|
||||
continue
|
||||
|
||||
model_fields = getattr(params_model, "model_fields", None)
|
||||
if model_fields:
|
||||
param_info = ", ".join(
|
||||
f"{field_name}({_get_type_name(field_info.annotation)})"
|
||||
for field_name, field_info in model_fields.items()
|
||||
)
|
||||
message_parts.append(f"{i}. {plugin_name} - 参数: {param_info}")
|
||||
else:
|
||||
message_parts.append(f"{i}. {plugin_name} - 无参数")
|
||||
|
||||
await schedule_cmd.finish("\n".join(message_parts))
|
||||
|
||||
|
||||
@schedule_cmd.assign("状态")
|
||||
async def _(schedule_id: int):
|
||||
status = await scheduler_manager.get_schedule_status(schedule_id)
|
||||
if not status:
|
||||
await schedule_cmd.finish(f"未找到ID为 {schedule_id} 的定时任务。")
|
||||
|
||||
info_lines = [
|
||||
f"📋 定时任务详细信息 (ID: {schedule_id})",
|
||||
"--------------------",
|
||||
f"▫️ 插件: {status['plugin_name']}",
|
||||
f"▫️ Bot ID: {status.get('bot_id') or '默认'}",
|
||||
f"▫️ 目标: {status['group_id'] or '全局'}",
|
||||
f"▫️ 状态: {'✔️ 已启用' if status['is_enabled'] else '⏸️ 已暂停'}",
|
||||
f"▫️ 下次运行: {status['next_run_time']}",
|
||||
f"▫️ 触发规则: {_format_trigger(status)}",
|
||||
f"▫️ 任务参数: {_format_params(status)}",
|
||||
]
|
||||
await schedule_cmd.finish("\n".join(info_lines))
|
||||
298
zhenxun/builtin_plugins/scheduler_admin/commands.py
Normal file
298
zhenxun/builtin_plugins/scheduler_admin/commands.py
Normal file
@ -0,0 +1,298 @@
|
||||
import re
|
||||
|
||||
from nonebot.adapters import Event
|
||||
from nonebot.adapters.onebot.v11 import Bot
|
||||
from nonebot.params import Depends
|
||||
from nonebot.permission import SUPERUSER
|
||||
from nonebot_plugin_alconna import (
|
||||
Alconna,
|
||||
AlconnaMatch,
|
||||
Args,
|
||||
Match,
|
||||
Option,
|
||||
Query,
|
||||
Subcommand,
|
||||
on_alconna,
|
||||
)
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.services.scheduler import scheduler_manager
|
||||
from zhenxun.services.scheduler.targeter import ScheduleTargeter
|
||||
from zhenxun.utils.rules import admin_check
|
||||
|
||||
schedule_cmd = on_alconna(
|
||||
Alconna(
|
||||
"定时任务",
|
||||
Subcommand(
|
||||
"查看",
|
||||
Option("-g", Args["target_group_id", str]),
|
||||
Option("-all", help_text="查看所有群聊 (SUPERUSER)"),
|
||||
Option("-p", Args["plugin_name", str], help_text="按插件名筛选"),
|
||||
Option("--page", Args["page", int, 1], help_text="指定页码"),
|
||||
alias=["ls", "list"],
|
||||
help_text="查看定时任务",
|
||||
),
|
||||
Subcommand(
|
||||
"设置",
|
||||
Args["plugin_name", str],
|
||||
Option("--cron", Args["cron_expr", str], help_text="设置 cron 表达式"),
|
||||
Option("--interval", Args["interval_expr", str], help_text="设置时间间隔"),
|
||||
Option("--date", Args["date_expr", str], help_text="设置特定执行日期"),
|
||||
Option(
|
||||
"--daily",
|
||||
Args["daily_expr", str],
|
||||
help_text="设置每天执行的时间 (如 08:20)",
|
||||
),
|
||||
Option("-g", Args["group_id", str], help_text="指定群组ID或'all'"),
|
||||
Option("-all", help_text="对所有群生效 (等同于 -g all)"),
|
||||
Option("--kwargs", Args["kwargs_str", str], help_text="设置任务参数"),
|
||||
Option(
|
||||
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
|
||||
),
|
||||
alias=["add", "开启"],
|
||||
help_text="设置/开启一个定时任务",
|
||||
),
|
||||
Subcommand(
|
||||
"删除",
|
||||
Args["schedule_id?", int],
|
||||
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
|
||||
Option("-g", Args["group_id", str], help_text="指定群组ID"),
|
||||
Option("-all", help_text="对所有群生效"),
|
||||
Option(
|
||||
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
|
||||
),
|
||||
alias=["del", "rm", "remove", "关闭", "取消"],
|
||||
help_text="删除一个或多个定时任务",
|
||||
),
|
||||
Subcommand(
|
||||
"暂停",
|
||||
Args["schedule_id?", int],
|
||||
Option("-all", help_text="对当前群所有任务生效"),
|
||||
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
|
||||
Option("-g", Args["group_id", str], help_text="指定群组ID (SUPERUSER)"),
|
||||
Option(
|
||||
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
|
||||
),
|
||||
alias=["pause"],
|
||||
help_text="暂停一个或多个定时任务",
|
||||
),
|
||||
Subcommand(
|
||||
"恢复",
|
||||
Args["schedule_id?", int],
|
||||
Option("-all", help_text="对当前群所有任务生效"),
|
||||
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
|
||||
Option("-g", Args["group_id", str], help_text="指定群组ID (SUPERUSER)"),
|
||||
Option(
|
||||
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
|
||||
),
|
||||
alias=["resume"],
|
||||
help_text="恢复一个或多个定时任务",
|
||||
),
|
||||
Subcommand(
|
||||
"执行",
|
||||
Args["schedule_id", int],
|
||||
alias=["trigger", "run"],
|
||||
help_text="立即执行一次任务",
|
||||
),
|
||||
Subcommand(
|
||||
"更新",
|
||||
Args["schedule_id", int],
|
||||
Option("--cron", Args["cron_expr", str], help_text="设置 cron 表达式"),
|
||||
Option("--interval", Args["interval_expr", str], help_text="设置时间间隔"),
|
||||
Option("--date", Args["date_expr", str], help_text="设置特定执行日期"),
|
||||
Option(
|
||||
"--daily",
|
||||
Args["daily_expr", str],
|
||||
help_text="更新每天执行的时间 (如 08:20)",
|
||||
),
|
||||
Option("--kwargs", Args["kwargs_str", str], help_text="更新参数"),
|
||||
alias=["update", "modify", "修改"],
|
||||
help_text="更新任务配置",
|
||||
),
|
||||
Subcommand(
|
||||
"状态",
|
||||
Args["schedule_id", int],
|
||||
alias=["status", "info"],
|
||||
help_text="查看单个任务的详细状态",
|
||||
),
|
||||
Subcommand(
|
||||
"插件列表",
|
||||
alias=["plugins"],
|
||||
help_text="列出所有可用的插件",
|
||||
),
|
||||
),
|
||||
priority=5,
|
||||
block=True,
|
||||
rule=admin_check(1),
|
||||
)
|
||||
|
||||
schedule_cmd.shortcut(
|
||||
"任务状态",
|
||||
command="定时任务",
|
||||
arguments=["状态", "{%0}"],
|
||||
prefix=True,
|
||||
)
|
||||
|
||||
|
||||
class ScheduleTarget:
|
||||
pass
|
||||
|
||||
|
||||
class TargetByID(ScheduleTarget):
|
||||
def __init__(self, id: int):
|
||||
self.id = id
|
||||
|
||||
|
||||
class TargetByPlugin(ScheduleTarget):
|
||||
def __init__(
|
||||
self, plugin: str, group_id: str | None = None, all_groups: bool = False
|
||||
):
|
||||
self.plugin = plugin
|
||||
self.group_id = group_id
|
||||
self.all_groups = all_groups
|
||||
|
||||
|
||||
class TargetAll(ScheduleTarget):
|
||||
def __init__(self, for_group: str | None = None):
|
||||
self.for_group = for_group
|
||||
|
||||
|
||||
TargetScope = TargetByID | TargetByPlugin | TargetAll | None
|
||||
|
||||
|
||||
def create_target_parser(subcommand_name: str):
|
||||
async def dependency(
|
||||
event: Event,
|
||||
schedule_id: Match[int] = AlconnaMatch("schedule_id"),
|
||||
plugin_name: Match[str] = AlconnaMatch("plugin_name"),
|
||||
group_id: Match[str] = AlconnaMatch("group_id"),
|
||||
all_enabled: Query[bool] = Query(f"{subcommand_name}.all"),
|
||||
) -> TargetScope:
|
||||
if schedule_id.available:
|
||||
return TargetByID(schedule_id.result)
|
||||
|
||||
if plugin_name.available:
|
||||
p_name = plugin_name.result
|
||||
if all_enabled.available:
|
||||
return TargetByPlugin(plugin=p_name, all_groups=True)
|
||||
elif group_id.available:
|
||||
gid = group_id.result
|
||||
if gid.lower() == "all":
|
||||
return TargetByPlugin(plugin=p_name, all_groups=True)
|
||||
return TargetByPlugin(plugin=p_name, group_id=gid)
|
||||
else:
|
||||
current_group_id = getattr(event, "group_id", None)
|
||||
return TargetByPlugin(
|
||||
plugin=p_name,
|
||||
group_id=str(current_group_id) if current_group_id else None,
|
||||
)
|
||||
|
||||
if all_enabled.available:
|
||||
current_group_id = getattr(event, "group_id", None)
|
||||
if not current_group_id:
|
||||
await schedule_cmd.finish(
|
||||
"私聊中单独使用 -all 选项时,必须使用 -g <群号> 指定目标。"
|
||||
)
|
||||
return TargetAll(for_group=str(current_group_id))
|
||||
|
||||
return None
|
||||
|
||||
return dependency
|
||||
|
||||
|
||||
def parse_interval(interval_str: str) -> dict:
|
||||
match = re.match(r"(\d+)([smhd])", interval_str.lower())
|
||||
if not match:
|
||||
raise ValueError("时间间隔格式错误, 请使用如 '30m', '2h', '1d', '10s' 的格式。")
|
||||
value, unit = int(match.group(1)), match.group(2)
|
||||
if unit == "s":
|
||||
return {"seconds": value}
|
||||
if unit == "m":
|
||||
return {"minutes": value}
|
||||
if unit == "h":
|
||||
return {"hours": value}
|
||||
if unit == "d":
|
||||
return {"days": value}
|
||||
return {}
|
||||
|
||||
|
||||
def parse_daily_time(time_str: str) -> dict:
|
||||
if match := re.match(r"^(\d{1,2}):(\d{1,2})(?::(\d{1,2}))?$", time_str):
|
||||
hour, minute, second = match.groups()
|
||||
hour, minute = int(hour), int(minute)
|
||||
if not (0 <= hour <= 23 and 0 <= minute <= 59):
|
||||
raise ValueError("小时或分钟数值超出范围。")
|
||||
cron_config = {
|
||||
"minute": str(minute),
|
||||
"hour": str(hour),
|
||||
"day": "*",
|
||||
"month": "*",
|
||||
"day_of_week": "*",
|
||||
"timezone": Config.get_config("SchedulerManager", "SCHEDULER_TIMEZONE"),
|
||||
}
|
||||
if second is not None:
|
||||
if not (0 <= int(second) <= 59):
|
||||
raise ValueError("秒数值超出范围。")
|
||||
cron_config["second"] = str(second)
|
||||
return cron_config
|
||||
else:
|
||||
raise ValueError("时间格式错误,请使用 'HH:MM' 或 'HH:MM:SS' 格式。")
|
||||
|
||||
|
||||
async def GetBotId(bot: Bot, bot_id_match: Match[str] = AlconnaMatch("bot_id")) -> str:
|
||||
if bot_id_match.available:
|
||||
return bot_id_match.result
|
||||
return bot.self_id
|
||||
|
||||
|
||||
def GetTargeter(subcommand: str):
|
||||
"""
|
||||
依赖注入函数,用于解析命令参数并返回一个配置好的 ScheduleTargeter 实例。
|
||||
"""
|
||||
|
||||
async def dependency(
|
||||
event: Event,
|
||||
bot: Bot,
|
||||
schedule_id: Match[int] = AlconnaMatch("schedule_id"),
|
||||
plugin_name: Match[str] = AlconnaMatch("plugin_name"),
|
||||
group_id: Match[str] = AlconnaMatch("group_id"),
|
||||
all_enabled: Query[bool] = Query(f"{subcommand}.all"),
|
||||
bot_id_to_operate: str = Depends(GetBotId),
|
||||
) -> ScheduleTargeter:
|
||||
if schedule_id.available:
|
||||
return scheduler_manager.target(id=schedule_id.result)
|
||||
|
||||
if plugin_name.available:
|
||||
if all_enabled.available:
|
||||
return scheduler_manager.target(plugin_name=plugin_name.result)
|
||||
|
||||
current_group_id = getattr(event, "group_id", None)
|
||||
gid = group_id.result if group_id.available else current_group_id
|
||||
return scheduler_manager.target(
|
||||
plugin_name=plugin_name.result,
|
||||
group_id=str(gid) if gid else None,
|
||||
bot_id=bot_id_to_operate,
|
||||
)
|
||||
|
||||
if all_enabled.available:
|
||||
current_group_id = getattr(event, "group_id", None)
|
||||
gid = group_id.result if group_id.available else current_group_id
|
||||
is_su = await SUPERUSER(bot, event)
|
||||
if not gid and not is_su:
|
||||
await schedule_cmd.finish(
|
||||
f"在私聊中对所有任务进行'{subcommand}'操作需要超级用户权限。"
|
||||
)
|
||||
|
||||
if (gid and str(gid).lower() == "all") or (not gid and is_su):
|
||||
return scheduler_manager.target()
|
||||
|
||||
return scheduler_manager.target(
|
||||
group_id=str(gid) if gid else None, bot_id=bot_id_to_operate
|
||||
)
|
||||
|
||||
await schedule_cmd.finish(
|
||||
f"'{subcommand}'操作失败:请提供任务ID,"
|
||||
f"或通过 -p <插件名> 或 -all 指定要操作的任务。"
|
||||
)
|
||||
|
||||
return Depends(dependency)
|
||||
380
zhenxun/builtin_plugins/scheduler_admin/handlers.py
Normal file
380
zhenxun/builtin_plugins/scheduler_admin/handlers.py
Normal file
@ -0,0 +1,380 @@
|
||||
from datetime import datetime
|
||||
|
||||
from nonebot.adapters import Event
|
||||
from nonebot.adapters.onebot.v11 import Bot
|
||||
from nonebot.params import Depends
|
||||
from nonebot.permission import SUPERUSER
|
||||
from nonebot_plugin_alconna import AlconnaMatch, Arparma, Match, Query
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from zhenxun.models.schedule_info import ScheduleInfo
|
||||
from zhenxun.services.scheduler import scheduler_manager
|
||||
from zhenxun.services.scheduler.targeter import ScheduleTargeter
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
|
||||
from . import presenters
|
||||
from .commands import (
|
||||
GetBotId,
|
||||
GetTargeter,
|
||||
parse_daily_time,
|
||||
parse_interval,
|
||||
schedule_cmd,
|
||||
)
|
||||
|
||||
|
||||
@schedule_cmd.handle()
|
||||
async def _handle_time_options_mutex(arp: Arparma):
|
||||
time_options = ["cron", "interval", "date", "daily"]
|
||||
provided_options = [opt for opt in time_options if arp.query(opt) is not None]
|
||||
if len(provided_options) > 1:
|
||||
await schedule_cmd.finish(
|
||||
f"时间选项 --{', --'.join(provided_options)} 不能同时使用,请只选择一个。"
|
||||
)
|
||||
|
||||
|
||||
@schedule_cmd.assign("查看")
|
||||
async def handle_view(
|
||||
bot: Bot,
|
||||
event: Event,
|
||||
target_group_id: Match[str] = AlconnaMatch("target_group_id"),
|
||||
all_groups: Query[bool] = Query("查看.all"),
|
||||
plugin_name: Match[str] = AlconnaMatch("plugin_name"),
|
||||
page: Match[int] = AlconnaMatch("page"),
|
||||
):
|
||||
is_superuser = await SUPERUSER(bot, event)
|
||||
title = ""
|
||||
gid_filter = None
|
||||
|
||||
current_group_id = getattr(event, "group_id", None)
|
||||
if not (all_groups.available or target_group_id.available) and not current_group_id:
|
||||
await schedule_cmd.finish("私聊中查看任务必须使用 -g <群号> 或 -all 选项。")
|
||||
|
||||
if all_groups.available:
|
||||
if not is_superuser:
|
||||
await schedule_cmd.finish("需要超级用户权限才能查看所有群组的定时任务。")
|
||||
title = "所有群组的定时任务"
|
||||
elif target_group_id.available:
|
||||
if not is_superuser:
|
||||
await schedule_cmd.finish("需要超级用户权限才能查看指定群组的定时任务。")
|
||||
gid_filter = target_group_id.result
|
||||
title = f"群 {gid_filter} 的定时任务"
|
||||
else:
|
||||
gid_filter = str(current_group_id)
|
||||
title = "本群的定时任务"
|
||||
|
||||
p_name_filter = plugin_name.result if plugin_name.available else None
|
||||
|
||||
schedules = await scheduler_manager.get_schedules(
|
||||
plugin_name=p_name_filter, group_id=gid_filter
|
||||
)
|
||||
|
||||
if p_name_filter:
|
||||
title += f" [插件: {p_name_filter}]"
|
||||
|
||||
if not schedules:
|
||||
await schedule_cmd.finish("没有找到任何相关的定时任务。")
|
||||
|
||||
img = await presenters.format_schedule_list_as_image(
|
||||
schedules=schedules, title=title, current_page=page.result
|
||||
)
|
||||
await MessageUtils.build_message(img).send(reply_to=True)
|
||||
|
||||
|
||||
@schedule_cmd.assign("设置")
|
||||
async def handle_set(
|
||||
event: Event,
|
||||
plugin_name: Match[str] = AlconnaMatch("plugin_name"),
|
||||
cron_expr: Match[str] = AlconnaMatch("cron_expr"),
|
||||
interval_expr: Match[str] = AlconnaMatch("interval_expr"),
|
||||
date_expr: Match[str] = AlconnaMatch("date_expr"),
|
||||
daily_expr: Match[str] = AlconnaMatch("daily_expr"),
|
||||
group_id: Match[str] = AlconnaMatch("group_id"),
|
||||
kwargs_str: Match[str] = AlconnaMatch("kwargs_str"),
|
||||
all_enabled: Query[bool] = Query("设置.all"),
|
||||
bot_id_to_operate: str = Depends(GetBotId),
|
||||
):
|
||||
if not plugin_name.available:
|
||||
await schedule_cmd.finish("设置任务时必须提供插件名称。")
|
||||
|
||||
has_time_option = any(
|
||||
[
|
||||
cron_expr.available,
|
||||
interval_expr.available,
|
||||
date_expr.available,
|
||||
daily_expr.available,
|
||||
]
|
||||
)
|
||||
if not has_time_option:
|
||||
await schedule_cmd.finish(
|
||||
"必须提供一种时间选项: --cron, --interval, --date, 或 --daily。"
|
||||
)
|
||||
|
||||
p_name = plugin_name.result
|
||||
if p_name not in scheduler_manager.get_registered_plugins():
|
||||
await schedule_cmd.finish(
|
||||
f"插件 '{p_name}' 没有注册可用的定时任务。\n"
|
||||
f"可用插件: {list(scheduler_manager.get_registered_plugins())}"
|
||||
)
|
||||
|
||||
trigger_type, trigger_config = "", {}
|
||||
try:
|
||||
if cron_expr.available:
|
||||
trigger_type, trigger_config = (
|
||||
"cron",
|
||||
dict(
|
||||
zip(
|
||||
["minute", "hour", "day", "month", "day_of_week"],
|
||||
cron_expr.result.split(),
|
||||
)
|
||||
),
|
||||
)
|
||||
elif interval_expr.available:
|
||||
trigger_type, trigger_config = (
|
||||
"interval",
|
||||
parse_interval(interval_expr.result),
|
||||
)
|
||||
elif date_expr.available:
|
||||
trigger_type, trigger_config = (
|
||||
"date",
|
||||
{"run_date": datetime.fromisoformat(date_expr.result)},
|
||||
)
|
||||
elif daily_expr.available:
|
||||
trigger_type, trigger_config = "cron", parse_daily_time(daily_expr.result)
|
||||
else:
|
||||
await schedule_cmd.finish(
|
||||
"必须提供一种时间选项: --cron, --interval, --date, 或 --daily。"
|
||||
)
|
||||
except ValueError as e:
|
||||
await schedule_cmd.finish(f"时间参数解析错误: {e}")
|
||||
|
||||
job_kwargs = {}
|
||||
if kwargs_str.available:
|
||||
task_meta = scheduler_manager._registered_tasks[p_name]
|
||||
params_model = task_meta.get("model")
|
||||
if not (
|
||||
params_model
|
||||
and isinstance(params_model, type)
|
||||
and issubclass(params_model, BaseModel)
|
||||
):
|
||||
await schedule_cmd.finish(f"插件 '{p_name}' 不支持或配置了无效的参数模型。")
|
||||
try:
|
||||
raw_kwargs = dict(
|
||||
item.strip().split("=", 1) for item in kwargs_str.result.split(",")
|
||||
)
|
||||
|
||||
model_validate = getattr(params_model, "model_validate", None)
|
||||
if not model_validate:
|
||||
await schedule_cmd.finish(f"插件 '{p_name}' 的参数模型不支持验证")
|
||||
|
||||
validated_model = model_validate(raw_kwargs)
|
||||
|
||||
model_dump = getattr(validated_model, "model_dump", None)
|
||||
if not model_dump:
|
||||
await schedule_cmd.finish(f"插件 '{p_name}' 的参数模型不支持导出")
|
||||
|
||||
job_kwargs = model_dump()
|
||||
except ValidationError as e:
|
||||
errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()]
|
||||
await schedule_cmd.finish(
|
||||
f"插件 '{p_name}' 的任务参数验证失败:\n" + "\n".join(errors)
|
||||
)
|
||||
except Exception as e:
|
||||
await schedule_cmd.finish(
|
||||
f"参数格式错误,请使用 'key=value,key2=value2' 格式。错误: {e}"
|
||||
)
|
||||
|
||||
gid_str = group_id.result if group_id.available else None
|
||||
target_group_id = (
|
||||
scheduler_manager.ALL_GROUPS
|
||||
if (gid_str and gid_str.lower() == "all") or all_enabled.available
|
||||
else gid_str or getattr(event, "group_id", None)
|
||||
)
|
||||
if not target_group_id:
|
||||
await schedule_cmd.finish(
|
||||
"私聊中设置定时任务时,必须使用 -g <群号> 或 --all 选项指定目标。"
|
||||
)
|
||||
|
||||
schedule = await scheduler_manager.add_schedule(
|
||||
p_name,
|
||||
str(target_group_id),
|
||||
trigger_type,
|
||||
trigger_config,
|
||||
job_kwargs,
|
||||
bot_id=bot_id_to_operate,
|
||||
)
|
||||
|
||||
target_desc = (
|
||||
f"所有群组 (Bot: {bot_id_to_operate})"
|
||||
if target_group_id == scheduler_manager.ALL_GROUPS
|
||||
else f"群组 {target_group_id}"
|
||||
)
|
||||
|
||||
if schedule:
|
||||
await schedule_cmd.finish(
|
||||
f"为 [{target_desc}] 已成功设置插件 '{p_name}' 的定时任务 "
|
||||
f"(ID: {schedule.id})。"
|
||||
)
|
||||
else:
|
||||
await schedule_cmd.finish(f"为 [{target_desc}] 设置任务失败。")
|
||||
|
||||
|
||||
@schedule_cmd.assign("删除")
|
||||
async def handle_delete(targeter: ScheduleTargeter = GetTargeter("删除")):
|
||||
schedules_to_remove: list[ScheduleInfo] = await targeter._get_schedules()
|
||||
if not schedules_to_remove:
|
||||
await schedule_cmd.finish("没有找到可删除的任务。")
|
||||
|
||||
count, _ = await targeter.remove()
|
||||
|
||||
if count > 0 and schedules_to_remove:
|
||||
if len(schedules_to_remove) == 1:
|
||||
message = presenters.format_remove_success(schedules_to_remove[0])
|
||||
else:
|
||||
target_desc = targeter._generate_target_description()
|
||||
message = f"✅ 成功移除了{target_desc} {count} 个任务。"
|
||||
else:
|
||||
message = "没有任务被移除。"
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
|
||||
@schedule_cmd.assign("暂停")
|
||||
async def handle_pause(targeter: ScheduleTargeter = GetTargeter("暂停")):
|
||||
schedules_to_pause: list[ScheduleInfo] = await targeter._get_schedules()
|
||||
if not schedules_to_pause:
|
||||
await schedule_cmd.finish("没有找到可暂停的任务。")
|
||||
|
||||
count, _ = await targeter.pause()
|
||||
|
||||
if count > 0 and schedules_to_pause:
|
||||
if len(schedules_to_pause) == 1:
|
||||
message = presenters.format_pause_success(schedules_to_pause[0])
|
||||
else:
|
||||
target_desc = targeter._generate_target_description()
|
||||
message = f"✅ 成功暂停了{target_desc} {count} 个任务。"
|
||||
else:
|
||||
message = "没有任务被暂停。"
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
|
||||
@schedule_cmd.assign("恢复")
|
||||
async def handle_resume(targeter: ScheduleTargeter = GetTargeter("恢复")):
|
||||
schedules_to_resume: list[ScheduleInfo] = await targeter._get_schedules()
|
||||
if not schedules_to_resume:
|
||||
await schedule_cmd.finish("没有找到可恢复的任务。")
|
||||
|
||||
count, _ = await targeter.resume()
|
||||
|
||||
if count > 0 and schedules_to_resume:
|
||||
if len(schedules_to_resume) == 1:
|
||||
message = presenters.format_resume_success(schedules_to_resume[0])
|
||||
else:
|
||||
target_desc = targeter._generate_target_description()
|
||||
message = f"✅ 成功恢复了{target_desc} {count} 个任务。"
|
||||
else:
|
||||
message = "没有任务被恢复。"
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
|
||||
@schedule_cmd.assign("执行")
|
||||
async def handle_trigger(schedule_id: Match[int] = AlconnaMatch("schedule_id")):
|
||||
from zhenxun.services.scheduler.repository import ScheduleRepository
|
||||
|
||||
schedule_info = await ScheduleRepository.get_by_id(schedule_id.result)
|
||||
if not schedule_info:
|
||||
await schedule_cmd.finish(f"未找到 ID 为 {schedule_id.result} 的任务。")
|
||||
|
||||
success, message = await scheduler_manager.trigger_now(schedule_id.result)
|
||||
|
||||
if success:
|
||||
final_message = presenters.format_trigger_success(schedule_info)
|
||||
else:
|
||||
final_message = f"❌ 手动触发失败: {message}"
|
||||
await schedule_cmd.finish(final_message)
|
||||
|
||||
|
||||
@schedule_cmd.assign("更新")
|
||||
async def handle_update(
|
||||
schedule_id: Match[int] = AlconnaMatch("schedule_id"),
|
||||
cron_expr: Match[str] = AlconnaMatch("cron_expr"),
|
||||
interval_expr: Match[str] = AlconnaMatch("interval_expr"),
|
||||
date_expr: Match[str] = AlconnaMatch("date_expr"),
|
||||
daily_expr: Match[str] = AlconnaMatch("daily_expr"),
|
||||
kwargs_str: Match[str] = AlconnaMatch("kwargs_str"),
|
||||
):
|
||||
if not any(
|
||||
[
|
||||
cron_expr.available,
|
||||
interval_expr.available,
|
||||
date_expr.available,
|
||||
daily_expr.available,
|
||||
kwargs_str.available,
|
||||
]
|
||||
):
|
||||
await schedule_cmd.finish(
|
||||
"请提供需要更新的时间 (--cron/--interval/--date/--daily) 或参数 (--kwargs)"
|
||||
)
|
||||
|
||||
trigger_type, trigger_config, job_kwargs = None, None, None
|
||||
try:
|
||||
if cron_expr.available:
|
||||
trigger_type, trigger_config = (
|
||||
"cron",
|
||||
dict(
|
||||
zip(
|
||||
["minute", "hour", "day", "month", "day_of_week"],
|
||||
cron_expr.result.split(),
|
||||
)
|
||||
),
|
||||
)
|
||||
elif interval_expr.available:
|
||||
trigger_type, trigger_config = (
|
||||
"interval",
|
||||
parse_interval(interval_expr.result),
|
||||
)
|
||||
elif date_expr.available:
|
||||
trigger_type, trigger_config = (
|
||||
"date",
|
||||
{"run_date": datetime.fromisoformat(date_expr.result)},
|
||||
)
|
||||
elif daily_expr.available:
|
||||
trigger_type, trigger_config = "cron", parse_daily_time(daily_expr.result)
|
||||
except ValueError as e:
|
||||
await schedule_cmd.finish(f"时间参数解析错误: {e}")
|
||||
|
||||
if kwargs_str.available:
|
||||
job_kwargs = dict(
|
||||
item.strip().split("=", 1) for item in kwargs_str.result.split(",")
|
||||
)
|
||||
|
||||
success, message = await scheduler_manager.update_schedule(
|
||||
schedule_id.result, trigger_type, trigger_config, job_kwargs
|
||||
)
|
||||
|
||||
if success:
|
||||
from zhenxun.services.scheduler.repository import ScheduleRepository
|
||||
|
||||
updated_schedule = await ScheduleRepository.get_by_id(schedule_id.result)
|
||||
if updated_schedule:
|
||||
final_message = presenters.format_update_success(updated_schedule)
|
||||
else:
|
||||
final_message = "✅ 更新成功,但无法获取更新后的任务详情。"
|
||||
else:
|
||||
final_message = f"❌ 更新失败: {message}"
|
||||
|
||||
await schedule_cmd.finish(final_message)
|
||||
|
||||
|
||||
@schedule_cmd.assign("插件列表")
|
||||
async def handle_plugins_list():
|
||||
message = await presenters.format_plugins_list()
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
|
||||
@schedule_cmd.assign("状态")
|
||||
async def handle_status(schedule_id: Match[int] = AlconnaMatch("schedule_id")):
|
||||
status = await scheduler_manager.get_schedule_status(schedule_id.result)
|
||||
if not status:
|
||||
await schedule_cmd.finish(f"未找到ID为 {schedule_id.result} 的定时任务。")
|
||||
|
||||
message = presenters.format_single_status_message(status)
|
||||
await schedule_cmd.finish(message)
|
||||
274
zhenxun/builtin_plugins/scheduler_admin/presenters.py
Normal file
274
zhenxun/builtin_plugins/scheduler_admin/presenters.py
Normal file
@ -0,0 +1,274 @@
|
||||
import asyncio
|
||||
|
||||
from zhenxun.models.schedule_info import ScheduleInfo
|
||||
from zhenxun.services.scheduler import scheduler_manager
|
||||
from zhenxun.utils._image_template import ImageTemplate, RowStyle
|
||||
|
||||
|
||||
def _get_type_name(annotation) -> str:
|
||||
"""获取类型注解的名称"""
|
||||
if hasattr(annotation, "__name__"):
|
||||
return annotation.__name__
|
||||
elif hasattr(annotation, "_name"):
|
||||
return annotation._name
|
||||
else:
|
||||
return str(annotation)
|
||||
|
||||
|
||||
def _format_trigger(schedule: dict) -> str:
|
||||
"""格式化触发器信息为可读字符串"""
|
||||
trigger_type = schedule.get("trigger_type")
|
||||
config = schedule.get("trigger_config")
|
||||
|
||||
if not isinstance(config, dict):
|
||||
return f"配置错误: {config}"
|
||||
|
||||
if trigger_type == "cron":
|
||||
hour = config.get("hour", "??")
|
||||
minute = config.get("minute", "??")
|
||||
try:
|
||||
hour_int = int(hour)
|
||||
minute_int = int(minute)
|
||||
return f"每天 {hour_int:02d}:{minute_int:02d}"
|
||||
except (ValueError, TypeError):
|
||||
return f"每天 {hour}:{minute}"
|
||||
elif trigger_type == "interval":
|
||||
units = {
|
||||
"weeks": "周",
|
||||
"days": "天",
|
||||
"hours": "小时",
|
||||
"minutes": "分钟",
|
||||
"seconds": "秒",
|
||||
}
|
||||
for unit, unit_name in units.items():
|
||||
if value := config.get(unit):
|
||||
return f"每 {value} {unit_name}"
|
||||
return "未知间隔"
|
||||
elif trigger_type == "date":
|
||||
run_date = config.get("run_date", "N/A")
|
||||
return f"特定时间 {run_date}"
|
||||
else:
|
||||
return f"未知触发器类型: {trigger_type}"
|
||||
|
||||
|
||||
def _format_trigger_for_card(schedule_info: ScheduleInfo | dict) -> str:
|
||||
"""为信息卡片格式化触发器规则"""
|
||||
trigger_type = (
|
||||
schedule_info.get("trigger_type")
|
||||
if isinstance(schedule_info, dict)
|
||||
else schedule_info.trigger_type
|
||||
)
|
||||
config = (
|
||||
schedule_info.get("trigger_config")
|
||||
if isinstance(schedule_info, dict)
|
||||
else schedule_info.trigger_config
|
||||
)
|
||||
|
||||
if not isinstance(config, dict):
|
||||
return f"配置错误: {config}"
|
||||
|
||||
if trigger_type == "cron":
|
||||
hour = config.get("hour", "??")
|
||||
minute = config.get("minute", "??")
|
||||
try:
|
||||
hour_int = int(hour)
|
||||
minute_int = int(minute)
|
||||
return f"每天 {hour_int:02d}:{minute_int:02d}"
|
||||
except (ValueError, TypeError):
|
||||
return f"每天 {hour}:{minute}"
|
||||
elif trigger_type == "interval":
|
||||
units = {
|
||||
"weeks": "周",
|
||||
"days": "天",
|
||||
"hours": "小时",
|
||||
"minutes": "分钟",
|
||||
"seconds": "秒",
|
||||
}
|
||||
for unit, unit_name in units.items():
|
||||
if value := config.get(unit):
|
||||
return f"每 {value} {unit_name}"
|
||||
return "未知间隔"
|
||||
elif trigger_type == "date":
|
||||
run_date = config.get("run_date", "N/A")
|
||||
return f"特定时间 {run_date}"
|
||||
else:
|
||||
return f"未知规则: {trigger_type}"
|
||||
|
||||
|
||||
def _format_operation_result_card(
|
||||
title: str, schedule_info: ScheduleInfo, extra_info: list[str] | None = None
|
||||
) -> str:
|
||||
"""
|
||||
生成一个标准的操作结果信息卡片。
|
||||
|
||||
参数:
|
||||
title: 卡片的标题 (例如 "✅ 成功暂停定时任务!")
|
||||
schedule_info: 相关的 ScheduleInfo 对象
|
||||
extra_info: (可选) 额外的补充信息行
|
||||
"""
|
||||
target_desc = (
|
||||
f"群组 {schedule_info.group_id}"
|
||||
if schedule_info.group_id
|
||||
and schedule_info.group_id != scheduler_manager.ALL_GROUPS
|
||||
else "所有群组"
|
||||
if schedule_info.group_id == scheduler_manager.ALL_GROUPS
|
||||
else "全局"
|
||||
)
|
||||
|
||||
info_lines = [
|
||||
title,
|
||||
f"✓ 任务 ID: {schedule_info.id}",
|
||||
f"🖋 插件: {schedule_info.plugin_name}",
|
||||
f"🎯 目标: {target_desc}",
|
||||
f"⏰ 时间: {_format_trigger_for_card(schedule_info)}",
|
||||
]
|
||||
if extra_info:
|
||||
info_lines.extend(extra_info)
|
||||
|
||||
return "\n".join(info_lines)
|
||||
|
||||
|
||||
def format_pause_success(schedule_info: ScheduleInfo) -> str:
|
||||
"""格式化暂停成功的消息"""
|
||||
return _format_operation_result_card("✅ 成功暂停定时任务!", schedule_info)
|
||||
|
||||
|
||||
def format_resume_success(schedule_info: ScheduleInfo) -> str:
|
||||
"""格式化恢复成功的消息"""
|
||||
return _format_operation_result_card("▶️ 成功恢复定时任务!", schedule_info)
|
||||
|
||||
|
||||
def format_remove_success(schedule_info: ScheduleInfo) -> str:
|
||||
"""格式化删除成功的消息"""
|
||||
return _format_operation_result_card("❌ 成功删除定时任务!", schedule_info)
|
||||
|
||||
|
||||
def format_trigger_success(schedule_info: ScheduleInfo) -> str:
|
||||
"""格式化手动触发成功的消息"""
|
||||
return _format_operation_result_card("🚀 成功手动触发定时任务!", schedule_info)
|
||||
|
||||
|
||||
def format_update_success(schedule_info: ScheduleInfo) -> str:
|
||||
"""格式化更新成功的消息"""
|
||||
return _format_operation_result_card("🔄️ 成功更新定时任务配置!", schedule_info)
|
||||
|
||||
|
||||
def _status_row_style(column: str, text: str) -> RowStyle:
|
||||
"""为状态列设置颜色"""
|
||||
style = RowStyle()
|
||||
if column == "状态":
|
||||
if text == "启用":
|
||||
style.font_color = "#67C23A"
|
||||
elif text == "暂停":
|
||||
style.font_color = "#F56C6C"
|
||||
elif text == "运行中":
|
||||
style.font_color = "#409EFF"
|
||||
return style
|
||||
|
||||
|
||||
def _format_params(schedule_status: dict) -> str:
|
||||
"""将任务参数格式化为人类可读的字符串"""
|
||||
if kwargs := schedule_status.get("job_kwargs"):
|
||||
return " | ".join(f"{k}: {v}" for k, v in kwargs.items())
|
||||
return "-"
|
||||
|
||||
|
||||
async def format_schedule_list_as_image(
|
||||
schedules: list[ScheduleInfo], title: str, current_page: int
|
||||
):
|
||||
"""将任务列表格式化为图片"""
|
||||
page_size = 15
|
||||
total_items = len(schedules)
|
||||
total_pages = (total_items + page_size - 1) // page_size
|
||||
start_index = (current_page - 1) * page_size
|
||||
end_index = start_index + page_size
|
||||
paginated_schedules = schedules[start_index:end_index]
|
||||
|
||||
if not paginated_schedules:
|
||||
return "这一页没有内容了哦~"
|
||||
|
||||
status_tasks = [
|
||||
scheduler_manager.get_schedule_status(s.id) for s in paginated_schedules
|
||||
]
|
||||
all_statuses = await asyncio.gather(*status_tasks)
|
||||
|
||||
def get_status_text(status_value):
|
||||
if isinstance(status_value, bool):
|
||||
return "启用" if status_value else "暂停"
|
||||
return str(status_value)
|
||||
|
||||
data_list = [
|
||||
[
|
||||
s["id"],
|
||||
s["plugin_name"],
|
||||
s.get("bot_id") or "N/A",
|
||||
s["group_id"] or "全局",
|
||||
s["next_run_time"],
|
||||
_format_trigger(s),
|
||||
_format_params(s),
|
||||
get_status_text(s["is_enabled"]),
|
||||
]
|
||||
for s in all_statuses
|
||||
if s
|
||||
]
|
||||
|
||||
if not data_list:
|
||||
return "没有找到任何相关的定时任务。"
|
||||
|
||||
return await ImageTemplate.table_page(
|
||||
head_text=title,
|
||||
tip_text=f"第 {current_page}/{total_pages} 页,共 {total_items} 条任务",
|
||||
column_name=["ID", "插件", "Bot", "目标", "下次运行", "规则", "参数", "状态"],
|
||||
data_list=data_list,
|
||||
column_space=20,
|
||||
text_style=_status_row_style,
|
||||
)
|
||||
|
||||
|
||||
def format_single_status_message(status: dict) -> str:
|
||||
"""格式化单个任务状态为文本消息"""
|
||||
info_lines = [
|
||||
f"📋 定时任务详细信息 (ID: {status['id']})",
|
||||
"--------------------",
|
||||
f"▫️ 插件: {status['plugin_name']}",
|
||||
f"▫️ Bot ID: {status.get('bot_id') or '默认'}",
|
||||
f"▫️ 目标: {status['group_id'] or '全局'}",
|
||||
f"▫️ 状态: {'✔️ 已启用' if status['is_enabled'] else '⏸️ 已暂停'}",
|
||||
f"▫️ 下次运行: {status['next_run_time']}",
|
||||
f"▫️ 触发规则: {_format_trigger(status)}",
|
||||
f"▫️ 任务参数: {_format_params(status)}",
|
||||
]
|
||||
return "\n".join(info_lines)
|
||||
|
||||
|
||||
async def format_plugins_list() -> str:
|
||||
"""格式化可用插件列表为文本消息"""
|
||||
from pydantic import BaseModel
|
||||
|
||||
registered_plugins = scheduler_manager.get_registered_plugins()
|
||||
if not registered_plugins:
|
||||
return "当前没有已注册的定时任务插件。"
|
||||
|
||||
message_parts = ["📋 已注册的定时任务插件:"]
|
||||
for i, plugin_name in enumerate(registered_plugins, 1):
|
||||
task_meta = scheduler_manager._registered_tasks[plugin_name]
|
||||
params_model = task_meta.get("model")
|
||||
|
||||
param_info_str = "无参数"
|
||||
if (
|
||||
params_model
|
||||
and isinstance(params_model, type)
|
||||
and issubclass(params_model, BaseModel)
|
||||
):
|
||||
model_fields = getattr(params_model, "model_fields", None)
|
||||
if model_fields:
|
||||
param_info_str = "参数: " + ", ".join(
|
||||
f"{field_name}({_get_type_name(field_info.annotation)})"
|
||||
for field_name, field_info in model_fields.items()
|
||||
)
|
||||
elif params_model:
|
||||
param_info_str = "⚠️ 参数模型配置错误"
|
||||
|
||||
message_parts.append(f"{i}. {plugin_name} - {param_info_str}")
|
||||
|
||||
return "\n".join(message_parts)
|
||||
@ -1,5 +1,3 @@
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from tortoise.functions import Count
|
||||
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
@ -10,6 +8,7 @@ from zhenxun.utils.echart_utils import ChartUtils
|
||||
from zhenxun.utils.echart_utils.models import Barh
|
||||
from zhenxun.utils.enum import PluginType
|
||||
from zhenxun.utils.image_utils import BuildImage
|
||||
from zhenxun.utils.utils import TimeUtils
|
||||
|
||||
|
||||
class StatisticsManage:
|
||||
@ -66,8 +65,7 @@ class StatisticsManage:
|
||||
if plugin_name:
|
||||
query = query.filter(plugin_name=plugin_name)
|
||||
if day:
|
||||
time = datetime.now() - timedelta(days=day)
|
||||
query = query.filter(create_time__gte=time)
|
||||
query = query.filter(create_time__gte=TimeUtils.get_day_start())
|
||||
data_list = (
|
||||
await query.annotate(count=Count("id"))
|
||||
.group_by("plugin_name")
|
||||
@ -87,8 +85,7 @@ class StatisticsManage:
|
||||
if group_id:
|
||||
query = query.filter(group_id=group_id)
|
||||
if day:
|
||||
time = datetime.now() - timedelta(days=day)
|
||||
query = query.filter(create_time__gte=time)
|
||||
query = query.filter(create_time__gte=TimeUtils.get_day_start())
|
||||
data_list = (
|
||||
await query.annotate(count=Count("id"))
|
||||
.group_by("plugin_name")
|
||||
@ -104,8 +101,7 @@ class StatisticsManage:
|
||||
async def get_group_statistics(cls, group_id: str, day: int | None, title: str):
|
||||
query = Statistics.filter(group_id=group_id)
|
||||
if day:
|
||||
time = datetime.now() - timedelta(days=day)
|
||||
query = query.filter(create_time__gte=time)
|
||||
query = query.filter(create_time__gte=TimeUtils.get_day_start())
|
||||
data_list = (
|
||||
await query.annotate(count=Count("id"))
|
||||
.group_by("plugin_name")
|
||||
|
||||
@ -28,7 +28,7 @@ from nonebot_plugin_alconna.uniseg.segment import (
|
||||
)
|
||||
from nonebot_plugin_session import EventSession
|
||||
|
||||
from zhenxun.configs.utils import PluginExtraData, RegisterConfig, Task
|
||||
from zhenxun.configs.utils import PluginExtraData, Task
|
||||
from zhenxun.utils.enum import PluginType
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
|
||||
@ -73,16 +73,6 @@ __plugin_meta__ = PluginMetadata(
|
||||
author="HibiKier",
|
||||
version="1.2",
|
||||
plugin_type=PluginType.SUPERUSER,
|
||||
configs=[
|
||||
RegisterConfig(
|
||||
module="_task",
|
||||
key="DEFAULT_BROADCAST",
|
||||
value=True,
|
||||
help="被动 广播 进群默认开关状态",
|
||||
default_value=True,
|
||||
type=bool,
|
||||
)
|
||||
],
|
||||
tasks=[Task(module="broadcast", name="广播")],
|
||||
).to_dict(),
|
||||
)
|
||||
|
||||
@ -4,6 +4,7 @@ from fastapi.responses import JSONResponse
|
||||
from zhenxun.models.plugin_info import PluginInfo as DbPluginInfo
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import BlockType, PluginType
|
||||
from zhenxun.utils.manager.virtual_env_package_manager import VirtualEnvPackageManager
|
||||
|
||||
from ....base_model import Result
|
||||
from ....utils import authentication, clear_help_image
|
||||
@ -11,6 +12,7 @@ from .data_source import ApiDataSource
|
||||
from .model import (
|
||||
BatchUpdatePlugins,
|
||||
BatchUpdateResult,
|
||||
InstallDependenciesPayload,
|
||||
PluginCount,
|
||||
PluginDetail,
|
||||
PluginInfo,
|
||||
@ -162,9 +164,9 @@ async def _(module: str) -> Result[PluginDetail]:
|
||||
dependencies=[authentication()],
|
||||
response_model=Result[BatchUpdateResult],
|
||||
response_class=JSONResponse,
|
||||
summary="批量更新插件配置",
|
||||
description="批量更新插件配置",
|
||||
)
|
||||
async def batch_update_plugin_config_api(
|
||||
async def _(
|
||||
params: BatchUpdatePlugins,
|
||||
) -> Result[BatchUpdateResult]:
|
||||
"""批量更新插件配置,如开关、类型等"""
|
||||
@ -187,9 +189,9 @@ async def batch_update_plugin_config_api(
|
||||
"/menu_type/rename",
|
||||
dependencies=[authentication()],
|
||||
response_model=Result,
|
||||
summary="重命名菜单类型",
|
||||
description="重命名菜单类型",
|
||||
)
|
||||
async def rename_menu_type_api(payload: RenameMenuTypePayload) -> Result:
|
||||
async def _(payload: RenameMenuTypePayload) -> Result[str]:
|
||||
try:
|
||||
result = await ApiDataSource.rename_menu_type(
|
||||
old_name=payload.old_name, new_name=payload.new_name
|
||||
@ -213,3 +215,24 @@ async def rename_menu_type_api(payload: RenameMenuTypePayload) -> Result:
|
||||
except Exception as e:
|
||||
logger.error(f"{router.prefix}/menu_type/rename 调用错误", "WebUi", e=e)
|
||||
return Result.fail(info=f"发生未知错误: {type(e).__name__}")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/install_dependencies",
|
||||
dependencies=[authentication()],
|
||||
response_model=Result,
|
||||
response_class=JSONResponse,
|
||||
description="安装/卸载依赖",
|
||||
)
|
||||
async def _(payload: InstallDependenciesPayload) -> Result:
|
||||
try:
|
||||
if not payload.dependencies:
|
||||
return Result.fail("依赖列表不能为空")
|
||||
if payload.handle_type == "install":
|
||||
result = VirtualEnvPackageManager.install(payload.dependencies)
|
||||
else:
|
||||
result = VirtualEnvPackageManager.uninstall(payload.dependencies)
|
||||
return Result.ok(result)
|
||||
except Exception as e:
|
||||
logger.error(f"{router.prefix}/install_dependencies 调用错误", "WebUi", e=e)
|
||||
return Result.fail(f"发生了一点错误捏 {type(e)}: {e}")
|
||||
|
||||
@ -167,7 +167,7 @@ class ApiDataSource:
|
||||
)
|
||||
|
||||
return {
|
||||
"success": len(errors) == 0,
|
||||
"success": not errors,
|
||||
"updated_count": updated_count + bulk_updated_count,
|
||||
"errors": errors,
|
||||
}
|
||||
@ -184,19 +184,24 @@ class ApiDataSource:
|
||||
config: ConfigGroup
|
||||
|
||||
返回:
|
||||
lPluginConfig: 配置数据
|
||||
PluginConfig: 配置数据
|
||||
"""
|
||||
type_str = ""
|
||||
type_inner = None
|
||||
if r := re.search(r"<class '(.*)'>", str(config.configs[cfg].type)):
|
||||
ct = str(config.configs[cfg].type)
|
||||
if r := re.search(r"<class '(.*)'>", ct):
|
||||
type_str = r[1]
|
||||
elif r := re.search(r"typing\.(.*)\[(.*)\]", str(config.configs[cfg].type)):
|
||||
elif (r := re.search(r"typing\.(.*)\[(.*)\]", ct)) or (
|
||||
r := re.search(r"(.*)\[(.*)\]", ct)
|
||||
):
|
||||
type_str = r[1]
|
||||
if type_str:
|
||||
type_str = type_str.lower()
|
||||
type_inner = r[2]
|
||||
if type_inner:
|
||||
type_inner = [x.strip() for x in type_inner.split(",")]
|
||||
else:
|
||||
type_str = ct
|
||||
return PluginConfig(
|
||||
module=module,
|
||||
key=cfg,
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -162,3 +162,15 @@ class BatchUpdateResult(BaseModel):
|
||||
default_factory=list, description="错误信息列表"
|
||||
)
|
||||
"""错误信息列表"""
|
||||
|
||||
|
||||
class InstallDependenciesPayload(BaseModel):
|
||||
"""
|
||||
安装依赖
|
||||
"""
|
||||
|
||||
handle_type: Literal["install", "uninstall"] = Field(..., description="处理类型")
|
||||
"""处理类型"""
|
||||
|
||||
dependencies: list[str] = Field(..., description="依赖列表")
|
||||
"""依赖列表"""
|
||||
|
||||
@ -106,21 +106,34 @@ class ConfigGroup(BaseModel):
|
||||
if value_to_process is None:
|
||||
return default
|
||||
|
||||
if cfg.type:
|
||||
if _is_pydantic_type(cfg.type):
|
||||
if build_model:
|
||||
try:
|
||||
return parse_as(cfg.type, value_to_process)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Pydantic 模型解析失败 (key: {c.upper()}). ", e=e
|
||||
)
|
||||
if cfg.arg_parser:
|
||||
try:
|
||||
return cattrs.structure(value_to_process, cfg.type)
|
||||
return cfg.arg_parser(value_to_process)
|
||||
except Exception as e:
|
||||
logger.warning(f"Cattrs 结构化失败 (key: {key}),返回原始值。", e=e)
|
||||
logger.debug(
|
||||
f"配置项类型转换 MODULE: [<u><y>{self.module}</y></u>] | "
|
||||
f"KEY: [<u><y>{key}</y></u>] 的自定义解析器失败,将使用原始值",
|
||||
e=e,
|
||||
)
|
||||
return value_to_process
|
||||
|
||||
return value_to_process
|
||||
if not build_model or not cfg.type:
|
||||
return value_to_process
|
||||
|
||||
try:
|
||||
if _is_pydantic_type(cfg.type):
|
||||
parsed_value = parse_as(cfg.type, value_to_process)
|
||||
return parsed_value
|
||||
else:
|
||||
structured_value = cattrs.structure(value_to_process, cfg.type)
|
||||
return structured_value
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"❌ 配置项 '{self.module}.{key}' 自动类型转换失败 "
|
||||
f"(目标类型: {cfg.type}),将返回原始值。请检查配置文件格式。错误: {e}",
|
||||
e=e,
|
||||
)
|
||||
return value_to_process
|
||||
|
||||
def to_dict(self, **kwargs):
|
||||
return model_dump(self, **kwargs)
|
||||
@ -167,6 +180,57 @@ class ConfigsManager:
|
||||
if data := self._data.get(module):
|
||||
data.name = name
|
||||
|
||||
def _merge_dicts(self, new_data: dict, original_data: dict) -> dict:
|
||||
"""合并两个字典,只进行key值的新增和删除操作,不修改原有key的值
|
||||
|
||||
递归处理嵌套字典,确保所有层级的key保持一致
|
||||
|
||||
参数:
|
||||
new_data: 新数据字典
|
||||
original_data: 原数据字典
|
||||
|
||||
返回:
|
||||
合并后的字典
|
||||
"""
|
||||
result = dict(original_data)
|
||||
|
||||
# 遍历新数据的键
|
||||
for key, value in new_data.items():
|
||||
# 如果键不在原数据中,添加它
|
||||
if key not in original_data:
|
||||
result[key] = value
|
||||
# 如果两边都是字典,递归处理
|
||||
elif isinstance(value, dict) and isinstance(original_data[key], dict):
|
||||
result[key] = self._merge_dicts(value, original_data[key])
|
||||
# 如果键已存在,保留原值,不覆盖
|
||||
# (不做任何操作,保持原值)
|
||||
|
||||
return result
|
||||
|
||||
def _normalize_config_data(self, value: Any, original_value: Any = None) -> Any:
|
||||
"""标准化配置数据,处理BaseModel和字典的情况
|
||||
|
||||
参数:
|
||||
value: 要标准化的值
|
||||
original_value: 原始值,用于合并字典
|
||||
|
||||
返回:
|
||||
标准化后的值
|
||||
"""
|
||||
# 处理BaseModel
|
||||
processed_value = _dump_pydantic_obj(value)
|
||||
|
||||
# 如果处理后的值是字典,且原始值也存在
|
||||
if isinstance(processed_value, dict) and original_value is not None:
|
||||
# 处理原始值
|
||||
processed_original = _dump_pydantic_obj(original_value)
|
||||
|
||||
# 如果原始值也是字典,合并它们
|
||||
if isinstance(processed_original, dict):
|
||||
return self._merge_dicts(processed_value, processed_original)
|
||||
|
||||
return processed_value
|
||||
|
||||
def add_plugin_config(
|
||||
self,
|
||||
module: str,
|
||||
@ -195,16 +259,18 @@ class ConfigsManager:
|
||||
ValueError: module和key不能为为空
|
||||
ValueError: 填写错误
|
||||
"""
|
||||
|
||||
key = key.upper()
|
||||
if not module or not key:
|
||||
raise ValueError("add_plugin_config: module和key不能为为空")
|
||||
if isinstance(value, BaseModel):
|
||||
value = model_dump(value)
|
||||
if isinstance(default_value, BaseModel):
|
||||
default_value = model_dump(default_value)
|
||||
|
||||
processed_value = _dump_pydantic_obj(value)
|
||||
processed_default_value = _dump_pydantic_obj(default_value)
|
||||
# 获取现有配置值(如果存在)
|
||||
existing_value = None
|
||||
if module in self._data and (config := self._data[module].configs.get(key)):
|
||||
existing_value = config.value
|
||||
|
||||
# 标准化值和默认值
|
||||
processed_value = self._normalize_config_data(value, existing_value)
|
||||
processed_default_value = self._normalize_config_data(default_value)
|
||||
|
||||
self.add_module.append(f"{module}:{key}".lower())
|
||||
if module in self._data and (config := self._data[module].configs.get(key)):
|
||||
@ -338,14 +404,13 @@ class ConfigsManager:
|
||||
with open(self._simple_file, "w", encoding="utf8") as f:
|
||||
_yaml.dump(self._simple_data, f)
|
||||
path = path or self.file
|
||||
save_data = {}
|
||||
for module, config_group in self._data.items():
|
||||
save_data[module] = {}
|
||||
for config_key, config_model in config_group.configs.items():
|
||||
save_data[module][config_key] = model_dump(
|
||||
config_model, exclude={"type", "arg_parser"}
|
||||
)
|
||||
|
||||
save_data = {
|
||||
module: {
|
||||
config_key: model_dump(config_model, exclude={"type", "arg_parser"})
|
||||
for config_key, config_model in config_group.configs.items()
|
||||
}
|
||||
for module, config_group in self._data.items()
|
||||
}
|
||||
with open(path, "w", encoding="utf8") as f:
|
||||
_yaml.dump(save_data, f)
|
||||
|
||||
|
||||
@ -65,7 +65,7 @@ class RegisterConfig(BaseModel):
|
||||
"""配置注解"""
|
||||
default_value: Any | None = None
|
||||
"""默认值"""
|
||||
type: Any = None
|
||||
type: Any = str
|
||||
"""参数类型"""
|
||||
arg_parser: Callable | None = None
|
||||
"""参数解析"""
|
||||
|
||||
@ -49,7 +49,8 @@ class ChatHistory(Model):
|
||||
o = "-" if order == "DESC" else ""
|
||||
query = cls.filter(group_id=gid) if gid else cls
|
||||
if date_scope:
|
||||
query = query.filter(create_time__range=date_scope)
|
||||
filter_scope = (date_scope[0].isoformat(" "), date_scope[1].isoformat(" "))
|
||||
query = query.filter(create_time__range=filter_scope)
|
||||
return list(
|
||||
await query.annotate(count=Count("user_id"))
|
||||
.order_by(f"{o}count")
|
||||
|
||||
@ -99,6 +99,8 @@ class LevelUser(Model):
|
||||
返回:
|
||||
bool: 是否大于level
|
||||
"""
|
||||
if level == 0:
|
||||
return True
|
||||
if group_id:
|
||||
if user := await cls.get_or_none(user_id=user_id, group_id=group_id):
|
||||
return user.user_level >= level
|
||||
|
||||
@ -1,3 +1,14 @@
|
||||
"""
|
||||
Zhenxun Bot - 核心服务模块
|
||||
|
||||
主要服务包括:
|
||||
- 数据库上下文 (db_context): 提供数据库模型基类和连接管理。
|
||||
- 日志服务 (log): 提供增强的、带上下文的日志记录器。
|
||||
- LLM服务 (llm): 提供与大语言模型交互的统一API。
|
||||
- 插件生命周期管理 (plugin_init): 支持插件安装和卸载时的钩子函数。
|
||||
- 定时任务调度器 (scheduler): 提供持久化的、可管理的定时任务服务。
|
||||
"""
|
||||
|
||||
from nonebot import require
|
||||
|
||||
require("nonebot_plugin_apscheduler")
|
||||
@ -6,3 +17,33 @@ require("nonebot_plugin_session")
|
||||
require("nonebot_plugin_htmlrender")
|
||||
require("nonebot_plugin_uninfo")
|
||||
require("nonebot_plugin_waiter")
|
||||
|
||||
from .db_context import Model, disconnect
|
||||
from .llm import (
|
||||
AI,
|
||||
LLMContentPart,
|
||||
LLMException,
|
||||
LLMMessage,
|
||||
get_model_instance,
|
||||
list_available_models,
|
||||
tool_registry,
|
||||
)
|
||||
from .log import logger
|
||||
from .plugin_init import PluginInit, PluginInitManager
|
||||
from .scheduler import scheduler_manager
|
||||
|
||||
__all__ = [
|
||||
"AI",
|
||||
"LLMContentPart",
|
||||
"LLMException",
|
||||
"LLMMessage",
|
||||
"Model",
|
||||
"PluginInit",
|
||||
"PluginInitManager",
|
||||
"disconnect",
|
||||
"get_model_instance",
|
||||
"list_available_models",
|
||||
"logger",
|
||||
"scheduler_manager",
|
||||
"tool_registry",
|
||||
]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -10,10 +10,10 @@ from .api import (
|
||||
TaskType,
|
||||
analyze,
|
||||
analyze_multimodal,
|
||||
analyze_with_images,
|
||||
chat,
|
||||
code,
|
||||
embed,
|
||||
pipeline_chat,
|
||||
search,
|
||||
search_multimodal,
|
||||
)
|
||||
@ -35,6 +35,7 @@ from .manager import (
|
||||
list_model_identifiers,
|
||||
set_global_default_model_name,
|
||||
)
|
||||
from .tools import tool_registry
|
||||
from .types import (
|
||||
EmbeddingTaskType,
|
||||
LLMContentPart,
|
||||
@ -43,6 +44,7 @@ from .types import (
|
||||
LLMMessage,
|
||||
LLMResponse,
|
||||
LLMTool,
|
||||
MCPCompatible,
|
||||
ModelDetail,
|
||||
ModelInfo,
|
||||
ModelProvider,
|
||||
@ -51,7 +53,7 @@ from .types import (
|
||||
ToolMetadata,
|
||||
UsageInfo,
|
||||
)
|
||||
from .utils import create_multimodal_message, unimsg_to_llm_parts
|
||||
from .utils import create_multimodal_message, message_to_unimessage, unimsg_to_llm_parts
|
||||
|
||||
__all__ = [
|
||||
"AI",
|
||||
@ -65,6 +67,7 @@ __all__ = [
|
||||
"LLMMessage",
|
||||
"LLMResponse",
|
||||
"LLMTool",
|
||||
"MCPCompatible",
|
||||
"ModelDetail",
|
||||
"ModelInfo",
|
||||
"ModelName",
|
||||
@ -76,7 +79,6 @@ __all__ = [
|
||||
"UsageInfo",
|
||||
"analyze",
|
||||
"analyze_multimodal",
|
||||
"analyze_with_images",
|
||||
"chat",
|
||||
"clear_model_cache",
|
||||
"code",
|
||||
@ -88,9 +90,12 @@ __all__ = [
|
||||
"list_available_models",
|
||||
"list_embedding_models",
|
||||
"list_model_identifiers",
|
||||
"message_to_unimessage",
|
||||
"pipeline_chat",
|
||||
"register_llm_configs",
|
||||
"search",
|
||||
"search_multimodal",
|
||||
"set_global_default_model_name",
|
||||
"tool_registry",
|
||||
"unimsg_to_llm_parts",
|
||||
]
|
||||
|
||||
@ -8,7 +8,6 @@ from .base import BaseAdapter, OpenAICompatAdapter, RequestData, ResponseData
|
||||
from .factory import LLMAdapterFactory, get_adapter_for_api_type, register_adapter
|
||||
from .gemini import GeminiAdapter
|
||||
from .openai import OpenAIAdapter
|
||||
from .zhipu import ZhipuAdapter
|
||||
|
||||
LLMAdapterFactory.initialize()
|
||||
|
||||
@ -20,7 +19,6 @@ __all__ = [
|
||||
"OpenAICompatAdapter",
|
||||
"RequestData",
|
||||
"ResponseData",
|
||||
"ZhipuAdapter",
|
||||
"get_adapter_for_api_type",
|
||||
"register_adapter",
|
||||
]
|
||||
|
||||
@ -17,6 +17,7 @@ if TYPE_CHECKING:
|
||||
from ..service import LLMModel
|
||||
from ..types.content import LLMMessage
|
||||
from ..types.enums import EmbeddingTaskType
|
||||
from ..types.models import LLMTool
|
||||
|
||||
|
||||
class RequestData(BaseModel):
|
||||
@ -60,7 +61,7 @@ class BaseAdapter(ABC):
|
||||
"""支持的API类型列表"""
|
||||
pass
|
||||
|
||||
def prepare_simple_request(
|
||||
async def prepare_simple_request(
|
||||
self,
|
||||
model: "LLMModel",
|
||||
api_key: str,
|
||||
@ -86,7 +87,7 @@ class BaseAdapter(ABC):
|
||||
|
||||
config = model._generation_config
|
||||
|
||||
return self.prepare_advanced_request(
|
||||
return await self.prepare_advanced_request(
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
messages=messages,
|
||||
@ -96,13 +97,13 @@ class BaseAdapter(ABC):
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def prepare_advanced_request(
|
||||
async def prepare_advanced_request(
|
||||
self,
|
||||
model: "LLMModel",
|
||||
api_key: str,
|
||||
messages: list["LLMMessage"],
|
||||
config: "LLMGenerationConfig | None" = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
tools: list["LLMTool"] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
) -> RequestData:
|
||||
"""准备高级请求"""
|
||||
@ -238,6 +239,9 @@ class BaseAdapter(ABC):
|
||||
message = choice.get("message", {})
|
||||
content = message.get("content", "")
|
||||
|
||||
if content:
|
||||
content = content.strip()
|
||||
|
||||
parsed_tool_calls: list[LLMToolCall] | None = None
|
||||
if message_tool_calls := message.get("tool_calls"):
|
||||
from ..types.models import LLMToolFunction
|
||||
@ -375,7 +379,7 @@ class BaseAdapter(ABC):
|
||||
if model.temperature is not None:
|
||||
base_config["temperature"] = model.temperature
|
||||
if model.max_tokens is not None:
|
||||
if model.api_type in ["gemini", "gemini_native"]:
|
||||
if model.api_type == "gemini":
|
||||
base_config["maxOutputTokens"] = model.max_tokens
|
||||
else:
|
||||
base_config["max_tokens"] = model.max_tokens
|
||||
@ -401,26 +405,51 @@ class OpenAICompatAdapter(BaseAdapter):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_chat_endpoint(self) -> str:
|
||||
def get_chat_endpoint(self, model: "LLMModel") -> str:
|
||||
"""子类必须实现,返回 chat completions 的端点"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_embedding_endpoint(self) -> str:
|
||||
def get_embedding_endpoint(self, model: "LLMModel") -> str:
|
||||
"""子类必须实现,返回 embeddings 的端点"""
|
||||
pass
|
||||
|
||||
def prepare_advanced_request(
|
||||
async def prepare_simple_request(
|
||||
self,
|
||||
model: "LLMModel",
|
||||
api_key: str,
|
||||
prompt: str,
|
||||
history: list[dict[str, str]] | None = None,
|
||||
) -> RequestData:
|
||||
"""准备简单文本生成请求 - OpenAI兼容API的通用实现"""
|
||||
url = self.get_api_url(model, self.get_chat_endpoint(model))
|
||||
headers = self.get_base_headers(api_key)
|
||||
|
||||
messages = []
|
||||
if history:
|
||||
messages.extend(history)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
body = {
|
||||
"model": model.model_name,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
body = self.apply_config_override(model, body)
|
||||
|
||||
return RequestData(url=url, headers=headers, body=body)
|
||||
|
||||
async def prepare_advanced_request(
|
||||
self,
|
||||
model: "LLMModel",
|
||||
api_key: str,
|
||||
messages: list["LLMMessage"],
|
||||
config: "LLMGenerationConfig | None" = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
tools: list["LLMTool"] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
) -> RequestData:
|
||||
"""准备高级请求 - OpenAI兼容格式"""
|
||||
url = self.get_api_url(model, self.get_chat_endpoint())
|
||||
url = self.get_api_url(model, self.get_chat_endpoint(model))
|
||||
headers = self.get_base_headers(api_key)
|
||||
openai_messages = self.convert_messages_to_openai_format(messages)
|
||||
|
||||
@ -430,7 +459,21 @@ class OpenAICompatAdapter(BaseAdapter):
|
||||
}
|
||||
|
||||
if tools:
|
||||
body["tools"] = tools
|
||||
openai_tools = []
|
||||
for tool in tools:
|
||||
if tool.type == "function" and tool.function:
|
||||
openai_tools.append({"type": "function", "function": tool.function})
|
||||
elif tool.type == "mcp" and tool.mcp_session:
|
||||
if callable(tool.mcp_session):
|
||||
raise ValueError(
|
||||
"适配器接收到未激活的 MCP 会话工厂。"
|
||||
"会话工厂应该在 LLMModel.generate_response 中被激活。"
|
||||
)
|
||||
openai_tools.append(
|
||||
tool.mcp_session.to_api_tool(api_type=self.api_type)
|
||||
)
|
||||
if openai_tools:
|
||||
body["tools"] = openai_tools
|
||||
if tool_choice:
|
||||
body["tool_choice"] = tool_choice
|
||||
|
||||
@ -444,7 +487,7 @@ class OpenAICompatAdapter(BaseAdapter):
|
||||
is_advanced: bool = False,
|
||||
) -> ResponseData:
|
||||
"""解析响应 - 直接使用基类的 OpenAI 格式解析"""
|
||||
_ = model, is_advanced # 未使用的参数
|
||||
_ = model, is_advanced
|
||||
return self.parse_openai_response(response_json)
|
||||
|
||||
def prepare_embedding_request(
|
||||
@ -456,8 +499,8 @@ class OpenAICompatAdapter(BaseAdapter):
|
||||
**kwargs: Any,
|
||||
) -> RequestData:
|
||||
"""准备嵌入请求 - OpenAI兼容格式"""
|
||||
_ = task_type # 未使用的参数
|
||||
url = self.get_api_url(model, self.get_embedding_endpoint())
|
||||
_ = task_type
|
||||
url = self.get_api_url(model, self.get_embedding_endpoint(model))
|
||||
headers = self.get_base_headers(api_key)
|
||||
|
||||
body = {
|
||||
@ -465,7 +508,6 @@ class OpenAICompatAdapter(BaseAdapter):
|
||||
"input": texts,
|
||||
}
|
||||
|
||||
# 应用额外的配置参数
|
||||
if kwargs:
|
||||
body.update(kwargs)
|
||||
|
||||
|
||||
@ -22,10 +22,8 @@ class LLMAdapterFactory:
|
||||
|
||||
from .gemini import GeminiAdapter
|
||||
from .openai import OpenAIAdapter
|
||||
from .zhipu import ZhipuAdapter
|
||||
|
||||
cls.register_adapter(OpenAIAdapter())
|
||||
cls.register_adapter(ZhipuAdapter())
|
||||
cls.register_adapter(GeminiAdapter())
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
||||
from ..service import LLMModel
|
||||
from ..types.content import LLMMessage
|
||||
from ..types.enums import EmbeddingTaskType
|
||||
from ..types.models import LLMToolCall
|
||||
from ..types.models import LLMTool, LLMToolCall
|
||||
|
||||
|
||||
class GeminiAdapter(BaseAdapter):
|
||||
@ -38,30 +38,16 @@ class GeminiAdapter(BaseAdapter):
|
||||
|
||||
return headers
|
||||
|
||||
def prepare_advanced_request(
|
||||
async def prepare_advanced_request(
|
||||
self,
|
||||
model: "LLMModel",
|
||||
api_key: str,
|
||||
messages: list["LLMMessage"],
|
||||
config: "LLMGenerationConfig | None" = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
tools: list["LLMTool"] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
) -> RequestData:
|
||||
"""准备高级请求"""
|
||||
return self._prepare_request(
|
||||
model, api_key, messages, config, tools, tool_choice
|
||||
)
|
||||
|
||||
def _prepare_request(
|
||||
self,
|
||||
model: "LLMModel",
|
||||
api_key: str,
|
||||
messages: list["LLMMessage"],
|
||||
config: "LLMGenerationConfig | None" = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
) -> RequestData:
|
||||
"""准备 Gemini API 请求 - 支持所有高级功能"""
|
||||
effective_config = config if config is not None else model._generation_config
|
||||
|
||||
endpoint = self._get_gemini_endpoint(model, effective_config)
|
||||
@ -78,7 +64,8 @@ class GeminiAdapter(BaseAdapter):
|
||||
system_instruction_parts = [{"text": msg.content}]
|
||||
elif isinstance(msg.content, list):
|
||||
system_instruction_parts = [
|
||||
part.convert_for_api("gemini") for part in msg.content
|
||||
await part.convert_for_api_async("gemini")
|
||||
for part in msg.content
|
||||
]
|
||||
continue
|
||||
|
||||
@ -87,7 +74,9 @@ class GeminiAdapter(BaseAdapter):
|
||||
current_parts.append({"text": msg.content})
|
||||
elif isinstance(msg.content, list):
|
||||
for part_obj in msg.content:
|
||||
current_parts.append(part_obj.convert_for_api("gemini"))
|
||||
current_parts.append(
|
||||
await part_obj.convert_for_api_async("gemini")
|
||||
)
|
||||
gemini_contents.append({"role": "user", "parts": current_parts})
|
||||
|
||||
elif msg.role == "assistant" or msg.role == "model":
|
||||
@ -95,7 +84,9 @@ class GeminiAdapter(BaseAdapter):
|
||||
current_parts.append({"text": msg.content})
|
||||
elif isinstance(msg.content, list):
|
||||
for part_obj in msg.content:
|
||||
current_parts.append(part_obj.convert_for_api("gemini"))
|
||||
current_parts.append(
|
||||
await part_obj.convert_for_api_async("gemini")
|
||||
)
|
||||
|
||||
if msg.tool_calls:
|
||||
import json
|
||||
@ -154,16 +145,22 @@ class GeminiAdapter(BaseAdapter):
|
||||
|
||||
all_tools_for_request = []
|
||||
if tools:
|
||||
for tool_item in tools:
|
||||
if isinstance(tool_item, dict):
|
||||
if "name" in tool_item and "description" in tool_item:
|
||||
all_tools_for_request.append(
|
||||
{"functionDeclarations": [tool_item]}
|
||||
for tool in tools:
|
||||
if tool.type == "function" and tool.function:
|
||||
all_tools_for_request.append(
|
||||
{"functionDeclarations": [tool.function]}
|
||||
)
|
||||
elif tool.type == "mcp" and tool.mcp_session:
|
||||
if callable(tool.mcp_session):
|
||||
raise ValueError(
|
||||
"适配器接收到未激活的 MCP 会话工厂。"
|
||||
"会话工厂应该在 LLMModel.generate_response 中被激活。"
|
||||
)
|
||||
else:
|
||||
all_tools_for_request.append(tool_item)
|
||||
else:
|
||||
all_tools_for_request.append(tool_item)
|
||||
all_tools_for_request.append(
|
||||
tool.mcp_session.to_api_tool(api_type=self.api_type)
|
||||
)
|
||||
elif tool.type == "google_search":
|
||||
all_tools_for_request.append({"googleSearch": {}})
|
||||
|
||||
if effective_config:
|
||||
if getattr(effective_config, "enable_grounding", False):
|
||||
@ -183,11 +180,7 @@ class GeminiAdapter(BaseAdapter):
|
||||
logger.debug("隐式启用代码执行工具。")
|
||||
|
||||
if all_tools_for_request:
|
||||
gemini_api_tools = self._convert_tools_to_gemini_format(
|
||||
all_tools_for_request
|
||||
)
|
||||
if gemini_api_tools:
|
||||
body["tools"] = gemini_api_tools
|
||||
body["tools"] = all_tools_for_request
|
||||
|
||||
final_tool_choice = tool_choice
|
||||
if final_tool_choice is None and effective_config:
|
||||
@ -241,38 +234,6 @@ class GeminiAdapter(BaseAdapter):
|
||||
|
||||
return f"/v1beta/models/{model.model_name}:generateContent"
|
||||
|
||||
def _convert_tools_to_gemini_format(
|
||||
self, tools: list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""转换工具格式为Gemini格式"""
|
||||
gemini_tools = []
|
||||
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function":
|
||||
func = tool["function"]
|
||||
gemini_tool = {
|
||||
"functionDeclarations": [
|
||||
{
|
||||
"name": func["name"],
|
||||
"description": func.get("description", ""),
|
||||
"parameters": func.get("parameters", {}),
|
||||
}
|
||||
]
|
||||
}
|
||||
gemini_tools.append(gemini_tool)
|
||||
elif tool.get("type") == "code_execution":
|
||||
gemini_tools.append(
|
||||
{"codeExecution": {"language": tool.get("language", "python")}}
|
||||
)
|
||||
elif tool.get("type") == "google_search":
|
||||
gemini_tools.append({"googleSearch": {}})
|
||||
elif "googleSearch" in tool:
|
||||
gemini_tools.append({"googleSearch": tool["googleSearch"]})
|
||||
elif "codeExecution" in tool:
|
||||
gemini_tools.append({"codeExecution": tool["codeExecution"]})
|
||||
|
||||
return gemini_tools
|
||||
|
||||
def _convert_tool_choice_to_gemini(
|
||||
self, tool_choice_value: str | dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
@ -395,10 +356,11 @@ class GeminiAdapter(BaseAdapter):
|
||||
for category, threshold in custom_safety_settings.items():
|
||||
safety_settings.append({"category": category, "threshold": threshold})
|
||||
else:
|
||||
from ..config.providers import get_gemini_safety_threshold
|
||||
|
||||
threshold = get_gemini_safety_threshold()
|
||||
for category in safety_categories:
|
||||
safety_settings.append(
|
||||
{"category": category, "threshold": "BLOCK_MEDIUM_AND_ABOVE"}
|
||||
)
|
||||
safety_settings.append({"category": category, "threshold": threshold})
|
||||
|
||||
return safety_settings if safety_settings else None
|
||||
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
"""
|
||||
OpenAI API 适配器
|
||||
|
||||
支持 OpenAI、DeepSeek 和其他 OpenAI 兼容的 API 服务。
|
||||
支持 OpenAI、DeepSeek、智谱AI 和其他 OpenAI 兼容的 API 服务。
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .base import OpenAICompatAdapter, RequestData
|
||||
from .base import OpenAICompatAdapter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..service import LLMModel
|
||||
@ -21,37 +21,18 @@ class OpenAIAdapter(OpenAICompatAdapter):
|
||||
|
||||
@property
|
||||
def supported_api_types(self) -> list[str]:
|
||||
return ["openai", "deepseek", "general_openai_compat"]
|
||||
return ["openai", "deepseek", "zhipu", "general_openai_compat", "ark"]
|
||||
|
||||
def get_chat_endpoint(self) -> str:
|
||||
def get_chat_endpoint(self, model: "LLMModel") -> str:
|
||||
"""返回聊天完成端点"""
|
||||
if model.api_type == "ark":
|
||||
return "/api/v3/chat/completions"
|
||||
if model.api_type == "zhipu":
|
||||
return "/api/paas/v4/chat/completions"
|
||||
return "/v1/chat/completions"
|
||||
|
||||
def get_embedding_endpoint(self) -> str:
|
||||
"""返回嵌入端点"""
|
||||
def get_embedding_endpoint(self, model: "LLMModel") -> str:
|
||||
"""根据API类型返回嵌入端点"""
|
||||
if model.api_type == "zhipu":
|
||||
return "/v4/embeddings"
|
||||
return "/v1/embeddings"
|
||||
|
||||
def prepare_simple_request(
|
||||
self,
|
||||
model: "LLMModel",
|
||||
api_key: str,
|
||||
prompt: str,
|
||||
history: list[dict[str, str]] | None = None,
|
||||
) -> RequestData:
|
||||
"""准备简单文本生成请求 - OpenAI优化实现"""
|
||||
url = self.get_api_url(model, self.get_chat_endpoint())
|
||||
headers = self.get_base_headers(api_key)
|
||||
|
||||
messages = []
|
||||
if history:
|
||||
messages.extend(history)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
body = {
|
||||
"model": model.model_name,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
body = self.apply_config_override(model, body)
|
||||
|
||||
return RequestData(url=url, headers=headers, body=body)
|
||||
|
||||
@ -1,57 +0,0 @@
|
||||
"""
|
||||
智谱 AI API 适配器
|
||||
|
||||
支持智谱 AI 的 GLM 系列模型,使用 OpenAI 兼容的接口格式。
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .base import OpenAICompatAdapter, RequestData
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..service import LLMModel
|
||||
|
||||
|
||||
class ZhipuAdapter(OpenAICompatAdapter):
|
||||
"""智谱AI适配器 - 使用智谱AI专用的OpenAI兼容接口"""
|
||||
|
||||
@property
|
||||
def api_type(self) -> str:
|
||||
return "zhipu"
|
||||
|
||||
@property
|
||||
def supported_api_types(self) -> list[str]:
|
||||
return ["zhipu"]
|
||||
|
||||
def get_chat_endpoint(self) -> str:
|
||||
"""返回智谱AI聊天完成端点"""
|
||||
return "/api/paas/v4/chat/completions"
|
||||
|
||||
def get_embedding_endpoint(self) -> str:
|
||||
"""返回智谱AI嵌入端点"""
|
||||
return "/v4/embeddings"
|
||||
|
||||
def prepare_simple_request(
|
||||
self,
|
||||
model: "LLMModel",
|
||||
api_key: str,
|
||||
prompt: str,
|
||||
history: list[dict[str, str]] | None = None,
|
||||
) -> RequestData:
|
||||
"""准备简单文本生成请求 - 智谱AI优化实现"""
|
||||
url = self.get_api_url(model, self.get_chat_endpoint())
|
||||
headers = self.get_base_headers(api_key)
|
||||
|
||||
messages = []
|
||||
if history:
|
||||
messages.extend(history)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
body = {
|
||||
"model": model.model_name,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
body = self.apply_config_override(model, body)
|
||||
|
||||
return RequestData(url=url, headers=headers, body=body)
|
||||
@ -2,6 +2,7 @@
|
||||
LLM 服务的高级 API 接口
|
||||
"""
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
@ -14,6 +15,7 @@ from zhenxun.services.log import logger
|
||||
from .config import CommonOverrides, LLMGenerationConfig
|
||||
from .config.providers import get_ai_config
|
||||
from .manager import get_global_default_model_name, get_model_instance
|
||||
from .tools import tool_registry
|
||||
from .types import (
|
||||
EmbeddingTaskType,
|
||||
LLMContentPart,
|
||||
@ -56,6 +58,7 @@ class AIConfig:
|
||||
enable_gemini_safe_mode: bool = False
|
||||
enable_gemini_multimodal: bool = False
|
||||
enable_gemini_grounding: bool = False
|
||||
default_preserve_media_in_history: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
"""初始化后从配置中读取默认值"""
|
||||
@ -81,7 +84,7 @@ class AI:
|
||||
"""
|
||||
初始化AI服务
|
||||
|
||||
Args:
|
||||
参数:
|
||||
config: AI 配置.
|
||||
history: 可选的初始对话历史.
|
||||
"""
|
||||
@ -93,16 +96,65 @@ class AI:
|
||||
self.history = []
|
||||
logger.info("AI session history cleared.")
|
||||
|
||||
def _sanitize_message_for_history(self, message: LLMMessage) -> LLMMessage:
|
||||
"""
|
||||
净化用于存入历史记录的消息。
|
||||
将非文本的多模态内容部分替换为文本占位符,以避免重复处理。
|
||||
"""
|
||||
if not isinstance(message.content, list):
|
||||
return message
|
||||
|
||||
sanitized_message = copy.deepcopy(message)
|
||||
content_list = sanitized_message.content
|
||||
if not isinstance(content_list, list):
|
||||
return sanitized_message
|
||||
|
||||
new_content_parts: list[LLMContentPart] = []
|
||||
has_multimodal_content = False
|
||||
|
||||
for part in content_list:
|
||||
if isinstance(part, LLMContentPart) and part.type == "text":
|
||||
new_content_parts.append(part)
|
||||
else:
|
||||
has_multimodal_content = True
|
||||
|
||||
if has_multimodal_content:
|
||||
placeholder = "[用户发送了媒体文件,内容已在首次分析时处理]"
|
||||
text_part_found = False
|
||||
for part in new_content_parts:
|
||||
if part.type == "text":
|
||||
part.text = f"{placeholder} {part.text or ''}".strip()
|
||||
text_part_found = True
|
||||
break
|
||||
if not text_part_found:
|
||||
new_content_parts.insert(0, LLMContentPart.text_part(placeholder))
|
||||
|
||||
sanitized_message.content = new_content_parts
|
||||
return sanitized_message
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
message: str | LLMMessage | list[LLMContentPart],
|
||||
*,
|
||||
model: ModelName = None,
|
||||
preserve_media_in_history: bool | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
进行一次聊天对话。
|
||||
此方法会自动使用和更新会话内的历史记录。
|
||||
|
||||
参数:
|
||||
message: 用户输入的消息。
|
||||
model: 本次对话要使用的模型。
|
||||
preserve_media_in_history: 是否在历史记录中保留原始多模态信息。
|
||||
- True: 保留,用于深度多轮媒体分析。
|
||||
- False: 不保留,替换为占位符,提高效率。
|
||||
- None (默认): 使用AI实例配置的默认值。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
str: 模型的文本响应。
|
||||
"""
|
||||
current_message: LLMMessage
|
||||
if isinstance(message, str):
|
||||
@ -127,7 +179,20 @@ class AI:
|
||||
final_messages, model, "聊天失败", kwargs
|
||||
)
|
||||
|
||||
self.history.append(current_message)
|
||||
should_preserve = (
|
||||
preserve_media_in_history
|
||||
if preserve_media_in_history is not None
|
||||
else self.config.default_preserve_media_in_history
|
||||
)
|
||||
|
||||
if should_preserve:
|
||||
logger.debug("深度分析模式:在历史记录中保留原始多模态消息。")
|
||||
self.history.append(current_message)
|
||||
else:
|
||||
logger.debug("高效模式:净化历史记录中的多模态消息。")
|
||||
sanitized_user_message = self._sanitize_message_for_history(current_message)
|
||||
self.history.append(sanitized_user_message)
|
||||
|
||||
self.history.append(LLMMessage.assistant_text_response(response.text))
|
||||
|
||||
return response.text
|
||||
@ -140,7 +205,18 @@ class AI:
|
||||
timeout: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""代码执行"""
|
||||
"""
|
||||
代码执行
|
||||
|
||||
参数:
|
||||
prompt: 代码执行的提示词。
|
||||
model: 要使用的模型名称。
|
||||
timeout: 代码执行超时时间(秒)。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
dict[str, Any]: 包含执行结果的字典,包含text、code_executions和success字段。
|
||||
"""
|
||||
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
|
||||
|
||||
config = CommonOverrides.gemini_code_execution()
|
||||
@ -168,7 +244,18 @@ class AI:
|
||||
instruction: str = "",
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""信息搜索 - 支持多模态输入"""
|
||||
"""
|
||||
信息搜索 - 支持多模态输入
|
||||
|
||||
参数:
|
||||
query: 搜索查询内容,支持文本或多模态消息。
|
||||
model: 要使用的模型名称。
|
||||
instruction: 搜索指令。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
dict[str, Any]: 包含搜索结果的字典,包含text、sources、queries和success字段
|
||||
"""
|
||||
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
|
||||
config = CommonOverrides.gemini_grounding()
|
||||
|
||||
@ -217,63 +304,69 @@ class AI:
|
||||
|
||||
async def analyze(
|
||||
self,
|
||||
message: UniMessage,
|
||||
message: UniMessage | None,
|
||||
*,
|
||||
instruction: str = "",
|
||||
model: ModelName = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
use_tools: list[str] | None = None,
|
||||
tool_config: dict[str, Any] | None = None,
|
||||
activated_tools: list[LLMTool] | None = None,
|
||||
history: list[LLMMessage] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str | LLMResponse:
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
内容分析 - 接收 UniMessage 物件进行多模态分析和工具呼叫。
|
||||
这是处理复杂互动的主要方法。
|
||||
|
||||
参数:
|
||||
message: 要分析的消息内容(支持多模态)。
|
||||
instruction: 分析指令。
|
||||
model: 要使用的模型名称。
|
||||
use_tools: 要使用的工具名称列表。
|
||||
tool_config: 工具配置。
|
||||
activated_tools: 已激活的工具列表。
|
||||
history: 对话历史记录。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
LLMResponse: 模型的完整响应结果。
|
||||
"""
|
||||
content_parts = await unimsg_to_llm_parts(message)
|
||||
content_parts = await unimsg_to_llm_parts(message or UniMessage())
|
||||
|
||||
final_messages: list[LLMMessage] = []
|
||||
if history:
|
||||
final_messages.extend(history)
|
||||
|
||||
if instruction:
|
||||
final_messages.append(LLMMessage.system(instruction))
|
||||
if not any(msg.role == "system" for msg in final_messages):
|
||||
final_messages.insert(0, LLMMessage.system(instruction))
|
||||
|
||||
if not content_parts:
|
||||
if instruction:
|
||||
if instruction and not history:
|
||||
final_messages.append(LLMMessage.user(instruction))
|
||||
else:
|
||||
elif not history:
|
||||
raise LLMException(
|
||||
"分析内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
|
||||
)
|
||||
else:
|
||||
final_messages.append(LLMMessage.user(content_parts))
|
||||
|
||||
llm_tools = None
|
||||
if tools:
|
||||
llm_tools = []
|
||||
for tool_dict in tools:
|
||||
if isinstance(tool_dict, dict):
|
||||
if "name" in tool_dict and "description" in tool_dict:
|
||||
llm_tool = LLMTool(
|
||||
type="function",
|
||||
function={
|
||||
"name": tool_dict["name"],
|
||||
"description": tool_dict["description"],
|
||||
"parameters": tool_dict.get("parameters", {}),
|
||||
},
|
||||
)
|
||||
llm_tools.append(llm_tool)
|
||||
else:
|
||||
llm_tools.append(LLMTool(**tool_dict))
|
||||
else:
|
||||
llm_tools.append(tool_dict)
|
||||
llm_tools: list[LLMTool] | None = activated_tools
|
||||
if not llm_tools and use_tools:
|
||||
try:
|
||||
llm_tools = tool_registry.get_tools(use_tools)
|
||||
logger.debug(f"已从注册表加载工具定义: {use_tools}")
|
||||
except ValueError as e:
|
||||
raise LLMException(
|
||||
f"加载工具定义失败: {e}",
|
||||
code=LLMErrorCode.CONFIGURATION_ERROR,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
tool_choice = None
|
||||
if tool_config:
|
||||
mode = tool_config.get("mode", "auto")
|
||||
if mode == "auto":
|
||||
tool_choice = "auto"
|
||||
elif mode == "any":
|
||||
tool_choice = "any"
|
||||
elif mode == "none":
|
||||
tool_choice = "none"
|
||||
if mode in ["auto", "any", "none"]:
|
||||
tool_choice = mode
|
||||
|
||||
response = await self._execute_generation(
|
||||
final_messages,
|
||||
@ -284,9 +377,7 @@ class AI:
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
if response.tool_calls:
|
||||
return response
|
||||
return response.text
|
||||
return response
|
||||
|
||||
async def _execute_generation(
|
||||
self,
|
||||
@ -298,7 +389,7 @@ class AI:
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
base_config: LLMGenerationConfig | None = None,
|
||||
) -> LLMResponse:
|
||||
"""通用的生成执行方法,封装重复的模型获取、配置合并和异常处理逻辑"""
|
||||
"""通用的生成执行方法,封装模型获取和单次API调用"""
|
||||
try:
|
||||
resolved_model_name = self._resolve_model_name(
|
||||
model_name or self.config.model
|
||||
@ -311,7 +402,9 @@ class AI:
|
||||
resolved_model_name, override_config=final_config_dict
|
||||
) as model_instance:
|
||||
return await model_instance.generate_response(
|
||||
messages, tools=llm_tools, tool_choice=tool_choice
|
||||
messages,
|
||||
tools=llm_tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
except LLMException:
|
||||
raise
|
||||
@ -380,7 +473,18 @@ class AI:
|
||||
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
||||
**kwargs: Any,
|
||||
) -> list[list[float]]:
|
||||
"""生成文本嵌入向量"""
|
||||
"""
|
||||
生成文本嵌入向量
|
||||
|
||||
参数:
|
||||
texts: 要生成嵌入向量的文本或文本列表。
|
||||
model: 要使用的嵌入模型名称。
|
||||
task_type: 嵌入任务类型。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
list[list[float]]: 文本的嵌入向量列表。
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
if not texts:
|
||||
@ -420,7 +524,17 @@ async def chat(
|
||||
model: ModelName = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""聊天对话便捷函数"""
|
||||
"""
|
||||
聊天对话便捷函数
|
||||
|
||||
参数:
|
||||
message: 用户输入的消息。
|
||||
model: 要使用的模型名称。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
str: 模型的文本响应。
|
||||
"""
|
||||
ai = AI()
|
||||
return await ai.chat(message, model=model, **kwargs)
|
||||
|
||||
@ -432,7 +546,18 @@ async def code(
|
||||
timeout: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""代码执行便捷函数"""
|
||||
"""
|
||||
代码执行便捷函数
|
||||
|
||||
参数:
|
||||
prompt: 代码执行的提示词。
|
||||
model: 要使用的模型名称。
|
||||
timeout: 代码执行超时时间(秒)。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
dict[str, Any]: 包含执行结果的字典。
|
||||
"""
|
||||
ai = AI()
|
||||
return await ai.code(prompt, model=model, timeout=timeout, **kwargs)
|
||||
|
||||
@ -444,45 +569,56 @@ async def search(
|
||||
instruction: str = "",
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""信息搜索便捷函数"""
|
||||
"""
|
||||
信息搜索便捷函数
|
||||
|
||||
参数:
|
||||
query: 搜索查询内容。
|
||||
model: 要使用的模型名称。
|
||||
instruction: 搜索指令。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
dict[str, Any]: 包含搜索结果的字典。
|
||||
"""
|
||||
ai = AI()
|
||||
return await ai.search(query, model=model, instruction=instruction, **kwargs)
|
||||
|
||||
|
||||
async def analyze(
|
||||
message: UniMessage,
|
||||
message: UniMessage | None,
|
||||
*,
|
||||
instruction: str = "",
|
||||
model: ModelName = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
use_tools: list[str] | None = None,
|
||||
tool_config: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str | LLMResponse:
|
||||
"""内容分析便捷函数"""
|
||||
"""
|
||||
内容分析便捷函数
|
||||
|
||||
参数:
|
||||
message: 要分析的消息内容。
|
||||
instruction: 分析指令。
|
||||
model: 要使用的模型名称。
|
||||
use_tools: 要使用的工具名称列表。
|
||||
tool_config: 工具配置。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
str | LLMResponse: 分析结果。
|
||||
"""
|
||||
ai = AI()
|
||||
return await ai.analyze(
|
||||
message,
|
||||
instruction=instruction,
|
||||
model=model,
|
||||
tools=tools,
|
||||
use_tools=use_tools,
|
||||
tool_config=tool_config,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def analyze_with_images(
|
||||
text: str,
|
||||
images: list[str | Path | bytes] | str | Path | bytes,
|
||||
*,
|
||||
instruction: str = "",
|
||||
model: ModelName = None,
|
||||
**kwargs: Any,
|
||||
) -> str | LLMResponse:
|
||||
"""图片分析便捷函数"""
|
||||
message = create_multimodal_message(text=text, images=images)
|
||||
return await analyze(message, instruction=instruction, model=model, **kwargs)
|
||||
|
||||
|
||||
async def analyze_multimodal(
|
||||
text: str | None = None,
|
||||
images: list[str | Path | bytes] | str | Path | bytes | None = None,
|
||||
@ -493,7 +629,21 @@ async def analyze_multimodal(
|
||||
model: ModelName = None,
|
||||
**kwargs: Any,
|
||||
) -> str | LLMResponse:
|
||||
"""多模态分析便捷函数"""
|
||||
"""
|
||||
多模态分析便捷函数
|
||||
|
||||
参数:
|
||||
text: 文本内容。
|
||||
images: 图片文件路径、字节数据或列表。
|
||||
videos: 视频文件路径、字节数据或列表。
|
||||
audios: 音频文件路径、字节数据或列表。
|
||||
instruction: 分析指令。
|
||||
model: 要使用的模型名称。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
str | LLMResponse: 分析结果。
|
||||
"""
|
||||
message = create_multimodal_message(
|
||||
text=text, images=images, videos=videos, audios=audios
|
||||
)
|
||||
@ -510,7 +660,21 @@ async def search_multimodal(
|
||||
model: ModelName = None,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""多模态搜索便捷函数"""
|
||||
"""
|
||||
多模态搜索便捷函数
|
||||
|
||||
参数:
|
||||
text: 文本内容。
|
||||
images: 图片文件路径、字节数据或列表。
|
||||
videos: 视频文件路径、字节数据或列表。
|
||||
audios: 音频文件路径、字节数据或列表。
|
||||
instruction: 搜索指令。
|
||||
model: 要使用的模型名称。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
dict[str, Any]: 包含搜索结果的字典。
|
||||
"""
|
||||
message = create_multimodal_message(
|
||||
text=text, images=images, videos=videos, audios=audios
|
||||
)
|
||||
@ -525,6 +689,101 @@ async def embed(
|
||||
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
||||
**kwargs: Any,
|
||||
) -> list[list[float]]:
|
||||
"""文本嵌入便捷函数"""
|
||||
"""
|
||||
文本嵌入便捷函数
|
||||
|
||||
参数:
|
||||
texts: 要生成嵌入向量的文本或文本列表。
|
||||
model: 要使用的嵌入模型名称。
|
||||
task_type: 嵌入任务类型。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
list[list[float]]: 文本的嵌入向量列表。
|
||||
"""
|
||||
ai = AI()
|
||||
return await ai.embed(texts, model=model, task_type=task_type, **kwargs)
|
||||
|
||||
|
||||
async def pipeline_chat(
|
||||
message: UniMessage | str | list[LLMContentPart],
|
||||
model_chain: list[ModelName],
|
||||
*,
|
||||
initial_instruction: str = "",
|
||||
final_instruction: str = "",
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
AI模型链式调用,前一个模型的输出作为下一个模型的输入。
|
||||
|
||||
参数:
|
||||
message: 初始输入消息(支持多模态)
|
||||
model_chain: 模型名称列表
|
||||
initial_instruction: 第一个模型的系统指令
|
||||
final_instruction: 最后一个模型的系统指令
|
||||
**kwargs: 传递给模型实例的其他参数
|
||||
|
||||
返回:
|
||||
LLMResponse: 最后一个模型的响应结果
|
||||
"""
|
||||
if not model_chain:
|
||||
raise ValueError("模型链`model_chain`不能为空。")
|
||||
|
||||
current_content: str | list[LLMContentPart]
|
||||
if isinstance(message, str):
|
||||
current_content = message
|
||||
elif isinstance(message, list):
|
||||
current_content = message
|
||||
else:
|
||||
current_content = await unimsg_to_llm_parts(message)
|
||||
|
||||
final_response: LLMResponse | None = None
|
||||
|
||||
for i, model_name in enumerate(model_chain):
|
||||
if not model_name:
|
||||
raise ValueError(f"模型链中第 {i + 1} 个模型名称为空。")
|
||||
|
||||
is_first_step = i == 0
|
||||
is_last_step = i == len(model_chain) - 1
|
||||
|
||||
messages_for_step: list[LLMMessage] = []
|
||||
instruction_for_step = ""
|
||||
if is_first_step and initial_instruction:
|
||||
instruction_for_step = initial_instruction
|
||||
elif is_last_step and final_instruction:
|
||||
instruction_for_step = final_instruction
|
||||
|
||||
if instruction_for_step:
|
||||
messages_for_step.append(LLMMessage.system(instruction_for_step))
|
||||
|
||||
messages_for_step.append(LLMMessage.user(current_content))
|
||||
|
||||
logger.info(
|
||||
f"Pipeline Step [{i + 1}/{len(model_chain)}]: "
|
||||
f"使用模型 '{model_name}' 进行处理..."
|
||||
)
|
||||
try:
|
||||
async with await get_model_instance(model_name, **kwargs) as model:
|
||||
response = await model.generate_response(messages_for_step)
|
||||
final_response = response
|
||||
current_content = response.text.strip()
|
||||
if not current_content and not is_last_step:
|
||||
logger.warning(
|
||||
f"模型 '{model_name}' 在中间步骤返回了空内容,流水线可能无法继续。"
|
||||
)
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"在模型链的第 {i + 1} 步 ('{model_name}') 出错: {e}", e=e)
|
||||
raise LLMException(
|
||||
f"流水线在模型 '{model_name}' 处执行失败: {e}",
|
||||
code=LLMErrorCode.GENERATION_FAILED,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
if final_response is None:
|
||||
raise LLMException(
|
||||
"AI流水线未能产生任何响应。", code=LLMErrorCode.GENERATION_FAILED
|
||||
)
|
||||
|
||||
return final_response
|
||||
|
||||
@ -14,6 +14,8 @@ from .generation import (
|
||||
from .presets import CommonOverrides
|
||||
from .providers import (
|
||||
LLMConfig,
|
||||
ToolConfig,
|
||||
get_gemini_safety_threshold,
|
||||
get_llm_config,
|
||||
register_llm_configs,
|
||||
set_default_model,
|
||||
@ -25,8 +27,10 @@ __all__ = [
|
||||
"LLMConfig",
|
||||
"LLMGenerationConfig",
|
||||
"ModelConfigOverride",
|
||||
"ToolConfig",
|
||||
"apply_api_specific_mappings",
|
||||
"create_generation_config_from_kwargs",
|
||||
"get_gemini_safety_threshold",
|
||||
"get_llm_config",
|
||||
"register_llm_configs",
|
||||
"set_default_model",
|
||||
|
||||
@ -111,12 +111,12 @@ class LLMGenerationConfig(ModelConfigOverride):
|
||||
params["temperature"] = self.temperature
|
||||
|
||||
if self.max_tokens is not None:
|
||||
if api_type in ["gemini", "gemini_native"]:
|
||||
if api_type == "gemini":
|
||||
params["maxOutputTokens"] = self.max_tokens
|
||||
else:
|
||||
params["max_tokens"] = self.max_tokens
|
||||
|
||||
if api_type in ["gemini", "gemini_native"]:
|
||||
if api_type == "gemini":
|
||||
if self.top_k is not None:
|
||||
params["topK"] = self.top_k
|
||||
if self.top_p is not None:
|
||||
@ -151,13 +151,13 @@ class LLMGenerationConfig(ModelConfigOverride):
|
||||
if api_type in ["openai", "zhipu", "deepseek", "general_openai_compat"]:
|
||||
params["response_format"] = {"type": "json_object"}
|
||||
logger.debug(f"为 {api_type} 启用 JSON 对象输出模式")
|
||||
elif api_type in ["gemini", "gemini_native"]:
|
||||
elif api_type == "gemini":
|
||||
params["responseMimeType"] = "application/json"
|
||||
if self.response_schema:
|
||||
params["responseSchema"] = self.response_schema
|
||||
logger.debug(f"为 {api_type} 启用 JSON MIME 类型输出模式")
|
||||
|
||||
if api_type in ["gemini", "gemini_native"]:
|
||||
if api_type == "gemini":
|
||||
if (
|
||||
self.response_format != ResponseFormat.JSON
|
||||
and self.response_mime_type is not None
|
||||
@ -214,7 +214,7 @@ def apply_api_specific_mappings(
|
||||
"""应用API特定的参数映射"""
|
||||
mapped_params = params.copy()
|
||||
|
||||
if api_type in ["gemini", "gemini_native"]:
|
||||
if api_type == "gemini":
|
||||
if "max_tokens" in mapped_params:
|
||||
mapped_params["maxOutputTokens"] = mapped_params.pop("max_tokens")
|
||||
if "top_k" in mapped_params:
|
||||
|
||||
@ -71,14 +71,17 @@ class CommonOverrides:
|
||||
|
||||
@staticmethod
|
||||
def gemini_safe() -> LLMGenerationConfig:
|
||||
"""Gemini 安全模式:严格安全设置"""
|
||||
"""Gemini 安全模式:使用配置的安全设置"""
|
||||
from .providers import get_gemini_safety_threshold
|
||||
|
||||
threshold = get_gemini_safety_threshold()
|
||||
return LLMGenerationConfig(
|
||||
temperature=0.5,
|
||||
safety_settings={
|
||||
"HARM_CATEGORY_HARASSMENT": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
"HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
"HARM_CATEGORY_HARASSMENT": threshold,
|
||||
"HARM_CATEGORY_HATE_SPEECH": threshold,
|
||||
"HARM_CATEGORY_SEXUALLY_EXPLICIT": threshold,
|
||||
"HARM_CATEGORY_DANGEROUS_CONTENT": threshold,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -4,15 +4,33 @@ LLM 提供商配置管理
|
||||
负责注册和管理 AI 服务提供商的配置项。
|
||||
"""
|
||||
|
||||
from functools import lru_cache
|
||||
import json
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.configs.path_config import DATA_PATH
|
||||
from zhenxun.configs.utils import parse_as
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||
|
||||
from ..types.models import ModelDetail, ProviderConfig
|
||||
|
||||
|
||||
class ToolConfig(BaseModel):
|
||||
"""MCP类型工具的配置定义"""
|
||||
|
||||
type: str = "mcp"
|
||||
name: str = Field(..., description="工具的唯一名称标识")
|
||||
description: str | None = Field(None, description="工具功能的描述")
|
||||
mcp_config: dict[str, Any] | BaseModel = Field(
|
||||
..., description="MCP服务器的特定配置"
|
||||
)
|
||||
|
||||
|
||||
AI_CONFIG_GROUP = "AI"
|
||||
PROVIDERS_CONFIG_KEY = "PROVIDERS"
|
||||
|
||||
@ -38,6 +56,9 @@ class LLMConfig(BaseModel):
|
||||
providers: list[ProviderConfig] = Field(
|
||||
default_factory=list, description="配置多个 AI 服务提供商及其模型信息"
|
||||
)
|
||||
mcp_tools: list[ToolConfig] = Field(
|
||||
default_factory=list, description="配置可用的外部MCP工具"
|
||||
)
|
||||
|
||||
def get_provider_by_name(self, name: str) -> ProviderConfig | None:
|
||||
"""根据名称获取提供商配置
|
||||
@ -132,7 +153,7 @@ def get_default_providers() -> list[dict[str, Any]]:
|
||||
return [
|
||||
{
|
||||
"name": "DeepSeek",
|
||||
"api_key": "sk-******",
|
||||
"api_key": "YOUR_ARK_API_KEY",
|
||||
"api_base": "https://api.deepseek.com",
|
||||
"api_type": "openai",
|
||||
"models": [
|
||||
@ -146,9 +167,30 @@ def get_default_providers() -> list[dict[str, Any]]:
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "ARK",
|
||||
"api_key": "YOUR_ARK_API_KEY",
|
||||
"api_base": "https://ark.cn-beijing.volces.com",
|
||||
"api_type": "ark",
|
||||
"models": [
|
||||
{"model_name": "deepseek-r1-250528"},
|
||||
{"model_name": "doubao-seed-1-6-250615"},
|
||||
{"model_name": "doubao-seed-1-6-flash-250615"},
|
||||
{"model_name": "doubao-seed-1-6-thinking-250615"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "siliconflow",
|
||||
"api_key": "YOUR_ARK_API_KEY",
|
||||
"api_base": "https://api.siliconflow.cn",
|
||||
"api_type": "openai",
|
||||
"models": [
|
||||
{"model_name": "deepseek-ai/DeepSeek-V3"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "GLM",
|
||||
"api_key": "",
|
||||
"api_key": "YOUR_ARK_API_KEY",
|
||||
"api_base": "https://open.bigmodel.cn",
|
||||
"api_type": "zhipu",
|
||||
"models": [
|
||||
@ -167,12 +209,41 @@ def get_default_providers() -> list[dict[str, Any]]:
|
||||
"api_type": "gemini",
|
||||
"models": [
|
||||
{"model_name": "gemini-2.0-flash"},
|
||||
{"model_name": "gemini-2.5-flash-preview-05-20"},
|
||||
{"model_name": "gemini-2.5-flash"},
|
||||
{"model_name": "gemini-2.5-pro"},
|
||||
{"model_name": "gemini-2.5-flash-lite-preview-06-17"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def get_default_mcp_tools() -> dict[str, Any]:
|
||||
"""
|
||||
获取默认的MCP工具配置,用于在文件不存在时创建。
|
||||
包含了 baidu-map, Context7, 和 sequential-thinking.
|
||||
"""
|
||||
return {
|
||||
"mcpServers": {
|
||||
"baidu-map": {
|
||||
"command": "npx",
|
||||
"args": ["-y", "@baidumap/mcp-server-baidu-map"],
|
||||
"env": {"BAIDU_MAP_API_KEY": "<YOUR_BAIDU_MAP_API_KEY>"},
|
||||
"description": "百度地图工具,提供地理编码、路线规划等功能。",
|
||||
},
|
||||
"sequential-thinking": {
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-sequential-thinking"],
|
||||
"description": "顺序思维工具,用于帮助模型进行多步骤推理。",
|
||||
},
|
||||
"Context7": {
|
||||
"command": "npx",
|
||||
"args": ["-y", "@upstash/context7-mcp@latest"],
|
||||
"description": "Upstash 提供的上下文管理和记忆工具。",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def register_llm_configs():
|
||||
"""注册 LLM 服务的配置项"""
|
||||
logger.info("注册 LLM 服务的配置项")
|
||||
@ -214,6 +285,19 @@ def register_llm_configs():
|
||||
help="LLM服务请求重试的基础延迟时间(秒)",
|
||||
type=int,
|
||||
)
|
||||
Config.add_plugin_config(
|
||||
AI_CONFIG_GROUP,
|
||||
"gemini_safety_threshold",
|
||||
"BLOCK_MEDIUM_AND_ABOVE",
|
||||
help=(
|
||||
"Gemini 安全过滤阈值 "
|
||||
"(BLOCK_LOW_AND_ABOVE: 阻止低级别及以上, "
|
||||
"BLOCK_MEDIUM_AND_ABOVE: 阻止中等级别及以上, "
|
||||
"BLOCK_ONLY_HIGH: 只阻止高级别, "
|
||||
"BLOCK_NONE: 不阻止)"
|
||||
),
|
||||
type=str,
|
||||
)
|
||||
|
||||
Config.add_plugin_config(
|
||||
AI_CONFIG_GROUP,
|
||||
@ -225,24 +309,111 @@ def register_llm_configs():
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_llm_config() -> LLMConfig:
|
||||
"""获取 LLM 配置实例
|
||||
|
||||
返回:
|
||||
LLMConfig: LLM 配置实例
|
||||
"""
|
||||
"""获取 LLM 配置实例,现在会从新的 JSON 文件加载 MCP 工具"""
|
||||
ai_config = get_ai_config()
|
||||
|
||||
llm_data_path = DATA_PATH / "llm"
|
||||
mcp_tools_path = llm_data_path / "mcp_tools.json"
|
||||
|
||||
mcp_tools_list = []
|
||||
mcp_servers_dict = {}
|
||||
|
||||
if not mcp_tools_path.exists():
|
||||
logger.info(f"未找到 MCP 工具配置文件,将在 '{mcp_tools_path}' 创建一个。")
|
||||
llm_data_path.mkdir(parents=True, exist_ok=True)
|
||||
default_mcp_config = get_default_mcp_tools()
|
||||
try:
|
||||
with mcp_tools_path.open("w", encoding="utf-8") as f:
|
||||
json.dump(default_mcp_config, f, ensure_ascii=False, indent=2)
|
||||
mcp_servers_dict = default_mcp_config.get("mcpServers", {})
|
||||
except Exception as e:
|
||||
logger.error(f"创建默认 MCP 配置文件失败: {e}", e=e)
|
||||
mcp_servers_dict = {}
|
||||
else:
|
||||
try:
|
||||
with mcp_tools_path.open("r", encoding="utf-8") as f:
|
||||
mcp_data = json.load(f)
|
||||
mcp_servers_dict = mcp_data.get("mcpServers", {})
|
||||
if not isinstance(mcp_servers_dict, dict):
|
||||
logger.warning(
|
||||
f"'{mcp_tools_path}' 中的 'mcpServers' 键不是一个字典,"
|
||||
f"将使用空配置。"
|
||||
)
|
||||
mcp_servers_dict = {}
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"解析 MCP 配置文件 '{mcp_tools_path}' 失败: {e}", e=e)
|
||||
except Exception as e:
|
||||
logger.error(f"读取 MCP 配置文件时发生未知错误: {e}", e=e)
|
||||
mcp_servers_dict = {}
|
||||
|
||||
if sys.platform == "win32":
|
||||
logger.debug("检测到Windows平台,正在调整MCP工具的npx命令...")
|
||||
for name, config in mcp_servers_dict.items():
|
||||
if isinstance(config, dict) and config.get("command") == "npx":
|
||||
logger.info(f"为工具 '{name}' 包装npx命令以兼容Windows。")
|
||||
original_args = config.get("args", [])
|
||||
config["command"] = "cmd"
|
||||
config["args"] = ["/c", "npx", *original_args]
|
||||
|
||||
if mcp_servers_dict:
|
||||
mcp_tools_list = [
|
||||
{
|
||||
"name": name,
|
||||
"type": "mcp",
|
||||
"description": config.get("description", f"MCP tool for {name}"),
|
||||
"mcp_config": config,
|
||||
}
|
||||
for name, config in mcp_servers_dict.items()
|
||||
if isinstance(config, dict)
|
||||
]
|
||||
|
||||
from ..tools.registry import tool_registry
|
||||
|
||||
for tool_dict in mcp_tools_list:
|
||||
if isinstance(tool_dict, dict):
|
||||
tool_name = tool_dict.get("name")
|
||||
if not tool_name:
|
||||
continue
|
||||
|
||||
config_model = tool_registry.get_mcp_config_model(tool_name)
|
||||
if not config_model:
|
||||
logger.debug(
|
||||
f"MCP工具 '{tool_name}' 没有注册其配置模型,"
|
||||
f"将跳过特定配置验证,直接使用原始配置字典。"
|
||||
)
|
||||
continue
|
||||
|
||||
mcp_config_data = tool_dict.get("mcp_config", {})
|
||||
try:
|
||||
parsed_mcp_config = parse_as(config_model, mcp_config_data)
|
||||
tool_dict["mcp_config"] = parsed_mcp_config
|
||||
except Exception as e:
|
||||
raise ValueError(f"MCP工具 '{tool_name}' 的 `mcp_config` 配置错误: {e}")
|
||||
|
||||
config_data = {
|
||||
"default_model_name": ai_config.get("default_model_name"),
|
||||
"proxy": ai_config.get("proxy"),
|
||||
"timeout": ai_config.get("timeout", 180),
|
||||
"max_retries_llm": ai_config.get("max_retries_llm", 3),
|
||||
"retry_delay_llm": ai_config.get("retry_delay_llm", 2),
|
||||
"providers": ai_config.get(PROVIDERS_CONFIG_KEY, []),
|
||||
PROVIDERS_CONFIG_KEY: ai_config.get(PROVIDERS_CONFIG_KEY, []),
|
||||
"mcp_tools": mcp_tools_list,
|
||||
}
|
||||
|
||||
return LLMConfig(**config_data)
|
||||
return parse_as(LLMConfig, config_data)
|
||||
|
||||
|
||||
def get_gemini_safety_threshold() -> str:
|
||||
"""获取 Gemini 安全过滤阈值配置
|
||||
|
||||
返回:
|
||||
str: 安全过滤阈值
|
||||
"""
|
||||
ai_config = get_ai_config()
|
||||
return ai_config.get("gemini_safety_threshold", "BLOCK_MEDIUM_AND_ABOVE")
|
||||
|
||||
|
||||
def validate_llm_config() -> tuple[bool, list[str]]:
|
||||
@ -326,3 +497,17 @@ def set_default_model(provider_model_name: str | None) -> bool:
|
||||
logger.info("默认模型已清除")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@PriorityLifecycle.on_startup(priority=10)
|
||||
async def _init_llm_config_on_startup():
|
||||
"""
|
||||
在服务启动时主动调用一次 get_llm_config,
|
||||
以触发必要的初始化操作,例如创建默认的 mcp_tools.json 文件。
|
||||
"""
|
||||
logger.info("正在初始化 LLM 配置并检查 MCP 工具文件...")
|
||||
try:
|
||||
get_llm_config()
|
||||
logger.info("LLM 配置初始化完成。")
|
||||
except Exception as e:
|
||||
logger.error(f"LLM 配置初始化时发生错误: {e}", e=e)
|
||||
|
||||
@ -49,12 +49,36 @@ class LLMHttpClient:
|
||||
max_keepalive_connections=self.config.max_keepalive_connections,
|
||||
)
|
||||
timeout = httpx.Timeout(self.config.timeout)
|
||||
|
||||
client_kwargs = {}
|
||||
if self.config.proxy:
|
||||
try:
|
||||
version_parts = httpx.__version__.split(".")
|
||||
major = int(
|
||||
"".join(c for c in version_parts[0] if c.isdigit())
|
||||
)
|
||||
minor = (
|
||||
int("".join(c for c in version_parts[1] if c.isdigit()))
|
||||
if len(version_parts) > 1
|
||||
else 0
|
||||
)
|
||||
if (major, minor) >= (0, 28):
|
||||
client_kwargs["proxy"] = self.config.proxy
|
||||
else:
|
||||
client_kwargs["proxies"] = self.config.proxy
|
||||
except (ValueError, IndexError):
|
||||
client_kwargs["proxies"] = self.config.proxy
|
||||
logger.warning(
|
||||
f"无法解析 httpx 版本 '{httpx.__version__}',"
|
||||
"LLM模块将默认使用旧版 'proxies' 参数语法。"
|
||||
)
|
||||
|
||||
self._client = httpx.AsyncClient(
|
||||
headers=headers,
|
||||
limits=limits,
|
||||
timeout=timeout,
|
||||
proxies=self.config.proxy,
|
||||
follow_redirects=True,
|
||||
**client_kwargs,
|
||||
)
|
||||
if self._client is None:
|
||||
raise LLMException(
|
||||
@ -156,7 +180,16 @@ async def create_llm_http_client(
|
||||
timeout: int = 180,
|
||||
proxy: str | None = None,
|
||||
) -> LLMHttpClient:
|
||||
"""创建LLM HTTP客户端"""
|
||||
"""
|
||||
创建LLM HTTP客户端
|
||||
|
||||
参数:
|
||||
timeout: 超时时间(秒)。
|
||||
proxy: 代理服务器地址。
|
||||
|
||||
返回:
|
||||
LLMHttpClient: HTTP客户端实例。
|
||||
"""
|
||||
config = HttpClientConfig(timeout=timeout, proxy=proxy)
|
||||
return LLMHttpClient(config)
|
||||
|
||||
@ -185,7 +218,20 @@ async def with_smart_retry(
|
||||
provider_name: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""智能重试装饰器 - 支持Key轮询和错误分类"""
|
||||
"""
|
||||
智能重试装饰器 - 支持Key轮询和错误分类
|
||||
|
||||
参数:
|
||||
func: 要重试的异步函数。
|
||||
*args: 传递给函数的位置参数。
|
||||
retry_config: 重试配置。
|
||||
key_store: API密钥状态存储。
|
||||
provider_name: 提供商名称。
|
||||
**kwargs: 传递给函数的关键字参数。
|
||||
|
||||
返回:
|
||||
Any: 函数执行结果。
|
||||
"""
|
||||
config = retry_config or RetryConfig()
|
||||
last_exception: Exception | None = None
|
||||
failed_keys: set[str] = set()
|
||||
@ -294,7 +340,17 @@ class KeyStatusStore:
|
||||
api_keys: list[str],
|
||||
exclude_keys: set[str] | None = None,
|
||||
) -> str | None:
|
||||
"""获取下一个可用的API密钥(轮询策略)"""
|
||||
"""
|
||||
获取下一个可用的API密钥(轮询策略)
|
||||
|
||||
参数:
|
||||
provider_name: 提供商名称。
|
||||
api_keys: API密钥列表。
|
||||
exclude_keys: 要排除的密钥集合。
|
||||
|
||||
返回:
|
||||
str | None: 可用的API密钥,如果没有可用密钥则返回None。
|
||||
"""
|
||||
if not api_keys:
|
||||
return None
|
||||
|
||||
@ -338,7 +394,13 @@ class KeyStatusStore:
|
||||
logger.debug(f"记录API密钥成功使用: {self._get_key_id(api_key)}")
|
||||
|
||||
async def record_failure(self, api_key: str, status_code: int | None):
|
||||
"""记录失败使用"""
|
||||
"""
|
||||
记录失败使用
|
||||
|
||||
参数:
|
||||
api_key: API密钥。
|
||||
status_code: HTTP状态码。
|
||||
"""
|
||||
key_id = self._get_key_id(api_key)
|
||||
async with self._lock:
|
||||
if status_code in [401, 403]:
|
||||
@ -356,7 +418,15 @@ class KeyStatusStore:
|
||||
logger.info(f"重置API密钥状态: {self._get_key_id(api_key)}")
|
||||
|
||||
async def get_key_stats(self, api_keys: list[str]) -> dict[str, dict]:
|
||||
"""获取密钥使用统计"""
|
||||
"""
|
||||
获取密钥使用统计
|
||||
|
||||
参数:
|
||||
api_keys: API密钥列表。
|
||||
|
||||
返回:
|
||||
dict[str, dict]: 密钥统计信息字典。
|
||||
"""
|
||||
stats = {}
|
||||
async with self._lock:
|
||||
for key in api_keys:
|
||||
|
||||
@ -17,6 +17,7 @@ from .config.providers import AI_CONFIG_GROUP, PROVIDERS_CONFIG_KEY, get_ai_conf
|
||||
from .core import http_client_manager, key_store
|
||||
from .service import LLMModel
|
||||
from .types import LLMErrorCode, LLMException, ModelDetail, ProviderConfig
|
||||
from .types.capabilities import get_model_capabilities
|
||||
|
||||
DEFAULT_MODEL_NAME_KEY = "default_model_name"
|
||||
PROXY_KEY = "proxy"
|
||||
@ -115,57 +116,30 @@ def get_default_api_base_for_type(api_type: str) -> str | None:
|
||||
|
||||
|
||||
def get_configured_providers() -> list[ProviderConfig]:
|
||||
"""从配置中获取Provider列表 - 简化版本"""
|
||||
"""从配置中获取Provider列表 - 简化和修正版本"""
|
||||
ai_config = get_ai_config()
|
||||
providers_raw = ai_config.get(PROVIDERS_CONFIG_KEY, [])
|
||||
if not isinstance(providers_raw, list):
|
||||
providers = ai_config.get(PROVIDERS_CONFIG_KEY, [])
|
||||
|
||||
if not isinstance(providers, list):
|
||||
logger.error(
|
||||
f"配置项 {AI_CONFIG_GROUP}.{PROVIDERS_CONFIG_KEY} 不是一个列表,"
|
||||
f"配置项 {AI_CONFIG_GROUP}.{PROVIDERS_CONFIG_KEY} 的值不是一个列表,"
|
||||
f"将使用空列表。"
|
||||
)
|
||||
return []
|
||||
|
||||
valid_providers = []
|
||||
for i, item in enumerate(providers_raw):
|
||||
if not isinstance(item, dict):
|
||||
logger.warning(f"配置文件中第 {i + 1} 项不是字典格式,已跳过。")
|
||||
continue
|
||||
|
||||
try:
|
||||
if not item.get("name"):
|
||||
logger.warning(f"Provider {i + 1} 缺少 'name' 字段,已跳过。")
|
||||
continue
|
||||
|
||||
if not item.get("api_key"):
|
||||
logger.warning(
|
||||
f"Provider '{item['name']}' 缺少 'api_key' 字段,已跳过。"
|
||||
)
|
||||
continue
|
||||
|
||||
if "api_type" not in item or not item["api_type"]:
|
||||
provider_name = item.get("name", "").lower()
|
||||
if "glm" in provider_name or "zhipu" in provider_name:
|
||||
item["api_type"] = "zhipu"
|
||||
elif "gemini" in provider_name or "google" in provider_name:
|
||||
item["api_type"] = "gemini"
|
||||
else:
|
||||
item["api_type"] = "openai"
|
||||
|
||||
if "api_base" not in item or not item["api_base"]:
|
||||
api_type = item.get("api_type")
|
||||
if api_type:
|
||||
default_api_base = get_default_api_base_for_type(api_type)
|
||||
if default_api_base:
|
||||
item["api_base"] = default_api_base
|
||||
|
||||
if "models" not in item:
|
||||
item["models"] = [{"model_name": item.get("name", "default")}]
|
||||
|
||||
provider_conf = ProviderConfig(**item)
|
||||
valid_providers.append(provider_conf)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"解析配置文件中 Provider {i + 1} 时出错: {e},已跳过。")
|
||||
for i, item in enumerate(providers):
|
||||
if isinstance(item, ProviderConfig):
|
||||
if not item.api_base:
|
||||
default_api_base = get_default_api_base_for_type(item.api_type)
|
||||
if default_api_base:
|
||||
item.api_base = default_api_base
|
||||
valid_providers.append(item)
|
||||
else:
|
||||
logger.warning(
|
||||
f"配置文件中第 {i + 1} 项未能正确解析为 ProviderConfig 对象,"
|
||||
f"已跳过。实际类型: {type(item)}"
|
||||
)
|
||||
|
||||
return valid_providers
|
||||
|
||||
@ -173,14 +147,15 @@ def get_configured_providers() -> list[ProviderConfig]:
|
||||
def find_model_config(
|
||||
provider_name: str, model_name: str
|
||||
) -> tuple[ProviderConfig, ModelDetail] | None:
|
||||
"""在配置中查找指定的 Provider 和 ModelDetail
|
||||
"""
|
||||
在配置中查找指定的 Provider 和 ModelDetail
|
||||
|
||||
Args:
|
||||
参数:
|
||||
provider_name: 提供商名称
|
||||
model_name: 模型名称
|
||||
|
||||
Returns:
|
||||
找到的 (ProviderConfig, ModelDetail) 元组,未找到则返回 None
|
||||
返回:
|
||||
tuple[ProviderConfig, ModelDetail] | None: 找到的配置元组,未找到则返回 None
|
||||
"""
|
||||
providers = get_configured_providers()
|
||||
|
||||
@ -221,10 +196,11 @@ def _get_model_identifiers(provider_name: str, model_detail: ModelDetail) -> lis
|
||||
|
||||
|
||||
def list_model_identifiers() -> dict[str, list[str]]:
|
||||
"""列出所有模型的可用标识符
|
||||
"""
|
||||
列出所有模型的可用标识符
|
||||
|
||||
Returns:
|
||||
字典,键为模型的完整名称,值为该模型的所有可用标识符列表
|
||||
返回:
|
||||
dict[str, list[str]]: 字典,键为模型的完整名称,值为该模型的所有可用标识符列表
|
||||
"""
|
||||
providers = get_configured_providers()
|
||||
result = {}
|
||||
@ -248,7 +224,16 @@ async def get_model_instance(
|
||||
provider_model_name: str | None = None,
|
||||
override_config: dict[str, Any] | None = None,
|
||||
) -> LLMModel:
|
||||
"""根据 'ProviderName/ModelName' 字符串获取并实例化 LLMModel (异步版本)"""
|
||||
"""
|
||||
根据 'ProviderName/ModelName' 字符串获取并实例化 LLMModel (异步版本)
|
||||
|
||||
参数:
|
||||
provider_model_name: 模型名称,格式为 'ProviderName/ModelName'。
|
||||
override_config: 覆盖配置字典。
|
||||
|
||||
返回:
|
||||
LLMModel: 模型实例。
|
||||
"""
|
||||
cache_key = _make_cache_key(provider_model_name, override_config)
|
||||
cached_model = _get_cached_model(cache_key)
|
||||
if cached_model:
|
||||
@ -292,6 +277,10 @@ async def get_model_instance(
|
||||
|
||||
provider_config_found, model_detail_found = config_tuple_found
|
||||
|
||||
capabilities = get_model_capabilities(model_detail_found.model_name)
|
||||
|
||||
model_detail_found.is_embedding_model = capabilities.is_embedding_model
|
||||
|
||||
ai_config = get_ai_config()
|
||||
global_proxy_setting = ai_config.get(PROXY_KEY)
|
||||
default_timeout = (
|
||||
@ -322,6 +311,7 @@ async def get_model_instance(
|
||||
model_detail=model_detail_found,
|
||||
key_store=key_store,
|
||||
http_client=shared_http_client,
|
||||
capabilities=capabilities,
|
||||
)
|
||||
|
||||
if override_config:
|
||||
@ -357,7 +347,15 @@ def get_global_default_model_name() -> str | None:
|
||||
|
||||
|
||||
def set_global_default_model_name(provider_model_name: str | None) -> bool:
|
||||
"""设置全局默认模型名称"""
|
||||
"""
|
||||
设置全局默认模型名称
|
||||
|
||||
参数:
|
||||
provider_model_name: 模型名称,格式为 'ProviderName/ModelName'。
|
||||
|
||||
返回:
|
||||
bool: 设置是否成功。
|
||||
"""
|
||||
if provider_model_name:
|
||||
prov_name, mod_name = parse_provider_model_string(provider_model_name)
|
||||
if not prov_name or not mod_name or not find_model_config(prov_name, mod_name):
|
||||
@ -377,7 +375,12 @@ def set_global_default_model_name(provider_model_name: str | None) -> bool:
|
||||
|
||||
|
||||
async def get_key_usage_stats() -> dict[str, Any]:
|
||||
"""获取所有Provider的Key使用统计"""
|
||||
"""
|
||||
获取所有Provider的Key使用统计
|
||||
|
||||
返回:
|
||||
dict[str, Any]: 包含所有Provider的Key使用统计信息。
|
||||
"""
|
||||
providers = get_configured_providers()
|
||||
stats = {}
|
||||
|
||||
@ -400,7 +403,16 @@ async def get_key_usage_stats() -> dict[str, Any]:
|
||||
|
||||
|
||||
async def reset_key_status(provider_name: str, api_key: str | None = None) -> bool:
|
||||
"""重置指定Provider的Key状态"""
|
||||
"""
|
||||
重置指定Provider的Key状态
|
||||
|
||||
参数:
|
||||
provider_name: 提供商名称。
|
||||
api_key: 要重置的特定API密钥,如果为None则重置所有密钥。
|
||||
|
||||
返回:
|
||||
bool: 重置是否成功。
|
||||
"""
|
||||
providers = get_configured_providers()
|
||||
target_provider = None
|
||||
|
||||
|
||||
@ -6,11 +6,13 @@ LLM 模型实现类
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Awaitable, Callable
|
||||
from contextlib import AsyncExitStack
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
from .adapters.base import RequestData
|
||||
from .config import LLMGenerationConfig
|
||||
from .config.providers import get_ai_config
|
||||
from .core import (
|
||||
@ -30,6 +32,8 @@ from .types import (
|
||||
ModelDetail,
|
||||
ProviderConfig,
|
||||
)
|
||||
from .types.capabilities import ModelCapabilities, ModelModality
|
||||
from .utils import _sanitize_request_body_for_logging
|
||||
|
||||
|
||||
class LLMModelBase(ABC):
|
||||
@ -42,7 +46,17 @@ class LLMModelBase(ABC):
|
||||
history: list[dict[str, str]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""生成文本"""
|
||||
"""
|
||||
生成文本
|
||||
|
||||
参数:
|
||||
prompt: 输入提示词。
|
||||
history: 对话历史记录。
|
||||
**kwargs: 其他参数。
|
||||
|
||||
返回:
|
||||
str: 生成的文本。
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -54,7 +68,19 @@ class LLMModelBase(ABC):
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""生成高级响应"""
|
||||
"""
|
||||
生成高级响应
|
||||
|
||||
参数:
|
||||
messages: 消息列表。
|
||||
config: 生成配置。
|
||||
tools: 工具列表。
|
||||
tool_choice: 工具选择策略。
|
||||
**kwargs: 其他参数。
|
||||
|
||||
返回:
|
||||
LLMResponse: 模型响应。
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -64,7 +90,17 @@ class LLMModelBase(ABC):
|
||||
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
||||
**kwargs: Any,
|
||||
) -> list[list[float]]:
|
||||
"""生成文本嵌入向量"""
|
||||
"""
|
||||
生成文本嵌入向量
|
||||
|
||||
参数:
|
||||
texts: 文本列表。
|
||||
task_type: 嵌入任务类型。
|
||||
**kwargs: 其他参数。
|
||||
|
||||
返回:
|
||||
list[list[float]]: 嵌入向量列表。
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@ -77,12 +113,14 @@ class LLMModel(LLMModelBase):
|
||||
model_detail: ModelDetail,
|
||||
key_store: KeyStatusStore,
|
||||
http_client: LLMHttpClient,
|
||||
capabilities: ModelCapabilities,
|
||||
config_override: LLMGenerationConfig | None = None,
|
||||
):
|
||||
self.provider_config = provider_config
|
||||
self.model_detail = model_detail
|
||||
self.key_store = key_store
|
||||
self.http_client: LLMHttpClient = http_client
|
||||
self.capabilities = capabilities
|
||||
self._generation_config = config_override
|
||||
|
||||
self.provider_name = provider_config.name
|
||||
@ -99,6 +137,34 @@ class LLMModel(LLMModelBase):
|
||||
|
||||
self._is_closed = False
|
||||
|
||||
def can_process_images(self) -> bool:
|
||||
"""检查模型是否支持图片作为输入。"""
|
||||
return ModelModality.IMAGE in self.capabilities.input_modalities
|
||||
|
||||
def can_process_video(self) -> bool:
|
||||
"""检查模型是否支持视频作为输入。"""
|
||||
return ModelModality.VIDEO in self.capabilities.input_modalities
|
||||
|
||||
def can_process_audio(self) -> bool:
|
||||
"""检查模型是否支持音频作为输入。"""
|
||||
return ModelModality.AUDIO in self.capabilities.input_modalities
|
||||
|
||||
def can_generate_images(self) -> bool:
|
||||
"""检查模型是否支持生成图片。"""
|
||||
return ModelModality.IMAGE in self.capabilities.output_modalities
|
||||
|
||||
def can_generate_audio(self) -> bool:
|
||||
"""检查模型是否支持生成音频 (TTS)。"""
|
||||
return ModelModality.AUDIO in self.capabilities.output_modalities
|
||||
|
||||
def can_use_tools(self) -> bool:
|
||||
"""检查模型是否支持工具调用/函数调用。"""
|
||||
return self.capabilities.supports_tool_calling
|
||||
|
||||
def is_embedding_model(self) -> bool:
|
||||
"""检查这是否是一个嵌入模型。"""
|
||||
return self.capabilities.is_embedding_model
|
||||
|
||||
async def _get_http_client(self) -> LLMHttpClient:
|
||||
"""获取HTTP客户端"""
|
||||
if self.http_client.is_closed:
|
||||
@ -135,24 +201,54 @@ class LLMModel(LLMModelBase):
|
||||
|
||||
return selected_key
|
||||
|
||||
async def _execute_embedding_request(
|
||||
async def _perform_api_call(
|
||||
self,
|
||||
adapter,
|
||||
texts: list[str],
|
||||
task_type: EmbeddingTaskType | str,
|
||||
http_client: LLMHttpClient,
|
||||
prepare_request_func: Callable[[str], Awaitable["RequestData"]],
|
||||
parse_response_func: Callable[[dict[str, Any]], Any],
|
||||
http_client: "LLMHttpClient",
|
||||
failed_keys: set[str] | None = None,
|
||||
) -> list[list[float]]:
|
||||
"""执行单次嵌入请求 - 供重试机制调用"""
|
||||
log_context: str = "API",
|
||||
) -> Any:
|
||||
"""
|
||||
执行API调用的通用核心方法。
|
||||
|
||||
该方法封装了以下通用逻辑:
|
||||
1. 选择API密钥。
|
||||
2. 准备和记录请求。
|
||||
3. 发送HTTP POST请求。
|
||||
4. 处理HTTP错误和API特定错误。
|
||||
5. 记录密钥使用状态。
|
||||
6. 解析成功的响应。
|
||||
|
||||
参数:
|
||||
prepare_request_func: 准备请求的函数。
|
||||
parse_response_func: 解析响应的函数。
|
||||
http_client: HTTP客户端。
|
||||
failed_keys: 失败的密钥集合。
|
||||
log_context: 日志上下文。
|
||||
|
||||
返回:
|
||||
Any: 解析后的响应数据。
|
||||
"""
|
||||
api_key = await self._select_api_key(failed_keys)
|
||||
|
||||
try:
|
||||
request_data = adapter.prepare_embedding_request(
|
||||
model=self,
|
||||
api_key=api_key,
|
||||
texts=texts,
|
||||
task_type=task_type,
|
||||
request_data = await prepare_request_func(api_key)
|
||||
|
||||
logger.info(
|
||||
f"🌐 发起LLM请求 - 模型: {self.provider_name}/{self.model_name} "
|
||||
f"[{log_context}]"
|
||||
)
|
||||
logger.debug(f"📡 请求URL: {request_data.url}")
|
||||
masked_key = (
|
||||
f"{api_key[:8]}...{api_key[-4:] if len(api_key) > 12 else '***'}"
|
||||
)
|
||||
logger.debug(f"🔑 API密钥: {masked_key}")
|
||||
logger.debug(f"📋 请求头: {dict(request_data.headers)}")
|
||||
|
||||
sanitized_body = _sanitize_request_body_for_logging(request_data.body)
|
||||
request_body_str = json.dumps(sanitized_body, ensure_ascii=False, indent=2)
|
||||
logger.debug(f"📦 请求体: {request_body_str}")
|
||||
|
||||
http_response = await http_client.post(
|
||||
request_data.url,
|
||||
@ -160,121 +256,16 @@ class LLMModel(LLMModelBase):
|
||||
json=request_data.body,
|
||||
)
|
||||
|
||||
if http_response.status_code != 200:
|
||||
error_text = http_response.text
|
||||
logger.error(
|
||||
f"HTTP嵌入请求失败: {http_response.status_code} - {error_text}"
|
||||
)
|
||||
await self.key_store.record_failure(api_key, http_response.status_code)
|
||||
|
||||
error_code = LLMErrorCode.API_REQUEST_FAILED
|
||||
if http_response.status_code in [401, 403]:
|
||||
error_code = LLMErrorCode.API_KEY_INVALID
|
||||
elif http_response.status_code == 429:
|
||||
error_code = LLMErrorCode.API_RATE_LIMITED
|
||||
|
||||
raise LLMException(
|
||||
f"HTTP嵌入请求失败: {http_response.status_code}",
|
||||
code=error_code,
|
||||
details={
|
||||
"status_code": http_response.status_code,
|
||||
"response": error_text,
|
||||
"api_key": api_key,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response_json = http_response.json()
|
||||
adapter.validate_embedding_response(response_json)
|
||||
embeddings = adapter.parse_embedding_response(response_json)
|
||||
except Exception as e:
|
||||
logger.error(f"解析嵌入响应失败: {e}", e=e)
|
||||
await self.key_store.record_failure(api_key, None)
|
||||
if isinstance(e, LLMException):
|
||||
raise
|
||||
else:
|
||||
raise LLMException(
|
||||
f"解析API嵌入响应失败: {e}",
|
||||
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
await self.key_store.record_success(api_key)
|
||||
return embeddings
|
||||
|
||||
except LLMException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"生成嵌入时发生未预期错误: {e}", e=e)
|
||||
await self.key_store.record_failure(api_key, None)
|
||||
raise LLMException(
|
||||
f"生成嵌入失败: {e}",
|
||||
code=LLMErrorCode.EMBEDDING_FAILED,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def _execute_with_smart_retry(
|
||||
self,
|
||||
adapter,
|
||||
messages: list[LLMMessage],
|
||||
config: LLMGenerationConfig | None,
|
||||
tools_dict: list[dict[str, Any]] | None,
|
||||
tool_choice: str | dict[str, Any] | None,
|
||||
http_client: LLMHttpClient,
|
||||
):
|
||||
"""智能重试机制 - 使用统一的重试装饰器"""
|
||||
ai_config = get_ai_config()
|
||||
max_retries = ai_config.get("max_retries_llm", 3)
|
||||
retry_delay = ai_config.get("retry_delay_llm", 2)
|
||||
retry_config = RetryConfig(max_retries=max_retries, retry_delay=retry_delay)
|
||||
|
||||
return await with_smart_retry(
|
||||
self._execute_single_request,
|
||||
adapter,
|
||||
messages,
|
||||
config,
|
||||
tools_dict,
|
||||
tool_choice,
|
||||
http_client,
|
||||
retry_config=retry_config,
|
||||
key_store=self.key_store,
|
||||
provider_name=self.provider_name,
|
||||
)
|
||||
|
||||
async def _execute_single_request(
|
||||
self,
|
||||
adapter,
|
||||
messages: list[LLMMessage],
|
||||
config: LLMGenerationConfig | None,
|
||||
tools_dict: list[dict[str, Any]] | None,
|
||||
tool_choice: str | dict[str, Any] | None,
|
||||
http_client: LLMHttpClient,
|
||||
failed_keys: set[str] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""执行单次请求 - 供重试机制调用,直接返回 LLMResponse"""
|
||||
api_key = await self._select_api_key(failed_keys)
|
||||
|
||||
try:
|
||||
request_data = adapter.prepare_advanced_request(
|
||||
model=self,
|
||||
api_key=api_key,
|
||||
messages=messages,
|
||||
config=config,
|
||||
tools=tools_dict,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
http_response = await http_client.post(
|
||||
request_data.url,
|
||||
headers=request_data.headers,
|
||||
json=request_data.body,
|
||||
)
|
||||
logger.debug(f"📥 响应状态码: {http_response.status_code}")
|
||||
logger.debug(f"📄 响应头: {dict(http_response.headers)}")
|
||||
|
||||
if http_response.status_code != 200:
|
||||
error_text = http_response.text
|
||||
logger.error(
|
||||
f"HTTP请求失败: {http_response.status_code} - {error_text}"
|
||||
f"❌ HTTP请求失败: {http_response.status_code} - {error_text} "
|
||||
f"[{log_context}]"
|
||||
)
|
||||
logger.debug(f"💥 完整错误响应: {error_text}")
|
||||
|
||||
await self.key_store.record_failure(api_key, http_response.status_code)
|
||||
|
||||
@ -299,69 +290,165 @@ class LLMModel(LLMModelBase):
|
||||
|
||||
try:
|
||||
response_json = http_response.json()
|
||||
response_data = adapter.parse_response(
|
||||
model=self,
|
||||
response_json=response_json,
|
||||
is_advanced=True,
|
||||
)
|
||||
|
||||
from .types.models import LLMToolCall
|
||||
|
||||
response_tool_calls = []
|
||||
if response_data.tool_calls:
|
||||
for tc_data in response_data.tool_calls:
|
||||
if isinstance(tc_data, LLMToolCall):
|
||||
response_tool_calls.append(tc_data)
|
||||
elif isinstance(tc_data, dict):
|
||||
try:
|
||||
response_tool_calls.append(LLMToolCall(**tc_data))
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"无法将工具调用数据转换为LLMToolCall: {tc_data}, "
|
||||
f"error: {e}"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"工具调用数据格式未知: {tc_data}")
|
||||
|
||||
llm_response = LLMResponse(
|
||||
text=response_data.text,
|
||||
usage_info=response_data.usage_info,
|
||||
raw_response=response_data.raw_response,
|
||||
tool_calls=response_tool_calls if response_tool_calls else None,
|
||||
code_executions=response_data.code_executions,
|
||||
grounding_metadata=response_data.grounding_metadata,
|
||||
cache_info=response_data.cache_info,
|
||||
response_json_str = json.dumps(
|
||||
response_json, ensure_ascii=False, indent=2
|
||||
)
|
||||
logger.debug(f"📋 响应JSON: {response_json_str}")
|
||||
parsed_data = parse_response_func(response_json)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析响应失败: {e}", e=e)
|
||||
logger.error(f"解析 {log_context} 响应失败: {e}", e=e)
|
||||
await self.key_store.record_failure(api_key, None)
|
||||
|
||||
if isinstance(e, LLMException):
|
||||
raise
|
||||
else:
|
||||
raise LLMException(
|
||||
f"解析API响应失败: {e}",
|
||||
f"解析API {log_context} 响应失败: {e}",
|
||||
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
await self.key_store.record_success(api_key)
|
||||
|
||||
return llm_response
|
||||
logger.debug(f"✅ API密钥使用成功: {masked_key}")
|
||||
logger.info(f"🎯 LLM响应解析完成 [{log_context}]")
|
||||
return parsed_data
|
||||
|
||||
except LLMException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"生成响应时发生未预期错误: {e}", e=e)
|
||||
error_log_msg = f"生成 {log_context.lower()} 时发生未预期错误: {e}"
|
||||
logger.error(error_log_msg, e=e)
|
||||
await self.key_store.record_failure(api_key, None)
|
||||
|
||||
raise LLMException(
|
||||
f"生成响应失败: {e}",
|
||||
code=LLMErrorCode.GENERATION_FAILED,
|
||||
error_log_msg,
|
||||
code=LLMErrorCode.GENERATION_FAILED
|
||||
if log_context == "Generation"
|
||||
else LLMErrorCode.EMBEDDING_FAILED,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def _execute_embedding_request(
|
||||
self,
|
||||
adapter,
|
||||
texts: list[str],
|
||||
task_type: EmbeddingTaskType | str,
|
||||
http_client: LLMHttpClient,
|
||||
failed_keys: set[str] | None = None,
|
||||
) -> list[list[float]]:
|
||||
"""执行单次嵌入请求 - 供重试机制调用"""
|
||||
|
||||
async def prepare_request(api_key: str) -> RequestData:
|
||||
return adapter.prepare_embedding_request(
|
||||
model=self,
|
||||
api_key=api_key,
|
||||
texts=texts,
|
||||
task_type=task_type,
|
||||
)
|
||||
|
||||
def parse_response(response_json: dict[str, Any]) -> list[list[float]]:
|
||||
adapter.validate_embedding_response(response_json)
|
||||
return adapter.parse_embedding_response(response_json)
|
||||
|
||||
return await self._perform_api_call(
|
||||
prepare_request_func=prepare_request,
|
||||
parse_response_func=parse_response,
|
||||
http_client=http_client,
|
||||
failed_keys=failed_keys,
|
||||
log_context="Embedding",
|
||||
)
|
||||
|
||||
async def _execute_with_smart_retry(
|
||||
self,
|
||||
adapter,
|
||||
messages: list[LLMMessage],
|
||||
config: LLMGenerationConfig | None,
|
||||
tools: list[LLMTool] | None,
|
||||
tool_choice: str | dict[str, Any] | None,
|
||||
http_client: LLMHttpClient,
|
||||
):
|
||||
"""智能重试机制 - 使用统一的重试装饰器"""
|
||||
ai_config = get_ai_config()
|
||||
max_retries = ai_config.get("max_retries_llm", 3)
|
||||
retry_delay = ai_config.get("retry_delay_llm", 2)
|
||||
retry_config = RetryConfig(max_retries=max_retries, retry_delay=retry_delay)
|
||||
|
||||
return await with_smart_retry(
|
||||
self._execute_single_request,
|
||||
adapter,
|
||||
messages,
|
||||
config,
|
||||
tools,
|
||||
tool_choice,
|
||||
http_client,
|
||||
retry_config=retry_config,
|
||||
key_store=self.key_store,
|
||||
provider_name=self.provider_name,
|
||||
)
|
||||
|
||||
async def _execute_single_request(
|
||||
self,
|
||||
adapter,
|
||||
messages: list[LLMMessage],
|
||||
config: LLMGenerationConfig | None,
|
||||
tools: list[LLMTool] | None,
|
||||
tool_choice: str | dict[str, Any] | None,
|
||||
http_client: LLMHttpClient,
|
||||
failed_keys: set[str] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""执行单次请求 - 供重试机制调用,直接返回 LLMResponse"""
|
||||
|
||||
async def prepare_request(api_key: str) -> RequestData:
|
||||
return await adapter.prepare_advanced_request(
|
||||
model=self,
|
||||
api_key=api_key,
|
||||
messages=messages,
|
||||
config=config,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
def parse_response(response_json: dict[str, Any]) -> LLMResponse:
|
||||
response_data = adapter.parse_response(
|
||||
model=self,
|
||||
response_json=response_json,
|
||||
is_advanced=True,
|
||||
)
|
||||
from .types.models import LLMToolCall
|
||||
|
||||
response_tool_calls = []
|
||||
if response_data.tool_calls:
|
||||
for tc_data in response_data.tool_calls:
|
||||
if isinstance(tc_data, LLMToolCall):
|
||||
response_tool_calls.append(tc_data)
|
||||
elif isinstance(tc_data, dict):
|
||||
try:
|
||||
response_tool_calls.append(LLMToolCall(**tc_data))
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"无法将工具调用数据转换为LLMToolCall: {tc_data}, "
|
||||
f"error: {e}"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"工具调用数据格式未知: {tc_data}")
|
||||
|
||||
return LLMResponse(
|
||||
text=response_data.text,
|
||||
usage_info=response_data.usage_info,
|
||||
raw_response=response_data.raw_response,
|
||||
tool_calls=response_tool_calls if response_tool_calls else None,
|
||||
code_executions=response_data.code_executions,
|
||||
grounding_metadata=response_data.grounding_metadata,
|
||||
cache_info=response_data.cache_info,
|
||||
)
|
||||
|
||||
return await self._perform_api_call(
|
||||
prepare_request_func=prepare_request,
|
||||
parse_response_func=parse_response,
|
||||
http_client=http_client,
|
||||
failed_keys=failed_keys,
|
||||
log_context="Generation",
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
标记模型实例的当前使用周期结束。
|
||||
@ -400,7 +487,17 @@ class LLMModel(LLMModelBase):
|
||||
history: list[dict[str, str]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""生成文本 - 通过 generate_response 实现"""
|
||||
"""
|
||||
生成文本 - 通过 generate_response 实现
|
||||
|
||||
参数:
|
||||
prompt: 输入提示词。
|
||||
history: 对话历史记录。
|
||||
**kwargs: 其他参数。
|
||||
|
||||
返回:
|
||||
str: 生成的文本。
|
||||
"""
|
||||
self._check_not_closed()
|
||||
|
||||
messages: list[LLMMessage] = []
|
||||
@ -439,11 +536,21 @@ class LLMModel(LLMModelBase):
|
||||
config: LLMGenerationConfig | None = None,
|
||||
tools: list[LLMTool] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tool_executor: Callable[[str, dict[str, Any]], Awaitable[Any]] | None = None,
|
||||
max_tool_iterations: int = 5,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""生成高级响应 - 实现完整的工具调用循环"""
|
||||
"""
|
||||
生成高级响应
|
||||
|
||||
参数:
|
||||
messages: 消息列表。
|
||||
config: 生成配置。
|
||||
tools: 工具列表。
|
||||
tool_choice: 工具选择策略。
|
||||
**kwargs: 其他参数。
|
||||
|
||||
返回:
|
||||
LLMResponse: 模型响应。
|
||||
"""
|
||||
self._check_not_closed()
|
||||
|
||||
from .adapters import get_adapter_for_api_type
|
||||
@ -468,109 +575,43 @@ class LLMModel(LLMModelBase):
|
||||
merged_dict.update(config.to_dict())
|
||||
final_request_config = LLMGenerationConfig(**merged_dict)
|
||||
|
||||
tools_dict: list[dict[str, Any]] | None = None
|
||||
if tools:
|
||||
tools_dict = []
|
||||
for tool in tools:
|
||||
if hasattr(tool, "model_dump"):
|
||||
model_dump_func = getattr(tool, "model_dump")
|
||||
tools_dict.append(model_dump_func(exclude_none=True))
|
||||
elif isinstance(tool, dict):
|
||||
tools_dict.append(tool)
|
||||
else:
|
||||
try:
|
||||
tools_dict.append(dict(tool))
|
||||
except (TypeError, ValueError):
|
||||
logger.warning(f"工具 '{tool}' 无法转换为字典,已忽略。")
|
||||
|
||||
http_client = await self._get_http_client()
|
||||
current_messages = list(messages)
|
||||
|
||||
for iteration in range(max_tool_iterations):
|
||||
logger.debug(f"工具调用循环迭代: {iteration + 1}/{max_tool_iterations}")
|
||||
async with AsyncExitStack() as stack:
|
||||
activated_tools = []
|
||||
if tools:
|
||||
for tool in tools:
|
||||
if tool.type == "mcp" and callable(tool.mcp_session):
|
||||
func_obj = getattr(tool.mcp_session, "func", None)
|
||||
tool_name = (
|
||||
getattr(func_obj, "__name__", "unknown")
|
||||
if func_obj
|
||||
else "unknown"
|
||||
)
|
||||
logger.debug(f"正在激活 MCP 工具会话: {tool_name}")
|
||||
|
||||
active_session = await stack.enter_async_context(
|
||||
tool.mcp_session()
|
||||
)
|
||||
|
||||
activated_tools.append(
|
||||
LLMTool.from_mcp_session(
|
||||
session=active_session, annotations=tool.annotations
|
||||
)
|
||||
)
|
||||
else:
|
||||
activated_tools.append(tool)
|
||||
|
||||
llm_response = await self._execute_with_smart_retry(
|
||||
adapter,
|
||||
current_messages,
|
||||
messages,
|
||||
final_request_config,
|
||||
tools_dict if iteration == 0 else None,
|
||||
tool_choice if iteration == 0 else None,
|
||||
activated_tools if activated_tools else None,
|
||||
tool_choice,
|
||||
http_client,
|
||||
)
|
||||
|
||||
response_tool_calls = llm_response.tool_calls or []
|
||||
|
||||
if not response_tool_calls or not tool_executor:
|
||||
logger.debug("模型未请求工具调用,或未提供工具执行器。返回当前响应。")
|
||||
return llm_response
|
||||
|
||||
logger.info(f"模型请求执行 {len(response_tool_calls)} 个工具。")
|
||||
|
||||
assistant_message_content = llm_response.text if llm_response.text else ""
|
||||
current_messages.append(
|
||||
LLMMessage.assistant_tool_calls(
|
||||
content=assistant_message_content, tool_calls=response_tool_calls
|
||||
)
|
||||
)
|
||||
|
||||
tool_response_messages: list[LLMMessage] = []
|
||||
for tool_call in response_tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
try:
|
||||
tool_args_dict = json.loads(tool_call.function.arguments)
|
||||
logger.debug(f"执行工具: {tool_name},参数: {tool_args_dict}")
|
||||
|
||||
tool_result = await tool_executor(tool_name, tool_args_dict)
|
||||
logger.debug(
|
||||
f"工具 '{tool_name}' 执行结果: {str(tool_result)[:200]}..."
|
||||
)
|
||||
|
||||
tool_response_messages.append(
|
||||
LLMMessage.tool_response(
|
||||
tool_call_id=tool_call.id,
|
||||
function_name=tool_name,
|
||||
result=tool_result,
|
||||
)
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(
|
||||
f"工具 '{tool_name}' 参数JSON解析失败: "
|
||||
f"{tool_call.function.arguments}, 错误: {e}"
|
||||
)
|
||||
tool_response_messages.append(
|
||||
LLMMessage.tool_response(
|
||||
tool_call_id=tool_call.id,
|
||||
function_name=tool_name,
|
||||
result={
|
||||
"error": "Argument JSON parsing failed",
|
||||
"details": str(e),
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"执行工具 '{tool_name}' 失败: {e}", e=e)
|
||||
tool_response_messages.append(
|
||||
LLMMessage.tool_response(
|
||||
tool_call_id=tool_call.id,
|
||||
function_name=tool_name,
|
||||
result={
|
||||
"error": "Tool execution failed",
|
||||
"details": str(e),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
current_messages.extend(tool_response_messages)
|
||||
|
||||
logger.warning(f"已达到最大工具调用迭代次数 ({max_tool_iterations})。")
|
||||
raise LLMException(
|
||||
"已达到最大工具调用迭代次数,但模型仍在请求工具调用或未提供最终文本回复。",
|
||||
code=LLMErrorCode.GENERATION_FAILED,
|
||||
details={
|
||||
"iterations": max_tool_iterations,
|
||||
"last_messages": current_messages[-2:],
|
||||
},
|
||||
)
|
||||
return llm_response
|
||||
|
||||
async def generate_embeddings(
|
||||
self,
|
||||
@ -578,7 +619,17 @@ class LLMModel(LLMModelBase):
|
||||
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
||||
**kwargs: Any,
|
||||
) -> list[list[float]]:
|
||||
"""生成文本嵌入向量"""
|
||||
"""
|
||||
生成文本嵌入向量
|
||||
|
||||
参数:
|
||||
texts: 文本列表。
|
||||
task_type: 嵌入任务类型。
|
||||
**kwargs: 其他参数。
|
||||
|
||||
返回:
|
||||
list[list[float]]: 嵌入向量列表。
|
||||
"""
|
||||
self._check_not_closed()
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
7
zhenxun/services/llm/tools/__init__.py
Normal file
7
zhenxun/services/llm/tools/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
"""
|
||||
工具模块导出
|
||||
"""
|
||||
|
||||
from .registry import tool_registry
|
||||
|
||||
__all__ = ["tool_registry"]
|
||||
181
zhenxun/services/llm/tools/registry.py
Normal file
181
zhenxun/services/llm/tools/registry.py
Normal file
@ -0,0 +1,181 @@
|
||||
"""
|
||||
工具注册表
|
||||
|
||||
负责加载、管理和实例化来自配置的工具。
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from contextlib import AbstractAsyncContextManager
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
from ..types import LLMTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..config.providers import ToolConfig
|
||||
from ..types.protocols import MCPCompatible
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""工具注册表,用于管理和实例化配置的工具。"""
|
||||
|
||||
def __init__(self):
|
||||
self._function_tools: dict[str, LLMTool] = {}
|
||||
|
||||
self._mcp_config_models: dict[str, type[BaseModel]] = {}
|
||||
if TYPE_CHECKING:
|
||||
self._mcp_factories: dict[
|
||||
str, Callable[..., AbstractAsyncContextManager["MCPCompatible"]]
|
||||
] = {}
|
||||
else:
|
||||
self._mcp_factories: dict[str, Callable] = {}
|
||||
|
||||
self._tool_configs: dict[str, "ToolConfig"] | None = None
|
||||
self._tool_cache: dict[str, "LLMTool"] = {}
|
||||
|
||||
def _load_configs_if_needed(self):
|
||||
"""如果尚未加载,则从主配置中加载MCP工具定义。"""
|
||||
if self._tool_configs is None:
|
||||
logger.debug("首次访问,正在加载MCP工具配置...")
|
||||
from ..config.providers import get_llm_config
|
||||
|
||||
llm_config = get_llm_config()
|
||||
self._tool_configs = {tool.name: tool for tool in llm_config.mcp_tools}
|
||||
logger.info(f"已加载 {len(self._tool_configs)} 个MCP工具配置。")
|
||||
|
||||
def function_tool(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
parameters: dict,
|
||||
required: list[str] | None = None,
|
||||
):
|
||||
"""
|
||||
装饰器:在代码中注册一个简单的、无状态的函数工具。
|
||||
|
||||
参数:
|
||||
name: 工具的唯一名称。
|
||||
description: 工具功能的描述。
|
||||
parameters: OpenAPI格式的函数参数schema的properties部分。
|
||||
required: 必需的参数列表。
|
||||
"""
|
||||
|
||||
def decorator(func: Callable):
|
||||
if name in self._function_tools or name in self._mcp_factories:
|
||||
logger.warning(f"正在覆盖已注册的工具: {name}")
|
||||
|
||||
tool_definition = LLMTool.create(
|
||||
name=name,
|
||||
description=description,
|
||||
parameters=parameters,
|
||||
required=required,
|
||||
)
|
||||
self._function_tools[name] = tool_definition
|
||||
logger.info(f"已在代码中注册函数工具: '{name}'")
|
||||
tool_definition.annotations = tool_definition.annotations or {}
|
||||
tool_definition.annotations["executable"] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def mcp_tool(self, name: str, config_model: type[BaseModel]):
|
||||
"""
|
||||
装饰器:注册一个MCP工具及其配置模型。
|
||||
|
||||
参数:
|
||||
name: 工具的唯一名称,必须与配置文件中的名称匹配。
|
||||
config_model: 一个Pydantic模型,用于定义和验证该工具的 `mcp_config`。
|
||||
"""
|
||||
|
||||
def decorator(factory_func: Callable):
|
||||
if name in self._mcp_factories:
|
||||
logger.warning(f"正在覆盖已注册的 MCP 工厂: {name}")
|
||||
self._mcp_factories[name] = factory_func
|
||||
self._mcp_config_models[name] = config_model
|
||||
logger.info(f"已注册 MCP 工具 '{name}' (配置模型: {config_model.__name__})")
|
||||
return factory_func
|
||||
|
||||
return decorator
|
||||
|
||||
def get_mcp_config_model(self, name: str) -> type[BaseModel] | None:
|
||||
"""根据名称获取MCP工具的配置模型。"""
|
||||
return self._mcp_config_models.get(name)
|
||||
|
||||
def register_mcp_factory(
|
||||
self,
|
||||
name: str,
|
||||
factory: Callable,
|
||||
):
|
||||
"""
|
||||
在代码中注册一个 MCP 会话工厂,将其与配置中的工具名称关联。
|
||||
|
||||
参数:
|
||||
name: 工具的唯一名称,必须与配置文件中的名称匹配。
|
||||
factory: 一个返回异步生成器的可调用对象(会话工厂)。
|
||||
"""
|
||||
if name in self._mcp_factories:
|
||||
logger.warning(f"正在覆盖已注册的 MCP 工厂: {name}")
|
||||
self._mcp_factories[name] = factory
|
||||
logger.info(f"已注册 MCP 会话工厂: '{name}'")
|
||||
|
||||
def get_tool(self, name: str) -> "LLMTool":
|
||||
"""
|
||||
根据名称获取一个 LLMTool 定义。
|
||||
对于MCP工具,返回的 LLMTool 实例包含一个可调用的会话工厂,
|
||||
而不是一个已激活的会话。
|
||||
"""
|
||||
logger.debug(f"🔍 请求获取工具定义: {name}")
|
||||
|
||||
if name in self._tool_cache:
|
||||
logger.debug(f"✅ 从缓存中获取工具定义: {name}")
|
||||
return self._tool_cache[name]
|
||||
|
||||
if name in self._function_tools:
|
||||
logger.debug(f"🛠️ 获取函数工具定义: {name}")
|
||||
tool = self._function_tools[name]
|
||||
self._tool_cache[name] = tool
|
||||
return tool
|
||||
|
||||
self._load_configs_if_needed()
|
||||
if self._tool_configs is None or name not in self._tool_configs:
|
||||
known_tools = list(self._function_tools.keys()) + (
|
||||
list(self._tool_configs.keys()) if self._tool_configs else []
|
||||
)
|
||||
logger.error(f"❌ 未找到名为 '{name}' 的工具定义")
|
||||
logger.debug(f"📋 可用工具定义列表: {known_tools}")
|
||||
raise ValueError(f"未找到名为 '{name}' 的工具定义。已知工具: {known_tools}")
|
||||
|
||||
config = self._tool_configs[name]
|
||||
tool: "LLMTool"
|
||||
|
||||
if name not in self._mcp_factories:
|
||||
logger.error(f"❌ MCP工具 '{name}' 缺少工厂函数")
|
||||
available_factories = list(self._mcp_factories.keys())
|
||||
logger.debug(f"📋 已注册的MCP工厂: {available_factories}")
|
||||
raise ValueError(
|
||||
f"MCP 工具 '{name}' 已在配置中定义,但没有注册对应的工厂函数。"
|
||||
"请使用 `@tool_registry.mcp_tool` 装饰器进行注册。"
|
||||
)
|
||||
|
||||
logger.info(f"🔧 创建MCP工具定义: {name}")
|
||||
factory = self._mcp_factories[name]
|
||||
typed_mcp_config = config.mcp_config
|
||||
logger.debug(f"📋 MCP工具配置: {typed_mcp_config}")
|
||||
|
||||
configured_factory = partial(factory, config=typed_mcp_config)
|
||||
tool = LLMTool.from_mcp_session(session=configured_factory)
|
||||
|
||||
self._tool_cache[name] = tool
|
||||
logger.debug(f"💾 MCP工具定义已缓存: {name}")
|
||||
return tool
|
||||
|
||||
def get_tools(self, names: list[str]) -> list["LLMTool"]:
|
||||
"""根据名称列表获取多个 LLMTool 实例。"""
|
||||
return [self.get_tool(name) for name in names]
|
||||
|
||||
|
||||
tool_registry = ToolRegistry()
|
||||
@ -4,6 +4,7 @@ LLM 类型定义模块
|
||||
统一导出所有核心类型、协议和异常定义。
|
||||
"""
|
||||
|
||||
from .capabilities import ModelCapabilities, ModelModality, get_model_capabilities
|
||||
from .content import (
|
||||
LLMContentPart,
|
||||
LLMMessage,
|
||||
@ -26,6 +27,7 @@ from .models import (
|
||||
ToolMetadata,
|
||||
UsageInfo,
|
||||
)
|
||||
from .protocols import MCPCompatible
|
||||
|
||||
__all__ = [
|
||||
"EmbeddingTaskType",
|
||||
@ -41,8 +43,11 @@ __all__ = [
|
||||
"LLMTool",
|
||||
"LLMToolCall",
|
||||
"LLMToolFunction",
|
||||
"MCPCompatible",
|
||||
"ModelCapabilities",
|
||||
"ModelDetail",
|
||||
"ModelInfo",
|
||||
"ModelModality",
|
||||
"ModelName",
|
||||
"ModelProvider",
|
||||
"ProviderConfig",
|
||||
@ -50,5 +55,6 @@ __all__ = [
|
||||
"ToolCategory",
|
||||
"ToolMetadata",
|
||||
"UsageInfo",
|
||||
"get_model_capabilities",
|
||||
"get_user_friendly_error_message",
|
||||
]
|
||||
|
||||
128
zhenxun/services/llm/types/capabilities.py
Normal file
128
zhenxun/services/llm/types/capabilities.py
Normal file
@ -0,0 +1,128 @@
|
||||
"""
|
||||
LLM 模型能力定义模块
|
||||
|
||||
定义模型的输入输出模态、工具调用支持等核心能力。
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
import fnmatch
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ModelModality(str, Enum):
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
AUDIO = "audio"
|
||||
VIDEO = "video"
|
||||
EMBEDDING = "embedding"
|
||||
|
||||
|
||||
class ModelCapabilities(BaseModel):
|
||||
"""定义一个模型的核心、稳定能力。"""
|
||||
|
||||
input_modalities: set[ModelModality] = Field(default={ModelModality.TEXT})
|
||||
output_modalities: set[ModelModality] = Field(default={ModelModality.TEXT})
|
||||
supports_tool_calling: bool = False
|
||||
is_embedding_model: bool = False
|
||||
|
||||
|
||||
STANDARD_TEXT_TOOL_CAPABILITIES = ModelCapabilities(
|
||||
input_modalities={ModelModality.TEXT},
|
||||
output_modalities={ModelModality.TEXT},
|
||||
supports_tool_calling=True,
|
||||
)
|
||||
|
||||
GEMINI_CAPABILITIES = ModelCapabilities(
|
||||
input_modalities={
|
||||
ModelModality.TEXT,
|
||||
ModelModality.IMAGE,
|
||||
ModelModality.AUDIO,
|
||||
ModelModality.VIDEO,
|
||||
},
|
||||
output_modalities={ModelModality.TEXT},
|
||||
supports_tool_calling=True,
|
||||
)
|
||||
|
||||
DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES = ModelCapabilities(
|
||||
input_modalities={ModelModality.TEXT, ModelModality.IMAGE, ModelModality.VIDEO},
|
||||
output_modalities={ModelModality.TEXT},
|
||||
supports_tool_calling=True,
|
||||
)
|
||||
|
||||
|
||||
MODEL_ALIAS_MAPPING: dict[str, str] = {
|
||||
"deepseek-v3*": "deepseek-chat",
|
||||
"deepseek-ai/DeepSeek-V3": "deepseek-chat",
|
||||
"deepseek-r1*": "deepseek-reasoner",
|
||||
}
|
||||
|
||||
|
||||
MODEL_CAPABILITIES_REGISTRY: dict[str, ModelCapabilities] = {
|
||||
"gemini-*-tts": ModelCapabilities(
|
||||
input_modalities={ModelModality.TEXT},
|
||||
output_modalities={ModelModality.AUDIO},
|
||||
),
|
||||
"gemini-*-native-audio-*": ModelCapabilities(
|
||||
input_modalities={ModelModality.TEXT, ModelModality.AUDIO, ModelModality.VIDEO},
|
||||
output_modalities={ModelModality.TEXT, ModelModality.AUDIO},
|
||||
supports_tool_calling=True,
|
||||
),
|
||||
"gemini-2.0-flash-preview-image-generation": ModelCapabilities(
|
||||
input_modalities={
|
||||
ModelModality.TEXT,
|
||||
ModelModality.IMAGE,
|
||||
ModelModality.AUDIO,
|
||||
ModelModality.VIDEO,
|
||||
},
|
||||
output_modalities={ModelModality.TEXT, ModelModality.IMAGE},
|
||||
supports_tool_calling=True,
|
||||
),
|
||||
"gemini-embedding-exp": ModelCapabilities(
|
||||
input_modalities={ModelModality.TEXT},
|
||||
output_modalities={ModelModality.EMBEDDING},
|
||||
is_embedding_model=True,
|
||||
),
|
||||
"gemini-2.5-pro*": GEMINI_CAPABILITIES,
|
||||
"gemini-1.5-pro*": GEMINI_CAPABILITIES,
|
||||
"gemini-2.5-flash*": GEMINI_CAPABILITIES,
|
||||
"gemini-2.0-flash*": GEMINI_CAPABILITIES,
|
||||
"gemini-1.5-flash*": GEMINI_CAPABILITIES,
|
||||
"GLM-4V-Flash": ModelCapabilities(
|
||||
input_modalities={ModelModality.TEXT, ModelModality.IMAGE},
|
||||
output_modalities={ModelModality.TEXT},
|
||||
supports_tool_calling=True,
|
||||
),
|
||||
"GLM-4V-Plus*": ModelCapabilities(
|
||||
input_modalities={ModelModality.TEXT, ModelModality.IMAGE, ModelModality.VIDEO},
|
||||
output_modalities={ModelModality.TEXT},
|
||||
supports_tool_calling=True,
|
||||
),
|
||||
"glm-4-*": STANDARD_TEXT_TOOL_CAPABILITIES,
|
||||
"glm-z1-*": STANDARD_TEXT_TOOL_CAPABILITIES,
|
||||
"doubao-seed-*": DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES,
|
||||
"doubao-1-5-thinking-vision-pro": DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES,
|
||||
"deepseek-chat": STANDARD_TEXT_TOOL_CAPABILITIES,
|
||||
"deepseek-reasoner": STANDARD_TEXT_TOOL_CAPABILITIES,
|
||||
}
|
||||
|
||||
|
||||
def get_model_capabilities(model_name: str) -> ModelCapabilities:
|
||||
"""
|
||||
从注册表获取模型能力,支持别名映射和通配符匹配。
|
||||
查找顺序: 1. 标准化名称 -> 2. 精确匹配 -> 3. 通配符匹配 -> 4. 默认值
|
||||
"""
|
||||
canonical_name = model_name
|
||||
for alias_pattern, c_name in MODEL_ALIAS_MAPPING.items():
|
||||
if fnmatch.fnmatch(model_name, alias_pattern):
|
||||
canonical_name = c_name
|
||||
break
|
||||
|
||||
if canonical_name in MODEL_CAPABILITIES_REGISTRY:
|
||||
return MODEL_CAPABILITIES_REGISTRY[canonical_name]
|
||||
|
||||
for pattern, capabilities in MODEL_CAPABILITIES_REGISTRY.items():
|
||||
if "*" in pattern and fnmatch.fnmatch(model_name, pattern):
|
||||
return capabilities
|
||||
|
||||
return ModelCapabilities()
|
||||
@ -225,8 +225,10 @@ class LLMContentPart(BaseModel):
|
||||
logger.warning(f"无法解析Base64图像数据: {self.image_source[:50]}...")
|
||||
return None
|
||||
|
||||
def convert_for_api(self, api_type: str) -> dict[str, Any]:
|
||||
async def convert_for_api_async(self, api_type: str) -> dict[str, Any]:
|
||||
"""根据API类型转换多模态内容格式"""
|
||||
from zhenxun.utils.http_utils import AsyncHttpx
|
||||
|
||||
if self.type == "text":
|
||||
if api_type == "openai":
|
||||
return {"type": "text", "text": self.text}
|
||||
@ -248,20 +250,23 @@ class LLMContentPart(BaseModel):
|
||||
mime_type, data = base64_info
|
||||
return {"inlineData": {"mimeType": mime_type, "data": data}}
|
||||
else:
|
||||
# 如果无法解析 Base64 数据,抛出异常
|
||||
raise ValueError(
|
||||
f"无法解析Base64图像数据: {self.image_source[:50]}..."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Gemini API需要Base64格式,但提供的是URL: {self.image_source}"
|
||||
)
|
||||
return {
|
||||
"inlineData": {
|
||||
"mimeType": "image/jpeg",
|
||||
"data": self.image_source,
|
||||
elif self.is_image_url():
|
||||
logger.debug(f"正在为Gemini下载并编码URL图片: {self.image_source}")
|
||||
try:
|
||||
image_bytes = await AsyncHttpx.get_content(self.image_source)
|
||||
mime_type = self.mime_type or "image/jpeg"
|
||||
base64_data = base64.b64encode(image_bytes).decode("utf-8")
|
||||
return {
|
||||
"inlineData": {"mimeType": mime_type, "data": base64_data}
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"下载或编码URL图片失败: {e}", e=e)
|
||||
raise ValueError(f"无法处理图片URL: {e}")
|
||||
else:
|
||||
raise ValueError(f"不支持的图像源格式: {self.image_source[:50]}...")
|
||||
else:
|
||||
return {"type": "image_url", "image_url": {"url": self.image_source}}
|
||||
|
||||
|
||||
@ -4,13 +4,25 @@ LLM 数据模型定义
|
||||
包含模型信息、配置、工具定义和响应数据的模型类。
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from contextlib import AbstractAsyncContextManager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .enums import ModelProvider, ToolCategory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .protocols import MCPCompatible
|
||||
|
||||
MCPSessionType = (
|
||||
MCPCompatible | Callable[[], AbstractAsyncContextManager[MCPCompatible]] | None
|
||||
)
|
||||
else:
|
||||
MCPCompatible = object
|
||||
MCPSessionType = Any
|
||||
|
||||
ModelName = str | None
|
||||
|
||||
|
||||
@ -98,10 +110,21 @@ class LLMToolCall(BaseModel):
|
||||
class LLMTool(BaseModel):
|
||||
"""LLM 工具定义(支持 MCP 风格)"""
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
type: str = "function"
|
||||
function: dict[str, Any]
|
||||
function: dict[str, Any] | None = None
|
||||
mcp_session: MCPSessionType = None
|
||||
annotations: dict[str, Any] | None = Field(default=None, description="工具注解")
|
||||
|
||||
def model_post_init(self, /, __context: Any) -> None:
|
||||
"""验证工具定义的有效性"""
|
||||
_ = __context
|
||||
if self.type == "function" and self.function is None:
|
||||
raise ValueError("函数类型的工具必须包含 'function' 字段。")
|
||||
if self.type == "mcp" and self.mcp_session is None:
|
||||
raise ValueError("MCP 类型的工具必须包含 'mcp_session' 字段。")
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
@ -111,7 +134,7 @@ class LLMTool(BaseModel):
|
||||
required: list[str] | None = None,
|
||||
annotations: dict[str, Any] | None = None,
|
||||
) -> "LLMTool":
|
||||
"""创建工具"""
|
||||
"""创建函数工具"""
|
||||
function_def = {
|
||||
"name": name,
|
||||
"description": description,
|
||||
@ -123,6 +146,15 @@ class LLMTool(BaseModel):
|
||||
}
|
||||
return cls(type="function", function=function_def, annotations=annotations)
|
||||
|
||||
@classmethod
|
||||
def from_mcp_session(
|
||||
cls,
|
||||
session: Any,
|
||||
annotations: dict[str, Any] | None = None,
|
||||
) -> "LLMTool":
|
||||
"""从 MCP 会话创建工具"""
|
||||
return cls(type="mcp", mcp_session=session, annotations=annotations)
|
||||
|
||||
|
||||
class LLMCodeExecution(BaseModel):
|
||||
"""代码执行结果"""
|
||||
|
||||
24
zhenxun/services/llm/types/protocols.py
Normal file
24
zhenxun/services/llm/types/protocols.py
Normal file
@ -0,0 +1,24 @@
|
||||
"""
|
||||
LLM 模块的协议定义
|
||||
"""
|
||||
|
||||
from typing import Any, Protocol
|
||||
|
||||
|
||||
class MCPCompatible(Protocol):
|
||||
"""
|
||||
一个协议,定义了与LLM模块兼容的MCP会话对象应具备的行为。
|
||||
任何实现了 to_api_tool 方法的对象都可以被认为是 MCPCompatible。
|
||||
"""
|
||||
|
||||
def to_api_tool(self, api_type: str) -> dict[str, Any]:
|
||||
"""
|
||||
将此MCP会话转换为特定LLM提供商API所需的工具格式。
|
||||
|
||||
参数:
|
||||
api_type: 目标API的类型 (例如 'gemini', 'openai')。
|
||||
|
||||
返回:
|
||||
dict[str, Any]: 一个字典,代表可以在API请求中使用的工具定义。
|
||||
"""
|
||||
...
|
||||
@ -3,8 +3,10 @@ LLM 模块的工具和转换函数
|
||||
"""
|
||||
|
||||
import base64
|
||||
import copy
|
||||
from pathlib import Path
|
||||
|
||||
from nonebot.adapters import Message as PlatformMessage
|
||||
from nonebot_plugin_alconna.uniseg import (
|
||||
At,
|
||||
File,
|
||||
@ -17,6 +19,7 @@ from nonebot_plugin_alconna.uniseg import (
|
||||
)
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.http_utils import AsyncHttpx
|
||||
|
||||
from .types import LLMContentPart
|
||||
|
||||
@ -25,6 +28,12 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
|
||||
"""
|
||||
将 UniMessage 实例转换为一个 LLMContentPart 列表。
|
||||
这是处理多模态输入的核心转换逻辑。
|
||||
|
||||
参数:
|
||||
message: 要转换的UniMessage实例。
|
||||
|
||||
返回:
|
||||
list[LLMContentPart]: 转换后的内容部分列表。
|
||||
"""
|
||||
parts: list[LLMContentPart] = []
|
||||
for seg in message:
|
||||
@ -51,14 +60,25 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
|
||||
if seg.path:
|
||||
part = await LLMContentPart.from_path(seg.path)
|
||||
elif seg.url:
|
||||
logger.warning(
|
||||
f"直接使用 URL 的 {type(seg).__name__} 段,"
|
||||
f"API 可能不支持: {seg.url}"
|
||||
)
|
||||
part = LLMContentPart.text_part(
|
||||
f"[{type(seg).__name__.upper()} FILE: {seg.name or seg.url}]"
|
||||
)
|
||||
elif hasattr(seg, "raw") and seg.raw:
|
||||
try:
|
||||
logger.debug(f"检测到媒体URL,开始下载: {seg.url}")
|
||||
media_bytes = await AsyncHttpx.get_content(seg.url)
|
||||
|
||||
new_seg = copy.copy(seg)
|
||||
new_seg.raw = media_bytes
|
||||
seg = new_seg
|
||||
logger.debug(f"媒体文件下载成功,大小: {len(media_bytes)} bytes")
|
||||
except Exception as e:
|
||||
logger.error(f"从URL下载媒体失败: {seg.url}, 错误: {e}")
|
||||
part = LLMContentPart.text_part(
|
||||
f"[下载媒体失败: {seg.name or seg.url}]"
|
||||
)
|
||||
|
||||
if part:
|
||||
parts.append(part)
|
||||
continue
|
||||
|
||||
if hasattr(seg, "raw") and seg.raw:
|
||||
mime_type = getattr(seg, "mimetype", None)
|
||||
if isinstance(seg.raw, bytes):
|
||||
b64_data = base64.b64encode(seg.raw).decode("utf-8")
|
||||
@ -127,50 +147,19 @@ def create_multimodal_message(
|
||||
audio_mimetypes: list[str] | str | None = None,
|
||||
) -> UniMessage:
|
||||
"""
|
||||
创建多模态消息的便捷函数,方便第三方调用。
|
||||
创建多模态消息的便捷函数
|
||||
|
||||
Args:
|
||||
参数:
|
||||
text: 文本内容
|
||||
images: 图片数据,支持路径、字节数据或URL
|
||||
videos: 视频数据,支持路径、字节数据或URL
|
||||
audios: 音频数据,支持路径、字节数据或URL
|
||||
image_mimetypes: 图片MIME类型,当images为bytes时需要指定
|
||||
video_mimetypes: 视频MIME类型,当videos为bytes时需要指定
|
||||
audio_mimetypes: 音频MIME类型,当audios为bytes时需要指定
|
||||
videos: 视频数据
|
||||
audios: 音频数据
|
||||
image_mimetypes: 图片MIME类型,bytes数据时需要指定
|
||||
video_mimetypes: 视频MIME类型,bytes数据时需要指定
|
||||
audio_mimetypes: 音频MIME类型,bytes数据时需要指定
|
||||
|
||||
Returns:
|
||||
返回:
|
||||
UniMessage: 构建好的多模态消息
|
||||
|
||||
Examples:
|
||||
# 纯文本
|
||||
msg = create_multimodal_message("请分析这段文字")
|
||||
|
||||
# 文本 + 单张图片(路径)
|
||||
msg = create_multimodal_message("分析图片", images="/path/to/image.jpg")
|
||||
|
||||
# 文本 + 多张图片
|
||||
msg = create_multimodal_message(
|
||||
"比较图片", images=["/path/1.jpg", "/path/2.jpg"]
|
||||
)
|
||||
|
||||
# 文本 + 图片字节数据
|
||||
msg = create_multimodal_message(
|
||||
"分析", images=image_data, image_mimetypes="image/jpeg"
|
||||
)
|
||||
|
||||
# 文本 + 视频
|
||||
msg = create_multimodal_message("分析视频", videos="/path/to/video.mp4")
|
||||
|
||||
# 文本 + 音频
|
||||
msg = create_multimodal_message("转录音频", audios="/path/to/audio.wav")
|
||||
|
||||
# 混合多模态
|
||||
msg = create_multimodal_message(
|
||||
"分析这些媒体文件",
|
||||
images="/path/to/image.jpg",
|
||||
videos="/path/to/video.mp4",
|
||||
audios="/path/to/audio.wav"
|
||||
)
|
||||
"""
|
||||
message = UniMessage()
|
||||
|
||||
@ -196,7 +185,7 @@ def _add_media_to_message(
|
||||
media_class: type,
|
||||
default_mimetype: str,
|
||||
) -> None:
|
||||
"""添加媒体文件到 UniMessage 的辅助函数"""
|
||||
"""添加媒体文件到 UniMessage"""
|
||||
if not isinstance(media_items, list):
|
||||
media_items = [media_items]
|
||||
|
||||
@ -216,3 +205,80 @@ def _add_media_to_message(
|
||||
elif isinstance(item, bytes):
|
||||
mimetype = mime_list[i] if i < len(mime_list) else default_mimetype
|
||||
message.append(media_class(raw=item, mimetype=mimetype))
|
||||
|
||||
|
||||
def message_to_unimessage(message: PlatformMessage) -> UniMessage:
|
||||
"""
|
||||
将平台特定的 Message 对象转换为通用的 UniMessage。
|
||||
主要用于处理引用消息等未被自动转换的消息体。
|
||||
|
||||
参数:
|
||||
message: 平台特定的Message对象。
|
||||
|
||||
返回:
|
||||
UniMessage: 转换后的通用消息对象。
|
||||
"""
|
||||
uni_segments = []
|
||||
for seg in message:
|
||||
if seg.type == "text":
|
||||
uni_segments.append(Text(seg.data.get("text", "")))
|
||||
elif seg.type == "image":
|
||||
uni_segments.append(Image(url=seg.data.get("url")))
|
||||
elif seg.type == "record":
|
||||
uni_segments.append(Voice(url=seg.data.get("url")))
|
||||
elif seg.type == "video":
|
||||
uni_segments.append(Video(url=seg.data.get("url")))
|
||||
elif seg.type == "at":
|
||||
uni_segments.append(At("user", str(seg.data.get("qq", ""))))
|
||||
else:
|
||||
logger.debug(f"跳过不支持的平台消息段类型: {seg.type}")
|
||||
|
||||
return UniMessage(uni_segments)
|
||||
|
||||
|
||||
def _sanitize_request_body_for_logging(body: dict) -> dict:
|
||||
"""
|
||||
净化请求体用于日志记录,移除大数据字段并添加摘要信息
|
||||
|
||||
参数:
|
||||
body: 原始请求体字典。
|
||||
|
||||
返回:
|
||||
dict: 净化后的请求体字典。
|
||||
"""
|
||||
try:
|
||||
sanitized_body = copy.deepcopy(body)
|
||||
|
||||
if "contents" in sanitized_body and isinstance(
|
||||
sanitized_body["contents"], list
|
||||
):
|
||||
for content_item in sanitized_body["contents"]:
|
||||
if "parts" in content_item and isinstance(content_item["parts"], list):
|
||||
media_summary = []
|
||||
new_parts = []
|
||||
for part in content_item["parts"]:
|
||||
if "inlineData" in part and isinstance(
|
||||
part["inlineData"], dict
|
||||
):
|
||||
data = part["inlineData"].get("data")
|
||||
if isinstance(data, str):
|
||||
mime_type = part["inlineData"].get(
|
||||
"mimeType", "unknown"
|
||||
)
|
||||
media_summary.append(f"{mime_type} ({len(data)} chars)")
|
||||
continue
|
||||
new_parts.append(part)
|
||||
|
||||
if media_summary:
|
||||
summary_text = (
|
||||
f"[多模态内容: {len(media_summary)}个文件 - "
|
||||
f"{', '.join(media_summary)}]"
|
||||
)
|
||||
new_parts.insert(0, {"text": summary_text})
|
||||
|
||||
content_item["parts"] = new_parts
|
||||
|
||||
return sanitized_body
|
||||
except Exception as e:
|
||||
logger.warning(f"日志净化失败: {e},将记录原始请求体。")
|
||||
return body
|
||||
|
||||
12
zhenxun/services/scheduler/__init__.py
Normal file
12
zhenxun/services/scheduler/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
"""
|
||||
定时调度服务模块
|
||||
|
||||
提供一个统一的、持久化的定时任务管理器,供所有插件使用。
|
||||
"""
|
||||
|
||||
from .lifecycle import _load_schedules_from_db
|
||||
from .service import scheduler_manager
|
||||
|
||||
_ = _load_schedules_from_db
|
||||
|
||||
__all__ = ["scheduler_manager"]
|
||||
102
zhenxun/services/scheduler/adapter.py
Normal file
102
zhenxun/services/scheduler/adapter.py
Normal file
@ -0,0 +1,102 @@
|
||||
"""
|
||||
引擎适配层 (Adapter)
|
||||
|
||||
封装所有对具体调度器引擎 (APScheduler) 的操作,
|
||||
使上层服务与调度器实现解耦。
|
||||
"""
|
||||
|
||||
from nonebot_plugin_apscheduler import scheduler
|
||||
|
||||
from zhenxun.models.schedule_info import ScheduleInfo
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
from .job import _execute_job
|
||||
|
||||
JOB_PREFIX = "zhenxun_schedule_"
|
||||
|
||||
|
||||
class APSchedulerAdapter:
|
||||
"""封装对 APScheduler 的操作"""
|
||||
|
||||
@staticmethod
|
||||
def _get_job_id(schedule_id: int) -> str:
|
||||
"""生成 APScheduler 的 Job ID"""
|
||||
return f"{JOB_PREFIX}{schedule_id}"
|
||||
|
||||
@staticmethod
|
||||
def add_or_reschedule_job(schedule: ScheduleInfo):
|
||||
"""根据 ScheduleInfo 添加或重新调度一个 APScheduler 任务"""
|
||||
job_id = APSchedulerAdapter._get_job_id(schedule.id)
|
||||
|
||||
if not isinstance(schedule.trigger_config, dict):
|
||||
logger.error(
|
||||
f"任务 {schedule.id} 的 trigger_config 不是字典类型: "
|
||||
f"{type(schedule.trigger_config)}"
|
||||
)
|
||||
return
|
||||
|
||||
job = scheduler.get_job(job_id)
|
||||
if job:
|
||||
scheduler.reschedule_job(
|
||||
job_id, trigger=schedule.trigger_type, **schedule.trigger_config
|
||||
)
|
||||
logger.debug(f"已更新APScheduler任务: {job_id}")
|
||||
else:
|
||||
scheduler.add_job(
|
||||
_execute_job,
|
||||
trigger=schedule.trigger_type,
|
||||
id=job_id,
|
||||
misfire_grace_time=300,
|
||||
args=[schedule.id],
|
||||
**schedule.trigger_config,
|
||||
)
|
||||
logger.debug(f"已添加新的APScheduler任务: {job_id}")
|
||||
|
||||
@staticmethod
|
||||
def remove_job(schedule_id: int):
|
||||
"""移除一个 APScheduler 任务"""
|
||||
job_id = APSchedulerAdapter._get_job_id(schedule_id)
|
||||
try:
|
||||
scheduler.remove_job(job_id)
|
||||
logger.debug(f"已从APScheduler中移除任务: {job_id}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def pause_job(schedule_id: int):
|
||||
"""暂停一个 APScheduler 任务"""
|
||||
job_id = APSchedulerAdapter._get_job_id(schedule_id)
|
||||
try:
|
||||
scheduler.pause_job(job_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def resume_job(schedule_id: int):
|
||||
"""恢复一个 APScheduler 任务"""
|
||||
job_id = APSchedulerAdapter._get_job_id(schedule_id)
|
||||
try:
|
||||
scheduler.resume_job(job_id)
|
||||
except Exception:
|
||||
import asyncio
|
||||
|
||||
from .repository import ScheduleRepository
|
||||
|
||||
async def _re_add_job():
|
||||
schedule = await ScheduleRepository.get_by_id(schedule_id)
|
||||
if schedule:
|
||||
APSchedulerAdapter.add_or_reschedule_job(schedule)
|
||||
|
||||
asyncio.create_task(_re_add_job()) # noqa: RUF006
|
||||
|
||||
@staticmethod
|
||||
def get_job_status(schedule_id: int) -> dict:
|
||||
"""获取 APScheduler Job 的状态"""
|
||||
job_id = APSchedulerAdapter._get_job_id(schedule_id)
|
||||
job = scheduler.get_job(job_id)
|
||||
return {
|
||||
"next_run_time": job.next_run_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
if job and job.next_run_time
|
||||
else "N/A",
|
||||
"is_paused_in_scheduler": not bool(job.next_run_time) if job else "N/A",
|
||||
}
|
||||
192
zhenxun/services/scheduler/job.py
Normal file
192
zhenxun/services/scheduler/job.py
Normal file
@ -0,0 +1,192 @@
|
||||
"""
|
||||
定时任务的执行逻辑
|
||||
|
||||
包含被 APScheduler 实际调度的函数,以及处理不同目标(单个、所有群组)的执行策略。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import inspect
|
||||
import random
|
||||
|
||||
import nonebot
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.models.schedule_info import ScheduleInfo
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.common_utils import CommonUtils
|
||||
from zhenxun.utils.decorator.retry import Retry
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
SCHEDULE_CONCURRENCY_KEY = "all_groups_concurrency_limit"
|
||||
|
||||
|
||||
async def _execute_job(schedule_id: int):
|
||||
"""
|
||||
APScheduler 调度的入口函数。
|
||||
根据 schedule_id 处理特定任务、所有群组任务或全局任务。
|
||||
"""
|
||||
from .repository import ScheduleRepository
|
||||
from .service import scheduler_manager
|
||||
|
||||
scheduler_manager._running_tasks.add(schedule_id)
|
||||
try:
|
||||
schedule = await ScheduleRepository.get_by_id(schedule_id)
|
||||
if not schedule or not schedule.is_enabled:
|
||||
logger.warning(f"定时任务 {schedule_id} 不存在或已禁用,跳过执行。")
|
||||
return
|
||||
|
||||
plugin_name = schedule.plugin_name
|
||||
|
||||
task_meta = scheduler_manager._registered_tasks.get(plugin_name)
|
||||
if not task_meta:
|
||||
logger.error(
|
||||
f"无法执行定时任务:插件 '{plugin_name}' 未注册或已卸载。将禁用该任务。"
|
||||
)
|
||||
schedule.is_enabled = False
|
||||
await ScheduleRepository.save(schedule, update_fields=["is_enabled"])
|
||||
from .adapter import APSchedulerAdapter
|
||||
|
||||
APSchedulerAdapter.remove_job(schedule.id)
|
||||
return
|
||||
|
||||
try:
|
||||
if schedule.bot_id:
|
||||
bot = nonebot.get_bot(schedule.bot_id)
|
||||
else:
|
||||
bot = nonebot.get_bot()
|
||||
logger.debug(
|
||||
f"任务 {schedule_id} 未关联特定Bot,使用默认Bot {bot.self_id}"
|
||||
)
|
||||
except KeyError:
|
||||
logger.warning(
|
||||
f"定时任务 {schedule_id} 需要的 Bot {schedule.bot_id} "
|
||||
f"不在线,本次执行跳过。"
|
||||
)
|
||||
return
|
||||
except ValueError:
|
||||
logger.warning(f"当前没有Bot在线,定时任务 {schedule_id} 跳过。")
|
||||
return
|
||||
|
||||
if schedule.group_id == scheduler_manager.ALL_GROUPS:
|
||||
await _execute_for_all_groups(schedule, task_meta, bot)
|
||||
else:
|
||||
await _execute_for_single_target(schedule, task_meta, bot)
|
||||
finally:
|
||||
scheduler_manager._running_tasks.discard(schedule_id)
|
||||
|
||||
|
||||
async def _execute_for_all_groups(schedule: ScheduleInfo, task_meta: dict, bot):
|
||||
"""为所有群组执行任务,并处理优先级覆盖。"""
|
||||
plugin_name = schedule.plugin_name
|
||||
|
||||
concurrency_limit = Config.get_config(
|
||||
"SchedulerManager", SCHEDULE_CONCURRENCY_KEY, 5
|
||||
)
|
||||
if not isinstance(concurrency_limit, int) or concurrency_limit <= 0:
|
||||
logger.warning(
|
||||
f"无效的定时任务并发限制配置 '{concurrency_limit}',将使用默认值 5。"
|
||||
)
|
||||
concurrency_limit = 5
|
||||
|
||||
logger.info(
|
||||
f"开始执行针对 [所有群组] 的任务 "
|
||||
f"(ID: {schedule.id}, 插件: {plugin_name}, Bot: {bot.self_id}),"
|
||||
f"并发限制: {concurrency_limit}"
|
||||
)
|
||||
|
||||
all_gids = set()
|
||||
try:
|
||||
group_list, _ = await PlatformUtils.get_group_list(bot)
|
||||
all_gids.update(
|
||||
g.group_id for g in group_list if g.group_id and not g.channel_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"为 'all' 任务获取 Bot {bot.self_id} 的群列表失败", e=e)
|
||||
return
|
||||
|
||||
specific_tasks_gids = set(
|
||||
await ScheduleInfo.filter(
|
||||
plugin_name=plugin_name, group_id__in=list(all_gids)
|
||||
).values_list("group_id", flat=True)
|
||||
)
|
||||
|
||||
semaphore = asyncio.Semaphore(concurrency_limit)
|
||||
|
||||
async def worker(gid: str):
|
||||
"""使用 Semaphore 包装单个群组的任务执行"""
|
||||
await asyncio.sleep(random.uniform(0, 59))
|
||||
async with semaphore:
|
||||
temp_schedule = copy.deepcopy(schedule)
|
||||
temp_schedule.group_id = gid
|
||||
await _execute_for_single_target(temp_schedule, task_meta, bot)
|
||||
await asyncio.sleep(random.uniform(0.1, 0.5))
|
||||
|
||||
tasks_to_run = []
|
||||
for gid in all_gids:
|
||||
if gid in specific_tasks_gids:
|
||||
logger.debug(f"群组 {gid} 已有特定任务,跳过 'all' 任务的执行。")
|
||||
continue
|
||||
tasks_to_run.append(worker(gid))
|
||||
|
||||
if tasks_to_run:
|
||||
await asyncio.gather(*tasks_to_run)
|
||||
|
||||
|
||||
async def _execute_for_single_target(schedule: ScheduleInfo, task_meta: dict, bot):
|
||||
"""为单个目标(具体群组或全局)执行任务。"""
|
||||
|
||||
plugin_name = schedule.plugin_name
|
||||
group_id = schedule.group_id
|
||||
|
||||
try:
|
||||
is_blocked = await CommonUtils.task_is_block(bot, plugin_name, group_id)
|
||||
if is_blocked:
|
||||
target_desc = f"群 {group_id}" if group_id else "全局"
|
||||
logger.info(
|
||||
f"插件 '{plugin_name}' 的定时任务在目标 [{target_desc}]"
|
||||
"因功能被禁用而跳过执行。"
|
||||
)
|
||||
return
|
||||
|
||||
max_retries = Config.get_config("SchedulerManager", "JOB_MAX_RETRIES", 2)
|
||||
retry_delay = Config.get_config("SchedulerManager", "JOB_RETRY_DELAY", 10)
|
||||
|
||||
@Retry.simple(
|
||||
stop_max_attempt=max_retries + 1,
|
||||
wait_fixed_seconds=retry_delay,
|
||||
log_name=f"定时任务执行:{schedule.plugin_name}",
|
||||
)
|
||||
async def _execute_task_with_retry():
|
||||
task_func = task_meta["func"]
|
||||
job_kwargs = schedule.job_kwargs
|
||||
if not isinstance(job_kwargs, dict):
|
||||
logger.error(
|
||||
f"任务 {schedule.id} 的 job_kwargs 不是字典类型: {type(job_kwargs)}"
|
||||
)
|
||||
return
|
||||
|
||||
sig = inspect.signature(task_func)
|
||||
if "bot" in sig.parameters:
|
||||
job_kwargs["bot"] = bot
|
||||
|
||||
await task_func(group_id, **job_kwargs)
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
f"插件 '{schedule.plugin_name}' 开始为目标 "
|
||||
f"[{schedule.group_id or '全局'}] 执行定时任务 (ID: {schedule.id})。"
|
||||
)
|
||||
await _execute_task_with_retry()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"执行定时任务 (ID: {schedule.id}, 插件: {schedule.plugin_name}, "
|
||||
f"目标: {schedule.group_id or '全局'}) 在所有重试后最终失败",
|
||||
e=e,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"执行定时任务 (ID: {schedule.id}, 插件: {plugin_name}, "
|
||||
f"目标: {group_id or '全局'}) 时发生异常",
|
||||
e=e,
|
||||
)
|
||||
62
zhenxun/services/scheduler/lifecycle.py
Normal file
62
zhenxun/services/scheduler/lifecycle.py
Normal file
@ -0,0 +1,62 @@
|
||||
"""
|
||||
定时任务的生命周期管理
|
||||
|
||||
包含在机器人启动时加载和调度数据库中保存的任务的逻辑。
|
||||
"""
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||
|
||||
from .adapter import APSchedulerAdapter
|
||||
from .repository import ScheduleRepository
|
||||
from .service import scheduler_manager
|
||||
|
||||
|
||||
@PriorityLifecycle.on_startup(priority=90)
|
||||
async def _load_schedules_from_db():
|
||||
"""在服务启动时从数据库加载并调度所有任务。"""
|
||||
logger.info("正在从数据库加载并调度所有定时任务...")
|
||||
schedules = await ScheduleRepository.get_all_enabled()
|
||||
count = 0
|
||||
for schedule in schedules:
|
||||
if schedule.plugin_name in scheduler_manager._registered_tasks:
|
||||
APSchedulerAdapter.add_or_reschedule_job(schedule)
|
||||
count += 1
|
||||
else:
|
||||
logger.warning(f"跳过加载定时任务:插件 '{schedule.plugin_name}' 未注册。")
|
||||
logger.info(f"数据库定时任务加载完成,共成功加载 {count} 个任务。")
|
||||
|
||||
logger.info("正在检查并注册声明式默认任务...")
|
||||
declared_count = 0
|
||||
for task_info in scheduler_manager._declared_tasks:
|
||||
plugin_name = task_info["plugin_name"]
|
||||
group_id = task_info["group_id"]
|
||||
bot_id = task_info["bot_id"]
|
||||
|
||||
query_kwargs = {
|
||||
"plugin_name": plugin_name,
|
||||
"group_id": group_id,
|
||||
"bot_id": bot_id,
|
||||
}
|
||||
exists = await ScheduleRepository.exists(**query_kwargs)
|
||||
|
||||
if not exists:
|
||||
logger.info(f"为插件 '{plugin_name}' 注册新的默认定时任务...")
|
||||
schedule = await scheduler_manager.add_schedule(
|
||||
plugin_name=plugin_name,
|
||||
group_id=group_id,
|
||||
trigger_type=task_info["trigger_type"],
|
||||
trigger_config=task_info["trigger_config"],
|
||||
job_kwargs=task_info["job_kwargs"],
|
||||
bot_id=bot_id,
|
||||
)
|
||||
if schedule:
|
||||
declared_count += 1
|
||||
logger.debug(f"默认任务 '{plugin_name}' 注册成功 (ID: {schedule.id})")
|
||||
else:
|
||||
logger.error(f"默认任务 '{plugin_name}' 注册失败")
|
||||
else:
|
||||
logger.debug(f"插件 '{plugin_name}' 的默认任务已存在于数据库中,跳过注册。")
|
||||
|
||||
if declared_count > 0:
|
||||
logger.info(f"声明式任务检查完成,新注册了 {declared_count} 个默认任务。")
|
||||
79
zhenxun/services/scheduler/repository.py
Normal file
79
zhenxun/services/scheduler/repository.py
Normal file
@ -0,0 +1,79 @@
|
||||
"""
|
||||
数据持久层 (Repository)
|
||||
|
||||
封装所有对 ScheduleInfo 模型的数据库操作,将数据访问逻辑与业务逻辑分离。
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from tortoise.queryset import QuerySet
|
||||
|
||||
from zhenxun.models.schedule_info import ScheduleInfo
|
||||
|
||||
|
||||
class ScheduleRepository:
|
||||
"""封装 ScheduleInfo 模型的数据库操作"""
|
||||
|
||||
@staticmethod
|
||||
async def get_by_id(schedule_id: int) -> ScheduleInfo | None:
|
||||
"""通过ID获取任务"""
|
||||
return await ScheduleInfo.get_or_none(id=schedule_id)
|
||||
|
||||
@staticmethod
|
||||
async def get_all_enabled() -> list[ScheduleInfo]:
|
||||
"""获取所有启用的任务"""
|
||||
return await ScheduleInfo.filter(is_enabled=True).all()
|
||||
|
||||
@staticmethod
|
||||
async def get_all(plugin_name: str | None = None) -> list[ScheduleInfo]:
|
||||
"""获取所有任务,可按插件名过滤"""
|
||||
if plugin_name:
|
||||
return await ScheduleInfo.filter(plugin_name=plugin_name).all()
|
||||
return await ScheduleInfo.all()
|
||||
|
||||
@staticmethod
|
||||
async def save(schedule: ScheduleInfo, update_fields: list[str] | None = None):
|
||||
"""保存任务"""
|
||||
await schedule.save(update_fields=update_fields)
|
||||
|
||||
@staticmethod
|
||||
async def exists(**kwargs: Any) -> bool:
|
||||
"""检查任务是否存在"""
|
||||
return await ScheduleInfo.exists(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
async def get_by_plugin_and_group(
|
||||
plugin_name: str, group_ids: list[str]
|
||||
) -> list[ScheduleInfo]:
|
||||
"""根据插件和群组ID列表获取任务"""
|
||||
return await ScheduleInfo.filter(
|
||||
plugin_name=plugin_name, group_id__in=group_ids
|
||||
).all()
|
||||
|
||||
@staticmethod
|
||||
async def update_or_create(
|
||||
defaults: dict, **kwargs: Any
|
||||
) -> tuple[ScheduleInfo, bool]:
|
||||
"""更新或创建任务"""
|
||||
return await ScheduleInfo.update_or_create(defaults=defaults, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
async def query_schedules(**filters: Any) -> list[ScheduleInfo]:
|
||||
"""
|
||||
根据任意条件查询任务列表
|
||||
|
||||
参数:
|
||||
**filters: 过滤条件,如 group_id="123", plugin_name="abc"
|
||||
|
||||
返回:
|
||||
list[ScheduleInfo]: 任务列表
|
||||
"""
|
||||
cleaned_filters = {k: v for k, v in filters.items() if v is not None}
|
||||
if not cleaned_filters:
|
||||
return await ScheduleInfo.all()
|
||||
return await ScheduleInfo.filter(**cleaned_filters).all()
|
||||
|
||||
@staticmethod
|
||||
def filter(**kwargs: Any) -> QuerySet[ScheduleInfo]:
|
||||
"""提供一个通用的过滤查询接口,供Targeter使用"""
|
||||
return ScheduleInfo.filter(**kwargs)
|
||||
448
zhenxun/services/scheduler/service.py
Normal file
448
zhenxun/services/scheduler/service.py
Normal file
@ -0,0 +1,448 @@
|
||||
"""
|
||||
服务层 (Service)
|
||||
|
||||
定义 SchedulerManager 类作为定时任务服务的公共 API 入口。
|
||||
它负责编排业务逻辑,并调用 Repository 和 Adapter 层来完成具体工作。
|
||||
"""
|
||||
|
||||
from collections.abc import Callable, Coroutine
|
||||
from datetime import datetime
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import nonebot
|
||||
from pydantic import BaseModel
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.models.schedule_info import ScheduleInfo
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
from .adapter import APSchedulerAdapter
|
||||
from .job import _execute_job
|
||||
from .repository import ScheduleRepository
|
||||
from .targeter import ScheduleTargeter
|
||||
|
||||
|
||||
class SchedulerManager:
|
||||
ALL_GROUPS: ClassVar[str] = "__ALL_GROUPS__"
|
||||
_registered_tasks: ClassVar[
|
||||
dict[str, dict[str, Callable | type[BaseModel] | None]]
|
||||
] = {}
|
||||
_declared_tasks: ClassVar[list[dict[str, Any]]] = []
|
||||
_running_tasks: ClassVar[set] = set()
|
||||
|
||||
def target(self, **filters: Any) -> ScheduleTargeter:
|
||||
"""
|
||||
创建目标选择器以执行批量操作
|
||||
|
||||
参数:
|
||||
**filters: 过滤条件,支持plugin_name、group_id、bot_id等字段。
|
||||
|
||||
返回:
|
||||
ScheduleTargeter: 目标选择器对象,可用于批量操作。
|
||||
"""
|
||||
return ScheduleTargeter(self, **filters)
|
||||
|
||||
def task(
|
||||
self,
|
||||
trigger: str,
|
||||
group_id: str | None = None,
|
||||
bot_id: str | None = None,
|
||||
**trigger_kwargs,
|
||||
):
|
||||
"""
|
||||
声明式定时任务装饰器
|
||||
|
||||
参数:
|
||||
trigger: 触发器类型,如'cron'、'interval'等。
|
||||
group_id: 目标群组ID,None表示全局任务。
|
||||
bot_id: 目标Bot ID,None表示使用默认Bot。
|
||||
**trigger_kwargs: 触发器配置参数。
|
||||
|
||||
返回:
|
||||
Callable: 装饰器函数。
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[..., Coroutine]) -> Callable[..., Coroutine]:
|
||||
try:
|
||||
plugin = nonebot.get_plugin_by_module_name(func.__module__)
|
||||
if not plugin:
|
||||
raise ValueError(f"函数 {func.__name__} 不在任何已加载的插件中。")
|
||||
plugin_name = plugin.name
|
||||
|
||||
task_declaration = {
|
||||
"plugin_name": plugin_name,
|
||||
"func": func,
|
||||
"group_id": group_id,
|
||||
"bot_id": bot_id,
|
||||
"trigger_type": trigger,
|
||||
"trigger_config": trigger_kwargs,
|
||||
"job_kwargs": {},
|
||||
}
|
||||
self._declared_tasks.append(task_declaration)
|
||||
logger.debug(
|
||||
f"发现声明式定时任务 '{plugin_name}',将在启动时进行注册。"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"注册声明式定时任务失败: {func.__name__}, 错误: {e}")
|
||||
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def register(
|
||||
self, plugin_name: str, params_model: type[BaseModel] | None = None
|
||||
) -> Callable:
|
||||
"""
|
||||
注册可调度的任务函数
|
||||
|
||||
参数:
|
||||
plugin_name: 插件名称,用于标识任务。
|
||||
params_model: 参数验证模型,继承自BaseModel的类。
|
||||
|
||||
返回:
|
||||
Callable: 装饰器函数。
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[..., Coroutine]) -> Callable[..., Coroutine]:
|
||||
if plugin_name in self._registered_tasks:
|
||||
logger.warning(f"插件 '{plugin_name}' 的定时任务已被重复注册。")
|
||||
self._registered_tasks[plugin_name] = {
|
||||
"func": func,
|
||||
"model": params_model,
|
||||
}
|
||||
model_name = params_model.__name__ if params_model else "无"
|
||||
logger.debug(
|
||||
f"插件 '{plugin_name}' 的定时任务已注册,参数模型: {model_name}"
|
||||
)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def get_registered_plugins(self) -> list[str]:
|
||||
"""
|
||||
获取已注册插件列表
|
||||
|
||||
返回:
|
||||
list[str]: 已注册的插件名称列表。
|
||||
"""
|
||||
return list(self._registered_tasks.keys())
|
||||
|
||||
async def add_daily_task(
|
||||
self,
|
||||
plugin_name: str,
|
||||
group_id: str | None,
|
||||
hour: int,
|
||||
minute: int,
|
||||
second: int = 0,
|
||||
job_kwargs: dict | None = None,
|
||||
bot_id: str | None = None,
|
||||
) -> "ScheduleInfo | None":
|
||||
"""
|
||||
添加每日定时任务
|
||||
|
||||
参数:
|
||||
plugin_name: 插件名称。
|
||||
group_id: 目标群组ID,None表示全局任务。
|
||||
hour: 执行小时(0-23)。
|
||||
minute: 执行分钟(0-59)。
|
||||
second: 执行秒数(0-59),默认为0。
|
||||
job_kwargs: 任务参数字典。
|
||||
bot_id: 目标Bot ID,None表示使用默认Bot。
|
||||
|
||||
返回:
|
||||
ScheduleInfo | None: 创建的任务信息,失败时返回None。
|
||||
"""
|
||||
trigger_config = {
|
||||
"hour": hour,
|
||||
"minute": minute,
|
||||
"second": second,
|
||||
"timezone": Config.get_config("SchedulerManager", "SCHEDULER_TIMEZONE"),
|
||||
}
|
||||
return await self.add_schedule(
|
||||
plugin_name,
|
||||
group_id,
|
||||
"cron",
|
||||
trigger_config,
|
||||
job_kwargs=job_kwargs,
|
||||
bot_id=bot_id,
|
||||
)
|
||||
|
||||
async def add_interval_task(
|
||||
self,
|
||||
plugin_name: str,
|
||||
group_id: str | None,
|
||||
*,
|
||||
weeks: int = 0,
|
||||
days: int = 0,
|
||||
hours: int = 0,
|
||||
minutes: int = 0,
|
||||
seconds: int = 0,
|
||||
start_date: str | datetime | None = None,
|
||||
job_kwargs: dict | None = None,
|
||||
bot_id: str | None = None,
|
||||
) -> "ScheduleInfo | None":
|
||||
"""添加间隔性定时任务"""
|
||||
trigger_config = {
|
||||
"weeks": weeks,
|
||||
"days": days,
|
||||
"hours": hours,
|
||||
"minutes": minutes,
|
||||
"seconds": seconds,
|
||||
"start_date": start_date,
|
||||
}
|
||||
trigger_config = {k: v for k, v in trigger_config.items() if v}
|
||||
return await self.add_schedule(
|
||||
plugin_name,
|
||||
group_id,
|
||||
"interval",
|
||||
trigger_config,
|
||||
job_kwargs=job_kwargs,
|
||||
bot_id=bot_id,
|
||||
)
|
||||
|
||||
def _validate_and_prepare_kwargs(
|
||||
self, plugin_name: str, job_kwargs: dict | None
|
||||
) -> tuple[bool, str | dict]:
|
||||
"""验证并准备任务参数,应用默认值"""
|
||||
from pydantic import ValidationError
|
||||
|
||||
task_meta = self._registered_tasks.get(plugin_name)
|
||||
if not task_meta:
|
||||
return False, f"插件 '{plugin_name}' 未注册。"
|
||||
|
||||
params_model = task_meta.get("model")
|
||||
job_kwargs = job_kwargs if job_kwargs is not None else {}
|
||||
|
||||
if not params_model:
|
||||
if job_kwargs:
|
||||
logger.warning(
|
||||
f"插件 '{plugin_name}' 未定义参数模型,但收到了参数: {job_kwargs}"
|
||||
)
|
||||
return True, job_kwargs
|
||||
|
||||
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
|
||||
logger.error(f"插件 '{plugin_name}' 的参数模型不是有效的 BaseModel 类")
|
||||
return False, f"插件 '{plugin_name}' 的参数模型配置错误"
|
||||
|
||||
try:
|
||||
model_validate = getattr(params_model, "model_validate", None)
|
||||
if not model_validate:
|
||||
return False, f"插件 '{plugin_name}' 的参数模型不支持验证"
|
||||
|
||||
validated_model = model_validate(job_kwargs)
|
||||
|
||||
model_dump = getattr(validated_model, "model_dump", None)
|
||||
if not model_dump:
|
||||
return False, f"插件 '{plugin_name}' 的参数模型不支持导出"
|
||||
|
||||
return True, model_dump()
|
||||
except ValidationError as e:
|
||||
errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()]
|
||||
error_str = "\n".join(errors)
|
||||
msg = f"插件 '{plugin_name}' 的任务参数验证失败:\n{error_str}"
|
||||
return False, msg
|
||||
|
||||
async def add_schedule(
|
||||
self,
|
||||
plugin_name: str,
|
||||
group_id: str | None,
|
||||
trigger_type: str,
|
||||
trigger_config: dict,
|
||||
job_kwargs: dict | None = None,
|
||||
bot_id: str | None = None,
|
||||
) -> "ScheduleInfo | None":
|
||||
"""
|
||||
添加定时任务(通用方法)
|
||||
|
||||
参数:
|
||||
plugin_name: 插件名称。
|
||||
group_id: 目标群组ID,None表示全局任务。
|
||||
trigger_type: 触发器类型,如'cron'、'interval'等。
|
||||
trigger_config: 触发器配置字典。
|
||||
job_kwargs: 任务参数字典。
|
||||
bot_id: 目标Bot ID,None表示使用默认Bot。
|
||||
|
||||
返回:
|
||||
ScheduleInfo | None: 创建的任务信息,失败时返回None。
|
||||
"""
|
||||
if plugin_name not in self._registered_tasks:
|
||||
logger.error(f"插件 '{plugin_name}' 没有注册可用的定时任务。")
|
||||
return None
|
||||
|
||||
is_valid, result = self._validate_and_prepare_kwargs(plugin_name, job_kwargs)
|
||||
if not is_valid:
|
||||
logger.error(f"任务参数校验失败: {result}")
|
||||
return None
|
||||
|
||||
search_kwargs = {"plugin_name": plugin_name, "group_id": group_id}
|
||||
if bot_id and group_id == self.ALL_GROUPS:
|
||||
search_kwargs["bot_id"] = bot_id
|
||||
else:
|
||||
search_kwargs["bot_id__isnull"] = True
|
||||
|
||||
defaults = {
|
||||
"trigger_type": trigger_type,
|
||||
"trigger_config": trigger_config,
|
||||
"job_kwargs": result,
|
||||
"is_enabled": True,
|
||||
}
|
||||
|
||||
schedule, created = await ScheduleRepository.update_or_create(
|
||||
defaults, **search_kwargs
|
||||
)
|
||||
APSchedulerAdapter.add_or_reschedule_job(schedule)
|
||||
|
||||
action = "设置" if created else "更新"
|
||||
logger.info(
|
||||
f"已成功{action}插件 '{plugin_name}' 的定时任务 (ID: {schedule.id})。"
|
||||
)
|
||||
return schedule
|
||||
|
||||
async def get_all_schedules(self) -> list[ScheduleInfo]:
|
||||
"""
|
||||
获取所有定时任务信息
|
||||
"""
|
||||
return await self.get_schedules()
|
||||
|
||||
async def get_schedules(
|
||||
self,
|
||||
plugin_name: str | None = None,
|
||||
group_id: str | None = None,
|
||||
bot_id: str | None = None,
|
||||
) -> list[ScheduleInfo]:
|
||||
"""
|
||||
根据条件获取定时任务列表
|
||||
|
||||
参数:
|
||||
plugin_name: 插件名称,None表示不限制。
|
||||
group_id: 群组ID,None表示不限制。
|
||||
bot_id: Bot ID,None表示不限制。
|
||||
|
||||
返回:
|
||||
list[ScheduleInfo]: 符合条件的任务信息列表。
|
||||
"""
|
||||
return await ScheduleRepository.query_schedules(
|
||||
plugin_name=plugin_name, group_id=group_id, bot_id=bot_id
|
||||
)
|
||||
|
||||
async def update_schedule(
|
||||
self,
|
||||
schedule_id: int,
|
||||
trigger_type: str | None = None,
|
||||
trigger_config: dict | None = None,
|
||||
job_kwargs: dict | None = None,
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
更新定时任务配置
|
||||
|
||||
参数:
|
||||
schedule_id: 任务ID。
|
||||
trigger_type: 新的触发器类型,None表示不更新。
|
||||
trigger_config: 新的触发器配置,None表示不更新。
|
||||
job_kwargs: 新的任务参数,None表示不更新。
|
||||
|
||||
返回:
|
||||
tuple[bool, str]: (是否成功, 结果消息)。
|
||||
"""
|
||||
schedule = await ScheduleRepository.get_by_id(schedule_id)
|
||||
if not schedule:
|
||||
return False, f"未找到 ID 为 {schedule_id} 的任务。"
|
||||
|
||||
updated_fields = []
|
||||
if trigger_config is not None:
|
||||
schedule.trigger_config = trigger_config
|
||||
updated_fields.append("trigger_config")
|
||||
if trigger_type is not None and schedule.trigger_type != trigger_type:
|
||||
schedule.trigger_type = trigger_type
|
||||
updated_fields.append("trigger_type")
|
||||
|
||||
if job_kwargs is not None:
|
||||
existing_kwargs = (
|
||||
schedule.job_kwargs.copy()
|
||||
if isinstance(schedule.job_kwargs, dict)
|
||||
else {}
|
||||
)
|
||||
existing_kwargs.update(job_kwargs)
|
||||
|
||||
is_valid, result = self._validate_and_prepare_kwargs(
|
||||
schedule.plugin_name, existing_kwargs
|
||||
)
|
||||
if not is_valid:
|
||||
return False, str(result)
|
||||
|
||||
assert isinstance(result, dict), "验证成功时 result 应该是字典类型"
|
||||
schedule.job_kwargs = result
|
||||
updated_fields.append("job_kwargs")
|
||||
|
||||
if not updated_fields:
|
||||
return True, "没有任何需要更新的配置。"
|
||||
|
||||
await ScheduleRepository.save(schedule, update_fields=updated_fields)
|
||||
APSchedulerAdapter.add_or_reschedule_job(schedule)
|
||||
return True, f"成功更新了任务 ID: {schedule_id} 的配置。"
|
||||
|
||||
async def get_schedule_status(self, schedule_id: int) -> dict | None:
|
||||
"""获取定时任务的详细状态信息"""
|
||||
schedule = await ScheduleRepository.get_by_id(schedule_id)
|
||||
if not schedule:
|
||||
return None
|
||||
|
||||
status_from_scheduler = APSchedulerAdapter.get_job_status(schedule.id)
|
||||
|
||||
status_text = (
|
||||
"运行中"
|
||||
if schedule_id in self._running_tasks
|
||||
else ("启用" if schedule.is_enabled else "暂停")
|
||||
)
|
||||
|
||||
return {
|
||||
"id": schedule.id,
|
||||
"bot_id": schedule.bot_id,
|
||||
"plugin_name": schedule.plugin_name,
|
||||
"group_id": schedule.group_id,
|
||||
"is_enabled": status_text,
|
||||
"trigger_type": schedule.trigger_type,
|
||||
"trigger_config": schedule.trigger_config,
|
||||
"job_kwargs": schedule.job_kwargs,
|
||||
**status_from_scheduler,
|
||||
}
|
||||
|
||||
async def pause_schedule(self, schedule_id: int) -> tuple[bool, str]:
|
||||
"""暂停指定的定时任务"""
|
||||
schedule = await ScheduleRepository.get_by_id(schedule_id)
|
||||
if not schedule or not schedule.is_enabled:
|
||||
return False, "任务不存在或已暂停。"
|
||||
|
||||
schedule.is_enabled = False
|
||||
await ScheduleRepository.save(schedule, update_fields=["is_enabled"])
|
||||
APSchedulerAdapter.pause_job(schedule_id)
|
||||
return True, f"已暂停任务 (ID: {schedule.id})。"
|
||||
|
||||
async def resume_schedule(self, schedule_id: int) -> tuple[bool, str]:
|
||||
"""恢复指定的定时任务"""
|
||||
schedule = await ScheduleRepository.get_by_id(schedule_id)
|
||||
if not schedule or schedule.is_enabled:
|
||||
return False, "任务不存在或已启用。"
|
||||
|
||||
schedule.is_enabled = True
|
||||
await ScheduleRepository.save(schedule, update_fields=["is_enabled"])
|
||||
APSchedulerAdapter.resume_job(schedule_id)
|
||||
return True, f"已恢复任务 (ID: {schedule.id})。"
|
||||
|
||||
async def trigger_now(self, schedule_id: int) -> tuple[bool, str]:
|
||||
"""立即手动触发指定的定时任务"""
|
||||
schedule = await ScheduleRepository.get_by_id(schedule_id)
|
||||
if not schedule:
|
||||
return False, f"未找到 ID 为 {schedule_id} 的定时任务。"
|
||||
if schedule.plugin_name not in self._registered_tasks:
|
||||
return False, f"插件 '{schedule.plugin_name}' 没有注册可用的定时任务。"
|
||||
|
||||
try:
|
||||
await _execute_job(schedule.id)
|
||||
return True, f"已手动触发任务 (ID: {schedule.id})。"
|
||||
except Exception as e:
|
||||
logger.error(f"手动触发任务失败: {e}")
|
||||
return False, f"手动触发任务失败: {e}"
|
||||
|
||||
|
||||
scheduler_manager = SchedulerManager()
|
||||
109
zhenxun/services/scheduler/targeter.py
Normal file
109
zhenxun/services/scheduler/targeter.py
Normal file
@ -0,0 +1,109 @@
|
||||
"""
|
||||
目标选择器 (Targeter)
|
||||
|
||||
提供链式API,用于构建和执行对多个定时任务的批量操作。
|
||||
"""
|
||||
|
||||
from collections.abc import Callable, Coroutine
|
||||
from typing import Any
|
||||
|
||||
from .adapter import APSchedulerAdapter
|
||||
from .repository import ScheduleRepository
|
||||
|
||||
|
||||
class ScheduleTargeter:
|
||||
"""
|
||||
一个用于构建和执行定时任务批量操作的目标选择器。
|
||||
"""
|
||||
|
||||
def __init__(self, manager: Any, **filters: Any):
|
||||
"""初始化目标选择器"""
|
||||
self._manager = manager
|
||||
self._filters = {k: v for k, v in filters.items() if v is not None}
|
||||
|
||||
async def _get_schedules(self):
|
||||
"""根据过滤器获取任务"""
|
||||
query = ScheduleRepository.filter(**self._filters)
|
||||
return await query.all()
|
||||
|
||||
def _generate_target_description(self) -> str:
|
||||
"""根据过滤条件生成友好的目标描述"""
|
||||
if "id" in self._filters:
|
||||
return f"任务 ID {self._filters['id']} 的"
|
||||
|
||||
parts = []
|
||||
if "group_id" in self._filters:
|
||||
group_id = self._filters["group_id"]
|
||||
if group_id == self._manager.ALL_GROUPS:
|
||||
parts.append("所有群组中")
|
||||
else:
|
||||
parts.append(f"群 {group_id} 中")
|
||||
|
||||
if "plugin_name" in self._filters:
|
||||
parts.append(f"插件 '{self._filters['plugin_name']}' 的")
|
||||
|
||||
if not parts:
|
||||
return "所有"
|
||||
|
||||
return "".join(parts)
|
||||
|
||||
async def _apply_operation(
|
||||
self,
|
||||
operation_func: Callable[[int], Coroutine[Any, Any, tuple[bool, str]]],
|
||||
operation_name: str,
|
||||
) -> tuple[int, str]:
|
||||
"""通用的操作应用模板"""
|
||||
schedules = await self._get_schedules()
|
||||
if not schedules:
|
||||
target_desc = self._generate_target_description()
|
||||
return 0, f"没有找到{target_desc}可供{operation_name}的任务。"
|
||||
|
||||
success_count = 0
|
||||
for schedule in schedules:
|
||||
success, _ = await operation_func(schedule.id)
|
||||
if success:
|
||||
success_count += 1
|
||||
|
||||
target_desc = self._generate_target_description()
|
||||
return (
|
||||
success_count,
|
||||
f"成功{operation_name}了{target_desc} {success_count} 个任务。",
|
||||
)
|
||||
|
||||
async def pause(self) -> tuple[int, str]:
|
||||
"""
|
||||
暂停匹配的定时任务
|
||||
|
||||
返回:
|
||||
tuple[int, str]: (成功暂停的任务数量, 操作结果消息)。
|
||||
"""
|
||||
return await self._apply_operation(self._manager.pause_schedule, "暂停")
|
||||
|
||||
async def resume(self) -> tuple[int, str]:
|
||||
"""
|
||||
恢复匹配的定时任务
|
||||
|
||||
返回:
|
||||
tuple[int, str]: (成功恢复的任务数量, 操作结果消息)。
|
||||
"""
|
||||
return await self._apply_operation(self._manager.resume_schedule, "恢复")
|
||||
|
||||
async def remove(self) -> tuple[int, str]:
|
||||
"""
|
||||
移除匹配的定时任务
|
||||
|
||||
返回:
|
||||
tuple[int, str]: (成功移除的任务数量, 操作结果消息)。
|
||||
"""
|
||||
schedules = await self._get_schedules()
|
||||
if not schedules:
|
||||
target_desc = self._generate_target_description()
|
||||
return 0, f"没有找到{target_desc}可供移除的任务。"
|
||||
|
||||
for schedule in schedules:
|
||||
APSchedulerAdapter.remove_job(schedule.id)
|
||||
|
||||
query = ScheduleRepository.filter(**self._filters)
|
||||
count = await query.delete()
|
||||
target_desc = self._generate_target_description()
|
||||
return count, f"成功移除了{target_desc} {count} 个任务。"
|
||||
@ -137,19 +137,13 @@ def get_async_client(
|
||||
|
||||
class AsyncHttpx:
|
||||
"""
|
||||
一个高级的、健壮的异步HTTP客户端工具类。
|
||||
高性能异步HTTP客户端工具类。
|
||||
|
||||
设计理念:
|
||||
- **全局共享客户端**: 默认情况下,所有请求都通过一个在应用启动时初始化的全局
|
||||
`httpx.AsyncClient` 实例发出。这个实例共享连接池,提高了效率和性能。
|
||||
- **向后兼容与灵活性**: 完全兼容旧的API,同时提供了两种方式来处理需要
|
||||
特殊网络配置(如不同代理、超时)的请求:
|
||||
1. **单次请求覆盖**: 在调用 `get`, `post` 等方法时,直接传入 `proxies`,
|
||||
`timeout` 等参数,将为该次请求创建一个临时的、独立的客户端。
|
||||
2. **临时客户端上下文**: 使用 `temporary_client()` 上下文管理器,可以
|
||||
获取一个独立的、可配置的客户端,用于执行一系列需要相同特殊配置的请求。
|
||||
- **健壮性**: 内置了自动重试、多镜像URL回退(fallback)机制,并提供了便捷的
|
||||
JSON解析和文件下载方法。
|
||||
特性:
|
||||
- 全局共享连接池,提升性能
|
||||
- 支持临时客户端配置(代理、超时等)
|
||||
- 内置重试机制和多URL回退
|
||||
- 提供JSON解析和文件下载功能
|
||||
"""
|
||||
|
||||
CLIENT_KEY: ClassVar[list[str]] = [
|
||||
@ -157,7 +151,6 @@ class AsyncHttpx:
|
||||
"proxies",
|
||||
"proxy",
|
||||
"verify",
|
||||
"headers",
|
||||
]
|
||||
|
||||
default_proxy: ClassVar[dict[str, str] | None] = (
|
||||
@ -290,15 +283,6 @@ class AsyncHttpx:
|
||||
) -> Response:
|
||||
"""发送 GET 请求,并返回第一个成功的响应。
|
||||
|
||||
说明:
|
||||
本方法是 httpx.get 的高级包装,增加了多链接尝试、自动重试和统一的
|
||||
客户端管理。如果提供 URL 列表,它将依次尝试直到成功为止。
|
||||
|
||||
用法建议:
|
||||
- **常规使用**: `await AsyncHttpx.get(url)` 将使用全局客户端。
|
||||
- **单次覆盖配置**: `await AsyncHttpx.get(url, timeout=5, proxies=None)`
|
||||
将为本次请求创建一个独立的临时客户端。
|
||||
|
||||
参数:
|
||||
url: 单个请求 URL 或一个 URL 列表。
|
||||
follow_redirects: 是否跟随重定向。
|
||||
@ -312,7 +296,7 @@ class AsyncHttpx:
|
||||
返回:
|
||||
Response: httpx 的响应对象。
|
||||
|
||||
Raises:
|
||||
异常:
|
||||
AllURIsFailedError: 当所有提供的URL都请求失败时抛出。
|
||||
"""
|
||||
|
||||
@ -373,10 +357,11 @@ class AsyncHttpx:
|
||||
"""
|
||||
[私有] 执行单个HTTP请求并解析JSON,用于内部统一处理。
|
||||
"""
|
||||
client_kwargs, request_kwargs = cls._split_kwargs(kwargs)
|
||||
|
||||
async with cls._get_active_client_context(
|
||||
client=client, **kwargs
|
||||
client=client, **client_kwargs
|
||||
) as active_client:
|
||||
_, request_kwargs = cls._split_kwargs(kwargs)
|
||||
response = await active_client.request(method, url, **request_kwargs)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
@ -394,11 +379,6 @@ class AsyncHttpx:
|
||||
"""
|
||||
发送GET请求并自动解析为JSON,支持重试和多链接尝试。
|
||||
|
||||
说明:
|
||||
这是一个高度便捷的方法,封装了请求、重试、JSON解析和错误处理。
|
||||
它会在网络错误或JSON解析错误时自动重试。
|
||||
如果所有尝试都失败,它会安全地返回一个默认值。
|
||||
|
||||
参数:
|
||||
url: 单个请求 URL 或一个备用 URL 列表。
|
||||
default: (可选) 当所有尝试都失败时返回的默认值,默认为None。
|
||||
@ -411,7 +391,7 @@ class AsyncHttpx:
|
||||
返回:
|
||||
Any: 解析后的JSON数据,或在失败时返回 `default` 值。
|
||||
|
||||
Raises:
|
||||
异常:
|
||||
AllURIsFailedError: 当 `raise_on_failure` 为 True 且所有URL都请求失败时抛出
|
||||
"""
|
||||
|
||||
@ -490,25 +470,33 @@ class AsyncHttpx:
|
||||
"""
|
||||
执行单个流式下载的私有方法,被重试装饰器包裹。
|
||||
"""
|
||||
client_kwargs, request_kwargs = cls._split_kwargs(kwargs)
|
||||
show_progress = request_kwargs.pop("show_progress", False)
|
||||
|
||||
async with cls._get_active_client_context(
|
||||
client=client, **kwargs
|
||||
client=client, **client_kwargs
|
||||
) as active_client:
|
||||
async with active_client.stream("GET", url, **kwargs) as response:
|
||||
async with active_client.stream("GET", url, **request_kwargs) as response:
|
||||
response.raise_for_status()
|
||||
total = int(response.headers.get("Content-Length", 0))
|
||||
|
||||
with Progress(
|
||||
TextColumn(path.name),
|
||||
"[progress.percentage]{task.percentage:>3.0f}%",
|
||||
BarColumn(bar_width=None),
|
||||
DownloadColumn(),
|
||||
TransferSpeedColumn(),
|
||||
) as progress:
|
||||
task_id = progress.add_task("Download", total=total)
|
||||
if show_progress:
|
||||
with Progress(
|
||||
TextColumn(path.name),
|
||||
"[progress.percentage]{task.percentage:>3.0f}%",
|
||||
BarColumn(bar_width=None),
|
||||
DownloadColumn(),
|
||||
TransferSpeedColumn(),
|
||||
) as progress:
|
||||
task_id = progress.add_task("Download", total=total)
|
||||
async with aiofiles.open(path, "wb") as f:
|
||||
async for chunk in response.aiter_bytes():
|
||||
await f.write(chunk)
|
||||
progress.update(task_id, advance=len(chunk))
|
||||
else:
|
||||
async with aiofiles.open(path, "wb") as f:
|
||||
async for chunk in response.aiter_bytes():
|
||||
await f.write(chunk)
|
||||
progress.update(task_id, advance=len(chunk))
|
||||
|
||||
@classmethod
|
||||
async def download_file(
|
||||
@ -517,6 +505,7 @@ class AsyncHttpx:
|
||||
path: str | Path,
|
||||
*,
|
||||
stream: bool = False,
|
||||
show_progress: bool = False,
|
||||
client: AsyncClient | None = None,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
@ -529,6 +518,7 @@ class AsyncHttpx:
|
||||
url: 单个文件 URL 或一个备用 URL 列表。
|
||||
path: 文件保存的本地路径。
|
||||
stream: (可选) 是否使用流式下载,适用于大文件,默认为 False。
|
||||
show_progress: (可选) 当 stream=True 时,是否显示下载进度条。默认为 False。
|
||||
client: (可选) 指定的HTTP客户端。
|
||||
**kwargs: 其他所有传递给 get() 方法或 httpx.stream() 的参数。
|
||||
|
||||
@ -544,7 +534,9 @@ class AsyncHttpx:
|
||||
async with aiofiles.open(path, "wb") as f:
|
||||
await f.write(content)
|
||||
else:
|
||||
await cls._stream_download(current_url, path, **worker_kwargs)
|
||||
await cls._stream_download(
|
||||
current_url, path, show_progress=show_progress, **worker_kwargs
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"下载 {current_url} 成功 -> {path.absolute()}",
|
||||
@ -573,10 +565,6 @@ class AsyncHttpx:
|
||||
) -> list[bool]:
|
||||
"""并发下载多个文件,支持为每个文件提供备用镜像链接。
|
||||
|
||||
说明:
|
||||
使用 asyncio.Semaphore 来控制并发请求的数量。
|
||||
对于 url_list 中的每个元素,如果它是一个列表,则会依次尝试直到下载成功。
|
||||
|
||||
参数:
|
||||
url_list: 包含所有文件下载任务的列表。每个元素可以是:
|
||||
- 一个字符串 (str): 代表该任务的唯一URL。
|
||||
@ -625,9 +613,6 @@ class AsyncHttpx:
|
||||
async def get_fastest_mirror(cls, url_list: list[str]) -> list[str]:
|
||||
"""测试并返回最快的镜像地址。
|
||||
|
||||
说明:
|
||||
通过并发发送 HEAD 请求来测试每个 URL 的响应时间和可用性,并按响应速度排序。
|
||||
|
||||
参数:
|
||||
url_list: 需要测试的镜像 URL 列表。
|
||||
|
||||
@ -671,23 +656,12 @@ class AsyncHttpx:
|
||||
"""
|
||||
创建一个临时的、可配置的HTTP客户端上下文,并直接返回该客户端实例。
|
||||
|
||||
此方法返回一个标准的 `httpx.AsyncClient`,它不使用全局连接池,
|
||||
拥有独立的配置(如代理、headers、超时等),并在退出上下文后自动关闭。
|
||||
适用于需要用一套特殊网络配置执行一系列请求的场景。
|
||||
|
||||
用法:
|
||||
async with AsyncHttpx.temporary_client(proxies=None, timeout=5) as client:
|
||||
# client 是一个标准的 httpx.AsyncClient 实例
|
||||
response1 = await client.get("http://some.internal.api/1")
|
||||
response2 = await client.get("http://some.internal.api/2")
|
||||
data = response2.json()
|
||||
|
||||
参数:
|
||||
**kwargs: 所有传递给 `httpx.AsyncClient` 构造函数的参数。
|
||||
例如: `proxies`, `headers`, `verify`, `timeout`,
|
||||
`follow_redirects`。
|
||||
|
||||
Yields:
|
||||
返回:
|
||||
httpx.AsyncClient: 一个配置好的、临时的客户端实例。
|
||||
"""
|
||||
async with get_async_client(**kwargs) as client:
|
||||
|
||||
@ -1,810 +0,0 @@
|
||||
import asyncio
|
||||
from collections.abc import Callable, Coroutine
|
||||
import copy
|
||||
import inspect
|
||||
import random
|
||||
from typing import ClassVar
|
||||
|
||||
import nonebot
|
||||
from nonebot import get_bots
|
||||
from nonebot_plugin_apscheduler import scheduler
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.models.schedule_info import ScheduleInfo
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.common_utils import CommonUtils
|
||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
SCHEDULE_CONCURRENCY_KEY = "all_groups_concurrency_limit"
|
||||
|
||||
|
||||
class SchedulerManager:
|
||||
"""
|
||||
一个通用的、持久化的定时任务管理器,供所有插件使用。
|
||||
"""
|
||||
|
||||
_registered_tasks: ClassVar[
|
||||
dict[str, dict[str, Callable | type[BaseModel] | None]]
|
||||
] = {}
|
||||
_JOB_PREFIX = "zhenxun_schedule_"
|
||||
_running_tasks: ClassVar[set] = set()
|
||||
|
||||
def register(
|
||||
self, plugin_name: str, params_model: type[BaseModel] | None = None
|
||||
) -> Callable:
|
||||
"""
|
||||
注册一个可调度的任务函数。
|
||||
被装饰的函数签名应为 `async def func(group_id: str | None, **kwargs)`
|
||||
|
||||
Args:
|
||||
plugin_name (str): 插件的唯一名称 (通常是模块名)。
|
||||
params_model (type[BaseModel], optional): 一个 Pydantic BaseModel 类,
|
||||
用于定义和验证任务函数接受的额外参数。
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[..., Coroutine]) -> Callable[..., Coroutine]:
|
||||
if plugin_name in self._registered_tasks:
|
||||
logger.warning(f"插件 '{plugin_name}' 的定时任务已被重复注册。")
|
||||
self._registered_tasks[plugin_name] = {
|
||||
"func": func,
|
||||
"model": params_model,
|
||||
}
|
||||
model_name = params_model.__name__ if params_model else "无"
|
||||
logger.debug(
|
||||
f"插件 '{plugin_name}' 的定时任务已注册,参数模型: {model_name}"
|
||||
)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def get_registered_plugins(self) -> list[str]:
|
||||
"""获取所有已注册定时任务的插件列表。"""
|
||||
return list(self._registered_tasks.keys())
|
||||
|
||||
def _get_job_id(self, schedule_id: int) -> str:
|
||||
"""根据数据库ID生成唯一的 APScheduler Job ID。"""
|
||||
return f"{self._JOB_PREFIX}{schedule_id}"
|
||||
|
||||
async def _execute_job(self, schedule_id: int):
|
||||
"""
|
||||
APScheduler 调度的入口函数。
|
||||
根据 schedule_id 处理特定任务、所有群组任务或全局任务。
|
||||
"""
|
||||
schedule = await ScheduleInfo.get_or_none(id=schedule_id)
|
||||
if not schedule or not schedule.is_enabled:
|
||||
logger.warning(f"定时任务 {schedule_id} 不存在或已禁用,跳过执行。")
|
||||
return
|
||||
|
||||
plugin_name = schedule.plugin_name
|
||||
|
||||
task_meta = self._registered_tasks.get(plugin_name)
|
||||
if not task_meta:
|
||||
logger.error(
|
||||
f"无法执行定时任务:插件 '{plugin_name}' 未注册或已卸载。将禁用该任务。"
|
||||
)
|
||||
schedule.is_enabled = False
|
||||
await schedule.save(update_fields=["is_enabled"])
|
||||
self._remove_aps_job(schedule.id)
|
||||
return
|
||||
|
||||
try:
|
||||
if schedule.bot_id:
|
||||
bot = nonebot.get_bot(schedule.bot_id)
|
||||
else:
|
||||
bot = nonebot.get_bot()
|
||||
logger.debug(
|
||||
f"任务 {schedule_id} 未关联特定Bot,使用默认Bot {bot.self_id}"
|
||||
)
|
||||
except KeyError:
|
||||
logger.warning(
|
||||
f"定时任务 {schedule_id} 需要的 Bot {schedule.bot_id} "
|
||||
f"不在线,本次执行跳过。"
|
||||
)
|
||||
return
|
||||
except ValueError:
|
||||
logger.warning(f"当前没有Bot在线,定时任务 {schedule_id} 跳过。")
|
||||
return
|
||||
|
||||
if schedule.group_id == "__ALL_GROUPS__":
|
||||
await self._execute_for_all_groups(schedule, task_meta, bot)
|
||||
else:
|
||||
await self._execute_for_single_target(schedule, task_meta, bot)
|
||||
|
||||
async def _execute_for_all_groups(
|
||||
self, schedule: ScheduleInfo, task_meta: dict, bot
|
||||
):
|
||||
"""为所有群组执行任务,并处理优先级覆盖。"""
|
||||
plugin_name = schedule.plugin_name
|
||||
|
||||
concurrency_limit = Config.get_config(
|
||||
"SchedulerManager", SCHEDULE_CONCURRENCY_KEY, 5
|
||||
)
|
||||
if not isinstance(concurrency_limit, int) or concurrency_limit <= 0:
|
||||
logger.warning(
|
||||
f"无效的定时任务并发限制配置 '{concurrency_limit}',将使用默认值 5。"
|
||||
)
|
||||
concurrency_limit = 5
|
||||
|
||||
logger.info(
|
||||
f"开始执行针对 [所有群组] 的任务 "
|
||||
f"(ID: {schedule.id}, 插件: {plugin_name}, Bot: {bot.self_id}),"
|
||||
f"并发限制: {concurrency_limit}"
|
||||
)
|
||||
|
||||
all_gids = set()
|
||||
try:
|
||||
group_list, _ = await PlatformUtils.get_group_list(bot)
|
||||
all_gids.update(
|
||||
g.group_id for g in group_list if g.group_id and not g.channel_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"为 'all' 任务获取 Bot {bot.self_id} 的群列表失败", e=e)
|
||||
return
|
||||
|
||||
specific_tasks_gids = set(
|
||||
await ScheduleInfo.filter(
|
||||
plugin_name=plugin_name, group_id__in=list(all_gids)
|
||||
).values_list("group_id", flat=True)
|
||||
)
|
||||
|
||||
semaphore = asyncio.Semaphore(concurrency_limit)
|
||||
|
||||
async def worker(gid: str):
|
||||
"""使用 Semaphore 包装单个群组的任务执行"""
|
||||
async with semaphore:
|
||||
temp_schedule = copy.deepcopy(schedule)
|
||||
temp_schedule.group_id = gid
|
||||
await self._execute_for_single_target(temp_schedule, task_meta, bot)
|
||||
await asyncio.sleep(random.uniform(0.1, 0.5))
|
||||
|
||||
tasks_to_run = []
|
||||
for gid in all_gids:
|
||||
if gid in specific_tasks_gids:
|
||||
logger.debug(f"群组 {gid} 已有特定任务,跳过 'all' 任务的执行。")
|
||||
continue
|
||||
tasks_to_run.append(worker(gid))
|
||||
|
||||
if tasks_to_run:
|
||||
await asyncio.gather(*tasks_to_run)
|
||||
|
||||
async def _execute_for_single_target(
|
||||
self, schedule: ScheduleInfo, task_meta: dict, bot
|
||||
):
|
||||
"""为单个目标(具体群组或全局)执行任务。"""
|
||||
plugin_name = schedule.plugin_name
|
||||
group_id = schedule.group_id
|
||||
|
||||
try:
|
||||
is_blocked = await CommonUtils.task_is_block(bot, plugin_name, group_id)
|
||||
if is_blocked:
|
||||
target_desc = f"群 {group_id}" if group_id else "全局"
|
||||
logger.info(
|
||||
f"插件 '{plugin_name}' 的定时任务在目标 [{target_desc}]"
|
||||
"因功能被禁用而跳过执行。"
|
||||
)
|
||||
return
|
||||
|
||||
task_func = task_meta["func"]
|
||||
job_kwargs = schedule.job_kwargs
|
||||
if not isinstance(job_kwargs, dict):
|
||||
logger.error(
|
||||
f"任务 {schedule.id} 的 job_kwargs 不是字典类型: {type(job_kwargs)}"
|
||||
)
|
||||
return
|
||||
|
||||
sig = inspect.signature(task_func)
|
||||
if "bot" in sig.parameters:
|
||||
job_kwargs["bot"] = bot
|
||||
|
||||
logger.info(
|
||||
f"插件 '{plugin_name}' 开始为目标 [{group_id or '全局'}] "
|
||||
f"执行定时任务 (ID: {schedule.id})。"
|
||||
)
|
||||
task = asyncio.create_task(task_func(group_id, **job_kwargs))
|
||||
self._running_tasks.add(task)
|
||||
task.add_done_callback(self._running_tasks.discard)
|
||||
await task
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"执行定时任务 (ID: {schedule.id}, 插件: {plugin_name}, "
|
||||
f"目标: {group_id or '全局'}) 时发生异常",
|
||||
e=e,
|
||||
)
|
||||
|
||||
def _validate_and_prepare_kwargs(
|
||||
self, plugin_name: str, job_kwargs: dict | None
|
||||
) -> tuple[bool, str | dict]:
|
||||
"""验证并准备任务参数,应用默认值"""
|
||||
task_meta = self._registered_tasks.get(plugin_name)
|
||||
if not task_meta:
|
||||
return False, f"插件 '{plugin_name}' 未注册。"
|
||||
|
||||
params_model = task_meta.get("model")
|
||||
job_kwargs = job_kwargs if job_kwargs is not None else {}
|
||||
|
||||
if not params_model:
|
||||
if job_kwargs:
|
||||
logger.warning(
|
||||
f"插件 '{plugin_name}' 未定义参数模型,但收到了参数: {job_kwargs}"
|
||||
)
|
||||
return True, job_kwargs
|
||||
|
||||
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
|
||||
logger.error(f"插件 '{plugin_name}' 的参数模型不是有效的 BaseModel 类")
|
||||
return False, f"插件 '{plugin_name}' 的参数模型配置错误"
|
||||
|
||||
try:
|
||||
model_validate = getattr(params_model, "model_validate", None)
|
||||
if not model_validate:
|
||||
return False, f"插件 '{plugin_name}' 的参数模型不支持验证"
|
||||
|
||||
validated_model = model_validate(job_kwargs)
|
||||
|
||||
model_dump = getattr(validated_model, "model_dump", None)
|
||||
if not model_dump:
|
||||
return False, f"插件 '{plugin_name}' 的参数模型不支持导出"
|
||||
|
||||
return True, model_dump()
|
||||
except ValidationError as e:
|
||||
errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()]
|
||||
error_str = "\n".join(errors)
|
||||
msg = f"插件 '{plugin_name}' 的任务参数验证失败:\n{error_str}"
|
||||
return False, msg
|
||||
|
||||
def _add_aps_job(self, schedule: ScheduleInfo):
|
||||
"""根据 ScheduleInfo 对象添加或更新一个 APScheduler 任务。"""
|
||||
job_id = self._get_job_id(schedule.id)
|
||||
try:
|
||||
scheduler.remove_job(job_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not isinstance(schedule.trigger_config, dict):
|
||||
logger.error(
|
||||
f"任务 {schedule.id} 的 trigger_config 不是字典类型: "
|
||||
f"{type(schedule.trigger_config)}"
|
||||
)
|
||||
return
|
||||
|
||||
scheduler.add_job(
|
||||
self._execute_job,
|
||||
trigger=schedule.trigger_type,
|
||||
id=job_id,
|
||||
misfire_grace_time=300,
|
||||
args=[schedule.id],
|
||||
**schedule.trigger_config,
|
||||
)
|
||||
logger.debug(
|
||||
f"已在 APScheduler 中添加/更新任务: {job_id} "
|
||||
f"with trigger: {schedule.trigger_config}"
|
||||
)
|
||||
|
||||
def _remove_aps_job(self, schedule_id: int):
|
||||
"""移除一个 APScheduler 任务。"""
|
||||
job_id = self._get_job_id(schedule_id)
|
||||
try:
|
||||
scheduler.remove_job(job_id)
|
||||
logger.debug(f"已从 APScheduler 中移除任务: {job_id}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def add_schedule(
|
||||
self,
|
||||
plugin_name: str,
|
||||
group_id: str | None,
|
||||
trigger_type: str,
|
||||
trigger_config: dict,
|
||||
job_kwargs: dict | None = None,
|
||||
bot_id: str | None = None,
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
添加或更新一个定时任务。
|
||||
"""
|
||||
if plugin_name not in self._registered_tasks:
|
||||
return False, f"插件 '{plugin_name}' 没有注册可用的定时任务。"
|
||||
|
||||
is_valid, result = self._validate_and_prepare_kwargs(plugin_name, job_kwargs)
|
||||
if not is_valid:
|
||||
return False, str(result)
|
||||
|
||||
validated_job_kwargs = result
|
||||
|
||||
effective_bot_id = bot_id if group_id == "__ALL_GROUPS__" else None
|
||||
|
||||
search_kwargs = {
|
||||
"plugin_name": plugin_name,
|
||||
"group_id": group_id,
|
||||
}
|
||||
if effective_bot_id:
|
||||
search_kwargs["bot_id"] = effective_bot_id
|
||||
else:
|
||||
search_kwargs["bot_id__isnull"] = True
|
||||
|
||||
defaults = {
|
||||
"trigger_type": trigger_type,
|
||||
"trigger_config": trigger_config,
|
||||
"job_kwargs": validated_job_kwargs,
|
||||
"is_enabled": True,
|
||||
}
|
||||
|
||||
schedule = await ScheduleInfo.filter(**search_kwargs).first()
|
||||
created = False
|
||||
|
||||
if schedule:
|
||||
for key, value in defaults.items():
|
||||
setattr(schedule, key, value)
|
||||
await schedule.save()
|
||||
else:
|
||||
creation_kwargs = {
|
||||
"plugin_name": plugin_name,
|
||||
"group_id": group_id,
|
||||
"bot_id": effective_bot_id,
|
||||
**defaults,
|
||||
}
|
||||
schedule = await ScheduleInfo.create(**creation_kwargs)
|
||||
created = True
|
||||
self._add_aps_job(schedule)
|
||||
action = "设置" if created else "更新"
|
||||
return True, f"已成功{action}插件 '{plugin_name}' 的定时任务。"
|
||||
|
||||
async def add_schedule_for_all(
|
||||
self,
|
||||
plugin_name: str,
|
||||
trigger_type: str,
|
||||
trigger_config: dict,
|
||||
job_kwargs: dict | None = None,
|
||||
) -> tuple[int, int]:
|
||||
"""为所有机器人所在的群组添加定时任务。"""
|
||||
if plugin_name not in self._registered_tasks:
|
||||
raise ValueError(f"插件 '{plugin_name}' 没有注册可用的定时任务。")
|
||||
|
||||
groups = set()
|
||||
for bot in get_bots().values():
|
||||
try:
|
||||
group_list, _ = await PlatformUtils.get_group_list(bot)
|
||||
groups.update(
|
||||
g.group_id for g in group_list if g.group_id and not g.channel_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取 Bot {bot.self_id} 的群列表失败", e=e)
|
||||
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
for gid in groups:
|
||||
try:
|
||||
success, _ = await self.add_schedule(
|
||||
plugin_name, gid, trigger_type, trigger_config, job_kwargs
|
||||
)
|
||||
if success:
|
||||
success_count += 1
|
||||
else:
|
||||
fail_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"为群 {gid} 添加定时任务失败: {e}", e=e)
|
||||
fail_count += 1
|
||||
await asyncio.sleep(0.05)
|
||||
return success_count, fail_count
|
||||
|
||||
async def update_schedule(
|
||||
self,
|
||||
schedule_id: int,
|
||||
trigger_type: str | None = None,
|
||||
trigger_config: dict | None = None,
|
||||
job_kwargs: dict | None = None,
|
||||
) -> tuple[bool, str]:
|
||||
"""部分更新一个已存在的定时任务。"""
|
||||
schedule = await self.get_schedule_by_id(schedule_id)
|
||||
if not schedule:
|
||||
return False, f"未找到 ID 为 {schedule_id} 的任务。"
|
||||
|
||||
updated_fields = []
|
||||
if trigger_config is not None:
|
||||
schedule.trigger_config = trigger_config
|
||||
updated_fields.append("trigger_config")
|
||||
|
||||
if trigger_type is not None and schedule.trigger_type != trigger_type:
|
||||
schedule.trigger_type = trigger_type
|
||||
updated_fields.append("trigger_type")
|
||||
|
||||
if job_kwargs is not None:
|
||||
if not isinstance(schedule.job_kwargs, dict):
|
||||
return False, f"任务 {schedule_id} 的 job_kwargs 数据格式错误。"
|
||||
|
||||
merged_kwargs = schedule.job_kwargs.copy()
|
||||
merged_kwargs.update(job_kwargs)
|
||||
|
||||
is_valid, result = self._validate_and_prepare_kwargs(
|
||||
schedule.plugin_name, merged_kwargs
|
||||
)
|
||||
if not is_valid:
|
||||
return False, str(result)
|
||||
|
||||
schedule.job_kwargs = result # type: ignore
|
||||
updated_fields.append("job_kwargs")
|
||||
|
||||
if not updated_fields:
|
||||
return True, "没有任何需要更新的配置。"
|
||||
|
||||
await schedule.save(update_fields=updated_fields)
|
||||
self._add_aps_job(schedule)
|
||||
return True, f"成功更新了任务 ID: {schedule_id} 的配置。"
|
||||
|
||||
async def remove_schedule(
|
||||
self, plugin_name: str, group_id: str | None, bot_id: str | None = None
|
||||
) -> tuple[bool, str]:
|
||||
"""移除指定插件和群组的定时任务。"""
|
||||
query = {"plugin_name": plugin_name, "group_id": group_id}
|
||||
if bot_id:
|
||||
query["bot_id"] = bot_id
|
||||
|
||||
schedules = await ScheduleInfo.filter(**query)
|
||||
if not schedules:
|
||||
msg = (
|
||||
f"未找到与 Bot {bot_id} 相关的群 {group_id} "
|
||||
f"的插件 '{plugin_name}' 定时任务。"
|
||||
)
|
||||
return (False, msg)
|
||||
|
||||
for schedule in schedules:
|
||||
self._remove_aps_job(schedule.id)
|
||||
await schedule.delete()
|
||||
|
||||
target_desc = f"群 {group_id}" if group_id else "全局"
|
||||
msg = (
|
||||
f"已取消 Bot {bot_id} 在 [{target_desc}] "
|
||||
f"的插件 '{plugin_name}' 所有定时任务。"
|
||||
)
|
||||
return (True, msg)
|
||||
|
||||
async def remove_schedule_for_all(
|
||||
self, plugin_name: str, bot_id: str | None = None
|
||||
) -> int:
|
||||
"""移除指定插件在所有群组的定时任务。"""
|
||||
query = {"plugin_name": plugin_name}
|
||||
if bot_id:
|
||||
query["bot_id"] = bot_id
|
||||
|
||||
schedules_to_delete = await ScheduleInfo.filter(**query).all()
|
||||
if not schedules_to_delete:
|
||||
return 0
|
||||
|
||||
for schedule in schedules_to_delete:
|
||||
self._remove_aps_job(schedule.id)
|
||||
await schedule.delete()
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
return len(schedules_to_delete)
|
||||
|
||||
async def remove_schedules_by_group(self, group_id: str) -> tuple[bool, str]:
|
||||
"""移除指定群组的所有定时任务。"""
|
||||
schedules = await ScheduleInfo.filter(group_id=group_id)
|
||||
if not schedules:
|
||||
return False, f"群 {group_id} 没有任何定时任务。"
|
||||
|
||||
count = 0
|
||||
for schedule in schedules:
|
||||
self._remove_aps_job(schedule.id)
|
||||
await schedule.delete()
|
||||
count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
return True, f"已成功移除群 {group_id} 的 {count} 个定时任务。"
|
||||
|
||||
async def pause_schedules_by_group(self, group_id: str) -> tuple[int, str]:
|
||||
"""暂停指定群组的所有定时任务。"""
|
||||
schedules = await ScheduleInfo.filter(group_id=group_id, is_enabled=True)
|
||||
if not schedules:
|
||||
return 0, f"群 {group_id} 没有正在运行的定时任务可暂停。"
|
||||
|
||||
count = 0
|
||||
for schedule in schedules:
|
||||
success, _ = await self.pause_schedule(schedule.id)
|
||||
if success:
|
||||
count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
return count, f"已成功暂停群 {group_id} 的 {count} 个定时任务。"
|
||||
|
||||
async def resume_schedules_by_group(self, group_id: str) -> tuple[int, str]:
|
||||
"""恢复指定群组的所有定时任务。"""
|
||||
schedules = await ScheduleInfo.filter(group_id=group_id, is_enabled=False)
|
||||
if not schedules:
|
||||
return 0, f"群 {group_id} 没有已暂停的定时任务可恢复。"
|
||||
|
||||
count = 0
|
||||
for schedule in schedules:
|
||||
success, _ = await self.resume_schedule(schedule.id)
|
||||
if success:
|
||||
count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
return count, f"已成功恢复群 {group_id} 的 {count} 个定时任务。"
|
||||
|
||||
async def pause_schedules_by_plugin(self, plugin_name: str) -> tuple[int, str]:
|
||||
"""暂停指定插件在所有群组的定时任务。"""
|
||||
schedules = await ScheduleInfo.filter(plugin_name=plugin_name, is_enabled=True)
|
||||
if not schedules:
|
||||
return 0, f"插件 '{plugin_name}' 没有正在运行的定时任务可暂停。"
|
||||
|
||||
count = 0
|
||||
for schedule in schedules:
|
||||
success, _ = await self.pause_schedule(schedule.id)
|
||||
if success:
|
||||
count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
return (
|
||||
count,
|
||||
f"已成功暂停插件 '{plugin_name}' 在所有群组的 {count} 个定时任务。",
|
||||
)
|
||||
|
||||
async def resume_schedules_by_plugin(self, plugin_name: str) -> tuple[int, str]:
|
||||
"""恢复指定插件在所有群组的定时任务。"""
|
||||
schedules = await ScheduleInfo.filter(plugin_name=plugin_name, is_enabled=False)
|
||||
if not schedules:
|
||||
return 0, f"插件 '{plugin_name}' 没有已暂停的定时任务可恢复。"
|
||||
|
||||
count = 0
|
||||
for schedule in schedules:
|
||||
success, _ = await self.resume_schedule(schedule.id)
|
||||
if success:
|
||||
count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
return (
|
||||
count,
|
||||
f"已成功恢复插件 '{plugin_name}' 在所有群组的 {count} 个定时任务。",
|
||||
)
|
||||
|
||||
async def pause_schedule_by_plugin_group(
|
||||
self, plugin_name: str, group_id: str | None, bot_id: str | None = None
|
||||
) -> tuple[bool, str]:
|
||||
"""暂停指定插件在指定群组的定时任务。"""
|
||||
query = {"plugin_name": plugin_name, "group_id": group_id, "is_enabled": True}
|
||||
if bot_id:
|
||||
query["bot_id"] = bot_id
|
||||
|
||||
schedules = await ScheduleInfo.filter(**query)
|
||||
if not schedules:
|
||||
return (
|
||||
False,
|
||||
f"群 {group_id} 未设置插件 '{plugin_name}' 的定时任务或任务已暂停。",
|
||||
)
|
||||
|
||||
count = 0
|
||||
for schedule in schedules:
|
||||
success, _ = await self.pause_schedule(schedule.id)
|
||||
if success:
|
||||
count += 1
|
||||
|
||||
return (
|
||||
True,
|
||||
f"已成功暂停群 {group_id} 的插件 '{plugin_name}' 共 {count} 个定时任务。",
|
||||
)
|
||||
|
||||
async def resume_schedule_by_plugin_group(
|
||||
self, plugin_name: str, group_id: str | None, bot_id: str | None = None
|
||||
) -> tuple[bool, str]:
|
||||
"""恢复指定插件在指定群组的定时任务。"""
|
||||
query = {"plugin_name": plugin_name, "group_id": group_id, "is_enabled": False}
|
||||
if bot_id:
|
||||
query["bot_id"] = bot_id
|
||||
|
||||
schedules = await ScheduleInfo.filter(**query)
|
||||
if not schedules:
|
||||
return (
|
||||
False,
|
||||
f"群 {group_id} 未设置插件 '{plugin_name}' 的定时任务或任务已启用。",
|
||||
)
|
||||
|
||||
count = 0
|
||||
for schedule in schedules:
|
||||
success, _ = await self.resume_schedule(schedule.id)
|
||||
if success:
|
||||
count += 1
|
||||
|
||||
return (
|
||||
True,
|
||||
f"已成功恢复群 {group_id} 的插件 '{plugin_name}' 共 {count} 个定时任务。",
|
||||
)
|
||||
|
||||
async def remove_all_schedules(self) -> tuple[int, str]:
|
||||
"""移除所有群组的所有定时任务。"""
|
||||
schedules = await ScheduleInfo.all()
|
||||
if not schedules:
|
||||
return 0, "当前没有任何定时任务。"
|
||||
|
||||
count = 0
|
||||
for schedule in schedules:
|
||||
self._remove_aps_job(schedule.id)
|
||||
await schedule.delete()
|
||||
count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
return count, f"已成功移除所有群组的 {count} 个定时任务。"
|
||||
|
||||
async def pause_all_schedules(self) -> tuple[int, str]:
|
||||
"""暂停所有群组的所有定时任务。"""
|
||||
schedules = await ScheduleInfo.filter(is_enabled=True)
|
||||
if not schedules:
|
||||
return 0, "当前没有正在运行的定时任务可暂停。"
|
||||
|
||||
count = 0
|
||||
for schedule in schedules:
|
||||
success, _ = await self.pause_schedule(schedule.id)
|
||||
if success:
|
||||
count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
return count, f"已成功暂停所有群组的 {count} 个定时任务。"
|
||||
|
||||
async def resume_all_schedules(self) -> tuple[int, str]:
|
||||
"""恢复所有群组的所有定时任务。"""
|
||||
schedules = await ScheduleInfo.filter(is_enabled=False)
|
||||
if not schedules:
|
||||
return 0, "当前没有已暂停的定时任务可恢复。"
|
||||
|
||||
count = 0
|
||||
for schedule in schedules:
|
||||
success, _ = await self.resume_schedule(schedule.id)
|
||||
if success:
|
||||
count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
return count, f"已成功恢复所有群组的 {count} 个定时任务。"
|
||||
|
||||
async def remove_schedule_by_id(self, schedule_id: int) -> tuple[bool, str]:
|
||||
"""通过ID移除指定的定时任务。"""
|
||||
schedule = await self.get_schedule_by_id(schedule_id)
|
||||
if not schedule:
|
||||
return False, f"未找到 ID 为 {schedule_id} 的定时任务。"
|
||||
|
||||
self._remove_aps_job(schedule.id)
|
||||
await schedule.delete()
|
||||
|
||||
return (
|
||||
True,
|
||||
f"已删除插件 '{schedule.plugin_name}' 在群 {schedule.group_id} "
|
||||
f"的定时任务 (ID: {schedule.id})。",
|
||||
)
|
||||
|
||||
async def get_schedule_by_id(self, schedule_id: int) -> ScheduleInfo | None:
|
||||
"""通过ID获取定时任务信息。"""
|
||||
return await ScheduleInfo.get_or_none(id=schedule_id)
|
||||
|
||||
async def get_schedules(
|
||||
self, plugin_name: str, group_id: str | None
|
||||
) -> list[ScheduleInfo]:
|
||||
"""获取特定群组特定插件的所有定时任务。"""
|
||||
return await ScheduleInfo.filter(plugin_name=plugin_name, group_id=group_id)
|
||||
|
||||
async def get_schedule(
|
||||
self, plugin_name: str, group_id: str | None
|
||||
) -> ScheduleInfo | None:
|
||||
"""获取特定群组的定时任务信息。"""
|
||||
return await ScheduleInfo.get_or_none(
|
||||
plugin_name=plugin_name, group_id=group_id
|
||||
)
|
||||
|
||||
async def get_all_schedules(
|
||||
self, plugin_name: str | None = None
|
||||
) -> list[ScheduleInfo]:
|
||||
"""获取所有定时任务信息,可按插件名过滤。"""
|
||||
if plugin_name:
|
||||
return await ScheduleInfo.filter(plugin_name=plugin_name).all()
|
||||
return await ScheduleInfo.all()
|
||||
|
||||
async def get_schedule_status(self, schedule_id: int) -> dict | None:
|
||||
"""获取任务的详细状态。"""
|
||||
schedule = await self.get_schedule_by_id(schedule_id)
|
||||
if not schedule:
|
||||
return None
|
||||
|
||||
job_id = self._get_job_id(schedule.id)
|
||||
job = scheduler.get_job(job_id)
|
||||
|
||||
status = {
|
||||
"id": schedule.id,
|
||||
"bot_id": schedule.bot_id,
|
||||
"plugin_name": schedule.plugin_name,
|
||||
"group_id": schedule.group_id,
|
||||
"is_enabled": schedule.is_enabled,
|
||||
"trigger_type": schedule.trigger_type,
|
||||
"trigger_config": schedule.trigger_config,
|
||||
"job_kwargs": schedule.job_kwargs,
|
||||
"next_run_time": job.next_run_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
if job and job.next_run_time
|
||||
else "N/A",
|
||||
"is_paused_in_scheduler": not bool(job.next_run_time) if job else "N/A",
|
||||
}
|
||||
return status
|
||||
|
||||
async def pause_schedule(self, schedule_id: int) -> tuple[bool, str]:
|
||||
"""暂停一个定时任务。"""
|
||||
schedule = await self.get_schedule_by_id(schedule_id)
|
||||
if not schedule or not schedule.is_enabled:
|
||||
return False, "任务不存在或已暂停。"
|
||||
|
||||
schedule.is_enabled = False
|
||||
await schedule.save(update_fields=["is_enabled"])
|
||||
|
||||
job_id = self._get_job_id(schedule.id)
|
||||
try:
|
||||
scheduler.pause_job(job_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return (
|
||||
True,
|
||||
f"已暂停插件 '{schedule.plugin_name}' 在群 {schedule.group_id} "
|
||||
f"的定时任务 (ID: {schedule.id})。",
|
||||
)
|
||||
|
||||
async def resume_schedule(self, schedule_id: int) -> tuple[bool, str]:
|
||||
"""恢复一个定时任务。"""
|
||||
schedule = await self.get_schedule_by_id(schedule_id)
|
||||
if not schedule or schedule.is_enabled:
|
||||
return False, "任务不存在或已启用。"
|
||||
|
||||
schedule.is_enabled = True
|
||||
await schedule.save(update_fields=["is_enabled"])
|
||||
|
||||
job_id = self._get_job_id(schedule.id)
|
||||
try:
|
||||
scheduler.resume_job(job_id)
|
||||
except Exception:
|
||||
self._add_aps_job(schedule)
|
||||
|
||||
return (
|
||||
True,
|
||||
f"已恢复插件 '{schedule.plugin_name}' 在群 {schedule.group_id} "
|
||||
f"的定时任务 (ID: {schedule.id})。",
|
||||
)
|
||||
|
||||
async def trigger_now(self, schedule_id: int) -> tuple[bool, str]:
|
||||
"""手动触发一个定时任务。"""
|
||||
schedule = await self.get_schedule_by_id(schedule_id)
|
||||
if not schedule:
|
||||
return False, f"未找到 ID 为 {schedule_id} 的定时任务。"
|
||||
|
||||
if schedule.plugin_name not in self._registered_tasks:
|
||||
return False, f"插件 '{schedule.plugin_name}' 没有注册可用的定时任务。"
|
||||
|
||||
try:
|
||||
await self._execute_job(schedule.id)
|
||||
return (
|
||||
True,
|
||||
f"已手动触发插件 '{schedule.plugin_name}' 在群 {schedule.group_id} "
|
||||
f"的定时任务 (ID: {schedule.id})。",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"手动触发任务失败: {e}")
|
||||
return False, f"手动触发任务失败: {e}"
|
||||
|
||||
|
||||
scheduler_manager = SchedulerManager()
|
||||
|
||||
|
||||
@PriorityLifecycle.on_startup(priority=90)
|
||||
async def _load_schedules_from_db():
|
||||
"""在服务启动时从数据库加载并调度所有任务。"""
|
||||
Config.add_plugin_config(
|
||||
"SchedulerManager",
|
||||
SCHEDULE_CONCURRENCY_KEY,
|
||||
5,
|
||||
help="“所有群组”类型定时任务的并发执行数量限制",
|
||||
type=int,
|
||||
)
|
||||
|
||||
logger.info("正在从数据库加载并调度所有定时任务...")
|
||||
schedules = await ScheduleInfo.filter(is_enabled=True).all()
|
||||
count = 0
|
||||
for schedule in schedules:
|
||||
if schedule.plugin_name in scheduler_manager._registered_tasks:
|
||||
scheduler_manager._add_aps_job(schedule)
|
||||
count += 1
|
||||
else:
|
||||
logger.warning(f"跳过加载定时任务:插件 '{schedule.plugin_name}' 未注册。")
|
||||
logger.info(f"定时任务加载完成,共成功加载 {count} 个任务。")
|
||||
@ -31,7 +31,9 @@ class VirtualEnvPackageManager:
|
||||
def __get_command(cls) -> list[str]:
|
||||
if path := Config.get_config("virtualenv", "python_path"):
|
||||
return [path, "-m", "pip"]
|
||||
return cls.WIN_COMMAND if BAT_FILE.exists() else cls.DEFAULT_COMMAND
|
||||
return (
|
||||
cls.WIN_COMMAND.copy() if BAT_FILE.exists() else cls.DEFAULT_COMMAND.copy()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def install(cls, package: list[str] | str):
|
||||
@ -57,8 +59,10 @@ class VirtualEnvPackageManager:
|
||||
f"安装虚拟环境包指令执行完成: {result.stdout}",
|
||||
LOG_COMMAND,
|
||||
)
|
||||
return result.stdout
|
||||
except CalledProcessError as e:
|
||||
logger.error(f"安装虚拟环境包指令执行失败: {e.stderr}.", LOG_COMMAND)
|
||||
return e.stderr
|
||||
|
||||
@classmethod
|
||||
def uninstall(cls, package: list[str] | str):
|
||||
@ -72,6 +76,7 @@ class VirtualEnvPackageManager:
|
||||
try:
|
||||
command = cls.__get_command()
|
||||
command.append("uninstall")
|
||||
command.append("-y")
|
||||
command.append(" ".join(package))
|
||||
logger.info(f"执行虚拟环境卸载包指令: {command}", LOG_COMMAND)
|
||||
result = subprocess.run(
|
||||
@ -84,8 +89,10 @@ class VirtualEnvPackageManager:
|
||||
f"卸载虚拟环境包指令执行完成: {result.stdout}",
|
||||
LOG_COMMAND,
|
||||
)
|
||||
return result.stdout
|
||||
except CalledProcessError as e:
|
||||
logger.error(f"卸载虚拟环境包指令执行失败: {e.stderr}.", LOG_COMMAND)
|
||||
return e.stderr
|
||||
|
||||
@classmethod
|
||||
def update(cls, package: list[str] | str):
|
||||
@ -109,8 +116,10 @@ class VirtualEnvPackageManager:
|
||||
text=True,
|
||||
)
|
||||
logger.debug(f"更新虚拟环境包指令执行完成: {result.stdout}", LOG_COMMAND)
|
||||
return result.stdout
|
||||
except CalledProcessError as e:
|
||||
logger.error(f"更新虚拟环境包指令执行失败: {e.stderr}.", LOG_COMMAND)
|
||||
return e.stderr
|
||||
|
||||
@classmethod
|
||||
def install_requirement(cls, requirement_file: Path):
|
||||
@ -140,11 +149,13 @@ class VirtualEnvPackageManager:
|
||||
f"安装虚拟环境依赖文件指令执行完成: {result.stdout}",
|
||||
LOG_COMMAND,
|
||||
)
|
||||
return result.stdout
|
||||
except CalledProcessError as e:
|
||||
logger.error(
|
||||
f"安装虚拟环境依赖文件指令执行失败: {e.stderr}.",
|
||||
LOG_COMMAND,
|
||||
)
|
||||
return e.stderr
|
||||
|
||||
@classmethod
|
||||
def list(cls) -> str:
|
||||
|
||||
@ -80,14 +80,14 @@ class PlatformUtils:
|
||||
@classmethod
|
||||
async def send_superuser(
|
||||
cls,
|
||||
bot: Bot,
|
||||
bot: Bot | None,
|
||||
message: UniMessage | str,
|
||||
superuser_id: str | None = None,
|
||||
) -> list[tuple[str, Receipt]]:
|
||||
"""发送消息给超级用户
|
||||
|
||||
参数:
|
||||
bot: Bot
|
||||
bot: Bot,没有传入时使用get_bot随机获取
|
||||
message: 消息
|
||||
superuser_id: 指定超级用户id.
|
||||
|
||||
@ -97,6 +97,8 @@ class PlatformUtils:
|
||||
返回:
|
||||
Receipt | None: Receipt
|
||||
"""
|
||||
if not bot:
|
||||
bot = nonebot.get_bot()
|
||||
superuser_ids = []
|
||||
if superuser_id:
|
||||
superuser_ids.append(superuser_id)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from datetime import date, datetime
|
||||
import os
|
||||
from pathlib import Path
|
||||
import time
|
||||
@ -277,3 +277,20 @@ def is_number(text: str) -> bool:
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
class TimeUtils:
|
||||
@classmethod
|
||||
def get_day_start(cls, target_date: date | datetime | None = None) -> datetime:
|
||||
"""获取某天的0点时间
|
||||
|
||||
返回:
|
||||
datetime: 今天某天的0点时间
|
||||
"""
|
||||
if not target_date:
|
||||
target_date = datetime.now()
|
||||
return (
|
||||
target_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
if isinstance(target_date, datetime)
|
||||
else datetime.combine(target_date, datetime.min.time())
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user