构建缓存,hook使用缓存

This commit is contained in:
HibiKier 2025-01-08 15:23:10 +08:00
parent 45649bb29d
commit 521bcaceeb
16 changed files with 996 additions and 660 deletions

View File

@ -1,580 +0,0 @@
from nonebot.adapters import Bot, Event
from nonebot.adapters.onebot.v11 import PokeNotifyEvent
from nonebot.exception import IgnoredException
from nonebot.matcher import Matcher
from nonebot_plugin_alconna import At, UniMsg
from nonebot_plugin_session import EventSession
from pydantic import BaseModel
from tortoise.exceptions import IntegrityError
from zhenxun.configs.config import Config
from zhenxun.models.bot_console import BotConsole
from zhenxun.models.group_console import GroupConsole
from zhenxun.models.level_user import LevelUser
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.plugin_limit import PluginLimit
from zhenxun.models.user_console import UserConsole
from zhenxun.services.log import logger
from zhenxun.utils.cache_utils import Cache
from zhenxun.utils.enum import (
BlockType,
CacheType,
GoldHandle,
LimitWatchType,
PluginLimitType,
PluginType,
)
from zhenxun.utils.exception import InsufficientGold
from zhenxun.utils.message import MessageUtils
from zhenxun.utils.utils import CountLimiter, FreqLimiter, UserBlockLimiter
base_config = Config.get("hook")
class Limit(BaseModel):
limit: PluginLimit
limiter: FreqLimiter | UserBlockLimiter | CountLimiter
class Config:
arbitrary_types_allowed = True
class LimitManage:
add_module = [] # noqa: RUF012
cd_limit: dict[str, Limit] = {} # noqa: RUF012
block_limit: dict[str, Limit] = {} # noqa: RUF012
count_limit: dict[str, Limit] = {} # noqa: RUF012
@classmethod
def add_limit(cls, limit: PluginLimit):
"""添加限制
参数:
limit: PluginLimit
"""
if limit.module not in cls.add_module:
cls.add_module.append(limit.module)
if limit.limit_type == PluginLimitType.BLOCK:
cls.block_limit[limit.module] = Limit(
limit=limit, limiter=UserBlockLimiter()
)
elif limit.limit_type == PluginLimitType.CD:
cls.cd_limit[limit.module] = Limit(
limit=limit, limiter=FreqLimiter(limit.cd)
)
elif limit.limit_type == PluginLimitType.COUNT:
cls.count_limit[limit.module] = Limit(
limit=limit, limiter=CountLimiter(limit.max_count)
)
@classmethod
def unblock(
cls, module: str, user_id: str, group_id: str | None, channel_id: str | None
):
"""解除插件block
参数:
module: 模块名
user_id: 用户id
group_id: 群组id
channel_id: 频道id
"""
if limit_model := cls.block_limit.get(module):
limit = limit_model.limit
limiter: UserBlockLimiter = limit_model.limiter # type: ignore
key_type = user_id
if group_id and limit.watch_type == LimitWatchType.GROUP:
key_type = channel_id or group_id
logger.debug(
f"解除对象: {key_type} 的block限制",
"AuthChecker",
session=user_id,
group_id=group_id,
)
limiter.set_false(key_type)
@classmethod
async def check(
cls,
module: str,
user_id: str,
group_id: str | None,
channel_id: str | None,
session: EventSession,
):
"""检测限制
参数:
module: 模块名
user_id: 用户id
group_id: 群组id
channel_id: 频道id
session: Session
异常:
IgnoredException: IgnoredException
"""
if limit_model := cls.cd_limit.get(module):
await cls.__check(limit_model, user_id, group_id, channel_id, session)
if limit_model := cls.block_limit.get(module):
await cls.__check(limit_model, user_id, group_id, channel_id, session)
if limit_model := cls.count_limit.get(module):
await cls.__check(limit_model, user_id, group_id, channel_id, session)
@classmethod
async def __check(
cls,
limit_model: Limit | None,
user_id: str,
group_id: str | None,
channel_id: str | None,
session: EventSession,
):
"""检测限制
参数:
limit_model: Limit
user_id: 用户id
group_id: 群组id
channel_id: 频道id
session: Session
异常:
IgnoredException: IgnoredException
"""
if not limit_model:
return
limit = limit_model.limit
limiter = limit_model.limiter
is_limit = (
LimitWatchType.ALL
or (group_id and limit.watch_type == LimitWatchType.GROUP)
or (not group_id and limit.watch_type == LimitWatchType.USER)
)
key_type = user_id
if group_id and limit.watch_type == LimitWatchType.GROUP:
key_type = channel_id or group_id
if is_limit and not limiter.check(key_type):
if limit.result:
await MessageUtils.build_message(limit.result).send()
logger.debug(
f"{limit.module}({limit.limit_type}) 正在限制中...",
"AuthChecker",
session=session,
)
raise IgnoredException(f"{limit.module} 正在限制中...")
else:
logger.debug(
f"开始进行限制 {limit.module}({limit.limit_type})...",
"AuthChecker",
session=user_id,
group_id=group_id,
)
if isinstance(limiter, FreqLimiter):
limiter.start_cd(key_type)
if isinstance(limiter, UserBlockLimiter):
limiter.set_true(key_type)
if isinstance(limiter, CountLimiter):
limiter.increase(key_type)
class IsSuperuserException(Exception):
pass
class AuthChecker:
"""
权限检查
"""
def __init__(self):
check_notice_info_cd = Config.get_config("hook", "CHECK_NOTICE_INFO_CD")
if check_notice_info_cd is None or check_notice_info_cd < 0:
raise ValueError("模块: [hook], 配置项: [CHECK_NOTICE_INFO_CD] 为空或小于0")
self._flmt = FreqLimiter(check_notice_info_cd)
self._flmt_g = FreqLimiter(check_notice_info_cd)
self._flmt_s = FreqLimiter(check_notice_info_cd)
self._flmt_c = FreqLimiter(check_notice_info_cd)
def is_send_limit_message(self, plugin: PluginInfo, sid: str) -> bool:
"""是否发送提示消息
参数:
plugin: PluginInfo
返回:
bool: 是否发送提示消息
"""
if not base_config.get("IS_SEND_TIP_MESSAGE"):
return False
if plugin.plugin_type == PluginType.DEPENDANT:
return False
return plugin.module != "ai" if self._flmt_s.check(sid) else False
async def auth(
self,
matcher: Matcher,
event: Event,
bot: Bot,
session: EventSession,
message: UniMsg,
):
"""权限检查
参数:
matcher: matcher
bot: bot
session: EventSession
message: UniMsg
"""
is_ignore = False
cost_gold = 0
user_id = session.id1
group_id = session.id3
channel_id = session.id2
if not group_id:
group_id = channel_id
channel_id = None
if matcher.type == "notice" and not isinstance(event, PokeNotifyEvent):
"""过滤除poke外的notice"""
return
if user_id and matcher.plugin and (module_path := matcher.plugin.module_name):
try:
user = await UserConsole.get_user(user_id, session.platform)
except IntegrityError as e:
logger.debug(
"重复创建用户,已跳过该次权限...",
"AuthChecker",
session=session,
e=e,
)
return
if plugin := await Cache.get(CacheType.PLUGINS, module_path):
if plugin.plugin_type == PluginType.HIDDEN:
logger.debug(
f"插件: {plugin.name}:{plugin.module} "
"为HIDDEN已跳过权限检查..."
)
return
try:
cost_gold = await self.auth_cost(user, plugin, session)
if session.id1 in bot.config.superusers:
if plugin.plugin_type == PluginType.SUPERUSER:
raise IsSuperuserException()
if not plugin.limit_superuser:
cost_gold = 0
raise IsSuperuserException()
await self.auth_bot(plugin, bot.self_id)
await self.auth_group(plugin, session, message)
await self.auth_admin(plugin, session)
await self.auth_plugin(plugin, session, event)
await self.auth_limit(plugin, session)
except IsSuperuserException:
logger.debug(
"超级用户或被ban跳过权限检测...", "AuthChecker", session=session
)
except IgnoredException:
is_ignore = True
LimitManage.unblock(
matcher.plugin.name, user_id, group_id, channel_id
)
except AssertionError as e:
is_ignore = True
logger.debug("消息无法发送", session=session, e=e)
if cost_gold and user_id:
"""花费金币"""
try:
await UserConsole.reduce_gold(
user_id,
cost_gold,
GoldHandle.PLUGIN,
matcher.plugin.name if matcher.plugin else "",
session.platform,
)
except InsufficientGold:
if u := await UserConsole.get_user(user_id):
u.gold = 0
await u.save(update_fields=["gold"])
logger.debug(
f"调用功能花费金币: {cost_gold}", "AuthChecker", session=session
)
if is_ignore:
raise IgnoredException("权限检测 ignore")
async def auth_bot(self, plugin: PluginInfo, bot_id: str):
"""机器人权限
参数:
plugin: PluginInfo
bot_id: bot_id
"""
if not await BotConsole.get_bot_status(bot_id):
logger.debug("Bot休眠中阻断权限检测...", "AuthChecker")
raise IgnoredException("BotConsole休眠权限检测 ignore")
if await BotConsole.is_block_plugin(bot_id, plugin.module):
logger.debug(
f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭...",
"AuthChecker",
)
raise IgnoredException("BotConsole插件权限检测 ignore")
async def auth_limit(self, plugin: PluginInfo, session: EventSession):
"""插件限制
参数:
plugin: PluginInfo
session: EventSession
"""
user_id = session.id1
group_id = session.id3
channel_id = session.id2
if not group_id:
group_id = channel_id
channel_id = None
if plugin.module not in LimitManage.add_module:
limit_list: list[PluginLimit] = await plugin.plugin_limit.filter(
status=True
).all() # type: ignore
for limit in limit_list:
LimitManage.add_limit(limit)
if user_id:
await LimitManage.check(
plugin.module, user_id, group_id, channel_id, session
)
async def auth_plugin(
self, plugin: PluginInfo, session: EventSession, event: Event
):
"""插件状态
参数:
plugin: PluginInfo
session: EventSession
"""
group_id = session.id3
channel_id = session.id2
if not group_id:
group_id = channel_id
channel_id = None
if user_id := session.id1:
is_poke = isinstance(event, PokeNotifyEvent)
if group_id:
sid = group_id or user_id
if await GroupConsole.is_superuser_block_plugin(
group_id, plugin.module
):
"""超级用户群组插件状态"""
if self.is_send_limit_message(plugin, sid) and not is_poke:
self._flmt_s.start_cd(group_id or user_id)
await MessageUtils.build_message(
"超级管理员禁用了该群此功能..."
).send(reply_to=True)
logger.debug(
f"{plugin.name}({plugin.module}) 超级管理员禁用了该群此功能...",
"AuthChecker",
session=session,
)
raise IgnoredException("超级管理员禁用了该群此功能...")
if await GroupConsole.is_normal_block_plugin(group_id, plugin.module):
"""群组插件状态"""
if self.is_send_limit_message(plugin, sid) and not is_poke:
self._flmt_s.start_cd(group_id or user_id)
await MessageUtils.build_message("该群未开启此功能...").send(
reply_to=True
)
logger.debug(
f"{plugin.name}({plugin.module}) 未开启此功能...",
"AuthChecker",
session=session,
)
raise IgnoredException("该群未开启此功能...")
if plugin.block_type == BlockType.GROUP:
"""全局群组禁用"""
try:
if self.is_send_limit_message(plugin, sid) and not is_poke:
self._flmt_c.start_cd(group_id)
await MessageUtils.build_message(
"该功能在群组中已被禁用..."
).send(reply_to=True)
except Exception as e:
logger.error(
"auth_plugin 发送消息失败",
"AuthChecker",
session=session,
e=e,
)
logger.debug(
f"{plugin.name}({plugin.module}) 该插件在群组中已被禁用...",
"AuthChecker",
session=session,
)
raise IgnoredException("该插件在群组中已被禁用...")
else:
sid = user_id
if plugin.block_type == BlockType.PRIVATE:
"""全局私聊禁用"""
try:
if self.is_send_limit_message(plugin, sid) and not is_poke:
self._flmt_c.start_cd(user_id)
await MessageUtils.build_message(
"该功能在私聊中已被禁用..."
).send()
except Exception as e:
logger.error(
"auth_admin 发送消息失败",
"AuthChecker",
session=session,
e=e,
)
logger.debug(
f"{plugin.name}({plugin.module}) 该插件在私聊中已被禁用...",
"AuthChecker",
session=session,
)
raise IgnoredException("该插件在私聊中已被禁用...")
if not plugin.status and plugin.block_type == BlockType.ALL:
"""全局状态"""
if group_id and await GroupConsole.is_super_group(group_id):
raise IsSuperuserException()
logger.debug(
f"{plugin.name}({plugin.module}) 全局未开启此功能...",
"AuthChecker",
session=session,
)
if self.is_send_limit_message(plugin, sid) and not is_poke:
self._flmt_s.start_cd(group_id or user_id)
await MessageUtils.build_message("全局未开启此功能...").send()
raise IgnoredException("全局未开启此功能...")
async def auth_admin(self, plugin: PluginInfo, session: EventSession):
"""管理员命令 个人权限
参数:
plugin: PluginInfo
session: EventSession
"""
user_id = session.id1
if user_id and plugin.admin_level:
if group_id := session.id3 or session.id2:
if not await LevelUser.check_level(
user_id, group_id, plugin.admin_level
):
try:
if self._flmt.check(user_id):
self._flmt.start_cd(user_id)
await MessageUtils.build_message(
[
At(flag="user", target=user_id),
f"你的权限不足喔,"
f"该功能需要的权限等级: {plugin.admin_level}",
]
).send(reply_to=True)
except Exception as e:
logger.error(
"auth_admin 发送消息失败",
"AuthChecker",
session=session,
e=e,
)
logger.debug(
f"{plugin.name}({plugin.module}) 管理员权限不足...",
"AuthChecker",
session=session,
)
raise IgnoredException("管理员权限不足...")
elif not await LevelUser.check_level(user_id, None, plugin.admin_level):
try:
await MessageUtils.build_message(
f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}"
).send()
except Exception as e:
logger.error(
"auth_admin 发送消息失败", "AuthChecker", session=session, e=e
)
logger.debug(
f"{plugin.name}({plugin.module}) 管理员权限不足...",
"AuthChecker",
session=session,
)
raise IgnoredException("权限不足")
async def auth_group(
self, plugin: PluginInfo, session: EventSession, message: UniMsg
):
"""群黑名单检测 群总开关检测
参数:
plugin: PluginInfo
session: EventSession
message: UniMsg
"""
if not (group_id := session.id3 or session.id2):
return
text = message.extract_plain_text()
group = await GroupConsole.get_group(group_id)
if not group:
"""群不存在"""
logger.debug(
"群组信息不存在...",
"AuthChecker",
session=session,
)
raise IgnoredException("群不存在")
if group.level < 0:
"""群权限小于0"""
logger.debug(
"群黑名单, 群权限-1...",
"AuthChecker",
session=session,
)
raise IgnoredException("群黑名单")
if not group.status:
"""群休眠"""
if text.strip() != "醒来":
logger.debug("群休眠状态...", "AuthChecker", session=session)
raise IgnoredException("群休眠状态")
if plugin.level > group.level:
"""插件等级大于群等级"""
logger.debug(
f"{plugin.name}({plugin.module}) 群等级限制.."
f"该功能需要的群等级: {plugin.level}..",
"AuthChecker",
session=session,
)
raise IgnoredException(f"{plugin.name}({plugin.module}) 群等级限制...")
async def auth_cost(
self, user: UserConsole, plugin: PluginInfo, session: EventSession
) -> int:
"""检测是否满足金币条件
参数:
user: UserConsole
plugin: PluginInfo
session: EventSession
返回:
int: 需要消耗的金币
"""
if user.gold < plugin.cost_gold:
"""插件消耗金币不足"""
try:
await MessageUtils.build_message(
f"金币不足..该功能需要{plugin.cost_gold}金币.."
).send()
except Exception as e:
logger.error(
"auth_cost 发送消息失败", "AuthChecker", session=session, e=e
)
logger.debug(
f"{plugin.name}({plugin.module}) 金币限制.."
f"该功能需要{plugin.cost_gold}金币..",
"AuthChecker",
session=session,
)
raise IgnoredException(f"{plugin.name}({plugin.module}) 金币限制...")
return plugin.cost_gold
checker = AuthChecker()

