添加插件limit缓存

This commit is contained in:
HibiKier 2025-04-11 16:46:24 +08:00
parent 43f1e48c14
commit 8865ab0d58
7 changed files with 163 additions and 35 deletions

View File

@ -14,7 +14,7 @@ from zhenxun.utils.utils import EntityIDs, get_entity_ids
from .config import LOGGER_COMMAND
from .exception import SkipPluginException
from .utils import send_message
from .utils import freq, send_message
Config.add_plugin_config(
"hook",
@ -131,6 +131,7 @@ async def user_handle(
# and not db_plugin.ignore_prompt
and time != -1
and ban_result
and freq.is_send_limit_message(db_plugin, entity.user_id, False)
):
await send_message(
session,

View File

@ -8,7 +8,12 @@ from zhenxun.models.plugin_limit import PluginLimit
from zhenxun.services.log import logger
from zhenxun.utils.enum import LimitWatchType, PluginLimitType
from zhenxun.utils.message import MessageUtils
from zhenxun.utils.utils import CountLimiter, FreqLimiter, UserBlockLimiter
from zhenxun.utils.utils import (
CountLimiter,
FreqLimiter,
UserBlockLimiter,
get_entity_ids,
)
from .config import LOGGER_COMMAND
from .exception import SkipPluginException
@ -22,7 +27,7 @@ class Limit(BaseModel):
arbitrary_types_allowed = True
class LimitManage:
class LimitManager:
add_module: ClassVar[list] = []
cd_limit: ClassVar[dict[str, Limit]] = {}
@ -84,7 +89,6 @@ class LimitManage:
user_id: str,
group_id: str | None,
channel_id: str | None,
session: Uninfo,
):
"""检测限制
@ -93,17 +97,16 @@ class LimitManage:
user_id: 用户id
group_id: 群组id
channel_id: 频道id
session: Session
异常:
IgnoredException: IgnoredException
"""
if limit_model := cls.cd_limit.get(module):
await cls.__check(limit_model, user_id, group_id, channel_id, session)
await cls.__check(limit_model, user_id, group_id, channel_id)
if limit_model := cls.block_limit.get(module):
await cls.__check(limit_model, user_id, group_id, channel_id, session)
await cls.__check(limit_model, user_id, group_id, channel_id)
if limit_model := cls.count_limit.get(module):
await cls.__check(limit_model, user_id, group_id, channel_id, session)
await cls.__check(limit_model, user_id, group_id, channel_id)
@classmethod
async def __check(
@ -112,7 +115,6 @@ class LimitManage:
user_id: str,
group_id: str | None,
channel_id: str | None,
session: Uninfo,
):
"""检测限制
@ -121,7 +123,6 @@ class LimitManage:
user_id: 用户id
group_id: 群组id
channel_id: 频道id
session: Session
异常:
IgnoredException: IgnoredException
@ -166,23 +167,14 @@ async def auth_limit(plugin: PluginInfo, session: Uninfo):
plugin: PluginInfo
session: Uninfo
"""
user_id = session.user.id
group_id = None
channel_id = None
if session.group:
if session.group.parent:
group_id = session.group.parent.id
channel_id = session.group.id
else:
group_id = session.group.id
if not group_id:
group_id = channel_id
channel_id = None
if plugin.module not in LimitManage.add_module:
entity = get_entity_ids(session)
if plugin.module not in LimitManager.add_module:
limit_list: list[PluginLimit] = await plugin.plugin_limit.filter(
status=True
).all() # type: ignore
for limit in limit_list:
LimitManage.add_limit(limit)
if user_id:
await LimitManage.check(plugin.module, user_id, group_id, channel_id, session)
LimitManager.add_limit(limit)
if entity.user_id:
await LimitManager.check(
plugin.module, entity.user_id, entity.group_id, entity.channel_id
)

View File

@ -0,0 +1,35 @@
import nonebot
from nonebot_plugin_uninfo import Uninfo
from zhenxun.configs.config import Config
from .exception import SkipPluginException
Config.add_plugin_config(
"hook",
"FILTER_BOT",
True,
help="过滤当前连接bot防止bot互相调用",
default_value=True,
type=bool,
)
def bot_filter(session: Uninfo):
"""过滤bot调用bot
参数:
session: Uninfo
异常:
SkipPluginException: bot互相调用
"""
if not Config.get_config("hook", "FILTER_BOT"):
return
bot_ids = list(nonebot.get_bots().keys())
if session.user.id == session.self_id:
return
if session.user.id in bot_ids:
raise SkipPluginException(
f"bot:{session.self_id} 尝试调用 bot:{session.user.id}"
)

View File

@ -27,6 +27,7 @@ from .auth.auth_cost import auth_cost
from .auth.auth_group import auth_group
from .auth.auth_limit import LimitManage, auth_limit
from .auth.auth_plugin import auth_plugin
from .auth.bot_filter import bot_filter
from .auth.config import LOGGER_COMMAND
from .auth.exception import (
IsSuperuserException,
@ -152,6 +153,7 @@ async def auth(
cost_gold = await get_plugin_cost(bot, user, plugin, session)
await asyncio.gather(
*[
bot_filter(session),
auth_ban(matcher, bot, session),
auth_bot(plugin, bot.self_id),
auth_group(plugin, entity, message),

View File

@ -8,6 +8,7 @@ from zhenxun.models.bot_console import BotConsole
from zhenxun.models.group_console import GroupConsole
from zhenxun.models.level_user import LevelUser
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.plugin_limit import PluginLimit
from zhenxun.models.user_console import UserConsole
from zhenxun.services.cache import CacheData, CacheRoot
from zhenxun.utils.enum import CacheType
@ -38,16 +39,41 @@ def default_cleanup_expired(cache_data: CacheData) -> list[str]:
return expire_key
def default_cleanup_expired_1(cache_data: CacheData) -> list[str]:
"""默认清理列表过期cache方法"""
if not cache_data.data:
return []
now = time.time()
expire_key = []
for k, t in list(cache_data.expire_data.items()):
if t < now:
expire_key.append(k)
cache_data.expire_data.pop(k)
if expire_key:
cache_data.data = [k for k in cache_data.data if repr(k) not in expire_key]
return expire_key
def default_with_expiration(
data: dict[str, Any], expire_data: dict[str, int], expire: int
):
"""默认更新期时间cache方法"""
"""默认更新期时间cache方法"""
if not data:
return {}
keys = {k for k in data if k not in expire_data}
return {k: time.time() + expire for k in keys} if keys else {}
def default_with_expiration_1(
data: dict[str, Any], expire_data: dict[str, int], expire: int
):
"""默认更新过期时间cache方法"""
if not data:
return {}
keys = {repr(k) for k in data if repr(k) not in expire_data}
return {k: time.time() + expire for k in keys} if keys else {}
@CacheRoot.new(CacheType.PLUGINS)
async def _():
data_list = await PluginInfo.get_plugins()
@ -75,7 +101,7 @@ async def _(cache_data: CacheData, module: str):
@CacheRoot.with_refresh(CacheType.PLUGINS)
async def _(data: dict[str, PluginInfo]):
plugins = await PluginInfo.filter(module__in=data.keys(), load_status=True)
plugins = await PluginInfo.filter(module__in=data.keys(), load_status=True).all()
for plugin in plugins:
data[plugin.module] = plugin
@ -119,7 +145,7 @@ async def _(cache_data: CacheData, group_id: str):
async def _(data: dict[str, GroupConsole]):
groups = await GroupConsole.filter(
group_id__in=data.keys(), channel_id__isnull=True
)
).all()
for group in groups:
data[group.group_id] = group
@ -161,7 +187,7 @@ async def _(cache_data: CacheData, bot_id: str):
@CacheRoot.with_refresh(CacheType.BOT)
async def _(data: dict[str, BotConsole]):
bots = await BotConsole.filter(bot_id__in=data.keys())
bots = await BotConsole.filter(bot_id__in=data.keys()).all()
for bot in bots:
data[bot.bot_id] = bot
@ -203,7 +229,7 @@ async def _(cache_data: CacheData, user_id: str):
@CacheRoot.with_refresh(CacheType.USERS)
async def _(data: dict[str, UserConsole]):
users = await UserConsole.filter(user_id__in=data.keys())
users = await UserConsole.filter(user_id__in=data.keys()).all()
for user in users:
data[user.user_id] = user
@ -240,7 +266,17 @@ async def _(cache_data: CacheData, user_id: str, group_id: str | None = None):
]
@CacheRoot.new(CacheType.BAN, False, 5)
@CacheRoot.with_expiration(CacheType.LEVEL)
def _(data: dict[str, UserConsole], expire_data: dict[str, int], expire: int):
return default_with_expiration_1(data, expire_data, expire)
@CacheRoot.cleanup_expired(CacheType.LEVEL)
def _(cache_data: CacheData):
return default_cleanup_expired_1(cache_data)
@CacheRoot.new(CacheType.BAN, False)
async def _():
return await BanConsole.all()
@ -268,3 +304,63 @@ async def _(cache_data: CacheData, user_id: str | None, group_id: str | None = N
if not data.user_id and data.group_id == group_id
]
return None
@CacheRoot.with_expiration(CacheType.BAN)
def _(data: dict[str, UserConsole], expire_data: dict[str, int], expire: int):
return default_with_expiration_1(data, expire_data, expire)
@CacheRoot.cleanup_expired(CacheType.BAN)
def _(cache_data: CacheData):
return default_cleanup_expired_1(cache_data)
@CacheRoot.new(CacheType.LIMIT)
async def _():
data_list = await PluginLimit.filter(status=True).all()
result_data = {}
for data in data_list:
if not result_data.get(data.module):
result_data[data.module] = []
result_data[data.module].append(data)
return result_data
@CacheRoot.updater(CacheType.LIMIT)
async def _(data: dict[str, list[PluginLimit]], key: str, value: Any):
if value:
data[key] = value
elif limits := await PluginLimit.filter(module=key, status=True):
data[key] = limits
@CacheRoot.getter(CacheType.LIMIT, result_model=list[PluginLimit])
async def _(cache_data: CacheData, module: str):
cache_data.data = cache_data.data or {}
result = cache_data.data.get(module, None)
if not result:
result = await PluginLimit.filter(module=module, status=True)
if result:
cache_data.data[module] = result
return result
@CacheRoot.with_refresh(CacheType.LIMIT)
async def _(data: dict[str, list[PluginLimit]]):
limits = await PluginLimit.filter(module__in=data.keys(), load_status=True).all()
data.clear()
for limit in limits:
if not data.get(limit.module):
data[limit.module] = []
data[limit.module].append(limit)
@CacheRoot.with_expiration(CacheType.LIMIT)
def _(data: dict[str, PluginInfo], expire_data: dict[str, int], expire: int):
return default_with_expiration(data, expire_data, expire)
@CacheRoot.cleanup_expired(CacheType.LIMIT)
def _(cache_data: CacheData):
return default_cleanup_expired(cache_data)

View File

@ -33,7 +33,7 @@ def validate_name(func: Callable):
def wrapper(self, name: str, *args, **kwargs):
_name = name.upper()
if _name not in CacheManage._data:
if _name not in CacheManager._data:
raise DbCacheException(f"DbCache 缓存数据 {name} 不存在...")
return func(self, _name, *args, **kwargs)
@ -190,7 +190,7 @@ class CacheData(BaseModel):
)
class CacheManage:
class CacheManager:
"""全局缓存管理,减少数据库与网络请求查询次数
@ -324,7 +324,7 @@ class CacheManage:
await cache.update(key, value, *args, **kwargs)
CacheRoot = CacheManage()
CacheRoot = CacheManager()
class Cache(Generic[T]):

View File

@ -18,6 +18,8 @@ class CacheType(StrEnum):
"""全局bot信息"""
LEVEL = "GLOBAL_USER_LEVEL"
"""用户权限"""
LIMIT = "GLOBAL_LIMIT"
"""插件限制"""
class DbLockType(StrEnum):