From 7d962bb4a85515d2906e5572588f11314bb6c28c Mon Sep 17 00:00:00 2001 From: HibiKier <775757368@qq.com> Date: Tue, 8 Jul 2025 16:10:35 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat(cache):=20=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E7=BC=93=E5=AD=98=E5=8A=9F=E8=83=BD=E9=85=8D=E7=BD=AE=E9=A1=B9?= =?UTF-8?q?=EF=BC=8C=E5=B9=B6=E6=96=B0=E5=A2=9E=E6=95=B0=E6=8D=AE=E8=AE=BF?= =?UTF-8?q?=E9=97=AE=E5=B1=82=E4=BB=A5=E6=94=AF=E6=8C=81=E7=BC=93=E5=AD=98?= =?UTF-8?q?=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- zhenxun/services/cache.py | 2 + zhenxun/services/data_access.py | 210 ++++++++++++++++++++++++++++++++ 2 files changed, 212 insertions(+) create mode 100644 zhenxun/services/data_access.py diff --git a/zhenxun/services/cache.py b/zhenxun/services/cache.py index af78b892..04222e95 100644 --- a/zhenxun/services/cache.py +++ b/zhenxun/services/cache.py @@ -26,6 +26,8 @@ driver = nonebot.get_driver() class Config(BaseModel): + enable_cache: bool = True + """是否开启缓存功能""" redis_host: str | None = None """redis地址""" redis_port: int | None = None diff --git a/zhenxun/services/data_access.py b/zhenxun/services/data_access.py new file mode 100644 index 00000000..52992b28 --- /dev/null +++ b/zhenxun/services/data_access.py @@ -0,0 +1,210 @@ +from typing import Any, Generic, TypeVar, cast + +from zhenxun.services.cache import CacheRoot +from zhenxun.services.cache import config as cache_config +from zhenxun.services.db_context import Model +from zhenxun.services.log import logger + +T = TypeVar("T", bound=Model) + + +class DataAccess(Generic[T]): + """数据访问层,根据配置决定是否使用缓存 + + 使用示例: + ```python + from zhenxun.services import DataAccess + from zhenxun.models.plugin_info import PluginInfo + + # 创建数据访问对象 + plugin_dao = DataAccess(PluginInfo) + + # 获取单个数据 + plugin = await plugin_dao.get(module="example_module") + + # 获取所有数据 + all_plugins = await plugin_dao.all() + + # 筛选数据 + enabled_plugins = await plugin_dao.filter(status=True) + + # 创建数据 + new_plugin = await plugin_dao.create( + module="new_module", + name="新插件", + status=True + ) + ``` + """ + + def __init__(self, model_cls: type[T], cache_type: str | None = None): + """初始化数据访问对象 + + 参数: + model_cls: 模型类 + cache_type: 缓存类型,如果为None则使用模型类的cache_type属性 + """ + self.model_cls = model_cls + self.cache_type = cache_type or getattr(model_cls, "cache_type", None) + + async def get_or_none(self, *args, **kwargs) -> T | None: + """获取单条数据 + + 参数: + *args: 查询参数 + **kwargs: 查询参数 + + 返回: + Optional[T]: 查询结果,如果不存在返回None + """ + # 如果缓存功能被禁用或模型没有缓存类型,直接从数据库获取 + if not cache_config.enable_cache or not self.cache_type: + return await self.model_cls.safe_get_or_none(*args, **kwargs) + + # 从缓存获取 + try: + # 生成缓存键 + key = self._generate_cache_key(kwargs) + # 尝试从缓存获取 + data = await CacheRoot.get(self.cache_type, key) + if data: + return cast(T, data) + except Exception as e: + logger.error("从缓存获取数据失败", e=e) + + # 如果缓存中没有,从数据库获取 + return await self.model_cls.safe_get_or_none(*args, **kwargs) + + async def filter(self, *args, **kwargs) -> list[T]: + """筛选数据 + + 参数: + *args: 查询参数 + **kwargs: 查询参数 + + 返回: + List[T]: 查询结果列表 + """ + # 如果缓存功能被禁用或模型没有缓存类型,直接从数据库获取 + if not cache_config.enable_cache or not self.cache_type: + return await self.model_cls.filter(*args, **kwargs) + + # 尝试从缓存获取所有数据后筛选 + try: + # 获取缓存数据 + cache_data = await CacheRoot.get_cache_data(self.cache_type) + if isinstance(cache_data, dict) and cache_data: + # 在内存中筛选 + filtered_data = [] + for item in cache_data.values(): + match = not any( + not hasattr(item, k) or getattr(item, k) != v + for k, v in kwargs.items() + ) + if match: + filtered_data.append(item) + return cast(list[T], filtered_data) + except Exception as e: + logger.error("从缓存筛选数据失败", e=e) + + # 如果缓存中没有或筛选失败,从数据库获取 + return await self.model_cls.filter(*args, **kwargs) + + async def all(self) -> list[T]: + """获取所有数据 + + 返回: + List[T]: 所有数据列表 + """ + # 如果缓存功能被禁用或模型没有缓存类型,直接从数据库获取 + if not cache_config.enable_cache or not self.cache_type: + return await self.model_cls.all() + + # 尝试从缓存获取所有数据 + try: + # 获取缓存数据 + cache_data = await CacheRoot.get_cache_data(self.cache_type) + if isinstance(cache_data, dict) and cache_data: + return cast(list[T], list(cache_data.values())) + except Exception as e: + logger.error("从缓存获取所有数据失败", e=e) + + # 如果缓存中没有,从数据库获取 + return await self.model_cls.all() + + async def count(self, *args, **kwargs) -> int: + """获取数据数量 + + 参数: + *args: 查询参数 + **kwargs: 查询参数 + + 返回: + int: 数据数量 + """ + # 直接从数据库获取数量 + return await self.model_cls.filter(*args, **kwargs).count() + + async def exists(self, *args, **kwargs) -> bool: + """判断数据是否存在 + + 参数: + *args: 查询参数 + **kwargs: 查询参数 + + 返回: + bool: 是否存在 + """ + # 直接从数据库判断是否存在 + return await self.model_cls.filter(*args, **kwargs).exists() + + async def create(self, **kwargs) -> T: + """创建数据 + + 参数: + **kwargs: 创建参数 + + 返回: + T: 创建的数据 + """ + return await self.model_cls.create(**kwargs) + + async def update_or_create( + self, defaults: dict[str, Any] | None = None, **kwargs + ) -> tuple[T, bool]: + """更新或创建数据 + + 参数: + defaults: 默认值 + **kwargs: 查询参数 + + 返回: + tuple[T, bool]: (数据, 是否创建) + """ + return await self.model_cls.update_or_create(defaults=defaults, **kwargs) + + async def delete(self, *args, **kwargs) -> int: + """删除数据 + + 参数: + *args: 查询参数 + **kwargs: 查询参数 + + 返回: + int: 删除的数据数量 + """ + return await self.model_cls.filter(*args, **kwargs).delete() + + def _generate_cache_key(self, kwargs: dict[str, Any]) -> str: + """根据查询参数生成缓存键 + + 参数: + kwargs: 查询参数 + + 返回: + str: 缓存键 + """ + # 实现一个简单的键生成算法 + if not kwargs: + return "default" + return "_".join(f"{k}:{v}" for k, v in sorted(kwargs.items()))