From c8c21d7370ccfda3576e600aba67bbdf481f55fd Mon Sep 17 00:00:00 2001 From: HibiKier <775757368@qq.com> Date: Tue, 15 Jul 2025 00:45:52 +0800 Subject: [PATCH] =?UTF-8?q?:fire:=20=E7=A7=BB=E9=99=A4=E6=97=A7db-context?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- zhenxun/services/cache/__init__.py | 2 +- zhenxun/services/db_context.py | 465 ----------------------------- 2 files changed, 1 insertion(+), 466 deletions(-) delete mode 100644 zhenxun/services/db_context.py diff --git a/zhenxun/services/cache/__init__.py b/zhenxun/services/cache/__init__.py index f0a95805..e2ce9c30 100644 --- a/zhenxun/services/cache/__init__.py +++ b/zhenxun/services/cache/__init__.py @@ -418,7 +418,7 @@ class CacheManager: return self._cache_backend except ImportError as e: logger.error( - "导入aiocache[reids]失败,将默认使用内存缓存...", + "导入aiocache[redis]失败,将默认使用内存缓存...", LOG_COMMAND, e=e, ) diff --git a/zhenxun/services/db_context.py b/zhenxun/services/db_context.py deleted file mode 100644 index e6c42472..00000000 --- a/zhenxun/services/db_context.py +++ /dev/null @@ -1,465 +0,0 @@ -import asyncio -from collections.abc import Iterable -import contextlib -import time -from typing import Any, ClassVar -from typing_extensions import Self -from urllib.parse import urlparse - -from nonebot import get_driver -from nonebot.utils import is_coroutine_callable -from tortoise import Tortoise -from tortoise.backends.base.client import BaseDBAsyncClient -from tortoise.connection import connections -from tortoise.exceptions import IntegrityError, MultipleObjectsReturned -from tortoise.models import Model as TortoiseModel -from tortoise.transactions import in_transaction - -from zhenxun.configs.config import BotConfig -from zhenxun.services.cache import CacheRoot -from zhenxun.services.log import logger -from zhenxun.utils.enum import DbLockType -from zhenxun.utils.exception import HookPriorityException -from zhenxun.utils.manager.priority_manager import PriorityLifecycle - -driver = get_driver() - -SCRIPT_METHOD = [] -MODELS: list[str] = [] - -# 数据库操作超时设置(秒) -DB_TIMEOUT_SECONDS = 3.0 - -# 性能监控阈值(秒) -SLOW_QUERY_THRESHOLD = 0.5 - -LOG_COMMAND = "DbContext" - - -async def with_db_timeout( - coro, timeout: float = DB_TIMEOUT_SECONDS, operation: str | None = None -): - """带超时控制的数据库操作""" - start_time = time.time() - try: - result = await asyncio.wait_for(coro, timeout=timeout) - elapsed = time.time() - start_time - if elapsed > SLOW_QUERY_THRESHOLD and operation: - logger.warning(f"慢查询: {operation} 耗时 {elapsed:.3f}s", LOG_COMMAND) - return result - except asyncio.TimeoutError: - if operation: - logger.error(f"数据库操作超时: {operation} (>{timeout}s)", LOG_COMMAND) - raise - - -class Model(TortoiseModel): - """ - 增强的ORM基类,解决锁嵌套问题 - """ - - sem_data: ClassVar[dict[str, dict[str, asyncio.Semaphore]]] = {} - _current_locks: ClassVar[dict[int, DbLockType]] = {} # 跟踪当前协程持有的锁 - - def __init_subclass__(cls, **kwargs): - super().__init_subclass__(**kwargs) - if cls.__module__ not in MODELS: - MODELS.append(cls.__module__) - - if func := getattr(cls, "_run_script", None): - SCRIPT_METHOD.append((cls.__module__, func)) - - @classmethod - def get_cache_type(cls) -> str | None: - """获取缓存类型""" - return getattr(cls, "cache_type", None) - - @classmethod - def get_cache_key_field(cls) -> str | tuple[str]: - """获取缓存键字段""" - return getattr(cls, "cache_key_field", "id") - - @classmethod - def get_cache_key(cls, instance) -> str | None: - """获取缓存键 - - 参数: - instance: 模型实例 - - 返回: - str | None: 缓存键,如果无法获取则返回None - """ - from zhenxun.services.cache.config import COMPOSITE_KEY_SEPARATOR - - key_field = cls.get_cache_key_field() - - if isinstance(key_field, tuple): - # 多字段主键 - key_parts = [] - for field in key_field: - if hasattr(instance, field): - value = getattr(instance, field, None) - key_parts.append(value if value is not None else "") - else: - # 如果缺少任何必要的字段,返回None - key_parts.append("") - - # 如果没有有效参数,返回None - return COMPOSITE_KEY_SEPARATOR.join(key_parts) if key_parts else None - elif hasattr(instance, key_field): - value = getattr(instance, key_field, None) - return str(value) if value is not None else None - - return None - - @classmethod - def get_semaphore(cls, lock_type: DbLockType): - enable_lock = getattr(cls, "enable_lock", None) - if not enable_lock or lock_type not in enable_lock: - return None - - if cls.__name__ not in cls.sem_data: - cls.sem_data[cls.__name__] = {} - if lock_type not in cls.sem_data[cls.__name__]: - cls.sem_data[cls.__name__][lock_type] = asyncio.Semaphore(1) - return cls.sem_data[cls.__name__][lock_type] - - @classmethod - def _require_lock(cls, lock_type: DbLockType) -> bool: - """检查是否需要真正加锁""" - task_id = id(asyncio.current_task()) - return cls._current_locks.get(task_id) != lock_type - - @classmethod - @contextlib.asynccontextmanager - async def _lock_context(cls, lock_type: DbLockType): - """带重入检查的锁上下文""" - task_id = id(asyncio.current_task()) - need_lock = cls._require_lock(lock_type) - - if need_lock and (sem := cls.get_semaphore(lock_type)): - cls._current_locks[task_id] = lock_type - async with sem: - yield - cls._current_locks.pop(task_id, None) - else: - yield - - @classmethod - async def create( - cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any - ) -> Self: - """创建数据(使用CREATE锁)""" - async with cls._lock_context(DbLockType.CREATE): - # 直接调用父类的_create方法避免触发save的锁 - result = await super().create(using_db=using_db, **kwargs) - if cache_type := cls.get_cache_type(): - await CacheRoot.invalidate_cache(cache_type, cls.get_cache_key(result)) - return result - - @classmethod - async def get_or_create( - cls, - defaults: dict | None = None, - using_db: BaseDBAsyncClient | None = None, - **kwargs: Any, - ) -> tuple[Self, bool]: - """获取或创建数据(无锁版本,依赖数据库约束)""" - result = await super().get_or_create( - defaults=defaults, using_db=using_db, **kwargs - ) - if cache_type := cls.get_cache_type(): - await CacheRoot.invalidate_cache(cache_type, cls.get_cache_key(result[0])) - return result - - @classmethod - async def update_or_create( - cls, - defaults: dict | None = None, - using_db: BaseDBAsyncClient | None = None, - **kwargs: Any, - ) -> tuple[Self, bool]: - """更新或创建数据(使用UPSERT锁)""" - async with cls._lock_context(DbLockType.UPSERT): - try: - # 先尝试更新(带行锁) - async with in_transaction(): - if obj := await cls.filter(**kwargs).select_for_update().first(): - await obj.update_from_dict(defaults or {}) - await obj.save() - result = (obj, False) - else: - # 创建时不重复加锁 - result = await cls.create(**kwargs, **(defaults or {})), True - - if cache_type := cls.get_cache_type(): - await CacheRoot.invalidate_cache( - cache_type, cls.get_cache_key(result[0]) - ) - return result - except IntegrityError: - # 处理极端情况下的唯一约束冲突 - obj = await cls.get(**kwargs) - return obj, False - - async def save( - self, - using_db: BaseDBAsyncClient | None = None, - update_fields: Iterable[str] | None = None, - force_create: bool = False, - force_update: bool = False, - ): - """保存数据(根据操作类型自动选择锁)""" - lock_type = ( - DbLockType.CREATE - if getattr(self, "id", None) is None - else DbLockType.UPDATE - ) - async with self._lock_context(lock_type): - await super().save( - using_db=using_db, - update_fields=update_fields, - force_create=force_create, - force_update=force_update, - ) - if cache_type := getattr(self, "cache_type", None): - await CacheRoot.invalidate_cache( - cache_type, self.__class__.get_cache_key(self) - ) - - async def delete(self, using_db: BaseDBAsyncClient | None = None): - cache_type = getattr(self, "cache_type", None) - key = self.__class__.get_cache_key(self) if cache_type else None - # 执行删除操作 - await super().delete(using_db=using_db) - - # 清除缓存 - if cache_type: - await CacheRoot.invalidate_cache(cache_type, key) - - @classmethod - async def safe_get_or_none( - cls, - *args, - using_db: BaseDBAsyncClient | None = None, - clean_duplicates: bool = True, - **kwargs: Any, - ) -> Self | None: - """安全地获取一条记录或None,处理存在多个记录时返回最新的那个 - 注意,默认会删除重复的记录,仅保留最新的 - - 参数: - *args: 查询参数 - using_db: 数据库连接 - clean_duplicates: 是否删除重复的记录,仅保留最新的 - **kwargs: 查询参数 - - 返回: - Self | None: 查询结果,如果不存在返回None - """ - try: - # 先尝试使用 get_or_none 获取单个记录 - try: - return await with_db_timeout( - cls.get_or_none(*args, using_db=using_db, **kwargs), - operation=f"{cls.__name__}.get_or_none", - ) - except MultipleObjectsReturned: - # 如果出现多个记录的情况,进行特殊处理 - logger.warning( - f"{cls.__name__} safe_get_or_none 发现多个记录: {kwargs}", - LOG_COMMAND, - ) - - # 查询所有匹配记录 - records = await with_db_timeout( - cls.filter(*args, **kwargs).all(), - operation=f"{cls.__name__}.filter.all", - ) - - if not records: - return None - - # 如果需要清理重复记录 - if clean_duplicates and hasattr(records[0], "id"): - # 按 id 排序 - records = sorted( - records, key=lambda x: getattr(x, "id", 0), reverse=True - ) - for record in records[1:]: - try: - await with_db_timeout( - record.delete(), - operation=f"{cls.__name__}.delete_duplicate", - ) - logger.info( - f"{cls.__name__} 删除重复记录:" - f" id={getattr(record, 'id', None)}", - LOG_COMMAND, - ) - except Exception as del_e: - logger.error(f"删除重复记录失败: {del_e}") - return records[0] - # 如果不需要清理或没有 id 字段,则返回最新的记录 - if hasattr(cls, "id"): - return await with_db_timeout( - cls.filter(*args, **kwargs).order_by("-id").first(), - operation=f"{cls.__name__}.filter.order_by.first", - ) - # 如果没有 id 字段,则返回第一个记录 - return await with_db_timeout( - cls.filter(*args, **kwargs).first(), - operation=f"{cls.__name__}.filter.first", - ) - except asyncio.TimeoutError: - logger.error( - f"数据库操作超时: {cls.__name__}.safe_get_or_none", LOG_COMMAND - ) - return None - except Exception as e: - # 其他类型的错误则继续抛出 - logger.error( - f"数据库操作异常: {cls.__name__}.safe_get_or_none, {e!s}", LOG_COMMAND - ) - raise - - -class DbUrlIsNode(HookPriorityException): - """ - 数据库链接地址为空 - """ - - pass - - -class DbConnectError(Exception): - """ - 数据库连接错误 - """ - - pass - - -POSTGRESQL_CONFIG = { - "max_size": 30, # 最大连接数 - "min_size": 5, # 最小保持的连接数(可选) -} - - -MYSQL_CONFIG = { - "max_connections": 20, # 最大连接数 - "connect_timeout": 30, # 连接超时(可选) -} - -SQLITE_CONFIG = { - "journal_mode": "WAL", # 提高并发写入性能 - "timeout": 30, # 锁等待超时(可选) -} - - -def get_config() -> dict: - """获取数据库配置""" - parsed = urlparse(BotConfig.db_url) - - # 基础配置 - config = { - "connections": { - "default": BotConfig.db_url # 默认直接使用连接字符串 - }, - "apps": { - "models": { - "models": MODELS, - "default_connection": "default", - } - }, - "timezone": "Asia/Shanghai", - } - - # 根据数据库类型应用高级配置 - if parsed.scheme.startswith("postgres"): - config["connections"]["default"] = { - "engine": "tortoise.backends.asyncpg", - "credentials": { - "host": parsed.hostname, - "port": parsed.port or 5432, - "user": parsed.username, - "password": parsed.password, - "database": parsed.path[1:], - }, - **POSTGRESQL_CONFIG, - } - elif parsed.scheme == "mysql": - config["connections"]["default"] = { - "engine": "tortoise.backends.mysql", - "credentials": { - "host": parsed.hostname, - "port": parsed.port or 3306, - "user": parsed.username, - "password": parsed.password, - "database": parsed.path[1:], - }, - **MYSQL_CONFIG, - } - elif parsed.scheme == "sqlite": - config["connections"]["default"] = { - "engine": "tortoise.backends.sqlite", - "credentials": { - "file_path": parsed.path or ":memory:", - }, - **SQLITE_CONFIG, - } - return config - - -@PriorityLifecycle.on_startup(priority=1) -async def init(): - if not BotConfig.db_url: - # raise DbUrlIsNode("数据库配置为空,请在.env.dev中配置DB_URL...") - error = f""" -********************************************************************** -🌟 **************************** 配置为空 ************************* 🌟 -🚀 请打开 WebUi 进行基础配置 🚀 -🌐 配置地址:http://{driver.config.host}:{driver.config.port}/#/configure 🌐 -*********************************************************************** -*********************************************************************** - """ - raise DbUrlIsNode("\n" + error.strip()) - try: - await Tortoise.init( - config=get_config(), - ) - if SCRIPT_METHOD: - db = Tortoise.get_connection("default") - logger.debug( - "即将运行SCRIPT_METHOD方法, 合计 " - f"{len(SCRIPT_METHOD)} 个..." - ) - sql_list = [] - for module, func in SCRIPT_METHOD: - try: - sql = await func() if is_coroutine_callable(func) else func() - if sql: - sql_list += sql - except Exception as e: - logger.debug(f"{module} 执行SCRIPT_METHOD方法出错...", e=e) - for sql in sql_list: - logger.debug(f"执行SQL: {sql}") - try: - await asyncio.wait_for( - db.execute_query_dict(sql), timeout=DB_TIMEOUT_SECONDS - ) - # await TestSQL.raw(sql) - except Exception as e: - logger.debug(f"执行SQL: {sql} 错误...", e=e) - if sql_list: - logger.debug("SCRIPT_METHOD方法执行完毕!") - logger.debug("开始生成数据库表结构...") - await Tortoise.generate_schemas() - logger.debug("数据库表结构生成完毕!") - logger.info("Database loaded successfully!") - except Exception as e: - raise DbConnectError(f"数据库连接错误... e:{e}") from e - - -async def disconnect(): - await connections.close_all()