mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 14:22:55 +08:00
✨ 引入缓存机制 (#1889)
* 添加全局cache * ✨ 构建缓存,hook使用缓存 * ✨ 新增数据库Model方法监控 * ✨ 数据库添加semaphore锁 * 🩹 优化webapi返回数据 * ✨ 添加增量缓存与缓存过期 * 🎨 优化检测代码结构 * ⚡ 优化hook权限检测性能 * 🐛 添加新异常判断跳过权限检测 * ✨ 添加插件limit缓存 * 🎨 代码格式优化 * 🐛 修复代码导入 * 🐛 修复刷新时检查 * 👽 Rename exception for missing database URL in initialization * ♿ Update default database URL to SQLite in configuration * 🔧 Update tortoise-orm and aiocache dependencies restrictions; add optional redis and asyncpg support * 🐛 修复ban检测 * 🐛 修复所有插件关闭时缓存更新 * 🐛 尝试迁移至aiocache * 🐛 完善aiocache缓存 * ⚡ 代码性能优化 * 🐛 移除获取封禁缓存时的日志记录 * 🐛 修复缓存类型声明,优化封禁用户处理逻辑 * 🐛 优化LevelUser权限更新逻辑及数据库迁移 * ✨ cache支持redis连接 * 🚨 auto fix by pre-commit hooks * ⚡ :增强获取群组的安全性和准确性。同时,优化了缓存管理中的相关逻辑,确保缓存操作的一致性。 * ✨ feat(auth_limit): 将插件初始化逻辑的启动装饰器更改为优先级管理器 * 🔧 修复日志记录级别 * 🔧 更新数据库连接字符串 * 🔧 更新数据库连接字符串为内存数据库,并优化权限检查逻辑 * ✨ feat(cache): 增加缓存功能配置项,并新增数据访问层以支持缓存逻辑 * ♻️ 重构cache * ✨ feat(cache): 增强缓存管理,新增缓存字典和缓存列表功能,支持过期时间管理 * 🔧 修复Notebook类中的viewport高度设置,将其从1000调整为10 * ✨ 更新插件管理逻辑,替换缓存服务为CacheRoot并优化缓存失效处理 * ✨ 更新RegisterConfig类中的type字段 * ✨ 修复清理重复记录逻辑,确保检查记录的id属性有效性 * ⚡ 超级无敌大优化,解决延迟与卡死问题 * ✨ 更新封禁功能,增加封禁时长参数和描述,优化插件信息返回结构 * ✨ 更新zhenxun_help.py中的viewport高度,将其从453调整为10,以优化页面显示效果 * ✨ 优化插件分类逻辑,增加插件ID排序,并更新插件信息返回结构 --------- Co-authored-by: BalconyJH <balconyjh@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
6283c3d13d
commit
8649aaaa54
14
.env.dev
14
.env.dev
@ -27,6 +27,18 @@ QBOT_ID_DATA = '{
|
|||||||
# 示例: "sqlite:data/db/zhenxun.db" 在data目录下建立db文件夹
|
# 示例: "sqlite:data/db/zhenxun.db" 在data目录下建立db文件夹
|
||||||
DB_URL = ""
|
DB_URL = ""
|
||||||
|
|
||||||
|
# NONE: 不使用缓存, MEMORY: 使用内存缓存, REDIS: 使用Redis缓存
|
||||||
|
CACHE_MODE = NONE
|
||||||
|
# REDIS配置,使用REDIS替换Cache内存缓存
|
||||||
|
# REDIS地址
|
||||||
|
# REDIS_HOST = "127.0.0.1"
|
||||||
|
# REDIS端口
|
||||||
|
# REDIS_PORT = 6379
|
||||||
|
# REDIS密码
|
||||||
|
# REDIS_PASSWORD = ""
|
||||||
|
# REDIS过期时间
|
||||||
|
# REDIS_EXPIRE = 600
|
||||||
|
|
||||||
# 系统代理
|
# 系统代理
|
||||||
# SYSTEM_PROXY = "http://127.0.0.1:7890"
|
# SYSTEM_PROXY = "http://127.0.0.1:7890"
|
||||||
|
|
||||||
@ -40,7 +52,7 @@ PLATFORM_SUPERUSERS = '
|
|||||||
DRIVER=~fastapi+~httpx+~websockets
|
DRIVER=~fastapi+~httpx+~websockets
|
||||||
|
|
||||||
|
|
||||||
# LOG_LEVEL=DEBUG
|
# LOG_LEVEL = DEBUG
|
||||||
# 服务器和端口
|
# 服务器和端口
|
||||||
HOST = 127.0.0.1
|
HOST = 127.0.0.1
|
||||||
PORT = 8080
|
PORT = 8080
|
||||||
|
|||||||
951
poetry.lock
generated
951
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -16,7 +16,7 @@ python = "^3.10"
|
|||||||
playwright = "^1.41.1"
|
playwright = "^1.41.1"
|
||||||
nonebot-adapter-onebot = "^2.3.1"
|
nonebot-adapter-onebot = "^2.3.1"
|
||||||
nonebot-plugin-apscheduler = "^0.5"
|
nonebot-plugin-apscheduler = "^0.5"
|
||||||
tortoise-orm = { extras = ["asyncpg"], version = "^0.20.0" }
|
tortoise-orm = "^0.20.0"
|
||||||
cattrs = "^23.2.3"
|
cattrs = "^23.2.3"
|
||||||
ruamel-yaml = "^0.18.5"
|
ruamel-yaml = "^0.18.5"
|
||||||
strenum = "^0.4.15"
|
strenum = "^0.4.15"
|
||||||
@ -39,7 +39,7 @@ dateparser = "^1.2.0"
|
|||||||
bilireq = "0.2.3post0"
|
bilireq = "0.2.3post0"
|
||||||
python-jose = { extras = ["cryptography"], version = "^3.3.0" }
|
python-jose = { extras = ["cryptography"], version = "^3.3.0" }
|
||||||
python-multipart = "^0.0.9"
|
python-multipart = "^0.0.9"
|
||||||
aiocache = "^0.12.2"
|
aiocache = {extras = ["redis"], version = "^0.12.3"}
|
||||||
py-cpuinfo = "^9.0.0"
|
py-cpuinfo = "^9.0.0"
|
||||||
nonebot-plugin-alconna = "^0.54.0"
|
nonebot-plugin-alconna = "^0.54.0"
|
||||||
tenacity = "^9.0.0"
|
tenacity = "^9.0.0"
|
||||||
@ -47,6 +47,9 @@ nonebot-plugin-uninfo = ">0.4.1"
|
|||||||
nonebot-plugin-waiter = "^0.8.1"
|
nonebot-plugin-waiter = "^0.8.1"
|
||||||
multidict = ">=6.0.0,!=6.3.2"
|
multidict = ">=6.0.0,!=6.3.2"
|
||||||
|
|
||||||
|
redis = { version = ">=5", optional = true }
|
||||||
|
asyncpg = { version = ">=0.20.0", optional = true }
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
nonebug = "^0.4"
|
nonebug = "^0.4"
|
||||||
pytest-cov = "^5.0.0"
|
pytest-cov = "^5.0.0"
|
||||||
@ -57,6 +60,9 @@ respx = "^0.21.1"
|
|||||||
ruff = "^0.8.0"
|
ruff = "^0.8.0"
|
||||||
pre-commit = "^4.0.0"
|
pre-commit = "^4.0.0"
|
||||||
|
|
||||||
|
[tool.poetry.extras]
|
||||||
|
redis = ["redis"]
|
||||||
|
postgresql = ["asyncpg"]
|
||||||
|
|
||||||
[tool.nonebot]
|
[tool.nonebot]
|
||||||
plugins = [
|
plugins = [
|
||||||
|
|||||||
@ -87,13 +87,17 @@ __plugin_meta__ = PluginMetadata(
|
|||||||
smart_tools=[
|
smart_tools=[
|
||||||
AICallableTag(
|
AICallableTag(
|
||||||
name="call_ban",
|
name="call_ban",
|
||||||
description="某人多次(至少三次)辱骂你,调用此方法进行封禁",
|
description="如果你讨厌某个人(好感度过低并让你感到困扰,或者多次辱骂你),调用此方法进行封禁,调用该方法后要告知用户被封禁和原因",
|
||||||
parameters=AICallableParam(
|
parameters=AICallableParam(
|
||||||
type="object",
|
type="object",
|
||||||
properties={
|
properties={
|
||||||
"user_id": AICallableProperties(
|
"user_id": AICallableProperties(
|
||||||
type="string", description="用户的id"
|
type="string", description="用户的id"
|
||||||
),
|
),
|
||||||
|
"duration": AICallableProperties(
|
||||||
|
type="integer",
|
||||||
|
description="封禁时长(选择的值只能是1-360),单位为分钟,如果频繁触发,按情况增加",
|
||||||
|
),
|
||||||
},
|
},
|
||||||
required=["user_id"],
|
required=["user_id"],
|
||||||
),
|
),
|
||||||
|
|||||||
@ -9,14 +9,14 @@ from zhenxun.services.log import logger
|
|||||||
from zhenxun.utils.image_utils import BuildImage, ImageTemplate
|
from zhenxun.utils.image_utils import BuildImage, ImageTemplate
|
||||||
|
|
||||||
|
|
||||||
async def call_ban(user_id: str):
|
async def call_ban(user_id: str, duration: int = 1):
|
||||||
"""调用ban
|
"""调用ban
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
user_id: 用户id
|
user_id: 用户id
|
||||||
"""
|
"""
|
||||||
await BanConsole.ban(user_id, None, 9, 60 * 12)
|
await BanConsole.ban(user_id, None, 9, duration * 60)
|
||||||
logger.info("辱骂次数过多,已将用户加入黑名单...", "ban", session=user_id)
|
logger.info("被讨厌了,已将用户加入黑名单...", "ban", session=user_id)
|
||||||
|
|
||||||
|
|
||||||
class BanManage:
|
class BanManage:
|
||||||
@ -114,7 +114,7 @@ class BanManage:
|
|||||||
if not is_superuser and user_id and session.id1:
|
if not is_superuser and user_id and session.id1:
|
||||||
user_level = await LevelUser.get_user_level(session.id1, group_id)
|
user_level = await LevelUser.get_user_level(session.id1, group_id)
|
||||||
if idx:
|
if idx:
|
||||||
ban_data = await BanConsole.get_or_none(id=idx)
|
ban_data = await BanConsole.get_ban(id=idx)
|
||||||
if not ban_data:
|
if not ban_data:
|
||||||
return False, "该用户/群组不在黑名单中捏..."
|
return False, "该用户/群组不在黑名单中捏..."
|
||||||
if ban_data.ban_level > user_level:
|
if ban_data.ban_level > user_level:
|
||||||
|
|||||||
@ -1,10 +1,13 @@
|
|||||||
import os
|
import os
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
from zhenxun.configs.path_config import DATA_PATH, IMAGE_PATH
|
from zhenxun.configs.path_config import DATA_PATH, IMAGE_PATH
|
||||||
from zhenxun.models.group_console import GroupConsole
|
from zhenxun.models.group_console import GroupConsole
|
||||||
from zhenxun.models.plugin_info import PluginInfo
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
from zhenxun.models.task_info import TaskInfo
|
from zhenxun.models.task_info import TaskInfo
|
||||||
from zhenxun.utils.enum import BlockType, PluginType
|
from zhenxun.services.cache import CacheRoot
|
||||||
|
from zhenxun.utils.common_utils import CommonUtils
|
||||||
|
from zhenxun.utils.enum import BlockType, CacheType, PluginType
|
||||||
from zhenxun.utils.exception import GroupInfoNotFound
|
from zhenxun.utils.exception import GroupInfoNotFound
|
||||||
from zhenxun.utils.image_utils import BuildImage, ImageTemplate, RowStyle
|
from zhenxun.utils.image_utils import BuildImage, ImageTemplate, RowStyle
|
||||||
|
|
||||||
@ -116,9 +119,7 @@ async def build_task(group_id: str | None) -> BuildImage:
|
|||||||
column_name = ["ID", "模块", "名称", "群组状态", "全局状态", "运行时间"]
|
column_name = ["ID", "模块", "名称", "群组状态", "全局状态", "运行时间"]
|
||||||
group = None
|
group = None
|
||||||
if group_id:
|
if group_id:
|
||||||
group = await GroupConsole.get_or_none(
|
group = await GroupConsole.get_group(group_id=group_id)
|
||||||
group_id=group_id, channel_id__isnull=True
|
|
||||||
)
|
|
||||||
if not group:
|
if not group:
|
||||||
raise GroupInfoNotFound()
|
raise GroupInfoNotFound()
|
||||||
else:
|
else:
|
||||||
@ -200,26 +201,26 @@ class PluginManager:
|
|||||||
)
|
)
|
||||||
return f"成功将所有功能进群默认状态修改为: {'开启' if status else '关闭'}"
|
return f"成功将所有功能进群默认状态修改为: {'开启' if status else '关闭'}"
|
||||||
if group_id:
|
if group_id:
|
||||||
if group := await GroupConsole.get_or_none(
|
if group := await GroupConsole.get_group(group_id=group_id):
|
||||||
group_id=group_id, channel_id__isnull=True
|
module_list = cast(
|
||||||
):
|
list[str],
|
||||||
module_list = await PluginInfo.filter(
|
await PluginInfo.filter(plugin_type=PluginType.NORMAL).values_list(
|
||||||
plugin_type=PluginType.NORMAL
|
"module", flat=True
|
||||||
).values_list("module", flat=True)
|
),
|
||||||
|
)
|
||||||
if status:
|
if status:
|
||||||
for module in module_list:
|
# 开启所有功能 - 清空禁用列表
|
||||||
group.block_plugin = group.block_plugin.replace(
|
group.block_plugin = ""
|
||||||
f"<{module},", ""
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
module_list = [f"<{module}" for module in module_list]
|
# 关闭所有功能 - 将模块列表转换为禁用格式
|
||||||
group.block_plugin = ",".join(module_list) + "," # type: ignore
|
group.block_plugin = CommonUtils.convert_module_format(module_list)
|
||||||
await group.save(update_fields=["block_plugin"])
|
await group.save(update_fields=["block_plugin"])
|
||||||
return f"成功将此群组所有功能状态修改为: {'开启' if status else '关闭'}"
|
return f"成功将此群组所有功能状态修改为: {'开启' if status else '关闭'}"
|
||||||
return "获取群组失败..."
|
return "获取群组失败..."
|
||||||
await PluginInfo.filter(plugin_type=PluginType.NORMAL).update(
|
await PluginInfo.filter(plugin_type=PluginType.NORMAL).update(
|
||||||
status=status, block_type=None if status else BlockType.ALL
|
status=status, block_type=None if status else BlockType.ALL
|
||||||
)
|
)
|
||||||
|
await CacheRoot.invalidate_cache(CacheType.PLUGINS)
|
||||||
return f"成功将所有功能全局状态修改为: {'开启' if status else '关闭'}"
|
return f"成功将所有功能全局状态修改为: {'开启' if status else '关闭'}"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -232,9 +233,7 @@ class PluginManager:
|
|||||||
返回:
|
返回:
|
||||||
bool: 是否醒来
|
bool: 是否醒来
|
||||||
"""
|
"""
|
||||||
if c := await GroupConsole.get_or_none(
|
if c := await GroupConsole.get_group(group_id=group_id):
|
||||||
group_id=group_id, channel_id__isnull=True
|
|
||||||
):
|
|
||||||
return c.status
|
return c.status
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -245,9 +244,11 @@ class PluginManager:
|
|||||||
参数:
|
参数:
|
||||||
group_id: 群组id
|
group_id: 群组id
|
||||||
"""
|
"""
|
||||||
await GroupConsole.filter(group_id=group_id, channel_id__isnull=True).update(
|
group, _ = await GroupConsole.get_or_create(
|
||||||
status=False
|
group_id=group_id, channel_id__isnull=True
|
||||||
)
|
)
|
||||||
|
group.status = False
|
||||||
|
await group.save(update_fields=["status"])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def wake(cls, group_id: str):
|
async def wake(cls, group_id: str):
|
||||||
@ -256,9 +257,11 @@ class PluginManager:
|
|||||||
参数:
|
参数:
|
||||||
group_id: 群组id
|
group_id: 群组id
|
||||||
"""
|
"""
|
||||||
await GroupConsole.filter(group_id=group_id, channel_id__isnull=True).update(
|
group, _ = await GroupConsole.get_or_create(
|
||||||
status=True
|
group_id=group_id, channel_id__isnull=True
|
||||||
)
|
)
|
||||||
|
group.status = True
|
||||||
|
await group.save(update_fields=["status"])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def block(cls, module: str):
|
async def block(cls, module: str):
|
||||||
@ -267,7 +270,9 @@ class PluginManager:
|
|||||||
参数:
|
参数:
|
||||||
module: 模块名
|
module: 模块名
|
||||||
"""
|
"""
|
||||||
await PluginInfo.filter(module=module).update(status=False)
|
if plugin := await PluginInfo.get_plugin(module=module):
|
||||||
|
plugin.status = False
|
||||||
|
await plugin.save(update_fields=["status"])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def unblock(cls, module: str):
|
async def unblock(cls, module: str):
|
||||||
@ -276,7 +281,9 @@ class PluginManager:
|
|||||||
参数:
|
参数:
|
||||||
module: 模块名
|
module: 模块名
|
||||||
"""
|
"""
|
||||||
await PluginInfo.filter(module=module).update(status=True)
|
if plugin := await PluginInfo.get_plugin(module=module):
|
||||||
|
plugin.status = True
|
||||||
|
await plugin.save(update_fields=["status"])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def block_group_plugin(cls, plugin_name: str, group_id: str) -> str:
|
async def block_group_plugin(cls, plugin_name: str, group_id: str) -> str:
|
||||||
@ -437,17 +444,18 @@ class PluginManager:
|
|||||||
"""
|
"""
|
||||||
status_str = "关闭" if status else "开启"
|
status_str = "关闭" if status else "开启"
|
||||||
if is_all:
|
if is_all:
|
||||||
modules = await TaskInfo.annotate().values_list("module", flat=True)
|
module_list = cast(
|
||||||
if modules:
|
list[str], await TaskInfo.annotate().values_list("module", flat=True)
|
||||||
|
)
|
||||||
|
if module_list:
|
||||||
group, _ = await GroupConsole.get_or_create(
|
group, _ = await GroupConsole.get_or_create(
|
||||||
group_id=group_id, channel_id__isnull=True
|
group_id=group_id, channel_id__isnull=True
|
||||||
)
|
)
|
||||||
modules = [f"<{module}" for module in modules]
|
|
||||||
if status:
|
if status:
|
||||||
group.block_task = ",".join(modules) + "," # type: ignore
|
group.block_task = CommonUtils.convert_module_format(module_list)
|
||||||
else:
|
else:
|
||||||
for module in modules:
|
# 开启所有模块 - 清空禁用列表
|
||||||
group.block_task = group.block_task.replace(f"{module},", "")
|
group.block_task = ""
|
||||||
await group.save(update_fields=["block_task"])
|
await group.save(update_fields=["block_task"])
|
||||||
return f"已成功{status_str}全部被动技能!"
|
return f"已成功{status_str}全部被动技能!"
|
||||||
elif task := await TaskInfo.get_or_none(name=task_name):
|
elif task := await TaskInfo.get_or_none(name=task_name):
|
||||||
|
|||||||
@ -1,13 +1,15 @@
|
|||||||
from nonebot import on_message
|
from nonebot import on_message
|
||||||
from nonebot.plugin import PluginMetadata
|
from nonebot.plugin import PluginMetadata
|
||||||
from nonebot_plugin_alconna import UniMsg
|
from nonebot_plugin_alconna import UniMsg
|
||||||
from nonebot_plugin_session import EventSession
|
from nonebot_plugin_apscheduler import scheduler
|
||||||
|
from nonebot_plugin_uninfo import Uninfo
|
||||||
|
|
||||||
from zhenxun.configs.config import Config
|
from zhenxun.configs.config import Config
|
||||||
from zhenxun.configs.utils import PluginExtraData, RegisterConfig
|
from zhenxun.configs.utils import PluginExtraData, RegisterConfig
|
||||||
from zhenxun.models.chat_history import ChatHistory
|
from zhenxun.models.chat_history import ChatHistory
|
||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
from zhenxun.utils.enum import PluginType
|
from zhenxun.utils.enum import PluginType
|
||||||
|
from zhenxun.utils.utils import get_entity_ids
|
||||||
|
|
||||||
__plugin_meta__ = PluginMetadata(
|
__plugin_meta__ = PluginMetadata(
|
||||||
name="消息存储",
|
name="消息存储",
|
||||||
@ -37,18 +39,34 @@ def rule(message: UniMsg) -> bool:
|
|||||||
|
|
||||||
chat_history = on_message(rule=rule, priority=1, block=False)
|
chat_history = on_message(rule=rule, priority=1, block=False)
|
||||||
|
|
||||||
|
TEMP_LIST = []
|
||||||
|
|
||||||
|
|
||||||
@chat_history.handle()
|
@chat_history.handle()
|
||||||
async def handle_message(message: UniMsg, session: EventSession):
|
async def _(message: UniMsg, session: Uninfo):
|
||||||
"""处理消息存储"""
|
entity = get_entity_ids(session)
|
||||||
try:
|
TEMP_LIST.append(
|
||||||
await ChatHistory.create(
|
ChatHistory(
|
||||||
user_id=session.id1,
|
user_id=entity.user_id,
|
||||||
group_id=session.id2,
|
group_id=entity.group_id,
|
||||||
text=str(message),
|
text=str(message),
|
||||||
plain_text=message.extract_plain_text(),
|
plain_text=message.extract_plain_text(),
|
||||||
bot_id=session.bot_id,
|
bot_id=session.self_id,
|
||||||
platform=session.platform,
|
platform=session.platform,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@scheduler.scheduled_job(
|
||||||
|
"interval",
|
||||||
|
minutes=1,
|
||||||
|
)
|
||||||
|
async def _():
|
||||||
|
try:
|
||||||
|
message_list = TEMP_LIST.copy()
|
||||||
|
TEMP_LIST.clear()
|
||||||
|
if message_list:
|
||||||
|
await ChatHistory.bulk_create(message_list)
|
||||||
|
logger.debug(f"批量添加聊天记录 {len(message_list)} 条", "定时任务")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("存储聊天记录失败", "chat_history", e=e)
|
logger.warning("存储聊天记录失败", "chat_history", e=e)
|
||||||
|
|||||||
@ -45,11 +45,13 @@ async def classify_plugin(
|
|||||||
"""
|
"""
|
||||||
sort_data = await sort_type()
|
sort_data = await sort_type()
|
||||||
classify: dict[str, list] = {}
|
classify: dict[str, list] = {}
|
||||||
group = await GroupConsole.get_or_none(group_id=group_id) if group_id else None
|
group = await GroupConsole.get_group(group_id=group_id) if group_id else None
|
||||||
bot = await BotConsole.get_or_none(bot_id=session.self_id)
|
bot = await BotConsole.get_or_none(bot_id=session.self_id)
|
||||||
for menu, value in sort_data.items():
|
for menu, value in sort_data.items():
|
||||||
for plugin in value:
|
for plugin in value:
|
||||||
if not classify.get(menu):
|
if not classify.get(menu):
|
||||||
classify[menu] = []
|
classify[menu] = []
|
||||||
classify[menu].append(handle(bot, plugin, group, is_detail))
|
classify[menu].append(handle(bot, plugin, group, is_detail))
|
||||||
|
for value in classify.values():
|
||||||
|
value.sort(key=lambda x: x.id)
|
||||||
return classify
|
return classify
|
||||||
|
|||||||
@ -21,6 +21,8 @@ class Item(BaseModel):
|
|||||||
"""插件名称"""
|
"""插件名称"""
|
||||||
sta: int
|
sta: int
|
||||||
"""插件状态"""
|
"""插件状态"""
|
||||||
|
id: int
|
||||||
|
"""插件id"""
|
||||||
|
|
||||||
|
|
||||||
class PluginList(BaseModel):
|
class PluginList(BaseModel):
|
||||||
@ -80,10 +82,9 @@ def __handle_item(
|
|||||||
sta = 2
|
sta = 2
|
||||||
if f"{plugin.module}," in group.block_plugin:
|
if f"{plugin.module}," in group.block_plugin:
|
||||||
sta = 1
|
sta = 1
|
||||||
if bot:
|
if bot and f"{plugin.module}," in bot.block_plugins:
|
||||||
if f"{plugin.module}," in bot.block_plugins:
|
sta = 2
|
||||||
sta = 2
|
return Item(plugin_name=plugin.name, sta=sta, id=plugin.id)
|
||||||
return Item(plugin_name=plugin.name, sta=sta)
|
|
||||||
|
|
||||||
|
|
||||||
def build_plugin_data(classify: dict[str, list[Item]]) -> list[dict[str, str]]:
|
def build_plugin_data(classify: dict[str, list[Item]]) -> list[dict[str, str]]:
|
||||||
@ -142,7 +143,7 @@ async def build_html_image(
|
|||||||
template_name="zhenxun_menu.html",
|
template_name="zhenxun_menu.html",
|
||||||
templates={"plugin_list": plugin_list},
|
templates={"plugin_list": plugin_list},
|
||||||
pages={
|
pages={
|
||||||
"viewport": {"width": 1903, "height": 975},
|
"viewport": {"width": 1903, "height": 10},
|
||||||
"base_url": f"file://{TEMPLATE_PATH}",
|
"base_url": f"file://{TEMPLATE_PATH}",
|
||||||
},
|
},
|
||||||
wait=2,
|
wait=2,
|
||||||
|
|||||||
@ -45,7 +45,7 @@ async def build_normal_image(group_id: str | None, is_detail: bool) -> BuildImag
|
|||||||
color="black" if idx % 2 else "white",
|
color="black" if idx % 2 else "white",
|
||||||
)
|
)
|
||||||
curr_h = 10
|
curr_h = 10
|
||||||
group = await GroupConsole.get_or_none(group_id=group_id)
|
group = await GroupConsole.get_group(group_id=group_id) if group_id else None
|
||||||
for _, plugin in enumerate(plugin_list):
|
for _, plugin in enumerate(plugin_list):
|
||||||
text_color = (255, 255, 255) if idx % 2 else (0, 0, 0)
|
text_color = (255, 255, 255) if idx % 2 else (0, 0, 0)
|
||||||
if group and f"{plugin.module}," in group.block_plugin:
|
if group and f"{plugin.module}," in group.block_plugin:
|
||||||
@ -80,7 +80,7 @@ async def build_normal_image(group_id: str | None, is_detail: bool) -> BuildImag
|
|||||||
width, height = 10, 10
|
width, height = 10, 10
|
||||||
for s in [
|
for s in [
|
||||||
"目前支持的功能列表:",
|
"目前支持的功能列表:",
|
||||||
"可以通过 ‘帮助 [功能名称或功能Id]’ 来获取对应功能的使用方法",
|
"可以通过 '帮助 [功能名称或功能Id]' 来获取对应功能的使用方法",
|
||||||
]:
|
]:
|
||||||
text = await BuildImage.build_text_image(s, "HYWenHei-85W.ttf", 24)
|
text = await BuildImage.build_text_image(s, "HYWenHei-85W.ttf", 24)
|
||||||
await result.paste(text, (width, height))
|
await result.paste(text, (width, height))
|
||||||
|
|||||||
@ -20,6 +20,12 @@ class Item(BaseModel):
|
|||||||
"""插件名称"""
|
"""插件名称"""
|
||||||
commands: list[str]
|
commands: list[str]
|
||||||
"""插件命令"""
|
"""插件命令"""
|
||||||
|
id: str
|
||||||
|
"""插件id"""
|
||||||
|
status: bool
|
||||||
|
"""插件状态"""
|
||||||
|
has_superuser_help: bool
|
||||||
|
"""插件是否拥有超级用户帮助"""
|
||||||
|
|
||||||
|
|
||||||
def __handle_item(
|
def __handle_item(
|
||||||
@ -39,23 +45,36 @@ def __handle_item(
|
|||||||
返回:
|
返回:
|
||||||
Item: Item
|
Item: Item
|
||||||
"""
|
"""
|
||||||
|
status = True
|
||||||
|
has_superuser_help = False
|
||||||
|
nb_plugin = nonebot.get_plugin_by_module_name(plugin.module_path)
|
||||||
|
if nb_plugin and nb_plugin.metadata and nb_plugin.metadata.extra:
|
||||||
|
extra_data = PluginExtraData(**nb_plugin.metadata.extra)
|
||||||
|
if extra_data.superuser_help:
|
||||||
|
has_superuser_help = True
|
||||||
if not plugin.status:
|
if not plugin.status:
|
||||||
if plugin.block_type == BlockType.ALL:
|
if plugin.block_type == BlockType.ALL:
|
||||||
plugin.name = f"{plugin.name}(不可用)"
|
status = False
|
||||||
elif group and plugin.block_type == BlockType.GROUP:
|
elif group and plugin.block_type == BlockType.GROUP:
|
||||||
plugin.name = f"{plugin.name}(不可用)"
|
status = False
|
||||||
elif not group and plugin.block_type == BlockType.PRIVATE:
|
elif not group and plugin.block_type == BlockType.PRIVATE:
|
||||||
plugin.name = f"{plugin.name}(不可用)"
|
status = False
|
||||||
elif group and f"{plugin.module}," in group.block_plugin:
|
elif group and f"{plugin.module}," in group.block_plugin:
|
||||||
plugin.name = f"{plugin.name}(不可用)"
|
status = False
|
||||||
elif bot and f"{plugin.module}," in bot.block_plugins:
|
elif bot and f"{plugin.module}," in bot.block_plugins:
|
||||||
plugin.name = f"{plugin.name}(不可用)"
|
status = False
|
||||||
commands = []
|
commands = []
|
||||||
nb_plugin = nonebot.get_plugin_by_module_name(plugin.module_path)
|
nb_plugin = nonebot.get_plugin_by_module_name(plugin.module_path)
|
||||||
if is_detail and nb_plugin and nb_plugin.metadata and nb_plugin.metadata.extra:
|
if is_detail and nb_plugin and nb_plugin.metadata and nb_plugin.metadata.extra:
|
||||||
extra_data = PluginExtraData(**nb_plugin.metadata.extra)
|
extra_data = PluginExtraData(**nb_plugin.metadata.extra)
|
||||||
commands = [cmd.command for cmd in extra_data.commands]
|
commands = [cmd.command for cmd in extra_data.commands]
|
||||||
return Item(plugin_name=f"{plugin.id}-{plugin.name}", commands=commands)
|
return Item(
|
||||||
|
plugin_name=plugin.name,
|
||||||
|
commands=commands,
|
||||||
|
id=str(plugin.id),
|
||||||
|
status=status,
|
||||||
|
has_superuser_help=has_superuser_help,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_plugin_data(classify: dict[str, list[Item]]) -> list[dict[str, str]]:
|
def build_plugin_data(classify: dict[str, list[Item]]) -> list[dict[str, str]]:
|
||||||
@ -78,68 +97,10 @@ def build_plugin_data(classify: dict[str, list[Item]]) -> list[dict[str, str]]:
|
|||||||
}
|
}
|
||||||
for menu, value in classify.items()
|
for menu, value in classify.items()
|
||||||
]
|
]
|
||||||
plugin_list = build_line_data(plugin_list)
|
plugin_list.insert(0, {"name": menu_key, "items": max_data})
|
||||||
plugin_list.insert(
|
|
||||||
0,
|
|
||||||
build_plugin_line(
|
|
||||||
menu_key if menu_key not in ["normal", "功能"] else "主要功能",
|
|
||||||
max_data,
|
|
||||||
30,
|
|
||||||
100,
|
|
||||||
True,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return plugin_list
|
|
||||||
|
|
||||||
|
|
||||||
def build_plugin_line(
|
|
||||||
name: str, items: list, left: int, width: int | None = None, is_max: bool = False
|
|
||||||
) -> dict:
|
|
||||||
"""构造插件行数据
|
|
||||||
|
|
||||||
参数:
|
|
||||||
name: 菜单名称
|
|
||||||
items: 插件名称列表
|
|
||||||
left: 左边距
|
|
||||||
width: 总插件长度.
|
|
||||||
is_max: 是否为最大长度的插件菜单
|
|
||||||
|
|
||||||
返回:
|
|
||||||
dict: 插件数据
|
|
||||||
"""
|
|
||||||
_plugins = []
|
|
||||||
width = width or 50
|
|
||||||
if len(items) // 2 > 6 or is_max:
|
|
||||||
width = 100
|
|
||||||
plugin_list1 = []
|
|
||||||
plugin_list2 = []
|
|
||||||
for i in range(len(items)):
|
|
||||||
if i % 2:
|
|
||||||
plugin_list1.append(items[i])
|
|
||||||
else:
|
|
||||||
plugin_list2.append(items[i])
|
|
||||||
_plugins = [(30, 50, plugin_list1), (0, 50, plugin_list2)]
|
|
||||||
else:
|
|
||||||
_plugins = [(left, 100, items)]
|
|
||||||
return {"name": name, "items": _plugins, "width": width}
|
|
||||||
|
|
||||||
|
|
||||||
def build_line_data(plugin_list: list[dict]) -> list[dict]:
|
|
||||||
"""构造插件数据
|
|
||||||
|
|
||||||
参数:
|
|
||||||
plugin_list: 插件列表
|
|
||||||
|
|
||||||
返回:
|
|
||||||
list[dict]: 插件数据
|
|
||||||
"""
|
|
||||||
left = 30
|
|
||||||
data = []
|
|
||||||
for plugin in plugin_list:
|
for plugin in plugin_list:
|
||||||
data.append(build_plugin_line(plugin["name"], plugin["items"], left))
|
plugin["items"].sort(key=lambda x: x.id)
|
||||||
if len(plugin["items"]) // 2 <= 6:
|
return plugin_list
|
||||||
left = 15 if left == 30 else 30
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
async def build_zhenxun_image(
|
async def build_zhenxun_image(
|
||||||
@ -160,6 +121,7 @@ async def build_zhenxun_image(
|
|||||||
width = int(637 * 1.5) if is_detail else 637
|
width = int(637 * 1.5) if is_detail else 637
|
||||||
title_font = int(53 * 1.5) if is_detail else 53
|
title_font = int(53 * 1.5) if is_detail else 53
|
||||||
tip_font = int(19 * 1.5) if is_detail else 19
|
tip_font = int(19 * 1.5) if is_detail else 19
|
||||||
|
plugin_count = sum(len(plugin["items"]) for plugin in plugin_list)
|
||||||
return await template_to_pic(
|
return await template_to_pic(
|
||||||
template_path=str((TEMPLATE_PATH / "ss_menu").absolute()),
|
template_path=str((TEMPLATE_PATH / "ss_menu").absolute()),
|
||||||
template_name="main.html",
|
template_name="main.html",
|
||||||
@ -170,10 +132,11 @@ async def build_zhenxun_image(
|
|||||||
"width": width,
|
"width": width,
|
||||||
"font_size": (title_font, tip_font),
|
"font_size": (title_font, tip_font),
|
||||||
"is_detail": is_detail,
|
"is_detail": is_detail,
|
||||||
|
"plugin_count": plugin_count,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
pages={
|
pages={
|
||||||
"viewport": {"width": width, "height": 453},
|
"viewport": {"width": width, "height": 10},
|
||||||
"base_url": f"file://{TEMPLATE_PATH}",
|
"base_url": f"file://{TEMPLATE_PATH}",
|
||||||
},
|
},
|
||||||
wait=2,
|
wait=2,
|
||||||
|
|||||||
@ -1,597 +0,0 @@
|
|||||||
from typing import ClassVar
|
|
||||||
|
|
||||||
from nonebot.adapters import Bot, Event
|
|
||||||
from nonebot.adapters.onebot.v11 import PokeNotifyEvent
|
|
||||||
from nonebot.exception import IgnoredException
|
|
||||||
from nonebot.matcher import Matcher
|
|
||||||
from nonebot_plugin_alconna import At, UniMsg
|
|
||||||
from nonebot_plugin_session import EventSession
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from tortoise.exceptions import IntegrityError
|
|
||||||
|
|
||||||
from zhenxun.configs.config import Config
|
|
||||||
from zhenxun.models.bot_console import BotConsole
|
|
||||||
from zhenxun.models.group_console import GroupConsole
|
|
||||||
from zhenxun.models.level_user import LevelUser
|
|
||||||
from zhenxun.models.plugin_info import PluginInfo
|
|
||||||
from zhenxun.models.plugin_limit import PluginLimit
|
|
||||||
from zhenxun.models.sign_user import SignUser
|
|
||||||
from zhenxun.models.user_console import UserConsole
|
|
||||||
from zhenxun.services.log import logger
|
|
||||||
from zhenxun.utils.enum import (
|
|
||||||
BlockType,
|
|
||||||
GoldHandle,
|
|
||||||
LimitWatchType,
|
|
||||||
PluginLimitType,
|
|
||||||
PluginType,
|
|
||||||
)
|
|
||||||
from zhenxun.utils.exception import InsufficientGold
|
|
||||||
from zhenxun.utils.message import MessageUtils
|
|
||||||
from zhenxun.utils.utils import CountLimiter, FreqLimiter, UserBlockLimiter
|
|
||||||
|
|
||||||
base_config = Config.get("hook")
|
|
||||||
|
|
||||||
|
|
||||||
class Limit(BaseModel):
|
|
||||||
limit: PluginLimit
|
|
||||||
limiter: FreqLimiter | UserBlockLimiter | CountLimiter
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
|
|
||||||
class LimitManage:
|
|
||||||
add_module: ClassVar[list] = []
|
|
||||||
|
|
||||||
cd_limit: ClassVar[dict[str, Limit]] = {}
|
|
||||||
block_limit: ClassVar[dict[str, Limit]] = {}
|
|
||||||
count_limit: ClassVar[dict[str, Limit]] = {}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def add_limit(cls, limit: PluginLimit):
|
|
||||||
"""添加限制
|
|
||||||
|
|
||||||
参数:
|
|
||||||
limit: PluginLimit
|
|
||||||
"""
|
|
||||||
if limit.module not in cls.add_module:
|
|
||||||
cls.add_module.append(limit.module)
|
|
||||||
if limit.limit_type == PluginLimitType.BLOCK:
|
|
||||||
cls.block_limit[limit.module] = Limit(
|
|
||||||
limit=limit, limiter=UserBlockLimiter()
|
|
||||||
)
|
|
||||||
elif limit.limit_type == PluginLimitType.CD:
|
|
||||||
cls.cd_limit[limit.module] = Limit(
|
|
||||||
limit=limit, limiter=FreqLimiter(limit.cd)
|
|
||||||
)
|
|
||||||
elif limit.limit_type == PluginLimitType.COUNT:
|
|
||||||
cls.count_limit[limit.module] = Limit(
|
|
||||||
limit=limit, limiter=CountLimiter(limit.max_count)
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def unblock(
|
|
||||||
cls, module: str, user_id: str, group_id: str | None, channel_id: str | None
|
|
||||||
):
|
|
||||||
"""解除插件block
|
|
||||||
|
|
||||||
参数:
|
|
||||||
module: 模块名
|
|
||||||
user_id: 用户id
|
|
||||||
group_id: 群组id
|
|
||||||
channel_id: 频道id
|
|
||||||
"""
|
|
||||||
if limit_model := cls.block_limit.get(module):
|
|
||||||
limit = limit_model.limit
|
|
||||||
limiter: UserBlockLimiter = limit_model.limiter # type: ignore
|
|
||||||
key_type = user_id
|
|
||||||
if group_id and limit.watch_type == LimitWatchType.GROUP:
|
|
||||||
key_type = channel_id or group_id
|
|
||||||
logger.debug(
|
|
||||||
f"解除对象: {key_type} 的block限制",
|
|
||||||
"AuthChecker",
|
|
||||||
session=user_id,
|
|
||||||
group_id=group_id,
|
|
||||||
)
|
|
||||||
limiter.set_false(key_type)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def check(
|
|
||||||
cls,
|
|
||||||
module: str,
|
|
||||||
user_id: str,
|
|
||||||
group_id: str | None,
|
|
||||||
channel_id: str | None,
|
|
||||||
session: EventSession,
|
|
||||||
):
|
|
||||||
"""检测限制
|
|
||||||
|
|
||||||
参数:
|
|
||||||
module: 模块名
|
|
||||||
user_id: 用户id
|
|
||||||
group_id: 群组id
|
|
||||||
channel_id: 频道id
|
|
||||||
session: Session
|
|
||||||
|
|
||||||
异常:
|
|
||||||
IgnoredException: IgnoredException
|
|
||||||
"""
|
|
||||||
if limit_model := cls.cd_limit.get(module):
|
|
||||||
await cls.__check(limit_model, user_id, group_id, channel_id, session)
|
|
||||||
if limit_model := cls.block_limit.get(module):
|
|
||||||
await cls.__check(limit_model, user_id, group_id, channel_id, session)
|
|
||||||
if limit_model := cls.count_limit.get(module):
|
|
||||||
await cls.__check(limit_model, user_id, group_id, channel_id, session)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def __check(
|
|
||||||
cls,
|
|
||||||
limit_model: Limit | None,
|
|
||||||
user_id: str,
|
|
||||||
group_id: str | None,
|
|
||||||
channel_id: str | None,
|
|
||||||
session: EventSession,
|
|
||||||
):
|
|
||||||
"""检测限制
|
|
||||||
|
|
||||||
参数:
|
|
||||||
limit_model: Limit
|
|
||||||
user_id: 用户id
|
|
||||||
group_id: 群组id
|
|
||||||
channel_id: 频道id
|
|
||||||
session: Session
|
|
||||||
|
|
||||||
异常:
|
|
||||||
IgnoredException: IgnoredException
|
|
||||||
"""
|
|
||||||
if not limit_model:
|
|
||||||
return
|
|
||||||
limit = limit_model.limit
|
|
||||||
limiter = limit_model.limiter
|
|
||||||
is_limit = (
|
|
||||||
LimitWatchType.ALL
|
|
||||||
or (group_id and limit.watch_type == LimitWatchType.GROUP)
|
|
||||||
or (not group_id and limit.watch_type == LimitWatchType.USER)
|
|
||||||
)
|
|
||||||
key_type = user_id
|
|
||||||
if group_id and limit.watch_type == LimitWatchType.GROUP:
|
|
||||||
key_type = channel_id or group_id
|
|
||||||
if is_limit and not limiter.check(key_type):
|
|
||||||
if limit.result:
|
|
||||||
await MessageUtils.build_message(limit.result).send()
|
|
||||||
logger.debug(
|
|
||||||
f"{limit.module}({limit.limit_type}) 正在限制中...",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
raise IgnoredException(f"{limit.module} 正在限制中...")
|
|
||||||
else:
|
|
||||||
logger.debug(
|
|
||||||
f"开始进行限制 {limit.module}({limit.limit_type})...",
|
|
||||||
"AuthChecker",
|
|
||||||
session=user_id,
|
|
||||||
group_id=group_id,
|
|
||||||
)
|
|
||||||
if isinstance(limiter, FreqLimiter):
|
|
||||||
limiter.start_cd(key_type)
|
|
||||||
if isinstance(limiter, UserBlockLimiter):
|
|
||||||
limiter.set_true(key_type)
|
|
||||||
if isinstance(limiter, CountLimiter):
|
|
||||||
limiter.increase(key_type)
|
|
||||||
|
|
||||||
|
|
||||||
class IsSuperuserException(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class AuthChecker:
|
|
||||||
"""
|
|
||||||
权限检查
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
check_notice_info_cd = Config.get_config("hook", "CHECK_NOTICE_INFO_CD")
|
|
||||||
if check_notice_info_cd is None or check_notice_info_cd < 0:
|
|
||||||
raise ValueError("模块: [hook], 配置项: [CHECK_NOTICE_INFO_CD] 为空或小于0")
|
|
||||||
self._flmt = FreqLimiter(check_notice_info_cd)
|
|
||||||
self._flmt_g = FreqLimiter(check_notice_info_cd)
|
|
||||||
self._flmt_s = FreqLimiter(check_notice_info_cd)
|
|
||||||
self._flmt_c = FreqLimiter(check_notice_info_cd)
|
|
||||||
|
|
||||||
def is_send_limit_message(self, plugin: PluginInfo, sid: str) -> bool:
|
|
||||||
"""是否发送提示消息
|
|
||||||
|
|
||||||
参数:
|
|
||||||
plugin: PluginInfo
|
|
||||||
|
|
||||||
返回:
|
|
||||||
bool: 是否发送提示消息
|
|
||||||
"""
|
|
||||||
if not base_config.get("IS_SEND_TIP_MESSAGE"):
|
|
||||||
return False
|
|
||||||
if plugin.plugin_type == PluginType.DEPENDANT:
|
|
||||||
return False
|
|
||||||
if plugin.ignore_prompt:
|
|
||||||
return False
|
|
||||||
return self._flmt_s.check(sid)
|
|
||||||
|
|
||||||
async def auth(
|
|
||||||
self,
|
|
||||||
matcher: Matcher,
|
|
||||||
event: Event,
|
|
||||||
bot: Bot,
|
|
||||||
session: EventSession,
|
|
||||||
message: UniMsg,
|
|
||||||
):
|
|
||||||
"""权限检查
|
|
||||||
|
|
||||||
参数:
|
|
||||||
matcher: matcher
|
|
||||||
bot: bot
|
|
||||||
session: EventSession
|
|
||||||
message: UniMsg
|
|
||||||
"""
|
|
||||||
is_ignore = False
|
|
||||||
cost_gold = 0
|
|
||||||
user_id = session.id1
|
|
||||||
group_id = session.id3
|
|
||||||
channel_id = session.id2
|
|
||||||
if not group_id:
|
|
||||||
group_id = channel_id
|
|
||||||
channel_id = None
|
|
||||||
if matcher.type == "notice" and not isinstance(event, PokeNotifyEvent):
|
|
||||||
"""过滤除poke外的notice"""
|
|
||||||
return
|
|
||||||
if user_id and matcher.plugin and (module_path := matcher.plugin.module_name):
|
|
||||||
try:
|
|
||||||
user = await UserConsole.get_user(user_id, session.platform)
|
|
||||||
except IntegrityError as e:
|
|
||||||
logger.debug(
|
|
||||||
"重复创建用户,已跳过该次权限...",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
e=e,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
if plugin := await PluginInfo.get_or_none(module_path=module_path):
|
|
||||||
if plugin.plugin_type == PluginType.HIDDEN:
|
|
||||||
logger.debug(
|
|
||||||
f"插件: {plugin.name}:{plugin.module} "
|
|
||||||
"为HIDDEN,已跳过权限检查..."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
cost_gold = await self.auth_cost(user, plugin, session)
|
|
||||||
if session.id1 in bot.config.superusers:
|
|
||||||
if plugin.plugin_type == PluginType.SUPERUSER:
|
|
||||||
raise IsSuperuserException()
|
|
||||||
if not plugin.limit_superuser:
|
|
||||||
cost_gold = 0
|
|
||||||
raise IsSuperuserException()
|
|
||||||
await self.auth_bot(plugin, bot.self_id)
|
|
||||||
await self.auth_group(plugin, session, message)
|
|
||||||
await self.auth_admin(plugin, session)
|
|
||||||
await self.auth_plugin(plugin, session, event)
|
|
||||||
await self.auth_limit(plugin, session)
|
|
||||||
except IsSuperuserException:
|
|
||||||
logger.debug(
|
|
||||||
"超级用户或被ban跳过权限检测...", "AuthChecker", session=session
|
|
||||||
)
|
|
||||||
except IgnoredException:
|
|
||||||
is_ignore = True
|
|
||||||
LimitManage.unblock(
|
|
||||||
matcher.plugin.name, user_id, group_id, channel_id
|
|
||||||
)
|
|
||||||
except AssertionError as e:
|
|
||||||
is_ignore = True
|
|
||||||
logger.debug("消息无法发送", session=session, e=e)
|
|
||||||
if cost_gold and user_id:
|
|
||||||
"""花费金币"""
|
|
||||||
try:
|
|
||||||
await UserConsole.reduce_gold(
|
|
||||||
user_id,
|
|
||||||
cost_gold,
|
|
||||||
GoldHandle.PLUGIN,
|
|
||||||
matcher.plugin.name if matcher.plugin else "",
|
|
||||||
session.platform,
|
|
||||||
)
|
|
||||||
except InsufficientGold:
|
|
||||||
if u := await UserConsole.get_user(user_id):
|
|
||||||
u.gold = 0
|
|
||||||
await u.save(update_fields=["gold"])
|
|
||||||
logger.debug(
|
|
||||||
f"调用功能花费金币: {cost_gold}", "AuthChecker", session=session
|
|
||||||
)
|
|
||||||
if is_ignore:
|
|
||||||
raise IgnoredException("权限检测 ignore")
|
|
||||||
|
|
||||||
async def auth_bot(self, plugin: PluginInfo, bot_id: str):
|
|
||||||
"""机器人权限
|
|
||||||
|
|
||||||
参数:
|
|
||||||
plugin: PluginInfo
|
|
||||||
bot_id: bot_id
|
|
||||||
"""
|
|
||||||
if not await BotConsole.get_bot_status(bot_id):
|
|
||||||
logger.debug("Bot休眠中阻断权限检测...", "AuthChecker")
|
|
||||||
raise IgnoredException("BotConsole休眠权限检测 ignore")
|
|
||||||
if await BotConsole.is_block_plugin(bot_id, plugin.module):
|
|
||||||
logger.debug(
|
|
||||||
f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭...",
|
|
||||||
"AuthChecker",
|
|
||||||
)
|
|
||||||
raise IgnoredException("BotConsole插件权限检测 ignore")
|
|
||||||
|
|
||||||
async def auth_limit(self, plugin: PluginInfo, session: EventSession):
|
|
||||||
"""插件限制
|
|
||||||
|
|
||||||
参数:
|
|
||||||
plugin: PluginInfo
|
|
||||||
session: EventSession
|
|
||||||
"""
|
|
||||||
user_id = session.id1
|
|
||||||
group_id = session.id3
|
|
||||||
channel_id = session.id2
|
|
||||||
if not group_id:
|
|
||||||
group_id = channel_id
|
|
||||||
channel_id = None
|
|
||||||
if plugin.module not in LimitManage.add_module:
|
|
||||||
limit_list: list[PluginLimit] = await plugin.plugin_limit.filter(
|
|
||||||
status=True
|
|
||||||
).all() # type: ignore
|
|
||||||
for limit in limit_list:
|
|
||||||
LimitManage.add_limit(limit)
|
|
||||||
if user_id:
|
|
||||||
await LimitManage.check(
|
|
||||||
plugin.module, user_id, group_id, channel_id, session
|
|
||||||
)
|
|
||||||
|
|
||||||
async def auth_plugin(
|
|
||||||
self, plugin: PluginInfo, session: EventSession, event: Event
|
|
||||||
):
|
|
||||||
"""插件状态
|
|
||||||
|
|
||||||
参数:
|
|
||||||
plugin: PluginInfo
|
|
||||||
session: EventSession
|
|
||||||
"""
|
|
||||||
group_id = session.id3
|
|
||||||
channel_id = session.id2
|
|
||||||
if not group_id:
|
|
||||||
group_id = channel_id
|
|
||||||
channel_id = None
|
|
||||||
if user_id := session.id1:
|
|
||||||
if plugin.impression > 0:
|
|
||||||
sign_user = await SignUser.get_user(user_id)
|
|
||||||
if float(sign_user.impression) < plugin.impression:
|
|
||||||
if self.is_send_limit_message(plugin, user_id):
|
|
||||||
self._flmt_s.start_cd(user_id)
|
|
||||||
await MessageUtils.build_message(
|
|
||||||
f"好感度不足哦,当前功能需要好感度: {plugin.impression},"
|
|
||||||
"请继续签到提升好感度吧!"
|
|
||||||
).send(reply_to=True)
|
|
||||||
logger.debug(
|
|
||||||
f"{plugin.name}({plugin.module}) 用户好感度不足...",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
raise IgnoredException("好感度不足...")
|
|
||||||
if group_id:
|
|
||||||
sid = group_id or user_id
|
|
||||||
if await GroupConsole.is_superuser_block_plugin(
|
|
||||||
group_id, plugin.module
|
|
||||||
):
|
|
||||||
"""超级用户群组插件状态"""
|
|
||||||
if self.is_send_limit_message(plugin, sid):
|
|
||||||
self._flmt_s.start_cd(group_id or user_id)
|
|
||||||
await MessageUtils.build_message(
|
|
||||||
"超级管理员禁用了该群此功能..."
|
|
||||||
).send(reply_to=True)
|
|
||||||
logger.debug(
|
|
||||||
f"{plugin.name}({plugin.module}) 超级管理员禁用了该群此功能...",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
raise IgnoredException("超级管理员禁用了该群此功能...")
|
|
||||||
if await GroupConsole.is_normal_block_plugin(group_id, plugin.module):
|
|
||||||
"""群组插件状态"""
|
|
||||||
if self.is_send_limit_message(plugin, sid):
|
|
||||||
self._flmt_s.start_cd(group_id or user_id)
|
|
||||||
await MessageUtils.build_message("该群未开启此功能...").send(
|
|
||||||
reply_to=True
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"{plugin.name}({plugin.module}) 未开启此功能...",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
raise IgnoredException("该群未开启此功能...")
|
|
||||||
if plugin.block_type == BlockType.GROUP:
|
|
||||||
"""全局群组禁用"""
|
|
||||||
try:
|
|
||||||
if self.is_send_limit_message(plugin, sid):
|
|
||||||
self._flmt_c.start_cd(group_id)
|
|
||||||
await MessageUtils.build_message(
|
|
||||||
"该功能在群组中已被禁用..."
|
|
||||||
).send(reply_to=True)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
"auth_plugin 发送消息失败",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
e=e,
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"{plugin.name}({plugin.module}) 该插件在群组中已被禁用...",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
raise IgnoredException("该插件在群组中已被禁用...")
|
|
||||||
else:
|
|
||||||
sid = user_id
|
|
||||||
if plugin.block_type == BlockType.PRIVATE:
|
|
||||||
"""全局私聊禁用"""
|
|
||||||
try:
|
|
||||||
if self.is_send_limit_message(plugin, sid):
|
|
||||||
self._flmt_c.start_cd(user_id)
|
|
||||||
await MessageUtils.build_message(
|
|
||||||
"该功能在私聊中已被禁用..."
|
|
||||||
).send()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
"auth_admin 发送消息失败",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
e=e,
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"{plugin.name}({plugin.module}) 该插件在私聊中已被禁用...",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
raise IgnoredException("该插件在私聊中已被禁用...")
|
|
||||||
if not plugin.status and plugin.block_type == BlockType.ALL:
|
|
||||||
"""全局状态"""
|
|
||||||
if group_id and await GroupConsole.is_super_group(group_id):
|
|
||||||
raise IsSuperuserException()
|
|
||||||
logger.debug(
|
|
||||||
f"{plugin.name}({plugin.module}) 全局未开启此功能...",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
if self.is_send_limit_message(plugin, sid):
|
|
||||||
self._flmt_s.start_cd(group_id or user_id)
|
|
||||||
await MessageUtils.build_message("全局未开启此功能...").send()
|
|
||||||
raise IgnoredException("全局未开启此功能...")
|
|
||||||
|
|
||||||
async def auth_admin(self, plugin: PluginInfo, session: EventSession):
|
|
||||||
"""管理员命令 个人权限
|
|
||||||
|
|
||||||
参数:
|
|
||||||
plugin: PluginInfo
|
|
||||||
session: EventSession
|
|
||||||
"""
|
|
||||||
user_id = session.id1
|
|
||||||
if user_id and plugin.admin_level:
|
|
||||||
if group_id := session.id3 or session.id2:
|
|
||||||
if not await LevelUser.check_level(
|
|
||||||
user_id, group_id, plugin.admin_level
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
if self._flmt.check(user_id):
|
|
||||||
self._flmt.start_cd(user_id)
|
|
||||||
await MessageUtils.build_message(
|
|
||||||
[
|
|
||||||
At(flag="user", target=user_id),
|
|
||||||
f"你的权限不足喔,"
|
|
||||||
f"该功能需要的权限等级: {plugin.admin_level}",
|
|
||||||
]
|
|
||||||
).send(reply_to=True)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
"auth_admin 发送消息失败",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
e=e,
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"{plugin.name}({plugin.module}) 管理员权限不足...",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
raise IgnoredException("管理员权限不足...")
|
|
||||||
elif not await LevelUser.check_level(user_id, None, plugin.admin_level):
|
|
||||||
try:
|
|
||||||
await MessageUtils.build_message(
|
|
||||||
f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}"
|
|
||||||
).send()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
"auth_admin 发送消息失败", "AuthChecker", session=session, e=e
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"{plugin.name}({plugin.module}) 管理员权限不足...",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
raise IgnoredException("权限不足")
|
|
||||||
|
|
||||||
async def auth_group(
|
|
||||||
self, plugin: PluginInfo, session: EventSession, message: UniMsg
|
|
||||||
):
|
|
||||||
"""群黑名单检测 群总开关检测
|
|
||||||
|
|
||||||
参数:
|
|
||||||
plugin: PluginInfo
|
|
||||||
session: EventSession
|
|
||||||
message: UniMsg
|
|
||||||
"""
|
|
||||||
if not (group_id := session.id3 or session.id2):
|
|
||||||
return
|
|
||||||
text = message.extract_plain_text()
|
|
||||||
group = await GroupConsole.get_group(group_id)
|
|
||||||
if not group:
|
|
||||||
"""群不存在"""
|
|
||||||
logger.debug(
|
|
||||||
"群组信息不存在...",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
raise IgnoredException("群不存在")
|
|
||||||
if group.level < 0:
|
|
||||||
"""群权限小于0"""
|
|
||||||
logger.debug(
|
|
||||||
"群黑名单, 群权限-1...",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
raise IgnoredException("群黑名单")
|
|
||||||
if not group.status:
|
|
||||||
"""群休眠"""
|
|
||||||
if text.strip() != "醒来":
|
|
||||||
logger.debug("群休眠状态...", "AuthChecker", session=session)
|
|
||||||
raise IgnoredException("群休眠状态")
|
|
||||||
if plugin.level > group.level:
|
|
||||||
"""插件等级大于群等级"""
|
|
||||||
logger.debug(
|
|
||||||
f"{plugin.name}({plugin.module}) 群等级限制.."
|
|
||||||
f"该功能需要的群等级: {plugin.level}..",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
raise IgnoredException(f"{plugin.name}({plugin.module}) 群等级限制...")
|
|
||||||
|
|
||||||
async def auth_cost(
|
|
||||||
self, user: UserConsole, plugin: PluginInfo, session: EventSession
|
|
||||||
) -> int:
|
|
||||||
"""检测是否满足金币条件
|
|
||||||
|
|
||||||
参数:
|
|
||||||
user: UserConsole
|
|
||||||
plugin: PluginInfo
|
|
||||||
session: EventSession
|
|
||||||
|
|
||||||
返回:
|
|
||||||
int: 需要消耗的金币
|
|
||||||
"""
|
|
||||||
if user.gold < plugin.cost_gold:
|
|
||||||
"""插件消耗金币不足"""
|
|
||||||
try:
|
|
||||||
await MessageUtils.build_message(
|
|
||||||
f"金币不足..该功能需要{plugin.cost_gold}金币.."
|
|
||||||
).send()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
"auth_cost 发送消息失败", "AuthChecker", session=session, e=e
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"{plugin.name}({plugin.module}) 金币限制.."
|
|
||||||
f"该功能需要{plugin.cost_gold}金币..",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
raise IgnoredException(f"{plugin.name}({plugin.module}) 金币限制...")
|
|
||||||
return plugin.cost_gold
|
|
||||||
|
|
||||||
|
|
||||||
checker = AuthChecker()
|
|
||||||
99
zhenxun/builtin_plugins/hooks/auth/auth_admin.py
Normal file
99
zhenxun/builtin_plugins/hooks/auth/auth_admin.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
|
from nonebot_plugin_alconna import At
|
||||||
|
from nonebot_plugin_uninfo import Uninfo
|
||||||
|
|
||||||
|
from zhenxun.models.level_user import LevelUser
|
||||||
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
|
from zhenxun.services.data_access import DataAccess
|
||||||
|
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
||||||
|
from zhenxun.services.log import logger
|
||||||
|
from zhenxun.utils.utils import get_entity_ids
|
||||||
|
|
||||||
|
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
|
||||||
|
from .exception import SkipPluginException
|
||||||
|
from .utils import send_message
|
||||||
|
|
||||||
|
|
||||||
|
async def auth_admin(plugin: PluginInfo, session: Uninfo):
|
||||||
|
"""管理员命令 个人权限
|
||||||
|
|
||||||
|
参数:
|
||||||
|
plugin: PluginInfo
|
||||||
|
session: Uninfo
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
if not plugin.admin_level:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
entity = get_entity_ids(session)
|
||||||
|
level_dao = DataAccess(LevelUser)
|
||||||
|
|
||||||
|
# 并行查询用户权限数据
|
||||||
|
global_user: LevelUser | None = None
|
||||||
|
group_users: LevelUser | None = None
|
||||||
|
|
||||||
|
# 查询全局权限
|
||||||
|
global_user_task = level_dao.safe_get_or_none(
|
||||||
|
user_id=session.user.id, group_id__isnull=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 如果在群组中,查询群组权限
|
||||||
|
group_users_task = None
|
||||||
|
if entity.group_id:
|
||||||
|
group_users_task = level_dao.safe_get_or_none(
|
||||||
|
user_id=session.user.id, group_id=entity.group_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# 等待查询完成,添加超时控制
|
||||||
|
try:
|
||||||
|
results = await asyncio.wait_for(
|
||||||
|
asyncio.gather(global_user_task, group_users_task or asyncio.sleep(0)),
|
||||||
|
timeout=DB_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
global_user = results[0]
|
||||||
|
group_users = results[1] if group_users_task else None
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(f"查询用户权限超时: user_id={session.user.id}", LOGGER_COMMAND)
|
||||||
|
# 超时时不阻塞,继续执行
|
||||||
|
return
|
||||||
|
|
||||||
|
user_level = global_user.user_level if global_user else 0
|
||||||
|
if entity.group_id and group_users:
|
||||||
|
user_level = max(user_level, group_users.user_level)
|
||||||
|
|
||||||
|
if user_level < plugin.admin_level:
|
||||||
|
await send_message(
|
||||||
|
session,
|
||||||
|
[
|
||||||
|
At(flag="user", target=session.user.id),
|
||||||
|
f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}",
|
||||||
|
],
|
||||||
|
entity.user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise SkipPluginException(
|
||||||
|
f"{plugin.name}({plugin.module}) 管理员权限不足..."
|
||||||
|
)
|
||||||
|
elif global_user:
|
||||||
|
if global_user.user_level < plugin.admin_level:
|
||||||
|
await send_message(
|
||||||
|
session,
|
||||||
|
f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}",
|
||||||
|
)
|
||||||
|
|
||||||
|
raise SkipPluginException(
|
||||||
|
f"{plugin.name}({plugin.module}) 管理员权限不足..."
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
# 记录执行时间
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||||
|
logger.warning(
|
||||||
|
f"auth_admin 耗时: {elapsed:.3f}s, plugin={plugin.module}",
|
||||||
|
LOGGER_COMMAND,
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
303
zhenxun/builtin_plugins/hooks/auth/auth_ban.py
Normal file
303
zhenxun/builtin_plugins/hooks/auth/auth_ban.py
Normal file
@ -0,0 +1,303 @@
|
|||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
|
from nonebot.adapters import Bot
|
||||||
|
from nonebot.matcher import Matcher
|
||||||
|
from nonebot_plugin_alconna import At
|
||||||
|
from nonebot_plugin_uninfo import Uninfo
|
||||||
|
|
||||||
|
from zhenxun.configs.config import Config
|
||||||
|
from zhenxun.models.ban_console import BanConsole
|
||||||
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
|
from zhenxun.services.data_access import DataAccess
|
||||||
|
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
||||||
|
from zhenxun.services.log import logger
|
||||||
|
from zhenxun.utils.enum import PluginType
|
||||||
|
from zhenxun.utils.utils import EntityIDs, get_entity_ids
|
||||||
|
|
||||||
|
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
|
||||||
|
from .exception import SkipPluginException
|
||||||
|
from .utils import freq, send_message
|
||||||
|
|
||||||
|
Config.add_plugin_config(
|
||||||
|
"hook",
|
||||||
|
"BAN_RESULT",
|
||||||
|
"才不会给你发消息.",
|
||||||
|
help="对被ban用户发送的消息",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_ban_time(ban_record: BanConsole | None) -> int:
|
||||||
|
"""根据ban记录计算剩余ban时间
|
||||||
|
|
||||||
|
参数:
|
||||||
|
ban_record: BanConsole记录
|
||||||
|
|
||||||
|
返回:
|
||||||
|
int: ban剩余时长,-1时为永久ban,0表示未被ban
|
||||||
|
"""
|
||||||
|
if not ban_record:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if ban_record.duration == -1:
|
||||||
|
return -1
|
||||||
|
|
||||||
|
_time = time.time() - (ban_record.ban_time + ban_record.duration)
|
||||||
|
return 0 if _time > 0 else int(abs(_time))
|
||||||
|
|
||||||
|
|
||||||
|
async def is_ban(user_id: str | None, group_id: str | None) -> int:
|
||||||
|
"""检查用户或群组是否被ban
|
||||||
|
|
||||||
|
参数:
|
||||||
|
user_id: 用户ID
|
||||||
|
group_id: 群组ID
|
||||||
|
|
||||||
|
返回:
|
||||||
|
int: ban的剩余时间,0表示未被ban
|
||||||
|
"""
|
||||||
|
if not user_id and not group_id:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
ban_dao = DataAccess(BanConsole)
|
||||||
|
|
||||||
|
# 分别获取用户在群组中的ban记录和全局ban记录
|
||||||
|
group_user = None
|
||||||
|
user = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 并行查询用户和群组的 ban 记录
|
||||||
|
tasks = []
|
||||||
|
if user_id and group_id:
|
||||||
|
tasks.append(ban_dao.safe_get_or_none(user_id=user_id, group_id=group_id))
|
||||||
|
if user_id:
|
||||||
|
tasks.append(
|
||||||
|
ban_dao.safe_get_or_none(user_id=user_id, group_id__isnull=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 等待所有查询完成,添加超时控制
|
||||||
|
if tasks:
|
||||||
|
try:
|
||||||
|
ban_records = await asyncio.wait_for(
|
||||||
|
asyncio.gather(*tasks), timeout=DB_TIMEOUT_SECONDS
|
||||||
|
)
|
||||||
|
if len(tasks) == 2:
|
||||||
|
group_user, user = ban_records
|
||||||
|
elif user_id and group_id:
|
||||||
|
group_user = ban_records[0]
|
||||||
|
else:
|
||||||
|
user = ban_records[0]
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(
|
||||||
|
f"查询ban记录超时: user_id={user_id}, group_id={group_id}",
|
||||||
|
LOGGER_COMMAND,
|
||||||
|
)
|
||||||
|
# 超时时返回0,避免阻塞
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# 检查记录并计算ban时间
|
||||||
|
results = []
|
||||||
|
if group_user:
|
||||||
|
results.append(group_user)
|
||||||
|
if user:
|
||||||
|
results.append(user)
|
||||||
|
|
||||||
|
# 如果没有找到记录,返回0
|
||||||
|
if not results:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
logger.debug(f"查询到的ban记录: {results}", LOGGER_COMMAND)
|
||||||
|
# 检查所有记录,找出最严格的ban(时间最长的)
|
||||||
|
max_ban_time: int = 0
|
||||||
|
for result in results:
|
||||||
|
if result.duration > 0 or result.duration == -1:
|
||||||
|
# 直接计算ban时间,避免再次查询数据库
|
||||||
|
ban_time = calculate_ban_time(result)
|
||||||
|
if ban_time == -1 or ban_time > max_ban_time:
|
||||||
|
max_ban_time = ban_time
|
||||||
|
|
||||||
|
return max_ban_time
|
||||||
|
finally:
|
||||||
|
# 记录执行时间
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||||
|
logger.warning(
|
||||||
|
f"is_ban 耗时: {elapsed:.3f}s",
|
||||||
|
LOGGER_COMMAND,
|
||||||
|
session=user_id,
|
||||||
|
group_id=group_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def check_plugin_type(matcher: Matcher) -> bool:
|
||||||
|
"""判断插件类型是否是隐藏插件
|
||||||
|
|
||||||
|
参数:
|
||||||
|
matcher: Matcher
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bool: 是否为隐藏插件
|
||||||
|
"""
|
||||||
|
if plugin := matcher.plugin:
|
||||||
|
if metadata := plugin.metadata:
|
||||||
|
extra = metadata.extra
|
||||||
|
if extra.get("plugin_type") in [PluginType.HIDDEN]:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def format_time(time_val: float) -> str:
|
||||||
|
"""格式化时间
|
||||||
|
|
||||||
|
参数:
|
||||||
|
time_val: ban时长
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 格式化时间文本
|
||||||
|
"""
|
||||||
|
if time_val == -1:
|
||||||
|
return "∞"
|
||||||
|
time_val = abs(int(time_val))
|
||||||
|
if time_val < 60:
|
||||||
|
time_str = f"{time_val!s} 秒"
|
||||||
|
else:
|
||||||
|
minute = int(time_val / 60)
|
||||||
|
if minute > 60:
|
||||||
|
hours = minute // 60
|
||||||
|
minute %= 60
|
||||||
|
time_str = f"{hours} 小时 {minute}分钟"
|
||||||
|
else:
|
||||||
|
time_str = f"{minute} 分钟"
|
||||||
|
return time_str
|
||||||
|
|
||||||
|
|
||||||
|
async def group_handle(group_id: str) -> None:
|
||||||
|
"""群组ban检查
|
||||||
|
|
||||||
|
参数:
|
||||||
|
group_id: 群组id
|
||||||
|
|
||||||
|
异常:
|
||||||
|
SkipPluginException: 群组处于黑名单
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
if await is_ban(None, group_id):
|
||||||
|
raise SkipPluginException("群组处于黑名单中...")
|
||||||
|
finally:
|
||||||
|
# 记录执行时间
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||||
|
logger.warning(
|
||||||
|
f"group_handle 耗时: {elapsed:.3f}s",
|
||||||
|
LOGGER_COMMAND,
|
||||||
|
group_id=group_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def user_handle(module: str, entity: EntityIDs, session: Uninfo) -> None:
|
||||||
|
"""用户ban检查
|
||||||
|
|
||||||
|
参数:
|
||||||
|
module: 插件模块名
|
||||||
|
entity: 实体ID信息
|
||||||
|
session: Uninfo
|
||||||
|
|
||||||
|
异常:
|
||||||
|
SkipPluginException: 用户处于黑名单
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
ban_result = Config.get_config("hook", "BAN_RESULT")
|
||||||
|
time_val = await is_ban(entity.user_id, entity.group_id)
|
||||||
|
if not time_val:
|
||||||
|
return
|
||||||
|
time_str = format_time(time_val)
|
||||||
|
plugin_dao = DataAccess(PluginInfo)
|
||||||
|
try:
|
||||||
|
db_plugin = await asyncio.wait_for(
|
||||||
|
plugin_dao.safe_get_or_none(module=module), timeout=DB_TIMEOUT_SECONDS
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(f"查询插件信息超时: {module}", LOGGER_COMMAND)
|
||||||
|
# 超时时不阻塞,继续执行
|
||||||
|
raise SkipPluginException("用户处于黑名单中...")
|
||||||
|
|
||||||
|
if (
|
||||||
|
db_plugin
|
||||||
|
and not db_plugin.ignore_prompt
|
||||||
|
and time_val != -1
|
||||||
|
and ban_result
|
||||||
|
and freq.is_send_limit_message(db_plugin, entity.user_id, False)
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
send_message(
|
||||||
|
session,
|
||||||
|
[
|
||||||
|
At(flag="user", target=entity.user_id),
|
||||||
|
f"{ban_result}\n在..在 {time_str} 后才会理你喔",
|
||||||
|
],
|
||||||
|
entity.user_id,
|
||||||
|
),
|
||||||
|
timeout=DB_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(f"发送消息超时: {entity.user_id}", LOGGER_COMMAND)
|
||||||
|
raise SkipPluginException("用户处于黑名单中...")
|
||||||
|
finally:
|
||||||
|
# 记录执行时间
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||||
|
logger.warning(
|
||||||
|
f"user_handle 耗时: {elapsed:.3f}s",
|
||||||
|
LOGGER_COMMAND,
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def auth_ban(matcher: Matcher, bot: Bot, session: Uninfo) -> None:
|
||||||
|
"""权限检查 - ban 检查
|
||||||
|
|
||||||
|
参数:
|
||||||
|
matcher: Matcher
|
||||||
|
bot: Bot
|
||||||
|
session: Uninfo
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
if not check_plugin_type(matcher):
|
||||||
|
return
|
||||||
|
if not matcher.plugin_name:
|
||||||
|
return
|
||||||
|
entity = get_entity_ids(session)
|
||||||
|
if entity.user_id in bot.config.superusers:
|
||||||
|
return
|
||||||
|
if entity.group_id:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
group_handle(entity.group_id), timeout=DB_TIMEOUT_SECONDS
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(f"群组ban检查超时: {entity.group_id}", LOGGER_COMMAND)
|
||||||
|
# 超时时不阻塞,继续执行
|
||||||
|
|
||||||
|
if entity.user_id:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
user_handle(matcher.plugin_name, entity, session),
|
||||||
|
timeout=DB_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(f"用户ban检查超时: {entity.user_id}", LOGGER_COMMAND)
|
||||||
|
# 超时时不阻塞,继续执行
|
||||||
|
finally:
|
||||||
|
# 记录总执行时间
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||||
|
logger.warning(
|
||||||
|
f"auth_ban 总耗时: {elapsed:.3f}s, plugin={matcher.plugin_name}",
|
||||||
|
LOGGER_COMMAND,
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
55
zhenxun/builtin_plugins/hooks/auth/auth_bot.py
Normal file
55
zhenxun/builtin_plugins/hooks/auth/auth_bot.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
|
from zhenxun.models.bot_console import BotConsole
|
||||||
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
|
from zhenxun.services.data_access import DataAccess
|
||||||
|
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
||||||
|
from zhenxun.services.log import logger
|
||||||
|
from zhenxun.utils.common_utils import CommonUtils
|
||||||
|
|
||||||
|
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
|
||||||
|
from .exception import SkipPluginException
|
||||||
|
|
||||||
|
|
||||||
|
async def auth_bot(plugin: PluginInfo, bot_id: str):
|
||||||
|
"""bot层面的权限检查
|
||||||
|
|
||||||
|
参数:
|
||||||
|
plugin: PluginInfo
|
||||||
|
bot_id: bot id
|
||||||
|
|
||||||
|
异常:
|
||||||
|
SkipPluginException: 忽略插件
|
||||||
|
SkipPluginException: 忽略插件
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 从数据库或缓存中获取 bot 信息
|
||||||
|
bot_dao = DataAccess(BotConsole)
|
||||||
|
|
||||||
|
try:
|
||||||
|
bot: BotConsole | None = await asyncio.wait_for(
|
||||||
|
bot_dao.safe_get_or_none(bot_id=bot_id), timeout=DB_TIMEOUT_SECONDS
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(f"查询Bot信息超时: bot_id={bot_id}", LOGGER_COMMAND)
|
||||||
|
# 超时时不阻塞,继续执行
|
||||||
|
return
|
||||||
|
|
||||||
|
if not bot or not bot.status:
|
||||||
|
raise SkipPluginException("Bot不存在或休眠中阻断权限检测...")
|
||||||
|
if CommonUtils.format(plugin.module) in bot.block_plugins:
|
||||||
|
raise SkipPluginException(
|
||||||
|
f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭..."
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
# 记录执行时间
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||||
|
logger.warning(
|
||||||
|
f"auth_bot 耗时: {elapsed:.3f}s, "
|
||||||
|
f"bot_id={bot_id}, plugin={plugin.module}",
|
||||||
|
LOGGER_COMMAND,
|
||||||
|
)
|
||||||
41
zhenxun/builtin_plugins/hooks/auth/auth_cost.py
Normal file
41
zhenxun/builtin_plugins/hooks/auth/auth_cost.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
from nonebot_plugin_uninfo import Uninfo
|
||||||
|
|
||||||
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
|
from zhenxun.models.user_console import UserConsole
|
||||||
|
from zhenxun.services.log import logger
|
||||||
|
|
||||||
|
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
|
||||||
|
from .exception import SkipPluginException
|
||||||
|
from .utils import send_message
|
||||||
|
|
||||||
|
|
||||||
|
async def auth_cost(user: UserConsole, plugin: PluginInfo, session: Uninfo) -> int:
|
||||||
|
"""检测是否满足金币条件
|
||||||
|
|
||||||
|
参数:
|
||||||
|
user: UserConsole
|
||||||
|
plugin: PluginInfo
|
||||||
|
session: Uninfo
|
||||||
|
|
||||||
|
返回:
|
||||||
|
int: 需要消耗的金币
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if user.gold < plugin.cost_gold:
|
||||||
|
"""插件消耗金币不足"""
|
||||||
|
await send_message(session, f"金币不足..该功能需要{plugin.cost_gold}金币..")
|
||||||
|
raise SkipPluginException(f"{plugin.name}({plugin.module}) 金币限制...")
|
||||||
|
return plugin.cost_gold
|
||||||
|
finally:
|
||||||
|
# 记录执行时间
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||||
|
logger.warning(
|
||||||
|
f"auth_cost 耗时: {elapsed:.3f}s, plugin={plugin.module}",
|
||||||
|
LOGGER_COMMAND,
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
68
zhenxun/builtin_plugins/hooks/auth/auth_group.py
Normal file
68
zhenxun/builtin_plugins/hooks/auth/auth_group.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
|
from nonebot_plugin_alconna import UniMsg
|
||||||
|
|
||||||
|
from zhenxun.models.group_console import GroupConsole
|
||||||
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
|
from zhenxun.services.data_access import DataAccess
|
||||||
|
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
||||||
|
from zhenxun.services.log import logger
|
||||||
|
from zhenxun.utils.utils import EntityIDs
|
||||||
|
|
||||||
|
from .config import LOGGER_COMMAND, WARNING_THRESHOLD, SwitchEnum
|
||||||
|
from .exception import SkipPluginException
|
||||||
|
|
||||||
|
|
||||||
|
async def auth_group(plugin: PluginInfo, entity: EntityIDs, message: UniMsg):
|
||||||
|
"""群黑名单检测 群总开关检测
|
||||||
|
|
||||||
|
参数:
|
||||||
|
plugin: PluginInfo
|
||||||
|
entity: EntityIDs
|
||||||
|
message: UniMsg
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
if not entity.group_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
text = message.extract_plain_text()
|
||||||
|
|
||||||
|
# 从数据库或缓存中获取群组信息
|
||||||
|
group_dao = DataAccess(GroupConsole)
|
||||||
|
|
||||||
|
try:
|
||||||
|
group: GroupConsole | None = await asyncio.wait_for(
|
||||||
|
group_dao.safe_get_or_none(
|
||||||
|
group_id=entity.group_id, channel_id__isnull=True
|
||||||
|
),
|
||||||
|
timeout=DB_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error("查询群组信息超时", LOGGER_COMMAND, session=entity.user_id)
|
||||||
|
# 超时时不阻塞,继续执行
|
||||||
|
return
|
||||||
|
|
||||||
|
if not group:
|
||||||
|
raise SkipPluginException("群组信息不存在...")
|
||||||
|
if group.level < 0:
|
||||||
|
raise SkipPluginException("群组黑名单, 目标群组群权限权限-1...")
|
||||||
|
if text.strip() != SwitchEnum.ENABLE and not group.status:
|
||||||
|
raise SkipPluginException("群组休眠状态...")
|
||||||
|
if plugin.level > group.level:
|
||||||
|
raise SkipPluginException(
|
||||||
|
f"{plugin.name}({plugin.module}) 群等级限制,"
|
||||||
|
f"该功能需要的群等级: {plugin.level}..."
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
# 记录执行时间
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||||
|
logger.warning(
|
||||||
|
f"auth_group 耗时: {elapsed:.3f}s, plugin={plugin.module}",
|
||||||
|
LOGGER_COMMAND,
|
||||||
|
session=entity.user_id,
|
||||||
|
group_id=entity.group_id,
|
||||||
|
)
|
||||||
318
zhenxun/builtin_plugins/hooks/auth/auth_limit.py
Normal file
318
zhenxun/builtin_plugins/hooks/auth/auth_limit.py
Normal file
@ -0,0 +1,318 @@
|
|||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
|
import nonebot
|
||||||
|
from nonebot_plugin_uninfo import Uninfo
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
|
from zhenxun.models.plugin_limit import PluginLimit
|
||||||
|
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
||||||
|
from zhenxun.services.log import logger
|
||||||
|
from zhenxun.utils.enum import LimitWatchType, PluginLimitType
|
||||||
|
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||||
|
from zhenxun.utils.message import MessageUtils
|
||||||
|
from zhenxun.utils.utils import (
|
||||||
|
CountLimiter,
|
||||||
|
FreqLimiter,
|
||||||
|
UserBlockLimiter,
|
||||||
|
get_entity_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
|
||||||
|
from .exception import SkipPluginException
|
||||||
|
|
||||||
|
driver = nonebot.get_driver()
|
||||||
|
|
||||||
|
|
||||||
|
@PriorityLifecycle.on_startup(priority=5)
|
||||||
|
async def _():
|
||||||
|
"""初始化限制"""
|
||||||
|
await LimitManager.init_limit()
|
||||||
|
|
||||||
|
|
||||||
|
class Limit(BaseModel):
|
||||||
|
limit: PluginLimit
|
||||||
|
limiter: FreqLimiter | UserBlockLimiter | CountLimiter
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
|
class LimitManager:
|
||||||
|
add_module: ClassVar[list] = []
|
||||||
|
last_update_time: ClassVar[float] = 0
|
||||||
|
update_interval: ClassVar[float] = 6000 # 1小时更新一次
|
||||||
|
is_updating: ClassVar[bool] = False # 防止并发更新
|
||||||
|
|
||||||
|
cd_limit: ClassVar[dict[str, Limit]] = {}
|
||||||
|
block_limit: ClassVar[dict[str, Limit]] = {}
|
||||||
|
count_limit: ClassVar[dict[str, Limit]] = {}
|
||||||
|
|
||||||
|
# 模块限制缓存,避免频繁查询数据库
|
||||||
|
module_limit_cache: ClassVar[dict[str, tuple[float, list[PluginLimit]]]] = {}
|
||||||
|
module_cache_ttl: ClassVar[float] = 60 # 模块缓存有效期(秒)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def init_limit(cls):
|
||||||
|
"""初始化限制"""
|
||||||
|
cls.last_update_time = time.time()
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(cls.update_limits(), timeout=DB_TIMEOUT_SECONDS * 2)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error("初始化限制超时", LOGGER_COMMAND)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def update_limits(cls):
|
||||||
|
"""更新限制信息"""
|
||||||
|
# 防止并发更新
|
||||||
|
if cls.is_updating:
|
||||||
|
return
|
||||||
|
|
||||||
|
cls.is_updating = True
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
limit_list = await asyncio.wait_for(
|
||||||
|
PluginLimit.filter(status=True).all(), timeout=DB_TIMEOUT_SECONDS
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error("查询限制信息超时", LOGGER_COMMAND)
|
||||||
|
cls.is_updating = False
|
||||||
|
return
|
||||||
|
|
||||||
|
# 清空旧数据
|
||||||
|
cls.add_module = []
|
||||||
|
cls.cd_limit = {}
|
||||||
|
cls.block_limit = {}
|
||||||
|
cls.count_limit = {}
|
||||||
|
# 添加新数据
|
||||||
|
for limit in limit_list:
|
||||||
|
cls.add_limit(limit)
|
||||||
|
|
||||||
|
cls.last_update_time = time.time()
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的更新
|
||||||
|
logger.warning(f"更新限制信息耗时: {elapsed:.3f}s", LOGGER_COMMAND)
|
||||||
|
finally:
|
||||||
|
cls.is_updating = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_limit(cls, limit: PluginLimit):
|
||||||
|
"""添加限制
|
||||||
|
|
||||||
|
参数:
|
||||||
|
limit: PluginLimit
|
||||||
|
"""
|
||||||
|
if limit.module not in cls.add_module:
|
||||||
|
cls.add_module.append(limit.module)
|
||||||
|
if limit.limit_type == PluginLimitType.BLOCK:
|
||||||
|
cls.block_limit[limit.module] = Limit(
|
||||||
|
limit=limit, limiter=UserBlockLimiter()
|
||||||
|
)
|
||||||
|
elif limit.limit_type == PluginLimitType.CD:
|
||||||
|
cls.cd_limit[limit.module] = Limit(
|
||||||
|
limit=limit, limiter=FreqLimiter(limit.cd)
|
||||||
|
)
|
||||||
|
elif limit.limit_type == PluginLimitType.COUNT:
|
||||||
|
cls.count_limit[limit.module] = Limit(
|
||||||
|
limit=limit, limiter=CountLimiter(limit.max_count)
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def unblock(
|
||||||
|
cls, module: str, user_id: str, group_id: str | None, channel_id: str | None
|
||||||
|
):
|
||||||
|
"""解除插件block
|
||||||
|
|
||||||
|
参数:
|
||||||
|
module: 模块名
|
||||||
|
user_id: 用户id
|
||||||
|
group_id: 群组id
|
||||||
|
channel_id: 频道id
|
||||||
|
"""
|
||||||
|
if limit_model := cls.block_limit.get(module):
|
||||||
|
limit = limit_model.limit
|
||||||
|
limiter: UserBlockLimiter = limit_model.limiter # type: ignore
|
||||||
|
key_type = user_id
|
||||||
|
if group_id and limit.watch_type == LimitWatchType.GROUP:
|
||||||
|
key_type = channel_id or group_id
|
||||||
|
logger.debug(
|
||||||
|
f"解除对象: {key_type} 的block限制",
|
||||||
|
LOGGER_COMMAND,
|
||||||
|
session=user_id,
|
||||||
|
group_id=group_id,
|
||||||
|
)
|
||||||
|
limiter.set_false(key_type)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_module_limits(cls, module: str) -> list[PluginLimit]:
|
||||||
|
"""获取模块的限制信息,使用缓存减少数据库查询
|
||||||
|
|
||||||
|
参数:
|
||||||
|
module: 模块名
|
||||||
|
|
||||||
|
返回:
|
||||||
|
list[PluginLimit]: 限制列表
|
||||||
|
"""
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
# 检查缓存
|
||||||
|
if module in cls.module_limit_cache:
|
||||||
|
cache_time, limits = cls.module_limit_cache[module]
|
||||||
|
if current_time - cache_time < cls.module_cache_ttl:
|
||||||
|
return limits
|
||||||
|
|
||||||
|
# 缓存不存在或已过期,从数据库查询
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
limits = await asyncio.wait_for(
|
||||||
|
PluginLimit.filter(module=module, status=True).all(),
|
||||||
|
timeout=DB_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的查询
|
||||||
|
logger.warning(
|
||||||
|
f"查询模块限制信息耗时: {elapsed:.3f}s, 模块: {module}",
|
||||||
|
LOGGER_COMMAND,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新缓存
|
||||||
|
cls.module_limit_cache[module] = (current_time, limits)
|
||||||
|
return limits
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(f"查询模块限制信息超时: {module}", LOGGER_COMMAND)
|
||||||
|
# 超时时返回空列表,避免阻塞
|
||||||
|
return []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def check(
|
||||||
|
cls,
|
||||||
|
module: str,
|
||||||
|
user_id: str,
|
||||||
|
group_id: str | None,
|
||||||
|
channel_id: str | None,
|
||||||
|
):
|
||||||
|
"""检测限制
|
||||||
|
|
||||||
|
参数:
|
||||||
|
module: 模块名
|
||||||
|
user_id: 用户id
|
||||||
|
group_id: 群组id
|
||||||
|
channel_id: 频道id
|
||||||
|
|
||||||
|
异常:
|
||||||
|
IgnoredException: IgnoredException
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 定期更新全局限制信息
|
||||||
|
if (
|
||||||
|
time.time() - cls.last_update_time > cls.update_interval
|
||||||
|
and not cls.is_updating
|
||||||
|
):
|
||||||
|
# 使用异步任务更新,避免阻塞当前请求
|
||||||
|
asyncio.create_task(cls.update_limits()) # noqa: RUF006
|
||||||
|
|
||||||
|
# 如果模块不在已加载列表中,只加载该模块的限制
|
||||||
|
if module not in cls.add_module:
|
||||||
|
limits = await cls.get_module_limits(module)
|
||||||
|
for limit in limits:
|
||||||
|
cls.add_limit(limit)
|
||||||
|
|
||||||
|
# 检查各种限制
|
||||||
|
try:
|
||||||
|
if limit_model := cls.cd_limit.get(module):
|
||||||
|
await cls.__check(limit_model, user_id, group_id, channel_id)
|
||||||
|
if limit_model := cls.block_limit.get(module):
|
||||||
|
await cls.__check(limit_model, user_id, group_id, channel_id)
|
||||||
|
if limit_model := cls.count_limit.get(module):
|
||||||
|
await cls.__check(limit_model, user_id, group_id, channel_id)
|
||||||
|
finally:
|
||||||
|
# 记录总执行时间
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||||
|
logger.warning(
|
||||||
|
f"限制检查耗时: {elapsed:.3f}s, 模块: {module}",
|
||||||
|
LOGGER_COMMAND,
|
||||||
|
session=user_id,
|
||||||
|
group_id=group_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def __check(
|
||||||
|
cls,
|
||||||
|
limit_model: Limit | None,
|
||||||
|
user_id: str,
|
||||||
|
group_id: str | None,
|
||||||
|
channel_id: str | None,
|
||||||
|
):
|
||||||
|
"""检测限制
|
||||||
|
|
||||||
|
参数:
|
||||||
|
limit_model: Limit
|
||||||
|
user_id: 用户id
|
||||||
|
group_id: 群组id
|
||||||
|
channel_id: 频道id
|
||||||
|
|
||||||
|
异常:
|
||||||
|
IgnoredException: IgnoredException
|
||||||
|
"""
|
||||||
|
if not limit_model:
|
||||||
|
return
|
||||||
|
limit = limit_model.limit
|
||||||
|
limiter = limit_model.limiter
|
||||||
|
is_limit = (
|
||||||
|
LimitWatchType.ALL
|
||||||
|
or (group_id and limit.watch_type == LimitWatchType.GROUP)
|
||||||
|
or (not group_id and limit.watch_type == LimitWatchType.USER)
|
||||||
|
)
|
||||||
|
key_type = user_id
|
||||||
|
if group_id and limit.watch_type == LimitWatchType.GROUP:
|
||||||
|
key_type = channel_id or group_id
|
||||||
|
if is_limit and not limiter.check(key_type):
|
||||||
|
if limit.result:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
MessageUtils.build_message(limit.result).send(),
|
||||||
|
timeout=DB_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(f"发送限制消息超时: {limit.module}", LOGGER_COMMAND)
|
||||||
|
raise SkipPluginException(
|
||||||
|
f"{limit.module}({limit.limit_type}) 正在限制中..."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
f"开始进行限制 {limit.module}({limit.limit_type})...",
|
||||||
|
LOGGER_COMMAND,
|
||||||
|
session=user_id,
|
||||||
|
group_id=group_id,
|
||||||
|
)
|
||||||
|
if isinstance(limiter, FreqLimiter):
|
||||||
|
limiter.start_cd(key_type)
|
||||||
|
if isinstance(limiter, UserBlockLimiter):
|
||||||
|
limiter.set_true(key_type)
|
||||||
|
if isinstance(limiter, CountLimiter):
|
||||||
|
limiter.increase(key_type)
|
||||||
|
|
||||||
|
|
||||||
|
async def auth_limit(plugin: PluginInfo, session: Uninfo):
|
||||||
|
"""插件限制
|
||||||
|
|
||||||
|
参数:
|
||||||
|
plugin: PluginInfo
|
||||||
|
session: Uninfo
|
||||||
|
"""
|
||||||
|
entity = get_entity_ids(session)
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
LimitManager.check(
|
||||||
|
plugin.module, entity.user_id, entity.group_id, entity.channel_id
|
||||||
|
),
|
||||||
|
timeout=DB_TIMEOUT_SECONDS * 2, # 给予更长的超时时间
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(f"检查插件限制超时: {plugin.module}", LOGGER_COMMAND)
|
||||||
|
# 超时时不抛出异常,允许继续执行
|
||||||
242
zhenxun/builtin_plugins/hooks/auth/auth_plugin.py
Normal file
242
zhenxun/builtin_plugins/hooks/auth/auth_plugin.py
Normal file
@ -0,0 +1,242 @@
|
|||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
|
from nonebot.adapters import Event
|
||||||
|
from nonebot_plugin_uninfo import Uninfo
|
||||||
|
|
||||||
|
from zhenxun.models.group_console import GroupConsole
|
||||||
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
|
from zhenxun.services.data_access import DataAccess
|
||||||
|
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
||||||
|
from zhenxun.services.log import logger
|
||||||
|
from zhenxun.utils.common_utils import CommonUtils
|
||||||
|
from zhenxun.utils.enum import BlockType
|
||||||
|
from zhenxun.utils.utils import get_entity_ids
|
||||||
|
|
||||||
|
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
|
||||||
|
from .exception import IsSuperuserException, SkipPluginException
|
||||||
|
from .utils import freq, is_poke, send_message
|
||||||
|
|
||||||
|
|
||||||
|
class GroupCheck:
|
||||||
|
def __init__(
|
||||||
|
self, plugin: PluginInfo, group_id: str, session: Uninfo, is_poke: bool
|
||||||
|
) -> None:
|
||||||
|
self.group_id = group_id
|
||||||
|
self.session = session
|
||||||
|
self.is_poke = is_poke
|
||||||
|
self.plugin = plugin
|
||||||
|
self.group_dao = DataAccess(GroupConsole)
|
||||||
|
self.group_data = None
|
||||||
|
|
||||||
|
async def check(self):
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
# 只查询一次数据库,使用 DataAccess 的缓存机制
|
||||||
|
try:
|
||||||
|
self.group_data = await asyncio.wait_for(
|
||||||
|
self.group_dao.safe_get_or_none(
|
||||||
|
group_id=self.group_id, channel_id__isnull=True
|
||||||
|
),
|
||||||
|
timeout=DB_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(f"查询群组数据超时: {self.group_id}", LOGGER_COMMAND)
|
||||||
|
return # 超时时不阻塞,继续执行
|
||||||
|
|
||||||
|
# 检查超级用户禁用
|
||||||
|
if (
|
||||||
|
self.group_data
|
||||||
|
and CommonUtils.format(self.plugin.module)
|
||||||
|
in self.group_data.superuser_block_plugin
|
||||||
|
):
|
||||||
|
if freq.is_send_limit_message(self.plugin, self.group_id, self.is_poke):
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
send_message(
|
||||||
|
self.session,
|
||||||
|
"超级管理员禁用了该群此功能...",
|
||||||
|
self.group_id,
|
||||||
|
),
|
||||||
|
timeout=DB_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(f"发送消息超时: {self.group_id}", LOGGER_COMMAND)
|
||||||
|
raise SkipPluginException(
|
||||||
|
f"{self.plugin.name}({self.plugin.module})"
|
||||||
|
f" 超级管理员禁用了该群此功能..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查普通禁用
|
||||||
|
if (
|
||||||
|
self.group_data
|
||||||
|
and CommonUtils.format(self.plugin.module)
|
||||||
|
in self.group_data.block_plugin
|
||||||
|
):
|
||||||
|
if freq.is_send_limit_message(self.plugin, self.group_id, self.is_poke):
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
send_message(
|
||||||
|
self.session, "该群未开启此功能...", self.group_id
|
||||||
|
),
|
||||||
|
timeout=DB_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(f"发送消息超时: {self.group_id}", LOGGER_COMMAND)
|
||||||
|
raise SkipPluginException(
|
||||||
|
f"{self.plugin.name}({self.plugin.module}) 未开启此功能..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查全局禁用
|
||||||
|
if self.plugin.block_type == BlockType.GROUP:
|
||||||
|
if freq.is_send_limit_message(self.plugin, self.group_id, self.is_poke):
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
send_message(
|
||||||
|
self.session, "该功能在群组中已被禁用...", self.group_id
|
||||||
|
),
|
||||||
|
timeout=DB_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(f"发送消息超时: {self.group_id}", LOGGER_COMMAND)
|
||||||
|
raise SkipPluginException(
|
||||||
|
f"{self.plugin.name}({self.plugin.module})该插件在群组中已被禁用..."
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
# 记录执行时间
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||||
|
logger.warning(
|
||||||
|
f"GroupCheck.check 耗时: {elapsed:.3f}s, 群组: {self.group_id}",
|
||||||
|
LOGGER_COMMAND,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginCheck:
|
||||||
|
def __init__(self, group_id: str | None, session: Uninfo, is_poke: bool):
|
||||||
|
self.session = session
|
||||||
|
self.is_poke = is_poke
|
||||||
|
self.group_id = group_id
|
||||||
|
self.group_dao = DataAccess(GroupConsole)
|
||||||
|
self.group_data = None
|
||||||
|
|
||||||
|
async def check_user(self, plugin: PluginInfo):
|
||||||
|
"""全局私聊禁用检测
|
||||||
|
|
||||||
|
参数:
|
||||||
|
plugin: PluginInfo
|
||||||
|
|
||||||
|
异常:
|
||||||
|
IgnoredException: 忽略插件
|
||||||
|
"""
|
||||||
|
if plugin.block_type == BlockType.PRIVATE:
|
||||||
|
if freq.is_send_limit_message(plugin, self.session.user.id, self.is_poke):
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
send_message(self.session, "该功能在私聊中已被禁用..."),
|
||||||
|
timeout=DB_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error("发送消息超时", LOGGER_COMMAND)
|
||||||
|
raise SkipPluginException(
|
||||||
|
f"{plugin.name}({plugin.module}) 该插件在私聊中已被禁用..."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def check_global(self, plugin: PluginInfo):
|
||||||
|
"""全局状态
|
||||||
|
|
||||||
|
参数:
|
||||||
|
plugin: PluginInfo
|
||||||
|
|
||||||
|
异常:
|
||||||
|
IgnoredException: 忽略插件
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
if plugin.status or plugin.block_type != BlockType.ALL:
|
||||||
|
return
|
||||||
|
"""全局状态"""
|
||||||
|
if self.group_id:
|
||||||
|
# 使用 DataAccess 的缓存机制
|
||||||
|
try:
|
||||||
|
self.group_data = await asyncio.wait_for(
|
||||||
|
self.group_dao.safe_get_or_none(
|
||||||
|
group_id=self.group_id, channel_id__isnull=True
|
||||||
|
),
|
||||||
|
timeout=DB_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(f"查询群组数据超时: {self.group_id}", LOGGER_COMMAND)
|
||||||
|
return # 超时时不阻塞,继续执行
|
||||||
|
|
||||||
|
if self.group_data and self.group_data.is_super:
|
||||||
|
raise IsSuperuserException()
|
||||||
|
|
||||||
|
sid = self.group_id or self.session.user.id
|
||||||
|
if freq.is_send_limit_message(plugin, sid, self.is_poke):
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
send_message(self.session, "全局未开启此功能...", sid),
|
||||||
|
timeout=DB_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(f"发送消息超时: {sid}", LOGGER_COMMAND)
|
||||||
|
raise SkipPluginException(
|
||||||
|
f"{plugin.name}({plugin.module}) 全局未开启此功能..."
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
# 记录执行时间
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||||
|
logger.warning(
|
||||||
|
f"PluginCheck.check_global 耗时: {elapsed:.3f}s", LOGGER_COMMAND
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def auth_plugin(plugin: PluginInfo, session: Uninfo, event: Event):
|
||||||
|
"""插件状态
|
||||||
|
|
||||||
|
参数:
|
||||||
|
plugin: PluginInfo
|
||||||
|
session: Uninfo
|
||||||
|
event: Event
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
entity = get_entity_ids(session)
|
||||||
|
is_poke_event = is_poke(event)
|
||||||
|
user_check = PluginCheck(entity.group_id, session, is_poke_event)
|
||||||
|
|
||||||
|
if entity.group_id:
|
||||||
|
group_check = GroupCheck(plugin, entity.group_id, session, is_poke_event)
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
group_check.check(), timeout=DB_TIMEOUT_SECONDS * 2
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(f"群组检查超时: {entity.group_id}", LOGGER_COMMAND)
|
||||||
|
# 超时时不阻塞,继续执行
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
user_check.check_user(plugin), timeout=DB_TIMEOUT_SECONDS
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error("用户检查超时", LOGGER_COMMAND)
|
||||||
|
# 超时时不阻塞,继续执行
|
||||||
|
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
user_check.check_global(plugin), timeout=DB_TIMEOUT_SECONDS
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error("全局检查超时", LOGGER_COMMAND)
|
||||||
|
# 超时时不阻塞,继续执行
|
||||||
|
finally:
|
||||||
|
# 记录总执行时间
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
||||||
|
logger.warning(
|
||||||
|
f"auth_plugin 总耗时: {elapsed:.3f}s, 模块: {plugin.module}",
|
||||||
|
LOGGER_COMMAND,
|
||||||
|
)
|
||||||
35
zhenxun/builtin_plugins/hooks/auth/bot_filter.py
Normal file
35
zhenxun/builtin_plugins/hooks/auth/bot_filter.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import nonebot
|
||||||
|
from nonebot_plugin_uninfo import Uninfo
|
||||||
|
|
||||||
|
from zhenxun.configs.config import Config
|
||||||
|
|
||||||
|
from .exception import SkipPluginException
|
||||||
|
|
||||||
|
Config.add_plugin_config(
|
||||||
|
"hook",
|
||||||
|
"FILTER_BOT",
|
||||||
|
True,
|
||||||
|
help="过滤当前连接bot(防止bot互相调用)",
|
||||||
|
default_value=True,
|
||||||
|
type=bool,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def bot_filter(session: Uninfo):
|
||||||
|
"""过滤bot调用bot
|
||||||
|
|
||||||
|
参数:
|
||||||
|
session: Uninfo
|
||||||
|
|
||||||
|
异常:
|
||||||
|
SkipPluginException: bot互相调用
|
||||||
|
"""
|
||||||
|
if not Config.get_config("hook", "FILTER_BOT"):
|
||||||
|
return
|
||||||
|
bot_ids = list(nonebot.get_bots().keys())
|
||||||
|
if session.user.id == session.self_id:
|
||||||
|
return
|
||||||
|
if session.user.id in bot_ids:
|
||||||
|
raise SkipPluginException(
|
||||||
|
f"bot:{session.self_id} 尝试调用 bot:{session.user.id}"
|
||||||
|
)
|
||||||
16
zhenxun/builtin_plugins/hooks/auth/config.py
Normal file
16
zhenxun/builtin_plugins/hooks/auth/config.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
import sys
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 11):
|
||||||
|
from enum import StrEnum
|
||||||
|
else:
|
||||||
|
from strenum import StrEnum
|
||||||
|
|
||||||
|
LOGGER_COMMAND = "AuthChecker"
|
||||||
|
|
||||||
|
|
||||||
|
class SwitchEnum(StrEnum):
|
||||||
|
ENABLE = "醒来"
|
||||||
|
DISABLE = "休息吧"
|
||||||
|
|
||||||
|
|
||||||
|
WARNING_THRESHOLD = 0.5 # 警告阈值(秒)
|
||||||
26
zhenxun/builtin_plugins/hooks/auth/exception.py
Normal file
26
zhenxun/builtin_plugins/hooks/auth/exception.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
class IsSuperuserException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SkipPluginException(Exception):
|
||||||
|
def __init__(self, info: str, *args: object) -> None:
|
||||||
|
super().__init__(*args)
|
||||||
|
self.info = info
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.info
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return self.info
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionExemption(Exception):
|
||||||
|
def __init__(self, info: str, *args: object) -> None:
|
||||||
|
super().__init__(*args)
|
||||||
|
self.info = info
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.info
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return self.info
|
||||||
91
zhenxun/builtin_plugins/hooks/auth/utils.py
Normal file
91
zhenxun/builtin_plugins/hooks/auth/utils.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
import contextlib
|
||||||
|
|
||||||
|
from nonebot.adapters import Event
|
||||||
|
from nonebot_plugin_uninfo import Uninfo
|
||||||
|
|
||||||
|
from zhenxun.configs.config import Config
|
||||||
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
|
from zhenxun.services.log import logger
|
||||||
|
from zhenxun.utils.enum import PluginType
|
||||||
|
from zhenxun.utils.message import MessageUtils
|
||||||
|
from zhenxun.utils.utils import FreqLimiter
|
||||||
|
|
||||||
|
from .config import LOGGER_COMMAND
|
||||||
|
|
||||||
|
base_config = Config.get("hook")
|
||||||
|
|
||||||
|
|
||||||
|
def is_poke(event: Event) -> bool:
|
||||||
|
"""判断是否为poke类型
|
||||||
|
|
||||||
|
参数:
|
||||||
|
event: Event
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bool: 是否为poke类型
|
||||||
|
"""
|
||||||
|
with contextlib.suppress(ImportError):
|
||||||
|
from nonebot.adapters.onebot.v11 import PokeNotifyEvent
|
||||||
|
|
||||||
|
return isinstance(event, PokeNotifyEvent)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def send_message(
|
||||||
|
session: Uninfo, message: list | str, check_tag: str | None = None
|
||||||
|
):
|
||||||
|
"""发送消息
|
||||||
|
|
||||||
|
参数:
|
||||||
|
session: Uninfo
|
||||||
|
message: 消息
|
||||||
|
check_tag: cd flag
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not check_tag:
|
||||||
|
await MessageUtils.build_message(message).send(reply_to=True)
|
||||||
|
elif freq._flmt.check(check_tag):
|
||||||
|
freq._flmt.start_cd(check_tag)
|
||||||
|
await MessageUtils.build_message(message).send(reply_to=True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"发送消息失败",
|
||||||
|
LOGGER_COMMAND,
|
||||||
|
session=session,
|
||||||
|
e=e,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FreqUtils:
|
||||||
|
def __init__(self):
|
||||||
|
check_notice_info_cd = Config.get_config("hook", "CHECK_NOTICE_INFO_CD")
|
||||||
|
if check_notice_info_cd is None or check_notice_info_cd < 0:
|
||||||
|
raise ValueError("模块: [hook], 配置项: [CHECK_NOTICE_INFO_CD] 为空或小于0")
|
||||||
|
self._flmt = FreqLimiter(check_notice_info_cd)
|
||||||
|
self._flmt_g = FreqLimiter(check_notice_info_cd)
|
||||||
|
self._flmt_s = FreqLimiter(check_notice_info_cd)
|
||||||
|
self._flmt_c = FreqLimiter(check_notice_info_cd)
|
||||||
|
|
||||||
|
def is_send_limit_message(
|
||||||
|
self, plugin: PluginInfo, sid: str, is_poke: bool
|
||||||
|
) -> bool:
|
||||||
|
"""是否发送提示消息
|
||||||
|
|
||||||
|
参数:
|
||||||
|
plugin: PluginInfo
|
||||||
|
sid: 检测键
|
||||||
|
is_poke: 是否是戳一戳
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bool: 是否发送提示消息
|
||||||
|
"""
|
||||||
|
if is_poke:
|
||||||
|
return False
|
||||||
|
if not base_config.get("IS_SEND_TIP_MESSAGE"):
|
||||||
|
return False
|
||||||
|
if plugin.plugin_type == PluginType.DEPENDANT:
|
||||||
|
return False
|
||||||
|
return plugin.module != "ai" if self._flmt_s.check(sid) else False
|
||||||
|
|
||||||
|
|
||||||
|
freq = FreqUtils()
|
||||||
375
zhenxun/builtin_plugins/hooks/auth_checker.py
Normal file
375
zhenxun/builtin_plugins/hooks/auth_checker.py
Normal file
@ -0,0 +1,375 @@
|
|||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
|
from nonebot.adapters import Bot, Event
|
||||||
|
from nonebot.exception import IgnoredException
|
||||||
|
from nonebot.matcher import Matcher
|
||||||
|
from nonebot_plugin_alconna import UniMsg
|
||||||
|
from nonebot_plugin_uninfo import Uninfo
|
||||||
|
from tortoise.exceptions import IntegrityError
|
||||||
|
|
||||||
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
|
from zhenxun.models.user_console import UserConsole
|
||||||
|
from zhenxun.services.data_access import DataAccess
|
||||||
|
from zhenxun.services.log import logger
|
||||||
|
from zhenxun.utils.enum import GoldHandle, PluginType
|
||||||
|
from zhenxun.utils.exception import InsufficientGold
|
||||||
|
from zhenxun.utils.platform import PlatformUtils
|
||||||
|
from zhenxun.utils.utils import get_entity_ids
|
||||||
|
|
||||||
|
from .auth.auth_admin import auth_admin
|
||||||
|
from .auth.auth_ban import auth_ban
|
||||||
|
from .auth.auth_bot import auth_bot
|
||||||
|
from .auth.auth_cost import auth_cost
|
||||||
|
from .auth.auth_group import auth_group
|
||||||
|
from .auth.auth_limit import LimitManager, auth_limit
|
||||||
|
from .auth.auth_plugin import auth_plugin
|
||||||
|
from .auth.bot_filter import bot_filter
|
||||||
|
from .auth.config import LOGGER_COMMAND, WARNING_THRESHOLD
|
||||||
|
from .auth.exception import (
|
||||||
|
IsSuperuserException,
|
||||||
|
PermissionExemption,
|
||||||
|
SkipPluginException,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 超时设置(秒)
|
||||||
|
TIMEOUT_SECONDS = 5.0
|
||||||
|
# 熔断计数器
|
||||||
|
CIRCUIT_BREAKERS = {
|
||||||
|
"auth_ban": {"failures": 0, "threshold": 3, "active": False, "reset_time": 0},
|
||||||
|
"auth_bot": {"failures": 0, "threshold": 3, "active": False, "reset_time": 0},
|
||||||
|
"auth_group": {"failures": 0, "threshold": 3, "active": False, "reset_time": 0},
|
||||||
|
"auth_admin": {"failures": 0, "threshold": 3, "active": False, "reset_time": 0},
|
||||||
|
"auth_plugin": {"failures": 0, "threshold": 3, "active": False, "reset_time": 0},
|
||||||
|
"auth_limit": {"failures": 0, "threshold": 3, "active": False, "reset_time": 0},
|
||||||
|
}
|
||||||
|
# 熔断重置时间(秒)
|
||||||
|
CIRCUIT_RESET_TIME = 300 # 5分钟
|
||||||
|
|
||||||
|
|
||||||
|
# 超时装饰器
|
||||||
|
async def with_timeout(coro, timeout=TIMEOUT_SECONDS, name=None):
|
||||||
|
"""带超时控制的协程执行
|
||||||
|
|
||||||
|
参数:
|
||||||
|
coro: 要执行的协程
|
||||||
|
timeout: 超时时间(秒)
|
||||||
|
name: 操作名称,用于日志记录
|
||||||
|
|
||||||
|
返回:
|
||||||
|
协程的返回值,或者在超时时抛出 TimeoutError
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return await asyncio.wait_for(coro, timeout=timeout)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
if name:
|
||||||
|
logger.error(f"{name} 操作超时 (>{timeout}s)", LOGGER_COMMAND)
|
||||||
|
# 更新熔断计数器
|
||||||
|
if name in CIRCUIT_BREAKERS:
|
||||||
|
CIRCUIT_BREAKERS[name]["failures"] += 1
|
||||||
|
if (
|
||||||
|
CIRCUIT_BREAKERS[name]["failures"]
|
||||||
|
>= CIRCUIT_BREAKERS[name]["threshold"]
|
||||||
|
and not CIRCUIT_BREAKERS[name]["active"]
|
||||||
|
):
|
||||||
|
CIRCUIT_BREAKERS[name]["active"] = True
|
||||||
|
CIRCUIT_BREAKERS[name]["reset_time"] = (
|
||||||
|
time.time() + CIRCUIT_RESET_TIME
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
f"{name} 熔断器已激活,将在 {CIRCUIT_RESET_TIME} 秒后重置",
|
||||||
|
LOGGER_COMMAND,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# 检查熔断状态
|
||||||
|
def check_circuit_breaker(name):
|
||||||
|
"""检查熔断器状态
|
||||||
|
|
||||||
|
参数:
|
||||||
|
name: 操作名称
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bool: 是否已熔断
|
||||||
|
"""
|
||||||
|
if name not in CIRCUIT_BREAKERS:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 检查是否需要重置熔断器
|
||||||
|
if (
|
||||||
|
CIRCUIT_BREAKERS[name]["active"]
|
||||||
|
and time.time() > CIRCUIT_BREAKERS[name]["reset_time"]
|
||||||
|
):
|
||||||
|
CIRCUIT_BREAKERS[name]["active"] = False
|
||||||
|
CIRCUIT_BREAKERS[name]["failures"] = 0
|
||||||
|
logger.info(f"{name} 熔断器已重置", LOGGER_COMMAND)
|
||||||
|
|
||||||
|
return CIRCUIT_BREAKERS[name]["active"]
|
||||||
|
|
||||||
|
|
||||||
|
async def get_plugin_and_user(
|
||||||
|
module: str, user_id: str
|
||||||
|
) -> tuple[PluginInfo, UserConsole]:
|
||||||
|
"""获取用户数据和插件信息
|
||||||
|
|
||||||
|
参数:
|
||||||
|
module: 模块名
|
||||||
|
user_id: 用户id
|
||||||
|
|
||||||
|
异常:
|
||||||
|
PermissionExemption: 插件数据不存在
|
||||||
|
PermissionExemption: 插件类型为HIDDEN
|
||||||
|
PermissionExemption: 重复创建用户
|
||||||
|
PermissionExemption: 用户数据不存在
|
||||||
|
|
||||||
|
返回:
|
||||||
|
tuple[PluginInfo, UserConsole]: 插件信息,用户信息
|
||||||
|
"""
|
||||||
|
user_dao = DataAccess(UserConsole)
|
||||||
|
plugin_dao = DataAccess(PluginInfo)
|
||||||
|
|
||||||
|
# 并行查询插件和用户数据
|
||||||
|
plugin_task = plugin_dao.safe_get_or_none(module=module)
|
||||||
|
user_task = user_dao.safe_get_or_none(user_id=user_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
plugin, user = await with_timeout(
|
||||||
|
asyncio.gather(plugin_task, user_task), name="get_plugin_and_user"
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# 如果并行查询超时,尝试串行查询
|
||||||
|
logger.warning("并行查询超时,尝试串行查询", LOGGER_COMMAND)
|
||||||
|
plugin = await with_timeout(
|
||||||
|
plugin_dao.safe_get_or_none(module=module), name="get_plugin"
|
||||||
|
)
|
||||||
|
user = await with_timeout(
|
||||||
|
user_dao.safe_get_or_none(user_id=user_id), name="get_user"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not plugin:
|
||||||
|
raise PermissionExemption(f"插件:{module} 数据不存在,已跳过权限检查...")
|
||||||
|
if plugin.plugin_type == PluginType.HIDDEN:
|
||||||
|
raise PermissionExemption(
|
||||||
|
f"插件: {plugin.name}:{plugin.module} 为HIDDEN,已跳过权限检查..."
|
||||||
|
)
|
||||||
|
user = None
|
||||||
|
try:
|
||||||
|
user = await user_dao.safe_get_or_none(user_id=user_id)
|
||||||
|
except IntegrityError as e:
|
||||||
|
raise PermissionExemption("重复创建用户,已跳过该次权限检查...") from e
|
||||||
|
if not user:
|
||||||
|
raise PermissionExemption("用户数据不存在,已跳过权限检查...")
|
||||||
|
return plugin, user
|
||||||
|
|
||||||
|
|
||||||
|
async def get_plugin_cost(
|
||||||
|
bot: Bot, user: UserConsole, plugin: PluginInfo, session: Uninfo
|
||||||
|
) -> int:
|
||||||
|
"""获取插件费用
|
||||||
|
|
||||||
|
参数:
|
||||||
|
bot: Bot
|
||||||
|
user: 用户数据
|
||||||
|
plugin: 插件数据
|
||||||
|
session: Uninfo
|
||||||
|
|
||||||
|
异常:
|
||||||
|
IsSuperuserException: 超级用户
|
||||||
|
IsSuperuserException: 超级用户
|
||||||
|
|
||||||
|
返回:
|
||||||
|
int: 调用插件金币费用
|
||||||
|
"""
|
||||||
|
cost_gold = await with_timeout(auth_cost(user, plugin, session), name="auth_cost")
|
||||||
|
if session.user.id in bot.config.superusers:
|
||||||
|
if plugin.plugin_type == PluginType.SUPERUSER:
|
||||||
|
raise IsSuperuserException()
|
||||||
|
if not plugin.limit_superuser:
|
||||||
|
raise IsSuperuserException()
|
||||||
|
return cost_gold
|
||||||
|
|
||||||
|
|
||||||
|
async def reduce_gold(user_id: str, module: str, cost_gold: int, session: Uninfo):
|
||||||
|
"""扣除用户金币
|
||||||
|
|
||||||
|
参数:
|
||||||
|
user_id: 用户id
|
||||||
|
module: 插件模块名称
|
||||||
|
cost_gold: 消耗金币
|
||||||
|
session: Uninfo
|
||||||
|
"""
|
||||||
|
user_dao = DataAccess(UserConsole)
|
||||||
|
try:
|
||||||
|
await with_timeout(
|
||||||
|
UserConsole.reduce_gold(
|
||||||
|
user_id,
|
||||||
|
cost_gold,
|
||||||
|
GoldHandle.PLUGIN,
|
||||||
|
module,
|
||||||
|
PlatformUtils.get_platform(session),
|
||||||
|
),
|
||||||
|
name="reduce_gold",
|
||||||
|
)
|
||||||
|
except InsufficientGold:
|
||||||
|
if u := await UserConsole.get_user(user_id):
|
||||||
|
u.gold = 0
|
||||||
|
await u.save(update_fields=["gold"])
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(
|
||||||
|
f"扣除金币超时,用户: {user_id}, 金币: {cost_gold}",
|
||||||
|
LOGGER_COMMAND,
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 清除缓存,使下次查询时从数据库获取最新数据
|
||||||
|
await user_dao.clear_cache(user_id=user_id)
|
||||||
|
logger.debug(f"调用功能花费金币: {cost_gold}", LOGGER_COMMAND, session=session)
|
||||||
|
|
||||||
|
|
||||||
|
# 辅助函数,用于记录每个 hook 的执行时间
|
||||||
|
async def time_hook(coro, name, time_dict):
|
||||||
|
start = time.time()
|
||||||
|
try:
|
||||||
|
# 检查熔断状态
|
||||||
|
if check_circuit_breaker(name):
|
||||||
|
logger.info(f"{name} 熔断器激活中,跳过执行", LOGGER_COMMAND)
|
||||||
|
time_dict[name] = "熔断跳过"
|
||||||
|
return
|
||||||
|
|
||||||
|
# 添加超时控制
|
||||||
|
return await with_timeout(coro, name=name)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
time_dict[name] = f"超时 (>{TIMEOUT_SECONDS}s)"
|
||||||
|
finally:
|
||||||
|
if name not in time_dict:
|
||||||
|
time_dict[name] = f"{time.time() - start:.3f}s"
|
||||||
|
|
||||||
|
|
||||||
|
async def auth(
|
||||||
|
matcher: Matcher,
|
||||||
|
event: Event,
|
||||||
|
bot: Bot,
|
||||||
|
session: Uninfo,
|
||||||
|
message: UniMsg,
|
||||||
|
):
|
||||||
|
"""权限检查
|
||||||
|
|
||||||
|
参数:
|
||||||
|
matcher: matcher
|
||||||
|
event: Event
|
||||||
|
bot: bot
|
||||||
|
session: Uninfo
|
||||||
|
message: UniMsg
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
cost_gold = 0
|
||||||
|
ignore_flag = False
|
||||||
|
entity = get_entity_ids(session)
|
||||||
|
module = matcher.plugin_name or ""
|
||||||
|
|
||||||
|
# 用于记录各个 hook 的执行时间
|
||||||
|
hook_times = {}
|
||||||
|
hooks_time = 0 # 初始化 hooks_time 变量
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not module:
|
||||||
|
raise PermissionExemption("Matcher插件名称不存在...")
|
||||||
|
|
||||||
|
# 获取插件和用户数据
|
||||||
|
plugin_user_start = time.time()
|
||||||
|
try:
|
||||||
|
plugin, user = await with_timeout(
|
||||||
|
get_plugin_and_user(module, entity.user_id), name="get_plugin_and_user"
|
||||||
|
)
|
||||||
|
hook_times["get_plugin_user"] = f"{time.time() - plugin_user_start:.3f}s"
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(
|
||||||
|
f"获取插件和用户数据超时,模块: {module}",
|
||||||
|
LOGGER_COMMAND,
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
raise PermissionExemption("获取插件和用户数据超时,请稍后再试...")
|
||||||
|
|
||||||
|
# 获取插件费用
|
||||||
|
cost_start = time.time()
|
||||||
|
try:
|
||||||
|
cost_gold = await with_timeout(
|
||||||
|
get_plugin_cost(bot, user, plugin, session), name="get_plugin_cost"
|
||||||
|
)
|
||||||
|
hook_times["cost_gold"] = f"{time.time() - cost_start:.3f}s"
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(
|
||||||
|
f"获取插件费用超时,模块: {module}", LOGGER_COMMAND, session=session
|
||||||
|
)
|
||||||
|
# 继续执行,不阻止权限检查
|
||||||
|
|
||||||
|
# 执行 bot_filter
|
||||||
|
bot_filter(session)
|
||||||
|
|
||||||
|
# 并行执行所有 hook 检查,并记录执行时间
|
||||||
|
hooks_start = time.time()
|
||||||
|
|
||||||
|
# 创建所有 hook 任务
|
||||||
|
hook_tasks = [
|
||||||
|
time_hook(auth_ban(matcher, bot, session), "auth_ban", hook_times),
|
||||||
|
time_hook(auth_bot(plugin, bot.self_id), "auth_bot", hook_times),
|
||||||
|
time_hook(auth_group(plugin, entity, message), "auth_group", hook_times),
|
||||||
|
time_hook(auth_admin(plugin, session), "auth_admin", hook_times),
|
||||||
|
time_hook(auth_plugin(plugin, session, event), "auth_plugin", hook_times),
|
||||||
|
time_hook(auth_limit(plugin, session), "auth_limit", hook_times),
|
||||||
|
]
|
||||||
|
|
||||||
|
# 使用 gather 并行执行所有 hook,但添加总体超时控制
|
||||||
|
try:
|
||||||
|
await with_timeout(
|
||||||
|
asyncio.gather(*hook_tasks),
|
||||||
|
timeout=TIMEOUT_SECONDS * 2, # 给总体执行更多时间
|
||||||
|
name="auth_hooks_gather",
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(
|
||||||
|
f"权限检查 hooks 总体执行超时,模块: {module}",
|
||||||
|
LOGGER_COMMAND,
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
# 不抛出异常,允许继续执行
|
||||||
|
|
||||||
|
hooks_time = time.time() - hooks_start
|
||||||
|
|
||||||
|
except SkipPluginException as e:
|
||||||
|
LimitManager.unblock(module, entity.user_id, entity.group_id, entity.channel_id)
|
||||||
|
logger.info(str(e), LOGGER_COMMAND, session=session)
|
||||||
|
ignore_flag = True
|
||||||
|
except IsSuperuserException:
|
||||||
|
logger.debug("超级用户跳过权限检测...", LOGGER_COMMAND, session=session)
|
||||||
|
except PermissionExemption as e:
|
||||||
|
logger.info(str(e), LOGGER_COMMAND, session=session)
|
||||||
|
|
||||||
|
# 扣除金币
|
||||||
|
if not ignore_flag and cost_gold > 0:
|
||||||
|
gold_start = time.time()
|
||||||
|
try:
|
||||||
|
await with_timeout(
|
||||||
|
reduce_gold(entity.user_id, module, cost_gold, session),
|
||||||
|
name="reduce_gold",
|
||||||
|
)
|
||||||
|
hook_times["reduce_gold"] = f"{time.time() - gold_start:.3f}s"
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(
|
||||||
|
f"扣除金币超时,模块: {module}", LOGGER_COMMAND, session=session
|
||||||
|
)
|
||||||
|
|
||||||
|
# 记录总执行时间
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
if total_time > WARNING_THRESHOLD: # 如果总时间超过500ms,记录详细信息
|
||||||
|
logger.warning(
|
||||||
|
f"权限检查耗时过长: {total_time:.3f}s, 模块: {module}, "
|
||||||
|
f"hooks时间: {hooks_time:.3f}s, "
|
||||||
|
f"详情: {hook_times}",
|
||||||
|
LOGGER_COMMAND,
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
|
||||||
|
if ignore_flag:
|
||||||
|
raise IgnoredException("权限检测 ignore")
|
||||||
@ -1,41 +1,43 @@
|
|||||||
from nonebot.adapters.onebot.v11 import Bot, Event
|
import time
|
||||||
|
|
||||||
|
from nonebot.adapters import Bot, Event
|
||||||
from nonebot.matcher import Matcher
|
from nonebot.matcher import Matcher
|
||||||
from nonebot.message import run_postprocessor, run_preprocessor
|
from nonebot.message import run_postprocessor, run_preprocessor
|
||||||
from nonebot_plugin_alconna import UniMsg
|
from nonebot_plugin_alconna import UniMsg
|
||||||
from nonebot_plugin_session import EventSession
|
from nonebot_plugin_uninfo import Uninfo
|
||||||
|
|
||||||
from ._auth_checker import LimitManage, checker
|
from zhenxun.services.log import logger
|
||||||
|
|
||||||
|
from .auth.config import LOGGER_COMMAND
|
||||||
|
from .auth_checker import LimitManager, auth
|
||||||
|
|
||||||
|
|
||||||
# # 权限检测
|
# # 权限检测
|
||||||
@run_preprocessor
|
@run_preprocessor
|
||||||
async def _(
|
async def _(matcher: Matcher, event: Event, bot: Bot, session: Uninfo, message: UniMsg):
|
||||||
matcher: Matcher, event: Event, bot: Bot, session: EventSession, message: UniMsg
|
start_time = time.time()
|
||||||
):
|
await auth(
|
||||||
await checker.auth(
|
|
||||||
matcher,
|
matcher,
|
||||||
event,
|
event,
|
||||||
bot,
|
bot,
|
||||||
session,
|
session,
|
||||||
message,
|
message,
|
||||||
)
|
)
|
||||||
|
logger.debug(f"权限检测耗时:{time.time() - start_time}秒", LOGGER_COMMAND)
|
||||||
|
|
||||||
|
|
||||||
# 解除命令block阻塞
|
# 解除命令block阻塞
|
||||||
@run_postprocessor
|
@run_postprocessor
|
||||||
async def _(
|
async def _(matcher: Matcher, session: Uninfo):
|
||||||
matcher: Matcher,
|
user_id = session.user.id
|
||||||
exception: Exception | None,
|
group_id = None
|
||||||
bot: Bot,
|
channel_id = None
|
||||||
event: Event,
|
if session.group:
|
||||||
session: EventSession,
|
if session.group.parent:
|
||||||
):
|
group_id = session.group.parent.id
|
||||||
user_id = session.id1
|
channel_id = session.group.id
|
||||||
group_id = session.id3
|
else:
|
||||||
channel_id = session.id2
|
group_id = session.group.id
|
||||||
if not group_id:
|
|
||||||
group_id = channel_id
|
|
||||||
channel_id = None
|
|
||||||
if user_id and matcher.plugin:
|
if user_id and matcher.plugin:
|
||||||
module = matcher.plugin.name
|
module = matcher.plugin.name
|
||||||
LimitManage.unblock(module, user_id, group_id, channel_id)
|
LimitManager.unblock(module, user_id, group_id, channel_id)
|
||||||
|
|||||||
@ -1,84 +0,0 @@
|
|||||||
from nonebot.adapters import Bot, Event
|
|
||||||
from nonebot.exception import IgnoredException
|
|
||||||
from nonebot.matcher import Matcher
|
|
||||||
from nonebot.message import run_preprocessor
|
|
||||||
from nonebot.typing import T_State
|
|
||||||
from nonebot_plugin_alconna import At
|
|
||||||
from nonebot_plugin_session import EventSession
|
|
||||||
|
|
||||||
from zhenxun.configs.config import Config
|
|
||||||
from zhenxun.models.ban_console import BanConsole
|
|
||||||
from zhenxun.models.group_console import GroupConsole
|
|
||||||
from zhenxun.services.log import logger
|
|
||||||
from zhenxun.utils.enum import PluginType
|
|
||||||
from zhenxun.utils.message import MessageUtils
|
|
||||||
from zhenxun.utils.utils import FreqLimiter
|
|
||||||
|
|
||||||
Config.add_plugin_config(
|
|
||||||
"hook",
|
|
||||||
"BAN_RESULT",
|
|
||||||
"才不会给你发消息.",
|
|
||||||
help="对被ban用户发送的消息",
|
|
||||||
)
|
|
||||||
|
|
||||||
_flmt = FreqLimiter(300)
|
|
||||||
|
|
||||||
|
|
||||||
# 检查是否被ban
|
|
||||||
@run_preprocessor
|
|
||||||
async def _(
|
|
||||||
matcher: Matcher, bot: Bot, event: Event, state: T_State, session: EventSession
|
|
||||||
):
|
|
||||||
extra = {}
|
|
||||||
if plugin := matcher.plugin:
|
|
||||||
if metadata := plugin.metadata:
|
|
||||||
extra = metadata.extra
|
|
||||||
if extra.get("plugin_type") in [PluginType.HIDDEN]:
|
|
||||||
return
|
|
||||||
user_id = session.id1
|
|
||||||
group_id = session.id3 or session.id2
|
|
||||||
if group_id:
|
|
||||||
if user_id in bot.config.superusers:
|
|
||||||
return
|
|
||||||
if await BanConsole.is_ban(None, group_id):
|
|
||||||
logger.debug("群组处于黑名单中...", "ban_hook")
|
|
||||||
raise IgnoredException("群组处于黑名单中...")
|
|
||||||
if g := await GroupConsole.get_group(group_id):
|
|
||||||
if g.level < 0:
|
|
||||||
logger.debug("群黑名单, 群权限-1...", "ban_hook")
|
|
||||||
raise IgnoredException("群黑名单, 群权限-1..")
|
|
||||||
if user_id:
|
|
||||||
ban_result = Config.get_config("hook", "BAN_RESULT")
|
|
||||||
if user_id in bot.config.superusers:
|
|
||||||
return
|
|
||||||
if await BanConsole.is_ban(user_id, group_id):
|
|
||||||
time = await BanConsole.check_ban_time(user_id, group_id)
|
|
||||||
if time == -1:
|
|
||||||
time_str = "∞"
|
|
||||||
else:
|
|
||||||
time = abs(int(time))
|
|
||||||
if time < 60:
|
|
||||||
time_str = f"{time!s} 秒"
|
|
||||||
else:
|
|
||||||
minute = int(time / 60)
|
|
||||||
if minute > 60:
|
|
||||||
hours = minute // 60
|
|
||||||
minute %= 60
|
|
||||||
time_str = f"{hours} 小时 {minute}分钟"
|
|
||||||
else:
|
|
||||||
time_str = f"{minute} 分钟"
|
|
||||||
if (
|
|
||||||
not extra.get("ignore_prompt")
|
|
||||||
and time != -1
|
|
||||||
and ban_result
|
|
||||||
and _flmt.check(user_id)
|
|
||||||
):
|
|
||||||
_flmt.start_cd(user_id)
|
|
||||||
await MessageUtils.build_message(
|
|
||||||
[
|
|
||||||
At(flag="user", target=user_id),
|
|
||||||
f"{ban_result}\n在..在 {time_str} 后才会理你喔",
|
|
||||||
]
|
|
||||||
).send()
|
|
||||||
logger.debug("用户处于黑名单中...", "ban_hook")
|
|
||||||
raise IgnoredException("用户处于黑名单中...")
|
|
||||||
@ -9,6 +9,8 @@ from zhenxun.utils.enum import BotSentType
|
|||||||
from zhenxun.utils.manager.message_manager import MessageManager
|
from zhenxun.utils.manager.message_manager import MessageManager
|
||||||
from zhenxun.utils.platform import PlatformUtils
|
from zhenxun.utils.platform import PlatformUtils
|
||||||
|
|
||||||
|
LOG_COMMAND = "MessageHook"
|
||||||
|
|
||||||
|
|
||||||
def replace_message(message: Message) -> str:
|
def replace_message(message: Message) -> str:
|
||||||
"""将消息中的at、image、record、face替换为字符串
|
"""将消息中的at、image、record、face替换为字符串
|
||||||
@ -54,11 +56,11 @@ async def handle_api_result(
|
|||||||
if user_id and message_id:
|
if user_id and message_id:
|
||||||
MessageManager.add(str(user_id), str(message_id))
|
MessageManager.add(str(user_id), str(message_id))
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"收集消息id,user_id: {user_id}, msg_id: {message_id}", "msg_hook"
|
f"收集消息id,user_id: {user_id}, msg_id: {message_id}", LOG_COMMAND
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"收集消息id发生错误...data: {data}, result: {result}", "msg_hook", e=e
|
f"收集消息id发生错误...data: {data}, result: {result}", LOG_COMMAND, e=e
|
||||||
)
|
)
|
||||||
if not Config.get_config("hook", "RECORD_BOT_SENT_MESSAGES"):
|
if not Config.get_config("hook", "RECORD_BOT_SENT_MESSAGES"):
|
||||||
return
|
return
|
||||||
@ -80,6 +82,6 @@ async def handle_api_result(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"消息发送记录发生错误...data: {data}, result: {result}",
|
f"消息发送记录发生错误...data: {data}, result: {result}",
|
||||||
"msg_hook",
|
LOG_COMMAND,
|
||||||
e=e,
|
e=e,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -4,15 +4,27 @@ import nonebot
|
|||||||
from nonebot.adapters import Bot
|
from nonebot.adapters import Bot
|
||||||
|
|
||||||
from zhenxun.models.group_console import GroupConsole
|
from zhenxun.models.group_console import GroupConsole
|
||||||
|
from zhenxun.services.cache import CacheException
|
||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
|
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||||
from zhenxun.utils.platform import PlatformUtils
|
from zhenxun.utils.platform import PlatformUtils
|
||||||
|
|
||||||
nonebot.load_plugins(str(Path(__file__).parent.resolve()))
|
nonebot.load_plugins(str(Path(__file__).parent.resolve()))
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .__init_cache import register_cache_types
|
||||||
|
except CacheException as e:
|
||||||
|
raise SystemError(f"ERROR:{e}")
|
||||||
|
|
||||||
driver = nonebot.get_driver()
|
driver = nonebot.get_driver()
|
||||||
|
|
||||||
|
|
||||||
|
@PriorityLifecycle.on_startup(priority=5)
|
||||||
|
async def _():
|
||||||
|
register_cache_types()
|
||||||
|
logger.info("缓存类型注册完成")
|
||||||
|
|
||||||
|
|
||||||
@driver.on_bot_connect
|
@driver.on_bot_connect
|
||||||
async def _(bot: Bot):
|
async def _(bot: Bot):
|
||||||
"""将bot已存在的群组添加群认证
|
"""将bot已存在的群组添加群认证
|
||||||
|
|||||||
35
zhenxun/builtin_plugins/init/__init_cache.py
Normal file
35
zhenxun/builtin_plugins/init/__init_cache.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
"""
|
||||||
|
缓存初始化模块
|
||||||
|
|
||||||
|
负责注册各种缓存类型,实现按需缓存机制
|
||||||
|
"""
|
||||||
|
|
||||||
|
from zhenxun.models.ban_console import BanConsole
|
||||||
|
from zhenxun.models.bot_console import BotConsole
|
||||||
|
from zhenxun.models.group_console import GroupConsole
|
||||||
|
from zhenxun.models.level_user import LevelUser
|
||||||
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
|
from zhenxun.models.user_console import UserConsole
|
||||||
|
from zhenxun.services.cache import CacheRegistry, cache_config
|
||||||
|
from zhenxun.services.cache.config import CacheMode
|
||||||
|
from zhenxun.services.log import logger
|
||||||
|
from zhenxun.utils.enum import CacheType
|
||||||
|
|
||||||
|
|
||||||
|
# 注册缓存类型
|
||||||
|
def register_cache_types():
|
||||||
|
"""注册所有缓存类型"""
|
||||||
|
CacheRegistry.register(CacheType.PLUGINS, PluginInfo)
|
||||||
|
CacheRegistry.register(CacheType.GROUPS, GroupConsole)
|
||||||
|
CacheRegistry.register(CacheType.BOT, BotConsole)
|
||||||
|
CacheRegistry.register(CacheType.USERS, UserConsole)
|
||||||
|
CacheRegistry.register(
|
||||||
|
CacheType.LEVEL, LevelUser, key_format="{user_id}_{group_id}"
|
||||||
|
)
|
||||||
|
CacheRegistry.register(CacheType.BAN, BanConsole, key_format="{user_id}_{group_id}")
|
||||||
|
|
||||||
|
if cache_config.cache_mode == CacheMode.NONE:
|
||||||
|
logger.info("缓存功能已禁用,将直接从数据库获取数据")
|
||||||
|
else:
|
||||||
|
logger.info(f"已注册所有缓存类型,缓存模式: {cache_config.cache_mode}")
|
||||||
|
logger.info("使用增量缓存模式,数据将按需加载到缓存中")
|
||||||
@ -1,3 +1,5 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import nonebot
|
import nonebot
|
||||||
from nonebot import get_loaded_plugins
|
from nonebot import get_loaded_plugins
|
||||||
@ -112,24 +114,29 @@ async def _():
|
|||||||
await _handle_setting(plugin, plugin_list, limit_list)
|
await _handle_setting(plugin, plugin_list, limit_list)
|
||||||
create_list = []
|
create_list = []
|
||||||
update_list = []
|
update_list = []
|
||||||
|
update_task_list = []
|
||||||
for plugin in plugin_list:
|
for plugin in plugin_list:
|
||||||
if plugin.module_path not in module2id:
|
if plugin.module_path not in module2id:
|
||||||
create_list.append(plugin)
|
create_list.append(plugin)
|
||||||
else:
|
else:
|
||||||
plugin.id = module2id[plugin.module_path]
|
plugin.id = module2id[plugin.module_path]
|
||||||
await plugin.save(
|
update_task_list.append(
|
||||||
update_fields=[
|
plugin.save(
|
||||||
"name",
|
update_fields=[
|
||||||
"author",
|
"name",
|
||||||
"version",
|
"author",
|
||||||
"admin_level",
|
"version",
|
||||||
"plugin_type",
|
"admin_level",
|
||||||
"is_show",
|
"plugin_type",
|
||||||
]
|
"is_show",
|
||||||
|
]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
update_list.append(plugin)
|
update_list.append(plugin)
|
||||||
if create_list:
|
if create_list:
|
||||||
await PluginInfo.bulk_create(create_list, 10)
|
await PluginInfo.bulk_create(create_list, 10)
|
||||||
|
if update_task_list:
|
||||||
|
await asyncio.gather(*update_task_list)
|
||||||
# if update_list:
|
# if update_list:
|
||||||
# # TODO: 批量更新无法更新plugin_type: tortoise.exceptions.OperationalError:
|
# # TODO: 批量更新无法更新plugin_type: tortoise.exceptions.OperationalError:
|
||||||
# column "superuser" does not exist
|
# column "superuser" does not exist
|
||||||
|
|||||||
@ -205,7 +205,7 @@ class Manager:
|
|||||||
self.cd_data: dict[str, PluginCdBlock] = {}
|
self.cd_data: dict[str, PluginCdBlock] = {}
|
||||||
if self.cd_file.exists():
|
if self.cd_file.exists():
|
||||||
with open(self.cd_file, encoding="utf8") as f:
|
with open(self.cd_file, encoding="utf8") as f:
|
||||||
temp = _yaml.load(f)
|
temp = _yaml.load(f) or {}
|
||||||
if "PluginCdLimit" in temp.keys():
|
if "PluginCdLimit" in temp.keys():
|
||||||
for k, v in temp["PluginCdLimit"].items():
|
for k, v in temp["PluginCdLimit"].items():
|
||||||
if "." in k:
|
if "." in k:
|
||||||
@ -216,7 +216,7 @@ class Manager:
|
|||||||
self.block_data: dict[str, BaseBlock] = {}
|
self.block_data: dict[str, BaseBlock] = {}
|
||||||
if self.block_file.exists():
|
if self.block_file.exists():
|
||||||
with open(self.block_file, encoding="utf8") as f:
|
with open(self.block_file, encoding="utf8") as f:
|
||||||
temp = _yaml.load(f)
|
temp = _yaml.load(f) or {}
|
||||||
if "PluginBlockLimit" in temp.keys():
|
if "PluginBlockLimit" in temp.keys():
|
||||||
for k, v in temp["PluginBlockLimit"].items():
|
for k, v in temp["PluginBlockLimit"].items():
|
||||||
if "." in k:
|
if "." in k:
|
||||||
@ -227,7 +227,7 @@ class Manager:
|
|||||||
self.count_data: dict[str, PluginCountBlock] = {}
|
self.count_data: dict[str, PluginCountBlock] = {}
|
||||||
if self.count_file.exists():
|
if self.count_file.exists():
|
||||||
with open(self.count_file, encoding="utf8") as f:
|
with open(self.count_file, encoding="utf8") as f:
|
||||||
temp = _yaml.load(f)
|
temp = _yaml.load(f) or {}
|
||||||
if "PluginCountLimit" in temp.keys():
|
if "PluginCountLimit" in temp.keys():
|
||||||
for k, v in temp["PluginCountLimit"].items():
|
for k, v in temp["PluginCountLimit"].items():
|
||||||
if "." in k:
|
if "." in k:
|
||||||
|
|||||||
@ -55,15 +55,17 @@ class GroupManager:
|
|||||||
if plugin_list := await PluginInfo.filter(default_status=False).all():
|
if plugin_list := await PluginInfo.filter(default_status=False).all():
|
||||||
for plugin in plugin_list:
|
for plugin in plugin_list:
|
||||||
block_plugin += f"<{plugin.module},"
|
block_plugin += f"<{plugin.module},"
|
||||||
group_info = await bot.get_group_info(group_id=group_id, no_cache=True)
|
group_info = await bot.get_group_info(group_id=group_id)
|
||||||
await GroupConsole.create(
|
await GroupConsole.update_or_create(
|
||||||
group_id=group_info["group_id"],
|
group_id=group_info["group_id"],
|
||||||
group_name=group_info["group_name"],
|
defaults={
|
||||||
max_member_count=group_info["max_member_count"],
|
"group_name": group_info["group_name"],
|
||||||
member_count=group_info["member_count"],
|
"max_member_count": group_info["max_member_count"],
|
||||||
group_flag=1,
|
"member_count": group_info["member_count"],
|
||||||
block_plugin=block_plugin,
|
"group_flag": 1,
|
||||||
platform="qq",
|
"block_plugin": block_plugin,
|
||||||
|
"platform": "qq",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -145,7 +147,7 @@ class GroupManager:
|
|||||||
e=e,
|
e=e,
|
||||||
)
|
)
|
||||||
raise ForceAddGroupError("强制拉群或未有群信息,退出群聊失败...") from e
|
raise ForceAddGroupError("强制拉群或未有群信息,退出群聊失败...") from e
|
||||||
await GroupConsole.filter(group_id=group_id).delete()
|
# await GroupConsole.filter(group_id=group_id).delete()
|
||||||
raise ForceAddGroupError(f"触发强制入群保护,已成功退出群聊 {group_id}...")
|
raise ForceAddGroupError(f"触发强制入群保护,已成功退出群聊 {group_id}...")
|
||||||
else:
|
else:
|
||||||
await cls.__handle_add_group(bot, group_id, group)
|
await cls.__handle_add_group(bot, group_id, group)
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from nonebot.message import run_preprocessor
|
from nonebot import on_message
|
||||||
from nonebot_plugin_uninfo import Uninfo
|
from nonebot_plugin_uninfo import Uninfo
|
||||||
|
|
||||||
from zhenxun.models.friend_user import FriendUser
|
from zhenxun.models.friend_user import FriendUser
|
||||||
@ -8,24 +8,27 @@ from zhenxun.services.log import logger
|
|||||||
from zhenxun.utils.platform import PlatformUtils
|
from zhenxun.utils.platform import PlatformUtils
|
||||||
|
|
||||||
|
|
||||||
@run_preprocessor
|
def rule(session: Uninfo) -> bool:
|
||||||
async def do_something(session: Uninfo):
|
return PlatformUtils.is_qbot(session)
|
||||||
|
|
||||||
|
|
||||||
|
_matcher = on_message(priority=999, block=False, rule=rule)
|
||||||
|
|
||||||
|
|
||||||
|
@_matcher.handle()
|
||||||
|
async def _(session: Uninfo):
|
||||||
platform = PlatformUtils.get_platform(session)
|
platform = PlatformUtils.get_platform(session)
|
||||||
if session.group:
|
if session.group:
|
||||||
if not await GroupConsole.exists(group_id=session.group.id):
|
if not await GroupConsole.exists(group_id=session.group.id):
|
||||||
await GroupConsole.create(group_id=session.group.id)
|
await GroupConsole.create(group_id=session.group.id)
|
||||||
logger.info("添加当前群组ID信息" "", session=session)
|
logger.info("添加当前群组ID信息", session=session)
|
||||||
|
await GroupInfoUser.update_or_create(
|
||||||
if not await GroupInfoUser.exists(
|
user_id=session.user.id,
|
||||||
user_id=session.user.id, group_id=session.group.id
|
group_id=session.group.id,
|
||||||
):
|
platform=PlatformUtils.get_platform(session),
|
||||||
await GroupInfoUser.create(
|
)
|
||||||
user_id=session.user.id, group_id=session.group.id, platform=platform
|
|
||||||
)
|
|
||||||
logger.info("添加当前用户群组ID信息", "", session=session)
|
|
||||||
elif not await FriendUser.exists(user_id=session.user.id, platform=platform):
|
elif not await FriendUser.exists(user_id=session.user.id, platform=platform):
|
||||||
try:
|
await FriendUser.create(
|
||||||
await FriendUser.create(user_id=session.user.id, platform=platform)
|
user_id=session.user.id, platform=PlatformUtils.get_platform(session)
|
||||||
logger.info("添加当前好友用户信息", "", session=session)
|
)
|
||||||
except Exception as e:
|
logger.info("添加当前好友用户信息", "", session=session)
|
||||||
logger.error("添加当前好友用户信息失败", session=session, e=e)
|
|
||||||
|
|||||||
@ -1,30 +0,0 @@
|
|||||||
from zhenxun.models.group_console import GroupConsole
|
|
||||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
|
||||||
|
|
||||||
|
|
||||||
@PriorityLifecycle.on_startup(priority=5)
|
|
||||||
async def _():
|
|
||||||
"""开启/禁用插件格式修改"""
|
|
||||||
_, is_create = await GroupConsole.get_or_create(group_id=133133133)
|
|
||||||
"""标记"""
|
|
||||||
if is_create:
|
|
||||||
data_list = []
|
|
||||||
for group in await GroupConsole.all():
|
|
||||||
if group.block_plugin:
|
|
||||||
if modules := group.block_plugin.split(","):
|
|
||||||
block_plugin = "".join(
|
|
||||||
(f"{module}," if module.startswith("<") else f"<{module},")
|
|
||||||
for module in modules
|
|
||||||
if module.strip()
|
|
||||||
)
|
|
||||||
group.block_plugin = block_plugin.replace("<,", "")
|
|
||||||
if group.block_task:
|
|
||||||
if modules := group.block_task.split(","):
|
|
||||||
block_task = "".join(
|
|
||||||
(f"{module}," if module.startswith("<") else f"<{module},")
|
|
||||||
for module in modules
|
|
||||||
if module.strip()
|
|
||||||
)
|
|
||||||
group.block_task = block_task.replace("<,", "")
|
|
||||||
data_list.append(group)
|
|
||||||
await GroupConsole.bulk_update(data_list, ["block_plugin", "block_task"], 10)
|
|
||||||
@ -44,9 +44,7 @@ class StatisticsManage:
|
|||||||
title = f"{user.user_name if user else user_id} {day_type}功能调用统计"
|
title = f"{user.user_name if user else user_id} {day_type}功能调用统计"
|
||||||
elif group_id:
|
elif group_id:
|
||||||
"""查群组"""
|
"""查群组"""
|
||||||
group = await GroupConsole.get_or_none(
|
group = await GroupConsole.get_group(group_id=group_id)
|
||||||
group_id=group_id, channel_id__isnull=True
|
|
||||||
)
|
|
||||||
title = f"{group.group_name if group else group_id} {day_type}功能调用统计"
|
title = f"{group.group_name if group else group_id} {day_type}功能调用统计"
|
||||||
else:
|
else:
|
||||||
title = "功能调用统计"
|
title = "功能调用统计"
|
||||||
|
|||||||
@ -163,7 +163,7 @@ async def _(session: EventSession, arparma: Arparma, state: T_State, level: int)
|
|||||||
@_matcher.assign("super-handle", parameterless=[CheckGroupId()])
|
@_matcher.assign("super-handle", parameterless=[CheckGroupId()])
|
||||||
async def _(session: EventSession, arparma: Arparma, state: T_State):
|
async def _(session: EventSession, arparma: Arparma, state: T_State):
|
||||||
gid = state["group_id"]
|
gid = state["group_id"]
|
||||||
group = await GroupConsole.get_or_none(group_id=gid)
|
group = await GroupConsole.get_group(group_id=gid)
|
||||||
if not group:
|
if not group:
|
||||||
await MessageUtils.build_message("群组信息不存在, 请更新群组信息...").finish()
|
await MessageUtils.build_message("群组信息不存在, 请更新群组信息...").finish()
|
||||||
s = "删除" if arparma.find("delete") else "添加"
|
s = "删除" if arparma.find("delete") else "添加"
|
||||||
@ -177,7 +177,9 @@ async def _(session: EventSession, arparma: Arparma, state: T_State):
|
|||||||
async def _(session: EventSession, arparma: Arparma, state: T_State):
|
async def _(session: EventSession, arparma: Arparma, state: T_State):
|
||||||
gid = state["group_id"]
|
gid = state["group_id"]
|
||||||
await GroupConsole.update_or_create(
|
await GroupConsole.update_or_create(
|
||||||
group_id=gid, defaults={"group_flag": 0 if arparma.find("delete") else 1}
|
group_id=gid,
|
||||||
|
channel_id__isnull=True,
|
||||||
|
defaults={"group_flag": 0 if arparma.find("delete") else 1},
|
||||||
)
|
)
|
||||||
s = "删除" if arparma.find("delete") else "添加"
|
s = "删除" if arparma.find("delete") else "添加"
|
||||||
await MessageUtils.build_message(f"{s}群认证成功!").send(reply_to=True)
|
await MessageUtils.build_message(f"{s}群认证成功!").send(reply_to=True)
|
||||||
|
|||||||
@ -119,7 +119,7 @@ class ApiDataSource:
|
|||||||
(await PlatformUtils.get_friend_list(select_bot.bot))[0]
|
(await PlatformUtils.get_friend_list(select_bot.bot))[0]
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("获取bot好友/群组信息失败...", "WebUi", e=e)
|
logger.warning("获取bot好友/群组数量失败...", "WebUi", e=e)
|
||||||
select_bot.group_count = 0
|
select_bot.group_count = 0
|
||||||
select_bot.friend_count = 0
|
select_bot.friend_count = 0
|
||||||
select_bot.status = await BotConsole.get_bot_status(select_bot.self_id)
|
select_bot.status = await BotConsole.get_bot_status(select_bot.self_id)
|
||||||
|
|||||||
@ -250,7 +250,7 @@ class ApiDataSource:
|
|||||||
返回:
|
返回:
|
||||||
GroupDetail | None: 群组详情数据
|
GroupDetail | None: 群组详情数据
|
||||||
"""
|
"""
|
||||||
group = await GroupConsole.get_or_none(group_id=group_id)
|
group = await GroupConsole.get_group(group_id=group_id)
|
||||||
if not group:
|
if not group:
|
||||||
return None
|
return None
|
||||||
like_plugin = await cls.__get_group_detail_like_plugin(group_id)
|
like_plugin = await cls.__get_group_detail_like_plugin(group_id)
|
||||||
|
|||||||
@ -45,6 +45,7 @@ async def _(path: str | None = None) -> Result[list[DirFile]]:
|
|||||||
mtime=file_path.stat().st_mtime,
|
mtime=file_path.stat().st_mtime,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
data_list.sort(key=lambda f: f.name)
|
||||||
return Result.ok(data_list)
|
return Result.ok(data_list)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return Result.fail(f"获取文件列表失败: {e!s}")
|
return Result.fail(f"获取文件列表失败: {e!s}")
|
||||||
|
|||||||
@ -13,8 +13,8 @@ class BotSetting(BaseModel):
|
|||||||
"""回复时NICKNAME"""
|
"""回复时NICKNAME"""
|
||||||
system_proxy: str | None = None
|
system_proxy: str | None = None
|
||||||
"""系统代理"""
|
"""系统代理"""
|
||||||
db_url: str = ""
|
db_url: str = "sqlite:data/zhenxun.db"
|
||||||
"""数据库链接"""
|
"""数据库链接, 默认值为sqlite:data/zhenxun.db"""
|
||||||
platform_superusers: dict[str, list[str]] = Field(default_factory=dict)
|
platform_superusers: dict[str, list[str]] = Field(default_factory=dict)
|
||||||
"""平台超级用户"""
|
"""平台超级用户"""
|
||||||
qbot_id_data: dict[str, str] = Field(default_factory=dict)
|
qbot_id_data: dict[str, str] = Field(default_factory=dict)
|
||||||
|
|||||||
@ -155,8 +155,6 @@ class AICallableProperties(BaseModel):
|
|||||||
"""参数类型"""
|
"""参数类型"""
|
||||||
description: str
|
description: str
|
||||||
"""参数描述"""
|
"""参数描述"""
|
||||||
enums: list[str] | None = None
|
|
||||||
"""参数枚举"""
|
|
||||||
|
|
||||||
|
|
||||||
class AICallableParam(BaseModel):
|
class AICallableParam(BaseModel):
|
||||||
|
|||||||
@ -1,10 +1,12 @@
|
|||||||
import time
|
import time
|
||||||
|
from typing import ClassVar
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from tortoise import fields
|
from tortoise import fields
|
||||||
|
|
||||||
from zhenxun.services.db_context import Model
|
from zhenxun.services.db_context import Model
|
||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
|
from zhenxun.utils.enum import CacheType, DbLockType
|
||||||
from zhenxun.utils.exception import UserAndGroupIsNone
|
from zhenxun.utils.exception import UserAndGroupIsNone
|
||||||
|
|
||||||
|
|
||||||
@ -27,6 +29,15 @@ class BanConsole(Model):
|
|||||||
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
|
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
|
||||||
table = "ban_console"
|
table = "ban_console"
|
||||||
table_description = "封禁人员/群组数据表"
|
table_description = "封禁人员/群组数据表"
|
||||||
|
unique_together = ("user_id", "group_id")
|
||||||
|
indexes = [("user_id",), ("group_id",)] # noqa: RUF012
|
||||||
|
|
||||||
|
cache_type = CacheType.BAN
|
||||||
|
"""缓存类型"""
|
||||||
|
cache_key_field = ("user_id", "group_id")
|
||||||
|
"""缓存键字段"""
|
||||||
|
enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE, DbLockType.UPSERT]
|
||||||
|
"""开启锁"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _get_data(cls, user_id: str | None, group_id: str | None) -> Self | None:
|
async def _get_data(cls, user_id: str | None, group_id: str | None) -> Self | None:
|
||||||
@ -46,12 +57,12 @@ class BanConsole(Model):
|
|||||||
raise UserAndGroupIsNone()
|
raise UserAndGroupIsNone()
|
||||||
if user_id:
|
if user_id:
|
||||||
return (
|
return (
|
||||||
await cls.get_or_none(user_id=user_id, group_id=group_id)
|
await cls.safe_get_or_none(user_id=user_id, group_id=group_id)
|
||||||
if group_id
|
if group_id
|
||||||
else await cls.get_or_none(user_id=user_id, group_id__isnull=True)
|
else await cls.safe_get_or_none(user_id=user_id, group_id__isnull=True)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return await cls.get_or_none(user_id="", group_id=group_id)
|
return await cls.safe_get_or_none(user_id="", group_id=group_id)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def check_ban_level(
|
async def check_ban_level(
|
||||||
@ -167,3 +178,32 @@ class BanConsole(Model):
|
|||||||
await user.delete()
|
await user.delete()
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_ban(
|
||||||
|
cls,
|
||||||
|
*,
|
||||||
|
id: int | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
group_id: str | None = None,
|
||||||
|
) -> Self | None:
|
||||||
|
"""安全地获取ban记录
|
||||||
|
|
||||||
|
参数:
|
||||||
|
id: 记录id
|
||||||
|
user_id: 用户id
|
||||||
|
group_id: 群组id
|
||||||
|
|
||||||
|
返回:
|
||||||
|
Self | None: ban记录
|
||||||
|
"""
|
||||||
|
if id is not None:
|
||||||
|
return await cls.safe_get_or_none(id=id)
|
||||||
|
return await cls._get_data(user_id, group_id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _run_script(cls):
|
||||||
|
return [
|
||||||
|
"CREATE INDEX idx_ban_console_user_id ON ban_console(user_id);",
|
||||||
|
"CREATE INDEX idx_ban_console_group_id ON ban_console(group_id);",
|
||||||
|
]
|
||||||
|
|||||||
@ -3,6 +3,7 @@ from typing import Literal, overload
|
|||||||
from tortoise import fields
|
from tortoise import fields
|
||||||
|
|
||||||
from zhenxun.services.db_context import Model
|
from zhenxun.services.db_context import Model
|
||||||
|
from zhenxun.utils.enum import CacheType
|
||||||
|
|
||||||
|
|
||||||
class BotConsole(Model):
|
class BotConsole(Model):
|
||||||
@ -29,6 +30,11 @@ class BotConsole(Model):
|
|||||||
table = "bot_console"
|
table = "bot_console"
|
||||||
table_description = "Bot数据表"
|
table_description = "Bot数据表"
|
||||||
|
|
||||||
|
cache_type = CacheType.BOT
|
||||||
|
"""缓存类型"""
|
||||||
|
cache_key_field = "bot_id"
|
||||||
|
"""缓存键字段"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def format(name: str) -> str:
|
def format(name: str) -> str:
|
||||||
return f"<{name},"
|
return f"<{name},"
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, cast, overload
|
from typing import Any, ClassVar, cast, overload
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from tortoise import fields
|
from tortoise import fields
|
||||||
@ -6,8 +6,9 @@ from tortoise.backends.base.client import BaseDBAsyncClient
|
|||||||
|
|
||||||
from zhenxun.models.plugin_info import PluginInfo
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
from zhenxun.models.task_info import TaskInfo
|
from zhenxun.models.task_info import TaskInfo
|
||||||
|
from zhenxun.services.cache import CacheRoot
|
||||||
from zhenxun.services.db_context import Model
|
from zhenxun.services.db_context import Model
|
||||||
from zhenxun.utils.enum import PluginType
|
from zhenxun.utils.enum import CacheType, DbLockType, PluginType
|
||||||
|
|
||||||
|
|
||||||
def add_disable_marker(name: str) -> str:
|
def add_disable_marker(name: str) -> str:
|
||||||
@ -86,6 +87,16 @@ class GroupConsole(Model):
|
|||||||
table = "group_console"
|
table = "group_console"
|
||||||
table_description = "群组信息表"
|
table_description = "群组信息表"
|
||||||
unique_together = ("group_id", "channel_id")
|
unique_together = ("group_id", "channel_id")
|
||||||
|
indexes = [ # noqa: RUF012
|
||||||
|
("group_id",)
|
||||||
|
]
|
||||||
|
|
||||||
|
cache_type = CacheType.GROUPS
|
||||||
|
"""缓存类型"""
|
||||||
|
cache_key_field = ("group_id", "channel_id")
|
||||||
|
"""缓存键字段"""
|
||||||
|
enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE, DbLockType.UPSERT]
|
||||||
|
"""开启锁"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _get_task_modules(cls, *, default_status: bool) -> list[str]:
|
async def _get_task_modules(cls, *, default_status: bool) -> list[str]:
|
||||||
@ -116,6 +127,18 @@ class GroupConsole(Model):
|
|||||||
).values_list("module", flat=True),
|
).values_list("module", flat=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _update_cache(cls, instance):
|
||||||
|
"""更新缓存
|
||||||
|
|
||||||
|
参数:
|
||||||
|
instance: 需要更新缓存的实例
|
||||||
|
"""
|
||||||
|
if cache_type := cls.get_cache_type():
|
||||||
|
key = cls.get_cache_key(instance)
|
||||||
|
if key is not None:
|
||||||
|
await CacheRoot.invalidate_cache(cache_type, key)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create(
|
async def create(
|
||||||
cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any
|
cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any
|
||||||
@ -129,6 +152,9 @@ class GroupConsole(Model):
|
|||||||
if task_modules or plugin_modules:
|
if task_modules or plugin_modules:
|
||||||
await cls._update_modules(group, task_modules, plugin_modules, using_db)
|
await cls._update_modules(group, task_modules, plugin_modules, using_db)
|
||||||
|
|
||||||
|
# 更新缓存
|
||||||
|
await cls._update_cache(group)
|
||||||
|
|
||||||
return group
|
return group
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -180,6 +206,10 @@ class GroupConsole(Model):
|
|||||||
if task_modules or plugin_modules:
|
if task_modules or plugin_modules:
|
||||||
await cls._update_modules(group, task_modules, plugin_modules, using_db)
|
await cls._update_modules(group, task_modules, plugin_modules, using_db)
|
||||||
|
|
||||||
|
# 更新缓存
|
||||||
|
if is_create:
|
||||||
|
await cls._update_cache(group)
|
||||||
|
|
||||||
return group, is_create
|
return group, is_create
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -202,24 +232,39 @@ class GroupConsole(Model):
|
|||||||
if task_modules or plugin_modules:
|
if task_modules or plugin_modules:
|
||||||
await cls._update_modules(group, task_modules, plugin_modules, using_db)
|
await cls._update_modules(group, task_modules, plugin_modules, using_db)
|
||||||
|
|
||||||
|
# 更新缓存
|
||||||
|
await cls._update_cache(group)
|
||||||
|
|
||||||
return group, is_create
|
return group, is_create
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_group(
|
async def get_group(
|
||||||
cls, group_id: str, channel_id: str | None = None
|
cls,
|
||||||
|
group_id: str,
|
||||||
|
channel_id: str | None = None,
|
||||||
|
clean_duplicates: bool = True,
|
||||||
) -> Self | None:
|
) -> Self | None:
|
||||||
"""获取群组
|
"""获取群组
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
group_id: 群组id
|
group_id: 群组id
|
||||||
channel_id: 频道id.
|
channel_id: 频道id
|
||||||
|
clean_duplicates: 是否删除重复的记录,仅保留最新的
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
Self: GroupConsole
|
Self: GroupConsole
|
||||||
"""
|
"""
|
||||||
if channel_id:
|
if channel_id:
|
||||||
return await cls.get_or_none(group_id=group_id, channel_id=channel_id)
|
return await cls.safe_get_or_none(
|
||||||
return await cls.get_or_none(group_id=group_id, channel_id__isnull=True)
|
group_id=group_id,
|
||||||
|
channel_id=channel_id,
|
||||||
|
clean_duplicates=clean_duplicates,
|
||||||
|
)
|
||||||
|
return await cls.safe_get_or_none(
|
||||||
|
group_id=group_id,
|
||||||
|
channel_id__isnull=True,
|
||||||
|
clean_duplicates=clean_duplicates,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def is_super_group(cls, group_id: str) -> bool:
|
async def is_super_group(cls, group_id: str) -> bool:
|
||||||
@ -303,6 +348,9 @@ class GroupConsole(Model):
|
|||||||
if update_fields:
|
if update_fields:
|
||||||
await group.save(update_fields=update_fields)
|
await group.save(update_fields=update_fields)
|
||||||
|
|
||||||
|
# 更新缓存
|
||||||
|
await cls._update_cache(group)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def set_unblock_plugin(
|
async def set_unblock_plugin(
|
||||||
cls,
|
cls,
|
||||||
@ -339,6 +387,9 @@ class GroupConsole(Model):
|
|||||||
if update_fields:
|
if update_fields:
|
||||||
await group.save(update_fields=update_fields)
|
await group.save(update_fields=update_fields)
|
||||||
|
|
||||||
|
# 更新缓存
|
||||||
|
await cls._update_cache(group)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def is_normal_block_plugin(
|
async def is_normal_block_plugin(
|
||||||
cls, group_id: str, module: str, channel_id: str | None = None
|
cls, group_id: str, module: str, channel_id: str | None = None
|
||||||
@ -442,6 +493,9 @@ class GroupConsole(Model):
|
|||||||
if update_fields:
|
if update_fields:
|
||||||
await group.save(update_fields=update_fields)
|
await group.save(update_fields=update_fields)
|
||||||
|
|
||||||
|
# 更新缓存
|
||||||
|
await cls._update_cache(group)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def set_unblock_task(
|
async def set_unblock_task(
|
||||||
cls,
|
cls,
|
||||||
@ -476,6 +530,9 @@ class GroupConsole(Model):
|
|||||||
if update_fields:
|
if update_fields:
|
||||||
await group.save(update_fields=update_fields)
|
await group.save(update_fields=update_fields)
|
||||||
|
|
||||||
|
# 更新缓存
|
||||||
|
await cls._update_cache(group)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _run_script(cls):
|
def _run_script(cls):
|
||||||
return [
|
return [
|
||||||
@ -483,4 +540,6 @@ class GroupConsole(Model):
|
|||||||
" character varying(255) NOT NULL DEFAULT '';",
|
" character varying(255) NOT NULL DEFAULT '';",
|
||||||
"ALTER TABLE group_console ADD superuser_block_task"
|
"ALTER TABLE group_console ADD superuser_block_task"
|
||||||
" character varying(255) NOT NULL DEFAULT '';",
|
" character varying(255) NOT NULL DEFAULT '';",
|
||||||
|
"CREATE INDEX idx_group_console_group_id ON group_console(group_id);",
|
||||||
|
"CREATE INDEX idx_group_console_group_null_channel ON group_console(group_id) WHERE channel_id IS NULL;", # 单独创建channel为空的索引 # noqa: E501
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from tortoise import fields
|
from tortoise import fields
|
||||||
|
|
||||||
from zhenxun.services.db_context import Model
|
from zhenxun.services.db_context import Model
|
||||||
|
from zhenxun.utils.enum import CacheType
|
||||||
|
|
||||||
|
|
||||||
class LevelUser(Model):
|
class LevelUser(Model):
|
||||||
@ -20,6 +21,11 @@ class LevelUser(Model):
|
|||||||
table_description = "用户权限数据库"
|
table_description = "用户权限数据库"
|
||||||
unique_together = ("user_id", "group_id")
|
unique_together = ("user_id", "group_id")
|
||||||
|
|
||||||
|
cache_type = CacheType.LEVEL
|
||||||
|
"""缓存类型"""
|
||||||
|
cache_key_field = ("user_id", "group_id")
|
||||||
|
"""缓存键字段"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_user_level(cls, user_id: str, group_id: str | None) -> int:
|
async def get_user_level(cls, user_id: str, group_id: str | None) -> int:
|
||||||
"""获取用户在群内的等级
|
"""获取用户在群内的等级
|
||||||
@ -53,6 +59,9 @@ class LevelUser(Model):
|
|||||||
level: 权限等级
|
level: 权限等级
|
||||||
group_flag: 是否被自动更新刷新权限 0:是, 1:否.
|
group_flag: 是否被自动更新刷新权限 0:是, 1:否.
|
||||||
"""
|
"""
|
||||||
|
if await cls.exists(user_id=user_id, group_id=group_id, user_level=level):
|
||||||
|
# 权限相同时跳过
|
||||||
|
return
|
||||||
await cls.update_or_create(
|
await cls.update_or_create(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from tortoise import fields
|
|||||||
|
|
||||||
from zhenxun.models.plugin_limit import PluginLimit # noqa: F401
|
from zhenxun.models.plugin_limit import PluginLimit # noqa: F401
|
||||||
from zhenxun.services.db_context import Model
|
from zhenxun.services.db_context import Model
|
||||||
from zhenxun.utils.enum import BlockType, PluginType
|
from zhenxun.utils.enum import BlockType, CacheType, PluginType
|
||||||
|
|
||||||
|
|
||||||
class PluginInfo(Model):
|
class PluginInfo(Model):
|
||||||
@ -59,6 +59,11 @@ class PluginInfo(Model):
|
|||||||
table = "plugin_info"
|
table = "plugin_info"
|
||||||
table_description = "插件基本信息"
|
table_description = "插件基本信息"
|
||||||
|
|
||||||
|
cache_type = CacheType.PLUGINS
|
||||||
|
"""缓存类型"""
|
||||||
|
cache_key_field = "module"
|
||||||
|
"""缓存键字段"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_plugin(
|
async def get_plugin(
|
||||||
cls, load_status: bool = True, filter_parent: bool = True, **kwargs
|
cls, load_status: bool = True, filter_parent: bool = True, **kwargs
|
||||||
|
|||||||
@ -2,7 +2,7 @@ from tortoise import fields
|
|||||||
|
|
||||||
from zhenxun.models.goods_info import GoodsInfo
|
from zhenxun.models.goods_info import GoodsInfo
|
||||||
from zhenxun.services.db_context import Model
|
from zhenxun.services.db_context import Model
|
||||||
from zhenxun.utils.enum import GoldHandle
|
from zhenxun.utils.enum import CacheType, GoldHandle
|
||||||
from zhenxun.utils.exception import GoodsNotFound, InsufficientGold
|
from zhenxun.utils.exception import GoodsNotFound, InsufficientGold
|
||||||
|
|
||||||
from .user_gold_log import UserGoldLog
|
from .user_gold_log import UserGoldLog
|
||||||
@ -29,6 +29,12 @@ class UserConsole(Model):
|
|||||||
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
|
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
|
||||||
table = "user_console"
|
table = "user_console"
|
||||||
table_description = "用户数据表"
|
table_description = "用户数据表"
|
||||||
|
indexes = [("user_id",), ("uid",)] # noqa: RUF012
|
||||||
|
|
||||||
|
cache_type = CacheType.USERS
|
||||||
|
"""缓存类型"""
|
||||||
|
cache_key_field = "user_id"
|
||||||
|
"""缓存键字段"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_user(cls, user_id: str, platform: str | None = None) -> "UserConsole":
|
async def get_user(cls, user_id: str, platform: str | None = None) -> "UserConsole":
|
||||||
@ -193,3 +199,10 @@ class UserConsole(Model):
|
|||||||
if goods := await GoodsInfo.get_or_none(goods_name=name):
|
if goods := await GoodsInfo.get_or_none(goods_name=name):
|
||||||
return await cls.use_props(user_id, goods.uuid, num, platform)
|
return await cls.use_props(user_id, goods.uuid, num, platform)
|
||||||
raise GoodsNotFound("未找到商品...")
|
raise GoodsNotFound("未找到商品...")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _run_script(cls):
|
||||||
|
return [
|
||||||
|
"CREATE INDEX idx_user_console_user_id ON user_console(user_id);",
|
||||||
|
"CREATE INDEX idx_user_console_uid ON user_console(uid);",
|
||||||
|
]
|
||||||
|
|||||||
1065
zhenxun/services/cache/__init__.py
vendored
Normal file
1065
zhenxun/services/cache/__init__.py
vendored
Normal file
File diff suppressed because it is too large
Load Diff
452
zhenxun/services/cache/cache_containers.py
vendored
Normal file
452
zhenxun/services/cache/cache_containers.py
vendored
Normal file
@ -0,0 +1,452 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
import time
|
||||||
|
from typing import Any, Generic, TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CacheData(Generic[T]):
|
||||||
|
"""缓存数据类,存储数据和过期时间"""
|
||||||
|
|
||||||
|
value: T
|
||||||
|
expire_time: float = 0 # 0表示永不过期
|
||||||
|
|
||||||
|
|
||||||
|
class CacheDict:
|
||||||
|
"""缓存字典类,提供类似普通字典的接口,数据只存储在内存中"""
|
||||||
|
|
||||||
|
def __init__(self, name: str, expire: int = 0):
|
||||||
|
"""初始化缓存字典
|
||||||
|
|
||||||
|
参数:
|
||||||
|
name: 字典名称
|
||||||
|
expire: 过期时间(秒),默认为0表示永不过期
|
||||||
|
"""
|
||||||
|
self.name = name.upper()
|
||||||
|
self.expire = expire
|
||||||
|
self._data: dict[str, CacheData[Any]] = {}
|
||||||
|
|
||||||
|
def __getitem__(self, key: str) -> Any:
|
||||||
|
"""获取字典项
|
||||||
|
|
||||||
|
参数:
|
||||||
|
key: 字典键
|
||||||
|
|
||||||
|
返回:
|
||||||
|
Any: 字典值
|
||||||
|
"""
|
||||||
|
data = self._data.get(key)
|
||||||
|
if data is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 检查是否过期
|
||||||
|
if data.expire_time > 0 and data.expire_time < time.time():
|
||||||
|
del self._data[key]
|
||||||
|
return None
|
||||||
|
|
||||||
|
return data.value
|
||||||
|
|
||||||
|
def __setitem__(self, key: str, value: Any) -> None:
|
||||||
|
"""设置字典项
|
||||||
|
|
||||||
|
参数:
|
||||||
|
key: 字典键
|
||||||
|
value: 字典值
|
||||||
|
"""
|
||||||
|
# 计算过期时间
|
||||||
|
expire_time = 0
|
||||||
|
if self.expire > 0:
|
||||||
|
expire_time = time.time() + self.expire
|
||||||
|
|
||||||
|
self._data[key] = CacheData(value=value, expire_time=expire_time)
|
||||||
|
|
||||||
|
def __delitem__(self, key: str) -> None:
|
||||||
|
"""删除字典项
|
||||||
|
|
||||||
|
参数:
|
||||||
|
key: 字典键
|
||||||
|
"""
|
||||||
|
if key in self._data:
|
||||||
|
del self._data[key]
|
||||||
|
|
||||||
|
def __contains__(self, key: str) -> bool:
|
||||||
|
"""检查键是否存在
|
||||||
|
|
||||||
|
参数:
|
||||||
|
key: 字典键
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bool: 是否存在
|
||||||
|
"""
|
||||||
|
if key not in self._data:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 检查是否过期
|
||||||
|
data = self._data[key]
|
||||||
|
if data.expire_time > 0 and data.expire_time < time.time():
|
||||||
|
del self._data[key]
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get(self, key: str, default: Any = None) -> Any:
|
||||||
|
"""获取字典项,如果不存在返回默认值
|
||||||
|
|
||||||
|
参数:
|
||||||
|
key: 字典键
|
||||||
|
default: 默认值
|
||||||
|
|
||||||
|
返回:
|
||||||
|
Any: 字典值或默认值
|
||||||
|
"""
|
||||||
|
value = self[key]
|
||||||
|
return default if value is None else value
|
||||||
|
|
||||||
|
def set(self, key: str, value: Any, expire: int | None = None) -> None:
|
||||||
|
"""设置字典项
|
||||||
|
|
||||||
|
参数:
|
||||||
|
key: 字典键
|
||||||
|
value: 字典值
|
||||||
|
expire: 过期时间(秒),为None时使用默认值
|
||||||
|
"""
|
||||||
|
# 计算过期时间
|
||||||
|
expire_time = 0
|
||||||
|
if expire is not None and expire > 0:
|
||||||
|
expire_time = time.time() + expire
|
||||||
|
elif self.expire > 0:
|
||||||
|
expire_time = time.time() + self.expire
|
||||||
|
|
||||||
|
self._data[key] = CacheData(value=value, expire_time=expire_time)
|
||||||
|
|
||||||
|
def pop(self, key: str, default: Any = None) -> Any:
|
||||||
|
"""删除并返回字典项
|
||||||
|
|
||||||
|
参数:
|
||||||
|
key: 字典键
|
||||||
|
default: 默认值
|
||||||
|
|
||||||
|
返回:
|
||||||
|
Any: 字典值或默认值
|
||||||
|
"""
|
||||||
|
if key not in self._data:
|
||||||
|
return default
|
||||||
|
|
||||||
|
data = self._data.pop(key)
|
||||||
|
|
||||||
|
# 检查是否过期
|
||||||
|
if data.expire_time > 0 and data.expire_time < time.time():
|
||||||
|
return default
|
||||||
|
|
||||||
|
return data.value
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""清空字典"""
|
||||||
|
self._data.clear()
|
||||||
|
|
||||||
|
def keys(self) -> list[str]:
|
||||||
|
"""获取所有键
|
||||||
|
|
||||||
|
返回:
|
||||||
|
list[str]: 键列表
|
||||||
|
"""
|
||||||
|
# 清理过期的键
|
||||||
|
self._clean_expired()
|
||||||
|
return list(self._data.keys())
|
||||||
|
|
||||||
|
def values(self) -> list[Any]:
|
||||||
|
"""获取所有值
|
||||||
|
|
||||||
|
返回:
|
||||||
|
list[Any]: 值列表
|
||||||
|
"""
|
||||||
|
# 清理过期的键
|
||||||
|
self._clean_expired()
|
||||||
|
return [data.value for data in self._data.values()]
|
||||||
|
|
||||||
|
def items(self) -> list[tuple[str, Any]]:
|
||||||
|
"""获取所有键值对
|
||||||
|
|
||||||
|
返回:
|
||||||
|
list[tuple[str, Any]]: 键值对列表
|
||||||
|
"""
|
||||||
|
# 清理过期的键
|
||||||
|
self._clean_expired()
|
||||||
|
return [(key, data.value) for key, data in self._data.items()]
|
||||||
|
|
||||||
|
def _clean_expired(self) -> None:
|
||||||
|
"""清理过期的键"""
|
||||||
|
now = time.time()
|
||||||
|
expired_keys = [
|
||||||
|
key
|
||||||
|
for key, data in self._data.items()
|
||||||
|
if data.expire_time > 0 and data.expire_time < now
|
||||||
|
]
|
||||||
|
for key in expired_keys:
|
||||||
|
del self._data[key]
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""获取字典长度
|
||||||
|
|
||||||
|
返回:
|
||||||
|
int: 字典长度
|
||||||
|
"""
|
||||||
|
# 清理过期的键
|
||||||
|
self._clean_expired()
|
||||||
|
return len(self._data)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""字符串表示
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 字符串表示
|
||||||
|
"""
|
||||||
|
# 清理过期的键
|
||||||
|
self._clean_expired()
|
||||||
|
return f"CacheDict({self.name}, {len(self._data)} items)"
|
||||||
|
|
||||||
|
|
||||||
|
class CacheList:
|
||||||
|
"""缓存列表类,提供类似普通列表的接口,数据只存储在内存中"""
|
||||||
|
|
||||||
|
def __init__(self, name: str, expire: int = 0):
|
||||||
|
"""初始化缓存列表
|
||||||
|
|
||||||
|
参数:
|
||||||
|
name: 列表名称
|
||||||
|
expire: 过期时间(秒),默认为0表示永不过期
|
||||||
|
"""
|
||||||
|
self.name = name.upper()
|
||||||
|
self.expire = expire
|
||||||
|
self._data: list[CacheData[Any]] = []
|
||||||
|
self._expire_time = 0
|
||||||
|
|
||||||
|
# 如果设置了过期时间,计算整个列表的过期时间
|
||||||
|
if self.expire > 0:
|
||||||
|
self._expire_time = time.time() + self.expire
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> Any:
|
||||||
|
"""获取列表项
|
||||||
|
|
||||||
|
参数:
|
||||||
|
index: 列表索引
|
||||||
|
|
||||||
|
返回:
|
||||||
|
Any: 列表值
|
||||||
|
"""
|
||||||
|
# 检查整个列表是否过期
|
||||||
|
if self._is_expired():
|
||||||
|
self.clear()
|
||||||
|
raise IndexError(f"列表索引 {index} 超出范围")
|
||||||
|
|
||||||
|
if 0 <= index < len(self._data):
|
||||||
|
return self._data[index].value
|
||||||
|
raise IndexError(f"列表索引 {index} 超出范围")
|
||||||
|
|
||||||
|
def __setitem__(self, index: int, value: Any) -> None:
|
||||||
|
"""设置列表项
|
||||||
|
|
||||||
|
参数:
|
||||||
|
index: 列表索引
|
||||||
|
value: 列表值
|
||||||
|
"""
|
||||||
|
# 检查整个列表是否过期
|
||||||
|
if self._is_expired():
|
||||||
|
self.clear()
|
||||||
|
|
||||||
|
# 确保索引有效
|
||||||
|
while len(self._data) <= index:
|
||||||
|
self._data.append(CacheData(value=None))
|
||||||
|
self._data[index] = CacheData(value=value)
|
||||||
|
|
||||||
|
# 更新过期时间
|
||||||
|
self._update_expire_time()
|
||||||
|
|
||||||
|
def __delitem__(self, index: int) -> None:
|
||||||
|
"""删除列表项
|
||||||
|
|
||||||
|
参数:
|
||||||
|
index: 列表索引
|
||||||
|
"""
|
||||||
|
# 检查整个列表是否过期
|
||||||
|
if self._is_expired():
|
||||||
|
self.clear()
|
||||||
|
raise IndexError(f"列表索引 {index} 超出范围")
|
||||||
|
|
||||||
|
if 0 <= index < len(self._data):
|
||||||
|
del self._data[index]
|
||||||
|
# 更新过期时间
|
||||||
|
self._update_expire_time()
|
||||||
|
else:
|
||||||
|
raise IndexError(f"列表索引 {index} 超出范围")
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""获取列表长度
|
||||||
|
|
||||||
|
返回:
|
||||||
|
int: 列表长度
|
||||||
|
"""
|
||||||
|
# 检查整个列表是否过期
|
||||||
|
if self._is_expired():
|
||||||
|
self.clear()
|
||||||
|
return len(self._data)
|
||||||
|
|
||||||
|
def append(self, value: Any) -> None:
|
||||||
|
"""添加列表项
|
||||||
|
|
||||||
|
参数:
|
||||||
|
value: 列表值
|
||||||
|
"""
|
||||||
|
# 检查整个列表是否过期
|
||||||
|
if self._is_expired():
|
||||||
|
self.clear()
|
||||||
|
|
||||||
|
self._data.append(CacheData(value=value))
|
||||||
|
|
||||||
|
# 更新过期时间
|
||||||
|
self._update_expire_time()
|
||||||
|
|
||||||
|
def extend(self, values: list[Any]) -> None:
|
||||||
|
"""扩展列表
|
||||||
|
|
||||||
|
参数:
|
||||||
|
values: 要添加的值列表
|
||||||
|
"""
|
||||||
|
# 检查整个列表是否过期
|
||||||
|
if self._is_expired():
|
||||||
|
self.clear()
|
||||||
|
|
||||||
|
self._data.extend([CacheData(value=v) for v in values])
|
||||||
|
|
||||||
|
# 更新过期时间
|
||||||
|
self._update_expire_time()
|
||||||
|
|
||||||
|
def insert(self, index: int, value: Any) -> None:
|
||||||
|
"""插入列表项
|
||||||
|
|
||||||
|
参数:
|
||||||
|
index: 插入位置
|
||||||
|
value: 列表值
|
||||||
|
"""
|
||||||
|
# 检查整个列表是否过期
|
||||||
|
if self._is_expired():
|
||||||
|
self.clear()
|
||||||
|
|
||||||
|
self._data.insert(index, CacheData(value=value))
|
||||||
|
|
||||||
|
# 更新过期时间
|
||||||
|
self._update_expire_time()
|
||||||
|
|
||||||
|
def pop(self, index: int = -1) -> Any:
|
||||||
|
"""删除并返回列表项
|
||||||
|
|
||||||
|
参数:
|
||||||
|
index: 列表索引,默认为最后一项
|
||||||
|
|
||||||
|
返回:
|
||||||
|
Any: 列表值
|
||||||
|
"""
|
||||||
|
# 检查整个列表是否过期
|
||||||
|
if self._is_expired():
|
||||||
|
self.clear()
|
||||||
|
raise IndexError("从空列表中弹出")
|
||||||
|
|
||||||
|
if not self._data:
|
||||||
|
raise IndexError("从空列表中弹出")
|
||||||
|
|
||||||
|
item = self._data.pop(index)
|
||||||
|
|
||||||
|
# 更新过期时间
|
||||||
|
self._update_expire_time()
|
||||||
|
|
||||||
|
return item.value
|
||||||
|
|
||||||
|
def remove(self, value: Any) -> None:
|
||||||
|
"""删除第一个匹配的列表项
|
||||||
|
|
||||||
|
参数:
|
||||||
|
value: 要删除的值
|
||||||
|
"""
|
||||||
|
# 检查整个列表是否过期
|
||||||
|
if self._is_expired():
|
||||||
|
self.clear()
|
||||||
|
raise ValueError(f"{value} 不在列表中")
|
||||||
|
|
||||||
|
# 查找匹配的项
|
||||||
|
for i, item in enumerate(self._data):
|
||||||
|
if item.value == value:
|
||||||
|
del self._data[i]
|
||||||
|
# 更新过期时间
|
||||||
|
self._update_expire_time()
|
||||||
|
return
|
||||||
|
|
||||||
|
raise ValueError(f"{value} 不在列表中")
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""清空列表"""
|
||||||
|
self._data.clear()
|
||||||
|
# 重置过期时间
|
||||||
|
self._update_expire_time()
|
||||||
|
|
||||||
|
def index(self, value: Any, start: int = 0, end: int | None = None) -> int:
|
||||||
|
"""查找值的索引
|
||||||
|
|
||||||
|
参数:
|
||||||
|
value: 要查找的值
|
||||||
|
start: 起始索引
|
||||||
|
end: 结束索引
|
||||||
|
|
||||||
|
返回:
|
||||||
|
int: 索引位置
|
||||||
|
"""
|
||||||
|
# 检查整个列表是否过期
|
||||||
|
if self._is_expired():
|
||||||
|
self.clear()
|
||||||
|
raise ValueError(f"{value} 不在列表中")
|
||||||
|
|
||||||
|
end = end if end is not None else len(self._data)
|
||||||
|
|
||||||
|
for i in range(start, min(end, len(self._data))):
|
||||||
|
if self._data[i].value == value:
|
||||||
|
return i
|
||||||
|
|
||||||
|
raise ValueError(f"{value} 不在列表中")
|
||||||
|
|
||||||
|
def count(self, value: Any) -> int:
|
||||||
|
"""计算值出现的次数
|
||||||
|
|
||||||
|
参数:
|
||||||
|
value: 要计数的值
|
||||||
|
|
||||||
|
返回:
|
||||||
|
int: 出现次数
|
||||||
|
"""
|
||||||
|
# 检查整个列表是否过期
|
||||||
|
if self._is_expired():
|
||||||
|
self.clear()
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return sum(1 for item in self._data if item.value == value)
|
||||||
|
|
||||||
|
def _is_expired(self) -> bool:
|
||||||
|
"""检查整个列表是否过期"""
|
||||||
|
return self._expire_time > 0 and self._expire_time < time.time()
|
||||||
|
|
||||||
|
def _update_expire_time(self) -> None:
|
||||||
|
"""更新过期时间"""
|
||||||
|
if self.expire > 0:
|
||||||
|
self._expire_time = time.time() + self.expire
|
||||||
|
else:
|
||||||
|
self._expire_time = 0
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""字符串表示
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 字符串表示
|
||||||
|
"""
|
||||||
|
# 检查整个列表是否过期
|
||||||
|
if self._is_expired():
|
||||||
|
self.clear()
|
||||||
|
return f"CacheList({self.name}, {len(self._data)} items)"
|
||||||
35
zhenxun/services/cache/config.py
vendored
Normal file
35
zhenxun/services/cache/config.py
vendored
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
"""
|
||||||
|
缓存系统配置
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 日志标识
|
||||||
|
LOG_COMMAND = "CacheRoot"
|
||||||
|
|
||||||
|
# 默认缓存过期时间(秒)
|
||||||
|
DEFAULT_EXPIRE = 600
|
||||||
|
|
||||||
|
# 缓存键前缀
|
||||||
|
CACHE_KEY_PREFIX = "ZHENXUN"
|
||||||
|
|
||||||
|
# 缓存键分隔符
|
||||||
|
CACHE_KEY_SEPARATOR = ":"
|
||||||
|
|
||||||
|
# 复合键分隔符(用于分隔tuple类型的cache_key_field)
|
||||||
|
COMPOSITE_KEY_SEPARATOR = "_"
|
||||||
|
|
||||||
|
|
||||||
|
# 缓存模式
|
||||||
|
class CacheMode:
|
||||||
|
# 内存缓存 - 使用内存存储缓存数据
|
||||||
|
MEMORY = "MEMORY"
|
||||||
|
# Redis缓存 - 使用Redis服务器存储缓存数据
|
||||||
|
REDIS = "REDIS"
|
||||||
|
# 不使用缓存 - 将使用ttl=0的内存缓存,相当于直接从数据库获取数据
|
||||||
|
NONE = "NONE"
|
||||||
|
|
||||||
|
|
||||||
|
SPECIAL_KEY_FORMATS = {
|
||||||
|
"LEVEL": "{user_id}" + COMPOSITE_KEY_SEPARATOR + "{group_id}",
|
||||||
|
"BAN": "{user_id}" + COMPOSITE_KEY_SEPARATOR + "{group_id}",
|
||||||
|
"GROUPS": "{group_id}" + COMPOSITE_KEY_SEPARATOR + "{channel_id}",
|
||||||
|
}
|
||||||
653
zhenxun/services/data_access.py
Normal file
653
zhenxun/services/data_access.py
Normal file
@ -0,0 +1,653 @@
|
|||||||
|
from typing import Any, ClassVar, Generic, TypeVar, cast
|
||||||
|
|
||||||
|
from zhenxun.services.cache import Cache, CacheRoot, cache_config
|
||||||
|
from zhenxun.services.cache.config import COMPOSITE_KEY_SEPARATOR, CacheMode
|
||||||
|
from zhenxun.services.db_context import Model, with_db_timeout
|
||||||
|
from zhenxun.services.log import logger
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=Model)
|
||||||
|
|
||||||
|
|
||||||
|
class DataAccess(Generic[T]):
|
||||||
|
"""数据访问层,根据配置决定是否使用缓存
|
||||||
|
|
||||||
|
使用示例:
|
||||||
|
```python
|
||||||
|
from zhenxun.services import DataAccess
|
||||||
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
|
|
||||||
|
# 创建数据访问对象
|
||||||
|
plugin_dao = DataAccess(PluginInfo)
|
||||||
|
|
||||||
|
# 获取单个数据
|
||||||
|
plugin = await plugin_dao.get(module="example_module")
|
||||||
|
|
||||||
|
# 获取所有数据
|
||||||
|
all_plugins = await plugin_dao.all()
|
||||||
|
|
||||||
|
# 筛选数据
|
||||||
|
enabled_plugins = await plugin_dao.filter(status=True)
|
||||||
|
|
||||||
|
# 创建数据
|
||||||
|
new_plugin = await plugin_dao.create(
|
||||||
|
module="new_module",
|
||||||
|
name="新插件",
|
||||||
|
status=True
|
||||||
|
)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 添加缓存统计信息
|
||||||
|
_cache_stats: ClassVar[dict] = {}
|
||||||
|
# 空结果标记
|
||||||
|
_NULL_RESULT = "__NULL_RESULT_PLACEHOLDER__"
|
||||||
|
# 默认空结果缓存时间(秒)- 设置为5分钟,避免频繁查询数据库
|
||||||
|
_NULL_RESULT_TTL = 300
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_null_result_ttl(cls, seconds: int) -> None:
|
||||||
|
"""设置空结果缓存时间
|
||||||
|
|
||||||
|
参数:
|
||||||
|
seconds: 缓存时间(秒)
|
||||||
|
"""
|
||||||
|
if seconds < 0:
|
||||||
|
raise ValueError("缓存时间不能为负数")
|
||||||
|
cls._NULL_RESULT_TTL = seconds
|
||||||
|
logger.info(f"已设置DataAccess空结果缓存时间为 {seconds} 秒")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_null_result_ttl(cls) -> int:
|
||||||
|
"""获取空结果缓存时间
|
||||||
|
|
||||||
|
返回:
|
||||||
|
int: 缓存时间(秒)
|
||||||
|
"""
|
||||||
|
return cls._NULL_RESULT_TTL
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, model_cls: type[T], key_field: str = "id", cache_type: str | None = None
|
||||||
|
):
|
||||||
|
"""初始化数据访问对象
|
||||||
|
|
||||||
|
参数:
|
||||||
|
model_cls: 模型类
|
||||||
|
key_field: 主键字段
|
||||||
|
"""
|
||||||
|
self.model_cls = model_cls
|
||||||
|
self.key_field = getattr(model_cls, "cache_key_field", key_field)
|
||||||
|
self.cache_type = getattr(model_cls, "cache_type", cache_type)
|
||||||
|
|
||||||
|
if not self.cache_type:
|
||||||
|
raise ValueError("缓存类型不能为空")
|
||||||
|
self.cache = Cache(self.cache_type)
|
||||||
|
|
||||||
|
# 初始化缓存统计
|
||||||
|
if self.cache_type not in self._cache_stats:
|
||||||
|
self._cache_stats[self.cache_type] = {
|
||||||
|
"hits": 0, # 缓存命中次数
|
||||||
|
"misses": 0, # 缓存未命中次数
|
||||||
|
"null_hits": 0, # 空结果缓存命中次数
|
||||||
|
"sets": 0, # 缓存设置次数
|
||||||
|
"null_sets": 0, # 空结果缓存设置次数
|
||||||
|
"deletes": 0, # 缓存删除次数
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_cache_stats(cls):
|
||||||
|
"""获取缓存统计信息"""
|
||||||
|
result = []
|
||||||
|
for cache_type, stats in cls._cache_stats.items():
|
||||||
|
hits = stats["hits"]
|
||||||
|
null_hits = stats.get("null_hits", 0)
|
||||||
|
misses = stats["misses"]
|
||||||
|
total = hits + null_hits + misses
|
||||||
|
hit_rate = ((hits + null_hits) / total * 100) if total > 0 else 0
|
||||||
|
result.append(
|
||||||
|
{
|
||||||
|
"cache_type": cache_type,
|
||||||
|
"hits": hits,
|
||||||
|
"null_hits": null_hits,
|
||||||
|
"misses": misses,
|
||||||
|
"sets": stats["sets"],
|
||||||
|
"null_sets": stats.get("null_sets", 0),
|
||||||
|
"deletes": stats["deletes"],
|
||||||
|
"hit_rate": f"{hit_rate:.2f}%",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def reset_cache_stats(cls):
|
||||||
|
"""重置缓存统计信息"""
|
||||||
|
for stats in cls._cache_stats.values():
|
||||||
|
stats["hits"] = 0
|
||||||
|
stats["null_hits"] = 0
|
||||||
|
stats["misses"] = 0
|
||||||
|
stats["sets"] = 0
|
||||||
|
stats["null_sets"] = 0
|
||||||
|
stats["deletes"] = 0
|
||||||
|
|
||||||
|
def _build_cache_key_from_kwargs(self, **kwargs) -> str | None:
|
||||||
|
"""从关键字参数构建缓存键
|
||||||
|
|
||||||
|
参数:
|
||||||
|
**kwargs: 关键字参数
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str | None: 缓存键,如果无法构建则返回None
|
||||||
|
"""
|
||||||
|
if isinstance(self.key_field, tuple):
|
||||||
|
# 多字段主键
|
||||||
|
key_parts = []
|
||||||
|
key_parts.extend(str(kwargs.get(field, "")) for field in self.key_field)
|
||||||
|
return COMPOSITE_KEY_SEPARATOR.join(key_parts) if key_parts else None
|
||||||
|
elif self.key_field in kwargs:
|
||||||
|
# 单字段主键
|
||||||
|
return str(kwargs[self.key_field])
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def safe_get_or_none(self, *args, **kwargs) -> T | None:
|
||||||
|
"""安全的获取单条数据
|
||||||
|
|
||||||
|
参数:
|
||||||
|
*args: 查询参数
|
||||||
|
**kwargs: 查询参数
|
||||||
|
|
||||||
|
返回:
|
||||||
|
Optional[T]: 查询结果,如果不存在返回None
|
||||||
|
"""
|
||||||
|
# 如果没有缓存类型,直接从数据库获取
|
||||||
|
if not self.cache_type or cache_config.cache_mode == CacheMode.NONE:
|
||||||
|
logger.debug(f"{self.model_cls.__name__} 直接从数据库获取数据: {kwargs}")
|
||||||
|
return await with_db_timeout(
|
||||||
|
self.model_cls.safe_get_or_none(*args, **kwargs),
|
||||||
|
operation=f"{self.model_cls.__name__}.safe_get_or_none",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 尝试从缓存获取
|
||||||
|
cache_key = None
|
||||||
|
try:
|
||||||
|
# 尝试构建缓存键
|
||||||
|
cache_key = self._build_cache_key_from_kwargs(**kwargs)
|
||||||
|
|
||||||
|
# 如果成功构建缓存键,尝试从缓存获取
|
||||||
|
if cache_key is not None:
|
||||||
|
data = await self.cache.get(cache_key)
|
||||||
|
logger.debug(
|
||||||
|
f"{self.model_cls.__name__} self.cache.get(cache_key)"
|
||||||
|
f" 从缓存获取到的数据 {type(data)}: {data}"
|
||||||
|
)
|
||||||
|
if data == self._NULL_RESULT:
|
||||||
|
# 空结果缓存命中
|
||||||
|
self._cache_stats[self.cache_type]["null_hits"] += 1
|
||||||
|
logger.debug(
|
||||||
|
f"{self.model_cls.__name__} 从缓存获取到空结果: {cache_key}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
elif data:
|
||||||
|
# 缓存命中
|
||||||
|
self._cache_stats[self.cache_type]["hits"] += 1
|
||||||
|
logger.debug(
|
||||||
|
f"{self.model_cls.__name__} 从缓存获取数据成功: {cache_key}"
|
||||||
|
)
|
||||||
|
return cast(T, data)
|
||||||
|
else:
|
||||||
|
# 缓存未命中
|
||||||
|
self._cache_stats[self.cache_type]["misses"] += 1
|
||||||
|
logger.debug(f"{self.model_cls.__name__} 缓存未命中: {cache_key}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"{self.model_cls.__name__} 从缓存获取数据失败: {kwargs}", e=e)
|
||||||
|
|
||||||
|
# 如果缓存中没有,从数据库获取
|
||||||
|
logger.debug(f"{self.model_cls.__name__} 从数据库获取数据: {kwargs}")
|
||||||
|
data = await self.model_cls.safe_get_or_none(*args, **kwargs)
|
||||||
|
|
||||||
|
# 如果获取到数据,存入缓存
|
||||||
|
if data:
|
||||||
|
try:
|
||||||
|
# 生成缓存键
|
||||||
|
cache_key = self._build_cache_key_for_item(data)
|
||||||
|
if cache_key is not None:
|
||||||
|
# 存入缓存
|
||||||
|
await self.cache.set(cache_key, data)
|
||||||
|
self._cache_stats[self.cache_type]["sets"] += 1
|
||||||
|
logger.debug(
|
||||||
|
f"{self.model_cls.__name__} 数据已存入缓存: {cache_key}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"{self.model_cls.__name__} 存入缓存失败,参数: {kwargs}", e=e
|
||||||
|
)
|
||||||
|
elif cache_key is not None:
|
||||||
|
# 如果没有获取到数据,缓存空结果
|
||||||
|
try:
|
||||||
|
# 存入空结果缓存,使用较短的过期时间
|
||||||
|
await self.cache.set(
|
||||||
|
cache_key, self._NULL_RESULT, expire=self._NULL_RESULT_TTL
|
||||||
|
)
|
||||||
|
self._cache_stats[self.cache_type]["null_sets"] += 1
|
||||||
|
logger.debug(
|
||||||
|
f"{self.model_cls.__name__} 空结果已存入缓存: {cache_key},"
|
||||||
|
f" TTL={self._NULL_RESULT_TTL}秒"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"{self.model_cls.__name__} 存入空结果缓存失败,参数: {kwargs}", e=e
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
async def get_or_none(self, *args, **kwargs) -> T | None:
|
||||||
|
"""获取单条数据
|
||||||
|
|
||||||
|
参数:
|
||||||
|
*args: 查询参数
|
||||||
|
**kwargs: 查询参数
|
||||||
|
|
||||||
|
返回:
|
||||||
|
Optional[T]: 查询结果,如果不存在返回None
|
||||||
|
"""
|
||||||
|
# 如果没有缓存类型,直接从数据库获取
|
||||||
|
if not self.cache_type or cache_config.cache_mode == CacheMode.NONE:
|
||||||
|
logger.debug(f"{self.model_cls.__name__} 直接从数据库获取数据: {kwargs}")
|
||||||
|
return await with_db_timeout(
|
||||||
|
self.model_cls.get_or_none(*args, **kwargs),
|
||||||
|
operation=f"{self.model_cls.__name__}.get_or_none",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 尝试从缓存获取
|
||||||
|
cache_key = None
|
||||||
|
try:
|
||||||
|
# 尝试构建缓存键
|
||||||
|
cache_key = self._build_cache_key_from_kwargs(**kwargs)
|
||||||
|
|
||||||
|
# 如果成功构建缓存键,尝试从缓存获取
|
||||||
|
if cache_key is not None:
|
||||||
|
data = await self.cache.get(cache_key)
|
||||||
|
if data == self._NULL_RESULT:
|
||||||
|
# 空结果缓存命中
|
||||||
|
self._cache_stats[self.cache_type]["null_hits"] += 1
|
||||||
|
logger.debug(
|
||||||
|
f"{self.model_cls.__name__} 从缓存获取到空结果: {cache_key}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
elif data:
|
||||||
|
# 缓存命中
|
||||||
|
self._cache_stats[self.cache_type]["hits"] += 1
|
||||||
|
logger.debug(
|
||||||
|
f"{self.model_cls.__name__} 从缓存获取数据成功: {cache_key}"
|
||||||
|
)
|
||||||
|
return cast(T, data)
|
||||||
|
else:
|
||||||
|
# 缓存未命中
|
||||||
|
self._cache_stats[self.cache_type]["misses"] += 1
|
||||||
|
logger.debug(f"{self.model_cls.__name__} 缓存未命中: {cache_key}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"{self.model_cls.__name__} 从缓存获取数据失败: {kwargs}", e=e)
|
||||||
|
|
||||||
|
# 如果缓存中没有,从数据库获取
|
||||||
|
logger.debug(f"{self.model_cls.__name__} 从数据库获取数据: {kwargs}")
|
||||||
|
data = await self.model_cls.get_or_none(*args, **kwargs)
|
||||||
|
|
||||||
|
# 如果获取到数据,存入缓存
|
||||||
|
if data:
|
||||||
|
try:
|
||||||
|
cache_key = self._build_cache_key_for_item(data)
|
||||||
|
# 生成缓存键
|
||||||
|
if cache_key is not None:
|
||||||
|
# 存入缓存
|
||||||
|
await self.cache.set(cache_key, data)
|
||||||
|
self._cache_stats[self.cache_type]["sets"] += 1
|
||||||
|
logger.debug(
|
||||||
|
f"{self.model_cls.__name__} 数据已存入缓存: {cache_key}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"{self.model_cls.__name__} 存入缓存失败,参数: {kwargs}", e=e
|
||||||
|
)
|
||||||
|
elif cache_key is not None:
|
||||||
|
# 如果没有获取到数据,缓存空结果
|
||||||
|
try:
|
||||||
|
# 存入空结果缓存,使用较短的过期时间
|
||||||
|
await self.cache.set(
|
||||||
|
cache_key, self._NULL_RESULT, expire=self._NULL_RESULT_TTL
|
||||||
|
)
|
||||||
|
self._cache_stats[self.cache_type]["null_sets"] += 1
|
||||||
|
logger.debug(
|
||||||
|
f"{self.model_cls.__name__} 空结果已存入缓存: {cache_key},"
|
||||||
|
f" TTL={self._NULL_RESULT_TTL}秒"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"{self.model_cls.__name__} 存入空结果缓存失败,参数: {kwargs}", e=e
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
async def clear_cache(self, **kwargs) -> bool:
|
||||||
|
"""只清除缓存,不影响数据库数据
|
||||||
|
|
||||||
|
参数:
|
||||||
|
**kwargs: 查询参数,必须包含主键字段
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bool: 是否成功清除缓存
|
||||||
|
"""
|
||||||
|
# 如果没有缓存类型,直接返回True
|
||||||
|
if not self.cache_type or cache_config.cache_mode == CacheMode.NONE:
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 构建缓存键
|
||||||
|
cache_key = self._build_cache_key_from_kwargs(**kwargs)
|
||||||
|
if cache_key is None:
|
||||||
|
if isinstance(self.key_field, tuple):
|
||||||
|
# 如果是复合键,检查缺少哪些字段
|
||||||
|
missing_fields = [
|
||||||
|
field for field in self.key_field if field not in kwargs
|
||||||
|
]
|
||||||
|
logger.error(
|
||||||
|
f"清除{self.model_cls.__name__}缓存失败: "
|
||||||
|
f"缺少主键字段 {', '.join(missing_fields)}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"清除{self.model_cls.__name__}缓存失败: "
|
||||||
|
f"缺少主键字段 {self.key_field}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 删除缓存
|
||||||
|
await self.cache.delete(cache_key)
|
||||||
|
self._cache_stats[self.cache_type]["deletes"] += 1
|
||||||
|
logger.debug(f"已清除{self.model_cls.__name__}缓存: {cache_key}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"清除{self.model_cls.__name__}缓存失败", e=e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _build_composite_key(self, data: T) -> str | None:
|
||||||
|
"""构建复合缓存键
|
||||||
|
|
||||||
|
参数:
|
||||||
|
data: 数据对象
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str | None: 构建的缓存键,如果无法构建则返回None
|
||||||
|
"""
|
||||||
|
# 如果是元组,表示多个字段组成键
|
||||||
|
if isinstance(self.key_field, tuple):
|
||||||
|
# 构建键参数列表
|
||||||
|
key_parts = []
|
||||||
|
for field in self.key_field:
|
||||||
|
value = getattr(data, field, "")
|
||||||
|
key_parts.append(value if value is not None else "")
|
||||||
|
|
||||||
|
# 如果没有有效参数,返回None
|
||||||
|
return COMPOSITE_KEY_SEPARATOR.join(key_parts) if key_parts else None
|
||||||
|
elif hasattr(data, self.key_field):
|
||||||
|
value = getattr(data, self.key_field, None)
|
||||||
|
return str(value) if value is not None else None
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _build_cache_key_for_item(self, item: T) -> str | None:
|
||||||
|
"""为数据项构建缓存键
|
||||||
|
|
||||||
|
参数:
|
||||||
|
item: 数据项
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str | None: 缓存键,如果无法生成则返回None
|
||||||
|
"""
|
||||||
|
# 如果没有缓存类型,返回None
|
||||||
|
if not self.cache_type:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 获取缓存类型的配置信息
|
||||||
|
cache_model = CacheRoot.get_model(self.cache_type)
|
||||||
|
|
||||||
|
if not cache_model.key_format:
|
||||||
|
# 常规处理,使用主键作为缓存键
|
||||||
|
return self._build_composite_key(item)
|
||||||
|
# 构建键参数字典
|
||||||
|
key_parts = []
|
||||||
|
# 从格式字符串中提取所需的字段名
|
||||||
|
import re
|
||||||
|
|
||||||
|
field_names = re.findall(r"{([^}]+)}", cache_model.key_format)
|
||||||
|
|
||||||
|
# 收集所有字段值
|
||||||
|
for field in field_names:
|
||||||
|
value = getattr(item, field, "")
|
||||||
|
key_parts.append(value if value is not None else "")
|
||||||
|
|
||||||
|
return COMPOSITE_KEY_SEPARATOR.join(key_parts)
|
||||||
|
|
||||||
|
async def _cache_items(self, data_list: list[T]) -> None:
|
||||||
|
"""将数据列表存入缓存
|
||||||
|
|
||||||
|
参数:
|
||||||
|
data_list: 数据列表
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
not data_list
|
||||||
|
or not self.cache_type
|
||||||
|
or cache_config.cache_mode == CacheMode.NONE
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 遍历数据列表,将每条数据存入缓存
|
||||||
|
cached_count = 0
|
||||||
|
for item in data_list:
|
||||||
|
cache_key = self._build_cache_key_for_item(item)
|
||||||
|
if cache_key is not None:
|
||||||
|
await self.cache.set(cache_key, item)
|
||||||
|
cached_count += 1
|
||||||
|
self._cache_stats[self.cache_type]["sets"] += 1
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"{self.model_cls.__name__} 批量缓存: {cached_count}/{len(data_list)}项"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"{self.model_cls.__name__} 批量缓存失败", e=e)
|
||||||
|
|
||||||
|
async def filter(self, *args, **kwargs) -> list[T]:
|
||||||
|
"""筛选数据
|
||||||
|
|
||||||
|
参数:
|
||||||
|
*args: 查询参数
|
||||||
|
**kwargs: 查询参数
|
||||||
|
|
||||||
|
返回:
|
||||||
|
List[T]: 查询结果列表
|
||||||
|
"""
|
||||||
|
# 从数据库获取数据
|
||||||
|
logger.debug(f"{self.model_cls.__name__} filter: 从数据库查询, 参数: {kwargs}")
|
||||||
|
data_list = await self.model_cls.filter(*args, **kwargs)
|
||||||
|
logger.debug(
|
||||||
|
f"{self.model_cls.__name__} filter: 查询结果数量: {len(data_list)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 将数据存入缓存
|
||||||
|
await self._cache_items(data_list)
|
||||||
|
|
||||||
|
return data_list
|
||||||
|
|
||||||
|
async def all(self) -> list[T]:
|
||||||
|
"""获取所有数据
|
||||||
|
|
||||||
|
返回:
|
||||||
|
List[T]: 所有数据列表
|
||||||
|
"""
|
||||||
|
# 直接从数据库获取
|
||||||
|
logger.debug(f"{self.model_cls.__name__} all: 从数据库查询所有数据")
|
||||||
|
data_list = await self.model_cls.all()
|
||||||
|
logger.debug(f"{self.model_cls.__name__} all: 查询结果数量: {len(data_list)}")
|
||||||
|
|
||||||
|
# 将数据存入缓存
|
||||||
|
await self._cache_items(data_list)
|
||||||
|
|
||||||
|
return data_list
|
||||||
|
|
||||||
|
async def count(self, *args, **kwargs) -> int:
|
||||||
|
"""获取数据数量
|
||||||
|
|
||||||
|
参数:
|
||||||
|
*args: 查询参数
|
||||||
|
**kwargs: 查询参数
|
||||||
|
|
||||||
|
返回:
|
||||||
|
int: 数据数量
|
||||||
|
"""
|
||||||
|
# 直接从数据库获取数量
|
||||||
|
return await self.model_cls.filter(*args, **kwargs).count()
|
||||||
|
|
||||||
|
async def exists(self, *args, **kwargs) -> bool:
|
||||||
|
"""判断数据是否存在
|
||||||
|
|
||||||
|
参数:
|
||||||
|
*args: 查询参数
|
||||||
|
**kwargs: 查询参数
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bool: 是否存在
|
||||||
|
"""
|
||||||
|
# 直接从数据库判断是否存在
|
||||||
|
return await self.model_cls.filter(*args, **kwargs).exists()
|
||||||
|
|
||||||
|
async def create(self, **kwargs) -> T:
|
||||||
|
"""创建数据
|
||||||
|
|
||||||
|
参数:
|
||||||
|
**kwargs: 创建参数
|
||||||
|
|
||||||
|
返回:
|
||||||
|
T: 创建的数据
|
||||||
|
"""
|
||||||
|
# 创建数据
|
||||||
|
logger.debug(f"{self.model_cls.__name__} create: 创建数据, 参数: {kwargs}")
|
||||||
|
data = await self.model_cls.create(**kwargs)
|
||||||
|
|
||||||
|
# 如果有缓存类型,将数据存入缓存
|
||||||
|
if self.cache_type and cache_config.cache_mode != CacheMode.NONE:
|
||||||
|
try:
|
||||||
|
# 生成缓存键
|
||||||
|
cache_key = self._build_cache_key_for_item(data)
|
||||||
|
if cache_key is not None:
|
||||||
|
# 存入缓存
|
||||||
|
await self.cache.set(cache_key, data)
|
||||||
|
self._cache_stats[self.cache_type]["sets"] += 1
|
||||||
|
logger.debug(
|
||||||
|
f"{self.model_cls.__name__} create: "
|
||||||
|
f"新创建的数据已存入缓存: {cache_key}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"{self.model_cls.__name__} create: 存入缓存失败,参数: {kwargs}",
|
||||||
|
e=e,
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
async def update_or_create(
|
||||||
|
self, defaults: dict[str, Any] | None = None, **kwargs
|
||||||
|
) -> tuple[T, bool]:
|
||||||
|
"""更新或创建数据
|
||||||
|
|
||||||
|
参数:
|
||||||
|
defaults: 默认值
|
||||||
|
**kwargs: 查询参数
|
||||||
|
|
||||||
|
返回:
|
||||||
|
tuple[T, bool]: (数据, 是否创建)
|
||||||
|
"""
|
||||||
|
# 更新或创建数据
|
||||||
|
data, created = await self.model_cls.update_or_create(
|
||||||
|
defaults=defaults, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# 如果有缓存类型,将数据存入缓存
|
||||||
|
if self.cache_type and cache_config.cache_mode != CacheMode.NONE:
|
||||||
|
try:
|
||||||
|
# 生成缓存键
|
||||||
|
cache_key = self._build_cache_key_for_item(data)
|
||||||
|
if cache_key is not None:
|
||||||
|
# 存入缓存
|
||||||
|
await self.cache.set(cache_key, data)
|
||||||
|
self._cache_stats[self.cache_type]["sets"] += 1
|
||||||
|
logger.debug(f"更新或创建的数据已存入缓存: {cache_key}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"存入缓存失败,参数: {kwargs}", e=e)
|
||||||
|
|
||||||
|
return data, created
|
||||||
|
|
||||||
|
async def delete(self, *args, **kwargs) -> int:
|
||||||
|
"""删除数据
|
||||||
|
|
||||||
|
参数:
|
||||||
|
*args: 查询参数
|
||||||
|
**kwargs: 查询参数
|
||||||
|
|
||||||
|
返回:
|
||||||
|
int: 删除的数据数量
|
||||||
|
"""
|
||||||
|
logger.debug(f"{self.model_cls.__name__} delete: 删除数据, 参数: {kwargs}")
|
||||||
|
|
||||||
|
# 如果有缓存类型且有key_field参数,先尝试删除缓存
|
||||||
|
if self.cache_type and cache_config.cache_mode != CacheMode.NONE:
|
||||||
|
try:
|
||||||
|
# 尝试构建缓存键
|
||||||
|
cache_key = self._build_cache_key_from_kwargs(**kwargs)
|
||||||
|
|
||||||
|
if cache_key is not None:
|
||||||
|
# 如果成功构建缓存键,直接删除缓存
|
||||||
|
await self.cache.delete(cache_key)
|
||||||
|
self._cache_stats[self.cache_type]["deletes"] += 1
|
||||||
|
logger.debug(
|
||||||
|
f"{self.model_cls.__name__} delete: 已删除缓存: {cache_key}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 否则需要先查询出要删除的数据,然后删除对应的缓存
|
||||||
|
items = await self.model_cls.filter(*args, **kwargs)
|
||||||
|
logger.debug(
|
||||||
|
f"{self.model_cls.__name__} delete:"
|
||||||
|
f" 查询到 {len(items)} 条要删除的数据"
|
||||||
|
)
|
||||||
|
for item in items:
|
||||||
|
item_cache_key = self._build_cache_key_for_item(item)
|
||||||
|
if item_cache_key is not None:
|
||||||
|
await self.cache.delete(item_cache_key)
|
||||||
|
self._cache_stats[self.cache_type]["deletes"] += 1
|
||||||
|
if items:
|
||||||
|
logger.debug(
|
||||||
|
f"{self.model_cls.__name__} delete:"
|
||||||
|
f" 已删除 {len(items)} 条数据的缓存"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"{self.model_cls.__name__} delete: 删除缓存失败", e=e)
|
||||||
|
|
||||||
|
# 删除数据
|
||||||
|
result = await self.model_cls.filter(*args, **kwargs).delete()
|
||||||
|
logger.debug(
|
||||||
|
f"{self.model_cls.__name__} delete: 已从数据库删除 {result} 条数据"
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _generate_cache_key(self, data: T) -> str:
|
||||||
|
"""根据数据对象生成缓存键
|
||||||
|
|
||||||
|
参数:
|
||||||
|
data: 数据对象
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 缓存键
|
||||||
|
"""
|
||||||
|
# 使用新方法构建复合键
|
||||||
|
if composite_key := self._build_composite_key(data):
|
||||||
|
return composite_key
|
||||||
|
|
||||||
|
# 如果无法生成复合键,生成一个唯一键
|
||||||
|
return f"object_{id(data)}"
|
||||||
@ -1,37 +1,328 @@
|
|||||||
import nonebot
|
import asyncio
|
||||||
|
from collections.abc import Iterable
|
||||||
|
import contextlib
|
||||||
|
import time
|
||||||
|
from typing import Any, ClassVar
|
||||||
|
from typing_extensions import Self
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from nonebot import get_driver
|
||||||
from nonebot.utils import is_coroutine_callable
|
from nonebot.utils import is_coroutine_callable
|
||||||
from tortoise import Tortoise
|
from tortoise import Tortoise
|
||||||
|
from tortoise.backends.base.client import BaseDBAsyncClient
|
||||||
from tortoise.connection import connections
|
from tortoise.connection import connections
|
||||||
from tortoise.models import Model as Model_
|
from tortoise.exceptions import IntegrityError, MultipleObjectsReturned
|
||||||
|
from tortoise.models import Model as TortoiseModel
|
||||||
|
from tortoise.transactions import in_transaction
|
||||||
|
|
||||||
from zhenxun.configs.config import BotConfig
|
from zhenxun.configs.config import BotConfig
|
||||||
|
from zhenxun.services.cache import CacheRoot
|
||||||
|
from zhenxun.services.log import logger
|
||||||
|
from zhenxun.utils.enum import DbLockType
|
||||||
from zhenxun.utils.exception import HookPriorityException
|
from zhenxun.utils.exception import HookPriorityException
|
||||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||||
|
|
||||||
from .log import logger
|
driver = get_driver()
|
||||||
|
|
||||||
SCRIPT_METHOD = []
|
SCRIPT_METHOD = []
|
||||||
MODELS: list[str] = []
|
MODELS: list[str] = []
|
||||||
|
|
||||||
|
# 数据库操作超时设置(秒)
|
||||||
|
DB_TIMEOUT_SECONDS = 3.0
|
||||||
|
|
||||||
driver = nonebot.get_driver()
|
# 性能监控阈值(秒)
|
||||||
|
SLOW_QUERY_THRESHOLD = 0.5
|
||||||
|
|
||||||
|
LOG_COMMAND = "DbContext"
|
||||||
|
|
||||||
|
|
||||||
class Model(Model_):
|
async def with_db_timeout(
|
||||||
|
coro, timeout: float = DB_TIMEOUT_SECONDS, operation: str | None = None
|
||||||
|
):
|
||||||
|
"""带超时控制的数据库操作"""
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
result = await asyncio.wait_for(coro, timeout=timeout)
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
if elapsed > SLOW_QUERY_THRESHOLD and operation:
|
||||||
|
logger.warning(f"慢查询: {operation} 耗时 {elapsed:.3f}s", LOG_COMMAND)
|
||||||
|
return result
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
if operation:
|
||||||
|
logger.error(f"数据库操作超时: {operation} (>{timeout}s)", LOG_COMMAND)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
class Model(TortoiseModel):
|
||||||
"""
|
"""
|
||||||
自动添加模块
|
增强的ORM基类,解决锁嵌套问题
|
||||||
|
|
||||||
Args:
|
|
||||||
Model_: Model
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
sem_data: ClassVar[dict[str, dict[str, asyncio.Semaphore]]] = {}
|
||||||
|
_current_locks: ClassVar[dict[int, DbLockType]] = {} # 跟踪当前协程持有的锁
|
||||||
|
|
||||||
def __init_subclass__(cls, **kwargs):
|
def __init_subclass__(cls, **kwargs):
|
||||||
|
super().__init_subclass__(**kwargs)
|
||||||
if cls.__module__ not in MODELS:
|
if cls.__module__ not in MODELS:
|
||||||
MODELS.append(cls.__module__)
|
MODELS.append(cls.__module__)
|
||||||
|
|
||||||
if func := getattr(cls, "_run_script", None):
|
if func := getattr(cls, "_run_script", None):
|
||||||
SCRIPT_METHOD.append((cls.__module__, func))
|
SCRIPT_METHOD.append((cls.__module__, func))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_cache_type(cls) -> str | None:
|
||||||
|
"""获取缓存类型"""
|
||||||
|
return getattr(cls, "cache_type", None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_cache_key_field(cls) -> str | tuple[str]:
|
||||||
|
"""获取缓存键字段"""
|
||||||
|
return getattr(cls, "cache_key_field", "id")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_cache_key(cls, instance) -> str | None:
|
||||||
|
"""获取缓存键
|
||||||
|
|
||||||
|
参数:
|
||||||
|
instance: 模型实例
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str | None: 缓存键,如果无法获取则返回None
|
||||||
|
"""
|
||||||
|
from zhenxun.services.cache.config import COMPOSITE_KEY_SEPARATOR
|
||||||
|
|
||||||
|
key_field = cls.get_cache_key_field()
|
||||||
|
|
||||||
|
if isinstance(key_field, tuple):
|
||||||
|
# 多字段主键
|
||||||
|
key_parts = []
|
||||||
|
for field in key_field:
|
||||||
|
if hasattr(instance, field):
|
||||||
|
value = getattr(instance, field, None)
|
||||||
|
key_parts.append(value if value is not None else "")
|
||||||
|
else:
|
||||||
|
# 如果缺少任何必要的字段,返回None
|
||||||
|
key_parts.append("")
|
||||||
|
|
||||||
|
# 如果没有有效参数,返回None
|
||||||
|
return COMPOSITE_KEY_SEPARATOR.join(key_parts) if key_parts else None
|
||||||
|
elif hasattr(instance, key_field):
|
||||||
|
value = getattr(instance, key_field, None)
|
||||||
|
return str(value) if value is not None else None
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_semaphore(cls, lock_type: DbLockType):
|
||||||
|
enable_lock = getattr(cls, "enable_lock", None)
|
||||||
|
if not enable_lock or lock_type not in enable_lock:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if cls.__name__ not in cls.sem_data:
|
||||||
|
cls.sem_data[cls.__name__] = {}
|
||||||
|
if lock_type not in cls.sem_data[cls.__name__]:
|
||||||
|
cls.sem_data[cls.__name__][lock_type] = asyncio.Semaphore(1)
|
||||||
|
return cls.sem_data[cls.__name__][lock_type]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _require_lock(cls, lock_type: DbLockType) -> bool:
|
||||||
|
"""检查是否需要真正加锁"""
|
||||||
|
task_id = id(asyncio.current_task())
|
||||||
|
return cls._current_locks.get(task_id) != lock_type
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def _lock_context(cls, lock_type: DbLockType):
|
||||||
|
"""带重入检查的锁上下文"""
|
||||||
|
task_id = id(asyncio.current_task())
|
||||||
|
need_lock = cls._require_lock(lock_type)
|
||||||
|
|
||||||
|
if need_lock and (sem := cls.get_semaphore(lock_type)):
|
||||||
|
cls._current_locks[task_id] = lock_type
|
||||||
|
async with sem:
|
||||||
|
yield
|
||||||
|
cls._current_locks.pop(task_id, None)
|
||||||
|
else:
|
||||||
|
yield
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create(
|
||||||
|
cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any
|
||||||
|
) -> Self:
|
||||||
|
"""创建数据(使用CREATE锁)"""
|
||||||
|
async with cls._lock_context(DbLockType.CREATE):
|
||||||
|
# 直接调用父类的_create方法避免触发save的锁
|
||||||
|
result = await super().create(using_db=using_db, **kwargs)
|
||||||
|
if cache_type := cls.get_cache_type():
|
||||||
|
await CacheRoot.invalidate_cache(cache_type, cls.get_cache_key(result))
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_or_create(
|
||||||
|
cls,
|
||||||
|
defaults: dict | None = None,
|
||||||
|
using_db: BaseDBAsyncClient | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> tuple[Self, bool]:
|
||||||
|
"""获取或创建数据(无锁版本,依赖数据库约束)"""
|
||||||
|
result = await super().get_or_create(
|
||||||
|
defaults=defaults, using_db=using_db, **kwargs
|
||||||
|
)
|
||||||
|
if cache_type := cls.get_cache_type():
|
||||||
|
await CacheRoot.invalidate_cache(cache_type, cls.get_cache_key(result[0]))
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def update_or_create(
|
||||||
|
cls,
|
||||||
|
defaults: dict | None = None,
|
||||||
|
using_db: BaseDBAsyncClient | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> tuple[Self, bool]:
|
||||||
|
"""更新或创建数据(使用UPSERT锁)"""
|
||||||
|
async with cls._lock_context(DbLockType.UPSERT):
|
||||||
|
try:
|
||||||
|
# 先尝试更新(带行锁)
|
||||||
|
async with in_transaction():
|
||||||
|
if obj := await cls.filter(**kwargs).select_for_update().first():
|
||||||
|
await obj.update_from_dict(defaults or {})
|
||||||
|
await obj.save()
|
||||||
|
result = (obj, False)
|
||||||
|
else:
|
||||||
|
# 创建时不重复加锁
|
||||||
|
result = await cls.create(**kwargs, **(defaults or {})), True
|
||||||
|
|
||||||
|
if cache_type := cls.get_cache_type():
|
||||||
|
await CacheRoot.invalidate_cache(
|
||||||
|
cache_type, cls.get_cache_key(result[0])
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
except IntegrityError:
|
||||||
|
# 处理极端情况下的唯一约束冲突
|
||||||
|
obj = await cls.get(**kwargs)
|
||||||
|
return obj, False
|
||||||
|
|
||||||
|
async def save(
|
||||||
|
self,
|
||||||
|
using_db: BaseDBAsyncClient | None = None,
|
||||||
|
update_fields: Iterable[str] | None = None,
|
||||||
|
force_create: bool = False,
|
||||||
|
force_update: bool = False,
|
||||||
|
):
|
||||||
|
"""保存数据(根据操作类型自动选择锁)"""
|
||||||
|
lock_type = (
|
||||||
|
DbLockType.CREATE
|
||||||
|
if getattr(self, "id", None) is None
|
||||||
|
else DbLockType.UPDATE
|
||||||
|
)
|
||||||
|
async with self._lock_context(lock_type):
|
||||||
|
await super().save(
|
||||||
|
using_db=using_db,
|
||||||
|
update_fields=update_fields,
|
||||||
|
force_create=force_create,
|
||||||
|
force_update=force_update,
|
||||||
|
)
|
||||||
|
if cache_type := getattr(self, "cache_type", None):
|
||||||
|
await CacheRoot.invalidate_cache(
|
||||||
|
cache_type, self.__class__.get_cache_key(self)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete(self, using_db: BaseDBAsyncClient | None = None):
|
||||||
|
cache_type = getattr(self, "cache_type", None)
|
||||||
|
key = self.__class__.get_cache_key(self) if cache_type else None
|
||||||
|
# 执行删除操作
|
||||||
|
await super().delete(using_db=using_db)
|
||||||
|
|
||||||
|
# 清除缓存
|
||||||
|
if cache_type:
|
||||||
|
await CacheRoot.invalidate_cache(cache_type, key)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def safe_get_or_none(
|
||||||
|
cls,
|
||||||
|
*args,
|
||||||
|
using_db: BaseDBAsyncClient | None = None,
|
||||||
|
clean_duplicates: bool = True,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Self | None:
|
||||||
|
"""安全地获取一条记录或None,处理存在多个记录时返回最新的那个
|
||||||
|
注意,默认会删除重复的记录,仅保留最新的
|
||||||
|
|
||||||
|
参数:
|
||||||
|
*args: 查询参数
|
||||||
|
using_db: 数据库连接
|
||||||
|
clean_duplicates: 是否删除重复的记录,仅保留最新的
|
||||||
|
**kwargs: 查询参数
|
||||||
|
|
||||||
|
返回:
|
||||||
|
Self | None: 查询结果,如果不存在返回None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 先尝试使用 get_or_none 获取单个记录
|
||||||
|
try:
|
||||||
|
return await with_db_timeout(
|
||||||
|
cls.get_or_none(*args, using_db=using_db, **kwargs),
|
||||||
|
operation=f"{cls.__name__}.get_or_none",
|
||||||
|
)
|
||||||
|
except MultipleObjectsReturned:
|
||||||
|
# 如果出现多个记录的情况,进行特殊处理
|
||||||
|
logger.warning(
|
||||||
|
f"{cls.__name__} safe_get_or_none 发现多个记录: {kwargs}",
|
||||||
|
LOG_COMMAND,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 查询所有匹配记录
|
||||||
|
records = await with_db_timeout(
|
||||||
|
cls.filter(*args, **kwargs).all(),
|
||||||
|
operation=f"{cls.__name__}.filter.all",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not records:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 如果需要清理重复记录
|
||||||
|
if clean_duplicates and hasattr(records[0], "id"):
|
||||||
|
# 按 id 排序
|
||||||
|
records = sorted(
|
||||||
|
records, key=lambda x: getattr(x, "id", 0), reverse=True
|
||||||
|
)
|
||||||
|
for record in records[1:]:
|
||||||
|
try:
|
||||||
|
await with_db_timeout(
|
||||||
|
record.delete(),
|
||||||
|
operation=f"{cls.__name__}.delete_duplicate",
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"{cls.__name__} 删除重复记录:"
|
||||||
|
f" id={getattr(record, 'id', None)}",
|
||||||
|
LOG_COMMAND,
|
||||||
|
)
|
||||||
|
except Exception as del_e:
|
||||||
|
logger.error(f"删除重复记录失败: {del_e}")
|
||||||
|
return records[0]
|
||||||
|
# 如果不需要清理或没有 id 字段,则返回最新的记录
|
||||||
|
if hasattr(cls, "id"):
|
||||||
|
return await with_db_timeout(
|
||||||
|
cls.filter(*args, **kwargs).order_by("-id").first(),
|
||||||
|
operation=f"{cls.__name__}.filter.order_by.first",
|
||||||
|
)
|
||||||
|
# 如果没有 id 字段,则返回第一个记录
|
||||||
|
return await with_db_timeout(
|
||||||
|
cls.filter(*args, **kwargs).first(),
|
||||||
|
operation=f"{cls.__name__}.filter.first",
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(
|
||||||
|
f"数据库操作超时: {cls.__name__}.safe_get_or_none", LOG_COMMAND
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
# 其他类型的错误则继续抛出
|
||||||
|
logger.error(
|
||||||
|
f"数据库操作异常: {cls.__name__}.safe_get_or_none, {e!s}", LOG_COMMAND
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
class DbUrlIsNode(HookPriorityException):
|
class DbUrlIsNode(HookPriorityException):
|
||||||
"""
|
"""
|
||||||
@ -49,6 +340,77 @@ class DbConnectError(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
POSTGRESQL_CONFIG = {
|
||||||
|
"max_size": 30, # 最大连接数
|
||||||
|
"min_size": 5, # 最小保持的连接数(可选)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
MYSQL_CONFIG = {
|
||||||
|
"max_connections": 20, # 最大连接数
|
||||||
|
"connect_timeout": 30, # 连接超时(可选)
|
||||||
|
}
|
||||||
|
|
||||||
|
SQLITE_CONFIG = {
|
||||||
|
"journal_mode": "WAL", # 提高并发写入性能
|
||||||
|
"timeout": 30, # 锁等待超时(可选)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_config(db_url: str) -> dict:
|
||||||
|
"""获取数据库配置"""
|
||||||
|
parsed = urlparse(BotConfig.db_url)
|
||||||
|
|
||||||
|
# 基础配置
|
||||||
|
config = {
|
||||||
|
"connections": {
|
||||||
|
"default": BotConfig.db_url # 默认直接使用连接字符串
|
||||||
|
},
|
||||||
|
"apps": {
|
||||||
|
"models": {
|
||||||
|
"models": MODELS,
|
||||||
|
"default_connection": "default",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"timezone": "Asia/Shanghai",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 根据数据库类型应用高级配置
|
||||||
|
if parsed.scheme.startswith("postgres"):
|
||||||
|
config["connections"]["default"] = {
|
||||||
|
"engine": "tortoise.backends.asyncpg",
|
||||||
|
"credentials": {
|
||||||
|
"host": parsed.hostname,
|
||||||
|
"port": parsed.port or 5432,
|
||||||
|
"user": parsed.username,
|
||||||
|
"password": parsed.password,
|
||||||
|
"database": parsed.path[1:],
|
||||||
|
},
|
||||||
|
**POSTGRESQL_CONFIG,
|
||||||
|
}
|
||||||
|
elif parsed.scheme == "mysql":
|
||||||
|
config["connections"]["default"] = {
|
||||||
|
"engine": "tortoise.backends.mysql",
|
||||||
|
"credentials": {
|
||||||
|
"host": parsed.hostname,
|
||||||
|
"port": parsed.port or 3306,
|
||||||
|
"user": parsed.username,
|
||||||
|
"password": parsed.password,
|
||||||
|
"database": parsed.path[1:],
|
||||||
|
},
|
||||||
|
**MYSQL_CONFIG,
|
||||||
|
}
|
||||||
|
elif parsed.scheme == "sqlite":
|
||||||
|
config["connections"]["default"] = {
|
||||||
|
"engine": "tortoise.backends.sqlite",
|
||||||
|
"credentials": {
|
||||||
|
"file_path": parsed.path[1:] or ":memory:",
|
||||||
|
},
|
||||||
|
**SQLITE_CONFIG,
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
@PriorityLifecycle.on_startup(priority=1)
|
@PriorityLifecycle.on_startup(priority=1)
|
||||||
async def init():
|
async def init():
|
||||||
if not BotConfig.db_url:
|
if not BotConfig.db_url:
|
||||||
@ -64,9 +426,7 @@ async def init():
|
|||||||
raise DbUrlIsNode("\n" + error.strip())
|
raise DbUrlIsNode("\n" + error.strip())
|
||||||
try:
|
try:
|
||||||
await Tortoise.init(
|
await Tortoise.init(
|
||||||
db_url=BotConfig.db_url,
|
config=get_config(BotConfig.db_url),
|
||||||
modules={"models": MODELS},
|
|
||||||
timezone="Asia/Shanghai",
|
|
||||||
)
|
)
|
||||||
if SCRIPT_METHOD:
|
if SCRIPT_METHOD:
|
||||||
db = Tortoise.get_connection("default")
|
db = Tortoise.get_connection("default")
|
||||||
@ -85,13 +445,17 @@ async def init():
|
|||||||
for sql in sql_list:
|
for sql in sql_list:
|
||||||
logger.debug(f"执行SQL: {sql}")
|
logger.debug(f"执行SQL: {sql}")
|
||||||
try:
|
try:
|
||||||
await db.execute_query_dict(sql)
|
await asyncio.wait_for(
|
||||||
|
db.execute_query_dict(sql), timeout=DB_TIMEOUT_SECONDS
|
||||||
|
)
|
||||||
# await TestSQL.raw(sql)
|
# await TestSQL.raw(sql)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"执行SQL: {sql} 错误...", e=e)
|
logger.debug(f"执行SQL: {sql} 错误...", e=e)
|
||||||
if sql_list:
|
if sql_list:
|
||||||
logger.debug("SCRIPT_METHOD方法执行完毕!")
|
logger.debug("SCRIPT_METHOD方法执行完毕!")
|
||||||
|
logger.debug("开始生成数据库表结构...")
|
||||||
await Tortoise.generate_schemas()
|
await Tortoise.generate_schemas()
|
||||||
|
logger.debug("数据库表结构生成完毕!")
|
||||||
logger.info("Database loaded successfully!")
|
logger.info("Database loaded successfully!")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise DbConnectError(f"数据库连接错误... e:{e}") from e
|
raise DbConnectError(f"数据库连接错误... e:{e}") from e
|
||||||
|
|||||||
@ -469,7 +469,7 @@ class Notebook:
|
|||||||
template_name="main.html",
|
template_name="main.html",
|
||||||
templates={"elements": self._data},
|
templates={"elements": self._data},
|
||||||
pages={
|
pages={
|
||||||
"viewport": {"width": 700, "height": 1000},
|
"viewport": {"width": 700, "height": 10},
|
||||||
"base_url": f"file://{TEMPLATE_PATH}",
|
"base_url": f"file://{TEMPLATE_PATH}",
|
||||||
},
|
},
|
||||||
wait=2,
|
wait=2,
|
||||||
|
|||||||
@ -53,9 +53,7 @@ class CommonUtils:
|
|||||||
if await GroupConsole.is_block_task(group_id, module):
|
if await GroupConsole.is_block_task(group_id, module):
|
||||||
"""群组是否禁用被动"""
|
"""群组是否禁用被动"""
|
||||||
return True
|
return True
|
||||||
if g := await GroupConsole.get_or_none(
|
if g := await GroupConsole.get_group(group_id=group_id):
|
||||||
group_id=group_id, channel_id__isnull=True
|
|
||||||
):
|
|
||||||
"""群组权限是否小于0"""
|
"""群组权限是否小于0"""
|
||||||
if g.level < 0:
|
if g.level < 0:
|
||||||
return True
|
return True
|
||||||
|
|||||||
@ -44,6 +44,44 @@ class EventLogType(StrEnum):
|
|||||||
"""主动退群"""
|
"""主动退群"""
|
||||||
|
|
||||||
|
|
||||||
|
class CacheType(StrEnum):
|
||||||
|
"""
|
||||||
|
缓存类型
|
||||||
|
"""
|
||||||
|
|
||||||
|
PLUGINS = "GLOBAL_ALL_PLUGINS"
|
||||||
|
"""全局全部插件"""
|
||||||
|
GROUPS = "GLOBAL_ALL_GROUPS"
|
||||||
|
"""全局全部群组"""
|
||||||
|
USERS = "GLOBAL_ALL_USERS"
|
||||||
|
"""全部用户"""
|
||||||
|
BAN = "GLOBAL_ALL_BAN"
|
||||||
|
"""全局ban列表"""
|
||||||
|
BOT = "GLOBAL_BOT"
|
||||||
|
"""全局bot信息"""
|
||||||
|
LEVEL = "GLOBAL_USER_LEVEL"
|
||||||
|
"""用户权限"""
|
||||||
|
LIMIT = "GLOBAL_LIMIT"
|
||||||
|
"""插件限制"""
|
||||||
|
|
||||||
|
|
||||||
|
class DbLockType(StrEnum):
|
||||||
|
"""
|
||||||
|
锁类型
|
||||||
|
"""
|
||||||
|
|
||||||
|
CREATE = "CREATE"
|
||||||
|
"""创建"""
|
||||||
|
DELETE = "DELETE"
|
||||||
|
"""删除"""
|
||||||
|
UPDATE = "UPDATE"
|
||||||
|
"""更新"""
|
||||||
|
QUERY = "QUERY"
|
||||||
|
"""查询"""
|
||||||
|
UPSERT = "UPSERT"
|
||||||
|
"""创建或更新"""
|
||||||
|
|
||||||
|
|
||||||
class GoldHandle(StrEnum):
|
class GoldHandle(StrEnum):
|
||||||
"""
|
"""
|
||||||
金币处理
|
金币处理
|
||||||
|
|||||||
@ -49,6 +49,9 @@ async def _():
|
|||||||
try:
|
try:
|
||||||
for priority in priority_list:
|
for priority in priority_list:
|
||||||
for func in priority_data[priority]:
|
for func in priority_data[priority]:
|
||||||
|
logger.debug(
|
||||||
|
f"执行优先级 [{priority}] on_startup 方法: {func.__module__}"
|
||||||
|
)
|
||||||
if is_coroutine_callable(func):
|
if is_coroutine_callable(func):
|
||||||
await func()
|
await func()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -1,11 +1,13 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass
|
||||||
from datetime import date, datetime
|
from datetime import date, datetime
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import time
|
import time
|
||||||
from typing import Any
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
from nonebot_plugin_uninfo import Uninfo
|
||||||
import pypinyin
|
import pypinyin
|
||||||
import pytz
|
import pytz
|
||||||
|
|
||||||
@ -13,43 +15,53 @@ from zhenxun.configs.config import Config
|
|||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EntityIDs:
|
||||||
|
user_id: str
|
||||||
|
"""用户id"""
|
||||||
|
group_id: str | None
|
||||||
|
"""群组id"""
|
||||||
|
channel_id: str | None
|
||||||
|
"""频道id"""
|
||||||
|
|
||||||
|
|
||||||
class ResourceDirManager:
|
class ResourceDirManager:
|
||||||
"""
|
"""
|
||||||
临时文件管理器
|
临时文件管理器
|
||||||
"""
|
"""
|
||||||
|
|
||||||
temp_path = [] # noqa: RUF012
|
temp_path: ClassVar[set[Path]] = set()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __tree_append(cls, path: Path):
|
def __tree_append(cls, path: Path, deep: int = 1, current: int = 0):
|
||||||
"""递归添加文件夹
|
"""递归添加文件夹"""
|
||||||
|
if current >= deep and deep != -1:
|
||||||
参数:
|
return
|
||||||
path: 文件夹路径
|
path = path.resolve() # 标准化路径
|
||||||
"""
|
|
||||||
for f in os.listdir(path):
|
for f in os.listdir(path):
|
||||||
file = path / f
|
file = (path / f).resolve() # 标准化子路径
|
||||||
if file.is_dir():
|
if file.is_dir():
|
||||||
if file not in cls.temp_path:
|
if file not in cls.temp_path:
|
||||||
cls.temp_path.append(file)
|
cls.temp_path.add(file)
|
||||||
logger.debug(f"添加临时文件夹: {path}")
|
logger.debug(f"添加临时文件夹: {file}")
|
||||||
cls.__tree_append(file)
|
cls.__tree_append(file, deep, current + 1)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def add_temp_dir(cls, path: str | Path, tree: bool = False):
|
def add_temp_dir(cls, path: str | Path, tree: bool = False, deep: int = 1):
|
||||||
"""添加临时清理文件夹,这些文件夹会被自动清理
|
"""添加临时清理文件夹,这些文件夹会被自动清理
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
path: 文件夹路径
|
path: 文件夹路径
|
||||||
tree: 是否递归添加文件夹
|
tree: 是否递归添加文件夹
|
||||||
|
deep: 深度, -1 为无限深度
|
||||||
"""
|
"""
|
||||||
if isinstance(path, str):
|
if isinstance(path, str):
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
if path not in cls.temp_path:
|
if path not in cls.temp_path:
|
||||||
cls.temp_path.append(path)
|
cls.temp_path.add(path)
|
||||||
logger.debug(f"添加临时文件夹: {path}")
|
logger.debug(f"添加临时文件夹: {path}")
|
||||||
if tree:
|
if tree:
|
||||||
cls.__tree_append(path)
|
cls.__tree_append(path, deep)
|
||||||
|
|
||||||
|
|
||||||
class CountLimiter:
|
class CountLimiter:
|
||||||
@ -230,6 +242,27 @@ def is_valid_date(date_text: str, separator: str = "-") -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_entity_ids(session: Uninfo) -> EntityIDs:
|
||||||
|
"""获取用户id,群组id,频道id
|
||||||
|
|
||||||
|
参数:
|
||||||
|
session: Uninfo
|
||||||
|
|
||||||
|
返回:
|
||||||
|
EntityIDs: 用户id,群组id,频道id
|
||||||
|
"""
|
||||||
|
user_id = session.user.id
|
||||||
|
group_id = None
|
||||||
|
channel_id = None
|
||||||
|
if session.group:
|
||||||
|
if session.group.parent:
|
||||||
|
group_id = session.group.parent.id
|
||||||
|
channel_id = session.group.id
|
||||||
|
else:
|
||||||
|
group_id = session.group.id
|
||||||
|
return EntityIDs(user_id=user_id, group_id=group_id, channel_id=channel_id)
|
||||||
|
|
||||||
|
|
||||||
def is_number(text: str) -> bool:
|
def is_number(text: str) -> bool:
|
||||||
"""是否为数字
|
"""是否为数字
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user