格式化db_context

This commit is contained in:
HibiKier 2025-07-15 00:44:05 +08:00
parent faa91b8bd4
commit 88493328ee
10 changed files with 567 additions and 66 deletions

View File

@ -5,7 +5,7 @@ import nonebot
from nonebot.adapters import Bot from nonebot.adapters import Bot
from nonebot.drivers import Driver from nonebot.drivers import Driver
from tortoise import Tortoise from tortoise import Tortoise
from tortoise.exceptions import OperationalError from tortoise.exceptions import IntegrityError, OperationalError
import ujson as json import ujson as json
from zhenxun.models.bot_connect_log import BotConnectLog from zhenxun.models.bot_connect_log import BotConnectLog
@ -30,9 +30,12 @@ async def _(bot: Bot):
bot_id=bot.self_id, platform=bot.adapter, connect_time=datetime.now(), type=1 bot_id=bot.self_id, platform=bot.adapter, connect_time=datetime.now(), type=1
) )
if not await BotConsole.exists(bot_id=bot.self_id): if not await BotConsole.exists(bot_id=bot.self_id):
await BotConsole.create( try:
bot_id=bot.self_id, platform=PlatformUtils.get_platform(bot) await BotConsole.create(
) bot_id=bot.self_id, platform=PlatformUtils.get_platform(bot)
)
except IntegrityError as e:
logger.warning(f"记录bot: {bot.self_id} 数据已存在...", e=e)
@driver.on_bot_disconnect @driver.on_bot_disconnect

View File

@ -18,7 +18,7 @@ require("nonebot_plugin_htmlrender")
require("nonebot_plugin_uninfo") require("nonebot_plugin_uninfo")
require("nonebot_plugin_waiter") require("nonebot_plugin_waiter")
from .db_context import Model, disconnect from .db_context import Model, disconnect, with_db_timeout
from .llm import ( from .llm import (
AI, AI,
AIConfig, AIConfig,
@ -80,4 +80,5 @@ __all__ = [
"search_multimodal", "search_multimodal",
"set_global_default_model_name", "set_global_default_model_name",
"tool_registry", "tool_registry",
"with_db_timeout",
] ]

View File

@ -63,6 +63,7 @@ from functools import wraps
from typing import Any, ClassVar, Generic, TypeVar, get_type_hints from typing import Any, ClassVar, Generic, TypeVar, get_type_hints
from aiocache import Cache as AioCache from aiocache import Cache as AioCache
from aiocache import SimpleMemoryCache
from aiocache.base import BaseCache from aiocache.base import BaseCache
from aiocache.serializers import JsonSerializer from aiocache.serializers import JsonSerializer
import nonebot import nonebot
@ -392,22 +393,14 @@ class CacheManager:
def cache_backend(self) -> BaseCache | AioCache: def cache_backend(self) -> BaseCache | AioCache:
"""获取缓存后端""" """获取缓存后端"""
if self._cache_backend is None: if self._cache_backend is None:
try: ttl = cache_config.redis_expire
from aiocache import RedisCache, SimpleMemoryCache if cache_config.cache_mode == CacheMode.NONE:
ttl = 0
logger.info("缓存功能已禁用,使用非持久化内存缓存", LOG_COMMAND)
elif cache_config.cache_mode == CacheMode.REDIS and cache_config.redis_host:
try:
from aiocache import RedisCache
if cache_config.cache_mode == CacheMode.NONE:
# 使用内存缓存但禁用持久化
self._cache_backend = SimpleMemoryCache(
serializer=JsonSerializer(),
namespace=CACHE_KEY_PREFIX,
timeout=30,
ttl=0, # 设置为0不缓存
)
logger.info("缓存功能已禁用,使用非持久化内存缓存", LOG_COMMAND)
elif (
cache_config.cache_mode == CacheMode.REDIS
and cache_config.redis_host
):
# 使用Redis缓存 # 使用Redis缓存
self._cache_backend = RedisCache( self._cache_backend = RedisCache(
serializer=JsonSerializer(), serializer=JsonSerializer(),
@ -419,27 +412,25 @@ class CacheManager:
password=cache_config.redis_password, password=cache_config.redis_password,
) )
logger.info( logger.info(
f"使用Redis缓存地址: {cache_config.redis_host}", LOG_COMMAND f"使用Redis缓存地址: {cache_config.redis_host}",
LOG_COMMAND,
) )
else: return self._cache_backend
# 默认使用内存缓存 except ImportError as e:
self._cache_backend = SimpleMemoryCache( logger.error(
serializer=JsonSerializer(), "导入aiocache[reids]失败,将默认使用内存缓存...",
namespace=CACHE_KEY_PREFIX, LOG_COMMAND,
timeout=30, e=e,
ttl=cache_config.redis_expire,
) )
logger.info("使用内存缓存", LOG_COMMAND) else:
except ImportError: logger.info("使用内存缓存", LOG_COMMAND)
logger.error("导入aiocache模块失败使用内存缓存", LOG_COMMAND) # 默认使用内存缓存
# 使用内存缓存 self._cache_backend = SimpleMemoryCache(
self._cache_backend = AioCache( serializer=JsonSerializer(),
cache_class=AioCache.MEMORY, namespace=CACHE_KEY_PREFIX,
serializer=JsonSerializer(), timeout=30,
namespace=CACHE_KEY_PREFIX, ttl=ttl,
timeout=30, )
ttl=cache_config.redis_expire,
)
return self._cache_backend return self._cache_backend
@property @property