View File

@ -0,0 +1,76 @@
from nonebot.exception import IgnoredException
from nonebot_plugin_alconna import At
from nonebot_plugin_uninfo import Uninfo
from zhenxun.models.level_user import LevelUser
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.log import logger
from zhenxun.utils.cache_utils import Cache
from zhenxun.utils.enum import CacheType
from zhenxun.utils.message import MessageUtils
from .utils import freq
async def auth_admin(plugin: PluginInfo, session: Uninfo):
"""管理员命令 个人权限
参数:
plugin: PluginInfo
session: PluginInfo
"""
group_id = None
cache = Cache[list[LevelUser]](CacheType.LEVEL)
user_level = await cache.get(session.user.id) or []
if session.group:
if session.group.parent:
group_id = session.group.parent.id
else:
group_id = session.group.id
if not plugin.admin_level:
return
if group_id:
user_level += await cache.get(session.user.id, group_id) or []
user = max(user_level, key=lambda x: x.user_level)
if user.user_level < plugin.admin_level:
try:
if freq._flmt.check(session.user.id):
freq._flmt.start_cd(session.user.id)
await MessageUtils.build_message(
[
At(flag="user", target=session.user.id),
"你的权限不足喔,"
f"该功能需要的权限等级: {plugin.admin_level}",
]
).send(reply_to=True)
except Exception as e:
logger.error(
"auth_admin 发送消息失败",
"AuthChecker",
session=session,
e=e,
)
logger.debug(
f"{plugin.name}({plugin.module}) 管理员权限不足...",
"AuthChecker",
session=session,
)
raise IgnoredException("管理员权限不足...")
elif user_level:
user = max(user_level, key=lambda x: x.user_level)
if user.user_level < plugin.admin_level:
try:
await MessageUtils.build_message(
f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}"
).send()
except Exception as e:
logger.error(
"auth_admin 发送消息失败", "AuthChecker", session=session, e=e
)
logger.debug(
f"{plugin.name}({plugin.module}) 管理员权限不足...",
"AuthChecker",
session=session,
)
raise IgnoredException("权限不足")

