修复并发时数据库超时 (#2063)

* 🔧 修复和优化:调整超时设置,重构检查逻辑,简化代码结构

- 在 `chkdsk_hook.py` 中重构 `check` 方法,提取公共逻辑
- 更新 `CacheManager` 中的超时设置,使用新的 `CACHE_TIMEOUT`
- 在 `utils.py` 中添加缓存逻辑,记录数据库操作的执行情况

*  feat(auth): 添加并发控制,优化权限检查逻辑

* Update utils.py

* 🚨 auto fix by pre-commit hooks

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
HibiKier 2025-10-09 08:46:08 +08:00 committed by GitHub
parent f94121080f
commit e7f3c210df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 177 additions and 154 deletions

View File

@ -58,5 +58,14 @@ Config.add_plugin_config(
type=bool,
)
Config.add_plugin_config(
"hook",
"AUTH_HOOKS_CONCURRENCY_LIMIT",
5,
help="同步进入权限钩子最大并发数",
default_value=5,
type=int,
)
nonebot.load_plugins(str(Path(__file__).parent.resolve()))

View File

@ -96,7 +96,6 @@ async def is_ban(user_id: str | None, group_id: str | None) -> int:
f"查询ban记录超时: user_id={user_id}, group_id={group_id}",
LOGGER_COMMAND,
)
# 超时时返回0避免阻塞
return 0
# 检查记录并计算ban时间
@ -199,7 +198,7 @@ async def group_handle(group_id: str) -> None:
)
async def user_handle(module: str, entity: EntityIDs, session: Uninfo) -> None:
async def user_handle(plugin: PluginInfo, entity: EntityIDs, session: Uninfo) -> None:
"""用户ban检查
参数:
@ -217,22 +216,12 @@ async def user_handle(module: str, entity: EntityIDs, session: Uninfo) -> None:
if not time_val:
return
time_str = format_time(time_val)
plugin_dao = DataAccess(PluginInfo)
try:
db_plugin = await asyncio.wait_for(
plugin_dao.safe_get_or_none(module=module), timeout=DB_TIMEOUT_SECONDS
)
except asyncio.TimeoutError:
logger.error(f"查询插件信息超时: {module}", LOGGER_COMMAND)
# 超时时不阻塞,继续执行
raise SkipPluginException("用户处于黑名单中...")
if (
db_plugin
and not db_plugin.ignore_prompt
plugin
and time_val != -1
and ban_result
and freq.is_send_limit_message(db_plugin, entity.user_id, False)
and freq.is_send_limit_message(plugin, entity.user_id, False)
):
try:
await asyncio.wait_for(
@ -260,7 +249,9 @@ async def user_handle(module: str, entity: EntityIDs, session: Uninfo) -> None:
)
async def auth_ban(matcher: Matcher, bot: Bot, session: Uninfo) -> None:
async def auth_ban(
matcher: Matcher, bot: Bot, session: Uninfo, plugin: PluginInfo
) -> None:
"""权限检查 - ban 检查
参数:
@ -289,7 +280,7 @@ async def auth_ban(matcher: Matcher, bot: Bot, session: Uninfo) -> None:
if entity.user_id:
try:
await asyncio.wait_for(
user_handle(matcher.plugin_name, entity, session),
user_handle(plugin, entity, session),
timeout=DB_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:

View File

@ -1,50 +1,36 @@
import asyncio
import time
from nonebot_plugin_alconna import UniMsg
from zhenxun.models.group_console import GroupConsole
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.data_access import DataAccess
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
from zhenxun.services.log import logger
from zhenxun.utils.utils import EntityIDs
from .config import LOGGER_COMMAND, WARNING_THRESHOLD, SwitchEnum
from .exception import SkipPluginException
async def auth_group(plugin: PluginInfo, entity: EntityIDs, message: UniMsg):
async def auth_group(
plugin: PluginInfo,
group: GroupConsole | None,
message: UniMsg,
group_id: str | None,
):
"""群黑名单检测 群总开关检测
参数:
plugin: PluginInfo
entity: EntityIDs
group: GroupConsole
message: UniMsg
"""
start_time = time.time()
if not entity.group_id:
if not group_id:
return
start_time = time.time()
try:
text = message.extract_plain_text()
# 从数据库或缓存中获取群组信息
group_dao = DataAccess(GroupConsole)
try:
group: GroupConsole | None = await asyncio.wait_for(
group_dao.safe_get_or_none(
group_id=entity.group_id, channel_id__isnull=True
),
timeout=DB_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.error("查询群组信息超时", LOGGER_COMMAND, session=entity.user_id)
# 超时时不阻塞,继续执行
return
if not group:
raise SkipPluginException("群组信息不存在...")
if group.level < 0:
@ -63,6 +49,5 @@ async def auth_group(plugin: PluginInfo, entity: EntityIDs, message: UniMsg):
logger.warning(
f"auth_group 耗时: {elapsed:.3f}s, plugin={plugin.module}",
LOGGER_COMMAND,
session=entity.user_id,
group_id=entity.group_id,
group_id=group_id,
)

View File

@ -6,12 +6,10 @@ from nonebot_plugin_uninfo import Uninfo
from zhenxun.models.group_console import GroupConsole
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.data_access import DataAccess
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
from zhenxun.services.log import logger
from zhenxun.utils.common_utils import CommonUtils
from zhenxun.utils.enum import BlockType
from zhenxun.utils.utils import get_entity_ids
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
from .exception import IsSuperuserException, SkipPluginException
@ -20,30 +18,17 @@ from .utils import freq, is_poke, send_message
class GroupCheck:
def __init__(
self, plugin: PluginInfo, group_id: str, session: Uninfo, is_poke: bool
self, plugin: PluginInfo, group: GroupConsole, session: Uninfo, is_poke: bool
) -> None:
self.group_id = group_id
self.session = session
self.is_poke = is_poke
self.plugin = plugin
self.group_dao = DataAccess(GroupConsole)
self.group_data = None
self.group_data = group
self.group_id = group.group_id
async def check(self):
start_time = time.time()
try:
# 只查询一次数据库,使用 DataAccess 的缓存机制
try:
self.group_data = await asyncio.wait_for(
self.group_dao.safe_get_or_none(
group_id=self.group_id, channel_id__isnull=True
),
timeout=DB_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.error(f"查询群组数据超时: {self.group_id}", LOGGER_COMMAND)
return # 超时时不阻塞,继续执行
# 检查超级用户禁用
if (
self.group_data
@ -113,12 +98,13 @@ class GroupCheck:
class PluginCheck:
def __init__(self, group_id: str | None, session: Uninfo, is_poke: bool):
def __init__(self, group: GroupConsole | None, session: Uninfo, is_poke: bool):
self.session = session
self.is_poke = is_poke
self.group_id = group_id
self.group_dao = DataAccess(GroupConsole)
self.group_data = None
self.group_data = group
self.group_id = None
if group:
self.group_id = group.group_id
async def check_user(self, plugin: PluginInfo):
"""全局私聊禁用检测
@ -156,19 +142,6 @@ class PluginCheck:
if plugin.status or plugin.block_type != BlockType.ALL:
return
"""全局状态"""
if self.group_id:
# 使用 DataAccess 的缓存机制
try:
self.group_data = await asyncio.wait_for(
self.group_dao.safe_get_or_none(
group_id=self.group_id, channel_id__isnull=True
),
timeout=DB_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.error(f"查询群组数据超时: {self.group_id}", LOGGER_COMMAND)
return # 超时时不阻塞,继续执行
if self.group_data and self.group_data.is_super:
raise IsSuperuserException()
@ -193,7 +166,9 @@ class PluginCheck:
)
async def auth_plugin(plugin: PluginInfo, session: Uninfo, event: Event):
async def auth_plugin(
plugin: PluginInfo, group: GroupConsole | None, session: Uninfo, event: Event
):
"""插件状态
参数:
@ -203,35 +178,23 @@ async def auth_plugin(plugin: PluginInfo, session: Uninfo, event: Event):
"""
start_time = time.time()
try:
entity = get_entity_ids(session)
is_poke_event = is_poke(event)
user_check = PluginCheck(entity.group_id, session, is_poke_event)
user_check = PluginCheck(group, session, is_poke_event)
if entity.group_id:
group_check = GroupCheck(plugin, entity.group_id, session, is_poke_event)
try:
await asyncio.wait_for(
group_check.check(), timeout=DB_TIMEOUT_SECONDS * 2
)
except asyncio.TimeoutError:
logger.error(f"群组检查超时: {entity.group_id}", LOGGER_COMMAND)
# 超时时不阻塞,继续执行
tasks = []
if group:
tasks.append(GroupCheck(plugin, group, session, is_poke_event).check())
else:
try:
await asyncio.wait_for(
user_check.check_user(plugin), timeout=DB_TIMEOUT_SECONDS
)
except asyncio.TimeoutError:
logger.error("用户检查超时", LOGGER_COMMAND)
# 超时时不阻塞,继续执行
tasks.append(user_check.check_user(plugin))
tasks.append(user_check.check_global(plugin))
try:
await asyncio.wait_for(
user_check.check_global(plugin), timeout=DB_TIMEOUT_SECONDS
asyncio.gather(*tasks), timeout=DB_TIMEOUT_SECONDS * 2
)
except asyncio.TimeoutError:
logger.error("全局检查超时", LOGGER_COMMAND)
# 超时时不阻塞,继续执行
logger.error("插件用户/群组/全局检查超时...", LOGGER_COMMAND)
finally:
# 记录总执行时间
elapsed = time.time() - start_time

View File

@ -85,7 +85,7 @@ class FreqUtils:
return False
if plugin.plugin_type == PluginType.DEPENDANT:
return False
return plugin.module != "ai" if self._flmt_s.check(sid) else False
return False if plugin.ignore_prompt else self._flmt_s.check(sid)
freq = FreqUtils()

View File

@ -8,6 +8,7 @@ from nonebot_plugin_alconna import UniMsg
from nonebot_plugin_uninfo import Uninfo
from tortoise.exceptions import IntegrityError
from zhenxun.models.group_console import GroupConsole
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.user_console import UserConsole
from zhenxun.services.data_access import DataAccess
@ -31,6 +32,7 @@ from .auth.exception import (
PermissionExemption,
SkipPluginException,
)
from .auth.utils import base_config
# 超时设置(秒)
TIMEOUT_SECONDS = 5.0
@ -46,6 +48,16 @@ CIRCUIT_BREAKERS = {
# 熔断重置时间(秒)
CIRCUIT_RESET_TIME = 300 # 5分钟
# 并发控制:限制同时进入 hooks 并行检查的协程数
# 默认为 6可通过环境变量 AUTH_HOOKS_CONCURRENCY_LIMIT 调整
HOOKS_CONCURRENCY_LIMIT = base_config.get("AUTH_HOOKS_CONCURRENCY_LIMIT")
# 全局信号量与计数器
HOOKS_SEMAPHORE = asyncio.Semaphore(HOOKS_CONCURRENCY_LIMIT)
HOOKS_ACTIVE_COUNT = 0
HOOKS_ACTIVE_LOCK = asyncio.Lock()
# 超时装饰器
async def with_timeout(coro, timeout=TIMEOUT_SECONDS, name=None):
@ -259,6 +271,30 @@ async def time_hook(coro, name, time_dict):
time_dict[name] = f"{time.time() - start:.3f}s"
async def _enter_hooks_section():
"""尝试获取全局信号量并更新计数器,超时则抛出 PermissionExemption。"""
global HOOKS_ACTIVE_COUNT
# 队列模式:如果达到上限,协程将排队等待直到获取到信号量
await HOOKS_SEMAPHORE.acquire()
async with HOOKS_ACTIVE_LOCK:
HOOKS_ACTIVE_COUNT += 1
logger.debug(f"当前并发权限检查数量: {HOOKS_ACTIVE_COUNT}", LOGGER_COMMAND)
async def _leave_hooks_section():
"""释放信号量并更新计数器。"""
global HOOKS_ACTIVE_COUNT
from contextlib import suppress
with suppress(Exception):
HOOKS_SEMAPHORE.release()
async with HOOKS_ACTIVE_LOCK:
HOOKS_ACTIVE_COUNT -= 1
# 保证计数不为负
HOOKS_ACTIVE_COUNT = max(HOOKS_ACTIVE_COUNT, 0)
logger.debug(f"当前并发权限检查数量: {HOOKS_ACTIVE_COUNT}", LOGGER_COMMAND)
async def auth(
matcher: Matcher,
event: Event,
@ -285,6 +321,9 @@ async def auth(
hook_times = {}
hooks_time = 0 # 初始化 hooks_time 变量
# 记录是否已进入 hooks 区域(用于 finally 中释放)
entered_hooks = False
try:
if not module:
raise PermissionExemption("Matcher插件名称不存在...")
@ -304,6 +343,10 @@ async def auth(
)
raise PermissionExemption("获取插件和用户数据超时,请稍后再试...")
# 进入 hooks 并行检查区域(会在高并发时排队)
await _enter_hooks_section()
entered_hooks = True
# 获取插件费用
cost_start = time.time()
try:
@ -320,16 +363,32 @@ async def auth(
# 执行 bot_filter
bot_filter(session)
group = None
if entity.group_id:
group_dao = DataAccess(GroupConsole)
group = await with_timeout(
group_dao.safe_get_or_none(
group_id=entity.group_id, channel_id__isnull=True
),
name="get_group",
)
# 并行执行所有 hook 检查,并记录执行时间
hooks_start = time.time()
# 创建所有 hook 任务
hook_tasks = [
time_hook(auth_ban(matcher, bot, session), "auth_ban", hook_times),
time_hook(auth_ban(matcher, bot, session, plugin), "auth_ban", hook_times),
time_hook(auth_bot(plugin, bot.self_id), "auth_bot", hook_times),
time_hook(auth_group(plugin, entity, message), "auth_group", hook_times),
time_hook(
auth_group(plugin, group, message, entity.group_id),
"auth_group",
hook_times,
),
time_hook(auth_admin(plugin, session), "auth_admin", hook_times),
time_hook(auth_plugin(plugin, session, event), "auth_plugin", hook_times),
time_hook(
auth_plugin(plugin, group, session, event), "auth_plugin", hook_times
),
time_hook(auth_limit(plugin, session), "auth_limit", hook_times),
]
@ -358,7 +417,17 @@ async def auth(
logger.debug("超级用户跳过权限检测...", LOGGER_COMMAND, session=session)
except PermissionExemption as e:
logger.info(str(e), LOGGER_COMMAND, session=session)
finally:
# 如果进入过 hooks 区域,确保释放信号量(即使上层处理抛出了异常)
if entered_hooks:
try:
await _leave_hooks_section()
except Exception:
logger.error(
"释放 hooks 信号量时出错",
LOGGER_COMMAND,
session=session,
)
# 扣除金币
if not ignore_flag and cost_gold > 0:
gold_start = time.time()

View File

@ -43,17 +43,19 @@ class BanCheckLimiter:
def check(self, key: str | float) -> bool:
if time.time() - self.mtime[key] > self.default_check_time:
self.mtime[key] = time.time()
self.mint[key] = 0
return False
return self._extracted_from_check_3(key, False)
if (
self.mint[key] >= self.default_count
and time.time() - self.mtime[key] < self.default_check_time
):
return self._extracted_from_check_3(key, True)
return False
# TODO Rename this here and in `check`
def _extracted_from_check_3(self, key, arg1):
self.mtime[key] = time.time()
self.mint[key] = 0
return True
return False
return arg1
_blmt = BanCheckLimiter(
@ -70,7 +72,8 @@ async def _(
module = None
if plugin := matcher.plugin:
module = plugin.module_name
if metadata := plugin.metadata:
if not (metadata := plugin.metadata):
return
extra = metadata.extra
if extra.get("plugin_type") in [
PluginType.HIDDEN,
@ -79,8 +82,6 @@ async def _(
PluginType.SUPERUSER,
]:
return
else:
return
if matcher.type == "notice":
return
user_id = session.id1
@ -88,8 +89,7 @@ async def _(
malicious_ban_time = Config.get_config("hook", "MALICIOUS_BAN_TIME")
if not malicious_ban_time:
raise ValueError("模块: [hook], 配置项: [MALICIOUS_BAN_TIME] 为空或小于0")
if user_id:
if module:
if user_id and module:
if _blmt.check(f"{user_id}__{module}"):
await BanConsole.ban(
user_id,

View File

@ -98,6 +98,7 @@ from .cache_containers import CacheDict, CacheList
from .config import (
CACHE_KEY_PREFIX,
CACHE_KEY_SEPARATOR,
CACHE_TIMEOUT,
DEFAULT_EXPIRE,
LOG_COMMAND,
SPECIAL_KEY_FORMATS,
@ -551,7 +552,6 @@ class CacheManager:
返回:
Any: 缓存数据如果不存在返回默认值
"""
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
# 如果缓存被禁用或缓存模式为NONE直接返回默认值
if not self.enabled or cache_config.cache_mode == CacheMode.NONE:
@ -561,7 +561,7 @@ class CacheManager:
cache_key = self._build_key(cache_type, key)
data = await asyncio.wait_for(
self.cache_backend.get(cache_key), # type: ignore
timeout=DB_TIMEOUT_SECONDS,
timeout=CACHE_TIMEOUT,
)
if data is None:

View File

@ -5,6 +5,9 @@
# 日志标识
LOG_COMMAND = "CacheRoot"
# 缓存获取超时时间(秒)
CACHE_TIMEOUT = 10
# 默认缓存过期时间(秒)
DEFAULT_EXPIRE = 600

View File

@ -27,5 +27,8 @@ async def with_db_timeout(
return result
except asyncio.TimeoutError:
if operation:
logger.error(f"数据库操作超时: {operation} (>{timeout}s)", LOG_COMMAND)
logger.error(
f"数据库操作超时: {operation} (>{timeout}s) 来源: {source}",
LOG_COMMAND,
)
raise