添加增量缓存与缓存过期

This commit is contained in:
HibiKier 2025-03-05 17:37:55 +08:00
parent 76cafea7d4
commit 6eb9bb510a
7 changed files with 478 additions and 204 deletions

View File

@ -4,8 +4,8 @@ from nonebot_plugin_uninfo import Uninfo
from zhenxun.models.level_user import LevelUser from zhenxun.models.level_user import LevelUser
from zhenxun.models.plugin_info import PluginInfo from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.log import logger
from zhenxun.services.cache import Cache from zhenxun.services.cache import Cache
from zhenxun.services.log import logger
from zhenxun.utils.enum import CacheType from zhenxun.utils.enum import CacheType
from zhenxun.utils.message import MessageUtils from zhenxun.utils.message import MessageUtils

View File

@ -2,8 +2,8 @@ from nonebot.exception import IgnoredException
from zhenxun.models.bot_console import BotConsole from zhenxun.models.bot_console import BotConsole
from zhenxun.models.plugin_info import PluginInfo from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.log import logger
from zhenxun.services.cache import Cache from zhenxun.services.cache import Cache
from zhenxun.services.log import logger
from zhenxun.utils.common_utils import CommonUtils from zhenxun.utils.common_utils import CommonUtils
from zhenxun.utils.enum import CacheType from zhenxun.utils.enum import CacheType

View File

