feat(limit, message): 引入声明式限流系统并增强消息格式化功能 (#1978)

- 新增 Cooldown、RateLimit、ConcurrencyLimit 三种限流依赖
- MessageUtils 支持动态格式化字符串 (format_args 参数)
- 插件CD限制消息显示精确剩余时间

- 重构限流逻辑至 utils/limiters.py,新增时间工具模块
- 整合时间工具函数并优化时区处理
- 新增 limiter_hook 自动释放资源,CooldownError 优化异常处理

- 冷却提示从固定文本改为动态显示剩余时间
- 示例:总结功能冷却中,请等待 1分30秒 后再试~

Co-authored-by: webjoin111 <455457521@qq.com>
Co-authored-by: HibiKier <45528451+HibiKier@users.noreply.github.com>
This commit is contained in:
Rumio 2025-07-15 17:13:33 +08:00 committed by GitHub
parent d218c569d4
commit b993450a23
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 463 additions and 107 deletions

View File

@ -11,14 +11,11 @@ from zhenxun.models.plugin_limit import PluginLimit
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
from zhenxun.services.log import logger
from zhenxun.utils.enum import LimitWatchType, PluginLimitType
from zhenxun.utils.limiters import CountLimiter, FreqLimiter, UserBlockLimiter
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from zhenxun.utils.message import MessageUtils
from zhenxun.utils.utils import (
CountLimiter,
FreqLimiter,
UserBlockLimiter,
get_entity_ids,
)
from zhenxun.utils.time_utils import TimeUtils
from zhenxun.utils.utils import get_entity_ids
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
from .exception import SkipPluginException
@ -273,9 +270,16 @@ class LimitManager:
key_type = channel_id or group_id
if is_limit and not limiter.check(key_type):
if limit.result:
format_kwargs = {}
if isinstance(limiter, FreqLimiter):
left_time = limiter.left_time(key_type)
cd_str = TimeUtils.format_duration(left_time)
format_kwargs = {"cd": cd_str}
try:
await asyncio.wait_for(
MessageUtils.build_message(limit.result).send(),
MessageUtils.build_message(
limit.result, format_args=format_kwargs
).send(),
timeout=DB_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:

View File

@ -0,0 +1,15 @@
from nonebot.matcher import Matcher
from nonebot.message import run_postprocessor
from zhenxun.utils.limiters import ConcurrencyLimiter
@run_postprocessor
async def _concurrency_release_hook(matcher: Matcher):
"""
后处理器在事件处理结束后释放并发限制的信号量
"""
if concurrency_info := matcher.state.get("_concurrency_limiter_info"):
limiter: ConcurrencyLimiter = concurrency_info["limiter"]
key = concurrency_info["key"]
limiter.release(key)

View File

@ -8,7 +8,7 @@ from zhenxun.utils.echart_utils import ChartUtils
from zhenxun.utils.echart_utils.models import Barh
from zhenxun.utils.enum import PluginType
from zhenxun.utils.image_utils import BuildImage
from zhenxun.utils.utils import TimeUtils
from zhenxun.utils.time_utils import TimeUtils
class StatisticsManage:

View File

@ -1,13 +1,181 @@
from typing import Any
from typing import Any, Literal
from nonebot.adapters import Bot, Event
from nonebot.internal.params import Depends
from nonebot.matcher import Matcher
from nonebot.params import Command
from nonebot.permission import SUPERUSER
from nonebot_plugin_session import EventSession
from nonebot_plugin_uninfo import Uninfo
from zhenxun.configs.config import Config
from zhenxun.utils.limiters import ConcurrencyLimiter, FreqLimiter, RateLimiter
from zhenxun.utils.message import MessageUtils
from zhenxun.utils.time_utils import TimeUtils
_coolers: dict[str, FreqLimiter] = {}
_rate_limiters: dict[str, RateLimiter] = {}
_concurrency_limiters: dict[str, ConcurrencyLimiter] = {}
def _create_limiter_dependency(
limiter_class: type,
limiter_storage: dict,
limiter_init_args: dict[str, Any],
scope: Literal["user", "group", "global"],
prompt: str,
**kwargs,
):
"""
一个高阶函数用于创建不同类型的限制器依赖
参数:
limiter_class: 限制器类 (FreqLimiter, RateLimiter, etc.).
limiter_storage: 用于存储限制器实例的字典.
limiter_init_args: 限制器类的初始化参数.
scope: 限制作用域.
prompt: 触发限制时的提示信息.
**kwargs: 传递给特定限制器逻辑的额外参数.
"""
async def dependency(
matcher: Matcher, session: EventSession, bot: Bot, event: Event
) -> bool:
if await SUPERUSER(bot, event):
return True
handler_id = (
f"{matcher.plugin_name}:{matcher.handlers[0].call.__code__.co_firstlineno}"
)
key: str | None = None
if scope == "user":
key = session.id1
elif scope == "group":
key = session.id3 or session.id2 or session.id1
elif scope == "global":
key = f"global_{handler_id}"
if not key:
return True
if handler_id not in limiter_storage:
limiter_storage[handler_id] = limiter_class(**limiter_init_args)
limiter = limiter_storage[handler_id]
if isinstance(limiter, ConcurrencyLimiter):
await limiter.acquire(key)
matcher.state["_concurrency_limiter_info"] = {
"limiter": limiter,
"key": key,
}
return True
else:
if limiter.check(key):
if isinstance(limiter, FreqLimiter):
limiter.start_cd(
key, kwargs.get("duration_sec", limiter.default_cd)
)
return True
else:
left_time = limiter.left_time(key)
format_kwargs = {
"cd_str": TimeUtils.format_duration(left_time),
**(kwargs.get("prompt_format_kwargs", {})),
}
message = prompt.format(**format_kwargs)
await matcher.finish(message)
return Depends(dependency)
def Cooldown(
duration: str,
*,
scope: Literal["user", "group", "global"] = "user",
prompt: str = "操作过于频繁,请等待 {cd_str}",
) -> bool:
"""声明式冷却检查依赖,限制用户操作频率
参数:
duration: 冷却时间字符串 (e.g., "30s", "10m", "1h")
scope: 冷却作用域
prompt: 自定义的冷却提示消息可使用 {cd_str} 占位符
返回:
bool: 是否允许执行
"""
try:
parsed_seconds = TimeUtils.parse_time_string(duration)
except ValueError as e:
raise ValueError(f"Cooldown装饰器中的duration格式错误: {e}")
return _create_limiter_dependency(
limiter_class=FreqLimiter,
limiter_storage=_coolers,
limiter_init_args={"default_cd_seconds": parsed_seconds},
scope=scope,
prompt=prompt,
duration_sec=parsed_seconds,
)
def RateLimit(
count: int,
duration: str,
*,
scope: Literal["user", "group", "global"] = "user",
prompt: str = "太快了,在 {duration_str} 内只能触发{limit}次,请等待 {cd_str}",
) -> bool:
"""声明式速率限制依赖,在指定时间窗口内限制操作次数
参数:
count: 在时间窗口内允许的最大调用次数
duration: 时间窗口字符串 (e.g., "1m", "1h")
scope: 限制作用域
prompt: 自定义的提示消息可使用 {cd_str}, {duration_str}, {limit} 占位符
返回:
bool: 是否允许执行
"""
try:
parsed_seconds = TimeUtils.parse_time_string(duration)
except ValueError as e:
raise ValueError(f"RateLimit装饰器中的duration格式错误: {e}")
return _create_limiter_dependency(
limiter_class=RateLimiter,
limiter_storage=_rate_limiters,
limiter_init_args={"max_calls": count, "time_window": parsed_seconds},
scope=scope,
prompt=prompt,
prompt_format_kwargs={"duration_str": duration, "limit": count},
)
def ConcurrencyLimit(
count: int,
*,
scope: Literal["user", "group", "global"] = "global",
prompt: str | None = "当前功能繁忙,请稍后再试...",
) -> bool:
"""声明式并发数限制依赖,控制某个功能同时执行的实例数量
参数:
count: 最大并发数
scope: 限制作用域
prompt: 提示消息暂未使用主要用于未来扩展超时功能
返回:
bool: 是否允许执行
"""
return _create_limiter_dependency(
limiter_class=ConcurrencyLimiter,
limiter_storage=_concurrency_limiters,
limiter_init_args={"max_concurrent": count},
scope=scope,
prompt=prompt or "",
)
def CheckUg(check_user: bool = True, check_group: bool = True):
@ -75,7 +243,6 @@ def GetConfig(
if module_:
value = Config.get_config(module_, config, default_value)
if value is None and prompt:
# await matcher.finish(prompt or f"配置项 {config} 未填写!")
await matcher.finish(prompt)
return value

View File

@ -1,3 +1,17 @@
from nonebot.exception import IgnoredException
class CooldownError(IgnoredException):
"""
冷却异常用于在冷却时中断事件处理
继承自 IgnoredException不会在控制台留下错误堆栈
"""
def __init__(self, message: str):
self.message = message
super().__init__(message)
class HookPriorityException(BaseException):
"""
钩子优先级异常

140
zhenxun/utils/limiters.py Normal file
View File

@ -0,0 +1,140 @@
import asyncio
from collections import defaultdict, deque
import time
from typing import Any
class FreqLimiter:
"""
命令冷却检测用户是否处于冷却状态
"""
def __init__(self, default_cd_seconds: int):
self.next_time: dict[Any, float] = defaultdict(float)
self.default_cd = default_cd_seconds
def check(self, key: Any) -> bool:
return time.time() >= self.next_time[key]
def start_cd(self, key: Any, cd_time: int = 0):
self.next_time[key] = time.time() + (
cd_time if cd_time > 0 else self.default_cd
)
def left_time(self, key: Any) -> float:
return max(0.0, self.next_time[key] - time.time())
class CountLimiter:
"""
每日调用命令次数限制
"""
tz = None
def __init__(self, max_num: int):
self.today = -1
self.count: dict[Any, int] = defaultdict(int)
self.max = max_num
def check(self, key: Any) -> bool:
import datetime
day = datetime.datetime.now().day
if day != self.today:
self.today = day
self.count.clear()
return self.count[key] < self.max
def get_num(self, key: Any) -> int:
return self.count[key]
def increase(self, key: Any, num: int = 1):
self.count[key] += num
def reset(self, key: Any):
self.count[key] = 0
class UserBlockLimiter:
"""
检测用户是否正在调用命令 (简单阻塞锁)
"""
def __init__(self):
self.flag_data: dict[Any, bool] = defaultdict(bool)
self.time: dict[Any, float] = defaultdict(float)
def set_true(self, key: Any):
self.time[key] = time.time()
self.flag_data[key] = True
def set_false(self, key: Any):
self.flag_data[key] = False
def check(self, key: Any) -> bool:
if self.flag_data[key] and time.time() - self.time[key] > 30:
self.set_false(key)
return not self.flag_data[key]
class RateLimiter:
"""
一个简单的基于时间窗口的速率限制器
"""
def __init__(self, max_calls: int, time_window: int):
self.requests: dict[Any, deque[float]] = defaultdict(deque)
self.max_calls = max_calls
self.time_window = time_window
def check(self, key: Any) -> bool:
"""检查是否超出速率限制。如果未超出,则记录本次调用。"""
now = time.time()
while self.requests[key] and self.requests[key][0] <= now - self.time_window:
self.requests[key].popleft()
if len(self.requests[key]) < self.max_calls:
self.requests[key].append(now)
return True
return False
def left_time(self, key: Any) -> float:
"""计算距离下次可调用还需等待的时间"""
if self.requests[key]:
return max(0.0, self.requests[key][0] + self.time_window - time.time())
return 0.0
class ConcurrencyLimiter:
"""
一个基于 asyncio.Semaphore 的并发限制器
"""
def __init__(self, max_concurrent: int):
self._semaphores: dict[Any, asyncio.Semaphore] = {}
self.max_concurrent = max_concurrent
self._active_tasks: dict[Any, int] = defaultdict(int)
def _get_semaphore(self, key: Any) -> asyncio.Semaphore:
if key not in self._semaphores:
self._semaphores[key] = asyncio.Semaphore(self.max_concurrent)
return self._semaphores[key]
async def acquire(self, key: Any):
"""获取一个信号量,如果达到并发上限则会阻塞等待。"""
semaphore = self._get_semaphore(key)
await semaphore.acquire()
self._active_tasks[key] += 1
def release(self, key: Any):
"""释放一个信号量。"""
if key in self._semaphores:
if self._active_tasks[key] > 0:
self._semaphores[key].release()
self._active_tasks[key] -= 1
else:
import logging
logging.warning(f"尝试释放键 '{key}' 的信号量时,计数已经为零。")

View File

@ -49,11 +49,14 @@ class Config(BaseModel):
class MessageUtils:
@classmethod
def __build_message(cls, msg_list: list[MESSAGE_TYPE]) -> list[Text | Image]:
def __build_message(
cls, msg_list: list[MESSAGE_TYPE], format_args: dict | None = None
) -> list[Text | Image]:
"""构造消息
参数:
msg_list: 消息列表
format_args: 用于格式化字符串的参数字典.
返回:
list[Text | Text]: 构造完成的消息列表
@ -65,7 +68,15 @@ class MessageUtils:
if msg.startswith("base64://"):
message_list.append(Image(raw=BytesIO(base64.b64decode(msg[9:]))))
else:
message_list.append(Text(msg))
formatted_msg = msg
if format_args:
try:
formatted_msg = msg.format_map(format_args)
except (KeyError, IndexError) as e:
logger.debug(
f"格式化字符串 '{msg}' 失败 ({e}),将使用原始文本。"
)
message_list.append(Text(formatted_msg))
elif isinstance(msg, int | float):
message_list.append(Text(str(msg)))
elif isinstance(msg, Path):
@ -90,12 +101,15 @@ class MessageUtils:
@classmethod
def build_message(
cls, msg_list: MESSAGE_TYPE | list[MESSAGE_TYPE | list[MESSAGE_TYPE]]
cls,
msg_list: MESSAGE_TYPE | list[MESSAGE_TYPE | list[MESSAGE_TYPE]],
format_args: dict | None = None,
) -> UniMessage:
"""构造消息
参数:
msg_list: 消息列表
format_args: 用于格式化字符串的参数字典.
返回:
UniMessage: 构造完成的消息列表
@ -105,7 +119,7 @@ class MessageUtils:
msg_list = [msg_list]
for m in msg_list:
_data = m if isinstance(m, list) else [m]
message_list += cls.__build_message(_data) # type: ignore
message_list += cls.__build_message(_data, format_args)
return UniMessage(message_list)
@classmethod

View File

@ -0,0 +1,91 @@
from datetime import date, datetime
import re
import pytz
class TimeUtils:
DEFAULT_TIMEZONE = pytz.timezone("Asia/Shanghai")
@classmethod
def get_day_start(cls, target_date: date | datetime | None = None) -> datetime:
"""获取某天的0点时间
返回:
datetime: 今天某天的0点时间
"""
if not target_date:
target_date = datetime.now(cls.DEFAULT_TIMEZONE)
if isinstance(target_date, datetime) and target_date.tzinfo is None:
target_date = cls.DEFAULT_TIMEZONE.localize(target_date)
return (
target_date.replace(hour=0, minute=0, second=0, microsecond=0)
if isinstance(target_date, datetime)
else datetime.combine(
target_date, datetime.min.time(), tzinfo=cls.DEFAULT_TIMEZONE
)
)
@classmethod
def is_valid_date(cls, date_text: str, separator: str = "-") -> bool:
"""日期是否合法
参数:
date_text: 日期
separator: 分隔符
返回:
bool: 日期是否合法
"""
try:
datetime.strptime(date_text, f"%Y{separator}%m{separator}%d")
return True
except ValueError:
return False
@classmethod
def parse_time_string(cls, time_str: str) -> int:
"""
将带有单位的时间字符串 (e.g., "10s", "5m", "1h") 解析为总秒数
"""
time_str = time_str.lower().strip()
match = re.match(r"^(\d+)([smh])$", time_str)
if not match:
raise ValueError(
f"无效的时间格式: '{time_str}'。请使用如 '30s', '10m', '2h' 的格式。"
)
value, unit = int(match.group(1)), match.group(2)
if unit == "s":
return value
if unit == "m":
return value * 60
if unit == "h":
return value * 3600
return 0
@classmethod
def format_duration(cls, seconds: float) -> str:
"""
将秒数格式化为易于阅读的字符串 (例如 "1小时5分钟", "30.5秒")
"""
seconds = round(seconds, 1)
if seconds < 0.1:
return "不到1秒"
if seconds < 60:
return f"{seconds}"
minutes, sec_remainder = divmod(int(seconds), 60)
if minutes < 60:
if sec_remainder == 0:
return f"{minutes}分钟"
return f"{minutes}分钟{sec_remainder}"
hours, rem_minutes = divmod(minutes, 60)
if rem_minutes == 0:
return f"{hours}小时"
return f"{hours}小时{rem_minutes}分钟"

View File

@ -1,19 +1,19 @@
from collections import defaultdict
from dataclasses import dataclass
from datetime import date, datetime
from datetime import datetime
import os
from pathlib import Path
import time
from typing import Any, ClassVar
from typing import ClassVar
import httpx
from nonebot_plugin_uninfo import Uninfo
import pypinyin
import pytz
from zhenxun.configs.config import Config
from zhenxun.services.log import logger
from .limiters import CountLimiter, FreqLimiter, UserBlockLimiter # noqa: F401
@dataclass
class EntityIDs:
@ -64,78 +64,6 @@ class ResourceDirManager:
cls.__tree_append(path, deep)
class CountLimiter:
"""
每日调用命令次数限制
"""
tz = pytz.timezone("Asia/Shanghai")
def __init__(self, max_num):
self.today = -1
self.count = defaultdict(int)
self.max = max_num
def check(self, key) -> bool:
day = datetime.now(self.tz).day
if day != self.today:
self.today = day
self.count.clear()
return self.count[key] < self.max
def get_num(self, key):
return self.count[key]
def increase(self, key, num=1):
self.count[key] += num
def reset(self, key):
self.count[key] = 0
class UserBlockLimiter:
"""
检测用户是否正在调用命令
"""
def __init__(self):
self.flag_data = defaultdict(bool)
self.time = time.time()
def set_true(self, key: Any):
self.time = time.time()
self.flag_data[key] = True
def set_false(self, key: Any):
self.flag_data[key] = False
def check(self, key: Any) -> bool:
if time.time() - self.time > 30:
self.set_false(key)
return not self.flag_data[key]
class FreqLimiter:
"""
命令冷却检测用户是否处于冷却状态
"""
def __init__(self, default_cd_seconds: int):
self.next_time = defaultdict(float)
self.default_cd = default_cd_seconds
def check(self, key: Any) -> bool:
return time.time() >= self.next_time[key]
def start_cd(self, key: Any, cd_time: int = 0):
self.next_time[key] = time.time() + (
cd_time if cd_time > 0 else self.default_cd
)
def left_time(self, key: Any) -> float:
return self.next_time[key] - time.time()
def cn2py(word: str) -> str:
"""将字符串转化为拼音
@ -277,20 +205,3 @@ def is_number(text: str) -> bool:
return True
except ValueError:
return False
class TimeUtils:
@classmethod
def get_day_start(cls, target_date: date | datetime | None = None) -> datetime:
"""获取某天的0点时间
返回:
datetime: 今天某天的0点时间
"""
if not target_date:
target_date = datetime.now()
return (
target_date.replace(hour=0, minute=0, second=0, microsecond=0)
if isinstance(target_date, datetime)
else datetime.combine(target_date, datetime.min.time())
)