View File

@ -54,11 +54,7 @@ class CacheDict:
key: 字典键 key: 字典键
value: 字典值 value: 字典值
""" """
# 计算过期时间 expire_time = time.time() + self.expire if self.expire > 0 else 0
expire_time = 0
if self.expire > 0:
expire_time = time.time() + self.expire
self._data[key] = CacheData(value=value, expire_time=expire_time) self._data[key] = CacheData(value=value, expire_time=expire_time)
def __delitem__(self, key: str) -> None: def __delitem__(self, key: str) -> None:
@ -274,12 +270,11 @@ class CacheList:
self.clear() self.clear()
raise IndexError(f"列表索引 {index} 超出范围") raise IndexError(f"列表索引 {index} 超出范围")
if 0 <= index < len(self._data): if not 0 <= index < len(self._data):
del self._data[index]
# 更新过期时间
self._update_expire_time()
else:
raise IndexError(f"列表索引 {index} 超出范围") raise IndexError(f"列表索引 {index} 超出范围")
del self._data[index]
# 更新过期时间
self._update_expire_time()
def __len__(self) -> int: def __len__(self) -> int:
"""获取列表长度 """获取列表长度
@ -427,6 +422,7 @@ class CacheList:
self.clear() self.clear()
return 0 return 0
# sourcery skip: simplify-constant-sum
return sum(1 for item in self._data if item.value == value) return sum(1 for item in self._data if item.value == value)
def _is_expired(self) -> bool: def _is_expired(self) -> bool:
@ -435,10 +431,7 @@ class CacheList:
def _update_expire_time(self) -> None: def _update_expire_time(self) -> None:
"""更新过期时间""" """更新过期时间"""
if self.expire > 0: self._expire_time = time.time() + self.expire if self.expire > 0 else 0
self._expire_time = time.time() + self.expire
else:
self._expire_time = 0
def __str__(self) -> str: def __str__(self) -> str:
"""字符串表示 """字符串表示

View File

