mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 14:22:55 +08:00
✨ 添加插件limit缓存
This commit is contained in:
parent
43f1e48c14
commit
8865ab0d58
@ -14,7 +14,7 @@ from zhenxun.utils.utils import EntityIDs, get_entity_ids
|
||||
|
||||
from .config import LOGGER_COMMAND
|
||||
from .exception import SkipPluginException
|
||||
from .utils import send_message
|
||||
from .utils import freq, send_message
|
||||
|
||||
Config.add_plugin_config(
|
||||
"hook",
|
||||
@ -131,6 +131,7 @@ async def user_handle(
|
||||
# and not db_plugin.ignore_prompt
|
||||
and time != -1
|
||||
and ban_result
|
||||
and freq.is_send_limit_message(db_plugin, entity.user_id, False)
|
||||
):
|
||||
await send_message(
|
||||
session,
|
||||
|
||||
@ -8,7 +8,12 @@ from zhenxun.models.plugin_limit import PluginLimit
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import LimitWatchType, PluginLimitType
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
from zhenxun.utils.utils import CountLimiter, FreqLimiter, UserBlockLimiter
|
||||
from zhenxun.utils.utils import (
|
||||
CountLimiter,
|
||||
FreqLimiter,
|
||||
UserBlockLimiter,
|
||||
get_entity_ids,
|
||||
)
|
||||
|
||||
from .config import LOGGER_COMMAND
|
||||
from .exception import SkipPluginException
|
||||
@ -22,7 +27,7 @@ class Limit(BaseModel):
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class LimitManage:
|
||||
class LimitManager:
|
||||
add_module: ClassVar[list] = []
|
||||
|
||||
cd_limit: ClassVar[dict[str, Limit]] = {}
|
||||
@ -84,7 +89,6 @@ class LimitManage:
|
||||
user_id: str,
|
||||
group_id: str | None,
|
||||
channel_id: str | None,
|
||||
session: Uninfo,
|
||||
):
|
||||
"""检测限制
|
||||
|
||||
@ -93,17 +97,16 @@ class LimitManage:
|
||||
user_id: 用户id
|
||||
group_id: 群组id
|
||||
channel_id: 频道id
|
||||
session: Session
|
||||
|
||||
异常:
|
||||
IgnoredException: IgnoredException
|
||||
"""
|
||||
if limit_model := cls.cd_limit.get(module):
|
||||
await cls.__check(limit_model, user_id, group_id, channel_id, session)
|
||||
await cls.__check(limit_model, user_id, group_id, channel_id)
|
||||
if limit_model := cls.block_limit.get(module):
|
||||
await cls.__check(limit_model, user_id, group_id, channel_id, session)
|
||||
await cls.__check(limit_model, user_id, group_id, channel_id)
|
||||
if limit_model := cls.count_limit.get(module):
|
||||
await cls.__check(limit_model, user_id, group_id, channel_id, session)
|
||||
await cls.__check(limit_model, user_id, group_id, channel_id)
|
||||
|
||||
@classmethod
|
||||
async def __check(
|
||||
@ -112,7 +115,6 @@ class LimitManage:
|
||||
user_id: str,
|
||||
group_id: str | None,
|
||||
channel_id: str | None,
|
||||
session: Uninfo,
|
||||
):
|
||||
"""检测限制
|
||||
|
||||
@ -121,7 +123,6 @@ class LimitManage:
|
||||
user_id: 用户id
|
||||
group_id: 群组id
|
||||
channel_id: 频道id
|
||||
session: Session
|
||||
|
||||
异常:
|
||||
IgnoredException: IgnoredException
|
||||
@ -166,23 +167,14 @@ async def auth_limit(plugin: PluginInfo, session: Uninfo):
|
||||
plugin: PluginInfo
|
||||
session: Uninfo
|
||||
"""
|
||||
user_id = session.user.id
|
||||
group_id = None
|
||||
channel_id = None
|
||||
if session.group:
|
||||
if session.group.parent:
|
||||
group_id = session.group.parent.id
|
||||
channel_id = session.group.id
|
||||
else:
|
||||
group_id = session.group.id
|
||||
if not group_id:
|
||||
group_id = channel_id
|
||||
channel_id = None
|
||||
if plugin.module not in LimitManage.add_module:
|
||||
entity = get_entity_ids(session)
|
||||
if plugin.module not in LimitManager.add_module:
|
||||
limit_list: list[PluginLimit] = await plugin.plugin_limit.filter(
|
||||
status=True
|
||||
).all() # type: ignore
|
||||
for limit in limit_list:
|
||||
LimitManage.add_limit(limit)
|
||||
if user_id:
|
||||
await LimitManage.check(plugin.module, user_id, group_id, channel_id, session)
|
||||
LimitManager.add_limit(limit)
|
||||
if entity.user_id:
|
||||
await LimitManager.check(
|
||||
plugin.module, entity.user_id, entity.group_id, entity.channel_id
|
||||
)
|
||||
|
||||
35
zhenxun/builtin_plugins/hooks/auth/bot_filter.py
Normal file
35
zhenxun/builtin_plugins/hooks/auth/bot_filter.py
Normal file
@ -0,0 +1,35 @@
|
||||
import nonebot
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
|
||||
from .exception import SkipPluginException
|
||||
|
||||
Config.add_plugin_config(
|
||||
"hook",
|
||||
"FILTER_BOT",
|
||||
True,
|
||||
help="过滤当前连接bot(防止bot互相调用)",
|
||||
default_value=True,
|
||||
type=bool,
|
||||
)
|
||||
|
||||
|
||||
def bot_filter(session: Uninfo):
|
||||
"""过滤bot调用bot
|
||||
|
||||
参数:
|
||||
session: Uninfo
|
||||
|
||||
异常:
|
||||
SkipPluginException: bot互相调用
|
||||
"""
|
||||
if not Config.get_config("hook", "FILTER_BOT"):
|
||||
return
|
||||
bot_ids = list(nonebot.get_bots().keys())
|
||||
if session.user.id == session.self_id:
|
||||
return
|
||||
if session.user.id in bot_ids:
|
||||
raise SkipPluginException(
|
||||
f"bot:{session.self_id} 尝试调用 bot:{session.user.id}"
|
||||
)
|
||||
@ -27,6 +27,7 @@ from .auth.auth_cost import auth_cost
|
||||
from .auth.auth_group import auth_group
|
||||
from .auth.auth_limit import LimitManage, auth_limit
|
||||
from .auth.auth_plugin import auth_plugin
|
||||
from .auth.bot_filter import bot_filter
|
||||
from .auth.config import LOGGER_COMMAND
|
||||
from .auth.exception import (
|
||||
IsSuperuserException,
|
||||
@ -152,6 +153,7 @@ async def auth(
|
||||
cost_gold = await get_plugin_cost(bot, user, plugin, session)
|
||||
await asyncio.gather(
|
||||
*[
|
||||
bot_filter(session),
|
||||
auth_ban(matcher, bot, session),
|
||||
auth_bot(plugin, bot.self_id),
|
||||
auth_group(plugin, entity, message),
|
||||
|
||||
@ -8,6 +8,7 @@ from zhenxun.models.bot_console import BotConsole
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.models.level_user import LevelUser
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.models.plugin_limit import PluginLimit
|
||||
from zhenxun.models.user_console import UserConsole
|
||||
from zhenxun.services.cache import CacheData, CacheRoot
|
||||
from zhenxun.utils.enum import CacheType
|
||||
@ -38,16 +39,41 @@ def default_cleanup_expired(cache_data: CacheData) -> list[str]:
|
||||
return expire_key
|
||||
|
||||
|
||||
def default_cleanup_expired_1(cache_data: CacheData) -> list[str]:
|
||||
"""默认清理列表过期cache方法"""
|
||||
if not cache_data.data:
|
||||
return []
|
||||
now = time.time()
|
||||
expire_key = []
|
||||
for k, t in list(cache_data.expire_data.items()):
|
||||
if t < now:
|
||||
expire_key.append(k)
|
||||
cache_data.expire_data.pop(k)
|
||||
if expire_key:
|
||||
cache_data.data = [k for k in cache_data.data if repr(k) not in expire_key]
|
||||
return expire_key
|
||||
|
||||
|
||||
def default_with_expiration(
|
||||
data: dict[str, Any], expire_data: dict[str, int], expire: int
|
||||
):
|
||||
"""默认更新期时间cache方法"""
|
||||
"""默认更新过期时间cache方法"""
|
||||
if not data:
|
||||
return {}
|
||||
keys = {k for k in data if k not in expire_data}
|
||||
return {k: time.time() + expire for k in keys} if keys else {}
|
||||
|
||||
|
||||
def default_with_expiration_1(
|
||||
data: dict[str, Any], expire_data: dict[str, int], expire: int
|
||||
):
|
||||
"""默认更新过期时间cache方法"""
|
||||
if not data:
|
||||
return {}
|
||||
keys = {repr(k) for k in data if repr(k) not in expire_data}
|
||||
return {k: time.time() + expire for k in keys} if keys else {}
|
||||
|
||||
|
||||
@CacheRoot.new(CacheType.PLUGINS)
|
||||
async def _():
|
||||
data_list = await PluginInfo.get_plugins()
|
||||
@ -75,7 +101,7 @@ async def _(cache_data: CacheData, module: str):
|
||||
|
||||
@CacheRoot.with_refresh(CacheType.PLUGINS)
|
||||
async def _(data: dict[str, PluginInfo]):
|
||||
plugins = await PluginInfo.filter(module__in=data.keys(), load_status=True)
|
||||
plugins = await PluginInfo.filter(module__in=data.keys(), load_status=True).all()
|
||||
for plugin in plugins:
|
||||
data[plugin.module] = plugin
|
||||
|
||||
@ -119,7 +145,7 @@ async def _(cache_data: CacheData, group_id: str):
|
||||
async def _(data: dict[str, GroupConsole]):
|
||||
groups = await GroupConsole.filter(
|
||||
group_id__in=data.keys(), channel_id__isnull=True
|
||||
)
|
||||
).all()
|
||||
for group in groups:
|
||||
data[group.group_id] = group
|
||||
|
||||
@ -161,7 +187,7 @@ async def _(cache_data: CacheData, bot_id: str):
|
||||
|
||||
@CacheRoot.with_refresh(CacheType.BOT)
|
||||
async def _(data: dict[str, BotConsole]):
|
||||
bots = await BotConsole.filter(bot_id__in=data.keys())
|
||||
bots = await BotConsole.filter(bot_id__in=data.keys()).all()
|
||||
for bot in bots:
|
||||
data[bot.bot_id] = bot
|
||||
|
||||
@ -203,7 +229,7 @@ async def _(cache_data: CacheData, user_id: str):
|
||||
|
||||
@CacheRoot.with_refresh(CacheType.USERS)
|
||||
async def _(data: dict[str, UserConsole]):
|
||||
users = await UserConsole.filter(user_id__in=data.keys())
|
||||
users = await UserConsole.filter(user_id__in=data.keys()).all()
|
||||
for user in users:
|
||||
data[user.user_id] = user
|
||||
|
||||
@ -240,7 +266,17 @@ async def _(cache_data: CacheData, user_id: str, group_id: str | None = None):
|
||||
]
|
||||
|
||||
|
||||
@CacheRoot.new(CacheType.BAN, False, 5)
|
||||
@CacheRoot.with_expiration(CacheType.LEVEL)
|
||||
def _(data: dict[str, UserConsole], expire_data: dict[str, int], expire: int):
|
||||
return default_with_expiration_1(data, expire_data, expire)
|
||||
|
||||
|
||||
@CacheRoot.cleanup_expired(CacheType.LEVEL)
|
||||
def _(cache_data: CacheData):
|
||||
return default_cleanup_expired_1(cache_data)
|
||||
|
||||
|
||||
@CacheRoot.new(CacheType.BAN, False)
|
||||
async def _():
|
||||
return await BanConsole.all()
|
||||
|
||||
@ -268,3 +304,63 @@ async def _(cache_data: CacheData, user_id: str | None, group_id: str | None = N
|
||||
if not data.user_id and data.group_id == group_id
|
||||
]
|
||||
return None
|
||||
|
||||
|
||||
@CacheRoot.with_expiration(CacheType.BAN)
|
||||
def _(data: dict[str, UserConsole], expire_data: dict[str, int], expire: int):
|
||||
return default_with_expiration_1(data, expire_data, expire)
|
||||
|
||||
|
||||
@CacheRoot.cleanup_expired(CacheType.BAN)
|
||||
def _(cache_data: CacheData):
|
||||
return default_cleanup_expired_1(cache_data)
|
||||
|
||||
|
||||
@CacheRoot.new(CacheType.LIMIT)
|
||||
async def _():
|
||||
data_list = await PluginLimit.filter(status=True).all()
|
||||
result_data = {}
|
||||
for data in data_list:
|
||||
if not result_data.get(data.module):
|
||||
result_data[data.module] = []
|
||||
result_data[data.module].append(data)
|
||||
return result_data
|
||||
|
||||
|
||||
@CacheRoot.updater(CacheType.LIMIT)
|
||||
async def _(data: dict[str, list[PluginLimit]], key: str, value: Any):
|
||||
if value:
|
||||
data[key] = value
|
||||
elif limits := await PluginLimit.filter(module=key, status=True):
|
||||
data[key] = limits
|
||||
|
||||
|
||||
@CacheRoot.getter(CacheType.LIMIT, result_model=list[PluginLimit])
|
||||
async def _(cache_data: CacheData, module: str):
|
||||
cache_data.data = cache_data.data or {}
|
||||
result = cache_data.data.get(module, None)
|
||||
if not result:
|
||||
result = await PluginLimit.filter(module=module, status=True)
|
||||
if result:
|
||||
cache_data.data[module] = result
|
||||
return result
|
||||
|
||||
|
||||
@CacheRoot.with_refresh(CacheType.LIMIT)
|
||||
async def _(data: dict[str, list[PluginLimit]]):
|
||||
limits = await PluginLimit.filter(module__in=data.keys(), load_status=True).all()
|
||||
data.clear()
|
||||
for limit in limits:
|
||||
if not data.get(limit.module):
|
||||
data[limit.module] = []
|
||||
data[limit.module].append(limit)
|
||||
|
||||
|
||||
@CacheRoot.with_expiration(CacheType.LIMIT)
|
||||
def _(data: dict[str, PluginInfo], expire_data: dict[str, int], expire: int):
|
||||
return default_with_expiration(data, expire_data, expire)
|
||||
|
||||
|
||||
@CacheRoot.cleanup_expired(CacheType.LIMIT)
|
||||
def _(cache_data: CacheData):
|
||||
return default_cleanup_expired(cache_data)
|
||||
|
||||
@ -33,7 +33,7 @@ def validate_name(func: Callable):
|
||||
|
||||
def wrapper(self, name: str, *args, **kwargs):
|
||||
_name = name.upper()
|
||||
if _name not in CacheManage._data:
|
||||
if _name not in CacheManager._data:
|
||||
raise DbCacheException(f"DbCache 缓存数据 {name} 不存在...")
|
||||
return func(self, _name, *args, **kwargs)
|
||||
|
||||
@ -190,7 +190,7 @@ class CacheData(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class CacheManage:
|
||||
class CacheManager:
|
||||
"""全局缓存管理,减少数据库与网络请求查询次数
|
||||
|
||||
|
||||
@ -324,7 +324,7 @@ class CacheManage:
|
||||
await cache.update(key, value, *args, **kwargs)
|
||||
|
||||
|
||||
CacheRoot = CacheManage()
|
||||
CacheRoot = CacheManager()
|
||||
|
||||
|
||||
class Cache(Generic[T]):
|
||||
|
||||
@ -18,6 +18,8 @@ class CacheType(StrEnum):
|
||||
"""全局bot信息"""
|
||||
LEVEL = "GLOBAL_USER_LEVEL"
|
||||
"""用户权限"""
|
||||
LIMIT = "GLOBAL_LIMIT"
|
||||
"""插件限制"""
|
||||
|
||||
|
||||
class DbLockType(StrEnum):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user