View File

@ -4,31 +4,29 @@ from zhenxun.models.bot_console import BotConsole
from zhenxun.models.plugin_info import PluginInfo from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.log import logger from zhenxun.services.log import logger
from zhenxun.utils.cache_utils import Cache from zhenxun.utils.cache_utils import Cache
from zhenxun.utils.common_utils import CommonUtils
from zhenxun.utils.enum import CacheType from zhenxun.utils.enum import CacheType
async def get_bot_status(bot_id: str):
a = await Cache.get(CacheType.BOT, bot_id)
cache_data = await Cache.get_cache(CacheType.BOT)
if cache_data and cache_data.getter:
b = await cache_data.getter.get(cache_data.data)
if bot := await Cache.get(CacheType.BOT, bot_id):
return bot
async def auth_bot(plugin: PluginInfo, bot_id: str): async def auth_bot(plugin: PluginInfo, bot_id: str):
"""机器人权限 """bot层面的权限检查
参数: 参数:
plugin: PluginInfo plugin: PluginInfo
bot_id: bot_id bot_id: bot id
异常:
IgnoredException: 忽略插件
IgnoredException: 忽略插件
""" """
if not await BotConsole.get_bot_status(bot_id): if cache := Cache[BotConsole](CacheType.BOT):
logger.debug("Bot休眠中阻断权限检测...", "AuthChecker") bot = await cache.get(bot_id)
raise IgnoredException("BotConsole休眠权限检测 ignore") if not bot or not bot.status:
if await BotConsole.is_block_plugin(bot_id, plugin.module): logger.debug("Bot不存在或休眠中阻断权限检测...", "AuthChecker")
logger.debug( raise IgnoredException("BotConsole休眠权限检测 ignore")
f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭...", if CommonUtils.format(plugin.module) in bot.block_plugins:
"AuthChecker", logger.debug(
) f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭...",
raise IgnoredException("BotConsole插件权限检测 ignore") "AuthChecker",
)
raise IgnoredException("BotConsole插件权限检测 ignore")

View File

@ -0,0 +1,35 @@
from nonebot.exception import IgnoredException
from nonebot_plugin_uninfo import Uninfo
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.user_console import UserConsole
from zhenxun.services.log import logger
from zhenxun.utils.message import MessageUtils
async def auth_cost(user: UserConsole, plugin: PluginInfo, session: Uninfo) -> int:
"""检测是否满足金币条件
参数:
plugin: PluginInfo
session: Uninfo
返回:
int: 需要消耗的金币
"""
if user.gold < plugin.cost_gold:
"""插件消耗金币不足"""
try:
await MessageUtils.build_message(
f"金币不足..该功能需要{plugin.cost_gold}金币.."
).send()
except Exception as e:
logger.error("auth_cost 发送消息失败", "AuthChecker", session=session, e=e)
logger.debug(
f"{plugin.name}({plugin.module}) 金币限制.."
f"该功能需要{plugin.cost_gold}金币..",
"AuthChecker",
session=session,
)
raise IgnoredException(f"{plugin.name}({plugin.module}) 金币限制...")
return plugin.cost_gold

View File

