diff --git a/zhenxun/builtin_plugins/init/__init__.py b/zhenxun/builtin_plugins/init/__init__.py index ead34817..607e94fa 100644 --- a/zhenxun/builtin_plugins/init/__init__.py +++ b/zhenxun/builtin_plugins/init/__init__.py @@ -11,7 +11,7 @@ from zhenxun.utils.platform import PlatformUtils nonebot.load_plugins(str(Path(__file__).parent.resolve())) try: - from .__init_cache import driver + from . import __init_cache except DbCacheException as e: raise SystemError(f"ERROR:{e}") diff --git a/zhenxun/builtin_plugins/init/__init_cache.py b/zhenxun/builtin_plugins/init/__init_cache.py index 3bb04711..e6ac5bf6 100644 --- a/zhenxun/builtin_plugins/init/__init_cache.py +++ b/zhenxun/builtin_plugins/init/__init_cache.py @@ -1,8 +1,5 @@ -import time from typing import Any -import nonebot - from zhenxun.models.ban_console import BanConsole from zhenxun.models.bot_console import BotConsole from zhenxun.models.group_console import GroupConsole @@ -11,77 +8,20 @@ from zhenxun.models.plugin_info import PluginInfo from zhenxun.models.plugin_limit import PluginLimit from zhenxun.models.user_console import UserConsole from zhenxun.services.cache import CacheData, CacheRoot +from zhenxun.services.log import logger from zhenxun.utils.enum import CacheType -driver = nonebot.get_driver() - - -@driver.on_startup -async def _(): - """开启cache检测""" - CacheRoot.start_check() - - -def default_cleanup_expired(cache_data: CacheData) -> list[str]: - """默认清理过期cache方法""" - if not cache_data.data: - return [] - now = time.time() - expire_key = [] - for k, t in list(cache_data.expire_data.items()): - if t < now: - expire_key.append(k) - cache_data.expire_data.pop(k) - if expire_key: - cache_data.data = { - k: v for k, v in cache_data.data.items() if k not in expire_key - } - return expire_key - - -def default_cleanup_expired_1(cache_data: CacheData) -> list[str]: - """默认清理列表过期cache方法""" - if not cache_data.data: - return [] - now = time.time() - expire_key = [] - for k, t in list(cache_data.expire_data.items()): - if t < now: - expire_key.append(k) - cache_data.expire_data.pop(k) - if expire_key: - cache_data.data = [k for k in cache_data.data if repr(k) not in expire_key] - return expire_key - - -def default_with_expiration( - data: dict[str, Any], expire_data: dict[str, int], expire: int -): - """默认更新过期时间cache方法""" - if not data: - return {} - keys = {k for k in data if k not in expire_data} - return {k: time.time() + expire for k in keys} if keys else {} - - -def default_with_expiration_1( - data: dict[str, Any], expire_data: dict[str, int], expire: int -): - """默认更新过期时间cache方法""" - if not data: - return {} - keys = {repr(k) for k in data if repr(k) not in expire_data} - return {k: time.time() + expire for k in keys} if keys else {} - @CacheRoot.new(CacheType.PLUGINS) async def _(): + """初始化插件缓存""" data_list = await PluginInfo.get_plugins() return {p.module: p for p in data_list} @CacheRoot.updater(CacheType.PLUGINS) async def _(data: dict[str, PluginInfo], key: str, value: Any): + """更新插件缓存""" if value: data[key] = value elif plugin := await PluginInfo.get_plugin(module=key): @@ -90,42 +30,35 @@ async def _(data: dict[str, PluginInfo], key: str, value: Any): @CacheRoot.getter(CacheType.PLUGINS, result_model=PluginInfo) async def _(cache_data: CacheData, module: str): - cache_data.data = cache_data.data or {} - result = cache_data.data.get(module, None) - if not result: - result = await PluginInfo.get_plugin(module=module) - if result: - cache_data.data[module] = result - return result + """获取插件缓存""" + data = await cache_data.get_data() or {} + if module not in data: + if plugin := await PluginInfo.get_plugin(module=module): + data[module] = plugin + await cache_data.set_data(data) + logger.debug(f"插件 {module} 数据已设置到缓存") + return data.get(module) @CacheRoot.with_refresh(CacheType.PLUGINS) async def _(data: dict[str, PluginInfo] | None): + """刷新插件缓存""" if not data: return plugins = await PluginInfo.filter(module__in=data.keys(), load_status=True).all() - for plugin in plugins: - data[plugin.module] = plugin - - -@CacheRoot.with_expiration(CacheType.PLUGINS) -def _(data: dict[str, PluginInfo], expire_data: dict[str, int], expire: int): - return default_with_expiration(data, expire_data, expire) - - -@CacheRoot.cleanup_expired(CacheType.PLUGINS) -def _(cache_data: CacheData): - return default_cleanup_expired(cache_data) + data.update({p.module: p for p in plugins}) @CacheRoot.new(CacheType.GROUPS) async def _(): + """初始化群组缓存""" data_list = await GroupConsole.all() return {p.group_id: p for p in data_list if not p.channel_id} @CacheRoot.updater(CacheType.GROUPS) async def _(data: dict[str, GroupConsole], key: str, value: Any): + """更新群组缓存""" if value: data[key] = value elif group := await GroupConsole.get_group(group_id=key): @@ -134,44 +67,36 @@ async def _(data: dict[str, GroupConsole], key: str, value: Any): @CacheRoot.getter(CacheType.GROUPS, result_model=GroupConsole) async def _(cache_data: CacheData, group_id: str): - cache_data.data = cache_data.data or {} - result = cache_data.data.get(group_id, None) - if not result: - result = await GroupConsole.get_group(group_id=group_id) - if result: - cache_data.data[group_id] = result - return result + """获取群组缓存""" + data = await cache_data.get_data() or {} + if group_id not in data: + if group := await GroupConsole.get_group(group_id=group_id): + data[group_id] = group + await cache_data.set_data(data) + return data.get(group_id) @CacheRoot.with_refresh(CacheType.GROUPS) async def _(data: dict[str, GroupConsole] | None): + """刷新群组缓存""" if not data: return groups = await GroupConsole.filter( group_id__in=data.keys(), channel_id__isnull=True ).all() - for group in groups: - data[group.group_id] = group - - -@CacheRoot.with_expiration(CacheType.GROUPS) -def _(data: dict[str, GroupConsole], expire_data: dict[str, int], expire: int): - return default_with_expiration(data, expire_data, expire) - - -@CacheRoot.cleanup_expired(CacheType.GROUPS) -def _(cache_data: CacheData): - return default_cleanup_expired(cache_data) + data.update({g.group_id: g for g in groups}) @CacheRoot.new(CacheType.BOT) async def _(): + """初始化机器人缓存""" data_list = await BotConsole.all() return {p.bot_id: p for p in data_list} @CacheRoot.updater(CacheType.BOT) async def _(data: dict[str, BotConsole], key: str, value: Any): + """更新机器人缓存""" if value: data[key] = value elif bot := await BotConsole.get_or_none(bot_id=key): @@ -180,42 +105,34 @@ async def _(data: dict[str, BotConsole], key: str, value: Any): @CacheRoot.getter(CacheType.BOT, result_model=BotConsole) async def _(cache_data: CacheData, bot_id: str): - cache_data.data = cache_data.data or {} - result = cache_data.data.get(bot_id, None) - if not result: - result = await BotConsole.get_or_none(bot_id=bot_id) - if result: - cache_data.data[bot_id] = result - return result + """获取机器人缓存""" + data = await cache_data.get_data() or {} + if bot_id not in data: + if bot := await BotConsole.get_or_none(bot_id=bot_id): + data[bot_id] = bot + await cache_data.set_data(data) + return data.get(bot_id) @CacheRoot.with_refresh(CacheType.BOT) async def _(data: dict[str, BotConsole] | None): + """刷新机器人缓存""" if not data: return bots = await BotConsole.filter(bot_id__in=data.keys()).all() - for bot in bots: - data[bot.bot_id] = bot - - -@CacheRoot.with_expiration(CacheType.BOT) -def _(data: dict[str, BotConsole], expire_data: dict[str, int], expire: int): - return default_with_expiration(data, expire_data, expire) - - -@CacheRoot.cleanup_expired(CacheType.BOT) -def _(cache_data: CacheData): - return default_cleanup_expired(cache_data) + data.update({b.bot_id: b for b in bots}) @CacheRoot.new(CacheType.USERS) async def _(): + """初始化用户缓存""" data_list = await UserConsole.all() return {p.user_id: p for p in data_list} @CacheRoot.updater(CacheType.USERS) async def _(data: dict[str, UserConsole], key: str, value: Any): + """更新用户缓存""" if value: data[key] = value elif user := await UserConsole.get_user(user_id=key): @@ -224,108 +141,61 @@ async def _(data: dict[str, UserConsole], key: str, value: Any): @CacheRoot.getter(CacheType.USERS, result_model=UserConsole) async def _(cache_data: CacheData, user_id: str): - cache_data.data = cache_data.data or {} - result = cache_data.data.get(user_id, None) - if not result: - result = await UserConsole.get_user(user_id=user_id) - if result: - cache_data.data[user_id] = result - return result + """获取用户缓存""" + data = await cache_data.get_data() or {} + if user_id not in data: + if user := await UserConsole.get_user(user_id=user_id): + data[user_id] = user + await cache_data.set_data(data) + return data.get(user_id) @CacheRoot.with_refresh(CacheType.USERS) async def _(data: dict[str, UserConsole] | None): + """刷新用户缓存""" if not data: return users = await UserConsole.filter(user_id__in=data.keys()).all() - for user in users: - data[user.user_id] = user - - -@CacheRoot.with_expiration(CacheType.USERS) -def _(data: dict[str, UserConsole], expire_data: dict[str, int], expire: int): - return default_with_expiration(data, expire_data, expire) - - -@CacheRoot.cleanup_expired(CacheType.USERS) -def _(cache_data: CacheData): - return default_cleanup_expired(cache_data) + data.update({u.user_id: u for u in users}) @CacheRoot.new(CacheType.LEVEL, False) async def _(): + """初始化等级缓存""" return await LevelUser().all() @CacheRoot.getter(CacheType.LEVEL, result_model=list[LevelUser]) async def _(cache_data: CacheData, user_id: str, group_id: str | None = None): - cache_data.data = cache_data.data or [] + """获取等级缓存""" + data = await cache_data.get_data() or [] if not group_id: - return [ - data - for data in cache_data.data - if data.user_id == user_id and not data.group_id - ] - else: - return [ - data - for data in cache_data.data - if data.user_id == user_id and data.group_id == group_id - ] - - -@CacheRoot.with_expiration(CacheType.LEVEL) -def _(data: dict[str, UserConsole], expire_data: dict[str, int], expire: int): - return default_with_expiration_1(data, expire_data, expire) - - -@CacheRoot.cleanup_expired(CacheType.LEVEL) -def _(cache_data: CacheData): - return default_cleanup_expired_1(cache_data) + return [d for d in data if d.user_id == user_id and not d.group_id] + return [d for d in data if d.user_id == user_id and d.group_id == group_id] @CacheRoot.new(CacheType.BAN, False) async def _(): + """初始化封禁缓存""" return await BanConsole.all() @CacheRoot.getter(CacheType.BAN, result_model=list[BanConsole]) async def _(cache_data: CacheData, user_id: str | None, group_id: str | None = None): + """获取封禁缓存""" + data = await cache_data.get_data() or [] if user_id: - return ( - [ - data - for data in cache_data.data - if data.user_id == user_id and data.group_id == group_id - ] - if group_id - else [ - data - for data in cache_data.data - if data.user_id == user_id and not data.group_id - ] - ) + if group_id: + return [d for d in data if d.user_id == user_id and d.group_id == group_id] + return [d for d in data if d.user_id == user_id and not d.group_id] if group_id: - return [ - data - for data in cache_data.data - if not data.user_id and data.group_id == group_id - ] + return [d for d in data if not d.user_id and d.group_id == group_id] return None -@CacheRoot.with_expiration(CacheType.BAN) -def _(data: dict[str, UserConsole], expire_data: dict[str, int], expire: int): - return default_with_expiration_1(data, expire_data, expire) - - -@CacheRoot.cleanup_expired(CacheType.BAN) -def _(cache_data: CacheData): - return default_cleanup_expired_1(cache_data) - - @CacheRoot.new(CacheType.LIMIT) async def _(): + """初始化限制缓存""" data_list = await PluginLimit.filter(status=True).all() result_data = {} for data in data_list: @@ -337,6 +207,7 @@ async def _(): @CacheRoot.updater(CacheType.LIMIT) async def _(data: dict[str, list[PluginLimit]], key: str, value: Any): + """更新限制缓存""" if value: data[key] = value elif limits := await PluginLimit.filter(module=key, status=True): @@ -345,32 +216,25 @@ async def _(data: dict[str, list[PluginLimit]], key: str, value: Any): @CacheRoot.getter(CacheType.LIMIT, result_model=list[PluginLimit]) async def _(cache_data: CacheData, module: str): - cache_data.data = cache_data.data or {} - result = cache_data.data.get(module, None) - if not result: - result = await PluginLimit.filter(module=module, status=True) - if result: - cache_data.data[module] = result - return result + """获取限制缓存""" + data = await cache_data.get_data() or {} + if module not in data: + if limits := await PluginLimit.filter(module=module, status=True): + data[module] = limits + await cache_data.set_data(data) + return data.get(module) @CacheRoot.with_refresh(CacheType.LIMIT) async def _(data: dict[str, list[PluginLimit]] | None): + """刷新限制缓存""" if not data: return limits = await PluginLimit.filter(module__in=data.keys(), load_status=True).all() - data.clear() + new_data = {} for limit in limits: - if not data.get(limit.module): - data[limit.module] = [] - data[limit.module].append(limit) - - -@CacheRoot.with_expiration(CacheType.LIMIT) -def _(data: dict[str, PluginInfo], expire_data: dict[str, int], expire: int): - return default_with_expiration(data, expire_data, expire) - - -@CacheRoot.cleanup_expired(CacheType.LIMIT) -def _(cache_data: CacheData): - return default_cleanup_expired(cache_data) + if not new_data.get(limit.module): + new_data[limit.module] = [] + new_data[limit.module].append(limit) + data.clear() + data.update(new_data) diff --git a/zhenxun/services/cache.py b/zhenxun/services/cache.py index 29774447..f7d8ab55 100644 --- a/zhenxun/services/cache.py +++ b/zhenxun/services/cache.py @@ -1,11 +1,14 @@ from collections.abc import Callable +from datetime import datetime from functools import wraps import inspect -import time -from typing import Any, ClassVar, Generic, TypeVar, cast +from typing import Any, ClassVar, Generic, TypeVar +from aiocache import Cache as AioCache +from aiocache.base import BaseCache +from aiocache.serializers import JsonSerializer +from nonebot.compat import model_dump from nonebot.utils import is_coroutine_callable -from nonebot_plugin_apscheduler import scheduler from pydantic import BaseModel from zhenxun.services.log import logger @@ -16,168 +19,258 @@ T = TypeVar("T") class DbCacheException(Exception): + """缓存相关异常""" + def __init__(self, info: str): self.info = info - def __repr__(self) -> str: - return super().__repr__() - def __str__(self) -> str: return self.info def validate_name(func: Callable): - """ - 装饰器:验证 name 是否存在于 CacheManage._data 中。 - """ + """验证缓存名称是否存在的装饰器""" def wrapper(self, name: str, *args, **kwargs): _name = name.upper() if _name not in CacheManager._data: - raise DbCacheException(f"DbCache 缓存数据 {name} 不存在...") + raise DbCacheException(f"缓存数据 {name} 不存在") return func(self, _name, *args, **kwargs) return wrapper class CacheGetter(BaseModel, Generic[T]): + """缓存数据获取器""" + get_func: Callable[..., Any] | None = None - """获取方法""" async def get(self, cache_data: "CacheData", *args, **kwargs) -> T: - """获取缓存""" + """获取处理后的缓存数据""" if not self.get_func: - return cache_data.data + return await cache_data.get_data() + if is_coroutine_callable(self.get_func): - processed_data = await self.get_func(cache_data, *args, **kwargs) - else: - processed_data = self.get_func(cache_data, *args, **kwargs) - return cast(T, processed_data) + return await self.get_func(cache_data, *args, **kwargs) + return self.get_func(cache_data, *args, **kwargs) class CacheData(BaseModel): + """缓存数据模型""" + name: str - """缓存名称""" func: Callable[..., Any] - """更新方法""" getter: CacheGetter | None = None - """获取方法""" updater: Callable[..., Any] | None = None - """更新单个方法""" with_refresh: Callable[..., Any] | None = None - """刷新方法""" - with_expiration: Callable[..., Any] | None = None - """缓存时间初始化方法""" - cleanup_expired: Callable[..., Any] | None = None - """缓存过期方法""" - data: Any = None - """缓存数据""" - expire: int - """缓存过期时间""" - expire_data: dict[str, int | float] = {} - """缓存过期数据时间记录""" - reload_time: float = time.time() - """更新时间""" + expire: int = 600 # 默认10分钟过期 reload_count: int = 0 - """更新次数""" incremental_update: bool = True - """是否是增量更新""" - async def get(self, *args, **kwargs) -> Any: - """获取单个缓存""" - if not self.reload_count and not self.incremental_update: - # 首次获取时,非增量更新获取全部数据 - await self.reload() - self.call_cleanup_expired() # 移除过期缓存 - if not self.getter: - return self.data - result = await self.getter.get(self, *args, **kwargs) - await self.call_with_expiration() - return result + class Config: + arbitrary_types_allowed = True - async def update(self, key: str, value: Any = None, *args, **kwargs): - """更新单个缓存""" - if not self.updater: - return logger.warning( - f"缓存类型 {self.name} 没有更新方法,无法更新", "CacheRoot" - ) - if self.data: - if is_coroutine_callable(self.updater): - await self.updater(self.data, key, value, *args, **kwargs) - else: - self.updater(self.data, key, value, *args, **kwargs) - logger.debug( - f"缓存类型 {self.name} 更新单个缓存 key: {key},value: {value}", - "CacheRoot", - ) - self.expire_data[key] = time.time() + self.expire - else: - logger.warning(f"缓存类型 {self.name} 为空,无法更新", "CacheRoot") - - async def refresh(self, *args, **kwargs): - """刷新缓存,只刷新已缓存的数据""" - if not self.with_refresh: - return await self.reload(*args, **kwargs) - if self.data: - if is_coroutine_callable(self.with_refresh): - await self.with_refresh(self.data, *args, **kwargs) - else: - self.with_refresh(self.data, *args, **kwargs) - logger.debug( - f"缓存类型 {self.name} 刷新全局缓存,共刷新 {len(self.data)} 条数据", - "CacheRoot", - ) - - async def reload(self, *args, **kwargs): - """更新全部缓存数据""" - if self.has_args(): - self.data = ( - await self.func(*args, **kwargs) - if is_coroutine_callable(self.func) - else self.func(*args, **kwargs) - ) - else: - self.data = ( - await self.func() if is_coroutine_callable(self.func) else self.func() - ) - await self.call_with_expiration() - self.reload_time = time.time() - self.reload_count += 1 - logger.debug( - f"缓存类型 {self.name} 更新全局缓存,共更新 {len(self.data)} 条数据", - "CacheRoot", + @property + def _cache(self) -> BaseCache: + """获取aiocache实例""" + return AioCache( + AioCache.MEMORY, + serializer=JsonSerializer(), + namespace="zhenxun_cache", + timeout=30, # 操作超时时间 + ttl=self.expire, # 设置默认过期时间 ) - def call_cleanup_expired(self): - """清理过期缓存""" - if self.cleanup_expired: - if result := self.cleanup_expired(self): - logger.debug( - f"成功清理 {self.name} {len(result)} 条过期缓存", "CacheRoot" + async def get_data(self) -> Any: + """从缓存获取数据""" + try: + data = await self._cache.get(self.name) + logger.debug(f"获取缓存 {self.name} 数据: {data}") + + # 如果数据为空,尝试重新加载 + if data is None: + logger.debug(f"缓存 {self.name} 数据为空,尝试重新加载") + try: + if self.has_args(): + new_data = ( + await self.func() + if is_coroutine_callable(self.func) + else self.func() + ) + else: + new_data = ( + await self.func() + if is_coroutine_callable(self.func) + else self.func() + ) + + await self.set_data(new_data) + self.reload_count += 1 + logger.info(f"重新加载缓存 {self.name} 完成") + return new_data + except Exception as e: + logger.error(f"重新加载缓存 {self.name} 失败: {e}") + return None + + return data + except Exception as e: + logger.error(f"获取缓存 {self.name} 失败: {e}") + return None + + def _serialize_value(self, value: Any) -> Any: + """序列化值,将数据转换为JSON可序列化的格式 + + Args: + value: 需要序列化的值 + + Returns: + JSON可序列化的值 + """ + if value is None: + return None + + # 处理datetime + if isinstance(value, datetime): + return value.isoformat() + + # 处理Tortoise-ORM Model + if hasattr(value, "_meta") and hasattr(value, "__dict__"): + result = {} + for field in value._meta.fields: + try: + field_value = getattr(value, field) + # 跳过反向关系字段 + if isinstance(field_value, list | set) and hasattr( + field_value, "_related_name" + ): + continue + result[field] = self._serialize_value(field_value) + except AttributeError: + continue + return result + # 处理Pydantic模型 + elif isinstance(value, BaseModel): + return model_dump(value) + elif isinstance(value, dict): + # 处理字典 + return {str(k): self._serialize_value(v) for k, v in value.items()} + elif isinstance(value, list | tuple | set): + # 处理列表、元组、集合 + return [self._serialize_value(item) for item in value] + elif isinstance(value, int | float | str | bool): + # 基本类型直接返回 + return value + elif hasattr(value, "__dict__"): + # 处理普通类对象 + return self._serialize_value(value.__dict__) + else: + # 其他类型转换为字符串 + return str(value) + + async def set_data(self, value: Any): + """设置缓存数据""" + try: + # 1. 序列化数据 + serialized_value = self._serialize_value(value) + logger.debug(f"设置缓存 {self.name} 原始数据: {value}") + logger.debug(f"设置缓存 {self.name} 序列化后数据: {serialized_value}") + + # 2. 删除旧数据 + await self._cache.delete(self.name) + logger.debug(f"删除缓存 {self.name} 旧数据") + + # 3. 设置新数据 + await self._cache.set(self.name, serialized_value, ttl=self.expire) + logger.debug(f"设置缓存 {self.name} 新数据完成") + + # 4. 立即验证 + cached_data = await self._cache.get(self.name) + if cached_data is None: + logger.error(f"缓存 {self.name} 设置失败:数据验证失败") + # 5. 如果验证失败,尝试重新设置 + await self._cache.set(self.name, serialized_value, ttl=self.expire) + cached_data = await self._cache.get(self.name) + if cached_data is None: + logger.error(f"缓存 {self.name} 重试设置失败") + else: + logger.debug(f"缓存 {self.name} 重试设置成功: {cached_data}") + else: + logger.debug(f"缓存 {self.name} 数据验证成功: {cached_data}") + except Exception as e: + logger.error(f"设置缓存 {self.name} 失败: {e}") + raise # 重新抛出异常,让上层处理 + + async def delete_data(self): + """删除缓存数据""" + try: + await self._cache.delete(self.name) + except Exception as e: + logger.error(f"删除缓存 {self.name}", e=e) + + async def get(self, *args, **kwargs) -> Any: + """获取缓存""" + if not self.reload_count and not self.incremental_update: + await self.reload(*args, **kwargs) + + if not self.getter: + return await self.get_data() + + return await self.getter.get(self, *args, **kwargs) + + async def update(self, key: str, value: Any = None, *args, **kwargs): + """更新单个缓存项""" + if not self.updater: + logger.warning(f"缓存 {self.name} 未配置更新方法") + return + + current_data = await self.get_data() or {} + if is_coroutine_callable(self.updater): + await self.updater(current_data, key, value, *args, **kwargs) + else: + self.updater(current_data, key, value, *args, **kwargs) + + await self.set_data(current_data) + logger.debug(f"更新缓存 {self.name}.{key}") + + async def refresh(self, *args, **kwargs): + """刷新缓存数据""" + if not self.with_refresh: + return await self.reload(*args, **kwargs) + + current_data = await self.get_data() + if current_data: + if is_coroutine_callable(self.with_refresh): + await self.with_refresh(current_data, *args, **kwargs) + else: + self.with_refresh(current_data, *args, **kwargs) + await self.set_data(current_data) + logger.debug(f"刷新缓存 {self.name}") + + async def reload(self, *args, **kwargs): + """重新加载全部数据""" + try: + if self.has_args(): + new_data = ( + await self.func(*args, **kwargs) + if is_coroutine_callable(self.func) + else self.func(*args, **kwargs) + ) + else: + new_data = ( + await self.func() + if is_coroutine_callable(self.func) + else self.func() ) - async def call_with_expiration(self, is_force: bool = False): - """缓存时间更新 + await self.set_data(new_data) + self.reload_count += 1 + logger.info(f"重新加载缓存 {self.name} 完成") + except Exception as e: + logger.error(f"重新加载缓存 {self.name} 失败: {e}") + raise - 参数: - is_force: 是否强制更新全部数据缓存时间. - """ - if self.with_expiration: - if is_force: - self.expire_data = {} - expiration_data = ( - await self.with_expiration(self.data, self.expire_data, self.expire) - if is_coroutine_callable(self.with_expiration) - else self.with_expiration(self.data, self.expire_data, self.expire) - ) - self.expire_data = {**self.expire_data, **expiration_data} - - def has_args(self): - """是否含有参数 - - 返回: - bool: 是否含有参数 - """ + def has_args(self) -> bool: + """检查函数是否需要参数""" sig = inspect.signature(self.func) return any( param.kind @@ -189,66 +282,81 @@ class CacheData(BaseModel): for param in sig.parameters.values() ) + async def get_key(self, key: str) -> Any: + """获取缓存中指定键的数据 + + Args: + key: 要获取的键名 + + Returns: + 键对应的值,如果不存在返回None + """ + try: + data = await self.get_data() + return data.get(key) if isinstance(data, dict) else None + except Exception as e: + logger.error(f"获取缓存 {self.name}.{key} 失败: {e}") + return None + + async def get_keys(self, keys: list[str]) -> dict[str, Any]: + """获取缓存中多个键的数据 + + Args: + keys: 要获取的键名列表 + + Returns: + 包含所有请求键值的字典,不存在的键值为None + """ + try: + data = await self.get_data() + if isinstance(data, dict): + return {key: data.get(key) for key in keys} + return dict.fromkeys(keys) + except Exception as e: + logger.error(f"获取缓存 {self.name} 的多个键失败: {e}") + return dict.fromkeys(keys) + class CacheManager: - """全局缓存管理,减少数据库与网络请求查询次数 - - - 异常: - DbCacheException: 数据名称重复 - DbCacheException: 数据不存在 - - """ + """全局缓存管理器""" _data: ClassVar[dict[str, CacheData]] = {} - def start_check(self): - """启动缓存检查""" - for cache_data in self._data.values(): - if cache_data.cleanup_expired: - scheduler.add_job( - cache_data.call_cleanup_expired, - "interval", - seconds=cache_data.expire, - args=[], - id=f"CacheRoot-{cache_data.name}", - ) + def new(self, name: str, incremental_update: bool = True, expire: int = 600): + """注册新缓存""" - def new(self, name: str, incremental_update: bool = True, expire: int = 60 * 10): def wrapper(func: Callable): _name = name.upper() if _name in self._data: - raise DbCacheException(f"DbCache 缓存数据 {name} 已存在...") + raise DbCacheException(f"缓存 {name} 已存在") + self._data[_name] = CacheData( name=_name, func=func, expire=expire, incremental_update=incremental_update, ) + return func return wrapper def listener(self, name: str): + """创建缓存监听器""" + def decorator(func): @wraps(func) async def wrapper(*args, **kwargs): try: - if is_coroutine_callable(func): - result = await func(*args, **kwargs) - else: - result = func(*args, **kwargs) - return result + return ( + await func(*args, **kwargs) + if is_coroutine_callable(func) + else func(*args, **kwargs) + ) finally: - cache_data = self._data.get(name.upper()) - if cache_data and cache_data.with_refresh: - if is_coroutine_callable(cache_data.with_refresh): - await cache_data.with_refresh(cache_data.data) - else: - cache_data.with_refresh(cache_data.data) - await cache_data.call_with_expiration(True) - logger.debug( - f"缓存类型 {name.upper()} 进行监听更新...", "CacheRoot" - ) + cache = self._data.get(name.upper()) + if cache and cache.with_refresh: + await cache.refresh() + logger.debug(f"监听器触发缓存 {name} 刷新") return wrapper @@ -256,86 +364,79 @@ class CacheManager: @validate_name def updater(self, name: str): + """设置缓存更新方法""" + def wrapper(func: Callable): - self._data[name.upper()].updater = func + self._data[name].updater = func + return func return wrapper @validate_name def getter(self, name: str, result_model: type): + """设置缓存获取方法""" + def wrapper(func: Callable): self._data[name].getter = CacheGetter[result_model](get_func=func) + return func return wrapper @validate_name def with_refresh(self, name: str): + """设置缓存刷新方法""" + def wrapper(func: Callable): - self._data[name.upper()].with_refresh = func + self._data[name].with_refresh = func + return func return wrapper - @validate_name - def with_expiration(self, name: str): - def wrapper(func: Callable[[Any, int], dict[str, float]]): - self._data[name.upper()].with_expiration = func - - return wrapper - - @validate_name - def cleanup_expired(self, name: str): - def wrapper(func: Callable[[CacheData], None]): - self._data[name.upper()].cleanup_expired = func - - return wrapper - - async def check_expire(self, name: str, *args, **kwargs): - name = name.upper() - if self._data.get(name) and ( - time.time() - self._data[name].reload_time > self._data[name].expire - or not self._data[name].reload_count - ): - await self._data[name].reload(*args, **kwargs) - - async def get_cache_data(self, name: str): - return cache.data if (cache := await self.get_cache(name)) else None - - async def get_cache(self, name: str, *args, **kwargs) -> CacheData | None: - name = name.upper() - if cache := self._data.get(name): - # await self.check_expire(name, *args, **kwargs) - return cache - return None - - async def get(self, name: str, *args, **kwargs): - cache = await self.get_cache(name.upper(), *args, **kwargs) - if cache: - return await cache.get(*args, **kwargs) if cache.getter else cache.data - return None - - async def reload(self, name: str, *args, **kwargs): + async def get_cache_data(self, name: str) -> Any | None: + """获取缓存数据""" cache = await self.get_cache(name.upper()) - if cache: - await cache.refresh(*args, **kwargs) + return await cache.get_data() if cache else None - async def update(self, name: str, key: str, value: Any, *args, **kwargs): + async def get_cache(self, name: str) -> CacheData | None: + """获取缓存对象""" + return self._data.get(name.upper()) + + async def get(self, name: str, *args, **kwargs) -> Any: + """获取缓存内容""" + cache = await self.get_cache(name.upper()) + return await cache.get(*args, **kwargs) if cache else None + + async def update(self, name: str, key: str, value: Any = None, *args, **kwargs): + """更新缓存项""" cache = await self.get_cache(name.upper()) if cache: await cache.update(key, value, *args, **kwargs) + async def reload(self, name: str, *args, **kwargs): + """重新加载缓存""" + cache = await self.get_cache(name.upper()) + if cache: + await cache.reload(*args, **kwargs) + +# 全局缓存管理器实例 CacheRoot = CacheManager() class Cache(Generic[T]): + """类型化缓存访问接口""" + def __init__(self, module: str): - self.module = module + self.module = module.upper() async def get(self, *args, **kwargs) -> T | None: + """获取缓存""" return await CacheRoot.get(self.module, *args, **kwargs) async def update(self, key: str, value: Any = None, *args, **kwargs): - return await CacheRoot.update(self.module, key, value, *args, **kwargs) + """更新缓存项""" + await CacheRoot.update(self.module, key, value, *args, **kwargs) - async def reload(self, key: str | None = None, *args, **kwargs): - await CacheRoot.reload(self.module, key, *args, **kwargs) + async def reload(self, *args, **kwargs): + """重新加载缓存""" + await CacheRoot.reload(self.module, *args, **kwargs)