diff --git a/zhenxun/builtin_plugins/hooks/auth/auth_limit.py b/zhenxun/builtin_plugins/hooks/auth/auth_limit.py index d199ff0d..80650472 100644 --- a/zhenxun/builtin_plugins/hooks/auth/auth_limit.py +++ b/zhenxun/builtin_plugins/hooks/auth/auth_limit.py @@ -11,14 +11,11 @@ from zhenxun.models.plugin_limit import PluginLimit from zhenxun.services.db_context import DB_TIMEOUT_SECONDS from zhenxun.services.log import logger from zhenxun.utils.enum import LimitWatchType, PluginLimitType +from zhenxun.utils.limiters import CountLimiter, FreqLimiter, UserBlockLimiter from zhenxun.utils.manager.priority_manager import PriorityLifecycle from zhenxun.utils.message import MessageUtils -from zhenxun.utils.utils import ( - CountLimiter, - FreqLimiter, - UserBlockLimiter, - get_entity_ids, -) +from zhenxun.utils.time_utils import TimeUtils +from zhenxun.utils.utils import get_entity_ids from .config import LOGGER_COMMAND, WARNING_THRESHOLD from .exception import SkipPluginException @@ -273,9 +270,16 @@ class LimitManager: key_type = channel_id or group_id if is_limit and not limiter.check(key_type): if limit.result: + format_kwargs = {} + if isinstance(limiter, FreqLimiter): + left_time = limiter.left_time(key_type) + cd_str = TimeUtils.format_duration(left_time) + format_kwargs = {"cd": cd_str} try: await asyncio.wait_for( - MessageUtils.build_message(limit.result).send(), + MessageUtils.build_message( + limit.result, format_args=format_kwargs + ).send(), timeout=DB_TIMEOUT_SECONDS, ) except asyncio.TimeoutError: diff --git a/zhenxun/builtin_plugins/hooks/limiter_hook.py b/zhenxun/builtin_plugins/hooks/limiter_hook.py new file mode 100644 index 00000000..22680941 --- /dev/null +++ b/zhenxun/builtin_plugins/hooks/limiter_hook.py @@ -0,0 +1,15 @@ +from nonebot.matcher import Matcher +from nonebot.message import run_postprocessor + +from zhenxun.utils.limiters import ConcurrencyLimiter + + +@run_postprocessor +async def _concurrency_release_hook(matcher: Matcher): + """ + 后处理器:在事件处理结束后,释放并发限制的信号量。 + """ + if concurrency_info := matcher.state.get("_concurrency_limiter_info"): + limiter: ConcurrencyLimiter = concurrency_info["limiter"] + key = concurrency_info["key"] + limiter.release(key) diff --git a/zhenxun/builtin_plugins/statistics/_data_source.py b/zhenxun/builtin_plugins/statistics/_data_source.py index 81e2b035..ab426ae6 100644 --- a/zhenxun/builtin_plugins/statistics/_data_source.py +++ b/zhenxun/builtin_plugins/statistics/_data_source.py @@ -8,7 +8,7 @@ from zhenxun.utils.echart_utils import ChartUtils from zhenxun.utils.echart_utils.models import Barh from zhenxun.utils.enum import PluginType from zhenxun.utils.image_utils import BuildImage -from zhenxun.utils.utils import TimeUtils +from zhenxun.utils.time_utils import TimeUtils class StatisticsManage: diff --git a/zhenxun/utils/depends/__init__.py b/zhenxun/utils/depends/__init__.py index 887813dd..e52dfd48 100644 --- a/zhenxun/utils/depends/__init__.py +++ b/zhenxun/utils/depends/__init__.py @@ -1,13 +1,181 @@ -from typing import Any +from typing import Any, Literal +from nonebot.adapters import Bot, Event from nonebot.internal.params import Depends from nonebot.matcher import Matcher from nonebot.params import Command +from nonebot.permission import SUPERUSER from nonebot_plugin_session import EventSession from nonebot_plugin_uninfo import Uninfo from zhenxun.configs.config import Config +from zhenxun.utils.limiters import ConcurrencyLimiter, FreqLimiter, RateLimiter from zhenxun.utils.message import MessageUtils +from zhenxun.utils.time_utils import TimeUtils + +_coolers: dict[str, FreqLimiter] = {} +_rate_limiters: dict[str, RateLimiter] = {} +_concurrency_limiters: dict[str, ConcurrencyLimiter] = {} + + +def _create_limiter_dependency( + limiter_class: type, + limiter_storage: dict, + limiter_init_args: dict[str, Any], + scope: Literal["user", "group", "global"], + prompt: str, + **kwargs, +): + """ + 一个高阶函数,用于创建不同类型的限制器依赖。 + + 参数: + limiter_class: 限制器类 (FreqLimiter, RateLimiter, etc.). + limiter_storage: 用于存储限制器实例的字典. + limiter_init_args: 限制器类的初始化参数. + scope: 限制作用域. + prompt: 触发限制时的提示信息. + **kwargs: 传递给特定限制器逻辑的额外参数. + """ + + async def dependency( + matcher: Matcher, session: EventSession, bot: Bot, event: Event + ) -> bool: + if await SUPERUSER(bot, event): + return True + + handler_id = ( + f"{matcher.plugin_name}:{matcher.handlers[0].call.__code__.co_firstlineno}" + ) + + key: str | None = None + if scope == "user": + key = session.id1 + elif scope == "group": + key = session.id3 or session.id2 or session.id1 + elif scope == "global": + key = f"global_{handler_id}" + + if not key: + return True + + if handler_id not in limiter_storage: + limiter_storage[handler_id] = limiter_class(**limiter_init_args) + limiter = limiter_storage[handler_id] + + if isinstance(limiter, ConcurrencyLimiter): + await limiter.acquire(key) + matcher.state["_concurrency_limiter_info"] = { + "limiter": limiter, + "key": key, + } + return True + else: + if limiter.check(key): + if isinstance(limiter, FreqLimiter): + limiter.start_cd( + key, kwargs.get("duration_sec", limiter.default_cd) + ) + return True + else: + left_time = limiter.left_time(key) + format_kwargs = { + "cd_str": TimeUtils.format_duration(left_time), + **(kwargs.get("prompt_format_kwargs", {})), + } + message = prompt.format(**format_kwargs) + await matcher.finish(message) + + return Depends(dependency) + + +def Cooldown( + duration: str, + *, + scope: Literal["user", "group", "global"] = "user", + prompt: str = "操作过于频繁,请等待 {cd_str}", +) -> bool: + """声明式冷却检查依赖,限制用户操作频率 + + 参数: + duration: 冷却时间字符串 (e.g., "30s", "10m", "1h") + scope: 冷却作用域 + prompt: 自定义的冷却提示消息,可使用 {cd_str} 占位符 + + 返回: + bool: 是否允许执行 + """ + try: + parsed_seconds = TimeUtils.parse_time_string(duration) + except ValueError as e: + raise ValueError(f"Cooldown装饰器中的duration格式错误: {e}") + + return _create_limiter_dependency( + limiter_class=FreqLimiter, + limiter_storage=_coolers, + limiter_init_args={"default_cd_seconds": parsed_seconds}, + scope=scope, + prompt=prompt, + duration_sec=parsed_seconds, + ) + + +def RateLimit( + count: int, + duration: str, + *, + scope: Literal["user", "group", "global"] = "user", + prompt: str = "太快了,在 {duration_str} 内只能触发{limit}次,请等待 {cd_str}", +) -> bool: + """声明式速率限制依赖,在指定时间窗口内限制操作次数 + + 参数: + count: 在时间窗口内允许的最大调用次数 + duration: 时间窗口字符串 (e.g., "1m", "1h") + scope: 限制作用域 + prompt: 自定义的提示消息,可使用 {cd_str}, {duration_str}, {limit} 占位符 + + 返回: + bool: 是否允许执行 + """ + try: + parsed_seconds = TimeUtils.parse_time_string(duration) + except ValueError as e: + raise ValueError(f"RateLimit装饰器中的duration格式错误: {e}") + + return _create_limiter_dependency( + limiter_class=RateLimiter, + limiter_storage=_rate_limiters, + limiter_init_args={"max_calls": count, "time_window": parsed_seconds}, + scope=scope, + prompt=prompt, + prompt_format_kwargs={"duration_str": duration, "limit": count}, + ) + + +def ConcurrencyLimit( + count: int, + *, + scope: Literal["user", "group", "global"] = "global", + prompt: str | None = "当前功能繁忙,请稍后再试...", +) -> bool: + """声明式并发数限制依赖,控制某个功能同时执行的实例数量 + + 参数: + count: 最大并发数 + scope: 限制作用域 + prompt: 提示消息(暂未使用,主要用于未来扩展超时功能) + + 返回: + bool: 是否允许执行 + """ + return _create_limiter_dependency( + limiter_class=ConcurrencyLimiter, + limiter_storage=_concurrency_limiters, + limiter_init_args={"max_concurrent": count}, + scope=scope, + prompt=prompt or "", + ) def CheckUg(check_user: bool = True, check_group: bool = True): @@ -75,7 +243,6 @@ def GetConfig( if module_: value = Config.get_config(module_, config, default_value) if value is None and prompt: - # await matcher.finish(prompt or f"配置项 {config} 未填写!") await matcher.finish(prompt) return value diff --git a/zhenxun/utils/exception.py b/zhenxun/utils/exception.py index 9ab664f4..8b3ec282 100644 --- a/zhenxun/utils/exception.py +++ b/zhenxun/utils/exception.py @@ -1,3 +1,17 @@ +from nonebot.exception import IgnoredException + + +class CooldownError(IgnoredException): + """ + 冷却异常,用于在冷却时中断事件处理。 + 继承自 IgnoredException,不会在控制台留下错误堆栈。 + """ + + def __init__(self, message: str): + self.message = message + super().__init__(message) + + class HookPriorityException(BaseException): """ 钩子优先级异常 diff --git a/zhenxun/utils/limiters.py b/zhenxun/utils/limiters.py new file mode 100644 index 00000000..1bd5f662 --- /dev/null +++ b/zhenxun/utils/limiters.py @@ -0,0 +1,140 @@ +import asyncio +from collections import defaultdict, deque +import time +from typing import Any + + +class FreqLimiter: + """ + 命令冷却,检测用户是否处于冷却状态 + """ + + def __init__(self, default_cd_seconds: int): + self.next_time: dict[Any, float] = defaultdict(float) + self.default_cd = default_cd_seconds + + def check(self, key: Any) -> bool: + return time.time() >= self.next_time[key] + + def start_cd(self, key: Any, cd_time: int = 0): + self.next_time[key] = time.time() + ( + cd_time if cd_time > 0 else self.default_cd + ) + + def left_time(self, key: Any) -> float: + return max(0.0, self.next_time[key] - time.time()) + + +class CountLimiter: + """ + 每日调用命令次数限制 + """ + + tz = None + + def __init__(self, max_num: int): + self.today = -1 + self.count: dict[Any, int] = defaultdict(int) + self.max = max_num + + def check(self, key: Any) -> bool: + import datetime + + day = datetime.datetime.now().day + if day != self.today: + self.today = day + self.count.clear() + return self.count[key] < self.max + + def get_num(self, key: Any) -> int: + return self.count[key] + + def increase(self, key: Any, num: int = 1): + self.count[key] += num + + def reset(self, key: Any): + self.count[key] = 0 + + +class UserBlockLimiter: + """ + 检测用户是否正在调用命令 (简单阻塞锁) + """ + + def __init__(self): + self.flag_data: dict[Any, bool] = defaultdict(bool) + self.time: dict[Any, float] = defaultdict(float) + + def set_true(self, key: Any): + self.time[key] = time.time() + self.flag_data[key] = True + + def set_false(self, key: Any): + self.flag_data[key] = False + + def check(self, key: Any) -> bool: + if self.flag_data[key] and time.time() - self.time[key] > 30: + self.set_false(key) + return not self.flag_data[key] + + +class RateLimiter: + """ + 一个简单的基于时间窗口的速率限制器。 + """ + + def __init__(self, max_calls: int, time_window: int): + self.requests: dict[Any, deque[float]] = defaultdict(deque) + self.max_calls = max_calls + self.time_window = time_window + + def check(self, key: Any) -> bool: + """检查是否超出速率限制。如果未超出,则记录本次调用。""" + now = time.time() + + while self.requests[key] and self.requests[key][0] <= now - self.time_window: + self.requests[key].popleft() + + if len(self.requests[key]) < self.max_calls: + self.requests[key].append(now) + return True + return False + + def left_time(self, key: Any) -> float: + """计算距离下次可调用还需等待的时间""" + if self.requests[key]: + return max(0.0, self.requests[key][0] + self.time_window - time.time()) + return 0.0 + + +class ConcurrencyLimiter: + """ + 一个基于 asyncio.Semaphore 的并发限制器。 + """ + + def __init__(self, max_concurrent: int): + self._semaphores: dict[Any, asyncio.Semaphore] = {} + self.max_concurrent = max_concurrent + self._active_tasks: dict[Any, int] = defaultdict(int) + + def _get_semaphore(self, key: Any) -> asyncio.Semaphore: + if key not in self._semaphores: + self._semaphores[key] = asyncio.Semaphore(self.max_concurrent) + return self._semaphores[key] + + async def acquire(self, key: Any): + """获取一个信号量,如果达到并发上限则会阻塞等待。""" + semaphore = self._get_semaphore(key) + await semaphore.acquire() + self._active_tasks[key] += 1 + + def release(self, key: Any): + """释放一个信号量。""" + if key in self._semaphores: + if self._active_tasks[key] > 0: + self._semaphores[key].release() + self._active_tasks[key] -= 1 + else: + import logging + + logging.warning(f"尝试释放键 '{key}' 的信号量时,计数已经为零。") diff --git a/zhenxun/utils/message.py b/zhenxun/utils/message.py index 927b050c..5fec2213 100644 --- a/zhenxun/utils/message.py +++ b/zhenxun/utils/message.py @@ -49,11 +49,14 @@ class Config(BaseModel): class MessageUtils: @classmethod - def __build_message(cls, msg_list: list[MESSAGE_TYPE]) -> list[Text | Image]: + def __build_message( + cls, msg_list: list[MESSAGE_TYPE], format_args: dict | None = None + ) -> list[Text | Image]: """构造消息 参数: msg_list: 消息列表 + format_args: 用于格式化字符串的参数字典. 返回: list[Text | Text]: 构造完成的消息列表 @@ -65,7 +68,15 @@ class MessageUtils: if msg.startswith("base64://"): message_list.append(Image(raw=BytesIO(base64.b64decode(msg[9:])))) else: - message_list.append(Text(msg)) + formatted_msg = msg + if format_args: + try: + formatted_msg = msg.format_map(format_args) + except (KeyError, IndexError) as e: + logger.debug( + f"格式化字符串 '{msg}' 失败 ({e}),将使用原始文本。" + ) + message_list.append(Text(formatted_msg)) elif isinstance(msg, int | float): message_list.append(Text(str(msg))) elif isinstance(msg, Path): @@ -90,12 +101,15 @@ class MessageUtils: @classmethod def build_message( - cls, msg_list: MESSAGE_TYPE | list[MESSAGE_TYPE | list[MESSAGE_TYPE]] + cls, + msg_list: MESSAGE_TYPE | list[MESSAGE_TYPE | list[MESSAGE_TYPE]], + format_args: dict | None = None, ) -> UniMessage: """构造消息 参数: msg_list: 消息列表 + format_args: 用于格式化字符串的参数字典. 返回: UniMessage: 构造完成的消息列表 @@ -105,7 +119,7 @@ class MessageUtils: msg_list = [msg_list] for m in msg_list: _data = m if isinstance(m, list) else [m] - message_list += cls.__build_message(_data) # type: ignore + message_list += cls.__build_message(_data, format_args) return UniMessage(message_list) @classmethod diff --git a/zhenxun/utils/time_utils.py b/zhenxun/utils/time_utils.py new file mode 100644 index 00000000..f478625d --- /dev/null +++ b/zhenxun/utils/time_utils.py @@ -0,0 +1,91 @@ +from datetime import date, datetime +import re + +import pytz + + +class TimeUtils: + DEFAULT_TIMEZONE = pytz.timezone("Asia/Shanghai") + + @classmethod + def get_day_start(cls, target_date: date | datetime | None = None) -> datetime: + """获取某天的0点时间 + + 返回: + datetime: 今天某天的0点时间 + """ + if not target_date: + target_date = datetime.now(cls.DEFAULT_TIMEZONE) + + if isinstance(target_date, datetime) and target_date.tzinfo is None: + target_date = cls.DEFAULT_TIMEZONE.localize(target_date) + + return ( + target_date.replace(hour=0, minute=0, second=0, microsecond=0) + if isinstance(target_date, datetime) + else datetime.combine( + target_date, datetime.min.time(), tzinfo=cls.DEFAULT_TIMEZONE + ) + ) + + @classmethod + def is_valid_date(cls, date_text: str, separator: str = "-") -> bool: + """日期是否合法 + + 参数: + date_text: 日期 + separator: 分隔符 + + 返回: + bool: 日期是否合法 + """ + try: + datetime.strptime(date_text, f"%Y{separator}%m{separator}%d") + return True + except ValueError: + return False + + @classmethod + def parse_time_string(cls, time_str: str) -> int: + """ + 将带有单位的时间字符串 (e.g., "10s", "5m", "1h") 解析为总秒数。 + """ + time_str = time_str.lower().strip() + match = re.match(r"^(\d+)([smh])$", time_str) + if not match: + raise ValueError( + f"无效的时间格式: '{time_str}'。请使用如 '30s', '10m', '2h' 的格式。" + ) + + value, unit = int(match.group(1)), match.group(2) + + if unit == "s": + return value + if unit == "m": + return value * 60 + if unit == "h": + return value * 3600 + return 0 + + @classmethod + def format_duration(cls, seconds: float) -> str: + """ + 将秒数格式化为易于阅读的字符串 (例如 "1小时5分钟", "30.5秒") + """ + seconds = round(seconds, 1) + if seconds < 0.1: + return "不到1秒" + if seconds < 60: + return f"{seconds}秒" + + minutes, sec_remainder = divmod(int(seconds), 60) + + if minutes < 60: + if sec_remainder == 0: + return f"{minutes}分钟" + return f"{minutes}分钟{sec_remainder}秒" + + hours, rem_minutes = divmod(minutes, 60) + if rem_minutes == 0: + return f"{hours}小时" + return f"{hours}小时{rem_minutes}分钟" diff --git a/zhenxun/utils/utils.py b/zhenxun/utils/utils.py index 44dcd672..fc6b4096 100644 --- a/zhenxun/utils/utils.py +++ b/zhenxun/utils/utils.py @@ -1,19 +1,19 @@ -from collections import defaultdict from dataclasses import dataclass -from datetime import date, datetime +from datetime import datetime import os from pathlib import Path import time -from typing import Any, ClassVar +from typing import ClassVar import httpx from nonebot_plugin_uninfo import Uninfo import pypinyin -import pytz from zhenxun.configs.config import Config from zhenxun.services.log import logger +from .limiters import CountLimiter, FreqLimiter, UserBlockLimiter # noqa: F401 + @dataclass class EntityIDs: @@ -64,78 +64,6 @@ class ResourceDirManager: cls.__tree_append(path, deep) -class CountLimiter: - """ - 每日调用命令次数限制 - """ - - tz = pytz.timezone("Asia/Shanghai") - - def __init__(self, max_num): - self.today = -1 - self.count = defaultdict(int) - self.max = max_num - - def check(self, key) -> bool: - day = datetime.now(self.tz).day - if day != self.today: - self.today = day - self.count.clear() - return self.count[key] < self.max - - def get_num(self, key): - return self.count[key] - - def increase(self, key, num=1): - self.count[key] += num - - def reset(self, key): - self.count[key] = 0 - - -class UserBlockLimiter: - """ - 检测用户是否正在调用命令 - """ - - def __init__(self): - self.flag_data = defaultdict(bool) - self.time = time.time() - - def set_true(self, key: Any): - self.time = time.time() - self.flag_data[key] = True - - def set_false(self, key: Any): - self.flag_data[key] = False - - def check(self, key: Any) -> bool: - if time.time() - self.time > 30: - self.set_false(key) - return not self.flag_data[key] - - -class FreqLimiter: - """ - 命令冷却,检测用户是否处于冷却状态 - """ - - def __init__(self, default_cd_seconds: int): - self.next_time = defaultdict(float) - self.default_cd = default_cd_seconds - - def check(self, key: Any) -> bool: - return time.time() >= self.next_time[key] - - def start_cd(self, key: Any, cd_time: int = 0): - self.next_time[key] = time.time() + ( - cd_time if cd_time > 0 else self.default_cd - ) - - def left_time(self, key: Any) -> float: - return self.next_time[key] - time.time() - - def cn2py(word: str) -> str: """将字符串转化为拼音 @@ -277,20 +205,3 @@ def is_number(text: str) -> bool: return True except ValueError: return False - - -class TimeUtils: - @classmethod - def get_day_start(cls, target_date: date | datetime | None = None) -> datetime: - """获取某天的0点时间 - - 返回: - datetime: 今天某天的0点时间 - """ - if not target_date: - target_date = datetime.now() - return ( - target_date.replace(hour=0, minute=0, second=0, microsecond=0) - if isinstance(target_date, datetime) - else datetime.combine(target_date, datetime.min.time()) - )