From 521bcaceeb6dd172d68a1a48d09085ff2799f76a Mon Sep 17 00:00:00 2001 From: HibiKier <775757368@qq.com> Date: Wed, 8 Jan 2025 15:23:10 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20=E6=9E=84=E5=BB=BA=E7=BC=93?= =?UTF-8?q?=E5=AD=98=EF=BC=8Chook=E4=BD=BF=E7=94=A8=E7=BC=93=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../builtin_plugins/hooks/_auth_checker.py | 580 ------------------ .../builtin_plugins/hooks/auth/auth_admin.py | 76 +++ .../builtin_plugins/hooks/auth/auth_bot.py | 38 +- .../builtin_plugins/hooks/auth/auth_cost.py | 35 ++ .../builtin_plugins/hooks/auth/auth_group.py | 57 ++ .../builtin_plugins/hooks/auth/auth_limit.py | 189 ++++++ .../builtin_plugins/hooks/auth/auth_plugin.py | 201 ++++++ .../builtin_plugins/hooks/auth/exception.py | 2 + zhenxun/builtin_plugins/hooks/auth/utils.py | 41 ++ zhenxun/builtin_plugins/hooks/auth_checker.py | 125 ++++ zhenxun/builtin_plugins/hooks/auth_hook.py | 35 +- zhenxun/builtin_plugins/init/__init__.py | 127 +++- zhenxun/models/group_console.py | 13 +- zhenxun/models/plugin_info.py | 12 +- zhenxun/utils/cache_utils.py | 121 +++- zhenxun/utils/enum.py | 4 + 16 files changed, 996 insertions(+), 660 deletions(-) delete mode 100644 zhenxun/builtin_plugins/hooks/_auth_checker.py create mode 100644 zhenxun/builtin_plugins/hooks/auth/auth_admin.py create mode 100644 zhenxun/builtin_plugins/hooks/auth/auth_cost.py create mode 100644 zhenxun/builtin_plugins/hooks/auth/auth_group.py create mode 100644 zhenxun/builtin_plugins/hooks/auth/auth_limit.py create mode 100644 zhenxun/builtin_plugins/hooks/auth/auth_plugin.py create mode 100644 zhenxun/builtin_plugins/hooks/auth/exception.py create mode 100644 zhenxun/builtin_plugins/hooks/auth/utils.py create mode 100644 zhenxun/builtin_plugins/hooks/auth_checker.py diff --git a/zhenxun/builtin_plugins/hooks/_auth_checker.py b/zhenxun/builtin_plugins/hooks/_auth_checker.py deleted file mode 100644 index b4ce15fb..00000000 --- a/zhenxun/builtin_plugins/hooks/_auth_checker.py +++ /dev/null @@ -1,580 +0,0 @@ -from nonebot.adapters import Bot, Event -from nonebot.adapters.onebot.v11 import PokeNotifyEvent -from nonebot.exception import IgnoredException -from nonebot.matcher import Matcher -from nonebot_plugin_alconna import At, UniMsg -from nonebot_plugin_session import EventSession -from pydantic import BaseModel -from tortoise.exceptions import IntegrityError - -from zhenxun.configs.config import Config -from zhenxun.models.bot_console import BotConsole -from zhenxun.models.group_console import GroupConsole -from zhenxun.models.level_user import LevelUser -from zhenxun.models.plugin_info import PluginInfo -from zhenxun.models.plugin_limit import PluginLimit -from zhenxun.models.user_console import UserConsole -from zhenxun.services.log import logger -from zhenxun.utils.cache_utils import Cache -from zhenxun.utils.enum import ( - BlockType, - CacheType, - GoldHandle, - LimitWatchType, - PluginLimitType, - PluginType, -) -from zhenxun.utils.exception import InsufficientGold -from zhenxun.utils.message import MessageUtils -from zhenxun.utils.utils import CountLimiter, FreqLimiter, UserBlockLimiter - -base_config = Config.get("hook") - - -class Limit(BaseModel): - limit: PluginLimit - limiter: FreqLimiter | UserBlockLimiter | CountLimiter - - class Config: - arbitrary_types_allowed = True - - -class LimitManage: - add_module = [] # noqa: RUF012 - - cd_limit: dict[str, Limit] = {} # noqa: RUF012 - block_limit: dict[str, Limit] = {} # noqa: RUF012 - count_limit: dict[str, Limit] = {} # noqa: RUF012 - - @classmethod - def add_limit(cls, limit: PluginLimit): - """添加限制 - - 参数: - limit: PluginLimit - """ - if limit.module not in cls.add_module: - cls.add_module.append(limit.module) - if limit.limit_type == PluginLimitType.BLOCK: - cls.block_limit[limit.module] = Limit( - limit=limit, limiter=UserBlockLimiter() - ) - elif limit.limit_type == PluginLimitType.CD: - cls.cd_limit[limit.module] = Limit( - limit=limit, limiter=FreqLimiter(limit.cd) - ) - elif limit.limit_type == PluginLimitType.COUNT: - cls.count_limit[limit.module] = Limit( - limit=limit, limiter=CountLimiter(limit.max_count) - ) - - @classmethod - def unblock( - cls, module: str, user_id: str, group_id: str | None, channel_id: str | None - ): - """解除插件block - - 参数: - module: 模块名 - user_id: 用户id - group_id: 群组id - channel_id: 频道id - """ - if limit_model := cls.block_limit.get(module): - limit = limit_model.limit - limiter: UserBlockLimiter = limit_model.limiter # type: ignore - key_type = user_id - if group_id and limit.watch_type == LimitWatchType.GROUP: - key_type = channel_id or group_id - logger.debug( - f"解除对象: {key_type} 的block限制", - "AuthChecker", - session=user_id, - group_id=group_id, - ) - limiter.set_false(key_type) - - @classmethod - async def check( - cls, - module: str, - user_id: str, - group_id: str | None, - channel_id: str | None, - session: EventSession, - ): - """检测限制 - - 参数: - module: 模块名 - user_id: 用户id - group_id: 群组id - channel_id: 频道id - session: Session - - 异常: - IgnoredException: IgnoredException - """ - if limit_model := cls.cd_limit.get(module): - await cls.__check(limit_model, user_id, group_id, channel_id, session) - if limit_model := cls.block_limit.get(module): - await cls.__check(limit_model, user_id, group_id, channel_id, session) - if limit_model := cls.count_limit.get(module): - await cls.__check(limit_model, user_id, group_id, channel_id, session) - - @classmethod - async def __check( - cls, - limit_model: Limit | None, - user_id: str, - group_id: str | None, - channel_id: str | None, - session: EventSession, - ): - """检测限制 - - 参数: - limit_model: Limit - user_id: 用户id - group_id: 群组id - channel_id: 频道id - session: Session - - 异常: - IgnoredException: IgnoredException - """ - if not limit_model: - return - limit = limit_model.limit - limiter = limit_model.limiter - is_limit = ( - LimitWatchType.ALL - or (group_id and limit.watch_type == LimitWatchType.GROUP) - or (not group_id and limit.watch_type == LimitWatchType.USER) - ) - key_type = user_id - if group_id and limit.watch_type == LimitWatchType.GROUP: - key_type = channel_id or group_id - if is_limit and not limiter.check(key_type): - if limit.result: - await MessageUtils.build_message(limit.result).send() - logger.debug( - f"{limit.module}({limit.limit_type}) 正在限制中...", - "AuthChecker", - session=session, - ) - raise IgnoredException(f"{limit.module} 正在限制中...") - else: - logger.debug( - f"开始进行限制 {limit.module}({limit.limit_type})...", - "AuthChecker", - session=user_id, - group_id=group_id, - ) - if isinstance(limiter, FreqLimiter): - limiter.start_cd(key_type) - if isinstance(limiter, UserBlockLimiter): - limiter.set_true(key_type) - if isinstance(limiter, CountLimiter): - limiter.increase(key_type) - - -class IsSuperuserException(Exception): - pass - - -class AuthChecker: - """ - 权限检查 - """ - - def __init__(self): - check_notice_info_cd = Config.get_config("hook", "CHECK_NOTICE_INFO_CD") - if check_notice_info_cd is None or check_notice_info_cd < 0: - raise ValueError("模块: [hook], 配置项: [CHECK_NOTICE_INFO_CD] 为空或小于0") - self._flmt = FreqLimiter(check_notice_info_cd) - self._flmt_g = FreqLimiter(check_notice_info_cd) - self._flmt_s = FreqLimiter(check_notice_info_cd) - self._flmt_c = FreqLimiter(check_notice_info_cd) - - def is_send_limit_message(self, plugin: PluginInfo, sid: str) -> bool: - """是否发送提示消息 - - 参数: - plugin: PluginInfo - - 返回: - bool: 是否发送提示消息 - """ - if not base_config.get("IS_SEND_TIP_MESSAGE"): - return False - if plugin.plugin_type == PluginType.DEPENDANT: - return False - return plugin.module != "ai" if self._flmt_s.check(sid) else False - - async def auth( - self, - matcher: Matcher, - event: Event, - bot: Bot, - session: EventSession, - message: UniMsg, - ): - """权限检查 - - 参数: - matcher: matcher - bot: bot - session: EventSession - message: UniMsg - """ - is_ignore = False - cost_gold = 0 - user_id = session.id1 - group_id = session.id3 - channel_id = session.id2 - if not group_id: - group_id = channel_id - channel_id = None - if matcher.type == "notice" and not isinstance(event, PokeNotifyEvent): - """过滤除poke外的notice""" - return - if user_id and matcher.plugin and (module_path := matcher.plugin.module_name): - try: - user = await UserConsole.get_user(user_id, session.platform) - except IntegrityError as e: - logger.debug( - "重复创建用户,已跳过该次权限...", - "AuthChecker", - session=session, - e=e, - ) - return - if plugin := await Cache.get(CacheType.PLUGINS, module_path): - if plugin.plugin_type == PluginType.HIDDEN: - logger.debug( - f"插件: {plugin.name}:{plugin.module} " - "为HIDDEN,已跳过权限检查..." - ) - return - try: - cost_gold = await self.auth_cost(user, plugin, session) - if session.id1 in bot.config.superusers: - if plugin.plugin_type == PluginType.SUPERUSER: - raise IsSuperuserException() - if not plugin.limit_superuser: - cost_gold = 0 - raise IsSuperuserException() - await self.auth_bot(plugin, bot.self_id) - await self.auth_group(plugin, session, message) - await self.auth_admin(plugin, session) - await self.auth_plugin(plugin, session, event) - await self.auth_limit(plugin, session) - except IsSuperuserException: - logger.debug( - "超级用户或被ban跳过权限检测...", "AuthChecker", session=session - ) - except IgnoredException: - is_ignore = True - LimitManage.unblock( - matcher.plugin.name, user_id, group_id, channel_id - ) - except AssertionError as e: - is_ignore = True - logger.debug("消息无法发送", session=session, e=e) - if cost_gold and user_id: - """花费金币""" - try: - await UserConsole.reduce_gold( - user_id, - cost_gold, - GoldHandle.PLUGIN, - matcher.plugin.name if matcher.plugin else "", - session.platform, - ) - except InsufficientGold: - if u := await UserConsole.get_user(user_id): - u.gold = 0 - await u.save(update_fields=["gold"]) - logger.debug( - f"调用功能花费金币: {cost_gold}", "AuthChecker", session=session - ) - if is_ignore: - raise IgnoredException("权限检测 ignore") - - async def auth_bot(self, plugin: PluginInfo, bot_id: str): - """机器人权限 - - 参数: - plugin: PluginInfo - bot_id: bot_id - """ - if not await BotConsole.get_bot_status(bot_id): - logger.debug("Bot休眠中阻断权限检测...", "AuthChecker") - raise IgnoredException("BotConsole休眠权限检测 ignore") - if await BotConsole.is_block_plugin(bot_id, plugin.module): - logger.debug( - f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭...", - "AuthChecker", - ) - raise IgnoredException("BotConsole插件权限检测 ignore") - - async def auth_limit(self, plugin: PluginInfo, session: EventSession): - """插件限制 - - 参数: - plugin: PluginInfo - session: EventSession - """ - user_id = session.id1 - group_id = session.id3 - channel_id = session.id2 - if not group_id: - group_id = channel_id - channel_id = None - if plugin.module not in LimitManage.add_module: - limit_list: list[PluginLimit] = await plugin.plugin_limit.filter( - status=True - ).all() # type: ignore - for limit in limit_list: - LimitManage.add_limit(limit) - if user_id: - await LimitManage.check( - plugin.module, user_id, group_id, channel_id, session - ) - - async def auth_plugin( - self, plugin: PluginInfo, session: EventSession, event: Event - ): - """插件状态 - - 参数: - plugin: PluginInfo - session: EventSession - """ - group_id = session.id3 - channel_id = session.id2 - if not group_id: - group_id = channel_id - channel_id = None - if user_id := session.id1: - is_poke = isinstance(event, PokeNotifyEvent) - if group_id: - sid = group_id or user_id - if await GroupConsole.is_superuser_block_plugin( - group_id, plugin.module - ): - """超级用户群组插件状态""" - if self.is_send_limit_message(plugin, sid) and not is_poke: - self._flmt_s.start_cd(group_id or user_id) - await MessageUtils.build_message( - "超级管理员禁用了该群此功能..." - ).send(reply_to=True) - logger.debug( - f"{plugin.name}({plugin.module}) 超级管理员禁用了该群此功能...", - "AuthChecker", - session=session, - ) - raise IgnoredException("超级管理员禁用了该群此功能...") - if await GroupConsole.is_normal_block_plugin(group_id, plugin.module): - """群组插件状态""" - if self.is_send_limit_message(plugin, sid) and not is_poke: - self._flmt_s.start_cd(group_id or user_id) - await MessageUtils.build_message("该群未开启此功能...").send( - reply_to=True - ) - logger.debug( - f"{plugin.name}({plugin.module}) 未开启此功能...", - "AuthChecker", - session=session, - ) - raise IgnoredException("该群未开启此功能...") - if plugin.block_type == BlockType.GROUP: - """全局群组禁用""" - try: - if self.is_send_limit_message(plugin, sid) and not is_poke: - self._flmt_c.start_cd(group_id) - await MessageUtils.build_message( - "该功能在群组中已被禁用..." - ).send(reply_to=True) - except Exception as e: - logger.error( - "auth_plugin 发送消息失败", - "AuthChecker", - session=session, - e=e, - ) - logger.debug( - f"{plugin.name}({plugin.module}) 该插件在群组中已被禁用...", - "AuthChecker", - session=session, - ) - raise IgnoredException("该插件在群组中已被禁用...") - else: - sid = user_id - if plugin.block_type == BlockType.PRIVATE: - """全局私聊禁用""" - try: - if self.is_send_limit_message(plugin, sid) and not is_poke: - self._flmt_c.start_cd(user_id) - await MessageUtils.build_message( - "该功能在私聊中已被禁用..." - ).send() - except Exception as e: - logger.error( - "auth_admin 发送消息失败", - "AuthChecker", - session=session, - e=e, - ) - logger.debug( - f"{plugin.name}({plugin.module}) 该插件在私聊中已被禁用...", - "AuthChecker", - session=session, - ) - raise IgnoredException("该插件在私聊中已被禁用...") - if not plugin.status and plugin.block_type == BlockType.ALL: - """全局状态""" - if group_id and await GroupConsole.is_super_group(group_id): - raise IsSuperuserException() - logger.debug( - f"{plugin.name}({plugin.module}) 全局未开启此功能...", - "AuthChecker", - session=session, - ) - if self.is_send_limit_message(plugin, sid) and not is_poke: - self._flmt_s.start_cd(group_id or user_id) - await MessageUtils.build_message("全局未开启此功能...").send() - raise IgnoredException("全局未开启此功能...") - - async def auth_admin(self, plugin: PluginInfo, session: EventSession): - """管理员命令 个人权限 - - 参数: - plugin: PluginInfo - session: EventSession - """ - user_id = session.id1 - if user_id and plugin.admin_level: - if group_id := session.id3 or session.id2: - if not await LevelUser.check_level( - user_id, group_id, plugin.admin_level - ): - try: - if self._flmt.check(user_id): - self._flmt.start_cd(user_id) - await MessageUtils.build_message( - [ - At(flag="user", target=user_id), - f"你的权限不足喔," - f"该功能需要的权限等级: {plugin.admin_level}", - ] - ).send(reply_to=True) - except Exception as e: - logger.error( - "auth_admin 发送消息失败", - "AuthChecker", - session=session, - e=e, - ) - logger.debug( - f"{plugin.name}({plugin.module}) 管理员权限不足...", - "AuthChecker", - session=session, - ) - raise IgnoredException("管理员权限不足...") - elif not await LevelUser.check_level(user_id, None, plugin.admin_level): - try: - await MessageUtils.build_message( - f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}" - ).send() - except Exception as e: - logger.error( - "auth_admin 发送消息失败", "AuthChecker", session=session, e=e - ) - logger.debug( - f"{plugin.name}({plugin.module}) 管理员权限不足...", - "AuthChecker", - session=session, - ) - raise IgnoredException("权限不足") - - async def auth_group( - self, plugin: PluginInfo, session: EventSession, message: UniMsg - ): - """群黑名单检测 群总开关检测 - - 参数: - plugin: PluginInfo - session: EventSession - message: UniMsg - """ - if not (group_id := session.id3 or session.id2): - return - text = message.extract_plain_text() - group = await GroupConsole.get_group(group_id) - if not group: - """群不存在""" - logger.debug( - "群组信息不存在...", - "AuthChecker", - session=session, - ) - raise IgnoredException("群不存在") - if group.level < 0: - """群权限小于0""" - logger.debug( - "群黑名单, 群权限-1...", - "AuthChecker", - session=session, - ) - raise IgnoredException("群黑名单") - if not group.status: - """群休眠""" - if text.strip() != "醒来": - logger.debug("群休眠状态...", "AuthChecker", session=session) - raise IgnoredException("群休眠状态") - if plugin.level > group.level: - """插件等级大于群等级""" - logger.debug( - f"{plugin.name}({plugin.module}) 群等级限制.." - f"该功能需要的群等级: {plugin.level}..", - "AuthChecker", - session=session, - ) - raise IgnoredException(f"{plugin.name}({plugin.module}) 群等级限制...") - - async def auth_cost( - self, user: UserConsole, plugin: PluginInfo, session: EventSession - ) -> int: - """检测是否满足金币条件 - - 参数: - user: UserConsole - plugin: PluginInfo - session: EventSession - - 返回: - int: 需要消耗的金币 - """ - if user.gold < plugin.cost_gold: - """插件消耗金币不足""" - try: - await MessageUtils.build_message( - f"金币不足..该功能需要{plugin.cost_gold}金币.." - ).send() - except Exception as e: - logger.error( - "auth_cost 发送消息失败", "AuthChecker", session=session, e=e - ) - logger.debug( - f"{plugin.name}({plugin.module}) 金币限制.." - f"该功能需要{plugin.cost_gold}金币..", - "AuthChecker", - session=session, - ) - raise IgnoredException(f"{plugin.name}({plugin.module}) 金币限制...") - return plugin.cost_gold - - -checker = AuthChecker() diff --git a/zhenxun/builtin_plugins/hooks/auth/auth_admin.py b/zhenxun/builtin_plugins/hooks/auth/auth_admin.py new file mode 100644 index 00000000..294d1c3d --- /dev/null +++ b/zhenxun/builtin_plugins/hooks/auth/auth_admin.py @@ -0,0 +1,76 @@ +from nonebot.exception import IgnoredException +from nonebot_plugin_alconna import At +from nonebot_plugin_uninfo import Uninfo + +from zhenxun.models.level_user import LevelUser +from zhenxun.models.plugin_info import PluginInfo +from zhenxun.services.log import logger +from zhenxun.utils.cache_utils import Cache +from zhenxun.utils.enum import CacheType +from zhenxun.utils.message import MessageUtils + +from .utils import freq + + +async def auth_admin(plugin: PluginInfo, session: Uninfo): + """管理员命令 个人权限 + + 参数: + plugin: PluginInfo + session: PluginInfo + """ + group_id = None + cache = Cache[list[LevelUser]](CacheType.LEVEL) + user_level = await cache.get(session.user.id) or [] + if session.group: + if session.group.parent: + group_id = session.group.parent.id + else: + group_id = session.group.id + + if not plugin.admin_level: + return + if group_id: + user_level += await cache.get(session.user.id, group_id) or [] + user = max(user_level, key=lambda x: x.user_level) + if user.user_level < plugin.admin_level: + try: + if freq._flmt.check(session.user.id): + freq._flmt.start_cd(session.user.id) + await MessageUtils.build_message( + [ + At(flag="user", target=session.user.id), + "你的权限不足喔," + f"该功能需要的权限等级: {plugin.admin_level}", + ] + ).send(reply_to=True) + except Exception as e: + logger.error( + "auth_admin 发送消息失败", + "AuthChecker", + session=session, + e=e, + ) + logger.debug( + f"{plugin.name}({plugin.module}) 管理员权限不足...", + "AuthChecker", + session=session, + ) + raise IgnoredException("管理员权限不足...") + elif user_level: + user = max(user_level, key=lambda x: x.user_level) + if user.user_level < plugin.admin_level: + try: + await MessageUtils.build_message( + f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}" + ).send() + except Exception as e: + logger.error( + "auth_admin 发送消息失败", "AuthChecker", session=session, e=e + ) + logger.debug( + f"{plugin.name}({plugin.module}) 管理员权限不足...", + "AuthChecker", + session=session, + ) + raise IgnoredException("权限不足") diff --git a/zhenxun/builtin_plugins/hooks/auth/auth_bot.py b/zhenxun/builtin_plugins/hooks/auth/auth_bot.py index 76fd109c..7afbf5a0 100644 --- a/zhenxun/builtin_plugins/hooks/auth/auth_bot.py +++ b/zhenxun/builtin_plugins/hooks/auth/auth_bot.py @@ -4,31 +4,29 @@ from zhenxun.models.bot_console import BotConsole from zhenxun.models.plugin_info import PluginInfo from zhenxun.services.log import logger from zhenxun.utils.cache_utils import Cache +from zhenxun.utils.common_utils import CommonUtils from zhenxun.utils.enum import CacheType -async def get_bot_status(bot_id: str): - a = await Cache.get(CacheType.BOT, bot_id) - cache_data = await Cache.get_cache(CacheType.BOT) - if cache_data and cache_data.getter: - b = await cache_data.getter.get(cache_data.data) - if bot := await Cache.get(CacheType.BOT, bot_id): - return bot - - async def auth_bot(plugin: PluginInfo, bot_id: str): - """机器人权限 + """bot层面的权限检查 参数: plugin: PluginInfo - bot_id: bot_id + bot_id: bot id + + 异常: + IgnoredException: 忽略插件 + IgnoredException: 忽略插件 """ - if not await BotConsole.get_bot_status(bot_id): - logger.debug("Bot休眠中阻断权限检测...", "AuthChecker") - raise IgnoredException("BotConsole休眠权限检测 ignore") - if await BotConsole.is_block_plugin(bot_id, plugin.module): - logger.debug( - f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭...", - "AuthChecker", - ) - raise IgnoredException("BotConsole插件权限检测 ignore") + if cache := Cache[BotConsole](CacheType.BOT): + bot = await cache.get(bot_id) + if not bot or not bot.status: + logger.debug("Bot不存在或休眠中阻断权限检测...", "AuthChecker") + raise IgnoredException("BotConsole休眠权限检测 ignore") + if CommonUtils.format(plugin.module) in bot.block_plugins: + logger.debug( + f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭...", + "AuthChecker", + ) + raise IgnoredException("BotConsole插件权限检测 ignore") diff --git a/zhenxun/builtin_plugins/hooks/auth/auth_cost.py b/zhenxun/builtin_plugins/hooks/auth/auth_cost.py new file mode 100644 index 00000000..9a5e9a48 --- /dev/null +++ b/zhenxun/builtin_plugins/hooks/auth/auth_cost.py @@ -0,0 +1,35 @@ +from nonebot.exception import IgnoredException +from nonebot_plugin_uninfo import Uninfo + +from zhenxun.models.plugin_info import PluginInfo +from zhenxun.models.user_console import UserConsole +from zhenxun.services.log import logger +from zhenxun.utils.message import MessageUtils + + +async def auth_cost(user: UserConsole, plugin: PluginInfo, session: Uninfo) -> int: + """检测是否满足金币条件 + + 参数: + plugin: PluginInfo + session: Uninfo + + 返回: + int: 需要消耗的金币 + """ + if user.gold < plugin.cost_gold: + """插件消耗金币不足""" + try: + await MessageUtils.build_message( + f"金币不足..该功能需要{plugin.cost_gold}金币.." + ).send() + except Exception as e: + logger.error("auth_cost 发送消息失败", "AuthChecker", session=session, e=e) + logger.debug( + f"{plugin.name}({plugin.module}) 金币限制.." + f"该功能需要{plugin.cost_gold}金币..", + "AuthChecker", + session=session, + ) + raise IgnoredException(f"{plugin.name}({plugin.module}) 金币限制...") + return plugin.cost_gold diff --git a/zhenxun/builtin_plugins/hooks/auth/auth_group.py b/zhenxun/builtin_plugins/hooks/auth/auth_group.py new file mode 100644 index 00000000..81ce7565 --- /dev/null +++ b/zhenxun/builtin_plugins/hooks/auth/auth_group.py @@ -0,0 +1,57 @@ +from nonebot.exception import IgnoredException +from nonebot_plugin_alconna import UniMsg +from nonebot_plugin_uninfo import Uninfo + +from zhenxun.models.group_console import GroupConsole +from zhenxun.models.plugin_info import PluginInfo +from zhenxun.services.log import logger +from zhenxun.utils.cache_utils import Cache +from zhenxun.utils.enum import CacheType + + +async def auth_group(plugin: PluginInfo, session: Uninfo, message: UniMsg): + """群黑名单检测 群总开关检测 + + 参数: + plugin: PluginInfo + session: EventSession + message: UniMsg + """ + if not session.group: + return + if session.group.parent: + group_id = session.group.parent.id + else: + group_id = session.group.id + text = message.extract_plain_text() + group = await Cache[GroupConsole](CacheType.GROUPS).get(group_id) + if not group: + """群不存在""" + logger.debug( + "群组信息不存在...", + "AuthChecker", + session=session, + ) + raise IgnoredException("群不存在") + if group.level < 0: + """群权限小于0""" + logger.debug( + "群黑名单, 群权限-1...", + "AuthChecker", + session=session, + ) + raise IgnoredException("群黑名单") + if not group.status: + """群休眠""" + if text.strip() != "醒来": + logger.debug("群休眠状态...", "AuthChecker", session=session) + raise IgnoredException("群休眠状态") + if plugin.level > group.level: + """插件等级大于群等级""" + logger.debug( + f"{plugin.name}({plugin.module}) 群等级限制.." + f"该功能需要的群等级: {plugin.level}..", + "AuthChecker", + session=session, + ) + raise IgnoredException(f"{plugin.name}({plugin.module}) 群等级限制...") diff --git a/zhenxun/builtin_plugins/hooks/auth/auth_limit.py b/zhenxun/builtin_plugins/hooks/auth/auth_limit.py new file mode 100644 index 00000000..0da55b72 --- /dev/null +++ b/zhenxun/builtin_plugins/hooks/auth/auth_limit.py @@ -0,0 +1,189 @@ +from typing import ClassVar + +from nonebot.exception import IgnoredException +from nonebot_plugin_uninfo import Uninfo +from pydantic import BaseModel + +from zhenxun.models.plugin_info import PluginInfo +from zhenxun.models.plugin_limit import PluginLimit +from zhenxun.services.log import logger +from zhenxun.utils.enum import LimitWatchType, PluginLimitType +from zhenxun.utils.message import MessageUtils +from zhenxun.utils.utils import CountLimiter, FreqLimiter, UserBlockLimiter + + +class Limit(BaseModel): + limit: PluginLimit + limiter: FreqLimiter | UserBlockLimiter | CountLimiter + + class Config: + arbitrary_types_allowed = True + + +class LimitManage: + add_module: ClassVar[list] = [] + + cd_limit: ClassVar[dict[str, Limit]] = {} + block_limit: ClassVar[dict[str, Limit]] = {} + count_limit: ClassVar[dict[str, Limit]] = {} + + @classmethod + def add_limit(cls, limit: PluginLimit): + """添加限制 + + 参数: + limit: PluginLimit + """ + if limit.module not in cls.add_module: + cls.add_module.append(limit.module) + if limit.limit_type == PluginLimitType.BLOCK: + cls.block_limit[limit.module] = Limit( + limit=limit, limiter=UserBlockLimiter() + ) + elif limit.limit_type == PluginLimitType.CD: + cls.cd_limit[limit.module] = Limit( + limit=limit, limiter=FreqLimiter(limit.cd) + ) + elif limit.limit_type == PluginLimitType.COUNT: + cls.count_limit[limit.module] = Limit( + limit=limit, limiter=CountLimiter(limit.max_count) + ) + + @classmethod + def unblock( + cls, module: str, user_id: str, group_id: str | None, channel_id: str | None + ): + """解除插件block + + 参数: + module: 模块名 + user_id: 用户id + group_id: 群组id + channel_id: 频道id + """ + if limit_model := cls.block_limit.get(module): + limit = limit_model.limit + limiter: UserBlockLimiter = limit_model.limiter # type: ignore + key_type = user_id + if group_id and limit.watch_type == LimitWatchType.GROUP: + key_type = channel_id or group_id + logger.debug( + f"解除对象: {key_type} 的block限制", + "AuthChecker", + session=user_id, + group_id=group_id, + ) + limiter.set_false(key_type) + + @classmethod + async def check( + cls, + module: str, + user_id: str, + group_id: str | None, + channel_id: str | None, + session: Uninfo, + ): + """检测限制 + + 参数: + module: 模块名 + user_id: 用户id + group_id: 群组id + channel_id: 频道id + session: Session + + 异常: + IgnoredException: IgnoredException + """ + if limit_model := cls.cd_limit.get(module): + await cls.__check(limit_model, user_id, group_id, channel_id, session) + if limit_model := cls.block_limit.get(module): + await cls.__check(limit_model, user_id, group_id, channel_id, session) + if limit_model := cls.count_limit.get(module): + await cls.__check(limit_model, user_id, group_id, channel_id, session) + + @classmethod + async def __check( + cls, + limit_model: Limit | None, + user_id: str, + group_id: str | None, + channel_id: str | None, + session: Uninfo, + ): + """检测限制 + + 参数: + limit_model: Limit + user_id: 用户id + group_id: 群组id + channel_id: 频道id + session: Session + + 异常: + IgnoredException: IgnoredException + """ + if not limit_model: + return + limit = limit_model.limit + limiter = limit_model.limiter + is_limit = ( + LimitWatchType.ALL + or (group_id and limit.watch_type == LimitWatchType.GROUP) + or (not group_id and limit.watch_type == LimitWatchType.USER) + ) + key_type = user_id + if group_id and limit.watch_type == LimitWatchType.GROUP: + key_type = channel_id or group_id + if is_limit and not limiter.check(key_type): + if limit.result: + await MessageUtils.build_message(limit.result).send() + logger.debug( + f"{limit.module}({limit.limit_type}) 正在限制中...", + "AuthChecker", + session=session, + ) + raise IgnoredException(f"{limit.module} 正在限制中...") + else: + logger.debug( + f"开始进行限制 {limit.module}({limit.limit_type})...", + "AuthChecker", + session=user_id, + group_id=group_id, + ) + if isinstance(limiter, FreqLimiter): + limiter.start_cd(key_type) + if isinstance(limiter, UserBlockLimiter): + limiter.set_true(key_type) + if isinstance(limiter, CountLimiter): + limiter.increase(key_type) + + +async def auth_limit(plugin: PluginInfo, session: Uninfo): + """插件限制 + + 参数: + plugin: PluginInfo + session: Uninfo + """ + user_id = session.user.id + group_id = None + channel_id = None + if session.group: + if session.group.parent: + group_id = session.group.parent.id + channel_id = session.group.id + else: + group_id = session.group.id + if not group_id: + group_id = channel_id + channel_id = None + if plugin.module not in LimitManage.add_module: + limit_list: list[PluginLimit] = await plugin.plugin_limit.filter( + status=True + ).all() # type: ignore + for limit in limit_list: + LimitManage.add_limit(limit) + if user_id: + await LimitManage.check(plugin.module, user_id, group_id, channel_id, session) diff --git a/zhenxun/builtin_plugins/hooks/auth/auth_plugin.py b/zhenxun/builtin_plugins/hooks/auth/auth_plugin.py new file mode 100644 index 00000000..f014c1cd --- /dev/null +++ b/zhenxun/builtin_plugins/hooks/auth/auth_plugin.py @@ -0,0 +1,201 @@ +from nonebot.adapters import Event +from nonebot.exception import IgnoredException +from nonebot_plugin_uninfo import Uninfo + +from zhenxun.models.group_console import GroupConsole +from zhenxun.models.plugin_info import PluginInfo +from zhenxun.services.log import logger +from zhenxun.utils.cache_utils import Cache +from zhenxun.utils.common_utils import CommonUtils +from zhenxun.utils.enum import BlockType, CacheType +from zhenxun.utils.message import MessageUtils + +from .exception import IsSuperuserException +from .utils import freq + + +class GroupCheck: + def __init__( + self, plugin: PluginInfo, group_id: str, session: Uninfo, is_poke: bool + ) -> None: + self.group_id = group_id + self.session = session + self.is_poke = is_poke + self.plugin = plugin + + async def __get_data(self): + cache = Cache[GroupConsole](CacheType.GROUPS) + return await cache.get(self.group_id) + + async def check(self): + await self.check_superuser_block(self.plugin) + + async def check_superuser_block(self, plugin: PluginInfo): + """超级用户禁用群组插件检测 + + 参数: + plugin: PluginInfo + + 异常: + IgnoredException: 忽略插件 + """ + group = await self.__get_data() + if group and CommonUtils.format(plugin.module) in group.superuser_block_plugin: + if freq.is_send_limit_message(plugin, group.group_id, self.is_poke): + freq._flmt_s.start_cd(self.group_id) + await MessageUtils.build_message("超级管理员禁用了该群此功能...").send( + reply_to=True + ) + logger.debug( + f"{plugin.name}({plugin.module}) 超级管理员禁用了该群此功能...", + "AuthChecker", + session=self.session, + ) + raise IgnoredException("超级管理员禁用了该群此功能...") + await self.check_normal_block(self.plugin) + + async def check_normal_block(self, plugin: PluginInfo): + """群组插件状态 + + 参数: + plugin: PluginInfo + + 异常: + IgnoredException: 忽略插件 + """ + group = await self.__get_data() + if group and CommonUtils.format(plugin.module) in group.block_plugin: + if freq.is_send_limit_message(plugin, self.group_id, self.is_poke): + freq._flmt_s.start_cd(self.group_id) + await MessageUtils.build_message("该群未开启此功能...").send( + reply_to=True + ) + logger.debug( + f"{plugin.name}({plugin.module}) 未开启此功能...", + "AuthChecker", + session=self.session, + ) + raise IgnoredException("该群未开启此功能...") + await self.check_global_block(self.plugin) + + async def check_global_block(self, plugin: PluginInfo): + """全局禁用插件检测 + + 参数: + plugin: PluginInfo + + 异常: + IgnoredException: 忽略插件 + """ + if plugin.block_type == BlockType.GROUP: + """全局群组禁用""" + try: + if freq.is_send_limit_message(plugin, self.group_id, self.is_poke): + freq._flmt_c.start_cd(self.group_id) + await MessageUtils.build_message("该功能在群组中已被禁用...").send( + reply_to=True + ) + except Exception as e: + logger.error( + "auth_plugin 发送消息失败", + "AuthChecker", + session=self.session, + e=e, + ) + logger.debug( + f"{plugin.name}({plugin.module}) 该插件在群组中已被禁用...", + "AuthChecker", + session=self.session, + ) + raise IgnoredException("该插件在群组中已被禁用...") + + +class PluginCheck: + def __init__(self, group_id: str | None, session: Uninfo, is_poke: bool): + self.session = session + self.is_poke = is_poke + self.group_id = group_id + + async def check_user(self, plugin: PluginInfo): + """全局私聊禁用检测 + + 参数: + plugin: PluginInfo + + 异常: + IgnoredException: 忽略插件 + """ + if plugin.block_type == BlockType.PRIVATE: + try: + if freq.is_send_limit_message( + plugin, self.session.user.id, self.is_poke + ): + freq._flmt_c.start_cd(self.session.user.id) + await MessageUtils.build_message("该功能在私聊中已被禁用...").send() + except Exception as e: + logger.error( + "auth_admin 发送消息失败", + "AuthChecker", + session=self.session, + e=e, + ) + logger.debug( + f"{plugin.name}({plugin.module}) 该插件在私聊中已被禁用...", + "AuthChecker", + session=self.session, + ) + raise IgnoredException("该插件在私聊中已被禁用...") + + async def check_global(self, plugin: PluginInfo): + """全局状态 + + 参数: + plugin: PluginInfo + + 异常: + IgnoredException: 忽略插件 + """ + if not plugin.status and plugin.block_type == BlockType.ALL: + """全局状态""" + cache = Cache[GroupConsole](CacheType.GROUPS) + if self.group_id and (group := await cache.get(self.group_id)): + if group.is_super: + raise IsSuperuserException() + logger.debug( + f"{plugin.name}({plugin.module}) 全局未开启此功能...", + "AuthChecker", + session=self.session, + ) + sid = self.group_id or self.session.user.id + if freq.is_send_limit_message(plugin, sid, self.is_poke): + freq._flmt_s.start_cd(sid) + await MessageUtils.build_message("全局未开启此功能...").send() + raise IgnoredException("全局未开启此功能...") + + +async def auth_plugin(plugin: PluginInfo, session: Uninfo, event: Event): + """插件状态 + + 参数: + plugin: PluginInfo + session: Uninfo + """ + group_id = None + if session.group: + if session.group.parent: + group_id = session.group.parent.id + else: + group_id = session.group.id + try: + from nonebot.adapters.onebot.v11 import PokeNotifyEvent + + is_poke = isinstance(event, PokeNotifyEvent) + except ImportError: + is_poke = False + user_check = PluginCheck(group_id, session, is_poke) + if group_id: + group_check = GroupCheck(plugin, group_id, session, is_poke) + await group_check.check() + else: + await user_check.check_user(plugin) + await user_check.check_global(plugin) diff --git a/zhenxun/builtin_plugins/hooks/auth/exception.py b/zhenxun/builtin_plugins/hooks/auth/exception.py new file mode 100644 index 00000000..cc0d3fde --- /dev/null +++ b/zhenxun/builtin_plugins/hooks/auth/exception.py @@ -0,0 +1,2 @@ +class IsSuperuserException(Exception): + pass diff --git a/zhenxun/builtin_plugins/hooks/auth/utils.py b/zhenxun/builtin_plugins/hooks/auth/utils.py new file mode 100644 index 00000000..5ce3c7fa --- /dev/null +++ b/zhenxun/builtin_plugins/hooks/auth/utils.py @@ -0,0 +1,41 @@ +from zhenxun.configs.config import Config +from zhenxun.models.plugin_info import PluginInfo +from zhenxun.utils.enum import PluginType +from zhenxun.utils.utils import FreqLimiter + +base_config = Config.get("hook") + + +class FreqUtils: + def __init__(self): + check_notice_info_cd = Config.get_config("hook", "CHECK_NOTICE_INFO_CD") + if check_notice_info_cd is None or check_notice_info_cd < 0: + raise ValueError("模块: [hook], 配置项: [CHECK_NOTICE_INFO_CD] 为空或小于0") + self._flmt = FreqLimiter(check_notice_info_cd) + self._flmt_g = FreqLimiter(check_notice_info_cd) + self._flmt_s = FreqLimiter(check_notice_info_cd) + self._flmt_c = FreqLimiter(check_notice_info_cd) + + def is_send_limit_message( + self, plugin: PluginInfo, sid: str, is_poke: bool + ) -> bool: + """是否发送提示消息 + + 参数: + plugin: PluginInfo + sid: 检测键 + is_poke: 是否是戳一戳 + + 返回: + bool: 是否发送提示消息 + """ + if is_poke: + return False + if not base_config.get("IS_SEND_TIP_MESSAGE"): + return False + if plugin.plugin_type == PluginType.DEPENDANT: + return False + return plugin.module != "ai" if self._flmt_s.check(sid) else False + + +freq = FreqUtils() diff --git a/zhenxun/builtin_plugins/hooks/auth_checker.py b/zhenxun/builtin_plugins/hooks/auth_checker.py new file mode 100644 index 00000000..c65ea1a9 --- /dev/null +++ b/zhenxun/builtin_plugins/hooks/auth_checker.py @@ -0,0 +1,125 @@ +from nonebot.adapters import Bot, Event +from nonebot.exception import IgnoredException +from nonebot.matcher import Matcher +from nonebot_plugin_alconna import UniMsg +from nonebot_plugin_uninfo import Uninfo +from tortoise.exceptions import IntegrityError + +from zhenxun.models.plugin_info import PluginInfo +from zhenxun.models.user_console import UserConsole +from zhenxun.services.log import logger +from zhenxun.utils.cache_utils import Cache +from zhenxun.utils.enum import ( + CacheType, + GoldHandle, + PluginType, +) +from zhenxun.utils.exception import InsufficientGold +from zhenxun.utils.platform import PlatformUtils + +from .auth.auth_admin import auth_admin +from .auth.auth_bot import auth_bot +from .auth.auth_cost import auth_cost +from .auth.auth_group import auth_group +from .auth.auth_limit import LimitManage, auth_limit +from .auth.auth_plugin import auth_plugin +from .auth.exception import IsSuperuserException + + +async def auth( + matcher: Matcher, + event: Event, + bot: Bot, + session: Uninfo, + message: UniMsg, +): + """权限检查 + + 参数: + matcher: matcher + event: Event + bot: bot + session: Uninfo + message: UniMsg + """ + user_id = session.user.id + group_id = None + channel_id = None + if session.group: + if session.group.parent: + group_id = session.group.parent.id + channel_id = session.group.id + else: + group_id = session.group.id + is_ignore = False + cost_gold = 0 + try: + from nonebot.adapters.onebot.v11 import PokeNotifyEvent + + if matcher.type == "notice" and not isinstance(event, PokeNotifyEvent): + """过滤除poke外的notice""" + return + except ImportError: + pass + user_cache = Cache[UserConsole](CacheType.USERS) + if matcher.plugin and (module := matcher.plugin.name): + try: + user = await user_cache.get(session.user.id) + except IntegrityError as e: + logger.debug( + "重复创建用户,已跳过该次权限检查...", + "AuthChecker", + session=session, + e=e, + ) + return + plugin = await Cache[PluginInfo](CacheType.PLUGINS).get(module) + if user and plugin: + if plugin.plugin_type == PluginType.HIDDEN: + logger.debug( + f"插件: {plugin.name}:{plugin.module} " + "为HIDDEN,已跳过权限检查..." + ) + return + try: + cost_gold = await auth_cost(user, plugin, session) + if session.user.id in bot.config.superusers: + if plugin.plugin_type == PluginType.SUPERUSER: + raise IsSuperuserException() + if not plugin.limit_superuser: + cost_gold = 0 + raise IsSuperuserException() + await auth_bot(plugin, bot.self_id) + await auth_group(plugin, session, message) + await auth_admin(plugin, session) + await auth_plugin(plugin, session, event) + await auth_limit(plugin, session) + except IsSuperuserException: + logger.debug( + "超级用户或被ban跳过权限检测...", "AuthChecker", session=session + ) + except IgnoredException: + is_ignore = True + LimitManage.unblock(matcher.plugin.name, user_id, group_id, channel_id) + except AssertionError as e: + is_ignore = True + logger.debug("消息无法发送", session=session, e=e) + if cost_gold and user_id: + """花费金币""" + try: + await UserConsole.reduce_gold( + user_id, + cost_gold, + GoldHandle.PLUGIN, + matcher.plugin.name if matcher.plugin else "", + PlatformUtils.get_platform(session), + ) + except InsufficientGold: + if u := await UserConsole.get_user(user_id): + u.gold = 0 + await u.save(update_fields=["gold"]) + # 更新缓存 + await user_cache.update(user_id) + logger.debug(f"调用功能花费金币: {cost_gold}", "AuthChecker", session=session) + if is_ignore: + raise IgnoredException("权限检测 ignore") diff --git a/zhenxun/builtin_plugins/hooks/auth_hook.py b/zhenxun/builtin_plugins/hooks/auth_hook.py index 0ccca75c..b5553b9b 100644 --- a/zhenxun/builtin_plugins/hooks/auth_hook.py +++ b/zhenxun/builtin_plugins/hooks/auth_hook.py @@ -1,18 +1,16 @@ -from nonebot.adapters.onebot.v11 import Bot, Event +from nonebot.adapters import Bot, Event from nonebot.matcher import Matcher from nonebot.message import run_postprocessor, run_preprocessor from nonebot_plugin_alconna import UniMsg -from nonebot_plugin_session import EventSession +from nonebot_plugin_uninfo import Uninfo -from ._auth_checker import LimitManage, checker +from .auth_checker import LimitManage, auth # # 权限检测 @run_preprocessor -async def _( - matcher: Matcher, event: Event, bot: Bot, session: EventSession, message: UniMsg -): - await checker.auth( +async def _(matcher: Matcher, event: Event, bot: Bot, session: Uninfo, message: UniMsg): + await auth( matcher, event, bot, @@ -23,19 +21,16 @@ async def _( # 解除命令block阻塞 @run_postprocessor -async def _( - matcher: Matcher, - exception: Exception | None, - bot: Bot, - event: Event, - session: EventSession, -): - user_id = session.id1 - group_id = session.id3 - channel_id = session.id2 - if not group_id: - group_id = channel_id - channel_id = None +async def _(matcher: Matcher, session: Uninfo): + user_id = session.user.id + group_id = None + channel_id = None + if session.group: + if session.group.parent: + group_id = session.group.parent.id + channel_id = session.group.id + else: + group_id = session.group.id if user_id and matcher.plugin: module = matcher.plugin.name LimitManage.unblock(module, user_id, group_id, channel_id) diff --git a/zhenxun/builtin_plugins/init/__init__.py b/zhenxun/builtin_plugins/init/__init__.py index 870440c8..37df2a02 100644 --- a/zhenxun/builtin_plugins/init/__init__.py +++ b/zhenxun/builtin_plugins/init/__init__.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import Any import nonebot from nonebot.adapters import Bot @@ -6,9 +7,11 @@ from nonebot.adapters import Bot from zhenxun.models.ban_console import BanConsole from zhenxun.models.bot_console import BotConsole from zhenxun.models.group_console import GroupConsole +from zhenxun.models.level_user import LevelUser from zhenxun.models.plugin_info import PluginInfo +from zhenxun.models.user_console import UserConsole from zhenxun.services.log import logger -from zhenxun.utils.cache_utils import Cache +from zhenxun.utils.cache_utils import CacheRoot from zhenxun.utils.enum import CacheType from zhenxun.utils.platform import PlatformUtils @@ -48,45 +51,131 @@ async def _(bot: Bot): ) -@Cache.listener(CacheType.PLUGINS) -async def _(): - data_list = await PluginInfo.get_plugins() - return {p.module: p for p in data_list} +@CacheRoot.new(CacheType.PLUGINS) +async def _(data: dict[str, PluginInfo] = {}, key: str | None = None): + if data and key: + if plugin := await PluginInfo.get_plugin(module=key): + data[key] = plugin + else: + data_list = await PluginInfo.get_plugins() + return {p.module: p for p in data_list} -@Cache.getter(CacheType.PLUGINS, result_model=PluginInfo) -def _(data: dict[str, PluginInfo], module: str): - return data.get(module, None) +@CacheRoot.updater(CacheType.PLUGINS) +async def _(data: dict[str, PluginInfo], key: str, value: Any): + if value: + data[key] = value + elif plugin := await PluginInfo.get_plugin(module=key): + data[key] = plugin -@Cache.listener(CacheType.GROUPS) +@CacheRoot.getter(CacheType.PLUGINS, result_model=PluginInfo) +async def _(data: dict[str, PluginInfo], module: str): + result = data.get(module, None) + if not result: + result = await PluginInfo.get_plugin(module=module) + if result: + data[module] = result + return result + + +@CacheRoot.new(CacheType.GROUPS) async def _(): data_list = await GroupConsole.all() - return {p.group_id: p for p in data_list} + return {p.group_id: p for p in data_list if not p.channel_id} -@Cache.getter(CacheType.GROUPS, result_model=GroupConsole) -def _(data: dict[str, GroupConsole], module: str): - return data.get(module, None) +@CacheRoot.updater(CacheType.GROUPS) +async def _(data: dict[str, GroupConsole], key: str, value: Any): + if value: + data[key] = value + elif group := await GroupConsole.get_group(group_id=key): + data[key] = group -@Cache.listener(CacheType.BOT) +@CacheRoot.getter(CacheType.GROUPS, result_model=GroupConsole) +async def _(data: dict[str, GroupConsole], group_id: str): + result = data.get(group_id, None) + if not result: + result = await GroupConsole.get_group(group_id=group_id) + if result: + data[group_id] = result + return result + + +@CacheRoot.new(CacheType.BOT) async def _(): data_list = await BotConsole.all() return {p.bot_id: p for p in data_list} -@Cache.getter(CacheType.BOT, result_model=BotConsole) -def _(data: dict[str, BotConsole], module: str): - return data.get(module, None) +@CacheRoot.updater(CacheType.BOT) +async def _(data: dict[str, BotConsole], key: str, value: Any): + if value: + data[key] = value + elif bot := await BotConsole.get_or_none(bot_id=key): + data[key] = bot -@Cache.listener(CacheType.BAN) +@CacheRoot.getter(CacheType.BOT, result_model=BotConsole) +async def _(data: dict[str, BotConsole], bot_id: str): + result = data.get(bot_id, None) + if not result: + result = await BotConsole.get_or_none(bot_id=bot_id) + if result: + data[bot_id] = result + return result + + +@CacheRoot.new(CacheType.USERS) +async def _(): + data_list = await UserConsole.all() + return {p.user_id: p for p in data_list} + + +@CacheRoot.updater(CacheType.USERS) +async def _(data: dict[str, UserConsole], key: str, value: Any): + if value: + data[key] = value + elif user := await UserConsole.get_user(user_id=key): + data[key] = user + + +@CacheRoot.getter(CacheType.USERS, result_model=UserConsole) +async def _(data: dict[str, UserConsole], user_id: str): + result = data.get(user_id, None) + if not result: + result = await UserConsole.get_user(user_id=user_id) + if result: + data[user_id] = result + return result + + +@CacheRoot.new(CacheType.LEVEL) +async def _(): + return await LevelUser().all() + + +@CacheRoot.getter(CacheType.LEVEL, result_model=list[LevelUser]) +def _(data_list: list[LevelUser], user_id: str, group_id: str | None = None): + if not group_id: + return [ + data for data in data_list if data.user_id == user_id and not data.group_id + ] + else: + return [ + data + for data in data_list + if data.user_id == user_id and data.group_id == group_id + ] + + +@CacheRoot.new(CacheType.BAN) async def _(): return await BanConsole.all() -@Cache.getter(CacheType.BAN, result_model=list[BanConsole]) +@CacheRoot.getter(CacheType.BAN, result_model=list[BanConsole]) def _(data_list: list[BanConsole], user_id: str, group_id: str): if user_id: if group_id: diff --git a/zhenxun/models/group_console.py b/zhenxun/models/group_console.py index 8520b93d..e6689d1c 100644 --- a/zhenxun/models/group_console.py +++ b/zhenxun/models/group_console.py @@ -7,7 +7,8 @@ from tortoise.backends.base.client import BaseDBAsyncClient from zhenxun.models.plugin_info import PluginInfo from zhenxun.models.task_info import TaskInfo from zhenxun.services.db_context import Model -from zhenxun.utils.enum import PluginType +from zhenxun.utils.cache_utils import CacheRoot +from zhenxun.utils.enum import CacheType, PluginType class GroupConsole(Model): @@ -46,8 +47,7 @@ class GroupConsole(Model): platform = fields.CharField(255, default="qq", description="所属平台") """所属平台""" - class Meta: # type: ignore - table = "group_console" + class Meta: # type: ignore table = "group_console" table_description = "群组信息表" unique_together = ("group_id", "channel_id") @@ -80,6 +80,7 @@ class GroupConsole(Model): return "".join(cls.format(item) for item in data) @classmethod + @CacheRoot.listener(CacheType.GROUPS) async def create( cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any ) -> Self: @@ -100,6 +101,7 @@ class GroupConsole(Model): return group @classmethod + @CacheRoot.listener(CacheType.GROUPS) async def get_or_create( cls, defaults: dict | None = None, @@ -127,6 +129,7 @@ class GroupConsole(Model): return group, is_create @classmethod + @CacheRoot.listener(CacheType.GROUPS) async def update_or_create( cls, defaults: dict | None = None, @@ -216,6 +219,7 @@ class GroupConsole(Model): ) @classmethod + @CacheRoot.listener(CacheType.GROUPS) async def set_block_plugin( cls, group_id: str, @@ -242,6 +246,7 @@ class GroupConsole(Model): await group.save(update_fields=["block_plugin", "superuser_block_plugin"]) @classmethod + @CacheRoot.listener(CacheType.GROUPS) async def set_unblock_plugin( cls, group_id: str, @@ -338,6 +343,7 @@ class GroupConsole(Model): ) @classmethod + @CacheRoot.listener(CacheType.GROUPS) async def set_block_task( cls, group_id: str, @@ -364,6 +370,7 @@ class GroupConsole(Model): await group.save(update_fields=["block_task", "superuser_block_task"]) @classmethod + @CacheRoot.listener(CacheType.GROUPS) async def set_unblock_task( cls, group_id: str, diff --git a/zhenxun/models/plugin_info.py b/zhenxun/models/plugin_info.py index aea208bd..c8c7f0af 100644 --- a/zhenxun/models/plugin_info.py +++ b/zhenxun/models/plugin_info.py @@ -1,10 +1,13 @@ +from typing import Any from typing_extensions import Self from tortoise import fields +from tortoise.backends.base.client import BaseDBAsyncClient from zhenxun.models.plugin_limit import PluginLimit # noqa: F401 from zhenxun.services.db_context import Model -from zhenxun.utils.enum import BlockType, PluginType +from zhenxun.utils.cache_utils import CacheRoot +from zhenxun.utils.enum import BlockType, CacheType, PluginType class PluginInfo(Model): @@ -79,6 +82,13 @@ class PluginInfo(Model): """ return await cls.filter(load_status=load_status, **kwargs).all() + @classmethod + @CacheRoot.listener(CacheType.PLUGINS) + async def create( + cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any + ) -> Self: + return await super().create(using_db=using_db, **kwargs) + @classmethod async def _run_script(cls): return [ diff --git a/zhenxun/utils/cache_utils.py b/zhenxun/utils/cache_utils.py index 62c34c3f..dccbcff4 100644 --- a/zhenxun/utils/cache_utils.py +++ b/zhenxun/utils/cache_utils.py @@ -1,11 +1,14 @@ from collections.abc import Callable +from functools import wraps import time from typing import Any, ClassVar, Generic, TypeVar, cast from nonebot.utils import is_coroutine_callable from pydantic import BaseModel -__all__ = ["Cache", "CacheData"] +from zhenxun.services.log import logger + +__all__ = ["Cache", "CacheData", "CacheRoot"] T = TypeVar("T") @@ -27,10 +30,14 @@ class CacheGetter(BaseModel, Generic[T]): class CacheData(BaseModel): + name: str + """缓存名称""" func: Callable[..., Any] """更新方法""" getter: CacheGetter | None = None """获取方法""" + updater: Callable[..., Any] | None = None + """更新单个方法""" data: Any = None """缓存数据""" expire: int @@ -40,13 +47,36 @@ class CacheData(BaseModel): reload_count: int = 0 """更新次数""" - async def reload(self): + async def get(self, *args, **kwargs) -> Any: + """获取单个缓存""" + if not self.getter: + return self.data + return await self.getter.get(self.data, *args, **kwargs) + + async def update(self, key: str, value: Any = None, *args, **kwargs): + """更新单个缓存""" + if not self.updater: + return logger.warning( + f"缓存类型 {self.name} 没有更新方法,无法更新", "CacheRoot" + ) + if self.data: + if is_coroutine_callable(self.updater): + await self.updater(self.data, key, value, *args, **kwargs) + else: + self.updater(self.data, key, value, *args, **kwargs) + else: + logger.warning(f"缓存类型 {self.name} 为空,无法更新", "CacheRoot") + + async def reload(self, *args, **kwargs): """更新缓存""" self.data = ( - await self.func() if is_coroutine_callable(self.func) else self.func() + await self.func(*args, **kwargs) + if is_coroutine_callable(self.func) + else self.func(*args, **kwargs) ) self.reload_time = time.time() self.reload_count += 1 + logger.debug(f"缓存类型 {self.name} 更新全局缓存", "CacheRoot") async def check_expire(self): if time.time() - self.reload_time > self.expire or not self.reload_count: @@ -54,14 +84,54 @@ class CacheData(BaseModel): class CacheManage: + """全局缓存管理,减少数据库与网络请求查询次数 + + + 异常: + ValueError: 数据名称重复 + ValueError: 数据不存在 + + """ + _data: ClassVar[dict[str, CacheData]] = {} - def listener(self, name: str, expire: int = 60 * 10): + def new(self, name: str, expire: int = 60 * 10): def wrapper(func: Callable): _name = name.upper() if _name in self._data: raise ValueError(f"DbCache 缓存数据 {name} 已存在...") - self._data[_name] = CacheData(func=func, expire=expire) + self._data[_name] = CacheData(name=_name, func=func, expire=expire) + + return wrapper + + def listener(self, name: str): + def decorator(func): + @wraps(func) + async def wrapper(*args, **kwargs): + try: + if is_coroutine_callable(func): + result = await func(*args, **kwargs) + else: + result = func(*args, **kwargs) + return result + finally: + cache_data = self._data.get(name.upper()) + if cache_data: + await cache_data.reload() + logger.debug( + f"缓存类型 {name.upper()} 进行监听更新...", "CacheRoot" + ) + + return wrapper + + return decorator + + def updater(self, name: str): + def wrapper(func: Callable): + _name = name.upper() + if _name not in self._data: + raise ValueError(f"DbCache 缓存数据 {name} 不存在...") + self._data[_name].updater = func return wrapper @@ -83,12 +153,12 @@ class CacheManage: ): await self._data[name].reload() - async def get_cache_data(self, name: str) -> CacheData | None: + async def get_cache_data(self, name: str): if cache := await self.get_cache(name): - return cache + return cache.data return None - async def get_cache(self, name: str): + async def get_cache(self, name: str) -> CacheData | None: name = name.upper() cache = self._data.get(name) if cache: @@ -96,18 +166,35 @@ class CacheManage: return cache return None - async def get(self, name: str, *args, **kwargs) -> T | None: - cache = self._data.get(name.upper()) + async def get(self, name: str, *args, **kwargs): + cache = await self.get_cache(name.upper()) if cache: - return ( - await cache.getter.get(*args, **kwargs) if cache.getter else cache.data - ) + return await cache.get(*args, **kwargs) if cache.getter else cache.data return None - async def reload(self, name: str): - cache = self._data.get(name.upper()) + async def reload(self, name: str, *args, **kwargs): + cache = await self.get_cache(name.upper()) if cache: - await cache.reload() + await cache.reload(*args, **kwargs) + + async def update(self, name: str, key: str, value: Any, *args, **kwargs): + cache = await self.get_cache(name.upper()) + if cache: + await cache.update(key, value, *args, **kwargs) -Cache = CacheManage() +CacheRoot = CacheManage() + + +class Cache(Generic[T]): + def __init__(self, module: str): + self.module = module + + async def get(self, *args, **kwargs) -> T | None: + return await CacheRoot.get(self.module, *args, **kwargs) + + async def update(self, key: str, value: Any = None, *args, **kwargs): + return await CacheRoot.update(self.module, key, value, *args, **kwargs) + + async def reload(self, key: str | None = None, *args, **kwargs): + await CacheRoot.reload(self.module, key, *args, **kwargs) diff --git a/zhenxun/utils/enum.py b/zhenxun/utils/enum.py index 5cc39e82..001dae6f 100644 --- a/zhenxun/utils/enum.py +++ b/zhenxun/utils/enum.py @@ -10,10 +10,14 @@ class CacheType(StrEnum): """全局全部插件""" GROUPS = "GLOBAL_ALL_GROUPS" """全局全部群组""" + USERS = "GLOBAL_ALL_USERS" + """全部用户""" BAN = "GLOBAL_ALL_BAN" """全局ban列表""" BOT = "GLOBAL_BOT" """全局bot信息""" + LEVEL = "GLOBAL_USER_LEVEL" + """用户权限""" class GoldHandle(StrEnum):