添加全局cache

This commit is contained in:
HibiKier 2025-01-07 18:29:35 +08:00
parent e124c1dbdb
commit 45649bb29d
6 changed files with 233 additions and 2 deletions

View File

@ -15,8 +15,10 @@ from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.plugin_limit import PluginLimit from zhenxun.models.plugin_limit import PluginLimit
from zhenxun.models.user_console import UserConsole from zhenxun.models.user_console import UserConsole
from zhenxun.services.log import logger from zhenxun.services.log import logger
from zhenxun.utils.cache_utils import Cache
from zhenxun.utils.enum import ( from zhenxun.utils.enum import (
BlockType, BlockType,
CacheType,
GoldHandle, GoldHandle,
LimitWatchType, LimitWatchType,
PluginLimitType, PluginLimitType,
@ -248,7 +250,7 @@ class AuthChecker:
e=e, e=e,
) )
return return
if plugin := await PluginInfo.get_or_none(module_path=module_path): if plugin := await Cache.get(CacheType.PLUGINS, module_path):
if plugin.plugin_type == PluginType.HIDDEN: if plugin.plugin_type == PluginType.HIDDEN:
logger.debug( logger.debug(
f"插件: {plugin.name}:{plugin.module} " f"插件: {plugin.name}:{plugin.module} "

View File

@ -0,0 +1,34 @@
from nonebot.exception import IgnoredException
from zhenxun.models.bot_console import BotConsole
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.log import logger
from zhenxun.utils.cache_utils import Cache
from zhenxun.utils.enum import CacheType
async def get_bot_status(bot_id: str):
a = await Cache.get(CacheType.BOT, bot_id)
cache_data = await Cache.get_cache(CacheType.BOT)
if cache_data and cache_data.getter:
b = await cache_data.getter.get(cache_data.data)
if bot := await Cache.get(CacheType.BOT, bot_id):
return bot
async def auth_bot(plugin: PluginInfo, bot_id: str):
"""机器人权限
参数:
plugin: PluginInfo
bot_id: bot_id
"""
if not await BotConsole.get_bot_status(bot_id):
logger.debug("Bot休眠中阻断权限检测...", "AuthChecker")
raise IgnoredException("BotConsole休眠权限检测 ignore")
if await BotConsole.is_block_plugin(bot_id, plugin.module):
logger.debug(
f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭...",
"AuthChecker",
)
raise IgnoredException("BotConsole插件权限检测 ignore")

View File

@ -3,8 +3,13 @@ from pathlib import Path
import nonebot import nonebot
from nonebot.adapters import Bot from nonebot.adapters import Bot
from zhenxun.models.ban_console import BanConsole
from zhenxun.models.bot_console import BotConsole
from zhenxun.models.group_console import GroupConsole from zhenxun.models.group_console import GroupConsole
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.log import logger from zhenxun.services.log import logger
from zhenxun.utils.cache_utils import Cache
from zhenxun.utils.enum import CacheType
from zhenxun.utils.platform import PlatformUtils from zhenxun.utils.platform import PlatformUtils
nonebot.load_plugins(str(Path(__file__).parent.resolve())) nonebot.load_plugins(str(Path(__file__).parent.resolve()))
@ -41,3 +46,66 @@ async def _(bot: Bot):
f"更新Bot: {bot.self_id} 的群认证完成,共创建 {len(create_list)} 条数据," f"更新Bot: {bot.self_id} 的群认证完成,共创建 {len(create_list)} 条数据,"
f"共修改 {len(update_id)} 条数据..." f"共修改 {len(update_id)} 条数据..."
) )
@Cache.listener(CacheType.PLUGINS)
async def _():
data_list = await PluginInfo.get_plugins()
return {p.module: p for p in data_list}
@Cache.getter(CacheType.PLUGINS, result_model=PluginInfo)
def _(data: dict[str, PluginInfo], module: str):
return data.get(module, None)
@Cache.listener(CacheType.GROUPS)
async def _():
data_list = await GroupConsole.all()
return {p.group_id: p for p in data_list}
@Cache.getter(CacheType.GROUPS, result_model=GroupConsole)
def _(data: dict[str, GroupConsole], module: str):
return data.get(module, None)
@Cache.listener(CacheType.BOT)
async def _():
data_list = await BotConsole.all()
return {p.bot_id: p for p in data_list}
@Cache.getter(CacheType.BOT, result_model=BotConsole)
def _(data: dict[str, BotConsole], module: str):
return data.get(module, None)
@Cache.listener(CacheType.BAN)
async def _():
return await BanConsole.all()
@Cache.getter(CacheType.BAN, result_model=list[BanConsole])
def _(data_list: list[BanConsole], user_id: str, group_id: str):
if user_id:
if group_id:
return [
data
for data in data_list
if data.user_id == user_id and data.group_id == group_id
]
else:
return [
data
for data in data_list
if data.user_id == user_id and not data.group_id
]
else:
if group_id:
return [
data
for data in data_list
if not data.user_id and data.group_id == group_id
]
return None

