优化hook权限检测性能

This commit is contained in:
HibiKier 2025-04-08 17:11:44 +08:00
parent 03f0185e46
commit 41dd767724
19 changed files with 718 additions and 530 deletions

View File

@ -196,7 +196,7 @@ class PluginManage:
await PluginInfo.filter(plugin_type=PluginType.NORMAL).update( await PluginInfo.filter(plugin_type=PluginType.NORMAL).update(
default_status=status default_status=status
) )
return f'成功将所有功能进群默认状态修改为: {"开启" if status else "关闭"}' return f"成功将所有功能进群默认状态修改为: {'开启' if status else '关闭'}"
if group_id: if group_id:
if group := await GroupConsole.get_or_none( if group := await GroupConsole.get_or_none(
group_id=group_id, channel_id__isnull=True group_id=group_id, channel_id__isnull=True
@ -213,12 +213,12 @@ class PluginManage:
module_list = [f"<{module}" for module in module_list] module_list = [f"<{module}" for module in module_list]
group.block_plugin = ",".join(module_list) + "," # type: ignore group.block_plugin = ",".join(module_list) + "," # type: ignore
await group.save(update_fields=["block_plugin"]) await group.save(update_fields=["block_plugin"])
return f'成功将此群组所有功能状态修改为: {"开启" if status else "关闭"}' return f"成功将此群组所有功能状态修改为: {'开启' if status else '关闭'}"
return "获取群组失败..." return "获取群组失败..."
await PluginInfo.filter(plugin_type=PluginType.NORMAL).update( await PluginInfo.filter(plugin_type=PluginType.NORMAL).update(
status=status, block_type=None if status else BlockType.ALL status=status, block_type=None if status else BlockType.ALL
) )
return f'成功将所有功能全局状态修改为: {"开启" if status else "关闭"}' return f"成功将所有功能全局状态修改为: {'开启' if status else '关闭'}"
@classmethod @classmethod
async def is_wake(cls, group_id: str) -> bool: async def is_wake(cls, group_id: str) -> bool:
@ -243,9 +243,11 @@ class PluginManage:
参数: 参数:
group_id: 群组id group_id: 群组id
""" """
await GroupConsole.filter(group_id=group_id, channel_id__isnull=True).update( group, _ = await GroupConsole.get_or_create(
status=False group_id=group_id, channel_id__isnull=True
) )
group.status = False
await group.save(update_fields=["status"])
@classmethod @classmethod
async def wake(cls, group_id: str): async def wake(cls, group_id: str):
@ -254,9 +256,11 @@ class PluginManage:
参数: 参数:
group_id: 群组id group_id: 群组id
""" """
await GroupConsole.filter(group_id=group_id, channel_id__isnull=True).update( group, _ = await GroupConsole.get_or_create(
status=True group_id=group_id, channel_id__isnull=True
) )
group.status = True
await group.save(update_fields=["status"])
@classmethod @classmethod
async def block(cls, module: str): async def block(cls, module: str):

View File

@ -1,15 +1,14 @@
from nonebot.exception import IgnoredException
from nonebot_plugin_alconna import At from nonebot_plugin_alconna import At
from nonebot_plugin_uninfo import Uninfo from nonebot_plugin_uninfo import Uninfo
from zhenxun.models.level_user import LevelUser from zhenxun.models.level_user import LevelUser
from zhenxun.models.plugin_info import PluginInfo from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.cache import Cache from zhenxun.services.cache import Cache
from zhenxun.services.log import logger
from zhenxun.utils.enum import CacheType from zhenxun.utils.enum import CacheType
from zhenxun.utils.message import MessageUtils from zhenxun.utils.utils import get_entity_ids
from .utils import freq from .exception import SkipPluginException
from .utils import send_message
async def auth_admin(plugin: PluginInfo, session: Uninfo): async def auth_admin(plugin: PluginInfo, session: Uninfo):
@ -17,60 +16,33 @@ async def auth_admin(plugin: PluginInfo, session: Uninfo):
参数: 参数:
plugin: PluginInfo plugin: PluginInfo
session: PluginInfo session: Uninfo
""" """
group_id = None
cache = Cache[list[LevelUser]](CacheType.LEVEL)
user_level = await cache.get(session.user.id) or []
if session.group:
if session.group.parent:
group_id = session.group.parent.id
else:
group_id = session.group.id
if not plugin.admin_level: if not plugin.admin_level:
return return
if group_id: entity = get_entity_ids(session)
user_level += await cache.get(session.user.id, group_id) or [] cache = Cache[list[LevelUser]](CacheType.LEVEL)
user_level = await cache.get(session.user.id) or []
if entity.group_id:
user_level += await cache.get(session.user.id, entity.group_id) or []
user = max(user_level, key=lambda x: x.user_level) user = max(user_level, key=lambda x: x.user_level)
if user.user_level < plugin.admin_level: if user.user_level < plugin.admin_level:
try: await send_message(
if freq._flmt.check(session.user.id): session,
freq._flmt.start_cd(session.user.id) [
await MessageUtils.build_message( At(flag="user", target=session.user.id),
[ f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}",
At(flag="user", target=session.user.id), ],
"你的权限不足喔," entity.user_id,
f"该功能需要的权限等级: {plugin.admin_level}", )
] raise SkipPluginException(
).send(reply_to=True) f"{plugin.name}({plugin.module}) 管理员权限不足..."
except Exception as e:
logger.error(
"auth_admin 发送消息失败",
"AuthChecker",
session=session,
e=e,
)
logger.debug(
f"{plugin.name}({plugin.module}) 管理员权限不足...",
"AuthChecker",
session=session,
) )
raise IgnoredException("管理员权限不足...")
elif user_level: elif user_level:
user = max(user_level, key=lambda x: x.user_level) user = max(user_level, key=lambda x: x.user_level)
if user.user_level < plugin.admin_level: if user.user_level < plugin.admin_level:
try: await send_message(
await MessageUtils.build_message( session,
f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}" f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}",
).send() )
except Exception as e: raise SkipPluginException(f"{plugin.name}({plugin.module}) 管理员权限不足...")
logger.error(
"auth_admin 发送消息失败", "AuthChecker", session=session, e=e
)
logger.debug(
f"{plugin.name}({plugin.module}) 管理员权限不足...",
"AuthChecker",
session=session,
)
raise IgnoredException("权限不足")

View File