@ -0,0 +1,137 @@
import asyncio
from urllib.parse import urlparse
import nonebot
from nonebot.utils import is_coroutine_callable
from tortoise import Tortoise
from tortoise.connection import connections
from zhenxun.configs.config import BotConfig
from zhenxun.services.log import logger
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from .base_model import Model
from .config import (
DB_TIMEOUT_SECONDS,
MYSQL_CONFIG,
POSTGRESQL_CONFIG,
SLOW_QUERY_THRESHOLD,
SQLITE_CONFIG,
db_model,
prompt,
)
from .exceptions import DbConnectError, DbUrlIsNode
from .utils import with_db_timeout
__all__ = [
"DB_TIMEOUT_SECONDS",
"SLOW_QUERY_THRESHOLD",
"DbConnectError",
"DbUrlIsNode",
"Model",
"disconnect",
"init",
"with_db_timeout",
]
driver = nonebot.get_driver()
def get_config() -> dict:
"""获取数据库配置"""
parsed = urlparse(BotConfig.db_url)
# 基础配置
config = {
"connections": {
"default": BotConfig.db_url # 默认直接使用连接字符串
},
"apps": {
"models": {
"models": db_model.models,
"default_connection": "default",
}
},
"timezone": "Asia/Shanghai",
}
# 根据数据库类型应用高级配置
if parsed.scheme.startswith("postgres"):
config["connections"]["default"] = {
"engine": "tortoise.backends.asyncpg",
"credentials": {
"host": parsed.hostname,
"port": parsed.port or 5432,
"user": parsed.username,
"password": parsed.password,
"database": parsed.path[1:],
},
**POSTGRESQL_CONFIG,
}
elif parsed.scheme == "mysql":
config["connections"]["default"] = {
"engine": "tortoise.backends.mysql",
"credentials": {
"host": parsed.hostname,
"port": parsed.port or 3306,
"user": parsed.username,
"password": parsed.password,
"database": parsed.path[1:],
},
**MYSQL_CONFIG,
}
elif parsed.scheme == "sqlite":
config["connections"]["default"] = {
"engine": "tortoise.backends.sqlite",
"credentials": {
"file_path": parsed.path or ":memory:",
},
**SQLITE_CONFIG,
}
return config
@PriorityLifecycle.on_startup(priority=1)
async def init():
if not BotConfig.db_url:
error = prompt.format(host=driver.config.host, port=driver.config.port)
raise DbUrlIsNode("\n" + error.strip())
try:
await Tortoise.init(
config=get_config(),
)
if db_model.script_method:
db = Tortoise.get_connection("default")
logger.debug(
"即将运行SCRIPT_METHOD方法, 合计 "
f"<u><y>{len(db_model.script_method)}</y></u> 个..."
)
sql_list = []
for module, func in db_model.script_method:
try:
sql = await func() if is_coroutine_callable(func) else func()
if sql:
sql_list += sql
except Exception as e:
logger.debug(f"{module} 执行SCRIPT_METHOD方法出错...", e=e)
for sql in sql_list:
logger.debug(f"执行SQL: {sql}")
try:
await asyncio.wait_for(
db.execute_query_dict(sql), timeout=DB_TIMEOUT_SECONDS
)
# await TestSQL.raw(sql)
except Exception as e:
logger.debug(f"执行SQL: {sql} 错误...", e=e)
if sql_list:
logger.debug("SCRIPT_METHOD方法执行完毕!")
logger.debug("开始生成数据库表结构...")
await Tortoise.generate_schemas()
logger.debug("数据库表结构生成完毕!")
logger.info("Database loaded successfully!")
except Exception as e:
raise DbConnectError(f"数据库连接错误... e:{e}") from e
async def disconnect():
await connections.close_all()

View File

