From 8865ab0d587c863856c194c53ff1994df186525e Mon Sep 17 00:00:00 2001 From: HibiKier <775757368@qq.com> Date: Fri, 11 Apr 2025 16:46:24 +0800 Subject: [PATCH] =?UTF-8?q?:sparkles:=20=E6=B7=BB=E5=8A=A0=E6=8F=92?= =?UTF-8?q?=E4=BB=B6limit=E7=BC=93=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../builtin_plugins/hooks/auth/auth_ban.py | 3 +- .../builtin_plugins/hooks/auth/auth_limit.py | 42 +++---- .../builtin_plugins/hooks/auth/bot_filter.py | 35 ++++++ zhenxun/builtin_plugins/hooks/auth_checker.py | 2 + zhenxun/builtin_plugins/init/__init_cache.py | 108 +++++++++++++++++- zhenxun/services/cache.py | 6 +- zhenxun/utils/enum.py | 2 + 7 files changed, 163 insertions(+), 35 deletions(-) create mode 100644 zhenxun/builtin_plugins/hooks/auth/bot_filter.py diff --git a/zhenxun/builtin_plugins/hooks/auth/auth_ban.py b/zhenxun/builtin_plugins/hooks/auth/auth_ban.py index 727ab70b..ed141371 100644 --- a/zhenxun/builtin_plugins/hooks/auth/auth_ban.py +++ b/zhenxun/builtin_plugins/hooks/auth/auth_ban.py @@ -14,7 +14,7 @@ from zhenxun.utils.utils import EntityIDs, get_entity_ids from .config import LOGGER_COMMAND from .exception import SkipPluginException -from .utils import send_message +from .utils import freq, send_message Config.add_plugin_config( "hook", @@ -131,6 +131,7 @@ async def user_handle( # and not db_plugin.ignore_prompt and time != -1 and ban_result + and freq.is_send_limit_message(db_plugin, entity.user_id, False) ): await send_message( session, diff --git a/zhenxun/builtin_plugins/hooks/auth/auth_limit.py b/zhenxun/builtin_plugins/hooks/auth/auth_limit.py index e56fa9ed..61dc8705 100644 --- a/zhenxun/builtin_plugins/hooks/auth/auth_limit.py +++ b/zhenxun/builtin_plugins/hooks/auth/auth_limit.py @@ -8,7 +8,12 @@ from zhenxun.models.plugin_limit import PluginLimit from zhenxun.services.log import logger from zhenxun.utils.enum import LimitWatchType, PluginLimitType from zhenxun.utils.message import MessageUtils -from zhenxun.utils.utils import CountLimiter, FreqLimiter, UserBlockLimiter +from zhenxun.utils.utils import ( + CountLimiter, + FreqLimiter, + UserBlockLimiter, + get_entity_ids, +) from .config import LOGGER_COMMAND from .exception import SkipPluginException @@ -22,7 +27,7 @@ class Limit(BaseModel): arbitrary_types_allowed = True -class LimitManage: +class LimitManager: add_module: ClassVar[list] = [] cd_limit: ClassVar[dict[str, Limit]] = {} @@ -84,7 +89,6 @@ class LimitManage: user_id: str, group_id: str | None, channel_id: str | None, - session: Uninfo, ): """检测限制 @@ -93,17 +97,16 @@ class LimitManage: user_id: 用户id group_id: 群组id channel_id: 频道id - session: Session 异常: IgnoredException: IgnoredException """ if limit_model := cls.cd_limit.get(module): - await cls.__check(limit_model, user_id, group_id, channel_id, session) + await cls.__check(limit_model, user_id, group_id, channel_id) if limit_model := cls.block_limit.get(module): - await cls.__check(limit_model, user_id, group_id, channel_id, session) + await cls.__check(limit_model, user_id, group_id, channel_id) if limit_model := cls.count_limit.get(module): - await cls.__check(limit_model, user_id, group_id, channel_id, session) + await cls.__check(limit_model, user_id, group_id, channel_id) @classmethod async def __check( @@ -112,7 +115,6 @@ class LimitManage: user_id: str, group_id: str | None, channel_id: str | None, - session: Uninfo, ): """检测限制 @@ -121,7 +123,6 @@ class LimitManage: user_id: 用户id group_id: 群组id channel_id: 频道id - session: Session 异常: IgnoredException: IgnoredException @@ -166,23 +167,14 @@ async def auth_limit(plugin: PluginInfo, session: Uninfo): plugin: PluginInfo session: Uninfo """ - user_id = session.user.id - group_id = None - channel_id = None - if session.group: - if session.group.parent: - group_id = session.group.parent.id - channel_id = session.group.id - else: - group_id = session.group.id - if not group_id: - group_id = channel_id - channel_id = None - if plugin.module not in LimitManage.add_module: + entity = get_entity_ids(session) + if plugin.module not in LimitManager.add_module: limit_list: list[PluginLimit] = await plugin.plugin_limit.filter( status=True ).all() # type: ignore for limit in limit_list: - LimitManage.add_limit(limit) - if user_id: - await LimitManage.check(plugin.module, user_id, group_id, channel_id, session) + LimitManager.add_limit(limit) + if entity.user_id: + await LimitManager.check( + plugin.module, entity.user_id, entity.group_id, entity.channel_id + ) diff --git a/zhenxun/builtin_plugins/hooks/auth/bot_filter.py b/zhenxun/builtin_plugins/hooks/auth/bot_filter.py new file mode 100644 index 00000000..04e47372 --- /dev/null +++ b/zhenxun/builtin_plugins/hooks/auth/bot_filter.py @@ -0,0 +1,35 @@ +import nonebot +from nonebot_plugin_uninfo import Uninfo + +from zhenxun.configs.config import Config + +from .exception import SkipPluginException + +Config.add_plugin_config( + "hook", + "FILTER_BOT", + True, + help="过滤当前连接bot(防止bot互相调用)", + default_value=True, + type=bool, +) + + +def bot_filter(session: Uninfo): + """过滤bot调用bot + + 参数: + session: Uninfo + + 异常: + SkipPluginException: bot互相调用 + """ + if not Config.get_config("hook", "FILTER_BOT"): + return + bot_ids = list(nonebot.get_bots().keys()) + if session.user.id == session.self_id: + return + if session.user.id in bot_ids: + raise SkipPluginException( + f"bot:{session.self_id} 尝试调用 bot:{session.user.id}" + ) diff --git a/zhenxun/builtin_plugins/hooks/auth_checker.py b/zhenxun/builtin_plugins/hooks/auth_checker.py index c6568983..9e0293e1 100644 --- a/zhenxun/builtin_plugins/hooks/auth_checker.py +++ b/zhenxun/builtin_plugins/hooks/auth_checker.py @@ -27,6 +27,7 @@ from .auth.auth_cost import auth_cost from .auth.auth_group import auth_group from .auth.auth_limit import LimitManage, auth_limit from .auth.auth_plugin import auth_plugin +from .auth.bot_filter import bot_filter from .auth.config import LOGGER_COMMAND from .auth.exception import ( IsSuperuserException, @@ -152,6 +153,7 @@ async def auth( cost_gold = await get_plugin_cost(bot, user, plugin, session) await asyncio.gather( *[ + bot_filter(session), auth_ban(matcher, bot, session), auth_bot(plugin, bot.self_id), auth_group(plugin, entity, message), diff --git a/zhenxun/builtin_plugins/init/__init_cache.py b/zhenxun/builtin_plugins/init/__init_cache.py index ceced884..49f84f26 100644 --- a/zhenxun/builtin_plugins/init/__init_cache.py +++ b/zhenxun/builtin_plugins/init/__init_cache.py @@ -8,6 +8,7 @@ from zhenxun.models.bot_console import BotConsole from zhenxun.models.group_console import GroupConsole from zhenxun.models.level_user import LevelUser from zhenxun.models.plugin_info import PluginInfo +from zhenxun.models.plugin_limit import PluginLimit from zhenxun.models.user_console import UserConsole from zhenxun.services.cache import CacheData, CacheRoot from zhenxun.utils.enum import CacheType @@ -38,16 +39,41 @@ def default_cleanup_expired(cache_data: CacheData) -> list[str]: return expire_key +def default_cleanup_expired_1(cache_data: CacheData) -> list[str]: + """默认清理列表过期cache方法""" + if not cache_data.data: + return [] + now = time.time() + expire_key = [] + for k, t in list(cache_data.expire_data.items()): + if t < now: + expire_key.append(k) + cache_data.expire_data.pop(k) + if expire_key: + cache_data.data = [k for k in cache_data.data if repr(k) not in expire_key] + return expire_key + + def default_with_expiration( data: dict[str, Any], expire_data: dict[str, int], expire: int ): - """默认更新期时间cache方法""" + """默认更新过期时间cache方法""" if not data: return {} keys = {k for k in data if k not in expire_data} return {k: time.time() + expire for k in keys} if keys else {} +def default_with_expiration_1( + data: dict[str, Any], expire_data: dict[str, int], expire: int +): + """默认更新过期时间cache方法""" + if not data: + return {} + keys = {repr(k) for k in data if repr(k) not in expire_data} + return {k: time.time() + expire for k in keys} if keys else {} + + @CacheRoot.new(CacheType.PLUGINS) async def _(): data_list = await PluginInfo.get_plugins() @@ -75,7 +101,7 @@ async def _(cache_data: CacheData, module: str): @CacheRoot.with_refresh(CacheType.PLUGINS) async def _(data: dict[str, PluginInfo]): - plugins = await PluginInfo.filter(module__in=data.keys(), load_status=True) + plugins = await PluginInfo.filter(module__in=data.keys(), load_status=True).all() for plugin in plugins: data[plugin.module] = plugin @@ -119,7 +145,7 @@ async def _(cache_data: CacheData, group_id: str): async def _(data: dict[str, GroupConsole]): groups = await GroupConsole.filter( group_id__in=data.keys(), channel_id__isnull=True - ) + ).all() for group in groups: data[group.group_id] = group @@ -161,7 +187,7 @@ async def _(cache_data: CacheData, bot_id: str): @CacheRoot.with_refresh(CacheType.BOT) async def _(data: dict[str, BotConsole]): - bots = await BotConsole.filter(bot_id__in=data.keys()) + bots = await BotConsole.filter(bot_id__in=data.keys()).all() for bot in bots: data[bot.bot_id] = bot @@ -203,7 +229,7 @@ async def _(cache_data: CacheData, user_id: str): @CacheRoot.with_refresh(CacheType.USERS) async def _(data: dict[str, UserConsole]): - users = await UserConsole.filter(user_id__in=data.keys()) + users = await UserConsole.filter(user_id__in=data.keys()).all() for user in users: data[user.user_id] = user @@ -240,7 +266,17 @@ async def _(cache_data: CacheData, user_id: str, group_id: str | None = None): ] -@CacheRoot.new(CacheType.BAN, False, 5) +@CacheRoot.with_expiration(CacheType.LEVEL) +def _(data: dict[str, UserConsole], expire_data: dict[str, int], expire: int): + return default_with_expiration_1(data, expire_data, expire) + + +@CacheRoot.cleanup_expired(CacheType.LEVEL) +def _(cache_data: CacheData): + return default_cleanup_expired_1(cache_data) + + +@CacheRoot.new(CacheType.BAN, False) async def _(): return await BanConsole.all() @@ -268,3 +304,63 @@ async def _(cache_data: CacheData, user_id: str | None, group_id: str | None = N if not data.user_id and data.group_id == group_id ] return None + + +@CacheRoot.with_expiration(CacheType.BAN) +def _(data: dict[str, UserConsole], expire_data: dict[str, int], expire: int): + return default_with_expiration_1(data, expire_data, expire) + + +@CacheRoot.cleanup_expired(CacheType.BAN) +def _(cache_data: CacheData): + return default_cleanup_expired_1(cache_data) + + +@CacheRoot.new(CacheType.LIMIT) +async def _(): + data_list = await PluginLimit.filter(status=True).all() + result_data = {} + for data in data_list: + if not result_data.get(data.module): + result_data[data.module] = [] + result_data[data.module].append(data) + return result_data + + +@CacheRoot.updater(CacheType.LIMIT) +async def _(data: dict[str, list[PluginLimit]], key: str, value: Any): + if value: + data[key] = value + elif limits := await PluginLimit.filter(module=key, status=True): + data[key] = limits + + +@CacheRoot.getter(CacheType.LIMIT, result_model=list[PluginLimit]) +async def _(cache_data: CacheData, module: str): + cache_data.data = cache_data.data or {} + result = cache_data.data.get(module, None) + if not result: + result = await PluginLimit.filter(module=module, status=True) + if result: + cache_data.data[module] = result + return result + + +@CacheRoot.with_refresh(CacheType.LIMIT) +async def _(data: dict[str, list[PluginLimit]]): + limits = await PluginLimit.filter(module__in=data.keys(), load_status=True).all() + data.clear() + for limit in limits: + if not data.get(limit.module): + data[limit.module] = [] + data[limit.module].append(limit) + + +@CacheRoot.with_expiration(CacheType.LIMIT) +def _(data: dict[str, PluginInfo], expire_data: dict[str, int], expire: int): + return default_with_expiration(data, expire_data, expire) + + +@CacheRoot.cleanup_expired(CacheType.LIMIT) +def _(cache_data: CacheData): + return default_cleanup_expired(cache_data) diff --git a/zhenxun/services/cache.py b/zhenxun/services/cache.py index 04e60a23..29774447 100644 --- a/zhenxun/services/cache.py +++ b/zhenxun/services/cache.py @@ -33,7 +33,7 @@ def validate_name(func: Callable): def wrapper(self, name: str, *args, **kwargs): _name = name.upper() - if _name not in CacheManage._data: + if _name not in CacheManager._data: raise DbCacheException(f"DbCache 缓存数据 {name} 不存在...") return func(self, _name, *args, **kwargs) @@ -190,7 +190,7 @@ class CacheData(BaseModel): ) -class CacheManage: +class CacheManager: """全局缓存管理,减少数据库与网络请求查询次数 @@ -324,7 +324,7 @@ class CacheManage: await cache.update(key, value, *args, **kwargs) -CacheRoot = CacheManage() +CacheRoot = CacheManager() class Cache(Generic[T]): diff --git a/zhenxun/utils/enum.py b/zhenxun/utils/enum.py index c6d31572..5b235615 100644 --- a/zhenxun/utils/enum.py +++ b/zhenxun/utils/enum.py @@ -18,6 +18,8 @@ class CacheType(StrEnum): """全局bot信息""" LEVEL = "GLOBAL_USER_LEVEL" """用户权限""" + LIMIT = "GLOBAL_LIMIT" + """插件限制""" class DbLockType(StrEnum):