mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 14:22:55 +08:00
✨ 添加增量缓存与缓存过期
This commit is contained in:
parent
76cafea7d4
commit
6eb9bb510a
@ -4,8 +4,8 @@ from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.models.level_user import LevelUser
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.services.cache import Cache
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import CacheType
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
|
||||
|
||||
@ -2,8 +2,8 @@ from nonebot.exception import IgnoredException
|
||||
|
||||
from zhenxun.models.bot_console import BotConsole
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.services.cache import Cache
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.common_utils import CommonUtils
|
||||
from zhenxun.utils.enum import CacheType
|
||||
|
||||
|
||||
@ -9,6 +9,7 @@ from tortoise.exceptions import MultipleObjectsReturned
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.models.ban_console import BanConsole
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.services.cache import Cache
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import CacheType, PluginType
|
||||
@ -27,7 +28,8 @@ _flmt = FreqLimiter(300)
|
||||
|
||||
async def is_ban(user_id: str | None, group_id: str | None):
|
||||
cache = Cache[list[BanConsole]](CacheType.BAN)
|
||||
return await cache.get(user_id, group_id)
|
||||
result = await cache.get(user_id, group_id) or await cache.get(user_id)
|
||||
return result and result[0].ban_time > 0
|
||||
|
||||
|
||||
# 检查是否被ban
|
||||
@ -80,8 +82,18 @@ async def _(matcher: Matcher, bot: Bot, session: Uninfo):
|
||||
time_str = f"{hours} 小时 {minute}分钟"
|
||||
else:
|
||||
time_str = f"{minute} 分钟"
|
||||
if time != -1 and ban_result and _flmt.check(user_id):
|
||||
db_plugin = await Cache[PluginInfo](CacheType.PLUGINS).get(
|
||||
matcher.plugin_name
|
||||
)
|
||||
if (
|
||||
db_plugin
|
||||
# and not db_plugin.ignore_prompt
|
||||
and time != -1
|
||||
and ban_result
|
||||
and _flmt.check(user_id)
|
||||
):
|
||||
_flmt.start_cd(user_id)
|
||||
logger.debug(f"ban检测发送插件: {matcher.plugin_name}")
|
||||
await MessageUtils.build_message(
|
||||
[
|
||||
At(flag="user", target=user_id),
|
||||
|
||||
@ -1,22 +1,21 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
import sys
|
||||
|
||||
import nonebot
|
||||
from nonebot.adapters import Bot
|
||||
|
||||
from zhenxun.models.ban_console import BanConsole
|
||||
from zhenxun.models.bot_console import BotConsole
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.models.level_user import LevelUser
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.models.user_console import UserConsole
|
||||
from zhenxun.services.cache import CacheRoot
|
||||
from zhenxun.services.cache import DbCacheException
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import CacheType
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
nonebot.load_plugins(str(Path(__file__).parent.resolve()))
|
||||
|
||||
try:
|
||||
from .__init_cache import driver
|
||||
except DbCacheException as e:
|
||||
raise SystemError(f"ERROR:{e}")
|
||||
|
||||
driver = nonebot.get_driver()
|
||||
|
||||
@ -49,150 +48,3 @@ async def _(bot: Bot):
|
||||
f"更新Bot: {bot.self_id} 的群认证完成,共创建 {len(create_list)} 条数据,"
|
||||
f"共修改 {len(update_id)} 条数据..."
|
||||
)
|
||||
|
||||
|
||||
@CacheRoot.new(CacheType.PLUGINS)
|
||||
async def _(data: dict[str, PluginInfo] = {}, key: str | None = None):
|
||||
if data and key:
|
||||
if plugin := await PluginInfo.get_plugin(module=key):
|
||||
data[key] = plugin
|
||||
else:
|
||||
data_list = await PluginInfo.get_plugins()
|
||||
return {p.module: p for p in data_list}
|
||||
|
||||
|
||||
@CacheRoot.updater(CacheType.PLUGINS)
|
||||
async def _(data: dict[str, PluginInfo], key: str, value: Any):
|
||||
if value:
|
||||
data[key] = value
|
||||
elif plugin := await PluginInfo.get_plugin(module=key):
|
||||
data[key] = plugin
|
||||
|
||||
|
||||
@CacheRoot.getter(CacheType.PLUGINS, result_model=PluginInfo)
|
||||
async def _(data: dict[str, PluginInfo], module: str):
|
||||
result = data.get(module, None)
|
||||
if not result:
|
||||
result = await PluginInfo.get_plugin(module=module)
|
||||
if result:
|
||||
data[module] = result
|
||||
return result
|
||||
|
||||
|
||||
@CacheRoot.new(CacheType.GROUPS)
|
||||
async def _():
|
||||
data_list = await GroupConsole.all()
|
||||
return {p.group_id: p for p in data_list if not p.channel_id}
|
||||
|
||||
|
||||
@CacheRoot.updater(CacheType.GROUPS)
|
||||
async def _(data: dict[str, GroupConsole], key: str, value: Any):
|
||||
if value:
|
||||
data[key] = value
|
||||
elif group := await GroupConsole.get_group(group_id=key):
|
||||
data[key] = group
|
||||
|
||||
|
||||
@CacheRoot.getter(CacheType.GROUPS, result_model=GroupConsole)
|
||||
async def _(data: dict[str, GroupConsole], group_id: str):
|
||||
result = data.get(group_id, None)
|
||||
if not result:
|
||||
result = await GroupConsole.get_group(group_id=group_id)
|
||||
if result:
|
||||
data[group_id] = result
|
||||
return result
|
||||
|
||||
|
||||
@CacheRoot.new(CacheType.BOT)
|
||||
async def _():
|
||||
data_list = await BotConsole.all()
|
||||
return {p.bot_id: p for p in data_list}
|
||||
|
||||
|
||||
@CacheRoot.updater(CacheType.BOT)
|
||||
async def _(data: dict[str, BotConsole], key: str, value: Any):
|
||||
if value:
|
||||
data[key] = value
|
||||
elif bot := await BotConsole.get_or_none(bot_id=key):
|
||||
data[key] = bot
|
||||
|
||||
|
||||
@CacheRoot.getter(CacheType.BOT, result_model=BotConsole)
|
||||
async def _(data: dict[str, BotConsole], bot_id: str):
|
||||
result = data.get(bot_id, None)
|
||||
if not result:
|
||||
result = await BotConsole.get_or_none(bot_id=bot_id)
|
||||
if result:
|
||||
data[bot_id] = result
|
||||
return result
|
||||
|
||||
|
||||
@CacheRoot.new(CacheType.USERS)
|
||||
async def _():
|
||||
data_list = await UserConsole.all()
|
||||
return {p.user_id: p for p in data_list}
|
||||
|
||||
|
||||
@CacheRoot.updater(CacheType.USERS)
|
||||
async def _(data: dict[str, UserConsole], key: str, value: Any):
|
||||
if value:
|
||||
data[key] = value
|
||||
elif user := await UserConsole.get_user(user_id=key):
|
||||
data[key] = user
|
||||
|
||||
|
||||
@CacheRoot.getter(CacheType.USERS, result_model=UserConsole)
|
||||
async def _(data: dict[str, UserConsole], user_id: str):
|
||||
result = data.get(user_id, None)
|
||||
if not result:
|
||||
result = await UserConsole.get_user(user_id=user_id)
|
||||
if result:
|
||||
data[user_id] = result
|
||||
return result
|
||||
|
||||
|
||||
@CacheRoot.new(CacheType.LEVEL)
|
||||
async def _():
|
||||
return await LevelUser().all()
|
||||
|
||||
|
||||
@CacheRoot.getter(CacheType.LEVEL, result_model=list[LevelUser])
|
||||
def _(data_list: list[LevelUser], user_id: str, group_id: str | None = None):
|
||||
if not group_id:
|
||||
return [
|
||||
data for data in data_list if data.user_id == user_id and not data.group_id
|
||||
]
|
||||
else:
|
||||
return [
|
||||
data
|
||||
for data in data_list
|
||||
if data.user_id == user_id and data.group_id == group_id
|
||||
]
|
||||
|
||||
|
||||
@CacheRoot.new(CacheType.BAN)
|
||||
async def _():
|
||||
return await BanConsole.all()
|
||||
|
||||
|
||||
@CacheRoot.getter(CacheType.BAN, result_model=list[BanConsole])
|
||||
def _(data_list: list[BanConsole], user_id: str | None, group_id: str | None = None):
|
||||
if user_id:
|
||||
return (
|
||||
[
|
||||
data
|
||||
for data in data_list
|
||||
if data.user_id == user_id and data.group_id == group_id
|
||||
]
|
||||
if group_id
|
||||
else [
|
||||
data
|
||||
for data in data_list
|
||||
if data.user_id == user_id and not data.group_id
|
||||
]
|
||||
)
|
||||
if group_id:
|
||||
return [
|
||||
data for data in data_list if not data.user_id and data.group_id == group_id
|
||||
]
|
||||
return None
|
||||
|
||||
276
zhenxun/builtin_plugins/init/__init_cache.py
Normal file
276
zhenxun/builtin_plugins/init/__init_cache.py
Normal file
@ -0,0 +1,276 @@
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import nonebot
|
||||
|
||||
from zhenxun.models.ban_console import BanConsole
|
||||
from zhenxun.models.bot_console import BotConsole
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.models.level_user import LevelUser
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.models.user_console import UserConsole
|
||||
from zhenxun.services.cache import CacheData, CacheRoot
|
||||
from zhenxun.utils.enum import CacheType
|
||||
|
||||
driver = nonebot.get_driver()
|
||||
|
||||
|
||||
@driver.on_startup
|
||||
async def _():
|
||||
"""开启cache检测"""
|
||||
CacheRoot.start_check()
|
||||
|
||||
|
||||
def default_cleanup_expired(cache_data: CacheData) -> list[str]:
|
||||
"""默认清理过期cache方法"""
|
||||
if not cache_data.data:
|
||||
return []
|
||||
now = time.time()
|
||||
expire_key = []
|
||||
for k, t in list(cache_data.expire_data.items()):
|
||||
if t < now:
|
||||
expire_key.append(k)
|
||||
cache_data.expire_data.pop(k)
|
||||
if expire_key:
|
||||
cache_data.data = {
|
||||
k: v for k, v in cache_data.data.items() if k not in expire_key
|
||||
}
|
||||
return expire_key
|
||||
|
||||
|
||||
def default_with_expiration(
|
||||
data: dict[str, Any], expire_data: dict[str, int], expire: int
|
||||
):
|
||||
"""默认更新期时间cache方法"""
|
||||
keys = {k for k in data if k not in expire_data}
|
||||
return {k: time.time() + expire for k in keys} if keys else {}
|
||||
|
||||
|
||||
@CacheRoot.new(CacheType.PLUGINS)
|
||||
async def _():
|
||||
data_list = await PluginInfo.get_plugins()
|
||||
return {p.module: p for p in data_list}
|
||||
|
||||
|
||||
@CacheRoot.new(CacheType.PLUGINS)
|
||||
async def _():
|
||||
data_list = await PluginInfo.get_plugins()
|
||||
return {p.module: p for p in data_list}
|
||||
|
||||
|
||||
@CacheRoot.updater(CacheType.PLUGINS)
|
||||
async def _(data: dict[str, PluginInfo], key: str, value: Any):
|
||||
if value:
|
||||
data[key] = value
|
||||
elif plugin := await PluginInfo.get_plugin(module=key):
|
||||
data[key] = plugin
|
||||
|
||||
|
||||
@CacheRoot.getter(CacheType.PLUGINS, result_model=PluginInfo)
|
||||
async def _(cache_data: CacheData, module: str):
|
||||
cache_data.data = cache_data.data or {}
|
||||
result = cache_data.data.get(module, None)
|
||||
if not result:
|
||||
result = await PluginInfo.get_plugin(module=module)
|
||||
if result:
|
||||
cache_data.data[module] = result
|
||||
return result
|
||||
|
||||
|
||||
@CacheRoot.with_refresh(CacheType.PLUGINS)
|
||||
async def _(data: dict[str, PluginInfo]):
|
||||
plugins = await PluginInfo.filter(module__in=data.keys(), load_status=True)
|
||||
for plugin in plugins:
|
||||
data[plugin.module] = plugin
|
||||
|
||||
|
||||
@CacheRoot.with_expiration(CacheType.PLUGINS)
|
||||
def _(data: dict[str, PluginInfo], expire_data: dict[str, int], expire: int):
|
||||
return default_with_expiration(data, expire_data, expire)
|
||||
|
||||
|
||||
@CacheRoot.cleanup_expired(CacheType.PLUGINS)
|
||||
def _(cache_data: CacheData):
|
||||
return default_cleanup_expired(cache_data)
|
||||
|
||||
|
||||
@CacheRoot.new(CacheType.GROUPS)
|
||||
async def _():
|
||||
data_list = await GroupConsole.all()
|
||||
return {p.group_id: p for p in data_list if not p.channel_id}
|
||||
|
||||
|
||||
@CacheRoot.updater(CacheType.GROUPS)
|
||||
async def _(data: dict[str, GroupConsole], key: str, value: Any):
|
||||
if value:
|
||||
data[key] = value
|
||||
elif group := await GroupConsole.get_group(group_id=key):
|
||||
data[key] = group
|
||||
|
||||
|
||||
@CacheRoot.getter(CacheType.GROUPS, result_model=GroupConsole)
|
||||
async def _(data: dict[str, GroupConsole] | None, group_id: str):
|
||||
if not data:
|
||||
data = {}
|
||||
result = data.get(group_id, None)
|
||||
if not result:
|
||||
result = await GroupConsole.get_group(group_id=group_id)
|
||||
if result:
|
||||
data[group_id] = result
|
||||
return result
|
||||
|
||||
|
||||
@CacheRoot.with_refresh(CacheType.GROUPS)
|
||||
async def _(data: dict[str, GroupConsole]):
|
||||
groups = await GroupConsole.filter(
|
||||
group_id__in=data.keys(), channel_id__isnull=True, load_status=True
|
||||
)
|
||||
for group in groups:
|
||||
data[group.group_id] = group
|
||||
|
||||
|
||||
@CacheRoot.with_expiration(CacheType.GROUPS)
|
||||
def _(data: dict[str, GroupConsole], expire_data: dict[str, int], expire: int):
|
||||
return default_with_expiration(data, expire_data, expire)
|
||||
|
||||
|
||||
@CacheRoot.cleanup_expired(CacheType.GROUPS)
|
||||
def _(cache_data: CacheData):
|
||||
return default_cleanup_expired(cache_data)
|
||||
|
||||
|
||||
@CacheRoot.new(CacheType.BOT)
|
||||
async def _():
|
||||
data_list = await BotConsole.all()
|
||||
return {p.bot_id: p for p in data_list}
|
||||
|
||||
|
||||
@CacheRoot.updater(CacheType.BOT)
|
||||
async def _(data: dict[str, BotConsole], key: str, value: Any):
|
||||
if value:
|
||||
data[key] = value
|
||||
elif bot := await BotConsole.get_or_none(bot_id=key):
|
||||
data[key] = bot
|
||||
|
||||
|
||||
@CacheRoot.getter(CacheType.BOT, result_model=BotConsole)
|
||||
async def _(data: dict[str, BotConsole] | None, bot_id: str):
|
||||
if not data:
|
||||
data = {}
|
||||
result = data.get(bot_id, None)
|
||||
if not result:
|
||||
result = await BotConsole.get_or_none(bot_id=bot_id)
|
||||
if result:
|
||||
data[bot_id] = result
|
||||
return result
|
||||
|
||||
|
||||
@CacheRoot.with_refresh(CacheType.BOT)
|
||||
async def _(data: dict[str, BotConsole]):
|
||||
bots = await BotConsole.filter(bot_id__in=data.keys())
|
||||
for bot in bots:
|
||||
data[bot.bot_id] = bot
|
||||
|
||||
|
||||
@CacheRoot.with_expiration(CacheType.BOT)
|
||||
def _(data: dict[str, BotConsole], expire_data: dict[str, int], expire: int):
|
||||
return default_with_expiration(data, expire_data, expire)
|
||||
|
||||
|
||||
@CacheRoot.cleanup_expired(CacheType.BOT)
|
||||
def _(cache_data: CacheData):
|
||||
return default_cleanup_expired(cache_data)
|
||||
|
||||
|
||||
@CacheRoot.new(CacheType.USERS)
|
||||
async def _():
|
||||
data_list = await UserConsole.all()
|
||||
return {p.user_id: p for p in data_list}
|
||||
|
||||
|
||||
@CacheRoot.updater(CacheType.USERS)
|
||||
async def _(data: dict[str, UserConsole], key: str, value: Any):
|
||||
if value:
|
||||
data[key] = value
|
||||
elif user := await UserConsole.get_user(user_id=key):
|
||||
data[key] = user
|
||||
|
||||
|
||||
@CacheRoot.getter(CacheType.USERS, result_model=UserConsole)
|
||||
async def _(cache_data: CacheData, user_id: str):
|
||||
cache_data.data = cache_data.data or {}
|
||||
result = cache_data.data.get(user_id, None)
|
||||
if not result:
|
||||
result = await UserConsole.get_user(user_id=user_id)
|
||||
if result:
|
||||
cache_data.data[user_id] = result
|
||||
return result
|
||||
|
||||
|
||||
@CacheRoot.with_refresh(CacheType.USERS)
|
||||
async def _(data: dict[str, UserConsole]):
|
||||
users = await UserConsole.filter(user_id__in=data.keys())
|
||||
for user in users:
|
||||
data[user.user_id] = user
|
||||
|
||||
|
||||
@CacheRoot.with_expiration(CacheType.USERS)
|
||||
def _(data: dict[str, UserConsole], expire_data: dict[str, int], expire: int):
|
||||
return default_with_expiration(data, expire_data, expire)
|
||||
|
||||
|
||||
@CacheRoot.cleanup_expired(CacheType.USERS)
|
||||
def _(cache_data: CacheData):
|
||||
return default_cleanup_expired(cache_data)
|
||||
|
||||
|
||||
@CacheRoot.new(CacheType.LEVEL)
|
||||
async def _():
|
||||
return await LevelUser().all()
|
||||
|
||||
|
||||
@CacheRoot.getter(CacheType.LEVEL, result_model=list[LevelUser])
|
||||
async def _(cache_data: CacheData, user_id: str, group_id: str | None = None):
|
||||
cache_data.data = cache_data.data or []
|
||||
if not group_id:
|
||||
return [
|
||||
data
|
||||
for data in cache_data.data
|
||||
if data.user_id == user_id and not data.group_id
|
||||
]
|
||||
else:
|
||||
return [
|
||||
data
|
||||
for data in cache_data.data
|
||||
if data.user_id == user_id and data.group_id == group_id
|
||||
]
|
||||
|
||||
|
||||
@CacheRoot.new(CacheType.BAN)
|
||||
async def _():
|
||||
return await BanConsole.all()
|
||||
|
||||
|
||||
@CacheRoot.getter(CacheType.BAN, result_model=list[BanConsole])
|
||||
def _(cache_data: CacheData, user_id: str | None, group_id: str | None = None):
|
||||
if user_id:
|
||||
return (
|
||||
[
|
||||
data
|
||||
for data in cache_data.data
|
||||
if data.user_id == user_id and data.group_id == group_id
|
||||
]
|
||||
if group_id
|
||||
else [
|
||||
data
|
||||
for data in cache_data.data
|
||||
if data.user_id == user_id and not data.group_id
|
||||
]
|
||||
)
|
||||
if group_id:
|
||||
return [
|
||||
data
|
||||
for data in cache_data.data
|
||||
if not data.user_id and data.group_id == group_id
|
||||
]
|
||||
return None
|
||||
@ -107,7 +107,6 @@ class GroupConsole(Model):
|
||||
return group
|
||||
|
||||
@classmethod
|
||||
@CacheRoot.listener(CacheType.GROUPS)
|
||||
async def get_or_create(
|
||||
cls,
|
||||
defaults: dict | None = None,
|
||||
@ -132,6 +131,9 @@ class GroupConsole(Model):
|
||||
await group.save(
|
||||
using_db=using_db, update_fields=["block_plugin", "block_task"]
|
||||
)
|
||||
if is_create:
|
||||
if cache := await CacheRoot.get_cache(CacheType.GROUPS):
|
||||
await cache.update(group.group_id, group)
|
||||
return group, is_create
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
import inspect
|
||||
import time
|
||||
from typing import Any, ClassVar, Generic, TypeVar, cast
|
||||
|
||||
from nonebot.utils import is_coroutine_callable
|
||||
from nonebot_plugin_apscheduler import scheduler
|
||||
from pydantic import BaseModel
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
@ -13,18 +15,43 @@ __all__ = ["Cache", "CacheData", "CacheRoot"]
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class DbCacheException(Exception):
|
||||
def __init__(self, info: str):
|
||||
self.info = info
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return super().__repr__()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.info
|
||||
|
||||
|
||||
def validate_name(func: Callable):
|
||||
"""
|
||||
装饰器:验证 name 是否存在于 CacheManage._data 中。
|
||||
"""
|
||||
|
||||
def wrapper(self, name: str, *args, **kwargs):
|
||||
_name = name.upper()
|
||||
if _name not in CacheManage._data:
|
||||
raise DbCacheException(f"DbCache 缓存数据 {name} 不存在...")
|
||||
return func(self, _name, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class CacheGetter(BaseModel, Generic[T]):
|
||||
get_func: Callable[..., Any] | None = None
|
||||
"""获取方法"""
|
||||
|
||||
async def get(self, data: Any, *args, **kwargs) -> T:
|
||||
async def get(self, cache_data: "CacheData", *args, **kwargs) -> T:
|
||||
"""获取缓存"""
|
||||
processed_data = (
|
||||
await self.get_func(data, *args, **kwargs)
|
||||
await self.get_func(cache_data, *args, **kwargs)
|
||||
if self.get_func and is_coroutine_callable(self.get_func)
|
||||
else self.get_func(data, *args, **kwargs)
|
||||
else self.get_func(cache_data, *args, **kwargs)
|
||||
if self.get_func
|
||||
else data
|
||||
else cache_data.data
|
||||
)
|
||||
return cast(T, processed_data)
|
||||
|
||||
@ -38,10 +65,18 @@ class CacheData(BaseModel):
|
||||
"""获取方法"""
|
||||
updater: Callable[..., Any] | None = None
|
||||
"""更新单个方法"""
|
||||
with_refresh: Callable[..., Any] | None = None
|
||||
"""刷新方法"""
|
||||
with_expiration: Callable[..., Any] | None = None
|
||||
"""缓存时间初始化方法"""
|
||||
cleanup_expired: Callable[..., Any] | None = None
|
||||
"""缓存过期方法"""
|
||||
data: Any = None
|
||||
"""缓存数据"""
|
||||
expire: int
|
||||
"""缓存过期时间"""
|
||||
expire_data: dict[str, int | float] = {}
|
||||
"""缓存过期数据时间记录"""
|
||||
reload_time: float = time.time()
|
||||
"""更新时间"""
|
||||
reload_count: int = 0
|
||||
@ -49,9 +84,12 @@ class CacheData(BaseModel):
|
||||
|
||||
async def get(self, *args, **kwargs) -> Any:
|
||||
"""获取单个缓存"""
|
||||
self.call_cleanup_expired() # 移除过期缓存
|
||||
if not self.getter:
|
||||
return self.data
|
||||
return await self.getter.get(self.data, *args, **kwargs)
|
||||
result = await self.getter.get(self, *args, **kwargs)
|
||||
await self.call_with_expiration()
|
||||
return result
|
||||
|
||||
async def update(self, key: str, value: Any = None, *args, **kwargs):
|
||||
"""更新单个缓存"""
|
||||
@ -64,23 +102,88 @@ class CacheData(BaseModel):
|
||||
await self.updater(self.data, key, value, *args, **kwargs)
|
||||
else:
|
||||
self.updater(self.data, key, value, *args, **kwargs)
|
||||
logger.debug(
|
||||
f"缓存类型 {self.name} 更新单个缓存 key: {key},value: {value}",
|
||||
"CacheRoot",
|
||||
)
|
||||
self.expire_data[key] = time.time() + self.expire
|
||||
else:
|
||||
logger.warning(f"缓存类型 {self.name} 为空,无法更新", "CacheRoot")
|
||||
|
||||
async def refresh(self, *args, **kwargs):
|
||||
"""刷新缓存,只刷新已缓存的数据"""
|
||||
if not self.with_refresh:
|
||||
return await self.reload(*args, **kwargs)
|
||||
if self.data:
|
||||
if is_coroutine_callable(self.with_refresh):
|
||||
await self.with_refresh(self.data, *args, **kwargs)
|
||||
else:
|
||||
self.with_refresh(self.data, *args, **kwargs)
|
||||
logger.debug(
|
||||
f"缓存类型 {self.name} 刷新全局缓存,共刷新 {len(self.data)} 条数据",
|
||||
"CacheRoot",
|
||||
)
|
||||
|
||||
async def reload(self, *args, **kwargs):
|
||||
"""更新缓存"""
|
||||
"""更新全部缓存数据"""
|
||||
if self.has_args():
|
||||
self.data = (
|
||||
await self.func(*args, **kwargs)
|
||||
if is_coroutine_callable(self.func)
|
||||
else self.func(*args, **kwargs)
|
||||
)
|
||||
else:
|
||||
self.data = (
|
||||
await self.func() if is_coroutine_callable(self.func) else self.func()
|
||||
)
|
||||
await self.call_with_expiration()
|
||||
self.reload_time = time.time()
|
||||
self.reload_count += 1
|
||||
logger.debug(f"缓存类型 {self.name} 更新全局缓存", "CacheRoot")
|
||||
logger.debug(
|
||||
f"缓存类型 {self.name} 更新全局缓存,共更新 {len(self.data)} 条数据",
|
||||
"CacheRoot",
|
||||
)
|
||||
|
||||
async def check_expire(self):
|
||||
if time.time() - self.reload_time > self.expire or not self.reload_count:
|
||||
await self.reload()
|
||||
def call_cleanup_expired(self):
|
||||
"""清理过期缓存"""
|
||||
if self.cleanup_expired:
|
||||
if result := self.cleanup_expired(self):
|
||||
logger.debug(
|
||||
f"成功清理 {self.name} {len(result)} 条过期缓存", "CacheRoot"
|
||||
)
|
||||
|
||||
async def call_with_expiration(self, is_force: bool = False):
|
||||
"""缓存时间更新
|
||||
|
||||
参数:
|
||||
is_force: 是否强制更新全部数据缓存时间.
|
||||
"""
|
||||
if self.with_expiration:
|
||||
if is_force:
|
||||
self.expire_data = {}
|
||||
expiration_data = (
|
||||
await self.with_expiration(self.data, self.expire_data, self.expire)
|
||||
if is_coroutine_callable(self.with_expiration)
|
||||
else self.with_expiration(self.data, self.expire_data, self.expire)
|
||||
)
|
||||
self.expire_data = {**self.expire_data, **expiration_data}
|
||||
|
||||
def has_args(self):
|
||||
"""是否含有参数
|
||||
|
||||
返回:
|
||||
bool: 是否含有参数
|
||||
"""
|
||||
sig = inspect.signature(self.func)
|
||||
return any(
|
||||
param.kind
|
||||
in (
|
||||
param.POSITIONAL_OR_KEYWORD,
|
||||
param.POSITIONAL_ONLY,
|
||||
param.VAR_POSITIONAL,
|
||||
)
|
||||
for param in sig.parameters.values()
|
||||
)
|
||||
|
||||
|
||||
class CacheManage:
|
||||
@ -88,18 +191,30 @@ class CacheManage:
|
||||
|
||||
|
||||
异常:
|
||||
ValueError: 数据名称重复
|
||||
ValueError: 数据不存在
|
||||
DbCacheException: 数据名称重复
|
||||
DbCacheException: 数据不存在
|
||||
|
||||
"""
|
||||
|
||||
_data: ClassVar[dict[str, CacheData]] = {}
|
||||
|
||||
def start_check(self):
|
||||
"""启动缓存检查"""
|
||||
for cache_data in self._data.values():
|
||||
if cache_data.cleanup_expired:
|
||||
scheduler.add_job(
|
||||
cache_data.call_cleanup_expired,
|
||||
"interval",
|
||||
seconds=cache_data.expire,
|
||||
args=[],
|
||||
id=f"CacheRoot-{cache_data.name}",
|
||||
)
|
||||
|
||||
def new(self, name: str, expire: int = 60 * 10):
|
||||
def wrapper(func: Callable):
|
||||
_name = name.upper()
|
||||
if _name in self._data:
|
||||
raise ValueError(f"DbCache 缓存数据 {name} 已存在...")
|
||||
raise DbCacheException(f"DbCache 缓存数据 {name} 已存在...")
|
||||
self._data[_name] = CacheData(name=_name, func=func, expire=expire)
|
||||
|
||||
return wrapper
|
||||
@ -116,8 +231,12 @@ class CacheManage:
|
||||
return result
|
||||
finally:
|
||||
cache_data = self._data.get(name.upper())
|
||||
if cache_data:
|
||||
await cache_data.reload()
|
||||
if cache_data and cache_data.with_refresh:
|
||||
if is_coroutine_callable(cache_data.with_refresh):
|
||||
await cache_data.with_refresh(cache_data.data)
|
||||
else:
|
||||
cache_data.with_refresh(cache_data.data)
|
||||
await cache_data.call_with_expiration(True)
|
||||
logger.debug(
|
||||
f"缓存类型 {name.upper()} 进行监听更新...", "CacheRoot"
|
||||
)
|
||||
@ -126,48 +245,61 @@ class CacheManage:
|
||||
|
||||
return decorator
|
||||
|
||||
@validate_name
|
||||
def updater(self, name: str):
|
||||
def wrapper(func: Callable):
|
||||
_name = name.upper()
|
||||
if _name not in self._data:
|
||||
raise ValueError(f"DbCache 缓存数据 {name} 不存在...")
|
||||
self._data[_name].updater = func
|
||||
self._data[name.upper()].updater = func
|
||||
|
||||
return wrapper
|
||||
|
||||
@validate_name
|
||||
def getter(self, name: str, result_model: type | None = None):
|
||||
def wrapper(func: Callable):
|
||||
_name = name.upper()
|
||||
if _name not in self._data:
|
||||
raise ValueError(f"DbCache 缓存数据 {name} 不存在...")
|
||||
self._data[_name].getter = CacheGetter[result_model](get_func=func)
|
||||
self._data[name].getter = CacheGetter[result_model](get_func=func)
|
||||
|
||||
return wrapper
|
||||
|
||||
async def check_expire(self, name: str):
|
||||
@validate_name
|
||||
def with_refresh(self, name: str):
|
||||
def wrapper(func: Callable):
|
||||
self._data[name.upper()].with_refresh = func
|
||||
|
||||
return wrapper
|
||||
|
||||
@validate_name
|
||||
def with_expiration(self, name: str):
|
||||
def wrapper(func: Callable[[Any, int], dict[str, float]]):
|
||||
self._data[name.upper()].with_expiration = func
|
||||
|
||||
return wrapper
|
||||
|
||||
@validate_name
|
||||
def cleanup_expired(self, name: str):
|
||||
def wrapper(func: Callable[[CacheData], None]):
|
||||
self._data[name.upper()].cleanup_expired = func
|
||||
|
||||
return wrapper
|
||||
|
||||
async def check_expire(self, name: str, *args, **kwargs):
|
||||
name = name.upper()
|
||||
if self._data.get(name):
|
||||
if (
|
||||
if self._data.get(name) and (
|
||||
time.time() - self._data[name].reload_time > self._data[name].expire
|
||||
or not self._data[name].reload_count
|
||||
):
|
||||
await self._data[name].reload()
|
||||
await self._data[name].reload(*args, **kwargs)
|
||||
|
||||
async def get_cache_data(self, name: str):
|
||||
if cache := await self.get_cache(name):
|
||||
return cache.data
|
||||
return None
|
||||
return cache.data if (cache := await self.get_cache(name)) else None
|
||||
|
||||
async def get_cache(self, name: str) -> CacheData | None:
|
||||
async def get_cache(self, name: str, *args, **kwargs) -> CacheData | None:
|
||||
name = name.upper()
|
||||
cache = self._data.get(name)
|
||||
if cache:
|
||||
await self.check_expire(name)
|
||||
if cache := self._data.get(name):
|
||||
# await self.check_expire(name, *args, **kwargs)
|
||||
return cache
|
||||
return None
|
||||
|
||||
async def get(self, name: str, *args, **kwargs):
|
||||
cache = await self.get_cache(name.upper())
|
||||
cache = await self.get_cache(name.upper(), *args, **kwargs)
|
||||
if cache:
|
||||
return await cache.get(*args, **kwargs) if cache.getter else cache.data
|
||||
return None
|
||||
@ -175,7 +307,7 @@ class CacheManage:
|
||||
async def reload(self, name: str, *args, **kwargs):
|
||||
cache = await self.get_cache(name.upper())
|
||||
if cache:
|
||||
await cache.reload(*args, **kwargs)
|
||||
await cache.refresh(*args, **kwargs)
|
||||
|
||||
async def update(self, name: str, key: str, value: Any, *args, **kwargs):
|
||||
cache = await self.get_cache(name.upper())
|
||||
|
||||
Loading…
Reference in New Issue
Block a user