mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 14:22:55 +08:00
✨ cache支持redis连接
This commit is contained in:
parent
e63a3692e9
commit
e04bd1eff8
10
.env.dev
10
.env.dev
@ -27,6 +27,16 @@ QBOT_ID_DATA = '{
|
|||||||
# 示例: "sqlite:data/db/zhenxun.db" 在data目录下建立db文件夹
|
# 示例: "sqlite:data/db/zhenxun.db" 在data目录下建立db文件夹
|
||||||
DB_URL = ""
|
DB_URL = ""
|
||||||
|
|
||||||
|
# REDIS配置,使用REDIS替换Cache内存缓存
|
||||||
|
# REDIS地址
|
||||||
|
# REDIS_HOST = "127.0.0.1"
|
||||||
|
# REDIS端口
|
||||||
|
# REDIS_PORT = 6379
|
||||||
|
# REDIS密码
|
||||||
|
# REDIS_PASSWORD = ""
|
||||||
|
# REDIS过期时间
|
||||||
|
# REDIS_EXPIRE = 600
|
||||||
|
|
||||||
# 系统代理
|
# 系统代理
|
||||||
# SYSTEM_PROXY = "http://127.0.0.1:7890"
|
# SYSTEM_PROXY = "http://127.0.0.1:7890"
|
||||||
|
|
||||||
|
|||||||
245
poetry.lock
generated
245
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -39,7 +39,7 @@ dateparser = "^1.2.0"
|
|||||||
bilireq = "0.2.3post0"
|
bilireq = "0.2.3post0"
|
||||||
python-jose = { extras = ["cryptography"], version = "^3.3.0" }
|
python-jose = { extras = ["cryptography"], version = "^3.3.0" }
|
||||||
python-multipart = "^0.0.9"
|
python-multipart = "^0.0.9"
|
||||||
aiocache = ">=0.12, <1.0.0"
|
aiocache = {extras = ["redis"], version = "^0.12.3"}
|
||||||
py-cpuinfo = "^9.0.0"
|
py-cpuinfo = "^9.0.0"
|
||||||
nonebot-plugin-alconna = "^0.54.0"
|
nonebot-plugin-alconna = "^0.54.0"
|
||||||
tenacity = "^9.0.0"
|
tenacity = "^9.0.0"
|
||||||
|
|||||||
@ -5,12 +5,14 @@ import inspect
|
|||||||
from typing import Any, ClassVar, Generic, TypeVar
|
from typing import Any, ClassVar, Generic, TypeVar
|
||||||
|
|
||||||
from aiocache import Cache as AioCache
|
from aiocache import Cache as AioCache
|
||||||
|
|
||||||
|
# from aiocache.backends.redis import RedisCache
|
||||||
from aiocache.base import BaseCache
|
from aiocache.base import BaseCache
|
||||||
from aiocache.serializers import JsonSerializer
|
from aiocache.serializers import JsonSerializer
|
||||||
|
import nonebot
|
||||||
from nonebot.compat import model_dump
|
from nonebot.compat import model_dump
|
||||||
from nonebot.utils import is_coroutine_callable
|
from nonebot.utils import is_coroutine_callable
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from tortoise.fields.base import Field
|
|
||||||
|
|
||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
|
|
||||||
@ -18,6 +20,24 @@ __all__ = ["Cache", "CacheData", "CacheRoot"]
|
|||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
LOG_COMMAND = "cache"
|
||||||
|
|
||||||
|
driver = nonebot.get_driver()
|
||||||
|
|
||||||
|
|
||||||
|
class Config(BaseModel):
|
||||||
|
redis_host: str | None = None
|
||||||
|
"""redis地址"""
|
||||||
|
redis_port: int | None = None
|
||||||
|
"""redis端口"""
|
||||||
|
redis_password: str | None = None
|
||||||
|
"""redis密码"""
|
||||||
|
redis_expire: int = 600
|
||||||
|
"""redis过期时间"""
|
||||||
|
|
||||||
|
|
||||||
|
config = nonebot.get_plugin_config(Config)
|
||||||
|
|
||||||
|
|
||||||
class DbCacheException(Exception):
|
class DbCacheException(Exception):
|
||||||
"""缓存相关异常"""
|
"""缓存相关异常"""
|
||||||
@ -99,35 +119,21 @@ class CacheData(BaseModel):
|
|||||||
expire: int = 600 # 默认10分钟过期
|
expire: int = 600 # 默认10分钟过期
|
||||||
reload_count: int = 0
|
reload_count: int = 0
|
||||||
lazy_load: bool = True # 默认延迟加载
|
lazy_load: bool = True # 默认延迟加载
|
||||||
_cache_instance: BaseCache | None = None
|
|
||||||
result_model: type | None = None
|
result_model: type | None = None
|
||||||
_keys: set[str] = set() # 存储所有缓存键
|
_keys: set[str] = set() # 存储所有缓存键
|
||||||
|
_cache: BaseCache | AioCache
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
underscore_attrs_are_private = True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _cache(self) -> BaseCache:
|
|
||||||
"""获取aiocache实例"""
|
|
||||||
if self._cache_instance is None:
|
|
||||||
self._cache_instance = AioCache(
|
|
||||||
AioCache.MEMORY,
|
|
||||||
serializer=JsonSerializer(),
|
|
||||||
namespace="zhenxun_cache",
|
|
||||||
timeout=30, # 操作超时时间
|
|
||||||
ttl=self.expire, # 设置默认过期时间
|
|
||||||
)
|
|
||||||
return self._cache_instance
|
|
||||||
|
|
||||||
def _deserialize_value(self, value: Any, target_type: type | None = None) -> Any:
|
def _deserialize_value(self, value: Any, target_type: type | None = None) -> Any:
|
||||||
"""反序列化值,将JSON数据转换回原始类型
|
"""反序列化值,将JSON数据转换回原始类型
|
||||||
|
|
||||||
Args:
|
参数:
|
||||||
value: 需要反序列化的值
|
value: 需要反序列化的值
|
||||||
target_type: 目标类型,用于指导反序列化
|
target_type: 目标类型,用于指导反序列化
|
||||||
|
|
||||||
Returns:
|
返回:
|
||||||
反序列化后的值
|
反序列化后的值
|
||||||
"""
|
"""
|
||||||
if value is None:
|
if value is None:
|
||||||
@ -137,11 +143,40 @@ class CacheData(BaseModel):
|
|||||||
if isinstance(value, dict) and target_type:
|
if isinstance(value, dict) and target_type:
|
||||||
# 处理Tortoise-ORM Model
|
# 处理Tortoise-ORM Model
|
||||||
if hasattr(target_type, "_meta"):
|
if hasattr(target_type, "_meta"):
|
||||||
|
return self._extracted_from__deserialize_value_19(value, target_type)
|
||||||
|
elif hasattr(target_type, "model_validate"):
|
||||||
|
return target_type.model_validate(value)
|
||||||
|
elif hasattr(target_type, "from_dict"):
|
||||||
|
return target_type.from_dict(value)
|
||||||
|
elif hasattr(target_type, "parse_obj"):
|
||||||
|
return target_type.parse_obj(value)
|
||||||
|
else:
|
||||||
|
return target_type(**value)
|
||||||
|
|
||||||
|
# 处理列表类型
|
||||||
|
if isinstance(value, list):
|
||||||
|
if not value:
|
||||||
|
return value
|
||||||
|
if (
|
||||||
|
target_type
|
||||||
|
and hasattr(target_type, "__origin__")
|
||||||
|
and target_type.__origin__ is list
|
||||||
|
):
|
||||||
|
item_type = target_type.__args__[0]
|
||||||
|
return [self._deserialize_value(item, item_type) for item in value]
|
||||||
|
return [self._deserialize_value(item) for item in value]
|
||||||
|
|
||||||
|
# 处理字典类型
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return {k: self._deserialize_value(v) for k, v in value.items()}
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
def _extracted_from__deserialize_value_19(self, value, target_type):
|
||||||
# 处理字段值
|
# 处理字段值
|
||||||
processed_value = {}
|
processed_value = {}
|
||||||
for field_name, field_value in value.items():
|
for field_name, field_value in value.items():
|
||||||
field: Field = target_type._meta.fields_map.get(field_name)
|
if field := target_type._meta.fields_map.get(field_name):
|
||||||
if field:
|
|
||||||
# 跳过反向关系字段
|
# 跳过反向关系字段
|
||||||
if hasattr(field, "_related_name"):
|
if hasattr(field, "_related_name"):
|
||||||
continue
|
continue
|
||||||
@ -169,38 +204,6 @@ class CacheData(BaseModel):
|
|||||||
# 设置 _saved_in_db 标志
|
# 设置 _saved_in_db 标志
|
||||||
instance._saved_in_db = True
|
instance._saved_in_db = True
|
||||||
return instance
|
return instance
|
||||||
# 处理Pydantic模型
|
|
||||||
elif hasattr(target_type, "model_validate"):
|
|
||||||
return target_type.model_validate(value)
|
|
||||||
elif hasattr(target_type, "from_dict"):
|
|
||||||
return target_type.from_dict(value)
|
|
||||||
elif hasattr(target_type, "parse_obj"):
|
|
||||||
return target_type.parse_obj(value)
|
|
||||||
else:
|
|
||||||
return target_type(**value)
|
|
||||||
|
|
||||||
# 处理列表类型
|
|
||||||
if isinstance(value, list):
|
|
||||||
if not value:
|
|
||||||
return value
|
|
||||||
if (
|
|
||||||
target_type
|
|
||||||
and hasattr(target_type, "__origin__")
|
|
||||||
and target_type.__origin__ is list
|
|
||||||
):
|
|
||||||
item_type = target_type.__args__[0]
|
|
||||||
return [self._deserialize_value(item, item_type) for item in value]
|
|
||||||
return [self._deserialize_value(item) for item in value]
|
|
||||||
|
|
||||||
# 处理字典类型
|
|
||||||
if isinstance(value, dict):
|
|
||||||
return {k: self._deserialize_value(v) for k, v in value.items()}
|
|
||||||
|
|
||||||
# 处理基本类型
|
|
||||||
if isinstance(value, int | float | str | bool):
|
|
||||||
return value
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
async def get_data(self) -> Any:
|
async def get_data(self) -> Any:
|
||||||
"""从缓存获取数据"""
|
"""从缓存获取数据"""
|
||||||
@ -245,10 +248,10 @@ class CacheData(BaseModel):
|
|||||||
def _serialize_value(self, value: Any) -> Any:
|
def _serialize_value(self, value: Any) -> Any:
|
||||||
"""序列化值,将数据转换为JSON可序列化的格式
|
"""序列化值,将数据转换为JSON可序列化的格式
|
||||||
|
|
||||||
Args:
|
参数:
|
||||||
value: 需要序列化的值
|
value: 需要序列化的值
|
||||||
|
|
||||||
Returns:
|
返回:
|
||||||
JSON可序列化的值
|
JSON可序列化的值
|
||||||
"""
|
"""
|
||||||
if value is None:
|
if value is None:
|
||||||
@ -417,10 +420,10 @@ class CacheData(BaseModel):
|
|||||||
async def get_key(self, key: str) -> Any:
|
async def get_key(self, key: str) -> Any:
|
||||||
"""获取缓存中指定键的数据
|
"""获取缓存中指定键的数据
|
||||||
|
|
||||||
Args:
|
参数:
|
||||||
key: 要获取的键名
|
key: 要获取的键名
|
||||||
|
|
||||||
Returns:
|
返回:
|
||||||
键对应的值,如果不存在返回None
|
键对应的值,如果不存在返回None
|
||||||
"""
|
"""
|
||||||
cache_key = self._get_cache_key(key)
|
cache_key = self._get_cache_key(key)
|
||||||
@ -438,10 +441,10 @@ class CacheData(BaseModel):
|
|||||||
async def get_keys(self, keys: list[str]) -> dict[str, Any]:
|
async def get_keys(self, keys: list[str]) -> dict[str, Any]:
|
||||||
"""获取缓存中多个键的数据
|
"""获取缓存中多个键的数据
|
||||||
|
|
||||||
Args:
|
参数:
|
||||||
keys: 要获取的键名列表
|
keys: 要获取的键名列表
|
||||||
|
|
||||||
Returns:
|
返回:
|
||||||
包含所有请求键值的字典,不存在的键值为None
|
包含所有请求键值的字典,不存在的键值为None
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
@ -512,11 +515,52 @@ class CacheData(BaseModel):
|
|||||||
class CacheManager:
|
class CacheManager:
|
||||||
"""全局缓存管理器"""
|
"""全局缓存管理器"""
|
||||||
|
|
||||||
|
_cache_instance: BaseCache | AioCache | None = None
|
||||||
_data: ClassVar[dict[str, CacheData]] = {}
|
_data: ClassVar[dict[str, CacheData]] = {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _cache(self) -> BaseCache | AioCache:
|
||||||
|
"""获取aiocache实例"""
|
||||||
|
if self._cache_instance is None:
|
||||||
|
if config.redis_host:
|
||||||
|
self._cache_instance = AioCache(
|
||||||
|
AioCache.REDIS,
|
||||||
|
serializer=JsonSerializer(),
|
||||||
|
namespace="zhenxun_cache",
|
||||||
|
timeout=30, # 操作超时时间
|
||||||
|
ttl=config.redis_expire, # 设置默认过期时间
|
||||||
|
endpoint=config.redis_host,
|
||||||
|
port=config.redis_port,
|
||||||
|
password=config.redis_password,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._cache_instance = AioCache(
|
||||||
|
AioCache.MEMORY,
|
||||||
|
serializer=JsonSerializer(),
|
||||||
|
namespace="zhenxun_cache",
|
||||||
|
timeout=30, # 操作超时时间
|
||||||
|
ttl=config.redis_expire, # 设置默认过期时间
|
||||||
|
)
|
||||||
|
logger.info("初始化缓存完成...", LOG_COMMAND)
|
||||||
|
return self._cache_instance
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
if self._cache_instance:
|
||||||
|
await self._cache_instance.close()
|
||||||
|
|
||||||
|
async def verify_connection(self):
|
||||||
|
"""连接测试"""
|
||||||
|
try:
|
||||||
|
await self._cache.get("__test__")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("连接失败", LOG_COMMAND, e=e)
|
||||||
|
raise
|
||||||
|
|
||||||
async def init_non_lazy_caches(self):
|
async def init_non_lazy_caches(self):
|
||||||
"""初始化所有非延迟加载的缓存"""
|
"""初始化所有非延迟加载的缓存"""
|
||||||
|
await self.verify_connection()
|
||||||
for name, cache in self._data.items():
|
for name, cache in self._data.items():
|
||||||
|
cache._cache = self._cache
|
||||||
if not cache.lazy_load:
|
if not cache.lazy_load:
|
||||||
try:
|
try:
|
||||||
await cache.reload()
|
await cache.reload()
|
||||||
@ -527,7 +571,7 @@ class CacheManager:
|
|||||||
def new(self, name: str, lazy_load: bool = True, expire: int = 600):
|
def new(self, name: str, lazy_load: bool = True, expire: int = 600):
|
||||||
"""注册新缓存
|
"""注册新缓存
|
||||||
|
|
||||||
Args:
|
参数:
|
||||||
name: 缓存名称
|
name: 缓存名称
|
||||||
lazy_load: 是否延迟加载,默认为True。为False时会在程序启动时自动加载
|
lazy_load: 是否延迟加载,默认为True。为False时会在程序启动时自动加载
|
||||||
expire: 过期时间(秒)
|
expire: 过期时间(秒)
|
||||||
@ -543,6 +587,7 @@ class CacheManager:
|
|||||||
func=func,
|
func=func,
|
||||||
expire=expire,
|
expire=expire,
|
||||||
lazy_load=lazy_load,
|
lazy_load=lazy_load,
|
||||||
|
_cache=self._cache,
|
||||||
)
|
)
|
||||||
return func
|
return func
|
||||||
|
|
||||||
@ -649,3 +694,8 @@ class Cache(Generic[T]):
|
|||||||
async def reload(self, *args, **kwargs):
|
async def reload(self, *args, **kwargs):
|
||||||
"""重新加载缓存"""
|
"""重新加载缓存"""
|
||||||
await CacheRoot.reload(self.module, *args, **kwargs)
|
await CacheRoot.reload(self.module, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@driver.on_shutdown
|
||||||
|
async def _():
|
||||||
|
await CacheRoot.close()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user