@ -0,0 +1,288 @@
import asyncio
from collections.abc import Iterable
import contextlib
from typing import Any, ClassVar
from typing_extensions import Self
from tortoise.backends.base.client import BaseDBAsyncClient
from tortoise.exceptions import IntegrityError, MultipleObjectsReturned
from tortoise.models import Model as TortoiseModel
from tortoise.transactions import in_transaction
from zhenxun.services.cache import CacheRoot
from zhenxun.services.log import logger
from zhenxun.utils.enum import DbLockType
from .config import LOG_COMMAND, db_model
from .utils import with_db_timeout
class Model(TortoiseModel):
"""
增强的ORM基类解决锁嵌套问题
"""
sem_data: ClassVar[dict[str, dict[str, asyncio.Semaphore]]] = {}
_current_locks: ClassVar[dict[int, DbLockType]] = {} # 跟踪当前协程持有的锁
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if cls.__module__ not in db_model.models:
db_model.models.append(cls.__module__)
if func := getattr(cls, "_run_script", None):
db_model.script_method.append((cls.__module__, func))
@classmethod
def get_cache_type(cls) -> str | None:
"""获取缓存类型"""
return getattr(cls, "cache_type", None)
@classmethod
def get_cache_key_field(cls) -> str | tuple[str]:
"""获取缓存键字段"""
return getattr(cls, "cache_key_field", "id")
@classmethod
def get_cache_key(cls, instance) -> str | None:
"""获取缓存键
参数:
instance: 模型实例
返回:
str | None: 缓存键如果无法获取则返回None
"""
from zhenxun.services.cache.config import COMPOSITE_KEY_SEPARATOR
key_field = cls.get_cache_key_field()
if isinstance(key_field, tuple):
# 多字段主键
key_parts = []
for field in key_field:
if hasattr(instance, field):
value = getattr(instance, field, None)
key_parts.append(value if value is not None else "")
else:
# 如果缺少任何必要的字段返回None
key_parts.append("")
# 如果没有有效参数返回None
return COMPOSITE_KEY_SEPARATOR.join(key_parts) if key_parts else None
elif hasattr(instance, key_field):
value = getattr(instance, key_field, None)
return str(value) if value is not None else None
return None
@classmethod
def get_semaphore(cls, lock_type: DbLockType):
enable_lock = getattr(cls, "enable_lock", None)
if not enable_lock or lock_type not in enable_lock:
return None
if cls.__name__ not in cls.sem_data:
cls.sem_data[cls.__name__] = {}
if lock_type not in cls.sem_data[cls.__name__]:
cls.sem_data[cls.__name__][lock_type] = asyncio.Semaphore(1)
return cls.sem_data[cls.__name__][lock_type]
@classmethod
def _require_lock(cls, lock_type: DbLockType) -> bool:
"""检查是否需要真正加锁"""
task_id = id(asyncio.current_task())
return cls._current_locks.get(task_id) != lock_type
@classmethod
@contextlib.asynccontextmanager
async def _lock_context(cls, lock_type: DbLockType):
"""带重入检查的锁上下文"""
task_id = id(asyncio.current_task())
need_lock = cls._require_lock(lock_type)
if need_lock and (sem := cls.get_semaphore(lock_type)):
cls._current_locks[task_id] = lock_type
async with sem:
yield
cls._current_locks.pop(task_id, None)
else:
yield
@classmethod
async def create(
cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any
) -> Self:
"""创建数据使用CREATE锁"""
async with cls._lock_context(DbLockType.CREATE):
# 直接调用父类的_create方法避免触发save的锁
result = await super().create(using_db=using_db, **kwargs)
if cache_type := cls.get_cache_type():
await CacheRoot.invalidate_cache(cache_type, cls.get_cache_key(result))
return result
@classmethod
async def get_or_create(
cls,
defaults: dict | None = None,
using_db: BaseDBAsyncClient | None = None,
**kwargs: Any,
) -> tuple[Self, bool]:
"""获取或创建数据(无锁版本,依赖数据库约束)"""
result = await super().get_or_create(
defaults=defaults, using_db=using_db, **kwargs
)
if cache_type := cls.get_cache_type():
await CacheRoot.invalidate_cache(cache_type, cls.get_cache_key(result[0]))
return result
@classmethod
async def update_or_create(
cls,
defaults: dict | None = None,
using_db: BaseDBAsyncClient | None = None,
**kwargs: Any,
) -> tuple[Self, bool]:
"""更新或创建数据使用UPSERT锁"""
async with cls._lock_context(DbLockType.UPSERT):
try:
# 先尝试更新(带行锁)
async with in_transaction():
if obj := await cls.filter(**kwargs).select_for_update().first():
await obj.update_from_dict(defaults or {})
await obj.save()
result = (obj, False)
else:
# 创建时不重复加锁
result = await cls.create(**kwargs, **(defaults or {})), True
if cache_type := cls.get_cache_type():
await CacheRoot.invalidate_cache(
cache_type, cls.get_cache_key(result[0])
)
return result
except IntegrityError:
# 处理极端情况下的唯一约束冲突
obj = await cls.get(**kwargs)
return obj, False
async def save(
self,
using_db: BaseDBAsyncClient | None = None,
update_fields: Iterable[str] | None = None,
force_create: bool = False,
force_update: bool = False,
):
"""保存数据(根据操作类型自动选择锁)"""
lock_type = (
DbLockType.CREATE
if getattr(self, "id", None) is None
else DbLockType.UPDATE
)
async with self._lock_context(lock_type):
await super().save(
using_db=using_db,
update_fields=update_fields,
force_create=force_create,
force_update=force_update,
)
if cache_type := getattr(self, "cache_type", None):
await CacheRoot.invalidate_cache(
cache_type, self.__class__.get_cache_key(self)
)
async def delete(self, using_db: BaseDBAsyncClient | None = None):
cache_type = getattr(self, "cache_type", None)
key = self.__class__.get_cache_key(self) if cache_type else None
# 执行删除操作
await super().delete(using_db=using_db)
# 清除缓存
if cache_type:
await CacheRoot.invalidate_cache(cache_type, key)
@classmethod
async def safe_get_or_none(
cls,
*args,
using_db: BaseDBAsyncClient | None = None,
clean_duplicates: bool = True,
**kwargs: Any,
) -> Self | None:
"""安全地获取一条记录或None处理存在多个记录时返回最新的那个
注意默认会删除重复的记录仅保留最新的
参数:
*args: 查询参数
using_db: 数据库连接
clean_duplicates: 是否删除重复的记录仅保留最新的
**kwargs: 查询参数
返回:
Self | None: 查询结果如果不存在返回None
"""
try:
# 先尝试使用 get_or_none 获取单个记录
try:
return await with_db_timeout(
cls.get_or_none(*args, using_db=using_db, **kwargs),
operation=f"{cls.__name__}.get_or_none",
)
except MultipleObjectsReturned:
# 如果出现多个记录的情况,进行特殊处理
logger.warning(
f"{cls.__name__} safe_get_or_none 发现多个记录: {kwargs}",
LOG_COMMAND,
)
# 查询所有匹配记录
records = await with_db_timeout(
cls.filter(*args, **kwargs).all(),
operation=f"{cls.__name__}.filter.all",
)
if not records:
return None
# 如果需要清理重复记录
if clean_duplicates and hasattr(records[0], "id"):
# 按 id 排序
records = sorted(
records, key=lambda x: getattr(x, "id", 0), reverse=True
)
for record in records[1:]:
try:
await with_db_timeout(
record.delete(),
operation=f"{cls.__name__}.delete_duplicate",
)
logger.info(
f"{cls.__name__} 删除重复记录:"
f" id={getattr(record, 'id', None)}",
LOG_COMMAND,
)
except Exception as del_e:
logger.error(f"删除重复记录失败: {del_e}")
return records[0]
# 如果不需要清理或没有 id 字段,则返回最新的记录
if hasattr(cls, "id"):
return await with_db_timeout(
cls.filter(*args, **kwargs).order_by("-id").first(),
operation=f"{cls.__name__}.filter.order_by.first",
)
# 如果没有 id 字段,则返回第一个记录
return await with_db_timeout(
cls.filter(*args, **kwargs).first(),
operation=f"{cls.__name__}.filter.first",
)
except asyncio.TimeoutError:
logger.error(
f"数据库操作超时: {cls.__name__}.safe_get_or_none", LOG_COMMAND
)
return None
except Exception as e:
# 其他类型的错误则继续抛出
logger.error(
f"数据库操作异常: {cls.__name__}.safe_get_or_none, {e!s}", LOG_COMMAND
)
raise

