mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-14 21:52:56 +08:00
Compare commits
3 Commits
f94121080f
...
1cc18bb195
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1cc18bb195 | ||
|
|
74a9f3a843 | ||
|
|
e7f3c210df |
@ -58,5 +58,14 @@ Config.add_plugin_config(
|
||||
type=bool,
|
||||
)
|
||||
|
||||
Config.add_plugin_config(
|
||||
"hook",
|
||||
"AUTH_HOOKS_CONCURRENCY_LIMIT",
|
||||
5,
|
||||
help="同步进入权限钩子最大并发数",
|
||||
default_value=5,
|
||||
type=int,
|
||||
)
|
||||
|
||||
|
||||
nonebot.load_plugins(str(Path(__file__).parent.resolve()))
|
||||
|
||||
@ -96,7 +96,6 @@ async def is_ban(user_id: str | None, group_id: str | None) -> int:
|
||||
f"查询ban记录超时: user_id={user_id}, group_id={group_id}",
|
||||
LOGGER_COMMAND,
|
||||
)
|
||||
# 超时时返回0,避免阻塞
|
||||
return 0
|
||||
|
||||
# 检查记录并计算ban时间
|
||||
@ -199,7 +198,7 @@ async def group_handle(group_id: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
async def user_handle(module: str, entity: EntityIDs, session: Uninfo) -> None:
|
||||
async def user_handle(plugin: PluginInfo, entity: EntityIDs, session: Uninfo) -> None:
|
||||
"""用户ban检查
|
||||
|
||||
参数:
|
||||
@ -217,22 +216,12 @@ async def user_handle(module: str, entity: EntityIDs, session: Uninfo) -> None:
|
||||
if not time_val:
|
||||
return
|
||||
time_str = format_time(time_val)
|
||||
plugin_dao = DataAccess(PluginInfo)
|
||||
try:
|
||||
db_plugin = await asyncio.wait_for(
|
||||
plugin_dao.safe_get_or_none(module=module), timeout=DB_TIMEOUT_SECONDS
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"查询插件信息超时: {module}", LOGGER_COMMAND)
|
||||
# 超时时不阻塞,继续执行
|
||||
raise SkipPluginException("用户处于黑名单中...")
|
||||
|
||||
if (
|
||||
db_plugin
|
||||
and not db_plugin.ignore_prompt
|
||||
plugin
|
||||
and time_val != -1
|
||||
and ban_result
|
||||
and freq.is_send_limit_message(db_plugin, entity.user_id, False)
|
||||
and freq.is_send_limit_message(plugin, entity.user_id, False)
|
||||
):
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
@ -260,7 +249,9 @@ async def user_handle(module: str, entity: EntityIDs, session: Uninfo) -> None:
|
||||
)
|
||||
|
||||
|
||||
async def auth_ban(matcher: Matcher, bot: Bot, session: Uninfo) -> None:
|
||||
async def auth_ban(
|
||||
matcher: Matcher, bot: Bot, session: Uninfo, plugin: PluginInfo
|
||||
) -> None:
|
||||
"""权限检查 - ban 检查
|
||||
|
||||
参数:
|
||||
@ -289,7 +280,7 @@ async def auth_ban(matcher: Matcher, bot: Bot, session: Uninfo) -> None:
|
||||
if entity.user_id:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
user_handle(matcher.plugin_name, entity, session),
|
||||
user_handle(plugin, entity, session),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
|
||||
@ -1,50 +1,36 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from nonebot_plugin_alconna import UniMsg
|
||||
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.services.data_access import DataAccess
|
||||
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.utils import EntityIDs
|
||||
|
||||
from .config import LOGGER_COMMAND, WARNING_THRESHOLD, SwitchEnum
|
||||
from .exception import SkipPluginException
|
||||
|
||||
|
||||
async def auth_group(plugin: PluginInfo, entity: EntityIDs, message: UniMsg):
|
||||
async def auth_group(
|
||||
plugin: PluginInfo,
|
||||
group: GroupConsole | None,
|
||||
message: UniMsg,
|
||||
group_id: str | None,
|
||||
):
|
||||
"""群黑名单检测 群总开关检测
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
entity: EntityIDs
|
||||
group: GroupConsole
|
||||
message: UniMsg
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
if not entity.group_id:
|
||||
if not group_id:
|
||||
return
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
text = message.extract_plain_text()
|
||||
|
||||
# 从数据库或缓存中获取群组信息
|
||||
group_dao = DataAccess(GroupConsole)
|
||||
|
||||
try:
|
||||
group: GroupConsole | None = await asyncio.wait_for(
|
||||
group_dao.safe_get_or_none(
|
||||
group_id=entity.group_id, channel_id__isnull=True
|
||||
),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("查询群组信息超时", LOGGER_COMMAND, session=entity.user_id)
|
||||
# 超时时不阻塞,继续执行
|
||||
return
|
||||
|
||||
if not group:
|
||||
raise SkipPluginException("群组信息不存在...")
|
||||
if group.level < 0:
|
||||
@ -63,6 +49,5 @@ async def auth_group(plugin: PluginInfo, entity: EntityIDs, message: UniMsg):
|
||||
logger.warning(
|
||||
f"auth_group 耗时: {elapsed:.3f}s, plugin={plugin.module}",
|
||||
LOGGER_COMMAND,
|
||||
session=entity.user_id,
|
||||
group_id=entity.group_id,
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
@ -6,12 +6,10 @@ from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.services.data_access import DataAccess
|
||||
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.common_utils import CommonUtils
|
||||
from zhenxun.utils.enum import BlockType
|
||||
from zhenxun.utils.utils import get_entity_ids
|
||||
|
||||
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
|
||||
from .exception import IsSuperuserException, SkipPluginException
|
||||
@ -20,30 +18,17 @@ from .utils import freq, is_poke, send_message
|
||||
|
||||
class GroupCheck:
|
||||
def __init__(
|
||||
self, plugin: PluginInfo, group_id: str, session: Uninfo, is_poke: bool
|
||||
self, plugin: PluginInfo, group: GroupConsole, session: Uninfo, is_poke: bool
|
||||
) -> None:
|
||||
self.group_id = group_id
|
||||
self.session = session
|
||||
self.is_poke = is_poke
|
||||
self.plugin = plugin
|
||||
self.group_dao = DataAccess(GroupConsole)
|
||||
self.group_data = None
|
||||
self.group_data = group
|
||||
self.group_id = group.group_id
|
||||
|
||||
async def check(self):
|
||||
start_time = time.time()
|
||||
try:
|
||||
# 只查询一次数据库,使用 DataAccess 的缓存机制
|
||||
try:
|
||||
self.group_data = await asyncio.wait_for(
|
||||
self.group_dao.safe_get_or_none(
|
||||
group_id=self.group_id, channel_id__isnull=True
|
||||
),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"查询群组数据超时: {self.group_id}", LOGGER_COMMAND)
|
||||
return # 超时时不阻塞,继续执行
|
||||
|
||||
# 检查超级用户禁用
|
||||
if (
|
||||
self.group_data
|
||||
@ -113,12 +98,13 @@ class GroupCheck:
|
||||
|
||||
|
||||
class PluginCheck:
|
||||
def __init__(self, group_id: str | None, session: Uninfo, is_poke: bool):
|
||||
def __init__(self, group: GroupConsole | None, session: Uninfo, is_poke: bool):
|
||||
self.session = session
|
||||
self.is_poke = is_poke
|
||||
self.group_id = group_id
|
||||
self.group_dao = DataAccess(GroupConsole)
|
||||
self.group_data = None
|
||||
self.group_data = group
|
||||
self.group_id = None
|
||||
if group:
|
||||
self.group_id = group.group_id
|
||||
|
||||
async def check_user(self, plugin: PluginInfo):
|
||||
"""全局私聊禁用检测
|
||||
@ -156,21 +142,8 @@ class PluginCheck:
|
||||
if plugin.status or plugin.block_type != BlockType.ALL:
|
||||
return
|
||||
"""全局状态"""
|
||||
if self.group_id:
|
||||
# 使用 DataAccess 的缓存机制
|
||||
try:
|
||||
self.group_data = await asyncio.wait_for(
|
||||
self.group_dao.safe_get_or_none(
|
||||
group_id=self.group_id, channel_id__isnull=True
|
||||
),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"查询群组数据超时: {self.group_id}", LOGGER_COMMAND)
|
||||
return # 超时时不阻塞,继续执行
|
||||
|
||||
if self.group_data and self.group_data.is_super:
|
||||
raise IsSuperuserException()
|
||||
if self.group_data and self.group_data.is_super:
|
||||
raise IsSuperuserException()
|
||||
|
||||
sid = self.group_id or self.session.user.id
|
||||
if freq.is_send_limit_message(plugin, sid, self.is_poke):
|
||||
@ -193,7 +166,9 @@ class PluginCheck:
|
||||
)
|
||||
|
||||
|
||||
async def auth_plugin(plugin: PluginInfo, session: Uninfo, event: Event):
|
||||
async def auth_plugin(
|
||||
plugin: PluginInfo, group: GroupConsole | None, session: Uninfo, event: Event
|
||||
):
|
||||
"""插件状态
|
||||
|
||||
参数:
|
||||
@ -203,35 +178,23 @@ async def auth_plugin(plugin: PluginInfo, session: Uninfo, event: Event):
|
||||
"""
|
||||
start_time = time.time()
|
||||
try:
|
||||
entity = get_entity_ids(session)
|
||||
is_poke_event = is_poke(event)
|
||||
user_check = PluginCheck(entity.group_id, session, is_poke_event)
|
||||
user_check = PluginCheck(group, session, is_poke_event)
|
||||
|
||||
if entity.group_id:
|
||||
group_check = GroupCheck(plugin, entity.group_id, session, is_poke_event)
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
group_check.check(), timeout=DB_TIMEOUT_SECONDS * 2
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"群组检查超时: {entity.group_id}", LOGGER_COMMAND)
|
||||
# 超时时不阻塞,继续执行
|
||||
tasks = []
|
||||
if group:
|
||||
tasks.append(GroupCheck(plugin, group, session, is_poke_event).check())
|
||||
else:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
user_check.check_user(plugin), timeout=DB_TIMEOUT_SECONDS
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("用户检查超时", LOGGER_COMMAND)
|
||||
# 超时时不阻塞,继续执行
|
||||
tasks.append(user_check.check_user(plugin))
|
||||
tasks.append(user_check.check_global(plugin))
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
user_check.check_global(plugin), timeout=DB_TIMEOUT_SECONDS
|
||||
asyncio.gather(*tasks), timeout=DB_TIMEOUT_SECONDS * 2
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("全局检查超时", LOGGER_COMMAND)
|
||||
# 超时时不阻塞,继续执行
|
||||
logger.error("插件用户/群组/全局检查超时...", LOGGER_COMMAND)
|
||||
|
||||
finally:
|
||||
# 记录总执行时间
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
@ -85,7 +85,7 @@ class FreqUtils:
|
||||
return False
|
||||
if plugin.plugin_type == PluginType.DEPENDANT:
|
||||
return False
|
||||
return plugin.module != "ai" if self._flmt_s.check(sid) else False
|
||||
return False if plugin.ignore_prompt else self._flmt_s.check(sid)
|
||||
|
||||
|
||||
freq = FreqUtils()
|
||||
|
||||
@ -8,6 +8,7 @@ from nonebot_plugin_alconna import UniMsg
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
from tortoise.exceptions import IntegrityError
|
||||
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.models.user_console import UserConsole
|
||||
from zhenxun.services.data_access import DataAccess
|
||||
@ -31,6 +32,7 @@ from .auth.exception import (
|
||||
PermissionExemption,
|
||||
SkipPluginException,
|
||||
)
|
||||
from .auth.utils import base_config
|
||||
|
||||
# 超时设置(秒)
|
||||
TIMEOUT_SECONDS = 5.0
|
||||
@ -46,6 +48,16 @@ CIRCUIT_BREAKERS = {
|
||||
# 熔断重置时间(秒)
|
||||
CIRCUIT_RESET_TIME = 300 # 5分钟
|
||||
|
||||
# 并发控制:限制同时进入 hooks 并行检查的协程数
|
||||
|
||||
# 默认为 6,可通过环境变量 AUTH_HOOKS_CONCURRENCY_LIMIT 调整
|
||||
HOOKS_CONCURRENCY_LIMIT = base_config.get("AUTH_HOOKS_CONCURRENCY_LIMIT")
|
||||
|
||||
# 全局信号量与计数器
|
||||
HOOKS_SEMAPHORE = asyncio.Semaphore(HOOKS_CONCURRENCY_LIMIT)
|
||||
HOOKS_ACTIVE_COUNT = 0
|
||||
HOOKS_ACTIVE_LOCK = asyncio.Lock()
|
||||
|
||||
|
||||
# 超时装饰器
|
||||
async def with_timeout(coro, timeout=TIMEOUT_SECONDS, name=None):
|
||||
@ -259,6 +271,30 @@ async def time_hook(coro, name, time_dict):
|
||||
time_dict[name] = f"{time.time() - start:.3f}s"
|
||||
|
||||
|
||||
async def _enter_hooks_section():
|
||||
"""尝试获取全局信号量并更新计数器,超时则抛出 PermissionExemption。"""
|
||||
global HOOKS_ACTIVE_COUNT
|
||||
# 队列模式:如果达到上限,协程将排队等待直到获取到信号量
|
||||
await HOOKS_SEMAPHORE.acquire()
|
||||
async with HOOKS_ACTIVE_LOCK:
|
||||
HOOKS_ACTIVE_COUNT += 1
|
||||
logger.debug(f"当前并发权限检查数量: {HOOKS_ACTIVE_COUNT}", LOGGER_COMMAND)
|
||||
|
||||
|
||||
async def _leave_hooks_section():
|
||||
"""释放信号量并更新计数器。"""
|
||||
global HOOKS_ACTIVE_COUNT
|
||||
from contextlib import suppress
|
||||
|
||||
with suppress(Exception):
|
||||
HOOKS_SEMAPHORE.release()
|
||||
async with HOOKS_ACTIVE_LOCK:
|
||||
HOOKS_ACTIVE_COUNT -= 1
|
||||
# 保证计数不为负
|
||||
HOOKS_ACTIVE_COUNT = max(HOOKS_ACTIVE_COUNT, 0)
|
||||
logger.debug(f"当前并发权限检查数量: {HOOKS_ACTIVE_COUNT}", LOGGER_COMMAND)
|
||||
|
||||
|
||||
async def auth(
|
||||
matcher: Matcher,
|
||||
event: Event,
|
||||
@ -285,6 +321,9 @@ async def auth(
|
||||
hook_times = {}
|
||||
hooks_time = 0 # 初始化 hooks_time 变量
|
||||
|
||||
# 记录是否已进入 hooks 区域(用于 finally 中释放)
|
||||
entered_hooks = False
|
||||
|
||||
try:
|
||||
if not module:
|
||||
raise PermissionExemption("Matcher插件名称不存在...")
|
||||
@ -304,6 +343,10 @@ async def auth(
|
||||
)
|
||||
raise PermissionExemption("获取插件和用户数据超时,请稍后再试...")
|
||||
|
||||
# 进入 hooks 并行检查区域(会在高并发时排队)
|
||||
await _enter_hooks_section()
|
||||
entered_hooks = True
|
||||
|
||||
# 获取插件费用
|
||||
cost_start = time.time()
|
||||
try:
|
||||
@ -320,16 +363,32 @@ async def auth(
|
||||
# 执行 bot_filter
|
||||
bot_filter(session)
|
||||
|
||||
group = None
|
||||
if entity.group_id:
|
||||
group_dao = DataAccess(GroupConsole)
|
||||
group = await with_timeout(
|
||||
group_dao.safe_get_or_none(
|
||||
group_id=entity.group_id, channel_id__isnull=True
|
||||
),
|
||||
name="get_group",
|
||||
)
|
||||
|
||||
# 并行执行所有 hook 检查,并记录执行时间
|
||||
hooks_start = time.time()
|
||||
|
||||
# 创建所有 hook 任务
|
||||
hook_tasks = [
|
||||
time_hook(auth_ban(matcher, bot, session), "auth_ban", hook_times),
|
||||
time_hook(auth_ban(matcher, bot, session, plugin), "auth_ban", hook_times),
|
||||
time_hook(auth_bot(plugin, bot.self_id), "auth_bot", hook_times),
|
||||
time_hook(auth_group(plugin, entity, message), "auth_group", hook_times),
|
||||
time_hook(
|
||||
auth_group(plugin, group, message, entity.group_id),
|
||||
"auth_group",
|
||||
hook_times,
|
||||
),
|
||||
time_hook(auth_admin(plugin, session), "auth_admin", hook_times),
|
||||
time_hook(auth_plugin(plugin, session, event), "auth_plugin", hook_times),
|
||||
time_hook(
|
||||
auth_plugin(plugin, group, session, event), "auth_plugin", hook_times
|
||||
),
|
||||
time_hook(auth_limit(plugin, session), "auth_limit", hook_times),
|
||||
]
|
||||
|
||||
@ -358,7 +417,17 @@ async def auth(
|
||||
logger.debug("超级用户跳过权限检测...", LOGGER_COMMAND, session=session)
|
||||
except PermissionExemption as e:
|
||||
logger.info(str(e), LOGGER_COMMAND, session=session)
|
||||
|
||||
finally:
|
||||
# 如果进入过 hooks 区域,确保释放信号量(即使上层处理抛出了异常)
|
||||
if entered_hooks:
|
||||
try:
|
||||
await _leave_hooks_section()
|
||||
except Exception:
|
||||
logger.error(
|
||||
"释放 hooks 信号量时出错",
|
||||
LOGGER_COMMAND,
|
||||
session=session,
|
||||
)
|
||||
# 扣除金币
|
||||
if not ignore_flag and cost_gold > 0:
|
||||
gold_start = time.time()
|
||||
|
||||
@ -43,18 +43,20 @@ class BanCheckLimiter:
|
||||
|
||||
def check(self, key: str | float) -> bool:
|
||||
if time.time() - self.mtime[key] > self.default_check_time:
|
||||
self.mtime[key] = time.time()
|
||||
self.mint[key] = 0
|
||||
return False
|
||||
return self._extracted_from_check_3(key, False)
|
||||
if (
|
||||
self.mint[key] >= self.default_count
|
||||
and time.time() - self.mtime[key] < self.default_check_time
|
||||
):
|
||||
self.mtime[key] = time.time()
|
||||
self.mint[key] = 0
|
||||
return True
|
||||
return self._extracted_from_check_3(key, True)
|
||||
return False
|
||||
|
||||
# TODO Rename this here and in `check`
|
||||
def _extracted_from_check_3(self, key, arg1):
|
||||
self.mtime[key] = time.time()
|
||||
self.mint[key] = 0
|
||||
return arg1
|
||||
|
||||
|
||||
_blmt = BanCheckLimiter(
|
||||
malicious_check_time,
|
||||
@ -70,16 +72,15 @@ async def _(
|
||||
module = None
|
||||
if plugin := matcher.plugin:
|
||||
module = plugin.module_name
|
||||
if metadata := plugin.metadata:
|
||||
extra = metadata.extra
|
||||
if extra.get("plugin_type") in [
|
||||
PluginType.HIDDEN,
|
||||
PluginType.DEPENDANT,
|
||||
PluginType.ADMIN,
|
||||
PluginType.SUPERUSER,
|
||||
]:
|
||||
return
|
||||
else:
|
||||
if not (metadata := plugin.metadata):
|
||||
return
|
||||
extra = metadata.extra
|
||||
if extra.get("plugin_type") in [
|
||||
PluginType.HIDDEN,
|
||||
PluginType.DEPENDANT,
|
||||
PluginType.ADMIN,
|
||||
PluginType.SUPERUSER,
|
||||
]:
|
||||
return
|
||||
if matcher.type == "notice":
|
||||
return
|
||||
@ -88,32 +89,31 @@ async def _(
|
||||
malicious_ban_time = Config.get_config("hook", "MALICIOUS_BAN_TIME")
|
||||
if not malicious_ban_time:
|
||||
raise ValueError("模块: [hook], 配置项: [MALICIOUS_BAN_TIME] 为空或小于0")
|
||||
if user_id:
|
||||
if module:
|
||||
if _blmt.check(f"{user_id}__{module}"):
|
||||
await BanConsole.ban(
|
||||
user_id,
|
||||
group_id,
|
||||
9,
|
||||
"恶意触发命令检测",
|
||||
malicious_ban_time * 60,
|
||||
bot.self_id,
|
||||
)
|
||||
logger.info(
|
||||
f"触发了恶意触发检测: {matcher.plugin_name}",
|
||||
"HOOK",
|
||||
session=session,
|
||||
)
|
||||
await MessageUtils.build_message(
|
||||
[
|
||||
At(flag="user", target=user_id),
|
||||
"检测到恶意触发命令,您将被封禁 30 分钟",
|
||||
]
|
||||
).send()
|
||||
logger.debug(
|
||||
f"触发了恶意触发检测: {matcher.plugin_name}",
|
||||
"HOOK",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException("检测到恶意触发命令")
|
||||
_blmt.add(f"{user_id}__{module}")
|
||||
if user_id and module:
|
||||
if _blmt.check(f"{user_id}__{module}"):
|
||||
await BanConsole.ban(
|
||||
user_id,
|
||||
group_id,
|
||||
9,
|
||||
"恶意触发命令检测",
|
||||
malicious_ban_time * 60,
|
||||
bot.self_id,
|
||||
)
|
||||
logger.info(
|
||||
f"触发了恶意触发检测: {matcher.plugin_name}",
|
||||
"HOOK",
|
||||
session=session,
|
||||
)
|
||||
await MessageUtils.build_message(
|
||||
[
|
||||
At(flag="user", target=user_id),
|
||||
"检测到恶意触发命令,您将被封禁 30 分钟",
|
||||
]
|
||||
).send()
|
||||
logger.debug(
|
||||
f"触发了恶意触发检测: {matcher.plugin_name}",
|
||||
"HOOK",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException("检测到恶意触发命令")
|
||||
_blmt.add(f"{user_id}__{module}")
|
||||
|
||||
@ -367,7 +367,7 @@ class ShopManage:
|
||||
else:
|
||||
goods_info = await GoodsInfo.get_or_none(goods_name=goods_name)
|
||||
if not goods_info:
|
||||
return f"{goods_name} 不存在..."
|
||||
return "对应的道具不存在..."
|
||||
if goods_info.is_passive:
|
||||
return f"{goods_info.goods_name} 是被动道具, 无法使用..."
|
||||
goods = cls.uuid2goods.get(goods_info.uuid)
|
||||
|
||||
4
zhenxun/services/cache/__init__.py
vendored
4
zhenxun/services/cache/__init__.py
vendored
@ -98,6 +98,7 @@ from .cache_containers import CacheDict, CacheList
|
||||
from .config import (
|
||||
CACHE_KEY_PREFIX,
|
||||
CACHE_KEY_SEPARATOR,
|
||||
CACHE_TIMEOUT,
|
||||
DEFAULT_EXPIRE,
|
||||
LOG_COMMAND,
|
||||
SPECIAL_KEY_FORMATS,
|
||||
@ -551,7 +552,6 @@ class CacheManager:
|
||||
返回:
|
||||
Any: 缓存数据,如果不存在返回默认值
|
||||
"""
|
||||
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
||||
|
||||
# 如果缓存被禁用或缓存模式为NONE,直接返回默认值
|
||||
if not self.enabled or cache_config.cache_mode == CacheMode.NONE:
|
||||
@ -561,7 +561,7 @@ class CacheManager:
|
||||
cache_key = self._build_key(cache_type, key)
|
||||
data = await asyncio.wait_for(
|
||||
self.cache_backend.get(cache_key), # type: ignore
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
timeout=CACHE_TIMEOUT,
|
||||
)
|
||||
|
||||
if data is None:
|
||||
|
||||
3
zhenxun/services/cache/config.py
vendored
3
zhenxun/services/cache/config.py
vendored
@ -5,6 +5,9 @@
|
||||
# 日志标识
|
||||
LOG_COMMAND = "CacheRoot"
|
||||
|
||||
# 缓存获取超时时间(秒)
|
||||
CACHE_TIMEOUT = 10
|
||||
|
||||
# 默认缓存过期时间(秒)
|
||||
DEFAULT_EXPIRE = 600
|
||||
|
||||
|
||||
@ -27,5 +27,8 @@ async def with_db_timeout(
|
||||
return result
|
||||
except asyncio.TimeoutError:
|
||||
if operation:
|
||||
logger.error(f"数据库操作超时: {operation} (>{timeout}s)", LOG_COMMAND)
|
||||
logger.error(
|
||||
f"数据库操作超时: {operation} (>{timeout}s) 来源: {source}",
|
||||
LOG_COMMAND,
|
||||
)
|
||||
raise
|
||||
|
||||
@ -35,7 +35,7 @@ class ResponseData(BaseModel):
|
||||
"""响应数据封装 - 支持所有高级功能"""
|
||||
|
||||
text: str
|
||||
image_bytes: bytes | None = None
|
||||
images: list[bytes] | None = None
|
||||
usage_info: dict[str, Any] | None = None
|
||||
raw_response: dict[str, Any] | None = None
|
||||
tool_calls: list[LLMToolCall] | None = None
|
||||
@ -246,17 +246,17 @@ class BaseAdapter(ABC):
|
||||
if content:
|
||||
content = content.strip()
|
||||
|
||||
image_bytes: bytes | None = None
|
||||
images_bytes: list[bytes] = []
|
||||
if content and content.startswith("{") and content.endswith("}"):
|
||||
try:
|
||||
content_json = json.loads(content)
|
||||
if "b64_json" in content_json:
|
||||
image_bytes = base64.b64decode(content_json["b64_json"])
|
||||
images_bytes.append(base64.b64decode(content_json["b64_json"]))
|
||||
content = "[图片已生成]"
|
||||
elif "data" in content_json and isinstance(
|
||||
content_json["data"], str
|
||||
):
|
||||
image_bytes = base64.b64decode(content_json["data"])
|
||||
images_bytes.append(base64.b64decode(content_json["data"]))
|
||||
content = "[图片已生成]"
|
||||
|
||||
except (json.JSONDecodeError, KeyError, binascii.Error):
|
||||
@ -273,7 +273,7 @@ class BaseAdapter(ABC):
|
||||
if url_str.startswith("data:image/png;base64,"):
|
||||
try:
|
||||
b64_data = url_str.split(",", 1)[1]
|
||||
image_bytes = base64.b64decode(b64_data)
|
||||
images_bytes.append(base64.b64decode(b64_data))
|
||||
content = content if content else "[图片已生成]"
|
||||
except (IndexError, binascii.Error) as e:
|
||||
logger.warning(f"解析OpenRouter Base64图片数据失败: {e}")
|
||||
@ -316,7 +316,7 @@ class BaseAdapter(ABC):
|
||||
text=final_text,
|
||||
tool_calls=parsed_tool_calls,
|
||||
usage_info=usage_info,
|
||||
image_bytes=image_bytes,
|
||||
images=images_bytes if images_bytes else None,
|
||||
raw_response=response_json,
|
||||
)
|
||||
|
||||
|
||||
@ -408,7 +408,7 @@ class GeminiAdapter(BaseAdapter):
|
||||
parts = content_data.get("parts", [])
|
||||
|
||||
text_content = ""
|
||||
image_bytes: bytes | None = None
|
||||
images_bytes: list[bytes] = []
|
||||
parsed_tool_calls: list["LLMToolCall"] | None = None
|
||||
thought_summary_parts = []
|
||||
answer_parts = []
|
||||
@ -423,10 +423,7 @@ class GeminiAdapter(BaseAdapter):
|
||||
elif "inlineData" in part:
|
||||
inline_data = part["inlineData"]
|
||||
if "data" in inline_data:
|
||||
image_bytes = base64.b64decode(inline_data["data"])
|
||||
answer_parts.append(
|
||||
f"[图片已生成: {inline_data.get('mimeType', 'image')}]"
|
||||
)
|
||||
images_bytes.append(base64.b64decode(inline_data["data"]))
|
||||
|
||||
elif "functionCall" in part:
|
||||
if parsed_tool_calls is None:
|
||||
@ -494,7 +491,7 @@ class GeminiAdapter(BaseAdapter):
|
||||
return ResponseData(
|
||||
text=text_content,
|
||||
tool_calls=parsed_tool_calls,
|
||||
image_bytes=image_bytes,
|
||||
images=images_bytes if images_bytes else None,
|
||||
usage_info=usage_info,
|
||||
raw_response=response_json,
|
||||
grounding_metadata=grounding_metadata_obj,
|
||||
|
||||
@ -339,7 +339,7 @@ async def _generate_image_from_message(
|
||||
|
||||
response = await model_instance.generate_response(messages, config=config)
|
||||
|
||||
if not response.image_bytes:
|
||||
if not response.images:
|
||||
error_text = response.text or "模型未返回图片数据。"
|
||||
logger.warning(f"图片生成调用未返回图片,返回文本内容: {error_text}")
|
||||
|
||||
|
||||
@ -5,12 +5,12 @@ LLM 模型管理器
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.pydantic_compat import dump_json_safely
|
||||
|
||||
from .config import validate_override_params
|
||||
from .config.providers import AI_CONFIG_GROUP, PROVIDERS_CONFIG_KEY, get_ai_config
|
||||
@ -43,7 +43,7 @@ def _make_cache_key(
|
||||
) -> str:
|
||||
"""生成缓存键"""
|
||||
config_str = (
|
||||
json.dumps(override_config, sort_keys=True) if override_config else "None"
|
||||
dump_json_safely(override_config, sort_keys=True) if override_config else "None"
|
||||
)
|
||||
key_data = f"{provider_model_name}:{config_str}"
|
||||
return hashlib.md5(key_data.encode()).hexdigest()
|
||||
|
||||
@ -13,6 +13,7 @@ from pydantic import BaseModel
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.log_sanitizer import sanitize_for_logging
|
||||
from zhenxun.utils.pydantic_compat import dump_json_safely
|
||||
|
||||
from .adapters.base import RequestData
|
||||
from .config import LLMGenerationConfig
|
||||
@ -194,13 +195,15 @@ class LLMModel(LLMModelBase):
|
||||
sanitized_body = sanitize_for_logging(
|
||||
request_data.body, context=sanitizer_req_context
|
||||
)
|
||||
request_body_str = json.dumps(sanitized_body, ensure_ascii=False, indent=2)
|
||||
request_body_str = dump_json_safely(
|
||||
sanitized_body, ensure_ascii=False, indent=2
|
||||
)
|
||||
logger.debug(f"📦 请求体: {request_body_str}")
|
||||
|
||||
http_response = await http_client.post(
|
||||
request_data.url,
|
||||
headers=request_data.headers,
|
||||
json=request_data.body,
|
||||
content=dump_json_safely(request_data.body, ensure_ascii=False),
|
||||
)
|
||||
|
||||
logger.debug(f"📥 响应状态码: {http_response.status_code}")
|
||||
@ -394,7 +397,7 @@ class LLMModel(LLMModelBase):
|
||||
return LLMResponse(
|
||||
text=response_data.text,
|
||||
usage_info=response_data.usage_info,
|
||||
image_bytes=response_data.image_bytes,
|
||||
images=response_data.images,
|
||||
raw_response=response_data.raw_response,
|
||||
tool_calls=response_tool_calls if response_tool_calls else None,
|
||||
code_executions=response_data.code_executions,
|
||||
@ -424,7 +427,7 @@ class LLMModel(LLMModelBase):
|
||||
|
||||
policy = config.validation_policy
|
||||
if policy:
|
||||
if policy.get("require_image") and not parsed_data.image_bytes:
|
||||
if policy.get("require_image") and not parsed_data.images:
|
||||
if self.api_type == "gemini" and parsed_data.raw_response:
|
||||
usage_metadata = parsed_data.raw_response.get(
|
||||
"usageMetadata", {}
|
||||
|
||||
@ -425,7 +425,7 @@ class LLMResponse(BaseModel):
|
||||
"""LLM 响应"""
|
||||
|
||||
text: str
|
||||
image_bytes: bytes | None = None
|
||||
images: list[bytes] | None = None
|
||||
usage_info: dict[str, Any] | None = None
|
||||
raw_response: dict[str, Any] | None = None
|
||||
tool_calls: list[Any] | None = None
|
||||
|
||||
@ -217,16 +217,17 @@ class RendererService:
|
||||
context.processed_components.add(component_id)
|
||||
|
||||
component_path_base = str(component.template_name)
|
||||
variant = getattr(component, "variant", None)
|
||||
manifest = await context.theme_manager.get_template_manifest(
|
||||
component_path_base
|
||||
component_path_base, skin=variant
|
||||
)
|
||||
|
||||
style_paths_to_load = []
|
||||
if manifest and manifest.styles:
|
||||
if manifest and "styles" in manifest:
|
||||
styles = (
|
||||
[manifest.styles]
|
||||
if isinstance(manifest.styles, str)
|
||||
else manifest.styles
|
||||
[manifest["styles"]]
|
||||
if isinstance(manifest["styles"], str)
|
||||
else manifest["styles"]
|
||||
)
|
||||
for style_path in styles:
|
||||
full_style_path = str(Path(component_path_base) / style_path).replace(
|
||||
@ -383,6 +384,7 @@ class RendererService:
|
||||
)
|
||||
|
||||
temp_env.globals.update(context.theme_manager.jinja_env.globals)
|
||||
temp_env.filters.update(context.theme_manager.jinja_env.filters)
|
||||
temp_env.globals["asset"] = (
|
||||
context.theme_manager._create_standalone_asset_loader(template_dir)
|
||||
)
|
||||
@ -431,10 +433,11 @@ class RendererService:
|
||||
component_render_options = {}
|
||||
|
||||
manifest_options = {}
|
||||
variant = getattr(component, "variant", None)
|
||||
if manifest := await context.theme_manager.get_template_manifest(
|
||||
component.template_name
|
||||
component.template_name, skin=variant
|
||||
):
|
||||
manifest_options = manifest.render_options or {}
|
||||
manifest_options = manifest.get("render_options", {})
|
||||
|
||||
final_render_options = component_render_options.copy()
|
||||
final_render_options.update(manifest_options)
|
||||
@ -557,6 +560,8 @@ class RendererService:
|
||||
await self.initialize()
|
||||
assert self._theme_manager is not None, "ThemeManager 未初始化"
|
||||
|
||||
self._theme_manager._manifest_cache.clear()
|
||||
logger.debug("已清除UI清单缓存 (manifest cache)。")
|
||||
current_theme_name = Config.get_config("UI", "THEME", "default")
|
||||
await self._theme_manager.load_theme(current_theme_name)
|
||||
logger.info(f"主题 '{current_theme_name}' 已成功重载。")
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import aiofiles
|
||||
from jinja2 import (
|
||||
ChoiceLoader,
|
||||
Environment,
|
||||
@ -21,7 +21,6 @@ import ujson as json
|
||||
|
||||
from zhenxun.configs.path_config import THEMES_PATH
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.services.renderer.models import TemplateManifest
|
||||
from zhenxun.services.renderer.protocols import Renderable
|
||||
from zhenxun.services.renderer.registry import asset_registry
|
||||
from zhenxun.utils.pydantic_compat import model_dump
|
||||
@ -32,6 +31,20 @@ if TYPE_CHECKING:
|
||||
from .config import RESERVED_TEMPLATE_KEYS
|
||||
|
||||
|
||||
def deep_merge_dict(base: dict, new: dict) -> dict:
|
||||
"""
|
||||
递归地将 new 字典合并到 base 字典中。
|
||||
new 字典中的值会覆盖 base 字典中的值。
|
||||
"""
|
||||
result = base.copy()
|
||||
for key, value in new.items():
|
||||
if isinstance(value, dict) and key in result and isinstance(result[key], dict):
|
||||
result[key] = deep_merge_dict(result[key], value)
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
class RelativePathEnvironment(Environment):
|
||||
"""
|
||||
一个自定义的 Jinja2 环境,重写了 join_path 方法以支持模板间的相对路径引用。
|
||||
@ -151,14 +164,42 @@ class ResourceResolver:
|
||||
|
||||
def resolve_asset_uri(self, asset_path: str, current_template_name: str) -> str:
|
||||
"""解析资源路径,实现完整的回退逻辑,并返回可用的URI。"""
|
||||
if not self.theme_manager.current_theme:
|
||||
if (
|
||||
not self.theme_manager.current_theme
|
||||
or not self.theme_manager.jinja_env.loader
|
||||
):
|
||||
return ""
|
||||
|
||||
if asset_path.startswith("@"):
|
||||
try:
|
||||
full_asset_path = self.theme_manager.jinja_env.join_path(
|
||||
asset_path, current_template_name
|
||||
)
|
||||
_source, file_abs_path, _uptodate = (
|
||||
self.theme_manager.jinja_env.loader.get_source(
|
||||
self.theme_manager.jinja_env, full_asset_path
|
||||
)
|
||||
)
|
||||
if file_abs_path:
|
||||
logger.debug(
|
||||
f"Jinja Loader resolved asset '{asset_path}'->'{file_abs_path}'"
|
||||
)
|
||||
return Path(file_abs_path).absolute().as_uri()
|
||||
except TemplateNotFound:
|
||||
logger.warning(
|
||||
f"资源文件在命名空间中未找到: '{asset_path}'"
|
||||
f"(在模板 '{current_template_name}' 中引用)"
|
||||
)
|
||||
return ""
|
||||
|
||||
search_paths: list[tuple[str, Path]] = []
|
||||
if asset_path.startswith("./"):
|
||||
if asset_path.startswith("./") or asset_path.startswith("../"):
|
||||
relative_part = (
|
||||
asset_path[2:] if asset_path.startswith("./") else asset_path
|
||||
)
|
||||
search_paths.extend(
|
||||
self._search_paths_for_relative_asset(
|
||||
asset_path[2:], current_template_name
|
||||
relative_part, current_template_name
|
||||
)
|
||||
)
|
||||
else:
|
||||
@ -209,6 +250,9 @@ class ThemeManager:
|
||||
|
||||
self.jinja_env.filters["md"] = self._markdown_filter
|
||||
|
||||
self._manifest_cache: dict[str, Any] = {}
|
||||
self._manifest_cache_lock = asyncio.Lock()
|
||||
|
||||
def list_available_themes(self) -> list[str]:
|
||||
"""扫描主题目录并返回所有可用的主题名称。"""
|
||||
if not THEMES_PATH.is_dir():
|
||||
@ -377,16 +421,26 @@ class ThemeManager:
|
||||
logger.error(f"指定的模板文件路径不存在: '{component_path_base}'", e=e)
|
||||
raise e
|
||||
|
||||
entrypoint_filename = "main.html"
|
||||
manifest = await self.get_template_manifest(component_path_base)
|
||||
if manifest and manifest.entrypoint:
|
||||
entrypoint_filename = manifest.entrypoint
|
||||
base_manifest = await self.get_template_manifest(component_path_base)
|
||||
|
||||
skin_to_use = variant or (base_manifest.get("skin") if base_manifest else None)
|
||||
|
||||
final_manifest = await self.get_template_manifest(
|
||||
component_path_base, skin=skin_to_use
|
||||
)
|
||||
logger.debug(f"final_manifest: {final_manifest}")
|
||||
|
||||
entrypoint_filename = (
|
||||
final_manifest.get("entrypoint", "main.html")
|
||||
if final_manifest
|
||||
else "main.html"
|
||||
)
|
||||
|
||||
potential_paths = []
|
||||
|
||||
if variant:
|
||||
if skin_to_use:
|
||||
potential_paths.append(
|
||||
f"{component_path_base}/skins/{variant}/{entrypoint_filename}"
|
||||
f"{component_path_base}/skins/{skin_to_use}/{entrypoint_filename}"
|
||||
)
|
||||
|
||||
potential_paths.append(f"{component_path_base}/{entrypoint_filename}")
|
||||
@ -410,28 +464,88 @@ class ThemeManager:
|
||||
logger.error(err_msg)
|
||||
raise TemplateNotFound(err_msg)
|
||||
|
||||
async def get_template_manifest(
|
||||
self, component_path: str
|
||||
) -> TemplateManifest | None:
|
||||
"""
|
||||
查找并解析组件的 manifest.json 文件。
|
||||
"""
|
||||
manifest_path_str = f"{component_path}/manifest.json"
|
||||
async def _load_single_manifest(self, path_str: str) -> dict[str, Any] | None:
|
||||
"""从指定路径加载单个 manifest.json 文件。"""
|
||||
normalized_path = path_str.replace("\\", "/")
|
||||
manifest_path_str = f"{normalized_path}/manifest.json"
|
||||
|
||||
if not self.jinja_env.loader:
|
||||
return None
|
||||
|
||||
try:
|
||||
_, full_path, _ = self.jinja_env.loader.get_source(
|
||||
source, filepath, _ = self.jinja_env.loader.get_source(
|
||||
self.jinja_env, manifest_path_str
|
||||
)
|
||||
if full_path and Path(full_path).exists():
|
||||
async with aiofiles.open(full_path, encoding="utf-8") as f:
|
||||
manifest_data = json.loads(await f.read())
|
||||
return TemplateManifest(**manifest_data)
|
||||
logger.debug(f"找到清单文件: '{manifest_path_str}' (从 '{filepath}' 加载)")
|
||||
return json.loads(source)
|
||||
except TemplateNotFound:
|
||||
logger.trace(f"未找到清单文件: '{manifest_path_str}'")
|
||||
return None
|
||||
return None
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"清单文件 '{manifest_path_str}' 解析失败")
|
||||
return None
|
||||
|
||||
async def _load_and_merge_manifests(
|
||||
self, component_path: Path | str, skin: str | None = None
|
||||
) -> dict[str, Any] | None:
|
||||
"""加载基础和皮肤清单并进行合并。"""
|
||||
logger.debug(f"开始加载清单: component_path='{component_path}', skin='{skin}'")
|
||||
|
||||
base_manifest = await self._load_single_manifest(str(component_path))
|
||||
|
||||
if skin:
|
||||
skin_path = Path(component_path) / "skins" / skin
|
||||
skin_manifest = await self._load_single_manifest(str(skin_path))
|
||||
|
||||
if skin_manifest:
|
||||
if base_manifest:
|
||||
merged = deep_merge_dict(base_manifest, skin_manifest)
|
||||
logger.debug(
|
||||
f"已合并基础清单和皮肤清单: '{component_path}' + skin '{skin}'"
|
||||
)
|
||||
return merged
|
||||
else:
|
||||
logger.debug(f"只找到皮肤清单: '{skin_path}'")
|
||||
return skin_manifest
|
||||
|
||||
if base_manifest:
|
||||
logger.debug(f"只找到基础清单: '{component_path}'")
|
||||
else:
|
||||
logger.debug(f"未找到任何清单: '{component_path}'")
|
||||
|
||||
return base_manifest
|
||||
|
||||
async def get_template_manifest(
|
||||
self, component_path: str, skin: str | None = None
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
查找并解析组件的 manifest.json 文件。
|
||||
支持皮肤清单的继承与合并,并带有缓存。
|
||||
|
||||
Args:
|
||||
component_path: 组件路径
|
||||
skin: 皮肤名称(可选)
|
||||
|
||||
Returns:
|
||||
合并后的清单字典,如果不存在则返回 None
|
||||
"""
|
||||
cache_key = f"{component_path}:{skin or 'base'}"
|
||||
|
||||
if cache_key in self._manifest_cache:
|
||||
logger.debug(f"清单缓存命中: '{cache_key}'")
|
||||
return self._manifest_cache[cache_key]
|
||||
|
||||
async with self._manifest_cache_lock:
|
||||
if cache_key in self._manifest_cache:
|
||||
logger.debug(f"清单缓存命中(锁内): '{cache_key}'")
|
||||
return self._manifest_cache[cache_key]
|
||||
|
||||
manifest = await self._load_and_merge_manifests(component_path, skin)
|
||||
|
||||
self._manifest_cache[cache_key] = manifest
|
||||
logger.debug(f"清单已缓存: '{cache_key}'")
|
||||
|
||||
return manifest
|
||||
|
||||
async def resolve_markdown_style_path(
|
||||
self, style_name: str, context: "RenderContext"
|
||||
|
||||
@ -126,12 +126,15 @@ class SqlUtils:
|
||||
def format_usage_for_markdown(text: str) -> str:
|
||||
"""
|
||||
智能地将Python多行字符串转换为适合Markdown渲染的格式。
|
||||
- 将单个换行符替换为Markdown的硬换行(行尾加两个空格)。
|
||||
- 在列表、标题等块级元素前自动插入换行,确保正确解析。
|
||||
- 将段落内的单个换行符替换为Markdown的硬换行(行尾加两个空格)。
|
||||
- 保留两个或更多的连续换行符,使其成为Markdown的段落分隔。
|
||||
"""
|
||||
if not text:
|
||||
return ""
|
||||
text = re.sub(r"\n{2,}", "<<PARAGRAPH_BREAK>>", text)
|
||||
text = text.replace("\n", " \n")
|
||||
text = text.replace("<<PARAGRAPH_BREAK>>", "\n\n")
|
||||
|
||||
text = re.sub(r"([^\n])\n(\s*[-*] |\s*#+\s|\s*>)", r"\1\n\n\2", text)
|
||||
|
||||
text = re.sub(r"(?<!\n)\n(?!\n)", " \n", text)
|
||||
|
||||
return text
|
||||
|
||||
@ -5,10 +5,14 @@ Pydantic V1 & V2 兼容层模块
|
||||
包括 model_dump, model_copy, model_json_schema, parse_as 等。
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeVar, get_args, get_origin
|
||||
|
||||
from nonebot.compat import PYDANTIC_V2, model_dump
|
||||
from pydantic import VERSION, BaseModel
|
||||
import ujson as json
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
V = TypeVar("V")
|
||||
@ -19,6 +23,7 @@ __all__ = [
|
||||
"_dump_pydantic_obj",
|
||||
"_is_pydantic_type",
|
||||
"compat_computed_field",
|
||||
"dump_json_safely",
|
||||
"model_copy",
|
||||
"model_dump",
|
||||
"model_json_schema",
|
||||
@ -93,3 +98,26 @@ def parse_as(type_: type[V], obj: Any) -> V:
|
||||
from pydantic import TypeAdapter # type: ignore
|
||||
|
||||
return TypeAdapter(type_).validate_python(obj)
|
||||
|
||||
|
||||
def dump_json_safely(obj: Any, **kwargs) -> str:
|
||||
"""
|
||||
安全地将可能包含 Pydantic 特定类型 (如 Enum) 的对象序列化为 JSON 字符串。
|
||||
"""
|
||||
|
||||
def default_serializer(o):
|
||||
if isinstance(o, Enum):
|
||||
return o.value
|
||||
if isinstance(o, datetime):
|
||||
return o.isoformat()
|
||||
if isinstance(o, Path):
|
||||
return str(o.as_posix())
|
||||
if isinstance(o, set):
|
||||
return list(o)
|
||||
if isinstance(o, BaseModel):
|
||||
return model_dump(o)
|
||||
raise TypeError(
|
||||
f"Object of type {o.__class__.__name__} is not JSON serializable"
|
||||
)
|
||||
|
||||
return json.dumps(obj, default=default_serializer, **kwargs)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user