格式化db_context (#1980)

*  格式化db_context

* 🔥 移除旧db-context

*  添加旧版本兼容
This commit is contained in:
HibiKier 2025-07-15 17:08:42 +08:00 committed by GitHub
parent faa91b8bd4
commit d218c569d4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 293 additions and 248 deletions

View File

@ -5,7 +5,7 @@ import nonebot
from nonebot.adapters import Bot from nonebot.adapters import Bot
from nonebot.drivers import Driver from nonebot.drivers import Driver
from tortoise import Tortoise from tortoise import Tortoise
from tortoise.exceptions import OperationalError from tortoise.exceptions import IntegrityError, OperationalError
import ujson as json import ujson as json
from zhenxun.models.bot_connect_log import BotConnectLog from zhenxun.models.bot_connect_log import BotConnectLog
@ -30,9 +30,12 @@ async def _(bot: Bot):
bot_id=bot.self_id, platform=bot.adapter, connect_time=datetime.now(), type=1 bot_id=bot.self_id, platform=bot.adapter, connect_time=datetime.now(), type=1
) )
if not await BotConsole.exists(bot_id=bot.self_id): if not await BotConsole.exists(bot_id=bot.self_id):
await BotConsole.create( try:
bot_id=bot.self_id, platform=PlatformUtils.get_platform(bot) await BotConsole.create(
) bot_id=bot.self_id, platform=PlatformUtils.get_platform(bot)
)
except IntegrityError as e:
logger.warning(f"记录bot: {bot.self_id} 数据已存在...", e=e)
@driver.on_bot_disconnect @driver.on_bot_disconnect

View File

@ -18,7 +18,7 @@ require("nonebot_plugin_htmlrender")
require("nonebot_plugin_uninfo") require("nonebot_plugin_uninfo")
require("nonebot_plugin_waiter") require("nonebot_plugin_waiter")
from .db_context import Model, disconnect from .db_context import Model, disconnect, with_db_timeout
from .llm import ( from .llm import (
AI, AI,
AIConfig, AIConfig,
@ -80,4 +80,5 @@ __all__ = [
"search_multimodal", "search_multimodal",
"set_global_default_model_name", "set_global_default_model_name",
"tool_registry", "tool_registry",
"with_db_timeout",
] ]

View File

@ -63,6 +63,7 @@ from functools import wraps
from typing import Any, ClassVar, Generic, TypeVar, get_type_hints from typing import Any, ClassVar, Generic, TypeVar, get_type_hints
from aiocache import Cache as AioCache from aiocache import Cache as AioCache
from aiocache import SimpleMemoryCache
from aiocache.base import BaseCache from aiocache.base import BaseCache
from aiocache.serializers import JsonSerializer from aiocache.serializers import JsonSerializer
import nonebot import nonebot
@ -392,22 +393,14 @@ class CacheManager:
def cache_backend(self) -> BaseCache | AioCache: def cache_backend(self) -> BaseCache | AioCache:
"""获取缓存后端""" """获取缓存后端"""
if self._cache_backend is None: if self._cache_backend is None:
try: ttl = cache_config.redis_expire
from aiocache import RedisCache, SimpleMemoryCache if cache_config.cache_mode == CacheMode.NONE:
ttl = 0
logger.info("缓存功能已禁用,使用非持久化内存缓存", LOG_COMMAND)
elif cache_config.cache_mode == CacheMode.REDIS and cache_config.redis_host:
try:
from aiocache import RedisCache
if cache_config.cache_mode == CacheMode.NONE:
# 使用内存缓存但禁用持久化
self._cache_backend = SimpleMemoryCache(
serializer=JsonSerializer(),
namespace=CACHE_KEY_PREFIX,
timeout=30,
ttl=0, # 设置为0不缓存
)
logger.info("缓存功能已禁用,使用非持久化内存缓存", LOG_COMMAND)
elif (
cache_config.cache_mode == CacheMode.REDIS
and cache_config.redis_host
):
# 使用Redis缓存 # 使用Redis缓存
self._cache_backend = RedisCache( self._cache_backend = RedisCache(
serializer=JsonSerializer(), serializer=JsonSerializer(),
@ -419,27 +412,25 @@ class CacheManager:
password=cache_config.redis_password, password=cache_config.redis_password,
) )
logger.info( logger.info(
f"使用Redis缓存地址: {cache_config.redis_host}", LOG_COMMAND f"使用Redis缓存地址: {cache_config.redis_host}",
LOG_COMMAND,
) )
else: return self._cache_backend
# 默认使用内存缓存 except ImportError as e:
self._cache_backend = SimpleMemoryCache( logger.error(
serializer=JsonSerializer(), "导入aiocache[redis]失败,将默认使用内存缓存...",
namespace=CACHE_KEY_PREFIX, LOG_COMMAND,
timeout=30, e=e,
ttl=cache_config.redis_expire,
) )
logger.info("使用内存缓存", LOG_COMMAND) else:
except ImportError: logger.info("使用内存缓存", LOG_COMMAND)
logger.error("导入aiocache模块失败使用内存缓存", LOG_COMMAND) # 默认使用内存缓存
# 使用内存缓存 self._cache_backend = SimpleMemoryCache(
self._cache_backend = AioCache( serializer=JsonSerializer(),
cache_class=AioCache.MEMORY, namespace=CACHE_KEY_PREFIX,
serializer=JsonSerializer(), timeout=30,
namespace=CACHE_KEY_PREFIX, ttl=ttl,
timeout=30, )
ttl=cache_config.redis_expire,
)
return self._cache_backend return self._cache_backend
@property @property

