From d218c569d46760e630cef14a5cbf12421b612891 Mon Sep 17 00:00:00 2001 From: HibiKier <45528451+HibiKier@users.noreply.github.com> Date: Tue, 15 Jul 2025 17:08:42 +0800 Subject: [PATCH] =?UTF-8?q?:sparkles:=20=E6=A0=BC=E5=BC=8F=E5=8C=96db=5Fco?= =?UTF-8?q?ntext=20(#1980)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * :sparkles: 格式化db_context * :fire: 移除旧db-context * :zap: 添加旧版本兼容 --- zhenxun/builtin_plugins/__init__.py | 11 +- zhenxun/services/__init__.py | 3 +- zhenxun/services/cache/__init__.py | 59 +++--- zhenxun/services/cache/cache_containers.py | 21 +- zhenxun/services/db_context/__init__.py | 146 ++++++++++++++ .../base_model.py} | 187 +----------------- zhenxun/services/db_context/config.py | 46 +++++ zhenxun/services/db_context/exceptions.py | 14 ++ zhenxun/services/db_context/utils.py | 27 +++ zhenxun/services/plugin_init.py | 27 +-- 10 files changed, 293 insertions(+), 248 deletions(-) create mode 100644 zhenxun/services/db_context/__init__.py rename zhenxun/services/{db_context.py => db_context/base_model.py} (64%) create mode 100644 zhenxun/services/db_context/config.py create mode 100644 zhenxun/services/db_context/exceptions.py create mode 100644 zhenxun/services/db_context/utils.py diff --git a/zhenxun/builtin_plugins/__init__.py b/zhenxun/builtin_plugins/__init__.py index 4003e506..a5aa7a4b 100644 --- a/zhenxun/builtin_plugins/__init__.py +++ b/zhenxun/builtin_plugins/__init__.py @@ -5,7 +5,7 @@ import nonebot from nonebot.adapters import Bot from nonebot.drivers import Driver from tortoise import Tortoise -from tortoise.exceptions import OperationalError +from tortoise.exceptions import IntegrityError, OperationalError import ujson as json from zhenxun.models.bot_connect_log import BotConnectLog @@ -30,9 +30,12 @@ async def _(bot: Bot): bot_id=bot.self_id, platform=bot.adapter, connect_time=datetime.now(), type=1 ) if not await BotConsole.exists(bot_id=bot.self_id): - await BotConsole.create( - bot_id=bot.self_id, platform=PlatformUtils.get_platform(bot) - ) + try: + await BotConsole.create( + bot_id=bot.self_id, platform=PlatformUtils.get_platform(bot) + ) + except IntegrityError as e: + logger.warning(f"记录bot: {bot.self_id} 数据已存在...", e=e) @driver.on_bot_disconnect diff --git a/zhenxun/services/__init__.py b/zhenxun/services/__init__.py index 4c820b87..b3dc292e 100644 --- a/zhenxun/services/__init__.py +++ b/zhenxun/services/__init__.py @@ -18,7 +18,7 @@ require("nonebot_plugin_htmlrender") require("nonebot_plugin_uninfo") require("nonebot_plugin_waiter") -from .db_context import Model, disconnect +from .db_context import Model, disconnect, with_db_timeout from .llm import ( AI, AIConfig, @@ -80,4 +80,5 @@ __all__ = [ "search_multimodal", "set_global_default_model_name", "tool_registry", + "with_db_timeout", ] diff --git a/zhenxun/services/cache/__init__.py b/zhenxun/services/cache/__init__.py index 76b05a5c..e2ce9c30 100644 --- a/zhenxun/services/cache/__init__.py +++ b/zhenxun/services/cache/__init__.py @@ -63,6 +63,7 @@ from functools import wraps from typing import Any, ClassVar, Generic, TypeVar, get_type_hints from aiocache import Cache as AioCache +from aiocache import SimpleMemoryCache from aiocache.base import BaseCache from aiocache.serializers import JsonSerializer import nonebot @@ -392,22 +393,14 @@ class CacheManager: def cache_backend(self) -> BaseCache | AioCache: """获取缓存后端""" if self._cache_backend is None: - try: - from aiocache import RedisCache, SimpleMemoryCache + ttl = cache_config.redis_expire + if cache_config.cache_mode == CacheMode.NONE: + ttl = 0 + logger.info("缓存功能已禁用,使用非持久化内存缓存", LOG_COMMAND) + elif cache_config.cache_mode == CacheMode.REDIS and cache_config.redis_host: + try: + from aiocache import RedisCache - if cache_config.cache_mode == CacheMode.NONE: - # 使用内存缓存但禁用持久化 - self._cache_backend = SimpleMemoryCache( - serializer=JsonSerializer(), - namespace=CACHE_KEY_PREFIX, - timeout=30, - ttl=0, # 设置为0,不缓存 - ) - logger.info("缓存功能已禁用,使用非持久化内存缓存", LOG_COMMAND) - elif ( - cache_config.cache_mode == CacheMode.REDIS - and cache_config.redis_host - ): # 使用Redis缓存 self._cache_backend = RedisCache( serializer=JsonSerializer(), @@ -419,27 +412,25 @@ class CacheManager: password=cache_config.redis_password, ) logger.info( - f"使用Redis缓存,地址: {cache_config.redis_host}", LOG_COMMAND + f"使用Redis缓存,地址: {cache_config.redis_host}", + LOG_COMMAND, ) - else: - # 默认使用内存缓存 - self._cache_backend = SimpleMemoryCache( - serializer=JsonSerializer(), - namespace=CACHE_KEY_PREFIX, - timeout=30, - ttl=cache_config.redis_expire, + return self._cache_backend + except ImportError as e: + logger.error( + "导入aiocache[redis]失败,将默认使用内存缓存...", + LOG_COMMAND, + e=e, ) - logger.info("使用内存缓存", LOG_COMMAND) - except ImportError: - logger.error("导入aiocache模块失败,使用内存缓存", LOG_COMMAND) - # 使用内存缓存 - self._cache_backend = AioCache( - cache_class=AioCache.MEMORY, - serializer=JsonSerializer(), - namespace=CACHE_KEY_PREFIX, - timeout=30, - ttl=cache_config.redis_expire, - ) + else: + logger.info("使用内存缓存", LOG_COMMAND) + # 默认使用内存缓存 + self._cache_backend = SimpleMemoryCache( + serializer=JsonSerializer(), + namespace=CACHE_KEY_PREFIX, + timeout=30, + ttl=ttl, + ) return self._cache_backend @property diff --git a/zhenxun/services/cache/cache_containers.py b/zhenxun/services/cache/cache_containers.py index b0efe3fb..42b35239 100644 --- a/zhenxun/services/cache/cache_containers.py +++ b/zhenxun/services/cache/cache_containers.py @@ -54,11 +54,7 @@ class CacheDict: key: 字典键 value: 字典值 """ - # 计算过期时间 - expire_time = 0 - if self.expire > 0: - expire_time = time.time() + self.expire - + expire_time = time.time() + self.expire if self.expire > 0 else 0 self._data[key] = CacheData(value=value, expire_time=expire_time) def __delitem__(self, key: str) -> None: @@ -274,12 +270,11 @@ class CacheList: self.clear() raise IndexError(f"列表索引 {index} 超出范围") - if 0 <= index < len(self._data): - del self._data[index] - # 更新过期时间 - self._update_expire_time() - else: + if not 0 <= index < len(self._data): raise IndexError(f"列表索引 {index} 超出范围") + del self._data[index] + # 更新过期时间 + self._update_expire_time() def __len__(self) -> int: """获取列表长度 @@ -427,6 +422,7 @@ class CacheList: self.clear() return 0 + # sourcery skip: simplify-constant-sum return sum(1 for item in self._data if item.value == value) def _is_expired(self) -> bool: @@ -435,10 +431,7 @@ class CacheList: def _update_expire_time(self) -> None: """更新过期时间""" - if self.expire > 0: - self._expire_time = time.time() + self.expire - else: - self._expire_time = 0 + self._expire_time = time.time() + self.expire if self.expire > 0 else 0 def __str__(self) -> str: """字符串表示 diff --git a/zhenxun/services/db_context/__init__.py b/zhenxun/services/db_context/__init__.py new file mode 100644 index 00000000..26fd9bcd --- /dev/null +++ b/zhenxun/services/db_context/__init__.py @@ -0,0 +1,146 @@ +import asyncio +from urllib.parse import urlparse + +import nonebot +from nonebot.utils import is_coroutine_callable +from tortoise import Tortoise +from tortoise.connection import connections + +from zhenxun.configs.config import BotConfig +from zhenxun.services.log import logger +from zhenxun.utils.manager.priority_manager import PriorityLifecycle + +from .base_model import Model +from .config import ( + DB_TIMEOUT_SECONDS, + MYSQL_CONFIG, + POSTGRESQL_CONFIG, + SLOW_QUERY_THRESHOLD, + SQLITE_CONFIG, + db_model, + prompt, +) +from .exceptions import DbConnectError, DbUrlIsNode +from .utils import with_db_timeout + +MODELS = db_model.models +SCRIPT_METHOD = db_model.script_method + +__all__ = [ + "DB_TIMEOUT_SECONDS", + "MODELS", + "SCRIPT_METHOD", + "SLOW_QUERY_THRESHOLD", + "DbConnectError", + "DbUrlIsNode", + "Model", + "disconnect", + "init", + "with_db_timeout", +] + +driver = nonebot.get_driver() + + +def get_config() -> dict: + """获取数据库配置""" + parsed = urlparse(BotConfig.db_url) + + # 基础配置 + config = { + "connections": { + "default": BotConfig.db_url # 默认直接使用连接字符串 + }, + "apps": { + "models": { + "models": db_model.models, + "default_connection": "default", + } + }, + "timezone": "Asia/Shanghai", + } + + # 根据数据库类型应用高级配置 + if parsed.scheme.startswith("postgres"): + config["connections"]["default"] = { + "engine": "tortoise.backends.asyncpg", + "credentials": { + "host": parsed.hostname, + "port": parsed.port or 5432, + "user": parsed.username, + "password": parsed.password, + "database": parsed.path[1:], + }, + **POSTGRESQL_CONFIG, + } + elif parsed.scheme == "mysql": + config["connections"]["default"] = { + "engine": "tortoise.backends.mysql", + "credentials": { + "host": parsed.hostname, + "port": parsed.port or 3306, + "user": parsed.username, + "password": parsed.password, + "database": parsed.path[1:], + }, + **MYSQL_CONFIG, + } + elif parsed.scheme == "sqlite": + config["connections"]["default"] = { + "engine": "tortoise.backends.sqlite", + "credentials": { + "file_path": parsed.path or ":memory:", + }, + **SQLITE_CONFIG, + } + return config + + +@PriorityLifecycle.on_startup(priority=1) +async def init(): + global MODELS, SCRIPT_METHOD + + MODELS = db_model.models + SCRIPT_METHOD = db_model.script_method + if not BotConfig.db_url: + error = prompt.format(host=driver.config.host, port=driver.config.port) + raise DbUrlIsNode("\n" + error.strip()) + try: + await Tortoise.init( + config=get_config(), + ) + if db_model.script_method: + db = Tortoise.get_connection("default") + logger.debug( + "即将运行SCRIPT_METHOD方法, 合计 " + f"{len(db_model.script_method)} 个..." + ) + sql_list = [] + for module, func in db_model.script_method: + try: + sql = await func() if is_coroutine_callable(func) else func() + if sql: + sql_list += sql + except Exception as e: + logger.debug(f"{module} 执行SCRIPT_METHOD方法出错...", e=e) + for sql in sql_list: + logger.debug(f"执行SQL: {sql}") + try: + await asyncio.wait_for( + db.execute_query_dict(sql), timeout=DB_TIMEOUT_SECONDS + ) + # await TestSQL.raw(sql) + except Exception as e: + logger.debug(f"执行SQL: {sql} 错误...", e=e) + if sql_list: + logger.debug("SCRIPT_METHOD方法执行完毕!") + logger.debug("开始生成数据库表结构...") + await Tortoise.generate_schemas() + logger.debug("数据库表结构生成完毕!") + logger.info("Database loaded successfully!") + except Exception as e: + raise DbConnectError(f"数据库连接错误... e:{e}") from e + + +async def disconnect(): + await connections.close_all() diff --git a/zhenxun/services/db_context.py b/zhenxun/services/db_context/base_model.py similarity index 64% rename from zhenxun/services/db_context.py rename to zhenxun/services/db_context/base_model.py index e6c42472..3e0e23ef 100644 --- a/zhenxun/services/db_context.py +++ b/zhenxun/services/db_context/base_model.py @@ -1,56 +1,20 @@ import asyncio from collections.abc import Iterable import contextlib -import time from typing import Any, ClassVar from typing_extensions import Self -from urllib.parse import urlparse -from nonebot import get_driver -from nonebot.utils import is_coroutine_callable -from tortoise import Tortoise from tortoise.backends.base.client import BaseDBAsyncClient -from tortoise.connection import connections from tortoise.exceptions import IntegrityError, MultipleObjectsReturned from tortoise.models import Model as TortoiseModel from tortoise.transactions import in_transaction -from zhenxun.configs.config import BotConfig from zhenxun.services.cache import CacheRoot from zhenxun.services.log import logger from zhenxun.utils.enum import DbLockType -from zhenxun.utils.exception import HookPriorityException -from zhenxun.utils.manager.priority_manager import PriorityLifecycle -driver = get_driver() - -SCRIPT_METHOD = [] -MODELS: list[str] = [] - -# 数据库操作超时设置(秒) -DB_TIMEOUT_SECONDS = 3.0 - -# 性能监控阈值(秒) -SLOW_QUERY_THRESHOLD = 0.5 - -LOG_COMMAND = "DbContext" - - -async def with_db_timeout( - coro, timeout: float = DB_TIMEOUT_SECONDS, operation: str | None = None -): - """带超时控制的数据库操作""" - start_time = time.time() - try: - result = await asyncio.wait_for(coro, timeout=timeout) - elapsed = time.time() - start_time - if elapsed > SLOW_QUERY_THRESHOLD and operation: - logger.warning(f"慢查询: {operation} 耗时 {elapsed:.3f}s", LOG_COMMAND) - return result - except asyncio.TimeoutError: - if operation: - logger.error(f"数据库操作超时: {operation} (>{timeout}s)", LOG_COMMAND) - raise +from .config import LOG_COMMAND, db_model +from .utils import with_db_timeout class Model(TortoiseModel): @@ -63,11 +27,11 @@ class Model(TortoiseModel): def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) - if cls.__module__ not in MODELS: - MODELS.append(cls.__module__) + if cls.__module__ not in db_model.models: + db_model.models.append(cls.__module__) if func := getattr(cls, "_run_script", None): - SCRIPT_METHOD.append((cls.__module__, func)) + db_model.script_method.append((cls.__module__, func)) @classmethod def get_cache_type(cls) -> str | None: @@ -322,144 +286,3 @@ class Model(TortoiseModel): f"数据库操作异常: {cls.__name__}.safe_get_or_none, {e!s}", LOG_COMMAND ) raise - - -class DbUrlIsNode(HookPriorityException): - """ - 数据库链接地址为空 - """ - - pass - - -class DbConnectError(Exception): - """ - 数据库连接错误 - """ - - pass - - -POSTGRESQL_CONFIG = { - "max_size": 30, # 最大连接数 - "min_size": 5, # 最小保持的连接数(可选) -} - - -MYSQL_CONFIG = { - "max_connections": 20, # 最大连接数 - "connect_timeout": 30, # 连接超时(可选) -} - -SQLITE_CONFIG = { - "journal_mode": "WAL", # 提高并发写入性能 - "timeout": 30, # 锁等待超时(可选) -} - - -def get_config() -> dict: - """获取数据库配置""" - parsed = urlparse(BotConfig.db_url) - - # 基础配置 - config = { - "connections": { - "default": BotConfig.db_url # 默认直接使用连接字符串 - }, - "apps": { - "models": { - "models": MODELS, - "default_connection": "default", - } - }, - "timezone": "Asia/Shanghai", - } - - # 根据数据库类型应用高级配置 - if parsed.scheme.startswith("postgres"): - config["connections"]["default"] = { - "engine": "tortoise.backends.asyncpg", - "credentials": { - "host": parsed.hostname, - "port": parsed.port or 5432, - "user": parsed.username, - "password": parsed.password, - "database": parsed.path[1:], - }, - **POSTGRESQL_CONFIG, - } - elif parsed.scheme == "mysql": - config["connections"]["default"] = { - "engine": "tortoise.backends.mysql", - "credentials": { - "host": parsed.hostname, - "port": parsed.port or 3306, - "user": parsed.username, - "password": parsed.password, - "database": parsed.path[1:], - }, - **MYSQL_CONFIG, - } - elif parsed.scheme == "sqlite": - config["connections"]["default"] = { - "engine": "tortoise.backends.sqlite", - "credentials": { - "file_path": parsed.path or ":memory:", - }, - **SQLITE_CONFIG, - } - return config - - -@PriorityLifecycle.on_startup(priority=1) -async def init(): - if not BotConfig.db_url: - # raise DbUrlIsNode("数据库配置为空,请在.env.dev中配置DB_URL...") - error = f""" -********************************************************************** -🌟 **************************** 配置为空 ************************* 🌟 -🚀 请打开 WebUi 进行基础配置 🚀 -🌐 配置地址:http://{driver.config.host}:{driver.config.port}/#/configure 🌐 -*********************************************************************** -*********************************************************************** - """ - raise DbUrlIsNode("\n" + error.strip()) - try: - await Tortoise.init( - config=get_config(), - ) - if SCRIPT_METHOD: - db = Tortoise.get_connection("default") - logger.debug( - "即将运行SCRIPT_METHOD方法, 合计 " - f"{len(SCRIPT_METHOD)} 个..." - ) - sql_list = [] - for module, func in SCRIPT_METHOD: - try: - sql = await func() if is_coroutine_callable(func) else func() - if sql: - sql_list += sql - except Exception as e: - logger.debug(f"{module} 执行SCRIPT_METHOD方法出错...", e=e) - for sql in sql_list: - logger.debug(f"执行SQL: {sql}") - try: - await asyncio.wait_for( - db.execute_query_dict(sql), timeout=DB_TIMEOUT_SECONDS - ) - # await TestSQL.raw(sql) - except Exception as e: - logger.debug(f"执行SQL: {sql} 错误...", e=e) - if sql_list: - logger.debug("SCRIPT_METHOD方法执行完毕!") - logger.debug("开始生成数据库表结构...") - await Tortoise.generate_schemas() - logger.debug("数据库表结构生成完毕!") - logger.info("Database loaded successfully!") - except Exception as e: - raise DbConnectError(f"数据库连接错误... e:{e}") from e - - -async def disconnect(): - await connections.close_all() diff --git a/zhenxun/services/db_context/config.py b/zhenxun/services/db_context/config.py new file mode 100644 index 00000000..ae6d6b8c --- /dev/null +++ b/zhenxun/services/db_context/config.py @@ -0,0 +1,46 @@ +from collections.abc import Callable + +from pydantic import BaseModel + +# 数据库操作超时设置(秒) +DB_TIMEOUT_SECONDS = 3.0 + +# 性能监控阈值(秒) +SLOW_QUERY_THRESHOLD = 0.5 + +LOG_COMMAND = "DbContext" + + +POSTGRESQL_CONFIG = { + "max_size": 30, # 最大连接数 + "min_size": 5, # 最小保持的连接数(可选) +} + + +MYSQL_CONFIG = { + "max_connections": 20, # 最大连接数 + "connect_timeout": 30, # 连接超时(可选) +} + +SQLITE_CONFIG = { + "journal_mode": "WAL", # 提高并发写入性能 + "timeout": 30, # 锁等待超时(可选) +} + + +class DbModel(BaseModel): + script_method: list[tuple[str, Callable]] = [] + models: list[str] = [] + + +db_model = DbModel() + + +prompt = """ +********************************************************************** +🌟 **************************** 配置为空 ************************* 🌟 +🚀 请打开 WebUi 进行基础配置 🚀 +🌐 配置地址:http://{host}:{port}/#/configure 🌐 +*********************************************************************** +*********************************************************************** +""" diff --git a/zhenxun/services/db_context/exceptions.py b/zhenxun/services/db_context/exceptions.py new file mode 100644 index 00000000..163f92e2 --- /dev/null +++ b/zhenxun/services/db_context/exceptions.py @@ -0,0 +1,14 @@ +class DbUrlIsNode(Exception): + """ + 数据库链接地址为空 + """ + + pass + + +class DbConnectError(Exception): + """ + 数据库连接错误 + """ + + pass diff --git a/zhenxun/services/db_context/utils.py b/zhenxun/services/db_context/utils.py new file mode 100644 index 00000000..47db548f --- /dev/null +++ b/zhenxun/services/db_context/utils.py @@ -0,0 +1,27 @@ +import asyncio +import time + +from zhenxun.services.log import logger + +from .config import ( + DB_TIMEOUT_SECONDS, + LOG_COMMAND, + SLOW_QUERY_THRESHOLD, +) + + +async def with_db_timeout( + coro, timeout: float = DB_TIMEOUT_SECONDS, operation: str | None = None +): + """带超时控制的数据库操作""" + start_time = time.time() + try: + result = await asyncio.wait_for(coro, timeout=timeout) + elapsed = time.time() - start_time + if elapsed > SLOW_QUERY_THRESHOLD and operation: + logger.warning(f"慢查询: {operation} 耗时 {elapsed:.3f}s", LOG_COMMAND) + return result + except asyncio.TimeoutError: + if operation: + logger.error(f"数据库操作超时: {operation} (>{timeout}s)", LOG_COMMAND) + raise diff --git a/zhenxun/services/plugin_init.py b/zhenxun/services/plugin_init.py index a622a9e8..a7b8685a 100644 --- a/zhenxun/services/plugin_init.py +++ b/zhenxun/services/plugin_init.py @@ -54,19 +54,20 @@ class PluginInitManager: @classmethod async def install_all(cls): """运行所有插件安装方法""" - if cls.plugins: - for module_path, model in cls.plugins.items(): - if model.install: - class_ = model.class_() - try: - logger.debug(f"开始执行: {module_path}:install 方法") - if is_coroutine_callable(class_.install): - await class_.install() - else: - class_.install() # type: ignore - logger.debug(f"执行: {module_path}:install 完成") - except Exception as e: - logger.error(f"执行: {module_path}:install 失败", e=e) + if not cls.plugins: + return + for module_path, model in cls.plugins.items(): + if model.install: + class_ = model.class_() + try: + logger.debug(f"开始执行: {module_path}:install 方法") + if is_coroutine_callable(class_.install): + await class_.install() + else: + class_.install() # type: ignore + logger.debug(f"执行: {module_path}:install 完成") + except Exception as e: + logger.error(f"执行: {module_path}:install 失败", e=e) @classmethod async def install(cls, module_path: str):