mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-14 21:52:56 +08:00
✨ feat(limit, message): 引入声明式限流系统并增强消息格式化功能 (#1978)
- 新增 Cooldown、RateLimit、ConcurrencyLimit 三种限流依赖 - MessageUtils 支持动态格式化字符串 (format_args 参数) - 插件CD限制消息显示精确剩余时间 - 重构限流逻辑至 utils/limiters.py,新增时间工具模块 - 整合时间工具函数并优化时区处理 - 新增 limiter_hook 自动释放资源,CooldownError 优化异常处理 - 冷却提示从固定文本改为动态显示剩余时间 - 示例:总结功能冷却中,请等待 1分30秒 后再试~ Co-authored-by: webjoin111 <455457521@qq.com> Co-authored-by: HibiKier <45528451+HibiKier@users.noreply.github.com>
This commit is contained in:
parent
d218c569d4
commit
b993450a23
@ -11,14 +11,11 @@ from zhenxun.models.plugin_limit import PluginLimit
|
||||
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import LimitWatchType, PluginLimitType
|
||||
from zhenxun.utils.limiters import CountLimiter, FreqLimiter, UserBlockLimiter
|
||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
from zhenxun.utils.utils import (
|
||||
CountLimiter,
|
||||
FreqLimiter,
|
||||
UserBlockLimiter,
|
||||
get_entity_ids,
|
||||
)
|
||||
from zhenxun.utils.time_utils import TimeUtils
|
||||
from zhenxun.utils.utils import get_entity_ids
|
||||
|
||||
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
|
||||
from .exception import SkipPluginException
|
||||
@ -273,9 +270,16 @@ class LimitManager:
|
||||
key_type = channel_id or group_id
|
||||
if is_limit and not limiter.check(key_type):
|
||||
if limit.result:
|
||||
format_kwargs = {}
|
||||
if isinstance(limiter, FreqLimiter):
|
||||
left_time = limiter.left_time(key_type)
|
||||
cd_str = TimeUtils.format_duration(left_time)
|
||||
format_kwargs = {"cd": cd_str}
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
MessageUtils.build_message(limit.result).send(),
|
||||
MessageUtils.build_message(
|
||||
limit.result, format_args=format_kwargs
|
||||
).send(),
|
||||
timeout=DB_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
|
||||
15
zhenxun/builtin_plugins/hooks/limiter_hook.py
Normal file
15
zhenxun/builtin_plugins/hooks/limiter_hook.py
Normal file
@ -0,0 +1,15 @@
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot.message import run_postprocessor
|
||||
|
||||
from zhenxun.utils.limiters import ConcurrencyLimiter
|
||||
|
||||
|
||||
@run_postprocessor
|
||||
async def _concurrency_release_hook(matcher: Matcher):
|
||||
"""
|
||||
后处理器:在事件处理结束后,释放并发限制的信号量。
|
||||
"""
|
||||
if concurrency_info := matcher.state.get("_concurrency_limiter_info"):
|
||||
limiter: ConcurrencyLimiter = concurrency_info["limiter"]
|
||||
key = concurrency_info["key"]
|
||||
limiter.release(key)
|
||||
@ -8,7 +8,7 @@ from zhenxun.utils.echart_utils import ChartUtils
|
||||
from zhenxun.utils.echart_utils.models import Barh
|
||||
from zhenxun.utils.enum import PluginType
|
||||
from zhenxun.utils.image_utils import BuildImage
|
||||
from zhenxun.utils.utils import TimeUtils
|
||||
from zhenxun.utils.time_utils import TimeUtils
|
||||
|
||||
|
||||
class StatisticsManage:
|
||||
|
||||
@ -1,13 +1,181 @@
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot.internal.params import Depends
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot.params import Command
|
||||
from nonebot.permission import SUPERUSER
|
||||
from nonebot_plugin_session import EventSession
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.utils.limiters import ConcurrencyLimiter, FreqLimiter, RateLimiter
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
from zhenxun.utils.time_utils import TimeUtils
|
||||
|
||||
_coolers: dict[str, FreqLimiter] = {}
|
||||
_rate_limiters: dict[str, RateLimiter] = {}
|
||||
_concurrency_limiters: dict[str, ConcurrencyLimiter] = {}
|
||||
|
||||
|
||||
def _create_limiter_dependency(
|
||||
limiter_class: type,
|
||||
limiter_storage: dict,
|
||||
limiter_init_args: dict[str, Any],
|
||||
scope: Literal["user", "group", "global"],
|
||||
prompt: str,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
一个高阶函数,用于创建不同类型的限制器依赖。
|
||||
|
||||
参数:
|
||||
limiter_class: 限制器类 (FreqLimiter, RateLimiter, etc.).
|
||||
limiter_storage: 用于存储限制器实例的字典.
|
||||
limiter_init_args: 限制器类的初始化参数.
|
||||
scope: 限制作用域.
|
||||
prompt: 触发限制时的提示信息.
|
||||
**kwargs: 传递给特定限制器逻辑的额外参数.
|
||||
"""
|
||||
|
||||
async def dependency(
|
||||
matcher: Matcher, session: EventSession, bot: Bot, event: Event
|
||||
) -> bool:
|
||||
if await SUPERUSER(bot, event):
|
||||
return True
|
||||
|
||||
handler_id = (
|
||||
f"{matcher.plugin_name}:{matcher.handlers[0].call.__code__.co_firstlineno}"
|
||||
)
|
||||
|
||||
key: str | None = None
|
||||
if scope == "user":
|
||||
key = session.id1
|
||||
elif scope == "group":
|
||||
key = session.id3 or session.id2 or session.id1
|
||||
elif scope == "global":
|
||||
key = f"global_{handler_id}"
|
||||
|
||||
if not key:
|
||||
return True
|
||||
|
||||
if handler_id not in limiter_storage:
|
||||
limiter_storage[handler_id] = limiter_class(**limiter_init_args)
|
||||
limiter = limiter_storage[handler_id]
|
||||
|
||||
if isinstance(limiter, ConcurrencyLimiter):
|
||||
await limiter.acquire(key)
|
||||
matcher.state["_concurrency_limiter_info"] = {
|
||||
"limiter": limiter,
|
||||
"key": key,
|
||||
}
|
||||
return True
|
||||
else:
|
||||
if limiter.check(key):
|
||||
if isinstance(limiter, FreqLimiter):
|
||||
limiter.start_cd(
|
||||
key, kwargs.get("duration_sec", limiter.default_cd)
|
||||
)
|
||||
return True
|
||||
else:
|
||||
left_time = limiter.left_time(key)
|
||||
format_kwargs = {
|
||||
"cd_str": TimeUtils.format_duration(left_time),
|
||||
**(kwargs.get("prompt_format_kwargs", {})),
|
||||
}
|
||||
message = prompt.format(**format_kwargs)
|
||||
await matcher.finish(message)
|
||||
|
||||
return Depends(dependency)
|
||||
|
||||
|
||||
def Cooldown(
|
||||
duration: str,
|
||||
*,
|
||||
scope: Literal["user", "group", "global"] = "user",
|
||||
prompt: str = "操作过于频繁,请等待 {cd_str}",
|
||||
) -> bool:
|
||||
"""声明式冷却检查依赖,限制用户操作频率
|
||||
|
||||
参数:
|
||||
duration: 冷却时间字符串 (e.g., "30s", "10m", "1h")
|
||||
scope: 冷却作用域
|
||||
prompt: 自定义的冷却提示消息,可使用 {cd_str} 占位符
|
||||
|
||||
返回:
|
||||
bool: 是否允许执行
|
||||
"""
|
||||
try:
|
||||
parsed_seconds = TimeUtils.parse_time_string(duration)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Cooldown装饰器中的duration格式错误: {e}")
|
||||
|
||||
return _create_limiter_dependency(
|
||||
limiter_class=FreqLimiter,
|
||||
limiter_storage=_coolers,
|
||||
limiter_init_args={"default_cd_seconds": parsed_seconds},
|
||||
scope=scope,
|
||||
prompt=prompt,
|
||||
duration_sec=parsed_seconds,
|
||||
)
|
||||
|
||||
|
||||
def RateLimit(
|
||||
count: int,
|
||||
duration: str,
|
||||
*,
|
||||
scope: Literal["user", "group", "global"] = "user",
|
||||
prompt: str = "太快了,在 {duration_str} 内只能触发{limit}次,请等待 {cd_str}",
|
||||
) -> bool:
|
||||
"""声明式速率限制依赖,在指定时间窗口内限制操作次数
|
||||
|
||||
参数:
|
||||
count: 在时间窗口内允许的最大调用次数
|
||||
duration: 时间窗口字符串 (e.g., "1m", "1h")
|
||||
scope: 限制作用域
|
||||
prompt: 自定义的提示消息,可使用 {cd_str}, {duration_str}, {limit} 占位符
|
||||
|
||||
返回:
|
||||
bool: 是否允许执行
|
||||
"""
|
||||
try:
|
||||
parsed_seconds = TimeUtils.parse_time_string(duration)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"RateLimit装饰器中的duration格式错误: {e}")
|
||||
|
||||
return _create_limiter_dependency(
|
||||
limiter_class=RateLimiter,
|
||||
limiter_storage=_rate_limiters,
|
||||
limiter_init_args={"max_calls": count, "time_window": parsed_seconds},
|
||||
scope=scope,
|
||||
prompt=prompt,
|
||||
prompt_format_kwargs={"duration_str": duration, "limit": count},
|
||||
)
|
||||
|
||||
|
||||
def ConcurrencyLimit(
|
||||
count: int,
|
||||
*,
|
||||
scope: Literal["user", "group", "global"] = "global",
|
||||
prompt: str | None = "当前功能繁忙,请稍后再试...",
|
||||
) -> bool:
|
||||
"""声明式并发数限制依赖,控制某个功能同时执行的实例数量
|
||||
|
||||
参数:
|
||||
count: 最大并发数
|
||||
scope: 限制作用域
|
||||
prompt: 提示消息(暂未使用,主要用于未来扩展超时功能)
|
||||
|
||||
返回:
|
||||
bool: 是否允许执行
|
||||
"""
|
||||
return _create_limiter_dependency(
|
||||
limiter_class=ConcurrencyLimiter,
|
||||
limiter_storage=_concurrency_limiters,
|
||||
limiter_init_args={"max_concurrent": count},
|
||||
scope=scope,
|
||||
prompt=prompt or "",
|
||||
)
|
||||
|
||||
|
||||
def CheckUg(check_user: bool = True, check_group: bool = True):
|
||||
@ -75,7 +243,6 @@ def GetConfig(
|
||||
if module_:
|
||||
value = Config.get_config(module_, config, default_value)
|
||||
if value is None and prompt:
|
||||
# await matcher.finish(prompt or f"配置项 {config} 未填写!")
|
||||
await matcher.finish(prompt)
|
||||
return value
|
||||
|
||||
|
||||
@ -1,3 +1,17 @@
|
||||
from nonebot.exception import IgnoredException
|
||||
|
||||
|
||||
class CooldownError(IgnoredException):
|
||||
"""
|
||||
冷却异常,用于在冷却时中断事件处理。
|
||||
继承自 IgnoredException,不会在控制台留下错误堆栈。
|
||||
"""
|
||||
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class HookPriorityException(BaseException):
|
||||
"""
|
||||
钩子优先级异常
|
||||
|
||||
140
zhenxun/utils/limiters.py
Normal file
140
zhenxun/utils/limiters.py
Normal file
@ -0,0 +1,140 @@
|
||||
import asyncio
|
||||
from collections import defaultdict, deque
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
|
||||
class FreqLimiter:
|
||||
"""
|
||||
命令冷却,检测用户是否处于冷却状态
|
||||
"""
|
||||
|
||||
def __init__(self, default_cd_seconds: int):
|
||||
self.next_time: dict[Any, float] = defaultdict(float)
|
||||
self.default_cd = default_cd_seconds
|
||||
|
||||
def check(self, key: Any) -> bool:
|
||||
return time.time() >= self.next_time[key]
|
||||
|
||||
def start_cd(self, key: Any, cd_time: int = 0):
|
||||
self.next_time[key] = time.time() + (
|
||||
cd_time if cd_time > 0 else self.default_cd
|
||||
)
|
||||
|
||||
def left_time(self, key: Any) -> float:
|
||||
return max(0.0, self.next_time[key] - time.time())
|
||||
|
||||
|
||||
class CountLimiter:
|
||||
"""
|
||||
每日调用命令次数限制
|
||||
"""
|
||||
|
||||
tz = None
|
||||
|
||||
def __init__(self, max_num: int):
|
||||
self.today = -1
|
||||
self.count: dict[Any, int] = defaultdict(int)
|
||||
self.max = max_num
|
||||
|
||||
def check(self, key: Any) -> bool:
|
||||
import datetime
|
||||
|
||||
day = datetime.datetime.now().day
|
||||
if day != self.today:
|
||||
self.today = day
|
||||
self.count.clear()
|
||||
return self.count[key] < self.max
|
||||
|
||||
def get_num(self, key: Any) -> int:
|
||||
return self.count[key]
|
||||
|
||||
def increase(self, key: Any, num: int = 1):
|
||||
self.count[key] += num
|
||||
|
||||
def reset(self, key: Any):
|
||||
self.count[key] = 0
|
||||
|
||||
|
||||
class UserBlockLimiter:
|
||||
"""
|
||||
检测用户是否正在调用命令 (简单阻塞锁)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.flag_data: dict[Any, bool] = defaultdict(bool)
|
||||
self.time: dict[Any, float] = defaultdict(float)
|
||||
|
||||
def set_true(self, key: Any):
|
||||
self.time[key] = time.time()
|
||||
self.flag_data[key] = True
|
||||
|
||||
def set_false(self, key: Any):
|
||||
self.flag_data[key] = False
|
||||
|
||||
def check(self, key: Any) -> bool:
|
||||
if self.flag_data[key] and time.time() - self.time[key] > 30:
|
||||
self.set_false(key)
|
||||
return not self.flag_data[key]
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""
|
||||
一个简单的基于时间窗口的速率限制器。
|
||||
"""
|
||||
|
||||
def __init__(self, max_calls: int, time_window: int):
|
||||
self.requests: dict[Any, deque[float]] = defaultdict(deque)
|
||||
self.max_calls = max_calls
|
||||
self.time_window = time_window
|
||||
|
||||
def check(self, key: Any) -> bool:
|
||||
"""检查是否超出速率限制。如果未超出,则记录本次调用。"""
|
||||
now = time.time()
|
||||
|
||||
while self.requests[key] and self.requests[key][0] <= now - self.time_window:
|
||||
self.requests[key].popleft()
|
||||
|
||||
if len(self.requests[key]) < self.max_calls:
|
||||
self.requests[key].append(now)
|
||||
return True
|
||||
return False
|
||||
|
||||
def left_time(self, key: Any) -> float:
|
||||
"""计算距离下次可调用还需等待的时间"""
|
||||
if self.requests[key]:
|
||||
return max(0.0, self.requests[key][0] + self.time_window - time.time())
|
||||
return 0.0
|
||||
|
||||
|
||||
class ConcurrencyLimiter:
|
||||
"""
|
||||
一个基于 asyncio.Semaphore 的并发限制器。
|
||||
"""
|
||||
|
||||
def __init__(self, max_concurrent: int):
|
||||
self._semaphores: dict[Any, asyncio.Semaphore] = {}
|
||||
self.max_concurrent = max_concurrent
|
||||
self._active_tasks: dict[Any, int] = defaultdict(int)
|
||||
|
||||
def _get_semaphore(self, key: Any) -> asyncio.Semaphore:
|
||||
if key not in self._semaphores:
|
||||
self._semaphores[key] = asyncio.Semaphore(self.max_concurrent)
|
||||
return self._semaphores[key]
|
||||
|
||||
async def acquire(self, key: Any):
|
||||
"""获取一个信号量,如果达到并发上限则会阻塞等待。"""
|
||||
semaphore = self._get_semaphore(key)
|
||||
await semaphore.acquire()
|
||||
self._active_tasks[key] += 1
|
||||
|
||||
def release(self, key: Any):
|
||||
"""释放一个信号量。"""
|
||||
if key in self._semaphores:
|
||||
if self._active_tasks[key] > 0:
|
||||
self._semaphores[key].release()
|
||||
self._active_tasks[key] -= 1
|
||||
else:
|
||||
import logging
|
||||
|
||||
logging.warning(f"尝试释放键 '{key}' 的信号量时,计数已经为零。")
|
||||
@ -49,11 +49,14 @@ class Config(BaseModel):
|
||||
|
||||
class MessageUtils:
|
||||
@classmethod
|
||||
def __build_message(cls, msg_list: list[MESSAGE_TYPE]) -> list[Text | Image]:
|
||||
def __build_message(
|
||||
cls, msg_list: list[MESSAGE_TYPE], format_args: dict | None = None
|
||||
) -> list[Text | Image]:
|
||||
"""构造消息
|
||||
|
||||
参数:
|
||||
msg_list: 消息列表
|
||||
format_args: 用于格式化字符串的参数字典.
|
||||
|
||||
返回:
|
||||
list[Text | Text]: 构造完成的消息列表
|
||||
@ -65,7 +68,15 @@ class MessageUtils:
|
||||
if msg.startswith("base64://"):
|
||||
message_list.append(Image(raw=BytesIO(base64.b64decode(msg[9:]))))
|
||||
else:
|
||||
message_list.append(Text(msg))
|
||||
formatted_msg = msg
|
||||
if format_args:
|
||||
try:
|
||||
formatted_msg = msg.format_map(format_args)
|
||||
except (KeyError, IndexError) as e:
|
||||
logger.debug(
|
||||
f"格式化字符串 '{msg}' 失败 ({e}),将使用原始文本。"
|
||||
)
|
||||
message_list.append(Text(formatted_msg))
|
||||
elif isinstance(msg, int | float):
|
||||
message_list.append(Text(str(msg)))
|
||||
elif isinstance(msg, Path):
|
||||
@ -90,12 +101,15 @@ class MessageUtils:
|
||||
|
||||
@classmethod
|
||||
def build_message(
|
||||
cls, msg_list: MESSAGE_TYPE | list[MESSAGE_TYPE | list[MESSAGE_TYPE]]
|
||||
cls,
|
||||
msg_list: MESSAGE_TYPE | list[MESSAGE_TYPE | list[MESSAGE_TYPE]],
|
||||
format_args: dict | None = None,
|
||||
) -> UniMessage:
|
||||
"""构造消息
|
||||
|
||||
参数:
|
||||
msg_list: 消息列表
|
||||
format_args: 用于格式化字符串的参数字典.
|
||||
|
||||
返回:
|
||||
UniMessage: 构造完成的消息列表
|
||||
@ -105,7 +119,7 @@ class MessageUtils:
|
||||
msg_list = [msg_list]
|
||||
for m in msg_list:
|
||||
_data = m if isinstance(m, list) else [m]
|
||||
message_list += cls.__build_message(_data) # type: ignore
|
||||
message_list += cls.__build_message(_data, format_args)
|
||||
return UniMessage(message_list)
|
||||
|
||||
@classmethod
|
||||
|
||||
91
zhenxun/utils/time_utils.py
Normal file
91
zhenxun/utils/time_utils.py
Normal file
@ -0,0 +1,91 @@
|
||||
from datetime import date, datetime
|
||||
import re
|
||||
|
||||
import pytz
|
||||
|
||||
|
||||
class TimeUtils:
|
||||
DEFAULT_TIMEZONE = pytz.timezone("Asia/Shanghai")
|
||||
|
||||
@classmethod
|
||||
def get_day_start(cls, target_date: date | datetime | None = None) -> datetime:
|
||||
"""获取某天的0点时间
|
||||
|
||||
返回:
|
||||
datetime: 今天某天的0点时间
|
||||
"""
|
||||
if not target_date:
|
||||
target_date = datetime.now(cls.DEFAULT_TIMEZONE)
|
||||
|
||||
if isinstance(target_date, datetime) and target_date.tzinfo is None:
|
||||
target_date = cls.DEFAULT_TIMEZONE.localize(target_date)
|
||||
|
||||
return (
|
||||
target_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
if isinstance(target_date, datetime)
|
||||
else datetime.combine(
|
||||
target_date, datetime.min.time(), tzinfo=cls.DEFAULT_TIMEZONE
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_valid_date(cls, date_text: str, separator: str = "-") -> bool:
|
||||
"""日期是否合法
|
||||
|
||||
参数:
|
||||
date_text: 日期
|
||||
separator: 分隔符
|
||||
|
||||
返回:
|
||||
bool: 日期是否合法
|
||||
"""
|
||||
try:
|
||||
datetime.strptime(date_text, f"%Y{separator}%m{separator}%d")
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def parse_time_string(cls, time_str: str) -> int:
|
||||
"""
|
||||
将带有单位的时间字符串 (e.g., "10s", "5m", "1h") 解析为总秒数。
|
||||
"""
|
||||
time_str = time_str.lower().strip()
|
||||
match = re.match(r"^(\d+)([smh])$", time_str)
|
||||
if not match:
|
||||
raise ValueError(
|
||||
f"无效的时间格式: '{time_str}'。请使用如 '30s', '10m', '2h' 的格式。"
|
||||
)
|
||||
|
||||
value, unit = int(match.group(1)), match.group(2)
|
||||
|
||||
if unit == "s":
|
||||
return value
|
||||
if unit == "m":
|
||||
return value * 60
|
||||
if unit == "h":
|
||||
return value * 3600
|
||||
return 0
|
||||
|
||||
@classmethod
|
||||
def format_duration(cls, seconds: float) -> str:
|
||||
"""
|
||||
将秒数格式化为易于阅读的字符串 (例如 "1小时5分钟", "30.5秒")
|
||||
"""
|
||||
seconds = round(seconds, 1)
|
||||
if seconds < 0.1:
|
||||
return "不到1秒"
|
||||
if seconds < 60:
|
||||
return f"{seconds}秒"
|
||||
|
||||
minutes, sec_remainder = divmod(int(seconds), 60)
|
||||
|
||||
if minutes < 60:
|
||||
if sec_remainder == 0:
|
||||
return f"{minutes}分钟"
|
||||
return f"{minutes}分钟{sec_remainder}秒"
|
||||
|
||||
hours, rem_minutes = divmod(minutes, 60)
|
||||
if rem_minutes == 0:
|
||||
return f"{hours}小时"
|
||||
return f"{hours}小时{rem_minutes}分钟"
|
||||
@ -1,19 +1,19 @@
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from datetime import date, datetime
|
||||
from datetime import datetime
|
||||
import os
|
||||
from pathlib import Path
|
||||
import time
|
||||
from typing import Any, ClassVar
|
||||
from typing import ClassVar
|
||||
|
||||
import httpx
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
import pypinyin
|
||||
import pytz
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
from .limiters import CountLimiter, FreqLimiter, UserBlockLimiter # noqa: F401
|
||||
|
||||
|
||||
@dataclass
|
||||
class EntityIDs:
|
||||
@ -64,78 +64,6 @@ class ResourceDirManager:
|
||||
cls.__tree_append(path, deep)
|
||||
|
||||
|
||||
class CountLimiter:
|
||||
"""
|
||||
每日调用命令次数限制
|
||||
"""
|
||||
|
||||
tz = pytz.timezone("Asia/Shanghai")
|
||||
|
||||
def __init__(self, max_num):
|
||||
self.today = -1
|
||||
self.count = defaultdict(int)
|
||||
self.max = max_num
|
||||
|
||||
def check(self, key) -> bool:
|
||||
day = datetime.now(self.tz).day
|
||||
if day != self.today:
|
||||
self.today = day
|
||||
self.count.clear()
|
||||
return self.count[key] < self.max
|
||||
|
||||
def get_num(self, key):
|
||||
return self.count[key]
|
||||
|
||||
def increase(self, key, num=1):
|
||||
self.count[key] += num
|
||||
|
||||
def reset(self, key):
|
||||
self.count[key] = 0
|
||||
|
||||
|
||||
class UserBlockLimiter:
|
||||
"""
|
||||
检测用户是否正在调用命令
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.flag_data = defaultdict(bool)
|
||||
self.time = time.time()
|
||||
|
||||
def set_true(self, key: Any):
|
||||
self.time = time.time()
|
||||
self.flag_data[key] = True
|
||||
|
||||
def set_false(self, key: Any):
|
||||
self.flag_data[key] = False
|
||||
|
||||
def check(self, key: Any) -> bool:
|
||||
if time.time() - self.time > 30:
|
||||
self.set_false(key)
|
||||
return not self.flag_data[key]
|
||||
|
||||
|
||||
class FreqLimiter:
|
||||
"""
|
||||
命令冷却,检测用户是否处于冷却状态
|
||||
"""
|
||||
|
||||
def __init__(self, default_cd_seconds: int):
|
||||
self.next_time = defaultdict(float)
|
||||
self.default_cd = default_cd_seconds
|
||||
|
||||
def check(self, key: Any) -> bool:
|
||||
return time.time() >= self.next_time[key]
|
||||
|
||||
def start_cd(self, key: Any, cd_time: int = 0):
|
||||
self.next_time[key] = time.time() + (
|
||||
cd_time if cd_time > 0 else self.default_cd
|
||||
)
|
||||
|
||||
def left_time(self, key: Any) -> float:
|
||||
return self.next_time[key] - time.time()
|
||||
|
||||
|
||||
def cn2py(word: str) -> str:
|
||||
"""将字符串转化为拼音
|
||||
|
||||
@ -277,20 +205,3 @@ def is_number(text: str) -> bool:
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
class TimeUtils:
|
||||
@classmethod
|
||||
def get_day_start(cls, target_date: date | datetime | None = None) -> datetime:
|
||||
"""获取某天的0点时间
|
||||
|
||||
返回:
|
||||
datetime: 今天某天的0点时间
|
||||
"""
|
||||
if not target_date:
|
||||
target_date = datetime.now()
|
||||
return (
|
||||
target_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
if isinstance(target_date, datetime)
|
||||
else datetime.combine(target_date, datetime.min.time())
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user