View File

@ -54,11 +54,7 @@ class CacheDict:
key: 字典键 key: 字典键
value: 字典值 value: 字典值
""" """
# 计算过期时间 expire_time = time.time() + self.expire if self.expire > 0 else 0
expire_time = 0
if self.expire > 0:
expire_time = time.time() + self.expire
self._data[key] = CacheData(value=value, expire_time=expire_time) self._data[key] = CacheData(value=value, expire_time=expire_time)
def __delitem__(self, key: str) -> None: def __delitem__(self, key: str) -> None:
@ -274,12 +270,11 @@ class CacheList:
self.clear() self.clear()
raise IndexError(f"列表索引 {index} 超出范围") raise IndexError(f"列表索引 {index} 超出范围")
if 0 <= index < len(self._data): if not 0 <= index < len(self._data):
del self._data[index]
# 更新过期时间
self._update_expire_time()
else:
raise IndexError(f"列表索引 {index} 超出范围") raise IndexError(f"列表索引 {index} 超出范围")
del self._data[index]
# 更新过期时间
self._update_expire_time()
def __len__(self) -> int: def __len__(self) -> int:
"""获取列表长度 """获取列表长度
@ -427,6 +422,7 @@ class CacheList:
self.clear() self.clear()
return 0 return 0
# sourcery skip: simplify-constant-sum
return sum(1 for item in self._data if item.value == value) return sum(1 for item in self._data if item.value == value)
def _is_expired(self) -> bool: def _is_expired(self) -> bool:
@ -435,10 +431,7 @@ class CacheList:
def _update_expire_time(self) -> None: def _update_expire_time(self) -> None:
"""更新过期时间""" """更新过期时间"""
if self.expire > 0: self._expire_time = time.time() + self.expire if self.expire > 0 else 0
self._expire_time = time.time() + self.expire
else:
self._expire_time = 0
def __str__(self) -> str: def __str__(self) -> str:
"""字符串表示 """字符串表示

View File

