mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 14:22:55 +08:00
🐛 完善aiocache缓存
This commit is contained in:
parent
f3d5b77bdc
commit
92d118685d
@ -1,5 +1,3 @@
|
|||||||
from typing import Any
|
|
||||||
|
|
||||||
from zhenxun.models.ban_console import BanConsole
|
from zhenxun.models.ban_console import BanConsole
|
||||||
from zhenxun.models.bot_console import BotConsole
|
from zhenxun.models.bot_console import BotConsole
|
||||||
from zhenxun.models.group_console import GroupConsole
|
from zhenxun.models.group_console import GroupConsole
|
||||||
@ -19,34 +17,26 @@ async def _():
|
|||||||
return {p.module: p for p in data_list}
|
return {p.module: p for p in data_list}
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.updater(CacheType.PLUGINS)
|
|
||||||
async def _(data: dict[str, PluginInfo], key: str, value: Any):
|
|
||||||
"""更新插件缓存"""
|
|
||||||
if value:
|
|
||||||
data[key] = value
|
|
||||||
elif plugin := await PluginInfo.get_plugin(module=key):
|
|
||||||
data[key] = plugin
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.getter(CacheType.PLUGINS, result_model=PluginInfo)
|
@CacheRoot.getter(CacheType.PLUGINS, result_model=PluginInfo)
|
||||||
async def _(cache_data: CacheData, module: str):
|
async def _(cache_data: CacheData, module: str):
|
||||||
"""获取插件缓存"""
|
"""获取插件缓存"""
|
||||||
data = await cache_data.get_data() or {}
|
data = await cache_data.get_key(module)
|
||||||
if module not in data:
|
if not data:
|
||||||
if plugin := await PluginInfo.get_plugin(module=module):
|
if plugin := await PluginInfo.get_plugin(module=module):
|
||||||
data[module] = plugin
|
await cache_data.set_key(module, plugin)
|
||||||
await cache_data.set_data(data)
|
|
||||||
logger.debug(f"插件 {module} 数据已设置到缓存")
|
logger.debug(f"插件 {module} 数据已设置到缓存")
|
||||||
return data.get(module)
|
return plugin
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.with_refresh(CacheType.PLUGINS)
|
@CacheRoot.with_refresh(CacheType.PLUGINS)
|
||||||
async def _(data: dict[str, PluginInfo] | None):
|
async def _(cache_data: CacheData, data: dict[str, PluginInfo] | None):
|
||||||
"""刷新插件缓存"""
|
"""刷新插件缓存"""
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
plugins = await PluginInfo.filter(module__in=data.keys(), load_status=True).all()
|
plugins = await PluginInfo.filter(module__in=data.keys(), load_status=True).all()
|
||||||
data.update({p.module: p for p in plugins})
|
for plugin in plugins:
|
||||||
|
await cache_data.set_key(plugin.module, plugin)
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.new(CacheType.GROUPS)
|
@CacheRoot.new(CacheType.GROUPS)
|
||||||
@ -56,35 +46,27 @@ async def _():
|
|||||||
return {p.group_id: p for p in data_list if not p.channel_id}
|
return {p.group_id: p for p in data_list if not p.channel_id}
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.updater(CacheType.GROUPS)
|
|
||||||
async def _(data: dict[str, GroupConsole], key: str, value: Any):
|
|
||||||
"""更新群组缓存"""
|
|
||||||
if value:
|
|
||||||
data[key] = value
|
|
||||||
elif group := await GroupConsole.get_group(group_id=key):
|
|
||||||
data[key] = group
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.getter(CacheType.GROUPS, result_model=GroupConsole)
|
@CacheRoot.getter(CacheType.GROUPS, result_model=GroupConsole)
|
||||||
async def _(cache_data: CacheData, group_id: str):
|
async def _(cache_data: CacheData, group_id: str):
|
||||||
"""获取群组缓存"""
|
"""获取群组缓存"""
|
||||||
data = await cache_data.get_data() or {}
|
data = await cache_data.get_key(group_id)
|
||||||
if group_id not in data:
|
if not data:
|
||||||
if group := await GroupConsole.get_group(group_id=group_id):
|
if group := await GroupConsole.get_group(group_id=group_id):
|
||||||
data[group_id] = group
|
await cache_data.set_key(group_id, group)
|
||||||
await cache_data.set_data(data)
|
return group
|
||||||
return data.get(group_id)
|
return data
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.with_refresh(CacheType.GROUPS)
|
@CacheRoot.with_refresh(CacheType.GROUPS)
|
||||||
async def _(data: dict[str, GroupConsole] | None):
|
async def _(cache_data: CacheData, data: dict[str, GroupConsole] | None):
|
||||||
"""刷新群组缓存"""
|
"""刷新群组缓存"""
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
groups = await GroupConsole.filter(
|
groups = await GroupConsole.filter(
|
||||||
group_id__in=data.keys(), channel_id__isnull=True
|
group_id__in=data.keys(), channel_id__isnull=True
|
||||||
).all()
|
).all()
|
||||||
data.update({g.group_id: g for g in groups})
|
for group in groups:
|
||||||
|
await cache_data.set_key(group.group_id, group)
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.new(CacheType.BOT)
|
@CacheRoot.new(CacheType.BOT)
|
||||||
@ -94,33 +76,25 @@ async def _():
|
|||||||
return {p.bot_id: p for p in data_list}
|
return {p.bot_id: p for p in data_list}
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.updater(CacheType.BOT)
|
|
||||||
async def _(data: dict[str, BotConsole], key: str, value: Any):
|
|
||||||
"""更新机器人缓存"""
|
|
||||||
if value:
|
|
||||||
data[key] = value
|
|
||||||
elif bot := await BotConsole.get_or_none(bot_id=key):
|
|
||||||
data[key] = bot
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.getter(CacheType.BOT, result_model=BotConsole)
|
@CacheRoot.getter(CacheType.BOT, result_model=BotConsole)
|
||||||
async def _(cache_data: CacheData, bot_id: str):
|
async def _(cache_data: CacheData, bot_id: str):
|
||||||
"""获取机器人缓存"""
|
"""获取机器人缓存"""
|
||||||
data = await cache_data.get_data() or {}
|
data = await cache_data.get_key(bot_id)
|
||||||
if bot_id not in data:
|
if not data:
|
||||||
if bot := await BotConsole.get_or_none(bot_id=bot_id):
|
if bot := await BotConsole.get_or_none(bot_id=bot_id):
|
||||||
data[bot_id] = bot
|
await cache_data.set_key(bot_id, bot)
|
||||||
await cache_data.set_data(data)
|
return bot
|
||||||
return data.get(bot_id)
|
return data
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.with_refresh(CacheType.BOT)
|
@CacheRoot.with_refresh(CacheType.BOT)
|
||||||
async def _(data: dict[str, BotConsole] | None):
|
async def _(cache_data: CacheData, data: dict[str, BotConsole] | None):
|
||||||
"""刷新机器人缓存"""
|
"""刷新机器人缓存"""
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
bots = await BotConsole.filter(bot_id__in=data.keys()).all()
|
bots = await BotConsole.filter(bot_id__in=data.keys()).all()
|
||||||
data.update({b.bot_id: b for b in bots})
|
for bot in bots:
|
||||||
|
await cache_data.set_key(bot.bot_id, bot)
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.new(CacheType.USERS)
|
@CacheRoot.new(CacheType.USERS)
|
||||||
@ -130,111 +104,100 @@ async def _():
|
|||||||
return {p.user_id: p for p in data_list}
|
return {p.user_id: p for p in data_list}
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.updater(CacheType.USERS)
|
|
||||||
async def _(data: dict[str, UserConsole], key: str, value: Any):
|
|
||||||
"""更新用户缓存"""
|
|
||||||
if value:
|
|
||||||
data[key] = value
|
|
||||||
elif user := await UserConsole.get_user(user_id=key):
|
|
||||||
data[key] = user
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.getter(CacheType.USERS, result_model=UserConsole)
|
@CacheRoot.getter(CacheType.USERS, result_model=UserConsole)
|
||||||
async def _(cache_data: CacheData, user_id: str):
|
async def _(cache_data: CacheData, user_id: str):
|
||||||
"""获取用户缓存"""
|
"""获取用户缓存"""
|
||||||
data = await cache_data.get_data() or {}
|
data = await cache_data.get_key(user_id)
|
||||||
if user_id not in data:
|
if not data:
|
||||||
if user := await UserConsole.get_user(user_id=user_id):
|
if user := await UserConsole.get_user(user_id=user_id):
|
||||||
data[user_id] = user
|
await cache_data.set_key(user_id, user)
|
||||||
await cache_data.set_data(data)
|
return user
|
||||||
return data.get(user_id)
|
return data
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.with_refresh(CacheType.USERS)
|
@CacheRoot.with_refresh(CacheType.USERS)
|
||||||
async def _(data: dict[str, UserConsole] | None):
|
async def _(cache_data: CacheData, data: dict[str, UserConsole] | None):
|
||||||
"""刷新用户缓存"""
|
"""刷新用户缓存"""
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
users = await UserConsole.filter(user_id__in=data.keys()).all()
|
users = await UserConsole.filter(user_id__in=data.keys()).all()
|
||||||
data.update({u.user_id: u for u in users})
|
for user in users:
|
||||||
|
await cache_data.set_key(user.user_id, user)
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.new(CacheType.LEVEL, False)
|
@CacheRoot.new(CacheType.LEVEL, False)
|
||||||
async def _():
|
async def _():
|
||||||
"""初始化等级缓存"""
|
"""初始化等级缓存"""
|
||||||
return await LevelUser().all()
|
data_list = await LevelUser().all()
|
||||||
|
return {f"{d.user_id}:{d.group_id or ''}": d for d in data_list}
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.getter(CacheType.LEVEL, result_model=list[LevelUser])
|
@CacheRoot.getter(CacheType.LEVEL, result_model=list[LevelUser])
|
||||||
async def _(cache_data: CacheData, user_id: str, group_id: str | None = None):
|
async def _(cache_data: CacheData, user_id: str, group_id: str | None = None):
|
||||||
"""获取等级缓存"""
|
"""获取等级缓存"""
|
||||||
data = await cache_data.get_data() or []
|
key = f"{user_id}:{group_id or ''}"
|
||||||
if not group_id:
|
data = await cache_data.get_key(key)
|
||||||
return [d for d in data if d.user_id == user_id and not d.group_id]
|
if not data:
|
||||||
return [d for d in data if d.user_id == user_id and d.group_id == group_id]
|
if group_id:
|
||||||
|
data = await LevelUser.filter(user_id=user_id, group_id=group_id).all()
|
||||||
|
else:
|
||||||
|
data = await LevelUser.filter(user_id=user_id, group_id__isnull=True).all()
|
||||||
|
if data:
|
||||||
|
await cache_data.set_key(key, data)
|
||||||
|
return data
|
||||||
|
return data or []
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.new(CacheType.BAN, False)
|
@CacheRoot.new(CacheType.BAN, False)
|
||||||
async def _():
|
async def _():
|
||||||
"""初始化封禁缓存"""
|
"""初始化封禁缓存"""
|
||||||
return await BanConsole.all()
|
data_list = await BanConsole.all()
|
||||||
|
return {f"{d.user_id or ''}:{d.group_id or ''}": d for d in data_list}
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.getter(CacheType.BAN, result_model=list[BanConsole])
|
@CacheRoot.getter(CacheType.BAN, result_model=list[BanConsole])
|
||||||
async def _(cache_data: CacheData, user_id: str | None, group_id: str | None = None):
|
async def _(cache_data: CacheData, user_id: str | None, group_id: str | None = None):
|
||||||
"""获取封禁缓存"""
|
"""获取封禁缓存"""
|
||||||
data = await cache_data.get_data() or []
|
key = f"{user_id or ''}:{group_id or ''}"
|
||||||
if user_id:
|
data = await cache_data.get_key(key)
|
||||||
if group_id:
|
if not data:
|
||||||
return [d for d in data if d.user_id == user_id and d.group_id == group_id]
|
if user_id and group_id:
|
||||||
return [d for d in data if d.user_id == user_id and not d.group_id]
|
data = await BanConsole.filter(user_id=user_id, group_id=group_id).all()
|
||||||
if group_id:
|
elif user_id:
|
||||||
return [d for d in data if not d.user_id and d.group_id == group_id]
|
data = await BanConsole.filter(user_id=user_id, group_id__isnull=True).all()
|
||||||
return None
|
elif group_id:
|
||||||
|
data = await BanConsole.filter(
|
||||||
|
user_id__isnull=True, group_id=group_id
|
||||||
|
).all()
|
||||||
|
if data:
|
||||||
|
await cache_data.set_key(key, data)
|
||||||
|
return data
|
||||||
|
return data or []
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.new(CacheType.LIMIT)
|
@CacheRoot.new(CacheType.LIMIT)
|
||||||
async def _():
|
async def _():
|
||||||
"""初始化限制缓存"""
|
"""初始化限制缓存"""
|
||||||
data_list = await PluginLimit.filter(status=True).all()
|
data_list = await PluginLimit.filter(status=True).all()
|
||||||
result_data = {}
|
return {data.module: data for data in data_list}
|
||||||
for data in data_list:
|
|
||||||
if not result_data.get(data.module):
|
|
||||||
result_data[data.module] = []
|
|
||||||
result_data[data.module].append(data)
|
|
||||||
return result_data
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.updater(CacheType.LIMIT)
|
|
||||||
async def _(data: dict[str, list[PluginLimit]], key: str, value: Any):
|
|
||||||
"""更新限制缓存"""
|
|
||||||
if value:
|
|
||||||
data[key] = value
|
|
||||||
elif limits := await PluginLimit.filter(module=key, status=True):
|
|
||||||
data[key] = limits
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.getter(CacheType.LIMIT, result_model=list[PluginLimit])
|
@CacheRoot.getter(CacheType.LIMIT, result_model=list[PluginLimit])
|
||||||
async def _(cache_data: CacheData, module: str):
|
async def _(cache_data: CacheData, module: str):
|
||||||
"""获取限制缓存"""
|
"""获取限制缓存"""
|
||||||
data = await cache_data.get_data() or {}
|
data = await cache_data.get_key(module)
|
||||||
if module not in data:
|
if not data:
|
||||||
if limits := await PluginLimit.filter(module=module, status=True):
|
if limits := await PluginLimit.filter(module=module, status=True):
|
||||||
data[module] = limits
|
await cache_data.set_key(module, limits)
|
||||||
await cache_data.set_data(data)
|
return limits
|
||||||
return data.get(module)
|
return data or []
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.with_refresh(CacheType.LIMIT)
|
@CacheRoot.with_refresh(CacheType.LIMIT)
|
||||||
async def _(data: dict[str, list[PluginLimit]] | None):
|
async def _(cache_data: CacheData, data: dict[str, list[PluginLimit]] | None):
|
||||||
"""刷新限制缓存"""
|
"""刷新限制缓存"""
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
limits = await PluginLimit.filter(module__in=data.keys(), load_status=True).all()
|
limits = await PluginLimit.filter(module__in=data.keys(), load_status=True).all()
|
||||||
new_data = {}
|
|
||||||
for limit in limits:
|
for limit in limits:
|
||||||
if not new_data.get(limit.module):
|
await cache_data.set_key(limit.module, limit)
|
||||||
new_data[limit.module] = []
|
|
||||||
new_data[limit.module].append(limit)
|
|
||||||
data.clear()
|
|
||||||
data.update(new_data)
|
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from aiocache.serializers import JsonSerializer
|
|||||||
from nonebot.compat import model_dump
|
from nonebot.compat import model_dump
|
||||||
from nonebot.utils import is_coroutine_callable
|
from nonebot.utils import is_coroutine_callable
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from tortoise.fields.base import Field
|
||||||
|
|
||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
|
|
||||||
@ -44,15 +45,47 @@ class CacheGetter(BaseModel, Generic[T]):
|
|||||||
"""缓存数据获取器"""
|
"""缓存数据获取器"""
|
||||||
|
|
||||||
get_func: Callable[..., Any] | None = None
|
get_func: Callable[..., Any] | None = None
|
||||||
|
get_all_func: Callable[..., Any] | None = None
|
||||||
|
|
||||||
async def get(self, cache_data: "CacheData", *args, **kwargs) -> T:
|
async def get(self, cache_data: "CacheData", key: str, *args, **kwargs) -> T:
|
||||||
"""获取处理后的缓存数据"""
|
"""获取单个缓存数据"""
|
||||||
if not self.get_func:
|
if not self.get_func:
|
||||||
return await cache_data.get_data()
|
data = await cache_data.get_key(key)
|
||||||
|
if cache_data.result_model:
|
||||||
|
return cache_data._deserialize_value(data, cache_data.result_model)
|
||||||
|
return data
|
||||||
|
|
||||||
if is_coroutine_callable(self.get_func):
|
if is_coroutine_callable(self.get_func):
|
||||||
return await self.get_func(cache_data, *args, **kwargs)
|
data = await self.get_func(cache_data, key, *args, **kwargs)
|
||||||
return self.get_func(cache_data, *args, **kwargs)
|
else:
|
||||||
|
data = self.get_func(cache_data, key, *args, **kwargs)
|
||||||
|
|
||||||
|
if cache_data.result_model:
|
||||||
|
return cache_data._deserialize_value(data, cache_data.result_model)
|
||||||
|
return data
|
||||||
|
|
||||||
|
async def get_all(self, cache_data: "CacheData", *args, **kwargs) -> dict[str, T]:
|
||||||
|
"""获取所有缓存数据"""
|
||||||
|
if not self.get_all_func:
|
||||||
|
data = await cache_data.get_all_data()
|
||||||
|
if cache_data.result_model:
|
||||||
|
return {
|
||||||
|
k: cache_data._deserialize_value(v, cache_data.result_model)
|
||||||
|
for k, v in data.items()
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
|
||||||
|
if is_coroutine_callable(self.get_all_func):
|
||||||
|
data = await self.get_all_func(cache_data, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
data = self.get_all_func(cache_data, *args, **kwargs)
|
||||||
|
|
||||||
|
if cache_data.result_model:
|
||||||
|
return {
|
||||||
|
k: cache_data._deserialize_value(v, cache_data.result_model)
|
||||||
|
for k, v in data.items()
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class CacheData(BaseModel):
|
class CacheData(BaseModel):
|
||||||
@ -66,20 +99,117 @@ class CacheData(BaseModel):
|
|||||||
expire: int = 600 # 默认10分钟过期
|
expire: int = 600 # 默认10分钟过期
|
||||||
reload_count: int = 0
|
reload_count: int = 0
|
||||||
incremental_update: bool = True
|
incremental_update: bool = True
|
||||||
|
_cache_instance: BaseCache | None = None
|
||||||
|
result_model: type | None = None
|
||||||
|
_keys: set[str] = set() # 存储所有缓存键
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
underscore_attrs_are_private = True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _cache(self) -> BaseCache:
|
def _cache(self) -> BaseCache:
|
||||||
"""获取aiocache实例"""
|
"""获取aiocache实例"""
|
||||||
return AioCache(
|
if self._cache_instance is None:
|
||||||
AioCache.MEMORY,
|
self._cache_instance = AioCache(
|
||||||
serializer=JsonSerializer(),
|
AioCache.MEMORY,
|
||||||
namespace="zhenxun_cache",
|
serializer=JsonSerializer(),
|
||||||
timeout=30, # 操作超时时间
|
namespace="zhenxun_cache",
|
||||||
ttl=self.expire, # 设置默认过期时间
|
timeout=30, # 操作超时时间
|
||||||
)
|
ttl=self.expire, # 设置默认过期时间
|
||||||
|
)
|
||||||
|
return self._cache_instance
|
||||||
|
|
||||||
|
def _deserialize_value(self, value: Any, target_type: type | None = None) -> Any:
|
||||||
|
"""反序列化值,将JSON数据转换回原始类型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: 需要反序列化的值
|
||||||
|
target_type: 目标类型,用于指导反序列化
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
反序列化后的值
|
||||||
|
"""
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 如果是字典且指定了目标类型
|
||||||
|
if isinstance(value, dict) and target_type:
|
||||||
|
# 处理Tortoise-ORM Model
|
||||||
|
if hasattr(target_type, "_meta"):
|
||||||
|
# 处理字段值
|
||||||
|
processed_value = {}
|
||||||
|
for field_name, field_value in value.items():
|
||||||
|
field: Field = target_type._meta.fields_map.get(field_name)
|
||||||
|
if field:
|
||||||
|
# 跳过反向关系字段
|
||||||
|
if hasattr(field, "_related_name"):
|
||||||
|
continue
|
||||||
|
# 处理 CharEnumField
|
||||||
|
if hasattr(field, "enum_class"):
|
||||||
|
try:
|
||||||
|
processed_value[field_name] = field.enum_class(
|
||||||
|
field_value
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
processed_value[field_name] = None
|
||||||
|
else:
|
||||||
|
processed_value[field_name] = field_value
|
||||||
|
|
||||||
|
logger.debug(f"处理后的值: {processed_value}")
|
||||||
|
|
||||||
|
# 创建模型实例
|
||||||
|
instance = target_type()
|
||||||
|
# 设置字段值
|
||||||
|
for field_name, field_value in processed_value.items():
|
||||||
|
if field_name in target_type._meta.fields_map:
|
||||||
|
field = target_type._meta.fields_map[field_name]
|
||||||
|
# 设置字段值
|
||||||
|
try:
|
||||||
|
if hasattr(field, "to_python_value"):
|
||||||
|
if not field.field_type:
|
||||||
|
logger.warning(f"字段 {field_name} 类型为空")
|
||||||
|
continue
|
||||||
|
field_value = field.to_python_value(field_value)
|
||||||
|
setattr(instance, field_name, field_value)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"设置字段 {field_name} 失败: {e}")
|
||||||
|
|
||||||
|
# 设置 _saved_in_db 标志
|
||||||
|
instance._saved_in_db = True
|
||||||
|
return instance
|
||||||
|
# 处理Pydantic模型
|
||||||
|
elif hasattr(target_type, "model_validate"):
|
||||||
|
return target_type.model_validate(value)
|
||||||
|
elif hasattr(target_type, "from_dict"):
|
||||||
|
return target_type.from_dict(value)
|
||||||
|
elif hasattr(target_type, "parse_obj"):
|
||||||
|
return target_type.parse_obj(value)
|
||||||
|
else:
|
||||||
|
return target_type(**value)
|
||||||
|
|
||||||
|
# 处理列表类型
|
||||||
|
if isinstance(value, list):
|
||||||
|
if not value:
|
||||||
|
return value
|
||||||
|
if (
|
||||||
|
target_type
|
||||||
|
and hasattr(target_type, "__origin__")
|
||||||
|
and target_type.__origin__ is list
|
||||||
|
):
|
||||||
|
item_type = target_type.__args__[0]
|
||||||
|
return [self._deserialize_value(item, item_type) for item in value]
|
||||||
|
return [self._deserialize_value(item) for item in value]
|
||||||
|
|
||||||
|
# 处理字典类型
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return {k: self._deserialize_value(v) for k, v in value.items()}
|
||||||
|
|
||||||
|
# 处理基本类型
|
||||||
|
if isinstance(value, int | float | str | bool):
|
||||||
|
return value
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
async def get_data(self) -> Any:
|
async def get_data(self) -> Any:
|
||||||
"""从缓存获取数据"""
|
"""从缓存获取数据"""
|
||||||
@ -88,29 +218,33 @@ class CacheData(BaseModel):
|
|||||||
logger.debug(f"获取缓存 {self.name} 数据: {data}")
|
logger.debug(f"获取缓存 {self.name} 数据: {data}")
|
||||||
|
|
||||||
# 如果数据为空,尝试重新加载
|
# 如果数据为空,尝试重新加载
|
||||||
if data is None:
|
# if data is None:
|
||||||
logger.debug(f"缓存 {self.name} 数据为空,尝试重新加载")
|
# logger.debug(f"缓存 {self.name} 数据为空,尝试重新加载")
|
||||||
try:
|
# try:
|
||||||
if self.has_args():
|
# if self.has_args():
|
||||||
new_data = (
|
# new_data = (
|
||||||
await self.func()
|
# await self.func()
|
||||||
if is_coroutine_callable(self.func)
|
# if is_coroutine_callable(self.func)
|
||||||
else self.func()
|
# else self.func()
|
||||||
)
|
# )
|
||||||
else:
|
# else:
|
||||||
new_data = (
|
# new_data = (
|
||||||
await self.func()
|
# await self.func()
|
||||||
if is_coroutine_callable(self.func)
|
# if is_coroutine_callable(self.func)
|
||||||
else self.func()
|
# else self.func()
|
||||||
)
|
# )
|
||||||
|
|
||||||
await self.set_data(new_data)
|
# await self.set_data(new_data)
|
||||||
self.reload_count += 1
|
# self.reload_count += 1
|
||||||
logger.info(f"重新加载缓存 {self.name} 完成")
|
# logger.info(f"重新加载缓存 {self.name} 完成")
|
||||||
return new_data
|
# return new_data
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
logger.error(f"重新加载缓存 {self.name} 失败: {e}")
|
# logger.error(f"重新加载缓存 {self.name} 失败: {e}")
|
||||||
return None
|
# return None
|
||||||
|
|
||||||
|
# 使用 result_model 进行反序列化
|
||||||
|
if self.result_model:
|
||||||
|
return self._deserialize_value(data, self.result_model)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -144,10 +278,16 @@ class CacheData(BaseModel):
|
|||||||
field_value, "_related_name"
|
field_value, "_related_name"
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
|
# 跳过外键关系字段
|
||||||
|
if hasattr(field_value, "_meta"):
|
||||||
|
field_value = getattr(
|
||||||
|
field_value, value._meta.fields[field].related_name or "id"
|
||||||
|
)
|
||||||
result[field] = self._serialize_value(field_value)
|
result[field] = self._serialize_value(field_value)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
continue
|
continue
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# 处理Pydantic模型
|
# 处理Pydantic模型
|
||||||
elif isinstance(value, BaseModel):
|
elif isinstance(value, BaseModel):
|
||||||
return model_dump(value)
|
return model_dump(value)
|
||||||
@ -160,9 +300,6 @@ class CacheData(BaseModel):
|
|||||||
elif isinstance(value, int | float | str | bool):
|
elif isinstance(value, int | float | str | bool):
|
||||||
# 基本类型直接返回
|
# 基本类型直接返回
|
||||||
return value
|
return value
|
||||||
elif hasattr(value, "__dict__"):
|
|
||||||
# 处理普通类对象
|
|
||||||
return self._serialize_value(value.__dict__)
|
|
||||||
else:
|
else:
|
||||||
# 其他类型转换为字符串
|
# 其他类型转换为字符串
|
||||||
return str(value)
|
return str(value)
|
||||||
@ -183,19 +320,6 @@ class CacheData(BaseModel):
|
|||||||
await self._cache.set(self.name, serialized_value, ttl=self.expire)
|
await self._cache.set(self.name, serialized_value, ttl=self.expire)
|
||||||
logger.debug(f"设置缓存 {self.name} 新数据完成")
|
logger.debug(f"设置缓存 {self.name} 新数据完成")
|
||||||
|
|
||||||
# 4. 立即验证
|
|
||||||
cached_data = await self._cache.get(self.name)
|
|
||||||
if cached_data is None:
|
|
||||||
logger.error(f"缓存 {self.name} 设置失败:数据验证失败")
|
|
||||||
# 5. 如果验证失败,尝试重新设置
|
|
||||||
await self._cache.set(self.name, serialized_value, ttl=self.expire)
|
|
||||||
cached_data = await self._cache.get(self.name)
|
|
||||||
if cached_data is None:
|
|
||||||
logger.error(f"缓存 {self.name} 重试设置失败")
|
|
||||||
else:
|
|
||||||
logger.debug(f"缓存 {self.name} 重试设置成功: {cached_data}")
|
|
||||||
else:
|
|
||||||
logger.debug(f"缓存 {self.name} 数据验证成功: {cached_data}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"设置缓存 {self.name} 失败: {e}")
|
logger.error(f"设置缓存 {self.name} 失败: {e}")
|
||||||
raise # 重新抛出异常,让上层处理
|
raise # 重新抛出异常,让上层处理
|
||||||
@ -207,15 +331,25 @@ class CacheData(BaseModel):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"删除缓存 {self.name}", e=e)
|
logger.error(f"删除缓存 {self.name}", e=e)
|
||||||
|
|
||||||
async def get(self, *args, **kwargs) -> Any:
|
async def get(self, key: str, *args, **kwargs) -> Any:
|
||||||
"""获取缓存"""
|
"""获取缓存"""
|
||||||
if not self.reload_count and not self.incremental_update:
|
if not self.reload_count and not self.incremental_update:
|
||||||
await self.reload(*args, **kwargs)
|
await self.reload(*args, **kwargs)
|
||||||
|
|
||||||
if not self.getter:
|
if not self.getter:
|
||||||
return await self.get_data()
|
return await self.get_key(key)
|
||||||
|
|
||||||
return await self.getter.get(self, *args, **kwargs)
|
return await self.getter.get(self, key, *args, **kwargs)
|
||||||
|
|
||||||
|
async def get_all(self, *args, **kwargs) -> dict[str, Any]:
|
||||||
|
"""获取所有缓存数据"""
|
||||||
|
if not self.reload_count and not self.incremental_update:
|
||||||
|
await self.reload(*args, **kwargs)
|
||||||
|
|
||||||
|
if not self.getter:
|
||||||
|
return await self.get_all_data()
|
||||||
|
|
||||||
|
return await self.getter.get_all(self, *args, **kwargs)
|
||||||
|
|
||||||
async def update(self, key: str, value: Any = None, *args, **kwargs):
|
async def update(self, key: str, value: Any = None, *args, **kwargs):
|
||||||
"""更新单个缓存项"""
|
"""更新单个缓存项"""
|
||||||
@ -223,13 +357,13 @@ class CacheData(BaseModel):
|
|||||||
logger.warning(f"缓存 {self.name} 未配置更新方法")
|
logger.warning(f"缓存 {self.name} 未配置更新方法")
|
||||||
return
|
return
|
||||||
|
|
||||||
current_data = await self.get_data() or {}
|
current_data = await self.get_key(key) or {}
|
||||||
if is_coroutine_callable(self.updater):
|
if is_coroutine_callable(self.updater):
|
||||||
await self.updater(current_data, key, value, *args, **kwargs)
|
await self.updater(current_data, key, value, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
self.updater(current_data, key, value, *args, **kwargs)
|
self.updater(current_data, key, value, *args, **kwargs)
|
||||||
|
|
||||||
await self.set_data(current_data)
|
await self.set_key(key, current_data)
|
||||||
logger.debug(f"更新缓存 {self.name}.{key}")
|
logger.debug(f"更新缓存 {self.name}.{key}")
|
||||||
|
|
||||||
async def refresh(self, *args, **kwargs):
|
async def refresh(self, *args, **kwargs):
|
||||||
@ -262,7 +396,14 @@ class CacheData(BaseModel):
|
|||||||
else self.func()
|
else self.func()
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.set_data(new_data)
|
# 如果是字典,则分别存储每个键值对
|
||||||
|
if isinstance(new_data, dict):
|
||||||
|
for key, value in new_data.items():
|
||||||
|
await self.set_key(key, value)
|
||||||
|
else:
|
||||||
|
# 如果不是字典,则存储为单个键值对
|
||||||
|
await self.set_key("default", new_data)
|
||||||
|
|
||||||
self.reload_count += 1
|
self.reload_count += 1
|
||||||
logger.info(f"重新加载缓存 {self.name} 完成")
|
logger.info(f"重新加载缓存 {self.name} 完成")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -291,11 +432,16 @@ class CacheData(BaseModel):
|
|||||||
Returns:
|
Returns:
|
||||||
键对应的值,如果不存在返回None
|
键对应的值,如果不存在返回None
|
||||||
"""
|
"""
|
||||||
|
cache_key = self._get_cache_key(key)
|
||||||
try:
|
try:
|
||||||
data = await self.get_data()
|
data = await self._cache.get(cache_key)
|
||||||
return data.get(key) if isinstance(data, dict) else None
|
logger.debug(f"获取缓存 {cache_key} 数据: {data}")
|
||||||
|
|
||||||
|
if self.result_model:
|
||||||
|
return self._deserialize_value(data, self.result_model)
|
||||||
|
return data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取缓存 {self.name}.{key} 失败: {e}")
|
logger.error(f"获取缓存 {cache_key} 失败: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_keys(self, keys: list[str]) -> dict[str, Any]:
|
async def get_keys(self, keys: list[str]) -> dict[str, Any]:
|
||||||
@ -316,6 +462,61 @@ class CacheData(BaseModel):
|
|||||||
logger.error(f"获取缓存 {self.name} 的多个键失败: {e}")
|
logger.error(f"获取缓存 {self.name} 的多个键失败: {e}")
|
||||||
return dict.fromkeys(keys)
|
return dict.fromkeys(keys)
|
||||||
|
|
||||||
|
def _get_cache_key(self, key: str) -> str:
|
||||||
|
"""获取缓存键名"""
|
||||||
|
return f"{self.name}:{key}"
|
||||||
|
|
||||||
|
async def get_all_data(self) -> dict[str, Any]:
|
||||||
|
"""获取所有缓存数据"""
|
||||||
|
try:
|
||||||
|
result = {}
|
||||||
|
for key in self._keys:
|
||||||
|
# 提取原始键名(去掉前缀)
|
||||||
|
original_key = key.split(":", 1)[1]
|
||||||
|
data = await self._cache.get(key)
|
||||||
|
if self.result_model:
|
||||||
|
result[original_key] = self._deserialize_value(
|
||||||
|
data, self.result_model
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result[original_key] = data
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取所有缓存数据失败: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def set_key(self, key: str, value: Any):
|
||||||
|
"""设置指定键的缓存数据"""
|
||||||
|
cache_key = self._get_cache_key(key)
|
||||||
|
try:
|
||||||
|
serialized_value = self._serialize_value(value)
|
||||||
|
await self._cache.set(cache_key, serialized_value, ttl=self.expire)
|
||||||
|
self._keys.add(cache_key) # 添加到键列表
|
||||||
|
logger.debug(f"设置缓存 {cache_key} 数据完成")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"设置缓存 {cache_key} 失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def delete_key(self, key: str):
|
||||||
|
"""删除指定键的缓存数据"""
|
||||||
|
cache_key = self._get_cache_key(key)
|
||||||
|
try:
|
||||||
|
await self._cache.delete(cache_key)
|
||||||
|
self._keys.discard(cache_key) # 从键列表中移除
|
||||||
|
logger.debug(f"删除缓存 {cache_key} 完成")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"删除缓存 {cache_key} 失败: {e}")
|
||||||
|
|
||||||
|
async def clear(self):
|
||||||
|
"""清除所有缓存数据"""
|
||||||
|
try:
|
||||||
|
for key in list(self._keys): # 使用列表复制避免在迭代时修改
|
||||||
|
await self._cache.delete(key)
|
||||||
|
self._keys.clear()
|
||||||
|
logger.debug(f"清除缓存 {self.name} 完成")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"清除缓存 {self.name} 失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
class CacheManager:
|
class CacheManager:
|
||||||
"""全局缓存管理器"""
|
"""全局缓存管理器"""
|
||||||
@ -378,6 +579,7 @@ class CacheManager:
|
|||||||
|
|
||||||
def wrapper(func: Callable):
|
def wrapper(func: Callable):
|
||||||
self._data[name].getter = CacheGetter[result_model](get_func=func)
|
self._data[name].getter = CacheGetter[result_model](get_func=func)
|
||||||
|
self._data[name].result_model = result_model
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user