mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 14:22:55 +08:00
⚡ 优化hook权限检测性能
This commit is contained in:
parent
03f0185e46
commit
41dd767724
@ -196,7 +196,7 @@ class PluginManage:
|
|||||||
await PluginInfo.filter(plugin_type=PluginType.NORMAL).update(
|
await PluginInfo.filter(plugin_type=PluginType.NORMAL).update(
|
||||||
default_status=status
|
default_status=status
|
||||||
)
|
)
|
||||||
return f'成功将所有功能进群默认状态修改为: {"开启" if status else "关闭"}'
|
return f"成功将所有功能进群默认状态修改为: {'开启' if status else '关闭'}"
|
||||||
if group_id:
|
if group_id:
|
||||||
if group := await GroupConsole.get_or_none(
|
if group := await GroupConsole.get_or_none(
|
||||||
group_id=group_id, channel_id__isnull=True
|
group_id=group_id, channel_id__isnull=True
|
||||||
@ -213,12 +213,12 @@ class PluginManage:
|
|||||||
module_list = [f"<{module}" for module in module_list]
|
module_list = [f"<{module}" for module in module_list]
|
||||||
group.block_plugin = ",".join(module_list) + "," # type: ignore
|
group.block_plugin = ",".join(module_list) + "," # type: ignore
|
||||||
await group.save(update_fields=["block_plugin"])
|
await group.save(update_fields=["block_plugin"])
|
||||||
return f'成功将此群组所有功能状态修改为: {"开启" if status else "关闭"}'
|
return f"成功将此群组所有功能状态修改为: {'开启' if status else '关闭'}"
|
||||||
return "获取群组失败..."
|
return "获取群组失败..."
|
||||||
await PluginInfo.filter(plugin_type=PluginType.NORMAL).update(
|
await PluginInfo.filter(plugin_type=PluginType.NORMAL).update(
|
||||||
status=status, block_type=None if status else BlockType.ALL
|
status=status, block_type=None if status else BlockType.ALL
|
||||||
)
|
)
|
||||||
return f'成功将所有功能全局状态修改为: {"开启" if status else "关闭"}'
|
return f"成功将所有功能全局状态修改为: {'开启' if status else '关闭'}"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def is_wake(cls, group_id: str) -> bool:
|
async def is_wake(cls, group_id: str) -> bool:
|
||||||
@ -243,9 +243,11 @@ class PluginManage:
|
|||||||
参数:
|
参数:
|
||||||
group_id: 群组id
|
group_id: 群组id
|
||||||
"""
|
"""
|
||||||
await GroupConsole.filter(group_id=group_id, channel_id__isnull=True).update(
|
group, _ = await GroupConsole.get_or_create(
|
||||||
status=False
|
group_id=group_id, channel_id__isnull=True
|
||||||
)
|
)
|
||||||
|
group.status = False
|
||||||
|
await group.save(update_fields=["status"])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def wake(cls, group_id: str):
|
async def wake(cls, group_id: str):
|
||||||
@ -254,9 +256,11 @@ class PluginManage:
|
|||||||
参数:
|
参数:
|
||||||
group_id: 群组id
|
group_id: 群组id
|
||||||
"""
|
"""
|
||||||
await GroupConsole.filter(group_id=group_id, channel_id__isnull=True).update(
|
group, _ = await GroupConsole.get_or_create(
|
||||||
status=True
|
group_id=group_id, channel_id__isnull=True
|
||||||
)
|
)
|
||||||
|
group.status = True
|
||||||
|
await group.save(update_fields=["status"])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def block(cls, module: str):
|
async def block(cls, module: str):
|
||||||
|
|||||||
@ -1,15 +1,14 @@
|
|||||||
from nonebot.exception import IgnoredException
|
|
||||||
from nonebot_plugin_alconna import At
|
from nonebot_plugin_alconna import At
|
||||||
from nonebot_plugin_uninfo import Uninfo
|
from nonebot_plugin_uninfo import Uninfo
|
||||||
|
|
||||||
from zhenxun.models.level_user import LevelUser
|
from zhenxun.models.level_user import LevelUser
|
||||||
from zhenxun.models.plugin_info import PluginInfo
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
from zhenxun.services.cache import Cache
|
from zhenxun.services.cache import Cache
|
||||||
from zhenxun.services.log import logger
|
|
||||||
from zhenxun.utils.enum import CacheType
|
from zhenxun.utils.enum import CacheType
|
||||||
from zhenxun.utils.message import MessageUtils
|
from zhenxun.utils.utils import get_entity_ids
|
||||||
|
|
||||||
from .utils import freq
|
from .exception import SkipPluginException
|
||||||
|
from .utils import send_message
|
||||||
|
|
||||||
|
|
||||||
async def auth_admin(plugin: PluginInfo, session: Uninfo):
|
async def auth_admin(plugin: PluginInfo, session: Uninfo):
|
||||||
@ -17,60 +16,33 @@ async def auth_admin(plugin: PluginInfo, session: Uninfo):
|
|||||||
|
|
||||||
参数:
|
参数:
|
||||||
plugin: PluginInfo
|
plugin: PluginInfo
|
||||||
session: PluginInfo
|
session: Uninfo
|
||||||
"""
|
"""
|
||||||
group_id = None
|
|
||||||
cache = Cache[list[LevelUser]](CacheType.LEVEL)
|
|
||||||
user_level = await cache.get(session.user.id) or []
|
|
||||||
if session.group:
|
|
||||||
if session.group.parent:
|
|
||||||
group_id = session.group.parent.id
|
|
||||||
else:
|
|
||||||
group_id = session.group.id
|
|
||||||
|
|
||||||
if not plugin.admin_level:
|
if not plugin.admin_level:
|
||||||
return
|
return
|
||||||
if group_id:
|
entity = get_entity_ids(session)
|
||||||
user_level += await cache.get(session.user.id, group_id) or []
|
cache = Cache[list[LevelUser]](CacheType.LEVEL)
|
||||||
|
user_level = await cache.get(session.user.id) or []
|
||||||
|
if entity.group_id:
|
||||||
|
user_level += await cache.get(session.user.id, entity.group_id) or []
|
||||||
user = max(user_level, key=lambda x: x.user_level)
|
user = max(user_level, key=lambda x: x.user_level)
|
||||||
if user.user_level < plugin.admin_level:
|
if user.user_level < plugin.admin_level:
|
||||||
try:
|
await send_message(
|
||||||
if freq._flmt.check(session.user.id):
|
session,
|
||||||
freq._flmt.start_cd(session.user.id)
|
[
|
||||||
await MessageUtils.build_message(
|
At(flag="user", target=session.user.id),
|
||||||
[
|
f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}",
|
||||||
At(flag="user", target=session.user.id),
|
],
|
||||||
"你的权限不足喔,"
|
entity.user_id,
|
||||||
f"该功能需要的权限等级: {plugin.admin_level}",
|
)
|
||||||
]
|
raise SkipPluginException(
|
||||||
).send(reply_to=True)
|
f"{plugin.name}({plugin.module}) 管理员权限不足..."
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
"auth_admin 发送消息失败",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
e=e,
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"{plugin.name}({plugin.module}) 管理员权限不足...",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
)
|
)
|
||||||
raise IgnoredException("管理员权限不足...")
|
|
||||||
elif user_level:
|
elif user_level:
|
||||||
user = max(user_level, key=lambda x: x.user_level)
|
user = max(user_level, key=lambda x: x.user_level)
|
||||||
if user.user_level < plugin.admin_level:
|
if user.user_level < plugin.admin_level:
|
||||||
try:
|
await send_message(
|
||||||
await MessageUtils.build_message(
|
session,
|
||||||
f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}"
|
f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}",
|
||||||
).send()
|
)
|
||||||
except Exception as e:
|
raise SkipPluginException(f"{plugin.name}({plugin.module}) 管理员权限不足...")
|
||||||
logger.error(
|
|
||||||
"auth_admin 发送消息失败", "AuthChecker", session=session, e=e
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"{plugin.name}({plugin.module}) 管理员权限不足...",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
raise IgnoredException("权限不足")
|
|
||||||
|
|||||||
167
zhenxun/builtin_plugins/hooks/auth/auth_ban.py
Normal file
167
zhenxun/builtin_plugins/hooks/auth/auth_ban.py
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
from nonebot.adapters import Bot
|
||||||
|
from nonebot.matcher import Matcher
|
||||||
|
from nonebot_plugin_alconna import At
|
||||||
|
from nonebot_plugin_uninfo import Uninfo
|
||||||
|
from tortoise.exceptions import MultipleObjectsReturned
|
||||||
|
|
||||||
|
from zhenxun.configs.config import Config
|
||||||
|
from zhenxun.models.ban_console import BanConsole
|
||||||
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
|
from zhenxun.services.cache import Cache
|
||||||
|
from zhenxun.services.log import logger
|
||||||
|
from zhenxun.utils.enum import CacheType, PluginType
|
||||||
|
from zhenxun.utils.utils import EntityIDs, get_entity_ids
|
||||||
|
|
||||||
|
from .config import LOGGER_COMMAND
|
||||||
|
from .exception import SkipPluginException
|
||||||
|
from .utils import send_message
|
||||||
|
|
||||||
|
Config.add_plugin_config(
|
||||||
|
"hook",
|
||||||
|
"BAN_RESULT",
|
||||||
|
"才不会给你发消息.",
|
||||||
|
help="对被ban用户发送的消息",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def is_ban(user_id: str | None, group_id: str | None) -> int:
|
||||||
|
cache = Cache[list[BanConsole]](CacheType.BAN)
|
||||||
|
results = await cache.get(user_id, group_id) or await cache.get(user_id)
|
||||||
|
if not results:
|
||||||
|
return 0
|
||||||
|
for result in results:
|
||||||
|
if result.group_id == group_id and (
|
||||||
|
result.duration > 0 or result.duration == -1
|
||||||
|
):
|
||||||
|
return await BanConsole.check_ban_time(user_id, group_id)
|
||||||
|
if not result.group_id and result.duration == -1:
|
||||||
|
return await BanConsole.check_ban_time(user_id, group_id)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def check_plugin_type(matcher: Matcher) -> bool:
|
||||||
|
"""判断插件类型是否是隐藏插件
|
||||||
|
|
||||||
|
参数:
|
||||||
|
matcher: Matcher
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bool: 是否为隐藏插件
|
||||||
|
"""
|
||||||
|
if plugin := matcher.plugin:
|
||||||
|
if metadata := plugin.metadata:
|
||||||
|
extra = metadata.extra
|
||||||
|
if extra.get("plugin_type") in [PluginType.HIDDEN]:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def format_time(time: float) -> str:
|
||||||
|
"""格式化时间
|
||||||
|
|
||||||
|
参数:
|
||||||
|
time: ban时长
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 格式化时间文本
|
||||||
|
"""
|
||||||
|
if time == -1:
|
||||||
|
return "∞"
|
||||||
|
time = abs(int(time))
|
||||||
|
if time < 60:
|
||||||
|
time_str = f"{time!s} 秒"
|
||||||
|
else:
|
||||||
|
minute = int(time / 60)
|
||||||
|
if minute > 60:
|
||||||
|
hours = minute // 60
|
||||||
|
minute %= 60
|
||||||
|
time_str = f"{hours} 小时 {minute}分钟"
|
||||||
|
else:
|
||||||
|
time_str = f"{minute} 分钟"
|
||||||
|
return time_str
|
||||||
|
|
||||||
|
|
||||||
|
async def group_handle(cache: Cache[list[BanConsole]], group_id: str):
|
||||||
|
"""群组ban检查
|
||||||
|
|
||||||
|
参数:
|
||||||
|
cache: cache
|
||||||
|
group_id: 群组id
|
||||||
|
|
||||||
|
异常:
|
||||||
|
SkipPluginException: 群组处于黑名单
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if await is_ban(None, group_id):
|
||||||
|
raise SkipPluginException("群组处于黑名单中...")
|
||||||
|
except MultipleObjectsReturned:
|
||||||
|
logger.warning(
|
||||||
|
"群组黑名单数据重复,过滤该次hook并移除多余数据...", LOGGER_COMMAND
|
||||||
|
)
|
||||||
|
ids = await BanConsole.filter(user_id="", group_id=group_id).values_list(
|
||||||
|
"id", flat=True
|
||||||
|
)
|
||||||
|
await BanConsole.filter(id__in=ids[:-1]).delete()
|
||||||
|
await cache.reload()
|
||||||
|
|
||||||
|
|
||||||
|
async def user_handle(
|
||||||
|
module: str, cache: Cache[list[BanConsole]], entity: EntityIDs, session: Uninfo
|
||||||
|
):
|
||||||
|
"""用户ban检查
|
||||||
|
|
||||||
|
参数:
|
||||||
|
module: 插件模块名
|
||||||
|
cache: cache
|
||||||
|
user_id: 用户id
|
||||||
|
session: Uninfo
|
||||||
|
|
||||||
|
异常:
|
||||||
|
SkipPluginException: 用户处于黑名单
|
||||||
|
"""
|
||||||
|
ban_result = Config.get_config("hook", "BAN_RESULT")
|
||||||
|
try:
|
||||||
|
time = await is_ban(entity.user_id, entity.group_id)
|
||||||
|
if not time:
|
||||||
|
return
|
||||||
|
time_str = format_time(time)
|
||||||
|
db_plugin = await Cache[PluginInfo](CacheType.PLUGINS).get(module)
|
||||||
|
if (
|
||||||
|
db_plugin
|
||||||
|
# and not db_plugin.ignore_prompt
|
||||||
|
and time != -1
|
||||||
|
and ban_result
|
||||||
|
):
|
||||||
|
await send_message(
|
||||||
|
session,
|
||||||
|
[
|
||||||
|
At(flag="user", target=entity.user_id),
|
||||||
|
f"{ban_result}\n在..在 {time_str} 后才会理你喔",
|
||||||
|
],
|
||||||
|
entity.user_id,
|
||||||
|
)
|
||||||
|
raise SkipPluginException("用户处于黑名单中...")
|
||||||
|
except MultipleObjectsReturned:
|
||||||
|
logger.warning(
|
||||||
|
"用户黑名单数据重复,过滤该次hook并移除多余数据...", LOGGER_COMMAND
|
||||||
|
)
|
||||||
|
ids = await BanConsole.filter(user_id=entity.user_id, group_id="").values_list(
|
||||||
|
"id", flat=True
|
||||||
|
)
|
||||||
|
await BanConsole.filter(id__in=ids[:-1]).delete()
|
||||||
|
await cache.reload()
|
||||||
|
|
||||||
|
|
||||||
|
async def auth_ban(matcher: Matcher, bot: Bot, session: Uninfo):
|
||||||
|
if not check_plugin_type(matcher):
|
||||||
|
return
|
||||||
|
if not matcher.plugin_name:
|
||||||
|
return
|
||||||
|
entity = get_entity_ids(session)
|
||||||
|
if entity.user_id in bot.config.superusers:
|
||||||
|
return
|
||||||
|
cache = Cache[list[BanConsole]](CacheType.BAN)
|
||||||
|
if entity.group_id:
|
||||||
|
await group_handle(cache, entity.group_id)
|
||||||
|
if entity.user_id:
|
||||||
|
await user_handle(matcher.plugin_name, cache, entity, session)
|
||||||
@ -1,12 +1,11 @@
|
|||||||
from nonebot.exception import IgnoredException
|
|
||||||
|
|
||||||
from zhenxun.models.bot_console import BotConsole
|
from zhenxun.models.bot_console import BotConsole
|
||||||
from zhenxun.models.plugin_info import PluginInfo
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
from zhenxun.services.cache import Cache
|
from zhenxun.services.cache import Cache
|
||||||
from zhenxun.services.log import logger
|
|
||||||
from zhenxun.utils.common_utils import CommonUtils
|
from zhenxun.utils.common_utils import CommonUtils
|
||||||
from zhenxun.utils.enum import CacheType
|
from zhenxun.utils.enum import CacheType
|
||||||
|
|
||||||
|
from .exception import SkipPluginException
|
||||||
|
|
||||||
|
|
||||||
async def auth_bot(plugin: PluginInfo, bot_id: str):
|
async def auth_bot(plugin: PluginInfo, bot_id: str):
|
||||||
"""bot层面的权限检查
|
"""bot层面的权限检查
|
||||||
@ -16,17 +15,14 @@ async def auth_bot(plugin: PluginInfo, bot_id: str):
|
|||||||
bot_id: bot id
|
bot_id: bot id
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
IgnoredException: 忽略插件
|
SkipPluginException: 忽略插件
|
||||||
IgnoredException: 忽略插件
|
SkipPluginException: 忽略插件
|
||||||
"""
|
"""
|
||||||
if cache := Cache[BotConsole](CacheType.BOT):
|
if cache := Cache[BotConsole](CacheType.BOT):
|
||||||
bot = await cache.get(bot_id)
|
bot = await cache.get(bot_id)
|
||||||
if not bot or not bot.status:
|
if not bot or not bot.status:
|
||||||
logger.debug("Bot不存在或休眠中阻断权限检测...", "AuthChecker")
|
raise SkipPluginException("Bot不存在或休眠中阻断权限检测...")
|
||||||
raise IgnoredException("BotConsole休眠权限检测 ignore")
|
|
||||||
if CommonUtils.format(plugin.module) in bot.block_plugins:
|
if CommonUtils.format(plugin.module) in bot.block_plugins:
|
||||||
logger.debug(
|
raise SkipPluginException(
|
||||||
f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭...",
|
f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭..."
|
||||||
"AuthChecker",
|
|
||||||
)
|
)
|
||||||
raise IgnoredException("BotConsole插件权限检测 ignore")
|
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
from nonebot.exception import IgnoredException
|
|
||||||
from nonebot_plugin_uninfo import Uninfo
|
from nonebot_plugin_uninfo import Uninfo
|
||||||
|
|
||||||
from zhenxun.models.plugin_info import PluginInfo
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
from zhenxun.models.user_console import UserConsole
|
from zhenxun.models.user_console import UserConsole
|
||||||
from zhenxun.services.log import logger
|
|
||||||
from zhenxun.utils.message import MessageUtils
|
from .exception import SkipPluginException
|
||||||
|
from .utils import send_message
|
||||||
|
|
||||||
|
|
||||||
async def auth_cost(user: UserConsole, plugin: PluginInfo, session: Uninfo) -> int:
|
async def auth_cost(user: UserConsole, plugin: PluginInfo, session: Uninfo) -> int:
|
||||||
@ -19,17 +19,6 @@ async def auth_cost(user: UserConsole, plugin: PluginInfo, session: Uninfo) -> i
|
|||||||
"""
|
"""
|
||||||
if user.gold < plugin.cost_gold:
|
if user.gold < plugin.cost_gold:
|
||||||
"""插件消耗金币不足"""
|
"""插件消耗金币不足"""
|
||||||
try:
|
await send_message(session, f"金币不足..该功能需要{plugin.cost_gold}金币..")
|
||||||
await MessageUtils.build_message(
|
raise SkipPluginException(f"{plugin.name}({plugin.module}) 金币限制...")
|
||||||
f"金币不足..该功能需要{plugin.cost_gold}金币.."
|
|
||||||
).send()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("auth_cost 发送消息失败", "AuthChecker", session=session, e=e)
|
|
||||||
logger.debug(
|
|
||||||
f"{plugin.name}({plugin.module}) 金币限制.."
|
|
||||||
f"该功能需要{plugin.cost_gold}金币..",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
raise IgnoredException(f"{plugin.name}({plugin.module}) 金币限制...")
|
|
||||||
return plugin.cost_gold
|
return plugin.cost_gold
|
||||||
|
|||||||
@ -1,57 +1,35 @@
|
|||||||
from nonebot.exception import IgnoredException
|
|
||||||
from nonebot_plugin_alconna import UniMsg
|
from nonebot_plugin_alconna import UniMsg
|
||||||
from nonebot_plugin_uninfo import Uninfo
|
|
||||||
|
|
||||||
from zhenxun.models.group_console import GroupConsole
|
from zhenxun.models.group_console import GroupConsole
|
||||||
from zhenxun.models.plugin_info import PluginInfo
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
from zhenxun.services.cache import Cache
|
from zhenxun.services.cache import Cache
|
||||||
from zhenxun.services.log import logger
|
|
||||||
from zhenxun.utils.enum import CacheType
|
from zhenxun.utils.enum import CacheType
|
||||||
|
from zhenxun.utils.utils import EntityIDs
|
||||||
|
|
||||||
|
from .config import SwitchEnum
|
||||||
|
from .exception import SkipPluginException
|
||||||
|
|
||||||
|
|
||||||
async def auth_group(plugin: PluginInfo, session: Uninfo, message: UniMsg):
|
async def auth_group(plugin: PluginInfo, entity: EntityIDs, message: UniMsg):
|
||||||
"""群黑名单检测 群总开关检测
|
"""群黑名单检测 群总开关检测
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
plugin: PluginInfo
|
plugin: PluginInfo
|
||||||
session: EventSession
|
entity: EntityIDs
|
||||||
message: UniMsg
|
message: UniMsg
|
||||||
"""
|
"""
|
||||||
if not session.group:
|
if not entity.group_id:
|
||||||
return
|
return
|
||||||
if session.group.parent:
|
|
||||||
group_id = session.group.parent.id
|
|
||||||
else:
|
|
||||||
group_id = session.group.id
|
|
||||||
text = message.extract_plain_text()
|
text = message.extract_plain_text()
|
||||||
group = await Cache[GroupConsole](CacheType.GROUPS).get(group_id)
|
group = await Cache[GroupConsole](CacheType.GROUPS).get(entity.group_id)
|
||||||
if not group:
|
if not group:
|
||||||
"""群不存在"""
|
raise SkipPluginException("群组信息不存在...")
|
||||||
logger.debug(
|
|
||||||
"群组信息不存在...",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
raise IgnoredException("群不存在")
|
|
||||||
if group.level < 0:
|
if group.level < 0:
|
||||||
"""群权限小于0"""
|
raise SkipPluginException("群组黑名单, 目标群组群权限权限-1...")
|
||||||
logger.debug(
|
if text.strip() != SwitchEnum.ENABLE and not group.status:
|
||||||
"群黑名单, 群权限-1...",
|
raise SkipPluginException("群组休眠状态...")
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
raise IgnoredException("群黑名单")
|
|
||||||
if not group.status:
|
|
||||||
"""群休眠"""
|
|
||||||
if text.strip() != "醒来":
|
|
||||||
logger.debug("群休眠状态...", "AuthChecker", session=session)
|
|
||||||
raise IgnoredException("群休眠状态")
|
|
||||||
if plugin.level > group.level:
|
if plugin.level > group.level:
|
||||||
"""插件等级大于群等级"""
|
raise SkipPluginException(
|
||||||
logger.debug(
|
f"{plugin.name}({plugin.module}) 群等级限制,"
|
||||||
f"{plugin.name}({plugin.module}) 群等级限制.."
|
f"该功能需要的群等级: {plugin.level}..."
|
||||||
f"该功能需要的群等级: {plugin.level}..",
|
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
)
|
)
|
||||||
raise IgnoredException(f"{plugin.name}({plugin.module}) 群等级限制...")
|
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
from typing import ClassVar
|
from typing import ClassVar
|
||||||
|
|
||||||
from nonebot.exception import IgnoredException
|
|
||||||
from nonebot_plugin_uninfo import Uninfo
|
from nonebot_plugin_uninfo import Uninfo
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@ -11,6 +10,9 @@ from zhenxun.utils.enum import LimitWatchType, PluginLimitType
|
|||||||
from zhenxun.utils.message import MessageUtils
|
from zhenxun.utils.message import MessageUtils
|
||||||
from zhenxun.utils.utils import CountLimiter, FreqLimiter, UserBlockLimiter
|
from zhenxun.utils.utils import CountLimiter, FreqLimiter, UserBlockLimiter
|
||||||
|
|
||||||
|
from .config import LOGGER_COMMAND
|
||||||
|
from .exception import SkipPluginException
|
||||||
|
|
||||||
|
|
||||||
class Limit(BaseModel):
|
class Limit(BaseModel):
|
||||||
limit: PluginLimit
|
limit: PluginLimit
|
||||||
@ -69,7 +71,7 @@ class LimitManage:
|
|||||||
key_type = channel_id or group_id
|
key_type = channel_id or group_id
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"解除对象: {key_type} 的block限制",
|
f"解除对象: {key_type} 的block限制",
|
||||||
"AuthChecker",
|
LOGGER_COMMAND,
|
||||||
session=user_id,
|
session=user_id,
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
)
|
)
|
||||||
@ -139,16 +141,13 @@ class LimitManage:
|
|||||||
if is_limit and not limiter.check(key_type):
|
if is_limit and not limiter.check(key_type):
|
||||||
if limit.result:
|
if limit.result:
|
||||||
await MessageUtils.build_message(limit.result).send()
|
await MessageUtils.build_message(limit.result).send()
|
||||||
logger.debug(
|
raise SkipPluginException(
|
||||||
f"{limit.module}({limit.limit_type}) 正在限制中...",
|
f"{limit.module}({limit.limit_type}) 正在限制中..."
|
||||||
"AuthChecker",
|
|
||||||
session=session,
|
|
||||||
)
|
)
|
||||||
raise IgnoredException(f"{limit.module} 正在限制中...")
|
|
||||||
else:
|
else:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"开始进行限制 {limit.module}({limit.limit_type})...",
|
f"开始进行限制 {limit.module}({limit.limit_type})...",
|
||||||
"AuthChecker",
|
LOGGER_COMMAND,
|
||||||
session=user_id,
|
session=user_id,
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,17 +1,15 @@
|
|||||||
from nonebot.adapters import Event
|
from nonebot.adapters import Event
|
||||||
from nonebot.exception import IgnoredException
|
|
||||||
from nonebot_plugin_uninfo import Uninfo
|
from nonebot_plugin_uninfo import Uninfo
|
||||||
|
|
||||||
from zhenxun.models.group_console import GroupConsole
|
from zhenxun.models.group_console import GroupConsole
|
||||||
from zhenxun.models.plugin_info import PluginInfo
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
from zhenxun.services.cache import Cache
|
from zhenxun.services.cache import Cache
|
||||||
from zhenxun.services.log import logger
|
|
||||||
from zhenxun.utils.common_utils import CommonUtils
|
from zhenxun.utils.common_utils import CommonUtils
|
||||||
from zhenxun.utils.enum import BlockType, CacheType
|
from zhenxun.utils.enum import BlockType, CacheType
|
||||||
from zhenxun.utils.message import MessageUtils
|
from zhenxun.utils.utils import get_entity_ids
|
||||||
|
|
||||||
from .exception import IsSuperuserException
|
from .exception import IsSuperuserException, SkipPluginException
|
||||||
from .utils import freq
|
from .utils import freq, is_poke, send_message
|
||||||
|
|
||||||
|
|
||||||
class GroupCheck:
|
class GroupCheck:
|
||||||
@ -42,16 +40,12 @@ class GroupCheck:
|
|||||||
group = await self.__get_data()
|
group = await self.__get_data()
|
||||||
if group and CommonUtils.format(plugin.module) in group.superuser_block_plugin:
|
if group and CommonUtils.format(plugin.module) in group.superuser_block_plugin:
|
||||||
if freq.is_send_limit_message(plugin, group.group_id, self.is_poke):
|
if freq.is_send_limit_message(plugin, group.group_id, self.is_poke):
|
||||||
freq._flmt_s.start_cd(self.group_id)
|
await send_message(
|
||||||
await MessageUtils.build_message("超级管理员禁用了该群此功能...").send(
|
self.session, "超级管理员禁用了该群此功能...", self.group_id
|
||||||
reply_to=True
|
|
||||||
)
|
)
|
||||||
logger.debug(
|
raise SkipPluginException(
|
||||||
f"{plugin.name}({plugin.module}) 超级管理员禁用了该群此功能...",
|
f"{plugin.name}({plugin.module}) 超级管理员禁用了该群此功能..."
|
||||||
"AuthChecker",
|
|
||||||
session=self.session,
|
|
||||||
)
|
)
|
||||||
raise IgnoredException("超级管理员禁用了该群此功能...")
|
|
||||||
await self.check_normal_block(self.plugin)
|
await self.check_normal_block(self.plugin)
|
||||||
|
|
||||||
async def check_normal_block(self, plugin: PluginInfo):
|
async def check_normal_block(self, plugin: PluginInfo):
|
||||||
@ -66,16 +60,8 @@ class GroupCheck:
|
|||||||
group = await self.__get_data()
|
group = await self.__get_data()
|
||||||
if group and CommonUtils.format(plugin.module) in group.block_plugin:
|
if group and CommonUtils.format(plugin.module) in group.block_plugin:
|
||||||
if freq.is_send_limit_message(plugin, self.group_id, self.is_poke):
|
if freq.is_send_limit_message(plugin, self.group_id, self.is_poke):
|
||||||
freq._flmt_s.start_cd(self.group_id)
|
await send_message(self.session, "该群未开启此功能...", self.group_id)
|
||||||
await MessageUtils.build_message("该群未开启此功能...").send(
|
raise SkipPluginException(f"{plugin.name}({plugin.module}) 未开启此功能...")
|
||||||
reply_to=True
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"{plugin.name}({plugin.module}) 未开启此功能...",
|
|
||||||
"AuthChecker",
|
|
||||||
session=self.session,
|
|
||||||
)
|
|
||||||
raise IgnoredException("该群未开启此功能...")
|
|
||||||
await self.check_global_block(self.plugin)
|
await self.check_global_block(self.plugin)
|
||||||
|
|
||||||
async def check_global_block(self, plugin: PluginInfo):
|
async def check_global_block(self, plugin: PluginInfo):
|
||||||
@ -89,25 +75,13 @@ class GroupCheck:
|
|||||||
"""
|
"""
|
||||||
if plugin.block_type == BlockType.GROUP:
|
if plugin.block_type == BlockType.GROUP:
|
||||||
"""全局群组禁用"""
|
"""全局群组禁用"""
|
||||||
try:
|
if freq.is_send_limit_message(plugin, self.group_id, self.is_poke):
|
||||||
if freq.is_send_limit_message(plugin, self.group_id, self.is_poke):
|
await send_message(
|
||||||
freq._flmt_c.start_cd(self.group_id)
|
self.session, "该功能在群组中已被禁用...", self.group_id
|
||||||
await MessageUtils.build_message("该功能在群组中已被禁用...").send(
|
|
||||||
reply_to=True
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
"auth_plugin 发送消息失败",
|
|
||||||
"AuthChecker",
|
|
||||||
session=self.session,
|
|
||||||
e=e,
|
|
||||||
)
|
)
|
||||||
logger.debug(
|
raise SkipPluginException(
|
||||||
f"{plugin.name}({plugin.module}) 该插件在群组中已被禁用...",
|
f"{plugin.name}({plugin.module}) 该插件在群组中已被禁用..."
|
||||||
"AuthChecker",
|
|
||||||
session=self.session,
|
|
||||||
)
|
)
|
||||||
raise IgnoredException("该插件在群组中已被禁用...")
|
|
||||||
|
|
||||||
|
|
||||||
class PluginCheck:
|
class PluginCheck:
|
||||||
@ -126,25 +100,11 @@ class PluginCheck:
|
|||||||
IgnoredException: 忽略插件
|
IgnoredException: 忽略插件
|
||||||
"""
|
"""
|
||||||
if plugin.block_type == BlockType.PRIVATE:
|
if plugin.block_type == BlockType.PRIVATE:
|
||||||
try:
|
if freq.is_send_limit_message(plugin, self.session.user.id, self.is_poke):
|
||||||
if freq.is_send_limit_message(
|
await send_message(self.session, "该功能在私聊中已被禁用...")
|
||||||
plugin, self.session.user.id, self.is_poke
|
raise SkipPluginException(
|
||||||
):
|
f"{plugin.name}({plugin.module}) 该插件在私聊中已被禁用..."
|
||||||
freq._flmt_c.start_cd(self.session.user.id)
|
|
||||||
await MessageUtils.build_message("该功能在私聊中已被禁用...").send()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
"auth_admin 发送消息失败",
|
|
||||||
"AuthChecker",
|
|
||||||
session=self.session,
|
|
||||||
e=e,
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"{plugin.name}({plugin.module}) 该插件在私聊中已被禁用...",
|
|
||||||
"AuthChecker",
|
|
||||||
session=self.session,
|
|
||||||
)
|
)
|
||||||
raise IgnoredException("该插件在私聊中已被禁用...")
|
|
||||||
|
|
||||||
async def check_global(self, plugin: PluginInfo):
|
async def check_global(self, plugin: PluginInfo):
|
||||||
"""全局状态
|
"""全局状态
|
||||||
@ -162,16 +122,10 @@ class PluginCheck:
|
|||||||
if self.group_id and (group := await cache.get(self.group_id)):
|
if self.group_id and (group := await cache.get(self.group_id)):
|
||||||
if group.is_super:
|
if group.is_super:
|
||||||
raise IsSuperuserException()
|
raise IsSuperuserException()
|
||||||
logger.debug(
|
|
||||||
f"{plugin.name}({plugin.module}) 全局未开启此功能...",
|
|
||||||
"AuthChecker",
|
|
||||||
session=self.session,
|
|
||||||
)
|
|
||||||
sid = self.group_id or self.session.user.id
|
sid = self.group_id or self.session.user.id
|
||||||
if freq.is_send_limit_message(plugin, sid, self.is_poke):
|
if freq.is_send_limit_message(plugin, sid, self.is_poke):
|
||||||
freq._flmt_s.start_cd(sid)
|
await send_message(self.session, "全局未开启此功能...", sid)
|
||||||
await MessageUtils.build_message("全局未开启此功能...").send()
|
raise SkipPluginException(f"{plugin.name}({plugin.module}) 全局未开启此功能...")
|
||||||
raise IgnoredException("全局未开启此功能...")
|
|
||||||
|
|
||||||
|
|
||||||
async def auth_plugin(plugin: PluginInfo, session: Uninfo, event: Event):
|
async def auth_plugin(plugin: PluginInfo, session: Uninfo, event: Event):
|
||||||
@ -180,22 +134,13 @@ async def auth_plugin(plugin: PluginInfo, session: Uninfo, event: Event):
|
|||||||
参数:
|
参数:
|
||||||
plugin: PluginInfo
|
plugin: PluginInfo
|
||||||
session: Uninfo
|
session: Uninfo
|
||||||
|
event: Event
|
||||||
"""
|
"""
|
||||||
group_id = None
|
entity = get_entity_ids(session)
|
||||||
if session.group:
|
is_poke_event = is_poke(event)
|
||||||
if session.group.parent:
|
user_check = PluginCheck(entity.group_id, session, is_poke_event)
|
||||||
group_id = session.group.parent.id
|
if entity.group_id:
|
||||||
else:
|
group_check = GroupCheck(plugin, entity.group_id, session, is_poke_event)
|
||||||
group_id = session.group.id
|
|
||||||
try:
|
|
||||||
from nonebot.adapters.onebot.v11 import PokeNotifyEvent
|
|
||||||
|
|
||||||
is_poke = isinstance(event, PokeNotifyEvent)
|
|
||||||
except ImportError:
|
|
||||||
is_poke = False
|
|
||||||
user_check = PluginCheck(group_id, session, is_poke)
|
|
||||||
if group_id:
|
|
||||||
group_check = GroupCheck(plugin, group_id, session, is_poke)
|
|
||||||
await group_check.check()
|
await group_check.check()
|
||||||
else:
|
else:
|
||||||
await user_check.check_user(plugin)
|
await user_check.check_user(plugin)
|
||||||
|
|||||||
13
zhenxun/builtin_plugins/hooks/auth/config.py
Normal file
13
zhenxun/builtin_plugins/hooks/auth/config.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
import sys
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 11):
|
||||||
|
from enum import StrEnum
|
||||||
|
else:
|
||||||
|
from strenum import StrEnum
|
||||||
|
|
||||||
|
LOGGER_COMMAND = "AuthChecker"
|
||||||
|
|
||||||
|
|
||||||
|
class SwitchEnum(StrEnum):
|
||||||
|
ENABLE = "醒来"
|
||||||
|
DISABLE = "休息吧"
|
||||||
@ -1,2 +1,14 @@
|
|||||||
class IsSuperuserException(Exception):
|
class IsSuperuserException(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SkipPluginException(Exception):
|
||||||
|
def __init__(self, info: str, *args: object) -> None:
|
||||||
|
super().__init__(*args)
|
||||||
|
self.info = info
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.info
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return self.info
|
||||||
|
|||||||
@ -1,11 +1,61 @@
|
|||||||
|
import contextlib
|
||||||
|
|
||||||
|
from nonebot.adapters import Event
|
||||||
|
from nonebot_plugin_uninfo import Uninfo
|
||||||
|
|
||||||
from zhenxun.configs.config import Config
|
from zhenxun.configs.config import Config
|
||||||
from zhenxun.models.plugin_info import PluginInfo
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
|
from zhenxun.services.log import logger
|
||||||
from zhenxun.utils.enum import PluginType
|
from zhenxun.utils.enum import PluginType
|
||||||
|
from zhenxun.utils.message import MessageUtils
|
||||||
from zhenxun.utils.utils import FreqLimiter
|
from zhenxun.utils.utils import FreqLimiter
|
||||||
|
|
||||||
|
from .config import LOGGER_COMMAND
|
||||||
|
|
||||||
base_config = Config.get("hook")
|
base_config = Config.get("hook")
|
||||||
|
|
||||||
|
|
||||||
|
def is_poke(event: Event) -> bool:
|
||||||
|
"""判断是否为poke类型
|
||||||
|
|
||||||
|
参数:
|
||||||
|
event: Event
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bool: 是否为poke类型
|
||||||
|
"""
|
||||||
|
with contextlib.suppress(ImportError):
|
||||||
|
from nonebot.adapters.onebot.v11 import PokeNotifyEvent
|
||||||
|
|
||||||
|
return isinstance(event, PokeNotifyEvent)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def send_message(
|
||||||
|
session: Uninfo, message: list | str, check_tag: str | None = None
|
||||||
|
):
|
||||||
|
"""发送消息
|
||||||
|
|
||||||
|
参数:
|
||||||
|
session: Uninfo
|
||||||
|
message: 消息
|
||||||
|
check_tag: cd flag
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not check_tag:
|
||||||
|
await MessageUtils.build_message(message).send(reply_to=True)
|
||||||
|
elif freq._flmt.check(check_tag):
|
||||||
|
freq._flmt.start_cd(check_tag)
|
||||||
|
await MessageUtils.build_message(message).send(reply_to=True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"发送消息失败",
|
||||||
|
LOGGER_COMMAND,
|
||||||
|
session=session,
|
||||||
|
e=e,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FreqUtils:
|
class FreqUtils:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
check_notice_info_cd = Config.get_config("hook", "CHECK_NOTICE_INFO_CD")
|
check_notice_info_cd = Config.get_config("hook", "CHECK_NOTICE_INFO_CD")
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
import contextlib
|
import asyncio
|
||||||
|
|
||||||
from nonebot.adapters import Bot, Event
|
from nonebot.adapters import Bot, Event
|
||||||
from nonebot.exception import IgnoredException
|
from nonebot.exception import IgnoredException
|
||||||
@ -18,14 +18,107 @@ from zhenxun.utils.enum import (
|
|||||||
)
|
)
|
||||||
from zhenxun.utils.exception import InsufficientGold
|
from zhenxun.utils.exception import InsufficientGold
|
||||||
from zhenxun.utils.platform import PlatformUtils
|
from zhenxun.utils.platform import PlatformUtils
|
||||||
|
from zhenxun.utils.utils import get_entity_ids
|
||||||
|
|
||||||
from .auth.auth_admin import auth_admin
|
from .auth.auth_admin import auth_admin
|
||||||
|
from .auth.auth_ban import auth_ban
|
||||||
from .auth.auth_bot import auth_bot
|
from .auth.auth_bot import auth_bot
|
||||||
from .auth.auth_cost import auth_cost
|
from .auth.auth_cost import auth_cost
|
||||||
from .auth.auth_group import auth_group
|
from .auth.auth_group import auth_group
|
||||||
from .auth.auth_limit import LimitManage, auth_limit
|
from .auth.auth_limit import LimitManage, auth_limit
|
||||||
from .auth.auth_plugin import auth_plugin
|
from .auth.auth_plugin import auth_plugin
|
||||||
from .auth.exception import IsSuperuserException
|
from .auth.config import LOGGER_COMMAND
|
||||||
|
from .auth.exception import IsSuperuserException, SkipPluginException
|
||||||
|
|
||||||
|
|
||||||
|
async def get_plugin_and_user(
|
||||||
|
module: str, user_id: str
|
||||||
|
) -> tuple[PluginInfo, UserConsole]:
|
||||||
|
"""获取用户数据和插件信息
|
||||||
|
|
||||||
|
参数:
|
||||||
|
module: 模块名
|
||||||
|
user_id: 用户id
|
||||||
|
|
||||||
|
异常:
|
||||||
|
SkipPluginException: 插件数据不存在
|
||||||
|
SkipPluginException: 插件类型为HIDDEN
|
||||||
|
SkipPluginException: 重复创建用户
|
||||||
|
SkipPluginException: 用户数据不存在
|
||||||
|
|
||||||
|
返回:
|
||||||
|
tuple[PluginInfo, UserConsole]: 插件信息,用户信息
|
||||||
|
"""
|
||||||
|
user_cache = Cache[UserConsole](CacheType.USERS)
|
||||||
|
plugin = await Cache[PluginInfo](CacheType.PLUGINS).get(module)
|
||||||
|
if not plugin:
|
||||||
|
raise SkipPluginException(f"插件:{module} 数据不存在,已跳过权限检查...")
|
||||||
|
if plugin.plugin_type == PluginType.HIDDEN:
|
||||||
|
raise SkipPluginException(
|
||||||
|
f"插件: {plugin.name}:{plugin.module} 为HIDDEN,已跳过权限检查..."
|
||||||
|
)
|
||||||
|
user = None
|
||||||
|
try:
|
||||||
|
user = await user_cache.get(user_id)
|
||||||
|
except IntegrityError as e:
|
||||||
|
raise SkipPluginException("重复创建用户,已跳过该次权限检查...") from e
|
||||||
|
if not user:
|
||||||
|
raise SkipPluginException("用户数据不存在,已跳过权限检查...")
|
||||||
|
return plugin, user
|
||||||
|
|
||||||
|
|
||||||
|
async def get_plugin_cost(
|
||||||
|
bot: Bot, user: UserConsole, plugin: PluginInfo, session: Uninfo
|
||||||
|
) -> int:
|
||||||
|
"""获取插件费用
|
||||||
|
|
||||||
|
参数:
|
||||||
|
bot: Bot
|
||||||
|
user: 用户数据
|
||||||
|
plugin: 插件数据
|
||||||
|
session: Uninfo
|
||||||
|
|
||||||
|
异常:
|
||||||
|
IsSuperuserException: 超级用户
|
||||||
|
IsSuperuserException: 超级用户
|
||||||
|
|
||||||
|
返回:
|
||||||
|
int: 调用插件金币费用
|
||||||
|
"""
|
||||||
|
cost_gold = await auth_cost(user, plugin, session)
|
||||||
|
if session.user.id in bot.config.superusers:
|
||||||
|
if plugin.plugin_type == PluginType.SUPERUSER:
|
||||||
|
raise IsSuperuserException()
|
||||||
|
if not plugin.limit_superuser:
|
||||||
|
raise IsSuperuserException()
|
||||||
|
return cost_gold
|
||||||
|
|
||||||
|
|
||||||
|
async def reduce_gold(user_id: str, module: str, cost_gold: int, session: Uninfo):
|
||||||
|
"""扣除用户金币
|
||||||
|
|
||||||
|
参数:
|
||||||
|
user_id: 用户id
|
||||||
|
module: 插件模块名称
|
||||||
|
cost_gold: 消耗金币
|
||||||
|
session: Uninfo
|
||||||
|
"""
|
||||||
|
user_cache = Cache[UserConsole](CacheType.USERS)
|
||||||
|
try:
|
||||||
|
await UserConsole.reduce_gold(
|
||||||
|
user_id,
|
||||||
|
cost_gold,
|
||||||
|
GoldHandle.PLUGIN,
|
||||||
|
module,
|
||||||
|
PlatformUtils.get_platform(session),
|
||||||
|
)
|
||||||
|
except InsufficientGold:
|
||||||
|
if u := await UserConsole.get_user(user_id):
|
||||||
|
u.gold = 0
|
||||||
|
await u.save(update_fields=["gold"])
|
||||||
|
# 更新缓存
|
||||||
|
await user_cache.update(user_id)
|
||||||
|
logger.debug(f"调用功能花费金币: {cost_gold}", LOGGER_COMMAND, session=session)
|
||||||
|
|
||||||
|
|
||||||
async def auth(
|
async def auth(
|
||||||
@ -44,85 +137,32 @@ async def auth(
|
|||||||
session: Uninfo
|
session: Uninfo
|
||||||
message: UniMsg
|
message: UniMsg
|
||||||
"""
|
"""
|
||||||
user_id = session.user.id
|
|
||||||
group_id = None
|
|
||||||
channel_id = None
|
|
||||||
if session.group:
|
|
||||||
if session.group.parent:
|
|
||||||
group_id = session.group.parent.id
|
|
||||||
channel_id = session.group.id
|
|
||||||
else:
|
|
||||||
group_id = session.group.id
|
|
||||||
is_ignore = False
|
|
||||||
cost_gold = 0
|
cost_gold = 0
|
||||||
with contextlib.suppress(ImportError):
|
ignore_flag = False
|
||||||
from nonebot.adapters.onebot.v11 import PokeNotifyEvent
|
entity = get_entity_ids(session)
|
||||||
|
module = matcher.plugin_name or ""
|
||||||
if matcher.type == "notice" and not isinstance(event, PokeNotifyEvent):
|
try:
|
||||||
"""过滤除poke外的notice"""
|
if not module:
|
||||||
return
|
raise SkipPluginException("Matcher插件名称不存在...")
|
||||||
user_cache = Cache[UserConsole](CacheType.USERS)
|
plugin, user = await get_plugin_and_user(module, entity.user_id)
|
||||||
if matcher.plugin and (module := matcher.plugin.name):
|
cost_gold = await get_plugin_cost(bot, user, plugin, session)
|
||||||
plugin = await Cache[PluginInfo](CacheType.PLUGINS).get(module)
|
await asyncio.gather(
|
||||||
if not plugin:
|
*[
|
||||||
return logger.debug(f"插件:{module} 数据不存在,已跳过权限检查...")
|
auth_ban(matcher, bot, session),
|
||||||
if plugin.plugin_type == PluginType.HIDDEN:
|
auth_bot(plugin, bot.self_id),
|
||||||
return logger.debug(
|
auth_group(plugin, entity, message),
|
||||||
f"插件: {plugin.name}:{plugin.module} 为HIDDEN,已跳过权限检查..."
|
auth_admin(plugin, session),
|
||||||
)
|
auth_plugin(plugin, session, event),
|
||||||
user = None
|
]
|
||||||
try:
|
)
|
||||||
user = await user_cache.get(session.user.id)
|
await auth_limit(plugin, session)
|
||||||
except IntegrityError as e:
|
except SkipPluginException as e:
|
||||||
logger.debug(
|
LimitManage.unblock(module, entity.user_id, entity.group_id, entity.channel_id)
|
||||||
"重复创建用户,已跳过该次权限检查...",
|
logger.info(str(e), LOGGER_COMMAND, session=session)
|
||||||
"AuthChecker",
|
ignore_flag = True
|
||||||
session=session,
|
except IsSuperuserException:
|
||||||
e=e,
|
logger.debug("超级用户跳过权限检测...", LOGGER_COMMAND, session=session)
|
||||||
)
|
if not ignore_flag and cost_gold > 0:
|
||||||
if not user:
|
await reduce_gold(entity.user_id, module, cost_gold, session)
|
||||||
return logger.debug(
|
if ignore_flag:
|
||||||
"用户数据不存在,已跳过权限检查...", "AuthChecker", session=session
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
cost_gold = await auth_cost(user, plugin, session)
|
|
||||||
if session.user.id in bot.config.superusers:
|
|
||||||
if plugin.plugin_type == PluginType.SUPERUSER:
|
|
||||||
raise IsSuperuserException()
|
|
||||||
if not plugin.limit_superuser:
|
|
||||||
cost_gold = 0
|
|
||||||
raise IsSuperuserException()
|
|
||||||
await auth_bot(plugin, bot.self_id)
|
|
||||||
await auth_group(plugin, session, message)
|
|
||||||
await auth_admin(plugin, session)
|
|
||||||
await auth_plugin(plugin, session, event)
|
|
||||||
await auth_limit(plugin, session)
|
|
||||||
except IsSuperuserException:
|
|
||||||
logger.debug(
|
|
||||||
"超级用户或被ban跳过权限检测...", "AuthChecker", session=session
|
|
||||||
)
|
|
||||||
except IgnoredException:
|
|
||||||
is_ignore = True
|
|
||||||
LimitManage.unblock(matcher.plugin.name, user_id, group_id, channel_id)
|
|
||||||
except AssertionError as e:
|
|
||||||
is_ignore = True
|
|
||||||
logger.debug("消息无法发送", session=session, e=e)
|
|
||||||
if cost_gold and user_id:
|
|
||||||
"""花费金币"""
|
|
||||||
try:
|
|
||||||
await UserConsole.reduce_gold(
|
|
||||||
user_id,
|
|
||||||
cost_gold,
|
|
||||||
GoldHandle.PLUGIN,
|
|
||||||
matcher.plugin.name if matcher.plugin else "",
|
|
||||||
PlatformUtils.get_platform(session),
|
|
||||||
)
|
|
||||||
except InsufficientGold:
|
|
||||||
if u := await UserConsole.get_user(user_id):
|
|
||||||
u.gold = 0
|
|
||||||
await u.save(update_fields=["gold"])
|
|
||||||
# 更新缓存
|
|
||||||
await user_cache.update(user_id)
|
|
||||||
logger.debug(f"调用功能花费金币: {cost_gold}", "AuthChecker", session=session)
|
|
||||||
if is_ignore:
|
|
||||||
raise IgnoredException("权限检测 ignore")
|
raise IgnoredException("权限检测 ignore")
|
||||||
|
|||||||
@ -1,15 +1,21 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
from nonebot.adapters import Bot, Event
|
from nonebot.adapters import Bot, Event
|
||||||
from nonebot.matcher import Matcher
|
from nonebot.matcher import Matcher
|
||||||
from nonebot.message import run_postprocessor, run_preprocessor
|
from nonebot.message import run_postprocessor, run_preprocessor
|
||||||
from nonebot_plugin_alconna import UniMsg
|
from nonebot_plugin_alconna import UniMsg
|
||||||
from nonebot_plugin_uninfo import Uninfo
|
from nonebot_plugin_uninfo import Uninfo
|
||||||
|
|
||||||
|
from zhenxun.services.log import logger
|
||||||
|
|
||||||
|
from .auth.config import LOGGER_COMMAND
|
||||||
from .auth_checker import LimitManage, auth
|
from .auth_checker import LimitManage, auth
|
||||||
|
|
||||||
|
|
||||||
# # 权限检测
|
# # 权限检测
|
||||||
@run_preprocessor
|
@run_preprocessor
|
||||||
async def _(matcher: Matcher, event: Event, bot: Bot, session: Uninfo, message: UniMsg):
|
async def _(matcher: Matcher, event: Event, bot: Bot, session: Uninfo, message: UniMsg):
|
||||||
|
start_time = time.time()
|
||||||
await auth(
|
await auth(
|
||||||
matcher,
|
matcher,
|
||||||
event,
|
event,
|
||||||
@ -17,6 +23,7 @@ async def _(matcher: Matcher, event: Event, bot: Bot, session: Uninfo, message:
|
|||||||
session,
|
session,
|
||||||
message,
|
message,
|
||||||
)
|
)
|
||||||
|
logger.info(f"权限检测耗时:{time.time() - start_time}秒", LOGGER_COMMAND)
|
||||||
|
|
||||||
|
|
||||||
# 解除命令block阻塞
|
# 解除命令block阻塞
|
||||||
|
|||||||
@ -1,104 +0,0 @@
|
|||||||
from nonebot.adapters import Bot
|
|
||||||
from nonebot.exception import IgnoredException
|
|
||||||
from nonebot.matcher import Matcher
|
|
||||||
from nonebot.message import run_preprocessor
|
|
||||||
from nonebot_plugin_alconna import At
|
|
||||||
from nonebot_plugin_uninfo import Uninfo
|
|
||||||
from tortoise.exceptions import MultipleObjectsReturned
|
|
||||||
|
|
||||||
from zhenxun.configs.config import Config
|
|
||||||
from zhenxun.models.ban_console import BanConsole
|
|
||||||
from zhenxun.models.group_console import GroupConsole
|
|
||||||
from zhenxun.models.plugin_info import PluginInfo
|
|
||||||
from zhenxun.services.cache import Cache
|
|
||||||
from zhenxun.services.log import logger
|
|
||||||
from zhenxun.utils.enum import CacheType, PluginType
|
|
||||||
from zhenxun.utils.message import MessageUtils
|
|
||||||
from zhenxun.utils.utils import FreqLimiter
|
|
||||||
|
|
||||||
Config.add_plugin_config(
|
|
||||||
"hook",
|
|
||||||
"BAN_RESULT",
|
|
||||||
"才不会给你发消息.",
|
|
||||||
help="对被ban用户发送的消息",
|
|
||||||
)
|
|
||||||
|
|
||||||
_flmt = FreqLimiter(300)
|
|
||||||
|
|
||||||
|
|
||||||
async def is_ban(user_id: str | None, group_id: str | None):
|
|
||||||
cache = Cache[list[BanConsole]](CacheType.BAN)
|
|
||||||
result = await cache.get(user_id, group_id) or await cache.get(user_id)
|
|
||||||
return result and result[0].ban_time > 0
|
|
||||||
|
|
||||||
|
|
||||||
# 检查是否被ban
|
|
||||||
@run_preprocessor
|
|
||||||
async def _(matcher: Matcher, bot: Bot, session: Uninfo):
|
|
||||||
if plugin := matcher.plugin:
|
|
||||||
if metadata := plugin.metadata:
|
|
||||||
extra = metadata.extra
|
|
||||||
if extra.get("plugin_type") in [PluginType.HIDDEN]:
|
|
||||||
return
|
|
||||||
user_id = session.user.id
|
|
||||||
group_id = session.group.id if session.group else None
|
|
||||||
cache = Cache[list[BanConsole]](CacheType.BAN)
|
|
||||||
if user_id in bot.config.superusers:
|
|
||||||
return
|
|
||||||
if group_id:
|
|
||||||
try:
|
|
||||||
if await is_ban(None, group_id):
|
|
||||||
logger.debug("群组处于黑名单中...", "BanChecker")
|
|
||||||
raise IgnoredException("群组处于黑名单中...")
|
|
||||||
except MultipleObjectsReturned:
|
|
||||||
logger.warning(
|
|
||||||
"群组黑名单数据重复,过滤该次hook并移除多余数据...", "BanChecker"
|
|
||||||
)
|
|
||||||
ids = await BanConsole.filter(user_id="", group_id=group_id).values_list(
|
|
||||||
"id", flat=True
|
|
||||||
)
|
|
||||||
await BanConsole.filter(id__in=ids[:-1]).delete()
|
|
||||||
await cache.reload()
|
|
||||||
group_cache = Cache[GroupConsole](CacheType.GROUPS)
|
|
||||||
if g := await group_cache.get(group_id):
|
|
||||||
if g.level < 0:
|
|
||||||
logger.debug("群黑名单, 群权限-1...", "BanChecker")
|
|
||||||
raise IgnoredException("群黑名单, 群权限-1..")
|
|
||||||
if user_id:
|
|
||||||
ban_result = Config.get_config("hook", "BAN_RESULT")
|
|
||||||
if await is_ban(user_id, group_id):
|
|
||||||
time = await BanConsole.check_ban_time(user_id, group_id)
|
|
||||||
if time == -1:
|
|
||||||
time_str = "∞"
|
|
||||||
else:
|
|
||||||
time = abs(int(time))
|
|
||||||
if time < 60:
|
|
||||||
time_str = f"{time!s} 秒"
|
|
||||||
else:
|
|
||||||
minute = int(time / 60)
|
|
||||||
if minute > 60:
|
|
||||||
hours = minute // 60
|
|
||||||
minute %= 60
|
|
||||||
time_str = f"{hours} 小时 {minute}分钟"
|
|
||||||
else:
|
|
||||||
time_str = f"{minute} 分钟"
|
|
||||||
db_plugin = await Cache[PluginInfo](CacheType.PLUGINS).get(
|
|
||||||
matcher.plugin_name
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
db_plugin
|
|
||||||
# and not db_plugin.ignore_prompt
|
|
||||||
and time != -1
|
|
||||||
and ban_result
|
|
||||||
and _flmt.check(user_id)
|
|
||||||
):
|
|
||||||
_flmt.start_cd(user_id)
|
|
||||||
logger.debug(f"ban检测发送插件: {matcher.plugin_name}")
|
|
||||||
await MessageUtils.build_message(
|
|
||||||
[
|
|
||||||
At(flag="user", target=user_id),
|
|
||||||
f"{ban_result}\n在..在 {time_str} 后才会理你喔",
|
|
||||||
]
|
|
||||||
).send()
|
|
||||||
logger.debug("用户处于黑名单中...", "BanChecker")
|
|
||||||
raise IgnoredException("用户处于黑名单中...")
|
|
||||||
@ -42,6 +42,8 @@ def default_with_expiration(
|
|||||||
data: dict[str, Any], expire_data: dict[str, int], expire: int
|
data: dict[str, Any], expire_data: dict[str, int], expire: int
|
||||||
):
|
):
|
||||||
"""默认更新期时间cache方法"""
|
"""默认更新期时间cache方法"""
|
||||||
|
if not data:
|
||||||
|
return {}
|
||||||
keys = {k for k in data if k not in expire_data}
|
keys = {k for k in data if k not in expire_data}
|
||||||
return {k: time.time() + expire for k in keys} if keys else {}
|
return {k: time.time() + expire for k in keys} if keys else {}
|
||||||
|
|
||||||
@ -52,12 +54,6 @@ async def _():
|
|||||||
return {p.module: p for p in data_list}
|
return {p.module: p for p in data_list}
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.new(CacheType.PLUGINS)
|
|
||||||
async def _():
|
|
||||||
data_list = await PluginInfo.get_plugins()
|
|
||||||
return {p.module: p for p in data_list}
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.updater(CacheType.PLUGINS)
|
@CacheRoot.updater(CacheType.PLUGINS)
|
||||||
async def _(data: dict[str, PluginInfo], key: str, value: Any):
|
async def _(data: dict[str, PluginInfo], key: str, value: Any):
|
||||||
if value:
|
if value:
|
||||||
@ -109,21 +105,20 @@ async def _(data: dict[str, GroupConsole], key: str, value: Any):
|
|||||||
|
|
||||||
|
|
||||||
@CacheRoot.getter(CacheType.GROUPS, result_model=GroupConsole)
|
@CacheRoot.getter(CacheType.GROUPS, result_model=GroupConsole)
|
||||||
async def _(data: dict[str, GroupConsole] | None, group_id: str):
|
async def _(cache_data: CacheData, group_id: str):
|
||||||
if not data:
|
cache_data.data = cache_data.data or {}
|
||||||
data = {}
|
result = cache_data.data.get(group_id, None)
|
||||||
result = data.get(group_id, None)
|
|
||||||
if not result:
|
if not result:
|
||||||
result = await GroupConsole.get_group(group_id=group_id)
|
result = await GroupConsole.get_group(group_id=group_id)
|
||||||
if result:
|
if result:
|
||||||
data[group_id] = result
|
cache_data.data[group_id] = result
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.with_refresh(CacheType.GROUPS)
|
@CacheRoot.with_refresh(CacheType.GROUPS)
|
||||||
async def _(data: dict[str, GroupConsole]):
|
async def _(data: dict[str, GroupConsole]):
|
||||||
groups = await GroupConsole.filter(
|
groups = await GroupConsole.filter(
|
||||||
group_id__in=data.keys(), channel_id__isnull=True, load_status=True
|
group_id__in=data.keys(), channel_id__isnull=True
|
||||||
)
|
)
|
||||||
for group in groups:
|
for group in groups:
|
||||||
data[group.group_id] = group
|
data[group.group_id] = group
|
||||||
@ -154,14 +149,13 @@ async def _(data: dict[str, BotConsole], key: str, value: Any):
|
|||||||
|
|
||||||
|
|
||||||
@CacheRoot.getter(CacheType.BOT, result_model=BotConsole)
|
@CacheRoot.getter(CacheType.BOT, result_model=BotConsole)
|
||||||
async def _(data: dict[str, BotConsole] | None, bot_id: str):
|
async def _(cache_data: CacheData, bot_id: str):
|
||||||
if not data:
|
cache_data.data = cache_data.data or {}
|
||||||
data = {}
|
result = cache_data.data.get(bot_id, None)
|
||||||
result = data.get(bot_id, None)
|
|
||||||
if not result:
|
if not result:
|
||||||
result = await BotConsole.get_or_none(bot_id=bot_id)
|
result = await BotConsole.get_or_none(bot_id=bot_id)
|
||||||
if result:
|
if result:
|
||||||
data[bot_id] = result
|
cache_data.data[bot_id] = result
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@ -224,7 +218,7 @@ def _(cache_data: CacheData):
|
|||||||
return default_cleanup_expired(cache_data)
|
return default_cleanup_expired(cache_data)
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.new(CacheType.LEVEL)
|
@CacheRoot.new(CacheType.LEVEL, False)
|
||||||
async def _():
|
async def _():
|
||||||
return await LevelUser().all()
|
return await LevelUser().all()
|
||||||
|
|
||||||
@ -246,13 +240,13 @@ async def _(cache_data: CacheData, user_id: str, group_id: str | None = None):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.new(CacheType.BAN)
|
@CacheRoot.new(CacheType.BAN, False, 5)
|
||||||
async def _():
|
async def _():
|
||||||
return await BanConsole.all()
|
return await BanConsole.all()
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.getter(CacheType.BAN, result_model=list[BanConsole])
|
@CacheRoot.getter(CacheType.BAN, result_model=list[BanConsole])
|
||||||
def _(cache_data: CacheData, user_id: str | None, group_id: str | None = None):
|
async def _(cache_data: CacheData, user_id: str | None, group_id: str | None = None):
|
||||||
if user_id:
|
if user_id:
|
||||||
return (
|
return (
|
||||||
[
|
[
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, ClassVar, overload
|
from typing import Any, ClassVar, cast, overload
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from tortoise import fields
|
from tortoise import fields
|
||||||
@ -11,6 +11,42 @@ from zhenxun.services.db_context import Model
|
|||||||
from zhenxun.utils.enum import CacheType, DbLockType, PluginType
|
from zhenxun.utils.enum import CacheType, DbLockType, PluginType
|
||||||
|
|
||||||
|
|
||||||
|
def add_disable_marker(name: str) -> str:
|
||||||
|
"""添加模块禁用标记符
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 模块名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
添加了禁用标记的模块名 (前缀'<'和后缀',')
|
||||||
|
"""
|
||||||
|
return f"<{name},"
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def convert_module_format(data: str) -> list[str]: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def convert_module_format(data: list[str]) -> str: ...
|
||||||
|
|
||||||
|
|
||||||
|
def convert_module_format(data: str | list[str]) -> str | list[str]:
|
||||||
|
"""
|
||||||
|
在 `<aaa,<bbb,<ccc,` 和 `["aaa", "bbb", "ccc"]` (即禁用启用)之间进行相互转换。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
data: 要转换的数据
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str | list[str]: 根据输入类型返回转换后的数据。
|
||||||
|
"""
|
||||||
|
if isinstance(data, str):
|
||||||
|
return [item.strip(",") for item in data.split("<") if item]
|
||||||
|
else:
|
||||||
|
return "".join(format(item) for item in data)
|
||||||
|
|
||||||
|
|
||||||
class GroupConsole(Model):
|
class GroupConsole(Model):
|
||||||
id = fields.IntField(pk=True, generated=True, auto_increment=True)
|
id = fields.IntField(pk=True, generated=True, auto_increment=True)
|
||||||
"""自增id"""
|
"""自增id"""
|
||||||
@ -47,7 +83,7 @@ class GroupConsole(Model):
|
|||||||
platform = fields.CharField(255, default="qq", description="所属平台")
|
platform = fields.CharField(255, default="qq", description="所属平台")
|
||||||
"""所属平台"""
|
"""所属平台"""
|
||||||
|
|
||||||
class Meta: # type: ignore
|
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
|
||||||
table = "group_console"
|
table = "group_console"
|
||||||
table_description = "群组信息表"
|
table_description = "群组信息表"
|
||||||
unique_together = ("group_id", "channel_id")
|
unique_together = ("group_id", "channel_id")
|
||||||
@ -57,33 +93,34 @@ class GroupConsole(Model):
|
|||||||
enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE]
|
enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE]
|
||||||
"""开启锁"""
|
"""开启锁"""
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def format(name: str) -> str:
|
|
||||||
return f"<{name},"
|
|
||||||
|
|
||||||
@overload
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_module_format(cls, data: str) -> list[str]: ...
|
async def _get_task_modules(cls, *, default_status: bool) -> list[str]:
|
||||||
|
"""获取默认禁用的任务模块
|
||||||
@overload
|
|
||||||
@classmethod
|
|
||||||
def convert_module_format(cls, data: list[str]) -> str: ...
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def convert_module_format(cls, data: str | list[str]) -> str | list[str]:
|
|
||||||
"""
|
|
||||||
在 `<aaa,<bbb,<ccc,` 和 `["aaa", "bbb", "ccc"]` 之间进行相互转换。
|
|
||||||
|
|
||||||
参数:
|
|
||||||
data (str | list[str]): 输入数据,可能是格式化字符串或字符串列表。
|
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
str | list[str]: 根据输入类型返回转换后的数据。
|
list[str]: 任务模块列表
|
||||||
"""
|
"""
|
||||||
if isinstance(data, str):
|
return cast(
|
||||||
return [item.strip(",") for item in data.split("<") if item]
|
list[str],
|
||||||
elif isinstance(data, list):
|
await TaskInfo.filter(default_status=default_status).values_list(
|
||||||
return "".join(cls.format(item) for item in data)
|
"module", flat=True
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _get_plugin_modules(cls, *, default_status: bool) -> list[str]:
|
||||||
|
"""获取默认禁用的插件模块
|
||||||
|
|
||||||
|
返回:
|
||||||
|
list[str]: 插件模块列表
|
||||||
|
"""
|
||||||
|
return cast(
|
||||||
|
list[str],
|
||||||
|
await PluginInfo.filter(
|
||||||
|
plugin_type__in=[PluginType.NORMAL, PluginType.DEPENDANT],
|
||||||
|
default_status=default_status,
|
||||||
|
).values_list("module", flat=True),
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@CacheRoot.listener(CacheType.GROUPS)
|
@CacheRoot.listener(CacheType.GROUPS)
|
||||||
@ -92,20 +129,44 @@ class GroupConsole(Model):
|
|||||||
) -> Self:
|
) -> Self:
|
||||||
"""覆盖create方法"""
|
"""覆盖create方法"""
|
||||||
group = await super().create(using_db=using_db, **kwargs)
|
group = await super().create(using_db=using_db, **kwargs)
|
||||||
if modules := await TaskInfo.filter(default_status=False).values_list(
|
|
||||||
"module", flat=True
|
task_modules = await cls._get_task_modules(default_status=False)
|
||||||
):
|
plugin_modules = await cls._get_plugin_modules(default_status=False)
|
||||||
group.block_task = cls.convert_module_format(modules) # type: ignore
|
|
||||||
if modules := await PluginInfo.filter(
|
if task_modules or plugin_modules:
|
||||||
plugin_type__in=[PluginType.NORMAL, PluginType.DEPENDANT],
|
await cls._update_modules(group, task_modules, plugin_modules, using_db)
|
||||||
default_status=False,
|
|
||||||
).values_list("module", flat=True):
|
|
||||||
group.block_plugin = cls.convert_module_format(modules) # type: ignore
|
|
||||||
await group.save(
|
|
||||||
using_db=using_db, update_fields=["block_plugin", "block_task"]
|
|
||||||
)
|
|
||||||
return group
|
return group
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _update_modules(
|
||||||
|
cls,
|
||||||
|
group: Self,
|
||||||
|
task_modules: list[str],
|
||||||
|
plugin_modules: list[str],
|
||||||
|
using_db: BaseDBAsyncClient | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""更新模块设置
|
||||||
|
|
||||||
|
参数:
|
||||||
|
group: 群组实例
|
||||||
|
task_modules: 任务模块列表
|
||||||
|
plugin_modules: 插件模块列表
|
||||||
|
using_db: 数据库连接
|
||||||
|
"""
|
||||||
|
update_fields = []
|
||||||
|
|
||||||
|
if task_modules:
|
||||||
|
group.block_task = convert_module_format(task_modules)
|
||||||
|
update_fields.append("block_task")
|
||||||
|
|
||||||
|
if plugin_modules:
|
||||||
|
group.block_plugin = convert_module_format(plugin_modules)
|
||||||
|
update_fields.append("block_plugin")
|
||||||
|
|
||||||
|
if update_fields:
|
||||||
|
await group.save(using_db=using_db, update_fields=update_fields)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_or_create(
|
async def get_or_create(
|
||||||
cls,
|
cls,
|
||||||
@ -117,23 +178,19 @@ class GroupConsole(Model):
|
|||||||
group, is_create = await super().get_or_create(
|
group, is_create = await super().get_or_create(
|
||||||
defaults=defaults, using_db=using_db, **kwargs
|
defaults=defaults, using_db=using_db, **kwargs
|
||||||
)
|
)
|
||||||
if is_create and (
|
if not is_create:
|
||||||
modules := await TaskInfo.filter(default_status=False).values_list(
|
return group, is_create
|
||||||
"module", flat=True
|
|
||||||
)
|
task_modules = await cls._get_task_modules(default_status=False)
|
||||||
):
|
plugin_modules = await cls._get_plugin_modules(default_status=False)
|
||||||
group.block_task = cls.convert_module_format(modules) # type: ignore
|
|
||||||
if modules := await PluginInfo.filter(
|
if task_modules or plugin_modules:
|
||||||
plugin_type__in=[PluginType.NORMAL, PluginType.DEPENDANT],
|
await cls._update_modules(group, task_modules, plugin_modules, using_db)
|
||||||
default_status=False,
|
|
||||||
).values_list("module", flat=True):
|
|
||||||
group.block_plugin = cls.convert_module_format(modules) # type: ignore
|
|
||||||
await group.save(
|
|
||||||
using_db=using_db, update_fields=["block_plugin", "block_task"]
|
|
||||||
)
|
|
||||||
if is_create:
|
if is_create:
|
||||||
if cache := await CacheRoot.get_cache(CacheType.GROUPS):
|
if cache := await CacheRoot.get_cache(CacheType.GROUPS):
|
||||||
await cache.update(group.group_id, group)
|
await cache.update(group.group_id, group)
|
||||||
|
|
||||||
return group, is_create
|
return group, is_create
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -148,20 +205,15 @@ class GroupConsole(Model):
|
|||||||
group, is_create = await super().update_or_create(
|
group, is_create = await super().update_or_create(
|
||||||
defaults=defaults, using_db=using_db, **kwargs
|
defaults=defaults, using_db=using_db, **kwargs
|
||||||
)
|
)
|
||||||
if is_create and (
|
if not is_create:
|
||||||
modules := await TaskInfo.filter(default_status=False).values_list(
|
return group, is_create
|
||||||
"module", flat=True
|
|
||||||
)
|
task_modules = await cls._get_task_modules(default_status=False)
|
||||||
):
|
plugin_modules = await cls._get_plugin_modules(default_status=False)
|
||||||
group.block_task = cls.convert_module_format(modules) # type: ignore
|
|
||||||
if modules := await PluginInfo.filter(
|
if task_modules or plugin_modules:
|
||||||
plugin_type__in=[PluginType.NORMAL, PluginType.DEPENDANT],
|
await cls._update_modules(group, task_modules, plugin_modules, using_db)
|
||||||
default_status=False,
|
|
||||||
).values_list("module", flat=True):
|
|
||||||
group.block_plugin = cls.convert_module_format(modules) # type: ignore
|
|
||||||
await group.save(
|
|
||||||
using_db=using_db, update_fields=["block_plugin", "block_task"]
|
|
||||||
)
|
|
||||||
return group, is_create
|
return group, is_create
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -206,7 +258,7 @@ class GroupConsole(Model):
|
|||||||
"""
|
"""
|
||||||
return await cls.exists(
|
return await cls.exists(
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
superuser_block_plugin__contains=cls.format(module),
|
superuser_block_plugin__contains=add_disable_marker(module),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -220,10 +272,11 @@ class GroupConsole(Model):
|
|||||||
返回:
|
返回:
|
||||||
bool: 是否禁用插件
|
bool: 是否禁用插件
|
||||||
"""
|
"""
|
||||||
|
module = add_disable_marker(module)
|
||||||
return await cls.exists(
|
return await cls.exists(
|
||||||
group_id=group_id, block_plugin__contains=cls.format(module)
|
group_id=group_id, block_plugin__contains=module
|
||||||
) or await cls.exists(
|
) or await cls.exists(
|
||||||
group_id=group_id, superuser_block_plugin__contains=cls.format(module)
|
group_id=group_id, superuser_block_plugin__contains=module
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -245,12 +298,22 @@ class GroupConsole(Model):
|
|||||||
group, _ = await cls.get_or_create(
|
group, _ = await cls.get_or_create(
|
||||||
group_id=group_id, defaults={"platform": platform}
|
group_id=group_id, defaults={"platform": platform}
|
||||||
)
|
)
|
||||||
|
update_fields = []
|
||||||
if is_superuser:
|
if is_superuser:
|
||||||
if cls.format(module) not in group.superuser_block_plugin:
|
superuser_block_plugin = convert_module_format(group.superuser_block_plugin)
|
||||||
group.superuser_block_plugin += cls.format(module)
|
if module not in superuser_block_plugin:
|
||||||
elif cls.format(module) not in group.block_plugin:
|
superuser_block_plugin.append(module)
|
||||||
group.block_plugin += cls.format(module)
|
group.superuser_block_plugin = convert_module_format(
|
||||||
await group.save(update_fields=["block_plugin", "superuser_block_plugin"])
|
superuser_block_plugin
|
||||||
|
)
|
||||||
|
update_fields.append("superuser_block_plugin")
|
||||||
|
elif add_disable_marker(module) not in group.block_plugin:
|
||||||
|
block_plugin = convert_module_format(group.block_plugin)
|
||||||
|
block_plugin.append(module)
|
||||||
|
group.block_plugin = convert_module_format(block_plugin)
|
||||||
|
update_fields.append("block_plugin")
|
||||||
|
if update_fields:
|
||||||
|
await group.save(update_fields=update_fields)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def set_unblock_plugin(
|
async def set_unblock_plugin(
|
||||||
@ -271,14 +334,22 @@ class GroupConsole(Model):
|
|||||||
group, _ = await cls.get_or_create(
|
group, _ = await cls.get_or_create(
|
||||||
group_id=group_id, defaults={"platform": platform}
|
group_id=group_id, defaults={"platform": platform}
|
||||||
)
|
)
|
||||||
|
update_fields = []
|
||||||
if is_superuser:
|
if is_superuser:
|
||||||
if cls.format(module) in group.superuser_block_plugin:
|
superuser_block_plugin = convert_module_format(group.superuser_block_plugin)
|
||||||
group.superuser_block_plugin = group.superuser_block_plugin.replace(
|
if module in superuser_block_plugin:
|
||||||
cls.format(module), ""
|
superuser_block_plugin.remove(module)
|
||||||
|
group.superuser_block_plugin = convert_module_format(
|
||||||
|
superuser_block_plugin
|
||||||
)
|
)
|
||||||
elif cls.format(module) in group.block_plugin:
|
update_fields.append("superuser_block_plugin")
|
||||||
group.block_plugin = group.block_plugin.replace(cls.format(module), "")
|
elif add_disable_marker(module) in group.block_plugin:
|
||||||
await group.save(update_fields=["block_plugin", "superuser_block_plugin"])
|
block_plugin = convert_module_format(group.block_plugin)
|
||||||
|
block_plugin.remove(module)
|
||||||
|
group.block_plugin = convert_module_format(block_plugin)
|
||||||
|
update_fields.append("block_plugin")
|
||||||
|
if update_fields:
|
||||||
|
await group.save(update_fields=update_fields)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def is_normal_block_plugin(
|
async def is_normal_block_plugin(
|
||||||
@ -297,7 +368,7 @@ class GroupConsole(Model):
|
|||||||
return await cls.exists(
|
return await cls.exists(
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
channel_id=channel_id,
|
channel_id=channel_id,
|
||||||
block_plugin__contains=cls.format(module),
|
block_plugin__contains=f"<{module},",
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -313,7 +384,7 @@ class GroupConsole(Model):
|
|||||||
"""
|
"""
|
||||||
return await cls.exists(
|
return await cls.exists(
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
superuser_block_task__contains=cls.format(task),
|
superuser_block_task__contains=add_disable_marker(task),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -330,24 +401,23 @@ class GroupConsole(Model):
|
|||||||
返回:
|
返回:
|
||||||
bool: 是否禁用被动
|
bool: 是否禁用被动
|
||||||
"""
|
"""
|
||||||
|
task = add_disable_marker(task)
|
||||||
if not channel_id:
|
if not channel_id:
|
||||||
return await cls.exists(
|
return await cls.exists(
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
channel_id__isnull=True,
|
channel_id__isnull=True,
|
||||||
block_task__contains=cls.format(task),
|
block_task__contains=task,
|
||||||
) or await cls.exists(
|
) or await cls.exists(
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
channel_id__isnull=True,
|
channel_id__isnull=True,
|
||||||
superuser_block_task__contains=cls.format(task),
|
superuser_block_task__contains=task,
|
||||||
)
|
)
|
||||||
return await cls.exists(
|
return await cls.exists(
|
||||||
group_id=group_id,
|
group_id=group_id, channel_id=channel_id, block_task__contains=task
|
||||||
channel_id=channel_id,
|
|
||||||
block_task__contains=cls.format(task),
|
|
||||||
) or await cls.exists(
|
) or await cls.exists(
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
channel_id__isnull=True,
|
channel_id__isnull=True,
|
||||||
superuser_block_task__contains=cls.format(task),
|
superuser_block_task__contains=task,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -369,12 +439,20 @@ class GroupConsole(Model):
|
|||||||
group, _ = await cls.get_or_create(
|
group, _ = await cls.get_or_create(
|
||||||
group_id=group_id, defaults={"platform": platform}
|
group_id=group_id, defaults={"platform": platform}
|
||||||
)
|
)
|
||||||
|
update_fields = []
|
||||||
if is_superuser:
|
if is_superuser:
|
||||||
if cls.format(task) not in group.superuser_block_task:
|
superuser_block_task = convert_module_format(group.superuser_block_task)
|
||||||
group.superuser_block_task += cls.format(task)
|
if task not in group.superuser_block_task:
|
||||||
elif cls.format(task) not in group.block_task:
|
superuser_block_task.append(task)
|
||||||
group.block_task += cls.format(task)
|
group.superuser_block_task = convert_module_format(superuser_block_task)
|
||||||
await group.save(update_fields=["block_task", "superuser_block_task"])
|
update_fields.append("superuser_block_task")
|
||||||
|
elif add_disable_marker(task) not in group.block_task:
|
||||||
|
block_task = convert_module_format(group.block_task)
|
||||||
|
block_task.append(task)
|
||||||
|
group.block_task = convert_module_format(block_task)
|
||||||
|
update_fields.append("block_task")
|
||||||
|
if update_fields:
|
||||||
|
await group.save(update_fields=update_fields)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def set_unblock_task(
|
async def set_unblock_task(
|
||||||
@ -395,14 +473,20 @@ class GroupConsole(Model):
|
|||||||
group, _ = await cls.get_or_create(
|
group, _ = await cls.get_or_create(
|
||||||
group_id=group_id, defaults={"platform": platform}
|
group_id=group_id, defaults={"platform": platform}
|
||||||
)
|
)
|
||||||
|
update_fields = []
|
||||||
if is_superuser:
|
if is_superuser:
|
||||||
if cls.format(task) in group.superuser_block_task:
|
superuser_block_task = convert_module_format(group.superuser_block_task)
|
||||||
group.superuser_block_task = group.superuser_block_task.replace(
|
if task in superuser_block_task:
|
||||||
cls.format(task), ""
|
superuser_block_task.remove(task)
|
||||||
)
|
group.superuser_block_task = convert_module_format(superuser_block_task)
|
||||||
elif cls.format(task) in group.block_task:
|
update_fields.append("superuser_block_task")
|
||||||
group.block_task = group.block_task.replace(cls.format(task), "")
|
elif add_disable_marker(task) in group.block_task:
|
||||||
await group.save(update_fields=["block_task", "superuser_block_task"])
|
block_task = convert_module_format(group.block_task)
|
||||||
|
block_task.remove(task)
|
||||||
|
group.block_task = convert_module_format(block_task)
|
||||||
|
update_fields.append("block_task")
|
||||||
|
if update_fields:
|
||||||
|
await group.save(update_fields=update_fields)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _run_script(cls):
|
def _run_script(cls):
|
||||||
|
|||||||
@ -46,13 +46,12 @@ class CacheGetter(BaseModel, Generic[T]):
|
|||||||
|
|
||||||
async def get(self, cache_data: "CacheData", *args, **kwargs) -> T:
|
async def get(self, cache_data: "CacheData", *args, **kwargs) -> T:
|
||||||
"""获取缓存"""
|
"""获取缓存"""
|
||||||
processed_data = (
|
if not self.get_func:
|
||||||
await self.get_func(cache_data, *args, **kwargs)
|
return cache_data.data
|
||||||
if self.get_func and is_coroutine_callable(self.get_func)
|
if is_coroutine_callable(self.get_func):
|
||||||
else self.get_func(cache_data, *args, **kwargs)
|
processed_data = await self.get_func(cache_data, *args, **kwargs)
|
||||||
if self.get_func
|
else:
|
||||||
else cache_data.data
|
processed_data = self.get_func(cache_data, *args, **kwargs)
|
||||||
)
|
|
||||||
return cast(T, processed_data)
|
return cast(T, processed_data)
|
||||||
|
|
||||||
|
|
||||||
@ -81,9 +80,14 @@ class CacheData(BaseModel):
|
|||||||
"""更新时间"""
|
"""更新时间"""
|
||||||
reload_count: int = 0
|
reload_count: int = 0
|
||||||
"""更新次数"""
|
"""更新次数"""
|
||||||
|
incremental_update: bool = True
|
||||||
|
"""是否是增量更新"""
|
||||||
|
|
||||||
async def get(self, *args, **kwargs) -> Any:
|
async def get(self, *args, **kwargs) -> Any:
|
||||||
"""获取单个缓存"""
|
"""获取单个缓存"""
|
||||||
|
if not self.reload_count and not self.incremental_update:
|
||||||
|
# 首次获取时,非增量更新获取全部数据
|
||||||
|
await self.reload()
|
||||||
self.call_cleanup_expired() # 移除过期缓存
|
self.call_cleanup_expired() # 移除过期缓存
|
||||||
if not self.getter:
|
if not self.getter:
|
||||||
return self.data
|
return self.data
|
||||||
@ -210,12 +214,17 @@ class CacheManage:
|
|||||||
id=f"CacheRoot-{cache_data.name}",
|
id=f"CacheRoot-{cache_data.name}",
|
||||||
)
|
)
|
||||||
|
|
||||||
def new(self, name: str, expire: int = 60 * 10):
|
def new(self, name: str, incremental_update: bool = True, expire: int = 60 * 10):
|
||||||
def wrapper(func: Callable):
|
def wrapper(func: Callable):
|
||||||
_name = name.upper()
|
_name = name.upper()
|
||||||
if _name in self._data:
|
if _name in self._data:
|
||||||
raise DbCacheException(f"DbCache 缓存数据 {name} 已存在...")
|
raise DbCacheException(f"DbCache 缓存数据 {name} 已存在...")
|
||||||
self._data[_name] = CacheData(name=_name, func=func, expire=expire)
|
self._data[_name] = CacheData(
|
||||||
|
name=_name,
|
||||||
|
func=func,
|
||||||
|
expire=expire,
|
||||||
|
incremental_update=incremental_update,
|
||||||
|
)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
@ -253,7 +262,7 @@ class CacheManage:
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
@validate_name
|
@validate_name
|
||||||
def getter(self, name: str, result_model: type | None = None):
|
def getter(self, name: str, result_model: type):
|
||||||
def wrapper(func: Callable):
|
def wrapper(func: Callable):
|
||||||
self._data[name].getter = CacheGetter[result_model](get_func=func)
|
self._data[name].getter = CacheGetter[result_model](get_func=func)
|
||||||
|
|
||||||
|
|||||||
@ -69,12 +69,12 @@ class Model(TortoiseModel):
|
|||||||
using_db: BaseDBAsyncClient | None = None,
|
using_db: BaseDBAsyncClient | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> tuple[Self, bool]:
|
) -> tuple[Self, bool]:
|
||||||
result = await super().get_or_create(
|
result, is_create = await super().get_or_create(
|
||||||
defaults=defaults, using_db=using_db, **kwargs
|
defaults=defaults, using_db=using_db, **kwargs
|
||||||
)
|
)
|
||||||
if cache_type := cls.get_cache_type():
|
if is_create and (cache_type := cls.get_cache_type()):
|
||||||
await CacheRoot.reload(cache_type)
|
await CacheRoot.reload(cache_type)
|
||||||
return result
|
return (result, is_create)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def update_or_create(
|
async def update_or_create(
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -6,6 +7,7 @@ import time
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
from nonebot_plugin_uninfo import Uninfo
|
||||||
import pypinyin
|
import pypinyin
|
||||||
import pytz
|
import pytz
|
||||||
|
|
||||||
@ -13,6 +15,16 @@ from zhenxun.configs.config import Config
|
|||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EntityIDs:
|
||||||
|
user_id: str
|
||||||
|
"""用户id"""
|
||||||
|
group_id: str | None
|
||||||
|
"""群组id"""
|
||||||
|
channel_id: str | None
|
||||||
|
"""频道id"""
|
||||||
|
|
||||||
|
|
||||||
class ResourceDirManager:
|
class ResourceDirManager:
|
||||||
"""
|
"""
|
||||||
临时文件管理器
|
临时文件管理器
|
||||||
@ -228,3 +240,24 @@ def is_valid_date(date_text: str, separator: str = "-") -> bool:
|
|||||||
return True
|
return True
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_entity_ids(session: Uninfo) -> EntityIDs:
|
||||||
|
"""获取用户id,群组id,频道id
|
||||||
|
|
||||||
|
参数:
|
||||||
|
session: Uninfo
|
||||||
|
|
||||||
|
返回:
|
||||||
|
EntityIDs: 用户id,群组id,频道id
|
||||||
|
"""
|
||||||
|
user_id = session.user.id
|
||||||
|
group_id = None
|
||||||
|
channel_id = None
|
||||||
|
if session.group:
|
||||||
|
if session.group.parent:
|
||||||
|
group_id = session.group.parent.id
|
||||||
|
channel_id = session.group.id
|
||||||
|
else:
|
||||||
|
group_id = session.group.id
|
||||||
|
return EntityIDs(user_id=user_id, group_id=group_id, channel_id=channel_id)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user