mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 14:22:55 +08:00
✨ 构建缓存,hook使用缓存
This commit is contained in:
parent
45649bb29d
commit
521bcaceeb
@ -1,580 +0,0 @@
|
|||||||
from nonebot.adapters import Bot, Event
|
|
||||||
from nonebot.adapters.onebot.v11 import PokeNotifyEvent
|
|
||||||
from nonebot.exception import IgnoredException
|
|
||||||
from nonebot.matcher import Matcher
|
|
||||||
from nonebot_plugin_alconna import At, UniMsg
|
|
||||||
from nonebot_plugin_session import EventSession
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from tortoise.exceptions import IntegrityError
|
|
||||||
|
|
||||||
from zhenxun.configs.config import Config
|
|
||||||
from zhenxun.models.bot_console import BotConsole
|
|
||||||
from zhenxun.models.group_console import GroupConsole
|
|
||||||
from zhenxun.models.level_user import LevelUser
|
|
||||||
from zhenxun.models.plugin_info import PluginInfo
|
|
||||||
from zhenxun.models.plugin_limit import PluginLimit
|
|
||||||
from zhenxun.models.user_console import UserConsole
|
|
||||||
from zhenxun.services.log import logger
|
|
||||||
from zhenxun.utils.cache_utils import Cache
|
|
||||||
from zhenxun.utils.enum import (
|
|
||||||
BlockType,
|
|
||||||
CacheType,
|
|
||||||
GoldHandle,
|
|
||||||
LimitWatchType,
|
|
||||||
PluginLimitType,
|
|
||||||
PluginType,
|
|
||||||
)
|
|
||||||
from zhenxun.utils.exception import InsufficientGold
|
|
||||||
from zhenxun.utils.message import MessageUtils
|
|
||||||
from zhenxun.utils.utils import CountLimiter, FreqLimiter, UserBlockLimiter
|
|
||||||
|
|
||||||
base_config = Config.get("hook")
|
|
||||||
|
|
||||||
|
|
||||||
class Limit(BaseModel):
|
|
||||||
limit: PluginLimit
|
|
||||||
limiter: FreqLimiter | UserBlockLimiter | CountLimiter
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
|
|
||||||
class LimitManage:
|
|
||||||
add_module = [] # noqa: RUF012
|
|
||||||
|
|
||||||
cd_limit: dict[str, Limit] = {} # noqa: RUF012
|
|
||||||
block_limit: dict[str, Limit] = {} # noqa: RUF012
|
|
||||||
count_limit: dict[str, Limit] = {} # noqa: RUF012
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def add_limit(cls, limit: PluginLimit):
|
|
||||||
"""添加限制
|
|
||||||
|
|
||||||
参数:
|
|
||||||
limit: PluginLimit
|
|
||||||
"""
|
|
||||||
if limit.module not in cls.add_module:
|
|
||||||
cls.add_module.append(limit.module)
|
|
||||||
if limit.limit_type == PluginLimitType.BLOCK:
|
|
||||||
cls.block_limit[limit.module] = Limit(
|
|
||||||
limit=limit, limiter=UserBlockLimiter()
|
|
||||||
)
|
|
||||||
elif limit.limit_type == PluginLimitType.CD:
|
|
||||||
cls.cd_limit[limit.module] = Limit(
|
|
||||||
limit=limit, limiter=FreqLimiter(limit.cd)
|
|
||||||
)
|
|
||||||
elif limit.limit_type == PluginLimitType.COUNT:
|
|
||||||
cls.count_limit[limit.module] = Limit(
|
|
||||||
limit=limit, limiter=CountLimiter(limit.max_count)
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def unblock(
|
|
||||||
cls, module: str, user_id: str, group_id: str | None, channel_id: str | None
|
|
||||||
):
|
|
||||||
"""解除插件block
|
|
||||||
|
|
||||||
参数:
|
|
||||||
module: 模块名
|
|
||||||
user_id: 用户id
|
|
||||||
group_id: 群组id
|
|
||||||
channel_id: 频道id
|
|
||||||
"""
|
|
||||||
if limit_model := cls.block_limit.get(module):
|
|
||||||
limit = limit_model.limit
|
|
||||||
limiter: UserBlockLimiter = limit_model.limiter # type: ignore
|
|
||||||
key_type = user_id
|
|
||||||
if group_id and limit.watch_type == LimitWatchType.GROUP:
|
|
||||||
key_type = channel_id or group_id
|
|
||||||
logger.debug(
|
|
||||||
f"解除对象: {key_type} 的block限制",
|
|
||||||
"AuthChecker",
|
|
||||||
session=user_id,
|
|
||||||
group_id=group_id,
|
|
||||||
)
|
|
||||||
limiter.set_false(key_type)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def check(
|
|
||||||
cls,
|
|
||||||
module: str,
|
|
||||||
user_id: str,
|
|
||||||
group_id: str | None,
|
|
||||||
channel_id: str | None,
|
|
||||||
session: EventSession,
|
|
||||||
):
|
|
||||||
"""检测限制
|
|
||||||
|
|
||||||
参数:
|
|
||||||
module: 模块名
|
|
||||||
user_id: 用户id
|
|
||||||
group_id: 群组id
|
|
||||||
channel_id: 频道id
|
|
||||||
session: Session
|
|
||||||
|
|
||||||
异常:
|
|
||||||
IgnoredException: IgnoredException
|
|
||||||
"""
|
|
||||||
if limit_model := cls.cd_limit.get(module):
|
|
||||||
await cls.__check(limit_model, user_id, group_id, channel_id, session)
|
|
||||||
if limit_model := cls.block_limit.get(module):
|
|
||||||
await cls.__check(limit_model, user_id, group_id, channel_id, session)
|
|
||||||
if limit_model := cls.count_limit.get(module):
|
|
||||||
await cls.__check(limit_model, user_id, group_id, channel_id, session)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def __check(
|
|
||||||
cls,
|
|
||||||
limit_model: Limit | None,
|
|
||||||
user_id: str,
|
|
||||||
group_id: str | None,
|
|
||||||
channel_id: str | None,
|
|
||||||
session: EventSession,
|
|
||||||
):
|
|
||||||
"""检测限制
|
|
||||||
|
|
||||||
参数:
|
|
||||||
limit_model: Limit
|
|
||||||
user_id: 用户id
|
|
||||||
group_id: 群组id
|
|
||||||
channel_id: 频道id
|
|
||||||
session: Session
|
|
||||||
|
|
||||||
异常:
|
|
||||||
IgnoredException: IgnoredException
|
|
||||||
"""
|
|
||||||
if not limit_model:
|
|
||||||
return
|
|
||||||
limit = limit_model.limit
|
|
||||||
limiter = limit_model.limiter
|
|
||||||
is_limit = (
|
|
||||||
LimitWatchType.ALL
|
|
||||||
or (group_id and limit.watch_type == LimitWatchType.GROUP)
|
|
||||||
or (not group_id and limit.watch_type == LimitWatchType.USER)
|
|
||||||
)
|
|
||||||
key_type = user_id
|
|
||||||
if group_id and limit.watch_type == LimitWatchType.GROUP:
|
|
||||||
key_type = channel_id or group_id
|
|
||||||
if is_limit and not limiter.check(key_type):
|
|
||||||
if limit.result:
|
|
||||||
await MessageUtils.build_message(limit.result).send()
|
|
||||||
logger.debug(
|
|
||||||
f"{limit.module}({limit.limit_type}) 正在限制中...",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
raise IgnoredException(f"{limit.module} 正在限制中...")
|
|
||||||
else:
|
|
||||||
logger.debug(
|
|
||||||
f"开始进行限制 {limit.module}({limit.limit_type})...",
|
|
||||||
"AuthChecker",
|
|
||||||
session=user_id,
|
|
||||||
group_id=group_id,
|
|
||||||
)
|
|
||||||
if isinstance(limiter, FreqLimiter):
|
|
||||||
limiter.start_cd(key_type)
|
|
||||||
if isinstance(limiter, UserBlockLimiter):
|
|
||||||
limiter.set_true(key_type)
|
|
||||||
if isinstance(limiter, CountLimiter):
|
|
||||||
limiter.increase(key_type)
|
|
||||||
|
|
||||||
|
|
||||||
class IsSuperuserException(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class AuthChecker:
|
|
||||||
"""
|
|
||||||
权限检查
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
check_notice_info_cd = Config.get_config("hook", "CHECK_NOTICE_INFO_CD")
|
|
||||||
if check_notice_info_cd is None or check_notice_info_cd < 0:
|
|
||||||
raise ValueError("模块: [hook], 配置项: [CHECK_NOTICE_INFO_CD] 为空或小于0")
|
|
||||||
self._flmt = FreqLimiter(check_notice_info_cd)
|
|
||||||
self._flmt_g = FreqLimiter(check_notice_info_cd)
|
|
||||||
self._flmt_s = FreqLimiter(check_notice_info_cd)
|
|
||||||
self._flmt_c = FreqLimiter(check_notice_info_cd)
|
|
||||||
|
|
||||||
def is_send_limit_message(self, plugin: PluginInfo, sid: str) -> bool:
|
|
||||||
"""是否发送提示消息
|
|
||||||
|
|
||||||
参数:
|
|
||||||
plugin: PluginInfo
|
|
||||||
|
|
||||||
返回:
|
|
||||||
bool: 是否发送提示消息
|
|
||||||
"""
|
|
||||||
if not base_config.get("IS_SEND_TIP_MESSAGE"):
|
|
||||||
return False
|
|
||||||
if plugin.plugin_type == PluginType.DEPENDANT:
|
|
||||||
return False
|
|
||||||
return plugin.module != "ai" if self._flmt_s.check(sid) else False
|
|
||||||
|
|
||||||
async def auth(
|
|
||||||
self,
|
|
||||||
matcher: Matcher,
|
|
||||||
event: Event,
|
|
||||||
bot: Bot,
|
|
||||||
session: EventSession,
|
|
||||||
message: UniMsg,
|
|
||||||
):
|
|
||||||
"""权限检查
|
|
||||||
|
|
||||||
参数:
|
|
||||||
matcher: matcher
|
|
||||||
bot: bot
|
|
||||||
session: EventSession
|
|
||||||
message: UniMsg
|
|
||||||
"""
|
|
||||||
is_ignore = False
|
|
||||||
cost_gold = 0
|
|
||||||
user_id = session.id1
|
|
||||||
group_id = session.id3
|
|
||||||
channel_id = session.id2
|
|
||||||
if not group_id:
|
|
||||||
group_id = channel_id
|
|
||||||
channel_id = None
|
|
||||||
if matcher.type == "notice" and not isinstance(event, PokeNotifyEvent):
|
|
||||||
"""过滤除poke外的notice"""
|
|
||||||
return
|
|
||||||
if user_id and matcher.plugin and (module_path := matcher.plugin.module_name):
|
|
||||||
try:
|
|
||||||
user = await UserConsole.get_user(user_id, session.platform)
|
|
||||||
except IntegrityError as e:
|
|
||||||
logger.debug(
|
|
||||||
"重复创建用户,已跳过该次权限...",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
e=e,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
if plugin := await Cache.get(CacheType.PLUGINS, module_path):
|
|
||||||
if plugin.plugin_type == PluginType.HIDDEN:
|
|
||||||
logger.debug(
|
|
||||||
f"插件: {plugin.name}:{plugin.module} "
|
|
||||||
"为HIDDEN,已跳过权限检查..."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
cost_gold = await self.auth_cost(user, plugin, session)
|
|
||||||
if session.id1 in bot.config.superusers:
|
|
||||||
if plugin.plugin_type == PluginType.SUPERUSER:
|
|
||||||
raise IsSuperuserException()
|
|
||||||
if not plugin.limit_superuser:
|
|
||||||
cost_gold = 0
|
|
||||||
raise IsSuperuserException()
|
|
||||||
await self.auth_bot(plugin, bot.self_id)
|
|
||||||
await self.auth_group(plugin, session, message)
|
|
||||||
await self.auth_admin(plugin, session)
|
|
||||||
await self.auth_plugin(plugin, session, event)
|
|
||||||
await self.auth_limit(plugin, session)
|
|
||||||
except IsSuperuserException:
|
|
||||||
logger.debug(
|
|
||||||
"超级用户或被ban跳过权限检测...", "AuthChecker", session=session
|
|
||||||
)
|
|
||||||
except IgnoredException:
|
|
||||||
is_ignore = True
|
|
||||||
LimitManage.unblock(
|
|
||||||
matcher.plugin.name, user_id, group_id, channel_id
|
|
||||||
)
|
|
||||||
except AssertionError as e:
|
|
||||||
is_ignore = True
|
|
||||||
logger.debug("消息无法发送", session=session, e=e)
|
|
||||||
if cost_gold and user_id:
|
|
||||||
"""花费金币"""
|
|
||||||
try:
|
|
||||||
await UserConsole.reduce_gold(
|
|
||||||
user_id,
|
|
||||||
cost_gold,
|
|
||||||
GoldHandle.PLUGIN,
|
|
||||||
matcher.plugin.name if matcher.plugin else "",
|
|
||||||
session.platform,
|
|
||||||
)
|
|
||||||
except InsufficientGold:
|
|
||||||
if u := await UserConsole.get_user(user_id):
|
|
||||||
u.gold = 0
|
|
||||||
await u.save(update_fields=["gold"])
|
|
||||||
logger.debug(
|
|
||||||
f"调用功能花费金币: {cost_gold}", "AuthChecker", session=session
|
|
||||||
)
|
|
||||||
if is_ignore:
|
|
||||||
raise IgnoredException("权限检测 ignore")
|
|
||||||
|
|
||||||
async def auth_bot(self, plugin: PluginInfo, bot_id: str):
|
|
||||||
"""机器人权限
|
|
||||||
|
|
||||||
参数:
|
|
||||||
plugin: PluginInfo
|
|
||||||
bot_id: bot_id
|
|
||||||
"""
|
|
||||||
if not await BotConsole.get_bot_status(bot_id):
|
|
||||||
logger.debug("Bot休眠中阻断权限检测...", "AuthChecker")
|
|
||||||
raise IgnoredException("BotConsole休眠权限检测 ignore")
|
|
||||||
if await BotConsole.is_block_plugin(bot_id, plugin.module):
|
|
||||||
logger.debug(
|
|
||||||
f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭...",
|
|
||||||
"AuthChecker",
|
|
||||||
)
|
|
||||||
raise IgnoredException("BotConsole插件权限检测 ignore")
|
|
||||||
|
|
||||||
async def auth_limit(self, plugin: PluginInfo, session: EventSession):
|
|
||||||
"""插件限制
|
|
||||||
|
|
||||||
参数:
|
|
||||||
plugin: PluginInfo
|
|
||||||
session: EventSession
|
|
||||||
"""
|
|
||||||
user_id = session.id1
|
|
||||||
group_id = session.id3
|
|
||||||
channel_id = session.id2
|
|
||||||
if not group_id:
|
|
||||||
group_id = channel_id
|
|
||||||
channel_id = None
|
|
||||||
if plugin.module not in LimitManage.add_module:
|
|
||||||
limit_list: list[PluginLimit] = await plugin.plugin_limit.filter(
|
|
||||||
status=True
|
|
||||||
).all() # type: ignore
|
|
||||||
for limit in limit_list:
|
|
||||||
LimitManage.add_limit(limit)
|
|
||||||
if user_id:
|
|
||||||
await LimitManage.check(
|
|
||||||
plugin.module, user_id, group_id, channel_id, session
|
|
||||||
)
|
|
||||||
|
|
||||||
async def auth_plugin(
|
|
||||||
self, plugin: PluginInfo, session: EventSession, event: Event
|
|
||||||
):
|
|
||||||
"""插件状态
|
|
||||||
|
|
||||||
参数:
|
|
||||||
plugin: PluginInfo
|
|
||||||
session: EventSession
|
|
||||||
"""
|
|
||||||
group_id = session.id3
|
|
||||||
channel_id = session.id2
|
|
||||||
if not group_id:
|
|
||||||
group_id = channel_id
|
|
||||||
channel_id = None
|
|
||||||
if user_id := session.id1:
|
|
||||||
is_poke = isinstance(event, PokeNotifyEvent)
|
|
||||||
if group_id:
|
|
||||||
sid = group_id or user_id
|
|
||||||
if await GroupConsole.is_superuser_block_plugin(
|
|
||||||
group_id, plugin.module
|
|
||||||
):
|
|
||||||
"""超级用户群组插件状态"""
|
|
||||||
if self.is_send_limit_message(plugin, sid) and not is_poke:
|
|
||||||
self._flmt_s.start_cd(group_id or user_id)
|
|
||||||
await MessageUtils.build_message(
|
|
||||||
"超级管理员禁用了该群此功能..."
|
|
||||||
).send(reply_to=True)
|
|
||||||
logger.debug(
|
|
||||||
f"{plugin.name}({plugin.module}) 超级管理员禁用了该群此功能...",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
raise IgnoredException("超级管理员禁用了该群此功能...")
|
|
||||||
if await GroupConsole.is_normal_block_plugin(group_id, plugin.module):
|
|
||||||
"""群组插件状态"""
|
|
||||||
if self.is_send_limit_message(plugin, sid) and not is_poke:
|
|
||||||
self._flmt_s.start_cd(group_id or user_id)
|
|
||||||
await MessageUtils.build_message("该群未开启此功能...").send(
|
|
||||||
reply_to=True
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"{plugin.name}({plugin.module}) 未开启此功能...",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
raise IgnoredException("该群未开启此功能...")
|
|
||||||
if plugin.block_type == BlockType.GROUP:
|
|
||||||
"""全局群组禁用"""
|
|
||||||
try:
|
|
||||||
if self.is_send_limit_message(plugin, sid) and not is_poke:
|
|
||||||
self._flmt_c.start_cd(group_id)
|
|
||||||
await MessageUtils.build_message(
|
|
||||||
"该功能在群组中已被禁用..."
|
|
||||||
).send(reply_to=True)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
"auth_plugin 发送消息失败",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
e=e,
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"{plugin.name}({plugin.module}) 该插件在群组中已被禁用...",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
raise IgnoredException("该插件在群组中已被禁用...")
|
|
||||||
else:
|
|
||||||
sid = user_id
|
|
||||||
if plugin.block_type == BlockType.PRIVATE:
|
|
||||||
"""全局私聊禁用"""
|
|
||||||
try:
|
|
||||||
if self.is_send_limit_message(plugin, sid) and not is_poke:
|
|
||||||
self._flmt_c.start_cd(user_id)
|
|
||||||
await MessageUtils.build_message(
|
|
||||||
"该功能在私聊中已被禁用..."
|
|
||||||
).send()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
"auth_admin 发送消息失败",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
e=e,
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"{plugin.name}({plugin.module}) 该插件在私聊中已被禁用...",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
raise IgnoredException("该插件在私聊中已被禁用...")
|
|
||||||
if not plugin.status and plugin.block_type == BlockType.ALL:
|
|
||||||
"""全局状态"""
|
|
||||||
if group_id and await GroupConsole.is_super_group(group_id):
|
|
||||||
raise IsSuperuserException()
|
|
||||||
logger.debug(
|
|
||||||
f"{plugin.name}({plugin.module}) 全局未开启此功能...",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
if self.is_send_limit_message(plugin, sid) and not is_poke:
|
|
||||||
self._flmt_s.start_cd(group_id or user_id)
|
|
||||||
await MessageUtils.build_message("全局未开启此功能...").send()
|
|
||||||
raise IgnoredException("全局未开启此功能...")
|
|
||||||
|
|
||||||
async def auth_admin(self, plugin: PluginInfo, session: EventSession):
|
|
||||||
"""管理员命令 个人权限
|
|
||||||
|
|
||||||
参数:
|
|
||||||
plugin: PluginInfo
|
|
||||||
session: EventSession
|
|
||||||
"""
|
|
||||||
user_id = session.id1
|
|
||||||
if user_id and plugin.admin_level:
|
|
||||||
if group_id := session.id3 or session.id2:
|
|
||||||
if not await LevelUser.check_level(
|
|
||||||
user_id, group_id, plugin.admin_level
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
if self._flmt.check(user_id):
|
|
||||||
self._flmt.start_cd(user_id)
|
|
||||||
await MessageUtils.build_message(
|
|
||||||
[
|
|
||||||
At(flag="user", target=user_id),
|
|
||||||
f"你的权限不足喔,"
|
|
||||||
f"该功能需要的权限等级: {plugin.admin_level}",
|
|
||||||
]
|
|
||||||
).send(reply_to=True)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
"auth_admin 发送消息失败",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
e=e,
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"{plugin.name}({plugin.module}) 管理员权限不足...",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
raise IgnoredException("管理员权限不足...")
|
|
||||||
elif not await LevelUser.check_level(user_id, None, plugin.admin_level):
|
|
||||||
try:
|
|
||||||
await MessageUtils.build_message(
|
|
||||||
f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}"
|
|
||||||
).send()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
"auth_admin 发送消息失败", "AuthChecker", session=session, e=e
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"{plugin.name}({plugin.module}) 管理员权限不足...",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
raise IgnoredException("权限不足")
|
|
||||||
|
|
||||||
async def auth_group(
|
|
||||||
self, plugin: PluginInfo, session: EventSession, message: UniMsg
|
|
||||||
):
|
|
||||||
"""群黑名单检测 群总开关检测
|
|
||||||
|
|
||||||
参数:
|
|
||||||
plugin: PluginInfo
|
|
||||||
session: EventSession
|
|
||||||
message: UniMsg
|
|
||||||
"""
|
|
||||||
if not (group_id := session.id3 or session.id2):
|
|
||||||
return
|
|
||||||
text = message.extract_plain_text()
|
|
||||||
group = await GroupConsole.get_group(group_id)
|
|
||||||
if not group:
|
|
||||||
"""群不存在"""
|
|
||||||
logger.debug(
|
|
||||||
"群组信息不存在...",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
raise IgnoredException("群不存在")
|
|
||||||
if group.level < 0:
|
|
||||||
"""群权限小于0"""
|
|
||||||
logger.debug(
|
|
||||||
"群黑名单, 群权限-1...",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
raise IgnoredException("群黑名单")
|
|
||||||
if not group.status:
|
|
||||||
"""群休眠"""
|
|
||||||
if text.strip() != "醒来":
|
|
||||||
logger.debug("群休眠状态...", "AuthChecker", session=session)
|
|
||||||
raise IgnoredException("群休眠状态")
|
|
||||||
if plugin.level > group.level:
|
|
||||||
"""插件等级大于群等级"""
|
|
||||||
logger.debug(
|
|
||||||
f"{plugin.name}({plugin.module}) 群等级限制.."
|
|
||||||
f"该功能需要的群等级: {plugin.level}..",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
raise IgnoredException(f"{plugin.name}({plugin.module}) 群等级限制...")
|
|
||||||
|
|
||||||
async def auth_cost(
|
|
||||||
self, user: UserConsole, plugin: PluginInfo, session: EventSession
|
|
||||||
) -> int:
|
|
||||||
"""检测是否满足金币条件
|
|
||||||
|
|
||||||
参数:
|
|
||||||
user: UserConsole
|
|
||||||
plugin: PluginInfo
|
|
||||||
session: EventSession
|
|
||||||
|
|
||||||
返回:
|
|
||||||
int: 需要消耗的金币
|
|
||||||
"""
|
|
||||||
if user.gold < plugin.cost_gold:
|
|
||||||
"""插件消耗金币不足"""
|
|
||||||
try:
|
|
||||||
await MessageUtils.build_message(
|
|
||||||
f"金币不足..该功能需要{plugin.cost_gold}金币.."
|
|
||||||
).send()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
"auth_cost 发送消息失败", "AuthChecker", session=session, e=e
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"{plugin.name}({plugin.module}) 金币限制.."
|
|
||||||
f"该功能需要{plugin.cost_gold}金币..",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
raise IgnoredException(f"{plugin.name}({plugin.module}) 金币限制...")
|
|
||||||
return plugin.cost_gold
|
|
||||||
|
|
||||||
|
|
||||||
checker = AuthChecker()
|
|
||||||
76
zhenxun/builtin_plugins/hooks/auth/auth_admin.py
Normal file
76
zhenxun/builtin_plugins/hooks/auth/auth_admin.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
from nonebot.exception import IgnoredException
|
||||||
|
from nonebot_plugin_alconna import At
|
||||||
|
from nonebot_plugin_uninfo import Uninfo
|
||||||
|
|
||||||
|
from zhenxun.models.level_user import LevelUser
|
||||||
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
|
from zhenxun.services.log import logger
|
||||||
|
from zhenxun.utils.cache_utils import Cache
|
||||||
|
from zhenxun.utils.enum import CacheType
|
||||||
|
from zhenxun.utils.message import MessageUtils
|
||||||
|
|
||||||
|
from .utils import freq
|
||||||
|
|
||||||
|
|
||||||
|
async def auth_admin(plugin: PluginInfo, session: Uninfo):
|
||||||
|
"""管理员命令 个人权限
|
||||||
|
|
||||||
|
参数:
|
||||||
|
plugin: PluginInfo
|
||||||
|
session: PluginInfo
|
||||||
|
"""
|
||||||
|
group_id = None
|
||||||
|
cache = Cache[list[LevelUser]](CacheType.LEVEL)
|
||||||
|
user_level = await cache.get(session.user.id) or []
|
||||||
|
if session.group:
|
||||||
|
if session.group.parent:
|
||||||
|
group_id = session.group.parent.id
|
||||||
|
else:
|
||||||
|
group_id = session.group.id
|
||||||
|
|
||||||
|
if not plugin.admin_level:
|
||||||
|
return
|
||||||
|
if group_id:
|
||||||
|
user_level += await cache.get(session.user.id, group_id) or []
|
||||||
|
user = max(user_level, key=lambda x: x.user_level)
|
||||||
|
if user.user_level < plugin.admin_level:
|
||||||
|
try:
|
||||||
|
if freq._flmt.check(session.user.id):
|
||||||
|
freq._flmt.start_cd(session.user.id)
|
||||||
|
await MessageUtils.build_message(
|
||||||
|
[
|
||||||
|
At(flag="user", target=session.user.id),
|
||||||
|
"你的权限不足喔,"
|
||||||
|
f"该功能需要的权限等级: {plugin.admin_level}",
|
||||||
|
]
|
||||||
|
).send(reply_to=True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"auth_admin 发送消息失败",
|
||||||
|
"AuthChecker",
|
||||||
|
session=session,
|
||||||
|
e=e,
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"{plugin.name}({plugin.module}) 管理员权限不足...",
|
||||||
|
"AuthChecker",
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
raise IgnoredException("管理员权限不足...")
|
||||||
|
elif user_level:
|
||||||
|
user = max(user_level, key=lambda x: x.user_level)
|
||||||
|
if user.user_level < plugin.admin_level:
|
||||||
|
try:
|
||||||
|
await MessageUtils.build_message(
|
||||||
|
f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}"
|
||||||
|
).send()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"auth_admin 发送消息失败", "AuthChecker", session=session, e=e
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"{plugin.name}({plugin.module}) 管理员权限不足...",
|
||||||
|
"AuthChecker",
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
raise IgnoredException("权限不足")
|
||||||
@ -4,31 +4,29 @@ from zhenxun.models.bot_console import BotConsole
|
|||||||
from zhenxun.models.plugin_info import PluginInfo
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
from zhenxun.utils.cache_utils import Cache
|
from zhenxun.utils.cache_utils import Cache
|
||||||
|
from zhenxun.utils.common_utils import CommonUtils
|
||||||
from zhenxun.utils.enum import CacheType
|
from zhenxun.utils.enum import CacheType
|
||||||
|
|
||||||
|
|
||||||
async def get_bot_status(bot_id: str):
|
|
||||||
a = await Cache.get(CacheType.BOT, bot_id)
|
|
||||||
cache_data = await Cache.get_cache(CacheType.BOT)
|
|
||||||
if cache_data and cache_data.getter:
|
|
||||||
b = await cache_data.getter.get(cache_data.data)
|
|
||||||
if bot := await Cache.get(CacheType.BOT, bot_id):
|
|
||||||
return bot
|
|
||||||
|
|
||||||
|
|
||||||
async def auth_bot(plugin: PluginInfo, bot_id: str):
|
async def auth_bot(plugin: PluginInfo, bot_id: str):
|
||||||
"""机器人权限
|
"""bot层面的权限检查
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
plugin: PluginInfo
|
plugin: PluginInfo
|
||||||
bot_id: bot_id
|
bot_id: bot id
|
||||||
|
|
||||||
|
异常:
|
||||||
|
IgnoredException: 忽略插件
|
||||||
|
IgnoredException: 忽略插件
|
||||||
"""
|
"""
|
||||||
if not await BotConsole.get_bot_status(bot_id):
|
if cache := Cache[BotConsole](CacheType.BOT):
|
||||||
logger.debug("Bot休眠中阻断权限检测...", "AuthChecker")
|
bot = await cache.get(bot_id)
|
||||||
raise IgnoredException("BotConsole休眠权限检测 ignore")
|
if not bot or not bot.status:
|
||||||
if await BotConsole.is_block_plugin(bot_id, plugin.module):
|
logger.debug("Bot不存在或休眠中阻断权限检测...", "AuthChecker")
|
||||||
logger.debug(
|
raise IgnoredException("BotConsole休眠权限检测 ignore")
|
||||||
f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭...",
|
if CommonUtils.format(plugin.module) in bot.block_plugins:
|
||||||
"AuthChecker",
|
logger.debug(
|
||||||
)
|
f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭...",
|
||||||
raise IgnoredException("BotConsole插件权限检测 ignore")
|
"AuthChecker",
|
||||||
|
)
|
||||||
|
raise IgnoredException("BotConsole插件权限检测 ignore")
|
||||||
|
|||||||
35
zhenxun/builtin_plugins/hooks/auth/auth_cost.py
Normal file
35
zhenxun/builtin_plugins/hooks/auth/auth_cost.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
from nonebot.exception import IgnoredException
|
||||||
|
from nonebot_plugin_uninfo import Uninfo
|
||||||
|
|
||||||
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
|
from zhenxun.models.user_console import UserConsole
|
||||||
|
from zhenxun.services.log import logger
|
||||||
|
from zhenxun.utils.message import MessageUtils
|
||||||
|
|
||||||
|
|
||||||
|
async def auth_cost(user: UserConsole, plugin: PluginInfo, session: Uninfo) -> int:
|
||||||
|
"""检测是否满足金币条件
|
||||||
|
|
||||||
|
参数:
|
||||||
|
plugin: PluginInfo
|
||||||
|
session: Uninfo
|
||||||
|
|
||||||
|
返回:
|
||||||
|
int: 需要消耗的金币
|
||||||
|
"""
|
||||||
|
if user.gold < plugin.cost_gold:
|
||||||
|
"""插件消耗金币不足"""
|
||||||
|
try:
|
||||||
|
await MessageUtils.build_message(
|
||||||
|
f"金币不足..该功能需要{plugin.cost_gold}金币.."
|
||||||
|
).send()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("auth_cost 发送消息失败", "AuthChecker", session=session, e=e)
|
||||||
|
logger.debug(
|
||||||
|
f"{plugin.name}({plugin.module}) 金币限制.."
|
||||||
|
f"该功能需要{plugin.cost_gold}金币..",
|
||||||
|
"AuthChecker",
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
raise IgnoredException(f"{plugin.name}({plugin.module}) 金币限制...")
|
||||||
|
return plugin.cost_gold
|
||||||
57
zhenxun/builtin_plugins/hooks/auth/auth_group.py
Normal file
57
zhenxun/builtin_plugins/hooks/auth/auth_group.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
from nonebot.exception import IgnoredException
|
||||||
|
from nonebot_plugin_alconna import UniMsg
|
||||||
|
from nonebot_plugin_uninfo import Uninfo
|
||||||
|
|
||||||
|
from zhenxun.models.group_console import GroupConsole
|
||||||
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
|
from zhenxun.services.log import logger
|
||||||
|
from zhenxun.utils.cache_utils import Cache
|
||||||
|
from zhenxun.utils.enum import CacheType
|
||||||
|
|
||||||
|
|
||||||
|
async def auth_group(plugin: PluginInfo, session: Uninfo, message: UniMsg):
|
||||||
|
"""群黑名单检测 群总开关检测
|
||||||
|
|
||||||
|
参数:
|
||||||
|
plugin: PluginInfo
|
||||||
|
session: EventSession
|
||||||
|
message: UniMsg
|
||||||
|
"""
|
||||||
|
if not session.group:
|
||||||
|
return
|
||||||
|
if session.group.parent:
|
||||||
|
group_id = session.group.parent.id
|
||||||
|
else:
|
||||||
|
group_id = session.group.id
|
||||||
|
text = message.extract_plain_text()
|
||||||
|
group = await Cache[GroupConsole](CacheType.GROUPS).get(group_id)
|
||||||
|
if not group:
|
||||||
|
"""群不存在"""
|
||||||
|
logger.debug(
|
||||||
|
"群组信息不存在...",
|
||||||
|
"AuthChecker",
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
raise IgnoredException("群不存在")
|
||||||
|
if group.level < 0:
|
||||||
|
"""群权限小于0"""
|
||||||
|
logger.debug(
|
||||||
|
"群黑名单, 群权限-1...",
|
||||||
|
"AuthChecker",
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
raise IgnoredException("群黑名单")
|
||||||
|
if not group.status:
|
||||||
|
"""群休眠"""
|
||||||
|
if text.strip() != "醒来":
|
||||||
|
logger.debug("群休眠状态...", "AuthChecker", session=session)
|
||||||
|
raise IgnoredException("群休眠状态")
|
||||||
|
if plugin.level > group.level:
|
||||||
|
"""插件等级大于群等级"""
|
||||||
|
logger.debug(
|
||||||
|
f"{plugin.name}({plugin.module}) 群等级限制.."
|
||||||
|
f"该功能需要的群等级: {plugin.level}..",
|
||||||
|
"AuthChecker",
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
raise IgnoredException(f"{plugin.name}({plugin.module}) 群等级限制...")
|
||||||
189
zhenxun/builtin_plugins/hooks/auth/auth_limit.py
Normal file
189
zhenxun/builtin_plugins/hooks/auth/auth_limit.py
Normal file
@ -0,0 +1,189 @@
|
|||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
|
from nonebot.exception import IgnoredException
|
||||||
|
from nonebot_plugin_uninfo import Uninfo
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
|
from zhenxun.models.plugin_limit import PluginLimit
|
||||||
|
from zhenxun.services.log import logger
|
||||||
|
from zhenxun.utils.enum import LimitWatchType, PluginLimitType
|
||||||
|
from zhenxun.utils.message import MessageUtils
|
||||||
|
from zhenxun.utils.utils import CountLimiter, FreqLimiter, UserBlockLimiter
|
||||||
|
|
||||||
|
|
||||||
|
class Limit(BaseModel):
|
||||||
|
limit: PluginLimit
|
||||||
|
limiter: FreqLimiter | UserBlockLimiter | CountLimiter
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
|
class LimitManage:
|
||||||
|
add_module: ClassVar[list] = []
|
||||||
|
|
||||||
|
cd_limit: ClassVar[dict[str, Limit]] = {}
|
||||||
|
block_limit: ClassVar[dict[str, Limit]] = {}
|
||||||
|
count_limit: ClassVar[dict[str, Limit]] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_limit(cls, limit: PluginLimit):
|
||||||
|
"""添加限制
|
||||||
|
|
||||||
|
参数:
|
||||||
|
limit: PluginLimit
|
||||||
|
"""
|
||||||
|
if limit.module not in cls.add_module:
|
||||||
|
cls.add_module.append(limit.module)
|
||||||
|
if limit.limit_type == PluginLimitType.BLOCK:
|
||||||
|
cls.block_limit[limit.module] = Limit(
|
||||||
|
limit=limit, limiter=UserBlockLimiter()
|
||||||
|
)
|
||||||
|
elif limit.limit_type == PluginLimitType.CD:
|
||||||
|
cls.cd_limit[limit.module] = Limit(
|
||||||
|
limit=limit, limiter=FreqLimiter(limit.cd)
|
||||||
|
)
|
||||||
|
elif limit.limit_type == PluginLimitType.COUNT:
|
||||||
|
cls.count_limit[limit.module] = Limit(
|
||||||
|
limit=limit, limiter=CountLimiter(limit.max_count)
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def unblock(
|
||||||
|
cls, module: str, user_id: str, group_id: str | None, channel_id: str | None
|
||||||
|
):
|
||||||
|
"""解除插件block
|
||||||
|
|
||||||
|
参数:
|
||||||
|
module: 模块名
|
||||||
|
user_id: 用户id
|
||||||
|
group_id: 群组id
|
||||||
|
channel_id: 频道id
|
||||||
|
"""
|
||||||
|
if limit_model := cls.block_limit.get(module):
|
||||||
|
limit = limit_model.limit
|
||||||
|
limiter: UserBlockLimiter = limit_model.limiter # type: ignore
|
||||||
|
key_type = user_id
|
||||||
|
if group_id and limit.watch_type == LimitWatchType.GROUP:
|
||||||
|
key_type = channel_id or group_id
|
||||||
|
logger.debug(
|
||||||
|
f"解除对象: {key_type} 的block限制",
|
||||||
|
"AuthChecker",
|
||||||
|
session=user_id,
|
||||||
|
group_id=group_id,
|
||||||
|
)
|
||||||
|
limiter.set_false(key_type)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def check(
|
||||||
|
cls,
|
||||||
|
module: str,
|
||||||
|
user_id: str,
|
||||||
|
group_id: str | None,
|
||||||
|
channel_id: str | None,
|
||||||
|
session: Uninfo,
|
||||||
|
):
|
||||||
|
"""检测限制
|
||||||
|
|
||||||
|
参数:
|
||||||
|
module: 模块名
|
||||||
|
user_id: 用户id
|
||||||
|
group_id: 群组id
|
||||||
|
channel_id: 频道id
|
||||||
|
session: Session
|
||||||
|
|
||||||
|
异常:
|
||||||
|
IgnoredException: IgnoredException
|
||||||
|
"""
|
||||||
|
if limit_model := cls.cd_limit.get(module):
|
||||||
|
await cls.__check(limit_model, user_id, group_id, channel_id, session)
|
||||||
|
if limit_model := cls.block_limit.get(module):
|
||||||
|
await cls.__check(limit_model, user_id, group_id, channel_id, session)
|
||||||
|
if limit_model := cls.count_limit.get(module):
|
||||||
|
await cls.__check(limit_model, user_id, group_id, channel_id, session)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def __check(
|
||||||
|
cls,
|
||||||
|
limit_model: Limit | None,
|
||||||
|
user_id: str,
|
||||||
|
group_id: str | None,
|
||||||
|
channel_id: str | None,
|
||||||
|
session: Uninfo,
|
||||||
|
):
|
||||||
|
"""检测限制
|
||||||
|
|
||||||
|
参数:
|
||||||
|
limit_model: Limit
|
||||||
|
user_id: 用户id
|
||||||
|
group_id: 群组id
|
||||||
|
channel_id: 频道id
|
||||||
|
session: Session
|
||||||
|
|
||||||
|
异常:
|
||||||
|
IgnoredException: IgnoredException
|
||||||
|
"""
|
||||||
|
if not limit_model:
|
||||||
|
return
|
||||||
|
limit = limit_model.limit
|
||||||
|
limiter = limit_model.limiter
|
||||||
|
is_limit = (
|
||||||
|
LimitWatchType.ALL
|
||||||
|
or (group_id and limit.watch_type == LimitWatchType.GROUP)
|
||||||
|
or (not group_id and limit.watch_type == LimitWatchType.USER)
|
||||||
|
)
|
||||||
|
key_type = user_id
|
||||||
|
if group_id and limit.watch_type == LimitWatchType.GROUP:
|
||||||
|
key_type = channel_id or group_id
|
||||||
|
if is_limit and not limiter.check(key_type):
|
||||||
|
if limit.result:
|
||||||
|
await MessageUtils.build_message(limit.result).send()
|
||||||
|
logger.debug(
|
||||||
|
f"{limit.module}({limit.limit_type}) 正在限制中...",
|
||||||
|
"AuthChecker",
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
raise IgnoredException(f"{limit.module} 正在限制中...")
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
f"开始进行限制 {limit.module}({limit.limit_type})...",
|
||||||
|
"AuthChecker",
|
||||||
|
session=user_id,
|
||||||
|
group_id=group_id,
|
||||||
|
)
|
||||||
|
if isinstance(limiter, FreqLimiter):
|
||||||
|
limiter.start_cd(key_type)
|
||||||
|
if isinstance(limiter, UserBlockLimiter):
|
||||||
|
limiter.set_true(key_type)
|
||||||
|
if isinstance(limiter, CountLimiter):
|
||||||
|
limiter.increase(key_type)
|
||||||
|
|
||||||
|
|
||||||
|
async def auth_limit(plugin: PluginInfo, session: Uninfo):
|
||||||
|
"""插件限制
|
||||||
|
|
||||||
|
参数:
|
||||||
|
plugin: PluginInfo
|
||||||
|
session: Uninfo
|
||||||
|
"""
|
||||||
|
user_id = session.user.id
|
||||||
|
group_id = None
|
||||||
|
channel_id = None
|
||||||
|
if session.group:
|
||||||
|
if session.group.parent:
|
||||||
|
group_id = session.group.parent.id
|
||||||
|
channel_id = session.group.id
|
||||||
|
else:
|
||||||
|
group_id = session.group.id
|
||||||
|
if not group_id:
|
||||||
|
group_id = channel_id
|
||||||
|
channel_id = None
|
||||||
|
if plugin.module not in LimitManage.add_module:
|
||||||
|
limit_list: list[PluginLimit] = await plugin.plugin_limit.filter(
|
||||||
|
status=True
|
||||||
|
).all() # type: ignore
|
||||||
|
for limit in limit_list:
|
||||||
|
LimitManage.add_limit(limit)
|
||||||
|
if user_id:
|
||||||
|
await LimitManage.check(plugin.module, user_id, group_id, channel_id, session)
|
||||||
201
zhenxun/builtin_plugins/hooks/auth/auth_plugin.py
Normal file
201
zhenxun/builtin_plugins/hooks/auth/auth_plugin.py
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
from nonebot.adapters import Event
|
||||||
|
from nonebot.exception import IgnoredException
|
||||||
|
from nonebot_plugin_uninfo import Uninfo
|
||||||
|
|
||||||
|
from zhenxun.models.group_console import GroupConsole
|
||||||
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
|
from zhenxun.services.log import logger
|
||||||
|
from zhenxun.utils.cache_utils import Cache
|
||||||
|
from zhenxun.utils.common_utils import CommonUtils
|
||||||
|
from zhenxun.utils.enum import BlockType, CacheType
|
||||||
|
from zhenxun.utils.message import MessageUtils
|
||||||
|
|
||||||
|
from .exception import IsSuperuserException
|
||||||
|
from .utils import freq
|
||||||
|
|
||||||
|
|
||||||
|
class GroupCheck:
|
||||||
|
def __init__(
|
||||||
|
self, plugin: PluginInfo, group_id: str, session: Uninfo, is_poke: bool
|
||||||
|
) -> None:
|
||||||
|
self.group_id = group_id
|
||||||
|
self.session = session
|
||||||
|
self.is_poke = is_poke
|
||||||
|
self.plugin = plugin
|
||||||
|
|
||||||
|
async def __get_data(self):
|
||||||
|
cache = Cache[GroupConsole](CacheType.GROUPS)
|
||||||
|
return await cache.get(self.group_id)
|
||||||
|
|
||||||
|
async def check(self):
|
||||||
|
await self.check_superuser_block(self.plugin)
|
||||||
|
|
||||||
|
async def check_superuser_block(self, plugin: PluginInfo):
|
||||||
|
"""超级用户禁用群组插件检测
|
||||||
|
|
||||||
|
参数:
|
||||||
|
plugin: PluginInfo
|
||||||
|
|
||||||
|
异常:
|
||||||
|
IgnoredException: 忽略插件
|
||||||
|
"""
|
||||||
|
group = await self.__get_data()
|
||||||
|
if group and CommonUtils.format(plugin.module) in group.superuser_block_plugin:
|
||||||
|
if freq.is_send_limit_message(plugin, group.group_id, self.is_poke):
|
||||||
|
freq._flmt_s.start_cd(self.group_id)
|
||||||
|
await MessageUtils.build_message("超级管理员禁用了该群此功能...").send(
|
||||||
|
reply_to=True
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"{plugin.name}({plugin.module}) 超级管理员禁用了该群此功能...",
|
||||||
|
"AuthChecker",
|
||||||
|
session=self.session,
|
||||||
|
)
|
||||||
|
raise IgnoredException("超级管理员禁用了该群此功能...")
|
||||||
|
await self.check_normal_block(self.plugin)
|
||||||
|
|
||||||
|
async def check_normal_block(self, plugin: PluginInfo):
|
||||||
|
"""群组插件状态
|
||||||
|
|
||||||
|
参数:
|
||||||
|
plugin: PluginInfo
|
||||||
|
|
||||||
|
异常:
|
||||||
|
IgnoredException: 忽略插件
|
||||||
|
"""
|
||||||
|
group = await self.__get_data()
|
||||||
|
if group and CommonUtils.format(plugin.module) in group.block_plugin:
|
||||||
|
if freq.is_send_limit_message(plugin, self.group_id, self.is_poke):
|
||||||
|
freq._flmt_s.start_cd(self.group_id)
|
||||||
|
await MessageUtils.build_message("该群未开启此功能...").send(
|
||||||
|
reply_to=True
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"{plugin.name}({plugin.module}) 未开启此功能...",
|
||||||
|
"AuthChecker",
|
||||||
|
session=self.session,
|
||||||
|
)
|
||||||
|
raise IgnoredException("该群未开启此功能...")
|
||||||
|
await self.check_global_block(self.plugin)
|
||||||
|
|
||||||
|
async def check_global_block(self, plugin: PluginInfo):
|
||||||
|
"""全局禁用插件检测
|
||||||
|
|
||||||
|
参数:
|
||||||
|
plugin: PluginInfo
|
||||||
|
|
||||||
|
异常:
|
||||||
|
IgnoredException: 忽略插件
|
||||||
|
"""
|
||||||
|
if plugin.block_type == BlockType.GROUP:
|
||||||
|
"""全局群组禁用"""
|
||||||
|
try:
|
||||||
|
if freq.is_send_limit_message(plugin, self.group_id, self.is_poke):
|
||||||
|
freq._flmt_c.start_cd(self.group_id)
|
||||||
|
await MessageUtils.build_message("该功能在群组中已被禁用...").send(
|
||||||
|
reply_to=True
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"auth_plugin 发送消息失败",
|
||||||
|
"AuthChecker",
|
||||||
|
session=self.session,
|
||||||
|
e=e,
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"{plugin.name}({plugin.module}) 该插件在群组中已被禁用...",
|
||||||
|
"AuthChecker",
|
||||||
|
session=self.session,
|
||||||
|
)
|
||||||
|
raise IgnoredException("该插件在群组中已被禁用...")
|
||||||
|
|
||||||
|
|
||||||
|
class PluginCheck:
|
||||||
|
def __init__(self, group_id: str | None, session: Uninfo, is_poke: bool):
|
||||||
|
self.session = session
|
||||||
|
self.is_poke = is_poke
|
||||||
|
self.group_id = group_id
|
||||||
|
|
||||||
|
async def check_user(self, plugin: PluginInfo):
|
||||||
|
"""全局私聊禁用检测
|
||||||
|
|
||||||
|
参数:
|
||||||
|
plugin: PluginInfo
|
||||||
|
|
||||||
|
异常:
|
||||||
|
IgnoredException: 忽略插件
|
||||||
|
"""
|
||||||
|
if plugin.block_type == BlockType.PRIVATE:
|
||||||
|
try:
|
||||||
|
if freq.is_send_limit_message(
|
||||||
|
plugin, self.session.user.id, self.is_poke
|
||||||
|
):
|
||||||
|
freq._flmt_c.start_cd(self.session.user.id)
|
||||||
|
await MessageUtils.build_message("该功能在私聊中已被禁用...").send()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"auth_admin 发送消息失败",
|
||||||
|
"AuthChecker",
|
||||||
|
session=self.session,
|
||||||
|
e=e,
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"{plugin.name}({plugin.module}) 该插件在私聊中已被禁用...",
|
||||||
|
"AuthChecker",
|
||||||
|
session=self.session,
|
||||||
|
)
|
||||||
|
raise IgnoredException("该插件在私聊中已被禁用...")
|
||||||
|
|
||||||
|
async def check_global(self, plugin: PluginInfo):
|
||||||
|
"""全局状态
|
||||||
|
|
||||||
|
参数:
|
||||||
|
plugin: PluginInfo
|
||||||
|
|
||||||
|
异常:
|
||||||
|
IgnoredException: 忽略插件
|
||||||
|
"""
|
||||||
|
if not plugin.status and plugin.block_type == BlockType.ALL:
|
||||||
|
"""全局状态"""
|
||||||
|
cache = Cache[GroupConsole](CacheType.GROUPS)
|
||||||
|
if self.group_id and (group := await cache.get(self.group_id)):
|
||||||
|
if group.is_super:
|
||||||
|
raise IsSuperuserException()
|
||||||
|
logger.debug(
|
||||||
|
f"{plugin.name}({plugin.module}) 全局未开启此功能...",
|
||||||
|
"AuthChecker",
|
||||||
|
session=self.session,
|
||||||
|
)
|
||||||
|
sid = self.group_id or self.session.user.id
|
||||||
|
if freq.is_send_limit_message(plugin, sid, self.is_poke):
|
||||||
|
freq._flmt_s.start_cd(sid)
|
||||||
|
await MessageUtils.build_message("全局未开启此功能...").send()
|
||||||
|
raise IgnoredException("全局未开启此功能...")
|
||||||
|
|
||||||
|
|
||||||
|
async def auth_plugin(plugin: PluginInfo, session: Uninfo, event: Event):
|
||||||
|
"""插件状态
|
||||||
|
|
||||||
|
参数:
|
||||||
|
plugin: PluginInfo
|
||||||
|
session: Uninfo
|
||||||
|
"""
|
||||||
|
group_id = None
|
||||||
|
if session.group:
|
||||||
|
if session.group.parent:
|
||||||
|
group_id = session.group.parent.id
|
||||||
|
else:
|
||||||
|
group_id = session.group.id
|
||||||
|
try:
|
||||||
|
from nonebot.adapters.onebot.v11 import PokeNotifyEvent
|
||||||
|
|
||||||
|
is_poke = isinstance(event, PokeNotifyEvent)
|
||||||
|
except ImportError:
|
||||||
|
is_poke = False
|
||||||
|
user_check = PluginCheck(group_id, session, is_poke)
|
||||||
|
if group_id:
|
||||||
|
group_check = GroupCheck(plugin, group_id, session, is_poke)
|
||||||
|
await group_check.check()
|
||||||
|
else:
|
||||||
|
await user_check.check_user(plugin)
|
||||||
|
await user_check.check_global(plugin)
|
||||||
2
zhenxun/builtin_plugins/hooks/auth/exception.py
Normal file
2
zhenxun/builtin_plugins/hooks/auth/exception.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
class IsSuperuserException(Exception):
|
||||||
|
pass
|
||||||
41
zhenxun/builtin_plugins/hooks/auth/utils.py
Normal file
41
zhenxun/builtin_plugins/hooks/auth/utils.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
from zhenxun.configs.config import Config
|
||||||
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
|
from zhenxun.utils.enum import PluginType
|
||||||
|
from zhenxun.utils.utils import FreqLimiter
|
||||||
|
|
||||||
|
base_config = Config.get("hook")
|
||||||
|
|
||||||
|
|
||||||
|
class FreqUtils:
|
||||||
|
def __init__(self):
|
||||||
|
check_notice_info_cd = Config.get_config("hook", "CHECK_NOTICE_INFO_CD")
|
||||||
|
if check_notice_info_cd is None or check_notice_info_cd < 0:
|
||||||
|
raise ValueError("模块: [hook], 配置项: [CHECK_NOTICE_INFO_CD] 为空或小于0")
|
||||||
|
self._flmt = FreqLimiter(check_notice_info_cd)
|
||||||
|
self._flmt_g = FreqLimiter(check_notice_info_cd)
|
||||||
|
self._flmt_s = FreqLimiter(check_notice_info_cd)
|
||||||
|
self._flmt_c = FreqLimiter(check_notice_info_cd)
|
||||||
|
|
||||||
|
def is_send_limit_message(
|
||||||
|
self, plugin: PluginInfo, sid: str, is_poke: bool
|
||||||
|
) -> bool:
|
||||||
|
"""是否发送提示消息
|
||||||
|
|
||||||
|
参数:
|
||||||
|
plugin: PluginInfo
|
||||||
|
sid: 检测键
|
||||||
|
is_poke: 是否是戳一戳
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bool: 是否发送提示消息
|
||||||
|
"""
|
||||||
|
if is_poke:
|
||||||
|
return False
|
||||||
|
if not base_config.get("IS_SEND_TIP_MESSAGE"):
|
||||||
|
return False
|
||||||
|
if plugin.plugin_type == PluginType.DEPENDANT:
|
||||||
|
return False
|
||||||
|
return plugin.module != "ai" if self._flmt_s.check(sid) else False
|
||||||
|
|
||||||
|
|
||||||
|
freq = FreqUtils()
|
||||||
125
zhenxun/builtin_plugins/hooks/auth_checker.py
Normal file
125
zhenxun/builtin_plugins/hooks/auth_checker.py
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
from nonebot.adapters import Bot, Event
|
||||||
|
from nonebot.exception import IgnoredException
|
||||||
|
from nonebot.matcher import Matcher
|
||||||
|
from nonebot_plugin_alconna import UniMsg
|
||||||
|
from nonebot_plugin_uninfo import Uninfo
|
||||||
|
from tortoise.exceptions import IntegrityError
|
||||||
|
|
||||||
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
|
from zhenxun.models.user_console import UserConsole
|
||||||
|
from zhenxun.services.log import logger
|
||||||
|
from zhenxun.utils.cache_utils import Cache
|
||||||
|
from zhenxun.utils.enum import (
|
||||||
|
CacheType,
|
||||||
|
GoldHandle,
|
||||||
|
PluginType,
|
||||||
|
)
|
||||||
|
from zhenxun.utils.exception import InsufficientGold
|
||||||
|
from zhenxun.utils.platform import PlatformUtils
|
||||||
|
|
||||||
|
from .auth.auth_admin import auth_admin
|
||||||
|
from .auth.auth_bot import auth_bot
|
||||||
|
from .auth.auth_cost import auth_cost
|
||||||
|
from .auth.auth_group import auth_group
|
||||||
|
from .auth.auth_limit import LimitManage, auth_limit
|
||||||
|
from .auth.auth_plugin import auth_plugin
|
||||||
|
from .auth.exception import IsSuperuserException
|
||||||
|
|
||||||
|
|
||||||
|
async def auth(
|
||||||
|
matcher: Matcher,
|
||||||
|
event: Event,
|
||||||
|
bot: Bot,
|
||||||
|
session: Uninfo,
|
||||||
|
message: UniMsg,
|
||||||
|
):
|
||||||
|
"""权限检查
|
||||||
|
|
||||||
|
参数:
|
||||||
|
matcher: matcher
|
||||||
|
event: Event
|
||||||
|
bot: bot
|
||||||
|
session: Uninfo
|
||||||
|
message: UniMsg
|
||||||
|
"""
|
||||||
|
user_id = session.user.id
|
||||||
|
group_id = None
|
||||||
|
channel_id = None
|
||||||
|
if session.group:
|
||||||
|
if session.group.parent:
|
||||||
|
group_id = session.group.parent.id
|
||||||
|
channel_id = session.group.id
|
||||||
|
else:
|
||||||
|
group_id = session.group.id
|
||||||
|
is_ignore = False
|
||||||
|
cost_gold = 0
|
||||||
|
try:
|
||||||
|
from nonebot.adapters.onebot.v11 import PokeNotifyEvent
|
||||||
|
|
||||||
|
if matcher.type == "notice" and not isinstance(event, PokeNotifyEvent):
|
||||||
|
"""过滤除poke外的notice"""
|
||||||
|
return
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
user_cache = Cache[UserConsole](CacheType.USERS)
|
||||||
|
if matcher.plugin and (module := matcher.plugin.name):
|
||||||
|
try:
|
||||||
|
user = await user_cache.get(session.user.id)
|
||||||
|
except IntegrityError as e:
|
||||||
|
logger.debug(
|
||||||
|
"重复创建用户,已跳过该次权限检查...",
|
||||||
|
"AuthChecker",
|
||||||
|
session=session,
|
||||||
|
e=e,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
plugin = await Cache[PluginInfo](CacheType.PLUGINS).get(module)
|
||||||
|
if user and plugin:
|
||||||
|
if plugin.plugin_type == PluginType.HIDDEN:
|
||||||
|
logger.debug(
|
||||||
|
f"插件: {plugin.name}:{plugin.module} "
|
||||||
|
"为HIDDEN,已跳过权限检查..."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
cost_gold = await auth_cost(user, plugin, session)
|
||||||
|
if session.user.id in bot.config.superusers:
|
||||||
|
if plugin.plugin_type == PluginType.SUPERUSER:
|
||||||
|
raise IsSuperuserException()
|
||||||
|
if not plugin.limit_superuser:
|
||||||
|
cost_gold = 0
|
||||||
|
raise IsSuperuserException()
|
||||||
|
await auth_bot(plugin, bot.self_id)
|
||||||
|
await auth_group(plugin, session, message)
|
||||||
|
await auth_admin(plugin, session)
|
||||||
|
await auth_plugin(plugin, session, event)
|
||||||
|
await auth_limit(plugin, session)
|
||||||
|
except IsSuperuserException:
|
||||||
|
logger.debug(
|
||||||
|
"超级用户或被ban跳过权限检测...", "AuthChecker", session=session
|
||||||
|
)
|
||||||
|
except IgnoredException:
|
||||||
|
is_ignore = True
|
||||||
|
LimitManage.unblock(matcher.plugin.name, user_id, group_id, channel_id)
|
||||||
|
except AssertionError as e:
|
||||||
|
is_ignore = True
|
||||||
|
logger.debug("消息无法发送", session=session, e=e)
|
||||||
|
if cost_gold and user_id:
|
||||||
|
"""花费金币"""
|
||||||
|
try:
|
||||||
|
await UserConsole.reduce_gold(
|
||||||
|
user_id,
|
||||||
|
cost_gold,
|
||||||
|
GoldHandle.PLUGIN,
|
||||||
|
matcher.plugin.name if matcher.plugin else "",
|
||||||
|
PlatformUtils.get_platform(session),
|
||||||
|
)
|
||||||
|
except InsufficientGold:
|
||||||
|
if u := await UserConsole.get_user(user_id):
|
||||||
|
u.gold = 0
|
||||||
|
await u.save(update_fields=["gold"])
|
||||||
|
# 更新缓存
|
||||||
|
await user_cache.update(user_id)
|
||||||
|
logger.debug(f"调用功能花费金币: {cost_gold}", "AuthChecker", session=session)
|
||||||
|
if is_ignore:
|
||||||
|
raise IgnoredException("权限检测 ignore")
|
||||||
@ -1,18 +1,16 @@
|
|||||||
from nonebot.adapters.onebot.v11 import Bot, Event
|
from nonebot.adapters import Bot, Event
|
||||||
from nonebot.matcher import Matcher
|
from nonebot.matcher import Matcher
|
||||||
from nonebot.message import run_postprocessor, run_preprocessor
|
from nonebot.message import run_postprocessor, run_preprocessor
|
||||||
from nonebot_plugin_alconna import UniMsg
|
from nonebot_plugin_alconna import UniMsg
|
||||||
from nonebot_plugin_session import EventSession
|
from nonebot_plugin_uninfo import Uninfo
|
||||||
|
|
||||||
from ._auth_checker import LimitManage, checker
|
from .auth_checker import LimitManage, auth
|
||||||
|
|
||||||
|
|
||||||
# # 权限检测
|
# # 权限检测
|
||||||
@run_preprocessor
|
@run_preprocessor
|
||||||
async def _(
|
async def _(matcher: Matcher, event: Event, bot: Bot, session: Uninfo, message: UniMsg):
|
||||||
matcher: Matcher, event: Event, bot: Bot, session: EventSession, message: UniMsg
|
await auth(
|
||||||
):
|
|
||||||
await checker.auth(
|
|
||||||
matcher,
|
matcher,
|
||||||
event,
|
event,
|
||||||
bot,
|
bot,
|
||||||
@ -23,19 +21,16 @@ async def _(
|
|||||||
|
|
||||||
# 解除命令block阻塞
|
# 解除命令block阻塞
|
||||||
@run_postprocessor
|
@run_postprocessor
|
||||||
async def _(
|
async def _(matcher: Matcher, session: Uninfo):
|
||||||
matcher: Matcher,
|
user_id = session.user.id
|
||||||
exception: Exception | None,
|
group_id = None
|
||||||
bot: Bot,
|
channel_id = None
|
||||||
event: Event,
|
if session.group:
|
||||||
session: EventSession,
|
if session.group.parent:
|
||||||
):
|
group_id = session.group.parent.id
|
||||||
user_id = session.id1
|
channel_id = session.group.id
|
||||||
group_id = session.id3
|
else:
|
||||||
channel_id = session.id2
|
group_id = session.group.id
|
||||||
if not group_id:
|
|
||||||
group_id = channel_id
|
|
||||||
channel_id = None
|
|
||||||
if user_id and matcher.plugin:
|
if user_id and matcher.plugin:
|
||||||
module = matcher.plugin.name
|
module = matcher.plugin.name
|
||||||
LimitManage.unblock(module, user_id, group_id, channel_id)
|
LimitManage.unblock(module, user_id, group_id, channel_id)
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import nonebot
|
import nonebot
|
||||||
from nonebot.adapters import Bot
|
from nonebot.adapters import Bot
|
||||||
@ -6,9 +7,11 @@ from nonebot.adapters import Bot
|
|||||||
from zhenxun.models.ban_console import BanConsole
|
from zhenxun.models.ban_console import BanConsole
|
||||||
from zhenxun.models.bot_console import BotConsole
|
from zhenxun.models.bot_console import BotConsole
|
||||||
from zhenxun.models.group_console import GroupConsole
|
from zhenxun.models.group_console import GroupConsole
|
||||||
|
from zhenxun.models.level_user import LevelUser
|
||||||
from zhenxun.models.plugin_info import PluginInfo
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
|
from zhenxun.models.user_console import UserConsole
|
||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
from zhenxun.utils.cache_utils import Cache
|
from zhenxun.utils.cache_utils import CacheRoot
|
||||||
from zhenxun.utils.enum import CacheType
|
from zhenxun.utils.enum import CacheType
|
||||||
from zhenxun.utils.platform import PlatformUtils
|
from zhenxun.utils.platform import PlatformUtils
|
||||||
|
|
||||||
@ -48,45 +51,131 @@ async def _(bot: Bot):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@Cache.listener(CacheType.PLUGINS)
|
@CacheRoot.new(CacheType.PLUGINS)
|
||||||
async def _():
|
async def _(data: dict[str, PluginInfo] = {}, key: str | None = None):
|
||||||
data_list = await PluginInfo.get_plugins()
|
if data and key:
|
||||||
return {p.module: p for p in data_list}
|
if plugin := await PluginInfo.get_plugin(module=key):
|
||||||
|
data[key] = plugin
|
||||||
|
else:
|
||||||
|
data_list = await PluginInfo.get_plugins()
|
||||||
|
return {p.module: p for p in data_list}
|
||||||
|
|
||||||
|
|
||||||
@Cache.getter(CacheType.PLUGINS, result_model=PluginInfo)
|
@CacheRoot.updater(CacheType.PLUGINS)
|
||||||
def _(data: dict[str, PluginInfo], module: str):
|
async def _(data: dict[str, PluginInfo], key: str, value: Any):
|
||||||
return data.get(module, None)
|
if value:
|
||||||
|
data[key] = value
|
||||||
|
elif plugin := await PluginInfo.get_plugin(module=key):
|
||||||
|
data[key] = plugin
|
||||||
|
|
||||||
|
|
||||||
@Cache.listener(CacheType.GROUPS)
|
@CacheRoot.getter(CacheType.PLUGINS, result_model=PluginInfo)
|
||||||
|
async def _(data: dict[str, PluginInfo], module: str):
|
||||||
|
result = data.get(module, None)
|
||||||
|
if not result:
|
||||||
|
result = await PluginInfo.get_plugin(module=module)
|
||||||
|
if result:
|
||||||
|
data[module] = result
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@CacheRoot.new(CacheType.GROUPS)
|
||||||
async def _():
|
async def _():
|
||||||
data_list = await GroupConsole.all()
|
data_list = await GroupConsole.all()
|
||||||
return {p.group_id: p for p in data_list}
|
return {p.group_id: p for p in data_list if not p.channel_id}
|
||||||
|
|
||||||
|
|
||||||
@Cache.getter(CacheType.GROUPS, result_model=GroupConsole)
|
@CacheRoot.updater(CacheType.GROUPS)
|
||||||
def _(data: dict[str, GroupConsole], module: str):
|
async def _(data: dict[str, GroupConsole], key: str, value: Any):
|
||||||
return data.get(module, None)
|
if value:
|
||||||
|
data[key] = value
|
||||||
|
elif group := await GroupConsole.get_group(group_id=key):
|
||||||
|
data[key] = group
|
||||||
|
|
||||||
|
|
||||||
@Cache.listener(CacheType.BOT)
|
@CacheRoot.getter(CacheType.GROUPS, result_model=GroupConsole)
|
||||||
|
async def _(data: dict[str, GroupConsole], group_id: str):
|
||||||
|
result = data.get(group_id, None)
|
||||||
|
if not result:
|
||||||
|
result = await GroupConsole.get_group(group_id=group_id)
|
||||||
|
if result:
|
||||||
|
data[group_id] = result
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@CacheRoot.new(CacheType.BOT)
|
||||||
async def _():
|
async def _():
|
||||||
data_list = await BotConsole.all()
|
data_list = await BotConsole.all()
|
||||||
return {p.bot_id: p for p in data_list}
|
return {p.bot_id: p for p in data_list}
|
||||||
|
|
||||||
|
|
||||||
@Cache.getter(CacheType.BOT, result_model=BotConsole)
|
@CacheRoot.updater(CacheType.BOT)
|
||||||
def _(data: dict[str, BotConsole], module: str):
|
async def _(data: dict[str, BotConsole], key: str, value: Any):
|
||||||
return data.get(module, None)
|
if value:
|
||||||
|
data[key] = value
|
||||||
|
elif bot := await BotConsole.get_or_none(bot_id=key):
|
||||||
|
data[key] = bot
|
||||||
|
|
||||||
|
|
||||||
@Cache.listener(CacheType.BAN)
|
@CacheRoot.getter(CacheType.BOT, result_model=BotConsole)
|
||||||
|
async def _(data: dict[str, BotConsole], bot_id: str):
|
||||||
|
result = data.get(bot_id, None)
|
||||||
|
if not result:
|
||||||
|
result = await BotConsole.get_or_none(bot_id=bot_id)
|
||||||
|
if result:
|
||||||
|
data[bot_id] = result
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@CacheRoot.new(CacheType.USERS)
|
||||||
|
async def _():
|
||||||
|
data_list = await UserConsole.all()
|
||||||
|
return {p.user_id: p for p in data_list}
|
||||||
|
|
||||||
|
|
||||||
|
@CacheRoot.updater(CacheType.USERS)
|
||||||
|
async def _(data: dict[str, UserConsole], key: str, value: Any):
|
||||||
|
if value:
|
||||||
|
data[key] = value
|
||||||
|
elif user := await UserConsole.get_user(user_id=key):
|
||||||
|
data[key] = user
|
||||||
|
|
||||||
|
|
||||||
|
@CacheRoot.getter(CacheType.USERS, result_model=UserConsole)
|
||||||
|
async def _(data: dict[str, UserConsole], user_id: str):
|
||||||
|
result = data.get(user_id, None)
|
||||||
|
if not result:
|
||||||
|
result = await UserConsole.get_user(user_id=user_id)
|
||||||
|
if result:
|
||||||
|
data[user_id] = result
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@CacheRoot.new(CacheType.LEVEL)
|
||||||
|
async def _():
|
||||||
|
return await LevelUser().all()
|
||||||
|
|
||||||
|
|
||||||
|
@CacheRoot.getter(CacheType.LEVEL, result_model=list[LevelUser])
|
||||||
|
def _(data_list: list[LevelUser], user_id: str, group_id: str | None = None):
|
||||||
|
if not group_id:
|
||||||
|
return [
|
||||||
|
data for data in data_list if data.user_id == user_id and not data.group_id
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return [
|
||||||
|
data
|
||||||
|
for data in data_list
|
||||||
|
if data.user_id == user_id and data.group_id == group_id
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@CacheRoot.new(CacheType.BAN)
|
||||||
async def _():
|
async def _():
|
||||||
return await BanConsole.all()
|
return await BanConsole.all()
|
||||||
|
|
||||||
|
|
||||||
@Cache.getter(CacheType.BAN, result_model=list[BanConsole])
|
@CacheRoot.getter(CacheType.BAN, result_model=list[BanConsole])
|
||||||
def _(data_list: list[BanConsole], user_id: str, group_id: str):
|
def _(data_list: list[BanConsole], user_id: str, group_id: str):
|
||||||
if user_id:
|
if user_id:
|
||||||
if group_id:
|
if group_id:
|
||||||
|
|||||||
@ -7,7 +7,8 @@ from tortoise.backends.base.client import BaseDBAsyncClient
|
|||||||
from zhenxun.models.plugin_info import PluginInfo
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
from zhenxun.models.task_info import TaskInfo
|
from zhenxun.models.task_info import TaskInfo
|
||||||
from zhenxun.services.db_context import Model
|
from zhenxun.services.db_context import Model
|
||||||
from zhenxun.utils.enum import PluginType
|
from zhenxun.utils.cache_utils import CacheRoot
|
||||||
|
from zhenxun.utils.enum import CacheType, PluginType
|
||||||
|
|
||||||
|
|
||||||
class GroupConsole(Model):
|
class GroupConsole(Model):
|
||||||
@ -46,8 +47,7 @@ class GroupConsole(Model):
|
|||||||
platform = fields.CharField(255, default="qq", description="所属平台")
|
platform = fields.CharField(255, default="qq", description="所属平台")
|
||||||
"""所属平台"""
|
"""所属平台"""
|
||||||
|
|
||||||
class Meta: # type: ignore
|
class Meta: # type: ignore table = "group_console"
|
||||||
table = "group_console"
|
|
||||||
table_description = "群组信息表"
|
table_description = "群组信息表"
|
||||||
unique_together = ("group_id", "channel_id")
|
unique_together = ("group_id", "channel_id")
|
||||||
|
|
||||||
@ -80,6 +80,7 @@ class GroupConsole(Model):
|
|||||||
return "".join(cls.format(item) for item in data)
|
return "".join(cls.format(item) for item in data)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@CacheRoot.listener(CacheType.GROUPS)
|
||||||
async def create(
|
async def create(
|
||||||
cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any
|
cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any
|
||||||
) -> Self:
|
) -> Self:
|
||||||
@ -100,6 +101,7 @@ class GroupConsole(Model):
|
|||||||
return group
|
return group
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@CacheRoot.listener(CacheType.GROUPS)
|
||||||
async def get_or_create(
|
async def get_or_create(
|
||||||
cls,
|
cls,
|
||||||
defaults: dict | None = None,
|
defaults: dict | None = None,
|
||||||
@ -127,6 +129,7 @@ class GroupConsole(Model):
|
|||||||
return group, is_create
|
return group, is_create
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@CacheRoot.listener(CacheType.GROUPS)
|
||||||
async def update_or_create(
|
async def update_or_create(
|
||||||
cls,
|
cls,
|
||||||
defaults: dict | None = None,
|
defaults: dict | None = None,
|
||||||
@ -216,6 +219,7 @@ class GroupConsole(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@CacheRoot.listener(CacheType.GROUPS)
|
||||||
async def set_block_plugin(
|
async def set_block_plugin(
|
||||||
cls,
|
cls,
|
||||||
group_id: str,
|
group_id: str,
|
||||||
@ -242,6 +246,7 @@ class GroupConsole(Model):
|
|||||||
await group.save(update_fields=["block_plugin", "superuser_block_plugin"])
|
await group.save(update_fields=["block_plugin", "superuser_block_plugin"])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@CacheRoot.listener(CacheType.GROUPS)
|
||||||
async def set_unblock_plugin(
|
async def set_unblock_plugin(
|
||||||
cls,
|
cls,
|
||||||
group_id: str,
|
group_id: str,
|
||||||
@ -338,6 +343,7 @@ class GroupConsole(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@CacheRoot.listener(CacheType.GROUPS)
|
||||||
async def set_block_task(
|
async def set_block_task(
|
||||||
cls,
|
cls,
|
||||||
group_id: str,
|
group_id: str,
|
||||||
@ -364,6 +370,7 @@ class GroupConsole(Model):
|
|||||||
await group.save(update_fields=["block_task", "superuser_block_task"])
|
await group.save(update_fields=["block_task", "superuser_block_task"])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@CacheRoot.listener(CacheType.GROUPS)
|
||||||
async def set_unblock_task(
|
async def set_unblock_task(
|
||||||
cls,
|
cls,
|
||||||
group_id: str,
|
group_id: str,
|
||||||
|
|||||||
@ -1,10 +1,13 @@
|
|||||||
|
from typing import Any
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from tortoise import fields
|
from tortoise import fields
|
||||||
|
from tortoise.backends.base.client import BaseDBAsyncClient
|
||||||
|
|
||||||
from zhenxun.models.plugin_limit import PluginLimit # noqa: F401
|
from zhenxun.models.plugin_limit import PluginLimit # noqa: F401
|
||||||
from zhenxun.services.db_context import Model
|
from zhenxun.services.db_context import Model
|
||||||
from zhenxun.utils.enum import BlockType, PluginType
|
from zhenxun.utils.cache_utils import CacheRoot
|
||||||
|
from zhenxun.utils.enum import BlockType, CacheType, PluginType
|
||||||
|
|
||||||
|
|
||||||
class PluginInfo(Model):
|
class PluginInfo(Model):
|
||||||
@ -79,6 +82,13 @@ class PluginInfo(Model):
|
|||||||
"""
|
"""
|
||||||
return await cls.filter(load_status=load_status, **kwargs).all()
|
return await cls.filter(load_status=load_status, **kwargs).all()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@CacheRoot.listener(CacheType.PLUGINS)
|
||||||
|
async def create(
|
||||||
|
cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any
|
||||||
|
) -> Self:
|
||||||
|
return await super().create(using_db=using_db, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _run_script(cls):
|
async def _run_script(cls):
|
||||||
return [
|
return [
|
||||||
|
|||||||
@ -1,11 +1,14 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from functools import wraps
|
||||||
import time
|
import time
|
||||||
from typing import Any, ClassVar, Generic, TypeVar, cast
|
from typing import Any, ClassVar, Generic, TypeVar, cast
|
||||||
|
|
||||||
from nonebot.utils import is_coroutine_callable
|
from nonebot.utils import is_coroutine_callable
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
__all__ = ["Cache", "CacheData"]
|
from zhenxun.services.log import logger
|
||||||
|
|
||||||
|
__all__ = ["Cache", "CacheData", "CacheRoot"]
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
@ -27,10 +30,14 @@ class CacheGetter(BaseModel, Generic[T]):
|
|||||||
|
|
||||||
|
|
||||||
class CacheData(BaseModel):
|
class CacheData(BaseModel):
|
||||||
|
name: str
|
||||||
|
"""缓存名称"""
|
||||||
func: Callable[..., Any]
|
func: Callable[..., Any]
|
||||||
"""更新方法"""
|
"""更新方法"""
|
||||||
getter: CacheGetter | None = None
|
getter: CacheGetter | None = None
|
||||||
"""获取方法"""
|
"""获取方法"""
|
||||||
|
updater: Callable[..., Any] | None = None
|
||||||
|
"""更新单个方法"""
|
||||||
data: Any = None
|
data: Any = None
|
||||||
"""缓存数据"""
|
"""缓存数据"""
|
||||||
expire: int
|
expire: int
|
||||||
@ -40,13 +47,36 @@ class CacheData(BaseModel):
|
|||||||
reload_count: int = 0
|
reload_count: int = 0
|
||||||
"""更新次数"""
|
"""更新次数"""
|
||||||
|
|
||||||
async def reload(self):
|
async def get(self, *args, **kwargs) -> Any:
|
||||||
|
"""获取单个缓存"""
|
||||||
|
if not self.getter:
|
||||||
|
return self.data
|
||||||
|
return await self.getter.get(self.data, *args, **kwargs)
|
||||||
|
|
||||||
|
async def update(self, key: str, value: Any = None, *args, **kwargs):
|
||||||
|
"""更新单个缓存"""
|
||||||
|
if not self.updater:
|
||||||
|
return logger.warning(
|
||||||
|
f"缓存类型 {self.name} 没有更新方法,无法更新", "CacheRoot"
|
||||||
|
)
|
||||||
|
if self.data:
|
||||||
|
if is_coroutine_callable(self.updater):
|
||||||
|
await self.updater(self.data, key, value, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
self.updater(self.data, key, value, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
logger.warning(f"缓存类型 {self.name} 为空,无法更新", "CacheRoot")
|
||||||
|
|
||||||
|
async def reload(self, *args, **kwargs):
|
||||||
"""更新缓存"""
|
"""更新缓存"""
|
||||||
self.data = (
|
self.data = (
|
||||||
await self.func() if is_coroutine_callable(self.func) else self.func()
|
await self.func(*args, **kwargs)
|
||||||
|
if is_coroutine_callable(self.func)
|
||||||
|
else self.func(*args, **kwargs)
|
||||||
)
|
)
|
||||||
self.reload_time = time.time()
|
self.reload_time = time.time()
|
||||||
self.reload_count += 1
|
self.reload_count += 1
|
||||||
|
logger.debug(f"缓存类型 {self.name} 更新全局缓存", "CacheRoot")
|
||||||
|
|
||||||
async def check_expire(self):
|
async def check_expire(self):
|
||||||
if time.time() - self.reload_time > self.expire or not self.reload_count:
|
if time.time() - self.reload_time > self.expire or not self.reload_count:
|
||||||
@ -54,14 +84,54 @@ class CacheData(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class CacheManage:
|
class CacheManage:
|
||||||
|
"""全局缓存管理,减少数据库与网络请求查询次数
|
||||||
|
|
||||||
|
|
||||||
|
异常:
|
||||||
|
ValueError: 数据名称重复
|
||||||
|
ValueError: 数据不存在
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
_data: ClassVar[dict[str, CacheData]] = {}
|
_data: ClassVar[dict[str, CacheData]] = {}
|
||||||
|
|
||||||
def listener(self, name: str, expire: int = 60 * 10):
|
def new(self, name: str, expire: int = 60 * 10):
|
||||||
def wrapper(func: Callable):
|
def wrapper(func: Callable):
|
||||||
_name = name.upper()
|
_name = name.upper()
|
||||||
if _name in self._data:
|
if _name in self._data:
|
||||||
raise ValueError(f"DbCache 缓存数据 {name} 已存在...")
|
raise ValueError(f"DbCache 缓存数据 {name} 已存在...")
|
||||||
self._data[_name] = CacheData(func=func, expire=expire)
|
self._data[_name] = CacheData(name=_name, func=func, expire=expire)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
def listener(self, name: str):
|
||||||
|
def decorator(func):
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
if is_coroutine_callable(func):
|
||||||
|
result = await func(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
return result
|
||||||
|
finally:
|
||||||
|
cache_data = self._data.get(name.upper())
|
||||||
|
if cache_data:
|
||||||
|
await cache_data.reload()
|
||||||
|
logger.debug(
|
||||||
|
f"缓存类型 {name.upper()} 进行监听更新...", "CacheRoot"
|
||||||
|
)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def updater(self, name: str):
|
||||||
|
def wrapper(func: Callable):
|
||||||
|
_name = name.upper()
|
||||||
|
if _name not in self._data:
|
||||||
|
raise ValueError(f"DbCache 缓存数据 {name} 不存在...")
|
||||||
|
self._data[_name].updater = func
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
@ -83,12 +153,12 @@ class CacheManage:
|
|||||||
):
|
):
|
||||||
await self._data[name].reload()
|
await self._data[name].reload()
|
||||||
|
|
||||||
async def get_cache_data(self, name: str) -> CacheData | None:
|
async def get_cache_data(self, name: str):
|
||||||
if cache := await self.get_cache(name):
|
if cache := await self.get_cache(name):
|
||||||
return cache
|
return cache.data
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_cache(self, name: str):
|
async def get_cache(self, name: str) -> CacheData | None:
|
||||||
name = name.upper()
|
name = name.upper()
|
||||||
cache = self._data.get(name)
|
cache = self._data.get(name)
|
||||||
if cache:
|
if cache:
|
||||||
@ -96,18 +166,35 @@ class CacheManage:
|
|||||||
return cache
|
return cache
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get(self, name: str, *args, **kwargs) -> T | None:
|
async def get(self, name: str, *args, **kwargs):
|
||||||
cache = self._data.get(name.upper())
|
cache = await self.get_cache(name.upper())
|
||||||
if cache:
|
if cache:
|
||||||
return (
|
return await cache.get(*args, **kwargs) if cache.getter else cache.data
|
||||||
await cache.getter.get(*args, **kwargs) if cache.getter else cache.data
|
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def reload(self, name: str):
|
async def reload(self, name: str, *args, **kwargs):
|
||||||
cache = self._data.get(name.upper())
|
cache = await self.get_cache(name.upper())
|
||||||
if cache:
|
if cache:
|
||||||
await cache.reload()
|
await cache.reload(*args, **kwargs)
|
||||||
|
|
||||||
|
async def update(self, name: str, key: str, value: Any, *args, **kwargs):
|
||||||
|
cache = await self.get_cache(name.upper())
|
||||||
|
if cache:
|
||||||
|
await cache.update(key, value, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
Cache = CacheManage()
|
CacheRoot = CacheManage()
|
||||||
|
|
||||||
|
|
||||||
|
class Cache(Generic[T]):
|
||||||
|
def __init__(self, module: str):
|
||||||
|
self.module = module
|
||||||
|
|
||||||
|
async def get(self, *args, **kwargs) -> T | None:
|
||||||
|
return await CacheRoot.get(self.module, *args, **kwargs)
|
||||||
|
|
||||||
|
async def update(self, key: str, value: Any = None, *args, **kwargs):
|
||||||
|
return await CacheRoot.update(self.module, key, value, *args, **kwargs)
|
||||||
|
|
||||||
|
async def reload(self, key: str | None = None, *args, **kwargs):
|
||||||
|
await CacheRoot.reload(self.module, key, *args, **kwargs)
|
||||||
|
|||||||
@ -10,10 +10,14 @@ class CacheType(StrEnum):
|
|||||||
"""全局全部插件"""
|
"""全局全部插件"""
|
||||||
GROUPS = "GLOBAL_ALL_GROUPS"
|
GROUPS = "GLOBAL_ALL_GROUPS"
|
||||||
"""全局全部群组"""
|
"""全局全部群组"""
|
||||||
|
USERS = "GLOBAL_ALL_USERS"
|
||||||
|
"""全部用户"""
|
||||||
BAN = "GLOBAL_ALL_BAN"
|
BAN = "GLOBAL_ALL_BAN"
|
||||||
"""全局ban列表"""
|
"""全局ban列表"""
|
||||||
BOT = "GLOBAL_BOT"
|
BOT = "GLOBAL_BOT"
|
||||||
"""全局bot信息"""
|
"""全局bot信息"""
|
||||||
|
LEVEL = "GLOBAL_USER_LEVEL"
|
||||||
|
"""用户权限"""
|
||||||
|
|
||||||
|
|
||||||
class GoldHandle(StrEnum):
|
class GoldHandle(StrEnum):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user