mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 14:22:55 +08:00
✨ 添加插件limit缓存
This commit is contained in:
parent
43f1e48c14
commit
8865ab0d58
@ -14,7 +14,7 @@ from zhenxun.utils.utils import EntityIDs, get_entity_ids
|
|||||||
|
|
||||||
from .config import LOGGER_COMMAND
|
from .config import LOGGER_COMMAND
|
||||||
from .exception import SkipPluginException
|
from .exception import SkipPluginException
|
||||||
from .utils import send_message
|
from .utils import freq, send_message
|
||||||
|
|
||||||
Config.add_plugin_config(
|
Config.add_plugin_config(
|
||||||
"hook",
|
"hook",
|
||||||
@ -131,6 +131,7 @@ async def user_handle(
|
|||||||
# and not db_plugin.ignore_prompt
|
# and not db_plugin.ignore_prompt
|
||||||
and time != -1
|
and time != -1
|
||||||
and ban_result
|
and ban_result
|
||||||
|
and freq.is_send_limit_message(db_plugin, entity.user_id, False)
|
||||||
):
|
):
|
||||||
await send_message(
|
await send_message(
|
||||||
session,
|
session,
|
||||||
|
|||||||
@ -8,7 +8,12 @@ from zhenxun.models.plugin_limit import PluginLimit
|
|||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
from zhenxun.utils.enum import LimitWatchType, PluginLimitType
|
from zhenxun.utils.enum import LimitWatchType, PluginLimitType
|
||||||
from zhenxun.utils.message import MessageUtils
|
from zhenxun.utils.message import MessageUtils
|
||||||
from zhenxun.utils.utils import CountLimiter, FreqLimiter, UserBlockLimiter
|
from zhenxun.utils.utils import (
|
||||||
|
CountLimiter,
|
||||||
|
FreqLimiter,
|
||||||
|
UserBlockLimiter,
|
||||||
|
get_entity_ids,
|
||||||
|
)
|
||||||
|
|
||||||
from .config import LOGGER_COMMAND
|
from .config import LOGGER_COMMAND
|
||||||
from .exception import SkipPluginException
|
from .exception import SkipPluginException
|
||||||
@ -22,7 +27,7 @@ class Limit(BaseModel):
|
|||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
class LimitManage:
|
class LimitManager:
|
||||||
add_module: ClassVar[list] = []
|
add_module: ClassVar[list] = []
|
||||||
|
|
||||||
cd_limit: ClassVar[dict[str, Limit]] = {}
|
cd_limit: ClassVar[dict[str, Limit]] = {}
|
||||||
@ -84,7 +89,6 @@ class LimitManage:
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
group_id: str | None,
|
group_id: str | None,
|
||||||
channel_id: str | None,
|
channel_id: str | None,
|
||||||
session: Uninfo,
|
|
||||||
):
|
):
|
||||||
"""检测限制
|
"""检测限制
|
||||||
|
|
||||||
@ -93,17 +97,16 @@ class LimitManage:
|
|||||||
user_id: 用户id
|
user_id: 用户id
|
||||||
group_id: 群组id
|
group_id: 群组id
|
||||||
channel_id: 频道id
|
channel_id: 频道id
|
||||||
session: Session
|
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
IgnoredException: IgnoredException
|
IgnoredException: IgnoredException
|
||||||
"""
|
"""
|
||||||
if limit_model := cls.cd_limit.get(module):
|
if limit_model := cls.cd_limit.get(module):
|
||||||
await cls.__check(limit_model, user_id, group_id, channel_id, session)
|
await cls.__check(limit_model, user_id, group_id, channel_id)
|
||||||
if limit_model := cls.block_limit.get(module):
|
if limit_model := cls.block_limit.get(module):
|
||||||
await cls.__check(limit_model, user_id, group_id, channel_id, session)
|
await cls.__check(limit_model, user_id, group_id, channel_id)
|
||||||
if limit_model := cls.count_limit.get(module):
|
if limit_model := cls.count_limit.get(module):
|
||||||
await cls.__check(limit_model, user_id, group_id, channel_id, session)
|
await cls.__check(limit_model, user_id, group_id, channel_id)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def __check(
|
async def __check(
|
||||||
@ -112,7 +115,6 @@ class LimitManage:
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
group_id: str | None,
|
group_id: str | None,
|
||||||
channel_id: str | None,
|
channel_id: str | None,
|
||||||
session: Uninfo,
|
|
||||||
):
|
):
|
||||||
"""检测限制
|
"""检测限制
|
||||||
|
|
||||||
@ -121,7 +123,6 @@ class LimitManage:
|
|||||||
user_id: 用户id
|
user_id: 用户id
|
||||||
group_id: 群组id
|
group_id: 群组id
|
||||||
channel_id: 频道id
|
channel_id: 频道id
|
||||||
session: Session
|
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
IgnoredException: IgnoredException
|
IgnoredException: IgnoredException
|
||||||
@ -166,23 +167,14 @@ async def auth_limit(plugin: PluginInfo, session: Uninfo):
|
|||||||
plugin: PluginInfo
|
plugin: PluginInfo
|
||||||
session: Uninfo
|
session: Uninfo
|
||||||
"""
|
"""
|
||||||
user_id = session.user.id
|
entity = get_entity_ids(session)
|
||||||
group_id = None
|
if plugin.module not in LimitManager.add_module:
|
||||||
channel_id = None
|
|
||||||
if session.group:
|
|
||||||
if session.group.parent:
|
|
||||||
group_id = session.group.parent.id
|
|
||||||
channel_id = session.group.id
|
|
||||||
else:
|
|
||||||
group_id = session.group.id
|
|
||||||
if not group_id:
|
|
||||||
group_id = channel_id
|
|
||||||
channel_id = None
|
|
||||||
if plugin.module not in LimitManage.add_module:
|
|
||||||
limit_list: list[PluginLimit] = await plugin.plugin_limit.filter(
|
limit_list: list[PluginLimit] = await plugin.plugin_limit.filter(
|
||||||
status=True
|
status=True
|
||||||
).all() # type: ignore
|
).all() # type: ignore
|
||||||
for limit in limit_list:
|
for limit in limit_list:
|
||||||
LimitManage.add_limit(limit)
|
LimitManager.add_limit(limit)
|
||||||
if user_id:
|
if entity.user_id:
|
||||||
await LimitManage.check(plugin.module, user_id, group_id, channel_id, session)
|
await LimitManager.check(
|
||||||
|
plugin.module, entity.user_id, entity.group_id, entity.channel_id
|
||||||
|
)
|
||||||
|
|||||||
35
zhenxun/builtin_plugins/hooks/auth/bot_filter.py
Normal file
35
zhenxun/builtin_plugins/hooks/auth/bot_filter.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import nonebot
|
||||||
|
from nonebot_plugin_uninfo import Uninfo
|
||||||
|
|
||||||
|
from zhenxun.configs.config import Config
|
||||||
|
|
||||||
|
from .exception import SkipPluginException
|
||||||
|
|
||||||
|
Config.add_plugin_config(
|
||||||
|
"hook",
|
||||||
|
"FILTER_BOT",
|
||||||
|
True,
|
||||||
|
help="过滤当前连接bot(防止bot互相调用)",
|
||||||
|
default_value=True,
|
||||||
|
type=bool,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def bot_filter(session: Uninfo):
|
||||||
|
"""过滤bot调用bot
|
||||||
|
|
||||||
|
参数:
|
||||||
|
session: Uninfo
|
||||||
|
|
||||||
|
异常:
|
||||||
|
SkipPluginException: bot互相调用
|
||||||
|
"""
|
||||||
|
if not Config.get_config("hook", "FILTER_BOT"):
|
||||||
|
return
|
||||||
|
bot_ids = list(nonebot.get_bots().keys())
|
||||||
|
if session.user.id == session.self_id:
|
||||||
|
return
|
||||||
|
if session.user.id in bot_ids:
|
||||||
|
raise SkipPluginException(
|
||||||
|
f"bot:{session.self_id} 尝试调用 bot:{session.user.id}"
|
||||||
|
)
|
||||||
@ -27,6 +27,7 @@ from .auth.auth_cost import auth_cost
|
|||||||
from .auth.auth_group import auth_group
|
from .auth.auth_group import auth_group
|
||||||
from .auth.auth_limit import LimitManage, auth_limit
|
from .auth.auth_limit import LimitManage, auth_limit
|
||||||
from .auth.auth_plugin import auth_plugin
|
from .auth.auth_plugin import auth_plugin
|
||||||
|
from .auth.bot_filter import bot_filter
|
||||||
from .auth.config import LOGGER_COMMAND
|
from .auth.config import LOGGER_COMMAND
|
||||||
from .auth.exception import (
|
from .auth.exception import (
|
||||||
IsSuperuserException,
|
IsSuperuserException,
|
||||||
@ -152,6 +153,7 @@ async def auth(
|
|||||||
cost_gold = await get_plugin_cost(bot, user, plugin, session)
|
cost_gold = await get_plugin_cost(bot, user, plugin, session)
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
|
bot_filter(session),
|
||||||
auth_ban(matcher, bot, session),
|
auth_ban(matcher, bot, session),
|
||||||
auth_bot(plugin, bot.self_id),
|
auth_bot(plugin, bot.self_id),
|
||||||
auth_group(plugin, entity, message),
|
auth_group(plugin, entity, message),
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from zhenxun.models.bot_console import BotConsole
|
|||||||
from zhenxun.models.group_console import GroupConsole
|
from zhenxun.models.group_console import GroupConsole
|
||||||
from zhenxun.models.level_user import LevelUser
|
from zhenxun.models.level_user import LevelUser
|
||||||
from zhenxun.models.plugin_info import PluginInfo
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
|
from zhenxun.models.plugin_limit import PluginLimit
|
||||||
from zhenxun.models.user_console import UserConsole
|
from zhenxun.models.user_console import UserConsole
|
||||||
from zhenxun.services.cache import CacheData, CacheRoot
|
from zhenxun.services.cache import CacheData, CacheRoot
|
||||||
from zhenxun.utils.enum import CacheType
|
from zhenxun.utils.enum import CacheType
|
||||||
@ -38,16 +39,41 @@ def default_cleanup_expired(cache_data: CacheData) -> list[str]:
|
|||||||
return expire_key
|
return expire_key
|
||||||
|
|
||||||
|
|
||||||
|
def default_cleanup_expired_1(cache_data: CacheData) -> list[str]:
|
||||||
|
"""默认清理列表过期cache方法"""
|
||||||
|
if not cache_data.data:
|
||||||
|
return []
|
||||||
|
now = time.time()
|
||||||
|
expire_key = []
|
||||||
|
for k, t in list(cache_data.expire_data.items()):
|
||||||
|
if t < now:
|
||||||
|
expire_key.append(k)
|
||||||
|
cache_data.expire_data.pop(k)
|
||||||
|
if expire_key:
|
||||||
|
cache_data.data = [k for k in cache_data.data if repr(k) not in expire_key]
|
||||||
|
return expire_key
|
||||||
|
|
||||||
|
|
||||||
def default_with_expiration(
|
def default_with_expiration(
|
||||||
data: dict[str, Any], expire_data: dict[str, int], expire: int
|
data: dict[str, Any], expire_data: dict[str, int], expire: int
|
||||||
):
|
):
|
||||||
"""默认更新期时间cache方法"""
|
"""默认更新过期时间cache方法"""
|
||||||
if not data:
|
if not data:
|
||||||
return {}
|
return {}
|
||||||
keys = {k for k in data if k not in expire_data}
|
keys = {k for k in data if k not in expire_data}
|
||||||
return {k: time.time() + expire for k in keys} if keys else {}
|
return {k: time.time() + expire for k in keys} if keys else {}
|
||||||
|
|
||||||
|
|
||||||
|
def default_with_expiration_1(
|
||||||
|
data: dict[str, Any], expire_data: dict[str, int], expire: int
|
||||||
|
):
|
||||||
|
"""默认更新过期时间cache方法"""
|
||||||
|
if not data:
|
||||||
|
return {}
|
||||||
|
keys = {repr(k) for k in data if repr(k) not in expire_data}
|
||||||
|
return {k: time.time() + expire for k in keys} if keys else {}
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.new(CacheType.PLUGINS)
|
@CacheRoot.new(CacheType.PLUGINS)
|
||||||
async def _():
|
async def _():
|
||||||
data_list = await PluginInfo.get_plugins()
|
data_list = await PluginInfo.get_plugins()
|
||||||
@ -75,7 +101,7 @@ async def _(cache_data: CacheData, module: str):
|
|||||||
|
|
||||||
@CacheRoot.with_refresh(CacheType.PLUGINS)
|
@CacheRoot.with_refresh(CacheType.PLUGINS)
|
||||||
async def _(data: dict[str, PluginInfo]):
|
async def _(data: dict[str, PluginInfo]):
|
||||||
plugins = await PluginInfo.filter(module__in=data.keys(), load_status=True)
|
plugins = await PluginInfo.filter(module__in=data.keys(), load_status=True).all()
|
||||||
for plugin in plugins:
|
for plugin in plugins:
|
||||||
data[plugin.module] = plugin
|
data[plugin.module] = plugin
|
||||||
|
|
||||||
@ -119,7 +145,7 @@ async def _(cache_data: CacheData, group_id: str):
|
|||||||
async def _(data: dict[str, GroupConsole]):
|
async def _(data: dict[str, GroupConsole]):
|
||||||
groups = await GroupConsole.filter(
|
groups = await GroupConsole.filter(
|
||||||
group_id__in=data.keys(), channel_id__isnull=True
|
group_id__in=data.keys(), channel_id__isnull=True
|
||||||
)
|
).all()
|
||||||
for group in groups:
|
for group in groups:
|
||||||
data[group.group_id] = group
|
data[group.group_id] = group
|
||||||
|
|
||||||
@ -161,7 +187,7 @@ async def _(cache_data: CacheData, bot_id: str):
|
|||||||
|
|
||||||
@CacheRoot.with_refresh(CacheType.BOT)
|
@CacheRoot.with_refresh(CacheType.BOT)
|
||||||
async def _(data: dict[str, BotConsole]):
|
async def _(data: dict[str, BotConsole]):
|
||||||
bots = await BotConsole.filter(bot_id__in=data.keys())
|
bots = await BotConsole.filter(bot_id__in=data.keys()).all()
|
||||||
for bot in bots:
|
for bot in bots:
|
||||||
data[bot.bot_id] = bot
|
data[bot.bot_id] = bot
|
||||||
|
|
||||||
@ -203,7 +229,7 @@ async def _(cache_data: CacheData, user_id: str):
|
|||||||
|
|
||||||
@CacheRoot.with_refresh(CacheType.USERS)
|
@CacheRoot.with_refresh(CacheType.USERS)
|
||||||
async def _(data: dict[str, UserConsole]):
|
async def _(data: dict[str, UserConsole]):
|
||||||
users = await UserConsole.filter(user_id__in=data.keys())
|
users = await UserConsole.filter(user_id__in=data.keys()).all()
|
||||||
for user in users:
|
for user in users:
|
||||||
data[user.user_id] = user
|
data[user.user_id] = user
|
||||||
|
|
||||||
@ -240,7 +266,17 @@ async def _(cache_data: CacheData, user_id: str, group_id: str | None = None):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.new(CacheType.BAN, False, 5)
|
@CacheRoot.with_expiration(CacheType.LEVEL)
|
||||||
|
def _(data: dict[str, UserConsole], expire_data: dict[str, int], expire: int):
|
||||||
|
return default_with_expiration_1(data, expire_data, expire)
|
||||||
|
|
||||||
|
|
||||||
|
@CacheRoot.cleanup_expired(CacheType.LEVEL)
|
||||||
|
def _(cache_data: CacheData):
|
||||||
|
return default_cleanup_expired_1(cache_data)
|
||||||
|
|
||||||
|
|
||||||
|
@CacheRoot.new(CacheType.BAN, False)
|
||||||
async def _():
|
async def _():
|
||||||
return await BanConsole.all()
|
return await BanConsole.all()
|
||||||
|
|
||||||
@ -268,3 +304,63 @@ async def _(cache_data: CacheData, user_id: str | None, group_id: str | None = N
|
|||||||
if not data.user_id and data.group_id == group_id
|
if not data.user_id and data.group_id == group_id
|
||||||
]
|
]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@CacheRoot.with_expiration(CacheType.BAN)
|
||||||
|
def _(data: dict[str, UserConsole], expire_data: dict[str, int], expire: int):
|
||||||
|
return default_with_expiration_1(data, expire_data, expire)
|
||||||
|
|
||||||
|
|
||||||
|
@CacheRoot.cleanup_expired(CacheType.BAN)
|
||||||
|
def _(cache_data: CacheData):
|
||||||
|
return default_cleanup_expired_1(cache_data)
|
||||||
|
|
||||||
|
|
||||||
|
@CacheRoot.new(CacheType.LIMIT)
|
||||||
|
async def _():
|
||||||
|
data_list = await PluginLimit.filter(status=True).all()
|
||||||
|
result_data = {}
|
||||||
|
for data in data_list:
|
||||||
|
if not result_data.get(data.module):
|
||||||
|
result_data[data.module] = []
|
||||||
|
result_data[data.module].append(data)
|
||||||
|
return result_data
|
||||||
|
|
||||||
|
|
||||||
|
@CacheRoot.updater(CacheType.LIMIT)
|
||||||
|
async def _(data: dict[str, list[PluginLimit]], key: str, value: Any):
|
||||||
|
if value:
|
||||||
|
data[key] = value
|
||||||
|
elif limits := await PluginLimit.filter(module=key, status=True):
|
||||||
|
data[key] = limits
|
||||||
|
|
||||||
|
|
||||||
|
@CacheRoot.getter(CacheType.LIMIT, result_model=list[PluginLimit])
|
||||||
|
async def _(cache_data: CacheData, module: str):
|
||||||
|
cache_data.data = cache_data.data or {}
|
||||||
|
result = cache_data.data.get(module, None)
|
||||||
|
if not result:
|
||||||
|
result = await PluginLimit.filter(module=module, status=True)
|
||||||
|
if result:
|
||||||
|
cache_data.data[module] = result
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@CacheRoot.with_refresh(CacheType.LIMIT)
|
||||||
|
async def _(data: dict[str, list[PluginLimit]]):
|
||||||
|
limits = await PluginLimit.filter(module__in=data.keys(), load_status=True).all()
|
||||||
|
data.clear()
|
||||||
|
for limit in limits:
|
||||||
|
if not data.get(limit.module):
|
||||||
|
data[limit.module] = []
|
||||||
|
data[limit.module].append(limit)
|
||||||
|
|
||||||
|
|
||||||
|
@CacheRoot.with_expiration(CacheType.LIMIT)
|
||||||
|
def _(data: dict[str, PluginInfo], expire_data: dict[str, int], expire: int):
|
||||||
|
return default_with_expiration(data, expire_data, expire)
|
||||||
|
|
||||||
|
|
||||||
|
@CacheRoot.cleanup_expired(CacheType.LIMIT)
|
||||||
|
def _(cache_data: CacheData):
|
||||||
|
return default_cleanup_expired(cache_data)
|
||||||
|
|||||||
@ -33,7 +33,7 @@ def validate_name(func: Callable):
|
|||||||
|
|
||||||
def wrapper(self, name: str, *args, **kwargs):
|
def wrapper(self, name: str, *args, **kwargs):
|
||||||
_name = name.upper()
|
_name = name.upper()
|
||||||
if _name not in CacheManage._data:
|
if _name not in CacheManager._data:
|
||||||
raise DbCacheException(f"DbCache 缓存数据 {name} 不存在...")
|
raise DbCacheException(f"DbCache 缓存数据 {name} 不存在...")
|
||||||
return func(self, _name, *args, **kwargs)
|
return func(self, _name, *args, **kwargs)
|
||||||
|
|
||||||
@ -190,7 +190,7 @@ class CacheData(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CacheManage:
|
class CacheManager:
|
||||||
"""全局缓存管理,减少数据库与网络请求查询次数
|
"""全局缓存管理,减少数据库与网络请求查询次数
|
||||||
|
|
||||||
|
|
||||||
@ -324,7 +324,7 @@ class CacheManage:
|
|||||||
await cache.update(key, value, *args, **kwargs)
|
await cache.update(key, value, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
CacheRoot = CacheManage()
|
CacheRoot = CacheManager()
|
||||||
|
|
||||||
|
|
||||||
class Cache(Generic[T]):
|
class Cache(Generic[T]):
|
||||||
|
|||||||
@ -18,6 +18,8 @@ class CacheType(StrEnum):
|
|||||||
"""全局bot信息"""
|
"""全局bot信息"""
|
||||||
LEVEL = "GLOBAL_USER_LEVEL"
|
LEVEL = "GLOBAL_USER_LEVEL"
|
||||||
"""用户权限"""
|
"""用户权限"""
|
||||||
|
LIMIT = "GLOBAL_LIMIT"
|
||||||
|
"""插件限制"""
|
||||||
|
|
||||||
|
|
||||||
class DbLockType(StrEnum):
|
class DbLockType(StrEnum):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user