mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 06:12:53 +08:00
修复并发时数据库超时 (#2063)
* 🔧 修复和优化:调整超时设置,重构检查逻辑,简化代码结构 - 在 `chkdsk_hook.py` 中重构 `check` 方法,提取公共逻辑 - 更新 `CacheManager` 中的超时设置,使用新的 `CACHE_TIMEOUT` - 在 `utils.py` 中添加缓存逻辑,记录数据库操作的执行情况 * ✨ feat(auth): 添加并发控制,优化权限检查逻辑 * Update utils.py * 🚨 auto fix by pre-commit hooks --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
f94121080f
commit
e7f3c210df
@ -58,5 +58,14 @@ Config.add_plugin_config(
|
|||||||
type=bool,
|
type=bool,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Config.add_plugin_config(
|
||||||
|
"hook",
|
||||||
|
"AUTH_HOOKS_CONCURRENCY_LIMIT",
|
||||||
|
5,
|
||||||
|
help="同步进入权限钩子最大并发数",
|
||||||
|
default_value=5,
|
||||||
|
type=int,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
nonebot.load_plugins(str(Path(__file__).parent.resolve()))
|
nonebot.load_plugins(str(Path(__file__).parent.resolve()))
|
||||||
|
|||||||
@ -96,7 +96,6 @@ async def is_ban(user_id: str | None, group_id: str | None) -> int:
|
|||||||
f"查询ban记录超时: user_id={user_id}, group_id={group_id}",
|
f"查询ban记录超时: user_id={user_id}, group_id={group_id}",
|
||||||
LOGGER_COMMAND,
|
LOGGER_COMMAND,
|
||||||
)
|
)
|
||||||
# 超时时返回0,避免阻塞
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
# 检查记录并计算ban时间
|
# 检查记录并计算ban时间
|
||||||
@ -199,7 +198,7 @@ async def group_handle(group_id: str) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def user_handle(module: str, entity: EntityIDs, session: Uninfo) -> None:
|
async def user_handle(plugin: PluginInfo, entity: EntityIDs, session: Uninfo) -> None:
|
||||||
"""用户ban检查
|
"""用户ban检查
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
@ -217,22 +216,12 @@ async def user_handle(module: str, entity: EntityIDs, session: Uninfo) -> None:
|
|||||||
if not time_val:
|
if not time_val:
|
||||||
return
|
return
|
||||||
time_str = format_time(time_val)
|
time_str = format_time(time_val)
|
||||||
plugin_dao = DataAccess(PluginInfo)
|
|
||||||
try:
|
|
||||||
db_plugin = await asyncio.wait_for(
|
|
||||||
plugin_dao.safe_get_or_none(module=module), timeout=DB_TIMEOUT_SECONDS
|
|
||||||
)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.error(f"查询插件信息超时: {module}", LOGGER_COMMAND)
|
|
||||||
# 超时时不阻塞,继续执行
|
|
||||||
raise SkipPluginException("用户处于黑名单中...")
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
db_plugin
|
plugin
|
||||||
and not db_plugin.ignore_prompt
|
|
||||||
and time_val != -1
|
and time_val != -1
|
||||||
and ban_result
|
and ban_result
|
||||||
and freq.is_send_limit_message(db_plugin, entity.user_id, False)
|
and freq.is_send_limit_message(plugin, entity.user_id, False)
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(
|
||||||
@ -260,7 +249,9 @@ async def user_handle(module: str, entity: EntityIDs, session: Uninfo) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def auth_ban(matcher: Matcher, bot: Bot, session: Uninfo) -> None:
|
async def auth_ban(
|
||||||
|
matcher: Matcher, bot: Bot, session: Uninfo, plugin: PluginInfo
|
||||||
|
) -> None:
|
||||||
"""权限检查 - ban 检查
|
"""权限检查 - ban 检查
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
@ -289,7 +280,7 @@ async def auth_ban(matcher: Matcher, bot: Bot, session: Uninfo) -> None:
|
|||||||
if entity.user_id:
|
if entity.user_id:
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(
|
||||||
user_handle(matcher.plugin_name, entity, session),
|
user_handle(plugin, entity, session),
|
||||||
timeout=DB_TIMEOUT_SECONDS,
|
timeout=DB_TIMEOUT_SECONDS,
|
||||||
)
|
)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
|
|||||||
@ -1,50 +1,36 @@
|
|||||||
import asyncio
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from nonebot_plugin_alconna import UniMsg
|
from nonebot_plugin_alconna import UniMsg
|
||||||
|
|
||||||
from zhenxun.models.group_console import GroupConsole
|
from zhenxun.models.group_console import GroupConsole
|
||||||
from zhenxun.models.plugin_info import PluginInfo
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
from zhenxun.services.data_access import DataAccess
|
|
||||||
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
|
||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
from zhenxun.utils.utils import EntityIDs
|
|
||||||
|
|
||||||
from .config import LOGGER_COMMAND, WARNING_THRESHOLD, SwitchEnum
|
from .config import LOGGER_COMMAND, WARNING_THRESHOLD, SwitchEnum
|
||||||
from .exception import SkipPluginException
|
from .exception import SkipPluginException
|
||||||
|
|
||||||
|
|
||||||
async def auth_group(plugin: PluginInfo, entity: EntityIDs, message: UniMsg):
|
async def auth_group(
|
||||||
|
plugin: PluginInfo,
|
||||||
|
group: GroupConsole | None,
|
||||||
|
message: UniMsg,
|
||||||
|
group_id: str | None,
|
||||||
|
):
|
||||||
"""群黑名单检测 群总开关检测
|
"""群黑名单检测 群总开关检测
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
plugin: PluginInfo
|
plugin: PluginInfo
|
||||||
entity: EntityIDs
|
group: GroupConsole
|
||||||
message: UniMsg
|
message: UniMsg
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
if not group_id:
|
||||||
|
|
||||||
if not entity.group_id:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
text = message.extract_plain_text()
|
text = message.extract_plain_text()
|
||||||
|
|
||||||
# 从数据库或缓存中获取群组信息
|
|
||||||
group_dao = DataAccess(GroupConsole)
|
|
||||||
|
|
||||||
try:
|
|
||||||
group: GroupConsole | None = await asyncio.wait_for(
|
|
||||||
group_dao.safe_get_or_none(
|
|
||||||
group_id=entity.group_id, channel_id__isnull=True
|
|
||||||
),
|
|
||||||
timeout=DB_TIMEOUT_SECONDS,
|
|
||||||
)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.error("查询群组信息超时", LOGGER_COMMAND, session=entity.user_id)
|
|
||||||
# 超时时不阻塞,继续执行
|
|
||||||
return
|
|
||||||
|
|
||||||
if not group:
|
if not group:
|
||||||
raise SkipPluginException("群组信息不存在...")
|
raise SkipPluginException("群组信息不存在...")
|
||||||
if group.level < 0:
|
if group.level < 0:
|
||||||
@ -63,6 +49,5 @@ async def auth_group(plugin: PluginInfo, entity: EntityIDs, message: UniMsg):
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
f"auth_group 耗时: {elapsed:.3f}s, plugin={plugin.module}",
|
f"auth_group 耗时: {elapsed:.3f}s, plugin={plugin.module}",
|
||||||
LOGGER_COMMAND,
|
LOGGER_COMMAND,
|
||||||
session=entity.user_id,
|
group_id=group_id,
|
||||||
group_id=entity.group_id,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@ -6,12 +6,10 @@ from nonebot_plugin_uninfo import Uninfo
|
|||||||
|
|
||||||
from zhenxun.models.group_console import GroupConsole
|
from zhenxun.models.group_console import GroupConsole
|
||||||
from zhenxun.models.plugin_info import PluginInfo
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
from zhenxun.services.data_access import DataAccess
|
|
||||||
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
from zhenxun.utils.common_utils import CommonUtils
|
from zhenxun.utils.common_utils import CommonUtils
|
||||||
from zhenxun.utils.enum import BlockType
|
from zhenxun.utils.enum import BlockType
|
||||||
from zhenxun.utils.utils import get_entity_ids
|
|
||||||
|
|
||||||
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
|
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
|
||||||
from .exception import IsSuperuserException, SkipPluginException
|
from .exception import IsSuperuserException, SkipPluginException
|
||||||
@ -20,30 +18,17 @@ from .utils import freq, is_poke, send_message
|
|||||||
|
|
||||||
class GroupCheck:
|
class GroupCheck:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, plugin: PluginInfo, group_id: str, session: Uninfo, is_poke: bool
|
self, plugin: PluginInfo, group: GroupConsole, session: Uninfo, is_poke: bool
|
||||||
) -> None:
|
) -> None:
|
||||||
self.group_id = group_id
|
|
||||||
self.session = session
|
self.session = session
|
||||||
self.is_poke = is_poke
|
self.is_poke = is_poke
|
||||||
self.plugin = plugin
|
self.plugin = plugin
|
||||||
self.group_dao = DataAccess(GroupConsole)
|
self.group_data = group
|
||||||
self.group_data = None
|
self.group_id = group.group_id
|
||||||
|
|
||||||
async def check(self):
|
async def check(self):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
try:
|
try:
|
||||||
# 只查询一次数据库,使用 DataAccess 的缓存机制
|
|
||||||
try:
|
|
||||||
self.group_data = await asyncio.wait_for(
|
|
||||||
self.group_dao.safe_get_or_none(
|
|
||||||
group_id=self.group_id, channel_id__isnull=True
|
|
||||||
),
|
|
||||||
timeout=DB_TIMEOUT_SECONDS,
|
|
||||||
)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.error(f"查询群组数据超时: {self.group_id}", LOGGER_COMMAND)
|
|
||||||
return # 超时时不阻塞,继续执行
|
|
||||||
|
|
||||||
# 检查超级用户禁用
|
# 检查超级用户禁用
|
||||||
if (
|
if (
|
||||||
self.group_data
|
self.group_data
|
||||||
@ -113,12 +98,13 @@ class GroupCheck:
|
|||||||
|
|
||||||
|
|
||||||
class PluginCheck:
|
class PluginCheck:
|
||||||
def __init__(self, group_id: str | None, session: Uninfo, is_poke: bool):
|
def __init__(self, group: GroupConsole | None, session: Uninfo, is_poke: bool):
|
||||||
self.session = session
|
self.session = session
|
||||||
self.is_poke = is_poke
|
self.is_poke = is_poke
|
||||||
self.group_id = group_id
|
self.group_data = group
|
||||||
self.group_dao = DataAccess(GroupConsole)
|
self.group_id = None
|
||||||
self.group_data = None
|
if group:
|
||||||
|
self.group_id = group.group_id
|
||||||
|
|
||||||
async def check_user(self, plugin: PluginInfo):
|
async def check_user(self, plugin: PluginInfo):
|
||||||
"""全局私聊禁用检测
|
"""全局私聊禁用检测
|
||||||
@ -156,21 +142,8 @@ class PluginCheck:
|
|||||||
if plugin.status or plugin.block_type != BlockType.ALL:
|
if plugin.status or plugin.block_type != BlockType.ALL:
|
||||||
return
|
return
|
||||||
"""全局状态"""
|
"""全局状态"""
|
||||||
if self.group_id:
|
if self.group_data and self.group_data.is_super:
|
||||||
# 使用 DataAccess 的缓存机制
|
raise IsSuperuserException()
|
||||||
try:
|
|
||||||
self.group_data = await asyncio.wait_for(
|
|
||||||
self.group_dao.safe_get_or_none(
|
|
||||||
group_id=self.group_id, channel_id__isnull=True
|
|
||||||
),
|
|
||||||
timeout=DB_TIMEOUT_SECONDS,
|
|
||||||
)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.error(f"查询群组数据超时: {self.group_id}", LOGGER_COMMAND)
|
|
||||||
return # 超时时不阻塞,继续执行
|
|
||||||
|
|
||||||
if self.group_data and self.group_data.is_super:
|
|
||||||
raise IsSuperuserException()
|
|
||||||
|
|
||||||
sid = self.group_id or self.session.user.id
|
sid = self.group_id or self.session.user.id
|
||||||
if freq.is_send_limit_message(plugin, sid, self.is_poke):
|
if freq.is_send_limit_message(plugin, sid, self.is_poke):
|
||||||
@ -193,7 +166,9 @@ class PluginCheck:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def auth_plugin(plugin: PluginInfo, session: Uninfo, event: Event):
|
async def auth_plugin(
|
||||||
|
plugin: PluginInfo, group: GroupConsole | None, session: Uninfo, event: Event
|
||||||
|
):
|
||||||
"""插件状态
|
"""插件状态
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
@ -203,35 +178,23 @@ async def auth_plugin(plugin: PluginInfo, session: Uninfo, event: Event):
|
|||||||
"""
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
try:
|
try:
|
||||||
entity = get_entity_ids(session)
|
|
||||||
is_poke_event = is_poke(event)
|
is_poke_event = is_poke(event)
|
||||||
user_check = PluginCheck(entity.group_id, session, is_poke_event)
|
user_check = PluginCheck(group, session, is_poke_event)
|
||||||
|
|
||||||
if entity.group_id:
|
tasks = []
|
||||||
group_check = GroupCheck(plugin, entity.group_id, session, is_poke_event)
|
if group:
|
||||||
try:
|
tasks.append(GroupCheck(plugin, group, session, is_poke_event).check())
|
||||||
await asyncio.wait_for(
|
|
||||||
group_check.check(), timeout=DB_TIMEOUT_SECONDS * 2
|
|
||||||
)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.error(f"群组检查超时: {entity.group_id}", LOGGER_COMMAND)
|
|
||||||
# 超时时不阻塞,继续执行
|
|
||||||
else:
|
else:
|
||||||
try:
|
tasks.append(user_check.check_user(plugin))
|
||||||
await asyncio.wait_for(
|
tasks.append(user_check.check_global(plugin))
|
||||||
user_check.check_user(plugin), timeout=DB_TIMEOUT_SECONDS
|
|
||||||
)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.error("用户检查超时", LOGGER_COMMAND)
|
|
||||||
# 超时时不阻塞,继续执行
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(
|
||||||
user_check.check_global(plugin), timeout=DB_TIMEOUT_SECONDS
|
asyncio.gather(*tasks), timeout=DB_TIMEOUT_SECONDS * 2
|
||||||
)
|
)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
logger.error("全局检查超时", LOGGER_COMMAND)
|
logger.error("插件用户/群组/全局检查超时...", LOGGER_COMMAND)
|
||||||
# 超时时不阻塞,继续执行
|
|
||||||
finally:
|
finally:
|
||||||
# 记录总执行时间
|
# 记录总执行时间
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
|
|||||||
@ -85,7 +85,7 @@ class FreqUtils:
|
|||||||
return False
|
return False
|
||||||
if plugin.plugin_type == PluginType.DEPENDANT:
|
if plugin.plugin_type == PluginType.DEPENDANT:
|
||||||
return False
|
return False
|
||||||
return plugin.module != "ai" if self._flmt_s.check(sid) else False
|
return False if plugin.ignore_prompt else self._flmt_s.check(sid)
|
||||||
|
|
||||||
|
|
||||||
freq = FreqUtils()
|
freq = FreqUtils()
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from nonebot_plugin_alconna import UniMsg
|
|||||||
from nonebot_plugin_uninfo import Uninfo
|
from nonebot_plugin_uninfo import Uninfo
|
||||||
from tortoise.exceptions import IntegrityError
|
from tortoise.exceptions import IntegrityError
|
||||||
|
|
||||||
|
from zhenxun.models.group_console import GroupConsole
|
||||||
from zhenxun.models.plugin_info import PluginInfo
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
from zhenxun.models.user_console import UserConsole
|
from zhenxun.models.user_console import UserConsole
|
||||||
from zhenxun.services.data_access import DataAccess
|
from zhenxun.services.data_access import DataAccess
|
||||||
@ -31,6 +32,7 @@ from .auth.exception import (
|
|||||||
PermissionExemption,
|
PermissionExemption,
|
||||||
SkipPluginException,
|
SkipPluginException,
|
||||||
)
|
)
|
||||||
|
from .auth.utils import base_config
|
||||||
|
|
||||||
# 超时设置(秒)
|
# 超时设置(秒)
|
||||||
TIMEOUT_SECONDS = 5.0
|
TIMEOUT_SECONDS = 5.0
|
||||||
@ -46,6 +48,16 @@ CIRCUIT_BREAKERS = {
|
|||||||
# 熔断重置时间(秒)
|
# 熔断重置时间(秒)
|
||||||
CIRCUIT_RESET_TIME = 300 # 5分钟
|
CIRCUIT_RESET_TIME = 300 # 5分钟
|
||||||
|
|
||||||
|
# 并发控制:限制同时进入 hooks 并行检查的协程数
|
||||||
|
|
||||||
|
# 默认为 6,可通过环境变量 AUTH_HOOKS_CONCURRENCY_LIMIT 调整
|
||||||
|
HOOKS_CONCURRENCY_LIMIT = base_config.get("AUTH_HOOKS_CONCURRENCY_LIMIT")
|
||||||
|
|
||||||
|
# 全局信号量与计数器
|
||||||
|
HOOKS_SEMAPHORE = asyncio.Semaphore(HOOKS_CONCURRENCY_LIMIT)
|
||||||
|
HOOKS_ACTIVE_COUNT = 0
|
||||||
|
HOOKS_ACTIVE_LOCK = asyncio.Lock()
|
||||||
|
|
||||||
|
|
||||||
# 超时装饰器
|
# 超时装饰器
|
||||||
async def with_timeout(coro, timeout=TIMEOUT_SECONDS, name=None):
|
async def with_timeout(coro, timeout=TIMEOUT_SECONDS, name=None):
|
||||||
@ -259,6 +271,30 @@ async def time_hook(coro, name, time_dict):
|
|||||||
time_dict[name] = f"{time.time() - start:.3f}s"
|
time_dict[name] = f"{time.time() - start:.3f}s"
|
||||||
|
|
||||||
|
|
||||||
|
async def _enter_hooks_section():
|
||||||
|
"""尝试获取全局信号量并更新计数器,超时则抛出 PermissionExemption。"""
|
||||||
|
global HOOKS_ACTIVE_COUNT
|
||||||
|
# 队列模式:如果达到上限,协程将排队等待直到获取到信号量
|
||||||
|
await HOOKS_SEMAPHORE.acquire()
|
||||||
|
async with HOOKS_ACTIVE_LOCK:
|
||||||
|
HOOKS_ACTIVE_COUNT += 1
|
||||||
|
logger.debug(f"当前并发权限检查数量: {HOOKS_ACTIVE_COUNT}", LOGGER_COMMAND)
|
||||||
|
|
||||||
|
|
||||||
|
async def _leave_hooks_section():
|
||||||
|
"""释放信号量并更新计数器。"""
|
||||||
|
global HOOKS_ACTIVE_COUNT
|
||||||
|
from contextlib import suppress
|
||||||
|
|
||||||
|
with suppress(Exception):
|
||||||
|
HOOKS_SEMAPHORE.release()
|
||||||
|
async with HOOKS_ACTIVE_LOCK:
|
||||||
|
HOOKS_ACTIVE_COUNT -= 1
|
||||||
|
# 保证计数不为负
|
||||||
|
HOOKS_ACTIVE_COUNT = max(HOOKS_ACTIVE_COUNT, 0)
|
||||||
|
logger.debug(f"当前并发权限检查数量: {HOOKS_ACTIVE_COUNT}", LOGGER_COMMAND)
|
||||||
|
|
||||||
|
|
||||||
async def auth(
|
async def auth(
|
||||||
matcher: Matcher,
|
matcher: Matcher,
|
||||||
event: Event,
|
event: Event,
|
||||||
@ -285,6 +321,9 @@ async def auth(
|
|||||||
hook_times = {}
|
hook_times = {}
|
||||||
hooks_time = 0 # 初始化 hooks_time 变量
|
hooks_time = 0 # 初始化 hooks_time 变量
|
||||||
|
|
||||||
|
# 记录是否已进入 hooks 区域(用于 finally 中释放)
|
||||||
|
entered_hooks = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not module:
|
if not module:
|
||||||
raise PermissionExemption("Matcher插件名称不存在...")
|
raise PermissionExemption("Matcher插件名称不存在...")
|
||||||
@ -304,6 +343,10 @@ async def auth(
|
|||||||
)
|
)
|
||||||
raise PermissionExemption("获取插件和用户数据超时,请稍后再试...")
|
raise PermissionExemption("获取插件和用户数据超时,请稍后再试...")
|
||||||
|
|
||||||
|
# 进入 hooks 并行检查区域(会在高并发时排队)
|
||||||
|
await _enter_hooks_section()
|
||||||
|
entered_hooks = True
|
||||||
|
|
||||||
# 获取插件费用
|
# 获取插件费用
|
||||||
cost_start = time.time()
|
cost_start = time.time()
|
||||||
try:
|
try:
|
||||||
@ -320,16 +363,32 @@ async def auth(
|
|||||||
# 执行 bot_filter
|
# 执行 bot_filter
|
||||||
bot_filter(session)
|
bot_filter(session)
|
||||||
|
|
||||||
|
group = None
|
||||||
|
if entity.group_id:
|
||||||
|
group_dao = DataAccess(GroupConsole)
|
||||||
|
group = await with_timeout(
|
||||||
|
group_dao.safe_get_or_none(
|
||||||
|
group_id=entity.group_id, channel_id__isnull=True
|
||||||
|
),
|
||||||
|
name="get_group",
|
||||||
|
)
|
||||||
|
|
||||||
# 并行执行所有 hook 检查,并记录执行时间
|
# 并行执行所有 hook 检查,并记录执行时间
|
||||||
hooks_start = time.time()
|
hooks_start = time.time()
|
||||||
|
|
||||||
# 创建所有 hook 任务
|
# 创建所有 hook 任务
|
||||||
hook_tasks = [
|
hook_tasks = [
|
||||||
time_hook(auth_ban(matcher, bot, session), "auth_ban", hook_times),
|
time_hook(auth_ban(matcher, bot, session, plugin), "auth_ban", hook_times),
|
||||||
time_hook(auth_bot(plugin, bot.self_id), "auth_bot", hook_times),
|
time_hook(auth_bot(plugin, bot.self_id), "auth_bot", hook_times),
|
||||||
time_hook(auth_group(plugin, entity, message), "auth_group", hook_times),
|
time_hook(
|
||||||
|
auth_group(plugin, group, message, entity.group_id),
|
||||||
|
"auth_group",
|
||||||
|
hook_times,
|
||||||
|
),
|
||||||
time_hook(auth_admin(plugin, session), "auth_admin", hook_times),
|
time_hook(auth_admin(plugin, session), "auth_admin", hook_times),
|
||||||
time_hook(auth_plugin(plugin, session, event), "auth_plugin", hook_times),
|
time_hook(
|
||||||
|
auth_plugin(plugin, group, session, event), "auth_plugin", hook_times
|
||||||
|
),
|
||||||
time_hook(auth_limit(plugin, session), "auth_limit", hook_times),
|
time_hook(auth_limit(plugin, session), "auth_limit", hook_times),
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -358,7 +417,17 @@ async def auth(
|
|||||||
logger.debug("超级用户跳过权限检测...", LOGGER_COMMAND, session=session)
|
logger.debug("超级用户跳过权限检测...", LOGGER_COMMAND, session=session)
|
||||||
except PermissionExemption as e:
|
except PermissionExemption as e:
|
||||||
logger.info(str(e), LOGGER_COMMAND, session=session)
|
logger.info(str(e), LOGGER_COMMAND, session=session)
|
||||||
|
finally:
|
||||||
|
# 如果进入过 hooks 区域,确保释放信号量(即使上层处理抛出了异常)
|
||||||
|
if entered_hooks:
|
||||||
|
try:
|
||||||
|
await _leave_hooks_section()
|
||||||
|
except Exception:
|
||||||
|
logger.error(
|
||||||
|
"释放 hooks 信号量时出错",
|
||||||
|
LOGGER_COMMAND,
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
# 扣除金币
|
# 扣除金币
|
||||||
if not ignore_flag and cost_gold > 0:
|
if not ignore_flag and cost_gold > 0:
|
||||||
gold_start = time.time()
|
gold_start = time.time()
|
||||||
|
|||||||
@ -43,18 +43,20 @@ class BanCheckLimiter:
|
|||||||
|
|
||||||
def check(self, key: str | float) -> bool:
|
def check(self, key: str | float) -> bool:
|
||||||
if time.time() - self.mtime[key] > self.default_check_time:
|
if time.time() - self.mtime[key] > self.default_check_time:
|
||||||
self.mtime[key] = time.time()
|
return self._extracted_from_check_3(key, False)
|
||||||
self.mint[key] = 0
|
|
||||||
return False
|
|
||||||
if (
|
if (
|
||||||
self.mint[key] >= self.default_count
|
self.mint[key] >= self.default_count
|
||||||
and time.time() - self.mtime[key] < self.default_check_time
|
and time.time() - self.mtime[key] < self.default_check_time
|
||||||
):
|
):
|
||||||
self.mtime[key] = time.time()
|
return self._extracted_from_check_3(key, True)
|
||||||
self.mint[key] = 0
|
|
||||||
return True
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# TODO Rename this here and in `check`
|
||||||
|
def _extracted_from_check_3(self, key, arg1):
|
||||||
|
self.mtime[key] = time.time()
|
||||||
|
self.mint[key] = 0
|
||||||
|
return arg1
|
||||||
|
|
||||||
|
|
||||||
_blmt = BanCheckLimiter(
|
_blmt = BanCheckLimiter(
|
||||||
malicious_check_time,
|
malicious_check_time,
|
||||||
@ -70,16 +72,15 @@ async def _(
|
|||||||
module = None
|
module = None
|
||||||
if plugin := matcher.plugin:
|
if plugin := matcher.plugin:
|
||||||
module = plugin.module_name
|
module = plugin.module_name
|
||||||
if metadata := plugin.metadata:
|
if not (metadata := plugin.metadata):
|
||||||
extra = metadata.extra
|
return
|
||||||
if extra.get("plugin_type") in [
|
extra = metadata.extra
|
||||||
PluginType.HIDDEN,
|
if extra.get("plugin_type") in [
|
||||||
PluginType.DEPENDANT,
|
PluginType.HIDDEN,
|
||||||
PluginType.ADMIN,
|
PluginType.DEPENDANT,
|
||||||
PluginType.SUPERUSER,
|
PluginType.ADMIN,
|
||||||
]:
|
PluginType.SUPERUSER,
|
||||||
return
|
]:
|
||||||
else:
|
|
||||||
return
|
return
|
||||||
if matcher.type == "notice":
|
if matcher.type == "notice":
|
||||||
return
|
return
|
||||||
@ -88,32 +89,31 @@ async def _(
|
|||||||
malicious_ban_time = Config.get_config("hook", "MALICIOUS_BAN_TIME")
|
malicious_ban_time = Config.get_config("hook", "MALICIOUS_BAN_TIME")
|
||||||
if not malicious_ban_time:
|
if not malicious_ban_time:
|
||||||
raise ValueError("模块: [hook], 配置项: [MALICIOUS_BAN_TIME] 为空或小于0")
|
raise ValueError("模块: [hook], 配置项: [MALICIOUS_BAN_TIME] 为空或小于0")
|
||||||
if user_id:
|
if user_id and module:
|
||||||
if module:
|
if _blmt.check(f"{user_id}__{module}"):
|
||||||
if _blmt.check(f"{user_id}__{module}"):
|
await BanConsole.ban(
|
||||||
await BanConsole.ban(
|
user_id,
|
||||||
user_id,
|
group_id,
|
||||||
group_id,
|
9,
|
||||||
9,
|
"恶意触发命令检测",
|
||||||
"恶意触发命令检测",
|
malicious_ban_time * 60,
|
||||||
malicious_ban_time * 60,
|
bot.self_id,
|
||||||
bot.self_id,
|
)
|
||||||
)
|
logger.info(
|
||||||
logger.info(
|
f"触发了恶意触发检测: {matcher.plugin_name}",
|
||||||
f"触发了恶意触发检测: {matcher.plugin_name}",
|
"HOOK",
|
||||||
"HOOK",
|
session=session,
|
||||||
session=session,
|
)
|
||||||
)
|
await MessageUtils.build_message(
|
||||||
await MessageUtils.build_message(
|
[
|
||||||
[
|
At(flag="user", target=user_id),
|
||||||
At(flag="user", target=user_id),
|
"检测到恶意触发命令,您将被封禁 30 分钟",
|
||||||
"检测到恶意触发命令,您将被封禁 30 分钟",
|
]
|
||||||
]
|
).send()
|
||||||
).send()
|
logger.debug(
|
||||||
logger.debug(
|
f"触发了恶意触发检测: {matcher.plugin_name}",
|
||||||
f"触发了恶意触发检测: {matcher.plugin_name}",
|
"HOOK",
|
||||||
"HOOK",
|
session=session,
|
||||||
session=session,
|
)
|
||||||
)
|
raise IgnoredException("检测到恶意触发命令")
|
||||||
raise IgnoredException("检测到恶意触发命令")
|
_blmt.add(f"{user_id}__{module}")
|
||||||
_blmt.add(f"{user_id}__{module}")
|
|
||||||
|
|||||||
4
zhenxun/services/cache/__init__.py
vendored
4
zhenxun/services/cache/__init__.py
vendored
@ -98,6 +98,7 @@ from .cache_containers import CacheDict, CacheList
|
|||||||
from .config import (
|
from .config import (
|
||||||
CACHE_KEY_PREFIX,
|
CACHE_KEY_PREFIX,
|
||||||
CACHE_KEY_SEPARATOR,
|
CACHE_KEY_SEPARATOR,
|
||||||
|
CACHE_TIMEOUT,
|
||||||
DEFAULT_EXPIRE,
|
DEFAULT_EXPIRE,
|
||||||
LOG_COMMAND,
|
LOG_COMMAND,
|
||||||
SPECIAL_KEY_FORMATS,
|
SPECIAL_KEY_FORMATS,
|
||||||
@ -551,7 +552,6 @@ class CacheManager:
|
|||||||
返回:
|
返回:
|
||||||
Any: 缓存数据,如果不存在返回默认值
|
Any: 缓存数据,如果不存在返回默认值
|
||||||
"""
|
"""
|
||||||
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
|
||||||
|
|
||||||
# 如果缓存被禁用或缓存模式为NONE,直接返回默认值
|
# 如果缓存被禁用或缓存模式为NONE,直接返回默认值
|
||||||
if not self.enabled or cache_config.cache_mode == CacheMode.NONE:
|
if not self.enabled or cache_config.cache_mode == CacheMode.NONE:
|
||||||
@ -561,7 +561,7 @@ class CacheManager:
|
|||||||
cache_key = self._build_key(cache_type, key)
|
cache_key = self._build_key(cache_type, key)
|
||||||
data = await asyncio.wait_for(
|
data = await asyncio.wait_for(
|
||||||
self.cache_backend.get(cache_key), # type: ignore
|
self.cache_backend.get(cache_key), # type: ignore
|
||||||
timeout=DB_TIMEOUT_SECONDS,
|
timeout=CACHE_TIMEOUT,
|
||||||
)
|
)
|
||||||
|
|
||||||
if data is None:
|
if data is None:
|
||||||
|
|||||||
3
zhenxun/services/cache/config.py
vendored
3
zhenxun/services/cache/config.py
vendored
@ -5,6 +5,9 @@
|
|||||||
# 日志标识
|
# 日志标识
|
||||||
LOG_COMMAND = "CacheRoot"
|
LOG_COMMAND = "CacheRoot"
|
||||||
|
|
||||||
|
# 缓存获取超时时间(秒)
|
||||||
|
CACHE_TIMEOUT = 10
|
||||||
|
|
||||||
# 默认缓存过期时间(秒)
|
# 默认缓存过期时间(秒)
|
||||||
DEFAULT_EXPIRE = 600
|
DEFAULT_EXPIRE = 600
|
||||||
|
|
||||||
|
|||||||
@ -27,5 +27,8 @@ async def with_db_timeout(
|
|||||||
return result
|
return result
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
if operation:
|
if operation:
|
||||||
logger.error(f"数据库操作超时: {operation} (>{timeout}s)", LOG_COMMAND)
|
logger.error(
|
||||||
|
f"数据库操作超时: {operation} (>{timeout}s) 来源: {source}",
|
||||||
|
LOG_COMMAND,
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user