diff --git a/zhenxun/builtin_plugins/hooks/auth/auth_admin.py b/zhenxun/builtin_plugins/hooks/auth/auth_admin.py index 177c6037..3bdbe1ef 100644 --- a/zhenxun/builtin_plugins/hooks/auth/auth_admin.py +++ b/zhenxun/builtin_plugins/hooks/auth/auth_admin.py @@ -4,8 +4,8 @@ from nonebot_plugin_uninfo import Uninfo from zhenxun.models.level_user import LevelUser from zhenxun.models.plugin_info import PluginInfo -from zhenxun.services.log import logger from zhenxun.services.cache import Cache +from zhenxun.services.log import logger from zhenxun.utils.enum import CacheType from zhenxun.utils.message import MessageUtils diff --git a/zhenxun/builtin_plugins/hooks/auth/auth_bot.py b/zhenxun/builtin_plugins/hooks/auth/auth_bot.py index 77d487b1..b10b0079 100644 --- a/zhenxun/builtin_plugins/hooks/auth/auth_bot.py +++ b/zhenxun/builtin_plugins/hooks/auth/auth_bot.py @@ -2,8 +2,8 @@ from nonebot.exception import IgnoredException from zhenxun.models.bot_console import BotConsole from zhenxun.models.plugin_info import PluginInfo -from zhenxun.services.log import logger from zhenxun.services.cache import Cache +from zhenxun.services.log import logger from zhenxun.utils.common_utils import CommonUtils from zhenxun.utils.enum import CacheType diff --git a/zhenxun/builtin_plugins/hooks/ban_hook.py b/zhenxun/builtin_plugins/hooks/ban_hook.py index a77a7300..3c38678a 100644 --- a/zhenxun/builtin_plugins/hooks/ban_hook.py +++ b/zhenxun/builtin_plugins/hooks/ban_hook.py @@ -9,6 +9,7 @@ from tortoise.exceptions import MultipleObjectsReturned from zhenxun.configs.config import Config from zhenxun.models.ban_console import BanConsole from zhenxun.models.group_console import GroupConsole +from zhenxun.models.plugin_info import PluginInfo from zhenxun.services.cache import Cache from zhenxun.services.log import logger from zhenxun.utils.enum import CacheType, PluginType @@ -27,7 +28,8 @@ _flmt = FreqLimiter(300) async def is_ban(user_id: str | None, group_id: str | None): cache = Cache[list[BanConsole]](CacheType.BAN) - return await cache.get(user_id, group_id) + result = await cache.get(user_id, group_id) or await cache.get(user_id) + return result and result[0].ban_time > 0 # 检查是否被ban @@ -80,8 +82,18 @@ async def _(matcher: Matcher, bot: Bot, session: Uninfo): time_str = f"{hours} 小时 {minute}分钟" else: time_str = f"{minute} 分钟" - if time != -1 and ban_result and _flmt.check(user_id): + db_plugin = await Cache[PluginInfo](CacheType.PLUGINS).get( + matcher.plugin_name + ) + if ( + db_plugin + # and not db_plugin.ignore_prompt + and time != -1 + and ban_result + and _flmt.check(user_id) + ): _flmt.start_cd(user_id) + logger.debug(f"ban检测发送插件: {matcher.plugin_name}") await MessageUtils.build_message( [ At(flag="user", target=user_id), diff --git a/zhenxun/builtin_plugins/init/__init__.py b/zhenxun/builtin_plugins/init/__init__.py index 4bcdb3b3..3ea493cc 100644 --- a/zhenxun/builtin_plugins/init/__init__.py +++ b/zhenxun/builtin_plugins/init/__init__.py @@ -1,22 +1,21 @@ +import os from pathlib import Path -from typing import Any +import sys import nonebot from nonebot.adapters import Bot -from zhenxun.models.ban_console import BanConsole -from zhenxun.models.bot_console import BotConsole from zhenxun.models.group_console import GroupConsole -from zhenxun.models.level_user import LevelUser -from zhenxun.models.plugin_info import PluginInfo -from zhenxun.models.user_console import UserConsole -from zhenxun.services.cache import CacheRoot +from zhenxun.services.cache import DbCacheException from zhenxun.services.log import logger -from zhenxun.utils.enum import CacheType from zhenxun.utils.platform import PlatformUtils nonebot.load_plugins(str(Path(__file__).parent.resolve())) +try: + from .__init_cache import driver +except DbCacheException as e: + raise SystemError(f"ERROR:{e}") driver = nonebot.get_driver() @@ -49,150 +48,3 @@ async def _(bot: Bot): f"更新Bot: {bot.self_id} 的群认证完成,共创建 {len(create_list)} 条数据," f"共修改 {len(update_id)} 条数据..." ) - - -@CacheRoot.new(CacheType.PLUGINS) -async def _(data: dict[str, PluginInfo] = {}, key: str | None = None): - if data and key: - if plugin := await PluginInfo.get_plugin(module=key): - data[key] = plugin - else: - data_list = await PluginInfo.get_plugins() - return {p.module: p for p in data_list} - - -@CacheRoot.updater(CacheType.PLUGINS) -async def _(data: dict[str, PluginInfo], key: str, value: Any): - if value: - data[key] = value - elif plugin := await PluginInfo.get_plugin(module=key): - data[key] = plugin - - -@CacheRoot.getter(CacheType.PLUGINS, result_model=PluginInfo) -async def _(data: dict[str, PluginInfo], module: str): - result = data.get(module, None) - if not result: - result = await PluginInfo.get_plugin(module=module) - if result: - data[module] = result - return result - - -@CacheRoot.new(CacheType.GROUPS) -async def _(): - data_list = await GroupConsole.all() - return {p.group_id: p for p in data_list if not p.channel_id} - - -@CacheRoot.updater(CacheType.GROUPS) -async def _(data: dict[str, GroupConsole], key: str, value: Any): - if value: - data[key] = value - elif group := await GroupConsole.get_group(group_id=key): - data[key] = group - - -@CacheRoot.getter(CacheType.GROUPS, result_model=GroupConsole) -async def _(data: dict[str, GroupConsole], group_id: str): - result = data.get(group_id, None) - if not result: - result = await GroupConsole.get_group(group_id=group_id) - if result: - data[group_id] = result - return result - - -@CacheRoot.new(CacheType.BOT) -async def _(): - data_list = await BotConsole.all() - return {p.bot_id: p for p in data_list} - - -@CacheRoot.updater(CacheType.BOT) -async def _(data: dict[str, BotConsole], key: str, value: Any): - if value: - data[key] = value - elif bot := await BotConsole.get_or_none(bot_id=key): - data[key] = bot - - -@CacheRoot.getter(CacheType.BOT, result_model=BotConsole) -async def _(data: dict[str, BotConsole], bot_id: str): - result = data.get(bot_id, None) - if not result: - result = await BotConsole.get_or_none(bot_id=bot_id) - if result: - data[bot_id] = result - return result - - -@CacheRoot.new(CacheType.USERS) -async def _(): - data_list = await UserConsole.all() - return {p.user_id: p for p in data_list} - - -@CacheRoot.updater(CacheType.USERS) -async def _(data: dict[str, UserConsole], key: str, value: Any): - if value: - data[key] = value - elif user := await UserConsole.get_user(user_id=key): - data[key] = user - - -@CacheRoot.getter(CacheType.USERS, result_model=UserConsole) -async def _(data: dict[str, UserConsole], user_id: str): - result = data.get(user_id, None) - if not result: - result = await UserConsole.get_user(user_id=user_id) - if result: - data[user_id] = result - return result - - -@CacheRoot.new(CacheType.LEVEL) -async def _(): - return await LevelUser().all() - - -@CacheRoot.getter(CacheType.LEVEL, result_model=list[LevelUser]) -def _(data_list: list[LevelUser], user_id: str, group_id: str | None = None): - if not group_id: - return [ - data for data in data_list if data.user_id == user_id and not data.group_id - ] - else: - return [ - data - for data in data_list - if data.user_id == user_id and data.group_id == group_id - ] - - -@CacheRoot.new(CacheType.BAN) -async def _(): - return await BanConsole.all() - - -@CacheRoot.getter(CacheType.BAN, result_model=list[BanConsole]) -def _(data_list: list[BanConsole], user_id: str | None, group_id: str | None = None): - if user_id: - return ( - [ - data - for data in data_list - if data.user_id == user_id and data.group_id == group_id - ] - if group_id - else [ - data - for data in data_list - if data.user_id == user_id and not data.group_id - ] - ) - if group_id: - return [ - data for data in data_list if not data.user_id and data.group_id == group_id - ] - return None diff --git a/zhenxun/builtin_plugins/init/__init_cache.py b/zhenxun/builtin_plugins/init/__init_cache.py new file mode 100644 index 00000000..d7cec120 --- /dev/null +++ b/zhenxun/builtin_plugins/init/__init_cache.py @@ -0,0 +1,276 @@ +import time +from typing import Any + +import nonebot + +from zhenxun.models.ban_console import BanConsole +from zhenxun.models.bot_console import BotConsole +from zhenxun.models.group_console import GroupConsole +from zhenxun.models.level_user import LevelUser +from zhenxun.models.plugin_info import PluginInfo +from zhenxun.models.user_console import UserConsole +from zhenxun.services.cache import CacheData, CacheRoot +from zhenxun.utils.enum import CacheType + +driver = nonebot.get_driver() + + +@driver.on_startup +async def _(): + """开启cache检测""" + CacheRoot.start_check() + + +def default_cleanup_expired(cache_data: CacheData) -> list[str]: + """默认清理过期cache方法""" + if not cache_data.data: + return [] + now = time.time() + expire_key = [] + for k, t in list(cache_data.expire_data.items()): + if t < now: + expire_key.append(k) + cache_data.expire_data.pop(k) + if expire_key: + cache_data.data = { + k: v for k, v in cache_data.data.items() if k not in expire_key + } + return expire_key + + +def default_with_expiration( + data: dict[str, Any], expire_data: dict[str, int], expire: int +): + """默认更新期时间cache方法""" + keys = {k for k in data if k not in expire_data} + return {k: time.time() + expire for k in keys} if keys else {} + + +@CacheRoot.new(CacheType.PLUGINS) +async def _(): + data_list = await PluginInfo.get_plugins() + return {p.module: p for p in data_list} + + +@CacheRoot.new(CacheType.PLUGINS) +async def _(): + data_list = await PluginInfo.get_plugins() + return {p.module: p for p in data_list} + + +@CacheRoot.updater(CacheType.PLUGINS) +async def _(data: dict[str, PluginInfo], key: str, value: Any): + if value: + data[key] = value + elif plugin := await PluginInfo.get_plugin(module=key): + data[key] = plugin + + +@CacheRoot.getter(CacheType.PLUGINS, result_model=PluginInfo) +async def _(cache_data: CacheData, module: str): + cache_data.data = cache_data.data or {} + result = cache_data.data.get(module, None) + if not result: + result = await PluginInfo.get_plugin(module=module) + if result: + cache_data.data[module] = result + return result + + +@CacheRoot.with_refresh(CacheType.PLUGINS) +async def _(data: dict[str, PluginInfo]): + plugins = await PluginInfo.filter(module__in=data.keys(), load_status=True) + for plugin in plugins: + data[plugin.module] = plugin + + +@CacheRoot.with_expiration(CacheType.PLUGINS) +def _(data: dict[str, PluginInfo], expire_data: dict[str, int], expire: int): + return default_with_expiration(data, expire_data, expire) + + +@CacheRoot.cleanup_expired(CacheType.PLUGINS) +def _(cache_data: CacheData): + return default_cleanup_expired(cache_data) + + +@CacheRoot.new(CacheType.GROUPS) +async def _(): + data_list = await GroupConsole.all() + return {p.group_id: p for p in data_list if not p.channel_id} + + +@CacheRoot.updater(CacheType.GROUPS) +async def _(data: dict[str, GroupConsole], key: str, value: Any): + if value: + data[key] = value + elif group := await GroupConsole.get_group(group_id=key): + data[key] = group + + +@CacheRoot.getter(CacheType.GROUPS, result_model=GroupConsole) +async def _(data: dict[str, GroupConsole] | None, group_id: str): + if not data: + data = {} + result = data.get(group_id, None) + if not result: + result = await GroupConsole.get_group(group_id=group_id) + if result: + data[group_id] = result + return result + + +@CacheRoot.with_refresh(CacheType.GROUPS) +async def _(data: dict[str, GroupConsole]): + groups = await GroupConsole.filter( + group_id__in=data.keys(), channel_id__isnull=True, load_status=True + ) + for group in groups: + data[group.group_id] = group + + +@CacheRoot.with_expiration(CacheType.GROUPS) +def _(data: dict[str, GroupConsole], expire_data: dict[str, int], expire: int): + return default_with_expiration(data, expire_data, expire) + + +@CacheRoot.cleanup_expired(CacheType.GROUPS) +def _(cache_data: CacheData): + return default_cleanup_expired(cache_data) + + +@CacheRoot.new(CacheType.BOT) +async def _(): + data_list = await BotConsole.all() + return {p.bot_id: p for p in data_list} + + +@CacheRoot.updater(CacheType.BOT) +async def _(data: dict[str, BotConsole], key: str, value: Any): + if value: + data[key] = value + elif bot := await BotConsole.get_or_none(bot_id=key): + data[key] = bot + + +@CacheRoot.getter(CacheType.BOT, result_model=BotConsole) +async def _(data: dict[str, BotConsole] | None, bot_id: str): + if not data: + data = {} + result = data.get(bot_id, None) + if not result: + result = await BotConsole.get_or_none(bot_id=bot_id) + if result: + data[bot_id] = result + return result + + +@CacheRoot.with_refresh(CacheType.BOT) +async def _(data: dict[str, BotConsole]): + bots = await BotConsole.filter(bot_id__in=data.keys()) + for bot in bots: + data[bot.bot_id] = bot + + +@CacheRoot.with_expiration(CacheType.BOT) +def _(data: dict[str, BotConsole], expire_data: dict[str, int], expire: int): + return default_with_expiration(data, expire_data, expire) + + +@CacheRoot.cleanup_expired(CacheType.BOT) +def _(cache_data: CacheData): + return default_cleanup_expired(cache_data) + + +@CacheRoot.new(CacheType.USERS) +async def _(): + data_list = await UserConsole.all() + return {p.user_id: p for p in data_list} + + +@CacheRoot.updater(CacheType.USERS) +async def _(data: dict[str, UserConsole], key: str, value: Any): + if value: + data[key] = value + elif user := await UserConsole.get_user(user_id=key): + data[key] = user + + +@CacheRoot.getter(CacheType.USERS, result_model=UserConsole) +async def _(cache_data: CacheData, user_id: str): + cache_data.data = cache_data.data or {} + result = cache_data.data.get(user_id, None) + if not result: + result = await UserConsole.get_user(user_id=user_id) + if result: + cache_data.data[user_id] = result + return result + + +@CacheRoot.with_refresh(CacheType.USERS) +async def _(data: dict[str, UserConsole]): + users = await UserConsole.filter(user_id__in=data.keys()) + for user in users: + data[user.user_id] = user + + +@CacheRoot.with_expiration(CacheType.USERS) +def _(data: dict[str, UserConsole], expire_data: dict[str, int], expire: int): + return default_with_expiration(data, expire_data, expire) + + +@CacheRoot.cleanup_expired(CacheType.USERS) +def _(cache_data: CacheData): + return default_cleanup_expired(cache_data) + + +@CacheRoot.new(CacheType.LEVEL) +async def _(): + return await LevelUser().all() + + +@CacheRoot.getter(CacheType.LEVEL, result_model=list[LevelUser]) +async def _(cache_data: CacheData, user_id: str, group_id: str | None = None): + cache_data.data = cache_data.data or [] + if not group_id: + return [ + data + for data in cache_data.data + if data.user_id == user_id and not data.group_id + ] + else: + return [ + data + for data in cache_data.data + if data.user_id == user_id and data.group_id == group_id + ] + + +@CacheRoot.new(CacheType.BAN) +async def _(): + return await BanConsole.all() + + +@CacheRoot.getter(CacheType.BAN, result_model=list[BanConsole]) +def _(cache_data: CacheData, user_id: str | None, group_id: str | None = None): + if user_id: + return ( + [ + data + for data in cache_data.data + if data.user_id == user_id and data.group_id == group_id + ] + if group_id + else [ + data + for data in cache_data.data + if data.user_id == user_id and not data.group_id + ] + ) + if group_id: + return [ + data + for data in cache_data.data + if not data.user_id and data.group_id == group_id + ] + return None diff --git a/zhenxun/models/group_console.py b/zhenxun/models/group_console.py index 373e5f68..598ac34d 100644 --- a/zhenxun/models/group_console.py +++ b/zhenxun/models/group_console.py @@ -107,7 +107,6 @@ class GroupConsole(Model): return group @classmethod - @CacheRoot.listener(CacheType.GROUPS) async def get_or_create( cls, defaults: dict | None = None, @@ -132,6 +131,9 @@ class GroupConsole(Model): await group.save( using_db=using_db, update_fields=["block_plugin", "block_task"] ) + if is_create: + if cache := await CacheRoot.get_cache(CacheType.GROUPS): + await cache.update(group.group_id, group) return group, is_create @classmethod diff --git a/zhenxun/services/cache.py b/zhenxun/services/cache.py index 7170bef1..84f128a6 100644 --- a/zhenxun/services/cache.py +++ b/zhenxun/services/cache.py @@ -1,9 +1,11 @@ from collections.abc import Callable from functools import wraps +import inspect import time from typing import Any, ClassVar, Generic, TypeVar, cast from nonebot.utils import is_coroutine_callable +from nonebot_plugin_apscheduler import scheduler from pydantic import BaseModel from zhenxun.services.log import logger @@ -13,18 +15,43 @@ __all__ = ["Cache", "CacheData", "CacheRoot"] T = TypeVar("T") +class DbCacheException(Exception): + def __init__(self, info: str): + self.info = info + + def __repr__(self) -> str: + return super().__repr__() + + def __str__(self) -> str: + return self.info + + +def validate_name(func: Callable): + """ + 装饰器:验证 name 是否存在于 CacheManage._data 中。 + """ + + def wrapper(self, name: str, *args, **kwargs): + _name = name.upper() + if _name not in CacheManage._data: + raise DbCacheException(f"DbCache 缓存数据 {name} 不存在...") + return func(self, _name, *args, **kwargs) + + return wrapper + + class CacheGetter(BaseModel, Generic[T]): get_func: Callable[..., Any] | None = None """获取方法""" - async def get(self, data: Any, *args, **kwargs) -> T: + async def get(self, cache_data: "CacheData", *args, **kwargs) -> T: """获取缓存""" processed_data = ( - await self.get_func(data, *args, **kwargs) + await self.get_func(cache_data, *args, **kwargs) if self.get_func and is_coroutine_callable(self.get_func) - else self.get_func(data, *args, **kwargs) + else self.get_func(cache_data, *args, **kwargs) if self.get_func - else data + else cache_data.data ) return cast(T, processed_data) @@ -38,10 +65,18 @@ class CacheData(BaseModel): """获取方法""" updater: Callable[..., Any] | None = None """更新单个方法""" + with_refresh: Callable[..., Any] | None = None + """刷新方法""" + with_expiration: Callable[..., Any] | None = None + """缓存时间初始化方法""" + cleanup_expired: Callable[..., Any] | None = None + """缓存过期方法""" data: Any = None """缓存数据""" expire: int """缓存过期时间""" + expire_data: dict[str, int | float] = {} + """缓存过期数据时间记录""" reload_time: float = time.time() """更新时间""" reload_count: int = 0 @@ -49,9 +84,12 @@ class CacheData(BaseModel): async def get(self, *args, **kwargs) -> Any: """获取单个缓存""" + self.call_cleanup_expired() # 移除过期缓存 if not self.getter: return self.data - return await self.getter.get(self.data, *args, **kwargs) + result = await self.getter.get(self, *args, **kwargs) + await self.call_with_expiration() + return result async def update(self, key: str, value: Any = None, *args, **kwargs): """更新单个缓存""" @@ -64,23 +102,88 @@ class CacheData(BaseModel): await self.updater(self.data, key, value, *args, **kwargs) else: self.updater(self.data, key, value, *args, **kwargs) + logger.debug( + f"缓存类型 {self.name} 更新单个缓存 key: {key},value: {value}", + "CacheRoot", + ) + self.expire_data[key] = time.time() + self.expire else: logger.warning(f"缓存类型 {self.name} 为空,无法更新", "CacheRoot") + async def refresh(self, *args, **kwargs): + """刷新缓存,只刷新已缓存的数据""" + if not self.with_refresh: + return await self.reload(*args, **kwargs) + if self.data: + if is_coroutine_callable(self.with_refresh): + await self.with_refresh(self.data, *args, **kwargs) + else: + self.with_refresh(self.data, *args, **kwargs) + logger.debug( + f"缓存类型 {self.name} 刷新全局缓存,共刷新 {len(self.data)} 条数据", + "CacheRoot", + ) + async def reload(self, *args, **kwargs): - """更新缓存""" - self.data = ( - await self.func(*args, **kwargs) - if is_coroutine_callable(self.func) - else self.func(*args, **kwargs) - ) + """更新全部缓存数据""" + if self.has_args(): + self.data = ( + await self.func(*args, **kwargs) + if is_coroutine_callable(self.func) + else self.func(*args, **kwargs) + ) + else: + self.data = ( + await self.func() if is_coroutine_callable(self.func) else self.func() + ) + await self.call_with_expiration() self.reload_time = time.time() self.reload_count += 1 - logger.debug(f"缓存类型 {self.name} 更新全局缓存", "CacheRoot") + logger.debug( + f"缓存类型 {self.name} 更新全局缓存,共更新 {len(self.data)} 条数据", + "CacheRoot", + ) - async def check_expire(self): - if time.time() - self.reload_time > self.expire or not self.reload_count: - await self.reload() + def call_cleanup_expired(self): + """清理过期缓存""" + if self.cleanup_expired: + if result := self.cleanup_expired(self): + logger.debug( + f"成功清理 {self.name} {len(result)} 条过期缓存", "CacheRoot" + ) + + async def call_with_expiration(self, is_force: bool = False): + """缓存时间更新 + + 参数: + is_force: 是否强制更新全部数据缓存时间. + """ + if self.with_expiration: + if is_force: + self.expire_data = {} + expiration_data = ( + await self.with_expiration(self.data, self.expire_data, self.expire) + if is_coroutine_callable(self.with_expiration) + else self.with_expiration(self.data, self.expire_data, self.expire) + ) + self.expire_data = {**self.expire_data, **expiration_data} + + def has_args(self): + """是否含有参数 + + 返回: + bool: 是否含有参数 + """ + sig = inspect.signature(self.func) + return any( + param.kind + in ( + param.POSITIONAL_OR_KEYWORD, + param.POSITIONAL_ONLY, + param.VAR_POSITIONAL, + ) + for param in sig.parameters.values() + ) class CacheManage: @@ -88,18 +191,30 @@ class CacheManage: 异常: - ValueError: 数据名称重复 - ValueError: 数据不存在 + DbCacheException: 数据名称重复 + DbCacheException: 数据不存在 """ _data: ClassVar[dict[str, CacheData]] = {} + def start_check(self): + """启动缓存检查""" + for cache_data in self._data.values(): + if cache_data.cleanup_expired: + scheduler.add_job( + cache_data.call_cleanup_expired, + "interval", + seconds=cache_data.expire, + args=[], + id=f"CacheRoot-{cache_data.name}", + ) + def new(self, name: str, expire: int = 60 * 10): def wrapper(func: Callable): _name = name.upper() if _name in self._data: - raise ValueError(f"DbCache 缓存数据 {name} 已存在...") + raise DbCacheException(f"DbCache 缓存数据 {name} 已存在...") self._data[_name] = CacheData(name=_name, func=func, expire=expire) return wrapper @@ -116,8 +231,12 @@ class CacheManage: return result finally: cache_data = self._data.get(name.upper()) - if cache_data: - await cache_data.reload() + if cache_data and cache_data.with_refresh: + if is_coroutine_callable(cache_data.with_refresh): + await cache_data.with_refresh(cache_data.data) + else: + cache_data.with_refresh(cache_data.data) + await cache_data.call_with_expiration(True) logger.debug( f"缓存类型 {name.upper()} 进行监听更新...", "CacheRoot" ) @@ -126,48 +245,61 @@ class CacheManage: return decorator + @validate_name def updater(self, name: str): def wrapper(func: Callable): - _name = name.upper() - if _name not in self._data: - raise ValueError(f"DbCache 缓存数据 {name} 不存在...") - self._data[_name].updater = func + self._data[name.upper()].updater = func return wrapper + @validate_name def getter(self, name: str, result_model: type | None = None): def wrapper(func: Callable): - _name = name.upper() - if _name not in self._data: - raise ValueError(f"DbCache 缓存数据 {name} 不存在...") - self._data[_name].getter = CacheGetter[result_model](get_func=func) + self._data[name].getter = CacheGetter[result_model](get_func=func) return wrapper - async def check_expire(self, name: str): + @validate_name + def with_refresh(self, name: str): + def wrapper(func: Callable): + self._data[name.upper()].with_refresh = func + + return wrapper + + @validate_name + def with_expiration(self, name: str): + def wrapper(func: Callable[[Any, int], dict[str, float]]): + self._data[name.upper()].with_expiration = func + + return wrapper + + @validate_name + def cleanup_expired(self, name: str): + def wrapper(func: Callable[[CacheData], None]): + self._data[name.upper()].cleanup_expired = func + + return wrapper + + async def check_expire(self, name: str, *args, **kwargs): name = name.upper() - if self._data.get(name): - if ( - time.time() - self._data[name].reload_time > self._data[name].expire - or not self._data[name].reload_count - ): - await self._data[name].reload() + if self._data.get(name) and ( + time.time() - self._data[name].reload_time > self._data[name].expire + or not self._data[name].reload_count + ): + await self._data[name].reload(*args, **kwargs) async def get_cache_data(self, name: str): - if cache := await self.get_cache(name): - return cache.data - return None + return cache.data if (cache := await self.get_cache(name)) else None - async def get_cache(self, name: str) -> CacheData | None: + async def get_cache(self, name: str, *args, **kwargs) -> CacheData | None: name = name.upper() - cache = self._data.get(name) - if cache: - await self.check_expire(name) + if cache := self._data.get(name): + # await self.check_expire(name, *args, **kwargs) return cache return None async def get(self, name: str, *args, **kwargs): - cache = await self.get_cache(name.upper()) + cache = await self.get_cache(name.upper(), *args, **kwargs) if cache: return await cache.get(*args, **kwargs) if cache.getter else cache.data return None @@ -175,7 +307,7 @@ class CacheManage: async def reload(self, name: str, *args, **kwargs): cache = await self.get_cache(name.upper()) if cache: - await cache.reload(*args, **kwargs) + await cache.refresh(*args, **kwargs) async def update(self, name: str, key: str, value: Any, *args, **kwargs): cache = await self.get_cache(name.upper())