🐛 尝试迁移至aiocache

This commit is contained in:
HibiKier 2025-05-08 18:43:27 +08:00
parent 83942bae26
commit f3d5b77bdc
3 changed files with 385 additions and 420 deletions

View File

@ -11,7 +11,7 @@ from zhenxun.utils.platform import PlatformUtils
nonebot.load_plugins(str(Path(__file__).parent.resolve()))
try:
from .__init_cache import driver
from . import __init_cache
except DbCacheException as e:
raise SystemError(f"ERROR{e}")

View File

@ -1,8 +1,5 @@
import time
from typing import Any
import nonebot
from zhenxun.models.ban_console import BanConsole
from zhenxun.models.bot_console import BotConsole
from zhenxun.models.group_console import GroupConsole
@ -11,77 +8,20 @@ from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.plugin_limit import PluginLimit
from zhenxun.models.user_console import UserConsole
from zhenxun.services.cache import CacheData, CacheRoot
from zhenxun.services.log import logger
from zhenxun.utils.enum import CacheType
driver = nonebot.get_driver()
@driver.on_startup
async def _():
"""开启cache检测"""
CacheRoot.start_check()
def default_cleanup_expired(cache_data: CacheData) -> list[str]:
"""默认清理过期cache方法"""
if not cache_data.data:
return []
now = time.time()
expire_key = []
for k, t in list(cache_data.expire_data.items()):
if t < now:
expire_key.append(k)
cache_data.expire_data.pop(k)
if expire_key:
cache_data.data = {
k: v for k, v in cache_data.data.items() if k not in expire_key
}
return expire_key
def default_cleanup_expired_1(cache_data: CacheData) -> list[str]:
"""默认清理列表过期cache方法"""
if not cache_data.data:
return []
now = time.time()
expire_key = []
for k, t in list(cache_data.expire_data.items()):
if t < now:
expire_key.append(k)
cache_data.expire_data.pop(k)
if expire_key:
cache_data.data = [k for k in cache_data.data if repr(k) not in expire_key]
return expire_key
def default_with_expiration(
data: dict[str, Any], expire_data: dict[str, int], expire: int
):
"""默认更新过期时间cache方法"""
if not data:
return {}
keys = {k for k in data if k not in expire_data}
return {k: time.time() + expire for k in keys} if keys else {}
def default_with_expiration_1(
data: dict[str, Any], expire_data: dict[str, int], expire: int
):
"""默认更新过期时间cache方法"""
if not data:
return {}
keys = {repr(k) for k in data if repr(k) not in expire_data}
return {k: time.time() + expire for k in keys} if keys else {}
@CacheRoot.new(CacheType.PLUGINS)
async def _():
"""初始化插件缓存"""
data_list = await PluginInfo.get_plugins()
return {p.module: p for p in data_list}
@CacheRoot.updater(CacheType.PLUGINS)
async def _(data: dict[str, PluginInfo], key: str, value: Any):
"""更新插件缓存"""
if value:
data[key] = value
elif plugin := await PluginInfo.get_plugin(module=key):
@ -90,42 +30,35 @@ async def _(data: dict[str, PluginInfo], key: str, value: Any):
@CacheRoot.getter(CacheType.PLUGINS, result_model=PluginInfo)
async def _(cache_data: CacheData, module: str):
cache_data.data = cache_data.data or {}
result = cache_data.data.get(module, None)
if not result:
result = await PluginInfo.get_plugin(module=module)
if result:
cache_data.data[module] = result
return result
"""获取插件缓存"""
data = await cache_data.get_data() or {}
if module not in data:
if plugin := await PluginInfo.get_plugin(module=module):
data[module] = plugin
await cache_data.set_data(data)
logger.debug(f"插件 {module} 数据已设置到缓存")
return data.get(module)
@CacheRoot.with_refresh(CacheType.PLUGINS)
async def _(data: dict[str, PluginInfo] | None):
"""刷新插件缓存"""
if not data:
return
plugins = await PluginInfo.filter(module__in=data.keys(), load_status=True).all()
for plugin in plugins:
data[plugin.module] = plugin
@CacheRoot.with_expiration(CacheType.PLUGINS)
def _(data: dict[str, PluginInfo], expire_data: dict[str, int], expire: int):
return default_with_expiration(data, expire_data, expire)
@CacheRoot.cleanup_expired(CacheType.PLUGINS)
def _(cache_data: CacheData):
return default_cleanup_expired(cache_data)
data.update({p.module: p for p in plugins})
@CacheRoot.new(CacheType.GROUPS)
async def _():
"""初始化群组缓存"""
data_list = await GroupConsole.all()
return {p.group_id: p for p in data_list if not p.channel_id}
@CacheRoot.updater(CacheType.GROUPS)
async def _(data: dict[str, GroupConsole], key: str, value: Any):
"""更新群组缓存"""
if value:
data[key] = value
elif group := await GroupConsole.get_group(group_id=key):
@ -134,44 +67,36 @@ async def _(data: dict[str, GroupConsole], key: str, value: Any):
@CacheRoot.getter(CacheType.GROUPS, result_model=GroupConsole)
async def _(cache_data: CacheData, group_id: str):
cache_data.data = cache_data.data or {}
result = cache_data.data.get(group_id, None)
if not result:
result = await GroupConsole.get_group(group_id=group_id)
if result:
cache_data.data[group_id] = result
return result
"""获取群组缓存"""
data = await cache_data.get_data() or {}
if group_id not in data:
if group := await GroupConsole.get_group(group_id=group_id):
data[group_id] = group
await cache_data.set_data(data)
return data.get(group_id)
@CacheRoot.with_refresh(CacheType.GROUPS)
async def _(data: dict[str, GroupConsole] | None):
"""刷新群组缓存"""
if not data:
return
groups = await GroupConsole.filter(
group_id__in=data.keys(), channel_id__isnull=True
).all()
for group in groups:
data[group.group_id] = group
@CacheRoot.with_expiration(CacheType.GROUPS)
def _(data: dict[str, GroupConsole], expire_data: dict[str, int], expire: int):
return default_with_expiration(data, expire_data, expire)
@CacheRoot.cleanup_expired(CacheType.GROUPS)
def _(cache_data: CacheData):
return default_cleanup_expired(cache_data)
data.update({g.group_id: g for g in groups})
@CacheRoot.new(CacheType.BOT)
async def _():
"""初始化机器人缓存"""
data_list = await BotConsole.all()
return {p.bot_id: p for p in data_list}
@CacheRoot.updater(CacheType.BOT)
async def _(data: dict[str, BotConsole], key: str, value: Any):
"""更新机器人缓存"""
if value:
data[key] = value
elif bot := await BotConsole.get_or_none(bot_id=key):
@ -180,42 +105,34 @@ async def _(data: dict[str, BotConsole], key: str, value: Any):
@CacheRoot.getter(CacheType.BOT, result_model=BotConsole)
async def _(cache_data: CacheData, bot_id: str):
cache_data.data = cache_data.data or {}
result = cache_data.data.get(bot_id, None)
if not result:
result = await BotConsole.get_or_none(bot_id=bot_id)
if result:
cache_data.data[bot_id] = result
return result
"""获取机器人缓存"""
data = await cache_data.get_data() or {}
if bot_id not in data:
if bot := await BotConsole.get_or_none(bot_id=bot_id):
data[bot_id] = bot
await cache_data.set_data(data)
return data.get(bot_id)
@CacheRoot.with_refresh(CacheType.BOT)
async def _(data: dict[str, BotConsole] | None):
"""刷新机器人缓存"""
if not data:
return
bots = await BotConsole.filter(bot_id__in=data.keys()).all()
for bot in bots:
data[bot.bot_id] = bot
@CacheRoot.with_expiration(CacheType.BOT)
def _(data: dict[str, BotConsole], expire_data: dict[str, int], expire: int):
return default_with_expiration(data, expire_data, expire)
@CacheRoot.cleanup_expired(CacheType.BOT)
def _(cache_data: CacheData):
return default_cleanup_expired(cache_data)
data.update({b.bot_id: b for b in bots})
@CacheRoot.new(CacheType.USERS)
async def _():
"""初始化用户缓存"""
data_list = await UserConsole.all()
return {p.user_id: p for p in data_list}
@CacheRoot.updater(CacheType.USERS)
async def _(data: dict[str, UserConsole], key: str, value: Any):
"""更新用户缓存"""
if value:
data[key] = value
elif user := await UserConsole.get_user(user_id=key):
@ -224,108 +141,61 @@ async def _(data: dict[str, UserConsole], key: str, value: Any):
@CacheRoot.getter(CacheType.USERS, result_model=UserConsole)
async def _(cache_data: CacheData, user_id: str):
cache_data.data = cache_data.data or {}
result = cache_data.data.get(user_id, None)
if not result:
result = await UserConsole.get_user(user_id=user_id)
if result:
cache_data.data[user_id] = result
return result
"""获取用户缓存"""
data = await cache_data.get_data() or {}
if user_id not in data:
if user := await UserConsole.get_user(user_id=user_id):
data[user_id] = user
await cache_data.set_data(data)
return data.get(user_id)
@CacheRoot.with_refresh(CacheType.USERS)
async def _(data: dict[str, UserConsole] | None):
"""刷新用户缓存"""
if not data:
return
users = await UserConsole.filter(user_id__in=data.keys()).all()
for user in users:
data[user.user_id] = user
@CacheRoot.with_expiration(CacheType.USERS)
def _(data: dict[str, UserConsole], expire_data: dict[str, int], expire: int):
return default_with_expiration(data, expire_data, expire)
@CacheRoot.cleanup_expired(CacheType.USERS)
def _(cache_data: CacheData):
return default_cleanup_expired(cache_data)
data.update({u.user_id: u for u in users})
@CacheRoot.new(CacheType.LEVEL, False)
async def _():
"""初始化等级缓存"""
return await LevelUser().all()
@CacheRoot.getter(CacheType.LEVEL, result_model=list[LevelUser])
async def _(cache_data: CacheData, user_id: str, group_id: str | None = None):
cache_data.data = cache_data.data or []
"""获取等级缓存"""
data = await cache_data.get_data() or []
if not group_id:
return [
data
for data in cache_data.data
if data.user_id == user_id and not data.group_id
]
else:
return [
data
for data in cache_data.data
if data.user_id == user_id and data.group_id == group_id
]
@CacheRoot.with_expiration(CacheType.LEVEL)
def _(data: dict[str, UserConsole], expire_data: dict[str, int], expire: int):
return default_with_expiration_1(data, expire_data, expire)
@CacheRoot.cleanup_expired(CacheType.LEVEL)
def _(cache_data: CacheData):
return default_cleanup_expired_1(cache_data)
return [d for d in data if d.user_id == user_id and not d.group_id]
return [d for d in data if d.user_id == user_id and d.group_id == group_id]
@CacheRoot.new(CacheType.BAN, False)
async def _():
"""初始化封禁缓存"""
return await BanConsole.all()
@CacheRoot.getter(CacheType.BAN, result_model=list[BanConsole])
async def _(cache_data: CacheData, user_id: str | None, group_id: str | None = None):
"""获取封禁缓存"""
data = await cache_data.get_data() or []
if user_id:
return (
[
data
for data in cache_data.data
if data.user_id == user_id and data.group_id == group_id
]
if group_id
else [
data
for data in cache_data.data
if data.user_id == user_id and not data.group_id
]
)
if group_id:
return [d for d in data if d.user_id == user_id and d.group_id == group_id]
return [d for d in data if d.user_id == user_id and not d.group_id]
if group_id:
return [
data
for data in cache_data.data
if not data.user_id and data.group_id == group_id
]
return [d for d in data if not d.user_id and d.group_id == group_id]
return None
@CacheRoot.with_expiration(CacheType.BAN)
def _(data: dict[str, UserConsole], expire_data: dict[str, int], expire: int):
return default_with_expiration_1(data, expire_data, expire)
@CacheRoot.cleanup_expired(CacheType.BAN)
def _(cache_data: CacheData):
return default_cleanup_expired_1(cache_data)
@CacheRoot.new(CacheType.LIMIT)
async def _():
"""初始化限制缓存"""
data_list = await PluginLimit.filter(status=True).all()
result_data = {}
for data in data_list:
@ -337,6 +207,7 @@ async def _():
@CacheRoot.updater(CacheType.LIMIT)
async def _(data: dict[str, list[PluginLimit]], key: str, value: Any):
"""更新限制缓存"""
if value:
data[key] = value
elif limits := await PluginLimit.filter(module=key, status=True):
@ -345,32 +216,25 @@ async def _(data: dict[str, list[PluginLimit]], key: str, value: Any):
@CacheRoot.getter(CacheType.LIMIT, result_model=list[PluginLimit])
async def _(cache_data: CacheData, module: str):
cache_data.data = cache_data.data or {}
result = cache_data.data.get(module, None)
if not result:
result = await PluginLimit.filter(module=module, status=True)
if result:
cache_data.data[module] = result
return result
"""获取限制缓存"""
data = await cache_data.get_data() or {}
if module not in data:
if limits := await PluginLimit.filter(module=module, status=True):
data[module] = limits
await cache_data.set_data(data)
return data.get(module)
@CacheRoot.with_refresh(CacheType.LIMIT)
async def _(data: dict[str, list[PluginLimit]] | None):
"""刷新限制缓存"""
if not data:
return
limits = await PluginLimit.filter(module__in=data.keys(), load_status=True).all()
data.clear()
new_data = {}
for limit in limits:
if not data.get(limit.module):
data[limit.module] = []
data[limit.module].append(limit)
@CacheRoot.with_expiration(CacheType.LIMIT)
def _(data: dict[str, PluginInfo], expire_data: dict[str, int], expire: int):
return default_with_expiration(data, expire_data, expire)
@CacheRoot.cleanup_expired(CacheType.LIMIT)
def _(cache_data: CacheData):
return default_cleanup_expired(cache_data)
if not new_data.get(limit.module):
new_data[limit.module] = []
new_data[limit.module].append(limit)
data.clear()
data.update(new_data)