@ -0,0 +1,57 @@
from nonebot.exception import IgnoredException
from nonebot_plugin_alconna import UniMsg
from nonebot_plugin_uninfo import Uninfo
from zhenxun.models.group_console import GroupConsole
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.log import logger
from zhenxun.utils.cache_utils import Cache
from zhenxun.utils.enum import CacheType
async def auth_group(plugin: PluginInfo, session: Uninfo, message: UniMsg):
"""群黑名单检测 群总开关检测
参数:
plugin: PluginInfo
session: EventSession
message: UniMsg
"""
if not session.group:
return
if session.group.parent:
group_id = session.group.parent.id
else:
group_id = session.group.id
text = message.extract_plain_text()
group = await Cache[GroupConsole](CacheType.GROUPS).get(group_id)
if not group:
"""群不存在"""
logger.debug(
"群组信息不存在...",
"AuthChecker",
session=session,
)
raise IgnoredException("群不存在")
if group.level < 0:
"""群权限小于0"""
logger.debug(
"群黑名单, 群权限-1...",
"AuthChecker",
session=session,
)
raise IgnoredException("群黑名单")
if not group.status:
"""群休眠"""
if text.strip() != "醒来":
logger.debug("群休眠状态...", "AuthChecker", session=session)
raise IgnoredException("群休眠状态")
if plugin.level > group.level:
"""插件等级大于群等级"""
logger.debug(
f"{plugin.name}({plugin.module}) 群等级限制.."
f"该功能需要的群等级: {plugin.level}..",
"AuthChecker",
session=session,
)
raise IgnoredException(f"{plugin.name}({plugin.module}) 群等级限制...")

View File

@ -0,0 +1,189 @@
from typing import ClassVar
from nonebot.exception import IgnoredException
from nonebot_plugin_uninfo import Uninfo
from pydantic import BaseModel
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.plugin_limit import PluginLimit
from zhenxun.services.log import logger
from zhenxun.utils.enum import LimitWatchType, PluginLimitType
from zhenxun.utils.message import MessageUtils
from zhenxun.utils.utils import CountLimiter, FreqLimiter, UserBlockLimiter
class Limit(BaseModel):
limit: PluginLimit
limiter: FreqLimiter | UserBlockLimiter | CountLimiter
class Config:
arbitrary_types_allowed = True
class LimitManage:
add_module: ClassVar[list] = []
cd_limit: ClassVar[dict[str, Limit]] = {}
block_limit: ClassVar[dict[str, Limit]] = {}
count_limit: ClassVar[dict[str, Limit]] = {}
@classmethod
def add_limit(cls, limit: PluginLimit):
"""添加限制
参数:
limit: PluginLimit
"""
if limit.module not in cls.add_module:
cls.add_module.append(limit.module)
if limit.limit_type == PluginLimitType.BLOCK:
cls.block_limit[limit.module] = Limit(
limit=limit, limiter=UserBlockLimiter()
)
elif limit.limit_type == PluginLimitType.CD:
cls.cd_limit[limit.module] = Limit(
limit=limit, limiter=FreqLimiter(limit.cd)
)
elif limit.limit_type == PluginLimitType.COUNT:
cls.count_limit[limit.module] = Limit(
limit=limit, limiter=CountLimiter(limit.max_count)
)
@classmethod
def unblock(
cls, module: str, user_id: str, group_id: str | None, channel_id: str | None
):
"""解除插件block
参数:
module: 模块名
user_id: 用户id
group_id: 群组id
channel_id: 频道id
"""
if limit_model := cls.block_limit.get(module):
limit = limit_model.limit
limiter: UserBlockLimiter = limit_model.limiter # type: ignore
key_type = user_id
if group_id and limit.watch_type == LimitWatchType.GROUP:
key_type = channel_id or group_id
logger.debug(
f"解除对象: {key_type} 的block限制",
"AuthChecker",
session=user_id,
group_id=group_id,
)
limiter.set_false(key_type)
@classmethod
async def check(
cls,
module: str,
user_id: str,
group_id: str | None,
channel_id: str | None,
session: Uninfo,
):
"""检测限制
参数:
module: 模块名
user_id: 用户id
group_id: 群组id
channel_id: 频道id
session: Session
异常:
IgnoredException: IgnoredException
"""
if limit_model := cls.cd_limit.get(module):
await cls.__check(limit_model, user_id, group_id, channel_id, session)
if limit_model := cls.block_limit.get(module):
await cls.__check(limit_model, user_id, group_id, channel_id, session)
if limit_model := cls.count_limit.get(module):
await cls.__check(limit_model, user_id, group_id, channel_id, session)
@classmethod
async def __check(
cls,
limit_model: Limit | None,
user_id: str,
group_id: str | None,
channel_id: str | None,
session: Uninfo,
):
"""检测限制
参数:
limit_model: Limit
user_id: 用户id
group_id: 群组id
channel_id: 频道id
session: Session
异常:
IgnoredException: IgnoredException
"""
if not limit_model:
return
limit = limit_model.limit
limiter = limit_model.limiter
is_limit = (
LimitWatchType.ALL
or (group_id and limit.watch_type == LimitWatchType.GROUP)
or (not group_id and limit.watch_type == LimitWatchType.USER)
)
key_type = user_id
if group_id and limit.watch_type == LimitWatchType.GROUP:
key_type = channel_id or group_id
if is_limit and not limiter.check(key_type):
if limit.result:
await MessageUtils.build_message(limit.result).send()
logger.debug(
f"{limit.module}({limit.limit_type}) 正在限制中...",
"AuthChecker",
session=session,
)
raise IgnoredException(f"{limit.module} 正在限制中...")
else:
logger.debug(
f"开始进行限制 {limit.module}({limit.limit_type})...",
"AuthChecker",
session=user_id,
group_id=group_id,
)
if isinstance(limiter, FreqLimiter):
limiter.start_cd(key_type)
if isinstance(limiter, UserBlockLimiter):
limiter.set_true(key_type)
if isinstance(limiter, CountLimiter):
limiter.increase(key_type)
async def auth_limit(plugin: PluginInfo, session: Uninfo):
"""插件限制
参数:
plugin: PluginInfo
session: Uninfo
"""
user_id = session.user.id
group_id = None
channel_id = None
if session.group:
if session.group.parent:
group_id = session.group.parent.id
channel_id = session.group.id
else:
group_id = session.group.id
if not group_id:
group_id = channel_id
channel_id = None
if plugin.module not in LimitManage.add_module:
limit_list: list[PluginLimit] = await plugin.plugin_limit.filter(
status=True
).all() # type: ignore
for limit in limit_list:
LimitManage.add_limit(limit)
if user_id:
await LimitManage.check(plugin.module, user_id, group_id, channel_id, session)

View File