@ -0,0 +1,146 @@
import asyncio
from urllib.parse import urlparse
import nonebot
from nonebot.utils import is_coroutine_callable
from tortoise import Tortoise
from tortoise.connection import connections
from zhenxun.configs.config import BotConfig
from zhenxun.services.log import logger
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from .base_model import Model
from .config import (
DB_TIMEOUT_SECONDS,
MYSQL_CONFIG,
POSTGRESQL_CONFIG,
SLOW_QUERY_THRESHOLD,
SQLITE_CONFIG,
db_model,
prompt,
)
from .exceptions import DbConnectError, DbUrlIsNode
from .utils import with_db_timeout
MODELS = db_model.models
SCRIPT_METHOD = db_model.script_method
__all__ = [
"DB_TIMEOUT_SECONDS",
"MODELS",
"SCRIPT_METHOD",
"SLOW_QUERY_THRESHOLD",
"DbConnectError",
"DbUrlIsNode",
"Model",
"disconnect",
"init",
"with_db_timeout",
]
driver = nonebot.get_driver()
def get_config() -> dict:
"""获取数据库配置"""
parsed = urlparse(BotConfig.db_url)
# 基础配置
config = {
"connections": {
"default": BotConfig.db_url # 默认直接使用连接字符串
},
"apps": {
"models": {
"models": db_model.models,
"default_connection": "default",
}
},
"timezone": "Asia/Shanghai",
}
# 根据数据库类型应用高级配置
if parsed.scheme.startswith("postgres"):
config["connections"]["default"] = {
"engine": "tortoise.backends.asyncpg",
"credentials": {
"host": parsed.hostname,
"port": parsed.port or 5432,
"user": parsed.username,
"password": parsed.password,
"database": parsed.path[1:],
},
**POSTGRESQL_CONFIG,
}
elif parsed.scheme == "mysql":
config["connections"]["default"] = {
"engine": "tortoise.backends.mysql",
"credentials": {
"host": parsed.hostname,
"port": parsed.port or 3306,
"user": parsed.username,
"password": parsed.password,
"database": parsed.path[1:],
},
**MYSQL_CONFIG,
}
elif parsed.scheme == "sqlite":
config["connections"]["default"] = {
"engine": "tortoise.backends.sqlite",
"credentials": {
"file_path": parsed.path or ":memory:",
},
**SQLITE_CONFIG,
}
return config
@PriorityLifecycle.on_startup(priority=1)
async def init():
global MODELS, SCRIPT_METHOD
MODELS = db_model.models
SCRIPT_METHOD = db_model.script_method
if not BotConfig.db_url:
error = prompt.format(host=driver.config.host, port=driver.config.port)
raise DbUrlIsNode("\n" + error.strip())
try:
await Tortoise.init(
config=get_config(),
)
if db_model.script_method:
db = Tortoise.get_connection("default")
logger.debug(
"即将运行SCRIPT_METHOD方法, 合计 "
f"<u><y>{len(db_model.script_method)}</y></u> 个..."
)
sql_list = []
for module, func in db_model.script_method:
try:
sql = await func() if is_coroutine_callable(func) else func()
if sql:
sql_list += sql
except Exception as e:
logger.debug(f"{module} 执行SCRIPT_METHOD方法出错...", e=e)
for sql in sql_list:
logger.debug(f"执行SQL: {sql}")
try:
await asyncio.wait_for(
db.execute_query_dict(sql), timeout=DB_TIMEOUT_SECONDS
)
# await TestSQL.raw(sql)
except Exception as e:
logger.debug(f"执行SQL: {sql} 错误...", e=e)
if sql_list:
logger.debug("SCRIPT_METHOD方法执行完毕!")
logger.debug("开始生成数据库表结构...")
await Tortoise.generate_schemas()
logger.debug("数据库表结构生成完毕!")
logger.info("Database loaded successfully!")
except Exception as e:
raise DbConnectError(f"数据库连接错误... e:{e}") from e
async def disconnect():
await connections.close_all()

View File