@ -0,0 +1,167 @@
from nonebot.adapters import Bot
from nonebot.matcher import Matcher
from nonebot_plugin_alconna import At
from nonebot_plugin_uninfo import Uninfo
from tortoise.exceptions import MultipleObjectsReturned
from zhenxun.configs.config import Config
from zhenxun.models.ban_console import BanConsole
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.cache import Cache
from zhenxun.services.log import logger
from zhenxun.utils.enum import CacheType, PluginType
from zhenxun.utils.utils import EntityIDs, get_entity_ids
from .config import LOGGER_COMMAND
from .exception import SkipPluginException
from .utils import send_message
Config.add_plugin_config(
"hook",
"BAN_RESULT",
"才不会给你发消息.",
help="对被ban用户发送的消息",
)
async def is_ban(user_id: str | None, group_id: str | None) -> int:
cache = Cache[list[BanConsole]](CacheType.BAN)
results = await cache.get(user_id, group_id) or await cache.get(user_id)
if not results:
return 0
for result in results:
if result.group_id == group_id and (
result.duration > 0 or result.duration == -1
):
return await BanConsole.check_ban_time(user_id, group_id)
if not result.group_id and result.duration == -1:
return await BanConsole.check_ban_time(user_id, group_id)
return 0
def check_plugin_type(matcher: Matcher) -> bool:
"""判断插件类型是否是隐藏插件
参数:
matcher: Matcher
返回:
bool: 是否为隐藏插件
"""
if plugin := matcher.plugin:
if metadata := plugin.metadata:
extra = metadata.extra
if extra.get("plugin_type") in [PluginType.HIDDEN]:
return False
return True
def format_time(time: float) -> str:
"""格式化时间
参数:
time: ban时长
返回:
str: 格式化时间文本
"""
if time == -1:
return ""
time = abs(int(time))
if time < 60:
time_str = f"{time!s}"
else:
minute = int(time / 60)
if minute > 60:
hours = minute // 60
minute %= 60
time_str = f"{hours} 小时 {minute}分钟"
else:
time_str = f"{minute} 分钟"
return time_str
async def group_handle(cache: Cache[list[BanConsole]], group_id: str):
"""群组ban检查
参数:
cache: cache
group_id: 群组id
异常:
SkipPluginException: 群组处于黑名单
"""
try:
if await is_ban(None, group_id):
raise SkipPluginException("群组处于黑名单中...")
except MultipleObjectsReturned:
logger.warning(
"群组黑名单数据重复过滤该次hook并移除多余数据...", LOGGER_COMMAND
)
ids = await BanConsole.filter(user_id="", group_id=group_id).values_list(
"id", flat=True
)
await BanConsole.filter(id__in=ids[:-1]).delete()
await cache.reload()
async def user_handle(
module: str, cache: Cache[list[BanConsole]], entity: EntityIDs, session: Uninfo
):
"""用户ban检查
参数:
module: 插件模块名
cache: cache
user_id: 用户id
session: Uninfo
异常:
SkipPluginException: 用户处于黑名单
"""
ban_result = Config.get_config("hook", "BAN_RESULT")
try:
time = await is_ban(entity.user_id, entity.group_id)
if not time:
return
time_str = format_time(time)
db_plugin = await Cache[PluginInfo](CacheType.PLUGINS).get(module)
if (
db_plugin
# and not db_plugin.ignore_prompt
and time != -1
and ban_result
):
await send_message(
session,
[
At(flag="user", target=entity.user_id),
f"{ban_result}\n在..在 {time_str} 后才会理你喔",
],
entity.user_id,
)
raise SkipPluginException("用户处于黑名单中...")
except MultipleObjectsReturned:
logger.warning(
"用户黑名单数据重复过滤该次hook并移除多余数据...", LOGGER_COMMAND
)
ids = await BanConsole.filter(user_id=entity.user_id, group_id="").values_list(
"id", flat=True
)
await BanConsole.filter(id__in=ids[:-1]).delete()
await cache.reload()
async def auth_ban(matcher: Matcher, bot: Bot, session: Uninfo):
if not check_plugin_type(matcher):
return
if not matcher.plugin_name:
return
entity = get_entity_ids(session)
if entity.user_id in bot.config.superusers:
return
cache = Cache[list[BanConsole]](CacheType.BAN)
if entity.group_id:
await group_handle(cache, entity.group_id)
if entity.user_id:
await user_handle(matcher.plugin_name, cache, entity, session)

View File

@ -1,12 +1,11 @@
from nonebot.exception import IgnoredException
from zhenxun.models.bot_console import BotConsole from zhenxun.models.bot_console import BotConsole
from zhenxun.models.plugin_info import PluginInfo from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.cache import Cache from zhenxun.services.cache import Cache
from zhenxun.services.log import logger
from zhenxun.utils.common_utils import CommonUtils from zhenxun.utils.common_utils import CommonUtils
from zhenxun.utils.enum import CacheType from zhenxun.utils.enum import CacheType
from .exception import SkipPluginException
async def auth_bot(plugin: PluginInfo, bot_id: str): async def auth_bot(plugin: PluginInfo, bot_id: str):
"""bot层面的权限检查 """bot层面的权限检查
@ -16,17 +15,14 @@ async def auth_bot(plugin: PluginInfo, bot_id: str):
bot_id: bot id bot_id: bot id
异常: 异常:
IgnoredException: 忽略插件 SkipPluginException: 忽略插件
IgnoredException: 忽略插件 SkipPluginException: 忽略插件
""" """
if cache := Cache[BotConsole](CacheType.BOT): if cache := Cache[BotConsole](CacheType.BOT):
bot = await cache.get(bot_id) bot = await cache.get(bot_id)
if not bot or not bot.status: if not bot or not bot.status:
logger.debug("Bot不存在或休眠中阻断权限检测...", "AuthChecker") raise SkipPluginException("Bot不存在或休眠中阻断权限检测...")
raise IgnoredException("BotConsole休眠权限检测 ignore")
if CommonUtils.format(plugin.module) in bot.block_plugins: if CommonUtils.format(plugin.module) in bot.block_plugins:
logger.debug( raise SkipPluginException(
f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭...", f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭..."
"AuthChecker",
) )
raise IgnoredException("BotConsole插件权限检测 ignore")

View File

@ -1,10 +1,10 @@
from nonebot.exception import IgnoredException
from nonebot_plugin_uninfo import Uninfo from nonebot_plugin_uninfo import Uninfo
from zhenxun.models.plugin_info import PluginInfo from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.user_console import UserConsole from zhenxun.models.user_console import UserConsole
from zhenxun.services.log import logger
from zhenxun.utils.message import MessageUtils from .exception import SkipPluginException
from .utils import send_message
async def auth_cost(user: UserConsole, plugin: PluginInfo, session: Uninfo) -> int: async def auth_cost(user: UserConsole, plugin: PluginInfo, session: Uninfo) -> int:
@ -19,17 +19,6 @@ async def auth_cost(user: UserConsole, plugin: PluginInfo, session: Uninfo) -> i
""" """
if user.gold < plugin.cost_gold: if user.gold < plugin.cost_gold:
"""插件消耗金币不足""" """插件消耗金币不足"""
try: await send_message(session, f"金币不足..该功能需要{plugin.cost_gold}金币..")
await MessageUtils.build_message( raise SkipPluginException(f"{plugin.name}({plugin.module}) 金币限制...")
f"金币不足..该功能需要{plugin.cost_gold}金币.."
).send()
except Exception as e:
logger.error("auth_cost 发送消息失败", "AuthChecker", session=session, e=e)
logger.debug(
f"{plugin.name}({plugin.module}) 金币限制.."
f"该功能需要{plugin.cost_gold}金币..",
"AuthChecker",
session=session,
)
raise IgnoredException(f"{plugin.name}({plugin.module}) 金币限制...")
return plugin.cost_gold return plugin.cost_gold

View File

@ -1,57 +1,35 @@
from nonebot.exception import IgnoredException
from nonebot_plugin_alconna import UniMsg from nonebot_plugin_alconna import UniMsg
from nonebot_plugin_uninfo import Uninfo
from zhenxun.models.group_console import GroupConsole from zhenxun.models.group_console import GroupConsole
from zhenxun.models.plugin_info import PluginInfo from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.cache import Cache from zhenxun.services.cache import Cache
from zhenxun.services.log import logger
from zhenxun.utils.enum import CacheType from zhenxun.utils.enum import CacheType
from zhenxun.utils.utils import EntityIDs
from .config import SwitchEnum
from .exception import SkipPluginException
async def auth_group(plugin: PluginInfo, session: Uninfo, message: UniMsg): async def auth_group(plugin: PluginInfo, entity: EntityIDs, message: UniMsg):
"""群黑名单检测 群总开关检测 """群黑名单检测 群总开关检测
参数: 参数:
plugin: PluginInfo plugin: PluginInfo
session: EventSession entity: EntityIDs
message: UniMsg message: UniMsg
""" """
if not session.group: if not entity.group_id:
return return
if session.group.parent:
group_id = session.group.parent.id
else:
group_id = session.group.id
text = message.extract_plain_text() text = message.extract_plain_text()
group = await Cache[GroupConsole](CacheType.GROUPS).get(group_id) group = await Cache[GroupConsole](CacheType.GROUPS).get(entity.group_id)
if not group: if not group:
"""群不存在""" raise SkipPluginException("群组信息不存在...")
logger.debug(
"群组信息不存在...",
"AuthChecker",
session=session,
)
raise IgnoredException("群不存在")
if group.level < 0: if group.level < 0:
"""群权限小于0""" raise SkipPluginException("群组黑名单, 目标群组群权限权限-1...")
logger.debug( if text.strip() != SwitchEnum.ENABLE and not group.status:
"群黑名单, 群权限-1...", raise SkipPluginException("群组休眠状态...")
"AuthChecker",
session=session,
)
raise IgnoredException("群黑名单")
if not group.status:
"""群休眠"""
if text.strip() != "醒来":
logger.debug("群休眠状态...", "AuthChecker", session=session)
raise IgnoredException("群休眠状态")
if plugin.level > group.level: if plugin.level > group.level:
"""插件等级大于群等级""" raise SkipPluginException(
logger.debug( f"{plugin.name}({plugin.module}) 群等级限制,"
f"{plugin.name}({plugin.module}) 群等级限制.." f"该功能需要的群等级: {plugin.level}..."
f"该功能需要的群等级: {plugin.level}..",
"AuthChecker",
session=session,
) )
raise IgnoredException(f"{plugin.name}({plugin.module}) 群等级限制...")

