From 41dd7677243aa57a8ecdb20ce970abe84aa5e48e Mon Sep 17 00:00:00 2001 From: HibiKier <775757368@qq.com> Date: Tue, 8 Apr 2025 17:11:44 +0800 Subject: [PATCH] =?UTF-8?q?:zap:=20=E4=BC=98=E5=8C=96hook=E6=9D=83?= =?UTF-8?q?=E9=99=90=E6=A3=80=E6=B5=8B=E6=80=A7=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../admin/plugin_switch/_data_source.py | 18 +- .../builtin_plugins/hooks/auth/auth_admin.py | 76 ++--- .../builtin_plugins/hooks/auth/auth_ban.py | 167 +++++++++++ .../builtin_plugins/hooks/auth/auth_bot.py | 18 +- .../builtin_plugins/hooks/auth/auth_cost.py | 21 +- .../builtin_plugins/hooks/auth/auth_group.py | 52 +--- .../builtin_plugins/hooks/auth/auth_limit.py | 15 +- .../builtin_plugins/hooks/auth/auth_plugin.py | 107 ++----- zhenxun/builtin_plugins/hooks/auth/config.py | 13 + .../builtin_plugins/hooks/auth/exception.py | 12 + zhenxun/builtin_plugins/hooks/auth/utils.py | 50 ++++ zhenxun/builtin_plugins/hooks/auth_checker.py | 204 ++++++++----- zhenxun/builtin_plugins/hooks/auth_hook.py | 7 + zhenxun/builtin_plugins/hooks/ban_hook.py | 104 ------- zhenxun/builtin_plugins/init/__init_cache.py | 34 +-- zhenxun/models/group_console.py | 282 ++++++++++++------ zhenxun/services/cache.py | 29 +- zhenxun/services/db_context.py | 6 +- zhenxun/utils/utils.py | 33 ++ 19 files changed, 718 insertions(+), 530 deletions(-) create mode 100644 zhenxun/builtin_plugins/hooks/auth/auth_ban.py create mode 100644 zhenxun/builtin_plugins/hooks/auth/config.py delete mode 100644 zhenxun/builtin_plugins/hooks/ban_hook.py diff --git a/zhenxun/builtin_plugins/admin/plugin_switch/_data_source.py b/zhenxun/builtin_plugins/admin/plugin_switch/_data_source.py index 7af8dc29..8767a78f 100644 --- a/zhenxun/builtin_plugins/admin/plugin_switch/_data_source.py +++ b/zhenxun/builtin_plugins/admin/plugin_switch/_data_source.py @@ -196,7 +196,7 @@ class PluginManage: await PluginInfo.filter(plugin_type=PluginType.NORMAL).update( default_status=status ) - return f'成功将所有功能进群默认状态修改为: {"开启" if status else "关闭"}' + return f"成功将所有功能进群默认状态修改为: {'开启' if status else '关闭'}" if group_id: if group := await GroupConsole.get_or_none( group_id=group_id, channel_id__isnull=True @@ -213,12 +213,12 @@ class PluginManage: module_list = [f"<{module}" for module in module_list] group.block_plugin = ",".join(module_list) + "," # type: ignore await group.save(update_fields=["block_plugin"]) - return f'成功将此群组所有功能状态修改为: {"开启" if status else "关闭"}' + return f"成功将此群组所有功能状态修改为: {'开启' if status else '关闭'}" return "获取群组失败..." await PluginInfo.filter(plugin_type=PluginType.NORMAL).update( status=status, block_type=None if status else BlockType.ALL ) - return f'成功将所有功能全局状态修改为: {"开启" if status else "关闭"}' + return f"成功将所有功能全局状态修改为: {'开启' if status else '关闭'}" @classmethod async def is_wake(cls, group_id: str) -> bool: @@ -243,9 +243,11 @@ class PluginManage: 参数: group_id: 群组id """ - await GroupConsole.filter(group_id=group_id, channel_id__isnull=True).update( - status=False + group, _ = await GroupConsole.get_or_create( + group_id=group_id, channel_id__isnull=True ) + group.status = False + await group.save(update_fields=["status"]) @classmethod async def wake(cls, group_id: str): @@ -254,9 +256,11 @@ class PluginManage: 参数: group_id: 群组id """ - await GroupConsole.filter(group_id=group_id, channel_id__isnull=True).update( - status=True + group, _ = await GroupConsole.get_or_create( + group_id=group_id, channel_id__isnull=True ) + group.status = True + await group.save(update_fields=["status"]) @classmethod async def block(cls, module: str): diff --git a/zhenxun/builtin_plugins/hooks/auth/auth_admin.py b/zhenxun/builtin_plugins/hooks/auth/auth_admin.py index 3bdbe1ef..3d22f6b0 100644 --- a/zhenxun/builtin_plugins/hooks/auth/auth_admin.py +++ b/zhenxun/builtin_plugins/hooks/auth/auth_admin.py @@ -1,15 +1,14 @@ -from nonebot.exception import IgnoredException from nonebot_plugin_alconna import At from nonebot_plugin_uninfo import Uninfo from zhenxun.models.level_user import LevelUser from zhenxun.models.plugin_info import PluginInfo from zhenxun.services.cache import Cache -from zhenxun.services.log import logger from zhenxun.utils.enum import CacheType -from zhenxun.utils.message import MessageUtils +from zhenxun.utils.utils import get_entity_ids -from .utils import freq +from .exception import SkipPluginException +from .utils import send_message async def auth_admin(plugin: PluginInfo, session: Uninfo): @@ -17,60 +16,33 @@ async def auth_admin(plugin: PluginInfo, session: Uninfo): 参数: plugin: PluginInfo - session: PluginInfo + session: Uninfo """ - group_id = None - cache = Cache[list[LevelUser]](CacheType.LEVEL) - user_level = await cache.get(session.user.id) or [] - if session.group: - if session.group.parent: - group_id = session.group.parent.id - else: - group_id = session.group.id - if not plugin.admin_level: return - if group_id: - user_level += await cache.get(session.user.id, group_id) or [] + entity = get_entity_ids(session) + cache = Cache[list[LevelUser]](CacheType.LEVEL) + user_level = await cache.get(session.user.id) or [] + if entity.group_id: + user_level += await cache.get(session.user.id, entity.group_id) or [] user = max(user_level, key=lambda x: x.user_level) if user.user_level < plugin.admin_level: - try: - if freq._flmt.check(session.user.id): - freq._flmt.start_cd(session.user.id) - await MessageUtils.build_message( - [ - At(flag="user", target=session.user.id), - "你的权限不足喔," - f"该功能需要的权限等级: {plugin.admin_level}", - ] - ).send(reply_to=True) - except Exception as e: - logger.error( - "auth_admin 发送消息失败", - "AuthChecker", - session=session, - e=e, - ) - logger.debug( - f"{plugin.name}({plugin.module}) 管理员权限不足...", - "AuthChecker", - session=session, + await send_message( + session, + [ + At(flag="user", target=session.user.id), + f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}", + ], + entity.user_id, + ) + raise SkipPluginException( + f"{plugin.name}({plugin.module}) 管理员权限不足..." ) - raise IgnoredException("管理员权限不足...") elif user_level: user = max(user_level, key=lambda x: x.user_level) if user.user_level < plugin.admin_level: - try: - await MessageUtils.build_message( - f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}" - ).send() - except Exception as e: - logger.error( - "auth_admin 发送消息失败", "AuthChecker", session=session, e=e - ) - logger.debug( - f"{plugin.name}({plugin.module}) 管理员权限不足...", - "AuthChecker", - session=session, - ) - raise IgnoredException("权限不足") + await send_message( + session, + f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}", + ) + raise SkipPluginException(f"{plugin.name}({plugin.module}) 管理员权限不足...") diff --git a/zhenxun/builtin_plugins/hooks/auth/auth_ban.py b/zhenxun/builtin_plugins/hooks/auth/auth_ban.py new file mode 100644 index 00000000..727ab70b --- /dev/null +++ b/zhenxun/builtin_plugins/hooks/auth/auth_ban.py @@ -0,0 +1,167 @@ +from nonebot.adapters import Bot +from nonebot.matcher import Matcher +from nonebot_plugin_alconna import At +from nonebot_plugin_uninfo import Uninfo +from tortoise.exceptions import MultipleObjectsReturned + +from zhenxun.configs.config import Config +from zhenxun.models.ban_console import BanConsole +from zhenxun.models.plugin_info import PluginInfo +from zhenxun.services.cache import Cache +from zhenxun.services.log import logger +from zhenxun.utils.enum import CacheType, PluginType +from zhenxun.utils.utils import EntityIDs, get_entity_ids + +from .config import LOGGER_COMMAND +from .exception import SkipPluginException +from .utils import send_message + +Config.add_plugin_config( + "hook", + "BAN_RESULT", + "才不会给你发消息.", + help="对被ban用户发送的消息", +) + + +async def is_ban(user_id: str | None, group_id: str | None) -> int: + cache = Cache[list[BanConsole]](CacheType.BAN) + results = await cache.get(user_id, group_id) or await cache.get(user_id) + if not results: + return 0 + for result in results: + if result.group_id == group_id and ( + result.duration > 0 or result.duration == -1 + ): + return await BanConsole.check_ban_time(user_id, group_id) + if not result.group_id and result.duration == -1: + return await BanConsole.check_ban_time(user_id, group_id) + return 0 + + +def check_plugin_type(matcher: Matcher) -> bool: + """判断插件类型是否是隐藏插件 + + 参数: + matcher: Matcher + + 返回: + bool: 是否为隐藏插件 + """ + if plugin := matcher.plugin: + if metadata := plugin.metadata: + extra = metadata.extra + if extra.get("plugin_type") in [PluginType.HIDDEN]: + return False + return True + + +def format_time(time: float) -> str: + """格式化时间 + + 参数: + time: ban时长 + + 返回: + str: 格式化时间文本 + """ + if time == -1: + return "∞" + time = abs(int(time)) + if time < 60: + time_str = f"{time!s} 秒" + else: + minute = int(time / 60) + if minute > 60: + hours = minute // 60 + minute %= 60 + time_str = f"{hours} 小时 {minute}分钟" + else: + time_str = f"{minute} 分钟" + return time_str + + +async def group_handle(cache: Cache[list[BanConsole]], group_id: str): + """群组ban检查 + + 参数: + cache: cache + group_id: 群组id + + 异常: + SkipPluginException: 群组处于黑名单 + """ + try: + if await is_ban(None, group_id): + raise SkipPluginException("群组处于黑名单中...") + except MultipleObjectsReturned: + logger.warning( + "群组黑名单数据重复,过滤该次hook并移除多余数据...", LOGGER_COMMAND + ) + ids = await BanConsole.filter(user_id="", group_id=group_id).values_list( + "id", flat=True + ) + await BanConsole.filter(id__in=ids[:-1]).delete() + await cache.reload() + + +async def user_handle( + module: str, cache: Cache[list[BanConsole]], entity: EntityIDs, session: Uninfo +): + """用户ban检查 + + 参数: + module: 插件模块名 + cache: cache + user_id: 用户id + session: Uninfo + + 异常: + SkipPluginException: 用户处于黑名单 + """ + ban_result = Config.get_config("hook", "BAN_RESULT") + try: + time = await is_ban(entity.user_id, entity.group_id) + if not time: + return + time_str = format_time(time) + db_plugin = await Cache[PluginInfo](CacheType.PLUGINS).get(module) + if ( + db_plugin + # and not db_plugin.ignore_prompt + and time != -1 + and ban_result + ): + await send_message( + session, + [ + At(flag="user", target=entity.user_id), + f"{ban_result}\n在..在 {time_str} 后才会理你喔", + ], + entity.user_id, + ) + raise SkipPluginException("用户处于黑名单中...") + except MultipleObjectsReturned: + logger.warning( + "用户黑名单数据重复,过滤该次hook并移除多余数据...", LOGGER_COMMAND + ) + ids = await BanConsole.filter(user_id=entity.user_id, group_id="").values_list( + "id", flat=True + ) + await BanConsole.filter(id__in=ids[:-1]).delete() + await cache.reload() + + +async def auth_ban(matcher: Matcher, bot: Bot, session: Uninfo): + if not check_plugin_type(matcher): + return + if not matcher.plugin_name: + return + entity = get_entity_ids(session) + if entity.user_id in bot.config.superusers: + return + cache = Cache[list[BanConsole]](CacheType.BAN) + if entity.group_id: + await group_handle(cache, entity.group_id) + if entity.user_id: + await user_handle(matcher.plugin_name, cache, entity, session) diff --git a/zhenxun/builtin_plugins/hooks/auth/auth_bot.py b/zhenxun/builtin_plugins/hooks/auth/auth_bot.py index b10b0079..2427223f 100644 --- a/zhenxun/builtin_plugins/hooks/auth/auth_bot.py +++ b/zhenxun/builtin_plugins/hooks/auth/auth_bot.py @@ -1,12 +1,11 @@ -from nonebot.exception import IgnoredException - from zhenxun.models.bot_console import BotConsole from zhenxun.models.plugin_info import PluginInfo from zhenxun.services.cache import Cache -from zhenxun.services.log import logger from zhenxun.utils.common_utils import CommonUtils from zhenxun.utils.enum import CacheType +from .exception import SkipPluginException + async def auth_bot(plugin: PluginInfo, bot_id: str): """bot层面的权限检查 @@ -16,17 +15,14 @@ async def auth_bot(plugin: PluginInfo, bot_id: str): bot_id: bot id 异常: - IgnoredException: 忽略插件 - IgnoredException: 忽略插件 + SkipPluginException: 忽略插件 + SkipPluginException: 忽略插件 """ if cache := Cache[BotConsole](CacheType.BOT): bot = await cache.get(bot_id) if not bot or not bot.status: - logger.debug("Bot不存在或休眠中阻断权限检测...", "AuthChecker") - raise IgnoredException("BotConsole休眠权限检测 ignore") + raise SkipPluginException("Bot不存在或休眠中阻断权限检测...") if CommonUtils.format(plugin.module) in bot.block_plugins: - logger.debug( - f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭...", - "AuthChecker", + raise SkipPluginException( + f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭..." ) - raise IgnoredException("BotConsole插件权限检测 ignore") diff --git a/zhenxun/builtin_plugins/hooks/auth/auth_cost.py b/zhenxun/builtin_plugins/hooks/auth/auth_cost.py index 9a5e9a48..7a971085 100644 --- a/zhenxun/builtin_plugins/hooks/auth/auth_cost.py +++ b/zhenxun/builtin_plugins/hooks/auth/auth_cost.py @@ -1,10 +1,10 @@ -from nonebot.exception import IgnoredException from nonebot_plugin_uninfo import Uninfo from zhenxun.models.plugin_info import PluginInfo from zhenxun.models.user_console import UserConsole -from zhenxun.services.log import logger -from zhenxun.utils.message import MessageUtils + +from .exception import SkipPluginException +from .utils import send_message async def auth_cost(user: UserConsole, plugin: PluginInfo, session: Uninfo) -> int: @@ -19,17 +19,6 @@ async def auth_cost(user: UserConsole, plugin: PluginInfo, session: Uninfo) -> i """ if user.gold < plugin.cost_gold: """插件消耗金币不足""" - try: - await MessageUtils.build_message( - f"金币不足..该功能需要{plugin.cost_gold}金币.." - ).send() - except Exception as e: - logger.error("auth_cost 发送消息失败", "AuthChecker", session=session, e=e) - logger.debug( - f"{plugin.name}({plugin.module}) 金币限制.." - f"该功能需要{plugin.cost_gold}金币..", - "AuthChecker", - session=session, - ) - raise IgnoredException(f"{plugin.name}({plugin.module}) 金币限制...") + await send_message(session, f"金币不足..该功能需要{plugin.cost_gold}金币..") + raise SkipPluginException(f"{plugin.name}({plugin.module}) 金币限制...") return plugin.cost_gold diff --git a/zhenxun/builtin_plugins/hooks/auth/auth_group.py b/zhenxun/builtin_plugins/hooks/auth/auth_group.py index 313c2bd5..290a3ad9 100644 --- a/zhenxun/builtin_plugins/hooks/auth/auth_group.py +++ b/zhenxun/builtin_plugins/hooks/auth/auth_group.py @@ -1,57 +1,35 @@ -from nonebot.exception import IgnoredException from nonebot_plugin_alconna import UniMsg -from nonebot_plugin_uninfo import Uninfo from zhenxun.models.group_console import GroupConsole from zhenxun.models.plugin_info import PluginInfo from zhenxun.services.cache import Cache -from zhenxun.services.log import logger from zhenxun.utils.enum import CacheType +from zhenxun.utils.utils import EntityIDs + +from .config import SwitchEnum +from .exception import SkipPluginException -async def auth_group(plugin: PluginInfo, session: Uninfo, message: UniMsg): +async def auth_group(plugin: PluginInfo, entity: EntityIDs, message: UniMsg): """群黑名单检测 群总开关检测 参数: plugin: PluginInfo - session: EventSession + entity: EntityIDs message: UniMsg """ - if not session.group: + if not entity.group_id: return - if session.group.parent: - group_id = session.group.parent.id - else: - group_id = session.group.id text = message.extract_plain_text() - group = await Cache[GroupConsole](CacheType.GROUPS).get(group_id) + group = await Cache[GroupConsole](CacheType.GROUPS).get(entity.group_id) if not group: - """群不存在""" - logger.debug( - "群组信息不存在...", - "AuthChecker", - session=session, - ) - raise IgnoredException("群不存在") + raise SkipPluginException("群组信息不存在...") if group.level < 0: - """群权限小于0""" - logger.debug( - "群黑名单, 群权限-1...", - "AuthChecker", - session=session, - ) - raise IgnoredException("群黑名单") - if not group.status: - """群休眠""" - if text.strip() != "醒来": - logger.debug("群休眠状态...", "AuthChecker", session=session) - raise IgnoredException("群休眠状态") + raise SkipPluginException("群组黑名单, 目标群组群权限权限-1...") + if text.strip() != SwitchEnum.ENABLE and not group.status: + raise SkipPluginException("群组休眠状态...") if plugin.level > group.level: - """插件等级大于群等级""" - logger.debug( - f"{plugin.name}({plugin.module}) 群等级限制.." - f"该功能需要的群等级: {plugin.level}..", - "AuthChecker", - session=session, + raise SkipPluginException( + f"{plugin.name}({plugin.module}) 群等级限制," + f"该功能需要的群等级: {plugin.level}..." ) - raise IgnoredException(f"{plugin.name}({plugin.module}) 群等级限制...") diff --git a/zhenxun/builtin_plugins/hooks/auth/auth_limit.py b/zhenxun/builtin_plugins/hooks/auth/auth_limit.py index 0da55b72..e56fa9ed 100644 --- a/zhenxun/builtin_plugins/hooks/auth/auth_limit.py +++ b/zhenxun/builtin_plugins/hooks/auth/auth_limit.py @@ -1,6 +1,5 @@ from typing import ClassVar -from nonebot.exception import IgnoredException from nonebot_plugin_uninfo import Uninfo from pydantic import BaseModel @@ -11,6 +10,9 @@ from zhenxun.utils.enum import LimitWatchType, PluginLimitType from zhenxun.utils.message import MessageUtils from zhenxun.utils.utils import CountLimiter, FreqLimiter, UserBlockLimiter +from .config import LOGGER_COMMAND +from .exception import SkipPluginException + class Limit(BaseModel): limit: PluginLimit @@ -69,7 +71,7 @@ class LimitManage: key_type = channel_id or group_id logger.debug( f"解除对象: {key_type} 的block限制", - "AuthChecker", + LOGGER_COMMAND, session=user_id, group_id=group_id, ) @@ -139,16 +141,13 @@ class LimitManage: if is_limit and not limiter.check(key_type): if limit.result: await MessageUtils.build_message(limit.result).send() - logger.debug( - f"{limit.module}({limit.limit_type}) 正在限制中...", - "AuthChecker", - session=session, + raise SkipPluginException( + f"{limit.module}({limit.limit_type}) 正在限制中..." ) - raise IgnoredException(f"{limit.module} 正在限制中...") else: logger.debug( f"开始进行限制 {limit.module}({limit.limit_type})...", - "AuthChecker", + LOGGER_COMMAND, session=user_id, group_id=group_id, ) diff --git a/zhenxun/builtin_plugins/hooks/auth/auth_plugin.py b/zhenxun/builtin_plugins/hooks/auth/auth_plugin.py index 40a355da..ebfe7be1 100644 --- a/zhenxun/builtin_plugins/hooks/auth/auth_plugin.py +++ b/zhenxun/builtin_plugins/hooks/auth/auth_plugin.py @@ -1,17 +1,15 @@ from nonebot.adapters import Event -from nonebot.exception import IgnoredException from nonebot_plugin_uninfo import Uninfo from zhenxun.models.group_console import GroupConsole from zhenxun.models.plugin_info import PluginInfo from zhenxun.services.cache import Cache -from zhenxun.services.log import logger from zhenxun.utils.common_utils import CommonUtils from zhenxun.utils.enum import BlockType, CacheType -from zhenxun.utils.message import MessageUtils +from zhenxun.utils.utils import get_entity_ids -from .exception import IsSuperuserException -from .utils import freq +from .exception import IsSuperuserException, SkipPluginException +from .utils import freq, is_poke, send_message class GroupCheck: @@ -42,16 +40,12 @@ class GroupCheck: group = await self.__get_data() if group and CommonUtils.format(plugin.module) in group.superuser_block_plugin: if freq.is_send_limit_message(plugin, group.group_id, self.is_poke): - freq._flmt_s.start_cd(self.group_id) - await MessageUtils.build_message("超级管理员禁用了该群此功能...").send( - reply_to=True + await send_message( + self.session, "超级管理员禁用了该群此功能...", self.group_id ) - logger.debug( - f"{plugin.name}({plugin.module}) 超级管理员禁用了该群此功能...", - "AuthChecker", - session=self.session, + raise SkipPluginException( + f"{plugin.name}({plugin.module}) 超级管理员禁用了该群此功能..." ) - raise IgnoredException("超级管理员禁用了该群此功能...") await self.check_normal_block(self.plugin) async def check_normal_block(self, plugin: PluginInfo): @@ -66,16 +60,8 @@ class GroupCheck: group = await self.__get_data() if group and CommonUtils.format(plugin.module) in group.block_plugin: if freq.is_send_limit_message(plugin, self.group_id, self.is_poke): - freq._flmt_s.start_cd(self.group_id) - await MessageUtils.build_message("该群未开启此功能...").send( - reply_to=True - ) - logger.debug( - f"{plugin.name}({plugin.module}) 未开启此功能...", - "AuthChecker", - session=self.session, - ) - raise IgnoredException("该群未开启此功能...") + await send_message(self.session, "该群未开启此功能...", self.group_id) + raise SkipPluginException(f"{plugin.name}({plugin.module}) 未开启此功能...") await self.check_global_block(self.plugin) async def check_global_block(self, plugin: PluginInfo): @@ -89,25 +75,13 @@ class GroupCheck: """ if plugin.block_type == BlockType.GROUP: """全局群组禁用""" - try: - if freq.is_send_limit_message(plugin, self.group_id, self.is_poke): - freq._flmt_c.start_cd(self.group_id) - await MessageUtils.build_message("该功能在群组中已被禁用...").send( - reply_to=True - ) - except Exception as e: - logger.error( - "auth_plugin 发送消息失败", - "AuthChecker", - session=self.session, - e=e, + if freq.is_send_limit_message(plugin, self.group_id, self.is_poke): + await send_message( + self.session, "该功能在群组中已被禁用...", self.group_id ) - logger.debug( - f"{plugin.name}({plugin.module}) 该插件在群组中已被禁用...", - "AuthChecker", - session=self.session, + raise SkipPluginException( + f"{plugin.name}({plugin.module}) 该插件在群组中已被禁用..." ) - raise IgnoredException("该插件在群组中已被禁用...") class PluginCheck: @@ -126,25 +100,11 @@ class PluginCheck: IgnoredException: 忽略插件 """ if plugin.block_type == BlockType.PRIVATE: - try: - if freq.is_send_limit_message( - plugin, self.session.user.id, self.is_poke - ): - freq._flmt_c.start_cd(self.session.user.id) - await MessageUtils.build_message("该功能在私聊中已被禁用...").send() - except Exception as e: - logger.error( - "auth_admin 发送消息失败", - "AuthChecker", - session=self.session, - e=e, - ) - logger.debug( - f"{plugin.name}({plugin.module}) 该插件在私聊中已被禁用...", - "AuthChecker", - session=self.session, + if freq.is_send_limit_message(plugin, self.session.user.id, self.is_poke): + await send_message(self.session, "该功能在私聊中已被禁用...") + raise SkipPluginException( + f"{plugin.name}({plugin.module}) 该插件在私聊中已被禁用..." ) - raise IgnoredException("该插件在私聊中已被禁用...") async def check_global(self, plugin: PluginInfo): """全局状态 @@ -162,16 +122,10 @@ class PluginCheck: if self.group_id and (group := await cache.get(self.group_id)): if group.is_super: raise IsSuperuserException() - logger.debug( - f"{plugin.name}({plugin.module}) 全局未开启此功能...", - "AuthChecker", - session=self.session, - ) sid = self.group_id or self.session.user.id if freq.is_send_limit_message(plugin, sid, self.is_poke): - freq._flmt_s.start_cd(sid) - await MessageUtils.build_message("全局未开启此功能...").send() - raise IgnoredException("全局未开启此功能...") + await send_message(self.session, "全局未开启此功能...", sid) + raise SkipPluginException(f"{plugin.name}({plugin.module}) 全局未开启此功能...") async def auth_plugin(plugin: PluginInfo, session: Uninfo, event: Event): @@ -180,22 +134,13 @@ async def auth_plugin(plugin: PluginInfo, session: Uninfo, event: Event): 参数: plugin: PluginInfo session: Uninfo + event: Event """ - group_id = None - if session.group: - if session.group.parent: - group_id = session.group.parent.id - else: - group_id = session.group.id - try: - from nonebot.adapters.onebot.v11 import PokeNotifyEvent - - is_poke = isinstance(event, PokeNotifyEvent) - except ImportError: - is_poke = False - user_check = PluginCheck(group_id, session, is_poke) - if group_id: - group_check = GroupCheck(plugin, group_id, session, is_poke) + entity = get_entity_ids(session) + is_poke_event = is_poke(event) + user_check = PluginCheck(entity.group_id, session, is_poke_event) + if entity.group_id: + group_check = GroupCheck(plugin, entity.group_id, session, is_poke_event) await group_check.check() else: await user_check.check_user(plugin) diff --git a/zhenxun/builtin_plugins/hooks/auth/config.py b/zhenxun/builtin_plugins/hooks/auth/config.py new file mode 100644 index 00000000..d68b7d00 --- /dev/null +++ b/zhenxun/builtin_plugins/hooks/auth/config.py @@ -0,0 +1,13 @@ +import sys + +if sys.version_info >= (3, 11): + from enum import StrEnum +else: + from strenum import StrEnum + +LOGGER_COMMAND = "AuthChecker" + + +class SwitchEnum(StrEnum): + ENABLE = "醒来" + DISABLE = "休息吧" diff --git a/zhenxun/builtin_plugins/hooks/auth/exception.py b/zhenxun/builtin_plugins/hooks/auth/exception.py index cc0d3fde..195b29d2 100644 --- a/zhenxun/builtin_plugins/hooks/auth/exception.py +++ b/zhenxun/builtin_plugins/hooks/auth/exception.py @@ -1,2 +1,14 @@ class IsSuperuserException(Exception): pass + + +class SkipPluginException(Exception): + def __init__(self, info: str, *args: object) -> None: + super().__init__(*args) + self.info = info + + def __str__(self) -> str: + return self.info + + def __repr__(self) -> str: + return self.info diff --git a/zhenxun/builtin_plugins/hooks/auth/utils.py b/zhenxun/builtin_plugins/hooks/auth/utils.py index 5ce3c7fa..0f925590 100644 --- a/zhenxun/builtin_plugins/hooks/auth/utils.py +++ b/zhenxun/builtin_plugins/hooks/auth/utils.py @@ -1,11 +1,61 @@ +import contextlib + +from nonebot.adapters import Event +from nonebot_plugin_uninfo import Uninfo + from zhenxun.configs.config import Config from zhenxun.models.plugin_info import PluginInfo +from zhenxun.services.log import logger from zhenxun.utils.enum import PluginType +from zhenxun.utils.message import MessageUtils from zhenxun.utils.utils import FreqLimiter +from .config import LOGGER_COMMAND + base_config = Config.get("hook") +def is_poke(event: Event) -> bool: + """判断是否为poke类型 + + 参数: + event: Event + + 返回: + bool: 是否为poke类型 + """ + with contextlib.suppress(ImportError): + from nonebot.adapters.onebot.v11 import PokeNotifyEvent + + return isinstance(event, PokeNotifyEvent) + return False + + +async def send_message( + session: Uninfo, message: list | str, check_tag: str | None = None +): + """发送消息 + + 参数: + session: Uninfo + message: 消息 + check_tag: cd flag + """ + try: + if not check_tag: + await MessageUtils.build_message(message).send(reply_to=True) + elif freq._flmt.check(check_tag): + freq._flmt.start_cd(check_tag) + await MessageUtils.build_message(message).send(reply_to=True) + except Exception as e: + logger.error( + "发送消息失败", + LOGGER_COMMAND, + session=session, + e=e, + ) + + class FreqUtils: def __init__(self): check_notice_info_cd = Config.get_config("hook", "CHECK_NOTICE_INFO_CD") diff --git a/zhenxun/builtin_plugins/hooks/auth_checker.py b/zhenxun/builtin_plugins/hooks/auth_checker.py index 061f1cdf..b65e8bfa 100644 --- a/zhenxun/builtin_plugins/hooks/auth_checker.py +++ b/zhenxun/builtin_plugins/hooks/auth_checker.py @@ -1,4 +1,4 @@ -import contextlib +import asyncio from nonebot.adapters import Bot, Event from nonebot.exception import IgnoredException @@ -18,14 +18,107 @@ from zhenxun.utils.enum import ( ) from zhenxun.utils.exception import InsufficientGold from zhenxun.utils.platform import PlatformUtils +from zhenxun.utils.utils import get_entity_ids from .auth.auth_admin import auth_admin +from .auth.auth_ban import auth_ban from .auth.auth_bot import auth_bot from .auth.auth_cost import auth_cost from .auth.auth_group import auth_group from .auth.auth_limit import LimitManage, auth_limit from .auth.auth_plugin import auth_plugin -from .auth.exception import IsSuperuserException +from .auth.config import LOGGER_COMMAND +from .auth.exception import IsSuperuserException, SkipPluginException + + +async def get_plugin_and_user( + module: str, user_id: str +) -> tuple[PluginInfo, UserConsole]: + """获取用户数据和插件信息 + + 参数: + module: 模块名 + user_id: 用户id + + 异常: + SkipPluginException: 插件数据不存在 + SkipPluginException: 插件类型为HIDDEN + SkipPluginException: 重复创建用户 + SkipPluginException: 用户数据不存在 + + 返回: + tuple[PluginInfo, UserConsole]: 插件信息,用户信息 + """ + user_cache = Cache[UserConsole](CacheType.USERS) + plugin = await Cache[PluginInfo](CacheType.PLUGINS).get(module) + if not plugin: + raise SkipPluginException(f"插件:{module} 数据不存在,已跳过权限检查...") + if plugin.plugin_type == PluginType.HIDDEN: + raise SkipPluginException( + f"插件: {plugin.name}:{plugin.module} 为HIDDEN,已跳过权限检查..." + ) + user = None + try: + user = await user_cache.get(user_id) + except IntegrityError as e: + raise SkipPluginException("重复创建用户,已跳过该次权限检查...") from e + if not user: + raise SkipPluginException("用户数据不存在,已跳过权限检查...") + return plugin, user + + +async def get_plugin_cost( + bot: Bot, user: UserConsole, plugin: PluginInfo, session: Uninfo +) -> int: + """获取插件费用 + + 参数: + bot: Bot + user: 用户数据 + plugin: 插件数据 + session: Uninfo + + 异常: + IsSuperuserException: 超级用户 + IsSuperuserException: 超级用户 + + 返回: + int: 调用插件金币费用 + """ + cost_gold = await auth_cost(user, plugin, session) + if session.user.id in bot.config.superusers: + if plugin.plugin_type == PluginType.SUPERUSER: + raise IsSuperuserException() + if not plugin.limit_superuser: + raise IsSuperuserException() + return cost_gold + + +async def reduce_gold(user_id: str, module: str, cost_gold: int, session: Uninfo): + """扣除用户金币 + + 参数: + user_id: 用户id + module: 插件模块名称 + cost_gold: 消耗金币 + session: Uninfo + """ + user_cache = Cache[UserConsole](CacheType.USERS) + try: + await UserConsole.reduce_gold( + user_id, + cost_gold, + GoldHandle.PLUGIN, + module, + PlatformUtils.get_platform(session), + ) + except InsufficientGold: + if u := await UserConsole.get_user(user_id): + u.gold = 0 + await u.save(update_fields=["gold"]) + # 更新缓存 + await user_cache.update(user_id) + logger.debug(f"调用功能花费金币: {cost_gold}", LOGGER_COMMAND, session=session) async def auth( @@ -44,85 +137,32 @@ async def auth( session: Uninfo message: UniMsg """ - user_id = session.user.id - group_id = None - channel_id = None - if session.group: - if session.group.parent: - group_id = session.group.parent.id - channel_id = session.group.id - else: - group_id = session.group.id - is_ignore = False cost_gold = 0 - with contextlib.suppress(ImportError): - from nonebot.adapters.onebot.v11 import PokeNotifyEvent - - if matcher.type == "notice" and not isinstance(event, PokeNotifyEvent): - """过滤除poke外的notice""" - return - user_cache = Cache[UserConsole](CacheType.USERS) - if matcher.plugin and (module := matcher.plugin.name): - plugin = await Cache[PluginInfo](CacheType.PLUGINS).get(module) - if not plugin: - return logger.debug(f"插件:{module} 数据不存在,已跳过权限检查...") - if plugin.plugin_type == PluginType.HIDDEN: - return logger.debug( - f"插件: {plugin.name}:{plugin.module} 为HIDDEN,已跳过权限检查..." - ) - user = None - try: - user = await user_cache.get(session.user.id) - except IntegrityError as e: - logger.debug( - "重复创建用户,已跳过该次权限检查...", - "AuthChecker", - session=session, - e=e, - ) - if not user: - return logger.debug( - "用户数据不存在,已跳过权限检查...", "AuthChecker", session=session - ) - try: - cost_gold = await auth_cost(user, plugin, session) - if session.user.id in bot.config.superusers: - if plugin.plugin_type == PluginType.SUPERUSER: - raise IsSuperuserException() - if not plugin.limit_superuser: - cost_gold = 0 - raise IsSuperuserException() - await auth_bot(plugin, bot.self_id) - await auth_group(plugin, session, message) - await auth_admin(plugin, session) - await auth_plugin(plugin, session, event) - await auth_limit(plugin, session) - except IsSuperuserException: - logger.debug( - "超级用户或被ban跳过权限检测...", "AuthChecker", session=session - ) - except IgnoredException: - is_ignore = True - LimitManage.unblock(matcher.plugin.name, user_id, group_id, channel_id) - except AssertionError as e: - is_ignore = True - logger.debug("消息无法发送", session=session, e=e) - if cost_gold and user_id: - """花费金币""" - try: - await UserConsole.reduce_gold( - user_id, - cost_gold, - GoldHandle.PLUGIN, - matcher.plugin.name if matcher.plugin else "", - PlatformUtils.get_platform(session), - ) - except InsufficientGold: - if u := await UserConsole.get_user(user_id): - u.gold = 0 - await u.save(update_fields=["gold"]) - # 更新缓存 - await user_cache.update(user_id) - logger.debug(f"调用功能花费金币: {cost_gold}", "AuthChecker", session=session) - if is_ignore: + ignore_flag = False + entity = get_entity_ids(session) + module = matcher.plugin_name or "" + try: + if not module: + raise SkipPluginException("Matcher插件名称不存在...") + plugin, user = await get_plugin_and_user(module, entity.user_id) + cost_gold = await get_plugin_cost(bot, user, plugin, session) + await asyncio.gather( + *[ + auth_ban(matcher, bot, session), + auth_bot(plugin, bot.self_id), + auth_group(plugin, entity, message), + auth_admin(plugin, session), + auth_plugin(plugin, session, event), + ] + ) + await auth_limit(plugin, session) + except SkipPluginException as e: + LimitManage.unblock(module, entity.user_id, entity.group_id, entity.channel_id) + logger.info(str(e), LOGGER_COMMAND, session=session) + ignore_flag = True + except IsSuperuserException: + logger.debug("超级用户跳过权限检测...", LOGGER_COMMAND, session=session) + if not ignore_flag and cost_gold > 0: + await reduce_gold(entity.user_id, module, cost_gold, session) + if ignore_flag: raise IgnoredException("权限检测 ignore") diff --git a/zhenxun/builtin_plugins/hooks/auth_hook.py b/zhenxun/builtin_plugins/hooks/auth_hook.py index b5553b9b..a53935fe 100644 --- a/zhenxun/builtin_plugins/hooks/auth_hook.py +++ b/zhenxun/builtin_plugins/hooks/auth_hook.py @@ -1,15 +1,21 @@ +import time + from nonebot.adapters import Bot, Event from nonebot.matcher import Matcher from nonebot.message import run_postprocessor, run_preprocessor from nonebot_plugin_alconna import UniMsg from nonebot_plugin_uninfo import Uninfo +from zhenxun.services.log import logger + +from .auth.config import LOGGER_COMMAND from .auth_checker import LimitManage, auth # # 权限检测 @run_preprocessor async def _(matcher: Matcher, event: Event, bot: Bot, session: Uninfo, message: UniMsg): + start_time = time.time() await auth( matcher, event, @@ -17,6 +23,7 @@ async def _(matcher: Matcher, event: Event, bot: Bot, session: Uninfo, message: session, message, ) + logger.info(f"权限检测耗时:{time.time() - start_time}秒", LOGGER_COMMAND) # 解除命令block阻塞 diff --git a/zhenxun/builtin_plugins/hooks/ban_hook.py b/zhenxun/builtin_plugins/hooks/ban_hook.py deleted file mode 100644 index 3c38678a..00000000 --- a/zhenxun/builtin_plugins/hooks/ban_hook.py +++ /dev/null @@ -1,104 +0,0 @@ -from nonebot.adapters import Bot -from nonebot.exception import IgnoredException -from nonebot.matcher import Matcher -from nonebot.message import run_preprocessor -from nonebot_plugin_alconna import At -from nonebot_plugin_uninfo import Uninfo -from tortoise.exceptions import MultipleObjectsReturned - -from zhenxun.configs.config import Config -from zhenxun.models.ban_console import BanConsole -from zhenxun.models.group_console import GroupConsole -from zhenxun.models.plugin_info import PluginInfo -from zhenxun.services.cache import Cache -from zhenxun.services.log import logger -from zhenxun.utils.enum import CacheType, PluginType -from zhenxun.utils.message import MessageUtils -from zhenxun.utils.utils import FreqLimiter - -Config.add_plugin_config( - "hook", - "BAN_RESULT", - "才不会给你发消息.", - help="对被ban用户发送的消息", -) - -_flmt = FreqLimiter(300) - - -async def is_ban(user_id: str | None, group_id: str | None): - cache = Cache[list[BanConsole]](CacheType.BAN) - result = await cache.get(user_id, group_id) or await cache.get(user_id) - return result and result[0].ban_time > 0 - - -# 检查是否被ban -@run_preprocessor -async def _(matcher: Matcher, bot: Bot, session: Uninfo): - if plugin := matcher.plugin: - if metadata := plugin.metadata: - extra = metadata.extra - if extra.get("plugin_type") in [PluginType.HIDDEN]: - return - user_id = session.user.id - group_id = session.group.id if session.group else None - cache = Cache[list[BanConsole]](CacheType.BAN) - if user_id in bot.config.superusers: - return - if group_id: - try: - if await is_ban(None, group_id): - logger.debug("群组处于黑名单中...", "BanChecker") - raise IgnoredException("群组处于黑名单中...") - except MultipleObjectsReturned: - logger.warning( - "群组黑名单数据重复,过滤该次hook并移除多余数据...", "BanChecker" - ) - ids = await BanConsole.filter(user_id="", group_id=group_id).values_list( - "id", flat=True - ) - await BanConsole.filter(id__in=ids[:-1]).delete() - await cache.reload() - group_cache = Cache[GroupConsole](CacheType.GROUPS) - if g := await group_cache.get(group_id): - if g.level < 0: - logger.debug("群黑名单, 群权限-1...", "BanChecker") - raise IgnoredException("群黑名单, 群权限-1..") - if user_id: - ban_result = Config.get_config("hook", "BAN_RESULT") - if await is_ban(user_id, group_id): - time = await BanConsole.check_ban_time(user_id, group_id) - if time == -1: - time_str = "∞" - else: - time = abs(int(time)) - if time < 60: - time_str = f"{time!s} 秒" - else: - minute = int(time / 60) - if minute > 60: - hours = minute // 60 - minute %= 60 - time_str = f"{hours} 小时 {minute}分钟" - else: - time_str = f"{minute} 分钟" - db_plugin = await Cache[PluginInfo](CacheType.PLUGINS).get( - matcher.plugin_name - ) - if ( - db_plugin - # and not db_plugin.ignore_prompt - and time != -1 - and ban_result - and _flmt.check(user_id) - ): - _flmt.start_cd(user_id) - logger.debug(f"ban检测发送插件: {matcher.plugin_name}") - await MessageUtils.build_message( - [ - At(flag="user", target=user_id), - f"{ban_result}\n在..在 {time_str} 后才会理你喔", - ] - ).send() - logger.debug("用户处于黑名单中...", "BanChecker") - raise IgnoredException("用户处于黑名单中...") diff --git a/zhenxun/builtin_plugins/init/__init_cache.py b/zhenxun/builtin_plugins/init/__init_cache.py index d7cec120..ceced884 100644 --- a/zhenxun/builtin_plugins/init/__init_cache.py +++ b/zhenxun/builtin_plugins/init/__init_cache.py @@ -42,6 +42,8 @@ def default_with_expiration( data: dict[str, Any], expire_data: dict[str, int], expire: int ): """默认更新期时间cache方法""" + if not data: + return {} keys = {k for k in data if k not in expire_data} return {k: time.time() + expire for k in keys} if keys else {} @@ -52,12 +54,6 @@ async def _(): return {p.module: p for p in data_list} -@CacheRoot.new(CacheType.PLUGINS) -async def _(): - data_list = await PluginInfo.get_plugins() - return {p.module: p for p in data_list} - - @CacheRoot.updater(CacheType.PLUGINS) async def _(data: dict[str, PluginInfo], key: str, value: Any): if value: @@ -109,21 +105,20 @@ async def _(data: dict[str, GroupConsole], key: str, value: Any): @CacheRoot.getter(CacheType.GROUPS, result_model=GroupConsole) -async def _(data: dict[str, GroupConsole] | None, group_id: str): - if not data: - data = {} - result = data.get(group_id, None) +async def _(cache_data: CacheData, group_id: str): + cache_data.data = cache_data.data or {} + result = cache_data.data.get(group_id, None) if not result: result = await GroupConsole.get_group(group_id=group_id) if result: - data[group_id] = result + cache_data.data[group_id] = result return result @CacheRoot.with_refresh(CacheType.GROUPS) async def _(data: dict[str, GroupConsole]): groups = await GroupConsole.filter( - group_id__in=data.keys(), channel_id__isnull=True, load_status=True + group_id__in=data.keys(), channel_id__isnull=True ) for group in groups: data[group.group_id] = group @@ -154,14 +149,13 @@ async def _(data: dict[str, BotConsole], key: str, value: Any): @CacheRoot.getter(CacheType.BOT, result_model=BotConsole) -async def _(data: dict[str, BotConsole] | None, bot_id: str): - if not data: - data = {} - result = data.get(bot_id, None) +async def _(cache_data: CacheData, bot_id: str): + cache_data.data = cache_data.data or {} + result = cache_data.data.get(bot_id, None) if not result: result = await BotConsole.get_or_none(bot_id=bot_id) if result: - data[bot_id] = result + cache_data.data[bot_id] = result return result @@ -224,7 +218,7 @@ def _(cache_data: CacheData): return default_cleanup_expired(cache_data) -@CacheRoot.new(CacheType.LEVEL) +@CacheRoot.new(CacheType.LEVEL, False) async def _(): return await LevelUser().all() @@ -246,13 +240,13 @@ async def _(cache_data: CacheData, user_id: str, group_id: str | None = None): ] -@CacheRoot.new(CacheType.BAN) +@CacheRoot.new(CacheType.BAN, False, 5) async def _(): return await BanConsole.all() @CacheRoot.getter(CacheType.BAN, result_model=list[BanConsole]) -def _(cache_data: CacheData, user_id: str | None, group_id: str | None = None): +async def _(cache_data: CacheData, user_id: str | None, group_id: str | None = None): if user_id: return ( [ diff --git a/zhenxun/models/group_console.py b/zhenxun/models/group_console.py index 598ac34d..a85ed1f8 100644 --- a/zhenxun/models/group_console.py +++ b/zhenxun/models/group_console.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, overload +from typing import Any, ClassVar, cast, overload from typing_extensions import Self from tortoise import fields @@ -11,6 +11,42 @@ from zhenxun.services.db_context import Model from zhenxun.utils.enum import CacheType, DbLockType, PluginType +def add_disable_marker(name: str) -> str: + """添加模块禁用标记符 + + Args: + name: 模块名称 + + Returns: + 添加了禁用标记的模块名 (前缀'<'和后缀',') + """ + return f"<{name}," + + +@overload +def convert_module_format(data: str) -> list[str]: ... + + +@overload +def convert_module_format(data: list[str]) -> str: ... + + +def convert_module_format(data: str | list[str]) -> str | list[str]: + """ + 在 ` str: - return f"<{name}," - - @overload @classmethod - def convert_module_format(cls, data: str) -> list[str]: ... - - @overload - @classmethod - def convert_module_format(cls, data: list[str]) -> str: ... - - @classmethod - def convert_module_format(cls, data: str | list[str]) -> str | list[str]: - """ - 在 ` list[str]: + """获取默认禁用的任务模块 返回: - str | list[str]: 根据输入类型返回转换后的数据。 + list[str]: 任务模块列表 """ - if isinstance(data, str): - return [item.strip(",") for item in data.split("<") if item] - elif isinstance(data, list): - return "".join(cls.format(item) for item in data) + return cast( + list[str], + await TaskInfo.filter(default_status=default_status).values_list( + "module", flat=True + ), + ) + + @classmethod + async def _get_plugin_modules(cls, *, default_status: bool) -> list[str]: + """获取默认禁用的插件模块 + + 返回: + list[str]: 插件模块列表 + """ + return cast( + list[str], + await PluginInfo.filter( + plugin_type__in=[PluginType.NORMAL, PluginType.DEPENDANT], + default_status=default_status, + ).values_list("module", flat=True), + ) @classmethod @CacheRoot.listener(CacheType.GROUPS) @@ -92,20 +129,44 @@ class GroupConsole(Model): ) -> Self: """覆盖create方法""" group = await super().create(using_db=using_db, **kwargs) - if modules := await TaskInfo.filter(default_status=False).values_list( - "module", flat=True - ): - group.block_task = cls.convert_module_format(modules) # type: ignore - if modules := await PluginInfo.filter( - plugin_type__in=[PluginType.NORMAL, PluginType.DEPENDANT], - default_status=False, - ).values_list("module", flat=True): - group.block_plugin = cls.convert_module_format(modules) # type: ignore - await group.save( - using_db=using_db, update_fields=["block_plugin", "block_task"] - ) + + task_modules = await cls._get_task_modules(default_status=False) + plugin_modules = await cls._get_plugin_modules(default_status=False) + + if task_modules or plugin_modules: + await cls._update_modules(group, task_modules, plugin_modules, using_db) + return group + @classmethod + async def _update_modules( + cls, + group: Self, + task_modules: list[str], + plugin_modules: list[str], + using_db: BaseDBAsyncClient | None = None, + ) -> None: + """更新模块设置 + + 参数: + group: 群组实例 + task_modules: 任务模块列表 + plugin_modules: 插件模块列表 + using_db: 数据库连接 + """ + update_fields = [] + + if task_modules: + group.block_task = convert_module_format(task_modules) + update_fields.append("block_task") + + if plugin_modules: + group.block_plugin = convert_module_format(plugin_modules) + update_fields.append("block_plugin") + + if update_fields: + await group.save(using_db=using_db, update_fields=update_fields) + @classmethod async def get_or_create( cls, @@ -117,23 +178,19 @@ class GroupConsole(Model): group, is_create = await super().get_or_create( defaults=defaults, using_db=using_db, **kwargs ) - if is_create and ( - modules := await TaskInfo.filter(default_status=False).values_list( - "module", flat=True - ) - ): - group.block_task = cls.convert_module_format(modules) # type: ignore - if modules := await PluginInfo.filter( - plugin_type__in=[PluginType.NORMAL, PluginType.DEPENDANT], - default_status=False, - ).values_list("module", flat=True): - group.block_plugin = cls.convert_module_format(modules) # type: ignore - await group.save( - using_db=using_db, update_fields=["block_plugin", "block_task"] - ) + if not is_create: + return group, is_create + + task_modules = await cls._get_task_modules(default_status=False) + plugin_modules = await cls._get_plugin_modules(default_status=False) + + if task_modules or plugin_modules: + await cls._update_modules(group, task_modules, plugin_modules, using_db) + if is_create: if cache := await CacheRoot.get_cache(CacheType.GROUPS): await cache.update(group.group_id, group) + return group, is_create @classmethod @@ -148,20 +205,15 @@ class GroupConsole(Model): group, is_create = await super().update_or_create( defaults=defaults, using_db=using_db, **kwargs ) - if is_create and ( - modules := await TaskInfo.filter(default_status=False).values_list( - "module", flat=True - ) - ): - group.block_task = cls.convert_module_format(modules) # type: ignore - if modules := await PluginInfo.filter( - plugin_type__in=[PluginType.NORMAL, PluginType.DEPENDANT], - default_status=False, - ).values_list("module", flat=True): - group.block_plugin = cls.convert_module_format(modules) # type: ignore - await group.save( - using_db=using_db, update_fields=["block_plugin", "block_task"] - ) + if not is_create: + return group, is_create + + task_modules = await cls._get_task_modules(default_status=False) + plugin_modules = await cls._get_plugin_modules(default_status=False) + + if task_modules or plugin_modules: + await cls._update_modules(group, task_modules, plugin_modules, using_db) + return group, is_create @classmethod @@ -206,7 +258,7 @@ class GroupConsole(Model): """ return await cls.exists( group_id=group_id, - superuser_block_plugin__contains=cls.format(module), + superuser_block_plugin__contains=add_disable_marker(module), ) @classmethod @@ -220,10 +272,11 @@ class GroupConsole(Model): 返回: bool: 是否禁用插件 """ + module = add_disable_marker(module) return await cls.exists( - group_id=group_id, block_plugin__contains=cls.format(module) + group_id=group_id, block_plugin__contains=module ) or await cls.exists( - group_id=group_id, superuser_block_plugin__contains=cls.format(module) + group_id=group_id, superuser_block_plugin__contains=module ) @classmethod @@ -245,12 +298,22 @@ class GroupConsole(Model): group, _ = await cls.get_or_create( group_id=group_id, defaults={"platform": platform} ) + update_fields = [] if is_superuser: - if cls.format(module) not in group.superuser_block_plugin: - group.superuser_block_plugin += cls.format(module) - elif cls.format(module) not in group.block_plugin: - group.block_plugin += cls.format(module) - await group.save(update_fields=["block_plugin", "superuser_block_plugin"]) + superuser_block_plugin = convert_module_format(group.superuser_block_plugin) + if module not in superuser_block_plugin: + superuser_block_plugin.append(module) + group.superuser_block_plugin = convert_module_format( + superuser_block_plugin + ) + update_fields.append("superuser_block_plugin") + elif add_disable_marker(module) not in group.block_plugin: + block_plugin = convert_module_format(group.block_plugin) + block_plugin.append(module) + group.block_plugin = convert_module_format(block_plugin) + update_fields.append("block_plugin") + if update_fields: + await group.save(update_fields=update_fields) @classmethod async def set_unblock_plugin( @@ -271,14 +334,22 @@ class GroupConsole(Model): group, _ = await cls.get_or_create( group_id=group_id, defaults={"platform": platform} ) + update_fields = [] if is_superuser: - if cls.format(module) in group.superuser_block_plugin: - group.superuser_block_plugin = group.superuser_block_plugin.replace( - cls.format(module), "" + superuser_block_plugin = convert_module_format(group.superuser_block_plugin) + if module in superuser_block_plugin: + superuser_block_plugin.remove(module) + group.superuser_block_plugin = convert_module_format( + superuser_block_plugin ) - elif cls.format(module) in group.block_plugin: - group.block_plugin = group.block_plugin.replace(cls.format(module), "") - await group.save(update_fields=["block_plugin", "superuser_block_plugin"]) + update_fields.append("superuser_block_plugin") + elif add_disable_marker(module) in group.block_plugin: + block_plugin = convert_module_format(group.block_plugin) + block_plugin.remove(module) + group.block_plugin = convert_module_format(block_plugin) + update_fields.append("block_plugin") + if update_fields: + await group.save(update_fields=update_fields) @classmethod async def is_normal_block_plugin( @@ -297,7 +368,7 @@ class GroupConsole(Model): return await cls.exists( group_id=group_id, channel_id=channel_id, - block_plugin__contains=cls.format(module), + block_plugin__contains=f"<{module},", ) @classmethod @@ -313,7 +384,7 @@ class GroupConsole(Model): """ return await cls.exists( group_id=group_id, - superuser_block_task__contains=cls.format(task), + superuser_block_task__contains=add_disable_marker(task), ) @classmethod @@ -330,24 +401,23 @@ class GroupConsole(Model): 返回: bool: 是否禁用被动 """ + task = add_disable_marker(task) if not channel_id: return await cls.exists( group_id=group_id, channel_id__isnull=True, - block_task__contains=cls.format(task), + block_task__contains=task, ) or await cls.exists( group_id=group_id, channel_id__isnull=True, - superuser_block_task__contains=cls.format(task), + superuser_block_task__contains=task, ) return await cls.exists( - group_id=group_id, - channel_id=channel_id, - block_task__contains=cls.format(task), + group_id=group_id, channel_id=channel_id, block_task__contains=task ) or await cls.exists( group_id=group_id, channel_id__isnull=True, - superuser_block_task__contains=cls.format(task), + superuser_block_task__contains=task, ) @classmethod @@ -369,12 +439,20 @@ class GroupConsole(Model): group, _ = await cls.get_or_create( group_id=group_id, defaults={"platform": platform} ) + update_fields = [] if is_superuser: - if cls.format(task) not in group.superuser_block_task: - group.superuser_block_task += cls.format(task) - elif cls.format(task) not in group.block_task: - group.block_task += cls.format(task) - await group.save(update_fields=["block_task", "superuser_block_task"]) + superuser_block_task = convert_module_format(group.superuser_block_task) + if task not in group.superuser_block_task: + superuser_block_task.append(task) + group.superuser_block_task = convert_module_format(superuser_block_task) + update_fields.append("superuser_block_task") + elif add_disable_marker(task) not in group.block_task: + block_task = convert_module_format(group.block_task) + block_task.append(task) + group.block_task = convert_module_format(block_task) + update_fields.append("block_task") + if update_fields: + await group.save(update_fields=update_fields) @classmethod async def set_unblock_task( @@ -395,14 +473,20 @@ class GroupConsole(Model): group, _ = await cls.get_or_create( group_id=group_id, defaults={"platform": platform} ) + update_fields = [] if is_superuser: - if cls.format(task) in group.superuser_block_task: - group.superuser_block_task = group.superuser_block_task.replace( - cls.format(task), "" - ) - elif cls.format(task) in group.block_task: - group.block_task = group.block_task.replace(cls.format(task), "") - await group.save(update_fields=["block_task", "superuser_block_task"]) + superuser_block_task = convert_module_format(group.superuser_block_task) + if task in superuser_block_task: + superuser_block_task.remove(task) + group.superuser_block_task = convert_module_format(superuser_block_task) + update_fields.append("superuser_block_task") + elif add_disable_marker(task) in group.block_task: + block_task = convert_module_format(group.block_task) + block_task.remove(task) + group.block_task = convert_module_format(block_task) + update_fields.append("block_task") + if update_fields: + await group.save(update_fields=update_fields) @classmethod def _run_script(cls): diff --git a/zhenxun/services/cache.py b/zhenxun/services/cache.py index 84f128a6..04e60a23 100644 --- a/zhenxun/services/cache.py +++ b/zhenxun/services/cache.py @@ -46,13 +46,12 @@ class CacheGetter(BaseModel, Generic[T]): async def get(self, cache_data: "CacheData", *args, **kwargs) -> T: """获取缓存""" - processed_data = ( - await self.get_func(cache_data, *args, **kwargs) - if self.get_func and is_coroutine_callable(self.get_func) - else self.get_func(cache_data, *args, **kwargs) - if self.get_func - else cache_data.data - ) + if not self.get_func: + return cache_data.data + if is_coroutine_callable(self.get_func): + processed_data = await self.get_func(cache_data, *args, **kwargs) + else: + processed_data = self.get_func(cache_data, *args, **kwargs) return cast(T, processed_data) @@ -81,9 +80,14 @@ class CacheData(BaseModel): """更新时间""" reload_count: int = 0 """更新次数""" + incremental_update: bool = True + """是否是增量更新""" async def get(self, *args, **kwargs) -> Any: """获取单个缓存""" + if not self.reload_count and not self.incremental_update: + # 首次获取时,非增量更新获取全部数据 + await self.reload() self.call_cleanup_expired() # 移除过期缓存 if not self.getter: return self.data @@ -210,12 +214,17 @@ class CacheManage: id=f"CacheRoot-{cache_data.name}", ) - def new(self, name: str, expire: int = 60 * 10): + def new(self, name: str, incremental_update: bool = True, expire: int = 60 * 10): def wrapper(func: Callable): _name = name.upper() if _name in self._data: raise DbCacheException(f"DbCache 缓存数据 {name} 已存在...") - self._data[_name] = CacheData(name=_name, func=func, expire=expire) + self._data[_name] = CacheData( + name=_name, + func=func, + expire=expire, + incremental_update=incremental_update, + ) return wrapper @@ -253,7 +262,7 @@ class CacheManage: return wrapper @validate_name - def getter(self, name: str, result_model: type | None = None): + def getter(self, name: str, result_model: type): def wrapper(func: Callable): self._data[name].getter = CacheGetter[result_model](get_func=func) diff --git a/zhenxun/services/db_context.py b/zhenxun/services/db_context.py index 359fc5f2..65697f32 100644 --- a/zhenxun/services/db_context.py +++ b/zhenxun/services/db_context.py @@ -69,12 +69,12 @@ class Model(TortoiseModel): using_db: BaseDBAsyncClient | None = None, **kwargs: Any, ) -> tuple[Self, bool]: - result = await super().get_or_create( + result, is_create = await super().get_or_create( defaults=defaults, using_db=using_db, **kwargs ) - if cache_type := cls.get_cache_type(): + if is_create and (cache_type := cls.get_cache_type()): await CacheRoot.reload(cache_type) - return result + return (result, is_create) @classmethod async def update_or_create( diff --git a/zhenxun/utils/utils.py b/zhenxun/utils/utils.py index 3bdf2df1..16175449 100644 --- a/zhenxun/utils/utils.py +++ b/zhenxun/utils/utils.py @@ -1,4 +1,5 @@ from collections import defaultdict +from dataclasses import dataclass from datetime import datetime import os from pathlib import Path @@ -6,6 +7,7 @@ import time from typing import Any import httpx +from nonebot_plugin_uninfo import Uninfo import pypinyin import pytz @@ -13,6 +15,16 @@ from zhenxun.configs.config import Config from zhenxun.services.log import logger +@dataclass +class EntityIDs: + user_id: str + """用户id""" + group_id: str | None + """群组id""" + channel_id: str | None + """频道id""" + + class ResourceDirManager: """ 临时文件管理器 @@ -228,3 +240,24 @@ def is_valid_date(date_text: str, separator: str = "-") -> bool: return True except ValueError: return False + + +def get_entity_ids(session: Uninfo) -> EntityIDs: + """获取用户id,群组id,频道id + + 参数: + session: Uninfo + + 返回: + EntityIDs: 用户id,群组id,频道id + """ + user_id = session.user.id + group_id = None + channel_id = None + if session.group: + if session.group.parent: + group_id = session.group.parent.id + channel_id = session.group.id + else: + group_id = session.group.id + return EntityIDs(user_id=user_id, group_id=group_id, channel_id=channel_id)