mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 14:22:55 +08:00
Merge branch 'main' into db/unicode_support
This commit is contained in:
commit
357b3a11d2
14
.env.dev
14
.env.dev
@ -27,6 +27,18 @@ QBOT_ID_DATA = '{
|
||||
# 示例: "sqlite:data/db/zhenxun.db" 在data目录下建立db文件夹
|
||||
DB_URL = ""
|
||||
|
||||
# NONE: 不使用缓存, MEMORY: 使用内存缓存, REDIS: 使用Redis缓存
|
||||
CACHE_MODE = NONE
|
||||
# REDIS配置,使用REDIS替换Cache内存缓存
|
||||
# REDIS地址
|
||||
# REDIS_HOST = "127.0.0.1"
|
||||
# REDIS端口
|
||||
# REDIS_PORT = 6379
|
||||
# REDIS密码
|
||||
# REDIS_PASSWORD = ""
|
||||
# REDIS过期时间
|
||||
# REDIS_EXPIRE = 600
|
||||
|
||||
# 系统代理
|
||||
# SYSTEM_PROXY = "http://127.0.0.1:7890"
|
||||
|
||||
@ -40,7 +52,7 @@ PLATFORM_SUPERUSERS = '
|
||||
DRIVER=~fastapi+~httpx+~websockets
|
||||
|
||||
|
||||
# LOG_LEVEL=DEBUG
|
||||
# LOG_LEVEL = DEBUG
|
||||
# 服务器和端口
|
||||
HOST = 127.0.0.1
|
||||
PORT = 8080
|
||||
|
||||
951
poetry.lock
generated
951
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -16,7 +16,7 @@ python = "^3.10"
|
||||
playwright = "^1.41.1"
|
||||
nonebot-adapter-onebot = "^2.3.1"
|
||||
nonebot-plugin-apscheduler = "^0.5"
|
||||
tortoise-orm = { extras = ["asyncpg"], version = "^0.20.0" }
|
||||
tortoise-orm = "^0.20.0"
|
||||
cattrs = "^23.2.3"
|
||||
ruamel-yaml = "^0.18.5"
|
||||
strenum = "^0.4.15"
|
||||
@ -39,7 +39,7 @@ dateparser = "^1.2.0"
|
||||
bilireq = "0.2.3post0"
|
||||
python-jose = { extras = ["cryptography"], version = "^3.3.0" }
|
||||
python-multipart = "^0.0.9"
|
||||
aiocache = "^0.12.2"
|
||||
aiocache = {extras = ["redis"], version = "^0.12.3"}
|
||||
py-cpuinfo = "^9.0.0"
|
||||
nonebot-plugin-alconna = "^0.54.0"
|
||||
tenacity = "^9.0.0"
|
||||
@ -47,6 +47,9 @@ nonebot-plugin-uninfo = ">0.4.1"
|
||||
nonebot-plugin-waiter = "^0.8.1"
|
||||
multidict = ">=6.0.0,!=6.3.2"
|
||||
|
||||
redis = { version = ">=5", optional = true }
|
||||
asyncpg = { version = ">=0.20.0", optional = true }
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
nonebug = "^0.4"
|
||||
pytest-cov = "^5.0.0"
|
||||
@ -57,6 +60,9 @@ respx = "^0.21.1"
|
||||
ruff = "^0.8.0"
|
||||
pre-commit = "^4.0.0"
|
||||
|
||||
[tool.poetry.extras]
|
||||
redis = ["redis"]
|
||||
postgresql = ["asyncpg"]
|
||||
|
||||
[tool.nonebot]
|
||||
plugins = [
|
||||
|
||||
@ -50,22 +50,31 @@ async def _(bot: Bot):
|
||||
|
||||
|
||||
SIGN_SQL = """
|
||||
select distinct on("user_id") t1.user_id, t1.checkin_count, t1.add_probability,
|
||||
t1.specify_probability, t1.impression
|
||||
from public.sign_group_users t1
|
||||
join (
|
||||
select user_id, max(t2.impression) as max_impression
|
||||
from public.sign_group_users t2
|
||||
group by user_id
|
||||
) t on t.user_id = t1.user_id and t.max_impression = t1.impression
|
||||
SELECT user_id, checkin_count, add_probability, specify_probability, impression
|
||||
FROM (
|
||||
SELECT
|
||||
t1.user_id,
|
||||
t1.checkin_count,
|
||||
t1.add_probability,
|
||||
t1.specify_probability,
|
||||
t1.impression,
|
||||
ROW_NUMBER() OVER(PARTITION BY t1.user_id ORDER BY t1.impression DESC) AS rn
|
||||
FROM sign_group_users t1
|
||||
INNER JOIN (
|
||||
SELECT user_id, MAX(impression) AS max_impression
|
||||
FROM sign_group_users
|
||||
GROUP BY user_id
|
||||
) t2 ON t2.user_id = t1.user_id AND t2.max_impression = t1.impression
|
||||
) t
|
||||
WHERE rn = 1
|
||||
"""
|
||||
|
||||
BAG_SQL = """
|
||||
select t1.user_id, t1.gold, t1.property
|
||||
from public.bag_users t1
|
||||
from bag_users t1
|
||||
join (
|
||||
select user_id, max(t2.gold) as max_gold
|
||||
from public.bag_users t2
|
||||
from bag_users t2
|
||||
group by user_id
|
||||
) t on t.user_id = t1.user_id and t.max_gold = t1.gold
|
||||
"""
|
||||
|
||||
@ -87,13 +87,17 @@ __plugin_meta__ = PluginMetadata(
|
||||
smart_tools=[
|
||||
AICallableTag(
|
||||
name="call_ban",
|
||||
description="某人多次(至少三次)辱骂你,调用此方法进行封禁",
|
||||
description="如果你讨厌某个人(好感度过低并让你感到困扰,或者多次辱骂你),调用此方法进行封禁,调用该方法后要告知用户被封禁和原因",
|
||||
parameters=AICallableParam(
|
||||
type="object",
|
||||
properties={
|
||||
"user_id": AICallableProperties(
|
||||
type="string", description="用户的id"
|
||||
),
|
||||
"duration": AICallableProperties(
|
||||
type="integer",
|
||||
description="封禁时长(选择的值只能是1-360),单位为分钟,如果频繁触发,按情况增加",
|
||||
),
|
||||
},
|
||||
required=["user_id"],
|
||||
),
|
||||
|
||||
@ -9,14 +9,14 @@ from zhenxun.services.log import logger
|
||||
from zhenxun.utils.image_utils import BuildImage, ImageTemplate
|
||||
|
||||
|
||||
async def call_ban(user_id: str):
|
||||
async def call_ban(user_id: str, duration: int = 1):
|
||||
"""调用ban
|
||||
|
||||
参数:
|
||||
user_id: 用户id
|
||||
"""
|
||||
await BanConsole.ban(user_id, None, 9, 60 * 12)
|
||||
logger.info("辱骂次数过多,已将用户加入黑名单...", "ban", session=user_id)
|
||||
await BanConsole.ban(user_id, None, 9, duration * 60)
|
||||
logger.info("被讨厌了,已将用户加入黑名单...", "ban", session=user_id)
|
||||
|
||||
|
||||
class BanManage:
|
||||
@ -114,7 +114,7 @@ class BanManage:
|
||||
if not is_superuser and user_id and session.id1:
|
||||
user_level = await LevelUser.get_user_level(session.id1, group_id)
|
||||
if idx:
|
||||
ban_data = await BanConsole.get_or_none(id=idx)
|
||||
ban_data = await BanConsole.get_ban(id=idx)
|
||||
if not ban_data:
|
||||
return False, "该用户/群组不在黑名单中捏..."
|
||||
if ban_data.ban_level > user_level:
|
||||
|
||||
@ -1,10 +1,13 @@
|
||||
import os
|
||||
from typing import cast
|
||||
|
||||
from zhenxun.configs.path_config import DATA_PATH, IMAGE_PATH
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.models.task_info import TaskInfo
|
||||
from zhenxun.utils.enum import BlockType, PluginType
|
||||
from zhenxun.services.cache import CacheRoot
|
||||
from zhenxun.utils.common_utils import CommonUtils
|
||||
from zhenxun.utils.enum import BlockType, CacheType, PluginType
|
||||
from zhenxun.utils.exception import GroupInfoNotFound
|
||||
from zhenxun.utils.image_utils import BuildImage, ImageTemplate, RowStyle
|
||||
|
||||
@ -116,9 +119,7 @@ async def build_task(group_id: str | None) -> BuildImage:
|
||||
column_name = ["ID", "模块", "名称", "群组状态", "全局状态", "运行时间"]
|
||||
group = None
|
||||
if group_id:
|
||||
group = await GroupConsole.get_or_none(
|
||||
group_id=group_id, channel_id__isnull=True
|
||||
)
|
||||
group = await GroupConsole.get_group(group_id=group_id)
|
||||
if not group:
|
||||
raise GroupInfoNotFound()
|
||||
else:
|
||||
@ -200,26 +201,26 @@ class PluginManager:
|
||||
)
|
||||
return f"成功将所有功能进群默认状态修改为: {'开启' if status else '关闭'}"
|
||||
if group_id:
|
||||
if group := await GroupConsole.get_or_none(
|
||||
group_id=group_id, channel_id__isnull=True
|
||||
):
|
||||
module_list = await PluginInfo.filter(
|
||||
plugin_type=PluginType.NORMAL
|
||||
).values_list("module", flat=True)
|
||||
if status:
|
||||
for module in module_list:
|
||||
group.block_plugin = group.block_plugin.replace(
|
||||
f"<{module},", ""
|
||||
if group := await GroupConsole.get_group(group_id=group_id):
|
||||
module_list = cast(
|
||||
list[str],
|
||||
await PluginInfo.filter(plugin_type=PluginType.NORMAL).values_list(
|
||||
"module", flat=True
|
||||
),
|
||||
)
|
||||
if status:
|
||||
# 开启所有功能 - 清空禁用列表
|
||||
group.block_plugin = ""
|
||||
else:
|
||||
module_list = [f"<{module}" for module in module_list]
|
||||
group.block_plugin = ",".join(module_list) + "," # type: ignore
|
||||
# 关闭所有功能 - 将模块列表转换为禁用格式
|
||||
group.block_plugin = CommonUtils.convert_module_format(module_list)
|
||||
await group.save(update_fields=["block_plugin"])
|
||||
return f"成功将此群组所有功能状态修改为: {'开启' if status else '关闭'}"
|
||||
return "获取群组失败..."
|
||||
await PluginInfo.filter(plugin_type=PluginType.NORMAL).update(
|
||||
status=status, block_type=None if status else BlockType.ALL
|
||||
)
|
||||
await CacheRoot.invalidate_cache(CacheType.PLUGINS)
|
||||
return f"成功将所有功能全局状态修改为: {'开启' if status else '关闭'}"
|
||||
|
||||
@classmethod
|
||||
@ -232,9 +233,7 @@ class PluginManager:
|
||||
返回:
|
||||
bool: 是否醒来
|
||||
"""
|
||||
if c := await GroupConsole.get_or_none(
|
||||
group_id=group_id, channel_id__isnull=True
|
||||
):
|
||||
if c := await GroupConsole.get_group(group_id=group_id):
|
||||
return c.status
|
||||
return False
|
||||
|
||||
@ -245,9 +244,11 @@ class PluginManager:
|
||||
参数:
|
||||
group_id: 群组id
|
||||
"""
|
||||
await GroupConsole.filter(group_id=group_id, channel_id__isnull=True).update(
|
||||
status=False
|
||||
group, _ = await GroupConsole.get_or_create(
|
||||
group_id=group_id, channel_id__isnull=True
|
||||
)
|
||||
group.status = False
|
||||
await group.save(update_fields=["status"])
|
||||
|
||||
@classmethod
|
||||
async def wake(cls, group_id: str):
|
||||
@ -256,9 +257,11 @@ class PluginManager:
|
||||
参数:
|
||||
group_id: 群组id
|
||||
"""
|
||||
await GroupConsole.filter(group_id=group_id, channel_id__isnull=True).update(
|
||||
status=True
|
||||
group, _ = await GroupConsole.get_or_create(
|
||||
group_id=group_id, channel_id__isnull=True
|
||||
)
|
||||
group.status = True
|
||||
await group.save(update_fields=["status"])
|
||||
|
||||
@classmethod
|
||||
async def block(cls, module: str):
|
||||
@ -267,7 +270,9 @@ class PluginManager:
|
||||
参数:
|
||||
module: 模块名
|
||||
"""
|
||||
await PluginInfo.filter(module=module).update(status=False)
|
||||
if plugin := await PluginInfo.get_plugin(module=module):
|
||||
plugin.status = False
|
||||
await plugin.save(update_fields=["status"])
|
||||
|
||||
@classmethod
|
||||
async def unblock(cls, module: str):
|
||||
@ -276,7 +281,9 @@ class PluginManager:
|
||||
参数:
|
||||
module: 模块名
|
||||
"""
|
||||
await PluginInfo.filter(module=module).update(status=True)
|
||||
if plugin := await PluginInfo.get_plugin(module=module):
|
||||
plugin.status = True
|
||||
await plugin.save(update_fields=["status"])
|
||||
|
||||
@classmethod
|
||||
async def block_group_plugin(cls, plugin_name: str, group_id: str) -> str:
|
||||
@ -437,17 +444,18 @@ class PluginManager:
|
||||
"""
|
||||
status_str = "关闭" if status else "开启"
|
||||
if is_all:
|
||||
modules = await TaskInfo.annotate().values_list("module", flat=True)
|
||||
if modules:
|
||||
module_list = cast(
|
||||
list[str], await TaskInfo.annotate().values_list("module", flat=True)
|
||||
)
|
||||
if module_list:
|
||||
group, _ = await GroupConsole.get_or_create(
|
||||
group_id=group_id, channel_id__isnull=True
|
||||
)
|
||||
modules = [f"<{module}" for module in modules]
|
||||
if status:
|
||||
group.block_task = ",".join(modules) + "," # type: ignore
|
||||
group.block_task = CommonUtils.convert_module_format(module_list)
|
||||
else:
|
||||
for module in modules:
|
||||
group.block_task = group.block_task.replace(f"{module},", "")
|
||||
# 开启所有模块 - 清空禁用列表
|
||||
group.block_task = ""
|
||||
await group.save(update_fields=["block_task"])
|
||||
return f"已成功{status_str}全部被动技能!"
|
||||
elif task := await TaskInfo.get_or_none(name=task_name):
|
||||
|
||||
@ -1,13 +1,15 @@
|
||||
from nonebot import on_message
|
||||
from nonebot.plugin import PluginMetadata
|
||||
from nonebot_plugin_alconna import UniMsg
|
||||
from nonebot_plugin_session import EventSession
|
||||
from nonebot_plugin_apscheduler import scheduler
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.configs.utils import PluginExtraData, RegisterConfig
|
||||
from zhenxun.models.chat_history import ChatHistory
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import PluginType
|
||||
from zhenxun.utils.utils import get_entity_ids
|
||||
|
||||
__plugin_meta__ = PluginMetadata(
|
||||
name="消息存储",
|
||||
@ -37,18 +39,34 @@ def rule(message: UniMsg) -> bool:
|
||||
|
||||
chat_history = on_message(rule=rule, priority=1, block=False)
|
||||
|
||||
TEMP_LIST = []
|
||||
|
||||
|
||||
@chat_history.handle()
|
||||
async def handle_message(message: UniMsg, session: EventSession):
|
||||
"""处理消息存储"""
|
||||
try:
|
||||
await ChatHistory.create(
|
||||
user_id=session.id1,
|
||||
group_id=session.id2,
|
||||
async def _(message: UniMsg, session: Uninfo):
|
||||
entity = get_entity_ids(session)
|
||||
TEMP_LIST.append(
|
||||
ChatHistory(
|
||||
user_id=entity.user_id,
|
||||
group_id=entity.group_id,
|
||||
text=str(message),
|
||||
plain_text=message.extract_plain_text(),
|
||||
bot_id=session.bot_id,
|
||||
bot_id=session.self_id,
|
||||
platform=session.platform,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@scheduler.scheduled_job(
|
||||
"interval",
|
||||
minutes=1,
|
||||
)
|
||||
async def _():
|
||||
try:
|
||||
message_list = TEMP_LIST.copy()
|
||||
TEMP_LIST.clear()
|
||||
if message_list:
|
||||
await ChatHistory.bulk_create(message_list)
|
||||
logger.debug(f"批量添加聊天记录 {len(message_list)} 条", "定时任务")
|
||||
except Exception as e:
|
||||
logger.warning("存储聊天记录失败", "chat_history", e=e)
|
||||
|
||||
@ -45,11 +45,13 @@ async def classify_plugin(
|
||||
"""
|
||||
sort_data = await sort_type()
|
||||
classify: dict[str, list] = {}
|
||||
group = await GroupConsole.get_or_none(group_id=group_id) if group_id else None
|
||||
group = await GroupConsole.get_group(group_id=group_id) if group_id else None
|
||||
bot = await BotConsole.get_or_none(bot_id=session.self_id)
|
||||
for menu, value in sort_data.items():
|
||||
for plugin in value:
|
||||
if not classify.get(menu):
|
||||
classify[menu] = []
|
||||
classify[menu].append(handle(bot, plugin, group, is_detail))
|
||||
for value in classify.values():
|
||||
value.sort(key=lambda x: x.id)
|
||||
return classify
|
||||
|
||||
@ -21,6 +21,8 @@ class Item(BaseModel):
|
||||
"""插件名称"""
|
||||
sta: int
|
||||
"""插件状态"""
|
||||
id: int
|
||||
"""插件id"""
|
||||
|
||||
|
||||
class PluginList(BaseModel):
|
||||
@ -80,10 +82,9 @@ def __handle_item(
|
||||
sta = 2
|
||||
if f"{plugin.module}," in group.block_plugin:
|
||||
sta = 1
|
||||
if bot:
|
||||
if f"{plugin.module}," in bot.block_plugins:
|
||||
if bot and f"{plugin.module}," in bot.block_plugins:
|
||||
sta = 2
|
||||
return Item(plugin_name=plugin.name, sta=sta)
|
||||
return Item(plugin_name=plugin.name, sta=sta, id=plugin.id)
|
||||
|
||||
|
||||
def build_plugin_data(classify: dict[str, list[Item]]) -> list[dict[str, str]]:
|
||||
@ -142,7 +143,7 @@ async def build_html_image(
|
||||
template_name="zhenxun_menu.html",
|
||||
templates={"plugin_list": plugin_list},
|
||||
pages={
|
||||
"viewport": {"width": 1903, "height": 975},
|
||||
"viewport": {"width": 1903, "height": 10},
|
||||
"base_url": f"file://{TEMPLATE_PATH}",
|
||||
},
|
||||
wait=2,
|
||||
|
||||
@ -45,7 +45,7 @@ async def build_normal_image(group_id: str | None, is_detail: bool) -> BuildImag
|
||||
color="black" if idx % 2 else "white",
|
||||
)
|
||||
curr_h = 10
|
||||
group = await GroupConsole.get_or_none(group_id=group_id)
|
||||
group = await GroupConsole.get_group(group_id=group_id) if group_id else None
|
||||
for _, plugin in enumerate(plugin_list):
|
||||
text_color = (255, 255, 255) if idx % 2 else (0, 0, 0)
|
||||
if group and f"{plugin.module}," in group.block_plugin:
|
||||
@ -80,7 +80,7 @@ async def build_normal_image(group_id: str | None, is_detail: bool) -> BuildImag
|
||||
width, height = 10, 10
|
||||
for s in [
|
||||
"目前支持的功能列表:",
|
||||
"可以通过 ‘帮助 [功能名称或功能Id]’ 来获取对应功能的使用方法",
|
||||
"可以通过 '帮助 [功能名称或功能Id]' 来获取对应功能的使用方法",
|
||||
]:
|
||||
text = await BuildImage.build_text_image(s, "HYWenHei-85W.ttf", 24)
|
||||
await result.paste(text, (width, height))
|
||||
|
||||
@ -20,6 +20,12 @@ class Item(BaseModel):
|
||||
"""插件名称"""
|
||||
commands: list[str]
|
||||
"""插件命令"""
|
||||
id: str
|
||||
"""插件id"""
|
||||
status: bool
|
||||
"""插件状态"""
|
||||
has_superuser_help: bool
|
||||
"""插件是否拥有超级用户帮助"""
|
||||
|
||||
|
||||
def __handle_item(
|
||||
@ -39,23 +45,36 @@ def __handle_item(
|
||||
返回:
|
||||
Item: Item
|
||||
"""
|
||||
status = True
|
||||
has_superuser_help = False
|
||||
nb_plugin = nonebot.get_plugin_by_module_name(plugin.module_path)
|
||||
if nb_plugin and nb_plugin.metadata and nb_plugin.metadata.extra:
|
||||
extra_data = PluginExtraData(**nb_plugin.metadata.extra)
|
||||
if extra_data.superuser_help:
|
||||
has_superuser_help = True
|
||||
if not plugin.status:
|
||||
if plugin.block_type == BlockType.ALL:
|
||||
plugin.name = f"{plugin.name}(不可用)"
|
||||
status = False
|
||||
elif group and plugin.block_type == BlockType.GROUP:
|
||||
plugin.name = f"{plugin.name}(不可用)"
|
||||
status = False
|
||||
elif not group and plugin.block_type == BlockType.PRIVATE:
|
||||
plugin.name = f"{plugin.name}(不可用)"
|
||||
status = False
|
||||
elif group and f"{plugin.module}," in group.block_plugin:
|
||||
plugin.name = f"{plugin.name}(不可用)"
|
||||
status = False
|
||||
elif bot and f"{plugin.module}," in bot.block_plugins:
|
||||
plugin.name = f"{plugin.name}(不可用)"
|
||||
status = False
|
||||
commands = []
|
||||
nb_plugin = nonebot.get_plugin_by_module_name(plugin.module_path)
|
||||
if is_detail and nb_plugin and nb_plugin.metadata and nb_plugin.metadata.extra:
|
||||
extra_data = PluginExtraData(**nb_plugin.metadata.extra)
|
||||
commands = [cmd.command for cmd in extra_data.commands]
|
||||
return Item(plugin_name=f"{plugin.id}-{plugin.name}", commands=commands)
|
||||
return Item(
|
||||
plugin_name=plugin.name,
|
||||
commands=commands,
|
||||
id=str(plugin.id),
|
||||
status=status,
|
||||
has_superuser_help=has_superuser_help,
|
||||
)
|
||||
|
||||
|
||||
def build_plugin_data(classify: dict[str, list[Item]]) -> list[dict[str, str]]:
|
||||
@ -78,68 +97,10 @@ def build_plugin_data(classify: dict[str, list[Item]]) -> list[dict[str, str]]:
|
||||
}
|
||||
for menu, value in classify.items()
|
||||
]
|
||||
plugin_list = build_line_data(plugin_list)
|
||||
plugin_list.insert(
|
||||
0,
|
||||
build_plugin_line(
|
||||
menu_key if menu_key not in ["normal", "功能"] else "主要功能",
|
||||
max_data,
|
||||
30,
|
||||
100,
|
||||
True,
|
||||
),
|
||||
)
|
||||
return plugin_list
|
||||
|
||||
|
||||
def build_plugin_line(
|
||||
name: str, items: list, left: int, width: int | None = None, is_max: bool = False
|
||||
) -> dict:
|
||||
"""构造插件行数据
|
||||
|
||||
参数:
|
||||
name: 菜单名称
|
||||
items: 插件名称列表
|
||||
left: 左边距
|
||||
width: 总插件长度.
|
||||
is_max: 是否为最大长度的插件菜单
|
||||
|
||||
返回:
|
||||
dict: 插件数据
|
||||
"""
|
||||
_plugins = []
|
||||
width = width or 50
|
||||
if len(items) // 2 > 6 or is_max:
|
||||
width = 100
|
||||
plugin_list1 = []
|
||||
plugin_list2 = []
|
||||
for i in range(len(items)):
|
||||
if i % 2:
|
||||
plugin_list1.append(items[i])
|
||||
else:
|
||||
plugin_list2.append(items[i])
|
||||
_plugins = [(30, 50, plugin_list1), (0, 50, plugin_list2)]
|
||||
else:
|
||||
_plugins = [(left, 100, items)]
|
||||
return {"name": name, "items": _plugins, "width": width}
|
||||
|
||||
|
||||
def build_line_data(plugin_list: list[dict]) -> list[dict]:
|
||||
"""构造插件数据
|
||||
|
||||
参数:
|
||||
plugin_list: 插件列表
|
||||
|
||||
返回:
|
||||
list[dict]: 插件数据
|
||||
"""
|
||||
left = 30
|
||||
data = []
|
||||
plugin_list.insert(0, {"name": menu_key, "items": max_data})
|
||||
for plugin in plugin_list:
|
||||
data.append(build_plugin_line(plugin["name"], plugin["items"], left))
|
||||
if len(plugin["items"]) // 2 <= 6:
|
||||
left = 15 if left == 30 else 30
|
||||
return data
|
||||
plugin["items"].sort(key=lambda x: x.id)
|
||||
return plugin_list
|
||||
|
||||
|
||||
async def build_zhenxun_image(
|
||||
@ -160,6 +121,7 @@ async def build_zhenxun_image(
|
||||
width = int(637 * 1.5) if is_detail else 637
|
||||
title_font = int(53 * 1.5) if is_detail else 53
|
||||
tip_font = int(19 * 1.5) if is_detail else 19
|
||||
plugin_count = sum(len(plugin["items"]) for plugin in plugin_list)
|
||||
return await template_to_pic(
|
||||
template_path=str((TEMPLATE_PATH / "ss_menu").absolute()),
|
||||
template_name="main.html",
|
||||
@ -170,10 +132,11 @@ async def build_zhenxun_image(
|
||||
"width": width,
|
||||
"font_size": (title_font, tip_font),
|
||||
"is_detail": is_detail,
|
||||
"plugin_count": plugin_count,
|
||||
}
|
||||
},
|
||||
pages={
|
||||
"viewport": {"width": width, "height": 453},
|
||||
"viewport": {"width": width, "height": 10},
|
||||
"base_url": f"file://{TEMPLATE_PATH}",
|
||||
},
|
||||
wait=2,
|
||||
|
||||
@ -1,597 +0,0 @@
|
||||
from typing import ClassVar
|
||||
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot.adapters.onebot.v11 import PokeNotifyEvent
|
||||
from nonebot.exception import IgnoredException
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot_plugin_alconna import At, UniMsg
|
||||
from nonebot_plugin_session import EventSession
|
||||
from pydantic import BaseModel
|
||||
from tortoise.exceptions import IntegrityError
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.models.bot_console import BotConsole
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.models.level_user import LevelUser
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.models.plugin_limit import PluginLimit
|
||||
from zhenxun.models.sign_user import SignUser
|
||||
from zhenxun.models.user_console import UserConsole
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import (
|
||||
BlockType,
|
||||
GoldHandle,
|
||||
LimitWatchType,
|
||||
PluginLimitType,
|
||||
PluginType,
|
||||
)
|
||||
from zhenxun.utils.exception import InsufficientGold
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
from zhenxun.utils.utils import CountLimiter, FreqLimiter, UserBlockLimiter
|
||||
|
||||
base_config = Config.get("hook")
|
||||
|
||||
|
||||
class Limit(BaseModel):
|
||||
limit: PluginLimit
|
||||
limiter: FreqLimiter | UserBlockLimiter | CountLimiter
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class LimitManage:
|
||||
add_module: ClassVar[list] = []
|
||||
|
||||
cd_limit: ClassVar[dict[str, Limit]] = {}
|
||||
block_limit: ClassVar[dict[str, Limit]] = {}
|
||||
count_limit: ClassVar[dict[str, Limit]] = {}
|
||||
|
||||
@classmethod
|
||||
def add_limit(cls, limit: PluginLimit):
|
||||
"""添加限制
|
||||
|
||||
参数:
|
||||
limit: PluginLimit
|
||||
"""
|
||||
if limit.module not in cls.add_module:
|
||||
cls.add_module.append(limit.module)
|
||||
if limit.limit_type == PluginLimitType.BLOCK:
|
||||
cls.block_limit[limit.module] = Limit(
|
||||
limit=limit, limiter=UserBlockLimiter()
|
||||
)
|
||||
elif limit.limit_type == PluginLimitType.CD:
|
||||
cls.cd_limit[limit.module] = Limit(
|
||||
limit=limit, limiter=FreqLimiter(limit.cd)
|
||||
)
|
||||
elif limit.limit_type == PluginLimitType.COUNT:
|
||||
cls.count_limit[limit.module] = Limit(
|
||||
limit=limit, limiter=CountLimiter(limit.max_count)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def unblock(
|
||||
cls, module: str, user_id: str, group_id: str | None, channel_id: str | None
|
||||
):
|
||||
"""解除插件block
|
||||
|
||||
参数:
|
||||
module: 模块名
|
||||
user_id: 用户id
|
||||
group_id: 群组id
|
||||
channel_id: 频道id
|
||||
"""
|
||||
if limit_model := cls.block_limit.get(module):
|
||||
limit = limit_model.limit
|
||||
limiter: UserBlockLimiter = limit_model.limiter # type: ignore
|
||||
key_type = user_id
|
||||
if group_id and limit.watch_type == LimitWatchType.GROUP:
|
||||
key_type = channel_id or group_id
|
||||
logger.debug(
|
||||
f"解除对象: {key_type} 的block限制",
|
||||
"AuthChecker",
|
||||
session=user_id,
|
||||
group_id=group_id,
|
||||
)
|
||||
limiter.set_false(key_type)
|
||||
|
||||
@classmethod
|
||||
async def check(
|
||||
cls,
|
||||
module: str,
|
||||
user_id: str,
|
||||
group_id: str | None,
|
||||
channel_id: str | None,
|
||||
session: EventSession,
|
||||
):
|
||||
"""检测限制
|
||||
|
||||
参数:
|
||||
module: 模块名
|
||||
user_id: 用户id
|
||||
group_id: 群组id
|
||||
channel_id: 频道id
|
||||
session: Session
|
||||
|
||||
异常:
|
||||
IgnoredException: IgnoredException
|
||||
"""
|
||||
if limit_model := cls.cd_limit.get(module):
|
||||
await cls.__check(limit_model, user_id, group_id, channel_id, session)
|
||||
if limit_model := cls.block_limit.get(module):
|
||||
await cls.__check(limit_model, user_id, group_id, channel_id, session)
|
||||
if limit_model := cls.count_limit.get(module):
|
||||
await cls.__check(limit_model, user_id, group_id, channel_id, session)
|
||||
|
||||
@classmethod
|
||||
async def __check(
|
||||
cls,
|
||||
limit_model: Limit | None,
|
||||
user_id: str,
|
||||
group_id: str | None,
|
||||
channel_id: str | None,
|
||||
session: EventSession,
|
||||
):
|
||||
"""检测限制
|
||||
|
||||
参数:
|
||||
limit_model: Limit
|
||||
user_id: 用户id
|
||||
group_id: 群组id
|
||||
channel_id: 频道id
|
||||
session: Session
|
||||
|
||||
异常:
|
||||
IgnoredException: IgnoredException
|
||||
"""
|
||||
if not limit_model:
|
||||
return
|
||||
limit = limit_model.limit
|
||||
limiter = limit_model.limiter
|
||||
is_limit = (
|
||||
LimitWatchType.ALL
|
||||
or (group_id and limit.watch_type == LimitWatchType.GROUP)
|
||||
or (not group_id and limit.watch_type == LimitWatchType.USER)
|
||||
)
|
||||
key_type = user_id
|
||||
if group_id and limit.watch_type == LimitWatchType.GROUP:
|
||||
key_type = channel_id or group_id
|
||||
if is_limit and not limiter.check(key_type):
|
||||
if limit.result:
|
||||
await MessageUtils.build_message(limit.result).send()
|
||||
logger.debug(
|
||||
f"{limit.module}({limit.limit_type}) 正在限制中...",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException(f"{limit.module} 正在限制中...")
|
||||
else:
|
||||
logger.debug(
|
||||
f"开始进行限制 {limit.module}({limit.limit_type})...",
|
||||
"AuthChecker",
|
||||
session=user_id,
|
||||
group_id=group_id,
|
||||
)
|
||||
if isinstance(limiter, FreqLimiter):
|
||||
limiter.start_cd(key_type)
|
||||
if isinstance(limiter, UserBlockLimiter):
|
||||
limiter.set_true(key_type)
|
||||
if isinstance(limiter, CountLimiter):
|
||||
limiter.increase(key_type)
|
||||
|
||||
|
||||
class IsSuperuserException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class AuthChecker:
|
||||
"""
|
||||
权限检查
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
check_notice_info_cd = Config.get_config("hook", "CHECK_NOTICE_INFO_CD")
|
||||
if check_notice_info_cd is None or check_notice_info_cd < 0:
|
||||
raise ValueError("模块: [hook], 配置项: [CHECK_NOTICE_INFO_CD] 为空或小于0")
|
||||
self._flmt = FreqLimiter(check_notice_info_cd)
|
||||
self._flmt_g = FreqLimiter(check_notice_info_cd)
|
||||
self._flmt_s = FreqLimiter(check_notice_info_cd)
|
||||
self._flmt_c = FreqLimiter(check_notice_info_cd)
|
||||
|
||||
def is_send_limit_message(self, plugin: PluginInfo, sid: str) -> bool:
|
||||
"""是否发送提示消息
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
|
||||
返回:
|
||||
bool: 是否发送提示消息
|
||||
"""
|
||||
if not base_config.get("IS_SEND_TIP_MESSAGE"):
|
||||
return False
|
||||
if plugin.plugin_type == PluginType.DEPENDANT:
|
||||
return False
|
||||
if plugin.ignore_prompt:
|
||||
return False
|
||||
return self._flmt_s.check(sid)
|
||||
|
||||
async def auth(
|
||||
self,
|
||||
matcher: Matcher,
|
||||
event: Event,
|
||||
bot: Bot,
|
||||
session: EventSession,
|
||||
message: UniMsg,
|
||||
):
|
||||
"""权限检查
|
||||
|
||||
参数:
|
||||
matcher: matcher
|
||||
bot: bot
|
||||
session: EventSession
|
||||
message: UniMsg
|
||||
"""
|
||||
is_ignore = False
|
||||
cost_gold = 0
|
||||
user_id = session.id1
|
||||
group_id = session.id3
|
||||
channel_id = session.id2
|
||||
if not group_id:
|
||||
group_id = channel_id
|
||||
channel_id = None
|
||||
if matcher.type == "notice" and not isinstance(event, PokeNotifyEvent):
|
||||
"""过滤除poke外的notice"""
|
||||
return
|
||||
if user_id and matcher.plugin and (module_path := matcher.plugin.module_name):
|
||||
try:
|
||||
user = await UserConsole.get_user(user_id, session.platform)
|
||||
except IntegrityError as e:
|
||||
logger.debug(
|
||||
"重复创建用户,已跳过该次权限...",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
e=e,
|
||||
)
|
||||
return
|
||||
if plugin := await PluginInfo.get_or_none(module_path=module_path):
|
||||
if plugin.plugin_type == PluginType.HIDDEN:
|
||||
logger.debug(
|
||||
f"插件: {plugin.name}:{plugin.module} "
|
||||
"为HIDDEN,已跳过权限检查..."
|
||||
)
|
||||
return
|
||||
try:
|
||||
cost_gold = await self.auth_cost(user, plugin, session)
|
||||
if session.id1 in bot.config.superusers:
|
||||
if plugin.plugin_type == PluginType.SUPERUSER:
|
||||
raise IsSuperuserException()
|
||||
if not plugin.limit_superuser:
|
||||
cost_gold = 0
|
||||
raise IsSuperuserException()
|
||||
await self.auth_bot(plugin, bot.self_id)
|
||||
await self.auth_group(plugin, session, message)
|
||||
await self.auth_admin(plugin, session)
|
||||
await self.auth_plugin(plugin, session, event)
|
||||
await self.auth_limit(plugin, session)
|
||||
except IsSuperuserException:
|
||||
logger.debug(
|
||||
"超级用户或被ban跳过权限检测...", "AuthChecker", session=session
|
||||
)
|
||||
except IgnoredException:
|
||||
is_ignore = True
|
||||
LimitManage.unblock(
|
||||
matcher.plugin.name, user_id, group_id, channel_id
|
||||
)
|
||||
except AssertionError as e:
|
||||
is_ignore = True
|
||||
logger.debug("消息无法发送", session=session, e=e)
|
||||
if cost_gold and user_id:
|
||||
"""花费金币"""
|
||||
try:
|
||||
await UserConsole.reduce_gold(
|
||||
user_id,
|
||||
cost_gold,
|
||||
GoldHandle.PLUGIN,
|
||||
matcher.plugin.name if matcher.plugin else "",
|
||||
session.platform,
|
||||
)
|
||||
except InsufficientGold:
|
||||
if u := await UserConsole.get_user(user_id):
|
||||
u.gold = 0
|
||||
await u.save(update_fields=["gold"])
|
||||
logger.debug(
|
||||
f"调用功能花费金币: {cost_gold}", "AuthChecker", session=session
|
||||
)
|
||||
if is_ignore:
|
||||
raise IgnoredException("权限检测 ignore")
|
||||
|
||||
async def auth_bot(self, plugin: PluginInfo, bot_id: str):
|
||||
"""机器人权限
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
bot_id: bot_id
|
||||
"""
|
||||
if not await BotConsole.get_bot_status(bot_id):
|
||||
logger.debug("Bot休眠中阻断权限检测...", "AuthChecker")
|
||||
raise IgnoredException("BotConsole休眠权限检测 ignore")
|
||||
if await BotConsole.is_block_plugin(bot_id, plugin.module):
|
||||
logger.debug(
|
||||
f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭...",
|
||||
"AuthChecker",
|
||||
)
|
||||
raise IgnoredException("BotConsole插件权限检测 ignore")
|
||||
|
||||
async def auth_limit(self, plugin: PluginInfo, session: EventSession):
|
||||
"""插件限制
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
session: EventSession
|
||||
"""
|
||||
user_id = session.id1
|
||||
group_id = session.id3
|
||||
channel_id = session.id2
|
||||
if not group_id:
|
||||
group_id = channel_id
|
||||
channel_id = None
|
||||
if plugin.module not in LimitManage.add_module:
|
||||
limit_list: list[PluginLimit] = await plugin.plugin_limit.filter(
|
||||
status=True
|
||||
).all() # type: ignore
|
||||
for limit in limit_list:
|
||||
LimitManage.add_limit(limit)
|
||||
if user_id:
|
||||
await LimitManage.check(
|
||||
plugin.module, user_id, group_id, channel_id, session
|
||||
)
|
||||
|
||||
async def auth_plugin(
|
||||
self, plugin: PluginInfo, session: EventSession, event: Event
|
||||
):
|
||||
"""插件状态
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
session: EventSession
|
||||
"""
|
||||
group_id = session.id3
|
||||
channel_id = session.id2
|
||||
if not group_id:
|
||||
group_id = channel_id
|
||||
channel_id = None
|
||||
if user_id := session.id1:
|
||||
if plugin.impression > 0:
|
||||
sign_user = await SignUser.get_user(user_id)
|
||||
if float(sign_user.impression) < plugin.impression:
|
||||
if self.is_send_limit_message(plugin, user_id):
|
||||
self._flmt_s.start_cd(user_id)
|
||||
await MessageUtils.build_message(
|
||||
f"好感度不足哦,当前功能需要好感度: {plugin.impression},"
|
||||
"请继续签到提升好感度吧!"
|
||||
).send(reply_to=True)
|
||||
logger.debug(
|
||||
f"{plugin.name}({plugin.module}) 用户好感度不足...",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException("好感度不足...")
|
||||
if group_id:
|
||||
sid = group_id or user_id
|
||||
if await GroupConsole.is_superuser_block_plugin(
|
||||
group_id, plugin.module
|
||||
):
|
||||
"""超级用户群组插件状态"""
|
||||
if self.is_send_limit_message(plugin, sid):
|
||||
self._flmt_s.start_cd(group_id or user_id)
|
||||
await MessageUtils.build_message(
|
||||
"超级管理员禁用了该群此功能..."
|
||||
).send(reply_to=True)
|
||||
logger.debug(
|
||||
f"{plugin.name}({plugin.module}) 超级管理员禁用了该群此功能...",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException("超级管理员禁用了该群此功能...")
|
||||
if await GroupConsole.is_normal_block_plugin(group_id, plugin.module):
|
||||
"""群组插件状态"""
|
||||
if self.is_send_limit_message(plugin, sid):
|
||||
self._flmt_s.start_cd(group_id or user_id)
|
||||
await MessageUtils.build_message("该群未开启此功能...").send(
|
||||
reply_to=True
|
||||
)
|
||||
logger.debug(
|
||||
f"{plugin.name}({plugin.module}) 未开启此功能...",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException("该群未开启此功能...")
|
||||
if plugin.block_type == BlockType.GROUP:
|
||||
"""全局群组禁用"""
|
||||
try:
|
||||
if self.is_send_limit_message(plugin, sid):
|
||||
self._flmt_c.start_cd(group_id)
|
||||
await MessageUtils.build_message(
|
||||
"该功能在群组中已被禁用..."
|
||||
).send(reply_to=True)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"auth_plugin 发送消息失败",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
e=e,
|
||||
)
|
||||
logger.debug(
|
||||
f"{plugin.name}({plugin.module}) 该插件在群组中已被禁用...",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException("该插件在群组中已被禁用...")
|
||||
else:
|
||||
sid = user_id
|
||||
if plugin.block_type == BlockType.PRIVATE:
|
||||
"""全局私聊禁用"""
|
||||
try:
|
||||
if self.is_send_limit_message(plugin, sid):
|
||||
self._flmt_c.start_cd(user_id)
|
||||
await MessageUtils.build_message(
|
||||
"该功能在私聊中已被禁用..."
|
||||
).send()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"auth_admin 发送消息失败",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
e=e,
|
||||
)
|
||||
logger.debug(
|
||||
f"{plugin.name}({plugin.module}) 该插件在私聊中已被禁用...",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException("该插件在私聊中已被禁用...")
|
||||
if not plugin.status and plugin.block_type == BlockType.ALL:
|
||||
"""全局状态"""
|
||||
if group_id and await GroupConsole.is_super_group(group_id):
|
||||
raise IsSuperuserException()
|
||||
logger.debug(
|
||||
f"{plugin.name}({plugin.module}) 全局未开启此功能...",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
if self.is_send_limit_message(plugin, sid):
|
||||
self._flmt_s.start_cd(group_id or user_id)
|
||||
await MessageUtils.build_message("全局未开启此功能...").send()
|
||||
raise IgnoredException("全局未开启此功能...")
|
||||
|
||||
async def auth_admin(self, plugin: PluginInfo, session: EventSession):
|
||||
"""管理员命令 个人权限
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
session: EventSession
|
||||
"""
|
||||
user_id = session.id1
|
||||
if user_id and plugin.admin_level:
|
||||
if group_id := session.id3 or session.id2:
|
||||
if not await LevelUser.check_level(
|
||||
user_id, group_id, plugin.admin_level
|
||||
):
|
||||
try:
|
||||
if self._flmt.check(user_id):
|
||||
self._flmt.start_cd(user_id)
|
||||
await MessageUtils.build_message(
|
||||
[
|
||||
At(flag="user", target=user_id),
|
||||
f"你的权限不足喔,"
|
||||
f"该功能需要的权限等级: {plugin.admin_level}",
|
||||
]
|
||||
).send(reply_to=True)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"auth_admin 发送消息失败",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
e=e,
|
||||
)
|
||||
logger.debug(
|
||||
f"{plugin.name}({plugin.module}) 管理员权限不足...",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException("管理员权限不足...")
|
||||
elif not await LevelUser.check_level(user_id, None, plugin.admin_level):
|
||||
try:
|
||||
await MessageUtils.build_message(
|
||||
f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}"
|
||||
).send()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"auth_admin 发送消息失败", "AuthChecker", session=session, e=e
|
||||
)
|
||||
logger.debug(
|
||||
f"{plugin.name}({plugin.module}) 管理员权限不足...",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException("权限不足")
|
||||
|
||||
async def auth_group(
|
||||
self, plugin: PluginInfo, session: EventSession, message: UniMsg
|
||||
):
|
||||
"""群黑名单检测 群总开关检测
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
session: EventSession
|
||||
message: UniMsg
|
||||
"""
|
||||
if not (group_id := session.id3 or session.id2):
|
||||
return
|
||||
text = message.extract_plain_text()
|
||||
group = await GroupConsole.get_group(group_id)
|
||||
if not group:
|
||||
"""群不存在"""
|
||||
logger.debug(
|
||||
"群组信息不存在...",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException("群不存在")
|
||||
if group.level < 0:
|
||||
"""群权限小于0"""
|
||||
logger.debug(
|
||||
"群黑名单, 群权限-1...",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException("群黑名单")
|
||||
if not group.status:
|
||||
"""群休眠"""
|
||||
if text.strip() != "醒来":
|
||||
logger.debug("群休眠状态...", "AuthChecker", session=session)
|
||||
raise IgnoredException("群休眠状态")
|
||||
if plugin.level > group.level:
|
||||
"""插件等级大于群等级"""
|
||||
logger.debug(
|
||||
f"{plugin.name}({plugin.module}) 群等级限制.."
|
||||
f"该功能需要的群等级: {plugin.level}..",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException(f"{plugin.name}({plugin.module}) 群等级限制...")
|
||||
|
||||
async def auth_cost(
|
||||
self, user: UserConsole, plugin: PluginInfo, session: EventSession
|
||||
) -> int:
|
||||
"""检测是否满足金币条件
|
||||
|
||||
参数:
|
||||
user: UserConsole
|
||||
plugin: PluginInfo
|
||||
session: EventSession
|
||||
|
||||
返回:
|
||||
int: 需要消耗的金币
|
||||
"""
|
||||
if user.gold < plugin.cost_gold:
|
||||
"""插件消耗金币不足"""
|
||||
try:
|
||||
await MessageUtils.build_message(
|
||||
f"金币不足..该功能需要{plugin.cost_gold}金币.."
|
||||
).send()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"auth_cost 发送消息失败", "AuthChecker", session=session, e=e
|
||||
)
|
||||
logger.debug(
|
||||
f"{plugin.name}({plugin.module}) 金币限制.."
|
||||
f"该功能需要{plugin.cost_gold}金币..",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException(f"{plugin.name}({plugin.module}) 金币限制...")
|
||||
return plugin.cost_gold
|
||||
|
||||
|
||||
checker = AuthChecker()
|
||||
99
zhenxun/builtin_plugins/hooks/auth/auth_admin.py
Normal file
99
zhenxun/builtin_plugins/hooks/auth/auth_admin.py
Normal file
@ -0,0 +1,99 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from nonebot_plugin_alconna import At
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.models.level_user import LevelUser
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.services.data_access import DataAccess
|
||||
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.utils import get_entity_ids
|
||||
|
||||
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
|
||||
from .exception import SkipPluginException
|
||||
from .utils import send_message
|
||||
|
||||
|
||||
async def auth_admin(plugin: PluginInfo, session: Uninfo):
|
||||
"""管理员命令 个人权限
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
session: Uninfo
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
if not plugin.admin_level:
|
||||
return
|
||||
|
||||
try:
|
||||
entity = get_entity_ids(session)
|
||||
level_dao = DataAccess(LevelUser)
|
||||
|
||||
# 并行查询用户权限数据
|
||||
global_user: LevelUser | None = None
|
||||
group_users: LevelUser | None = None
|
||||
|
||||
# 查询全局权限
|
||||
global_user_task = level_dao.safe_get_or_none(
|
||||
user_id=session.user.id, group_id__isnull=True
|
||||
)
|
||||
|
||||
# 如果在群组中,查询群组权限
|
||||
group_users_task = None
|
||||
if entity.group_id:
|
||||
group_users_task = level_dao.safe_get_or_none(
|
||||
user_id=session.user.id, group_id=entity.group_id
|
||||
)
|
||||
|
||||
# 等待查询完成,添加超时控制
|
||||
try:
|
||||
results = await asyncio.wait_for(
|
||||
asyncio.gather(global_user_task, group_users_task or asyncio.sleep(0)),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
global_user = results[0]
|
||||
group_users = results[1] if group_users_task else None
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"查询用户权限超时: user_id={session.user.id}", LOGGER_COMMAND)
|
||||
# 超时时不阻塞,继续执行
|
||||
return
|
||||
|
||||
user_level = global_user.user_level if global_user else 0
|
||||
if entity.group_id and group_users:
|
||||
user_level = max(user_level, group_users.user_level)
|
||||
|
||||
if user_level < plugin.admin_level:
|
||||
await send_message(
|
||||
session,
|
||||
[
|
||||
At(flag="user", target=session.user.id),
|
||||
f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}",
|
||||
],
|
||||
entity.user_id,
|
||||
)
|
||||
|
||||
raise SkipPluginException(
|
||||
f"{plugin.name}({plugin.module}) 管理员权限不足..."
|
||||
)
|
||||
elif global_user:
|
||||
if global_user.user_level < plugin.admin_level:
|
||||
await send_message(
|
||||
session,
|
||||
f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}",
|
||||
)
|
||||
|
||||
raise SkipPluginException(
|
||||
f"{plugin.name}({plugin.module}) 管理员权限不足..."
|
||||
)
|
||||
finally:
|
||||
# 记录执行时间
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||
logger.warning(
|
||||
f"auth_admin 耗时: {elapsed:.3f}s, plugin={plugin.module}",
|
||||
LOGGER_COMMAND,
|
||||
session=session,
|
||||
)
|
||||
303
zhenxun/builtin_plugins/hooks/auth/auth_ban.py
Normal file
303
zhenxun/builtin_plugins/hooks/auth/auth_ban.py
Normal file
@ -0,0 +1,303 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from nonebot.adapters import Bot
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot_plugin_alconna import At
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.models.ban_console import BanConsole
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.services.data_access import DataAccess
|
||||
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import PluginType
|
||||
from zhenxun.utils.utils import EntityIDs, get_entity_ids
|
||||
|
||||
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
|
||||
from .exception import SkipPluginException
|
||||
from .utils import freq, send_message
|
||||
|
||||
Config.add_plugin_config(
|
||||
"hook",
|
||||
"BAN_RESULT",
|
||||
"才不会给你发消息.",
|
||||
help="对被ban用户发送的消息",
|
||||
)
|
||||
|
||||
|
||||
def calculate_ban_time(ban_record: BanConsole | None) -> int:
|
||||
"""根据ban记录计算剩余ban时间
|
||||
|
||||
参数:
|
||||
ban_record: BanConsole记录
|
||||
|
||||
返回:
|
||||
int: ban剩余时长,-1时为永久ban,0表示未被ban
|
||||
"""
|
||||
if not ban_record:
|
||||
return 0
|
||||
|
||||
if ban_record.duration == -1:
|
||||
return -1
|
||||
|
||||
_time = time.time() - (ban_record.ban_time + ban_record.duration)
|
||||
return 0 if _time > 0 else int(abs(_time))
|
||||
|
||||
|
||||
async def is_ban(user_id: str | None, group_id: str | None) -> int:
|
||||
"""检查用户或群组是否被ban
|
||||
|
||||
参数:
|
||||
user_id: 用户ID
|
||||
group_id: 群组ID
|
||||
|
||||
返回:
|
||||
int: ban的剩余时间,0表示未被ban
|
||||
"""
|
||||
if not user_id and not group_id:
|
||||
return 0
|
||||
|
||||
start_time = time.time()
|
||||
ban_dao = DataAccess(BanConsole)
|
||||
|
||||
# 分别获取用户在群组中的ban记录和全局ban记录
|
||||
group_user = None
|
||||
user = None
|
||||
|
||||
try:
|
||||
# 并行查询用户和群组的 ban 记录
|
||||
tasks = []
|
||||
if user_id and group_id:
|
||||
tasks.append(ban_dao.safe_get_or_none(user_id=user_id, group_id=group_id))
|
||||
if user_id:
|
||||
tasks.append(
|
||||
ban_dao.safe_get_or_none(user_id=user_id, group_id__isnull=True)
|
||||
)
|
||||
|
||||
# 等待所有查询完成,添加超时控制
|
||||
if tasks:
|
||||
try:
|
||||
ban_records = await asyncio.wait_for(
|
||||
asyncio.gather(*tasks), timeout=DB_TIMEOUT_SECONDS
|
||||
)
|
||||
if len(tasks) == 2:
|
||||
group_user, user = ban_records
|
||||
elif user_id and group_id:
|
||||
group_user = ban_records[0]
|
||||
else:
|
||||
user = ban_records[0]
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
f"查询ban记录超时: user_id={user_id}, group_id={group_id}",
|
||||
LOGGER_COMMAND,
|
||||
)
|
||||
# 超时时返回0,避免阻塞
|
||||
return 0
|
||||
|
||||
# 检查记录并计算ban时间
|
||||
results = []
|
||||
if group_user:
|
||||
results.append(group_user)
|
||||
if user:
|
||||
results.append(user)
|
||||
|
||||
# 如果没有找到记录,返回0
|
||||
if not results:
|
||||
return 0
|
||||
|
||||
logger.debug(f"查询到的ban记录: {results}", LOGGER_COMMAND)
|
||||
# 检查所有记录,找出最严格的ban(时间最长的)
|
||||
max_ban_time: int = 0
|
||||
for result in results:
|
||||
if result.duration > 0 or result.duration == -1:
|
||||
# 直接计算ban时间,避免再次查询数据库
|
||||
ban_time = calculate_ban_time(result)
|
||||
if ban_time == -1 or ban_time > max_ban_time:
|
||||
max_ban_time = ban_time
|
||||
|
||||
return max_ban_time
|
||||
finally:
|
||||
# 记录执行时间
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||
logger.warning(
|
||||
f"is_ban 耗时: {elapsed:.3f}s",
|
||||
LOGGER_COMMAND,
|
||||
session=user_id,
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
|
||||
def check_plugin_type(matcher: Matcher) -> bool:
|
||||
"""判断插件类型是否是隐藏插件
|
||||
|
||||
参数:
|
||||
matcher: Matcher
|
||||
|
||||
返回:
|
||||
bool: 是否为隐藏插件
|
||||
"""
|
||||
if plugin := matcher.plugin:
|
||||
if metadata := plugin.metadata:
|
||||
extra = metadata.extra
|
||||
if extra.get("plugin_type") in [PluginType.HIDDEN]:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def format_time(time_val: float) -> str:
|
||||
"""格式化时间
|
||||
|
||||
参数:
|
||||
time_val: ban时长
|
||||
|
||||
返回:
|
||||
str: 格式化时间文本
|
||||
"""
|
||||
if time_val == -1:
|
||||
return "∞"
|
||||
time_val = abs(int(time_val))
|
||||
if time_val < 60:
|
||||
time_str = f"{time_val!s} 秒"
|
||||
else:
|
||||
minute = int(time_val / 60)
|
||||
if minute > 60:
|
||||
hours = minute // 60
|
||||
minute %= 60
|
||||
time_str = f"{hours} 小时 {minute}分钟"
|
||||
else:
|
||||
time_str = f"{minute} 分钟"
|
||||
return time_str
|
||||
|
||||
|
||||
async def group_handle(group_id: str) -> None:
|
||||
"""群组ban检查
|
||||
|
||||
参数:
|
||||
group_id: 群组id
|
||||
|
||||
异常:
|
||||
SkipPluginException: 群组处于黑名单
|
||||
"""
|
||||
start_time = time.time()
|
||||
try:
|
||||
if await is_ban(None, group_id):
|
||||
raise SkipPluginException("群组处于黑名单中...")
|
||||
finally:
|
||||
# 记录执行时间
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||
logger.warning(
|
||||
f"group_handle 耗时: {elapsed:.3f}s",
|
||||
LOGGER_COMMAND,
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
|
||||
async def user_handle(module: str, entity: EntityIDs, session: Uninfo) -> None:
|
||||
"""用户ban检查
|
||||
|
||||
参数:
|
||||
module: 插件模块名
|
||||
entity: 实体ID信息
|
||||
session: Uninfo
|
||||
|
||||
异常:
|
||||
SkipPluginException: 用户处于黑名单
|
||||
"""
|
||||
start_time = time.time()
|
||||
try:
|
||||
ban_result = Config.get_config("hook", "BAN_RESULT")
|
||||
time_val = await is_ban(entity.user_id, entity.group_id)
|
||||
if not time_val:
|
||||
return
|
||||
time_str = format_time(time_val)
|
||||
plugin_dao = DataAccess(PluginInfo)
|
||||
try:
|
||||
db_plugin = await asyncio.wait_for(
|
||||
plugin_dao.safe_get_or_none(module=module), timeout=DB_TIMEOUT_SECONDS
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"查询插件信息超时: {module}", LOGGER_COMMAND)
|
||||
# 超时时不阻塞,继续执行
|
||||
raise SkipPluginException("用户处于黑名单中...")
|
||||
|
||||
if (
|
||||
db_plugin
|
||||
and not db_plugin.ignore_prompt
|
||||
and time_val != -1
|
||||
and ban_result
|
||||
and freq.is_send_limit_message(db_plugin, entity.user_id, False)
|
||||
):
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
send_message(
|
||||
session,
|
||||
[
|
||||
At(flag="user", target=entity.user_id),
|
||||
f"{ban_result}\n在..在 {time_str} 后才会理你喔",
|
||||
],
|
||||
entity.user_id,
|
||||
),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"发送消息超时: {entity.user_id}", LOGGER_COMMAND)
|
||||
raise SkipPluginException("用户处于黑名单中...")
|
||||
finally:
|
||||
# 记录执行时间
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||
logger.warning(
|
||||
f"user_handle 耗时: {elapsed:.3f}s",
|
||||
LOGGER_COMMAND,
|
||||
session=session,
|
||||
)
|
||||
|
||||
|
||||
async def auth_ban(matcher: Matcher, bot: Bot, session: Uninfo) -> None:
|
||||
"""权限检查 - ban 检查
|
||||
|
||||
参数:
|
||||
matcher: Matcher
|
||||
bot: Bot
|
||||
session: Uninfo
|
||||
"""
|
||||
start_time = time.time()
|
||||
try:
|
||||
if not check_plugin_type(matcher):
|
||||
return
|
||||
if not matcher.plugin_name:
|
||||
return
|
||||
entity = get_entity_ids(session)
|
||||
if entity.user_id in bot.config.superusers:
|
||||
return
|
||||
if entity.group_id:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
group_handle(entity.group_id), timeout=DB_TIMEOUT_SECONDS
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"群组ban检查超时: {entity.group_id}", LOGGER_COMMAND)
|
||||
# 超时时不阻塞,继续执行
|
||||
|
||||
if entity.user_id:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
user_handle(matcher.plugin_name, entity, session),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"用户ban检查超时: {entity.user_id}", LOGGER_COMMAND)
|
||||
# 超时时不阻塞,继续执行
|
||||
finally:
|
||||
# 记录总执行时间
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||
logger.warning(
|
||||
f"auth_ban 总耗时: {elapsed:.3f}s, plugin={matcher.plugin_name}",
|
||||
LOGGER_COMMAND,
|
||||
session=session,
|
||||
)
|
||||
55
zhenxun/builtin_plugins/hooks/auth/auth_bot.py
Normal file
55
zhenxun/builtin_plugins/hooks/auth/auth_bot.py
Normal file
@ -0,0 +1,55 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from zhenxun.models.bot_console import BotConsole
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.services.data_access import DataAccess
|
||||
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.common_utils import CommonUtils
|
||||
|
||||
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
|
||||
from .exception import SkipPluginException
|
||||
|
||||
|
||||
async def auth_bot(plugin: PluginInfo, bot_id: str):
|
||||
"""bot层面的权限检查
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
bot_id: bot id
|
||||
|
||||
异常:
|
||||
SkipPluginException: 忽略插件
|
||||
SkipPluginException: 忽略插件
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 从数据库或缓存中获取 bot 信息
|
||||
bot_dao = DataAccess(BotConsole)
|
||||
|
||||
try:
|
||||
bot: BotConsole | None = await asyncio.wait_for(
|
||||
bot_dao.safe_get_or_none(bot_id=bot_id), timeout=DB_TIMEOUT_SECONDS
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"查询Bot信息超时: bot_id={bot_id}", LOGGER_COMMAND)
|
||||
# 超时时不阻塞,继续执行
|
||||
return
|
||||
|
||||
if not bot or not bot.status:
|
||||
raise SkipPluginException("Bot不存在或休眠中阻断权限检测...")
|
||||
if CommonUtils.format(plugin.module) in bot.block_plugins:
|
||||
raise SkipPluginException(
|
||||
f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭..."
|
||||
)
|
||||
finally:
|
||||
# 记录执行时间
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||
logger.warning(
|
||||
f"auth_bot 耗时: {elapsed:.3f}s, "
|
||||
f"bot_id={bot_id}, plugin={plugin.module}",
|
||||
LOGGER_COMMAND,
|
||||
)
|
||||
41
zhenxun/builtin_plugins/hooks/auth/auth_cost.py
Normal file
41
zhenxun/builtin_plugins/hooks/auth/auth_cost.py
Normal file
@ -0,0 +1,41 @@
|
||||
import time
|
||||
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.models.user_console import UserConsole
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
|
||||
from .exception import SkipPluginException
|
||||
from .utils import send_message
|
||||
|
||||
|
||||
async def auth_cost(user: UserConsole, plugin: PluginInfo, session: Uninfo) -> int:
|
||||
"""检测是否满足金币条件
|
||||
|
||||
参数:
|
||||
user: UserConsole
|
||||
plugin: PluginInfo
|
||||
session: Uninfo
|
||||
|
||||
返回:
|
||||
int: 需要消耗的金币
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
if user.gold < plugin.cost_gold:
|
||||
"""插件消耗金币不足"""
|
||||
await send_message(session, f"金币不足..该功能需要{plugin.cost_gold}金币..")
|
||||
raise SkipPluginException(f"{plugin.name}({plugin.module}) 金币限制...")
|
||||
return plugin.cost_gold
|
||||
finally:
|
||||
# 记录执行时间
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||
logger.warning(
|
||||
f"auth_cost 耗时: {elapsed:.3f}s, plugin={plugin.module}",
|
||||
LOGGER_COMMAND,
|
||||
session=session,
|
||||
)
|
||||
68
zhenxun/builtin_plugins/hooks/auth/auth_group.py
Normal file
68
zhenxun/builtin_plugins/hooks/auth/auth_group.py
Normal file
@ -0,0 +1,68 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from nonebot_plugin_alconna import UniMsg
|
||||
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.services.data_access import DataAccess
|
||||
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.utils import EntityIDs
|
||||
|
||||
from .config import LOGGER_COMMAND, WARNING_THRESHOLD, SwitchEnum
|
||||
from .exception import SkipPluginException
|
||||
|
||||
|
||||
async def auth_group(plugin: PluginInfo, entity: EntityIDs, message: UniMsg):
|
||||
"""群黑名单检测 群总开关检测
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
entity: EntityIDs
|
||||
message: UniMsg
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
if not entity.group_id:
|
||||
return
|
||||
|
||||
try:
|
||||
text = message.extract_plain_text()
|
||||
|
||||
# 从数据库或缓存中获取群组信息
|
||||
group_dao = DataAccess(GroupConsole)
|
||||
|
||||
try:
|
||||
group: GroupConsole | None = await asyncio.wait_for(
|
||||
group_dao.safe_get_or_none(
|
||||
group_id=entity.group_id, channel_id__isnull=True
|
||||
),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("查询群组信息超时", LOGGER_COMMAND, session=entity.user_id)
|
||||
# 超时时不阻塞,继续执行
|
||||
return
|
||||
|
||||
if not group:
|
||||
raise SkipPluginException("群组信息不存在...")
|
||||
if group.level < 0:
|
||||
raise SkipPluginException("群组黑名单, 目标群组群权限权限-1...")
|
||||
if text.strip() != SwitchEnum.ENABLE and not group.status:
|
||||
raise SkipPluginException("群组休眠状态...")
|
||||
if plugin.level > group.level:
|
||||
raise SkipPluginException(
|
||||
f"{plugin.name}({plugin.module}) 群等级限制,"
|
||||
f"该功能需要的群等级: {plugin.level}..."
|
||||
)
|
||||
finally:
|
||||
# 记录执行时间
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||
logger.warning(
|
||||
f"auth_group 耗时: {elapsed:.3f}s, plugin={plugin.module}",
|
||||
LOGGER_COMMAND,
|
||||
session=entity.user_id,
|
||||
group_id=entity.group_id,
|
||||
)
|
||||
318
zhenxun/builtin_plugins/hooks/auth/auth_limit.py
Normal file
318
zhenxun/builtin_plugins/hooks/auth/auth_limit.py
Normal file
@ -0,0 +1,318 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import ClassVar
|
||||
|
||||
import nonebot
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
from pydantic import BaseModel
|
||||
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.models.plugin_limit import PluginLimit
|
||||
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import LimitWatchType, PluginLimitType
|
||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
from zhenxun.utils.utils import (
|
||||
CountLimiter,
|
||||
FreqLimiter,
|
||||
UserBlockLimiter,
|
||||
get_entity_ids,
|
||||
)
|
||||
|
||||
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
|
||||
from .exception import SkipPluginException
|
||||
|
||||
driver = nonebot.get_driver()
|
||||
|
||||
|
||||
@PriorityLifecycle.on_startup(priority=5)
|
||||
async def _():
|
||||
"""初始化限制"""
|
||||
await LimitManager.init_limit()
|
||||
|
||||
|
||||
class Limit(BaseModel):
|
||||
limit: PluginLimit
|
||||
limiter: FreqLimiter | UserBlockLimiter | CountLimiter
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class LimitManager:
|
||||
add_module: ClassVar[list] = []
|
||||
last_update_time: ClassVar[float] = 0
|
||||
update_interval: ClassVar[float] = 6000 # 1小时更新一次
|
||||
is_updating: ClassVar[bool] = False # 防止并发更新
|
||||
|
||||
cd_limit: ClassVar[dict[str, Limit]] = {}
|
||||
block_limit: ClassVar[dict[str, Limit]] = {}
|
||||
count_limit: ClassVar[dict[str, Limit]] = {}
|
||||
|
||||
# 模块限制缓存,避免频繁查询数据库
|
||||
module_limit_cache: ClassVar[dict[str, tuple[float, list[PluginLimit]]]] = {}
|
||||
module_cache_ttl: ClassVar[float] = 60 # 模块缓存有效期(秒)
|
||||
|
||||
@classmethod
|
||||
async def init_limit(cls):
|
||||
"""初始化限制"""
|
||||
cls.last_update_time = time.time()
|
||||
try:
|
||||
await asyncio.wait_for(cls.update_limits(), timeout=DB_TIMEOUT_SECONDS * 2)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("初始化限制超时", LOGGER_COMMAND)
|
||||
|
||||
@classmethod
|
||||
async def update_limits(cls):
|
||||
"""更新限制信息"""
|
||||
# 防止并发更新
|
||||
if cls.is_updating:
|
||||
return
|
||||
|
||||
cls.is_updating = True
|
||||
try:
|
||||
start_time = time.time()
|
||||
try:
|
||||
limit_list = await asyncio.wait_for(
|
||||
PluginLimit.filter(status=True).all(), timeout=DB_TIMEOUT_SECONDS
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("查询限制信息超时", LOGGER_COMMAND)
|
||||
cls.is_updating = False
|
||||
return
|
||||
|
||||
# 清空旧数据
|
||||
cls.add_module = []
|
||||
cls.cd_limit = {}
|
||||
cls.block_limit = {}
|
||||
cls.count_limit = {}
|
||||
# 添加新数据
|
||||
for limit in limit_list:
|
||||
cls.add_limit(limit)
|
||||
|
||||
cls.last_update_time = time.time()
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的更新
|
||||
logger.warning(f"更新限制信息耗时: {elapsed:.3f}s", LOGGER_COMMAND)
|
||||
finally:
|
||||
cls.is_updating = False
|
||||
|
||||
@classmethod
|
||||
def add_limit(cls, limit: PluginLimit):
|
||||
"""添加限制
|
||||
|
||||
参数:
|
||||
limit: PluginLimit
|
||||
"""
|
||||
if limit.module not in cls.add_module:
|
||||
cls.add_module.append(limit.module)
|
||||
if limit.limit_type == PluginLimitType.BLOCK:
|
||||
cls.block_limit[limit.module] = Limit(
|
||||
limit=limit, limiter=UserBlockLimiter()
|
||||
)
|
||||
elif limit.limit_type == PluginLimitType.CD:
|
||||
cls.cd_limit[limit.module] = Limit(
|
||||
limit=limit, limiter=FreqLimiter(limit.cd)
|
||||
)
|
||||
elif limit.limit_type == PluginLimitType.COUNT:
|
||||
cls.count_limit[limit.module] = Limit(
|
||||
limit=limit, limiter=CountLimiter(limit.max_count)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def unblock(
|
||||
cls, module: str, user_id: str, group_id: str | None, channel_id: str | None
|
||||
):
|
||||
"""解除插件block
|
||||
|
||||
参数:
|
||||
module: 模块名
|
||||
user_id: 用户id
|
||||
group_id: 群组id
|
||||
channel_id: 频道id
|
||||
"""
|
||||
if limit_model := cls.block_limit.get(module):
|
||||
limit = limit_model.limit
|
||||
limiter: UserBlockLimiter = limit_model.limiter # type: ignore
|
||||
key_type = user_id
|
||||
if group_id and limit.watch_type == LimitWatchType.GROUP:
|
||||
key_type = channel_id or group_id
|
||||
logger.debug(
|
||||
f"解除对象: {key_type} 的block限制",
|
||||
LOGGER_COMMAND,
|
||||
session=user_id,
|
||||
group_id=group_id,
|
||||
)
|
||||
limiter.set_false(key_type)
|
||||
|
||||
@classmethod
|
||||
async def get_module_limits(cls, module: str) -> list[PluginLimit]:
|
||||
"""获取模块的限制信息,使用缓存减少数据库查询
|
||||
|
||||
参数:
|
||||
module: 模块名
|
||||
|
||||
返回:
|
||||
list[PluginLimit]: 限制列表
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
# 检查缓存
|
||||
if module in cls.module_limit_cache:
|
||||
cache_time, limits = cls.module_limit_cache[module]
|
||||
if current_time - cache_time < cls.module_cache_ttl:
|
||||
return limits
|
||||
|
||||
# 缓存不存在或已过期,从数据库查询
|
||||
try:
|
||||
start_time = time.time()
|
||||
limits = await asyncio.wait_for(
|
||||
PluginLimit.filter(module=module, status=True).all(),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的查询
|
||||
logger.warning(
|
||||
f"查询模块限制信息耗时: {elapsed:.3f}s, 模块: {module}",
|
||||
LOGGER_COMMAND,
|
||||
)
|
||||
|
||||
# 更新缓存
|
||||
cls.module_limit_cache[module] = (current_time, limits)
|
||||
return limits
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"查询模块限制信息超时: {module}", LOGGER_COMMAND)
|
||||
# 超时时返回空列表,避免阻塞
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
async def check(
|
||||
cls,
|
||||
module: str,
|
||||
user_id: str,
|
||||
group_id: str | None,
|
||||
channel_id: str | None,
|
||||
):
|
||||
"""检测限制
|
||||
|
||||
参数:
|
||||
module: 模块名
|
||||
user_id: 用户id
|
||||
group_id: 群组id
|
||||
channel_id: 频道id
|
||||
|
||||
异常:
|
||||
IgnoredException: IgnoredException
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# 定期更新全局限制信息
|
||||
if (
|
||||
time.time() - cls.last_update_time > cls.update_interval
|
||||
and not cls.is_updating
|
||||
):
|
||||
# 使用异步任务更新,避免阻塞当前请求
|
||||
asyncio.create_task(cls.update_limits()) # noqa: RUF006
|
||||
|
||||
# 如果模块不在已加载列表中,只加载该模块的限制
|
||||
if module not in cls.add_module:
|
||||
limits = await cls.get_module_limits(module)
|
||||
for limit in limits:
|
||||
cls.add_limit(limit)
|
||||
|
||||
# 检查各种限制
|
||||
try:
|
||||
if limit_model := cls.cd_limit.get(module):
|
||||
await cls.__check(limit_model, user_id, group_id, channel_id)
|
||||
if limit_model := cls.block_limit.get(module):
|
||||
await cls.__check(limit_model, user_id, group_id, channel_id)
|
||||
if limit_model := cls.count_limit.get(module):
|
||||
await cls.__check(limit_model, user_id, group_id, channel_id)
|
||||
finally:
|
||||
# 记录总执行时间
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||
logger.warning(
|
||||
f"限制检查耗时: {elapsed:.3f}s, 模块: {module}",
|
||||
LOGGER_COMMAND,
|
||||
session=user_id,
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def __check(
|
||||
cls,
|
||||
limit_model: Limit | None,
|
||||
user_id: str,
|
||||
group_id: str | None,
|
||||
channel_id: str | None,
|
||||
):
|
||||
"""检测限制
|
||||
|
||||
参数:
|
||||
limit_model: Limit
|
||||
user_id: 用户id
|
||||
group_id: 群组id
|
||||
channel_id: 频道id
|
||||
|
||||
异常:
|
||||
IgnoredException: IgnoredException
|
||||
"""
|
||||
if not limit_model:
|
||||
return
|
||||
limit = limit_model.limit
|
||||
limiter = limit_model.limiter
|
||||
is_limit = (
|
||||
LimitWatchType.ALL
|
||||
or (group_id and limit.watch_type == LimitWatchType.GROUP)
|
||||
or (not group_id and limit.watch_type == LimitWatchType.USER)
|
||||
)
|
||||
key_type = user_id
|
||||
if group_id and limit.watch_type == LimitWatchType.GROUP:
|
||||
key_type = channel_id or group_id
|
||||
if is_limit and not limiter.check(key_type):
|
||||
if limit.result:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
MessageUtils.build_message(limit.result).send(),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"发送限制消息超时: {limit.module}", LOGGER_COMMAND)
|
||||
raise SkipPluginException(
|
||||
f"{limit.module}({limit.limit_type}) 正在限制中..."
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"开始进行限制 {limit.module}({limit.limit_type})...",
|
||||
LOGGER_COMMAND,
|
||||
session=user_id,
|
||||
group_id=group_id,
|
||||
)
|
||||
if isinstance(limiter, FreqLimiter):
|
||||
limiter.start_cd(key_type)
|
||||
if isinstance(limiter, UserBlockLimiter):
|
||||
limiter.set_true(key_type)
|
||||
if isinstance(limiter, CountLimiter):
|
||||
limiter.increase(key_type)
|
||||
|
||||
|
||||
async def auth_limit(plugin: PluginInfo, session: Uninfo):
|
||||
"""插件限制
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
session: Uninfo
|
||||
"""
|
||||
entity = get_entity_ids(session)
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
LimitManager.check(
|
||||
plugin.module, entity.user_id, entity.group_id, entity.channel_id
|
||||
),
|
||||
timeout=DB_TIMEOUT_SECONDS * 2, # 给予更长的超时时间
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"检查插件限制超时: {plugin.module}", LOGGER_COMMAND)
|
||||
# 超时时不抛出异常,允许继续执行
|
||||
242
zhenxun/builtin_plugins/hooks/auth/auth_plugin.py
Normal file
242
zhenxun/builtin_plugins/hooks/auth/auth_plugin.py
Normal file
@ -0,0 +1,242 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from nonebot.adapters import Event
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.services.data_access import DataAccess
|
||||
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.common_utils import CommonUtils
|
||||
from zhenxun.utils.enum import BlockType
|
||||
from zhenxun.utils.utils import get_entity_ids
|
||||
|
||||
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
|
||||
from .exception import IsSuperuserException, SkipPluginException
|
||||
from .utils import freq, is_poke, send_message
|
||||
|
||||
|
||||
class GroupCheck:
|
||||
def __init__(
|
||||
self, plugin: PluginInfo, group_id: str, session: Uninfo, is_poke: bool
|
||||
) -> None:
|
||||
self.group_id = group_id
|
||||
self.session = session
|
||||
self.is_poke = is_poke
|
||||
self.plugin = plugin
|
||||
self.group_dao = DataAccess(GroupConsole)
|
||||
self.group_data = None
|
||||
|
||||
async def check(self):
|
||||
start_time = time.time()
|
||||
try:
|
||||
# 只查询一次数据库,使用 DataAccess 的缓存机制
|
||||
try:
|
||||
self.group_data = await asyncio.wait_for(
|
||||
self.group_dao.safe_get_or_none(
|
||||
group_id=self.group_id, channel_id__isnull=True
|
||||
),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"查询群组数据超时: {self.group_id}", LOGGER_COMMAND)
|
||||
return # 超时时不阻塞,继续执行
|
||||
|
||||
# 检查超级用户禁用
|
||||
if (
|
||||
self.group_data
|
||||
and CommonUtils.format(self.plugin.module)
|
||||
in self.group_data.superuser_block_plugin
|
||||
):
|
||||
if freq.is_send_limit_message(self.plugin, self.group_id, self.is_poke):
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
send_message(
|
||||
self.session,
|
||||
"超级管理员禁用了该群此功能...",
|
||||
self.group_id,
|
||||
),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"发送消息超时: {self.group_id}", LOGGER_COMMAND)
|
||||
raise SkipPluginException(
|
||||
f"{self.plugin.name}({self.plugin.module})"
|
||||
f" 超级管理员禁用了该群此功能..."
|
||||
)
|
||||
|
||||
# 检查普通禁用
|
||||
if (
|
||||
self.group_data
|
||||
and CommonUtils.format(self.plugin.module)
|
||||
in self.group_data.block_plugin
|
||||
):
|
||||
if freq.is_send_limit_message(self.plugin, self.group_id, self.is_poke):
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
send_message(
|
||||
self.session, "该群未开启此功能...", self.group_id
|
||||
),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"发送消息超时: {self.group_id}", LOGGER_COMMAND)
|
||||
raise SkipPluginException(
|
||||
f"{self.plugin.name}({self.plugin.module}) 未开启此功能..."
|
||||
)
|
||||
|
||||
# 检查全局禁用
|
||||
if self.plugin.block_type == BlockType.GROUP:
|
||||
if freq.is_send_limit_message(self.plugin, self.group_id, self.is_poke):
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
send_message(
|
||||
self.session, "该功能在群组中已被禁用...", self.group_id
|
||||
),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"发送消息超时: {self.group_id}", LOGGER_COMMAND)
|
||||
raise SkipPluginException(
|
||||
f"{self.plugin.name}({self.plugin.module})该插件在群组中已被禁用..."
|
||||
)
|
||||
finally:
|
||||
# 记录执行时间
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||
logger.warning(
|
||||
f"GroupCheck.check 耗时: {elapsed:.3f}s, 群组: {self.group_id}",
|
||||
LOGGER_COMMAND,
|
||||
)
|
||||
|
||||
|
||||
class PluginCheck:
|
||||
def __init__(self, group_id: str | None, session: Uninfo, is_poke: bool):
|
||||
self.session = session
|
||||
self.is_poke = is_poke
|
||||
self.group_id = group_id
|
||||
self.group_dao = DataAccess(GroupConsole)
|
||||
self.group_data = None
|
||||
|
||||
async def check_user(self, plugin: PluginInfo):
|
||||
"""全局私聊禁用检测
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
|
||||
异常:
|
||||
IgnoredException: 忽略插件
|
||||
"""
|
||||
if plugin.block_type == BlockType.PRIVATE:
|
||||
if freq.is_send_limit_message(plugin, self.session.user.id, self.is_poke):
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
send_message(self.session, "该功能在私聊中已被禁用..."),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("发送消息超时", LOGGER_COMMAND)
|
||||
raise SkipPluginException(
|
||||
f"{plugin.name}({plugin.module}) 该插件在私聊中已被禁用..."
|
||||
)
|
||||
|
||||
async def check_global(self, plugin: PluginInfo):
|
||||
"""全局状态
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
|
||||
异常:
|
||||
IgnoredException: 忽略插件
|
||||
"""
|
||||
start_time = time.time()
|
||||
try:
|
||||
if plugin.status or plugin.block_type != BlockType.ALL:
|
||||
return
|
||||
"""全局状态"""
|
||||
if self.group_id:
|
||||
# 使用 DataAccess 的缓存机制
|
||||
try:
|
||||
self.group_data = await asyncio.wait_for(
|
||||
self.group_dao.safe_get_or_none(
|
||||
group_id=self.group_id, channel_id__isnull=True
|
||||
),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"查询群组数据超时: {self.group_id}", LOGGER_COMMAND)
|
||||
return # 超时时不阻塞,继续执行
|
||||
|
||||
if self.group_data and self.group_data.is_super:
|
||||
raise IsSuperuserException()
|
||||
|
||||
sid = self.group_id or self.session.user.id
|
||||
if freq.is_send_limit_message(plugin, sid, self.is_poke):
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
send_message(self.session, "全局未开启此功能...", sid),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"发送消息超时: {sid}", LOGGER_COMMAND)
|
||||
raise SkipPluginException(
|
||||
f"{plugin.name}({plugin.module}) 全局未开启此功能..."
|
||||
)
|
||||
finally:
|
||||
# 记录执行时间
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||
logger.warning(
|
||||
f"PluginCheck.check_global 耗时: {elapsed:.3f}s", LOGGER_COMMAND
|
||||
)
|
||||
|
||||
|
||||
async def auth_plugin(plugin: PluginInfo, session: Uninfo, event: Event):
|
||||
"""插件状态
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
session: Uninfo
|
||||
event: Event
|
||||
"""
|
||||
start_time = time.time()
|
||||
try:
|
||||
entity = get_entity_ids(session)
|
||||
is_poke_event = is_poke(event)
|
||||
user_check = PluginCheck(entity.group_id, session, is_poke_event)
|
||||
|
||||
if entity.group_id:
|
||||
group_check = GroupCheck(plugin, entity.group_id, session, is_poke_event)
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
group_check.check(), timeout=DB_TIMEOUT_SECONDS * 2
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"群组检查超时: {entity.group_id}", LOGGER_COMMAND)
|
||||
# 超时时不阻塞,继续执行
|
||||
else:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
user_check.check_user(plugin), timeout=DB_TIMEOUT_SECONDS
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("用户检查超时", LOGGER_COMMAND)
|
||||
# 超时时不阻塞,继续执行
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
user_check.check_global(plugin), timeout=DB_TIMEOUT_SECONDS
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("全局检查超时", LOGGER_COMMAND)
|
||||
# 超时时不阻塞,继续执行
|
||||
finally:
|
||||
# 记录总执行时间
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||
logger.warning(
|
||||
f"auth_plugin 总耗时: {elapsed:.3f}s, 模块: {plugin.module}",
|
||||
LOGGER_COMMAND,
|
||||
)
|
||||
35
zhenxun/builtin_plugins/hooks/auth/bot_filter.py
Normal file
35
zhenxun/builtin_plugins/hooks/auth/bot_filter.py
Normal file
@ -0,0 +1,35 @@
|
||||
import nonebot
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
|
||||
from .exception import SkipPluginException
|
||||
|
||||
Config.add_plugin_config(
|
||||
"hook",
|
||||
"FILTER_BOT",
|
||||
True,
|
||||
help="过滤当前连接bot(防止bot互相调用)",
|
||||
default_value=True,
|
||||
type=bool,
|
||||
)
|
||||
|
||||
|
||||
def bot_filter(session: Uninfo):
|
||||
"""过滤bot调用bot
|
||||
|
||||
参数:
|
||||
session: Uninfo
|
||||
|
||||
异常:
|
||||
SkipPluginException: bot互相调用
|
||||
"""
|
||||
if not Config.get_config("hook", "FILTER_BOT"):
|
||||
return
|
||||
bot_ids = list(nonebot.get_bots().keys())
|
||||
if session.user.id == session.self_id:
|
||||
return
|
||||
if session.user.id in bot_ids:
|
||||
raise SkipPluginException(
|
||||
f"bot:{session.self_id} 尝试调用 bot:{session.user.id}"
|
||||
)
|
||||
16
zhenxun/builtin_plugins/hooks/auth/config.py
Normal file
16
zhenxun/builtin_plugins/hooks/auth/config.py
Normal file
@ -0,0 +1,16 @@
|
||||
import sys
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from enum import StrEnum
|
||||
else:
|
||||
from strenum import StrEnum
|
||||
|
||||
LOGGER_COMMAND = "AuthChecker"
|
||||
|
||||
|
||||
class SwitchEnum(StrEnum):
|
||||
ENABLE = "醒来"
|
||||
DISABLE = "休息吧"
|
||||
|
||||
|
||||
WARNING_THRESHOLD = 0.5 # 警告阈值(秒)
|
||||
26
zhenxun/builtin_plugins/hooks/auth/exception.py
Normal file
26
zhenxun/builtin_plugins/hooks/auth/exception.py
Normal file
@ -0,0 +1,26 @@
|
||||
class IsSuperuserException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class SkipPluginException(Exception):
|
||||
def __init__(self, info: str, *args: object) -> None:
|
||||
super().__init__(*args)
|
||||
self.info = info
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.info
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.info
|
||||
|
||||
|
||||
class PermissionExemption(Exception):
|
||||
def __init__(self, info: str, *args: object) -> None:
|
||||
super().__init__(*args)
|
||||
self.info = info
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.info
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.info
|
||||
91
zhenxun/builtin_plugins/hooks/auth/utils.py
Normal file
91
zhenxun/builtin_plugins/hooks/auth/utils.py
Normal file
@ -0,0 +1,91 @@
|
||||
import contextlib
|
||||
|
||||
from nonebot.adapters import Event
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import PluginType
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
from zhenxun.utils.utils import FreqLimiter
|
||||
|
||||
from .config import LOGGER_COMMAND
|
||||
|
||||
base_config = Config.get("hook")
|
||||
|
||||
|
||||
def is_poke(event: Event) -> bool:
|
||||
"""判断是否为poke类型
|
||||
|
||||
参数:
|
||||
event: Event
|
||||
|
||||
返回:
|
||||
bool: 是否为poke类型
|
||||
"""
|
||||
with contextlib.suppress(ImportError):
|
||||
from nonebot.adapters.onebot.v11 import PokeNotifyEvent
|
||||
|
||||
return isinstance(event, PokeNotifyEvent)
|
||||
return False
|
||||
|
||||
|
||||
async def send_message(
|
||||
session: Uninfo, message: list | str, check_tag: str | None = None
|
||||
):
|
||||
"""发送消息
|
||||
|
||||
参数:
|
||||
session: Uninfo
|
||||
message: 消息
|
||||
check_tag: cd flag
|
||||
"""
|
||||
try:
|
||||
if not check_tag:
|
||||
await MessageUtils.build_message(message).send(reply_to=True)
|
||||
elif freq._flmt.check(check_tag):
|
||||
freq._flmt.start_cd(check_tag)
|
||||
await MessageUtils.build_message(message).send(reply_to=True)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"发送消息失败",
|
||||
LOGGER_COMMAND,
|
||||
session=session,
|
||||
e=e,
|
||||
)
|
||||
|
||||
|
||||
class FreqUtils:
|
||||
def __init__(self):
|
||||
check_notice_info_cd = Config.get_config("hook", "CHECK_NOTICE_INFO_CD")
|
||||
if check_notice_info_cd is None or check_notice_info_cd < 0:
|
||||
raise ValueError("模块: [hook], 配置项: [CHECK_NOTICE_INFO_CD] 为空或小于0")
|
||||
self._flmt = FreqLimiter(check_notice_info_cd)
|
||||
self._flmt_g = FreqLimiter(check_notice_info_cd)
|
||||
self._flmt_s = FreqLimiter(check_notice_info_cd)
|
||||
self._flmt_c = FreqLimiter(check_notice_info_cd)
|
||||
|
||||
def is_send_limit_message(
|
||||
self, plugin: PluginInfo, sid: str, is_poke: bool
|
||||
) -> bool:
|
||||
"""是否发送提示消息
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
sid: 检测键
|
||||
is_poke: 是否是戳一戳
|
||||
|
||||
返回:
|
||||
bool: 是否发送提示消息
|
||||
"""
|
||||
if is_poke:
|
||||
return False
|
||||
if not base_config.get("IS_SEND_TIP_MESSAGE"):
|
||||
return False
|
||||
if plugin.plugin_type == PluginType.DEPENDANT:
|
||||
return False
|
||||
return plugin.module != "ai" if self._flmt_s.check(sid) else False
|
||||
|
||||
|
||||
freq = FreqUtils()
|
||||
375
zhenxun/builtin_plugins/hooks/auth_checker.py
Normal file
375
zhenxun/builtin_plugins/hooks/auth_checker.py
Normal file
@ -0,0 +1,375 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot.exception import IgnoredException
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot_plugin_alconna import UniMsg
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
from tortoise.exceptions import IntegrityError
|
||||
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.models.user_console import UserConsole
|
||||
from zhenxun.services.data_access import DataAccess
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import GoldHandle, PluginType
|
||||
from zhenxun.utils.exception import InsufficientGold
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
from zhenxun.utils.utils import get_entity_ids
|
||||
|
||||
from .auth.auth_admin import auth_admin
|
||||
from .auth.auth_ban import auth_ban
|
||||
from .auth.auth_bot import auth_bot
|
||||
from .auth.auth_cost import auth_cost
|
||||
from .auth.auth_group import auth_group
|
||||
from .auth.auth_limit import LimitManager, auth_limit
|
||||
from .auth.auth_plugin import auth_plugin
|
||||
from .auth.bot_filter import bot_filter
|
||||
from .auth.config import LOGGER_COMMAND, WARNING_THRESHOLD
|
||||
from .auth.exception import (
|
||||
IsSuperuserException,
|
||||
PermissionExemption,
|
||||
SkipPluginException,
|
||||
)
|
||||
|
||||
# 超时设置(秒)
|
||||
TIMEOUT_SECONDS = 5.0
|
||||
# 熔断计数器
|
||||
CIRCUIT_BREAKERS = {
|
||||
"auth_ban": {"failures": 0, "threshold": 3, "active": False, "reset_time": 0},
|
||||
"auth_bot": {"failures": 0, "threshold": 3, "active": False, "reset_time": 0},
|
||||
"auth_group": {"failures": 0, "threshold": 3, "active": False, "reset_time": 0},
|
||||
"auth_admin": {"failures": 0, "threshold": 3, "active": False, "reset_time": 0},
|
||||
"auth_plugin": {"failures": 0, "threshold": 3, "active": False, "reset_time": 0},
|
||||
"auth_limit": {"failures": 0, "threshold": 3, "active": False, "reset_time": 0},
|
||||
}
|
||||
# 熔断重置时间(秒)
|
||||
CIRCUIT_RESET_TIME = 300 # 5分钟
|
||||
|
||||
|
||||
# 超时装饰器
|
||||
async def with_timeout(coro, timeout=TIMEOUT_SECONDS, name=None):
|
||||
"""带超时控制的协程执行
|
||||
|
||||
参数:
|
||||
coro: 要执行的协程
|
||||
timeout: 超时时间(秒)
|
||||
name: 操作名称,用于日志记录
|
||||
|
||||
返回:
|
||||
协程的返回值,或者在超时时抛出 TimeoutError
|
||||
"""
|
||||
try:
|
||||
return await asyncio.wait_for(coro, timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
if name:
|
||||
logger.error(f"{name} 操作超时 (>{timeout}s)", LOGGER_COMMAND)
|
||||
# 更新熔断计数器
|
||||
if name in CIRCUIT_BREAKERS:
|
||||
CIRCUIT_BREAKERS[name]["failures"] += 1
|
||||
if (
|
||||
CIRCUIT_BREAKERS[name]["failures"]
|
||||
>= CIRCUIT_BREAKERS[name]["threshold"]
|
||||
and not CIRCUIT_BREAKERS[name]["active"]
|
||||
):
|
||||
CIRCUIT_BREAKERS[name]["active"] = True
|
||||
CIRCUIT_BREAKERS[name]["reset_time"] = (
|
||||
time.time() + CIRCUIT_RESET_TIME
|
||||
)
|
||||
logger.warning(
|
||||
f"{name} 熔断器已激活,将在 {CIRCUIT_RESET_TIME} 秒后重置",
|
||||
LOGGER_COMMAND,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
# 检查熔断状态
|
||||
def check_circuit_breaker(name):
|
||||
"""检查熔断器状态
|
||||
|
||||
参数:
|
||||
name: 操作名称
|
||||
|
||||
返回:
|
||||
bool: 是否已熔断
|
||||
"""
|
||||
if name not in CIRCUIT_BREAKERS:
|
||||
return False
|
||||
|
||||
# 检查是否需要重置熔断器
|
||||
if (
|
||||
CIRCUIT_BREAKERS[name]["active"]
|
||||
and time.time() > CIRCUIT_BREAKERS[name]["reset_time"]
|
||||
):
|
||||
CIRCUIT_BREAKERS[name]["active"] = False
|
||||
CIRCUIT_BREAKERS[name]["failures"] = 0
|
||||
logger.info(f"{name} 熔断器已重置", LOGGER_COMMAND)
|
||||
|
||||
return CIRCUIT_BREAKERS[name]["active"]
|
||||
|
||||
|
||||
async def get_plugin_and_user(
|
||||
module: str, user_id: str
|
||||
) -> tuple[PluginInfo, UserConsole]:
|
||||
"""获取用户数据和插件信息
|
||||
|
||||
参数:
|
||||
module: 模块名
|
||||
user_id: 用户id
|
||||
|
||||
异常:
|
||||
PermissionExemption: 插件数据不存在
|
||||
PermissionExemption: 插件类型为HIDDEN
|
||||
PermissionExemption: 重复创建用户
|
||||
PermissionExemption: 用户数据不存在
|
||||
|
||||
返回:
|
||||
tuple[PluginInfo, UserConsole]: 插件信息,用户信息
|
||||
"""
|
||||
user_dao = DataAccess(UserConsole)
|
||||
plugin_dao = DataAccess(PluginInfo)
|
||||
|
||||
# 并行查询插件和用户数据
|
||||
plugin_task = plugin_dao.safe_get_or_none(module=module)
|
||||
user_task = user_dao.safe_get_or_none(user_id=user_id)
|
||||
|
||||
try:
|
||||
plugin, user = await with_timeout(
|
||||
asyncio.gather(plugin_task, user_task), name="get_plugin_and_user"
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
# 如果并行查询超时,尝试串行查询
|
||||
logger.warning("并行查询超时,尝试串行查询", LOGGER_COMMAND)
|
||||
plugin = await with_timeout(
|
||||
plugin_dao.safe_get_or_none(module=module), name="get_plugin"
|
||||
)
|
||||
user = await with_timeout(
|
||||
user_dao.safe_get_or_none(user_id=user_id), name="get_user"
|
||||
)
|
||||
|
||||
if not plugin:
|
||||
raise PermissionExemption(f"插件:{module} 数据不存在,已跳过权限检查...")
|
||||
if plugin.plugin_type == PluginType.HIDDEN:
|
||||
raise PermissionExemption(
|
||||
f"插件: {plugin.name}:{plugin.module} 为HIDDEN,已跳过权限检查..."
|
||||
)
|
||||
user = None
|
||||
try:
|
||||
user = await user_dao.safe_get_or_none(user_id=user_id)
|
||||
except IntegrityError as e:
|
||||
raise PermissionExemption("重复创建用户,已跳过该次权限检查...") from e
|
||||
if not user:
|
||||
raise PermissionExemption("用户数据不存在,已跳过权限检查...")
|
||||
return plugin, user
|
||||
|
||||
|
||||
async def get_plugin_cost(
|
||||
bot: Bot, user: UserConsole, plugin: PluginInfo, session: Uninfo
|
||||
) -> int:
|
||||
"""获取插件费用
|
||||
|
||||
参数:
|
||||
bot: Bot
|
||||
user: 用户数据
|
||||
plugin: 插件数据
|
||||
session: Uninfo
|
||||
|
||||
异常:
|
||||
IsSuperuserException: 超级用户
|
||||
IsSuperuserException: 超级用户
|
||||
|
||||
返回:
|
||||
int: 调用插件金币费用
|
||||
"""
|
||||
cost_gold = await with_timeout(auth_cost(user, plugin, session), name="auth_cost")
|
||||
if session.user.id in bot.config.superusers:
|
||||
if plugin.plugin_type == PluginType.SUPERUSER:
|
||||
raise IsSuperuserException()
|
||||
if not plugin.limit_superuser:
|
||||
raise IsSuperuserException()
|
||||
return cost_gold
|
||||
|
||||
|
||||
async def reduce_gold(user_id: str, module: str, cost_gold: int, session: Uninfo):
|
||||
"""扣除用户金币
|
||||
|
||||
参数:
|
||||
user_id: 用户id
|
||||
module: 插件模块名称
|
||||
cost_gold: 消耗金币
|
||||
session: Uninfo
|
||||
"""
|
||||
user_dao = DataAccess(UserConsole)
|
||||
try:
|
||||
await with_timeout(
|
||||
UserConsole.reduce_gold(
|
||||
user_id,
|
||||
cost_gold,
|
||||
GoldHandle.PLUGIN,
|
||||
module,
|
||||
PlatformUtils.get_platform(session),
|
||||
),
|
||||
name="reduce_gold",
|
||||
)
|
||||
except InsufficientGold:
|
||||
if u := await UserConsole.get_user(user_id):
|
||||
u.gold = 0
|
||||
await u.save(update_fields=["gold"])
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
f"扣除金币超时,用户: {user_id}, 金币: {cost_gold}",
|
||||
LOGGER_COMMAND,
|
||||
session=session,
|
||||
)
|
||||
|
||||
# 清除缓存,使下次查询时从数据库获取最新数据
|
||||
await user_dao.clear_cache(user_id=user_id)
|
||||
logger.debug(f"调用功能花费金币: {cost_gold}", LOGGER_COMMAND, session=session)
|
||||
|
||||
|
||||
# 辅助函数,用于记录每个 hook 的执行时间
|
||||
async def time_hook(coro, name, time_dict):
|
||||
start = time.time()
|
||||
try:
|
||||
# 检查熔断状态
|
||||
if check_circuit_breaker(name):
|
||||
logger.info(f"{name} 熔断器激活中,跳过执行", LOGGER_COMMAND)
|
||||
time_dict[name] = "熔断跳过"
|
||||
return
|
||||
|
||||
# 添加超时控制
|
||||
return await with_timeout(coro, name=name)
|
||||
except asyncio.TimeoutError:
|
||||
time_dict[name] = f"超时 (>{TIMEOUT_SECONDS}s)"
|
||||
finally:
|
||||
if name not in time_dict:
|
||||
time_dict[name] = f"{time.time() - start:.3f}s"
|
||||
|
||||
|
||||
async def auth(
|
||||
matcher: Matcher,
|
||||
event: Event,
|
||||
bot: Bot,
|
||||
session: Uninfo,
|
||||
message: UniMsg,
|
||||
):
|
||||
"""权限检查
|
||||
|
||||
参数:
|
||||
matcher: matcher
|
||||
event: Event
|
||||
bot: bot
|
||||
session: Uninfo
|
||||
message: UniMsg
|
||||
"""
|
||||
start_time = time.time()
|
||||
cost_gold = 0
|
||||
ignore_flag = False
|
||||
entity = get_entity_ids(session)
|
||||
module = matcher.plugin_name or ""
|
||||
|
||||
# 用于记录各个 hook 的执行时间
|
||||
hook_times = {}
|
||||
hooks_time = 0 # 初始化 hooks_time 变量
|
||||
|
||||
try:
|
||||
if not module:
|
||||
raise PermissionExemption("Matcher插件名称不存在...")
|
||||
|
||||
# 获取插件和用户数据
|
||||
plugin_user_start = time.time()
|
||||
try:
|
||||
plugin, user = await with_timeout(
|
||||
get_plugin_and_user(module, entity.user_id), name="get_plugin_and_user"
|
||||
)
|
||||
hook_times["get_plugin_user"] = f"{time.time() - plugin_user_start:.3f}s"
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
f"获取插件和用户数据超时,模块: {module}",
|
||||
LOGGER_COMMAND,
|
||||
session=session,
|
||||
)
|
||||
raise PermissionExemption("获取插件和用户数据超时,请稍后再试...")
|
||||
|
||||
# 获取插件费用
|
||||
cost_start = time.time()
|
||||
try:
|
||||
cost_gold = await with_timeout(
|
||||
get_plugin_cost(bot, user, plugin, session), name="get_plugin_cost"
|
||||
)
|
||||
hook_times["cost_gold"] = f"{time.time() - cost_start:.3f}s"
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
f"获取插件费用超时,模块: {module}", LOGGER_COMMAND, session=session
|
||||
)
|
||||
# 继续执行,不阻止权限检查
|
||||
|
||||
# 执行 bot_filter
|
||||
bot_filter(session)
|
||||
|
||||
# 并行执行所有 hook 检查,并记录执行时间
|
||||
hooks_start = time.time()
|
||||
|
||||
# 创建所有 hook 任务
|
||||
hook_tasks = [
|
||||
time_hook(auth_ban(matcher, bot, session), "auth_ban", hook_times),
|
||||
time_hook(auth_bot(plugin, bot.self_id), "auth_bot", hook_times),
|
||||
time_hook(auth_group(plugin, entity, message), "auth_group", hook_times),
|
||||
time_hook(auth_admin(plugin, session), "auth_admin", hook_times),
|
||||
time_hook(auth_plugin(plugin, session, event), "auth_plugin", hook_times),
|
||||
time_hook(auth_limit(plugin, session), "auth_limit", hook_times),
|
||||
]
|
||||
|
||||
# 使用 gather 并行执行所有 hook,但添加总体超时控制
|
||||
try:
|
||||
await with_timeout(
|
||||
asyncio.gather(*hook_tasks),
|
||||
timeout=TIMEOUT_SECONDS * 2, # 给总体执行更多时间
|
||||
name="auth_hooks_gather",
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
f"权限检查 hooks 总体执行超时,模块: {module}",
|
||||
LOGGER_COMMAND,
|
||||
session=session,
|
||||
)
|
||||
# 不抛出异常,允许继续执行
|
||||
|
||||
hooks_time = time.time() - hooks_start
|
||||
|
||||
except SkipPluginException as e:
|
||||
LimitManager.unblock(module, entity.user_id, entity.group_id, entity.channel_id)
|
||||
logger.info(str(e), LOGGER_COMMAND, session=session)
|
||||
ignore_flag = True
|
||||
except IsSuperuserException:
|
||||
logger.debug("超级用户跳过权限检测...", LOGGER_COMMAND, session=session)
|
||||
except PermissionExemption as e:
|
||||
logger.info(str(e), LOGGER_COMMAND, session=session)
|
||||
|
||||
# 扣除金币
|
||||
if not ignore_flag and cost_gold > 0:
|
||||
gold_start = time.time()
|
||||
try:
|
||||
await with_timeout(
|
||||
reduce_gold(entity.user_id, module, cost_gold, session),
|
||||
name="reduce_gold",
|
||||
)
|
||||
hook_times["reduce_gold"] = f"{time.time() - gold_start:.3f}s"
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
f"扣除金币超时,模块: {module}", LOGGER_COMMAND, session=session
|
||||
)
|
||||
|
||||
# 记录总执行时间
|
||||
total_time = time.time() - start_time
|
||||
if total_time > WARNING_THRESHOLD: # 如果总时间超过500ms,记录详细信息
|
||||
logger.warning(
|
||||
f"权限检查耗时过长: {total_time:.3f}s, 模块: {module}, "
|
||||
f"hooks时间: {hooks_time:.3f}s, "
|
||||
f"详情: {hook_times}",
|
||||
LOGGER_COMMAND,
|
||||
session=session,
|
||||
)
|
||||
|
||||
if ignore_flag:
|
||||
raise IgnoredException("权限检测 ignore")
|
||||
@ -1,41 +1,43 @@
|
||||
from nonebot.adapters.onebot.v11 import Bot, Event
|
||||
import time
|
||||
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot.message import run_postprocessor, run_preprocessor
|
||||
from nonebot_plugin_alconna import UniMsg
|
||||
from nonebot_plugin_session import EventSession
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from ._auth_checker import LimitManage, checker
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
from .auth.config import LOGGER_COMMAND
|
||||
from .auth_checker import LimitManager, auth
|
||||
|
||||
|
||||
# # 权限检测
|
||||
@run_preprocessor
|
||||
async def _(
|
||||
matcher: Matcher, event: Event, bot: Bot, session: EventSession, message: UniMsg
|
||||
):
|
||||
await checker.auth(
|
||||
async def _(matcher: Matcher, event: Event, bot: Bot, session: Uninfo, message: UniMsg):
|
||||
start_time = time.time()
|
||||
await auth(
|
||||
matcher,
|
||||
event,
|
||||
bot,
|
||||
session,
|
||||
message,
|
||||
)
|
||||
logger.debug(f"权限检测耗时:{time.time() - start_time}秒", LOGGER_COMMAND)
|
||||
|
||||
|
||||
# 解除命令block阻塞
|
||||
@run_postprocessor
|
||||
async def _(
|
||||
matcher: Matcher,
|
||||
exception: Exception | None,
|
||||
bot: Bot,
|
||||
event: Event,
|
||||
session: EventSession,
|
||||
):
|
||||
user_id = session.id1
|
||||
group_id = session.id3
|
||||
channel_id = session.id2
|
||||
if not group_id:
|
||||
group_id = channel_id
|
||||
async def _(matcher: Matcher, session: Uninfo):
|
||||
user_id = session.user.id
|
||||
group_id = None
|
||||
channel_id = None
|
||||
if session.group:
|
||||
if session.group.parent:
|
||||
group_id = session.group.parent.id
|
||||
channel_id = session.group.id
|
||||
else:
|
||||
group_id = session.group.id
|
||||
if user_id and matcher.plugin:
|
||||
module = matcher.plugin.name
|
||||
LimitManage.unblock(module, user_id, group_id, channel_id)
|
||||
LimitManager.unblock(module, user_id, group_id, channel_id)
|
||||
|
||||
@ -1,84 +0,0 @@
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot.exception import IgnoredException
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot.message import run_preprocessor
|
||||
from nonebot.typing import T_State
|
||||
from nonebot_plugin_alconna import At
|
||||
from nonebot_plugin_session import EventSession
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.models.ban_console import BanConsole
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import PluginType
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
from zhenxun.utils.utils import FreqLimiter
|
||||
|
||||
Config.add_plugin_config(
|
||||
"hook",
|
||||
"BAN_RESULT",
|
||||
"才不会给你发消息.",
|
||||
help="对被ban用户发送的消息",
|
||||
)
|
||||
|
||||
_flmt = FreqLimiter(300)
|
||||
|
||||
|
||||
# 检查是否被ban
|
||||
@run_preprocessor
|
||||
async def _(
|
||||
matcher: Matcher, bot: Bot, event: Event, state: T_State, session: EventSession
|
||||
):
|
||||
extra = {}
|
||||
if plugin := matcher.plugin:
|
||||
if metadata := plugin.metadata:
|
||||
extra = metadata.extra
|
||||
if extra.get("plugin_type") in [PluginType.HIDDEN]:
|
||||
return
|
||||
user_id = session.id1
|
||||
group_id = session.id3 or session.id2
|
||||
if group_id:
|
||||
if user_id in bot.config.superusers:
|
||||
return
|
||||
if await BanConsole.is_ban(None, group_id):
|
||||
logger.debug("群组处于黑名单中...", "ban_hook")
|
||||
raise IgnoredException("群组处于黑名单中...")
|
||||
if g := await GroupConsole.get_group(group_id):
|
||||
if g.level < 0:
|
||||
logger.debug("群黑名单, 群权限-1...", "ban_hook")
|
||||
raise IgnoredException("群黑名单, 群权限-1..")
|
||||
if user_id:
|
||||
ban_result = Config.get_config("hook", "BAN_RESULT")
|
||||
if user_id in bot.config.superusers:
|
||||
return
|
||||
if await BanConsole.is_ban(user_id, group_id):
|
||||
time = await BanConsole.check_ban_time(user_id, group_id)
|
||||
if time == -1:
|
||||
time_str = "∞"
|
||||
else:
|
||||
time = abs(int(time))
|
||||
if time < 60:
|
||||
time_str = f"{time!s} 秒"
|
||||
else:
|
||||
minute = int(time / 60)
|
||||
if minute > 60:
|
||||
hours = minute // 60
|
||||
minute %= 60
|
||||
time_str = f"{hours} 小时 {minute}分钟"
|
||||
else:
|
||||
time_str = f"{minute} 分钟"
|
||||
if (
|
||||
not extra.get("ignore_prompt")
|
||||
and time != -1
|
||||
and ban_result
|
||||
and _flmt.check(user_id)
|
||||
):
|
||||
_flmt.start_cd(user_id)
|
||||
await MessageUtils.build_message(
|
||||
[
|
||||
At(flag="user", target=user_id),
|
||||
f"{ban_result}\n在..在 {time_str} 后才会理你喔",
|
||||
]
|
||||
).send()
|
||||
logger.debug("用户处于黑名单中...", "ban_hook")
|
||||
raise IgnoredException("用户处于黑名单中...")
|
||||
@ -9,6 +9,8 @@ from zhenxun.utils.enum import BotSentType
|
||||
from zhenxun.utils.manager.message_manager import MessageManager
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
LOG_COMMAND = "MessageHook"
|
||||
|
||||
|
||||
def replace_message(message: Message) -> str:
|
||||
"""将消息中的at、image、record、face替换为字符串
|
||||
@ -54,11 +56,11 @@ async def handle_api_result(
|
||||
if user_id and message_id:
|
||||
MessageManager.add(str(user_id), str(message_id))
|
||||
logger.debug(
|
||||
f"收集消息id,user_id: {user_id}, msg_id: {message_id}", "msg_hook"
|
||||
f"收集消息id,user_id: {user_id}, msg_id: {message_id}", LOG_COMMAND
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"收集消息id发生错误...data: {data}, result: {result}", "msg_hook", e=e
|
||||
f"收集消息id发生错误...data: {data}, result: {result}", LOG_COMMAND, e=e
|
||||
)
|
||||
if not Config.get_config("hook", "RECORD_BOT_SENT_MESSAGES"):
|
||||
return
|
||||
@ -80,6 +82,6 @@ async def handle_api_result(
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"消息发送记录发生错误...data: {data}, result: {result}",
|
||||
"msg_hook",
|
||||
LOG_COMMAND,
|
||||
e=e,
|
||||
)
|
||||
|
||||
@ -4,15 +4,27 @@ import nonebot
|
||||
from nonebot.adapters import Bot
|
||||
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.services.cache import CacheException
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
nonebot.load_plugins(str(Path(__file__).parent.resolve()))
|
||||
|
||||
try:
|
||||
from .__init_cache import register_cache_types
|
||||
except CacheException as e:
|
||||
raise SystemError(f"ERROR:{e}")
|
||||
|
||||
driver = nonebot.get_driver()
|
||||
|
||||
|
||||
@PriorityLifecycle.on_startup(priority=5)
|
||||
async def _():
|
||||
register_cache_types()
|
||||
logger.info("缓存类型注册完成")
|
||||
|
||||
|
||||
@driver.on_bot_connect
|
||||
async def _(bot: Bot):
|
||||
"""将bot已存在的群组添加群认证
|
||||
|
||||
35
zhenxun/builtin_plugins/init/__init_cache.py
Normal file
35
zhenxun/builtin_plugins/init/__init_cache.py
Normal file
@ -0,0 +1,35 @@
|
||||
"""
|
||||
缓存初始化模块
|
||||
|
||||
负责注册各种缓存类型,实现按需缓存机制
|
||||
"""
|
||||
|
||||
from zhenxun.models.ban_console import BanConsole
|
||||
from zhenxun.models.bot_console import BotConsole
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.models.level_user import LevelUser
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.models.user_console import UserConsole
|
||||
from zhenxun.services.cache import CacheRegistry, cache_config
|
||||
from zhenxun.services.cache.config import CacheMode
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import CacheType
|
||||
|
||||
|
||||
# 注册缓存类型
|
||||
def register_cache_types():
|
||||
"""注册所有缓存类型"""
|
||||
CacheRegistry.register(CacheType.PLUGINS, PluginInfo)
|
||||
CacheRegistry.register(CacheType.GROUPS, GroupConsole)
|
||||
CacheRegistry.register(CacheType.BOT, BotConsole)
|
||||
CacheRegistry.register(CacheType.USERS, UserConsole)
|
||||
CacheRegistry.register(
|
||||
CacheType.LEVEL, LevelUser, key_format="{user_id}_{group_id}"
|
||||
)
|
||||
CacheRegistry.register(CacheType.BAN, BanConsole, key_format="{user_id}_{group_id}")
|
||||
|
||||
if cache_config.cache_mode == CacheMode.NONE:
|
||||
logger.info("缓存功能已禁用,将直接从数据库获取数据")
|
||||
else:
|
||||
logger.info(f"已注册所有缓存类型,缓存模式: {cache_config.cache_mode}")
|
||||
logger.info("使用增量缓存模式,数据将按需加载到缓存中")
|
||||
@ -1,3 +1,5 @@
|
||||
import asyncio
|
||||
|
||||
import aiofiles
|
||||
import nonebot
|
||||
from nonebot import get_loaded_plugins
|
||||
@ -112,12 +114,14 @@ async def _():
|
||||
await _handle_setting(plugin, plugin_list, limit_list)
|
||||
create_list = []
|
||||
update_list = []
|
||||
update_task_list = []
|
||||
for plugin in plugin_list:
|
||||
if plugin.module_path not in module2id:
|
||||
create_list.append(plugin)
|
||||
else:
|
||||
plugin.id = module2id[plugin.module_path]
|
||||
await plugin.save(
|
||||
update_task_list.append(
|
||||
plugin.save(
|
||||
update_fields=[
|
||||
"name",
|
||||
"author",
|
||||
@ -127,9 +131,12 @@ async def _():
|
||||
"is_show",
|
||||
]
|
||||
)
|
||||
)
|
||||
update_list.append(plugin)
|
||||
if create_list:
|
||||
await PluginInfo.bulk_create(create_list, 10)
|
||||
if update_task_list:
|
||||
await asyncio.gather(*update_task_list)
|
||||
# if update_list:
|
||||
# # TODO: 批量更新无法更新plugin_type: tortoise.exceptions.OperationalError:
|
||||
# column "superuser" does not exist
|
||||
|
||||
@ -205,7 +205,7 @@ class Manager:
|
||||
self.cd_data: dict[str, PluginCdBlock] = {}
|
||||
if self.cd_file.exists():
|
||||
with open(self.cd_file, encoding="utf8") as f:
|
||||
temp = _yaml.load(f)
|
||||
temp = _yaml.load(f) or {}
|
||||
if "PluginCdLimit" in temp.keys():
|
||||
for k, v in temp["PluginCdLimit"].items():
|
||||
if "." in k:
|
||||
@ -216,7 +216,7 @@ class Manager:
|
||||
self.block_data: dict[str, BaseBlock] = {}
|
||||
if self.block_file.exists():
|
||||
with open(self.block_file, encoding="utf8") as f:
|
||||
temp = _yaml.load(f)
|
||||
temp = _yaml.load(f) or {}
|
||||
if "PluginBlockLimit" in temp.keys():
|
||||
for k, v in temp["PluginBlockLimit"].items():
|
||||
if "." in k:
|
||||
@ -227,7 +227,7 @@ class Manager:
|
||||
self.count_data: dict[str, PluginCountBlock] = {}
|
||||
if self.count_file.exists():
|
||||
with open(self.count_file, encoding="utf8") as f:
|
||||
temp = _yaml.load(f)
|
||||
temp = _yaml.load(f) or {}
|
||||
if "PluginCountLimit" in temp.keys():
|
||||
for k, v in temp["PluginCountLimit"].items():
|
||||
if "." in k:
|
||||
|
||||
171
zhenxun/builtin_plugins/llm_manager/__init__.py
Normal file
171
zhenxun/builtin_plugins/llm_manager/__init__.py
Normal file
@ -0,0 +1,171 @@
|
||||
from nonebot.permission import SUPERUSER
|
||||
from nonebot.plugin import PluginMetadata
|
||||
from nonebot_plugin_alconna import (
|
||||
Alconna,
|
||||
Args,
|
||||
Arparma,
|
||||
Match,
|
||||
Option,
|
||||
Query,
|
||||
Subcommand,
|
||||
on_alconna,
|
||||
store_true,
|
||||
)
|
||||
|
||||
from zhenxun.configs.utils import PluginExtraData
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import PluginType
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
|
||||
from .data_source import DataSource
|
||||
from .presenters import Presenters
|
||||
|
||||
__plugin_meta__ = PluginMetadata(
|
||||
name="LLM模型管理",
|
||||
description="查看和管理大语言模型服务。",
|
||||
usage="""
|
||||
LLM模型管理 (SUPERUSER)
|
||||
|
||||
llm list [--all]
|
||||
- 查看可用模型列表。
|
||||
- --all: 显示包括不可用在内的所有模型。
|
||||
|
||||
llm info <Provider/ModelName>
|
||||
- 查看指定模型的详细信息和能力。
|
||||
|
||||
llm default [Provider/ModelName]
|
||||
- 查看或设置全局默认模型。
|
||||
- 不带参数: 查看当前默认模型。
|
||||
- 带参数: 设置新的默认模型。
|
||||
- 例子: llm default Gemini/gemini-2.0-flash
|
||||
|
||||
llm test <Provider/ModelName>
|
||||
- 测试指定模型的连通性和API Key有效性。
|
||||
|
||||
llm keys <ProviderName>
|
||||
- 查看指定提供商的所有API Key状态。
|
||||
|
||||
llm reset-key <ProviderName> [--key <api_key>]
|
||||
- 重置提供商的所有或指定API Key的失败状态。
|
||||
""",
|
||||
extra=PluginExtraData(
|
||||
author="HibiKier",
|
||||
version="1.0.0",
|
||||
plugin_type=PluginType.SUPERUSER,
|
||||
).to_dict(),
|
||||
)
|
||||
|
||||
llm_cmd = on_alconna(
|
||||
Alconna(
|
||||
"llm",
|
||||
Subcommand("list", alias=["ls"], help_text="查看模型列表"),
|
||||
Subcommand("info", Args["model_name", str], help_text="查看模型详情"),
|
||||
Subcommand("default", Args["model_name?", str], help_text="查看或设置默认模型"),
|
||||
Subcommand(
|
||||
"test", Args["model_name", str], alias=["ping"], help_text="测试模型连通性"
|
||||
),
|
||||
Subcommand("keys", Args["provider_name", str], help_text="查看API密钥状态"),
|
||||
Subcommand(
|
||||
"reset-key",
|
||||
Args["provider_name", str],
|
||||
Option("--key", Args["api_key", str], help_text="指定要重置的API Key"),
|
||||
help_text="重置API Key状态",
|
||||
),
|
||||
Option("--all", action=store_true, help_text="显示所有条目"),
|
||||
),
|
||||
permission=SUPERUSER,
|
||||
priority=5,
|
||||
block=True,
|
||||
)
|
||||
|
||||
|
||||
@llm_cmd.assign("list")
|
||||
async def handle_list(arp: Arparma, show_all: Query[bool] = Query("all")):
|
||||
"""处理 'llm list' 命令"""
|
||||
logger.info("获取LLM模型列表", command="LLM Manage", session=arp.header_result)
|
||||
models = await DataSource.get_model_list(show_all=show_all.result)
|
||||
|
||||
image = await Presenters.format_model_list_as_image(models, show_all.result)
|
||||
await llm_cmd.finish(MessageUtils.build_message(image))
|
||||
|
||||
|
||||
@llm_cmd.assign("info")
|
||||
async def handle_info(arp: Arparma, model_name: Match[str]):
|
||||
"""处理 'llm info' 命令"""
|
||||
logger.info(
|
||||
f"获取模型详情: {model_name.result}",
|
||||
command="LLM Manage",
|
||||
session=arp.header_result,
|
||||
)
|
||||
details = await DataSource.get_model_details(model_name.result)
|
||||
if not details:
|
||||
await llm_cmd.finish(f"未找到模型: {model_name.result}")
|
||||
|
||||
image_bytes = await Presenters.format_model_details_as_markdown_image(details)
|
||||
await llm_cmd.finish(MessageUtils.build_message(image_bytes))
|
||||
|
||||
|
||||
@llm_cmd.assign("default")
|
||||
async def handle_default(arp: Arparma, model_name: Match[str]):
|
||||
"""处理 'llm default' 命令"""
|
||||
if model_name.available:
|
||||
logger.info(
|
||||
f"设置默认模型为: {model_name.result}",
|
||||
command="LLM Manage",
|
||||
session=arp.header_result,
|
||||
)
|
||||
success, message = await DataSource.set_default_model(model_name.result)
|
||||
await llm_cmd.finish(message)
|
||||
else:
|
||||
logger.info("查看默认模型", command="LLM Manage", session=arp.header_result)
|
||||
current_default = await DataSource.get_default_model()
|
||||
await llm_cmd.finish(f"当前全局默认模型为: {current_default or '未设置'}")
|
||||
|
||||
|
||||
@llm_cmd.assign("test")
|
||||
async def handle_test(arp: Arparma, model_name: Match[str]):
|
||||
"""处理 'llm test' 命令"""
|
||||
logger.info(
|
||||
f"测试模型连通性: {model_name.result}",
|
||||
command="LLM Manage",
|
||||
session=arp.header_result,
|
||||
)
|
||||
await llm_cmd.send(f"正在测试模型 '{model_name.result}',请稍候...")
|
||||
|
||||
success, message = await DataSource.test_model_connectivity(model_name.result)
|
||||
await llm_cmd.finish(message)
|
||||
|
||||
|
||||
@llm_cmd.assign("keys")
|
||||
async def handle_keys(arp: Arparma, provider_name: Match[str]):
|
||||
"""处理 'llm keys' 命令"""
|
||||
logger.info(
|
||||
f"查看提供商API Key状态: {provider_name.result}",
|
||||
command="LLM Manage",
|
||||
session=arp.header_result,
|
||||
)
|
||||
sorted_stats = await DataSource.get_key_status(provider_name.result)
|
||||
if not sorted_stats:
|
||||
await llm_cmd.finish(
|
||||
f"未找到提供商 '{provider_name.result}' 或其没有配置API Keys。"
|
||||
)
|
||||
|
||||
image = await Presenters.format_key_status_as_image(
|
||||
provider_name.result, sorted_stats
|
||||
)
|
||||
await llm_cmd.finish(MessageUtils.build_message(image))
|
||||
|
||||
|
||||
@llm_cmd.assign("reset-key")
|
||||
async def handle_reset_key(
|
||||
arp: Arparma, provider_name: Match[str], api_key: Match[str]
|
||||
):
|
||||
"""处理 'llm reset-key' 命令"""
|
||||
key_to_reset = api_key.result if api_key.available else None
|
||||
log_msg = f"重置 {provider_name.result} 的 " + (
|
||||
"指定API Key" if key_to_reset else "所有API Keys"
|
||||
)
|
||||
logger.info(log_msg, command="LLM Manage", session=arp.header_result)
|
||||
|
||||
success, message = await DataSource.reset_key(provider_name.result, key_to_reset)
|
||||
await llm_cmd.finish(message)
|
||||
120
zhenxun/builtin_plugins/llm_manager/data_source.py
Normal file
120
zhenxun/builtin_plugins/llm_manager/data_source.py
Normal file
@ -0,0 +1,120 @@
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from zhenxun.services.llm import (
|
||||
LLMException,
|
||||
get_global_default_model_name,
|
||||
get_model_instance,
|
||||
list_available_models,
|
||||
set_global_default_model_name,
|
||||
)
|
||||
from zhenxun.services.llm.core import KeyStatus
|
||||
from zhenxun.services.llm.manager import (
|
||||
reset_key_status,
|
||||
)
|
||||
|
||||
|
||||
class DataSource:
|
||||
"""LLM管理插件的数据源和业务逻辑"""
|
||||
|
||||
@staticmethod
|
||||
async def get_model_list(show_all: bool = False) -> list[dict[str, Any]]:
|
||||
"""获取模型列表"""
|
||||
models = list_available_models()
|
||||
if show_all:
|
||||
return models
|
||||
return [m for m in models if m.get("is_available", True)]
|
||||
|
||||
@staticmethod
|
||||
async def get_model_details(model_name_str: str) -> dict[str, Any] | None:
|
||||
"""获取指定模型的详细信息"""
|
||||
try:
|
||||
model = await get_model_instance(model_name_str)
|
||||
return {
|
||||
"provider_config": model.provider_config,
|
||||
"model_detail": model.model_detail,
|
||||
"capabilities": model.capabilities,
|
||||
}
|
||||
except LLMException:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def get_default_model() -> str | None:
|
||||
"""获取全局默认模型"""
|
||||
return get_global_default_model_name()
|
||||
|
||||
@staticmethod
|
||||
async def set_default_model(model_name_str: str) -> tuple[bool, str]:
|
||||
"""设置全局默认模型"""
|
||||
success = set_global_default_model_name(model_name_str)
|
||||
if success:
|
||||
return True, f"✅ 成功将默认模型设置为: {model_name_str}"
|
||||
else:
|
||||
return False, f"❌ 设置失败,模型 '{model_name_str}' 不存在或无效。"
|
||||
|
||||
@staticmethod
|
||||
async def test_model_connectivity(model_name_str: str) -> tuple[bool, str]:
|
||||
"""测试模型连通性"""
|
||||
start_time = time.monotonic()
|
||||
try:
|
||||
async with await get_model_instance(model_name_str) as model:
|
||||
await model.generate_text("你好")
|
||||
end_time = time.monotonic()
|
||||
latency = (end_time - start_time) * 1000
|
||||
return (
|
||||
True,
|
||||
f"✅ 模型 '{model_name_str}' 连接成功!\n响应延迟: {latency:.2f} ms",
|
||||
)
|
||||
except LLMException as e:
|
||||
return (
|
||||
False,
|
||||
f"❌ 模型 '{model_name_str}' 连接测试失败:\n"
|
||||
f"{e.user_friendly_message}\n错误码: {e.code.name}",
|
||||
)
|
||||
except Exception as e:
|
||||
return False, f"❌ 测试时发生未知错误: {e!s}"
|
||||
|
||||
@staticmethod
|
||||
async def get_key_status(provider_name: str) -> list[dict[str, Any]] | None:
|
||||
"""获取并排序指定提供商的API Key状态"""
|
||||
from zhenxun.services.llm.manager import get_key_usage_stats
|
||||
|
||||
all_stats = await get_key_usage_stats()
|
||||
provider_stats = all_stats.get(provider_name)
|
||||
|
||||
if not provider_stats or not provider_stats.get("key_stats"):
|
||||
return None
|
||||
|
||||
key_stats_dict = provider_stats["key_stats"]
|
||||
|
||||
stats_list = [
|
||||
{"key_id": key_id, **stats} for key_id, stats in key_stats_dict.items()
|
||||
]
|
||||
|
||||
def sort_key(item: dict[str, Any]):
|
||||
status_priority = item.get("status_enum", KeyStatus.UNUSED).value
|
||||
return (
|
||||
status_priority,
|
||||
100 - item.get("success_rate", 100.0),
|
||||
-item.get("total_calls", 0),
|
||||
)
|
||||
|
||||
sorted_stats_list = sorted(stats_list, key=sort_key)
|
||||
|
||||
return sorted_stats_list
|
||||
|
||||
@staticmethod
|
||||
async def reset_key(provider_name: str, api_key: str | None) -> tuple[bool, str]:
|
||||
"""重置API Key状态"""
|
||||
success = await reset_key_status(provider_name, api_key)
|
||||
if success:
|
||||
if api_key:
|
||||
if len(api_key) > 8:
|
||||
target = f"API Key '{api_key[:4]}...{api_key[-4:]}'"
|
||||
else:
|
||||
target = f"API Key '{api_key}'"
|
||||
else:
|
||||
target = "所有API Keys"
|
||||
return True, f"✅ 成功重置提供商 '{provider_name}' 的 {target} 的状态。"
|
||||
else:
|
||||
return False, "❌ 重置失败,请检查提供商名称或API Key是否正确。"
|
||||
204
zhenxun/builtin_plugins/llm_manager/presenters.py
Normal file
204
zhenxun/builtin_plugins/llm_manager/presenters.py
Normal file
@ -0,0 +1,204 @@
|
||||
from typing import Any
|
||||
|
||||
from zhenxun.services.llm.core import KeyStatus
|
||||
from zhenxun.services.llm.types import ModelModality
|
||||
from zhenxun.utils._build_image import BuildImage
|
||||
from zhenxun.utils._image_template import ImageTemplate, Markdown, RowStyle
|
||||
|
||||
|
||||
def _format_seconds(seconds: int) -> str:
|
||||
"""将秒数格式化为 'Xm Ys' 或 'Xh Ym' 的形式"""
|
||||
if seconds <= 0:
|
||||
return "0s"
|
||||
if seconds < 60:
|
||||
return f"{seconds}s"
|
||||
|
||||
minutes, seconds = divmod(seconds, 60)
|
||||
if minutes < 60:
|
||||
return f"{minutes}m {seconds}s"
|
||||
|
||||
hours, minutes = divmod(minutes, 60)
|
||||
return f"{hours}h {minutes}m"
|
||||
|
||||
|
||||
class Presenters:
|
||||
"""格式化LLM管理插件的输出 (图片格式)"""
|
||||
|
||||
@staticmethod
|
||||
async def format_model_list_as_image(
|
||||
models: list[dict[str, Any]], show_all: bool
|
||||
) -> BuildImage:
|
||||
"""将模型列表格式化为表格图片"""
|
||||
title = "📋 LLM模型列表" + (" (所有已配置模型)" if show_all else " (仅可用)")
|
||||
|
||||
if not models:
|
||||
return await BuildImage.build_text_image(
|
||||
f"{title}\n\n当前没有配置任何LLM模型。"
|
||||
)
|
||||
|
||||
column_name = ["提供商", "模型名称", "API类型", "状态"]
|
||||
data_list = []
|
||||
for model in models:
|
||||
status_text = "✅ 可用" if model.get("is_available", True) else "❌ 不可用"
|
||||
embed_tag = " (Embed)" if model.get("is_embedding_model", False) else ""
|
||||
data_list.append(
|
||||
[
|
||||
model.get("provider_name", "N/A"),
|
||||
f"{model.get('model_name', 'N/A')}{embed_tag}",
|
||||
model.get("api_type", "N/A"),
|
||||
status_text,
|
||||
]
|
||||
)
|
||||
|
||||
return await ImageTemplate.table_page(
|
||||
head_text=title,
|
||||
tip_text="使用 `llm info <Provider/ModelName>` 查看详情",
|
||||
column_name=column_name,
|
||||
data_list=data_list,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def format_model_details_as_markdown_image(details: dict[str, Any]) -> bytes:
|
||||
"""将模型详情格式化为Markdown图片"""
|
||||
provider = details["provider_config"]
|
||||
model = details["model_detail"]
|
||||
caps = details["capabilities"]
|
||||
|
||||
cap_list = []
|
||||
if ModelModality.IMAGE in caps.input_modalities:
|
||||
cap_list.append("视觉")
|
||||
if ModelModality.VIDEO in caps.input_modalities:
|
||||
cap_list.append("视频")
|
||||
if ModelModality.AUDIO in caps.input_modalities:
|
||||
cap_list.append("音频")
|
||||
if caps.supports_tool_calling:
|
||||
cap_list.append("工具调用")
|
||||
if caps.is_embedding_model:
|
||||
cap_list.append("文本嵌入")
|
||||
|
||||
md = Markdown()
|
||||
md.head(f"🔎 模型详情: {provider.name}/{model.model_name}", level=1)
|
||||
md.text("---")
|
||||
md.head("提供商信息", level=2)
|
||||
md.list(
|
||||
[
|
||||
f"**名称**: {provider.name}",
|
||||
f"**API 类型**: {provider.api_type}",
|
||||
f"**API Base**: {provider.api_base or '默认'}",
|
||||
]
|
||||
)
|
||||
md.head("模型详情", level=2)
|
||||
|
||||
temp_value = model.temperature or provider.temperature or "未设置"
|
||||
token_value = model.max_tokens or provider.max_tokens or "未设置"
|
||||
|
||||
md.list(
|
||||
[
|
||||
f"**名称**: {model.model_name}",
|
||||
f"**默认温度**: {temp_value}",
|
||||
f"**最大Token**: {token_value}",
|
||||
f"**核心能力**: {', '.join(cap_list) or '纯文本'}",
|
||||
]
|
||||
)
|
||||
|
||||
return await md.build()
|
||||
|
||||
@staticmethod
|
||||
async def format_key_status_as_image(
|
||||
provider_name: str, sorted_stats: list[dict[str, Any]]
|
||||
) -> BuildImage:
|
||||
"""将已排序的、详细的API Key状态格式化为表格图片"""
|
||||
title = f"🔑 '{provider_name}' API Key 状态"
|
||||
|
||||
if not sorted_stats:
|
||||
return await BuildImage.build_text_image(
|
||||
f"{title}\n\n该提供商没有配置API Keys。"
|
||||
)
|
||||
|
||||
def _status_row_style(column: str, text: str) -> RowStyle:
|
||||
style = RowStyle()
|
||||
if column == "状态":
|
||||
if "✅ 健康" in text:
|
||||
style.font_color = "#67C23A"
|
||||
elif "⚠️ 告警" in text:
|
||||
style.font_color = "#E6A23C"
|
||||
elif "❌ 错误" in text or "🚫" in text:
|
||||
style.font_color = "#F56C6C"
|
||||
elif "❄️ 冷却中" in text:
|
||||
style.font_color = "#409EFF"
|
||||
elif column == "成功率":
|
||||
try:
|
||||
if text != "N/A":
|
||||
rate = float(text.replace("%", ""))
|
||||
if rate < 80:
|
||||
style.font_color = "#F56C6C"
|
||||
elif rate < 95:
|
||||
style.font_color = "#E6A23C"
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
return style
|
||||
|
||||
column_name = [
|
||||
"Key (部分)",
|
||||
"状态",
|
||||
"总调用",
|
||||
"成功率",
|
||||
"平均延迟(s)",
|
||||
"上次错误",
|
||||
"建议操作",
|
||||
]
|
||||
data_list = []
|
||||
|
||||
for key_info in sorted_stats:
|
||||
status_enum: KeyStatus = key_info["status_enum"]
|
||||
|
||||
if status_enum == KeyStatus.COOLDOWN:
|
||||
cooldown_seconds = int(key_info["cooldown_seconds_left"])
|
||||
formatted_time = _format_seconds(cooldown_seconds)
|
||||
status_text = f"❄️ 冷却中({formatted_time})"
|
||||
else:
|
||||
status_text = {
|
||||
KeyStatus.DISABLED: "🚫 永久禁用",
|
||||
KeyStatus.ERROR: "❌ 错误",
|
||||
KeyStatus.WARNING: "⚠️ 告警",
|
||||
KeyStatus.HEALTHY: "✅ 健康",
|
||||
KeyStatus.UNUSED: "⚪️ 未使用",
|
||||
}.get(status_enum, "❔ 未知")
|
||||
|
||||
total_calls = key_info["total_calls"]
|
||||
total_calls_text = (
|
||||
f"{key_info['success_count']}/{total_calls}"
|
||||
if total_calls > 0
|
||||
else "0/0"
|
||||
)
|
||||
|
||||
success_rate = key_info["success_rate"]
|
||||
success_rate_text = f"{success_rate:.1f}%" if total_calls > 0 else "N/A"
|
||||
|
||||
avg_latency = key_info["avg_latency"]
|
||||
avg_latency_text = f"{avg_latency / 1000:.2f}" if avg_latency > 0 else "N/A"
|
||||
|
||||
last_error = key_info.get("last_error") or "-"
|
||||
if len(last_error) > 25:
|
||||
last_error = last_error[:22] + "..."
|
||||
|
||||
data_list.append(
|
||||
[
|
||||
key_info["key_id"],
|
||||
status_text,
|
||||
total_calls_text,
|
||||
success_rate_text,
|
||||
avg_latency_text,
|
||||
last_error,
|
||||
key_info["suggested_action"],
|
||||
]
|
||||
)
|
||||
|
||||
return await ImageTemplate.table_page(
|
||||
head_text=title,
|
||||
tip_text="使用 `llm reset-key <Provider>` 重置Key状态",
|
||||
column_name=column_name,
|
||||
data_list=data_list,
|
||||
text_style=_status_row_style,
|
||||
column_space=15,
|
||||
)
|
||||
@ -55,15 +55,17 @@ class GroupManager:
|
||||
if plugin_list := await PluginInfo.filter(default_status=False).all():
|
||||
for plugin in plugin_list:
|
||||
block_plugin += f"<{plugin.module},"
|
||||
group_info = await bot.get_group_info(group_id=group_id, no_cache=True)
|
||||
await GroupConsole.create(
|
||||
group_info = await bot.get_group_info(group_id=group_id)
|
||||
await GroupConsole.update_or_create(
|
||||
group_id=group_info["group_id"],
|
||||
group_name=group_info["group_name"],
|
||||
max_member_count=group_info["max_member_count"],
|
||||
member_count=group_info["member_count"],
|
||||
group_flag=1,
|
||||
block_plugin=block_plugin,
|
||||
platform="qq",
|
||||
defaults={
|
||||
"group_name": group_info["group_name"],
|
||||
"max_member_count": group_info["max_member_count"],
|
||||
"member_count": group_info["member_count"],
|
||||
"group_flag": 1,
|
||||
"block_plugin": block_plugin,
|
||||
"platform": "qq",
|
||||
},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -145,7 +147,7 @@ class GroupManager:
|
||||
e=e,
|
||||
)
|
||||
raise ForceAddGroupError("强制拉群或未有群信息,退出群聊失败...") from e
|
||||
await GroupConsole.filter(group_id=group_id).delete()
|
||||
# await GroupConsole.filter(group_id=group_id).delete()
|
||||
raise ForceAddGroupError(f"触发强制入群保护,已成功退出群聊 {group_id}...")
|
||||
else:
|
||||
await cls.__handle_add_group(bot, group_id, group)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from nonebot.message import run_preprocessor
|
||||
from nonebot import on_message
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.models.friend_user import FriendUser
|
||||
@ -8,24 +8,27 @@ from zhenxun.services.log import logger
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
|
||||
@run_preprocessor
|
||||
async def do_something(session: Uninfo):
|
||||
def rule(session: Uninfo) -> bool:
|
||||
return PlatformUtils.is_qbot(session)
|
||||
|
||||
|
||||
_matcher = on_message(priority=999, block=False, rule=rule)
|
||||
|
||||
|
||||
@_matcher.handle()
|
||||
async def _(session: Uninfo):
|
||||
platform = PlatformUtils.get_platform(session)
|
||||
if session.group:
|
||||
if not await GroupConsole.exists(group_id=session.group.id):
|
||||
await GroupConsole.create(group_id=session.group.id)
|
||||
logger.info("添加当前群组ID信息" "", session=session)
|
||||
|
||||
if not await GroupInfoUser.exists(
|
||||
user_id=session.user.id, group_id=session.group.id
|
||||
):
|
||||
await GroupInfoUser.create(
|
||||
user_id=session.user.id, group_id=session.group.id, platform=platform
|
||||
logger.info("添加当前群组ID信息", session=session)
|
||||
await GroupInfoUser.update_or_create(
|
||||
user_id=session.user.id,
|
||||
group_id=session.group.id,
|
||||
platform=PlatformUtils.get_platform(session),
|
||||
)
|
||||
logger.info("添加当前用户群组ID信息", "", session=session)
|
||||
elif not await FriendUser.exists(user_id=session.user.id, platform=platform):
|
||||
try:
|
||||
await FriendUser.create(user_id=session.user.id, platform=platform)
|
||||
await FriendUser.create(
|
||||
user_id=session.user.id, platform=PlatformUtils.get_platform(session)
|
||||
)
|
||||
logger.info("添加当前好友用户信息", "", session=session)
|
||||
except Exception as e:
|
||||
logger.error("添加当前好友用户信息失败", session=session, e=e)
|
||||
|
||||
@ -1,30 +0,0 @@
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||
|
||||
|
||||
@PriorityLifecycle.on_startup(priority=5)
|
||||
async def _():
|
||||
"""开启/禁用插件格式修改"""
|
||||
_, is_create = await GroupConsole.get_or_create(group_id=133133133)
|
||||
"""标记"""
|
||||
if is_create:
|
||||
data_list = []
|
||||
for group in await GroupConsole.all():
|
||||
if group.block_plugin:
|
||||
if modules := group.block_plugin.split(","):
|
||||
block_plugin = "".join(
|
||||
(f"{module}," if module.startswith("<") else f"<{module},")
|
||||
for module in modules
|
||||
if module.strip()
|
||||
)
|
||||
group.block_plugin = block_plugin.replace("<,", "")
|
||||
if group.block_task:
|
||||
if modules := group.block_task.split(","):
|
||||
block_task = "".join(
|
||||
(f"{module}," if module.startswith("<") else f"<{module},")
|
||||
for module in modules
|
||||
if module.strip()
|
||||
)
|
||||
group.block_task = block_task.replace("<,", "")
|
||||
data_list.append(group)
|
||||
await GroupConsole.bulk_update(data_list, ["block_plugin", "block_task"], 10)
|
||||
@ -44,9 +44,7 @@ class StatisticsManage:
|
||||
title = f"{user.user_name if user else user_id} {day_type}功能调用统计"
|
||||
elif group_id:
|
||||
"""查群组"""
|
||||
group = await GroupConsole.get_or_none(
|
||||
group_id=group_id, channel_id__isnull=True
|
||||
)
|
||||
group = await GroupConsole.get_group(group_id=group_id)
|
||||
title = f"{group.group_name if group else group_id} {day_type}功能调用统计"
|
||||
else:
|
||||
title = "功能调用统计"
|
||||
|
||||
@ -163,7 +163,7 @@ async def _(session: EventSession, arparma: Arparma, state: T_State, level: int)
|
||||
@_matcher.assign("super-handle", parameterless=[CheckGroupId()])
|
||||
async def _(session: EventSession, arparma: Arparma, state: T_State):
|
||||
gid = state["group_id"]
|
||||
group = await GroupConsole.get_or_none(group_id=gid)
|
||||
group = await GroupConsole.get_group(group_id=gid)
|
||||
if not group:
|
||||
await MessageUtils.build_message("群组信息不存在, 请更新群组信息...").finish()
|
||||
s = "删除" if arparma.find("delete") else "添加"
|
||||
@ -177,7 +177,9 @@ async def _(session: EventSession, arparma: Arparma, state: T_State):
|
||||
async def _(session: EventSession, arparma: Arparma, state: T_State):
|
||||
gid = state["group_id"]
|
||||
await GroupConsole.update_or_create(
|
||||
group_id=gid, defaults={"group_flag": 0 if arparma.find("delete") else 1}
|
||||
group_id=gid,
|
||||
channel_id__isnull=True,
|
||||
defaults={"group_flag": 0 if arparma.find("delete") else 1},
|
||||
)
|
||||
s = "删除" if arparma.find("delete") else "添加"
|
||||
await MessageUtils.build_message(f"{s}群认证成功!").send(reply_to=True)
|
||||
|
||||
@ -119,7 +119,7 @@ class ApiDataSource:
|
||||
(await PlatformUtils.get_friend_list(select_bot.bot))[0]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("获取bot好友/群组信息失败...", "WebUi", e=e)
|
||||
logger.warning("获取bot好友/群组数量失败...", "WebUi", e=e)
|
||||
select_bot.group_count = 0
|
||||
select_bot.friend_count = 0
|
||||
select_bot.status = await BotConsole.get_bot_status(select_bot.self_id)
|
||||
|
||||
@ -250,7 +250,7 @@ class ApiDataSource:
|
||||
返回:
|
||||
GroupDetail | None: 群组详情数据
|
||||
"""
|
||||
group = await GroupConsole.get_or_none(group_id=group_id)
|
||||
group = await GroupConsole.get_group(group_id=group_id)
|
||||
if not group:
|
||||
return None
|
||||
like_plugin = await cls.__get_group_detail_like_plugin(group_id)
|
||||
|
||||
@ -45,6 +45,7 @@ async def _(path: str | None = None) -> Result[list[DirFile]]:
|
||||
mtime=file_path.stat().st_mtime,
|
||||
)
|
||||
)
|
||||
data_list.sort(key=lambda f: f.name)
|
||||
return Result.ok(data_list)
|
||||
except Exception as e:
|
||||
return Result.fail(f"获取文件列表失败: {e!s}")
|
||||
|
||||
@ -13,8 +13,8 @@ class BotSetting(BaseModel):
|
||||
"""回复时NICKNAME"""
|
||||
system_proxy: str | None = None
|
||||
"""系统代理"""
|
||||
db_url: str = ""
|
||||
"""数据库链接"""
|
||||
db_url: str = "sqlite:data/zhenxun.db"
|
||||
"""数据库链接, 默认值为sqlite:data/zhenxun.db"""
|
||||
platform_superusers: dict[str, list[str]] = Field(default_factory=dict)
|
||||
"""平台超级用户"""
|
||||
qbot_id_data: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
@ -155,8 +155,6 @@ class AICallableProperties(BaseModel):
|
||||
"""参数类型"""
|
||||
description: str
|
||||
"""参数描述"""
|
||||
enums: list[str] | None = None
|
||||
"""参数枚举"""
|
||||
|
||||
|
||||
class AICallableParam(BaseModel):
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
import time
|
||||
from typing import ClassVar
|
||||
from typing_extensions import Self
|
||||
|
||||
from tortoise import fields
|
||||
|
||||
from zhenxun.services.db_context import Model
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import CacheType, DbLockType
|
||||
from zhenxun.utils.exception import UserAndGroupIsNone
|
||||
|
||||
|
||||
@ -27,6 +29,15 @@ class BanConsole(Model):
|
||||
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
|
||||
table = "ban_console"
|
||||
table_description = "封禁人员/群组数据表"
|
||||
unique_together = ("user_id", "group_id")
|
||||
indexes = [("user_id",), ("group_id",)] # noqa: RUF012
|
||||
|
||||
cache_type = CacheType.BAN
|
||||
"""缓存类型"""
|
||||
cache_key_field = ("user_id", "group_id")
|
||||
"""缓存键字段"""
|
||||
enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE, DbLockType.UPSERT]
|
||||
"""开启锁"""
|
||||
|
||||
@classmethod
|
||||
async def _get_data(cls, user_id: str | None, group_id: str | None) -> Self | None:
|
||||
@ -46,12 +57,12 @@ class BanConsole(Model):
|
||||
raise UserAndGroupIsNone()
|
||||
if user_id:
|
||||
return (
|
||||
await cls.get_or_none(user_id=user_id, group_id=group_id)
|
||||
await cls.safe_get_or_none(user_id=user_id, group_id=group_id)
|
||||
if group_id
|
||||
else await cls.get_or_none(user_id=user_id, group_id__isnull=True)
|
||||
else await cls.safe_get_or_none(user_id=user_id, group_id__isnull=True)
|
||||
)
|
||||
else:
|
||||
return await cls.get_or_none(user_id="", group_id=group_id)
|
||||
return await cls.safe_get_or_none(user_id="", group_id=group_id)
|
||||
|
||||
@classmethod
|
||||
async def check_ban_level(
|
||||
@ -167,3 +178,32 @@ class BanConsole(Model):
|
||||
await user.delete()
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_ban(
|
||||
cls,
|
||||
*,
|
||||
id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
group_id: str | None = None,
|
||||
) -> Self | None:
|
||||
"""安全地获取ban记录
|
||||
|
||||
参数:
|
||||
id: 记录id
|
||||
user_id: 用户id
|
||||
group_id: 群组id
|
||||
|
||||
返回:
|
||||
Self | None: ban记录
|
||||
"""
|
||||
if id is not None:
|
||||
return await cls.safe_get_or_none(id=id)
|
||||
return await cls._get_data(user_id, group_id)
|
||||
|
||||
@classmethod
|
||||
async def _run_script(cls):
|
||||
return [
|
||||
"CREATE INDEX idx_ban_console_user_id ON ban_console(user_id);",
|
||||
"CREATE INDEX idx_ban_console_group_id ON ban_console(group_id);",
|
||||
]
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import Literal, overload
|
||||
from tortoise import fields
|
||||
|
||||
from zhenxun.services.db_context import Model
|
||||
from zhenxun.utils.enum import CacheType
|
||||
|
||||
|
||||
class BotConsole(Model):
|
||||
@ -29,6 +30,11 @@ class BotConsole(Model):
|
||||
table = "bot_console"
|
||||
table_description = "Bot数据表"
|
||||
|
||||
cache_type = CacheType.BOT
|
||||
"""缓存类型"""
|
||||
cache_key_field = "bot_id"
|
||||
"""缓存键字段"""
|
||||
|
||||
@staticmethod
|
||||
def format(name: str) -> str:
|
||||
return f"<{name},"
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any, cast, overload
|
||||
from typing import Any, ClassVar, cast, overload
|
||||
from typing_extensions import Self
|
||||
|
||||
from tortoise import fields
|
||||
@ -6,8 +6,9 @@ from tortoise.backends.base.client import BaseDBAsyncClient
|
||||
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.models.task_info import TaskInfo
|
||||
from zhenxun.services.cache import CacheRoot
|
||||
from zhenxun.services.db_context import Model
|
||||
from zhenxun.utils.enum import PluginType
|
||||
from zhenxun.utils.enum import CacheType, DbLockType, PluginType
|
||||
|
||||
|
||||
def add_disable_marker(name: str) -> str:
|
||||
@ -86,6 +87,16 @@ class GroupConsole(Model):
|
||||
table = "group_console"
|
||||
table_description = "群组信息表"
|
||||
unique_together = ("group_id", "channel_id")
|
||||
indexes = [ # noqa: RUF012
|
||||
("group_id",)
|
||||
]
|
||||
|
||||
cache_type = CacheType.GROUPS
|
||||
"""缓存类型"""
|
||||
cache_key_field = ("group_id", "channel_id")
|
||||
"""缓存键字段"""
|
||||
enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE, DbLockType.UPSERT]
|
||||
"""开启锁"""
|
||||
|
||||
@classmethod
|
||||
async def _get_task_modules(cls, *, default_status: bool) -> list[str]:
|
||||
@ -116,6 +127,18 @@ class GroupConsole(Model):
|
||||
).values_list("module", flat=True),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def _update_cache(cls, instance):
|
||||
"""更新缓存
|
||||
|
||||
参数:
|
||||
instance: 需要更新缓存的实例
|
||||
"""
|
||||
if cache_type := cls.get_cache_type():
|
||||
key = cls.get_cache_key(instance)
|
||||
if key is not None:
|
||||
await CacheRoot.invalidate_cache(cache_type, key)
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any
|
||||
@ -129,6 +152,9 @@ class GroupConsole(Model):
|
||||
if task_modules or plugin_modules:
|
||||
await cls._update_modules(group, task_modules, plugin_modules, using_db)
|
||||
|
||||
# 更新缓存
|
||||
await cls._update_cache(group)
|
||||
|
||||
return group
|
||||
|
||||
@classmethod
|
||||
@ -180,6 +206,10 @@ class GroupConsole(Model):
|
||||
if task_modules or plugin_modules:
|
||||
await cls._update_modules(group, task_modules, plugin_modules, using_db)
|
||||
|
||||
# 更新缓存
|
||||
if is_create:
|
||||
await cls._update_cache(group)
|
||||
|
||||
return group, is_create
|
||||
|
||||
@classmethod
|
||||
@ -202,24 +232,39 @@ class GroupConsole(Model):
|
||||
if task_modules or plugin_modules:
|
||||
await cls._update_modules(group, task_modules, plugin_modules, using_db)
|
||||
|
||||
# 更新缓存
|
||||
await cls._update_cache(group)
|
||||
|
||||
return group, is_create
|
||||
|
||||
@classmethod
|
||||
async def get_group(
|
||||
cls, group_id: str, channel_id: str | None = None
|
||||
cls,
|
||||
group_id: str,
|
||||
channel_id: str | None = None,
|
||||
clean_duplicates: bool = True,
|
||||
) -> Self | None:
|
||||
"""获取群组
|
||||
|
||||
参数:
|
||||
group_id: 群组id
|
||||
channel_id: 频道id.
|
||||
channel_id: 频道id
|
||||
clean_duplicates: 是否删除重复的记录,仅保留最新的
|
||||
|
||||
返回:
|
||||
Self: GroupConsole
|
||||
"""
|
||||
if channel_id:
|
||||
return await cls.get_or_none(group_id=group_id, channel_id=channel_id)
|
||||
return await cls.get_or_none(group_id=group_id, channel_id__isnull=True)
|
||||
return await cls.safe_get_or_none(
|
||||
group_id=group_id,
|
||||
channel_id=channel_id,
|
||||
clean_duplicates=clean_duplicates,
|
||||
)
|
||||
return await cls.safe_get_or_none(
|
||||
group_id=group_id,
|
||||
channel_id__isnull=True,
|
||||
clean_duplicates=clean_duplicates,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def is_super_group(cls, group_id: str) -> bool:
|
||||
@ -303,6 +348,9 @@ class GroupConsole(Model):
|
||||
if update_fields:
|
||||
await group.save(update_fields=update_fields)
|
||||
|
||||
# 更新缓存
|
||||
await cls._update_cache(group)
|
||||
|
||||
@classmethod
|
||||
async def set_unblock_plugin(
|
||||
cls,
|
||||
@ -339,6 +387,9 @@ class GroupConsole(Model):
|
||||
if update_fields:
|
||||
await group.save(update_fields=update_fields)
|
||||
|
||||
# 更新缓存
|
||||
await cls._update_cache(group)
|
||||
|
||||
@classmethod
|
||||
async def is_normal_block_plugin(
|
||||
cls, group_id: str, module: str, channel_id: str | None = None
|
||||
@ -442,6 +493,9 @@ class GroupConsole(Model):
|
||||
if update_fields:
|
||||
await group.save(update_fields=update_fields)
|
||||
|
||||
# 更新缓存
|
||||
await cls._update_cache(group)
|
||||
|
||||
@classmethod
|
||||
async def set_unblock_task(
|
||||
cls,
|
||||
@ -476,6 +530,9 @@ class GroupConsole(Model):
|
||||
if update_fields:
|
||||
await group.save(update_fields=update_fields)
|
||||
|
||||
# 更新缓存
|
||||
await cls._update_cache(group)
|
||||
|
||||
@classmethod
|
||||
def _run_script(cls):
|
||||
return [
|
||||
@ -483,4 +540,6 @@ class GroupConsole(Model):
|
||||
" character varying(255) NOT NULL DEFAULT '';",
|
||||
"ALTER TABLE group_console ADD superuser_block_task"
|
||||
" character varying(255) NOT NULL DEFAULT '';",
|
||||
"CREATE INDEX idx_group_console_group_id ON group_console(group_id);",
|
||||
"CREATE INDEX idx_group_console_group_null_channel ON group_console(group_id) WHERE channel_id IS NULL;", # 单独创建channel为空的索引 # noqa: E501
|
||||
]
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from tortoise import fields
|
||||
|
||||
from zhenxun.services.db_context import Model
|
||||
from zhenxun.utils.enum import CacheType
|
||||
|
||||
|
||||
class LevelUser(Model):
|
||||
@ -20,6 +21,11 @@ class LevelUser(Model):
|
||||
table_description = "用户权限数据库"
|
||||
unique_together = ("user_id", "group_id")
|
||||
|
||||
cache_type = CacheType.LEVEL
|
||||
"""缓存类型"""
|
||||
cache_key_field = ("user_id", "group_id")
|
||||
"""缓存键字段"""
|
||||
|
||||
@classmethod
|
||||
async def get_user_level(cls, user_id: str, group_id: str | None) -> int:
|
||||
"""获取用户在群内的等级
|
||||
@ -53,6 +59,9 @@ class LevelUser(Model):
|
||||
level: 权限等级
|
||||
group_flag: 是否被自动更新刷新权限 0:是, 1:否.
|
||||
"""
|
||||
if await cls.exists(user_id=user_id, group_id=group_id, user_level=level):
|
||||
# 权限相同时跳过
|
||||
return
|
||||
await cls.update_or_create(
|
||||
user_id=user_id,
|
||||
group_id=group_id,
|
||||
|
||||
@ -4,7 +4,7 @@ from tortoise import fields
|
||||
|
||||
from zhenxun.models.plugin_limit import PluginLimit # noqa: F401
|
||||
from zhenxun.services.db_context import Model
|
||||
from zhenxun.utils.enum import BlockType, PluginType
|
||||
from zhenxun.utils.enum import BlockType, CacheType, PluginType
|
||||
|
||||
|
||||
class PluginInfo(Model):
|
||||
@ -59,6 +59,11 @@ class PluginInfo(Model):
|
||||
table = "plugin_info"
|
||||
table_description = "插件基本信息"
|
||||
|
||||
cache_type = CacheType.PLUGINS
|
||||
"""缓存类型"""
|
||||
cache_key_field = "module"
|
||||
"""缓存键字段"""
|
||||
|
||||
@classmethod
|
||||
async def get_plugin(
|
||||
cls, load_status: bool = True, filter_parent: bool = True, **kwargs
|
||||
|
||||
@ -2,7 +2,7 @@ from tortoise import fields
|
||||
|
||||
from zhenxun.models.goods_info import GoodsInfo
|
||||
from zhenxun.services.db_context import Model
|
||||
from zhenxun.utils.enum import GoldHandle
|
||||
from zhenxun.utils.enum import CacheType, GoldHandle
|
||||
from zhenxun.utils.exception import GoodsNotFound, InsufficientGold
|
||||
|
||||
from .user_gold_log import UserGoldLog
|
||||
@ -29,6 +29,12 @@ class UserConsole(Model):
|
||||
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
|
||||
table = "user_console"
|
||||
table_description = "用户数据表"
|
||||
indexes = [("user_id",), ("uid",)] # noqa: RUF012
|
||||
|
||||
cache_type = CacheType.USERS
|
||||
"""缓存类型"""
|
||||
cache_key_field = "user_id"
|
||||
"""缓存键字段"""
|
||||
|
||||
@classmethod
|
||||
async def get_user(cls, user_id: str, platform: str | None = None) -> "UserConsole":
|
||||
@ -193,3 +199,10 @@ class UserConsole(Model):
|
||||
if goods := await GoodsInfo.get_or_none(goods_name=name):
|
||||
return await cls.use_props(user_id, goods.uuid, num, platform)
|
||||
raise GoodsNotFound("未找到商品...")
|
||||
|
||||
@classmethod
|
||||
async def _run_script(cls):
|
||||
return [
|
||||
"CREATE INDEX idx_user_console_user_id ON user_console(user_id);",
|
||||
"CREATE INDEX idx_user_console_uid ON user_console(uid);",
|
||||
]
|
||||
|
||||
@ -21,11 +21,28 @@ require("nonebot_plugin_waiter")
|
||||
from .db_context import Model, disconnect
|
||||
from .llm import (
|
||||
AI,
|
||||
AIConfig,
|
||||
CommonOverrides,
|
||||
LLMContentPart,
|
||||
LLMException,
|
||||
LLMGenerationConfig,
|
||||
LLMMessage,
|
||||
analyze,
|
||||
analyze_multimodal,
|
||||
chat,
|
||||
clear_model_cache,
|
||||
code,
|
||||
create_multimodal_message,
|
||||
embed,
|
||||
generate,
|
||||
get_cache_stats,
|
||||
get_model_instance,
|
||||
list_available_models,
|
||||
list_embedding_models,
|
||||
pipeline_chat,
|
||||
search,
|
||||
search_multimodal,
|
||||
set_global_default_model_name,
|
||||
tool_registry,
|
||||
)
|
||||
from .log import logger
|
||||
@ -34,16 +51,33 @@ from .scheduler import scheduler_manager
|
||||
|
||||
__all__ = [
|
||||
"AI",
|
||||
"AIConfig",
|
||||
"CommonOverrides",
|
||||
"LLMContentPart",
|
||||
"LLMException",
|
||||
"LLMGenerationConfig",
|
||||
"LLMMessage",
|
||||
"Model",
|
||||
"PluginInit",
|
||||
"PluginInitManager",
|
||||
"analyze",
|
||||
"analyze_multimodal",
|
||||
"chat",
|
||||
"clear_model_cache",
|
||||
"code",
|
||||
"create_multimodal_message",
|
||||
"disconnect",
|
||||
"embed",
|
||||
"generate",
|
||||
"get_cache_stats",
|
||||
"get_model_instance",
|
||||
"list_available_models",
|
||||
"list_embedding_models",
|
||||
"logger",
|
||||
"pipeline_chat",
|
||||
"scheduler_manager",
|
||||
"search",
|
||||
"search_multimodal",
|
||||
"set_global_default_model_name",
|
||||
"tool_registry",
|
||||
]
|
||||
|
||||
1065
zhenxun/services/cache/__init__.py
vendored
Normal file
1065
zhenxun/services/cache/__init__.py
vendored
Normal file
File diff suppressed because it is too large
Load Diff
452
zhenxun/services/cache/cache_containers.py
vendored
Normal file
452
zhenxun/services/cache/cache_containers.py
vendored
Normal file
@ -0,0 +1,452 @@
|
||||
from dataclasses import dataclass
|
||||
import time
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheData(Generic[T]):
|
||||
"""缓存数据类,存储数据和过期时间"""
|
||||
|
||||
value: T
|
||||
expire_time: float = 0 # 0表示永不过期
|
||||
|
||||
|
||||
class CacheDict:
|
||||
"""缓存字典类,提供类似普通字典的接口,数据只存储在内存中"""
|
||||
|
||||
def __init__(self, name: str, expire: int = 0):
|
||||
"""初始化缓存字典
|
||||
|
||||
参数:
|
||||
name: 字典名称
|
||||
expire: 过期时间(秒),默认为0表示永不过期
|
||||
"""
|
||||
self.name = name.upper()
|
||||
self.expire = expire
|
||||
self._data: dict[str, CacheData[Any]] = {}
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
"""获取字典项
|
||||
|
||||
参数:
|
||||
key: 字典键
|
||||
|
||||
返回:
|
||||
Any: 字典值
|
||||
"""
|
||||
data = self._data.get(key)
|
||||
if data is None:
|
||||
return None
|
||||
|
||||
# 检查是否过期
|
||||
if data.expire_time > 0 and data.expire_time < time.time():
|
||||
del self._data[key]
|
||||
return None
|
||||
|
||||
return data.value
|
||||
|
||||
def __setitem__(self, key: str, value: Any) -> None:
|
||||
"""设置字典项
|
||||
|
||||
参数:
|
||||
key: 字典键
|
||||
value: 字典值
|
||||
"""
|
||||
# 计算过期时间
|
||||
expire_time = 0
|
||||
if self.expire > 0:
|
||||
expire_time = time.time() + self.expire
|
||||
|
||||
self._data[key] = CacheData(value=value, expire_time=expire_time)
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
"""删除字典项
|
||||
|
||||
参数:
|
||||
key: 字典键
|
||||
"""
|
||||
if key in self._data:
|
||||
del self._data[key]
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
"""检查键是否存在
|
||||
|
||||
参数:
|
||||
key: 字典键
|
||||
|
||||
返回:
|
||||
bool: 是否存在
|
||||
"""
|
||||
if key not in self._data:
|
||||
return False
|
||||
|
||||
# 检查是否过期
|
||||
data = self._data[key]
|
||||
if data.expire_time > 0 and data.expire_time < time.time():
|
||||
del self._data[key]
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
"""获取字典项,如果不存在返回默认值
|
||||
|
||||
参数:
|
||||
key: 字典键
|
||||
default: 默认值
|
||||
|
||||
返回:
|
||||
Any: 字典值或默认值
|
||||
"""
|
||||
value = self[key]
|
||||
return default if value is None else value
|
||||
|
||||
def set(self, key: str, value: Any, expire: int | None = None) -> None:
|
||||
"""设置字典项
|
||||
|
||||
参数:
|
||||
key: 字典键
|
||||
value: 字典值
|
||||
expire: 过期时间(秒),为None时使用默认值
|
||||
"""
|
||||
# 计算过期时间
|
||||
expire_time = 0
|
||||
if expire is not None and expire > 0:
|
||||
expire_time = time.time() + expire
|
||||
elif self.expire > 0:
|
||||
expire_time = time.time() + self.expire
|
||||
|
||||
self._data[key] = CacheData(value=value, expire_time=expire_time)
|
||||
|
||||
def pop(self, key: str, default: Any = None) -> Any:
|
||||
"""删除并返回字典项
|
||||
|
||||
参数:
|
||||
key: 字典键
|
||||
default: 默认值
|
||||
|
||||
返回:
|
||||
Any: 字典值或默认值
|
||||
"""
|
||||
if key not in self._data:
|
||||
return default
|
||||
|
||||
data = self._data.pop(key)
|
||||
|
||||
# 检查是否过期
|
||||
if data.expire_time > 0 and data.expire_time < time.time():
|
||||
return default
|
||||
|
||||
return data.value
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空字典"""
|
||||
self._data.clear()
|
||||
|
||||
def keys(self) -> list[str]:
|
||||
"""获取所有键
|
||||
|
||||
返回:
|
||||
list[str]: 键列表
|
||||
"""
|
||||
# 清理过期的键
|
||||
self._clean_expired()
|
||||
return list(self._data.keys())
|
||||
|
||||
def values(self) -> list[Any]:
|
||||
"""获取所有值
|
||||
|
||||
返回:
|
||||
list[Any]: 值列表
|
||||
"""
|
||||
# 清理过期的键
|
||||
self._clean_expired()
|
||||
return [data.value for data in self._data.values()]
|
||||
|
||||
def items(self) -> list[tuple[str, Any]]:
|
||||
"""获取所有键值对
|
||||
|
||||
返回:
|
||||
list[tuple[str, Any]]: 键值对列表
|
||||
"""
|
||||
# 清理过期的键
|
||||
self._clean_expired()
|
||||
return [(key, data.value) for key, data in self._data.items()]
|
||||
|
||||
def _clean_expired(self) -> None:
|
||||
"""清理过期的键"""
|
||||
now = time.time()
|
||||
expired_keys = [
|
||||
key
|
||||
for key, data in self._data.items()
|
||||
if data.expire_time > 0 and data.expire_time < now
|
||||
]
|
||||
for key in expired_keys:
|
||||
del self._data[key]
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""获取字典长度
|
||||
|
||||
返回:
|
||||
int: 字典长度
|
||||
"""
|
||||
# 清理过期的键
|
||||
self._clean_expired()
|
||||
return len(self._data)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""字符串表示
|
||||
|
||||
返回:
|
||||
str: 字符串表示
|
||||
"""
|
||||
# 清理过期的键
|
||||
self._clean_expired()
|
||||
return f"CacheDict({self.name}, {len(self._data)} items)"
|
||||
|
||||
|
||||
class CacheList:
|
||||
"""缓存列表类,提供类似普通列表的接口,数据只存储在内存中"""
|
||||
|
||||
def __init__(self, name: str, expire: int = 0):
|
||||
"""初始化缓存列表
|
||||
|
||||
参数:
|
||||
name: 列表名称
|
||||
expire: 过期时间(秒),默认为0表示永不过期
|
||||
"""
|
||||
self.name = name.upper()
|
||||
self.expire = expire
|
||||
self._data: list[CacheData[Any]] = []
|
||||
self._expire_time = 0
|
||||
|
||||
# 如果设置了过期时间,计算整个列表的过期时间
|
||||
if self.expire > 0:
|
||||
self._expire_time = time.time() + self.expire
|
||||
|
||||
def __getitem__(self, index: int) -> Any:
|
||||
"""获取列表项
|
||||
|
||||
参数:
|
||||
index: 列表索引
|
||||
|
||||
返回:
|
||||
Any: 列表值
|
||||
"""
|
||||
# 检查整个列表是否过期
|
||||
if self._is_expired():
|
||||
self.clear()
|
||||
raise IndexError(f"列表索引 {index} 超出范围")
|
||||
|
||||
if 0 <= index < len(self._data):
|
||||
return self._data[index].value
|
||||
raise IndexError(f"列表索引 {index} 超出范围")
|
||||
|
||||
def __setitem__(self, index: int, value: Any) -> None:
|
||||
"""设置列表项
|
||||
|
||||
参数:
|
||||
index: 列表索引
|
||||
value: 列表值
|
||||
"""
|
||||
# 检查整个列表是否过期
|
||||
if self._is_expired():
|
||||
self.clear()
|
||||
|
||||
# 确保索引有效
|
||||
while len(self._data) <= index:
|
||||
self._data.append(CacheData(value=None))
|
||||
self._data[index] = CacheData(value=value)
|
||||
|
||||
# 更新过期时间
|
||||
self._update_expire_time()
|
||||
|
||||
def __delitem__(self, index: int) -> None:
|
||||
"""删除列表项
|
||||
|
||||
参数:
|
||||
index: 列表索引
|
||||
"""
|
||||
# 检查整个列表是否过期
|
||||
if self._is_expired():
|
||||
self.clear()
|
||||
raise IndexError(f"列表索引 {index} 超出范围")
|
||||
|
||||
if 0 <= index < len(self._data):
|
||||
del self._data[index]
|
||||
# 更新过期时间
|
||||
self._update_expire_time()
|
||||
else:
|
||||
raise IndexError(f"列表索引 {index} 超出范围")
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""获取列表长度
|
||||
|
||||
返回:
|
||||
int: 列表长度
|
||||
"""
|
||||
# 检查整个列表是否过期
|
||||
if self._is_expired():
|
||||
self.clear()
|
||||
return len(self._data)
|
||||
|
||||
def append(self, value: Any) -> None:
|
||||
"""添加列表项
|
||||
|
||||
参数:
|
||||
value: 列表值
|
||||
"""
|
||||
# 检查整个列表是否过期
|
||||
if self._is_expired():
|
||||
self.clear()
|
||||
|
||||
self._data.append(CacheData(value=value))
|
||||
|
||||
# 更新过期时间
|
||||
self._update_expire_time()
|
||||
|
||||
def extend(self, values: list[Any]) -> None:
|
||||
"""扩展列表
|
||||
|
||||
参数:
|
||||
values: 要添加的值列表
|
||||
"""
|
||||
# 检查整个列表是否过期
|
||||
if self._is_expired():
|
||||
self.clear()
|
||||
|
||||
self._data.extend([CacheData(value=v) for v in values])
|
||||
|
||||
# 更新过期时间
|
||||
self._update_expire_time()
|
||||
|
||||
def insert(self, index: int, value: Any) -> None:
|
||||
"""插入列表项
|
||||
|
||||
参数:
|
||||
index: 插入位置
|
||||
value: 列表值
|
||||
"""
|
||||
# 检查整个列表是否过期
|
||||
if self._is_expired():
|
||||
self.clear()
|
||||
|
||||
self._data.insert(index, CacheData(value=value))
|
||||
|
||||
# 更新过期时间
|
||||
self._update_expire_time()
|
||||
|
||||
def pop(self, index: int = -1) -> Any:
|
||||
"""删除并返回列表项
|
||||
|
||||
参数:
|
||||
index: 列表索引,默认为最后一项
|
||||
|
||||
返回:
|
||||
Any: 列表值
|
||||
"""
|
||||
# 检查整个列表是否过期
|
||||
if self._is_expired():
|
||||
self.clear()
|
||||
raise IndexError("从空列表中弹出")
|
||||
|
||||
if not self._data:
|
||||
raise IndexError("从空列表中弹出")
|
||||
|
||||
item = self._data.pop(index)
|
||||
|
||||
# 更新过期时间
|
||||
self._update_expire_time()
|
||||
|
||||
return item.value
|
||||
|
||||
def remove(self, value: Any) -> None:
|
||||
"""删除第一个匹配的列表项
|
||||
|
||||
参数:
|
||||
value: 要删除的值
|
||||
"""
|
||||
# 检查整个列表是否过期
|
||||
if self._is_expired():
|
||||
self.clear()
|
||||
raise ValueError(f"{value} 不在列表中")
|
||||
|
||||
# 查找匹配的项
|
||||
for i, item in enumerate(self._data):
|
||||
if item.value == value:
|
||||
del self._data[i]
|
||||
# 更新过期时间
|
||||
self._update_expire_time()
|
||||
return
|
||||
|
||||
raise ValueError(f"{value} 不在列表中")
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空列表"""
|
||||
self._data.clear()
|
||||
# 重置过期时间
|
||||
self._update_expire_time()
|
||||
|
||||
def index(self, value: Any, start: int = 0, end: int | None = None) -> int:
|
||||
"""查找值的索引
|
||||
|
||||
参数:
|
||||
value: 要查找的值
|
||||
start: 起始索引
|
||||
end: 结束索引
|
||||
|
||||
返回:
|
||||
int: 索引位置
|
||||
"""
|
||||
# 检查整个列表是否过期
|
||||
if self._is_expired():
|
||||
self.clear()
|
||||
raise ValueError(f"{value} 不在列表中")
|
||||
|
||||
end = end if end is not None else len(self._data)
|
||||
|
||||
for i in range(start, min(end, len(self._data))):
|
||||
if self._data[i].value == value:
|
||||
return i
|
||||
|
||||
raise ValueError(f"{value} 不在列表中")
|
||||
|
||||
def count(self, value: Any) -> int:
|
||||
"""计算值出现的次数
|
||||
|
||||
参数:
|
||||
value: 要计数的值
|
||||
|
||||
返回:
|
||||
int: 出现次数
|
||||
"""
|
||||
# 检查整个列表是否过期
|
||||
if self._is_expired():
|
||||
self.clear()
|
||||
return 0
|
||||
|
||||
return sum(1 for item in self._data if item.value == value)
|
||||
|
||||
def _is_expired(self) -> bool:
|
||||
"""检查整个列表是否过期"""
|
||||
return self._expire_time > 0 and self._expire_time < time.time()
|
||||
|
||||
def _update_expire_time(self) -> None:
|
||||
"""更新过期时间"""
|
||||
if self.expire > 0:
|
||||
self._expire_time = time.time() + self.expire
|
||||
else:
|
||||
self._expire_time = 0
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""字符串表示
|
||||
|
||||
返回:
|
||||
str: 字符串表示
|
||||
"""
|
||||
# 检查整个列表是否过期
|
||||
if self._is_expired():
|
||||
self.clear()
|
||||
return f"CacheList({self.name}, {len(self._data)} items)"
|
||||
35
zhenxun/services/cache/config.py
vendored
Normal file
35
zhenxun/services/cache/config.py
vendored
Normal file
@ -0,0 +1,35 @@
|
||||
"""
|
||||
缓存系统配置
|
||||
"""
|
||||
|
||||
# 日志标识
|
||||
LOG_COMMAND = "CacheRoot"
|
||||
|
||||
# 默认缓存过期时间(秒)
|
||||
DEFAULT_EXPIRE = 600
|
||||
|
||||
# 缓存键前缀
|
||||
CACHE_KEY_PREFIX = "ZHENXUN"
|
||||
|
||||
# 缓存键分隔符
|
||||
CACHE_KEY_SEPARATOR = ":"
|
||||
|
||||
# 复合键分隔符(用于分隔tuple类型的cache_key_field)
|
||||
COMPOSITE_KEY_SEPARATOR = "_"
|
||||
|
||||
|
||||
# 缓存模式
|
||||
class CacheMode:
|
||||
# 内存缓存 - 使用内存存储缓存数据
|
||||
MEMORY = "MEMORY"
|
||||
# Redis缓存 - 使用Redis服务器存储缓存数据
|
||||
REDIS = "REDIS"
|
||||
# 不使用缓存 - 将使用ttl=0的内存缓存,相当于直接从数据库获取数据
|
||||
NONE = "NONE"
|
||||
|
||||
|
||||
SPECIAL_KEY_FORMATS = {
|
||||
"LEVEL": "{user_id}" + COMPOSITE_KEY_SEPARATOR + "{group_id}",
|
||||
"BAN": "{user_id}" + COMPOSITE_KEY_SEPARATOR + "{group_id}",
|
||||
"GROUPS": "{group_id}" + COMPOSITE_KEY_SEPARATOR + "{channel_id}",
|
||||
}
|
||||
653
zhenxun/services/data_access.py
Normal file
653
zhenxun/services/data_access.py
Normal file
@ -0,0 +1,653 @@
|
||||
from typing import Any, ClassVar, Generic, TypeVar, cast
|
||||
|
||||
from zhenxun.services.cache import Cache, CacheRoot, cache_config
|
||||
from zhenxun.services.cache.config import COMPOSITE_KEY_SEPARATOR, CacheMode
|
||||
from zhenxun.services.db_context import Model, with_db_timeout
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
T = TypeVar("T", bound=Model)
|
||||
|
||||
|
||||
class DataAccess(Generic[T]):
|
||||
"""数据访问层,根据配置决定是否使用缓存
|
||||
|
||||
使用示例:
|
||||
```python
|
||||
from zhenxun.services import DataAccess
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
|
||||
# 创建数据访问对象
|
||||
plugin_dao = DataAccess(PluginInfo)
|
||||
|
||||
# 获取单个数据
|
||||
plugin = await plugin_dao.get(module="example_module")
|
||||
|
||||
# 获取所有数据
|
||||
all_plugins = await plugin_dao.all()
|
||||
|
||||
# 筛选数据
|
||||
enabled_plugins = await plugin_dao.filter(status=True)
|
||||
|
||||
# 创建数据
|
||||
new_plugin = await plugin_dao.create(
|
||||
module="new_module",
|
||||
name="新插件",
|
||||
status=True
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
# 添加缓存统计信息
|
||||
_cache_stats: ClassVar[dict] = {}
|
||||
# 空结果标记
|
||||
_NULL_RESULT = "__NULL_RESULT_PLACEHOLDER__"
|
||||
# 默认空结果缓存时间(秒)- 设置为5分钟,避免频繁查询数据库
|
||||
_NULL_RESULT_TTL = 300
|
||||
|
||||
@classmethod
|
||||
def set_null_result_ttl(cls, seconds: int) -> None:
|
||||
"""设置空结果缓存时间
|
||||
|
||||
参数:
|
||||
seconds: 缓存时间(秒)
|
||||
"""
|
||||
if seconds < 0:
|
||||
raise ValueError("缓存时间不能为负数")
|
||||
cls._NULL_RESULT_TTL = seconds
|
||||
logger.info(f"已设置DataAccess空结果缓存时间为 {seconds} 秒")
|
||||
|
||||
@classmethod
|
||||
def get_null_result_ttl(cls) -> int:
|
||||
"""获取空结果缓存时间
|
||||
|
||||
返回:
|
||||
int: 缓存时间(秒)
|
||||
"""
|
||||
return cls._NULL_RESULT_TTL
|
||||
|
||||
def __init__(
|
||||
self, model_cls: type[T], key_field: str = "id", cache_type: str | None = None
|
||||
):
|
||||
"""初始化数据访问对象
|
||||
|
||||
参数:
|
||||
model_cls: 模型类
|
||||
key_field: 主键字段
|
||||
"""
|
||||
self.model_cls = model_cls
|
||||
self.key_field = getattr(model_cls, "cache_key_field", key_field)
|
||||
self.cache_type = getattr(model_cls, "cache_type", cache_type)
|
||||
|
||||
if not self.cache_type:
|
||||
raise ValueError("缓存类型不能为空")
|
||||
self.cache = Cache(self.cache_type)
|
||||
|
||||
# 初始化缓存统计
|
||||
if self.cache_type not in self._cache_stats:
|
||||
self._cache_stats[self.cache_type] = {
|
||||
"hits": 0, # 缓存命中次数
|
||||
"misses": 0, # 缓存未命中次数
|
||||
"null_hits": 0, # 空结果缓存命中次数
|
||||
"sets": 0, # 缓存设置次数
|
||||
"null_sets": 0, # 空结果缓存设置次数
|
||||
"deletes": 0, # 缓存删除次数
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_cache_stats(cls):
|
||||
"""获取缓存统计信息"""
|
||||
result = []
|
||||
for cache_type, stats in cls._cache_stats.items():
|
||||
hits = stats["hits"]
|
||||
null_hits = stats.get("null_hits", 0)
|
||||
misses = stats["misses"]
|
||||
total = hits + null_hits + misses
|
||||
hit_rate = ((hits + null_hits) / total * 100) if total > 0 else 0
|
||||
result.append(
|
||||
{
|
||||
"cache_type": cache_type,
|
||||
"hits": hits,
|
||||
"null_hits": null_hits,
|
||||
"misses": misses,
|
||||
"sets": stats["sets"],
|
||||
"null_sets": stats.get("null_sets", 0),
|
||||
"deletes": stats["deletes"],
|
||||
"hit_rate": f"{hit_rate:.2f}%",
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def reset_cache_stats(cls):
|
||||
"""重置缓存统计信息"""
|
||||
for stats in cls._cache_stats.values():
|
||||
stats["hits"] = 0
|
||||
stats["null_hits"] = 0
|
||||
stats["misses"] = 0
|
||||
stats["sets"] = 0
|
||||
stats["null_sets"] = 0
|
||||
stats["deletes"] = 0
|
||||
|
||||
def _build_cache_key_from_kwargs(self, **kwargs) -> str | None:
|
||||
"""从关键字参数构建缓存键
|
||||
|
||||
参数:
|
||||
**kwargs: 关键字参数
|
||||
|
||||
返回:
|
||||
str | None: 缓存键,如果无法构建则返回None
|
||||
"""
|
||||
if isinstance(self.key_field, tuple):
|
||||
# 多字段主键
|
||||
key_parts = []
|
||||
key_parts.extend(str(kwargs.get(field, "")) for field in self.key_field)
|
||||
return COMPOSITE_KEY_SEPARATOR.join(key_parts) if key_parts else None
|
||||
elif self.key_field in kwargs:
|
||||
# 单字段主键
|
||||
return str(kwargs[self.key_field])
|
||||
return None
|
||||
|
||||
async def safe_get_or_none(self, *args, **kwargs) -> T | None:
|
||||
"""安全的获取单条数据
|
||||
|
||||
参数:
|
||||
*args: 查询参数
|
||||
**kwargs: 查询参数
|
||||
|
||||
返回:
|
||||
Optional[T]: 查询结果,如果不存在返回None
|
||||
"""
|
||||
# 如果没有缓存类型,直接从数据库获取
|
||||
if not self.cache_type or cache_config.cache_mode == CacheMode.NONE:
|
||||
logger.debug(f"{self.model_cls.__name__} 直接从数据库获取数据: {kwargs}")
|
||||
return await with_db_timeout(
|
||||
self.model_cls.safe_get_or_none(*args, **kwargs),
|
||||
operation=f"{self.model_cls.__name__}.safe_get_or_none",
|
||||
)
|
||||
|
||||
# 尝试从缓存获取
|
||||
cache_key = None
|
||||
try:
|
||||
# 尝试构建缓存键
|
||||
cache_key = self._build_cache_key_from_kwargs(**kwargs)
|
||||
|
||||
# 如果成功构建缓存键,尝试从缓存获取
|
||||
if cache_key is not None:
|
||||
data = await self.cache.get(cache_key)
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} self.cache.get(cache_key)"
|
||||
f" 从缓存获取到的数据 {type(data)}: {data}"
|
||||
)
|
||||
if data == self._NULL_RESULT:
|
||||
# 空结果缓存命中
|
||||
self._cache_stats[self.cache_type]["null_hits"] += 1
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} 从缓存获取到空结果: {cache_key}"
|
||||
)
|
||||
return None
|
||||
elif data:
|
||||
# 缓存命中
|
||||
self._cache_stats[self.cache_type]["hits"] += 1
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} 从缓存获取数据成功: {cache_key}"
|
||||
)
|
||||
return cast(T, data)
|
||||
else:
|
||||
# 缓存未命中
|
||||
self._cache_stats[self.cache_type]["misses"] += 1
|
||||
logger.debug(f"{self.model_cls.__name__} 缓存未命中: {cache_key}")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.model_cls.__name__} 从缓存获取数据失败: {kwargs}", e=e)
|
||||
|
||||
# 如果缓存中没有,从数据库获取
|
||||
logger.debug(f"{self.model_cls.__name__} 从数据库获取数据: {kwargs}")
|
||||
data = await self.model_cls.safe_get_or_none(*args, **kwargs)
|
||||
|
||||
# 如果获取到数据,存入缓存
|
||||
if data:
|
||||
try:
|
||||
# 生成缓存键
|
||||
cache_key = self._build_cache_key_for_item(data)
|
||||
if cache_key is not None:
|
||||
# 存入缓存
|
||||
await self.cache.set(cache_key, data)
|
||||
self._cache_stats[self.cache_type]["sets"] += 1
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} 数据已存入缓存: {cache_key}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"{self.model_cls.__name__} 存入缓存失败,参数: {kwargs}", e=e
|
||||
)
|
||||
elif cache_key is not None:
|
||||
# 如果没有获取到数据,缓存空结果
|
||||
try:
|
||||
# 存入空结果缓存,使用较短的过期时间
|
||||
await self.cache.set(
|
||||
cache_key, self._NULL_RESULT, expire=self._NULL_RESULT_TTL
|
||||
)
|
||||
self._cache_stats[self.cache_type]["null_sets"] += 1
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} 空结果已存入缓存: {cache_key},"
|
||||
f" TTL={self._NULL_RESULT_TTL}秒"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"{self.model_cls.__name__} 存入空结果缓存失败,参数: {kwargs}", e=e
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
async def get_or_none(self, *args, **kwargs) -> T | None:
|
||||
"""获取单条数据
|
||||
|
||||
参数:
|
||||
*args: 查询参数
|
||||
**kwargs: 查询参数
|
||||
|
||||
返回:
|
||||
Optional[T]: 查询结果,如果不存在返回None
|
||||
"""
|
||||
# 如果没有缓存类型,直接从数据库获取
|
||||
if not self.cache_type or cache_config.cache_mode == CacheMode.NONE:
|
||||
logger.debug(f"{self.model_cls.__name__} 直接从数据库获取数据: {kwargs}")
|
||||
return await with_db_timeout(
|
||||
self.model_cls.get_or_none(*args, **kwargs),
|
||||
operation=f"{self.model_cls.__name__}.get_or_none",
|
||||
)
|
||||
|
||||
# 尝试从缓存获取
|
||||
cache_key = None
|
||||
try:
|
||||
# 尝试构建缓存键
|
||||
cache_key = self._build_cache_key_from_kwargs(**kwargs)
|
||||
|
||||
# 如果成功构建缓存键,尝试从缓存获取
|
||||
if cache_key is not None:
|
||||
data = await self.cache.get(cache_key)
|
||||
if data == self._NULL_RESULT:
|
||||
# 空结果缓存命中
|
||||
self._cache_stats[self.cache_type]["null_hits"] += 1
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} 从缓存获取到空结果: {cache_key}"
|
||||
)
|
||||
return None
|
||||
elif data:
|
||||
# 缓存命中
|
||||
self._cache_stats[self.cache_type]["hits"] += 1
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} 从缓存获取数据成功: {cache_key}"
|
||||
)
|
||||
return cast(T, data)
|
||||
else:
|
||||
# 缓存未命中
|
||||
self._cache_stats[self.cache_type]["misses"] += 1
|
||||
logger.debug(f"{self.model_cls.__name__} 缓存未命中: {cache_key}")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.model_cls.__name__} 从缓存获取数据失败: {kwargs}", e=e)
|
||||
|
||||
# 如果缓存中没有,从数据库获取
|
||||
logger.debug(f"{self.model_cls.__name__} 从数据库获取数据: {kwargs}")
|
||||
data = await self.model_cls.get_or_none(*args, **kwargs)
|
||||
|
||||
# 如果获取到数据,存入缓存
|
||||
if data:
|
||||
try:
|
||||
cache_key = self._build_cache_key_for_item(data)
|
||||
# 生成缓存键
|
||||
if cache_key is not None:
|
||||
# 存入缓存
|
||||
await self.cache.set(cache_key, data)
|
||||
self._cache_stats[self.cache_type]["sets"] += 1
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} 数据已存入缓存: {cache_key}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"{self.model_cls.__name__} 存入缓存失败,参数: {kwargs}", e=e
|
||||
)
|
||||
elif cache_key is not None:
|
||||
# 如果没有获取到数据,缓存空结果
|
||||
try:
|
||||
# 存入空结果缓存,使用较短的过期时间
|
||||
await self.cache.set(
|
||||
cache_key, self._NULL_RESULT, expire=self._NULL_RESULT_TTL
|
||||
)
|
||||
self._cache_stats[self.cache_type]["null_sets"] += 1
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} 空结果已存入缓存: {cache_key},"
|
||||
f" TTL={self._NULL_RESULT_TTL}秒"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"{self.model_cls.__name__} 存入空结果缓存失败,参数: {kwargs}", e=e
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
async def clear_cache(self, **kwargs) -> bool:
|
||||
"""只清除缓存,不影响数据库数据
|
||||
|
||||
参数:
|
||||
**kwargs: 查询参数,必须包含主键字段
|
||||
|
||||
返回:
|
||||
bool: 是否成功清除缓存
|
||||
"""
|
||||
# 如果没有缓存类型,直接返回True
|
||||
if not self.cache_type or cache_config.cache_mode == CacheMode.NONE:
|
||||
return True
|
||||
|
||||
try:
|
||||
# 构建缓存键
|
||||
cache_key = self._build_cache_key_from_kwargs(**kwargs)
|
||||
if cache_key is None:
|
||||
if isinstance(self.key_field, tuple):
|
||||
# 如果是复合键,检查缺少哪些字段
|
||||
missing_fields = [
|
||||
field for field in self.key_field if field not in kwargs
|
||||
]
|
||||
logger.error(
|
||||
f"清除{self.model_cls.__name__}缓存失败: "
|
||||
f"缺少主键字段 {', '.join(missing_fields)}"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"清除{self.model_cls.__name__}缓存失败: "
|
||||
f"缺少主键字段 {self.key_field}"
|
||||
)
|
||||
return False
|
||||
|
||||
# 删除缓存
|
||||
await self.cache.delete(cache_key)
|
||||
self._cache_stats[self.cache_type]["deletes"] += 1
|
||||
logger.debug(f"已清除{self.model_cls.__name__}缓存: {cache_key}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"清除{self.model_cls.__name__}缓存失败", e=e)
|
||||
return False
|
||||
|
||||
def _build_composite_key(self, data: T) -> str | None:
|
||||
"""构建复合缓存键
|
||||
|
||||
参数:
|
||||
data: 数据对象
|
||||
|
||||
返回:
|
||||
str | None: 构建的缓存键,如果无法构建则返回None
|
||||
"""
|
||||
# 如果是元组,表示多个字段组成键
|
||||
if isinstance(self.key_field, tuple):
|
||||
# 构建键参数列表
|
||||
key_parts = []
|
||||
for field in self.key_field:
|
||||
value = getattr(data, field, "")
|
||||
key_parts.append(value if value is not None else "")
|
||||
|
||||
# 如果没有有效参数,返回None
|
||||
return COMPOSITE_KEY_SEPARATOR.join(key_parts) if key_parts else None
|
||||
elif hasattr(data, self.key_field):
|
||||
value = getattr(data, self.key_field, None)
|
||||
return str(value) if value is not None else None
|
||||
|
||||
return None
|
||||
|
||||
def _build_cache_key_for_item(self, item: T) -> str | None:
|
||||
"""为数据项构建缓存键
|
||||
|
||||
参数:
|
||||
item: 数据项
|
||||
|
||||
返回:
|
||||
str | None: 缓存键,如果无法生成则返回None
|
||||
"""
|
||||
# 如果没有缓存类型,返回None
|
||||
if not self.cache_type:
|
||||
return None
|
||||
|
||||
# 获取缓存类型的配置信息
|
||||
cache_model = CacheRoot.get_model(self.cache_type)
|
||||
|
||||
if not cache_model.key_format:
|
||||
# 常规处理,使用主键作为缓存键
|
||||
return self._build_composite_key(item)
|
||||
# 构建键参数字典
|
||||
key_parts = []
|
||||
# 从格式字符串中提取所需的字段名
|
||||
import re
|
||||
|
||||
field_names = re.findall(r"{([^}]+)}", cache_model.key_format)
|
||||
|
||||
# 收集所有字段值
|
||||
for field in field_names:
|
||||
value = getattr(item, field, "")
|
||||
key_parts.append(value if value is not None else "")
|
||||
|
||||
return COMPOSITE_KEY_SEPARATOR.join(key_parts)
|
||||
|
||||
async def _cache_items(self, data_list: list[T]) -> None:
|
||||
"""将数据列表存入缓存
|
||||
|
||||
参数:
|
||||
data_list: 数据列表
|
||||
"""
|
||||
if (
|
||||
not data_list
|
||||
or not self.cache_type
|
||||
or cache_config.cache_mode == CacheMode.NONE
|
||||
):
|
||||
return
|
||||
|
||||
try:
|
||||
# 遍历数据列表,将每条数据存入缓存
|
||||
cached_count = 0
|
||||
for item in data_list:
|
||||
cache_key = self._build_cache_key_for_item(item)
|
||||
if cache_key is not None:
|
||||
await self.cache.set(cache_key, item)
|
||||
cached_count += 1
|
||||
self._cache_stats[self.cache_type]["sets"] += 1
|
||||
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} 批量缓存: {cached_count}/{len(data_list)}项"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.model_cls.__name__} 批量缓存失败", e=e)
|
||||
|
||||
async def filter(self, *args, **kwargs) -> list[T]:
|
||||
"""筛选数据
|
||||
|
||||
参数:
|
||||
*args: 查询参数
|
||||
**kwargs: 查询参数
|
||||
|
||||
返回:
|
||||
List[T]: 查询结果列表
|
||||
"""
|
||||
# 从数据库获取数据
|
||||
logger.debug(f"{self.model_cls.__name__} filter: 从数据库查询, 参数: {kwargs}")
|
||||
data_list = await self.model_cls.filter(*args, **kwargs)
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} filter: 查询结果数量: {len(data_list)}"
|
||||
)
|
||||
|
||||
# 将数据存入缓存
|
||||
await self._cache_items(data_list)
|
||||
|
||||
return data_list
|
||||
|
||||
async def all(self) -> list[T]:
|
||||
"""获取所有数据
|
||||
|
||||
返回:
|
||||
List[T]: 所有数据列表
|
||||
"""
|
||||
# 直接从数据库获取
|
||||
logger.debug(f"{self.model_cls.__name__} all: 从数据库查询所有数据")
|
||||
data_list = await self.model_cls.all()
|
||||
logger.debug(f"{self.model_cls.__name__} all: 查询结果数量: {len(data_list)}")
|
||||
|
||||
# 将数据存入缓存
|
||||
await self._cache_items(data_list)
|
||||
|
||||
return data_list
|
||||
|
||||
async def count(self, *args, **kwargs) -> int:
|
||||
"""获取数据数量
|
||||
|
||||
参数:
|
||||
*args: 查询参数
|
||||
**kwargs: 查询参数
|
||||
|
||||
返回:
|
||||
int: 数据数量
|
||||
"""
|
||||
# 直接从数据库获取数量
|
||||
return await self.model_cls.filter(*args, **kwargs).count()
|
||||
|
||||
async def exists(self, *args, **kwargs) -> bool:
|
||||
"""判断数据是否存在
|
||||
|
||||
参数:
|
||||
*args: 查询参数
|
||||
**kwargs: 查询参数
|
||||
|
||||
返回:
|
||||
bool: 是否存在
|
||||
"""
|
||||
# 直接从数据库判断是否存在
|
||||
return await self.model_cls.filter(*args, **kwargs).exists()
|
||||
|
||||
async def create(self, **kwargs) -> T:
|
||||
"""创建数据
|
||||
|
||||
参数:
|
||||
**kwargs: 创建参数
|
||||
|
||||
返回:
|
||||
T: 创建的数据
|
||||
"""
|
||||
# 创建数据
|
||||
logger.debug(f"{self.model_cls.__name__} create: 创建数据, 参数: {kwargs}")
|
||||
data = await self.model_cls.create(**kwargs)
|
||||
|
||||
# 如果有缓存类型,将数据存入缓存
|
||||
if self.cache_type and cache_config.cache_mode != CacheMode.NONE:
|
||||
try:
|
||||
# 生成缓存键
|
||||
cache_key = self._build_cache_key_for_item(data)
|
||||
if cache_key is not None:
|
||||
# 存入缓存
|
||||
await self.cache.set(cache_key, data)
|
||||
self._cache_stats[self.cache_type]["sets"] += 1
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} create: "
|
||||
f"新创建的数据已存入缓存: {cache_key}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"{self.model_cls.__name__} create: 存入缓存失败,参数: {kwargs}",
|
||||
e=e,
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
async def update_or_create(
|
||||
self, defaults: dict[str, Any] | None = None, **kwargs
|
||||
) -> tuple[T, bool]:
|
||||
"""更新或创建数据
|
||||
|
||||
参数:
|
||||
defaults: 默认值
|
||||
**kwargs: 查询参数
|
||||
|
||||
返回:
|
||||
tuple[T, bool]: (数据, 是否创建)
|
||||
"""
|
||||
# 更新或创建数据
|
||||
data, created = await self.model_cls.update_or_create(
|
||||
defaults=defaults, **kwargs
|
||||
)
|
||||
|
||||
# 如果有缓存类型,将数据存入缓存
|
||||
if self.cache_type and cache_config.cache_mode != CacheMode.NONE:
|
||||
try:
|
||||
# 生成缓存键
|
||||
cache_key = self._build_cache_key_for_item(data)
|
||||
if cache_key is not None:
|
||||
# 存入缓存
|
||||
await self.cache.set(cache_key, data)
|
||||
self._cache_stats[self.cache_type]["sets"] += 1
|
||||
logger.debug(f"更新或创建的数据已存入缓存: {cache_key}")
|
||||
except Exception as e:
|
||||
logger.error(f"存入缓存失败,参数: {kwargs}", e=e)
|
||||
|
||||
return data, created
|
||||
|
||||
async def delete(self, *args, **kwargs) -> int:
|
||||
"""删除数据
|
||||
|
||||
参数:
|
||||
*args: 查询参数
|
||||
**kwargs: 查询参数
|
||||
|
||||
返回:
|
||||
int: 删除的数据数量
|
||||
"""
|
||||
logger.debug(f"{self.model_cls.__name__} delete: 删除数据, 参数: {kwargs}")
|
||||
|
||||
# 如果有缓存类型且有key_field参数,先尝试删除缓存
|
||||
if self.cache_type and cache_config.cache_mode != CacheMode.NONE:
|
||||
try:
|
||||
# 尝试构建缓存键
|
||||
cache_key = self._build_cache_key_from_kwargs(**kwargs)
|
||||
|
||||
if cache_key is not None:
|
||||
# 如果成功构建缓存键,直接删除缓存
|
||||
await self.cache.delete(cache_key)
|
||||
self._cache_stats[self.cache_type]["deletes"] += 1
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} delete: 已删除缓存: {cache_key}"
|
||||
)
|
||||
else:
|
||||
# 否则需要先查询出要删除的数据,然后删除对应的缓存
|
||||
items = await self.model_cls.filter(*args, **kwargs)
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} delete:"
|
||||
f" 查询到 {len(items)} 条要删除的数据"
|
||||
)
|
||||
for item in items:
|
||||
item_cache_key = self._build_cache_key_for_item(item)
|
||||
if item_cache_key is not None:
|
||||
await self.cache.delete(item_cache_key)
|
||||
self._cache_stats[self.cache_type]["deletes"] += 1
|
||||
if items:
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} delete:"
|
||||
f" 已删除 {len(items)} 条数据的缓存"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.model_cls.__name__} delete: 删除缓存失败", e=e)
|
||||
|
||||
# 删除数据
|
||||
result = await self.model_cls.filter(*args, **kwargs).delete()
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} delete: 已从数据库删除 {result} 条数据"
|
||||
)
|
||||
return result
|
||||
|
||||
def _generate_cache_key(self, data: T) -> str:
|
||||
"""根据数据对象生成缓存键
|
||||
|
||||
参数:
|
||||
data: 数据对象
|
||||
|
||||
返回:
|
||||
str: 缓存键
|
||||
"""
|
||||
# 使用新方法构建复合键
|
||||
if composite_key := self._build_composite_key(data):
|
||||
return composite_key
|
||||
|
||||
# 如果无法生成复合键,生成一个唯一键
|
||||
return f"object_{id(data)}"
|
||||
@ -1,40 +1,339 @@
|
||||
import asyncio
|
||||
from collections.abc import Iterable
|
||||
import contextlib
|
||||
import re
|
||||
import time
|
||||
from typing import Any, ClassVar
|
||||
from typing_extensions import Self
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import nonebot
|
||||
from nonebot import get_driver
|
||||
from nonebot.utils import is_coroutine_callable
|
||||
from tortoise import Tortoise
|
||||
from tortoise.backends.base.client import BaseDBAsyncClient
|
||||
from tortoise.connection import connections
|
||||
from tortoise.models import Model as Model_
|
||||
from tortoise.exceptions import IntegrityError, MultipleObjectsReturned
|
||||
from tortoise.models import Model as TortoiseModel
|
||||
from tortoise.transactions import in_transaction
|
||||
|
||||
from zhenxun.configs.config import BotConfig
|
||||
from zhenxun.services.cache import CacheRoot
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import DbLockType
|
||||
from zhenxun.utils.exception import HookPriorityException
|
||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||
from zhenxun.utils.utils import unicode_escape, unicode_unescape
|
||||
|
||||
from .log import logger
|
||||
driver = get_driver()
|
||||
|
||||
SCRIPT_METHOD = []
|
||||
MODELS: list[str] = []
|
||||
|
||||
# 数据库操作超时设置(秒)
|
||||
DB_TIMEOUT_SECONDS = 3.0
|
||||
|
||||
driver = nonebot.get_driver()
|
||||
# 性能监控阈值(秒)
|
||||
SLOW_QUERY_THRESHOLD = 0.5
|
||||
|
||||
LOG_COMMAND = "DbContext"
|
||||
|
||||
|
||||
class UnicodeSafeMixin(Model_):
|
||||
async def with_db_timeout(
|
||||
coro, timeout: float = DB_TIMEOUT_SECONDS, operation: str | None = None
|
||||
):
|
||||
"""带超时控制的数据库操作"""
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = await asyncio.wait_for(coro, timeout=timeout)
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > SLOW_QUERY_THRESHOLD and operation:
|
||||
logger.warning(f"慢查询: {operation} 耗时 {elapsed:.3f}s", LOG_COMMAND)
|
||||
return result
|
||||
except asyncio.TimeoutError:
|
||||
if operation:
|
||||
logger.error(f"数据库操作超时: {operation} (>{timeout}s)", LOG_COMMAND)
|
||||
raise
|
||||
|
||||
|
||||
class Model(TortoiseModel):
|
||||
"""
|
||||
增强的ORM基类,解决锁嵌套问题
|
||||
"""
|
||||
|
||||
sem_data: ClassVar[dict[str, dict[str, asyncio.Semaphore]]] = {}
|
||||
_current_locks: ClassVar[dict[int, DbLockType]] = {} # 跟踪当前协程持有的锁
|
||||
_unicode_safe_fields: list[str] = [] # noqa: RUF012
|
||||
"""需要处理的字段名列表"""
|
||||
"""需要Unicode处理的字段名列表"""
|
||||
|
||||
async def save(self, *args, **kwargs):
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
if cls.__module__ not in MODELS:
|
||||
MODELS.append(cls.__module__)
|
||||
|
||||
if func := getattr(cls, "_run_script", None):
|
||||
SCRIPT_METHOD.append((cls.__module__, func))
|
||||
|
||||
@classmethod
|
||||
def get_cache_type(cls) -> str | None:
|
||||
"""获取缓存类型"""
|
||||
return getattr(cls, "cache_type", None)
|
||||
|
||||
@classmethod
|
||||
def get_cache_key_field(cls) -> str | tuple[str]:
|
||||
"""获取缓存键字段"""
|
||||
return getattr(cls, "cache_key_field", "id")
|
||||
|
||||
@classmethod
|
||||
def get_cache_key(cls, instance) -> str | None:
|
||||
"""获取缓存键
|
||||
|
||||
参数:
|
||||
instance: 模型实例
|
||||
|
||||
返回:
|
||||
str | None: 缓存键,如果无法获取则返回None
|
||||
"""
|
||||
from zhenxun.services.cache.config import COMPOSITE_KEY_SEPARATOR
|
||||
|
||||
key_field = cls.get_cache_key_field()
|
||||
|
||||
if isinstance(key_field, tuple):
|
||||
# 多字段主键
|
||||
key_parts = []
|
||||
for field in key_field:
|
||||
if hasattr(instance, field):
|
||||
value = getattr(instance, field, None)
|
||||
key_parts.append(value if value is not None else "")
|
||||
else:
|
||||
# 如果缺少任何必要的字段,返回None
|
||||
key_parts.append("")
|
||||
|
||||
# 如果没有有效参数,返回None
|
||||
return COMPOSITE_KEY_SEPARATOR.join(key_parts) if key_parts else None
|
||||
elif hasattr(instance, key_field):
|
||||
value = getattr(instance, key_field, None)
|
||||
return str(value) if value is not None else None
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_semaphore(cls, lock_type: DbLockType):
|
||||
enable_lock = getattr(cls, "enable_lock", None)
|
||||
if not enable_lock or lock_type not in enable_lock:
|
||||
return None
|
||||
|
||||
if cls.__name__ not in cls.sem_data:
|
||||
cls.sem_data[cls.__name__] = {}
|
||||
if lock_type not in cls.sem_data[cls.__name__]:
|
||||
cls.sem_data[cls.__name__][lock_type] = asyncio.Semaphore(1)
|
||||
return cls.sem_data[cls.__name__][lock_type]
|
||||
|
||||
@classmethod
|
||||
def _require_lock(cls, lock_type: DbLockType) -> bool:
|
||||
"""检查是否需要真正加锁"""
|
||||
task_id = id(asyncio.current_task())
|
||||
return cls._current_locks.get(task_id) != lock_type
|
||||
|
||||
@classmethod
|
||||
@contextlib.asynccontextmanager
|
||||
async def _lock_context(cls, lock_type: DbLockType):
|
||||
"""带重入检查的锁上下文"""
|
||||
task_id = id(asyncio.current_task())
|
||||
need_lock = cls._require_lock(lock_type)
|
||||
|
||||
if need_lock and (sem := cls.get_semaphore(lock_type)):
|
||||
cls._current_locks[task_id] = lock_type
|
||||
async with sem:
|
||||
yield
|
||||
cls._current_locks.pop(task_id, None)
|
||||
else:
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any
|
||||
) -> Self:
|
||||
"""创建数据(使用CREATE锁)"""
|
||||
async with cls._lock_context(DbLockType.CREATE):
|
||||
# 直接调用父类的_create方法避免触发save的锁
|
||||
result = await super().create(using_db=using_db, **kwargs)
|
||||
if cache_type := cls.get_cache_type():
|
||||
await CacheRoot.invalidate_cache(cache_type, cls.get_cache_key(result))
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def get_or_create(
|
||||
cls,
|
||||
defaults: dict | None = None,
|
||||
using_db: BaseDBAsyncClient | None = None,
|
||||
**kwargs: Any,
|
||||
) -> tuple[Self, bool]:
|
||||
"""获取或创建数据(无锁版本,依赖数据库约束)"""
|
||||
result = await super().get_or_create(
|
||||
defaults=defaults, using_db=using_db, **kwargs
|
||||
)
|
||||
if cache_type := cls.get_cache_type():
|
||||
await CacheRoot.invalidate_cache(cache_type, cls.get_cache_key(result[0]))
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def update_or_create(
|
||||
cls,
|
||||
defaults: dict | None = None,
|
||||
using_db: BaseDBAsyncClient | None = None,
|
||||
**kwargs: Any,
|
||||
) -> tuple[Self, bool]:
|
||||
"""更新或创建数据(使用UPSERT锁)"""
|
||||
async with cls._lock_context(DbLockType.UPSERT):
|
||||
try:
|
||||
# 先尝试更新(带行锁)
|
||||
async with in_transaction():
|
||||
if obj := await cls.filter(**kwargs).select_for_update().first():
|
||||
await obj.update_from_dict(defaults or {})
|
||||
await obj.save()
|
||||
result = (obj, False)
|
||||
else:
|
||||
# 创建时不重复加锁
|
||||
result = await cls.create(**kwargs, **(defaults or {})), True
|
||||
|
||||
if cache_type := cls.get_cache_type():
|
||||
await CacheRoot.invalidate_cache(
|
||||
cache_type, cls.get_cache_key(result[0])
|
||||
)
|
||||
return result
|
||||
except IntegrityError:
|
||||
# 处理极端情况下的唯一约束冲突
|
||||
obj = await cls.get(**kwargs)
|
||||
return obj, False
|
||||
|
||||
async def save(
|
||||
self,
|
||||
using_db: BaseDBAsyncClient | None = None,
|
||||
update_fields: Iterable[str] | None = None,
|
||||
force_create: bool = False,
|
||||
force_update: bool = False,
|
||||
):
|
||||
"""保存数据(根据操作类型自动选择锁)"""
|
||||
lock_type = (
|
||||
DbLockType.CREATE
|
||||
if getattr(self, "id", None) is None
|
||||
else DbLockType.UPDATE
|
||||
)
|
||||
|
||||
async with self._lock_context(lock_type):
|
||||
for field_name in self._unicode_safe_fields:
|
||||
value = getattr(self, field_name)
|
||||
if isinstance(value, str):
|
||||
# 如果是新数据或数据未标记已处理
|
||||
if not getattr(self, f"_{field_name}_converted", False):
|
||||
if isinstance(value, str) and not getattr(
|
||||
self, f"_{field_name}_converted", False
|
||||
):
|
||||
setattr(self, field_name, unicode_escape(value))
|
||||
setattr(self, f"_{field_name}_converted", True)
|
||||
await super().save(*args, **kwargs)
|
||||
await super().save(
|
||||
using_db=using_db,
|
||||
update_fields=update_fields,
|
||||
force_create=force_create,
|
||||
force_update=force_update,
|
||||
)
|
||||
if cache_type := getattr(self, "cache_type", None):
|
||||
await CacheRoot.invalidate_cache(
|
||||
cache_type, self.__class__.get_cache_key(self)
|
||||
)
|
||||
|
||||
async def delete(self, using_db: BaseDBAsyncClient | None = None):
|
||||
cache_type = getattr(self, "cache_type", None)
|
||||
key = self.__class__.get_cache_key(self) if cache_type else None
|
||||
# 执行删除操作
|
||||
await super().delete(using_db=using_db)
|
||||
|
||||
# 清除缓存
|
||||
if cache_type:
|
||||
await CacheRoot.invalidate_cache(cache_type, key)
|
||||
|
||||
@classmethod
|
||||
async def safe_get_or_none(
|
||||
cls,
|
||||
*args,
|
||||
using_db: BaseDBAsyncClient | None = None,
|
||||
clean_duplicates: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> Self | None:
|
||||
"""安全地获取一条记录或None,处理存在多个记录时返回最新的那个
|
||||
注意,默认会删除重复的记录,仅保留最新的
|
||||
|
||||
参数:
|
||||
*args: 查询参数
|
||||
using_db: 数据库连接
|
||||
clean_duplicates: 是否删除重复的记录,仅保留最新的
|
||||
**kwargs: 查询参数
|
||||
|
||||
返回:
|
||||
Self | None: 查询结果,如果不存在返回None
|
||||
"""
|
||||
try:
|
||||
# 先尝试使用 get_or_none 获取单个记录
|
||||
try:
|
||||
return await with_db_timeout(
|
||||
cls.get_or_none(*args, using_db=using_db, **kwargs),
|
||||
operation=f"{cls.__name__}.get_or_none",
|
||||
)
|
||||
except MultipleObjectsReturned:
|
||||
# 如果出现多个记录的情况,进行特殊处理
|
||||
logger.warning(
|
||||
f"{cls.__name__} safe_get_or_none 发现多个记录: {kwargs}",
|
||||
LOG_COMMAND,
|
||||
)
|
||||
|
||||
# 查询所有匹配记录
|
||||
records = await with_db_timeout(
|
||||
cls.filter(*args, **kwargs).all(),
|
||||
operation=f"{cls.__name__}.filter.all",
|
||||
)
|
||||
|
||||
if not records:
|
||||
return None
|
||||
|
||||
# 如果需要清理重复记录
|
||||
if clean_duplicates and hasattr(records[0], "id"):
|
||||
# 按 id 排序
|
||||
records = sorted(
|
||||
records, key=lambda x: getattr(x, "id", 0), reverse=True
|
||||
)
|
||||
for record in records[1:]:
|
||||
try:
|
||||
await with_db_timeout(
|
||||
record.delete(),
|
||||
operation=f"{cls.__name__}.delete_duplicate",
|
||||
)
|
||||
logger.info(
|
||||
f"{cls.__name__} 删除重复记录:"
|
||||
f" id={getattr(record, 'id', None)}",
|
||||
LOG_COMMAND,
|
||||
)
|
||||
except Exception as del_e:
|
||||
logger.error(f"删除重复记录失败: {del_e}")
|
||||
return records[0]
|
||||
# 如果不需要清理或没有 id 字段,则返回最新的记录
|
||||
if hasattr(cls, "id"):
|
||||
return await with_db_timeout(
|
||||
cls.filter(*args, **kwargs).order_by("-id").first(),
|
||||
operation=f"{cls.__name__}.filter.order_by.first",
|
||||
)
|
||||
# 如果没有 id 字段,则返回第一个记录
|
||||
return await with_db_timeout(
|
||||
cls.filter(*args, **kwargs).first(),
|
||||
operation=f"{cls.__name__}.filter.first",
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
f"数据库操作超时: {cls.__name__}.safe_get_or_none", LOG_COMMAND
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
# 其他类型的错误则继续抛出
|
||||
logger.error(
|
||||
f"数据库操作异常: {cls.__name__}.safe_get_or_none, {e!s}", LOG_COMMAND
|
||||
)
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def get(cls, *args, **kwargs):
|
||||
@ -62,9 +361,9 @@ class UnicodeSafeMixin(Model_):
|
||||
for obj in objects:
|
||||
for field in safe_fields:
|
||||
value = getattr(obj, field)
|
||||
if isinstance(value, str):
|
||||
# 如果是新数据或数据未标记已处理
|
||||
if not getattr(obj, f"_{field}_converted", False):
|
||||
if isinstance(value, str) and not getattr(
|
||||
obj, f"_{field}_converted", False
|
||||
):
|
||||
setattr(obj, field, unicode_escape(value))
|
||||
setattr(obj, f"_{field}_converted", True)
|
||||
|
||||
@ -86,13 +385,12 @@ class UnicodeSafeMixin(Model_):
|
||||
for obj in objects:
|
||||
for field_name in cls._unicode_safe_fields:
|
||||
value = getattr(obj, field_name)
|
||||
if isinstance(value, str):
|
||||
# 如果是新数据或数据未标记已处理
|
||||
if not getattr(obj, f"_{field_name}_converted", False):
|
||||
if isinstance(value, str) and not getattr(
|
||||
obj, f"_{field_name}_converted", False
|
||||
):
|
||||
setattr(obj, field_name, unicode_escape(value))
|
||||
setattr(obj, f"_{field_name}_converted", True)
|
||||
|
||||
# 调用原始 bulk_create 方法
|
||||
return super().bulk_create(
|
||||
objects,
|
||||
batch_size,
|
||||
@ -108,29 +406,11 @@ class UnicodeSafeMixin(Model_):
|
||||
for field_name in cls._unicode_safe_fields:
|
||||
value = getattr(instance, field_name)
|
||||
if isinstance(value, str):
|
||||
# 如果字段包含有效转义序列才处理
|
||||
if re.search(r"(?<!\\)\\u[0-9a-fA-F]{4}", value):
|
||||
setattr(instance, field_name, unicode_unescape(value))
|
||||
# 标记字段已处理
|
||||
setattr(instance, f"_{field_name}_converted", True)
|
||||
|
||||
|
||||
class Model(UnicodeSafeMixin):
|
||||
"""
|
||||
自动添加模块
|
||||
|
||||
Args:
|
||||
UnicodeSafeMixin: Model
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
if cls.__module__ not in MODELS:
|
||||
MODELS.append(cls.__module__)
|
||||
|
||||
if func := getattr(cls, "_run_script", None):
|
||||
SCRIPT_METHOD.append((cls.__module__, func))
|
||||
|
||||
|
||||
class DbUrlIsNode(HookPriorityException):
|
||||
"""
|
||||
数据库链接地址为空
|
||||
@ -147,6 +427,77 @@ class DbConnectError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
POSTGRESQL_CONFIG = {
|
||||
"max_size": 30, # 最大连接数
|
||||
"min_size": 5, # 最小保持的连接数(可选)
|
||||
}
|
||||
|
||||
|
||||
MYSQL_CONFIG = {
|
||||
"max_connections": 20, # 最大连接数
|
||||
"connect_timeout": 30, # 连接超时(可选)
|
||||
}
|
||||
|
||||
SQLITE_CONFIG = {
|
||||
"journal_mode": "WAL", # 提高并发写入性能
|
||||
"timeout": 30, # 锁等待超时(可选)
|
||||
}
|
||||
|
||||
|
||||
def get_config() -> dict:
|
||||
"""获取数据库配置"""
|
||||
parsed = urlparse(BotConfig.db_url)
|
||||
|
||||
# 基础配置
|
||||
config = {
|
||||
"connections": {
|
||||
"default": BotConfig.db_url # 默认直接使用连接字符串
|
||||
},
|
||||
"apps": {
|
||||
"models": {
|
||||
"models": MODELS,
|
||||
"default_connection": "default",
|
||||
}
|
||||
},
|
||||
"timezone": "Asia/Shanghai",
|
||||
}
|
||||
|
||||
# 根据数据库类型应用高级配置
|
||||
if parsed.scheme.startswith("postgres"):
|
||||
config["connections"]["default"] = {
|
||||
"engine": "tortoise.backends.asyncpg",
|
||||
"credentials": {
|
||||
"host": parsed.hostname,
|
||||
"port": parsed.port or 5432,
|
||||
"user": parsed.username,
|
||||
"password": parsed.password,
|
||||
"database": parsed.path[1:],
|
||||
},
|
||||
**POSTGRESQL_CONFIG,
|
||||
}
|
||||
elif parsed.scheme == "mysql":
|
||||
config["connections"]["default"] = {
|
||||
"engine": "tortoise.backends.mysql",
|
||||
"credentials": {
|
||||
"host": parsed.hostname,
|
||||
"port": parsed.port or 3306,
|
||||
"user": parsed.username,
|
||||
"password": parsed.password,
|
||||
"database": parsed.path[1:],
|
||||
},
|
||||
**MYSQL_CONFIG,
|
||||
}
|
||||
elif parsed.scheme == "sqlite":
|
||||
config["connections"]["default"] = {
|
||||
"engine": "tortoise.backends.sqlite",
|
||||
"credentials": {
|
||||
"file_path": parsed.path or ":memory:",
|
||||
},
|
||||
**SQLITE_CONFIG,
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
@PriorityLifecycle.on_startup(priority=1)
|
||||
async def init():
|
||||
if not BotConfig.db_url:
|
||||
@ -162,9 +513,7 @@ async def init():
|
||||
raise DbUrlIsNode("\n" + error.strip())
|
||||
try:
|
||||
await Tortoise.init(
|
||||
db_url=BotConfig.db_url,
|
||||
modules={"models": MODELS},
|
||||
timezone="Asia/Shanghai",
|
||||
config=get_config(),
|
||||
)
|
||||
if SCRIPT_METHOD:
|
||||
db = Tortoise.get_connection("default")
|
||||
@ -183,13 +532,17 @@ async def init():
|
||||
for sql in sql_list:
|
||||
logger.debug(f"执行SQL: {sql}")
|
||||
try:
|
||||
await db.execute_query_dict(sql)
|
||||
await asyncio.wait_for(
|
||||
db.execute_query_dict(sql), timeout=DB_TIMEOUT_SECONDS
|
||||
)
|
||||
# await TestSQL.raw(sql)
|
||||
except Exception as e:
|
||||
logger.debug(f"执行SQL: {sql} 错误...", e=e)
|
||||
if sql_list:
|
||||
logger.debug("SCRIPT_METHOD方法执行完毕!")
|
||||
logger.debug("开始生成数据库表结构...")
|
||||
await Tortoise.generate_schemas()
|
||||
logger.debug("数据库表结构生成完毕!")
|
||||
logger.info("Database loaded successfully!")
|
||||
except Exception as e:
|
||||
raise DbConnectError(f"数据库连接错误... e:{e}") from e
|
||||
|
||||
@ -198,7 +198,7 @@ print(search_result['text'])
|
||||
当你需要进行有上下文的、连续的对话时,`AI` 类是你的最佳选择。
|
||||
|
||||
```python
|
||||
from zhenxun.services.llm.api import AI, AIConfig
|
||||
from zhenxun.services.llm import AI, AIConfig
|
||||
|
||||
# 初始化一个AI会话,可以传入自定义配置
|
||||
ai_config = AIConfig(model="GLM/glm-4-flash", temperature=0.7)
|
||||
@ -395,7 +395,7 @@ async def my_tool_factory(config: MyToolConfig):
|
||||
在 `analyze` 或 `generate_response` 中使用 `use_tools` 参数。框架会自动处理整个调用流程。
|
||||
|
||||
```python
|
||||
from zhenxun.services.llm.api import analyze
|
||||
from zhenxun.services.llm import analyze
|
||||
from nonebot_plugin_alconna.uniseg import UniMessage
|
||||
|
||||
response = await analyze(
|
||||
@ -442,7 +442,6 @@ from zhenxun.services.llm.manager import (
|
||||
get_key_usage_stats,
|
||||
reset_key_status
|
||||
)
|
||||
from zhenxun.services.llm import clear_model_cache, get_cache_stats
|
||||
|
||||
# 列出所有在config.yaml中配置的可用模型
|
||||
models = list_available_models()
|
||||
|
||||
@ -5,14 +5,12 @@ LLM 服务模块 - 公共 API 入口
|
||||
"""
|
||||
|
||||
from .api import (
|
||||
AI,
|
||||
AIConfig,
|
||||
TaskType,
|
||||
analyze,
|
||||
analyze_multimodal,
|
||||
chat,
|
||||
code,
|
||||
embed,
|
||||
generate,
|
||||
pipeline_chat,
|
||||
search,
|
||||
search_multimodal,
|
||||
@ -35,6 +33,7 @@ from .manager import (
|
||||
list_model_identifiers,
|
||||
set_global_default_model_name,
|
||||
)
|
||||
from .session import AI, AIConfig
|
||||
from .tools import tool_registry
|
||||
from .types import (
|
||||
EmbeddingTaskType,
|
||||
@ -49,6 +48,7 @@ from .types import (
|
||||
ModelInfo,
|
||||
ModelProvider,
|
||||
ResponseFormat,
|
||||
TaskType,
|
||||
ToolCategory,
|
||||
ToolMetadata,
|
||||
UsageInfo,
|
||||
@ -84,6 +84,7 @@ __all__ = [
|
||||
"code",
|
||||
"create_multimodal_message",
|
||||
"embed",
|
||||
"generate",
|
||||
"get_cache_stats",
|
||||
"get_global_default_model_name",
|
||||
"get_model_instance",
|
||||
|
||||
@ -1,10 +1,7 @@
|
||||
"""
|
||||
LLM 服务的高级 API 接口
|
||||
LLM 服务的高级 API 接口 - 便捷函数入口
|
||||
"""
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@ -12,10 +9,8 @@ from nonebot_plugin_alconna.uniseg import UniMessage
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
from .config import CommonOverrides, LLMGenerationConfig
|
||||
from .config.providers import get_ai_config
|
||||
from .manager import get_global_default_model_name, get_model_instance
|
||||
from .tools import tool_registry
|
||||
from .manager import get_model_instance
|
||||
from .session import AI
|
||||
from .types import (
|
||||
EmbeddingTaskType,
|
||||
LLMContentPart,
|
||||
@ -29,514 +24,31 @@ from .types import (
|
||||
from .utils import create_multimodal_message, unimsg_to_llm_parts
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
"""任务类型枚举"""
|
||||
|
||||
CHAT = "chat"
|
||||
CODE = "code"
|
||||
SEARCH = "search"
|
||||
ANALYSIS = "analysis"
|
||||
GENERATION = "generation"
|
||||
MULTIMODAL = "multimodal"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AIConfig:
|
||||
"""AI配置类 - 简化版本"""
|
||||
|
||||
model: ModelName = None
|
||||
default_embedding_model: ModelName = None
|
||||
temperature: float | None = None
|
||||
max_tokens: int | None = None
|
||||
enable_cache: bool = False
|
||||
enable_code: bool = False
|
||||
enable_search: bool = False
|
||||
timeout: int | None = None
|
||||
|
||||
enable_gemini_json_mode: bool = False
|
||||
enable_gemini_thinking: bool = False
|
||||
enable_gemini_safe_mode: bool = False
|
||||
enable_gemini_multimodal: bool = False
|
||||
enable_gemini_grounding: bool = False
|
||||
default_preserve_media_in_history: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
"""初始化后从配置中读取默认值"""
|
||||
ai_config = get_ai_config()
|
||||
if self.model is None:
|
||||
self.model = ai_config.get("default_model_name")
|
||||
if self.timeout is None:
|
||||
self.timeout = ai_config.get("timeout", 180)
|
||||
|
||||
|
||||
class AI:
|
||||
"""统一的AI服务类 - 平衡设计版本
|
||||
|
||||
提供三层API:
|
||||
1. 简单方法:ai.chat(), ai.code(), ai.search()
|
||||
2. 标准方法:ai.analyze() 支持复杂参数
|
||||
3. 高级方法:通过get_model_instance()直接访问
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, config: AIConfig | None = None, history: list[LLMMessage] | None = None
|
||||
):
|
||||
"""
|
||||
初始化AI服务
|
||||
|
||||
参数:
|
||||
config: AI 配置.
|
||||
history: 可选的初始对话历史.
|
||||
"""
|
||||
self.config = config or AIConfig()
|
||||
self.history = history or []
|
||||
|
||||
def clear_history(self):
|
||||
"""清空当前会话的历史记录"""
|
||||
self.history = []
|
||||
logger.info("AI session history cleared.")
|
||||
|
||||
def _sanitize_message_for_history(self, message: LLMMessage) -> LLMMessage:
|
||||
"""
|
||||
净化用于存入历史记录的消息。
|
||||
将非文本的多模态内容部分替换为文本占位符,以避免重复处理。
|
||||
"""
|
||||
if not isinstance(message.content, list):
|
||||
return message
|
||||
|
||||
sanitized_message = copy.deepcopy(message)
|
||||
content_list = sanitized_message.content
|
||||
if not isinstance(content_list, list):
|
||||
return sanitized_message
|
||||
|
||||
new_content_parts: list[LLMContentPart] = []
|
||||
has_multimodal_content = False
|
||||
|
||||
for part in content_list:
|
||||
if isinstance(part, LLMContentPart) and part.type == "text":
|
||||
new_content_parts.append(part)
|
||||
else:
|
||||
has_multimodal_content = True
|
||||
|
||||
if has_multimodal_content:
|
||||
placeholder = "[用户发送了媒体文件,内容已在首次分析时处理]"
|
||||
text_part_found = False
|
||||
for part in new_content_parts:
|
||||
if part.type == "text":
|
||||
part.text = f"{placeholder} {part.text or ''}".strip()
|
||||
text_part_found = True
|
||||
break
|
||||
if not text_part_found:
|
||||
new_content_parts.insert(0, LLMContentPart.text_part(placeholder))
|
||||
|
||||
sanitized_message.content = new_content_parts
|
||||
return sanitized_message
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
message: str | LLMMessage | list[LLMContentPart],
|
||||
*,
|
||||
model: ModelName = None,
|
||||
preserve_media_in_history: bool | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
进行一次聊天对话。
|
||||
此方法会自动使用和更新会话内的历史记录。
|
||||
|
||||
参数:
|
||||
message: 用户输入的消息。
|
||||
model: 本次对话要使用的模型。
|
||||
preserve_media_in_history: 是否在历史记录中保留原始多模态信息。
|
||||
- True: 保留,用于深度多轮媒体分析。
|
||||
- False: 不保留,替换为占位符,提高效率。
|
||||
- None (默认): 使用AI实例配置的默认值。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
str: 模型的文本响应。
|
||||
"""
|
||||
current_message: LLMMessage
|
||||
if isinstance(message, str):
|
||||
current_message = LLMMessage.user(message)
|
||||
elif isinstance(message, list) and all(
|
||||
isinstance(part, LLMContentPart) for part in message
|
||||
):
|
||||
current_message = LLMMessage.user(message)
|
||||
elif isinstance(message, LLMMessage):
|
||||
current_message = message
|
||||
else:
|
||||
raise LLMException(
|
||||
f"AI.chat 不支持的消息类型: {type(message)}. "
|
||||
"请使用 str, LLMMessage, 或 list[LLMContentPart]. "
|
||||
"对于更复杂的多模态输入或文件路径,请使用 AI.analyze().",
|
||||
code=LLMErrorCode.API_REQUEST_FAILED,
|
||||
)
|
||||
|
||||
final_messages = [*self.history, current_message]
|
||||
|
||||
response = await self._execute_generation(
|
||||
final_messages, model, "聊天失败", kwargs
|
||||
)
|
||||
|
||||
should_preserve = (
|
||||
preserve_media_in_history
|
||||
if preserve_media_in_history is not None
|
||||
else self.config.default_preserve_media_in_history
|
||||
)
|
||||
|
||||
if should_preserve:
|
||||
logger.debug("深度分析模式:在历史记录中保留原始多模态消息。")
|
||||
self.history.append(current_message)
|
||||
else:
|
||||
logger.debug("高效模式:净化历史记录中的多模态消息。")
|
||||
sanitized_user_message = self._sanitize_message_for_history(current_message)
|
||||
self.history.append(sanitized_user_message)
|
||||
|
||||
self.history.append(LLMMessage.assistant_text_response(response.text))
|
||||
|
||||
return response.text
|
||||
|
||||
async def code(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
model: ModelName = None,
|
||||
timeout: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
代码执行
|
||||
|
||||
参数:
|
||||
prompt: 代码执行的提示词。
|
||||
model: 要使用的模型名称。
|
||||
timeout: 代码执行超时时间(秒)。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
dict[str, Any]: 包含执行结果的字典,包含text、code_executions和success字段。
|
||||
"""
|
||||
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
|
||||
|
||||
config = CommonOverrides.gemini_code_execution()
|
||||
if timeout:
|
||||
config.custom_params = config.custom_params or {}
|
||||
config.custom_params["code_execution_timeout"] = timeout
|
||||
|
||||
messages = [LLMMessage.user(prompt)]
|
||||
|
||||
response = await self._execute_generation(
|
||||
messages, resolved_model, "代码执行失败", kwargs, base_config=config
|
||||
)
|
||||
|
||||
return {
|
||||
"text": response.text,
|
||||
"code_executions": response.code_executions or [],
|
||||
"success": True,
|
||||
}
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str | UniMessage,
|
||||
*,
|
||||
model: ModelName = None,
|
||||
instruction: str = "",
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
信息搜索 - 支持多模态输入
|
||||
|
||||
参数:
|
||||
query: 搜索查询内容,支持文本或多模态消息。
|
||||
model: 要使用的模型名称。
|
||||
instruction: 搜索指令。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
dict[str, Any]: 包含搜索结果的字典,包含text、sources、queries和success字段
|
||||
"""
|
||||
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
|
||||
config = CommonOverrides.gemini_grounding()
|
||||
|
||||
if isinstance(query, str):
|
||||
messages = [LLMMessage.user(query)]
|
||||
elif isinstance(query, UniMessage):
|
||||
content_parts = await unimsg_to_llm_parts(query)
|
||||
|
||||
final_messages: list[LLMMessage] = []
|
||||
if instruction:
|
||||
final_messages.append(LLMMessage.system(instruction))
|
||||
|
||||
if not content_parts:
|
||||
if instruction:
|
||||
final_messages.append(LLMMessage.user(instruction))
|
||||
else:
|
||||
raise LLMException(
|
||||
"搜索内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
|
||||
)
|
||||
else:
|
||||
final_messages.append(LLMMessage.user(content_parts))
|
||||
|
||||
messages = final_messages
|
||||
else:
|
||||
raise LLMException(
|
||||
f"不支持的搜索输入类型: {type(query)}. 请使用 str 或 UniMessage.",
|
||||
code=LLMErrorCode.API_REQUEST_FAILED,
|
||||
)
|
||||
|
||||
response = await self._execute_generation(
|
||||
messages, resolved_model, "信息搜索失败", kwargs, base_config=config
|
||||
)
|
||||
|
||||
result = {
|
||||
"text": response.text,
|
||||
"sources": [],
|
||||
"queries": [],
|
||||
"success": True,
|
||||
}
|
||||
|
||||
if response.grounding_metadata:
|
||||
result["sources"] = response.grounding_metadata.grounding_attributions or []
|
||||
result["queries"] = response.grounding_metadata.web_search_queries or []
|
||||
|
||||
return result
|
||||
|
||||
async def analyze(
|
||||
self,
|
||||
message: UniMessage | None,
|
||||
*,
|
||||
instruction: str = "",
|
||||
model: ModelName = None,
|
||||
use_tools: list[str] | None = None,
|
||||
tool_config: dict[str, Any] | None = None,
|
||||
activated_tools: list[LLMTool] | None = None,
|
||||
history: list[LLMMessage] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
内容分析 - 接收 UniMessage 物件进行多模态分析和工具呼叫。
|
||||
|
||||
参数:
|
||||
message: 要分析的消息内容(支持多模态)。
|
||||
instruction: 分析指令。
|
||||
model: 要使用的模型名称。
|
||||
use_tools: 要使用的工具名称列表。
|
||||
tool_config: 工具配置。
|
||||
activated_tools: 已激活的工具列表。
|
||||
history: 对话历史记录。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
LLMResponse: 模型的完整响应结果。
|
||||
"""
|
||||
content_parts = await unimsg_to_llm_parts(message or UniMessage())
|
||||
|
||||
final_messages: list[LLMMessage] = []
|
||||
if history:
|
||||
final_messages.extend(history)
|
||||
|
||||
if instruction:
|
||||
if not any(msg.role == "system" for msg in final_messages):
|
||||
final_messages.insert(0, LLMMessage.system(instruction))
|
||||
|
||||
if not content_parts:
|
||||
if instruction and not history:
|
||||
final_messages.append(LLMMessage.user(instruction))
|
||||
elif not history:
|
||||
raise LLMException(
|
||||
"分析内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
|
||||
)
|
||||
else:
|
||||
final_messages.append(LLMMessage.user(content_parts))
|
||||
|
||||
llm_tools: list[LLMTool] | None = activated_tools
|
||||
if not llm_tools and use_tools:
|
||||
try:
|
||||
llm_tools = tool_registry.get_tools(use_tools)
|
||||
logger.debug(f"已从注册表加载工具定义: {use_tools}")
|
||||
except ValueError as e:
|
||||
raise LLMException(
|
||||
f"加载工具定义失败: {e}",
|
||||
code=LLMErrorCode.CONFIGURATION_ERROR,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
tool_choice = None
|
||||
if tool_config:
|
||||
mode = tool_config.get("mode", "auto")
|
||||
if mode in ["auto", "any", "none"]:
|
||||
tool_choice = mode
|
||||
|
||||
response = await self._execute_generation(
|
||||
final_messages,
|
||||
model,
|
||||
"内容分析失败",
|
||||
kwargs,
|
||||
llm_tools=llm_tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def _execute_generation(
|
||||
self,
|
||||
messages: list[LLMMessage],
|
||||
model_name: ModelName,
|
||||
error_message: str,
|
||||
config_overrides: dict[str, Any],
|
||||
llm_tools: list[LLMTool] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
base_config: LLMGenerationConfig | None = None,
|
||||
) -> LLMResponse:
|
||||
"""通用的生成执行方法,封装模型获取和单次API调用"""
|
||||
try:
|
||||
resolved_model_name = self._resolve_model_name(
|
||||
model_name or self.config.model
|
||||
)
|
||||
final_config_dict = self._merge_config(
|
||||
config_overrides, base_config=base_config
|
||||
)
|
||||
|
||||
async with await get_model_instance(
|
||||
resolved_model_name, override_config=final_config_dict
|
||||
) as model_instance:
|
||||
return await model_instance.generate_response(
|
||||
messages,
|
||||
tools=llm_tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
except LLMException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"{error_message}: {e}", e=e)
|
||||
raise LLMException(f"{error_message}: {e}", cause=e)
|
||||
|
||||
def _resolve_model_name(self, model_name: ModelName) -> str:
|
||||
"""解析模型名称"""
|
||||
if model_name:
|
||||
return model_name
|
||||
|
||||
default_model = get_global_default_model_name()
|
||||
if default_model:
|
||||
return default_model
|
||||
|
||||
raise LLMException(
|
||||
"未指定模型名称且未设置全局默认模型",
|
||||
code=LLMErrorCode.MODEL_NOT_FOUND,
|
||||
)
|
||||
|
||||
def _merge_config(
|
||||
self,
|
||||
user_config: dict[str, Any],
|
||||
base_config: LLMGenerationConfig | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""合并配置"""
|
||||
final_config = {}
|
||||
if base_config:
|
||||
final_config.update(base_config.to_dict())
|
||||
|
||||
if self.config.temperature is not None:
|
||||
final_config["temperature"] = self.config.temperature
|
||||
if self.config.max_tokens is not None:
|
||||
final_config["max_tokens"] = self.config.max_tokens
|
||||
|
||||
if self.config.enable_cache:
|
||||
final_config["enable_caching"] = True
|
||||
if self.config.enable_code:
|
||||
final_config["enable_code_execution"] = True
|
||||
if self.config.enable_search:
|
||||
final_config["enable_grounding"] = True
|
||||
|
||||
if self.config.enable_gemini_json_mode:
|
||||
final_config["response_mime_type"] = "application/json"
|
||||
if self.config.enable_gemini_thinking:
|
||||
final_config["thinking_budget"] = 0.8
|
||||
if self.config.enable_gemini_safe_mode:
|
||||
final_config["safety_settings"] = (
|
||||
CommonOverrides.gemini_safe().safety_settings
|
||||
)
|
||||
if self.config.enable_gemini_multimodal:
|
||||
final_config.update(CommonOverrides.gemini_multimodal().to_dict())
|
||||
if self.config.enable_gemini_grounding:
|
||||
final_config["enable_grounding"] = True
|
||||
|
||||
final_config.update(user_config)
|
||||
|
||||
return final_config
|
||||
|
||||
async def embed(
|
||||
self,
|
||||
texts: list[str] | str,
|
||||
*,
|
||||
model: ModelName = None,
|
||||
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
||||
**kwargs: Any,
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
生成文本嵌入向量
|
||||
|
||||
参数:
|
||||
texts: 要生成嵌入向量的文本或文本列表。
|
||||
model: 要使用的嵌入模型名称。
|
||||
task_type: 嵌入任务类型。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
list[list[float]]: 文本的嵌入向量列表。
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
try:
|
||||
resolved_model_str = (
|
||||
model or self.config.default_embedding_model or self.config.model
|
||||
)
|
||||
if not resolved_model_str:
|
||||
raise LLMException(
|
||||
"使用 embed 功能时必须指定嵌入模型名称,"
|
||||
"或在 AIConfig 中配置 default_embedding_model。",
|
||||
code=LLMErrorCode.MODEL_NOT_FOUND,
|
||||
)
|
||||
resolved_model_str = self._resolve_model_name(resolved_model_str)
|
||||
|
||||
async with await get_model_instance(
|
||||
resolved_model_str,
|
||||
override_config=None,
|
||||
) as embedding_model_instance:
|
||||
return await embedding_model_instance.generate_embeddings(
|
||||
texts, task_type=task_type, **kwargs
|
||||
)
|
||||
except LLMException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"文本嵌入失败: {e}", e=e)
|
||||
raise LLMException(
|
||||
f"文本嵌入失败: {e}", code=LLMErrorCode.EMBEDDING_FAILED, cause=e
|
||||
)
|
||||
|
||||
|
||||
async def chat(
|
||||
message: str | LLMMessage | list[LLMContentPart],
|
||||
*,
|
||||
model: ModelName = None,
|
||||
tools: list[LLMTool] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
聊天对话便捷函数
|
||||
|
||||
参数:
|
||||
message: 用户输入的消息。
|
||||
model: 要使用的模型名称。
|
||||
tools: 本次对话可用的工具列表。
|
||||
tool_choice: 强制模型使用的工具。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
str: 模型的文本响应。
|
||||
LLMResponse: 模型的完整响应,可能包含文本或工具调用请求。
|
||||
"""
|
||||
ai = AI()
|
||||
return await ai.chat(message, model=model, **kwargs)
|
||||
return await ai.chat(
|
||||
message, model=model, tools=tools, tool_choice=tool_choice, **kwargs
|
||||
)
|
||||
|
||||
|
||||
async def code(
|
||||
@ -730,12 +242,14 @@ async def pipeline_chat(
|
||||
raise ValueError("模型链`model_chain`不能为空。")
|
||||
|
||||
current_content: str | list[LLMContentPart]
|
||||
if isinstance(message, str):
|
||||
if isinstance(message, UniMessage):
|
||||
current_content = await unimsg_to_llm_parts(message)
|
||||
elif isinstance(message, str):
|
||||
current_content = message
|
||||
elif isinstance(message, list):
|
||||
current_content = message
|
||||
else:
|
||||
current_content = await unimsg_to_llm_parts(message)
|
||||
raise TypeError(f"不支持的消息类型: {type(message)}")
|
||||
|
||||
final_response: LLMResponse | None = None
|
||||
|
||||
@ -787,3 +301,45 @@ async def pipeline_chat(
|
||||
)
|
||||
|
||||
return final_response
|
||||
|
||||
|
||||
async def generate(
|
||||
messages: list[LLMMessage],
|
||||
*,
|
||||
model: ModelName = None,
|
||||
tools: list[LLMTool] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
根据完整的消息列表(包括系统指令)生成一次性响应。
|
||||
这是一个便捷的函数,不使用或修改任何会话历史。
|
||||
|
||||
参数:
|
||||
messages: 用于生成响应的完整消息列表。
|
||||
model: 要使用的模型名称。
|
||||
tools: 可用的工具列表。
|
||||
tool_choice: 工具选择策略。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
LLMResponse: 模型的完整响应对象。
|
||||
"""
|
||||
try:
|
||||
ai_instance = AI()
|
||||
resolved_model_name = ai_instance._resolve_model_name(model)
|
||||
final_config_dict = ai_instance._merge_config(kwargs)
|
||||
|
||||
async with await get_model_instance(
|
||||
resolved_model_name, override_config=final_config_dict
|
||||
) as model_instance:
|
||||
return await model_instance.generate_response(
|
||||
messages,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
except LLMException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"生成响应失败: {e}", e=e)
|
||||
raise LLMException(f"生成响应失败: {e}", cause=e)
|
||||
|
||||
@ -17,6 +17,7 @@ from zhenxun.configs.utils import parse_as
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||
|
||||
from ..core import key_store
|
||||
from ..types.models import ModelDetail, ProviderConfig
|
||||
|
||||
|
||||
@ -502,12 +503,13 @@ def set_default_model(provider_model_name: str | None) -> bool:
|
||||
@PriorityLifecycle.on_startup(priority=10)
|
||||
async def _init_llm_config_on_startup():
|
||||
"""
|
||||
在服务启动时主动调用一次 get_llm_config,
|
||||
以触发必要的初始化操作,例如创建默认的 mcp_tools.json 文件。
|
||||
在服务启动时主动调用一次 get_llm_config 和 key_store.initialize,
|
||||
以触发必要的初始化操作。
|
||||
"""
|
||||
logger.info("正在初始化 LLM 配置并检查 MCP 工具文件...")
|
||||
logger.info("正在初始化 LLM 配置并加载密钥状态...")
|
||||
try:
|
||||
get_llm_config()
|
||||
logger.info("LLM 配置初始化完成。")
|
||||
await key_store.initialize()
|
||||
logger.info("LLM 配置和密钥状态初始化完成。")
|
||||
except Exception as e:
|
||||
logger.error(f"LLM 配置初始化时发生错误: {e}", e=e)
|
||||
logger.error(f"LLM 配置或密钥状态初始化时发生错误: {e}", e=e)
|
||||
|
||||
@ -5,17 +5,27 @@ LLM 核心基础设施模块
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from dataclasses import asdict, dataclass
|
||||
from enum import IntEnum
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import aiofiles
|
||||
import httpx
|
||||
import nonebot
|
||||
from pydantic import BaseModel
|
||||
|
||||
from zhenxun.configs.path_config import DATA_PATH
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.user_agent import get_user_agent
|
||||
|
||||
from .types import ProviderConfig
|
||||
from .types.exceptions import LLMErrorCode, LLMException
|
||||
|
||||
driver = nonebot.get_driver()
|
||||
|
||||
|
||||
class HttpClientConfig(BaseModel):
|
||||
"""HTTP客户端配置"""
|
||||
@ -194,6 +204,82 @@ async def create_llm_http_client(
|
||||
return LLMHttpClient(config)
|
||||
|
||||
|
||||
class KeyStatus(IntEnum):
|
||||
"""用于排序和展示的密钥状态枚举"""
|
||||
|
||||
DISABLED = 0
|
||||
ERROR = 1
|
||||
COOLDOWN = 2
|
||||
WARNING = 3
|
||||
HEALTHY = 4
|
||||
UNUSED = 5
|
||||
|
||||
|
||||
@dataclass
|
||||
class KeyStats:
|
||||
"""单个API Key的详细状态和统计信息"""
|
||||
|
||||
cooldown_until: float = 0.0
|
||||
success_count: int = 0
|
||||
failure_count: int = 0
|
||||
total_latency: float = 0.0
|
||||
last_error_info: str | None = None
|
||||
|
||||
@property
|
||||
def is_available(self) -> bool:
|
||||
"""检查Key当前是否可用"""
|
||||
return time.time() >= self.cooldown_until
|
||||
|
||||
@property
|
||||
def avg_latency(self) -> float:
|
||||
"""计算平均延迟"""
|
||||
return (
|
||||
self.total_latency / self.success_count if self.success_count > 0 else 0.0
|
||||
)
|
||||
|
||||
@property
|
||||
def success_rate(self) -> float:
|
||||
"""计算成功率"""
|
||||
total = self.success_count + self.failure_count
|
||||
return self.success_count / total * 100 if total > 0 else 100.0
|
||||
|
||||
@property
|
||||
def status(self) -> KeyStatus:
|
||||
"""根据当前统计数据动态计算状态"""
|
||||
now = time.time()
|
||||
cooldown_left = max(0, self.cooldown_until - now)
|
||||
|
||||
if cooldown_left > 31536000 - 60:
|
||||
return KeyStatus.DISABLED
|
||||
if cooldown_left > 0:
|
||||
return KeyStatus.COOLDOWN
|
||||
|
||||
total_calls = self.success_count + self.failure_count
|
||||
if total_calls == 0:
|
||||
return KeyStatus.UNUSED
|
||||
|
||||
if self.success_rate < 80:
|
||||
return KeyStatus.ERROR
|
||||
|
||||
if total_calls >= 5 and self.avg_latency > 15000:
|
||||
return KeyStatus.WARNING
|
||||
|
||||
return KeyStatus.HEALTHY
|
||||
|
||||
@property
|
||||
def suggested_action(self) -> str:
|
||||
"""根据状态给出建议操作"""
|
||||
status_actions = {
|
||||
KeyStatus.DISABLED: "更换Key",
|
||||
KeyStatus.ERROR: "检查网络/重置",
|
||||
KeyStatus.COOLDOWN: "等待/重置",
|
||||
KeyStatus.WARNING: "观察",
|
||||
KeyStatus.HEALTHY: "-",
|
||||
KeyStatus.UNUSED: "-",
|
||||
}
|
||||
return status_actions.get(self.status, "未知")
|
||||
|
||||
|
||||
class RetryConfig:
|
||||
"""重试配置"""
|
||||
|
||||
@ -236,25 +322,37 @@ async def with_smart_retry(
|
||||
last_exception: Exception | None = None
|
||||
failed_keys: set[str] = set()
|
||||
|
||||
model_instance = next((arg for arg in args if hasattr(arg, "api_keys")), None)
|
||||
all_provider_keys = model_instance.api_keys if model_instance else []
|
||||
|
||||
for attempt in range(config.max_retries + 1):
|
||||
try:
|
||||
if config.key_rotation and "failed_keys" in func.__code__.co_varnames:
|
||||
kwargs["failed_keys"] = failed_keys
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
start_time = time.monotonic()
|
||||
result = await func(*args, **kwargs)
|
||||
latency = (time.monotonic() - start_time) * 1000
|
||||
|
||||
if key_store and isinstance(result, tuple) and len(result) == 2:
|
||||
final_result, api_key_used = result
|
||||
if api_key_used:
|
||||
await key_store.record_success(api_key_used, latency)
|
||||
return final_result
|
||||
else:
|
||||
return result
|
||||
|
||||
except LLMException as e:
|
||||
last_exception = e
|
||||
api_key_in_use = e.details.get("api_key")
|
||||
|
||||
if e.code in [
|
||||
LLMErrorCode.API_KEY_INVALID,
|
||||
LLMErrorCode.API_QUOTA_EXCEEDED,
|
||||
]:
|
||||
if hasattr(e, "details") and e.details and "api_key" in e.details:
|
||||
failed_keys.add(e.details["api_key"])
|
||||
if key_store and provider_name:
|
||||
if api_key_in_use:
|
||||
failed_keys.add(api_key_in_use)
|
||||
if key_store and provider_name and len(all_provider_keys) > 1:
|
||||
status_code = e.details.get("status_code")
|
||||
error_message = f"({e.code.name}) {e.message}"
|
||||
await key_store.record_failure(
|
||||
e.details["api_key"], e.details.get("status_code")
|
||||
api_key_in_use, status_code, error_message
|
||||
)
|
||||
|
||||
should_retry = _should_retry_llm_error(e, attempt, config.max_retries)
|
||||
@ -267,7 +365,7 @@ async def with_smart_retry(
|
||||
if config.exponential_backoff:
|
||||
wait_time *= 2**attempt
|
||||
logger.warning(
|
||||
f"请求失败,{wait_time}秒后重试 (第{attempt + 1}次): {e}"
|
||||
f"请求失败,{wait_time:.2f}秒后重试 (第{attempt + 1}次): {e}"
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
@ -325,14 +423,66 @@ def _should_retry_llm_error(
|
||||
|
||||
|
||||
class KeyStatusStore:
|
||||
"""API Key 状态管理存储 - 优化版本,支持轮询和负载均衡"""
|
||||
"""API Key 状态管理存储 - 支持持久化"""
|
||||
|
||||
def __init__(self):
|
||||
self._key_status: dict[str, bool] = {}
|
||||
self._key_usage_count: dict[str, int] = {}
|
||||
self._key_last_used: dict[str, float] = {}
|
||||
self._key_stats: dict[str, KeyStats] = {}
|
||||
self._provider_key_index: dict[str, int] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
self._file_path = DATA_PATH / "llm" / "key_status.json"
|
||||
|
||||
async def initialize(self):
|
||||
"""从文件异步加载密钥状态,在应用启动时调用"""
|
||||
async with self._lock:
|
||||
if not self._file_path.exists():
|
||||
logger.info("未找到密钥状态文件,将使用内存状态启动。")
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info(f"正在从 {self._file_path} 加载密钥状态...")
|
||||
async with aiofiles.open(self._file_path, encoding="utf-8") as f:
|
||||
content = await f.read()
|
||||
if not content:
|
||||
logger.warning("密钥状态文件为空。")
|
||||
return
|
||||
data = json.loads(content)
|
||||
|
||||
for key, stats_dict in data.items():
|
||||
self._key_stats[key] = KeyStats(**stats_dict)
|
||||
|
||||
logger.info(f"成功加载 {len(self._key_stats)} 个密钥的状态。")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"密钥状态文件 {self._file_path} 格式错误,无法解析。")
|
||||
except Exception as e:
|
||||
logger.error(f"加载密钥状态文件时发生错误: {e}", e=e)
|
||||
|
||||
async def _save_to_file_internal(self):
|
||||
"""
|
||||
[内部方法] 将当前密钥状态安全地写入JSON文件。
|
||||
假定调用方已持有锁。
|
||||
"""
|
||||
data_to_save = {key: asdict(stats) for key, stats in self._key_stats.items()}
|
||||
|
||||
try:
|
||||
self._file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
temp_path = self._file_path.with_suffix(".json.tmp")
|
||||
|
||||
async with aiofiles.open(temp_path, "w", encoding="utf-8") as f:
|
||||
await f.write(json.dumps(data_to_save, ensure_ascii=False, indent=2))
|
||||
|
||||
if self._file_path.exists():
|
||||
self._file_path.unlink()
|
||||
os.rename(temp_path, self._file_path)
|
||||
logger.debug("密钥状态已成功持久化到文件。")
|
||||
except Exception as e:
|
||||
logger.error(f"保存密钥状态到文件失败: {e}", e=e)
|
||||
|
||||
async def shutdown(self):
|
||||
"""在应用关闭时安全地保存状态"""
|
||||
async with self._lock:
|
||||
await self._save_to_file_internal()
|
||||
logger.info("KeyStatusStore 已在关闭前保存状态。")
|
||||
|
||||
async def get_next_available_key(
|
||||
self,
|
||||
@ -355,88 +505,122 @@ class KeyStatusStore:
|
||||
return None
|
||||
|
||||
exclude_keys = exclude_keys or set()
|
||||
|
||||
async with self._lock:
|
||||
for key in api_keys:
|
||||
if key not in self._key_stats:
|
||||
self._key_stats[key] = KeyStats()
|
||||
|
||||
available_keys = [
|
||||
key
|
||||
for key in api_keys
|
||||
if key not in exclude_keys and self._key_status.get(key, True)
|
||||
if key not in exclude_keys and self._key_stats[key].is_available
|
||||
]
|
||||
|
||||
if not available_keys:
|
||||
return api_keys[0] if api_keys else None
|
||||
return api_keys[0]
|
||||
|
||||
async with self._lock:
|
||||
current_index = self._provider_key_index.get(provider_name, 0)
|
||||
|
||||
selected_key = available_keys[current_index % len(available_keys)]
|
||||
self._provider_key_index[provider_name] = current_index + 1
|
||||
|
||||
self._provider_key_index[provider_name] = (current_index + 1) % len(
|
||||
available_keys
|
||||
total_usage = (
|
||||
self._key_stats[selected_key].success_count
|
||||
+ self._key_stats[selected_key].failure_count
|
||||
)
|
||||
|
||||
import time
|
||||
|
||||
self._key_usage_count[selected_key] = (
|
||||
self._key_usage_count.get(selected_key, 0) + 1
|
||||
)
|
||||
self._key_last_used[selected_key] = time.time()
|
||||
|
||||
logger.debug(
|
||||
f"轮询选择API密钥: {self._get_key_id(selected_key)} "
|
||||
f"(使用次数: {self._key_usage_count[selected_key]})"
|
||||
f"(使用次数: {total_usage})"
|
||||
)
|
||||
|
||||
return selected_key
|
||||
|
||||
async def record_success(self, api_key: str):
|
||||
"""记录成功使用"""
|
||||
async def record_success(self, api_key: str, latency: float):
|
||||
"""记录成功使用,并持久化"""
|
||||
async with self._lock:
|
||||
self._key_status[api_key] = True
|
||||
logger.debug(f"记录API密钥成功使用: {self._get_key_id(api_key)}")
|
||||
stats = self._key_stats.setdefault(api_key, KeyStats())
|
||||
stats.cooldown_until = 0.0
|
||||
stats.success_count += 1
|
||||
stats.total_latency += latency
|
||||
stats.last_error_info = None
|
||||
await self._save_to_file_internal()
|
||||
logger.debug(
|
||||
f"记录API密钥成功使用: {self._get_key_id(api_key)}, 延迟: {latency:.2f}ms"
|
||||
)
|
||||
|
||||
async def record_failure(self, api_key: str, status_code: int | None):
|
||||
async def record_failure(
|
||||
self, api_key: str, status_code: int | None, error_message: str
|
||||
):
|
||||
"""
|
||||
记录失败使用
|
||||
记录失败使用,并设置冷却时间
|
||||
|
||||
参数:
|
||||
api_key: API密钥。
|
||||
status_code: HTTP状态码。
|
||||
error_message: 错误信息。
|
||||
"""
|
||||
key_id = self._get_key_id(api_key)
|
||||
async with self._lock:
|
||||
if status_code in [401, 403]:
|
||||
self._key_status[api_key] = False
|
||||
logger.warning(
|
||||
f"API密钥认证失败,标记为不可用: {key_id} (状态码: {status_code})"
|
||||
)
|
||||
now = time.time()
|
||||
cooldown_duration = 300
|
||||
|
||||
if status_code in [401, 403, 404]:
|
||||
cooldown_duration = 31536000
|
||||
log_level = "error"
|
||||
log_message = f"API密钥认证/权限/路径错误,将永久禁用: {key_id}"
|
||||
elif status_code == 429:
|
||||
cooldown_duration = 60
|
||||
log_level = "warning"
|
||||
log_message = f"API密钥被限流,冷却60秒: {key_id}"
|
||||
else:
|
||||
logger.debug(f"记录API密钥失败使用: {key_id} (状态码: {status_code})")
|
||||
log_level = "warning"
|
||||
log_message = f"API密钥遇到临时性错误,冷却{cooldown_duration}秒: {key_id}"
|
||||
|
||||
async with self._lock:
|
||||
stats = self._key_stats.setdefault(api_key, KeyStats())
|
||||
stats.cooldown_until = now + cooldown_duration
|
||||
stats.failure_count += 1
|
||||
stats.last_error_info = error_message[:256]
|
||||
await self._save_to_file_internal()
|
||||
|
||||
getattr(logger, log_level)(log_message)
|
||||
|
||||
async def reset_key_status(self, api_key: str):
|
||||
"""重置密钥状态(用于恢复机制)"""
|
||||
"""重置密钥状态,并持久化"""
|
||||
async with self._lock:
|
||||
self._key_status[api_key] = True
|
||||
stats = self._key_stats.setdefault(api_key, KeyStats())
|
||||
stats.cooldown_until = 0.0
|
||||
stats.last_error_info = None
|
||||
await self._save_to_file_internal()
|
||||
logger.info(f"重置API密钥状态: {self._get_key_id(api_key)}")
|
||||
|
||||
async def get_key_stats(self, api_keys: list[str]) -> dict[str, dict]:
|
||||
"""
|
||||
获取密钥使用统计
|
||||
获取密钥使用统计,并计算出用于展示的派生数据。
|
||||
|
||||
参数:
|
||||
api_keys: API密钥列表。
|
||||
|
||||
返回:
|
||||
dict[str, dict]: 密钥统计信息字典。
|
||||
dict[str, dict]: 包含丰富状态和统计信息的密钥字典。
|
||||
"""
|
||||
stats = {}
|
||||
stats_dict = {}
|
||||
now = time.time()
|
||||
async with self._lock:
|
||||
for key in api_keys:
|
||||
key_id = self._get_key_id(key)
|
||||
stats[key_id] = {
|
||||
"available": self._key_status.get(key, True),
|
||||
"usage_count": self._key_usage_count.get(key, 0),
|
||||
"last_used": self._key_last_used.get(key, 0),
|
||||
stats = self._key_stats.get(key, KeyStats())
|
||||
|
||||
stats_dict[key_id] = {
|
||||
"status_enum": stats.status,
|
||||
"cooldown_seconds_left": max(0, stats.cooldown_until - now),
|
||||
"total_calls": stats.success_count + stats.failure_count,
|
||||
"success_count": stats.success_count,
|
||||
"failure_count": stats.failure_count,
|
||||
"success_rate": stats.success_rate,
|
||||
"avg_latency": stats.avg_latency,
|
||||
"last_error": stats.last_error_info,
|
||||
"suggested_action": stats.suggested_action,
|
||||
}
|
||||
return stats
|
||||
return stats_dict
|
||||
|
||||
def _get_key_id(self, api_key: str) -> str:
|
||||
"""获取API密钥的标识符(用于日志)"""
|
||||
@ -446,3 +630,8 @@ class KeyStatusStore:
|
||||
|
||||
|
||||
key_store = KeyStatusStore()
|
||||
|
||||
|
||||
@driver.on_shutdown
|
||||
async def _shutdown_key_store():
|
||||
await key_store.shutdown()
|
||||
|
||||
@ -137,8 +137,8 @@ def get_configured_providers() -> list[ProviderConfig]:
|
||||
valid_providers.append(item)
|
||||
else:
|
||||
logger.warning(
|
||||
f"配置文件中第 {i + 1} 项未能正确解析为 ProviderConfig 对象,"
|
||||
f"已跳过。实际类型: {type(item)}"
|
||||
f"配置文件中第 {i + 1} 项未能正确解析为 ProviderConfig 对象,已跳过。"
|
||||
f"实际类型: {type(item)}"
|
||||
)
|
||||
|
||||
return valid_providers
|
||||
|
||||
@ -46,17 +46,7 @@ class LLMModelBase(ABC):
|
||||
history: list[dict[str, str]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
生成文本
|
||||
|
||||
参数:
|
||||
prompt: 输入提示词。
|
||||
history: 对话历史记录。
|
||||
**kwargs: 其他参数。
|
||||
|
||||
返回:
|
||||
str: 生成的文本。
|
||||
"""
|
||||
"""生成文本"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -68,19 +58,7 @@ class LLMModelBase(ABC):
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
生成高级响应
|
||||
|
||||
参数:
|
||||
messages: 消息列表。
|
||||
config: 生成配置。
|
||||
tools: 工具列表。
|
||||
tool_choice: 工具选择策略。
|
||||
**kwargs: 其他参数。
|
||||
|
||||
返回:
|
||||
LLMResponse: 模型响应。
|
||||
"""
|
||||
"""生成高级响应"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -90,17 +68,7 @@ class LLMModelBase(ABC):
|
||||
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
||||
**kwargs: Any,
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
生成文本嵌入向量
|
||||
|
||||
参数:
|
||||
texts: 文本列表。
|
||||
task_type: 嵌入任务类型。
|
||||
**kwargs: 其他参数。
|
||||
|
||||
返回:
|
||||
list[list[float]]: 嵌入向量列表。
|
||||
"""
|
||||
"""生成文本嵌入向量"""
|
||||
pass
|
||||
|
||||
|
||||
@ -208,28 +176,8 @@ class LLMModel(LLMModelBase):
|
||||
http_client: "LLMHttpClient",
|
||||
failed_keys: set[str] | None = None,
|
||||
log_context: str = "API",
|
||||
) -> Any:
|
||||
"""
|
||||
执行API调用的通用核心方法。
|
||||
|
||||
该方法封装了以下通用逻辑:
|
||||
1. 选择API密钥。
|
||||
2. 准备和记录请求。
|
||||
3. 发送HTTP POST请求。
|
||||
4. 处理HTTP错误和API特定错误。
|
||||
5. 记录密钥使用状态。
|
||||
6. 解析成功的响应。
|
||||
|
||||
参数:
|
||||
prepare_request_func: 准备请求的函数。
|
||||
parse_response_func: 解析响应的函数。
|
||||
http_client: HTTP客户端。
|
||||
failed_keys: 失败的密钥集合。
|
||||
log_context: 日志上下文。
|
||||
|
||||
返回:
|
||||
Any: 解析后的响应数据。
|
||||
"""
|
||||
) -> tuple[Any, str]:
|
||||
"""执行API调用的通用核心方法"""
|
||||
api_key = await self._select_api_key(failed_keys)
|
||||
|
||||
try:
|
||||
@ -267,7 +215,9 @@ class LLMModel(LLMModelBase):
|
||||
)
|
||||
logger.debug(f"💥 完整错误响应: {error_text}")
|
||||
|
||||
await self.key_store.record_failure(api_key, http_response.status_code)
|
||||
await self.key_store.record_failure(
|
||||
api_key, http_response.status_code, error_text
|
||||
)
|
||||
|
||||
if http_response.status_code in [401, 403]:
|
||||
error_code = LLMErrorCode.API_KEY_INVALID
|
||||
@ -298,7 +248,7 @@ class LLMModel(LLMModelBase):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析 {log_context} 响应失败: {e}", e=e)
|
||||
await self.key_store.record_failure(api_key, None)
|
||||
await self.key_store.record_failure(api_key, None, str(e))
|
||||
if isinstance(e, LLMException):
|
||||
raise
|
||||
else:
|
||||
@ -308,17 +258,15 @@ class LLMModel(LLMModelBase):
|
||||
cause=e,
|
||||
)
|
||||
|
||||
await self.key_store.record_success(api_key)
|
||||
logger.debug(f"✅ API密钥使用成功: {masked_key}")
|
||||
logger.info(f"🎯 LLM响应解析完成 [{log_context}]")
|
||||
return parsed_data
|
||||
return parsed_data, api_key
|
||||
|
||||
except LLMException:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_log_msg = f"生成 {log_context.lower()} 时发生未预期错误: {e}"
|
||||
logger.error(error_log_msg, e=e)
|
||||
await self.key_store.record_failure(api_key, None)
|
||||
await self.key_store.record_failure(api_key, None, str(e))
|
||||
raise LLMException(
|
||||
error_log_msg,
|
||||
code=LLMErrorCode.GENERATION_FAILED
|
||||
@ -349,13 +297,14 @@ class LLMModel(LLMModelBase):
|
||||
adapter.validate_embedding_response(response_json)
|
||||
return adapter.parse_embedding_response(response_json)
|
||||
|
||||
return await self._perform_api_call(
|
||||
parsed_data, api_key_used = await self._perform_api_call(
|
||||
prepare_request_func=prepare_request,
|
||||
parse_response_func=parse_response,
|
||||
http_client=http_client,
|
||||
failed_keys=failed_keys,
|
||||
log_context="Embedding",
|
||||
)
|
||||
return parsed_data
|
||||
|
||||
async def _execute_with_smart_retry(
|
||||
self,
|
||||
@ -394,8 +343,8 @@ class LLMModel(LLMModelBase):
|
||||
tool_choice: str | dict[str, Any] | None,
|
||||
http_client: LLMHttpClient,
|
||||
failed_keys: set[str] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""执行单次请求 - 供重试机制调用,直接返回 LLMResponse"""
|
||||
) -> tuple[LLMResponse, str]:
|
||||
"""执行单次请求 - 供重试机制调用,直接返回 LLMResponse 和使用的 key"""
|
||||
|
||||
async def prepare_request(api_key: str) -> RequestData:
|
||||
return await adapter.prepare_advanced_request(
|
||||
@ -441,19 +390,17 @@ class LLMModel(LLMModelBase):
|
||||
cache_info=response_data.cache_info,
|
||||
)
|
||||
|
||||
return await self._perform_api_call(
|
||||
parsed_data, api_key_used = await self._perform_api_call(
|
||||
prepare_request_func=prepare_request,
|
||||
parse_response_func=parse_response,
|
||||
http_client=http_client,
|
||||
failed_keys=failed_keys,
|
||||
log_context="Generation",
|
||||
)
|
||||
return parsed_data, api_key_used
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
标记模型实例的当前使用周期结束。
|
||||
共享的 HTTP 客户端由 LLMHttpClientManager 管理,不由 LLMModel 关闭。
|
||||
"""
|
||||
"""标记模型实例的当前使用周期结束"""
|
||||
if self._is_closed:
|
||||
return
|
||||
self._is_closed = True
|
||||
@ -487,17 +434,7 @@ class LLMModel(LLMModelBase):
|
||||
history: list[dict[str, str]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
生成文本 - 通过 generate_response 实现
|
||||
|
||||
参数:
|
||||
prompt: 输入提示词。
|
||||
history: 对话历史记录。
|
||||
**kwargs: 其他参数。
|
||||
|
||||
返回:
|
||||
str: 生成的文本。
|
||||
"""
|
||||
"""生成文本"""
|
||||
self._check_not_closed()
|
||||
|
||||
messages: list[LLMMessage] = []
|
||||
@ -538,19 +475,7 @@ class LLMModel(LLMModelBase):
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
生成高级响应
|
||||
|
||||
参数:
|
||||
messages: 消息列表。
|
||||
config: 生成配置。
|
||||
tools: 工具列表。
|
||||
tool_choice: 工具选择策略。
|
||||
**kwargs: 其他参数。
|
||||
|
||||
返回:
|
||||
LLMResponse: 模型响应。
|
||||
"""
|
||||
"""生成高级响应"""
|
||||
self._check_not_closed()
|
||||
|
||||
from .adapters import get_adapter_for_api_type
|
||||
@ -619,17 +544,7 @@ class LLMModel(LLMModelBase):
|
||||
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
||||
**kwargs: Any,
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
生成文本嵌入向量
|
||||
|
||||
参数:
|
||||
texts: 文本列表。
|
||||
task_type: 嵌入任务类型。
|
||||
**kwargs: 其他参数。
|
||||
|
||||
返回:
|
||||
list[list[float]]: 嵌入向量列表。
|
||||
"""
|
||||
"""生成文本嵌入向量"""
|
||||
self._check_not_closed()
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
532
zhenxun/services/llm/session.py
Normal file
532
zhenxun/services/llm/session.py
Normal file
@ -0,0 +1,532 @@
|
||||
"""
|
||||
LLM 服务 - 会话客户端
|
||||
|
||||
提供一个有状态的、面向会话的 LLM 客户端,用于进行多轮对话和复杂交互。
|
||||
"""
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from nonebot_plugin_alconna.uniseg import UniMessage
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
from .config import CommonOverrides, LLMGenerationConfig
|
||||
from .config.providers import get_ai_config
|
||||
from .manager import get_global_default_model_name, get_model_instance
|
||||
from .tools import tool_registry
|
||||
from .types import (
|
||||
EmbeddingTaskType,
|
||||
LLMContentPart,
|
||||
LLMErrorCode,
|
||||
LLMException,
|
||||
LLMMessage,
|
||||
LLMResponse,
|
||||
LLMTool,
|
||||
ModelName,
|
||||
)
|
||||
from .utils import unimsg_to_llm_parts
|
||||
|
||||
|
||||
@dataclass
|
||||
class AIConfig:
|
||||
"""AI配置类 - 简化版本"""
|
||||
|
||||
model: ModelName = None
|
||||
default_embedding_model: ModelName = None
|
||||
temperature: float | None = None
|
||||
max_tokens: int | None = None
|
||||
enable_cache: bool = False
|
||||
enable_code: bool = False
|
||||
enable_search: bool = False
|
||||
timeout: int | None = None
|
||||
|
||||
enable_gemini_json_mode: bool = False
|
||||
enable_gemini_thinking: bool = False
|
||||
enable_gemini_safe_mode: bool = False
|
||||
enable_gemini_multimodal: bool = False
|
||||
enable_gemini_grounding: bool = False
|
||||
default_preserve_media_in_history: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
"""初始化后从配置中读取默认值"""
|
||||
ai_config = get_ai_config()
|
||||
if self.model is None:
|
||||
self.model = ai_config.get("default_model_name")
|
||||
if self.timeout is None:
|
||||
self.timeout = ai_config.get("timeout", 180)
|
||||
|
||||
|
||||
class AI:
|
||||
"""统一的AI服务类 - 平衡设计版本
|
||||
|
||||
提供三层API:
|
||||
1. 简单方法:ai.chat(), ai.code(), ai.search()
|
||||
2. 标准方法:ai.analyze() 支持复杂参数
|
||||
3. 高级方法:通过get_model_instance()直接访问
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, config: AIConfig | None = None, history: list[LLMMessage] | None = None
|
||||
):
|
||||
"""
|
||||
初始化AI服务
|
||||
|
||||
参数:
|
||||
config: AI 配置.
|
||||
history: 可选的初始对话历史.
|
||||
"""
|
||||
self.config = config or AIConfig()
|
||||
self.history = history or []
|
||||
|
||||
def clear_history(self):
|
||||
"""清空当前会话的历史记录"""
|
||||
self.history = []
|
||||
logger.info("AI session history cleared.")
|
||||
|
||||
def _sanitize_message_for_history(self, message: LLMMessage) -> LLMMessage:
|
||||
"""
|
||||
净化用于存入历史记录的消息。
|
||||
将非文本的多模态内容部分替换为文本占位符,以避免重复处理。
|
||||
"""
|
||||
if not isinstance(message.content, list):
|
||||
return message
|
||||
|
||||
sanitized_message = copy.deepcopy(message)
|
||||
content_list = sanitized_message.content
|
||||
if not isinstance(content_list, list):
|
||||
return sanitized_message
|
||||
|
||||
new_content_parts: list[LLMContentPart] = []
|
||||
has_multimodal_content = False
|
||||
|
||||
for part in content_list:
|
||||
if isinstance(part, LLMContentPart) and part.type == "text":
|
||||
new_content_parts.append(part)
|
||||
else:
|
||||
has_multimodal_content = True
|
||||
|
||||
if has_multimodal_content:
|
||||
placeholder = "[用户发送了媒体文件,内容已在首次分析时处理]"
|
||||
text_part_found = False
|
||||
for part in new_content_parts:
|
||||
if part.type == "text":
|
||||
part.text = f"{placeholder} {part.text or ''}".strip()
|
||||
text_part_found = True
|
||||
break
|
||||
if not text_part_found:
|
||||
new_content_parts.insert(0, LLMContentPart.text_part(placeholder))
|
||||
|
||||
sanitized_message.content = new_content_parts
|
||||
return sanitized_message
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
message: str | LLMMessage | list[LLMContentPart],
|
||||
*,
|
||||
model: ModelName = None,
|
||||
preserve_media_in_history: bool | None = None,
|
||||
tools: list[LLMTool] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
进行一次聊天对话,支持工具调用。
|
||||
此方法会自动使用和更新会话内的历史记录。
|
||||
|
||||
参数:
|
||||
message: 用户输入的消息。
|
||||
model: 本次对话要使用的模型。
|
||||
preserve_media_in_history: 是否在历史记录中保留原始多模态信息。
|
||||
- True: 保留,用于深度多轮媒体分析。
|
||||
- False: 不保留,替换为占位符,提高效率。
|
||||
- None (默认): 使用AI实例配置的默认值。
|
||||
tools: 本次对话可用的工具列表。
|
||||
tool_choice: 强制模型使用的工具。
|
||||
**kwargs: 传递给模型的其他生成参数。
|
||||
|
||||
返回:
|
||||
LLMResponse: 模型的完整响应,可能包含文本或工具调用请求。
|
||||
"""
|
||||
current_message: LLMMessage
|
||||
if isinstance(message, str):
|
||||
current_message = LLMMessage.user(message)
|
||||
elif isinstance(message, list) and all(
|
||||
isinstance(part, LLMContentPart) for part in message
|
||||
):
|
||||
current_message = LLMMessage.user(message)
|
||||
elif isinstance(message, LLMMessage):
|
||||
current_message = message
|
||||
else:
|
||||
raise LLMException(
|
||||
f"AI.chat 不支持的消息类型: {type(message)}. "
|
||||
"请使用 str, LLMMessage, 或 list[LLMContentPart]. "
|
||||
"对于更复杂的多模态输入或文件路径,请使用 AI.analyze().",
|
||||
code=LLMErrorCode.API_REQUEST_FAILED,
|
||||
)
|
||||
|
||||
final_messages = [*self.history, current_message]
|
||||
|
||||
response = await self._execute_generation(
|
||||
messages=final_messages,
|
||||
model_name=model,
|
||||
error_message="聊天失败",
|
||||
config_overrides=kwargs,
|
||||
llm_tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
should_preserve = (
|
||||
preserve_media_in_history
|
||||
if preserve_media_in_history is not None
|
||||
else self.config.default_preserve_media_in_history
|
||||
)
|
||||
|
||||
if should_preserve:
|
||||
logger.debug("深度分析模式:在历史记录中保留原始多模态消息。")
|
||||
self.history.append(current_message)
|
||||
else:
|
||||
logger.debug("高效模式:净化历史记录中的多模态消息。")
|
||||
sanitized_user_message = self._sanitize_message_for_history(current_message)
|
||||
self.history.append(sanitized_user_message)
|
||||
|
||||
self.history.append(
|
||||
LLMMessage(
|
||||
role="assistant", content=response.text, tool_calls=response.tool_calls
|
||||
)
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def code(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
model: ModelName = None,
|
||||
timeout: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
代码执行
|
||||
|
||||
参数:
|
||||
prompt: 代码执行的提示词。
|
||||
model: 要使用的模型名称。
|
||||
timeout: 代码执行超时时间(秒)。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
dict[str, Any]: 包含执行结果的字典,包含text、code_executions和success字段。
|
||||
"""
|
||||
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
|
||||
|
||||
config = CommonOverrides.gemini_code_execution()
|
||||
if timeout:
|
||||
config.custom_params = config.custom_params or {}
|
||||
config.custom_params["code_execution_timeout"] = timeout
|
||||
|
||||
messages = [LLMMessage.user(prompt)]
|
||||
|
||||
response = await self._execute_generation(
|
||||
messages=messages,
|
||||
model_name=resolved_model,
|
||||
error_message="代码执行失败",
|
||||
config_overrides=kwargs,
|
||||
base_config=config,
|
||||
)
|
||||
|
||||
return {
|
||||
"text": response.text,
|
||||
"code_executions": response.code_executions or [],
|
||||
"success": True,
|
||||
}
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str | UniMessage,
|
||||
*,
|
||||
model: ModelName = None,
|
||||
instruction: str = "",
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
信息搜索 - 支持多模态输入
|
||||
|
||||
参数:
|
||||
query: 搜索查询内容,支持文本或多模态消息。
|
||||
model: 要使用的模型名称。
|
||||
instruction: 搜索指令。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
dict[str, Any]: 包含搜索结果的字典,包含text、sources、queries和success字段
|
||||
"""
|
||||
from nonebot_plugin_alconna.uniseg import UniMessage
|
||||
|
||||
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
|
||||
config = CommonOverrides.gemini_grounding()
|
||||
|
||||
if isinstance(query, str):
|
||||
messages = [LLMMessage.user(query)]
|
||||
elif isinstance(query, UniMessage):
|
||||
content_parts = await unimsg_to_llm_parts(query)
|
||||
|
||||
final_messages: list[LLMMessage] = []
|
||||
if instruction:
|
||||
final_messages.append(LLMMessage.system(instruction))
|
||||
|
||||
if not content_parts:
|
||||
if instruction:
|
||||
final_messages.append(LLMMessage.user(instruction))
|
||||
else:
|
||||
raise LLMException(
|
||||
"搜索内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
|
||||
)
|
||||
else:
|
||||
final_messages.append(LLMMessage.user(content_parts))
|
||||
|
||||
messages = final_messages
|
||||
else:
|
||||
raise LLMException(
|
||||
f"不支持的搜索输入类型: {type(query)}. 请使用 str 或 UniMessage.",
|
||||
code=LLMErrorCode.API_REQUEST_FAILED,
|
||||
)
|
||||
|
||||
response = await self._execute_generation(
|
||||
messages=messages,
|
||||
model_name=resolved_model,
|
||||
error_message="信息搜索失败",
|
||||
config_overrides=kwargs,
|
||||
base_config=config,
|
||||
)
|
||||
|
||||
result = {
|
||||
"text": response.text,
|
||||
"sources": [],
|
||||
"queries": [],
|
||||
"success": True,
|
||||
}
|
||||
|
||||
if response.grounding_metadata:
|
||||
result["sources"] = response.grounding_metadata.grounding_attributions or []
|
||||
result["queries"] = response.grounding_metadata.web_search_queries or []
|
||||
|
||||
return result
|
||||
|
||||
async def analyze(
|
||||
self,
|
||||
message: UniMessage | None,
|
||||
*,
|
||||
instruction: str = "",
|
||||
model: ModelName = None,
|
||||
use_tools: list[str] | None = None,
|
||||
tool_config: dict[str, Any] | None = None,
|
||||
activated_tools: list[LLMTool] | None = None,
|
||||
history: list[LLMMessage] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
内容分析 - 接收 UniMessage 物件进行多模态分析和工具呼叫。
|
||||
|
||||
参数:
|
||||
message: 要分析的消息内容(支持多模态)。
|
||||
instruction: 分析指令。
|
||||
model: 要使用的模型名称。
|
||||
use_tools: 要使用的工具名称列表。
|
||||
tool_config: 工具配置。
|
||||
activated_tools: 已激活的工具列表。
|
||||
history: 对话历史记录。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
LLMResponse: 模型的完整响应结果。
|
||||
"""
|
||||
from nonebot_plugin_alconna.uniseg import UniMessage
|
||||
|
||||
content_parts = await unimsg_to_llm_parts(message or UniMessage())
|
||||
|
||||
final_messages: list[LLMMessage] = []
|
||||
if history:
|
||||
final_messages.extend(history)
|
||||
|
||||
if instruction:
|
||||
if not any(msg.role == "system" for msg in final_messages):
|
||||
final_messages.insert(0, LLMMessage.system(instruction))
|
||||
|
||||
if not content_parts:
|
||||
if instruction and not history:
|
||||
final_messages.append(LLMMessage.user(instruction))
|
||||
elif not history:
|
||||
raise LLMException(
|
||||
"分析内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
|
||||
)
|
||||
else:
|
||||
final_messages.append(LLMMessage.user(content_parts))
|
||||
|
||||
llm_tools: list[LLMTool] | None = activated_tools
|
||||
if not llm_tools and use_tools:
|
||||
try:
|
||||
llm_tools = tool_registry.get_tools(use_tools)
|
||||
logger.debug(f"已从注册表加载工具定义: {use_tools}")
|
||||
except ValueError as e:
|
||||
raise LLMException(
|
||||
f"加载工具定义失败: {e}",
|
||||
code=LLMErrorCode.CONFIGURATION_ERROR,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
tool_choice = None
|
||||
if tool_config:
|
||||
mode = tool_config.get("mode", "auto")
|
||||
if mode in ["auto", "any", "none"]:
|
||||
tool_choice = mode
|
||||
|
||||
response = await self._execute_generation(
|
||||
messages=final_messages,
|
||||
model_name=model,
|
||||
error_message="内容分析失败",
|
||||
config_overrides=kwargs,
|
||||
llm_tools=llm_tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def _execute_generation(
|
||||
self,
|
||||
messages: list[LLMMessage],
|
||||
model_name: ModelName,
|
||||
error_message: str,
|
||||
config_overrides: dict[str, Any],
|
||||
llm_tools: list[LLMTool] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
base_config: LLMGenerationConfig | None = None,
|
||||
) -> LLMResponse:
|
||||
"""通用的生成执行方法,封装模型获取和单次API调用"""
|
||||
try:
|
||||
resolved_model_name = self._resolve_model_name(
|
||||
model_name or self.config.model
|
||||
)
|
||||
final_config_dict = self._merge_config(
|
||||
config_overrides, base_config=base_config
|
||||
)
|
||||
|
||||
async with await get_model_instance(
|
||||
resolved_model_name, override_config=final_config_dict
|
||||
) as model_instance:
|
||||
return await model_instance.generate_response(
|
||||
messages,
|
||||
tools=llm_tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
except LLMException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"{error_message}: {e}", e=e)
|
||||
raise LLMException(f"{error_message}: {e}", cause=e)
|
||||
|
||||
def _resolve_model_name(self, model_name: ModelName) -> str:
|
||||
"""解析模型名称"""
|
||||
if model_name:
|
||||
return model_name
|
||||
|
||||
default_model = get_global_default_model_name()
|
||||
if default_model:
|
||||
return default_model
|
||||
|
||||
raise LLMException(
|
||||
"未指定模型名称且未设置全局默认模型",
|
||||
code=LLMErrorCode.MODEL_NOT_FOUND,
|
||||
)
|
||||
|
||||
def _merge_config(
|
||||
self,
|
||||
user_config: dict[str, Any],
|
||||
base_config: LLMGenerationConfig | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""合并配置"""
|
||||
final_config = {}
|
||||
if base_config:
|
||||
final_config.update(base_config.to_dict())
|
||||
|
||||
if self.config.temperature is not None:
|
||||
final_config["temperature"] = self.config.temperature
|
||||
if self.config.max_tokens is not None:
|
||||
final_config["max_tokens"] = self.config.max_tokens
|
||||
|
||||
if self.config.enable_cache:
|
||||
final_config["enable_caching"] = True
|
||||
if self.config.enable_code:
|
||||
final_config["enable_code_execution"] = True
|
||||
if self.config.enable_search:
|
||||
final_config["enable_grounding"] = True
|
||||
|
||||
if self.config.enable_gemini_json_mode:
|
||||
final_config["response_mime_type"] = "application/json"
|
||||
if self.config.enable_gemini_thinking:
|
||||
final_config["thinking_budget"] = 0.8
|
||||
if self.config.enable_gemini_safe_mode:
|
||||
final_config["safety_settings"] = (
|
||||
CommonOverrides.gemini_safe().safety_settings
|
||||
)
|
||||
if self.config.enable_gemini_multimodal:
|
||||
final_config.update(CommonOverrides.gemini_multimodal().to_dict())
|
||||
if self.config.enable_gemini_grounding:
|
||||
final_config["enable_grounding"] = True
|
||||
|
||||
final_config.update(user_config)
|
||||
|
||||
return final_config
|
||||
|
||||
async def embed(
|
||||
self,
|
||||
texts: list[str] | str,
|
||||
*,
|
||||
model: ModelName = None,
|
||||
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
||||
**kwargs: Any,
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
生成文本嵌入向量
|
||||
|
||||
参数:
|
||||
texts: 要生成嵌入向量的文本或文本列表。
|
||||
model: 要使用的嵌入模型名称。
|
||||
task_type: 嵌入任务类型。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
list[list[float]]: 文本的嵌入向量列表。
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
try:
|
||||
resolved_model_str = (
|
||||
model or self.config.default_embedding_model or self.config.model
|
||||
)
|
||||
if not resolved_model_str:
|
||||
raise LLMException(
|
||||
"使用 embed 功能时必须指定嵌入模型名称,"
|
||||
"或在 AIConfig 中配置 default_embedding_model。",
|
||||
code=LLMErrorCode.MODEL_NOT_FOUND,
|
||||
)
|
||||
resolved_model_str = self._resolve_model_name(resolved_model_str)
|
||||
|
||||
async with await get_model_instance(
|
||||
resolved_model_str,
|
||||
override_config=None,
|
||||
) as embedding_model_instance:
|
||||
return await embedding_model_instance.generate_embeddings(
|
||||
texts, task_type=task_type, **kwargs
|
||||
)
|
||||
except LLMException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"文本嵌入失败: {e}", e=e)
|
||||
raise LLMException(
|
||||
f"文本嵌入失败: {e}", code=LLMErrorCode.EMBEDDING_FAILED, cause=e
|
||||
)
|
||||
@ -10,7 +10,13 @@ from .content import (
|
||||
LLMMessage,
|
||||
LLMResponse,
|
||||
)
|
||||
from .enums import EmbeddingTaskType, ModelProvider, ResponseFormat, ToolCategory
|
||||
from .enums import (
|
||||
EmbeddingTaskType,
|
||||
ModelProvider,
|
||||
ResponseFormat,
|
||||
TaskType,
|
||||
ToolCategory,
|
||||
)
|
||||
from .exceptions import LLMErrorCode, LLMException, get_user_friendly_error_message
|
||||
from .models import (
|
||||
LLMCacheInfo,
|
||||
@ -52,6 +58,7 @@ __all__ = [
|
||||
"ModelProvider",
|
||||
"ProviderConfig",
|
||||
"ResponseFormat",
|
||||
"TaskType",
|
||||
"ToolCategory",
|
||||
"ToolMetadata",
|
||||
"UsageInfo",
|
||||
|
||||
@ -45,6 +45,17 @@ class ToolCategory(Enum):
|
||||
CUSTOM = auto()
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
"""任务类型枚举"""
|
||||
|
||||
CHAT = "chat"
|
||||
CODE = "code"
|
||||
SEARCH = "search"
|
||||
ANALYSIS = "analysis"
|
||||
GENERATION = "generation"
|
||||
MULTIMODAL = "multimodal"
|
||||
|
||||
|
||||
class LLMErrorCode(Enum):
|
||||
"""LLM 服务相关的错误代码枚举"""
|
||||
|
||||
|
||||
@ -469,7 +469,7 @@ class Notebook:
|
||||
template_name="main.html",
|
||||
templates={"elements": self._data},
|
||||
pages={
|
||||
"viewport": {"width": 700, "height": 1000},
|
||||
"viewport": {"width": 700, "height": 10},
|
||||
"base_url": f"file://{TEMPLATE_PATH}",
|
||||
},
|
||||
wait=2,
|
||||
|
||||
@ -53,9 +53,7 @@ class CommonUtils:
|
||||
if await GroupConsole.is_block_task(group_id, module):
|
||||
"""群组是否禁用被动"""
|
||||
return True
|
||||
if g := await GroupConsole.get_or_none(
|
||||
group_id=group_id, channel_id__isnull=True
|
||||
):
|
||||
if g := await GroupConsole.get_group(group_id=group_id):
|
||||
"""群组权限是否小于0"""
|
||||
if g.level < 0:
|
||||
return True
|
||||
|
||||
@ -44,6 +44,44 @@ class EventLogType(StrEnum):
|
||||
"""主动退群"""
|
||||
|
||||
|
||||
class CacheType(StrEnum):
|
||||
"""
|
||||
缓存类型
|
||||
"""
|
||||
|
||||
PLUGINS = "GLOBAL_ALL_PLUGINS"
|
||||
"""全局全部插件"""
|
||||
GROUPS = "GLOBAL_ALL_GROUPS"
|
||||
"""全局全部群组"""
|
||||
USERS = "GLOBAL_ALL_USERS"
|
||||
"""全部用户"""
|
||||
BAN = "GLOBAL_ALL_BAN"
|
||||
"""全局ban列表"""
|
||||
BOT = "GLOBAL_BOT"
|
||||
"""全局bot信息"""
|
||||
LEVEL = "GLOBAL_USER_LEVEL"
|
||||
"""用户权限"""
|
||||
LIMIT = "GLOBAL_LIMIT"
|
||||
"""插件限制"""
|
||||
|
||||
|
||||
class DbLockType(StrEnum):
|
||||
"""
|
||||
锁类型
|
||||
"""
|
||||
|
||||
CREATE = "CREATE"
|
||||
"""创建"""
|
||||
DELETE = "DELETE"
|
||||
"""删除"""
|
||||
UPDATE = "UPDATE"
|
||||
"""更新"""
|
||||
QUERY = "QUERY"
|
||||
"""查询"""
|
||||
UPSERT = "UPSERT"
|
||||
"""创建或更新"""
|
||||
|
||||
|
||||
class GoldHandle(StrEnum):
|
||||
"""
|
||||
金币处理
|
||||
|
||||
@ -49,6 +49,9 @@ async def _():
|
||||
try:
|
||||
for priority in priority_list:
|
||||
for func in priority_data[priority]:
|
||||
logger.debug(
|
||||
f"执行优先级 [{priority}] on_startup 方法: {func.__module__}"
|
||||
)
|
||||
if is_coroutine_callable(func):
|
||||
await func()
|
||||
else:
|
||||
|
||||
@ -1,12 +1,14 @@
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from datetime import date, datetime
|
||||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import httpx
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
import pypinyin
|
||||
import pytz
|
||||
|
||||
@ -14,43 +16,53 @@ from zhenxun.configs.config import Config
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class EntityIDs:
|
||||
user_id: str
|
||||
"""用户id"""
|
||||
group_id: str | None
|
||||
"""群组id"""
|
||||
channel_id: str | None
|
||||
"""频道id"""
|
||||
|
||||
|
||||
class ResourceDirManager:
|
||||
"""
|
||||
临时文件管理器
|
||||
"""
|
||||
|
||||
temp_path = [] # noqa: RUF012
|
||||
temp_path: ClassVar[set[Path]] = set()
|
||||
|
||||
@classmethod
|
||||
def __tree_append(cls, path: Path):
|
||||
"""递归添加文件夹
|
||||
|
||||
参数:
|
||||
path: 文件夹路径
|
||||
"""
|
||||
def __tree_append(cls, path: Path, deep: int = 1, current: int = 0):
|
||||
"""递归添加文件夹"""
|
||||
if current >= deep and deep != -1:
|
||||
return
|
||||
path = path.resolve() # 标准化路径
|
||||
for f in os.listdir(path):
|
||||
file = path / f
|
||||
file = (path / f).resolve() # 标准化子路径
|
||||
if file.is_dir():
|
||||
if file not in cls.temp_path:
|
||||
cls.temp_path.append(file)
|
||||
logger.debug(f"添加临时文件夹: {path}")
|
||||
cls.__tree_append(file)
|
||||
cls.temp_path.add(file)
|
||||
logger.debug(f"添加临时文件夹: {file}")
|
||||
cls.__tree_append(file, deep, current + 1)
|
||||
|
||||
@classmethod
|
||||
def add_temp_dir(cls, path: str | Path, tree: bool = False):
|
||||
def add_temp_dir(cls, path: str | Path, tree: bool = False, deep: int = 1):
|
||||
"""添加临时清理文件夹,这些文件夹会被自动清理
|
||||
|
||||
参数:
|
||||
path: 文件夹路径
|
||||
tree: 是否递归添加文件夹
|
||||
deep: 深度, -1 为无限深度
|
||||
"""
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
if path not in cls.temp_path:
|
||||
cls.temp_path.append(path)
|
||||
cls.temp_path.add(path)
|
||||
logger.debug(f"添加临时文件夹: {path}")
|
||||
if tree:
|
||||
cls.__tree_append(path)
|
||||
cls.__tree_append(path, deep)
|
||||
|
||||
|
||||
class CountLimiter:
|
||||
@ -231,6 +243,27 @@ def is_valid_date(date_text: str, separator: str = "-") -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def get_entity_ids(session: Uninfo) -> EntityIDs:
|
||||
"""获取用户id,群组id,频道id
|
||||
|
||||
参数:
|
||||
session: Uninfo
|
||||
|
||||
返回:
|
||||
EntityIDs: 用户id,群组id,频道id
|
||||
"""
|
||||
user_id = session.user.id
|
||||
group_id = None
|
||||
channel_id = None
|
||||
if session.group:
|
||||
if session.group.parent:
|
||||
group_id = session.group.parent.id
|
||||
channel_id = session.group.id
|
||||
else:
|
||||
group_id = session.group.id
|
||||
return EntityIDs(user_id=user_id, group_id=group_id, channel_id=channel_id)
|
||||
|
||||
|
||||
def is_number(text: str) -> bool:
|
||||
"""是否为数字
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user