View File

@ -1,6 +1,5 @@
from typing import ClassVar from typing import ClassVar
from nonebot.exception import IgnoredException
from nonebot_plugin_uninfo import Uninfo from nonebot_plugin_uninfo import Uninfo
from pydantic import BaseModel from pydantic import BaseModel
@ -11,6 +10,9 @@ from zhenxun.utils.enum import LimitWatchType, PluginLimitType
from zhenxun.utils.message import MessageUtils from zhenxun.utils.message import MessageUtils
from zhenxun.utils.utils import CountLimiter, FreqLimiter, UserBlockLimiter from zhenxun.utils.utils import CountLimiter, FreqLimiter, UserBlockLimiter
from .config import LOGGER_COMMAND
from .exception import SkipPluginException
class Limit(BaseModel): class Limit(BaseModel):
limit: PluginLimit limit: PluginLimit
@ -69,7 +71,7 @@ class LimitManage:
key_type = channel_id or group_id key_type = channel_id or group_id
logger.debug( logger.debug(
f"解除对象: {key_type} 的block限制", f"解除对象: {key_type} 的block限制",
"AuthChecker", LOGGER_COMMAND,
session=user_id, session=user_id,
group_id=group_id, group_id=group_id,
) )
@ -139,16 +141,13 @@ class LimitManage:
if is_limit and not limiter.check(key_type): if is_limit and not limiter.check(key_type):
if limit.result: if limit.result:
await MessageUtils.build_message(limit.result).send() await MessageUtils.build_message(limit.result).send()
logger.debug( raise SkipPluginException(
f"{limit.module}({limit.limit_type}) 正在限制中...", f"{limit.module}({limit.limit_type}) 正在限制中..."
"AuthChecker",
session=session,
) )
raise IgnoredException(f"{limit.module} 正在限制中...")
else: else:
logger.debug( logger.debug(
f"开始进行限制 {limit.module}({limit.limit_type})...", f"开始进行限制 {limit.module}({limit.limit_type})...",
"AuthChecker", LOGGER_COMMAND,
session=user_id, session=user_id,
group_id=group_id, group_id=group_id,
) )

View File

@ -1,17 +1,15 @@
from nonebot.adapters import Event from nonebot.adapters import Event
from nonebot.exception import IgnoredException
from nonebot_plugin_uninfo import Uninfo from nonebot_plugin_uninfo import Uninfo
from zhenxun.models.group_console import GroupConsole from zhenxun.models.group_console import GroupConsole
from zhenxun.models.plugin_info import PluginInfo from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.cache import Cache from zhenxun.services.cache import Cache
from zhenxun.services.log import logger
from zhenxun.utils.common_utils import CommonUtils from zhenxun.utils.common_utils import CommonUtils
from zhenxun.utils.enum import BlockType, CacheType from zhenxun.utils.enum import BlockType, CacheType
from zhenxun.utils.message import MessageUtils from zhenxun.utils.utils import get_entity_ids
from .exception import IsSuperuserException from .exception import IsSuperuserException, SkipPluginException
from .utils import freq from .utils import freq, is_poke, send_message
class GroupCheck: class GroupCheck:
@ -42,16 +40,12 @@ class GroupCheck:
group = await self.__get_data() group = await self.__get_data()
if group and CommonUtils.format(plugin.module) in group.superuser_block_plugin: if group and CommonUtils.format(plugin.module) in group.superuser_block_plugin:
if freq.is_send_limit_message(plugin, group.group_id, self.is_poke): if freq.is_send_limit_message(plugin, group.group_id, self.is_poke):
freq._flmt_s.start_cd(self.group_id) await send_message(
await MessageUtils.build_message("超级管理员禁用了该群此功能...").send( self.session, "超级管理员禁用了该群此功能...", self.group_id
reply_to=True
) )
logger.debug( raise SkipPluginException(
f"{plugin.name}({plugin.module}) 超级管理员禁用了该群此功能...", f"{plugin.name}({plugin.module}) 超级管理员禁用了该群此功能..."
"AuthChecker",
session=self.session,
) )
raise IgnoredException("超级管理员禁用了该群此功能...")
await self.check_normal_block(self.plugin) await self.check_normal_block(self.plugin)
async def check_normal_block(self, plugin: PluginInfo): async def check_normal_block(self, plugin: PluginInfo):
@ -66,16 +60,8 @@ class GroupCheck:
group = await self.__get_data() group = await self.__get_data()
if group and CommonUtils.format(plugin.module) in group.block_plugin: if group and CommonUtils.format(plugin.module) in group.block_plugin:
if freq.is_send_limit_message(plugin, self.group_id, self.is_poke): if freq.is_send_limit_message(plugin, self.group_id, self.is_poke):
freq._flmt_s.start_cd(self.group_id) await send_message(self.session, "该群未开启此功能...", self.group_id)
await MessageUtils.build_message("该群未开启此功能...").send( raise SkipPluginException(f"{plugin.name}({plugin.module}) 未开启此功能...")
reply_to=True
)
logger.debug(
f"{plugin.name}({plugin.module}) 未开启此功能...",
"AuthChecker",
session=self.session,
)
raise IgnoredException("该群未开启此功能...")
await self.check_global_block(self.plugin) await self.check_global_block(self.plugin)
async def check_global_block(self, plugin: PluginInfo): async def check_global_block(self, plugin: PluginInfo):
@ -89,25 +75,13 @@ class GroupCheck:
""" """
if plugin.block_type == BlockType.GROUP: if plugin.block_type == BlockType.GROUP:
"""全局群组禁用""" """全局群组禁用"""
try: if freq.is_send_limit_message(plugin, self.group_id, self.is_poke):
if freq.is_send_limit_message(plugin, self.group_id, self.is_poke): await send_message(
freq._flmt_c.start_cd(self.group_id) self.session, "该功能在群组中已被禁用...", self.group_id
await MessageUtils.build_message("该功能在群组中已被禁用...").send(
reply_to=True
)
except Exception as e:
logger.error(
"auth_plugin 发送消息失败",
"AuthChecker",
session=self.session,
e=e,
) )
logger.debug( raise SkipPluginException(
f"{plugin.name}({plugin.module}) 该插件在群组中已被禁用...", f"{plugin.name}({plugin.module}) 该插件在群组中已被禁用..."
"AuthChecker",
session=self.session,
) )
raise IgnoredException("该插件在群组中已被禁用...")
class PluginCheck: class PluginCheck:
@ -126,25 +100,11 @@ class PluginCheck:
IgnoredException: 忽略插件 IgnoredException: 忽略插件
""" """
if plugin.block_type == BlockType.PRIVATE: if plugin.block_type == BlockType.PRIVATE:
try: if freq.is_send_limit_message(plugin, self.session.user.id, self.is_poke):
if freq.is_send_limit_message( await send_message(self.session, "该功能在私聊中已被禁用...")
plugin, self.session.user.id, self.is_poke raise SkipPluginException(
): f"{plugin.name}({plugin.module}) 该插件在私聊中已被禁用..."
freq._flmt_c.start_cd(self.session.user.id)
await MessageUtils.build_message("该功能在私聊中已被禁用...").send()
except Exception as e:
logger.error(
"auth_admin 发送消息失败",
"AuthChecker",
session=self.session,
e=e,
)
logger.debug(
f"{plugin.name}({plugin.module}) 该插件在私聊中已被禁用...",
"AuthChecker",
session=self.session,
) )
raise IgnoredException("该插件在私聊中已被禁用...")
async def check_global(self, plugin: PluginInfo): async def check_global(self, plugin: PluginInfo):
"""全局状态 """全局状态
@ -162,16 +122,10 @@ class PluginCheck:
if self.group_id and (group := await cache.get(self.group_id)): if self.group_id and (group := await cache.get(self.group_id)):
if group.is_super: if group.is_super:
raise IsSuperuserException() raise IsSuperuserException()
logger.debug(
f"{plugin.name}({plugin.module}) 全局未开启此功能...",
"AuthChecker",
session=self.session,
)
sid = self.group_id or self.session.user.id sid = self.group_id or self.session.user.id
if freq.is_send_limit_message(plugin, sid, self.is_poke): if freq.is_send_limit_message(plugin, sid, self.is_poke):
freq._flmt_s.start_cd(sid) await send_message(self.session, "全局未开启此功能...", sid)
await MessageUtils.build_message("全局未开启此功能...").send() raise SkipPluginException(f"{plugin.name}({plugin.module}) 全局未开启此功能...")
raise IgnoredException("全局未开启此功能...")
async def auth_plugin(plugin: PluginInfo, session: Uninfo, event: Event): async def auth_plugin(plugin: PluginInfo, session: Uninfo, event: Event):
@ -180,22 +134,13 @@ async def auth_plugin(plugin: PluginInfo, session: Uninfo, event: Event):
参数: 参数:
plugin: PluginInfo plugin: PluginInfo
session: Uninfo session: Uninfo
event: Event
""" """
group_id = None entity = get_entity_ids(session)
if session.group: is_poke_event = is_poke(event)
if session.group.parent: user_check = PluginCheck(entity.group_id, session, is_poke_event)
group_id = session.group.parent.id if entity.group_id:
else: group_check = GroupCheck(plugin, entity.group_id, session, is_poke_event)
group_id = session.group.id
try:
from nonebot.adapters.onebot.v11 import PokeNotifyEvent
is_poke = isinstance(event, PokeNotifyEvent)
except ImportError:
is_poke = False
user_check = PluginCheck(group_id, session, is_poke)
if group_id:
group_check = GroupCheck(plugin, group_id, session, is_poke)
await group_check.check() await group_check.check()
else: else:
await user_check.check_user(plugin) await user_check.check_user(plugin)