@ -9,6 +9,7 @@ from tortoise.exceptions import MultipleObjectsReturned
from zhenxun.configs.config import Config from zhenxun.configs.config import Config
from zhenxun.models.ban_console import BanConsole from zhenxun.models.ban_console import BanConsole
from zhenxun.models.group_console import GroupConsole from zhenxun.models.group_console import GroupConsole
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.cache import Cache from zhenxun.services.cache import Cache
from zhenxun.services.log import logger from zhenxun.services.log import logger
from zhenxun.utils.enum import CacheType, PluginType from zhenxun.utils.enum import CacheType, PluginType
@ -27,7 +28,8 @@ _flmt = FreqLimiter(300)
async def is_ban(user_id: str | None, group_id: str | None): async def is_ban(user_id: str | None, group_id: str | None):
cache = Cache[list[BanConsole]](CacheType.BAN) cache = Cache[list[BanConsole]](CacheType.BAN)
return await cache.get(user_id, group_id) result = await cache.get(user_id, group_id) or await cache.get(user_id)
return result and result[0].ban_time > 0
# 检查是否被ban # 检查是否被ban
@ -80,8 +82,18 @@ async def _(matcher: Matcher, bot: Bot, session: Uninfo):
time_str = f"{hours} 小时 {minute}分钟" time_str = f"{hours} 小时 {minute}分钟"
else: else:
time_str = f"{minute} 分钟" time_str = f"{minute} 分钟"
if time != -1 and ban_result and _flmt.check(user_id): db_plugin = await Cache[PluginInfo](CacheType.PLUGINS).get(
matcher.plugin_name
)
if (
db_plugin
# and not db_plugin.ignore_prompt
and time != -1
and ban_result
and _flmt.check(user_id)
):
_flmt.start_cd(user_id) _flmt.start_cd(user_id)
logger.debug(f"ban检测发送插件: {matcher.plugin_name}")
await MessageUtils.build_message( await MessageUtils.build_message(
[ [
At(flag="user", target=user_id), At(flag="user", target=user_id),

View File

@ -1,22 +1,21 @@
import os
from pathlib import Path from pathlib import Path
from typing import Any import sys
import nonebot import nonebot
from nonebot.adapters import Bot from nonebot.adapters import Bot
from zhenxun.models.ban_console import BanConsole
from zhenxun.models.bot_console import BotConsole
from zhenxun.models.group_console import GroupConsole from zhenxun.models.group_console import GroupConsole
from zhenxun.models.level_user import LevelUser from zhenxun.services.cache import DbCacheException
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.user_console import UserConsole
from zhenxun.services.cache import CacheRoot
from zhenxun.services.log import logger from zhenxun.services.log import logger
from zhenxun.utils.enum import CacheType
from zhenxun.utils.platform import PlatformUtils from zhenxun.utils.platform import PlatformUtils
nonebot.load_plugins(str(Path(__file__).parent.resolve())) nonebot.load_plugins(str(Path(__file__).parent.resolve()))
try:
from .__init_cache import driver
except DbCacheException as e:
raise SystemError(f"ERROR{e}")
driver = nonebot.get_driver() driver = nonebot.get_driver()
@ -49,150 +48,3 @@ async def _(bot: Bot):
f"更新Bot: {bot.self_id} 的群认证完成,共创建 {len(create_list)} 条数据," f"更新Bot: {bot.self_id} 的群认证完成,共创建 {len(create_list)} 条数据,"
f"共修改 {len(update_id)} 条数据..." f"共修改 {len(update_id)} 条数据..."
) )
@CacheRoot.new(CacheType.PLUGINS)
async def _(data: dict[str, PluginInfo] = {}, key: str | None = None):
if data and key:
if plugin := await PluginInfo.get_plugin(module=key):
data[key] = plugin
else:
data_list = await PluginInfo.get_plugins()
return {p.module: p for p in data_list}
@CacheRoot.updater(CacheType.PLUGINS)
async def _(data: dict[str, PluginInfo], key: str, value: Any):
if value:
data[key] = value
elif plugin := await PluginInfo.get_plugin(module=key):
data[key] = plugin
@CacheRoot.getter(CacheType.PLUGINS, result_model=PluginInfo)
async def _(data: dict[str, PluginInfo], module: str):
result = data.get(module, None)
if not result:
result = await PluginInfo.get_plugin(module=module)
if result:
data[module] = result
return result
@CacheRoot.new(CacheType.GROUPS)
async def _():
data_list = await GroupConsole.all()
return {p.group_id: p for p in data_list if not p.channel_id}
@CacheRoot.updater(CacheType.GROUPS)
async def _(data: dict[str, GroupConsole], key: str, value: Any):
if value:
data[key] = value
elif group := await GroupConsole.get_group(group_id=key):
data[key] = group
@CacheRoot.getter(CacheType.GROUPS, result_model=GroupConsole)
async def _(data: dict[str, GroupConsole], group_id: str):
result = data.get(group_id, None)
if not result:
result = await GroupConsole.get_group(group_id=group_id)
if result:
data[group_id] = result
return result
@CacheRoot.new(CacheType.BOT)
async def _():
data_list = await BotConsole.all()
return {p.bot_id: p for p in data_list}
@CacheRoot.updater(CacheType.BOT)
async def _(data: dict[str, BotConsole], key: str, value: Any):
if value:
data[key] = value
elif bot := await BotConsole.get_or_none(bot_id=key):
data[key] = bot
@CacheRoot.getter(CacheType.BOT, result_model=BotConsole)
async def _(data: dict[str, BotConsole], bot_id: str):
result = data.get(bot_id, None)
if not result:
result = await BotConsole.get_or_none(bot_id=bot_id)
if result:
data[bot_id] = result
return result
@CacheRoot.new(CacheType.USERS)
async def _():
data_list = await UserConsole.all()
return {p.user_id: p for p in data_list}
@CacheRoot.updater(CacheType.USERS)
async def _(data: dict[str, UserConsole], key: str, value: Any):
if value:
data[key] = value
elif user := await UserConsole.get_user(user_id=key):
data[key] = user
@CacheRoot.getter(CacheType.USERS, result_model=UserConsole)
async def _(data: dict[str, UserConsole], user_id: str):
result = data.get(user_id, None)
if not result:
result = await UserConsole.get_user(user_id=user_id)
if result:
data[user_id] = result
return result
@CacheRoot.new(CacheType.LEVEL)
async def _():
return await LevelUser().all()
@CacheRoot.getter(CacheType.LEVEL, result_model=list[LevelUser])
def _(data_list: list[LevelUser], user_id: str, group_id: str | None = None):
if not group_id:
return [
data for data in data_list if data.user_id == user_id and not data.group_id
]
else:
return [
data
for data in data_list
if data.user_id == user_id and data.group_id == group_id
]
@CacheRoot.new(CacheType.BAN)
async def _():
return await BanConsole.all()
@CacheRoot.getter(CacheType.BAN, result_model=list[BanConsole])
def _(data_list: list[BanConsole], user_id: str | None, group_id: str | None = None):
if user_id:
return (
[
data
for data in data_list
if data.user_id == user_id and data.group_id == group_id
]
if group_id
else [
data
for data in data_list
if data.user_id == user_id and not data.group_id
]
)
if group_id:
return [
data for data in data_list if not data.user_id and data.group_id == group_id
]
return None

View File

@ -0,0 +1,276 @@
import time
from typing import Any
import nonebot
from zhenxun.models.ban_console import BanConsole
from zhenxun.models.bot_console import BotConsole
from zhenxun.models.group_console import GroupConsole
from zhenxun.models.level_user import LevelUser
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.user_console import UserConsole
from zhenxun.services.cache import CacheData, CacheRoot
from zhenxun.utils.enum import CacheType
driver = nonebot.get_driver()
@driver.on_startup
async def _():
"""开启cache检测"""
CacheRoot.start_check()
def default_cleanup_expired(cache_data: CacheData) -> list[str]:
"""默认清理过期cache方法"""
if not cache_data.data:
return []
now = time.time()
expire_key = []
for k, t in list(cache_data.expire_data.items()):
if t < now:
expire_key.append(k)
cache_data.expire_data.pop(k)
if expire_key:
cache_data.data = {
k: v for k, v in cache_data.data.items() if k not in expire_key
}
return expire_key
def default_with_expiration(
data: dict[str, Any], expire_data: dict[str, int], expire: int
):
"""默认更新期时间cache方法"""
keys = {k for k in data if k not in expire_data}
return {k: time.time() + expire for k in keys} if keys else {}
@CacheRoot.new(CacheType.PLUGINS)
async def _():
data_list = await PluginInfo.get_plugins()
return {p.module: p for p in data_list}
@CacheRoot.new(CacheType.PLUGINS)
async def _():
data_list = await PluginInfo.get_plugins()
return {p.module: p for p in data_list}
@CacheRoot.updater(CacheType.PLUGINS)
async def _(data: dict[str, PluginInfo], key: str, value: Any):
if value:
data[key] = value
elif plugin := await PluginInfo.get_plugin(module=key):
data[key] = plugin
@CacheRoot.getter(CacheType.PLUGINS, result_model=PluginInfo)
async def _(cache_data: CacheData, module: str):
cache_data.data = cache_data.data or {}
result = cache_data.data.get(module, None)
if not result:
result = await PluginInfo.get_plugin(module=module)
if result:
cache_data.data[module] = result
return result
@CacheRoot.with_refresh(CacheType.PLUGINS)
async def _(data: dict[str, PluginInfo]):
plugins = await PluginInfo.filter(module__in=data.keys(), load_status=True)
for plugin in plugins:
data[plugin.module] = plugin
@CacheRoot.with_expiration(CacheType.PLUGINS)
def _(data: dict[str, PluginInfo], expire_data: dict[str, int], expire: int):
return default_with_expiration(data, expire_data, expire)
@CacheRoot.cleanup_expired(CacheType.PLUGINS)
def _(cache_data: CacheData):
return default_cleanup_expired(cache_data)
@CacheRoot.new(CacheType.GROUPS)
async def _():
data_list = await GroupConsole.all()
return {p.group_id: p for p in data_list if not p.channel_id}
@CacheRoot.updater(CacheType.GROUPS)
async def _(data: dict[str, GroupConsole], key: str, value: Any):
if value:
data[key] = value
elif group := await GroupConsole.get_group(group_id=key):
data[key] = group
@CacheRoot.getter(CacheType.GROUPS, result_model=GroupConsole)
async def _(data: dict[str, GroupConsole] | None, group_id: str):
if not data:
data = {}
result = data.get(group_id, None)
if not result:
result = await GroupConsole.get_group(group_id=group_id)
if result:
data[group_id] = result
return result
@CacheRoot.with_refresh(CacheType.GROUPS)
async def _(data: dict[str, GroupConsole]):
groups = await GroupConsole.filter(
group_id__in=data.keys(), channel_id__isnull=True, load_status=True
)
for group in groups:
data[group.group_id] = group
@CacheRoot.with_expiration(CacheType.GROUPS)
def _(data: dict[str, GroupConsole], expire_data: dict[str, int], expire: int):
return default_with_expiration(data, expire_data, expire)
@CacheRoot.cleanup_expired(CacheType.GROUPS)
def _(cache_data: CacheData):
return default_cleanup_expired(cache_data)
@CacheRoot.new(CacheType.BOT)
async def _():
data_list = await BotConsole.all()
return {p.bot_id: p for p in data_list}
@CacheRoot.updater(CacheType.BOT)
async def _(data: dict[str, BotConsole], key: str, value: Any):
if value:
data[key] = value
elif bot := await BotConsole.get_or_none(bot_id=key):
data[key] = bot
@CacheRoot.getter(CacheType.BOT, result_model=BotConsole)
async def _(data: dict[str, BotConsole] | None, bot_id: str):
if not data:
data = {}
result = data.get(bot_id, None)
if not result:
result = await BotConsole.get_or_none(bot_id=bot_id)
if result:
data[bot_id] = result
return result
@CacheRoot.with_refresh(CacheType.BOT)
async def _(data: dict[str, BotConsole]):
bots = await BotConsole.filter(bot_id__in=data.keys())
for bot in bots:
data[bot.bot_id] = bot
@CacheRoot.with_expiration(CacheType.BOT)
def _(data: dict[str, BotConsole], expire_data: dict[str, int], expire: int):
return default_with_expiration(data, expire_data, expire)
@CacheRoot.cleanup_expired(CacheType.BOT)
def _(cache_data: CacheData):
return default_cleanup_expired(cache_data)
@CacheRoot.new(CacheType.USERS)
async def _():
data_list = await UserConsole.all()
return {p.user_id: p for p in data_list}
@CacheRoot.updater(CacheType.USERS)
async def _(data: dict[str, UserConsole], key: str, value: Any):
if value:
data[key] = value
elif user := await UserConsole.get_user(user_id=key):
data[key] = user
@CacheRoot.getter(CacheType.USERS, result_model=UserConsole)
async def _(cache_data: CacheData, user_id: str):
cache_data.data = cache_data.data or {}
result = cache_data.data.get(user_id, None)
if not result:
result = await UserConsole.get_user(user_id=user_id)
if result:
cache_data.data[user_id] = result
return result
@CacheRoot.with_refresh(CacheType.USERS)
async def _(data: dict[str, UserConsole]):
users = await UserConsole.filter(user_id__in=data.keys())
for user in users:
data[user.user_id] = user
@CacheRoot.with_expiration(CacheType.USERS)
def _(data: dict[str, UserConsole], expire_data: dict[str, int], expire: int):
return default_with_expiration(data, expire_data, expire)
@CacheRoot.cleanup_expired(CacheType.USERS)
def _(cache_data: CacheData):
return default_cleanup_expired(cache_data)
@CacheRoot.new(CacheType.LEVEL)
async def _():
return await LevelUser().all()
@CacheRoot.getter(CacheType.LEVEL, result_model=list[LevelUser])
async def _(cache_data: CacheData, user_id: str, group_id: str | None = None):
cache_data.data = cache_data.data or []
if not group_id:
return [
data
for data in cache_data.data
if data.user_id == user_id and not data.group_id
]
else:
return [
data
for data in cache_data.data
if data.user_id == user_id and data.group_id == group_id
]
@CacheRoot.new(CacheType.BAN)
async def _():
return await BanConsole.all()
@CacheRoot.getter(CacheType.BAN, result_model=list[BanConsole])
def _(cache_data: CacheData, user_id: str | None, group_id: str | None = None):
if user_id:
return (
[
data
for data in cache_data.data
if data.user_id == user_id and data.group_id == group_id
]
if group_id
else [
data
for data in cache_data.data
if data.user_id == user_id and not data.group_id
]
)
if group_id:
return [
data
for data in cache_data.data
if not data.user_id and data.group_id == group_id
]
return None

View File

@ -107,7 +107,6 @@ class GroupConsole(Model):
return group return group
@classmethod @classmethod
@CacheRoot.listener(CacheType.GROUPS)
async def get_or_create( async def get_or_create(
cls, cls,
defaults: dict | None = None, defaults: dict | None = None,
@ -132,6 +131,9 @@ class GroupConsole(Model):
await group.save( await group.save(
using_db=using_db, update_fields=["block_plugin", "block_task"] using_db=using_db, update_fields=["block_plugin", "block_task"]
) )
if is_create:
if cache := await CacheRoot.get_cache(CacheType.GROUPS):
await cache.update(group.group_id, group)
return group, is_create return group, is_create
@classmethod @classmethod

View File

@ -1,9 +1,11 @@
from collections.abc import Callable from collections.abc import Callable
from functools import wraps from functools import wraps
import inspect
import time import time
from typing import Any, ClassVar, Generic, TypeVar, cast from typing import Any, ClassVar, Generic, TypeVar, cast
from nonebot.utils import is_coroutine_callable from nonebot.utils import is_coroutine_callable
from nonebot_plugin_apscheduler import scheduler
from pydantic import BaseModel from pydantic import BaseModel
from zhenxun.services.log import logger from zhenxun.services.log import logger
@ -13,18 +15,43 @@ __all__ = ["Cache", "CacheData", "CacheRoot"]
T = TypeVar("T") T = TypeVar("T")
class DbCacheException(Exception):
def __init__(self, info: str):
self.info = info
def __repr__(self) -> str:
return super().__repr__()
def __str__(self) -> str:
return self.info
def validate_name(func: Callable):
"""
装饰器验证 name 是否存在于 CacheManage._data
"""
def wrapper(self, name: str, *args, **kwargs):
_name = name.upper()
if _name not in CacheManage._data:
raise DbCacheException(f"DbCache 缓存数据 {name} 不存在...")
return func(self, _name, *args, **kwargs)
return wrapper
class CacheGetter(BaseModel, Generic[T]): class CacheGetter(BaseModel, Generic[T]):
get_func: Callable[..., Any] | None = None get_func: Callable[..., Any] | None = None
"""获取方法""" """获取方法"""
async def get(self, data: Any, *args, **kwargs) -> T: async def get(self, cache_data: "CacheData", *args, **kwargs) -> T:
"""获取缓存""" """获取缓存"""
processed_data = ( processed_data = (
await self.get_func(data, *args, **kwargs) await self.get_func(cache_data, *args, **kwargs)
if self.get_func and is_coroutine_callable(self.get_func) if self.get_func and is_coroutine_callable(self.get_func)
else self.get_func(data, *args, **kwargs) else self.get_func(cache_data, *args, **kwargs)
if self.get_func if self.get_func
else data else cache_data.data
) )
return cast(T, processed_data) return cast(T, processed_data)
@ -38,10 +65,18 @@ class CacheData(BaseModel):
"""获取方法""" """获取方法"""
updater: Callable[..., Any] | None = None updater: Callable[..., Any] | None = None
"""更新单个方法""" """更新单个方法"""
with_refresh: Callable[..., Any] | None = None
"""刷新方法"""
with_expiration: Callable[..., Any] | None = None
"""缓存时间初始化方法"""
cleanup_expired: Callable[..., Any] | None = None
"""缓存过期方法"""
data: Any = None data: Any = None
"""缓存数据""" """缓存数据"""
expire: int expire: int
"""缓存过期时间""" """缓存过期时间"""
expire_data: dict[str, int | float] = {}
"""缓存过期数据时间记录"""
reload_time: float = time.time() reload_time: float = time.time()
"""更新时间""" """更新时间"""
reload_count: int = 0 reload_count: int = 0
@ -49,9 +84,12 @@ class CacheData(BaseModel):
async def get(self, *args, **kwargs) -> Any: async def get(self, *args, **kwargs) -> Any:
"""获取单个缓存""" """获取单个缓存"""
self.call_cleanup_expired() # 移除过期缓存
if not self.getter: if not self.getter:
return self.data return self.data
return await self.getter.get(self.data, *args, **kwargs) result = await self.getter.get(self, *args, **kwargs)
await self.call_with_expiration()
return result
async def update(self, key: str, value: Any = None, *args, **kwargs): async def update(self, key: str, value: Any = None, *args, **kwargs):
"""更新单个缓存""" """更新单个缓存"""
@ -64,23 +102,88 @@ class CacheData(BaseModel):
await self.updater(self.data, key, value, *args, **kwargs) await self.updater(self.data, key, value, *args, **kwargs)
else: else:
self.updater(self.data, key, value, *args, **kwargs) self.updater(self.data, key, value, *args, **kwargs)
logger.debug(
f"缓存类型 {self.name} 更新单个缓存 key: {key}value: {value}",
"CacheRoot",
)
self.expire_data[key] = time.time() + self.expire
else: else:
logger.warning(f"缓存类型 {self.name} 为空,无法更新", "CacheRoot") logger.warning(f"缓存类型 {self.name} 为空,无法更新", "CacheRoot")
async def refresh(self, *args, **kwargs):
"""刷新缓存,只刷新已缓存的数据"""
if not self.with_refresh:
return await self.reload(*args, **kwargs)
if self.data:
if is_coroutine_callable(self.with_refresh):
await self.with_refresh(self.data, *args, **kwargs)
else:
self.with_refresh(self.data, *args, **kwargs)
logger.debug(
f"缓存类型 {self.name} 刷新全局缓存,共刷新 {len(self.data)} 条数据",
"CacheRoot",
)
async def reload(self, *args, **kwargs): async def reload(self, *args, **kwargs):
"""更新缓存""" """更新全部缓存数据"""
self.data = ( if self.has_args():
await self.func(*args, **kwargs) self.data = (
if is_coroutine_callable(self.func) await self.func(*args, **kwargs)
else self.func(*args, **kwargs) if is_coroutine_callable(self.func)
) else self.func(*args, **kwargs)
)
else:
self.data = (
await self.func() if is_coroutine_callable(self.func) else self.func()
)
await self.call_with_expiration()
self.reload_time = time.time() self.reload_time = time.time()
self.reload_count += 1 self.reload_count += 1
logger.debug(f"缓存类型 {self.name} 更新全局缓存", "CacheRoot") logger.debug(
f"缓存类型 {self.name} 更新全局缓存,共更新 {len(self.data)} 条数据",
"CacheRoot",
)
async def check_expire(self): def call_cleanup_expired(self):
if time.time() - self.reload_time > self.expire or not self.reload_count: """清理过期缓存"""
await self.reload() if self.cleanup_expired:
if result := self.cleanup_expired(self):
logger.debug(
f"成功清理 {self.name} {len(result)} 条过期缓存", "CacheRoot"
)
async def call_with_expiration(self, is_force: bool = False):
"""缓存时间更新
参数:
is_force: 是否强制更新全部数据缓存时间.
"""
if self.with_expiration:
if is_force:
self.expire_data = {}
expiration_data = (
await self.with_expiration(self.data, self.expire_data, self.expire)
if is_coroutine_callable(self.with_expiration)
else self.with_expiration(self.data, self.expire_data, self.expire)
)
self.expire_data = {**self.expire_data, **expiration_data}
def has_args(self):
"""是否含有参数
返回:
bool: 是否含有参数
"""
sig = inspect.signature(self.func)
return any(
param.kind
in (
param.POSITIONAL_OR_KEYWORD,
param.POSITIONAL_ONLY,
param.VAR_POSITIONAL,
)
for param in sig.parameters.values()
)
class CacheManage: class CacheManage:
@ -88,18 +191,30 @@ class CacheManage:
异常: 异常:
ValueError: 数据名称重复 DbCacheException: 数据名称重复
ValueError: 数据不存在 DbCacheException: 数据不存在
""" """
_data: ClassVar[dict[str, CacheData]] = {} _data: ClassVar[dict[str, CacheData]] = {}
def start_check(self):
"""启动缓存检查"""
for cache_data in self._data.values():
if cache_data.cleanup_expired:
scheduler.add_job(
cache_data.call_cleanup_expired,
"interval",
seconds=cache_data.expire,
args=[],
id=f"CacheRoot-{cache_data.name}",
)
def new(self, name: str, expire: int = 60 * 10): def new(self, name: str, expire: int = 60 * 10):
def wrapper(func: Callable): def wrapper(func: Callable):
_name = name.upper() _name = name.upper()
if _name in self._data: if _name in self._data:
raise ValueError(f"DbCache 缓存数据 {name} 已存在...") raise DbCacheException(f"DbCache 缓存数据 {name} 已存在...")
self._data[_name] = CacheData(name=_name, func=func, expire=expire) self._data[_name] = CacheData(name=_name, func=func, expire=expire)
return wrapper return wrapper
@ -116,8 +231,12 @@ class CacheManage:
return result return result
finally: finally:
cache_data = self._data.get(name.upper()) cache_data = self._data.get(name.upper())
if cache_data: if cache_data and cache_data.with_refresh:
await cache_data.reload() if is_coroutine_callable(cache_data.with_refresh):
await cache_data.with_refresh(cache_data.data)
else:
cache_data.with_refresh(cache_data.data)
await cache_data.call_with_expiration(True)
logger.debug( logger.debug(
f"缓存类型 {name.upper()} 进行监听更新...", "CacheRoot" f"缓存类型 {name.upper()} 进行监听更新...", "CacheRoot"
) )
@ -126,48 +245,61 @@ class CacheManage:
return decorator return decorator
@validate_name
def updater(self, name: str): def updater(self, name: str):
def wrapper(func: Callable): def wrapper(func: Callable):
_name = name.upper() self._data[name.upper()].updater = func
if _name not in self._data:
raise ValueError(f"DbCache 缓存数据 {name} 不存在...")
self._data[_name].updater = func
return wrapper return wrapper
@validate_name
def getter(self, name: str, result_model: type | None = None): def getter(self, name: str, result_model: type | None = None):
def wrapper(func: Callable): def wrapper(func: Callable):
_name = name.upper() self._data[name].getter = CacheGetter[result_model](get_func=func)
if _name not in self._data:
raise ValueError(f"DbCache 缓存数据 {name} 不存在...")
self._data[_name].getter = CacheGetter[result_model](get_func=func)
return wrapper return wrapper
async def check_expire(self, name: str): @validate_name
def with_refresh(self, name: str):
def wrapper(func: Callable):
self._data[name.upper()].with_refresh = func
return wrapper
@validate_name
def with_expiration(self, name: str):
def wrapper(func: Callable[[Any, int], dict[str, float]]):
self._data[name.upper()].with_expiration = func
return wrapper
@validate_name
def cleanup_expired(self, name: str):
def wrapper(func: Callable[[CacheData], None]):
self._data[name.upper()].cleanup_expired = func
return wrapper
async def check_expire(self, name: str, *args, **kwargs):
name = name.upper() name = name.upper()
if self._data.get(name): if self._data.get(name) and (
if ( time.time() - self._data[name].reload_time > self._data[name].expire
time.time() - self._data[name].reload_time > self._data[name].expire or not self._data[name].reload_count
or not self._data[name].reload_count ):
): await self._data[name].reload(*args, **kwargs)
await self._data[name].reload()
async def get_cache_data(self, name: str): async def get_cache_data(self, name: str):
if cache := await self.get_cache(name): return cache.data if (cache := await self.get_cache(name)) else None
return cache.data
return None
async def get_cache(self, name: str) -> CacheData | None: async def get_cache(self, name: str, *args, **kwargs) -> CacheData | None:
name = name.upper() name = name.upper()
cache = self._data.get(name) if cache := self._data.get(name):
if cache: # await self.check_expire(name, *args, **kwargs)
await self.check_expire(name)
return cache return cache
return None return None
async def get(self, name: str, *args, **kwargs): async def get(self, name: str, *args, **kwargs):
cache = await self.get_cache(name.upper()) cache = await self.get_cache(name.upper(), *args, **kwargs)
if cache: if cache:
return await cache.get(*args, **kwargs) if cache.getter else cache.data return await cache.get(*args, **kwargs) if cache.getter else cache.data
return None return None
@ -175,7 +307,7 @@ class CacheManage:
async def reload(self, name: str, *args, **kwargs): async def reload(self, name: str, *args, **kwargs):
cache = await self.get_cache(name.upper()) cache = await self.get_cache(name.upper())
if cache: if cache:
await cache.reload(*args, **kwargs) await cache.refresh(*args, **kwargs)
async def update(self, name: str, key: str, value: Any, *args, **kwargs): async def update(self, name: str, key: str, value: Any, *args, **kwargs):
cache = await self.get_cache(name.upper()) cache = await self.get_cache(name.upper())