mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 14:22:55 +08:00
⚡ 超级无敌大优化,解决延迟与卡死问题
This commit is contained in:
parent
419c8934a0
commit
defe99e66c
@ -1,11 +1,17 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from nonebot_plugin_alconna import At
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.models.level_user import LevelUser
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.services.data_access import DataAccess
|
||||
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.utils import get_entity_ids
|
||||
|
||||
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
|
||||
from .exception import SkipPluginException
|
||||
from .utils import send_message
|
||||
|
||||
@ -17,42 +23,77 @@ async def auth_admin(plugin: PluginInfo, session: Uninfo):
|
||||
plugin: PluginInfo
|
||||
session: Uninfo
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
if not plugin.admin_level:
|
||||
return
|
||||
entity = get_entity_ids(session)
|
||||
level_dao = DataAccess(LevelUser)
|
||||
global_user = await level_dao.safe_get_or_none(
|
||||
user_id=session.user.id, group_id__isnull=True
|
||||
)
|
||||
user_level = 0
|
||||
if global_user:
|
||||
user_level = global_user.user_level
|
||||
if entity.group_id:
|
||||
# 获取用户在当前群组的权限数据
|
||||
group_users = await level_dao.safe_get_or_none(
|
||||
user_id=session.user.id, group_id=entity.group_id
|
||||
|
||||
try:
|
||||
entity = get_entity_ids(session)
|
||||
level_dao = DataAccess(LevelUser)
|
||||
|
||||
# 并行查询用户权限数据
|
||||
global_user: LevelUser | None = None
|
||||
group_users: LevelUser | None = None
|
||||
|
||||
# 查询全局权限
|
||||
global_user_task = level_dao.safe_get_or_none(
|
||||
user_id=session.user.id, group_id__isnull=True
|
||||
)
|
||||
if group_users:
|
||||
|
||||
# 如果在群组中,查询群组权限
|
||||
group_users_task = None
|
||||
if entity.group_id:
|
||||
group_users_task = level_dao.safe_get_or_none(
|
||||
user_id=session.user.id, group_id=entity.group_id
|
||||
)
|
||||
|
||||
# 等待查询完成,添加超时控制
|
||||
try:
|
||||
results = await asyncio.wait_for(
|
||||
asyncio.gather(global_user_task, group_users_task or asyncio.sleep(0)),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
global_user = results[0]
|
||||
group_users = results[1] if group_users_task else None
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"查询用户权限超时: user_id={session.user.id}", LOGGER_COMMAND)
|
||||
# 超时时不阻塞,继续执行
|
||||
return
|
||||
|
||||
user_level = global_user.user_level if global_user else 0
|
||||
if entity.group_id and group_users:
|
||||
user_level = max(user_level, group_users.user_level)
|
||||
|
||||
if user_level < plugin.admin_level:
|
||||
await send_message(
|
||||
session,
|
||||
[
|
||||
At(flag="user", target=session.user.id),
|
||||
if user_level < plugin.admin_level:
|
||||
await send_message(
|
||||
session,
|
||||
[
|
||||
At(flag="user", target=session.user.id),
|
||||
f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}",
|
||||
],
|
||||
entity.user_id,
|
||||
)
|
||||
|
||||
raise SkipPluginException(
|
||||
f"{plugin.name}({plugin.module}) 管理员权限不足..."
|
||||
)
|
||||
elif global_user:
|
||||
if global_user.user_level < plugin.admin_level:
|
||||
await send_message(
|
||||
session,
|
||||
f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}",
|
||||
],
|
||||
entity.user_id,
|
||||
)
|
||||
raise SkipPluginException(
|
||||
f"{plugin.name}({plugin.module}) 管理员权限不足..."
|
||||
)
|
||||
elif global_user:
|
||||
if global_user.user_level < plugin.admin_level:
|
||||
await send_message(
|
||||
session,
|
||||
f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}",
|
||||
)
|
||||
raise SkipPluginException(
|
||||
f"{plugin.name}({plugin.module}) 管理员权限不足..."
|
||||
)
|
||||
|
||||
raise SkipPluginException(
|
||||
f"{plugin.name}({plugin.module}) 管理员权限不足..."
|
||||
)
|
||||
finally:
|
||||
# 记录执行时间
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||
logger.warning(
|
||||
f"auth_admin 耗时: {elapsed:.3f}s, plugin={plugin.module}",
|
||||
LOGGER_COMMAND,
|
||||
session=session,
|
||||
)
|
||||
|
||||
@ -1,3 +1,6 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from nonebot.adapters import Bot
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot_plugin_alconna import At
|
||||
@ -7,9 +10,12 @@ from zhenxun.configs.config import Config
|
||||
from zhenxun.models.ban_console import BanConsole
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.services.data_access import DataAccess
|
||||
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import PluginType
|
||||
from zhenxun.utils.utils import EntityIDs, get_entity_ids
|
||||
|
||||
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
|
||||
from .exception import SkipPluginException
|
||||
from .utils import freq, send_message
|
||||
|
||||
@ -21,32 +27,105 @@ Config.add_plugin_config(
|
||||
)
|
||||
|
||||
|
||||
def calculate_ban_time(ban_record: BanConsole | None) -> int:
|
||||
"""根据ban记录计算剩余ban时间
|
||||
|
||||
参数:
|
||||
ban_record: BanConsole记录
|
||||
|
||||
返回:
|
||||
int: ban剩余时长,-1时为永久ban,0表示未被ban
|
||||
"""
|
||||
if not ban_record:
|
||||
return 0
|
||||
|
||||
if ban_record.duration == -1:
|
||||
return -1
|
||||
|
||||
_time = time.time() - (ban_record.ban_time + ban_record.duration)
|
||||
return 0 if _time > 0 else int(abs(_time))
|
||||
|
||||
|
||||
async def is_ban(user_id: str | None, group_id: str | None) -> int:
|
||||
"""检查用户或群组是否被ban
|
||||
|
||||
参数:
|
||||
user_id: 用户ID
|
||||
group_id: 群组ID
|
||||
|
||||
返回:
|
||||
int: ban的剩余时间,0表示未被ban
|
||||
"""
|
||||
if not user_id and not group_id:
|
||||
return 0
|
||||
|
||||
start_time = time.time()
|
||||
ban_dao = DataAccess(BanConsole)
|
||||
|
||||
# 分别获取用户在群组中的ban记录和全局ban记录
|
||||
group_user = None
|
||||
user = None
|
||||
|
||||
if user_id and group_id:
|
||||
group_user = await ban_dao.safe_get_or_none(user_id=user_id, group_id=group_id)
|
||||
try:
|
||||
# 并行查询用户和群组的 ban 记录
|
||||
tasks = []
|
||||
if user_id and group_id:
|
||||
tasks.append(ban_dao.safe_get_or_none(user_id=user_id, group_id=group_id))
|
||||
if user_id:
|
||||
tasks.append(ban_dao.safe_get_or_none(user_id=user_id, group_id=""))
|
||||
|
||||
if user_id:
|
||||
user = await ban_dao.safe_get_or_none(user_id=user_id, group_id="")
|
||||
# 等待所有查询完成,添加超时控制
|
||||
if tasks:
|
||||
try:
|
||||
ban_records = await asyncio.wait_for(
|
||||
asyncio.gather(*tasks), timeout=DB_TIMEOUT_SECONDS
|
||||
)
|
||||
if len(tasks) == 2:
|
||||
group_user, user = ban_records
|
||||
elif user_id and group_id:
|
||||
group_user = ban_records[0]
|
||||
else:
|
||||
user = ban_records[0]
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
f"查询ban记录超时: user_id={user_id}, group_id={group_id}",
|
||||
LOGGER_COMMAND,
|
||||
)
|
||||
# 超时时返回0,避免阻塞
|
||||
return 0
|
||||
|
||||
results = []
|
||||
if group_user:
|
||||
results.append(group_user)
|
||||
if user:
|
||||
results.append(user)
|
||||
if not results:
|
||||
return 0
|
||||
for result in results:
|
||||
if result.duration > 0 or result.duration == -1:
|
||||
return await BanConsole.check_ban_time(user_id, group_id)
|
||||
return 0
|
||||
# 检查记录并计算ban时间
|
||||
results = []
|
||||
if group_user:
|
||||
results.append(group_user)
|
||||
if user:
|
||||
results.append(user)
|
||||
|
||||
# 如果没有找到记录,返回0
|
||||
if not results:
|
||||
return 0
|
||||
|
||||
logger.debug(f"查询到的ban记录: {results}", LOGGER_COMMAND)
|
||||
# 检查所有记录,找出最严格的ban(时间最长的)
|
||||
max_ban_time: int = 0
|
||||
for result in results:
|
||||
if result.duration > 0 or result.duration == -1:
|
||||
# 直接计算ban时间,避免再次查询数据库
|
||||
ban_time = calculate_ban_time(result)
|
||||
if ban_time == -1 or ban_time > max_ban_time:
|
||||
max_ban_time = ban_time
|
||||
|
||||
return max_ban_time
|
||||
finally:
|
||||
# 记录执行时间
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||
logger.warning(
|
||||
f"is_ban 耗时: {elapsed:.3f}s",
|
||||
LOGGER_COMMAND,
|
||||
session=user_id,
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
|
||||
def check_plugin_type(matcher: Matcher) -> bool:
|
||||
@ -66,22 +145,22 @@ def check_plugin_type(matcher: Matcher) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def format_time(time: float) -> str:
|
||||
def format_time(time_val: float) -> str:
|
||||
"""格式化时间
|
||||
|
||||
参数:
|
||||
time: ban时长
|
||||
time_val: ban时长
|
||||
|
||||
返回:
|
||||
str: 格式化时间文本
|
||||
"""
|
||||
if time == -1:
|
||||
if time_val == -1:
|
||||
return "∞"
|
||||
time = abs(int(time))
|
||||
if time < 60:
|
||||
time_str = f"{time!s} 秒"
|
||||
time_val = abs(int(time_val))
|
||||
if time_val < 60:
|
||||
time_str = f"{time_val!s} 秒"
|
||||
else:
|
||||
minute = int(time / 60)
|
||||
minute = int(time_val / 60)
|
||||
if minute > 60:
|
||||
hours = minute // 60
|
||||
minute %= 60
|
||||
@ -91,66 +170,132 @@ def format_time(time: float) -> str:
|
||||
return time_str
|
||||
|
||||
|
||||
async def group_handle(group_id: str):
|
||||
async def group_handle(group_id: str) -> None:
|
||||
"""群组ban检查
|
||||
|
||||
参数:
|
||||
ban_dao: BanConsole数据访问对象
|
||||
group_id: 群组id
|
||||
|
||||
异常:
|
||||
SkipPluginException: 群组处于黑名单
|
||||
"""
|
||||
if await is_ban(None, group_id):
|
||||
raise SkipPluginException("群组处于黑名单中...")
|
||||
start_time = time.time()
|
||||
try:
|
||||
if await is_ban(None, group_id):
|
||||
raise SkipPluginException("群组处于黑名单中...")
|
||||
finally:
|
||||
# 记录执行时间
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||
logger.warning(
|
||||
f"group_handle 耗时: {elapsed:.3f}s",
|
||||
LOGGER_COMMAND,
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
|
||||
async def user_handle(module: str, entity: EntityIDs, session: Uninfo):
|
||||
async def user_handle(module: str, entity: EntityIDs, session: Uninfo) -> None:
|
||||
"""用户ban检查
|
||||
|
||||
参数:
|
||||
module: 插件模块名
|
||||
ban_dao: BanConsole数据访问对象
|
||||
entity: 实体ID信息
|
||||
session: Uninfo
|
||||
|
||||
异常:
|
||||
SkipPluginException: 用户处于黑名单
|
||||
"""
|
||||
ban_result = Config.get_config("hook", "BAN_RESULT")
|
||||
time = await is_ban(entity.user_id, entity.group_id)
|
||||
if not time:
|
||||
return
|
||||
time_str = format_time(time)
|
||||
plugin_dao = DataAccess(PluginInfo)
|
||||
db_plugin = await plugin_dao.safe_get_or_none(module=module)
|
||||
if (
|
||||
db_plugin
|
||||
and not db_plugin.ignore_prompt
|
||||
and time != -1
|
||||
and ban_result
|
||||
and freq.is_send_limit_message(db_plugin, entity.user_id, False)
|
||||
):
|
||||
await send_message(
|
||||
session,
|
||||
[
|
||||
At(flag="user", target=entity.user_id),
|
||||
f"{ban_result}\n在..在 {time_str} 后才会理你喔",
|
||||
],
|
||||
entity.user_id,
|
||||
)
|
||||
raise SkipPluginException("用户处于黑名单中...")
|
||||
start_time = time.time()
|
||||
try:
|
||||
ban_result = Config.get_config("hook", "BAN_RESULT")
|
||||
time_val = await is_ban(entity.user_id, entity.group_id)
|
||||
if not time_val:
|
||||
return
|
||||
time_str = format_time(time_val)
|
||||
plugin_dao = DataAccess(PluginInfo)
|
||||
try:
|
||||
db_plugin = await asyncio.wait_for(
|
||||
plugin_dao.safe_get_or_none(module=module), timeout=DB_TIMEOUT_SECONDS
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"查询插件信息超时: {module}", LOGGER_COMMAND)
|
||||
# 超时时不阻塞,继续执行
|
||||
raise SkipPluginException("用户处于黑名单中...")
|
||||
|
||||
if (
|
||||
db_plugin
|
||||
and not db_plugin.ignore_prompt
|
||||
and time_val != -1
|
||||
and ban_result
|
||||
and freq.is_send_limit_message(db_plugin, entity.user_id, False)
|
||||
):
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
send_message(
|
||||
session,
|
||||
[
|
||||
At(flag="user", target=entity.user_id),
|
||||
f"{ban_result}\n在..在 {time_str} 后才会理你喔",
|
||||
],
|
||||
entity.user_id,
|
||||
),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"发送消息超时: {entity.user_id}", LOGGER_COMMAND)
|
||||
raise SkipPluginException("用户处于黑名单中...")
|
||||
finally:
|
||||
# 记录执行时间
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||
logger.warning(
|
||||
f"user_handle 耗时: {elapsed:.3f}s",
|
||||
LOGGER_COMMAND,
|
||||
session=session,
|
||||
)
|
||||
|
||||
|
||||
async def auth_ban(matcher: Matcher, bot: Bot, session: Uninfo):
|
||||
if not check_plugin_type(matcher):
|
||||
return
|
||||
if not matcher.plugin_name:
|
||||
return
|
||||
entity = get_entity_ids(session)
|
||||
if entity.user_id in bot.config.superusers:
|
||||
return
|
||||
if entity.group_id:
|
||||
await group_handle(entity.group_id)
|
||||
if entity.user_id:
|
||||
await user_handle(matcher.plugin_name, entity, session)
|
||||
async def auth_ban(matcher: Matcher, bot: Bot, session: Uninfo) -> None:
|
||||
"""权限检查 - ban 检查
|
||||
|
||||
参数:
|
||||
matcher: Matcher
|
||||
bot: Bot
|
||||
session: Uninfo
|
||||
"""
|
||||
start_time = time.time()
|
||||
try:
|
||||
if not check_plugin_type(matcher):
|
||||
return
|
||||
if not matcher.plugin_name:
|
||||
return
|
||||
entity = get_entity_ids(session)
|
||||
if entity.user_id in bot.config.superusers:
|
||||
return
|
||||
if entity.group_id:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
group_handle(entity.group_id), timeout=DB_TIMEOUT_SECONDS
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"群组ban检查超时: {entity.group_id}", LOGGER_COMMAND)
|
||||
# 超时时不阻塞,继续执行
|
||||
|
||||
if entity.user_id:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
user_handle(matcher.plugin_name, entity, session),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"用户ban检查超时: {entity.user_id}", LOGGER_COMMAND)
|
||||
# 超时时不阻塞,继续执行
|
||||
finally:
|
||||
# 记录总执行时间
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||
logger.warning(
|
||||
f"auth_ban 总耗时: {elapsed:.3f}s, plugin={matcher.plugin_name}",
|
||||
LOGGER_COMMAND,
|
||||
session=session,
|
||||
)
|
||||
|
||||
@ -1,8 +1,14 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from zhenxun.models.bot_console import BotConsole
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.services.data_access import DataAccess
|
||||
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.common_utils import CommonUtils
|
||||
|
||||
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
|
||||
from .exception import SkipPluginException
|
||||
|
||||
|
||||
@ -17,11 +23,33 @@ async def auth_bot(plugin: PluginInfo, bot_id: str):
|
||||
SkipPluginException: 忽略插件
|
||||
SkipPluginException: 忽略插件
|
||||
"""
|
||||
bot_dao = DataAccess(BotConsole)
|
||||
bot = await bot_dao.safe_get_or_none(bot_id=bot_id)
|
||||
if not bot or not bot.status:
|
||||
raise SkipPluginException("Bot不存在或休眠中阻断权限检测...")
|
||||
if CommonUtils.format(plugin.module) in bot.block_plugins:
|
||||
raise SkipPluginException(
|
||||
f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭..."
|
||||
)
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 从数据库或缓存中获取 bot 信息
|
||||
bot_dao = DataAccess(BotConsole)
|
||||
|
||||
try:
|
||||
bot: BotConsole | None = await asyncio.wait_for(
|
||||
bot_dao.safe_get_or_none(bot_id=bot_id), timeout=DB_TIMEOUT_SECONDS
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"查询Bot信息超时: bot_id={bot_id}", LOGGER_COMMAND)
|
||||
# 超时时不阻塞,继续执行
|
||||
return
|
||||
|
||||
if not bot or not bot.status:
|
||||
raise SkipPluginException("Bot不存在或休眠中阻断权限检测...")
|
||||
if CommonUtils.format(plugin.module) in bot.block_plugins:
|
||||
raise SkipPluginException(
|
||||
f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭..."
|
||||
)
|
||||
finally:
|
||||
# 记录执行时间
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||
logger.warning(
|
||||
f"auth_bot 耗时: {elapsed:.3f}s, "
|
||||
f"bot_id={bot_id}, plugin={plugin.module}",
|
||||
LOGGER_COMMAND,
|
||||
)
|
||||
|
||||
@ -1,8 +1,12 @@
|
||||
import time
|
||||
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.models.user_console import UserConsole
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
|
||||
from .exception import SkipPluginException
|
||||
from .utils import send_message
|
||||
|
||||
@ -18,8 +22,20 @@ async def auth_cost(user: UserConsole, plugin: PluginInfo, session: Uninfo) -> i
|
||||
返回:
|
||||
int: 需要消耗的金币
|
||||
"""
|
||||
if user.gold < plugin.cost_gold:
|
||||
"""插件消耗金币不足"""
|
||||
await send_message(session, f"金币不足..该功能需要{plugin.cost_gold}金币..")
|
||||
raise SkipPluginException(f"{plugin.name}({plugin.module}) 金币限制...")
|
||||
return plugin.cost_gold
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
if user.gold < plugin.cost_gold:
|
||||
"""插件消耗金币不足"""
|
||||
await send_message(session, f"金币不足..该功能需要{plugin.cost_gold}金币..")
|
||||
raise SkipPluginException(f"{plugin.name}({plugin.module}) 金币限制...")
|
||||
return plugin.cost_gold
|
||||
finally:
|
||||
# 记录执行时间
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||
logger.warning(
|
||||
f"auth_cost 耗时: {elapsed:.3f}s, plugin={plugin.module}",
|
||||
LOGGER_COMMAND,
|
||||
session=session,
|
||||
)
|
||||
|
||||
@ -1,11 +1,16 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from nonebot_plugin_alconna import UniMsg
|
||||
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.services.data_access import DataAccess
|
||||
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.utils import EntityIDs
|
||||
|
||||
from .config import SwitchEnum
|
||||
from .config import LOGGER_COMMAND, WARNING_THRESHOLD, SwitchEnum
|
||||
from .exception import SkipPluginException
|
||||
|
||||
|
||||
@ -17,21 +22,47 @@ async def auth_group(plugin: PluginInfo, entity: EntityIDs, message: UniMsg):
|
||||
entity: EntityIDs
|
||||
message: UniMsg
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
if not entity.group_id:
|
||||
return
|
||||
text = message.extract_plain_text()
|
||||
group_dao = DataAccess(GroupConsole)
|
||||
group = await group_dao.safe_get_or_none(
|
||||
group_id=entity.group_id, channel_id__isnull=True
|
||||
)
|
||||
if not group:
|
||||
raise SkipPluginException("群组信息不存在...")
|
||||
if group.level < 0:
|
||||
raise SkipPluginException("群组黑名单, 目标群组群权限权限-1...")
|
||||
if text.strip() != SwitchEnum.ENABLE and not group.status:
|
||||
raise SkipPluginException("群组休眠状态...")
|
||||
if plugin.level > group.level:
|
||||
raise SkipPluginException(
|
||||
f"{plugin.name}({plugin.module}) 群等级限制,"
|
||||
f"该功能需要的群等级: {plugin.level}..."
|
||||
)
|
||||
|
||||
try:
|
||||
text = message.extract_plain_text()
|
||||
|
||||
# 从数据库或缓存中获取群组信息
|
||||
group_dao = DataAccess(GroupConsole)
|
||||
|
||||
try:
|
||||
group: GroupConsole | None = await asyncio.wait_for(
|
||||
group_dao.safe_get_or_none(
|
||||
group_id=entity.group_id, channel_id__isnull=True
|
||||
),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("查询群组信息超时", LOGGER_COMMAND, session=entity.user_id)
|
||||
# 超时时不阻塞,继续执行
|
||||
return
|
||||
|
||||
if not group:
|
||||
raise SkipPluginException("群组信息不存在...")
|
||||
if group.level < 0:
|
||||
raise SkipPluginException("群组黑名单, 目标群组群权限权限-1...")
|
||||
if text.strip() != SwitchEnum.ENABLE and not group.status:
|
||||
raise SkipPluginException("群组休眠状态...")
|
||||
if plugin.level > group.level:
|
||||
raise SkipPluginException(
|
||||
f"{plugin.name}({plugin.module}) 群等级限制,"
|
||||
f"该功能需要的群等级: {plugin.level}..."
|
||||
)
|
||||
finally:
|
||||
# 记录执行时间
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||
logger.warning(
|
||||
f"auth_group 耗时: {elapsed:.3f}s, plugin={plugin.module}",
|
||||
LOGGER_COMMAND,
|
||||
session=entity.user_id,
|
||||
group_id=entity.group_id,
|
||||
)
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import ClassVar
|
||||
|
||||
import nonebot
|
||||
@ -6,6 +8,7 @@ from pydantic import BaseModel
|
||||
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.models.plugin_limit import PluginLimit
|
||||
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import LimitWatchType, PluginLimitType
|
||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||
@ -17,7 +20,7 @@ from zhenxun.utils.utils import (
|
||||
get_entity_ids,
|
||||
)
|
||||
|
||||
from .config import LOGGER_COMMAND
|
||||
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
|
||||
from .exception import SkipPluginException
|
||||
|
||||
driver = nonebot.get_driver()
|
||||
@ -39,17 +42,61 @@ class Limit(BaseModel):
|
||||
|
||||
class LimitManager:
|
||||
add_module: ClassVar[list] = []
|
||||
last_update_time: ClassVar[float] = 0
|
||||
update_interval: ClassVar[float] = 6000 # 1小时更新一次
|
||||
is_updating: ClassVar[bool] = False # 防止并发更新
|
||||
|
||||
cd_limit: ClassVar[dict[str, Limit]] = {}
|
||||
block_limit: ClassVar[dict[str, Limit]] = {}
|
||||
count_limit: ClassVar[dict[str, Limit]] = {}
|
||||
|
||||
# 模块限制缓存,避免频繁查询数据库
|
||||
module_limit_cache: ClassVar[dict[str, tuple[float, list[PluginLimit]]]] = {}
|
||||
module_cache_ttl: ClassVar[float] = 60 # 模块缓存有效期(秒)
|
||||
|
||||
@classmethod
|
||||
async def init_limit(cls):
|
||||
"""初始化限制"""
|
||||
limit_list = await PluginLimit.filter(status=True).all()
|
||||
for limit in limit_list:
|
||||
cls.add_limit(limit)
|
||||
cls.last_update_time = time.time()
|
||||
try:
|
||||
await asyncio.wait_for(cls.update_limits(), timeout=DB_TIMEOUT_SECONDS * 2)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("初始化限制超时", LOGGER_COMMAND)
|
||||
|
||||
@classmethod
|
||||
async def update_limits(cls):
|
||||
"""更新限制信息"""
|
||||
# 防止并发更新
|
||||
if cls.is_updating:
|
||||
return
|
||||
|
||||
cls.is_updating = True
|
||||
try:
|
||||
start_time = time.time()
|
||||
try:
|
||||
limit_list = await asyncio.wait_for(
|
||||
PluginLimit.filter(status=True).all(), timeout=DB_TIMEOUT_SECONDS
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("查询限制信息超时", LOGGER_COMMAND)
|
||||
cls.is_updating = False
|
||||
return
|
||||
|
||||
# 清空旧数据
|
||||
cls.add_module = []
|
||||
cls.cd_limit = {}
|
||||
cls.block_limit = {}
|
||||
cls.count_limit = {}
|
||||
# 添加新数据
|
||||
for limit in limit_list:
|
||||
cls.add_limit(limit)
|
||||
|
||||
cls.last_update_time = time.time()
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的更新
|
||||
logger.warning(f"更新限制信息耗时: {elapsed:.3f}s", LOGGER_COMMAND)
|
||||
finally:
|
||||
cls.is_updating = False
|
||||
|
||||
@classmethod
|
||||
def add_limit(cls, limit: PluginLimit):
|
||||
@ -99,6 +146,46 @@ class LimitManager:
|
||||
)
|
||||
limiter.set_false(key_type)
|
||||
|
||||
@classmethod
|
||||
async def get_module_limits(cls, module: str) -> list[PluginLimit]:
|
||||
"""获取模块的限制信息,使用缓存减少数据库查询
|
||||
|
||||
参数:
|
||||
module: 模块名
|
||||
|
||||
返回:
|
||||
list[PluginLimit]: 限制列表
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
# 检查缓存
|
||||
if module in cls.module_limit_cache:
|
||||
cache_time, limits = cls.module_limit_cache[module]
|
||||
if current_time - cache_time < cls.module_cache_ttl:
|
||||
return limits
|
||||
|
||||
# 缓存不存在或已过期,从数据库查询
|
||||
try:
|
||||
start_time = time.time()
|
||||
limits = await asyncio.wait_for(
|
||||
PluginLimit.filter(module=module, status=True).all(),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的查询
|
||||
logger.warning(
|
||||
f"查询模块限制信息耗时: {elapsed:.3f}s, 模块: {module}",
|
||||
LOGGER_COMMAND,
|
||||
)
|
||||
|
||||
# 更新缓存
|
||||
cls.module_limit_cache[module] = (current_time, limits)
|
||||
return limits
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"查询模块限制信息超时: {module}", LOGGER_COMMAND)
|
||||
# 超时时返回空列表,避免阻塞
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
async def check(
|
||||
cls,
|
||||
@ -118,12 +205,40 @@ class LimitManager:
|
||||
异常:
|
||||
IgnoredException: IgnoredException
|
||||
"""
|
||||
if limit_model := cls.cd_limit.get(module):
|
||||
await cls.__check(limit_model, user_id, group_id, channel_id)
|
||||
if limit_model := cls.block_limit.get(module):
|
||||
await cls.__check(limit_model, user_id, group_id, channel_id)
|
||||
if limit_model := cls.count_limit.get(module):
|
||||
await cls.__check(limit_model, user_id, group_id, channel_id)
|
||||
start_time = time.time()
|
||||
|
||||
# 定期更新全局限制信息
|
||||
if (
|
||||
time.time() - cls.last_update_time > cls.update_interval
|
||||
and not cls.is_updating
|
||||
):
|
||||
# 使用异步任务更新,避免阻塞当前请求
|
||||
asyncio.create_task(cls.update_limits()) # noqa: RUF006
|
||||
|
||||
# 如果模块不在已加载列表中,只加载该模块的限制
|
||||
if module not in cls.add_module:
|
||||
limits = await cls.get_module_limits(module)
|
||||
for limit in limits:
|
||||
cls.add_limit(limit)
|
||||
|
||||
# 检查各种限制
|
||||
try:
|
||||
if limit_model := cls.cd_limit.get(module):
|
||||
await cls.__check(limit_model, user_id, group_id, channel_id)
|
||||
if limit_model := cls.block_limit.get(module):
|
||||
await cls.__check(limit_model, user_id, group_id, channel_id)
|
||||
if limit_model := cls.count_limit.get(module):
|
||||
await cls.__check(limit_model, user_id, group_id, channel_id)
|
||||
finally:
|
||||
# 记录总执行时间
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||
logger.warning(
|
||||
f"限制检查耗时: {elapsed:.3f}s, 模块: {module}",
|
||||
LOGGER_COMMAND,
|
||||
session=user_id,
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def __check(
|
||||
@ -158,7 +273,13 @@ class LimitManager:
|
||||
key_type = channel_id or group_id
|
||||
if is_limit and not limiter.check(key_type):
|
||||
if limit.result:
|
||||
await MessageUtils.build_message(limit.result).send()
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
MessageUtils.build_message(limit.result).send(),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"发送限制消息超时: {limit.module}", LOGGER_COMMAND)
|
||||
raise SkipPluginException(
|
||||
f"{limit.module}({limit.limit_type}) 正在限制中..."
|
||||
)
|
||||
@ -185,11 +306,13 @@ async def auth_limit(plugin: PluginInfo, session: Uninfo):
|
||||
session: Uninfo
|
||||
"""
|
||||
entity = get_entity_ids(session)
|
||||
if plugin.module not in LimitManager.add_module:
|
||||
limit_list = await PluginLimit.filter(module=plugin.module, status=True).all()
|
||||
for limit in limit_list:
|
||||
LimitManager.add_limit(limit)
|
||||
if entity.user_id:
|
||||
await LimitManager.check(
|
||||
plugin.module, entity.user_id, entity.group_id, entity.channel_id
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
LimitManager.check(
|
||||
plugin.module, entity.user_id, entity.group_id, entity.channel_id
|
||||
),
|
||||
timeout=DB_TIMEOUT_SECONDS * 2, # 给予更长的超时时间
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"检查插件限制超时: {plugin.module}", LOGGER_COMMAND)
|
||||
# 超时时不抛出异常,允许继续执行
|
||||
|
||||
@ -1,13 +1,19 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from nonebot.adapters import Event
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.services.data_access import DataAccess
|
||||
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.common_utils import CommonUtils
|
||||
from zhenxun.utils.enum import BlockType
|
||||
from zhenxun.utils.utils import get_entity_ids
|
||||
|
||||
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
|
||||
from .exception import IsSuperuserException, SkipPluginException
|
||||
from .utils import freq, is_poke, send_message
|
||||
|
||||
@ -21,69 +27,89 @@ class GroupCheck:
|
||||
self.is_poke = is_poke
|
||||
self.plugin = plugin
|
||||
self.group_dao = DataAccess(GroupConsole)
|
||||
|
||||
async def __get_data(self):
|
||||
return await self.group_dao.safe_get_or_none(
|
||||
group_id=self.group_id, channel_id__isnull=True
|
||||
)
|
||||
self.group_data = None
|
||||
|
||||
async def check(self):
|
||||
await self.check_superuser_block(self.plugin)
|
||||
|
||||
async def check_superuser_block(self, plugin: PluginInfo):
|
||||
"""超级用户禁用群组插件检测
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
|
||||
异常:
|
||||
IgnoredException: 忽略插件
|
||||
"""
|
||||
group = await self.__get_data()
|
||||
if group and CommonUtils.format(plugin.module) in group.superuser_block_plugin:
|
||||
if freq.is_send_limit_message(plugin, group.group_id, self.is_poke):
|
||||
await send_message(
|
||||
self.session, "超级管理员禁用了该群此功能...", self.group_id
|
||||
start_time = time.time()
|
||||
try:
|
||||
# 只查询一次数据库,使用 DataAccess 的缓存机制
|
||||
try:
|
||||
self.group_data = await asyncio.wait_for(
|
||||
self.group_dao.safe_get_or_none(
|
||||
group_id=self.group_id, channel_id__isnull=True
|
||||
),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
raise SkipPluginException(
|
||||
f"{plugin.name}({plugin.module}) 超级管理员禁用了该群此功能..."
|
||||
)
|
||||
await self.check_normal_block(self.plugin)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"查询群组数据超时: {self.group_id}", LOGGER_COMMAND)
|
||||
return # 超时时不阻塞,继续执行
|
||||
|
||||
async def check_normal_block(self, plugin: PluginInfo):
|
||||
"""群组插件状态
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
|
||||
异常:
|
||||
IgnoredException: 忽略插件
|
||||
"""
|
||||
group = await self.__get_data()
|
||||
if group and CommonUtils.format(plugin.module) in group.block_plugin:
|
||||
if freq.is_send_limit_message(plugin, self.group_id, self.is_poke):
|
||||
await send_message(self.session, "该群未开启此功能...", self.group_id)
|
||||
raise SkipPluginException(f"{plugin.name}({plugin.module}) 未开启此功能...")
|
||||
await self.check_global_block(self.plugin)
|
||||
|
||||
async def check_global_block(self, plugin: PluginInfo):
|
||||
"""全局禁用插件检测
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
|
||||
异常:
|
||||
IgnoredException: 忽略插件
|
||||
"""
|
||||
if plugin.block_type == BlockType.GROUP:
|
||||
"""全局群组禁用"""
|
||||
if freq.is_send_limit_message(plugin, self.group_id, self.is_poke):
|
||||
await send_message(
|
||||
self.session, "该功能在群组中已被禁用...", self.group_id
|
||||
# 检查超级用户禁用
|
||||
if (
|
||||
self.group_data
|
||||
and CommonUtils.format(self.plugin.module)
|
||||
in self.group_data.superuser_block_plugin
|
||||
):
|
||||
if freq.is_send_limit_message(self.plugin, self.group_id, self.is_poke):
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
send_message(
|
||||
self.session,
|
||||
"超级管理员禁用了该群此功能...",
|
||||
self.group_id,
|
||||
),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"发送消息超时: {self.group_id}", LOGGER_COMMAND)
|
||||
raise SkipPluginException(
|
||||
f"{self.plugin.name}({self.plugin.module})"
|
||||
f" 超级管理员禁用了该群此功能..."
|
||||
)
|
||||
|
||||
# 检查普通禁用
|
||||
if (
|
||||
self.group_data
|
||||
and CommonUtils.format(self.plugin.module)
|
||||
in self.group_data.block_plugin
|
||||
):
|
||||
if freq.is_send_limit_message(self.plugin, self.group_id, self.is_poke):
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
send_message(
|
||||
self.session, "该群未开启此功能...", self.group_id
|
||||
),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"发送消息超时: {self.group_id}", LOGGER_COMMAND)
|
||||
raise SkipPluginException(
|
||||
f"{self.plugin.name}({self.plugin.module}) 未开启此功能..."
|
||||
)
|
||||
|
||||
# 检查全局禁用
|
||||
if self.plugin.block_type == BlockType.GROUP:
|
||||
if freq.is_send_limit_message(self.plugin, self.group_id, self.is_poke):
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
send_message(
|
||||
self.session, "该功能在群组中已被禁用...", self.group_id
|
||||
),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"发送消息超时: {self.group_id}", LOGGER_COMMAND)
|
||||
raise SkipPluginException(
|
||||
f"{self.plugin.name}({self.plugin.module})该插件在群组中已被禁用..."
|
||||
)
|
||||
finally:
|
||||
# 记录执行时间
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||
logger.warning(
|
||||
f"GroupCheck.check 耗时: {elapsed:.3f}s, 群组: {self.group_id}",
|
||||
LOGGER_COMMAND,
|
||||
)
|
||||
raise SkipPluginException(
|
||||
f"{plugin.name}({plugin.module}) 该插件在群组中已被禁用..."
|
||||
)
|
||||
|
||||
|
||||
class PluginCheck:
|
||||
@ -92,6 +118,7 @@ class PluginCheck:
|
||||
self.is_poke = is_poke
|
||||
self.group_id = group_id
|
||||
self.group_dao = DataAccess(GroupConsole)
|
||||
self.group_data = None
|
||||
|
||||
async def check_user(self, plugin: PluginInfo):
|
||||
"""全局私聊禁用检测
|
||||
@ -104,7 +131,13 @@ class PluginCheck:
|
||||
"""
|
||||
if plugin.block_type == BlockType.PRIVATE:
|
||||
if freq.is_send_limit_message(plugin, self.session.user.id, self.is_poke):
|
||||
await send_message(self.session, "该功能在私聊中已被禁用...")
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
send_message(self.session, "该功能在私聊中已被禁用..."),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("发送消息超时", LOGGER_COMMAND)
|
||||
raise SkipPluginException(
|
||||
f"{plugin.name}({plugin.module}) 该插件在私聊中已被禁用..."
|
||||
)
|
||||
@ -118,19 +151,46 @@ class PluginCheck:
|
||||
异常:
|
||||
IgnoredException: 忽略插件
|
||||
"""
|
||||
if plugin.status or plugin.block_type != BlockType.ALL:
|
||||
return
|
||||
"""全局状态"""
|
||||
if self.group_id:
|
||||
group = await self.group_dao.safe_get_or_none(
|
||||
group_id=self.group_id, channel_id__isnull=True
|
||||
start_time = time.time()
|
||||
try:
|
||||
if plugin.status or plugin.block_type != BlockType.ALL:
|
||||
return
|
||||
"""全局状态"""
|
||||
if self.group_id:
|
||||
# 使用 DataAccess 的缓存机制
|
||||
try:
|
||||
self.group_data = await asyncio.wait_for(
|
||||
self.group_dao.safe_get_or_none(
|
||||
group_id=self.group_id, channel_id__isnull=True
|
||||
),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"查询群组数据超时: {self.group_id}", LOGGER_COMMAND)
|
||||
return # 超时时不阻塞,继续执行
|
||||
|
||||
if self.group_data and self.group_data.is_super:
|
||||
raise IsSuperuserException()
|
||||
|
||||
sid = self.group_id or self.session.user.id
|
||||
if freq.is_send_limit_message(plugin, sid, self.is_poke):
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
send_message(self.session, "全局未开启此功能...", sid),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"发送消息超时: {sid}", LOGGER_COMMAND)
|
||||
raise SkipPluginException(
|
||||
f"{plugin.name}({plugin.module}) 全局未开启此功能..."
|
||||
)
|
||||
if group and group.is_super:
|
||||
raise IsSuperuserException()
|
||||
sid = self.group_id or self.session.user.id
|
||||
if freq.is_send_limit_message(plugin, sid, self.is_poke):
|
||||
await send_message(self.session, "全局未开启此功能...", sid)
|
||||
raise SkipPluginException(f"{plugin.name}({plugin.module}) 全局未开启此功能...")
|
||||
finally:
|
||||
# 记录执行时间
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||
logger.warning(
|
||||
f"PluginCheck.check_global 耗时: {elapsed:.3f}s", LOGGER_COMMAND
|
||||
)
|
||||
|
||||
|
||||
async def auth_plugin(plugin: PluginInfo, session: Uninfo, event: Event):
|
||||
@ -141,12 +201,42 @@ async def auth_plugin(plugin: PluginInfo, session: Uninfo, event: Event):
|
||||
session: Uninfo
|
||||
event: Event
|
||||
"""
|
||||
entity = get_entity_ids(session)
|
||||
is_poke_event = is_poke(event)
|
||||
user_check = PluginCheck(entity.group_id, session, is_poke_event)
|
||||
if entity.group_id:
|
||||
group_check = GroupCheck(plugin, entity.group_id, session, is_poke_event)
|
||||
await group_check.check()
|
||||
else:
|
||||
await user_check.check_user(plugin)
|
||||
await user_check.check_global(plugin)
|
||||
start_time = time.time()
|
||||
try:
|
||||
entity = get_entity_ids(session)
|
||||
is_poke_event = is_poke(event)
|
||||
user_check = PluginCheck(entity.group_id, session, is_poke_event)
|
||||
|
||||
if entity.group_id:
|
||||
group_check = GroupCheck(plugin, entity.group_id, session, is_poke_event)
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
group_check.check(), timeout=DB_TIMEOUT_SECONDS * 2
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"群组检查超时: {entity.group_id}", LOGGER_COMMAND)
|
||||
# 超时时不阻塞,继续执行
|
||||
else:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
user_check.check_user(plugin), timeout=DB_TIMEOUT_SECONDS
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("用户检查超时", LOGGER_COMMAND)
|
||||
# 超时时不阻塞,继续执行
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
user_check.check_global(plugin), timeout=DB_TIMEOUT_SECONDS
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("全局检查超时", LOGGER_COMMAND)
|
||||
# 超时时不阻塞,继续执行
|
||||
finally:
|
||||
# 记录总执行时间
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||
logger.warning(
|
||||
f"auth_plugin 总耗时: {elapsed:.3f}s, 模块: {plugin.module}",
|
||||
LOGGER_COMMAND,
|
||||
)
|
||||
|
||||
@ -11,3 +11,6 @@ LOGGER_COMMAND = "AuthChecker"
|
||||
class SwitchEnum(StrEnum):
|
||||
ENABLE = "醒来"
|
||||
DISABLE = "休息吧"
|
||||
|
||||
|
||||
WARNING_THRESHOLD = 0.5 # 警告阈值(秒)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot.exception import IgnoredException
|
||||
@ -24,13 +25,88 @@ from .auth.auth_group import auth_group
|
||||
from .auth.auth_limit import LimitManager, auth_limit
|
||||
from .auth.auth_plugin import auth_plugin
|
||||
from .auth.bot_filter import bot_filter
|
||||
from .auth.config import LOGGER_COMMAND
|
||||
from .auth.config import LOGGER_COMMAND, WARNING_THRESHOLD
|
||||
from .auth.exception import (
|
||||
IsSuperuserException,
|
||||
PermissionExemption,
|
||||
SkipPluginException,
|
||||
)
|
||||
|
||||
# 超时设置(秒)
|
||||
TIMEOUT_SECONDS = 5.0
|
||||
# 熔断计数器
|
||||
CIRCUIT_BREAKERS = {
|
||||
"auth_ban": {"failures": 0, "threshold": 3, "active": False, "reset_time": 0},
|
||||
"auth_bot": {"failures": 0, "threshold": 3, "active": False, "reset_time": 0},
|
||||
"auth_group": {"failures": 0, "threshold": 3, "active": False, "reset_time": 0},
|
||||
"auth_admin": {"failures": 0, "threshold": 3, "active": False, "reset_time": 0},
|
||||
"auth_plugin": {"failures": 0, "threshold": 3, "active": False, "reset_time": 0},
|
||||
"auth_limit": {"failures": 0, "threshold": 3, "active": False, "reset_time": 0},
|
||||
}
|
||||
# 熔断重置时间(秒)
|
||||
CIRCUIT_RESET_TIME = 300 # 5分钟
|
||||
|
||||
|
||||
# 超时装饰器
|
||||
async def with_timeout(coro, timeout=TIMEOUT_SECONDS, name=None):
|
||||
"""带超时控制的协程执行
|
||||
|
||||
参数:
|
||||
coro: 要执行的协程
|
||||
timeout: 超时时间(秒)
|
||||
name: 操作名称,用于日志记录
|
||||
|
||||
返回:
|
||||
协程的返回值,或者在超时时抛出 TimeoutError
|
||||
"""
|
||||
try:
|
||||
return await asyncio.wait_for(coro, timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
if name:
|
||||
logger.error(f"{name} 操作超时 (>{timeout}s)", LOGGER_COMMAND)
|
||||
# 更新熔断计数器
|
||||
if name in CIRCUIT_BREAKERS:
|
||||
CIRCUIT_BREAKERS[name]["failures"] += 1
|
||||
if (
|
||||
CIRCUIT_BREAKERS[name]["failures"]
|
||||
>= CIRCUIT_BREAKERS[name]["threshold"]
|
||||
and not CIRCUIT_BREAKERS[name]["active"]
|
||||
):
|
||||
CIRCUIT_BREAKERS[name]["active"] = True
|
||||
CIRCUIT_BREAKERS[name]["reset_time"] = (
|
||||
time.time() + CIRCUIT_RESET_TIME
|
||||
)
|
||||
logger.warning(
|
||||
f"{name} 熔断器已激活,将在 {CIRCUIT_RESET_TIME} 秒后重置",
|
||||
LOGGER_COMMAND,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
# 检查熔断状态
|
||||
def check_circuit_breaker(name):
|
||||
"""检查熔断器状态
|
||||
|
||||
参数:
|
||||
name: 操作名称
|
||||
|
||||
返回:
|
||||
bool: 是否已熔断
|
||||
"""
|
||||
if name not in CIRCUIT_BREAKERS:
|
||||
return False
|
||||
|
||||
# 检查是否需要重置熔断器
|
||||
if (
|
||||
CIRCUIT_BREAKERS[name]["active"]
|
||||
and time.time() > CIRCUIT_BREAKERS[name]["reset_time"]
|
||||
):
|
||||
CIRCUIT_BREAKERS[name]["active"] = False
|
||||
CIRCUIT_BREAKERS[name]["failures"] = 0
|
||||
logger.info(f"{name} 熔断器已重置", LOGGER_COMMAND)
|
||||
|
||||
return CIRCUIT_BREAKERS[name]["active"]
|
||||
|
||||
|
||||
async def get_plugin_and_user(
|
||||
module: str, user_id: str
|
||||
@ -52,7 +128,25 @@ async def get_plugin_and_user(
|
||||
"""
|
||||
user_dao = DataAccess(UserConsole)
|
||||
plugin_dao = DataAccess(PluginInfo)
|
||||
plugin = await plugin_dao.safe_get_or_none(module=module)
|
||||
|
||||
# 并行查询插件和用户数据
|
||||
plugin_task = plugin_dao.safe_get_or_none(module=module)
|
||||
user_task = user_dao.safe_get_or_none(user_id=user_id)
|
||||
|
||||
try:
|
||||
plugin, user = await with_timeout(
|
||||
asyncio.gather(plugin_task, user_task), name="get_plugin_and_user"
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
# 如果并行查询超时,尝试串行查询
|
||||
logger.warning("并行查询超时,尝试串行查询", LOGGER_COMMAND)
|
||||
plugin = await with_timeout(
|
||||
plugin_dao.safe_get_or_none(module=module), name="get_plugin"
|
||||
)
|
||||
user = await with_timeout(
|
||||
user_dao.safe_get_or_none(user_id=user_id), name="get_user"
|
||||
)
|
||||
|
||||
if not plugin:
|
||||
raise PermissionExemption(f"插件:{module} 数据不存在,已跳过权限检查...")
|
||||
if plugin.plugin_type == PluginType.HIDDEN:
|
||||
@ -87,7 +181,7 @@ async def get_plugin_cost(
|
||||
返回:
|
||||
int: 调用插件金币费用
|
||||
"""
|
||||
cost_gold = await auth_cost(user, plugin, session)
|
||||
cost_gold = await with_timeout(auth_cost(user, plugin, session), name="auth_cost")
|
||||
if session.user.id in bot.config.superusers:
|
||||
if plugin.plugin_type == PluginType.SUPERUSER:
|
||||
raise IsSuperuserException()
|
||||
@ -107,22 +201,51 @@ async def reduce_gold(user_id: str, module: str, cost_gold: int, session: Uninfo
|
||||
"""
|
||||
user_dao = DataAccess(UserConsole)
|
||||
try:
|
||||
await UserConsole.reduce_gold(
|
||||
user_id,
|
||||
cost_gold,
|
||||
GoldHandle.PLUGIN,
|
||||
module,
|
||||
PlatformUtils.get_platform(session),
|
||||
await with_timeout(
|
||||
UserConsole.reduce_gold(
|
||||
user_id,
|
||||
cost_gold,
|
||||
GoldHandle.PLUGIN,
|
||||
module,
|
||||
PlatformUtils.get_platform(session),
|
||||
),
|
||||
name="reduce_gold",
|
||||
)
|
||||
except InsufficientGold:
|
||||
if u := await UserConsole.get_user(user_id):
|
||||
u.gold = 0
|
||||
await u.save(update_fields=["gold"])
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
f"扣除金币超时,用户: {user_id}, 金币: {cost_gold}",
|
||||
LOGGER_COMMAND,
|
||||
session=session,
|
||||
)
|
||||
|
||||
# 清除缓存,使下次查询时从数据库获取最新数据
|
||||
await user_dao.clear_cache(user_id=user_id)
|
||||
logger.debug(f"调用功能花费金币: {cost_gold}", LOGGER_COMMAND, session=session)
|
||||
|
||||
|
||||
# 辅助函数,用于记录每个 hook 的执行时间
|
||||
async def time_hook(coro, name, time_dict):
|
||||
start = time.time()
|
||||
try:
|
||||
# 检查熔断状态
|
||||
if check_circuit_breaker(name):
|
||||
logger.info(f"{name} 熔断器激活中,跳过执行", LOGGER_COMMAND)
|
||||
time_dict[name] = "熔断跳过"
|
||||
return
|
||||
|
||||
# 添加超时控制
|
||||
return await with_timeout(coro, name=name)
|
||||
except asyncio.TimeoutError:
|
||||
time_dict[name] = f"超时 (>{TIMEOUT_SECONDS}s)"
|
||||
finally:
|
||||
if name not in time_dict:
|
||||
time_dict[name] = f"{time.time() - start:.3f}s"
|
||||
|
||||
|
||||
async def auth(
|
||||
matcher: Matcher,
|
||||
event: Event,
|
||||
@ -139,26 +262,81 @@ async def auth(
|
||||
session: Uninfo
|
||||
message: UniMsg
|
||||
"""
|
||||
start_time = time.time()
|
||||
cost_gold = 0
|
||||
ignore_flag = False
|
||||
entity = get_entity_ids(session)
|
||||
module = matcher.plugin_name or ""
|
||||
|
||||
# 用于记录各个 hook 的执行时间
|
||||
hook_times = {}
|
||||
hooks_time = 0 # 初始化 hooks_time 变量
|
||||
|
||||
try:
|
||||
if not module:
|
||||
raise PermissionExemption("Matcher插件名称不存在...")
|
||||
plugin, user = await get_plugin_and_user(module, entity.user_id)
|
||||
cost_gold = await get_plugin_cost(bot, user, plugin, session)
|
||||
|
||||
# 获取插件和用户数据
|
||||
plugin_user_start = time.time()
|
||||
try:
|
||||
plugin, user = await with_timeout(
|
||||
get_plugin_and_user(module, entity.user_id), name="get_plugin_and_user"
|
||||
)
|
||||
hook_times["get_plugin_user"] = f"{time.time() - plugin_user_start:.3f}s"
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
f"获取插件和用户数据超时,模块: {module}",
|
||||
LOGGER_COMMAND,
|
||||
session=session,
|
||||
)
|
||||
raise PermissionExemption("获取插件和用户数据超时,请稍后再试...")
|
||||
|
||||
# 获取插件费用
|
||||
cost_start = time.time()
|
||||
try:
|
||||
cost_gold = await with_timeout(
|
||||
get_plugin_cost(bot, user, plugin, session), name="get_plugin_cost"
|
||||
)
|
||||
hook_times["cost_gold"] = f"{time.time() - cost_start:.3f}s"
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
f"获取插件费用超时,模块: {module}", LOGGER_COMMAND, session=session
|
||||
)
|
||||
# 继续执行,不阻止权限检查
|
||||
|
||||
# 执行 bot_filter
|
||||
bot_filter(session)
|
||||
await asyncio.gather(
|
||||
*[
|
||||
auth_ban(matcher, bot, session),
|
||||
auth_bot(plugin, bot.self_id),
|
||||
auth_group(plugin, entity, message),
|
||||
auth_admin(plugin, session),
|
||||
auth_plugin(plugin, session, event),
|
||||
auth_limit(plugin, session),
|
||||
]
|
||||
)
|
||||
|
||||
# 并行执行所有 hook 检查,并记录执行时间
|
||||
hooks_start = time.time()
|
||||
|
||||
# 创建所有 hook 任务
|
||||
hook_tasks = [
|
||||
time_hook(auth_ban(matcher, bot, session), "auth_ban", hook_times),
|
||||
time_hook(auth_bot(plugin, bot.self_id), "auth_bot", hook_times),
|
||||
time_hook(auth_group(plugin, entity, message), "auth_group", hook_times),
|
||||
time_hook(auth_admin(plugin, session), "auth_admin", hook_times),
|
||||
time_hook(auth_plugin(plugin, session, event), "auth_plugin", hook_times),
|
||||
time_hook(auth_limit(plugin, session), "auth_limit", hook_times),
|
||||
]
|
||||
|
||||
# 使用 gather 并行执行所有 hook,但添加总体超时控制
|
||||
try:
|
||||
await with_timeout(
|
||||
asyncio.gather(*hook_tasks),
|
||||
timeout=TIMEOUT_SECONDS * 2, # 给总体执行更多时间
|
||||
name="auth_hooks_gather",
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
f"权限检查 hooks 总体执行超时,模块: {module}",
|
||||
LOGGER_COMMAND,
|
||||
session=session,
|
||||
)
|
||||
# 不抛出异常,允许继续执行
|
||||
|
||||
hooks_time = time.time() - hooks_start
|
||||
|
||||
except SkipPluginException as e:
|
||||
LimitManager.unblock(module, entity.user_id, entity.group_id, entity.channel_id)
|
||||
logger.info(str(e), LOGGER_COMMAND, session=session)
|
||||
@ -167,7 +345,31 @@ async def auth(
|
||||
logger.debug("超级用户跳过权限检测...", LOGGER_COMMAND, session=session)
|
||||
except PermissionExemption as e:
|
||||
logger.info(str(e), LOGGER_COMMAND, session=session)
|
||||
|
||||
# 扣除金币
|
||||
if not ignore_flag and cost_gold > 0:
|
||||
await reduce_gold(entity.user_id, module, cost_gold, session)
|
||||
gold_start = time.time()
|
||||
try:
|
||||
await with_timeout(
|
||||
reduce_gold(entity.user_id, module, cost_gold, session),
|
||||
name="reduce_gold",
|
||||
)
|
||||
hook_times["reduce_gold"] = f"{time.time() - gold_start:.3f}s"
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
f"扣除金币超时,模块: {module}", LOGGER_COMMAND, session=session
|
||||
)
|
||||
|
||||
# 记录总执行时间
|
||||
total_time = time.time() - start_time
|
||||
if total_time > WARNING_THRESHOLD: # 如果总时间超过500ms,记录详细信息
|
||||
logger.warning(
|
||||
f"权限检查耗时过长: {total_time:.3f}s, 模块: {module}, "
|
||||
f"hooks时间: {hooks_time:.3f}s, "
|
||||
f"详情: {hook_times}",
|
||||
LOGGER_COMMAND,
|
||||
session=session,
|
||||
)
|
||||
|
||||
if ignore_flag:
|
||||
raise IgnoredException("权限检测 ignore")
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import asyncio
|
||||
|
||||
import aiofiles
|
||||
import nonebot
|
||||
from nonebot import get_loaded_plugins
|
||||
@ -112,24 +114,29 @@ async def _():
|
||||
await _handle_setting(plugin, plugin_list, limit_list)
|
||||
create_list = []
|
||||
update_list = []
|
||||
update_task_list = []
|
||||
for plugin in plugin_list:
|
||||
if plugin.module_path not in module2id:
|
||||
create_list.append(plugin)
|
||||
else:
|
||||
plugin.id = module2id[plugin.module_path]
|
||||
await plugin.save(
|
||||
update_fields=[
|
||||
"name",
|
||||
"author",
|
||||
"version",
|
||||
"admin_level",
|
||||
"plugin_type",
|
||||
"is_show",
|
||||
]
|
||||
update_task_list.append(
|
||||
plugin.save(
|
||||
update_fields=[
|
||||
"name",
|
||||
"author",
|
||||
"version",
|
||||
"admin_level",
|
||||
"plugin_type",
|
||||
"is_show",
|
||||
]
|
||||
)
|
||||
)
|
||||
update_list.append(plugin)
|
||||
if create_list:
|
||||
await PluginInfo.bulk_create(create_list, 10)
|
||||
if update_task_list:
|
||||
await asyncio.gather(*update_task_list)
|
||||
# if update_list:
|
||||
# # TODO: 批量更新无法更新plugin_type: tortoise.exceptions.OperationalError:
|
||||
# column "superuser" does not exist
|
||||
|
||||
@ -1,30 +0,0 @@
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||
|
||||
|
||||
@PriorityLifecycle.on_startup(priority=5)
|
||||
async def _():
|
||||
"""开启/禁用插件格式修改"""
|
||||
_, is_create = await GroupConsole.get_or_create(group_id=133133133)
|
||||
"""标记"""
|
||||
if is_create:
|
||||
data_list = []
|
||||
for group in await GroupConsole.all():
|
||||
if group.block_plugin:
|
||||
if modules := group.block_plugin.split(","):
|
||||
block_plugin = "".join(
|
||||
(f"{module}," if module.startswith("<") else f"<{module},")
|
||||
for module in modules
|
||||
if module.strip()
|
||||
)
|
||||
group.block_plugin = block_plugin.replace("<,", "")
|
||||
if group.block_task:
|
||||
if modules := group.block_task.split(","):
|
||||
block_task = "".join(
|
||||
(f"{module}," if module.startswith("<") else f"<{module},")
|
||||
for module in modules
|
||||
if module.strip()
|
||||
)
|
||||
group.block_task = block_task.replace("<,", "")
|
||||
data_list.append(group)
|
||||
await GroupConsole.bulk_update(data_list, ["block_plugin", "block_task"], 10)
|
||||
@ -177,7 +177,9 @@ async def _(session: EventSession, arparma: Arparma, state: T_State):
|
||||
async def _(session: EventSession, arparma: Arparma, state: T_State):
|
||||
gid = state["group_id"]
|
||||
await GroupConsole.update_or_create(
|
||||
group_id=gid, defaults={"group_flag": 0 if arparma.find("delete") else 1}
|
||||
group_id=gid,
|
||||
channel_id__isnull=True,
|
||||
defaults={"group_flag": 0 if arparma.find("delete") else 1},
|
||||
)
|
||||
s = "删除" if arparma.find("delete") else "添加"
|
||||
await MessageUtils.build_message(f"{s}群认证成功!").send(reply_to=True)
|
||||
|
||||
@ -30,12 +30,13 @@ class BanConsole(Model):
|
||||
table = "ban_console"
|
||||
table_description = "封禁人员/群组数据表"
|
||||
unique_together = ("user_id", "group_id")
|
||||
indexes = [("user_id",), ("group_id",)] # noqa: RUF012
|
||||
|
||||
cache_type = CacheType.BAN
|
||||
"""缓存类型"""
|
||||
cache_key_field = ("user_id", "group_id")
|
||||
"""缓存键字段"""
|
||||
enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE]
|
||||
enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE, DbLockType.UPSERT]
|
||||
"""开启锁"""
|
||||
|
||||
@classmethod
|
||||
@ -199,3 +200,10 @@ class BanConsole(Model):
|
||||
if id is not None:
|
||||
return await cls.safe_get_or_none(id=id)
|
||||
return await cls._get_data(user_id, group_id)
|
||||
|
||||
@classmethod
|
||||
async def _run_script(cls):
|
||||
return [
|
||||
"CREATE INDEX idx_ban_console_user_id ON ban_console(user_id);",
|
||||
"CREATE INDEX idx_ban_console_group_id ON ban_console(group_id);",
|
||||
]
|
||||
|
||||
@ -87,12 +87,15 @@ class GroupConsole(Model):
|
||||
table = "group_console"
|
||||
table_description = "群组信息表"
|
||||
unique_together = ("group_id", "channel_id")
|
||||
indexes = [ # noqa: RUF012
|
||||
("group_id",)
|
||||
]
|
||||
|
||||
cache_type = CacheType.GROUPS
|
||||
"""缓存类型"""
|
||||
cache_key_field = ("group_id", "channel_id")
|
||||
"""缓存键字段"""
|
||||
enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE]
|
||||
enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE, DbLockType.UPSERT]
|
||||
"""开启锁"""
|
||||
|
||||
@classmethod
|
||||
@ -537,4 +540,6 @@ class GroupConsole(Model):
|
||||
" character varying(255) NOT NULL DEFAULT '';",
|
||||
"ALTER TABLE group_console ADD superuser_block_task"
|
||||
" character varying(255) NOT NULL DEFAULT '';",
|
||||
"CREATE INDEX idx_group_console_group_id ON group_console(group_id);",
|
||||
"CREATE INDEX idx_group_console_group_null_channel ON group_console(group_id) WHERE channel_id IS NULL;", # 单独创建channel为空的索引 # noqa: E501
|
||||
]
|
||||
|
||||
@ -29,6 +29,7 @@ class UserConsole(Model):
|
||||
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
|
||||
table = "user_console"
|
||||
table_description = "用户数据表"
|
||||
indexes = [("user_id",), ("uid",)] # noqa: RUF012
|
||||
|
||||
cache_type = CacheType.USERS
|
||||
"""缓存类型"""
|
||||
@ -198,3 +199,10 @@ class UserConsole(Model):
|
||||
if goods := await GoodsInfo.get_or_none(goods_name=name):
|
||||
return await cls.use_props(user_id, goods.uuid, num, platform)
|
||||
raise GoodsNotFound("未找到商品...")
|
||||
|
||||
@classmethod
|
||||
async def _run_script(cls):
|
||||
return [
|
||||
"CREATE INDEX idx_user_console_user_id ON user_console(user_id);",
|
||||
"CREATE INDEX idx_user_console_uid ON user_console(uid);",
|
||||
]
|
||||
|
||||
31
zhenxun/services/cache/__init__.py
vendored
31
zhenxun/services/cache/__init__.py
vendored
@ -56,6 +56,7 @@ await message_list.save()
|
||||
```
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
@ -500,7 +501,8 @@ class CacheManager:
|
||||
logger.debug(f"清除所有 {cache_type} 缓存", LOG_COMMAND)
|
||||
return await self.clear(cache_type)
|
||||
except Exception as e:
|
||||
logger.error(f"清除缓存 {cache_type} 失败", LOG_COMMAND, e=e)
|
||||
if f"缓存类型 {cache_type} 不存在" not in str(e):
|
||||
logger.warning(f"清除缓存 {cache_type} 失败", LOG_COMMAND, e=e)
|
||||
return False
|
||||
|
||||
async def get(
|
||||
@ -516,13 +518,18 @@ class CacheManager:
|
||||
返回:
|
||||
Any: 缓存数据,如果不存在返回默认值
|
||||
"""
|
||||
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
||||
|
||||
# 如果缓存被禁用或缓存模式为NONE,直接返回默认值
|
||||
if not self.enabled or cache_config.cache_mode == CacheMode.NONE:
|
||||
return default
|
||||
|
||||
cache_key = None
|
||||
try:
|
||||
cache_key = self._build_key(cache_type, key)
|
||||
data = await self.cache_backend.get(cache_key) # type: ignore
|
||||
data = await asyncio.wait_for(
|
||||
self.cache_backend.get(cache_key), # type: ignore
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
if data is None:
|
||||
return default
|
||||
@ -534,6 +541,9 @@ class CacheManager:
|
||||
if model.result_type:
|
||||
return self._deserialize_value(data, model.result_type)
|
||||
return data
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"获取缓存 {cache_type}:{cache_key} 超时", LOG_COMMAND)
|
||||
return default
|
||||
except Exception as e:
|
||||
logger.error(f"获取缓存 {cache_type} 失败", LOG_COMMAND, e=e)
|
||||
return default
|
||||
@ -556,10 +566,12 @@ class CacheManager:
|
||||
返回:
|
||||
bool: 是否成功
|
||||
"""
|
||||
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
||||
|
||||
# 如果缓存被禁用或缓存模式为NONE,直接返回False
|
||||
if not self.enabled or cache_config.cache_mode == CacheMode.NONE:
|
||||
return False
|
||||
|
||||
cache_key = None
|
||||
try:
|
||||
cache_key = self._build_key(cache_type, key)
|
||||
model = self.get_model(cache_type)
|
||||
@ -571,8 +583,14 @@ class CacheManager:
|
||||
ttl = expire if expire is not None else model.expire
|
||||
|
||||
# 设置缓存
|
||||
await self.cache_backend.set(cache_key, serialized_value, ttl=ttl) # type: ignore
|
||||
await asyncio.wait_for(
|
||||
self.cache_backend.set(cache_key, serialized_value, ttl=ttl), # type: ignore
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"设置缓存 {cache_type}:{cache_key} 超时", LOG_COMMAND)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"设置缓存 {cache_type} 失败", LOG_COMMAND, e=e)
|
||||
return False
|
||||
@ -647,7 +665,8 @@ class CacheManager:
|
||||
await self.cache_backend.clear() # type: ignore
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("清除缓存失败", LOG_COMMAND, e=e)
|
||||
if f"缓存类型 {cache_type} 不存在" not in str(e):
|
||||
logger.warning("清除缓存失败", LOG_COMMAND, e=e)
|
||||
return False
|
||||
|
||||
async def close(self):
|
||||
|
||||
@ -1,11 +1,8 @@
|
||||
from typing import Any, Generic, TypeVar, cast
|
||||
from typing import Any, ClassVar, Generic, TypeVar, cast
|
||||
|
||||
from zhenxun.services.cache import Cache, CacheRoot, cache_config
|
||||
from zhenxun.services.cache.config import (
|
||||
COMPOSITE_KEY_SEPARATOR,
|
||||
CacheMode,
|
||||
)
|
||||
from zhenxun.services.db_context import Model
|
||||
from zhenxun.services.cache.config import COMPOSITE_KEY_SEPARATOR, CacheMode
|
||||
from zhenxun.services.db_context import Model, with_db_timeout
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
T = TypeVar("T", bound=Model)
|
||||
@ -40,6 +37,34 @@ class DataAccess(Generic[T]):
|
||||
```
|
||||
"""
|
||||
|
||||
# 添加缓存统计信息
|
||||
_cache_stats: ClassVar[dict] = {}
|
||||
# 空结果标记
|
||||
_NULL_RESULT = "__NULL_RESULT_PLACEHOLDER__"
|
||||
# 默认空结果缓存时间(秒)- 设置为5分钟,避免频繁查询数据库
|
||||
_NULL_RESULT_TTL = 300
|
||||
|
||||
@classmethod
|
||||
def set_null_result_ttl(cls, seconds: int) -> None:
|
||||
"""设置空结果缓存时间
|
||||
|
||||
参数:
|
||||
seconds: 缓存时间(秒)
|
||||
"""
|
||||
if seconds < 0:
|
||||
raise ValueError("缓存时间不能为负数")
|
||||
cls._NULL_RESULT_TTL = seconds
|
||||
logger.info(f"已设置DataAccess空结果缓存时间为 {seconds} 秒")
|
||||
|
||||
@classmethod
|
||||
def get_null_result_ttl(cls) -> int:
|
||||
"""获取空结果缓存时间
|
||||
|
||||
返回:
|
||||
int: 缓存时间(秒)
|
||||
"""
|
||||
return cls._NULL_RESULT_TTL
|
||||
|
||||
def __init__(
|
||||
self, model_cls: type[T], key_field: str = "id", cache_type: str | None = None
|
||||
):
|
||||
@ -57,6 +82,52 @@ class DataAccess(Generic[T]):
|
||||
raise ValueError("缓存类型不能为空")
|
||||
self.cache = Cache(self.cache_type)
|
||||
|
||||
# 初始化缓存统计
|
||||
if self.cache_type not in self._cache_stats:
|
||||
self._cache_stats[self.cache_type] = {
|
||||
"hits": 0, # 缓存命中次数
|
||||
"misses": 0, # 缓存未命中次数
|
||||
"null_hits": 0, # 空结果缓存命中次数
|
||||
"sets": 0, # 缓存设置次数
|
||||
"null_sets": 0, # 空结果缓存设置次数
|
||||
"deletes": 0, # 缓存删除次数
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_cache_stats(cls):
|
||||
"""获取缓存统计信息"""
|
||||
result = []
|
||||
for cache_type, stats in cls._cache_stats.items():
|
||||
hits = stats["hits"]
|
||||
null_hits = stats.get("null_hits", 0)
|
||||
misses = stats["misses"]
|
||||
total = hits + null_hits + misses
|
||||
hit_rate = ((hits + null_hits) / total * 100) if total > 0 else 0
|
||||
result.append(
|
||||
{
|
||||
"cache_type": cache_type,
|
||||
"hits": hits,
|
||||
"null_hits": null_hits,
|
||||
"misses": misses,
|
||||
"sets": stats["sets"],
|
||||
"null_sets": stats.get("null_sets", 0),
|
||||
"deletes": stats["deletes"],
|
||||
"hit_rate": f"{hit_rate:.2f}%",
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def reset_cache_stats(cls):
|
||||
"""重置缓存统计信息"""
|
||||
for stats in cls._cache_stats.values():
|
||||
stats["hits"] = 0
|
||||
stats["null_hits"] = 0
|
||||
stats["misses"] = 0
|
||||
stats["sets"] = 0
|
||||
stats["null_sets"] = 0
|
||||
stats["deletes"] = 0
|
||||
|
||||
def _build_cache_key_from_kwargs(self, **kwargs) -> str | None:
|
||||
"""从关键字参数构建缓存键
|
||||
|
||||
@ -69,12 +140,8 @@ class DataAccess(Generic[T]):
|
||||
if isinstance(self.key_field, tuple):
|
||||
# 多字段主键
|
||||
key_parts = []
|
||||
for field in self.key_field:
|
||||
key_parts.append(str(kwargs.get(field, "")))
|
||||
|
||||
if key_parts:
|
||||
return COMPOSITE_KEY_SEPARATOR.join(key_parts)
|
||||
return None
|
||||
key_parts.extend(str(kwargs.get(field, "")) for field in self.key_field)
|
||||
return COMPOSITE_KEY_SEPARATOR.join(key_parts) if key_parts else None
|
||||
elif self.key_field in kwargs:
|
||||
# 单字段主键
|
||||
return str(kwargs[self.key_field])
|
||||
@ -92,9 +159,14 @@ class DataAccess(Generic[T]):
|
||||
"""
|
||||
# 如果没有缓存类型,直接从数据库获取
|
||||
if not self.cache_type or cache_config.cache_mode == CacheMode.NONE:
|
||||
return await self.model_cls.safe_get_or_none(*args, **kwargs)
|
||||
logger.debug(f"{self.model_cls.__name__} 直接从数据库获取数据: {kwargs}")
|
||||
return await with_db_timeout(
|
||||
self.model_cls.safe_get_or_none(*args, **kwargs),
|
||||
operation=f"{self.model_cls.__name__}.safe_get_or_none",
|
||||
)
|
||||
|
||||
# 尝试从缓存获取
|
||||
cache_key = None
|
||||
try:
|
||||
# 尝试构建缓存键
|
||||
cache_key = self._build_cache_key_from_kwargs(**kwargs)
|
||||
@ -102,12 +174,33 @@ class DataAccess(Generic[T]):
|
||||
# 如果成功构建缓存键,尝试从缓存获取
|
||||
if cache_key is not None:
|
||||
data = await self.cache.get(cache_key)
|
||||
if data:
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} self.cache.get(cache_key)"
|
||||
f" 从缓存获取到的数据 {type(data)}: {data}"
|
||||
)
|
||||
if data == self._NULL_RESULT:
|
||||
# 空结果缓存命中
|
||||
self._cache_stats[self.cache_type]["null_hits"] += 1
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} 从缓存获取到空结果: {cache_key}"
|
||||
)
|
||||
return None
|
||||
elif data:
|
||||
# 缓存命中
|
||||
self._cache_stats[self.cache_type]["hits"] += 1
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} 从缓存获取数据成功: {cache_key}"
|
||||
)
|
||||
return cast(T, data)
|
||||
else:
|
||||
# 缓存未命中
|
||||
self._cache_stats[self.cache_type]["misses"] += 1
|
||||
logger.debug(f"{self.model_cls.__name__} 缓存未命中: {cache_key}")
|
||||
except Exception as e:
|
||||
logger.error("从缓存获取数据失败", e=e)
|
||||
logger.error(f"{self.model_cls.__name__} 从缓存获取数据失败: {kwargs}", e=e)
|
||||
|
||||
# 如果缓存中没有,从数据库获取
|
||||
logger.debug(f"{self.model_cls.__name__} 从数据库获取数据: {kwargs}")
|
||||
data = await self.model_cls.safe_get_or_none(*args, **kwargs)
|
||||
|
||||
# 如果获取到数据,存入缓存
|
||||
@ -118,9 +211,30 @@ class DataAccess(Generic[T]):
|
||||
if cache_key is not None:
|
||||
# 存入缓存
|
||||
await self.cache.set(cache_key, data)
|
||||
logger.debug(f"{self.cache_type} 数据已存入缓存: {cache_key}")
|
||||
self._cache_stats[self.cache_type]["sets"] += 1
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} 数据已存入缓存: {cache_key}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.cache_type} 存入缓存失败,参数: {kwargs}", e=e)
|
||||
logger.error(
|
||||
f"{self.model_cls.__name__} 存入缓存失败,参数: {kwargs}", e=e
|
||||
)
|
||||
elif cache_key is not None:
|
||||
# 如果没有获取到数据,缓存空结果
|
||||
try:
|
||||
# 存入空结果缓存,使用较短的过期时间
|
||||
await self.cache.set(
|
||||
cache_key, self._NULL_RESULT, expire=self._NULL_RESULT_TTL
|
||||
)
|
||||
self._cache_stats[self.cache_type]["null_sets"] += 1
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} 空结果已存入缓存: {cache_key},"
|
||||
f" TTL={self._NULL_RESULT_TTL}秒"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"{self.model_cls.__name__} 存入空结果缓存失败,参数: {kwargs}", e=e
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@ -136,9 +250,14 @@ class DataAccess(Generic[T]):
|
||||
"""
|
||||
# 如果没有缓存类型,直接从数据库获取
|
||||
if not self.cache_type or cache_config.cache_mode == CacheMode.NONE:
|
||||
return await self.model_cls.get_or_none(*args, **kwargs)
|
||||
logger.debug(f"{self.model_cls.__name__} 直接从数据库获取数据: {kwargs}")
|
||||
return await with_db_timeout(
|
||||
self.model_cls.get_or_none(*args, **kwargs),
|
||||
operation=f"{self.model_cls.__name__}.get_or_none",
|
||||
)
|
||||
|
||||
# 尝试从缓存获取
|
||||
cache_key = None
|
||||
try:
|
||||
# 尝试构建缓存键
|
||||
cache_key = self._build_cache_key_from_kwargs(**kwargs)
|
||||
@ -146,12 +265,29 @@ class DataAccess(Generic[T]):
|
||||
# 如果成功构建缓存键,尝试从缓存获取
|
||||
if cache_key is not None:
|
||||
data = await self.cache.get(cache_key)
|
||||
if data:
|
||||
if data == self._NULL_RESULT:
|
||||
# 空结果缓存命中
|
||||
self._cache_stats[self.cache_type]["null_hits"] += 1
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} 从缓存获取到空结果: {cache_key}"
|
||||
)
|
||||
return None
|
||||
elif data:
|
||||
# 缓存命中
|
||||
self._cache_stats[self.cache_type]["hits"] += 1
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} 从缓存获取数据成功: {cache_key}"
|
||||
)
|
||||
return cast(T, data)
|
||||
else:
|
||||
# 缓存未命中
|
||||
self._cache_stats[self.cache_type]["misses"] += 1
|
||||
logger.debug(f"{self.model_cls.__name__} 缓存未命中: {cache_key}")
|
||||
except Exception as e:
|
||||
logger.error("从缓存获取数据失败", e=e)
|
||||
logger.error(f"{self.model_cls.__name__} 从缓存获取数据失败: {kwargs}", e=e)
|
||||
|
||||
# 如果缓存中没有,从数据库获取
|
||||
logger.debug(f"{self.model_cls.__name__} 从数据库获取数据: {kwargs}")
|
||||
data = await self.model_cls.get_or_none(*args, **kwargs)
|
||||
|
||||
# 如果获取到数据,存入缓存
|
||||
@ -162,9 +298,30 @@ class DataAccess(Generic[T]):
|
||||
if cache_key is not None:
|
||||
# 存入缓存
|
||||
await self.cache.set(cache_key, data)
|
||||
logger.debug(f"{self.cache_type} 数据已存入缓存: {cache_key}")
|
||||
self._cache_stats[self.cache_type]["sets"] += 1
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} 数据已存入缓存: {cache_key}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.cache_type} 存入缓存失败,参数: {kwargs}", e=e)
|
||||
logger.error(
|
||||
f"{self.model_cls.__name__} 存入缓存失败,参数: {kwargs}", e=e
|
||||
)
|
||||
elif cache_key is not None:
|
||||
# 如果没有获取到数据,缓存空结果
|
||||
try:
|
||||
# 存入空结果缓存,使用较短的过期时间
|
||||
await self.cache.set(
|
||||
cache_key, self._NULL_RESULT, expire=self._NULL_RESULT_TTL
|
||||
)
|
||||
self._cache_stats[self.cache_type]["null_sets"] += 1
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} 空结果已存入缓存: {cache_key},"
|
||||
f" TTL={self._NULL_RESULT_TTL}秒"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"{self.model_cls.__name__} 存入空结果缓存失败,参数: {kwargs}", e=e
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@ -203,6 +360,7 @@ class DataAccess(Generic[T]):
|
||||
|
||||
# 删除缓存
|
||||
await self.cache.delete(cache_key)
|
||||
self._cache_stats[self.cache_type]["deletes"] += 1
|
||||
logger.debug(f"已清除{self.model_cls.__name__}缓存: {cache_key}")
|
||||
return True
|
||||
except Exception as e:
|
||||
@ -227,12 +385,7 @@ class DataAccess(Generic[T]):
|
||||
key_parts.append(value if value is not None else "")
|
||||
|
||||
# 如果没有有效参数,返回None
|
||||
if not key_parts:
|
||||
return None
|
||||
|
||||
return COMPOSITE_KEY_SEPARATOR.join(key_parts)
|
||||
|
||||
# 单个字段作为键
|
||||
return COMPOSITE_KEY_SEPARATOR.join(key_parts) if key_parts else None
|
||||
elif hasattr(data, self.key_field):
|
||||
value = getattr(data, self.key_field, None)
|
||||
return str(value) if value is not None else None
|
||||
@ -255,24 +408,22 @@ class DataAccess(Generic[T]):
|
||||
# 获取缓存类型的配置信息
|
||||
cache_model = CacheRoot.get_model(self.cache_type)
|
||||
|
||||
# 如果有键格式定义,则需要构建特殊格式的键
|
||||
if cache_model.key_format:
|
||||
# 构建键参数字典
|
||||
key_parts = []
|
||||
# 从格式字符串中提取所需的字段名
|
||||
import re
|
||||
|
||||
field_names = re.findall(r"{([^}]+)}", cache_model.key_format)
|
||||
|
||||
# 收集所有字段值
|
||||
for field in field_names:
|
||||
value = getattr(item, field, "")
|
||||
key_parts.append(value if value is not None else "")
|
||||
|
||||
return COMPOSITE_KEY_SEPARATOR.join(key_parts)
|
||||
else:
|
||||
if not cache_model.key_format:
|
||||
# 常规处理,使用主键作为缓存键
|
||||
return self._build_composite_key(item)
|
||||
# 构建键参数字典
|
||||
key_parts = []
|
||||
# 从格式字符串中提取所需的字段名
|
||||
import re
|
||||
|
||||
field_names = re.findall(r"{([^}]+)}", cache_model.key_format)
|
||||
|
||||
# 收集所有字段值
|
||||
for field in field_names:
|
||||
value = getattr(item, field, "")
|
||||
key_parts.append(value if value is not None else "")
|
||||
|
||||
return COMPOSITE_KEY_SEPARATOR.join(key_parts)
|
||||
|
||||
async def _cache_items(self, data_list: list[T]) -> None:
|
||||
"""将数据列表存入缓存
|
||||
@ -289,14 +440,19 @@ class DataAccess(Generic[T]):
|
||||
|
||||
try:
|
||||
# 遍历数据列表,将每条数据存入缓存
|
||||
cached_count = 0
|
||||
for item in data_list:
|
||||
cache_key = self._build_cache_key_for_item(item)
|
||||
if cache_key is not None:
|
||||
await self.cache.set(cache_key, item)
|
||||
cached_count += 1
|
||||
self._cache_stats[self.cache_type]["sets"] += 1
|
||||
|
||||
logger.debug(f"{self.cache_type} 数据已存入缓存,数量: {len(data_list)}")
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} 批量缓存: {cached_count}/{len(data_list)}项"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.cache_type} 数据存入缓存失败", e=e)
|
||||
logger.error(f"{self.model_cls.__name__} 批量缓存失败", e=e)
|
||||
|
||||
async def filter(self, *args, **kwargs) -> list[T]:
|
||||
"""筛选数据
|
||||
@ -309,7 +465,11 @@ class DataAccess(Generic[T]):
|
||||
List[T]: 查询结果列表
|
||||
"""
|
||||
# 从数据库获取数据
|
||||
logger.debug(f"{self.model_cls.__name__} filter: 从数据库查询, 参数: {kwargs}")
|
||||
data_list = await self.model_cls.filter(*args, **kwargs)
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} filter: 查询结果数量: {len(data_list)}"
|
||||
)
|
||||
|
||||
# 将数据存入缓存
|
||||
await self._cache_items(data_list)
|
||||
@ -323,7 +483,9 @@ class DataAccess(Generic[T]):
|
||||
List[T]: 所有数据列表
|
||||
"""
|
||||
# 直接从数据库获取
|
||||
logger.debug(f"{self.model_cls.__name__} all: 从数据库查询所有数据")
|
||||
data_list = await self.model_cls.all()
|
||||
logger.debug(f"{self.model_cls.__name__} all: 查询结果数量: {len(data_list)}")
|
||||
|
||||
# 将数据存入缓存
|
||||
await self._cache_items(data_list)
|
||||
@ -366,6 +528,7 @@ class DataAccess(Generic[T]):
|
||||
T: 创建的数据
|
||||
"""
|
||||
# 创建数据
|
||||
logger.debug(f"{self.model_cls.__name__} create: 创建数据, 参数: {kwargs}")
|
||||
data = await self.model_cls.create(**kwargs)
|
||||
|
||||
# 如果有缓存类型,将数据存入缓存
|
||||
@ -376,11 +539,16 @@ class DataAccess(Generic[T]):
|
||||
if cache_key is not None:
|
||||
# 存入缓存
|
||||
await self.cache.set(cache_key, data)
|
||||
self._cache_stats[self.cache_type]["sets"] += 1
|
||||
logger.debug(
|
||||
f"{self.cache_type} 新创建的数据已存入缓存: {cache_key}"
|
||||
f"{self.model_cls.__name__} create: "
|
||||
f"新创建的数据已存入缓存: {cache_key}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.cache_type} 存入缓存失败,参数: {kwargs}", e=e)
|
||||
logger.error(
|
||||
f"{self.model_cls.__name__} create: 存入缓存失败,参数: {kwargs}",
|
||||
e=e,
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@ -409,6 +577,7 @@ class DataAccess(Generic[T]):
|
||||
if cache_key is not None:
|
||||
# 存入缓存
|
||||
await self.cache.set(cache_key, data)
|
||||
self._cache_stats[self.cache_type]["sets"] += 1
|
||||
logger.debug(f"更新或创建的数据已存入缓存: {cache_key}")
|
||||
except Exception as e:
|
||||
logger.error(f"存入缓存失败,参数: {kwargs}", e=e)
|
||||
@ -425,6 +594,8 @@ class DataAccess(Generic[T]):
|
||||
返回:
|
||||
int: 删除的数据数量
|
||||
"""
|
||||
logger.debug(f"{self.model_cls.__name__} delete: 删除数据, 参数: {kwargs}")
|
||||
|
||||
# 如果有缓存类型且有key_field参数,先尝试删除缓存
|
||||
if self.cache_type and cache_config.cache_mode != CacheMode.NONE:
|
||||
try:
|
||||
@ -434,21 +605,36 @@ class DataAccess(Generic[T]):
|
||||
if cache_key is not None:
|
||||
# 如果成功构建缓存键,直接删除缓存
|
||||
await self.cache.delete(cache_key)
|
||||
logger.debug(f"已删除缓存: {cache_key}")
|
||||
self._cache_stats[self.cache_type]["deletes"] += 1
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} delete: 已删除缓存: {cache_key}"
|
||||
)
|
||||
else:
|
||||
# 否则需要先查询出要删除的数据,然后删除对应的缓存
|
||||
items = await self.model_cls.filter(*args, **kwargs)
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} delete:"
|
||||
f" 查询到 {len(items)} 条要删除的数据"
|
||||
)
|
||||
for item in items:
|
||||
item_cache_key = self._build_cache_key_for_item(item)
|
||||
if item_cache_key is not None:
|
||||
await self.cache.delete(item_cache_key)
|
||||
self._cache_stats[self.cache_type]["deletes"] += 1
|
||||
if items:
|
||||
logger.debug(f"已删除{len(items)}条数据的缓存")
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} delete:"
|
||||
f" 已删除 {len(items)} 条数据的缓存"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("删除缓存失败", e=e)
|
||||
logger.error(f"{self.model_cls.__name__} delete: 删除缓存失败", e=e)
|
||||
|
||||
# 删除数据
|
||||
return await self.model_cls.filter(*args, **kwargs).delete()
|
||||
result = await self.model_cls.filter(*args, **kwargs).delete()
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} delete: 已从数据库删除 {result} 条数据"
|
||||
)
|
||||
return result
|
||||
|
||||
def _generate_cache_key(self, data: T) -> str:
|
||||
"""根据数据对象生成缓存键
|
||||
|
||||
@ -1,67 +1,83 @@
|
||||
from asyncio import Semaphore
|
||||
import asyncio
|
||||
from collections.abc import Iterable
|
||||
import contextlib
|
||||
import time
|
||||
from typing import Any, ClassVar
|
||||
from typing_extensions import Self
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import nonebot
|
||||
from nonebot import get_driver
|
||||
from nonebot.utils import is_coroutine_callable
|
||||
from tortoise import Tortoise
|
||||
from tortoise.backends.base.client import BaseDBAsyncClient
|
||||
from tortoise.connection import connections
|
||||
from tortoise.exceptions import IntegrityError, MultipleObjectsReturned
|
||||
from tortoise.models import Model as TortoiseModel
|
||||
from tortoise.transactions import in_transaction
|
||||
|
||||
from zhenxun.configs.config import BotConfig
|
||||
from zhenxun.services.cache import CacheRoot
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import DbLockType
|
||||
from zhenxun.utils.exception import HookPriorityException
|
||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||
|
||||
from .cache import CacheRoot
|
||||
from .cache.config import COMPOSITE_KEY_SEPARATOR
|
||||
from .log import logger
|
||||
driver = get_driver()
|
||||
|
||||
SCRIPT_METHOD = []
|
||||
MODELS: list[str] = []
|
||||
|
||||
driver = nonebot.get_driver()
|
||||
# 数据库操作超时设置(秒)
|
||||
DB_TIMEOUT_SECONDS = 3.0
|
||||
|
||||
# 性能监控阈值(秒)
|
||||
SLOW_QUERY_THRESHOLD = 0.5
|
||||
|
||||
LOG_COMMAND = "DbContext"
|
||||
|
||||
|
||||
async def with_db_timeout(
|
||||
coro, timeout: float = DB_TIMEOUT_SECONDS, operation: str | None = None
|
||||
):
|
||||
"""带超时控制的数据库操作"""
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = await asyncio.wait_for(coro, timeout=timeout)
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > SLOW_QUERY_THRESHOLD and operation:
|
||||
logger.warning(f"慢查询: {operation} 耗时 {elapsed:.3f}s", LOG_COMMAND)
|
||||
return result
|
||||
except asyncio.TimeoutError:
|
||||
if operation:
|
||||
logger.error(f"数据库操作超时: {operation} (>{timeout}s)", LOG_COMMAND)
|
||||
raise
|
||||
|
||||
|
||||
class Model(TortoiseModel):
|
||||
"""
|
||||
自动添加模块
|
||||
增强的ORM基类,解决锁嵌套问题
|
||||
"""
|
||||
|
||||
sem_data: ClassVar[dict[str, dict[str, Semaphore]]] = {}
|
||||
sem_data: ClassVar[dict[str, dict[str, asyncio.Semaphore]]] = {}
|
||||
_current_locks: ClassVar[dict[int, DbLockType]] = {} # 跟踪当前协程持有的锁
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
if cls.__module__ not in MODELS:
|
||||
MODELS.append(cls.__module__)
|
||||
|
||||
if func := getattr(cls, "_run_script", None):
|
||||
SCRIPT_METHOD.append((cls.__module__, func))
|
||||
if enable_lock := getattr(cls, "enable_lock", []):
|
||||
"""创建锁"""
|
||||
cls.sem_data[cls.__module__] = {}
|
||||
for lock in enable_lock:
|
||||
cls.sem_data[cls.__module__][lock] = Semaphore(1)
|
||||
|
||||
@classmethod
|
||||
def get_semaphore(cls, lock_type: DbLockType):
|
||||
return cls.sem_data.get(cls.__module__, {}).get(lock_type, None)
|
||||
|
||||
@classmethod
|
||||
def get_cache_type(cls) -> str | None:
|
||||
"""获取缓存类型"""
|
||||
return getattr(cls, "cache_type", None)
|
||||
|
||||
@classmethod
|
||||
def get_cache_key_field(cls) -> str | tuple[str]:
|
||||
"""获取缓存键字段名
|
||||
|
||||
返回:
|
||||
str | tuple[str]: 缓存键字段名,可能是单个字段名或字段名元组
|
||||
"""
|
||||
if hasattr(cls, "cache_key_field"):
|
||||
return getattr(cls, "cache_key_field", "id")
|
||||
return "id"
|
||||
"""获取缓存键字段"""
|
||||
return getattr(cls, "cache_key_field", "id")
|
||||
|
||||
@classmethod
|
||||
def get_cache_key(cls, instance) -> str | None:
|
||||
@ -71,13 +87,14 @@ class Model(TortoiseModel):
|
||||
instance: 模型实例
|
||||
|
||||
返回:
|
||||
str | None
|
||||
str | None: 缓存键,如果无法获取则返回None
|
||||
"""
|
||||
from zhenxun.services.cache.config import COMPOSITE_KEY_SEPARATOR
|
||||
|
||||
key_field = cls.get_cache_key_field()
|
||||
|
||||
# 如果是元组,表示多个字段组成键
|
||||
if isinstance(key_field, tuple):
|
||||
# 构建键参数列表
|
||||
# 多字段主键
|
||||
key_parts = []
|
||||
for field in key_field:
|
||||
if hasattr(instance, field):
|
||||
@ -85,25 +102,60 @@ class Model(TortoiseModel):
|
||||
key_parts.append(value if value is not None else "")
|
||||
else:
|
||||
# 如果缺少任何必要的字段,返回None
|
||||
return None
|
||||
key_parts.append("")
|
||||
|
||||
# 如果没有有效参数,返回None
|
||||
if not key_parts:
|
||||
return None
|
||||
|
||||
return COMPOSITE_KEY_SEPARATOR.join(str(param) for param in key_parts)
|
||||
|
||||
# 单个字段作为键
|
||||
return COMPOSITE_KEY_SEPARATOR.join(key_parts) if key_parts else None
|
||||
elif hasattr(instance, key_field):
|
||||
return getattr(instance, key_field, None)
|
||||
value = getattr(instance, key_field, None)
|
||||
return str(value) if value is not None else None
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_semaphore(cls, lock_type: DbLockType):
|
||||
enable_lock = getattr(cls, "enable_lock", None)
|
||||
if not enable_lock or lock_type not in enable_lock:
|
||||
return None
|
||||
|
||||
if cls.__name__ not in cls.sem_data:
|
||||
cls.sem_data[cls.__name__] = {}
|
||||
if lock_type not in cls.sem_data[cls.__name__]:
|
||||
cls.sem_data[cls.__name__][lock_type] = asyncio.Semaphore(1)
|
||||
return cls.sem_data[cls.__name__][lock_type]
|
||||
|
||||
@classmethod
|
||||
def _require_lock(cls, lock_type: DbLockType) -> bool:
|
||||
"""检查是否需要真正加锁"""
|
||||
task_id = id(asyncio.current_task())
|
||||
return cls._current_locks.get(task_id) != lock_type
|
||||
|
||||
@classmethod
|
||||
@contextlib.asynccontextmanager
|
||||
async def _lock_context(cls, lock_type: DbLockType):
|
||||
"""带重入检查的锁上下文"""
|
||||
task_id = id(asyncio.current_task())
|
||||
need_lock = cls._require_lock(lock_type)
|
||||
|
||||
if need_lock and (sem := cls.get_semaphore(lock_type)):
|
||||
cls._current_locks[task_id] = lock_type
|
||||
async with sem:
|
||||
yield
|
||||
cls._current_locks.pop(task_id, None)
|
||||
else:
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any
|
||||
) -> Self:
|
||||
return await super().create(using_db=using_db, **kwargs)
|
||||
"""创建数据(使用CREATE锁)"""
|
||||
async with cls._lock_context(DbLockType.CREATE):
|
||||
# 直接调用父类的_create方法避免触发save的锁
|
||||
result = await super().create(using_db=using_db, **kwargs)
|
||||
if cache_type := cls.get_cache_type():
|
||||
await CacheRoot.invalidate_cache(cache_type, cls.get_cache_key(result))
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def get_or_create(
|
||||
@ -112,31 +164,13 @@ class Model(TortoiseModel):
|
||||
using_db: BaseDBAsyncClient | None = None,
|
||||
**kwargs: Any,
|
||||
) -> tuple[Self, bool]:
|
||||
if sem := cls.get_semaphore(DbLockType.CREATE):
|
||||
async with sem:
|
||||
# 在锁内执行查询和创建操作
|
||||
result, is_create = await super().get_or_create(
|
||||
defaults=defaults, using_db=using_db, **kwargs
|
||||
)
|
||||
if is_create and (cache_type := cls.get_cache_type()):
|
||||
# 获取缓存键
|
||||
key = cls.get_cache_key(result)
|
||||
await CacheRoot.invalidate_cache(
|
||||
cache_type, key if key is not None else None
|
||||
)
|
||||
return (result, is_create)
|
||||
else:
|
||||
# 如果没有锁,则执行原来的逻辑
|
||||
result, is_create = await super().get_or_create(
|
||||
defaults=defaults, using_db=using_db, **kwargs
|
||||
)
|
||||
if is_create and (cache_type := cls.get_cache_type()):
|
||||
# 获取缓存键
|
||||
key = cls.get_cache_key(result)
|
||||
await CacheRoot.invalidate_cache(
|
||||
cache_type, key if key is not None else None
|
||||
)
|
||||
return (result, is_create)
|
||||
"""获取或创建数据(无锁版本,依赖数据库约束)"""
|
||||
result = await super().get_or_create(
|
||||
defaults=defaults, using_db=using_db, **kwargs
|
||||
)
|
||||
if cache_type := cls.get_cache_type():
|
||||
await CacheRoot.invalidate_cache(cache_type, cls.get_cache_key(result[0]))
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def update_or_create(
|
||||
@ -145,73 +179,28 @@ class Model(TortoiseModel):
|
||||
using_db: BaseDBAsyncClient | None = None,
|
||||
**kwargs: Any,
|
||||
) -> tuple[Self, bool]:
|
||||
if sem := cls.get_semaphore(DbLockType.CREATE):
|
||||
async with sem:
|
||||
# 在锁内执行查询和创建操作
|
||||
result = await super().update_or_create(
|
||||
defaults=defaults, using_db=using_db, **kwargs
|
||||
)
|
||||
"""更新或创建数据(使用UPSERT锁)"""
|
||||
async with cls._lock_context(DbLockType.UPSERT):
|
||||
try:
|
||||
# 先尝试更新(带行锁)
|
||||
async with in_transaction():
|
||||
if obj := await cls.filter(**kwargs).select_for_update().first():
|
||||
await obj.update_from_dict(defaults or {})
|
||||
await obj.save()
|
||||
result = (obj, False)
|
||||
else:
|
||||
# 创建时不重复加锁
|
||||
result = await cls.create(**kwargs, **(defaults or {})), True
|
||||
|
||||
if cache_type := cls.get_cache_type():
|
||||
# 获取缓存键
|
||||
key = cls.get_cache_key(result[0])
|
||||
await CacheRoot.invalidate_cache(
|
||||
cache_type, key if key is not None else None
|
||||
cache_type, cls.get_cache_key(result[0])
|
||||
)
|
||||
return result
|
||||
else:
|
||||
# 如果没有锁,则执行原来的逻辑
|
||||
result = await super().update_or_create(
|
||||
defaults=defaults, using_db=using_db, **kwargs
|
||||
)
|
||||
if cache_type := cls.get_cache_type():
|
||||
# 获取缓存键
|
||||
key = cls.get_cache_key(result[0])
|
||||
await CacheRoot.invalidate_cache(
|
||||
cache_type, key if key is not None else None
|
||||
)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def bulk_create( # type: ignore
|
||||
cls,
|
||||
objects: Iterable[Self], # type: ignore
|
||||
batch_size: int | None = None,
|
||||
ignore_conflicts: bool = False,
|
||||
update_fields: Iterable[str] | None = None,
|
||||
on_conflict: Iterable[str] | None = None,
|
||||
using_db: BaseDBAsyncClient | None = None,
|
||||
) -> list[Self]: # type: ignore
|
||||
result = await super().bulk_create(
|
||||
objects=objects,
|
||||
batch_size=batch_size,
|
||||
ignore_conflicts=ignore_conflicts,
|
||||
update_fields=update_fields,
|
||||
on_conflict=on_conflict,
|
||||
using_db=using_db,
|
||||
)
|
||||
if cache_type := cls.get_cache_type():
|
||||
# 批量创建时清除整个类型的缓存
|
||||
await CacheRoot.invalidate_cache(cache_type)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def bulk_update( # type: ignore
|
||||
cls,
|
||||
objects: Iterable[Self], # type: ignore
|
||||
fields: Iterable[str],
|
||||
batch_size: int | None = None,
|
||||
using_db: BaseDBAsyncClient | None = None,
|
||||
) -> int: # type: ignore
|
||||
result = await super().bulk_update(
|
||||
objects=objects,
|
||||
fields=fields,
|
||||
batch_size=batch_size,
|
||||
using_db=using_db,
|
||||
)
|
||||
if cache_type := cls.get_cache_type():
|
||||
# 批量更新时清除整个类型的缓存
|
||||
await CacheRoot.invalidate_cache(cache_type)
|
||||
return result
|
||||
except IntegrityError:
|
||||
# 处理极端情况下的唯一约束冲突
|
||||
obj = await cls.get(**kwargs)
|
||||
return obj, False
|
||||
|
||||
async def save(
|
||||
self,
|
||||
@ -220,37 +209,27 @@ class Model(TortoiseModel):
|
||||
force_create: bool = False,
|
||||
force_update: bool = False,
|
||||
):
|
||||
if getattr(self, "id", None) is None:
|
||||
sem = self.get_semaphore(DbLockType.CREATE)
|
||||
else:
|
||||
sem = self.get_semaphore(DbLockType.UPDATE)
|
||||
if sem:
|
||||
async with sem:
|
||||
await super().save(
|
||||
using_db=using_db,
|
||||
update_fields=update_fields,
|
||||
force_create=force_create,
|
||||
force_update=force_update,
|
||||
)
|
||||
else:
|
||||
"""保存数据(根据操作类型自动选择锁)"""
|
||||
lock_type = (
|
||||
DbLockType.CREATE
|
||||
if getattr(self, "id", None) is None
|
||||
else DbLockType.UPDATE
|
||||
)
|
||||
async with self._lock_context(lock_type):
|
||||
await super().save(
|
||||
using_db=using_db,
|
||||
update_fields=update_fields,
|
||||
force_create=force_create,
|
||||
force_update=force_update,
|
||||
)
|
||||
if cache_type := getattr(self, "cache_type", None):
|
||||
# 获取缓存键
|
||||
key = self.__class__.get_cache_key(self)
|
||||
await CacheRoot.invalidate_cache(cache_type, key)
|
||||
if cache_type := getattr(self, "cache_type", None):
|
||||
await CacheRoot.invalidate_cache(
|
||||
cache_type, self.__class__.get_cache_key(self)
|
||||
)
|
||||
|
||||
async def delete(self, using_db: BaseDBAsyncClient | None = None):
|
||||
# 在删除前获取缓存键
|
||||
cache_type = getattr(self, "cache_type", None)
|
||||
key = None
|
||||
if cache_type:
|
||||
key = self.__class__.get_cache_key(self)
|
||||
|
||||
key = self.__class__.get_cache_key(self) if cache_type else None
|
||||
# 执行删除操作
|
||||
await super().delete(using_db=using_db)
|
||||
|
||||
@ -280,15 +259,23 @@ class Model(TortoiseModel):
|
||||
"""
|
||||
try:
|
||||
# 先尝试使用 get_or_none 获取单个记录
|
||||
return await cls.get_or_none(*args, using_db=using_db, **kwargs)
|
||||
except Exception as e:
|
||||
# 如果出现错误(可能是存在多个记录)
|
||||
if "Multiple objects" in str(e):
|
||||
logger.warning(
|
||||
f"{cls.__name__} safe_get_or_none 发现多个记录: {kwargs}"
|
||||
try:
|
||||
return await with_db_timeout(
|
||||
cls.get_or_none(*args, using_db=using_db, **kwargs),
|
||||
operation=f"{cls.__name__}.get_or_none",
|
||||
)
|
||||
except MultipleObjectsReturned:
|
||||
# 如果出现多个记录的情况,进行特殊处理
|
||||
logger.warning(
|
||||
f"{cls.__name__} safe_get_or_none 发现多个记录: {kwargs}",
|
||||
LOG_COMMAND,
|
||||
)
|
||||
|
||||
# 查询所有匹配记录
|
||||
records = await cls.filter(*args, **kwargs)
|
||||
records = await with_db_timeout(
|
||||
cls.filter(*args, **kwargs).all(),
|
||||
operation=f"{cls.__name__}.filter.all",
|
||||
)
|
||||
|
||||
if not records:
|
||||
return None
|
||||
@ -301,20 +288,39 @@ class Model(TortoiseModel):
|
||||
)
|
||||
for record in records[1:]:
|
||||
try:
|
||||
await record.delete()
|
||||
await with_db_timeout(
|
||||
record.delete(),
|
||||
operation=f"{cls.__name__}.delete_duplicate",
|
||||
)
|
||||
logger.info(
|
||||
f"{cls.__name__} 删除重复记录:"
|
||||
f" id={getattr(record, 'id', None)}"
|
||||
f" id={getattr(record, 'id', None)}",
|
||||
LOG_COMMAND,
|
||||
)
|
||||
except Exception as del_e:
|
||||
logger.error(f"删除重复记录失败: {del_e}")
|
||||
return records[0]
|
||||
# 如果不需要清理或没有 id 字段,则返回最新的记录
|
||||
if hasattr(cls, "id"):
|
||||
return await cls.filter(*args, **kwargs).order_by("-id").first()
|
||||
return await with_db_timeout(
|
||||
cls.filter(*args, **kwargs).order_by("-id").first(),
|
||||
operation=f"{cls.__name__}.filter.order_by.first",
|
||||
)
|
||||
# 如果没有 id 字段,则返回第一个记录
|
||||
return await cls.filter(*args, **kwargs).first()
|
||||
return await with_db_timeout(
|
||||
cls.filter(*args, **kwargs).first(),
|
||||
operation=f"{cls.__name__}.filter.first",
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
f"数据库操作超时: {cls.__name__}.safe_get_or_none", LOG_COMMAND
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
# 其他类型的错误则继续抛出
|
||||
logger.error(
|
||||
f"数据库操作异常: {cls.__name__}.safe_get_or_none, {e!s}", LOG_COMMAND
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@ -334,6 +340,77 @@ class DbConnectError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
POSTGRESQL_CONFIG = {
|
||||
"max_size": 30, # 最大连接数
|
||||
"min_size": 5, # 最小保持的连接数(可选)
|
||||
}
|
||||
|
||||
|
||||
MYSQL_CONFIG = {
|
||||
"max_connections": 20, # 最大连接数
|
||||
"connect_timeout": 30, # 连接超时(可选)
|
||||
}
|
||||
|
||||
SQLITE_CONFIG = {
|
||||
"journal_mode": "WAL", # 提高并发写入性能
|
||||
"timeout": 30, # 锁等待超时(可选)
|
||||
}
|
||||
|
||||
|
||||
def get_config(db_url: str) -> dict:
|
||||
"""获取数据库配置"""
|
||||
parsed = urlparse(BotConfig.db_url)
|
||||
|
||||
# 基础配置
|
||||
config = {
|
||||
"connections": {
|
||||
"default": BotConfig.db_url # 默认直接使用连接字符串
|
||||
},
|
||||
"apps": {
|
||||
"models": {
|
||||
"models": MODELS,
|
||||
"default_connection": "default",
|
||||
}
|
||||
},
|
||||
"timezone": "Asia/Shanghai",
|
||||
}
|
||||
|
||||
# 根据数据库类型应用高级配置
|
||||
if parsed.scheme.startswith("postgres"):
|
||||
config["connections"]["default"] = {
|
||||
"engine": "tortoise.backends.asyncpg",
|
||||
"credentials": {
|
||||
"host": parsed.hostname,
|
||||
"port": parsed.port or 5432,
|
||||
"user": parsed.username,
|
||||
"password": parsed.password,
|
||||
"database": parsed.path[1:],
|
||||
},
|
||||
**POSTGRESQL_CONFIG,
|
||||
}
|
||||
elif parsed.scheme == "mysql":
|
||||
config["connections"]["default"] = {
|
||||
"engine": "tortoise.backends.mysql",
|
||||
"credentials": {
|
||||
"host": parsed.hostname,
|
||||
"port": parsed.port or 3306,
|
||||
"user": parsed.username,
|
||||
"password": parsed.password,
|
||||
"database": parsed.path[1:],
|
||||
},
|
||||
**MYSQL_CONFIG,
|
||||
}
|
||||
elif parsed.scheme == "sqlite":
|
||||
config["connections"]["default"] = {
|
||||
"engine": "tortoise.backends.sqlite",
|
||||
"credentials": {
|
||||
"file_path": parsed.path[1:] or ":memory:",
|
||||
},
|
||||
**SQLITE_CONFIG,
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
@PriorityLifecycle.on_startup(priority=1)
|
||||
async def init():
|
||||
if not BotConfig.db_url:
|
||||
@ -349,9 +426,7 @@ async def init():
|
||||
raise DbUrlIsNode("\n" + error.strip())
|
||||
try:
|
||||
await Tortoise.init(
|
||||
db_url=BotConfig.db_url,
|
||||
modules={"models": MODELS},
|
||||
timezone="Asia/Shanghai",
|
||||
config=get_config(BotConfig.db_url),
|
||||
)
|
||||
if SCRIPT_METHOD:
|
||||
db = Tortoise.get_connection("default")
|
||||
@ -366,17 +441,21 @@ async def init():
|
||||
if sql:
|
||||
sql_list += sql
|
||||
except Exception as e:
|
||||
logger.trace(f"{module} 执行SCRIPT_METHOD方法出错...", e=e)
|
||||
logger.debug(f"{module} 执行SCRIPT_METHOD方法出错...", e=e)
|
||||
for sql in sql_list:
|
||||
logger.trace(f"执行SQL: {sql}")
|
||||
logger.debug(f"执行SQL: {sql}")
|
||||
try:
|
||||
await db.execute_query_dict(sql)
|
||||
await asyncio.wait_for(
|
||||
db.execute_query_dict(sql), timeout=DB_TIMEOUT_SECONDS
|
||||
)
|
||||
# await TestSQL.raw(sql)
|
||||
except Exception as e:
|
||||
logger.trace(f"执行SQL: {sql} 错误...", e=e)
|
||||
logger.debug(f"执行SQL: {sql} 错误...", e=e)
|
||||
if sql_list:
|
||||
logger.debug("SCRIPT_METHOD方法执行完毕!")
|
||||
logger.debug("开始生成数据库表结构...")
|
||||
await Tortoise.generate_schemas()
|
||||
logger.debug("数据库表结构生成完毕!")
|
||||
logger.info("Database loaded successfully!")
|
||||
except Exception as e:
|
||||
raise DbConnectError(f"数据库连接错误... e:{e}") from e
|
||||
|
||||
@ -78,6 +78,8 @@ class DbLockType(StrEnum):
|
||||
"""更新"""
|
||||
QUERY = "QUERY"
|
||||
"""查询"""
|
||||
UPSERT = "UPSERT"
|
||||
"""创建或更新"""
|
||||
|
||||
|
||||
class GoldHandle(StrEnum):
|
||||
|
||||
@ -49,6 +49,9 @@ async def _():
|
||||
try:
|
||||
for priority in priority_list:
|
||||
for func in priority_data[priority]:
|
||||
logger.debug(
|
||||
f"执行优先级 [{priority}] on_startup 方法: {func.__module__}"
|
||||
)
|
||||
if is_coroutine_callable(func):
|
||||
await func()
|
||||
else:
|
||||
|
||||
@ -4,7 +4,7 @@ from datetime import date, datetime
|
||||
import os
|
||||
from pathlib import Path
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import httpx
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
@ -30,38 +30,38 @@ class ResourceDirManager:
|
||||
临时文件管理器
|
||||
"""
|
||||
|
||||
temp_path = [] # noqa: RUF012
|
||||
temp_path: ClassVar[set[Path]] = set()
|
||||
|
||||
@classmethod
|
||||
def __tree_append(cls, path: Path):
|
||||
"""递归添加文件夹
|
||||
|
||||
参数:
|
||||
path: 文件夹路径
|
||||
"""
|
||||
def __tree_append(cls, path: Path, deep: int = 1, current: int = 0):
|
||||
"""递归添加文件夹"""
|
||||
if current >= deep and deep != -1:
|
||||
return
|
||||
path = path.resolve() # 标准化路径
|
||||
for f in os.listdir(path):
|
||||
file = path / f
|
||||
file = (path / f).resolve() # 标准化子路径
|
||||
if file.is_dir():
|
||||
if file not in cls.temp_path:
|
||||
cls.temp_path.append(file)
|
||||
logger.debug(f"添加临时文件夹: {path}")
|
||||
cls.__tree_append(file)
|
||||
cls.temp_path.add(file)
|
||||
logger.debug(f"添加临时文件夹: {file}")
|
||||
cls.__tree_append(file, deep, current + 1)
|
||||
|
||||
@classmethod
|
||||
def add_temp_dir(cls, path: str | Path, tree: bool = False):
|
||||
def add_temp_dir(cls, path: str | Path, tree: bool = False, deep: int = 1):
|
||||
"""添加临时清理文件夹,这些文件夹会被自动清理
|
||||
|
||||
参数:
|
||||
path: 文件夹路径
|
||||
tree: 是否递归添加文件夹
|
||||
deep: 深度, -1 为无限深度
|
||||
"""
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
if path not in cls.temp_path:
|
||||
cls.temp_path.append(path)
|
||||
cls.temp_path.add(path)
|
||||
logger.debug(f"添加临时文件夹: {path}")
|
||||
if tree:
|
||||
cls.__tree_append(path)
|
||||
cls.__tree_append(path, deep)
|
||||
|
||||
|
||||
class CountLimiter:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user