View File

@ -0,0 +1,13 @@
import sys
if sys.version_info >= (3, 11):
from enum import StrEnum
else:
from strenum import StrEnum
LOGGER_COMMAND = "AuthChecker"
class SwitchEnum(StrEnum):
ENABLE = "醒来"
DISABLE = "休息吧"

View File

@ -1,2 +1,14 @@
class IsSuperuserException(Exception): class IsSuperuserException(Exception):
pass pass
class SkipPluginException(Exception):
def __init__(self, info: str, *args: object) -> None:
super().__init__(*args)
self.info = info
def __str__(self) -> str:
return self.info
def __repr__(self) -> str:
return self.info

View File

@ -1,11 +1,61 @@
import contextlib
from nonebot.adapters import Event
from nonebot_plugin_uninfo import Uninfo
from zhenxun.configs.config import Config from zhenxun.configs.config import Config
from zhenxun.models.plugin_info import PluginInfo from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.log import logger
from zhenxun.utils.enum import PluginType from zhenxun.utils.enum import PluginType
from zhenxun.utils.message import MessageUtils
from zhenxun.utils.utils import FreqLimiter from zhenxun.utils.utils import FreqLimiter
from .config import LOGGER_COMMAND
base_config = Config.get("hook") base_config = Config.get("hook")
def is_poke(event: Event) -> bool:
"""判断是否为poke类型
参数:
event: Event
返回:
bool: 是否为poke类型
"""
with contextlib.suppress(ImportError):
from nonebot.adapters.onebot.v11 import PokeNotifyEvent
return isinstance(event, PokeNotifyEvent)
return False
async def send_message(
session: Uninfo, message: list | str, check_tag: str | None = None
):
"""发送消息
参数:
session: Uninfo
message: 消息
check_tag: cd flag
"""
try:
if not check_tag:
await MessageUtils.build_message(message).send(reply_to=True)
elif freq._flmt.check(check_tag):
freq._flmt.start_cd(check_tag)
await MessageUtils.build_message(message).send(reply_to=True)
except Exception as e:
logger.error(
"发送消息失败",
LOGGER_COMMAND,
session=session,
e=e,
)
class FreqUtils: class FreqUtils:
def __init__(self): def __init__(self):
check_notice_info_cd = Config.get_config("hook", "CHECK_NOTICE_INFO_CD") check_notice_info_cd = Config.get_config("hook", "CHECK_NOTICE_INFO_CD")

View File