@ -0,0 +1,201 @@
from nonebot.adapters import Event
from nonebot.exception import IgnoredException
from nonebot_plugin_uninfo import Uninfo
from zhenxun.models.group_console import GroupConsole
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.log import logger
from zhenxun.utils.cache_utils import Cache
from zhenxun.utils.common_utils import CommonUtils
from zhenxun.utils.enum import BlockType, CacheType
from zhenxun.utils.message import MessageUtils
from .exception import IsSuperuserException
from .utils import freq
class GroupCheck:
def __init__(
self, plugin: PluginInfo, group_id: str, session: Uninfo, is_poke: bool
) -> None:
self.group_id = group_id
self.session = session
self.is_poke = is_poke
self.plugin = plugin
async def __get_data(self):
cache = Cache[GroupConsole](CacheType.GROUPS)
return await cache.get(self.group_id)
async def check(self):
await self.check_superuser_block(self.plugin)
async def check_superuser_block(self, plugin: PluginInfo):
"""超级用户禁用群组插件检测
参数:
plugin: PluginInfo
异常:
IgnoredException: 忽略插件
"""
group = await self.__get_data()
if group and CommonUtils.format(plugin.module) in group.superuser_block_plugin:
if freq.is_send_limit_message(plugin, group.group_id, self.is_poke):
freq._flmt_s.start_cd(self.group_id)
await MessageUtils.build_message("超级管理员禁用了该群此功能...").send(
reply_to=True
)
logger.debug(
f"{plugin.name}({plugin.module}) 超级管理员禁用了该群此功能...",
"AuthChecker",
session=self.session,
)
raise IgnoredException("超级管理员禁用了该群此功能...")
await self.check_normal_block(self.plugin)
async def check_normal_block(self, plugin: PluginInfo):
"""群组插件状态
参数:
plugin: PluginInfo
异常:
IgnoredException: 忽略插件
"""
group = await self.__get_data()
if group and CommonUtils.format(plugin.module) in group.block_plugin:
if freq.is_send_limit_message(plugin, self.group_id, self.is_poke):
freq._flmt_s.start_cd(self.group_id)
await MessageUtils.build_message("该群未开启此功能...").send(
reply_to=True
)
logger.debug(
f"{plugin.name}({plugin.module}) 未开启此功能...",
"AuthChecker",
session=self.session,
)
raise IgnoredException("该群未开启此功能...")
await self.check_global_block(self.plugin)
async def check_global_block(self, plugin: PluginInfo):
"""全局禁用插件检测
参数:
plugin: PluginInfo
异常:
IgnoredException: 忽略插件
"""
if plugin.block_type == BlockType.GROUP:
"""全局群组禁用"""
try:
if freq.is_send_limit_message(plugin, self.group_id, self.is_poke):
freq._flmt_c.start_cd(self.group_id)
await MessageUtils.build_message("该功能在群组中已被禁用...").send(
reply_to=True
)
except Exception as e:
logger.error(
"auth_plugin 发送消息失败",
"AuthChecker",
session=self.session,
e=e,
)
logger.debug(
f"{plugin.name}({plugin.module}) 该插件在群组中已被禁用...",
"AuthChecker",
session=self.session,
)
raise IgnoredException("该插件在群组中已被禁用...")
class PluginCheck:
def __init__(self, group_id: str | None, session: Uninfo, is_poke: bool):
self.session = session
self.is_poke = is_poke
self.group_id = group_id
async def check_user(self, plugin: PluginInfo):
"""全局私聊禁用检测
参数:
plugin: PluginInfo
异常:
IgnoredException: 忽略插件
"""
if plugin.block_type == BlockType.PRIVATE:
try:
if freq.is_send_limit_message(
plugin, self.session.user.id, self.is_poke
):
freq._flmt_c.start_cd(self.session.user.id)
await MessageUtils.build_message("该功能在私聊中已被禁用...").send()
except Exception as e:
logger.error(
"auth_admin 发送消息失败",
"AuthChecker",
session=self.session,
e=e,
)
logger.debug(
f"{plugin.name}({plugin.module}) 该插件在私聊中已被禁用...",
"AuthChecker",
session=self.session,
)
raise IgnoredException("该插件在私聊中已被禁用...")
async def check_global(self, plugin: PluginInfo):
"""全局状态
参数:
plugin: PluginInfo
异常:
IgnoredException: 忽略插件
"""
if not plugin.status and plugin.block_type == BlockType.ALL:
"""全局状态"""
cache = Cache[GroupConsole](CacheType.GROUPS)
if self.group_id and (group := await cache.get(self.group_id)):
if group.is_super:
raise IsSuperuserException()
logger.debug(
f"{plugin.name}({plugin.module}) 全局未开启此功能...",
"AuthChecker",
session=self.session,
)
sid = self.group_id or self.session.user.id
if freq.is_send_limit_message(plugin, sid, self.is_poke):
freq._flmt_s.start_cd(sid)
await MessageUtils.build_message("全局未开启此功能...").send()
raise IgnoredException("全局未开启此功能...")
async def auth_plugin(plugin: PluginInfo, session: Uninfo, event: Event):
"""插件状态
参数:
plugin: PluginInfo
session: Uninfo
"""
group_id = None
if session.group:
if session.group.parent:
group_id = session.group.parent.id
else:
group_id = session.group.id
try:
from nonebot.adapters.onebot.v11 import PokeNotifyEvent
is_poke = isinstance(event, PokeNotifyEvent)
except ImportError:
is_poke = False
user_check = PluginCheck(group_id, session, is_poke)
if group_id:
group_check = GroupCheck(plugin, group_id, session, is_poke)
await group_check.check()
else:
await user_check.check_user(plugin)
await user_check.check_global(plugin)

View File

@ -0,0 +1,2 @@
class IsSuperuserException(Exception):
pass

View File

@ -0,0 +1,41 @@
from zhenxun.configs.config import Config
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.utils.enum import PluginType
from zhenxun.utils.utils import FreqLimiter
base_config = Config.get("hook")
class FreqUtils:
def __init__(self):
check_notice_info_cd = Config.get_config("hook", "CHECK_NOTICE_INFO_CD")
if check_notice_info_cd is None or check_notice_info_cd < 0:
raise ValueError("模块: [hook], 配置项: [CHECK_NOTICE_INFO_CD] 为空或小于0")
self._flmt = FreqLimiter(check_notice_info_cd)
self._flmt_g = FreqLimiter(check_notice_info_cd)
self._flmt_s = FreqLimiter(check_notice_info_cd)
self._flmt_c = FreqLimiter(check_notice_info_cd)
def is_send_limit_message(
self, plugin: PluginInfo, sid: str, is_poke: bool
) -> bool:
"""是否发送提示消息
参数:
plugin: PluginInfo
sid: 检测键
is_poke: 是否是戳一戳
返回:
bool: 是否发送提示消息
"""
if is_poke:
return False
if not base_config.get("IS_SEND_TIP_MESSAGE"):
return False
if plugin.plugin_type == PluginType.DEPENDANT:
return False
return plugin.module != "ai" if self._flmt_s.check(sid) else False
freq = FreqUtils()

View File