View File

@ -1,11 +1,14 @@
from collections.abc import Callable
from datetime import datetime
from functools import wraps
import inspect
import time
from typing import Any, ClassVar, Generic, TypeVar, cast
from typing import Any, ClassVar, Generic, TypeVar
from aiocache import Cache as AioCache
from aiocache.base import BaseCache
from aiocache.serializers import JsonSerializer
from nonebot.compat import model_dump
from nonebot.utils import is_coroutine_callable
from nonebot_plugin_apscheduler import scheduler
from pydantic import BaseModel
from zhenxun.services.log import logger
@ -16,168 +19,258 @@ T = TypeVar("T")
class DbCacheException(Exception):
"""缓存相关异常"""
def __init__(self, info: str):
self.info = info
def __repr__(self) -> str:
return super().__repr__()
def __str__(self) -> str:
return self.info
def validate_name(func: Callable):
"""
装饰器验证 name 是否存在于 CacheManage._data
"""
"""验证缓存名称是否存在的装饰器"""
def wrapper(self, name: str, *args, **kwargs):
_name = name.upper()
if _name not in CacheManager._data:
raise DbCacheException(f"DbCache 缓存数据 {name} 不存在...")
raise DbCacheException(f"缓存数据 {name} 不存在")
return func(self, _name, *args, **kwargs)
return wrapper
class CacheGetter(BaseModel, Generic[T]):
"""缓存数据获取器"""
get_func: Callable[..., Any] | None = None
"""获取方法"""
async def get(self, cache_data: "CacheData", *args, **kwargs) -> T:
"""获取缓存"""
"""获取处理后的缓存数据"""
if not self.get_func:
return cache_data.data
return await cache_data.get_data()
if is_coroutine_callable(self.get_func):
processed_data = await self.get_func(cache_data, *args, **kwargs)
else:
processed_data = self.get_func(cache_data, *args, **kwargs)
return cast(T, processed_data)
return await self.get_func(cache_data, *args, **kwargs)
return self.get_func(cache_data, *args, **kwargs)
class CacheData(BaseModel):
"""缓存数据模型"""
name: str
"""缓存名称"""
func: Callable[..., Any]
"""更新方法"""
getter: CacheGetter | None = None
"""获取方法"""
updater: Callable[..., Any] | None = None
"""更新单个方法"""
with_refresh: Callable[..., Any] | None = None
"""刷新方法"""
with_expiration: Callable[..., Any] | None = None
"""缓存时间初始化方法"""
cleanup_expired: Callable[..., Any] | None = None
"""缓存过期方法"""
data: Any = None
"""缓存数据"""
expire: int
"""缓存过期时间"""
expire_data: dict[str, int | float] = {}
"""缓存过期数据时间记录"""
reload_time: float = time.time()
"""更新时间"""
expire: int = 600 # 默认10分钟过期
reload_count: int = 0
"""更新次数"""
incremental_update: bool = True
"""是否是增量更新"""
async def get(self, *args, **kwargs) -> Any:
"""获取单个缓存"""
if not self.reload_count and not self.incremental_update:
# 首次获取时,非增量更新获取全部数据
await self.reload()
self.call_cleanup_expired() # 移除过期缓存
if not self.getter:
return self.data
result = await self.getter.get(self, *args, **kwargs)
await self.call_with_expiration()
return result
class Config:
arbitrary_types_allowed = True
async def update(self, key: str, value: Any = None, *args, **kwargs):
"""更新单个缓存"""
if not self.updater:
return logger.warning(
f"缓存类型 {self.name} 没有更新方法,无法更新", "CacheRoot"
)
if self.data:
if is_coroutine_callable(self.updater):
await self.updater(self.data, key, value, *args, **kwargs)
else:
self.updater(self.data, key, value, *args, **kwargs)
logger.debug(
f"缓存类型 {self.name} 更新单个缓存 key: {key}value: {value}",
"CacheRoot",
)
self.expire_data[key] = time.time() + self.expire
else:
logger.warning(f"缓存类型 {self.name} 为空,无法更新", "CacheRoot")
async def refresh(self, *args, **kwargs):
"""刷新缓存,只刷新已缓存的数据"""
if not self.with_refresh:
return await self.reload(*args, **kwargs)
if self.data:
if is_coroutine_callable(self.with_refresh):
await self.with_refresh(self.data, *args, **kwargs)
else:
self.with_refresh(self.data, *args, **kwargs)
logger.debug(
f"缓存类型 {self.name} 刷新全局缓存,共刷新 {len(self.data)} 条数据",
"CacheRoot",
)
async def reload(self, *args, **kwargs):
"""更新全部缓存数据"""
if self.has_args():
self.data = (
await self.func(*args, **kwargs)
if is_coroutine_callable(self.func)
else self.func(*args, **kwargs)
)
else:
self.data = (
await self.func() if is_coroutine_callable(self.func) else self.func()
)
await self.call_with_expiration()
self.reload_time = time.time()
self.reload_count += 1
logger.debug(
f"缓存类型 {self.name} 更新全局缓存,共更新 {len(self.data)} 条数据",
"CacheRoot",
@property
def _cache(self) -> BaseCache:
"""获取aiocache实例"""
return AioCache(
AioCache.MEMORY,
serializer=JsonSerializer(),
namespace="zhenxun_cache",
timeout=30, # 操作超时时间
ttl=self.expire, # 设置默认过期时间
)
def call_cleanup_expired(self):
"""清理过期缓存"""
if self.cleanup_expired:
if result := self.cleanup_expired(self):
logger.debug(
f"成功清理 {self.name} {len(result)} 条过期缓存", "CacheRoot"
async def get_data(self) -> Any:
"""从缓存获取数据"""
try:
data = await self._cache.get(self.name)
logger.debug(f"获取缓存 {self.name} 数据: {data}")
# 如果数据为空,尝试重新加载
if data is None:
logger.debug(f"缓存 {self.name} 数据为空,尝试重新加载")
try:
if self.has_args():
new_data = (
await self.func()
if is_coroutine_callable(self.func)
else self.func()
)
else:
new_data = (
await self.func()
if is_coroutine_callable(self.func)
else self.func()
)
await self.set_data(new_data)
self.reload_count += 1
logger.info(f"重新加载缓存 {self.name} 完成")
return new_data
except Exception as e:
logger.error(f"重新加载缓存 {self.name} 失败: {e}")
return None
return data
except Exception as e:
logger.error(f"获取缓存 {self.name} 失败: {e}")
return None
def _serialize_value(self, value: Any) -> Any:
"""序列化值将数据转换为JSON可序列化的格式
Args:
value: 需要序列化的值
Returns:
JSON可序列化的值
"""
if value is None:
return None
# 处理datetime
if isinstance(value, datetime):
return value.isoformat()
# 处理Tortoise-ORM Model
if hasattr(value, "_meta") and hasattr(value, "__dict__"):
result = {}
for field in value._meta.fields:
try:
field_value = getattr(value, field)
# 跳过反向关系字段
if isinstance(field_value, list | set) and hasattr(
field_value, "_related_name"
):
continue
result[field] = self._serialize_value(field_value)
except AttributeError:
continue
return result
# 处理Pydantic模型
elif isinstance(value, BaseModel):
return model_dump(value)
elif isinstance(value, dict):
# 处理字典
return {str(k): self._serialize_value(v) for k, v in value.items()}
elif isinstance(value, list | tuple | set):
# 处理列表、元组、集合
return [self._serialize_value(item) for item in value]
elif isinstance(value, int | float | str | bool):
# 基本类型直接返回
return value
elif hasattr(value, "__dict__"):
# 处理普通类对象
return self._serialize_value(value.__dict__)
else:
# 其他类型转换为字符串
return str(value)
async def set_data(self, value: Any):
"""设置缓存数据"""
try:
# 1. 序列化数据
serialized_value = self._serialize_value(value)
logger.debug(f"设置缓存 {self.name} 原始数据: {value}")
logger.debug(f"设置缓存 {self.name} 序列化后数据: {serialized_value}")
# 2. 删除旧数据
await self._cache.delete(self.name)
logger.debug(f"删除缓存 {self.name} 旧数据")
# 3. 设置新数据
await self._cache.set(self.name, serialized_value, ttl=self.expire)
logger.debug(f"设置缓存 {self.name} 新数据完成")
# 4. 立即验证
cached_data = await self._cache.get(self.name)
if cached_data is None:
logger.error(f"缓存 {self.name} 设置失败:数据验证失败")
# 5. 如果验证失败,尝试重新设置
await self._cache.set(self.name, serialized_value, ttl=self.expire)
cached_data = await self._cache.get(self.name)
if cached_data is None:
logger.error(f"缓存 {self.name} 重试设置失败")
else:
logger.debug(f"缓存 {self.name} 重试设置成功: {cached_data}")
else:
logger.debug(f"缓存 {self.name} 数据验证成功: {cached_data}")
except Exception as e:
logger.error(f"设置缓存 {self.name} 失败: {e}")
raise # 重新抛出异常,让上层处理
async def delete_data(self):
"""删除缓存数据"""
try:
await self._cache.delete(self.name)
except Exception as e:
logger.error(f"删除缓存 {self.name}", e=e)
async def get(self, *args, **kwargs) -> Any:
"""获取缓存"""
if not self.reload_count and not self.incremental_update:
await self.reload(*args, **kwargs)
if not self.getter:
return await self.get_data()
return await self.getter.get(self, *args, **kwargs)
async def update(self, key: str, value: Any = None, *args, **kwargs):
"""更新单个缓存项"""
if not self.updater:
logger.warning(f"缓存 {self.name} 未配置更新方法")
return
current_data = await self.get_data() or {}
if is_coroutine_callable(self.updater):
await self.updater(current_data, key, value, *args, **kwargs)
else:
self.updater(current_data, key, value, *args, **kwargs)
await self.set_data(current_data)
logger.debug(f"更新缓存 {self.name}.{key}")
async def refresh(self, *args, **kwargs):
"""刷新缓存数据"""
if not self.with_refresh:
return await self.reload(*args, **kwargs)
current_data = await self.get_data()
if current_data:
if is_coroutine_callable(self.with_refresh):
await self.with_refresh(current_data, *args, **kwargs)
else:
self.with_refresh(current_data, *args, **kwargs)
await self.set_data(current_data)
logger.debug(f"刷新缓存 {self.name}")
async def reload(self, *args, **kwargs):
"""重新加载全部数据"""
try:
if self.has_args():
new_data = (
await self.func(*args, **kwargs)
if is_coroutine_callable(self.func)
else self.func(*args, **kwargs)
)
else:
new_data = (
await self.func()
if is_coroutine_callable(self.func)
else self.func()
)
async def call_with_expiration(self, is_force: bool = False):
"""缓存时间更新
await self.set_data(new_data)
self.reload_count += 1
logger.info(f"重新加载缓存 {self.name} 完成")
except Exception as e:
logger.error(f"重新加载缓存 {self.name} 失败: {e}")
raise
参数:
is_force: 是否强制更新全部数据缓存时间.
"""
if self.with_expiration:
if is_force:
self.expire_data = {}
expiration_data = (
await self.with_expiration(self.data, self.expire_data, self.expire)
if is_coroutine_callable(self.with_expiration)
else self.with_expiration(self.data, self.expire_data, self.expire)
)
self.expire_data = {**self.expire_data, **expiration_data}
def has_args(self):
"""是否含有参数
返回:
bool: 是否含有参数
"""
def has_args(self) -> bool:
"""检查函数是否需要参数"""
sig = inspect.signature(self.func)
return any(
param.kind
@ -189,66 +282,81 @@ class CacheData(BaseModel):
for param in sig.parameters.values()
)
async def get_key(self, key: str) -> Any:
"""获取缓存中指定键的数据
Args:
key: 要获取的键名
Returns:
键对应的值如果不存在返回None
"""
try:
data = await self.get_data()
return data.get(key) if isinstance(data, dict) else None
except Exception as e:
logger.error(f"获取缓存 {self.name}.{key} 失败: {e}")
return None
async def get_keys(self, keys: list[str]) -> dict[str, Any]:
"""获取缓存中多个键的数据
Args:
keys: 要获取的键名列表
Returns:
包含所有请求键值的字典不存在的键值为None
"""
try:
data = await self.get_data()
if isinstance(data, dict):
return {key: data.get(key) for key in keys}
return dict.fromkeys(keys)
except Exception as e:
logger.error(f"获取缓存 {self.name} 的多个键失败: {e}")
return dict.fromkeys(keys)
class CacheManager:
"""全局缓存管理,减少数据库与网络请求查询次数
异常:
DbCacheException: 数据名称重复
DbCacheException: 数据不存在
"""
"""全局缓存管理器"""
_data: ClassVar[dict[str, CacheData]] = {}
def start_check(self):
"""启动缓存检查"""
for cache_data in self._data.values():
if cache_data.cleanup_expired:
scheduler.add_job(
cache_data.call_cleanup_expired,
"interval",
seconds=cache_data.expire,
args=[],
id=f"CacheRoot-{cache_data.name}",
)
def new(self, name: str, incremental_update: bool = True, expire: int = 600):
"""注册新缓存"""
def new(self, name: str, incremental_update: bool = True, expire: int = 60 * 10):
def wrapper(func: Callable):
_name = name.upper()
if _name in self._data:
raise DbCacheException(f"DbCache 缓存数据 {name} 已存在...")
raise DbCacheException(f"缓存 {name} 已存在")
self._data[_name] = CacheData(
name=_name,
func=func,
expire=expire,
incremental_update=incremental_update,
)
return func
return wrapper
def listener(self, name: str):
"""创建缓存监听器"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
try:
if is_coroutine_callable(func):
result = await func(*args, **kwargs)
else:
result = func(*args, **kwargs)
return result
return (
await func(*args, **kwargs)
if is_coroutine_callable(func)
else func(*args, **kwargs)
)
finally:
cache_data = self._data.get(name.upper())
if cache_data and cache_data.with_refresh:
if is_coroutine_callable(cache_data.with_refresh):
await cache_data.with_refresh(cache_data.data)
else:
cache_data.with_refresh(cache_data.data)
await cache_data.call_with_expiration(True)
logger.debug(
f"缓存类型 {name.upper()} 进行监听更新...", "CacheRoot"
)
cache = self._data.get(name.upper())
if cache and cache.with_refresh:
await cache.refresh()
logger.debug(f"监听器触发缓存 {name} 刷新")
return wrapper
@ -256,86 +364,79 @@ class CacheManager:
@validate_name
def updater(self, name: str):
"""设置缓存更新方法"""
def wrapper(func: Callable):
self._data[name.upper()].updater = func
self._data[name].updater = func
return func
return wrapper
@validate_name
def getter(self, name: str, result_model: type):
"""设置缓存获取方法"""
def wrapper(func: Callable):
self._data[name].getter = CacheGetter[result_model](get_func=func)
return func
return wrapper
@validate_name
def with_refresh(self, name: str):
"""设置缓存刷新方法"""
def wrapper(func: Callable):
self._data[name.upper()].with_refresh = func
self._data[name].with_refresh = func
return func
return wrapper
@validate_name
def with_expiration(self, name: str):
def wrapper(func: Callable[[Any, int], dict[str, float]]):
self._data[name.upper()].with_expiration = func
return wrapper
@validate_name
def cleanup_expired(self, name: str):
def wrapper(func: Callable[[CacheData], None]):
self._data[name.upper()].cleanup_expired = func
return wrapper
async def check_expire(self, name: str, *args, **kwargs):
name = name.upper()
if self._data.get(name) and (
time.time() - self._data[name].reload_time > self._data[name].expire
or not self._data[name].reload_count
):
await self._data[name].reload(*args, **kwargs)
async def get_cache_data(self, name: str):
return cache.data if (cache := await self.get_cache(name)) else None
async def get_cache(self, name: str, *args, **kwargs) -> CacheData | None:
name = name.upper()
if cache := self._data.get(name):
# await self.check_expire(name, *args, **kwargs)
return cache
return None
async def get(self, name: str, *args, **kwargs):
cache = await self.get_cache(name.upper(), *args, **kwargs)
if cache:
return await cache.get(*args, **kwargs) if cache.getter else cache.data
return None
async def reload(self, name: str, *args, **kwargs):
async def get_cache_data(self, name: str) -> Any | None:
"""获取缓存数据"""
cache = await self.get_cache(name.upper())
if cache:
await cache.refresh(*args, **kwargs)
return await cache.get_data() if cache else None
async def update(self, name: str, key: str, value: Any, *args, **kwargs):
async def get_cache(self, name: str) -> CacheData | None:
"""获取缓存对象"""
return self._data.get(name.upper())
async def get(self, name: str, *args, **kwargs) -> Any:
"""获取缓存内容"""
cache = await self.get_cache(name.upper())
return await cache.get(*args, **kwargs) if cache else None
async def update(self, name: str, key: str, value: Any = None, *args, **kwargs):
"""更新缓存项"""
cache = await self.get_cache(name.upper())
if cache:
await cache.update(key, value, *args, **kwargs)
async def reload(self, name: str, *args, **kwargs):
"""重新加载缓存"""
cache = await self.get_cache(name.upper())
if cache:
await cache.reload(*args, **kwargs)
# 全局缓存管理器实例
CacheRoot = CacheManager()
class Cache(Generic[T]):
"""类型化缓存访问接口"""
def __init__(self, module: str):
self.module = module
self.module = module.upper()
async def get(self, *args, **kwargs) -> T | None:
"""获取缓存"""
return await CacheRoot.get(self.module, *args, **kwargs)
async def update(self, key: str, value: Any = None, *args, **kwargs):
return await CacheRoot.update(self.module, key, value, *args, **kwargs)
"""更新缓存项"""
await CacheRoot.update(self.module, key, value, *args, **kwargs)
async def reload(self, key: str | None = None, *args, **kwargs):
await CacheRoot.reload(self.module, key, *args, **kwargs)
async def reload(self, *args, **kwargs):
"""重新加载缓存"""
await CacheRoot.reload(self.module, *args, **kwargs)