diff --git a/zhenxun/builtin_plugins/hooks/call_hook.py b/zhenxun/builtin_plugins/hooks/call_hook.py index f87bd983..023efa31 100644 --- a/zhenxun/builtin_plugins/hooks/call_hook.py +++ b/zhenxun/builtin_plugins/hooks/call_hook.py @@ -1,7 +1,10 @@ +import asyncio +from collections import deque from typing import Any from nonebot.adapters import Bot, Message from nonebot.adapters.onebot.v11 import MessageSegment +from nonebot_plugin_apscheduler import scheduler from zhenxun.configs.config import Config from zhenxun.models.bot_message_store import BotMessageStore @@ -13,6 +16,45 @@ from zhenxun.utils.platform import PlatformUtils LOG_COMMAND = "MessageHook" +_BOT_MSG_BUFFER: deque[dict[str, Any]] = deque() +_BOT_MSG_BUFFER_LOCK = asyncio.Lock() +_BOT_MSG_BULK_SIZE = 50 +_PENDING_TASKS: set[asyncio.Task] = set() + + +async def _flush_bot_messages(): + async with _BOT_MSG_BUFFER_LOCK: + if not _BOT_MSG_BUFFER: + return + items: list[dict[str, Any]] = [] + while _BOT_MSG_BUFFER: + items.append(_BOT_MSG_BUFFER.popleft()) + try: + await BotMessageStore.bulk_create([BotMessageStore(**it) for it in items]) + except Exception as e: + logger.warning("批量写入BotMessageStore失败", LOG_COMMAND, e=e) + # 尝试降级逐条写入,避免数据全部丢失 + try: + for it in items: + await BotMessageStore.create(**it) + except Exception as e2: + logger.warning("逐条写入BotMessageStore失败", LOG_COMMAND, e=e2) + + +async def _enqueue_bot_message(item: dict[str, Any]): + async with _BOT_MSG_BUFFER_LOCK: + _BOT_MSG_BUFFER.append(item) + if len(_BOT_MSG_BUFFER) >= _BOT_MSG_BULK_SIZE: + task = asyncio.create_task(_flush_bot_messages()) + _PENDING_TASKS.add(task) + task.add_done_callback(_PENDING_TASKS.discard) + + +@scheduler.scheduled_job("interval", seconds=10) +async def _flush_bot_messages_job(): + await _flush_bot_messages() + + def replace_message(message: Message) -> str: """将消息中的at、image、record、face替换为字符串 @@ -95,18 +137,20 @@ async def handle_api_result( if not Config.get_config("hook", "RECORD_BOT_SENT_MESSAGES"): return try: - await BotMessageStore.create( - bot_id=bot.self_id, - user_id=user_id, - group_id=group_id, - sent_type=BotSentType.GROUP - if message_type == "group" - else BotSentType.PRIVATE, - text=replace_message(message), - plain_text=message.extract_plain_text() - if isinstance(message, Message) - else replace_message(message), - platform=PlatformUtils.get_platform(bot), + await _enqueue_bot_message( + { + "bot_id": bot.self_id, + "user_id": user_id, + "group_id": group_id, + "sent_type": BotSentType.GROUP + if message_type == "group" + else BotSentType.PRIVATE, + "text": replace_message(message), + "plain_text": message.extract_plain_text() + if isinstance(message, Message) + else replace_message(message), + "platform": PlatformUtils.get_platform(bot), + } ) logger.debug(f"消息发送记录,message: {format_message_for_log(message)}") except Exception as e: diff --git a/zhenxun/services/cache/cache_containers.py b/zhenxun/services/cache/cache_containers.py index aad8878f..bf4bbd95 100644 --- a/zhenxun/services/cache/cache_containers.py +++ b/zhenxun/services/cache/cache_containers.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +import random import time from typing import Any, Generic, TypeVar @@ -112,7 +113,7 @@ class CacheDict(Generic[T]): if expire is not None and expire > 0: expire_time = time.time() + expire elif self.expire > 0: - expire_time = time.time() + self.expire + expire_time = time.time() + self.expire + random.randint(0, 600) self._data[key] = CacheData(value=value, expire_time=expire_time) diff --git a/zhenxun/services/db_context/base_model.py b/zhenxun/services/db_context/base_model.py index ff642258..ad7460bf 100644 --- a/zhenxun/services/db_context/base_model.py +++ b/zhenxun/services/db_context/base_model.py @@ -1,9 +1,12 @@ import asyncio +from collections import deque from collections.abc import Iterable import contextlib +import time from typing import Any, ClassVar from typing_extensions import Self +from nonebot_plugin_apscheduler import scheduler from tortoise.backends.base.client import BaseDBAsyncClient from tortoise.exceptions import IntegrityError, MultipleObjectsReturned from tortoise.models import Model as TortoiseModel @@ -16,6 +19,57 @@ from zhenxun.utils.enum import DbLockType from .config import LOG_COMMAND, db_model from .utils import with_db_timeout +start_time = time.time() + + +class ModelCallMetrics: + """记录Model数据库方法调用次数(30秒窗口内,按模块与方法分组)""" + + window_seconds: ClassVar[float] = 30.0 + _call_history: ClassVar[dict[str, dict[str, deque[float]]]] = {} + + @classmethod + def _trim(cls, module: str, method: str, now: float | None = None): + """移除指定模块/方法超出时间窗口的记录""" + module_hist = cls._call_history.get(module) + if not module_hist or method not in module_hist: + return + + now = time.monotonic() if now is None else now + time_limit = now - cls.window_seconds + history = module_hist[method] + while history and history[0] < time_limit: + history.popleft() + + @classmethod + def record_call(cls, module: str, method: str): + """记录一次数据库方法调用(按模块与方法)""" + if time.time() - start_time < 20: + return + now = time.monotonic() + method_hist = cls._call_history.setdefault(module, {}) + history = method_hist.setdefault(method, deque()) + history.append(now) + cls._trim(module, method, now) + + @classmethod + def get_count(cls, module: str, method: str) -> int: + """获取指定模块指定方法在窗口内的调用次数""" + cls._trim(module, method) + return len(cls._call_history.get(module, {}).get(method, ())) + + @classmethod + def get_method_counts(cls, module: str) -> dict[str, int]: + """获取指定模块各方法在窗口内的调用统计""" + methods = cls._call_history.get(module, {}) + return {m: cls.get_count(module, m) for m in list(methods)} + + @classmethod + def get_all_counts(cls) -> dict[str, dict[str, int]]: + """获取所有模块各方法在窗口内的调用统计""" + modules = list(cls._call_history) + return {module: cls.get_method_counts(module) for module in modules} + class Model(TortoiseModel): """ @@ -114,6 +168,7 @@ class Model(TortoiseModel): cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any ) -> Self: """创建数据(使用CREATE锁)""" + # ModelCallMetrics.record_call(cls.__module__, "create") async with cls._lock_context(DbLockType.CREATE): # 直接调用父类的_create方法避免触发save的锁 result = await super().create(using_db=using_db, **kwargs) @@ -129,6 +184,7 @@ class Model(TortoiseModel): **kwargs: Any, ) -> tuple[Self, bool]: """获取或创建数据(无锁版本,依赖数据库约束)""" + ModelCallMetrics.record_call(cls.__module__, "get_or_create") result = await super().get_or_create( defaults=defaults, using_db=using_db, **kwargs ) @@ -144,6 +200,7 @@ class Model(TortoiseModel): **kwargs: Any, ) -> tuple[Self, bool]: """更新或创建数据(使用UPSERT锁)""" + ModelCallMetrics.record_call(cls.__module__, "update_or_create") async with cls._lock_context(DbLockType.UPSERT): try: # 先尝试更新(带行锁) @@ -174,6 +231,7 @@ class Model(TortoiseModel): force_update: bool = False, ): """保存数据(根据操作类型自动选择锁)""" + ModelCallMetrics.record_call(self.__class__.__module__, "save") lock_type = ( DbLockType.CREATE if getattr(self, "id", None) is None @@ -192,6 +250,7 @@ class Model(TortoiseModel): ) async def delete(self, using_db: BaseDBAsyncClient | None = None): + ModelCallMetrics.record_call(self.__class__.__module__, "delete") cache_type = getattr(self, "cache_type", None) key = self.__class__.get_cache_key(self) if cache_type else None # 执行删除操作 @@ -221,6 +280,7 @@ class Model(TortoiseModel): 返回: Self | None: 查询结果,如果不存在返回None """ + ModelCallMetrics.record_call(cls.__module__, "safe_get_or_none") try: # 先尝试使用 get_or_none 获取单个记录 try: @@ -291,3 +351,28 @@ class Model(TortoiseModel): f"数据库操作异常: {cls.__name__}.safe_get_or_none, {e!s}", LOG_COMMAND ) raise + + +@scheduler.scheduled_job( + "interval", + seconds=10, +) +async def _(): + counts = ModelCallMetrics.get_all_counts() + if not counts: + return + + total_calls = sum(sum(method_counts.values()) for method_counts in counts.values()) + + lines: list[str] = ["Model 数据库方法调用次数(最近30秒):", f"总计: {total_calls}"] + for module in sorted(counts.keys()): + lines.append(f"- module: {module}") + method_counts = counts[module] + if not method_counts: + lines.append(" (no calls)") + continue + lines.extend( + f" {method}: {method_counts[method]}" + for method in sorted(method_counts.keys()) + ) + logger.debug("\n".join(lines))