添加插件limit缓存

This commit is contained in:
HibiKier 2025-04-11 16:46:24 +08:00
parent 43f1e48c14
commit 8865ab0d58
7 changed files with 163 additions and 35 deletions

View File

@ -14,7 +14,7 @@ from zhenxun.utils.utils import EntityIDs, get_entity_ids
from .config import LOGGER_COMMAND from .config import LOGGER_COMMAND
from .exception import SkipPluginException from .exception import SkipPluginException
from .utils import send_message from .utils import freq, send_message
Config.add_plugin_config( Config.add_plugin_config(
"hook", "hook",
@ -131,6 +131,7 @@ async def user_handle(
# and not db_plugin.ignore_prompt # and not db_plugin.ignore_prompt
and time != -1 and time != -1
and ban_result and ban_result
and freq.is_send_limit_message(db_plugin, entity.user_id, False)
): ):
await send_message( await send_message(
session, session,

View File

@ -8,7 +8,12 @@ from zhenxun.models.plugin_limit import PluginLimit
from zhenxun.services.log import logger from zhenxun.services.log import logger
from zhenxun.utils.enum import LimitWatchType, PluginLimitType from zhenxun.utils.enum import LimitWatchType, PluginLimitType
from zhenxun.utils.message import MessageUtils from zhenxun.utils.message import MessageUtils
from zhenxun.utils.utils import CountLimiter, FreqLimiter, UserBlockLimiter from zhenxun.utils.utils import (
CountLimiter,
FreqLimiter,
UserBlockLimiter,
get_entity_ids,
)
from .config import LOGGER_COMMAND from .config import LOGGER_COMMAND
from .exception import SkipPluginException from .exception import SkipPluginException
@ -22,7 +27,7 @@ class Limit(BaseModel):
arbitrary_types_allowed = True arbitrary_types_allowed = True
class LimitManage: class LimitManager:
add_module: ClassVar[list] = [] add_module: ClassVar[list] = []
cd_limit: ClassVar[dict[str, Limit]] = {} cd_limit: ClassVar[dict[str, Limit]] = {}
@ -84,7 +89,6 @@ class LimitManage:
user_id: str, user_id: str,
group_id: str | None, group_id: str | None,
channel_id: str | None, channel_id: str | None,
session: Uninfo,
): ):
"""检测限制 """检测限制
@ -93,17 +97,16 @@ class LimitManage:
user_id: 用户id user_id: 用户id
group_id: 群组id group_id: 群组id
channel_id: 频道id channel_id: 频道id
session: Session
异常: 异常:
IgnoredException: IgnoredException IgnoredException: IgnoredException
""" """
if limit_model := cls.cd_limit.get(module): if limit_model := cls.cd_limit.get(module):
await cls.__check(limit_model, user_id, group_id, channel_id, session) await cls.__check(limit_model, user_id, group_id, channel_id)
if limit_model := cls.block_limit.get(module): if limit_model := cls.block_limit.get(module):
await cls.__check(limit_model, user_id, group_id, channel_id, session) await cls.__check(limit_model, user_id, group_id, channel_id)
if limit_model := cls.count_limit.get(module): if limit_model := cls.count_limit.get(module):
await cls.__check(limit_model, user_id, group_id, channel_id, session) await cls.__check(limit_model, user_id, group_id, channel_id)
@classmethod @classmethod
async def __check( async def __check(
@ -112,7 +115,6 @@ class LimitManage:
user_id: str, user_id: str,
group_id: str | None, group_id: str | None,
channel_id: str | None, channel_id: str | None,
session: Uninfo,
): ):
"""检测限制 """检测限制
@ -121,7 +123,6 @@ class LimitManage:
user_id: 用户id user_id: 用户id
group_id: 群组id group_id: 群组id
channel_id: 频道id channel_id: 频道id
session: Session
异常: 异常:
IgnoredException: IgnoredException IgnoredException: IgnoredException
@ -166,23 +167,14 @@ async def auth_limit(plugin: PluginInfo, session: Uninfo):
plugin: PluginInfo plugin: PluginInfo
session: Uninfo session: Uninfo
""" """
user_id = session.user.id entity = get_entity_ids(session)
group_id = None if plugin.module not in LimitManager.add_module:
channel_id = None
if session.group:
if session.group.parent:
group_id = session.group.parent.id
channel_id = session.group.id
else:
group_id = session.group.id
if not group_id:
group_id = channel_id
channel_id = None
if plugin.module not in LimitManage.add_module:
limit_list: list[PluginLimit] = await plugin.plugin_limit.filter( limit_list: list[PluginLimit] = await plugin.plugin_limit.filter(
status=True status=True
).all() # type: ignore ).all() # type: ignore
for limit in limit_list: for limit in limit_list:
LimitManage.add_limit(limit) LimitManager.add_limit(limit)
if user_id: if entity.user_id:
await LimitManage.check(plugin.module, user_id, group_id, channel_id, session) await LimitManager.check(
plugin.module, entity.user_id, entity.group_id, entity.channel_id
)

View File

@ -0,0 +1,35 @@
import nonebot
from nonebot_plugin_uninfo import Uninfo
from zhenxun.configs.config import Config
from .exception import SkipPluginException
Config.add_plugin_config(
"hook",
"FILTER_BOT",
True,
help="过滤当前连接bot防止bot互相调用",
default_value=True,
type=bool,
)
def bot_filter(session: Uninfo):
"""过滤bot调用bot
参数:
session: Uninfo
异常:
SkipPluginException: bot互相调用
"""
if not Config.get_config("hook", "FILTER_BOT"):
return
bot_ids = list(nonebot.get_bots().keys())
if session.user.id == session.self_id:
return
if session.user.id in bot_ids:
raise SkipPluginException(
f"bot:{session.self_id} 尝试调用 bot:{session.user.id}"
)

View File

@ -27,6 +27,7 @@ from .auth.auth_cost import auth_cost
from .auth.auth_group import auth_group from .auth.auth_group import auth_group
from .auth.auth_limit import LimitManage, auth_limit from .auth.auth_limit import LimitManage, auth_limit
from .auth.auth_plugin import auth_plugin from .auth.auth_plugin import auth_plugin
from .auth.bot_filter import bot_filter
from .auth.config import LOGGER_COMMAND from .auth.config import LOGGER_COMMAND
from .auth.exception import ( from .auth.exception import (
IsSuperuserException, IsSuperuserException,
@ -152,6 +153,7 @@ async def auth(
cost_gold = await get_plugin_cost(bot, user, plugin, session) cost_gold = await get_plugin_cost(bot, user, plugin, session)
await asyncio.gather( await asyncio.gather(
*[ *[
bot_filter(session),
auth_ban(matcher, bot, session), auth_ban(matcher, bot, session),
auth_bot(plugin, bot.self_id), auth_bot(plugin, bot.self_id),
auth_group(plugin, entity, message), auth_group(plugin, entity, message),

View File

@ -8,6 +8,7 @@ from zhenxun.models.bot_console import BotConsole
from zhenxun.models.group_console import GroupConsole from zhenxun.models.group_console import GroupConsole
from zhenxun.models.level_user import LevelUser from zhenxun.models.level_user import LevelUser
from zhenxun.models.plugin_info import PluginInfo from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.plugin_limit import PluginLimit
from zhenxun.models.user_console import UserConsole from zhenxun.models.user_console import UserConsole
from zhenxun.services.cache import CacheData, CacheRoot from zhenxun.services.cache import CacheData, CacheRoot
from zhenxun.utils.enum import CacheType from zhenxun.utils.enum import CacheType
@ -38,16 +39,41 @@ def default_cleanup_expired(cache_data: CacheData) -> list[str]:
return expire_key return expire_key
def default_cleanup_expired_1(cache_data: CacheData) -> list[str]:
"""默认清理列表过期cache方法"""
if not cache_data.data:
return []
now = time.time()
expire_key = []
for k, t in list(cache_data.expire_data.items()):
if t < now:
expire_key.append(k)
cache_data.expire_data.pop(k)
if expire_key:
cache_data.data = [k for k in cache_data.data if repr(k) not in expire_key]
return expire_key
def default_with_expiration( def default_with_expiration(
data: dict[str, Any], expire_data: dict[str, int], expire: int data: dict[str, Any], expire_data: dict[str, int], expire: int
): ):
"""默认更新期时间cache方法""" """默认更新期时间cache方法"""
if not data: if not data:
return {} return {}
keys = {k for k in data if k not in expire_data} keys = {k for k in data if k not in expire_data}
return {k: time.time() + expire for k in keys} if keys else {} return {k: time.time() + expire for k in keys} if keys else {}
def default_with_expiration_1(
data: dict[str, Any], expire_data: dict[str, int], expire: int
):
"""默认更新过期时间cache方法"""
if not data:
return {}
keys = {repr(k) for k in data if repr(k) not in expire_data}
return {k: time.time() + expire for k in keys} if keys else {}
@CacheRoot.new(CacheType.PLUGINS) @CacheRoot.new(CacheType.PLUGINS)
async def _(): async def _():
data_list = await PluginInfo.get_plugins() data_list = await PluginInfo.get_plugins()
@ -75,7 +101,7 @@ async def _(cache_data: CacheData, module: str):
@CacheRoot.with_refresh(CacheType.PLUGINS) @CacheRoot.with_refresh(CacheType.PLUGINS)
async def _(data: dict[str, PluginInfo]): async def _(data: dict[str, PluginInfo]):
plugins = await PluginInfo.filter(module__in=data.keys(), load_status=True) plugins = await PluginInfo.filter(module__in=data.keys(), load_status=True).all()
for plugin in plugins: for plugin in plugins:
data[plugin.module] = plugin data[plugin.module] = plugin
@ -119,7 +145,7 @@ async def _(cache_data: CacheData, group_id: str):
async def _(data: dict[str, GroupConsole]): async def _(data: dict[str, GroupConsole]):
groups = await GroupConsole.filter( groups = await GroupConsole.filter(
group_id__in=data.keys(), channel_id__isnull=True group_id__in=data.keys(), channel_id__isnull=True
) ).all()
for group in groups: for group in groups:
data[group.group_id] = group data[group.group_id] = group
@ -161,7 +187,7 @@ async def _(cache_data: CacheData, bot_id: str):
@CacheRoot.with_refresh(CacheType.BOT) @CacheRoot.with_refresh(CacheType.BOT)
async def _(data: dict[str, BotConsole]): async def _(data: dict[str, BotConsole]):
bots = await BotConsole.filter(bot_id__in=data.keys()) bots = await BotConsole.filter(bot_id__in=data.keys()).all()
for bot in bots: for bot in bots:
data[bot.bot_id] = bot data[bot.bot_id] = bot
@ -203,7 +229,7 @@ async def _(cache_data: CacheData, user_id: str):
@CacheRoot.with_refresh(CacheType.USERS) @CacheRoot.with_refresh(CacheType.USERS)
async def _(data: dict[str, UserConsole]): async def _(data: dict[str, UserConsole]):
users = await UserConsole.filter(user_id__in=data.keys()) users = await UserConsole.filter(user_id__in=data.keys()).all()
for user in users: for user in users:
data[user.user_id] = user data[user.user_id] = user
@ -240,7 +266,17 @@ async def _(cache_data: CacheData, user_id: str, group_id: str | None = None):
] ]
@CacheRoot.new(CacheType.BAN, False, 5) @CacheRoot.with_expiration(CacheType.LEVEL)
def _(data: dict[str, UserConsole], expire_data: dict[str, int], expire: int):
return default_with_expiration_1(data, expire_data, expire)
@CacheRoot.cleanup_expired(CacheType.LEVEL)
def _(cache_data: CacheData):
return default_cleanup_expired_1(cache_data)
@CacheRoot.new(CacheType.BAN, False)
async def _(): async def _():
return await BanConsole.all() return await BanConsole.all()
@ -268,3 +304,63 @@ async def _(cache_data: CacheData, user_id: str | None, group_id: str | None = N
if not data.user_id and data.group_id == group_id if not data.user_id and data.group_id == group_id
] ]
return None return None
@CacheRoot.with_expiration(CacheType.BAN)
def _(data: dict[str, UserConsole], expire_data: dict[str, int], expire: int):
return default_with_expiration_1(data, expire_data, expire)
@CacheRoot.cleanup_expired(CacheType.BAN)
def _(cache_data: CacheData):
return default_cleanup_expired_1(cache_data)
@CacheRoot.new(CacheType.LIMIT)
async def _():
data_list = await PluginLimit.filter(status=True).all()
result_data = {}
for data in data_list:
if not result_data.get(data.module):
result_data[data.module] = []
result_data[data.module].append(data)
return result_data
@CacheRoot.updater(CacheType.LIMIT)
async def _(data: dict[str, list[PluginLimit]], key: str, value: Any):
if value:
data[key] = value
elif limits := await PluginLimit.filter(module=key, status=True):
data[key] = limits
@CacheRoot.getter(CacheType.LIMIT, result_model=list[PluginLimit])
async def _(cache_data: CacheData, module: str):
cache_data.data = cache_data.data or {}
result = cache_data.data.get(module, None)
if not result:
result = await PluginLimit.filter(module=module, status=True)
if result:
cache_data.data[module] = result
return result
@CacheRoot.with_refresh(CacheType.LIMIT)
async def _(data: dict[str, list[PluginLimit]]):
limits = await PluginLimit.filter(module__in=data.keys(), load_status=True).all()
data.clear()
for limit in limits:
if not data.get(limit.module):
data[limit.module] = []
data[limit.module].append(limit)
@CacheRoot.with_expiration(CacheType.LIMIT)
def _(data: dict[str, PluginInfo], expire_data: dict[str, int], expire: int):
return default_with_expiration(data, expire_data, expire)
@CacheRoot.cleanup_expired(CacheType.LIMIT)
def _(cache_data: CacheData):
return default_cleanup_expired(cache_data)

View File

@ -33,7 +33,7 @@ def validate_name(func: Callable):
def wrapper(self, name: str, *args, **kwargs): def wrapper(self, name: str, *args, **kwargs):
_name = name.upper() _name = name.upper()
if _name not in CacheManage._data: if _name not in CacheManager._data:
raise DbCacheException(f"DbCache 缓存数据 {name} 不存在...") raise DbCacheException(f"DbCache 缓存数据 {name} 不存在...")
return func(self, _name, *args, **kwargs) return func(self, _name, *args, **kwargs)
@ -190,7 +190,7 @@ class CacheData(BaseModel):
) )
class CacheManage: class CacheManager:
"""全局缓存管理,减少数据库与网络请求查询次数 """全局缓存管理,减少数据库与网络请求查询次数
@ -324,7 +324,7 @@ class CacheManage:
await cache.update(key, value, *args, **kwargs) await cache.update(key, value, *args, **kwargs)
CacheRoot = CacheManage() CacheRoot = CacheManager()
class Cache(Generic[T]): class Cache(Generic[T]):

View File

@ -18,6 +18,8 @@ class CacheType(StrEnum):
"""全局bot信息""" """全局bot信息"""
LEVEL = "GLOBAL_USER_LEVEL" LEVEL = "GLOBAL_USER_LEVEL"
"""用户权限""" """用户权限"""
LIMIT = "GLOBAL_LIMIT"
"""插件限制"""
class DbLockType(StrEnum): class DbLockType(StrEnum):