From d218c569d46760e630cef14a5cbf12421b612891 Mon Sep 17 00:00:00 2001 From: HibiKier <45528451+HibiKier@users.noreply.github.com> Date: Tue, 15 Jul 2025 17:08:42 +0800 Subject: [PATCH 1/2] =?UTF-8?q?:sparkles:=20=E6=A0=BC=E5=BC=8F=E5=8C=96db?= =?UTF-8?q?=5Fcontext=20(#1980)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * :sparkles: 格式化db_context * :fire: 移除旧db-context * :zap: 添加旧版本兼容 --- zhenxun/builtin_plugins/__init__.py | 11 +- zhenxun/services/__init__.py | 3 +- zhenxun/services/cache/__init__.py | 59 +++--- zhenxun/services/cache/cache_containers.py | 21 +- zhenxun/services/db_context/__init__.py | 146 ++++++++++++++ .../base_model.py} | 187 +----------------- zhenxun/services/db_context/config.py | 46 +++++ zhenxun/services/db_context/exceptions.py | 14 ++ zhenxun/services/db_context/utils.py | 27 +++ zhenxun/services/plugin_init.py | 27 +-- 10 files changed, 293 insertions(+), 248 deletions(-) create mode 100644 zhenxun/services/db_context/__init__.py rename zhenxun/services/{db_context.py => db_context/base_model.py} (64%) create mode 100644 zhenxun/services/db_context/config.py create mode 100644 zhenxun/services/db_context/exceptions.py create mode 100644 zhenxun/services/db_context/utils.py diff --git a/zhenxun/builtin_plugins/__init__.py b/zhenxun/builtin_plugins/__init__.py index 4003e506..a5aa7a4b 100644 --- a/zhenxun/builtin_plugins/__init__.py +++ b/zhenxun/builtin_plugins/__init__.py @@ -5,7 +5,7 @@ import nonebot from nonebot.adapters import Bot from nonebot.drivers import Driver from tortoise import Tortoise -from tortoise.exceptions import OperationalError +from tortoise.exceptions import IntegrityError, OperationalError import ujson as json from zhenxun.models.bot_connect_log import BotConnectLog @@ -30,9 +30,12 @@ async def _(bot: Bot): bot_id=bot.self_id, platform=bot.adapter, connect_time=datetime.now(), type=1 ) if not await BotConsole.exists(bot_id=bot.self_id): - await BotConsole.create( - bot_id=bot.self_id, platform=PlatformUtils.get_platform(bot) - ) + try: + await BotConsole.create( + bot_id=bot.self_id, platform=PlatformUtils.get_platform(bot) + ) + except IntegrityError as e: + logger.warning(f"记录bot: {bot.self_id} 数据已存在...", e=e) @driver.on_bot_disconnect diff --git a/zhenxun/services/__init__.py b/zhenxun/services/__init__.py index 4c820b87..b3dc292e 100644 --- a/zhenxun/services/__init__.py +++ b/zhenxun/services/__init__.py @@ -18,7 +18,7 @@ require("nonebot_plugin_htmlrender") require("nonebot_plugin_uninfo") require("nonebot_plugin_waiter") -from .db_context import Model, disconnect +from .db_context import Model, disconnect, with_db_timeout from .llm import ( AI, AIConfig, @@ -80,4 +80,5 @@ __all__ = [ "search_multimodal", "set_global_default_model_name", "tool_registry", + "with_db_timeout", ] diff --git a/zhenxun/services/cache/__init__.py b/zhenxun/services/cache/__init__.py index 76b05a5c..e2ce9c30 100644 --- a/zhenxun/services/cache/__init__.py +++ b/zhenxun/services/cache/__init__.py @@ -63,6 +63,7 @@ from functools import wraps from typing import Any, ClassVar, Generic, TypeVar, get_type_hints from aiocache import Cache as AioCache +from aiocache import SimpleMemoryCache from aiocache.base import BaseCache from aiocache.serializers import JsonSerializer import nonebot @@ -392,22 +393,14 @@ class CacheManager: def cache_backend(self) -> BaseCache | AioCache: """获取缓存后端""" if self._cache_backend is None: - try: - from aiocache import RedisCache, SimpleMemoryCache + ttl = cache_config.redis_expire + if cache_config.cache_mode == CacheMode.NONE: + ttl = 0 + logger.info("缓存功能已禁用,使用非持久化内存缓存", LOG_COMMAND) + elif cache_config.cache_mode == CacheMode.REDIS and cache_config.redis_host: + try: + from aiocache import RedisCache - if cache_config.cache_mode == CacheMode.NONE: - # 使用内存缓存但禁用持久化 - self._cache_backend = SimpleMemoryCache( - serializer=JsonSerializer(), - namespace=CACHE_KEY_PREFIX, - timeout=30, - ttl=0, # 设置为0,不缓存 - ) - logger.info("缓存功能已禁用,使用非持久化内存缓存", LOG_COMMAND) - elif ( - cache_config.cache_mode == CacheMode.REDIS - and cache_config.redis_host - ): # 使用Redis缓存 self._cache_backend = RedisCache( serializer=JsonSerializer(), @@ -419,27 +412,25 @@ class CacheManager: password=cache_config.redis_password, ) logger.info( - f"使用Redis缓存,地址: {cache_config.redis_host}", LOG_COMMAND + f"使用Redis缓存,地址: {cache_config.redis_host}", + LOG_COMMAND, ) - else: - # 默认使用内存缓存 - self._cache_backend = SimpleMemoryCache( - serializer=JsonSerializer(), - namespace=CACHE_KEY_PREFIX, - timeout=30, - ttl=cache_config.redis_expire, + return self._cache_backend + except ImportError as e: + logger.error( + "导入aiocache[redis]失败,将默认使用内存缓存...", + LOG_COMMAND, + e=e, ) - logger.info("使用内存缓存", LOG_COMMAND) - except ImportError: - logger.error("导入aiocache模块失败,使用内存缓存", LOG_COMMAND) - # 使用内存缓存 - self._cache_backend = AioCache( - cache_class=AioCache.MEMORY, - serializer=JsonSerializer(), - namespace=CACHE_KEY_PREFIX, - timeout=30, - ttl=cache_config.redis_expire, - ) + else: + logger.info("使用内存缓存", LOG_COMMAND) + # 默认使用内存缓存 + self._cache_backend = SimpleMemoryCache( + serializer=JsonSerializer(), + namespace=CACHE_KEY_PREFIX, + timeout=30, + ttl=ttl, + ) return self._cache_backend @property diff --git a/zhenxun/services/cache/cache_containers.py b/zhenxun/services/cache/cache_containers.py index b0efe3fb..42b35239 100644 --- a/zhenxun/services/cache/cache_containers.py +++ b/zhenxun/services/cache/cache_containers.py @@ -54,11 +54,7 @@ class CacheDict: key: 字典键 value: 字典值 """ - # 计算过期时间 - expire_time = 0 - if self.expire > 0: - expire_time = time.time() + self.expire - + expire_time = time.time() + self.expire if self.expire > 0 else 0 self._data[key] = CacheData(value=value, expire_time=expire_time) def __delitem__(self, key: str) -> None: @@ -274,12 +270,11 @@ class CacheList: self.clear() raise IndexError(f"列表索引 {index} 超出范围") - if 0 <= index < len(self._data): - del self._data[index] - # 更新过期时间 - self._update_expire_time() - else: + if not 0 <= index < len(self._data): raise IndexError(f"列表索引 {index} 超出范围") + del self._data[index] + # 更新过期时间 + self._update_expire_time() def __len__(self) -> int: """获取列表长度 @@ -427,6 +422,7 @@ class CacheList: self.clear() return 0 + # sourcery skip: simplify-constant-sum return sum(1 for item in self._data if item.value == value) def _is_expired(self) -> bool: @@ -435,10 +431,7 @@ class CacheList: def _update_expire_time(self) -> None: """更新过期时间""" - if self.expire > 0: - self._expire_time = time.time() + self.expire - else: - self._expire_time = 0 + self._expire_time = time.time() + self.expire if self.expire > 0 else 0 def __str__(self) -> str: """字符串表示 diff --git a/zhenxun/services/db_context/__init__.py b/zhenxun/services/db_context/__init__.py new file mode 100644 index 00000000..26fd9bcd --- /dev/null +++ b/zhenxun/services/db_context/__init__.py @@ -0,0 +1,146 @@ +import asyncio +from urllib.parse import urlparse + +import nonebot +from nonebot.utils import is_coroutine_callable +from tortoise import Tortoise +from tortoise.connection import connections + +from zhenxun.configs.config import BotConfig +from zhenxun.services.log import logger +from zhenxun.utils.manager.priority_manager import PriorityLifecycle + +from .base_model import Model +from .config import ( + DB_TIMEOUT_SECONDS, + MYSQL_CONFIG, + POSTGRESQL_CONFIG, + SLOW_QUERY_THRESHOLD, + SQLITE_CONFIG, + db_model, + prompt, +) +from .exceptions import DbConnectError, DbUrlIsNode +from .utils import with_db_timeout + +MODELS = db_model.models +SCRIPT_METHOD = db_model.script_method + +__all__ = [ + "DB_TIMEOUT_SECONDS", + "MODELS", + "SCRIPT_METHOD", + "SLOW_QUERY_THRESHOLD", + "DbConnectError", + "DbUrlIsNode", + "Model", + "disconnect", + "init", + "with_db_timeout", +] + +driver = nonebot.get_driver() + + +def get_config() -> dict: + """获取数据库配置""" + parsed = urlparse(BotConfig.db_url) + + # 基础配置 + config = { + "connections": { + "default": BotConfig.db_url # 默认直接使用连接字符串 + }, + "apps": { + "models": { + "models": db_model.models, + "default_connection": "default", + } + }, + "timezone": "Asia/Shanghai", + } + + # 根据数据库类型应用高级配置 + if parsed.scheme.startswith("postgres"): + config["connections"]["default"] = { + "engine": "tortoise.backends.asyncpg", + "credentials": { + "host": parsed.hostname, + "port": parsed.port or 5432, + "user": parsed.username, + "password": parsed.password, + "database": parsed.path[1:], + }, + **POSTGRESQL_CONFIG, + } + elif parsed.scheme == "mysql": + config["connections"]["default"] = { + "engine": "tortoise.backends.mysql", + "credentials": { + "host": parsed.hostname, + "port": parsed.port or 3306, + "user": parsed.username, + "password": parsed.password, + "database": parsed.path[1:], + }, + **MYSQL_CONFIG, + } + elif parsed.scheme == "sqlite": + config["connections"]["default"] = { + "engine": "tortoise.backends.sqlite", + "credentials": { + "file_path": parsed.path or ":memory:", + }, + **SQLITE_CONFIG, + } + return config + + +@PriorityLifecycle.on_startup(priority=1) +async def init(): + global MODELS, SCRIPT_METHOD + + MODELS = db_model.models + SCRIPT_METHOD = db_model.script_method + if not BotConfig.db_url: + error = prompt.format(host=driver.config.host, port=driver.config.port) + raise DbUrlIsNode("\n" + error.strip()) + try: + await Tortoise.init( + config=get_config(), + ) + if db_model.script_method: + db = Tortoise.get_connection("default") + logger.debug( + "即将运行SCRIPT_METHOD方法, 合计 " + f"{len(db_model.script_method)} 个..." + ) + sql_list = [] + for module, func in db_model.script_method: + try: + sql = await func() if is_coroutine_callable(func) else func() + if sql: + sql_list += sql + except Exception as e: + logger.debug(f"{module} 执行SCRIPT_METHOD方法出错...", e=e) + for sql in sql_list: + logger.debug(f"执行SQL: {sql}") + try: + await asyncio.wait_for( + db.execute_query_dict(sql), timeout=DB_TIMEOUT_SECONDS + ) + # await TestSQL.raw(sql) + except Exception as e: + logger.debug(f"执行SQL: {sql} 错误...", e=e) + if sql_list: + logger.debug("SCRIPT_METHOD方法执行完毕!") + logger.debug("开始生成数据库表结构...") + await Tortoise.generate_schemas() + logger.debug("数据库表结构生成完毕!") + logger.info("Database loaded successfully!") + except Exception as e: + raise DbConnectError(f"数据库连接错误... e:{e}") from e + + +async def disconnect(): + await connections.close_all() diff --git a/zhenxun/services/db_context.py b/zhenxun/services/db_context/base_model.py similarity index 64% rename from zhenxun/services/db_context.py rename to zhenxun/services/db_context/base_model.py index e6c42472..3e0e23ef 100644 --- a/zhenxun/services/db_context.py +++ b/zhenxun/services/db_context/base_model.py @@ -1,56 +1,20 @@ import asyncio from collections.abc import Iterable import contextlib -import time from typing import Any, ClassVar from typing_extensions import Self -from urllib.parse import urlparse -from nonebot import get_driver -from nonebot.utils import is_coroutine_callable -from tortoise import Tortoise from tortoise.backends.base.client import BaseDBAsyncClient -from tortoise.connection import connections from tortoise.exceptions import IntegrityError, MultipleObjectsReturned from tortoise.models import Model as TortoiseModel from tortoise.transactions import in_transaction -from zhenxun.configs.config import BotConfig from zhenxun.services.cache import CacheRoot from zhenxun.services.log import logger from zhenxun.utils.enum import DbLockType -from zhenxun.utils.exception import HookPriorityException -from zhenxun.utils.manager.priority_manager import PriorityLifecycle -driver = get_driver() - -SCRIPT_METHOD = [] -MODELS: list[str] = [] - -# 数据库操作超时设置(秒) -DB_TIMEOUT_SECONDS = 3.0 - -# 性能监控阈值(秒) -SLOW_QUERY_THRESHOLD = 0.5 - -LOG_COMMAND = "DbContext" - - -async def with_db_timeout( - coro, timeout: float = DB_TIMEOUT_SECONDS, operation: str | None = None -): - """带超时控制的数据库操作""" - start_time = time.time() - try: - result = await asyncio.wait_for(coro, timeout=timeout) - elapsed = time.time() - start_time - if elapsed > SLOW_QUERY_THRESHOLD and operation: - logger.warning(f"慢查询: {operation} 耗时 {elapsed:.3f}s", LOG_COMMAND) - return result - except asyncio.TimeoutError: - if operation: - logger.error(f"数据库操作超时: {operation} (>{timeout}s)", LOG_COMMAND) - raise +from .config import LOG_COMMAND, db_model +from .utils import with_db_timeout class Model(TortoiseModel): @@ -63,11 +27,11 @@ class Model(TortoiseModel): def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) - if cls.__module__ not in MODELS: - MODELS.append(cls.__module__) + if cls.__module__ not in db_model.models: + db_model.models.append(cls.__module__) if func := getattr(cls, "_run_script", None): - SCRIPT_METHOD.append((cls.__module__, func)) + db_model.script_method.append((cls.__module__, func)) @classmethod def get_cache_type(cls) -> str | None: @@ -322,144 +286,3 @@ class Model(TortoiseModel): f"数据库操作异常: {cls.__name__}.safe_get_or_none, {e!s}", LOG_COMMAND ) raise - - -class DbUrlIsNode(HookPriorityException): - """ - 数据库链接地址为空 - """ - - pass - - -class DbConnectError(Exception): - """ - 数据库连接错误 - """ - - pass - - -POSTGRESQL_CONFIG = { - "max_size": 30, # 最大连接数 - "min_size": 5, # 最小保持的连接数(可选) -} - - -MYSQL_CONFIG = { - "max_connections": 20, # 最大连接数 - "connect_timeout": 30, # 连接超时(可选) -} - -SQLITE_CONFIG = { - "journal_mode": "WAL", # 提高并发写入性能 - "timeout": 30, # 锁等待超时(可选) -} - - -def get_config() -> dict: - """获取数据库配置""" - parsed = urlparse(BotConfig.db_url) - - # 基础配置 - config = { - "connections": { - "default": BotConfig.db_url # 默认直接使用连接字符串 - }, - "apps": { - "models": { - "models": MODELS, - "default_connection": "default", - } - }, - "timezone": "Asia/Shanghai", - } - - # 根据数据库类型应用高级配置 - if parsed.scheme.startswith("postgres"): - config["connections"]["default"] = { - "engine": "tortoise.backends.asyncpg", - "credentials": { - "host": parsed.hostname, - "port": parsed.port or 5432, - "user": parsed.username, - "password": parsed.password, - "database": parsed.path[1:], - }, - **POSTGRESQL_CONFIG, - } - elif parsed.scheme == "mysql": - config["connections"]["default"] = { - "engine": "tortoise.backends.mysql", - "credentials": { - "host": parsed.hostname, - "port": parsed.port or 3306, - "user": parsed.username, - "password": parsed.password, - "database": parsed.path[1:], - }, - **MYSQL_CONFIG, - } - elif parsed.scheme == "sqlite": - config["connections"]["default"] = { - "engine": "tortoise.backends.sqlite", - "credentials": { - "file_path": parsed.path or ":memory:", - }, - **SQLITE_CONFIG, - } - return config - - -@PriorityLifecycle.on_startup(priority=1) -async def init(): - if not BotConfig.db_url: - # raise DbUrlIsNode("数据库配置为空,请在.env.dev中配置DB_URL...") - error = f""" -********************************************************************** -🌟 **************************** 配置为空 ************************* 🌟 -🚀 请打开 WebUi 进行基础配置 🚀 -🌐 配置地址:http://{driver.config.host}:{driver.config.port}/#/configure 🌐 -*********************************************************************** -*********************************************************************** - """ - raise DbUrlIsNode("\n" + error.strip()) - try: - await Tortoise.init( - config=get_config(), - ) - if SCRIPT_METHOD: - db = Tortoise.get_connection("default") - logger.debug( - "即将运行SCRIPT_METHOD方法, 合计 " - f"{len(SCRIPT_METHOD)} 个..." - ) - sql_list = [] - for module, func in SCRIPT_METHOD: - try: - sql = await func() if is_coroutine_callable(func) else func() - if sql: - sql_list += sql - except Exception as e: - logger.debug(f"{module} 执行SCRIPT_METHOD方法出错...", e=e) - for sql in sql_list: - logger.debug(f"执行SQL: {sql}") - try: - await asyncio.wait_for( - db.execute_query_dict(sql), timeout=DB_TIMEOUT_SECONDS - ) - # await TestSQL.raw(sql) - except Exception as e: - logger.debug(f"执行SQL: {sql} 错误...", e=e) - if sql_list: - logger.debug("SCRIPT_METHOD方法执行完毕!") - logger.debug("开始生成数据库表结构...") - await Tortoise.generate_schemas() - logger.debug("数据库表结构生成完毕!") - logger.info("Database loaded successfully!") - except Exception as e: - raise DbConnectError(f"数据库连接错误... e:{e}") from e - - -async def disconnect(): - await connections.close_all() diff --git a/zhenxun/services/db_context/config.py b/zhenxun/services/db_context/config.py new file mode 100644 index 00000000..ae6d6b8c --- /dev/null +++ b/zhenxun/services/db_context/config.py @@ -0,0 +1,46 @@ +from collections.abc import Callable + +from pydantic import BaseModel + +# 数据库操作超时设置(秒) +DB_TIMEOUT_SECONDS = 3.0 + +# 性能监控阈值(秒) +SLOW_QUERY_THRESHOLD = 0.5 + +LOG_COMMAND = "DbContext" + + +POSTGRESQL_CONFIG = { + "max_size": 30, # 最大连接数 + "min_size": 5, # 最小保持的连接数(可选) +} + + +MYSQL_CONFIG = { + "max_connections": 20, # 最大连接数 + "connect_timeout": 30, # 连接超时(可选) +} + +SQLITE_CONFIG = { + "journal_mode": "WAL", # 提高并发写入性能 + "timeout": 30, # 锁等待超时(可选) +} + + +class DbModel(BaseModel): + script_method: list[tuple[str, Callable]] = [] + models: list[str] = [] + + +db_model = DbModel() + + +prompt = """ +********************************************************************** +🌟 **************************** 配置为空 ************************* 🌟 +🚀 请打开 WebUi 进行基础配置 🚀 +🌐 配置地址:http://{host}:{port}/#/configure 🌐 +*********************************************************************** +*********************************************************************** +""" diff --git a/zhenxun/services/db_context/exceptions.py b/zhenxun/services/db_context/exceptions.py new file mode 100644 index 00000000..163f92e2 --- /dev/null +++ b/zhenxun/services/db_context/exceptions.py @@ -0,0 +1,14 @@ +class DbUrlIsNode(Exception): + """ + 数据库链接地址为空 + """ + + pass + + +class DbConnectError(Exception): + """ + 数据库连接错误 + """ + + pass diff --git a/zhenxun/services/db_context/utils.py b/zhenxun/services/db_context/utils.py new file mode 100644 index 00000000..47db548f --- /dev/null +++ b/zhenxun/services/db_context/utils.py @@ -0,0 +1,27 @@ +import asyncio +import time + +from zhenxun.services.log import logger + +from .config import ( + DB_TIMEOUT_SECONDS, + LOG_COMMAND, + SLOW_QUERY_THRESHOLD, +) + + +async def with_db_timeout( + coro, timeout: float = DB_TIMEOUT_SECONDS, operation: str | None = None +): + """带超时控制的数据库操作""" + start_time = time.time() + try: + result = await asyncio.wait_for(coro, timeout=timeout) + elapsed = time.time() - start_time + if elapsed > SLOW_QUERY_THRESHOLD and operation: + logger.warning(f"慢查询: {operation} 耗时 {elapsed:.3f}s", LOG_COMMAND) + return result + except asyncio.TimeoutError: + if operation: + logger.error(f"数据库操作超时: {operation} (>{timeout}s)", LOG_COMMAND) + raise diff --git a/zhenxun/services/plugin_init.py b/zhenxun/services/plugin_init.py index a622a9e8..a7b8685a 100644 --- a/zhenxun/services/plugin_init.py +++ b/zhenxun/services/plugin_init.py @@ -54,19 +54,20 @@ class PluginInitManager: @classmethod async def install_all(cls): """运行所有插件安装方法""" - if cls.plugins: - for module_path, model in cls.plugins.items(): - if model.install: - class_ = model.class_() - try: - logger.debug(f"开始执行: {module_path}:install 方法") - if is_coroutine_callable(class_.install): - await class_.install() - else: - class_.install() # type: ignore - logger.debug(f"执行: {module_path}:install 完成") - except Exception as e: - logger.error(f"执行: {module_path}:install 失败", e=e) + if not cls.plugins: + return + for module_path, model in cls.plugins.items(): + if model.install: + class_ = model.class_() + try: + logger.debug(f"开始执行: {module_path}:install 方法") + if is_coroutine_callable(class_.install): + await class_.install() + else: + class_.install() # type: ignore + logger.debug(f"执行: {module_path}:install 完成") + except Exception as e: + logger.error(f"执行: {module_path}:install 失败", e=e) @classmethod async def install(cls, module_path: str): From b993450a234e43f0518ab1dfbabe6af95dace8ec Mon Sep 17 00:00:00 2001 From: Rumio <32546670+webjoin111@users.noreply.github.com> Date: Tue, 15 Jul 2025 17:13:33 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E2=9C=A8=20feat(limit,=20message):=20?= =?UTF-8?q?=E5=BC=95=E5=85=A5=E5=A3=B0=E6=98=8E=E5=BC=8F=E9=99=90=E6=B5=81?= =?UTF-8?q?=E7=B3=BB=E7=BB=9F=E5=B9=B6=E5=A2=9E=E5=BC=BA=E6=B6=88=E6=81=AF?= =?UTF-8?q?=E6=A0=BC=E5=BC=8F=E5=8C=96=E5=8A=9F=E8=83=BD=20(#1978)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 Cooldown、RateLimit、ConcurrencyLimit 三种限流依赖 - MessageUtils 支持动态格式化字符串 (format_args 参数) - 插件CD限制消息显示精确剩余时间 - 重构限流逻辑至 utils/limiters.py,新增时间工具模块 - 整合时间工具函数并优化时区处理 - 新增 limiter_hook 自动释放资源,CooldownError 优化异常处理 - 冷却提示从固定文本改为动态显示剩余时间 - 示例:总结功能冷却中,请等待 1分30秒 后再试~ Co-authored-by: webjoin111 <455457521@qq.com> Co-authored-by: HibiKier <45528451+HibiKier@users.noreply.github.com> --- .../builtin_plugins/hooks/auth/auth_limit.py | 18 +- zhenxun/builtin_plugins/hooks/limiter_hook.py | 15 ++ .../statistics/_data_source.py | 2 +- zhenxun/utils/depends/__init__.py | 171 +++++++++++++++++- zhenxun/utils/exception.py | 14 ++ zhenxun/utils/limiters.py | 140 ++++++++++++++ zhenxun/utils/message.py | 22 ++- zhenxun/utils/time_utils.py | 91 ++++++++++ zhenxun/utils/utils.py | 97 +--------- 9 files changed, 463 insertions(+), 107 deletions(-) create mode 100644 zhenxun/builtin_plugins/hooks/limiter_hook.py create mode 100644 zhenxun/utils/limiters.py create mode 100644 zhenxun/utils/time_utils.py diff --git a/zhenxun/builtin_plugins/hooks/auth/auth_limit.py b/zhenxun/builtin_plugins/hooks/auth/auth_limit.py index d199ff0d..80650472 100644 --- a/zhenxun/builtin_plugins/hooks/auth/auth_limit.py +++ b/zhenxun/builtin_plugins/hooks/auth/auth_limit.py @@ -11,14 +11,11 @@ from zhenxun.models.plugin_limit import PluginLimit from zhenxun.services.db_context import DB_TIMEOUT_SECONDS from zhenxun.services.log import logger from zhenxun.utils.enum import LimitWatchType, PluginLimitType +from zhenxun.utils.limiters import CountLimiter, FreqLimiter, UserBlockLimiter from zhenxun.utils.manager.priority_manager import PriorityLifecycle from zhenxun.utils.message import MessageUtils -from zhenxun.utils.utils import ( - CountLimiter, - FreqLimiter, - UserBlockLimiter, - get_entity_ids, -) +from zhenxun.utils.time_utils import TimeUtils +from zhenxun.utils.utils import get_entity_ids from .config import LOGGER_COMMAND, WARNING_THRESHOLD from .exception import SkipPluginException @@ -273,9 +270,16 @@ class LimitManager: key_type = channel_id or group_id if is_limit and not limiter.check(key_type): if limit.result: + format_kwargs = {} + if isinstance(limiter, FreqLimiter): + left_time = limiter.left_time(key_type) + cd_str = TimeUtils.format_duration(left_time) + format_kwargs = {"cd": cd_str} try: await asyncio.wait_for( - MessageUtils.build_message(limit.result).send(), + MessageUtils.build_message( + limit.result, format_args=format_kwargs + ).send(), timeout=DB_TIMEOUT_SECONDS, ) except asyncio.TimeoutError: diff --git a/zhenxun/builtin_plugins/hooks/limiter_hook.py b/zhenxun/builtin_plugins/hooks/limiter_hook.py new file mode 100644 index 00000000..22680941 --- /dev/null +++ b/zhenxun/builtin_plugins/hooks/limiter_hook.py @@ -0,0 +1,15 @@ +from nonebot.matcher import Matcher +from nonebot.message import run_postprocessor + +from zhenxun.utils.limiters import ConcurrencyLimiter + + +@run_postprocessor +async def _concurrency_release_hook(matcher: Matcher): + """ + 后处理器:在事件处理结束后,释放并发限制的信号量。 + """ + if concurrency_info := matcher.state.get("_concurrency_limiter_info"): + limiter: ConcurrencyLimiter = concurrency_info["limiter"] + key = concurrency_info["key"] + limiter.release(key) diff --git a/zhenxun/builtin_plugins/statistics/_data_source.py b/zhenxun/builtin_plugins/statistics/_data_source.py index 81e2b035..ab426ae6 100644 --- a/zhenxun/builtin_plugins/statistics/_data_source.py +++ b/zhenxun/builtin_plugins/statistics/_data_source.py @@ -8,7 +8,7 @@ from zhenxun.utils.echart_utils import ChartUtils from zhenxun.utils.echart_utils.models import Barh from zhenxun.utils.enum import PluginType from zhenxun.utils.image_utils import BuildImage -from zhenxun.utils.utils import TimeUtils +from zhenxun.utils.time_utils import TimeUtils class StatisticsManage: diff --git a/zhenxun/utils/depends/__init__.py b/zhenxun/utils/depends/__init__.py index 887813dd..e52dfd48 100644 --- a/zhenxun/utils/depends/__init__.py +++ b/zhenxun/utils/depends/__init__.py @@ -1,13 +1,181 @@ -from typing import Any +from typing import Any, Literal +from nonebot.adapters import Bot, Event from nonebot.internal.params import Depends from nonebot.matcher import Matcher from nonebot.params import Command +from nonebot.permission import SUPERUSER from nonebot_plugin_session import EventSession from nonebot_plugin_uninfo import Uninfo from zhenxun.configs.config import Config +from zhenxun.utils.limiters import ConcurrencyLimiter, FreqLimiter, RateLimiter from zhenxun.utils.message import MessageUtils +from zhenxun.utils.time_utils import TimeUtils + +_coolers: dict[str, FreqLimiter] = {} +_rate_limiters: dict[str, RateLimiter] = {} +_concurrency_limiters: dict[str, ConcurrencyLimiter] = {} + + +def _create_limiter_dependency( + limiter_class: type, + limiter_storage: dict, + limiter_init_args: dict[str, Any], + scope: Literal["user", "group", "global"], + prompt: str, + **kwargs, +): + """ + 一个高阶函数,用于创建不同类型的限制器依赖。 + + 参数: + limiter_class: 限制器类 (FreqLimiter, RateLimiter, etc.). + limiter_storage: 用于存储限制器实例的字典. + limiter_init_args: 限制器类的初始化参数. + scope: 限制作用域. + prompt: 触发限制时的提示信息. + **kwargs: 传递给特定限制器逻辑的额外参数. + """ + + async def dependency( + matcher: Matcher, session: EventSession, bot: Bot, event: Event + ) -> bool: + if await SUPERUSER(bot, event): + return True + + handler_id = ( + f"{matcher.plugin_name}:{matcher.handlers[0].call.__code__.co_firstlineno}" + ) + + key: str | None = None + if scope == "user": + key = session.id1 + elif scope == "group": + key = session.id3 or session.id2 or session.id1 + elif scope == "global": + key = f"global_{handler_id}" + + if not key: + return True + + if handler_id not in limiter_storage: + limiter_storage[handler_id] = limiter_class(**limiter_init_args) + limiter = limiter_storage[handler_id] + + if isinstance(limiter, ConcurrencyLimiter): + await limiter.acquire(key) + matcher.state["_concurrency_limiter_info"] = { + "limiter": limiter, + "key": key, + } + return True + else: + if limiter.check(key): + if isinstance(limiter, FreqLimiter): + limiter.start_cd( + key, kwargs.get("duration_sec", limiter.default_cd) + ) + return True + else: + left_time = limiter.left_time(key) + format_kwargs = { + "cd_str": TimeUtils.format_duration(left_time), + **(kwargs.get("prompt_format_kwargs", {})), + } + message = prompt.format(**format_kwargs) + await matcher.finish(message) + + return Depends(dependency) + + +def Cooldown( + duration: str, + *, + scope: Literal["user", "group", "global"] = "user", + prompt: str = "操作过于频繁,请等待 {cd_str}", +) -> bool: + """声明式冷却检查依赖,限制用户操作频率 + + 参数: + duration: 冷却时间字符串 (e.g., "30s", "10m", "1h") + scope: 冷却作用域 + prompt: 自定义的冷却提示消息,可使用 {cd_str} 占位符 + + 返回: + bool: 是否允许执行 + """ + try: + parsed_seconds = TimeUtils.parse_time_string(duration) + except ValueError as e: + raise ValueError(f"Cooldown装饰器中的duration格式错误: {e}") + + return _create_limiter_dependency( + limiter_class=FreqLimiter, + limiter_storage=_coolers, + limiter_init_args={"default_cd_seconds": parsed_seconds}, + scope=scope, + prompt=prompt, + duration_sec=parsed_seconds, + ) + + +def RateLimit( + count: int, + duration: str, + *, + scope: Literal["user", "group", "global"] = "user", + prompt: str = "太快了,在 {duration_str} 内只能触发{limit}次,请等待 {cd_str}", +) -> bool: + """声明式速率限制依赖,在指定时间窗口内限制操作次数 + + 参数: + count: 在时间窗口内允许的最大调用次数 + duration: 时间窗口字符串 (e.g., "1m", "1h") + scope: 限制作用域 + prompt: 自定义的提示消息,可使用 {cd_str}, {duration_str}, {limit} 占位符 + + 返回: + bool: 是否允许执行 + """ + try: + parsed_seconds = TimeUtils.parse_time_string(duration) + except ValueError as e: + raise ValueError(f"RateLimit装饰器中的duration格式错误: {e}") + + return _create_limiter_dependency( + limiter_class=RateLimiter, + limiter_storage=_rate_limiters, + limiter_init_args={"max_calls": count, "time_window": parsed_seconds}, + scope=scope, + prompt=prompt, + prompt_format_kwargs={"duration_str": duration, "limit": count}, + ) + + +def ConcurrencyLimit( + count: int, + *, + scope: Literal["user", "group", "global"] = "global", + prompt: str | None = "当前功能繁忙,请稍后再试...", +) -> bool: + """声明式并发数限制依赖,控制某个功能同时执行的实例数量 + + 参数: + count: 最大并发数 + scope: 限制作用域 + prompt: 提示消息(暂未使用,主要用于未来扩展超时功能) + + 返回: + bool: 是否允许执行 + """ + return _create_limiter_dependency( + limiter_class=ConcurrencyLimiter, + limiter_storage=_concurrency_limiters, + limiter_init_args={"max_concurrent": count}, + scope=scope, + prompt=prompt or "", + ) def CheckUg(check_user: bool = True, check_group: bool = True): @@ -75,7 +243,6 @@ def GetConfig( if module_: value = Config.get_config(module_, config, default_value) if value is None and prompt: - # await matcher.finish(prompt or f"配置项 {config} 未填写!") await matcher.finish(prompt) return value diff --git a/zhenxun/utils/exception.py b/zhenxun/utils/exception.py index 9ab664f4..8b3ec282 100644 --- a/zhenxun/utils/exception.py +++ b/zhenxun/utils/exception.py @@ -1,3 +1,17 @@ +from nonebot.exception import IgnoredException + + +class CooldownError(IgnoredException): + """ + 冷却异常,用于在冷却时中断事件处理。 + 继承自 IgnoredException,不会在控制台留下错误堆栈。 + """ + + def __init__(self, message: str): + self.message = message + super().__init__(message) + + class HookPriorityException(BaseException): """ 钩子优先级异常 diff --git a/zhenxun/utils/limiters.py b/zhenxun/utils/limiters.py new file mode 100644 index 00000000..1bd5f662 --- /dev/null +++ b/zhenxun/utils/limiters.py @@ -0,0 +1,140 @@ +import asyncio +from collections import defaultdict, deque +import time +from typing import Any + + +class FreqLimiter: + """ + 命令冷却,检测用户是否处于冷却状态 + """ + + def __init__(self, default_cd_seconds: int): + self.next_time: dict[Any, float] = defaultdict(float) + self.default_cd = default_cd_seconds + + def check(self, key: Any) -> bool: + return time.time() >= self.next_time[key] + + def start_cd(self, key: Any, cd_time: int = 0): + self.next_time[key] = time.time() + ( + cd_time if cd_time > 0 else self.default_cd + ) + + def left_time(self, key: Any) -> float: + return max(0.0, self.next_time[key] - time.time()) + + +class CountLimiter: + """ + 每日调用命令次数限制 + """ + + tz = None + + def __init__(self, max_num: int): + self.today = -1 + self.count: dict[Any, int] = defaultdict(int) + self.max = max_num + + def check(self, key: Any) -> bool: + import datetime + + day = datetime.datetime.now().day + if day != self.today: + self.today = day + self.count.clear() + return self.count[key] < self.max + + def get_num(self, key: Any) -> int: + return self.count[key] + + def increase(self, key: Any, num: int = 1): + self.count[key] += num + + def reset(self, key: Any): + self.count[key] = 0 + + +class UserBlockLimiter: + """ + 检测用户是否正在调用命令 (简单阻塞锁) + """ + + def __init__(self): + self.flag_data: dict[Any, bool] = defaultdict(bool) + self.time: dict[Any, float] = defaultdict(float) + + def set_true(self, key: Any): + self.time[key] = time.time() + self.flag_data[key] = True + + def set_false(self, key: Any): + self.flag_data[key] = False + + def check(self, key: Any) -> bool: + if self.flag_data[key] and time.time() - self.time[key] > 30: + self.set_false(key) + return not self.flag_data[key] + + +class RateLimiter: + """ + 一个简单的基于时间窗口的速率限制器。 + """ + + def __init__(self, max_calls: int, time_window: int): + self.requests: dict[Any, deque[float]] = defaultdict(deque) + self.max_calls = max_calls + self.time_window = time_window + + def check(self, key: Any) -> bool: + """检查是否超出速率限制。如果未超出,则记录本次调用。""" + now = time.time() + + while self.requests[key] and self.requests[key][0] <= now - self.time_window: + self.requests[key].popleft() + + if len(self.requests[key]) < self.max_calls: + self.requests[key].append(now) + return True + return False + + def left_time(self, key: Any) -> float: + """计算距离下次可调用还需等待的时间""" + if self.requests[key]: + return max(0.0, self.requests[key][0] + self.time_window - time.time()) + return 0.0 + + +class ConcurrencyLimiter: + """ + 一个基于 asyncio.Semaphore 的并发限制器。 + """ + + def __init__(self, max_concurrent: int): + self._semaphores: dict[Any, asyncio.Semaphore] = {} + self.max_concurrent = max_concurrent + self._active_tasks: dict[Any, int] = defaultdict(int) + + def _get_semaphore(self, key: Any) -> asyncio.Semaphore: + if key not in self._semaphores: + self._semaphores[key] = asyncio.Semaphore(self.max_concurrent) + return self._semaphores[key] + + async def acquire(self, key: Any): + """获取一个信号量,如果达到并发上限则会阻塞等待。""" + semaphore = self._get_semaphore(key) + await semaphore.acquire() + self._active_tasks[key] += 1 + + def release(self, key: Any): + """释放一个信号量。""" + if key in self._semaphores: + if self._active_tasks[key] > 0: + self._semaphores[key].release() + self._active_tasks[key] -= 1 + else: + import logging + + logging.warning(f"尝试释放键 '{key}' 的信号量时,计数已经为零。") diff --git a/zhenxun/utils/message.py b/zhenxun/utils/message.py index 927b050c..5fec2213 100644 --- a/zhenxun/utils/message.py +++ b/zhenxun/utils/message.py @@ -49,11 +49,14 @@ class Config(BaseModel): class MessageUtils: @classmethod - def __build_message(cls, msg_list: list[MESSAGE_TYPE]) -> list[Text | Image]: + def __build_message( + cls, msg_list: list[MESSAGE_TYPE], format_args: dict | None = None + ) -> list[Text | Image]: """构造消息 参数: msg_list: 消息列表 + format_args: 用于格式化字符串的参数字典. 返回: list[Text | Text]: 构造完成的消息列表 @@ -65,7 +68,15 @@ class MessageUtils: if msg.startswith("base64://"): message_list.append(Image(raw=BytesIO(base64.b64decode(msg[9:])))) else: - message_list.append(Text(msg)) + formatted_msg = msg + if format_args: + try: + formatted_msg = msg.format_map(format_args) + except (KeyError, IndexError) as e: + logger.debug( + f"格式化字符串 '{msg}' 失败 ({e}),将使用原始文本。" + ) + message_list.append(Text(formatted_msg)) elif isinstance(msg, int | float): message_list.append(Text(str(msg))) elif isinstance(msg, Path): @@ -90,12 +101,15 @@ class MessageUtils: @classmethod def build_message( - cls, msg_list: MESSAGE_TYPE | list[MESSAGE_TYPE | list[MESSAGE_TYPE]] + cls, + msg_list: MESSAGE_TYPE | list[MESSAGE_TYPE | list[MESSAGE_TYPE]], + format_args: dict | None = None, ) -> UniMessage: """构造消息 参数: msg_list: 消息列表 + format_args: 用于格式化字符串的参数字典. 返回: UniMessage: 构造完成的消息列表 @@ -105,7 +119,7 @@ class MessageUtils: msg_list = [msg_list] for m in msg_list: _data = m if isinstance(m, list) else [m] - message_list += cls.__build_message(_data) # type: ignore + message_list += cls.__build_message(_data, format_args) return UniMessage(message_list) @classmethod diff --git a/zhenxun/utils/time_utils.py b/zhenxun/utils/time_utils.py new file mode 100644 index 00000000..f478625d --- /dev/null +++ b/zhenxun/utils/time_utils.py @@ -0,0 +1,91 @@ +from datetime import date, datetime +import re + +import pytz + + +class TimeUtils: + DEFAULT_TIMEZONE = pytz.timezone("Asia/Shanghai") + + @classmethod + def get_day_start(cls, target_date: date | datetime | None = None) -> datetime: + """获取某天的0点时间 + + 返回: + datetime: 今天某天的0点时间 + """ + if not target_date: + target_date = datetime.now(cls.DEFAULT_TIMEZONE) + + if isinstance(target_date, datetime) and target_date.tzinfo is None: + target_date = cls.DEFAULT_TIMEZONE.localize(target_date) + + return ( + target_date.replace(hour=0, minute=0, second=0, microsecond=0) + if isinstance(target_date, datetime) + else datetime.combine( + target_date, datetime.min.time(), tzinfo=cls.DEFAULT_TIMEZONE + ) + ) + + @classmethod + def is_valid_date(cls, date_text: str, separator: str = "-") -> bool: + """日期是否合法 + + 参数: + date_text: 日期 + separator: 分隔符 + + 返回: + bool: 日期是否合法 + """ + try: + datetime.strptime(date_text, f"%Y{separator}%m{separator}%d") + return True + except ValueError: + return False + + @classmethod + def parse_time_string(cls, time_str: str) -> int: + """ + 将带有单位的时间字符串 (e.g., "10s", "5m", "1h") 解析为总秒数。 + """ + time_str = time_str.lower().strip() + match = re.match(r"^(\d+)([smh])$", time_str) + if not match: + raise ValueError( + f"无效的时间格式: '{time_str}'。请使用如 '30s', '10m', '2h' 的格式。" + ) + + value, unit = int(match.group(1)), match.group(2) + + if unit == "s": + return value + if unit == "m": + return value * 60 + if unit == "h": + return value * 3600 + return 0 + + @classmethod + def format_duration(cls, seconds: float) -> str: + """ + 将秒数格式化为易于阅读的字符串 (例如 "1小时5分钟", "30.5秒") + """ + seconds = round(seconds, 1) + if seconds < 0.1: + return "不到1秒" + if seconds < 60: + return f"{seconds}秒" + + minutes, sec_remainder = divmod(int(seconds), 60) + + if minutes < 60: + if sec_remainder == 0: + return f"{minutes}分钟" + return f"{minutes}分钟{sec_remainder}秒" + + hours, rem_minutes = divmod(minutes, 60) + if rem_minutes == 0: + return f"{hours}小时" + return f"{hours}小时{rem_minutes}分钟" diff --git a/zhenxun/utils/utils.py b/zhenxun/utils/utils.py index 44dcd672..fc6b4096 100644 --- a/zhenxun/utils/utils.py +++ b/zhenxun/utils/utils.py @@ -1,19 +1,19 @@ -from collections import defaultdict from dataclasses import dataclass -from datetime import date, datetime +from datetime import datetime import os from pathlib import Path import time -from typing import Any, ClassVar +from typing import ClassVar import httpx from nonebot_plugin_uninfo import Uninfo import pypinyin -import pytz from zhenxun.configs.config import Config from zhenxun.services.log import logger +from .limiters import CountLimiter, FreqLimiter, UserBlockLimiter # noqa: F401 + @dataclass class EntityIDs: @@ -64,78 +64,6 @@ class ResourceDirManager: cls.__tree_append(path, deep) -class CountLimiter: - """ - 每日调用命令次数限制 - """ - - tz = pytz.timezone("Asia/Shanghai") - - def __init__(self, max_num): - self.today = -1 - self.count = defaultdict(int) - self.max = max_num - - def check(self, key) -> bool: - day = datetime.now(self.tz).day - if day != self.today: - self.today = day - self.count.clear() - return self.count[key] < self.max - - def get_num(self, key): - return self.count[key] - - def increase(self, key, num=1): - self.count[key] += num - - def reset(self, key): - self.count[key] = 0 - - -class UserBlockLimiter: - """ - 检测用户是否正在调用命令 - """ - - def __init__(self): - self.flag_data = defaultdict(bool) - self.time = time.time() - - def set_true(self, key: Any): - self.time = time.time() - self.flag_data[key] = True - - def set_false(self, key: Any): - self.flag_data[key] = False - - def check(self, key: Any) -> bool: - if time.time() - self.time > 30: - self.set_false(key) - return not self.flag_data[key] - - -class FreqLimiter: - """ - 命令冷却,检测用户是否处于冷却状态 - """ - - def __init__(self, default_cd_seconds: int): - self.next_time = defaultdict(float) - self.default_cd = default_cd_seconds - - def check(self, key: Any) -> bool: - return time.time() >= self.next_time[key] - - def start_cd(self, key: Any, cd_time: int = 0): - self.next_time[key] = time.time() + ( - cd_time if cd_time > 0 else self.default_cd - ) - - def left_time(self, key: Any) -> float: - return self.next_time[key] - time.time() - - def cn2py(word: str) -> str: """将字符串转化为拼音 @@ -277,20 +205,3 @@ def is_number(text: str) -> bool: return True except ValueError: return False - - -class TimeUtils: - @classmethod - def get_day_start(cls, target_date: date | datetime | None = None) -> datetime: - """获取某天的0点时间 - - 返回: - datetime: 今天某天的0点时间 - """ - if not target_date: - target_date = datetime.now() - return ( - target_date.replace(hour=0, minute=0, second=0, microsecond=0) - if isinstance(target_date, datetime) - else datetime.combine(target_date, datetime.min.time()) - )