mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 14:22:55 +08:00
✨ cache支持redis连接
This commit is contained in:
parent
e63a3692e9
commit
e04bd1eff8
10
.env.dev
10
.env.dev
@ -27,6 +27,16 @@ QBOT_ID_DATA = '{
|
||||
# 示例: "sqlite:data/db/zhenxun.db" 在data目录下建立db文件夹
|
||||
DB_URL = ""
|
||||
|
||||
# REDIS配置,使用REDIS替换Cache内存缓存
|
||||
# REDIS地址
|
||||
# REDIS_HOST = "127.0.0.1"
|
||||
# REDIS端口
|
||||
# REDIS_PORT = 6379
|
||||
# REDIS密码
|
||||
# REDIS_PASSWORD = ""
|
||||
# REDIS过期时间
|
||||
# REDIS_EXPIRE = 600
|
||||
|
||||
# 系统代理
|
||||
# SYSTEM_PROXY = "http://127.0.0.1:7890"
|
||||
|
||||
|
||||
245
poetry.lock
generated
245
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -39,7 +39,7 @@ dateparser = "^1.2.0"
|
||||
bilireq = "0.2.3post0"
|
||||
python-jose = { extras = ["cryptography"], version = "^3.3.0" }
|
||||
python-multipart = "^0.0.9"
|
||||
aiocache = ">=0.12, <1.0.0"
|
||||
aiocache = {extras = ["redis"], version = "^0.12.3"}
|
||||
py-cpuinfo = "^9.0.0"
|
||||
nonebot-plugin-alconna = "^0.54.0"
|
||||
tenacity = "^9.0.0"
|
||||
|
||||
@ -5,12 +5,14 @@ import inspect
|
||||
from typing import Any, ClassVar, Generic, TypeVar
|
||||
|
||||
from aiocache import Cache as AioCache
|
||||
|
||||
# from aiocache.backends.redis import RedisCache
|
||||
from aiocache.base import BaseCache
|
||||
from aiocache.serializers import JsonSerializer
|
||||
import nonebot
|
||||
from nonebot.compat import model_dump
|
||||
from nonebot.utils import is_coroutine_callable
|
||||
from pydantic import BaseModel
|
||||
from tortoise.fields.base import Field
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
@ -18,6 +20,24 @@ __all__ = ["Cache", "CacheData", "CacheRoot"]
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
LOG_COMMAND = "cache"
|
||||
|
||||
driver = nonebot.get_driver()
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
redis_host: str | None = None
|
||||
"""redis地址"""
|
||||
redis_port: int | None = None
|
||||
"""redis端口"""
|
||||
redis_password: str | None = None
|
||||
"""redis密码"""
|
||||
redis_expire: int = 600
|
||||
"""redis过期时间"""
|
||||
|
||||
|
||||
config = nonebot.get_plugin_config(Config)
|
||||
|
||||
|
||||
class DbCacheException(Exception):
|
||||
"""缓存相关异常"""
|
||||
@ -99,35 +119,21 @@ class CacheData(BaseModel):
|
||||
expire: int = 600 # 默认10分钟过期
|
||||
reload_count: int = 0
|
||||
lazy_load: bool = True # 默认延迟加载
|
||||
_cache_instance: BaseCache | None = None
|
||||
result_model: type | None = None
|
||||
_keys: set[str] = set() # 存储所有缓存键
|
||||
_cache: BaseCache | AioCache
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
underscore_attrs_are_private = True
|
||||
|
||||
@property
|
||||
def _cache(self) -> BaseCache:
|
||||
"""获取aiocache实例"""
|
||||
if self._cache_instance is None:
|
||||
self._cache_instance = AioCache(
|
||||
AioCache.MEMORY,
|
||||
serializer=JsonSerializer(),
|
||||
namespace="zhenxun_cache",
|
||||
timeout=30, # 操作超时时间
|
||||
ttl=self.expire, # 设置默认过期时间
|
||||
)
|
||||
return self._cache_instance
|
||||
|
||||
def _deserialize_value(self, value: Any, target_type: type | None = None) -> Any:
|
||||
"""反序列化值,将JSON数据转换回原始类型
|
||||
|
||||
Args:
|
||||
参数:
|
||||
value: 需要反序列化的值
|
||||
target_type: 目标类型,用于指导反序列化
|
||||
|
||||
Returns:
|
||||
返回:
|
||||
反序列化后的值
|
||||
"""
|
||||
if value is None:
|
||||
@ -137,39 +143,7 @@ class CacheData(BaseModel):
|
||||
if isinstance(value, dict) and target_type:
|
||||
# 处理Tortoise-ORM Model
|
||||
if hasattr(target_type, "_meta"):
|
||||
# 处理字段值
|
||||
processed_value = {}
|
||||
for field_name, field_value in value.items():
|
||||
field: Field = target_type._meta.fields_map.get(field_name)
|
||||
if field:
|
||||
# 跳过反向关系字段
|
||||
if hasattr(field, "_related_name"):
|
||||
continue
|
||||
processed_value[field_name] = field_value
|
||||
|
||||
logger.debug(f"处理后的值: {processed_value}")
|
||||
|
||||
# 创建模型实例
|
||||
instance = target_type()
|
||||
# 设置字段值
|
||||
for field_name, field_value in processed_value.items():
|
||||
if field_name in target_type._meta.fields_map:
|
||||
field = target_type._meta.fields_map[field_name]
|
||||
# 设置字段值
|
||||
try:
|
||||
if hasattr(field, "to_python_value"):
|
||||
if not field.field_type:
|
||||
logger.debug(f"字段 {field_name} 类型为空")
|
||||
continue
|
||||
field_value = field.to_python_value(field_value)
|
||||
setattr(instance, field_name, field_value)
|
||||
except Exception as e:
|
||||
logger.warning(f"设置字段 {field_name} 失败", e=e)
|
||||
|
||||
# 设置 _saved_in_db 标志
|
||||
instance._saved_in_db = True
|
||||
return instance
|
||||
# 处理Pydantic模型
|
||||
return self._extracted_from__deserialize_value_19(value, target_type)
|
||||
elif hasattr(target_type, "model_validate"):
|
||||
return target_type.model_validate(value)
|
||||
elif hasattr(target_type, "from_dict"):
|
||||
@ -196,12 +170,41 @@ class CacheData(BaseModel):
|
||||
if isinstance(value, dict):
|
||||
return {k: self._deserialize_value(v) for k, v in value.items()}
|
||||
|
||||
# 处理基本类型
|
||||
if isinstance(value, int | float | str | bool):
|
||||
return value
|
||||
|
||||
return value
|
||||
|
||||
def _extracted_from__deserialize_value_19(self, value, target_type):
|
||||
# 处理字段值
|
||||
processed_value = {}
|
||||
for field_name, field_value in value.items():
|
||||
if field := target_type._meta.fields_map.get(field_name):
|
||||
# 跳过反向关系字段
|
||||
if hasattr(field, "_related_name"):
|
||||
continue
|
||||
processed_value[field_name] = field_value
|
||||
|
||||
logger.debug(f"处理后的值: {processed_value}")
|
||||
|
||||
# 创建模型实例
|
||||
instance = target_type()
|
||||
# 设置字段值
|
||||
for field_name, field_value in processed_value.items():
|
||||
if field_name in target_type._meta.fields_map:
|
||||
field = target_type._meta.fields_map[field_name]
|
||||
# 设置字段值
|
||||
try:
|
||||
if hasattr(field, "to_python_value"):
|
||||
if not field.field_type:
|
||||
logger.debug(f"字段 {field_name} 类型为空")
|
||||
continue
|
||||
field_value = field.to_python_value(field_value)
|
||||
setattr(instance, field_name, field_value)
|
||||
except Exception as e:
|
||||
logger.warning(f"设置字段 {field_name} 失败", e=e)
|
||||
|
||||
# 设置 _saved_in_db 标志
|
||||
instance._saved_in_db = True
|
||||
return instance
|
||||
|
||||
async def get_data(self) -> Any:
|
||||
"""从缓存获取数据"""
|
||||
try:
|
||||
@ -245,10 +248,10 @@ class CacheData(BaseModel):
|
||||
def _serialize_value(self, value: Any) -> Any:
|
||||
"""序列化值,将数据转换为JSON可序列化的格式
|
||||
|
||||
Args:
|
||||
参数:
|
||||
value: 需要序列化的值
|
||||
|
||||
Returns:
|
||||
返回:
|
||||
JSON可序列化的值
|
||||
"""
|
||||
if value is None:
|
||||
@ -417,10 +420,10 @@ class CacheData(BaseModel):
|
||||
async def get_key(self, key: str) -> Any:
|
||||
"""获取缓存中指定键的数据
|
||||
|
||||
Args:
|
||||
参数:
|
||||
key: 要获取的键名
|
||||
|
||||
Returns:
|
||||
返回:
|
||||
键对应的值,如果不存在返回None
|
||||
"""
|
||||
cache_key = self._get_cache_key(key)
|
||||
@ -438,10 +441,10 @@ class CacheData(BaseModel):
|
||||
async def get_keys(self, keys: list[str]) -> dict[str, Any]:
|
||||
"""获取缓存中多个键的数据
|
||||
|
||||
Args:
|
||||
参数:
|
||||
keys: 要获取的键名列表
|
||||
|
||||
Returns:
|
||||
返回:
|
||||
包含所有请求键值的字典,不存在的键值为None
|
||||
"""
|
||||
try:
|
||||
@ -512,11 +515,52 @@ class CacheData(BaseModel):
|
||||
class CacheManager:
|
||||
"""全局缓存管理器"""
|
||||
|
||||
_cache_instance: BaseCache | AioCache | None = None
|
||||
_data: ClassVar[dict[str, CacheData]] = {}
|
||||
|
||||
@property
|
||||
def _cache(self) -> BaseCache | AioCache:
|
||||
"""获取aiocache实例"""
|
||||
if self._cache_instance is None:
|
||||
if config.redis_host:
|
||||
self._cache_instance = AioCache(
|
||||
AioCache.REDIS,
|
||||
serializer=JsonSerializer(),
|
||||
namespace="zhenxun_cache",
|
||||
timeout=30, # 操作超时时间
|
||||
ttl=config.redis_expire, # 设置默认过期时间
|
||||
endpoint=config.redis_host,
|
||||
port=config.redis_port,
|
||||
password=config.redis_password,
|
||||
)
|
||||
else:
|
||||
self._cache_instance = AioCache(
|
||||
AioCache.MEMORY,
|
||||
serializer=JsonSerializer(),
|
||||
namespace="zhenxun_cache",
|
||||
timeout=30, # 操作超时时间
|
||||
ttl=config.redis_expire, # 设置默认过期时间
|
||||
)
|
||||
logger.info("初始化缓存完成...", LOG_COMMAND)
|
||||
return self._cache_instance
|
||||
|
||||
async def close(self):
|
||||
if self._cache_instance:
|
||||
await self._cache_instance.close()
|
||||
|
||||
async def verify_connection(self):
|
||||
"""连接测试"""
|
||||
try:
|
||||
await self._cache.get("__test__")
|
||||
except Exception as e:
|
||||
logger.error("连接失败", LOG_COMMAND, e=e)
|
||||
raise
|
||||
|
||||
async def init_non_lazy_caches(self):
|
||||
"""初始化所有非延迟加载的缓存"""
|
||||
await self.verify_connection()
|
||||
for name, cache in self._data.items():
|
||||
cache._cache = self._cache
|
||||
if not cache.lazy_load:
|
||||
try:
|
||||
await cache.reload()
|
||||
@ -527,7 +571,7 @@ class CacheManager:
|
||||
def new(self, name: str, lazy_load: bool = True, expire: int = 600):
|
||||
"""注册新缓存
|
||||
|
||||
Args:
|
||||
参数:
|
||||
name: 缓存名称
|
||||
lazy_load: 是否延迟加载,默认为True。为False时会在程序启动时自动加载
|
||||
expire: 过期时间(秒)
|
||||
@ -543,6 +587,7 @@ class CacheManager:
|
||||
func=func,
|
||||
expire=expire,
|
||||
lazy_load=lazy_load,
|
||||
_cache=self._cache,
|
||||
)
|
||||
return func
|
||||
|
||||
@ -649,3 +694,8 @@ class Cache(Generic[T]):
|
||||
async def reload(self, *args, **kwargs):
|
||||
"""重新加载缓存"""
|
||||
await CacheRoot.reload(self.module, *args, **kwargs)
|
||||
|
||||
|
||||
@driver.on_shutdown
|
||||
async def _():
|
||||
await CacheRoot.close()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user