@ -0,0 +1,125 @@
from nonebot.adapters import Bot, Event
from nonebot.exception import IgnoredException
from nonebot.matcher import Matcher
from nonebot_plugin_alconna import UniMsg
from nonebot_plugin_uninfo import Uninfo
from tortoise.exceptions import IntegrityError
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.user_console import UserConsole
from zhenxun.services.log import logger
from zhenxun.utils.cache_utils import Cache
from zhenxun.utils.enum import (
CacheType,
GoldHandle,
PluginType,
)
from zhenxun.utils.exception import InsufficientGold
from zhenxun.utils.platform import PlatformUtils
from .auth.auth_admin import auth_admin
from .auth.auth_bot import auth_bot
from .auth.auth_cost import auth_cost
from .auth.auth_group import auth_group
from .auth.auth_limit import LimitManage, auth_limit
from .auth.auth_plugin import auth_plugin
from .auth.exception import IsSuperuserException
async def auth(
matcher: Matcher,
event: Event,
bot: Bot,
session: Uninfo,
message: UniMsg,
):
"""权限检查
参数:
matcher: matcher
event: Event
bot: bot
session: Uninfo
message: UniMsg
"""
user_id = session.user.id
group_id = None
channel_id = None
if session.group:
if session.group.parent:
group_id = session.group.parent.id
channel_id = session.group.id
else:
group_id = session.group.id
is_ignore = False
cost_gold = 0
try:
from nonebot.adapters.onebot.v11 import PokeNotifyEvent
if matcher.type == "notice" and not isinstance(event, PokeNotifyEvent):
"""过滤除poke外的notice"""
return
except ImportError:
pass
user_cache = Cache[UserConsole](CacheType.USERS)
if matcher.plugin and (module := matcher.plugin.name):
try:
user = await user_cache.get(session.user.id)
except IntegrityError as e:
logger.debug(
"重复创建用户,已跳过该次权限检查...",
"AuthChecker",
session=session,
e=e,
)
return
plugin = await Cache[PluginInfo](CacheType.PLUGINS).get(module)
if user and plugin:
if plugin.plugin_type == PluginType.HIDDEN:
logger.debug(
f"插件: {plugin.name}:{plugin.module} "
"为HIDDEN已跳过权限检查..."
)
return
try:
cost_gold = await auth_cost(user, plugin, session)
if session.user.id in bot.config.superusers:
if plugin.plugin_type == PluginType.SUPERUSER:
raise IsSuperuserException()
if not plugin.limit_superuser:
cost_gold = 0
raise IsSuperuserException()
await auth_bot(plugin, bot.self_id)
await auth_group(plugin, session, message)
await auth_admin(plugin, session)
await auth_plugin(plugin, session, event)
await auth_limit(plugin, session)
except IsSuperuserException:
logger.debug(
"超级用户或被ban跳过权限检测...", "AuthChecker", session=session
)
except IgnoredException:
is_ignore = True
LimitManage.unblock(matcher.plugin.name, user_id, group_id, channel_id)
except AssertionError as e:
is_ignore = True
logger.debug("消息无法发送", session=session, e=e)
if cost_gold and user_id:
"""花费金币"""
try:
await UserConsole.reduce_gold(
user_id,
cost_gold,
GoldHandle.PLUGIN,
matcher.plugin.name if matcher.plugin else "",
PlatformUtils.get_platform(session),
)
except InsufficientGold:
if u := await UserConsole.get_user(user_id):
u.gold = 0
await u.save(update_fields=["gold"])
# 更新缓存
await user_cache.update(user_id)
logger.debug(f"调用功能花费金币: {cost_gold}", "AuthChecker", session=session)
if is_ignore:
raise IgnoredException("权限检测 ignore")

View File

