mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-14 21:52:56 +08:00
379 lines
14 KiB
Python
379 lines
14 KiB
Python
import asyncio
|
||
from collections import deque
|
||
from collections.abc import Iterable
|
||
import contextlib
|
||
import time
|
||
from typing import Any, ClassVar
|
||
from typing_extensions import Self
|
||
|
||
from nonebot_plugin_apscheduler import scheduler
|
||
from tortoise.backends.base.client import BaseDBAsyncClient
|
||
from tortoise.exceptions import IntegrityError, MultipleObjectsReturned
|
||
from tortoise.models import Model as TortoiseModel
|
||
from tortoise.transactions import in_transaction
|
||
|
||
from zhenxun.services.cache import CacheRoot
|
||
from zhenxun.services.log import logger
|
||
from zhenxun.utils.enum import DbLockType
|
||
|
||
from .config import LOG_COMMAND, db_model
|
||
from .utils import with_db_timeout
|
||
|
||
start_time = time.time()
|
||
|
||
|
||
class ModelCallMetrics:
|
||
"""记录Model数据库方法调用次数(30秒窗口内,按模块与方法分组)"""
|
||
|
||
window_seconds: ClassVar[float] = 30.0
|
||
_call_history: ClassVar[dict[str, dict[str, deque[float]]]] = {}
|
||
|
||
@classmethod
|
||
def _trim(cls, module: str, method: str, now: float | None = None):
|
||
"""移除指定模块/方法超出时间窗口的记录"""
|
||
module_hist = cls._call_history.get(module)
|
||
if not module_hist or method not in module_hist:
|
||
return
|
||
|
||
now = time.monotonic() if now is None else now
|
||
time_limit = now - cls.window_seconds
|
||
history = module_hist[method]
|
||
while history and history[0] < time_limit:
|
||
history.popleft()
|
||
|
||
@classmethod
|
||
def record_call(cls, module: str, method: str):
|
||
"""记录一次数据库方法调用(按模块与方法)"""
|
||
if time.time() - start_time < 20:
|
||
return
|
||
now = time.monotonic()
|
||
method_hist = cls._call_history.setdefault(module, {})
|
||
history = method_hist.setdefault(method, deque())
|
||
history.append(now)
|
||
cls._trim(module, method, now)
|
||
|
||
@classmethod
|
||
def get_count(cls, module: str, method: str) -> int:
|
||
"""获取指定模块指定方法在窗口内的调用次数"""
|
||
cls._trim(module, method)
|
||
return len(cls._call_history.get(module, {}).get(method, ()))
|
||
|
||
@classmethod
|
||
def get_method_counts(cls, module: str) -> dict[str, int]:
|
||
"""获取指定模块各方法在窗口内的调用统计"""
|
||
methods = cls._call_history.get(module, {})
|
||
return {m: cls.get_count(module, m) for m in list(methods)}
|
||
|
||
@classmethod
|
||
def get_all_counts(cls) -> dict[str, dict[str, int]]:
|
||
"""获取所有模块各方法在窗口内的调用统计"""
|
||
modules = list(cls._call_history)
|
||
return {module: cls.get_method_counts(module) for module in modules}
|
||
|
||
|
||
class Model(TortoiseModel):
|
||
"""
|
||
增强的ORM基类,解决锁嵌套问题
|
||
"""
|
||
|
||
sem_data: ClassVar[dict[str, dict[str, asyncio.Semaphore]]] = {}
|
||
_current_locks: ClassVar[dict[int, DbLockType]] = {} # 跟踪当前协程持有的锁
|
||
|
||
def __init_subclass__(cls, **kwargs):
|
||
super().__init_subclass__(**kwargs)
|
||
if cls.__module__ not in db_model.models:
|
||
db_model.models.append(cls.__module__)
|
||
|
||
if func := getattr(cls, "_run_script", None):
|
||
db_model.script_method.append((cls.__module__, func))
|
||
|
||
@classmethod
|
||
def get_cache_type(cls) -> str | None:
|
||
"""获取缓存类型"""
|
||
return getattr(cls, "cache_type", None)
|
||
|
||
@classmethod
|
||
def get_cache_key_field(cls) -> str | tuple[str]:
|
||
"""获取缓存键字段"""
|
||
return getattr(cls, "cache_key_field", "id")
|
||
|
||
@classmethod
|
||
def get_cache_key(cls, instance) -> str | None:
|
||
"""获取缓存键
|
||
|
||
参数:
|
||
instance: 模型实例
|
||
|
||
返回:
|
||
str | None: 缓存键,如果无法获取则返回None
|
||
"""
|
||
from zhenxun.services.cache.config import COMPOSITE_KEY_SEPARATOR
|
||
|
||
key_field = cls.get_cache_key_field()
|
||
|
||
if isinstance(key_field, tuple):
|
||
# 多字段主键
|
||
key_parts = []
|
||
for field in key_field:
|
||
if hasattr(instance, field):
|
||
value = getattr(instance, field, None)
|
||
key_parts.append(value if value is not None else "")
|
||
else:
|
||
# 如果缺少任何必要的字段,返回None
|
||
key_parts.append("")
|
||
|
||
# 如果没有有效参数,返回None
|
||
return COMPOSITE_KEY_SEPARATOR.join(key_parts) if key_parts else None
|
||
elif hasattr(instance, key_field):
|
||
value = getattr(instance, key_field, None)
|
||
return str(value) if value is not None else None
|
||
|
||
return None
|
||
|
||
@classmethod
|
||
def get_semaphore(cls, lock_type: DbLockType):
|
||
enable_lock = getattr(cls, "enable_lock", None)
|
||
if not enable_lock or lock_type not in enable_lock:
|
||
return None
|
||
|
||
if cls.__name__ not in cls.sem_data:
|
||
cls.sem_data[cls.__name__] = {}
|
||
if lock_type not in cls.sem_data[cls.__name__]:
|
||
cls.sem_data[cls.__name__][lock_type] = asyncio.Semaphore(1)
|
||
return cls.sem_data[cls.__name__][lock_type]
|
||
|
||
@classmethod
|
||
def _require_lock(cls, lock_type: DbLockType) -> bool:
|
||
"""检查是否需要真正加锁"""
|
||
task_id = id(asyncio.current_task())
|
||
return cls._current_locks.get(task_id) != lock_type
|
||
|
||
@classmethod
|
||
@contextlib.asynccontextmanager
|
||
async def _lock_context(cls, lock_type: DbLockType):
|
||
"""带重入检查的锁上下文"""
|
||
task_id = id(asyncio.current_task())
|
||
need_lock = cls._require_lock(lock_type)
|
||
|
||
if need_lock and (sem := cls.get_semaphore(lock_type)):
|
||
cls._current_locks[task_id] = lock_type
|
||
async with sem:
|
||
yield
|
||
cls._current_locks.pop(task_id, None)
|
||
else:
|
||
yield
|
||
|
||
@classmethod
|
||
async def create(
|
||
cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any
|
||
) -> Self:
|
||
"""创建数据(使用CREATE锁)"""
|
||
# ModelCallMetrics.record_call(cls.__module__, "create")
|
||
async with cls._lock_context(DbLockType.CREATE):
|
||
# 直接调用父类的_create方法避免触发save的锁
|
||
result = await super().create(using_db=using_db, **kwargs)
|
||
if cache_type := cls.get_cache_type():
|
||
await CacheRoot.invalidate_cache(cache_type, cls.get_cache_key(result))
|
||
return result
|
||
|
||
@classmethod
|
||
async def get_or_create(
|
||
cls,
|
||
defaults: dict | None = None,
|
||
using_db: BaseDBAsyncClient | None = None,
|
||
**kwargs: Any,
|
||
) -> tuple[Self, bool]:
|
||
"""获取或创建数据(无锁版本,依赖数据库约束)"""
|
||
ModelCallMetrics.record_call(cls.__module__, "get_or_create")
|
||
result = await super().get_or_create(
|
||
defaults=defaults, using_db=using_db, **kwargs
|
||
)
|
||
if cache_type := cls.get_cache_type():
|
||
await CacheRoot.invalidate_cache(cache_type, cls.get_cache_key(result[0]))
|
||
return result
|
||
|
||
@classmethod
|
||
async def update_or_create(
|
||
cls,
|
||
defaults: dict | None = None,
|
||
using_db: BaseDBAsyncClient | None = None,
|
||
**kwargs: Any,
|
||
) -> tuple[Self, bool]:
|
||
"""更新或创建数据(使用UPSERT锁)"""
|
||
ModelCallMetrics.record_call(cls.__module__, "update_or_create")
|
||
async with cls._lock_context(DbLockType.UPSERT):
|
||
try:
|
||
# 先尝试更新(带行锁)
|
||
async with in_transaction():
|
||
if obj := await cls.filter(**kwargs).select_for_update().first():
|
||
await obj.update_from_dict(defaults or {})
|
||
await obj.save()
|
||
result = (obj, False)
|
||
else:
|
||
# 创建时不重复加锁
|
||
result = await cls.create(**kwargs, **(defaults or {})), True
|
||
|
||
if cache_type := cls.get_cache_type():
|
||
await CacheRoot.invalidate_cache(
|
||
cache_type, cls.get_cache_key(result[0])
|
||
)
|
||
return result
|
||
except IntegrityError:
|
||
# 处理极端情况下的唯一约束冲突
|
||
obj = await cls.get(**kwargs)
|
||
return obj, False
|
||
|
||
async def save(
|
||
self,
|
||
using_db: BaseDBAsyncClient | None = None,
|
||
update_fields: Iterable[str] | None = None,
|
||
force_create: bool = False,
|
||
force_update: bool = False,
|
||
):
|
||
"""保存数据(根据操作类型自动选择锁)"""
|
||
ModelCallMetrics.record_call(self.__class__.__module__, "save")
|
||
lock_type = (
|
||
DbLockType.CREATE
|
||
if getattr(self, "id", None) is None
|
||
else DbLockType.UPDATE
|
||
)
|
||
async with self._lock_context(lock_type):
|
||
await super().save(
|
||
using_db=using_db,
|
||
update_fields=update_fields,
|
||
force_create=force_create,
|
||
force_update=force_update,
|
||
)
|
||
if cache_type := getattr(self, "cache_type", None):
|
||
await CacheRoot.invalidate_cache(
|
||
cache_type, self.__class__.get_cache_key(self)
|
||
)
|
||
|
||
async def delete(self, using_db: BaseDBAsyncClient | None = None):
|
||
ModelCallMetrics.record_call(self.__class__.__module__, "delete")
|
||
cache_type = getattr(self, "cache_type", None)
|
||
key = self.__class__.get_cache_key(self) if cache_type else None
|
||
# 执行删除操作
|
||
await super().delete(using_db=using_db)
|
||
|
||
# 清除缓存
|
||
if cache_type:
|
||
await CacheRoot.invalidate_cache(cache_type, key)
|
||
|
||
@classmethod
|
||
async def safe_get_or_none(
|
||
cls,
|
||
*args,
|
||
using_db: BaseDBAsyncClient | None = None,
|
||
clean_duplicates: bool = True,
|
||
**kwargs: Any,
|
||
) -> Self | None:
|
||
"""安全地获取一条记录或None,处理存在多个记录时返回最新的那个
|
||
注意,默认会删除重复的记录,仅保留最新的
|
||
|
||
参数:
|
||
*args: 查询参数
|
||
using_db: 数据库连接
|
||
clean_duplicates: 是否删除重复的记录,仅保留最新的
|
||
**kwargs: 查询参数
|
||
|
||
返回:
|
||
Self | None: 查询结果,如果不存在返回None
|
||
"""
|
||
ModelCallMetrics.record_call(cls.__module__, "safe_get_or_none")
|
||
try:
|
||
# 先尝试使用 get_or_none 获取单个记录
|
||
try:
|
||
return await with_db_timeout(
|
||
cls.get_or_none(*args, using_db=using_db, **kwargs),
|
||
operation=f"{cls.__name__}.get_or_none",
|
||
source="DataBaseModel",
|
||
)
|
||
except MultipleObjectsReturned:
|
||
# 如果出现多个记录的情况,进行特殊处理
|
||
logger.warning(
|
||
f"{cls.__name__} safe_get_or_none 发现多个记录: {kwargs}",
|
||
LOG_COMMAND,
|
||
)
|
||
|
||
# 查询所有匹配记录
|
||
records = await with_db_timeout(
|
||
cls.filter(*args, **kwargs).all(),
|
||
operation=f"{cls.__name__}.filter.all",
|
||
source="DataBaseModel",
|
||
)
|
||
|
||
if not records:
|
||
return None
|
||
|
||
# 如果需要清理重复记录
|
||
if clean_duplicates and hasattr(records[0], "id"):
|
||
# 按 id 排序
|
||
records = sorted(
|
||
records, key=lambda x: getattr(x, "id", 0), reverse=True
|
||
)
|
||
for record in records[1:]:
|
||
try:
|
||
await with_db_timeout(
|
||
record.delete(),
|
||
operation=f"{cls.__name__}.delete_duplicate",
|
||
source="DataBaseModel",
|
||
)
|
||
logger.info(
|
||
f"{cls.__name__} 删除重复记录:"
|
||
f" id={getattr(record, 'id', None)}",
|
||
LOG_COMMAND,
|
||
)
|
||
except Exception as del_e:
|
||
logger.error(f"删除重复记录失败: {del_e}")
|
||
return records[0]
|
||
# 如果不需要清理或没有 id 字段,则返回最新的记录
|
||
if hasattr(cls, "id"):
|
||
return await with_db_timeout(
|
||
cls.filter(*args, **kwargs).order_by("-id").first(),
|
||
operation=f"{cls.__name__}.filter.order_by.first",
|
||
source="DataBaseModel",
|
||
)
|
||
# 如果没有 id 字段,则返回第一个记录
|
||
return await with_db_timeout(
|
||
cls.filter(*args, **kwargs).first(),
|
||
operation=f"{cls.__name__}.filter.first",
|
||
source="DataBaseModel",
|
||
)
|
||
except asyncio.TimeoutError:
|
||
logger.error(
|
||
f"数据库操作超时: {cls.__name__}.safe_get_or_none", LOG_COMMAND
|
||
)
|
||
return None
|
||
except Exception as e:
|
||
# 其他类型的错误则继续抛出
|
||
logger.error(
|
||
f"数据库操作异常: {cls.__name__}.safe_get_or_none, {e!s}", LOG_COMMAND
|
||
)
|
||
raise
|
||
|
||
|
||
@scheduler.scheduled_job(
|
||
"interval",
|
||
seconds=10,
|
||
)
|
||
async def _():
|
||
counts = ModelCallMetrics.get_all_counts()
|
||
if not counts:
|
||
return
|
||
|
||
total_calls = sum(sum(method_counts.values()) for method_counts in counts.values())
|
||
|
||
lines: list[str] = ["Model 数据库方法调用次数(最近30秒):", f"总计: {total_calls}"]
|
||
for module in sorted(counts.keys()):
|
||
lines.append(f"- module: {module}")
|
||
method_counts = counts[module]
|
||
if not method_counts:
|
||
lines.append(" (no calls)")
|
||
continue
|
||
lines.extend(
|
||
f" {method}: {method_counts[method]}"
|
||
for method in sorted(method_counts.keys())
|
||
)
|
||
logger.debug("\n".join(lines))
|