@ -1,4 +1,4 @@
import contextlib import asyncio
from nonebot.adapters import Bot, Event from nonebot.adapters import Bot, Event
from nonebot.exception import IgnoredException from nonebot.exception import IgnoredException
@ -18,14 +18,107 @@ from zhenxun.utils.enum import (
) )
from zhenxun.utils.exception import InsufficientGold from zhenxun.utils.exception import InsufficientGold
from zhenxun.utils.platform import PlatformUtils from zhenxun.utils.platform import PlatformUtils
from zhenxun.utils.utils import get_entity_ids
from .auth.auth_admin import auth_admin from .auth.auth_admin import auth_admin
from .auth.auth_ban import auth_ban
from .auth.auth_bot import auth_bot from .auth.auth_bot import auth_bot
from .auth.auth_cost import auth_cost from .auth.auth_cost import auth_cost
from .auth.auth_group import auth_group from .auth.auth_group import auth_group
from .auth.auth_limit import LimitManage, auth_limit from .auth.auth_limit import LimitManage, auth_limit
from .auth.auth_plugin import auth_plugin from .auth.auth_plugin import auth_plugin
from .auth.exception import IsSuperuserException from .auth.config import LOGGER_COMMAND
from .auth.exception import IsSuperuserException, SkipPluginException
async def get_plugin_and_user(
module: str, user_id: str
) -> tuple[PluginInfo, UserConsole]:
"""获取用户数据和插件信息
参数:
module: 模块名
user_id: 用户id
异常:
SkipPluginException: 插件数据不存在
SkipPluginException: 插件类型为HIDDEN
SkipPluginException: 重复创建用户
SkipPluginException: 用户数据不存在
返回:
tuple[PluginInfo, UserConsole]: 插件信息用户信息
"""
user_cache = Cache[UserConsole](CacheType.USERS)
plugin = await Cache[PluginInfo](CacheType.PLUGINS).get(module)
if not plugin:
raise SkipPluginException(f"插件:{module} 数据不存在,已跳过权限检查...")
if plugin.plugin_type == PluginType.HIDDEN:
raise SkipPluginException(
f"插件: {plugin.name}:{plugin.module} 为HIDDEN已跳过权限检查..."
)
user = None
try:
user = await user_cache.get(user_id)
except IntegrityError as e:
raise SkipPluginException("重复创建用户,已跳过该次权限检查...") from e
if not user:
raise SkipPluginException("用户数据不存在,已跳过权限检查...")
return plugin, user
async def get_plugin_cost(
bot: Bot, user: UserConsole, plugin: PluginInfo, session: Uninfo
) -> int:
"""获取插件费用
参数:
bot: Bot
user: 用户数据
plugin: 插件数据
session: Uninfo
异常:
IsSuperuserException: 超级用户
IsSuperuserException: 超级用户
返回:
int: 调用插件金币费用
"""
cost_gold = await auth_cost(user, plugin, session)
if session.user.id in bot.config.superusers:
if plugin.plugin_type == PluginType.SUPERUSER:
raise IsSuperuserException()
if not plugin.limit_superuser:
raise IsSuperuserException()
return cost_gold
async def reduce_gold(user_id: str, module: str, cost_gold: int, session: Uninfo):
"""扣除用户金币
参数:
user_id: 用户id
module: 插件模块名称
cost_gold: 消耗金币
session: Uninfo
"""
user_cache = Cache[UserConsole](CacheType.USERS)
try:
await UserConsole.reduce_gold(
user_id,
cost_gold,
GoldHandle.PLUGIN,
module,
PlatformUtils.get_platform(session),
)
except InsufficientGold:
if u := await UserConsole.get_user(user_id):
u.gold = 0
await u.save(update_fields=["gold"])
# 更新缓存
await user_cache.update(user_id)
logger.debug(f"调用功能花费金币: {cost_gold}", LOGGER_COMMAND, session=session)
async def auth( async def auth(
@ -44,85 +137,32 @@ async def auth(
session: Uninfo session: Uninfo
message: UniMsg message: UniMsg
""" """
user_id = session.user.id
group_id = None
channel_id = None
if session.group:
if session.group.parent:
group_id = session.group.parent.id
channel_id = session.group.id
else:
group_id = session.group.id
is_ignore = False
cost_gold = 0 cost_gold = 0
with contextlib.suppress(ImportError): ignore_flag = False
from nonebot.adapters.onebot.v11 import PokeNotifyEvent entity = get_entity_ids(session)
module = matcher.plugin_name or ""
if matcher.type == "notice" and not isinstance(event, PokeNotifyEvent): try:
"""过滤除poke外的notice""" if not module:
return raise SkipPluginException("Matcher插件名称不存在...")
user_cache = Cache[UserConsole](CacheType.USERS) plugin, user = await get_plugin_and_user(module, entity.user_id)
if matcher.plugin and (module := matcher.plugin.name): cost_gold = await get_plugin_cost(bot, user, plugin, session)
plugin = await Cache[PluginInfo](CacheType.PLUGINS).get(module) await asyncio.gather(
if not plugin: *[
return logger.debug(f"插件:{module} 数据不存在,已跳过权限检查...") auth_ban(matcher, bot, session),
if plugin.plugin_type == PluginType.HIDDEN: auth_bot(plugin, bot.self_id),
return logger.debug( auth_group(plugin, entity, message),
f"插件: {plugin.name}:{plugin.module} 为HIDDEN已跳过权限检查..." auth_admin(plugin, session),
) auth_plugin(plugin, session, event),
user = None ]
try: )
user = await user_cache.get(session.user.id) await auth_limit(plugin, session)
except IntegrityError as e: except SkipPluginException as e:
logger.debug( LimitManage.unblock(module, entity.user_id, entity.group_id, entity.channel_id)
"重复创建用户,已跳过该次权限检查...", logger.info(str(e), LOGGER_COMMAND, session=session)
"AuthChecker", ignore_flag = True
session=session, except IsSuperuserException:
e=e, logger.debug("超级用户跳过权限检测...", LOGGER_COMMAND, session=session)
) if not ignore_flag and cost_gold > 0:
if not user: await reduce_gold(entity.user_id, module, cost_gold, session)
return logger.debug( if ignore_flag:
"用户数据不存在,已跳过权限检查...", "AuthChecker", session=session
)
try:
cost_gold = await auth_cost(user, plugin, session)
if session.user.id in bot.config.superusers:
if plugin.plugin_type == PluginType.SUPERUSER:
raise IsSuperuserException()
if not plugin.limit_superuser:
cost_gold = 0
raise IsSuperuserException()
await auth_bot(plugin, bot.self_id)
await auth_group(plugin, session, message)
await auth_admin(plugin, session)
await auth_plugin(plugin, session, event)
await auth_limit(plugin, session)
except IsSuperuserException:
logger.debug(
"超级用户或被ban跳过权限检测...", "AuthChecker", session=session
)
except IgnoredException:
is_ignore = True
LimitManage.unblock(matcher.plugin.name, user_id, group_id, channel_id)
except AssertionError as e:
is_ignore = True
logger.debug("消息无法发送", session=session, e=e)
if cost_gold and user_id:
"""花费金币"""
try:
await UserConsole.reduce_gold(
user_id,
cost_gold,
GoldHandle.PLUGIN,
matcher.plugin.name if matcher.plugin else "",
PlatformUtils.get_platform(session),
)
except InsufficientGold:
if u := await UserConsole.get_user(user_id):
u.gold = 0
await u.save(update_fields=["gold"])
# 更新缓存
await user_cache.update(user_id)
logger.debug(f"调用功能花费金币: {cost_gold}", "AuthChecker", session=session)
if is_ignore:
raise IgnoredException("权限检测 ignore") raise IgnoredException("权限检测 ignore")

View File

@ -1,15 +1,21 @@
import time
from nonebot.adapters import Bot, Event from nonebot.adapters import Bot, Event
from nonebot.matcher import Matcher from nonebot.matcher import Matcher
from nonebot.message import run_postprocessor, run_preprocessor from nonebot.message import run_postprocessor, run_preprocessor
from nonebot_plugin_alconna import UniMsg from nonebot_plugin_alconna import UniMsg
from nonebot_plugin_uninfo import Uninfo from nonebot_plugin_uninfo import Uninfo
from zhenxun.services.log import logger
from .auth.config import LOGGER_COMMAND
from .auth_checker import LimitManage, auth from .auth_checker import LimitManage, auth
# # 权限检测 # # 权限检测
@run_preprocessor @run_preprocessor
async def _(matcher: Matcher, event: Event, bot: Bot, session: Uninfo, message: UniMsg): async def _(matcher: Matcher, event: Event, bot: Bot, session: Uninfo, message: UniMsg):
start_time = time.time()
await auth( await auth(
matcher, matcher,
event, event,
@ -17,6 +23,7 @@ async def _(matcher: Matcher, event: Event, bot: Bot, session: Uninfo, message:
session, session,
message, message,
) )
logger.info(f"权限检测耗时:{time.time() - start_time}", LOGGER_COMMAND)
# 解除命令block阻塞 # 解除命令block阻塞

View File

@ -1,104 +0,0 @@
from nonebot.adapters import Bot
from nonebot.exception import IgnoredException
from nonebot.matcher import Matcher
from nonebot.message import run_preprocessor
from nonebot_plugin_alconna import At
from nonebot_plugin_uninfo import Uninfo
from tortoise.exceptions import MultipleObjectsReturned
from zhenxun.configs.config import Config
from zhenxun.models.ban_console import BanConsole
from zhenxun.models.group_console import GroupConsole
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.cache import Cache
from zhenxun.services.log import logger
from zhenxun.utils.enum import CacheType, PluginType
from zhenxun.utils.message import MessageUtils
from zhenxun.utils.utils import FreqLimiter
Config.add_plugin_config(
"hook",
"BAN_RESULT",
"才不会给你发消息.",
help="对被ban用户发送的消息",
)
_flmt = FreqLimiter(300)
async def is_ban(user_id: str | None, group_id: str | None):
cache = Cache[list[BanConsole]](CacheType.BAN)
result = await cache.get(user_id, group_id) or await cache.get(user_id)
return result and result[0].ban_time > 0
# 检查是否被ban
@run_preprocessor
async def _(matcher: Matcher, bot: Bot, session: Uninfo):
if plugin := matcher.plugin:
if metadata := plugin.metadata:
extra = metadata.extra
if extra.get("plugin_type") in [PluginType.HIDDEN]:
return
user_id = session.user.id
group_id = session.group.id if session.group else None
cache = Cache[list[BanConsole]](CacheType.BAN)
if user_id in bot.config.superusers:
return
if group_id:
try:
if await is_ban(None, group_id):
logger.debug("群组处于黑名单中...", "BanChecker")
raise IgnoredException("群组处于黑名单中...")
except MultipleObjectsReturned:
logger.warning(
"群组黑名单数据重复过滤该次hook并移除多余数据...", "BanChecker"
)
ids = await BanConsole.filter(user_id="", group_id=group_id).values_list(
"id", flat=True
)
await BanConsole.filter(id__in=ids[:-1]).delete()
await cache.reload()
group_cache = Cache[GroupConsole](CacheType.GROUPS)
if g := await group_cache.get(group_id):
if g.level < 0:
logger.debug("群黑名单, 群权限-1...", "BanChecker")
raise IgnoredException("群黑名单, 群权限-1..")
if user_id:
ban_result = Config.get_config("hook", "BAN_RESULT")
if await is_ban(user_id, group_id):
time = await BanConsole.check_ban_time(user_id, group_id)
if time == -1:
time_str = ""
else:
time = abs(int(time))
if time < 60:
time_str = f"{time!s}"
else:
minute = int(time / 60)
if minute > 60:
hours = minute // 60
minute %= 60
time_str = f"{hours} 小时 {minute}分钟"
else:
time_str = f"{minute} 分钟"
db_plugin = await Cache[PluginInfo](CacheType.PLUGINS).get(
matcher.plugin_name
)
if (
db_plugin
# and not db_plugin.ignore_prompt
and time != -1
and ban_result
and _flmt.check(user_id)
):
_flmt.start_cd(user_id)
logger.debug(f"ban检测发送插件: {matcher.plugin_name}")
await MessageUtils.build_message(
[
At(flag="user", target=user_id),
f"{ban_result}\n在..在 {time_str} 后才会理你喔",
]
).send()
logger.debug("用户处于黑名单中...", "BanChecker")
raise IgnoredException("用户处于黑名单中...")

View File

@ -42,6 +42,8 @@ def default_with_expiration(
data: dict[str, Any], expire_data: dict[str, int], expire: int data: dict[str, Any], expire_data: dict[str, int], expire: int
): ):
"""默认更新期时间cache方法""" """默认更新期时间cache方法"""
if not data:
return {}
keys = {k for k in data if k not in expire_data} keys = {k for k in data if k not in expire_data}
return {k: time.time() + expire for k in keys} if keys else {} return {k: time.time() + expire for k in keys} if keys else {}
@ -52,12 +54,6 @@ async def _():
return {p.module: p for p in data_list} return {p.module: p for p in data_list}
@CacheRoot.new(CacheType.PLUGINS)
async def _():
data_list = await PluginInfo.get_plugins()
return {p.module: p for p in data_list}
@CacheRoot.updater(CacheType.PLUGINS) @CacheRoot.updater(CacheType.PLUGINS)
async def _(data: dict[str, PluginInfo], key: str, value: Any): async def _(data: dict[str, PluginInfo], key: str, value: Any):
if value: if value:
@ -109,21 +105,20 @@ async def _(data: dict[str, GroupConsole], key: str, value: Any):
@CacheRoot.getter(CacheType.GROUPS, result_model=GroupConsole) @CacheRoot.getter(CacheType.GROUPS, result_model=GroupConsole)
async def _(data: dict[str, GroupConsole] | None, group_id: str): async def _(cache_data: CacheData, group_id: str):
if not data: cache_data.data = cache_data.data or {}
data = {} result = cache_data.data.get(group_id, None)
result = data.get(group_id, None)
if not result: if not result:
result = await GroupConsole.get_group(group_id=group_id) result = await GroupConsole.get_group(group_id=group_id)
if result: if result:
data[group_id] = result cache_data.data[group_id] = result
return result return result
@CacheRoot.with_refresh(CacheType.GROUPS) @CacheRoot.with_refresh(CacheType.GROUPS)
async def _(data: dict[str, GroupConsole]): async def _(data: dict[str, GroupConsole]):
groups = await GroupConsole.filter( groups = await GroupConsole.filter(
group_id__in=data.keys(), channel_id__isnull=True, load_status=True group_id__in=data.keys(), channel_id__isnull=True
) )
for group in groups: for group in groups:
data[group.group_id] = group data[group.group_id] = group
@ -154,14 +149,13 @@ async def _(data: dict[str, BotConsole], key: str, value: Any):
@CacheRoot.getter(CacheType.BOT, result_model=BotConsole) @CacheRoot.getter(CacheType.BOT, result_model=BotConsole)
async def _(data: dict[str, BotConsole] | None, bot_id: str): async def _(cache_data: CacheData, bot_id: str):
if not data: cache_data.data = cache_data.data or {}
data = {} result = cache_data.data.get(bot_id, None)
result = data.get(bot_id, None)
if not result: if not result:
result = await BotConsole.get_or_none(bot_id=bot_id) result = await BotConsole.get_or_none(bot_id=bot_id)
if result: if result:
data[bot_id] = result cache_data.data[bot_id] = result
return result return result
@ -224,7 +218,7 @@ def _(cache_data: CacheData):
return default_cleanup_expired(cache_data) return default_cleanup_expired(cache_data)
@CacheRoot.new(CacheType.LEVEL) @CacheRoot.new(CacheType.LEVEL, False)
async def _(): async def _():
return await LevelUser().all() return await LevelUser().all()
@ -246,13 +240,13 @@ async def _(cache_data: CacheData, user_id: str, group_id: str | None = None):
] ]
@CacheRoot.new(CacheType.BAN) @CacheRoot.new(CacheType.BAN, False, 5)
async def _(): async def _():
return await BanConsole.all() return await BanConsole.all()
@CacheRoot.getter(CacheType.BAN, result_model=list[BanConsole]) @CacheRoot.getter(CacheType.BAN, result_model=list[BanConsole])
def _(cache_data: CacheData, user_id: str | None, group_id: str | None = None): async def _(cache_data: CacheData, user_id: str | None, group_id: str | None = None):
if user_id: if user_id:
return ( return (
[ [

View File

@ -1,4 +1,4 @@
from typing import Any, ClassVar, overload from typing import Any, ClassVar, cast, overload
from typing_extensions import Self from typing_extensions import Self
from tortoise import fields from tortoise import fields
@ -11,6 +11,42 @@ from zhenxun.services.db_context import Model
from zhenxun.utils.enum import CacheType, DbLockType, PluginType from zhenxun.utils.enum import CacheType, DbLockType, PluginType
def add_disable_marker(name: str) -> str:
"""添加模块禁用标记符
Args:
name: 模块名称
Returns:
添加了禁用标记的模块名 (前缀'<'和后缀',')
"""
return f"<{name},"
@overload
def convert_module_format(data: str) -> list[str]: ...
@overload
def convert_module_format(data: list[str]) -> str: ...
def convert_module_format(data: str | list[str]) -> str | list[str]:
"""
`<aaa,<bbb,<ccc,` `["aaa", "bbb", "ccc"]` (即禁用启用)之间进行相互转换
参数:
data: 要转换的数据
返回:
str | list[str]: 根据输入类型返回转换后的数据
"""
if isinstance(data, str):
return [item.strip(",") for item in data.split("<") if item]
else:
return "".join(format(item) for item in data)
class GroupConsole(Model): class GroupConsole(Model):
id = fields.IntField(pk=True, generated=True, auto_increment=True) id = fields.IntField(pk=True, generated=True, auto_increment=True)
"""自增id""" """自增id"""
@ -47,7 +83,7 @@ class GroupConsole(Model):
platform = fields.CharField(255, default="qq", description="所属平台") platform = fields.CharField(255, default="qq", description="所属平台")
"""所属平台""" """所属平台"""
class Meta: # type: ignore class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
table = "group_console" table = "group_console"
table_description = "群组信息表" table_description = "群组信息表"
unique_together = ("group_id", "channel_id") unique_together = ("group_id", "channel_id")
@ -57,33 +93,34 @@ class GroupConsole(Model):
enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE] enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE]
"""开启锁""" """开启锁"""
@staticmethod
def format(name: str) -> str:
return f"<{name},"
@overload
@classmethod @classmethod
def convert_module_format(cls, data: str) -> list[str]: ... async def _get_task_modules(cls, *, default_status: bool) -> list[str]:
"""获取默认禁用的任务模块
@overload
@classmethod
def convert_module_format(cls, data: list[str]) -> str: ...
@classmethod
def convert_module_format(cls, data: str | list[str]) -> str | list[str]:
"""
`<aaa,<bbb,<ccc,` `["aaa", "bbb", "ccc"]` 之间进行相互转换
参数:
data (str | list[str]): 输入数据可能是格式化字符串或字符串列表
返回: 返回:
str | list[str]: 根据输入类型返回转换后的数据 list[str]: 任务模块列表
""" """
if isinstance(data, str): return cast(
return [item.strip(",") for item in data.split("<") if item] list[str],
elif isinstance(data, list): await TaskInfo.filter(default_status=default_status).values_list(
return "".join(cls.format(item) for item in data) "module", flat=True
),
)
@classmethod
async def _get_plugin_modules(cls, *, default_status: bool) -> list[str]:
"""获取默认禁用的插件模块
返回:
list[str]: 插件模块列表
"""
return cast(
list[str],
await PluginInfo.filter(
plugin_type__in=[PluginType.NORMAL, PluginType.DEPENDANT],
default_status=default_status,
).values_list("module", flat=True),
)
@classmethod @classmethod
@CacheRoot.listener(CacheType.GROUPS) @CacheRoot.listener(CacheType.GROUPS)
@ -92,20 +129,44 @@ class GroupConsole(Model):
) -> Self: ) -> Self:
"""覆盖create方法""" """覆盖create方法"""
group = await super().create(using_db=using_db, **kwargs) group = await super().create(using_db=using_db, **kwargs)
if modules := await TaskInfo.filter(default_status=False).values_list(
"module", flat=True task_modules = await cls._get_task_modules(default_status=False)
): plugin_modules = await cls._get_plugin_modules(default_status=False)
group.block_task = cls.convert_module_format(modules) # type: ignore
if modules := await PluginInfo.filter( if task_modules or plugin_modules:
plugin_type__in=[PluginType.NORMAL, PluginType.DEPENDANT], await cls._update_modules(group, task_modules, plugin_modules, using_db)
default_status=False,
).values_list("module", flat=True):
group.block_plugin = cls.convert_module_format(modules) # type: ignore
await group.save(
using_db=using_db, update_fields=["block_plugin", "block_task"]
)
return group return group
@classmethod
async def _update_modules(
cls,
group: Self,
task_modules: list[str],
plugin_modules: list[str],
using_db: BaseDBAsyncClient | None = None,
) -> None:
"""更新模块设置
参数:
group: 群组实例
task_modules: 任务模块列表
plugin_modules: 插件模块列表
using_db: 数据库连接
"""
update_fields = []
if task_modules:
group.block_task = convert_module_format(task_modules)
update_fields.append("block_task")
if plugin_modules:
group.block_plugin = convert_module_format(plugin_modules)
update_fields.append("block_plugin")
if update_fields:
await group.save(using_db=using_db, update_fields=update_fields)
@classmethod @classmethod
async def get_or_create( async def get_or_create(
cls, cls,
@ -117,23 +178,19 @@ class GroupConsole(Model):
group, is_create = await super().get_or_create( group, is_create = await super().get_or_create(
defaults=defaults, using_db=using_db, **kwargs defaults=defaults, using_db=using_db, **kwargs
) )
if is_create and ( if not is_create:
modules := await TaskInfo.filter(default_status=False).values_list( return group, is_create
"module", flat=True
) task_modules = await cls._get_task_modules(default_status=False)
): plugin_modules = await cls._get_plugin_modules(default_status=False)
group.block_task = cls.convert_module_format(modules) # type: ignore
if modules := await PluginInfo.filter( if task_modules or plugin_modules:
plugin_type__in=[PluginType.NORMAL, PluginType.DEPENDANT], await cls._update_modules(group, task_modules, plugin_modules, using_db)
default_status=False,
).values_list("module", flat=True):
group.block_plugin = cls.convert_module_format(modules) # type: ignore
await group.save(
using_db=using_db, update_fields=["block_plugin", "block_task"]
)
if is_create: if is_create:
if cache := await CacheRoot.get_cache(CacheType.GROUPS): if cache := await CacheRoot.get_cache(CacheType.GROUPS):
await cache.update(group.group_id, group) await cache.update(group.group_id, group)
return group, is_create return group, is_create
@classmethod @classmethod
@ -148,20 +205,15 @@ class GroupConsole(Model):
group, is_create = await super().update_or_create( group, is_create = await super().update_or_create(
defaults=defaults, using_db=using_db, **kwargs defaults=defaults, using_db=using_db, **kwargs
) )
if is_create and ( if not is_create:
modules := await TaskInfo.filter(default_status=False).values_list( return group, is_create
"module", flat=True
) task_modules = await cls._get_task_modules(default_status=False)
): plugin_modules = await cls._get_plugin_modules(default_status=False)
group.block_task = cls.convert_module_format(modules) # type: ignore
if modules := await PluginInfo.filter( if task_modules or plugin_modules:
plugin_type__in=[PluginType.NORMAL, PluginType.DEPENDANT], await cls._update_modules(group, task_modules, plugin_modules, using_db)
default_status=False,
).values_list("module", flat=True):
group.block_plugin = cls.convert_module_format(modules) # type: ignore
await group.save(
using_db=using_db, update_fields=["block_plugin", "block_task"]
)
return group, is_create return group, is_create
@classmethod @classmethod
@ -206,7 +258,7 @@ class GroupConsole(Model):
""" """
return await cls.exists( return await cls.exists(
group_id=group_id, group_id=group_id,
superuser_block_plugin__contains=cls.format(module), superuser_block_plugin__contains=add_disable_marker(module),
) )
@classmethod @classmethod
@ -220,10 +272,11 @@ class GroupConsole(Model):
返回: 返回:
bool: 是否禁用插件 bool: 是否禁用插件
""" """
module = add_disable_marker(module)
return await cls.exists( return await cls.exists(
group_id=group_id, block_plugin__contains=cls.format(module) group_id=group_id, block_plugin__contains=module
) or await cls.exists( ) or await cls.exists(
group_id=group_id, superuser_block_plugin__contains=cls.format(module) group_id=group_id, superuser_block_plugin__contains=module
) )
@classmethod @classmethod
@ -245,12 +298,22 @@ class GroupConsole(Model):
group, _ = await cls.get_or_create( group, _ = await cls.get_or_create(
group_id=group_id, defaults={"platform": platform} group_id=group_id, defaults={"platform": platform}
) )
update_fields = []
if is_superuser: if is_superuser:
if cls.format(module) not in group.superuser_block_plugin: superuser_block_plugin = convert_module_format(group.superuser_block_plugin)
group.superuser_block_plugin += cls.format(module) if module not in superuser_block_plugin:
elif cls.format(module) not in group.block_plugin: superuser_block_plugin.append(module)
group.block_plugin += cls.format(module) group.superuser_block_plugin = convert_module_format(
await group.save(update_fields=["block_plugin", "superuser_block_plugin"]) superuser_block_plugin
)
update_fields.append("superuser_block_plugin")
elif add_disable_marker(module) not in group.block_plugin:
block_plugin = convert_module_format(group.block_plugin)
block_plugin.append(module)
group.block_plugin = convert_module_format(block_plugin)
update_fields.append("block_plugin")
if update_fields:
await group.save(update_fields=update_fields)
@classmethod @classmethod
async def set_unblock_plugin( async def set_unblock_plugin(
@ -271,14 +334,22 @@ class GroupConsole(Model):
group, _ = await cls.get_or_create( group, _ = await cls.get_or_create(
group_id=group_id, defaults={"platform": platform} group_id=group_id, defaults={"platform": platform}
) )
update_fields = []
if is_superuser: if is_superuser:
if cls.format(module) in group.superuser_block_plugin: superuser_block_plugin = convert_module_format(group.superuser_block_plugin)
group.superuser_block_plugin = group.superuser_block_plugin.replace( if module in superuser_block_plugin:
cls.format(module), "" superuser_block_plugin.remove(module)
group.superuser_block_plugin = convert_module_format(
superuser_block_plugin
) )
elif cls.format(module) in group.block_plugin: update_fields.append("superuser_block_plugin")
group.block_plugin = group.block_plugin.replace(cls.format(module), "") elif add_disable_marker(module) in group.block_plugin:
await group.save(update_fields=["block_plugin", "superuser_block_plugin"]) block_plugin = convert_module_format(group.block_plugin)
block_plugin.remove(module)
group.block_plugin = convert_module_format(block_plugin)
update_fields.append("block_plugin")
if update_fields:
await group.save(update_fields=update_fields)
@classmethod @classmethod
async def is_normal_block_plugin( async def is_normal_block_plugin(
@ -297,7 +368,7 @@ class GroupConsole(Model):
return await cls.exists( return await cls.exists(
group_id=group_id, group_id=group_id,
channel_id=channel_id, channel_id=channel_id,
block_plugin__contains=cls.format(module), block_plugin__contains=f"<{module},",
) )
@classmethod @classmethod
@ -313,7 +384,7 @@ class GroupConsole(Model):
""" """
return await cls.exists( return await cls.exists(
group_id=group_id, group_id=group_id,
superuser_block_task__contains=cls.format(task), superuser_block_task__contains=add_disable_marker(task),
) )
@classmethod @classmethod
@ -330,24 +401,23 @@ class GroupConsole(Model):
返回: 返回:
bool: 是否禁用被动 bool: 是否禁用被动
""" """
task = add_disable_marker(task)
if not channel_id: if not channel_id:
return await cls.exists( return await cls.exists(
group_id=group_id, group_id=group_id,
channel_id__isnull=True, channel_id__isnull=True,
block_task__contains=cls.format(task), block_task__contains=task,
) or await cls.exists( ) or await cls.exists(
group_id=group_id, group_id=group_id,
channel_id__isnull=True, channel_id__isnull=True,
superuser_block_task__contains=cls.format(task), superuser_block_task__contains=task,
) )
return await cls.exists( return await cls.exists(
group_id=group_id, group_id=group_id, channel_id=channel_id, block_task__contains=task
channel_id=channel_id,
block_task__contains=cls.format(task),
) or await cls.exists( ) or await cls.exists(
group_id=group_id, group_id=group_id,
channel_id__isnull=True, channel_id__isnull=True,
superuser_block_task__contains=cls.format(task), superuser_block_task__contains=task,
) )
@classmethod @classmethod
@ -369,12 +439,20 @@ class GroupConsole(Model):
group, _ = await cls.get_or_create( group, _ = await cls.get_or_create(
group_id=group_id, defaults={"platform": platform} group_id=group_id, defaults={"platform": platform}
) )
update_fields = []
if is_superuser: if is_superuser:
if cls.format(task) not in group.superuser_block_task: superuser_block_task = convert_module_format(group.superuser_block_task)
group.superuser_block_task += cls.format(task) if task not in group.superuser_block_task:
elif cls.format(task) not in group.block_task: superuser_block_task.append(task)
group.block_task += cls.format(task) group.superuser_block_task = convert_module_format(superuser_block_task)
await group.save(update_fields=["block_task", "superuser_block_task"]) update_fields.append("superuser_block_task")
elif add_disable_marker(task) not in group.block_task:
block_task = convert_module_format(group.block_task)
block_task.append(task)
group.block_task = convert_module_format(block_task)
update_fields.append("block_task")
if update_fields:
await group.save(update_fields=update_fields)
@classmethod @classmethod
async def set_unblock_task( async def set_unblock_task(
@ -395,14 +473,20 @@ class GroupConsole(Model):
group, _ = await cls.get_or_create( group, _ = await cls.get_or_create(
group_id=group_id, defaults={"platform": platform} group_id=group_id, defaults={"platform": platform}
) )
update_fields = []
if is_superuser: if is_superuser:
if cls.format(task) in group.superuser_block_task: superuser_block_task = convert_module_format(group.superuser_block_task)
group.superuser_block_task = group.superuser_block_task.replace( if task in superuser_block_task:
cls.format(task), "" superuser_block_task.remove(task)
) group.superuser_block_task = convert_module_format(superuser_block_task)
elif cls.format(task) in group.block_task: update_fields.append("superuser_block_task")
group.block_task = group.block_task.replace(cls.format(task), "") elif add_disable_marker(task) in group.block_task:
await group.save(update_fields=["block_task", "superuser_block_task"]) block_task = convert_module_format(group.block_task)
block_task.remove(task)
group.block_task = convert_module_format(block_task)
update_fields.append("block_task")
if update_fields:
await group.save(update_fields=update_fields)
@classmethod @classmethod
def _run_script(cls): def _run_script(cls):

View File

@ -46,13 +46,12 @@ class CacheGetter(BaseModel, Generic[T]):
async def get(self, cache_data: "CacheData", *args, **kwargs) -> T: async def get(self, cache_data: "CacheData", *args, **kwargs) -> T:
"""获取缓存""" """获取缓存"""
processed_data = ( if not self.get_func:
await self.get_func(cache_data, *args, **kwargs) return cache_data.data
if self.get_func and is_coroutine_callable(self.get_func) if is_coroutine_callable(self.get_func):
else self.get_func(cache_data, *args, **kwargs) processed_data = await self.get_func(cache_data, *args, **kwargs)
if self.get_func else:
else cache_data.data processed_data = self.get_func(cache_data, *args, **kwargs)
)
return cast(T, processed_data) return cast(T, processed_data)
@ -81,9 +80,14 @@ class CacheData(BaseModel):
"""更新时间""" """更新时间"""
reload_count: int = 0 reload_count: int = 0
"""更新次数""" """更新次数"""
incremental_update: bool = True
"""是否是增量更新"""
async def get(self, *args, **kwargs) -> Any: async def get(self, *args, **kwargs) -> Any:
"""获取单个缓存""" """获取单个缓存"""
if not self.reload_count and not self.incremental_update:
# 首次获取时,非增量更新获取全部数据
await self.reload()
self.call_cleanup_expired() # 移除过期缓存 self.call_cleanup_expired() # 移除过期缓存
if not self.getter: if not self.getter:
return self.data return self.data
@ -210,12 +214,17 @@ class CacheManage:
id=f"CacheRoot-{cache_data.name}", id=f"CacheRoot-{cache_data.name}",
) )
def new(self, name: str, expire: int = 60 * 10): def new(self, name: str, incremental_update: bool = True, expire: int = 60 * 10):
def wrapper(func: Callable): def wrapper(func: Callable):
_name = name.upper() _name = name.upper()
if _name in self._data: if _name in self._data:
raise DbCacheException(f"DbCache 缓存数据 {name} 已存在...") raise DbCacheException(f"DbCache 缓存数据 {name} 已存在...")
self._data[_name] = CacheData(name=_name, func=func, expire=expire) self._data[_name] = CacheData(
name=_name,
func=func,
expire=expire,
incremental_update=incremental_update,
)
return wrapper return wrapper
@ -253,7 +262,7 @@ class CacheManage:
return wrapper return wrapper
@validate_name @validate_name
def getter(self, name: str, result_model: type | None = None): def getter(self, name: str, result_model: type):
def wrapper(func: Callable): def wrapper(func: Callable):
self._data[name].getter = CacheGetter[result_model](get_func=func) self._data[name].getter = CacheGetter[result_model](get_func=func)

View File

@ -69,12 +69,12 @@ class Model(TortoiseModel):
using_db: BaseDBAsyncClient | None = None, using_db: BaseDBAsyncClient | None = None,
**kwargs: Any, **kwargs: Any,
) -> tuple[Self, bool]: ) -> tuple[Self, bool]:
result = await super().get_or_create( result, is_create = await super().get_or_create(
defaults=defaults, using_db=using_db, **kwargs defaults=defaults, using_db=using_db, **kwargs
) )
if cache_type := cls.get_cache_type(): if is_create and (cache_type := cls.get_cache_type()):
await CacheRoot.reload(cache_type) await CacheRoot.reload(cache_type)
return result return (result, is_create)
@classmethod @classmethod
async def update_or_create( async def update_or_create(

View File

@ -1,4 +1,5 @@
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime from datetime import datetime
import os import os
from pathlib import Path from pathlib import Path
@ -6,6 +7,7 @@ import time
from typing import Any from typing import Any
import httpx import httpx
from nonebot_plugin_uninfo import Uninfo
import pypinyin import pypinyin
import pytz import pytz
@ -13,6 +15,16 @@ from zhenxun.configs.config import Config
from zhenxun.services.log import logger from zhenxun.services.log import logger
@dataclass
class EntityIDs:
user_id: str
"""用户id"""
group_id: str | None
"""群组id"""
channel_id: str | None
"""频道id"""
class ResourceDirManager: class ResourceDirManager:
""" """
临时文件管理器 临时文件管理器
@ -228,3 +240,24 @@ def is_valid_date(date_text: str, separator: str = "-") -> bool:
return True return True
except ValueError: except ValueError:
return False return False
def get_entity_ids(session: Uninfo) -> EntityIDs:
"""获取用户id群组id频道id
参数:
session: Uninfo
返回:
EntityIDs: 用户id群组id频道id
"""
user_id = session.user.id
group_id = None
channel_id = None
if session.group:
if session.group.parent:
group_id = session.group.parent.id
channel_id = session.group.id
else:
group_id = session.group.id
return EntityIDs(user_id=user_id, group_id=group_id, channel_id=channel_id)