From d218c569d46760e630cef14a5cbf12421b612891 Mon Sep 17 00:00:00 2001
From: HibiKier <45528451+HibiKier@users.noreply.github.com>
Date: Tue, 15 Jul 2025 17:08:42 +0800
Subject: [PATCH 1/2] =?UTF-8?q?:sparkles:=20=E6=A0=BC=E5=BC=8F=E5=8C=96db?=
=?UTF-8?q?=5Fcontext=20(#1980)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* :sparkles: 格式化db_context
* :fire: 移除旧db-context
* :zap: 添加旧版本兼容
---
zhenxun/builtin_plugins/__init__.py | 11 +-
zhenxun/services/__init__.py | 3 +-
zhenxun/services/cache/__init__.py | 59 +++---
zhenxun/services/cache/cache_containers.py | 21 +-
zhenxun/services/db_context/__init__.py | 146 ++++++++++++++
.../base_model.py} | 187 +-----------------
zhenxun/services/db_context/config.py | 46 +++++
zhenxun/services/db_context/exceptions.py | 14 ++
zhenxun/services/db_context/utils.py | 27 +++
zhenxun/services/plugin_init.py | 27 +--
10 files changed, 293 insertions(+), 248 deletions(-)
create mode 100644 zhenxun/services/db_context/__init__.py
rename zhenxun/services/{db_context.py => db_context/base_model.py} (64%)
create mode 100644 zhenxun/services/db_context/config.py
create mode 100644 zhenxun/services/db_context/exceptions.py
create mode 100644 zhenxun/services/db_context/utils.py
diff --git a/zhenxun/builtin_plugins/__init__.py b/zhenxun/builtin_plugins/__init__.py
index 4003e506..a5aa7a4b 100644
--- a/zhenxun/builtin_plugins/__init__.py
+++ b/zhenxun/builtin_plugins/__init__.py
@@ -5,7 +5,7 @@ import nonebot
from nonebot.adapters import Bot
from nonebot.drivers import Driver
from tortoise import Tortoise
-from tortoise.exceptions import OperationalError
+from tortoise.exceptions import IntegrityError, OperationalError
import ujson as json
from zhenxun.models.bot_connect_log import BotConnectLog
@@ -30,9 +30,12 @@ async def _(bot: Bot):
bot_id=bot.self_id, platform=bot.adapter, connect_time=datetime.now(), type=1
)
if not await BotConsole.exists(bot_id=bot.self_id):
- await BotConsole.create(
- bot_id=bot.self_id, platform=PlatformUtils.get_platform(bot)
- )
+ try:
+ await BotConsole.create(
+ bot_id=bot.self_id, platform=PlatformUtils.get_platform(bot)
+ )
+ except IntegrityError as e:
+ logger.warning(f"记录bot: {bot.self_id} 数据已存在...", e=e)
@driver.on_bot_disconnect
diff --git a/zhenxun/services/__init__.py b/zhenxun/services/__init__.py
index 4c820b87..b3dc292e 100644
--- a/zhenxun/services/__init__.py
+++ b/zhenxun/services/__init__.py
@@ -18,7 +18,7 @@ require("nonebot_plugin_htmlrender")
require("nonebot_plugin_uninfo")
require("nonebot_plugin_waiter")
-from .db_context import Model, disconnect
+from .db_context import Model, disconnect, with_db_timeout
from .llm import (
AI,
AIConfig,
@@ -80,4 +80,5 @@ __all__ = [
"search_multimodal",
"set_global_default_model_name",
"tool_registry",
+ "with_db_timeout",
]
diff --git a/zhenxun/services/cache/__init__.py b/zhenxun/services/cache/__init__.py
index 76b05a5c..e2ce9c30 100644
--- a/zhenxun/services/cache/__init__.py
+++ b/zhenxun/services/cache/__init__.py
@@ -63,6 +63,7 @@ from functools import wraps
from typing import Any, ClassVar, Generic, TypeVar, get_type_hints
from aiocache import Cache as AioCache
+from aiocache import SimpleMemoryCache
from aiocache.base import BaseCache
from aiocache.serializers import JsonSerializer
import nonebot
@@ -392,22 +393,14 @@ class CacheManager:
def cache_backend(self) -> BaseCache | AioCache:
"""获取缓存后端"""
if self._cache_backend is None:
- try:
- from aiocache import RedisCache, SimpleMemoryCache
+ ttl = cache_config.redis_expire
+ if cache_config.cache_mode == CacheMode.NONE:
+ ttl = 0
+ logger.info("缓存功能已禁用,使用非持久化内存缓存", LOG_COMMAND)
+ elif cache_config.cache_mode == CacheMode.REDIS and cache_config.redis_host:
+ try:
+ from aiocache import RedisCache
- if cache_config.cache_mode == CacheMode.NONE:
- # 使用内存缓存但禁用持久化
- self._cache_backend = SimpleMemoryCache(
- serializer=JsonSerializer(),
- namespace=CACHE_KEY_PREFIX,
- timeout=30,
- ttl=0, # 设置为0,不缓存
- )
- logger.info("缓存功能已禁用,使用非持久化内存缓存", LOG_COMMAND)
- elif (
- cache_config.cache_mode == CacheMode.REDIS
- and cache_config.redis_host
- ):
# 使用Redis缓存
self._cache_backend = RedisCache(
serializer=JsonSerializer(),
@@ -419,27 +412,25 @@ class CacheManager:
password=cache_config.redis_password,
)
logger.info(
- f"使用Redis缓存,地址: {cache_config.redis_host}", LOG_COMMAND
+ f"使用Redis缓存,地址: {cache_config.redis_host}",
+ LOG_COMMAND,
)
- else:
- # 默认使用内存缓存
- self._cache_backend = SimpleMemoryCache(
- serializer=JsonSerializer(),
- namespace=CACHE_KEY_PREFIX,
- timeout=30,
- ttl=cache_config.redis_expire,
+ return self._cache_backend
+ except ImportError as e:
+ logger.error(
+ "导入aiocache[redis]失败,将默认使用内存缓存...",
+ LOG_COMMAND,
+ e=e,
)
- logger.info("使用内存缓存", LOG_COMMAND)
- except ImportError:
- logger.error("导入aiocache模块失败,使用内存缓存", LOG_COMMAND)
- # 使用内存缓存
- self._cache_backend = AioCache(
- cache_class=AioCache.MEMORY,
- serializer=JsonSerializer(),
- namespace=CACHE_KEY_PREFIX,
- timeout=30,
- ttl=cache_config.redis_expire,
- )
+ else:
+ logger.info("使用内存缓存", LOG_COMMAND)
+ # 默认使用内存缓存
+ self._cache_backend = SimpleMemoryCache(
+ serializer=JsonSerializer(),
+ namespace=CACHE_KEY_PREFIX,
+ timeout=30,
+ ttl=ttl,
+ )
return self._cache_backend
@property
diff --git a/zhenxun/services/cache/cache_containers.py b/zhenxun/services/cache/cache_containers.py
index b0efe3fb..42b35239 100644
--- a/zhenxun/services/cache/cache_containers.py
+++ b/zhenxun/services/cache/cache_containers.py
@@ -54,11 +54,7 @@ class CacheDict:
key: 字典键
value: 字典值
"""
- # 计算过期时间
- expire_time = 0
- if self.expire > 0:
- expire_time = time.time() + self.expire
-
+ expire_time = time.time() + self.expire if self.expire > 0 else 0
self._data[key] = CacheData(value=value, expire_time=expire_time)
def __delitem__(self, key: str) -> None:
@@ -274,12 +270,11 @@ class CacheList:
self.clear()
raise IndexError(f"列表索引 {index} 超出范围")
- if 0 <= index < len(self._data):
- del self._data[index]
- # 更新过期时间
- self._update_expire_time()
- else:
+ if not 0 <= index < len(self._data):
raise IndexError(f"列表索引 {index} 超出范围")
+ del self._data[index]
+ # 更新过期时间
+ self._update_expire_time()
def __len__(self) -> int:
"""获取列表长度
@@ -427,6 +422,7 @@ class CacheList:
self.clear()
return 0
+ # sourcery skip: simplify-constant-sum
return sum(1 for item in self._data if item.value == value)
def _is_expired(self) -> bool:
@@ -435,10 +431,7 @@ class CacheList:
def _update_expire_time(self) -> None:
"""更新过期时间"""
- if self.expire > 0:
- self._expire_time = time.time() + self.expire
- else:
- self._expire_time = 0
+ self._expire_time = time.time() + self.expire if self.expire > 0 else 0
def __str__(self) -> str:
"""字符串表示
diff --git a/zhenxun/services/db_context/__init__.py b/zhenxun/services/db_context/__init__.py
new file mode 100644
index 00000000..26fd9bcd
--- /dev/null
+++ b/zhenxun/services/db_context/__init__.py
@@ -0,0 +1,146 @@
+import asyncio
+from urllib.parse import urlparse
+
+import nonebot
+from nonebot.utils import is_coroutine_callable
+from tortoise import Tortoise
+from tortoise.connection import connections
+
+from zhenxun.configs.config import BotConfig
+from zhenxun.services.log import logger
+from zhenxun.utils.manager.priority_manager import PriorityLifecycle
+
+from .base_model import Model
+from .config import (
+ DB_TIMEOUT_SECONDS,
+ MYSQL_CONFIG,
+ POSTGRESQL_CONFIG,
+ SLOW_QUERY_THRESHOLD,
+ SQLITE_CONFIG,
+ db_model,
+ prompt,
+)
+from .exceptions import DbConnectError, DbUrlIsNode
+from .utils import with_db_timeout
+
+MODELS = db_model.models
+SCRIPT_METHOD = db_model.script_method
+
+__all__ = [
+ "DB_TIMEOUT_SECONDS",
+ "MODELS",
+ "SCRIPT_METHOD",
+ "SLOW_QUERY_THRESHOLD",
+ "DbConnectError",
+ "DbUrlIsNode",
+ "Model",
+ "disconnect",
+ "init",
+ "with_db_timeout",
+]
+
+driver = nonebot.get_driver()
+
+
+def get_config() -> dict:
+ """获取数据库配置"""
+ parsed = urlparse(BotConfig.db_url)
+
+ # 基础配置
+ config = {
+ "connections": {
+ "default": BotConfig.db_url # 默认直接使用连接字符串
+ },
+ "apps": {
+ "models": {
+ "models": db_model.models,
+ "default_connection": "default",
+ }
+ },
+ "timezone": "Asia/Shanghai",
+ }
+
+ # 根据数据库类型应用高级配置
+ if parsed.scheme.startswith("postgres"):
+ config["connections"]["default"] = {
+ "engine": "tortoise.backends.asyncpg",
+ "credentials": {
+ "host": parsed.hostname,
+ "port": parsed.port or 5432,
+ "user": parsed.username,
+ "password": parsed.password,
+ "database": parsed.path[1:],
+ },
+ **POSTGRESQL_CONFIG,
+ }
+ elif parsed.scheme == "mysql":
+ config["connections"]["default"] = {
+ "engine": "tortoise.backends.mysql",
+ "credentials": {
+ "host": parsed.hostname,
+ "port": parsed.port or 3306,
+ "user": parsed.username,
+ "password": parsed.password,
+ "database": parsed.path[1:],
+ },
+ **MYSQL_CONFIG,
+ }
+ elif parsed.scheme == "sqlite":
+ config["connections"]["default"] = {
+ "engine": "tortoise.backends.sqlite",
+ "credentials": {
+ "file_path": parsed.path or ":memory:",
+ },
+ **SQLITE_CONFIG,
+ }
+ return config
+
+
+@PriorityLifecycle.on_startup(priority=1)
+async def init():
+ global MODELS, SCRIPT_METHOD
+
+ MODELS = db_model.models
+ SCRIPT_METHOD = db_model.script_method
+ if not BotConfig.db_url:
+ error = prompt.format(host=driver.config.host, port=driver.config.port)
+ raise DbUrlIsNode("\n" + error.strip())
+ try:
+ await Tortoise.init(
+ config=get_config(),
+ )
+ if db_model.script_method:
+ db = Tortoise.get_connection("default")
+ logger.debug(
+ "即将运行SCRIPT_METHOD方法, 合计 "
+ f"{len(db_model.script_method)} 个..."
+ )
+ sql_list = []
+ for module, func in db_model.script_method:
+ try:
+ sql = await func() if is_coroutine_callable(func) else func()
+ if sql:
+ sql_list += sql
+ except Exception as e:
+ logger.debug(f"{module} 执行SCRIPT_METHOD方法出错...", e=e)
+ for sql in sql_list:
+ logger.debug(f"执行SQL: {sql}")
+ try:
+ await asyncio.wait_for(
+ db.execute_query_dict(sql), timeout=DB_TIMEOUT_SECONDS
+ )
+ # await TestSQL.raw(sql)
+ except Exception as e:
+ logger.debug(f"执行SQL: {sql} 错误...", e=e)
+ if sql_list:
+ logger.debug("SCRIPT_METHOD方法执行完毕!")
+ logger.debug("开始生成数据库表结构...")
+ await Tortoise.generate_schemas()
+ logger.debug("数据库表结构生成完毕!")
+ logger.info("Database loaded successfully!")
+ except Exception as e:
+ raise DbConnectError(f"数据库连接错误... e:{e}") from e
+
+
+async def disconnect():
+ await connections.close_all()
diff --git a/zhenxun/services/db_context.py b/zhenxun/services/db_context/base_model.py
similarity index 64%
rename from zhenxun/services/db_context.py
rename to zhenxun/services/db_context/base_model.py
index e6c42472..3e0e23ef 100644
--- a/zhenxun/services/db_context.py
+++ b/zhenxun/services/db_context/base_model.py
@@ -1,56 +1,20 @@
import asyncio
from collections.abc import Iterable
import contextlib
-import time
from typing import Any, ClassVar
from typing_extensions import Self
-from urllib.parse import urlparse
-from nonebot import get_driver
-from nonebot.utils import is_coroutine_callable
-from tortoise import Tortoise
from tortoise.backends.base.client import BaseDBAsyncClient
-from tortoise.connection import connections
from tortoise.exceptions import IntegrityError, MultipleObjectsReturned
from tortoise.models import Model as TortoiseModel
from tortoise.transactions import in_transaction
-from zhenxun.configs.config import BotConfig
from zhenxun.services.cache import CacheRoot
from zhenxun.services.log import logger
from zhenxun.utils.enum import DbLockType
-from zhenxun.utils.exception import HookPriorityException
-from zhenxun.utils.manager.priority_manager import PriorityLifecycle
-driver = get_driver()
-
-SCRIPT_METHOD = []
-MODELS: list[str] = []
-
-# 数据库操作超时设置(秒)
-DB_TIMEOUT_SECONDS = 3.0
-
-# 性能监控阈值(秒)
-SLOW_QUERY_THRESHOLD = 0.5
-
-LOG_COMMAND = "DbContext"
-
-
-async def with_db_timeout(
- coro, timeout: float = DB_TIMEOUT_SECONDS, operation: str | None = None
-):
- """带超时控制的数据库操作"""
- start_time = time.time()
- try:
- result = await asyncio.wait_for(coro, timeout=timeout)
- elapsed = time.time() - start_time
- if elapsed > SLOW_QUERY_THRESHOLD and operation:
- logger.warning(f"慢查询: {operation} 耗时 {elapsed:.3f}s", LOG_COMMAND)
- return result
- except asyncio.TimeoutError:
- if operation:
- logger.error(f"数据库操作超时: {operation} (>{timeout}s)", LOG_COMMAND)
- raise
+from .config import LOG_COMMAND, db_model
+from .utils import with_db_timeout
class Model(TortoiseModel):
@@ -63,11 +27,11 @@ class Model(TortoiseModel):
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
- if cls.__module__ not in MODELS:
- MODELS.append(cls.__module__)
+ if cls.__module__ not in db_model.models:
+ db_model.models.append(cls.__module__)
if func := getattr(cls, "_run_script", None):
- SCRIPT_METHOD.append((cls.__module__, func))
+ db_model.script_method.append((cls.__module__, func))
@classmethod
def get_cache_type(cls) -> str | None:
@@ -322,144 +286,3 @@ class Model(TortoiseModel):
f"数据库操作异常: {cls.__name__}.safe_get_or_none, {e!s}", LOG_COMMAND
)
raise
-
-
-class DbUrlIsNode(HookPriorityException):
- """
- 数据库链接地址为空
- """
-
- pass
-
-
-class DbConnectError(Exception):
- """
- 数据库连接错误
- """
-
- pass
-
-
-POSTGRESQL_CONFIG = {
- "max_size": 30, # 最大连接数
- "min_size": 5, # 最小保持的连接数(可选)
-}
-
-
-MYSQL_CONFIG = {
- "max_connections": 20, # 最大连接数
- "connect_timeout": 30, # 连接超时(可选)
-}
-
-SQLITE_CONFIG = {
- "journal_mode": "WAL", # 提高并发写入性能
- "timeout": 30, # 锁等待超时(可选)
-}
-
-
-def get_config() -> dict:
- """获取数据库配置"""
- parsed = urlparse(BotConfig.db_url)
-
- # 基础配置
- config = {
- "connections": {
- "default": BotConfig.db_url # 默认直接使用连接字符串
- },
- "apps": {
- "models": {
- "models": MODELS,
- "default_connection": "default",
- }
- },
- "timezone": "Asia/Shanghai",
- }
-
- # 根据数据库类型应用高级配置
- if parsed.scheme.startswith("postgres"):
- config["connections"]["default"] = {
- "engine": "tortoise.backends.asyncpg",
- "credentials": {
- "host": parsed.hostname,
- "port": parsed.port or 5432,
- "user": parsed.username,
- "password": parsed.password,
- "database": parsed.path[1:],
- },
- **POSTGRESQL_CONFIG,
- }
- elif parsed.scheme == "mysql":
- config["connections"]["default"] = {
- "engine": "tortoise.backends.mysql",
- "credentials": {
- "host": parsed.hostname,
- "port": parsed.port or 3306,
- "user": parsed.username,
- "password": parsed.password,
- "database": parsed.path[1:],
- },
- **MYSQL_CONFIG,
- }
- elif parsed.scheme == "sqlite":
- config["connections"]["default"] = {
- "engine": "tortoise.backends.sqlite",
- "credentials": {
- "file_path": parsed.path or ":memory:",
- },
- **SQLITE_CONFIG,
- }
- return config
-
-
-@PriorityLifecycle.on_startup(priority=1)
-async def init():
- if not BotConfig.db_url:
- # raise DbUrlIsNode("数据库配置为空,请在.env.dev中配置DB_URL...")
- error = f"""
-**********************************************************************
-🌟 **************************** 配置为空 ************************* 🌟
-🚀 请打开 WebUi 进行基础配置 🚀
-🌐 配置地址:http://{driver.config.host}:{driver.config.port}/#/configure 🌐
-***********************************************************************
-***********************************************************************
- """
- raise DbUrlIsNode("\n" + error.strip())
- try:
- await Tortoise.init(
- config=get_config(),
- )
- if SCRIPT_METHOD:
- db = Tortoise.get_connection("default")
- logger.debug(
- "即将运行SCRIPT_METHOD方法, 合计 "
- f"{len(SCRIPT_METHOD)} 个..."
- )
- sql_list = []
- for module, func in SCRIPT_METHOD:
- try:
- sql = await func() if is_coroutine_callable(func) else func()
- if sql:
- sql_list += sql
- except Exception as e:
- logger.debug(f"{module} 执行SCRIPT_METHOD方法出错...", e=e)
- for sql in sql_list:
- logger.debug(f"执行SQL: {sql}")
- try:
- await asyncio.wait_for(
- db.execute_query_dict(sql), timeout=DB_TIMEOUT_SECONDS
- )
- # await TestSQL.raw(sql)
- except Exception as e:
- logger.debug(f"执行SQL: {sql} 错误...", e=e)
- if sql_list:
- logger.debug("SCRIPT_METHOD方法执行完毕!")
- logger.debug("开始生成数据库表结构...")
- await Tortoise.generate_schemas()
- logger.debug("数据库表结构生成完毕!")
- logger.info("Database loaded successfully!")
- except Exception as e:
- raise DbConnectError(f"数据库连接错误... e:{e}") from e
-
-
-async def disconnect():
- await connections.close_all()
diff --git a/zhenxun/services/db_context/config.py b/zhenxun/services/db_context/config.py
new file mode 100644
index 00000000..ae6d6b8c
--- /dev/null
+++ b/zhenxun/services/db_context/config.py
@@ -0,0 +1,46 @@
+from collections.abc import Callable
+
+from pydantic import BaseModel
+
+# 数据库操作超时设置(秒)
+DB_TIMEOUT_SECONDS = 3.0
+
+# 性能监控阈值(秒)
+SLOW_QUERY_THRESHOLD = 0.5
+
+LOG_COMMAND = "DbContext"
+
+
+POSTGRESQL_CONFIG = {
+ "max_size": 30, # 最大连接数
+ "min_size": 5, # 最小保持的连接数(可选)
+}
+
+
+MYSQL_CONFIG = {
+ "max_connections": 20, # 最大连接数
+ "connect_timeout": 30, # 连接超时(可选)
+}
+
+SQLITE_CONFIG = {
+ "journal_mode": "WAL", # 提高并发写入性能
+ "timeout": 30, # 锁等待超时(可选)
+}
+
+
+class DbModel(BaseModel):
+ script_method: list[tuple[str, Callable]] = []
+ models: list[str] = []
+
+
+db_model = DbModel()
+
+
+prompt = """
+**********************************************************************
+🌟 **************************** 配置为空 ************************* 🌟
+🚀 请打开 WebUi 进行基础配置 🚀
+🌐 配置地址:http://{host}:{port}/#/configure 🌐
+***********************************************************************
+***********************************************************************
+"""
diff --git a/zhenxun/services/db_context/exceptions.py b/zhenxun/services/db_context/exceptions.py
new file mode 100644
index 00000000..163f92e2
--- /dev/null
+++ b/zhenxun/services/db_context/exceptions.py
@@ -0,0 +1,14 @@
+class DbUrlIsNode(Exception):
+ """
+ 数据库链接地址为空
+ """
+
+ pass
+
+
+class DbConnectError(Exception):
+ """
+ 数据库连接错误
+ """
+
+ pass
diff --git a/zhenxun/services/db_context/utils.py b/zhenxun/services/db_context/utils.py
new file mode 100644
index 00000000..47db548f
--- /dev/null
+++ b/zhenxun/services/db_context/utils.py
@@ -0,0 +1,27 @@
+import asyncio
+import time
+
+from zhenxun.services.log import logger
+
+from .config import (
+ DB_TIMEOUT_SECONDS,
+ LOG_COMMAND,
+ SLOW_QUERY_THRESHOLD,
+)
+
+
+async def with_db_timeout(
+ coro, timeout: float = DB_TIMEOUT_SECONDS, operation: str | None = None
+):
+ """带超时控制的数据库操作"""
+ start_time = time.time()
+ try:
+ result = await asyncio.wait_for(coro, timeout=timeout)
+ elapsed = time.time() - start_time
+ if elapsed > SLOW_QUERY_THRESHOLD and operation:
+ logger.warning(f"慢查询: {operation} 耗时 {elapsed:.3f}s", LOG_COMMAND)
+ return result
+ except asyncio.TimeoutError:
+ if operation:
+ logger.error(f"数据库操作超时: {operation} (>{timeout}s)", LOG_COMMAND)
+ raise
diff --git a/zhenxun/services/plugin_init.py b/zhenxun/services/plugin_init.py
index a622a9e8..a7b8685a 100644
--- a/zhenxun/services/plugin_init.py
+++ b/zhenxun/services/plugin_init.py
@@ -54,19 +54,20 @@ class PluginInitManager:
@classmethod
async def install_all(cls):
"""运行所有插件安装方法"""
- if cls.plugins:
- for module_path, model in cls.plugins.items():
- if model.install:
- class_ = model.class_()
- try:
- logger.debug(f"开始执行: {module_path}:install 方法")
- if is_coroutine_callable(class_.install):
- await class_.install()
- else:
- class_.install() # type: ignore
- logger.debug(f"执行: {module_path}:install 完成")
- except Exception as e:
- logger.error(f"执行: {module_path}:install 失败", e=e)
+ if not cls.plugins:
+ return
+ for module_path, model in cls.plugins.items():
+ if model.install:
+ class_ = model.class_()
+ try:
+ logger.debug(f"开始执行: {module_path}:install 方法")
+ if is_coroutine_callable(class_.install):
+ await class_.install()
+ else:
+ class_.install() # type: ignore
+ logger.debug(f"执行: {module_path}:install 完成")
+ except Exception as e:
+ logger.error(f"执行: {module_path}:install 失败", e=e)
@classmethod
async def install(cls, module_path: str):
From b993450a234e43f0518ab1dfbabe6af95dace8ec Mon Sep 17 00:00:00 2001
From: Rumio <32546670+webjoin111@users.noreply.github.com>
Date: Tue, 15 Jul 2025 17:13:33 +0800
Subject: [PATCH 2/2] =?UTF-8?q?=E2=9C=A8=20feat(limit,=20message):=20?=
=?UTF-8?q?=E5=BC=95=E5=85=A5=E5=A3=B0=E6=98=8E=E5=BC=8F=E9=99=90=E6=B5=81?=
=?UTF-8?q?=E7=B3=BB=E7=BB=9F=E5=B9=B6=E5=A2=9E=E5=BC=BA=E6=B6=88=E6=81=AF?=
=?UTF-8?q?=E6=A0=BC=E5=BC=8F=E5=8C=96=E5=8A=9F=E8=83=BD=20(#1978)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
- 新增 Cooldown、RateLimit、ConcurrencyLimit 三种限流依赖
- MessageUtils 支持动态格式化字符串 (format_args 参数)
- 插件CD限制消息显示精确剩余时间
- 重构限流逻辑至 utils/limiters.py,新增时间工具模块
- 整合时间工具函数并优化时区处理
- 新增 limiter_hook 自动释放资源,CooldownError 优化异常处理
- 冷却提示从固定文本改为动态显示剩余时间
- 示例:总结功能冷却中,请等待 1分30秒 后再试~
Co-authored-by: webjoin111 <455457521@qq.com>
Co-authored-by: HibiKier <45528451+HibiKier@users.noreply.github.com>
---
.../builtin_plugins/hooks/auth/auth_limit.py | 18 +-
zhenxun/builtin_plugins/hooks/limiter_hook.py | 15 ++
.../statistics/_data_source.py | 2 +-
zhenxun/utils/depends/__init__.py | 171 +++++++++++++++++-
zhenxun/utils/exception.py | 14 ++
zhenxun/utils/limiters.py | 140 ++++++++++++++
zhenxun/utils/message.py | 22 ++-
zhenxun/utils/time_utils.py | 91 ++++++++++
zhenxun/utils/utils.py | 97 +---------
9 files changed, 463 insertions(+), 107 deletions(-)
create mode 100644 zhenxun/builtin_plugins/hooks/limiter_hook.py
create mode 100644 zhenxun/utils/limiters.py
create mode 100644 zhenxun/utils/time_utils.py
diff --git a/zhenxun/builtin_plugins/hooks/auth/auth_limit.py b/zhenxun/builtin_plugins/hooks/auth/auth_limit.py
index d199ff0d..80650472 100644
--- a/zhenxun/builtin_plugins/hooks/auth/auth_limit.py
+++ b/zhenxun/builtin_plugins/hooks/auth/auth_limit.py
@@ -11,14 +11,11 @@ from zhenxun.models.plugin_limit import PluginLimit
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
from zhenxun.services.log import logger
from zhenxun.utils.enum import LimitWatchType, PluginLimitType
+from zhenxun.utils.limiters import CountLimiter, FreqLimiter, UserBlockLimiter
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from zhenxun.utils.message import MessageUtils
-from zhenxun.utils.utils import (
- CountLimiter,
- FreqLimiter,
- UserBlockLimiter,
- get_entity_ids,
-)
+from zhenxun.utils.time_utils import TimeUtils
+from zhenxun.utils.utils import get_entity_ids
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
from .exception import SkipPluginException
@@ -273,9 +270,16 @@ class LimitManager:
key_type = channel_id or group_id
if is_limit and not limiter.check(key_type):
if limit.result:
+ format_kwargs = {}
+ if isinstance(limiter, FreqLimiter):
+ left_time = limiter.left_time(key_type)
+ cd_str = TimeUtils.format_duration(left_time)
+ format_kwargs = {"cd": cd_str}
try:
await asyncio.wait_for(
- MessageUtils.build_message(limit.result).send(),
+ MessageUtils.build_message(
+ limit.result, format_args=format_kwargs
+ ).send(),
timeout=DB_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
diff --git a/zhenxun/builtin_plugins/hooks/limiter_hook.py b/zhenxun/builtin_plugins/hooks/limiter_hook.py
new file mode 100644
index 00000000..22680941
--- /dev/null
+++ b/zhenxun/builtin_plugins/hooks/limiter_hook.py
@@ -0,0 +1,15 @@
+from nonebot.matcher import Matcher
+from nonebot.message import run_postprocessor
+
+from zhenxun.utils.limiters import ConcurrencyLimiter
+
+
+@run_postprocessor
+async def _concurrency_release_hook(matcher: Matcher):
+ """
+ 后处理器:在事件处理结束后,释放并发限制的信号量。
+ """
+ if concurrency_info := matcher.state.get("_concurrency_limiter_info"):
+ limiter: ConcurrencyLimiter = concurrency_info["limiter"]
+ key = concurrency_info["key"]
+ limiter.release(key)
diff --git a/zhenxun/builtin_plugins/statistics/_data_source.py b/zhenxun/builtin_plugins/statistics/_data_source.py
index 81e2b035..ab426ae6 100644
--- a/zhenxun/builtin_plugins/statistics/_data_source.py
+++ b/zhenxun/builtin_plugins/statistics/_data_source.py
@@ -8,7 +8,7 @@ from zhenxun.utils.echart_utils import ChartUtils
from zhenxun.utils.echart_utils.models import Barh
from zhenxun.utils.enum import PluginType
from zhenxun.utils.image_utils import BuildImage
-from zhenxun.utils.utils import TimeUtils
+from zhenxun.utils.time_utils import TimeUtils
class StatisticsManage:
diff --git a/zhenxun/utils/depends/__init__.py b/zhenxun/utils/depends/__init__.py
index 887813dd..e52dfd48 100644
--- a/zhenxun/utils/depends/__init__.py
+++ b/zhenxun/utils/depends/__init__.py
@@ -1,13 +1,181 @@
-from typing import Any
+from typing import Any, Literal
+from nonebot.adapters import Bot, Event
from nonebot.internal.params import Depends
from nonebot.matcher import Matcher
from nonebot.params import Command
+from nonebot.permission import SUPERUSER
from nonebot_plugin_session import EventSession
from nonebot_plugin_uninfo import Uninfo
from zhenxun.configs.config import Config
+from zhenxun.utils.limiters import ConcurrencyLimiter, FreqLimiter, RateLimiter
from zhenxun.utils.message import MessageUtils
+from zhenxun.utils.time_utils import TimeUtils
+
+_coolers: dict[str, FreqLimiter] = {}
+_rate_limiters: dict[str, RateLimiter] = {}
+_concurrency_limiters: dict[str, ConcurrencyLimiter] = {}
+
+
+def _create_limiter_dependency(
+ limiter_class: type,
+ limiter_storage: dict,
+ limiter_init_args: dict[str, Any],
+ scope: Literal["user", "group", "global"],
+ prompt: str,
+ **kwargs,
+):
+ """
+ 一个高阶函数,用于创建不同类型的限制器依赖。
+
+ 参数:
+ limiter_class: 限制器类 (FreqLimiter, RateLimiter, etc.).
+ limiter_storage: 用于存储限制器实例的字典.
+ limiter_init_args: 限制器类的初始化参数.
+ scope: 限制作用域.
+ prompt: 触发限制时的提示信息.
+ **kwargs: 传递给特定限制器逻辑的额外参数.
+ """
+
+ async def dependency(
+ matcher: Matcher, session: EventSession, bot: Bot, event: Event
+ ) -> bool:
+ if await SUPERUSER(bot, event):
+ return True
+
+ handler_id = (
+ f"{matcher.plugin_name}:{matcher.handlers[0].call.__code__.co_firstlineno}"
+ )
+
+ key: str | None = None
+ if scope == "user":
+ key = session.id1
+ elif scope == "group":
+ key = session.id3 or session.id2 or session.id1
+ elif scope == "global":
+ key = f"global_{handler_id}"
+
+ if not key:
+ return True
+
+ if handler_id not in limiter_storage:
+ limiter_storage[handler_id] = limiter_class(**limiter_init_args)
+ limiter = limiter_storage[handler_id]
+
+ if isinstance(limiter, ConcurrencyLimiter):
+ await limiter.acquire(key)
+ matcher.state["_concurrency_limiter_info"] = {
+ "limiter": limiter,
+ "key": key,
+ }
+ return True
+ else:
+ if limiter.check(key):
+ if isinstance(limiter, FreqLimiter):
+ limiter.start_cd(
+ key, kwargs.get("duration_sec", limiter.default_cd)
+ )
+ return True
+ else:
+ left_time = limiter.left_time(key)
+ format_kwargs = {
+ "cd_str": TimeUtils.format_duration(left_time),
+ **(kwargs.get("prompt_format_kwargs", {})),
+ }
+ message = prompt.format(**format_kwargs)
+ await matcher.finish(message)
+
+ return Depends(dependency)
+
+
+def Cooldown(
+ duration: str,
+ *,
+ scope: Literal["user", "group", "global"] = "user",
+ prompt: str = "操作过于频繁,请等待 {cd_str}",
+) -> bool:
+ """声明式冷却检查依赖,限制用户操作频率
+
+ 参数:
+ duration: 冷却时间字符串 (e.g., "30s", "10m", "1h")
+ scope: 冷却作用域
+ prompt: 自定义的冷却提示消息,可使用 {cd_str} 占位符
+
+ 返回:
+ bool: 是否允许执行
+ """
+ try:
+ parsed_seconds = TimeUtils.parse_time_string(duration)
+ except ValueError as e:
+ raise ValueError(f"Cooldown装饰器中的duration格式错误: {e}")
+
+ return _create_limiter_dependency(
+ limiter_class=FreqLimiter,
+ limiter_storage=_coolers,
+ limiter_init_args={"default_cd_seconds": parsed_seconds},
+ scope=scope,
+ prompt=prompt,
+ duration_sec=parsed_seconds,
+ )
+
+
+def RateLimit(
+ count: int,
+ duration: str,
+ *,
+ scope: Literal["user", "group", "global"] = "user",
+ prompt: str = "太快了,在 {duration_str} 内只能触发{limit}次,请等待 {cd_str}",
+) -> bool:
+ """声明式速率限制依赖,在指定时间窗口内限制操作次数
+
+ 参数:
+ count: 在时间窗口内允许的最大调用次数
+ duration: 时间窗口字符串 (e.g., "1m", "1h")
+ scope: 限制作用域
+ prompt: 自定义的提示消息,可使用 {cd_str}, {duration_str}, {limit} 占位符
+
+ 返回:
+ bool: 是否允许执行
+ """
+ try:
+ parsed_seconds = TimeUtils.parse_time_string(duration)
+ except ValueError as e:
+ raise ValueError(f"RateLimit装饰器中的duration格式错误: {e}")
+
+ return _create_limiter_dependency(
+ limiter_class=RateLimiter,
+ limiter_storage=_rate_limiters,
+ limiter_init_args={"max_calls": count, "time_window": parsed_seconds},
+ scope=scope,
+ prompt=prompt,
+ prompt_format_kwargs={"duration_str": duration, "limit": count},
+ )
+
+
+def ConcurrencyLimit(
+ count: int,
+ *,
+ scope: Literal["user", "group", "global"] = "global",
+ prompt: str | None = "当前功能繁忙,请稍后再试...",
+) -> bool:
+ """声明式并发数限制依赖,控制某个功能同时执行的实例数量
+
+ 参数:
+ count: 最大并发数
+ scope: 限制作用域
+ prompt: 提示消息(暂未使用,主要用于未来扩展超时功能)
+
+ 返回:
+ bool: 是否允许执行
+ """
+ return _create_limiter_dependency(
+ limiter_class=ConcurrencyLimiter,
+ limiter_storage=_concurrency_limiters,
+ limiter_init_args={"max_concurrent": count},
+ scope=scope,
+ prompt=prompt or "",
+ )
def CheckUg(check_user: bool = True, check_group: bool = True):
@@ -75,7 +243,6 @@ def GetConfig(
if module_:
value = Config.get_config(module_, config, default_value)
if value is None and prompt:
- # await matcher.finish(prompt or f"配置项 {config} 未填写!")
await matcher.finish(prompt)
return value
diff --git a/zhenxun/utils/exception.py b/zhenxun/utils/exception.py
index 9ab664f4..8b3ec282 100644
--- a/zhenxun/utils/exception.py
+++ b/zhenxun/utils/exception.py
@@ -1,3 +1,17 @@
+from nonebot.exception import IgnoredException
+
+
+class CooldownError(IgnoredException):
+ """
+ 冷却异常,用于在冷却时中断事件处理。
+ 继承自 IgnoredException,不会在控制台留下错误堆栈。
+ """
+
+ def __init__(self, message: str):
+ self.message = message
+ super().__init__(message)
+
+
class HookPriorityException(BaseException):
"""
钩子优先级异常
diff --git a/zhenxun/utils/limiters.py b/zhenxun/utils/limiters.py
new file mode 100644
index 00000000..1bd5f662
--- /dev/null
+++ b/zhenxun/utils/limiters.py
@@ -0,0 +1,140 @@
+import asyncio
+from collections import defaultdict, deque
+import time
+from typing import Any
+
+
+class FreqLimiter:
+ """
+ 命令冷却,检测用户是否处于冷却状态
+ """
+
+ def __init__(self, default_cd_seconds: int):
+ self.next_time: dict[Any, float] = defaultdict(float)
+ self.default_cd = default_cd_seconds
+
+ def check(self, key: Any) -> bool:
+ return time.time() >= self.next_time[key]
+
+ def start_cd(self, key: Any, cd_time: int = 0):
+ self.next_time[key] = time.time() + (
+ cd_time if cd_time > 0 else self.default_cd
+ )
+
+ def left_time(self, key: Any) -> float:
+ return max(0.0, self.next_time[key] - time.time())
+
+
+class CountLimiter:
+ """
+ 每日调用命令次数限制
+ """
+
+ tz = None
+
+ def __init__(self, max_num: int):
+ self.today = -1
+ self.count: dict[Any, int] = defaultdict(int)
+ self.max = max_num
+
+ def check(self, key: Any) -> bool:
+ import datetime
+
+ day = datetime.datetime.now().day
+ if day != self.today:
+ self.today = day
+ self.count.clear()
+ return self.count[key] < self.max
+
+ def get_num(self, key: Any) -> int:
+ return self.count[key]
+
+ def increase(self, key: Any, num: int = 1):
+ self.count[key] += num
+
+ def reset(self, key: Any):
+ self.count[key] = 0
+
+
+class UserBlockLimiter:
+ """
+ 检测用户是否正在调用命令 (简单阻塞锁)
+ """
+
+ def __init__(self):
+ self.flag_data: dict[Any, bool] = defaultdict(bool)
+ self.time: dict[Any, float] = defaultdict(float)
+
+ def set_true(self, key: Any):
+ self.time[key] = time.time()
+ self.flag_data[key] = True
+
+ def set_false(self, key: Any):
+ self.flag_data[key] = False
+
+ def check(self, key: Any) -> bool:
+ if self.flag_data[key] and time.time() - self.time[key] > 30:
+ self.set_false(key)
+ return not self.flag_data[key]
+
+
+class RateLimiter:
+ """
+ 一个简单的基于时间窗口的速率限制器。
+ """
+
+ def __init__(self, max_calls: int, time_window: int):
+ self.requests: dict[Any, deque[float]] = defaultdict(deque)
+ self.max_calls = max_calls
+ self.time_window = time_window
+
+ def check(self, key: Any) -> bool:
+ """检查是否超出速率限制。如果未超出,则记录本次调用。"""
+ now = time.time()
+
+ while self.requests[key] and self.requests[key][0] <= now - self.time_window:
+ self.requests[key].popleft()
+
+ if len(self.requests[key]) < self.max_calls:
+ self.requests[key].append(now)
+ return True
+ return False
+
+ def left_time(self, key: Any) -> float:
+ """计算距离下次可调用还需等待的时间"""
+ if self.requests[key]:
+ return max(0.0, self.requests[key][0] + self.time_window - time.time())
+ return 0.0
+
+
+class ConcurrencyLimiter:
+ """
+ 一个基于 asyncio.Semaphore 的并发限制器。
+ """
+
+ def __init__(self, max_concurrent: int):
+ self._semaphores: dict[Any, asyncio.Semaphore] = {}
+ self.max_concurrent = max_concurrent
+ self._active_tasks: dict[Any, int] = defaultdict(int)
+
+ def _get_semaphore(self, key: Any) -> asyncio.Semaphore:
+ if key not in self._semaphores:
+ self._semaphores[key] = asyncio.Semaphore(self.max_concurrent)
+ return self._semaphores[key]
+
+ async def acquire(self, key: Any):
+ """获取一个信号量,如果达到并发上限则会阻塞等待。"""
+ semaphore = self._get_semaphore(key)
+ await semaphore.acquire()
+ self._active_tasks[key] += 1
+
+ def release(self, key: Any):
+ """释放一个信号量。"""
+ if key in self._semaphores:
+ if self._active_tasks[key] > 0:
+ self._semaphores[key].release()
+ self._active_tasks[key] -= 1
+ else:
+ import logging
+
+ logging.warning(f"尝试释放键 '{key}' 的信号量时,计数已经为零。")
diff --git a/zhenxun/utils/message.py b/zhenxun/utils/message.py
index 927b050c..5fec2213 100644
--- a/zhenxun/utils/message.py
+++ b/zhenxun/utils/message.py
@@ -49,11 +49,14 @@ class Config(BaseModel):
class MessageUtils:
@classmethod
- def __build_message(cls, msg_list: list[MESSAGE_TYPE]) -> list[Text | Image]:
+ def __build_message(
+ cls, msg_list: list[MESSAGE_TYPE], format_args: dict | None = None
+ ) -> list[Text | Image]:
"""构造消息
参数:
msg_list: 消息列表
+ format_args: 用于格式化字符串的参数字典.
返回:
list[Text | Text]: 构造完成的消息列表
@@ -65,7 +68,15 @@ class MessageUtils:
if msg.startswith("base64://"):
message_list.append(Image(raw=BytesIO(base64.b64decode(msg[9:]))))
else:
- message_list.append(Text(msg))
+ formatted_msg = msg
+ if format_args:
+ try:
+ formatted_msg = msg.format_map(format_args)
+ except (KeyError, IndexError) as e:
+ logger.debug(
+ f"格式化字符串 '{msg}' 失败 ({e}),将使用原始文本。"
+ )
+ message_list.append(Text(formatted_msg))
elif isinstance(msg, int | float):
message_list.append(Text(str(msg)))
elif isinstance(msg, Path):
@@ -90,12 +101,15 @@ class MessageUtils:
@classmethod
def build_message(
- cls, msg_list: MESSAGE_TYPE | list[MESSAGE_TYPE | list[MESSAGE_TYPE]]
+ cls,
+ msg_list: MESSAGE_TYPE | list[MESSAGE_TYPE | list[MESSAGE_TYPE]],
+ format_args: dict | None = None,
) -> UniMessage:
"""构造消息
参数:
msg_list: 消息列表
+ format_args: 用于格式化字符串的参数字典.
返回:
UniMessage: 构造完成的消息列表
@@ -105,7 +119,7 @@ class MessageUtils:
msg_list = [msg_list]
for m in msg_list:
_data = m if isinstance(m, list) else [m]
- message_list += cls.__build_message(_data) # type: ignore
+ message_list += cls.__build_message(_data, format_args)
return UniMessage(message_list)
@classmethod
diff --git a/zhenxun/utils/time_utils.py b/zhenxun/utils/time_utils.py
new file mode 100644
index 00000000..f478625d
--- /dev/null
+++ b/zhenxun/utils/time_utils.py
@@ -0,0 +1,91 @@
+from datetime import date, datetime
+import re
+
+import pytz
+
+
+class TimeUtils:
+ DEFAULT_TIMEZONE = pytz.timezone("Asia/Shanghai")
+
+ @classmethod
+ def get_day_start(cls, target_date: date | datetime | None = None) -> datetime:
+ """获取某天的0点时间
+
+ 返回:
+ datetime: 今天某天的0点时间
+ """
+ if not target_date:
+ target_date = datetime.now(cls.DEFAULT_TIMEZONE)
+
+ if isinstance(target_date, datetime) and target_date.tzinfo is None:
+ target_date = cls.DEFAULT_TIMEZONE.localize(target_date)
+
+ return (
+ target_date.replace(hour=0, minute=0, second=0, microsecond=0)
+ if isinstance(target_date, datetime)
+ else datetime.combine(
+ target_date, datetime.min.time(), tzinfo=cls.DEFAULT_TIMEZONE
+ )
+ )
+
+ @classmethod
+ def is_valid_date(cls, date_text: str, separator: str = "-") -> bool:
+ """日期是否合法
+
+ 参数:
+ date_text: 日期
+ separator: 分隔符
+
+ 返回:
+ bool: 日期是否合法
+ """
+ try:
+ datetime.strptime(date_text, f"%Y{separator}%m{separator}%d")
+ return True
+ except ValueError:
+ return False
+
+ @classmethod
+ def parse_time_string(cls, time_str: str) -> int:
+ """
+ 将带有单位的时间字符串 (e.g., "10s", "5m", "1h") 解析为总秒数。
+ """
+ time_str = time_str.lower().strip()
+ match = re.match(r"^(\d+)([smh])$", time_str)
+ if not match:
+ raise ValueError(
+ f"无效的时间格式: '{time_str}'。请使用如 '30s', '10m', '2h' 的格式。"
+ )
+
+ value, unit = int(match.group(1)), match.group(2)
+
+ if unit == "s":
+ return value
+ if unit == "m":
+ return value * 60
+ if unit == "h":
+ return value * 3600
+ return 0
+
+ @classmethod
+ def format_duration(cls, seconds: float) -> str:
+ """
+ 将秒数格式化为易于阅读的字符串 (例如 "1小时5分钟", "30.5秒")
+ """
+ seconds = round(seconds, 1)
+ if seconds < 0.1:
+ return "不到1秒"
+ if seconds < 60:
+ return f"{seconds}秒"
+
+ minutes, sec_remainder = divmod(int(seconds), 60)
+
+ if minutes < 60:
+ if sec_remainder == 0:
+ return f"{minutes}分钟"
+ return f"{minutes}分钟{sec_remainder}秒"
+
+ hours, rem_minutes = divmod(minutes, 60)
+ if rem_minutes == 0:
+ return f"{hours}小时"
+ return f"{hours}小时{rem_minutes}分钟"
diff --git a/zhenxun/utils/utils.py b/zhenxun/utils/utils.py
index 44dcd672..fc6b4096 100644
--- a/zhenxun/utils/utils.py
+++ b/zhenxun/utils/utils.py
@@ -1,19 +1,19 @@
-from collections import defaultdict
from dataclasses import dataclass
-from datetime import date, datetime
+from datetime import datetime
import os
from pathlib import Path
import time
-from typing import Any, ClassVar
+from typing import ClassVar
import httpx
from nonebot_plugin_uninfo import Uninfo
import pypinyin
-import pytz
from zhenxun.configs.config import Config
from zhenxun.services.log import logger
+from .limiters import CountLimiter, FreqLimiter, UserBlockLimiter # noqa: F401
+
@dataclass
class EntityIDs:
@@ -64,78 +64,6 @@ class ResourceDirManager:
cls.__tree_append(path, deep)
-class CountLimiter:
- """
- 每日调用命令次数限制
- """
-
- tz = pytz.timezone("Asia/Shanghai")
-
- def __init__(self, max_num):
- self.today = -1
- self.count = defaultdict(int)
- self.max = max_num
-
- def check(self, key) -> bool:
- day = datetime.now(self.tz).day
- if day != self.today:
- self.today = day
- self.count.clear()
- return self.count[key] < self.max
-
- def get_num(self, key):
- return self.count[key]
-
- def increase(self, key, num=1):
- self.count[key] += num
-
- def reset(self, key):
- self.count[key] = 0
-
-
-class UserBlockLimiter:
- """
- 检测用户是否正在调用命令
- """
-
- def __init__(self):
- self.flag_data = defaultdict(bool)
- self.time = time.time()
-
- def set_true(self, key: Any):
- self.time = time.time()
- self.flag_data[key] = True
-
- def set_false(self, key: Any):
- self.flag_data[key] = False
-
- def check(self, key: Any) -> bool:
- if time.time() - self.time > 30:
- self.set_false(key)
- return not self.flag_data[key]
-
-
-class FreqLimiter:
- """
- 命令冷却,检测用户是否处于冷却状态
- """
-
- def __init__(self, default_cd_seconds: int):
- self.next_time = defaultdict(float)
- self.default_cd = default_cd_seconds
-
- def check(self, key: Any) -> bool:
- return time.time() >= self.next_time[key]
-
- def start_cd(self, key: Any, cd_time: int = 0):
- self.next_time[key] = time.time() + (
- cd_time if cd_time > 0 else self.default_cd
- )
-
- def left_time(self, key: Any) -> float:
- return self.next_time[key] - time.time()
-
-
def cn2py(word: str) -> str:
"""将字符串转化为拼音
@@ -277,20 +205,3 @@ def is_number(text: str) -> bool:
return True
except ValueError:
return False
-
-
-class TimeUtils:
- @classmethod
- def get_day_start(cls, target_date: date | datetime | None = None) -> datetime:
- """获取某天的0点时间
-
- 返回:
- datetime: 今天某天的0点时间
- """
- if not target_date:
- target_date = datetime.now()
- return (
- target_date.replace(hour=0, minute=0, second=0, microsecond=0)
- if isinstance(target_date, datetime)
- else datetime.combine(target_date, datetime.min.time())
- )