优化hook权限检测性能

This commit is contained in:
HibiKier 2025-04-08 17:11:44 +08:00
parent 03f0185e46
commit 41dd767724
19 changed files with 718 additions and 530 deletions

View File

@ -196,7 +196,7 @@ class PluginManage:
await PluginInfo.filter(plugin_type=PluginType.NORMAL).update(
default_status=status
)
return f'成功将所有功能进群默认状态修改为: {"开启" if status else "关闭"}'
return f"成功将所有功能进群默认状态修改为: {'开启' if status else '关闭'}"
if group_id:
if group := await GroupConsole.get_or_none(
group_id=group_id, channel_id__isnull=True
@ -213,12 +213,12 @@ class PluginManage:
module_list = [f"<{module}" for module in module_list]
group.block_plugin = ",".join(module_list) + "," # type: ignore
await group.save(update_fields=["block_plugin"])
return f'成功将此群组所有功能状态修改为: {"开启" if status else "关闭"}'
return f"成功将此群组所有功能状态修改为: {'开启' if status else '关闭'}"
return "获取群组失败..."
await PluginInfo.filter(plugin_type=PluginType.NORMAL).update(
status=status, block_type=None if status else BlockType.ALL
)
return f'成功将所有功能全局状态修改为: {"开启" if status else "关闭"}'
return f"成功将所有功能全局状态修改为: {'开启' if status else '关闭'}"
@classmethod
async def is_wake(cls, group_id: str) -> bool:
@ -243,9 +243,11 @@ class PluginManage:
参数:
group_id: 群组id
"""
await GroupConsole.filter(group_id=group_id, channel_id__isnull=True).update(
status=False
group, _ = await GroupConsole.get_or_create(
group_id=group_id, channel_id__isnull=True
)
group.status = False
await group.save(update_fields=["status"])
@classmethod
async def wake(cls, group_id: str):
@ -254,9 +256,11 @@ class PluginManage:
参数:
group_id: 群组id
"""
await GroupConsole.filter(group_id=group_id, channel_id__isnull=True).update(
status=True
group, _ = await GroupConsole.get_or_create(
group_id=group_id, channel_id__isnull=True
)
group.status = True
await group.save(update_fields=["status"])
@classmethod
async def block(cls, module: str):

View File

@ -1,15 +1,14 @@
from nonebot.exception import IgnoredException
from nonebot_plugin_alconna import At
from nonebot_plugin_uninfo import Uninfo
from zhenxun.models.level_user import LevelUser
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.cache import Cache
from zhenxun.services.log import logger
from zhenxun.utils.enum import CacheType
from zhenxun.utils.message import MessageUtils
from zhenxun.utils.utils import get_entity_ids
from .utils import freq
from .exception import SkipPluginException
from .utils import send_message
async def auth_admin(plugin: PluginInfo, session: Uninfo):
@ -17,60 +16,33 @@ async def auth_admin(plugin: PluginInfo, session: Uninfo):
参数:
plugin: PluginInfo
session: PluginInfo
session: Uninfo
"""
group_id = None
cache = Cache[list[LevelUser]](CacheType.LEVEL)
user_level = await cache.get(session.user.id) or []
if session.group:
if session.group.parent:
group_id = session.group.parent.id
else:
group_id = session.group.id
if not plugin.admin_level:
return
if group_id:
user_level += await cache.get(session.user.id, group_id) or []
entity = get_entity_ids(session)
cache = Cache[list[LevelUser]](CacheType.LEVEL)
user_level = await cache.get(session.user.id) or []
if entity.group_id:
user_level += await cache.get(session.user.id, entity.group_id) or []
user = max(user_level, key=lambda x: x.user_level)
if user.user_level < plugin.admin_level:
try:
if freq._flmt.check(session.user.id):
freq._flmt.start_cd(session.user.id)
await MessageUtils.build_message(
await send_message(
session,
[
At(flag="user", target=session.user.id),
"你的权限不足喔,"
f"该功能需要的权限等级: {plugin.admin_level}",
]
).send(reply_to=True)
except Exception as e:
logger.error(
"auth_admin 发送消息失败",
"AuthChecker",
session=session,
e=e,
f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}",
],
entity.user_id,
)
logger.debug(
f"{plugin.name}({plugin.module}) 管理员权限不足...",
"AuthChecker",
session=session,
raise SkipPluginException(
f"{plugin.name}({plugin.module}) 管理员权限不足..."
)
raise IgnoredException("管理员权限不足...")
elif user_level:
user = max(user_level, key=lambda x: x.user_level)
if user.user_level < plugin.admin_level:
try:
await MessageUtils.build_message(
f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}"
).send()
except Exception as e:
logger.error(
"auth_admin 发送消息失败", "AuthChecker", session=session, e=e
await send_message(
session,
f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}",
)
logger.debug(
f"{plugin.name}({plugin.module}) 管理员权限不足...",
"AuthChecker",
session=session,
)
raise IgnoredException("权限不足")
raise SkipPluginException(f"{plugin.name}({plugin.module}) 管理员权限不足...")

View File

@ -0,0 +1,167 @@
from nonebot.adapters import Bot
from nonebot.matcher import Matcher
from nonebot_plugin_alconna import At
from nonebot_plugin_uninfo import Uninfo
from tortoise.exceptions import MultipleObjectsReturned
from zhenxun.configs.config import Config
from zhenxun.models.ban_console import BanConsole
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.cache import Cache
from zhenxun.services.log import logger
from zhenxun.utils.enum import CacheType, PluginType
from zhenxun.utils.utils import EntityIDs, get_entity_ids
from .config import LOGGER_COMMAND
from .exception import SkipPluginException
from .utils import send_message
Config.add_plugin_config(
"hook",
"BAN_RESULT",
"才不会给你发消息.",
help="对被ban用户发送的消息",
)
async def is_ban(user_id: str | None, group_id: str | None) -> int:
cache = Cache[list[BanConsole]](CacheType.BAN)
results = await cache.get(user_id, group_id) or await cache.get(user_id)
if not results:
return 0
for result in results:
if result.group_id == group_id and (
result.duration > 0 or result.duration == -1
):
return await BanConsole.check_ban_time(user_id, group_id)
if not result.group_id and result.duration == -1:
return await BanConsole.check_ban_time(user_id, group_id)
return 0
def check_plugin_type(matcher: Matcher) -> bool:
"""判断插件类型是否是隐藏插件
参数:
matcher: Matcher
返回:
bool: 是否为隐藏插件
"""
if plugin := matcher.plugin:
if metadata := plugin.metadata:
extra = metadata.extra
if extra.get("plugin_type") in [PluginType.HIDDEN]:
return False
return True
def format_time(time: float) -> str:
"""格式化时间
参数:
time: ban时长
返回:
str: 格式化时间文本
"""
if time == -1:
return ""
time = abs(int(time))
if time < 60:
time_str = f"{time!s}"
else:
minute = int(time / 60)
if minute > 60:
hours = minute // 60
minute %= 60
time_str = f"{hours} 小时 {minute}分钟"
else:
time_str = f"{minute} 分钟"
return time_str
async def group_handle(cache: Cache[list[BanConsole]], group_id: str):
"""群组ban检查
参数:
cache: cache
group_id: 群组id
异常:
SkipPluginException: 群组处于黑名单
"""
try:
if await is_ban(None, group_id):
raise SkipPluginException("群组处于黑名单中...")
except MultipleObjectsReturned:
logger.warning(
"群组黑名单数据重复过滤该次hook并移除多余数据...", LOGGER_COMMAND
)
ids = await BanConsole.filter(user_id="", group_id=group_id).values_list(
"id", flat=True
)
await BanConsole.filter(id__in=ids[:-1]).delete()
await cache.reload()
async def user_handle(
module: str, cache: Cache[list[BanConsole]], entity: EntityIDs, session: Uninfo
):
"""用户ban检查
参数:
module: 插件模块名
cache: cache
user_id: 用户id
session: Uninfo
异常:
SkipPluginException: 用户处于黑名单
"""
ban_result = Config.get_config("hook", "BAN_RESULT")
try:
time = await is_ban(entity.user_id, entity.group_id)
if not time:
return
time_str = format_time(time)
db_plugin = await Cache[PluginInfo](CacheType.PLUGINS).get(module)
if (
db_plugin
# and not db_plugin.ignore_prompt
and time != -1
and ban_result
):
await send_message(
session,
[
At(flag="user", target=entity.user_id),
f"{ban_result}\n在..在 {time_str} 后才会理你喔",
],
entity.user_id,
)
raise SkipPluginException("用户处于黑名单中...")
except MultipleObjectsReturned:
logger.warning(
"用户黑名单数据重复过滤该次hook并移除多余数据...", LOGGER_COMMAND
)
ids = await BanConsole.filter(user_id=entity.user_id, group_id="").values_list(
"id", flat=True
)
await BanConsole.filter(id__in=ids[:-1]).delete()
await cache.reload()
async def auth_ban(matcher: Matcher, bot: Bot, session: Uninfo):
if not check_plugin_type(matcher):
return
if not matcher.plugin_name:
return
entity = get_entity_ids(session)
if entity.user_id in bot.config.superusers:
return
cache = Cache[list[BanConsole]](CacheType.BAN)
if entity.group_id:
await group_handle(cache, entity.group_id)
if entity.user_id:
await user_handle(matcher.plugin_name, cache, entity, session)

View File

@ -1,12 +1,11 @@
from nonebot.exception import IgnoredException
from zhenxun.models.bot_console import BotConsole
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.cache import Cache
from zhenxun.services.log import logger
from zhenxun.utils.common_utils import CommonUtils
from zhenxun.utils.enum import CacheType
from .exception import SkipPluginException
async def auth_bot(plugin: PluginInfo, bot_id: str):
"""bot层面的权限检查
@ -16,17 +15,14 @@ async def auth_bot(plugin: PluginInfo, bot_id: str):
bot_id: bot id
异常:
IgnoredException: 忽略插件
IgnoredException: 忽略插件
SkipPluginException: 忽略插件
SkipPluginException: 忽略插件
"""
if cache := Cache[BotConsole](CacheType.BOT):
bot = await cache.get(bot_id)
if not bot or not bot.status:
logger.debug("Bot不存在或休眠中阻断权限检测...", "AuthChecker")
raise IgnoredException("BotConsole休眠权限检测 ignore")
raise SkipPluginException("Bot不存在或休眠中阻断权限检测...")
if CommonUtils.format(plugin.module) in bot.block_plugins:
logger.debug(
f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭...",
"AuthChecker",
raise SkipPluginException(
f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭..."
)
raise IgnoredException("BotConsole插件权限检测 ignore")

View File

@ -1,10 +1,10 @@
from nonebot.exception import IgnoredException
from nonebot_plugin_uninfo import Uninfo
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.user_console import UserConsole
from zhenxun.services.log import logger
from zhenxun.utils.message import MessageUtils
from .exception import SkipPluginException
from .utils import send_message
async def auth_cost(user: UserConsole, plugin: PluginInfo, session: Uninfo) -> int:
@ -19,17 +19,6 @@ async def auth_cost(user: UserConsole, plugin: PluginInfo, session: Uninfo) -> i
"""
if user.gold < plugin.cost_gold:
"""插件消耗金币不足"""
try:
await MessageUtils.build_message(
f"金币不足..该功能需要{plugin.cost_gold}金币.."
).send()
except Exception as e:
logger.error("auth_cost 发送消息失败", "AuthChecker", session=session, e=e)
logger.debug(
f"{plugin.name}({plugin.module}) 金币限制.."
f"该功能需要{plugin.cost_gold}金币..",
"AuthChecker",
session=session,
)
raise IgnoredException(f"{plugin.name}({plugin.module}) 金币限制...")
await send_message(session, f"金币不足..该功能需要{plugin.cost_gold}金币..")
raise SkipPluginException(f"{plugin.name}({plugin.module}) 金币限制...")
return plugin.cost_gold

View File

@ -1,57 +1,35 @@
from nonebot.exception import IgnoredException
from nonebot_plugin_alconna import UniMsg
from nonebot_plugin_uninfo import Uninfo
from zhenxun.models.group_console import GroupConsole
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.cache import Cache
from zhenxun.services.log import logger
from zhenxun.utils.enum import CacheType
from zhenxun.utils.utils import EntityIDs
from .config import SwitchEnum
from .exception import SkipPluginException
async def auth_group(plugin: PluginInfo, session: Uninfo, message: UniMsg):
async def auth_group(plugin: PluginInfo, entity: EntityIDs, message: UniMsg):
"""群黑名单检测 群总开关检测
参数:
plugin: PluginInfo
session: EventSession
entity: EntityIDs
message: UniMsg
"""
if not session.group:
if not entity.group_id:
return
if session.group.parent:
group_id = session.group.parent.id
else:
group_id = session.group.id
text = message.extract_plain_text()
group = await Cache[GroupConsole](CacheType.GROUPS).get(group_id)
group = await Cache[GroupConsole](CacheType.GROUPS).get(entity.group_id)
if not group:
"""群不存在"""
logger.debug(
"群组信息不存在...",
"AuthChecker",
session=session,
)
raise IgnoredException("群不存在")
raise SkipPluginException("群组信息不存在...")
if group.level < 0:
"""群权限小于0"""
logger.debug(
"群黑名单, 群权限-1...",
"AuthChecker",
session=session,
)
raise IgnoredException("群黑名单")
if not group.status:
"""群休眠"""
if text.strip() != "醒来":
logger.debug("群休眠状态...", "AuthChecker", session=session)
raise IgnoredException("群休眠状态")
raise SkipPluginException("群组黑名单, 目标群组群权限权限-1...")
if text.strip() != SwitchEnum.ENABLE and not group.status:
raise SkipPluginException("群组休眠状态...")
if plugin.level > group.level:
"""插件等级大于群等级"""
logger.debug(
f"{plugin.name}({plugin.module}) 群等级限制.."
f"该功能需要的群等级: {plugin.level}..",
"AuthChecker",
session=session,
raise SkipPluginException(
f"{plugin.name}({plugin.module}) 群等级限制,"
f"该功能需要的群等级: {plugin.level}..."
)
raise IgnoredException(f"{plugin.name}({plugin.module}) 群等级限制...")

View File

@ -1,6 +1,5 @@
from typing import ClassVar
from nonebot.exception import IgnoredException
from nonebot_plugin_uninfo import Uninfo
from pydantic import BaseModel
@ -11,6 +10,9 @@ from zhenxun.utils.enum import LimitWatchType, PluginLimitType
from zhenxun.utils.message import MessageUtils
from zhenxun.utils.utils import CountLimiter, FreqLimiter, UserBlockLimiter
from .config import LOGGER_COMMAND
from .exception import SkipPluginException
class Limit(BaseModel):
limit: PluginLimit
@ -69,7 +71,7 @@ class LimitManage:
key_type = channel_id or group_id
logger.debug(
f"解除对象: {key_type} 的block限制",
"AuthChecker",
LOGGER_COMMAND,
session=user_id,
group_id=group_id,
)
@ -139,16 +141,13 @@ class LimitManage:
if is_limit and not limiter.check(key_type):
if limit.result:
await MessageUtils.build_message(limit.result).send()
logger.debug(
f"{limit.module}({limit.limit_type}) 正在限制中...",
"AuthChecker",
session=session,
raise SkipPluginException(
f"{limit.module}({limit.limit_type}) 正在限制中..."
)
raise IgnoredException(f"{limit.module} 正在限制中...")
else:
logger.debug(
f"开始进行限制 {limit.module}({limit.limit_type})...",
"AuthChecker",
LOGGER_COMMAND,
session=user_id,
group_id=group_id,
)

View File

@ -1,17 +1,15 @@
from nonebot.adapters import Event
from nonebot.exception import IgnoredException
from nonebot_plugin_uninfo import Uninfo
from zhenxun.models.group_console import GroupConsole
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.cache import Cache
from zhenxun.services.log import logger
from zhenxun.utils.common_utils import CommonUtils
from zhenxun.utils.enum import BlockType, CacheType
from zhenxun.utils.message import MessageUtils
from zhenxun.utils.utils import get_entity_ids
from .exception import IsSuperuserException
from .utils import freq
from .exception import IsSuperuserException, SkipPluginException
from .utils import freq, is_poke, send_message
class GroupCheck:
@ -42,16 +40,12 @@ class GroupCheck:
group = await self.__get_data()
if group and CommonUtils.format(plugin.module) in group.superuser_block_plugin:
if freq.is_send_limit_message(plugin, group.group_id, self.is_poke):
freq._flmt_s.start_cd(self.group_id)
await MessageUtils.build_message("超级管理员禁用了该群此功能...").send(
reply_to=True
await send_message(
self.session, "超级管理员禁用了该群此功能...", self.group_id
)
logger.debug(
f"{plugin.name}({plugin.module}) 超级管理员禁用了该群此功能...",
"AuthChecker",
session=self.session,
raise SkipPluginException(
f"{plugin.name}({plugin.module}) 超级管理员禁用了该群此功能..."
)
raise IgnoredException("超级管理员禁用了该群此功能...")
await self.check_normal_block(self.plugin)
async def check_normal_block(self, plugin: PluginInfo):
@ -66,16 +60,8 @@ class GroupCheck:
group = await self.__get_data()
if group and CommonUtils.format(plugin.module) in group.block_plugin:
if freq.is_send_limit_message(plugin, self.group_id, self.is_poke):
freq._flmt_s.start_cd(self.group_id)
await MessageUtils.build_message("该群未开启此功能...").send(
reply_to=True
)
logger.debug(
f"{plugin.name}({plugin.module}) 未开启此功能...",
"AuthChecker",
session=self.session,
)
raise IgnoredException("该群未开启此功能...")
await send_message(self.session, "该群未开启此功能...", self.group_id)
raise SkipPluginException(f"{plugin.name}({plugin.module}) 未开启此功能...")
await self.check_global_block(self.plugin)
async def check_global_block(self, plugin: PluginInfo):
@ -89,25 +75,13 @@ class GroupCheck:
"""
if plugin.block_type == BlockType.GROUP:
"""全局群组禁用"""
try:
if freq.is_send_limit_message(plugin, self.group_id, self.is_poke):
freq._flmt_c.start_cd(self.group_id)
await MessageUtils.build_message("该功能在群组中已被禁用...").send(
reply_to=True
await send_message(
self.session, "该功能在群组中已被禁用...", self.group_id
)
except Exception as e:
logger.error(
"auth_plugin 发送消息失败",
"AuthChecker",
session=self.session,
e=e,
raise SkipPluginException(
f"{plugin.name}({plugin.module}) 该插件在群组中已被禁用..."
)
logger.debug(
f"{plugin.name}({plugin.module}) 该插件在群组中已被禁用...",
"AuthChecker",
session=self.session,
)
raise IgnoredException("该插件在群组中已被禁用...")
class PluginCheck:
@ -126,25 +100,11 @@ class PluginCheck:
IgnoredException: 忽略插件
"""
if plugin.block_type == BlockType.PRIVATE:
try:
if freq.is_send_limit_message(
plugin, self.session.user.id, self.is_poke
):
freq._flmt_c.start_cd(self.session.user.id)
await MessageUtils.build_message("该功能在私聊中已被禁用...").send()
except Exception as e:
logger.error(
"auth_admin 发送消息失败",
"AuthChecker",
session=self.session,
e=e,
if freq.is_send_limit_message(plugin, self.session.user.id, self.is_poke):
await send_message(self.session, "该功能在私聊中已被禁用...")
raise SkipPluginException(
f"{plugin.name}({plugin.module}) 该插件在私聊中已被禁用..."
)
logger.debug(
f"{plugin.name}({plugin.module}) 该插件在私聊中已被禁用...",
"AuthChecker",
session=self.session,
)
raise IgnoredException("该插件在私聊中已被禁用...")
async def check_global(self, plugin: PluginInfo):
"""全局状态
@ -162,16 +122,10 @@ class PluginCheck:
if self.group_id and (group := await cache.get(self.group_id)):
if group.is_super:
raise IsSuperuserException()
logger.debug(
f"{plugin.name}({plugin.module}) 全局未开启此功能...",
"AuthChecker",
session=self.session,
)
sid = self.group_id or self.session.user.id
if freq.is_send_limit_message(plugin, sid, self.is_poke):
freq._flmt_s.start_cd(sid)
await MessageUtils.build_message("全局未开启此功能...").send()
raise IgnoredException("全局未开启此功能...")
await send_message(self.session, "全局未开启此功能...", sid)
raise SkipPluginException(f"{plugin.name}({plugin.module}) 全局未开启此功能...")
async def auth_plugin(plugin: PluginInfo, session: Uninfo, event: Event):
@ -180,22 +134,13 @@ async def auth_plugin(plugin: PluginInfo, session: Uninfo, event: Event):
参数:
plugin: PluginInfo
session: Uninfo
event: Event
"""
group_id = None
if session.group:
if session.group.parent:
group_id = session.group.parent.id
else:
group_id = session.group.id
try:
from nonebot.adapters.onebot.v11 import PokeNotifyEvent
is_poke = isinstance(event, PokeNotifyEvent)
except ImportError:
is_poke = False
user_check = PluginCheck(group_id, session, is_poke)
if group_id:
group_check = GroupCheck(plugin, group_id, session, is_poke)
entity = get_entity_ids(session)
is_poke_event = is_poke(event)
user_check = PluginCheck(entity.group_id, session, is_poke_event)
if entity.group_id:
group_check = GroupCheck(plugin, entity.group_id, session, is_poke_event)
await group_check.check()
else:
await user_check.check_user(plugin)

View File

@ -0,0 +1,13 @@
import sys
if sys.version_info >= (3, 11):
from enum import StrEnum
else:
from strenum import StrEnum
LOGGER_COMMAND = "AuthChecker"
class SwitchEnum(StrEnum):
ENABLE = "醒来"
DISABLE = "休息吧"

View File

@ -1,2 +1,14 @@
class IsSuperuserException(Exception):
pass
class SkipPluginException(Exception):
def __init__(self, info: str, *args: object) -> None:
super().__init__(*args)
self.info = info
def __str__(self) -> str:
return self.info
def __repr__(self) -> str:
return self.info

View File

@ -1,11 +1,61 @@
import contextlib
from nonebot.adapters import Event
from nonebot_plugin_uninfo import Uninfo
from zhenxun.configs.config import Config
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.log import logger
from zhenxun.utils.enum import PluginType
from zhenxun.utils.message import MessageUtils
from zhenxun.utils.utils import FreqLimiter
from .config import LOGGER_COMMAND
base_config = Config.get("hook")
def is_poke(event: Event) -> bool:
"""判断是否为poke类型
参数:
event: Event
返回:
bool: 是否为poke类型
"""
with contextlib.suppress(ImportError):
from nonebot.adapters.onebot.v11 import PokeNotifyEvent
return isinstance(event, PokeNotifyEvent)
return False
async def send_message(
session: Uninfo, message: list | str, check_tag: str | None = None
):
"""发送消息
参数:
session: Uninfo
message: 消息
check_tag: cd flag
"""
try:
if not check_tag:
await MessageUtils.build_message(message).send(reply_to=True)
elif freq._flmt.check(check_tag):
freq._flmt.start_cd(check_tag)
await MessageUtils.build_message(message).send(reply_to=True)
except Exception as e:
logger.error(
"发送消息失败",
LOGGER_COMMAND,
session=session,
e=e,
)
class FreqUtils:
def __init__(self):
check_notice_info_cd = Config.get_config("hook", "CHECK_NOTICE_INFO_CD")

View File

@ -1,4 +1,4 @@
import contextlib
import asyncio
from nonebot.adapters import Bot, Event
from nonebot.exception import IgnoredException
@ -18,14 +18,107 @@ from zhenxun.utils.enum import (
)
from zhenxun.utils.exception import InsufficientGold
from zhenxun.utils.platform import PlatformUtils
from zhenxun.utils.utils import get_entity_ids
from .auth.auth_admin import auth_admin
from .auth.auth_ban import auth_ban
from .auth.auth_bot import auth_bot
from .auth.auth_cost import auth_cost
from .auth.auth_group import auth_group
from .auth.auth_limit import LimitManage, auth_limit
from .auth.auth_plugin import auth_plugin
from .auth.exception import IsSuperuserException
from .auth.config import LOGGER_COMMAND
from .auth.exception import IsSuperuserException, SkipPluginException
async def get_plugin_and_user(
module: str, user_id: str
) -> tuple[PluginInfo, UserConsole]:
"""获取用户数据和插件信息
参数:
module: 模块名
user_id: 用户id
异常:
SkipPluginException: 插件数据不存在
SkipPluginException: 插件类型为HIDDEN
SkipPluginException: 重复创建用户
SkipPluginException: 用户数据不存在
返回:
tuple[PluginInfo, UserConsole]: 插件信息用户信息
"""
user_cache = Cache[UserConsole](CacheType.USERS)
plugin = await Cache[PluginInfo](CacheType.PLUGINS).get(module)
if not plugin:
raise SkipPluginException(f"插件:{module} 数据不存在,已跳过权限检查...")
if plugin.plugin_type == PluginType.HIDDEN:
raise SkipPluginException(
f"插件: {plugin.name}:{plugin.module} 为HIDDEN已跳过权限检查..."
)
user = None
try:
user = await user_cache.get(user_id)
except IntegrityError as e:
raise SkipPluginException("重复创建用户,已跳过该次权限检查...") from e
if not user:
raise SkipPluginException("用户数据不存在,已跳过权限检查...")
return plugin, user
async def get_plugin_cost(
bot: Bot, user: UserConsole, plugin: PluginInfo, session: Uninfo
) -> int:
"""获取插件费用
参数:
bot: Bot
user: 用户数据
plugin: 插件数据
session: Uninfo
异常:
IsSuperuserException: 超级用户
IsSuperuserException: 超级用户
返回:
int: 调用插件金币费用
"""
cost_gold = await auth_cost(user, plugin, session)
if session.user.id in bot.config.superusers:
if plugin.plugin_type == PluginType.SUPERUSER:
raise IsSuperuserException()
if not plugin.limit_superuser:
raise IsSuperuserException()
return cost_gold
async def reduce_gold(user_id: str, module: str, cost_gold: int, session: Uninfo):
"""扣除用户金币
参数:
user_id: 用户id
module: 插件模块名称
cost_gold: 消耗金币
session: Uninfo
"""
user_cache = Cache[UserConsole](CacheType.USERS)
try:
await UserConsole.reduce_gold(
user_id,
cost_gold,
GoldHandle.PLUGIN,
module,
PlatformUtils.get_platform(session),
)
except InsufficientGold:
if u := await UserConsole.get_user(user_id):
u.gold = 0
await u.save(update_fields=["gold"])
# 更新缓存
await user_cache.update(user_id)
logger.debug(f"调用功能花费金币: {cost_gold}", LOGGER_COMMAND, session=session)
async def auth(
@ -44,85 +137,32 @@ async def auth(
session: Uninfo
message: UniMsg
"""
user_id = session.user.id
group_id = None
channel_id = None
if session.group:
if session.group.parent:
group_id = session.group.parent.id
channel_id = session.group.id
else:
group_id = session.group.id
is_ignore = False
cost_gold = 0
with contextlib.suppress(ImportError):
from nonebot.adapters.onebot.v11 import PokeNotifyEvent
if matcher.type == "notice" and not isinstance(event, PokeNotifyEvent):
"""过滤除poke外的notice"""
return
user_cache = Cache[UserConsole](CacheType.USERS)
if matcher.plugin and (module := matcher.plugin.name):
plugin = await Cache[PluginInfo](CacheType.PLUGINS).get(module)
if not plugin:
return logger.debug(f"插件:{module} 数据不存在,已跳过权限检查...")
if plugin.plugin_type == PluginType.HIDDEN:
return logger.debug(
f"插件: {plugin.name}:{plugin.module} 为HIDDEN已跳过权限检查..."
)
user = None
ignore_flag = False
entity = get_entity_ids(session)
module = matcher.plugin_name or ""
try:
user = await user_cache.get(session.user.id)
except IntegrityError as e:
logger.debug(
"重复创建用户,已跳过该次权限检查...",
"AuthChecker",
session=session,
e=e,
if not module:
raise SkipPluginException("Matcher插件名称不存在...")
plugin, user = await get_plugin_and_user(module, entity.user_id)
cost_gold = await get_plugin_cost(bot, user, plugin, session)
await asyncio.gather(
*[
auth_ban(matcher, bot, session),
auth_bot(plugin, bot.self_id),
auth_group(plugin, entity, message),
auth_admin(plugin, session),
auth_plugin(plugin, session, event),
]
)
if not user:
return logger.debug(
"用户数据不存在,已跳过权限检查...", "AuthChecker", session=session
)
try:
cost_gold = await auth_cost(user, plugin, session)
if session.user.id in bot.config.superusers:
if plugin.plugin_type == PluginType.SUPERUSER:
raise IsSuperuserException()
if not plugin.limit_superuser:
cost_gold = 0
raise IsSuperuserException()
await auth_bot(plugin, bot.self_id)
await auth_group(plugin, session, message)
await auth_admin(plugin, session)
await auth_plugin(plugin, session, event)
await auth_limit(plugin, session)
except SkipPluginException as e:
LimitManage.unblock(module, entity.user_id, entity.group_id, entity.channel_id)
logger.info(str(e), LOGGER_COMMAND, session=session)
ignore_flag = True
except IsSuperuserException:
logger.debug(
"超级用户或被ban跳过权限检测...", "AuthChecker", session=session
)
except IgnoredException:
is_ignore = True
LimitManage.unblock(matcher.plugin.name, user_id, group_id, channel_id)
except AssertionError as e:
is_ignore = True
logger.debug("消息无法发送", session=session, e=e)
if cost_gold and user_id:
"""花费金币"""
try:
await UserConsole.reduce_gold(
user_id,
cost_gold,
GoldHandle.PLUGIN,
matcher.plugin.name if matcher.plugin else "",
PlatformUtils.get_platform(session),
)
except InsufficientGold:
if u := await UserConsole.get_user(user_id):
u.gold = 0
await u.save(update_fields=["gold"])
# 更新缓存
await user_cache.update(user_id)
logger.debug(f"调用功能花费金币: {cost_gold}", "AuthChecker", session=session)
if is_ignore:
logger.debug("超级用户跳过权限检测...", LOGGER_COMMAND, session=session)
if not ignore_flag and cost_gold > 0:
await reduce_gold(entity.user_id, module, cost_gold, session)
if ignore_flag:
raise IgnoredException("权限检测 ignore")

View File

@ -1,15 +1,21 @@
import time
from nonebot.adapters import Bot, Event
from nonebot.matcher import Matcher
from nonebot.message import run_postprocessor, run_preprocessor
from nonebot_plugin_alconna import UniMsg
from nonebot_plugin_uninfo import Uninfo
from zhenxun.services.log import logger
from .auth.config import LOGGER_COMMAND
from .auth_checker import LimitManage, auth
# # 权限检测
@run_preprocessor
async def _(matcher: Matcher, event: Event, bot: Bot, session: Uninfo, message: UniMsg):
start_time = time.time()
await auth(
matcher,
event,
@ -17,6 +23,7 @@ async def _(matcher: Matcher, event: Event, bot: Bot, session: Uninfo, message:
session,
message,
)
logger.info(f"权限检测耗时:{time.time() - start_time}", LOGGER_COMMAND)
# 解除命令block阻塞

View File

@ -1,104 +0,0 @@
from nonebot.adapters import Bot
from nonebot.exception import IgnoredException
from nonebot.matcher import Matcher
from nonebot.message import run_preprocessor
from nonebot_plugin_alconna import At
from nonebot_plugin_uninfo import Uninfo
from tortoise.exceptions import MultipleObjectsReturned
from zhenxun.configs.config import Config
from zhenxun.models.ban_console import BanConsole
from zhenxun.models.group_console import GroupConsole
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.cache import Cache
from zhenxun.services.log import logger
from zhenxun.utils.enum import CacheType, PluginType
from zhenxun.utils.message import MessageUtils
from zhenxun.utils.utils import FreqLimiter
Config.add_plugin_config(
"hook",
"BAN_RESULT",
"才不会给你发消息.",
help="对被ban用户发送的消息",
)
_flmt = FreqLimiter(300)
async def is_ban(user_id: str | None, group_id: str | None):
cache = Cache[list[BanConsole]](CacheType.BAN)
result = await cache.get(user_id, group_id) or await cache.get(user_id)
return result and result[0].ban_time > 0
# 检查是否被ban
@run_preprocessor
async def _(matcher: Matcher, bot: Bot, session: Uninfo):
if plugin := matcher.plugin:
if metadata := plugin.metadata:
extra = metadata.extra
if extra.get("plugin_type") in [PluginType.HIDDEN]:
return
user_id = session.user.id
group_id = session.group.id if session.group else None
cache = Cache[list[BanConsole]](CacheType.BAN)
if user_id in bot.config.superusers:
return
if group_id:
try:
if await is_ban(None, group_id):
logger.debug("群组处于黑名单中...", "BanChecker")
raise IgnoredException("群组处于黑名单中...")
except MultipleObjectsReturned:
logger.warning(
"群组黑名单数据重复过滤该次hook并移除多余数据...", "BanChecker"
)
ids = await BanConsole.filter(user_id="", group_id=group_id).values_list(
"id", flat=True
)
await BanConsole.filter(id__in=ids[:-1]).delete()
await cache.reload()
group_cache = Cache[GroupConsole](CacheType.GROUPS)
if g := await group_cache.get(group_id):
if g.level < 0:
logger.debug("群黑名单, 群权限-1...", "BanChecker")
raise IgnoredException("群黑名单, 群权限-1..")
if user_id:
ban_result = Config.get_config("hook", "BAN_RESULT")
if await is_ban(user_id, group_id):
time = await BanConsole.check_ban_time(user_id, group_id)
if time == -1:
time_str = ""
else:
time = abs(int(time))
if time < 60:
time_str = f"{time!s}"
else:
minute = int(time / 60)
if minute > 60:
hours = minute // 60
minute %= 60
time_str = f"{hours} 小时 {minute}分钟"
else:
time_str = f"{minute} 分钟"
db_plugin = await Cache[PluginInfo](CacheType.PLUGINS).get(
matcher.plugin_name
)
if (
db_plugin
# and not db_plugin.ignore_prompt
and time != -1
and ban_result
and _flmt.check(user_id)
):
_flmt.start_cd(user_id)
logger.debug(f"ban检测发送插件: {matcher.plugin_name}")
await MessageUtils.build_message(
[
At(flag="user", target=user_id),
f"{ban_result}\n在..在 {time_str} 后才会理你喔",
]
).send()
logger.debug("用户处于黑名单中...", "BanChecker")
raise IgnoredException("用户处于黑名单中...")

View File

@ -42,6 +42,8 @@ def default_with_expiration(
data: dict[str, Any], expire_data: dict[str, int], expire: int
):
"""默认更新期时间cache方法"""
if not data:
return {}
keys = {k for k in data if k not in expire_data}
return {k: time.time() + expire for k in keys} if keys else {}
@ -52,12 +54,6 @@ async def _():
return {p.module: p for p in data_list}
@CacheRoot.new(CacheType.PLUGINS)
async def _():
data_list = await PluginInfo.get_plugins()
return {p.module: p for p in data_list}
@CacheRoot.updater(CacheType.PLUGINS)
async def _(data: dict[str, PluginInfo], key: str, value: Any):
if value:
@ -109,21 +105,20 @@ async def _(data: dict[str, GroupConsole], key: str, value: Any):
@CacheRoot.getter(CacheType.GROUPS, result_model=GroupConsole)
async def _(data: dict[str, GroupConsole] | None, group_id: str):
if not data:
data = {}
result = data.get(group_id, None)
async def _(cache_data: CacheData, group_id: str):
cache_data.data = cache_data.data or {}
result = cache_data.data.get(group_id, None)
if not result:
result = await GroupConsole.get_group(group_id=group_id)
if result:
data[group_id] = result
cache_data.data[group_id] = result
return result
@CacheRoot.with_refresh(CacheType.GROUPS)
async def _(data: dict[str, GroupConsole]):
groups = await GroupConsole.filter(
group_id__in=data.keys(), channel_id__isnull=True, load_status=True
group_id__in=data.keys(), channel_id__isnull=True
)
for group in groups:
data[group.group_id] = group
@ -154,14 +149,13 @@ async def _(data: dict[str, BotConsole], key: str, value: Any):
@CacheRoot.getter(CacheType.BOT, result_model=BotConsole)
async def _(data: dict[str, BotConsole] | None, bot_id: str):
if not data:
data = {}
result = data.get(bot_id, None)
async def _(cache_data: CacheData, bot_id: str):
cache_data.data = cache_data.data or {}
result = cache_data.data.get(bot_id, None)
if not result:
result = await BotConsole.get_or_none(bot_id=bot_id)
if result:
data[bot_id] = result
cache_data.data[bot_id] = result
return result
@ -224,7 +218,7 @@ def _(cache_data: CacheData):
return default_cleanup_expired(cache_data)
@CacheRoot.new(CacheType.LEVEL)
@CacheRoot.new(CacheType.LEVEL, False)
async def _():
return await LevelUser().all()
@ -246,13 +240,13 @@ async def _(cache_data: CacheData, user_id: str, group_id: str | None = None):
]
@CacheRoot.new(CacheType.BAN)
@CacheRoot.new(CacheType.BAN, False, 5)
async def _():
return await BanConsole.all()
@CacheRoot.getter(CacheType.BAN, result_model=list[BanConsole])
def _(cache_data: CacheData, user_id: str | None, group_id: str | None = None):
async def _(cache_data: CacheData, user_id: str | None, group_id: str | None = None):
if user_id:
return (
[

View File

@ -1,4 +1,4 @@
from typing import Any, ClassVar, overload
from typing import Any, ClassVar, cast, overload
from typing_extensions import Self
from tortoise import fields
@ -11,6 +11,42 @@ from zhenxun.services.db_context import Model
from zhenxun.utils.enum import CacheType, DbLockType, PluginType
def add_disable_marker(name: str) -> str:
"""添加模块禁用标记符
Args:
name: 模块名称
Returns:
添加了禁用标记的模块名 (前缀'<'和后缀',')
"""
return f"<{name},"
@overload
def convert_module_format(data: str) -> list[str]: ...
@overload
def convert_module_format(data: list[str]) -> str: ...
def convert_module_format(data: str | list[str]) -> str | list[str]:
"""
`<aaa,<bbb,<ccc,` `["aaa", "bbb", "ccc"]` (即禁用启用)之间进行相互转换
参数:
data: 要转换的数据
返回:
str | list[str]: 根据输入类型返回转换后的数据
"""
if isinstance(data, str):
return [item.strip(",") for item in data.split("<") if item]
else:
return "".join(format(item) for item in data)
class GroupConsole(Model):
id = fields.IntField(pk=True, generated=True, auto_increment=True)
"""自增id"""
@ -47,7 +83,7 @@ class GroupConsole(Model):
platform = fields.CharField(255, default="qq", description="所属平台")
"""所属平台"""
class Meta: # type: ignore
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
table = "group_console"
table_description = "群组信息表"
unique_together = ("group_id", "channel_id")
@ -57,33 +93,34 @@ class GroupConsole(Model):
enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE]
"""开启锁"""
@staticmethod
def format(name: str) -> str:
return f"<{name},"
@overload
@classmethod
def convert_module_format(cls, data: str) -> list[str]: ...
@overload
@classmethod
def convert_module_format(cls, data: list[str]) -> str: ...
@classmethod
def convert_module_format(cls, data: str | list[str]) -> str | list[str]:
"""
`<aaa,<bbb,<ccc,` `["aaa", "bbb", "ccc"]` 之间进行相互转换
参数:
data (str | list[str]): 输入数据可能是格式化字符串或字符串列表
async def _get_task_modules(cls, *, default_status: bool) -> list[str]:
"""获取默认禁用的任务模块
返回:
str | list[str]: 根据输入类型返回转换后的数据
list[str]: 任务模块列表
"""
if isinstance(data, str):
return [item.strip(",") for item in data.split("<") if item]
elif isinstance(data, list):
return "".join(cls.format(item) for item in data)
return cast(
list[str],
await TaskInfo.filter(default_status=default_status).values_list(
"module", flat=True
),
)
@classmethod
async def _get_plugin_modules(cls, *, default_status: bool) -> list[str]:
"""获取默认禁用的插件模块
返回:
list[str]: 插件模块列表
"""
return cast(
list[str],
await PluginInfo.filter(
plugin_type__in=[PluginType.NORMAL, PluginType.DEPENDANT],
default_status=default_status,
).values_list("module", flat=True),
)
@classmethod
@CacheRoot.listener(CacheType.GROUPS)
@ -92,20 +129,44 @@ class GroupConsole(Model):
) -> Self:
"""覆盖create方法"""
group = await super().create(using_db=using_db, **kwargs)
if modules := await TaskInfo.filter(default_status=False).values_list(
"module", flat=True
):
group.block_task = cls.convert_module_format(modules) # type: ignore
if modules := await PluginInfo.filter(
plugin_type__in=[PluginType.NORMAL, PluginType.DEPENDANT],
default_status=False,
).values_list("module", flat=True):
group.block_plugin = cls.convert_module_format(modules) # type: ignore
await group.save(
using_db=using_db, update_fields=["block_plugin", "block_task"]
)
task_modules = await cls._get_task_modules(default_status=False)
plugin_modules = await cls._get_plugin_modules(default_status=False)
if task_modules or plugin_modules:
await cls._update_modules(group, task_modules, plugin_modules, using_db)
return group
@classmethod
async def _update_modules(
cls,
group: Self,
task_modules: list[str],
plugin_modules: list[str],
using_db: BaseDBAsyncClient | None = None,
) -> None:
"""更新模块设置
参数:
group: 群组实例
task_modules: 任务模块列表
plugin_modules: 插件模块列表
using_db: 数据库连接
"""
update_fields = []
if task_modules:
group.block_task = convert_module_format(task_modules)
update_fields.append("block_task")
if plugin_modules:
group.block_plugin = convert_module_format(plugin_modules)
update_fields.append("block_plugin")
if update_fields:
await group.save(using_db=using_db, update_fields=update_fields)
@classmethod
async def get_or_create(
cls,
@ -117,23 +178,19 @@ class GroupConsole(Model):
group, is_create = await super().get_or_create(
defaults=defaults, using_db=using_db, **kwargs
)
if is_create and (
modules := await TaskInfo.filter(default_status=False).values_list(
"module", flat=True
)
):
group.block_task = cls.convert_module_format(modules) # type: ignore
if modules := await PluginInfo.filter(
plugin_type__in=[PluginType.NORMAL, PluginType.DEPENDANT],
default_status=False,
).values_list("module", flat=True):
group.block_plugin = cls.convert_module_format(modules) # type: ignore
await group.save(
using_db=using_db, update_fields=["block_plugin", "block_task"]
)
if not is_create:
return group, is_create
task_modules = await cls._get_task_modules(default_status=False)
plugin_modules = await cls._get_plugin_modules(default_status=False)
if task_modules or plugin_modules:
await cls._update_modules(group, task_modules, plugin_modules, using_db)
if is_create:
if cache := await CacheRoot.get_cache(CacheType.GROUPS):
await cache.update(group.group_id, group)
return group, is_create
@classmethod
@ -148,20 +205,15 @@ class GroupConsole(Model):
group, is_create = await super().update_or_create(
defaults=defaults, using_db=using_db, **kwargs
)
if is_create and (
modules := await TaskInfo.filter(default_status=False).values_list(
"module", flat=True
)
):
group.block_task = cls.convert_module_format(modules) # type: ignore
if modules := await PluginInfo.filter(
plugin_type__in=[PluginType.NORMAL, PluginType.DEPENDANT],
default_status=False,
).values_list("module", flat=True):
group.block_plugin = cls.convert_module_format(modules) # type: ignore
await group.save(
using_db=using_db, update_fields=["block_plugin", "block_task"]
)
if not is_create:
return group, is_create
task_modules = await cls._get_task_modules(default_status=False)
plugin_modules = await cls._get_plugin_modules(default_status=False)
if task_modules or plugin_modules:
await cls._update_modules(group, task_modules, plugin_modules, using_db)
return group, is_create
@classmethod
@ -206,7 +258,7 @@ class GroupConsole(Model):
"""
return await cls.exists(
group_id=group_id,
superuser_block_plugin__contains=cls.format(module),
superuser_block_plugin__contains=add_disable_marker(module),
)
@classmethod
@ -220,10 +272,11 @@ class GroupConsole(Model):
返回:
bool: 是否禁用插件
"""
module = add_disable_marker(module)
return await cls.exists(
group_id=group_id, block_plugin__contains=cls.format(module)
group_id=group_id, block_plugin__contains=module
) or await cls.exists(
group_id=group_id, superuser_block_plugin__contains=cls.format(module)
group_id=group_id, superuser_block_plugin__contains=module
)
@classmethod
@ -245,12 +298,22 @@ class GroupConsole(Model):
group, _ = await cls.get_or_create(
group_id=group_id, defaults={"platform": platform}
)
update_fields = []
if is_superuser:
if cls.format(module) not in group.superuser_block_plugin:
group.superuser_block_plugin += cls.format(module)
elif cls.format(module) not in group.block_plugin:
group.block_plugin += cls.format(module)
await group.save(update_fields=["block_plugin", "superuser_block_plugin"])
superuser_block_plugin = convert_module_format(group.superuser_block_plugin)
if module not in superuser_block_plugin:
superuser_block_plugin.append(module)
group.superuser_block_plugin = convert_module_format(
superuser_block_plugin
)
update_fields.append("superuser_block_plugin")
elif add_disable_marker(module) not in group.block_plugin:
block_plugin = convert_module_format(group.block_plugin)
block_plugin.append(module)
group.block_plugin = convert_module_format(block_plugin)
update_fields.append("block_plugin")
if update_fields:
await group.save(update_fields=update_fields)
@classmethod
async def set_unblock_plugin(
@ -271,14 +334,22 @@ class GroupConsole(Model):
group, _ = await cls.get_or_create(
group_id=group_id, defaults={"platform": platform}
)
update_fields = []
if is_superuser:
if cls.format(module) in group.superuser_block_plugin:
group.superuser_block_plugin = group.superuser_block_plugin.replace(
cls.format(module), ""
superuser_block_plugin = convert_module_format(group.superuser_block_plugin)
if module in superuser_block_plugin:
superuser_block_plugin.remove(module)
group.superuser_block_plugin = convert_module_format(
superuser_block_plugin
)
elif cls.format(module) in group.block_plugin:
group.block_plugin = group.block_plugin.replace(cls.format(module), "")
await group.save(update_fields=["block_plugin", "superuser_block_plugin"])
update_fields.append("superuser_block_plugin")
elif add_disable_marker(module) in group.block_plugin:
block_plugin = convert_module_format(group.block_plugin)
block_plugin.remove(module)
group.block_plugin = convert_module_format(block_plugin)
update_fields.append("block_plugin")
if update_fields:
await group.save(update_fields=update_fields)
@classmethod
async def is_normal_block_plugin(
@ -297,7 +368,7 @@ class GroupConsole(Model):
return await cls.exists(
group_id=group_id,
channel_id=channel_id,
block_plugin__contains=cls.format(module),
block_plugin__contains=f"<{module},",
)
@classmethod
@ -313,7 +384,7 @@ class GroupConsole(Model):
"""
return await cls.exists(
group_id=group_id,
superuser_block_task__contains=cls.format(task),
superuser_block_task__contains=add_disable_marker(task),
)
@classmethod
@ -330,24 +401,23 @@ class GroupConsole(Model):
返回:
bool: 是否禁用被动
"""
task = add_disable_marker(task)
if not channel_id:
return await cls.exists(
group_id=group_id,
channel_id__isnull=True,
block_task__contains=cls.format(task),
block_task__contains=task,
) or await cls.exists(
group_id=group_id,
channel_id__isnull=True,
superuser_block_task__contains=cls.format(task),
superuser_block_task__contains=task,
)
return await cls.exists(
group_id=group_id,
channel_id=channel_id,
block_task__contains=cls.format(task),
group_id=group_id, channel_id=channel_id, block_task__contains=task
) or await cls.exists(
group_id=group_id,
channel_id__isnull=True,
superuser_block_task__contains=cls.format(task),
superuser_block_task__contains=task,
)
@classmethod
@ -369,12 +439,20 @@ class GroupConsole(Model):
group, _ = await cls.get_or_create(
group_id=group_id, defaults={"platform": platform}
)
update_fields = []
if is_superuser:
if cls.format(task) not in group.superuser_block_task:
group.superuser_block_task += cls.format(task)
elif cls.format(task) not in group.block_task:
group.block_task += cls.format(task)
await group.save(update_fields=["block_task", "superuser_block_task"])
superuser_block_task = convert_module_format(group.superuser_block_task)
if task not in group.superuser_block_task:
superuser_block_task.append(task)
group.superuser_block_task = convert_module_format(superuser_block_task)
update_fields.append("superuser_block_task")
elif add_disable_marker(task) not in group.block_task:
block_task = convert_module_format(group.block_task)
block_task.append(task)
group.block_task = convert_module_format(block_task)
update_fields.append("block_task")
if update_fields:
await group.save(update_fields=update_fields)
@classmethod
async def set_unblock_task(
@ -395,14 +473,20 @@ class GroupConsole(Model):
group, _ = await cls.get_or_create(
group_id=group_id, defaults={"platform": platform}
)
update_fields = []
if is_superuser:
if cls.format(task) in group.superuser_block_task:
group.superuser_block_task = group.superuser_block_task.replace(
cls.format(task), ""
)
elif cls.format(task) in group.block_task:
group.block_task = group.block_task.replace(cls.format(task), "")
await group.save(update_fields=["block_task", "superuser_block_task"])
superuser_block_task = convert_module_format(group.superuser_block_task)
if task in superuser_block_task:
superuser_block_task.remove(task)
group.superuser_block_task = convert_module_format(superuser_block_task)
update_fields.append("superuser_block_task")
elif add_disable_marker(task) in group.block_task:
block_task = convert_module_format(group.block_task)
block_task.remove(task)
group.block_task = convert_module_format(block_task)
update_fields.append("block_task")
if update_fields:
await group.save(update_fields=update_fields)
@classmethod
def _run_script(cls):

View File

@ -46,13 +46,12 @@ class CacheGetter(BaseModel, Generic[T]):
async def get(self, cache_data: "CacheData", *args, **kwargs) -> T:
"""获取缓存"""
processed_data = (
await self.get_func(cache_data, *args, **kwargs)
if self.get_func and is_coroutine_callable(self.get_func)
else self.get_func(cache_data, *args, **kwargs)
if self.get_func
else cache_data.data
)
if not self.get_func:
return cache_data.data
if is_coroutine_callable(self.get_func):
processed_data = await self.get_func(cache_data, *args, **kwargs)
else:
processed_data = self.get_func(cache_data, *args, **kwargs)
return cast(T, processed_data)
@ -81,9 +80,14 @@ class CacheData(BaseModel):
"""更新时间"""
reload_count: int = 0
"""更新次数"""
incremental_update: bool = True
"""是否是增量更新"""
async def get(self, *args, **kwargs) -> Any:
"""获取单个缓存"""
if not self.reload_count and not self.incremental_update:
# 首次获取时,非增量更新获取全部数据
await self.reload()
self.call_cleanup_expired() # 移除过期缓存
if not self.getter:
return self.data
@ -210,12 +214,17 @@ class CacheManage:
id=f"CacheRoot-{cache_data.name}",
)
def new(self, name: str, expire: int = 60 * 10):
def new(self, name: str, incremental_update: bool = True, expire: int = 60 * 10):
def wrapper(func: Callable):
_name = name.upper()
if _name in self._data:
raise DbCacheException(f"DbCache 缓存数据 {name} 已存在...")
self._data[_name] = CacheData(name=_name, func=func, expire=expire)
self._data[_name] = CacheData(
name=_name,
func=func,
expire=expire,
incremental_update=incremental_update,
)
return wrapper
@ -253,7 +262,7 @@ class CacheManage:
return wrapper
@validate_name
def getter(self, name: str, result_model: type | None = None):
def getter(self, name: str, result_model: type):
def wrapper(func: Callable):
self._data[name].getter = CacheGetter[result_model](get_func=func)

View File

@ -69,12 +69,12 @@ class Model(TortoiseModel):
using_db: BaseDBAsyncClient | None = None,
**kwargs: Any,
) -> tuple[Self, bool]:
result = await super().get_or_create(
result, is_create = await super().get_or_create(
defaults=defaults, using_db=using_db, **kwargs
)
if cache_type := cls.get_cache_type():
if is_create and (cache_type := cls.get_cache_type()):
await CacheRoot.reload(cache_type)
return result
return (result, is_create)
@classmethod
async def update_or_create(

View File

@ -1,4 +1,5 @@
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime
import os
from pathlib import Path
@ -6,6 +7,7 @@ import time
from typing import Any
import httpx
from nonebot_plugin_uninfo import Uninfo
import pypinyin
import pytz
@ -13,6 +15,16 @@ from zhenxun.configs.config import Config
from zhenxun.services.log import logger
@dataclass
class EntityIDs:
user_id: str
"""用户id"""
group_id: str | None
"""群组id"""
channel_id: str | None
"""频道id"""
class ResourceDirManager:
"""
临时文件管理器
@ -228,3 +240,24 @@ def is_valid_date(date_text: str, separator: str = "-") -> bool:
return True
except ValueError:
return False
def get_entity_ids(session: Uninfo) -> EntityIDs:
"""获取用户id群组id频道id
参数:
session: Uninfo
返回:
EntityIDs: 用户id群组id频道id
"""
user_id = session.user.id
group_id = None
channel_id = None
if session.group:
if session.group.parent:
group_id = session.group.parent.id
channel_id = session.group.id
else:
group_id = session.group.id
return EntityIDs(user_id=user_id, group_id=group_id, channel_id=channel_id)