View File

@ -0,0 +1,46 @@
from collections.abc import Callable
from pydantic import BaseModel
# 数据库操作超时设置(秒)
DB_TIMEOUT_SECONDS = 3.0
# 性能监控阈值(秒)
SLOW_QUERY_THRESHOLD = 0.5
LOG_COMMAND = "DbContext"
POSTGRESQL_CONFIG = {
"max_size": 30, # 最大连接数
"min_size": 5, # 最小保持的连接数(可选)
}
MYSQL_CONFIG = {
"max_connections": 20, # 最大连接数
"connect_timeout": 30, # 连接超时(可选)
}
SQLITE_CONFIG = {
"journal_mode": "WAL", # 提高并发写入性能
"timeout": 30, # 锁等待超时(可选)
}
class DbModel(BaseModel):
script_method: list[tuple[str, Callable]] = []
models: list[str] = []
db_model = DbModel()
prompt = """
**********************************************************************
🌟 **************************** 配置为空 ************************* 🌟
🚀 请打开 WebUi 进行基础配置 🚀
🌐 配置地址http://{host}:{port}/#/configure 🌐
***********************************************************************
***********************************************************************
"""

View File

@ -0,0 +1,14 @@
class DbUrlIsNode(Exception):
"""
数据库链接地址为空
"""
pass
class DbConnectError(Exception):
"""
数据库连接错误
"""
pass

