zhenxun_bot/zhenxun/utils/depends/__init__.py

304 lines
8.9 KiB
Python
Raw Normal View History

from typing import Any, Literal
2024-05-04 13:48:12 +08:00
from nonebot.adapters import Bot, Event
from nonebot.exception import SkippedException
2024-10-18 18:57:55 +08:00
from nonebot.internal.params import Depends
from nonebot.matcher import Matcher
from nonebot.params import Command
from nonebot.permission import SUPERUSER
2024-05-04 13:48:12 +08:00
from nonebot_plugin_session import EventSession
from nonebot_plugin_uninfo import Uninfo
2024-03-09 23:42:59 +08:00
2024-05-04 13:48:12 +08:00
from zhenxun.configs.config import Config
from zhenxun.services import group_settings_service
from zhenxun.utils.limiters import ConcurrencyLimiter, FreqLimiter, RateLimiter
2024-08-11 15:57:33 +08:00
from zhenxun.utils.message import MessageUtils
from zhenxun.utils.time_utils import TimeUtils
_coolers: dict[str, FreqLimiter] = {}
_rate_limiters: dict[str, RateLimiter] = {}
_concurrency_limiters: dict[str, ConcurrencyLimiter] = {}
def _create_limiter_dependency(
limiter_class: type,
limiter_storage: dict,
limiter_init_args: dict[str, Any],
scope: Literal["user", "group", "global"],
prompt: str,
**kwargs,
):
"""
一个高阶函数用于创建不同类型的限制器依赖
参数:
limiter_class: 限制器类 (FreqLimiter, RateLimiter, etc.).
limiter_storage: 用于存储限制器实例的字典.
limiter_init_args: 限制器类的初始化参数.
scope: 限制作用域.
prompt: 触发限制时的提示信息.
**kwargs: 传递给特定限制器逻辑的额外参数.
"""
async def dependency(
matcher: Matcher, session: EventSession, bot: Bot, event: Event
) -> bool:
if await SUPERUSER(bot, event):
return True
handler_id = (
f"{matcher.plugin_name}:{matcher.handlers[0].call.__code__.co_firstlineno}"
)
key: str | None = None
if scope == "user":
key = session.id1
elif scope == "group":
key = session.id3 or session.id2 or session.id1
elif scope == "global":
key = f"global_{handler_id}"
if not key:
return True
if handler_id not in limiter_storage:
limiter_storage[handler_id] = limiter_class(**limiter_init_args)
limiter = limiter_storage[handler_id]
if isinstance(limiter, ConcurrencyLimiter):
await limiter.acquire(key)
matcher.state["_concurrency_limiter_info"] = {
"limiter": limiter,
"key": key,
}
return True
else:
if limiter.check(key):
if isinstance(limiter, FreqLimiter):
limiter.start_cd(
key, kwargs.get("duration_sec", limiter.default_cd)
)
return True
else:
left_time = limiter.left_time(key)
format_kwargs = {
"cd_str": TimeUtils.format_duration(left_time),
**(kwargs.get("prompt_format_kwargs", {})),
}
message = prompt.format(**format_kwargs)
await matcher.finish(message)
return Depends(dependency)
def Cooldown(
duration: str,
*,
scope: Literal["user", "group", "global"] = "user",
prompt: str = "操作过于频繁,请等待 {cd_str}",
) -> bool:
"""声明式冷却检查依赖,限制用户操作频率
参数:
duration: 冷却时间字符串 (e.g., "30s", "10m", "1h")
scope: 冷却作用域
prompt: 自定义的冷却提示消息可使用 {cd_str} 占位符
返回:
bool: 是否允许执行
"""
try:
parsed_seconds = TimeUtils.parse_time_string(duration)
except ValueError as e:
raise ValueError(f"Cooldown装饰器中的duration格式错误: {e}")
return _create_limiter_dependency(
limiter_class=FreqLimiter,
limiter_storage=_coolers,
limiter_init_args={"default_cd_seconds": parsed_seconds},
scope=scope,
prompt=prompt,
duration_sec=parsed_seconds,
)
def RateLimit(
count: int,
duration: str,
*,
scope: Literal["user", "group", "global"] = "user",
prompt: str = "太快了,在 {duration_str} 内只能触发{limit}次,请等待 {cd_str}",
) -> bool:
"""声明式速率限制依赖,在指定时间窗口内限制操作次数
参数:
count: 在时间窗口内允许的最大调用次数
duration: 时间窗口字符串 (e.g., "1m", "1h")
scope: 限制作用域
prompt: 自定义的提示消息可使用 {cd_str}, {duration_str}, {limit} 占位符
返回:
bool: 是否允许执行
"""
try:
parsed_seconds = TimeUtils.parse_time_string(duration)
except ValueError as e:
raise ValueError(f"RateLimit装饰器中的duration格式错误: {e}")
return _create_limiter_dependency(
limiter_class=RateLimiter,
limiter_storage=_rate_limiters,
limiter_init_args={"max_calls": count, "time_window": parsed_seconds},
scope=scope,
prompt=prompt,
prompt_format_kwargs={"duration_str": duration, "limit": count},
)
def ConcurrencyLimit(
count: int,
*,
scope: Literal["user", "group", "global"] = "global",
prompt: str | None = "当前功能繁忙,请稍后再试...",
) -> bool:
"""声明式并发数限制依赖,控制某个功能同时执行的实例数量
参数:
count: 最大并发数
scope: 限制作用域
prompt: 提示消息暂未使用主要用于未来扩展超时功能
返回:
bool: 是否允许执行
"""
return _create_limiter_dependency(
limiter_class=ConcurrencyLimiter,
limiter_storage=_concurrency_limiters,
limiter_init_args={"max_concurrent": count},
scope=scope,
prompt=prompt or "",
)
2024-05-04 13:48:12 +08:00
def CheckUg(check_user: bool = True, check_group: bool = True):
"""检测群组id和用户id是否存在
参数:
check_user: 检查用户id.
check_group: 检查群组id.
"""
async def dependency(session: EventSession):
if check_user:
user_id = session.id1
if not user_id:
2024-08-11 15:57:33 +08:00
await MessageUtils.build_message("用户id为空").finish()
2024-05-04 13:48:12 +08:00
if check_group:
group_id = session.id3 or session.id2
if not group_id:
2024-08-11 15:57:33 +08:00
await MessageUtils.build_message("群组id为空").finish()
2024-05-04 13:48:12 +08:00
return Depends(dependency)
2024-03-09 23:42:59 +08:00
def OneCommand():
"""
获取单个命令Command
"""
async def dependency(
cmd: tuple[str, ...] = Command(),
):
return cmd[0] if cmd else None
return Depends(dependency)
def UserName():
"""
用户名称
"""
2024-10-18 18:57:55 +08:00
async def dependency(user_info: Uninfo):
return user_info.user.nick or user_info.user.name or ""
2024-05-04 13:48:12 +08:00
return Depends(dependency)
def GetConfig(
module: str | None = None,
config: str = "",
default_value: Any = None,
prompt: str | None = None,
):
"""获取配置项
参数:
module: 模块名为空时默认使用当前插件模块名
config: 配置项名称
default_value: 默认值
prompt: 为空时提示
"""
async def dependency(matcher: Matcher):
module_ = module or matcher.plugin_name
if module_:
value = Config.get_config(module_, config, default_value)
if value is None and prompt:
await matcher.finish(prompt)
return value
return Depends(dependency)
def GetGroupConfig(model: type[Any]):
"""
依赖注入函数用于获取并解析插件的分群配置
"""
async def dependency(matcher: Matcher, session: EventSession):
"""
实际的依赖注入逻辑
"""
plugin_name = matcher.plugin_name
group_id = session.id3 or session.id2
if not plugin_name:
raise SkippedException("无法确定插件名称以获取配置")
if not group_id:
try:
return model()
except Exception:
raise SkippedException("在私聊中无法获取分群配置")
return await group_settings_service.get_all_for_plugin(
group_id, plugin_name, parse_model=model
)
return Depends(dependency)
2024-05-04 13:48:12 +08:00
def CheckConfig(
module: str | None = None,
config: str | list[str] = "",
prompt: str | None = None,
):
"""检测配置项在配置文件中是否填写
参数:
module: 模块名为空时默认使用当前插件模块名
config: 需要检查的配置项名称
prompt: 为空时提示
"""
async def dependency(matcher: Matcher):
module_ = module or matcher.plugin_name
if module_:
config_list = [config] if isinstance(config, str) else config
for c in config_list:
if Config.get_config(module_, c) is None:
await matcher.finish(prompt or f"配置项 {c} 未填写!")
2024-03-09 23:42:59 +08:00
return Depends(dependency)