View File

@ -1,5 +1,4 @@
from nonebot.adapters import Bot from nonebot.adapters import Bot
from nonebot.adapters.kaiheila.exception import ApiNotAvailable
from nonebot.permission import SUPERUSER from nonebot.permission import SUPERUSER
from nonebot.plugin import PluginMetadata from nonebot.plugin import PluginMetadata
from nonebot.rule import to_me from nonebot.rule import to_me

View File

@ -0,0 +1,113 @@
from collections.abc import Callable
import time
from typing import Any, ClassVar, Generic, TypeVar, cast
from nonebot.utils import is_coroutine_callable
from pydantic import BaseModel
__all__ = ["Cache", "CacheData"]
T = TypeVar("T")
class CacheGetter(BaseModel, Generic[T]):
get_func: Callable[..., Any] | None = None
"""获取方法"""
async def get(self, data: Any, *args, **kwargs) -> T:
"""获取缓存"""
processed_data = (
await self.get_func(data, *args, **kwargs)
if self.get_func and is_coroutine_callable(self.get_func)
else self.get_func(data, *args, **kwargs)
if self.get_func
else data
)
return cast(T, processed_data)
class CacheData(BaseModel):
func: Callable[..., Any]
"""更新方法"""
getter: CacheGetter | None = None
"""获取方法"""
data: Any = None
"""缓存数据"""
expire: int
"""缓存过期时间"""
reload_time = time.time()
"""更新时间"""
reload_count: int = 0
"""更新次数"""
async def reload(self):
"""更新缓存"""
self.data = (
await self.func() if is_coroutine_callable(self.func) else self.func()
)
self.reload_time = time.time()
self.reload_count += 1
async def check_expire(self):
if time.time() - self.reload_time > self.expire or not self.reload_count:
await self.reload()
class CacheManage:
_data: ClassVar[dict[str, CacheData]] = {}
def listener(self, name: str, expire: int = 60 * 10):
def wrapper(func: Callable):
_name = name.upper()
if _name in self._data:
raise ValueError(f"DbCache 缓存数据 {name} 已存在...")
self._data[_name] = CacheData(func=func, expire=expire)
return wrapper
def getter(self, name: str, result_model: type | None = None):
def wrapper(func: Callable):
_name = name.upper()
if _name not in self._data:
raise ValueError(f"DbCache 缓存数据 {name} 不存在...")
self._data[_name].getter = CacheGetter[result_model](get_func=func)
return wrapper
async def check_expire(self, name: str):
name = name.upper()
if self._data.get(name):
if (
time.time() - self._data[name].reload_time > self._data[name].expire
or not self._data[name].reload_count
):
await self._data[name].reload()
async def get_cache_data(self, name: str) -> CacheData | None:
if cache := await self.get_cache(name):
return cache
return None
async def get_cache(self, name: str):
name = name.upper()
cache = self._data.get(name)
if cache:
await self.check_expire(name)
return cache
return None
async def get(self, name: str, *args, **kwargs) -> T | None:
cache = self._data.get(name.upper())
if cache:
return (
await cache.getter.get(*args, **kwargs) if cache.getter else cache.data
)
return None
async def reload(self, name: str):
cache = self._data.get(name.upper())
if cache:
await cache.reload()
Cache = CacheManage()

View File

@ -1,6 +1,21 @@
from strenum import StrEnum from strenum import StrEnum
class CacheType(StrEnum):
"""
缓存类型
"""
PLUGINS = "GLOBAL_ALL_PLUGINS"
"""全局全部插件"""
GROUPS = "GLOBAL_ALL_GROUPS"
"""全局全部群组"""
BAN = "GLOBAL_ALL_BAN"
"""全局ban列表"""
BOT = "GLOBAL_BOT"
"""全局bot信息"""
class GoldHandle(StrEnum): class GoldHandle(StrEnum):
""" """
金币处理 金币处理