View File

@ -0,0 +1,27 @@
import asyncio
import time
from zhenxun.services.log import logger
from .config import (
DB_TIMEOUT_SECONDS,
LOG_COMMAND,
SLOW_QUERY_THRESHOLD,
)
async def with_db_timeout(
coro, timeout: float = DB_TIMEOUT_SECONDS, operation: str | None = None
):
"""带超时控制的数据库操作"""
start_time = time.time()
try:
result = await asyncio.wait_for(coro, timeout=timeout)
elapsed = time.time() - start_time
if elapsed > SLOW_QUERY_THRESHOLD and operation:
logger.warning(f"慢查询: {operation} 耗时 {elapsed:.3f}s", LOG_COMMAND)
return result
except asyncio.TimeoutError:
if operation:
logger.error(f"数据库操作超时: {operation} (>{timeout}s)", LOG_COMMAND)
raise

View File

@ -54,19 +54,20 @@ class PluginInitManager:
@classmethod @classmethod
async def install_all(cls): async def install_all(cls):
"""运行所有插件安装方法""" """运行所有插件安装方法"""
if cls.plugins: if not cls.plugins:
for module_path, model in cls.plugins.items(): return
if model.install: for module_path, model in cls.plugins.items():
class_ = model.class_() if model.install:
try: class_ = model.class_()
logger.debug(f"开始执行: {module_path}:install 方法") try:
if is_coroutine_callable(class_.install): logger.debug(f"开始执行: {module_path}:install 方法")
await class_.install() if is_coroutine_callable(class_.install):
else: await class_.install()
class_.install() # type: ignore else:
logger.debug(f"执行: {module_path}:install 完成") class_.install() # type: ignore
except Exception as e: logger.debug(f"执行: {module_path}:install 完成")
logger.error(f"执行: {module_path}:install 失败", e=e) except Exception as e:
logger.error(f"执行: {module_path}:install 失败", e=e)
@classmethod @classmethod
async def install(cls, module_path: str): async def install(cls, module_path: str):