Compare commits

...

4 Commits

Author SHA1 Message Date
pre-commit-ci[bot]
90cd5f473f
Merge 82bc83b85d into 1cc18bb195 2025-10-11 15:51:08 +05:30
molanp
1cc18bb195
fix(shop): 修改道具不存在时的提示信息 (#2061)
Some checks failed
检查bot是否运行正常 / bot check (push) Has been cancelled
CodeQL Code Security Analysis / Analyze (${{ matrix.language }}) (none, javascript-typescript) (push) Has been cancelled
CodeQL Code Security Analysis / Analyze (${{ matrix.language }}) (none, python) (push) Has been cancelled
Sequential Lint and Type Check / ruff-call (push) Has been cancelled
Release Drafter / Update Release Draft (push) Has been cancelled
Force Sync to Aliyun / sync (push) Has been cancelled
Update Version / update-version (push) Has been cancelled
Sequential Lint and Type Check / pyright-call (push) Has been cancelled
- 将道具不存在时的提示信息从具体的道具名称改为通用提示,避免暴露内部实现细节,
提升用户体验和安全性。
- resolve Bug: 使用道具功能优化
Fixes #2060
2025-10-09 09:01:20 +08:00
Rumio
74a9f3a843
feat(core): 支持LLM多图片响应,增强UI主题皮肤系统及优化JSON/Markdown处理 (#2062)
- 【LLM服务】
  - `LLMResponse` 模型现在支持 `images: list[bytes]`,允许模型返回多张图片。
  - LLM适配器 (`base.py`, `gemini.py`) 和 API 层 (`api.py`, `service.py`) 已更新以处理多图片响应。
  - 响应验证逻辑已调整,以检查 `images` 列表而非单个 `image_bytes`。
- 【UI渲染服务】
  - 引入组件“皮肤”(variant)概念,允许为同一组件提供不同视觉风格。
  - 改进了 `manifest.json` 的加载、合并和缓存机制,支持基础清单与皮肤清单的递归合并。
  - `ThemeManager` 现在会缓存已加载的清单,并在主题重载时清除缓存。
  - 增强了资源解析器 (`ResourceResolver`),支持 `@` 命名空间路径和更健壮的相对路径处理。
  - 独立模板现在会继承主 Jinja 环境的过滤器。
- 【工具函数】
  - 引入 `dump_json_safely` 工具函数,用于更安全地序列化包含 Pydantic 模型、枚举等复杂类型的对象为 JSON。
  - LLM 服务中的请求体和缓存键生成已改用 `dump_json_safely`。
  - 优化了 `format_usage_for_markdown` 函数,改进了 Markdown 文本的格式化,确保块级元素前有正确换行,并正确处理段落内硬换行。

Co-authored-by: webjoin111 <455457521@qq.com>
2025-10-09 08:50:40 +08:00
HibiKier
e7f3c210df
修复并发时数据库超时 (#2063)
* 🔧 修复和优化:调整超时设置,重构检查逻辑,简化代码结构

- 在 `chkdsk_hook.py` 中重构 `check` 方法,提取公共逻辑
- 更新 `CacheManager` 中的超时设置,使用新的 `CACHE_TIMEOUT`
- 在 `utils.py` 中添加缓存逻辑,记录数据库操作的执行情况

*  feat(auth): 添加并发控制,优化权限检查逻辑

* Update utils.py

* 🚨 auto fix by pre-commit hooks

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-10-09 08:46:08 +08:00
21 changed files with 383 additions and 210 deletions

View File

@ -58,5 +58,14 @@ Config.add_plugin_config(
type=bool, type=bool,
) )
Config.add_plugin_config(
"hook",
"AUTH_HOOKS_CONCURRENCY_LIMIT",
5,
help="同步进入权限钩子最大并发数",
default_value=5,
type=int,
)
nonebot.load_plugins(str(Path(__file__).parent.resolve())) nonebot.load_plugins(str(Path(__file__).parent.resolve()))

View File

@ -96,7 +96,6 @@ async def is_ban(user_id: str | None, group_id: str | None) -> int:
f"查询ban记录超时: user_id={user_id}, group_id={group_id}", f"查询ban记录超时: user_id={user_id}, group_id={group_id}",
LOGGER_COMMAND, LOGGER_COMMAND,
) )
# 超时时返回0避免阻塞
return 0 return 0
# 检查记录并计算ban时间 # 检查记录并计算ban时间
@ -199,7 +198,7 @@ async def group_handle(group_id: str) -> None:
) )
async def user_handle(module: str, entity: EntityIDs, session: Uninfo) -> None: async def user_handle(plugin: PluginInfo, entity: EntityIDs, session: Uninfo) -> None:
"""用户ban检查 """用户ban检查
参数: 参数:
@ -217,22 +216,12 @@ async def user_handle(module: str, entity: EntityIDs, session: Uninfo) -> None:
if not time_val: if not time_val:
return return
time_str = format_time(time_val) time_str = format_time(time_val)
plugin_dao = DataAccess(PluginInfo)
try:
db_plugin = await asyncio.wait_for(
plugin_dao.safe_get_or_none(module=module), timeout=DB_TIMEOUT_SECONDS
)
except asyncio.TimeoutError:
logger.error(f"查询插件信息超时: {module}", LOGGER_COMMAND)
# 超时时不阻塞,继续执行
raise SkipPluginException("用户处于黑名单中...")
if ( if (
db_plugin plugin
and not db_plugin.ignore_prompt
and time_val != -1 and time_val != -1
and ban_result and ban_result
and freq.is_send_limit_message(db_plugin, entity.user_id, False) and freq.is_send_limit_message(plugin, entity.user_id, False)
): ):
try: try:
await asyncio.wait_for( await asyncio.wait_for(
@ -260,7 +249,9 @@ async def user_handle(module: str, entity: EntityIDs, session: Uninfo) -> None:
) )
async def auth_ban(matcher: Matcher, bot: Bot, session: Uninfo) -> None: async def auth_ban(
matcher: Matcher, bot: Bot, session: Uninfo, plugin: PluginInfo
) -> None:
"""权限检查 - ban 检查 """权限检查 - ban 检查
参数: 参数:
@ -289,7 +280,7 @@ async def auth_ban(matcher: Matcher, bot: Bot, session: Uninfo) -> None:
if entity.user_id: if entity.user_id:
try: try:
await asyncio.wait_for( await asyncio.wait_for(
user_handle(matcher.plugin_name, entity, session), user_handle(plugin, entity, session),
timeout=DB_TIMEOUT_SECONDS, timeout=DB_TIMEOUT_SECONDS,
) )
except asyncio.TimeoutError: except asyncio.TimeoutError:

