超级无敌大优化,解决延迟与卡死问题

This commit is contained in:
HibiKier 2025-07-14 00:42:14 +08:00
parent 419c8934a0
commit defe99e66c
21 changed files with 1493 additions and 525 deletions

View File

@ -1,11 +1,17 @@
import asyncio
import time
from nonebot_plugin_alconna import At
from nonebot_plugin_uninfo import Uninfo
from zhenxun.models.level_user import LevelUser
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.data_access import DataAccess
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
from zhenxun.services.log import logger
from zhenxun.utils.utils import get_entity_ids
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
from .exception import SkipPluginException
from .utils import send_message
@ -17,42 +23,77 @@ async def auth_admin(plugin: PluginInfo, session: Uninfo):
plugin: PluginInfo
session: Uninfo
"""
start_time = time.time()
if not plugin.admin_level:
return
entity = get_entity_ids(session)
level_dao = DataAccess(LevelUser)
global_user = await level_dao.safe_get_or_none(
user_id=session.user.id, group_id__isnull=True
)
user_level = 0
if global_user:
user_level = global_user.user_level
if entity.group_id:
# 获取用户在当前群组的权限数据
group_users = await level_dao.safe_get_or_none(
user_id=session.user.id, group_id=entity.group_id
try:
entity = get_entity_ids(session)
level_dao = DataAccess(LevelUser)
# 并行查询用户权限数据
global_user: LevelUser | None = None
group_users: LevelUser | None = None
# 查询全局权限
global_user_task = level_dao.safe_get_or_none(
user_id=session.user.id, group_id__isnull=True
)
if group_users:
# 如果在群组中,查询群组权限
group_users_task = None
if entity.group_id:
group_users_task = level_dao.safe_get_or_none(
user_id=session.user.id, group_id=entity.group_id
)
# 等待查询完成,添加超时控制
try:
results = await asyncio.wait_for(
asyncio.gather(global_user_task, group_users_task or asyncio.sleep(0)),
timeout=DB_TIMEOUT_SECONDS,
)
global_user = results[0]
group_users = results[1] if group_users_task else None
except asyncio.TimeoutError:
logger.error(f"查询用户权限超时: user_id={session.user.id}", LOGGER_COMMAND)
# 超时时不阻塞,继续执行
return
user_level = global_user.user_level if global_user else 0
if entity.group_id and group_users:
user_level = max(user_level, group_users.user_level)
if user_level < plugin.admin_level:
await send_message(
session,
[
At(flag="user", target=session.user.id),
if user_level < plugin.admin_level:
await send_message(
session,
[
At(flag="user", target=session.user.id),
f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}",
],
entity.user_id,
)
raise SkipPluginException(
f"{plugin.name}({plugin.module}) 管理员权限不足..."
)
elif global_user:
if global_user.user_level < plugin.admin_level:
await send_message(
session,
f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}",
],
entity.user_id,
)
raise SkipPluginException(
f"{plugin.name}({plugin.module}) 管理员权限不足..."
)
elif global_user:
if global_user.user_level < plugin.admin_level:
await send_message(
session,
f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}",
)
raise SkipPluginException(
f"{plugin.name}({plugin.module}) 管理员权限不足..."
)
raise SkipPluginException(
f"{plugin.name}({plugin.module}) 管理员权限不足..."
)
finally:
# 记录执行时间
elapsed = time.time() - start_time
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
logger.warning(
f"auth_admin 耗时: {elapsed:.3f}s, plugin={plugin.module}",
LOGGER_COMMAND,
session=session,
)

View File

@ -1,3 +1,6 @@
import asyncio
import time
from nonebot.adapters import Bot
from nonebot.matcher import Matcher
from nonebot_plugin_alconna import At
@ -7,9 +10,12 @@ from zhenxun.configs.config import Config
from zhenxun.models.ban_console import BanConsole
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.data_access import DataAccess
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
from zhenxun.services.log import logger
from zhenxun.utils.enum import PluginType
from zhenxun.utils.utils import EntityIDs, get_entity_ids
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
from .exception import SkipPluginException
from .utils import freq, send_message
@ -21,32 +27,105 @@ Config.add_plugin_config(
)
def calculate_ban_time(ban_record: BanConsole | None) -> int:
"""根据ban记录计算剩余ban时间
参数:
ban_record: BanConsole记录
返回:
int: ban剩余时长-1时为永久ban0表示未被ban
"""
if not ban_record:
return 0
if ban_record.duration == -1:
return -1
_time = time.time() - (ban_record.ban_time + ban_record.duration)
return 0 if _time > 0 else int(abs(_time))
async def is_ban(user_id: str | None, group_id: str | None) -> int:
"""检查用户或群组是否被ban
参数:
user_id: 用户ID
group_id: 群组ID
返回:
int: ban的剩余时间0表示未被ban
"""
if not user_id and not group_id:
return 0
start_time = time.time()
ban_dao = DataAccess(BanConsole)
# 分别获取用户在群组中的ban记录和全局ban记录
group_user = None
user = None
if user_id and group_id:
group_user = await ban_dao.safe_get_or_none(user_id=user_id, group_id=group_id)
try:
# 并行查询用户和群组的 ban 记录
tasks = []
if user_id and group_id:
tasks.append(ban_dao.safe_get_or_none(user_id=user_id, group_id=group_id))
if user_id:
tasks.append(ban_dao.safe_get_or_none(user_id=user_id, group_id=""))
if user_id:
user = await ban_dao.safe_get_or_none(user_id=user_id, group_id="")
# 等待所有查询完成,添加超时控制
if tasks:
try:
ban_records = await asyncio.wait_for(
asyncio.gather(*tasks), timeout=DB_TIMEOUT_SECONDS
)
if len(tasks) == 2:
group_user, user = ban_records
elif user_id and group_id:
group_user = ban_records[0]
else:
user = ban_records[0]
except asyncio.TimeoutError:
logger.error(
f"查询ban记录超时: user_id={user_id}, group_id={group_id}",
LOGGER_COMMAND,
)
# 超时时返回0避免阻塞
return 0
results = []
if group_user:
results.append(group_user)
if user:
results.append(user)
if not results:
return 0
for result in results:
if result.duration > 0 or result.duration == -1:
return await BanConsole.check_ban_time(user_id, group_id)
return 0
# 检查记录并计算ban时间
results = []
if group_user:
results.append(group_user)
if user:
results.append(user)
# 如果没有找到记录返回0
if not results:
return 0
logger.debug(f"查询到的ban记录: {results}", LOGGER_COMMAND)
# 检查所有记录找出最严格的ban时间最长的
max_ban_time: int = 0
for result in results:
if result.duration > 0 or result.duration == -1:
# 直接计算ban时间避免再次查询数据库
ban_time = calculate_ban_time(result)
if ban_time == -1 or ban_time > max_ban_time:
max_ban_time = ban_time
return max_ban_time
finally:
# 记录执行时间
elapsed = time.time() - start_time
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
logger.warning(
f"is_ban 耗时: {elapsed:.3f}s",
LOGGER_COMMAND,
session=user_id,
group_id=group_id,
)
def check_plugin_type(matcher: Matcher) -> bool:
@ -66,22 +145,22 @@ def check_plugin_type(matcher: Matcher) -> bool:
return True
def format_time(time: float) -> str:
def format_time(time_val: float) -> str:
"""格式化时间
参数:
time: ban时长
time_val: ban时长
返回:
str: 格式化时间文本
"""
if time == -1:
if time_val == -1:
return ""
time = abs(int(time))
if time < 60:
time_str = f"{time!s}"
time_val = abs(int(time_val))
if time_val < 60:
time_str = f"{time_val!s}"
else:
minute = int(time / 60)
minute = int(time_val / 60)
if minute > 60:
hours = minute // 60
minute %= 60
@ -91,66 +170,132 @@ def format_time(time: float) -> str:
return time_str
async def group_handle(group_id: str):
async def group_handle(group_id: str) -> None:
"""群组ban检查
参数:
ban_dao: BanConsole数据访问对象
group_id: 群组id
异常:
SkipPluginException: 群组处于黑名单
"""
if await is_ban(None, group_id):
raise SkipPluginException("群组处于黑名单中...")
start_time = time.time()
try:
if await is_ban(None, group_id):
raise SkipPluginException("群组处于黑名单中...")
finally:
# 记录执行时间
elapsed = time.time() - start_time
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
logger.warning(
f"group_handle 耗时: {elapsed:.3f}s",
LOGGER_COMMAND,
group_id=group_id,
)
async def user_handle(module: str, entity: EntityIDs, session: Uninfo):
async def user_handle(module: str, entity: EntityIDs, session: Uninfo) -> None:
"""用户ban检查
参数:
module: 插件模块名
ban_dao: BanConsole数据访问对象
entity: 实体ID信息
session: Uninfo
异常:
SkipPluginException: 用户处于黑名单
"""
ban_result = Config.get_config("hook", "BAN_RESULT")
time = await is_ban(entity.user_id, entity.group_id)
if not time:
return
time_str = format_time(time)
plugin_dao = DataAccess(PluginInfo)
db_plugin = await plugin_dao.safe_get_or_none(module=module)
if (
db_plugin
and not db_plugin.ignore_prompt
and time != -1
and ban_result
and freq.is_send_limit_message(db_plugin, entity.user_id, False)
):
await send_message(
session,
[
At(flag="user", target=entity.user_id),
f"{ban_result}\n在..在 {time_str} 后才会理你喔",
],
entity.user_id,
)
raise SkipPluginException("用户处于黑名单中...")
start_time = time.time()
try:
ban_result = Config.get_config("hook", "BAN_RESULT")
time_val = await is_ban(entity.user_id, entity.group_id)
if not time_val:
return
time_str = format_time(time_val)
plugin_dao = DataAccess(PluginInfo)
try:
db_plugin = await asyncio.wait_for(
plugin_dao.safe_get_or_none(module=module), timeout=DB_TIMEOUT_SECONDS
)
except asyncio.TimeoutError:
logger.error(f"查询插件信息超时: {module}", LOGGER_COMMAND)
# 超时时不阻塞,继续执行
raise SkipPluginException("用户处于黑名单中...")
if (
db_plugin
and not db_plugin.ignore_prompt
and time_val != -1
and ban_result
and freq.is_send_limit_message(db_plugin, entity.user_id, False)
):
try:
await asyncio.wait_for(
send_message(
session,
[
At(flag="user", target=entity.user_id),
f"{ban_result}\n在..在 {time_str} 后才会理你喔",
],
entity.user_id,
),
timeout=DB_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.error(f"发送消息超时: {entity.user_id}", LOGGER_COMMAND)
raise SkipPluginException("用户处于黑名单中...")
finally:
# 记录执行时间
elapsed = time.time() - start_time
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
logger.warning(
f"user_handle 耗时: {elapsed:.3f}s",
LOGGER_COMMAND,
session=session,
)
async def auth_ban(matcher: Matcher, bot: Bot, session: Uninfo):
if not check_plugin_type(matcher):
return
if not matcher.plugin_name:
return
entity = get_entity_ids(session)
if entity.user_id in bot.config.superusers:
return
if entity.group_id:
await group_handle(entity.group_id)
if entity.user_id:
await user_handle(matcher.plugin_name, entity, session)
async def auth_ban(matcher: Matcher, bot: Bot, session: Uninfo) -> None:
"""权限检查 - ban 检查
参数:
matcher: Matcher
bot: Bot
session: Uninfo
"""
start_time = time.time()
try:
if not check_plugin_type(matcher):
return
if not matcher.plugin_name:
return
entity = get_entity_ids(session)
if entity.user_id in bot.config.superusers:
return
if entity.group_id:
try:
await asyncio.wait_for(
group_handle(entity.group_id), timeout=DB_TIMEOUT_SECONDS
)
except asyncio.TimeoutError:
logger.error(f"群组ban检查超时: {entity.group_id}", LOGGER_COMMAND)
# 超时时不阻塞,继续执行
if entity.user_id:
try:
await asyncio.wait_for(
user_handle(matcher.plugin_name, entity, session),
timeout=DB_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.error(f"用户ban检查超时: {entity.user_id}", LOGGER_COMMAND)
# 超时时不阻塞,继续执行
finally:
# 记录总执行时间
elapsed = time.time() - start_time
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
logger.warning(
f"auth_ban 总耗时: {elapsed:.3f}s, plugin={matcher.plugin_name}",
LOGGER_COMMAND,
session=session,
)

View File

@ -1,8 +1,14 @@
import asyncio
import time
from zhenxun.models.bot_console import BotConsole
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.data_access import DataAccess
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
from zhenxun.services.log import logger
from zhenxun.utils.common_utils import CommonUtils
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
from .exception import SkipPluginException
@ -17,11 +23,33 @@ async def auth_bot(plugin: PluginInfo, bot_id: str):
SkipPluginException: 忽略插件
SkipPluginException: 忽略插件
"""
bot_dao = DataAccess(BotConsole)
bot = await bot_dao.safe_get_or_none(bot_id=bot_id)
if not bot or not bot.status:
raise SkipPluginException("Bot不存在或休眠中阻断权限检测...")
if CommonUtils.format(plugin.module) in bot.block_plugins:
raise SkipPluginException(
f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭..."
)
start_time = time.time()
try:
# 从数据库或缓存中获取 bot 信息
bot_dao = DataAccess(BotConsole)
try:
bot: BotConsole | None = await asyncio.wait_for(
bot_dao.safe_get_or_none(bot_id=bot_id), timeout=DB_TIMEOUT_SECONDS
)
except asyncio.TimeoutError:
logger.error(f"查询Bot信息超时: bot_id={bot_id}", LOGGER_COMMAND)
# 超时时不阻塞,继续执行
return
if not bot or not bot.status:
raise SkipPluginException("Bot不存在或休眠中阻断权限检测...")
if CommonUtils.format(plugin.module) in bot.block_plugins:
raise SkipPluginException(
f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭..."
)
finally:
# 记录执行时间
elapsed = time.time() - start_time
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
logger.warning(
f"auth_bot 耗时: {elapsed:.3f}s, "
f"bot_id={bot_id}, plugin={plugin.module}",
LOGGER_COMMAND,
)

