feat(auth): 添加并发控制,优化权限检查逻辑
Some checks failed
Sequential Lint and Type Check / ruff-call (push) Has been cancelled
Sequential Lint and Type Check / pyright-call (push) Has been cancelled

This commit is contained in:
HibiKier 2025-10-08 23:15:38 +08:00
parent a8ee63b199
commit 7058850444
6 changed files with 123 additions and 106 deletions

View File

@ -58,5 +58,14 @@ Config.add_plugin_config(
type=bool,
)
Config.add_plugin_config(
"hook",
"AUTH_HOOKS_CONCURRENCY_LIMIT",
5,
help="同步进入权限钩子最大并发数",
default_value=5,
type=int,
)
nonebot.load_plugins(str(Path(__file__).parent.resolve()))

View File

@ -96,7 +96,6 @@ async def is_ban(user_id: str | None, group_id: str | None) -> int:
f"查询ban记录超时: user_id={user_id}, group_id={group_id}",
LOGGER_COMMAND,
)
# 超时时返回0避免阻塞
return 0
# 检查记录并计算ban时间
@ -199,7 +198,7 @@ async def group_handle(group_id: str) -> None:
)
async def user_handle(module: str, entity: EntityIDs, session: Uninfo) -> None:
async def user_handle(plugin: PluginInfo, entity: EntityIDs, session: Uninfo) -> None:
"""用户ban检查
参数:
@ -217,22 +216,12 @@ async def user_handle(module: str, entity: EntityIDs, session: Uninfo) -> None:
if not time_val:
return
time_str = format_time(time_val)
plugin_dao = DataAccess(PluginInfo)
try:
db_plugin = await asyncio.wait_for(
plugin_dao.safe_get_or_none(module=module), timeout=DB_TIMEOUT_SECONDS
)
except asyncio.TimeoutError:
logger.error(f"查询插件信息超时: {module}", LOGGER_COMMAND)
# 超时时不阻塞,继续执行
raise SkipPluginException("用户处于黑名单中...")
if (
db_plugin
and not db_plugin.ignore_prompt
plugin
and time_val != -1
and ban_result
and freq.is_send_limit_message(db_plugin, entity.user_id, False)
and freq.is_send_limit_message(plugin, entity.user_id, False)
):
try:
await asyncio.wait_for(
@ -260,7 +249,9 @@ async def user_handle(module: str, entity: EntityIDs, session: Uninfo) -> None:
)
async def auth_ban(matcher: Matcher, bot: Bot, session: Uninfo) -> None:
async def auth_ban(
matcher: Matcher, bot: Bot, session: Uninfo, plugin: PluginInfo
) -> None:
"""权限检查 - ban 检查
参数:
@ -289,7 +280,7 @@ async def auth_ban(matcher: Matcher, bot: Bot, session: Uninfo) -> None:
if entity.user_id:
try:
await asyncio.wait_for(
user_handle(matcher.plugin_name, entity, session),
user_handle(plugin, entity, session),
timeout=DB_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:

View File

@ -1,50 +1,36 @@
import asyncio
import time
from nonebot_plugin_alconna import UniMsg
from zhenxun.models.group_console import GroupConsole
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.data_access import DataAccess
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
from zhenxun.services.log import logger
from zhenxun.utils.utils import EntityIDs
from .config import LOGGER_COMMAND, WARNING_THRESHOLD, SwitchEnum
from .exception import SkipPluginException
async def auth_group(plugin: PluginInfo, entity: EntityIDs, message: UniMsg):
async def auth_group(
plugin: PluginInfo,
group: GroupConsole | None,
message: UniMsg,
group_id: str | None,
):
"""群黑名单检测 群总开关检测
参数:
plugin: PluginInfo
entity: EntityIDs
group: GroupConsole
message: UniMsg
"""
start_time = time.time()
if not entity.group_id:
if not group_id:
return
start_time = time.time()
try:
text = message.extract_plain_text()
# 从数据库或缓存中获取群组信息
group_dao = DataAccess(GroupConsole)
try:
group: GroupConsole | None = await asyncio.wait_for(
group_dao.safe_get_or_none(
group_id=entity.group_id, channel_id__isnull=True
),
timeout=DB_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.error("查询群组信息超时", LOGGER_COMMAND, session=entity.user_id)
# 超时时不阻塞,继续执行
return
if not group:
raise SkipPluginException("群组信息不存在...")
if group.level < 0:
@ -63,6 +49,5 @@ async def auth_group(plugin: PluginInfo, entity: EntityIDs, message: UniMsg):
logger.warning(
f"auth_group 耗时: {elapsed:.3f}s, plugin={plugin.module}",
LOGGER_COMMAND,
session=entity.user_id,
group_id=entity.group_id,
group_id=group_id,
)

View File

@ -6,12 +6,10 @@ from nonebot_plugin_uninfo import Uninfo
from zhenxun.models.group_console import GroupConsole
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.data_access import DataAccess
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
from zhenxun.services.log import logger
from zhenxun.utils.common_utils import CommonUtils
from zhenxun.utils.enum import BlockType
from zhenxun.utils.utils import get_entity_ids
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
from .exception import IsSuperuserException, SkipPluginException
@ -20,30 +18,17 @@ from .utils import freq, is_poke, send_message
class GroupCheck:
def __init__(
self, plugin: PluginInfo, group_id: str, session: Uninfo, is_poke: bool
self, plugin: PluginInfo, group: GroupConsole, session: Uninfo, is_poke: bool
) -> None:
self.group_id = group_id
self.session = session
self.is_poke = is_poke
self.plugin = plugin
self.group_dao = DataAccess(GroupConsole)
self.group_data = None
self.group_data = group
self.group_id = group.group_id
async def check(self):
start_time = time.time()
try:
# 只查询一次数据库,使用 DataAccess 的缓存机制
try:
self.group_data = await asyncio.wait_for(
self.group_dao.safe_get_or_none(
group_id=self.group_id, channel_id__isnull=True
),
timeout=DB_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.error(f"查询群组数据超时: {self.group_id}", LOGGER_COMMAND)
return # 超时时不阻塞,继续执行
# 检查超级用户禁用
if (
self.group_data
@ -113,12 +98,13 @@ class GroupCheck:
class PluginCheck:
def __init__(self, group_id: str | None, session: Uninfo, is_poke: bool):
def __init__(self, group: GroupConsole | None, session: Uninfo, is_poke: bool):
self.session = session
self.is_poke = is_poke
self.group_id = group_id
self.group_dao = DataAccess(GroupConsole)
self.group_data = None
self.group_data = group
self.group_id = None
if group:
self.group_id = group.group_id
async def check_user(self, plugin: PluginInfo):
"""全局私聊禁用检测
@ -156,21 +142,8 @@ class PluginCheck:
if plugin.status or plugin.block_type != BlockType.ALL:
return
"""全局状态"""
if self.group_id:
# 使用 DataAccess 的缓存机制
try:
self.group_data = await asyncio.wait_for(
self.group_dao.safe_get_or_none(
group_id=self.group_id, channel_id__isnull=True
),
timeout=DB_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.error(f"查询群组数据超时: {self.group_id}", LOGGER_COMMAND)
return # 超时时不阻塞,继续执行
if self.group_data and self.group_data.is_super:
raise IsSuperuserException()
if self.group_data and self.group_data.is_super:
raise IsSuperuserException()
sid = self.group_id or self.session.user.id
if freq.is_send_limit_message(plugin, sid, self.is_poke):
@ -193,7 +166,9 @@ class PluginCheck:
)
async def auth_plugin(plugin: PluginInfo, session: Uninfo, event: Event):
async def auth_plugin(
plugin: PluginInfo, group: GroupConsole | None, session: Uninfo, event: Event
):
"""插件状态
参数:
@ -203,35 +178,23 @@ async def auth_plugin(plugin: PluginInfo, session: Uninfo, event: Event):
"""
start_time = time.time()
try:
entity = get_entity_ids(session)
is_poke_event = is_poke(event)
user_check = PluginCheck(entity.group_id, session, is_poke_event)
user_check = PluginCheck(group, session, is_poke_event)
if entity.group_id:
group_check = GroupCheck(plugin, entity.group_id, session, is_poke_event)
try:
await asyncio.wait_for(
group_check.check(), timeout=DB_TIMEOUT_SECONDS * 2
)
except asyncio.TimeoutError:
logger.error(f"群组检查超时: {entity.group_id}", LOGGER_COMMAND)
# 超时时不阻塞,继续执行
tasks = []
if group:
tasks.append(GroupCheck(plugin, group, session, is_poke_event).check())
else:
try:
await asyncio.wait_for(
user_check.check_user(plugin), timeout=DB_TIMEOUT_SECONDS
)
except asyncio.TimeoutError:
logger.error("用户检查超时", LOGGER_COMMAND)
# 超时时不阻塞,继续执行
tasks.append(user_check.check_user(plugin))
tasks.append(user_check.check_global(plugin))
try:
await asyncio.wait_for(
user_check.check_global(plugin), timeout=DB_TIMEOUT_SECONDS
asyncio.gather(*tasks), timeout=DB_TIMEOUT_SECONDS * 2
)
except asyncio.TimeoutError:
logger.error("全局检查超时", LOGGER_COMMAND)
# 超时时不阻塞,继续执行
logger.error("插件用户/群组/全局检查超时...", LOGGER_COMMAND)
finally:
# 记录总执行时间
elapsed = time.time() - start_time

View File

@ -85,7 +85,7 @@ class FreqUtils:
return False
if plugin.plugin_type == PluginType.DEPENDANT:
return False
return plugin.module != "ai" if self._flmt_s.check(sid) else False
return False if plugin.ignore_prompt else self._flmt_s.check(sid)
freq = FreqUtils()

View File

@ -8,6 +8,7 @@ from nonebot_plugin_alconna import UniMsg
from nonebot_plugin_uninfo import Uninfo
from tortoise.exceptions import IntegrityError
from zhenxun.models.group_console import GroupConsole
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.user_console import UserConsole
from zhenxun.services.data_access import DataAccess
@ -31,6 +32,7 @@ from .auth.exception import (
PermissionExemption,
SkipPluginException,
)
from .auth.utils import base_config
# 超时设置(秒)
TIMEOUT_SECONDS = 5.0
@ -46,6 +48,16 @@ CIRCUIT_BREAKERS = {
# 熔断重置时间(秒)
CIRCUIT_RESET_TIME = 300 # 5分钟
# 并发控制:限制同时进入 hooks 并行检查的协程数
# 默认为 6可通过环境变量 AUTH_HOOKS_CONCURRENCY_LIMIT 调整
HOOKS_CONCURRENCY_LIMIT = base_config.get("AUTH_HOOKS_CONCURRENCY_LIMIT")
# 全局信号量与计数器
HOOKS_SEMAPHORE = asyncio.Semaphore(HOOKS_CONCURRENCY_LIMIT)
HOOKS_ACTIVE_COUNT = 0
HOOKS_ACTIVE_LOCK = asyncio.Lock()
# 超时装饰器
async def with_timeout(coro, timeout=TIMEOUT_SECONDS, name=None):
@ -259,6 +271,30 @@ async def time_hook(coro, name, time_dict):
time_dict[name] = f"{time.time() - start:.3f}s"
async def _enter_hooks_section():
"""尝试获取全局信号量并更新计数器,超时则抛出 PermissionExemption。"""
global HOOKS_ACTIVE_COUNT
# 队列模式:如果达到上限,协程将排队等待直到获取到信号量
await HOOKS_SEMAPHORE.acquire()
async with HOOKS_ACTIVE_LOCK:
HOOKS_ACTIVE_COUNT += 1
logger.debug(f"当前并发权限检查数量: {HOOKS_ACTIVE_COUNT}", LOGGER_COMMAND)
async def _leave_hooks_section():
"""释放信号量并更新计数器。"""
global HOOKS_ACTIVE_COUNT
from contextlib import suppress
with suppress(Exception):
HOOKS_SEMAPHORE.release()
async with HOOKS_ACTIVE_LOCK:
HOOKS_ACTIVE_COUNT -= 1
# 保证计数不为负
HOOKS_ACTIVE_COUNT = max(HOOKS_ACTIVE_COUNT, 0)
logger.debug(f"当前并发权限检查数量: {HOOKS_ACTIVE_COUNT}", LOGGER_COMMAND)
async def auth(
matcher: Matcher,
event: Event,
@ -285,6 +321,9 @@ async def auth(
hook_times = {}
hooks_time = 0 # 初始化 hooks_time 变量
# 记录是否已进入 hooks 区域(用于 finally 中释放)
entered_hooks = False
try:
if not module:
raise PermissionExemption("Matcher插件名称不存在...")
@ -304,6 +343,10 @@ async def auth(
)
raise PermissionExemption("获取插件和用户数据超时,请稍后再试...")
# 进入 hooks 并行检查区域(会在高并发时排队)
await _enter_hooks_section()
entered_hooks = True
# 获取插件费用
cost_start = time.time()
try:
@ -320,16 +363,32 @@ async def auth(
# 执行 bot_filter
bot_filter(session)
group = None
if entity.group_id:
group_dao = DataAccess(GroupConsole)
group = await with_timeout(
group_dao.safe_get_or_none(
group_id=entity.group_id, channel_id__isnull=True
),
name="get_group",
)
# 并行执行所有 hook 检查,并记录执行时间
hooks_start = time.time()
# 创建所有 hook 任务
hook_tasks = [
time_hook(auth_ban(matcher, bot, session), "auth_ban", hook_times),
time_hook(auth_ban(matcher, bot, session, plugin), "auth_ban", hook_times),
time_hook(auth_bot(plugin, bot.self_id), "auth_bot", hook_times),
time_hook(auth_group(plugin, entity, message), "auth_group", hook_times),
time_hook(
auth_group(plugin, group, message, entity.group_id),
"auth_group",
hook_times,
),
time_hook(auth_admin(plugin, session), "auth_admin", hook_times),
time_hook(auth_plugin(plugin, session, event), "auth_plugin", hook_times),
time_hook(
auth_plugin(plugin, group, session, event), "auth_plugin", hook_times
),
time_hook(auth_limit(plugin, session), "auth_limit", hook_times),
]
@ -358,7 +417,17 @@ async def auth(
logger.debug("超级用户跳过权限检测...", LOGGER_COMMAND, session=session)
except PermissionExemption as e:
logger.info(str(e), LOGGER_COMMAND, session=session)
finally:
# 如果进入过 hooks 区域,确保释放信号量(即使上层处理抛出了异常)
if entered_hooks:
try:
await _leave_hooks_section()
except Exception:
logger.error(
"释放 hooks 信号量时出错",
LOGGER_COMMAND,
session=session,
)
# 扣除金币
if not ignore_flag and cost_gold > 0:
gold_start = time.time()