View File

@ -1,50 +1,36 @@
import asyncio
import time import time
from nonebot_plugin_alconna import UniMsg from nonebot_plugin_alconna import UniMsg
from zhenxun.models.group_console import GroupConsole from zhenxun.models.group_console import GroupConsole
from zhenxun.models.plugin_info import PluginInfo from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.data_access import DataAccess
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
from zhenxun.services.log import logger from zhenxun.services.log import logger
from zhenxun.utils.utils import EntityIDs
from .config import LOGGER_COMMAND, WARNING_THRESHOLD, SwitchEnum from .config import LOGGER_COMMAND, WARNING_THRESHOLD, SwitchEnum
from .exception import SkipPluginException from .exception import SkipPluginException
async def auth_group(plugin: PluginInfo, entity: EntityIDs, message: UniMsg): async def auth_group(
plugin: PluginInfo,
group: GroupConsole | None,
message: UniMsg,
group_id: str | None,
):
"""群黑名单检测 群总开关检测 """群黑名单检测 群总开关检测
参数: 参数:
plugin: PluginInfo plugin: PluginInfo
entity: EntityIDs group: GroupConsole
message: UniMsg message: UniMsg
""" """
start_time = time.time() if not group_id:
if not entity.group_id:
return return
start_time = time.time()
try: try:
text = message.extract_plain_text() text = message.extract_plain_text()
# 从数据库或缓存中获取群组信息
group_dao = DataAccess(GroupConsole)
try:
group: GroupConsole | None = await asyncio.wait_for(
group_dao.safe_get_or_none(
group_id=entity.group_id, channel_id__isnull=True
),
timeout=DB_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.error("查询群组信息超时", LOGGER_COMMAND, session=entity.user_id)
# 超时时不阻塞,继续执行
return
if not group: if not group:
raise SkipPluginException("群组信息不存在...") raise SkipPluginException("群组信息不存在...")
if group.level < 0: if group.level < 0:
@ -63,6 +49,5 @@ async def auth_group(plugin: PluginInfo, entity: EntityIDs, message: UniMsg):
logger.warning( logger.warning(
f"auth_group 耗时: {elapsed:.3f}s, plugin={plugin.module}", f"auth_group 耗时: {elapsed:.3f}s, plugin={plugin.module}",
LOGGER_COMMAND, LOGGER_COMMAND,
session=entity.user_id, group_id=group_id,
group_id=entity.group_id,
) )

View File

@ -6,12 +6,10 @@ from nonebot_plugin_uninfo import Uninfo
from zhenxun.models.group_console import GroupConsole from zhenxun.models.group_console import GroupConsole
from zhenxun.models.plugin_info import PluginInfo from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.data_access import DataAccess
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
from zhenxun.services.log import logger from zhenxun.services.log import logger
from zhenxun.utils.common_utils import CommonUtils from zhenxun.utils.common_utils import CommonUtils
from zhenxun.utils.enum import BlockType from zhenxun.utils.enum import BlockType
from zhenxun.utils.utils import get_entity_ids
from .config import LOGGER_COMMAND, WARNING_THRESHOLD from .config import LOGGER_COMMAND, WARNING_THRESHOLD
from .exception import IsSuperuserException, SkipPluginException from .exception import IsSuperuserException, SkipPluginException
@ -20,30 +18,17 @@ from .utils import freq, is_poke, send_message
class GroupCheck: class GroupCheck:
def __init__( def __init__(
self, plugin: PluginInfo, group_id: str, session: Uninfo, is_poke: bool self, plugin: PluginInfo, group: GroupConsole, session: Uninfo, is_poke: bool
) -> None: ) -> None:
self.group_id = group_id
self.session = session self.session = session
self.is_poke = is_poke self.is_poke = is_poke
self.plugin = plugin self.plugin = plugin
self.group_dao = DataAccess(GroupConsole) self.group_data = group
self.group_data = None self.group_id = group.group_id
async def check(self): async def check(self):
start_time = time.time() start_time = time.time()
try: try:
# 只查询一次数据库,使用 DataAccess 的缓存机制
try:
self.group_data = await asyncio.wait_for(
self.group_dao.safe_get_or_none(
group_id=self.group_id, channel_id__isnull=True
),
timeout=DB_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.error(f"查询群组数据超时: {self.group_id}", LOGGER_COMMAND)
return # 超时时不阻塞,继续执行
# 检查超级用户禁用 # 检查超级用户禁用
if ( if (
self.group_data self.group_data
@ -113,12 +98,13 @@ class GroupCheck:
class PluginCheck: class PluginCheck:
def __init__(self, group_id: str | None, session: Uninfo, is_poke: bool): def __init__(self, group: GroupConsole | None, session: Uninfo, is_poke: bool):
self.session = session self.session = session
self.is_poke = is_poke self.is_poke = is_poke
self.group_id = group_id self.group_data = group
self.group_dao = DataAccess(GroupConsole) self.group_id = None
self.group_data = None if group:
self.group_id = group.group_id
async def check_user(self, plugin: PluginInfo): async def check_user(self, plugin: PluginInfo):
"""全局私聊禁用检测 """全局私聊禁用检测
@ -156,21 +142,8 @@ class PluginCheck:
if plugin.status or plugin.block_type != BlockType.ALL: if plugin.status or plugin.block_type != BlockType.ALL:
return return
"""全局状态""" """全局状态"""
if self.group_id: if self.group_data and self.group_data.is_super:
# 使用 DataAccess 的缓存机制 raise IsSuperuserException()
try:
self.group_data = await asyncio.wait_for(
self.group_dao.safe_get_or_none(
group_id=self.group_id, channel_id__isnull=True
),
timeout=DB_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.error(f"查询群组数据超时: {self.group_id}", LOGGER_COMMAND)
return # 超时时不阻塞,继续执行
if self.group_data and self.group_data.is_super:
raise IsSuperuserException()
sid = self.group_id or self.session.user.id sid = self.group_id or self.session.user.id
if freq.is_send_limit_message(plugin, sid, self.is_poke): if freq.is_send_limit_message(plugin, sid, self.is_poke):
@ -193,7 +166,9 @@ class PluginCheck:
) )
async def auth_plugin(plugin: PluginInfo, session: Uninfo, event: Event): async def auth_plugin(
plugin: PluginInfo, group: GroupConsole | None, session: Uninfo, event: Event
):
"""插件状态 """插件状态
参数: 参数:
@ -203,35 +178,23 @@ async def auth_plugin(plugin: PluginInfo, session: Uninfo, event: Event):
""" """
start_time = time.time() start_time = time.time()
try: try:
entity = get_entity_ids(session)
is_poke_event = is_poke(event) is_poke_event = is_poke(event)
user_check = PluginCheck(entity.group_id, session, is_poke_event) user_check = PluginCheck(group, session, is_poke_event)
if entity.group_id: tasks = []
group_check = GroupCheck(plugin, entity.group_id, session, is_poke_event) if group:
try: tasks.append(GroupCheck(plugin, group, session, is_poke_event).check())
await asyncio.wait_for(
group_check.check(), timeout=DB_TIMEOUT_SECONDS * 2
)
except asyncio.TimeoutError:
logger.error(f"群组检查超时: {entity.group_id}", LOGGER_COMMAND)
# 超时时不阻塞,继续执行
else: else:
try: tasks.append(user_check.check_user(plugin))
await asyncio.wait_for( tasks.append(user_check.check_global(plugin))
user_check.check_user(plugin), timeout=DB_TIMEOUT_SECONDS
)
except asyncio.TimeoutError:
logger.error("用户检查超时", LOGGER_COMMAND)
# 超时时不阻塞,继续执行
try: try:
await asyncio.wait_for( await asyncio.wait_for(
user_check.check_global(plugin), timeout=DB_TIMEOUT_SECONDS asyncio.gather(*tasks), timeout=DB_TIMEOUT_SECONDS * 2
) )
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.error("全局检查超时", LOGGER_COMMAND) logger.error("插件用户/群组/全局检查超时...", LOGGER_COMMAND)
# 超时时不阻塞,继续执行
finally: finally:
# 记录总执行时间 # 记录总执行时间
elapsed = time.time() - start_time elapsed = time.time() - start_time

View File

@ -85,7 +85,7 @@ class FreqUtils:
return False return False
if plugin.plugin_type == PluginType.DEPENDANT: if plugin.plugin_type == PluginType.DEPENDANT:
return False return False
return plugin.module != "ai" if self._flmt_s.check(sid) else False return False if plugin.ignore_prompt else self._flmt_s.check(sid)
freq = FreqUtils() freq = FreqUtils()

View File

@ -8,6 +8,7 @@ from nonebot_plugin_alconna import UniMsg
from nonebot_plugin_uninfo import Uninfo from nonebot_plugin_uninfo import Uninfo
from tortoise.exceptions import IntegrityError from tortoise.exceptions import IntegrityError
from zhenxun.models.group_console import GroupConsole
from zhenxun.models.plugin_info import PluginInfo from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.user_console import UserConsole from zhenxun.models.user_console import UserConsole
from zhenxun.services.data_access import DataAccess from zhenxun.services.data_access import DataAccess
@ -31,6 +32,7 @@ from .auth.exception import (
PermissionExemption, PermissionExemption,
SkipPluginException, SkipPluginException,
) )
from .auth.utils import base_config
# 超时设置(秒) # 超时设置(秒)
TIMEOUT_SECONDS = 5.0 TIMEOUT_SECONDS = 5.0
@ -46,6 +48,16 @@ CIRCUIT_BREAKERS = {
# 熔断重置时间(秒) # 熔断重置时间(秒)
CIRCUIT_RESET_TIME = 300 # 5分钟 CIRCUIT_RESET_TIME = 300 # 5分钟
# 并发控制:限制同时进入 hooks 并行检查的协程数
# 默认为 6可通过环境变量 AUTH_HOOKS_CONCURRENCY_LIMIT 调整
HOOKS_CONCURRENCY_LIMIT = base_config.get("AUTH_HOOKS_CONCURRENCY_LIMIT")
# 全局信号量与计数器
HOOKS_SEMAPHORE = asyncio.Semaphore(HOOKS_CONCURRENCY_LIMIT)
HOOKS_ACTIVE_COUNT = 0
HOOKS_ACTIVE_LOCK = asyncio.Lock()
# 超时装饰器 # 超时装饰器
async def with_timeout(coro, timeout=TIMEOUT_SECONDS, name=None): async def with_timeout(coro, timeout=TIMEOUT_SECONDS, name=None):
@ -259,6 +271,30 @@ async def time_hook(coro, name, time_dict):
time_dict[name] = f"{time.time() - start:.3f}s" time_dict[name] = f"{time.time() - start:.3f}s"
async def _enter_hooks_section():
"""尝试获取全局信号量并更新计数器,超时则抛出 PermissionExemption。"""
global HOOKS_ACTIVE_COUNT
# 队列模式:如果达到上限,协程将排队等待直到获取到信号量
await HOOKS_SEMAPHORE.acquire()
async with HOOKS_ACTIVE_LOCK:
HOOKS_ACTIVE_COUNT += 1
logger.debug(f"当前并发权限检查数量: {HOOKS_ACTIVE_COUNT}", LOGGER_COMMAND)
async def _leave_hooks_section():
"""释放信号量并更新计数器。"""
global HOOKS_ACTIVE_COUNT
from contextlib import suppress
with suppress(Exception):
HOOKS_SEMAPHORE.release()
async with HOOKS_ACTIVE_LOCK:
HOOKS_ACTIVE_COUNT -= 1
# 保证计数不为负
HOOKS_ACTIVE_COUNT = max(HOOKS_ACTIVE_COUNT, 0)
logger.debug(f"当前并发权限检查数量: {HOOKS_ACTIVE_COUNT}", LOGGER_COMMAND)
async def auth( async def auth(
matcher: Matcher, matcher: Matcher,
event: Event, event: Event,
@ -285,6 +321,9 @@ async def auth(
hook_times = {} hook_times = {}
hooks_time = 0 # 初始化 hooks_time 变量 hooks_time = 0 # 初始化 hooks_time 变量
# 记录是否已进入 hooks 区域(用于 finally 中释放)
entered_hooks = False
try: try:
if not module: if not module:
raise PermissionExemption("Matcher插件名称不存在...") raise PermissionExemption("Matcher插件名称不存在...")
@ -304,6 +343,10 @@ async def auth(
) )
raise PermissionExemption("获取插件和用户数据超时,请稍后再试...") raise PermissionExemption("获取插件和用户数据超时,请稍后再试...")
# 进入 hooks 并行检查区域(会在高并发时排队)
await _enter_hooks_section()
entered_hooks = True
# 获取插件费用 # 获取插件费用
cost_start = time.time() cost_start = time.time()
try: try:
@ -320,16 +363,32 @@ async def auth(
# 执行 bot_filter # 执行 bot_filter
bot_filter(session) bot_filter(session)
group = None
if entity.group_id:
group_dao = DataAccess(GroupConsole)
group = await with_timeout(
group_dao.safe_get_or_none(
group_id=entity.group_id, channel_id__isnull=True
),
name="get_group",
)
# 并行执行所有 hook 检查,并记录执行时间 # 并行执行所有 hook 检查,并记录执行时间
hooks_start = time.time() hooks_start = time.time()
# 创建所有 hook 任务 # 创建所有 hook 任务
hook_tasks = [ hook_tasks = [
time_hook(auth_ban(matcher, bot, session), "auth_ban", hook_times), time_hook(auth_ban(matcher, bot, session, plugin), "auth_ban", hook_times),
time_hook(auth_bot(plugin, bot.self_id), "auth_bot", hook_times), time_hook(auth_bot(plugin, bot.self_id), "auth_bot", hook_times),
time_hook(auth_group(plugin, entity, message), "auth_group", hook_times), time_hook(
auth_group(plugin, group, message, entity.group_id),
"auth_group",
hook_times,
),
time_hook(auth_admin(plugin, session), "auth_admin", hook_times), time_hook(auth_admin(plugin, session), "auth_admin", hook_times),
time_hook(auth_plugin(plugin, session, event), "auth_plugin", hook_times), time_hook(
auth_plugin(plugin, group, session, event), "auth_plugin", hook_times
),
time_hook(auth_limit(plugin, session), "auth_limit", hook_times), time_hook(auth_limit(plugin, session), "auth_limit", hook_times),
] ]
@ -358,7 +417,17 @@ async def auth(
logger.debug("超级用户跳过权限检测...", LOGGER_COMMAND, session=session) logger.debug("超级用户跳过权限检测...", LOGGER_COMMAND, session=session)
except PermissionExemption as e: except PermissionExemption as e:
logger.info(str(e), LOGGER_COMMAND, session=session) logger.info(str(e), LOGGER_COMMAND, session=session)
finally:
# 如果进入过 hooks 区域,确保释放信号量(即使上层处理抛出了异常)
if entered_hooks:
try:
await _leave_hooks_section()
except Exception:
logger.error(
"释放 hooks 信号量时出错",
LOGGER_COMMAND,
session=session,
)
# 扣除金币 # 扣除金币
if not ignore_flag and cost_gold > 0: if not ignore_flag and cost_gold > 0:
gold_start = time.time() gold_start = time.time()

View File

@ -43,18 +43,20 @@ class BanCheckLimiter:
def check(self, key: str | float) -> bool: def check(self, key: str | float) -> bool:
if time.time() - self.mtime[key] > self.default_check_time: if time.time() - self.mtime[key] > self.default_check_time:
self.mtime[key] = time.time() return self._extracted_from_check_3(key, False)
self.mint[key] = 0
return False
if ( if (
self.mint[key] >= self.default_count self.mint[key] >= self.default_count
and time.time() - self.mtime[key] < self.default_check_time and time.time() - self.mtime[key] < self.default_check_time
): ):
self.mtime[key] = time.time() return self._extracted_from_check_3(key, True)
self.mint[key] = 0
return True
return False return False
# TODO Rename this here and in `check`
def _extracted_from_check_3(self, key, arg1):
self.mtime[key] = time.time()
self.mint[key] = 0
return arg1
_blmt = BanCheckLimiter( _blmt = BanCheckLimiter(
malicious_check_time, malicious_check_time,
@ -70,16 +72,15 @@ async def _(
module = None module = None
if plugin := matcher.plugin: if plugin := matcher.plugin:
module = plugin.module_name module = plugin.module_name
if metadata := plugin.metadata: if not (metadata := plugin.metadata):
extra = metadata.extra return
if extra.get("plugin_type") in [ extra = metadata.extra
PluginType.HIDDEN, if extra.get("plugin_type") in [
PluginType.DEPENDANT, PluginType.HIDDEN,
PluginType.ADMIN, PluginType.DEPENDANT,
PluginType.SUPERUSER, PluginType.ADMIN,
]: PluginType.SUPERUSER,
return ]:
else:
return return
if matcher.type == "notice": if matcher.type == "notice":
return return
@ -88,32 +89,31 @@ async def _(
malicious_ban_time = Config.get_config("hook", "MALICIOUS_BAN_TIME") malicious_ban_time = Config.get_config("hook", "MALICIOUS_BAN_TIME")
if not malicious_ban_time: if not malicious_ban_time:
raise ValueError("模块: [hook], 配置项: [MALICIOUS_BAN_TIME] 为空或小于0") raise ValueError("模块: [hook], 配置项: [MALICIOUS_BAN_TIME] 为空或小于0")
if user_id: if user_id and module:
if module: if _blmt.check(f"{user_id}__{module}"):
if _blmt.check(f"{user_id}__{module}"): await BanConsole.ban(
await BanConsole.ban( user_id,
user_id, group_id,
group_id, 9,
9, "恶意触发命令检测",
"恶意触发命令检测", malicious_ban_time * 60,
malicious_ban_time * 60, bot.self_id,
bot.self_id, )
) logger.info(
logger.info( f"触发了恶意触发检测: {matcher.plugin_name}",
f"触发了恶意触发检测: {matcher.plugin_name}", "HOOK",
"HOOK", session=session,
session=session, )
) await MessageUtils.build_message(
await MessageUtils.build_message( [
[ At(flag="user", target=user_id),
At(flag="user", target=user_id), "检测到恶意触发命令,您将被封禁 30 分钟",
"检测到恶意触发命令,您将被封禁 30 分钟", ]
] ).send()
).send() logger.debug(
logger.debug( f"触发了恶意触发检测: {matcher.plugin_name}",
f"触发了恶意触发检测: {matcher.plugin_name}", "HOOK",
"HOOK", session=session,
session=session, )
) raise IgnoredException("检测到恶意触发命令")
raise IgnoredException("检测到恶意触发命令") _blmt.add(f"{user_id}__{module}")
_blmt.add(f"{user_id}__{module}")

View File

@ -367,7 +367,7 @@ class ShopManage:
else: else:
goods_info = await GoodsInfo.get_or_none(goods_name=goods_name) goods_info = await GoodsInfo.get_or_none(goods_name=goods_name)
if not goods_info: if not goods_info:
return f"{goods_name} 不存在..." return "对应的道具不存在..."
if goods_info.is_passive: if goods_info.is_passive:
return f"{goods_info.goods_name} 是被动道具, 无法使用..." return f"{goods_info.goods_name} 是被动道具, 无法使用..."
goods = cls.uuid2goods.get(goods_info.uuid) goods = cls.uuid2goods.get(goods_info.uuid)

View File

@ -98,6 +98,7 @@ from .cache_containers import CacheDict, CacheList
from .config import ( from .config import (
CACHE_KEY_PREFIX, CACHE_KEY_PREFIX,
CACHE_KEY_SEPARATOR, CACHE_KEY_SEPARATOR,
CACHE_TIMEOUT,
DEFAULT_EXPIRE, DEFAULT_EXPIRE,
LOG_COMMAND, LOG_COMMAND,
SPECIAL_KEY_FORMATS, SPECIAL_KEY_FORMATS,
@ -551,7 +552,6 @@ class CacheManager:
返回: 返回:
Any: 缓存数据如果不存在返回默认值 Any: 缓存数据如果不存在返回默认值
""" """
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
# 如果缓存被禁用或缓存模式为NONE直接返回默认值 # 如果缓存被禁用或缓存模式为NONE直接返回默认值
if not self.enabled or cache_config.cache_mode == CacheMode.NONE: if not self.enabled or cache_config.cache_mode == CacheMode.NONE:
@ -561,7 +561,7 @@ class CacheManager:
cache_key = self._build_key(cache_type, key) cache_key = self._build_key(cache_type, key)
data = await asyncio.wait_for( data = await asyncio.wait_for(
self.cache_backend.get(cache_key), # type: ignore self.cache_backend.get(cache_key), # type: ignore
timeout=DB_TIMEOUT_SECONDS, timeout=CACHE_TIMEOUT,
) )
if data is None: if data is None:

View File

@ -5,6 +5,9 @@
# 日志标识 # 日志标识
LOG_COMMAND = "CacheRoot" LOG_COMMAND = "CacheRoot"
# 缓存获取超时时间(秒)
CACHE_TIMEOUT = 10
# 默认缓存过期时间(秒) # 默认缓存过期时间(秒)
DEFAULT_EXPIRE = 600 DEFAULT_EXPIRE = 600

View File

@ -27,5 +27,8 @@ async def with_db_timeout(
return result return result
except asyncio.TimeoutError: except asyncio.TimeoutError:
if operation: if operation:
logger.error(f"数据库操作超时: {operation} (>{timeout}s)", LOG_COMMAND) logger.error(
f"数据库操作超时: {operation} (>{timeout}s) 来源: {source}",
LOG_COMMAND,
)
raise raise

View File

@ -35,7 +35,7 @@ class ResponseData(BaseModel):
"""响应数据封装 - 支持所有高级功能""" """响应数据封装 - 支持所有高级功能"""
text: str text: str
image_bytes: bytes | None = None images: list[bytes] | None = None
usage_info: dict[str, Any] | None = None usage_info: dict[str, Any] | None = None
raw_response: dict[str, Any] | None = None raw_response: dict[str, Any] | None = None
tool_calls: list[LLMToolCall] | None = None tool_calls: list[LLMToolCall] | None = None
@ -246,17 +246,17 @@ class BaseAdapter(ABC):
if content: if content:
content = content.strip() content = content.strip()
image_bytes: bytes | None = None images_bytes: list[bytes] = []
if content and content.startswith("{") and content.endswith("}"): if content and content.startswith("{") and content.endswith("}"):
try: try:
content_json = json.loads(content) content_json = json.loads(content)
if "b64_json" in content_json: if "b64_json" in content_json:
image_bytes = base64.b64decode(content_json["b64_json"]) images_bytes.append(base64.b64decode(content_json["b64_json"]))
content = "[图片已生成]" content = "[图片已生成]"
elif "data" in content_json and isinstance( elif "data" in content_json and isinstance(
content_json["data"], str content_json["data"], str
): ):
image_bytes = base64.b64decode(content_json["data"]) images_bytes.append(base64.b64decode(content_json["data"]))
content = "[图片已生成]" content = "[图片已生成]"
except (json.JSONDecodeError, KeyError, binascii.Error): except (json.JSONDecodeError, KeyError, binascii.Error):
@ -273,7 +273,7 @@ class BaseAdapter(ABC):
if url_str.startswith("data:image/png;base64,"): if url_str.startswith("data:image/png;base64,"):
try: try:
b64_data = url_str.split(",", 1)[1] b64_data = url_str.split(",", 1)[1]
image_bytes = base64.b64decode(b64_data) images_bytes.append(base64.b64decode(b64_data))
content = content if content else "[图片已生成]" content = content if content else "[图片已生成]"
except (IndexError, binascii.Error) as e: except (IndexError, binascii.Error) as e:
logger.warning(f"解析OpenRouter Base64图片数据失败: {e}") logger.warning(f"解析OpenRouter Base64图片数据失败: {e}")
@ -316,7 +316,7 @@ class BaseAdapter(ABC):
text=final_text, text=final_text,
tool_calls=parsed_tool_calls, tool_calls=parsed_tool_calls,
usage_info=usage_info, usage_info=usage_info,
image_bytes=image_bytes, images=images_bytes if images_bytes else None,
raw_response=response_json, raw_response=response_json,
) )

View File

@ -408,7 +408,7 @@ class GeminiAdapter(BaseAdapter):
parts = content_data.get("parts", []) parts = content_data.get("parts", [])
text_content = "" text_content = ""
image_bytes: bytes | None = None images_bytes: list[bytes] = []
parsed_tool_calls: list["LLMToolCall"] | None = None parsed_tool_calls: list["LLMToolCall"] | None = None
thought_summary_parts = [] thought_summary_parts = []
answer_parts = [] answer_parts = []
@ -423,10 +423,7 @@ class GeminiAdapter(BaseAdapter):
elif "inlineData" in part: elif "inlineData" in part:
inline_data = part["inlineData"] inline_data = part["inlineData"]
if "data" in inline_data: if "data" in inline_data:
image_bytes = base64.b64decode(inline_data["data"]) images_bytes.append(base64.b64decode(inline_data["data"]))
answer_parts.append(
f"[图片已生成: {inline_data.get('mimeType', 'image')}]"
)
elif "functionCall" in part: elif "functionCall" in part:
if parsed_tool_calls is None: if parsed_tool_calls is None:
@ -494,7 +491,7 @@ class GeminiAdapter(BaseAdapter):
return ResponseData( return ResponseData(
text=text_content, text=text_content,
tool_calls=parsed_tool_calls, tool_calls=parsed_tool_calls,
image_bytes=image_bytes, images=images_bytes if images_bytes else None,
usage_info=usage_info, usage_info=usage_info,
raw_response=response_json, raw_response=response_json,
grounding_metadata=grounding_metadata_obj, grounding_metadata=grounding_metadata_obj,

View File

@ -339,7 +339,7 @@ async def _generate_image_from_message(
response = await model_instance.generate_response(messages, config=config) response = await model_instance.generate_response(messages, config=config)
if not response.image_bytes: if not response.images:
error_text = response.text or "模型未返回图片数据。" error_text = response.text or "模型未返回图片数据。"
logger.warning(f"图片生成调用未返回图片,返回文本内容: {error_text}") logger.warning(f"图片生成调用未返回图片,返回文本内容: {error_text}")

View File

@ -5,12 +5,12 @@ LLM 模型管理器
""" """
import hashlib import hashlib
import json
import time import time
from typing import Any from typing import Any
from zhenxun.configs.config import Config from zhenxun.configs.config import Config
from zhenxun.services.log import logger from zhenxun.services.log import logger
from zhenxun.utils.pydantic_compat import dump_json_safely
from .config import validate_override_params from .config import validate_override_params
from .config.providers import AI_CONFIG_GROUP, PROVIDERS_CONFIG_KEY, get_ai_config from .config.providers import AI_CONFIG_GROUP, PROVIDERS_CONFIG_KEY, get_ai_config
@ -43,7 +43,7 @@ def _make_cache_key(
) -> str: ) -> str:
"""生成缓存键""" """生成缓存键"""
config_str = ( config_str = (
json.dumps(override_config, sort_keys=True) if override_config else "None" dump_json_safely(override_config, sort_keys=True) if override_config else "None"
) )
key_data = f"{provider_model_name}:{config_str}" key_data = f"{provider_model_name}:{config_str}"
return hashlib.md5(key_data.encode()).hexdigest() return hashlib.md5(key_data.encode()).hexdigest()

View File

@ -13,6 +13,7 @@ from pydantic import BaseModel
from zhenxun.services.log import logger from zhenxun.services.log import logger
from zhenxun.utils.log_sanitizer import sanitize_for_logging from zhenxun.utils.log_sanitizer import sanitize_for_logging
from zhenxun.utils.pydantic_compat import dump_json_safely
from .adapters.base import RequestData from .adapters.base import RequestData
from .config import LLMGenerationConfig from .config import LLMGenerationConfig
@ -194,13 +195,15 @@ class LLMModel(LLMModelBase):
sanitized_body = sanitize_for_logging( sanitized_body = sanitize_for_logging(
request_data.body, context=sanitizer_req_context request_data.body, context=sanitizer_req_context
) )
request_body_str = json.dumps(sanitized_body, ensure_ascii=False, indent=2) request_body_str = dump_json_safely(
sanitized_body, ensure_ascii=False, indent=2
)
logger.debug(f"📦 请求体: {request_body_str}") logger.debug(f"📦 请求体: {request_body_str}")
http_response = await http_client.post( http_response = await http_client.post(
request_data.url, request_data.url,
headers=request_data.headers, headers=request_data.headers,
json=request_data.body, content=dump_json_safely(request_data.body, ensure_ascii=False),
) )
logger.debug(f"📥 响应状态码: {http_response.status_code}") logger.debug(f"📥 响应状态码: {http_response.status_code}")
@ -394,7 +397,7 @@ class LLMModel(LLMModelBase):
return LLMResponse( return LLMResponse(
text=response_data.text, text=response_data.text,
usage_info=response_data.usage_info, usage_info=response_data.usage_info,
image_bytes=response_data.image_bytes, images=response_data.images,
raw_response=response_data.raw_response, raw_response=response_data.raw_response,
tool_calls=response_tool_calls if response_tool_calls else None, tool_calls=response_tool_calls if response_tool_calls else None,
code_executions=response_data.code_executions, code_executions=response_data.code_executions,
@ -424,7 +427,7 @@ class LLMModel(LLMModelBase):
policy = config.validation_policy policy = config.validation_policy
if policy: if policy:
if policy.get("require_image") and not parsed_data.image_bytes: if policy.get("require_image") and not parsed_data.images:
if self.api_type == "gemini" and parsed_data.raw_response: if self.api_type == "gemini" and parsed_data.raw_response:
usage_metadata = parsed_data.raw_response.get( usage_metadata = parsed_data.raw_response.get(
"usageMetadata", {} "usageMetadata", {}

View File

@ -425,7 +425,7 @@ class LLMResponse(BaseModel):
"""LLM 响应""" """LLM 响应"""
text: str text: str
image_bytes: bytes | None = None images: list[bytes] | None = None
usage_info: dict[str, Any] | None = None usage_info: dict[str, Any] | None = None
raw_response: dict[str, Any] | None = None raw_response: dict[str, Any] | None = None
tool_calls: list[Any] | None = None tool_calls: list[Any] | None = None

View File

@ -217,16 +217,17 @@ class RendererService:
context.processed_components.add(component_id) context.processed_components.add(component_id)
component_path_base = str(component.template_name) component_path_base = str(component.template_name)
variant = getattr(component, "variant", None)
manifest = await context.theme_manager.get_template_manifest( manifest = await context.theme_manager.get_template_manifest(
component_path_base component_path_base, skin=variant
) )
style_paths_to_load = [] style_paths_to_load = []
if manifest and manifest.styles: if manifest and "styles" in manifest:
styles = ( styles = (
[manifest.styles] [manifest["styles"]]
if isinstance(manifest.styles, str) if isinstance(manifest["styles"], str)
else manifest.styles else manifest["styles"]
) )
for style_path in styles: for style_path in styles:
full_style_path = str(Path(component_path_base) / style_path).replace( full_style_path = str(Path(component_path_base) / style_path).replace(
@ -383,6 +384,7 @@ class RendererService:
) )
temp_env.globals.update(context.theme_manager.jinja_env.globals) temp_env.globals.update(context.theme_manager.jinja_env.globals)
temp_env.filters.update(context.theme_manager.jinja_env.filters)
temp_env.globals["asset"] = ( temp_env.globals["asset"] = (
context.theme_manager._create_standalone_asset_loader(template_dir) context.theme_manager._create_standalone_asset_loader(template_dir)
) )
@ -431,10 +433,11 @@ class RendererService:
component_render_options = {} component_render_options = {}
manifest_options = {} manifest_options = {}
variant = getattr(component, "variant", None)
if manifest := await context.theme_manager.get_template_manifest( if manifest := await context.theme_manager.get_template_manifest(
component.template_name component.template_name, skin=variant
): ):
manifest_options = manifest.render_options or {} manifest_options = manifest.get("render_options", {})
final_render_options = component_render_options.copy() final_render_options = component_render_options.copy()
final_render_options.update(manifest_options) final_render_options.update(manifest_options)
@ -557,6 +560,8 @@ class RendererService:
await self.initialize() await self.initialize()
assert self._theme_manager is not None, "ThemeManager 未初始化" assert self._theme_manager is not None, "ThemeManager 未初始化"
self._theme_manager._manifest_cache.clear()
logger.debug("已清除UI清单缓存 (manifest cache)。")
current_theme_name = Config.get_config("UI", "THEME", "default") current_theme_name = Config.get_config("UI", "THEME", "default")
await self._theme_manager.load_theme(current_theme_name) await self._theme_manager.load_theme(current_theme_name)
logger.info(f"主题 '{current_theme_name}' 已成功重载。") logger.info(f"主题 '{current_theme_name}' 已成功重载。")

View File

@ -1,11 +1,11 @@
from __future__ import annotations from __future__ import annotations
import asyncio
from collections.abc import Callable from collections.abc import Callable
import os import os
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
import aiofiles
from jinja2 import ( from jinja2 import (
ChoiceLoader, ChoiceLoader,
Environment, Environment,
@ -21,7 +21,6 @@ import ujson as json
from zhenxun.configs.path_config import THEMES_PATH from zhenxun.configs.path_config import THEMES_PATH
from zhenxun.services.log import logger from zhenxun.services.log import logger
from zhenxun.services.renderer.models import TemplateManifest
from zhenxun.services.renderer.protocols import Renderable from zhenxun.services.renderer.protocols import Renderable
from zhenxun.services.renderer.registry import asset_registry from zhenxun.services.renderer.registry import asset_registry
from zhenxun.utils.pydantic_compat import model_dump from zhenxun.utils.pydantic_compat import model_dump
@ -32,6 +31,20 @@ if TYPE_CHECKING:
from .config import RESERVED_TEMPLATE_KEYS from .config import RESERVED_TEMPLATE_KEYS
def deep_merge_dict(base: dict, new: dict) -> dict:
"""
递归地将 new 字典合并到 base 字典中
new 字典中的值会覆盖 base 字典中的值
"""
result = base.copy()
for key, value in new.items():
if isinstance(value, dict) and key in result and isinstance(result[key], dict):
result[key] = deep_merge_dict(result[key], value)
else:
result[key] = value
return result
class RelativePathEnvironment(Environment): class RelativePathEnvironment(Environment):
""" """
一个自定义的 Jinja2 环境重写了 join_path 方法以支持模板间的相对路径引用 一个自定义的 Jinja2 环境重写了 join_path 方法以支持模板间的相对路径引用
@ -151,14 +164,42 @@ class ResourceResolver:
def resolve_asset_uri(self, asset_path: str, current_template_name: str) -> str: def resolve_asset_uri(self, asset_path: str, current_template_name: str) -> str:
"""解析资源路径实现完整的回退逻辑并返回可用的URI。""" """解析资源路径实现完整的回退逻辑并返回可用的URI。"""
if not self.theme_manager.current_theme: if (
not self.theme_manager.current_theme
or not self.theme_manager.jinja_env.loader
):
return "" return ""
if asset_path.startswith("@"):
try:
full_asset_path = self.theme_manager.jinja_env.join_path(
asset_path, current_template_name
)
_source, file_abs_path, _uptodate = (
self.theme_manager.jinja_env.loader.get_source(
self.theme_manager.jinja_env, full_asset_path
)
)
if file_abs_path:
logger.debug(
f"Jinja Loader resolved asset '{asset_path}'->'{file_abs_path}'"
)
return Path(file_abs_path).absolute().as_uri()
except TemplateNotFound:
logger.warning(
f"资源文件在命名空间中未找到: '{asset_path}'"
f"(在模板 '{current_template_name}' 中引用)"
)
return ""
search_paths: list[tuple[str, Path]] = [] search_paths: list[tuple[str, Path]] = []
if asset_path.startswith("./"): if asset_path.startswith("./") or asset_path.startswith("../"):
relative_part = (
asset_path[2:] if asset_path.startswith("./") else asset_path
)
search_paths.extend( search_paths.extend(
self._search_paths_for_relative_asset( self._search_paths_for_relative_asset(
asset_path[2:], current_template_name relative_part, current_template_name
) )
) )
else: else:
@ -209,6 +250,9 @@ class ThemeManager:
self.jinja_env.filters["md"] = self._markdown_filter self.jinja_env.filters["md"] = self._markdown_filter
self._manifest_cache: dict[str, Any] = {}
self._manifest_cache_lock = asyncio.Lock()
def list_available_themes(self) -> list[str]: def list_available_themes(self) -> list[str]:
"""扫描主题目录并返回所有可用的主题名称。""" """扫描主题目录并返回所有可用的主题名称。"""
if not THEMES_PATH.is_dir(): if not THEMES_PATH.is_dir():
@ -377,16 +421,26 @@ class ThemeManager:
logger.error(f"指定的模板文件路径不存在: '{component_path_base}'", e=e) logger.error(f"指定的模板文件路径不存在: '{component_path_base}'", e=e)
raise e raise e
entrypoint_filename = "main.html" base_manifest = await self.get_template_manifest(component_path_base)
manifest = await self.get_template_manifest(component_path_base)
if manifest and manifest.entrypoint: skin_to_use = variant or (base_manifest.get("skin") if base_manifest else None)
entrypoint_filename = manifest.entrypoint
final_manifest = await self.get_template_manifest(
component_path_base, skin=skin_to_use
)
logger.debug(f"final_manifest: {final_manifest}")
entrypoint_filename = (
final_manifest.get("entrypoint", "main.html")
if final_manifest
else "main.html"
)
potential_paths = [] potential_paths = []
if variant: if skin_to_use:
potential_paths.append( potential_paths.append(
f"{component_path_base}/skins/{variant}/{entrypoint_filename}" f"{component_path_base}/skins/{skin_to_use}/{entrypoint_filename}"
) )
potential_paths.append(f"{component_path_base}/{entrypoint_filename}") potential_paths.append(f"{component_path_base}/{entrypoint_filename}")
@ -410,28 +464,88 @@ class ThemeManager:
logger.error(err_msg) logger.error(err_msg)
raise TemplateNotFound(err_msg) raise TemplateNotFound(err_msg)
async def get_template_manifest( async def _load_single_manifest(self, path_str: str) -> dict[str, Any] | None:
self, component_path: str """从指定路径加载单个 manifest.json 文件。"""
) -> TemplateManifest | None: normalized_path = path_str.replace("\\", "/")
""" manifest_path_str = f"{normalized_path}/manifest.json"
查找并解析组件的 manifest.json 文件
"""
manifest_path_str = f"{component_path}/manifest.json"
if not self.jinja_env.loader: if not self.jinja_env.loader:
return None return None
try: try:
_, full_path, _ = self.jinja_env.loader.get_source( source, filepath, _ = self.jinja_env.loader.get_source(
self.jinja_env, manifest_path_str self.jinja_env, manifest_path_str
) )
if full_path and Path(full_path).exists(): logger.debug(f"找到清单文件: '{manifest_path_str}' (从 '{filepath}' 加载)")
async with aiofiles.open(full_path, encoding="utf-8") as f: return json.loads(source)
manifest_data = json.loads(await f.read())
return TemplateManifest(**manifest_data)
except TemplateNotFound: except TemplateNotFound:
logger.trace(f"未找到清单文件: '{manifest_path_str}'")
return None return None
return None except json.JSONDecodeError:
logger.warning(f"清单文件 '{manifest_path_str}' 解析失败")
return None
async def _load_and_merge_manifests(
self, component_path: Path | str, skin: str | None = None
) -> dict[str, Any] | None:
"""加载基础和皮肤清单并进行合并。"""
logger.debug(f"开始加载清单: component_path='{component_path}', skin='{skin}'")
base_manifest = await self._load_single_manifest(str(component_path))
if skin:
skin_path = Path(component_path) / "skins" / skin
skin_manifest = await self._load_single_manifest(str(skin_path))
if skin_manifest:
if base_manifest:
merged = deep_merge_dict(base_manifest, skin_manifest)
logger.debug(
f"已合并基础清单和皮肤清单: '{component_path}' + skin '{skin}'"
)
return merged
else:
logger.debug(f"只找到皮肤清单: '{skin_path}'")
return skin_manifest
if base_manifest:
logger.debug(f"只找到基础清单: '{component_path}'")
else:
logger.debug(f"未找到任何清单: '{component_path}'")
return base_manifest
async def get_template_manifest(
self, component_path: str, skin: str | None = None
) -> dict[str, Any] | None:
"""
查找并解析组件的 manifest.json 文件
支持皮肤清单的继承与合并,并带有缓存
Args:
component_path: 组件路径
skin: 皮肤名称(可选)
Returns:
合并后的清单字典,如果不存在则返回 None
"""
cache_key = f"{component_path}:{skin or 'base'}"
if cache_key in self._manifest_cache:
logger.debug(f"清单缓存命中: '{cache_key}'")
return self._manifest_cache[cache_key]
async with self._manifest_cache_lock:
if cache_key in self._manifest_cache:
logger.debug(f"清单缓存命中(锁内): '{cache_key}'")
return self._manifest_cache[cache_key]
manifest = await self._load_and_merge_manifests(component_path, skin)
self._manifest_cache[cache_key] = manifest
logger.debug(f"清单已缓存: '{cache_key}'")
return manifest
async def resolve_markdown_style_path( async def resolve_markdown_style_path(
self, style_name: str, context: "RenderContext" self, style_name: str, context: "RenderContext"

View File

@ -126,12 +126,15 @@ class SqlUtils:
def format_usage_for_markdown(text: str) -> str: def format_usage_for_markdown(text: str) -> str:
""" """
智能地将Python多行字符串转换为适合Markdown渲染的格式 智能地将Python多行字符串转换为适合Markdown渲染的格式
- 将单个换行符替换为Markdown的硬换行行尾加两个空格 - 在列表标题等块级元素前自动插入换行确保正确解析
- 将段落内的单个换行符替换为Markdown的硬换行行尾加两个空格
- 保留两个或更多的连续换行符使其成为Markdown的段落分隔 - 保留两个或更多的连续换行符使其成为Markdown的段落分隔
""" """
if not text: if not text:
return "" return ""
text = re.sub(r"\n{2,}", "<<PARAGRAPH_BREAK>>", text)
text = text.replace("\n", " \n") text = re.sub(r"([^\n])\n(\s*[-*] |\s*#+\s|\s*>)", r"\1\n\n\2", text)
text = text.replace("<<PARAGRAPH_BREAK>>", "\n\n")
text = re.sub(r"(?<!\n)\n(?!\n)", " \n", text)
return text return text

View File

@ -5,10 +5,14 @@ Pydantic V1 & V2 兼容层模块
包括 model_dump, model_copy, model_json_schema, parse_as 包括 model_dump, model_copy, model_json_schema, parse_as
""" """
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, TypeVar, get_args, get_origin from typing import Any, TypeVar, get_args, get_origin
from nonebot.compat import PYDANTIC_V2, model_dump from nonebot.compat import PYDANTIC_V2, model_dump
from pydantic import VERSION, BaseModel from pydantic import VERSION, BaseModel
import ujson as json
T = TypeVar("T", bound=BaseModel) T = TypeVar("T", bound=BaseModel)
V = TypeVar("V") V = TypeVar("V")
@ -19,6 +23,7 @@ __all__ = [
"_dump_pydantic_obj", "_dump_pydantic_obj",
"_is_pydantic_type", "_is_pydantic_type",
"compat_computed_field", "compat_computed_field",
"dump_json_safely",
"model_copy", "model_copy",
"model_dump", "model_dump",
"model_json_schema", "model_json_schema",
@ -93,3 +98,26 @@ def parse_as(type_: type[V], obj: Any) -> V:
from pydantic import TypeAdapter # type: ignore from pydantic import TypeAdapter # type: ignore
return TypeAdapter(type_).validate_python(obj) return TypeAdapter(type_).validate_python(obj)
def dump_json_safely(obj: Any, **kwargs) -> str:
"""
安全地将可能包含 Pydantic 特定类型 ( Enum) 的对象序列化为 JSON 字符串
"""
def default_serializer(o):
if isinstance(o, Enum):
return o.value
if isinstance(o, datetime):
return o.isoformat()
if isinstance(o, Path):
return str(o.as_posix())
if isinstance(o, set):
return list(o)
if isinstance(o, BaseModel):
return model_dump(o)
raise TypeError(
f"Object of type {o.__class__.__name__} is not JSON serializable"
)
return json.dumps(obj, default=default_serializer, **kwargs)