From e7f3c210dfd14bb23d1fbf99f99352aa0fd4bd95 Mon Sep 17 00:00:00 2001 From: HibiKier <45528451+HibiKier@users.noreply.github.com> Date: Thu, 9 Oct 2025 08:46:08 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E5=B9=B6=E5=8F=91=E6=97=B6?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=BA=93=E8=B6=85=E6=97=B6=20(#2063)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🔧 修复和优化:调整超时设置,重构检查逻辑,简化代码结构 - 在 `chkdsk_hook.py` 中重构 `check` 方法,提取公共逻辑 - 更新 `CacheManager` 中的超时设置,使用新的 `CACHE_TIMEOUT` - 在 `utils.py` 中添加缓存逻辑,记录数据库操作的执行情况 * ✨ feat(auth): 添加并发控制,优化权限检查逻辑 * Update utils.py * :rotating_light: auto fix by pre-commit hooks --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- zhenxun/builtin_plugins/hooks/__init__.py | 9 ++ .../builtin_plugins/hooks/auth/auth_ban.py | 23 ++--- .../builtin_plugins/hooks/auth/auth_group.py | 37 +++----- .../builtin_plugins/hooks/auth/auth_plugin.py | 81 +++++------------ zhenxun/builtin_plugins/hooks/auth/utils.py | 2 +- zhenxun/builtin_plugins/hooks/auth_checker.py | 77 +++++++++++++++- zhenxun/builtin_plugins/hooks/chkdsk_hook.py | 90 +++++++++---------- zhenxun/services/cache/__init__.py | 4 +- zhenxun/services/cache/config.py | 3 + zhenxun/services/db_context/utils.py | 5 +- 10 files changed, 177 insertions(+), 154 deletions(-) diff --git a/zhenxun/builtin_plugins/hooks/__init__.py b/zhenxun/builtin_plugins/hooks/__init__.py index 2f8c79de..e61ec71d 100644 --- a/zhenxun/builtin_plugins/hooks/__init__.py +++ b/zhenxun/builtin_plugins/hooks/__init__.py @@ -58,5 +58,14 @@ Config.add_plugin_config( type=bool, ) +Config.add_plugin_config( + "hook", + "AUTH_HOOKS_CONCURRENCY_LIMIT", + 5, + help="同步进入权限钩子最大并发数", + default_value=5, + type=int, +) + nonebot.load_plugins(str(Path(__file__).parent.resolve())) diff --git a/zhenxun/builtin_plugins/hooks/auth/auth_ban.py b/zhenxun/builtin_plugins/hooks/auth/auth_ban.py index b7663090..7eea7f57 100644 --- a/zhenxun/builtin_plugins/hooks/auth/auth_ban.py +++ b/zhenxun/builtin_plugins/hooks/auth/auth_ban.py @@ -96,7 +96,6 @@ async def is_ban(user_id: str | None, group_id: str | None) -> int: f"查询ban记录超时: user_id={user_id}, group_id={group_id}", LOGGER_COMMAND, ) - # 超时时返回0,避免阻塞 return 0 # 检查记录并计算ban时间 @@ -199,7 +198,7 @@ async def group_handle(group_id: str) -> None: ) -async def user_handle(module: str, entity: EntityIDs, session: Uninfo) -> None: +async def user_handle(plugin: PluginInfo, entity: EntityIDs, session: Uninfo) -> None: """用户ban检查 参数: @@ -217,22 +216,12 @@ async def user_handle(module: str, entity: EntityIDs, session: Uninfo) -> None: if not time_val: return time_str = format_time(time_val) - plugin_dao = DataAccess(PluginInfo) - try: - db_plugin = await asyncio.wait_for( - plugin_dao.safe_get_or_none(module=module), timeout=DB_TIMEOUT_SECONDS - ) - except asyncio.TimeoutError: - logger.error(f"查询插件信息超时: {module}", LOGGER_COMMAND) - # 超时时不阻塞,继续执行 - raise SkipPluginException("用户处于黑名单中...") if ( - db_plugin - and not db_plugin.ignore_prompt + plugin and time_val != -1 and ban_result - and freq.is_send_limit_message(db_plugin, entity.user_id, False) + and freq.is_send_limit_message(plugin, entity.user_id, False) ): try: await asyncio.wait_for( @@ -260,7 +249,9 @@ async def user_handle(module: str, entity: EntityIDs, session: Uninfo) -> None: ) -async def auth_ban(matcher: Matcher, bot: Bot, session: Uninfo) -> None: +async def auth_ban( + matcher: Matcher, bot: Bot, session: Uninfo, plugin: PluginInfo +) -> None: """权限检查 - ban 检查 参数: @@ -289,7 +280,7 @@ async def auth_ban(matcher: Matcher, bot: Bot, session: Uninfo) -> None: if entity.user_id: try: await asyncio.wait_for( - user_handle(matcher.plugin_name, entity, session), + user_handle(plugin, entity, session), timeout=DB_TIMEOUT_SECONDS, ) except asyncio.TimeoutError: diff --git a/zhenxun/builtin_plugins/hooks/auth/auth_group.py b/zhenxun/builtin_plugins/hooks/auth/auth_group.py index 24086812..20114bef 100644 --- a/zhenxun/builtin_plugins/hooks/auth/auth_group.py +++ b/zhenxun/builtin_plugins/hooks/auth/auth_group.py @@ -1,50 +1,36 @@ -import asyncio import time from nonebot_plugin_alconna import UniMsg from zhenxun.models.group_console import GroupConsole from zhenxun.models.plugin_info import PluginInfo -from zhenxun.services.data_access import DataAccess -from zhenxun.services.db_context import DB_TIMEOUT_SECONDS from zhenxun.services.log import logger -from zhenxun.utils.utils import EntityIDs from .config import LOGGER_COMMAND, WARNING_THRESHOLD, SwitchEnum from .exception import SkipPluginException -async def auth_group(plugin: PluginInfo, entity: EntityIDs, message: UniMsg): +async def auth_group( + plugin: PluginInfo, + group: GroupConsole | None, + message: UniMsg, + group_id: str | None, +): """群黑名单检测 群总开关检测 参数: plugin: PluginInfo - entity: EntityIDs + group: GroupConsole message: UniMsg """ - start_time = time.time() - - if not entity.group_id: + if not group_id: return + start_time = time.time() + try: text = message.extract_plain_text() - # 从数据库或缓存中获取群组信息 - group_dao = DataAccess(GroupConsole) - - try: - group: GroupConsole | None = await asyncio.wait_for( - group_dao.safe_get_or_none( - group_id=entity.group_id, channel_id__isnull=True - ), - timeout=DB_TIMEOUT_SECONDS, - ) - except asyncio.TimeoutError: - logger.error("查询群组信息超时", LOGGER_COMMAND, session=entity.user_id) - # 超时时不阻塞,继续执行 - return - if not group: raise SkipPluginException("群组信息不存在...") if group.level < 0: @@ -63,6 +49,5 @@ async def auth_group(plugin: PluginInfo, entity: EntityIDs, message: UniMsg): logger.warning( f"auth_group 耗时: {elapsed:.3f}s, plugin={plugin.module}", LOGGER_COMMAND, - session=entity.user_id, - group_id=entity.group_id, + group_id=group_id, ) diff --git a/zhenxun/builtin_plugins/hooks/auth/auth_plugin.py b/zhenxun/builtin_plugins/hooks/auth/auth_plugin.py index 002c97b4..ddab3161 100644 --- a/zhenxun/builtin_plugins/hooks/auth/auth_plugin.py +++ b/zhenxun/builtin_plugins/hooks/auth/auth_plugin.py @@ -6,12 +6,10 @@ from nonebot_plugin_uninfo import Uninfo from zhenxun.models.group_console import GroupConsole from zhenxun.models.plugin_info import PluginInfo -from zhenxun.services.data_access import DataAccess from zhenxun.services.db_context import DB_TIMEOUT_SECONDS from zhenxun.services.log import logger from zhenxun.utils.common_utils import CommonUtils from zhenxun.utils.enum import BlockType -from zhenxun.utils.utils import get_entity_ids from .config import LOGGER_COMMAND, WARNING_THRESHOLD from .exception import IsSuperuserException, SkipPluginException @@ -20,30 +18,17 @@ from .utils import freq, is_poke, send_message class GroupCheck: def __init__( - self, plugin: PluginInfo, group_id: str, session: Uninfo, is_poke: bool + self, plugin: PluginInfo, group: GroupConsole, session: Uninfo, is_poke: bool ) -> None: - self.group_id = group_id self.session = session self.is_poke = is_poke self.plugin = plugin - self.group_dao = DataAccess(GroupConsole) - self.group_data = None + self.group_data = group + self.group_id = group.group_id async def check(self): start_time = time.time() try: - # 只查询一次数据库,使用 DataAccess 的缓存机制 - try: - self.group_data = await asyncio.wait_for( - self.group_dao.safe_get_or_none( - group_id=self.group_id, channel_id__isnull=True - ), - timeout=DB_TIMEOUT_SECONDS, - ) - except asyncio.TimeoutError: - logger.error(f"查询群组数据超时: {self.group_id}", LOGGER_COMMAND) - return # 超时时不阻塞,继续执行 - # 检查超级用户禁用 if ( self.group_data @@ -113,12 +98,13 @@ class GroupCheck: class PluginCheck: - def __init__(self, group_id: str | None, session: Uninfo, is_poke: bool): + def __init__(self, group: GroupConsole | None, session: Uninfo, is_poke: bool): self.session = session self.is_poke = is_poke - self.group_id = group_id - self.group_dao = DataAccess(GroupConsole) - self.group_data = None + self.group_data = group + self.group_id = None + if group: + self.group_id = group.group_id async def check_user(self, plugin: PluginInfo): """全局私聊禁用检测 @@ -156,21 +142,8 @@ class PluginCheck: if plugin.status or plugin.block_type != BlockType.ALL: return """全局状态""" - if self.group_id: - # 使用 DataAccess 的缓存机制 - try: - self.group_data = await asyncio.wait_for( - self.group_dao.safe_get_or_none( - group_id=self.group_id, channel_id__isnull=True - ), - timeout=DB_TIMEOUT_SECONDS, - ) - except asyncio.TimeoutError: - logger.error(f"查询群组数据超时: {self.group_id}", LOGGER_COMMAND) - return # 超时时不阻塞,继续执行 - - if self.group_data and self.group_data.is_super: - raise IsSuperuserException() + if self.group_data and self.group_data.is_super: + raise IsSuperuserException() sid = self.group_id or self.session.user.id if freq.is_send_limit_message(plugin, sid, self.is_poke): @@ -193,7 +166,9 @@ class PluginCheck: ) -async def auth_plugin(plugin: PluginInfo, session: Uninfo, event: Event): +async def auth_plugin( + plugin: PluginInfo, group: GroupConsole | None, session: Uninfo, event: Event +): """插件状态 参数: @@ -203,35 +178,23 @@ async def auth_plugin(plugin: PluginInfo, session: Uninfo, event: Event): """ start_time = time.time() try: - entity = get_entity_ids(session) is_poke_event = is_poke(event) - user_check = PluginCheck(entity.group_id, session, is_poke_event) + user_check = PluginCheck(group, session, is_poke_event) - if entity.group_id: - group_check = GroupCheck(plugin, entity.group_id, session, is_poke_event) - try: - await asyncio.wait_for( - group_check.check(), timeout=DB_TIMEOUT_SECONDS * 2 - ) - except asyncio.TimeoutError: - logger.error(f"群组检查超时: {entity.group_id}", LOGGER_COMMAND) - # 超时时不阻塞,继续执行 + tasks = [] + if group: + tasks.append(GroupCheck(plugin, group, session, is_poke_event).check()) else: - try: - await asyncio.wait_for( - user_check.check_user(plugin), timeout=DB_TIMEOUT_SECONDS - ) - except asyncio.TimeoutError: - logger.error("用户检查超时", LOGGER_COMMAND) - # 超时时不阻塞,继续执行 + tasks.append(user_check.check_user(plugin)) + tasks.append(user_check.check_global(plugin)) try: await asyncio.wait_for( - user_check.check_global(plugin), timeout=DB_TIMEOUT_SECONDS + asyncio.gather(*tasks), timeout=DB_TIMEOUT_SECONDS * 2 ) except asyncio.TimeoutError: - logger.error("全局检查超时", LOGGER_COMMAND) - # 超时时不阻塞,继续执行 + logger.error("插件用户/群组/全局检查超时...", LOGGER_COMMAND) + finally: # 记录总执行时间 elapsed = time.time() - start_time diff --git a/zhenxun/builtin_plugins/hooks/auth/utils.py b/zhenxun/builtin_plugins/hooks/auth/utils.py index 0f925590..d2f1b551 100644 --- a/zhenxun/builtin_plugins/hooks/auth/utils.py +++ b/zhenxun/builtin_plugins/hooks/auth/utils.py @@ -85,7 +85,7 @@ class FreqUtils: return False if plugin.plugin_type == PluginType.DEPENDANT: return False - return plugin.module != "ai" if self._flmt_s.check(sid) else False + return False if plugin.ignore_prompt else self._flmt_s.check(sid) freq = FreqUtils() diff --git a/zhenxun/builtin_plugins/hooks/auth_checker.py b/zhenxun/builtin_plugins/hooks/auth_checker.py index ce4799ea..ae757cec 100644 --- a/zhenxun/builtin_plugins/hooks/auth_checker.py +++ b/zhenxun/builtin_plugins/hooks/auth_checker.py @@ -8,6 +8,7 @@ from nonebot_plugin_alconna import UniMsg from nonebot_plugin_uninfo import Uninfo from tortoise.exceptions import IntegrityError +from zhenxun.models.group_console import GroupConsole from zhenxun.models.plugin_info import PluginInfo from zhenxun.models.user_console import UserConsole from zhenxun.services.data_access import DataAccess @@ -31,6 +32,7 @@ from .auth.exception import ( PermissionExemption, SkipPluginException, ) +from .auth.utils import base_config # 超时设置(秒) TIMEOUT_SECONDS = 5.0 @@ -46,6 +48,16 @@ CIRCUIT_BREAKERS = { # 熔断重置时间(秒) CIRCUIT_RESET_TIME = 300 # 5分钟 +# 并发控制:限制同时进入 hooks 并行检查的协程数 + +# 默认为 6,可通过环境变量 AUTH_HOOKS_CONCURRENCY_LIMIT 调整 +HOOKS_CONCURRENCY_LIMIT = base_config.get("AUTH_HOOKS_CONCURRENCY_LIMIT") + +# 全局信号量与计数器 +HOOKS_SEMAPHORE = asyncio.Semaphore(HOOKS_CONCURRENCY_LIMIT) +HOOKS_ACTIVE_COUNT = 0 +HOOKS_ACTIVE_LOCK = asyncio.Lock() + # 超时装饰器 async def with_timeout(coro, timeout=TIMEOUT_SECONDS, name=None): @@ -259,6 +271,30 @@ async def time_hook(coro, name, time_dict): time_dict[name] = f"{time.time() - start:.3f}s" +async def _enter_hooks_section(): + """尝试获取全局信号量并更新计数器,超时则抛出 PermissionExemption。""" + global HOOKS_ACTIVE_COUNT + # 队列模式:如果达到上限,协程将排队等待直到获取到信号量 + await HOOKS_SEMAPHORE.acquire() + async with HOOKS_ACTIVE_LOCK: + HOOKS_ACTIVE_COUNT += 1 + logger.debug(f"当前并发权限检查数量: {HOOKS_ACTIVE_COUNT}", LOGGER_COMMAND) + + +async def _leave_hooks_section(): + """释放信号量并更新计数器。""" + global HOOKS_ACTIVE_COUNT + from contextlib import suppress + + with suppress(Exception): + HOOKS_SEMAPHORE.release() + async with HOOKS_ACTIVE_LOCK: + HOOKS_ACTIVE_COUNT -= 1 + # 保证计数不为负 + HOOKS_ACTIVE_COUNT = max(HOOKS_ACTIVE_COUNT, 0) + logger.debug(f"当前并发权限检查数量: {HOOKS_ACTIVE_COUNT}", LOGGER_COMMAND) + + async def auth( matcher: Matcher, event: Event, @@ -285,6 +321,9 @@ async def auth( hook_times = {} hooks_time = 0 # 初始化 hooks_time 变量 + # 记录是否已进入 hooks 区域(用于 finally 中释放) + entered_hooks = False + try: if not module: raise PermissionExemption("Matcher插件名称不存在...") @@ -304,6 +343,10 @@ async def auth( ) raise PermissionExemption("获取插件和用户数据超时,请稍后再试...") + # 进入 hooks 并行检查区域(会在高并发时排队) + await _enter_hooks_section() + entered_hooks = True + # 获取插件费用 cost_start = time.time() try: @@ -320,16 +363,32 @@ async def auth( # 执行 bot_filter bot_filter(session) + group = None + if entity.group_id: + group_dao = DataAccess(GroupConsole) + group = await with_timeout( + group_dao.safe_get_or_none( + group_id=entity.group_id, channel_id__isnull=True + ), + name="get_group", + ) + # 并行执行所有 hook 检查,并记录执行时间 hooks_start = time.time() # 创建所有 hook 任务 hook_tasks = [ - time_hook(auth_ban(matcher, bot, session), "auth_ban", hook_times), + time_hook(auth_ban(matcher, bot, session, plugin), "auth_ban", hook_times), time_hook(auth_bot(plugin, bot.self_id), "auth_bot", hook_times), - time_hook(auth_group(plugin, entity, message), "auth_group", hook_times), + time_hook( + auth_group(plugin, group, message, entity.group_id), + "auth_group", + hook_times, + ), time_hook(auth_admin(plugin, session), "auth_admin", hook_times), - time_hook(auth_plugin(plugin, session, event), "auth_plugin", hook_times), + time_hook( + auth_plugin(plugin, group, session, event), "auth_plugin", hook_times + ), time_hook(auth_limit(plugin, session), "auth_limit", hook_times), ] @@ -358,7 +417,17 @@ async def auth( logger.debug("超级用户跳过权限检测...", LOGGER_COMMAND, session=session) except PermissionExemption as e: logger.info(str(e), LOGGER_COMMAND, session=session) - + finally: + # 如果进入过 hooks 区域,确保释放信号量(即使上层处理抛出了异常) + if entered_hooks: + try: + await _leave_hooks_section() + except Exception: + logger.error( + "释放 hooks 信号量时出错", + LOGGER_COMMAND, + session=session, + ) # 扣除金币 if not ignore_flag and cost_gold > 0: gold_start = time.time() diff --git a/zhenxun/builtin_plugins/hooks/chkdsk_hook.py b/zhenxun/builtin_plugins/hooks/chkdsk_hook.py index 30080281..c657e3ed 100644 --- a/zhenxun/builtin_plugins/hooks/chkdsk_hook.py +++ b/zhenxun/builtin_plugins/hooks/chkdsk_hook.py @@ -43,18 +43,20 @@ class BanCheckLimiter: def check(self, key: str | float) -> bool: if time.time() - self.mtime[key] > self.default_check_time: - self.mtime[key] = time.time() - self.mint[key] = 0 - return False + return self._extracted_from_check_3(key, False) if ( self.mint[key] >= self.default_count and time.time() - self.mtime[key] < self.default_check_time ): - self.mtime[key] = time.time() - self.mint[key] = 0 - return True + return self._extracted_from_check_3(key, True) return False + # TODO Rename this here and in `check` + def _extracted_from_check_3(self, key, arg1): + self.mtime[key] = time.time() + self.mint[key] = 0 + return arg1 + _blmt = BanCheckLimiter( malicious_check_time, @@ -70,16 +72,15 @@ async def _( module = None if plugin := matcher.plugin: module = plugin.module_name - if metadata := plugin.metadata: - extra = metadata.extra - if extra.get("plugin_type") in [ - PluginType.HIDDEN, - PluginType.DEPENDANT, - PluginType.ADMIN, - PluginType.SUPERUSER, - ]: - return - else: + if not (metadata := plugin.metadata): + return + extra = metadata.extra + if extra.get("plugin_type") in [ + PluginType.HIDDEN, + PluginType.DEPENDANT, + PluginType.ADMIN, + PluginType.SUPERUSER, + ]: return if matcher.type == "notice": return @@ -88,32 +89,31 @@ async def _( malicious_ban_time = Config.get_config("hook", "MALICIOUS_BAN_TIME") if not malicious_ban_time: raise ValueError("模块: [hook], 配置项: [MALICIOUS_BAN_TIME] 为空或小于0") - if user_id: - if module: - if _blmt.check(f"{user_id}__{module}"): - await BanConsole.ban( - user_id, - group_id, - 9, - "恶意触发命令检测", - malicious_ban_time * 60, - bot.self_id, - ) - logger.info( - f"触发了恶意触发检测: {matcher.plugin_name}", - "HOOK", - session=session, - ) - await MessageUtils.build_message( - [ - At(flag="user", target=user_id), - "检测到恶意触发命令,您将被封禁 30 分钟", - ] - ).send() - logger.debug( - f"触发了恶意触发检测: {matcher.plugin_name}", - "HOOK", - session=session, - ) - raise IgnoredException("检测到恶意触发命令") - _blmt.add(f"{user_id}__{module}") + if user_id and module: + if _blmt.check(f"{user_id}__{module}"): + await BanConsole.ban( + user_id, + group_id, + 9, + "恶意触发命令检测", + malicious_ban_time * 60, + bot.self_id, + ) + logger.info( + f"触发了恶意触发检测: {matcher.plugin_name}", + "HOOK", + session=session, + ) + await MessageUtils.build_message( + [ + At(flag="user", target=user_id), + "检测到恶意触发命令,您将被封禁 30 分钟", + ] + ).send() + logger.debug( + f"触发了恶意触发检测: {matcher.plugin_name}", + "HOOK", + session=session, + ) + raise IgnoredException("检测到恶意触发命令") + _blmt.add(f"{user_id}__{module}") diff --git a/zhenxun/services/cache/__init__.py b/zhenxun/services/cache/__init__.py index ee19a4ce..9e222a44 100644 --- a/zhenxun/services/cache/__init__.py +++ b/zhenxun/services/cache/__init__.py @@ -98,6 +98,7 @@ from .cache_containers import CacheDict, CacheList from .config import ( CACHE_KEY_PREFIX, CACHE_KEY_SEPARATOR, + CACHE_TIMEOUT, DEFAULT_EXPIRE, LOG_COMMAND, SPECIAL_KEY_FORMATS, @@ -551,7 +552,6 @@ class CacheManager: 返回: Any: 缓存数据,如果不存在返回默认值 """ - from zhenxun.services.db_context import DB_TIMEOUT_SECONDS # 如果缓存被禁用或缓存模式为NONE,直接返回默认值 if not self.enabled or cache_config.cache_mode == CacheMode.NONE: @@ -561,7 +561,7 @@ class CacheManager: cache_key = self._build_key(cache_type, key) data = await asyncio.wait_for( self.cache_backend.get(cache_key), # type: ignore - timeout=DB_TIMEOUT_SECONDS, + timeout=CACHE_TIMEOUT, ) if data is None: diff --git a/zhenxun/services/cache/config.py b/zhenxun/services/cache/config.py index b974787b..f699657c 100644 --- a/zhenxun/services/cache/config.py +++ b/zhenxun/services/cache/config.py @@ -5,6 +5,9 @@ # 日志标识 LOG_COMMAND = "CacheRoot" +# 缓存获取超时时间(秒) +CACHE_TIMEOUT = 10 + # 默认缓存过期时间(秒) DEFAULT_EXPIRE = 600 diff --git a/zhenxun/services/db_context/utils.py b/zhenxun/services/db_context/utils.py index d0f58a1e..a1bb3824 100644 --- a/zhenxun/services/db_context/utils.py +++ b/zhenxun/services/db_context/utils.py @@ -27,5 +27,8 @@ async def with_db_timeout( return result except asyncio.TimeoutError: if operation: - logger.error(f"数据库操作超时: {operation} (>{timeout}s)", LOG_COMMAND) + logger.error( + f"数据库操作超时: {operation} (>{timeout}s) 来源: {source}", + LOG_COMMAND, + ) raise