@ -1,56 +1,20 @@
import asyncio import asyncio
from collections.abc import Iterable from collections.abc import Iterable
import contextlib import contextlib
import time
from typing import Any, ClassVar from typing import Any, ClassVar
from typing_extensions import Self from typing_extensions import Self
from urllib.parse import urlparse
from nonebot import get_driver
from nonebot.utils import is_coroutine_callable
from tortoise import Tortoise
from tortoise.backends.base.client import BaseDBAsyncClient from tortoise.backends.base.client import BaseDBAsyncClient
from tortoise.connection import connections
from tortoise.exceptions import IntegrityError, MultipleObjectsReturned from tortoise.exceptions import IntegrityError, MultipleObjectsReturned
from tortoise.models import Model as TortoiseModel from tortoise.models import Model as TortoiseModel
from tortoise.transactions import in_transaction from tortoise.transactions import in_transaction
from zhenxun.configs.config import BotConfig
from zhenxun.services.cache import CacheRoot from zhenxun.services.cache import CacheRoot
from zhenxun.services.log import logger from zhenxun.services.log import logger
from zhenxun.utils.enum import DbLockType from zhenxun.utils.enum import DbLockType
from zhenxun.utils.exception import HookPriorityException
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
driver = get_driver() from .config import LOG_COMMAND, db_model
from .utils import with_db_timeout
SCRIPT_METHOD = []
MODELS: list[str] = []
# 数据库操作超时设置(秒)
DB_TIMEOUT_SECONDS = 3.0
# 性能监控阈值(秒)
SLOW_QUERY_THRESHOLD = 0.5
LOG_COMMAND = "DbContext"
async def with_db_timeout(
coro, timeout: float = DB_TIMEOUT_SECONDS, operation: str | None = None
):
"""带超时控制的数据库操作"""
start_time = time.time()
try:
result = await asyncio.wait_for(coro, timeout=timeout)
elapsed = time.time() - start_time
if elapsed > SLOW_QUERY_THRESHOLD and operation:
logger.warning(f"慢查询: {operation} 耗时 {elapsed:.3f}s", LOG_COMMAND)
return result
except asyncio.TimeoutError:
if operation:
logger.error(f"数据库操作超时: {operation} (>{timeout}s)", LOG_COMMAND)
raise
class Model(TortoiseModel): class Model(TortoiseModel):
@ -63,11 +27,11 @@ class Model(TortoiseModel):
def __init_subclass__(cls, **kwargs): def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs) super().__init_subclass__(**kwargs)
if cls.__module__ not in MODELS: if cls.__module__ not in db_model.models:
MODELS.append(cls.__module__) db_model.models.append(cls.__module__)
if func := getattr(cls, "_run_script", None): if func := getattr(cls, "_run_script", None):
SCRIPT_METHOD.append((cls.__module__, func)) db_model.script_method.append((cls.__module__, func))
@classmethod @classmethod
def get_cache_type(cls) -> str | None: def get_cache_type(cls) -> str | None:
@ -322,144 +286,3 @@ class Model(TortoiseModel):
f"数据库操作异常: {cls.__name__}.safe_get_or_none, {e!s}", LOG_COMMAND f"数据库操作异常: {cls.__name__}.safe_get_or_none, {e!s}", LOG_COMMAND
) )
raise raise
class DbUrlIsNode(HookPriorityException):
"""
数据库链接地址为空
"""
pass
class DbConnectError(Exception):
"""
数据库连接错误
"""
pass
POSTGRESQL_CONFIG = {
"max_size": 30, # 最大连接数
"min_size": 5, # 最小保持的连接数(可选)
}
MYSQL_CONFIG = {
"max_connections": 20, # 最大连接数
"connect_timeout": 30, # 连接超时(可选)
}
SQLITE_CONFIG = {
"journal_mode": "WAL", # 提高并发写入性能
"timeout": 30, # 锁等待超时(可选)
}
def get_config() -> dict:
"""获取数据库配置"""
parsed = urlparse(BotConfig.db_url)
# 基础配置
config = {
"connections": {
"default": BotConfig.db_url # 默认直接使用连接字符串
},
"apps": {
"models": {
"models": MODELS,
"default_connection": "default",
}
},
"timezone": "Asia/Shanghai",
}
# 根据数据库类型应用高级配置
if parsed.scheme.startswith("postgres"):
config["connections"]["default"] = {
"engine": "tortoise.backends.asyncpg",
"credentials": {
"host": parsed.hostname,
"port": parsed.port or 5432,
"user": parsed.username,
"password": parsed.password,
"database": parsed.path[1:],
},
**POSTGRESQL_CONFIG,
}
elif parsed.scheme == "mysql":
config["connections"]["default"] = {
"engine": "tortoise.backends.mysql",
"credentials": {
"host": parsed.hostname,
"port": parsed.port or 3306,
"user": parsed.username,
"password": parsed.password,
"database": parsed.path[1:],
},
**MYSQL_CONFIG,
}
elif parsed.scheme == "sqlite":
config["connections"]["default"] = {
"engine": "tortoise.backends.sqlite",
"credentials": {
"file_path": parsed.path or ":memory:",
},
**SQLITE_CONFIG,
}
return config
@PriorityLifecycle.on_startup(priority=1)
async def init():
if not BotConfig.db_url:
# raise DbUrlIsNode("数据库配置为空,请在.env.dev中配置DB_URL...")
error = f"""
**********************************************************************
🌟 **************************** 配置为空 ************************* 🌟
🚀 请打开 WebUi 进行基础配置 🚀
🌐 配置地址http://{driver.config.host}:{driver.config.port}/#/configure 🌐
***********************************************************************
***********************************************************************
"""
raise DbUrlIsNode("\n" + error.strip())
try:
await Tortoise.init(
config=get_config(),
)
if SCRIPT_METHOD:
db = Tortoise.get_connection("default")
logger.debug(
"即将运行SCRIPT_METHOD方法, 合计 "
f"<u><y>{len(SCRIPT_METHOD)}</y></u> 个..."
)
sql_list = []
for module, func in SCRIPT_METHOD:
try:
sql = await func() if is_coroutine_callable(func) else func()
if sql:
sql_list += sql
except Exception as e:
logger.debug(f"{module} 执行SCRIPT_METHOD方法出错...", e=e)
for sql in sql_list:
logger.debug(f"执行SQL: {sql}")
try:
await asyncio.wait_for(
db.execute_query_dict(sql), timeout=DB_TIMEOUT_SECONDS
)
# await TestSQL.raw(sql)
except Exception as e:
logger.debug(f"执行SQL: {sql} 错误...", e=e)
if sql_list:
logger.debug("SCRIPT_METHOD方法执行完毕!")
logger.debug("开始生成数据库表结构...")
await Tortoise.generate_schemas()
logger.debug("数据库表结构生成完毕!")
logger.info("Database loaded successfully!")
except Exception as e:
raise DbConnectError(f"数据库连接错误... e:{e}") from e
async def disconnect():
await connections.close_all()

