diff --git a/zhenxun/builtin_plugins/hooks/auth/auth_ban.py b/zhenxun/builtin_plugins/hooks/auth/auth_ban.py index 4121d8ff..f1259c9c 100644 --- a/zhenxun/builtin_plugins/hooks/auth/auth_ban.py +++ b/zhenxun/builtin_plugins/hooks/auth/auth_ban.py @@ -1,3 +1,5 @@ +import asyncio + from nonebot.adapters import Bot from nonebot.matcher import Matcher from nonebot_plugin_alconna import At @@ -25,8 +27,17 @@ Config.add_plugin_config( async def is_ban(user_id: str | None, group_id: str | None) -> int: + if not user_id and not group_id: + return 0 cache = Cache[list[BanConsole]](CacheType.BAN) - results = await cache.get(user_id, group_id) or await cache.get(user_id) + group_user, user = await asyncio.gather( + cache.get(user_id, group_id), cache.get(user_id) + ) + results = [] + if group_user: + results.extend(group_user) + if user: + results.extend(user) if not results: return 0 for result in results: diff --git a/zhenxun/builtin_plugins/hooks/auth/auth_limit.py b/zhenxun/builtin_plugins/hooks/auth/auth_limit.py index 61dc8705..fe29ebc4 100644 --- a/zhenxun/builtin_plugins/hooks/auth/auth_limit.py +++ b/zhenxun/builtin_plugins/hooks/auth/auth_limit.py @@ -1,5 +1,6 @@ from typing import ClassVar +import nonebot from nonebot_plugin_uninfo import Uninfo from pydantic import BaseModel @@ -18,6 +19,14 @@ from zhenxun.utils.utils import ( from .config import LOGGER_COMMAND from .exception import SkipPluginException +driver = nonebot.get_driver() + + +@driver.on_startup +async def _(): + """初始化限制""" + await LimitManager.init_limit() + class Limit(BaseModel): limit: PluginLimit @@ -34,6 +43,13 @@ class LimitManager: block_limit: ClassVar[dict[str, Limit]] = {} count_limit: ClassVar[dict[str, Limit]] = {} + @classmethod + async def init_limit(cls): + """初始化限制""" + limit_list = await PluginLimit.filter(status=True).all() + for limit in limit_list: + cls.add_limit(limit) + @classmethod def add_limit(cls, limit: PluginLimit): """添加限制 @@ -169,9 +185,7 @@ async def auth_limit(plugin: PluginInfo, session: Uninfo): """ entity = get_entity_ids(session) if plugin.module not in LimitManager.add_module: - limit_list: list[PluginLimit] = await plugin.plugin_limit.filter( - status=True - ).all() # type: ignore + limit_list = await PluginLimit.filter(module=plugin.module, status=True).all() for limit in limit_list: LimitManager.add_limit(limit) if entity.user_id: diff --git a/zhenxun/builtin_plugins/hooks/auth_checker.py b/zhenxun/builtin_plugins/hooks/auth_checker.py index 554ec29e..0e0d5c64 100644 --- a/zhenxun/builtin_plugins/hooks/auth_checker.py +++ b/zhenxun/builtin_plugins/hooks/auth_checker.py @@ -159,9 +159,9 @@ async def auth( auth_group(plugin, entity, message), auth_admin(plugin, session), auth_plugin(plugin, session, event), + auth_limit(plugin, session), ] ) - await auth_limit(plugin, session) except SkipPluginException as e: LimitManager.unblock(module, entity.user_id, entity.group_id, entity.channel_id) logger.info(str(e), LOGGER_COMMAND, session=session) diff --git a/zhenxun/builtin_plugins/init/__init__.py b/zhenxun/builtin_plugins/init/__init__.py index 607e94fa..7c78b019 100644 --- a/zhenxun/builtin_plugins/init/__init__.py +++ b/zhenxun/builtin_plugins/init/__init__.py @@ -11,13 +11,18 @@ from zhenxun.utils.platform import PlatformUtils nonebot.load_plugins(str(Path(__file__).parent.resolve())) try: - from . import __init_cache + from .__init_cache import CacheRoot except DbCacheException as e: raise SystemError(f"ERROR:{e}") driver = nonebot.get_driver() +@driver.on_startup +async def _(): + await CacheRoot.init_non_lazy_caches() + + @driver.on_bot_connect async def _(bot: Bot): """将bot已存在的群组添加群认证 diff --git a/zhenxun/builtin_plugins/init/__init_cache.py b/zhenxun/builtin_plugins/init/__init_cache.py index 6cba00d4..aced954c 100644 --- a/zhenxun/builtin_plugins/init/__init_cache.py +++ b/zhenxun/builtin_plugins/init/__init_cache.py @@ -3,7 +3,6 @@ from zhenxun.models.bot_console import BotConsole from zhenxun.models.group_console import GroupConsole from zhenxun.models.level_user import LevelUser from zhenxun.models.plugin_info import PluginInfo -from zhenxun.models.plugin_limit import PluginLimit from zhenxun.models.user_console import UserConsole from zhenxun.services.cache import CacheData, CacheRoot from zhenxun.services.log import logger @@ -125,7 +124,7 @@ async def _(cache_data: CacheData, data: dict[str, UserConsole] | None): await cache_data.set_key(user.user_id, user) -@CacheRoot.new(CacheType.LEVEL, False) +@CacheRoot.new(CacheType.LEVEL) async def _(): """初始化等级缓存""" data_list = await LevelUser().all() @@ -152,52 +151,61 @@ async def _(cache_data: CacheData, user_id: str, group_id: str | None = None): async def _(): """初始化封禁缓存""" data_list = await BanConsole.all() - return {f"{d.user_id or ''}:{d.group_id or ''}": d for d in data_list} + return {f"{d.group_id or ''}:{d.user_id or ''}": d for d in data_list} @CacheRoot.getter(CacheType.BAN, result_model=list[BanConsole]) async def _(cache_data: CacheData, user_id: str | None, group_id: str | None = None): """获取封禁缓存""" - key = f"{user_id or ''}:{group_id or ''}" + if not user_id and not group_id: + return [] + key = f"{group_id or ''}:{user_id or ''}" + logger.info(f"获取封禁缓存: {key}") data = await cache_data.get_key(key) - if not data: - if user_id and group_id: - data = await BanConsole.filter(user_id=user_id, group_id=group_id).all() - elif user_id: - data = await BanConsole.filter(user_id=user_id, group_id__isnull=True).all() - elif group_id: - data = await BanConsole.filter( - user_id__isnull=True, group_id=group_id - ).all() - if data: - await cache_data.set_key(key, data) - return data + if data: + logger.info(f"已存在缓存: {key}:{data}") + # if not data: + # start = time.time() + # if user_id and group_id: + # data = await BanConsole.filter(user_id=user_id, group_id=group_id).all() + # elif user_id: + # data = await BanConsole.filter(user_id=user_id, group_id__isnull=True).all() + # elif group_id: + # data = await BanConsole.filter( + # user_id__isnull=True, group_id=group_id + # ).all() + # logger.info( + # f"获取封禁缓存耗时: {time.time() - start:.2f}秒, key: {key}, data: {data}" + # ) + # if data: + # await cache_data.set_key(key, data) + # return data return data or [] -@CacheRoot.new(CacheType.LIMIT) -async def _(): - """初始化限制缓存""" - data_list = await PluginLimit.filter(status=True).all() - return {data.module: data for data in data_list} +# @CacheRoot.new(CacheType.LIMIT) +# async def _(): +# """初始化限制缓存""" +# data_list = await PluginLimit.filter(status=True).all() +# return {data.module: data for data in data_list} -@CacheRoot.getter(CacheType.LIMIT, result_model=list[PluginLimit]) -async def _(cache_data: CacheData, module: str): - """获取限制缓存""" - data = await cache_data.get_key(module) - if not data: - if limits := await PluginLimit.filter(module=module, status=True): - await cache_data.set_key(module, limits) - return limits - return data or [] +# @CacheRoot.getter(CacheType.LIMIT, result_model=list[PluginLimit]) +# async def _(cache_data: CacheData, module: str): +# """获取限制缓存""" +# data = await cache_data.get_key(module) +# if not data: +# if limits := await PluginLimit.filter(module=module, status=True): +# await cache_data.set_key(module, limits) +# return limits +# return data or [] -@CacheRoot.with_refresh(CacheType.LIMIT) -async def _(cache_data: CacheData, data: dict[str, list[PluginLimit]] | None): - """刷新限制缓存""" - if not data: - return - limits = await PluginLimit.filter(module__in=data.keys(), load_status=True).all() - for limit in limits: - await cache_data.set_key(limit.module, limit) +# @CacheRoot.with_refresh(CacheType.LIMIT) +# async def _(cache_data: CacheData, data: dict[str, list[PluginLimit]] | None): +# """刷新限制缓存""" +# if not data: +# return +# limits = await PluginLimit.filter(module__in=data.keys(), load_status=True).all() +# for limit in limits: +# await cache_data.set_key(limit.module, limit) diff --git a/zhenxun/services/cache.py b/zhenxun/services/cache.py index 7c3378bf..49d03eec 100644 --- a/zhenxun/services/cache.py +++ b/zhenxun/services/cache.py @@ -98,7 +98,7 @@ class CacheData(BaseModel): with_refresh: Callable[..., Any] | None = None expire: int = 600 # 默认10分钟过期 reload_count: int = 0 - incremental_update: bool = True + lazy_load: bool = True # 默认延迟加载 _cache_instance: BaseCache | None = None result_model: type | None = None _keys: set[str] = set() # 存储所有缓存键 @@ -168,12 +168,12 @@ class CacheData(BaseModel): try: if hasattr(field, "to_python_value"): if not field.field_type: - logger.warning(f"字段 {field_name} 类型为空") + logger.debug(f"字段 {field_name} 类型为空") continue field_value = field.to_python_value(field_value) setattr(instance, field_name, field_value) except Exception as e: - logger.warning(f"设置字段 {field_name} 失败: {e}") + logger.warning(f"设置字段 {field_name} 失败", e=e) # 设置 _saved_in_db 标志 instance._saved_in_db = True @@ -333,7 +333,7 @@ class CacheData(BaseModel): async def get(self, key: str, *args, **kwargs) -> Any: """获取缓存""" - if not self.reload_count and not self.incremental_update: + if not self.reload_count and not self.lazy_load: await self.reload(*args, **kwargs) if not self.getter: @@ -343,7 +343,7 @@ class CacheData(BaseModel): async def get_all(self, *args, **kwargs) -> dict[str, Any]: """获取所有缓存数据""" - if not self.reload_count and not self.incremental_update: + if not self.reload_count and not self.lazy_load: await self.reload(*args, **kwargs) if not self.getter: @@ -523,8 +523,24 @@ class CacheManager: _data: ClassVar[dict[str, CacheData]] = {} - def new(self, name: str, incremental_update: bool = True, expire: int = 600): - """注册新缓存""" + async def init_non_lazy_caches(self): + """初始化所有非延迟加载的缓存""" + for name, cache in self._data.items(): + if not cache.lazy_load: + try: + await cache.reload() + logger.info(f"初始化缓存 {name} 完成") + except Exception as e: + logger.error(f"初始化缓存 {name} 失败: {e}") + + def new(self, name: str, lazy_load: bool = True, expire: int = 600): + """注册新缓存 + + Args: + name: 缓存名称 + lazy_load: 是否延迟加载,默认为True。为False时会在程序启动时自动加载 + expire: 过期时间(秒) + """ def wrapper(func: Callable): _name = name.upper() @@ -535,7 +551,7 @@ class CacheManager: name=_name, func=func, expire=expire, - incremental_update=incremental_update, + lazy_load=lazy_load, ) return func