diff --git a/zhenxun/builtin_plugins/init/__init_cache.py b/zhenxun/builtin_plugins/init/__init_cache.py index e6ac5bf6..6cba00d4 100644 --- a/zhenxun/builtin_plugins/init/__init_cache.py +++ b/zhenxun/builtin_plugins/init/__init_cache.py @@ -1,5 +1,3 @@ -from typing import Any - from zhenxun.models.ban_console import BanConsole from zhenxun.models.bot_console import BotConsole from zhenxun.models.group_console import GroupConsole @@ -19,34 +17,26 @@ async def _(): return {p.module: p for p in data_list} -@CacheRoot.updater(CacheType.PLUGINS) -async def _(data: dict[str, PluginInfo], key: str, value: Any): - """更新插件缓存""" - if value: - data[key] = value - elif plugin := await PluginInfo.get_plugin(module=key): - data[key] = plugin - - @CacheRoot.getter(CacheType.PLUGINS, result_model=PluginInfo) async def _(cache_data: CacheData, module: str): """获取插件缓存""" - data = await cache_data.get_data() or {} - if module not in data: + data = await cache_data.get_key(module) + if not data: if plugin := await PluginInfo.get_plugin(module=module): - data[module] = plugin - await cache_data.set_data(data) + await cache_data.set_key(module, plugin) logger.debug(f"插件 {module} 数据已设置到缓存") - return data.get(module) + return plugin + return data @CacheRoot.with_refresh(CacheType.PLUGINS) -async def _(data: dict[str, PluginInfo] | None): +async def _(cache_data: CacheData, data: dict[str, PluginInfo] | None): """刷新插件缓存""" if not data: return plugins = await PluginInfo.filter(module__in=data.keys(), load_status=True).all() - data.update({p.module: p for p in plugins}) + for plugin in plugins: + await cache_data.set_key(plugin.module, plugin) @CacheRoot.new(CacheType.GROUPS) @@ -56,35 +46,27 @@ async def _(): return {p.group_id: p for p in data_list if not p.channel_id} -@CacheRoot.updater(CacheType.GROUPS) -async def _(data: dict[str, GroupConsole], key: str, value: Any): - """更新群组缓存""" - if value: - data[key] = value - elif group := await GroupConsole.get_group(group_id=key): - data[key] = group - - @CacheRoot.getter(CacheType.GROUPS, result_model=GroupConsole) async def _(cache_data: CacheData, group_id: str): """获取群组缓存""" - data = await cache_data.get_data() or {} - if group_id not in data: + data = await cache_data.get_key(group_id) + if not data: if group := await GroupConsole.get_group(group_id=group_id): - data[group_id] = group - await cache_data.set_data(data) - return data.get(group_id) + await cache_data.set_key(group_id, group) + return group + return data @CacheRoot.with_refresh(CacheType.GROUPS) -async def _(data: dict[str, GroupConsole] | None): +async def _(cache_data: CacheData, data: dict[str, GroupConsole] | None): """刷新群组缓存""" if not data: return groups = await GroupConsole.filter( group_id__in=data.keys(), channel_id__isnull=True ).all() - data.update({g.group_id: g for g in groups}) + for group in groups: + await cache_data.set_key(group.group_id, group) @CacheRoot.new(CacheType.BOT) @@ -94,33 +76,25 @@ async def _(): return {p.bot_id: p for p in data_list} -@CacheRoot.updater(CacheType.BOT) -async def _(data: dict[str, BotConsole], key: str, value: Any): - """更新机器人缓存""" - if value: - data[key] = value - elif bot := await BotConsole.get_or_none(bot_id=key): - data[key] = bot - - @CacheRoot.getter(CacheType.BOT, result_model=BotConsole) async def _(cache_data: CacheData, bot_id: str): """获取机器人缓存""" - data = await cache_data.get_data() or {} - if bot_id not in data: + data = await cache_data.get_key(bot_id) + if not data: if bot := await BotConsole.get_or_none(bot_id=bot_id): - data[bot_id] = bot - await cache_data.set_data(data) - return data.get(bot_id) + await cache_data.set_key(bot_id, bot) + return bot + return data @CacheRoot.with_refresh(CacheType.BOT) -async def _(data: dict[str, BotConsole] | None): +async def _(cache_data: CacheData, data: dict[str, BotConsole] | None): """刷新机器人缓存""" if not data: return bots = await BotConsole.filter(bot_id__in=data.keys()).all() - data.update({b.bot_id: b for b in bots}) + for bot in bots: + await cache_data.set_key(bot.bot_id, bot) @CacheRoot.new(CacheType.USERS) @@ -130,111 +104,100 @@ async def _(): return {p.user_id: p for p in data_list} -@CacheRoot.updater(CacheType.USERS) -async def _(data: dict[str, UserConsole], key: str, value: Any): - """更新用户缓存""" - if value: - data[key] = value - elif user := await UserConsole.get_user(user_id=key): - data[key] = user - - @CacheRoot.getter(CacheType.USERS, result_model=UserConsole) async def _(cache_data: CacheData, user_id: str): """获取用户缓存""" - data = await cache_data.get_data() or {} - if user_id not in data: + data = await cache_data.get_key(user_id) + if not data: if user := await UserConsole.get_user(user_id=user_id): - data[user_id] = user - await cache_data.set_data(data) - return data.get(user_id) + await cache_data.set_key(user_id, user) + return user + return data @CacheRoot.with_refresh(CacheType.USERS) -async def _(data: dict[str, UserConsole] | None): +async def _(cache_data: CacheData, data: dict[str, UserConsole] | None): """刷新用户缓存""" if not data: return users = await UserConsole.filter(user_id__in=data.keys()).all() - data.update({u.user_id: u for u in users}) + for user in users: + await cache_data.set_key(user.user_id, user) @CacheRoot.new(CacheType.LEVEL, False) async def _(): """初始化等级缓存""" - return await LevelUser().all() + data_list = await LevelUser().all() + return {f"{d.user_id}:{d.group_id or ''}": d for d in data_list} @CacheRoot.getter(CacheType.LEVEL, result_model=list[LevelUser]) async def _(cache_data: CacheData, user_id: str, group_id: str | None = None): """获取等级缓存""" - data = await cache_data.get_data() or [] - if not group_id: - return [d for d in data if d.user_id == user_id and not d.group_id] - return [d for d in data if d.user_id == user_id and d.group_id == group_id] + key = f"{user_id}:{group_id or ''}" + data = await cache_data.get_key(key) + if not data: + if group_id: + data = await LevelUser.filter(user_id=user_id, group_id=group_id).all() + else: + data = await LevelUser.filter(user_id=user_id, group_id__isnull=True).all() + if data: + await cache_data.set_key(key, data) + return data + return data or [] @CacheRoot.new(CacheType.BAN, False) async def _(): """初始化封禁缓存""" - return await BanConsole.all() + data_list = await BanConsole.all() + return {f"{d.user_id or ''}:{d.group_id or ''}": d for d in data_list} @CacheRoot.getter(CacheType.BAN, result_model=list[BanConsole]) async def _(cache_data: CacheData, user_id: str | None, group_id: str | None = None): """获取封禁缓存""" - data = await cache_data.get_data() or [] - if user_id: - if group_id: - return [d for d in data if d.user_id == user_id and d.group_id == group_id] - return [d for d in data if d.user_id == user_id and not d.group_id] - if group_id: - return [d for d in data if not d.user_id and d.group_id == group_id] - return None + key = f"{user_id or ''}:{group_id or ''}" + data = await cache_data.get_key(key) + if not data: + if user_id and group_id: + data = await BanConsole.filter(user_id=user_id, group_id=group_id).all() + elif user_id: + data = await BanConsole.filter(user_id=user_id, group_id__isnull=True).all() + elif group_id: + data = await BanConsole.filter( + user_id__isnull=True, group_id=group_id + ).all() + if data: + await cache_data.set_key(key, data) + return data + return data or [] @CacheRoot.new(CacheType.LIMIT) async def _(): """初始化限制缓存""" data_list = await PluginLimit.filter(status=True).all() - result_data = {} - for data in data_list: - if not result_data.get(data.module): - result_data[data.module] = [] - result_data[data.module].append(data) - return result_data - - -@CacheRoot.updater(CacheType.LIMIT) -async def _(data: dict[str, list[PluginLimit]], key: str, value: Any): - """更新限制缓存""" - if value: - data[key] = value - elif limits := await PluginLimit.filter(module=key, status=True): - data[key] = limits + return {data.module: data for data in data_list} @CacheRoot.getter(CacheType.LIMIT, result_model=list[PluginLimit]) async def _(cache_data: CacheData, module: str): """获取限制缓存""" - data = await cache_data.get_data() or {} - if module not in data: + data = await cache_data.get_key(module) + if not data: if limits := await PluginLimit.filter(module=module, status=True): - data[module] = limits - await cache_data.set_data(data) - return data.get(module) + await cache_data.set_key(module, limits) + return limits + return data or [] @CacheRoot.with_refresh(CacheType.LIMIT) -async def _(data: dict[str, list[PluginLimit]] | None): +async def _(cache_data: CacheData, data: dict[str, list[PluginLimit]] | None): """刷新限制缓存""" if not data: return limits = await PluginLimit.filter(module__in=data.keys(), load_status=True).all() - new_data = {} for limit in limits: - if not new_data.get(limit.module): - new_data[limit.module] = [] - new_data[limit.module].append(limit) - data.clear() - data.update(new_data) + await cache_data.set_key(limit.module, limit) diff --git a/zhenxun/services/cache.py b/zhenxun/services/cache.py index f7d8ab55..7c3378bf 100644 --- a/zhenxun/services/cache.py +++ b/zhenxun/services/cache.py @@ -10,6 +10,7 @@ from aiocache.serializers import JsonSerializer from nonebot.compat import model_dump from nonebot.utils import is_coroutine_callable from pydantic import BaseModel +from tortoise.fields.base import Field from zhenxun.services.log import logger @@ -44,15 +45,47 @@ class CacheGetter(BaseModel, Generic[T]): """缓存数据获取器""" get_func: Callable[..., Any] | None = None + get_all_func: Callable[..., Any] | None = None - async def get(self, cache_data: "CacheData", *args, **kwargs) -> T: - """获取处理后的缓存数据""" + async def get(self, cache_data: "CacheData", key: str, *args, **kwargs) -> T: + """获取单个缓存数据""" if not self.get_func: - return await cache_data.get_data() + data = await cache_data.get_key(key) + if cache_data.result_model: + return cache_data._deserialize_value(data, cache_data.result_model) + return data if is_coroutine_callable(self.get_func): - return await self.get_func(cache_data, *args, **kwargs) - return self.get_func(cache_data, *args, **kwargs) + data = await self.get_func(cache_data, key, *args, **kwargs) + else: + data = self.get_func(cache_data, key, *args, **kwargs) + + if cache_data.result_model: + return cache_data._deserialize_value(data, cache_data.result_model) + return data + + async def get_all(self, cache_data: "CacheData", *args, **kwargs) -> dict[str, T]: + """获取所有缓存数据""" + if not self.get_all_func: + data = await cache_data.get_all_data() + if cache_data.result_model: + return { + k: cache_data._deserialize_value(v, cache_data.result_model) + for k, v in data.items() + } + return data + + if is_coroutine_callable(self.get_all_func): + data = await self.get_all_func(cache_data, *args, **kwargs) + else: + data = self.get_all_func(cache_data, *args, **kwargs) + + if cache_data.result_model: + return { + k: cache_data._deserialize_value(v, cache_data.result_model) + for k, v in data.items() + } + return data class CacheData(BaseModel): @@ -66,20 +99,117 @@ class CacheData(BaseModel): expire: int = 600 # 默认10分钟过期 reload_count: int = 0 incremental_update: bool = True + _cache_instance: BaseCache | None = None + result_model: type | None = None + _keys: set[str] = set() # 存储所有缓存键 class Config: arbitrary_types_allowed = True + underscore_attrs_are_private = True @property def _cache(self) -> BaseCache: """获取aiocache实例""" - return AioCache( - AioCache.MEMORY, - serializer=JsonSerializer(), - namespace="zhenxun_cache", - timeout=30, # 操作超时时间 - ttl=self.expire, # 设置默认过期时间 - ) + if self._cache_instance is None: + self._cache_instance = AioCache( + AioCache.MEMORY, + serializer=JsonSerializer(), + namespace="zhenxun_cache", + timeout=30, # 操作超时时间 + ttl=self.expire, # 设置默认过期时间 + ) + return self._cache_instance + + def _deserialize_value(self, value: Any, target_type: type | None = None) -> Any: + """反序列化值,将JSON数据转换回原始类型 + + Args: + value: 需要反序列化的值 + target_type: 目标类型,用于指导反序列化 + + Returns: + 反序列化后的值 + """ + if value is None: + return None + + # 如果是字典且指定了目标类型 + if isinstance(value, dict) and target_type: + # 处理Tortoise-ORM Model + if hasattr(target_type, "_meta"): + # 处理字段值 + processed_value = {} + for field_name, field_value in value.items(): + field: Field = target_type._meta.fields_map.get(field_name) + if field: + # 跳过反向关系字段 + if hasattr(field, "_related_name"): + continue + # 处理 CharEnumField + if hasattr(field, "enum_class"): + try: + processed_value[field_name] = field.enum_class( + field_value + ) + except ValueError: + processed_value[field_name] = None + else: + processed_value[field_name] = field_value + + logger.debug(f"处理后的值: {processed_value}") + + # 创建模型实例 + instance = target_type() + # 设置字段值 + for field_name, field_value in processed_value.items(): + if field_name in target_type._meta.fields_map: + field = target_type._meta.fields_map[field_name] + # 设置字段值 + try: + if hasattr(field, "to_python_value"): + if not field.field_type: + logger.warning(f"字段 {field_name} 类型为空") + continue + field_value = field.to_python_value(field_value) + setattr(instance, field_name, field_value) + except Exception as e: + logger.warning(f"设置字段 {field_name} 失败: {e}") + + # 设置 _saved_in_db 标志 + instance._saved_in_db = True + return instance + # 处理Pydantic模型 + elif hasattr(target_type, "model_validate"): + return target_type.model_validate(value) + elif hasattr(target_type, "from_dict"): + return target_type.from_dict(value) + elif hasattr(target_type, "parse_obj"): + return target_type.parse_obj(value) + else: + return target_type(**value) + + # 处理列表类型 + if isinstance(value, list): + if not value: + return value + if ( + target_type + and hasattr(target_type, "__origin__") + and target_type.__origin__ is list + ): + item_type = target_type.__args__[0] + return [self._deserialize_value(item, item_type) for item in value] + return [self._deserialize_value(item) for item in value] + + # 处理字典类型 + if isinstance(value, dict): + return {k: self._deserialize_value(v) for k, v in value.items()} + + # 处理基本类型 + if isinstance(value, int | float | str | bool): + return value + + return value async def get_data(self) -> Any: """从缓存获取数据""" @@ -88,29 +218,33 @@ class CacheData(BaseModel): logger.debug(f"获取缓存 {self.name} 数据: {data}") # 如果数据为空,尝试重新加载 - if data is None: - logger.debug(f"缓存 {self.name} 数据为空,尝试重新加载") - try: - if self.has_args(): - new_data = ( - await self.func() - if is_coroutine_callable(self.func) - else self.func() - ) - else: - new_data = ( - await self.func() - if is_coroutine_callable(self.func) - else self.func() - ) + # if data is None: + # logger.debug(f"缓存 {self.name} 数据为空,尝试重新加载") + # try: + # if self.has_args(): + # new_data = ( + # await self.func() + # if is_coroutine_callable(self.func) + # else self.func() + # ) + # else: + # new_data = ( + # await self.func() + # if is_coroutine_callable(self.func) + # else self.func() + # ) - await self.set_data(new_data) - self.reload_count += 1 - logger.info(f"重新加载缓存 {self.name} 完成") - return new_data - except Exception as e: - logger.error(f"重新加载缓存 {self.name} 失败: {e}") - return None + # await self.set_data(new_data) + # self.reload_count += 1 + # logger.info(f"重新加载缓存 {self.name} 完成") + # return new_data + # except Exception as e: + # logger.error(f"重新加载缓存 {self.name} 失败: {e}") + # return None + + # 使用 result_model 进行反序列化 + if self.result_model: + return self._deserialize_value(data, self.result_model) return data except Exception as e: @@ -144,10 +278,16 @@ class CacheData(BaseModel): field_value, "_related_name" ): continue + # 跳过外键关系字段 + if hasattr(field_value, "_meta"): + field_value = getattr( + field_value, value._meta.fields[field].related_name or "id" + ) result[field] = self._serialize_value(field_value) except AttributeError: continue return result + # 处理Pydantic模型 elif isinstance(value, BaseModel): return model_dump(value) @@ -160,9 +300,6 @@ class CacheData(BaseModel): elif isinstance(value, int | float | str | bool): # 基本类型直接返回 return value - elif hasattr(value, "__dict__"): - # 处理普通类对象 - return self._serialize_value(value.__dict__) else: # 其他类型转换为字符串 return str(value) @@ -183,19 +320,6 @@ class CacheData(BaseModel): await self._cache.set(self.name, serialized_value, ttl=self.expire) logger.debug(f"设置缓存 {self.name} 新数据完成") - # 4. 立即验证 - cached_data = await self._cache.get(self.name) - if cached_data is None: - logger.error(f"缓存 {self.name} 设置失败:数据验证失败") - # 5. 如果验证失败,尝试重新设置 - await self._cache.set(self.name, serialized_value, ttl=self.expire) - cached_data = await self._cache.get(self.name) - if cached_data is None: - logger.error(f"缓存 {self.name} 重试设置失败") - else: - logger.debug(f"缓存 {self.name} 重试设置成功: {cached_data}") - else: - logger.debug(f"缓存 {self.name} 数据验证成功: {cached_data}") except Exception as e: logger.error(f"设置缓存 {self.name} 失败: {e}") raise # 重新抛出异常,让上层处理 @@ -207,15 +331,25 @@ class CacheData(BaseModel): except Exception as e: logger.error(f"删除缓存 {self.name}", e=e) - async def get(self, *args, **kwargs) -> Any: + async def get(self, key: str, *args, **kwargs) -> Any: """获取缓存""" if not self.reload_count and not self.incremental_update: await self.reload(*args, **kwargs) if not self.getter: - return await self.get_data() + return await self.get_key(key) - return await self.getter.get(self, *args, **kwargs) + return await self.getter.get(self, key, *args, **kwargs) + + async def get_all(self, *args, **kwargs) -> dict[str, Any]: + """获取所有缓存数据""" + if not self.reload_count and not self.incremental_update: + await self.reload(*args, **kwargs) + + if not self.getter: + return await self.get_all_data() + + return await self.getter.get_all(self, *args, **kwargs) async def update(self, key: str, value: Any = None, *args, **kwargs): """更新单个缓存项""" @@ -223,13 +357,13 @@ class CacheData(BaseModel): logger.warning(f"缓存 {self.name} 未配置更新方法") return - current_data = await self.get_data() or {} + current_data = await self.get_key(key) or {} if is_coroutine_callable(self.updater): await self.updater(current_data, key, value, *args, **kwargs) else: self.updater(current_data, key, value, *args, **kwargs) - await self.set_data(current_data) + await self.set_key(key, current_data) logger.debug(f"更新缓存 {self.name}.{key}") async def refresh(self, *args, **kwargs): @@ -262,7 +396,14 @@ class CacheData(BaseModel): else self.func() ) - await self.set_data(new_data) + # 如果是字典,则分别存储每个键值对 + if isinstance(new_data, dict): + for key, value in new_data.items(): + await self.set_key(key, value) + else: + # 如果不是字典,则存储为单个键值对 + await self.set_key("default", new_data) + self.reload_count += 1 logger.info(f"重新加载缓存 {self.name} 完成") except Exception as e: @@ -291,11 +432,16 @@ class CacheData(BaseModel): Returns: 键对应的值,如果不存在返回None """ + cache_key = self._get_cache_key(key) try: - data = await self.get_data() - return data.get(key) if isinstance(data, dict) else None + data = await self._cache.get(cache_key) + logger.debug(f"获取缓存 {cache_key} 数据: {data}") + + if self.result_model: + return self._deserialize_value(data, self.result_model) + return data except Exception as e: - logger.error(f"获取缓存 {self.name}.{key} 失败: {e}") + logger.error(f"获取缓存 {cache_key} 失败: {e}") return None async def get_keys(self, keys: list[str]) -> dict[str, Any]: @@ -316,6 +462,61 @@ class CacheData(BaseModel): logger.error(f"获取缓存 {self.name} 的多个键失败: {e}") return dict.fromkeys(keys) + def _get_cache_key(self, key: str) -> str: + """获取缓存键名""" + return f"{self.name}:{key}" + + async def get_all_data(self) -> dict[str, Any]: + """获取所有缓存数据""" + try: + result = {} + for key in self._keys: + # 提取原始键名(去掉前缀) + original_key = key.split(":", 1)[1] + data = await self._cache.get(key) + if self.result_model: + result[original_key] = self._deserialize_value( + data, self.result_model + ) + else: + result[original_key] = data + return result + except Exception as e: + logger.error(f"获取所有缓存数据失败: {e}") + return {} + + async def set_key(self, key: str, value: Any): + """设置指定键的缓存数据""" + cache_key = self._get_cache_key(key) + try: + serialized_value = self._serialize_value(value) + await self._cache.set(cache_key, serialized_value, ttl=self.expire) + self._keys.add(cache_key) # 添加到键列表 + logger.debug(f"设置缓存 {cache_key} 数据完成") + except Exception as e: + logger.error(f"设置缓存 {cache_key} 失败: {e}") + raise + + async def delete_key(self, key: str): + """删除指定键的缓存数据""" + cache_key = self._get_cache_key(key) + try: + await self._cache.delete(cache_key) + self._keys.discard(cache_key) # 从键列表中移除 + logger.debug(f"删除缓存 {cache_key} 完成") + except Exception as e: + logger.error(f"删除缓存 {cache_key} 失败: {e}") + + async def clear(self): + """清除所有缓存数据""" + try: + for key in list(self._keys): # 使用列表复制避免在迭代时修改 + await self._cache.delete(key) + self._keys.clear() + logger.debug(f"清除缓存 {self.name} 完成") + except Exception as e: + logger.error(f"清除缓存 {self.name} 失败: {e}") + class CacheManager: """全局缓存管理器""" @@ -378,6 +579,7 @@ class CacheManager: def wrapper(func: Callable): self._data[name].getter = CacheGetter[result_model](get_func=func) + self._data[name].result_model = result_model return func return wrapper