mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 14:22:55 +08:00
🐛 尝试迁移至aiocache
This commit is contained in:
parent
83942bae26
commit
f3d5b77bdc
@ -11,7 +11,7 @@ from zhenxun.utils.platform import PlatformUtils
|
|||||||
nonebot.load_plugins(str(Path(__file__).parent.resolve()))
|
nonebot.load_plugins(str(Path(__file__).parent.resolve()))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from .__init_cache import driver
|
from . import __init_cache
|
||||||
except DbCacheException as e:
|
except DbCacheException as e:
|
||||||
raise SystemError(f"ERROR:{e}")
|
raise SystemError(f"ERROR:{e}")
|
||||||
|
|
||||||
|
|||||||
@ -1,8 +1,5 @@
|
|||||||
import time
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import nonebot
|
|
||||||
|
|
||||||
from zhenxun.models.ban_console import BanConsole
|
from zhenxun.models.ban_console import BanConsole
|
||||||
from zhenxun.models.bot_console import BotConsole
|
from zhenxun.models.bot_console import BotConsole
|
||||||
from zhenxun.models.group_console import GroupConsole
|
from zhenxun.models.group_console import GroupConsole
|
||||||
@ -11,77 +8,20 @@ from zhenxun.models.plugin_info import PluginInfo
|
|||||||
from zhenxun.models.plugin_limit import PluginLimit
|
from zhenxun.models.plugin_limit import PluginLimit
|
||||||
from zhenxun.models.user_console import UserConsole
|
from zhenxun.models.user_console import UserConsole
|
||||||
from zhenxun.services.cache import CacheData, CacheRoot
|
from zhenxun.services.cache import CacheData, CacheRoot
|
||||||
|
from zhenxun.services.log import logger
|
||||||
from zhenxun.utils.enum import CacheType
|
from zhenxun.utils.enum import CacheType
|
||||||
|
|
||||||
driver = nonebot.get_driver()
|
|
||||||
|
|
||||||
|
|
||||||
@driver.on_startup
|
|
||||||
async def _():
|
|
||||||
"""开启cache检测"""
|
|
||||||
CacheRoot.start_check()
|
|
||||||
|
|
||||||
|
|
||||||
def default_cleanup_expired(cache_data: CacheData) -> list[str]:
|
|
||||||
"""默认清理过期cache方法"""
|
|
||||||
if not cache_data.data:
|
|
||||||
return []
|
|
||||||
now = time.time()
|
|
||||||
expire_key = []
|
|
||||||
for k, t in list(cache_data.expire_data.items()):
|
|
||||||
if t < now:
|
|
||||||
expire_key.append(k)
|
|
||||||
cache_data.expire_data.pop(k)
|
|
||||||
if expire_key:
|
|
||||||
cache_data.data = {
|
|
||||||
k: v for k, v in cache_data.data.items() if k not in expire_key
|
|
||||||
}
|
|
||||||
return expire_key
|
|
||||||
|
|
||||||
|
|
||||||
def default_cleanup_expired_1(cache_data: CacheData) -> list[str]:
|
|
||||||
"""默认清理列表过期cache方法"""
|
|
||||||
if not cache_data.data:
|
|
||||||
return []
|
|
||||||
now = time.time()
|
|
||||||
expire_key = []
|
|
||||||
for k, t in list(cache_data.expire_data.items()):
|
|
||||||
if t < now:
|
|
||||||
expire_key.append(k)
|
|
||||||
cache_data.expire_data.pop(k)
|
|
||||||
if expire_key:
|
|
||||||
cache_data.data = [k for k in cache_data.data if repr(k) not in expire_key]
|
|
||||||
return expire_key
|
|
||||||
|
|
||||||
|
|
||||||
def default_with_expiration(
|
|
||||||
data: dict[str, Any], expire_data: dict[str, int], expire: int
|
|
||||||
):
|
|
||||||
"""默认更新过期时间cache方法"""
|
|
||||||
if not data:
|
|
||||||
return {}
|
|
||||||
keys = {k for k in data if k not in expire_data}
|
|
||||||
return {k: time.time() + expire for k in keys} if keys else {}
|
|
||||||
|
|
||||||
|
|
||||||
def default_with_expiration_1(
|
|
||||||
data: dict[str, Any], expire_data: dict[str, int], expire: int
|
|
||||||
):
|
|
||||||
"""默认更新过期时间cache方法"""
|
|
||||||
if not data:
|
|
||||||
return {}
|
|
||||||
keys = {repr(k) for k in data if repr(k) not in expire_data}
|
|
||||||
return {k: time.time() + expire for k in keys} if keys else {}
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.new(CacheType.PLUGINS)
|
@CacheRoot.new(CacheType.PLUGINS)
|
||||||
async def _():
|
async def _():
|
||||||
|
"""初始化插件缓存"""
|
||||||
data_list = await PluginInfo.get_plugins()
|
data_list = await PluginInfo.get_plugins()
|
||||||
return {p.module: p for p in data_list}
|
return {p.module: p for p in data_list}
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.updater(CacheType.PLUGINS)
|
@CacheRoot.updater(CacheType.PLUGINS)
|
||||||
async def _(data: dict[str, PluginInfo], key: str, value: Any):
|
async def _(data: dict[str, PluginInfo], key: str, value: Any):
|
||||||
|
"""更新插件缓存"""
|
||||||
if value:
|
if value:
|
||||||
data[key] = value
|
data[key] = value
|
||||||
elif plugin := await PluginInfo.get_plugin(module=key):
|
elif plugin := await PluginInfo.get_plugin(module=key):
|
||||||
@ -90,42 +30,35 @@ async def _(data: dict[str, PluginInfo], key: str, value: Any):
|
|||||||
|
|
||||||
@CacheRoot.getter(CacheType.PLUGINS, result_model=PluginInfo)
|
@CacheRoot.getter(CacheType.PLUGINS, result_model=PluginInfo)
|
||||||
async def _(cache_data: CacheData, module: str):
|
async def _(cache_data: CacheData, module: str):
|
||||||
cache_data.data = cache_data.data or {}
|
"""获取插件缓存"""
|
||||||
result = cache_data.data.get(module, None)
|
data = await cache_data.get_data() or {}
|
||||||
if not result:
|
if module not in data:
|
||||||
result = await PluginInfo.get_plugin(module=module)
|
if plugin := await PluginInfo.get_plugin(module=module):
|
||||||
if result:
|
data[module] = plugin
|
||||||
cache_data.data[module] = result
|
await cache_data.set_data(data)
|
||||||
return result
|
logger.debug(f"插件 {module} 数据已设置到缓存")
|
||||||
|
return data.get(module)
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.with_refresh(CacheType.PLUGINS)
|
@CacheRoot.with_refresh(CacheType.PLUGINS)
|
||||||
async def _(data: dict[str, PluginInfo] | None):
|
async def _(data: dict[str, PluginInfo] | None):
|
||||||
|
"""刷新插件缓存"""
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
plugins = await PluginInfo.filter(module__in=data.keys(), load_status=True).all()
|
plugins = await PluginInfo.filter(module__in=data.keys(), load_status=True).all()
|
||||||
for plugin in plugins:
|
data.update({p.module: p for p in plugins})
|
||||||
data[plugin.module] = plugin
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.with_expiration(CacheType.PLUGINS)
|
|
||||||
def _(data: dict[str, PluginInfo], expire_data: dict[str, int], expire: int):
|
|
||||||
return default_with_expiration(data, expire_data, expire)
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.cleanup_expired(CacheType.PLUGINS)
|
|
||||||
def _(cache_data: CacheData):
|
|
||||||
return default_cleanup_expired(cache_data)
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.new(CacheType.GROUPS)
|
@CacheRoot.new(CacheType.GROUPS)
|
||||||
async def _():
|
async def _():
|
||||||
|
"""初始化群组缓存"""
|
||||||
data_list = await GroupConsole.all()
|
data_list = await GroupConsole.all()
|
||||||
return {p.group_id: p for p in data_list if not p.channel_id}
|
return {p.group_id: p for p in data_list if not p.channel_id}
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.updater(CacheType.GROUPS)
|
@CacheRoot.updater(CacheType.GROUPS)
|
||||||
async def _(data: dict[str, GroupConsole], key: str, value: Any):
|
async def _(data: dict[str, GroupConsole], key: str, value: Any):
|
||||||
|
"""更新群组缓存"""
|
||||||
if value:
|
if value:
|
||||||
data[key] = value
|
data[key] = value
|
||||||
elif group := await GroupConsole.get_group(group_id=key):
|
elif group := await GroupConsole.get_group(group_id=key):
|
||||||
@ -134,44 +67,36 @@ async def _(data: dict[str, GroupConsole], key: str, value: Any):
|
|||||||
|
|
||||||
@CacheRoot.getter(CacheType.GROUPS, result_model=GroupConsole)
|
@CacheRoot.getter(CacheType.GROUPS, result_model=GroupConsole)
|
||||||
async def _(cache_data: CacheData, group_id: str):
|
async def _(cache_data: CacheData, group_id: str):
|
||||||
cache_data.data = cache_data.data or {}
|
"""获取群组缓存"""
|
||||||
result = cache_data.data.get(group_id, None)
|
data = await cache_data.get_data() or {}
|
||||||
if not result:
|
if group_id not in data:
|
||||||
result = await GroupConsole.get_group(group_id=group_id)
|
if group := await GroupConsole.get_group(group_id=group_id):
|
||||||
if result:
|
data[group_id] = group
|
||||||
cache_data.data[group_id] = result
|
await cache_data.set_data(data)
|
||||||
return result
|
return data.get(group_id)
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.with_refresh(CacheType.GROUPS)
|
@CacheRoot.with_refresh(CacheType.GROUPS)
|
||||||
async def _(data: dict[str, GroupConsole] | None):
|
async def _(data: dict[str, GroupConsole] | None):
|
||||||
|
"""刷新群组缓存"""
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
groups = await GroupConsole.filter(
|
groups = await GroupConsole.filter(
|
||||||
group_id__in=data.keys(), channel_id__isnull=True
|
group_id__in=data.keys(), channel_id__isnull=True
|
||||||
).all()
|
).all()
|
||||||
for group in groups:
|
data.update({g.group_id: g for g in groups})
|
||||||
data[group.group_id] = group
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.with_expiration(CacheType.GROUPS)
|
|
||||||
def _(data: dict[str, GroupConsole], expire_data: dict[str, int], expire: int):
|
|
||||||
return default_with_expiration(data, expire_data, expire)
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.cleanup_expired(CacheType.GROUPS)
|
|
||||||
def _(cache_data: CacheData):
|
|
||||||
return default_cleanup_expired(cache_data)
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.new(CacheType.BOT)
|
@CacheRoot.new(CacheType.BOT)
|
||||||
async def _():
|
async def _():
|
||||||
|
"""初始化机器人缓存"""
|
||||||
data_list = await BotConsole.all()
|
data_list = await BotConsole.all()
|
||||||
return {p.bot_id: p for p in data_list}
|
return {p.bot_id: p for p in data_list}
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.updater(CacheType.BOT)
|
@CacheRoot.updater(CacheType.BOT)
|
||||||
async def _(data: dict[str, BotConsole], key: str, value: Any):
|
async def _(data: dict[str, BotConsole], key: str, value: Any):
|
||||||
|
"""更新机器人缓存"""
|
||||||
if value:
|
if value:
|
||||||
data[key] = value
|
data[key] = value
|
||||||
elif bot := await BotConsole.get_or_none(bot_id=key):
|
elif bot := await BotConsole.get_or_none(bot_id=key):
|
||||||
@ -180,42 +105,34 @@ async def _(data: dict[str, BotConsole], key: str, value: Any):
|
|||||||
|
|
||||||
@CacheRoot.getter(CacheType.BOT, result_model=BotConsole)
|
@CacheRoot.getter(CacheType.BOT, result_model=BotConsole)
|
||||||
async def _(cache_data: CacheData, bot_id: str):
|
async def _(cache_data: CacheData, bot_id: str):
|
||||||
cache_data.data = cache_data.data or {}
|
"""获取机器人缓存"""
|
||||||
result = cache_data.data.get(bot_id, None)
|
data = await cache_data.get_data() or {}
|
||||||
if not result:
|
if bot_id not in data:
|
||||||
result = await BotConsole.get_or_none(bot_id=bot_id)
|
if bot := await BotConsole.get_or_none(bot_id=bot_id):
|
||||||
if result:
|
data[bot_id] = bot
|
||||||
cache_data.data[bot_id] = result
|
await cache_data.set_data(data)
|
||||||
return result
|
return data.get(bot_id)
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.with_refresh(CacheType.BOT)
|
@CacheRoot.with_refresh(CacheType.BOT)
|
||||||
async def _(data: dict[str, BotConsole] | None):
|
async def _(data: dict[str, BotConsole] | None):
|
||||||
|
"""刷新机器人缓存"""
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
bots = await BotConsole.filter(bot_id__in=data.keys()).all()
|
bots = await BotConsole.filter(bot_id__in=data.keys()).all()
|
||||||
for bot in bots:
|
data.update({b.bot_id: b for b in bots})
|
||||||
data[bot.bot_id] = bot
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.with_expiration(CacheType.BOT)
|
|
||||||
def _(data: dict[str, BotConsole], expire_data: dict[str, int], expire: int):
|
|
||||||
return default_with_expiration(data, expire_data, expire)
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.cleanup_expired(CacheType.BOT)
|
|
||||||
def _(cache_data: CacheData):
|
|
||||||
return default_cleanup_expired(cache_data)
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.new(CacheType.USERS)
|
@CacheRoot.new(CacheType.USERS)
|
||||||
async def _():
|
async def _():
|
||||||
|
"""初始化用户缓存"""
|
||||||
data_list = await UserConsole.all()
|
data_list = await UserConsole.all()
|
||||||
return {p.user_id: p for p in data_list}
|
return {p.user_id: p for p in data_list}
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.updater(CacheType.USERS)
|
@CacheRoot.updater(CacheType.USERS)
|
||||||
async def _(data: dict[str, UserConsole], key: str, value: Any):
|
async def _(data: dict[str, UserConsole], key: str, value: Any):
|
||||||
|
"""更新用户缓存"""
|
||||||
if value:
|
if value:
|
||||||
data[key] = value
|
data[key] = value
|
||||||
elif user := await UserConsole.get_user(user_id=key):
|
elif user := await UserConsole.get_user(user_id=key):
|
||||||
@ -224,108 +141,61 @@ async def _(data: dict[str, UserConsole], key: str, value: Any):
|
|||||||
|
|
||||||
@CacheRoot.getter(CacheType.USERS, result_model=UserConsole)
|
@CacheRoot.getter(CacheType.USERS, result_model=UserConsole)
|
||||||
async def _(cache_data: CacheData, user_id: str):
|
async def _(cache_data: CacheData, user_id: str):
|
||||||
cache_data.data = cache_data.data or {}
|
"""获取用户缓存"""
|
||||||
result = cache_data.data.get(user_id, None)
|
data = await cache_data.get_data() or {}
|
||||||
if not result:
|
if user_id not in data:
|
||||||
result = await UserConsole.get_user(user_id=user_id)
|
if user := await UserConsole.get_user(user_id=user_id):
|
||||||
if result:
|
data[user_id] = user
|
||||||
cache_data.data[user_id] = result
|
await cache_data.set_data(data)
|
||||||
return result
|
return data.get(user_id)
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.with_refresh(CacheType.USERS)
|
@CacheRoot.with_refresh(CacheType.USERS)
|
||||||
async def _(data: dict[str, UserConsole] | None):
|
async def _(data: dict[str, UserConsole] | None):
|
||||||
|
"""刷新用户缓存"""
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
users = await UserConsole.filter(user_id__in=data.keys()).all()
|
users = await UserConsole.filter(user_id__in=data.keys()).all()
|
||||||
for user in users:
|
data.update({u.user_id: u for u in users})
|
||||||
data[user.user_id] = user
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.with_expiration(CacheType.USERS)
|
|
||||||
def _(data: dict[str, UserConsole], expire_data: dict[str, int], expire: int):
|
|
||||||
return default_with_expiration(data, expire_data, expire)
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.cleanup_expired(CacheType.USERS)
|
|
||||||
def _(cache_data: CacheData):
|
|
||||||
return default_cleanup_expired(cache_data)
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.new(CacheType.LEVEL, False)
|
@CacheRoot.new(CacheType.LEVEL, False)
|
||||||
async def _():
|
async def _():
|
||||||
|
"""初始化等级缓存"""
|
||||||
return await LevelUser().all()
|
return await LevelUser().all()
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.getter(CacheType.LEVEL, result_model=list[LevelUser])
|
@CacheRoot.getter(CacheType.LEVEL, result_model=list[LevelUser])
|
||||||
async def _(cache_data: CacheData, user_id: str, group_id: str | None = None):
|
async def _(cache_data: CacheData, user_id: str, group_id: str | None = None):
|
||||||
cache_data.data = cache_data.data or []
|
"""获取等级缓存"""
|
||||||
|
data = await cache_data.get_data() or []
|
||||||
if not group_id:
|
if not group_id:
|
||||||
return [
|
return [d for d in data if d.user_id == user_id and not d.group_id]
|
||||||
data
|
return [d for d in data if d.user_id == user_id and d.group_id == group_id]
|
||||||
for data in cache_data.data
|
|
||||||
if data.user_id == user_id and not data.group_id
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
return [
|
|
||||||
data
|
|
||||||
for data in cache_data.data
|
|
||||||
if data.user_id == user_id and data.group_id == group_id
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.with_expiration(CacheType.LEVEL)
|
|
||||||
def _(data: dict[str, UserConsole], expire_data: dict[str, int], expire: int):
|
|
||||||
return default_with_expiration_1(data, expire_data, expire)
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.cleanup_expired(CacheType.LEVEL)
|
|
||||||
def _(cache_data: CacheData):
|
|
||||||
return default_cleanup_expired_1(cache_data)
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.new(CacheType.BAN, False)
|
@CacheRoot.new(CacheType.BAN, False)
|
||||||
async def _():
|
async def _():
|
||||||
|
"""初始化封禁缓存"""
|
||||||
return await BanConsole.all()
|
return await BanConsole.all()
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.getter(CacheType.BAN, result_model=list[BanConsole])
|
@CacheRoot.getter(CacheType.BAN, result_model=list[BanConsole])
|
||||||
async def _(cache_data: CacheData, user_id: str | None, group_id: str | None = None):
|
async def _(cache_data: CacheData, user_id: str | None, group_id: str | None = None):
|
||||||
|
"""获取封禁缓存"""
|
||||||
|
data = await cache_data.get_data() or []
|
||||||
if user_id:
|
if user_id:
|
||||||
return (
|
|
||||||
[
|
|
||||||
data
|
|
||||||
for data in cache_data.data
|
|
||||||
if data.user_id == user_id and data.group_id == group_id
|
|
||||||
]
|
|
||||||
if group_id
|
|
||||||
else [
|
|
||||||
data
|
|
||||||
for data in cache_data.data
|
|
||||||
if data.user_id == user_id and not data.group_id
|
|
||||||
]
|
|
||||||
)
|
|
||||||
if group_id:
|
if group_id:
|
||||||
return [
|
return [d for d in data if d.user_id == user_id and d.group_id == group_id]
|
||||||
data
|
return [d for d in data if d.user_id == user_id and not d.group_id]
|
||||||
for data in cache_data.data
|
if group_id:
|
||||||
if not data.user_id and data.group_id == group_id
|
return [d for d in data if not d.user_id and d.group_id == group_id]
|
||||||
]
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.with_expiration(CacheType.BAN)
|
|
||||||
def _(data: dict[str, UserConsole], expire_data: dict[str, int], expire: int):
|
|
||||||
return default_with_expiration_1(data, expire_data, expire)
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.cleanup_expired(CacheType.BAN)
|
|
||||||
def _(cache_data: CacheData):
|
|
||||||
return default_cleanup_expired_1(cache_data)
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.new(CacheType.LIMIT)
|
@CacheRoot.new(CacheType.LIMIT)
|
||||||
async def _():
|
async def _():
|
||||||
|
"""初始化限制缓存"""
|
||||||
data_list = await PluginLimit.filter(status=True).all()
|
data_list = await PluginLimit.filter(status=True).all()
|
||||||
result_data = {}
|
result_data = {}
|
||||||
for data in data_list:
|
for data in data_list:
|
||||||
@ -337,6 +207,7 @@ async def _():
|
|||||||
|
|
||||||
@CacheRoot.updater(CacheType.LIMIT)
|
@CacheRoot.updater(CacheType.LIMIT)
|
||||||
async def _(data: dict[str, list[PluginLimit]], key: str, value: Any):
|
async def _(data: dict[str, list[PluginLimit]], key: str, value: Any):
|
||||||
|
"""更新限制缓存"""
|
||||||
if value:
|
if value:
|
||||||
data[key] = value
|
data[key] = value
|
||||||
elif limits := await PluginLimit.filter(module=key, status=True):
|
elif limits := await PluginLimit.filter(module=key, status=True):
|
||||||
@ -345,32 +216,25 @@ async def _(data: dict[str, list[PluginLimit]], key: str, value: Any):
|
|||||||
|
|
||||||
@CacheRoot.getter(CacheType.LIMIT, result_model=list[PluginLimit])
|
@CacheRoot.getter(CacheType.LIMIT, result_model=list[PluginLimit])
|
||||||
async def _(cache_data: CacheData, module: str):
|
async def _(cache_data: CacheData, module: str):
|
||||||
cache_data.data = cache_data.data or {}
|
"""获取限制缓存"""
|
||||||
result = cache_data.data.get(module, None)
|
data = await cache_data.get_data() or {}
|
||||||
if not result:
|
if module not in data:
|
||||||
result = await PluginLimit.filter(module=module, status=True)
|
if limits := await PluginLimit.filter(module=module, status=True):
|
||||||
if result:
|
data[module] = limits
|
||||||
cache_data.data[module] = result
|
await cache_data.set_data(data)
|
||||||
return result
|
return data.get(module)
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.with_refresh(CacheType.LIMIT)
|
@CacheRoot.with_refresh(CacheType.LIMIT)
|
||||||
async def _(data: dict[str, list[PluginLimit]] | None):
|
async def _(data: dict[str, list[PluginLimit]] | None):
|
||||||
|
"""刷新限制缓存"""
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
limits = await PluginLimit.filter(module__in=data.keys(), load_status=True).all()
|
limits = await PluginLimit.filter(module__in=data.keys(), load_status=True).all()
|
||||||
data.clear()
|
new_data = {}
|
||||||
for limit in limits:
|
for limit in limits:
|
||||||
if not data.get(limit.module):
|
if not new_data.get(limit.module):
|
||||||
data[limit.module] = []
|
new_data[limit.module] = []
|
||||||
data[limit.module].append(limit)
|
new_data[limit.module].append(limit)
|
||||||
|
data.clear()
|
||||||
|
data.update(new_data)
|
||||||
@CacheRoot.with_expiration(CacheType.LIMIT)
|
|
||||||
def _(data: dict[str, PluginInfo], expire_data: dict[str, int], expire: int):
|
|
||||||
return default_with_expiration(data, expire_data, expire)
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.cleanup_expired(CacheType.LIMIT)
|
|
||||||
def _(cache_data: CacheData):
|
|
||||||
return default_cleanup_expired(cache_data)
|
|
||||||
|
|||||||
@ -1,11 +1,14 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from datetime import datetime
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
import inspect
|
import inspect
|
||||||
import time
|
from typing import Any, ClassVar, Generic, TypeVar
|
||||||
from typing import Any, ClassVar, Generic, TypeVar, cast
|
|
||||||
|
|
||||||
|
from aiocache import Cache as AioCache
|
||||||
|
from aiocache.base import BaseCache
|
||||||
|
from aiocache.serializers import JsonSerializer
|
||||||
|
from nonebot.compat import model_dump
|
||||||
from nonebot.utils import is_coroutine_callable
|
from nonebot.utils import is_coroutine_callable
|
||||||
from nonebot_plugin_apscheduler import scheduler
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
@ -16,168 +19,258 @@ T = TypeVar("T")
|
|||||||
|
|
||||||
|
|
||||||
class DbCacheException(Exception):
|
class DbCacheException(Exception):
|
||||||
|
"""缓存相关异常"""
|
||||||
|
|
||||||
def __init__(self, info: str):
|
def __init__(self, info: str):
|
||||||
self.info = info
|
self.info = info
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return super().__repr__()
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return self.info
|
return self.info
|
||||||
|
|
||||||
|
|
||||||
def validate_name(func: Callable):
|
def validate_name(func: Callable):
|
||||||
"""
|
"""验证缓存名称是否存在的装饰器"""
|
||||||
装饰器:验证 name 是否存在于 CacheManage._data 中。
|
|
||||||
"""
|
|
||||||
|
|
||||||
def wrapper(self, name: str, *args, **kwargs):
|
def wrapper(self, name: str, *args, **kwargs):
|
||||||
_name = name.upper()
|
_name = name.upper()
|
||||||
if _name not in CacheManager._data:
|
if _name not in CacheManager._data:
|
||||||
raise DbCacheException(f"DbCache 缓存数据 {name} 不存在...")
|
raise DbCacheException(f"缓存数据 {name} 不存在")
|
||||||
return func(self, _name, *args, **kwargs)
|
return func(self, _name, *args, **kwargs)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
class CacheGetter(BaseModel, Generic[T]):
|
class CacheGetter(BaseModel, Generic[T]):
|
||||||
|
"""缓存数据获取器"""
|
||||||
|
|
||||||
get_func: Callable[..., Any] | None = None
|
get_func: Callable[..., Any] | None = None
|
||||||
"""获取方法"""
|
|
||||||
|
|
||||||
async def get(self, cache_data: "CacheData", *args, **kwargs) -> T:
|
async def get(self, cache_data: "CacheData", *args, **kwargs) -> T:
|
||||||
"""获取缓存"""
|
"""获取处理后的缓存数据"""
|
||||||
if not self.get_func:
|
if not self.get_func:
|
||||||
return cache_data.data
|
return await cache_data.get_data()
|
||||||
|
|
||||||
if is_coroutine_callable(self.get_func):
|
if is_coroutine_callable(self.get_func):
|
||||||
processed_data = await self.get_func(cache_data, *args, **kwargs)
|
return await self.get_func(cache_data, *args, **kwargs)
|
||||||
else:
|
return self.get_func(cache_data, *args, **kwargs)
|
||||||
processed_data = self.get_func(cache_data, *args, **kwargs)
|
|
||||||
return cast(T, processed_data)
|
|
||||||
|
|
||||||
|
|
||||||
class CacheData(BaseModel):
|
class CacheData(BaseModel):
|
||||||
|
"""缓存数据模型"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
"""缓存名称"""
|
|
||||||
func: Callable[..., Any]
|
func: Callable[..., Any]
|
||||||
"""更新方法"""
|
|
||||||
getter: CacheGetter | None = None
|
getter: CacheGetter | None = None
|
||||||
"""获取方法"""
|
|
||||||
updater: Callable[..., Any] | None = None
|
updater: Callable[..., Any] | None = None
|
||||||
"""更新单个方法"""
|
|
||||||
with_refresh: Callable[..., Any] | None = None
|
with_refresh: Callable[..., Any] | None = None
|
||||||
"""刷新方法"""
|
expire: int = 600 # 默认10分钟过期
|
||||||
with_expiration: Callable[..., Any] | None = None
|
|
||||||
"""缓存时间初始化方法"""
|
|
||||||
cleanup_expired: Callable[..., Any] | None = None
|
|
||||||
"""缓存过期方法"""
|
|
||||||
data: Any = None
|
|
||||||
"""缓存数据"""
|
|
||||||
expire: int
|
|
||||||
"""缓存过期时间"""
|
|
||||||
expire_data: dict[str, int | float] = {}
|
|
||||||
"""缓存过期数据时间记录"""
|
|
||||||
reload_time: float = time.time()
|
|
||||||
"""更新时间"""
|
|
||||||
reload_count: int = 0
|
reload_count: int = 0
|
||||||
"""更新次数"""
|
|
||||||
incremental_update: bool = True
|
incremental_update: bool = True
|
||||||
"""是否是增量更新"""
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _cache(self) -> BaseCache:
|
||||||
|
"""获取aiocache实例"""
|
||||||
|
return AioCache(
|
||||||
|
AioCache.MEMORY,
|
||||||
|
serializer=JsonSerializer(),
|
||||||
|
namespace="zhenxun_cache",
|
||||||
|
timeout=30, # 操作超时时间
|
||||||
|
ttl=self.expire, # 设置默认过期时间
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_data(self) -> Any:
|
||||||
|
"""从缓存获取数据"""
|
||||||
|
try:
|
||||||
|
data = await self._cache.get(self.name)
|
||||||
|
logger.debug(f"获取缓存 {self.name} 数据: {data}")
|
||||||
|
|
||||||
|
# 如果数据为空,尝试重新加载
|
||||||
|
if data is None:
|
||||||
|
logger.debug(f"缓存 {self.name} 数据为空,尝试重新加载")
|
||||||
|
try:
|
||||||
|
if self.has_args():
|
||||||
|
new_data = (
|
||||||
|
await self.func()
|
||||||
|
if is_coroutine_callable(self.func)
|
||||||
|
else self.func()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_data = (
|
||||||
|
await self.func()
|
||||||
|
if is_coroutine_callable(self.func)
|
||||||
|
else self.func()
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.set_data(new_data)
|
||||||
|
self.reload_count += 1
|
||||||
|
logger.info(f"重新加载缓存 {self.name} 完成")
|
||||||
|
return new_data
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"重新加载缓存 {self.name} 失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return data
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取缓存 {self.name} 失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _serialize_value(self, value: Any) -> Any:
|
||||||
|
"""序列化值,将数据转换为JSON可序列化的格式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: 需要序列化的值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON可序列化的值
|
||||||
|
"""
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 处理datetime
|
||||||
|
if isinstance(value, datetime):
|
||||||
|
return value.isoformat()
|
||||||
|
|
||||||
|
# 处理Tortoise-ORM Model
|
||||||
|
if hasattr(value, "_meta") and hasattr(value, "__dict__"):
|
||||||
|
result = {}
|
||||||
|
for field in value._meta.fields:
|
||||||
|
try:
|
||||||
|
field_value = getattr(value, field)
|
||||||
|
# 跳过反向关系字段
|
||||||
|
if isinstance(field_value, list | set) and hasattr(
|
||||||
|
field_value, "_related_name"
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
result[field] = self._serialize_value(field_value)
|
||||||
|
except AttributeError:
|
||||||
|
continue
|
||||||
|
return result
|
||||||
|
# 处理Pydantic模型
|
||||||
|
elif isinstance(value, BaseModel):
|
||||||
|
return model_dump(value)
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
# 处理字典
|
||||||
|
return {str(k): self._serialize_value(v) for k, v in value.items()}
|
||||||
|
elif isinstance(value, list | tuple | set):
|
||||||
|
# 处理列表、元组、集合
|
||||||
|
return [self._serialize_value(item) for item in value]
|
||||||
|
elif isinstance(value, int | float | str | bool):
|
||||||
|
# 基本类型直接返回
|
||||||
|
return value
|
||||||
|
elif hasattr(value, "__dict__"):
|
||||||
|
# 处理普通类对象
|
||||||
|
return self._serialize_value(value.__dict__)
|
||||||
|
else:
|
||||||
|
# 其他类型转换为字符串
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
async def set_data(self, value: Any):
|
||||||
|
"""设置缓存数据"""
|
||||||
|
try:
|
||||||
|
# 1. 序列化数据
|
||||||
|
serialized_value = self._serialize_value(value)
|
||||||
|
logger.debug(f"设置缓存 {self.name} 原始数据: {value}")
|
||||||
|
logger.debug(f"设置缓存 {self.name} 序列化后数据: {serialized_value}")
|
||||||
|
|
||||||
|
# 2. 删除旧数据
|
||||||
|
await self._cache.delete(self.name)
|
||||||
|
logger.debug(f"删除缓存 {self.name} 旧数据")
|
||||||
|
|
||||||
|
# 3. 设置新数据
|
||||||
|
await self._cache.set(self.name, serialized_value, ttl=self.expire)
|
||||||
|
logger.debug(f"设置缓存 {self.name} 新数据完成")
|
||||||
|
|
||||||
|
# 4. 立即验证
|
||||||
|
cached_data = await self._cache.get(self.name)
|
||||||
|
if cached_data is None:
|
||||||
|
logger.error(f"缓存 {self.name} 设置失败:数据验证失败")
|
||||||
|
# 5. 如果验证失败,尝试重新设置
|
||||||
|
await self._cache.set(self.name, serialized_value, ttl=self.expire)
|
||||||
|
cached_data = await self._cache.get(self.name)
|
||||||
|
if cached_data is None:
|
||||||
|
logger.error(f"缓存 {self.name} 重试设置失败")
|
||||||
|
else:
|
||||||
|
logger.debug(f"缓存 {self.name} 重试设置成功: {cached_data}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"缓存 {self.name} 数据验证成功: {cached_data}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"设置缓存 {self.name} 失败: {e}")
|
||||||
|
raise # 重新抛出异常,让上层处理
|
||||||
|
|
||||||
|
async def delete_data(self):
|
||||||
|
"""删除缓存数据"""
|
||||||
|
try:
|
||||||
|
await self._cache.delete(self.name)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"删除缓存 {self.name}", e=e)
|
||||||
|
|
||||||
async def get(self, *args, **kwargs) -> Any:
|
async def get(self, *args, **kwargs) -> Any:
|
||||||
"""获取单个缓存"""
|
"""获取缓存"""
|
||||||
if not self.reload_count and not self.incremental_update:
|
if not self.reload_count and not self.incremental_update:
|
||||||
# 首次获取时,非增量更新获取全部数据
|
await self.reload(*args, **kwargs)
|
||||||
await self.reload()
|
|
||||||
self.call_cleanup_expired() # 移除过期缓存
|
|
||||||
if not self.getter:
|
if not self.getter:
|
||||||
return self.data
|
return await self.get_data()
|
||||||
result = await self.getter.get(self, *args, **kwargs)
|
|
||||||
await self.call_with_expiration()
|
return await self.getter.get(self, *args, **kwargs)
|
||||||
return result
|
|
||||||
|
|
||||||
async def update(self, key: str, value: Any = None, *args, **kwargs):
|
async def update(self, key: str, value: Any = None, *args, **kwargs):
|
||||||
"""更新单个缓存"""
|
"""更新单个缓存项"""
|
||||||
if not self.updater:
|
if not self.updater:
|
||||||
return logger.warning(
|
logger.warning(f"缓存 {self.name} 未配置更新方法")
|
||||||
f"缓存类型 {self.name} 没有更新方法,无法更新", "CacheRoot"
|
return
|
||||||
)
|
|
||||||
if self.data:
|
current_data = await self.get_data() or {}
|
||||||
if is_coroutine_callable(self.updater):
|
if is_coroutine_callable(self.updater):
|
||||||
await self.updater(self.data, key, value, *args, **kwargs)
|
await self.updater(current_data, key, value, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
self.updater(self.data, key, value, *args, **kwargs)
|
self.updater(current_data, key, value, *args, **kwargs)
|
||||||
logger.debug(
|
|
||||||
f"缓存类型 {self.name} 更新单个缓存 key: {key},value: {value}",
|
await self.set_data(current_data)
|
||||||
"CacheRoot",
|
logger.debug(f"更新缓存 {self.name}.{key}")
|
||||||
)
|
|
||||||
self.expire_data[key] = time.time() + self.expire
|
|
||||||
else:
|
|
||||||
logger.warning(f"缓存类型 {self.name} 为空,无法更新", "CacheRoot")
|
|
||||||
|
|
||||||
async def refresh(self, *args, **kwargs):
|
async def refresh(self, *args, **kwargs):
|
||||||
"""刷新缓存,只刷新已缓存的数据"""
|
"""刷新缓存数据"""
|
||||||
if not self.with_refresh:
|
if not self.with_refresh:
|
||||||
return await self.reload(*args, **kwargs)
|
return await self.reload(*args, **kwargs)
|
||||||
if self.data:
|
|
||||||
|
current_data = await self.get_data()
|
||||||
|
if current_data:
|
||||||
if is_coroutine_callable(self.with_refresh):
|
if is_coroutine_callable(self.with_refresh):
|
||||||
await self.with_refresh(self.data, *args, **kwargs)
|
await self.with_refresh(current_data, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
self.with_refresh(self.data, *args, **kwargs)
|
self.with_refresh(current_data, *args, **kwargs)
|
||||||
logger.debug(
|
await self.set_data(current_data)
|
||||||
f"缓存类型 {self.name} 刷新全局缓存,共刷新 {len(self.data)} 条数据",
|
logger.debug(f"刷新缓存 {self.name}")
|
||||||
"CacheRoot",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def reload(self, *args, **kwargs):
|
async def reload(self, *args, **kwargs):
|
||||||
"""更新全部缓存数据"""
|
"""重新加载全部数据"""
|
||||||
|
try:
|
||||||
if self.has_args():
|
if self.has_args():
|
||||||
self.data = (
|
new_data = (
|
||||||
await self.func(*args, **kwargs)
|
await self.func(*args, **kwargs)
|
||||||
if is_coroutine_callable(self.func)
|
if is_coroutine_callable(self.func)
|
||||||
else self.func(*args, **kwargs)
|
else self.func(*args, **kwargs)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.data = (
|
new_data = (
|
||||||
await self.func() if is_coroutine_callable(self.func) else self.func()
|
await self.func()
|
||||||
|
if is_coroutine_callable(self.func)
|
||||||
|
else self.func()
|
||||||
)
|
)
|
||||||
await self.call_with_expiration()
|
|
||||||
self.reload_time = time.time()
|
await self.set_data(new_data)
|
||||||
self.reload_count += 1
|
self.reload_count += 1
|
||||||
logger.debug(
|
logger.info(f"重新加载缓存 {self.name} 完成")
|
||||||
f"缓存类型 {self.name} 更新全局缓存,共更新 {len(self.data)} 条数据",
|
except Exception as e:
|
||||||
"CacheRoot",
|
logger.error(f"重新加载缓存 {self.name} 失败: {e}")
|
||||||
)
|
raise
|
||||||
|
|
||||||
def call_cleanup_expired(self):
|
def has_args(self) -> bool:
|
||||||
"""清理过期缓存"""
|
"""检查函数是否需要参数"""
|
||||||
if self.cleanup_expired:
|
|
||||||
if result := self.cleanup_expired(self):
|
|
||||||
logger.debug(
|
|
||||||
f"成功清理 {self.name} {len(result)} 条过期缓存", "CacheRoot"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def call_with_expiration(self, is_force: bool = False):
|
|
||||||
"""缓存时间更新
|
|
||||||
|
|
||||||
参数:
|
|
||||||
is_force: 是否强制更新全部数据缓存时间.
|
|
||||||
"""
|
|
||||||
if self.with_expiration:
|
|
||||||
if is_force:
|
|
||||||
self.expire_data = {}
|
|
||||||
expiration_data = (
|
|
||||||
await self.with_expiration(self.data, self.expire_data, self.expire)
|
|
||||||
if is_coroutine_callable(self.with_expiration)
|
|
||||||
else self.with_expiration(self.data, self.expire_data, self.expire)
|
|
||||||
)
|
|
||||||
self.expire_data = {**self.expire_data, **expiration_data}
|
|
||||||
|
|
||||||
def has_args(self):
|
|
||||||
"""是否含有参数
|
|
||||||
|
|
||||||
返回:
|
|
||||||
bool: 是否含有参数
|
|
||||||
"""
|
|
||||||
sig = inspect.signature(self.func)
|
sig = inspect.signature(self.func)
|
||||||
return any(
|
return any(
|
||||||
param.kind
|
param.kind
|
||||||
@ -189,66 +282,81 @@ class CacheData(BaseModel):
|
|||||||
for param in sig.parameters.values()
|
for param in sig.parameters.values()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def get_key(self, key: str) -> Any:
|
||||||
|
"""获取缓存中指定键的数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: 要获取的键名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
键对应的值,如果不存在返回None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
data = await self.get_data()
|
||||||
|
return data.get(key) if isinstance(data, dict) else None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取缓存 {self.name}.{key} 失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_keys(self, keys: list[str]) -> dict[str, Any]:
|
||||||
|
"""获取缓存中多个键的数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keys: 要获取的键名列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含所有请求键值的字典,不存在的键值为None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
data = await self.get_data()
|
||||||
|
if isinstance(data, dict):
|
||||||
|
return {key: data.get(key) for key in keys}
|
||||||
|
return dict.fromkeys(keys)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取缓存 {self.name} 的多个键失败: {e}")
|
||||||
|
return dict.fromkeys(keys)
|
||||||
|
|
||||||
|
|
||||||
class CacheManager:
|
class CacheManager:
|
||||||
"""全局缓存管理,减少数据库与网络请求查询次数
|
"""全局缓存管理器"""
|
||||||
|
|
||||||
|
|
||||||
异常:
|
|
||||||
DbCacheException: 数据名称重复
|
|
||||||
DbCacheException: 数据不存在
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
_data: ClassVar[dict[str, CacheData]] = {}
|
_data: ClassVar[dict[str, CacheData]] = {}
|
||||||
|
|
||||||
def start_check(self):
|
def new(self, name: str, incremental_update: bool = True, expire: int = 600):
|
||||||
"""启动缓存检查"""
|
"""注册新缓存"""
|
||||||
for cache_data in self._data.values():
|
|
||||||
if cache_data.cleanup_expired:
|
|
||||||
scheduler.add_job(
|
|
||||||
cache_data.call_cleanup_expired,
|
|
||||||
"interval",
|
|
||||||
seconds=cache_data.expire,
|
|
||||||
args=[],
|
|
||||||
id=f"CacheRoot-{cache_data.name}",
|
|
||||||
)
|
|
||||||
|
|
||||||
def new(self, name: str, incremental_update: bool = True, expire: int = 60 * 10):
|
|
||||||
def wrapper(func: Callable):
|
def wrapper(func: Callable):
|
||||||
_name = name.upper()
|
_name = name.upper()
|
||||||
if _name in self._data:
|
if _name in self._data:
|
||||||
raise DbCacheException(f"DbCache 缓存数据 {name} 已存在...")
|
raise DbCacheException(f"缓存 {name} 已存在")
|
||||||
|
|
||||||
self._data[_name] = CacheData(
|
self._data[_name] = CacheData(
|
||||||
name=_name,
|
name=_name,
|
||||||
func=func,
|
func=func,
|
||||||
expire=expire,
|
expire=expire,
|
||||||
incremental_update=incremental_update,
|
incremental_update=incremental_update,
|
||||||
)
|
)
|
||||||
|
return func
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
def listener(self, name: str):
|
def listener(self, name: str):
|
||||||
|
"""创建缓存监听器"""
|
||||||
|
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
async def wrapper(*args, **kwargs):
|
async def wrapper(*args, **kwargs):
|
||||||
try:
|
try:
|
||||||
if is_coroutine_callable(func):
|
return (
|
||||||
result = await func(*args, **kwargs)
|
await func(*args, **kwargs)
|
||||||
else:
|
if is_coroutine_callable(func)
|
||||||
result = func(*args, **kwargs)
|
else func(*args, **kwargs)
|
||||||
return result
|
|
||||||
finally:
|
|
||||||
cache_data = self._data.get(name.upper())
|
|
||||||
if cache_data and cache_data.with_refresh:
|
|
||||||
if is_coroutine_callable(cache_data.with_refresh):
|
|
||||||
await cache_data.with_refresh(cache_data.data)
|
|
||||||
else:
|
|
||||||
cache_data.with_refresh(cache_data.data)
|
|
||||||
await cache_data.call_with_expiration(True)
|
|
||||||
logger.debug(
|
|
||||||
f"缓存类型 {name.upper()} 进行监听更新...", "CacheRoot"
|
|
||||||
)
|
)
|
||||||
|
finally:
|
||||||
|
cache = self._data.get(name.upper())
|
||||||
|
if cache and cache.with_refresh:
|
||||||
|
await cache.refresh()
|
||||||
|
logger.debug(f"监听器触发缓存 {name} 刷新")
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
@ -256,86 +364,79 @@ class CacheManager:
|
|||||||
|
|
||||||
@validate_name
|
@validate_name
|
||||||
def updater(self, name: str):
|
def updater(self, name: str):
|
||||||
|
"""设置缓存更新方法"""
|
||||||
|
|
||||||
def wrapper(func: Callable):
|
def wrapper(func: Callable):
|
||||||
self._data[name.upper()].updater = func
|
self._data[name].updater = func
|
||||||
|
return func
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
@validate_name
|
@validate_name
|
||||||
def getter(self, name: str, result_model: type):
|
def getter(self, name: str, result_model: type):
|
||||||
|
"""设置缓存获取方法"""
|
||||||
|
|
||||||
def wrapper(func: Callable):
|
def wrapper(func: Callable):
|
||||||
self._data[name].getter = CacheGetter[result_model](get_func=func)
|
self._data[name].getter = CacheGetter[result_model](get_func=func)
|
||||||
|
return func
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
@validate_name
|
@validate_name
|
||||||
def with_refresh(self, name: str):
|
def with_refresh(self, name: str):
|
||||||
|
"""设置缓存刷新方法"""
|
||||||
|
|
||||||
def wrapper(func: Callable):
|
def wrapper(func: Callable):
|
||||||
self._data[name.upper()].with_refresh = func
|
self._data[name].with_refresh = func
|
||||||
|
return func
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
@validate_name
|
async def get_cache_data(self, name: str) -> Any | None:
|
||||||
def with_expiration(self, name: str):
|
"""获取缓存数据"""
|
||||||
def wrapper(func: Callable[[Any, int], dict[str, float]]):
|
|
||||||
self._data[name.upper()].with_expiration = func
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
@validate_name
|
|
||||||
def cleanup_expired(self, name: str):
|
|
||||||
def wrapper(func: Callable[[CacheData], None]):
|
|
||||||
self._data[name.upper()].cleanup_expired = func
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
async def check_expire(self, name: str, *args, **kwargs):
|
|
||||||
name = name.upper()
|
|
||||||
if self._data.get(name) and (
|
|
||||||
time.time() - self._data[name].reload_time > self._data[name].expire
|
|
||||||
or not self._data[name].reload_count
|
|
||||||
):
|
|
||||||
await self._data[name].reload(*args, **kwargs)
|
|
||||||
|
|
||||||
async def get_cache_data(self, name: str):
|
|
||||||
return cache.data if (cache := await self.get_cache(name)) else None
|
|
||||||
|
|
||||||
async def get_cache(self, name: str, *args, **kwargs) -> CacheData | None:
|
|
||||||
name = name.upper()
|
|
||||||
if cache := self._data.get(name):
|
|
||||||
# await self.check_expire(name, *args, **kwargs)
|
|
||||||
return cache
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def get(self, name: str, *args, **kwargs):
|
|
||||||
cache = await self.get_cache(name.upper(), *args, **kwargs)
|
|
||||||
if cache:
|
|
||||||
return await cache.get(*args, **kwargs) if cache.getter else cache.data
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def reload(self, name: str, *args, **kwargs):
|
|
||||||
cache = await self.get_cache(name.upper())
|
cache = await self.get_cache(name.upper())
|
||||||
if cache:
|
return await cache.get_data() if cache else None
|
||||||
await cache.refresh(*args, **kwargs)
|
|
||||||
|
|
||||||
async def update(self, name: str, key: str, value: Any, *args, **kwargs):
|
async def get_cache(self, name: str) -> CacheData | None:
|
||||||
|
"""获取缓存对象"""
|
||||||
|
return self._data.get(name.upper())
|
||||||
|
|
||||||
|
async def get(self, name: str, *args, **kwargs) -> Any:
|
||||||
|
"""获取缓存内容"""
|
||||||
|
cache = await self.get_cache(name.upper())
|
||||||
|
return await cache.get(*args, **kwargs) if cache else None
|
||||||
|
|
||||||
|
async def update(self, name: str, key: str, value: Any = None, *args, **kwargs):
|
||||||
|
"""更新缓存项"""
|
||||||
cache = await self.get_cache(name.upper())
|
cache = await self.get_cache(name.upper())
|
||||||
if cache:
|
if cache:
|
||||||
await cache.update(key, value, *args, **kwargs)
|
await cache.update(key, value, *args, **kwargs)
|
||||||
|
|
||||||
|
async def reload(self, name: str, *args, **kwargs):
|
||||||
|
"""重新加载缓存"""
|
||||||
|
cache = await self.get_cache(name.upper())
|
||||||
|
if cache:
|
||||||
|
await cache.reload(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# 全局缓存管理器实例
|
||||||
CacheRoot = CacheManager()
|
CacheRoot = CacheManager()
|
||||||
|
|
||||||
|
|
||||||
class Cache(Generic[T]):
|
class Cache(Generic[T]):
|
||||||
|
"""类型化缓存访问接口"""
|
||||||
|
|
||||||
def __init__(self, module: str):
|
def __init__(self, module: str):
|
||||||
self.module = module
|
self.module = module.upper()
|
||||||
|
|
||||||
async def get(self, *args, **kwargs) -> T | None:
|
async def get(self, *args, **kwargs) -> T | None:
|
||||||
|
"""获取缓存"""
|
||||||
return await CacheRoot.get(self.module, *args, **kwargs)
|
return await CacheRoot.get(self.module, *args, **kwargs)
|
||||||
|
|
||||||
async def update(self, key: str, value: Any = None, *args, **kwargs):
|
async def update(self, key: str, value: Any = None, *args, **kwargs):
|
||||||
return await CacheRoot.update(self.module, key, value, *args, **kwargs)
|
"""更新缓存项"""
|
||||||
|
await CacheRoot.update(self.module, key, value, *args, **kwargs)
|
||||||
|
|
||||||
async def reload(self, key: str | None = None, *args, **kwargs):
|
async def reload(self, *args, **kwargs):
|
||||||
await CacheRoot.reload(self.module, key, *args, **kwargs)
|
"""重新加载缓存"""
|
||||||
|
await CacheRoot.reload(self.module, *args, **kwargs)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user