Merge branch 'main' into feature/bot-profile

This commit is contained in:
HibiKier 2025-07-16 02:36:46 +08:00 committed by GitHub
commit e3e8f5b89b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 756 additions and 355 deletions

View File

@ -5,7 +5,7 @@ import nonebot
from nonebot.adapters import Bot
from nonebot.drivers import Driver
from tortoise import Tortoise
from tortoise.exceptions import OperationalError
from tortoise.exceptions import IntegrityError, OperationalError
import ujson as json
from zhenxun.models.bot_connect_log import BotConnectLog
@ -30,9 +30,12 @@ async def _(bot: Bot):
bot_id=bot.self_id, platform=bot.adapter, connect_time=datetime.now(), type=1
)
if not await BotConsole.exists(bot_id=bot.self_id):
await BotConsole.create(
bot_id=bot.self_id, platform=PlatformUtils.get_platform(bot)
)
try:
await BotConsole.create(
bot_id=bot.self_id, platform=PlatformUtils.get_platform(bot)
)
except IntegrityError as e:
logger.warning(f"记录bot: {bot.self_id} 数据已存在...", e=e)
@driver.on_bot_disconnect

View File

@ -11,14 +11,11 @@ from zhenxun.models.plugin_limit import PluginLimit
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
from zhenxun.services.log import logger
from zhenxun.utils.enum import LimitWatchType, PluginLimitType
from zhenxun.utils.limiters import CountLimiter, FreqLimiter, UserBlockLimiter
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from zhenxun.utils.message import MessageUtils
from zhenxun.utils.utils import (
CountLimiter,
FreqLimiter,
UserBlockLimiter,
get_entity_ids,
)
from zhenxun.utils.time_utils import TimeUtils
from zhenxun.utils.utils import get_entity_ids
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
from .exception import SkipPluginException
@ -273,9 +270,16 @@ class LimitManager:
key_type = channel_id or group_id
if is_limit and not limiter.check(key_type):
if limit.result:
format_kwargs = {}
if isinstance(limiter, FreqLimiter):
left_time = limiter.left_time(key_type)
cd_str = TimeUtils.format_duration(left_time)
format_kwargs = {"cd": cd_str}
try:
await asyncio.wait_for(
MessageUtils.build_message(limit.result).send(),
MessageUtils.build_message(
limit.result, format_args=format_kwargs
).send(),
timeout=DB_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:

View File

@ -0,0 +1,15 @@
from nonebot.matcher import Matcher
from nonebot.message import run_postprocessor
from zhenxun.utils.limiters import ConcurrencyLimiter
@run_postprocessor
async def _concurrency_release_hook(matcher: Matcher):
"""
后处理器在事件处理结束后释放并发限制的信号量
"""
if concurrency_info := matcher.state.get("_concurrency_limiter_info"):
limiter: ConcurrencyLimiter = concurrency_info["limiter"]
key = concurrency_info["key"]
limiter.release(key)

View File

@ -8,7 +8,7 @@ from zhenxun.utils.echart_utils import ChartUtils
from zhenxun.utils.echart_utils.models import Barh
from zhenxun.utils.enum import PluginType
from zhenxun.utils.image_utils import BuildImage
from zhenxun.utils.utils import TimeUtils
from zhenxun.utils.time_utils import TimeUtils
class StatisticsManage:

View File

@ -18,7 +18,7 @@ require("nonebot_plugin_htmlrender")
require("nonebot_plugin_uninfo")
require("nonebot_plugin_waiter")
from .db_context import Model, disconnect
from .db_context import Model, disconnect, with_db_timeout
from .llm import (
AI,
AIConfig,
@ -80,4 +80,5 @@ __all__ = [
"search_multimodal",
"set_global_default_model_name",
"tool_registry",
"with_db_timeout",
]

View File

@ -63,6 +63,7 @@ from functools import wraps
from typing import Any, ClassVar, Generic, TypeVar, get_type_hints
from aiocache import Cache as AioCache
from aiocache import SimpleMemoryCache
from aiocache.base import BaseCache
from aiocache.serializers import JsonSerializer
import nonebot
@ -392,22 +393,14 @@ class CacheManager:
def cache_backend(self) -> BaseCache | AioCache:
"""获取缓存后端"""
if self._cache_backend is None:
try:
from aiocache import RedisCache, SimpleMemoryCache
ttl = cache_config.redis_expire
if cache_config.cache_mode == CacheMode.NONE:
ttl = 0
logger.info("缓存功能已禁用,使用非持久化内存缓存", LOG_COMMAND)
elif cache_config.cache_mode == CacheMode.REDIS and cache_config.redis_host:
try:
from aiocache import RedisCache
if cache_config.cache_mode == CacheMode.NONE:
# 使用内存缓存但禁用持久化
self._cache_backend = SimpleMemoryCache(
serializer=JsonSerializer(),
namespace=CACHE_KEY_PREFIX,
timeout=30,
ttl=0, # 设置为0不缓存
)
logger.info("缓存功能已禁用,使用非持久化内存缓存", LOG_COMMAND)
elif (
cache_config.cache_mode == CacheMode.REDIS
and cache_config.redis_host
):
# 使用Redis缓存
self._cache_backend = RedisCache(
serializer=JsonSerializer(),
@ -419,27 +412,25 @@ class CacheManager:
password=cache_config.redis_password,
)
logger.info(
f"使用Redis缓存地址: {cache_config.redis_host}", LOG_COMMAND
f"使用Redis缓存地址: {cache_config.redis_host}",
LOG_COMMAND,
)
else:
# 默认使用内存缓存
self._cache_backend = SimpleMemoryCache(
serializer=JsonSerializer(),
namespace=CACHE_KEY_PREFIX,
timeout=30,
ttl=cache_config.redis_expire,
return self._cache_backend
except ImportError as e:
logger.error(
"导入aiocache[redis]失败,将默认使用内存缓存...",
LOG_COMMAND,
e=e,
)
logger.info("使用内存缓存", LOG_COMMAND)
except ImportError:
logger.error("导入aiocache模块失败使用内存缓存", LOG_COMMAND)
# 使用内存缓存
self._cache_backend = AioCache(
cache_class=AioCache.MEMORY,
serializer=JsonSerializer(),
namespace=CACHE_KEY_PREFIX,
timeout=30,
ttl=cache_config.redis_expire,
)
else:
logger.info("使用内存缓存", LOG_COMMAND)
# 默认使用内存缓存
self._cache_backend = SimpleMemoryCache(
serializer=JsonSerializer(),
namespace=CACHE_KEY_PREFIX,
timeout=30,
ttl=ttl,
)
return self._cache_backend
@property

View File

@ -54,11 +54,7 @@ class CacheDict:
key: 字典键
value: 字典值
"""
# 计算过期时间
expire_time = 0
if self.expire > 0:
expire_time = time.time() + self.expire
expire_time = time.time() + self.expire if self.expire > 0 else 0
self._data[key] = CacheData(value=value, expire_time=expire_time)
def __delitem__(self, key: str) -> None:
@ -274,12 +270,11 @@ class CacheList:
self.clear()
raise IndexError(f"列表索引 {index} 超出范围")
if 0 <= index < len(self._data):
del self._data[index]
# 更新过期时间
self._update_expire_time()
else:
if not 0 <= index < len(self._data):
raise IndexError(f"列表索引 {index} 超出范围")
del self._data[index]
# 更新过期时间
self._update_expire_time()
def __len__(self) -> int:
"""获取列表长度
@ -427,6 +422,7 @@ class CacheList:
self.clear()
return 0
# sourcery skip: simplify-constant-sum
return sum(1 for item in self._data if item.value == value)
def _is_expired(self) -> bool:
@ -435,10 +431,7 @@ class CacheList:
def _update_expire_time(self) -> None:
"""更新过期时间"""
if self.expire > 0:
self._expire_time = time.time() + self.expire
else:
self._expire_time = 0
self._expire_time = time.time() + self.expire if self.expire > 0 else 0
def __str__(self) -> str:
"""字符串表示

View File

@ -0,0 +1,146 @@
import asyncio
from urllib.parse import urlparse
import nonebot
from nonebot.utils import is_coroutine_callable
from tortoise import Tortoise
from tortoise.connection import connections
from zhenxun.configs.config import BotConfig
from zhenxun.services.log import logger
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from .base_model import Model
from .config import (
DB_TIMEOUT_SECONDS,
MYSQL_CONFIG,
POSTGRESQL_CONFIG,
SLOW_QUERY_THRESHOLD,
SQLITE_CONFIG,
db_model,
prompt,
)
from .exceptions import DbConnectError, DbUrlIsNode
from .utils import with_db_timeout
MODELS = db_model.models
SCRIPT_METHOD = db_model.script_method
__all__ = [
"DB_TIMEOUT_SECONDS",
"MODELS",
"SCRIPT_METHOD",
"SLOW_QUERY_THRESHOLD",
"DbConnectError",
"DbUrlIsNode",
"Model",
"disconnect",
"init",
"with_db_timeout",
]
driver = nonebot.get_driver()
def get_config() -> dict:
"""获取数据库配置"""
parsed = urlparse(BotConfig.db_url)
# 基础配置
config = {
"connections": {
"default": BotConfig.db_url # 默认直接使用连接字符串
},
"apps": {
"models": {
"models": db_model.models,
"default_connection": "default",
}
},
"timezone": "Asia/Shanghai",
}
# 根据数据库类型应用高级配置
if parsed.scheme.startswith("postgres"):
config["connections"]["default"] = {
"engine": "tortoise.backends.asyncpg",
"credentials": {
"host": parsed.hostname,
"port": parsed.port or 5432,
"user": parsed.username,
"password": parsed.password,
"database": parsed.path[1:],
},
**POSTGRESQL_CONFIG,
}
elif parsed.scheme == "mysql":
config["connections"]["default"] = {
"engine": "tortoise.backends.mysql",
"credentials": {
"host": parsed.hostname,
"port": parsed.port or 3306,
"user": parsed.username,
"password": parsed.password,
"database": parsed.path[1:],
},
**MYSQL_CONFIG,
}
elif parsed.scheme == "sqlite":
config["connections"]["default"] = {
"engine": "tortoise.backends.sqlite",
"credentials": {
"file_path": parsed.path or ":memory:",
},
**SQLITE_CONFIG,
}
return config
@PriorityLifecycle.on_startup(priority=1)
async def init():
global MODELS, SCRIPT_METHOD
MODELS = db_model.models
SCRIPT_METHOD = db_model.script_method
if not BotConfig.db_url:
error = prompt.format(host=driver.config.host, port=driver.config.port)
raise DbUrlIsNode("\n" + error.strip())
try:
await Tortoise.init(
config=get_config(),
)
if db_model.script_method:
db = Tortoise.get_connection("default")
logger.debug(
"即将运行SCRIPT_METHOD方法, 合计 "
f"<u><y>{len(db_model.script_method)}</y></u> 个..."
)
sql_list = []
for module, func in db_model.script_method:
try:
sql = await func() if is_coroutine_callable(func) else func()
if sql:
sql_list += sql
except Exception as e:
logger.debug(f"{module} 执行SCRIPT_METHOD方法出错...", e=e)
for sql in sql_list:
logger.debug(f"执行SQL: {sql}")
try:
await asyncio.wait_for(
db.execute_query_dict(sql), timeout=DB_TIMEOUT_SECONDS
)
# await TestSQL.raw(sql)
except Exception as e:
logger.debug(f"执行SQL: {sql} 错误...", e=e)
if sql_list:
logger.debug("SCRIPT_METHOD方法执行完毕!")
logger.debug("开始生成数据库表结构...")
await Tortoise.generate_schemas()
logger.debug("数据库表结构生成完毕!")
logger.info("Database loaded successfully!")
except Exception as e:
raise DbConnectError(f"数据库连接错误... e:{e}") from e
async def disconnect():
await connections.close_all()

View File

@ -1,56 +1,20 @@
import asyncio
from collections.abc import Iterable
import contextlib
import time
from typing import Any, ClassVar
from typing_extensions import Self
from urllib.parse import urlparse
from nonebot import get_driver
from nonebot.utils import is_coroutine_callable
from tortoise import Tortoise
from tortoise.backends.base.client import BaseDBAsyncClient
from tortoise.connection import connections
from tortoise.exceptions import IntegrityError, MultipleObjectsReturned
from tortoise.models import Model as TortoiseModel
from tortoise.transactions import in_transaction
from zhenxun.configs.config import BotConfig
from zhenxun.services.cache import CacheRoot
from zhenxun.services.log import logger
from zhenxun.utils.enum import DbLockType
from zhenxun.utils.exception import HookPriorityException
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
driver = get_driver()
SCRIPT_METHOD = []
MODELS: list[str] = []
# 数据库操作超时设置(秒)
DB_TIMEOUT_SECONDS = 3.0
# 性能监控阈值(秒)
SLOW_QUERY_THRESHOLD = 0.5
LOG_COMMAND = "DbContext"
async def with_db_timeout(
coro, timeout: float = DB_TIMEOUT_SECONDS, operation: str | None = None
):
"""带超时控制的数据库操作"""
start_time = time.time()
try:
result = await asyncio.wait_for(coro, timeout=timeout)
elapsed = time.time() - start_time
if elapsed > SLOW_QUERY_THRESHOLD and operation:
logger.warning(f"慢查询: {operation} 耗时 {elapsed:.3f}s", LOG_COMMAND)
return result
except asyncio.TimeoutError:
if operation:
logger.error(f"数据库操作超时: {operation} (>{timeout}s)", LOG_COMMAND)
raise
from .config import LOG_COMMAND, db_model
from .utils import with_db_timeout
class Model(TortoiseModel):
@ -63,11 +27,11 @@ class Model(TortoiseModel):
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if cls.__module__ not in MODELS:
MODELS.append(cls.__module__)
if cls.__module__ not in db_model.models:
db_model.models.append(cls.__module__)
if func := getattr(cls, "_run_script", None):
SCRIPT_METHOD.append((cls.__module__, func))
db_model.script_method.append((cls.__module__, func))
@classmethod
def get_cache_type(cls) -> str | None:
@ -322,144 +286,3 @@ class Model(TortoiseModel):
f"数据库操作异常: {cls.__name__}.safe_get_or_none, {e!s}", LOG_COMMAND
)
raise
class DbUrlIsNode(HookPriorityException):
"""
数据库链接地址为空
"""
pass
class DbConnectError(Exception):
"""
数据库连接错误
"""
pass
POSTGRESQL_CONFIG = {
"max_size": 30, # 最大连接数
"min_size": 5, # 最小保持的连接数(可选)
}
MYSQL_CONFIG = {
"max_connections": 20, # 最大连接数
"connect_timeout": 30, # 连接超时(可选)
}
SQLITE_CONFIG = {
"journal_mode": "WAL", # 提高并发写入性能
"timeout": 30, # 锁等待超时(可选)
}
def get_config() -> dict:
"""获取数据库配置"""
parsed = urlparse(BotConfig.db_url)
# 基础配置
config = {
"connections": {
"default": BotConfig.db_url # 默认直接使用连接字符串
},
"apps": {
"models": {
"models": MODELS,
"default_connection": "default",
}
},
"timezone": "Asia/Shanghai",
}
# 根据数据库类型应用高级配置
if parsed.scheme.startswith("postgres"):
config["connections"]["default"] = {
"engine": "tortoise.backends.asyncpg",
"credentials": {
"host": parsed.hostname,
"port": parsed.port or 5432,
"user": parsed.username,
"password": parsed.password,
"database": parsed.path[1:],
},
**POSTGRESQL_CONFIG,
}
elif parsed.scheme == "mysql":
config["connections"]["default"] = {
"engine": "tortoise.backends.mysql",
"credentials": {
"host": parsed.hostname,
"port": parsed.port or 3306,
"user": parsed.username,
"password": parsed.password,
"database": parsed.path[1:],
},
**MYSQL_CONFIG,
}
elif parsed.scheme == "sqlite":
config["connections"]["default"] = {
"engine": "tortoise.backends.sqlite",
"credentials": {
"file_path": parsed.path or ":memory:",
},
**SQLITE_CONFIG,
}
return config
@PriorityLifecycle.on_startup(priority=1)
async def init():
if not BotConfig.db_url:
# raise DbUrlIsNode("数据库配置为空,请在.env.dev中配置DB_URL...")
error = f"""
**********************************************************************
🌟 **************************** 配置为空 ************************* 🌟
🚀 请打开 WebUi 进行基础配置 🚀
🌐 配置地址http://{driver.config.host}:{driver.config.port}/#/configure 🌐
***********************************************************************
***********************************************************************
"""
raise DbUrlIsNode("\n" + error.strip())
try:
await Tortoise.init(
config=get_config(),
)
if SCRIPT_METHOD:
db = Tortoise.get_connection("default")
logger.debug(
"即将运行SCRIPT_METHOD方法, 合计 "
f"<u><y>{len(SCRIPT_METHOD)}</y></u> 个..."
)
sql_list = []
for module, func in SCRIPT_METHOD:
try:
sql = await func() if is_coroutine_callable(func) else func()
if sql:
sql_list += sql
except Exception as e:
logger.debug(f"{module} 执行SCRIPT_METHOD方法出错...", e=e)
for sql in sql_list:
logger.debug(f"执行SQL: {sql}")
try:
await asyncio.wait_for(
db.execute_query_dict(sql), timeout=DB_TIMEOUT_SECONDS
)
# await TestSQL.raw(sql)
except Exception as e:
logger.debug(f"执行SQL: {sql} 错误...", e=e)
if sql_list:
logger.debug("SCRIPT_METHOD方法执行完毕!")
logger.debug("开始生成数据库表结构...")
await Tortoise.generate_schemas()
logger.debug("数据库表结构生成完毕!")
logger.info("Database loaded successfully!")
except Exception as e:
raise DbConnectError(f"数据库连接错误... e:{e}") from e
async def disconnect():
await connections.close_all()

View File

@ -0,0 +1,46 @@
from collections.abc import Callable
from pydantic import BaseModel
# 数据库操作超时设置(秒)
DB_TIMEOUT_SECONDS = 3.0
# 性能监控阈值(秒)
SLOW_QUERY_THRESHOLD = 0.5
LOG_COMMAND = "DbContext"
POSTGRESQL_CONFIG = {
"max_size": 30, # 最大连接数
"min_size": 5, # 最小保持的连接数(可选)
}
MYSQL_CONFIG = {
"max_connections": 20, # 最大连接数
"connect_timeout": 30, # 连接超时(可选)
}
SQLITE_CONFIG = {
"journal_mode": "WAL", # 提高并发写入性能
"timeout": 30, # 锁等待超时(可选)
}
class DbModel(BaseModel):
script_method: list[tuple[str, Callable]] = []
models: list[str] = []
db_model = DbModel()
prompt = """
**********************************************************************
🌟 **************************** 配置为空 ************************* 🌟
🚀 请打开 WebUi 进行基础配置 🚀
🌐 配置地址http://{host}:{port}/#/configure 🌐
***********************************************************************
***********************************************************************
"""

View File

@ -0,0 +1,14 @@
class DbUrlIsNode(Exception):
"""
数据库链接地址为空
"""
pass
class DbConnectError(Exception):
"""
数据库连接错误
"""
pass

View File

@ -0,0 +1,27 @@
import asyncio
import time
from zhenxun.services.log import logger
from .config import (
DB_TIMEOUT_SECONDS,
LOG_COMMAND,
SLOW_QUERY_THRESHOLD,
)
async def with_db_timeout(
coro, timeout: float = DB_TIMEOUT_SECONDS, operation: str | None = None
):
"""带超时控制的数据库操作"""
start_time = time.time()
try:
result = await asyncio.wait_for(coro, timeout=timeout)
elapsed = time.time() - start_time
if elapsed > SLOW_QUERY_THRESHOLD and operation:
logger.warning(f"慢查询: {operation} 耗时 {elapsed:.3f}s", LOG_COMMAND)
return result
except asyncio.TimeoutError:
if operation:
logger.error(f"数据库操作超时: {operation} (>{timeout}s)", LOG_COMMAND)
raise

View File

@ -54,19 +54,20 @@ class PluginInitManager:
@classmethod
async def install_all(cls):
"""运行所有插件安装方法"""
if cls.plugins:
for module_path, model in cls.plugins.items():
if model.install:
class_ = model.class_()
try:
logger.debug(f"开始执行: {module_path}:install 方法")
if is_coroutine_callable(class_.install):
await class_.install()
else:
class_.install() # type: ignore
logger.debug(f"执行: {module_path}:install 完成")
except Exception as e:
logger.error(f"执行: {module_path}:install 失败", e=e)
if not cls.plugins:
return
for module_path, model in cls.plugins.items():
if model.install:
class_ = model.class_()
try:
logger.debug(f"开始执行: {module_path}:install 方法")
if is_coroutine_callable(class_.install):
await class_.install()
else:
class_.install() # type: ignore
logger.debug(f"执行: {module_path}:install 完成")
except Exception as e:
logger.error(f"执行: {module_path}:install 失败", e=e)
@classmethod
async def install(cls, module_path: str):

View File

@ -1,13 +1,181 @@
from typing import Any
from typing import Any, Literal
from nonebot.adapters import Bot, Event
from nonebot.internal.params import Depends
from nonebot.matcher import Matcher
from nonebot.params import Command
from nonebot.permission import SUPERUSER
from nonebot_plugin_session import EventSession
from nonebot_plugin_uninfo import Uninfo
from zhenxun.configs.config import Config
from zhenxun.utils.limiters import ConcurrencyLimiter, FreqLimiter, RateLimiter
from zhenxun.utils.message import MessageUtils
from zhenxun.utils.time_utils import TimeUtils
_coolers: dict[str, FreqLimiter] = {}
_rate_limiters: dict[str, RateLimiter] = {}
_concurrency_limiters: dict[str, ConcurrencyLimiter] = {}
def _create_limiter_dependency(
limiter_class: type,
limiter_storage: dict,
limiter_init_args: dict[str, Any],
scope: Literal["user", "group", "global"],
prompt: str,
**kwargs,
):
"""
一个高阶函数用于创建不同类型的限制器依赖
参数:
limiter_class: 限制器类 (FreqLimiter, RateLimiter, etc.).
limiter_storage: 用于存储限制器实例的字典.
limiter_init_args: 限制器类的初始化参数.
scope: 限制作用域.
prompt: 触发限制时的提示信息.
**kwargs: 传递给特定限制器逻辑的额外参数.
"""
async def dependency(
matcher: Matcher, session: EventSession, bot: Bot, event: Event
) -> bool:
if await SUPERUSER(bot, event):
return True
handler_id = (
f"{matcher.plugin_name}:{matcher.handlers[0].call.__code__.co_firstlineno}"
)
key: str | None = None
if scope == "user":
key = session.id1
elif scope == "group":
key = session.id3 or session.id2 or session.id1
elif scope == "global":
key = f"global_{handler_id}"
if not key:
return True
if handler_id not in limiter_storage:
limiter_storage[handler_id] = limiter_class(**limiter_init_args)
limiter = limiter_storage[handler_id]
if isinstance(limiter, ConcurrencyLimiter):
await limiter.acquire(key)
matcher.state["_concurrency_limiter_info"] = {
"limiter": limiter,
"key": key,
}
return True
else:
if limiter.check(key):
if isinstance(limiter, FreqLimiter):
limiter.start_cd(
key, kwargs.get("duration_sec", limiter.default_cd)
)
return True
else:
left_time = limiter.left_time(key)
format_kwargs = {
"cd_str": TimeUtils.format_duration(left_time),
**(kwargs.get("prompt_format_kwargs", {})),
}
message = prompt.format(**format_kwargs)
await matcher.finish(message)
return Depends(dependency)
def Cooldown(
duration: str,
*,
scope: Literal["user", "group", "global"] = "user",
prompt: str = "操作过于频繁,请等待 {cd_str}",
) -> bool:
"""声明式冷却检查依赖,限制用户操作频率
参数:
duration: 冷却时间字符串 (e.g., "30s", "10m", "1h")
scope: 冷却作用域
prompt: 自定义的冷却提示消息可使用 {cd_str} 占位符
返回:
bool: 是否允许执行
"""
try:
parsed_seconds = TimeUtils.parse_time_string(duration)
except ValueError as e:
raise ValueError(f"Cooldown装饰器中的duration格式错误: {e}")
return _create_limiter_dependency(
limiter_class=FreqLimiter,
limiter_storage=_coolers,
limiter_init_args={"default_cd_seconds": parsed_seconds},
scope=scope,
prompt=prompt,
duration_sec=parsed_seconds,
)
def RateLimit(
count: int,
duration: str,
*,
scope: Literal["user", "group", "global"] = "user",
prompt: str = "太快了,在 {duration_str} 内只能触发{limit}次,请等待 {cd_str}",
) -> bool:
"""声明式速率限制依赖,在指定时间窗口内限制操作次数
参数:
count: 在时间窗口内允许的最大调用次数
duration: 时间窗口字符串 (e.g., "1m", "1h")
scope: 限制作用域
prompt: 自定义的提示消息可使用 {cd_str}, {duration_str}, {limit} 占位符
返回:
bool: 是否允许执行
"""
try:
parsed_seconds = TimeUtils.parse_time_string(duration)
except ValueError as e:
raise ValueError(f"RateLimit装饰器中的duration格式错误: {e}")
return _create_limiter_dependency(
limiter_class=RateLimiter,
limiter_storage=_rate_limiters,
limiter_init_args={"max_calls": count, "time_window": parsed_seconds},
scope=scope,
prompt=prompt,
prompt_format_kwargs={"duration_str": duration, "limit": count},
)
def ConcurrencyLimit(
count: int,
*,
scope: Literal["user", "group", "global"] = "global",
prompt: str | None = "当前功能繁忙,请稍后再试...",
) -> bool:
"""声明式并发数限制依赖,控制某个功能同时执行的实例数量
参数:
count: 最大并发数
scope: 限制作用域
prompt: 提示消息暂未使用主要用于未来扩展超时功能
返回:
bool: 是否允许执行
"""
return _create_limiter_dependency(
limiter_class=ConcurrencyLimiter,
limiter_storage=_concurrency_limiters,
limiter_init_args={"max_concurrent": count},
scope=scope,
prompt=prompt or "",
)
def CheckUg(check_user: bool = True, check_group: bool = True):
@ -75,7 +243,6 @@ def GetConfig(
if module_:
value = Config.get_config(module_, config, default_value)
if value is None and prompt:
# await matcher.finish(prompt or f"配置项 {config} 未填写!")
await matcher.finish(prompt)
return value

View File

@ -1,3 +1,17 @@
from nonebot.exception import IgnoredException
class CooldownError(IgnoredException):
"""
冷却异常用于在冷却时中断事件处理
继承自 IgnoredException不会在控制台留下错误堆栈
"""
def __init__(self, message: str):
self.message = message
super().__init__(message)
class HookPriorityException(BaseException):
"""
钩子优先级异常

140
zhenxun/utils/limiters.py Normal file
View File

@ -0,0 +1,140 @@
import asyncio
from collections import defaultdict, deque
import time
from typing import Any
class FreqLimiter:
"""
命令冷却检测用户是否处于冷却状态
"""
def __init__(self, default_cd_seconds: int):
self.next_time: dict[Any, float] = defaultdict(float)
self.default_cd = default_cd_seconds
def check(self, key: Any) -> bool:
return time.time() >= self.next_time[key]
def start_cd(self, key: Any, cd_time: int = 0):
self.next_time[key] = time.time() + (
cd_time if cd_time > 0 else self.default_cd
)
def left_time(self, key: Any) -> float:
return max(0.0, self.next_time[key] - time.time())
class CountLimiter:
"""
每日调用命令次数限制
"""
tz = None
def __init__(self, max_num: int):
self.today = -1
self.count: dict[Any, int] = defaultdict(int)
self.max = max_num
def check(self, key: Any) -> bool:
import datetime
day = datetime.datetime.now().day
if day != self.today:
self.today = day
self.count.clear()
return self.count[key] < self.max
def get_num(self, key: Any) -> int:
return self.count[key]
def increase(self, key: Any, num: int = 1):
self.count[key] += num
def reset(self, key: Any):
self.count[key] = 0
class UserBlockLimiter:
"""
检测用户是否正在调用命令 (简单阻塞锁)
"""
def __init__(self):
self.flag_data: dict[Any, bool] = defaultdict(bool)
self.time: dict[Any, float] = defaultdict(float)
def set_true(self, key: Any):
self.time[key] = time.time()
self.flag_data[key] = True
def set_false(self, key: Any):
self.flag_data[key] = False
def check(self, key: Any) -> bool:
if self.flag_data[key] and time.time() - self.time[key] > 30:
self.set_false(key)
return not self.flag_data[key]
class RateLimiter:
"""
一个简单的基于时间窗口的速率限制器
"""
def __init__(self, max_calls: int, time_window: int):
self.requests: dict[Any, deque[float]] = defaultdict(deque)
self.max_calls = max_calls
self.time_window = time_window
def check(self, key: Any) -> bool:
"""检查是否超出速率限制。如果未超出,则记录本次调用。"""
now = time.time()
while self.requests[key] and self.requests[key][0] <= now - self.time_window:
self.requests[key].popleft()
if len(self.requests[key]) < self.max_calls:
self.requests[key].append(now)
return True
return False
def left_time(self, key: Any) -> float:
"""计算距离下次可调用还需等待的时间"""
if self.requests[key]:
return max(0.0, self.requests[key][0] + self.time_window - time.time())
return 0.0
class ConcurrencyLimiter:
"""
一个基于 asyncio.Semaphore 的并发限制器
"""
def __init__(self, max_concurrent: int):
self._semaphores: dict[Any, asyncio.Semaphore] = {}
self.max_concurrent = max_concurrent
self._active_tasks: dict[Any, int] = defaultdict(int)
def _get_semaphore(self, key: Any) -> asyncio.Semaphore:
if key not in self._semaphores:
self._semaphores[key] = asyncio.Semaphore(self.max_concurrent)
return self._semaphores[key]
async def acquire(self, key: Any):
"""获取一个信号量,如果达到并发上限则会阻塞等待。"""
semaphore = self._get_semaphore(key)
await semaphore.acquire()
self._active_tasks[key] += 1
def release(self, key: Any):
"""释放一个信号量。"""
if key in self._semaphores:
if self._active_tasks[key] > 0:
self._semaphores[key].release()
self._active_tasks[key] -= 1
else:
import logging
logging.warning(f"尝试释放键 '{key}' 的信号量时,计数已经为零。")

View File

@ -49,11 +49,14 @@ class Config(BaseModel):
class MessageUtils:
@classmethod
def __build_message(cls, msg_list: list[MESSAGE_TYPE]) -> list[Text | Image]:
def __build_message(
cls, msg_list: list[MESSAGE_TYPE], format_args: dict | None = None
) -> list[Text | Image]:
"""构造消息
参数:
msg_list: 消息列表
format_args: 用于格式化字符串的参数字典.
返回:
list[Text | Text]: 构造完成的消息列表
@ -65,7 +68,15 @@ class MessageUtils:
if msg.startswith("base64://"):
message_list.append(Image(raw=BytesIO(base64.b64decode(msg[9:]))))
else:
message_list.append(Text(msg))
formatted_msg = msg
if format_args:
try:
formatted_msg = msg.format_map(format_args)
except (KeyError, IndexError) as e:
logger.debug(
f"格式化字符串 '{msg}' 失败 ({e}),将使用原始文本。"
)
message_list.append(Text(formatted_msg))
elif isinstance(msg, int | float):
message_list.append(Text(str(msg)))
elif isinstance(msg, Path):
@ -90,12 +101,15 @@ class MessageUtils:
@classmethod
def build_message(
cls, msg_list: MESSAGE_TYPE | list[MESSAGE_TYPE | list[MESSAGE_TYPE]]
cls,
msg_list: MESSAGE_TYPE | list[MESSAGE_TYPE | list[MESSAGE_TYPE]],
format_args: dict | None = None,
) -> UniMessage:
"""构造消息
参数:
msg_list: 消息列表
format_args: 用于格式化字符串的参数字典.
返回:
UniMessage: 构造完成的消息列表
@ -105,7 +119,7 @@ class MessageUtils:
msg_list = [msg_list]
for m in msg_list:
_data = m if isinstance(m, list) else [m]
message_list += cls.__build_message(_data) # type: ignore
message_list += cls.__build_message(_data, format_args)
return UniMessage(message_list)
@classmethod

View File

@ -0,0 +1,91 @@
from datetime import date, datetime
import re
import pytz
class TimeUtils:
DEFAULT_TIMEZONE = pytz.timezone("Asia/Shanghai")
@classmethod
def get_day_start(cls, target_date: date | datetime | None = None) -> datetime:
"""获取某天的0点时间
返回:
datetime: 今天某天的0点时间
"""
if not target_date:
target_date = datetime.now(cls.DEFAULT_TIMEZONE)
if isinstance(target_date, datetime) and target_date.tzinfo is None:
target_date = cls.DEFAULT_TIMEZONE.localize(target_date)
return (
target_date.replace(hour=0, minute=0, second=0, microsecond=0)
if isinstance(target_date, datetime)
else datetime.combine(
target_date, datetime.min.time(), tzinfo=cls.DEFAULT_TIMEZONE
)
)
@classmethod
def is_valid_date(cls, date_text: str, separator: str = "-") -> bool:
"""日期是否合法
参数:
date_text: 日期
separator: 分隔符
返回:
bool: 日期是否合法
"""
try:
datetime.strptime(date_text, f"%Y{separator}%m{separator}%d")
return True
except ValueError:
return False
@classmethod
def parse_time_string(cls, time_str: str) -> int:
"""
将带有单位的时间字符串 (e.g., "10s", "5m", "1h") 解析为总秒数
"""
time_str = time_str.lower().strip()
match = re.match(r"^(\d+)([smh])$", time_str)
if not match:
raise ValueError(
f"无效的时间格式: '{time_str}'。请使用如 '30s', '10m', '2h' 的格式。"
)
value, unit = int(match.group(1)), match.group(2)
if unit == "s":
return value
if unit == "m":
return value * 60
if unit == "h":
return value * 3600
return 0
@classmethod
def format_duration(cls, seconds: float) -> str:
"""
将秒数格式化为易于阅读的字符串 (例如 "1小时5分钟", "30.5秒")
"""
seconds = round(seconds, 1)
if seconds < 0.1:
return "不到1秒"
if seconds < 60:
return f"{seconds}"
minutes, sec_remainder = divmod(int(seconds), 60)
if minutes < 60:
if sec_remainder == 0:
return f"{minutes}分钟"
return f"{minutes}分钟{sec_remainder}"
hours, rem_minutes = divmod(minutes, 60)
if rem_minutes == 0:
return f"{hours}小时"
return f"{hours}小时{rem_minutes}分钟"

View File

@ -1,19 +1,19 @@
from collections import defaultdict
from dataclasses import dataclass
from datetime import date, datetime
from datetime import datetime
import os
from pathlib import Path
import time
from typing import Any, ClassVar
from typing import ClassVar
import httpx
from nonebot_plugin_uninfo import Uninfo
import pypinyin
import pytz
from zhenxun.configs.config import Config
from zhenxun.services.log import logger
from .limiters import CountLimiter, FreqLimiter, UserBlockLimiter # noqa: F401
@dataclass
class EntityIDs:
@ -64,78 +64,6 @@ class ResourceDirManager:
cls.__tree_append(path, deep)
class CountLimiter:
"""
每日调用命令次数限制
"""
tz = pytz.timezone("Asia/Shanghai")
def __init__(self, max_num):
self.today = -1
self.count = defaultdict(int)
self.max = max_num
def check(self, key) -> bool:
day = datetime.now(self.tz).day
if day != self.today:
self.today = day
self.count.clear()
return self.count[key] < self.max
def get_num(self, key):
return self.count[key]
def increase(self, key, num=1):
self.count[key] += num
def reset(self, key):
self.count[key] = 0
class UserBlockLimiter:
"""
检测用户是否正在调用命令
"""
def __init__(self):
self.flag_data = defaultdict(bool)
self.time = time.time()
def set_true(self, key: Any):
self.time = time.time()
self.flag_data[key] = True
def set_false(self, key: Any):
self.flag_data[key] = False
def check(self, key: Any) -> bool:
if time.time() - self.time > 30:
self.set_false(key)
return not self.flag_data[key]
class FreqLimiter:
"""
命令冷却检测用户是否处于冷却状态
"""
def __init__(self, default_cd_seconds: int):
self.next_time = defaultdict(float)
self.default_cd = default_cd_seconds
def check(self, key: Any) -> bool:
return time.time() >= self.next_time[key]
def start_cd(self, key: Any, cd_time: int = 0):
self.next_time[key] = time.time() + (
cd_time if cd_time > 0 else self.default_cd
)
def left_time(self, key: Any) -> float:
return self.next_time[key] - time.time()
def cn2py(word: str) -> str:
"""将字符串转化为拼音
@ -277,20 +205,3 @@ def is_number(text: str) -> bool:
return True
except ValueError:
return False
class TimeUtils:
@classmethod
def get_day_start(cls, target_date: date | datetime | None = None) -> datetime:
"""获取某天的0点时间
返回:
datetime: 今天某天的0点时间
"""
if not target_date:
target_date = datetime.now()
return (
target_date.replace(hour=0, minute=0, second=0, microsecond=0)
if isinstance(target_date, datetime)
else datetime.combine(target_date, datetime.min.time())
)