mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 14:22:55 +08:00
Merge branch 'main' into feature/fix-config
This commit is contained in:
commit
1992e478b5
13
.gitignore
vendored
13
.gitignore
vendored
@ -139,22 +139,9 @@ dmypy.json
|
|||||||
# Cython debug symbols
|
# Cython debug symbols
|
||||||
cython_debug/
|
cython_debug/
|
||||||
|
|
||||||
demo.py
|
|
||||||
test.py
|
|
||||||
server_ip.py
|
|
||||||
member_activity_handle.py
|
|
||||||
Yu-Gi-Oh/
|
|
||||||
csgo/
|
|
||||||
fantasy_card/
|
|
||||||
data/
|
data/
|
||||||
log/
|
log/
|
||||||
backup/
|
backup/
|
||||||
extensive_plugin/
|
|
||||||
test/
|
|
||||||
bot.py
|
|
||||||
.idea/
|
.idea/
|
||||||
resources/
|
resources/
|
||||||
/configs/config.py
|
|
||||||
configs/config.yaml
|
|
||||||
.vscode/launch.json
|
.vscode/launch.json
|
||||||
plugins_/
|
|
||||||
1889
data/anime.json
1889
data/anime.json
File diff suppressed because it is too large
Load Diff
@ -116,6 +116,7 @@ async def app(app: App, tmp_path: Path, mocker: MockerFixture):
|
|||||||
await init()
|
await init()
|
||||||
# await driver._lifespan.startup()
|
# await driver._lifespan.startup()
|
||||||
os.environ["AIOCACHE_DISABLE"] = "1"
|
os.environ["AIOCACHE_DISABLE"] = "1"
|
||||||
|
os.environ["PYTEST_CURRENT_TEST"] = "1"
|
||||||
|
|
||||||
yield app
|
yield app
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from nonebot.adapters import Bot
|
from nonebot.adapters import Bot
|
||||||
from nonebot.plugin import PluginMetadata
|
from nonebot.plugin import PluginMetadata
|
||||||
from nonebot_plugin_alconna import AlconnaQuery, Arparma, Match, Query
|
from nonebot_plugin_alconna import AlconnaQuery, Arparma, Match, Query
|
||||||
from nonebot_plugin_session import EventSession
|
from nonebot_plugin_uninfo import Uninfo
|
||||||
|
|
||||||
from zhenxun.configs.config import Config
|
from zhenxun.configs.config import Config
|
||||||
from zhenxun.configs.utils import PluginExtraData, RegisterConfig
|
from zhenxun.configs.utils import PluginExtraData, RegisterConfig
|
||||||
@ -9,7 +9,7 @@ from zhenxun.services.log import logger
|
|||||||
from zhenxun.utils.enum import BlockType, PluginType
|
from zhenxun.utils.enum import BlockType, PluginType
|
||||||
from zhenxun.utils.message import MessageUtils
|
from zhenxun.utils.message import MessageUtils
|
||||||
|
|
||||||
from ._data_source import PluginManage, build_plugin, build_task, delete_help_image
|
from ._data_source import PluginManager, build_plugin, build_task, delete_help_image
|
||||||
from .command import _group_status_matcher, _status_matcher
|
from .command import _group_status_matcher, _status_matcher
|
||||||
|
|
||||||
base_config = Config.get("plugin_switch")
|
base_config = Config.get("plugin_switch")
|
||||||
@ -57,6 +57,11 @@ __plugin_meta__ = PluginMetadata(
|
|||||||
关闭群被动早晚安
|
关闭群被动早晚安
|
||||||
关闭群被动早晚安 -g 12355555
|
关闭群被动早晚安 -g 12355555
|
||||||
|
|
||||||
|
开启/关闭默认群被动 [被动名称]
|
||||||
|
私聊下: 开启/关闭群被动默认状态
|
||||||
|
示例:
|
||||||
|
关闭默认群被动 早晚安
|
||||||
|
|
||||||
开启/关闭所有群被动 ?[-g [group_id]]
|
开启/关闭所有群被动 ?[-g [group_id]]
|
||||||
私聊中: 开启/关闭全局或指定群组被动状态
|
私聊中: 开启/关闭全局或指定群组被动状态
|
||||||
示例:
|
示例:
|
||||||
@ -87,10 +92,10 @@ __plugin_meta__ = PluginMetadata(
|
|||||||
@_status_matcher.assign("$main")
|
@_status_matcher.assign("$main")
|
||||||
async def _(
|
async def _(
|
||||||
bot: Bot,
|
bot: Bot,
|
||||||
session: EventSession,
|
session: Uninfo,
|
||||||
arparma: Arparma,
|
arparma: Arparma,
|
||||||
):
|
):
|
||||||
if session.id1 in bot.config.superusers:
|
if session.user.id in bot.config.superusers:
|
||||||
image = await build_plugin()
|
image = await build_plugin()
|
||||||
logger.info(
|
logger.info(
|
||||||
"查看功能列表",
|
"查看功能列表",
|
||||||
@ -105,7 +110,7 @@ async def _(
|
|||||||
@_status_matcher.assign("open")
|
@_status_matcher.assign("open")
|
||||||
async def _(
|
async def _(
|
||||||
bot: Bot,
|
bot: Bot,
|
||||||
session: EventSession,
|
session: Uninfo,
|
||||||
arparma: Arparma,
|
arparma: Arparma,
|
||||||
plugin_name: Match[str],
|
plugin_name: Match[str],
|
||||||
group: Match[str],
|
group: Match[str],
|
||||||
@ -114,22 +119,23 @@ async def _(
|
|||||||
all: Query[bool] = AlconnaQuery("all.value", False),
|
all: Query[bool] = AlconnaQuery("all.value", False),
|
||||||
):
|
):
|
||||||
if not all.result and not plugin_name.available:
|
if not all.result and not plugin_name.available:
|
||||||
await MessageUtils.build_message("请输入功能名称").finish(reply_to=True)
|
await MessageUtils.build_message("请输入功能/被动名称").finish(reply_to=True)
|
||||||
name = plugin_name.result
|
name = plugin_name.result
|
||||||
if gid := session.id3 or session.id2:
|
if session.group:
|
||||||
|
group_id = session.group.id
|
||||||
"""修改当前群组的数据"""
|
"""修改当前群组的数据"""
|
||||||
if task.result:
|
if task.result:
|
||||||
if all.result:
|
if all.result:
|
||||||
result = await PluginManage.unblock_group_all_task(gid)
|
result = await PluginManager.unblock_group_all_task(group_id)
|
||||||
logger.info("开启所有群组被动", arparma.header_result, session=session)
|
logger.info("开启所有群组被动", arparma.header_result, session=session)
|
||||||
else:
|
else:
|
||||||
result = await PluginManage.unblock_group_task(name, gid)
|
result = await PluginManager.unblock_group_task(name, group_id)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"开启群组被动 {name}", arparma.header_result, session=session
|
f"开启群组被动 {name}", arparma.header_result, session=session
|
||||||
)
|
)
|
||||||
elif session.id1 in bot.config.superusers and default_status.result:
|
elif session.user.id in bot.config.superusers and default_status.result:
|
||||||
"""单个插件的进群默认修改"""
|
"""单个插件的进群默认修改"""
|
||||||
result = await PluginManage.set_default_status(name, True)
|
result = await PluginManager.set_default_status(name, True)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"超级用户开启 {name} 功能进群默认开关",
|
f"超级用户开启 {name} 功能进群默认开关",
|
||||||
arparma.header_result,
|
arparma.header_result,
|
||||||
@ -137,8 +143,8 @@ async def _(
|
|||||||
)
|
)
|
||||||
elif all.result:
|
elif all.result:
|
||||||
"""所有插件"""
|
"""所有插件"""
|
||||||
result = await PluginManage.set_all_plugin_status(
|
result = await PluginManager.set_all_plugin_status(
|
||||||
True, default_status.result, gid
|
True, default_status.result, group_id
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
"开启群组中全部功能",
|
"开启群组中全部功能",
|
||||||
@ -146,22 +152,24 @@ async def _(
|
|||||||
session=session,
|
session=session,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
result = await PluginManage.unblock_group_plugin(name, gid)
|
result = await PluginManager.unblock_group_plugin(name, group_id)
|
||||||
logger.info(f"开启功能 {name}", arparma.header_result, session=session)
|
logger.info(f"开启功能 {name}", arparma.header_result, session=session)
|
||||||
delete_help_image(gid)
|
delete_help_image(group_id)
|
||||||
await MessageUtils.build_message(result).finish(reply_to=True)
|
await MessageUtils.build_message(result).finish(reply_to=True)
|
||||||
elif session.id1 in bot.config.superusers:
|
elif session.user.id in bot.config.superusers:
|
||||||
"""私聊"""
|
"""私聊"""
|
||||||
group_id = group.result if group.available else None
|
group_id = group.result if group.available else None
|
||||||
if all.result:
|
if all.result:
|
||||||
if task.result:
|
if task.result:
|
||||||
"""关闭全局或指定群全部被动"""
|
"""关闭全局或指定群全部被动"""
|
||||||
if group_id:
|
if group_id:
|
||||||
result = await PluginManage.unblock_group_all_task(group_id)
|
result = await PluginManager.unblock_group_all_task(group_id)
|
||||||
else:
|
else:
|
||||||
result = await PluginManage.unblock_global_all_task()
|
result = await PluginManager.unblock_global_all_task(
|
||||||
|
default_status.result
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
result = await PluginManage.set_all_plugin_status(
|
result = await PluginManager.set_all_plugin_status(
|
||||||
True, default_status.result, group_id
|
True, default_status.result, group_id
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
@ -171,8 +179,8 @@ async def _(
|
|||||||
session=session,
|
session=session,
|
||||||
)
|
)
|
||||||
await MessageUtils.build_message(result).finish(reply_to=True)
|
await MessageUtils.build_message(result).finish(reply_to=True)
|
||||||
if default_status.result:
|
if default_status.result and not task.result:
|
||||||
result = await PluginManage.set_default_status(name, True)
|
result = await PluginManager.set_default_status(name, True)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"超级用户开启 {name} 功能进群默认开关",
|
f"超级用户开启 {name} 功能进群默认开关",
|
||||||
arparma.header_result,
|
arparma.header_result,
|
||||||
@ -186,7 +194,7 @@ async def _(
|
|||||||
name = split_list[0]
|
name = split_list[0]
|
||||||
group_id = split_list[1]
|
group_id = split_list[1]
|
||||||
if group_id:
|
if group_id:
|
||||||
result = await PluginManage.superuser_task_handle(name, group_id, True)
|
result = await PluginManager.superuser_task_handle(name, group_id, True)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"超级用户开启被动技能 {name}",
|
f"超级用户开启被动技能 {name}",
|
||||||
arparma.header_result,
|
arparma.header_result,
|
||||||
@ -194,14 +202,16 @@ async def _(
|
|||||||
target=group_id,
|
target=group_id,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
result = await PluginManage.unblock_global_task(name)
|
result = await PluginManager.unblock_global_task(
|
||||||
|
name, default_status.result
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"超级用户开启全局被动技能 {name}",
|
f"超级用户开启全局被动技能 {name}",
|
||||||
arparma.header_result,
|
arparma.header_result,
|
||||||
session=session,
|
session=session,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
result = await PluginManage.superuser_unblock(name, None, group_id)
|
result = await PluginManager.superuser_unblock(name, None, group_id)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"超级用户开启功能 {name}",
|
f"超级用户开启功能 {name}",
|
||||||
arparma.header_result,
|
arparma.header_result,
|
||||||
@ -215,7 +225,7 @@ async def _(
|
|||||||
@_status_matcher.assign("close")
|
@_status_matcher.assign("close")
|
||||||
async def _(
|
async def _(
|
||||||
bot: Bot,
|
bot: Bot,
|
||||||
session: EventSession,
|
session: Uninfo,
|
||||||
arparma: Arparma,
|
arparma: Arparma,
|
||||||
plugin_name: Match[str],
|
plugin_name: Match[str],
|
||||||
block_type: Match[str],
|
block_type: Match[str],
|
||||||
@ -225,22 +235,23 @@ async def _(
|
|||||||
all: Query[bool] = AlconnaQuery("all.value", False),
|
all: Query[bool] = AlconnaQuery("all.value", False),
|
||||||
):
|
):
|
||||||
if not all.result and not plugin_name.available:
|
if not all.result and not plugin_name.available:
|
||||||
await MessageUtils.build_message("请输入功能名称").finish(reply_to=True)
|
await MessageUtils.build_message("请输入功能/被动名称").finish(reply_to=True)
|
||||||
name = plugin_name.result
|
name = plugin_name.result
|
||||||
if gid := session.id3 or session.id2:
|
if session.group:
|
||||||
|
group_id = session.group.id
|
||||||
"""修改当前群组的数据"""
|
"""修改当前群组的数据"""
|
||||||
if task.result:
|
if task.result:
|
||||||
if all.result:
|
if all.result:
|
||||||
result = await PluginManage.block_group_all_task(gid)
|
result = await PluginManager.block_group_all_task(group_id)
|
||||||
logger.info("开启所有群组被动", arparma.header_result, session=session)
|
logger.info("开启所有群组被动", arparma.header_result, session=session)
|
||||||
else:
|
else:
|
||||||
result = await PluginManage.block_group_task(name, gid)
|
result = await PluginManager.block_group_task(name, group_id)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"关闭群组被动 {name}", arparma.header_result, session=session
|
f"关闭群组被动 {name}", arparma.header_result, session=session
|
||||||
)
|
)
|
||||||
elif session.id1 in bot.config.superusers and default_status.result:
|
elif session.user.id in bot.config.superusers and default_status.result:
|
||||||
"""单个插件的进群默认修改"""
|
"""单个插件的进群默认修改"""
|
||||||
result = await PluginManage.set_default_status(name, False)
|
result = await PluginManager.set_default_status(name, False)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"超级用户开启 {name} 功能进群默认开关",
|
f"超级用户开启 {name} 功能进群默认开关",
|
||||||
arparma.header_result,
|
arparma.header_result,
|
||||||
@ -248,26 +259,28 @@ async def _(
|
|||||||
)
|
)
|
||||||
elif all.result:
|
elif all.result:
|
||||||
"""所有插件"""
|
"""所有插件"""
|
||||||
result = await PluginManage.set_all_plugin_status(
|
result = await PluginManager.set_all_plugin_status(
|
||||||
False, default_status.result, gid
|
False, default_status.result, group_id
|
||||||
)
|
)
|
||||||
logger.info("关闭群组中全部功能", arparma.header_result, session=session)
|
logger.info("关闭群组中全部功能", arparma.header_result, session=session)
|
||||||
else:
|
else:
|
||||||
result = await PluginManage.block_group_plugin(name, gid)
|
result = await PluginManager.block_group_plugin(name, group_id)
|
||||||
logger.info(f"关闭功能 {name}", arparma.header_result, session=session)
|
logger.info(f"关闭功能 {name}", arparma.header_result, session=session)
|
||||||
delete_help_image(gid)
|
delete_help_image(group_id)
|
||||||
await MessageUtils.build_message(result).finish(reply_to=True)
|
await MessageUtils.build_message(result).finish(reply_to=True)
|
||||||
elif session.id1 in bot.config.superusers:
|
elif session.user.id in bot.config.superusers:
|
||||||
group_id = group.result if group.available else None
|
group_id = group.result if group.available else None
|
||||||
if all.result:
|
if all.result:
|
||||||
if task.result:
|
if task.result:
|
||||||
"""关闭全局或指定群全部被动"""
|
"""关闭全局或指定群全部被动"""
|
||||||
if group_id:
|
if group_id:
|
||||||
result = await PluginManage.block_group_all_task(group_id)
|
result = await PluginManager.block_group_all_task(group_id)
|
||||||
else:
|
else:
|
||||||
result = await PluginManage.block_global_all_task()
|
result = await PluginManager.block_global_all_task(
|
||||||
|
default_status.result
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
result = await PluginManage.set_all_plugin_status(
|
result = await PluginManager.set_all_plugin_status(
|
||||||
False, default_status.result, group_id
|
False, default_status.result, group_id
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
@ -277,8 +290,8 @@ async def _(
|
|||||||
session=session,
|
session=session,
|
||||||
)
|
)
|
||||||
await MessageUtils.build_message(result).finish(reply_to=True)
|
await MessageUtils.build_message(result).finish(reply_to=True)
|
||||||
if default_status.result:
|
if default_status.result and not task.result:
|
||||||
result = await PluginManage.set_default_status(name, False)
|
result = await PluginManager.set_default_status(name, False)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"超级用户关闭 {name} 功能进群默认开关",
|
f"超级用户关闭 {name} 功能进群默认开关",
|
||||||
arparma.header_result,
|
arparma.header_result,
|
||||||
@ -292,7 +305,9 @@ async def _(
|
|||||||
name = split_list[0]
|
name = split_list[0]
|
||||||
group_id = split_list[1]
|
group_id = split_list[1]
|
||||||
if group_id:
|
if group_id:
|
||||||
result = await PluginManage.superuser_task_handle(name, group_id, False)
|
result = await PluginManager.superuser_task_handle(
|
||||||
|
name, group_id, False
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"超级用户关闭被动技能 {name}",
|
f"超级用户关闭被动技能 {name}",
|
||||||
arparma.header_result,
|
arparma.header_result,
|
||||||
@ -300,7 +315,9 @@ async def _(
|
|||||||
target=group_id,
|
target=group_id,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
result = await PluginManage.block_global_task(name)
|
result = await PluginManager.block_global_task(
|
||||||
|
name, default_status.result
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"超级用户关闭全局被动技能 {name}",
|
f"超级用户关闭全局被动技能 {name}",
|
||||||
arparma.header_result,
|
arparma.header_result,
|
||||||
@ -314,7 +331,7 @@ async def _(
|
|||||||
elif block_type.result in ["g", "group"]:
|
elif block_type.result in ["g", "group"]:
|
||||||
if block_type.available:
|
if block_type.available:
|
||||||
_type = BlockType.GROUP
|
_type = BlockType.GROUP
|
||||||
result = await PluginManage.superuser_block(name, _type, group_id)
|
result = await PluginManager.superuser_block(name, _type, group_id)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"超级用户关闭功能 {name}, 禁用类型: {_type}",
|
f"超级用户关闭功能 {name}, 禁用类型: {_type}",
|
||||||
arparma.header_result,
|
arparma.header_result,
|
||||||
@ -327,19 +344,20 @@ async def _(
|
|||||||
|
|
||||||
@_group_status_matcher.handle()
|
@_group_status_matcher.handle()
|
||||||
async def _(
|
async def _(
|
||||||
session: EventSession,
|
session: Uninfo,
|
||||||
arparma: Arparma,
|
arparma: Arparma,
|
||||||
status: str,
|
status: str,
|
||||||
):
|
):
|
||||||
if gid := session.id3 or session.id2:
|
if session.group:
|
||||||
|
group_id = session.group.id
|
||||||
if status == "sleep":
|
if status == "sleep":
|
||||||
await PluginManage.sleep(gid)
|
await PluginManager.sleep(group_id)
|
||||||
logger.info("进行休眠", arparma.header_result, session=session)
|
logger.info("进行休眠", arparma.header_result, session=session)
|
||||||
await MessageUtils.build_message("那我先睡觉了...").finish()
|
await MessageUtils.build_message("那我先睡觉了...").finish()
|
||||||
else:
|
else:
|
||||||
if await PluginManage.is_wake(gid):
|
if await PluginManager.is_wake(group_id):
|
||||||
await MessageUtils.build_message("我还醒着呢!").finish()
|
await MessageUtils.build_message("我还醒着呢!").finish()
|
||||||
await PluginManage.wake(gid)
|
await PluginManager.wake(group_id)
|
||||||
logger.info("醒来", arparma.header_result, session=session)
|
logger.info("醒来", arparma.header_result, session=session)
|
||||||
await MessageUtils.build_message("呜..醒来了...").finish()
|
await MessageUtils.build_message("呜..醒来了...").finish()
|
||||||
return MessageUtils.build_message("群组id为空...").send()
|
return MessageUtils.build_message("群组id为空...").send()
|
||||||
@ -347,10 +365,10 @@ async def _(
|
|||||||
|
|
||||||
@_status_matcher.assign("task")
|
@_status_matcher.assign("task")
|
||||||
async def _(
|
async def _(
|
||||||
session: EventSession,
|
session: Uninfo,
|
||||||
arparma: Arparma,
|
arparma: Arparma,
|
||||||
):
|
):
|
||||||
image = await build_task(session.id3 or session.id2)
|
image = await build_task(session.group.id if session.group else None)
|
||||||
if image:
|
if image:
|
||||||
logger.info("查看群被动列表", arparma.header_result, session=session)
|
logger.info("查看群被动列表", arparma.header_result, session=session)
|
||||||
await MessageUtils.build_message(image).finish(reply_to=True)
|
await MessageUtils.build_message(image).finish(reply_to=True)
|
||||||
|
|||||||
@ -155,7 +155,7 @@ async def build_task(group_id: str | None) -> BuildImage:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class PluginManage:
|
class PluginManager:
|
||||||
@classmethod
|
@classmethod
|
||||||
async def set_default_status(cls, plugin_name: str, status: bool) -> str:
|
async def set_default_status(cls, plugin_name: str, status: bool) -> str:
|
||||||
"""设置插件进群默认状态
|
"""设置插件进群默认状态
|
||||||
@ -342,17 +342,21 @@ class PluginManage:
|
|||||||
return await cls._change_group_task("", group_id, True, True)
|
return await cls._change_group_task("", group_id, True, True)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def block_global_all_task(cls) -> str:
|
async def block_global_all_task(cls, is_default: bool) -> str:
|
||||||
"""禁用全局被动技能
|
"""禁用全局被动技能
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
str: 返回信息
|
str: 返回信息
|
||||||
"""
|
"""
|
||||||
|
if is_default:
|
||||||
|
await TaskInfo.all().update(default_status=False)
|
||||||
|
return "已禁用所有被动进群默认状态"
|
||||||
|
else:
|
||||||
await TaskInfo.all().update(status=False)
|
await TaskInfo.all().update(status=False)
|
||||||
return "已全局禁用所有被动状态"
|
return "已全局禁用所有被动状态"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def block_global_task(cls, name: str) -> str:
|
async def block_global_task(cls, name: str, is_default: bool = False) -> str:
|
||||||
"""禁用全局被动技能
|
"""禁用全局被动技能
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
@ -361,29 +365,45 @@ class PluginManage:
|
|||||||
返回:
|
返回:
|
||||||
str: 返回信息
|
str: 返回信息
|
||||||
"""
|
"""
|
||||||
|
if is_default:
|
||||||
|
await TaskInfo.filter(name=name).update(default_status=False)
|
||||||
|
return f"已禁用被动进群默认状态 {name}"
|
||||||
|
else:
|
||||||
await TaskInfo.filter(name=name).update(status=False)
|
await TaskInfo.filter(name=name).update(status=False)
|
||||||
return f"已全局禁用被动状态 {name}"
|
return f"已全局禁用被动状态 {name}"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def unblock_global_all_task(cls) -> str:
|
async def unblock_global_all_task(cls, is_default: bool) -> str:
|
||||||
"""开启全局被动技能
|
"""开启全局被动技能
|
||||||
|
|
||||||
|
参数:
|
||||||
|
is_default: 是否为默认状态
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
str: 返回信息
|
str: 返回信息
|
||||||
"""
|
"""
|
||||||
|
if is_default:
|
||||||
|
await TaskInfo.all().update(default_status=True)
|
||||||
|
return "已开启所有被动进群默认状态"
|
||||||
|
else:
|
||||||
await TaskInfo.all().update(status=True)
|
await TaskInfo.all().update(status=True)
|
||||||
return "已全局开启所有被动状态"
|
return "已全局开启所有被动状态"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def unblock_global_task(cls, name: str) -> str:
|
async def unblock_global_task(cls, name: str, is_default: bool = False) -> str:
|
||||||
"""开启全局被动技能
|
"""开启全局被动技能
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
name: 被动技能名称
|
name: 被动技能名称
|
||||||
|
is_default: 是否为默认状态
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
str: 返回信息
|
str: 返回信息
|
||||||
"""
|
"""
|
||||||
|
if is_default:
|
||||||
|
await TaskInfo.filter(name=name).update(default_status=True)
|
||||||
|
return f"已开启被动进群默认状态 {name}"
|
||||||
|
else:
|
||||||
await TaskInfo.filter(name=name).update(status=True)
|
await TaskInfo.filter(name=name).update(status=True)
|
||||||
return f"已全局开启被动状态 {name}"
|
return f"已全局开启被动状态 {name}"
|
||||||
|
|
||||||
|
|||||||
@ -58,6 +58,19 @@ _status_matcher.shortcut(
|
|||||||
prefix=True,
|
prefix=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_status_matcher.shortcut(
|
||||||
|
r"开启(所有|全部)默认群被动",
|
||||||
|
command="switch",
|
||||||
|
arguments=["open", "--task", "--all", "-df"],
|
||||||
|
prefix=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
_status_matcher.shortcut(
|
||||||
|
r"关闭(所有|全部)默认群被动",
|
||||||
|
command="switch",
|
||||||
|
arguments=["close", "--task", "--all", "-df"],
|
||||||
|
prefix=True,
|
||||||
|
)
|
||||||
|
|
||||||
_status_matcher.shortcut(
|
_status_matcher.shortcut(
|
||||||
r"开启群被动\s*(?P<name>.+)",
|
r"开启群被动\s*(?P<name>.+)",
|
||||||
@ -73,6 +86,20 @@ _status_matcher.shortcut(
|
|||||||
prefix=True,
|
prefix=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_status_matcher.shortcut(
|
||||||
|
r"开启默认群被动\s*(?P<name>.+)",
|
||||||
|
command="switch",
|
||||||
|
arguments=["open", "{name}", "--task", "-df"],
|
||||||
|
prefix=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
_status_matcher.shortcut(
|
||||||
|
r"关闭默认群被动\s*(?P<name>.+)",
|
||||||
|
command="switch",
|
||||||
|
arguments=["close", "{name}", "--task", "-df"],
|
||||||
|
prefix=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_status_matcher.shortcut(
|
_status_matcher.shortcut(
|
||||||
r"开启(所有|全部)群被动",
|
r"开启(所有|全部)群被动",
|
||||||
|
|||||||
@ -54,22 +54,6 @@ __plugin_meta__ = PluginMetadata(
|
|||||||
default_value=5,
|
default_value=5,
|
||||||
type=int,
|
type=int,
|
||||||
),
|
),
|
||||||
RegisterConfig(
|
|
||||||
module="_task",
|
|
||||||
key="DEFAULT_GROUP_WELCOME",
|
|
||||||
value=True,
|
|
||||||
help="被动 进群欢迎 进群默认开关状态",
|
|
||||||
default_value=True,
|
|
||||||
type=bool,
|
|
||||||
),
|
|
||||||
RegisterConfig(
|
|
||||||
module="_task",
|
|
||||||
key="DEFAULT_REFUND_GROUP_REMIND",
|
|
||||||
value=True,
|
|
||||||
help="被动 退群提醒 进群默认开关状态",
|
|
||||||
default_value=True,
|
|
||||||
type=bool,
|
|
||||||
),
|
|
||||||
],
|
],
|
||||||
tasks=[
|
tasks=[
|
||||||
Task(
|
Task(
|
||||||
|
|||||||
@ -1,9 +1,11 @@
|
|||||||
from nonebot.plugin import PluginMetadata
|
from nonebot.plugin import PluginMetadata
|
||||||
|
|
||||||
from zhenxun.configs.utils import PluginExtraData
|
from zhenxun.configs.utils import PluginExtraData, RegisterConfig
|
||||||
from zhenxun.utils.enum import PluginType
|
from zhenxun.utils.enum import PluginType
|
||||||
|
|
||||||
from . import command # noqa: F401
|
from . import commands, handlers
|
||||||
|
|
||||||
|
__all__ = ["commands", "handlers"]
|
||||||
|
|
||||||
__plugin_meta__ = PluginMetadata(
|
__plugin_meta__ = PluginMetadata(
|
||||||
name="定时任务管理",
|
name="定时任务管理",
|
||||||
@ -27,6 +29,8 @@ __plugin_meta__ = PluginMetadata(
|
|||||||
定时任务 恢复 <任务ID> | -p <插件> [-g <群号>] | -all
|
定时任务 恢复 <任务ID> | -p <插件> [-g <群号>] | -all
|
||||||
定时任务 执行 <任务ID>
|
定时任务 执行 <任务ID>
|
||||||
定时任务 更新 <任务ID> [时间选项] [--kwargs <参数>]
|
定时任务 更新 <任务ID> [时间选项] [--kwargs <参数>]
|
||||||
|
# [修改] 增加说明
|
||||||
|
• 说明: -p 选项可单独使用,用于操作指定插件的所有任务
|
||||||
|
|
||||||
📝 时间选项 (三选一):
|
📝 时间选项 (三选一):
|
||||||
--cron "<分> <时> <日> <月> <周>" # 例: --cron "0 8 * * *"
|
--cron "<分> <时> <日> <月> <周>" # 例: --cron "0 8 * * *"
|
||||||
@ -47,5 +51,35 @@ __plugin_meta__ = PluginMetadata(
|
|||||||
version="0.1.2",
|
version="0.1.2",
|
||||||
plugin_type=PluginType.SUPERUSER,
|
plugin_type=PluginType.SUPERUSER,
|
||||||
is_show=False,
|
is_show=False,
|
||||||
|
configs=[
|
||||||
|
RegisterConfig(
|
||||||
|
module="SchedulerManager",
|
||||||
|
key="ALL_GROUPS_CONCURRENCY_LIMIT",
|
||||||
|
value=5,
|
||||||
|
help="“所有群组”类型定时任务的并发执行数量限制",
|
||||||
|
type=int,
|
||||||
|
),
|
||||||
|
RegisterConfig(
|
||||||
|
module="SchedulerManager",
|
||||||
|
key="JOB_MAX_RETRIES",
|
||||||
|
value=2,
|
||||||
|
help="定时任务执行失败时的最大重试次数",
|
||||||
|
type=int,
|
||||||
|
),
|
||||||
|
RegisterConfig(
|
||||||
|
module="SchedulerManager",
|
||||||
|
key="JOB_RETRY_DELAY",
|
||||||
|
value=10,
|
||||||
|
help="定时任务执行重试的间隔时间(秒)",
|
||||||
|
type=int,
|
||||||
|
),
|
||||||
|
RegisterConfig(
|
||||||
|
module="SchedulerManager",
|
||||||
|
key="SCHEDULER_TIMEZONE",
|
||||||
|
value="Asia/Shanghai",
|
||||||
|
help="定时任务使用的时区,默认为 Asia/Shanghai",
|
||||||
|
type=str,
|
||||||
|
),
|
||||||
|
],
|
||||||
).to_dict(),
|
).to_dict(),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,836 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
from datetime import datetime
|
|
||||||
import re
|
|
||||||
|
|
||||||
from nonebot.adapters import Event
|
|
||||||
from nonebot.adapters.onebot.v11 import Bot
|
|
||||||
from nonebot.params import Depends
|
|
||||||
from nonebot.permission import SUPERUSER
|
|
||||||
from nonebot_plugin_alconna import (
|
|
||||||
Alconna,
|
|
||||||
AlconnaMatch,
|
|
||||||
Args,
|
|
||||||
Arparma,
|
|
||||||
Match,
|
|
||||||
Option,
|
|
||||||
Query,
|
|
||||||
Subcommand,
|
|
||||||
on_alconna,
|
|
||||||
)
|
|
||||||
from pydantic import BaseModel, ValidationError
|
|
||||||
|
|
||||||
from zhenxun.utils._image_template import ImageTemplate
|
|
||||||
from zhenxun.utils.manager.schedule_manager import scheduler_manager
|
|
||||||
|
|
||||||
|
|
||||||
def _get_type_name(annotation) -> str:
|
|
||||||
"""获取类型注解的名称"""
|
|
||||||
if hasattr(annotation, "__name__"):
|
|
||||||
return annotation.__name__
|
|
||||||
elif hasattr(annotation, "_name"):
|
|
||||||
return annotation._name
|
|
||||||
else:
|
|
||||||
return str(annotation)
|
|
||||||
|
|
||||||
|
|
||||||
from zhenxun.utils.message import MessageUtils
|
|
||||||
from zhenxun.utils.rules import admin_check
|
|
||||||
|
|
||||||
|
|
||||||
def _format_trigger(schedule_status: dict) -> str:
|
|
||||||
"""将触发器配置格式化为人类可读的字符串"""
|
|
||||||
trigger_type = schedule_status["trigger_type"]
|
|
||||||
config = schedule_status["trigger_config"]
|
|
||||||
|
|
||||||
if trigger_type == "cron":
|
|
||||||
minute = config.get("minute", "*")
|
|
||||||
hour = config.get("hour", "*")
|
|
||||||
day = config.get("day", "*")
|
|
||||||
month = config.get("month", "*")
|
|
||||||
day_of_week = config.get("day_of_week", "*")
|
|
||||||
|
|
||||||
if day == "*" and month == "*" and day_of_week == "*":
|
|
||||||
formatted_hour = hour if hour == "*" else f"{int(hour):02d}"
|
|
||||||
formatted_minute = minute if minute == "*" else f"{int(minute):02d}"
|
|
||||||
return f"每天 {formatted_hour}:{formatted_minute}"
|
|
||||||
else:
|
|
||||||
return f"Cron: {minute} {hour} {day} {month} {day_of_week}"
|
|
||||||
elif trigger_type == "interval":
|
|
||||||
seconds = config.get("seconds", 0)
|
|
||||||
minutes = config.get("minutes", 0)
|
|
||||||
hours = config.get("hours", 0)
|
|
||||||
days = config.get("days", 0)
|
|
||||||
if days:
|
|
||||||
trigger_str = f"每 {days} 天"
|
|
||||||
elif hours:
|
|
||||||
trigger_str = f"每 {hours} 小时"
|
|
||||||
elif minutes:
|
|
||||||
trigger_str = f"每 {minutes} 分钟"
|
|
||||||
else:
|
|
||||||
trigger_str = f"每 {seconds} 秒"
|
|
||||||
elif trigger_type == "date":
|
|
||||||
run_date = config.get("run_date", "未知时间")
|
|
||||||
trigger_str = f"在 {run_date}"
|
|
||||||
else:
|
|
||||||
trigger_str = f"{trigger_type}: {config}"
|
|
||||||
|
|
||||||
return trigger_str
|
|
||||||
|
|
||||||
|
|
||||||
def _format_params(schedule_status: dict) -> str:
|
|
||||||
"""将任务参数格式化为人类可读的字符串"""
|
|
||||||
if kwargs := schedule_status.get("job_kwargs"):
|
|
||||||
kwargs_str = " | ".join(f"{k}: {v}" for k, v in kwargs.items())
|
|
||||||
return kwargs_str
|
|
||||||
return "-"
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_interval(interval_str: str) -> dict:
|
|
||||||
"""增强版解析器,支持 d(天)"""
|
|
||||||
match = re.match(r"(\d+)([smhd])", interval_str.lower())
|
|
||||||
if not match:
|
|
||||||
raise ValueError("时间间隔格式错误, 请使用如 '30m', '2h', '1d', '10s' 的格式。")
|
|
||||||
|
|
||||||
value, unit = int(match.group(1)), match.group(2)
|
|
||||||
if unit == "s":
|
|
||||||
return {"seconds": value}
|
|
||||||
if unit == "m":
|
|
||||||
return {"minutes": value}
|
|
||||||
if unit == "h":
|
|
||||||
return {"hours": value}
|
|
||||||
if unit == "d":
|
|
||||||
return {"days": value}
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_daily_time(time_str: str) -> dict:
|
|
||||||
"""解析 HH:MM 或 HH:MM:SS 格式的时间为 cron 配置"""
|
|
||||||
if match := re.match(r"^(\d{1,2}):(\d{1,2})(?::(\d{1,2}))?$", time_str):
|
|
||||||
hour, minute, second = match.groups()
|
|
||||||
hour, minute = int(hour), int(minute)
|
|
||||||
|
|
||||||
if not (0 <= hour <= 23 and 0 <= minute <= 59):
|
|
||||||
raise ValueError("小时或分钟数值超出范围。")
|
|
||||||
|
|
||||||
cron_config = {
|
|
||||||
"minute": str(minute),
|
|
||||||
"hour": str(hour),
|
|
||||||
"day": "*",
|
|
||||||
"month": "*",
|
|
||||||
"day_of_week": "*",
|
|
||||||
}
|
|
||||||
if second is not None:
|
|
||||||
if not (0 <= int(second) <= 59):
|
|
||||||
raise ValueError("秒数值超出范围。")
|
|
||||||
cron_config["second"] = str(second)
|
|
||||||
|
|
||||||
return cron_config
|
|
||||||
else:
|
|
||||||
raise ValueError("时间格式错误,请使用 'HH:MM' 或 'HH:MM:SS' 格式。")
|
|
||||||
|
|
||||||
|
|
||||||
async def GetBotId(
|
|
||||||
bot: Bot,
|
|
||||||
bot_id_match: Match[str] = AlconnaMatch("bot_id"),
|
|
||||||
) -> str:
|
|
||||||
"""获取要操作的Bot ID"""
|
|
||||||
if bot_id_match.available:
|
|
||||||
return bot_id_match.result
|
|
||||||
return bot.self_id
|
|
||||||
|
|
||||||
|
|
||||||
class ScheduleTarget:
|
|
||||||
"""定时任务操作目标的基类"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class TargetByID(ScheduleTarget):
|
|
||||||
"""按任务ID操作"""
|
|
||||||
|
|
||||||
def __init__(self, id: int):
|
|
||||||
self.id = id
|
|
||||||
|
|
||||||
|
|
||||||
class TargetByPlugin(ScheduleTarget):
|
|
||||||
"""按插件名操作"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, plugin: str, group_id: str | None = None, all_groups: bool = False
|
|
||||||
):
|
|
||||||
self.plugin = plugin
|
|
||||||
self.group_id = group_id
|
|
||||||
self.all_groups = all_groups
|
|
||||||
|
|
||||||
|
|
||||||
class TargetAll(ScheduleTarget):
|
|
||||||
"""操作所有任务"""
|
|
||||||
|
|
||||||
def __init__(self, for_group: str | None = None):
|
|
||||||
self.for_group = for_group
|
|
||||||
|
|
||||||
|
|
||||||
TargetScope = TargetByID | TargetByPlugin | TargetAll | None
|
|
||||||
|
|
||||||
|
|
||||||
def create_target_parser(subcommand_name: str):
|
|
||||||
"""
|
|
||||||
创建一个依赖注入函数,用于解析删除、暂停、恢复等命令的操作目标。
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def dependency(
|
|
||||||
event: Event,
|
|
||||||
schedule_id: Match[int] = AlconnaMatch("schedule_id"),
|
|
||||||
plugin_name: Match[str] = AlconnaMatch("plugin_name"),
|
|
||||||
group_id: Match[str] = AlconnaMatch("group_id"),
|
|
||||||
all_enabled: Query[bool] = Query(f"{subcommand_name}.all"),
|
|
||||||
) -> TargetScope:
|
|
||||||
if schedule_id.available:
|
|
||||||
return TargetByID(schedule_id.result)
|
|
||||||
|
|
||||||
if plugin_name.available:
|
|
||||||
p_name = plugin_name.result
|
|
||||||
if all_enabled.available:
|
|
||||||
return TargetByPlugin(plugin=p_name, all_groups=True)
|
|
||||||
elif group_id.available:
|
|
||||||
gid = group_id.result
|
|
||||||
if gid.lower() == "all":
|
|
||||||
return TargetByPlugin(plugin=p_name, all_groups=True)
|
|
||||||
return TargetByPlugin(plugin=p_name, group_id=gid)
|
|
||||||
else:
|
|
||||||
current_group_id = getattr(event, "group_id", None)
|
|
||||||
if current_group_id:
|
|
||||||
return TargetByPlugin(plugin=p_name, group_id=str(current_group_id))
|
|
||||||
else:
|
|
||||||
await schedule_cmd.finish(
|
|
||||||
"私聊中操作插件任务必须使用 -g <群号> 或 -all 选项。"
|
|
||||||
)
|
|
||||||
|
|
||||||
if all_enabled.available:
|
|
||||||
return TargetAll(for_group=group_id.result if group_id.available else None)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
return dependency
|
|
||||||
|
|
||||||
|
|
||||||
schedule_cmd = on_alconna(
|
|
||||||
Alconna(
|
|
||||||
"定时任务",
|
|
||||||
Subcommand(
|
|
||||||
"查看",
|
|
||||||
Option("-g", Args["target_group_id", str]),
|
|
||||||
Option("-all", help_text="查看所有群聊 (SUPERUSER)"),
|
|
||||||
Option("-p", Args["plugin_name", str], help_text="按插件名筛选"),
|
|
||||||
Option("--page", Args["page", int, 1], help_text="指定页码"),
|
|
||||||
alias=["ls", "list"],
|
|
||||||
help_text="查看定时任务",
|
|
||||||
),
|
|
||||||
Subcommand(
|
|
||||||
"设置",
|
|
||||||
Args["plugin_name", str],
|
|
||||||
Option("--cron", Args["cron_expr", str], help_text="设置 cron 表达式"),
|
|
||||||
Option("--interval", Args["interval_expr", str], help_text="设置时间间隔"),
|
|
||||||
Option("--date", Args["date_expr", str], help_text="设置特定执行日期"),
|
|
||||||
Option(
|
|
||||||
"--daily",
|
|
||||||
Args["daily_expr", str],
|
|
||||||
help_text="设置每天执行的时间 (如 08:20)",
|
|
||||||
),
|
|
||||||
Option("-g", Args["group_id", str], help_text="指定群组ID或'all'"),
|
|
||||||
Option("-all", help_text="对所有群生效 (等同于 -g all)"),
|
|
||||||
Option("--kwargs", Args["kwargs_str", str], help_text="设置任务参数"),
|
|
||||||
Option(
|
|
||||||
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
|
|
||||||
),
|
|
||||||
alias=["add", "开启"],
|
|
||||||
help_text="设置/开启一个定时任务",
|
|
||||||
),
|
|
||||||
Subcommand(
|
|
||||||
"删除",
|
|
||||||
Args["schedule_id?", int],
|
|
||||||
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
|
|
||||||
Option("-g", Args["group_id", str], help_text="指定群组ID"),
|
|
||||||
Option("-all", help_text="对所有群生效"),
|
|
||||||
Option(
|
|
||||||
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
|
|
||||||
),
|
|
||||||
alias=["del", "rm", "remove", "关闭", "取消"],
|
|
||||||
help_text="删除一个或多个定时任务",
|
|
||||||
),
|
|
||||||
Subcommand(
|
|
||||||
"暂停",
|
|
||||||
Args["schedule_id?", int],
|
|
||||||
Option("-all", help_text="对当前群所有任务生效"),
|
|
||||||
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
|
|
||||||
Option("-g", Args["group_id", str], help_text="指定群组ID (SUPERUSER)"),
|
|
||||||
Option(
|
|
||||||
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
|
|
||||||
),
|
|
||||||
alias=["pause"],
|
|
||||||
help_text="暂停一个或多个定时任务",
|
|
||||||
),
|
|
||||||
Subcommand(
|
|
||||||
"恢复",
|
|
||||||
Args["schedule_id?", int],
|
|
||||||
Option("-all", help_text="对当前群所有任务生效"),
|
|
||||||
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
|
|
||||||
Option("-g", Args["group_id", str], help_text="指定群组ID (SUPERUSER)"),
|
|
||||||
Option(
|
|
||||||
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
|
|
||||||
),
|
|
||||||
alias=["resume"],
|
|
||||||
help_text="恢复一个或多个定时任务",
|
|
||||||
),
|
|
||||||
Subcommand(
|
|
||||||
"执行",
|
|
||||||
Args["schedule_id", int],
|
|
||||||
alias=["trigger", "run"],
|
|
||||||
help_text="立即执行一次任务",
|
|
||||||
),
|
|
||||||
Subcommand(
|
|
||||||
"更新",
|
|
||||||
Args["schedule_id", int],
|
|
||||||
Option("--cron", Args["cron_expr", str], help_text="设置 cron 表达式"),
|
|
||||||
Option("--interval", Args["interval_expr", str], help_text="设置时间间隔"),
|
|
||||||
Option("--date", Args["date_expr", str], help_text="设置特定执行日期"),
|
|
||||||
Option(
|
|
||||||
"--daily",
|
|
||||||
Args["daily_expr", str],
|
|
||||||
help_text="更新每天执行的时间 (如 08:20)",
|
|
||||||
),
|
|
||||||
Option("--kwargs", Args["kwargs_str", str], help_text="更新参数"),
|
|
||||||
alias=["update", "modify", "修改"],
|
|
||||||
help_text="更新任务配置",
|
|
||||||
),
|
|
||||||
Subcommand(
|
|
||||||
"状态",
|
|
||||||
Args["schedule_id", int],
|
|
||||||
alias=["status", "info"],
|
|
||||||
help_text="查看单个任务的详细状态",
|
|
||||||
),
|
|
||||||
Subcommand(
|
|
||||||
"插件列表",
|
|
||||||
alias=["plugins"],
|
|
||||||
help_text="列出所有可用的插件",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
priority=5,
|
|
||||||
block=True,
|
|
||||||
rule=admin_check(1),
|
|
||||||
)
|
|
||||||
|
|
||||||
schedule_cmd.shortcut(
|
|
||||||
"任务状态",
|
|
||||||
command="定时任务",
|
|
||||||
arguments=["状态", "{%0}"],
|
|
||||||
prefix=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@schedule_cmd.handle()
|
|
||||||
async def _handle_time_options_mutex(arp: Arparma):
|
|
||||||
time_options = ["cron", "interval", "date", "daily"]
|
|
||||||
provided_options = [opt for opt in time_options if arp.query(opt) is not None]
|
|
||||||
if len(provided_options) > 1:
|
|
||||||
await schedule_cmd.finish(
|
|
||||||
f"时间选项 --{', --'.join(provided_options)} 不能同时使用,请只选择一个。"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@schedule_cmd.assign("查看")
|
|
||||||
async def _(
|
|
||||||
bot: Bot,
|
|
||||||
event: Event,
|
|
||||||
target_group_id: Match[str] = AlconnaMatch("target_group_id"),
|
|
||||||
all_groups: Query[bool] = Query("查看.all"),
|
|
||||||
plugin_name: Match[str] = AlconnaMatch("plugin_name"),
|
|
||||||
page: Match[int] = AlconnaMatch("page"),
|
|
||||||
):
|
|
||||||
is_superuser = await SUPERUSER(bot, event)
|
|
||||||
schedules = []
|
|
||||||
title = ""
|
|
||||||
|
|
||||||
current_group_id = getattr(event, "group_id", None)
|
|
||||||
if not (all_groups.available or target_group_id.available) and not current_group_id:
|
|
||||||
await schedule_cmd.finish("私聊中查看任务必须使用 -g <群号> 或 -all 选项。")
|
|
||||||
|
|
||||||
if all_groups.available:
|
|
||||||
if not is_superuser:
|
|
||||||
await schedule_cmd.finish("需要超级用户权限才能查看所有群组的定时任务。")
|
|
||||||
schedules = await scheduler_manager.get_all_schedules()
|
|
||||||
title = "所有群组的定时任务"
|
|
||||||
elif target_group_id.available:
|
|
||||||
if not is_superuser:
|
|
||||||
await schedule_cmd.finish("需要超级用户权限才能查看指定群组的定时任务。")
|
|
||||||
gid = target_group_id.result
|
|
||||||
schedules = [
|
|
||||||
s for s in await scheduler_manager.get_all_schedules() if s.group_id == gid
|
|
||||||
]
|
|
||||||
title = f"群 {gid} 的定时任务"
|
|
||||||
else:
|
|
||||||
gid = str(current_group_id)
|
|
||||||
schedules = [
|
|
||||||
s for s in await scheduler_manager.get_all_schedules() if s.group_id == gid
|
|
||||||
]
|
|
||||||
title = "本群的定时任务"
|
|
||||||
|
|
||||||
if plugin_name.available:
|
|
||||||
schedules = [s for s in schedules if s.plugin_name == plugin_name.result]
|
|
||||||
title += f" [插件: {plugin_name.result}]"
|
|
||||||
|
|
||||||
if not schedules:
|
|
||||||
await schedule_cmd.finish("没有找到任何相关的定时任务。")
|
|
||||||
|
|
||||||
page_size = 15
|
|
||||||
current_page = page.result
|
|
||||||
total_items = len(schedules)
|
|
||||||
total_pages = (total_items + page_size - 1) // page_size
|
|
||||||
start_index = (current_page - 1) * page_size
|
|
||||||
end_index = start_index + page_size
|
|
||||||
paginated_schedules = schedules[start_index:end_index]
|
|
||||||
|
|
||||||
if not paginated_schedules:
|
|
||||||
await schedule_cmd.finish("这一页没有内容了哦~")
|
|
||||||
|
|
||||||
status_tasks = [
|
|
||||||
scheduler_manager.get_schedule_status(s.id) for s in paginated_schedules
|
|
||||||
]
|
|
||||||
all_statuses = await asyncio.gather(*status_tasks)
|
|
||||||
data_list = [
|
|
||||||
[
|
|
||||||
s["id"],
|
|
||||||
s["plugin_name"],
|
|
||||||
s.get("bot_id") or "N/A",
|
|
||||||
s["group_id"] or "全局",
|
|
||||||
s["next_run_time"],
|
|
||||||
_format_trigger(s),
|
|
||||||
_format_params(s),
|
|
||||||
"✔️ 已启用" if s["is_enabled"] else "⏸️ 已暂停",
|
|
||||||
]
|
|
||||||
for s in all_statuses
|
|
||||||
if s
|
|
||||||
]
|
|
||||||
|
|
||||||
if not data_list:
|
|
||||||
await schedule_cmd.finish("没有找到任何相关的定时任务。")
|
|
||||||
|
|
||||||
img = await ImageTemplate.table_page(
|
|
||||||
head_text=title,
|
|
||||||
tip_text=f"第 {current_page}/{total_pages} 页,共 {total_items} 条任务",
|
|
||||||
column_name=[
|
|
||||||
"ID",
|
|
||||||
"插件",
|
|
||||||
"Bot ID",
|
|
||||||
"群组/目标",
|
|
||||||
"下次运行",
|
|
||||||
"触发规则",
|
|
||||||
"参数",
|
|
||||||
"状态",
|
|
||||||
],
|
|
||||||
data_list=data_list,
|
|
||||||
column_space=20,
|
|
||||||
)
|
|
||||||
await MessageUtils.build_message(img).send(reply_to=True)
|
|
||||||
|
|
||||||
|
|
||||||
@schedule_cmd.assign("设置")
|
|
||||||
async def _(
|
|
||||||
event: Event,
|
|
||||||
plugin_name: str,
|
|
||||||
cron_expr: str | None = None,
|
|
||||||
interval_expr: str | None = None,
|
|
||||||
date_expr: str | None = None,
|
|
||||||
daily_expr: str | None = None,
|
|
||||||
group_id: str | None = None,
|
|
||||||
kwargs_str: str | None = None,
|
|
||||||
all_enabled: Query[bool] = Query("设置.all"),
|
|
||||||
bot_id_to_operate: str = Depends(GetBotId),
|
|
||||||
):
|
|
||||||
if plugin_name not in scheduler_manager._registered_tasks:
|
|
||||||
await schedule_cmd.finish(
|
|
||||||
f"插件 '{plugin_name}' 没有注册可用的定时任务。\n"
|
|
||||||
f"可用插件: {list(scheduler_manager._registered_tasks.keys())}"
|
|
||||||
)
|
|
||||||
|
|
||||||
trigger_type = ""
|
|
||||||
trigger_config = {}
|
|
||||||
|
|
||||||
try:
|
|
||||||
if cron_expr:
|
|
||||||
trigger_type = "cron"
|
|
||||||
parts = cron_expr.split()
|
|
||||||
if len(parts) != 5:
|
|
||||||
raise ValueError("Cron 表达式必须有5个部分 (分 时 日 月 周)")
|
|
||||||
cron_keys = ["minute", "hour", "day", "month", "day_of_week"]
|
|
||||||
trigger_config = dict(zip(cron_keys, parts))
|
|
||||||
elif interval_expr:
|
|
||||||
trigger_type = "interval"
|
|
||||||
trigger_config = _parse_interval(interval_expr)
|
|
||||||
elif date_expr:
|
|
||||||
trigger_type = "date"
|
|
||||||
trigger_config = {"run_date": datetime.fromisoformat(date_expr)}
|
|
||||||
elif daily_expr:
|
|
||||||
trigger_type = "cron"
|
|
||||||
trigger_config = _parse_daily_time(daily_expr)
|
|
||||||
else:
|
|
||||||
await schedule_cmd.finish(
|
|
||||||
"必须提供一种时间选项: --cron, --interval, --date, 或 --daily。"
|
|
||||||
)
|
|
||||||
except ValueError as e:
|
|
||||||
await schedule_cmd.finish(f"时间参数解析错误: {e}")
|
|
||||||
|
|
||||||
job_kwargs = {}
|
|
||||||
if kwargs_str:
|
|
||||||
task_meta = scheduler_manager._registered_tasks[plugin_name]
|
|
||||||
params_model = task_meta.get("model")
|
|
||||||
if not params_model:
|
|
||||||
await schedule_cmd.finish(f"插件 '{plugin_name}' 不支持设置额外参数。")
|
|
||||||
|
|
||||||
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
|
|
||||||
await schedule_cmd.finish(f"插件 '{plugin_name}' 的参数模型配置错误。")
|
|
||||||
|
|
||||||
raw_kwargs = {}
|
|
||||||
try:
|
|
||||||
for item in kwargs_str.split(","):
|
|
||||||
key, value = item.strip().split("=", 1)
|
|
||||||
raw_kwargs[key.strip()] = value
|
|
||||||
except Exception as e:
|
|
||||||
await schedule_cmd.finish(
|
|
||||||
f"参数格式错误,请使用 'key=value,key2=value2' 格式。错误: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
model_validate = getattr(params_model, "model_validate", None)
|
|
||||||
if not model_validate:
|
|
||||||
await schedule_cmd.finish(
|
|
||||||
f"插件 '{plugin_name}' 的参数模型不支持验证。"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
validated_model = model_validate(raw_kwargs)
|
|
||||||
|
|
||||||
model_dump = getattr(validated_model, "model_dump", None)
|
|
||||||
if not model_dump:
|
|
||||||
await schedule_cmd.finish(
|
|
||||||
f"插件 '{plugin_name}' 的参数模型不支持导出。"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
job_kwargs = model_dump()
|
|
||||||
except ValidationError as e:
|
|
||||||
errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()]
|
|
||||||
error_str = "\n".join(errors)
|
|
||||||
await schedule_cmd.finish(
|
|
||||||
f"插件 '{plugin_name}' 的任务参数验证失败:\n{error_str}"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
target_group_id: str | None
|
|
||||||
current_group_id = getattr(event, "group_id", None)
|
|
||||||
|
|
||||||
if group_id and group_id.lower() == "all":
|
|
||||||
target_group_id = "__ALL_GROUPS__"
|
|
||||||
elif all_enabled.available:
|
|
||||||
target_group_id = "__ALL_GROUPS__"
|
|
||||||
elif group_id:
|
|
||||||
target_group_id = group_id
|
|
||||||
elif current_group_id:
|
|
||||||
target_group_id = str(current_group_id)
|
|
||||||
else:
|
|
||||||
await schedule_cmd.finish(
|
|
||||||
"私聊中设置定时任务时,必须使用 -g <群号> 或 --all 选项指定目标。"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
success, msg = await scheduler_manager.add_schedule(
|
|
||||||
plugin_name,
|
|
||||||
target_group_id,
|
|
||||||
trigger_type,
|
|
||||||
trigger_config,
|
|
||||||
job_kwargs,
|
|
||||||
bot_id=bot_id_to_operate,
|
|
||||||
)
|
|
||||||
|
|
||||||
if target_group_id == "__ALL_GROUPS__":
|
|
||||||
target_desc = f"所有群组 (Bot: {bot_id_to_operate})"
|
|
||||||
elif target_group_id is None:
|
|
||||||
target_desc = "全局"
|
|
||||||
else:
|
|
||||||
target_desc = f"群组 {target_group_id}"
|
|
||||||
|
|
||||||
if success:
|
|
||||||
await schedule_cmd.finish(f"已成功为 [{target_desc}] {msg}")
|
|
||||||
else:
|
|
||||||
await schedule_cmd.finish(f"为 [{target_desc}] 设置任务失败: {msg}")
|
|
||||||
|
|
||||||
|
|
||||||
@schedule_cmd.assign("删除")
|
|
||||||
async def _(
|
|
||||||
target: TargetScope = Depends(create_target_parser("删除")),
|
|
||||||
bot_id_to_operate: str = Depends(GetBotId),
|
|
||||||
):
|
|
||||||
if isinstance(target, TargetByID):
|
|
||||||
_, message = await scheduler_manager.remove_schedule_by_id(target.id)
|
|
||||||
await schedule_cmd.finish(message)
|
|
||||||
|
|
||||||
elif isinstance(target, TargetByPlugin):
|
|
||||||
p_name = target.plugin
|
|
||||||
if p_name not in scheduler_manager.get_registered_plugins():
|
|
||||||
await schedule_cmd.finish(f"未找到插件 '{p_name}'。")
|
|
||||||
|
|
||||||
if target.all_groups:
|
|
||||||
removed_count = await scheduler_manager.remove_schedule_for_all(
|
|
||||||
p_name, bot_id=bot_id_to_operate
|
|
||||||
)
|
|
||||||
message = (
|
|
||||||
f"已取消了 {removed_count} 个群组的插件 '{p_name}' 定时任务。"
|
|
||||||
if removed_count > 0
|
|
||||||
else f"没有找到插件 '{p_name}' 的定时任务。"
|
|
||||||
)
|
|
||||||
await schedule_cmd.finish(message)
|
|
||||||
else:
|
|
||||||
_, message = await scheduler_manager.remove_schedule(
|
|
||||||
p_name, target.group_id, bot_id=bot_id_to_operate
|
|
||||||
)
|
|
||||||
await schedule_cmd.finish(message)
|
|
||||||
|
|
||||||
elif isinstance(target, TargetAll):
|
|
||||||
if target.for_group:
|
|
||||||
_, message = await scheduler_manager.remove_schedules_by_group(
|
|
||||||
target.for_group
|
|
||||||
)
|
|
||||||
await schedule_cmd.finish(message)
|
|
||||||
else:
|
|
||||||
_, message = await scheduler_manager.remove_all_schedules()
|
|
||||||
await schedule_cmd.finish(message)
|
|
||||||
|
|
||||||
else:
|
|
||||||
await schedule_cmd.finish(
|
|
||||||
"删除任务失败:请提供任务ID,或通过 -p <插件> 或 -all 指定要删除的任务。"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@schedule_cmd.assign("暂停")
|
|
||||||
async def _(
|
|
||||||
target: TargetScope = Depends(create_target_parser("暂停")),
|
|
||||||
bot_id_to_operate: str = Depends(GetBotId),
|
|
||||||
):
|
|
||||||
if isinstance(target, TargetByID):
|
|
||||||
_, message = await scheduler_manager.pause_schedule(target.id)
|
|
||||||
await schedule_cmd.finish(message)
|
|
||||||
|
|
||||||
elif isinstance(target, TargetByPlugin):
|
|
||||||
p_name = target.plugin
|
|
||||||
if p_name not in scheduler_manager.get_registered_plugins():
|
|
||||||
await schedule_cmd.finish(f"未找到插件 '{p_name}'。")
|
|
||||||
|
|
||||||
if target.all_groups:
|
|
||||||
_, message = await scheduler_manager.pause_schedules_by_plugin(p_name)
|
|
||||||
await schedule_cmd.finish(message)
|
|
||||||
else:
|
|
||||||
_, message = await scheduler_manager.pause_schedule_by_plugin_group(
|
|
||||||
p_name, target.group_id, bot_id=bot_id_to_operate
|
|
||||||
)
|
|
||||||
await schedule_cmd.finish(message)
|
|
||||||
|
|
||||||
elif isinstance(target, TargetAll):
|
|
||||||
if target.for_group:
|
|
||||||
_, message = await scheduler_manager.pause_schedules_by_group(
|
|
||||||
target.for_group
|
|
||||||
)
|
|
||||||
await schedule_cmd.finish(message)
|
|
||||||
else:
|
|
||||||
_, message = await scheduler_manager.pause_all_schedules()
|
|
||||||
await schedule_cmd.finish(message)
|
|
||||||
|
|
||||||
else:
|
|
||||||
await schedule_cmd.finish("请提供任务ID、使用 -p <插件> 或 -all 选项。")
|
|
||||||
|
|
||||||
|
|
||||||
@schedule_cmd.assign("恢复")
|
|
||||||
async def _(
|
|
||||||
target: TargetScope = Depends(create_target_parser("恢复")),
|
|
||||||
bot_id_to_operate: str = Depends(GetBotId),
|
|
||||||
):
|
|
||||||
if isinstance(target, TargetByID):
|
|
||||||
_, message = await scheduler_manager.resume_schedule(target.id)
|
|
||||||
await schedule_cmd.finish(message)
|
|
||||||
|
|
||||||
elif isinstance(target, TargetByPlugin):
|
|
||||||
p_name = target.plugin
|
|
||||||
if p_name not in scheduler_manager.get_registered_plugins():
|
|
||||||
await schedule_cmd.finish(f"未找到插件 '{p_name}'。")
|
|
||||||
|
|
||||||
if target.all_groups:
|
|
||||||
_, message = await scheduler_manager.resume_schedules_by_plugin(p_name)
|
|
||||||
await schedule_cmd.finish(message)
|
|
||||||
else:
|
|
||||||
_, message = await scheduler_manager.resume_schedule_by_plugin_group(
|
|
||||||
p_name, target.group_id, bot_id=bot_id_to_operate
|
|
||||||
)
|
|
||||||
await schedule_cmd.finish(message)
|
|
||||||
|
|
||||||
elif isinstance(target, TargetAll):
|
|
||||||
if target.for_group:
|
|
||||||
_, message = await scheduler_manager.resume_schedules_by_group(
|
|
||||||
target.for_group
|
|
||||||
)
|
|
||||||
await schedule_cmd.finish(message)
|
|
||||||
else:
|
|
||||||
_, message = await scheduler_manager.resume_all_schedules()
|
|
||||||
await schedule_cmd.finish(message)
|
|
||||||
|
|
||||||
else:
|
|
||||||
await schedule_cmd.finish("请提供任务ID、使用 -p <插件> 或 -all 选项。")
|
|
||||||
|
|
||||||
|
|
||||||
@schedule_cmd.assign("执行")
|
|
||||||
async def _(schedule_id: int):
|
|
||||||
_, message = await scheduler_manager.trigger_now(schedule_id)
|
|
||||||
await schedule_cmd.finish(message)
|
|
||||||
|
|
||||||
|
|
||||||
@schedule_cmd.assign("更新")
|
|
||||||
async def _(
|
|
||||||
schedule_id: int,
|
|
||||||
cron_expr: str | None = None,
|
|
||||||
interval_expr: str | None = None,
|
|
||||||
date_expr: str | None = None,
|
|
||||||
daily_expr: str | None = None,
|
|
||||||
kwargs_str: str | None = None,
|
|
||||||
):
|
|
||||||
if not any([cron_expr, interval_expr, date_expr, daily_expr, kwargs_str]):
|
|
||||||
await schedule_cmd.finish(
|
|
||||||
"请提供需要更新的时间 (--cron/--interval/--date/--daily) 或参数 (--kwargs)"
|
|
||||||
)
|
|
||||||
|
|
||||||
trigger_config = None
|
|
||||||
trigger_type = None
|
|
||||||
try:
|
|
||||||
if cron_expr:
|
|
||||||
trigger_type = "cron"
|
|
||||||
parts = cron_expr.split()
|
|
||||||
if len(parts) != 5:
|
|
||||||
raise ValueError("Cron 表达式必须有5个部分")
|
|
||||||
cron_keys = ["minute", "hour", "day", "month", "day_of_week"]
|
|
||||||
trigger_config = dict(zip(cron_keys, parts))
|
|
||||||
elif interval_expr:
|
|
||||||
trigger_type = "interval"
|
|
||||||
trigger_config = _parse_interval(interval_expr)
|
|
||||||
elif date_expr:
|
|
||||||
trigger_type = "date"
|
|
||||||
trigger_config = {"run_date": datetime.fromisoformat(date_expr)}
|
|
||||||
elif daily_expr:
|
|
||||||
trigger_type = "cron"
|
|
||||||
trigger_config = _parse_daily_time(daily_expr)
|
|
||||||
except ValueError as e:
|
|
||||||
await schedule_cmd.finish(f"时间参数解析错误: {e}")
|
|
||||||
|
|
||||||
job_kwargs = None
|
|
||||||
if kwargs_str:
|
|
||||||
schedule = await scheduler_manager.get_schedule_by_id(schedule_id)
|
|
||||||
if not schedule:
|
|
||||||
await schedule_cmd.finish(f"未找到 ID 为 {schedule_id} 的任务。")
|
|
||||||
|
|
||||||
task_meta = scheduler_manager._registered_tasks.get(schedule.plugin_name)
|
|
||||||
if not task_meta or not (params_model := task_meta.get("model")):
|
|
||||||
await schedule_cmd.finish(
|
|
||||||
f"插件 '{schedule.plugin_name}' 未定义参数模型,无法更新参数。"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
|
|
||||||
await schedule_cmd.finish(
|
|
||||||
f"插件 '{schedule.plugin_name}' 的参数模型配置错误。"
|
|
||||||
)
|
|
||||||
|
|
||||||
raw_kwargs = {}
|
|
||||||
try:
|
|
||||||
for item in kwargs_str.split(","):
|
|
||||||
key, value = item.strip().split("=", 1)
|
|
||||||
raw_kwargs[key.strip()] = value
|
|
||||||
except Exception as e:
|
|
||||||
await schedule_cmd.finish(
|
|
||||||
f"参数格式错误,请使用 'key=value,key2=value2' 格式。错误: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
model_validate = getattr(params_model, "model_validate", None)
|
|
||||||
if not model_validate:
|
|
||||||
await schedule_cmd.finish(
|
|
||||||
f"插件 '{schedule.plugin_name}' 的参数模型不支持验证。"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
validated_model = model_validate(raw_kwargs)
|
|
||||||
|
|
||||||
model_dump = getattr(validated_model, "model_dump", None)
|
|
||||||
if not model_dump:
|
|
||||||
await schedule_cmd.finish(
|
|
||||||
f"插件 '{schedule.plugin_name}' 的参数模型不支持导出。"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
job_kwargs = model_dump(exclude_unset=True)
|
|
||||||
except ValidationError as e:
|
|
||||||
errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()]
|
|
||||||
error_str = "\n".join(errors)
|
|
||||||
await schedule_cmd.finish(f"更新的参数验证失败:\n{error_str}")
|
|
||||||
return
|
|
||||||
|
|
||||||
_, message = await scheduler_manager.update_schedule(
|
|
||||||
schedule_id, trigger_type, trigger_config, job_kwargs
|
|
||||||
)
|
|
||||||
await schedule_cmd.finish(message)
|
|
||||||
|
|
||||||
|
|
||||||
@schedule_cmd.assign("插件列表")
|
|
||||||
async def _():
|
|
||||||
registered_plugins = scheduler_manager.get_registered_plugins()
|
|
||||||
if not registered_plugins:
|
|
||||||
await schedule_cmd.finish("当前没有已注册的定时任务插件。")
|
|
||||||
|
|
||||||
message_parts = ["📋 已注册的定时任务插件:"]
|
|
||||||
for i, plugin_name in enumerate(registered_plugins, 1):
|
|
||||||
task_meta = scheduler_manager._registered_tasks[plugin_name]
|
|
||||||
params_model = task_meta.get("model")
|
|
||||||
|
|
||||||
if not params_model:
|
|
||||||
message_parts.append(f"{i}. {plugin_name} - 无参数")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
|
|
||||||
message_parts.append(f"{i}. {plugin_name} - ⚠️ 参数模型配置错误")
|
|
||||||
continue
|
|
||||||
|
|
||||||
model_fields = getattr(params_model, "model_fields", None)
|
|
||||||
if model_fields:
|
|
||||||
param_info = ", ".join(
|
|
||||||
f"{field_name}({_get_type_name(field_info.annotation)})"
|
|
||||||
for field_name, field_info in model_fields.items()
|
|
||||||
)
|
|
||||||
message_parts.append(f"{i}. {plugin_name} - 参数: {param_info}")
|
|
||||||
else:
|
|
||||||
message_parts.append(f"{i}. {plugin_name} - 无参数")
|
|
||||||
|
|
||||||
await schedule_cmd.finish("\n".join(message_parts))
|
|
||||||
|
|
||||||
|
|
||||||
@schedule_cmd.assign("状态")
|
|
||||||
async def _(schedule_id: int):
|
|
||||||
status = await scheduler_manager.get_schedule_status(schedule_id)
|
|
||||||
if not status:
|
|
||||||
await schedule_cmd.finish(f"未找到ID为 {schedule_id} 的定时任务。")
|
|
||||||
|
|
||||||
info_lines = [
|
|
||||||
f"📋 定时任务详细信息 (ID: {schedule_id})",
|
|
||||||
"--------------------",
|
|
||||||
f"▫️ 插件: {status['plugin_name']}",
|
|
||||||
f"▫️ Bot ID: {status.get('bot_id') or '默认'}",
|
|
||||||
f"▫️ 目标: {status['group_id'] or '全局'}",
|
|
||||||
f"▫️ 状态: {'✔️ 已启用' if status['is_enabled'] else '⏸️ 已暂停'}",
|
|
||||||
f"▫️ 下次运行: {status['next_run_time']}",
|
|
||||||
f"▫️ 触发规则: {_format_trigger(status)}",
|
|
||||||
f"▫️ 任务参数: {_format_params(status)}",
|
|
||||||
]
|
|
||||||
await schedule_cmd.finish("\n".join(info_lines))
|
|
||||||
298
zhenxun/builtin_plugins/scheduler_admin/commands.py
Normal file
298
zhenxun/builtin_plugins/scheduler_admin/commands.py
Normal file
@ -0,0 +1,298 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
from nonebot.adapters import Event
|
||||||
|
from nonebot.adapters.onebot.v11 import Bot
|
||||||
|
from nonebot.params import Depends
|
||||||
|
from nonebot.permission import SUPERUSER
|
||||||
|
from nonebot_plugin_alconna import (
|
||||||
|
Alconna,
|
||||||
|
AlconnaMatch,
|
||||||
|
Args,
|
||||||
|
Match,
|
||||||
|
Option,
|
||||||
|
Query,
|
||||||
|
Subcommand,
|
||||||
|
on_alconna,
|
||||||
|
)
|
||||||
|
|
||||||
|
from zhenxun.configs.config import Config
|
||||||
|
from zhenxun.services.scheduler import scheduler_manager
|
||||||
|
from zhenxun.services.scheduler.targeter import ScheduleTargeter
|
||||||
|
from zhenxun.utils.rules import admin_check
|
||||||
|
|
||||||
|
schedule_cmd = on_alconna(
|
||||||
|
Alconna(
|
||||||
|
"定时任务",
|
||||||
|
Subcommand(
|
||||||
|
"查看",
|
||||||
|
Option("-g", Args["target_group_id", str]),
|
||||||
|
Option("-all", help_text="查看所有群聊 (SUPERUSER)"),
|
||||||
|
Option("-p", Args["plugin_name", str], help_text="按插件名筛选"),
|
||||||
|
Option("--page", Args["page", int, 1], help_text="指定页码"),
|
||||||
|
alias=["ls", "list"],
|
||||||
|
help_text="查看定时任务",
|
||||||
|
),
|
||||||
|
Subcommand(
|
||||||
|
"设置",
|
||||||
|
Args["plugin_name", str],
|
||||||
|
Option("--cron", Args["cron_expr", str], help_text="设置 cron 表达式"),
|
||||||
|
Option("--interval", Args["interval_expr", str], help_text="设置时间间隔"),
|
||||||
|
Option("--date", Args["date_expr", str], help_text="设置特定执行日期"),
|
||||||
|
Option(
|
||||||
|
"--daily",
|
||||||
|
Args["daily_expr", str],
|
||||||
|
help_text="设置每天执行的时间 (如 08:20)",
|
||||||
|
),
|
||||||
|
Option("-g", Args["group_id", str], help_text="指定群组ID或'all'"),
|
||||||
|
Option("-all", help_text="对所有群生效 (等同于 -g all)"),
|
||||||
|
Option("--kwargs", Args["kwargs_str", str], help_text="设置任务参数"),
|
||||||
|
Option(
|
||||||
|
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
|
||||||
|
),
|
||||||
|
alias=["add", "开启"],
|
||||||
|
help_text="设置/开启一个定时任务",
|
||||||
|
),
|
||||||
|
Subcommand(
|
||||||
|
"删除",
|
||||||
|
Args["schedule_id?", int],
|
||||||
|
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
|
||||||
|
Option("-g", Args["group_id", str], help_text="指定群组ID"),
|
||||||
|
Option("-all", help_text="对所有群生效"),
|
||||||
|
Option(
|
||||||
|
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
|
||||||
|
),
|
||||||
|
alias=["del", "rm", "remove", "关闭", "取消"],
|
||||||
|
help_text="删除一个或多个定时任务",
|
||||||
|
),
|
||||||
|
Subcommand(
|
||||||
|
"暂停",
|
||||||
|
Args["schedule_id?", int],
|
||||||
|
Option("-all", help_text="对当前群所有任务生效"),
|
||||||
|
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
|
||||||
|
Option("-g", Args["group_id", str], help_text="指定群组ID (SUPERUSER)"),
|
||||||
|
Option(
|
||||||
|
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
|
||||||
|
),
|
||||||
|
alias=["pause"],
|
||||||
|
help_text="暂停一个或多个定时任务",
|
||||||
|
),
|
||||||
|
Subcommand(
|
||||||
|
"恢复",
|
||||||
|
Args["schedule_id?", int],
|
||||||
|
Option("-all", help_text="对当前群所有任务生效"),
|
||||||
|
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
|
||||||
|
Option("-g", Args["group_id", str], help_text="指定群组ID (SUPERUSER)"),
|
||||||
|
Option(
|
||||||
|
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
|
||||||
|
),
|
||||||
|
alias=["resume"],
|
||||||
|
help_text="恢复一个或多个定时任务",
|
||||||
|
),
|
||||||
|
Subcommand(
|
||||||
|
"执行",
|
||||||
|
Args["schedule_id", int],
|
||||||
|
alias=["trigger", "run"],
|
||||||
|
help_text="立即执行一次任务",
|
||||||
|
),
|
||||||
|
Subcommand(
|
||||||
|
"更新",
|
||||||
|
Args["schedule_id", int],
|
||||||
|
Option("--cron", Args["cron_expr", str], help_text="设置 cron 表达式"),
|
||||||
|
Option("--interval", Args["interval_expr", str], help_text="设置时间间隔"),
|
||||||
|
Option("--date", Args["date_expr", str], help_text="设置特定执行日期"),
|
||||||
|
Option(
|
||||||
|
"--daily",
|
||||||
|
Args["daily_expr", str],
|
||||||
|
help_text="更新每天执行的时间 (如 08:20)",
|
||||||
|
),
|
||||||
|
Option("--kwargs", Args["kwargs_str", str], help_text="更新参数"),
|
||||||
|
alias=["update", "modify", "修改"],
|
||||||
|
help_text="更新任务配置",
|
||||||
|
),
|
||||||
|
Subcommand(
|
||||||
|
"状态",
|
||||||
|
Args["schedule_id", int],
|
||||||
|
alias=["status", "info"],
|
||||||
|
help_text="查看单个任务的详细状态",
|
||||||
|
),
|
||||||
|
Subcommand(
|
||||||
|
"插件列表",
|
||||||
|
alias=["plugins"],
|
||||||
|
help_text="列出所有可用的插件",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
priority=5,
|
||||||
|
block=True,
|
||||||
|
rule=admin_check(1),
|
||||||
|
)
|
||||||
|
|
||||||
|
schedule_cmd.shortcut(
|
||||||
|
"任务状态",
|
||||||
|
command="定时任务",
|
||||||
|
arguments=["状态", "{%0}"],
|
||||||
|
prefix=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ScheduleTarget:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TargetByID(ScheduleTarget):
|
||||||
|
def __init__(self, id: int):
|
||||||
|
self.id = id
|
||||||
|
|
||||||
|
|
||||||
|
class TargetByPlugin(ScheduleTarget):
|
||||||
|
def __init__(
|
||||||
|
self, plugin: str, group_id: str | None = None, all_groups: bool = False
|
||||||
|
):
|
||||||
|
self.plugin = plugin
|
||||||
|
self.group_id = group_id
|
||||||
|
self.all_groups = all_groups
|
||||||
|
|
||||||
|
|
||||||
|
class TargetAll(ScheduleTarget):
|
||||||
|
def __init__(self, for_group: str | None = None):
|
||||||
|
self.for_group = for_group
|
||||||
|
|
||||||
|
|
||||||
|
TargetScope = TargetByID | TargetByPlugin | TargetAll | None
|
||||||
|
|
||||||
|
|
||||||
|
def create_target_parser(subcommand_name: str):
|
||||||
|
async def dependency(
|
||||||
|
event: Event,
|
||||||
|
schedule_id: Match[int] = AlconnaMatch("schedule_id"),
|
||||||
|
plugin_name: Match[str] = AlconnaMatch("plugin_name"),
|
||||||
|
group_id: Match[str] = AlconnaMatch("group_id"),
|
||||||
|
all_enabled: Query[bool] = Query(f"{subcommand_name}.all"),
|
||||||
|
) -> TargetScope:
|
||||||
|
if schedule_id.available:
|
||||||
|
return TargetByID(schedule_id.result)
|
||||||
|
|
||||||
|
if plugin_name.available:
|
||||||
|
p_name = plugin_name.result
|
||||||
|
if all_enabled.available:
|
||||||
|
return TargetByPlugin(plugin=p_name, all_groups=True)
|
||||||
|
elif group_id.available:
|
||||||
|
gid = group_id.result
|
||||||
|
if gid.lower() == "all":
|
||||||
|
return TargetByPlugin(plugin=p_name, all_groups=True)
|
||||||
|
return TargetByPlugin(plugin=p_name, group_id=gid)
|
||||||
|
else:
|
||||||
|
current_group_id = getattr(event, "group_id", None)
|
||||||
|
return TargetByPlugin(
|
||||||
|
plugin=p_name,
|
||||||
|
group_id=str(current_group_id) if current_group_id else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if all_enabled.available:
|
||||||
|
current_group_id = getattr(event, "group_id", None)
|
||||||
|
if not current_group_id:
|
||||||
|
await schedule_cmd.finish(
|
||||||
|
"私聊中单独使用 -all 选项时,必须使用 -g <群号> 指定目标。"
|
||||||
|
)
|
||||||
|
return TargetAll(for_group=str(current_group_id))
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
return dependency
|
||||||
|
|
||||||
|
|
||||||
|
def parse_interval(interval_str: str) -> dict:
|
||||||
|
match = re.match(r"(\d+)([smhd])", interval_str.lower())
|
||||||
|
if not match:
|
||||||
|
raise ValueError("时间间隔格式错误, 请使用如 '30m', '2h', '1d', '10s' 的格式。")
|
||||||
|
value, unit = int(match.group(1)), match.group(2)
|
||||||
|
if unit == "s":
|
||||||
|
return {"seconds": value}
|
||||||
|
if unit == "m":
|
||||||
|
return {"minutes": value}
|
||||||
|
if unit == "h":
|
||||||
|
return {"hours": value}
|
||||||
|
if unit == "d":
|
||||||
|
return {"days": value}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_daily_time(time_str: str) -> dict:
|
||||||
|
if match := re.match(r"^(\d{1,2}):(\d{1,2})(?::(\d{1,2}))?$", time_str):
|
||||||
|
hour, minute, second = match.groups()
|
||||||
|
hour, minute = int(hour), int(minute)
|
||||||
|
if not (0 <= hour <= 23 and 0 <= minute <= 59):
|
||||||
|
raise ValueError("小时或分钟数值超出范围。")
|
||||||
|
cron_config = {
|
||||||
|
"minute": str(minute),
|
||||||
|
"hour": str(hour),
|
||||||
|
"day": "*",
|
||||||
|
"month": "*",
|
||||||
|
"day_of_week": "*",
|
||||||
|
"timezone": Config.get_config("SchedulerManager", "SCHEDULER_TIMEZONE"),
|
||||||
|
}
|
||||||
|
if second is not None:
|
||||||
|
if not (0 <= int(second) <= 59):
|
||||||
|
raise ValueError("秒数值超出范围。")
|
||||||
|
cron_config["second"] = str(second)
|
||||||
|
return cron_config
|
||||||
|
else:
|
||||||
|
raise ValueError("时间格式错误,请使用 'HH:MM' 或 'HH:MM:SS' 格式。")
|
||||||
|
|
||||||
|
|
||||||
|
async def GetBotId(bot: Bot, bot_id_match: Match[str] = AlconnaMatch("bot_id")) -> str:
|
||||||
|
if bot_id_match.available:
|
||||||
|
return bot_id_match.result
|
||||||
|
return bot.self_id
|
||||||
|
|
||||||
|
|
||||||
|
def GetTargeter(subcommand: str):
|
||||||
|
"""
|
||||||
|
依赖注入函数,用于解析命令参数并返回一个配置好的 ScheduleTargeter 实例。
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def dependency(
|
||||||
|
event: Event,
|
||||||
|
bot: Bot,
|
||||||
|
schedule_id: Match[int] = AlconnaMatch("schedule_id"),
|
||||||
|
plugin_name: Match[str] = AlconnaMatch("plugin_name"),
|
||||||
|
group_id: Match[str] = AlconnaMatch("group_id"),
|
||||||
|
all_enabled: Query[bool] = Query(f"{subcommand}.all"),
|
||||||
|
bot_id_to_operate: str = Depends(GetBotId),
|
||||||
|
) -> ScheduleTargeter:
|
||||||
|
if schedule_id.available:
|
||||||
|
return scheduler_manager.target(id=schedule_id.result)
|
||||||
|
|
||||||
|
if plugin_name.available:
|
||||||
|
if all_enabled.available:
|
||||||
|
return scheduler_manager.target(plugin_name=plugin_name.result)
|
||||||
|
|
||||||
|
current_group_id = getattr(event, "group_id", None)
|
||||||
|
gid = group_id.result if group_id.available else current_group_id
|
||||||
|
return scheduler_manager.target(
|
||||||
|
plugin_name=plugin_name.result,
|
||||||
|
group_id=str(gid) if gid else None,
|
||||||
|
bot_id=bot_id_to_operate,
|
||||||
|
)
|
||||||
|
|
||||||
|
if all_enabled.available:
|
||||||
|
current_group_id = getattr(event, "group_id", None)
|
||||||
|
gid = group_id.result if group_id.available else current_group_id
|
||||||
|
is_su = await SUPERUSER(bot, event)
|
||||||
|
if not gid and not is_su:
|
||||||
|
await schedule_cmd.finish(
|
||||||
|
f"在私聊中对所有任务进行'{subcommand}'操作需要超级用户权限。"
|
||||||
|
)
|
||||||
|
|
||||||
|
if (gid and str(gid).lower() == "all") or (not gid and is_su):
|
||||||
|
return scheduler_manager.target()
|
||||||
|
|
||||||
|
return scheduler_manager.target(
|
||||||
|
group_id=str(gid) if gid else None, bot_id=bot_id_to_operate
|
||||||
|
)
|
||||||
|
|
||||||
|
await schedule_cmd.finish(
|
||||||
|
f"'{subcommand}'操作失败:请提供任务ID,"
|
||||||
|
f"或通过 -p <插件名> 或 -all 指定要操作的任务。"
|
||||||
|
)
|
||||||
|
|
||||||
|
return Depends(dependency)
|
||||||
380
zhenxun/builtin_plugins/scheduler_admin/handlers.py
Normal file
380
zhenxun/builtin_plugins/scheduler_admin/handlers.py
Normal file
@ -0,0 +1,380 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from nonebot.adapters import Event
|
||||||
|
from nonebot.adapters.onebot.v11 import Bot
|
||||||
|
from nonebot.params import Depends
|
||||||
|
from nonebot.permission import SUPERUSER
|
||||||
|
from nonebot_plugin_alconna import AlconnaMatch, Arparma, Match, Query
|
||||||
|
from pydantic import BaseModel, ValidationError
|
||||||
|
|
||||||
|
from zhenxun.models.schedule_info import ScheduleInfo
|
||||||
|
from zhenxun.services.scheduler import scheduler_manager
|
||||||
|
from zhenxun.services.scheduler.targeter import ScheduleTargeter
|
||||||
|
from zhenxun.utils.message import MessageUtils
|
||||||
|
|
||||||
|
from . import presenters
|
||||||
|
from .commands import (
|
||||||
|
GetBotId,
|
||||||
|
GetTargeter,
|
||||||
|
parse_daily_time,
|
||||||
|
parse_interval,
|
||||||
|
schedule_cmd,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@schedule_cmd.handle()
|
||||||
|
async def _handle_time_options_mutex(arp: Arparma):
|
||||||
|
time_options = ["cron", "interval", "date", "daily"]
|
||||||
|
provided_options = [opt for opt in time_options if arp.query(opt) is not None]
|
||||||
|
if len(provided_options) > 1:
|
||||||
|
await schedule_cmd.finish(
|
||||||
|
f"时间选项 --{', --'.join(provided_options)} 不能同时使用,请只选择一个。"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@schedule_cmd.assign("查看")
|
||||||
|
async def handle_view(
|
||||||
|
bot: Bot,
|
||||||
|
event: Event,
|
||||||
|
target_group_id: Match[str] = AlconnaMatch("target_group_id"),
|
||||||
|
all_groups: Query[bool] = Query("查看.all"),
|
||||||
|
plugin_name: Match[str] = AlconnaMatch("plugin_name"),
|
||||||
|
page: Match[int] = AlconnaMatch("page"),
|
||||||
|
):
|
||||||
|
is_superuser = await SUPERUSER(bot, event)
|
||||||
|
title = ""
|
||||||
|
gid_filter = None
|
||||||
|
|
||||||
|
current_group_id = getattr(event, "group_id", None)
|
||||||
|
if not (all_groups.available or target_group_id.available) and not current_group_id:
|
||||||
|
await schedule_cmd.finish("私聊中查看任务必须使用 -g <群号> 或 -all 选项。")
|
||||||
|
|
||||||
|
if all_groups.available:
|
||||||
|
if not is_superuser:
|
||||||
|
await schedule_cmd.finish("需要超级用户权限才能查看所有群组的定时任务。")
|
||||||
|
title = "所有群组的定时任务"
|
||||||
|
elif target_group_id.available:
|
||||||
|
if not is_superuser:
|
||||||
|
await schedule_cmd.finish("需要超级用户权限才能查看指定群组的定时任务。")
|
||||||
|
gid_filter = target_group_id.result
|
||||||
|
title = f"群 {gid_filter} 的定时任务"
|
||||||
|
else:
|
||||||
|
gid_filter = str(current_group_id)
|
||||||
|
title = "本群的定时任务"
|
||||||
|
|
||||||
|
p_name_filter = plugin_name.result if plugin_name.available else None
|
||||||
|
|
||||||
|
schedules = await scheduler_manager.get_schedules(
|
||||||
|
plugin_name=p_name_filter, group_id=gid_filter
|
||||||
|
)
|
||||||
|
|
||||||
|
if p_name_filter:
|
||||||
|
title += f" [插件: {p_name_filter}]"
|
||||||
|
|
||||||
|
if not schedules:
|
||||||
|
await schedule_cmd.finish("没有找到任何相关的定时任务。")
|
||||||
|
|
||||||
|
img = await presenters.format_schedule_list_as_image(
|
||||||
|
schedules=schedules, title=title, current_page=page.result
|
||||||
|
)
|
||||||
|
await MessageUtils.build_message(img).send(reply_to=True)
|
||||||
|
|
||||||
|
|
||||||
|
@schedule_cmd.assign("设置")
|
||||||
|
async def handle_set(
|
||||||
|
event: Event,
|
||||||
|
plugin_name: Match[str] = AlconnaMatch("plugin_name"),
|
||||||
|
cron_expr: Match[str] = AlconnaMatch("cron_expr"),
|
||||||
|
interval_expr: Match[str] = AlconnaMatch("interval_expr"),
|
||||||
|
date_expr: Match[str] = AlconnaMatch("date_expr"),
|
||||||
|
daily_expr: Match[str] = AlconnaMatch("daily_expr"),
|
||||||
|
group_id: Match[str] = AlconnaMatch("group_id"),
|
||||||
|
kwargs_str: Match[str] = AlconnaMatch("kwargs_str"),
|
||||||
|
all_enabled: Query[bool] = Query("设置.all"),
|
||||||
|
bot_id_to_operate: str = Depends(GetBotId),
|
||||||
|
):
|
||||||
|
if not plugin_name.available:
|
||||||
|
await schedule_cmd.finish("设置任务时必须提供插件名称。")
|
||||||
|
|
||||||
|
has_time_option = any(
|
||||||
|
[
|
||||||
|
cron_expr.available,
|
||||||
|
interval_expr.available,
|
||||||
|
date_expr.available,
|
||||||
|
daily_expr.available,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if not has_time_option:
|
||||||
|
await schedule_cmd.finish(
|
||||||
|
"必须提供一种时间选项: --cron, --interval, --date, 或 --daily。"
|
||||||
|
)
|
||||||
|
|
||||||
|
p_name = plugin_name.result
|
||||||
|
if p_name not in scheduler_manager.get_registered_plugins():
|
||||||
|
await schedule_cmd.finish(
|
||||||
|
f"插件 '{p_name}' 没有注册可用的定时任务。\n"
|
||||||
|
f"可用插件: {list(scheduler_manager.get_registered_plugins())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
trigger_type, trigger_config = "", {}
|
||||||
|
try:
|
||||||
|
if cron_expr.available:
|
||||||
|
trigger_type, trigger_config = (
|
||||||
|
"cron",
|
||||||
|
dict(
|
||||||
|
zip(
|
||||||
|
["minute", "hour", "day", "month", "day_of_week"],
|
||||||
|
cron_expr.result.split(),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
elif interval_expr.available:
|
||||||
|
trigger_type, trigger_config = (
|
||||||
|
"interval",
|
||||||
|
parse_interval(interval_expr.result),
|
||||||
|
)
|
||||||
|
elif date_expr.available:
|
||||||
|
trigger_type, trigger_config = (
|
||||||
|
"date",
|
||||||
|
{"run_date": datetime.fromisoformat(date_expr.result)},
|
||||||
|
)
|
||||||
|
elif daily_expr.available:
|
||||||
|
trigger_type, trigger_config = "cron", parse_daily_time(daily_expr.result)
|
||||||
|
else:
|
||||||
|
await schedule_cmd.finish(
|
||||||
|
"必须提供一种时间选项: --cron, --interval, --date, 或 --daily。"
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
await schedule_cmd.finish(f"时间参数解析错误: {e}")
|
||||||
|
|
||||||
|
job_kwargs = {}
|
||||||
|
if kwargs_str.available:
|
||||||
|
task_meta = scheduler_manager._registered_tasks[p_name]
|
||||||
|
params_model = task_meta.get("model")
|
||||||
|
if not (
|
||||||
|
params_model
|
||||||
|
and isinstance(params_model, type)
|
||||||
|
and issubclass(params_model, BaseModel)
|
||||||
|
):
|
||||||
|
await schedule_cmd.finish(f"插件 '{p_name}' 不支持或配置了无效的参数模型。")
|
||||||
|
try:
|
||||||
|
raw_kwargs = dict(
|
||||||
|
item.strip().split("=", 1) for item in kwargs_str.result.split(",")
|
||||||
|
)
|
||||||
|
|
||||||
|
model_validate = getattr(params_model, "model_validate", None)
|
||||||
|
if not model_validate:
|
||||||
|
await schedule_cmd.finish(f"插件 '{p_name}' 的参数模型不支持验证")
|
||||||
|
|
||||||
|
validated_model = model_validate(raw_kwargs)
|
||||||
|
|
||||||
|
model_dump = getattr(validated_model, "model_dump", None)
|
||||||
|
if not model_dump:
|
||||||
|
await schedule_cmd.finish(f"插件 '{p_name}' 的参数模型不支持导出")
|
||||||
|
|
||||||
|
job_kwargs = model_dump()
|
||||||
|
except ValidationError as e:
|
||||||
|
errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()]
|
||||||
|
await schedule_cmd.finish(
|
||||||
|
f"插件 '{p_name}' 的任务参数验证失败:\n" + "\n".join(errors)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
await schedule_cmd.finish(
|
||||||
|
f"参数格式错误,请使用 'key=value,key2=value2' 格式。错误: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
gid_str = group_id.result if group_id.available else None
|
||||||
|
target_group_id = (
|
||||||
|
scheduler_manager.ALL_GROUPS
|
||||||
|
if (gid_str and gid_str.lower() == "all") or all_enabled.available
|
||||||
|
else gid_str or getattr(event, "group_id", None)
|
||||||
|
)
|
||||||
|
if not target_group_id:
|
||||||
|
await schedule_cmd.finish(
|
||||||
|
"私聊中设置定时任务时,必须使用 -g <群号> 或 --all 选项指定目标。"
|
||||||
|
)
|
||||||
|
|
||||||
|
schedule = await scheduler_manager.add_schedule(
|
||||||
|
p_name,
|
||||||
|
str(target_group_id),
|
||||||
|
trigger_type,
|
||||||
|
trigger_config,
|
||||||
|
job_kwargs,
|
||||||
|
bot_id=bot_id_to_operate,
|
||||||
|
)
|
||||||
|
|
||||||
|
target_desc = (
|
||||||
|
f"所有群组 (Bot: {bot_id_to_operate})"
|
||||||
|
if target_group_id == scheduler_manager.ALL_GROUPS
|
||||||
|
else f"群组 {target_group_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if schedule:
|
||||||
|
await schedule_cmd.finish(
|
||||||
|
f"为 [{target_desc}] 已成功设置插件 '{p_name}' 的定时任务 "
|
||||||
|
f"(ID: {schedule.id})。"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await schedule_cmd.finish(f"为 [{target_desc}] 设置任务失败。")
|
||||||
|
|
||||||
|
|
||||||
|
@schedule_cmd.assign("删除")
|
||||||
|
async def handle_delete(targeter: ScheduleTargeter = GetTargeter("删除")):
|
||||||
|
schedules_to_remove: list[ScheduleInfo] = await targeter._get_schedules()
|
||||||
|
if not schedules_to_remove:
|
||||||
|
await schedule_cmd.finish("没有找到可删除的任务。")
|
||||||
|
|
||||||
|
count, _ = await targeter.remove()
|
||||||
|
|
||||||
|
if count > 0 and schedules_to_remove:
|
||||||
|
if len(schedules_to_remove) == 1:
|
||||||
|
message = presenters.format_remove_success(schedules_to_remove[0])
|
||||||
|
else:
|
||||||
|
target_desc = targeter._generate_target_description()
|
||||||
|
message = f"✅ 成功移除了{target_desc} {count} 个任务。"
|
||||||
|
else:
|
||||||
|
message = "没有任务被移除。"
|
||||||
|
await schedule_cmd.finish(message)
|
||||||
|
|
||||||
|
|
||||||
|
@schedule_cmd.assign("暂停")
|
||||||
|
async def handle_pause(targeter: ScheduleTargeter = GetTargeter("暂停")):
|
||||||
|
schedules_to_pause: list[ScheduleInfo] = await targeter._get_schedules()
|
||||||
|
if not schedules_to_pause:
|
||||||
|
await schedule_cmd.finish("没有找到可暂停的任务。")
|
||||||
|
|
||||||
|
count, _ = await targeter.pause()
|
||||||
|
|
||||||
|
if count > 0 and schedules_to_pause:
|
||||||
|
if len(schedules_to_pause) == 1:
|
||||||
|
message = presenters.format_pause_success(schedules_to_pause[0])
|
||||||
|
else:
|
||||||
|
target_desc = targeter._generate_target_description()
|
||||||
|
message = f"✅ 成功暂停了{target_desc} {count} 个任务。"
|
||||||
|
else:
|
||||||
|
message = "没有任务被暂停。"
|
||||||
|
await schedule_cmd.finish(message)
|
||||||
|
|
||||||
|
|
||||||
|
@schedule_cmd.assign("恢复")
|
||||||
|
async def handle_resume(targeter: ScheduleTargeter = GetTargeter("恢复")):
|
||||||
|
schedules_to_resume: list[ScheduleInfo] = await targeter._get_schedules()
|
||||||
|
if not schedules_to_resume:
|
||||||
|
await schedule_cmd.finish("没有找到可恢复的任务。")
|
||||||
|
|
||||||
|
count, _ = await targeter.resume()
|
||||||
|
|
||||||
|
if count > 0 and schedules_to_resume:
|
||||||
|
if len(schedules_to_resume) == 1:
|
||||||
|
message = presenters.format_resume_success(schedules_to_resume[0])
|
||||||
|
else:
|
||||||
|
target_desc = targeter._generate_target_description()
|
||||||
|
message = f"✅ 成功恢复了{target_desc} {count} 个任务。"
|
||||||
|
else:
|
||||||
|
message = "没有任务被恢复。"
|
||||||
|
await schedule_cmd.finish(message)
|
||||||
|
|
||||||
|
|
||||||
|
@schedule_cmd.assign("执行")
|
||||||
|
async def handle_trigger(schedule_id: Match[int] = AlconnaMatch("schedule_id")):
|
||||||
|
from zhenxun.services.scheduler.repository import ScheduleRepository
|
||||||
|
|
||||||
|
schedule_info = await ScheduleRepository.get_by_id(schedule_id.result)
|
||||||
|
if not schedule_info:
|
||||||
|
await schedule_cmd.finish(f"未找到 ID 为 {schedule_id.result} 的任务。")
|
||||||
|
|
||||||
|
success, message = await scheduler_manager.trigger_now(schedule_id.result)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
final_message = presenters.format_trigger_success(schedule_info)
|
||||||
|
else:
|
||||||
|
final_message = f"❌ 手动触发失败: {message}"
|
||||||
|
await schedule_cmd.finish(final_message)
|
||||||
|
|
||||||
|
|
||||||
|
@schedule_cmd.assign("更新")
|
||||||
|
async def handle_update(
|
||||||
|
schedule_id: Match[int] = AlconnaMatch("schedule_id"),
|
||||||
|
cron_expr: Match[str] = AlconnaMatch("cron_expr"),
|
||||||
|
interval_expr: Match[str] = AlconnaMatch("interval_expr"),
|
||||||
|
date_expr: Match[str] = AlconnaMatch("date_expr"),
|
||||||
|
daily_expr: Match[str] = AlconnaMatch("daily_expr"),
|
||||||
|
kwargs_str: Match[str] = AlconnaMatch("kwargs_str"),
|
||||||
|
):
|
||||||
|
if not any(
|
||||||
|
[
|
||||||
|
cron_expr.available,
|
||||||
|
interval_expr.available,
|
||||||
|
date_expr.available,
|
||||||
|
daily_expr.available,
|
||||||
|
kwargs_str.available,
|
||||||
|
]
|
||||||
|
):
|
||||||
|
await schedule_cmd.finish(
|
||||||
|
"请提供需要更新的时间 (--cron/--interval/--date/--daily) 或参数 (--kwargs)"
|
||||||
|
)
|
||||||
|
|
||||||
|
trigger_type, trigger_config, job_kwargs = None, None, None
|
||||||
|
try:
|
||||||
|
if cron_expr.available:
|
||||||
|
trigger_type, trigger_config = (
|
||||||
|
"cron",
|
||||||
|
dict(
|
||||||
|
zip(
|
||||||
|
["minute", "hour", "day", "month", "day_of_week"],
|
||||||
|
cron_expr.result.split(),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
elif interval_expr.available:
|
||||||
|
trigger_type, trigger_config = (
|
||||||
|
"interval",
|
||||||
|
parse_interval(interval_expr.result),
|
||||||
|
)
|
||||||
|
elif date_expr.available:
|
||||||
|
trigger_type, trigger_config = (
|
||||||
|
"date",
|
||||||
|
{"run_date": datetime.fromisoformat(date_expr.result)},
|
||||||
|
)
|
||||||
|
elif daily_expr.available:
|
||||||
|
trigger_type, trigger_config = "cron", parse_daily_time(daily_expr.result)
|
||||||
|
except ValueError as e:
|
||||||
|
await schedule_cmd.finish(f"时间参数解析错误: {e}")
|
||||||
|
|
||||||
|
if kwargs_str.available:
|
||||||
|
job_kwargs = dict(
|
||||||
|
item.strip().split("=", 1) for item in kwargs_str.result.split(",")
|
||||||
|
)
|
||||||
|
|
||||||
|
success, message = await scheduler_manager.update_schedule(
|
||||||
|
schedule_id.result, trigger_type, trigger_config, job_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
from zhenxun.services.scheduler.repository import ScheduleRepository
|
||||||
|
|
||||||
|
updated_schedule = await ScheduleRepository.get_by_id(schedule_id.result)
|
||||||
|
if updated_schedule:
|
||||||
|
final_message = presenters.format_update_success(updated_schedule)
|
||||||
|
else:
|
||||||
|
final_message = "✅ 更新成功,但无法获取更新后的任务详情。"
|
||||||
|
else:
|
||||||
|
final_message = f"❌ 更新失败: {message}"
|
||||||
|
|
||||||
|
await schedule_cmd.finish(final_message)
|
||||||
|
|
||||||
|
|
||||||
|
@schedule_cmd.assign("插件列表")
|
||||||
|
async def handle_plugins_list():
|
||||||
|
message = await presenters.format_plugins_list()
|
||||||
|
await schedule_cmd.finish(message)
|
||||||
|
|
||||||
|
|
||||||
|
@schedule_cmd.assign("状态")
|
||||||
|
async def handle_status(schedule_id: Match[int] = AlconnaMatch("schedule_id")):
|
||||||
|
status = await scheduler_manager.get_schedule_status(schedule_id.result)
|
||||||
|
if not status:
|
||||||
|
await schedule_cmd.finish(f"未找到ID为 {schedule_id.result} 的定时任务。")
|
||||||
|
|
||||||
|
message = presenters.format_single_status_message(status)
|
||||||
|
await schedule_cmd.finish(message)
|
||||||
274
zhenxun/builtin_plugins/scheduler_admin/presenters.py
Normal file
274
zhenxun/builtin_plugins/scheduler_admin/presenters.py
Normal file
@ -0,0 +1,274 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
from zhenxun.models.schedule_info import ScheduleInfo
|
||||||
|
from zhenxun.services.scheduler import scheduler_manager
|
||||||
|
from zhenxun.utils._image_template import ImageTemplate, RowStyle
|
||||||
|
|
||||||
|
|
||||||
|
def _get_type_name(annotation) -> str:
|
||||||
|
"""获取类型注解的名称"""
|
||||||
|
if hasattr(annotation, "__name__"):
|
||||||
|
return annotation.__name__
|
||||||
|
elif hasattr(annotation, "_name"):
|
||||||
|
return annotation._name
|
||||||
|
else:
|
||||||
|
return str(annotation)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_trigger(schedule: dict) -> str:
|
||||||
|
"""格式化触发器信息为可读字符串"""
|
||||||
|
trigger_type = schedule.get("trigger_type")
|
||||||
|
config = schedule.get("trigger_config")
|
||||||
|
|
||||||
|
if not isinstance(config, dict):
|
||||||
|
return f"配置错误: {config}"
|
||||||
|
|
||||||
|
if trigger_type == "cron":
|
||||||
|
hour = config.get("hour", "??")
|
||||||
|
minute = config.get("minute", "??")
|
||||||
|
try:
|
||||||
|
hour_int = int(hour)
|
||||||
|
minute_int = int(minute)
|
||||||
|
return f"每天 {hour_int:02d}:{minute_int:02d}"
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return f"每天 {hour}:{minute}"
|
||||||
|
elif trigger_type == "interval":
|
||||||
|
units = {
|
||||||
|
"weeks": "周",
|
||||||
|
"days": "天",
|
||||||
|
"hours": "小时",
|
||||||
|
"minutes": "分钟",
|
||||||
|
"seconds": "秒",
|
||||||
|
}
|
||||||
|
for unit, unit_name in units.items():
|
||||||
|
if value := config.get(unit):
|
||||||
|
return f"每 {value} {unit_name}"
|
||||||
|
return "未知间隔"
|
||||||
|
elif trigger_type == "date":
|
||||||
|
run_date = config.get("run_date", "N/A")
|
||||||
|
return f"特定时间 {run_date}"
|
||||||
|
else:
|
||||||
|
return f"未知触发器类型: {trigger_type}"
|
||||||
|
|
||||||
|
|
||||||
|
def _format_trigger_for_card(schedule_info: ScheduleInfo | dict) -> str:
|
||||||
|
"""为信息卡片格式化触发器规则"""
|
||||||
|
trigger_type = (
|
||||||
|
schedule_info.get("trigger_type")
|
||||||
|
if isinstance(schedule_info, dict)
|
||||||
|
else schedule_info.trigger_type
|
||||||
|
)
|
||||||
|
config = (
|
||||||
|
schedule_info.get("trigger_config")
|
||||||
|
if isinstance(schedule_info, dict)
|
||||||
|
else schedule_info.trigger_config
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(config, dict):
|
||||||
|
return f"配置错误: {config}"
|
||||||
|
|
||||||
|
if trigger_type == "cron":
|
||||||
|
hour = config.get("hour", "??")
|
||||||
|
minute = config.get("minute", "??")
|
||||||
|
try:
|
||||||
|
hour_int = int(hour)
|
||||||
|
minute_int = int(minute)
|
||||||
|
return f"每天 {hour_int:02d}:{minute_int:02d}"
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return f"每天 {hour}:{minute}"
|
||||||
|
elif trigger_type == "interval":
|
||||||
|
units = {
|
||||||
|
"weeks": "周",
|
||||||
|
"days": "天",
|
||||||
|
"hours": "小时",
|
||||||
|
"minutes": "分钟",
|
||||||
|
"seconds": "秒",
|
||||||
|
}
|
||||||
|
for unit, unit_name in units.items():
|
||||||
|
if value := config.get(unit):
|
||||||
|
return f"每 {value} {unit_name}"
|
||||||
|
return "未知间隔"
|
||||||
|
elif trigger_type == "date":
|
||||||
|
run_date = config.get("run_date", "N/A")
|
||||||
|
return f"特定时间 {run_date}"
|
||||||
|
else:
|
||||||
|
return f"未知规则: {trigger_type}"
|
||||||
|
|
||||||
|
|
||||||
|
def _format_operation_result_card(
|
||||||
|
title: str, schedule_info: ScheduleInfo, extra_info: list[str] | None = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
生成一个标准的操作结果信息卡片。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
title: 卡片的标题 (例如 "✅ 成功暂停定时任务!")
|
||||||
|
schedule_info: 相关的 ScheduleInfo 对象
|
||||||
|
extra_info: (可选) 额外的补充信息行
|
||||||
|
"""
|
||||||
|
target_desc = (
|
||||||
|
f"群组 {schedule_info.group_id}"
|
||||||
|
if schedule_info.group_id
|
||||||
|
and schedule_info.group_id != scheduler_manager.ALL_GROUPS
|
||||||
|
else "所有群组"
|
||||||
|
if schedule_info.group_id == scheduler_manager.ALL_GROUPS
|
||||||
|
else "全局"
|
||||||
|
)
|
||||||
|
|
||||||
|
info_lines = [
|
||||||
|
title,
|
||||||
|
f"✓ 任务 ID: {schedule_info.id}",
|
||||||
|
f"🖋 插件: {schedule_info.plugin_name}",
|
||||||
|
f"🎯 目标: {target_desc}",
|
||||||
|
f"⏰ 时间: {_format_trigger_for_card(schedule_info)}",
|
||||||
|
]
|
||||||
|
if extra_info:
|
||||||
|
info_lines.extend(extra_info)
|
||||||
|
|
||||||
|
return "\n".join(info_lines)
|
||||||
|
|
||||||
|
|
||||||
|
def format_pause_success(schedule_info: ScheduleInfo) -> str:
|
||||||
|
"""格式化暂停成功的消息"""
|
||||||
|
return _format_operation_result_card("✅ 成功暂停定时任务!", schedule_info)
|
||||||
|
|
||||||
|
|
||||||
|
def format_resume_success(schedule_info: ScheduleInfo) -> str:
|
||||||
|
"""格式化恢复成功的消息"""
|
||||||
|
return _format_operation_result_card("▶️ 成功恢复定时任务!", schedule_info)
|
||||||
|
|
||||||
|
|
||||||
|
def format_remove_success(schedule_info: ScheduleInfo) -> str:
|
||||||
|
"""格式化删除成功的消息"""
|
||||||
|
return _format_operation_result_card("❌ 成功删除定时任务!", schedule_info)
|
||||||
|
|
||||||
|
|
||||||
|
def format_trigger_success(schedule_info: ScheduleInfo) -> str:
|
||||||
|
"""格式化手动触发成功的消息"""
|
||||||
|
return _format_operation_result_card("🚀 成功手动触发定时任务!", schedule_info)
|
||||||
|
|
||||||
|
|
||||||
|
def format_update_success(schedule_info: ScheduleInfo) -> str:
|
||||||
|
"""格式化更新成功的消息"""
|
||||||
|
return _format_operation_result_card("🔄️ 成功更新定时任务配置!", schedule_info)
|
||||||
|
|
||||||
|
|
||||||
|
def _status_row_style(column: str, text: str) -> RowStyle:
|
||||||
|
"""为状态列设置颜色"""
|
||||||
|
style = RowStyle()
|
||||||
|
if column == "状态":
|
||||||
|
if text == "启用":
|
||||||
|
style.font_color = "#67C23A"
|
||||||
|
elif text == "暂停":
|
||||||
|
style.font_color = "#F56C6C"
|
||||||
|
elif text == "运行中":
|
||||||
|
style.font_color = "#409EFF"
|
||||||
|
return style
|
||||||
|
|
||||||
|
|
||||||
|
def _format_params(schedule_status: dict) -> str:
|
||||||
|
"""将任务参数格式化为人类可读的字符串"""
|
||||||
|
if kwargs := schedule_status.get("job_kwargs"):
|
||||||
|
return " | ".join(f"{k}: {v}" for k, v in kwargs.items())
|
||||||
|
return "-"
|
||||||
|
|
||||||
|
|
||||||
|
async def format_schedule_list_as_image(
|
||||||
|
schedules: list[ScheduleInfo], title: str, current_page: int
|
||||||
|
):
|
||||||
|
"""将任务列表格式化为图片"""
|
||||||
|
page_size = 15
|
||||||
|
total_items = len(schedules)
|
||||||
|
total_pages = (total_items + page_size - 1) // page_size
|
||||||
|
start_index = (current_page - 1) * page_size
|
||||||
|
end_index = start_index + page_size
|
||||||
|
paginated_schedules = schedules[start_index:end_index]
|
||||||
|
|
||||||
|
if not paginated_schedules:
|
||||||
|
return "这一页没有内容了哦~"
|
||||||
|
|
||||||
|
status_tasks = [
|
||||||
|
scheduler_manager.get_schedule_status(s.id) for s in paginated_schedules
|
||||||
|
]
|
||||||
|
all_statuses = await asyncio.gather(*status_tasks)
|
||||||
|
|
||||||
|
def get_status_text(status_value):
|
||||||
|
if isinstance(status_value, bool):
|
||||||
|
return "启用" if status_value else "暂停"
|
||||||
|
return str(status_value)
|
||||||
|
|
||||||
|
data_list = [
|
||||||
|
[
|
||||||
|
s["id"],
|
||||||
|
s["plugin_name"],
|
||||||
|
s.get("bot_id") or "N/A",
|
||||||
|
s["group_id"] or "全局",
|
||||||
|
s["next_run_time"],
|
||||||
|
_format_trigger(s),
|
||||||
|
_format_params(s),
|
||||||
|
get_status_text(s["is_enabled"]),
|
||||||
|
]
|
||||||
|
for s in all_statuses
|
||||||
|
if s
|
||||||
|
]
|
||||||
|
|
||||||
|
if not data_list:
|
||||||
|
return "没有找到任何相关的定时任务。"
|
||||||
|
|
||||||
|
return await ImageTemplate.table_page(
|
||||||
|
head_text=title,
|
||||||
|
tip_text=f"第 {current_page}/{total_pages} 页,共 {total_items} 条任务",
|
||||||
|
column_name=["ID", "插件", "Bot", "目标", "下次运行", "规则", "参数", "状态"],
|
||||||
|
data_list=data_list,
|
||||||
|
column_space=20,
|
||||||
|
text_style=_status_row_style,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def format_single_status_message(status: dict) -> str:
|
||||||
|
"""格式化单个任务状态为文本消息"""
|
||||||
|
info_lines = [
|
||||||
|
f"📋 定时任务详细信息 (ID: {status['id']})",
|
||||||
|
"--------------------",
|
||||||
|
f"▫️ 插件: {status['plugin_name']}",
|
||||||
|
f"▫️ Bot ID: {status.get('bot_id') or '默认'}",
|
||||||
|
f"▫️ 目标: {status['group_id'] or '全局'}",
|
||||||
|
f"▫️ 状态: {'✔️ 已启用' if status['is_enabled'] else '⏸️ 已暂停'}",
|
||||||
|
f"▫️ 下次运行: {status['next_run_time']}",
|
||||||
|
f"▫️ 触发规则: {_format_trigger(status)}",
|
||||||
|
f"▫️ 任务参数: {_format_params(status)}",
|
||||||
|
]
|
||||||
|
return "\n".join(info_lines)
|
||||||
|
|
||||||
|
|
||||||
|
async def format_plugins_list() -> str:
|
||||||
|
"""格式化可用插件列表为文本消息"""
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
registered_plugins = scheduler_manager.get_registered_plugins()
|
||||||
|
if not registered_plugins:
|
||||||
|
return "当前没有已注册的定时任务插件。"
|
||||||
|
|
||||||
|
message_parts = ["📋 已注册的定时任务插件:"]
|
||||||
|
for i, plugin_name in enumerate(registered_plugins, 1):
|
||||||
|
task_meta = scheduler_manager._registered_tasks[plugin_name]
|
||||||
|
params_model = task_meta.get("model")
|
||||||
|
|
||||||
|
param_info_str = "无参数"
|
||||||
|
if (
|
||||||
|
params_model
|
||||||
|
and isinstance(params_model, type)
|
||||||
|
and issubclass(params_model, BaseModel)
|
||||||
|
):
|
||||||
|
model_fields = getattr(params_model, "model_fields", None)
|
||||||
|
if model_fields:
|
||||||
|
param_info_str = "参数: " + ", ".join(
|
||||||
|
f"{field_name}({_get_type_name(field_info.annotation)})"
|
||||||
|
for field_name, field_info in model_fields.items()
|
||||||
|
)
|
||||||
|
elif params_model:
|
||||||
|
param_info_str = "⚠️ 参数模型配置错误"
|
||||||
|
|
||||||
|
message_parts.append(f"{i}. {plugin_name} - {param_info_str}")
|
||||||
|
|
||||||
|
return "\n".join(message_parts)
|
||||||
@ -1,5 +1,3 @@
|
|||||||
from datetime import datetime, timedelta
|
|
||||||
|
|
||||||
from tortoise.functions import Count
|
from tortoise.functions import Count
|
||||||
|
|
||||||
from zhenxun.models.group_console import GroupConsole
|
from zhenxun.models.group_console import GroupConsole
|
||||||
@ -10,6 +8,7 @@ from zhenxun.utils.echart_utils import ChartUtils
|
|||||||
from zhenxun.utils.echart_utils.models import Barh
|
from zhenxun.utils.echart_utils.models import Barh
|
||||||
from zhenxun.utils.enum import PluginType
|
from zhenxun.utils.enum import PluginType
|
||||||
from zhenxun.utils.image_utils import BuildImage
|
from zhenxun.utils.image_utils import BuildImage
|
||||||
|
from zhenxun.utils.utils import TimeUtils
|
||||||
|
|
||||||
|
|
||||||
class StatisticsManage:
|
class StatisticsManage:
|
||||||
@ -68,8 +67,7 @@ class StatisticsManage:
|
|||||||
if plugin_name:
|
if plugin_name:
|
||||||
query = query.filter(plugin_name=plugin_name)
|
query = query.filter(plugin_name=plugin_name)
|
||||||
if day:
|
if day:
|
||||||
time = datetime.now() - timedelta(days=day)
|
query = query.filter(create_time__gte=TimeUtils.get_day_start())
|
||||||
query = query.filter(create_time__gte=time)
|
|
||||||
data_list = (
|
data_list = (
|
||||||
await query.annotate(count=Count("id"))
|
await query.annotate(count=Count("id"))
|
||||||
.group_by("plugin_name")
|
.group_by("plugin_name")
|
||||||
@ -89,8 +87,7 @@ class StatisticsManage:
|
|||||||
if group_id:
|
if group_id:
|
||||||
query = query.filter(group_id=group_id)
|
query = query.filter(group_id=group_id)
|
||||||
if day:
|
if day:
|
||||||
time = datetime.now() - timedelta(days=day)
|
query = query.filter(create_time__gte=TimeUtils.get_day_start())
|
||||||
query = query.filter(create_time__gte=time)
|
|
||||||
data_list = (
|
data_list = (
|
||||||
await query.annotate(count=Count("id"))
|
await query.annotate(count=Count("id"))
|
||||||
.group_by("plugin_name")
|
.group_by("plugin_name")
|
||||||
@ -106,8 +103,7 @@ class StatisticsManage:
|
|||||||
async def get_group_statistics(cls, group_id: str, day: int | None, title: str):
|
async def get_group_statistics(cls, group_id: str, day: int | None, title: str):
|
||||||
query = Statistics.filter(group_id=group_id)
|
query = Statistics.filter(group_id=group_id)
|
||||||
if day:
|
if day:
|
||||||
time = datetime.now() - timedelta(days=day)
|
query = query.filter(create_time__gte=TimeUtils.get_day_start())
|
||||||
query = query.filter(create_time__gte=time)
|
|
||||||
data_list = (
|
data_list = (
|
||||||
await query.annotate(count=Count("id"))
|
await query.annotate(count=Count("id"))
|
||||||
.group_by("plugin_name")
|
.group_by("plugin_name")
|
||||||
|
|||||||
@ -28,7 +28,7 @@ from nonebot_plugin_alconna.uniseg.segment import (
|
|||||||
)
|
)
|
||||||
from nonebot_plugin_session import EventSession
|
from nonebot_plugin_session import EventSession
|
||||||
|
|
||||||
from zhenxun.configs.utils import PluginExtraData, RegisterConfig, Task
|
from zhenxun.configs.utils import PluginExtraData, Task
|
||||||
from zhenxun.utils.enum import PluginType
|
from zhenxun.utils.enum import PluginType
|
||||||
from zhenxun.utils.message import MessageUtils
|
from zhenxun.utils.message import MessageUtils
|
||||||
|
|
||||||
@ -73,16 +73,6 @@ __plugin_meta__ = PluginMetadata(
|
|||||||
author="HibiKier",
|
author="HibiKier",
|
||||||
version="1.2",
|
version="1.2",
|
||||||
plugin_type=PluginType.SUPERUSER,
|
plugin_type=PluginType.SUPERUSER,
|
||||||
configs=[
|
|
||||||
RegisterConfig(
|
|
||||||
module="_task",
|
|
||||||
key="DEFAULT_BROADCAST",
|
|
||||||
value=True,
|
|
||||||
help="被动 广播 进群默认开关状态",
|
|
||||||
default_value=True,
|
|
||||||
type=bool,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
tasks=[Task(module="broadcast", name="广播")],
|
tasks=[Task(module="broadcast", name="广播")],
|
||||||
).to_dict(),
|
).to_dict(),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -106,17 +106,33 @@ class ConfigGroup(BaseModel):
|
|||||||
if value_to_process is None:
|
if value_to_process is None:
|
||||||
return default
|
return default
|
||||||
|
|
||||||
if cfg.type:
|
if cfg.arg_parser:
|
||||||
if build_model and _is_pydantic_type(cfg.type):
|
|
||||||
try:
|
try:
|
||||||
return parse_as(cfg.type, value_to_process)
|
return cfg.arg_parser(value_to_process)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Pydantic 模型解析失败 (key: {c.upper()}). ", e=e)
|
logger.debug(
|
||||||
try:
|
f"配置项类型转换 MODULE: [<u><y>{self.module}</y></u>] | "
|
||||||
return cattrs.structure(value_to_process, cfg.type)
|
f"KEY: [<u><y>{key}</y></u>] 的自定义解析器失败,将使用原始值",
|
||||||
except Exception as e:
|
e=e,
|
||||||
logger.warning(f"Cattrs 结构化失败 (key: {key}),返回原始值。", e=e)
|
)
|
||||||
|
return value_to_process
|
||||||
|
|
||||||
|
if not build_model or not cfg.type:
|
||||||
|
return value_to_process
|
||||||
|
|
||||||
|
try:
|
||||||
|
if _is_pydantic_type(cfg.type):
|
||||||
|
parsed_value = parse_as(cfg.type, value_to_process)
|
||||||
|
return parsed_value
|
||||||
|
else:
|
||||||
|
structured_value = cattrs.structure(value_to_process, cfg.type)
|
||||||
|
return structured_value
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"❌ 配置项 '{self.module}.{key}' 自动类型转换失败 "
|
||||||
|
f"(目标类型: {cfg.type}),将返回原始值。请检查配置文件格式。错误: {e}",
|
||||||
|
e=e,
|
||||||
|
)
|
||||||
return value_to_process
|
return value_to_process
|
||||||
|
|
||||||
def to_dict(self, **kwargs):
|
def to_dict(self, **kwargs):
|
||||||
|
|||||||
@ -49,7 +49,8 @@ class ChatHistory(Model):
|
|||||||
o = "-" if order == "DESC" else ""
|
o = "-" if order == "DESC" else ""
|
||||||
query = cls.filter(group_id=gid) if gid else cls
|
query = cls.filter(group_id=gid) if gid else cls
|
||||||
if date_scope:
|
if date_scope:
|
||||||
query = query.filter(create_time__range=date_scope)
|
filter_scope = (date_scope[0].isoformat(" "), date_scope[1].isoformat(" "))
|
||||||
|
query = query.filter(create_time__range=filter_scope)
|
||||||
return list(
|
return list(
|
||||||
await query.annotate(count=Count("user_id"))
|
await query.annotate(count=Count("user_id"))
|
||||||
.order_by(f"{o}count")
|
.order_by(f"{o}count")
|
||||||
|
|||||||
@ -90,11 +90,12 @@ class LevelUser(Model):
|
|||||||
返回:
|
返回:
|
||||||
bool: 是否大于level
|
bool: 是否大于level
|
||||||
"""
|
"""
|
||||||
|
if level == 0:
|
||||||
|
return True
|
||||||
if group_id:
|
if group_id:
|
||||||
if user := await cls.get_or_none(user_id=user_id, group_id=group_id):
|
if user := await cls.get_or_none(user_id=user_id, group_id=group_id):
|
||||||
return user.user_level >= level
|
return user.user_level >= level
|
||||||
else:
|
elif user_list := await cls.filter(user_id=user_id).all():
|
||||||
if user_list := await cls.filter(user_id=user_id).all():
|
|
||||||
user = max(user_list, key=lambda x: x.user_level)
|
user = max(user_list, key=lambda x: x.user_level)
|
||||||
return user.user_level >= level
|
return user.user_level >= level
|
||||||
return False
|
return False
|
||||||
@ -119,8 +120,7 @@ class LevelUser(Model):
|
|||||||
return [
|
return [
|
||||||
# 将user_id改为user_id
|
# 将user_id改为user_id
|
||||||
"ALTER TABLE level_users RENAME COLUMN user_qq TO user_id;",
|
"ALTER TABLE level_users RENAME COLUMN user_qq TO user_id;",
|
||||||
"ALTER TABLE level_users "
|
"ALTER TABLE level_users ALTER COLUMN user_id TYPE character varying(255);",
|
||||||
"ALTER COLUMN user_id TYPE character varying(255);",
|
|
||||||
# 将user_id字段类型改为character varying(255)
|
# 将user_id字段类型改为character varying(255)
|
||||||
"ALTER TABLE level_users "
|
"ALTER TABLE level_users "
|
||||||
"ALTER COLUMN group_id TYPE character varying(255);",
|
"ALTER COLUMN group_id TYPE character varying(255);",
|
||||||
|
|||||||
@ -1,3 +1,14 @@
|
|||||||
|
"""
|
||||||
|
Zhenxun Bot - 核心服务模块
|
||||||
|
|
||||||
|
主要服务包括:
|
||||||
|
- 数据库上下文 (db_context): 提供数据库模型基类和连接管理。
|
||||||
|
- 日志服务 (log): 提供增强的、带上下文的日志记录器。
|
||||||
|
- LLM服务 (llm): 提供与大语言模型交互的统一API。
|
||||||
|
- 插件生命周期管理 (plugin_init): 支持插件安装和卸载时的钩子函数。
|
||||||
|
- 定时任务调度器 (scheduler): 提供持久化的、可管理的定时任务服务。
|
||||||
|
"""
|
||||||
|
|
||||||
from nonebot import require
|
from nonebot import require
|
||||||
|
|
||||||
require("nonebot_plugin_apscheduler")
|
require("nonebot_plugin_apscheduler")
|
||||||
@ -6,3 +17,33 @@ require("nonebot_plugin_session")
|
|||||||
require("nonebot_plugin_htmlrender")
|
require("nonebot_plugin_htmlrender")
|
||||||
require("nonebot_plugin_uninfo")
|
require("nonebot_plugin_uninfo")
|
||||||
require("nonebot_plugin_waiter")
|
require("nonebot_plugin_waiter")
|
||||||
|
|
||||||
|
from .db_context import Model, disconnect
|
||||||
|
from .llm import (
|
||||||
|
AI,
|
||||||
|
LLMContentPart,
|
||||||
|
LLMException,
|
||||||
|
LLMMessage,
|
||||||
|
get_model_instance,
|
||||||
|
list_available_models,
|
||||||
|
tool_registry,
|
||||||
|
)
|
||||||
|
from .log import logger
|
||||||
|
from .plugin_init import PluginInit, PluginInitManager
|
||||||
|
from .scheduler import scheduler_manager
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AI",
|
||||||
|
"LLMContentPart",
|
||||||
|
"LLMException",
|
||||||
|
"LLMMessage",
|
||||||
|
"Model",
|
||||||
|
"PluginInit",
|
||||||
|
"PluginInitManager",
|
||||||
|
"disconnect",
|
||||||
|
"get_model_instance",
|
||||||
|
"list_available_models",
|
||||||
|
"logger",
|
||||||
|
"scheduler_manager",
|
||||||
|
"tool_registry",
|
||||||
|
]
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -10,10 +10,10 @@ from .api import (
|
|||||||
TaskType,
|
TaskType,
|
||||||
analyze,
|
analyze,
|
||||||
analyze_multimodal,
|
analyze_multimodal,
|
||||||
analyze_with_images,
|
|
||||||
chat,
|
chat,
|
||||||
code,
|
code,
|
||||||
embed,
|
embed,
|
||||||
|
pipeline_chat,
|
||||||
search,
|
search,
|
||||||
search_multimodal,
|
search_multimodal,
|
||||||
)
|
)
|
||||||
@ -35,6 +35,7 @@ from .manager import (
|
|||||||
list_model_identifiers,
|
list_model_identifiers,
|
||||||
set_global_default_model_name,
|
set_global_default_model_name,
|
||||||
)
|
)
|
||||||
|
from .tools import tool_registry
|
||||||
from .types import (
|
from .types import (
|
||||||
EmbeddingTaskType,
|
EmbeddingTaskType,
|
||||||
LLMContentPart,
|
LLMContentPart,
|
||||||
@ -43,6 +44,7 @@ from .types import (
|
|||||||
LLMMessage,
|
LLMMessage,
|
||||||
LLMResponse,
|
LLMResponse,
|
||||||
LLMTool,
|
LLMTool,
|
||||||
|
MCPCompatible,
|
||||||
ModelDetail,
|
ModelDetail,
|
||||||
ModelInfo,
|
ModelInfo,
|
||||||
ModelProvider,
|
ModelProvider,
|
||||||
@ -51,7 +53,7 @@ from .types import (
|
|||||||
ToolMetadata,
|
ToolMetadata,
|
||||||
UsageInfo,
|
UsageInfo,
|
||||||
)
|
)
|
||||||
from .utils import create_multimodal_message, unimsg_to_llm_parts
|
from .utils import create_multimodal_message, message_to_unimessage, unimsg_to_llm_parts
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AI",
|
"AI",
|
||||||
@ -65,6 +67,7 @@ __all__ = [
|
|||||||
"LLMMessage",
|
"LLMMessage",
|
||||||
"LLMResponse",
|
"LLMResponse",
|
||||||
"LLMTool",
|
"LLMTool",
|
||||||
|
"MCPCompatible",
|
||||||
"ModelDetail",
|
"ModelDetail",
|
||||||
"ModelInfo",
|
"ModelInfo",
|
||||||
"ModelName",
|
"ModelName",
|
||||||
@ -76,7 +79,6 @@ __all__ = [
|
|||||||
"UsageInfo",
|
"UsageInfo",
|
||||||
"analyze",
|
"analyze",
|
||||||
"analyze_multimodal",
|
"analyze_multimodal",
|
||||||
"analyze_with_images",
|
|
||||||
"chat",
|
"chat",
|
||||||
"clear_model_cache",
|
"clear_model_cache",
|
||||||
"code",
|
"code",
|
||||||
@ -88,9 +90,12 @@ __all__ = [
|
|||||||
"list_available_models",
|
"list_available_models",
|
||||||
"list_embedding_models",
|
"list_embedding_models",
|
||||||
"list_model_identifiers",
|
"list_model_identifiers",
|
||||||
|
"message_to_unimessage",
|
||||||
|
"pipeline_chat",
|
||||||
"register_llm_configs",
|
"register_llm_configs",
|
||||||
"search",
|
"search",
|
||||||
"search_multimodal",
|
"search_multimodal",
|
||||||
"set_global_default_model_name",
|
"set_global_default_model_name",
|
||||||
|
"tool_registry",
|
||||||
"unimsg_to_llm_parts",
|
"unimsg_to_llm_parts",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -8,7 +8,6 @@ from .base import BaseAdapter, OpenAICompatAdapter, RequestData, ResponseData
|
|||||||
from .factory import LLMAdapterFactory, get_adapter_for_api_type, register_adapter
|
from .factory import LLMAdapterFactory, get_adapter_for_api_type, register_adapter
|
||||||
from .gemini import GeminiAdapter
|
from .gemini import GeminiAdapter
|
||||||
from .openai import OpenAIAdapter
|
from .openai import OpenAIAdapter
|
||||||
from .zhipu import ZhipuAdapter
|
|
||||||
|
|
||||||
LLMAdapterFactory.initialize()
|
LLMAdapterFactory.initialize()
|
||||||
|
|
||||||
@ -20,7 +19,6 @@ __all__ = [
|
|||||||
"OpenAICompatAdapter",
|
"OpenAICompatAdapter",
|
||||||
"RequestData",
|
"RequestData",
|
||||||
"ResponseData",
|
"ResponseData",
|
||||||
"ZhipuAdapter",
|
|
||||||
"get_adapter_for_api_type",
|
"get_adapter_for_api_type",
|
||||||
"register_adapter",
|
"register_adapter",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -17,6 +17,7 @@ if TYPE_CHECKING:
|
|||||||
from ..service import LLMModel
|
from ..service import LLMModel
|
||||||
from ..types.content import LLMMessage
|
from ..types.content import LLMMessage
|
||||||
from ..types.enums import EmbeddingTaskType
|
from ..types.enums import EmbeddingTaskType
|
||||||
|
from ..types.models import LLMTool
|
||||||
|
|
||||||
|
|
||||||
class RequestData(BaseModel):
|
class RequestData(BaseModel):
|
||||||
@ -60,7 +61,7 @@ class BaseAdapter(ABC):
|
|||||||
"""支持的API类型列表"""
|
"""支持的API类型列表"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def prepare_simple_request(
|
async def prepare_simple_request(
|
||||||
self,
|
self,
|
||||||
model: "LLMModel",
|
model: "LLMModel",
|
||||||
api_key: str,
|
api_key: str,
|
||||||
@ -86,7 +87,7 @@ class BaseAdapter(ABC):
|
|||||||
|
|
||||||
config = model._generation_config
|
config = model._generation_config
|
||||||
|
|
||||||
return self.prepare_advanced_request(
|
return await self.prepare_advanced_request(
|
||||||
model=model,
|
model=model,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
@ -96,13 +97,13 @@ class BaseAdapter(ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def prepare_advanced_request(
|
async def prepare_advanced_request(
|
||||||
self,
|
self,
|
||||||
model: "LLMModel",
|
model: "LLMModel",
|
||||||
api_key: str,
|
api_key: str,
|
||||||
messages: list["LLMMessage"],
|
messages: list["LLMMessage"],
|
||||||
config: "LLMGenerationConfig | None" = None,
|
config: "LLMGenerationConfig | None" = None,
|
||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list["LLMTool"] | None = None,
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
) -> RequestData:
|
) -> RequestData:
|
||||||
"""准备高级请求"""
|
"""准备高级请求"""
|
||||||
@ -238,6 +239,9 @@ class BaseAdapter(ABC):
|
|||||||
message = choice.get("message", {})
|
message = choice.get("message", {})
|
||||||
content = message.get("content", "")
|
content = message.get("content", "")
|
||||||
|
|
||||||
|
if content:
|
||||||
|
content = content.strip()
|
||||||
|
|
||||||
parsed_tool_calls: list[LLMToolCall] | None = None
|
parsed_tool_calls: list[LLMToolCall] | None = None
|
||||||
if message_tool_calls := message.get("tool_calls"):
|
if message_tool_calls := message.get("tool_calls"):
|
||||||
from ..types.models import LLMToolFunction
|
from ..types.models import LLMToolFunction
|
||||||
@ -375,7 +379,7 @@ class BaseAdapter(ABC):
|
|||||||
if model.temperature is not None:
|
if model.temperature is not None:
|
||||||
base_config["temperature"] = model.temperature
|
base_config["temperature"] = model.temperature
|
||||||
if model.max_tokens is not None:
|
if model.max_tokens is not None:
|
||||||
if model.api_type in ["gemini", "gemini_native"]:
|
if model.api_type == "gemini":
|
||||||
base_config["maxOutputTokens"] = model.max_tokens
|
base_config["maxOutputTokens"] = model.max_tokens
|
||||||
else:
|
else:
|
||||||
base_config["max_tokens"] = model.max_tokens
|
base_config["max_tokens"] = model.max_tokens
|
||||||
@ -401,26 +405,51 @@ class OpenAICompatAdapter(BaseAdapter):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_chat_endpoint(self) -> str:
|
def get_chat_endpoint(self, model: "LLMModel") -> str:
|
||||||
"""子类必须实现,返回 chat completions 的端点"""
|
"""子类必须实现,返回 chat completions 的端点"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_embedding_endpoint(self) -> str:
|
def get_embedding_endpoint(self, model: "LLMModel") -> str:
|
||||||
"""子类必须实现,返回 embeddings 的端点"""
|
"""子类必须实现,返回 embeddings 的端点"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def prepare_advanced_request(
|
async def prepare_simple_request(
|
||||||
|
self,
|
||||||
|
model: "LLMModel",
|
||||||
|
api_key: str,
|
||||||
|
prompt: str,
|
||||||
|
history: list[dict[str, str]] | None = None,
|
||||||
|
) -> RequestData:
|
||||||
|
"""准备简单文本生成请求 - OpenAI兼容API的通用实现"""
|
||||||
|
url = self.get_api_url(model, self.get_chat_endpoint(model))
|
||||||
|
headers = self.get_base_headers(api_key)
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
if history:
|
||||||
|
messages.extend(history)
|
||||||
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
|
body = {
|
||||||
|
"model": model.model_name,
|
||||||
|
"messages": messages,
|
||||||
|
}
|
||||||
|
|
||||||
|
body = self.apply_config_override(model, body)
|
||||||
|
|
||||||
|
return RequestData(url=url, headers=headers, body=body)
|
||||||
|
|
||||||
|
async def prepare_advanced_request(
|
||||||
self,
|
self,
|
||||||
model: "LLMModel",
|
model: "LLMModel",
|
||||||
api_key: str,
|
api_key: str,
|
||||||
messages: list["LLMMessage"],
|
messages: list["LLMMessage"],
|
||||||
config: "LLMGenerationConfig | None" = None,
|
config: "LLMGenerationConfig | None" = None,
|
||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list["LLMTool"] | None = None,
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
) -> RequestData:
|
) -> RequestData:
|
||||||
"""准备高级请求 - OpenAI兼容格式"""
|
"""准备高级请求 - OpenAI兼容格式"""
|
||||||
url = self.get_api_url(model, self.get_chat_endpoint())
|
url = self.get_api_url(model, self.get_chat_endpoint(model))
|
||||||
headers = self.get_base_headers(api_key)
|
headers = self.get_base_headers(api_key)
|
||||||
openai_messages = self.convert_messages_to_openai_format(messages)
|
openai_messages = self.convert_messages_to_openai_format(messages)
|
||||||
|
|
||||||
@ -430,7 +459,21 @@ class OpenAICompatAdapter(BaseAdapter):
|
|||||||
}
|
}
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
body["tools"] = tools
|
openai_tools = []
|
||||||
|
for tool in tools:
|
||||||
|
if tool.type == "function" and tool.function:
|
||||||
|
openai_tools.append({"type": "function", "function": tool.function})
|
||||||
|
elif tool.type == "mcp" and tool.mcp_session:
|
||||||
|
if callable(tool.mcp_session):
|
||||||
|
raise ValueError(
|
||||||
|
"适配器接收到未激活的 MCP 会话工厂。"
|
||||||
|
"会话工厂应该在 LLMModel.generate_response 中被激活。"
|
||||||
|
)
|
||||||
|
openai_tools.append(
|
||||||
|
tool.mcp_session.to_api_tool(api_type=self.api_type)
|
||||||
|
)
|
||||||
|
if openai_tools:
|
||||||
|
body["tools"] = openai_tools
|
||||||
if tool_choice:
|
if tool_choice:
|
||||||
body["tool_choice"] = tool_choice
|
body["tool_choice"] = tool_choice
|
||||||
|
|
||||||
@ -444,7 +487,7 @@ class OpenAICompatAdapter(BaseAdapter):
|
|||||||
is_advanced: bool = False,
|
is_advanced: bool = False,
|
||||||
) -> ResponseData:
|
) -> ResponseData:
|
||||||
"""解析响应 - 直接使用基类的 OpenAI 格式解析"""
|
"""解析响应 - 直接使用基类的 OpenAI 格式解析"""
|
||||||
_ = model, is_advanced # 未使用的参数
|
_ = model, is_advanced
|
||||||
return self.parse_openai_response(response_json)
|
return self.parse_openai_response(response_json)
|
||||||
|
|
||||||
def prepare_embedding_request(
|
def prepare_embedding_request(
|
||||||
@ -456,8 +499,8 @@ class OpenAICompatAdapter(BaseAdapter):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> RequestData:
|
) -> RequestData:
|
||||||
"""准备嵌入请求 - OpenAI兼容格式"""
|
"""准备嵌入请求 - OpenAI兼容格式"""
|
||||||
_ = task_type # 未使用的参数
|
_ = task_type
|
||||||
url = self.get_api_url(model, self.get_embedding_endpoint())
|
url = self.get_api_url(model, self.get_embedding_endpoint(model))
|
||||||
headers = self.get_base_headers(api_key)
|
headers = self.get_base_headers(api_key)
|
||||||
|
|
||||||
body = {
|
body = {
|
||||||
@ -465,7 +508,6 @@ class OpenAICompatAdapter(BaseAdapter):
|
|||||||
"input": texts,
|
"input": texts,
|
||||||
}
|
}
|
||||||
|
|
||||||
# 应用额外的配置参数
|
|
||||||
if kwargs:
|
if kwargs:
|
||||||
body.update(kwargs)
|
body.update(kwargs)
|
||||||
|
|
||||||
|
|||||||
@ -22,10 +22,8 @@ class LLMAdapterFactory:
|
|||||||
|
|
||||||
from .gemini import GeminiAdapter
|
from .gemini import GeminiAdapter
|
||||||
from .openai import OpenAIAdapter
|
from .openai import OpenAIAdapter
|
||||||
from .zhipu import ZhipuAdapter
|
|
||||||
|
|
||||||
cls.register_adapter(OpenAIAdapter())
|
cls.register_adapter(OpenAIAdapter())
|
||||||
cls.register_adapter(ZhipuAdapter())
|
|
||||||
cls.register_adapter(GeminiAdapter())
|
cls.register_adapter(GeminiAdapter())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
|||||||
from ..service import LLMModel
|
from ..service import LLMModel
|
||||||
from ..types.content import LLMMessage
|
from ..types.content import LLMMessage
|
||||||
from ..types.enums import EmbeddingTaskType
|
from ..types.enums import EmbeddingTaskType
|
||||||
from ..types.models import LLMToolCall
|
from ..types.models import LLMTool, LLMToolCall
|
||||||
|
|
||||||
|
|
||||||
class GeminiAdapter(BaseAdapter):
|
class GeminiAdapter(BaseAdapter):
|
||||||
@ -38,30 +38,16 @@ class GeminiAdapter(BaseAdapter):
|
|||||||
|
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
def prepare_advanced_request(
|
async def prepare_advanced_request(
|
||||||
self,
|
self,
|
||||||
model: "LLMModel",
|
model: "LLMModel",
|
||||||
api_key: str,
|
api_key: str,
|
||||||
messages: list["LLMMessage"],
|
messages: list["LLMMessage"],
|
||||||
config: "LLMGenerationConfig | None" = None,
|
config: "LLMGenerationConfig | None" = None,
|
||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list["LLMTool"] | None = None,
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
) -> RequestData:
|
) -> RequestData:
|
||||||
"""准备高级请求"""
|
"""准备高级请求"""
|
||||||
return self._prepare_request(
|
|
||||||
model, api_key, messages, config, tools, tool_choice
|
|
||||||
)
|
|
||||||
|
|
||||||
def _prepare_request(
|
|
||||||
self,
|
|
||||||
model: "LLMModel",
|
|
||||||
api_key: str,
|
|
||||||
messages: list["LLMMessage"],
|
|
||||||
config: "LLMGenerationConfig | None" = None,
|
|
||||||
tools: list[dict[str, Any]] | None = None,
|
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
|
||||||
) -> RequestData:
|
|
||||||
"""准备 Gemini API 请求 - 支持所有高级功能"""
|
|
||||||
effective_config = config if config is not None else model._generation_config
|
effective_config = config if config is not None else model._generation_config
|
||||||
|
|
||||||
endpoint = self._get_gemini_endpoint(model, effective_config)
|
endpoint = self._get_gemini_endpoint(model, effective_config)
|
||||||
@ -78,7 +64,8 @@ class GeminiAdapter(BaseAdapter):
|
|||||||
system_instruction_parts = [{"text": msg.content}]
|
system_instruction_parts = [{"text": msg.content}]
|
||||||
elif isinstance(msg.content, list):
|
elif isinstance(msg.content, list):
|
||||||
system_instruction_parts = [
|
system_instruction_parts = [
|
||||||
part.convert_for_api("gemini") for part in msg.content
|
await part.convert_for_api_async("gemini")
|
||||||
|
for part in msg.content
|
||||||
]
|
]
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -87,7 +74,9 @@ class GeminiAdapter(BaseAdapter):
|
|||||||
current_parts.append({"text": msg.content})
|
current_parts.append({"text": msg.content})
|
||||||
elif isinstance(msg.content, list):
|
elif isinstance(msg.content, list):
|
||||||
for part_obj in msg.content:
|
for part_obj in msg.content:
|
||||||
current_parts.append(part_obj.convert_for_api("gemini"))
|
current_parts.append(
|
||||||
|
await part_obj.convert_for_api_async("gemini")
|
||||||
|
)
|
||||||
gemini_contents.append({"role": "user", "parts": current_parts})
|
gemini_contents.append({"role": "user", "parts": current_parts})
|
||||||
|
|
||||||
elif msg.role == "assistant" or msg.role == "model":
|
elif msg.role == "assistant" or msg.role == "model":
|
||||||
@ -95,7 +84,9 @@ class GeminiAdapter(BaseAdapter):
|
|||||||
current_parts.append({"text": msg.content})
|
current_parts.append({"text": msg.content})
|
||||||
elif isinstance(msg.content, list):
|
elif isinstance(msg.content, list):
|
||||||
for part_obj in msg.content:
|
for part_obj in msg.content:
|
||||||
current_parts.append(part_obj.convert_for_api("gemini"))
|
current_parts.append(
|
||||||
|
await part_obj.convert_for_api_async("gemini")
|
||||||
|
)
|
||||||
|
|
||||||
if msg.tool_calls:
|
if msg.tool_calls:
|
||||||
import json
|
import json
|
||||||
@ -154,16 +145,22 @@ class GeminiAdapter(BaseAdapter):
|
|||||||
|
|
||||||
all_tools_for_request = []
|
all_tools_for_request = []
|
||||||
if tools:
|
if tools:
|
||||||
for tool_item in tools:
|
for tool in tools:
|
||||||
if isinstance(tool_item, dict):
|
if tool.type == "function" and tool.function:
|
||||||
if "name" in tool_item and "description" in tool_item:
|
|
||||||
all_tools_for_request.append(
|
all_tools_for_request.append(
|
||||||
{"functionDeclarations": [tool_item]}
|
{"functionDeclarations": [tool.function]}
|
||||||
)
|
)
|
||||||
else:
|
elif tool.type == "mcp" and tool.mcp_session:
|
||||||
all_tools_for_request.append(tool_item)
|
if callable(tool.mcp_session):
|
||||||
else:
|
raise ValueError(
|
||||||
all_tools_for_request.append(tool_item)
|
"适配器接收到未激活的 MCP 会话工厂。"
|
||||||
|
"会话工厂应该在 LLMModel.generate_response 中被激活。"
|
||||||
|
)
|
||||||
|
all_tools_for_request.append(
|
||||||
|
tool.mcp_session.to_api_tool(api_type=self.api_type)
|
||||||
|
)
|
||||||
|
elif tool.type == "google_search":
|
||||||
|
all_tools_for_request.append({"googleSearch": {}})
|
||||||
|
|
||||||
if effective_config:
|
if effective_config:
|
||||||
if getattr(effective_config, "enable_grounding", False):
|
if getattr(effective_config, "enable_grounding", False):
|
||||||
@ -183,11 +180,7 @@ class GeminiAdapter(BaseAdapter):
|
|||||||
logger.debug("隐式启用代码执行工具。")
|
logger.debug("隐式启用代码执行工具。")
|
||||||
|
|
||||||
if all_tools_for_request:
|
if all_tools_for_request:
|
||||||
gemini_api_tools = self._convert_tools_to_gemini_format(
|
body["tools"] = all_tools_for_request
|
||||||
all_tools_for_request
|
|
||||||
)
|
|
||||||
if gemini_api_tools:
|
|
||||||
body["tools"] = gemini_api_tools
|
|
||||||
|
|
||||||
final_tool_choice = tool_choice
|
final_tool_choice = tool_choice
|
||||||
if final_tool_choice is None and effective_config:
|
if final_tool_choice is None and effective_config:
|
||||||
@ -241,38 +234,6 @@ class GeminiAdapter(BaseAdapter):
|
|||||||
|
|
||||||
return f"/v1beta/models/{model.model_name}:generateContent"
|
return f"/v1beta/models/{model.model_name}:generateContent"
|
||||||
|
|
||||||
def _convert_tools_to_gemini_format(
|
|
||||||
self, tools: list[dict[str, Any]]
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
"""转换工具格式为Gemini格式"""
|
|
||||||
gemini_tools = []
|
|
||||||
|
|
||||||
for tool in tools:
|
|
||||||
if tool.get("type") == "function":
|
|
||||||
func = tool["function"]
|
|
||||||
gemini_tool = {
|
|
||||||
"functionDeclarations": [
|
|
||||||
{
|
|
||||||
"name": func["name"],
|
|
||||||
"description": func.get("description", ""),
|
|
||||||
"parameters": func.get("parameters", {}),
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
gemini_tools.append(gemini_tool)
|
|
||||||
elif tool.get("type") == "code_execution":
|
|
||||||
gemini_tools.append(
|
|
||||||
{"codeExecution": {"language": tool.get("language", "python")}}
|
|
||||||
)
|
|
||||||
elif tool.get("type") == "google_search":
|
|
||||||
gemini_tools.append({"googleSearch": {}})
|
|
||||||
elif "googleSearch" in tool:
|
|
||||||
gemini_tools.append({"googleSearch": tool["googleSearch"]})
|
|
||||||
elif "codeExecution" in tool:
|
|
||||||
gemini_tools.append({"codeExecution": tool["codeExecution"]})
|
|
||||||
|
|
||||||
return gemini_tools
|
|
||||||
|
|
||||||
def _convert_tool_choice_to_gemini(
|
def _convert_tool_choice_to_gemini(
|
||||||
self, tool_choice_value: str | dict[str, Any]
|
self, tool_choice_value: str | dict[str, Any]
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
@ -395,10 +356,11 @@ class GeminiAdapter(BaseAdapter):
|
|||||||
for category, threshold in custom_safety_settings.items():
|
for category, threshold in custom_safety_settings.items():
|
||||||
safety_settings.append({"category": category, "threshold": threshold})
|
safety_settings.append({"category": category, "threshold": threshold})
|
||||||
else:
|
else:
|
||||||
|
from ..config.providers import get_gemini_safety_threshold
|
||||||
|
|
||||||
|
threshold = get_gemini_safety_threshold()
|
||||||
for category in safety_categories:
|
for category in safety_categories:
|
||||||
safety_settings.append(
|
safety_settings.append({"category": category, "threshold": threshold})
|
||||||
{"category": category, "threshold": "BLOCK_MEDIUM_AND_ABOVE"}
|
|
||||||
)
|
|
||||||
|
|
||||||
return safety_settings if safety_settings else None
|
return safety_settings if safety_settings else None
|
||||||
|
|
||||||
|
|||||||
@ -1,12 +1,12 @@
|
|||||||
"""
|
"""
|
||||||
OpenAI API 适配器
|
OpenAI API 适配器
|
||||||
|
|
||||||
支持 OpenAI、DeepSeek 和其他 OpenAI 兼容的 API 服务。
|
支持 OpenAI、DeepSeek、智谱AI 和其他 OpenAI 兼容的 API 服务。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from .base import OpenAICompatAdapter, RequestData
|
from .base import OpenAICompatAdapter
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..service import LLMModel
|
from ..service import LLMModel
|
||||||
@ -21,37 +21,18 @@ class OpenAIAdapter(OpenAICompatAdapter):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def supported_api_types(self) -> list[str]:
|
def supported_api_types(self) -> list[str]:
|
||||||
return ["openai", "deepseek", "general_openai_compat"]
|
return ["openai", "deepseek", "zhipu", "general_openai_compat", "ark"]
|
||||||
|
|
||||||
def get_chat_endpoint(self) -> str:
|
def get_chat_endpoint(self, model: "LLMModel") -> str:
|
||||||
"""返回聊天完成端点"""
|
"""返回聊天完成端点"""
|
||||||
|
if model.api_type == "ark":
|
||||||
|
return "/api/v3/chat/completions"
|
||||||
|
if model.api_type == "zhipu":
|
||||||
|
return "/api/paas/v4/chat/completions"
|
||||||
return "/v1/chat/completions"
|
return "/v1/chat/completions"
|
||||||
|
|
||||||
def get_embedding_endpoint(self) -> str:
|
def get_embedding_endpoint(self, model: "LLMModel") -> str:
|
||||||
"""返回嵌入端点"""
|
"""根据API类型返回嵌入端点"""
|
||||||
|
if model.api_type == "zhipu":
|
||||||
|
return "/v4/embeddings"
|
||||||
return "/v1/embeddings"
|
return "/v1/embeddings"
|
||||||
|
|
||||||
def prepare_simple_request(
|
|
||||||
self,
|
|
||||||
model: "LLMModel",
|
|
||||||
api_key: str,
|
|
||||||
prompt: str,
|
|
||||||
history: list[dict[str, str]] | None = None,
|
|
||||||
) -> RequestData:
|
|
||||||
"""准备简单文本生成请求 - OpenAI优化实现"""
|
|
||||||
url = self.get_api_url(model, self.get_chat_endpoint())
|
|
||||||
headers = self.get_base_headers(api_key)
|
|
||||||
|
|
||||||
messages = []
|
|
||||||
if history:
|
|
||||||
messages.extend(history)
|
|
||||||
messages.append({"role": "user", "content": prompt})
|
|
||||||
|
|
||||||
body = {
|
|
||||||
"model": model.model_name,
|
|
||||||
"messages": messages,
|
|
||||||
}
|
|
||||||
|
|
||||||
body = self.apply_config_override(model, body)
|
|
||||||
|
|
||||||
return RequestData(url=url, headers=headers, body=body)
|
|
||||||
|
|||||||
@ -1,57 +0,0 @@
|
|||||||
"""
|
|
||||||
智谱 AI API 适配器
|
|
||||||
|
|
||||||
支持智谱 AI 的 GLM 系列模型,使用 OpenAI 兼容的接口格式。
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from .base import OpenAICompatAdapter, RequestData
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from ..service import LLMModel
|
|
||||||
|
|
||||||
|
|
||||||
class ZhipuAdapter(OpenAICompatAdapter):
|
|
||||||
"""智谱AI适配器 - 使用智谱AI专用的OpenAI兼容接口"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def api_type(self) -> str:
|
|
||||||
return "zhipu"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supported_api_types(self) -> list[str]:
|
|
||||||
return ["zhipu"]
|
|
||||||
|
|
||||||
def get_chat_endpoint(self) -> str:
|
|
||||||
"""返回智谱AI聊天完成端点"""
|
|
||||||
return "/api/paas/v4/chat/completions"
|
|
||||||
|
|
||||||
def get_embedding_endpoint(self) -> str:
|
|
||||||
"""返回智谱AI嵌入端点"""
|
|
||||||
return "/v4/embeddings"
|
|
||||||
|
|
||||||
def prepare_simple_request(
|
|
||||||
self,
|
|
||||||
model: "LLMModel",
|
|
||||||
api_key: str,
|
|
||||||
prompt: str,
|
|
||||||
history: list[dict[str, str]] | None = None,
|
|
||||||
) -> RequestData:
|
|
||||||
"""准备简单文本生成请求 - 智谱AI优化实现"""
|
|
||||||
url = self.get_api_url(model, self.get_chat_endpoint())
|
|
||||||
headers = self.get_base_headers(api_key)
|
|
||||||
|
|
||||||
messages = []
|
|
||||||
if history:
|
|
||||||
messages.extend(history)
|
|
||||||
messages.append({"role": "user", "content": prompt})
|
|
||||||
|
|
||||||
body = {
|
|
||||||
"model": model.model_name,
|
|
||||||
"messages": messages,
|
|
||||||
}
|
|
||||||
|
|
||||||
body = self.apply_config_override(model, body)
|
|
||||||
|
|
||||||
return RequestData(url=url, headers=headers, body=body)
|
|
||||||
@ -2,6 +2,7 @@
|
|||||||
LLM 服务的高级 API 接口
|
LLM 服务的高级 API 接口
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -14,6 +15,7 @@ from zhenxun.services.log import logger
|
|||||||
from .config import CommonOverrides, LLMGenerationConfig
|
from .config import CommonOverrides, LLMGenerationConfig
|
||||||
from .config.providers import get_ai_config
|
from .config.providers import get_ai_config
|
||||||
from .manager import get_global_default_model_name, get_model_instance
|
from .manager import get_global_default_model_name, get_model_instance
|
||||||
|
from .tools import tool_registry
|
||||||
from .types import (
|
from .types import (
|
||||||
EmbeddingTaskType,
|
EmbeddingTaskType,
|
||||||
LLMContentPart,
|
LLMContentPart,
|
||||||
@ -56,6 +58,7 @@ class AIConfig:
|
|||||||
enable_gemini_safe_mode: bool = False
|
enable_gemini_safe_mode: bool = False
|
||||||
enable_gemini_multimodal: bool = False
|
enable_gemini_multimodal: bool = False
|
||||||
enable_gemini_grounding: bool = False
|
enable_gemini_grounding: bool = False
|
||||||
|
default_preserve_media_in_history: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""初始化后从配置中读取默认值"""
|
"""初始化后从配置中读取默认值"""
|
||||||
@ -81,7 +84,7 @@ class AI:
|
|||||||
"""
|
"""
|
||||||
初始化AI服务
|
初始化AI服务
|
||||||
|
|
||||||
Args:
|
参数:
|
||||||
config: AI 配置.
|
config: AI 配置.
|
||||||
history: 可选的初始对话历史.
|
history: 可选的初始对话历史.
|
||||||
"""
|
"""
|
||||||
@ -93,16 +96,65 @@ class AI:
|
|||||||
self.history = []
|
self.history = []
|
||||||
logger.info("AI session history cleared.")
|
logger.info("AI session history cleared.")
|
||||||
|
|
||||||
|
def _sanitize_message_for_history(self, message: LLMMessage) -> LLMMessage:
|
||||||
|
"""
|
||||||
|
净化用于存入历史记录的消息。
|
||||||
|
将非文本的多模态内容部分替换为文本占位符,以避免重复处理。
|
||||||
|
"""
|
||||||
|
if not isinstance(message.content, list):
|
||||||
|
return message
|
||||||
|
|
||||||
|
sanitized_message = copy.deepcopy(message)
|
||||||
|
content_list = sanitized_message.content
|
||||||
|
if not isinstance(content_list, list):
|
||||||
|
return sanitized_message
|
||||||
|
|
||||||
|
new_content_parts: list[LLMContentPart] = []
|
||||||
|
has_multimodal_content = False
|
||||||
|
|
||||||
|
for part in content_list:
|
||||||
|
if isinstance(part, LLMContentPart) and part.type == "text":
|
||||||
|
new_content_parts.append(part)
|
||||||
|
else:
|
||||||
|
has_multimodal_content = True
|
||||||
|
|
||||||
|
if has_multimodal_content:
|
||||||
|
placeholder = "[用户发送了媒体文件,内容已在首次分析时处理]"
|
||||||
|
text_part_found = False
|
||||||
|
for part in new_content_parts:
|
||||||
|
if part.type == "text":
|
||||||
|
part.text = f"{placeholder} {part.text or ''}".strip()
|
||||||
|
text_part_found = True
|
||||||
|
break
|
||||||
|
if not text_part_found:
|
||||||
|
new_content_parts.insert(0, LLMContentPart.text_part(placeholder))
|
||||||
|
|
||||||
|
sanitized_message.content = new_content_parts
|
||||||
|
return sanitized_message
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
message: str | LLMMessage | list[LLMContentPart],
|
message: str | LLMMessage | list[LLMContentPart],
|
||||||
*,
|
*,
|
||||||
model: ModelName = None,
|
model: ModelName = None,
|
||||||
|
preserve_media_in_history: bool | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
进行一次聊天对话。
|
进行一次聊天对话。
|
||||||
此方法会自动使用和更新会话内的历史记录。
|
此方法会自动使用和更新会话内的历史记录。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
message: 用户输入的消息。
|
||||||
|
model: 本次对话要使用的模型。
|
||||||
|
preserve_media_in_history: 是否在历史记录中保留原始多模态信息。
|
||||||
|
- True: 保留,用于深度多轮媒体分析。
|
||||||
|
- False: 不保留,替换为占位符,提高效率。
|
||||||
|
- None (默认): 使用AI实例配置的默认值。
|
||||||
|
**kwargs: 传递给模型的其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 模型的文本响应。
|
||||||
"""
|
"""
|
||||||
current_message: LLMMessage
|
current_message: LLMMessage
|
||||||
if isinstance(message, str):
|
if isinstance(message, str):
|
||||||
@ -127,7 +179,20 @@ class AI:
|
|||||||
final_messages, model, "聊天失败", kwargs
|
final_messages, model, "聊天失败", kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
should_preserve = (
|
||||||
|
preserve_media_in_history
|
||||||
|
if preserve_media_in_history is not None
|
||||||
|
else self.config.default_preserve_media_in_history
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_preserve:
|
||||||
|
logger.debug("深度分析模式:在历史记录中保留原始多模态消息。")
|
||||||
self.history.append(current_message)
|
self.history.append(current_message)
|
||||||
|
else:
|
||||||
|
logger.debug("高效模式:净化历史记录中的多模态消息。")
|
||||||
|
sanitized_user_message = self._sanitize_message_for_history(current_message)
|
||||||
|
self.history.append(sanitized_user_message)
|
||||||
|
|
||||||
self.history.append(LLMMessage.assistant_text_response(response.text))
|
self.history.append(LLMMessage.assistant_text_response(response.text))
|
||||||
|
|
||||||
return response.text
|
return response.text
|
||||||
@ -140,7 +205,18 @@ class AI:
|
|||||||
timeout: int | None = None,
|
timeout: int | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""代码执行"""
|
"""
|
||||||
|
代码执行
|
||||||
|
|
||||||
|
参数:
|
||||||
|
prompt: 代码执行的提示词。
|
||||||
|
model: 要使用的模型名称。
|
||||||
|
timeout: 代码执行超时时间(秒)。
|
||||||
|
**kwargs: 传递给模型的其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
dict[str, Any]: 包含执行结果的字典,包含text、code_executions和success字段。
|
||||||
|
"""
|
||||||
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
|
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
|
||||||
|
|
||||||
config = CommonOverrides.gemini_code_execution()
|
config = CommonOverrides.gemini_code_execution()
|
||||||
@ -168,7 +244,18 @@ class AI:
|
|||||||
instruction: str = "",
|
instruction: str = "",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""信息搜索 - 支持多模态输入"""
|
"""
|
||||||
|
信息搜索 - 支持多模态输入
|
||||||
|
|
||||||
|
参数:
|
||||||
|
query: 搜索查询内容,支持文本或多模态消息。
|
||||||
|
model: 要使用的模型名称。
|
||||||
|
instruction: 搜索指令。
|
||||||
|
**kwargs: 传递给模型的其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
dict[str, Any]: 包含搜索结果的字典,包含text、sources、queries和success字段
|
||||||
|
"""
|
||||||
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
|
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
|
||||||
config = CommonOverrides.gemini_grounding()
|
config = CommonOverrides.gemini_grounding()
|
||||||
|
|
||||||
@ -217,63 +304,69 @@ class AI:
|
|||||||
|
|
||||||
async def analyze(
|
async def analyze(
|
||||||
self,
|
self,
|
||||||
message: UniMessage,
|
message: UniMessage | None,
|
||||||
*,
|
*,
|
||||||
instruction: str = "",
|
instruction: str = "",
|
||||||
model: ModelName = None,
|
model: ModelName = None,
|
||||||
tools: list[dict[str, Any]] | None = None,
|
use_tools: list[str] | None = None,
|
||||||
tool_config: dict[str, Any] | None = None,
|
tool_config: dict[str, Any] | None = None,
|
||||||
|
activated_tools: list[LLMTool] | None = None,
|
||||||
|
history: list[LLMMessage] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str | LLMResponse:
|
) -> LLMResponse:
|
||||||
"""
|
"""
|
||||||
内容分析 - 接收 UniMessage 物件进行多模态分析和工具呼叫。
|
内容分析 - 接收 UniMessage 物件进行多模态分析和工具呼叫。
|
||||||
这是处理复杂互动的主要方法。
|
|
||||||
|
参数:
|
||||||
|
message: 要分析的消息内容(支持多模态)。
|
||||||
|
instruction: 分析指令。
|
||||||
|
model: 要使用的模型名称。
|
||||||
|
use_tools: 要使用的工具名称列表。
|
||||||
|
tool_config: 工具配置。
|
||||||
|
activated_tools: 已激活的工具列表。
|
||||||
|
history: 对话历史记录。
|
||||||
|
**kwargs: 传递给模型的其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
LLMResponse: 模型的完整响应结果。
|
||||||
"""
|
"""
|
||||||
content_parts = await unimsg_to_llm_parts(message)
|
content_parts = await unimsg_to_llm_parts(message or UniMessage())
|
||||||
|
|
||||||
final_messages: list[LLMMessage] = []
|
final_messages: list[LLMMessage] = []
|
||||||
|
if history:
|
||||||
|
final_messages.extend(history)
|
||||||
|
|
||||||
if instruction:
|
if instruction:
|
||||||
final_messages.append(LLMMessage.system(instruction))
|
if not any(msg.role == "system" for msg in final_messages):
|
||||||
|
final_messages.insert(0, LLMMessage.system(instruction))
|
||||||
|
|
||||||
if not content_parts:
|
if not content_parts:
|
||||||
if instruction:
|
if instruction and not history:
|
||||||
final_messages.append(LLMMessage.user(instruction))
|
final_messages.append(LLMMessage.user(instruction))
|
||||||
else:
|
elif not history:
|
||||||
raise LLMException(
|
raise LLMException(
|
||||||
"分析内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
|
"分析内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
final_messages.append(LLMMessage.user(content_parts))
|
final_messages.append(LLMMessage.user(content_parts))
|
||||||
|
|
||||||
llm_tools = None
|
llm_tools: list[LLMTool] | None = activated_tools
|
||||||
if tools:
|
if not llm_tools and use_tools:
|
||||||
llm_tools = []
|
try:
|
||||||
for tool_dict in tools:
|
llm_tools = tool_registry.get_tools(use_tools)
|
||||||
if isinstance(tool_dict, dict):
|
logger.debug(f"已从注册表加载工具定义: {use_tools}")
|
||||||
if "name" in tool_dict and "description" in tool_dict:
|
except ValueError as e:
|
||||||
llm_tool = LLMTool(
|
raise LLMException(
|
||||||
type="function",
|
f"加载工具定义失败: {e}",
|
||||||
function={
|
code=LLMErrorCode.CONFIGURATION_ERROR,
|
||||||
"name": tool_dict["name"],
|
cause=e,
|
||||||
"description": tool_dict["description"],
|
|
||||||
"parameters": tool_dict.get("parameters", {}),
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
llm_tools.append(llm_tool)
|
|
||||||
else:
|
|
||||||
llm_tools.append(LLMTool(**tool_dict))
|
|
||||||
else:
|
|
||||||
llm_tools.append(tool_dict)
|
|
||||||
|
|
||||||
tool_choice = None
|
tool_choice = None
|
||||||
if tool_config:
|
if tool_config:
|
||||||
mode = tool_config.get("mode", "auto")
|
mode = tool_config.get("mode", "auto")
|
||||||
if mode == "auto":
|
if mode in ["auto", "any", "none"]:
|
||||||
tool_choice = "auto"
|
tool_choice = mode
|
||||||
elif mode == "any":
|
|
||||||
tool_choice = "any"
|
|
||||||
elif mode == "none":
|
|
||||||
tool_choice = "none"
|
|
||||||
|
|
||||||
response = await self._execute_generation(
|
response = await self._execute_generation(
|
||||||
final_messages,
|
final_messages,
|
||||||
@ -284,9 +377,7 @@ class AI:
|
|||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.tool_calls:
|
|
||||||
return response
|
return response
|
||||||
return response.text
|
|
||||||
|
|
||||||
async def _execute_generation(
|
async def _execute_generation(
|
||||||
self,
|
self,
|
||||||
@ -298,7 +389,7 @@ class AI:
|
|||||||
tool_choice: str | dict[str, Any] | None = None,
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
base_config: LLMGenerationConfig | None = None,
|
base_config: LLMGenerationConfig | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""通用的生成执行方法,封装重复的模型获取、配置合并和异常处理逻辑"""
|
"""通用的生成执行方法,封装模型获取和单次API调用"""
|
||||||
try:
|
try:
|
||||||
resolved_model_name = self._resolve_model_name(
|
resolved_model_name = self._resolve_model_name(
|
||||||
model_name or self.config.model
|
model_name or self.config.model
|
||||||
@ -311,7 +402,9 @@ class AI:
|
|||||||
resolved_model_name, override_config=final_config_dict
|
resolved_model_name, override_config=final_config_dict
|
||||||
) as model_instance:
|
) as model_instance:
|
||||||
return await model_instance.generate_response(
|
return await model_instance.generate_response(
|
||||||
messages, tools=llm_tools, tool_choice=tool_choice
|
messages,
|
||||||
|
tools=llm_tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
)
|
)
|
||||||
except LLMException:
|
except LLMException:
|
||||||
raise
|
raise
|
||||||
@ -380,7 +473,18 @@ class AI:
|
|||||||
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> list[list[float]]:
|
) -> list[list[float]]:
|
||||||
"""生成文本嵌入向量"""
|
"""
|
||||||
|
生成文本嵌入向量
|
||||||
|
|
||||||
|
参数:
|
||||||
|
texts: 要生成嵌入向量的文本或文本列表。
|
||||||
|
model: 要使用的嵌入模型名称。
|
||||||
|
task_type: 嵌入任务类型。
|
||||||
|
**kwargs: 传递给模型的其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
list[list[float]]: 文本的嵌入向量列表。
|
||||||
|
"""
|
||||||
if isinstance(texts, str):
|
if isinstance(texts, str):
|
||||||
texts = [texts]
|
texts = [texts]
|
||||||
if not texts:
|
if not texts:
|
||||||
@ -420,7 +524,17 @@ async def chat(
|
|||||||
model: ModelName = None,
|
model: ModelName = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""聊天对话便捷函数"""
|
"""
|
||||||
|
聊天对话便捷函数
|
||||||
|
|
||||||
|
参数:
|
||||||
|
message: 用户输入的消息。
|
||||||
|
model: 要使用的模型名称。
|
||||||
|
**kwargs: 传递给模型的其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 模型的文本响应。
|
||||||
|
"""
|
||||||
ai = AI()
|
ai = AI()
|
||||||
return await ai.chat(message, model=model, **kwargs)
|
return await ai.chat(message, model=model, **kwargs)
|
||||||
|
|
||||||
@ -432,7 +546,18 @@ async def code(
|
|||||||
timeout: int | None = None,
|
timeout: int | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""代码执行便捷函数"""
|
"""
|
||||||
|
代码执行便捷函数
|
||||||
|
|
||||||
|
参数:
|
||||||
|
prompt: 代码执行的提示词。
|
||||||
|
model: 要使用的模型名称。
|
||||||
|
timeout: 代码执行超时时间(秒)。
|
||||||
|
**kwargs: 传递给模型的其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
dict[str, Any]: 包含执行结果的字典。
|
||||||
|
"""
|
||||||
ai = AI()
|
ai = AI()
|
||||||
return await ai.code(prompt, model=model, timeout=timeout, **kwargs)
|
return await ai.code(prompt, model=model, timeout=timeout, **kwargs)
|
||||||
|
|
||||||
@ -444,45 +569,56 @@ async def search(
|
|||||||
instruction: str = "",
|
instruction: str = "",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""信息搜索便捷函数"""
|
"""
|
||||||
|
信息搜索便捷函数
|
||||||
|
|
||||||
|
参数:
|
||||||
|
query: 搜索查询内容。
|
||||||
|
model: 要使用的模型名称。
|
||||||
|
instruction: 搜索指令。
|
||||||
|
**kwargs: 传递给模型的其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
dict[str, Any]: 包含搜索结果的字典。
|
||||||
|
"""
|
||||||
ai = AI()
|
ai = AI()
|
||||||
return await ai.search(query, model=model, instruction=instruction, **kwargs)
|
return await ai.search(query, model=model, instruction=instruction, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
async def analyze(
|
async def analyze(
|
||||||
message: UniMessage,
|
message: UniMessage | None,
|
||||||
*,
|
*,
|
||||||
instruction: str = "",
|
instruction: str = "",
|
||||||
model: ModelName = None,
|
model: ModelName = None,
|
||||||
tools: list[dict[str, Any]] | None = None,
|
use_tools: list[str] | None = None,
|
||||||
tool_config: dict[str, Any] | None = None,
|
tool_config: dict[str, Any] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str | LLMResponse:
|
) -> str | LLMResponse:
|
||||||
"""内容分析便捷函数"""
|
"""
|
||||||
|
内容分析便捷函数
|
||||||
|
|
||||||
|
参数:
|
||||||
|
message: 要分析的消息内容。
|
||||||
|
instruction: 分析指令。
|
||||||
|
model: 要使用的模型名称。
|
||||||
|
use_tools: 要使用的工具名称列表。
|
||||||
|
tool_config: 工具配置。
|
||||||
|
**kwargs: 传递给模型的其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str | LLMResponse: 分析结果。
|
||||||
|
"""
|
||||||
ai = AI()
|
ai = AI()
|
||||||
return await ai.analyze(
|
return await ai.analyze(
|
||||||
message,
|
message,
|
||||||
instruction=instruction,
|
instruction=instruction,
|
||||||
model=model,
|
model=model,
|
||||||
tools=tools,
|
use_tools=use_tools,
|
||||||
tool_config=tool_config,
|
tool_config=tool_config,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def analyze_with_images(
|
|
||||||
text: str,
|
|
||||||
images: list[str | Path | bytes] | str | Path | bytes,
|
|
||||||
*,
|
|
||||||
instruction: str = "",
|
|
||||||
model: ModelName = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> str | LLMResponse:
|
|
||||||
"""图片分析便捷函数"""
|
|
||||||
message = create_multimodal_message(text=text, images=images)
|
|
||||||
return await analyze(message, instruction=instruction, model=model, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
async def analyze_multimodal(
|
async def analyze_multimodal(
|
||||||
text: str | None = None,
|
text: str | None = None,
|
||||||
images: list[str | Path | bytes] | str | Path | bytes | None = None,
|
images: list[str | Path | bytes] | str | Path | bytes | None = None,
|
||||||
@ -493,7 +629,21 @@ async def analyze_multimodal(
|
|||||||
model: ModelName = None,
|
model: ModelName = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str | LLMResponse:
|
) -> str | LLMResponse:
|
||||||
"""多模态分析便捷函数"""
|
"""
|
||||||
|
多模态分析便捷函数
|
||||||
|
|
||||||
|
参数:
|
||||||
|
text: 文本内容。
|
||||||
|
images: 图片文件路径、字节数据或列表。
|
||||||
|
videos: 视频文件路径、字节数据或列表。
|
||||||
|
audios: 音频文件路径、字节数据或列表。
|
||||||
|
instruction: 分析指令。
|
||||||
|
model: 要使用的模型名称。
|
||||||
|
**kwargs: 传递给模型的其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str | LLMResponse: 分析结果。
|
||||||
|
"""
|
||||||
message = create_multimodal_message(
|
message = create_multimodal_message(
|
||||||
text=text, images=images, videos=videos, audios=audios
|
text=text, images=images, videos=videos, audios=audios
|
||||||
)
|
)
|
||||||
@ -510,7 +660,21 @@ async def search_multimodal(
|
|||||||
model: ModelName = None,
|
model: ModelName = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""多模态搜索便捷函数"""
|
"""
|
||||||
|
多模态搜索便捷函数
|
||||||
|
|
||||||
|
参数:
|
||||||
|
text: 文本内容。
|
||||||
|
images: 图片文件路径、字节数据或列表。
|
||||||
|
videos: 视频文件路径、字节数据或列表。
|
||||||
|
audios: 音频文件路径、字节数据或列表。
|
||||||
|
instruction: 搜索指令。
|
||||||
|
model: 要使用的模型名称。
|
||||||
|
**kwargs: 传递给模型的其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
dict[str, Any]: 包含搜索结果的字典。
|
||||||
|
"""
|
||||||
message = create_multimodal_message(
|
message = create_multimodal_message(
|
||||||
text=text, images=images, videos=videos, audios=audios
|
text=text, images=images, videos=videos, audios=audios
|
||||||
)
|
)
|
||||||
@ -525,6 +689,101 @@ async def embed(
|
|||||||
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> list[list[float]]:
|
) -> list[list[float]]:
|
||||||
"""文本嵌入便捷函数"""
|
"""
|
||||||
|
文本嵌入便捷函数
|
||||||
|
|
||||||
|
参数:
|
||||||
|
texts: 要生成嵌入向量的文本或文本列表。
|
||||||
|
model: 要使用的嵌入模型名称。
|
||||||
|
task_type: 嵌入任务类型。
|
||||||
|
**kwargs: 传递给模型的其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
list[list[float]]: 文本的嵌入向量列表。
|
||||||
|
"""
|
||||||
ai = AI()
|
ai = AI()
|
||||||
return await ai.embed(texts, model=model, task_type=task_type, **kwargs)
|
return await ai.embed(texts, model=model, task_type=task_type, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
async def pipeline_chat(
|
||||||
|
message: UniMessage | str | list[LLMContentPart],
|
||||||
|
model_chain: list[ModelName],
|
||||||
|
*,
|
||||||
|
initial_instruction: str = "",
|
||||||
|
final_instruction: str = "",
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""
|
||||||
|
AI模型链式调用,前一个模型的输出作为下一个模型的输入。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
message: 初始输入消息(支持多模态)
|
||||||
|
model_chain: 模型名称列表
|
||||||
|
initial_instruction: 第一个模型的系统指令
|
||||||
|
final_instruction: 最后一个模型的系统指令
|
||||||
|
**kwargs: 传递给模型实例的其他参数
|
||||||
|
|
||||||
|
返回:
|
||||||
|
LLMResponse: 最后一个模型的响应结果
|
||||||
|
"""
|
||||||
|
if not model_chain:
|
||||||
|
raise ValueError("模型链`model_chain`不能为空。")
|
||||||
|
|
||||||
|
current_content: str | list[LLMContentPart]
|
||||||
|
if isinstance(message, str):
|
||||||
|
current_content = message
|
||||||
|
elif isinstance(message, list):
|
||||||
|
current_content = message
|
||||||
|
else:
|
||||||
|
current_content = await unimsg_to_llm_parts(message)
|
||||||
|
|
||||||
|
final_response: LLMResponse | None = None
|
||||||
|
|
||||||
|
for i, model_name in enumerate(model_chain):
|
||||||
|
if not model_name:
|
||||||
|
raise ValueError(f"模型链中第 {i + 1} 个模型名称为空。")
|
||||||
|
|
||||||
|
is_first_step = i == 0
|
||||||
|
is_last_step = i == len(model_chain) - 1
|
||||||
|
|
||||||
|
messages_for_step: list[LLMMessage] = []
|
||||||
|
instruction_for_step = ""
|
||||||
|
if is_first_step and initial_instruction:
|
||||||
|
instruction_for_step = initial_instruction
|
||||||
|
elif is_last_step and final_instruction:
|
||||||
|
instruction_for_step = final_instruction
|
||||||
|
|
||||||
|
if instruction_for_step:
|
||||||
|
messages_for_step.append(LLMMessage.system(instruction_for_step))
|
||||||
|
|
||||||
|
messages_for_step.append(LLMMessage.user(current_content))
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Pipeline Step [{i + 1}/{len(model_chain)}]: "
|
||||||
|
f"使用模型 '{model_name}' 进行处理..."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
async with await get_model_instance(model_name, **kwargs) as model:
|
||||||
|
response = await model.generate_response(messages_for_step)
|
||||||
|
final_response = response
|
||||||
|
current_content = response.text.strip()
|
||||||
|
if not current_content and not is_last_step:
|
||||||
|
logger.warning(
|
||||||
|
f"模型 '{model_name}' 在中间步骤返回了空内容,流水线可能无法继续。"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"在模型链的第 {i + 1} 步 ('{model_name}') 出错: {e}", e=e)
|
||||||
|
raise LLMException(
|
||||||
|
f"流水线在模型 '{model_name}' 处执行失败: {e}",
|
||||||
|
code=LLMErrorCode.GENERATION_FAILED,
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
|
||||||
|
if final_response is None:
|
||||||
|
raise LLMException(
|
||||||
|
"AI流水线未能产生任何响应。", code=LLMErrorCode.GENERATION_FAILED
|
||||||
|
)
|
||||||
|
|
||||||
|
return final_response
|
||||||
|
|||||||
@ -14,6 +14,8 @@ from .generation import (
|
|||||||
from .presets import CommonOverrides
|
from .presets import CommonOverrides
|
||||||
from .providers import (
|
from .providers import (
|
||||||
LLMConfig,
|
LLMConfig,
|
||||||
|
ToolConfig,
|
||||||
|
get_gemini_safety_threshold,
|
||||||
get_llm_config,
|
get_llm_config,
|
||||||
register_llm_configs,
|
register_llm_configs,
|
||||||
set_default_model,
|
set_default_model,
|
||||||
@ -25,8 +27,10 @@ __all__ = [
|
|||||||
"LLMConfig",
|
"LLMConfig",
|
||||||
"LLMGenerationConfig",
|
"LLMGenerationConfig",
|
||||||
"ModelConfigOverride",
|
"ModelConfigOverride",
|
||||||
|
"ToolConfig",
|
||||||
"apply_api_specific_mappings",
|
"apply_api_specific_mappings",
|
||||||
"create_generation_config_from_kwargs",
|
"create_generation_config_from_kwargs",
|
||||||
|
"get_gemini_safety_threshold",
|
||||||
"get_llm_config",
|
"get_llm_config",
|
||||||
"register_llm_configs",
|
"register_llm_configs",
|
||||||
"set_default_model",
|
"set_default_model",
|
||||||
|
|||||||
@ -111,12 +111,12 @@ class LLMGenerationConfig(ModelConfigOverride):
|
|||||||
params["temperature"] = self.temperature
|
params["temperature"] = self.temperature
|
||||||
|
|
||||||
if self.max_tokens is not None:
|
if self.max_tokens is not None:
|
||||||
if api_type in ["gemini", "gemini_native"]:
|
if api_type == "gemini":
|
||||||
params["maxOutputTokens"] = self.max_tokens
|
params["maxOutputTokens"] = self.max_tokens
|
||||||
else:
|
else:
|
||||||
params["max_tokens"] = self.max_tokens
|
params["max_tokens"] = self.max_tokens
|
||||||
|
|
||||||
if api_type in ["gemini", "gemini_native"]:
|
if api_type == "gemini":
|
||||||
if self.top_k is not None:
|
if self.top_k is not None:
|
||||||
params["topK"] = self.top_k
|
params["topK"] = self.top_k
|
||||||
if self.top_p is not None:
|
if self.top_p is not None:
|
||||||
@ -151,13 +151,13 @@ class LLMGenerationConfig(ModelConfigOverride):
|
|||||||
if api_type in ["openai", "zhipu", "deepseek", "general_openai_compat"]:
|
if api_type in ["openai", "zhipu", "deepseek", "general_openai_compat"]:
|
||||||
params["response_format"] = {"type": "json_object"}
|
params["response_format"] = {"type": "json_object"}
|
||||||
logger.debug(f"为 {api_type} 启用 JSON 对象输出模式")
|
logger.debug(f"为 {api_type} 启用 JSON 对象输出模式")
|
||||||
elif api_type in ["gemini", "gemini_native"]:
|
elif api_type == "gemini":
|
||||||
params["responseMimeType"] = "application/json"
|
params["responseMimeType"] = "application/json"
|
||||||
if self.response_schema:
|
if self.response_schema:
|
||||||
params["responseSchema"] = self.response_schema
|
params["responseSchema"] = self.response_schema
|
||||||
logger.debug(f"为 {api_type} 启用 JSON MIME 类型输出模式")
|
logger.debug(f"为 {api_type} 启用 JSON MIME 类型输出模式")
|
||||||
|
|
||||||
if api_type in ["gemini", "gemini_native"]:
|
if api_type == "gemini":
|
||||||
if (
|
if (
|
||||||
self.response_format != ResponseFormat.JSON
|
self.response_format != ResponseFormat.JSON
|
||||||
and self.response_mime_type is not None
|
and self.response_mime_type is not None
|
||||||
@ -214,7 +214,7 @@ def apply_api_specific_mappings(
|
|||||||
"""应用API特定的参数映射"""
|
"""应用API特定的参数映射"""
|
||||||
mapped_params = params.copy()
|
mapped_params = params.copy()
|
||||||
|
|
||||||
if api_type in ["gemini", "gemini_native"]:
|
if api_type == "gemini":
|
||||||
if "max_tokens" in mapped_params:
|
if "max_tokens" in mapped_params:
|
||||||
mapped_params["maxOutputTokens"] = mapped_params.pop("max_tokens")
|
mapped_params["maxOutputTokens"] = mapped_params.pop("max_tokens")
|
||||||
if "top_k" in mapped_params:
|
if "top_k" in mapped_params:
|
||||||
|
|||||||
@ -71,14 +71,17 @@ class CommonOverrides:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def gemini_safe() -> LLMGenerationConfig:
|
def gemini_safe() -> LLMGenerationConfig:
|
||||||
"""Gemini 安全模式:严格安全设置"""
|
"""Gemini 安全模式:使用配置的安全设置"""
|
||||||
|
from .providers import get_gemini_safety_threshold
|
||||||
|
|
||||||
|
threshold = get_gemini_safety_threshold()
|
||||||
return LLMGenerationConfig(
|
return LLMGenerationConfig(
|
||||||
temperature=0.5,
|
temperature=0.5,
|
||||||
safety_settings={
|
safety_settings={
|
||||||
"HARM_CATEGORY_HARASSMENT": "BLOCK_MEDIUM_AND_ABOVE",
|
"HARM_CATEGORY_HARASSMENT": threshold,
|
||||||
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_MEDIUM_AND_ABOVE",
|
"HARM_CATEGORY_HATE_SPEECH": threshold,
|
||||||
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_MEDIUM_AND_ABOVE",
|
"HARM_CATEGORY_SEXUALLY_EXPLICIT": threshold,
|
||||||
"HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_MEDIUM_AND_ABOVE",
|
"HARM_CATEGORY_DANGEROUS_CONTENT": threshold,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -4,15 +4,33 @@ LLM 提供商配置管理
|
|||||||
负责注册和管理 AI 服务提供商的配置项。
|
负责注册和管理 AI 服务提供商的配置项。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from functools import lru_cache
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from zhenxun.configs.config import Config
|
from zhenxun.configs.config import Config
|
||||||
|
from zhenxun.configs.path_config import DATA_PATH
|
||||||
|
from zhenxun.configs.utils import parse_as
|
||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
|
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||||
|
|
||||||
from ..types.models import ModelDetail, ProviderConfig
|
from ..types.models import ModelDetail, ProviderConfig
|
||||||
|
|
||||||
|
|
||||||
|
class ToolConfig(BaseModel):
|
||||||
|
"""MCP类型工具的配置定义"""
|
||||||
|
|
||||||
|
type: str = "mcp"
|
||||||
|
name: str = Field(..., description="工具的唯一名称标识")
|
||||||
|
description: str | None = Field(None, description="工具功能的描述")
|
||||||
|
mcp_config: dict[str, Any] | BaseModel = Field(
|
||||||
|
..., description="MCP服务器的特定配置"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
AI_CONFIG_GROUP = "AI"
|
AI_CONFIG_GROUP = "AI"
|
||||||
PROVIDERS_CONFIG_KEY = "PROVIDERS"
|
PROVIDERS_CONFIG_KEY = "PROVIDERS"
|
||||||
|
|
||||||
@ -38,6 +56,9 @@ class LLMConfig(BaseModel):
|
|||||||
providers: list[ProviderConfig] = Field(
|
providers: list[ProviderConfig] = Field(
|
||||||
default_factory=list, description="配置多个 AI 服务提供商及其模型信息"
|
default_factory=list, description="配置多个 AI 服务提供商及其模型信息"
|
||||||
)
|
)
|
||||||
|
mcp_tools: list[ToolConfig] = Field(
|
||||||
|
default_factory=list, description="配置可用的外部MCP工具"
|
||||||
|
)
|
||||||
|
|
||||||
def get_provider_by_name(self, name: str) -> ProviderConfig | None:
|
def get_provider_by_name(self, name: str) -> ProviderConfig | None:
|
||||||
"""根据名称获取提供商配置
|
"""根据名称获取提供商配置
|
||||||
@ -132,7 +153,7 @@ def get_default_providers() -> list[dict[str, Any]]:
|
|||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"name": "DeepSeek",
|
"name": "DeepSeek",
|
||||||
"api_key": "sk-******",
|
"api_key": "YOUR_ARK_API_KEY",
|
||||||
"api_base": "https://api.deepseek.com",
|
"api_base": "https://api.deepseek.com",
|
||||||
"api_type": "openai",
|
"api_type": "openai",
|
||||||
"models": [
|
"models": [
|
||||||
@ -146,9 +167,30 @@ def get_default_providers() -> list[dict[str, Any]]:
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "ARK",
|
||||||
|
"api_key": "YOUR_ARK_API_KEY",
|
||||||
|
"api_base": "https://ark.cn-beijing.volces.com",
|
||||||
|
"api_type": "ark",
|
||||||
|
"models": [
|
||||||
|
{"model_name": "deepseek-r1-250528"},
|
||||||
|
{"model_name": "doubao-seed-1-6-250615"},
|
||||||
|
{"model_name": "doubao-seed-1-6-flash-250615"},
|
||||||
|
{"model_name": "doubao-seed-1-6-thinking-250615"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "siliconflow",
|
||||||
|
"api_key": "YOUR_ARK_API_KEY",
|
||||||
|
"api_base": "https://api.siliconflow.cn",
|
||||||
|
"api_type": "openai",
|
||||||
|
"models": [
|
||||||
|
{"model_name": "deepseek-ai/DeepSeek-V3"},
|
||||||
|
],
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "GLM",
|
"name": "GLM",
|
||||||
"api_key": "",
|
"api_key": "YOUR_ARK_API_KEY",
|
||||||
"api_base": "https://open.bigmodel.cn",
|
"api_base": "https://open.bigmodel.cn",
|
||||||
"api_type": "zhipu",
|
"api_type": "zhipu",
|
||||||
"models": [
|
"models": [
|
||||||
@ -167,12 +209,41 @@ def get_default_providers() -> list[dict[str, Any]]:
|
|||||||
"api_type": "gemini",
|
"api_type": "gemini",
|
||||||
"models": [
|
"models": [
|
||||||
{"model_name": "gemini-2.0-flash"},
|
{"model_name": "gemini-2.0-flash"},
|
||||||
{"model_name": "gemini-2.5-flash-preview-05-20"},
|
{"model_name": "gemini-2.5-flash"},
|
||||||
|
{"model_name": "gemini-2.5-pro"},
|
||||||
|
{"model_name": "gemini-2.5-flash-lite-preview-06-17"},
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_mcp_tools() -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
获取默认的MCP工具配置,用于在文件不存在时创建。
|
||||||
|
包含了 baidu-map, Context7, 和 sequential-thinking.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"mcpServers": {
|
||||||
|
"baidu-map": {
|
||||||
|
"command": "npx",
|
||||||
|
"args": ["-y", "@baidumap/mcp-server-baidu-map"],
|
||||||
|
"env": {"BAIDU_MAP_API_KEY": "<YOUR_BAIDU_MAP_API_KEY>"},
|
||||||
|
"description": "百度地图工具,提供地理编码、路线规划等功能。",
|
||||||
|
},
|
||||||
|
"sequential-thinking": {
|
||||||
|
"command": "npx",
|
||||||
|
"args": ["-y", "@modelcontextprotocol/server-sequential-thinking"],
|
||||||
|
"description": "顺序思维工具,用于帮助模型进行多步骤推理。",
|
||||||
|
},
|
||||||
|
"Context7": {
|
||||||
|
"command": "npx",
|
||||||
|
"args": ["-y", "@upstash/context7-mcp@latest"],
|
||||||
|
"description": "Upstash 提供的上下文管理和记忆工具。",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def register_llm_configs():
|
def register_llm_configs():
|
||||||
"""注册 LLM 服务的配置项"""
|
"""注册 LLM 服务的配置项"""
|
||||||
logger.info("注册 LLM 服务的配置项")
|
logger.info("注册 LLM 服务的配置项")
|
||||||
@ -214,6 +285,19 @@ def register_llm_configs():
|
|||||||
help="LLM服务请求重试的基础延迟时间(秒)",
|
help="LLM服务请求重试的基础延迟时间(秒)",
|
||||||
type=int,
|
type=int,
|
||||||
)
|
)
|
||||||
|
Config.add_plugin_config(
|
||||||
|
AI_CONFIG_GROUP,
|
||||||
|
"gemini_safety_threshold",
|
||||||
|
"BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
help=(
|
||||||
|
"Gemini 安全过滤阈值 "
|
||||||
|
"(BLOCK_LOW_AND_ABOVE: 阻止低级别及以上, "
|
||||||
|
"BLOCK_MEDIUM_AND_ABOVE: 阻止中等级别及以上, "
|
||||||
|
"BLOCK_ONLY_HIGH: 只阻止高级别, "
|
||||||
|
"BLOCK_NONE: 不阻止)"
|
||||||
|
),
|
||||||
|
type=str,
|
||||||
|
)
|
||||||
|
|
||||||
Config.add_plugin_config(
|
Config.add_plugin_config(
|
||||||
AI_CONFIG_GROUP,
|
AI_CONFIG_GROUP,
|
||||||
@ -225,24 +309,111 @@ def register_llm_configs():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
def get_llm_config() -> LLMConfig:
|
def get_llm_config() -> LLMConfig:
|
||||||
"""获取 LLM 配置实例
|
"""获取 LLM 配置实例,现在会从新的 JSON 文件加载 MCP 工具"""
|
||||||
|
|
||||||
返回:
|
|
||||||
LLMConfig: LLM 配置实例
|
|
||||||
"""
|
|
||||||
ai_config = get_ai_config()
|
ai_config = get_ai_config()
|
||||||
|
|
||||||
|
llm_data_path = DATA_PATH / "llm"
|
||||||
|
mcp_tools_path = llm_data_path / "mcp_tools.json"
|
||||||
|
|
||||||
|
mcp_tools_list = []
|
||||||
|
mcp_servers_dict = {}
|
||||||
|
|
||||||
|
if not mcp_tools_path.exists():
|
||||||
|
logger.info(f"未找到 MCP 工具配置文件,将在 '{mcp_tools_path}' 创建一个。")
|
||||||
|
llm_data_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
default_mcp_config = get_default_mcp_tools()
|
||||||
|
try:
|
||||||
|
with mcp_tools_path.open("w", encoding="utf-8") as f:
|
||||||
|
json.dump(default_mcp_config, f, ensure_ascii=False, indent=2)
|
||||||
|
mcp_servers_dict = default_mcp_config.get("mcpServers", {})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"创建默认 MCP 配置文件失败: {e}", e=e)
|
||||||
|
mcp_servers_dict = {}
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
with mcp_tools_path.open("r", encoding="utf-8") as f:
|
||||||
|
mcp_data = json.load(f)
|
||||||
|
mcp_servers_dict = mcp_data.get("mcpServers", {})
|
||||||
|
if not isinstance(mcp_servers_dict, dict):
|
||||||
|
logger.warning(
|
||||||
|
f"'{mcp_tools_path}' 中的 'mcpServers' 键不是一个字典,"
|
||||||
|
f"将使用空配置。"
|
||||||
|
)
|
||||||
|
mcp_servers_dict = {}
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"解析 MCP 配置文件 '{mcp_tools_path}' 失败: {e}", e=e)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"读取 MCP 配置文件时发生未知错误: {e}", e=e)
|
||||||
|
mcp_servers_dict = {}
|
||||||
|
|
||||||
|
if sys.platform == "win32":
|
||||||
|
logger.debug("检测到Windows平台,正在调整MCP工具的npx命令...")
|
||||||
|
for name, config in mcp_servers_dict.items():
|
||||||
|
if isinstance(config, dict) and config.get("command") == "npx":
|
||||||
|
logger.info(f"为工具 '{name}' 包装npx命令以兼容Windows。")
|
||||||
|
original_args = config.get("args", [])
|
||||||
|
config["command"] = "cmd"
|
||||||
|
config["args"] = ["/c", "npx", *original_args]
|
||||||
|
|
||||||
|
if mcp_servers_dict:
|
||||||
|
mcp_tools_list = [
|
||||||
|
{
|
||||||
|
"name": name,
|
||||||
|
"type": "mcp",
|
||||||
|
"description": config.get("description", f"MCP tool for {name}"),
|
||||||
|
"mcp_config": config,
|
||||||
|
}
|
||||||
|
for name, config in mcp_servers_dict.items()
|
||||||
|
if isinstance(config, dict)
|
||||||
|
]
|
||||||
|
|
||||||
|
from ..tools.registry import tool_registry
|
||||||
|
|
||||||
|
for tool_dict in mcp_tools_list:
|
||||||
|
if isinstance(tool_dict, dict):
|
||||||
|
tool_name = tool_dict.get("name")
|
||||||
|
if not tool_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
config_model = tool_registry.get_mcp_config_model(tool_name)
|
||||||
|
if not config_model:
|
||||||
|
logger.debug(
|
||||||
|
f"MCP工具 '{tool_name}' 没有注册其配置模型,"
|
||||||
|
f"将跳过特定配置验证,直接使用原始配置字典。"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
mcp_config_data = tool_dict.get("mcp_config", {})
|
||||||
|
try:
|
||||||
|
parsed_mcp_config = parse_as(config_model, mcp_config_data)
|
||||||
|
tool_dict["mcp_config"] = parsed_mcp_config
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"MCP工具 '{tool_name}' 的 `mcp_config` 配置错误: {e}")
|
||||||
|
|
||||||
config_data = {
|
config_data = {
|
||||||
"default_model_name": ai_config.get("default_model_name"),
|
"default_model_name": ai_config.get("default_model_name"),
|
||||||
"proxy": ai_config.get("proxy"),
|
"proxy": ai_config.get("proxy"),
|
||||||
"timeout": ai_config.get("timeout", 180),
|
"timeout": ai_config.get("timeout", 180),
|
||||||
"max_retries_llm": ai_config.get("max_retries_llm", 3),
|
"max_retries_llm": ai_config.get("max_retries_llm", 3),
|
||||||
"retry_delay_llm": ai_config.get("retry_delay_llm", 2),
|
"retry_delay_llm": ai_config.get("retry_delay_llm", 2),
|
||||||
"providers": ai_config.get(PROVIDERS_CONFIG_KEY, []),
|
PROVIDERS_CONFIG_KEY: ai_config.get(PROVIDERS_CONFIG_KEY, []),
|
||||||
|
"mcp_tools": mcp_tools_list,
|
||||||
}
|
}
|
||||||
|
|
||||||
return LLMConfig(**config_data)
|
return parse_as(LLMConfig, config_data)
|
||||||
|
|
||||||
|
|
||||||
|
def get_gemini_safety_threshold() -> str:
|
||||||
|
"""获取 Gemini 安全过滤阈值配置
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 安全过滤阈值
|
||||||
|
"""
|
||||||
|
ai_config = get_ai_config()
|
||||||
|
return ai_config.get("gemini_safety_threshold", "BLOCK_MEDIUM_AND_ABOVE")
|
||||||
|
|
||||||
|
|
||||||
def validate_llm_config() -> tuple[bool, list[str]]:
|
def validate_llm_config() -> tuple[bool, list[str]]:
|
||||||
@ -326,3 +497,17 @@ def set_default_model(provider_model_name: str | None) -> bool:
|
|||||||
logger.info("默认模型已清除")
|
logger.info("默认模型已清除")
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@PriorityLifecycle.on_startup(priority=10)
|
||||||
|
async def _init_llm_config_on_startup():
|
||||||
|
"""
|
||||||
|
在服务启动时主动调用一次 get_llm_config,
|
||||||
|
以触发必要的初始化操作,例如创建默认的 mcp_tools.json 文件。
|
||||||
|
"""
|
||||||
|
logger.info("正在初始化 LLM 配置并检查 MCP 工具文件...")
|
||||||
|
try:
|
||||||
|
get_llm_config()
|
||||||
|
logger.info("LLM 配置初始化完成。")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"LLM 配置初始化时发生错误: {e}", e=e)
|
||||||
|
|||||||
@ -49,12 +49,36 @@ class LLMHttpClient:
|
|||||||
max_keepalive_connections=self.config.max_keepalive_connections,
|
max_keepalive_connections=self.config.max_keepalive_connections,
|
||||||
)
|
)
|
||||||
timeout = httpx.Timeout(self.config.timeout)
|
timeout = httpx.Timeout(self.config.timeout)
|
||||||
|
|
||||||
|
client_kwargs = {}
|
||||||
|
if self.config.proxy:
|
||||||
|
try:
|
||||||
|
version_parts = httpx.__version__.split(".")
|
||||||
|
major = int(
|
||||||
|
"".join(c for c in version_parts[0] if c.isdigit())
|
||||||
|
)
|
||||||
|
minor = (
|
||||||
|
int("".join(c for c in version_parts[1] if c.isdigit()))
|
||||||
|
if len(version_parts) > 1
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
if (major, minor) >= (0, 28):
|
||||||
|
client_kwargs["proxy"] = self.config.proxy
|
||||||
|
else:
|
||||||
|
client_kwargs["proxies"] = self.config.proxy
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
client_kwargs["proxies"] = self.config.proxy
|
||||||
|
logger.warning(
|
||||||
|
f"无法解析 httpx 版本 '{httpx.__version__}',"
|
||||||
|
"LLM模块将默认使用旧版 'proxies' 参数语法。"
|
||||||
|
)
|
||||||
|
|
||||||
self._client = httpx.AsyncClient(
|
self._client = httpx.AsyncClient(
|
||||||
headers=headers,
|
headers=headers,
|
||||||
limits=limits,
|
limits=limits,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
proxies=self.config.proxy,
|
|
||||||
follow_redirects=True,
|
follow_redirects=True,
|
||||||
|
**client_kwargs,
|
||||||
)
|
)
|
||||||
if self._client is None:
|
if self._client is None:
|
||||||
raise LLMException(
|
raise LLMException(
|
||||||
@ -156,7 +180,16 @@ async def create_llm_http_client(
|
|||||||
timeout: int = 180,
|
timeout: int = 180,
|
||||||
proxy: str | None = None,
|
proxy: str | None = None,
|
||||||
) -> LLMHttpClient:
|
) -> LLMHttpClient:
|
||||||
"""创建LLM HTTP客户端"""
|
"""
|
||||||
|
创建LLM HTTP客户端
|
||||||
|
|
||||||
|
参数:
|
||||||
|
timeout: 超时时间(秒)。
|
||||||
|
proxy: 代理服务器地址。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
LLMHttpClient: HTTP客户端实例。
|
||||||
|
"""
|
||||||
config = HttpClientConfig(timeout=timeout, proxy=proxy)
|
config = HttpClientConfig(timeout=timeout, proxy=proxy)
|
||||||
return LLMHttpClient(config)
|
return LLMHttpClient(config)
|
||||||
|
|
||||||
@ -185,7 +218,20 @@ async def with_smart_retry(
|
|||||||
provider_name: str | None = None,
|
provider_name: str | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""智能重试装饰器 - 支持Key轮询和错误分类"""
|
"""
|
||||||
|
智能重试装饰器 - 支持Key轮询和错误分类
|
||||||
|
|
||||||
|
参数:
|
||||||
|
func: 要重试的异步函数。
|
||||||
|
*args: 传递给函数的位置参数。
|
||||||
|
retry_config: 重试配置。
|
||||||
|
key_store: API密钥状态存储。
|
||||||
|
provider_name: 提供商名称。
|
||||||
|
**kwargs: 传递给函数的关键字参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
Any: 函数执行结果。
|
||||||
|
"""
|
||||||
config = retry_config or RetryConfig()
|
config = retry_config or RetryConfig()
|
||||||
last_exception: Exception | None = None
|
last_exception: Exception | None = None
|
||||||
failed_keys: set[str] = set()
|
failed_keys: set[str] = set()
|
||||||
@ -294,7 +340,17 @@ class KeyStatusStore:
|
|||||||
api_keys: list[str],
|
api_keys: list[str],
|
||||||
exclude_keys: set[str] | None = None,
|
exclude_keys: set[str] | None = None,
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
"""获取下一个可用的API密钥(轮询策略)"""
|
"""
|
||||||
|
获取下一个可用的API密钥(轮询策略)
|
||||||
|
|
||||||
|
参数:
|
||||||
|
provider_name: 提供商名称。
|
||||||
|
api_keys: API密钥列表。
|
||||||
|
exclude_keys: 要排除的密钥集合。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str | None: 可用的API密钥,如果没有可用密钥则返回None。
|
||||||
|
"""
|
||||||
if not api_keys:
|
if not api_keys:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -338,7 +394,13 @@ class KeyStatusStore:
|
|||||||
logger.debug(f"记录API密钥成功使用: {self._get_key_id(api_key)}")
|
logger.debug(f"记录API密钥成功使用: {self._get_key_id(api_key)}")
|
||||||
|
|
||||||
async def record_failure(self, api_key: str, status_code: int | None):
|
async def record_failure(self, api_key: str, status_code: int | None):
|
||||||
"""记录失败使用"""
|
"""
|
||||||
|
记录失败使用
|
||||||
|
|
||||||
|
参数:
|
||||||
|
api_key: API密钥。
|
||||||
|
status_code: HTTP状态码。
|
||||||
|
"""
|
||||||
key_id = self._get_key_id(api_key)
|
key_id = self._get_key_id(api_key)
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
if status_code in [401, 403]:
|
if status_code in [401, 403]:
|
||||||
@ -356,7 +418,15 @@ class KeyStatusStore:
|
|||||||
logger.info(f"重置API密钥状态: {self._get_key_id(api_key)}")
|
logger.info(f"重置API密钥状态: {self._get_key_id(api_key)}")
|
||||||
|
|
||||||
async def get_key_stats(self, api_keys: list[str]) -> dict[str, dict]:
|
async def get_key_stats(self, api_keys: list[str]) -> dict[str, dict]:
|
||||||
"""获取密钥使用统计"""
|
"""
|
||||||
|
获取密钥使用统计
|
||||||
|
|
||||||
|
参数:
|
||||||
|
api_keys: API密钥列表。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
dict[str, dict]: 密钥统计信息字典。
|
||||||
|
"""
|
||||||
stats = {}
|
stats = {}
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
for key in api_keys:
|
for key in api_keys:
|
||||||
|
|||||||
@ -17,6 +17,7 @@ from .config.providers import AI_CONFIG_GROUP, PROVIDERS_CONFIG_KEY, get_ai_conf
|
|||||||
from .core import http_client_manager, key_store
|
from .core import http_client_manager, key_store
|
||||||
from .service import LLMModel
|
from .service import LLMModel
|
||||||
from .types import LLMErrorCode, LLMException, ModelDetail, ProviderConfig
|
from .types import LLMErrorCode, LLMException, ModelDetail, ProviderConfig
|
||||||
|
from .types.capabilities import get_model_capabilities
|
||||||
|
|
||||||
DEFAULT_MODEL_NAME_KEY = "default_model_name"
|
DEFAULT_MODEL_NAME_KEY = "default_model_name"
|
||||||
PROXY_KEY = "proxy"
|
PROXY_KEY = "proxy"
|
||||||
@ -115,57 +116,30 @@ def get_default_api_base_for_type(api_type: str) -> str | None:
|
|||||||
|
|
||||||
|
|
||||||
def get_configured_providers() -> list[ProviderConfig]:
|
def get_configured_providers() -> list[ProviderConfig]:
|
||||||
"""从配置中获取Provider列表 - 简化版本"""
|
"""从配置中获取Provider列表 - 简化和修正版本"""
|
||||||
ai_config = get_ai_config()
|
ai_config = get_ai_config()
|
||||||
providers_raw = ai_config.get(PROVIDERS_CONFIG_KEY, [])
|
providers = ai_config.get(PROVIDERS_CONFIG_KEY, [])
|
||||||
if not isinstance(providers_raw, list):
|
|
||||||
|
if not isinstance(providers, list):
|
||||||
logger.error(
|
logger.error(
|
||||||
f"配置项 {AI_CONFIG_GROUP}.{PROVIDERS_CONFIG_KEY} 不是一个列表,"
|
f"配置项 {AI_CONFIG_GROUP}.{PROVIDERS_CONFIG_KEY} 的值不是一个列表,"
|
||||||
f"将使用空列表。"
|
f"将使用空列表。"
|
||||||
)
|
)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
valid_providers = []
|
valid_providers = []
|
||||||
for i, item in enumerate(providers_raw):
|
for i, item in enumerate(providers):
|
||||||
if not isinstance(item, dict):
|
if isinstance(item, ProviderConfig):
|
||||||
logger.warning(f"配置文件中第 {i + 1} 项不是字典格式,已跳过。")
|
if not item.api_base:
|
||||||
continue
|
default_api_base = get_default_api_base_for_type(item.api_type)
|
||||||
|
|
||||||
try:
|
|
||||||
if not item.get("name"):
|
|
||||||
logger.warning(f"Provider {i + 1} 缺少 'name' 字段,已跳过。")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not item.get("api_key"):
|
|
||||||
logger.warning(
|
|
||||||
f"Provider '{item['name']}' 缺少 'api_key' 字段,已跳过。"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if "api_type" not in item or not item["api_type"]:
|
|
||||||
provider_name = item.get("name", "").lower()
|
|
||||||
if "glm" in provider_name or "zhipu" in provider_name:
|
|
||||||
item["api_type"] = "zhipu"
|
|
||||||
elif "gemini" in provider_name or "google" in provider_name:
|
|
||||||
item["api_type"] = "gemini"
|
|
||||||
else:
|
|
||||||
item["api_type"] = "openai"
|
|
||||||
|
|
||||||
if "api_base" not in item or not item["api_base"]:
|
|
||||||
api_type = item.get("api_type")
|
|
||||||
if api_type:
|
|
||||||
default_api_base = get_default_api_base_for_type(api_type)
|
|
||||||
if default_api_base:
|
if default_api_base:
|
||||||
item["api_base"] = default_api_base
|
item.api_base = default_api_base
|
||||||
|
valid_providers.append(item)
|
||||||
if "models" not in item:
|
else:
|
||||||
item["models"] = [{"model_name": item.get("name", "default")}]
|
logger.warning(
|
||||||
|
f"配置文件中第 {i + 1} 项未能正确解析为 ProviderConfig 对象,"
|
||||||
provider_conf = ProviderConfig(**item)
|
f"已跳过。实际类型: {type(item)}"
|
||||||
valid_providers.append(provider_conf)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"解析配置文件中 Provider {i + 1} 时出错: {e},已跳过。")
|
|
||||||
|
|
||||||
return valid_providers
|
return valid_providers
|
||||||
|
|
||||||
@ -173,14 +147,15 @@ def get_configured_providers() -> list[ProviderConfig]:
|
|||||||
def find_model_config(
|
def find_model_config(
|
||||||
provider_name: str, model_name: str
|
provider_name: str, model_name: str
|
||||||
) -> tuple[ProviderConfig, ModelDetail] | None:
|
) -> tuple[ProviderConfig, ModelDetail] | None:
|
||||||
"""在配置中查找指定的 Provider 和 ModelDetail
|
"""
|
||||||
|
在配置中查找指定的 Provider 和 ModelDetail
|
||||||
|
|
||||||
Args:
|
参数:
|
||||||
provider_name: 提供商名称
|
provider_name: 提供商名称
|
||||||
model_name: 模型名称
|
model_name: 模型名称
|
||||||
|
|
||||||
Returns:
|
返回:
|
||||||
找到的 (ProviderConfig, ModelDetail) 元组,未找到则返回 None
|
tuple[ProviderConfig, ModelDetail] | None: 找到的配置元组,未找到则返回 None
|
||||||
"""
|
"""
|
||||||
providers = get_configured_providers()
|
providers = get_configured_providers()
|
||||||
|
|
||||||
@ -221,10 +196,11 @@ def _get_model_identifiers(provider_name: str, model_detail: ModelDetail) -> lis
|
|||||||
|
|
||||||
|
|
||||||
def list_model_identifiers() -> dict[str, list[str]]:
|
def list_model_identifiers() -> dict[str, list[str]]:
|
||||||
"""列出所有模型的可用标识符
|
"""
|
||||||
|
列出所有模型的可用标识符
|
||||||
|
|
||||||
Returns:
|
返回:
|
||||||
字典,键为模型的完整名称,值为该模型的所有可用标识符列表
|
dict[str, list[str]]: 字典,键为模型的完整名称,值为该模型的所有可用标识符列表
|
||||||
"""
|
"""
|
||||||
providers = get_configured_providers()
|
providers = get_configured_providers()
|
||||||
result = {}
|
result = {}
|
||||||
@ -248,7 +224,16 @@ async def get_model_instance(
|
|||||||
provider_model_name: str | None = None,
|
provider_model_name: str | None = None,
|
||||||
override_config: dict[str, Any] | None = None,
|
override_config: dict[str, Any] | None = None,
|
||||||
) -> LLMModel:
|
) -> LLMModel:
|
||||||
"""根据 'ProviderName/ModelName' 字符串获取并实例化 LLMModel (异步版本)"""
|
"""
|
||||||
|
根据 'ProviderName/ModelName' 字符串获取并实例化 LLMModel (异步版本)
|
||||||
|
|
||||||
|
参数:
|
||||||
|
provider_model_name: 模型名称,格式为 'ProviderName/ModelName'。
|
||||||
|
override_config: 覆盖配置字典。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
LLMModel: 模型实例。
|
||||||
|
"""
|
||||||
cache_key = _make_cache_key(provider_model_name, override_config)
|
cache_key = _make_cache_key(provider_model_name, override_config)
|
||||||
cached_model = _get_cached_model(cache_key)
|
cached_model = _get_cached_model(cache_key)
|
||||||
if cached_model:
|
if cached_model:
|
||||||
@ -292,6 +277,10 @@ async def get_model_instance(
|
|||||||
|
|
||||||
provider_config_found, model_detail_found = config_tuple_found
|
provider_config_found, model_detail_found = config_tuple_found
|
||||||
|
|
||||||
|
capabilities = get_model_capabilities(model_detail_found.model_name)
|
||||||
|
|
||||||
|
model_detail_found.is_embedding_model = capabilities.is_embedding_model
|
||||||
|
|
||||||
ai_config = get_ai_config()
|
ai_config = get_ai_config()
|
||||||
global_proxy_setting = ai_config.get(PROXY_KEY)
|
global_proxy_setting = ai_config.get(PROXY_KEY)
|
||||||
default_timeout = (
|
default_timeout = (
|
||||||
@ -322,6 +311,7 @@ async def get_model_instance(
|
|||||||
model_detail=model_detail_found,
|
model_detail=model_detail_found,
|
||||||
key_store=key_store,
|
key_store=key_store,
|
||||||
http_client=shared_http_client,
|
http_client=shared_http_client,
|
||||||
|
capabilities=capabilities,
|
||||||
)
|
)
|
||||||
|
|
||||||
if override_config:
|
if override_config:
|
||||||
@ -357,7 +347,15 @@ def get_global_default_model_name() -> str | None:
|
|||||||
|
|
||||||
|
|
||||||
def set_global_default_model_name(provider_model_name: str | None) -> bool:
|
def set_global_default_model_name(provider_model_name: str | None) -> bool:
|
||||||
"""设置全局默认模型名称"""
|
"""
|
||||||
|
设置全局默认模型名称
|
||||||
|
|
||||||
|
参数:
|
||||||
|
provider_model_name: 模型名称,格式为 'ProviderName/ModelName'。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bool: 设置是否成功。
|
||||||
|
"""
|
||||||
if provider_model_name:
|
if provider_model_name:
|
||||||
prov_name, mod_name = parse_provider_model_string(provider_model_name)
|
prov_name, mod_name = parse_provider_model_string(provider_model_name)
|
||||||
if not prov_name or not mod_name or not find_model_config(prov_name, mod_name):
|
if not prov_name or not mod_name or not find_model_config(prov_name, mod_name):
|
||||||
@ -377,7 +375,12 @@ def set_global_default_model_name(provider_model_name: str | None) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
async def get_key_usage_stats() -> dict[str, Any]:
|
async def get_key_usage_stats() -> dict[str, Any]:
|
||||||
"""获取所有Provider的Key使用统计"""
|
"""
|
||||||
|
获取所有Provider的Key使用统计
|
||||||
|
|
||||||
|
返回:
|
||||||
|
dict[str, Any]: 包含所有Provider的Key使用统计信息。
|
||||||
|
"""
|
||||||
providers = get_configured_providers()
|
providers = get_configured_providers()
|
||||||
stats = {}
|
stats = {}
|
||||||
|
|
||||||
@ -400,7 +403,16 @@ async def get_key_usage_stats() -> dict[str, Any]:
|
|||||||
|
|
||||||
|
|
||||||
async def reset_key_status(provider_name: str, api_key: str | None = None) -> bool:
|
async def reset_key_status(provider_name: str, api_key: str | None = None) -> bool:
|
||||||
"""重置指定Provider的Key状态"""
|
"""
|
||||||
|
重置指定Provider的Key状态
|
||||||
|
|
||||||
|
参数:
|
||||||
|
provider_name: 提供商名称。
|
||||||
|
api_key: 要重置的特定API密钥,如果为None则重置所有密钥。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bool: 重置是否成功。
|
||||||
|
"""
|
||||||
providers = get_configured_providers()
|
providers = get_configured_providers()
|
||||||
target_provider = None
|
target_provider = None
|
||||||
|
|
||||||
|
|||||||
@ -6,11 +6,13 @@ LLM 模型实现类
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
|
from contextlib import AsyncExitStack
|
||||||
import json
|
import json
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
|
|
||||||
|
from .adapters.base import RequestData
|
||||||
from .config import LLMGenerationConfig
|
from .config import LLMGenerationConfig
|
||||||
from .config.providers import get_ai_config
|
from .config.providers import get_ai_config
|
||||||
from .core import (
|
from .core import (
|
||||||
@ -30,6 +32,8 @@ from .types import (
|
|||||||
ModelDetail,
|
ModelDetail,
|
||||||
ProviderConfig,
|
ProviderConfig,
|
||||||
)
|
)
|
||||||
|
from .types.capabilities import ModelCapabilities, ModelModality
|
||||||
|
from .utils import _sanitize_request_body_for_logging
|
||||||
|
|
||||||
|
|
||||||
class LLMModelBase(ABC):
|
class LLMModelBase(ABC):
|
||||||
@ -42,7 +46,17 @@ class LLMModelBase(ABC):
|
|||||||
history: list[dict[str, str]] | None = None,
|
history: list[dict[str, str]] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""生成文本"""
|
"""
|
||||||
|
生成文本
|
||||||
|
|
||||||
|
参数:
|
||||||
|
prompt: 输入提示词。
|
||||||
|
history: 对话历史记录。
|
||||||
|
**kwargs: 其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 生成的文本。
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -54,7 +68,19 @@ class LLMModelBase(ABC):
|
|||||||
tool_choice: str | dict[str, Any] | None = None,
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""生成高级响应"""
|
"""
|
||||||
|
生成高级响应
|
||||||
|
|
||||||
|
参数:
|
||||||
|
messages: 消息列表。
|
||||||
|
config: 生成配置。
|
||||||
|
tools: 工具列表。
|
||||||
|
tool_choice: 工具选择策略。
|
||||||
|
**kwargs: 其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
LLMResponse: 模型响应。
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -64,7 +90,17 @@ class LLMModelBase(ABC):
|
|||||||
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> list[list[float]]:
|
) -> list[list[float]]:
|
||||||
"""生成文本嵌入向量"""
|
"""
|
||||||
|
生成文本嵌入向量
|
||||||
|
|
||||||
|
参数:
|
||||||
|
texts: 文本列表。
|
||||||
|
task_type: 嵌入任务类型。
|
||||||
|
**kwargs: 其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
list[list[float]]: 嵌入向量列表。
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -77,12 +113,14 @@ class LLMModel(LLMModelBase):
|
|||||||
model_detail: ModelDetail,
|
model_detail: ModelDetail,
|
||||||
key_store: KeyStatusStore,
|
key_store: KeyStatusStore,
|
||||||
http_client: LLMHttpClient,
|
http_client: LLMHttpClient,
|
||||||
|
capabilities: ModelCapabilities,
|
||||||
config_override: LLMGenerationConfig | None = None,
|
config_override: LLMGenerationConfig | None = None,
|
||||||
):
|
):
|
||||||
self.provider_config = provider_config
|
self.provider_config = provider_config
|
||||||
self.model_detail = model_detail
|
self.model_detail = model_detail
|
||||||
self.key_store = key_store
|
self.key_store = key_store
|
||||||
self.http_client: LLMHttpClient = http_client
|
self.http_client: LLMHttpClient = http_client
|
||||||
|
self.capabilities = capabilities
|
||||||
self._generation_config = config_override
|
self._generation_config = config_override
|
||||||
|
|
||||||
self.provider_name = provider_config.name
|
self.provider_name = provider_config.name
|
||||||
@ -99,6 +137,34 @@ class LLMModel(LLMModelBase):
|
|||||||
|
|
||||||
self._is_closed = False
|
self._is_closed = False
|
||||||
|
|
||||||
|
def can_process_images(self) -> bool:
|
||||||
|
"""检查模型是否支持图片作为输入。"""
|
||||||
|
return ModelModality.IMAGE in self.capabilities.input_modalities
|
||||||
|
|
||||||
|
def can_process_video(self) -> bool:
|
||||||
|
"""检查模型是否支持视频作为输入。"""
|
||||||
|
return ModelModality.VIDEO in self.capabilities.input_modalities
|
||||||
|
|
||||||
|
def can_process_audio(self) -> bool:
|
||||||
|
"""检查模型是否支持音频作为输入。"""
|
||||||
|
return ModelModality.AUDIO in self.capabilities.input_modalities
|
||||||
|
|
||||||
|
def can_generate_images(self) -> bool:
|
||||||
|
"""检查模型是否支持生成图片。"""
|
||||||
|
return ModelModality.IMAGE in self.capabilities.output_modalities
|
||||||
|
|
||||||
|
def can_generate_audio(self) -> bool:
|
||||||
|
"""检查模型是否支持生成音频 (TTS)。"""
|
||||||
|
return ModelModality.AUDIO in self.capabilities.output_modalities
|
||||||
|
|
||||||
|
def can_use_tools(self) -> bool:
|
||||||
|
"""检查模型是否支持工具调用/函数调用。"""
|
||||||
|
return self.capabilities.supports_tool_calling
|
||||||
|
|
||||||
|
def is_embedding_model(self) -> bool:
|
||||||
|
"""检查这是否是一个嵌入模型。"""
|
||||||
|
return self.capabilities.is_embedding_model
|
||||||
|
|
||||||
async def _get_http_client(self) -> LLMHttpClient:
|
async def _get_http_client(self) -> LLMHttpClient:
|
||||||
"""获取HTTP客户端"""
|
"""获取HTTP客户端"""
|
||||||
if self.http_client.is_closed:
|
if self.http_client.is_closed:
|
||||||
@ -135,24 +201,54 @@ class LLMModel(LLMModelBase):
|
|||||||
|
|
||||||
return selected_key
|
return selected_key
|
||||||
|
|
||||||
async def _execute_embedding_request(
|
async def _perform_api_call(
|
||||||
self,
|
self,
|
||||||
adapter,
|
prepare_request_func: Callable[[str], Awaitable["RequestData"]],
|
||||||
texts: list[str],
|
parse_response_func: Callable[[dict[str, Any]], Any],
|
||||||
task_type: EmbeddingTaskType | str,
|
http_client: "LLMHttpClient",
|
||||||
http_client: LLMHttpClient,
|
|
||||||
failed_keys: set[str] | None = None,
|
failed_keys: set[str] | None = None,
|
||||||
) -> list[list[float]]:
|
log_context: str = "API",
|
||||||
"""执行单次嵌入请求 - 供重试机制调用"""
|
) -> Any:
|
||||||
|
"""
|
||||||
|
执行API调用的通用核心方法。
|
||||||
|
|
||||||
|
该方法封装了以下通用逻辑:
|
||||||
|
1. 选择API密钥。
|
||||||
|
2. 准备和记录请求。
|
||||||
|
3. 发送HTTP POST请求。
|
||||||
|
4. 处理HTTP错误和API特定错误。
|
||||||
|
5. 记录密钥使用状态。
|
||||||
|
6. 解析成功的响应。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
prepare_request_func: 准备请求的函数。
|
||||||
|
parse_response_func: 解析响应的函数。
|
||||||
|
http_client: HTTP客户端。
|
||||||
|
failed_keys: 失败的密钥集合。
|
||||||
|
log_context: 日志上下文。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
Any: 解析后的响应数据。
|
||||||
|
"""
|
||||||
api_key = await self._select_api_key(failed_keys)
|
api_key = await self._select_api_key(failed_keys)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
request_data = adapter.prepare_embedding_request(
|
request_data = await prepare_request_func(api_key)
|
||||||
model=self,
|
|
||||||
api_key=api_key,
|
logger.info(
|
||||||
texts=texts,
|
f"🌐 发起LLM请求 - 模型: {self.provider_name}/{self.model_name} "
|
||||||
task_type=task_type,
|
f"[{log_context}]"
|
||||||
)
|
)
|
||||||
|
logger.debug(f"📡 请求URL: {request_data.url}")
|
||||||
|
masked_key = (
|
||||||
|
f"{api_key[:8]}...{api_key[-4:] if len(api_key) > 12 else '***'}"
|
||||||
|
)
|
||||||
|
logger.debug(f"🔑 API密钥: {masked_key}")
|
||||||
|
logger.debug(f"📋 请求头: {dict(request_data.headers)}")
|
||||||
|
|
||||||
|
sanitized_body = _sanitize_request_body_for_logging(request_data.body)
|
||||||
|
request_body_str = json.dumps(sanitized_body, ensure_ascii=False, indent=2)
|
||||||
|
logger.debug(f"📦 请求体: {request_body_str}")
|
||||||
|
|
||||||
http_response = await http_client.post(
|
http_response = await http_client.post(
|
||||||
request_data.url,
|
request_data.url,
|
||||||
@ -160,121 +256,16 @@ class LLMModel(LLMModelBase):
|
|||||||
json=request_data.body,
|
json=request_data.body,
|
||||||
)
|
)
|
||||||
|
|
||||||
if http_response.status_code != 200:
|
logger.debug(f"📥 响应状态码: {http_response.status_code}")
|
||||||
error_text = http_response.text
|
logger.debug(f"📄 响应头: {dict(http_response.headers)}")
|
||||||
logger.error(
|
|
||||||
f"HTTP嵌入请求失败: {http_response.status_code} - {error_text}"
|
|
||||||
)
|
|
||||||
await self.key_store.record_failure(api_key, http_response.status_code)
|
|
||||||
|
|
||||||
error_code = LLMErrorCode.API_REQUEST_FAILED
|
|
||||||
if http_response.status_code in [401, 403]:
|
|
||||||
error_code = LLMErrorCode.API_KEY_INVALID
|
|
||||||
elif http_response.status_code == 429:
|
|
||||||
error_code = LLMErrorCode.API_RATE_LIMITED
|
|
||||||
|
|
||||||
raise LLMException(
|
|
||||||
f"HTTP嵌入请求失败: {http_response.status_code}",
|
|
||||||
code=error_code,
|
|
||||||
details={
|
|
||||||
"status_code": http_response.status_code,
|
|
||||||
"response": error_text,
|
|
||||||
"api_key": api_key,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response_json = http_response.json()
|
|
||||||
adapter.validate_embedding_response(response_json)
|
|
||||||
embeddings = adapter.parse_embedding_response(response_json)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"解析嵌入响应失败: {e}", e=e)
|
|
||||||
await self.key_store.record_failure(api_key, None)
|
|
||||||
if isinstance(e, LLMException):
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
raise LLMException(
|
|
||||||
f"解析API嵌入响应失败: {e}",
|
|
||||||
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
|
|
||||||
cause=e,
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.key_store.record_success(api_key)
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
except LLMException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"生成嵌入时发生未预期错误: {e}", e=e)
|
|
||||||
await self.key_store.record_failure(api_key, None)
|
|
||||||
raise LLMException(
|
|
||||||
f"生成嵌入失败: {e}",
|
|
||||||
code=LLMErrorCode.EMBEDDING_FAILED,
|
|
||||||
cause=e,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _execute_with_smart_retry(
|
|
||||||
self,
|
|
||||||
adapter,
|
|
||||||
messages: list[LLMMessage],
|
|
||||||
config: LLMGenerationConfig | None,
|
|
||||||
tools_dict: list[dict[str, Any]] | None,
|
|
||||||
tool_choice: str | dict[str, Any] | None,
|
|
||||||
http_client: LLMHttpClient,
|
|
||||||
):
|
|
||||||
"""智能重试机制 - 使用统一的重试装饰器"""
|
|
||||||
ai_config = get_ai_config()
|
|
||||||
max_retries = ai_config.get("max_retries_llm", 3)
|
|
||||||
retry_delay = ai_config.get("retry_delay_llm", 2)
|
|
||||||
retry_config = RetryConfig(max_retries=max_retries, retry_delay=retry_delay)
|
|
||||||
|
|
||||||
return await with_smart_retry(
|
|
||||||
self._execute_single_request,
|
|
||||||
adapter,
|
|
||||||
messages,
|
|
||||||
config,
|
|
||||||
tools_dict,
|
|
||||||
tool_choice,
|
|
||||||
http_client,
|
|
||||||
retry_config=retry_config,
|
|
||||||
key_store=self.key_store,
|
|
||||||
provider_name=self.provider_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _execute_single_request(
|
|
||||||
self,
|
|
||||||
adapter,
|
|
||||||
messages: list[LLMMessage],
|
|
||||||
config: LLMGenerationConfig | None,
|
|
||||||
tools_dict: list[dict[str, Any]] | None,
|
|
||||||
tool_choice: str | dict[str, Any] | None,
|
|
||||||
http_client: LLMHttpClient,
|
|
||||||
failed_keys: set[str] | None = None,
|
|
||||||
) -> LLMResponse:
|
|
||||||
"""执行单次请求 - 供重试机制调用,直接返回 LLMResponse"""
|
|
||||||
api_key = await self._select_api_key(failed_keys)
|
|
||||||
|
|
||||||
try:
|
|
||||||
request_data = adapter.prepare_advanced_request(
|
|
||||||
model=self,
|
|
||||||
api_key=api_key,
|
|
||||||
messages=messages,
|
|
||||||
config=config,
|
|
||||||
tools=tools_dict,
|
|
||||||
tool_choice=tool_choice,
|
|
||||||
)
|
|
||||||
|
|
||||||
http_response = await http_client.post(
|
|
||||||
request_data.url,
|
|
||||||
headers=request_data.headers,
|
|
||||||
json=request_data.body,
|
|
||||||
)
|
|
||||||
|
|
||||||
if http_response.status_code != 200:
|
if http_response.status_code != 200:
|
||||||
error_text = http_response.text
|
error_text = http_response.text
|
||||||
logger.error(
|
logger.error(
|
||||||
f"HTTP请求失败: {http_response.status_code} - {error_text}"
|
f"❌ HTTP请求失败: {http_response.status_code} - {error_text} "
|
||||||
|
f"[{log_context}]"
|
||||||
)
|
)
|
||||||
|
logger.debug(f"💥 完整错误响应: {error_text}")
|
||||||
|
|
||||||
await self.key_store.record_failure(api_key, http_response.status_code)
|
await self.key_store.record_failure(api_key, http_response.status_code)
|
||||||
|
|
||||||
@ -299,12 +290,129 @@ class LLMModel(LLMModelBase):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
response_json = http_response.json()
|
response_json = http_response.json()
|
||||||
|
response_json_str = json.dumps(
|
||||||
|
response_json, ensure_ascii=False, indent=2
|
||||||
|
)
|
||||||
|
logger.debug(f"📋 响应JSON: {response_json_str}")
|
||||||
|
parsed_data = parse_response_func(response_json)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"解析 {log_context} 响应失败: {e}", e=e)
|
||||||
|
await self.key_store.record_failure(api_key, None)
|
||||||
|
if isinstance(e, LLMException):
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
raise LLMException(
|
||||||
|
f"解析API {log_context} 响应失败: {e}",
|
||||||
|
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.key_store.record_success(api_key)
|
||||||
|
logger.debug(f"✅ API密钥使用成功: {masked_key}")
|
||||||
|
logger.info(f"🎯 LLM响应解析完成 [{log_context}]")
|
||||||
|
return parsed_data
|
||||||
|
|
||||||
|
except LLMException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
error_log_msg = f"生成 {log_context.lower()} 时发生未预期错误: {e}"
|
||||||
|
logger.error(error_log_msg, e=e)
|
||||||
|
await self.key_store.record_failure(api_key, None)
|
||||||
|
raise LLMException(
|
||||||
|
error_log_msg,
|
||||||
|
code=LLMErrorCode.GENERATION_FAILED
|
||||||
|
if log_context == "Generation"
|
||||||
|
else LLMErrorCode.EMBEDDING_FAILED,
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _execute_embedding_request(
|
||||||
|
self,
|
||||||
|
adapter,
|
||||||
|
texts: list[str],
|
||||||
|
task_type: EmbeddingTaskType | str,
|
||||||
|
http_client: LLMHttpClient,
|
||||||
|
failed_keys: set[str] | None = None,
|
||||||
|
) -> list[list[float]]:
|
||||||
|
"""执行单次嵌入请求 - 供重试机制调用"""
|
||||||
|
|
||||||
|
async def prepare_request(api_key: str) -> RequestData:
|
||||||
|
return adapter.prepare_embedding_request(
|
||||||
|
model=self,
|
||||||
|
api_key=api_key,
|
||||||
|
texts=texts,
|
||||||
|
task_type=task_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
def parse_response(response_json: dict[str, Any]) -> list[list[float]]:
|
||||||
|
adapter.validate_embedding_response(response_json)
|
||||||
|
return adapter.parse_embedding_response(response_json)
|
||||||
|
|
||||||
|
return await self._perform_api_call(
|
||||||
|
prepare_request_func=prepare_request,
|
||||||
|
parse_response_func=parse_response,
|
||||||
|
http_client=http_client,
|
||||||
|
failed_keys=failed_keys,
|
||||||
|
log_context="Embedding",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _execute_with_smart_retry(
|
||||||
|
self,
|
||||||
|
adapter,
|
||||||
|
messages: list[LLMMessage],
|
||||||
|
config: LLMGenerationConfig | None,
|
||||||
|
tools: list[LLMTool] | None,
|
||||||
|
tool_choice: str | dict[str, Any] | None,
|
||||||
|
http_client: LLMHttpClient,
|
||||||
|
):
|
||||||
|
"""智能重试机制 - 使用统一的重试装饰器"""
|
||||||
|
ai_config = get_ai_config()
|
||||||
|
max_retries = ai_config.get("max_retries_llm", 3)
|
||||||
|
retry_delay = ai_config.get("retry_delay_llm", 2)
|
||||||
|
retry_config = RetryConfig(max_retries=max_retries, retry_delay=retry_delay)
|
||||||
|
|
||||||
|
return await with_smart_retry(
|
||||||
|
self._execute_single_request,
|
||||||
|
adapter,
|
||||||
|
messages,
|
||||||
|
config,
|
||||||
|
tools,
|
||||||
|
tool_choice,
|
||||||
|
http_client,
|
||||||
|
retry_config=retry_config,
|
||||||
|
key_store=self.key_store,
|
||||||
|
provider_name=self.provider_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _execute_single_request(
|
||||||
|
self,
|
||||||
|
adapter,
|
||||||
|
messages: list[LLMMessage],
|
||||||
|
config: LLMGenerationConfig | None,
|
||||||
|
tools: list[LLMTool] | None,
|
||||||
|
tool_choice: str | dict[str, Any] | None,
|
||||||
|
http_client: LLMHttpClient,
|
||||||
|
failed_keys: set[str] | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""执行单次请求 - 供重试机制调用,直接返回 LLMResponse"""
|
||||||
|
|
||||||
|
async def prepare_request(api_key: str) -> RequestData:
|
||||||
|
return await adapter.prepare_advanced_request(
|
||||||
|
model=self,
|
||||||
|
api_key=api_key,
|
||||||
|
messages=messages,
|
||||||
|
config=config,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
)
|
||||||
|
|
||||||
|
def parse_response(response_json: dict[str, Any]) -> LLMResponse:
|
||||||
response_data = adapter.parse_response(
|
response_data = adapter.parse_response(
|
||||||
model=self,
|
model=self,
|
||||||
response_json=response_json,
|
response_json=response_json,
|
||||||
is_advanced=True,
|
is_advanced=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .types.models import LLMToolCall
|
from .types.models import LLMToolCall
|
||||||
|
|
||||||
response_tool_calls = []
|
response_tool_calls = []
|
||||||
@ -323,7 +431,7 @@ class LLMModel(LLMModelBase):
|
|||||||
else:
|
else:
|
||||||
logger.warning(f"工具调用数据格式未知: {tc_data}")
|
logger.warning(f"工具调用数据格式未知: {tc_data}")
|
||||||
|
|
||||||
llm_response = LLMResponse(
|
return LLMResponse(
|
||||||
text=response_data.text,
|
text=response_data.text,
|
||||||
usage_info=response_data.usage_info,
|
usage_info=response_data.usage_info,
|
||||||
raw_response=response_data.raw_response,
|
raw_response=response_data.raw_response,
|
||||||
@ -333,33 +441,12 @@ class LLMModel(LLMModelBase):
|
|||||||
cache_info=response_data.cache_info,
|
cache_info=response_data.cache_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
return await self._perform_api_call(
|
||||||
logger.error(f"解析响应失败: {e}", e=e)
|
prepare_request_func=prepare_request,
|
||||||
await self.key_store.record_failure(api_key, None)
|
parse_response_func=parse_response,
|
||||||
|
http_client=http_client,
|
||||||
if isinstance(e, LLMException):
|
failed_keys=failed_keys,
|
||||||
raise
|
log_context="Generation",
|
||||||
else:
|
|
||||||
raise LLMException(
|
|
||||||
f"解析API响应失败: {e}",
|
|
||||||
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
|
|
||||||
cause=e,
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.key_store.record_success(api_key)
|
|
||||||
|
|
||||||
return llm_response
|
|
||||||
|
|
||||||
except LLMException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"生成响应时发生未预期错误: {e}", e=e)
|
|
||||||
await self.key_store.record_failure(api_key, None)
|
|
||||||
|
|
||||||
raise LLMException(
|
|
||||||
f"生成响应失败: {e}",
|
|
||||||
code=LLMErrorCode.GENERATION_FAILED,
|
|
||||||
cause=e,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
@ -400,7 +487,17 @@ class LLMModel(LLMModelBase):
|
|||||||
history: list[dict[str, str]] | None = None,
|
history: list[dict[str, str]] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""生成文本 - 通过 generate_response 实现"""
|
"""
|
||||||
|
生成文本 - 通过 generate_response 实现
|
||||||
|
|
||||||
|
参数:
|
||||||
|
prompt: 输入提示词。
|
||||||
|
history: 对话历史记录。
|
||||||
|
**kwargs: 其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 生成的文本。
|
||||||
|
"""
|
||||||
self._check_not_closed()
|
self._check_not_closed()
|
||||||
|
|
||||||
messages: list[LLMMessage] = []
|
messages: list[LLMMessage] = []
|
||||||
@ -439,11 +536,21 @@ class LLMModel(LLMModelBase):
|
|||||||
config: LLMGenerationConfig | None = None,
|
config: LLMGenerationConfig | None = None,
|
||||||
tools: list[LLMTool] | None = None,
|
tools: list[LLMTool] | None = None,
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
tool_executor: Callable[[str, dict[str, Any]], Awaitable[Any]] | None = None,
|
|
||||||
max_tool_iterations: int = 5,
|
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""生成高级响应 - 实现完整的工具调用循环"""
|
"""
|
||||||
|
生成高级响应
|
||||||
|
|
||||||
|
参数:
|
||||||
|
messages: 消息列表。
|
||||||
|
config: 生成配置。
|
||||||
|
tools: 工具列表。
|
||||||
|
tool_choice: 工具选择策略。
|
||||||
|
**kwargs: 其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
LLMResponse: 模型响应。
|
||||||
|
"""
|
||||||
self._check_not_closed()
|
self._check_not_closed()
|
||||||
|
|
||||||
from .adapters import get_adapter_for_api_type
|
from .adapters import get_adapter_for_api_type
|
||||||
@ -468,117 +575,61 @@ class LLMModel(LLMModelBase):
|
|||||||
merged_dict.update(config.to_dict())
|
merged_dict.update(config.to_dict())
|
||||||
final_request_config = LLMGenerationConfig(**merged_dict)
|
final_request_config = LLMGenerationConfig(**merged_dict)
|
||||||
|
|
||||||
tools_dict: list[dict[str, Any]] | None = None
|
|
||||||
if tools:
|
|
||||||
tools_dict = []
|
|
||||||
for tool in tools:
|
|
||||||
if hasattr(tool, "model_dump"):
|
|
||||||
model_dump_func = getattr(tool, "model_dump")
|
|
||||||
tools_dict.append(model_dump_func(exclude_none=True))
|
|
||||||
elif isinstance(tool, dict):
|
|
||||||
tools_dict.append(tool)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
tools_dict.append(dict(tool))
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
logger.warning(f"工具 '{tool}' 无法转换为字典,已忽略。")
|
|
||||||
|
|
||||||
http_client = await self._get_http_client()
|
http_client = await self._get_http_client()
|
||||||
current_messages = list(messages)
|
|
||||||
|
|
||||||
for iteration in range(max_tool_iterations):
|
async with AsyncExitStack() as stack:
|
||||||
logger.debug(f"工具调用循环迭代: {iteration + 1}/{max_tool_iterations}")
|
activated_tools = []
|
||||||
|
if tools:
|
||||||
|
for tool in tools:
|
||||||
|
if tool.type == "mcp" and callable(tool.mcp_session):
|
||||||
|
func_obj = getattr(tool.mcp_session, "func", None)
|
||||||
|
tool_name = (
|
||||||
|
getattr(func_obj, "__name__", "unknown")
|
||||||
|
if func_obj
|
||||||
|
else "unknown"
|
||||||
|
)
|
||||||
|
logger.debug(f"正在激活 MCP 工具会话: {tool_name}")
|
||||||
|
|
||||||
|
active_session = await stack.enter_async_context(
|
||||||
|
tool.mcp_session()
|
||||||
|
)
|
||||||
|
|
||||||
|
activated_tools.append(
|
||||||
|
LLMTool.from_mcp_session(
|
||||||
|
session=active_session, annotations=tool.annotations
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
activated_tools.append(tool)
|
||||||
|
|
||||||
llm_response = await self._execute_with_smart_retry(
|
llm_response = await self._execute_with_smart_retry(
|
||||||
adapter,
|
adapter,
|
||||||
current_messages,
|
messages,
|
||||||
final_request_config,
|
final_request_config,
|
||||||
tools_dict if iteration == 0 else None,
|
activated_tools if activated_tools else None,
|
||||||
tool_choice if iteration == 0 else None,
|
tool_choice,
|
||||||
http_client,
|
http_client,
|
||||||
)
|
)
|
||||||
|
|
||||||
response_tool_calls = llm_response.tool_calls or []
|
|
||||||
|
|
||||||
if not response_tool_calls or not tool_executor:
|
|
||||||
logger.debug("模型未请求工具调用,或未提供工具执行器。返回当前响应。")
|
|
||||||
return llm_response
|
return llm_response
|
||||||
|
|
||||||
logger.info(f"模型请求执行 {len(response_tool_calls)} 个工具。")
|
|
||||||
|
|
||||||
assistant_message_content = llm_response.text if llm_response.text else ""
|
|
||||||
current_messages.append(
|
|
||||||
LLMMessage.assistant_tool_calls(
|
|
||||||
content=assistant_message_content, tool_calls=response_tool_calls
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_response_messages: list[LLMMessage] = []
|
|
||||||
for tool_call in response_tool_calls:
|
|
||||||
tool_name = tool_call.function.name
|
|
||||||
try:
|
|
||||||
tool_args_dict = json.loads(tool_call.function.arguments)
|
|
||||||
logger.debug(f"执行工具: {tool_name},参数: {tool_args_dict}")
|
|
||||||
|
|
||||||
tool_result = await tool_executor(tool_name, tool_args_dict)
|
|
||||||
logger.debug(
|
|
||||||
f"工具 '{tool_name}' 执行结果: {str(tool_result)[:200]}..."
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_response_messages.append(
|
|
||||||
LLMMessage.tool_response(
|
|
||||||
tool_call_id=tool_call.id,
|
|
||||||
function_name=tool_name,
|
|
||||||
result=tool_result,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
logger.error(
|
|
||||||
f"工具 '{tool_name}' 参数JSON解析失败: "
|
|
||||||
f"{tool_call.function.arguments}, 错误: {e}"
|
|
||||||
)
|
|
||||||
tool_response_messages.append(
|
|
||||||
LLMMessage.tool_response(
|
|
||||||
tool_call_id=tool_call.id,
|
|
||||||
function_name=tool_name,
|
|
||||||
result={
|
|
||||||
"error": "Argument JSON parsing failed",
|
|
||||||
"details": str(e),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"执行工具 '{tool_name}' 失败: {e}", e=e)
|
|
||||||
tool_response_messages.append(
|
|
||||||
LLMMessage.tool_response(
|
|
||||||
tool_call_id=tool_call.id,
|
|
||||||
function_name=tool_name,
|
|
||||||
result={
|
|
||||||
"error": "Tool execution failed",
|
|
||||||
"details": str(e),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
current_messages.extend(tool_response_messages)
|
|
||||||
|
|
||||||
logger.warning(f"已达到最大工具调用迭代次数 ({max_tool_iterations})。")
|
|
||||||
raise LLMException(
|
|
||||||
"已达到最大工具调用迭代次数,但模型仍在请求工具调用或未提供最终文本回复。",
|
|
||||||
code=LLMErrorCode.GENERATION_FAILED,
|
|
||||||
details={
|
|
||||||
"iterations": max_tool_iterations,
|
|
||||||
"last_messages": current_messages[-2:],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def generate_embeddings(
|
async def generate_embeddings(
|
||||||
self,
|
self,
|
||||||
texts: list[str],
|
texts: list[str],
|
||||||
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> list[list[float]]:
|
) -> list[list[float]]:
|
||||||
"""生成文本嵌入向量"""
|
"""
|
||||||
|
生成文本嵌入向量
|
||||||
|
|
||||||
|
参数:
|
||||||
|
texts: 文本列表。
|
||||||
|
task_type: 嵌入任务类型。
|
||||||
|
**kwargs: 其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
list[list[float]]: 嵌入向量列表。
|
||||||
|
"""
|
||||||
self._check_not_closed()
|
self._check_not_closed()
|
||||||
if not texts:
|
if not texts:
|
||||||
return []
|
return []
|
||||||
|
|||||||
7
zhenxun/services/llm/tools/__init__.py
Normal file
7
zhenxun/services/llm/tools/__init__.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
"""
|
||||||
|
工具模块导出
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .registry import tool_registry
|
||||||
|
|
||||||
|
__all__ = ["tool_registry"]
|
||||||
181
zhenxun/services/llm/tools/registry.py
Normal file
181
zhenxun/services/llm/tools/registry.py
Normal file
@ -0,0 +1,181 @@
|
|||||||
|
"""
|
||||||
|
工具注册表
|
||||||
|
|
||||||
|
负责加载、管理和实例化来自配置的工具。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from contextlib import AbstractAsyncContextManager
|
||||||
|
from functools import partial
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from zhenxun.services.log import logger
|
||||||
|
|
||||||
|
from ..types import LLMTool
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..config.providers import ToolConfig
|
||||||
|
from ..types.protocols import MCPCompatible
|
||||||
|
|
||||||
|
|
||||||
|
class ToolRegistry:
|
||||||
|
"""工具注册表,用于管理和实例化配置的工具。"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._function_tools: dict[str, LLMTool] = {}
|
||||||
|
|
||||||
|
self._mcp_config_models: dict[str, type[BaseModel]] = {}
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
self._mcp_factories: dict[
|
||||||
|
str, Callable[..., AbstractAsyncContextManager["MCPCompatible"]]
|
||||||
|
] = {}
|
||||||
|
else:
|
||||||
|
self._mcp_factories: dict[str, Callable] = {}
|
||||||
|
|
||||||
|
self._tool_configs: dict[str, "ToolConfig"] | None = None
|
||||||
|
self._tool_cache: dict[str, "LLMTool"] = {}
|
||||||
|
|
||||||
|
def _load_configs_if_needed(self):
|
||||||
|
"""如果尚未加载,则从主配置中加载MCP工具定义。"""
|
||||||
|
if self._tool_configs is None:
|
||||||
|
logger.debug("首次访问,正在加载MCP工具配置...")
|
||||||
|
from ..config.providers import get_llm_config
|
||||||
|
|
||||||
|
llm_config = get_llm_config()
|
||||||
|
self._tool_configs = {tool.name: tool for tool in llm_config.mcp_tools}
|
||||||
|
logger.info(f"已加载 {len(self._tool_configs)} 个MCP工具配置。")
|
||||||
|
|
||||||
|
def function_tool(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
parameters: dict,
|
||||||
|
required: list[str] | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
装饰器:在代码中注册一个简单的、无状态的函数工具。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
name: 工具的唯一名称。
|
||||||
|
description: 工具功能的描述。
|
||||||
|
parameters: OpenAPI格式的函数参数schema的properties部分。
|
||||||
|
required: 必需的参数列表。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(func: Callable):
|
||||||
|
if name in self._function_tools or name in self._mcp_factories:
|
||||||
|
logger.warning(f"正在覆盖已注册的工具: {name}")
|
||||||
|
|
||||||
|
tool_definition = LLMTool.create(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
parameters=parameters,
|
||||||
|
required=required,
|
||||||
|
)
|
||||||
|
self._function_tools[name] = tool_definition
|
||||||
|
logger.info(f"已在代码中注册函数工具: '{name}'")
|
||||||
|
tool_definition.annotations = tool_definition.annotations or {}
|
||||||
|
tool_definition.annotations["executable"] = func
|
||||||
|
return func
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def mcp_tool(self, name: str, config_model: type[BaseModel]):
|
||||||
|
"""
|
||||||
|
装饰器:注册一个MCP工具及其配置模型。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
name: 工具的唯一名称,必须与配置文件中的名称匹配。
|
||||||
|
config_model: 一个Pydantic模型,用于定义和验证该工具的 `mcp_config`。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(factory_func: Callable):
|
||||||
|
if name in self._mcp_factories:
|
||||||
|
logger.warning(f"正在覆盖已注册的 MCP 工厂: {name}")
|
||||||
|
self._mcp_factories[name] = factory_func
|
||||||
|
self._mcp_config_models[name] = config_model
|
||||||
|
logger.info(f"已注册 MCP 工具 '{name}' (配置模型: {config_model.__name__})")
|
||||||
|
return factory_func
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def get_mcp_config_model(self, name: str) -> type[BaseModel] | None:
|
||||||
|
"""根据名称获取MCP工具的配置模型。"""
|
||||||
|
return self._mcp_config_models.get(name)
|
||||||
|
|
||||||
|
def register_mcp_factory(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
factory: Callable,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
在代码中注册一个 MCP 会话工厂,将其与配置中的工具名称关联。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
name: 工具的唯一名称,必须与配置文件中的名称匹配。
|
||||||
|
factory: 一个返回异步生成器的可调用对象(会话工厂)。
|
||||||
|
"""
|
||||||
|
if name in self._mcp_factories:
|
||||||
|
logger.warning(f"正在覆盖已注册的 MCP 工厂: {name}")
|
||||||
|
self._mcp_factories[name] = factory
|
||||||
|
logger.info(f"已注册 MCP 会话工厂: '{name}'")
|
||||||
|
|
||||||
|
def get_tool(self, name: str) -> "LLMTool":
|
||||||
|
"""
|
||||||
|
根据名称获取一个 LLMTool 定义。
|
||||||
|
对于MCP工具,返回的 LLMTool 实例包含一个可调用的会话工厂,
|
||||||
|
而不是一个已激活的会话。
|
||||||
|
"""
|
||||||
|
logger.debug(f"🔍 请求获取工具定义: {name}")
|
||||||
|
|
||||||
|
if name in self._tool_cache:
|
||||||
|
logger.debug(f"✅ 从缓存中获取工具定义: {name}")
|
||||||
|
return self._tool_cache[name]
|
||||||
|
|
||||||
|
if name in self._function_tools:
|
||||||
|
logger.debug(f"🛠️ 获取函数工具定义: {name}")
|
||||||
|
tool = self._function_tools[name]
|
||||||
|
self._tool_cache[name] = tool
|
||||||
|
return tool
|
||||||
|
|
||||||
|
self._load_configs_if_needed()
|
||||||
|
if self._tool_configs is None or name not in self._tool_configs:
|
||||||
|
known_tools = list(self._function_tools.keys()) + (
|
||||||
|
list(self._tool_configs.keys()) if self._tool_configs else []
|
||||||
|
)
|
||||||
|
logger.error(f"❌ 未找到名为 '{name}' 的工具定义")
|
||||||
|
logger.debug(f"📋 可用工具定义列表: {known_tools}")
|
||||||
|
raise ValueError(f"未找到名为 '{name}' 的工具定义。已知工具: {known_tools}")
|
||||||
|
|
||||||
|
config = self._tool_configs[name]
|
||||||
|
tool: "LLMTool"
|
||||||
|
|
||||||
|
if name not in self._mcp_factories:
|
||||||
|
logger.error(f"❌ MCP工具 '{name}' 缺少工厂函数")
|
||||||
|
available_factories = list(self._mcp_factories.keys())
|
||||||
|
logger.debug(f"📋 已注册的MCP工厂: {available_factories}")
|
||||||
|
raise ValueError(
|
||||||
|
f"MCP 工具 '{name}' 已在配置中定义,但没有注册对应的工厂函数。"
|
||||||
|
"请使用 `@tool_registry.mcp_tool` 装饰器进行注册。"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"🔧 创建MCP工具定义: {name}")
|
||||||
|
factory = self._mcp_factories[name]
|
||||||
|
typed_mcp_config = config.mcp_config
|
||||||
|
logger.debug(f"📋 MCP工具配置: {typed_mcp_config}")
|
||||||
|
|
||||||
|
configured_factory = partial(factory, config=typed_mcp_config)
|
||||||
|
tool = LLMTool.from_mcp_session(session=configured_factory)
|
||||||
|
|
||||||
|
self._tool_cache[name] = tool
|
||||||
|
logger.debug(f"💾 MCP工具定义已缓存: {name}")
|
||||||
|
return tool
|
||||||
|
|
||||||
|
def get_tools(self, names: list[str]) -> list["LLMTool"]:
|
||||||
|
"""根据名称列表获取多个 LLMTool 实例。"""
|
||||||
|
return [self.get_tool(name) for name in names]
|
||||||
|
|
||||||
|
|
||||||
|
tool_registry = ToolRegistry()
|
||||||
@ -4,6 +4,7 @@ LLM 类型定义模块
|
|||||||
统一导出所有核心类型、协议和异常定义。
|
统一导出所有核心类型、协议和异常定义。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from .capabilities import ModelCapabilities, ModelModality, get_model_capabilities
|
||||||
from .content import (
|
from .content import (
|
||||||
LLMContentPart,
|
LLMContentPart,
|
||||||
LLMMessage,
|
LLMMessage,
|
||||||
@ -26,6 +27,7 @@ from .models import (
|
|||||||
ToolMetadata,
|
ToolMetadata,
|
||||||
UsageInfo,
|
UsageInfo,
|
||||||
)
|
)
|
||||||
|
from .protocols import MCPCompatible
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"EmbeddingTaskType",
|
"EmbeddingTaskType",
|
||||||
@ -41,8 +43,11 @@ __all__ = [
|
|||||||
"LLMTool",
|
"LLMTool",
|
||||||
"LLMToolCall",
|
"LLMToolCall",
|
||||||
"LLMToolFunction",
|
"LLMToolFunction",
|
||||||
|
"MCPCompatible",
|
||||||
|
"ModelCapabilities",
|
||||||
"ModelDetail",
|
"ModelDetail",
|
||||||
"ModelInfo",
|
"ModelInfo",
|
||||||
|
"ModelModality",
|
||||||
"ModelName",
|
"ModelName",
|
||||||
"ModelProvider",
|
"ModelProvider",
|
||||||
"ProviderConfig",
|
"ProviderConfig",
|
||||||
@ -50,5 +55,6 @@ __all__ = [
|
|||||||
"ToolCategory",
|
"ToolCategory",
|
||||||
"ToolMetadata",
|
"ToolMetadata",
|
||||||
"UsageInfo",
|
"UsageInfo",
|
||||||
|
"get_model_capabilities",
|
||||||
"get_user_friendly_error_message",
|
"get_user_friendly_error_message",
|
||||||
]
|
]
|
||||||
|
|||||||
128
zhenxun/services/llm/types/capabilities.py
Normal file
128
zhenxun/services/llm/types/capabilities.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
"""
|
||||||
|
LLM 模型能力定义模块
|
||||||
|
|
||||||
|
定义模型的输入输出模态、工具调用支持等核心能力。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
import fnmatch
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class ModelModality(str, Enum):
|
||||||
|
TEXT = "text"
|
||||||
|
IMAGE = "image"
|
||||||
|
AUDIO = "audio"
|
||||||
|
VIDEO = "video"
|
||||||
|
EMBEDDING = "embedding"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelCapabilities(BaseModel):
|
||||||
|
"""定义一个模型的核心、稳定能力。"""
|
||||||
|
|
||||||
|
input_modalities: set[ModelModality] = Field(default={ModelModality.TEXT})
|
||||||
|
output_modalities: set[ModelModality] = Field(default={ModelModality.TEXT})
|
||||||
|
supports_tool_calling: bool = False
|
||||||
|
is_embedding_model: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
STANDARD_TEXT_TOOL_CAPABILITIES = ModelCapabilities(
|
||||||
|
input_modalities={ModelModality.TEXT},
|
||||||
|
output_modalities={ModelModality.TEXT},
|
||||||
|
supports_tool_calling=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
GEMINI_CAPABILITIES = ModelCapabilities(
|
||||||
|
input_modalities={
|
||||||
|
ModelModality.TEXT,
|
||||||
|
ModelModality.IMAGE,
|
||||||
|
ModelModality.AUDIO,
|
||||||
|
ModelModality.VIDEO,
|
||||||
|
},
|
||||||
|
output_modalities={ModelModality.TEXT},
|
||||||
|
supports_tool_calling=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES = ModelCapabilities(
|
||||||
|
input_modalities={ModelModality.TEXT, ModelModality.IMAGE, ModelModality.VIDEO},
|
||||||
|
output_modalities={ModelModality.TEXT},
|
||||||
|
supports_tool_calling=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_ALIAS_MAPPING: dict[str, str] = {
|
||||||
|
"deepseek-v3*": "deepseek-chat",
|
||||||
|
"deepseek-ai/DeepSeek-V3": "deepseek-chat",
|
||||||
|
"deepseek-r1*": "deepseek-reasoner",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_CAPABILITIES_REGISTRY: dict[str, ModelCapabilities] = {
|
||||||
|
"gemini-*-tts": ModelCapabilities(
|
||||||
|
input_modalities={ModelModality.TEXT},
|
||||||
|
output_modalities={ModelModality.AUDIO},
|
||||||
|
),
|
||||||
|
"gemini-*-native-audio-*": ModelCapabilities(
|
||||||
|
input_modalities={ModelModality.TEXT, ModelModality.AUDIO, ModelModality.VIDEO},
|
||||||
|
output_modalities={ModelModality.TEXT, ModelModality.AUDIO},
|
||||||
|
supports_tool_calling=True,
|
||||||
|
),
|
||||||
|
"gemini-2.0-flash-preview-image-generation": ModelCapabilities(
|
||||||
|
input_modalities={
|
||||||
|
ModelModality.TEXT,
|
||||||
|
ModelModality.IMAGE,
|
||||||
|
ModelModality.AUDIO,
|
||||||
|
ModelModality.VIDEO,
|
||||||
|
},
|
||||||
|
output_modalities={ModelModality.TEXT, ModelModality.IMAGE},
|
||||||
|
supports_tool_calling=True,
|
||||||
|
),
|
||||||
|
"gemini-embedding-exp": ModelCapabilities(
|
||||||
|
input_modalities={ModelModality.TEXT},
|
||||||
|
output_modalities={ModelModality.EMBEDDING},
|
||||||
|
is_embedding_model=True,
|
||||||
|
),
|
||||||
|
"gemini-2.5-pro*": GEMINI_CAPABILITIES,
|
||||||
|
"gemini-1.5-pro*": GEMINI_CAPABILITIES,
|
||||||
|
"gemini-2.5-flash*": GEMINI_CAPABILITIES,
|
||||||
|
"gemini-2.0-flash*": GEMINI_CAPABILITIES,
|
||||||
|
"gemini-1.5-flash*": GEMINI_CAPABILITIES,
|
||||||
|
"GLM-4V-Flash": ModelCapabilities(
|
||||||
|
input_modalities={ModelModality.TEXT, ModelModality.IMAGE},
|
||||||
|
output_modalities={ModelModality.TEXT},
|
||||||
|
supports_tool_calling=True,
|
||||||
|
),
|
||||||
|
"GLM-4V-Plus*": ModelCapabilities(
|
||||||
|
input_modalities={ModelModality.TEXT, ModelModality.IMAGE, ModelModality.VIDEO},
|
||||||
|
output_modalities={ModelModality.TEXT},
|
||||||
|
supports_tool_calling=True,
|
||||||
|
),
|
||||||
|
"glm-4-*": STANDARD_TEXT_TOOL_CAPABILITIES,
|
||||||
|
"glm-z1-*": STANDARD_TEXT_TOOL_CAPABILITIES,
|
||||||
|
"doubao-seed-*": DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES,
|
||||||
|
"doubao-1-5-thinking-vision-pro": DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES,
|
||||||
|
"deepseek-chat": STANDARD_TEXT_TOOL_CAPABILITIES,
|
||||||
|
"deepseek-reasoner": STANDARD_TEXT_TOOL_CAPABILITIES,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_capabilities(model_name: str) -> ModelCapabilities:
|
||||||
|
"""
|
||||||
|
从注册表获取模型能力,支持别名映射和通配符匹配。
|
||||||
|
查找顺序: 1. 标准化名称 -> 2. 精确匹配 -> 3. 通配符匹配 -> 4. 默认值
|
||||||
|
"""
|
||||||
|
canonical_name = model_name
|
||||||
|
for alias_pattern, c_name in MODEL_ALIAS_MAPPING.items():
|
||||||
|
if fnmatch.fnmatch(model_name, alias_pattern):
|
||||||
|
canonical_name = c_name
|
||||||
|
break
|
||||||
|
|
||||||
|
if canonical_name in MODEL_CAPABILITIES_REGISTRY:
|
||||||
|
return MODEL_CAPABILITIES_REGISTRY[canonical_name]
|
||||||
|
|
||||||
|
for pattern, capabilities in MODEL_CAPABILITIES_REGISTRY.items():
|
||||||
|
if "*" in pattern and fnmatch.fnmatch(model_name, pattern):
|
||||||
|
return capabilities
|
||||||
|
|
||||||
|
return ModelCapabilities()
|
||||||
@ -225,8 +225,10 @@ class LLMContentPart(BaseModel):
|
|||||||
logger.warning(f"无法解析Base64图像数据: {self.image_source[:50]}...")
|
logger.warning(f"无法解析Base64图像数据: {self.image_source[:50]}...")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def convert_for_api(self, api_type: str) -> dict[str, Any]:
|
async def convert_for_api_async(self, api_type: str) -> dict[str, Any]:
|
||||||
"""根据API类型转换多模态内容格式"""
|
"""根据API类型转换多模态内容格式"""
|
||||||
|
from zhenxun.utils.http_utils import AsyncHttpx
|
||||||
|
|
||||||
if self.type == "text":
|
if self.type == "text":
|
||||||
if api_type == "openai":
|
if api_type == "openai":
|
||||||
return {"type": "text", "text": self.text}
|
return {"type": "text", "text": self.text}
|
||||||
@ -248,20 +250,23 @@ class LLMContentPart(BaseModel):
|
|||||||
mime_type, data = base64_info
|
mime_type, data = base64_info
|
||||||
return {"inlineData": {"mimeType": mime_type, "data": data}}
|
return {"inlineData": {"mimeType": mime_type, "data": data}}
|
||||||
else:
|
else:
|
||||||
# 如果无法解析 Base64 数据,抛出异常
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"无法解析Base64图像数据: {self.image_source[:50]}..."
|
f"无法解析Base64图像数据: {self.image_source[:50]}..."
|
||||||
)
|
)
|
||||||
else:
|
elif self.is_image_url():
|
||||||
logger.warning(
|
logger.debug(f"正在为Gemini下载并编码URL图片: {self.image_source}")
|
||||||
f"Gemini API需要Base64格式,但提供的是URL: {self.image_source}"
|
try:
|
||||||
)
|
image_bytes = await AsyncHttpx.get_content(self.image_source)
|
||||||
|
mime_type = self.mime_type or "image/jpeg"
|
||||||
|
base64_data = base64.b64encode(image_bytes).decode("utf-8")
|
||||||
return {
|
return {
|
||||||
"inlineData": {
|
"inlineData": {"mimeType": mime_type, "data": base64_data}
|
||||||
"mimeType": "image/jpeg",
|
|
||||||
"data": self.image_source,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"下载或编码URL图片失败: {e}", e=e)
|
||||||
|
raise ValueError(f"无法处理图片URL: {e}")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的图像源格式: {self.image_source[:50]}...")
|
||||||
else:
|
else:
|
||||||
return {"type": "image_url", "image_url": {"url": self.image_source}}
|
return {"type": "image_url", "image_url": {"url": self.image_source}}
|
||||||
|
|
||||||
|
|||||||
@ -4,13 +4,25 @@ LLM 数据模型定义
|
|||||||
包含模型信息、配置、工具定义和响应数据的模型类。
|
包含模型信息、配置、工具定义和响应数据的模型类。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from contextlib import AbstractAsyncContextManager
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from .enums import ModelProvider, ToolCategory
|
from .enums import ModelProvider, ToolCategory
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .protocols import MCPCompatible
|
||||||
|
|
||||||
|
MCPSessionType = (
|
||||||
|
MCPCompatible | Callable[[], AbstractAsyncContextManager[MCPCompatible]] | None
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
MCPCompatible = object
|
||||||
|
MCPSessionType = Any
|
||||||
|
|
||||||
ModelName = str | None
|
ModelName = str | None
|
||||||
|
|
||||||
|
|
||||||
@ -98,10 +110,21 @@ class LLMToolCall(BaseModel):
|
|||||||
class LLMTool(BaseModel):
|
class LLMTool(BaseModel):
|
||||||
"""LLM 工具定义(支持 MCP 风格)"""
|
"""LLM 工具定义(支持 MCP 风格)"""
|
||||||
|
|
||||||
|
model_config = {"arbitrary_types_allowed": True}
|
||||||
|
|
||||||
type: str = "function"
|
type: str = "function"
|
||||||
function: dict[str, Any]
|
function: dict[str, Any] | None = None
|
||||||
|
mcp_session: MCPSessionType = None
|
||||||
annotations: dict[str, Any] | None = Field(default=None, description="工具注解")
|
annotations: dict[str, Any] | None = Field(default=None, description="工具注解")
|
||||||
|
|
||||||
|
def model_post_init(self, /, __context: Any) -> None:
|
||||||
|
"""验证工具定义的有效性"""
|
||||||
|
_ = __context
|
||||||
|
if self.type == "function" and self.function is None:
|
||||||
|
raise ValueError("函数类型的工具必须包含 'function' 字段。")
|
||||||
|
if self.type == "mcp" and self.mcp_session is None:
|
||||||
|
raise ValueError("MCP 类型的工具必须包含 'mcp_session' 字段。")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
cls,
|
cls,
|
||||||
@ -111,7 +134,7 @@ class LLMTool(BaseModel):
|
|||||||
required: list[str] | None = None,
|
required: list[str] | None = None,
|
||||||
annotations: dict[str, Any] | None = None,
|
annotations: dict[str, Any] | None = None,
|
||||||
) -> "LLMTool":
|
) -> "LLMTool":
|
||||||
"""创建工具"""
|
"""创建函数工具"""
|
||||||
function_def = {
|
function_def = {
|
||||||
"name": name,
|
"name": name,
|
||||||
"description": description,
|
"description": description,
|
||||||
@ -123,6 +146,15 @@ class LLMTool(BaseModel):
|
|||||||
}
|
}
|
||||||
return cls(type="function", function=function_def, annotations=annotations)
|
return cls(type="function", function=function_def, annotations=annotations)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_mcp_session(
|
||||||
|
cls,
|
||||||
|
session: Any,
|
||||||
|
annotations: dict[str, Any] | None = None,
|
||||||
|
) -> "LLMTool":
|
||||||
|
"""从 MCP 会话创建工具"""
|
||||||
|
return cls(type="mcp", mcp_session=session, annotations=annotations)
|
||||||
|
|
||||||
|
|
||||||
class LLMCodeExecution(BaseModel):
|
class LLMCodeExecution(BaseModel):
|
||||||
"""代码执行结果"""
|
"""代码执行结果"""
|
||||||
|
|||||||
24
zhenxun/services/llm/types/protocols.py
Normal file
24
zhenxun/services/llm/types/protocols.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
"""
|
||||||
|
LLM 模块的协议定义
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Protocol
|
||||||
|
|
||||||
|
|
||||||
|
class MCPCompatible(Protocol):
|
||||||
|
"""
|
||||||
|
一个协议,定义了与LLM模块兼容的MCP会话对象应具备的行为。
|
||||||
|
任何实现了 to_api_tool 方法的对象都可以被认为是 MCPCompatible。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def to_api_tool(self, api_type: str) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
将此MCP会话转换为特定LLM提供商API所需的工具格式。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
api_type: 目标API的类型 (例如 'gemini', 'openai')。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
dict[str, Any]: 一个字典,代表可以在API请求中使用的工具定义。
|
||||||
|
"""
|
||||||
|
...
|
||||||
@ -3,8 +3,10 @@ LLM 模块的工具和转换函数
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
|
import copy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from nonebot.adapters import Message as PlatformMessage
|
||||||
from nonebot_plugin_alconna.uniseg import (
|
from nonebot_plugin_alconna.uniseg import (
|
||||||
At,
|
At,
|
||||||
File,
|
File,
|
||||||
@ -17,6 +19,7 @@ from nonebot_plugin_alconna.uniseg import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
|
from zhenxun.utils.http_utils import AsyncHttpx
|
||||||
|
|
||||||
from .types import LLMContentPart
|
from .types import LLMContentPart
|
||||||
|
|
||||||
@ -25,6 +28,12 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
|
|||||||
"""
|
"""
|
||||||
将 UniMessage 实例转换为一个 LLMContentPart 列表。
|
将 UniMessage 实例转换为一个 LLMContentPart 列表。
|
||||||
这是处理多模态输入的核心转换逻辑。
|
这是处理多模态输入的核心转换逻辑。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
message: 要转换的UniMessage实例。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
list[LLMContentPart]: 转换后的内容部分列表。
|
||||||
"""
|
"""
|
||||||
parts: list[LLMContentPart] = []
|
parts: list[LLMContentPart] = []
|
||||||
for seg in message:
|
for seg in message:
|
||||||
@ -51,14 +60,25 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
|
|||||||
if seg.path:
|
if seg.path:
|
||||||
part = await LLMContentPart.from_path(seg.path)
|
part = await LLMContentPart.from_path(seg.path)
|
||||||
elif seg.url:
|
elif seg.url:
|
||||||
logger.warning(
|
try:
|
||||||
f"直接使用 URL 的 {type(seg).__name__} 段,"
|
logger.debug(f"检测到媒体URL,开始下载: {seg.url}")
|
||||||
f"API 可能不支持: {seg.url}"
|
media_bytes = await AsyncHttpx.get_content(seg.url)
|
||||||
)
|
|
||||||
|
new_seg = copy.copy(seg)
|
||||||
|
new_seg.raw = media_bytes
|
||||||
|
seg = new_seg
|
||||||
|
logger.debug(f"媒体文件下载成功,大小: {len(media_bytes)} bytes")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"从URL下载媒体失败: {seg.url}, 错误: {e}")
|
||||||
part = LLMContentPart.text_part(
|
part = LLMContentPart.text_part(
|
||||||
f"[{type(seg).__name__.upper()} FILE: {seg.name or seg.url}]"
|
f"[下载媒体失败: {seg.name or seg.url}]"
|
||||||
)
|
)
|
||||||
elif hasattr(seg, "raw") and seg.raw:
|
|
||||||
|
if part:
|
||||||
|
parts.append(part)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if hasattr(seg, "raw") and seg.raw:
|
||||||
mime_type = getattr(seg, "mimetype", None)
|
mime_type = getattr(seg, "mimetype", None)
|
||||||
if isinstance(seg.raw, bytes):
|
if isinstance(seg.raw, bytes):
|
||||||
b64_data = base64.b64encode(seg.raw).decode("utf-8")
|
b64_data = base64.b64encode(seg.raw).decode("utf-8")
|
||||||
@ -127,50 +147,19 @@ def create_multimodal_message(
|
|||||||
audio_mimetypes: list[str] | str | None = None,
|
audio_mimetypes: list[str] | str | None = None,
|
||||||
) -> UniMessage:
|
) -> UniMessage:
|
||||||
"""
|
"""
|
||||||
创建多模态消息的便捷函数,方便第三方调用。
|
创建多模态消息的便捷函数
|
||||||
|
|
||||||
Args:
|
参数:
|
||||||
text: 文本内容
|
text: 文本内容
|
||||||
images: 图片数据,支持路径、字节数据或URL
|
images: 图片数据,支持路径、字节数据或URL
|
||||||
videos: 视频数据,支持路径、字节数据或URL
|
videos: 视频数据
|
||||||
audios: 音频数据,支持路径、字节数据或URL
|
audios: 音频数据
|
||||||
image_mimetypes: 图片MIME类型,当images为bytes时需要指定
|
image_mimetypes: 图片MIME类型,bytes数据时需要指定
|
||||||
video_mimetypes: 视频MIME类型,当videos为bytes时需要指定
|
video_mimetypes: 视频MIME类型,bytes数据时需要指定
|
||||||
audio_mimetypes: 音频MIME类型,当audios为bytes时需要指定
|
audio_mimetypes: 音频MIME类型,bytes数据时需要指定
|
||||||
|
|
||||||
Returns:
|
返回:
|
||||||
UniMessage: 构建好的多模态消息
|
UniMessage: 构建好的多模态消息
|
||||||
|
|
||||||
Examples:
|
|
||||||
# 纯文本
|
|
||||||
msg = create_multimodal_message("请分析这段文字")
|
|
||||||
|
|
||||||
# 文本 + 单张图片(路径)
|
|
||||||
msg = create_multimodal_message("分析图片", images="/path/to/image.jpg")
|
|
||||||
|
|
||||||
# 文本 + 多张图片
|
|
||||||
msg = create_multimodal_message(
|
|
||||||
"比较图片", images=["/path/1.jpg", "/path/2.jpg"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# 文本 + 图片字节数据
|
|
||||||
msg = create_multimodal_message(
|
|
||||||
"分析", images=image_data, image_mimetypes="image/jpeg"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 文本 + 视频
|
|
||||||
msg = create_multimodal_message("分析视频", videos="/path/to/video.mp4")
|
|
||||||
|
|
||||||
# 文本 + 音频
|
|
||||||
msg = create_multimodal_message("转录音频", audios="/path/to/audio.wav")
|
|
||||||
|
|
||||||
# 混合多模态
|
|
||||||
msg = create_multimodal_message(
|
|
||||||
"分析这些媒体文件",
|
|
||||||
images="/path/to/image.jpg",
|
|
||||||
videos="/path/to/video.mp4",
|
|
||||||
audios="/path/to/audio.wav"
|
|
||||||
)
|
|
||||||
"""
|
"""
|
||||||
message = UniMessage()
|
message = UniMessage()
|
||||||
|
|
||||||
@ -196,7 +185,7 @@ def _add_media_to_message(
|
|||||||
media_class: type,
|
media_class: type,
|
||||||
default_mimetype: str,
|
default_mimetype: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""添加媒体文件到 UniMessage 的辅助函数"""
|
"""添加媒体文件到 UniMessage"""
|
||||||
if not isinstance(media_items, list):
|
if not isinstance(media_items, list):
|
||||||
media_items = [media_items]
|
media_items = [media_items]
|
||||||
|
|
||||||
@ -216,3 +205,80 @@ def _add_media_to_message(
|
|||||||
elif isinstance(item, bytes):
|
elif isinstance(item, bytes):
|
||||||
mimetype = mime_list[i] if i < len(mime_list) else default_mimetype
|
mimetype = mime_list[i] if i < len(mime_list) else default_mimetype
|
||||||
message.append(media_class(raw=item, mimetype=mimetype))
|
message.append(media_class(raw=item, mimetype=mimetype))
|
||||||
|
|
||||||
|
|
||||||
|
def message_to_unimessage(message: PlatformMessage) -> UniMessage:
|
||||||
|
"""
|
||||||
|
将平台特定的 Message 对象转换为通用的 UniMessage。
|
||||||
|
主要用于处理引用消息等未被自动转换的消息体。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
message: 平台特定的Message对象。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
UniMessage: 转换后的通用消息对象。
|
||||||
|
"""
|
||||||
|
uni_segments = []
|
||||||
|
for seg in message:
|
||||||
|
if seg.type == "text":
|
||||||
|
uni_segments.append(Text(seg.data.get("text", "")))
|
||||||
|
elif seg.type == "image":
|
||||||
|
uni_segments.append(Image(url=seg.data.get("url")))
|
||||||
|
elif seg.type == "record":
|
||||||
|
uni_segments.append(Voice(url=seg.data.get("url")))
|
||||||
|
elif seg.type == "video":
|
||||||
|
uni_segments.append(Video(url=seg.data.get("url")))
|
||||||
|
elif seg.type == "at":
|
||||||
|
uni_segments.append(At("user", str(seg.data.get("qq", ""))))
|
||||||
|
else:
|
||||||
|
logger.debug(f"跳过不支持的平台消息段类型: {seg.type}")
|
||||||
|
|
||||||
|
return UniMessage(uni_segments)
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_request_body_for_logging(body: dict) -> dict:
|
||||||
|
"""
|
||||||
|
净化请求体用于日志记录,移除大数据字段并添加摘要信息
|
||||||
|
|
||||||
|
参数:
|
||||||
|
body: 原始请求体字典。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
dict: 净化后的请求体字典。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
sanitized_body = copy.deepcopy(body)
|
||||||
|
|
||||||
|
if "contents" in sanitized_body and isinstance(
|
||||||
|
sanitized_body["contents"], list
|
||||||
|
):
|
||||||
|
for content_item in sanitized_body["contents"]:
|
||||||
|
if "parts" in content_item and isinstance(content_item["parts"], list):
|
||||||
|
media_summary = []
|
||||||
|
new_parts = []
|
||||||
|
for part in content_item["parts"]:
|
||||||
|
if "inlineData" in part and isinstance(
|
||||||
|
part["inlineData"], dict
|
||||||
|
):
|
||||||
|
data = part["inlineData"].get("data")
|
||||||
|
if isinstance(data, str):
|
||||||
|
mime_type = part["inlineData"].get(
|
||||||
|
"mimeType", "unknown"
|
||||||
|
)
|
||||||
|
media_summary.append(f"{mime_type} ({len(data)} chars)")
|
||||||
|
continue
|
||||||
|
new_parts.append(part)
|
||||||
|
|
||||||
|
if media_summary:
|
||||||
|
summary_text = (
|
||||||
|
f"[多模态内容: {len(media_summary)}个文件 - "
|
||||||
|
f"{', '.join(media_summary)}]"
|
||||||
|
)
|
||||||
|
new_parts.insert(0, {"text": summary_text})
|
||||||
|
|
||||||
|
content_item["parts"] = new_parts
|
||||||
|
|
||||||
|
return sanitized_body
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"日志净化失败: {e},将记录原始请求体。")
|
||||||
|
return body
|
||||||
|
|||||||
12
zhenxun/services/scheduler/__init__.py
Normal file
12
zhenxun/services/scheduler/__init__.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
"""
|
||||||
|
定时调度服务模块
|
||||||
|
|
||||||
|
提供一个统一的、持久化的定时任务管理器,供所有插件使用。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .lifecycle import _load_schedules_from_db
|
||||||
|
from .service import scheduler_manager
|
||||||
|
|
||||||
|
_ = _load_schedules_from_db
|
||||||
|
|
||||||
|
__all__ = ["scheduler_manager"]
|
||||||
102
zhenxun/services/scheduler/adapter.py
Normal file
102
zhenxun/services/scheduler/adapter.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
"""
|
||||||
|
引擎适配层 (Adapter)
|
||||||
|
|
||||||
|
封装所有对具体调度器引擎 (APScheduler) 的操作,
|
||||||
|
使上层服务与调度器实现解耦。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from nonebot_plugin_apscheduler import scheduler
|
||||||
|
|
||||||
|
from zhenxun.models.schedule_info import ScheduleInfo
|
||||||
|
from zhenxun.services.log import logger
|
||||||
|
|
||||||
|
from .job import _execute_job
|
||||||
|
|
||||||
|
JOB_PREFIX = "zhenxun_schedule_"
|
||||||
|
|
||||||
|
|
||||||
|
class APSchedulerAdapter:
|
||||||
|
"""封装对 APScheduler 的操作"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_job_id(schedule_id: int) -> str:
|
||||||
|
"""生成 APScheduler 的 Job ID"""
|
||||||
|
return f"{JOB_PREFIX}{schedule_id}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_or_reschedule_job(schedule: ScheduleInfo):
|
||||||
|
"""根据 ScheduleInfo 添加或重新调度一个 APScheduler 任务"""
|
||||||
|
job_id = APSchedulerAdapter._get_job_id(schedule.id)
|
||||||
|
|
||||||
|
if not isinstance(schedule.trigger_config, dict):
|
||||||
|
logger.error(
|
||||||
|
f"任务 {schedule.id} 的 trigger_config 不是字典类型: "
|
||||||
|
f"{type(schedule.trigger_config)}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
job = scheduler.get_job(job_id)
|
||||||
|
if job:
|
||||||
|
scheduler.reschedule_job(
|
||||||
|
job_id, trigger=schedule.trigger_type, **schedule.trigger_config
|
||||||
|
)
|
||||||
|
logger.debug(f"已更新APScheduler任务: {job_id}")
|
||||||
|
else:
|
||||||
|
scheduler.add_job(
|
||||||
|
_execute_job,
|
||||||
|
trigger=schedule.trigger_type,
|
||||||
|
id=job_id,
|
||||||
|
misfire_grace_time=300,
|
||||||
|
args=[schedule.id],
|
||||||
|
**schedule.trigger_config,
|
||||||
|
)
|
||||||
|
logger.debug(f"已添加新的APScheduler任务: {job_id}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def remove_job(schedule_id: int):
|
||||||
|
"""移除一个 APScheduler 任务"""
|
||||||
|
job_id = APSchedulerAdapter._get_job_id(schedule_id)
|
||||||
|
try:
|
||||||
|
scheduler.remove_job(job_id)
|
||||||
|
logger.debug(f"已从APScheduler中移除任务: {job_id}")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def pause_job(schedule_id: int):
|
||||||
|
"""暂停一个 APScheduler 任务"""
|
||||||
|
job_id = APSchedulerAdapter._get_job_id(schedule_id)
|
||||||
|
try:
|
||||||
|
scheduler.pause_job(job_id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def resume_job(schedule_id: int):
|
||||||
|
"""恢复一个 APScheduler 任务"""
|
||||||
|
job_id = APSchedulerAdapter._get_job_id(schedule_id)
|
||||||
|
try:
|
||||||
|
scheduler.resume_job(job_id)
|
||||||
|
except Exception:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from .repository import ScheduleRepository
|
||||||
|
|
||||||
|
async def _re_add_job():
|
||||||
|
schedule = await ScheduleRepository.get_by_id(schedule_id)
|
||||||
|
if schedule:
|
||||||
|
APSchedulerAdapter.add_or_reschedule_job(schedule)
|
||||||
|
|
||||||
|
asyncio.create_task(_re_add_job()) # noqa: RUF006
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_job_status(schedule_id: int) -> dict:
|
||||||
|
"""获取 APScheduler Job 的状态"""
|
||||||
|
job_id = APSchedulerAdapter._get_job_id(schedule_id)
|
||||||
|
job = scheduler.get_job(job_id)
|
||||||
|
return {
|
||||||
|
"next_run_time": job.next_run_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
if job and job.next_run_time
|
||||||
|
else "N/A",
|
||||||
|
"is_paused_in_scheduler": not bool(job.next_run_time) if job else "N/A",
|
||||||
|
}
|
||||||
192
zhenxun/services/scheduler/job.py
Normal file
192
zhenxun/services/scheduler/job.py
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
"""
|
||||||
|
定时任务的执行逻辑
|
||||||
|
|
||||||
|
包含被 APScheduler 实际调度的函数,以及处理不同目标(单个、所有群组)的执行策略。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import copy
|
||||||
|
import inspect
|
||||||
|
import random
|
||||||
|
|
||||||
|
import nonebot
|
||||||
|
|
||||||
|
from zhenxun.configs.config import Config
|
||||||
|
from zhenxun.models.schedule_info import ScheduleInfo
|
||||||
|
from zhenxun.services.log import logger
|
||||||
|
from zhenxun.utils.common_utils import CommonUtils
|
||||||
|
from zhenxun.utils.decorator.retry import Retry
|
||||||
|
from zhenxun.utils.platform import PlatformUtils
|
||||||
|
|
||||||
|
SCHEDULE_CONCURRENCY_KEY = "all_groups_concurrency_limit"
|
||||||
|
|
||||||
|
|
||||||
|
async def _execute_job(schedule_id: int):
|
||||||
|
"""
|
||||||
|
APScheduler 调度的入口函数。
|
||||||
|
根据 schedule_id 处理特定任务、所有群组任务或全局任务。
|
||||||
|
"""
|
||||||
|
from .repository import ScheduleRepository
|
||||||
|
from .service import scheduler_manager
|
||||||
|
|
||||||
|
scheduler_manager._running_tasks.add(schedule_id)
|
||||||
|
try:
|
||||||
|
schedule = await ScheduleRepository.get_by_id(schedule_id)
|
||||||
|
if not schedule or not schedule.is_enabled:
|
||||||
|
logger.warning(f"定时任务 {schedule_id} 不存在或已禁用,跳过执行。")
|
||||||
|
return
|
||||||
|
|
||||||
|
plugin_name = schedule.plugin_name
|
||||||
|
|
||||||
|
task_meta = scheduler_manager._registered_tasks.get(plugin_name)
|
||||||
|
if not task_meta:
|
||||||
|
logger.error(
|
||||||
|
f"无法执行定时任务:插件 '{plugin_name}' 未注册或已卸载。将禁用该任务。"
|
||||||
|
)
|
||||||
|
schedule.is_enabled = False
|
||||||
|
await ScheduleRepository.save(schedule, update_fields=["is_enabled"])
|
||||||
|
from .adapter import APSchedulerAdapter
|
||||||
|
|
||||||
|
APSchedulerAdapter.remove_job(schedule.id)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
if schedule.bot_id:
|
||||||
|
bot = nonebot.get_bot(schedule.bot_id)
|
||||||
|
else:
|
||||||
|
bot = nonebot.get_bot()
|
||||||
|
logger.debug(
|
||||||
|
f"任务 {schedule_id} 未关联特定Bot,使用默认Bot {bot.self_id}"
|
||||||
|
)
|
||||||
|
except KeyError:
|
||||||
|
logger.warning(
|
||||||
|
f"定时任务 {schedule_id} 需要的 Bot {schedule.bot_id} "
|
||||||
|
f"不在线,本次执行跳过。"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
except ValueError:
|
||||||
|
logger.warning(f"当前没有Bot在线,定时任务 {schedule_id} 跳过。")
|
||||||
|
return
|
||||||
|
|
||||||
|
if schedule.group_id == scheduler_manager.ALL_GROUPS:
|
||||||
|
await _execute_for_all_groups(schedule, task_meta, bot)
|
||||||
|
else:
|
||||||
|
await _execute_for_single_target(schedule, task_meta, bot)
|
||||||
|
finally:
|
||||||
|
scheduler_manager._running_tasks.discard(schedule_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def _execute_for_all_groups(schedule: ScheduleInfo, task_meta: dict, bot):
|
||||||
|
"""为所有群组执行任务,并处理优先级覆盖。"""
|
||||||
|
plugin_name = schedule.plugin_name
|
||||||
|
|
||||||
|
concurrency_limit = Config.get_config(
|
||||||
|
"SchedulerManager", SCHEDULE_CONCURRENCY_KEY, 5
|
||||||
|
)
|
||||||
|
if not isinstance(concurrency_limit, int) or concurrency_limit <= 0:
|
||||||
|
logger.warning(
|
||||||
|
f"无效的定时任务并发限制配置 '{concurrency_limit}',将使用默认值 5。"
|
||||||
|
)
|
||||||
|
concurrency_limit = 5
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"开始执行针对 [所有群组] 的任务 "
|
||||||
|
f"(ID: {schedule.id}, 插件: {plugin_name}, Bot: {bot.self_id}),"
|
||||||
|
f"并发限制: {concurrency_limit}"
|
||||||
|
)
|
||||||
|
|
||||||
|
all_gids = set()
|
||||||
|
try:
|
||||||
|
group_list, _ = await PlatformUtils.get_group_list(bot)
|
||||||
|
all_gids.update(
|
||||||
|
g.group_id for g in group_list if g.group_id and not g.channel_id
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"为 'all' 任务获取 Bot {bot.self_id} 的群列表失败", e=e)
|
||||||
|
return
|
||||||
|
|
||||||
|
specific_tasks_gids = set(
|
||||||
|
await ScheduleInfo.filter(
|
||||||
|
plugin_name=plugin_name, group_id__in=list(all_gids)
|
||||||
|
).values_list("group_id", flat=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
semaphore = asyncio.Semaphore(concurrency_limit)
|
||||||
|
|
||||||
|
async def worker(gid: str):
|
||||||
|
"""使用 Semaphore 包装单个群组的任务执行"""
|
||||||
|
await asyncio.sleep(random.uniform(0, 59))
|
||||||
|
async with semaphore:
|
||||||
|
temp_schedule = copy.deepcopy(schedule)
|
||||||
|
temp_schedule.group_id = gid
|
||||||
|
await _execute_for_single_target(temp_schedule, task_meta, bot)
|
||||||
|
await asyncio.sleep(random.uniform(0.1, 0.5))
|
||||||
|
|
||||||
|
tasks_to_run = []
|
||||||
|
for gid in all_gids:
|
||||||
|
if gid in specific_tasks_gids:
|
||||||
|
logger.debug(f"群组 {gid} 已有特定任务,跳过 'all' 任务的执行。")
|
||||||
|
continue
|
||||||
|
tasks_to_run.append(worker(gid))
|
||||||
|
|
||||||
|
if tasks_to_run:
|
||||||
|
await asyncio.gather(*tasks_to_run)
|
||||||
|
|
||||||
|
|
||||||
|
async def _execute_for_single_target(schedule: ScheduleInfo, task_meta: dict, bot):
|
||||||
|
"""为单个目标(具体群组或全局)执行任务。"""
|
||||||
|
|
||||||
|
plugin_name = schedule.plugin_name
|
||||||
|
group_id = schedule.group_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
is_blocked = await CommonUtils.task_is_block(bot, plugin_name, group_id)
|
||||||
|
if is_blocked:
|
||||||
|
target_desc = f"群 {group_id}" if group_id else "全局"
|
||||||
|
logger.info(
|
||||||
|
f"插件 '{plugin_name}' 的定时任务在目标 [{target_desc}]"
|
||||||
|
"因功能被禁用而跳过执行。"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
max_retries = Config.get_config("SchedulerManager", "JOB_MAX_RETRIES", 2)
|
||||||
|
retry_delay = Config.get_config("SchedulerManager", "JOB_RETRY_DELAY", 10)
|
||||||
|
|
||||||
|
@Retry.simple(
|
||||||
|
stop_max_attempt=max_retries + 1,
|
||||||
|
wait_fixed_seconds=retry_delay,
|
||||||
|
log_name=f"定时任务执行:{schedule.plugin_name}",
|
||||||
|
)
|
||||||
|
async def _execute_task_with_retry():
|
||||||
|
task_func = task_meta["func"]
|
||||||
|
job_kwargs = schedule.job_kwargs
|
||||||
|
if not isinstance(job_kwargs, dict):
|
||||||
|
logger.error(
|
||||||
|
f"任务 {schedule.id} 的 job_kwargs 不是字典类型: {type(job_kwargs)}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
sig = inspect.signature(task_func)
|
||||||
|
if "bot" in sig.parameters:
|
||||||
|
job_kwargs["bot"] = bot
|
||||||
|
|
||||||
|
await task_func(group_id, **job_kwargs)
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(
|
||||||
|
f"插件 '{schedule.plugin_name}' 开始为目标 "
|
||||||
|
f"[{schedule.group_id or '全局'}] 执行定时任务 (ID: {schedule.id})。"
|
||||||
|
)
|
||||||
|
await _execute_task_with_retry()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"执行定时任务 (ID: {schedule.id}, 插件: {schedule.plugin_name}, "
|
||||||
|
f"目标: {schedule.group_id or '全局'}) 在所有重试后最终失败",
|
||||||
|
e=e,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"执行定时任务 (ID: {schedule.id}, 插件: {plugin_name}, "
|
||||||
|
f"目标: {group_id or '全局'}) 时发生异常",
|
||||||
|
e=e,
|
||||||
|
)
|
||||||
62
zhenxun/services/scheduler/lifecycle.py
Normal file
62
zhenxun/services/scheduler/lifecycle.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
"""
|
||||||
|
定时任务的生命周期管理
|
||||||
|
|
||||||
|
包含在机器人启动时加载和调度数据库中保存的任务的逻辑。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from zhenxun.services.log import logger
|
||||||
|
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||||
|
|
||||||
|
from .adapter import APSchedulerAdapter
|
||||||
|
from .repository import ScheduleRepository
|
||||||
|
from .service import scheduler_manager
|
||||||
|
|
||||||
|
|
||||||
|
@PriorityLifecycle.on_startup(priority=90)
|
||||||
|
async def _load_schedules_from_db():
|
||||||
|
"""在服务启动时从数据库加载并调度所有任务。"""
|
||||||
|
logger.info("正在从数据库加载并调度所有定时任务...")
|
||||||
|
schedules = await ScheduleRepository.get_all_enabled()
|
||||||
|
count = 0
|
||||||
|
for schedule in schedules:
|
||||||
|
if schedule.plugin_name in scheduler_manager._registered_tasks:
|
||||||
|
APSchedulerAdapter.add_or_reschedule_job(schedule)
|
||||||
|
count += 1
|
||||||
|
else:
|
||||||
|
logger.warning(f"跳过加载定时任务:插件 '{schedule.plugin_name}' 未注册。")
|
||||||
|
logger.info(f"数据库定时任务加载完成,共成功加载 {count} 个任务。")
|
||||||
|
|
||||||
|
logger.info("正在检查并注册声明式默认任务...")
|
||||||
|
declared_count = 0
|
||||||
|
for task_info in scheduler_manager._declared_tasks:
|
||||||
|
plugin_name = task_info["plugin_name"]
|
||||||
|
group_id = task_info["group_id"]
|
||||||
|
bot_id = task_info["bot_id"]
|
||||||
|
|
||||||
|
query_kwargs = {
|
||||||
|
"plugin_name": plugin_name,
|
||||||
|
"group_id": group_id,
|
||||||
|
"bot_id": bot_id,
|
||||||
|
}
|
||||||
|
exists = await ScheduleRepository.exists(**query_kwargs)
|
||||||
|
|
||||||
|
if not exists:
|
||||||
|
logger.info(f"为插件 '{plugin_name}' 注册新的默认定时任务...")
|
||||||
|
schedule = await scheduler_manager.add_schedule(
|
||||||
|
plugin_name=plugin_name,
|
||||||
|
group_id=group_id,
|
||||||
|
trigger_type=task_info["trigger_type"],
|
||||||
|
trigger_config=task_info["trigger_config"],
|
||||||
|
job_kwargs=task_info["job_kwargs"],
|
||||||
|
bot_id=bot_id,
|
||||||
|
)
|
||||||
|
if schedule:
|
||||||
|
declared_count += 1
|
||||||
|
logger.debug(f"默认任务 '{plugin_name}' 注册成功 (ID: {schedule.id})")
|
||||||
|
else:
|
||||||
|
logger.error(f"默认任务 '{plugin_name}' 注册失败")
|
||||||
|
else:
|
||||||
|
logger.debug(f"插件 '{plugin_name}' 的默认任务已存在于数据库中,跳过注册。")
|
||||||
|
|
||||||
|
if declared_count > 0:
|
||||||
|
logger.info(f"声明式任务检查完成,新注册了 {declared_count} 个默认任务。")
|
||||||
79
zhenxun/services/scheduler/repository.py
Normal file
79
zhenxun/services/scheduler/repository.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
"""
|
||||||
|
数据持久层 (Repository)
|
||||||
|
|
||||||
|
封装所有对 ScheduleInfo 模型的数据库操作,将数据访问逻辑与业务逻辑分离。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from tortoise.queryset import QuerySet
|
||||||
|
|
||||||
|
from zhenxun.models.schedule_info import ScheduleInfo
|
||||||
|
|
||||||
|
|
||||||
|
class ScheduleRepository:
|
||||||
|
"""封装 ScheduleInfo 模型的数据库操作"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_by_id(schedule_id: int) -> ScheduleInfo | None:
|
||||||
|
"""通过ID获取任务"""
|
||||||
|
return await ScheduleInfo.get_or_none(id=schedule_id)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_all_enabled() -> list[ScheduleInfo]:
|
||||||
|
"""获取所有启用的任务"""
|
||||||
|
return await ScheduleInfo.filter(is_enabled=True).all()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_all(plugin_name: str | None = None) -> list[ScheduleInfo]:
|
||||||
|
"""获取所有任务,可按插件名过滤"""
|
||||||
|
if plugin_name:
|
||||||
|
return await ScheduleInfo.filter(plugin_name=plugin_name).all()
|
||||||
|
return await ScheduleInfo.all()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def save(schedule: ScheduleInfo, update_fields: list[str] | None = None):
|
||||||
|
"""保存任务"""
|
||||||
|
await schedule.save(update_fields=update_fields)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def exists(**kwargs: Any) -> bool:
|
||||||
|
"""检查任务是否存在"""
|
||||||
|
return await ScheduleInfo.exists(**kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_by_plugin_and_group(
|
||||||
|
plugin_name: str, group_ids: list[str]
|
||||||
|
) -> list[ScheduleInfo]:
|
||||||
|
"""根据插件和群组ID列表获取任务"""
|
||||||
|
return await ScheduleInfo.filter(
|
||||||
|
plugin_name=plugin_name, group_id__in=group_ids
|
||||||
|
).all()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def update_or_create(
|
||||||
|
defaults: dict, **kwargs: Any
|
||||||
|
) -> tuple[ScheduleInfo, bool]:
|
||||||
|
"""更新或创建任务"""
|
||||||
|
return await ScheduleInfo.update_or_create(defaults=defaults, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def query_schedules(**filters: Any) -> list[ScheduleInfo]:
|
||||||
|
"""
|
||||||
|
根据任意条件查询任务列表
|
||||||
|
|
||||||
|
参数:
|
||||||
|
**filters: 过滤条件,如 group_id="123", plugin_name="abc"
|
||||||
|
|
||||||
|
返回:
|
||||||
|
list[ScheduleInfo]: 任务列表
|
||||||
|
"""
|
||||||
|
cleaned_filters = {k: v for k, v in filters.items() if v is not None}
|
||||||
|
if not cleaned_filters:
|
||||||
|
return await ScheduleInfo.all()
|
||||||
|
return await ScheduleInfo.filter(**cleaned_filters).all()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def filter(**kwargs: Any) -> QuerySet[ScheduleInfo]:
|
||||||
|
"""提供一个通用的过滤查询接口,供Targeter使用"""
|
||||||
|
return ScheduleInfo.filter(**kwargs)
|
||||||
448
zhenxun/services/scheduler/service.py
Normal file
448
zhenxun/services/scheduler/service.py
Normal file
@ -0,0 +1,448 @@
|
|||||||
|
"""
|
||||||
|
服务层 (Service)
|
||||||
|
|
||||||
|
定义 SchedulerManager 类作为定时任务服务的公共 API 入口。
|
||||||
|
它负责编排业务逻辑,并调用 Repository 和 Adapter 层来完成具体工作。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Callable, Coroutine
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
|
import nonebot
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from zhenxun.configs.config import Config
|
||||||
|
from zhenxun.models.schedule_info import ScheduleInfo
|
||||||
|
from zhenxun.services.log import logger
|
||||||
|
|
||||||
|
from .adapter import APSchedulerAdapter
|
||||||
|
from .job import _execute_job
|
||||||
|
from .repository import ScheduleRepository
|
||||||
|
from .targeter import ScheduleTargeter
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerManager:
|
||||||
|
ALL_GROUPS: ClassVar[str] = "__ALL_GROUPS__"
|
||||||
|
_registered_tasks: ClassVar[
|
||||||
|
dict[str, dict[str, Callable | type[BaseModel] | None]]
|
||||||
|
] = {}
|
||||||
|
_declared_tasks: ClassVar[list[dict[str, Any]]] = []
|
||||||
|
_running_tasks: ClassVar[set] = set()
|
||||||
|
|
||||||
|
def target(self, **filters: Any) -> ScheduleTargeter:
|
||||||
|
"""
|
||||||
|
创建目标选择器以执行批量操作
|
||||||
|
|
||||||
|
参数:
|
||||||
|
**filters: 过滤条件,支持plugin_name、group_id、bot_id等字段。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
ScheduleTargeter: 目标选择器对象,可用于批量操作。
|
||||||
|
"""
|
||||||
|
return ScheduleTargeter(self, **filters)
|
||||||
|
|
||||||
|
def task(
|
||||||
|
self,
|
||||||
|
trigger: str,
|
||||||
|
group_id: str | None = None,
|
||||||
|
bot_id: str | None = None,
|
||||||
|
**trigger_kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
声明式定时任务装饰器
|
||||||
|
|
||||||
|
参数:
|
||||||
|
trigger: 触发器类型,如'cron'、'interval'等。
|
||||||
|
group_id: 目标群组ID,None表示全局任务。
|
||||||
|
bot_id: 目标Bot ID,None表示使用默认Bot。
|
||||||
|
**trigger_kwargs: 触发器配置参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
Callable: 装饰器函数。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(func: Callable[..., Coroutine]) -> Callable[..., Coroutine]:
|
||||||
|
try:
|
||||||
|
plugin = nonebot.get_plugin_by_module_name(func.__module__)
|
||||||
|
if not plugin:
|
||||||
|
raise ValueError(f"函数 {func.__name__} 不在任何已加载的插件中。")
|
||||||
|
plugin_name = plugin.name
|
||||||
|
|
||||||
|
task_declaration = {
|
||||||
|
"plugin_name": plugin_name,
|
||||||
|
"func": func,
|
||||||
|
"group_id": group_id,
|
||||||
|
"bot_id": bot_id,
|
||||||
|
"trigger_type": trigger,
|
||||||
|
"trigger_config": trigger_kwargs,
|
||||||
|
"job_kwargs": {},
|
||||||
|
}
|
||||||
|
self._declared_tasks.append(task_declaration)
|
||||||
|
logger.debug(
|
||||||
|
f"发现声明式定时任务 '{plugin_name}',将在启动时进行注册。"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"注册声明式定时任务失败: {func.__name__}, 错误: {e}")
|
||||||
|
|
||||||
|
return func
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def register(
|
||||||
|
self, plugin_name: str, params_model: type[BaseModel] | None = None
|
||||||
|
) -> Callable:
|
||||||
|
"""
|
||||||
|
注册可调度的任务函数
|
||||||
|
|
||||||
|
参数:
|
||||||
|
plugin_name: 插件名称,用于标识任务。
|
||||||
|
params_model: 参数验证模型,继承自BaseModel的类。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
Callable: 装饰器函数。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(func: Callable[..., Coroutine]) -> Callable[..., Coroutine]:
|
||||||
|
if plugin_name in self._registered_tasks:
|
||||||
|
logger.warning(f"插件 '{plugin_name}' 的定时任务已被重复注册。")
|
||||||
|
self._registered_tasks[plugin_name] = {
|
||||||
|
"func": func,
|
||||||
|
"model": params_model,
|
||||||
|
}
|
||||||
|
model_name = params_model.__name__ if params_model else "无"
|
||||||
|
logger.debug(
|
||||||
|
f"插件 '{plugin_name}' 的定时任务已注册,参数模型: {model_name}"
|
||||||
|
)
|
||||||
|
return func
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def get_registered_plugins(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
获取已注册插件列表
|
||||||
|
|
||||||
|
返回:
|
||||||
|
list[str]: 已注册的插件名称列表。
|
||||||
|
"""
|
||||||
|
return list(self._registered_tasks.keys())
|
||||||
|
|
||||||
|
async def add_daily_task(
|
||||||
|
self,
|
||||||
|
plugin_name: str,
|
||||||
|
group_id: str | None,
|
||||||
|
hour: int,
|
||||||
|
minute: int,
|
||||||
|
second: int = 0,
|
||||||
|
job_kwargs: dict | None = None,
|
||||||
|
bot_id: str | None = None,
|
||||||
|
) -> "ScheduleInfo | None":
|
||||||
|
"""
|
||||||
|
添加每日定时任务
|
||||||
|
|
||||||
|
参数:
|
||||||
|
plugin_name: 插件名称。
|
||||||
|
group_id: 目标群组ID,None表示全局任务。
|
||||||
|
hour: 执行小时(0-23)。
|
||||||
|
minute: 执行分钟(0-59)。
|
||||||
|
second: 执行秒数(0-59),默认为0。
|
||||||
|
job_kwargs: 任务参数字典。
|
||||||
|
bot_id: 目标Bot ID,None表示使用默认Bot。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
ScheduleInfo | None: 创建的任务信息,失败时返回None。
|
||||||
|
"""
|
||||||
|
trigger_config = {
|
||||||
|
"hour": hour,
|
||||||
|
"minute": minute,
|
||||||
|
"second": second,
|
||||||
|
"timezone": Config.get_config("SchedulerManager", "SCHEDULER_TIMEZONE"),
|
||||||
|
}
|
||||||
|
return await self.add_schedule(
|
||||||
|
plugin_name,
|
||||||
|
group_id,
|
||||||
|
"cron",
|
||||||
|
trigger_config,
|
||||||
|
job_kwargs=job_kwargs,
|
||||||
|
bot_id=bot_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def add_interval_task(
|
||||||
|
self,
|
||||||
|
plugin_name: str,
|
||||||
|
group_id: str | None,
|
||||||
|
*,
|
||||||
|
weeks: int = 0,
|
||||||
|
days: int = 0,
|
||||||
|
hours: int = 0,
|
||||||
|
minutes: int = 0,
|
||||||
|
seconds: int = 0,
|
||||||
|
start_date: str | datetime | None = None,
|
||||||
|
job_kwargs: dict | None = None,
|
||||||
|
bot_id: str | None = None,
|
||||||
|
) -> "ScheduleInfo | None":
|
||||||
|
"""添加间隔性定时任务"""
|
||||||
|
trigger_config = {
|
||||||
|
"weeks": weeks,
|
||||||
|
"days": days,
|
||||||
|
"hours": hours,
|
||||||
|
"minutes": minutes,
|
||||||
|
"seconds": seconds,
|
||||||
|
"start_date": start_date,
|
||||||
|
}
|
||||||
|
trigger_config = {k: v for k, v in trigger_config.items() if v}
|
||||||
|
return await self.add_schedule(
|
||||||
|
plugin_name,
|
||||||
|
group_id,
|
||||||
|
"interval",
|
||||||
|
trigger_config,
|
||||||
|
job_kwargs=job_kwargs,
|
||||||
|
bot_id=bot_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _validate_and_prepare_kwargs(
|
||||||
|
self, plugin_name: str, job_kwargs: dict | None
|
||||||
|
) -> tuple[bool, str | dict]:
|
||||||
|
"""验证并准备任务参数,应用默认值"""
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
task_meta = self._registered_tasks.get(plugin_name)
|
||||||
|
if not task_meta:
|
||||||
|
return False, f"插件 '{plugin_name}' 未注册。"
|
||||||
|
|
||||||
|
params_model = task_meta.get("model")
|
||||||
|
job_kwargs = job_kwargs if job_kwargs is not None else {}
|
||||||
|
|
||||||
|
if not params_model:
|
||||||
|
if job_kwargs:
|
||||||
|
logger.warning(
|
||||||
|
f"插件 '{plugin_name}' 未定义参数模型,但收到了参数: {job_kwargs}"
|
||||||
|
)
|
||||||
|
return True, job_kwargs
|
||||||
|
|
||||||
|
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
|
||||||
|
logger.error(f"插件 '{plugin_name}' 的参数模型不是有效的 BaseModel 类")
|
||||||
|
return False, f"插件 '{plugin_name}' 的参数模型配置错误"
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_validate = getattr(params_model, "model_validate", None)
|
||||||
|
if not model_validate:
|
||||||
|
return False, f"插件 '{plugin_name}' 的参数模型不支持验证"
|
||||||
|
|
||||||
|
validated_model = model_validate(job_kwargs)
|
||||||
|
|
||||||
|
model_dump = getattr(validated_model, "model_dump", None)
|
||||||
|
if not model_dump:
|
||||||
|
return False, f"插件 '{plugin_name}' 的参数模型不支持导出"
|
||||||
|
|
||||||
|
return True, model_dump()
|
||||||
|
except ValidationError as e:
|
||||||
|
errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()]
|
||||||
|
error_str = "\n".join(errors)
|
||||||
|
msg = f"插件 '{plugin_name}' 的任务参数验证失败:\n{error_str}"
|
||||||
|
return False, msg
|
||||||
|
|
||||||
|
async def add_schedule(
|
||||||
|
self,
|
||||||
|
plugin_name: str,
|
||||||
|
group_id: str | None,
|
||||||
|
trigger_type: str,
|
||||||
|
trigger_config: dict,
|
||||||
|
job_kwargs: dict | None = None,
|
||||||
|
bot_id: str | None = None,
|
||||||
|
) -> "ScheduleInfo | None":
|
||||||
|
"""
|
||||||
|
添加定时任务(通用方法)
|
||||||
|
|
||||||
|
参数:
|
||||||
|
plugin_name: 插件名称。
|
||||||
|
group_id: 目标群组ID,None表示全局任务。
|
||||||
|
trigger_type: 触发器类型,如'cron'、'interval'等。
|
||||||
|
trigger_config: 触发器配置字典。
|
||||||
|
job_kwargs: 任务参数字典。
|
||||||
|
bot_id: 目标Bot ID,None表示使用默认Bot。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
ScheduleInfo | None: 创建的任务信息,失败时返回None。
|
||||||
|
"""
|
||||||
|
if plugin_name not in self._registered_tasks:
|
||||||
|
logger.error(f"插件 '{plugin_name}' 没有注册可用的定时任务。")
|
||||||
|
return None
|
||||||
|
|
||||||
|
is_valid, result = self._validate_and_prepare_kwargs(plugin_name, job_kwargs)
|
||||||
|
if not is_valid:
|
||||||
|
logger.error(f"任务参数校验失败: {result}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
search_kwargs = {"plugin_name": plugin_name, "group_id": group_id}
|
||||||
|
if bot_id and group_id == self.ALL_GROUPS:
|
||||||
|
search_kwargs["bot_id"] = bot_id
|
||||||
|
else:
|
||||||
|
search_kwargs["bot_id__isnull"] = True
|
||||||
|
|
||||||
|
defaults = {
|
||||||
|
"trigger_type": trigger_type,
|
||||||
|
"trigger_config": trigger_config,
|
||||||
|
"job_kwargs": result,
|
||||||
|
"is_enabled": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
schedule, created = await ScheduleRepository.update_or_create(
|
||||||
|
defaults, **search_kwargs
|
||||||
|
)
|
||||||
|
APSchedulerAdapter.add_or_reschedule_job(schedule)
|
||||||
|
|
||||||
|
action = "设置" if created else "更新"
|
||||||
|
logger.info(
|
||||||
|
f"已成功{action}插件 '{plugin_name}' 的定时任务 (ID: {schedule.id})。"
|
||||||
|
)
|
||||||
|
return schedule
|
||||||
|
|
||||||
|
async def get_all_schedules(self) -> list[ScheduleInfo]:
|
||||||
|
"""
|
||||||
|
获取所有定时任务信息
|
||||||
|
"""
|
||||||
|
return await self.get_schedules()
|
||||||
|
|
||||||
|
async def get_schedules(
|
||||||
|
self,
|
||||||
|
plugin_name: str | None = None,
|
||||||
|
group_id: str | None = None,
|
||||||
|
bot_id: str | None = None,
|
||||||
|
) -> list[ScheduleInfo]:
|
||||||
|
"""
|
||||||
|
根据条件获取定时任务列表
|
||||||
|
|
||||||
|
参数:
|
||||||
|
plugin_name: 插件名称,None表示不限制。
|
||||||
|
group_id: 群组ID,None表示不限制。
|
||||||
|
bot_id: Bot ID,None表示不限制。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
list[ScheduleInfo]: 符合条件的任务信息列表。
|
||||||
|
"""
|
||||||
|
return await ScheduleRepository.query_schedules(
|
||||||
|
plugin_name=plugin_name, group_id=group_id, bot_id=bot_id
|
||||||
|
)
|
||||||
|
|
||||||
|
async def update_schedule(
|
||||||
|
self,
|
||||||
|
schedule_id: int,
|
||||||
|
trigger_type: str | None = None,
|
||||||
|
trigger_config: dict | None = None,
|
||||||
|
job_kwargs: dict | None = None,
|
||||||
|
) -> tuple[bool, str]:
|
||||||
|
"""
|
||||||
|
更新定时任务配置
|
||||||
|
|
||||||
|
参数:
|
||||||
|
schedule_id: 任务ID。
|
||||||
|
trigger_type: 新的触发器类型,None表示不更新。
|
||||||
|
trigger_config: 新的触发器配置,None表示不更新。
|
||||||
|
job_kwargs: 新的任务参数,None表示不更新。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
tuple[bool, str]: (是否成功, 结果消息)。
|
||||||
|
"""
|
||||||
|
schedule = await ScheduleRepository.get_by_id(schedule_id)
|
||||||
|
if not schedule:
|
||||||
|
return False, f"未找到 ID 为 {schedule_id} 的任务。"
|
||||||
|
|
||||||
|
updated_fields = []
|
||||||
|
if trigger_config is not None:
|
||||||
|
schedule.trigger_config = trigger_config
|
||||||
|
updated_fields.append("trigger_config")
|
||||||
|
if trigger_type is not None and schedule.trigger_type != trigger_type:
|
||||||
|
schedule.trigger_type = trigger_type
|
||||||
|
updated_fields.append("trigger_type")
|
||||||
|
|
||||||
|
if job_kwargs is not None:
|
||||||
|
existing_kwargs = (
|
||||||
|
schedule.job_kwargs.copy()
|
||||||
|
if isinstance(schedule.job_kwargs, dict)
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
existing_kwargs.update(job_kwargs)
|
||||||
|
|
||||||
|
is_valid, result = self._validate_and_prepare_kwargs(
|
||||||
|
schedule.plugin_name, existing_kwargs
|
||||||
|
)
|
||||||
|
if not is_valid:
|
||||||
|
return False, str(result)
|
||||||
|
|
||||||
|
assert isinstance(result, dict), "验证成功时 result 应该是字典类型"
|
||||||
|
schedule.job_kwargs = result
|
||||||
|
updated_fields.append("job_kwargs")
|
||||||
|
|
||||||
|
if not updated_fields:
|
||||||
|
return True, "没有任何需要更新的配置。"
|
||||||
|
|
||||||
|
await ScheduleRepository.save(schedule, update_fields=updated_fields)
|
||||||
|
APSchedulerAdapter.add_or_reschedule_job(schedule)
|
||||||
|
return True, f"成功更新了任务 ID: {schedule_id} 的配置。"
|
||||||
|
|
||||||
|
async def get_schedule_status(self, schedule_id: int) -> dict | None:
|
||||||
|
"""获取定时任务的详细状态信息"""
|
||||||
|
schedule = await ScheduleRepository.get_by_id(schedule_id)
|
||||||
|
if not schedule:
|
||||||
|
return None
|
||||||
|
|
||||||
|
status_from_scheduler = APSchedulerAdapter.get_job_status(schedule.id)
|
||||||
|
|
||||||
|
status_text = (
|
||||||
|
"运行中"
|
||||||
|
if schedule_id in self._running_tasks
|
||||||
|
else ("启用" if schedule.is_enabled else "暂停")
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": schedule.id,
|
||||||
|
"bot_id": schedule.bot_id,
|
||||||
|
"plugin_name": schedule.plugin_name,
|
||||||
|
"group_id": schedule.group_id,
|
||||||
|
"is_enabled": status_text,
|
||||||
|
"trigger_type": schedule.trigger_type,
|
||||||
|
"trigger_config": schedule.trigger_config,
|
||||||
|
"job_kwargs": schedule.job_kwargs,
|
||||||
|
**status_from_scheduler,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def pause_schedule(self, schedule_id: int) -> tuple[bool, str]:
|
||||||
|
"""暂停指定的定时任务"""
|
||||||
|
schedule = await ScheduleRepository.get_by_id(schedule_id)
|
||||||
|
if not schedule or not schedule.is_enabled:
|
||||||
|
return False, "任务不存在或已暂停。"
|
||||||
|
|
||||||
|
schedule.is_enabled = False
|
||||||
|
await ScheduleRepository.save(schedule, update_fields=["is_enabled"])
|
||||||
|
APSchedulerAdapter.pause_job(schedule_id)
|
||||||
|
return True, f"已暂停任务 (ID: {schedule.id})。"
|
||||||
|
|
||||||
|
async def resume_schedule(self, schedule_id: int) -> tuple[bool, str]:
|
||||||
|
"""恢复指定的定时任务"""
|
||||||
|
schedule = await ScheduleRepository.get_by_id(schedule_id)
|
||||||
|
if not schedule or schedule.is_enabled:
|
||||||
|
return False, "任务不存在或已启用。"
|
||||||
|
|
||||||
|
schedule.is_enabled = True
|
||||||
|
await ScheduleRepository.save(schedule, update_fields=["is_enabled"])
|
||||||
|
APSchedulerAdapter.resume_job(schedule_id)
|
||||||
|
return True, f"已恢复任务 (ID: {schedule.id})。"
|
||||||
|
|
||||||
|
async def trigger_now(self, schedule_id: int) -> tuple[bool, str]:
|
||||||
|
"""立即手动触发指定的定时任务"""
|
||||||
|
schedule = await ScheduleRepository.get_by_id(schedule_id)
|
||||||
|
if not schedule:
|
||||||
|
return False, f"未找到 ID 为 {schedule_id} 的定时任务。"
|
||||||
|
if schedule.plugin_name not in self._registered_tasks:
|
||||||
|
return False, f"插件 '{schedule.plugin_name}' 没有注册可用的定时任务。"
|
||||||
|
|
||||||
|
try:
|
||||||
|
await _execute_job(schedule.id)
|
||||||
|
return True, f"已手动触发任务 (ID: {schedule.id})。"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"手动触发任务失败: {e}")
|
||||||
|
return False, f"手动触发任务失败: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
scheduler_manager = SchedulerManager()
|
||||||
109
zhenxun/services/scheduler/targeter.py
Normal file
109
zhenxun/services/scheduler/targeter.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
"""
|
||||||
|
目标选择器 (Targeter)
|
||||||
|
|
||||||
|
提供链式API,用于构建和执行对多个定时任务的批量操作。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Callable, Coroutine
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .adapter import APSchedulerAdapter
|
||||||
|
from .repository import ScheduleRepository
|
||||||
|
|
||||||
|
|
||||||
|
class ScheduleTargeter:
|
||||||
|
"""
|
||||||
|
一个用于构建和执行定时任务批量操作的目标选择器。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, manager: Any, **filters: Any):
|
||||||
|
"""初始化目标选择器"""
|
||||||
|
self._manager = manager
|
||||||
|
self._filters = {k: v for k, v in filters.items() if v is not None}
|
||||||
|
|
||||||
|
async def _get_schedules(self):
|
||||||
|
"""根据过滤器获取任务"""
|
||||||
|
query = ScheduleRepository.filter(**self._filters)
|
||||||
|
return await query.all()
|
||||||
|
|
||||||
|
def _generate_target_description(self) -> str:
|
||||||
|
"""根据过滤条件生成友好的目标描述"""
|
||||||
|
if "id" in self._filters:
|
||||||
|
return f"任务 ID {self._filters['id']} 的"
|
||||||
|
|
||||||
|
parts = []
|
||||||
|
if "group_id" in self._filters:
|
||||||
|
group_id = self._filters["group_id"]
|
||||||
|
if group_id == self._manager.ALL_GROUPS:
|
||||||
|
parts.append("所有群组中")
|
||||||
|
else:
|
||||||
|
parts.append(f"群 {group_id} 中")
|
||||||
|
|
||||||
|
if "plugin_name" in self._filters:
|
||||||
|
parts.append(f"插件 '{self._filters['plugin_name']}' 的")
|
||||||
|
|
||||||
|
if not parts:
|
||||||
|
return "所有"
|
||||||
|
|
||||||
|
return "".join(parts)
|
||||||
|
|
||||||
|
async def _apply_operation(
|
||||||
|
self,
|
||||||
|
operation_func: Callable[[int], Coroutine[Any, Any, tuple[bool, str]]],
|
||||||
|
operation_name: str,
|
||||||
|
) -> tuple[int, str]:
|
||||||
|
"""通用的操作应用模板"""
|
||||||
|
schedules = await self._get_schedules()
|
||||||
|
if not schedules:
|
||||||
|
target_desc = self._generate_target_description()
|
||||||
|
return 0, f"没有找到{target_desc}可供{operation_name}的任务。"
|
||||||
|
|
||||||
|
success_count = 0
|
||||||
|
for schedule in schedules:
|
||||||
|
success, _ = await operation_func(schedule.id)
|
||||||
|
if success:
|
||||||
|
success_count += 1
|
||||||
|
|
||||||
|
target_desc = self._generate_target_description()
|
||||||
|
return (
|
||||||
|
success_count,
|
||||||
|
f"成功{operation_name}了{target_desc} {success_count} 个任务。",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def pause(self) -> tuple[int, str]:
|
||||||
|
"""
|
||||||
|
暂停匹配的定时任务
|
||||||
|
|
||||||
|
返回:
|
||||||
|
tuple[int, str]: (成功暂停的任务数量, 操作结果消息)。
|
||||||
|
"""
|
||||||
|
return await self._apply_operation(self._manager.pause_schedule, "暂停")
|
||||||
|
|
||||||
|
async def resume(self) -> tuple[int, str]:
|
||||||
|
"""
|
||||||
|
恢复匹配的定时任务
|
||||||
|
|
||||||
|
返回:
|
||||||
|
tuple[int, str]: (成功恢复的任务数量, 操作结果消息)。
|
||||||
|
"""
|
||||||
|
return await self._apply_operation(self._manager.resume_schedule, "恢复")
|
||||||
|
|
||||||
|
async def remove(self) -> tuple[int, str]:
|
||||||
|
"""
|
||||||
|
移除匹配的定时任务
|
||||||
|
|
||||||
|
返回:
|
||||||
|
tuple[int, str]: (成功移除的任务数量, 操作结果消息)。
|
||||||
|
"""
|
||||||
|
schedules = await self._get_schedules()
|
||||||
|
if not schedules:
|
||||||
|
target_desc = self._generate_target_description()
|
||||||
|
return 0, f"没有找到{target_desc}可供移除的任务。"
|
||||||
|
|
||||||
|
for schedule in schedules:
|
||||||
|
APSchedulerAdapter.remove_job(schedule.id)
|
||||||
|
|
||||||
|
query = ScheduleRepository.filter(**self._filters)
|
||||||
|
count = await query.delete()
|
||||||
|
target_desc = self._generate_target_description()
|
||||||
|
return count, f"成功移除了{target_desc} {count} 个任务。"
|
||||||
@ -1,91 +1,94 @@
|
|||||||
import os
|
from collections.abc import AsyncGenerator
|
||||||
import sys
|
from contextlib import asynccontextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
from nonebot import get_driver
|
from nonebot_plugin_alconna import UniMessage
|
||||||
from playwright.__main__ import main
|
from nonebot_plugin_htmlrender import get_browser
|
||||||
from playwright.async_api import Browser, Playwright, async_playwright
|
from playwright.async_api import Page
|
||||||
|
|
||||||
from zhenxun.configs.config import BotConfig
|
from zhenxun.utils.message import MessageUtils
|
||||||
from zhenxun.services.log import logger
|
|
||||||
|
|
||||||
driver = get_driver()
|
|
||||||
|
|
||||||
_playwright: Playwright | None = None
|
|
||||||
_browser: Browser | None = None
|
|
||||||
|
|
||||||
|
|
||||||
# @driver.on_startup
|
class BrowserIsNone(Exception):
|
||||||
# async def start_browser():
|
pass
|
||||||
# global _playwright
|
|
||||||
# global _browser
|
|
||||||
# install()
|
|
||||||
# await check_playwright_env()
|
|
||||||
# _playwright = await async_playwright().start()
|
|
||||||
# _browser = await _playwright.chromium.launch()
|
|
||||||
|
|
||||||
|
|
||||||
# @driver.on_shutdown
|
class AsyncPlaywright:
|
||||||
# async def shutdown_browser():
|
@classmethod
|
||||||
# if _browser:
|
@asynccontextmanager
|
||||||
# await _browser.close()
|
async def new_page(
|
||||||
# if _playwright:
|
cls, cookies: list[dict[str, Any]] | dict[str, Any] | None = None, **kwargs
|
||||||
# await _playwright.stop() # type: ignore
|
) -> AsyncGenerator[Page, None]:
|
||||||
|
"""获取一个新页面
|
||||||
|
|
||||||
|
参数:
|
||||||
# def get_browser() -> Browser:
|
cookies: cookies
|
||||||
# if not _browser:
|
"""
|
||||||
# raise RuntimeError("playwright is not initalized")
|
browser = await get_browser()
|
||||||
# return _browser
|
ctx = await browser.new_context(**kwargs)
|
||||||
|
if cookies:
|
||||||
|
if isinstance(cookies, dict):
|
||||||
def install():
|
cookies = [cookies]
|
||||||
"""自动安装、更新 Chromium"""
|
await ctx.add_cookies(cookies) # type: ignore
|
||||||
|
page = await ctx.new_page()
|
||||||
def set_env_variables():
|
|
||||||
os.environ["PLAYWRIGHT_DOWNLOAD_HOST"] = (
|
|
||||||
"https://npmmirror.com/mirrors/playwright/"
|
|
||||||
)
|
|
||||||
if BotConfig.system_proxy:
|
|
||||||
os.environ["HTTPS_PROXY"] = BotConfig.system_proxy
|
|
||||||
|
|
||||||
def restore_env_variables():
|
|
||||||
os.environ.pop("PLAYWRIGHT_DOWNLOAD_HOST", None)
|
|
||||||
if BotConfig.system_proxy:
|
|
||||||
os.environ.pop("HTTPS_PROXY", None)
|
|
||||||
if original_proxy is not None:
|
|
||||||
os.environ["HTTPS_PROXY"] = original_proxy
|
|
||||||
|
|
||||||
def try_install_chromium():
|
|
||||||
try:
|
try:
|
||||||
sys.argv = ["", "install", "chromium"]
|
yield page
|
||||||
main()
|
finally:
|
||||||
except SystemExit as e:
|
await page.close()
|
||||||
return e.code == 0
|
await ctx.close()
|
||||||
return False
|
|
||||||
|
|
||||||
logger.info("检查 Chromium 更新")
|
@classmethod
|
||||||
|
async def screenshot(
|
||||||
|
cls,
|
||||||
|
url: str,
|
||||||
|
path: Path | str,
|
||||||
|
element: str | list[str],
|
||||||
|
*,
|
||||||
|
wait_time: int | None = None,
|
||||||
|
viewport_size: dict[str, int] | None = None,
|
||||||
|
wait_until: (
|
||||||
|
Literal["domcontentloaded", "load", "networkidle"] | None
|
||||||
|
) = "networkidle",
|
||||||
|
timeout: float | None = None,
|
||||||
|
type_: Literal["jpeg", "png"] | None = None,
|
||||||
|
user_agent: str | None = None,
|
||||||
|
cookies: list[dict[str, Any]] | dict[str, Any] | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> UniMessage | None:
|
||||||
|
"""截图,该方法仅用于简单快捷截图,复杂截图请操作 page
|
||||||
|
|
||||||
original_proxy = os.environ.get("HTTPS_PROXY")
|
参数:
|
||||||
set_env_variables()
|
url: 网址
|
||||||
|
path: 存储路径
|
||||||
success = try_install_chromium()
|
element: 元素选择
|
||||||
|
wait_time: 等待截取超时时间
|
||||||
if not success:
|
viewport_size: 窗口大小
|
||||||
logger.info("Chromium 更新失败,尝试从原始仓库下载,速度较慢")
|
wait_until: 等待类型
|
||||||
os.environ["PLAYWRIGHT_DOWNLOAD_HOST"] = ""
|
timeout: 超时限制
|
||||||
success = try_install_chromium()
|
type_: 保存类型
|
||||||
|
user_agent: user_agent
|
||||||
restore_env_variables()
|
cookies: cookies
|
||||||
|
"""
|
||||||
if not success:
|
if viewport_size is None:
|
||||||
raise RuntimeError("未知错误,Chromium 下载失败")
|
viewport_size = {"width": 2560, "height": 1080}
|
||||||
|
if isinstance(path, str):
|
||||||
|
path = Path(path)
|
||||||
async def check_playwright_env():
|
wait_time = wait_time * 1000 if wait_time else None
|
||||||
"""检查 Playwright 依赖"""
|
element_list = [element] if isinstance(element, str) else element
|
||||||
logger.info("检查 Playwright 依赖")
|
async with cls.new_page(
|
||||||
try:
|
cookies,
|
||||||
async with async_playwright() as p:
|
viewport=viewport_size,
|
||||||
await p.chromium.launch()
|
user_agent=user_agent,
|
||||||
except Exception as e:
|
**kwargs,
|
||||||
raise ImportError("加载失败,Playwright 依赖不全,") from e
|
) as page:
|
||||||
|
await page.goto(url, timeout=timeout, wait_until=wait_until)
|
||||||
|
card = page
|
||||||
|
for e in element_list:
|
||||||
|
if not card:
|
||||||
|
return None
|
||||||
|
card = await card.wait_for_selector(e, timeout=wait_time)
|
||||||
|
if card:
|
||||||
|
await card.screenshot(path=path, timeout=timeout, type=type_)
|
||||||
|
return MessageUtils.build_message(path)
|
||||||
|
return None
|
||||||
|
|||||||
@ -1,24 +1,226 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
|
from functools import partial, wraps
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
from anyio import EndOfStream
|
from anyio import EndOfStream
|
||||||
from httpx import ConnectError, HTTPStatusError, TimeoutException
|
from httpx import (
|
||||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
|
ConnectError,
|
||||||
|
HTTPStatusError,
|
||||||
|
RemoteProtocolError,
|
||||||
|
StreamError,
|
||||||
|
TimeoutException,
|
||||||
|
)
|
||||||
|
from nonebot.utils import is_coroutine_callable
|
||||||
|
from tenacity import (
|
||||||
|
RetryCallState,
|
||||||
|
retry,
|
||||||
|
retry_if_exception_type,
|
||||||
|
retry_if_result,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_exponential,
|
||||||
|
wait_fixed,
|
||||||
|
)
|
||||||
|
|
||||||
|
from zhenxun.services.log import logger
|
||||||
|
|
||||||
|
LOG_COMMAND = "RetryDecorator"
|
||||||
|
_SENTINEL = object()
|
||||||
|
|
||||||
|
|
||||||
|
def _log_before_sleep(log_name: str | None, retry_state: RetryCallState):
|
||||||
|
"""
|
||||||
|
tenacity 重试前的日志记录回调函数。
|
||||||
|
"""
|
||||||
|
func_name = retry_state.fn.__name__ if retry_state.fn else "unknown_function"
|
||||||
|
log_context = f"函数 '{func_name}'"
|
||||||
|
if log_name:
|
||||||
|
log_context = f"操作 '{log_name}' ({log_context})"
|
||||||
|
|
||||||
|
reason = ""
|
||||||
|
if retry_state.outcome:
|
||||||
|
if exc := retry_state.outcome.exception():
|
||||||
|
reason = f"触发异常: {exc.__class__.__name__}({exc})"
|
||||||
|
else:
|
||||||
|
reason = f"不满足结果条件: result={retry_state.outcome.result()}"
|
||||||
|
|
||||||
|
wait_time = (
|
||||||
|
getattr(retry_state.next_action, "sleep", 0) if retry_state.next_action else 0
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
f"{log_context} 第 {retry_state.attempt_number} 次重试... "
|
||||||
|
f"等待 {wait_time:.2f} 秒. {reason}",
|
||||||
|
LOG_COMMAND,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Retry:
|
class Retry:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def api(
|
def simple(
|
||||||
retry_count: int = 3, wait: int = 1, exception: tuple[type[Exception], ...] = ()
|
stop_max_attempt: int = 3,
|
||||||
|
wait_fixed_seconds: int = 2,
|
||||||
|
exception: tuple[type[Exception], ...] = (),
|
||||||
|
*,
|
||||||
|
log_name: str | None = None,
|
||||||
|
on_failure: Callable[[Exception], Any] | None = None,
|
||||||
|
return_on_failure: Any = _SENTINEL,
|
||||||
):
|
):
|
||||||
"""接口调用重试"""
|
"""
|
||||||
|
一个简单的、用于通用网络请求的重试装饰器预设。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
stop_max_attempt: 最大重试次数。
|
||||||
|
wait_fixed_seconds: 固定等待策略的等待秒数。
|
||||||
|
exception: 额外需要重试的异常类型元组。
|
||||||
|
log_name: 用于日志记录的操作名称。
|
||||||
|
on_failure: (可选) 所有重试失败后的回调。
|
||||||
|
return_on_failure: (可选) 所有重试失败后的返回值。
|
||||||
|
"""
|
||||||
|
return Retry.api(
|
||||||
|
stop_max_attempt=stop_max_attempt,
|
||||||
|
wait_fixed_seconds=wait_fixed_seconds,
|
||||||
|
exception=exception,
|
||||||
|
strategy="fixed",
|
||||||
|
log_name=log_name,
|
||||||
|
on_failure=on_failure,
|
||||||
|
return_on_failure=return_on_failure,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def download(
|
||||||
|
stop_max_attempt: int = 3,
|
||||||
|
exception: tuple[type[Exception], ...] = (),
|
||||||
|
*,
|
||||||
|
wait_exp_multiplier: int = 2,
|
||||||
|
wait_exp_max: int = 15,
|
||||||
|
log_name: str | None = None,
|
||||||
|
on_failure: Callable[[Exception], Any] | None = None,
|
||||||
|
return_on_failure: Any = _SENTINEL,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
一个适用于文件下载的重试装饰器预设,使用指数退避策略。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
stop_max_attempt: 最大重试次数。
|
||||||
|
exception: 额外需要重试的异常类型元组。
|
||||||
|
wait_exp_multiplier: 指数退避的乘数。
|
||||||
|
wait_exp_max: 指数退避的最大等待时间。
|
||||||
|
log_name: 用于日志记录的操作名称。
|
||||||
|
on_failure: (可选) 所有重试失败后的回调。
|
||||||
|
return_on_failure: (可选) 所有重试失败后的返回值。
|
||||||
|
"""
|
||||||
|
return Retry.api(
|
||||||
|
stop_max_attempt=stop_max_attempt,
|
||||||
|
exception=exception,
|
||||||
|
strategy="exponential",
|
||||||
|
wait_exp_multiplier=wait_exp_multiplier,
|
||||||
|
wait_exp_max=wait_exp_max,
|
||||||
|
log_name=log_name,
|
||||||
|
on_failure=on_failure,
|
||||||
|
return_on_failure=return_on_failure,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def api(
|
||||||
|
stop_max_attempt: int = 3,
|
||||||
|
wait_fixed_seconds: int = 1,
|
||||||
|
exception: tuple[type[Exception], ...] = (),
|
||||||
|
*,
|
||||||
|
strategy: Literal["fixed", "exponential"] = "fixed",
|
||||||
|
retry_on_result: Callable[[Any], bool] | None = None,
|
||||||
|
wait_exp_multiplier: int = 1,
|
||||||
|
wait_exp_max: int = 10,
|
||||||
|
log_name: str | None = None,
|
||||||
|
on_failure: Callable[[Exception], Any] | None = None,
|
||||||
|
return_on_failure: Any = _SENTINEL,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
通用、可配置的API调用重试装饰器。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
stop_max_attempt: 最大重试次数。
|
||||||
|
wait_fixed_seconds: 固定等待策略的等待秒数。
|
||||||
|
exception: 额外需要重试的异常类型元组。
|
||||||
|
strategy: 重试等待策略, 'fixed' (固定) 或 'exponential' (指数退避)。
|
||||||
|
retry_on_result: 一个回调函数,接收函数返回值。如果返回 True,则触发重试。
|
||||||
|
例如 `lambda r: r.status_code != 200`
|
||||||
|
wait_exp_multiplier: 指数退避的乘数。
|
||||||
|
wait_exp_max: 指数退避的最大等待时间。
|
||||||
|
log_name: 用于日志记录的操作名称,方便区分不同的重试场景。
|
||||||
|
on_failure: (可选) 当所有重试都失败后,在抛出异常或返回默认值之前,
|
||||||
|
会调用此函数,并将最终的异常实例作为参数传入。
|
||||||
|
return_on_failure: (可选) 如果设置了此参数,当所有重试失败后,
|
||||||
|
将不再抛出异常,而是返回此参数指定的值。
|
||||||
|
"""
|
||||||
base_exceptions = (
|
base_exceptions = (
|
||||||
TimeoutException,
|
TimeoutException,
|
||||||
ConnectError,
|
ConnectError,
|
||||||
HTTPStatusError,
|
HTTPStatusError,
|
||||||
|
StreamError,
|
||||||
|
RemoteProtocolError,
|
||||||
EndOfStream,
|
EndOfStream,
|
||||||
*exception,
|
*exception,
|
||||||
)
|
)
|
||||||
return retry(
|
|
||||||
reraise=True,
|
def decorator(func: Callable) -> Callable:
|
||||||
stop=stop_after_attempt(retry_count),
|
if strategy == "exponential":
|
||||||
wait=wait_fixed(wait),
|
wait_strategy = wait_exponential(
|
||||||
retry=retry_if_exception_type(base_exceptions),
|
multiplier=wait_exp_multiplier, max=wait_exp_max
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
wait_strategy = wait_fixed(wait_fixed_seconds)
|
||||||
|
|
||||||
|
retry_conditions = retry_if_exception_type(base_exceptions)
|
||||||
|
if retry_on_result:
|
||||||
|
retry_conditions |= retry_if_result(retry_on_result)
|
||||||
|
|
||||||
|
log_callback = partial(_log_before_sleep, log_name)
|
||||||
|
|
||||||
|
tenacity_retry_decorator = retry(
|
||||||
|
stop=stop_after_attempt(stop_max_attempt),
|
||||||
|
wait=wait_strategy,
|
||||||
|
retry=retry_conditions,
|
||||||
|
before_sleep=log_callback,
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
decorated_func = tenacity_retry_decorator(func)
|
||||||
|
|
||||||
|
if return_on_failure is _SENTINEL:
|
||||||
|
return decorated_func
|
||||||
|
|
||||||
|
if is_coroutine_callable(func):
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
async def async_wrapper(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
return await decorated_func(*args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
if on_failure:
|
||||||
|
if is_coroutine_callable(on_failure):
|
||||||
|
await on_failure(e)
|
||||||
|
else:
|
||||||
|
on_failure(e)
|
||||||
|
return return_on_failure
|
||||||
|
|
||||||
|
return async_wrapper
|
||||||
|
else:
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def sync_wrapper(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
return decorated_func(*args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
if on_failure:
|
||||||
|
if is_coroutine_callable(on_failure):
|
||||||
|
logger.error(
|
||||||
|
f"不能在同步函数 '{func.__name__}' 中调用异步的 "
|
||||||
|
f"on_failure 回调。",
|
||||||
|
LOG_COMMAND,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
on_failure(e)
|
||||||
|
return return_on_failure
|
||||||
|
|
||||||
|
return sync_wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|||||||
@ -64,3 +64,23 @@ class GoodsNotFound(Exception):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AllURIsFailedError(Exception):
|
||||||
|
"""
|
||||||
|
当所有备用URL都尝试失败后抛出此异常
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, urls: list[str], exceptions: list[Exception]):
|
||||||
|
self.urls = urls
|
||||||
|
self.exceptions = exceptions
|
||||||
|
super().__init__(
|
||||||
|
f"All {len(urls)} URIs failed. Last exception: {exceptions[-1]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
exc_info = "\n".join(
|
||||||
|
f" - {url}: {exc.__class__.__name__}({exc})"
|
||||||
|
for url, exc in zip(self.urls, self.exceptions)
|
||||||
|
)
|
||||||
|
return f"All {len(self.urls)} URIs failed:\n{exc_info}"
|
||||||
|
|||||||
@ -1,16 +1,15 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import AsyncGenerator, Sequence
|
from collections.abc import AsyncGenerator, Awaitable, Callable, Sequence
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import time
|
import time
|
||||||
from typing import Any, ClassVar, Literal, cast
|
from typing import Any, ClassVar, cast
|
||||||
|
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import httpx
|
import httpx
|
||||||
from httpx import AsyncHTTPTransport, HTTPStatusError, Proxy, Response
|
from httpx import AsyncClient, AsyncHTTPTransport, HTTPStatusError, Proxy, Response
|
||||||
from nonebot_plugin_alconna import UniMessage
|
import nonebot
|
||||||
from nonebot_plugin_htmlrender import get_browser
|
|
||||||
from playwright.async_api import Page
|
|
||||||
from rich.progress import (
|
from rich.progress import (
|
||||||
BarColumn,
|
BarColumn,
|
||||||
DownloadColumn,
|
DownloadColumn,
|
||||||
@ -18,13 +17,84 @@ from rich.progress import (
|
|||||||
TextColumn,
|
TextColumn,
|
||||||
TransferSpeedColumn,
|
TransferSpeedColumn,
|
||||||
)
|
)
|
||||||
|
import ujson as json
|
||||||
|
|
||||||
from zhenxun.configs.config import BotConfig
|
from zhenxun.configs.config import BotConfig
|
||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
from zhenxun.utils.message import MessageUtils
|
from zhenxun.utils.decorator.retry import Retry
|
||||||
|
from zhenxun.utils.exception import AllURIsFailedError
|
||||||
|
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||||
from zhenxun.utils.user_agent import get_user_agent
|
from zhenxun.utils.user_agent import get_user_agent
|
||||||
|
|
||||||
CLIENT_KEY = ["use_proxy", "proxies", "proxy", "verify", "headers"]
|
from .browser import AsyncPlaywright, BrowserIsNone # noqa: F401
|
||||||
|
|
||||||
|
_SENTINEL = object()
|
||||||
|
|
||||||
|
driver = nonebot.get_driver()
|
||||||
|
_client: AsyncClient | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@PriorityLifecycle.on_startup(priority=0)
|
||||||
|
async def _():
|
||||||
|
"""
|
||||||
|
在Bot启动时初始化全局httpx客户端。
|
||||||
|
"""
|
||||||
|
global _client
|
||||||
|
client_kwargs = {}
|
||||||
|
if proxy_url := BotConfig.system_proxy or None:
|
||||||
|
try:
|
||||||
|
version_parts = httpx.__version__.split(".")
|
||||||
|
major = int("".join(c for c in version_parts[0] if c.isdigit()))
|
||||||
|
minor = (
|
||||||
|
int("".join(c for c in version_parts[1] if c.isdigit()))
|
||||||
|
if len(version_parts) > 1
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
if (major, minor) >= (0, 28):
|
||||||
|
client_kwargs["proxy"] = proxy_url
|
||||||
|
else:
|
||||||
|
client_kwargs["proxies"] = proxy_url
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
client_kwargs["proxy"] = proxy_url
|
||||||
|
logger.warning(
|
||||||
|
f"无法解析 httpx 版本 '{httpx.__version__}',"
|
||||||
|
"将默认使用新版 'proxy' 参数语法。"
|
||||||
|
)
|
||||||
|
|
||||||
|
_client = httpx.AsyncClient(
|
||||||
|
headers=get_user_agent(),
|
||||||
|
follow_redirects=True,
|
||||||
|
**client_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("全局 httpx.AsyncClient 已启动。", "HTTPClient")
|
||||||
|
|
||||||
|
|
||||||
|
@driver.on_shutdown
|
||||||
|
async def _():
|
||||||
|
"""
|
||||||
|
在Bot关闭时关闭全局httpx客户端。
|
||||||
|
"""
|
||||||
|
if _client:
|
||||||
|
await _client.aclose()
|
||||||
|
logger.info("全局 httpx.AsyncClient 已关闭。", "HTTPClient")
|
||||||
|
|
||||||
|
|
||||||
|
def get_client() -> AsyncClient:
|
||||||
|
"""
|
||||||
|
获取全局 httpx.AsyncClient 实例。
|
||||||
|
"""
|
||||||
|
global _client
|
||||||
|
if not _client:
|
||||||
|
if not os.environ.get("PYTEST_CURRENT_TEST"):
|
||||||
|
raise RuntimeError("全局 httpx.AsyncClient 未初始化,请检查启动流程。")
|
||||||
|
# 在测试环境中创建临时客户端
|
||||||
|
logger.warning("在测试环境中创建临时HTTP客户端", "HTTPClient")
|
||||||
|
_client = httpx.AsyncClient(
|
||||||
|
headers=get_user_agent(),
|
||||||
|
follow_redirects=True,
|
||||||
|
)
|
||||||
|
return _client
|
||||||
|
|
||||||
|
|
||||||
def get_async_client(
|
def get_async_client(
|
||||||
@ -33,6 +103,10 @@ def get_async_client(
|
|||||||
verify: bool = False,
|
verify: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> httpx.AsyncClient:
|
) -> httpx.AsyncClient:
|
||||||
|
"""
|
||||||
|
[向后兼容] 创建 httpx.AsyncClient 实例的工厂函数。
|
||||||
|
此函数完全保留了旧版本的接口,确保现有代码无需修改即可使用。
|
||||||
|
"""
|
||||||
transport = kwargs.pop("transport", None) or AsyncHTTPTransport(verify=verify)
|
transport = kwargs.pop("transport", None) or AsyncHTTPTransport(verify=verify)
|
||||||
if proxies:
|
if proxies:
|
||||||
http_proxy = proxies.get("http://")
|
http_proxy = proxies.get("http://")
|
||||||
@ -62,6 +136,30 @@ def get_async_client(
|
|||||||
|
|
||||||
|
|
||||||
class AsyncHttpx:
|
class AsyncHttpx:
|
||||||
|
"""
|
||||||
|
一个高级的、健壮的异步HTTP客户端工具类。
|
||||||
|
|
||||||
|
设计理念:
|
||||||
|
- **全局共享客户端**: 默认情况下,所有请求都通过一个在应用启动时初始化的全局
|
||||||
|
`httpx.AsyncClient` 实例发出。这个实例共享连接池,提高了效率和性能。
|
||||||
|
- **向后兼容与灵活性**: 完全兼容旧的API,同时提供了两种方式来处理需要
|
||||||
|
特殊网络配置(如不同代理、超时)的请求:
|
||||||
|
1. **单次请求覆盖**: 在调用 `get`, `post` 等方法时,直接传入 `proxies`,
|
||||||
|
`timeout` 等参数,将为该次请求创建一个临时的、独立的客户端。
|
||||||
|
2. **临时客户端上下文**: 使用 `temporary_client()` 上下文管理器,可以
|
||||||
|
获取一个独立的、可配置的客户端,用于执行一系列需要相同特殊配置的请求。
|
||||||
|
- **健壮性**: 内置了自动重试、多镜像URL回退(fallback)机制,并提供了便捷的
|
||||||
|
JSON解析和文件下载方法。
|
||||||
|
"""
|
||||||
|
|
||||||
|
CLIENT_KEY: ClassVar[list[str]] = [
|
||||||
|
"use_proxy",
|
||||||
|
"proxies",
|
||||||
|
"proxy",
|
||||||
|
"verify",
|
||||||
|
"headers",
|
||||||
|
]
|
||||||
|
|
||||||
default_proxy: ClassVar[dict[str, str] | None] = (
|
default_proxy: ClassVar[dict[str, str] | None] = (
|
||||||
{
|
{
|
||||||
"http://": BotConfig.system_proxy,
|
"http://": BotConfig.system_proxy,
|
||||||
@ -72,82 +170,157 @@ class AsyncHttpx:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@asynccontextmanager
|
def _prepare_temporary_client_config(cls, client_kwargs: dict) -> dict:
|
||||||
async def _create_client(
|
"""
|
||||||
cls,
|
[向后兼容] 处理旧式的客户端kwargs,将其转换为get_async_client可用的配置。
|
||||||
*,
|
主要负责处理 use_proxy 标志,这是为了兼容旧版本代码中使用的 use_proxy 参数。
|
||||||
use_proxy: bool = True,
|
"""
|
||||||
proxies: dict[str, str] | None = None,
|
final_config = client_kwargs.copy()
|
||||||
proxy: str | None = None,
|
|
||||||
headers: dict[str, str] | None = None,
|
|
||||||
verify: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
) -> AsyncGenerator[httpx.AsyncClient, None]:
|
|
||||||
"""创建一个私有的、配置好的 httpx.AsyncClient 上下文管理器。
|
|
||||||
|
|
||||||
说明:
|
use_proxy = final_config.pop("use_proxy", True)
|
||||||
此方法用于内部统一创建客户端,处理代理和请求头逻辑,减少代码重复。
|
|
||||||
|
if "proxies" not in final_config and "proxy" not in final_config:
|
||||||
|
final_config["proxies"] = cls.default_proxy if use_proxy else None
|
||||||
|
return final_config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _split_kwargs(cls, kwargs: dict) -> tuple[dict, dict]:
|
||||||
|
"""[优化] 分离客户端配置和请求参数,使逻辑更清晰。"""
|
||||||
|
client_kwargs = {k: v for k, v in kwargs.items() if k in cls.CLIENT_KEY}
|
||||||
|
request_kwargs = {k: v for k, v in kwargs.items() if k not in cls.CLIENT_KEY}
|
||||||
|
return client_kwargs, request_kwargs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _get_active_client_context(
|
||||||
|
cls, client: AsyncClient | None = None, **kwargs
|
||||||
|
) -> AsyncGenerator[AsyncClient, None]:
|
||||||
|
"""
|
||||||
|
内部辅助方法,根据 kwargs 决定并提供一个活动的 HTTP 客户端。
|
||||||
|
- 如果 kwargs 中有客户端配置,则创建并返回一个临时客户端。
|
||||||
|
- 否则,返回传入的 client 或全局客户端。
|
||||||
|
- 自动处理临时客户端的关闭。
|
||||||
|
"""
|
||||||
|
if kwargs:
|
||||||
|
logger.debug(f"为单次请求创建临时客户端,配置: {kwargs}")
|
||||||
|
temp_client_config = cls._prepare_temporary_client_config(kwargs)
|
||||||
|
async with get_async_client(**temp_client_config) as temp_client:
|
||||||
|
yield temp_client
|
||||||
|
else:
|
||||||
|
yield client or get_client()
|
||||||
|
|
||||||
|
@Retry.simple(log_name="内部HTTP请求")
|
||||||
|
async def _execute_request_inner(
|
||||||
|
self, client: AsyncClient, method: str, url: str, **kwargs
|
||||||
|
) -> Response:
|
||||||
|
"""
|
||||||
|
[内部] 执行单次HTTP请求的私有核心方法,被重试装饰器包裹。
|
||||||
|
"""
|
||||||
|
return await client.request(method, url, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _single_request(
|
||||||
|
cls, method: str, url: str, *, client: AsyncClient | None = None, **kwargs
|
||||||
|
) -> Response:
|
||||||
|
"""
|
||||||
|
执行单次HTTP请求的私有方法,内置了默认的重试逻辑。
|
||||||
|
"""
|
||||||
|
client_kwargs, request_kwargs = cls._split_kwargs(kwargs)
|
||||||
|
|
||||||
|
async with cls._get_active_client_context(
|
||||||
|
client=client, **client_kwargs
|
||||||
|
) as active_client:
|
||||||
|
response = await cls()._execute_request_inner(
|
||||||
|
active_client, method, url, **request_kwargs
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _execute_with_fallbacks(
|
||||||
|
cls,
|
||||||
|
urls: str | list[str],
|
||||||
|
worker: Callable[..., Awaitable[Any]],
|
||||||
|
*,
|
||||||
|
client: AsyncClient | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
通用执行器,按顺序尝试多个URL,直到成功。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
use_proxy: 是否使用在类中定义的默认代理。
|
urls: 单个URL或URL列表。
|
||||||
proxies: 手动指定的代理,会覆盖默认代理。
|
worker: 一个接受单个URL和其他kwargs并执行请求的协程函数。
|
||||||
proxy: 单个代理,用于兼容旧版本,不再使用
|
client: 可选的HTTP客户端。
|
||||||
headers: 需要合并到客户端的自定义请求头。
|
**kwargs: 传递给worker的额外参数。
|
||||||
verify: 是否验证 SSL 证书。
|
|
||||||
**kwargs: 其他所有传递给 httpx.AsyncClient 的参数。
|
|
||||||
|
|
||||||
返回:
|
|
||||||
AsyncGenerator[httpx.AsyncClient, None]: 生成器。
|
|
||||||
"""
|
"""
|
||||||
proxies_to_use = proxies or (cls.default_proxy if use_proxy else None)
|
url_list = [urls] if isinstance(urls, str) else urls
|
||||||
|
exceptions = []
|
||||||
|
|
||||||
final_headers = get_user_agent()
|
for i, url in enumerate(url_list):
|
||||||
if headers:
|
try:
|
||||||
final_headers.update(headers)
|
result = await worker(url, client=client, **kwargs)
|
||||||
|
if i > 0:
|
||||||
|
logger.info(
|
||||||
|
f"成功从镜像 '{url}' 获取资源 "
|
||||||
|
f"(在尝试了 {i} 个失败的镜像之后)。",
|
||||||
|
"AsyncHttpx:FallbackExecutor",
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
exceptions.append(e)
|
||||||
|
if url != url_list[-1]:
|
||||||
|
logger.warning(
|
||||||
|
f"Worker '{worker.__name__}' on {url} failed, trying next. "
|
||||||
|
f"Error: {e.__class__.__name__}",
|
||||||
|
"AsyncHttpx:FallbackExecutor",
|
||||||
|
)
|
||||||
|
|
||||||
async with get_async_client(
|
raise AllURIsFailedError(url_list, exceptions)
|
||||||
proxies=proxies_to_use,
|
|
||||||
proxy=proxy,
|
|
||||||
verify=verify,
|
|
||||||
headers=final_headers,
|
|
||||||
**kwargs,
|
|
||||||
) as client:
|
|
||||||
yield client
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get(
|
async def get(
|
||||||
cls,
|
cls,
|
||||||
url: str | list[str],
|
url: str | list[str],
|
||||||
*,
|
*,
|
||||||
|
follow_redirects: bool = True,
|
||||||
check_status_code: int | None = None,
|
check_status_code: int | None = None,
|
||||||
|
client: AsyncClient | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Response: # sourcery skip: use-assigned-variable
|
) -> Response:
|
||||||
"""发送 GET 请求,并返回第一个成功的响应。
|
"""发送 GET 请求,并返回第一个成功的响应。
|
||||||
|
|
||||||
说明:
|
说明:
|
||||||
本方法是 httpx.get 的高级包装,增加了多链接尝试、自动重试和统一的代理管理。
|
本方法是 httpx.get 的高级包装,增加了多链接尝试、自动重试和统一的
|
||||||
如果提供 URL 列表,它将依次尝试直到成功为止。
|
客户端管理。如果提供 URL 列表,它将依次尝试直到成功为止。
|
||||||
|
|
||||||
|
用法建议:
|
||||||
|
- **常规使用**: `await AsyncHttpx.get(url)` 将使用全局客户端。
|
||||||
|
- **单次覆盖配置**: `await AsyncHttpx.get(url, timeout=5, proxies=None)`
|
||||||
|
将为本次请求创建一个独立的临时客户端。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
url: 单个请求 URL 或一个 URL 列表。
|
url: 单个请求 URL 或一个 URL 列表。
|
||||||
|
follow_redirects: 是否跟随重定向。
|
||||||
check_status_code: (可选) 若提供,将检查响应状态码是否匹配,否则抛出异常。
|
check_status_code: (可选) 若提供,将检查响应状态码是否匹配,否则抛出异常。
|
||||||
**kwargs: 其他所有传递给 httpx.get 的参数
|
client: (可选) 指定一个活动的HTTP客户端实例。若提供,则忽略
|
||||||
(如 `params`, `headers`, `timeout`等)。
|
`**kwargs`中的客户端配置。
|
||||||
|
**kwargs: 其他所有传递给 httpx.get 的参数 (如 `params`, `headers`,
|
||||||
|
`timeout`)。如果包含 `proxies`, `verify` 等客户端配置参数,
|
||||||
|
将创建一个临时客户端。
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
Response: Response
|
Response: httpx 的响应对象。
|
||||||
"""
|
|
||||||
urls = [url] if isinstance(url, str) else url
|
|
||||||
last_exception = None
|
|
||||||
for current_url in urls:
|
|
||||||
try:
|
|
||||||
logger.info(f"开始获取 {current_url}..")
|
|
||||||
client_kwargs = {k: v for k, v in kwargs.items() if k in CLIENT_KEY}
|
|
||||||
for key in CLIENT_KEY:
|
|
||||||
kwargs.pop(key, None)
|
|
||||||
async with cls._create_client(**client_kwargs) as client:
|
|
||||||
response = await client.get(current_url, **kwargs)
|
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AllURIsFailedError: 当所有提供的URL都请求失败时抛出。
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def worker(current_url: str, **worker_kwargs) -> Response:
|
||||||
|
logger.info(f"开始获取 {current_url}..", "AsyncHttpx:get")
|
||||||
|
response = await cls._single_request(
|
||||||
|
"GET", current_url, follow_redirects=follow_redirects, **worker_kwargs
|
||||||
|
)
|
||||||
if check_status_code and response.status_code != check_status_code:
|
if check_status_code and response.status_code != check_status_code:
|
||||||
raise HTTPStatusError(
|
raise HTTPStatusError(
|
||||||
f"状态码错误: {response.status_code}!={check_status_code}",
|
f"状态码错误: {response.status_code}!={check_status_code}",
|
||||||
@ -155,117 +328,172 @@ class AsyncHttpx:
|
|||||||
response=response,
|
response=response,
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
|
||||||
last_exception = e
|
|
||||||
if current_url != urls[-1]:
|
|
||||||
logger.warning(f"获取 {current_url} 失败, 尝试下一个", e=e)
|
|
||||||
|
|
||||||
raise last_exception or Exception("所有URL都获取失败")
|
return await cls._execute_with_fallbacks(url, worker, client=client, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def head(cls, url: str, **kwargs) -> Response:
|
async def head(
|
||||||
"""发送 HEAD 请求。
|
cls, url: str | list[str], *, client: AsyncClient | None = None, **kwargs
|
||||||
|
) -> Response:
|
||||||
|
"""发送 HEAD 请求,并返回第一个成功的响应。"""
|
||||||
|
|
||||||
说明:
|
async def worker(current_url: str, **worker_kwargs) -> Response:
|
||||||
本方法是对 httpx.head 的封装,通常用于检查资源的元信息(如大小、类型)。
|
return await cls._single_request("HEAD", current_url, **worker_kwargs)
|
||||||
|
|
||||||
参数:
|
return await cls._execute_with_fallbacks(url, worker, client=client, **kwargs)
|
||||||
url: 请求的 URL。
|
|
||||||
**kwargs: 其他所有传递给 httpx.head 的参数
|
|
||||||
(如 `headers`, `timeout`, `allow_redirects`)。
|
|
||||||
|
|
||||||
返回:
|
|
||||||
Response: Response
|
|
||||||
"""
|
|
||||||
client_kwargs = {k: v for k, v in kwargs.items() if k in CLIENT_KEY}
|
|
||||||
for key in CLIENT_KEY:
|
|
||||||
kwargs.pop(key, None)
|
|
||||||
async with cls._create_client(**client_kwargs) as client:
|
|
||||||
return await client.head(url, **kwargs)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def post(cls, url: str, **kwargs) -> Response:
|
async def post(
|
||||||
"""发送 POST 请求。
|
cls, url: str | list[str], *, client: AsyncClient | None = None, **kwargs
|
||||||
|
) -> Response:
|
||||||
|
"""发送 POST 请求,并返回第一个成功的响应。"""
|
||||||
|
|
||||||
说明:
|
async def worker(current_url: str, **worker_kwargs) -> Response:
|
||||||
本方法是对 httpx.post 的封装,提供了统一的代理和客户端管理。
|
return await cls._single_request("POST", current_url, **worker_kwargs)
|
||||||
|
|
||||||
参数:
|
return await cls._execute_with_fallbacks(url, worker, client=client, **kwargs)
|
||||||
url: 请求的 URL。
|
|
||||||
**kwargs: 其他所有传递给 httpx.post 的参数
|
|
||||||
(如 `data`, `json`, `content` 等)。
|
|
||||||
|
|
||||||
返回:
|
|
||||||
Response: Response。
|
|
||||||
"""
|
|
||||||
client_kwargs = {k: v for k, v in kwargs.items() if k in CLIENT_KEY}
|
|
||||||
for key in CLIENT_KEY:
|
|
||||||
kwargs.pop(key, None)
|
|
||||||
async with cls._create_client(**client_kwargs) as client:
|
|
||||||
return await client.post(url, **kwargs)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_content(cls, url: str, **kwargs) -> bytes:
|
async def get_content(
|
||||||
"""获取指定 URL 的二进制内容。
|
cls, url: str | list[str], *, client: AsyncClient | None = None, **kwargs
|
||||||
|
) -> bytes:
|
||||||
说明:
|
"""获取指定 URL 的二进制内容。"""
|
||||||
这是一个便捷方法,等同于调用 get() 后再访问 .content 属性。
|
res = await cls.get(url, client=client, **kwargs)
|
||||||
|
|
||||||
参数:
|
|
||||||
url: 请求的 URL。
|
|
||||||
**kwargs: 所有传递给 get() 方法的参数。
|
|
||||||
|
|
||||||
返回:
|
|
||||||
bytes: 响应内容的二进制字节流 (bytes)。
|
|
||||||
"""
|
|
||||||
res = await cls.get(url, **kwargs)
|
|
||||||
return res.content
|
return res.content
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def download_file(
|
@Retry.api(
|
||||||
|
log_name="JSON请求",
|
||||||
|
exception=(json.JSONDecodeError,),
|
||||||
|
return_on_failure=_SENTINEL,
|
||||||
|
)
|
||||||
|
async def _request_and_parse_json(
|
||||||
|
cls, method: str, url: str, *, client: AsyncClient | None = None, **kwargs
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
[私有] 执行单个HTTP请求并解析JSON,用于内部统一处理。
|
||||||
|
"""
|
||||||
|
async with cls._get_active_client_context(
|
||||||
|
client=client, **kwargs
|
||||||
|
) as active_client:
|
||||||
|
_, request_kwargs = cls._split_kwargs(kwargs)
|
||||||
|
response = await active_client.request(method, url, **request_kwargs)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_json(
|
||||||
cls,
|
cls,
|
||||||
url: str | list[str],
|
url: str | list[str],
|
||||||
path: str | Path,
|
|
||||||
*,
|
*,
|
||||||
stream: bool = False,
|
default: Any = None,
|
||||||
|
raise_on_failure: bool = False,
|
||||||
|
client: AsyncClient | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> bool:
|
) -> Any:
|
||||||
"""下载文件到指定路径。
|
"""
|
||||||
|
发送GET请求并自动解析为JSON,支持重试和多链接尝试。
|
||||||
|
|
||||||
说明:
|
说明:
|
||||||
支持多链接尝试和流式下载(带进度条)。
|
这是一个高度便捷的方法,封装了请求、重试、JSON解析和错误处理。
|
||||||
|
它会在网络错误或JSON解析错误时自动重试。
|
||||||
|
如果所有尝试都失败,它会安全地返回一个默认值。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
url: 单个文件 URL 或一个备用 URL 列表。
|
url: 单个请求 URL 或一个备用 URL 列表。
|
||||||
path: 文件保存的本地路径。
|
default: (可选) 当所有尝试都失败时返回的默认值,默认为None。
|
||||||
stream: (可选) 是否使用流式下载,适用于大文件,默认为 False。
|
raise_on_failure: (可选) 如果为 True, 当所有尝试失败时将抛出
|
||||||
**kwargs: 其他所有传递给 get() 方法或 httpx.stream() 的参数。
|
`AllURIsFailedError` 异常, 默认为 False.
|
||||||
|
client: (可选) 指定的HTTP客户端。
|
||||||
|
**kwargs: 其他所有传递给 httpx.get 的参数。
|
||||||
|
例如 `params`, `headers`, `timeout`等。
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
bool: 是否下载成功。
|
Any: 解析后的JSON数据,或在失败时返回 `default` 值。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AllURIsFailedError: 当 `raise_on_failure` 为 True 且所有URL都请求失败时抛出
|
||||||
"""
|
"""
|
||||||
path = Path(path)
|
|
||||||
path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
urls = [url] if isinstance(url, str) else url
|
async def worker(current_url: str, **worker_kwargs):
|
||||||
|
logger.debug(f"开始GET JSON: {current_url}", "AsyncHttpx:get_json")
|
||||||
|
return await cls._request_and_parse_json(
|
||||||
|
"GET", current_url, **worker_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
for current_url in urls:
|
|
||||||
try:
|
try:
|
||||||
if not stream:
|
result = await cls._execute_with_fallbacks(
|
||||||
response = await cls.get(current_url, **kwargs)
|
url, worker, client=client, **kwargs
|
||||||
response.raise_for_status()
|
)
|
||||||
async with aiofiles.open(path, "wb") as f:
|
return default if result is _SENTINEL else result
|
||||||
await f.write(response.content)
|
except AllURIsFailedError as e:
|
||||||
else:
|
logger.error(f"所有URL的JSON GET均失败: {e}", "AsyncHttpx:get_json")
|
||||||
async with cls._create_client(**kwargs) as client:
|
if raise_on_failure:
|
||||||
stream_kwargs = {
|
raise e
|
||||||
k: v
|
return default
|
||||||
for k, v in kwargs.items()
|
|
||||||
if k not in ["use_proxy", "proxy", "verify"]
|
@classmethod
|
||||||
}
|
async def post_json(
|
||||||
async with client.stream(
|
cls,
|
||||||
"GET", current_url, **stream_kwargs
|
url: str | list[str],
|
||||||
) as response:
|
*,
|
||||||
|
json: Any = None,
|
||||||
|
data: Any = None,
|
||||||
|
default: Any = None,
|
||||||
|
raise_on_failure: bool = False,
|
||||||
|
client: AsyncClient | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
发送POST请求并自动解析为JSON,功能与 get_json 类似。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
url: 单个请求 URL 或一个备用 URL 列表。
|
||||||
|
json: (可选) 作为请求体发送的JSON数据。
|
||||||
|
data: (可选) 作为请求体发送的表单数据。
|
||||||
|
default: (可选) 当所有尝试都失败时返回的默认值,默认为None。
|
||||||
|
raise_on_failure: (可选) 如果为 True, 当所有尝试失败时将抛出
|
||||||
|
AllURIsFailedError 异常, 默认为 False.
|
||||||
|
client: (可选) 指定的HTTP客户端。
|
||||||
|
**kwargs: 其他所有传递给 httpx.post 的参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
Any: 解析后的JSON数据,或在失败时返回 `default` 值。
|
||||||
|
"""
|
||||||
|
if json is not None:
|
||||||
|
kwargs["json"] = json
|
||||||
|
if data is not None:
|
||||||
|
kwargs["data"] = data
|
||||||
|
|
||||||
|
async def worker(current_url: str, **worker_kwargs):
|
||||||
|
logger.debug(f"开始POST JSON: {current_url}", "AsyncHttpx:post_json")
|
||||||
|
return await cls._request_and_parse_json(
|
||||||
|
"POST", current_url, **worker_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await cls._execute_with_fallbacks(
|
||||||
|
url, worker, client=client, **kwargs
|
||||||
|
)
|
||||||
|
return default if result is _SENTINEL else result
|
||||||
|
except AllURIsFailedError as e:
|
||||||
|
logger.error(f"所有URL的JSON POST均失败: {e}", "AsyncHttpx:post_json")
|
||||||
|
if raise_on_failure:
|
||||||
|
raise e
|
||||||
|
return default
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@Retry.api(log_name="文件下载(流式)")
|
||||||
|
async def _stream_download(
|
||||||
|
cls, url: str, path: Path, *, client: AsyncClient | None = None, **kwargs
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
执行单个流式下载的私有方法,被重试装饰器包裹。
|
||||||
|
"""
|
||||||
|
async with cls._get_active_client_context(
|
||||||
|
client=client, **kwargs
|
||||||
|
) as active_client:
|
||||||
|
async with active_client.stream("GET", url, **kwargs) as response:
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
total = int(response.headers.get("Content-Length", 0))
|
total = int(response.headers.get("Content-Length", 0))
|
||||||
|
|
||||||
@ -282,13 +510,56 @@ class AsyncHttpx:
|
|||||||
await f.write(chunk)
|
await f.write(chunk)
|
||||||
progress.update(task_id, advance=len(chunk))
|
progress.update(task_id, advance=len(chunk))
|
||||||
|
|
||||||
logger.info(f"下载 {current_url} 成功 -> {path.absolute()}")
|
@classmethod
|
||||||
|
async def download_file(
|
||||||
|
cls,
|
||||||
|
url: str | list[str],
|
||||||
|
path: str | Path,
|
||||||
|
*,
|
||||||
|
stream: bool = False,
|
||||||
|
client: AsyncClient | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> bool:
|
||||||
|
"""下载文件到指定路径。
|
||||||
|
|
||||||
|
说明:
|
||||||
|
支持多链接尝试和流式下载(带进度条)。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
url: 单个文件 URL 或一个备用 URL 列表。
|
||||||
|
path: 文件保存的本地路径。
|
||||||
|
stream: (可选) 是否使用流式下载,适用于大文件,默认为 False。
|
||||||
|
client: (可选) 指定的HTTP客户端。
|
||||||
|
**kwargs: 其他所有传递给 get() 方法或 httpx.stream() 的参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bool: 是否下载成功。
|
||||||
|
"""
|
||||||
|
path = Path(path)
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
async def worker(current_url: str, **worker_kwargs) -> bool:
|
||||||
|
if not stream:
|
||||||
|
content = await cls.get_content(current_url, **worker_kwargs)
|
||||||
|
async with aiofiles.open(path, "wb") as f:
|
||||||
|
await f.write(content)
|
||||||
|
else:
|
||||||
|
await cls._stream_download(current_url, path, **worker_kwargs)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"下载 {current_url} 成功 -> {path.absolute()}",
|
||||||
|
"AsyncHttpx:download",
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
try:
|
||||||
logger.warning(f"下载 {current_url} 失败,尝试下一个。错误: {e}")
|
return await cls._execute_with_fallbacks(
|
||||||
|
url, worker, client=client, **kwargs
|
||||||
logger.error(f"所有URL {urls} 下载均失败 -> {path.absolute()}")
|
)
|
||||||
|
except AllURIsFailedError:
|
||||||
|
logger.error(
|
||||||
|
f"所有URL下载均失败 -> {path.absolute()}", "AsyncHttpx:download"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -346,7 +617,6 @@ class AsyncHttpx:
|
|||||||
logger.error(f"并发下载任务 ({url_info}) 时发生错误", e=result)
|
logger.error(f"并发下载任务 ({url_info}) 时发生错误", e=result)
|
||||||
final_results.append(False)
|
final_results.append(False)
|
||||||
else:
|
else:
|
||||||
# download_file 返回的是 bool,可以直接附加
|
|
||||||
final_results.append(cast(bool, result))
|
final_results.append(cast(bool, result))
|
||||||
|
|
||||||
return final_results
|
return final_results
|
||||||
@ -395,86 +665,30 @@ class AsyncHttpx:
|
|||||||
_results = sorted(iter(_results), key=lambda r: r["elapsed_time"])
|
_results = sorted(iter(_results), key=lambda r: r["elapsed_time"])
|
||||||
return [result["url"] for result in _results]
|
return [result["url"] for result in _results]
|
||||||
|
|
||||||
|
|
||||||
class AsyncPlaywright:
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def new_page(
|
async def temporary_client(cls, **kwargs) -> AsyncGenerator[AsyncClient, None]:
|
||||||
cls, cookies: list[dict[str, Any]] | dict[str, Any] | None = None, **kwargs
|
"""
|
||||||
) -> AsyncGenerator[Page, None]:
|
创建一个临时的、可配置的HTTP客户端上下文,并直接返回该客户端实例。
|
||||||
"""获取一个新页面
|
|
||||||
|
此方法返回一个标准的 `httpx.AsyncClient`,它不使用全局连接池,
|
||||||
|
拥有独立的配置(如代理、headers、超时等),并在退出上下文后自动关闭。
|
||||||
|
适用于需要用一套特殊网络配置执行一系列请求的场景。
|
||||||
|
|
||||||
|
用法:
|
||||||
|
async with AsyncHttpx.temporary_client(proxies=None, timeout=5) as client:
|
||||||
|
# client 是一个标准的 httpx.AsyncClient 实例
|
||||||
|
response1 = await client.get("http://some.internal.api/1")
|
||||||
|
response2 = await client.get("http://some.internal.api/2")
|
||||||
|
data = response2.json()
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
cookies: cookies
|
**kwargs: 所有传递给 `httpx.AsyncClient` 构造函数的参数。
|
||||||
|
例如: `proxies`, `headers`, `verify`, `timeout`,
|
||||||
|
`follow_redirects`。
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
httpx.AsyncClient: 一个配置好的、临时的客户端实例。
|
||||||
"""
|
"""
|
||||||
browser = await get_browser()
|
async with get_async_client(**kwargs) as client:
|
||||||
ctx = await browser.new_context(**kwargs)
|
yield client
|
||||||
if cookies:
|
|
||||||
if isinstance(cookies, dict):
|
|
||||||
cookies = [cookies]
|
|
||||||
await ctx.add_cookies(cookies) # type: ignore
|
|
||||||
page = await ctx.new_page()
|
|
||||||
try:
|
|
||||||
yield page
|
|
||||||
finally:
|
|
||||||
await page.close()
|
|
||||||
await ctx.close()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def screenshot(
|
|
||||||
cls,
|
|
||||||
url: str,
|
|
||||||
path: Path | str,
|
|
||||||
element: str | list[str],
|
|
||||||
*,
|
|
||||||
wait_time: int | None = None,
|
|
||||||
viewport_size: dict[str, int] | None = None,
|
|
||||||
wait_until: (
|
|
||||||
Literal["domcontentloaded", "load", "networkidle"] | None
|
|
||||||
) = "networkidle",
|
|
||||||
timeout: float | None = None,
|
|
||||||
type_: Literal["jpeg", "png"] | None = None,
|
|
||||||
user_agent: str | None = None,
|
|
||||||
cookies: list[dict[str, Any]] | dict[str, Any] | None = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> UniMessage | None:
|
|
||||||
"""截图,该方法仅用于简单快捷截图,复杂截图请操作 page
|
|
||||||
|
|
||||||
参数:
|
|
||||||
url: 网址
|
|
||||||
path: 存储路径
|
|
||||||
element: 元素选择
|
|
||||||
wait_time: 等待截取超时时间
|
|
||||||
viewport_size: 窗口大小
|
|
||||||
wait_until: 等待类型
|
|
||||||
timeout: 超时限制
|
|
||||||
type_: 保存类型
|
|
||||||
user_agent: user_agent
|
|
||||||
cookies: cookies
|
|
||||||
"""
|
|
||||||
if viewport_size is None:
|
|
||||||
viewport_size = {"width": 2560, "height": 1080}
|
|
||||||
if isinstance(path, str):
|
|
||||||
path = Path(path)
|
|
||||||
wait_time = wait_time * 1000 if wait_time else None
|
|
||||||
element_list = [element] if isinstance(element, str) else element
|
|
||||||
async with cls.new_page(
|
|
||||||
cookies,
|
|
||||||
viewport=viewport_size,
|
|
||||||
user_agent=user_agent,
|
|
||||||
**kwargs,
|
|
||||||
) as page:
|
|
||||||
await page.goto(url, timeout=timeout, wait_until=wait_until)
|
|
||||||
card = page
|
|
||||||
for e in element_list:
|
|
||||||
if not card:
|
|
||||||
return None
|
|
||||||
card = await card.wait_for_selector(e, timeout=wait_time)
|
|
||||||
if card:
|
|
||||||
await card.screenshot(path=path, timeout=timeout, type=type_)
|
|
||||||
return MessageUtils.build_message(path)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class BrowserIsNone(Exception):
|
|
||||||
pass
|
|
||||||
|
|||||||
@ -1,810 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
from collections.abc import Callable, Coroutine
|
|
||||||
import copy
|
|
||||||
import inspect
|
|
||||||
import random
|
|
||||||
from typing import ClassVar
|
|
||||||
|
|
||||||
import nonebot
|
|
||||||
from nonebot import get_bots
|
|
||||||
from nonebot_plugin_apscheduler import scheduler
|
|
||||||
from pydantic import BaseModel, ValidationError
|
|
||||||
|
|
||||||
from zhenxun.configs.config import Config
|
|
||||||
from zhenxun.models.schedule_info import ScheduleInfo
|
|
||||||
from zhenxun.services.log import logger
|
|
||||||
from zhenxun.utils.common_utils import CommonUtils
|
|
||||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
|
||||||
from zhenxun.utils.platform import PlatformUtils
|
|
||||||
|
|
||||||
SCHEDULE_CONCURRENCY_KEY = "all_groups_concurrency_limit"
|
|
||||||
|
|
||||||
|
|
||||||
class SchedulerManager:
|
|
||||||
"""
|
|
||||||
一个通用的、持久化的定时任务管理器,供所有插件使用。
|
|
||||||
"""
|
|
||||||
|
|
||||||
_registered_tasks: ClassVar[
|
|
||||||
dict[str, dict[str, Callable | type[BaseModel] | None]]
|
|
||||||
] = {}
|
|
||||||
_JOB_PREFIX = "zhenxun_schedule_"
|
|
||||||
_running_tasks: ClassVar[set] = set()
|
|
||||||
|
|
||||||
def register(
|
|
||||||
self, plugin_name: str, params_model: type[BaseModel] | None = None
|
|
||||||
) -> Callable:
|
|
||||||
"""
|
|
||||||
注册一个可调度的任务函数。
|
|
||||||
被装饰的函数签名应为 `async def func(group_id: str | None, **kwargs)`
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plugin_name (str): 插件的唯一名称 (通常是模块名)。
|
|
||||||
params_model (type[BaseModel], optional): 一个 Pydantic BaseModel 类,
|
|
||||||
用于定义和验证任务函数接受的额外参数。
|
|
||||||
"""
|
|
||||||
|
|
||||||
def decorator(func: Callable[..., Coroutine]) -> Callable[..., Coroutine]:
|
|
||||||
if plugin_name in self._registered_tasks:
|
|
||||||
logger.warning(f"插件 '{plugin_name}' 的定时任务已被重复注册。")
|
|
||||||
self._registered_tasks[plugin_name] = {
|
|
||||||
"func": func,
|
|
||||||
"model": params_model,
|
|
||||||
}
|
|
||||||
model_name = params_model.__name__ if params_model else "无"
|
|
||||||
logger.debug(
|
|
||||||
f"插件 '{plugin_name}' 的定时任务已注册,参数模型: {model_name}"
|
|
||||||
)
|
|
||||||
return func
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
def get_registered_plugins(self) -> list[str]:
|
|
||||||
"""获取所有已注册定时任务的插件列表。"""
|
|
||||||
return list(self._registered_tasks.keys())
|
|
||||||
|
|
||||||
def _get_job_id(self, schedule_id: int) -> str:
|
|
||||||
"""根据数据库ID生成唯一的 APScheduler Job ID。"""
|
|
||||||
return f"{self._JOB_PREFIX}{schedule_id}"
|
|
||||||
|
|
||||||
async def _execute_job(self, schedule_id: int):
|
|
||||||
"""
|
|
||||||
APScheduler 调度的入口函数。
|
|
||||||
根据 schedule_id 处理特定任务、所有群组任务或全局任务。
|
|
||||||
"""
|
|
||||||
schedule = await ScheduleInfo.get_or_none(id=schedule_id)
|
|
||||||
if not schedule or not schedule.is_enabled:
|
|
||||||
logger.warning(f"定时任务 {schedule_id} 不存在或已禁用,跳过执行。")
|
|
||||||
return
|
|
||||||
|
|
||||||
plugin_name = schedule.plugin_name
|
|
||||||
|
|
||||||
task_meta = self._registered_tasks.get(plugin_name)
|
|
||||||
if not task_meta:
|
|
||||||
logger.error(
|
|
||||||
f"无法执行定时任务:插件 '{plugin_name}' 未注册或已卸载。将禁用该任务。"
|
|
||||||
)
|
|
||||||
schedule.is_enabled = False
|
|
||||||
await schedule.save(update_fields=["is_enabled"])
|
|
||||||
self._remove_aps_job(schedule.id)
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
if schedule.bot_id:
|
|
||||||
bot = nonebot.get_bot(schedule.bot_id)
|
|
||||||
else:
|
|
||||||
bot = nonebot.get_bot()
|
|
||||||
logger.debug(
|
|
||||||
f"任务 {schedule_id} 未关联特定Bot,使用默认Bot {bot.self_id}"
|
|
||||||
)
|
|
||||||
except KeyError:
|
|
||||||
logger.warning(
|
|
||||||
f"定时任务 {schedule_id} 需要的 Bot {schedule.bot_id} "
|
|
||||||
f"不在线,本次执行跳过。"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
except ValueError:
|
|
||||||
logger.warning(f"当前没有Bot在线,定时任务 {schedule_id} 跳过。")
|
|
||||||
return
|
|
||||||
|
|
||||||
if schedule.group_id == "__ALL_GROUPS__":
|
|
||||||
await self._execute_for_all_groups(schedule, task_meta, bot)
|
|
||||||
else:
|
|
||||||
await self._execute_for_single_target(schedule, task_meta, bot)
|
|
||||||
|
|
||||||
async def _execute_for_all_groups(
|
|
||||||
self, schedule: ScheduleInfo, task_meta: dict, bot
|
|
||||||
):
|
|
||||||
"""为所有群组执行任务,并处理优先级覆盖。"""
|
|
||||||
plugin_name = schedule.plugin_name
|
|
||||||
|
|
||||||
concurrency_limit = Config.get_config(
|
|
||||||
"SchedulerManager", SCHEDULE_CONCURRENCY_KEY, 5
|
|
||||||
)
|
|
||||||
if not isinstance(concurrency_limit, int) or concurrency_limit <= 0:
|
|
||||||
logger.warning(
|
|
||||||
f"无效的定时任务并发限制配置 '{concurrency_limit}',将使用默认值 5。"
|
|
||||||
)
|
|
||||||
concurrency_limit = 5
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"开始执行针对 [所有群组] 的任务 "
|
|
||||||
f"(ID: {schedule.id}, 插件: {plugin_name}, Bot: {bot.self_id}),"
|
|
||||||
f"并发限制: {concurrency_limit}"
|
|
||||||
)
|
|
||||||
|
|
||||||
all_gids = set()
|
|
||||||
try:
|
|
||||||
group_list, _ = await PlatformUtils.get_group_list(bot)
|
|
||||||
all_gids.update(
|
|
||||||
g.group_id for g in group_list if g.group_id and not g.channel_id
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"为 'all' 任务获取 Bot {bot.self_id} 的群列表失败", e=e)
|
|
||||||
return
|
|
||||||
|
|
||||||
specific_tasks_gids = set(
|
|
||||||
await ScheduleInfo.filter(
|
|
||||||
plugin_name=plugin_name, group_id__in=list(all_gids)
|
|
||||||
).values_list("group_id", flat=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
semaphore = asyncio.Semaphore(concurrency_limit)
|
|
||||||
|
|
||||||
async def worker(gid: str):
|
|
||||||
"""使用 Semaphore 包装单个群组的任务执行"""
|
|
||||||
async with semaphore:
|
|
||||||
temp_schedule = copy.deepcopy(schedule)
|
|
||||||
temp_schedule.group_id = gid
|
|
||||||
await self._execute_for_single_target(temp_schedule, task_meta, bot)
|
|
||||||
await asyncio.sleep(random.uniform(0.1, 0.5))
|
|
||||||
|
|
||||||
tasks_to_run = []
|
|
||||||
for gid in all_gids:
|
|
||||||
if gid in specific_tasks_gids:
|
|
||||||
logger.debug(f"群组 {gid} 已有特定任务,跳过 'all' 任务的执行。")
|
|
||||||
continue
|
|
||||||
tasks_to_run.append(worker(gid))
|
|
||||||
|
|
||||||
if tasks_to_run:
|
|
||||||
await asyncio.gather(*tasks_to_run)
|
|
||||||
|
|
||||||
async def _execute_for_single_target(
|
|
||||||
self, schedule: ScheduleInfo, task_meta: dict, bot
|
|
||||||
):
|
|
||||||
"""为单个目标(具体群组或全局)执行任务。"""
|
|
||||||
plugin_name = schedule.plugin_name
|
|
||||||
group_id = schedule.group_id
|
|
||||||
|
|
||||||
try:
|
|
||||||
is_blocked = await CommonUtils.task_is_block(bot, plugin_name, group_id)
|
|
||||||
if is_blocked:
|
|
||||||
target_desc = f"群 {group_id}" if group_id else "全局"
|
|
||||||
logger.info(
|
|
||||||
f"插件 '{plugin_name}' 的定时任务在目标 [{target_desc}]"
|
|
||||||
"因功能被禁用而跳过执行。"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
task_func = task_meta["func"]
|
|
||||||
job_kwargs = schedule.job_kwargs
|
|
||||||
if not isinstance(job_kwargs, dict):
|
|
||||||
logger.error(
|
|
||||||
f"任务 {schedule.id} 的 job_kwargs 不是字典类型: {type(job_kwargs)}"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
sig = inspect.signature(task_func)
|
|
||||||
if "bot" in sig.parameters:
|
|
||||||
job_kwargs["bot"] = bot
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"插件 '{plugin_name}' 开始为目标 [{group_id or '全局'}] "
|
|
||||||
f"执行定时任务 (ID: {schedule.id})。"
|
|
||||||
)
|
|
||||||
task = asyncio.create_task(task_func(group_id, **job_kwargs))
|
|
||||||
self._running_tasks.add(task)
|
|
||||||
task.add_done_callback(self._running_tasks.discard)
|
|
||||||
await task
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"执行定时任务 (ID: {schedule.id}, 插件: {plugin_name}, "
|
|
||||||
f"目标: {group_id or '全局'}) 时发生异常",
|
|
||||||
e=e,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _validate_and_prepare_kwargs(
|
|
||||||
self, plugin_name: str, job_kwargs: dict | None
|
|
||||||
) -> tuple[bool, str | dict]:
|
|
||||||
"""验证并准备任务参数,应用默认值"""
|
|
||||||
task_meta = self._registered_tasks.get(plugin_name)
|
|
||||||
if not task_meta:
|
|
||||||
return False, f"插件 '{plugin_name}' 未注册。"
|
|
||||||
|
|
||||||
params_model = task_meta.get("model")
|
|
||||||
job_kwargs = job_kwargs if job_kwargs is not None else {}
|
|
||||||
|
|
||||||
if not params_model:
|
|
||||||
if job_kwargs:
|
|
||||||
logger.warning(
|
|
||||||
f"插件 '{plugin_name}' 未定义参数模型,但收到了参数: {job_kwargs}"
|
|
||||||
)
|
|
||||||
return True, job_kwargs
|
|
||||||
|
|
||||||
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
|
|
||||||
logger.error(f"插件 '{plugin_name}' 的参数模型不是有效的 BaseModel 类")
|
|
||||||
return False, f"插件 '{plugin_name}' 的参数模型配置错误"
|
|
||||||
|
|
||||||
try:
|
|
||||||
model_validate = getattr(params_model, "model_validate", None)
|
|
||||||
if not model_validate:
|
|
||||||
return False, f"插件 '{plugin_name}' 的参数模型不支持验证"
|
|
||||||
|
|
||||||
validated_model = model_validate(job_kwargs)
|
|
||||||
|
|
||||||
model_dump = getattr(validated_model, "model_dump", None)
|
|
||||||
if not model_dump:
|
|
||||||
return False, f"插件 '{plugin_name}' 的参数模型不支持导出"
|
|
||||||
|
|
||||||
return True, model_dump()
|
|
||||||
except ValidationError as e:
|
|
||||||
errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()]
|
|
||||||
error_str = "\n".join(errors)
|
|
||||||
msg = f"插件 '{plugin_name}' 的任务参数验证失败:\n{error_str}"
|
|
||||||
return False, msg
|
|
||||||
|
|
||||||
def _add_aps_job(self, schedule: ScheduleInfo):
|
|
||||||
"""根据 ScheduleInfo 对象添加或更新一个 APScheduler 任务。"""
|
|
||||||
job_id = self._get_job_id(schedule.id)
|
|
||||||
try:
|
|
||||||
scheduler.remove_job(job_id)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if not isinstance(schedule.trigger_config, dict):
|
|
||||||
logger.error(
|
|
||||||
f"任务 {schedule.id} 的 trigger_config 不是字典类型: "
|
|
||||||
f"{type(schedule.trigger_config)}"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
scheduler.add_job(
|
|
||||||
self._execute_job,
|
|
||||||
trigger=schedule.trigger_type,
|
|
||||||
id=job_id,
|
|
||||||
misfire_grace_time=300,
|
|
||||||
args=[schedule.id],
|
|
||||||
**schedule.trigger_config,
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"已在 APScheduler 中添加/更新任务: {job_id} "
|
|
||||||
f"with trigger: {schedule.trigger_config}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _remove_aps_job(self, schedule_id: int):
|
|
||||||
"""移除一个 APScheduler 任务。"""
|
|
||||||
job_id = self._get_job_id(schedule_id)
|
|
||||||
try:
|
|
||||||
scheduler.remove_job(job_id)
|
|
||||||
logger.debug(f"已从 APScheduler 中移除任务: {job_id}")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def add_schedule(
|
|
||||||
self,
|
|
||||||
plugin_name: str,
|
|
||||||
group_id: str | None,
|
|
||||||
trigger_type: str,
|
|
||||||
trigger_config: dict,
|
|
||||||
job_kwargs: dict | None = None,
|
|
||||||
bot_id: str | None = None,
|
|
||||||
) -> tuple[bool, str]:
|
|
||||||
"""
|
|
||||||
添加或更新一个定时任务。
|
|
||||||
"""
|
|
||||||
if plugin_name not in self._registered_tasks:
|
|
||||||
return False, f"插件 '{plugin_name}' 没有注册可用的定时任务。"
|
|
||||||
|
|
||||||
is_valid, result = self._validate_and_prepare_kwargs(plugin_name, job_kwargs)
|
|
||||||
if not is_valid:
|
|
||||||
return False, str(result)
|
|
||||||
|
|
||||||
validated_job_kwargs = result
|
|
||||||
|
|
||||||
effective_bot_id = bot_id if group_id == "__ALL_GROUPS__" else None
|
|
||||||
|
|
||||||
search_kwargs = {
|
|
||||||
"plugin_name": plugin_name,
|
|
||||||
"group_id": group_id,
|
|
||||||
}
|
|
||||||
if effective_bot_id:
|
|
||||||
search_kwargs["bot_id"] = effective_bot_id
|
|
||||||
else:
|
|
||||||
search_kwargs["bot_id__isnull"] = True
|
|
||||||
|
|
||||||
defaults = {
|
|
||||||
"trigger_type": trigger_type,
|
|
||||||
"trigger_config": trigger_config,
|
|
||||||
"job_kwargs": validated_job_kwargs,
|
|
||||||
"is_enabled": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
schedule = await ScheduleInfo.filter(**search_kwargs).first()
|
|
||||||
created = False
|
|
||||||
|
|
||||||
if schedule:
|
|
||||||
for key, value in defaults.items():
|
|
||||||
setattr(schedule, key, value)
|
|
||||||
await schedule.save()
|
|
||||||
else:
|
|
||||||
creation_kwargs = {
|
|
||||||
"plugin_name": plugin_name,
|
|
||||||
"group_id": group_id,
|
|
||||||
"bot_id": effective_bot_id,
|
|
||||||
**defaults,
|
|
||||||
}
|
|
||||||
schedule = await ScheduleInfo.create(**creation_kwargs)
|
|
||||||
created = True
|
|
||||||
self._add_aps_job(schedule)
|
|
||||||
action = "设置" if created else "更新"
|
|
||||||
return True, f"已成功{action}插件 '{plugin_name}' 的定时任务。"
|
|
||||||
|
|
||||||
async def add_schedule_for_all(
|
|
||||||
self,
|
|
||||||
plugin_name: str,
|
|
||||||
trigger_type: str,
|
|
||||||
trigger_config: dict,
|
|
||||||
job_kwargs: dict | None = None,
|
|
||||||
) -> tuple[int, int]:
|
|
||||||
"""为所有机器人所在的群组添加定时任务。"""
|
|
||||||
if plugin_name not in self._registered_tasks:
|
|
||||||
raise ValueError(f"插件 '{plugin_name}' 没有注册可用的定时任务。")
|
|
||||||
|
|
||||||
groups = set()
|
|
||||||
for bot in get_bots().values():
|
|
||||||
try:
|
|
||||||
group_list, _ = await PlatformUtils.get_group_list(bot)
|
|
||||||
groups.update(
|
|
||||||
g.group_id for g in group_list if g.group_id and not g.channel_id
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取 Bot {bot.self_id} 的群列表失败", e=e)
|
|
||||||
|
|
||||||
success_count = 0
|
|
||||||
fail_count = 0
|
|
||||||
for gid in groups:
|
|
||||||
try:
|
|
||||||
success, _ = await self.add_schedule(
|
|
||||||
plugin_name, gid, trigger_type, trigger_config, job_kwargs
|
|
||||||
)
|
|
||||||
if success:
|
|
||||||
success_count += 1
|
|
||||||
else:
|
|
||||||
fail_count += 1
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"为群 {gid} 添加定时任务失败: {e}", e=e)
|
|
||||||
fail_count += 1
|
|
||||||
await asyncio.sleep(0.05)
|
|
||||||
return success_count, fail_count
|
|
||||||
|
|
||||||
async def update_schedule(
|
|
||||||
self,
|
|
||||||
schedule_id: int,
|
|
||||||
trigger_type: str | None = None,
|
|
||||||
trigger_config: dict | None = None,
|
|
||||||
job_kwargs: dict | None = None,
|
|
||||||
) -> tuple[bool, str]:
|
|
||||||
"""部分更新一个已存在的定时任务。"""
|
|
||||||
schedule = await self.get_schedule_by_id(schedule_id)
|
|
||||||
if not schedule:
|
|
||||||
return False, f"未找到 ID 为 {schedule_id} 的任务。"
|
|
||||||
|
|
||||||
updated_fields = []
|
|
||||||
if trigger_config is not None:
|
|
||||||
schedule.trigger_config = trigger_config
|
|
||||||
updated_fields.append("trigger_config")
|
|
||||||
|
|
||||||
if trigger_type is not None and schedule.trigger_type != trigger_type:
|
|
||||||
schedule.trigger_type = trigger_type
|
|
||||||
updated_fields.append("trigger_type")
|
|
||||||
|
|
||||||
if job_kwargs is not None:
|
|
||||||
if not isinstance(schedule.job_kwargs, dict):
|
|
||||||
return False, f"任务 {schedule_id} 的 job_kwargs 数据格式错误。"
|
|
||||||
|
|
||||||
merged_kwargs = schedule.job_kwargs.copy()
|
|
||||||
merged_kwargs.update(job_kwargs)
|
|
||||||
|
|
||||||
is_valid, result = self._validate_and_prepare_kwargs(
|
|
||||||
schedule.plugin_name, merged_kwargs
|
|
||||||
)
|
|
||||||
if not is_valid:
|
|
||||||
return False, str(result)
|
|
||||||
|
|
||||||
schedule.job_kwargs = result # type: ignore
|
|
||||||
updated_fields.append("job_kwargs")
|
|
||||||
|
|
||||||
if not updated_fields:
|
|
||||||
return True, "没有任何需要更新的配置。"
|
|
||||||
|
|
||||||
await schedule.save(update_fields=updated_fields)
|
|
||||||
self._add_aps_job(schedule)
|
|
||||||
return True, f"成功更新了任务 ID: {schedule_id} 的配置。"
|
|
||||||
|
|
||||||
async def remove_schedule(
|
|
||||||
self, plugin_name: str, group_id: str | None, bot_id: str | None = None
|
|
||||||
) -> tuple[bool, str]:
|
|
||||||
"""移除指定插件和群组的定时任务。"""
|
|
||||||
query = {"plugin_name": plugin_name, "group_id": group_id}
|
|
||||||
if bot_id:
|
|
||||||
query["bot_id"] = bot_id
|
|
||||||
|
|
||||||
schedules = await ScheduleInfo.filter(**query)
|
|
||||||
if not schedules:
|
|
||||||
msg = (
|
|
||||||
f"未找到与 Bot {bot_id} 相关的群 {group_id} "
|
|
||||||
f"的插件 '{plugin_name}' 定时任务。"
|
|
||||||
)
|
|
||||||
return (False, msg)
|
|
||||||
|
|
||||||
for schedule in schedules:
|
|
||||||
self._remove_aps_job(schedule.id)
|
|
||||||
await schedule.delete()
|
|
||||||
|
|
||||||
target_desc = f"群 {group_id}" if group_id else "全局"
|
|
||||||
msg = (
|
|
||||||
f"已取消 Bot {bot_id} 在 [{target_desc}] "
|
|
||||||
f"的插件 '{plugin_name}' 所有定时任务。"
|
|
||||||
)
|
|
||||||
return (True, msg)
|
|
||||||
|
|
||||||
async def remove_schedule_for_all(
|
|
||||||
self, plugin_name: str, bot_id: str | None = None
|
|
||||||
) -> int:
|
|
||||||
"""移除指定插件在所有群组的定时任务。"""
|
|
||||||
query = {"plugin_name": plugin_name}
|
|
||||||
if bot_id:
|
|
||||||
query["bot_id"] = bot_id
|
|
||||||
|
|
||||||
schedules_to_delete = await ScheduleInfo.filter(**query).all()
|
|
||||||
if not schedules_to_delete:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
for schedule in schedules_to_delete:
|
|
||||||
self._remove_aps_job(schedule.id)
|
|
||||||
await schedule.delete()
|
|
||||||
await asyncio.sleep(0.01)
|
|
||||||
|
|
||||||
return len(schedules_to_delete)
|
|
||||||
|
|
||||||
async def remove_schedules_by_group(self, group_id: str) -> tuple[bool, str]:
|
|
||||||
"""移除指定群组的所有定时任务。"""
|
|
||||||
schedules = await ScheduleInfo.filter(group_id=group_id)
|
|
||||||
if not schedules:
|
|
||||||
return False, f"群 {group_id} 没有任何定时任务。"
|
|
||||||
|
|
||||||
count = 0
|
|
||||||
for schedule in schedules:
|
|
||||||
self._remove_aps_job(schedule.id)
|
|
||||||
await schedule.delete()
|
|
||||||
count += 1
|
|
||||||
await asyncio.sleep(0.01)
|
|
||||||
|
|
||||||
return True, f"已成功移除群 {group_id} 的 {count} 个定时任务。"
|
|
||||||
|
|
||||||
async def pause_schedules_by_group(self, group_id: str) -> tuple[int, str]:
|
|
||||||
"""暂停指定群组的所有定时任务。"""
|
|
||||||
schedules = await ScheduleInfo.filter(group_id=group_id, is_enabled=True)
|
|
||||||
if not schedules:
|
|
||||||
return 0, f"群 {group_id} 没有正在运行的定时任务可暂停。"
|
|
||||||
|
|
||||||
count = 0
|
|
||||||
for schedule in schedules:
|
|
||||||
success, _ = await self.pause_schedule(schedule.id)
|
|
||||||
if success:
|
|
||||||
count += 1
|
|
||||||
await asyncio.sleep(0.01)
|
|
||||||
|
|
||||||
return count, f"已成功暂停群 {group_id} 的 {count} 个定时任务。"
|
|
||||||
|
|
||||||
async def resume_schedules_by_group(self, group_id: str) -> tuple[int, str]:
|
|
||||||
"""恢复指定群组的所有定时任务。"""
|
|
||||||
schedules = await ScheduleInfo.filter(group_id=group_id, is_enabled=False)
|
|
||||||
if not schedules:
|
|
||||||
return 0, f"群 {group_id} 没有已暂停的定时任务可恢复。"
|
|
||||||
|
|
||||||
count = 0
|
|
||||||
for schedule in schedules:
|
|
||||||
success, _ = await self.resume_schedule(schedule.id)
|
|
||||||
if success:
|
|
||||||
count += 1
|
|
||||||
await asyncio.sleep(0.01)
|
|
||||||
|
|
||||||
return count, f"已成功恢复群 {group_id} 的 {count} 个定时任务。"
|
|
||||||
|
|
||||||
async def pause_schedules_by_plugin(self, plugin_name: str) -> tuple[int, str]:
|
|
||||||
"""暂停指定插件在所有群组的定时任务。"""
|
|
||||||
schedules = await ScheduleInfo.filter(plugin_name=plugin_name, is_enabled=True)
|
|
||||||
if not schedules:
|
|
||||||
return 0, f"插件 '{plugin_name}' 没有正在运行的定时任务可暂停。"
|
|
||||||
|
|
||||||
count = 0
|
|
||||||
for schedule in schedules:
|
|
||||||
success, _ = await self.pause_schedule(schedule.id)
|
|
||||||
if success:
|
|
||||||
count += 1
|
|
||||||
await asyncio.sleep(0.01)
|
|
||||||
|
|
||||||
return (
|
|
||||||
count,
|
|
||||||
f"已成功暂停插件 '{plugin_name}' 在所有群组的 {count} 个定时任务。",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def resume_schedules_by_plugin(self, plugin_name: str) -> tuple[int, str]:
|
|
||||||
"""恢复指定插件在所有群组的定时任务。"""
|
|
||||||
schedules = await ScheduleInfo.filter(plugin_name=plugin_name, is_enabled=False)
|
|
||||||
if not schedules:
|
|
||||||
return 0, f"插件 '{plugin_name}' 没有已暂停的定时任务可恢复。"
|
|
||||||
|
|
||||||
count = 0
|
|
||||||
for schedule in schedules:
|
|
||||||
success, _ = await self.resume_schedule(schedule.id)
|
|
||||||
if success:
|
|
||||||
count += 1
|
|
||||||
await asyncio.sleep(0.01)
|
|
||||||
|
|
||||||
return (
|
|
||||||
count,
|
|
||||||
f"已成功恢复插件 '{plugin_name}' 在所有群组的 {count} 个定时任务。",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def pause_schedule_by_plugin_group(
|
|
||||||
self, plugin_name: str, group_id: str | None, bot_id: str | None = None
|
|
||||||
) -> tuple[bool, str]:
|
|
||||||
"""暂停指定插件在指定群组的定时任务。"""
|
|
||||||
query = {"plugin_name": plugin_name, "group_id": group_id, "is_enabled": True}
|
|
||||||
if bot_id:
|
|
||||||
query["bot_id"] = bot_id
|
|
||||||
|
|
||||||
schedules = await ScheduleInfo.filter(**query)
|
|
||||||
if not schedules:
|
|
||||||
return (
|
|
||||||
False,
|
|
||||||
f"群 {group_id} 未设置插件 '{plugin_name}' 的定时任务或任务已暂停。",
|
|
||||||
)
|
|
||||||
|
|
||||||
count = 0
|
|
||||||
for schedule in schedules:
|
|
||||||
success, _ = await self.pause_schedule(schedule.id)
|
|
||||||
if success:
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
return (
|
|
||||||
True,
|
|
||||||
f"已成功暂停群 {group_id} 的插件 '{plugin_name}' 共 {count} 个定时任务。",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def resume_schedule_by_plugin_group(
|
|
||||||
self, plugin_name: str, group_id: str | None, bot_id: str | None = None
|
|
||||||
) -> tuple[bool, str]:
|
|
||||||
"""恢复指定插件在指定群组的定时任务。"""
|
|
||||||
query = {"plugin_name": plugin_name, "group_id": group_id, "is_enabled": False}
|
|
||||||
if bot_id:
|
|
||||||
query["bot_id"] = bot_id
|
|
||||||
|
|
||||||
schedules = await ScheduleInfo.filter(**query)
|
|
||||||
if not schedules:
|
|
||||||
return (
|
|
||||||
False,
|
|
||||||
f"群 {group_id} 未设置插件 '{plugin_name}' 的定时任务或任务已启用。",
|
|
||||||
)
|
|
||||||
|
|
||||||
count = 0
|
|
||||||
for schedule in schedules:
|
|
||||||
success, _ = await self.resume_schedule(schedule.id)
|
|
||||||
if success:
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
return (
|
|
||||||
True,
|
|
||||||
f"已成功恢复群 {group_id} 的插件 '{plugin_name}' 共 {count} 个定时任务。",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def remove_all_schedules(self) -> tuple[int, str]:
|
|
||||||
"""移除所有群组的所有定时任务。"""
|
|
||||||
schedules = await ScheduleInfo.all()
|
|
||||||
if not schedules:
|
|
||||||
return 0, "当前没有任何定时任务。"
|
|
||||||
|
|
||||||
count = 0
|
|
||||||
for schedule in schedules:
|
|
||||||
self._remove_aps_job(schedule.id)
|
|
||||||
await schedule.delete()
|
|
||||||
count += 1
|
|
||||||
await asyncio.sleep(0.01)
|
|
||||||
|
|
||||||
return count, f"已成功移除所有群组的 {count} 个定时任务。"
|
|
||||||
|
|
||||||
async def pause_all_schedules(self) -> tuple[int, str]:
|
|
||||||
"""暂停所有群组的所有定时任务。"""
|
|
||||||
schedules = await ScheduleInfo.filter(is_enabled=True)
|
|
||||||
if not schedules:
|
|
||||||
return 0, "当前没有正在运行的定时任务可暂停。"
|
|
||||||
|
|
||||||
count = 0
|
|
||||||
for schedule in schedules:
|
|
||||||
success, _ = await self.pause_schedule(schedule.id)
|
|
||||||
if success:
|
|
||||||
count += 1
|
|
||||||
await asyncio.sleep(0.01)
|
|
||||||
|
|
||||||
return count, f"已成功暂停所有群组的 {count} 个定时任务。"
|
|
||||||
|
|
||||||
async def resume_all_schedules(self) -> tuple[int, str]:
|
|
||||||
"""恢复所有群组的所有定时任务。"""
|
|
||||||
schedules = await ScheduleInfo.filter(is_enabled=False)
|
|
||||||
if not schedules:
|
|
||||||
return 0, "当前没有已暂停的定时任务可恢复。"
|
|
||||||
|
|
||||||
count = 0
|
|
||||||
for schedule in schedules:
|
|
||||||
success, _ = await self.resume_schedule(schedule.id)
|
|
||||||
if success:
|
|
||||||
count += 1
|
|
||||||
await asyncio.sleep(0.01)
|
|
||||||
|
|
||||||
return count, f"已成功恢复所有群组的 {count} 个定时任务。"
|
|
||||||
|
|
||||||
async def remove_schedule_by_id(self, schedule_id: int) -> tuple[bool, str]:
|
|
||||||
"""通过ID移除指定的定时任务。"""
|
|
||||||
schedule = await self.get_schedule_by_id(schedule_id)
|
|
||||||
if not schedule:
|
|
||||||
return False, f"未找到 ID 为 {schedule_id} 的定时任务。"
|
|
||||||
|
|
||||||
self._remove_aps_job(schedule.id)
|
|
||||||
await schedule.delete()
|
|
||||||
|
|
||||||
return (
|
|
||||||
True,
|
|
||||||
f"已删除插件 '{schedule.plugin_name}' 在群 {schedule.group_id} "
|
|
||||||
f"的定时任务 (ID: {schedule.id})。",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_schedule_by_id(self, schedule_id: int) -> ScheduleInfo | None:
|
|
||||||
"""通过ID获取定时任务信息。"""
|
|
||||||
return await ScheduleInfo.get_or_none(id=schedule_id)
|
|
||||||
|
|
||||||
async def get_schedules(
|
|
||||||
self, plugin_name: str, group_id: str | None
|
|
||||||
) -> list[ScheduleInfo]:
|
|
||||||
"""获取特定群组特定插件的所有定时任务。"""
|
|
||||||
return await ScheduleInfo.filter(plugin_name=plugin_name, group_id=group_id)
|
|
||||||
|
|
||||||
async def get_schedule(
|
|
||||||
self, plugin_name: str, group_id: str | None
|
|
||||||
) -> ScheduleInfo | None:
|
|
||||||
"""获取特定群组的定时任务信息。"""
|
|
||||||
return await ScheduleInfo.get_or_none(
|
|
||||||
plugin_name=plugin_name, group_id=group_id
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_all_schedules(
|
|
||||||
self, plugin_name: str | None = None
|
|
||||||
) -> list[ScheduleInfo]:
|
|
||||||
"""获取所有定时任务信息,可按插件名过滤。"""
|
|
||||||
if plugin_name:
|
|
||||||
return await ScheduleInfo.filter(plugin_name=plugin_name).all()
|
|
||||||
return await ScheduleInfo.all()
|
|
||||||
|
|
||||||
async def get_schedule_status(self, schedule_id: int) -> dict | None:
|
|
||||||
"""获取任务的详细状态。"""
|
|
||||||
schedule = await self.get_schedule_by_id(schedule_id)
|
|
||||||
if not schedule:
|
|
||||||
return None
|
|
||||||
|
|
||||||
job_id = self._get_job_id(schedule.id)
|
|
||||||
job = scheduler.get_job(job_id)
|
|
||||||
|
|
||||||
status = {
|
|
||||||
"id": schedule.id,
|
|
||||||
"bot_id": schedule.bot_id,
|
|
||||||
"plugin_name": schedule.plugin_name,
|
|
||||||
"group_id": schedule.group_id,
|
|
||||||
"is_enabled": schedule.is_enabled,
|
|
||||||
"trigger_type": schedule.trigger_type,
|
|
||||||
"trigger_config": schedule.trigger_config,
|
|
||||||
"job_kwargs": schedule.job_kwargs,
|
|
||||||
"next_run_time": job.next_run_time.strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
if job and job.next_run_time
|
|
||||||
else "N/A",
|
|
||||||
"is_paused_in_scheduler": not bool(job.next_run_time) if job else "N/A",
|
|
||||||
}
|
|
||||||
return status
|
|
||||||
|
|
||||||
async def pause_schedule(self, schedule_id: int) -> tuple[bool, str]:
|
|
||||||
"""暂停一个定时任务。"""
|
|
||||||
schedule = await self.get_schedule_by_id(schedule_id)
|
|
||||||
if not schedule or not schedule.is_enabled:
|
|
||||||
return False, "任务不存在或已暂停。"
|
|
||||||
|
|
||||||
schedule.is_enabled = False
|
|
||||||
await schedule.save(update_fields=["is_enabled"])
|
|
||||||
|
|
||||||
job_id = self._get_job_id(schedule.id)
|
|
||||||
try:
|
|
||||||
scheduler.pause_job(job_id)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return (
|
|
||||||
True,
|
|
||||||
f"已暂停插件 '{schedule.plugin_name}' 在群 {schedule.group_id} "
|
|
||||||
f"的定时任务 (ID: {schedule.id})。",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def resume_schedule(self, schedule_id: int) -> tuple[bool, str]:
|
|
||||||
"""恢复一个定时任务。"""
|
|
||||||
schedule = await self.get_schedule_by_id(schedule_id)
|
|
||||||
if not schedule or schedule.is_enabled:
|
|
||||||
return False, "任务不存在或已启用。"
|
|
||||||
|
|
||||||
schedule.is_enabled = True
|
|
||||||
await schedule.save(update_fields=["is_enabled"])
|
|
||||||
|
|
||||||
job_id = self._get_job_id(schedule.id)
|
|
||||||
try:
|
|
||||||
scheduler.resume_job(job_id)
|
|
||||||
except Exception:
|
|
||||||
self._add_aps_job(schedule)
|
|
||||||
|
|
||||||
return (
|
|
||||||
True,
|
|
||||||
f"已恢复插件 '{schedule.plugin_name}' 在群 {schedule.group_id} "
|
|
||||||
f"的定时任务 (ID: {schedule.id})。",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def trigger_now(self, schedule_id: int) -> tuple[bool, str]:
|
|
||||||
"""手动触发一个定时任务。"""
|
|
||||||
schedule = await self.get_schedule_by_id(schedule_id)
|
|
||||||
if not schedule:
|
|
||||||
return False, f"未找到 ID 为 {schedule_id} 的定时任务。"
|
|
||||||
|
|
||||||
if schedule.plugin_name not in self._registered_tasks:
|
|
||||||
return False, f"插件 '{schedule.plugin_name}' 没有注册可用的定时任务。"
|
|
||||||
|
|
||||||
try:
|
|
||||||
await self._execute_job(schedule.id)
|
|
||||||
return (
|
|
||||||
True,
|
|
||||||
f"已手动触发插件 '{schedule.plugin_name}' 在群 {schedule.group_id} "
|
|
||||||
f"的定时任务 (ID: {schedule.id})。",
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"手动触发任务失败: {e}")
|
|
||||||
return False, f"手动触发任务失败: {e}"
|
|
||||||
|
|
||||||
|
|
||||||
scheduler_manager = SchedulerManager()
|
|
||||||
|
|
||||||
|
|
||||||
@PriorityLifecycle.on_startup(priority=90)
|
|
||||||
async def _load_schedules_from_db():
|
|
||||||
"""在服务启动时从数据库加载并调度所有任务。"""
|
|
||||||
Config.add_plugin_config(
|
|
||||||
"SchedulerManager",
|
|
||||||
SCHEDULE_CONCURRENCY_KEY,
|
|
||||||
5,
|
|
||||||
help="“所有群组”类型定时任务的并发执行数量限制",
|
|
||||||
type=int,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("正在从数据库加载并调度所有定时任务...")
|
|
||||||
schedules = await ScheduleInfo.filter(is_enabled=True).all()
|
|
||||||
count = 0
|
|
||||||
for schedule in schedules:
|
|
||||||
if schedule.plugin_name in scheduler_manager._registered_tasks:
|
|
||||||
scheduler_manager._add_aps_job(schedule)
|
|
||||||
count += 1
|
|
||||||
else:
|
|
||||||
logger.warning(f"跳过加载定时任务:插件 '{schedule.plugin_name}' 未注册。")
|
|
||||||
logger.info(f"定时任务加载完成,共成功加载 {count} 个任务。")
|
|
||||||
@ -80,14 +80,14 @@ class PlatformUtils:
|
|||||||
@classmethod
|
@classmethod
|
||||||
async def send_superuser(
|
async def send_superuser(
|
||||||
cls,
|
cls,
|
||||||
bot: Bot,
|
bot: Bot | None,
|
||||||
message: UniMessage | str,
|
message: UniMessage | str,
|
||||||
superuser_id: str | None = None,
|
superuser_id: str | None = None,
|
||||||
) -> list[tuple[str, Receipt]]:
|
) -> list[tuple[str, Receipt]]:
|
||||||
"""发送消息给超级用户
|
"""发送消息给超级用户
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
bot: Bot
|
bot: Bot,没有传入时使用get_bot随机获取
|
||||||
message: 消息
|
message: 消息
|
||||||
superuser_id: 指定超级用户id.
|
superuser_id: 指定超级用户id.
|
||||||
|
|
||||||
@ -97,6 +97,8 @@ class PlatformUtils:
|
|||||||
返回:
|
返回:
|
||||||
Receipt | None: Receipt
|
Receipt | None: Receipt
|
||||||
"""
|
"""
|
||||||
|
if not bot:
|
||||||
|
bot = nonebot.get_bot()
|
||||||
superuser_ids = []
|
superuser_ids = []
|
||||||
if superuser_id:
|
if superuser_id:
|
||||||
superuser_ids.append(superuser_id)
|
superuser_ids.append(superuser_id)
|
||||||
@ -529,9 +531,16 @@ class BroadcastEngine:
|
|||||||
try:
|
try:
|
||||||
self.bot_list.append(nonebot.get_bot(i))
|
self.bot_list.append(nonebot.get_bot(i))
|
||||||
except KeyError:
|
except KeyError:
|
||||||
logger.warning(f"Bot:{i} 对象未连接或不存在")
|
logger.warning(f"Bot:{i} 对象未连接或不存在", log_cmd)
|
||||||
if not self.bot_list:
|
if not self.bot_list:
|
||||||
raise ValueError("当前没有可用的Bot对象...", log_cmd)
|
try:
|
||||||
|
bot = nonebot.get_bot()
|
||||||
|
self.bot_list.append(bot)
|
||||||
|
logger.warning(
|
||||||
|
f"广播任务未传入Bot对象,使用默认Bot {bot.self_id}", log_cmd
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError("当前没有可用的Bot对象...", log_cmd) from e
|
||||||
|
|
||||||
async def call_check(self, bot: Bot, group_id: str) -> bool:
|
async def call_check(self, bot: Bot, group_id: str) -> bool:
|
||||||
"""运行发送检测函数
|
"""运行发送检测函数
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime
|
from datetime import date, datetime
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import time
|
import time
|
||||||
@ -244,3 +244,20 @@ def is_number(text: str) -> bool:
|
|||||||
return True
|
return True
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class TimeUtils:
|
||||||
|
@classmethod
|
||||||
|
def get_day_start(cls, target_date: date | datetime | None = None) -> datetime:
|
||||||
|
"""获取某天的0点时间
|
||||||
|
|
||||||
|
返回:
|
||||||
|
datetime: 今天某天的0点时间
|
||||||
|
"""
|
||||||
|
if not target_date:
|
||||||
|
target_date = datetime.now()
|
||||||
|
return (
|
||||||
|
target_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
|
if isinstance(target_date, datetime)
|
||||||
|
else datetime.combine(target_date, datetime.min.time())
|
||||||
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user