引入缓存机制 (#1889)

* 添加全局cache

*  构建缓存,hook使用缓存

*  新增数据库Model方法监控

*  数据库添加semaphore锁

* 🩹 优化webapi返回数据

*  添加增量缓存与缓存过期

* 🎨 优化检测代码结构

*  优化hook权限检测性能

* 🐛 添加新异常判断跳过权限检测

*  添加插件limit缓存

* 🎨 代码格式优化

* 🐛  修复代码导入

* 🐛 修复刷新时检查

* 👽 Rename exception for missing database URL in initialization

*  Update default database URL to SQLite in configuration

* 🔧 Update tortoise-orm and aiocache dependencies restrictions; add optional redis and asyncpg support

* 🐛 修复ban检测

* 🐛 修复所有插件关闭时缓存更新

* 🐛 尝试迁移至aiocache

* 🐛 完善aiocache缓存

*  代码性能优化

* 🐛 移除获取封禁缓存时的日志记录

* 🐛 修复缓存类型声明,优化封禁用户处理逻辑

* 🐛 优化LevelUser权限更新逻辑及数据库迁移

*  cache支持redis连接

* 🚨 auto fix by pre-commit hooks

*  :增强获取群组的安全性和准确性。同时,优化了缓存管理中的相关逻辑,确保缓存操作的一致性。

*  feat(auth_limit): 将插件初始化逻辑的启动装饰器更改为优先级管理器

* 🔧 修复日志记录级别

* 🔧 更新数据库连接字符串

* 🔧 更新数据库连接字符串为内存数据库,并优化权限检查逻辑

*  feat(cache): 增加缓存功能配置项,并新增数据访问层以支持缓存逻辑

* ♻️ 重构cache

*  feat(cache): 增强缓存管理,新增缓存字典和缓存列表功能,支持过期时间管理

* 🔧 修复Notebook类中的viewport高度设置,将其从1000调整为10

*  更新插件管理逻辑,替换缓存服务为CacheRoot并优化缓存失效处理

*  更新RegisterConfig类中的type字段

*  修复清理重复记录逻辑,确保检查记录的id属性有效性

*  超级无敌大优化,解决延迟与卡死问题

*  更新封禁功能,增加封禁时长参数和描述,优化插件信息返回结构

*  更新zhenxun_help.py中的viewport高度,将其从453调整为10,以优化页面显示效果

*  优化插件分类逻辑,增加插件ID排序,并更新插件信息返回结构

---------

Co-authored-by: BalconyJH <balconyjh@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
HibiKier 2025-07-14 22:35:29 +08:00 committed by GitHub
parent 6283c3d13d
commit 8649aaaa54
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
57 changed files with 5179 additions and 1479 deletions

View File

@ -27,6 +27,18 @@ QBOT_ID_DATA = '{
# 示例: "sqlite:data/db/zhenxun.db" 在data目录下建立db文件夹
DB_URL = ""
# NONE: 不使用缓存, MEMORY: 使用内存缓存, REDIS: 使用Redis缓存
CACHE_MODE = NONE
# REDIS配置使用REDIS替换Cache内存缓存
# REDIS地址
# REDIS_HOST = "127.0.0.1"
# REDIS端口
# REDIS_PORT = 6379
# REDIS密码
# REDIS_PASSWORD = ""
# REDIS过期时间
# REDIS_EXPIRE = 600
# 系统代理
# SYSTEM_PROXY = "http://127.0.0.1:7890"
@ -40,7 +52,7 @@ PLATFORM_SUPERUSERS = '
DRIVER=~fastapi+~httpx+~websockets
# LOG_LEVEL=DEBUG
# LOG_LEVEL = DEBUG
# 服务器和端口
HOST = 127.0.0.1
PORT = 8080

951
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -16,7 +16,7 @@ python = "^3.10"
playwright = "^1.41.1"
nonebot-adapter-onebot = "^2.3.1"
nonebot-plugin-apscheduler = "^0.5"
tortoise-orm = { extras = ["asyncpg"], version = "^0.20.0" }
tortoise-orm = "^0.20.0"
cattrs = "^23.2.3"
ruamel-yaml = "^0.18.5"
strenum = "^0.4.15"
@ -39,7 +39,7 @@ dateparser = "^1.2.0"
bilireq = "0.2.3post0"
python-jose = { extras = ["cryptography"], version = "^3.3.0" }
python-multipart = "^0.0.9"
aiocache = "^0.12.2"
aiocache = {extras = ["redis"], version = "^0.12.3"}
py-cpuinfo = "^9.0.0"
nonebot-plugin-alconna = "^0.54.0"
tenacity = "^9.0.0"
@ -47,6 +47,9 @@ nonebot-plugin-uninfo = ">0.4.1"
nonebot-plugin-waiter = "^0.8.1"
multidict = ">=6.0.0,!=6.3.2"
redis = { version = ">=5", optional = true }
asyncpg = { version = ">=0.20.0", optional = true }
[tool.poetry.group.dev.dependencies]
nonebug = "^0.4"
pytest-cov = "^5.0.0"
@ -57,6 +60,9 @@ respx = "^0.21.1"
ruff = "^0.8.0"
pre-commit = "^4.0.0"
[tool.poetry.extras]
redis = ["redis"]
postgresql = ["asyncpg"]
[tool.nonebot]
plugins = [

View File

@ -87,13 +87,17 @@ __plugin_meta__ = PluginMetadata(
smart_tools=[
AICallableTag(
name="call_ban",
description="某人多次(至少三次)辱骂你,调用此方法进行封禁",
description="如果你讨厌(好感度过低并让你感到困扰,或者多次辱骂你,调用此方法进行封禁,调用该方法后要告知用户被封禁和原因",
parameters=AICallableParam(
type="object",
properties={
"user_id": AICallableProperties(
type="string", description="用户的id"
),
"duration": AICallableProperties(
type="integer",
description="封禁时长选择的值只能是1-360单位为分钟如果频繁触发按情况增加",
),
},
required=["user_id"],
),

View File

@ -9,14 +9,14 @@ from zhenxun.services.log import logger
from zhenxun.utils.image_utils import BuildImage, ImageTemplate
async def call_ban(user_id: str):
async def call_ban(user_id: str, duration: int = 1):
"""调用ban
参数:
user_id: 用户id
"""
await BanConsole.ban(user_id, None, 9, 60 * 12)
logger.info("辱骂次数过多,已将用户加入黑名单...", "ban", session=user_id)
await BanConsole.ban(user_id, None, 9, duration * 60)
logger.info("被讨厌了,已将用户加入黑名单...", "ban", session=user_id)
class BanManage:
@ -114,7 +114,7 @@ class BanManage:
if not is_superuser and user_id and session.id1:
user_level = await LevelUser.get_user_level(session.id1, group_id)
if idx:
ban_data = await BanConsole.get_or_none(id=idx)
ban_data = await BanConsole.get_ban(id=idx)
if not ban_data:
return False, "该用户/群组不在黑名单中捏..."
if ban_data.ban_level > user_level:

View File

@ -1,10 +1,13 @@
import os
from typing import cast
from zhenxun.configs.path_config import DATA_PATH, IMAGE_PATH
from zhenxun.models.group_console import GroupConsole
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.task_info import TaskInfo
from zhenxun.utils.enum import BlockType, PluginType
from zhenxun.services.cache import CacheRoot
from zhenxun.utils.common_utils import CommonUtils
from zhenxun.utils.enum import BlockType, CacheType, PluginType
from zhenxun.utils.exception import GroupInfoNotFound
from zhenxun.utils.image_utils import BuildImage, ImageTemplate, RowStyle
@ -116,9 +119,7 @@ async def build_task(group_id: str | None) -> BuildImage:
column_name = ["ID", "模块", "名称", "群组状态", "全局状态", "运行时间"]
group = None
if group_id:
group = await GroupConsole.get_or_none(
group_id=group_id, channel_id__isnull=True
)
group = await GroupConsole.get_group(group_id=group_id)
if not group:
raise GroupInfoNotFound()
else:
@ -200,26 +201,26 @@ class PluginManager:
)
return f"成功将所有功能进群默认状态修改为: {'开启' if status else '关闭'}"
if group_id:
if group := await GroupConsole.get_or_none(
group_id=group_id, channel_id__isnull=True
):
module_list = await PluginInfo.filter(
plugin_type=PluginType.NORMAL
).values_list("module", flat=True)
if group := await GroupConsole.get_group(group_id=group_id):
module_list = cast(
list[str],
await PluginInfo.filter(plugin_type=PluginType.NORMAL).values_list(
"module", flat=True
),
)
if status:
for module in module_list:
group.block_plugin = group.block_plugin.replace(
f"<{module},", ""
)
# 开启所有功能 - 清空禁用列表
group.block_plugin = ""
else:
module_list = [f"<{module}" for module in module_list]
group.block_plugin = ",".join(module_list) + "," # type: ignore
# 关闭所有功能 - 将模块列表转换为禁用格式
group.block_plugin = CommonUtils.convert_module_format(module_list)
await group.save(update_fields=["block_plugin"])
return f"成功将此群组所有功能状态修改为: {'开启' if status else '关闭'}"
return "获取群组失败..."
await PluginInfo.filter(plugin_type=PluginType.NORMAL).update(
status=status, block_type=None if status else BlockType.ALL
)
await CacheRoot.invalidate_cache(CacheType.PLUGINS)
return f"成功将所有功能全局状态修改为: {'开启' if status else '关闭'}"
@classmethod
@ -232,9 +233,7 @@ class PluginManager:
返回:
bool: 是否醒来
"""
if c := await GroupConsole.get_or_none(
group_id=group_id, channel_id__isnull=True
):
if c := await GroupConsole.get_group(group_id=group_id):
return c.status
return False
@ -245,9 +244,11 @@ class PluginManager:
参数:
group_id: 群组id
"""
await GroupConsole.filter(group_id=group_id, channel_id__isnull=True).update(
status=False
group, _ = await GroupConsole.get_or_create(
group_id=group_id, channel_id__isnull=True
)
group.status = False
await group.save(update_fields=["status"])
@classmethod
async def wake(cls, group_id: str):
@ -256,9 +257,11 @@ class PluginManager:
参数:
group_id: 群组id
"""
await GroupConsole.filter(group_id=group_id, channel_id__isnull=True).update(
status=True
group, _ = await GroupConsole.get_or_create(
group_id=group_id, channel_id__isnull=True
)
group.status = True
await group.save(update_fields=["status"])
@classmethod
async def block(cls, module: str):
@ -267,7 +270,9 @@ class PluginManager:
参数:
module: 模块名
"""
await PluginInfo.filter(module=module).update(status=False)
if plugin := await PluginInfo.get_plugin(module=module):
plugin.status = False
await plugin.save(update_fields=["status"])
@classmethod
async def unblock(cls, module: str):
@ -276,7 +281,9 @@ class PluginManager:
参数:
module: 模块名
"""
await PluginInfo.filter(module=module).update(status=True)
if plugin := await PluginInfo.get_plugin(module=module):
plugin.status = True
await plugin.save(update_fields=["status"])
@classmethod
async def block_group_plugin(cls, plugin_name: str, group_id: str) -> str:
@ -437,17 +444,18 @@ class PluginManager:
"""
status_str = "关闭" if status else "开启"
if is_all:
modules = await TaskInfo.annotate().values_list("module", flat=True)
if modules:
module_list = cast(
list[str], await TaskInfo.annotate().values_list("module", flat=True)
)
if module_list:
group, _ = await GroupConsole.get_or_create(
group_id=group_id, channel_id__isnull=True
)
modules = [f"<{module}" for module in modules]
if status:
group.block_task = ",".join(modules) + "," # type: ignore
group.block_task = CommonUtils.convert_module_format(module_list)
else:
for module in modules:
group.block_task = group.block_task.replace(f"{module},", "")
# 开启所有模块 - 清空禁用列表
group.block_task = ""
await group.save(update_fields=["block_task"])
return f"已成功{status_str}全部被动技能!"
elif task := await TaskInfo.get_or_none(name=task_name):

View File

@ -1,13 +1,15 @@
from nonebot import on_message
from nonebot.plugin import PluginMetadata
from nonebot_plugin_alconna import UniMsg
from nonebot_plugin_session import EventSession
from nonebot_plugin_apscheduler import scheduler
from nonebot_plugin_uninfo import Uninfo
from zhenxun.configs.config import Config
from zhenxun.configs.utils import PluginExtraData, RegisterConfig
from zhenxun.models.chat_history import ChatHistory
from zhenxun.services.log import logger
from zhenxun.utils.enum import PluginType
from zhenxun.utils.utils import get_entity_ids
__plugin_meta__ = PluginMetadata(
name="消息存储",
@ -37,18 +39,34 @@ def rule(message: UniMsg) -> bool:
chat_history = on_message(rule=rule, priority=1, block=False)
TEMP_LIST = []
@chat_history.handle()
async def handle_message(message: UniMsg, session: EventSession):
"""处理消息存储"""
try:
await ChatHistory.create(
user_id=session.id1,
group_id=session.id2,
async def _(message: UniMsg, session: Uninfo):
entity = get_entity_ids(session)
TEMP_LIST.append(
ChatHistory(
user_id=entity.user_id,
group_id=entity.group_id,
text=str(message),
plain_text=message.extract_plain_text(),
bot_id=session.bot_id,
bot_id=session.self_id,
platform=session.platform,
)
)
@scheduler.scheduled_job(
"interval",
minutes=1,
)
async def _():
try:
message_list = TEMP_LIST.copy()
TEMP_LIST.clear()
if message_list:
await ChatHistory.bulk_create(message_list)
logger.debug(f"批量添加聊天记录 {len(message_list)}", "定时任务")
except Exception as e:
logger.warning("存储聊天记录失败", "chat_history", e=e)

View File

@ -45,11 +45,13 @@ async def classify_plugin(
"""
sort_data = await sort_type()
classify: dict[str, list] = {}
group = await GroupConsole.get_or_none(group_id=group_id) if group_id else None
group = await GroupConsole.get_group(group_id=group_id) if group_id else None
bot = await BotConsole.get_or_none(bot_id=session.self_id)
for menu, value in sort_data.items():
for plugin in value:
if not classify.get(menu):
classify[menu] = []
classify[menu].append(handle(bot, plugin, group, is_detail))
for value in classify.values():
value.sort(key=lambda x: x.id)
return classify

View File

@ -21,6 +21,8 @@ class Item(BaseModel):
"""插件名称"""
sta: int
"""插件状态"""
id: int
"""插件id"""
class PluginList(BaseModel):
@ -80,10 +82,9 @@ def __handle_item(
sta = 2
if f"{plugin.module}," in group.block_plugin:
sta = 1
if bot:
if f"{plugin.module}," in bot.block_plugins:
sta = 2
return Item(plugin_name=plugin.name, sta=sta)
if bot and f"{plugin.module}," in bot.block_plugins:
sta = 2
return Item(plugin_name=plugin.name, sta=sta, id=plugin.id)
def build_plugin_data(classify: dict[str, list[Item]]) -> list[dict[str, str]]:
@ -142,7 +143,7 @@ async def build_html_image(
template_name="zhenxun_menu.html",
templates={"plugin_list": plugin_list},
pages={
"viewport": {"width": 1903, "height": 975},
"viewport": {"width": 1903, "height": 10},
"base_url": f"file://{TEMPLATE_PATH}",
},
wait=2,

View File

@ -45,7 +45,7 @@ async def build_normal_image(group_id: str | None, is_detail: bool) -> BuildImag
color="black" if idx % 2 else "white",
)
curr_h = 10
group = await GroupConsole.get_or_none(group_id=group_id)
group = await GroupConsole.get_group(group_id=group_id) if group_id else None
for _, plugin in enumerate(plugin_list):
text_color = (255, 255, 255) if idx % 2 else (0, 0, 0)
if group and f"{plugin.module}," in group.block_plugin:
@ -80,7 +80,7 @@ async def build_normal_image(group_id: str | None, is_detail: bool) -> BuildImag
width, height = 10, 10
for s in [
"目前支持的功能列表:",
"可以通过 ‘帮助 [功能名称或功能Id] 来获取对应功能的使用方法",
"可以通过 '帮助 [功能名称或功能Id]' 来获取对应功能的使用方法",
]:
text = await BuildImage.build_text_image(s, "HYWenHei-85W.ttf", 24)
await result.paste(text, (width, height))

View File

@ -20,6 +20,12 @@ class Item(BaseModel):
"""插件名称"""
commands: list[str]
"""插件命令"""
id: str
"""插件id"""
status: bool
"""插件状态"""
has_superuser_help: bool
"""插件是否拥有超级用户帮助"""
def __handle_item(
@ -39,23 +45,36 @@ def __handle_item(
返回:
Item: Item
"""
status = True
has_superuser_help = False
nb_plugin = nonebot.get_plugin_by_module_name(plugin.module_path)
if nb_plugin and nb_plugin.metadata and nb_plugin.metadata.extra:
extra_data = PluginExtraData(**nb_plugin.metadata.extra)
if extra_data.superuser_help:
has_superuser_help = True
if not plugin.status:
if plugin.block_type == BlockType.ALL:
plugin.name = f"{plugin.name}(不可用)"
status = False
elif group and plugin.block_type == BlockType.GROUP:
plugin.name = f"{plugin.name}(不可用)"
status = False
elif not group and plugin.block_type == BlockType.PRIVATE:
plugin.name = f"{plugin.name}(不可用)"
status = False
elif group and f"{plugin.module}," in group.block_plugin:
plugin.name = f"{plugin.name}(不可用)"
status = False
elif bot and f"{plugin.module}," in bot.block_plugins:
plugin.name = f"{plugin.name}(不可用)"
status = False
commands = []
nb_plugin = nonebot.get_plugin_by_module_name(plugin.module_path)
if is_detail and nb_plugin and nb_plugin.metadata and nb_plugin.metadata.extra:
extra_data = PluginExtraData(**nb_plugin.metadata.extra)
commands = [cmd.command for cmd in extra_data.commands]
return Item(plugin_name=f"{plugin.id}-{plugin.name}", commands=commands)
return Item(
plugin_name=plugin.name,
commands=commands,
id=str(plugin.id),
status=status,
has_superuser_help=has_superuser_help,
)
def build_plugin_data(classify: dict[str, list[Item]]) -> list[dict[str, str]]:
@ -78,68 +97,10 @@ def build_plugin_data(classify: dict[str, list[Item]]) -> list[dict[str, str]]:
}
for menu, value in classify.items()
]
plugin_list = build_line_data(plugin_list)
plugin_list.insert(
0,
build_plugin_line(
menu_key if menu_key not in ["normal", "功能"] else "主要功能",
max_data,
30,
100,
True,
),
)
return plugin_list
def build_plugin_line(
name: str, items: list, left: int, width: int | None = None, is_max: bool = False
) -> dict:
"""构造插件行数据
参数:
name: 菜单名称
items: 插件名称列表
left: 左边距
width: 总插件长度.
is_max: 是否为最大长度的插件菜单
返回:
dict: 插件数据
"""
_plugins = []
width = width or 50
if len(items) // 2 > 6 or is_max:
width = 100
plugin_list1 = []
plugin_list2 = []
for i in range(len(items)):
if i % 2:
plugin_list1.append(items[i])
else:
plugin_list2.append(items[i])
_plugins = [(30, 50, plugin_list1), (0, 50, plugin_list2)]
else:
_plugins = [(left, 100, items)]
return {"name": name, "items": _plugins, "width": width}
def build_line_data(plugin_list: list[dict]) -> list[dict]:
"""构造插件数据
参数:
plugin_list: 插件列表
返回:
list[dict]: 插件数据
"""
left = 30
data = []
plugin_list.insert(0, {"name": menu_key, "items": max_data})
for plugin in plugin_list:
data.append(build_plugin_line(plugin["name"], plugin["items"], left))
if len(plugin["items"]) // 2 <= 6:
left = 15 if left == 30 else 30
return data
plugin["items"].sort(key=lambda x: x.id)
return plugin_list
async def build_zhenxun_image(
@ -160,6 +121,7 @@ async def build_zhenxun_image(
width = int(637 * 1.5) if is_detail else 637
title_font = int(53 * 1.5) if is_detail else 53
tip_font = int(19 * 1.5) if is_detail else 19
plugin_count = sum(len(plugin["items"]) for plugin in plugin_list)
return await template_to_pic(
template_path=str((TEMPLATE_PATH / "ss_menu").absolute()),
template_name="main.html",
@ -170,10 +132,11 @@ async def build_zhenxun_image(
"width": width,
"font_size": (title_font, tip_font),
"is_detail": is_detail,
"plugin_count": plugin_count,
}
},
pages={
"viewport": {"width": width, "height": 453},
"viewport": {"width": width, "height": 10},
"base_url": f"file://{TEMPLATE_PATH}",
},
wait=2,

View File

@ -1,597 +0,0 @@
from typing import ClassVar
from nonebot.adapters import Bot, Event
from nonebot.adapters.onebot.v11 import PokeNotifyEvent
from nonebot.exception import IgnoredException
from nonebot.matcher import Matcher
from nonebot_plugin_alconna import At, UniMsg
from nonebot_plugin_session import EventSession
from pydantic import BaseModel
from tortoise.exceptions import IntegrityError
from zhenxun.configs.config import Config
from zhenxun.models.bot_console import BotConsole
from zhenxun.models.group_console import GroupConsole
from zhenxun.models.level_user import LevelUser
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.plugin_limit import PluginLimit
from zhenxun.models.sign_user import SignUser
from zhenxun.models.user_console import UserConsole
from zhenxun.services.log import logger
from zhenxun.utils.enum import (
BlockType,
GoldHandle,
LimitWatchType,
PluginLimitType,
PluginType,
)
from zhenxun.utils.exception import InsufficientGold
from zhenxun.utils.message import MessageUtils
from zhenxun.utils.utils import CountLimiter, FreqLimiter, UserBlockLimiter
base_config = Config.get("hook")
class Limit(BaseModel):
limit: PluginLimit
limiter: FreqLimiter | UserBlockLimiter | CountLimiter
class Config:
arbitrary_types_allowed = True
class LimitManage:
add_module: ClassVar[list] = []
cd_limit: ClassVar[dict[str, Limit]] = {}
block_limit: ClassVar[dict[str, Limit]] = {}
count_limit: ClassVar[dict[str, Limit]] = {}
@classmethod
def add_limit(cls, limit: PluginLimit):
"""添加限制
参数:
limit: PluginLimit
"""
if limit.module not in cls.add_module:
cls.add_module.append(limit.module)
if limit.limit_type == PluginLimitType.BLOCK:
cls.block_limit[limit.module] = Limit(
limit=limit, limiter=UserBlockLimiter()
)
elif limit.limit_type == PluginLimitType.CD:
cls.cd_limit[limit.module] = Limit(
limit=limit, limiter=FreqLimiter(limit.cd)
)
elif limit.limit_type == PluginLimitType.COUNT:
cls.count_limit[limit.module] = Limit(
limit=limit, limiter=CountLimiter(limit.max_count)
)
@classmethod
def unblock(
cls, module: str, user_id: str, group_id: str | None, channel_id: str | None
):
"""解除插件block
参数:
module: 模块名
user_id: 用户id
group_id: 群组id
channel_id: 频道id
"""
if limit_model := cls.block_limit.get(module):
limit = limit_model.limit
limiter: UserBlockLimiter = limit_model.limiter # type: ignore
key_type = user_id
if group_id and limit.watch_type == LimitWatchType.GROUP:
key_type = channel_id or group_id
logger.debug(
f"解除对象: {key_type} 的block限制",
"AuthChecker",
session=user_id,
group_id=group_id,
)
limiter.set_false(key_type)
@classmethod
async def check(
cls,
module: str,
user_id: str,
group_id: str | None,
channel_id: str | None,
session: EventSession,
):
"""检测限制
参数:
module: 模块名
user_id: 用户id
group_id: 群组id
channel_id: 频道id
session: Session
异常:
IgnoredException: IgnoredException
"""
if limit_model := cls.cd_limit.get(module):
await cls.__check(limit_model, user_id, group_id, channel_id, session)
if limit_model := cls.block_limit.get(module):
await cls.__check(limit_model, user_id, group_id, channel_id, session)
if limit_model := cls.count_limit.get(module):
await cls.__check(limit_model, user_id, group_id, channel_id, session)
@classmethod
async def __check(
cls,
limit_model: Limit | None,
user_id: str,
group_id: str | None,
channel_id: str | None,
session: EventSession,
):
"""检测限制
参数:
limit_model: Limit
user_id: 用户id
group_id: 群组id
channel_id: 频道id
session: Session
异常:
IgnoredException: IgnoredException
"""
if not limit_model:
return
limit = limit_model.limit
limiter = limit_model.limiter
is_limit = (
LimitWatchType.ALL
or (group_id and limit.watch_type == LimitWatchType.GROUP)
or (not group_id and limit.watch_type == LimitWatchType.USER)
)
key_type = user_id
if group_id and limit.watch_type == LimitWatchType.GROUP:
key_type = channel_id or group_id
if is_limit and not limiter.check(key_type):
if limit.result:
await MessageUtils.build_message(limit.result).send()
logger.debug(
f"{limit.module}({limit.limit_type}) 正在限制中...",
"AuthChecker",
session=session,
)
raise IgnoredException(f"{limit.module} 正在限制中...")
else:
logger.debug(
f"开始进行限制 {limit.module}({limit.limit_type})...",
"AuthChecker",
session=user_id,
group_id=group_id,
)
if isinstance(limiter, FreqLimiter):
limiter.start_cd(key_type)
if isinstance(limiter, UserBlockLimiter):
limiter.set_true(key_type)
if isinstance(limiter, CountLimiter):
limiter.increase(key_type)
class IsSuperuserException(Exception):
pass
class AuthChecker:
"""
权限检查
"""
def __init__(self):
check_notice_info_cd = Config.get_config("hook", "CHECK_NOTICE_INFO_CD")
if check_notice_info_cd is None or check_notice_info_cd < 0:
raise ValueError("模块: [hook], 配置项: [CHECK_NOTICE_INFO_CD] 为空或小于0")
self._flmt = FreqLimiter(check_notice_info_cd)
self._flmt_g = FreqLimiter(check_notice_info_cd)
self._flmt_s = FreqLimiter(check_notice_info_cd)
self._flmt_c = FreqLimiter(check_notice_info_cd)
def is_send_limit_message(self, plugin: PluginInfo, sid: str) -> bool:
"""是否发送提示消息
参数:
plugin: PluginInfo
返回:
bool: 是否发送提示消息
"""
if not base_config.get("IS_SEND_TIP_MESSAGE"):
return False
if plugin.plugin_type == PluginType.DEPENDANT:
return False
if plugin.ignore_prompt:
return False
return self._flmt_s.check(sid)
async def auth(
self,
matcher: Matcher,
event: Event,
bot: Bot,
session: EventSession,
message: UniMsg,
):
"""权限检查
参数:
matcher: matcher
bot: bot
session: EventSession
message: UniMsg
"""
is_ignore = False
cost_gold = 0
user_id = session.id1
group_id = session.id3
channel_id = session.id2
if not group_id:
group_id = channel_id
channel_id = None
if matcher.type == "notice" and not isinstance(event, PokeNotifyEvent):
"""过滤除poke外的notice"""
return
if user_id and matcher.plugin and (module_path := matcher.plugin.module_name):
try:
user = await UserConsole.get_user(user_id, session.platform)
except IntegrityError as e:
logger.debug(
"重复创建用户,已跳过该次权限...",
"AuthChecker",
session=session,
e=e,
)
return
if plugin := await PluginInfo.get_or_none(module_path=module_path):
if plugin.plugin_type == PluginType.HIDDEN:
logger.debug(
f"插件: {plugin.name}:{plugin.module} "
"为HIDDEN已跳过权限检查..."
)
return
try:
cost_gold = await self.auth_cost(user, plugin, session)
if session.id1 in bot.config.superusers:
if plugin.plugin_type == PluginType.SUPERUSER:
raise IsSuperuserException()
if not plugin.limit_superuser:
cost_gold = 0
raise IsSuperuserException()
await self.auth_bot(plugin, bot.self_id)
await self.auth_group(plugin, session, message)
await self.auth_admin(plugin, session)
await self.auth_plugin(plugin, session, event)
await self.auth_limit(plugin, session)
except IsSuperuserException:
logger.debug(
"超级用户或被ban跳过权限检测...", "AuthChecker", session=session
)
except IgnoredException:
is_ignore = True
LimitManage.unblock(
matcher.plugin.name, user_id, group_id, channel_id
)
except AssertionError as e:
is_ignore = True
logger.debug("消息无法发送", session=session, e=e)
if cost_gold and user_id:
"""花费金币"""
try:
await UserConsole.reduce_gold(
user_id,
cost_gold,
GoldHandle.PLUGIN,
matcher.plugin.name if matcher.plugin else "",
session.platform,
)
except InsufficientGold:
if u := await UserConsole.get_user(user_id):
u.gold = 0
await u.save(update_fields=["gold"])
logger.debug(
f"调用功能花费金币: {cost_gold}", "AuthChecker", session=session
)
if is_ignore:
raise IgnoredException("权限检测 ignore")
async def auth_bot(self, plugin: PluginInfo, bot_id: str):
"""机器人权限
参数:
plugin: PluginInfo
bot_id: bot_id
"""
if not await BotConsole.get_bot_status(bot_id):
logger.debug("Bot休眠中阻断权限检测...", "AuthChecker")
raise IgnoredException("BotConsole休眠权限检测 ignore")
if await BotConsole.is_block_plugin(bot_id, plugin.module):
logger.debug(
f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭...",
"AuthChecker",
)
raise IgnoredException("BotConsole插件权限检测 ignore")
async def auth_limit(self, plugin: PluginInfo, session: EventSession):
"""插件限制
参数:
plugin: PluginInfo
session: EventSession
"""
user_id = session.id1
group_id = session.id3
channel_id = session.id2
if not group_id:
group_id = channel_id
channel_id = None
if plugin.module not in LimitManage.add_module:
limit_list: list[PluginLimit] = await plugin.plugin_limit.filter(
status=True
).all() # type: ignore
for limit in limit_list:
LimitManage.add_limit(limit)
if user_id:
await LimitManage.check(
plugin.module, user_id, group_id, channel_id, session
)
async def auth_plugin(
self, plugin: PluginInfo, session: EventSession, event: Event
):
"""插件状态
参数:
plugin: PluginInfo
session: EventSession
"""
group_id = session.id3
channel_id = session.id2
if not group_id:
group_id = channel_id
channel_id = None
if user_id := session.id1:
if plugin.impression > 0:
sign_user = await SignUser.get_user(user_id)
if float(sign_user.impression) < plugin.impression:
if self.is_send_limit_message(plugin, user_id):
self._flmt_s.start_cd(user_id)
await MessageUtils.build_message(
f"好感度不足哦,当前功能需要好感度: {plugin.impression}"
"请继续签到提升好感度吧!"
).send(reply_to=True)
logger.debug(
f"{plugin.name}({plugin.module}) 用户好感度不足...",
"AuthChecker",
session=session,
)
raise IgnoredException("好感度不足...")
if group_id:
sid = group_id or user_id
if await GroupConsole.is_superuser_block_plugin(
group_id, plugin.module
):
"""超级用户群组插件状态"""
if self.is_send_limit_message(plugin, sid):
self._flmt_s.start_cd(group_id or user_id)
await MessageUtils.build_message(
"超级管理员禁用了该群此功能..."
).send(reply_to=True)
logger.debug(
f"{plugin.name}({plugin.module}) 超级管理员禁用了该群此功能...",
"AuthChecker",
session=session,
)
raise IgnoredException("超级管理员禁用了该群此功能...")
if await GroupConsole.is_normal_block_plugin(group_id, plugin.module):
"""群组插件状态"""
if self.is_send_limit_message(plugin, sid):
self._flmt_s.start_cd(group_id or user_id)
await MessageUtils.build_message("该群未开启此功能...").send(
reply_to=True
)
logger.debug(
f"{plugin.name}({plugin.module}) 未开启此功能...",
"AuthChecker",
session=session,
)
raise IgnoredException("该群未开启此功能...")
if plugin.block_type == BlockType.GROUP:
"""全局群组禁用"""
try:
if self.is_send_limit_message(plugin, sid):
self._flmt_c.start_cd(group_id)
await MessageUtils.build_message(
"该功能在群组中已被禁用..."
).send(reply_to=True)
except Exception as e:
logger.error(
"auth_plugin 发送消息失败",
"AuthChecker",
session=session,
e=e,
)
logger.debug(
f"{plugin.name}({plugin.module}) 该插件在群组中已被禁用...",
"AuthChecker",
session=session,
)
raise IgnoredException("该插件在群组中已被禁用...")
else:
sid = user_id
if plugin.block_type == BlockType.PRIVATE:
"""全局私聊禁用"""
try:
if self.is_send_limit_message(plugin, sid):
self._flmt_c.start_cd(user_id)
await MessageUtils.build_message(
"该功能在私聊中已被禁用..."
).send()
except Exception as e:
logger.error(
"auth_admin 发送消息失败",
"AuthChecker",
session=session,
e=e,
)
logger.debug(
f"{plugin.name}({plugin.module}) 该插件在私聊中已被禁用...",
"AuthChecker",
session=session,
)
raise IgnoredException("该插件在私聊中已被禁用...")
if not plugin.status and plugin.block_type == BlockType.ALL:
"""全局状态"""
if group_id and await GroupConsole.is_super_group(group_id):
raise IsSuperuserException()
logger.debug(
f"{plugin.name}({plugin.module}) 全局未开启此功能...",
"AuthChecker",
session=session,
)
if self.is_send_limit_message(plugin, sid):
self._flmt_s.start_cd(group_id or user_id)
await MessageUtils.build_message("全局未开启此功能...").send()
raise IgnoredException("全局未开启此功能...")
async def auth_admin(self, plugin: PluginInfo, session: EventSession):
"""管理员命令 个人权限
参数:
plugin: PluginInfo
session: EventSession
"""
user_id = session.id1
if user_id and plugin.admin_level:
if group_id := session.id3 or session.id2:
if not await LevelUser.check_level(
user_id, group_id, plugin.admin_level
):
try:
if self._flmt.check(user_id):
self._flmt.start_cd(user_id)
await MessageUtils.build_message(
[
At(flag="user", target=user_id),
f"你的权限不足喔,"
f"该功能需要的权限等级: {plugin.admin_level}",
]
).send(reply_to=True)
except Exception as e:
logger.error(
"auth_admin 发送消息失败",
"AuthChecker",
session=session,
e=e,
)
logger.debug(
f"{plugin.name}({plugin.module}) 管理员权限不足...",
"AuthChecker",
session=session,
)
raise IgnoredException("管理员权限不足...")
elif not await LevelUser.check_level(user_id, None, plugin.admin_level):
try:
await MessageUtils.build_message(
f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}"
).send()
except Exception as e:
logger.error(
"auth_admin 发送消息失败", "AuthChecker", session=session, e=e
)
logger.debug(
f"{plugin.name}({plugin.module}) 管理员权限不足...",
"AuthChecker",
session=session,
)
raise IgnoredException("权限不足")
async def auth_group(
self, plugin: PluginInfo, session: EventSession, message: UniMsg
):
"""群黑名单检测 群总开关检测
参数:
plugin: PluginInfo
session: EventSession
message: UniMsg
"""
if not (group_id := session.id3 or session.id2):
return
text = message.extract_plain_text()
group = await GroupConsole.get_group(group_id)
if not group:
"""群不存在"""
logger.debug(
"群组信息不存在...",
"AuthChecker",
session=session,
)
raise IgnoredException("群不存在")
if group.level < 0:
"""群权限小于0"""
logger.debug(
"群黑名单, 群权限-1...",
"AuthChecker",
session=session,
)
raise IgnoredException("群黑名单")
if not group.status:
"""群休眠"""
if text.strip() != "醒来":
logger.debug("群休眠状态...", "AuthChecker", session=session)
raise IgnoredException("群休眠状态")
if plugin.level > group.level:
"""插件等级大于群等级"""
logger.debug(
f"{plugin.name}({plugin.module}) 群等级限制.."
f"该功能需要的群等级: {plugin.level}..",
"AuthChecker",
session=session,
)
raise IgnoredException(f"{plugin.name}({plugin.module}) 群等级限制...")
async def auth_cost(
self, user: UserConsole, plugin: PluginInfo, session: EventSession
) -> int:
"""检测是否满足金币条件
参数:
user: UserConsole
plugin: PluginInfo
session: EventSession
返回:
int: 需要消耗的金币
"""
if user.gold < plugin.cost_gold:
"""插件消耗金币不足"""
try:
await MessageUtils.build_message(
f"金币不足..该功能需要{plugin.cost_gold}金币.."
).send()
except Exception as e:
logger.error(
"auth_cost 发送消息失败", "AuthChecker", session=session, e=e
)
logger.debug(
f"{plugin.name}({plugin.module}) 金币限制.."
f"该功能需要{plugin.cost_gold}金币..",
"AuthChecker",
session=session,
)
raise IgnoredException(f"{plugin.name}({plugin.module}) 金币限制...")
return plugin.cost_gold
checker = AuthChecker()

View File

@ -0,0 +1,99 @@
import asyncio
import time
from nonebot_plugin_alconna import At
from nonebot_plugin_uninfo import Uninfo
from zhenxun.models.level_user import LevelUser
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.data_access import DataAccess
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
from zhenxun.services.log import logger
from zhenxun.utils.utils import get_entity_ids
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
from .exception import SkipPluginException
from .utils import send_message
async def auth_admin(plugin: PluginInfo, session: Uninfo):
"""管理员命令 个人权限
参数:
plugin: PluginInfo
session: Uninfo
"""
start_time = time.time()
if not plugin.admin_level:
return
try:
entity = get_entity_ids(session)
level_dao = DataAccess(LevelUser)
# 并行查询用户权限数据
global_user: LevelUser | None = None
group_users: LevelUser | None = None
# 查询全局权限
global_user_task = level_dao.safe_get_or_none(
user_id=session.user.id, group_id__isnull=True
)
# 如果在群组中,查询群组权限
group_users_task = None
if entity.group_id:
group_users_task = level_dao.safe_get_or_none(
user_id=session.user.id, group_id=entity.group_id
)
# 等待查询完成,添加超时控制
try:
results = await asyncio.wait_for(
asyncio.gather(global_user_task, group_users_task or asyncio.sleep(0)),
timeout=DB_TIMEOUT_SECONDS,
)
global_user = results[0]
group_users = results[1] if group_users_task else None
except asyncio.TimeoutError:
logger.error(f"查询用户权限超时: user_id={session.user.id}", LOGGER_COMMAND)
# 超时时不阻塞,继续执行
return
user_level = global_user.user_level if global_user else 0
if entity.group_id and group_users:
user_level = max(user_level, group_users.user_level)
if user_level < plugin.admin_level:
await send_message(
session,
[
At(flag="user", target=session.user.id),
f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}",
],
entity.user_id,
)
raise SkipPluginException(
f"{plugin.name}({plugin.module}) 管理员权限不足..."
)
elif global_user:
if global_user.user_level < plugin.admin_level:
await send_message(
session,
f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}",
)
raise SkipPluginException(
f"{plugin.name}({plugin.module}) 管理员权限不足..."
)
finally:
# 记录执行时间
elapsed = time.time() - start_time
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
logger.warning(
f"auth_admin 耗时: {elapsed:.3f}s, plugin={plugin.module}",
LOGGER_COMMAND,
session=session,
)

View File

@ -0,0 +1,303 @@
import asyncio
import time
from nonebot.adapters import Bot
from nonebot.matcher import Matcher
from nonebot_plugin_alconna import At
from nonebot_plugin_uninfo import Uninfo
from zhenxun.configs.config import Config
from zhenxun.models.ban_console import BanConsole
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.data_access import DataAccess
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
from zhenxun.services.log import logger
from zhenxun.utils.enum import PluginType
from zhenxun.utils.utils import EntityIDs, get_entity_ids
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
from .exception import SkipPluginException
from .utils import freq, send_message
Config.add_plugin_config(
"hook",
"BAN_RESULT",
"才不会给你发消息.",
help="对被ban用户发送的消息",
)
def calculate_ban_time(ban_record: BanConsole | None) -> int:
"""根据ban记录计算剩余ban时间
参数:
ban_record: BanConsole记录
返回:
int: ban剩余时长-1时为永久ban0表示未被ban
"""
if not ban_record:
return 0
if ban_record.duration == -1:
return -1
_time = time.time() - (ban_record.ban_time + ban_record.duration)
return 0 if _time > 0 else int(abs(_time))
async def is_ban(user_id: str | None, group_id: str | None) -> int:
"""检查用户或群组是否被ban
参数:
user_id: 用户ID
group_id: 群组ID
返回:
int: ban的剩余时间0表示未被ban
"""
if not user_id and not group_id:
return 0
start_time = time.time()
ban_dao = DataAccess(BanConsole)
# 分别获取用户在群组中的ban记录和全局ban记录
group_user = None
user = None
try:
# 并行查询用户和群组的 ban 记录
tasks = []
if user_id and group_id:
tasks.append(ban_dao.safe_get_or_none(user_id=user_id, group_id=group_id))
if user_id:
tasks.append(
ban_dao.safe_get_or_none(user_id=user_id, group_id__isnull=True)
)
# 等待所有查询完成,添加超时控制
if tasks:
try:
ban_records = await asyncio.wait_for(
asyncio.gather(*tasks), timeout=DB_TIMEOUT_SECONDS
)
if len(tasks) == 2:
group_user, user = ban_records
elif user_id and group_id:
group_user = ban_records[0]
else:
user = ban_records[0]
except asyncio.TimeoutError:
logger.error(
f"查询ban记录超时: user_id={user_id}, group_id={group_id}",
LOGGER_COMMAND,
)
# 超时时返回0避免阻塞
return 0
# 检查记录并计算ban时间
results = []
if group_user:
results.append(group_user)
if user:
results.append(user)
# 如果没有找到记录返回0
if not results:
return 0
logger.debug(f"查询到的ban记录: {results}", LOGGER_COMMAND)
# 检查所有记录找出最严格的ban时间最长的
max_ban_time: int = 0
for result in results:
if result.duration > 0 or result.duration == -1:
# 直接计算ban时间避免再次查询数据库
ban_time = calculate_ban_time(result)
if ban_time == -1 or ban_time > max_ban_time:
max_ban_time = ban_time
return max_ban_time
finally:
# 记录执行时间
elapsed = time.time() - start_time
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
logger.warning(
f"is_ban 耗时: {elapsed:.3f}s",
LOGGER_COMMAND,
session=user_id,
group_id=group_id,
)
def check_plugin_type(matcher: Matcher) -> bool:
"""判断插件类型是否是隐藏插件
参数:
matcher: Matcher
返回:
bool: 是否为隐藏插件
"""
if plugin := matcher.plugin:
if metadata := plugin.metadata:
extra = metadata.extra
if extra.get("plugin_type") in [PluginType.HIDDEN]:
return False
return True
def format_time(time_val: float) -> str:
"""格式化时间
参数:
time_val: ban时长
返回:
str: 格式化时间文本
"""
if time_val == -1:
return ""
time_val = abs(int(time_val))
if time_val < 60:
time_str = f"{time_val!s}"
else:
minute = int(time_val / 60)
if minute > 60:
hours = minute // 60
minute %= 60
time_str = f"{hours} 小时 {minute}分钟"
else:
time_str = f"{minute} 分钟"
return time_str
async def group_handle(group_id: str) -> None:
"""群组ban检查
参数:
group_id: 群组id
异常:
SkipPluginException: 群组处于黑名单
"""
start_time = time.time()
try:
if await is_ban(None, group_id):
raise SkipPluginException("群组处于黑名单中...")
finally:
# 记录执行时间
elapsed = time.time() - start_time
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
logger.warning(
f"group_handle 耗时: {elapsed:.3f}s",
LOGGER_COMMAND,
group_id=group_id,
)
async def user_handle(module: str, entity: EntityIDs, session: Uninfo) -> None:
"""用户ban检查
参数:
module: 插件模块名
entity: 实体ID信息
session: Uninfo
异常:
SkipPluginException: 用户处于黑名单
"""
start_time = time.time()
try:
ban_result = Config.get_config("hook", "BAN_RESULT")
time_val = await is_ban(entity.user_id, entity.group_id)
if not time_val:
return
time_str = format_time(time_val)
plugin_dao = DataAccess(PluginInfo)
try:
db_plugin = await asyncio.wait_for(
plugin_dao.safe_get_or_none(module=module), timeout=DB_TIMEOUT_SECONDS
)
except asyncio.TimeoutError:
logger.error(f"查询插件信息超时: {module}", LOGGER_COMMAND)
# 超时时不阻塞,继续执行
raise SkipPluginException("用户处于黑名单中...")
if (
db_plugin
and not db_plugin.ignore_prompt
and time_val != -1
and ban_result
and freq.is_send_limit_message(db_plugin, entity.user_id, False)
):
try:
await asyncio.wait_for(
send_message(
session,
[
At(flag="user", target=entity.user_id),
f"{ban_result}\n在..在 {time_str} 后才会理你喔",
],
entity.user_id,
),
timeout=DB_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.error(f"发送消息超时: {entity.user_id}", LOGGER_COMMAND)
raise SkipPluginException("用户处于黑名单中...")
finally:
# 记录执行时间
elapsed = time.time() - start_time
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
logger.warning(
f"user_handle 耗时: {elapsed:.3f}s",
LOGGER_COMMAND,
session=session,
)
async def auth_ban(matcher: Matcher, bot: Bot, session: Uninfo) -> None:
"""权限检查 - ban 检查
参数:
matcher: Matcher
bot: Bot
session: Uninfo
"""
start_time = time.time()
try:
if not check_plugin_type(matcher):
return
if not matcher.plugin_name:
return
entity = get_entity_ids(session)
if entity.user_id in bot.config.superusers:
return
if entity.group_id:
try:
await asyncio.wait_for(
group_handle(entity.group_id), timeout=DB_TIMEOUT_SECONDS
)
except asyncio.TimeoutError:
logger.error(f"群组ban检查超时: {entity.group_id}", LOGGER_COMMAND)
# 超时时不阻塞,继续执行
if entity.user_id:
try:
await asyncio.wait_for(
user_handle(matcher.plugin_name, entity, session),
timeout=DB_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.error(f"用户ban检查超时: {entity.user_id}", LOGGER_COMMAND)
# 超时时不阻塞,继续执行
finally:
# 记录总执行时间
elapsed = time.time() - start_time
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
logger.warning(
f"auth_ban 总耗时: {elapsed:.3f}s, plugin={matcher.plugin_name}",
LOGGER_COMMAND,
session=session,
)

View File

@ -0,0 +1,55 @@
import asyncio
import time
from zhenxun.models.bot_console import BotConsole
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.data_access import DataAccess
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
from zhenxun.services.log import logger
from zhenxun.utils.common_utils import CommonUtils
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
from .exception import SkipPluginException
async def auth_bot(plugin: PluginInfo, bot_id: str):
"""bot层面的权限检查
参数:
plugin: PluginInfo
bot_id: bot id
异常:
SkipPluginException: 忽略插件
SkipPluginException: 忽略插件
"""
start_time = time.time()
try:
# 从数据库或缓存中获取 bot 信息
bot_dao = DataAccess(BotConsole)
try:
bot: BotConsole | None = await asyncio.wait_for(
bot_dao.safe_get_or_none(bot_id=bot_id), timeout=DB_TIMEOUT_SECONDS
)
except asyncio.TimeoutError:
logger.error(f"查询Bot信息超时: bot_id={bot_id}", LOGGER_COMMAND)
# 超时时不阻塞,继续执行
return
if not bot or not bot.status:
raise SkipPluginException("Bot不存在或休眠中阻断权限检测...")
if CommonUtils.format(plugin.module) in bot.block_plugins:
raise SkipPluginException(
f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭..."
)
finally:
# 记录执行时间
elapsed = time.time() - start_time
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
logger.warning(
f"auth_bot 耗时: {elapsed:.3f}s, "
f"bot_id={bot_id}, plugin={plugin.module}",
LOGGER_COMMAND,
)

View File

@ -0,0 +1,41 @@
import time
from nonebot_plugin_uninfo import Uninfo
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.user_console import UserConsole
from zhenxun.services.log import logger
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
from .exception import SkipPluginException
from .utils import send_message
async def auth_cost(user: UserConsole, plugin: PluginInfo, session: Uninfo) -> int:
"""检测是否满足金币条件
参数:
user: UserConsole
plugin: PluginInfo
session: Uninfo
返回:
int: 需要消耗的金币
"""
start_time = time.time()
try:
if user.gold < plugin.cost_gold:
"""插件消耗金币不足"""
await send_message(session, f"金币不足..该功能需要{plugin.cost_gold}金币..")
raise SkipPluginException(f"{plugin.name}({plugin.module}) 金币限制...")
return plugin.cost_gold
finally:
# 记录执行时间
elapsed = time.time() - start_time
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
logger.warning(
f"auth_cost 耗时: {elapsed:.3f}s, plugin={plugin.module}",
LOGGER_COMMAND,
session=session,
)

View File

@ -0,0 +1,68 @@
import asyncio
import time
from nonebot_plugin_alconna import UniMsg
from zhenxun.models.group_console import GroupConsole
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.data_access import DataAccess
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
from zhenxun.services.log import logger
from zhenxun.utils.utils import EntityIDs
from .config import LOGGER_COMMAND, WARNING_THRESHOLD, SwitchEnum
from .exception import SkipPluginException
async def auth_group(plugin: PluginInfo, entity: EntityIDs, message: UniMsg):
"""群黑名单检测 群总开关检测
参数:
plugin: PluginInfo
entity: EntityIDs
message: UniMsg
"""
start_time = time.time()
if not entity.group_id:
return
try:
text = message.extract_plain_text()
# 从数据库或缓存中获取群组信息
group_dao = DataAccess(GroupConsole)
try:
group: GroupConsole | None = await asyncio.wait_for(
group_dao.safe_get_or_none(
group_id=entity.group_id, channel_id__isnull=True
),
timeout=DB_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.error("查询群组信息超时", LOGGER_COMMAND, session=entity.user_id)
# 超时时不阻塞,继续执行
return
if not group:
raise SkipPluginException("群组信息不存在...")
if group.level < 0:
raise SkipPluginException("群组黑名单, 目标群组群权限权限-1...")
if text.strip() != SwitchEnum.ENABLE and not group.status:
raise SkipPluginException("群组休眠状态...")
if plugin.level > group.level:
raise SkipPluginException(
f"{plugin.name}({plugin.module}) 群等级限制,"
f"该功能需要的群等级: {plugin.level}..."
)
finally:
# 记录执行时间
elapsed = time.time() - start_time
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
logger.warning(
f"auth_group 耗时: {elapsed:.3f}s, plugin={plugin.module}",
LOGGER_COMMAND,
session=entity.user_id,
group_id=entity.group_id,
)

View File

@ -0,0 +1,318 @@
import asyncio
import time
from typing import ClassVar
import nonebot
from nonebot_plugin_uninfo import Uninfo
from pydantic import BaseModel
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.plugin_limit import PluginLimit
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
from zhenxun.services.log import logger
from zhenxun.utils.enum import LimitWatchType, PluginLimitType
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from zhenxun.utils.message import MessageUtils
from zhenxun.utils.utils import (
CountLimiter,
FreqLimiter,
UserBlockLimiter,
get_entity_ids,
)
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
from .exception import SkipPluginException
driver = nonebot.get_driver()
@PriorityLifecycle.on_startup(priority=5)
async def _():
"""初始化限制"""
await LimitManager.init_limit()
class Limit(BaseModel):
limit: PluginLimit
limiter: FreqLimiter | UserBlockLimiter | CountLimiter
class Config:
arbitrary_types_allowed = True
class LimitManager:
add_module: ClassVar[list] = []
last_update_time: ClassVar[float] = 0
update_interval: ClassVar[float] = 6000 # 1小时更新一次
is_updating: ClassVar[bool] = False # 防止并发更新
cd_limit: ClassVar[dict[str, Limit]] = {}
block_limit: ClassVar[dict[str, Limit]] = {}
count_limit: ClassVar[dict[str, Limit]] = {}
# 模块限制缓存,避免频繁查询数据库
module_limit_cache: ClassVar[dict[str, tuple[float, list[PluginLimit]]]] = {}
module_cache_ttl: ClassVar[float] = 60 # 模块缓存有效期(秒)
@classmethod
async def init_limit(cls):
"""初始化限制"""
cls.last_update_time = time.time()
try:
await asyncio.wait_for(cls.update_limits(), timeout=DB_TIMEOUT_SECONDS * 2)
except asyncio.TimeoutError:
logger.error("初始化限制超时", LOGGER_COMMAND)
@classmethod
async def update_limits(cls):
"""更新限制信息"""
# 防止并发更新
if cls.is_updating:
return
cls.is_updating = True
try:
start_time = time.time()
try:
limit_list = await asyncio.wait_for(
PluginLimit.filter(status=True).all(), timeout=DB_TIMEOUT_SECONDS
)
except asyncio.TimeoutError:
logger.error("查询限制信息超时", LOGGER_COMMAND)
cls.is_updating = False
return
# 清空旧数据
cls.add_module = []
cls.cd_limit = {}
cls.block_limit = {}
cls.count_limit = {}
# 添加新数据
for limit in limit_list:
cls.add_limit(limit)
cls.last_update_time = time.time()
elapsed = time.time() - start_time
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的更新
logger.warning(f"更新限制信息耗时: {elapsed:.3f}s", LOGGER_COMMAND)
finally:
cls.is_updating = False
@classmethod
def add_limit(cls, limit: PluginLimit):
"""添加限制
参数:
limit: PluginLimit
"""
if limit.module not in cls.add_module:
cls.add_module.append(limit.module)
if limit.limit_type == PluginLimitType.BLOCK:
cls.block_limit[limit.module] = Limit(
limit=limit, limiter=UserBlockLimiter()
)
elif limit.limit_type == PluginLimitType.CD:
cls.cd_limit[limit.module] = Limit(
limit=limit, limiter=FreqLimiter(limit.cd)
)
elif limit.limit_type == PluginLimitType.COUNT:
cls.count_limit[limit.module] = Limit(
limit=limit, limiter=CountLimiter(limit.max_count)
)
@classmethod
def unblock(
cls, module: str, user_id: str, group_id: str | None, channel_id: str | None
):
"""解除插件block
参数:
module: 模块名
user_id: 用户id
group_id: 群组id
channel_id: 频道id
"""
if limit_model := cls.block_limit.get(module):
limit = limit_model.limit
limiter: UserBlockLimiter = limit_model.limiter # type: ignore
key_type = user_id
if group_id and limit.watch_type == LimitWatchType.GROUP:
key_type = channel_id or group_id
logger.debug(
f"解除对象: {key_type} 的block限制",
LOGGER_COMMAND,
session=user_id,
group_id=group_id,
)
limiter.set_false(key_type)
@classmethod
async def get_module_limits(cls, module: str) -> list[PluginLimit]:
"""获取模块的限制信息,使用缓存减少数据库查询
参数:
module: 模块名
返回:
list[PluginLimit]: 限制列表
"""
current_time = time.time()
# 检查缓存
if module in cls.module_limit_cache:
cache_time, limits = cls.module_limit_cache[module]
if current_time - cache_time < cls.module_cache_ttl:
return limits
# 缓存不存在或已过期,从数据库查询
try:
start_time = time.time()
limits = await asyncio.wait_for(
PluginLimit.filter(module=module, status=True).all(),
timeout=DB_TIMEOUT_SECONDS,
)
elapsed = time.time() - start_time
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的查询
logger.warning(
f"查询模块限制信息耗时: {elapsed:.3f}s, 模块: {module}",
LOGGER_COMMAND,
)
# 更新缓存
cls.module_limit_cache[module] = (current_time, limits)
return limits
except asyncio.TimeoutError:
logger.error(f"查询模块限制信息超时: {module}", LOGGER_COMMAND)
# 超时时返回空列表,避免阻塞
return []
@classmethod
async def check(
cls,
module: str,
user_id: str,
group_id: str | None,
channel_id: str | None,
):
"""检测限制
参数:
module: 模块名
user_id: 用户id
group_id: 群组id
channel_id: 频道id
异常:
IgnoredException: IgnoredException
"""
start_time = time.time()
# 定期更新全局限制信息
if (
time.time() - cls.last_update_time > cls.update_interval
and not cls.is_updating
):
# 使用异步任务更新,避免阻塞当前请求
asyncio.create_task(cls.update_limits()) # noqa: RUF006
# 如果模块不在已加载列表中,只加载该模块的限制
if module not in cls.add_module:
limits = await cls.get_module_limits(module)
for limit in limits:
cls.add_limit(limit)
# 检查各种限制
try:
if limit_model := cls.cd_limit.get(module):
await cls.__check(limit_model, user_id, group_id, channel_id)
if limit_model := cls.block_limit.get(module):
await cls.__check(limit_model, user_id, group_id, channel_id)
if limit_model := cls.count_limit.get(module):
await cls.__check(limit_model, user_id, group_id, channel_id)
finally:
# 记录总执行时间
elapsed = time.time() - start_time
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
logger.warning(
f"限制检查耗时: {elapsed:.3f}s, 模块: {module}",
LOGGER_COMMAND,
session=user_id,
group_id=group_id,
)
@classmethod
async def __check(
cls,
limit_model: Limit | None,
user_id: str,
group_id: str | None,
channel_id: str | None,
):
"""检测限制
参数:
limit_model: Limit
user_id: 用户id
group_id: 群组id
channel_id: 频道id
异常:
IgnoredException: IgnoredException
"""
if not limit_model:
return
limit = limit_model.limit
limiter = limit_model.limiter
is_limit = (
LimitWatchType.ALL
or (group_id and limit.watch_type == LimitWatchType.GROUP)
or (not group_id and limit.watch_type == LimitWatchType.USER)
)
key_type = user_id
if group_id and limit.watch_type == LimitWatchType.GROUP:
key_type = channel_id or group_id
if is_limit and not limiter.check(key_type):
if limit.result:
try:
await asyncio.wait_for(
MessageUtils.build_message(limit.result).send(),
timeout=DB_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.error(f"发送限制消息超时: {limit.module}", LOGGER_COMMAND)
raise SkipPluginException(
f"{limit.module}({limit.limit_type}) 正在限制中..."
)
else:
logger.debug(
f"开始进行限制 {limit.module}({limit.limit_type})...",
LOGGER_COMMAND,
session=user_id,
group_id=group_id,
)
if isinstance(limiter, FreqLimiter):
limiter.start_cd(key_type)
if isinstance(limiter, UserBlockLimiter):
limiter.set_true(key_type)
if isinstance(limiter, CountLimiter):
limiter.increase(key_type)
async def auth_limit(plugin: PluginInfo, session: Uninfo):
"""插件限制
参数:
plugin: PluginInfo
session: Uninfo
"""
entity = get_entity_ids(session)
try:
await asyncio.wait_for(
LimitManager.check(
plugin.module, entity.user_id, entity.group_id, entity.channel_id
),
timeout=DB_TIMEOUT_SECONDS * 2, # 给予更长的超时时间
)
except asyncio.TimeoutError:
logger.error(f"检查插件限制超时: {plugin.module}", LOGGER_COMMAND)
# 超时时不抛出异常,允许继续执行

View File

@ -0,0 +1,242 @@
import asyncio
import time
from nonebot.adapters import Event
from nonebot_plugin_uninfo import Uninfo
from zhenxun.models.group_console import GroupConsole
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.data_access import DataAccess
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
from zhenxun.services.log import logger
from zhenxun.utils.common_utils import CommonUtils
from zhenxun.utils.enum import BlockType
from zhenxun.utils.utils import get_entity_ids
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
from .exception import IsSuperuserException, SkipPluginException
from .utils import freq, is_poke, send_message
class GroupCheck:
def __init__(
self, plugin: PluginInfo, group_id: str, session: Uninfo, is_poke: bool
) -> None:
self.group_id = group_id
self.session = session
self.is_poke = is_poke
self.plugin = plugin
self.group_dao = DataAccess(GroupConsole)
self.group_data = None
async def check(self):
start_time = time.time()
try:
# 只查询一次数据库,使用 DataAccess 的缓存机制
try:
self.group_data = await asyncio.wait_for(
self.group_dao.safe_get_or_none(
group_id=self.group_id, channel_id__isnull=True
),
timeout=DB_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.error(f"查询群组数据超时: {self.group_id}", LOGGER_COMMAND)
return # 超时时不阻塞,继续执行
# 检查超级用户禁用
if (
self.group_data
and CommonUtils.format(self.plugin.module)
in self.group_data.superuser_block_plugin
):
if freq.is_send_limit_message(self.plugin, self.group_id, self.is_poke):
try:
await asyncio.wait_for(
send_message(
self.session,
"超级管理员禁用了该群此功能...",
self.group_id,
),
timeout=DB_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.error(f"发送消息超时: {self.group_id}", LOGGER_COMMAND)
raise SkipPluginException(
f"{self.plugin.name}({self.plugin.module})"
f" 超级管理员禁用了该群此功能..."
)
# 检查普通禁用
if (
self.group_data
and CommonUtils.format(self.plugin.module)
in self.group_data.block_plugin
):
if freq.is_send_limit_message(self.plugin, self.group_id, self.is_poke):
try:
await asyncio.wait_for(
send_message(
self.session, "该群未开启此功能...", self.group_id
),
timeout=DB_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.error(f"发送消息超时: {self.group_id}", LOGGER_COMMAND)
raise SkipPluginException(
f"{self.plugin.name}({self.plugin.module}) 未开启此功能..."
)
# 检查全局禁用
if self.plugin.block_type == BlockType.GROUP:
if freq.is_send_limit_message(self.plugin, self.group_id, self.is_poke):
try:
await asyncio.wait_for(
send_message(
self.session, "该功能在群组中已被禁用...", self.group_id
),
timeout=DB_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.error(f"发送消息超时: {self.group_id}", LOGGER_COMMAND)
raise SkipPluginException(
f"{self.plugin.name}({self.plugin.module})该插件在群组中已被禁用..."
)
finally:
# 记录执行时间
elapsed = time.time() - start_time
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
logger.warning(
f"GroupCheck.check 耗时: {elapsed:.3f}s, 群组: {self.group_id}",
LOGGER_COMMAND,
)
class PluginCheck:
def __init__(self, group_id: str | None, session: Uninfo, is_poke: bool):
self.session = session
self.is_poke = is_poke
self.group_id = group_id
self.group_dao = DataAccess(GroupConsole)
self.group_data = None
async def check_user(self, plugin: PluginInfo):
"""全局私聊禁用检测
参数:
plugin: PluginInfo
异常:
IgnoredException: 忽略插件
"""
if plugin.block_type == BlockType.PRIVATE:
if freq.is_send_limit_message(plugin, self.session.user.id, self.is_poke):
try:
await asyncio.wait_for(
send_message(self.session, "该功能在私聊中已被禁用..."),
timeout=DB_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.error("发送消息超时", LOGGER_COMMAND)
raise SkipPluginException(
f"{plugin.name}({plugin.module}) 该插件在私聊中已被禁用..."
)
async def check_global(self, plugin: PluginInfo):
"""全局状态
参数:
plugin: PluginInfo
异常:
IgnoredException: 忽略插件
"""
start_time = time.time()
try:
if plugin.status or plugin.block_type != BlockType.ALL:
return
"""全局状态"""
if self.group_id:
# 使用 DataAccess 的缓存机制
try:
self.group_data = await asyncio.wait_for(
self.group_dao.safe_get_or_none(
group_id=self.group_id, channel_id__isnull=True
),
timeout=DB_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.error(f"查询群组数据超时: {self.group_id}", LOGGER_COMMAND)
return # 超时时不阻塞,继续执行
if self.group_data and self.group_data.is_super:
raise IsSuperuserException()
sid = self.group_id or self.session.user.id
if freq.is_send_limit_message(plugin, sid, self.is_poke):
try:
await asyncio.wait_for(
send_message(self.session, "全局未开启此功能...", sid),
timeout=DB_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.error(f"发送消息超时: {sid}", LOGGER_COMMAND)
raise SkipPluginException(
f"{plugin.name}({plugin.module}) 全局未开启此功能..."
)
finally:
# 记录执行时间
elapsed = time.time() - start_time
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
logger.warning(
f"PluginCheck.check_global 耗时: {elapsed:.3f}s", LOGGER_COMMAND
)
async def auth_plugin(plugin: PluginInfo, session: Uninfo, event: Event):
"""插件状态
参数:
plugin: PluginInfo
session: Uninfo
event: Event
"""
start_time = time.time()
try:
entity = get_entity_ids(session)
is_poke_event = is_poke(event)
user_check = PluginCheck(entity.group_id, session, is_poke_event)
if entity.group_id:
group_check = GroupCheck(plugin, entity.group_id, session, is_poke_event)
try:
await asyncio.wait_for(
group_check.check(), timeout=DB_TIMEOUT_SECONDS * 2
)
except asyncio.TimeoutError:
logger.error(f"群组检查超时: {entity.group_id}", LOGGER_COMMAND)
# 超时时不阻塞,继续执行
else:
try:
await asyncio.wait_for(
user_check.check_user(plugin), timeout=DB_TIMEOUT_SECONDS
)
except asyncio.TimeoutError:
logger.error("用户检查超时", LOGGER_COMMAND)
# 超时时不阻塞,继续执行
try:
await asyncio.wait_for(
user_check.check_global(plugin), timeout=DB_TIMEOUT_SECONDS
)
except asyncio.TimeoutError:
logger.error("全局检查超时", LOGGER_COMMAND)
# 超时时不阻塞,继续执行
finally:
# 记录总执行时间
elapsed = time.time() - start_time
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
logger.warning(
f"auth_plugin 总耗时: {elapsed:.3f}s, 模块: {plugin.module}",
LOGGER_COMMAND,
)

View File

@ -0,0 +1,35 @@
import nonebot
from nonebot_plugin_uninfo import Uninfo
from zhenxun.configs.config import Config
from .exception import SkipPluginException
Config.add_plugin_config(
"hook",
"FILTER_BOT",
True,
help="过滤当前连接bot防止bot互相调用",
default_value=True,
type=bool,
)
def bot_filter(session: Uninfo):
"""过滤bot调用bot
参数:
session: Uninfo
异常:
SkipPluginException: bot互相调用
"""
if not Config.get_config("hook", "FILTER_BOT"):
return
bot_ids = list(nonebot.get_bots().keys())
if session.user.id == session.self_id:
return
if session.user.id in bot_ids:
raise SkipPluginException(
f"bot:{session.self_id} 尝试调用 bot:{session.user.id}"
)

View File

@ -0,0 +1,16 @@
import sys
if sys.version_info >= (3, 11):
from enum import StrEnum
else:
from strenum import StrEnum
LOGGER_COMMAND = "AuthChecker"
class SwitchEnum(StrEnum):
ENABLE = "醒来"
DISABLE = "休息吧"
WARNING_THRESHOLD = 0.5 # 警告阈值(秒)

View File

@ -0,0 +1,26 @@
class IsSuperuserException(Exception):
pass
class SkipPluginException(Exception):
def __init__(self, info: str, *args: object) -> None:
super().__init__(*args)
self.info = info
def __str__(self) -> str:
return self.info
def __repr__(self) -> str:
return self.info
class PermissionExemption(Exception):
def __init__(self, info: str, *args: object) -> None:
super().__init__(*args)
self.info = info
def __str__(self) -> str:
return self.info
def __repr__(self) -> str:
return self.info

View File

@ -0,0 +1,91 @@
import contextlib
from nonebot.adapters import Event
from nonebot_plugin_uninfo import Uninfo
from zhenxun.configs.config import Config
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.log import logger
from zhenxun.utils.enum import PluginType
from zhenxun.utils.message import MessageUtils
from zhenxun.utils.utils import FreqLimiter
from .config import LOGGER_COMMAND
base_config = Config.get("hook")
def is_poke(event: Event) -> bool:
"""判断是否为poke类型
参数:
event: Event
返回:
bool: 是否为poke类型
"""
with contextlib.suppress(ImportError):
from nonebot.adapters.onebot.v11 import PokeNotifyEvent
return isinstance(event, PokeNotifyEvent)
return False
async def send_message(
session: Uninfo, message: list | str, check_tag: str | None = None
):
"""发送消息
参数:
session: Uninfo
message: 消息
check_tag: cd flag
"""
try:
if not check_tag:
await MessageUtils.build_message(message).send(reply_to=True)
elif freq._flmt.check(check_tag):
freq._flmt.start_cd(check_tag)
await MessageUtils.build_message(message).send(reply_to=True)
except Exception as e:
logger.error(
"发送消息失败",
LOGGER_COMMAND,
session=session,
e=e,
)
class FreqUtils:
def __init__(self):
check_notice_info_cd = Config.get_config("hook", "CHECK_NOTICE_INFO_CD")
if check_notice_info_cd is None or check_notice_info_cd < 0:
raise ValueError("模块: [hook], 配置项: [CHECK_NOTICE_INFO_CD] 为空或小于0")
self._flmt = FreqLimiter(check_notice_info_cd)
self._flmt_g = FreqLimiter(check_notice_info_cd)
self._flmt_s = FreqLimiter(check_notice_info_cd)
self._flmt_c = FreqLimiter(check_notice_info_cd)
def is_send_limit_message(
self, plugin: PluginInfo, sid: str, is_poke: bool
) -> bool:
"""是否发送提示消息
参数:
plugin: PluginInfo
sid: 检测键
is_poke: 是否是戳一戳
返回:
bool: 是否发送提示消息
"""
if is_poke:
return False
if not base_config.get("IS_SEND_TIP_MESSAGE"):
return False
if plugin.plugin_type == PluginType.DEPENDANT:
return False
return plugin.module != "ai" if self._flmt_s.check(sid) else False
freq = FreqUtils()

View File

@ -0,0 +1,375 @@
import asyncio
import time
from nonebot.adapters import Bot, Event
from nonebot.exception import IgnoredException
from nonebot.matcher import Matcher
from nonebot_plugin_alconna import UniMsg
from nonebot_plugin_uninfo import Uninfo
from tortoise.exceptions import IntegrityError
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.user_console import UserConsole
from zhenxun.services.data_access import DataAccess
from zhenxun.services.log import logger
from zhenxun.utils.enum import GoldHandle, PluginType
from zhenxun.utils.exception import InsufficientGold
from zhenxun.utils.platform import PlatformUtils
from zhenxun.utils.utils import get_entity_ids
from .auth.auth_admin import auth_admin
from .auth.auth_ban import auth_ban
from .auth.auth_bot import auth_bot
from .auth.auth_cost import auth_cost
from .auth.auth_group import auth_group
from .auth.auth_limit import LimitManager, auth_limit
from .auth.auth_plugin import auth_plugin
from .auth.bot_filter import bot_filter
from .auth.config import LOGGER_COMMAND, WARNING_THRESHOLD
from .auth.exception import (
IsSuperuserException,
PermissionExemption,
SkipPluginException,
)
# 超时设置(秒)
TIMEOUT_SECONDS = 5.0
# 熔断计数器
CIRCUIT_BREAKERS = {
"auth_ban": {"failures": 0, "threshold": 3, "active": False, "reset_time": 0},
"auth_bot": {"failures": 0, "threshold": 3, "active": False, "reset_time": 0},
"auth_group": {"failures": 0, "threshold": 3, "active": False, "reset_time": 0},
"auth_admin": {"failures": 0, "threshold": 3, "active": False, "reset_time": 0},
"auth_plugin": {"failures": 0, "threshold": 3, "active": False, "reset_time": 0},
"auth_limit": {"failures": 0, "threshold": 3, "active": False, "reset_time": 0},
}
# 熔断重置时间(秒)
CIRCUIT_RESET_TIME = 300 # 5分钟
# 超时装饰器
async def with_timeout(coro, timeout=TIMEOUT_SECONDS, name=None):
"""带超时控制的协程执行
参数:
coro: 要执行的协程
timeout: 超时时间
name: 操作名称用于日志记录
返回:
协程的返回值或者在超时时抛出 TimeoutError
"""
try:
return await asyncio.wait_for(coro, timeout=timeout)
except asyncio.TimeoutError:
if name:
logger.error(f"{name} 操作超时 (>{timeout}s)", LOGGER_COMMAND)
# 更新熔断计数器
if name in CIRCUIT_BREAKERS:
CIRCUIT_BREAKERS[name]["failures"] += 1
if (
CIRCUIT_BREAKERS[name]["failures"]
>= CIRCUIT_BREAKERS[name]["threshold"]
and not CIRCUIT_BREAKERS[name]["active"]
):
CIRCUIT_BREAKERS[name]["active"] = True
CIRCUIT_BREAKERS[name]["reset_time"] = (
time.time() + CIRCUIT_RESET_TIME
)
logger.warning(
f"{name} 熔断器已激活,将在 {CIRCUIT_RESET_TIME} 秒后重置",
LOGGER_COMMAND,
)
raise
# 检查熔断状态
def check_circuit_breaker(name):
"""检查熔断器状态
参数:
name: 操作名称
返回:
bool: 是否已熔断
"""
if name not in CIRCUIT_BREAKERS:
return False
# 检查是否需要重置熔断器
if (
CIRCUIT_BREAKERS[name]["active"]
and time.time() > CIRCUIT_BREAKERS[name]["reset_time"]
):
CIRCUIT_BREAKERS[name]["active"] = False
CIRCUIT_BREAKERS[name]["failures"] = 0
logger.info(f"{name} 熔断器已重置", LOGGER_COMMAND)
return CIRCUIT_BREAKERS[name]["active"]
async def get_plugin_and_user(
module: str, user_id: str
) -> tuple[PluginInfo, UserConsole]:
"""获取用户数据和插件信息
参数:
module: 模块名
user_id: 用户id
异常:
PermissionExemption: 插件数据不存在
PermissionExemption: 插件类型为HIDDEN
PermissionExemption: 重复创建用户
PermissionExemption: 用户数据不存在
返回:
tuple[PluginInfo, UserConsole]: 插件信息用户信息
"""
user_dao = DataAccess(UserConsole)
plugin_dao = DataAccess(PluginInfo)
# 并行查询插件和用户数据
plugin_task = plugin_dao.safe_get_or_none(module=module)
user_task = user_dao.safe_get_or_none(user_id=user_id)
try:
plugin, user = await with_timeout(
asyncio.gather(plugin_task, user_task), name="get_plugin_and_user"
)
except asyncio.TimeoutError:
# 如果并行查询超时,尝试串行查询
logger.warning("并行查询超时,尝试串行查询", LOGGER_COMMAND)
plugin = await with_timeout(
plugin_dao.safe_get_or_none(module=module), name="get_plugin"
)
user = await with_timeout(
user_dao.safe_get_or_none(user_id=user_id), name="get_user"
)
if not plugin:
raise PermissionExemption(f"插件:{module} 数据不存在,已跳过权限检查...")
if plugin.plugin_type == PluginType.HIDDEN:
raise PermissionExemption(
f"插件: {plugin.name}:{plugin.module} 为HIDDEN已跳过权限检查..."
)
user = None
try:
user = await user_dao.safe_get_or_none(user_id=user_id)
except IntegrityError as e:
raise PermissionExemption("重复创建用户,已跳过该次权限检查...") from e
if not user:
raise PermissionExemption("用户数据不存在,已跳过权限检查...")
return plugin, user
async def get_plugin_cost(
bot: Bot, user: UserConsole, plugin: PluginInfo, session: Uninfo
) -> int:
"""获取插件费用
参数:
bot: Bot
user: 用户数据
plugin: 插件数据
session: Uninfo
异常:
IsSuperuserException: 超级用户
IsSuperuserException: 超级用户
返回:
int: 调用插件金币费用
"""
cost_gold = await with_timeout(auth_cost(user, plugin, session), name="auth_cost")
if session.user.id in bot.config.superusers:
if plugin.plugin_type == PluginType.SUPERUSER:
raise IsSuperuserException()
if not plugin.limit_superuser:
raise IsSuperuserException()
return cost_gold
async def reduce_gold(user_id: str, module: str, cost_gold: int, session: Uninfo):
"""扣除用户金币
参数:
user_id: 用户id
module: 插件模块名称
cost_gold: 消耗金币
session: Uninfo
"""
user_dao = DataAccess(UserConsole)
try:
await with_timeout(
UserConsole.reduce_gold(
user_id,
cost_gold,
GoldHandle.PLUGIN,
module,
PlatformUtils.get_platform(session),
),
name="reduce_gold",
)
except InsufficientGold:
if u := await UserConsole.get_user(user_id):
u.gold = 0
await u.save(update_fields=["gold"])
except asyncio.TimeoutError:
logger.error(
f"扣除金币超时,用户: {user_id}, 金币: {cost_gold}",
LOGGER_COMMAND,
session=session,
)
# 清除缓存,使下次查询时从数据库获取最新数据
await user_dao.clear_cache(user_id=user_id)
logger.debug(f"调用功能花费金币: {cost_gold}", LOGGER_COMMAND, session=session)
# 辅助函数,用于记录每个 hook 的执行时间
async def time_hook(coro, name, time_dict):
start = time.time()
try:
# 检查熔断状态
if check_circuit_breaker(name):
logger.info(f"{name} 熔断器激活中,跳过执行", LOGGER_COMMAND)
time_dict[name] = "熔断跳过"
return
# 添加超时控制
return await with_timeout(coro, name=name)
except asyncio.TimeoutError:
time_dict[name] = f"超时 (>{TIMEOUT_SECONDS}s)"
finally:
if name not in time_dict:
time_dict[name] = f"{time.time() - start:.3f}s"
async def auth(
matcher: Matcher,
event: Event,
bot: Bot,
session: Uninfo,
message: UniMsg,
):
"""权限检查
参数:
matcher: matcher
event: Event
bot: bot
session: Uninfo
message: UniMsg
"""
start_time = time.time()
cost_gold = 0
ignore_flag = False
entity = get_entity_ids(session)
module = matcher.plugin_name or ""
# 用于记录各个 hook 的执行时间
hook_times = {}
hooks_time = 0 # 初始化 hooks_time 变量
try:
if not module:
raise PermissionExemption("Matcher插件名称不存在...")
# 获取插件和用户数据
plugin_user_start = time.time()
try:
plugin, user = await with_timeout(
get_plugin_and_user(module, entity.user_id), name="get_plugin_and_user"
)
hook_times["get_plugin_user"] = f"{time.time() - plugin_user_start:.3f}s"
except asyncio.TimeoutError:
logger.error(
f"获取插件和用户数据超时,模块: {module}",
LOGGER_COMMAND,
session=session,
)
raise PermissionExemption("获取插件和用户数据超时,请稍后再试...")
# 获取插件费用
cost_start = time.time()
try:
cost_gold = await with_timeout(
get_plugin_cost(bot, user, plugin, session), name="get_plugin_cost"
)
hook_times["cost_gold"] = f"{time.time() - cost_start:.3f}s"
except asyncio.TimeoutError:
logger.error(
f"获取插件费用超时,模块: {module}", LOGGER_COMMAND, session=session
)
# 继续执行,不阻止权限检查
# 执行 bot_filter
bot_filter(session)
# 并行执行所有 hook 检查,并记录执行时间
hooks_start = time.time()
# 创建所有 hook 任务
hook_tasks = [
time_hook(auth_ban(matcher, bot, session), "auth_ban", hook_times),
time_hook(auth_bot(plugin, bot.self_id), "auth_bot", hook_times),
time_hook(auth_group(plugin, entity, message), "auth_group", hook_times),
time_hook(auth_admin(plugin, session), "auth_admin", hook_times),
time_hook(auth_plugin(plugin, session, event), "auth_plugin", hook_times),
time_hook(auth_limit(plugin, session), "auth_limit", hook_times),
]
# 使用 gather 并行执行所有 hook但添加总体超时控制
try:
await with_timeout(
asyncio.gather(*hook_tasks),
timeout=TIMEOUT_SECONDS * 2, # 给总体执行更多时间
name="auth_hooks_gather",
)
except asyncio.TimeoutError:
logger.error(
f"权限检查 hooks 总体执行超时,模块: {module}",
LOGGER_COMMAND,
session=session,
)
# 不抛出异常,允许继续执行
hooks_time = time.time() - hooks_start
except SkipPluginException as e:
LimitManager.unblock(module, entity.user_id, entity.group_id, entity.channel_id)
logger.info(str(e), LOGGER_COMMAND, session=session)
ignore_flag = True
except IsSuperuserException:
logger.debug("超级用户跳过权限检测...", LOGGER_COMMAND, session=session)
except PermissionExemption as e:
logger.info(str(e), LOGGER_COMMAND, session=session)
# 扣除金币
if not ignore_flag and cost_gold > 0:
gold_start = time.time()
try:
await with_timeout(
reduce_gold(entity.user_id, module, cost_gold, session),
name="reduce_gold",
)
hook_times["reduce_gold"] = f"{time.time() - gold_start:.3f}s"
except asyncio.TimeoutError:
logger.error(
f"扣除金币超时,模块: {module}", LOGGER_COMMAND, session=session
)
# 记录总执行时间
total_time = time.time() - start_time
if total_time > WARNING_THRESHOLD: # 如果总时间超过500ms记录详细信息
logger.warning(
f"权限检查耗时过长: {total_time:.3f}s, 模块: {module}, "
f"hooks时间: {hooks_time:.3f}s, "
f"详情: {hook_times}",
LOGGER_COMMAND,
session=session,
)
if ignore_flag:
raise IgnoredException("权限检测 ignore")

View File

@ -1,41 +1,43 @@
from nonebot.adapters.onebot.v11 import Bot, Event
import time
from nonebot.adapters import Bot, Event
from nonebot.matcher import Matcher
from nonebot.message import run_postprocessor, run_preprocessor
from nonebot_plugin_alconna import UniMsg
from nonebot_plugin_session import EventSession
from nonebot_plugin_uninfo import Uninfo
from ._auth_checker import LimitManage, checker
from zhenxun.services.log import logger
from .auth.config import LOGGER_COMMAND
from .auth_checker import LimitManager, auth
# # 权限检测
@run_preprocessor
async def _(
matcher: Matcher, event: Event, bot: Bot, session: EventSession, message: UniMsg
):
await checker.auth(
async def _(matcher: Matcher, event: Event, bot: Bot, session: Uninfo, message: UniMsg):
start_time = time.time()
await auth(
matcher,
event,
bot,
session,
message,
)
logger.debug(f"权限检测耗时:{time.time() - start_time}", LOGGER_COMMAND)
# 解除命令block阻塞
@run_postprocessor
async def _(
matcher: Matcher,
exception: Exception | None,
bot: Bot,
event: Event,
session: EventSession,
):
user_id = session.id1
group_id = session.id3
channel_id = session.id2
if not group_id:
group_id = channel_id
channel_id = None
async def _(matcher: Matcher, session: Uninfo):
user_id = session.user.id
group_id = None
channel_id = None
if session.group:
if session.group.parent:
group_id = session.group.parent.id
channel_id = session.group.id
else:
group_id = session.group.id
if user_id and matcher.plugin:
module = matcher.plugin.name
LimitManage.unblock(module, user_id, group_id, channel_id)
LimitManager.unblock(module, user_id, group_id, channel_id)

View File

@ -1,84 +0,0 @@
from nonebot.adapters import Bot, Event
from nonebot.exception import IgnoredException
from nonebot.matcher import Matcher
from nonebot.message import run_preprocessor
from nonebot.typing import T_State
from nonebot_plugin_alconna import At
from nonebot_plugin_session import EventSession
from zhenxun.configs.config import Config
from zhenxun.models.ban_console import BanConsole
from zhenxun.models.group_console import GroupConsole
from zhenxun.services.log import logger
from zhenxun.utils.enum import PluginType
from zhenxun.utils.message import MessageUtils
from zhenxun.utils.utils import FreqLimiter
Config.add_plugin_config(
"hook",
"BAN_RESULT",
"才不会给你发消息.",
help="对被ban用户发送的消息",
)
_flmt = FreqLimiter(300)
# 检查是否被ban
@run_preprocessor
async def _(
matcher: Matcher, bot: Bot, event: Event, state: T_State, session: EventSession
):
extra = {}
if plugin := matcher.plugin:
if metadata := plugin.metadata:
extra = metadata.extra
if extra.get("plugin_type") in [PluginType.HIDDEN]:
return
user_id = session.id1
group_id = session.id3 or session.id2
if group_id:
if user_id in bot.config.superusers:
return
if await BanConsole.is_ban(None, group_id):
logger.debug("群组处于黑名单中...", "ban_hook")
raise IgnoredException("群组处于黑名单中...")
if g := await GroupConsole.get_group(group_id):
if g.level < 0:
logger.debug("群黑名单, 群权限-1...", "ban_hook")
raise IgnoredException("群黑名单, 群权限-1..")
if user_id:
ban_result = Config.get_config("hook", "BAN_RESULT")
if user_id in bot.config.superusers:
return
if await BanConsole.is_ban(user_id, group_id):
time = await BanConsole.check_ban_time(user_id, group_id)
if time == -1:
time_str = ""
else:
time = abs(int(time))
if time < 60:
time_str = f"{time!s}"
else:
minute = int(time / 60)
if minute > 60:
hours = minute // 60
minute %= 60
time_str = f"{hours} 小时 {minute}分钟"
else:
time_str = f"{minute} 分钟"
if (
not extra.get("ignore_prompt")
and time != -1
and ban_result
and _flmt.check(user_id)
):
_flmt.start_cd(user_id)
await MessageUtils.build_message(
[
At(flag="user", target=user_id),
f"{ban_result}\n在..在 {time_str} 后才会理你喔",
]
).send()
logger.debug("用户处于黑名单中...", "ban_hook")
raise IgnoredException("用户处于黑名单中...")

View File

@ -9,6 +9,8 @@ from zhenxun.utils.enum import BotSentType
from zhenxun.utils.manager.message_manager import MessageManager
from zhenxun.utils.platform import PlatformUtils
LOG_COMMAND = "MessageHook"
def replace_message(message: Message) -> str:
"""将消息中的at、image、record、face替换为字符串
@ -54,11 +56,11 @@ async def handle_api_result(
if user_id and message_id:
MessageManager.add(str(user_id), str(message_id))
logger.debug(
f"收集消息iduser_id: {user_id}, msg_id: {message_id}", "msg_hook"
f"收集消息iduser_id: {user_id}, msg_id: {message_id}", LOG_COMMAND
)
except Exception as e:
logger.warning(
f"收集消息id发生错误...data: {data}, result: {result}", "msg_hook", e=e
f"收集消息id发生错误...data: {data}, result: {result}", LOG_COMMAND, e=e
)
if not Config.get_config("hook", "RECORD_BOT_SENT_MESSAGES"):
return
@ -80,6 +82,6 @@ async def handle_api_result(
except Exception as e:
logger.warning(
f"消息发送记录发生错误...data: {data}, result: {result}",
"msg_hook",
LOG_COMMAND,
e=e,
)

View File

@ -4,15 +4,27 @@ import nonebot
from nonebot.adapters import Bot
from zhenxun.models.group_console import GroupConsole
from zhenxun.services.cache import CacheException
from zhenxun.services.log import logger
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from zhenxun.utils.platform import PlatformUtils
nonebot.load_plugins(str(Path(__file__).parent.resolve()))
try:
from .__init_cache import register_cache_types
except CacheException as e:
raise SystemError(f"ERROR{e}")
driver = nonebot.get_driver()
@PriorityLifecycle.on_startup(priority=5)
async def _():
register_cache_types()
logger.info("缓存类型注册完成")
@driver.on_bot_connect
async def _(bot: Bot):
"""将bot已存在的群组添加群认证

View File

@ -0,0 +1,35 @@
"""
缓存初始化模块
负责注册各种缓存类型实现按需缓存机制
"""
from zhenxun.models.ban_console import BanConsole
from zhenxun.models.bot_console import BotConsole
from zhenxun.models.group_console import GroupConsole
from zhenxun.models.level_user import LevelUser
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.user_console import UserConsole
from zhenxun.services.cache import CacheRegistry, cache_config
from zhenxun.services.cache.config import CacheMode
from zhenxun.services.log import logger
from zhenxun.utils.enum import CacheType
# 注册缓存类型
def register_cache_types():
"""注册所有缓存类型"""
CacheRegistry.register(CacheType.PLUGINS, PluginInfo)
CacheRegistry.register(CacheType.GROUPS, GroupConsole)
CacheRegistry.register(CacheType.BOT, BotConsole)
CacheRegistry.register(CacheType.USERS, UserConsole)
CacheRegistry.register(
CacheType.LEVEL, LevelUser, key_format="{user_id}_{group_id}"
)
CacheRegistry.register(CacheType.BAN, BanConsole, key_format="{user_id}_{group_id}")
if cache_config.cache_mode == CacheMode.NONE:
logger.info("缓存功能已禁用,将直接从数据库获取数据")
else:
logger.info(f"已注册所有缓存类型,缓存模式: {cache_config.cache_mode}")
logger.info("使用增量缓存模式,数据将按需加载到缓存中")

View File

@ -1,3 +1,5 @@
import asyncio
import aiofiles
import nonebot
from nonebot import get_loaded_plugins
@ -112,24 +114,29 @@ async def _():
await _handle_setting(plugin, plugin_list, limit_list)
create_list = []
update_list = []
update_task_list = []
for plugin in plugin_list:
if plugin.module_path not in module2id:
create_list.append(plugin)
else:
plugin.id = module2id[plugin.module_path]
await plugin.save(
update_fields=[
"name",
"author",
"version",
"admin_level",
"plugin_type",
"is_show",
]
update_task_list.append(
plugin.save(
update_fields=[
"name",
"author",
"version",
"admin_level",
"plugin_type",
"is_show",
]
)
)
update_list.append(plugin)
if create_list:
await PluginInfo.bulk_create(create_list, 10)
if update_task_list:
await asyncio.gather(*update_task_list)
# if update_list:
# # TODO: 批量更新无法更新plugin_type: tortoise.exceptions.OperationalError:
# column "superuser" does not exist

View File

@ -205,7 +205,7 @@ class Manager:
self.cd_data: dict[str, PluginCdBlock] = {}
if self.cd_file.exists():
with open(self.cd_file, encoding="utf8") as f:
temp = _yaml.load(f)
temp = _yaml.load(f) or {}
if "PluginCdLimit" in temp.keys():
for k, v in temp["PluginCdLimit"].items():
if "." in k:
@ -216,7 +216,7 @@ class Manager:
self.block_data: dict[str, BaseBlock] = {}
if self.block_file.exists():
with open(self.block_file, encoding="utf8") as f:
temp = _yaml.load(f)
temp = _yaml.load(f) or {}
if "PluginBlockLimit" in temp.keys():
for k, v in temp["PluginBlockLimit"].items():
if "." in k:
@ -227,7 +227,7 @@ class Manager:
self.count_data: dict[str, PluginCountBlock] = {}
if self.count_file.exists():
with open(self.count_file, encoding="utf8") as f:
temp = _yaml.load(f)
temp = _yaml.load(f) or {}
if "PluginCountLimit" in temp.keys():
for k, v in temp["PluginCountLimit"].items():
if "." in k:

View File

@ -55,15 +55,17 @@ class GroupManager:
if plugin_list := await PluginInfo.filter(default_status=False).all():
for plugin in plugin_list:
block_plugin += f"<{plugin.module},"
group_info = await bot.get_group_info(group_id=group_id, no_cache=True)
await GroupConsole.create(
group_info = await bot.get_group_info(group_id=group_id)
await GroupConsole.update_or_create(
group_id=group_info["group_id"],
group_name=group_info["group_name"],
max_member_count=group_info["max_member_count"],
member_count=group_info["member_count"],
group_flag=1,
block_plugin=block_plugin,
platform="qq",
defaults={
"group_name": group_info["group_name"],
"max_member_count": group_info["max_member_count"],
"member_count": group_info["member_count"],
"group_flag": 1,
"block_plugin": block_plugin,
"platform": "qq",
},
)
@classmethod
@ -145,7 +147,7 @@ class GroupManager:
e=e,
)
raise ForceAddGroupError("强制拉群或未有群信息,退出群聊失败...") from e
await GroupConsole.filter(group_id=group_id).delete()
# await GroupConsole.filter(group_id=group_id).delete()
raise ForceAddGroupError(f"触发强制入群保护,已成功退出群聊 {group_id}...")
else:
await cls.__handle_add_group(bot, group_id, group)

View File

@ -1,4 +1,4 @@
from nonebot.message import run_preprocessor
from nonebot import on_message
from nonebot_plugin_uninfo import Uninfo
from zhenxun.models.friend_user import FriendUser
@ -8,24 +8,27 @@ from zhenxun.services.log import logger
from zhenxun.utils.platform import PlatformUtils
@run_preprocessor
async def do_something(session: Uninfo):
def rule(session: Uninfo) -> bool:
return PlatformUtils.is_qbot(session)
_matcher = on_message(priority=999, block=False, rule=rule)
@_matcher.handle()
async def _(session: Uninfo):
platform = PlatformUtils.get_platform(session)
if session.group:
if not await GroupConsole.exists(group_id=session.group.id):
await GroupConsole.create(group_id=session.group.id)
logger.info("添加当前群组ID信息" "", session=session)
if not await GroupInfoUser.exists(
user_id=session.user.id, group_id=session.group.id
):
await GroupInfoUser.create(
user_id=session.user.id, group_id=session.group.id, platform=platform
)
logger.info("添加当前用户群组ID信息", "", session=session)
logger.info("添加当前群组ID信息", session=session)
await GroupInfoUser.update_or_create(
user_id=session.user.id,
group_id=session.group.id,
platform=PlatformUtils.get_platform(session),
)
elif not await FriendUser.exists(user_id=session.user.id, platform=platform):
try:
await FriendUser.create(user_id=session.user.id, platform=platform)
logger.info("添加当前好友用户信息", "", session=session)
except Exception as e:
logger.error("添加当前好友用户信息失败", session=session, e=e)
await FriendUser.create(
user_id=session.user.id, platform=PlatformUtils.get_platform(session)
)
logger.info("添加当前好友用户信息", "", session=session)

View File

@ -1,30 +0,0 @@
from zhenxun.models.group_console import GroupConsole
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
@PriorityLifecycle.on_startup(priority=5)
async def _():
"""开启/禁用插件格式修改"""
_, is_create = await GroupConsole.get_or_create(group_id=133133133)
"""标记"""
if is_create:
data_list = []
for group in await GroupConsole.all():
if group.block_plugin:
if modules := group.block_plugin.split(","):
block_plugin = "".join(
(f"{module}," if module.startswith("<") else f"<{module},")
for module in modules
if module.strip()
)
group.block_plugin = block_plugin.replace("<,", "")
if group.block_task:
if modules := group.block_task.split(","):
block_task = "".join(
(f"{module}," if module.startswith("<") else f"<{module},")
for module in modules
if module.strip()
)
group.block_task = block_task.replace("<,", "")
data_list.append(group)
await GroupConsole.bulk_update(data_list, ["block_plugin", "block_task"], 10)

View File

@ -44,9 +44,7 @@ class StatisticsManage:
title = f"{user.user_name if user else user_id} {day_type}功能调用统计"
elif group_id:
"""查群组"""
group = await GroupConsole.get_or_none(
group_id=group_id, channel_id__isnull=True
)
group = await GroupConsole.get_group(group_id=group_id)
title = f"{group.group_name if group else group_id} {day_type}功能调用统计"
else:
title = "功能调用统计"

View File

@ -163,7 +163,7 @@ async def _(session: EventSession, arparma: Arparma, state: T_State, level: int)
@_matcher.assign("super-handle", parameterless=[CheckGroupId()])
async def _(session: EventSession, arparma: Arparma, state: T_State):
gid = state["group_id"]
group = await GroupConsole.get_or_none(group_id=gid)
group = await GroupConsole.get_group(group_id=gid)
if not group:
await MessageUtils.build_message("群组信息不存在, 请更新群组信息...").finish()
s = "删除" if arparma.find("delete") else "添加"
@ -177,7 +177,9 @@ async def _(session: EventSession, arparma: Arparma, state: T_State):
async def _(session: EventSession, arparma: Arparma, state: T_State):
gid = state["group_id"]
await GroupConsole.update_or_create(
group_id=gid, defaults={"group_flag": 0 if arparma.find("delete") else 1}
group_id=gid,
channel_id__isnull=True,
defaults={"group_flag": 0 if arparma.find("delete") else 1},
)
s = "删除" if arparma.find("delete") else "添加"
await MessageUtils.build_message(f"{s}群认证成功!").send(reply_to=True)

View File

@ -119,7 +119,7 @@ class ApiDataSource:
(await PlatformUtils.get_friend_list(select_bot.bot))[0]
)
except Exception as e:
logger.warning("获取bot好友/群组信息失败...", "WebUi", e=e)
logger.warning("获取bot好友/群组数量失败...", "WebUi", e=e)
select_bot.group_count = 0
select_bot.friend_count = 0
select_bot.status = await BotConsole.get_bot_status(select_bot.self_id)

View File

@ -250,7 +250,7 @@ class ApiDataSource:
返回:
GroupDetail | None: 群组详情数据
"""
group = await GroupConsole.get_or_none(group_id=group_id)
group = await GroupConsole.get_group(group_id=group_id)
if not group:
return None
like_plugin = await cls.__get_group_detail_like_plugin(group_id)

View File

@ -45,6 +45,7 @@ async def _(path: str | None = None) -> Result[list[DirFile]]:
mtime=file_path.stat().st_mtime,
)
)
data_list.sort(key=lambda f: f.name)
return Result.ok(data_list)
except Exception as e:
return Result.fail(f"获取文件列表失败: {e!s}")

View File

@ -13,8 +13,8 @@ class BotSetting(BaseModel):
"""回复时NICKNAME"""
system_proxy: str | None = None
"""系统代理"""
db_url: str = ""
"""数据库链接"""
db_url: str = "sqlite:data/zhenxun.db"
"""数据库链接, 默认值为sqlite:data/zhenxun.db"""
platform_superusers: dict[str, list[str]] = Field(default_factory=dict)
"""平台超级用户"""
qbot_id_data: dict[str, str] = Field(default_factory=dict)

View File

@ -155,8 +155,6 @@ class AICallableProperties(BaseModel):
"""参数类型"""
description: str
"""参数描述"""
enums: list[str] | None = None
"""参数枚举"""
class AICallableParam(BaseModel):

View File

@ -1,10 +1,12 @@
import time
from typing import ClassVar
from typing_extensions import Self
from tortoise import fields
from zhenxun.services.db_context import Model
from zhenxun.services.log import logger
from zhenxun.utils.enum import CacheType, DbLockType
from zhenxun.utils.exception import UserAndGroupIsNone
@ -27,6 +29,15 @@ class BanConsole(Model):
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
table = "ban_console"
table_description = "封禁人员/群组数据表"
unique_together = ("user_id", "group_id")
indexes = [("user_id",), ("group_id",)] # noqa: RUF012
cache_type = CacheType.BAN
"""缓存类型"""
cache_key_field = ("user_id", "group_id")
"""缓存键字段"""
enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE, DbLockType.UPSERT]
"""开启锁"""
@classmethod
async def _get_data(cls, user_id: str | None, group_id: str | None) -> Self | None:
@ -46,12 +57,12 @@ class BanConsole(Model):
raise UserAndGroupIsNone()
if user_id:
return (
await cls.get_or_none(user_id=user_id, group_id=group_id)
await cls.safe_get_or_none(user_id=user_id, group_id=group_id)
if group_id
else await cls.get_or_none(user_id=user_id, group_id__isnull=True)
else await cls.safe_get_or_none(user_id=user_id, group_id__isnull=True)
)
else:
return await cls.get_or_none(user_id="", group_id=group_id)
return await cls.safe_get_or_none(user_id="", group_id=group_id)
@classmethod
async def check_ban_level(
@ -167,3 +178,32 @@ class BanConsole(Model):
await user.delete()
return True
return False
@classmethod
async def get_ban(
cls,
*,
id: int | None = None,
user_id: str | None = None,
group_id: str | None = None,
) -> Self | None:
"""安全地获取ban记录
参数:
id: 记录id
user_id: 用户id
group_id: 群组id
返回:
Self | None: ban记录
"""
if id is not None:
return await cls.safe_get_or_none(id=id)
return await cls._get_data(user_id, group_id)
@classmethod
async def _run_script(cls):
return [
"CREATE INDEX idx_ban_console_user_id ON ban_console(user_id);",
"CREATE INDEX idx_ban_console_group_id ON ban_console(group_id);",
]

View File

@ -3,6 +3,7 @@ from typing import Literal, overload
from tortoise import fields
from zhenxun.services.db_context import Model
from zhenxun.utils.enum import CacheType
class BotConsole(Model):
@ -29,6 +30,11 @@ class BotConsole(Model):
table = "bot_console"
table_description = "Bot数据表"
cache_type = CacheType.BOT
"""缓存类型"""
cache_key_field = "bot_id"
"""缓存键字段"""
@staticmethod
def format(name: str) -> str:
return f"<{name},"

View File

@ -1,4 +1,4 @@
from typing import Any, cast, overload
from typing import Any, ClassVar, cast, overload
from typing_extensions import Self
from tortoise import fields
@ -6,8 +6,9 @@ from tortoise.backends.base.client import BaseDBAsyncClient
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.task_info import TaskInfo
from zhenxun.services.cache import CacheRoot
from zhenxun.services.db_context import Model
from zhenxun.utils.enum import PluginType
from zhenxun.utils.enum import CacheType, DbLockType, PluginType
def add_disable_marker(name: str) -> str:
@ -86,6 +87,16 @@ class GroupConsole(Model):
table = "group_console"
table_description = "群组信息表"
unique_together = ("group_id", "channel_id")
indexes = [ # noqa: RUF012
("group_id",)
]
cache_type = CacheType.GROUPS
"""缓存类型"""
cache_key_field = ("group_id", "channel_id")
"""缓存键字段"""
enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE, DbLockType.UPSERT]
"""开启锁"""
@classmethod
async def _get_task_modules(cls, *, default_status: bool) -> list[str]:
@ -116,6 +127,18 @@ class GroupConsole(Model):
).values_list("module", flat=True),
)
@classmethod
async def _update_cache(cls, instance):
"""更新缓存
参数:
instance: 需要更新缓存的实例
"""
if cache_type := cls.get_cache_type():
key = cls.get_cache_key(instance)
if key is not None:
await CacheRoot.invalidate_cache(cache_type, key)
@classmethod
async def create(
cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any
@ -129,6 +152,9 @@ class GroupConsole(Model):
if task_modules or plugin_modules:
await cls._update_modules(group, task_modules, plugin_modules, using_db)
# 更新缓存
await cls._update_cache(group)
return group
@classmethod
@ -180,6 +206,10 @@ class GroupConsole(Model):
if task_modules or plugin_modules:
await cls._update_modules(group, task_modules, plugin_modules, using_db)
# 更新缓存
if is_create:
await cls._update_cache(group)
return group, is_create
@classmethod
@ -202,24 +232,39 @@ class GroupConsole(Model):
if task_modules or plugin_modules:
await cls._update_modules(group, task_modules, plugin_modules, using_db)
# 更新缓存
await cls._update_cache(group)
return group, is_create
@classmethod
async def get_group(
cls, group_id: str, channel_id: str | None = None
cls,
group_id: str,
channel_id: str | None = None,
clean_duplicates: bool = True,
) -> Self | None:
"""获取群组
参数:
group_id: 群组id
channel_id: 频道id.
channel_id: 频道id
clean_duplicates: 是否删除重复的记录仅保留最新的
返回:
Self: GroupConsole
"""
if channel_id:
return await cls.get_or_none(group_id=group_id, channel_id=channel_id)
return await cls.get_or_none(group_id=group_id, channel_id__isnull=True)
return await cls.safe_get_or_none(
group_id=group_id,
channel_id=channel_id,
clean_duplicates=clean_duplicates,
)
return await cls.safe_get_or_none(
group_id=group_id,
channel_id__isnull=True,
clean_duplicates=clean_duplicates,
)
@classmethod
async def is_super_group(cls, group_id: str) -> bool:
@ -303,6 +348,9 @@ class GroupConsole(Model):
if update_fields:
await group.save(update_fields=update_fields)
# 更新缓存
await cls._update_cache(group)
@classmethod
async def set_unblock_plugin(
cls,
@ -339,6 +387,9 @@ class GroupConsole(Model):
if update_fields:
await group.save(update_fields=update_fields)
# 更新缓存
await cls._update_cache(group)
@classmethod
async def is_normal_block_plugin(
cls, group_id: str, module: str, channel_id: str | None = None
@ -442,6 +493,9 @@ class GroupConsole(Model):
if update_fields:
await group.save(update_fields=update_fields)
# 更新缓存
await cls._update_cache(group)
@classmethod
async def set_unblock_task(
cls,
@ -476,6 +530,9 @@ class GroupConsole(Model):
if update_fields:
await group.save(update_fields=update_fields)
# 更新缓存
await cls._update_cache(group)
@classmethod
def _run_script(cls):
return [
@ -483,4 +540,6 @@ class GroupConsole(Model):
" character varying(255) NOT NULL DEFAULT '';",
"ALTER TABLE group_console ADD superuser_block_task"
" character varying(255) NOT NULL DEFAULT '';",
"CREATE INDEX idx_group_console_group_id ON group_console(group_id);",
"CREATE INDEX idx_group_console_group_null_channel ON group_console(group_id) WHERE channel_id IS NULL;", # 单独创建channel为空的索引 # noqa: E501
]

View File

@ -1,6 +1,7 @@
from tortoise import fields
from zhenxun.services.db_context import Model
from zhenxun.utils.enum import CacheType
class LevelUser(Model):
@ -20,6 +21,11 @@ class LevelUser(Model):
table_description = "用户权限数据库"
unique_together = ("user_id", "group_id")
cache_type = CacheType.LEVEL
"""缓存类型"""
cache_key_field = ("user_id", "group_id")
"""缓存键字段"""
@classmethod
async def get_user_level(cls, user_id: str, group_id: str | None) -> int:
"""获取用户在群内的等级
@ -53,6 +59,9 @@ class LevelUser(Model):
level: 权限等级
group_flag: 是否被自动更新刷新权限 0:, 1:.
"""
if await cls.exists(user_id=user_id, group_id=group_id, user_level=level):
# 权限相同时跳过
return
await cls.update_or_create(
user_id=user_id,
group_id=group_id,

View File

@ -4,7 +4,7 @@ from tortoise import fields
from zhenxun.models.plugin_limit import PluginLimit # noqa: F401
from zhenxun.services.db_context import Model
from zhenxun.utils.enum import BlockType, PluginType
from zhenxun.utils.enum import BlockType, CacheType, PluginType
class PluginInfo(Model):
@ -59,6 +59,11 @@ class PluginInfo(Model):
table = "plugin_info"
table_description = "插件基本信息"
cache_type = CacheType.PLUGINS
"""缓存类型"""
cache_key_field = "module"
"""缓存键字段"""
@classmethod
async def get_plugin(
cls, load_status: bool = True, filter_parent: bool = True, **kwargs

View File

@ -2,7 +2,7 @@ from tortoise import fields
from zhenxun.models.goods_info import GoodsInfo
from zhenxun.services.db_context import Model
from zhenxun.utils.enum import GoldHandle
from zhenxun.utils.enum import CacheType, GoldHandle
from zhenxun.utils.exception import GoodsNotFound, InsufficientGold
from .user_gold_log import UserGoldLog
@ -29,6 +29,12 @@ class UserConsole(Model):
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
table = "user_console"
table_description = "用户数据表"
indexes = [("user_id",), ("uid",)] # noqa: RUF012
cache_type = CacheType.USERS
"""缓存类型"""
cache_key_field = "user_id"
"""缓存键字段"""
@classmethod
async def get_user(cls, user_id: str, platform: str | None = None) -> "UserConsole":
@ -193,3 +199,10 @@ class UserConsole(Model):
if goods := await GoodsInfo.get_or_none(goods_name=name):
return await cls.use_props(user_id, goods.uuid, num, platform)
raise GoodsNotFound("未找到商品...")
@classmethod
async def _run_script(cls):
return [
"CREATE INDEX idx_user_console_user_id ON user_console(user_id);",
"CREATE INDEX idx_user_console_uid ON user_console(uid);",
]

1065
zhenxun/services/cache/__init__.py vendored Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,452 @@
from dataclasses import dataclass
import time
from typing import Any, Generic, TypeVar
T = TypeVar("T")
@dataclass
class CacheData(Generic[T]):
"""缓存数据类,存储数据和过期时间"""
value: T
expire_time: float = 0 # 0表示永不过期
class CacheDict:
"""缓存字典类,提供类似普通字典的接口,数据只存储在内存中"""
def __init__(self, name: str, expire: int = 0):
"""初始化缓存字典
参数:
name: 字典名称
expire: 过期时间默认为0表示永不过期
"""
self.name = name.upper()
self.expire = expire
self._data: dict[str, CacheData[Any]] = {}
def __getitem__(self, key: str) -> Any:
"""获取字典项
参数:
key: 字典键
返回:
Any: 字典值
"""
data = self._data.get(key)
if data is None:
return None
# 检查是否过期
if data.expire_time > 0 and data.expire_time < time.time():
del self._data[key]
return None
return data.value
def __setitem__(self, key: str, value: Any) -> None:
"""设置字典项
参数:
key: 字典键
value: 字典值
"""
# 计算过期时间
expire_time = 0
if self.expire > 0:
expire_time = time.time() + self.expire
self._data[key] = CacheData(value=value, expire_time=expire_time)
def __delitem__(self, key: str) -> None:
"""删除字典项
参数:
key: 字典键
"""
if key in self._data:
del self._data[key]
def __contains__(self, key: str) -> bool:
"""检查键是否存在
参数:
key: 字典键
返回:
bool: 是否存在
"""
if key not in self._data:
return False
# 检查是否过期
data = self._data[key]
if data.expire_time > 0 and data.expire_time < time.time():
del self._data[key]
return False
return True
def get(self, key: str, default: Any = None) -> Any:
"""获取字典项,如果不存在返回默认值
参数:
key: 字典键
default: 默认值
返回:
Any: 字典值或默认值
"""
value = self[key]
return default if value is None else value
def set(self, key: str, value: Any, expire: int | None = None) -> None:
"""设置字典项
参数:
key: 字典键
value: 字典值
expire: 过期时间为None时使用默认值
"""
# 计算过期时间
expire_time = 0
if expire is not None and expire > 0:
expire_time = time.time() + expire
elif self.expire > 0:
expire_time = time.time() + self.expire
self._data[key] = CacheData(value=value, expire_time=expire_time)
def pop(self, key: str, default: Any = None) -> Any:
"""删除并返回字典项
参数:
key: 字典键
default: 默认值
返回:
Any: 字典值或默认值
"""
if key not in self._data:
return default
data = self._data.pop(key)
# 检查是否过期
if data.expire_time > 0 and data.expire_time < time.time():
return default
return data.value
def clear(self) -> None:
"""清空字典"""
self._data.clear()
def keys(self) -> list[str]:
"""获取所有键
返回:
list[str]: 键列表
"""
# 清理过期的键
self._clean_expired()
return list(self._data.keys())
def values(self) -> list[Any]:
"""获取所有值
返回:
list[Any]: 值列表
"""
# 清理过期的键
self._clean_expired()
return [data.value for data in self._data.values()]
def items(self) -> list[tuple[str, Any]]:
"""获取所有键值对
返回:
list[tuple[str, Any]]: 键值对列表
"""
# 清理过期的键
self._clean_expired()
return [(key, data.value) for key, data in self._data.items()]
def _clean_expired(self) -> None:
"""清理过期的键"""
now = time.time()
expired_keys = [
key
for key, data in self._data.items()
if data.expire_time > 0 and data.expire_time < now
]
for key in expired_keys:
del self._data[key]
def __len__(self) -> int:
"""获取字典长度
返回:
int: 字典长度
"""
# 清理过期的键
self._clean_expired()
return len(self._data)
def __str__(self) -> str:
"""字符串表示
返回:
str: 字符串表示
"""
# 清理过期的键
self._clean_expired()
return f"CacheDict({self.name}, {len(self._data)} items)"
class CacheList:
"""缓存列表类,提供类似普通列表的接口,数据只存储在内存中"""
def __init__(self, name: str, expire: int = 0):
"""初始化缓存列表
参数:
name: 列表名称
expire: 过期时间默认为0表示永不过期
"""
self.name = name.upper()
self.expire = expire
self._data: list[CacheData[Any]] = []
self._expire_time = 0
# 如果设置了过期时间,计算整个列表的过期时间
if self.expire > 0:
self._expire_time = time.time() + self.expire
def __getitem__(self, index: int) -> Any:
"""获取列表项
参数:
index: 列表索引
返回:
Any: 列表值
"""
# 检查整个列表是否过期
if self._is_expired():
self.clear()
raise IndexError(f"列表索引 {index} 超出范围")
if 0 <= index < len(self._data):
return self._data[index].value
raise IndexError(f"列表索引 {index} 超出范围")
def __setitem__(self, index: int, value: Any) -> None:
"""设置列表项
参数:
index: 列表索引
value: 列表值
"""
# 检查整个列表是否过期
if self._is_expired():
self.clear()
# 确保索引有效
while len(self._data) <= index:
self._data.append(CacheData(value=None))
self._data[index] = CacheData(value=value)
# 更新过期时间
self._update_expire_time()
def __delitem__(self, index: int) -> None:
"""删除列表项
参数:
index: 列表索引
"""
# 检查整个列表是否过期
if self._is_expired():
self.clear()
raise IndexError(f"列表索引 {index} 超出范围")
if 0 <= index < len(self._data):
del self._data[index]
# 更新过期时间
self._update_expire_time()
else:
raise IndexError(f"列表索引 {index} 超出范围")
def __len__(self) -> int:
"""获取列表长度
返回:
int: 列表长度
"""
# 检查整个列表是否过期
if self._is_expired():
self.clear()
return len(self._data)
def append(self, value: Any) -> None:
"""添加列表项
参数:
value: 列表值
"""
# 检查整个列表是否过期
if self._is_expired():
self.clear()
self._data.append(CacheData(value=value))
# 更新过期时间
self._update_expire_time()
def extend(self, values: list[Any]) -> None:
"""扩展列表
参数:
values: 要添加的值列表
"""
# 检查整个列表是否过期
if self._is_expired():
self.clear()
self._data.extend([CacheData(value=v) for v in values])
# 更新过期时间
self._update_expire_time()
def insert(self, index: int, value: Any) -> None:
"""插入列表项
参数:
index: 插入位置
value: 列表值
"""
# 检查整个列表是否过期
if self._is_expired():
self.clear()
self._data.insert(index, CacheData(value=value))
# 更新过期时间
self._update_expire_time()
def pop(self, index: int = -1) -> Any:
"""删除并返回列表项
参数:
index: 列表索引默认为最后一项
返回:
Any: 列表值
"""
# 检查整个列表是否过期
if self._is_expired():
self.clear()
raise IndexError("从空列表中弹出")
if not self._data:
raise IndexError("从空列表中弹出")
item = self._data.pop(index)
# 更新过期时间
self._update_expire_time()
return item.value
def remove(self, value: Any) -> None:
"""删除第一个匹配的列表项
参数:
value: 要删除的值
"""
# 检查整个列表是否过期
if self._is_expired():
self.clear()
raise ValueError(f"{value} 不在列表中")
# 查找匹配的项
for i, item in enumerate(self._data):
if item.value == value:
del self._data[i]
# 更新过期时间
self._update_expire_time()
return
raise ValueError(f"{value} 不在列表中")
def clear(self) -> None:
"""清空列表"""
self._data.clear()
# 重置过期时间
self._update_expire_time()
def index(self, value: Any, start: int = 0, end: int | None = None) -> int:
"""查找值的索引
参数:
value: 要查找的值
start: 起始索引
end: 结束索引
返回:
int: 索引位置
"""
# 检查整个列表是否过期
if self._is_expired():
self.clear()
raise ValueError(f"{value} 不在列表中")
end = end if end is not None else len(self._data)
for i in range(start, min(end, len(self._data))):
if self._data[i].value == value:
return i
raise ValueError(f"{value} 不在列表中")
def count(self, value: Any) -> int:
"""计算值出现的次数
参数:
value: 要计数的值
返回:
int: 出现次数
"""
# 检查整个列表是否过期
if self._is_expired():
self.clear()
return 0
return sum(1 for item in self._data if item.value == value)
def _is_expired(self) -> bool:
"""检查整个列表是否过期"""
return self._expire_time > 0 and self._expire_time < time.time()
def _update_expire_time(self) -> None:
"""更新过期时间"""
if self.expire > 0:
self._expire_time = time.time() + self.expire
else:
self._expire_time = 0
def __str__(self) -> str:
"""字符串表示
返回:
str: 字符串表示
"""
# 检查整个列表是否过期
if self._is_expired():
self.clear()
return f"CacheList({self.name}, {len(self._data)} items)"

35
zhenxun/services/cache/config.py vendored Normal file
View File

@ -0,0 +1,35 @@
"""
缓存系统配置
"""
# 日志标识
LOG_COMMAND = "CacheRoot"
# 默认缓存过期时间(秒)
DEFAULT_EXPIRE = 600
# 缓存键前缀
CACHE_KEY_PREFIX = "ZHENXUN"
# 缓存键分隔符
CACHE_KEY_SEPARATOR = ":"
# 复合键分隔符用于分隔tuple类型的cache_key_field
COMPOSITE_KEY_SEPARATOR = "_"
# 缓存模式
class CacheMode:
# 内存缓存 - 使用内存存储缓存数据
MEMORY = "MEMORY"
# Redis缓存 - 使用Redis服务器存储缓存数据
REDIS = "REDIS"
# 不使用缓存 - 将使用ttl=0的内存缓存相当于直接从数据库获取数据
NONE = "NONE"
SPECIAL_KEY_FORMATS = {
"LEVEL": "{user_id}" + COMPOSITE_KEY_SEPARATOR + "{group_id}",
"BAN": "{user_id}" + COMPOSITE_KEY_SEPARATOR + "{group_id}",
"GROUPS": "{group_id}" + COMPOSITE_KEY_SEPARATOR + "{channel_id}",
}

View File

@ -0,0 +1,653 @@
from typing import Any, ClassVar, Generic, TypeVar, cast
from zhenxun.services.cache import Cache, CacheRoot, cache_config
from zhenxun.services.cache.config import COMPOSITE_KEY_SEPARATOR, CacheMode
from zhenxun.services.db_context import Model, with_db_timeout
from zhenxun.services.log import logger
T = TypeVar("T", bound=Model)
class DataAccess(Generic[T]):
"""数据访问层,根据配置决定是否使用缓存
使用示例:
```python
from zhenxun.services import DataAccess
from zhenxun.models.plugin_info import PluginInfo
# 创建数据访问对象
plugin_dao = DataAccess(PluginInfo)
# 获取单个数据
plugin = await plugin_dao.get(module="example_module")
# 获取所有数据
all_plugins = await plugin_dao.all()
# 筛选数据
enabled_plugins = await plugin_dao.filter(status=True)
# 创建数据
new_plugin = await plugin_dao.create(
module="new_module",
name="新插件",
status=True
)
```
"""
# 添加缓存统计信息
_cache_stats: ClassVar[dict] = {}
# 空结果标记
_NULL_RESULT = "__NULL_RESULT_PLACEHOLDER__"
# 默认空结果缓存时间(秒)- 设置为5分钟避免频繁查询数据库
_NULL_RESULT_TTL = 300
@classmethod
def set_null_result_ttl(cls, seconds: int) -> None:
"""设置空结果缓存时间
参数:
seconds: 缓存时间
"""
if seconds < 0:
raise ValueError("缓存时间不能为负数")
cls._NULL_RESULT_TTL = seconds
logger.info(f"已设置DataAccess空结果缓存时间为 {seconds}")
@classmethod
def get_null_result_ttl(cls) -> int:
"""获取空结果缓存时间
返回:
int: 缓存时间
"""
return cls._NULL_RESULT_TTL
def __init__(
self, model_cls: type[T], key_field: str = "id", cache_type: str | None = None
):
"""初始化数据访问对象
参数:
model_cls: 模型类
key_field: 主键字段
"""
self.model_cls = model_cls
self.key_field = getattr(model_cls, "cache_key_field", key_field)
self.cache_type = getattr(model_cls, "cache_type", cache_type)
if not self.cache_type:
raise ValueError("缓存类型不能为空")
self.cache = Cache(self.cache_type)
# 初始化缓存统计
if self.cache_type not in self._cache_stats:
self._cache_stats[self.cache_type] = {
"hits": 0, # 缓存命中次数
"misses": 0, # 缓存未命中次数
"null_hits": 0, # 空结果缓存命中次数
"sets": 0, # 缓存设置次数
"null_sets": 0, # 空结果缓存设置次数
"deletes": 0, # 缓存删除次数
}
@classmethod
def get_cache_stats(cls):
"""获取缓存统计信息"""
result = []
for cache_type, stats in cls._cache_stats.items():
hits = stats["hits"]
null_hits = stats.get("null_hits", 0)
misses = stats["misses"]
total = hits + null_hits + misses
hit_rate = ((hits + null_hits) / total * 100) if total > 0 else 0
result.append(
{
"cache_type": cache_type,
"hits": hits,
"null_hits": null_hits,
"misses": misses,
"sets": stats["sets"],
"null_sets": stats.get("null_sets", 0),
"deletes": stats["deletes"],
"hit_rate": f"{hit_rate:.2f}%",
}
)
return result
@classmethod
def reset_cache_stats(cls):
"""重置缓存统计信息"""
for stats in cls._cache_stats.values():
stats["hits"] = 0
stats["null_hits"] = 0
stats["misses"] = 0
stats["sets"] = 0
stats["null_sets"] = 0
stats["deletes"] = 0
def _build_cache_key_from_kwargs(self, **kwargs) -> str | None:
"""从关键字参数构建缓存键
参数:
**kwargs: 关键字参数
返回:
str | None: 缓存键如果无法构建则返回None
"""
if isinstance(self.key_field, tuple):
# 多字段主键
key_parts = []
key_parts.extend(str(kwargs.get(field, "")) for field in self.key_field)
return COMPOSITE_KEY_SEPARATOR.join(key_parts) if key_parts else None
elif self.key_field in kwargs:
# 单字段主键
return str(kwargs[self.key_field])
return None
async def safe_get_or_none(self, *args, **kwargs) -> T | None:
"""安全的获取单条数据
参数:
*args: 查询参数
**kwargs: 查询参数
返回:
Optional[T]: 查询结果如果不存在返回None
"""
# 如果没有缓存类型,直接从数据库获取
if not self.cache_type or cache_config.cache_mode == CacheMode.NONE:
logger.debug(f"{self.model_cls.__name__} 直接从数据库获取数据: {kwargs}")
return await with_db_timeout(
self.model_cls.safe_get_or_none(*args, **kwargs),
operation=f"{self.model_cls.__name__}.safe_get_or_none",
)
# 尝试从缓存获取
cache_key = None
try:
# 尝试构建缓存键
cache_key = self._build_cache_key_from_kwargs(**kwargs)
# 如果成功构建缓存键,尝试从缓存获取
if cache_key is not None:
data = await self.cache.get(cache_key)
logger.debug(
f"{self.model_cls.__name__} self.cache.get(cache_key)"
f" 从缓存获取到的数据 {type(data)}: {data}"
)
if data == self._NULL_RESULT:
# 空结果缓存命中
self._cache_stats[self.cache_type]["null_hits"] += 1
logger.debug(
f"{self.model_cls.__name__} 从缓存获取到空结果: {cache_key}"
)
return None
elif data:
# 缓存命中
self._cache_stats[self.cache_type]["hits"] += 1
logger.debug(
f"{self.model_cls.__name__} 从缓存获取数据成功: {cache_key}"
)
return cast(T, data)
else:
# 缓存未命中
self._cache_stats[self.cache_type]["misses"] += 1
logger.debug(f"{self.model_cls.__name__} 缓存未命中: {cache_key}")
except Exception as e:
logger.error(f"{self.model_cls.__name__} 从缓存获取数据失败: {kwargs}", e=e)
# 如果缓存中没有,从数据库获取
logger.debug(f"{self.model_cls.__name__} 从数据库获取数据: {kwargs}")
data = await self.model_cls.safe_get_or_none(*args, **kwargs)
# 如果获取到数据,存入缓存
if data:
try:
# 生成缓存键
cache_key = self._build_cache_key_for_item(data)
if cache_key is not None:
# 存入缓存
await self.cache.set(cache_key, data)
self._cache_stats[self.cache_type]["sets"] += 1
logger.debug(
f"{self.model_cls.__name__} 数据已存入缓存: {cache_key}"
)
except Exception as e:
logger.error(
f"{self.model_cls.__name__} 存入缓存失败,参数: {kwargs}", e=e
)
elif cache_key is not None:
# 如果没有获取到数据,缓存空结果
try:
# 存入空结果缓存,使用较短的过期时间
await self.cache.set(
cache_key, self._NULL_RESULT, expire=self._NULL_RESULT_TTL
)
self._cache_stats[self.cache_type]["null_sets"] += 1
logger.debug(
f"{self.model_cls.__name__} 空结果已存入缓存: {cache_key},"
f" TTL={self._NULL_RESULT_TTL}"
)
except Exception as e:
logger.error(
f"{self.model_cls.__name__} 存入空结果缓存失败,参数: {kwargs}", e=e
)
return data
async def get_or_none(self, *args, **kwargs) -> T | None:
"""获取单条数据
参数:
*args: 查询参数
**kwargs: 查询参数
返回:
Optional[T]: 查询结果如果不存在返回None
"""
# 如果没有缓存类型,直接从数据库获取
if not self.cache_type or cache_config.cache_mode == CacheMode.NONE:
logger.debug(f"{self.model_cls.__name__} 直接从数据库获取数据: {kwargs}")
return await with_db_timeout(
self.model_cls.get_or_none(*args, **kwargs),
operation=f"{self.model_cls.__name__}.get_or_none",
)
# 尝试从缓存获取
cache_key = None
try:
# 尝试构建缓存键
cache_key = self._build_cache_key_from_kwargs(**kwargs)
# 如果成功构建缓存键,尝试从缓存获取
if cache_key is not None:
data = await self.cache.get(cache_key)
if data == self._NULL_RESULT:
# 空结果缓存命中
self._cache_stats[self.cache_type]["null_hits"] += 1
logger.debug(
f"{self.model_cls.__name__} 从缓存获取到空结果: {cache_key}"
)
return None
elif data:
# 缓存命中
self._cache_stats[self.cache_type]["hits"] += 1
logger.debug(
f"{self.model_cls.__name__} 从缓存获取数据成功: {cache_key}"
)
return cast(T, data)
else:
# 缓存未命中
self._cache_stats[self.cache_type]["misses"] += 1
logger.debug(f"{self.model_cls.__name__} 缓存未命中: {cache_key}")
except Exception as e:
logger.error(f"{self.model_cls.__name__} 从缓存获取数据失败: {kwargs}", e=e)
# 如果缓存中没有,从数据库获取
logger.debug(f"{self.model_cls.__name__} 从数据库获取数据: {kwargs}")
data = await self.model_cls.get_or_none(*args, **kwargs)
# 如果获取到数据,存入缓存
if data:
try:
cache_key = self._build_cache_key_for_item(data)
# 生成缓存键
if cache_key is not None:
# 存入缓存
await self.cache.set(cache_key, data)
self._cache_stats[self.cache_type]["sets"] += 1
logger.debug(
f"{self.model_cls.__name__} 数据已存入缓存: {cache_key}"
)
except Exception as e:
logger.error(
f"{self.model_cls.__name__} 存入缓存失败,参数: {kwargs}", e=e
)
elif cache_key is not None:
# 如果没有获取到数据,缓存空结果
try:
# 存入空结果缓存,使用较短的过期时间
await self.cache.set(
cache_key, self._NULL_RESULT, expire=self._NULL_RESULT_TTL
)
self._cache_stats[self.cache_type]["null_sets"] += 1
logger.debug(
f"{self.model_cls.__name__} 空结果已存入缓存: {cache_key},"
f" TTL={self._NULL_RESULT_TTL}"
)
except Exception as e:
logger.error(
f"{self.model_cls.__name__} 存入空结果缓存失败,参数: {kwargs}", e=e
)
return data
async def clear_cache(self, **kwargs) -> bool:
"""只清除缓存,不影响数据库数据
参数:
**kwargs: 查询参数必须包含主键字段
返回:
bool: 是否成功清除缓存
"""
# 如果没有缓存类型直接返回True
if not self.cache_type or cache_config.cache_mode == CacheMode.NONE:
return True
try:
# 构建缓存键
cache_key = self._build_cache_key_from_kwargs(**kwargs)
if cache_key is None:
if isinstance(self.key_field, tuple):
# 如果是复合键,检查缺少哪些字段
missing_fields = [
field for field in self.key_field if field not in kwargs
]
logger.error(
f"清除{self.model_cls.__name__}缓存失败: "
f"缺少主键字段 {', '.join(missing_fields)}"
)
else:
logger.error(
f"清除{self.model_cls.__name__}缓存失败: "
f"缺少主键字段 {self.key_field}"
)
return False
# 删除缓存
await self.cache.delete(cache_key)
self._cache_stats[self.cache_type]["deletes"] += 1
logger.debug(f"已清除{self.model_cls.__name__}缓存: {cache_key}")
return True
except Exception as e:
logger.error(f"清除{self.model_cls.__name__}缓存失败", e=e)
return False
def _build_composite_key(self, data: T) -> str | None:
"""构建复合缓存键
参数:
data: 数据对象
返回:
str | None: 构建的缓存键如果无法构建则返回None
"""
# 如果是元组,表示多个字段组成键
if isinstance(self.key_field, tuple):
# 构建键参数列表
key_parts = []
for field in self.key_field:
value = getattr(data, field, "")
key_parts.append(value if value is not None else "")
# 如果没有有效参数返回None
return COMPOSITE_KEY_SEPARATOR.join(key_parts) if key_parts else None
elif hasattr(data, self.key_field):
value = getattr(data, self.key_field, None)
return str(value) if value is not None else None
return None
def _build_cache_key_for_item(self, item: T) -> str | None:
"""为数据项构建缓存键
参数:
item: 数据项
返回:
str | None: 缓存键如果无法生成则返回None
"""
# 如果没有缓存类型返回None
if not self.cache_type:
return None
# 获取缓存类型的配置信息
cache_model = CacheRoot.get_model(self.cache_type)
if not cache_model.key_format:
# 常规处理,使用主键作为缓存键
return self._build_composite_key(item)
# 构建键参数字典
key_parts = []
# 从格式字符串中提取所需的字段名
import re
field_names = re.findall(r"{([^}]+)}", cache_model.key_format)
# 收集所有字段值
for field in field_names:
value = getattr(item, field, "")
key_parts.append(value if value is not None else "")
return COMPOSITE_KEY_SEPARATOR.join(key_parts)
async def _cache_items(self, data_list: list[T]) -> None:
"""将数据列表存入缓存
参数:
data_list: 数据列表
"""
if (
not data_list
or not self.cache_type
or cache_config.cache_mode == CacheMode.NONE
):
return
try:
# 遍历数据列表,将每条数据存入缓存
cached_count = 0
for item in data_list:
cache_key = self._build_cache_key_for_item(item)
if cache_key is not None:
await self.cache.set(cache_key, item)
cached_count += 1
self._cache_stats[self.cache_type]["sets"] += 1
logger.debug(
f"{self.model_cls.__name__} 批量缓存: {cached_count}/{len(data_list)}"
)
except Exception as e:
logger.error(f"{self.model_cls.__name__} 批量缓存失败", e=e)
async def filter(self, *args, **kwargs) -> list[T]:
"""筛选数据
参数:
*args: 查询参数
**kwargs: 查询参数
返回:
List[T]: 查询结果列表
"""
# 从数据库获取数据
logger.debug(f"{self.model_cls.__name__} filter: 从数据库查询, 参数: {kwargs}")
data_list = await self.model_cls.filter(*args, **kwargs)
logger.debug(
f"{self.model_cls.__name__} filter: 查询结果数量: {len(data_list)}"
)
# 将数据存入缓存
await self._cache_items(data_list)
return data_list
async def all(self) -> list[T]:
"""获取所有数据
返回:
List[T]: 所有数据列表
"""
# 直接从数据库获取
logger.debug(f"{self.model_cls.__name__} all: 从数据库查询所有数据")
data_list = await self.model_cls.all()
logger.debug(f"{self.model_cls.__name__} all: 查询结果数量: {len(data_list)}")
# 将数据存入缓存
await self._cache_items(data_list)
return data_list
async def count(self, *args, **kwargs) -> int:
"""获取数据数量
参数:
*args: 查询参数
**kwargs: 查询参数
返回:
int: 数据数量
"""
# 直接从数据库获取数量
return await self.model_cls.filter(*args, **kwargs).count()
async def exists(self, *args, **kwargs) -> bool:
"""判断数据是否存在
参数:
*args: 查询参数
**kwargs: 查询参数
返回:
bool: 是否存在
"""
# 直接从数据库判断是否存在
return await self.model_cls.filter(*args, **kwargs).exists()
async def create(self, **kwargs) -> T:
"""创建数据
参数:
**kwargs: 创建参数
返回:
T: 创建的数据
"""
# 创建数据
logger.debug(f"{self.model_cls.__name__} create: 创建数据, 参数: {kwargs}")
data = await self.model_cls.create(**kwargs)
# 如果有缓存类型,将数据存入缓存
if self.cache_type and cache_config.cache_mode != CacheMode.NONE:
try:
# 生成缓存键
cache_key = self._build_cache_key_for_item(data)
if cache_key is not None:
# 存入缓存
await self.cache.set(cache_key, data)
self._cache_stats[self.cache_type]["sets"] += 1
logger.debug(
f"{self.model_cls.__name__} create: "
f"新创建的数据已存入缓存: {cache_key}"
)
except Exception as e:
logger.error(
f"{self.model_cls.__name__} create: 存入缓存失败,参数: {kwargs}",
e=e,
)
return data
async def update_or_create(
self, defaults: dict[str, Any] | None = None, **kwargs
) -> tuple[T, bool]:
"""更新或创建数据
参数:
defaults: 默认值
**kwargs: 查询参数
返回:
tuple[T, bool]: (数据, 是否创建)
"""
# 更新或创建数据
data, created = await self.model_cls.update_or_create(
defaults=defaults, **kwargs
)
# 如果有缓存类型,将数据存入缓存
if self.cache_type and cache_config.cache_mode != CacheMode.NONE:
try:
# 生成缓存键
cache_key = self._build_cache_key_for_item(data)
if cache_key is not None:
# 存入缓存
await self.cache.set(cache_key, data)
self._cache_stats[self.cache_type]["sets"] += 1
logger.debug(f"更新或创建的数据已存入缓存: {cache_key}")
except Exception as e:
logger.error(f"存入缓存失败,参数: {kwargs}", e=e)
return data, created
async def delete(self, *args, **kwargs) -> int:
"""删除数据
参数:
*args: 查询参数
**kwargs: 查询参数
返回:
int: 删除的数据数量
"""
logger.debug(f"{self.model_cls.__name__} delete: 删除数据, 参数: {kwargs}")
# 如果有缓存类型且有key_field参数先尝试删除缓存
if self.cache_type and cache_config.cache_mode != CacheMode.NONE:
try:
# 尝试构建缓存键
cache_key = self._build_cache_key_from_kwargs(**kwargs)
if cache_key is not None:
# 如果成功构建缓存键,直接删除缓存
await self.cache.delete(cache_key)
self._cache_stats[self.cache_type]["deletes"] += 1
logger.debug(
f"{self.model_cls.__name__} delete: 已删除缓存: {cache_key}"
)
else:
# 否则需要先查询出要删除的数据,然后删除对应的缓存
items = await self.model_cls.filter(*args, **kwargs)
logger.debug(
f"{self.model_cls.__name__} delete:"
f" 查询到 {len(items)} 条要删除的数据"
)
for item in items:
item_cache_key = self._build_cache_key_for_item(item)
if item_cache_key is not None:
await self.cache.delete(item_cache_key)
self._cache_stats[self.cache_type]["deletes"] += 1
if items:
logger.debug(
f"{self.model_cls.__name__} delete:"
f" 已删除 {len(items)} 条数据的缓存"
)
except Exception as e:
logger.error(f"{self.model_cls.__name__} delete: 删除缓存失败", e=e)
# 删除数据
result = await self.model_cls.filter(*args, **kwargs).delete()
logger.debug(
f"{self.model_cls.__name__} delete: 已从数据库删除 {result} 条数据"
)
return result
def _generate_cache_key(self, data: T) -> str:
"""根据数据对象生成缓存键
参数:
data: 数据对象
返回:
str: 缓存键
"""
# 使用新方法构建复合键
if composite_key := self._build_composite_key(data):
return composite_key
# 如果无法生成复合键,生成一个唯一键
return f"object_{id(data)}"

View File

@ -1,37 +1,328 @@
import nonebot
import asyncio
from collections.abc import Iterable
import contextlib
import time
from typing import Any, ClassVar
from typing_extensions import Self
from urllib.parse import urlparse
from nonebot import get_driver
from nonebot.utils import is_coroutine_callable
from tortoise import Tortoise
from tortoise.backends.base.client import BaseDBAsyncClient
from tortoise.connection import connections
from tortoise.models import Model as Model_
from tortoise.exceptions import IntegrityError, MultipleObjectsReturned
from tortoise.models import Model as TortoiseModel
from tortoise.transactions import in_transaction
from zhenxun.configs.config import BotConfig
from zhenxun.services.cache import CacheRoot
from zhenxun.services.log import logger
from zhenxun.utils.enum import DbLockType
from zhenxun.utils.exception import HookPriorityException
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from .log import logger
driver = get_driver()
SCRIPT_METHOD = []
MODELS: list[str] = []
# 数据库操作超时设置(秒)
DB_TIMEOUT_SECONDS = 3.0
driver = nonebot.get_driver()
# 性能监控阈值(秒)
SLOW_QUERY_THRESHOLD = 0.5
LOG_COMMAND = "DbContext"
class Model(Model_):
async def with_db_timeout(
coro, timeout: float = DB_TIMEOUT_SECONDS, operation: str | None = None
):
"""带超时控制的数据库操作"""
start_time = time.time()
try:
result = await asyncio.wait_for(coro, timeout=timeout)
elapsed = time.time() - start_time
if elapsed > SLOW_QUERY_THRESHOLD and operation:
logger.warning(f"慢查询: {operation} 耗时 {elapsed:.3f}s", LOG_COMMAND)
return result
except asyncio.TimeoutError:
if operation:
logger.error(f"数据库操作超时: {operation} (>{timeout}s)", LOG_COMMAND)
raise
class Model(TortoiseModel):
"""
自动添加模块
Args:
Model_: Model
增强的ORM基类解决锁嵌套问题
"""
sem_data: ClassVar[dict[str, dict[str, asyncio.Semaphore]]] = {}
_current_locks: ClassVar[dict[int, DbLockType]] = {} # 跟踪当前协程持有的锁
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if cls.__module__ not in MODELS:
MODELS.append(cls.__module__)
if func := getattr(cls, "_run_script", None):
SCRIPT_METHOD.append((cls.__module__, func))
@classmethod
def get_cache_type(cls) -> str | None:
"""获取缓存类型"""
return getattr(cls, "cache_type", None)
@classmethod
def get_cache_key_field(cls) -> str | tuple[str]:
"""获取缓存键字段"""
return getattr(cls, "cache_key_field", "id")
@classmethod
def get_cache_key(cls, instance) -> str | None:
"""获取缓存键
参数:
instance: 模型实例
返回:
str | None: 缓存键如果无法获取则返回None
"""
from zhenxun.services.cache.config import COMPOSITE_KEY_SEPARATOR
key_field = cls.get_cache_key_field()
if isinstance(key_field, tuple):
# 多字段主键
key_parts = []
for field in key_field:
if hasattr(instance, field):
value = getattr(instance, field, None)
key_parts.append(value if value is not None else "")
else:
# 如果缺少任何必要的字段返回None
key_parts.append("")
# 如果没有有效参数返回None
return COMPOSITE_KEY_SEPARATOR.join(key_parts) if key_parts else None
elif hasattr(instance, key_field):
value = getattr(instance, key_field, None)
return str(value) if value is not None else None
return None
@classmethod
def get_semaphore(cls, lock_type: DbLockType):
enable_lock = getattr(cls, "enable_lock", None)
if not enable_lock or lock_type not in enable_lock:
return None
if cls.__name__ not in cls.sem_data:
cls.sem_data[cls.__name__] = {}
if lock_type not in cls.sem_data[cls.__name__]:
cls.sem_data[cls.__name__][lock_type] = asyncio.Semaphore(1)
return cls.sem_data[cls.__name__][lock_type]
@classmethod
def _require_lock(cls, lock_type: DbLockType) -> bool:
"""检查是否需要真正加锁"""
task_id = id(asyncio.current_task())
return cls._current_locks.get(task_id) != lock_type
@classmethod
@contextlib.asynccontextmanager
async def _lock_context(cls, lock_type: DbLockType):
"""带重入检查的锁上下文"""
task_id = id(asyncio.current_task())
need_lock = cls._require_lock(lock_type)
if need_lock and (sem := cls.get_semaphore(lock_type)):
cls._current_locks[task_id] = lock_type
async with sem:
yield
cls._current_locks.pop(task_id, None)
else:
yield
@classmethod
async def create(
cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any
) -> Self:
"""创建数据使用CREATE锁"""
async with cls._lock_context(DbLockType.CREATE):
# 直接调用父类的_create方法避免触发save的锁
result = await super().create(using_db=using_db, **kwargs)
if cache_type := cls.get_cache_type():
await CacheRoot.invalidate_cache(cache_type, cls.get_cache_key(result))
return result
@classmethod
async def get_or_create(
cls,
defaults: dict | None = None,
using_db: BaseDBAsyncClient | None = None,
**kwargs: Any,
) -> tuple[Self, bool]:
"""获取或创建数据(无锁版本,依赖数据库约束)"""
result = await super().get_or_create(
defaults=defaults, using_db=using_db, **kwargs
)
if cache_type := cls.get_cache_type():
await CacheRoot.invalidate_cache(cache_type, cls.get_cache_key(result[0]))
return result
@classmethod
async def update_or_create(
cls,
defaults: dict | None = None,
using_db: BaseDBAsyncClient | None = None,
**kwargs: Any,
) -> tuple[Self, bool]:
"""更新或创建数据使用UPSERT锁"""
async with cls._lock_context(DbLockType.UPSERT):
try:
# 先尝试更新(带行锁)
async with in_transaction():
if obj := await cls.filter(**kwargs).select_for_update().first():
await obj.update_from_dict(defaults or {})
await obj.save()
result = (obj, False)
else:
# 创建时不重复加锁
result = await cls.create(**kwargs, **(defaults or {})), True
if cache_type := cls.get_cache_type():
await CacheRoot.invalidate_cache(
cache_type, cls.get_cache_key(result[0])
)
return result
except IntegrityError:
# 处理极端情况下的唯一约束冲突
obj = await cls.get(**kwargs)
return obj, False
async def save(
self,
using_db: BaseDBAsyncClient | None = None,
update_fields: Iterable[str] | None = None,
force_create: bool = False,
force_update: bool = False,
):
"""保存数据(根据操作类型自动选择锁)"""
lock_type = (
DbLockType.CREATE
if getattr(self, "id", None) is None
else DbLockType.UPDATE
)
async with self._lock_context(lock_type):
await super().save(
using_db=using_db,
update_fields=update_fields,
force_create=force_create,
force_update=force_update,
)
if cache_type := getattr(self, "cache_type", None):
await CacheRoot.invalidate_cache(
cache_type, self.__class__.get_cache_key(self)
)
async def delete(self, using_db: BaseDBAsyncClient | None = None):
cache_type = getattr(self, "cache_type", None)
key = self.__class__.get_cache_key(self) if cache_type else None
# 执行删除操作
await super().delete(using_db=using_db)
# 清除缓存
if cache_type:
await CacheRoot.invalidate_cache(cache_type, key)
@classmethod
async def safe_get_or_none(
cls,
*args,
using_db: BaseDBAsyncClient | None = None,
clean_duplicates: bool = True,
**kwargs: Any,
) -> Self | None:
"""安全地获取一条记录或None处理存在多个记录时返回最新的那个
注意默认会删除重复的记录仅保留最新的
参数:
*args: 查询参数
using_db: 数据库连接
clean_duplicates: 是否删除重复的记录仅保留最新的
**kwargs: 查询参数
返回:
Self | None: 查询结果如果不存在返回None
"""
try:
# 先尝试使用 get_or_none 获取单个记录
try:
return await with_db_timeout(
cls.get_or_none(*args, using_db=using_db, **kwargs),
operation=f"{cls.__name__}.get_or_none",
)
except MultipleObjectsReturned:
# 如果出现多个记录的情况,进行特殊处理
logger.warning(
f"{cls.__name__} safe_get_or_none 发现多个记录: {kwargs}",
LOG_COMMAND,
)
# 查询所有匹配记录
records = await with_db_timeout(
cls.filter(*args, **kwargs).all(),
operation=f"{cls.__name__}.filter.all",
)
if not records:
return None
# 如果需要清理重复记录
if clean_duplicates and hasattr(records[0], "id"):
# 按 id 排序
records = sorted(
records, key=lambda x: getattr(x, "id", 0), reverse=True
)
for record in records[1:]:
try:
await with_db_timeout(
record.delete(),
operation=f"{cls.__name__}.delete_duplicate",
)
logger.info(
f"{cls.__name__} 删除重复记录:"
f" id={getattr(record, 'id', None)}",
LOG_COMMAND,
)
except Exception as del_e:
logger.error(f"删除重复记录失败: {del_e}")
return records[0]
# 如果不需要清理或没有 id 字段,则返回最新的记录
if hasattr(cls, "id"):
return await with_db_timeout(
cls.filter(*args, **kwargs).order_by("-id").first(),
operation=f"{cls.__name__}.filter.order_by.first",
)
# 如果没有 id 字段,则返回第一个记录
return await with_db_timeout(
cls.filter(*args, **kwargs).first(),
operation=f"{cls.__name__}.filter.first",
)
except asyncio.TimeoutError:
logger.error(
f"数据库操作超时: {cls.__name__}.safe_get_or_none", LOG_COMMAND
)
return None
except Exception as e:
# 其他类型的错误则继续抛出
logger.error(
f"数据库操作异常: {cls.__name__}.safe_get_or_none, {e!s}", LOG_COMMAND
)
raise
class DbUrlIsNode(HookPriorityException):
"""
@ -49,6 +340,77 @@ class DbConnectError(Exception):
pass
POSTGRESQL_CONFIG = {
"max_size": 30, # 最大连接数
"min_size": 5, # 最小保持的连接数(可选)
}
MYSQL_CONFIG = {
"max_connections": 20, # 最大连接数
"connect_timeout": 30, # 连接超时(可选)
}
SQLITE_CONFIG = {
"journal_mode": "WAL", # 提高并发写入性能
"timeout": 30, # 锁等待超时(可选)
}
def get_config(db_url: str) -> dict:
"""获取数据库配置"""
parsed = urlparse(BotConfig.db_url)
# 基础配置
config = {
"connections": {
"default": BotConfig.db_url # 默认直接使用连接字符串
},
"apps": {
"models": {
"models": MODELS,
"default_connection": "default",
}
},
"timezone": "Asia/Shanghai",
}
# 根据数据库类型应用高级配置
if parsed.scheme.startswith("postgres"):
config["connections"]["default"] = {
"engine": "tortoise.backends.asyncpg",
"credentials": {
"host": parsed.hostname,
"port": parsed.port or 5432,
"user": parsed.username,
"password": parsed.password,
"database": parsed.path[1:],
},
**POSTGRESQL_CONFIG,
}
elif parsed.scheme == "mysql":
config["connections"]["default"] = {
"engine": "tortoise.backends.mysql",
"credentials": {
"host": parsed.hostname,
"port": parsed.port or 3306,
"user": parsed.username,
"password": parsed.password,
"database": parsed.path[1:],
},
**MYSQL_CONFIG,
}
elif parsed.scheme == "sqlite":
config["connections"]["default"] = {
"engine": "tortoise.backends.sqlite",
"credentials": {
"file_path": parsed.path[1:] or ":memory:",
},
**SQLITE_CONFIG,
}
return config
@PriorityLifecycle.on_startup(priority=1)
async def init():
if not BotConfig.db_url:
@ -64,9 +426,7 @@ async def init():
raise DbUrlIsNode("\n" + error.strip())
try:
await Tortoise.init(
db_url=BotConfig.db_url,
modules={"models": MODELS},
timezone="Asia/Shanghai",
config=get_config(BotConfig.db_url),
)
if SCRIPT_METHOD:
db = Tortoise.get_connection("default")
@ -85,13 +445,17 @@ async def init():
for sql in sql_list:
logger.debug(f"执行SQL: {sql}")
try:
await db.execute_query_dict(sql)
await asyncio.wait_for(
db.execute_query_dict(sql), timeout=DB_TIMEOUT_SECONDS
)
# await TestSQL.raw(sql)
except Exception as e:
logger.debug(f"执行SQL: {sql} 错误...", e=e)
if sql_list:
logger.debug("SCRIPT_METHOD方法执行完毕!")
logger.debug("开始生成数据库表结构...")
await Tortoise.generate_schemas()
logger.debug("数据库表结构生成完毕!")
logger.info("Database loaded successfully!")
except Exception as e:
raise DbConnectError(f"数据库连接错误... e:{e}") from e

View File

@ -469,7 +469,7 @@ class Notebook:
template_name="main.html",
templates={"elements": self._data},
pages={
"viewport": {"width": 700, "height": 1000},
"viewport": {"width": 700, "height": 10},
"base_url": f"file://{TEMPLATE_PATH}",
},
wait=2,

View File

@ -53,9 +53,7 @@ class CommonUtils:
if await GroupConsole.is_block_task(group_id, module):
"""群组是否禁用被动"""
return True
if g := await GroupConsole.get_or_none(
group_id=group_id, channel_id__isnull=True
):
if g := await GroupConsole.get_group(group_id=group_id):
"""群组权限是否小于0"""
if g.level < 0:
return True

View File

@ -44,6 +44,44 @@ class EventLogType(StrEnum):
"""主动退群"""
class CacheType(StrEnum):
"""
缓存类型
"""
PLUGINS = "GLOBAL_ALL_PLUGINS"
"""全局全部插件"""
GROUPS = "GLOBAL_ALL_GROUPS"
"""全局全部群组"""
USERS = "GLOBAL_ALL_USERS"
"""全部用户"""
BAN = "GLOBAL_ALL_BAN"
"""全局ban列表"""
BOT = "GLOBAL_BOT"
"""全局bot信息"""
LEVEL = "GLOBAL_USER_LEVEL"
"""用户权限"""
LIMIT = "GLOBAL_LIMIT"
"""插件限制"""
class DbLockType(StrEnum):
"""
锁类型
"""
CREATE = "CREATE"
"""创建"""
DELETE = "DELETE"
"""删除"""
UPDATE = "UPDATE"
"""更新"""
QUERY = "QUERY"
"""查询"""
UPSERT = "UPSERT"
"""创建或更新"""
class GoldHandle(StrEnum):
"""
金币处理

View File

@ -49,6 +49,9 @@ async def _():
try:
for priority in priority_list:
for func in priority_data[priority]:
logger.debug(
f"执行优先级 [{priority}] on_startup 方法: {func.__module__}"
)
if is_coroutine_callable(func):
await func()
else:

View File

@ -1,11 +1,13 @@
from collections import defaultdict
from dataclasses import dataclass
from datetime import date, datetime
import os
from pathlib import Path
import time
from typing import Any
from typing import Any, ClassVar
import httpx
from nonebot_plugin_uninfo import Uninfo
import pypinyin
import pytz
@ -13,43 +15,53 @@ from zhenxun.configs.config import Config
from zhenxun.services.log import logger
@dataclass
class EntityIDs:
user_id: str
"""用户id"""
group_id: str | None
"""群组id"""
channel_id: str | None
"""频道id"""
class ResourceDirManager:
"""
临时文件管理器
"""
temp_path = [] # noqa: RUF012
temp_path: ClassVar[set[Path]] = set()
@classmethod
def __tree_append(cls, path: Path):
"""递归添加文件夹
参数:
path: 文件夹路径
"""
def __tree_append(cls, path: Path, deep: int = 1, current: int = 0):
"""递归添加文件夹"""
if current >= deep and deep != -1:
return
path = path.resolve() # 标准化路径
for f in os.listdir(path):
file = path / f
file = (path / f).resolve() # 标准化子路径
if file.is_dir():
if file not in cls.temp_path:
cls.temp_path.append(file)
logger.debug(f"添加临时文件夹: {path}")
cls.__tree_append(file)
cls.temp_path.add(file)
logger.debug(f"添加临时文件夹: {file}")
cls.__tree_append(file, deep, current + 1)
@classmethod
def add_temp_dir(cls, path: str | Path, tree: bool = False):
def add_temp_dir(cls, path: str | Path, tree: bool = False, deep: int = 1):
"""添加临时清理文件夹,这些文件夹会被自动清理
参数:
path: 文件夹路径
tree: 是否递归添加文件夹
deep: 深度, -1 为无限深度
"""
if isinstance(path, str):
path = Path(path)
if path not in cls.temp_path:
cls.temp_path.append(path)
cls.temp_path.add(path)
logger.debug(f"添加临时文件夹: {path}")
if tree:
cls.__tree_append(path)
cls.__tree_append(path, deep)
class CountLimiter:
@ -230,6 +242,27 @@ def is_valid_date(date_text: str, separator: str = "-") -> bool:
return False
def get_entity_ids(session: Uninfo) -> EntityIDs:
"""获取用户id群组id频道id
参数:
session: Uninfo
返回:
EntityIDs: 用户id群组id频道id
"""
user_id = session.user.id
group_id = None
channel_id = None
if session.group:
if session.group.parent:
group_id = session.group.parent.id
channel_id = session.group.id
else:
group_id = session.group.id
return EntityIDs(user_id=user_id, group_id=group_id, channel_id=channel_id)
def is_number(text: str) -> bool:
"""是否为数字