View File

@ -1,8 +1,12 @@
import time
from nonebot_plugin_uninfo import Uninfo
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.user_console import UserConsole
from zhenxun.services.log import logger
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
from .exception import SkipPluginException
from .utils import send_message
@ -18,8 +22,20 @@ async def auth_cost(user: UserConsole, plugin: PluginInfo, session: Uninfo) -> i
返回:
int: 需要消耗的金币
"""
if user.gold < plugin.cost_gold:
"""插件消耗金币不足"""
await send_message(session, f"金币不足..该功能需要{plugin.cost_gold}金币..")
raise SkipPluginException(f"{plugin.name}({plugin.module}) 金币限制...")
return plugin.cost_gold
start_time = time.time()
try:
if user.gold < plugin.cost_gold:
"""插件消耗金币不足"""
await send_message(session, f"金币不足..该功能需要{plugin.cost_gold}金币..")
raise SkipPluginException(f"{plugin.name}({plugin.module}) 金币限制...")
return plugin.cost_gold
finally:
# 记录执行时间
elapsed = time.time() - start_time
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
logger.warning(
f"auth_cost 耗时: {elapsed:.3f}s, plugin={plugin.module}",
LOGGER_COMMAND,
session=session,
)

View File

@ -1,11 +1,16 @@
import asyncio
import time
from nonebot_plugin_alconna import UniMsg
from zhenxun.models.group_console import GroupConsole
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.data_access import DataAccess
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
from zhenxun.services.log import logger
from zhenxun.utils.utils import EntityIDs
from .config import SwitchEnum
from .config import LOGGER_COMMAND, WARNING_THRESHOLD, SwitchEnum
from .exception import SkipPluginException
@ -17,21 +22,47 @@ async def auth_group(plugin: PluginInfo, entity: EntityIDs, message: UniMsg):
entity: EntityIDs
message: UniMsg
"""
start_time = time.time()
if not entity.group_id:
return
text = message.extract_plain_text()
group_dao = DataAccess(GroupConsole)
group = await group_dao.safe_get_or_none(
group_id=entity.group_id, channel_id__isnull=True
)
if not group:
raise SkipPluginException("群组信息不存在...")
if group.level < 0:
raise SkipPluginException("群组黑名单, 目标群组群权限权限-1...")
if text.strip() != SwitchEnum.ENABLE and not group.status:
raise SkipPluginException("群组休眠状态...")
if plugin.level > group.level:
raise SkipPluginException(
f"{plugin.name}({plugin.module}) 群等级限制,"
f"该功能需要的群等级: {plugin.level}..."
)
try:
text = message.extract_plain_text()
# 从数据库或缓存中获取群组信息
group_dao = DataAccess(GroupConsole)
try:
group: GroupConsole | None = await asyncio.wait_for(
group_dao.safe_get_or_none(
group_id=entity.group_id, channel_id__isnull=True
),
timeout=DB_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.error("查询群组信息超时", LOGGER_COMMAND, session=entity.user_id)
# 超时时不阻塞,继续执行
return
if not group:
raise SkipPluginException("群组信息不存在...")
if group.level < 0:
raise SkipPluginException("群组黑名单, 目标群组群权限权限-1...")
if text.strip() != SwitchEnum.ENABLE and not group.status:
raise SkipPluginException("群组休眠状态...")
if plugin.level > group.level:
raise SkipPluginException(
f"{plugin.name}({plugin.module}) 群等级限制,"
f"该功能需要的群等级: {plugin.level}..."
)
finally:
# 记录执行时间
elapsed = time.time() - start_time
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
logger.warning(
f"auth_group 耗时: {elapsed:.3f}s, plugin={plugin.module}",
LOGGER_COMMAND,
session=entity.user_id,
group_id=entity.group_id,
)

View File

@ -1,3 +1,5 @@
import asyncio
import time
from typing import ClassVar
import nonebot
@ -6,6 +8,7 @@ from pydantic import BaseModel
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.plugin_limit import PluginLimit
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
from zhenxun.services.log import logger
from zhenxun.utils.enum import LimitWatchType, PluginLimitType
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
@ -17,7 +20,7 @@ from zhenxun.utils.utils import (
get_entity_ids,
)
from .config import LOGGER_COMMAND
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
from .exception import SkipPluginException
driver = nonebot.get_driver()
@ -39,17 +42,61 @@ class Limit(BaseModel):
class LimitManager:
add_module: ClassVar[list] = []
last_update_time: ClassVar[float] = 0
update_interval: ClassVar[float] = 6000 # 1小时更新一次
is_updating: ClassVar[bool] = False # 防止并发更新
cd_limit: ClassVar[dict[str, Limit]] = {}
block_limit: ClassVar[dict[str, Limit]] = {}
count_limit: ClassVar[dict[str, Limit]] = {}
# 模块限制缓存,避免频繁查询数据库
module_limit_cache: ClassVar[dict[str, tuple[float, list[PluginLimit]]]] = {}
module_cache_ttl: ClassVar[float] = 60 # 模块缓存有效期(秒)
@classmethod
async def init_limit(cls):
"""初始化限制"""
limit_list = await PluginLimit.filter(status=True).all()
for limit in limit_list:
cls.add_limit(limit)
cls.last_update_time = time.time()
try:
await asyncio.wait_for(cls.update_limits(), timeout=DB_TIMEOUT_SECONDS * 2)
except asyncio.TimeoutError:
logger.error("初始化限制超时", LOGGER_COMMAND)
@classmethod
async def update_limits(cls):
"""更新限制信息"""
# 防止并发更新
if cls.is_updating:
return
cls.is_updating = True
try:
start_time = time.time()
try:
limit_list = await asyncio.wait_for(
PluginLimit.filter(status=True).all(), timeout=DB_TIMEOUT_SECONDS
)
except asyncio.TimeoutError:
logger.error("查询限制信息超时", LOGGER_COMMAND)
cls.is_updating = False
return
# 清空旧数据
cls.add_module = []
cls.cd_limit = {}
cls.block_limit = {}
cls.count_limit = {}
# 添加新数据
for limit in limit_list:
cls.add_limit(limit)
cls.last_update_time = time.time()
elapsed = time.time() - start_time
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的更新
logger.warning(f"更新限制信息耗时: {elapsed:.3f}s", LOGGER_COMMAND)
finally:
cls.is_updating = False
@classmethod
def add_limit(cls, limit: PluginLimit):
@ -99,6 +146,46 @@ class LimitManager:
)
limiter.set_false(key_type)
@classmethod
async def get_module_limits(cls, module: str) -> list[PluginLimit]:
"""获取模块的限制信息,使用缓存减少数据库查询
参数:
module: 模块名
返回:
list[PluginLimit]: 限制列表
"""
current_time = time.time()
# 检查缓存
if module in cls.module_limit_cache:
cache_time, limits = cls.module_limit_cache[module]
if current_time - cache_time < cls.module_cache_ttl:
return limits
# 缓存不存在或已过期,从数据库查询
try:
start_time = time.time()
limits = await asyncio.wait_for(
PluginLimit.filter(module=module, status=True).all(),
timeout=DB_TIMEOUT_SECONDS,
)
elapsed = time.time() - start_time
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的查询
logger.warning(
f"查询模块限制信息耗时: {elapsed:.3f}s, 模块: {module}",
LOGGER_COMMAND,
)
# 更新缓存
cls.module_limit_cache[module] = (current_time, limits)
return limits
except asyncio.TimeoutError:
logger.error(f"查询模块限制信息超时: {module}", LOGGER_COMMAND)
# 超时时返回空列表,避免阻塞
return []
@classmethod
async def check(
cls,
@ -118,12 +205,40 @@ class LimitManager:
异常:
IgnoredException: IgnoredException
"""
if limit_model := cls.cd_limit.get(module):
await cls.__check(limit_model, user_id, group_id, channel_id)
if limit_model := cls.block_limit.get(module):
await cls.__check(limit_model, user_id, group_id, channel_id)
if limit_model := cls.count_limit.get(module):
await cls.__check(limit_model, user_id, group_id, channel_id)
start_time = time.time()
# 定期更新全局限制信息
if (
time.time() - cls.last_update_time > cls.update_interval
and not cls.is_updating
):
# 使用异步任务更新,避免阻塞当前请求
asyncio.create_task(cls.update_limits()) # noqa: RUF006
# 如果模块不在已加载列表中,只加载该模块的限制
if module not in cls.add_module:
limits = await cls.get_module_limits(module)
for limit in limits:
cls.add_limit(limit)
# 检查各种限制
try:
if limit_model := cls.cd_limit.get(module):
await cls.__check(limit_model, user_id, group_id, channel_id)
if limit_model := cls.block_limit.get(module):
await cls.__check(limit_model, user_id, group_id, channel_id)
if limit_model := cls.count_limit.get(module):
await cls.__check(limit_model, user_id, group_id, channel_id)
finally:
# 记录总执行时间
elapsed = time.time() - start_time
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
logger.warning(
f"限制检查耗时: {elapsed:.3f}s, 模块: {module}",
LOGGER_COMMAND,
session=user_id,
group_id=group_id,
)
@classmethod
async def __check(
@ -158,7 +273,13 @@ class LimitManager:
key_type = channel_id or group_id
if is_limit and not limiter.check(key_type):
if limit.result:
await MessageUtils.build_message(limit.result).send()
try:
await asyncio.wait_for(
MessageUtils.build_message(limit.result).send(),
timeout=DB_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.error(f"发送限制消息超时: {limit.module}", LOGGER_COMMAND)
raise SkipPluginException(
f"{limit.module}({limit.limit_type}) 正在限制中..."
)
@ -185,11 +306,13 @@ async def auth_limit(plugin: PluginInfo, session: Uninfo):
session: Uninfo
"""
entity = get_entity_ids(session)
if plugin.module not in LimitManager.add_module:
limit_list = await PluginLimit.filter(module=plugin.module, status=True).all()
for limit in limit_list:
LimitManager.add_limit(limit)
if entity.user_id:
await LimitManager.check(
plugin.module, entity.user_id, entity.group_id, entity.channel_id
try:
await asyncio.wait_for(
LimitManager.check(
plugin.module, entity.user_id, entity.group_id, entity.channel_id
),
timeout=DB_TIMEOUT_SECONDS * 2, # 给予更长的超时时间
)
except asyncio.TimeoutError:
logger.error(f"检查插件限制超时: {plugin.module}", LOGGER_COMMAND)
# 超时时不抛出异常,允许继续执行

View File

@ -1,13 +1,19 @@
import asyncio
import time
from nonebot.adapters import Event
from nonebot_plugin_uninfo import Uninfo
from zhenxun.models.group_console import GroupConsole
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.data_access import DataAccess
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
from zhenxun.services.log import logger
from zhenxun.utils.common_utils import CommonUtils
from zhenxun.utils.enum import BlockType
from zhenxun.utils.utils import get_entity_ids
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
from .exception import IsSuperuserException, SkipPluginException
from .utils import freq, is_poke, send_message
@ -21,69 +27,89 @@ class GroupCheck:
self.is_poke = is_poke
self.plugin = plugin
self.group_dao = DataAccess(GroupConsole)
async def __get_data(self):
return await self.group_dao.safe_get_or_none(
group_id=self.group_id, channel_id__isnull=True
)
self.group_data = None
async def check(self):
await self.check_superuser_block(self.plugin)
async def check_superuser_block(self, plugin: PluginInfo):
"""超级用户禁用群组插件检测
参数:
plugin: PluginInfo
异常:
IgnoredException: 忽略插件
"""
group = await self.__get_data()
if group and CommonUtils.format(plugin.module) in group.superuser_block_plugin:
if freq.is_send_limit_message(plugin, group.group_id, self.is_poke):
await send_message(
self.session, "超级管理员禁用了该群此功能...", self.group_id
start_time = time.time()
try:
# 只查询一次数据库,使用 DataAccess 的缓存机制
try:
self.group_data = await asyncio.wait_for(
self.group_dao.safe_get_or_none(
group_id=self.group_id, channel_id__isnull=True
),
timeout=DB_TIMEOUT_SECONDS,
)
raise SkipPluginException(
f"{plugin.name}({plugin.module}) 超级管理员禁用了该群此功能..."
)
await self.check_normal_block(self.plugin)
except asyncio.TimeoutError:
logger.error(f"查询群组数据超时: {self.group_id}", LOGGER_COMMAND)
return # 超时时不阻塞,继续执行
async def check_normal_block(self, plugin: PluginInfo):
"""群组插件状态
参数:
plugin: PluginInfo
异常:
IgnoredException: 忽略插件
"""
group = await self.__get_data()
if group and CommonUtils.format(plugin.module) in group.block_plugin:
if freq.is_send_limit_message(plugin, self.group_id, self.is_poke):
await send_message(self.session, "该群未开启此功能...", self.group_id)
raise SkipPluginException(f"{plugin.name}({plugin.module}) 未开启此功能...")
await self.check_global_block(self.plugin)
async def check_global_block(self, plugin: PluginInfo):
"""全局禁用插件检测
参数:
plugin: PluginInfo
异常:
IgnoredException: 忽略插件
"""
if plugin.block_type == BlockType.GROUP:
"""全局群组禁用"""
if freq.is_send_limit_message(plugin, self.group_id, self.is_poke):
await send_message(
self.session, "该功能在群组中已被禁用...", self.group_id
# 检查超级用户禁用
if (
self.group_data
and CommonUtils.format(self.plugin.module)
in self.group_data.superuser_block_plugin
):
if freq.is_send_limit_message(self.plugin, self.group_id, self.is_poke):
try:
await asyncio.wait_for(
send_message(
self.session,
"超级管理员禁用了该群此功能...",
self.group_id,
),
timeout=DB_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.error(f"发送消息超时: {self.group_id}", LOGGER_COMMAND)
raise SkipPluginException(
f"{self.plugin.name}({self.plugin.module})"
f" 超级管理员禁用了该群此功能..."
)
# 检查普通禁用
if (
self.group_data
and CommonUtils.format(self.plugin.module)
in self.group_data.block_plugin
):
if freq.is_send_limit_message(self.plugin, self.group_id, self.is_poke):
try:
await asyncio.wait_for(
send_message(
self.session, "该群未开启此功能...", self.group_id
),
timeout=DB_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.error(f"发送消息超时: {self.group_id}", LOGGER_COMMAND)
raise SkipPluginException(
f"{self.plugin.name}({self.plugin.module}) 未开启此功能..."
)
# 检查全局禁用
if self.plugin.block_type == BlockType.GROUP:
if freq.is_send_limit_message(self.plugin, self.group_id, self.is_poke):
try:
await asyncio.wait_for(
send_message(
self.session, "该功能在群组中已被禁用...", self.group_id
),
timeout=DB_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.error(f"发送消息超时: {self.group_id}", LOGGER_COMMAND)
raise SkipPluginException(
f"{self.plugin.name}({self.plugin.module})该插件在群组中已被禁用..."
)
finally:
# 记录执行时间
elapsed = time.time() - start_time
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
logger.warning(
f"GroupCheck.check 耗时: {elapsed:.3f}s, 群组: {self.group_id}",
LOGGER_COMMAND,
)
raise SkipPluginException(
f"{plugin.name}({plugin.module}) 该插件在群组中已被禁用..."
)
class PluginCheck:
@ -92,6 +118,7 @@ class PluginCheck:
self.is_poke = is_poke
self.group_id = group_id
self.group_dao = DataAccess(GroupConsole)
self.group_data = None
async def check_user(self, plugin: PluginInfo):
"""全局私聊禁用检测
@ -104,7 +131,13 @@ class PluginCheck:
"""
if plugin.block_type == BlockType.PRIVATE:
if freq.is_send_limit_message(plugin, self.session.user.id, self.is_poke):
await send_message(self.session, "该功能在私聊中已被禁用...")
try:
await asyncio.wait_for(
send_message(self.session, "该功能在私聊中已被禁用..."),
timeout=DB_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.error("发送消息超时", LOGGER_COMMAND)
raise SkipPluginException(
f"{plugin.name}({plugin.module}) 该插件在私聊中已被禁用..."
)
@ -118,19 +151,46 @@ class PluginCheck:
异常:
IgnoredException: 忽略插件
"""
if plugin.status or plugin.block_type != BlockType.ALL:
return
"""全局状态"""
if self.group_id:
group = await self.group_dao.safe_get_or_none(
group_id=self.group_id, channel_id__isnull=True
start_time = time.time()
try:
if plugin.status or plugin.block_type != BlockType.ALL:
return
"""全局状态"""
if self.group_id:
# 使用 DataAccess 的缓存机制
try:
self.group_data = await asyncio.wait_for(
self.group_dao.safe_get_or_none(
group_id=self.group_id, channel_id__isnull=True
),
timeout=DB_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.error(f"查询群组数据超时: {self.group_id}", LOGGER_COMMAND)
return # 超时时不阻塞,继续执行
if self.group_data and self.group_data.is_super:
raise IsSuperuserException()
sid = self.group_id or self.session.user.id
if freq.is_send_limit_message(plugin, sid, self.is_poke):
try:
await asyncio.wait_for(
send_message(self.session, "全局未开启此功能...", sid),
timeout=DB_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.error(f"发送消息超时: {sid}", LOGGER_COMMAND)
raise SkipPluginException(
f"{plugin.name}({plugin.module}) 全局未开启此功能..."
)
if group and group.is_super:
raise IsSuperuserException()
sid = self.group_id or self.session.user.id
if freq.is_send_limit_message(plugin, sid, self.is_poke):
await send_message(self.session, "全局未开启此功能...", sid)
raise SkipPluginException(f"{plugin.name}({plugin.module}) 全局未开启此功能...")
finally:
# 记录执行时间
elapsed = time.time() - start_time
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
logger.warning(
f"PluginCheck.check_global 耗时: {elapsed:.3f}s", LOGGER_COMMAND
)
async def auth_plugin(plugin: PluginInfo, session: Uninfo, event: Event):
@ -141,12 +201,42 @@ async def auth_plugin(plugin: PluginInfo, session: Uninfo, event: Event):
session: Uninfo
event: Event
"""
entity = get_entity_ids(session)
is_poke_event = is_poke(event)
user_check = PluginCheck(entity.group_id, session, is_poke_event)
if entity.group_id:
group_check = GroupCheck(plugin, entity.group_id, session, is_poke_event)
await group_check.check()
else:
await user_check.check_user(plugin)
await user_check.check_global(plugin)
start_time = time.time()
try:
entity = get_entity_ids(session)
is_poke_event = is_poke(event)
user_check = PluginCheck(entity.group_id, session, is_poke_event)
if entity.group_id:
group_check = GroupCheck(plugin, entity.group_id, session, is_poke_event)
try:
await asyncio.wait_for(
group_check.check(), timeout=DB_TIMEOUT_SECONDS * 2
)
except asyncio.TimeoutError:
logger.error(f"群组检查超时: {entity.group_id}", LOGGER_COMMAND)
# 超时时不阻塞,继续执行
else:
try:
await asyncio.wait_for(
user_check.check_user(plugin), timeout=DB_TIMEOUT_SECONDS
)
except asyncio.TimeoutError:
logger.error("用户检查超时", LOGGER_COMMAND)
# 超时时不阻塞,继续执行
try:
await asyncio.wait_for(
user_check.check_global(plugin), timeout=DB_TIMEOUT_SECONDS
)
except asyncio.TimeoutError:
logger.error("全局检查超时", LOGGER_COMMAND)
# 超时时不阻塞,继续执行
finally:
# 记录总执行时间
elapsed = time.time() - start_time
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
logger.warning(
f"auth_plugin 总耗时: {elapsed:.3f}s, 模块: {plugin.module}",
LOGGER_COMMAND,
)

View File

@ -11,3 +11,6 @@ LOGGER_COMMAND = "AuthChecker"
class SwitchEnum(StrEnum):
ENABLE = "醒来"
DISABLE = "休息吧"
WARNING_THRESHOLD = 0.5 # 警告阈值(秒)

View File

@ -1,4 +1,5 @@
import asyncio
import time
from nonebot.adapters import Bot, Event
from nonebot.exception import IgnoredException
@ -24,13 +25,88 @@ from .auth.auth_group import auth_group
from .auth.auth_limit import LimitManager, auth_limit
from .auth.auth_plugin import auth_plugin
from .auth.bot_filter import bot_filter
from .auth.config import LOGGER_COMMAND
from .auth.config import LOGGER_COMMAND, WARNING_THRESHOLD
from .auth.exception import (
IsSuperuserException,
PermissionExemption,
SkipPluginException,
)
# 超时设置(秒)
TIMEOUT_SECONDS = 5.0
# 熔断计数器
CIRCUIT_BREAKERS = {
"auth_ban": {"failures": 0, "threshold": 3, "active": False, "reset_time": 0},
"auth_bot": {"failures": 0, "threshold": 3, "active": False, "reset_time": 0},
"auth_group": {"failures": 0, "threshold": 3, "active": False, "reset_time": 0},
"auth_admin": {"failures": 0, "threshold": 3, "active": False, "reset_time": 0},
"auth_plugin": {"failures": 0, "threshold": 3, "active": False, "reset_time": 0},
"auth_limit": {"failures": 0, "threshold": 3, "active": False, "reset_time": 0},
}
# 熔断重置时间(秒)
CIRCUIT_RESET_TIME = 300 # 5分钟
# 超时装饰器
async def with_timeout(coro, timeout=TIMEOUT_SECONDS, name=None):
"""带超时控制的协程执行
参数:
coro: 要执行的协程
timeout: 超时时间
name: 操作名称用于日志记录
返回:
协程的返回值或者在超时时抛出 TimeoutError
"""
try:
return await asyncio.wait_for(coro, timeout=timeout)
except asyncio.TimeoutError:
if name:
logger.error(f"{name} 操作超时 (>{timeout}s)", LOGGER_COMMAND)
# 更新熔断计数器
if name in CIRCUIT_BREAKERS:
CIRCUIT_BREAKERS[name]["failures"] += 1
if (
CIRCUIT_BREAKERS[name]["failures"]
>= CIRCUIT_BREAKERS[name]["threshold"]
and not CIRCUIT_BREAKERS[name]["active"]
):
CIRCUIT_BREAKERS[name]["active"] = True
CIRCUIT_BREAKERS[name]["reset_time"] = (
time.time() + CIRCUIT_RESET_TIME
)
logger.warning(
f"{name} 熔断器已激活,将在 {CIRCUIT_RESET_TIME} 秒后重置",
LOGGER_COMMAND,
)
raise
# 检查熔断状态
def check_circuit_breaker(name):
"""检查熔断器状态
参数:
name: 操作名称
返回:
bool: 是否已熔断
"""
if name not in CIRCUIT_BREAKERS:
return False
# 检查是否需要重置熔断器
if (
CIRCUIT_BREAKERS[name]["active"]
and time.time() > CIRCUIT_BREAKERS[name]["reset_time"]
):
CIRCUIT_BREAKERS[name]["active"] = False
CIRCUIT_BREAKERS[name]["failures"] = 0
logger.info(f"{name} 熔断器已重置", LOGGER_COMMAND)
return CIRCUIT_BREAKERS[name]["active"]
async def get_plugin_and_user(
module: str, user_id: str
@ -52,7 +128,25 @@ async def get_plugin_and_user(
"""
user_dao = DataAccess(UserConsole)
plugin_dao = DataAccess(PluginInfo)
plugin = await plugin_dao.safe_get_or_none(module=module)
# 并行查询插件和用户数据
plugin_task = plugin_dao.safe_get_or_none(module=module)
user_task = user_dao.safe_get_or_none(user_id=user_id)
try:
plugin, user = await with_timeout(
asyncio.gather(plugin_task, user_task), name="get_plugin_and_user"
)
except asyncio.TimeoutError:
# 如果并行查询超时,尝试串行查询
logger.warning("并行查询超时,尝试串行查询", LOGGER_COMMAND)
plugin = await with_timeout(
plugin_dao.safe_get_or_none(module=module), name="get_plugin"
)
user = await with_timeout(
user_dao.safe_get_or_none(user_id=user_id), name="get_user"
)
if not plugin:
raise PermissionExemption(f"插件:{module} 数据不存在,已跳过权限检查...")
if plugin.plugin_type == PluginType.HIDDEN:
@ -87,7 +181,7 @@ async def get_plugin_cost(
返回:
int: 调用插件金币费用
"""
cost_gold = await auth_cost(user, plugin, session)
cost_gold = await with_timeout(auth_cost(user, plugin, session), name="auth_cost")
if session.user.id in bot.config.superusers:
if plugin.plugin_type == PluginType.SUPERUSER:
raise IsSuperuserException()
@ -107,22 +201,51 @@ async def reduce_gold(user_id: str, module: str, cost_gold: int, session: Uninfo
"""
user_dao = DataAccess(UserConsole)
try:
await UserConsole.reduce_gold(
user_id,
cost_gold,
GoldHandle.PLUGIN,
module,
PlatformUtils.get_platform(session),
await with_timeout(
UserConsole.reduce_gold(
user_id,
cost_gold,
GoldHandle.PLUGIN,
module,
PlatformUtils.get_platform(session),
),
name="reduce_gold",
)
except InsufficientGold:
if u := await UserConsole.get_user(user_id):
u.gold = 0
await u.save(update_fields=["gold"])
except asyncio.TimeoutError:
logger.error(
f"扣除金币超时,用户: {user_id}, 金币: {cost_gold}",
LOGGER_COMMAND,
session=session,
)
# 清除缓存,使下次查询时从数据库获取最新数据
await user_dao.clear_cache(user_id=user_id)
logger.debug(f"调用功能花费金币: {cost_gold}", LOGGER_COMMAND, session=session)
# 辅助函数,用于记录每个 hook 的执行时间
async def time_hook(coro, name, time_dict):
start = time.time()
try:
# 检查熔断状态
if check_circuit_breaker(name):
logger.info(f"{name} 熔断器激活中,跳过执行", LOGGER_COMMAND)
time_dict[name] = "熔断跳过"
return
# 添加超时控制
return await with_timeout(coro, name=name)
except asyncio.TimeoutError:
time_dict[name] = f"超时 (>{TIMEOUT_SECONDS}s)"
finally:
if name not in time_dict:
time_dict[name] = f"{time.time() - start:.3f}s"
async def auth(
matcher: Matcher,
event: Event,
@ -139,26 +262,81 @@ async def auth(
session: Uninfo
message: UniMsg
"""
start_time = time.time()
cost_gold = 0
ignore_flag = False
entity = get_entity_ids(session)
module = matcher.plugin_name or ""
# 用于记录各个 hook 的执行时间
hook_times = {}
hooks_time = 0 # 初始化 hooks_time 变量
try:
if not module:
raise PermissionExemption("Matcher插件名称不存在...")
plugin, user = await get_plugin_and_user(module, entity.user_id)
cost_gold = await get_plugin_cost(bot, user, plugin, session)
# 获取插件和用户数据
plugin_user_start = time.time()
try:
plugin, user = await with_timeout(
get_plugin_and_user(module, entity.user_id), name="get_plugin_and_user"
)
hook_times["get_plugin_user"] = f"{time.time() - plugin_user_start:.3f}s"
except asyncio.TimeoutError:
logger.error(
f"获取插件和用户数据超时,模块: {module}",
LOGGER_COMMAND,
session=session,
)
raise PermissionExemption("获取插件和用户数据超时,请稍后再试...")
# 获取插件费用
cost_start = time.time()
try:
cost_gold = await with_timeout(
get_plugin_cost(bot, user, plugin, session), name="get_plugin_cost"
)
hook_times["cost_gold"] = f"{time.time() - cost_start:.3f}s"
except asyncio.TimeoutError:
logger.error(
f"获取插件费用超时,模块: {module}", LOGGER_COMMAND, session=session
)
# 继续执行,不阻止权限检查
# 执行 bot_filter
bot_filter(session)
await asyncio.gather(
*[
auth_ban(matcher, bot, session),
auth_bot(plugin, bot.self_id),
auth_group(plugin, entity, message),
auth_admin(plugin, session),
auth_plugin(plugin, session, event),
auth_limit(plugin, session),
]
)
# 并行执行所有 hook 检查,并记录执行时间
hooks_start = time.time()
# 创建所有 hook 任务
hook_tasks = [
time_hook(auth_ban(matcher, bot, session), "auth_ban", hook_times),
time_hook(auth_bot(plugin, bot.self_id), "auth_bot", hook_times),
time_hook(auth_group(plugin, entity, message), "auth_group", hook_times),
time_hook(auth_admin(plugin, session), "auth_admin", hook_times),
time_hook(auth_plugin(plugin, session, event), "auth_plugin", hook_times),
time_hook(auth_limit(plugin, session), "auth_limit", hook_times),
]
# 使用 gather 并行执行所有 hook但添加总体超时控制
try:
await with_timeout(
asyncio.gather(*hook_tasks),
timeout=TIMEOUT_SECONDS * 2, # 给总体执行更多时间
name="auth_hooks_gather",
)
except asyncio.TimeoutError:
logger.error(
f"权限检查 hooks 总体执行超时,模块: {module}",
LOGGER_COMMAND,
session=session,
)
# 不抛出异常,允许继续执行
hooks_time = time.time() - hooks_start
except SkipPluginException as e:
LimitManager.unblock(module, entity.user_id, entity.group_id, entity.channel_id)
logger.info(str(e), LOGGER_COMMAND, session=session)
@ -167,7 +345,31 @@ async def auth(
logger.debug("超级用户跳过权限检测...", LOGGER_COMMAND, session=session)
except PermissionExemption as e:
logger.info(str(e), LOGGER_COMMAND, session=session)
# 扣除金币
if not ignore_flag and cost_gold > 0:
await reduce_gold(entity.user_id, module, cost_gold, session)
gold_start = time.time()
try:
await with_timeout(
reduce_gold(entity.user_id, module, cost_gold, session),
name="reduce_gold",
)
hook_times["reduce_gold"] = f"{time.time() - gold_start:.3f}s"
except asyncio.TimeoutError:
logger.error(
f"扣除金币超时,模块: {module}", LOGGER_COMMAND, session=session
)
# 记录总执行时间
total_time = time.time() - start_time
if total_time > WARNING_THRESHOLD: # 如果总时间超过500ms记录详细信息
logger.warning(
f"权限检查耗时过长: {total_time:.3f}s, 模块: {module}, "
f"hooks时间: {hooks_time:.3f}s, "
f"详情: {hook_times}",
LOGGER_COMMAND,
session=session,
)
if ignore_flag:
raise IgnoredException("权限检测 ignore")

View File

@ -1,3 +1,5 @@
import asyncio
import aiofiles
import nonebot
from nonebot import get_loaded_plugins
@ -112,24 +114,29 @@ async def _():
await _handle_setting(plugin, plugin_list, limit_list)
create_list = []
update_list = []
update_task_list = []
for plugin in plugin_list:
if plugin.module_path not in module2id:
create_list.append(plugin)
else:
plugin.id = module2id[plugin.module_path]
await plugin.save(
update_fields=[
"name",
"author",
"version",
"admin_level",
"plugin_type",
"is_show",
]
update_task_list.append(
plugin.save(
update_fields=[
"name",
"author",
"version",
"admin_level",
"plugin_type",
"is_show",
]
)
)
update_list.append(plugin)
if create_list:
await PluginInfo.bulk_create(create_list, 10)
if update_task_list:
await asyncio.gather(*update_task_list)
# if update_list:
# # TODO: 批量更新无法更新plugin_type: tortoise.exceptions.OperationalError:
# column "superuser" does not exist

View File

@ -1,30 +0,0 @@
from zhenxun.models.group_console import GroupConsole
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
@PriorityLifecycle.on_startup(priority=5)
async def _():
"""开启/禁用插件格式修改"""
_, is_create = await GroupConsole.get_or_create(group_id=133133133)
"""标记"""
if is_create:
data_list = []
for group in await GroupConsole.all():
if group.block_plugin:
if modules := group.block_plugin.split(","):
block_plugin = "".join(
(f"{module}," if module.startswith("<") else f"<{module},")
for module in modules
if module.strip()
)
group.block_plugin = block_plugin.replace("<,", "")
if group.block_task:
if modules := group.block_task.split(","):
block_task = "".join(
(f"{module}," if module.startswith("<") else f"<{module},")
for module in modules
if module.strip()
)
group.block_task = block_task.replace("<,", "")
data_list.append(group)
await GroupConsole.bulk_update(data_list, ["block_plugin", "block_task"], 10)

View File

@ -177,7 +177,9 @@ async def _(session: EventSession, arparma: Arparma, state: T_State):
async def _(session: EventSession, arparma: Arparma, state: T_State):
gid = state["group_id"]
await GroupConsole.update_or_create(
group_id=gid, defaults={"group_flag": 0 if arparma.find("delete") else 1}
group_id=gid,
channel_id__isnull=True,
defaults={"group_flag": 0 if arparma.find("delete") else 1},
)
s = "删除" if arparma.find("delete") else "添加"
await MessageUtils.build_message(f"{s}群认证成功!").send(reply_to=True)

View File

@ -30,12 +30,13 @@ class BanConsole(Model):
table = "ban_console"
table_description = "封禁人员/群组数据表"
unique_together = ("user_id", "group_id")
indexes = [("user_id",), ("group_id",)] # noqa: RUF012
cache_type = CacheType.BAN
"""缓存类型"""
cache_key_field = ("user_id", "group_id")
"""缓存键字段"""
enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE]
enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE, DbLockType.UPSERT]
"""开启锁"""
@classmethod
@ -199,3 +200,10 @@ class BanConsole(Model):
if id is not None:
return await cls.safe_get_or_none(id=id)
return await cls._get_data(user_id, group_id)
@classmethod
async def _run_script(cls):
return [
"CREATE INDEX idx_ban_console_user_id ON ban_console(user_id);",
"CREATE INDEX idx_ban_console_group_id ON ban_console(group_id);",
]

View File

@ -87,12 +87,15 @@ class GroupConsole(Model):
table = "group_console"
table_description = "群组信息表"
unique_together = ("group_id", "channel_id")
indexes = [ # noqa: RUF012
("group_id",)
]
cache_type = CacheType.GROUPS
"""缓存类型"""
cache_key_field = ("group_id", "channel_id")
"""缓存键字段"""
enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE]
enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE, DbLockType.UPSERT]
"""开启锁"""
@classmethod
@ -537,4 +540,6 @@ class GroupConsole(Model):
" character varying(255) NOT NULL DEFAULT '';",
"ALTER TABLE group_console ADD superuser_block_task"
" character varying(255) NOT NULL DEFAULT '';",
"CREATE INDEX idx_group_console_group_id ON group_console(group_id);",
"CREATE INDEX idx_group_console_group_null_channel ON group_console(group_id) WHERE channel_id IS NULL;", # 单独创建channel为空的索引 # noqa: E501
]

View File

@ -29,6 +29,7 @@ class UserConsole(Model):
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
table = "user_console"
table_description = "用户数据表"
indexes = [("user_id",), ("uid",)] # noqa: RUF012
cache_type = CacheType.USERS
"""缓存类型"""
@ -198,3 +199,10 @@ class UserConsole(Model):
if goods := await GoodsInfo.get_or_none(goods_name=name):
return await cls.use_props(user_id, goods.uuid, num, platform)
raise GoodsNotFound("未找到商品...")
@classmethod
async def _run_script(cls):
return [
"CREATE INDEX idx_user_console_user_id ON user_console(user_id);",
"CREATE INDEX idx_user_console_uid ON user_console(uid);",
]

View File

@ -56,6 +56,7 @@ await message_list.save()
```
"""
import asyncio
from collections.abc import Callable
from datetime import datetime
from functools import wraps
@ -500,7 +501,8 @@ class CacheManager:
logger.debug(f"清除所有 {cache_type} 缓存", LOG_COMMAND)
return await self.clear(cache_type)
except Exception as e:
logger.error(f"清除缓存 {cache_type} 失败", LOG_COMMAND, e=e)
if f"缓存类型 {cache_type} 不存在" not in str(e):
logger.warning(f"清除缓存 {cache_type} 失败", LOG_COMMAND, e=e)
return False
async def get(
@ -516,13 +518,18 @@ class CacheManager:
返回:
Any: 缓存数据如果不存在返回默认值
"""
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
# 如果缓存被禁用或缓存模式为NONE直接返回默认值
if not self.enabled or cache_config.cache_mode == CacheMode.NONE:
return default
cache_key = None
try:
cache_key = self._build_key(cache_type, key)
data = await self.cache_backend.get(cache_key) # type: ignore
data = await asyncio.wait_for(
self.cache_backend.get(cache_key), # type: ignore
timeout=DB_TIMEOUT_SECONDS,
)
if data is None:
return default
@ -534,6 +541,9 @@ class CacheManager:
if model.result_type:
return self._deserialize_value(data, model.result_type)
return data
except asyncio.TimeoutError:
logger.error(f"获取缓存 {cache_type}:{cache_key} 超时", LOG_COMMAND)
return default
except Exception as e:
logger.error(f"获取缓存 {cache_type} 失败", LOG_COMMAND, e=e)
return default
@ -556,10 +566,12 @@ class CacheManager:
返回:
bool: 是否成功
"""
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
# 如果缓存被禁用或缓存模式为NONE直接返回False
if not self.enabled or cache_config.cache_mode == CacheMode.NONE:
return False
cache_key = None
try:
cache_key = self._build_key(cache_type, key)
model = self.get_model(cache_type)
@ -571,8 +583,14 @@ class CacheManager:
ttl = expire if expire is not None else model.expire
# 设置缓存
await self.cache_backend.set(cache_key, serialized_value, ttl=ttl) # type: ignore
await asyncio.wait_for(
self.cache_backend.set(cache_key, serialized_value, ttl=ttl), # type: ignore
timeout=DB_TIMEOUT_SECONDS,
)
return True
except asyncio.TimeoutError:
logger.error(f"设置缓存 {cache_type}:{cache_key} 超时", LOG_COMMAND)
return False
except Exception as e:
logger.error(f"设置缓存 {cache_type} 失败", LOG_COMMAND, e=e)
return False
@ -647,7 +665,8 @@ class CacheManager:
await self.cache_backend.clear() # type: ignore
return True
except Exception as e:
logger.error("清除缓存失败", LOG_COMMAND, e=e)
if f"缓存类型 {cache_type} 不存在" not in str(e):
logger.warning("清除缓存失败", LOG_COMMAND, e=e)
return False
async def close(self):

View File

@ -1,11 +1,8 @@
from typing import Any, Generic, TypeVar, cast
from typing import Any, ClassVar, Generic, TypeVar, cast
from zhenxun.services.cache import Cache, CacheRoot, cache_config
from zhenxun.services.cache.config import (
COMPOSITE_KEY_SEPARATOR,
CacheMode,
)
from zhenxun.services.db_context import Model
from zhenxun.services.cache.config import COMPOSITE_KEY_SEPARATOR, CacheMode
from zhenxun.services.db_context import Model, with_db_timeout
from zhenxun.services.log import logger
T = TypeVar("T", bound=Model)
@ -40,6 +37,34 @@ class DataAccess(Generic[T]):
```
"""
# 添加缓存统计信息
_cache_stats: ClassVar[dict] = {}
# 空结果标记
_NULL_RESULT = "__NULL_RESULT_PLACEHOLDER__"
# 默认空结果缓存时间(秒)- 设置为5分钟避免频繁查询数据库
_NULL_RESULT_TTL = 300
@classmethod
def set_null_result_ttl(cls, seconds: int) -> None:
"""设置空结果缓存时间
参数:
seconds: 缓存时间
"""
if seconds < 0:
raise ValueError("缓存时间不能为负数")
cls._NULL_RESULT_TTL = seconds
logger.info(f"已设置DataAccess空结果缓存时间为 {seconds}")
@classmethod
def get_null_result_ttl(cls) -> int:
"""获取空结果缓存时间
返回:
int: 缓存时间
"""
return cls._NULL_RESULT_TTL
def __init__(
self, model_cls: type[T], key_field: str = "id", cache_type: str | None = None
):
@ -57,6 +82,52 @@ class DataAccess(Generic[T]):
raise ValueError("缓存类型不能为空")
self.cache = Cache(self.cache_type)
# 初始化缓存统计
if self.cache_type not in self._cache_stats:
self._cache_stats[self.cache_type] = {
"hits": 0, # 缓存命中次数
"misses": 0, # 缓存未命中次数
"null_hits": 0, # 空结果缓存命中次数
"sets": 0, # 缓存设置次数
"null_sets": 0, # 空结果缓存设置次数
"deletes": 0, # 缓存删除次数
}
@classmethod
def get_cache_stats(cls):
"""获取缓存统计信息"""
result = []
for cache_type, stats in cls._cache_stats.items():
hits = stats["hits"]
null_hits = stats.get("null_hits", 0)
misses = stats["misses"]
total = hits + null_hits + misses
hit_rate = ((hits + null_hits) / total * 100) if total > 0 else 0
result.append(
{
"cache_type": cache_type,
"hits": hits,
"null_hits": null_hits,
"misses": misses,
"sets": stats["sets"],
"null_sets": stats.get("null_sets", 0),
"deletes": stats["deletes"],
"hit_rate": f"{hit_rate:.2f}%",
}
)
return result
@classmethod
def reset_cache_stats(cls):
"""重置缓存统计信息"""
for stats in cls._cache_stats.values():
stats["hits"] = 0
stats["null_hits"] = 0
stats["misses"] = 0
stats["sets"] = 0
stats["null_sets"] = 0
stats["deletes"] = 0
def _build_cache_key_from_kwargs(self, **kwargs) -> str | None:
"""从关键字参数构建缓存键
@ -69,12 +140,8 @@ class DataAccess(Generic[T]):
if isinstance(self.key_field, tuple):
# 多字段主键
key_parts = []
for field in self.key_field:
key_parts.append(str(kwargs.get(field, "")))
if key_parts:
return COMPOSITE_KEY_SEPARATOR.join(key_parts)
return None
key_parts.extend(str(kwargs.get(field, "")) for field in self.key_field)
return COMPOSITE_KEY_SEPARATOR.join(key_parts) if key_parts else None
elif self.key_field in kwargs:
# 单字段主键
return str(kwargs[self.key_field])
@ -92,9 +159,14 @@ class DataAccess(Generic[T]):
"""
# 如果没有缓存类型,直接从数据库获取
if not self.cache_type or cache_config.cache_mode == CacheMode.NONE:
return await self.model_cls.safe_get_or_none(*args, **kwargs)
logger.debug(f"{self.model_cls.__name__} 直接从数据库获取数据: {kwargs}")
return await with_db_timeout(
self.model_cls.safe_get_or_none(*args, **kwargs),
operation=f"{self.model_cls.__name__}.safe_get_or_none",
)
# 尝试从缓存获取
cache_key = None
try:
# 尝试构建缓存键
cache_key = self._build_cache_key_from_kwargs(**kwargs)
@ -102,12 +174,33 @@ class DataAccess(Generic[T]):
# 如果成功构建缓存键,尝试从缓存获取
if cache_key is not None:
data = await self.cache.get(cache_key)
if data:
logger.debug(
f"{self.model_cls.__name__} self.cache.get(cache_key)"
f" 从缓存获取到的数据 {type(data)}: {data}"
)
if data == self._NULL_RESULT:
# 空结果缓存命中
self._cache_stats[self.cache_type]["null_hits"] += 1
logger.debug(
f"{self.model_cls.__name__} 从缓存获取到空结果: {cache_key}"
)
return None
elif data:
# 缓存命中
self._cache_stats[self.cache_type]["hits"] += 1
logger.debug(
f"{self.model_cls.__name__} 从缓存获取数据成功: {cache_key}"
)
return cast(T, data)
else:
# 缓存未命中
self._cache_stats[self.cache_type]["misses"] += 1
logger.debug(f"{self.model_cls.__name__} 缓存未命中: {cache_key}")
except Exception as e:
logger.error("从缓存获取数据失败", e=e)
logger.error(f"{self.model_cls.__name__} 从缓存获取数据失败: {kwargs}", e=e)
# 如果缓存中没有,从数据库获取
logger.debug(f"{self.model_cls.__name__} 从数据库获取数据: {kwargs}")
data = await self.model_cls.safe_get_or_none(*args, **kwargs)
# 如果获取到数据,存入缓存
@ -118,9 +211,30 @@ class DataAccess(Generic[T]):
if cache_key is not None:
# 存入缓存
await self.cache.set(cache_key, data)
logger.debug(f"{self.cache_type} 数据已存入缓存: {cache_key}")
self._cache_stats[self.cache_type]["sets"] += 1
logger.debug(
f"{self.model_cls.__name__} 数据已存入缓存: {cache_key}"
)
except Exception as e:
logger.error(f"{self.cache_type} 存入缓存失败,参数: {kwargs}", e=e)
logger.error(
f"{self.model_cls.__name__} 存入缓存失败,参数: {kwargs}", e=e
)
elif cache_key is not None:
# 如果没有获取到数据,缓存空结果
try:
# 存入空结果缓存,使用较短的过期时间
await self.cache.set(
cache_key, self._NULL_RESULT, expire=self._NULL_RESULT_TTL
)
self._cache_stats[self.cache_type]["null_sets"] += 1
logger.debug(
f"{self.model_cls.__name__} 空结果已存入缓存: {cache_key},"
f" TTL={self._NULL_RESULT_TTL}"
)
except Exception as e:
logger.error(
f"{self.model_cls.__name__} 存入空结果缓存失败,参数: {kwargs}", e=e
)
return data
@ -136,9 +250,14 @@ class DataAccess(Generic[T]):
"""
# 如果没有缓存类型,直接从数据库获取
if not self.cache_type or cache_config.cache_mode == CacheMode.NONE:
return await self.model_cls.get_or_none(*args, **kwargs)
logger.debug(f"{self.model_cls.__name__} 直接从数据库获取数据: {kwargs}")
return await with_db_timeout(
self.model_cls.get_or_none(*args, **kwargs),
operation=f"{self.model_cls.__name__}.get_or_none",
)
# 尝试从缓存获取
cache_key = None
try:
# 尝试构建缓存键
cache_key = self._build_cache_key_from_kwargs(**kwargs)
@ -146,12 +265,29 @@ class DataAccess(Generic[T]):
# 如果成功构建缓存键,尝试从缓存获取
if cache_key is not None:
data = await self.cache.get(cache_key)
if data:
if data == self._NULL_RESULT:
# 空结果缓存命中
self._cache_stats[self.cache_type]["null_hits"] += 1
logger.debug(
f"{self.model_cls.__name__} 从缓存获取到空结果: {cache_key}"
)
return None
elif data:
# 缓存命中
self._cache_stats[self.cache_type]["hits"] += 1
logger.debug(
f"{self.model_cls.__name__} 从缓存获取数据成功: {cache_key}"
)
return cast(T, data)
else:
# 缓存未命中
self._cache_stats[self.cache_type]["misses"] += 1
logger.debug(f"{self.model_cls.__name__} 缓存未命中: {cache_key}")
except Exception as e:
logger.error("从缓存获取数据失败", e=e)
logger.error(f"{self.model_cls.__name__} 从缓存获取数据失败: {kwargs}", e=e)
# 如果缓存中没有,从数据库获取
logger.debug(f"{self.model_cls.__name__} 从数据库获取数据: {kwargs}")
data = await self.model_cls.get_or_none(*args, **kwargs)
# 如果获取到数据,存入缓存
@ -162,9 +298,30 @@ class DataAccess(Generic[T]):
if cache_key is not None:
# 存入缓存
await self.cache.set(cache_key, data)
logger.debug(f"{self.cache_type} 数据已存入缓存: {cache_key}")
self._cache_stats[self.cache_type]["sets"] += 1
logger.debug(
f"{self.model_cls.__name__} 数据已存入缓存: {cache_key}"
)
except Exception as e:
logger.error(f"{self.cache_type} 存入缓存失败,参数: {kwargs}", e=e)
logger.error(
f"{self.model_cls.__name__} 存入缓存失败,参数: {kwargs}", e=e
)
elif cache_key is not None:
# 如果没有获取到数据,缓存空结果
try:
# 存入空结果缓存,使用较短的过期时间
await self.cache.set(
cache_key, self._NULL_RESULT, expire=self._NULL_RESULT_TTL
)
self._cache_stats[self.cache_type]["null_sets"] += 1
logger.debug(
f"{self.model_cls.__name__} 空结果已存入缓存: {cache_key},"
f" TTL={self._NULL_RESULT_TTL}"
)
except Exception as e:
logger.error(
f"{self.model_cls.__name__} 存入空结果缓存失败,参数: {kwargs}", e=e
)
return data
@ -203,6 +360,7 @@ class DataAccess(Generic[T]):
# 删除缓存
await self.cache.delete(cache_key)
self._cache_stats[self.cache_type]["deletes"] += 1
logger.debug(f"已清除{self.model_cls.__name__}缓存: {cache_key}")
return True
except Exception as e:
@ -227,12 +385,7 @@ class DataAccess(Generic[T]):
key_parts.append(value if value is not None else "")
# 如果没有有效参数返回None
if not key_parts:
return None
return COMPOSITE_KEY_SEPARATOR.join(key_parts)
# 单个字段作为键
return COMPOSITE_KEY_SEPARATOR.join(key_parts) if key_parts else None
elif hasattr(data, self.key_field):
value = getattr(data, self.key_field, None)
return str(value) if value is not None else None
@ -255,24 +408,22 @@ class DataAccess(Generic[T]):
# 获取缓存类型的配置信息
cache_model = CacheRoot.get_model(self.cache_type)
# 如果有键格式定义,则需要构建特殊格式的键
if cache_model.key_format:
# 构建键参数字典
key_parts = []
# 从格式字符串中提取所需的字段名
import re
field_names = re.findall(r"{([^}]+)}", cache_model.key_format)
# 收集所有字段值
for field in field_names:
value = getattr(item, field, "")
key_parts.append(value if value is not None else "")
return COMPOSITE_KEY_SEPARATOR.join(key_parts)
else:
if not cache_model.key_format:
# 常规处理,使用主键作为缓存键
return self._build_composite_key(item)
# 构建键参数字典
key_parts = []
# 从格式字符串中提取所需的字段名
import re
field_names = re.findall(r"{([^}]+)}", cache_model.key_format)
# 收集所有字段值
for field in field_names:
value = getattr(item, field, "")
key_parts.append(value if value is not None else "")
return COMPOSITE_KEY_SEPARATOR.join(key_parts)
async def _cache_items(self, data_list: list[T]) -> None:
"""将数据列表存入缓存
@ -289,14 +440,19 @@ class DataAccess(Generic[T]):
try:
# 遍历数据列表,将每条数据存入缓存
cached_count = 0
for item in data_list:
cache_key = self._build_cache_key_for_item(item)
if cache_key is not None:
await self.cache.set(cache_key, item)
cached_count += 1
self._cache_stats[self.cache_type]["sets"] += 1
logger.debug(f"{self.cache_type} 数据已存入缓存,数量: {len(data_list)}")
logger.debug(
f"{self.model_cls.__name__} 批量缓存: {cached_count}/{len(data_list)}"
)
except Exception as e:
logger.error(f"{self.cache_type} 数据存入缓存失败", e=e)
logger.error(f"{self.model_cls.__name__} 批量缓存失败", e=e)
async def filter(self, *args, **kwargs) -> list[T]:
"""筛选数据
@ -309,7 +465,11 @@ class DataAccess(Generic[T]):
List[T]: 查询结果列表
"""
# 从数据库获取数据
logger.debug(f"{self.model_cls.__name__} filter: 从数据库查询, 参数: {kwargs}")
data_list = await self.model_cls.filter(*args, **kwargs)
logger.debug(
f"{self.model_cls.__name__} filter: 查询结果数量: {len(data_list)}"
)
# 将数据存入缓存
await self._cache_items(data_list)
@ -323,7 +483,9 @@ class DataAccess(Generic[T]):
List[T]: 所有数据列表
"""
# 直接从数据库获取
logger.debug(f"{self.model_cls.__name__} all: 从数据库查询所有数据")
data_list = await self.model_cls.all()
logger.debug(f"{self.model_cls.__name__} all: 查询结果数量: {len(data_list)}")
# 将数据存入缓存
await self._cache_items(data_list)
@ -366,6 +528,7 @@ class DataAccess(Generic[T]):
T: 创建的数据
"""
# 创建数据
logger.debug(f"{self.model_cls.__name__} create: 创建数据, 参数: {kwargs}")
data = await self.model_cls.create(**kwargs)
# 如果有缓存类型,将数据存入缓存
@ -376,11 +539,16 @@ class DataAccess(Generic[T]):
if cache_key is not None:
# 存入缓存
await self.cache.set(cache_key, data)
self._cache_stats[self.cache_type]["sets"] += 1
logger.debug(
f"{self.cache_type} 新创建的数据已存入缓存: {cache_key}"
f"{self.model_cls.__name__} create: "
f"新创建的数据已存入缓存: {cache_key}"
)
except Exception as e:
logger.error(f"{self.cache_type} 存入缓存失败,参数: {kwargs}", e=e)
logger.error(
f"{self.model_cls.__name__} create: 存入缓存失败,参数: {kwargs}",
e=e,
)
return data
@ -409,6 +577,7 @@ class DataAccess(Generic[T]):
if cache_key is not None:
# 存入缓存
await self.cache.set(cache_key, data)
self._cache_stats[self.cache_type]["sets"] += 1
logger.debug(f"更新或创建的数据已存入缓存: {cache_key}")
except Exception as e:
logger.error(f"存入缓存失败,参数: {kwargs}", e=e)
@ -425,6 +594,8 @@ class DataAccess(Generic[T]):
返回:
int: 删除的数据数量
"""
logger.debug(f"{self.model_cls.__name__} delete: 删除数据, 参数: {kwargs}")
# 如果有缓存类型且有key_field参数先尝试删除缓存
if self.cache_type and cache_config.cache_mode != CacheMode.NONE:
try:
@ -434,21 +605,36 @@ class DataAccess(Generic[T]):
if cache_key is not None:
# 如果成功构建缓存键,直接删除缓存
await self.cache.delete(cache_key)
logger.debug(f"已删除缓存: {cache_key}")
self._cache_stats[self.cache_type]["deletes"] += 1
logger.debug(
f"{self.model_cls.__name__} delete: 已删除缓存: {cache_key}"
)
else:
# 否则需要先查询出要删除的数据,然后删除对应的缓存
items = await self.model_cls.filter(*args, **kwargs)
logger.debug(
f"{self.model_cls.__name__} delete:"
f" 查询到 {len(items)} 条要删除的数据"
)
for item in items:
item_cache_key = self._build_cache_key_for_item(item)
if item_cache_key is not None:
await self.cache.delete(item_cache_key)
self._cache_stats[self.cache_type]["deletes"] += 1
if items:
logger.debug(f"已删除{len(items)}条数据的缓存")
logger.debug(
f"{self.model_cls.__name__} delete:"
f" 已删除 {len(items)} 条数据的缓存"
)
except Exception as e:
logger.error("删除缓存失败", e=e)
logger.error(f"{self.model_cls.__name__} delete: 删除缓存失败", e=e)
# 删除数据
return await self.model_cls.filter(*args, **kwargs).delete()
result = await self.model_cls.filter(*args, **kwargs).delete()
logger.debug(
f"{self.model_cls.__name__} delete: 已从数据库删除 {result} 条数据"
)
return result
def _generate_cache_key(self, data: T) -> str:
"""根据数据对象生成缓存键

View File

@ -1,67 +1,83 @@
from asyncio import Semaphore
import asyncio
from collections.abc import Iterable
import contextlib
import time
from typing import Any, ClassVar
from typing_extensions import Self
from urllib.parse import urlparse
import nonebot
from nonebot import get_driver
from nonebot.utils import is_coroutine_callable
from tortoise import Tortoise
from tortoise.backends.base.client import BaseDBAsyncClient
from tortoise.connection import connections
from tortoise.exceptions import IntegrityError, MultipleObjectsReturned
from tortoise.models import Model as TortoiseModel
from tortoise.transactions import in_transaction
from zhenxun.configs.config import BotConfig
from zhenxun.services.cache import CacheRoot
from zhenxun.services.log import logger
from zhenxun.utils.enum import DbLockType
from zhenxun.utils.exception import HookPriorityException
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from .cache import CacheRoot
from .cache.config import COMPOSITE_KEY_SEPARATOR
from .log import logger
driver = get_driver()
SCRIPT_METHOD = []
MODELS: list[str] = []
driver = nonebot.get_driver()
# 数据库操作超时设置(秒)
DB_TIMEOUT_SECONDS = 3.0
# 性能监控阈值(秒)
SLOW_QUERY_THRESHOLD = 0.5
LOG_COMMAND = "DbContext"
async def with_db_timeout(
coro, timeout: float = DB_TIMEOUT_SECONDS, operation: str | None = None
):
"""带超时控制的数据库操作"""
start_time = time.time()
try:
result = await asyncio.wait_for(coro, timeout=timeout)
elapsed = time.time() - start_time
if elapsed > SLOW_QUERY_THRESHOLD and operation:
logger.warning(f"慢查询: {operation} 耗时 {elapsed:.3f}s", LOG_COMMAND)
return result
except asyncio.TimeoutError:
if operation:
logger.error(f"数据库操作超时: {operation} (>{timeout}s)", LOG_COMMAND)
raise
class Model(TortoiseModel):
"""
自动添加模块
增强的ORM基类解决锁嵌套问题
"""
sem_data: ClassVar[dict[str, dict[str, Semaphore]]] = {}
sem_data: ClassVar[dict[str, dict[str, asyncio.Semaphore]]] = {}
_current_locks: ClassVar[dict[int, DbLockType]] = {} # 跟踪当前协程持有的锁
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if cls.__module__ not in MODELS:
MODELS.append(cls.__module__)
if func := getattr(cls, "_run_script", None):
SCRIPT_METHOD.append((cls.__module__, func))
if enable_lock := getattr(cls, "enable_lock", []):
"""创建锁"""
cls.sem_data[cls.__module__] = {}
for lock in enable_lock:
cls.sem_data[cls.__module__][lock] = Semaphore(1)
@classmethod
def get_semaphore(cls, lock_type: DbLockType):
return cls.sem_data.get(cls.__module__, {}).get(lock_type, None)
@classmethod
def get_cache_type(cls) -> str | None:
"""获取缓存类型"""
return getattr(cls, "cache_type", None)
@classmethod
def get_cache_key_field(cls) -> str | tuple[str]:
"""获取缓存键字段名
返回:
str | tuple[str]: 缓存键字段名可能是单个字段名或字段名元组
"""
if hasattr(cls, "cache_key_field"):
return getattr(cls, "cache_key_field", "id")
return "id"
"""获取缓存键字段"""
return getattr(cls, "cache_key_field", "id")
@classmethod
def get_cache_key(cls, instance) -> str | None:
@ -71,13 +87,14 @@ class Model(TortoiseModel):
instance: 模型实例
返回:
str | None
str | None: 缓存键如果无法获取则返回None
"""
from zhenxun.services.cache.config import COMPOSITE_KEY_SEPARATOR
key_field = cls.get_cache_key_field()
# 如果是元组,表示多个字段组成键
if isinstance(key_field, tuple):
# 构建键参数列表
# 多字段主键
key_parts = []
for field in key_field:
if hasattr(instance, field):
@ -85,25 +102,60 @@ class Model(TortoiseModel):
key_parts.append(value if value is not None else "")
else:
# 如果缺少任何必要的字段返回None
return None
key_parts.append("")
# 如果没有有效参数返回None
if not key_parts:
return None
return COMPOSITE_KEY_SEPARATOR.join(str(param) for param in key_parts)
# 单个字段作为键
return COMPOSITE_KEY_SEPARATOR.join(key_parts) if key_parts else None
elif hasattr(instance, key_field):
return getattr(instance, key_field, None)
value = getattr(instance, key_field, None)
return str(value) if value is not None else None
return None
@classmethod
def get_semaphore(cls, lock_type: DbLockType):
enable_lock = getattr(cls, "enable_lock", None)
if not enable_lock or lock_type not in enable_lock:
return None
if cls.__name__ not in cls.sem_data:
cls.sem_data[cls.__name__] = {}
if lock_type not in cls.sem_data[cls.__name__]:
cls.sem_data[cls.__name__][lock_type] = asyncio.Semaphore(1)
return cls.sem_data[cls.__name__][lock_type]
@classmethod
def _require_lock(cls, lock_type: DbLockType) -> bool:
"""检查是否需要真正加锁"""
task_id = id(asyncio.current_task())
return cls._current_locks.get(task_id) != lock_type
@classmethod
@contextlib.asynccontextmanager
async def _lock_context(cls, lock_type: DbLockType):
"""带重入检查的锁上下文"""
task_id = id(asyncio.current_task())
need_lock = cls._require_lock(lock_type)
if need_lock and (sem := cls.get_semaphore(lock_type)):
cls._current_locks[task_id] = lock_type
async with sem:
yield
cls._current_locks.pop(task_id, None)
else:
yield
@classmethod
async def create(
cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any
) -> Self:
return await super().create(using_db=using_db, **kwargs)
"""创建数据使用CREATE锁"""
async with cls._lock_context(DbLockType.CREATE):
# 直接调用父类的_create方法避免触发save的锁
result = await super().create(using_db=using_db, **kwargs)
if cache_type := cls.get_cache_type():
await CacheRoot.invalidate_cache(cache_type, cls.get_cache_key(result))
return result
@classmethod
async def get_or_create(
@ -112,31 +164,13 @@ class Model(TortoiseModel):
using_db: BaseDBAsyncClient | None = None,
**kwargs: Any,
) -> tuple[Self, bool]:
if sem := cls.get_semaphore(DbLockType.CREATE):
async with sem:
# 在锁内执行查询和创建操作
result, is_create = await super().get_or_create(
defaults=defaults, using_db=using_db, **kwargs
)
if is_create and (cache_type := cls.get_cache_type()):
# 获取缓存键
key = cls.get_cache_key(result)
await CacheRoot.invalidate_cache(
cache_type, key if key is not None else None
)
return (result, is_create)
else:
# 如果没有锁,则执行原来的逻辑
result, is_create = await super().get_or_create(
defaults=defaults, using_db=using_db, **kwargs
)
if is_create and (cache_type := cls.get_cache_type()):
# 获取缓存键
key = cls.get_cache_key(result)
await CacheRoot.invalidate_cache(
cache_type, key if key is not None else None
)
return (result, is_create)
"""获取或创建数据(无锁版本,依赖数据库约束)"""
result = await super().get_or_create(
defaults=defaults, using_db=using_db, **kwargs
)
if cache_type := cls.get_cache_type():
await CacheRoot.invalidate_cache(cache_type, cls.get_cache_key(result[0]))
return result
@classmethod
async def update_or_create(
@ -145,73 +179,28 @@ class Model(TortoiseModel):
using_db: BaseDBAsyncClient | None = None,
**kwargs: Any,
) -> tuple[Self, bool]:
if sem := cls.get_semaphore(DbLockType.CREATE):
async with sem:
# 在锁内执行查询和创建操作
result = await super().update_or_create(
defaults=defaults, using_db=using_db, **kwargs
)
"""更新或创建数据使用UPSERT锁"""
async with cls._lock_context(DbLockType.UPSERT):
try:
# 先尝试更新(带行锁)
async with in_transaction():
if obj := await cls.filter(**kwargs).select_for_update().first():
await obj.update_from_dict(defaults or {})
await obj.save()
result = (obj, False)
else:
# 创建时不重复加锁
result = await cls.create(**kwargs, **(defaults or {})), True
if cache_type := cls.get_cache_type():
# 获取缓存键
key = cls.get_cache_key(result[0])
await CacheRoot.invalidate_cache(
cache_type, key if key is not None else None
cache_type, cls.get_cache_key(result[0])
)
return result
else:
# 如果没有锁,则执行原来的逻辑
result = await super().update_or_create(
defaults=defaults, using_db=using_db, **kwargs
)
if cache_type := cls.get_cache_type():
# 获取缓存键
key = cls.get_cache_key(result[0])
await CacheRoot.invalidate_cache(
cache_type, key if key is not None else None
)
return result
@classmethod
async def bulk_create( # type: ignore
cls,
objects: Iterable[Self], # type: ignore
batch_size: int | None = None,
ignore_conflicts: bool = False,
update_fields: Iterable[str] | None = None,
on_conflict: Iterable[str] | None = None,
using_db: BaseDBAsyncClient | None = None,
) -> list[Self]: # type: ignore
result = await super().bulk_create(
objects=objects,
batch_size=batch_size,
ignore_conflicts=ignore_conflicts,
update_fields=update_fields,
on_conflict=on_conflict,
using_db=using_db,
)
if cache_type := cls.get_cache_type():
# 批量创建时清除整个类型的缓存
await CacheRoot.invalidate_cache(cache_type)
return result
@classmethod
async def bulk_update( # type: ignore
cls,
objects: Iterable[Self], # type: ignore
fields: Iterable[str],
batch_size: int | None = None,
using_db: BaseDBAsyncClient | None = None,
) -> int: # type: ignore
result = await super().bulk_update(
objects=objects,
fields=fields,
batch_size=batch_size,
using_db=using_db,
)
if cache_type := cls.get_cache_type():
# 批量更新时清除整个类型的缓存
await CacheRoot.invalidate_cache(cache_type)
return result
except IntegrityError:
# 处理极端情况下的唯一约束冲突
obj = await cls.get(**kwargs)
return obj, False
async def save(
self,
@ -220,37 +209,27 @@ class Model(TortoiseModel):
force_create: bool = False,
force_update: bool = False,
):
if getattr(self, "id", None) is None:
sem = self.get_semaphore(DbLockType.CREATE)
else:
sem = self.get_semaphore(DbLockType.UPDATE)
if sem:
async with sem:
await super().save(
using_db=using_db,
update_fields=update_fields,
force_create=force_create,
force_update=force_update,
)
else:
"""保存数据(根据操作类型自动选择锁)"""
lock_type = (
DbLockType.CREATE
if getattr(self, "id", None) is None
else DbLockType.UPDATE
)
async with self._lock_context(lock_type):
await super().save(
using_db=using_db,
update_fields=update_fields,
force_create=force_create,
force_update=force_update,
)
if cache_type := getattr(self, "cache_type", None):
# 获取缓存键
key = self.__class__.get_cache_key(self)
await CacheRoot.invalidate_cache(cache_type, key)
if cache_type := getattr(self, "cache_type", None):
await CacheRoot.invalidate_cache(
cache_type, self.__class__.get_cache_key(self)
)
async def delete(self, using_db: BaseDBAsyncClient | None = None):
# 在删除前获取缓存键
cache_type = getattr(self, "cache_type", None)
key = None
if cache_type:
key = self.__class__.get_cache_key(self)
key = self.__class__.get_cache_key(self) if cache_type else None
# 执行删除操作
await super().delete(using_db=using_db)
@ -280,15 +259,23 @@ class Model(TortoiseModel):
"""
try:
# 先尝试使用 get_or_none 获取单个记录
return await cls.get_or_none(*args, using_db=using_db, **kwargs)
except Exception as e:
# 如果出现错误(可能是存在多个记录)
if "Multiple objects" in str(e):
logger.warning(
f"{cls.__name__} safe_get_or_none 发现多个记录: {kwargs}"
try:
return await with_db_timeout(
cls.get_or_none(*args, using_db=using_db, **kwargs),
operation=f"{cls.__name__}.get_or_none",
)
except MultipleObjectsReturned:
# 如果出现多个记录的情况,进行特殊处理
logger.warning(
f"{cls.__name__} safe_get_or_none 发现多个记录: {kwargs}",
LOG_COMMAND,
)
# 查询所有匹配记录
records = await cls.filter(*args, **kwargs)
records = await with_db_timeout(
cls.filter(*args, **kwargs).all(),
operation=f"{cls.__name__}.filter.all",
)
if not records:
return None
@ -301,20 +288,39 @@ class Model(TortoiseModel):
)
for record in records[1:]:
try:
await record.delete()
await with_db_timeout(
record.delete(),
operation=f"{cls.__name__}.delete_duplicate",
)
logger.info(
f"{cls.__name__} 删除重复记录:"
f" id={getattr(record, 'id', None)}"
f" id={getattr(record, 'id', None)}",
LOG_COMMAND,
)
except Exception as del_e:
logger.error(f"删除重复记录失败: {del_e}")
return records[0]
# 如果不需要清理或没有 id 字段,则返回最新的记录
if hasattr(cls, "id"):
return await cls.filter(*args, **kwargs).order_by("-id").first()
return await with_db_timeout(
cls.filter(*args, **kwargs).order_by("-id").first(),
operation=f"{cls.__name__}.filter.order_by.first",
)
# 如果没有 id 字段,则返回第一个记录
return await cls.filter(*args, **kwargs).first()
return await with_db_timeout(
cls.filter(*args, **kwargs).first(),
operation=f"{cls.__name__}.filter.first",
)
except asyncio.TimeoutError:
logger.error(
f"数据库操作超时: {cls.__name__}.safe_get_or_none", LOG_COMMAND
)
return None
except Exception as e:
# 其他类型的错误则继续抛出
logger.error(
f"数据库操作异常: {cls.__name__}.safe_get_or_none, {e!s}", LOG_COMMAND
)
raise
@ -334,6 +340,77 @@ class DbConnectError(Exception):
pass
POSTGRESQL_CONFIG = {
"max_size": 30, # 最大连接数
"min_size": 5, # 最小保持的连接数(可选)
}
MYSQL_CONFIG = {
"max_connections": 20, # 最大连接数
"connect_timeout": 30, # 连接超时(可选)
}
SQLITE_CONFIG = {
"journal_mode": "WAL", # 提高并发写入性能
"timeout": 30, # 锁等待超时(可选)
}
def get_config(db_url: str) -> dict:
"""获取数据库配置"""
parsed = urlparse(BotConfig.db_url)
# 基础配置
config = {
"connections": {
"default": BotConfig.db_url # 默认直接使用连接字符串
},
"apps": {
"models": {
"models": MODELS,
"default_connection": "default",
}
},
"timezone": "Asia/Shanghai",
}
# 根据数据库类型应用高级配置
if parsed.scheme.startswith("postgres"):
config["connections"]["default"] = {
"engine": "tortoise.backends.asyncpg",
"credentials": {
"host": parsed.hostname,
"port": parsed.port or 5432,
"user": parsed.username,
"password": parsed.password,
"database": parsed.path[1:],
},
**POSTGRESQL_CONFIG,
}
elif parsed.scheme == "mysql":
config["connections"]["default"] = {
"engine": "tortoise.backends.mysql",
"credentials": {
"host": parsed.hostname,
"port": parsed.port or 3306,
"user": parsed.username,
"password": parsed.password,
"database": parsed.path[1:],
},
**MYSQL_CONFIG,
}
elif parsed.scheme == "sqlite":
config["connections"]["default"] = {
"engine": "tortoise.backends.sqlite",
"credentials": {
"file_path": parsed.path[1:] or ":memory:",
},
**SQLITE_CONFIG,
}
return config
@PriorityLifecycle.on_startup(priority=1)
async def init():
if not BotConfig.db_url:
@ -349,9 +426,7 @@ async def init():
raise DbUrlIsNode("\n" + error.strip())
try:
await Tortoise.init(
db_url=BotConfig.db_url,
modules={"models": MODELS},
timezone="Asia/Shanghai",
config=get_config(BotConfig.db_url),
)
if SCRIPT_METHOD:
db = Tortoise.get_connection("default")
@ -366,17 +441,21 @@ async def init():
if sql:
sql_list += sql
except Exception as e:
logger.trace(f"{module} 执行SCRIPT_METHOD方法出错...", e=e)
logger.debug(f"{module} 执行SCRIPT_METHOD方法出错...", e=e)
for sql in sql_list:
logger.trace(f"执行SQL: {sql}")
logger.debug(f"执行SQL: {sql}")
try:
await db.execute_query_dict(sql)
await asyncio.wait_for(
db.execute_query_dict(sql), timeout=DB_TIMEOUT_SECONDS
)
# await TestSQL.raw(sql)
except Exception as e:
logger.trace(f"执行SQL: {sql} 错误...", e=e)
logger.debug(f"执行SQL: {sql} 错误...", e=e)
if sql_list:
logger.debug("SCRIPT_METHOD方法执行完毕!")
logger.debug("开始生成数据库表结构...")
await Tortoise.generate_schemas()
logger.debug("数据库表结构生成完毕!")
logger.info("Database loaded successfully!")
except Exception as e:
raise DbConnectError(f"数据库连接错误... e:{e}") from e

View File

@ -78,6 +78,8 @@ class DbLockType(StrEnum):
"""更新"""
QUERY = "QUERY"
"""查询"""
UPSERT = "UPSERT"
"""创建或更新"""
class GoldHandle(StrEnum):

View File

@ -49,6 +49,9 @@ async def _():
try:
for priority in priority_list:
for func in priority_data[priority]:
logger.debug(
f"执行优先级 [{priority}] on_startup 方法: {func.__module__}"
)
if is_coroutine_callable(func):
await func()
else:

View File

@ -4,7 +4,7 @@ from datetime import date, datetime
import os
from pathlib import Path
import time
from typing import Any
from typing import Any, ClassVar
import httpx
from nonebot_plugin_uninfo import Uninfo
@ -30,38 +30,38 @@ class ResourceDirManager:
临时文件管理器
"""
temp_path = [] # noqa: RUF012
temp_path: ClassVar[set[Path]] = set()
@classmethod
def __tree_append(cls, path: Path):
"""递归添加文件夹
参数:
path: 文件夹路径
"""
def __tree_append(cls, path: Path, deep: int = 1, current: int = 0):
"""递归添加文件夹"""
if current >= deep and deep != -1:
return
path = path.resolve() # 标准化路径
for f in os.listdir(path):
file = path / f
file = (path / f).resolve() # 标准化子路径
if file.is_dir():
if file not in cls.temp_path:
cls.temp_path.append(file)
logger.debug(f"添加临时文件夹: {path}")
cls.__tree_append(file)
cls.temp_path.add(file)
logger.debug(f"添加临时文件夹: {file}")
cls.__tree_append(file, deep, current + 1)
@classmethod
def add_temp_dir(cls, path: str | Path, tree: bool = False):
def add_temp_dir(cls, path: str | Path, tree: bool = False, deep: int = 1):
"""添加临时清理文件夹,这些文件夹会被自动清理
参数:
path: 文件夹路径
tree: 是否递归添加文件夹
deep: 深度, -1 为无限深度
"""
if isinstance(path, str):
path = Path(path)
if path not in cls.temp_path:
cls.temp_path.append(path)
cls.temp_path.add(path)
logger.debug(f"添加临时文件夹: {path}")
if tree:
cls.__tree_append(path)
cls.__tree_append(path, deep)
class CountLimiter: