mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-14 21:52:56 +08:00
✨ 引入缓存机制 (#1889)
* 添加全局cache * ✨ 构建缓存,hook使用缓存 * ✨ 新增数据库Model方法监控 * ✨ 数据库添加semaphore锁 * 🩹 优化webapi返回数据 * ✨ 添加增量缓存与缓存过期 * 🎨 优化检测代码结构 * ⚡ 优化hook权限检测性能 * 🐛 添加新异常判断跳过权限检测 * ✨ 添加插件limit缓存 * 🎨 代码格式优化 * 🐛 修复代码导入 * 🐛 修复刷新时检查 * 👽 Rename exception for missing database URL in initialization * ♿ Update default database URL to SQLite in configuration * 🔧 Update tortoise-orm and aiocache dependencies restrictions; add optional redis and asyncpg support * 🐛 修复ban检测 * 🐛 修复所有插件关闭时缓存更新 * 🐛 尝试迁移至aiocache * 🐛 完善aiocache缓存 * ⚡ 代码性能优化 * 🐛 移除获取封禁缓存时的日志记录 * 🐛 修复缓存类型声明,优化封禁用户处理逻辑 * 🐛 优化LevelUser权限更新逻辑及数据库迁移 * ✨ cache支持redis连接 * 🚨 auto fix by pre-commit hooks * ⚡ :增强获取群组的安全性和准确性。同时,优化了缓存管理中的相关逻辑,确保缓存操作的一致性。 * ✨ feat(auth_limit): 将插件初始化逻辑的启动装饰器更改为优先级管理器 * 🔧 修复日志记录级别 * 🔧 更新数据库连接字符串 * 🔧 更新数据库连接字符串为内存数据库,并优化权限检查逻辑 * ✨ feat(cache): 增加缓存功能配置项,并新增数据访问层以支持缓存逻辑 * ♻️ 重构cache * ✨ feat(cache): 增强缓存管理,新增缓存字典和缓存列表功能,支持过期时间管理 * 🔧 修复Notebook类中的viewport高度设置,将其从1000调整为10 * ✨ 更新插件管理逻辑,替换缓存服务为CacheRoot并优化缓存失效处理 * ✨ 更新RegisterConfig类中的type字段 * ✨ 修复清理重复记录逻辑,确保检查记录的id属性有效性 * ⚡ 超级无敌大优化,解决延迟与卡死问题 * ✨ 更新封禁功能,增加封禁时长参数和描述,优化插件信息返回结构 * ✨ 更新zhenxun_help.py中的viewport高度,将其从453调整为10,以优化页面显示效果 * ✨ 优化插件分类逻辑,增加插件ID排序,并更新插件信息返回结构 --------- Co-authored-by: BalconyJH <balconyjh@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
6283c3d13d
commit
8649aaaa54
14
.env.dev
14
.env.dev
@ -27,6 +27,18 @@ QBOT_ID_DATA = '{
|
||||
# 示例: "sqlite:data/db/zhenxun.db" 在data目录下建立db文件夹
|
||||
DB_URL = ""
|
||||
|
||||
# NONE: 不使用缓存, MEMORY: 使用内存缓存, REDIS: 使用Redis缓存
|
||||
CACHE_MODE = NONE
|
||||
# REDIS配置,使用REDIS替换Cache内存缓存
|
||||
# REDIS地址
|
||||
# REDIS_HOST = "127.0.0.1"
|
||||
# REDIS端口
|
||||
# REDIS_PORT = 6379
|
||||
# REDIS密码
|
||||
# REDIS_PASSWORD = ""
|
||||
# REDIS过期时间
|
||||
# REDIS_EXPIRE = 600
|
||||
|
||||
# 系统代理
|
||||
# SYSTEM_PROXY = "http://127.0.0.1:7890"
|
||||
|
||||
@ -40,7 +52,7 @@ PLATFORM_SUPERUSERS = '
|
||||
DRIVER=~fastapi+~httpx+~websockets
|
||||
|
||||
|
||||
# LOG_LEVEL=DEBUG
|
||||
# LOG_LEVEL = DEBUG
|
||||
# 服务器和端口
|
||||
HOST = 127.0.0.1
|
||||
PORT = 8080
|
||||
|
||||
951
poetry.lock
generated
951
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -16,7 +16,7 @@ python = "^3.10"
|
||||
playwright = "^1.41.1"
|
||||
nonebot-adapter-onebot = "^2.3.1"
|
||||
nonebot-plugin-apscheduler = "^0.5"
|
||||
tortoise-orm = { extras = ["asyncpg"], version = "^0.20.0" }
|
||||
tortoise-orm = "^0.20.0"
|
||||
cattrs = "^23.2.3"
|
||||
ruamel-yaml = "^0.18.5"
|
||||
strenum = "^0.4.15"
|
||||
@ -39,7 +39,7 @@ dateparser = "^1.2.0"
|
||||
bilireq = "0.2.3post0"
|
||||
python-jose = { extras = ["cryptography"], version = "^3.3.0" }
|
||||
python-multipart = "^0.0.9"
|
||||
aiocache = "^0.12.2"
|
||||
aiocache = {extras = ["redis"], version = "^0.12.3"}
|
||||
py-cpuinfo = "^9.0.0"
|
||||
nonebot-plugin-alconna = "^0.54.0"
|
||||
tenacity = "^9.0.0"
|
||||
@ -47,6 +47,9 @@ nonebot-plugin-uninfo = ">0.4.1"
|
||||
nonebot-plugin-waiter = "^0.8.1"
|
||||
multidict = ">=6.0.0,!=6.3.2"
|
||||
|
||||
redis = { version = ">=5", optional = true }
|
||||
asyncpg = { version = ">=0.20.0", optional = true }
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
nonebug = "^0.4"
|
||||
pytest-cov = "^5.0.0"
|
||||
@ -57,6 +60,9 @@ respx = "^0.21.1"
|
||||
ruff = "^0.8.0"
|
||||
pre-commit = "^4.0.0"
|
||||
|
||||
[tool.poetry.extras]
|
||||
redis = ["redis"]
|
||||
postgresql = ["asyncpg"]
|
||||
|
||||
[tool.nonebot]
|
||||
plugins = [
|
||||
|
||||
@ -87,13 +87,17 @@ __plugin_meta__ = PluginMetadata(
|
||||
smart_tools=[
|
||||
AICallableTag(
|
||||
name="call_ban",
|
||||
description="某人多次(至少三次)辱骂你,调用此方法进行封禁",
|
||||
description="如果你讨厌某个人(好感度过低并让你感到困扰,或者多次辱骂你),调用此方法进行封禁,调用该方法后要告知用户被封禁和原因",
|
||||
parameters=AICallableParam(
|
||||
type="object",
|
||||
properties={
|
||||
"user_id": AICallableProperties(
|
||||
type="string", description="用户的id"
|
||||
),
|
||||
"duration": AICallableProperties(
|
||||
type="integer",
|
||||
description="封禁时长(选择的值只能是1-360),单位为分钟,如果频繁触发,按情况增加",
|
||||
),
|
||||
},
|
||||
required=["user_id"],
|
||||
),
|
||||
|
||||
@ -9,14 +9,14 @@ from zhenxun.services.log import logger
|
||||
from zhenxun.utils.image_utils import BuildImage, ImageTemplate
|
||||
|
||||
|
||||
async def call_ban(user_id: str):
|
||||
async def call_ban(user_id: str, duration: int = 1):
|
||||
"""调用ban
|
||||
|
||||
参数:
|
||||
user_id: 用户id
|
||||
"""
|
||||
await BanConsole.ban(user_id, None, 9, 60 * 12)
|
||||
logger.info("辱骂次数过多,已将用户加入黑名单...", "ban", session=user_id)
|
||||
await BanConsole.ban(user_id, None, 9, duration * 60)
|
||||
logger.info("被讨厌了,已将用户加入黑名单...", "ban", session=user_id)
|
||||
|
||||
|
||||
class BanManage:
|
||||
@ -114,7 +114,7 @@ class BanManage:
|
||||
if not is_superuser and user_id and session.id1:
|
||||
user_level = await LevelUser.get_user_level(session.id1, group_id)
|
||||
if idx:
|
||||
ban_data = await BanConsole.get_or_none(id=idx)
|
||||
ban_data = await BanConsole.get_ban(id=idx)
|
||||
if not ban_data:
|
||||
return False, "该用户/群组不在黑名单中捏..."
|
||||
if ban_data.ban_level > user_level:
|
||||
|
||||
@ -1,10 +1,13 @@
|
||||
import os
|
||||
from typing import cast
|
||||
|
||||
from zhenxun.configs.path_config import DATA_PATH, IMAGE_PATH
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.models.task_info import TaskInfo
|
||||
from zhenxun.utils.enum import BlockType, PluginType
|
||||
from zhenxun.services.cache import CacheRoot
|
||||
from zhenxun.utils.common_utils import CommonUtils
|
||||
from zhenxun.utils.enum import BlockType, CacheType, PluginType
|
||||
from zhenxun.utils.exception import GroupInfoNotFound
|
||||
from zhenxun.utils.image_utils import BuildImage, ImageTemplate, RowStyle
|
||||
|
||||
@ -116,9 +119,7 @@ async def build_task(group_id: str | None) -> BuildImage:
|
||||
column_name = ["ID", "模块", "名称", "群组状态", "全局状态", "运行时间"]
|
||||
group = None
|
||||
if group_id:
|
||||
group = await GroupConsole.get_or_none(
|
||||
group_id=group_id, channel_id__isnull=True
|
||||
)
|
||||
group = await GroupConsole.get_group(group_id=group_id)
|
||||
if not group:
|
||||
raise GroupInfoNotFound()
|
||||
else:
|
||||
@ -200,26 +201,26 @@ class PluginManager:
|
||||
)
|
||||
return f"成功将所有功能进群默认状态修改为: {'开启' if status else '关闭'}"
|
||||
if group_id:
|
||||
if group := await GroupConsole.get_or_none(
|
||||
group_id=group_id, channel_id__isnull=True
|
||||
):
|
||||
module_list = await PluginInfo.filter(
|
||||
plugin_type=PluginType.NORMAL
|
||||
).values_list("module", flat=True)
|
||||
if group := await GroupConsole.get_group(group_id=group_id):
|
||||
module_list = cast(
|
||||
list[str],
|
||||
await PluginInfo.filter(plugin_type=PluginType.NORMAL).values_list(
|
||||
"module", flat=True
|
||||
),
|
||||
)
|
||||
if status:
|
||||
for module in module_list:
|
||||
group.block_plugin = group.block_plugin.replace(
|
||||
f"<{module},", ""
|
||||
)
|
||||
# 开启所有功能 - 清空禁用列表
|
||||
group.block_plugin = ""
|
||||
else:
|
||||
module_list = [f"<{module}" for module in module_list]
|
||||
group.block_plugin = ",".join(module_list) + "," # type: ignore
|
||||
# 关闭所有功能 - 将模块列表转换为禁用格式
|
||||
group.block_plugin = CommonUtils.convert_module_format(module_list)
|
||||
await group.save(update_fields=["block_plugin"])
|
||||
return f"成功将此群组所有功能状态修改为: {'开启' if status else '关闭'}"
|
||||
return "获取群组失败..."
|
||||
await PluginInfo.filter(plugin_type=PluginType.NORMAL).update(
|
||||
status=status, block_type=None if status else BlockType.ALL
|
||||
)
|
||||
await CacheRoot.invalidate_cache(CacheType.PLUGINS)
|
||||
return f"成功将所有功能全局状态修改为: {'开启' if status else '关闭'}"
|
||||
|
||||
@classmethod
|
||||
@ -232,9 +233,7 @@ class PluginManager:
|
||||
返回:
|
||||
bool: 是否醒来
|
||||
"""
|
||||
if c := await GroupConsole.get_or_none(
|
||||
group_id=group_id, channel_id__isnull=True
|
||||
):
|
||||
if c := await GroupConsole.get_group(group_id=group_id):
|
||||
return c.status
|
||||
return False
|
||||
|
||||
@ -245,9 +244,11 @@ class PluginManager:
|
||||
参数:
|
||||
group_id: 群组id
|
||||
"""
|
||||
await GroupConsole.filter(group_id=group_id, channel_id__isnull=True).update(
|
||||
status=False
|
||||
group, _ = await GroupConsole.get_or_create(
|
||||
group_id=group_id, channel_id__isnull=True
|
||||
)
|
||||
group.status = False
|
||||
await group.save(update_fields=["status"])
|
||||
|
||||
@classmethod
|
||||
async def wake(cls, group_id: str):
|
||||
@ -256,9 +257,11 @@ class PluginManager:
|
||||
参数:
|
||||
group_id: 群组id
|
||||
"""
|
||||
await GroupConsole.filter(group_id=group_id, channel_id__isnull=True).update(
|
||||
status=True
|
||||
group, _ = await GroupConsole.get_or_create(
|
||||
group_id=group_id, channel_id__isnull=True
|
||||
)
|
||||
group.status = True
|
||||
await group.save(update_fields=["status"])
|
||||
|
||||
@classmethod
|
||||
async def block(cls, module: str):
|
||||
@ -267,7 +270,9 @@ class PluginManager:
|
||||
参数:
|
||||
module: 模块名
|
||||
"""
|
||||
await PluginInfo.filter(module=module).update(status=False)
|
||||
if plugin := await PluginInfo.get_plugin(module=module):
|
||||
plugin.status = False
|
||||
await plugin.save(update_fields=["status"])
|
||||
|
||||
@classmethod
|
||||
async def unblock(cls, module: str):
|
||||
@ -276,7 +281,9 @@ class PluginManager:
|
||||
参数:
|
||||
module: 模块名
|
||||
"""
|
||||
await PluginInfo.filter(module=module).update(status=True)
|
||||
if plugin := await PluginInfo.get_plugin(module=module):
|
||||
plugin.status = True
|
||||
await plugin.save(update_fields=["status"])
|
||||
|
||||
@classmethod
|
||||
async def block_group_plugin(cls, plugin_name: str, group_id: str) -> str:
|
||||
@ -437,17 +444,18 @@ class PluginManager:
|
||||
"""
|
||||
status_str = "关闭" if status else "开启"
|
||||
if is_all:
|
||||
modules = await TaskInfo.annotate().values_list("module", flat=True)
|
||||
if modules:
|
||||
module_list = cast(
|
||||
list[str], await TaskInfo.annotate().values_list("module", flat=True)
|
||||
)
|
||||
if module_list:
|
||||
group, _ = await GroupConsole.get_or_create(
|
||||
group_id=group_id, channel_id__isnull=True
|
||||
)
|
||||
modules = [f"<{module}" for module in modules]
|
||||
if status:
|
||||
group.block_task = ",".join(modules) + "," # type: ignore
|
||||
group.block_task = CommonUtils.convert_module_format(module_list)
|
||||
else:
|
||||
for module in modules:
|
||||
group.block_task = group.block_task.replace(f"{module},", "")
|
||||
# 开启所有模块 - 清空禁用列表
|
||||
group.block_task = ""
|
||||
await group.save(update_fields=["block_task"])
|
||||
return f"已成功{status_str}全部被动技能!"
|
||||
elif task := await TaskInfo.get_or_none(name=task_name):
|
||||
|
||||
@ -1,13 +1,15 @@
|
||||
from nonebot import on_message
|
||||
from nonebot.plugin import PluginMetadata
|
||||
from nonebot_plugin_alconna import UniMsg
|
||||
from nonebot_plugin_session import EventSession
|
||||
from nonebot_plugin_apscheduler import scheduler
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.configs.utils import PluginExtraData, RegisterConfig
|
||||
from zhenxun.models.chat_history import ChatHistory
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import PluginType
|
||||
from zhenxun.utils.utils import get_entity_ids
|
||||
|
||||
__plugin_meta__ = PluginMetadata(
|
||||
name="消息存储",
|
||||
@ -37,18 +39,34 @@ def rule(message: UniMsg) -> bool:
|
||||
|
||||
chat_history = on_message(rule=rule, priority=1, block=False)
|
||||
|
||||
TEMP_LIST = []
|
||||
|
||||
|
||||
@chat_history.handle()
|
||||
async def handle_message(message: UniMsg, session: EventSession):
|
||||
"""处理消息存储"""
|
||||
try:
|
||||
await ChatHistory.create(
|
||||
user_id=session.id1,
|
||||
group_id=session.id2,
|
||||
async def _(message: UniMsg, session: Uninfo):
|
||||
entity = get_entity_ids(session)
|
||||
TEMP_LIST.append(
|
||||
ChatHistory(
|
||||
user_id=entity.user_id,
|
||||
group_id=entity.group_id,
|
||||
text=str(message),
|
||||
plain_text=message.extract_plain_text(),
|
||||
bot_id=session.bot_id,
|
||||
bot_id=session.self_id,
|
||||
platform=session.platform,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@scheduler.scheduled_job(
|
||||
"interval",
|
||||
minutes=1,
|
||||
)
|
||||
async def _():
|
||||
try:
|
||||
message_list = TEMP_LIST.copy()
|
||||
TEMP_LIST.clear()
|
||||
if message_list:
|
||||
await ChatHistory.bulk_create(message_list)
|
||||
logger.debug(f"批量添加聊天记录 {len(message_list)} 条", "定时任务")
|
||||
except Exception as e:
|
||||
logger.warning("存储聊天记录失败", "chat_history", e=e)
|
||||
|
||||
@ -45,11 +45,13 @@ async def classify_plugin(
|
||||
"""
|
||||
sort_data = await sort_type()
|
||||
classify: dict[str, list] = {}
|
||||
group = await GroupConsole.get_or_none(group_id=group_id) if group_id else None
|
||||
group = await GroupConsole.get_group(group_id=group_id) if group_id else None
|
||||
bot = await BotConsole.get_or_none(bot_id=session.self_id)
|
||||
for menu, value in sort_data.items():
|
||||
for plugin in value:
|
||||
if not classify.get(menu):
|
||||
classify[menu] = []
|
||||
classify[menu].append(handle(bot, plugin, group, is_detail))
|
||||
for value in classify.values():
|
||||
value.sort(key=lambda x: x.id)
|
||||
return classify
|
||||
|
||||
@ -21,6 +21,8 @@ class Item(BaseModel):
|
||||
"""插件名称"""
|
||||
sta: int
|
||||
"""插件状态"""
|
||||
id: int
|
||||
"""插件id"""
|
||||
|
||||
|
||||
class PluginList(BaseModel):
|
||||
@ -80,10 +82,9 @@ def __handle_item(
|
||||
sta = 2
|
||||
if f"{plugin.module}," in group.block_plugin:
|
||||
sta = 1
|
||||
if bot:
|
||||
if f"{plugin.module}," in bot.block_plugins:
|
||||
sta = 2
|
||||
return Item(plugin_name=plugin.name, sta=sta)
|
||||
if bot and f"{plugin.module}," in bot.block_plugins:
|
||||
sta = 2
|
||||
return Item(plugin_name=plugin.name, sta=sta, id=plugin.id)
|
||||
|
||||
|
||||
def build_plugin_data(classify: dict[str, list[Item]]) -> list[dict[str, str]]:
|
||||
@ -142,7 +143,7 @@ async def build_html_image(
|
||||
template_name="zhenxun_menu.html",
|
||||
templates={"plugin_list": plugin_list},
|
||||
pages={
|
||||
"viewport": {"width": 1903, "height": 975},
|
||||
"viewport": {"width": 1903, "height": 10},
|
||||
"base_url": f"file://{TEMPLATE_PATH}",
|
||||
},
|
||||
wait=2,
|
||||
|
||||
@ -45,7 +45,7 @@ async def build_normal_image(group_id: str | None, is_detail: bool) -> BuildImag
|
||||
color="black" if idx % 2 else "white",
|
||||
)
|
||||
curr_h = 10
|
||||
group = await GroupConsole.get_or_none(group_id=group_id)
|
||||
group = await GroupConsole.get_group(group_id=group_id) if group_id else None
|
||||
for _, plugin in enumerate(plugin_list):
|
||||
text_color = (255, 255, 255) if idx % 2 else (0, 0, 0)
|
||||
if group and f"{plugin.module}," in group.block_plugin:
|
||||
@ -80,7 +80,7 @@ async def build_normal_image(group_id: str | None, is_detail: bool) -> BuildImag
|
||||
width, height = 10, 10
|
||||
for s in [
|
||||
"目前支持的功能列表:",
|
||||
"可以通过 ‘帮助 [功能名称或功能Id]’ 来获取对应功能的使用方法",
|
||||
"可以通过 '帮助 [功能名称或功能Id]' 来获取对应功能的使用方法",
|
||||
]:
|
||||
text = await BuildImage.build_text_image(s, "HYWenHei-85W.ttf", 24)
|
||||
await result.paste(text, (width, height))
|
||||
|
||||
@ -20,6 +20,12 @@ class Item(BaseModel):
|
||||
"""插件名称"""
|
||||
commands: list[str]
|
||||
"""插件命令"""
|
||||
id: str
|
||||
"""插件id"""
|
||||
status: bool
|
||||
"""插件状态"""
|
||||
has_superuser_help: bool
|
||||
"""插件是否拥有超级用户帮助"""
|
||||
|
||||
|
||||
def __handle_item(
|
||||
@ -39,23 +45,36 @@ def __handle_item(
|
||||
返回:
|
||||
Item: Item
|
||||
"""
|
||||
status = True
|
||||
has_superuser_help = False
|
||||
nb_plugin = nonebot.get_plugin_by_module_name(plugin.module_path)
|
||||
if nb_plugin and nb_plugin.metadata and nb_plugin.metadata.extra:
|
||||
extra_data = PluginExtraData(**nb_plugin.metadata.extra)
|
||||
if extra_data.superuser_help:
|
||||
has_superuser_help = True
|
||||
if not plugin.status:
|
||||
if plugin.block_type == BlockType.ALL:
|
||||
plugin.name = f"{plugin.name}(不可用)"
|
||||
status = False
|
||||
elif group and plugin.block_type == BlockType.GROUP:
|
||||
plugin.name = f"{plugin.name}(不可用)"
|
||||
status = False
|
||||
elif not group and plugin.block_type == BlockType.PRIVATE:
|
||||
plugin.name = f"{plugin.name}(不可用)"
|
||||
status = False
|
||||
elif group and f"{plugin.module}," in group.block_plugin:
|
||||
plugin.name = f"{plugin.name}(不可用)"
|
||||
status = False
|
||||
elif bot and f"{plugin.module}," in bot.block_plugins:
|
||||
plugin.name = f"{plugin.name}(不可用)"
|
||||
status = False
|
||||
commands = []
|
||||
nb_plugin = nonebot.get_plugin_by_module_name(plugin.module_path)
|
||||
if is_detail and nb_plugin and nb_plugin.metadata and nb_plugin.metadata.extra:
|
||||
extra_data = PluginExtraData(**nb_plugin.metadata.extra)
|
||||
commands = [cmd.command for cmd in extra_data.commands]
|
||||
return Item(plugin_name=f"{plugin.id}-{plugin.name}", commands=commands)
|
||||
return Item(
|
||||
plugin_name=plugin.name,
|
||||
commands=commands,
|
||||
id=str(plugin.id),
|
||||
status=status,
|
||||
has_superuser_help=has_superuser_help,
|
||||
)
|
||||
|
||||
|
||||
def build_plugin_data(classify: dict[str, list[Item]]) -> list[dict[str, str]]:
|
||||
@ -78,68 +97,10 @@ def build_plugin_data(classify: dict[str, list[Item]]) -> list[dict[str, str]]:
|
||||
}
|
||||
for menu, value in classify.items()
|
||||
]
|
||||
plugin_list = build_line_data(plugin_list)
|
||||
plugin_list.insert(
|
||||
0,
|
||||
build_plugin_line(
|
||||
menu_key if menu_key not in ["normal", "功能"] else "主要功能",
|
||||
max_data,
|
||||
30,
|
||||
100,
|
||||
True,
|
||||
),
|
||||
)
|
||||
return plugin_list
|
||||
|
||||
|
||||
def build_plugin_line(
|
||||
name: str, items: list, left: int, width: int | None = None, is_max: bool = False
|
||||
) -> dict:
|
||||
"""构造插件行数据
|
||||
|
||||
参数:
|
||||
name: 菜单名称
|
||||
items: 插件名称列表
|
||||
left: 左边距
|
||||
width: 总插件长度.
|
||||
is_max: 是否为最大长度的插件菜单
|
||||
|
||||
返回:
|
||||
dict: 插件数据
|
||||
"""
|
||||
_plugins = []
|
||||
width = width or 50
|
||||
if len(items) // 2 > 6 or is_max:
|
||||
width = 100
|
||||
plugin_list1 = []
|
||||
plugin_list2 = []
|
||||
for i in range(len(items)):
|
||||
if i % 2:
|
||||
plugin_list1.append(items[i])
|
||||
else:
|
||||
plugin_list2.append(items[i])
|
||||
_plugins = [(30, 50, plugin_list1), (0, 50, plugin_list2)]
|
||||
else:
|
||||
_plugins = [(left, 100, items)]
|
||||
return {"name": name, "items": _plugins, "width": width}
|
||||
|
||||
|
||||
def build_line_data(plugin_list: list[dict]) -> list[dict]:
|
||||
"""构造插件数据
|
||||
|
||||
参数:
|
||||
plugin_list: 插件列表
|
||||
|
||||
返回:
|
||||
list[dict]: 插件数据
|
||||
"""
|
||||
left = 30
|
||||
data = []
|
||||
plugin_list.insert(0, {"name": menu_key, "items": max_data})
|
||||
for plugin in plugin_list:
|
||||
data.append(build_plugin_line(plugin["name"], plugin["items"], left))
|
||||
if len(plugin["items"]) // 2 <= 6:
|
||||
left = 15 if left == 30 else 30
|
||||
return data
|
||||
plugin["items"].sort(key=lambda x: x.id)
|
||||
return plugin_list
|
||||
|
||||
|
||||
async def build_zhenxun_image(
|
||||
@ -160,6 +121,7 @@ async def build_zhenxun_image(
|
||||
width = int(637 * 1.5) if is_detail else 637
|
||||
title_font = int(53 * 1.5) if is_detail else 53
|
||||
tip_font = int(19 * 1.5) if is_detail else 19
|
||||
plugin_count = sum(len(plugin["items"]) for plugin in plugin_list)
|
||||
return await template_to_pic(
|
||||
template_path=str((TEMPLATE_PATH / "ss_menu").absolute()),
|
||||
template_name="main.html",
|
||||
@ -170,10 +132,11 @@ async def build_zhenxun_image(
|
||||
"width": width,
|
||||
"font_size": (title_font, tip_font),
|
||||
"is_detail": is_detail,
|
||||
"plugin_count": plugin_count,
|
||||
}
|
||||
},
|
||||
pages={
|
||||
"viewport": {"width": width, "height": 453},
|
||||
"viewport": {"width": width, "height": 10},
|
||||
"base_url": f"file://{TEMPLATE_PATH}",
|
||||
},
|
||||
wait=2,
|
||||
|
||||
@ -1,597 +0,0 @@
|
||||
from typing import ClassVar
|
||||
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot.adapters.onebot.v11 import PokeNotifyEvent
|
||||
from nonebot.exception import IgnoredException
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot_plugin_alconna import At, UniMsg
|
||||
from nonebot_plugin_session import EventSession
|
||||
from pydantic import BaseModel
|
||||
from tortoise.exceptions import IntegrityError
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.models.bot_console import BotConsole
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.models.level_user import LevelUser
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.models.plugin_limit import PluginLimit
|
||||
from zhenxun.models.sign_user import SignUser
|
||||
from zhenxun.models.user_console import UserConsole
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import (
|
||||
BlockType,
|
||||
GoldHandle,
|
||||
LimitWatchType,
|
||||
PluginLimitType,
|
||||
PluginType,
|
||||
)
|
||||
from zhenxun.utils.exception import InsufficientGold
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
from zhenxun.utils.utils import CountLimiter, FreqLimiter, UserBlockLimiter
|
||||
|
||||
base_config = Config.get("hook")
|
||||
|
||||
|
||||
class Limit(BaseModel):
|
||||
limit: PluginLimit
|
||||
limiter: FreqLimiter | UserBlockLimiter | CountLimiter
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class LimitManage:
|
||||
add_module: ClassVar[list] = []
|
||||
|
||||
cd_limit: ClassVar[dict[str, Limit]] = {}
|
||||
block_limit: ClassVar[dict[str, Limit]] = {}
|
||||
count_limit: ClassVar[dict[str, Limit]] = {}
|
||||
|
||||
@classmethod
|
||||
def add_limit(cls, limit: PluginLimit):
|
||||
"""添加限制
|
||||
|
||||
参数:
|
||||
limit: PluginLimit
|
||||
"""
|
||||
if limit.module not in cls.add_module:
|
||||
cls.add_module.append(limit.module)
|
||||
if limit.limit_type == PluginLimitType.BLOCK:
|
||||
cls.block_limit[limit.module] = Limit(
|
||||
limit=limit, limiter=UserBlockLimiter()
|
||||
)
|
||||
elif limit.limit_type == PluginLimitType.CD:
|
||||
cls.cd_limit[limit.module] = Limit(
|
||||
limit=limit, limiter=FreqLimiter(limit.cd)
|
||||
)
|
||||
elif limit.limit_type == PluginLimitType.COUNT:
|
||||
cls.count_limit[limit.module] = Limit(
|
||||
limit=limit, limiter=CountLimiter(limit.max_count)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def unblock(
|
||||
cls, module: str, user_id: str, group_id: str | None, channel_id: str | None
|
||||
):
|
||||
"""解除插件block
|
||||
|
||||
参数:
|
||||
module: 模块名
|
||||
user_id: 用户id
|
||||
group_id: 群组id
|
||||
channel_id: 频道id
|
||||
"""
|
||||
if limit_model := cls.block_limit.get(module):
|
||||
limit = limit_model.limit
|
||||
limiter: UserBlockLimiter = limit_model.limiter # type: ignore
|
||||
key_type = user_id
|
||||
if group_id and limit.watch_type == LimitWatchType.GROUP:
|
||||
key_type = channel_id or group_id
|
||||
logger.debug(
|
||||
f"解除对象: {key_type} 的block限制",
|
||||
"AuthChecker",
|
||||
session=user_id,
|
||||
group_id=group_id,
|
||||
)
|
||||
limiter.set_false(key_type)
|
||||
|
||||
@classmethod
|
||||
async def check(
|
||||
cls,
|
||||
module: str,
|
||||
user_id: str,
|
||||
group_id: str | None,
|
||||
channel_id: str | None,
|
||||
session: EventSession,
|
||||
):
|
||||
"""检测限制
|
||||
|
||||
参数:
|
||||
module: 模块名
|
||||
user_id: 用户id
|
||||
group_id: 群组id
|
||||
channel_id: 频道id
|
||||
session: Session
|
||||
|
||||
异常:
|
||||
IgnoredException: IgnoredException
|
||||
"""
|
||||
if limit_model := cls.cd_limit.get(module):
|
||||
await cls.__check(limit_model, user_id, group_id, channel_id, session)
|
||||
if limit_model := cls.block_limit.get(module):
|
||||
await cls.__check(limit_model, user_id, group_id, channel_id, session)
|
||||
if limit_model := cls.count_limit.get(module):
|
||||
await cls.__check(limit_model, user_id, group_id, channel_id, session)
|
||||
|
||||
@classmethod
|
||||
async def __check(
|
||||
cls,
|
||||
limit_model: Limit | None,
|
||||
user_id: str,
|
||||
group_id: str | None,
|
||||
channel_id: str | None,
|
||||
session: EventSession,
|
||||
):
|
||||
"""检测限制
|
||||
|
||||
参数:
|
||||
limit_model: Limit
|
||||
user_id: 用户id
|
||||
group_id: 群组id
|
||||
channel_id: 频道id
|
||||
session: Session
|
||||
|
||||
异常:
|
||||
IgnoredException: IgnoredException
|
||||
"""
|
||||
if not limit_model:
|
||||
return
|
||||
limit = limit_model.limit
|
||||
limiter = limit_model.limiter
|
||||
is_limit = (
|
||||
LimitWatchType.ALL
|
||||
or (group_id and limit.watch_type == LimitWatchType.GROUP)
|
||||
or (not group_id and limit.watch_type == LimitWatchType.USER)
|
||||
)
|
||||
key_type = user_id
|
||||
if group_id and limit.watch_type == LimitWatchType.GROUP:
|
||||
key_type = channel_id or group_id
|
||||
if is_limit and not limiter.check(key_type):
|
||||
if limit.result:
|
||||
await MessageUtils.build_message(limit.result).send()
|
||||
logger.debug(
|
||||
f"{limit.module}({limit.limit_type}) 正在限制中...",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException(f"{limit.module} 正在限制中...")
|
||||
else:
|
||||
logger.debug(
|
||||
f"开始进行限制 {limit.module}({limit.limit_type})...",
|
||||
"AuthChecker",
|
||||
session=user_id,
|
||||
group_id=group_id,
|
||||
)
|
||||
if isinstance(limiter, FreqLimiter):
|
||||
limiter.start_cd(key_type)
|
||||
if isinstance(limiter, UserBlockLimiter):
|
||||
limiter.set_true(key_type)
|
||||
if isinstance(limiter, CountLimiter):
|
||||
limiter.increase(key_type)
|
||||
|
||||
|
||||
class IsSuperuserException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class AuthChecker:
|
||||
"""
|
||||
权限检查
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
check_notice_info_cd = Config.get_config("hook", "CHECK_NOTICE_INFO_CD")
|
||||
if check_notice_info_cd is None or check_notice_info_cd < 0:
|
||||
raise ValueError("模块: [hook], 配置项: [CHECK_NOTICE_INFO_CD] 为空或小于0")
|
||||
self._flmt = FreqLimiter(check_notice_info_cd)
|
||||
self._flmt_g = FreqLimiter(check_notice_info_cd)
|
||||
self._flmt_s = FreqLimiter(check_notice_info_cd)
|
||||
self._flmt_c = FreqLimiter(check_notice_info_cd)
|
||||
|
||||
def is_send_limit_message(self, plugin: PluginInfo, sid: str) -> bool:
|
||||
"""是否发送提示消息
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
|
||||
返回:
|
||||
bool: 是否发送提示消息
|
||||
"""
|
||||
if not base_config.get("IS_SEND_TIP_MESSAGE"):
|
||||
return False
|
||||
if plugin.plugin_type == PluginType.DEPENDANT:
|
||||
return False
|
||||
if plugin.ignore_prompt:
|
||||
return False
|
||||
return self._flmt_s.check(sid)
|
||||
|
||||
async def auth(
|
||||
self,
|
||||
matcher: Matcher,
|
||||
event: Event,
|
||||
bot: Bot,
|
||||
session: EventSession,
|
||||
message: UniMsg,
|
||||
):
|
||||
"""权限检查
|
||||
|
||||
参数:
|
||||
matcher: matcher
|
||||
bot: bot
|
||||
session: EventSession
|
||||
message: UniMsg
|
||||
"""
|
||||
is_ignore = False
|
||||
cost_gold = 0
|
||||
user_id = session.id1
|
||||
group_id = session.id3
|
||||
channel_id = session.id2
|
||||
if not group_id:
|
||||
group_id = channel_id
|
||||
channel_id = None
|
||||
if matcher.type == "notice" and not isinstance(event, PokeNotifyEvent):
|
||||
"""过滤除poke外的notice"""
|
||||
return
|
||||
if user_id and matcher.plugin and (module_path := matcher.plugin.module_name):
|
||||
try:
|
||||
user = await UserConsole.get_user(user_id, session.platform)
|
||||
except IntegrityError as e:
|
||||
logger.debug(
|
||||
"重复创建用户,已跳过该次权限...",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
e=e,
|
||||
)
|
||||
return
|
||||
if plugin := await PluginInfo.get_or_none(module_path=module_path):
|
||||
if plugin.plugin_type == PluginType.HIDDEN:
|
||||
logger.debug(
|
||||
f"插件: {plugin.name}:{plugin.module} "
|
||||
"为HIDDEN,已跳过权限检查..."
|
||||
)
|
||||
return
|
||||
try:
|
||||
cost_gold = await self.auth_cost(user, plugin, session)
|
||||
if session.id1 in bot.config.superusers:
|
||||
if plugin.plugin_type == PluginType.SUPERUSER:
|
||||
raise IsSuperuserException()
|
||||
if not plugin.limit_superuser:
|
||||
cost_gold = 0
|
||||
raise IsSuperuserException()
|
||||
await self.auth_bot(plugin, bot.self_id)
|
||||
await self.auth_group(plugin, session, message)
|
||||
await self.auth_admin(plugin, session)
|
||||
await self.auth_plugin(plugin, session, event)
|
||||
await self.auth_limit(plugin, session)
|
||||
except IsSuperuserException:
|
||||
logger.debug(
|
||||
"超级用户或被ban跳过权限检测...", "AuthChecker", session=session
|
||||
)
|
||||
except IgnoredException:
|
||||
is_ignore = True
|
||||
LimitManage.unblock(
|
||||
matcher.plugin.name, user_id, group_id, channel_id
|
||||
)
|
||||
except AssertionError as e:
|
||||
is_ignore = True
|
||||
logger.debug("消息无法发送", session=session, e=e)
|
||||
if cost_gold and user_id:
|
||||
"""花费金币"""
|
||||
try:
|
||||
await UserConsole.reduce_gold(
|
||||
user_id,
|
||||
cost_gold,
|
||||
GoldHandle.PLUGIN,
|
||||
matcher.plugin.name if matcher.plugin else "",
|
||||
session.platform,
|
||||
)
|
||||
except InsufficientGold:
|
||||
if u := await UserConsole.get_user(user_id):
|
||||
u.gold = 0
|
||||
await u.save(update_fields=["gold"])
|
||||
logger.debug(
|
||||
f"调用功能花费金币: {cost_gold}", "AuthChecker", session=session
|
||||
)
|
||||
if is_ignore:
|
||||
raise IgnoredException("权限检测 ignore")
|
||||
|
||||
async def auth_bot(self, plugin: PluginInfo, bot_id: str):
|
||||
"""机器人权限
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
bot_id: bot_id
|
||||
"""
|
||||
if not await BotConsole.get_bot_status(bot_id):
|
||||
logger.debug("Bot休眠中阻断权限检测...", "AuthChecker")
|
||||
raise IgnoredException("BotConsole休眠权限检测 ignore")
|
||||
if await BotConsole.is_block_plugin(bot_id, plugin.module):
|
||||
logger.debug(
|
||||
f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭...",
|
||||
"AuthChecker",
|
||||
)
|
||||
raise IgnoredException("BotConsole插件权限检测 ignore")
|
||||
|
||||
async def auth_limit(self, plugin: PluginInfo, session: EventSession):
|
||||
"""插件限制
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
session: EventSession
|
||||
"""
|
||||
user_id = session.id1
|
||||
group_id = session.id3
|
||||
channel_id = session.id2
|
||||
if not group_id:
|
||||
group_id = channel_id
|
||||
channel_id = None
|
||||
if plugin.module not in LimitManage.add_module:
|
||||
limit_list: list[PluginLimit] = await plugin.plugin_limit.filter(
|
||||
status=True
|
||||
).all() # type: ignore
|
||||
for limit in limit_list:
|
||||
LimitManage.add_limit(limit)
|
||||
if user_id:
|
||||
await LimitManage.check(
|
||||
plugin.module, user_id, group_id, channel_id, session
|
||||
)
|
||||
|
||||
async def auth_plugin(
|
||||
self, plugin: PluginInfo, session: EventSession, event: Event
|
||||
):
|
||||
"""插件状态
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
session: EventSession
|
||||
"""
|
||||
group_id = session.id3
|
||||
channel_id = session.id2
|
||||
if not group_id:
|
||||
group_id = channel_id
|
||||
channel_id = None
|
||||
if user_id := session.id1:
|
||||
if plugin.impression > 0:
|
||||
sign_user = await SignUser.get_user(user_id)
|
||||
if float(sign_user.impression) < plugin.impression:
|
||||
if self.is_send_limit_message(plugin, user_id):
|
||||
self._flmt_s.start_cd(user_id)
|
||||
await MessageUtils.build_message(
|
||||
f"好感度不足哦,当前功能需要好感度: {plugin.impression},"
|
||||
"请继续签到提升好感度吧!"
|
||||
).send(reply_to=True)
|
||||
logger.debug(
|
||||
f"{plugin.name}({plugin.module}) 用户好感度不足...",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException("好感度不足...")
|
||||
if group_id:
|
||||
sid = group_id or user_id
|
||||
if await GroupConsole.is_superuser_block_plugin(
|
||||
group_id, plugin.module
|
||||
):
|
||||
"""超级用户群组插件状态"""
|
||||
if self.is_send_limit_message(plugin, sid):
|
||||
self._flmt_s.start_cd(group_id or user_id)
|
||||
await MessageUtils.build_message(
|
||||
"超级管理员禁用了该群此功能..."
|
||||
).send(reply_to=True)
|
||||
logger.debug(
|
||||
f"{plugin.name}({plugin.module}) 超级管理员禁用了该群此功能...",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException("超级管理员禁用了该群此功能...")
|
||||
if await GroupConsole.is_normal_block_plugin(group_id, plugin.module):
|
||||
"""群组插件状态"""
|
||||
if self.is_send_limit_message(plugin, sid):
|
||||
self._flmt_s.start_cd(group_id or user_id)
|
||||
await MessageUtils.build_message("该群未开启此功能...").send(
|
||||
reply_to=True
|
||||
)
|
||||
logger.debug(
|
||||
f"{plugin.name}({plugin.module}) 未开启此功能...",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException("该群未开启此功能...")
|
||||
if plugin.block_type == BlockType.GROUP:
|
||||
"""全局群组禁用"""
|
||||
try:
|
||||
if self.is_send_limit_message(plugin, sid):
|
||||
self._flmt_c.start_cd(group_id)
|
||||
await MessageUtils.build_message(
|
||||
"该功能在群组中已被禁用..."
|
||||
).send(reply_to=True)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"auth_plugin 发送消息失败",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
e=e,
|
||||
)
|
||||
logger.debug(
|
||||
f"{plugin.name}({plugin.module}) 该插件在群组中已被禁用...",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException("该插件在群组中已被禁用...")
|
||||
else:
|
||||
sid = user_id
|
||||
if plugin.block_type == BlockType.PRIVATE:
|
||||
"""全局私聊禁用"""
|
||||
try:
|
||||
if self.is_send_limit_message(plugin, sid):
|
||||
self._flmt_c.start_cd(user_id)
|
||||
await MessageUtils.build_message(
|
||||
"该功能在私聊中已被禁用..."
|
||||
).send()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"auth_admin 发送消息失败",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
e=e,
|
||||
)
|
||||
logger.debug(
|
||||
f"{plugin.name}({plugin.module}) 该插件在私聊中已被禁用...",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException("该插件在私聊中已被禁用...")
|
||||
if not plugin.status and plugin.block_type == BlockType.ALL:
|
||||
"""全局状态"""
|
||||
if group_id and await GroupConsole.is_super_group(group_id):
|
||||
raise IsSuperuserException()
|
||||
logger.debug(
|
||||
f"{plugin.name}({plugin.module}) 全局未开启此功能...",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
if self.is_send_limit_message(plugin, sid):
|
||||
self._flmt_s.start_cd(group_id or user_id)
|
||||
await MessageUtils.build_message("全局未开启此功能...").send()
|
||||
raise IgnoredException("全局未开启此功能...")
|
||||
|
||||
async def auth_admin(self, plugin: PluginInfo, session: EventSession):
|
||||
"""管理员命令 个人权限
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
session: EventSession
|
||||
"""
|
||||
user_id = session.id1
|
||||
if user_id and plugin.admin_level:
|
||||
if group_id := session.id3 or session.id2:
|
||||
if not await LevelUser.check_level(
|
||||
user_id, group_id, plugin.admin_level
|
||||
):
|
||||
try:
|
||||
if self._flmt.check(user_id):
|
||||
self._flmt.start_cd(user_id)
|
||||
await MessageUtils.build_message(
|
||||
[
|
||||
At(flag="user", target=user_id),
|
||||
f"你的权限不足喔,"
|
||||
f"该功能需要的权限等级: {plugin.admin_level}",
|
||||
]
|
||||
).send(reply_to=True)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"auth_admin 发送消息失败",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
e=e,
|
||||
)
|
||||
logger.debug(
|
||||
f"{plugin.name}({plugin.module}) 管理员权限不足...",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException("管理员权限不足...")
|
||||
elif not await LevelUser.check_level(user_id, None, plugin.admin_level):
|
||||
try:
|
||||
await MessageUtils.build_message(
|
||||
f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}"
|
||||
).send()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"auth_admin 发送消息失败", "AuthChecker", session=session, e=e
|
||||
)
|
||||
logger.debug(
|
||||
f"{plugin.name}({plugin.module}) 管理员权限不足...",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException("权限不足")
|
||||
|
||||
async def auth_group(
|
||||
self, plugin: PluginInfo, session: EventSession, message: UniMsg
|
||||
):
|
||||
"""群黑名单检测 群总开关检测
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
session: EventSession
|
||||
message: UniMsg
|
||||
"""
|
||||
if not (group_id := session.id3 or session.id2):
|
||||
return
|
||||
text = message.extract_plain_text()
|
||||
group = await GroupConsole.get_group(group_id)
|
||||
if not group:
|
||||
"""群不存在"""
|
||||
logger.debug(
|
||||
"群组信息不存在...",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException("群不存在")
|
||||
if group.level < 0:
|
||||
"""群权限小于0"""
|
||||
logger.debug(
|
||||
"群黑名单, 群权限-1...",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException("群黑名单")
|
||||
if not group.status:
|
||||
"""群休眠"""
|
||||
if text.strip() != "醒来":
|
||||
logger.debug("群休眠状态...", "AuthChecker", session=session)
|
||||
raise IgnoredException("群休眠状态")
|
||||
if plugin.level > group.level:
|
||||
"""插件等级大于群等级"""
|
||||
logger.debug(
|
||||
f"{plugin.name}({plugin.module}) 群等级限制.."
|
||||
f"该功能需要的群等级: {plugin.level}..",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException(f"{plugin.name}({plugin.module}) 群等级限制...")
|
||||
|
||||
async def auth_cost(
|
||||
self, user: UserConsole, plugin: PluginInfo, session: EventSession
|
||||
) -> int:
|
||||
"""检测是否满足金币条件
|
||||
|
||||
参数:
|
||||
user: UserConsole
|
||||
plugin: PluginInfo
|
||||
session: EventSession
|
||||
|
||||
返回:
|
||||
int: 需要消耗的金币
|
||||
"""
|
||||
if user.gold < plugin.cost_gold:
|
||||
"""插件消耗金币不足"""
|
||||
try:
|
||||
await MessageUtils.build_message(
|
||||
f"金币不足..该功能需要{plugin.cost_gold}金币.."
|
||||
).send()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"auth_cost 发送消息失败", "AuthChecker", session=session, e=e
|
||||
)
|
||||
logger.debug(
|
||||
f"{plugin.name}({plugin.module}) 金币限制.."
|
||||
f"该功能需要{plugin.cost_gold}金币..",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException(f"{plugin.name}({plugin.module}) 金币限制...")
|
||||
return plugin.cost_gold
|
||||
|
||||
|
||||
checker = AuthChecker()
|
||||
99
zhenxun/builtin_plugins/hooks/auth/auth_admin.py
Normal file
99
zhenxun/builtin_plugins/hooks/auth/auth_admin.py
Normal file
@ -0,0 +1,99 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from nonebot_plugin_alconna import At
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.models.level_user import LevelUser
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.services.data_access import DataAccess
|
||||
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.utils import get_entity_ids
|
||||
|
||||
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
|
||||
from .exception import SkipPluginException
|
||||
from .utils import send_message
|
||||
|
||||
|
||||
async def auth_admin(plugin: PluginInfo, session: Uninfo):
|
||||
"""管理员命令 个人权限
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
session: Uninfo
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
if not plugin.admin_level:
|
||||
return
|
||||
|
||||
try:
|
||||
entity = get_entity_ids(session)
|
||||
level_dao = DataAccess(LevelUser)
|
||||
|
||||
# 并行查询用户权限数据
|
||||
global_user: LevelUser | None = None
|
||||
group_users: LevelUser | None = None
|
||||
|
||||
# 查询全局权限
|
||||
global_user_task = level_dao.safe_get_or_none(
|
||||
user_id=session.user.id, group_id__isnull=True
|
||||
)
|
||||
|
||||
# 如果在群组中,查询群组权限
|
||||
group_users_task = None
|
||||
if entity.group_id:
|
||||
group_users_task = level_dao.safe_get_or_none(
|
||||
user_id=session.user.id, group_id=entity.group_id
|
||||
)
|
||||
|
||||
# 等待查询完成,添加超时控制
|
||||
try:
|
||||
results = await asyncio.wait_for(
|
||||
asyncio.gather(global_user_task, group_users_task or asyncio.sleep(0)),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
global_user = results[0]
|
||||
group_users = results[1] if group_users_task else None
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"查询用户权限超时: user_id={session.user.id}", LOGGER_COMMAND)
|
||||
# 超时时不阻塞,继续执行
|
||||
return
|
||||
|
||||
user_level = global_user.user_level if global_user else 0
|
||||
if entity.group_id and group_users:
|
||||
user_level = max(user_level, group_users.user_level)
|
||||
|
||||
if user_level < plugin.admin_level:
|
||||
await send_message(
|
||||
session,
|
||||
[
|
||||
At(flag="user", target=session.user.id),
|
||||
f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}",
|
||||
],
|
||||
entity.user_id,
|
||||
)
|
||||
|
||||
raise SkipPluginException(
|
||||
f"{plugin.name}({plugin.module}) 管理员权限不足..."
|
||||
)
|
||||
elif global_user:
|
||||
if global_user.user_level < plugin.admin_level:
|
||||
await send_message(
|
||||
session,
|
||||
f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}",
|
||||
)
|
||||
|
||||
raise SkipPluginException(
|
||||
f"{plugin.name}({plugin.module}) 管理员权限不足..."
|
||||
)
|
||||
finally:
|
||||
# 记录执行时间
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||
logger.warning(
|
||||
f"auth_admin 耗时: {elapsed:.3f}s, plugin={plugin.module}",
|
||||
LOGGER_COMMAND,
|
||||
session=session,
|
||||
)
|
||||
303
zhenxun/builtin_plugins/hooks/auth/auth_ban.py
Normal file
303
zhenxun/builtin_plugins/hooks/auth/auth_ban.py
Normal file
@ -0,0 +1,303 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from nonebot.adapters import Bot
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot_plugin_alconna import At
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.models.ban_console import BanConsole
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.services.data_access import DataAccess
|
||||
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import PluginType
|
||||
from zhenxun.utils.utils import EntityIDs, get_entity_ids
|
||||
|
||||
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
|
||||
from .exception import SkipPluginException
|
||||
from .utils import freq, send_message
|
||||
|
||||
Config.add_plugin_config(
|
||||
"hook",
|
||||
"BAN_RESULT",
|
||||
"才不会给你发消息.",
|
||||
help="对被ban用户发送的消息",
|
||||
)
|
||||
|
||||
|
||||
def calculate_ban_time(ban_record: BanConsole | None) -> int:
|
||||
"""根据ban记录计算剩余ban时间
|
||||
|
||||
参数:
|
||||
ban_record: BanConsole记录
|
||||
|
||||
返回:
|
||||
int: ban剩余时长,-1时为永久ban,0表示未被ban
|
||||
"""
|
||||
if not ban_record:
|
||||
return 0
|
||||
|
||||
if ban_record.duration == -1:
|
||||
return -1
|
||||
|
||||
_time = time.time() - (ban_record.ban_time + ban_record.duration)
|
||||
return 0 if _time > 0 else int(abs(_time))
|
||||
|
||||
|
||||
async def is_ban(user_id: str | None, group_id: str | None) -> int:
|
||||
"""检查用户或群组是否被ban
|
||||
|
||||
参数:
|
||||
user_id: 用户ID
|
||||
group_id: 群组ID
|
||||
|
||||
返回:
|
||||
int: ban的剩余时间,0表示未被ban
|
||||
"""
|
||||
if not user_id and not group_id:
|
||||
return 0
|
||||
|
||||
start_time = time.time()
|
||||
ban_dao = DataAccess(BanConsole)
|
||||
|
||||
# 分别获取用户在群组中的ban记录和全局ban记录
|
||||
group_user = None
|
||||
user = None
|
||||
|
||||
try:
|
||||
# 并行查询用户和群组的 ban 记录
|
||||
tasks = []
|
||||
if user_id and group_id:
|
||||
tasks.append(ban_dao.safe_get_or_none(user_id=user_id, group_id=group_id))
|
||||
if user_id:
|
||||
tasks.append(
|
||||
ban_dao.safe_get_or_none(user_id=user_id, group_id__isnull=True)
|
||||
)
|
||||
|
||||
# 等待所有查询完成,添加超时控制
|
||||
if tasks:
|
||||
try:
|
||||
ban_records = await asyncio.wait_for(
|
||||
asyncio.gather(*tasks), timeout=DB_TIMEOUT_SECONDS
|
||||
)
|
||||
if len(tasks) == 2:
|
||||
group_user, user = ban_records
|
||||
elif user_id and group_id:
|
||||
group_user = ban_records[0]
|
||||
else:
|
||||
user = ban_records[0]
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
f"查询ban记录超时: user_id={user_id}, group_id={group_id}",
|
||||
LOGGER_COMMAND,
|
||||
)
|
||||
# 超时时返回0,避免阻塞
|
||||
return 0
|
||||
|
||||
# 检查记录并计算ban时间
|
||||
results = []
|
||||
if group_user:
|
||||
results.append(group_user)
|
||||
if user:
|
||||
results.append(user)
|
||||
|
||||
# 如果没有找到记录,返回0
|
||||
if not results:
|
||||
return 0
|
||||
|
||||
logger.debug(f"查询到的ban记录: {results}", LOGGER_COMMAND)
|
||||
# 检查所有记录,找出最严格的ban(时间最长的)
|
||||
max_ban_time: int = 0
|
||||
for result in results:
|
||||
if result.duration > 0 or result.duration == -1:
|
||||
# 直接计算ban时间,避免再次查询数据库
|
||||
ban_time = calculate_ban_time(result)
|
||||
if ban_time == -1 or ban_time > max_ban_time:
|
||||
max_ban_time = ban_time
|
||||
|
||||
return max_ban_time
|
||||
finally:
|
||||
# 记录执行时间
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||
logger.warning(
|
||||
f"is_ban 耗时: {elapsed:.3f}s",
|
||||
LOGGER_COMMAND,
|
||||
session=user_id,
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
|
||||
def check_plugin_type(matcher: Matcher) -> bool:
|
||||
"""判断插件类型是否是隐藏插件
|
||||
|
||||
参数:
|
||||
matcher: Matcher
|
||||
|
||||
返回:
|
||||
bool: 是否为隐藏插件
|
||||
"""
|
||||
if plugin := matcher.plugin:
|
||||
if metadata := plugin.metadata:
|
||||
extra = metadata.extra
|
||||
if extra.get("plugin_type") in [PluginType.HIDDEN]:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def format_time(time_val: float) -> str:
|
||||
"""格式化时间
|
||||
|
||||
参数:
|
||||
time_val: ban时长
|
||||
|
||||
返回:
|
||||
str: 格式化时间文本
|
||||
"""
|
||||
if time_val == -1:
|
||||
return "∞"
|
||||
time_val = abs(int(time_val))
|
||||
if time_val < 60:
|
||||
time_str = f"{time_val!s} 秒"
|
||||
else:
|
||||
minute = int(time_val / 60)
|
||||
if minute > 60:
|
||||
hours = minute // 60
|
||||
minute %= 60
|
||||
time_str = f"{hours} 小时 {minute}分钟"
|
||||
else:
|
||||
time_str = f"{minute} 分钟"
|
||||
return time_str
|
||||
|
||||
|
||||
async def group_handle(group_id: str) -> None:
|
||||
"""群组ban检查
|
||||
|
||||
参数:
|
||||
group_id: 群组id
|
||||
|
||||
异常:
|
||||
SkipPluginException: 群组处于黑名单
|
||||
"""
|
||||
start_time = time.time()
|
||||
try:
|
||||
if await is_ban(None, group_id):
|
||||
raise SkipPluginException("群组处于黑名单中...")
|
||||
finally:
|
||||
# 记录执行时间
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||
logger.warning(
|
||||
f"group_handle 耗时: {elapsed:.3f}s",
|
||||
LOGGER_COMMAND,
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
|
||||
async def user_handle(module: str, entity: EntityIDs, session: Uninfo) -> None:
|
||||
"""用户ban检查
|
||||
|
||||
参数:
|
||||
module: 插件模块名
|
||||
entity: 实体ID信息
|
||||
session: Uninfo
|
||||
|
||||
异常:
|
||||
SkipPluginException: 用户处于黑名单
|
||||
"""
|
||||
start_time = time.time()
|
||||
try:
|
||||
ban_result = Config.get_config("hook", "BAN_RESULT")
|
||||
time_val = await is_ban(entity.user_id, entity.group_id)
|
||||
if not time_val:
|
||||
return
|
||||
time_str = format_time(time_val)
|
||||
plugin_dao = DataAccess(PluginInfo)
|
||||
try:
|
||||
db_plugin = await asyncio.wait_for(
|
||||
plugin_dao.safe_get_or_none(module=module), timeout=DB_TIMEOUT_SECONDS
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"查询插件信息超时: {module}", LOGGER_COMMAND)
|
||||
# 超时时不阻塞,继续执行
|
||||
raise SkipPluginException("用户处于黑名单中...")
|
||||
|
||||
if (
|
||||
db_plugin
|
||||
and not db_plugin.ignore_prompt
|
||||
and time_val != -1
|
||||
and ban_result
|
||||
and freq.is_send_limit_message(db_plugin, entity.user_id, False)
|
||||
):
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
send_message(
|
||||
session,
|
||||
[
|
||||
At(flag="user", target=entity.user_id),
|
||||
f"{ban_result}\n在..在 {time_str} 后才会理你喔",
|
||||
],
|
||||
entity.user_id,
|
||||
),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"发送消息超时: {entity.user_id}", LOGGER_COMMAND)
|
||||
raise SkipPluginException("用户处于黑名单中...")
|
||||
finally:
|
||||
# 记录执行时间
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||
logger.warning(
|
||||
f"user_handle 耗时: {elapsed:.3f}s",
|
||||
LOGGER_COMMAND,
|
||||
session=session,
|
||||
)
|
||||
|
||||
|
||||
async def auth_ban(matcher: Matcher, bot: Bot, session: Uninfo) -> None:
|
||||
"""权限检查 - ban 检查
|
||||
|
||||
参数:
|
||||
matcher: Matcher
|
||||
bot: Bot
|
||||
session: Uninfo
|
||||
"""
|
||||
start_time = time.time()
|
||||
try:
|
||||
if not check_plugin_type(matcher):
|
||||
return
|
||||
if not matcher.plugin_name:
|
||||
return
|
||||
entity = get_entity_ids(session)
|
||||
if entity.user_id in bot.config.superusers:
|
||||
return
|
||||
if entity.group_id:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
group_handle(entity.group_id), timeout=DB_TIMEOUT_SECONDS
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"群组ban检查超时: {entity.group_id}", LOGGER_COMMAND)
|
||||
# 超时时不阻塞,继续执行
|
||||
|
||||
if entity.user_id:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
user_handle(matcher.plugin_name, entity, session),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"用户ban检查超时: {entity.user_id}", LOGGER_COMMAND)
|
||||
# 超时时不阻塞,继续执行
|
||||
finally:
|
||||
# 记录总执行时间
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||
logger.warning(
|
||||
f"auth_ban 总耗时: {elapsed:.3f}s, plugin={matcher.plugin_name}",
|
||||
LOGGER_COMMAND,
|
||||
session=session,
|
||||
)
|
||||
55
zhenxun/builtin_plugins/hooks/auth/auth_bot.py
Normal file
55
zhenxun/builtin_plugins/hooks/auth/auth_bot.py
Normal file
@ -0,0 +1,55 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from zhenxun.models.bot_console import BotConsole
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.services.data_access import DataAccess
|
||||
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.common_utils import CommonUtils
|
||||
|
||||
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
|
||||
from .exception import SkipPluginException
|
||||
|
||||
|
||||
async def auth_bot(plugin: PluginInfo, bot_id: str):
|
||||
"""bot层面的权限检查
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
bot_id: bot id
|
||||
|
||||
异常:
|
||||
SkipPluginException: 忽略插件
|
||||
SkipPluginException: 忽略插件
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 从数据库或缓存中获取 bot 信息
|
||||
bot_dao = DataAccess(BotConsole)
|
||||
|
||||
try:
|
||||
bot: BotConsole | None = await asyncio.wait_for(
|
||||
bot_dao.safe_get_or_none(bot_id=bot_id), timeout=DB_TIMEOUT_SECONDS
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"查询Bot信息超时: bot_id={bot_id}", LOGGER_COMMAND)
|
||||
# 超时时不阻塞,继续执行
|
||||
return
|
||||
|
||||
if not bot or not bot.status:
|
||||
raise SkipPluginException("Bot不存在或休眠中阻断权限检测...")
|
||||
if CommonUtils.format(plugin.module) in bot.block_plugins:
|
||||
raise SkipPluginException(
|
||||
f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭..."
|
||||
)
|
||||
finally:
|
||||
# 记录执行时间
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||
logger.warning(
|
||||
f"auth_bot 耗时: {elapsed:.3f}s, "
|
||||
f"bot_id={bot_id}, plugin={plugin.module}",
|
||||
LOGGER_COMMAND,
|
||||
)
|
||||
41
zhenxun/builtin_plugins/hooks/auth/auth_cost.py
Normal file
41
zhenxun/builtin_plugins/hooks/auth/auth_cost.py
Normal file
@ -0,0 +1,41 @@
|
||||
import time
|
||||
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.models.user_console import UserConsole
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
|
||||
from .exception import SkipPluginException
|
||||
from .utils import send_message
|
||||
|
||||
|
||||
async def auth_cost(user: UserConsole, plugin: PluginInfo, session: Uninfo) -> int:
|
||||
"""检测是否满足金币条件
|
||||
|
||||
参数:
|
||||
user: UserConsole
|
||||
plugin: PluginInfo
|
||||
session: Uninfo
|
||||
|
||||
返回:
|
||||
int: 需要消耗的金币
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
if user.gold < plugin.cost_gold:
|
||||
"""插件消耗金币不足"""
|
||||
await send_message(session, f"金币不足..该功能需要{plugin.cost_gold}金币..")
|
||||
raise SkipPluginException(f"{plugin.name}({plugin.module}) 金币限制...")
|
||||
return plugin.cost_gold
|
||||
finally:
|
||||
# 记录执行时间
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||
logger.warning(
|
||||
f"auth_cost 耗时: {elapsed:.3f}s, plugin={plugin.module}",
|
||||
LOGGER_COMMAND,
|
||||
session=session,
|
||||
)
|
||||
68
zhenxun/builtin_plugins/hooks/auth/auth_group.py
Normal file
68
zhenxun/builtin_plugins/hooks/auth/auth_group.py
Normal file
@ -0,0 +1,68 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from nonebot_plugin_alconna import UniMsg
|
||||
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.services.data_access import DataAccess
|
||||
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.utils import EntityIDs
|
||||
|
||||
from .config import LOGGER_COMMAND, WARNING_THRESHOLD, SwitchEnum
|
||||
from .exception import SkipPluginException
|
||||
|
||||
|
||||
async def auth_group(plugin: PluginInfo, entity: EntityIDs, message: UniMsg):
|
||||
"""群黑名单检测 群总开关检测
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
entity: EntityIDs
|
||||
message: UniMsg
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
if not entity.group_id:
|
||||
return
|
||||
|
||||
try:
|
||||
text = message.extract_plain_text()
|
||||
|
||||
# 从数据库或缓存中获取群组信息
|
||||
group_dao = DataAccess(GroupConsole)
|
||||
|
||||
try:
|
||||
group: GroupConsole | None = await asyncio.wait_for(
|
||||
group_dao.safe_get_or_none(
|
||||
group_id=entity.group_id, channel_id__isnull=True
|
||||
),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("查询群组信息超时", LOGGER_COMMAND, session=entity.user_id)
|
||||
# 超时时不阻塞,继续执行
|
||||
return
|
||||
|
||||
if not group:
|
||||
raise SkipPluginException("群组信息不存在...")
|
||||
if group.level < 0:
|
||||
raise SkipPluginException("群组黑名单, 目标群组群权限权限-1...")
|
||||
if text.strip() != SwitchEnum.ENABLE and not group.status:
|
||||
raise SkipPluginException("群组休眠状态...")
|
||||
if plugin.level > group.level:
|
||||
raise SkipPluginException(
|
||||
f"{plugin.name}({plugin.module}) 群等级限制,"
|
||||
f"该功能需要的群等级: {plugin.level}..."
|
||||
)
|
||||
finally:
|
||||
# 记录执行时间
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||
logger.warning(
|
||||
f"auth_group 耗时: {elapsed:.3f}s, plugin={plugin.module}",
|
||||
LOGGER_COMMAND,
|
||||
session=entity.user_id,
|
||||
group_id=entity.group_id,
|
||||
)
|
||||
318
zhenxun/builtin_plugins/hooks/auth/auth_limit.py
Normal file
318
zhenxun/builtin_plugins/hooks/auth/auth_limit.py
Normal file
@ -0,0 +1,318 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import ClassVar
|
||||
|
||||
import nonebot
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
from pydantic import BaseModel
|
||||
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.models.plugin_limit import PluginLimit
|
||||
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import LimitWatchType, PluginLimitType
|
||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
from zhenxun.utils.utils import (
|
||||
CountLimiter,
|
||||
FreqLimiter,
|
||||
UserBlockLimiter,
|
||||
get_entity_ids,
|
||||
)
|
||||
|
||||
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
|
||||
from .exception import SkipPluginException
|
||||
|
||||
driver = nonebot.get_driver()
|
||||
|
||||
|
||||
@PriorityLifecycle.on_startup(priority=5)
|
||||
async def _():
|
||||
"""初始化限制"""
|
||||
await LimitManager.init_limit()
|
||||
|
||||
|
||||
class Limit(BaseModel):
|
||||
limit: PluginLimit
|
||||
limiter: FreqLimiter | UserBlockLimiter | CountLimiter
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class LimitManager:
|
||||
add_module: ClassVar[list] = []
|
||||
last_update_time: ClassVar[float] = 0
|
||||
update_interval: ClassVar[float] = 6000 # 1小时更新一次
|
||||
is_updating: ClassVar[bool] = False # 防止并发更新
|
||||
|
||||
cd_limit: ClassVar[dict[str, Limit]] = {}
|
||||
block_limit: ClassVar[dict[str, Limit]] = {}
|
||||
count_limit: ClassVar[dict[str, Limit]] = {}
|
||||
|
||||
# 模块限制缓存,避免频繁查询数据库
|
||||
module_limit_cache: ClassVar[dict[str, tuple[float, list[PluginLimit]]]] = {}
|
||||
module_cache_ttl: ClassVar[float] = 60 # 模块缓存有效期(秒)
|
||||
|
||||
@classmethod
|
||||
async def init_limit(cls):
|
||||
"""初始化限制"""
|
||||
cls.last_update_time = time.time()
|
||||
try:
|
||||
await asyncio.wait_for(cls.update_limits(), timeout=DB_TIMEOUT_SECONDS * 2)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("初始化限制超时", LOGGER_COMMAND)
|
||||
|
||||
@classmethod
|
||||
async def update_limits(cls):
|
||||
"""更新限制信息"""
|
||||
# 防止并发更新
|
||||
if cls.is_updating:
|
||||
return
|
||||
|
||||
cls.is_updating = True
|
||||
try:
|
||||
start_time = time.time()
|
||||
try:
|
||||
limit_list = await asyncio.wait_for(
|
||||
PluginLimit.filter(status=True).all(), timeout=DB_TIMEOUT_SECONDS
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("查询限制信息超时", LOGGER_COMMAND)
|
||||
cls.is_updating = False
|
||||
return
|
||||
|
||||
# 清空旧数据
|
||||
cls.add_module = []
|
||||
cls.cd_limit = {}
|
||||
cls.block_limit = {}
|
||||
cls.count_limit = {}
|
||||
# 添加新数据
|
||||
for limit in limit_list:
|
||||
cls.add_limit(limit)
|
||||
|
||||
cls.last_update_time = time.time()
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的更新
|
||||
logger.warning(f"更新限制信息耗时: {elapsed:.3f}s", LOGGER_COMMAND)
|
||||
finally:
|
||||
cls.is_updating = False
|
||||
|
||||
@classmethod
|
||||
def add_limit(cls, limit: PluginLimit):
|
||||
"""添加限制
|
||||
|
||||
参数:
|
||||
limit: PluginLimit
|
||||
"""
|
||||
if limit.module not in cls.add_module:
|
||||
cls.add_module.append(limit.module)
|
||||
if limit.limit_type == PluginLimitType.BLOCK:
|
||||
cls.block_limit[limit.module] = Limit(
|
||||
limit=limit, limiter=UserBlockLimiter()
|
||||
)
|
||||
elif limit.limit_type == PluginLimitType.CD:
|
||||
cls.cd_limit[limit.module] = Limit(
|
||||
limit=limit, limiter=FreqLimiter(limit.cd)
|
||||
)
|
||||
elif limit.limit_type == PluginLimitType.COUNT:
|
||||
cls.count_limit[limit.module] = Limit(
|
||||
limit=limit, limiter=CountLimiter(limit.max_count)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def unblock(
|
||||
cls, module: str, user_id: str, group_id: str | None, channel_id: str | None
|
||||
):
|
||||
"""解除插件block
|
||||
|
||||
参数:
|
||||
module: 模块名
|
||||
user_id: 用户id
|
||||
group_id: 群组id
|
||||
channel_id: 频道id
|
||||
"""
|
||||
if limit_model := cls.block_limit.get(module):
|
||||
limit = limit_model.limit
|
||||
limiter: UserBlockLimiter = limit_model.limiter # type: ignore
|
||||
key_type = user_id
|
||||
if group_id and limit.watch_type == LimitWatchType.GROUP:
|
||||
key_type = channel_id or group_id
|
||||
logger.debug(
|
||||
f"解除对象: {key_type} 的block限制",
|
||||
LOGGER_COMMAND,
|
||||
session=user_id,
|
||||
group_id=group_id,
|
||||
)
|
||||
limiter.set_false(key_type)
|
||||
|
||||
@classmethod
|
||||
async def get_module_limits(cls, module: str) -> list[PluginLimit]:
|
||||
"""获取模块的限制信息,使用缓存减少数据库查询
|
||||
|
||||
参数:
|
||||
module: 模块名
|
||||
|
||||
返回:
|
||||
list[PluginLimit]: 限制列表
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
# 检查缓存
|
||||
if module in cls.module_limit_cache:
|
||||
cache_time, limits = cls.module_limit_cache[module]
|
||||
if current_time - cache_time < cls.module_cache_ttl:
|
||||
return limits
|
||||
|
||||
# 缓存不存在或已过期,从数据库查询
|
||||
try:
|
||||
start_time = time.time()
|
||||
limits = await asyncio.wait_for(
|
||||
PluginLimit.filter(module=module, status=True).all(),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的查询
|
||||
logger.warning(
|
||||
f"查询模块限制信息耗时: {elapsed:.3f}s, 模块: {module}",
|
||||
LOGGER_COMMAND,
|
||||
)
|
||||
|
||||
# 更新缓存
|
||||
cls.module_limit_cache[module] = (current_time, limits)
|
||||
return limits
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"查询模块限制信息超时: {module}", LOGGER_COMMAND)
|
||||
# 超时时返回空列表,避免阻塞
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
async def check(
|
||||
cls,
|
||||
module: str,
|
||||
user_id: str,
|
||||
group_id: str | None,
|
||||
channel_id: str | None,
|
||||
):
|
||||
"""检测限制
|
||||
|
||||
参数:
|
||||
module: 模块名
|
||||
user_id: 用户id
|
||||
group_id: 群组id
|
||||
channel_id: 频道id
|
||||
|
||||
异常:
|
||||
IgnoredException: IgnoredException
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# 定期更新全局限制信息
|
||||
if (
|
||||
time.time() - cls.last_update_time > cls.update_interval
|
||||
and not cls.is_updating
|
||||
):
|
||||
# 使用异步任务更新,避免阻塞当前请求
|
||||
asyncio.create_task(cls.update_limits()) # noqa: RUF006
|
||||
|
||||
# 如果模块不在已加载列表中,只加载该模块的限制
|
||||
if module not in cls.add_module:
|
||||
limits = await cls.get_module_limits(module)
|
||||
for limit in limits:
|
||||
cls.add_limit(limit)
|
||||
|
||||
# 检查各种限制
|
||||
try:
|
||||
if limit_model := cls.cd_limit.get(module):
|
||||
await cls.__check(limit_model, user_id, group_id, channel_id)
|
||||
if limit_model := cls.block_limit.get(module):
|
||||
await cls.__check(limit_model, user_id, group_id, channel_id)
|
||||
if limit_model := cls.count_limit.get(module):
|
||||
await cls.__check(limit_model, user_id, group_id, channel_id)
|
||||
finally:
|
||||
# 记录总执行时间
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||
logger.warning(
|
||||
f"限制检查耗时: {elapsed:.3f}s, 模块: {module}",
|
||||
LOGGER_COMMAND,
|
||||
session=user_id,
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def __check(
|
||||
cls,
|
||||
limit_model: Limit | None,
|
||||
user_id: str,
|
||||
group_id: str | None,
|
||||
channel_id: str | None,
|
||||
):
|
||||
"""检测限制
|
||||
|
||||
参数:
|
||||
limit_model: Limit
|
||||
user_id: 用户id
|
||||
group_id: 群组id
|
||||
channel_id: 频道id
|
||||
|
||||
异常:
|
||||
IgnoredException: IgnoredException
|
||||
"""
|
||||
if not limit_model:
|
||||
return
|
||||
limit = limit_model.limit
|
||||
limiter = limit_model.limiter
|
||||
is_limit = (
|
||||
LimitWatchType.ALL
|
||||
or (group_id and limit.watch_type == LimitWatchType.GROUP)
|
||||
or (not group_id and limit.watch_type == LimitWatchType.USER)
|
||||
)
|
||||
key_type = user_id
|
||||
if group_id and limit.watch_type == LimitWatchType.GROUP:
|
||||
key_type = channel_id or group_id
|
||||
if is_limit and not limiter.check(key_type):
|
||||
if limit.result:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
MessageUtils.build_message(limit.result).send(),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"发送限制消息超时: {limit.module}", LOGGER_COMMAND)
|
||||
raise SkipPluginException(
|
||||
f"{limit.module}({limit.limit_type}) 正在限制中..."
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"开始进行限制 {limit.module}({limit.limit_type})...",
|
||||
LOGGER_COMMAND,
|
||||
session=user_id,
|
||||
group_id=group_id,
|
||||
)
|
||||
if isinstance(limiter, FreqLimiter):
|
||||
limiter.start_cd(key_type)
|
||||
if isinstance(limiter, UserBlockLimiter):
|
||||
limiter.set_true(key_type)
|
||||
if isinstance(limiter, CountLimiter):
|
||||
limiter.increase(key_type)
|
||||
|
||||
|
||||
async def auth_limit(plugin: PluginInfo, session: Uninfo):
|
||||
"""插件限制
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
session: Uninfo
|
||||
"""
|
||||
entity = get_entity_ids(session)
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
LimitManager.check(
|
||||
plugin.module, entity.user_id, entity.group_id, entity.channel_id
|
||||
),
|
||||
timeout=DB_TIMEOUT_SECONDS * 2, # 给予更长的超时时间
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"检查插件限制超时: {plugin.module}", LOGGER_COMMAND)
|
||||
# 超时时不抛出异常,允许继续执行
|
||||
242
zhenxun/builtin_plugins/hooks/auth/auth_plugin.py
Normal file
242
zhenxun/builtin_plugins/hooks/auth/auth_plugin.py
Normal file
@ -0,0 +1,242 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from nonebot.adapters import Event
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.services.data_access import DataAccess
|
||||
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.common_utils import CommonUtils
|
||||
from zhenxun.utils.enum import BlockType
|
||||
from zhenxun.utils.utils import get_entity_ids
|
||||
|
||||
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
|
||||
from .exception import IsSuperuserException, SkipPluginException
|
||||
from .utils import freq, is_poke, send_message
|
||||
|
||||
|
||||
class GroupCheck:
|
||||
def __init__(
|
||||
self, plugin: PluginInfo, group_id: str, session: Uninfo, is_poke: bool
|
||||
) -> None:
|
||||
self.group_id = group_id
|
||||
self.session = session
|
||||
self.is_poke = is_poke
|
||||
self.plugin = plugin
|
||||
self.group_dao = DataAccess(GroupConsole)
|
||||
self.group_data = None
|
||||
|
||||
async def check(self):
|
||||
start_time = time.time()
|
||||
try:
|
||||
# 只查询一次数据库,使用 DataAccess 的缓存机制
|
||||
try:
|
||||
self.group_data = await asyncio.wait_for(
|
||||
self.group_dao.safe_get_or_none(
|
||||
group_id=self.group_id, channel_id__isnull=True
|
||||
),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"查询群组数据超时: {self.group_id}", LOGGER_COMMAND)
|
||||
return # 超时时不阻塞,继续执行
|
||||
|
||||
# 检查超级用户禁用
|
||||
if (
|
||||
self.group_data
|
||||
and CommonUtils.format(self.plugin.module)
|
||||
in self.group_data.superuser_block_plugin
|
||||
):
|
||||
if freq.is_send_limit_message(self.plugin, self.group_id, self.is_poke):
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
send_message(
|
||||
self.session,
|
||||
"超级管理员禁用了该群此功能...",
|
||||
self.group_id,
|
||||
),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"发送消息超时: {self.group_id}", LOGGER_COMMAND)
|
||||
raise SkipPluginException(
|
||||
f"{self.plugin.name}({self.plugin.module})"
|
||||
f" 超级管理员禁用了该群此功能..."
|
||||
)
|
||||
|
||||
# 检查普通禁用
|
||||
if (
|
||||
self.group_data
|
||||
and CommonUtils.format(self.plugin.module)
|
||||
in self.group_data.block_plugin
|
||||
):
|
||||
if freq.is_send_limit_message(self.plugin, self.group_id, self.is_poke):
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
send_message(
|
||||
self.session, "该群未开启此功能...", self.group_id
|
||||
),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"发送消息超时: {self.group_id}", LOGGER_COMMAND)
|
||||
raise SkipPluginException(
|
||||
f"{self.plugin.name}({self.plugin.module}) 未开启此功能..."
|
||||
)
|
||||
|
||||
# 检查全局禁用
|
||||
if self.plugin.block_type == BlockType.GROUP:
|
||||
if freq.is_send_limit_message(self.plugin, self.group_id, self.is_poke):
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
send_message(
|
||||
self.session, "该功能在群组中已被禁用...", self.group_id
|
||||
),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"发送消息超时: {self.group_id}", LOGGER_COMMAND)
|
||||
raise SkipPluginException(
|
||||
f"{self.plugin.name}({self.plugin.module})该插件在群组中已被禁用..."
|
||||
)
|
||||
finally:
|
||||
# 记录执行时间
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||
logger.warning(
|
||||
f"GroupCheck.check 耗时: {elapsed:.3f}s, 群组: {self.group_id}",
|
||||
LOGGER_COMMAND,
|
||||
)
|
||||
|
||||
|
||||
class PluginCheck:
|
||||
def __init__(self, group_id: str | None, session: Uninfo, is_poke: bool):
|
||||
self.session = session
|
||||
self.is_poke = is_poke
|
||||
self.group_id = group_id
|
||||
self.group_dao = DataAccess(GroupConsole)
|
||||
self.group_data = None
|
||||
|
||||
async def check_user(self, plugin: PluginInfo):
|
||||
"""全局私聊禁用检测
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
|
||||
异常:
|
||||
IgnoredException: 忽略插件
|
||||
"""
|
||||
if plugin.block_type == BlockType.PRIVATE:
|
||||
if freq.is_send_limit_message(plugin, self.session.user.id, self.is_poke):
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
send_message(self.session, "该功能在私聊中已被禁用..."),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("发送消息超时", LOGGER_COMMAND)
|
||||
raise SkipPluginException(
|
||||
f"{plugin.name}({plugin.module}) 该插件在私聊中已被禁用..."
|
||||
)
|
||||
|
||||
async def check_global(self, plugin: PluginInfo):
|
||||
"""全局状态
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
|
||||
异常:
|
||||
IgnoredException: 忽略插件
|
||||
"""
|
||||
start_time = time.time()
|
||||
try:
|
||||
if plugin.status or plugin.block_type != BlockType.ALL:
|
||||
return
|
||||
"""全局状态"""
|
||||
if self.group_id:
|
||||
# 使用 DataAccess 的缓存机制
|
||||
try:
|
||||
self.group_data = await asyncio.wait_for(
|
||||
self.group_dao.safe_get_or_none(
|
||||
group_id=self.group_id, channel_id__isnull=True
|
||||
),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"查询群组数据超时: {self.group_id}", LOGGER_COMMAND)
|
||||
return # 超时时不阻塞,继续执行
|
||||
|
||||
if self.group_data and self.group_data.is_super:
|
||||
raise IsSuperuserException()
|
||||
|
||||
sid = self.group_id or self.session.user.id
|
||||
if freq.is_send_limit_message(plugin, sid, self.is_poke):
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
send_message(self.session, "全局未开启此功能...", sid),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"发送消息超时: {sid}", LOGGER_COMMAND)
|
||||
raise SkipPluginException(
|
||||
f"{plugin.name}({plugin.module}) 全局未开启此功能..."
|
||||
)
|
||||
finally:
|
||||
# 记录执行时间
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||
logger.warning(
|
||||
f"PluginCheck.check_global 耗时: {elapsed:.3f}s", LOGGER_COMMAND
|
||||
)
|
||||
|
||||
|
||||
async def auth_plugin(plugin: PluginInfo, session: Uninfo, event: Event):
|
||||
"""插件状态
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
session: Uninfo
|
||||
event: Event
|
||||
"""
|
||||
start_time = time.time()
|
||||
try:
|
||||
entity = get_entity_ids(session)
|
||||
is_poke_event = is_poke(event)
|
||||
user_check = PluginCheck(entity.group_id, session, is_poke_event)
|
||||
|
||||
if entity.group_id:
|
||||
group_check = GroupCheck(plugin, entity.group_id, session, is_poke_event)
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
group_check.check(), timeout=DB_TIMEOUT_SECONDS * 2
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"群组检查超时: {entity.group_id}", LOGGER_COMMAND)
|
||||
# 超时时不阻塞,继续执行
|
||||
else:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
user_check.check_user(plugin), timeout=DB_TIMEOUT_SECONDS
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("用户检查超时", LOGGER_COMMAND)
|
||||
# 超时时不阻塞,继续执行
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
user_check.check_global(plugin), timeout=DB_TIMEOUT_SECONDS
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("全局检查超时", LOGGER_COMMAND)
|
||||
# 超时时不阻塞,继续执行
|
||||
finally:
|
||||
# 记录总执行时间
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||
logger.warning(
|
||||
f"auth_plugin 总耗时: {elapsed:.3f}s, 模块: {plugin.module}",
|
||||
LOGGER_COMMAND,
|
||||
)
|
||||
35
zhenxun/builtin_plugins/hooks/auth/bot_filter.py
Normal file
35
zhenxun/builtin_plugins/hooks/auth/bot_filter.py
Normal file
@ -0,0 +1,35 @@
|
||||
import nonebot
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
|
||||
from .exception import SkipPluginException
|
||||
|
||||
Config.add_plugin_config(
|
||||
"hook",
|
||||
"FILTER_BOT",
|
||||
True,
|
||||
help="过滤当前连接bot(防止bot互相调用)",
|
||||
default_value=True,
|
||||
type=bool,
|
||||
)
|
||||
|
||||
|
||||
def bot_filter(session: Uninfo):
|
||||
"""过滤bot调用bot
|
||||
|
||||
参数:
|
||||
session: Uninfo
|
||||
|
||||
异常:
|
||||
SkipPluginException: bot互相调用
|
||||
"""
|
||||
if not Config.get_config("hook", "FILTER_BOT"):
|
||||
return
|
||||
bot_ids = list(nonebot.get_bots().keys())
|
||||
if session.user.id == session.self_id:
|
||||
return
|
||||
if session.user.id in bot_ids:
|
||||
raise SkipPluginException(
|
||||
f"bot:{session.self_id} 尝试调用 bot:{session.user.id}"
|
||||
)
|
||||
16
zhenxun/builtin_plugins/hooks/auth/config.py
Normal file
16
zhenxun/builtin_plugins/hooks/auth/config.py
Normal file
@ -0,0 +1,16 @@
|
||||
import sys
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from enum import StrEnum
|
||||
else:
|
||||
from strenum import StrEnum
|
||||
|
||||
LOGGER_COMMAND = "AuthChecker"
|
||||
|
||||
|
||||
class SwitchEnum(StrEnum):
|
||||
ENABLE = "醒来"
|
||||
DISABLE = "休息吧"
|
||||
|
||||
|
||||
WARNING_THRESHOLD = 0.5 # 警告阈值(秒)
|
||||
26
zhenxun/builtin_plugins/hooks/auth/exception.py
Normal file
26
zhenxun/builtin_plugins/hooks/auth/exception.py
Normal file
@ -0,0 +1,26 @@
|
||||
class IsSuperuserException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class SkipPluginException(Exception):
|
||||
def __init__(self, info: str, *args: object) -> None:
|
||||
super().__init__(*args)
|
||||
self.info = info
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.info
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.info
|
||||
|
||||
|
||||
class PermissionExemption(Exception):
|
||||
def __init__(self, info: str, *args: object) -> None:
|
||||
super().__init__(*args)
|
||||
self.info = info
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.info
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.info
|
||||
91
zhenxun/builtin_plugins/hooks/auth/utils.py
Normal file
91
zhenxun/builtin_plugins/hooks/auth/utils.py
Normal file
@ -0,0 +1,91 @@
|
||||
import contextlib
|
||||
|
||||
from nonebot.adapters import Event
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import PluginType
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
from zhenxun.utils.utils import FreqLimiter
|
||||
|
||||
from .config import LOGGER_COMMAND
|
||||
|
||||
base_config = Config.get("hook")
|
||||
|
||||
|
||||
def is_poke(event: Event) -> bool:
|
||||
"""判断是否为poke类型
|
||||
|
||||
参数:
|
||||
event: Event
|
||||
|
||||
返回:
|
||||
bool: 是否为poke类型
|
||||
"""
|
||||
with contextlib.suppress(ImportError):
|
||||
from nonebot.adapters.onebot.v11 import PokeNotifyEvent
|
||||
|
||||
return isinstance(event, PokeNotifyEvent)
|
||||
return False
|
||||
|
||||
|
||||
async def send_message(
|
||||
session: Uninfo, message: list | str, check_tag: str | None = None
|
||||
):
|
||||
"""发送消息
|
||||
|
||||
参数:
|
||||
session: Uninfo
|
||||
message: 消息
|
||||
check_tag: cd flag
|
||||
"""
|
||||
try:
|
||||
if not check_tag:
|
||||
await MessageUtils.build_message(message).send(reply_to=True)
|
||||
elif freq._flmt.check(check_tag):
|
||||
freq._flmt.start_cd(check_tag)
|
||||
await MessageUtils.build_message(message).send(reply_to=True)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"发送消息失败",
|
||||
LOGGER_COMMAND,
|
||||
session=session,
|
||||
e=e,
|
||||
)
|
||||
|
||||
|
||||
class FreqUtils:
|
||||
def __init__(self):
|
||||
check_notice_info_cd = Config.get_config("hook", "CHECK_NOTICE_INFO_CD")
|
||||
if check_notice_info_cd is None or check_notice_info_cd < 0:
|
||||
raise ValueError("模块: [hook], 配置项: [CHECK_NOTICE_INFO_CD] 为空或小于0")
|
||||
self._flmt = FreqLimiter(check_notice_info_cd)
|
||||
self._flmt_g = FreqLimiter(check_notice_info_cd)
|
||||
self._flmt_s = FreqLimiter(check_notice_info_cd)
|
||||
self._flmt_c = FreqLimiter(check_notice_info_cd)
|
||||
|
||||
def is_send_limit_message(
|
||||
self, plugin: PluginInfo, sid: str, is_poke: bool
|
||||
) -> bool:
|
||||
"""是否发送提示消息
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
sid: 检测键
|
||||
is_poke: 是否是戳一戳
|
||||
|
||||
返回:
|
||||
bool: 是否发送提示消息
|
||||
"""
|
||||
if is_poke:
|
||||
return False
|
||||
if not base_config.get("IS_SEND_TIP_MESSAGE"):
|
||||
return False
|
||||
if plugin.plugin_type == PluginType.DEPENDANT:
|
||||
return False
|
||||
return plugin.module != "ai" if self._flmt_s.check(sid) else False
|
||||
|
||||
|
||||
freq = FreqUtils()
|
||||
375
zhenxun/builtin_plugins/hooks/auth_checker.py
Normal file
375
zhenxun/builtin_plugins/hooks/auth_checker.py
Normal file
@ -0,0 +1,375 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot.exception import IgnoredException
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot_plugin_alconna import UniMsg
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
from tortoise.exceptions import IntegrityError
|
||||
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.models.user_console import UserConsole
|
||||
from zhenxun.services.data_access import DataAccess
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import GoldHandle, PluginType
|
||||
from zhenxun.utils.exception import InsufficientGold
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
from zhenxun.utils.utils import get_entity_ids
|
||||
|
||||
from .auth.auth_admin import auth_admin
|
||||
from .auth.auth_ban import auth_ban
|
||||
from .auth.auth_bot import auth_bot
|
||||
from .auth.auth_cost import auth_cost
|
||||
from .auth.auth_group import auth_group
|
||||
from .auth.auth_limit import LimitManager, auth_limit
|
||||
from .auth.auth_plugin import auth_plugin
|
||||
from .auth.bot_filter import bot_filter
|
||||
from .auth.config import LOGGER_COMMAND, WARNING_THRESHOLD
|
||||
from .auth.exception import (
|
||||
IsSuperuserException,
|
||||
PermissionExemption,
|
||||
SkipPluginException,
|
||||
)
|
||||
|
||||
# 超时设置(秒)
|
||||
TIMEOUT_SECONDS = 5.0
|
||||
# 熔断计数器
|
||||
CIRCUIT_BREAKERS = {
|
||||
"auth_ban": {"failures": 0, "threshold": 3, "active": False, "reset_time": 0},
|
||||
"auth_bot": {"failures": 0, "threshold": 3, "active": False, "reset_time": 0},
|
||||
"auth_group": {"failures": 0, "threshold": 3, "active": False, "reset_time": 0},
|
||||
"auth_admin": {"failures": 0, "threshold": 3, "active": False, "reset_time": 0},
|
||||
"auth_plugin": {"failures": 0, "threshold": 3, "active": False, "reset_time": 0},
|
||||
"auth_limit": {"failures": 0, "threshold": 3, "active": False, "reset_time": 0},
|
||||
}
|
||||
# 熔断重置时间(秒)
|
||||
CIRCUIT_RESET_TIME = 300 # 5分钟
|
||||
|
||||
|
||||
# 超时装饰器
|
||||
async def with_timeout(coro, timeout=TIMEOUT_SECONDS, name=None):
|
||||
"""带超时控制的协程执行
|
||||
|
||||
参数:
|
||||
coro: 要执行的协程
|
||||
timeout: 超时时间(秒)
|
||||
name: 操作名称,用于日志记录
|
||||
|
||||
返回:
|
||||
协程的返回值,或者在超时时抛出 TimeoutError
|
||||
"""
|
||||
try:
|
||||
return await asyncio.wait_for(coro, timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
if name:
|
||||
logger.error(f"{name} 操作超时 (>{timeout}s)", LOGGER_COMMAND)
|
||||
# 更新熔断计数器
|
||||
if name in CIRCUIT_BREAKERS:
|
||||
CIRCUIT_BREAKERS[name]["failures"] += 1
|
||||
if (
|
||||
CIRCUIT_BREAKERS[name]["failures"]
|
||||
>= CIRCUIT_BREAKERS[name]["threshold"]
|
||||
and not CIRCUIT_BREAKERS[name]["active"]
|
||||
):
|
||||
CIRCUIT_BREAKERS[name]["active"] = True
|
||||
CIRCUIT_BREAKERS[name]["reset_time"] = (
|
||||
time.time() + CIRCUIT_RESET_TIME
|
||||
)
|
||||
logger.warning(
|
||||
f"{name} 熔断器已激活,将在 {CIRCUIT_RESET_TIME} 秒后重置",
|
||||
LOGGER_COMMAND,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
# 检查熔断状态
|
||||
def check_circuit_breaker(name):
|
||||
"""检查熔断器状态
|
||||
|
||||
参数:
|
||||
name: 操作名称
|
||||
|
||||
返回:
|
||||
bool: 是否已熔断
|
||||
"""
|
||||
if name not in CIRCUIT_BREAKERS:
|
||||
return False
|
||||
|
||||
# 检查是否需要重置熔断器
|
||||
if (
|
||||
CIRCUIT_BREAKERS[name]["active"]
|
||||
and time.time() > CIRCUIT_BREAKERS[name]["reset_time"]
|
||||
):
|
||||
CIRCUIT_BREAKERS[name]["active"] = False
|
||||
CIRCUIT_BREAKERS[name]["failures"] = 0
|
||||
logger.info(f"{name} 熔断器已重置", LOGGER_COMMAND)
|
||||
|
||||
return CIRCUIT_BREAKERS[name]["active"]
|
||||
|
||||
|
||||
async def get_plugin_and_user(
|
||||
module: str, user_id: str
|
||||
) -> tuple[PluginInfo, UserConsole]:
|
||||
"""获取用户数据和插件信息
|
||||
|
||||
参数:
|
||||
module: 模块名
|
||||
user_id: 用户id
|
||||
|
||||
异常:
|
||||
PermissionExemption: 插件数据不存在
|
||||
PermissionExemption: 插件类型为HIDDEN
|
||||
PermissionExemption: 重复创建用户
|
||||
PermissionExemption: 用户数据不存在
|
||||
|
||||
返回:
|
||||
tuple[PluginInfo, UserConsole]: 插件信息,用户信息
|
||||
"""
|
||||
user_dao = DataAccess(UserConsole)
|
||||
plugin_dao = DataAccess(PluginInfo)
|
||||
|
||||
# 并行查询插件和用户数据
|
||||
plugin_task = plugin_dao.safe_get_or_none(module=module)
|
||||
user_task = user_dao.safe_get_or_none(user_id=user_id)
|
||||
|
||||
try:
|
||||
plugin, user = await with_timeout(
|
||||
asyncio.gather(plugin_task, user_task), name="get_plugin_and_user"
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
# 如果并行查询超时,尝试串行查询
|
||||
logger.warning("并行查询超时,尝试串行查询", LOGGER_COMMAND)
|
||||
plugin = await with_timeout(
|
||||
plugin_dao.safe_get_or_none(module=module), name="get_plugin"
|
||||
)
|
||||
user = await with_timeout(
|
||||
user_dao.safe_get_or_none(user_id=user_id), name="get_user"
|
||||
)
|
||||
|
||||
if not plugin:
|
||||
raise PermissionExemption(f"插件:{module} 数据不存在,已跳过权限检查...")
|
||||
if plugin.plugin_type == PluginType.HIDDEN:
|
||||
raise PermissionExemption(
|
||||
f"插件: {plugin.name}:{plugin.module} 为HIDDEN,已跳过权限检查..."
|
||||
)
|
||||
user = None
|
||||
try:
|
||||
user = await user_dao.safe_get_or_none(user_id=user_id)
|
||||
except IntegrityError as e:
|
||||
raise PermissionExemption("重复创建用户,已跳过该次权限检查...") from e
|
||||
if not user:
|
||||
raise PermissionExemption("用户数据不存在,已跳过权限检查...")
|
||||
return plugin, user
|
||||
|
||||
|
||||
async def get_plugin_cost(
|
||||
bot: Bot, user: UserConsole, plugin: PluginInfo, session: Uninfo
|
||||
) -> int:
|
||||
"""获取插件费用
|
||||
|
||||
参数:
|
||||
bot: Bot
|
||||
user: 用户数据
|
||||
plugin: 插件数据
|
||||
session: Uninfo
|
||||
|
||||
异常:
|
||||
IsSuperuserException: 超级用户
|
||||
IsSuperuserException: 超级用户
|
||||
|
||||
返回:
|
||||
int: 调用插件金币费用
|
||||
"""
|
||||
cost_gold = await with_timeout(auth_cost(user, plugin, session), name="auth_cost")
|
||||
if session.user.id in bot.config.superusers:
|
||||
if plugin.plugin_type == PluginType.SUPERUSER:
|
||||
raise IsSuperuserException()
|
||||
if not plugin.limit_superuser:
|
||||
raise IsSuperuserException()
|
||||
return cost_gold
|
||||
|
||||
|
||||
async def reduce_gold(user_id: str, module: str, cost_gold: int, session: Uninfo):
|
||||
"""扣除用户金币
|
||||
|
||||
参数:
|
||||
user_id: 用户id
|
||||
module: 插件模块名称
|
||||
cost_gold: 消耗金币
|
||||
session: Uninfo
|
||||
"""
|
||||
user_dao = DataAccess(UserConsole)
|
||||
try:
|
||||
await with_timeout(
|
||||
UserConsole.reduce_gold(
|
||||
user_id,
|
||||
cost_gold,
|
||||
GoldHandle.PLUGIN,
|
||||
module,
|
||||
PlatformUtils.get_platform(session),
|
||||
),
|
||||
name="reduce_gold",
|
||||
)
|
||||
except InsufficientGold:
|
||||
if u := await UserConsole.get_user(user_id):
|
||||
u.gold = 0
|
||||
await u.save(update_fields=["gold"])
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
f"扣除金币超时,用户: {user_id}, 金币: {cost_gold}",
|
||||
LOGGER_COMMAND,
|
||||
session=session,
|
||||
)
|
||||
|
||||
# 清除缓存,使下次查询时从数据库获取最新数据
|
||||
await user_dao.clear_cache(user_id=user_id)
|
||||
logger.debug(f"调用功能花费金币: {cost_gold}", LOGGER_COMMAND, session=session)
|
||||
|
||||
|
||||
# 辅助函数,用于记录每个 hook 的执行时间
|
||||
async def time_hook(coro, name, time_dict):
|
||||
start = time.time()
|
||||
try:
|
||||
# 检查熔断状态
|
||||
if check_circuit_breaker(name):
|
||||
logger.info(f"{name} 熔断器激活中,跳过执行", LOGGER_COMMAND)
|
||||
time_dict[name] = "熔断跳过"
|
||||
return
|
||||
|
||||
# 添加超时控制
|
||||
return await with_timeout(coro, name=name)
|
||||
except asyncio.TimeoutError:
|
||||
time_dict[name] = f"超时 (>{TIMEOUT_SECONDS}s)"
|
||||
finally:
|
||||
if name not in time_dict:
|
||||
time_dict[name] = f"{time.time() - start:.3f}s"
|
||||
|
||||
|
||||
async def auth(
|
||||
matcher: Matcher,
|
||||
event: Event,
|
||||
bot: Bot,
|
||||
session: Uninfo,
|
||||
message: UniMsg,
|
||||
):
|
||||
"""权限检查
|
||||
|
||||
参数:
|
||||
matcher: matcher
|
||||
event: Event
|
||||
bot: bot
|
||||
session: Uninfo
|
||||
message: UniMsg
|
||||
"""
|
||||
start_time = time.time()
|
||||
cost_gold = 0
|
||||
ignore_flag = False
|
||||
entity = get_entity_ids(session)
|
||||
module = matcher.plugin_name or ""
|
||||
|
||||
# 用于记录各个 hook 的执行时间
|
||||
hook_times = {}
|
||||
hooks_time = 0 # 初始化 hooks_time 变量
|
||||
|
||||
try:
|
||||
if not module:
|
||||
raise PermissionExemption("Matcher插件名称不存在...")
|
||||
|
||||
# 获取插件和用户数据
|
||||
plugin_user_start = time.time()
|
||||
try:
|
||||
plugin, user = await with_timeout(
|
||||
get_plugin_and_user(module, entity.user_id), name="get_plugin_and_user"
|
||||
)
|
||||
hook_times["get_plugin_user"] = f"{time.time() - plugin_user_start:.3f}s"
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
f"获取插件和用户数据超时,模块: {module}",
|
||||
LOGGER_COMMAND,
|
||||
session=session,
|
||||
)
|
||||
raise PermissionExemption("获取插件和用户数据超时,请稍后再试...")
|
||||
|
||||
# 获取插件费用
|
||||
cost_start = time.time()
|
||||
try:
|
||||
cost_gold = await with_timeout(
|
||||
get_plugin_cost(bot, user, plugin, session), name="get_plugin_cost"
|
||||
)
|
||||
hook_times["cost_gold"] = f"{time.time() - cost_start:.3f}s"
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
f"获取插件费用超时,模块: {module}", LOGGER_COMMAND, session=session
|
||||
)
|
||||
# 继续执行,不阻止权限检查
|
||||
|
||||
# 执行 bot_filter
|
||||
bot_filter(session)
|
||||
|
||||
# 并行执行所有 hook 检查,并记录执行时间
|
||||
hooks_start = time.time()
|
||||
|
||||
# 创建所有 hook 任务
|
||||
hook_tasks = [
|
||||
time_hook(auth_ban(matcher, bot, session), "auth_ban", hook_times),
|
||||
time_hook(auth_bot(plugin, bot.self_id), "auth_bot", hook_times),
|
||||
time_hook(auth_group(plugin, entity, message), "auth_group", hook_times),
|
||||
time_hook(auth_admin(plugin, session), "auth_admin", hook_times),
|
||||
time_hook(auth_plugin(plugin, session, event), "auth_plugin", hook_times),
|
||||
time_hook(auth_limit(plugin, session), "auth_limit", hook_times),
|
||||
]
|
||||
|
||||
# 使用 gather 并行执行所有 hook,但添加总体超时控制
|
||||
try:
|
||||
await with_timeout(
|
||||
asyncio.gather(*hook_tasks),
|
||||
timeout=TIMEOUT_SECONDS * 2, # 给总体执行更多时间
|
||||
name="auth_hooks_gather",
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
f"权限检查 hooks 总体执行超时,模块: {module}",
|
||||
LOGGER_COMMAND,
|
||||
session=session,
|
||||
)
|
||||
# 不抛出异常,允许继续执行
|
||||
|
||||
hooks_time = time.time() - hooks_start
|
||||
|
||||
except SkipPluginException as e:
|
||||
LimitManager.unblock(module, entity.user_id, entity.group_id, entity.channel_id)
|
||||
logger.info(str(e), LOGGER_COMMAND, session=session)
|
||||
ignore_flag = True
|
||||
except IsSuperuserException:
|
||||
logger.debug("超级用户跳过权限检测...", LOGGER_COMMAND, session=session)
|
||||
except PermissionExemption as e:
|
||||
logger.info(str(e), LOGGER_COMMAND, session=session)
|
||||
|
||||
# 扣除金币
|
||||
if not ignore_flag and cost_gold > 0:
|
||||
gold_start = time.time()
|
||||
try:
|
||||
await with_timeout(
|
||||
reduce_gold(entity.user_id, module, cost_gold, session),
|
||||
name="reduce_gold",
|
||||
)
|
||||
hook_times["reduce_gold"] = f"{time.time() - gold_start:.3f}s"
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
f"扣除金币超时,模块: {module}", LOGGER_COMMAND, session=session
|
||||
)
|
||||
|
||||
# 记录总执行时间
|
||||
total_time = time.time() - start_time
|
||||
if total_time > WARNING_THRESHOLD: # 如果总时间超过500ms,记录详细信息
|
||||
logger.warning(
|
||||
f"权限检查耗时过长: {total_time:.3f}s, 模块: {module}, "
|
||||
f"hooks时间: {hooks_time:.3f}s, "
|
||||
f"详情: {hook_times}",
|
||||
LOGGER_COMMAND,
|
||||
session=session,
|
||||
)
|
||||
|
||||
if ignore_flag:
|
||||
raise IgnoredException("权限检测 ignore")
|
||||
@ -1,41 +1,43 @@
|
||||
from nonebot.adapters.onebot.v11 import Bot, Event
|
||||
import time
|
||||
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot.message import run_postprocessor, run_preprocessor
|
||||
from nonebot_plugin_alconna import UniMsg
|
||||
from nonebot_plugin_session import EventSession
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from ._auth_checker import LimitManage, checker
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
from .auth.config import LOGGER_COMMAND
|
||||
from .auth_checker import LimitManager, auth
|
||||
|
||||
|
||||
# # 权限检测
|
||||
@run_preprocessor
|
||||
async def _(
|
||||
matcher: Matcher, event: Event, bot: Bot, session: EventSession, message: UniMsg
|
||||
):
|
||||
await checker.auth(
|
||||
async def _(matcher: Matcher, event: Event, bot: Bot, session: Uninfo, message: UniMsg):
|
||||
start_time = time.time()
|
||||
await auth(
|
||||
matcher,
|
||||
event,
|
||||
bot,
|
||||
session,
|
||||
message,
|
||||
)
|
||||
logger.debug(f"权限检测耗时:{time.time() - start_time}秒", LOGGER_COMMAND)
|
||||
|
||||
|
||||
# 解除命令block阻塞
|
||||
@run_postprocessor
|
||||
async def _(
|
||||
matcher: Matcher,
|
||||
exception: Exception | None,
|
||||
bot: Bot,
|
||||
event: Event,
|
||||
session: EventSession,
|
||||
):
|
||||
user_id = session.id1
|
||||
group_id = session.id3
|
||||
channel_id = session.id2
|
||||
if not group_id:
|
||||
group_id = channel_id
|
||||
channel_id = None
|
||||
async def _(matcher: Matcher, session: Uninfo):
|
||||
user_id = session.user.id
|
||||
group_id = None
|
||||
channel_id = None
|
||||
if session.group:
|
||||
if session.group.parent:
|
||||
group_id = session.group.parent.id
|
||||
channel_id = session.group.id
|
||||
else:
|
||||
group_id = session.group.id
|
||||
if user_id and matcher.plugin:
|
||||
module = matcher.plugin.name
|
||||
LimitManage.unblock(module, user_id, group_id, channel_id)
|
||||
LimitManager.unblock(module, user_id, group_id, channel_id)
|
||||
|
||||
@ -1,84 +0,0 @@
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot.exception import IgnoredException
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot.message import run_preprocessor
|
||||
from nonebot.typing import T_State
|
||||
from nonebot_plugin_alconna import At
|
||||
from nonebot_plugin_session import EventSession
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.models.ban_console import BanConsole
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import PluginType
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
from zhenxun.utils.utils import FreqLimiter
|
||||
|
||||
Config.add_plugin_config(
|
||||
"hook",
|
||||
"BAN_RESULT",
|
||||
"才不会给你发消息.",
|
||||
help="对被ban用户发送的消息",
|
||||
)
|
||||
|
||||
_flmt = FreqLimiter(300)
|
||||
|
||||
|
||||
# 检查是否被ban
|
||||
@run_preprocessor
|
||||
async def _(
|
||||
matcher: Matcher, bot: Bot, event: Event, state: T_State, session: EventSession
|
||||
):
|
||||
extra = {}
|
||||
if plugin := matcher.plugin:
|
||||
if metadata := plugin.metadata:
|
||||
extra = metadata.extra
|
||||
if extra.get("plugin_type") in [PluginType.HIDDEN]:
|
||||
return
|
||||
user_id = session.id1
|
||||
group_id = session.id3 or session.id2
|
||||
if group_id:
|
||||
if user_id in bot.config.superusers:
|
||||
return
|
||||
if await BanConsole.is_ban(None, group_id):
|
||||
logger.debug("群组处于黑名单中...", "ban_hook")
|
||||
raise IgnoredException("群组处于黑名单中...")
|
||||
if g := await GroupConsole.get_group(group_id):
|
||||
if g.level < 0:
|
||||
logger.debug("群黑名单, 群权限-1...", "ban_hook")
|
||||
raise IgnoredException("群黑名单, 群权限-1..")
|
||||
if user_id:
|
||||
ban_result = Config.get_config("hook", "BAN_RESULT")
|
||||
if user_id in bot.config.superusers:
|
||||
return
|
||||
if await BanConsole.is_ban(user_id, group_id):
|
||||
time = await BanConsole.check_ban_time(user_id, group_id)
|
||||
if time == -1:
|
||||
time_str = "∞"
|
||||
else:
|
||||
time = abs(int(time))
|
||||
if time < 60:
|
||||
time_str = f"{time!s} 秒"
|
||||
else:
|
||||
minute = int(time / 60)
|
||||
if minute > 60:
|
||||
hours = minute // 60
|
||||
minute %= 60
|
||||
time_str = f"{hours} 小时 {minute}分钟"
|
||||
else:
|
||||
time_str = f"{minute} 分钟"
|
||||
if (
|
||||
not extra.get("ignore_prompt")
|
||||
and time != -1
|
||||
and ban_result
|
||||
and _flmt.check(user_id)
|
||||
):
|
||||
_flmt.start_cd(user_id)
|
||||
await MessageUtils.build_message(
|
||||
[
|
||||
At(flag="user", target=user_id),
|
||||
f"{ban_result}\n在..在 {time_str} 后才会理你喔",
|
||||
]
|
||||
).send()
|
||||
logger.debug("用户处于黑名单中...", "ban_hook")
|
||||
raise IgnoredException("用户处于黑名单中...")
|
||||
@ -9,6 +9,8 @@ from zhenxun.utils.enum import BotSentType
|
||||
from zhenxun.utils.manager.message_manager import MessageManager
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
LOG_COMMAND = "MessageHook"
|
||||
|
||||
|
||||
def replace_message(message: Message) -> str:
|
||||
"""将消息中的at、image、record、face替换为字符串
|
||||
@ -54,11 +56,11 @@ async def handle_api_result(
|
||||
if user_id and message_id:
|
||||
MessageManager.add(str(user_id), str(message_id))
|
||||
logger.debug(
|
||||
f"收集消息id,user_id: {user_id}, msg_id: {message_id}", "msg_hook"
|
||||
f"收集消息id,user_id: {user_id}, msg_id: {message_id}", LOG_COMMAND
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"收集消息id发生错误...data: {data}, result: {result}", "msg_hook", e=e
|
||||
f"收集消息id发生错误...data: {data}, result: {result}", LOG_COMMAND, e=e
|
||||
)
|
||||
if not Config.get_config("hook", "RECORD_BOT_SENT_MESSAGES"):
|
||||
return
|
||||
@ -80,6 +82,6 @@ async def handle_api_result(
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"消息发送记录发生错误...data: {data}, result: {result}",
|
||||
"msg_hook",
|
||||
LOG_COMMAND,
|
||||
e=e,
|
||||
)
|
||||
|
||||
@ -4,15 +4,27 @@ import nonebot
|
||||
from nonebot.adapters import Bot
|
||||
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.services.cache import CacheException
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
nonebot.load_plugins(str(Path(__file__).parent.resolve()))
|
||||
|
||||
try:
|
||||
from .__init_cache import register_cache_types
|
||||
except CacheException as e:
|
||||
raise SystemError(f"ERROR:{e}")
|
||||
|
||||
driver = nonebot.get_driver()
|
||||
|
||||
|
||||
@PriorityLifecycle.on_startup(priority=5)
|
||||
async def _():
|
||||
register_cache_types()
|
||||
logger.info("缓存类型注册完成")
|
||||
|
||||
|
||||
@driver.on_bot_connect
|
||||
async def _(bot: Bot):
|
||||
"""将bot已存在的群组添加群认证
|
||||
|
||||
35
zhenxun/builtin_plugins/init/__init_cache.py
Normal file
35
zhenxun/builtin_plugins/init/__init_cache.py
Normal file
@ -0,0 +1,35 @@
|
||||
"""
|
||||
缓存初始化模块
|
||||
|
||||
负责注册各种缓存类型,实现按需缓存机制
|
||||
"""
|
||||
|
||||
from zhenxun.models.ban_console import BanConsole
|
||||
from zhenxun.models.bot_console import BotConsole
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.models.level_user import LevelUser
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.models.user_console import UserConsole
|
||||
from zhenxun.services.cache import CacheRegistry, cache_config
|
||||
from zhenxun.services.cache.config import CacheMode
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import CacheType
|
||||
|
||||
|
||||
# 注册缓存类型
|
||||
def register_cache_types():
|
||||
"""注册所有缓存类型"""
|
||||
CacheRegistry.register(CacheType.PLUGINS, PluginInfo)
|
||||
CacheRegistry.register(CacheType.GROUPS, GroupConsole)
|
||||
CacheRegistry.register(CacheType.BOT, BotConsole)
|
||||
CacheRegistry.register(CacheType.USERS, UserConsole)
|
||||
CacheRegistry.register(
|
||||
CacheType.LEVEL, LevelUser, key_format="{user_id}_{group_id}"
|
||||
)
|
||||
CacheRegistry.register(CacheType.BAN, BanConsole, key_format="{user_id}_{group_id}")
|
||||
|
||||
if cache_config.cache_mode == CacheMode.NONE:
|
||||
logger.info("缓存功能已禁用,将直接从数据库获取数据")
|
||||
else:
|
||||
logger.info(f"已注册所有缓存类型,缓存模式: {cache_config.cache_mode}")
|
||||
logger.info("使用增量缓存模式,数据将按需加载到缓存中")
|
||||
@ -1,3 +1,5 @@
|
||||
import asyncio
|
||||
|
||||
import aiofiles
|
||||
import nonebot
|
||||
from nonebot import get_loaded_plugins
|
||||
@ -112,24 +114,29 @@ async def _():
|
||||
await _handle_setting(plugin, plugin_list, limit_list)
|
||||
create_list = []
|
||||
update_list = []
|
||||
update_task_list = []
|
||||
for plugin in plugin_list:
|
||||
if plugin.module_path not in module2id:
|
||||
create_list.append(plugin)
|
||||
else:
|
||||
plugin.id = module2id[plugin.module_path]
|
||||
await plugin.save(
|
||||
update_fields=[
|
||||
"name",
|
||||
"author",
|
||||
"version",
|
||||
"admin_level",
|
||||
"plugin_type",
|
||||
"is_show",
|
||||
]
|
||||
update_task_list.append(
|
||||
plugin.save(
|
||||
update_fields=[
|
||||
"name",
|
||||
"author",
|
||||
"version",
|
||||
"admin_level",
|
||||
"plugin_type",
|
||||
"is_show",
|
||||
]
|
||||
)
|
||||
)
|
||||
update_list.append(plugin)
|
||||
if create_list:
|
||||
await PluginInfo.bulk_create(create_list, 10)
|
||||
if update_task_list:
|
||||
await asyncio.gather(*update_task_list)
|
||||
# if update_list:
|
||||
# # TODO: 批量更新无法更新plugin_type: tortoise.exceptions.OperationalError:
|
||||
# column "superuser" does not exist
|
||||
|
||||
@ -205,7 +205,7 @@ class Manager:
|
||||
self.cd_data: dict[str, PluginCdBlock] = {}
|
||||
if self.cd_file.exists():
|
||||
with open(self.cd_file, encoding="utf8") as f:
|
||||
temp = _yaml.load(f)
|
||||
temp = _yaml.load(f) or {}
|
||||
if "PluginCdLimit" in temp.keys():
|
||||
for k, v in temp["PluginCdLimit"].items():
|
||||
if "." in k:
|
||||
@ -216,7 +216,7 @@ class Manager:
|
||||
self.block_data: dict[str, BaseBlock] = {}
|
||||
if self.block_file.exists():
|
||||
with open(self.block_file, encoding="utf8") as f:
|
||||
temp = _yaml.load(f)
|
||||
temp = _yaml.load(f) or {}
|
||||
if "PluginBlockLimit" in temp.keys():
|
||||
for k, v in temp["PluginBlockLimit"].items():
|
||||
if "." in k:
|
||||
@ -227,7 +227,7 @@ class Manager:
|
||||
self.count_data: dict[str, PluginCountBlock] = {}
|
||||
if self.count_file.exists():
|
||||
with open(self.count_file, encoding="utf8") as f:
|
||||
temp = _yaml.load(f)
|
||||
temp = _yaml.load(f) or {}
|
||||
if "PluginCountLimit" in temp.keys():
|
||||
for k, v in temp["PluginCountLimit"].items():
|
||||
if "." in k:
|
||||
|
||||
@ -55,15 +55,17 @@ class GroupManager:
|
||||
if plugin_list := await PluginInfo.filter(default_status=False).all():
|
||||
for plugin in plugin_list:
|
||||
block_plugin += f"<{plugin.module},"
|
||||
group_info = await bot.get_group_info(group_id=group_id, no_cache=True)
|
||||
await GroupConsole.create(
|
||||
group_info = await bot.get_group_info(group_id=group_id)
|
||||
await GroupConsole.update_or_create(
|
||||
group_id=group_info["group_id"],
|
||||
group_name=group_info["group_name"],
|
||||
max_member_count=group_info["max_member_count"],
|
||||
member_count=group_info["member_count"],
|
||||
group_flag=1,
|
||||
block_plugin=block_plugin,
|
||||
platform="qq",
|
||||
defaults={
|
||||
"group_name": group_info["group_name"],
|
||||
"max_member_count": group_info["max_member_count"],
|
||||
"member_count": group_info["member_count"],
|
||||
"group_flag": 1,
|
||||
"block_plugin": block_plugin,
|
||||
"platform": "qq",
|
||||
},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -145,7 +147,7 @@ class GroupManager:
|
||||
e=e,
|
||||
)
|
||||
raise ForceAddGroupError("强制拉群或未有群信息,退出群聊失败...") from e
|
||||
await GroupConsole.filter(group_id=group_id).delete()
|
||||
# await GroupConsole.filter(group_id=group_id).delete()
|
||||
raise ForceAddGroupError(f"触发强制入群保护,已成功退出群聊 {group_id}...")
|
||||
else:
|
||||
await cls.__handle_add_group(bot, group_id, group)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from nonebot.message import run_preprocessor
|
||||
from nonebot import on_message
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.models.friend_user import FriendUser
|
||||
@ -8,24 +8,27 @@ from zhenxun.services.log import logger
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
|
||||
@run_preprocessor
|
||||
async def do_something(session: Uninfo):
|
||||
def rule(session: Uninfo) -> bool:
|
||||
return PlatformUtils.is_qbot(session)
|
||||
|
||||
|
||||
_matcher = on_message(priority=999, block=False, rule=rule)
|
||||
|
||||
|
||||
@_matcher.handle()
|
||||
async def _(session: Uninfo):
|
||||
platform = PlatformUtils.get_platform(session)
|
||||
if session.group:
|
||||
if not await GroupConsole.exists(group_id=session.group.id):
|
||||
await GroupConsole.create(group_id=session.group.id)
|
||||
logger.info("添加当前群组ID信息" "", session=session)
|
||||
|
||||
if not await GroupInfoUser.exists(
|
||||
user_id=session.user.id, group_id=session.group.id
|
||||
):
|
||||
await GroupInfoUser.create(
|
||||
user_id=session.user.id, group_id=session.group.id, platform=platform
|
||||
)
|
||||
logger.info("添加当前用户群组ID信息", "", session=session)
|
||||
logger.info("添加当前群组ID信息", session=session)
|
||||
await GroupInfoUser.update_or_create(
|
||||
user_id=session.user.id,
|
||||
group_id=session.group.id,
|
||||
platform=PlatformUtils.get_platform(session),
|
||||
)
|
||||
elif not await FriendUser.exists(user_id=session.user.id, platform=platform):
|
||||
try:
|
||||
await FriendUser.create(user_id=session.user.id, platform=platform)
|
||||
logger.info("添加当前好友用户信息", "", session=session)
|
||||
except Exception as e:
|
||||
logger.error("添加当前好友用户信息失败", session=session, e=e)
|
||||
await FriendUser.create(
|
||||
user_id=session.user.id, platform=PlatformUtils.get_platform(session)
|
||||
)
|
||||
logger.info("添加当前好友用户信息", "", session=session)
|
||||
|
||||
@ -1,30 +0,0 @@
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||
|
||||
|
||||
@PriorityLifecycle.on_startup(priority=5)
|
||||
async def _():
|
||||
"""开启/禁用插件格式修改"""
|
||||
_, is_create = await GroupConsole.get_or_create(group_id=133133133)
|
||||
"""标记"""
|
||||
if is_create:
|
||||
data_list = []
|
||||
for group in await GroupConsole.all():
|
||||
if group.block_plugin:
|
||||
if modules := group.block_plugin.split(","):
|
||||
block_plugin = "".join(
|
||||
(f"{module}," if module.startswith("<") else f"<{module},")
|
||||
for module in modules
|
||||
if module.strip()
|
||||
)
|
||||
group.block_plugin = block_plugin.replace("<,", "")
|
||||
if group.block_task:
|
||||
if modules := group.block_task.split(","):
|
||||
block_task = "".join(
|
||||
(f"{module}," if module.startswith("<") else f"<{module},")
|
||||
for module in modules
|
||||
if module.strip()
|
||||
)
|
||||
group.block_task = block_task.replace("<,", "")
|
||||
data_list.append(group)
|
||||
await GroupConsole.bulk_update(data_list, ["block_plugin", "block_task"], 10)
|
||||
@ -44,9 +44,7 @@ class StatisticsManage:
|
||||
title = f"{user.user_name if user else user_id} {day_type}功能调用统计"
|
||||
elif group_id:
|
||||
"""查群组"""
|
||||
group = await GroupConsole.get_or_none(
|
||||
group_id=group_id, channel_id__isnull=True
|
||||
)
|
||||
group = await GroupConsole.get_group(group_id=group_id)
|
||||
title = f"{group.group_name if group else group_id} {day_type}功能调用统计"
|
||||
else:
|
||||
title = "功能调用统计"
|
||||
|
||||
@ -163,7 +163,7 @@ async def _(session: EventSession, arparma: Arparma, state: T_State, level: int)
|
||||
@_matcher.assign("super-handle", parameterless=[CheckGroupId()])
|
||||
async def _(session: EventSession, arparma: Arparma, state: T_State):
|
||||
gid = state["group_id"]
|
||||
group = await GroupConsole.get_or_none(group_id=gid)
|
||||
group = await GroupConsole.get_group(group_id=gid)
|
||||
if not group:
|
||||
await MessageUtils.build_message("群组信息不存在, 请更新群组信息...").finish()
|
||||
s = "删除" if arparma.find("delete") else "添加"
|
||||
@ -177,7 +177,9 @@ async def _(session: EventSession, arparma: Arparma, state: T_State):
|
||||
async def _(session: EventSession, arparma: Arparma, state: T_State):
|
||||
gid = state["group_id"]
|
||||
await GroupConsole.update_or_create(
|
||||
group_id=gid, defaults={"group_flag": 0 if arparma.find("delete") else 1}
|
||||
group_id=gid,
|
||||
channel_id__isnull=True,
|
||||
defaults={"group_flag": 0 if arparma.find("delete") else 1},
|
||||
)
|
||||
s = "删除" if arparma.find("delete") else "添加"
|
||||
await MessageUtils.build_message(f"{s}群认证成功!").send(reply_to=True)
|
||||
|
||||
@ -119,7 +119,7 @@ class ApiDataSource:
|
||||
(await PlatformUtils.get_friend_list(select_bot.bot))[0]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("获取bot好友/群组信息失败...", "WebUi", e=e)
|
||||
logger.warning("获取bot好友/群组数量失败...", "WebUi", e=e)
|
||||
select_bot.group_count = 0
|
||||
select_bot.friend_count = 0
|
||||
select_bot.status = await BotConsole.get_bot_status(select_bot.self_id)
|
||||
|
||||
@ -250,7 +250,7 @@ class ApiDataSource:
|
||||
返回:
|
||||
GroupDetail | None: 群组详情数据
|
||||
"""
|
||||
group = await GroupConsole.get_or_none(group_id=group_id)
|
||||
group = await GroupConsole.get_group(group_id=group_id)
|
||||
if not group:
|
||||
return None
|
||||
like_plugin = await cls.__get_group_detail_like_plugin(group_id)
|
||||
|
||||
@ -45,6 +45,7 @@ async def _(path: str | None = None) -> Result[list[DirFile]]:
|
||||
mtime=file_path.stat().st_mtime,
|
||||
)
|
||||
)
|
||||
data_list.sort(key=lambda f: f.name)
|
||||
return Result.ok(data_list)
|
||||
except Exception as e:
|
||||
return Result.fail(f"获取文件列表失败: {e!s}")
|
||||
|
||||
@ -13,8 +13,8 @@ class BotSetting(BaseModel):
|
||||
"""回复时NICKNAME"""
|
||||
system_proxy: str | None = None
|
||||
"""系统代理"""
|
||||
db_url: str = ""
|
||||
"""数据库链接"""
|
||||
db_url: str = "sqlite:data/zhenxun.db"
|
||||
"""数据库链接, 默认值为sqlite:data/zhenxun.db"""
|
||||
platform_superusers: dict[str, list[str]] = Field(default_factory=dict)
|
||||
"""平台超级用户"""
|
||||
qbot_id_data: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
@ -155,8 +155,6 @@ class AICallableProperties(BaseModel):
|
||||
"""参数类型"""
|
||||
description: str
|
||||
"""参数描述"""
|
||||
enums: list[str] | None = None
|
||||
"""参数枚举"""
|
||||
|
||||
|
||||
class AICallableParam(BaseModel):
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
import time
|
||||
from typing import ClassVar
|
||||
from typing_extensions import Self
|
||||
|
||||
from tortoise import fields
|
||||
|
||||
from zhenxun.services.db_context import Model
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import CacheType, DbLockType
|
||||
from zhenxun.utils.exception import UserAndGroupIsNone
|
||||
|
||||
|
||||
@ -27,6 +29,15 @@ class BanConsole(Model):
|
||||
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
|
||||
table = "ban_console"
|
||||
table_description = "封禁人员/群组数据表"
|
||||
unique_together = ("user_id", "group_id")
|
||||
indexes = [("user_id",), ("group_id",)] # noqa: RUF012
|
||||
|
||||
cache_type = CacheType.BAN
|
||||
"""缓存类型"""
|
||||
cache_key_field = ("user_id", "group_id")
|
||||
"""缓存键字段"""
|
||||
enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE, DbLockType.UPSERT]
|
||||
"""开启锁"""
|
||||
|
||||
@classmethod
|
||||
async def _get_data(cls, user_id: str | None, group_id: str | None) -> Self | None:
|
||||
@ -46,12 +57,12 @@ class BanConsole(Model):
|
||||
raise UserAndGroupIsNone()
|
||||
if user_id:
|
||||
return (
|
||||
await cls.get_or_none(user_id=user_id, group_id=group_id)
|
||||
await cls.safe_get_or_none(user_id=user_id, group_id=group_id)
|
||||
if group_id
|
||||
else await cls.get_or_none(user_id=user_id, group_id__isnull=True)
|
||||
else await cls.safe_get_or_none(user_id=user_id, group_id__isnull=True)
|
||||
)
|
||||
else:
|
||||
return await cls.get_or_none(user_id="", group_id=group_id)
|
||||
return await cls.safe_get_or_none(user_id="", group_id=group_id)
|
||||
|
||||
@classmethod
|
||||
async def check_ban_level(
|
||||
@ -167,3 +178,32 @@ class BanConsole(Model):
|
||||
await user.delete()
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_ban(
|
||||
cls,
|
||||
*,
|
||||
id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
group_id: str | None = None,
|
||||
) -> Self | None:
|
||||
"""安全地获取ban记录
|
||||
|
||||
参数:
|
||||
id: 记录id
|
||||
user_id: 用户id
|
||||
group_id: 群组id
|
||||
|
||||
返回:
|
||||
Self | None: ban记录
|
||||
"""
|
||||
if id is not None:
|
||||
return await cls.safe_get_or_none(id=id)
|
||||
return await cls._get_data(user_id, group_id)
|
||||
|
||||
@classmethod
|
||||
async def _run_script(cls):
|
||||
return [
|
||||
"CREATE INDEX idx_ban_console_user_id ON ban_console(user_id);",
|
||||
"CREATE INDEX idx_ban_console_group_id ON ban_console(group_id);",
|
||||
]
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import Literal, overload
|
||||
from tortoise import fields
|
||||
|
||||
from zhenxun.services.db_context import Model
|
||||
from zhenxun.utils.enum import CacheType
|
||||
|
||||
|
||||
class BotConsole(Model):
|
||||
@ -29,6 +30,11 @@ class BotConsole(Model):
|
||||
table = "bot_console"
|
||||
table_description = "Bot数据表"
|
||||
|
||||
cache_type = CacheType.BOT
|
||||
"""缓存类型"""
|
||||
cache_key_field = "bot_id"
|
||||
"""缓存键字段"""
|
||||
|
||||
@staticmethod
|
||||
def format(name: str) -> str:
|
||||
return f"<{name},"
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any, cast, overload
|
||||
from typing import Any, ClassVar, cast, overload
|
||||
from typing_extensions import Self
|
||||
|
||||
from tortoise import fields
|
||||
@ -6,8 +6,9 @@ from tortoise.backends.base.client import BaseDBAsyncClient
|
||||
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.models.task_info import TaskInfo
|
||||
from zhenxun.services.cache import CacheRoot
|
||||
from zhenxun.services.db_context import Model
|
||||
from zhenxun.utils.enum import PluginType
|
||||
from zhenxun.utils.enum import CacheType, DbLockType, PluginType
|
||||
|
||||
|
||||
def add_disable_marker(name: str) -> str:
|
||||
@ -86,6 +87,16 @@ class GroupConsole(Model):
|
||||
table = "group_console"
|
||||
table_description = "群组信息表"
|
||||
unique_together = ("group_id", "channel_id")
|
||||
indexes = [ # noqa: RUF012
|
||||
("group_id",)
|
||||
]
|
||||
|
||||
cache_type = CacheType.GROUPS
|
||||
"""缓存类型"""
|
||||
cache_key_field = ("group_id", "channel_id")
|
||||
"""缓存键字段"""
|
||||
enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE, DbLockType.UPSERT]
|
||||
"""开启锁"""
|
||||
|
||||
@classmethod
|
||||
async def _get_task_modules(cls, *, default_status: bool) -> list[str]:
|
||||
@ -116,6 +127,18 @@ class GroupConsole(Model):
|
||||
).values_list("module", flat=True),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def _update_cache(cls, instance):
|
||||
"""更新缓存
|
||||
|
||||
参数:
|
||||
instance: 需要更新缓存的实例
|
||||
"""
|
||||
if cache_type := cls.get_cache_type():
|
||||
key = cls.get_cache_key(instance)
|
||||
if key is not None:
|
||||
await CacheRoot.invalidate_cache(cache_type, key)
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any
|
||||
@ -129,6 +152,9 @@ class GroupConsole(Model):
|
||||
if task_modules or plugin_modules:
|
||||
await cls._update_modules(group, task_modules, plugin_modules, using_db)
|
||||
|
||||
# 更新缓存
|
||||
await cls._update_cache(group)
|
||||
|
||||
return group
|
||||
|
||||
@classmethod
|
||||
@ -180,6 +206,10 @@ class GroupConsole(Model):
|
||||
if task_modules or plugin_modules:
|
||||
await cls._update_modules(group, task_modules, plugin_modules, using_db)
|
||||
|
||||
# 更新缓存
|
||||
if is_create:
|
||||
await cls._update_cache(group)
|
||||
|
||||
return group, is_create
|
||||
|
||||
@classmethod
|
||||
@ -202,24 +232,39 @@ class GroupConsole(Model):
|
||||
if task_modules or plugin_modules:
|
||||
await cls._update_modules(group, task_modules, plugin_modules, using_db)
|
||||
|
||||
# 更新缓存
|
||||
await cls._update_cache(group)
|
||||
|
||||
return group, is_create
|
||||
|
||||
@classmethod
|
||||
async def get_group(
|
||||
cls, group_id: str, channel_id: str | None = None
|
||||
cls,
|
||||
group_id: str,
|
||||
channel_id: str | None = None,
|
||||
clean_duplicates: bool = True,
|
||||
) -> Self | None:
|
||||
"""获取群组
|
||||
|
||||
参数:
|
||||
group_id: 群组id
|
||||
channel_id: 频道id.
|
||||
channel_id: 频道id
|
||||
clean_duplicates: 是否删除重复的记录,仅保留最新的
|
||||
|
||||
返回:
|
||||
Self: GroupConsole
|
||||
"""
|
||||
if channel_id:
|
||||
return await cls.get_or_none(group_id=group_id, channel_id=channel_id)
|
||||
return await cls.get_or_none(group_id=group_id, channel_id__isnull=True)
|
||||
return await cls.safe_get_or_none(
|
||||
group_id=group_id,
|
||||
channel_id=channel_id,
|
||||
clean_duplicates=clean_duplicates,
|
||||
)
|
||||
return await cls.safe_get_or_none(
|
||||
group_id=group_id,
|
||||
channel_id__isnull=True,
|
||||
clean_duplicates=clean_duplicates,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def is_super_group(cls, group_id: str) -> bool:
|
||||
@ -303,6 +348,9 @@ class GroupConsole(Model):
|
||||
if update_fields:
|
||||
await group.save(update_fields=update_fields)
|
||||
|
||||
# 更新缓存
|
||||
await cls._update_cache(group)
|
||||
|
||||
@classmethod
|
||||
async def set_unblock_plugin(
|
||||
cls,
|
||||
@ -339,6 +387,9 @@ class GroupConsole(Model):
|
||||
if update_fields:
|
||||
await group.save(update_fields=update_fields)
|
||||
|
||||
# 更新缓存
|
||||
await cls._update_cache(group)
|
||||
|
||||
@classmethod
|
||||
async def is_normal_block_plugin(
|
||||
cls, group_id: str, module: str, channel_id: str | None = None
|
||||
@ -442,6 +493,9 @@ class GroupConsole(Model):
|
||||
if update_fields:
|
||||
await group.save(update_fields=update_fields)
|
||||
|
||||
# 更新缓存
|
||||
await cls._update_cache(group)
|
||||
|
||||
@classmethod
|
||||
async def set_unblock_task(
|
||||
cls,
|
||||
@ -476,6 +530,9 @@ class GroupConsole(Model):
|
||||
if update_fields:
|
||||
await group.save(update_fields=update_fields)
|
||||
|
||||
# 更新缓存
|
||||
await cls._update_cache(group)
|
||||
|
||||
@classmethod
|
||||
def _run_script(cls):
|
||||
return [
|
||||
@ -483,4 +540,6 @@ class GroupConsole(Model):
|
||||
" character varying(255) NOT NULL DEFAULT '';",
|
||||
"ALTER TABLE group_console ADD superuser_block_task"
|
||||
" character varying(255) NOT NULL DEFAULT '';",
|
||||
"CREATE INDEX idx_group_console_group_id ON group_console(group_id);",
|
||||
"CREATE INDEX idx_group_console_group_null_channel ON group_console(group_id) WHERE channel_id IS NULL;", # 单独创建channel为空的索引 # noqa: E501
|
||||
]
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from tortoise import fields
|
||||
|
||||
from zhenxun.services.db_context import Model
|
||||
from zhenxun.utils.enum import CacheType
|
||||
|
||||
|
||||
class LevelUser(Model):
|
||||
@ -20,6 +21,11 @@ class LevelUser(Model):
|
||||
table_description = "用户权限数据库"
|
||||
unique_together = ("user_id", "group_id")
|
||||
|
||||
cache_type = CacheType.LEVEL
|
||||
"""缓存类型"""
|
||||
cache_key_field = ("user_id", "group_id")
|
||||
"""缓存键字段"""
|
||||
|
||||
@classmethod
|
||||
async def get_user_level(cls, user_id: str, group_id: str | None) -> int:
|
||||
"""获取用户在群内的等级
|
||||
@ -53,6 +59,9 @@ class LevelUser(Model):
|
||||
level: 权限等级
|
||||
group_flag: 是否被自动更新刷新权限 0:是, 1:否.
|
||||
"""
|
||||
if await cls.exists(user_id=user_id, group_id=group_id, user_level=level):
|
||||
# 权限相同时跳过
|
||||
return
|
||||
await cls.update_or_create(
|
||||
user_id=user_id,
|
||||
group_id=group_id,
|
||||
|
||||
@ -4,7 +4,7 @@ from tortoise import fields
|
||||
|
||||
from zhenxun.models.plugin_limit import PluginLimit # noqa: F401
|
||||
from zhenxun.services.db_context import Model
|
||||
from zhenxun.utils.enum import BlockType, PluginType
|
||||
from zhenxun.utils.enum import BlockType, CacheType, PluginType
|
||||
|
||||
|
||||
class PluginInfo(Model):
|
||||
@ -59,6 +59,11 @@ class PluginInfo(Model):
|
||||
table = "plugin_info"
|
||||
table_description = "插件基本信息"
|
||||
|
||||
cache_type = CacheType.PLUGINS
|
||||
"""缓存类型"""
|
||||
cache_key_field = "module"
|
||||
"""缓存键字段"""
|
||||
|
||||
@classmethod
|
||||
async def get_plugin(
|
||||
cls, load_status: bool = True, filter_parent: bool = True, **kwargs
|
||||
|
||||
@ -2,7 +2,7 @@ from tortoise import fields
|
||||
|
||||
from zhenxun.models.goods_info import GoodsInfo
|
||||
from zhenxun.services.db_context import Model
|
||||
from zhenxun.utils.enum import GoldHandle
|
||||
from zhenxun.utils.enum import CacheType, GoldHandle
|
||||
from zhenxun.utils.exception import GoodsNotFound, InsufficientGold
|
||||
|
||||
from .user_gold_log import UserGoldLog
|
||||
@ -29,6 +29,12 @@ class UserConsole(Model):
|
||||
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
|
||||
table = "user_console"
|
||||
table_description = "用户数据表"
|
||||
indexes = [("user_id",), ("uid",)] # noqa: RUF012
|
||||
|
||||
cache_type = CacheType.USERS
|
||||
"""缓存类型"""
|
||||
cache_key_field = "user_id"
|
||||
"""缓存键字段"""
|
||||
|
||||
@classmethod
|
||||
async def get_user(cls, user_id: str, platform: str | None = None) -> "UserConsole":
|
||||
@ -193,3 +199,10 @@ class UserConsole(Model):
|
||||
if goods := await GoodsInfo.get_or_none(goods_name=name):
|
||||
return await cls.use_props(user_id, goods.uuid, num, platform)
|
||||
raise GoodsNotFound("未找到商品...")
|
||||
|
||||
@classmethod
|
||||
async def _run_script(cls):
|
||||
return [
|
||||
"CREATE INDEX idx_user_console_user_id ON user_console(user_id);",
|
||||
"CREATE INDEX idx_user_console_uid ON user_console(uid);",
|
||||
]
|
||||
|
||||
1065
zhenxun/services/cache/__init__.py
vendored
Normal file
1065
zhenxun/services/cache/__init__.py
vendored
Normal file
File diff suppressed because it is too large
Load Diff
452
zhenxun/services/cache/cache_containers.py
vendored
Normal file
452
zhenxun/services/cache/cache_containers.py
vendored
Normal file
@ -0,0 +1,452 @@
|
||||
from dataclasses import dataclass
|
||||
import time
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheData(Generic[T]):
|
||||
"""缓存数据类,存储数据和过期时间"""
|
||||
|
||||
value: T
|
||||
expire_time: float = 0 # 0表示永不过期
|
||||
|
||||
|
||||
class CacheDict:
|
||||
"""缓存字典类,提供类似普通字典的接口,数据只存储在内存中"""
|
||||
|
||||
def __init__(self, name: str, expire: int = 0):
|
||||
"""初始化缓存字典
|
||||
|
||||
参数:
|
||||
name: 字典名称
|
||||
expire: 过期时间(秒),默认为0表示永不过期
|
||||
"""
|
||||
self.name = name.upper()
|
||||
self.expire = expire
|
||||
self._data: dict[str, CacheData[Any]] = {}
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
"""获取字典项
|
||||
|
||||
参数:
|
||||
key: 字典键
|
||||
|
||||
返回:
|
||||
Any: 字典值
|
||||
"""
|
||||
data = self._data.get(key)
|
||||
if data is None:
|
||||
return None
|
||||
|
||||
# 检查是否过期
|
||||
if data.expire_time > 0 and data.expire_time < time.time():
|
||||
del self._data[key]
|
||||
return None
|
||||
|
||||
return data.value
|
||||
|
||||
def __setitem__(self, key: str, value: Any) -> None:
|
||||
"""设置字典项
|
||||
|
||||
参数:
|
||||
key: 字典键
|
||||
value: 字典值
|
||||
"""
|
||||
# 计算过期时间
|
||||
expire_time = 0
|
||||
if self.expire > 0:
|
||||
expire_time = time.time() + self.expire
|
||||
|
||||
self._data[key] = CacheData(value=value, expire_time=expire_time)
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
"""删除字典项
|
||||
|
||||
参数:
|
||||
key: 字典键
|
||||
"""
|
||||
if key in self._data:
|
||||
del self._data[key]
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
"""检查键是否存在
|
||||
|
||||
参数:
|
||||
key: 字典键
|
||||
|
||||
返回:
|
||||
bool: 是否存在
|
||||
"""
|
||||
if key not in self._data:
|
||||
return False
|
||||
|
||||
# 检查是否过期
|
||||
data = self._data[key]
|
||||
if data.expire_time > 0 and data.expire_time < time.time():
|
||||
del self._data[key]
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
"""获取字典项,如果不存在返回默认值
|
||||
|
||||
参数:
|
||||
key: 字典键
|
||||
default: 默认值
|
||||
|
||||
返回:
|
||||
Any: 字典值或默认值
|
||||
"""
|
||||
value = self[key]
|
||||
return default if value is None else value
|
||||
|
||||
def set(self, key: str, value: Any, expire: int | None = None) -> None:
|
||||
"""设置字典项
|
||||
|
||||
参数:
|
||||
key: 字典键
|
||||
value: 字典值
|
||||
expire: 过期时间(秒),为None时使用默认值
|
||||
"""
|
||||
# 计算过期时间
|
||||
expire_time = 0
|
||||
if expire is not None and expire > 0:
|
||||
expire_time = time.time() + expire
|
||||
elif self.expire > 0:
|
||||
expire_time = time.time() + self.expire
|
||||
|
||||
self._data[key] = CacheData(value=value, expire_time=expire_time)
|
||||
|
||||
def pop(self, key: str, default: Any = None) -> Any:
|
||||
"""删除并返回字典项
|
||||
|
||||
参数:
|
||||
key: 字典键
|
||||
default: 默认值
|
||||
|
||||
返回:
|
||||
Any: 字典值或默认值
|
||||
"""
|
||||
if key not in self._data:
|
||||
return default
|
||||
|
||||
data = self._data.pop(key)
|
||||
|
||||
# 检查是否过期
|
||||
if data.expire_time > 0 and data.expire_time < time.time():
|
||||
return default
|
||||
|
||||
return data.value
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空字典"""
|
||||
self._data.clear()
|
||||
|
||||
def keys(self) -> list[str]:
|
||||
"""获取所有键
|
||||
|
||||
返回:
|
||||
list[str]: 键列表
|
||||
"""
|
||||
# 清理过期的键
|
||||
self._clean_expired()
|
||||
return list(self._data.keys())
|
||||
|
||||
def values(self) -> list[Any]:
|
||||
"""获取所有值
|
||||
|
||||
返回:
|
||||
list[Any]: 值列表
|
||||
"""
|
||||
# 清理过期的键
|
||||
self._clean_expired()
|
||||
return [data.value for data in self._data.values()]
|
||||
|
||||
def items(self) -> list[tuple[str, Any]]:
|
||||
"""获取所有键值对
|
||||
|
||||
返回:
|
||||
list[tuple[str, Any]]: 键值对列表
|
||||
"""
|
||||
# 清理过期的键
|
||||
self._clean_expired()
|
||||
return [(key, data.value) for key, data in self._data.items()]
|
||||
|
||||
def _clean_expired(self) -> None:
|
||||
"""清理过期的键"""
|
||||
now = time.time()
|
||||
expired_keys = [
|
||||
key
|
||||
for key, data in self._data.items()
|
||||
if data.expire_time > 0 and data.expire_time < now
|
||||
]
|
||||
for key in expired_keys:
|
||||
del self._data[key]
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""获取字典长度
|
||||
|
||||
返回:
|
||||
int: 字典长度
|
||||
"""
|
||||
# 清理过期的键
|
||||
self._clean_expired()
|
||||
return len(self._data)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""字符串表示
|
||||
|
||||
返回:
|
||||
str: 字符串表示
|
||||
"""
|
||||
# 清理过期的键
|
||||
self._clean_expired()
|
||||
return f"CacheDict({self.name}, {len(self._data)} items)"
|
||||
|
||||
|
||||
class CacheList:
|
||||
"""缓存列表类,提供类似普通列表的接口,数据只存储在内存中"""
|
||||
|
||||
def __init__(self, name: str, expire: int = 0):
|
||||
"""初始化缓存列表
|
||||
|
||||
参数:
|
||||
name: 列表名称
|
||||
expire: 过期时间(秒),默认为0表示永不过期
|
||||
"""
|
||||
self.name = name.upper()
|
||||
self.expire = expire
|
||||
self._data: list[CacheData[Any]] = []
|
||||
self._expire_time = 0
|
||||
|
||||
# 如果设置了过期时间,计算整个列表的过期时间
|
||||
if self.expire > 0:
|
||||
self._expire_time = time.time() + self.expire
|
||||
|
||||
def __getitem__(self, index: int) -> Any:
|
||||
"""获取列表项
|
||||
|
||||
参数:
|
||||
index: 列表索引
|
||||
|
||||
返回:
|
||||
Any: 列表值
|
||||
"""
|
||||
# 检查整个列表是否过期
|
||||
if self._is_expired():
|
||||
self.clear()
|
||||
raise IndexError(f"列表索引 {index} 超出范围")
|
||||
|
||||
if 0 <= index < len(self._data):
|
||||
return self._data[index].value
|
||||
raise IndexError(f"列表索引 {index} 超出范围")
|
||||
|
||||
def __setitem__(self, index: int, value: Any) -> None:
|
||||
"""设置列表项
|
||||
|
||||
参数:
|
||||
index: 列表索引
|
||||
value: 列表值
|
||||
"""
|
||||
# 检查整个列表是否过期
|
||||
if self._is_expired():
|
||||
self.clear()
|
||||
|
||||
# 确保索引有效
|
||||
while len(self._data) <= index:
|
||||
self._data.append(CacheData(value=None))
|
||||
self._data[index] = CacheData(value=value)
|
||||
|
||||
# 更新过期时间
|
||||
self._update_expire_time()
|
||||
|
||||
def __delitem__(self, index: int) -> None:
|
||||
"""删除列表项
|
||||
|
||||
参数:
|
||||
index: 列表索引
|
||||
"""
|
||||
# 检查整个列表是否过期
|
||||
if self._is_expired():
|
||||
self.clear()
|
||||
raise IndexError(f"列表索引 {index} 超出范围")
|
||||
|
||||
if 0 <= index < len(self._data):
|
||||
del self._data[index]
|
||||
# 更新过期时间
|
||||
self._update_expire_time()
|
||||
else:
|
||||
raise IndexError(f"列表索引 {index} 超出范围")
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""获取列表长度
|
||||
|
||||
返回:
|
||||
int: 列表长度
|
||||
"""
|
||||
# 检查整个列表是否过期
|
||||
if self._is_expired():
|
||||
self.clear()
|
||||
return len(self._data)
|
||||
|
||||
def append(self, value: Any) -> None:
|
||||
"""添加列表项
|
||||
|
||||
参数:
|
||||
value: 列表值
|
||||
"""
|
||||
# 检查整个列表是否过期
|
||||
if self._is_expired():
|
||||
self.clear()
|
||||
|
||||
self._data.append(CacheData(value=value))
|
||||
|
||||
# 更新过期时间
|
||||
self._update_expire_time()
|
||||
|
||||
def extend(self, values: list[Any]) -> None:
|
||||
"""扩展列表
|
||||
|
||||
参数:
|
||||
values: 要添加的值列表
|
||||
"""
|
||||
# 检查整个列表是否过期
|
||||
if self._is_expired():
|
||||
self.clear()
|
||||
|
||||
self._data.extend([CacheData(value=v) for v in values])
|
||||
|
||||
# 更新过期时间
|
||||
self._update_expire_time()
|
||||
|
||||
def insert(self, index: int, value: Any) -> None:
|
||||
"""插入列表项
|
||||
|
||||
参数:
|
||||
index: 插入位置
|
||||
value: 列表值
|
||||
"""
|
||||
# 检查整个列表是否过期
|
||||
if self._is_expired():
|
||||
self.clear()
|
||||
|
||||
self._data.insert(index, CacheData(value=value))
|
||||
|
||||
# 更新过期时间
|
||||
self._update_expire_time()
|
||||
|
||||
def pop(self, index: int = -1) -> Any:
|
||||
"""删除并返回列表项
|
||||
|
||||
参数:
|
||||
index: 列表索引,默认为最后一项
|
||||
|
||||
返回:
|
||||
Any: 列表值
|
||||
"""
|
||||
# 检查整个列表是否过期
|
||||
if self._is_expired():
|
||||
self.clear()
|
||||
raise IndexError("从空列表中弹出")
|
||||
|
||||
if not self._data:
|
||||
raise IndexError("从空列表中弹出")
|
||||
|
||||
item = self._data.pop(index)
|
||||
|
||||
# 更新过期时间
|
||||
self._update_expire_time()
|
||||
|
||||
return item.value
|
||||
|
||||
def remove(self, value: Any) -> None:
|
||||
"""删除第一个匹配的列表项
|
||||
|
||||
参数:
|
||||
value: 要删除的值
|
||||
"""
|
||||
# 检查整个列表是否过期
|
||||
if self._is_expired():
|
||||
self.clear()
|
||||
raise ValueError(f"{value} 不在列表中")
|
||||
|
||||
# 查找匹配的项
|
||||
for i, item in enumerate(self._data):
|
||||
if item.value == value:
|
||||
del self._data[i]
|
||||
# 更新过期时间
|
||||
self._update_expire_time()
|
||||
return
|
||||
|
||||
raise ValueError(f"{value} 不在列表中")
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空列表"""
|
||||
self._data.clear()
|
||||
# 重置过期时间
|
||||
self._update_expire_time()
|
||||
|
||||
def index(self, value: Any, start: int = 0, end: int | None = None) -> int:
|
||||
"""查找值的索引
|
||||
|
||||
参数:
|
||||
value: 要查找的值
|
||||
start: 起始索引
|
||||
end: 结束索引
|
||||
|
||||
返回:
|
||||
int: 索引位置
|
||||
"""
|
||||
# 检查整个列表是否过期
|
||||
if self._is_expired():
|
||||
self.clear()
|
||||
raise ValueError(f"{value} 不在列表中")
|
||||
|
||||
end = end if end is not None else len(self._data)
|
||||
|
||||
for i in range(start, min(end, len(self._data))):
|
||||
if self._data[i].value == value:
|
||||
return i
|
||||
|
||||
raise ValueError(f"{value} 不在列表中")
|
||||
|
||||
def count(self, value: Any) -> int:
|
||||
"""计算值出现的次数
|
||||
|
||||
参数:
|
||||
value: 要计数的值
|
||||
|
||||
返回:
|
||||
int: 出现次数
|
||||
"""
|
||||
# 检查整个列表是否过期
|
||||
if self._is_expired():
|
||||
self.clear()
|
||||
return 0
|
||||
|
||||
return sum(1 for item in self._data if item.value == value)
|
||||
|
||||
def _is_expired(self) -> bool:
|
||||
"""检查整个列表是否过期"""
|
||||
return self._expire_time > 0 and self._expire_time < time.time()
|
||||
|
||||
def _update_expire_time(self) -> None:
|
||||
"""更新过期时间"""
|
||||
if self.expire > 0:
|
||||
self._expire_time = time.time() + self.expire
|
||||
else:
|
||||
self._expire_time = 0
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""字符串表示
|
||||
|
||||
返回:
|
||||
str: 字符串表示
|
||||
"""
|
||||
# 检查整个列表是否过期
|
||||
if self._is_expired():
|
||||
self.clear()
|
||||
return f"CacheList({self.name}, {len(self._data)} items)"
|
||||
35
zhenxun/services/cache/config.py
vendored
Normal file
35
zhenxun/services/cache/config.py
vendored
Normal file
@ -0,0 +1,35 @@
|
||||
"""
|
||||
缓存系统配置
|
||||
"""
|
||||
|
||||
# 日志标识
|
||||
LOG_COMMAND = "CacheRoot"
|
||||
|
||||
# 默认缓存过期时间(秒)
|
||||
DEFAULT_EXPIRE = 600
|
||||
|
||||
# 缓存键前缀
|
||||
CACHE_KEY_PREFIX = "ZHENXUN"
|
||||
|
||||
# 缓存键分隔符
|
||||
CACHE_KEY_SEPARATOR = ":"
|
||||
|
||||
# 复合键分隔符(用于分隔tuple类型的cache_key_field)
|
||||
COMPOSITE_KEY_SEPARATOR = "_"
|
||||
|
||||
|
||||
# 缓存模式
|
||||
class CacheMode:
|
||||
# 内存缓存 - 使用内存存储缓存数据
|
||||
MEMORY = "MEMORY"
|
||||
# Redis缓存 - 使用Redis服务器存储缓存数据
|
||||
REDIS = "REDIS"
|
||||
# 不使用缓存 - 将使用ttl=0的内存缓存,相当于直接从数据库获取数据
|
||||
NONE = "NONE"
|
||||
|
||||
|
||||
SPECIAL_KEY_FORMATS = {
|
||||
"LEVEL": "{user_id}" + COMPOSITE_KEY_SEPARATOR + "{group_id}",
|
||||
"BAN": "{user_id}" + COMPOSITE_KEY_SEPARATOR + "{group_id}",
|
||||
"GROUPS": "{group_id}" + COMPOSITE_KEY_SEPARATOR + "{channel_id}",
|
||||
}
|
||||
653
zhenxun/services/data_access.py
Normal file
653
zhenxun/services/data_access.py
Normal file
@ -0,0 +1,653 @@
|
||||
from typing import Any, ClassVar, Generic, TypeVar, cast
|
||||
|
||||
from zhenxun.services.cache import Cache, CacheRoot, cache_config
|
||||
from zhenxun.services.cache.config import COMPOSITE_KEY_SEPARATOR, CacheMode
|
||||
from zhenxun.services.db_context import Model, with_db_timeout
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
T = TypeVar("T", bound=Model)
|
||||
|
||||
|
||||
class DataAccess(Generic[T]):
|
||||
"""数据访问层,根据配置决定是否使用缓存
|
||||
|
||||
使用示例:
|
||||
```python
|
||||
from zhenxun.services import DataAccess
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
|
||||
# 创建数据访问对象
|
||||
plugin_dao = DataAccess(PluginInfo)
|
||||
|
||||
# 获取单个数据
|
||||
plugin = await plugin_dao.get(module="example_module")
|
||||
|
||||
# 获取所有数据
|
||||
all_plugins = await plugin_dao.all()
|
||||
|
||||
# 筛选数据
|
||||
enabled_plugins = await plugin_dao.filter(status=True)
|
||||
|
||||
# 创建数据
|
||||
new_plugin = await plugin_dao.create(
|
||||
module="new_module",
|
||||
name="新插件",
|
||||
status=True
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
# 添加缓存统计信息
|
||||
_cache_stats: ClassVar[dict] = {}
|
||||
# 空结果标记
|
||||
_NULL_RESULT = "__NULL_RESULT_PLACEHOLDER__"
|
||||
# 默认空结果缓存时间(秒)- 设置为5分钟,避免频繁查询数据库
|
||||
_NULL_RESULT_TTL = 300
|
||||
|
||||
@classmethod
|
||||
def set_null_result_ttl(cls, seconds: int) -> None:
|
||||
"""设置空结果缓存时间
|
||||
|
||||
参数:
|
||||
seconds: 缓存时间(秒)
|
||||
"""
|
||||
if seconds < 0:
|
||||
raise ValueError("缓存时间不能为负数")
|
||||
cls._NULL_RESULT_TTL = seconds
|
||||
logger.info(f"已设置DataAccess空结果缓存时间为 {seconds} 秒")
|
||||
|
||||
@classmethod
|
||||
def get_null_result_ttl(cls) -> int:
|
||||
"""获取空结果缓存时间
|
||||
|
||||
返回:
|
||||
int: 缓存时间(秒)
|
||||
"""
|
||||
return cls._NULL_RESULT_TTL
|
||||
|
||||
def __init__(
|
||||
self, model_cls: type[T], key_field: str = "id", cache_type: str | None = None
|
||||
):
|
||||
"""初始化数据访问对象
|
||||
|
||||
参数:
|
||||
model_cls: 模型类
|
||||
key_field: 主键字段
|
||||
"""
|
||||
self.model_cls = model_cls
|
||||
self.key_field = getattr(model_cls, "cache_key_field", key_field)
|
||||
self.cache_type = getattr(model_cls, "cache_type", cache_type)
|
||||
|
||||
if not self.cache_type:
|
||||
raise ValueError("缓存类型不能为空")
|
||||
self.cache = Cache(self.cache_type)
|
||||
|
||||
# 初始化缓存统计
|
||||
if self.cache_type not in self._cache_stats:
|
||||
self._cache_stats[self.cache_type] = {
|
||||
"hits": 0, # 缓存命中次数
|
||||
"misses": 0, # 缓存未命中次数
|
||||
"null_hits": 0, # 空结果缓存命中次数
|
||||
"sets": 0, # 缓存设置次数
|
||||
"null_sets": 0, # 空结果缓存设置次数
|
||||
"deletes": 0, # 缓存删除次数
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_cache_stats(cls):
|
||||
"""获取缓存统计信息"""
|
||||
result = []
|
||||
for cache_type, stats in cls._cache_stats.items():
|
||||
hits = stats["hits"]
|
||||
null_hits = stats.get("null_hits", 0)
|
||||
misses = stats["misses"]
|
||||
total = hits + null_hits + misses
|
||||
hit_rate = ((hits + null_hits) / total * 100) if total > 0 else 0
|
||||
result.append(
|
||||
{
|
||||
"cache_type": cache_type,
|
||||
"hits": hits,
|
||||
"null_hits": null_hits,
|
||||
"misses": misses,
|
||||
"sets": stats["sets"],
|
||||
"null_sets": stats.get("null_sets", 0),
|
||||
"deletes": stats["deletes"],
|
||||
"hit_rate": f"{hit_rate:.2f}%",
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def reset_cache_stats(cls):
|
||||
"""重置缓存统计信息"""
|
||||
for stats in cls._cache_stats.values():
|
||||
stats["hits"] = 0
|
||||
stats["null_hits"] = 0
|
||||
stats["misses"] = 0
|
||||
stats["sets"] = 0
|
||||
stats["null_sets"] = 0
|
||||
stats["deletes"] = 0
|
||||
|
||||
def _build_cache_key_from_kwargs(self, **kwargs) -> str | None:
|
||||
"""从关键字参数构建缓存键
|
||||
|
||||
参数:
|
||||
**kwargs: 关键字参数
|
||||
|
||||
返回:
|
||||
str | None: 缓存键,如果无法构建则返回None
|
||||
"""
|
||||
if isinstance(self.key_field, tuple):
|
||||
# 多字段主键
|
||||
key_parts = []
|
||||
key_parts.extend(str(kwargs.get(field, "")) for field in self.key_field)
|
||||
return COMPOSITE_KEY_SEPARATOR.join(key_parts) if key_parts else None
|
||||
elif self.key_field in kwargs:
|
||||
# 单字段主键
|
||||
return str(kwargs[self.key_field])
|
||||
return None
|
||||
|
||||
async def safe_get_or_none(self, *args, **kwargs) -> T | None:
|
||||
"""安全的获取单条数据
|
||||
|
||||
参数:
|
||||
*args: 查询参数
|
||||
**kwargs: 查询参数
|
||||
|
||||
返回:
|
||||
Optional[T]: 查询结果,如果不存在返回None
|
||||
"""
|
||||
# 如果没有缓存类型,直接从数据库获取
|
||||
if not self.cache_type or cache_config.cache_mode == CacheMode.NONE:
|
||||
logger.debug(f"{self.model_cls.__name__} 直接从数据库获取数据: {kwargs}")
|
||||
return await with_db_timeout(
|
||||
self.model_cls.safe_get_or_none(*args, **kwargs),
|
||||
operation=f"{self.model_cls.__name__}.safe_get_or_none",
|
||||
)
|
||||
|
||||
# 尝试从缓存获取
|
||||
cache_key = None
|
||||
try:
|
||||
# 尝试构建缓存键
|
||||
cache_key = self._build_cache_key_from_kwargs(**kwargs)
|
||||
|
||||
# 如果成功构建缓存键,尝试从缓存获取
|
||||
if cache_key is not None:
|
||||
data = await self.cache.get(cache_key)
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} self.cache.get(cache_key)"
|
||||
f" 从缓存获取到的数据 {type(data)}: {data}"
|
||||
)
|
||||
if data == self._NULL_RESULT:
|
||||
# 空结果缓存命中
|
||||
self._cache_stats[self.cache_type]["null_hits"] += 1
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} 从缓存获取到空结果: {cache_key}"
|
||||
)
|
||||
return None
|
||||
elif data:
|
||||
# 缓存命中
|
||||
self._cache_stats[self.cache_type]["hits"] += 1
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} 从缓存获取数据成功: {cache_key}"
|
||||
)
|
||||
return cast(T, data)
|
||||
else:
|
||||
# 缓存未命中
|
||||
self._cache_stats[self.cache_type]["misses"] += 1
|
||||
logger.debug(f"{self.model_cls.__name__} 缓存未命中: {cache_key}")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.model_cls.__name__} 从缓存获取数据失败: {kwargs}", e=e)
|
||||
|
||||
# 如果缓存中没有,从数据库获取
|
||||
logger.debug(f"{self.model_cls.__name__} 从数据库获取数据: {kwargs}")
|
||||
data = await self.model_cls.safe_get_or_none(*args, **kwargs)
|
||||
|
||||
# 如果获取到数据,存入缓存
|
||||
if data:
|
||||
try:
|
||||
# 生成缓存键
|
||||
cache_key = self._build_cache_key_for_item(data)
|
||||
if cache_key is not None:
|
||||
# 存入缓存
|
||||
await self.cache.set(cache_key, data)
|
||||
self._cache_stats[self.cache_type]["sets"] += 1
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} 数据已存入缓存: {cache_key}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"{self.model_cls.__name__} 存入缓存失败,参数: {kwargs}", e=e
|
||||
)
|
||||
elif cache_key is not None:
|
||||
# 如果没有获取到数据,缓存空结果
|
||||
try:
|
||||
# 存入空结果缓存,使用较短的过期时间
|
||||
await self.cache.set(
|
||||
cache_key, self._NULL_RESULT, expire=self._NULL_RESULT_TTL
|
||||
)
|
||||
self._cache_stats[self.cache_type]["null_sets"] += 1
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} 空结果已存入缓存: {cache_key},"
|
||||
f" TTL={self._NULL_RESULT_TTL}秒"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"{self.model_cls.__name__} 存入空结果缓存失败,参数: {kwargs}", e=e
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
async def get_or_none(self, *args, **kwargs) -> T | None:
|
||||
"""获取单条数据
|
||||
|
||||
参数:
|
||||
*args: 查询参数
|
||||
**kwargs: 查询参数
|
||||
|
||||
返回:
|
||||
Optional[T]: 查询结果,如果不存在返回None
|
||||
"""
|
||||
# 如果没有缓存类型,直接从数据库获取
|
||||
if not self.cache_type or cache_config.cache_mode == CacheMode.NONE:
|
||||
logger.debug(f"{self.model_cls.__name__} 直接从数据库获取数据: {kwargs}")
|
||||
return await with_db_timeout(
|
||||
self.model_cls.get_or_none(*args, **kwargs),
|
||||
operation=f"{self.model_cls.__name__}.get_or_none",
|
||||
)
|
||||
|
||||
# 尝试从缓存获取
|
||||
cache_key = None
|
||||
try:
|
||||
# 尝试构建缓存键
|
||||
cache_key = self._build_cache_key_from_kwargs(**kwargs)
|
||||
|
||||
# 如果成功构建缓存键,尝试从缓存获取
|
||||
if cache_key is not None:
|
||||
data = await self.cache.get(cache_key)
|
||||
if data == self._NULL_RESULT:
|
||||
# 空结果缓存命中
|
||||
self._cache_stats[self.cache_type]["null_hits"] += 1
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} 从缓存获取到空结果: {cache_key}"
|
||||
)
|
||||
return None
|
||||
elif data:
|
||||
# 缓存命中
|
||||
self._cache_stats[self.cache_type]["hits"] += 1
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} 从缓存获取数据成功: {cache_key}"
|
||||
)
|
||||
return cast(T, data)
|
||||
else:
|
||||
# 缓存未命中
|
||||
self._cache_stats[self.cache_type]["misses"] += 1
|
||||
logger.debug(f"{self.model_cls.__name__} 缓存未命中: {cache_key}")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.model_cls.__name__} 从缓存获取数据失败: {kwargs}", e=e)
|
||||
|
||||
# 如果缓存中没有,从数据库获取
|
||||
logger.debug(f"{self.model_cls.__name__} 从数据库获取数据: {kwargs}")
|
||||
data = await self.model_cls.get_or_none(*args, **kwargs)
|
||||
|
||||
# 如果获取到数据,存入缓存
|
||||
if data:
|
||||
try:
|
||||
cache_key = self._build_cache_key_for_item(data)
|
||||
# 生成缓存键
|
||||
if cache_key is not None:
|
||||
# 存入缓存
|
||||
await self.cache.set(cache_key, data)
|
||||
self._cache_stats[self.cache_type]["sets"] += 1
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} 数据已存入缓存: {cache_key}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"{self.model_cls.__name__} 存入缓存失败,参数: {kwargs}", e=e
|
||||
)
|
||||
elif cache_key is not None:
|
||||
# 如果没有获取到数据,缓存空结果
|
||||
try:
|
||||
# 存入空结果缓存,使用较短的过期时间
|
||||
await self.cache.set(
|
||||
cache_key, self._NULL_RESULT, expire=self._NULL_RESULT_TTL
|
||||
)
|
||||
self._cache_stats[self.cache_type]["null_sets"] += 1
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} 空结果已存入缓存: {cache_key},"
|
||||
f" TTL={self._NULL_RESULT_TTL}秒"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"{self.model_cls.__name__} 存入空结果缓存失败,参数: {kwargs}", e=e
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
async def clear_cache(self, **kwargs) -> bool:
|
||||
"""只清除缓存,不影响数据库数据
|
||||
|
||||
参数:
|
||||
**kwargs: 查询参数,必须包含主键字段
|
||||
|
||||
返回:
|
||||
bool: 是否成功清除缓存
|
||||
"""
|
||||
# 如果没有缓存类型,直接返回True
|
||||
if not self.cache_type or cache_config.cache_mode == CacheMode.NONE:
|
||||
return True
|
||||
|
||||
try:
|
||||
# 构建缓存键
|
||||
cache_key = self._build_cache_key_from_kwargs(**kwargs)
|
||||
if cache_key is None:
|
||||
if isinstance(self.key_field, tuple):
|
||||
# 如果是复合键,检查缺少哪些字段
|
||||
missing_fields = [
|
||||
field for field in self.key_field if field not in kwargs
|
||||
]
|
||||
logger.error(
|
||||
f"清除{self.model_cls.__name__}缓存失败: "
|
||||
f"缺少主键字段 {', '.join(missing_fields)}"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"清除{self.model_cls.__name__}缓存失败: "
|
||||
f"缺少主键字段 {self.key_field}"
|
||||
)
|
||||
return False
|
||||
|
||||
# 删除缓存
|
||||
await self.cache.delete(cache_key)
|
||||
self._cache_stats[self.cache_type]["deletes"] += 1
|
||||
logger.debug(f"已清除{self.model_cls.__name__}缓存: {cache_key}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"清除{self.model_cls.__name__}缓存失败", e=e)
|
||||
return False
|
||||
|
||||
def _build_composite_key(self, data: T) -> str | None:
|
||||
"""构建复合缓存键
|
||||
|
||||
参数:
|
||||
data: 数据对象
|
||||
|
||||
返回:
|
||||
str | None: 构建的缓存键,如果无法构建则返回None
|
||||
"""
|
||||
# 如果是元组,表示多个字段组成键
|
||||
if isinstance(self.key_field, tuple):
|
||||
# 构建键参数列表
|
||||
key_parts = []
|
||||
for field in self.key_field:
|
||||
value = getattr(data, field, "")
|
||||
key_parts.append(value if value is not None else "")
|
||||
|
||||
# 如果没有有效参数,返回None
|
||||
return COMPOSITE_KEY_SEPARATOR.join(key_parts) if key_parts else None
|
||||
elif hasattr(data, self.key_field):
|
||||
value = getattr(data, self.key_field, None)
|
||||
return str(value) if value is not None else None
|
||||
|
||||
return None
|
||||
|
||||
def _build_cache_key_for_item(self, item: T) -> str | None:
|
||||
"""为数据项构建缓存键
|
||||
|
||||
参数:
|
||||
item: 数据项
|
||||
|
||||
返回:
|
||||
str | None: 缓存键,如果无法生成则返回None
|
||||
"""
|
||||
# 如果没有缓存类型,返回None
|
||||
if not self.cache_type:
|
||||
return None
|
||||
|
||||
# 获取缓存类型的配置信息
|
||||
cache_model = CacheRoot.get_model(self.cache_type)
|
||||
|
||||
if not cache_model.key_format:
|
||||
# 常规处理,使用主键作为缓存键
|
||||
return self._build_composite_key(item)
|
||||
# 构建键参数字典
|
||||
key_parts = []
|
||||
# 从格式字符串中提取所需的字段名
|
||||
import re
|
||||
|
||||
field_names = re.findall(r"{([^}]+)}", cache_model.key_format)
|
||||
|
||||
# 收集所有字段值
|
||||
for field in field_names:
|
||||
value = getattr(item, field, "")
|
||||
key_parts.append(value if value is not None else "")
|
||||
|
||||
return COMPOSITE_KEY_SEPARATOR.join(key_parts)
|
||||
|
||||
async def _cache_items(self, data_list: list[T]) -> None:
|
||||
"""将数据列表存入缓存
|
||||
|
||||
参数:
|
||||
data_list: 数据列表
|
||||
"""
|
||||
if (
|
||||
not data_list
|
||||
or not self.cache_type
|
||||
or cache_config.cache_mode == CacheMode.NONE
|
||||
):
|
||||
return
|
||||
|
||||
try:
|
||||
# 遍历数据列表,将每条数据存入缓存
|
||||
cached_count = 0
|
||||
for item in data_list:
|
||||
cache_key = self._build_cache_key_for_item(item)
|
||||
if cache_key is not None:
|
||||
await self.cache.set(cache_key, item)
|
||||
cached_count += 1
|
||||
self._cache_stats[self.cache_type]["sets"] += 1
|
||||
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} 批量缓存: {cached_count}/{len(data_list)}项"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.model_cls.__name__} 批量缓存失败", e=e)
|
||||
|
||||
async def filter(self, *args, **kwargs) -> list[T]:
|
||||
"""筛选数据
|
||||
|
||||
参数:
|
||||
*args: 查询参数
|
||||
**kwargs: 查询参数
|
||||
|
||||
返回:
|
||||
List[T]: 查询结果列表
|
||||
"""
|
||||
# 从数据库获取数据
|
||||
logger.debug(f"{self.model_cls.__name__} filter: 从数据库查询, 参数: {kwargs}")
|
||||
data_list = await self.model_cls.filter(*args, **kwargs)
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} filter: 查询结果数量: {len(data_list)}"
|
||||
)
|
||||
|
||||
# 将数据存入缓存
|
||||
await self._cache_items(data_list)
|
||||
|
||||
return data_list
|
||||
|
||||
async def all(self) -> list[T]:
|
||||
"""获取所有数据
|
||||
|
||||
返回:
|
||||
List[T]: 所有数据列表
|
||||
"""
|
||||
# 直接从数据库获取
|
||||
logger.debug(f"{self.model_cls.__name__} all: 从数据库查询所有数据")
|
||||
data_list = await self.model_cls.all()
|
||||
logger.debug(f"{self.model_cls.__name__} all: 查询结果数量: {len(data_list)}")
|
||||
|
||||
# 将数据存入缓存
|
||||
await self._cache_items(data_list)
|
||||
|
||||
return data_list
|
||||
|
||||
async def count(self, *args, **kwargs) -> int:
|
||||
"""获取数据数量
|
||||
|
||||
参数:
|
||||
*args: 查询参数
|
||||
**kwargs: 查询参数
|
||||
|
||||
返回:
|
||||
int: 数据数量
|
||||
"""
|
||||
# 直接从数据库获取数量
|
||||
return await self.model_cls.filter(*args, **kwargs).count()
|
||||
|
||||
async def exists(self, *args, **kwargs) -> bool:
|
||||
"""判断数据是否存在
|
||||
|
||||
参数:
|
||||
*args: 查询参数
|
||||
**kwargs: 查询参数
|
||||
|
||||
返回:
|
||||
bool: 是否存在
|
||||
"""
|
||||
# 直接从数据库判断是否存在
|
||||
return await self.model_cls.filter(*args, **kwargs).exists()
|
||||
|
||||
async def create(self, **kwargs) -> T:
|
||||
"""创建数据
|
||||
|
||||
参数:
|
||||
**kwargs: 创建参数
|
||||
|
||||
返回:
|
||||
T: 创建的数据
|
||||
"""
|
||||
# 创建数据
|
||||
logger.debug(f"{self.model_cls.__name__} create: 创建数据, 参数: {kwargs}")
|
||||
data = await self.model_cls.create(**kwargs)
|
||||
|
||||
# 如果有缓存类型,将数据存入缓存
|
||||
if self.cache_type and cache_config.cache_mode != CacheMode.NONE:
|
||||
try:
|
||||
# 生成缓存键
|
||||
cache_key = self._build_cache_key_for_item(data)
|
||||
if cache_key is not None:
|
||||
# 存入缓存
|
||||
await self.cache.set(cache_key, data)
|
||||
self._cache_stats[self.cache_type]["sets"] += 1
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} create: "
|
||||
f"新创建的数据已存入缓存: {cache_key}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"{self.model_cls.__name__} create: 存入缓存失败,参数: {kwargs}",
|
||||
e=e,
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
async def update_or_create(
|
||||
self, defaults: dict[str, Any] | None = None, **kwargs
|
||||
) -> tuple[T, bool]:
|
||||
"""更新或创建数据
|
||||
|
||||
参数:
|
||||
defaults: 默认值
|
||||
**kwargs: 查询参数
|
||||
|
||||
返回:
|
||||
tuple[T, bool]: (数据, 是否创建)
|
||||
"""
|
||||
# 更新或创建数据
|
||||
data, created = await self.model_cls.update_or_create(
|
||||
defaults=defaults, **kwargs
|
||||
)
|
||||
|
||||
# 如果有缓存类型,将数据存入缓存
|
||||
if self.cache_type and cache_config.cache_mode != CacheMode.NONE:
|
||||
try:
|
||||
# 生成缓存键
|
||||
cache_key = self._build_cache_key_for_item(data)
|
||||
if cache_key is not None:
|
||||
# 存入缓存
|
||||
await self.cache.set(cache_key, data)
|
||||
self._cache_stats[self.cache_type]["sets"] += 1
|
||||
logger.debug(f"更新或创建的数据已存入缓存: {cache_key}")
|
||||
except Exception as e:
|
||||
logger.error(f"存入缓存失败,参数: {kwargs}", e=e)
|
||||
|
||||
return data, created
|
||||
|
||||
async def delete(self, *args, **kwargs) -> int:
|
||||
"""删除数据
|
||||
|
||||
参数:
|
||||
*args: 查询参数
|
||||
**kwargs: 查询参数
|
||||
|
||||
返回:
|
||||
int: 删除的数据数量
|
||||
"""
|
||||
logger.debug(f"{self.model_cls.__name__} delete: 删除数据, 参数: {kwargs}")
|
||||
|
||||
# 如果有缓存类型且有key_field参数,先尝试删除缓存
|
||||
if self.cache_type and cache_config.cache_mode != CacheMode.NONE:
|
||||
try:
|
||||
# 尝试构建缓存键
|
||||
cache_key = self._build_cache_key_from_kwargs(**kwargs)
|
||||
|
||||
if cache_key is not None:
|
||||
# 如果成功构建缓存键,直接删除缓存
|
||||
await self.cache.delete(cache_key)
|
||||
self._cache_stats[self.cache_type]["deletes"] += 1
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} delete: 已删除缓存: {cache_key}"
|
||||
)
|
||||
else:
|
||||
# 否则需要先查询出要删除的数据,然后删除对应的缓存
|
||||
items = await self.model_cls.filter(*args, **kwargs)
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} delete:"
|
||||
f" 查询到 {len(items)} 条要删除的数据"
|
||||
)
|
||||
for item in items:
|
||||
item_cache_key = self._build_cache_key_for_item(item)
|
||||
if item_cache_key is not None:
|
||||
await self.cache.delete(item_cache_key)
|
||||
self._cache_stats[self.cache_type]["deletes"] += 1
|
||||
if items:
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} delete:"
|
||||
f" 已删除 {len(items)} 条数据的缓存"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.model_cls.__name__} delete: 删除缓存失败", e=e)
|
||||
|
||||
# 删除数据
|
||||
result = await self.model_cls.filter(*args, **kwargs).delete()
|
||||
logger.debug(
|
||||
f"{self.model_cls.__name__} delete: 已从数据库删除 {result} 条数据"
|
||||
)
|
||||
return result
|
||||
|
||||
def _generate_cache_key(self, data: T) -> str:
|
||||
"""根据数据对象生成缓存键
|
||||
|
||||
参数:
|
||||
data: 数据对象
|
||||
|
||||
返回:
|
||||
str: 缓存键
|
||||
"""
|
||||
# 使用新方法构建复合键
|
||||
if composite_key := self._build_composite_key(data):
|
||||
return composite_key
|
||||
|
||||
# 如果无法生成复合键,生成一个唯一键
|
||||
return f"object_{id(data)}"
|
||||
@ -1,37 +1,328 @@
|
||||
import nonebot
|
||||
import asyncio
|
||||
from collections.abc import Iterable
|
||||
import contextlib
|
||||
import time
|
||||
from typing import Any, ClassVar
|
||||
from typing_extensions import Self
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from nonebot import get_driver
|
||||
from nonebot.utils import is_coroutine_callable
|
||||
from tortoise import Tortoise
|
||||
from tortoise.backends.base.client import BaseDBAsyncClient
|
||||
from tortoise.connection import connections
|
||||
from tortoise.models import Model as Model_
|
||||
from tortoise.exceptions import IntegrityError, MultipleObjectsReturned
|
||||
from tortoise.models import Model as TortoiseModel
|
||||
from tortoise.transactions import in_transaction
|
||||
|
||||
from zhenxun.configs.config import BotConfig
|
||||
from zhenxun.services.cache import CacheRoot
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import DbLockType
|
||||
from zhenxun.utils.exception import HookPriorityException
|
||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||
|
||||
from .log import logger
|
||||
driver = get_driver()
|
||||
|
||||
SCRIPT_METHOD = []
|
||||
MODELS: list[str] = []
|
||||
|
||||
# 数据库操作超时设置(秒)
|
||||
DB_TIMEOUT_SECONDS = 3.0
|
||||
|
||||
driver = nonebot.get_driver()
|
||||
# 性能监控阈值(秒)
|
||||
SLOW_QUERY_THRESHOLD = 0.5
|
||||
|
||||
LOG_COMMAND = "DbContext"
|
||||
|
||||
|
||||
class Model(Model_):
|
||||
async def with_db_timeout(
|
||||
coro, timeout: float = DB_TIMEOUT_SECONDS, operation: str | None = None
|
||||
):
|
||||
"""带超时控制的数据库操作"""
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = await asyncio.wait_for(coro, timeout=timeout)
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > SLOW_QUERY_THRESHOLD and operation:
|
||||
logger.warning(f"慢查询: {operation} 耗时 {elapsed:.3f}s", LOG_COMMAND)
|
||||
return result
|
||||
except asyncio.TimeoutError:
|
||||
if operation:
|
||||
logger.error(f"数据库操作超时: {operation} (>{timeout}s)", LOG_COMMAND)
|
||||
raise
|
||||
|
||||
|
||||
class Model(TortoiseModel):
|
||||
"""
|
||||
自动添加模块
|
||||
|
||||
Args:
|
||||
Model_: Model
|
||||
增强的ORM基类,解决锁嵌套问题
|
||||
"""
|
||||
|
||||
sem_data: ClassVar[dict[str, dict[str, asyncio.Semaphore]]] = {}
|
||||
_current_locks: ClassVar[dict[int, DbLockType]] = {} # 跟踪当前协程持有的锁
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
if cls.__module__ not in MODELS:
|
||||
MODELS.append(cls.__module__)
|
||||
|
||||
if func := getattr(cls, "_run_script", None):
|
||||
SCRIPT_METHOD.append((cls.__module__, func))
|
||||
|
||||
@classmethod
|
||||
def get_cache_type(cls) -> str | None:
|
||||
"""获取缓存类型"""
|
||||
return getattr(cls, "cache_type", None)
|
||||
|
||||
@classmethod
|
||||
def get_cache_key_field(cls) -> str | tuple[str]:
|
||||
"""获取缓存键字段"""
|
||||
return getattr(cls, "cache_key_field", "id")
|
||||
|
||||
@classmethod
|
||||
def get_cache_key(cls, instance) -> str | None:
|
||||
"""获取缓存键
|
||||
|
||||
参数:
|
||||
instance: 模型实例
|
||||
|
||||
返回:
|
||||
str | None: 缓存键,如果无法获取则返回None
|
||||
"""
|
||||
from zhenxun.services.cache.config import COMPOSITE_KEY_SEPARATOR
|
||||
|
||||
key_field = cls.get_cache_key_field()
|
||||
|
||||
if isinstance(key_field, tuple):
|
||||
# 多字段主键
|
||||
key_parts = []
|
||||
for field in key_field:
|
||||
if hasattr(instance, field):
|
||||
value = getattr(instance, field, None)
|
||||
key_parts.append(value if value is not None else "")
|
||||
else:
|
||||
# 如果缺少任何必要的字段,返回None
|
||||
key_parts.append("")
|
||||
|
||||
# 如果没有有效参数,返回None
|
||||
return COMPOSITE_KEY_SEPARATOR.join(key_parts) if key_parts else None
|
||||
elif hasattr(instance, key_field):
|
||||
value = getattr(instance, key_field, None)
|
||||
return str(value) if value is not None else None
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_semaphore(cls, lock_type: DbLockType):
|
||||
enable_lock = getattr(cls, "enable_lock", None)
|
||||
if not enable_lock or lock_type not in enable_lock:
|
||||
return None
|
||||
|
||||
if cls.__name__ not in cls.sem_data:
|
||||
cls.sem_data[cls.__name__] = {}
|
||||
if lock_type not in cls.sem_data[cls.__name__]:
|
||||
cls.sem_data[cls.__name__][lock_type] = asyncio.Semaphore(1)
|
||||
return cls.sem_data[cls.__name__][lock_type]
|
||||
|
||||
@classmethod
|
||||
def _require_lock(cls, lock_type: DbLockType) -> bool:
|
||||
"""检查是否需要真正加锁"""
|
||||
task_id = id(asyncio.current_task())
|
||||
return cls._current_locks.get(task_id) != lock_type
|
||||
|
||||
@classmethod
|
||||
@contextlib.asynccontextmanager
|
||||
async def _lock_context(cls, lock_type: DbLockType):
|
||||
"""带重入检查的锁上下文"""
|
||||
task_id = id(asyncio.current_task())
|
||||
need_lock = cls._require_lock(lock_type)
|
||||
|
||||
if need_lock and (sem := cls.get_semaphore(lock_type)):
|
||||
cls._current_locks[task_id] = lock_type
|
||||
async with sem:
|
||||
yield
|
||||
cls._current_locks.pop(task_id, None)
|
||||
else:
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any
|
||||
) -> Self:
|
||||
"""创建数据(使用CREATE锁)"""
|
||||
async with cls._lock_context(DbLockType.CREATE):
|
||||
# 直接调用父类的_create方法避免触发save的锁
|
||||
result = await super().create(using_db=using_db, **kwargs)
|
||||
if cache_type := cls.get_cache_type():
|
||||
await CacheRoot.invalidate_cache(cache_type, cls.get_cache_key(result))
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def get_or_create(
|
||||
cls,
|
||||
defaults: dict | None = None,
|
||||
using_db: BaseDBAsyncClient | None = None,
|
||||
**kwargs: Any,
|
||||
) -> tuple[Self, bool]:
|
||||
"""获取或创建数据(无锁版本,依赖数据库约束)"""
|
||||
result = await super().get_or_create(
|
||||
defaults=defaults, using_db=using_db, **kwargs
|
||||
)
|
||||
if cache_type := cls.get_cache_type():
|
||||
await CacheRoot.invalidate_cache(cache_type, cls.get_cache_key(result[0]))
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def update_or_create(
|
||||
cls,
|
||||
defaults: dict | None = None,
|
||||
using_db: BaseDBAsyncClient | None = None,
|
||||
**kwargs: Any,
|
||||
) -> tuple[Self, bool]:
|
||||
"""更新或创建数据(使用UPSERT锁)"""
|
||||
async with cls._lock_context(DbLockType.UPSERT):
|
||||
try:
|
||||
# 先尝试更新(带行锁)
|
||||
async with in_transaction():
|
||||
if obj := await cls.filter(**kwargs).select_for_update().first():
|
||||
await obj.update_from_dict(defaults or {})
|
||||
await obj.save()
|
||||
result = (obj, False)
|
||||
else:
|
||||
# 创建时不重复加锁
|
||||
result = await cls.create(**kwargs, **(defaults or {})), True
|
||||
|
||||
if cache_type := cls.get_cache_type():
|
||||
await CacheRoot.invalidate_cache(
|
||||
cache_type, cls.get_cache_key(result[0])
|
||||
)
|
||||
return result
|
||||
except IntegrityError:
|
||||
# 处理极端情况下的唯一约束冲突
|
||||
obj = await cls.get(**kwargs)
|
||||
return obj, False
|
||||
|
||||
async def save(
|
||||
self,
|
||||
using_db: BaseDBAsyncClient | None = None,
|
||||
update_fields: Iterable[str] | None = None,
|
||||
force_create: bool = False,
|
||||
force_update: bool = False,
|
||||
):
|
||||
"""保存数据(根据操作类型自动选择锁)"""
|
||||
lock_type = (
|
||||
DbLockType.CREATE
|
||||
if getattr(self, "id", None) is None
|
||||
else DbLockType.UPDATE
|
||||
)
|
||||
async with self._lock_context(lock_type):
|
||||
await super().save(
|
||||
using_db=using_db,
|
||||
update_fields=update_fields,
|
||||
force_create=force_create,
|
||||
force_update=force_update,
|
||||
)
|
||||
if cache_type := getattr(self, "cache_type", None):
|
||||
await CacheRoot.invalidate_cache(
|
||||
cache_type, self.__class__.get_cache_key(self)
|
||||
)
|
||||
|
||||
async def delete(self, using_db: BaseDBAsyncClient | None = None):
|
||||
cache_type = getattr(self, "cache_type", None)
|
||||
key = self.__class__.get_cache_key(self) if cache_type else None
|
||||
# 执行删除操作
|
||||
await super().delete(using_db=using_db)
|
||||
|
||||
# 清除缓存
|
||||
if cache_type:
|
||||
await CacheRoot.invalidate_cache(cache_type, key)
|
||||
|
||||
@classmethod
|
||||
async def safe_get_or_none(
|
||||
cls,
|
||||
*args,
|
||||
using_db: BaseDBAsyncClient | None = None,
|
||||
clean_duplicates: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> Self | None:
|
||||
"""安全地获取一条记录或None,处理存在多个记录时返回最新的那个
|
||||
注意,默认会删除重复的记录,仅保留最新的
|
||||
|
||||
参数:
|
||||
*args: 查询参数
|
||||
using_db: 数据库连接
|
||||
clean_duplicates: 是否删除重复的记录,仅保留最新的
|
||||
**kwargs: 查询参数
|
||||
|
||||
返回:
|
||||
Self | None: 查询结果,如果不存在返回None
|
||||
"""
|
||||
try:
|
||||
# 先尝试使用 get_or_none 获取单个记录
|
||||
try:
|
||||
return await with_db_timeout(
|
||||
cls.get_or_none(*args, using_db=using_db, **kwargs),
|
||||
operation=f"{cls.__name__}.get_or_none",
|
||||
)
|
||||
except MultipleObjectsReturned:
|
||||
# 如果出现多个记录的情况,进行特殊处理
|
||||
logger.warning(
|
||||
f"{cls.__name__} safe_get_or_none 发现多个记录: {kwargs}",
|
||||
LOG_COMMAND,
|
||||
)
|
||||
|
||||
# 查询所有匹配记录
|
||||
records = await with_db_timeout(
|
||||
cls.filter(*args, **kwargs).all(),
|
||||
operation=f"{cls.__name__}.filter.all",
|
||||
)
|
||||
|
||||
if not records:
|
||||
return None
|
||||
|
||||
# 如果需要清理重复记录
|
||||
if clean_duplicates and hasattr(records[0], "id"):
|
||||
# 按 id 排序
|
||||
records = sorted(
|
||||
records, key=lambda x: getattr(x, "id", 0), reverse=True
|
||||
)
|
||||
for record in records[1:]:
|
||||
try:
|
||||
await with_db_timeout(
|
||||
record.delete(),
|
||||
operation=f"{cls.__name__}.delete_duplicate",
|
||||
)
|
||||
logger.info(
|
||||
f"{cls.__name__} 删除重复记录:"
|
||||
f" id={getattr(record, 'id', None)}",
|
||||
LOG_COMMAND,
|
||||
)
|
||||
except Exception as del_e:
|
||||
logger.error(f"删除重复记录失败: {del_e}")
|
||||
return records[0]
|
||||
# 如果不需要清理或没有 id 字段,则返回最新的记录
|
||||
if hasattr(cls, "id"):
|
||||
return await with_db_timeout(
|
||||
cls.filter(*args, **kwargs).order_by("-id").first(),
|
||||
operation=f"{cls.__name__}.filter.order_by.first",
|
||||
)
|
||||
# 如果没有 id 字段,则返回第一个记录
|
||||
return await with_db_timeout(
|
||||
cls.filter(*args, **kwargs).first(),
|
||||
operation=f"{cls.__name__}.filter.first",
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
f"数据库操作超时: {cls.__name__}.safe_get_or_none", LOG_COMMAND
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
# 其他类型的错误则继续抛出
|
||||
logger.error(
|
||||
f"数据库操作异常: {cls.__name__}.safe_get_or_none, {e!s}", LOG_COMMAND
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
class DbUrlIsNode(HookPriorityException):
|
||||
"""
|
||||
@ -49,6 +340,77 @@ class DbConnectError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
POSTGRESQL_CONFIG = {
|
||||
"max_size": 30, # 最大连接数
|
||||
"min_size": 5, # 最小保持的连接数(可选)
|
||||
}
|
||||
|
||||
|
||||
MYSQL_CONFIG = {
|
||||
"max_connections": 20, # 最大连接数
|
||||
"connect_timeout": 30, # 连接超时(可选)
|
||||
}
|
||||
|
||||
SQLITE_CONFIG = {
|
||||
"journal_mode": "WAL", # 提高并发写入性能
|
||||
"timeout": 30, # 锁等待超时(可选)
|
||||
}
|
||||
|
||||
|
||||
def get_config(db_url: str) -> dict:
|
||||
"""获取数据库配置"""
|
||||
parsed = urlparse(BotConfig.db_url)
|
||||
|
||||
# 基础配置
|
||||
config = {
|
||||
"connections": {
|
||||
"default": BotConfig.db_url # 默认直接使用连接字符串
|
||||
},
|
||||
"apps": {
|
||||
"models": {
|
||||
"models": MODELS,
|
||||
"default_connection": "default",
|
||||
}
|
||||
},
|
||||
"timezone": "Asia/Shanghai",
|
||||
}
|
||||
|
||||
# 根据数据库类型应用高级配置
|
||||
if parsed.scheme.startswith("postgres"):
|
||||
config["connections"]["default"] = {
|
||||
"engine": "tortoise.backends.asyncpg",
|
||||
"credentials": {
|
||||
"host": parsed.hostname,
|
||||
"port": parsed.port or 5432,
|
||||
"user": parsed.username,
|
||||
"password": parsed.password,
|
||||
"database": parsed.path[1:],
|
||||
},
|
||||
**POSTGRESQL_CONFIG,
|
||||
}
|
||||
elif parsed.scheme == "mysql":
|
||||
config["connections"]["default"] = {
|
||||
"engine": "tortoise.backends.mysql",
|
||||
"credentials": {
|
||||
"host": parsed.hostname,
|
||||
"port": parsed.port or 3306,
|
||||
"user": parsed.username,
|
||||
"password": parsed.password,
|
||||
"database": parsed.path[1:],
|
||||
},
|
||||
**MYSQL_CONFIG,
|
||||
}
|
||||
elif parsed.scheme == "sqlite":
|
||||
config["connections"]["default"] = {
|
||||
"engine": "tortoise.backends.sqlite",
|
||||
"credentials": {
|
||||
"file_path": parsed.path[1:] or ":memory:",
|
||||
},
|
||||
**SQLITE_CONFIG,
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
@PriorityLifecycle.on_startup(priority=1)
|
||||
async def init():
|
||||
if not BotConfig.db_url:
|
||||
@ -64,9 +426,7 @@ async def init():
|
||||
raise DbUrlIsNode("\n" + error.strip())
|
||||
try:
|
||||
await Tortoise.init(
|
||||
db_url=BotConfig.db_url,
|
||||
modules={"models": MODELS},
|
||||
timezone="Asia/Shanghai",
|
||||
config=get_config(BotConfig.db_url),
|
||||
)
|
||||
if SCRIPT_METHOD:
|
||||
db = Tortoise.get_connection("default")
|
||||
@ -85,13 +445,17 @@ async def init():
|
||||
for sql in sql_list:
|
||||
logger.debug(f"执行SQL: {sql}")
|
||||
try:
|
||||
await db.execute_query_dict(sql)
|
||||
await asyncio.wait_for(
|
||||
db.execute_query_dict(sql), timeout=DB_TIMEOUT_SECONDS
|
||||
)
|
||||
# await TestSQL.raw(sql)
|
||||
except Exception as e:
|
||||
logger.debug(f"执行SQL: {sql} 错误...", e=e)
|
||||
if sql_list:
|
||||
logger.debug("SCRIPT_METHOD方法执行完毕!")
|
||||
logger.debug("开始生成数据库表结构...")
|
||||
await Tortoise.generate_schemas()
|
||||
logger.debug("数据库表结构生成完毕!")
|
||||
logger.info("Database loaded successfully!")
|
||||
except Exception as e:
|
||||
raise DbConnectError(f"数据库连接错误... e:{e}") from e
|
||||
|
||||
@ -469,7 +469,7 @@ class Notebook:
|
||||
template_name="main.html",
|
||||
templates={"elements": self._data},
|
||||
pages={
|
||||
"viewport": {"width": 700, "height": 1000},
|
||||
"viewport": {"width": 700, "height": 10},
|
||||
"base_url": f"file://{TEMPLATE_PATH}",
|
||||
},
|
||||
wait=2,
|
||||
|
||||
@ -53,9 +53,7 @@ class CommonUtils:
|
||||
if await GroupConsole.is_block_task(group_id, module):
|
||||
"""群组是否禁用被动"""
|
||||
return True
|
||||
if g := await GroupConsole.get_or_none(
|
||||
group_id=group_id, channel_id__isnull=True
|
||||
):
|
||||
if g := await GroupConsole.get_group(group_id=group_id):
|
||||
"""群组权限是否小于0"""
|
||||
if g.level < 0:
|
||||
return True
|
||||
|
||||
@ -44,6 +44,44 @@ class EventLogType(StrEnum):
|
||||
"""主动退群"""
|
||||
|
||||
|
||||
class CacheType(StrEnum):
|
||||
"""
|
||||
缓存类型
|
||||
"""
|
||||
|
||||
PLUGINS = "GLOBAL_ALL_PLUGINS"
|
||||
"""全局全部插件"""
|
||||
GROUPS = "GLOBAL_ALL_GROUPS"
|
||||
"""全局全部群组"""
|
||||
USERS = "GLOBAL_ALL_USERS"
|
||||
"""全部用户"""
|
||||
BAN = "GLOBAL_ALL_BAN"
|
||||
"""全局ban列表"""
|
||||
BOT = "GLOBAL_BOT"
|
||||
"""全局bot信息"""
|
||||
LEVEL = "GLOBAL_USER_LEVEL"
|
||||
"""用户权限"""
|
||||
LIMIT = "GLOBAL_LIMIT"
|
||||
"""插件限制"""
|
||||
|
||||
|
||||
class DbLockType(StrEnum):
|
||||
"""
|
||||
锁类型
|
||||
"""
|
||||
|
||||
CREATE = "CREATE"
|
||||
"""创建"""
|
||||
DELETE = "DELETE"
|
||||
"""删除"""
|
||||
UPDATE = "UPDATE"
|
||||
"""更新"""
|
||||
QUERY = "QUERY"
|
||||
"""查询"""
|
||||
UPSERT = "UPSERT"
|
||||
"""创建或更新"""
|
||||
|
||||
|
||||
class GoldHandle(StrEnum):
|
||||
"""
|
||||
金币处理
|
||||
|
||||
@ -49,6 +49,9 @@ async def _():
|
||||
try:
|
||||
for priority in priority_list:
|
||||
for func in priority_data[priority]:
|
||||
logger.debug(
|
||||
f"执行优先级 [{priority}] on_startup 方法: {func.__module__}"
|
||||
)
|
||||
if is_coroutine_callable(func):
|
||||
await func()
|
||||
else:
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from datetime import date, datetime
|
||||
import os
|
||||
from pathlib import Path
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import httpx
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
import pypinyin
|
||||
import pytz
|
||||
|
||||
@ -13,43 +15,53 @@ from zhenxun.configs.config import Config
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class EntityIDs:
|
||||
user_id: str
|
||||
"""用户id"""
|
||||
group_id: str | None
|
||||
"""群组id"""
|
||||
channel_id: str | None
|
||||
"""频道id"""
|
||||
|
||||
|
||||
class ResourceDirManager:
|
||||
"""
|
||||
临时文件管理器
|
||||
"""
|
||||
|
||||
temp_path = [] # noqa: RUF012
|
||||
temp_path: ClassVar[set[Path]] = set()
|
||||
|
||||
@classmethod
|
||||
def __tree_append(cls, path: Path):
|
||||
"""递归添加文件夹
|
||||
|
||||
参数:
|
||||
path: 文件夹路径
|
||||
"""
|
||||
def __tree_append(cls, path: Path, deep: int = 1, current: int = 0):
|
||||
"""递归添加文件夹"""
|
||||
if current >= deep and deep != -1:
|
||||
return
|
||||
path = path.resolve() # 标准化路径
|
||||
for f in os.listdir(path):
|
||||
file = path / f
|
||||
file = (path / f).resolve() # 标准化子路径
|
||||
if file.is_dir():
|
||||
if file not in cls.temp_path:
|
||||
cls.temp_path.append(file)
|
||||
logger.debug(f"添加临时文件夹: {path}")
|
||||
cls.__tree_append(file)
|
||||
cls.temp_path.add(file)
|
||||
logger.debug(f"添加临时文件夹: {file}")
|
||||
cls.__tree_append(file, deep, current + 1)
|
||||
|
||||
@classmethod
|
||||
def add_temp_dir(cls, path: str | Path, tree: bool = False):
|
||||
def add_temp_dir(cls, path: str | Path, tree: bool = False, deep: int = 1):
|
||||
"""添加临时清理文件夹,这些文件夹会被自动清理
|
||||
|
||||
参数:
|
||||
path: 文件夹路径
|
||||
tree: 是否递归添加文件夹
|
||||
deep: 深度, -1 为无限深度
|
||||
"""
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
if path not in cls.temp_path:
|
||||
cls.temp_path.append(path)
|
||||
cls.temp_path.add(path)
|
||||
logger.debug(f"添加临时文件夹: {path}")
|
||||
if tree:
|
||||
cls.__tree_append(path)
|
||||
cls.__tree_append(path, deep)
|
||||
|
||||
|
||||
class CountLimiter:
|
||||
@ -230,6 +242,27 @@ def is_valid_date(date_text: str, separator: str = "-") -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def get_entity_ids(session: Uninfo) -> EntityIDs:
|
||||
"""获取用户id,群组id,频道id
|
||||
|
||||
参数:
|
||||
session: Uninfo
|
||||
|
||||
返回:
|
||||
EntityIDs: 用户id,群组id,频道id
|
||||
"""
|
||||
user_id = session.user.id
|
||||
group_id = None
|
||||
channel_id = None
|
||||
if session.group:
|
||||
if session.group.parent:
|
||||
group_id = session.group.parent.id
|
||||
channel_id = session.group.id
|
||||
else:
|
||||
group_id = session.group.id
|
||||
return EntityIDs(user_id=user_id, group_id=group_id, channel_id=channel_id)
|
||||
|
||||
|
||||
def is_number(text: str) -> bool:
|
||||
"""是否为数字
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user