@ -1,18 +1,16 @@
from nonebot.adapters.onebot.v11 import Bot, Event from nonebot.adapters import Bot, Event
from nonebot.matcher import Matcher from nonebot.matcher import Matcher
from nonebot.message import run_postprocessor, run_preprocessor from nonebot.message import run_postprocessor, run_preprocessor
from nonebot_plugin_alconna import UniMsg from nonebot_plugin_alconna import UniMsg
from nonebot_plugin_session import EventSession from nonebot_plugin_uninfo import Uninfo
from ._auth_checker import LimitManage, checker from .auth_checker import LimitManage, auth
# # 权限检测 # # 权限检测
@run_preprocessor @run_preprocessor
async def _( async def _(matcher: Matcher, event: Event, bot: Bot, session: Uninfo, message: UniMsg):
matcher: Matcher, event: Event, bot: Bot, session: EventSession, message: UniMsg await auth(
):
await checker.auth(
matcher, matcher,
event, event,
bot, bot,
@ -23,19 +21,16 @@ async def _(
# 解除命令block阻塞 # 解除命令block阻塞
@run_postprocessor @run_postprocessor
async def _( async def _(matcher: Matcher, session: Uninfo):
matcher: Matcher, user_id = session.user.id
exception: Exception | None, group_id = None
bot: Bot, channel_id = None
event: Event, if session.group:
session: EventSession, if session.group.parent:
): group_id = session.group.parent.id
user_id = session.id1 channel_id = session.group.id
group_id = session.id3 else:
channel_id = session.id2 group_id = session.group.id
if not group_id:
group_id = channel_id
channel_id = None
if user_id and matcher.plugin: if user_id and matcher.plugin:
module = matcher.plugin.name module = matcher.plugin.name
LimitManage.unblock(module, user_id, group_id, channel_id) LimitManage.unblock(module, user_id, group_id, channel_id)

View File

@ -1,4 +1,5 @@
from pathlib import Path from pathlib import Path
from typing import Any
import nonebot import nonebot
from nonebot.adapters import Bot from nonebot.adapters import Bot
@ -6,9 +7,11 @@ from nonebot.adapters import Bot
from zhenxun.models.ban_console import BanConsole from zhenxun.models.ban_console import BanConsole
from zhenxun.models.bot_console import BotConsole from zhenxun.models.bot_console import BotConsole
from zhenxun.models.group_console import GroupConsole from zhenxun.models.group_console import GroupConsole
from zhenxun.models.level_user import LevelUser
from zhenxun.models.plugin_info import PluginInfo from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.user_console import UserConsole
from zhenxun.services.log import logger from zhenxun.services.log import logger
from zhenxun.utils.cache_utils import Cache from zhenxun.utils.cache_utils import CacheRoot
from zhenxun.utils.enum import CacheType from zhenxun.utils.enum import CacheType
from zhenxun.utils.platform import PlatformUtils from zhenxun.utils.platform import PlatformUtils
@ -48,45 +51,131 @@ async def _(bot: Bot):
) )
@Cache.listener(CacheType.PLUGINS) @CacheRoot.new(CacheType.PLUGINS)
async def _(): async def _(data: dict[str, PluginInfo] = {}, key: str | None = None):
data_list = await PluginInfo.get_plugins() if data and key:
return {p.module: p for p in data_list} if plugin := await PluginInfo.get_plugin(module=key):
data[key] = plugin
else:
data_list = await PluginInfo.get_plugins()
return {p.module: p for p in data_list}
@Cache.getter(CacheType.PLUGINS, result_model=PluginInfo) @CacheRoot.updater(CacheType.PLUGINS)
def _(data: dict[str, PluginInfo], module: str): async def _(data: dict[str, PluginInfo], key: str, value: Any):
return data.get(module, None) if value:
data[key] = value
elif plugin := await PluginInfo.get_plugin(module=key):
data[key] = plugin
@Cache.listener(CacheType.GROUPS) @CacheRoot.getter(CacheType.PLUGINS, result_model=PluginInfo)
async def _(data: dict[str, PluginInfo], module: str):
result = data.get(module, None)
if not result:
result = await PluginInfo.get_plugin(module=module)
if result:
data[module] = result
return result
@CacheRoot.new(CacheType.GROUPS)
async def _(): async def _():
data_list = await GroupConsole.all() data_list = await GroupConsole.all()
return {p.group_id: p for p in data_list} return {p.group_id: p for p in data_list if not p.channel_id}
@Cache.getter(CacheType.GROUPS, result_model=GroupConsole) @CacheRoot.updater(CacheType.GROUPS)
def _(data: dict[str, GroupConsole], module: str): async def _(data: dict[str, GroupConsole], key: str, value: Any):
return data.get(module, None) if value:
data[key] = value
elif group := await GroupConsole.get_group(group_id=key):
data[key] = group
@Cache.listener(CacheType.BOT) @CacheRoot.getter(CacheType.GROUPS, result_model=GroupConsole)
async def _(data: dict[str, GroupConsole], group_id: str):
result = data.get(group_id, None)
if not result:
result = await GroupConsole.get_group(group_id=group_id)
if result:
data[group_id] = result
return result
@CacheRoot.new(CacheType.BOT)
async def _(): async def _():
data_list = await BotConsole.all() data_list = await BotConsole.all()
return {p.bot_id: p for p in data_list} return {p.bot_id: p for p in data_list}
@Cache.getter(CacheType.BOT, result_model=BotConsole) @CacheRoot.updater(CacheType.BOT)
def _(data: dict[str, BotConsole], module: str): async def _(data: dict[str, BotConsole], key: str, value: Any):
return data.get(module, None) if value:
data[key] = value
elif bot := await BotConsole.get_or_none(bot_id=key):
data[key] = bot
@Cache.listener(CacheType.BAN) @CacheRoot.getter(CacheType.BOT, result_model=BotConsole)
async def _(data: dict[str, BotConsole], bot_id: str):
result = data.get(bot_id, None)
if not result:
result = await BotConsole.get_or_none(bot_id=bot_id)
if result:
data[bot_id] = result
return result
@CacheRoot.new(CacheType.USERS)
async def _():
data_list = await UserConsole.all()
return {p.user_id: p for p in data_list}
@CacheRoot.updater(CacheType.USERS)
async def _(data: dict[str, UserConsole], key: str, value: Any):
if value:
data[key] = value
elif user := await UserConsole.get_user(user_id=key):
data[key] = user
@CacheRoot.getter(CacheType.USERS, result_model=UserConsole)
async def _(data: dict[str, UserConsole], user_id: str):
result = data.get(user_id, None)
if not result:
result = await UserConsole.get_user(user_id=user_id)
if result:
data[user_id] = result
return result
@CacheRoot.new(CacheType.LEVEL)
async def _():
return await LevelUser().all()
@CacheRoot.getter(CacheType.LEVEL, result_model=list[LevelUser])
def _(data_list: list[LevelUser], user_id: str, group_id: str | None = None):
if not group_id:
return [
data for data in data_list if data.user_id == user_id and not data.group_id
]
else:
return [
data
for data in data_list
if data.user_id == user_id and data.group_id == group_id
]
@CacheRoot.new(CacheType.BAN)
async def _(): async def _():
return await BanConsole.all() return await BanConsole.all()
@Cache.getter(CacheType.BAN, result_model=list[BanConsole]) @CacheRoot.getter(CacheType.BAN, result_model=list[BanConsole])
def _(data_list: list[BanConsole], user_id: str, group_id: str): def _(data_list: list[BanConsole], user_id: str, group_id: str):
if user_id: if user_id:
if group_id: if group_id:

View File

@ -7,7 +7,8 @@ from tortoise.backends.base.client import BaseDBAsyncClient
from zhenxun.models.plugin_info import PluginInfo from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.task_info import TaskInfo from zhenxun.models.task_info import TaskInfo
from zhenxun.services.db_context import Model from zhenxun.services.db_context import Model
from zhenxun.utils.enum import PluginType from zhenxun.utils.cache_utils import CacheRoot
from zhenxun.utils.enum import CacheType, PluginType
class GroupConsole(Model): class GroupConsole(Model):
@ -46,8 +47,7 @@ class GroupConsole(Model):
platform = fields.CharField(255, default="qq", description="所属平台") platform = fields.CharField(255, default="qq", description="所属平台")
"""所属平台""" """所属平台"""
class Meta: # type: ignore class Meta: # type: ignore table = "group_console"
table = "group_console"
table_description = "群组信息表" table_description = "群组信息表"
unique_together = ("group_id", "channel_id") unique_together = ("group_id", "channel_id")
@ -80,6 +80,7 @@ class GroupConsole(Model):
return "".join(cls.format(item) for item in data) return "".join(cls.format(item) for item in data)
@classmethod @classmethod
@CacheRoot.listener(CacheType.GROUPS)
async def create( async def create(
cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any
) -> Self: ) -> Self:
@ -100,6 +101,7 @@ class GroupConsole(Model):
return group return group
@classmethod @classmethod
@CacheRoot.listener(CacheType.GROUPS)
async def get_or_create( async def get_or_create(
cls, cls,
defaults: dict | None = None, defaults: dict | None = None,
@ -127,6 +129,7 @@ class GroupConsole(Model):
return group, is_create return group, is_create
@classmethod @classmethod
@CacheRoot.listener(CacheType.GROUPS)
async def update_or_create( async def update_or_create(
cls, cls,
defaults: dict | None = None, defaults: dict | None = None,
@ -216,6 +219,7 @@ class GroupConsole(Model):
) )
@classmethod @classmethod
@CacheRoot.listener(CacheType.GROUPS)
async def set_block_plugin( async def set_block_plugin(
cls, cls,
group_id: str, group_id: str,
@ -242,6 +246,7 @@ class GroupConsole(Model):
await group.save(update_fields=["block_plugin", "superuser_block_plugin"]) await group.save(update_fields=["block_plugin", "superuser_block_plugin"])
@classmethod @classmethod
@CacheRoot.listener(CacheType.GROUPS)
async def set_unblock_plugin( async def set_unblock_plugin(
cls, cls,
group_id: str, group_id: str,
@ -338,6 +343,7 @@ class GroupConsole(Model):
) )
@classmethod @classmethod
@CacheRoot.listener(CacheType.GROUPS)
async def set_block_task( async def set_block_task(
cls, cls,
group_id: str, group_id: str,
@ -364,6 +370,7 @@ class GroupConsole(Model):
await group.save(update_fields=["block_task", "superuser_block_task"]) await group.save(update_fields=["block_task", "superuser_block_task"])
@classmethod @classmethod
@CacheRoot.listener(CacheType.GROUPS)
async def set_unblock_task( async def set_unblock_task(
cls, cls,
group_id: str, group_id: str,

View File

@ -1,10 +1,13 @@
from typing import Any
from typing_extensions import Self from typing_extensions import Self
from tortoise import fields from tortoise import fields
from tortoise.backends.base.client import BaseDBAsyncClient
from zhenxun.models.plugin_limit import PluginLimit # noqa: F401 from zhenxun.models.plugin_limit import PluginLimit # noqa: F401
from zhenxun.services.db_context import Model from zhenxun.services.db_context import Model
from zhenxun.utils.enum import BlockType, PluginType from zhenxun.utils.cache_utils import CacheRoot
from zhenxun.utils.enum import BlockType, CacheType, PluginType
class PluginInfo(Model): class PluginInfo(Model):
@ -79,6 +82,13 @@ class PluginInfo(Model):
""" """
return await cls.filter(load_status=load_status, **kwargs).all() return await cls.filter(load_status=load_status, **kwargs).all()
@classmethod
@CacheRoot.listener(CacheType.PLUGINS)
async def create(
cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any
) -> Self:
return await super().create(using_db=using_db, **kwargs)
@classmethod @classmethod
async def _run_script(cls): async def _run_script(cls):
return [ return [

View File

@ -1,11 +1,14 @@
from collections.abc import Callable from collections.abc import Callable
from functools import wraps
import time import time
from typing import Any, ClassVar, Generic, TypeVar, cast from typing import Any, ClassVar, Generic, TypeVar, cast
from nonebot.utils import is_coroutine_callable from nonebot.utils import is_coroutine_callable
from pydantic import BaseModel from pydantic import BaseModel
__all__ = ["Cache", "CacheData"] from zhenxun.services.log import logger
__all__ = ["Cache", "CacheData", "CacheRoot"]
T = TypeVar("T") T = TypeVar("T")
@ -27,10 +30,14 @@ class CacheGetter(BaseModel, Generic[T]):
class CacheData(BaseModel): class CacheData(BaseModel):
name: str
"""缓存名称"""
func: Callable[..., Any] func: Callable[..., Any]
"""更新方法""" """更新方法"""
getter: CacheGetter | None = None getter: CacheGetter | None = None
"""获取方法""" """获取方法"""
updater: Callable[..., Any] | None = None
"""更新单个方法"""
data: Any = None data: Any = None
"""缓存数据""" """缓存数据"""
expire: int expire: int
@ -40,13 +47,36 @@ class CacheData(BaseModel):
reload_count: int = 0 reload_count: int = 0
"""更新次数""" """更新次数"""
async def reload(self): async def get(self, *args, **kwargs) -> Any:
"""获取单个缓存"""
if not self.getter:
return self.data
return await self.getter.get(self.data, *args, **kwargs)
async def update(self, key: str, value: Any = None, *args, **kwargs):
"""更新单个缓存"""
if not self.updater:
return logger.warning(
f"缓存类型 {self.name} 没有更新方法,无法更新", "CacheRoot"
)
if self.data:
if is_coroutine_callable(self.updater):
await self.updater(self.data, key, value, *args, **kwargs)
else:
self.updater(self.data, key, value, *args, **kwargs)
else:
logger.warning(f"缓存类型 {self.name} 为空,无法更新", "CacheRoot")
async def reload(self, *args, **kwargs):
"""更新缓存""" """更新缓存"""
self.data = ( self.data = (
await self.func() if is_coroutine_callable(self.func) else self.func() await self.func(*args, **kwargs)
if is_coroutine_callable(self.func)
else self.func(*args, **kwargs)
) )
self.reload_time = time.time() self.reload_time = time.time()
self.reload_count += 1 self.reload_count += 1
logger.debug(f"缓存类型 {self.name} 更新全局缓存", "CacheRoot")
async def check_expire(self): async def check_expire(self):
if time.time() - self.reload_time > self.expire or not self.reload_count: if time.time() - self.reload_time > self.expire or not self.reload_count:
@ -54,14 +84,54 @@ class CacheData(BaseModel):
class CacheManage: class CacheManage:
"""全局缓存管理,减少数据库与网络请求查询次数
异常:
ValueError: 数据名称重复
ValueError: 数据不存在
"""
_data: ClassVar[dict[str, CacheData]] = {} _data: ClassVar[dict[str, CacheData]] = {}
def listener(self, name: str, expire: int = 60 * 10): def new(self, name: str, expire: int = 60 * 10):
def wrapper(func: Callable): def wrapper(func: Callable):
_name = name.upper() _name = name.upper()
if _name in self._data: if _name in self._data:
raise ValueError(f"DbCache 缓存数据 {name} 已存在...") raise ValueError(f"DbCache 缓存数据 {name} 已存在...")
self._data[_name] = CacheData(func=func, expire=expire) self._data[_name] = CacheData(name=_name, func=func, expire=expire)
return wrapper
def listener(self, name: str):
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
try:
if is_coroutine_callable(func):
result = await func(*args, **kwargs)
else:
result = func(*args, **kwargs)
return result
finally:
cache_data = self._data.get(name.upper())
if cache_data:
await cache_data.reload()
logger.debug(
f"缓存类型 {name.upper()} 进行监听更新...", "CacheRoot"
)
return wrapper
return decorator
def updater(self, name: str):
def wrapper(func: Callable):
_name = name.upper()
if _name not in self._data:
raise ValueError(f"DbCache 缓存数据 {name} 不存在...")
self._data[_name].updater = func
return wrapper return wrapper
@ -83,12 +153,12 @@ class CacheManage:
): ):
await self._data[name].reload() await self._data[name].reload()
async def get_cache_data(self, name: str) -> CacheData | None: async def get_cache_data(self, name: str):
if cache := await self.get_cache(name): if cache := await self.get_cache(name):
return cache return cache.data
return None return None
async def get_cache(self, name: str): async def get_cache(self, name: str) -> CacheData | None:
name = name.upper() name = name.upper()
cache = self._data.get(name) cache = self._data.get(name)
if cache: if cache:
@ -96,18 +166,35 @@ class CacheManage:
return cache return cache
return None return None
async def get(self, name: str, *args, **kwargs) -> T | None: async def get(self, name: str, *args, **kwargs):
cache = self._data.get(name.upper()) cache = await self.get_cache(name.upper())
if cache: if cache:
return ( return await cache.get(*args, **kwargs) if cache.getter else cache.data
await cache.getter.get(*args, **kwargs) if cache.getter else cache.data
)
return None return None
async def reload(self, name: str): async def reload(self, name: str, *args, **kwargs):
cache = self._data.get(name.upper()) cache = await self.get_cache(name.upper())
if cache: if cache:
await cache.reload() await cache.reload(*args, **kwargs)
async def update(self, name: str, key: str, value: Any, *args, **kwargs):
cache = await self.get_cache(name.upper())
if cache:
await cache.update(key, value, *args, **kwargs)
Cache = CacheManage() CacheRoot = CacheManage()
class Cache(Generic[T]):
def __init__(self, module: str):
self.module = module
async def get(self, *args, **kwargs) -> T | None:
return await CacheRoot.get(self.module, *args, **kwargs)
async def update(self, key: str, value: Any = None, *args, **kwargs):
return await CacheRoot.update(self.module, key, value, *args, **kwargs)
async def reload(self, key: str | None = None, *args, **kwargs):
await CacheRoot.reload(self.module, key, *args, **kwargs)

View File

@ -10,10 +10,14 @@ class CacheType(StrEnum):
"""全局全部插件""" """全局全部插件"""
GROUPS = "GLOBAL_ALL_GROUPS" GROUPS = "GLOBAL_ALL_GROUPS"
"""全局全部群组""" """全局全部群组"""
USERS = "GLOBAL_ALL_USERS"
"""全部用户"""
BAN = "GLOBAL_ALL_BAN" BAN = "GLOBAL_ALL_BAN"
"""全局ban列表""" """全局ban列表"""
BOT = "GLOBAL_BOT" BOT = "GLOBAL_BOT"
"""全局bot信息""" """全局bot信息"""
LEVEL = "GLOBAL_USER_LEVEL"
"""用户权限"""
class GoldHandle(StrEnum): class GoldHandle(StrEnum):