mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 14:22:55 +08:00
Merge branch 'main' into feature/bot-profile
This commit is contained in:
commit
e3e8f5b89b
@ -5,7 +5,7 @@ import nonebot
|
|||||||
from nonebot.adapters import Bot
|
from nonebot.adapters import Bot
|
||||||
from nonebot.drivers import Driver
|
from nonebot.drivers import Driver
|
||||||
from tortoise import Tortoise
|
from tortoise import Tortoise
|
||||||
from tortoise.exceptions import OperationalError
|
from tortoise.exceptions import IntegrityError, OperationalError
|
||||||
import ujson as json
|
import ujson as json
|
||||||
|
|
||||||
from zhenxun.models.bot_connect_log import BotConnectLog
|
from zhenxun.models.bot_connect_log import BotConnectLog
|
||||||
@ -30,9 +30,12 @@ async def _(bot: Bot):
|
|||||||
bot_id=bot.self_id, platform=bot.adapter, connect_time=datetime.now(), type=1
|
bot_id=bot.self_id, platform=bot.adapter, connect_time=datetime.now(), type=1
|
||||||
)
|
)
|
||||||
if not await BotConsole.exists(bot_id=bot.self_id):
|
if not await BotConsole.exists(bot_id=bot.self_id):
|
||||||
await BotConsole.create(
|
try:
|
||||||
bot_id=bot.self_id, platform=PlatformUtils.get_platform(bot)
|
await BotConsole.create(
|
||||||
)
|
bot_id=bot.self_id, platform=PlatformUtils.get_platform(bot)
|
||||||
|
)
|
||||||
|
except IntegrityError as e:
|
||||||
|
logger.warning(f"记录bot: {bot.self_id} 数据已存在...", e=e)
|
||||||
|
|
||||||
|
|
||||||
@driver.on_bot_disconnect
|
@driver.on_bot_disconnect
|
||||||
|
|||||||
@ -11,14 +11,11 @@ from zhenxun.models.plugin_limit import PluginLimit
|
|||||||
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
from zhenxun.utils.enum import LimitWatchType, PluginLimitType
|
from zhenxun.utils.enum import LimitWatchType, PluginLimitType
|
||||||
|
from zhenxun.utils.limiters import CountLimiter, FreqLimiter, UserBlockLimiter
|
||||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||||
from zhenxun.utils.message import MessageUtils
|
from zhenxun.utils.message import MessageUtils
|
||||||
from zhenxun.utils.utils import (
|
from zhenxun.utils.time_utils import TimeUtils
|
||||||
CountLimiter,
|
from zhenxun.utils.utils import get_entity_ids
|
||||||
FreqLimiter,
|
|
||||||
UserBlockLimiter,
|
|
||||||
get_entity_ids,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
|
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
|
||||||
from .exception import SkipPluginException
|
from .exception import SkipPluginException
|
||||||
@ -273,9 +270,16 @@ class LimitManager:
|
|||||||
key_type = channel_id or group_id
|
key_type = channel_id or group_id
|
||||||
if is_limit and not limiter.check(key_type):
|
if is_limit and not limiter.check(key_type):
|
||||||
if limit.result:
|
if limit.result:
|
||||||
|
format_kwargs = {}
|
||||||
|
if isinstance(limiter, FreqLimiter):
|
||||||
|
left_time = limiter.left_time(key_type)
|
||||||
|
cd_str = TimeUtils.format_duration(left_time)
|
||||||
|
format_kwargs = {"cd": cd_str}
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(
|
||||||
MessageUtils.build_message(limit.result).send(),
|
MessageUtils.build_message(
|
||||||
|
limit.result, format_args=format_kwargs
|
||||||
|
).send(),
|
||||||
timeout=DB_TIMEOUT_SECONDS,
|
timeout=DB_TIMEOUT_SECONDS,
|
||||||
)
|
)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
|
|||||||
15
zhenxun/builtin_plugins/hooks/limiter_hook.py
Normal file
15
zhenxun/builtin_plugins/hooks/limiter_hook.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from nonebot.matcher import Matcher
|
||||||
|
from nonebot.message import run_postprocessor
|
||||||
|
|
||||||
|
from zhenxun.utils.limiters import ConcurrencyLimiter
|
||||||
|
|
||||||
|
|
||||||
|
@run_postprocessor
|
||||||
|
async def _concurrency_release_hook(matcher: Matcher):
|
||||||
|
"""
|
||||||
|
后处理器:在事件处理结束后,释放并发限制的信号量。
|
||||||
|
"""
|
||||||
|
if concurrency_info := matcher.state.get("_concurrency_limiter_info"):
|
||||||
|
limiter: ConcurrencyLimiter = concurrency_info["limiter"]
|
||||||
|
key = concurrency_info["key"]
|
||||||
|
limiter.release(key)
|
||||||
@ -8,7 +8,7 @@ from zhenxun.utils.echart_utils import ChartUtils
|
|||||||
from zhenxun.utils.echart_utils.models import Barh
|
from zhenxun.utils.echart_utils.models import Barh
|
||||||
from zhenxun.utils.enum import PluginType
|
from zhenxun.utils.enum import PluginType
|
||||||
from zhenxun.utils.image_utils import BuildImage
|
from zhenxun.utils.image_utils import BuildImage
|
||||||
from zhenxun.utils.utils import TimeUtils
|
from zhenxun.utils.time_utils import TimeUtils
|
||||||
|
|
||||||
|
|
||||||
class StatisticsManage:
|
class StatisticsManage:
|
||||||
|
|||||||
@ -18,7 +18,7 @@ require("nonebot_plugin_htmlrender")
|
|||||||
require("nonebot_plugin_uninfo")
|
require("nonebot_plugin_uninfo")
|
||||||
require("nonebot_plugin_waiter")
|
require("nonebot_plugin_waiter")
|
||||||
|
|
||||||
from .db_context import Model, disconnect
|
from .db_context import Model, disconnect, with_db_timeout
|
||||||
from .llm import (
|
from .llm import (
|
||||||
AI,
|
AI,
|
||||||
AIConfig,
|
AIConfig,
|
||||||
@ -80,4 +80,5 @@ __all__ = [
|
|||||||
"search_multimodal",
|
"search_multimodal",
|
||||||
"set_global_default_model_name",
|
"set_global_default_model_name",
|
||||||
"tool_registry",
|
"tool_registry",
|
||||||
|
"with_db_timeout",
|
||||||
]
|
]
|
||||||
|
|||||||
59
zhenxun/services/cache/__init__.py
vendored
59
zhenxun/services/cache/__init__.py
vendored
@ -63,6 +63,7 @@ from functools import wraps
|
|||||||
from typing import Any, ClassVar, Generic, TypeVar, get_type_hints
|
from typing import Any, ClassVar, Generic, TypeVar, get_type_hints
|
||||||
|
|
||||||
from aiocache import Cache as AioCache
|
from aiocache import Cache as AioCache
|
||||||
|
from aiocache import SimpleMemoryCache
|
||||||
from aiocache.base import BaseCache
|
from aiocache.base import BaseCache
|
||||||
from aiocache.serializers import JsonSerializer
|
from aiocache.serializers import JsonSerializer
|
||||||
import nonebot
|
import nonebot
|
||||||
@ -392,22 +393,14 @@ class CacheManager:
|
|||||||
def cache_backend(self) -> BaseCache | AioCache:
|
def cache_backend(self) -> BaseCache | AioCache:
|
||||||
"""获取缓存后端"""
|
"""获取缓存后端"""
|
||||||
if self._cache_backend is None:
|
if self._cache_backend is None:
|
||||||
try:
|
ttl = cache_config.redis_expire
|
||||||
from aiocache import RedisCache, SimpleMemoryCache
|
if cache_config.cache_mode == CacheMode.NONE:
|
||||||
|
ttl = 0
|
||||||
|
logger.info("缓存功能已禁用,使用非持久化内存缓存", LOG_COMMAND)
|
||||||
|
elif cache_config.cache_mode == CacheMode.REDIS and cache_config.redis_host:
|
||||||
|
try:
|
||||||
|
from aiocache import RedisCache
|
||||||
|
|
||||||
if cache_config.cache_mode == CacheMode.NONE:
|
|
||||||
# 使用内存缓存但禁用持久化
|
|
||||||
self._cache_backend = SimpleMemoryCache(
|
|
||||||
serializer=JsonSerializer(),
|
|
||||||
namespace=CACHE_KEY_PREFIX,
|
|
||||||
timeout=30,
|
|
||||||
ttl=0, # 设置为0,不缓存
|
|
||||||
)
|
|
||||||
logger.info("缓存功能已禁用,使用非持久化内存缓存", LOG_COMMAND)
|
|
||||||
elif (
|
|
||||||
cache_config.cache_mode == CacheMode.REDIS
|
|
||||||
and cache_config.redis_host
|
|
||||||
):
|
|
||||||
# 使用Redis缓存
|
# 使用Redis缓存
|
||||||
self._cache_backend = RedisCache(
|
self._cache_backend = RedisCache(
|
||||||
serializer=JsonSerializer(),
|
serializer=JsonSerializer(),
|
||||||
@ -419,27 +412,25 @@ class CacheManager:
|
|||||||
password=cache_config.redis_password,
|
password=cache_config.redis_password,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"使用Redis缓存,地址: {cache_config.redis_host}", LOG_COMMAND
|
f"使用Redis缓存,地址: {cache_config.redis_host}",
|
||||||
|
LOG_COMMAND,
|
||||||
)
|
)
|
||||||
else:
|
return self._cache_backend
|
||||||
# 默认使用内存缓存
|
except ImportError as e:
|
||||||
self._cache_backend = SimpleMemoryCache(
|
logger.error(
|
||||||
serializer=JsonSerializer(),
|
"导入aiocache[redis]失败,将默认使用内存缓存...",
|
||||||
namespace=CACHE_KEY_PREFIX,
|
LOG_COMMAND,
|
||||||
timeout=30,
|
e=e,
|
||||||
ttl=cache_config.redis_expire,
|
|
||||||
)
|
)
|
||||||
logger.info("使用内存缓存", LOG_COMMAND)
|
else:
|
||||||
except ImportError:
|
logger.info("使用内存缓存", LOG_COMMAND)
|
||||||
logger.error("导入aiocache模块失败,使用内存缓存", LOG_COMMAND)
|
# 默认使用内存缓存
|
||||||
# 使用内存缓存
|
self._cache_backend = SimpleMemoryCache(
|
||||||
self._cache_backend = AioCache(
|
serializer=JsonSerializer(),
|
||||||
cache_class=AioCache.MEMORY,
|
namespace=CACHE_KEY_PREFIX,
|
||||||
serializer=JsonSerializer(),
|
timeout=30,
|
||||||
namespace=CACHE_KEY_PREFIX,
|
ttl=ttl,
|
||||||
timeout=30,
|
)
|
||||||
ttl=cache_config.redis_expire,
|
|
||||||
)
|
|
||||||
return self._cache_backend
|
return self._cache_backend
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
21
zhenxun/services/cache/cache_containers.py
vendored
21
zhenxun/services/cache/cache_containers.py
vendored
@ -54,11 +54,7 @@ class CacheDict:
|
|||||||
key: 字典键
|
key: 字典键
|
||||||
value: 字典值
|
value: 字典值
|
||||||
"""
|
"""
|
||||||
# 计算过期时间
|
expire_time = time.time() + self.expire if self.expire > 0 else 0
|
||||||
expire_time = 0
|
|
||||||
if self.expire > 0:
|
|
||||||
expire_time = time.time() + self.expire
|
|
||||||
|
|
||||||
self._data[key] = CacheData(value=value, expire_time=expire_time)
|
self._data[key] = CacheData(value=value, expire_time=expire_time)
|
||||||
|
|
||||||
def __delitem__(self, key: str) -> None:
|
def __delitem__(self, key: str) -> None:
|
||||||
@ -274,12 +270,11 @@ class CacheList:
|
|||||||
self.clear()
|
self.clear()
|
||||||
raise IndexError(f"列表索引 {index} 超出范围")
|
raise IndexError(f"列表索引 {index} 超出范围")
|
||||||
|
|
||||||
if 0 <= index < len(self._data):
|
if not 0 <= index < len(self._data):
|
||||||
del self._data[index]
|
|
||||||
# 更新过期时间
|
|
||||||
self._update_expire_time()
|
|
||||||
else:
|
|
||||||
raise IndexError(f"列表索引 {index} 超出范围")
|
raise IndexError(f"列表索引 {index} 超出范围")
|
||||||
|
del self._data[index]
|
||||||
|
# 更新过期时间
|
||||||
|
self._update_expire_time()
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
"""获取列表长度
|
"""获取列表长度
|
||||||
@ -427,6 +422,7 @@ class CacheList:
|
|||||||
self.clear()
|
self.clear()
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
# sourcery skip: simplify-constant-sum
|
||||||
return sum(1 for item in self._data if item.value == value)
|
return sum(1 for item in self._data if item.value == value)
|
||||||
|
|
||||||
def _is_expired(self) -> bool:
|
def _is_expired(self) -> bool:
|
||||||
@ -435,10 +431,7 @@ class CacheList:
|
|||||||
|
|
||||||
def _update_expire_time(self) -> None:
|
def _update_expire_time(self) -> None:
|
||||||
"""更新过期时间"""
|
"""更新过期时间"""
|
||||||
if self.expire > 0:
|
self._expire_time = time.time() + self.expire if self.expire > 0 else 0
|
||||||
self._expire_time = time.time() + self.expire
|
|
||||||
else:
|
|
||||||
self._expire_time = 0
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
"""字符串表示
|
"""字符串表示
|
||||||
|
|||||||
146
zhenxun/services/db_context/__init__.py
Normal file
146
zhenxun/services/db_context/__init__.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
import asyncio
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import nonebot
|
||||||
|
from nonebot.utils import is_coroutine_callable
|
||||||
|
from tortoise import Tortoise
|
||||||
|
from tortoise.connection import connections
|
||||||
|
|
||||||
|
from zhenxun.configs.config import BotConfig
|
||||||
|
from zhenxun.services.log import logger
|
||||||
|
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||||
|
|
||||||
|
from .base_model import Model
|
||||||
|
from .config import (
|
||||||
|
DB_TIMEOUT_SECONDS,
|
||||||
|
MYSQL_CONFIG,
|
||||||
|
POSTGRESQL_CONFIG,
|
||||||
|
SLOW_QUERY_THRESHOLD,
|
||||||
|
SQLITE_CONFIG,
|
||||||
|
db_model,
|
||||||
|
prompt,
|
||||||
|
)
|
||||||
|
from .exceptions import DbConnectError, DbUrlIsNode
|
||||||
|
from .utils import with_db_timeout
|
||||||
|
|
||||||
|
MODELS = db_model.models
|
||||||
|
SCRIPT_METHOD = db_model.script_method
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DB_TIMEOUT_SECONDS",
|
||||||
|
"MODELS",
|
||||||
|
"SCRIPT_METHOD",
|
||||||
|
"SLOW_QUERY_THRESHOLD",
|
||||||
|
"DbConnectError",
|
||||||
|
"DbUrlIsNode",
|
||||||
|
"Model",
|
||||||
|
"disconnect",
|
||||||
|
"init",
|
||||||
|
"with_db_timeout",
|
||||||
|
]
|
||||||
|
|
||||||
|
driver = nonebot.get_driver()
|
||||||
|
|
||||||
|
|
||||||
|
def get_config() -> dict:
|
||||||
|
"""获取数据库配置"""
|
||||||
|
parsed = urlparse(BotConfig.db_url)
|
||||||
|
|
||||||
|
# 基础配置
|
||||||
|
config = {
|
||||||
|
"connections": {
|
||||||
|
"default": BotConfig.db_url # 默认直接使用连接字符串
|
||||||
|
},
|
||||||
|
"apps": {
|
||||||
|
"models": {
|
||||||
|
"models": db_model.models,
|
||||||
|
"default_connection": "default",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"timezone": "Asia/Shanghai",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 根据数据库类型应用高级配置
|
||||||
|
if parsed.scheme.startswith("postgres"):
|
||||||
|
config["connections"]["default"] = {
|
||||||
|
"engine": "tortoise.backends.asyncpg",
|
||||||
|
"credentials": {
|
||||||
|
"host": parsed.hostname,
|
||||||
|
"port": parsed.port or 5432,
|
||||||
|
"user": parsed.username,
|
||||||
|
"password": parsed.password,
|
||||||
|
"database": parsed.path[1:],
|
||||||
|
},
|
||||||
|
**POSTGRESQL_CONFIG,
|
||||||
|
}
|
||||||
|
elif parsed.scheme == "mysql":
|
||||||
|
config["connections"]["default"] = {
|
||||||
|
"engine": "tortoise.backends.mysql",
|
||||||
|
"credentials": {
|
||||||
|
"host": parsed.hostname,
|
||||||
|
"port": parsed.port or 3306,
|
||||||
|
"user": parsed.username,
|
||||||
|
"password": parsed.password,
|
||||||
|
"database": parsed.path[1:],
|
||||||
|
},
|
||||||
|
**MYSQL_CONFIG,
|
||||||
|
}
|
||||||
|
elif parsed.scheme == "sqlite":
|
||||||
|
config["connections"]["default"] = {
|
||||||
|
"engine": "tortoise.backends.sqlite",
|
||||||
|
"credentials": {
|
||||||
|
"file_path": parsed.path or ":memory:",
|
||||||
|
},
|
||||||
|
**SQLITE_CONFIG,
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
@PriorityLifecycle.on_startup(priority=1)
|
||||||
|
async def init():
|
||||||
|
global MODELS, SCRIPT_METHOD
|
||||||
|
|
||||||
|
MODELS = db_model.models
|
||||||
|
SCRIPT_METHOD = db_model.script_method
|
||||||
|
if not BotConfig.db_url:
|
||||||
|
error = prompt.format(host=driver.config.host, port=driver.config.port)
|
||||||
|
raise DbUrlIsNode("\n" + error.strip())
|
||||||
|
try:
|
||||||
|
await Tortoise.init(
|
||||||
|
config=get_config(),
|
||||||
|
)
|
||||||
|
if db_model.script_method:
|
||||||
|
db = Tortoise.get_connection("default")
|
||||||
|
logger.debug(
|
||||||
|
"即将运行SCRIPT_METHOD方法, 合计 "
|
||||||
|
f"<u><y>{len(db_model.script_method)}</y></u> 个..."
|
||||||
|
)
|
||||||
|
sql_list = []
|
||||||
|
for module, func in db_model.script_method:
|
||||||
|
try:
|
||||||
|
sql = await func() if is_coroutine_callable(func) else func()
|
||||||
|
if sql:
|
||||||
|
sql_list += sql
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"{module} 执行SCRIPT_METHOD方法出错...", e=e)
|
||||||
|
for sql in sql_list:
|
||||||
|
logger.debug(f"执行SQL: {sql}")
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
db.execute_query_dict(sql), timeout=DB_TIMEOUT_SECONDS
|
||||||
|
)
|
||||||
|
# await TestSQL.raw(sql)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"执行SQL: {sql} 错误...", e=e)
|
||||||
|
if sql_list:
|
||||||
|
logger.debug("SCRIPT_METHOD方法执行完毕!")
|
||||||
|
logger.debug("开始生成数据库表结构...")
|
||||||
|
await Tortoise.generate_schemas()
|
||||||
|
logger.debug("数据库表结构生成完毕!")
|
||||||
|
logger.info("Database loaded successfully!")
|
||||||
|
except Exception as e:
|
||||||
|
raise DbConnectError(f"数据库连接错误... e:{e}") from e
|
||||||
|
|
||||||
|
|
||||||
|
async def disconnect():
|
||||||
|
await connections.close_all()
|
||||||
@ -1,56 +1,20 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
import contextlib
|
import contextlib
|
||||||
import time
|
|
||||||
from typing import Any, ClassVar
|
from typing import Any, ClassVar
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
from nonebot import get_driver
|
|
||||||
from nonebot.utils import is_coroutine_callable
|
|
||||||
from tortoise import Tortoise
|
|
||||||
from tortoise.backends.base.client import BaseDBAsyncClient
|
from tortoise.backends.base.client import BaseDBAsyncClient
|
||||||
from tortoise.connection import connections
|
|
||||||
from tortoise.exceptions import IntegrityError, MultipleObjectsReturned
|
from tortoise.exceptions import IntegrityError, MultipleObjectsReturned
|
||||||
from tortoise.models import Model as TortoiseModel
|
from tortoise.models import Model as TortoiseModel
|
||||||
from tortoise.transactions import in_transaction
|
from tortoise.transactions import in_transaction
|
||||||
|
|
||||||
from zhenxun.configs.config import BotConfig
|
|
||||||
from zhenxun.services.cache import CacheRoot
|
from zhenxun.services.cache import CacheRoot
|
||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
from zhenxun.utils.enum import DbLockType
|
from zhenxun.utils.enum import DbLockType
|
||||||
from zhenxun.utils.exception import HookPriorityException
|
|
||||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
|
||||||
|
|
||||||
driver = get_driver()
|
from .config import LOG_COMMAND, db_model
|
||||||
|
from .utils import with_db_timeout
|
||||||
SCRIPT_METHOD = []
|
|
||||||
MODELS: list[str] = []
|
|
||||||
|
|
||||||
# 数据库操作超时设置(秒)
|
|
||||||
DB_TIMEOUT_SECONDS = 3.0
|
|
||||||
|
|
||||||
# 性能监控阈值(秒)
|
|
||||||
SLOW_QUERY_THRESHOLD = 0.5
|
|
||||||
|
|
||||||
LOG_COMMAND = "DbContext"
|
|
||||||
|
|
||||||
|
|
||||||
async def with_db_timeout(
|
|
||||||
coro, timeout: float = DB_TIMEOUT_SECONDS, operation: str | None = None
|
|
||||||
):
|
|
||||||
"""带超时控制的数据库操作"""
|
|
||||||
start_time = time.time()
|
|
||||||
try:
|
|
||||||
result = await asyncio.wait_for(coro, timeout=timeout)
|
|
||||||
elapsed = time.time() - start_time
|
|
||||||
if elapsed > SLOW_QUERY_THRESHOLD and operation:
|
|
||||||
logger.warning(f"慢查询: {operation} 耗时 {elapsed:.3f}s", LOG_COMMAND)
|
|
||||||
return result
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
if operation:
|
|
||||||
logger.error(f"数据库操作超时: {operation} (>{timeout}s)", LOG_COMMAND)
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
class Model(TortoiseModel):
|
class Model(TortoiseModel):
|
||||||
@ -63,11 +27,11 @@ class Model(TortoiseModel):
|
|||||||
|
|
||||||
def __init_subclass__(cls, **kwargs):
|
def __init_subclass__(cls, **kwargs):
|
||||||
super().__init_subclass__(**kwargs)
|
super().__init_subclass__(**kwargs)
|
||||||
if cls.__module__ not in MODELS:
|
if cls.__module__ not in db_model.models:
|
||||||
MODELS.append(cls.__module__)
|
db_model.models.append(cls.__module__)
|
||||||
|
|
||||||
if func := getattr(cls, "_run_script", None):
|
if func := getattr(cls, "_run_script", None):
|
||||||
SCRIPT_METHOD.append((cls.__module__, func))
|
db_model.script_method.append((cls.__module__, func))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_cache_type(cls) -> str | None:
|
def get_cache_type(cls) -> str | None:
|
||||||
@ -322,144 +286,3 @@ class Model(TortoiseModel):
|
|||||||
f"数据库操作异常: {cls.__name__}.safe_get_or_none, {e!s}", LOG_COMMAND
|
f"数据库操作异常: {cls.__name__}.safe_get_or_none, {e!s}", LOG_COMMAND
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
class DbUrlIsNode(HookPriorityException):
|
|
||||||
"""
|
|
||||||
数据库链接地址为空
|
|
||||||
"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class DbConnectError(Exception):
|
|
||||||
"""
|
|
||||||
数据库连接错误
|
|
||||||
"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
POSTGRESQL_CONFIG = {
|
|
||||||
"max_size": 30, # 最大连接数
|
|
||||||
"min_size": 5, # 最小保持的连接数(可选)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
MYSQL_CONFIG = {
|
|
||||||
"max_connections": 20, # 最大连接数
|
|
||||||
"connect_timeout": 30, # 连接超时(可选)
|
|
||||||
}
|
|
||||||
|
|
||||||
SQLITE_CONFIG = {
|
|
||||||
"journal_mode": "WAL", # 提高并发写入性能
|
|
||||||
"timeout": 30, # 锁等待超时(可选)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_config() -> dict:
|
|
||||||
"""获取数据库配置"""
|
|
||||||
parsed = urlparse(BotConfig.db_url)
|
|
||||||
|
|
||||||
# 基础配置
|
|
||||||
config = {
|
|
||||||
"connections": {
|
|
||||||
"default": BotConfig.db_url # 默认直接使用连接字符串
|
|
||||||
},
|
|
||||||
"apps": {
|
|
||||||
"models": {
|
|
||||||
"models": MODELS,
|
|
||||||
"default_connection": "default",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"timezone": "Asia/Shanghai",
|
|
||||||
}
|
|
||||||
|
|
||||||
# 根据数据库类型应用高级配置
|
|
||||||
if parsed.scheme.startswith("postgres"):
|
|
||||||
config["connections"]["default"] = {
|
|
||||||
"engine": "tortoise.backends.asyncpg",
|
|
||||||
"credentials": {
|
|
||||||
"host": parsed.hostname,
|
|
||||||
"port": parsed.port or 5432,
|
|
||||||
"user": parsed.username,
|
|
||||||
"password": parsed.password,
|
|
||||||
"database": parsed.path[1:],
|
|
||||||
},
|
|
||||||
**POSTGRESQL_CONFIG,
|
|
||||||
}
|
|
||||||
elif parsed.scheme == "mysql":
|
|
||||||
config["connections"]["default"] = {
|
|
||||||
"engine": "tortoise.backends.mysql",
|
|
||||||
"credentials": {
|
|
||||||
"host": parsed.hostname,
|
|
||||||
"port": parsed.port or 3306,
|
|
||||||
"user": parsed.username,
|
|
||||||
"password": parsed.password,
|
|
||||||
"database": parsed.path[1:],
|
|
||||||
},
|
|
||||||
**MYSQL_CONFIG,
|
|
||||||
}
|
|
||||||
elif parsed.scheme == "sqlite":
|
|
||||||
config["connections"]["default"] = {
|
|
||||||
"engine": "tortoise.backends.sqlite",
|
|
||||||
"credentials": {
|
|
||||||
"file_path": parsed.path or ":memory:",
|
|
||||||
},
|
|
||||||
**SQLITE_CONFIG,
|
|
||||||
}
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
@PriorityLifecycle.on_startup(priority=1)
|
|
||||||
async def init():
|
|
||||||
if not BotConfig.db_url:
|
|
||||||
# raise DbUrlIsNode("数据库配置为空,请在.env.dev中配置DB_URL...")
|
|
||||||
error = f"""
|
|
||||||
**********************************************************************
|
|
||||||
🌟 **************************** 配置为空 ************************* 🌟
|
|
||||||
🚀 请打开 WebUi 进行基础配置 🚀
|
|
||||||
🌐 配置地址:http://{driver.config.host}:{driver.config.port}/#/configure 🌐
|
|
||||||
***********************************************************************
|
|
||||||
***********************************************************************
|
|
||||||
"""
|
|
||||||
raise DbUrlIsNode("\n" + error.strip())
|
|
||||||
try:
|
|
||||||
await Tortoise.init(
|
|
||||||
config=get_config(),
|
|
||||||
)
|
|
||||||
if SCRIPT_METHOD:
|
|
||||||
db = Tortoise.get_connection("default")
|
|
||||||
logger.debug(
|
|
||||||
"即将运行SCRIPT_METHOD方法, 合计 "
|
|
||||||
f"<u><y>{len(SCRIPT_METHOD)}</y></u> 个..."
|
|
||||||
)
|
|
||||||
sql_list = []
|
|
||||||
for module, func in SCRIPT_METHOD:
|
|
||||||
try:
|
|
||||||
sql = await func() if is_coroutine_callable(func) else func()
|
|
||||||
if sql:
|
|
||||||
sql_list += sql
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"{module} 执行SCRIPT_METHOD方法出错...", e=e)
|
|
||||||
for sql in sql_list:
|
|
||||||
logger.debug(f"执行SQL: {sql}")
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(
|
|
||||||
db.execute_query_dict(sql), timeout=DB_TIMEOUT_SECONDS
|
|
||||||
)
|
|
||||||
# await TestSQL.raw(sql)
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"执行SQL: {sql} 错误...", e=e)
|
|
||||||
if sql_list:
|
|
||||||
logger.debug("SCRIPT_METHOD方法执行完毕!")
|
|
||||||
logger.debug("开始生成数据库表结构...")
|
|
||||||
await Tortoise.generate_schemas()
|
|
||||||
logger.debug("数据库表结构生成完毕!")
|
|
||||||
logger.info("Database loaded successfully!")
|
|
||||||
except Exception as e:
|
|
||||||
raise DbConnectError(f"数据库连接错误... e:{e}") from e
|
|
||||||
|
|
||||||
|
|
||||||
async def disconnect():
|
|
||||||
await connections.close_all()
|
|
||||||
46
zhenxun/services/db_context/config.py
Normal file
46
zhenxun/services/db_context/config.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
# 数据库操作超时设置(秒)
|
||||||
|
DB_TIMEOUT_SECONDS = 3.0
|
||||||
|
|
||||||
|
# 性能监控阈值(秒)
|
||||||
|
SLOW_QUERY_THRESHOLD = 0.5
|
||||||
|
|
||||||
|
LOG_COMMAND = "DbContext"
|
||||||
|
|
||||||
|
|
||||||
|
POSTGRESQL_CONFIG = {
|
||||||
|
"max_size": 30, # 最大连接数
|
||||||
|
"min_size": 5, # 最小保持的连接数(可选)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
MYSQL_CONFIG = {
|
||||||
|
"max_connections": 20, # 最大连接数
|
||||||
|
"connect_timeout": 30, # 连接超时(可选)
|
||||||
|
}
|
||||||
|
|
||||||
|
SQLITE_CONFIG = {
|
||||||
|
"journal_mode": "WAL", # 提高并发写入性能
|
||||||
|
"timeout": 30, # 锁等待超时(可选)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DbModel(BaseModel):
|
||||||
|
script_method: list[tuple[str, Callable]] = []
|
||||||
|
models: list[str] = []
|
||||||
|
|
||||||
|
|
||||||
|
db_model = DbModel()
|
||||||
|
|
||||||
|
|
||||||
|
prompt = """
|
||||||
|
**********************************************************************
|
||||||
|
🌟 **************************** 配置为空 ************************* 🌟
|
||||||
|
🚀 请打开 WebUi 进行基础配置 🚀
|
||||||
|
🌐 配置地址:http://{host}:{port}/#/configure 🌐
|
||||||
|
***********************************************************************
|
||||||
|
***********************************************************************
|
||||||
|
"""
|
||||||
14
zhenxun/services/db_context/exceptions.py
Normal file
14
zhenxun/services/db_context/exceptions.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
class DbUrlIsNode(Exception):
|
||||||
|
"""
|
||||||
|
数据库链接地址为空
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DbConnectError(Exception):
|
||||||
|
"""
|
||||||
|
数据库连接错误
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
27
zhenxun/services/db_context/utils.py
Normal file
27
zhenxun/services/db_context/utils.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
|
from zhenxun.services.log import logger
|
||||||
|
|
||||||
|
from .config import (
|
||||||
|
DB_TIMEOUT_SECONDS,
|
||||||
|
LOG_COMMAND,
|
||||||
|
SLOW_QUERY_THRESHOLD,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def with_db_timeout(
|
||||||
|
coro, timeout: float = DB_TIMEOUT_SECONDS, operation: str | None = None
|
||||||
|
):
|
||||||
|
"""带超时控制的数据库操作"""
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
result = await asyncio.wait_for(coro, timeout=timeout)
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
if elapsed > SLOW_QUERY_THRESHOLD and operation:
|
||||||
|
logger.warning(f"慢查询: {operation} 耗时 {elapsed:.3f}s", LOG_COMMAND)
|
||||||
|
return result
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
if operation:
|
||||||
|
logger.error(f"数据库操作超时: {operation} (>{timeout}s)", LOG_COMMAND)
|
||||||
|
raise
|
||||||
@ -54,19 +54,20 @@ class PluginInitManager:
|
|||||||
@classmethod
|
@classmethod
|
||||||
async def install_all(cls):
|
async def install_all(cls):
|
||||||
"""运行所有插件安装方法"""
|
"""运行所有插件安装方法"""
|
||||||
if cls.plugins:
|
if not cls.plugins:
|
||||||
for module_path, model in cls.plugins.items():
|
return
|
||||||
if model.install:
|
for module_path, model in cls.plugins.items():
|
||||||
class_ = model.class_()
|
if model.install:
|
||||||
try:
|
class_ = model.class_()
|
||||||
logger.debug(f"开始执行: {module_path}:install 方法")
|
try:
|
||||||
if is_coroutine_callable(class_.install):
|
logger.debug(f"开始执行: {module_path}:install 方法")
|
||||||
await class_.install()
|
if is_coroutine_callable(class_.install):
|
||||||
else:
|
await class_.install()
|
||||||
class_.install() # type: ignore
|
else:
|
||||||
logger.debug(f"执行: {module_path}:install 完成")
|
class_.install() # type: ignore
|
||||||
except Exception as e:
|
logger.debug(f"执行: {module_path}:install 完成")
|
||||||
logger.error(f"执行: {module_path}:install 失败", e=e)
|
except Exception as e:
|
||||||
|
logger.error(f"执行: {module_path}:install 失败", e=e)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def install(cls, module_path: str):
|
async def install(cls, module_path: str):
|
||||||
|
|||||||
@ -1,13 +1,181 @@
|
|||||||
from typing import Any
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from nonebot.adapters import Bot, Event
|
||||||
from nonebot.internal.params import Depends
|
from nonebot.internal.params import Depends
|
||||||
from nonebot.matcher import Matcher
|
from nonebot.matcher import Matcher
|
||||||
from nonebot.params import Command
|
from nonebot.params import Command
|
||||||
|
from nonebot.permission import SUPERUSER
|
||||||
from nonebot_plugin_session import EventSession
|
from nonebot_plugin_session import EventSession
|
||||||
from nonebot_plugin_uninfo import Uninfo
|
from nonebot_plugin_uninfo import Uninfo
|
||||||
|
|
||||||
from zhenxun.configs.config import Config
|
from zhenxun.configs.config import Config
|
||||||
|
from zhenxun.utils.limiters import ConcurrencyLimiter, FreqLimiter, RateLimiter
|
||||||
from zhenxun.utils.message import MessageUtils
|
from zhenxun.utils.message import MessageUtils
|
||||||
|
from zhenxun.utils.time_utils import TimeUtils
|
||||||
|
|
||||||
|
_coolers: dict[str, FreqLimiter] = {}
|
||||||
|
_rate_limiters: dict[str, RateLimiter] = {}
|
||||||
|
_concurrency_limiters: dict[str, ConcurrencyLimiter] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _create_limiter_dependency(
|
||||||
|
limiter_class: type,
|
||||||
|
limiter_storage: dict,
|
||||||
|
limiter_init_args: dict[str, Any],
|
||||||
|
scope: Literal["user", "group", "global"],
|
||||||
|
prompt: str,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
一个高阶函数,用于创建不同类型的限制器依赖。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
limiter_class: 限制器类 (FreqLimiter, RateLimiter, etc.).
|
||||||
|
limiter_storage: 用于存储限制器实例的字典.
|
||||||
|
limiter_init_args: 限制器类的初始化参数.
|
||||||
|
scope: 限制作用域.
|
||||||
|
prompt: 触发限制时的提示信息.
|
||||||
|
**kwargs: 传递给特定限制器逻辑的额外参数.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def dependency(
|
||||||
|
matcher: Matcher, session: EventSession, bot: Bot, event: Event
|
||||||
|
) -> bool:
|
||||||
|
if await SUPERUSER(bot, event):
|
||||||
|
return True
|
||||||
|
|
||||||
|
handler_id = (
|
||||||
|
f"{matcher.plugin_name}:{matcher.handlers[0].call.__code__.co_firstlineno}"
|
||||||
|
)
|
||||||
|
|
||||||
|
key: str | None = None
|
||||||
|
if scope == "user":
|
||||||
|
key = session.id1
|
||||||
|
elif scope == "group":
|
||||||
|
key = session.id3 or session.id2 or session.id1
|
||||||
|
elif scope == "global":
|
||||||
|
key = f"global_{handler_id}"
|
||||||
|
|
||||||
|
if not key:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if handler_id not in limiter_storage:
|
||||||
|
limiter_storage[handler_id] = limiter_class(**limiter_init_args)
|
||||||
|
limiter = limiter_storage[handler_id]
|
||||||
|
|
||||||
|
if isinstance(limiter, ConcurrencyLimiter):
|
||||||
|
await limiter.acquire(key)
|
||||||
|
matcher.state["_concurrency_limiter_info"] = {
|
||||||
|
"limiter": limiter,
|
||||||
|
"key": key,
|
||||||
|
}
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
if limiter.check(key):
|
||||||
|
if isinstance(limiter, FreqLimiter):
|
||||||
|
limiter.start_cd(
|
||||||
|
key, kwargs.get("duration_sec", limiter.default_cd)
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
left_time = limiter.left_time(key)
|
||||||
|
format_kwargs = {
|
||||||
|
"cd_str": TimeUtils.format_duration(left_time),
|
||||||
|
**(kwargs.get("prompt_format_kwargs", {})),
|
||||||
|
}
|
||||||
|
message = prompt.format(**format_kwargs)
|
||||||
|
await matcher.finish(message)
|
||||||
|
|
||||||
|
return Depends(dependency)
|
||||||
|
|
||||||
|
|
||||||
|
def Cooldown(
|
||||||
|
duration: str,
|
||||||
|
*,
|
||||||
|
scope: Literal["user", "group", "global"] = "user",
|
||||||
|
prompt: str = "操作过于频繁,请等待 {cd_str}",
|
||||||
|
) -> bool:
|
||||||
|
"""声明式冷却检查依赖,限制用户操作频率
|
||||||
|
|
||||||
|
参数:
|
||||||
|
duration: 冷却时间字符串 (e.g., "30s", "10m", "1h")
|
||||||
|
scope: 冷却作用域
|
||||||
|
prompt: 自定义的冷却提示消息,可使用 {cd_str} 占位符
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bool: 是否允许执行
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
parsed_seconds = TimeUtils.parse_time_string(duration)
|
||||||
|
except ValueError as e:
|
||||||
|
raise ValueError(f"Cooldown装饰器中的duration格式错误: {e}")
|
||||||
|
|
||||||
|
return _create_limiter_dependency(
|
||||||
|
limiter_class=FreqLimiter,
|
||||||
|
limiter_storage=_coolers,
|
||||||
|
limiter_init_args={"default_cd_seconds": parsed_seconds},
|
||||||
|
scope=scope,
|
||||||
|
prompt=prompt,
|
||||||
|
duration_sec=parsed_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def RateLimit(
|
||||||
|
count: int,
|
||||||
|
duration: str,
|
||||||
|
*,
|
||||||
|
scope: Literal["user", "group", "global"] = "user",
|
||||||
|
prompt: str = "太快了,在 {duration_str} 内只能触发{limit}次,请等待 {cd_str}",
|
||||||
|
) -> bool:
|
||||||
|
"""声明式速率限制依赖,在指定时间窗口内限制操作次数
|
||||||
|
|
||||||
|
参数:
|
||||||
|
count: 在时间窗口内允许的最大调用次数
|
||||||
|
duration: 时间窗口字符串 (e.g., "1m", "1h")
|
||||||
|
scope: 限制作用域
|
||||||
|
prompt: 自定义的提示消息,可使用 {cd_str}, {duration_str}, {limit} 占位符
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bool: 是否允许执行
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
parsed_seconds = TimeUtils.parse_time_string(duration)
|
||||||
|
except ValueError as e:
|
||||||
|
raise ValueError(f"RateLimit装饰器中的duration格式错误: {e}")
|
||||||
|
|
||||||
|
return _create_limiter_dependency(
|
||||||
|
limiter_class=RateLimiter,
|
||||||
|
limiter_storage=_rate_limiters,
|
||||||
|
limiter_init_args={"max_calls": count, "time_window": parsed_seconds},
|
||||||
|
scope=scope,
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_format_kwargs={"duration_str": duration, "limit": count},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def ConcurrencyLimit(
|
||||||
|
count: int,
|
||||||
|
*,
|
||||||
|
scope: Literal["user", "group", "global"] = "global",
|
||||||
|
prompt: str | None = "当前功能繁忙,请稍后再试...",
|
||||||
|
) -> bool:
|
||||||
|
"""声明式并发数限制依赖,控制某个功能同时执行的实例数量
|
||||||
|
|
||||||
|
参数:
|
||||||
|
count: 最大并发数
|
||||||
|
scope: 限制作用域
|
||||||
|
prompt: 提示消息(暂未使用,主要用于未来扩展超时功能)
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bool: 是否允许执行
|
||||||
|
"""
|
||||||
|
return _create_limiter_dependency(
|
||||||
|
limiter_class=ConcurrencyLimiter,
|
||||||
|
limiter_storage=_concurrency_limiters,
|
||||||
|
limiter_init_args={"max_concurrent": count},
|
||||||
|
scope=scope,
|
||||||
|
prompt=prompt or "",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def CheckUg(check_user: bool = True, check_group: bool = True):
|
def CheckUg(check_user: bool = True, check_group: bool = True):
|
||||||
@ -75,7 +243,6 @@ def GetConfig(
|
|||||||
if module_:
|
if module_:
|
||||||
value = Config.get_config(module_, config, default_value)
|
value = Config.get_config(module_, config, default_value)
|
||||||
if value is None and prompt:
|
if value is None and prompt:
|
||||||
# await matcher.finish(prompt or f"配置项 {config} 未填写!")
|
|
||||||
await matcher.finish(prompt)
|
await matcher.finish(prompt)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,17 @@
|
|||||||
|
from nonebot.exception import IgnoredException
|
||||||
|
|
||||||
|
|
||||||
|
class CooldownError(IgnoredException):
|
||||||
|
"""
|
||||||
|
冷却异常,用于在冷却时中断事件处理。
|
||||||
|
继承自 IgnoredException,不会在控制台留下错误堆栈。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, message: str):
|
||||||
|
self.message = message
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
class HookPriorityException(BaseException):
|
class HookPriorityException(BaseException):
|
||||||
"""
|
"""
|
||||||
钩子优先级异常
|
钩子优先级异常
|
||||||
|
|||||||
140
zhenxun/utils/limiters.py
Normal file
140
zhenxun/utils/limiters.py
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
import asyncio
|
||||||
|
from collections import defaultdict, deque
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class FreqLimiter:
|
||||||
|
"""
|
||||||
|
命令冷却,检测用户是否处于冷却状态
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, default_cd_seconds: int):
|
||||||
|
self.next_time: dict[Any, float] = defaultdict(float)
|
||||||
|
self.default_cd = default_cd_seconds
|
||||||
|
|
||||||
|
def check(self, key: Any) -> bool:
|
||||||
|
return time.time() >= self.next_time[key]
|
||||||
|
|
||||||
|
def start_cd(self, key: Any, cd_time: int = 0):
|
||||||
|
self.next_time[key] = time.time() + (
|
||||||
|
cd_time if cd_time > 0 else self.default_cd
|
||||||
|
)
|
||||||
|
|
||||||
|
def left_time(self, key: Any) -> float:
|
||||||
|
return max(0.0, self.next_time[key] - time.time())
|
||||||
|
|
||||||
|
|
||||||
|
class CountLimiter:
|
||||||
|
"""
|
||||||
|
每日调用命令次数限制
|
||||||
|
"""
|
||||||
|
|
||||||
|
tz = None
|
||||||
|
|
||||||
|
def __init__(self, max_num: int):
|
||||||
|
self.today = -1
|
||||||
|
self.count: dict[Any, int] = defaultdict(int)
|
||||||
|
self.max = max_num
|
||||||
|
|
||||||
|
def check(self, key: Any) -> bool:
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
day = datetime.datetime.now().day
|
||||||
|
if day != self.today:
|
||||||
|
self.today = day
|
||||||
|
self.count.clear()
|
||||||
|
return self.count[key] < self.max
|
||||||
|
|
||||||
|
def get_num(self, key: Any) -> int:
|
||||||
|
return self.count[key]
|
||||||
|
|
||||||
|
def increase(self, key: Any, num: int = 1):
|
||||||
|
self.count[key] += num
|
||||||
|
|
||||||
|
def reset(self, key: Any):
|
||||||
|
self.count[key] = 0
|
||||||
|
|
||||||
|
|
||||||
|
class UserBlockLimiter:
|
||||||
|
"""
|
||||||
|
检测用户是否正在调用命令 (简单阻塞锁)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.flag_data: dict[Any, bool] = defaultdict(bool)
|
||||||
|
self.time: dict[Any, float] = defaultdict(float)
|
||||||
|
|
||||||
|
def set_true(self, key: Any):
|
||||||
|
self.time[key] = time.time()
|
||||||
|
self.flag_data[key] = True
|
||||||
|
|
||||||
|
def set_false(self, key: Any):
|
||||||
|
self.flag_data[key] = False
|
||||||
|
|
||||||
|
def check(self, key: Any) -> bool:
|
||||||
|
if self.flag_data[key] and time.time() - self.time[key] > 30:
|
||||||
|
self.set_false(key)
|
||||||
|
return not self.flag_data[key]
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimiter:
|
||||||
|
"""
|
||||||
|
一个简单的基于时间窗口的速率限制器。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_calls: int, time_window: int):
|
||||||
|
self.requests: dict[Any, deque[float]] = defaultdict(deque)
|
||||||
|
self.max_calls = max_calls
|
||||||
|
self.time_window = time_window
|
||||||
|
|
||||||
|
def check(self, key: Any) -> bool:
|
||||||
|
"""检查是否超出速率限制。如果未超出,则记录本次调用。"""
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
|
while self.requests[key] and self.requests[key][0] <= now - self.time_window:
|
||||||
|
self.requests[key].popleft()
|
||||||
|
|
||||||
|
if len(self.requests[key]) < self.max_calls:
|
||||||
|
self.requests[key].append(now)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def left_time(self, key: Any) -> float:
|
||||||
|
"""计算距离下次可调用还需等待的时间"""
|
||||||
|
if self.requests[key]:
|
||||||
|
return max(0.0, self.requests[key][0] + self.time_window - time.time())
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class ConcurrencyLimiter:
|
||||||
|
"""
|
||||||
|
一个基于 asyncio.Semaphore 的并发限制器。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_concurrent: int):
|
||||||
|
self._semaphores: dict[Any, asyncio.Semaphore] = {}
|
||||||
|
self.max_concurrent = max_concurrent
|
||||||
|
self._active_tasks: dict[Any, int] = defaultdict(int)
|
||||||
|
|
||||||
|
def _get_semaphore(self, key: Any) -> asyncio.Semaphore:
|
||||||
|
if key not in self._semaphores:
|
||||||
|
self._semaphores[key] = asyncio.Semaphore(self.max_concurrent)
|
||||||
|
return self._semaphores[key]
|
||||||
|
|
||||||
|
async def acquire(self, key: Any):
|
||||||
|
"""获取一个信号量,如果达到并发上限则会阻塞等待。"""
|
||||||
|
semaphore = self._get_semaphore(key)
|
||||||
|
await semaphore.acquire()
|
||||||
|
self._active_tasks[key] += 1
|
||||||
|
|
||||||
|
def release(self, key: Any):
|
||||||
|
"""释放一个信号量。"""
|
||||||
|
if key in self._semaphores:
|
||||||
|
if self._active_tasks[key] > 0:
|
||||||
|
self._semaphores[key].release()
|
||||||
|
self._active_tasks[key] -= 1
|
||||||
|
else:
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.warning(f"尝试释放键 '{key}' 的信号量时,计数已经为零。")
|
||||||
@ -49,11 +49,14 @@ class Config(BaseModel):
|
|||||||
|
|
||||||
class MessageUtils:
|
class MessageUtils:
|
||||||
@classmethod
|
@classmethod
|
||||||
def __build_message(cls, msg_list: list[MESSAGE_TYPE]) -> list[Text | Image]:
|
def __build_message(
|
||||||
|
cls, msg_list: list[MESSAGE_TYPE], format_args: dict | None = None
|
||||||
|
) -> list[Text | Image]:
|
||||||
"""构造消息
|
"""构造消息
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
msg_list: 消息列表
|
msg_list: 消息列表
|
||||||
|
format_args: 用于格式化字符串的参数字典.
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
list[Text | Text]: 构造完成的消息列表
|
list[Text | Text]: 构造完成的消息列表
|
||||||
@ -65,7 +68,15 @@ class MessageUtils:
|
|||||||
if msg.startswith("base64://"):
|
if msg.startswith("base64://"):
|
||||||
message_list.append(Image(raw=BytesIO(base64.b64decode(msg[9:]))))
|
message_list.append(Image(raw=BytesIO(base64.b64decode(msg[9:]))))
|
||||||
else:
|
else:
|
||||||
message_list.append(Text(msg))
|
formatted_msg = msg
|
||||||
|
if format_args:
|
||||||
|
try:
|
||||||
|
formatted_msg = msg.format_map(format_args)
|
||||||
|
except (KeyError, IndexError) as e:
|
||||||
|
logger.debug(
|
||||||
|
f"格式化字符串 '{msg}' 失败 ({e}),将使用原始文本。"
|
||||||
|
)
|
||||||
|
message_list.append(Text(formatted_msg))
|
||||||
elif isinstance(msg, int | float):
|
elif isinstance(msg, int | float):
|
||||||
message_list.append(Text(str(msg)))
|
message_list.append(Text(str(msg)))
|
||||||
elif isinstance(msg, Path):
|
elif isinstance(msg, Path):
|
||||||
@ -90,12 +101,15 @@ class MessageUtils:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build_message(
|
def build_message(
|
||||||
cls, msg_list: MESSAGE_TYPE | list[MESSAGE_TYPE | list[MESSAGE_TYPE]]
|
cls,
|
||||||
|
msg_list: MESSAGE_TYPE | list[MESSAGE_TYPE | list[MESSAGE_TYPE]],
|
||||||
|
format_args: dict | None = None,
|
||||||
) -> UniMessage:
|
) -> UniMessage:
|
||||||
"""构造消息
|
"""构造消息
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
msg_list: 消息列表
|
msg_list: 消息列表
|
||||||
|
format_args: 用于格式化字符串的参数字典.
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
UniMessage: 构造完成的消息列表
|
UniMessage: 构造完成的消息列表
|
||||||
@ -105,7 +119,7 @@ class MessageUtils:
|
|||||||
msg_list = [msg_list]
|
msg_list = [msg_list]
|
||||||
for m in msg_list:
|
for m in msg_list:
|
||||||
_data = m if isinstance(m, list) else [m]
|
_data = m if isinstance(m, list) else [m]
|
||||||
message_list += cls.__build_message(_data) # type: ignore
|
message_list += cls.__build_message(_data, format_args)
|
||||||
return UniMessage(message_list)
|
return UniMessage(message_list)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
91
zhenxun/utils/time_utils.py
Normal file
91
zhenxun/utils/time_utils.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
from datetime import date, datetime
|
||||||
|
import re
|
||||||
|
|
||||||
|
import pytz
|
||||||
|
|
||||||
|
|
||||||
|
class TimeUtils:
|
||||||
|
DEFAULT_TIMEZONE = pytz.timezone("Asia/Shanghai")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_day_start(cls, target_date: date | datetime | None = None) -> datetime:
|
||||||
|
"""获取某天的0点时间
|
||||||
|
|
||||||
|
返回:
|
||||||
|
datetime: 今天某天的0点时间
|
||||||
|
"""
|
||||||
|
if not target_date:
|
||||||
|
target_date = datetime.now(cls.DEFAULT_TIMEZONE)
|
||||||
|
|
||||||
|
if isinstance(target_date, datetime) and target_date.tzinfo is None:
|
||||||
|
target_date = cls.DEFAULT_TIMEZONE.localize(target_date)
|
||||||
|
|
||||||
|
return (
|
||||||
|
target_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
|
if isinstance(target_date, datetime)
|
||||||
|
else datetime.combine(
|
||||||
|
target_date, datetime.min.time(), tzinfo=cls.DEFAULT_TIMEZONE
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_valid_date(cls, date_text: str, separator: str = "-") -> bool:
|
||||||
|
"""日期是否合法
|
||||||
|
|
||||||
|
参数:
|
||||||
|
date_text: 日期
|
||||||
|
separator: 分隔符
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bool: 日期是否合法
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
datetime.strptime(date_text, f"%Y{separator}%m{separator}%d")
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def parse_time_string(cls, time_str: str) -> int:
|
||||||
|
"""
|
||||||
|
将带有单位的时间字符串 (e.g., "10s", "5m", "1h") 解析为总秒数。
|
||||||
|
"""
|
||||||
|
time_str = time_str.lower().strip()
|
||||||
|
match = re.match(r"^(\d+)([smh])$", time_str)
|
||||||
|
if not match:
|
||||||
|
raise ValueError(
|
||||||
|
f"无效的时间格式: '{time_str}'。请使用如 '30s', '10m', '2h' 的格式。"
|
||||||
|
)
|
||||||
|
|
||||||
|
value, unit = int(match.group(1)), match.group(2)
|
||||||
|
|
||||||
|
if unit == "s":
|
||||||
|
return value
|
||||||
|
if unit == "m":
|
||||||
|
return value * 60
|
||||||
|
if unit == "h":
|
||||||
|
return value * 3600
|
||||||
|
return 0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def format_duration(cls, seconds: float) -> str:
|
||||||
|
"""
|
||||||
|
将秒数格式化为易于阅读的字符串 (例如 "1小时5分钟", "30.5秒")
|
||||||
|
"""
|
||||||
|
seconds = round(seconds, 1)
|
||||||
|
if seconds < 0.1:
|
||||||
|
return "不到1秒"
|
||||||
|
if seconds < 60:
|
||||||
|
return f"{seconds}秒"
|
||||||
|
|
||||||
|
minutes, sec_remainder = divmod(int(seconds), 60)
|
||||||
|
|
||||||
|
if minutes < 60:
|
||||||
|
if sec_remainder == 0:
|
||||||
|
return f"{minutes}分钟"
|
||||||
|
return f"{minutes}分钟{sec_remainder}秒"
|
||||||
|
|
||||||
|
hours, rem_minutes = divmod(minutes, 60)
|
||||||
|
if rem_minutes == 0:
|
||||||
|
return f"{hours}小时"
|
||||||
|
return f"{hours}小时{rem_minutes}分钟"
|
||||||
@ -1,19 +1,19 @@
|
|||||||
from collections import defaultdict
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import date, datetime
|
from datetime import datetime
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import time
|
import time
|
||||||
from typing import Any, ClassVar
|
from typing import ClassVar
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from nonebot_plugin_uninfo import Uninfo
|
from nonebot_plugin_uninfo import Uninfo
|
||||||
import pypinyin
|
import pypinyin
|
||||||
import pytz
|
|
||||||
|
|
||||||
from zhenxun.configs.config import Config
|
from zhenxun.configs.config import Config
|
||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
|
|
||||||
|
from .limiters import CountLimiter, FreqLimiter, UserBlockLimiter # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EntityIDs:
|
class EntityIDs:
|
||||||
@ -64,78 +64,6 @@ class ResourceDirManager:
|
|||||||
cls.__tree_append(path, deep)
|
cls.__tree_append(path, deep)
|
||||||
|
|
||||||
|
|
||||||
class CountLimiter:
|
|
||||||
"""
|
|
||||||
每日调用命令次数限制
|
|
||||||
"""
|
|
||||||
|
|
||||||
tz = pytz.timezone("Asia/Shanghai")
|
|
||||||
|
|
||||||
def __init__(self, max_num):
|
|
||||||
self.today = -1
|
|
||||||
self.count = defaultdict(int)
|
|
||||||
self.max = max_num
|
|
||||||
|
|
||||||
def check(self, key) -> bool:
|
|
||||||
day = datetime.now(self.tz).day
|
|
||||||
if day != self.today:
|
|
||||||
self.today = day
|
|
||||||
self.count.clear()
|
|
||||||
return self.count[key] < self.max
|
|
||||||
|
|
||||||
def get_num(self, key):
|
|
||||||
return self.count[key]
|
|
||||||
|
|
||||||
def increase(self, key, num=1):
|
|
||||||
self.count[key] += num
|
|
||||||
|
|
||||||
def reset(self, key):
|
|
||||||
self.count[key] = 0
|
|
||||||
|
|
||||||
|
|
||||||
class UserBlockLimiter:
|
|
||||||
"""
|
|
||||||
检测用户是否正在调用命令
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.flag_data = defaultdict(bool)
|
|
||||||
self.time = time.time()
|
|
||||||
|
|
||||||
def set_true(self, key: Any):
|
|
||||||
self.time = time.time()
|
|
||||||
self.flag_data[key] = True
|
|
||||||
|
|
||||||
def set_false(self, key: Any):
|
|
||||||
self.flag_data[key] = False
|
|
||||||
|
|
||||||
def check(self, key: Any) -> bool:
|
|
||||||
if time.time() - self.time > 30:
|
|
||||||
self.set_false(key)
|
|
||||||
return not self.flag_data[key]
|
|
||||||
|
|
||||||
|
|
||||||
class FreqLimiter:
|
|
||||||
"""
|
|
||||||
命令冷却,检测用户是否处于冷却状态
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, default_cd_seconds: int):
|
|
||||||
self.next_time = defaultdict(float)
|
|
||||||
self.default_cd = default_cd_seconds
|
|
||||||
|
|
||||||
def check(self, key: Any) -> bool:
|
|
||||||
return time.time() >= self.next_time[key]
|
|
||||||
|
|
||||||
def start_cd(self, key: Any, cd_time: int = 0):
|
|
||||||
self.next_time[key] = time.time() + (
|
|
||||||
cd_time if cd_time > 0 else self.default_cd
|
|
||||||
)
|
|
||||||
|
|
||||||
def left_time(self, key: Any) -> float:
|
|
||||||
return self.next_time[key] - time.time()
|
|
||||||
|
|
||||||
|
|
||||||
def cn2py(word: str) -> str:
|
def cn2py(word: str) -> str:
|
||||||
"""将字符串转化为拼音
|
"""将字符串转化为拼音
|
||||||
|
|
||||||
@ -277,20 +205,3 @@ def is_number(text: str) -> bool:
|
|||||||
return True
|
return True
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class TimeUtils:
|
|
||||||
@classmethod
|
|
||||||
def get_day_start(cls, target_date: date | datetime | None = None) -> datetime:
|
|
||||||
"""获取某天的0点时间
|
|
||||||
|
|
||||||
返回:
|
|
||||||
datetime: 今天某天的0点时间
|
|
||||||
"""
|
|
||||||
if not target_date:
|
|
||||||
target_date = datetime.now()
|
|
||||||
return (
|
|
||||||
target_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
|
||||||
if isinstance(target_date, datetime)
|
|
||||||
else datetime.combine(target_date, datetime.min.time())
|
|
||||||
)
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user