View File

@ -0,0 +1,46 @@
from collections.abc import Callable
from pydantic import BaseModel
# 数据库操作超时设置(秒)
DB_TIMEOUT_SECONDS = 3.0
# 性能监控阈值(秒)
SLOW_QUERY_THRESHOLD = 0.5
LOG_COMMAND = "DbContext"
POSTGRESQL_CONFIG = {
"max_size": 30, # 最大连接数
"min_size": 5, # 最小保持的连接数(可选)
}
MYSQL_CONFIG = {
"max_connections": 20, # 最大连接数
"connect_timeout": 30, # 连接超时(可选)
}
SQLITE_CONFIG = {
"journal_mode": "WAL", # 提高并发写入性能
"timeout": 30, # 锁等待超时(可选)
}
class DbModel(BaseModel):
script_method: list[tuple[str, Callable]] = []
models: list[str] = []
db_model = DbModel()
prompt = """
**********************************************************************
🌟 **************************** 配置为空 ************************* 🌟
🚀 请打开 WebUi 进行基础配置 🚀
🌐 配置地址http://{host}:{port}/#/configure 🌐
***********************************************************************
***********************************************************************
"""

View File

@ -0,0 +1,14 @@
class DbUrlIsNode(Exception):
"""
数据库链接地址为空
"""
pass
class DbConnectError(Exception):
"""
数据库连接错误
"""
pass

View File

@ -0,0 +1,27 @@
import asyncio
import time
from zhenxun.services.log import logger
from .config import (
DB_TIMEOUT_SECONDS,
LOG_COMMAND,
SLOW_QUERY_THRESHOLD,
)
async def with_db_timeout(
coro, timeout: float = DB_TIMEOUT_SECONDS, operation: str | None = None
):
"""带超时控制的数据库操作"""
start_time = time.time()
try:
result = await asyncio.wait_for(coro, timeout=timeout)
elapsed = time.time() - start_time
if elapsed > SLOW_QUERY_THRESHOLD and operation:
logger.warning(f"慢查询: {operation} 耗时 {elapsed:.3f}s", LOG_COMMAND)
return result
except asyncio.TimeoutError:
if operation:
logger.error(f"数据库操作超时: {operation} (>{timeout}s)", LOG_COMMAND)
raise

View File

@ -54,19 +54,20 @@ class PluginInitManager:
@classmethod @classmethod
async def install_all(cls): async def install_all(cls):
"""运行所有插件安装方法""" """运行所有插件安装方法"""
if cls.plugins: if not cls.plugins:
for module_path, model in cls.plugins.items(): return
if model.install: for module_path, model in cls.plugins.items():
class_ = model.class_() if model.install:
try: class_ = model.class_()
logger.debug(f"开始执行: {module_path}:install 方法") try:
if is_coroutine_callable(class_.install): logger.debug(f"开始执行: {module_path}:install 方法")
await class_.install() if is_coroutine_callable(class_.install):
else: await class_.install()
class_.install() # type: ignore else:
logger.debug(f"执行: {module_path}:install 完成") class_.install() # type: ignore
except Exception as e: logger.debug(f"执行: {module_path}:install 完成")
logger.error(f"执行: {module_path}:install 失败", e=e) except Exception as e:
logger.error(f"执行: {module_path}:install 失败", e=e)
@classmethod @classmethod
async def install(cls, module_path: str): async def install(cls, module_path: str):