格式化db_context

This commit is contained in:
HibiKier 2025-07-15 00:44:05 +08:00
parent faa91b8bd4
commit 88493328ee
10 changed files with 567 additions and 66 deletions

View File

@ -5,7 +5,7 @@ import nonebot
from nonebot.adapters import Bot
from nonebot.drivers import Driver
from tortoise import Tortoise
from tortoise.exceptions import OperationalError
from tortoise.exceptions import IntegrityError, OperationalError
import ujson as json
from zhenxun.models.bot_connect_log import BotConnectLog
@ -30,9 +30,12 @@ async def _(bot: Bot):
bot_id=bot.self_id, platform=bot.adapter, connect_time=datetime.now(), type=1
)
if not await BotConsole.exists(bot_id=bot.self_id):
await BotConsole.create(
bot_id=bot.self_id, platform=PlatformUtils.get_platform(bot)
)
try:
await BotConsole.create(
bot_id=bot.self_id, platform=PlatformUtils.get_platform(bot)
)
except IntegrityError as e:
logger.warning(f"记录bot: {bot.self_id} 数据已存在...", e=e)
@driver.on_bot_disconnect

View File

@ -18,7 +18,7 @@ require("nonebot_plugin_htmlrender")
require("nonebot_plugin_uninfo")
require("nonebot_plugin_waiter")
from .db_context import Model, disconnect
from .db_context import Model, disconnect, with_db_timeout
from .llm import (
AI,
AIConfig,
@ -80,4 +80,5 @@ __all__ = [
"search_multimodal",
"set_global_default_model_name",
"tool_registry",
"with_db_timeout",
]

View File

@ -63,6 +63,7 @@ from functools import wraps
from typing import Any, ClassVar, Generic, TypeVar, get_type_hints
from aiocache import Cache as AioCache
from aiocache import SimpleMemoryCache
from aiocache.base import BaseCache
from aiocache.serializers import JsonSerializer
import nonebot
@ -392,22 +393,14 @@ class CacheManager:
def cache_backend(self) -> BaseCache | AioCache:
"""获取缓存后端"""
if self._cache_backend is None:
try:
from aiocache import RedisCache, SimpleMemoryCache
ttl = cache_config.redis_expire
if cache_config.cache_mode == CacheMode.NONE:
ttl = 0
logger.info("缓存功能已禁用,使用非持久化内存缓存", LOG_COMMAND)
elif cache_config.cache_mode == CacheMode.REDIS and cache_config.redis_host:
try:
from aiocache import RedisCache
if cache_config.cache_mode == CacheMode.NONE:
# 使用内存缓存但禁用持久化
self._cache_backend = SimpleMemoryCache(
serializer=JsonSerializer(),
namespace=CACHE_KEY_PREFIX,
timeout=30,
ttl=0, # 设置为0不缓存
)
logger.info("缓存功能已禁用,使用非持久化内存缓存", LOG_COMMAND)
elif (
cache_config.cache_mode == CacheMode.REDIS
and cache_config.redis_host
):
# 使用Redis缓存
self._cache_backend = RedisCache(
serializer=JsonSerializer(),
@ -419,27 +412,25 @@ class CacheManager:
password=cache_config.redis_password,
)
logger.info(
f"使用Redis缓存地址: {cache_config.redis_host}", LOG_COMMAND
f"使用Redis缓存地址: {cache_config.redis_host}",
LOG_COMMAND,
)
else:
# 默认使用内存缓存
self._cache_backend = SimpleMemoryCache(
serializer=JsonSerializer(),
namespace=CACHE_KEY_PREFIX,
timeout=30,
ttl=cache_config.redis_expire,
return self._cache_backend
except ImportError as e:
logger.error(
"导入aiocache[reids]失败,将默认使用内存缓存...",
LOG_COMMAND,
e=e,
)
logger.info("使用内存缓存", LOG_COMMAND)
except ImportError:
logger.error("导入aiocache模块失败使用内存缓存", LOG_COMMAND)
# 使用内存缓存
self._cache_backend = AioCache(
cache_class=AioCache.MEMORY,
serializer=JsonSerializer(),
namespace=CACHE_KEY_PREFIX,
timeout=30,
ttl=cache_config.redis_expire,
)
else:
logger.info("使用内存缓存", LOG_COMMAND)
# 默认使用内存缓存
self._cache_backend = SimpleMemoryCache(
serializer=JsonSerializer(),
namespace=CACHE_KEY_PREFIX,
timeout=30,
ttl=ttl,
)
return self._cache_backend
@property

View File

@ -54,11 +54,7 @@ class CacheDict:
key: 字典键
value: 字典值
"""
# 计算过期时间
expire_time = 0
if self.expire > 0:
expire_time = time.time() + self.expire
expire_time = time.time() + self.expire if self.expire > 0 else 0
self._data[key] = CacheData(value=value, expire_time=expire_time)
def __delitem__(self, key: str) -> None:
@ -274,12 +270,11 @@ class CacheList:
self.clear()
raise IndexError(f"列表索引 {index} 超出范围")
if 0 <= index < len(self._data):
del self._data[index]
# 更新过期时间
self._update_expire_time()
else:
if not 0 <= index < len(self._data):
raise IndexError(f"列表索引 {index} 超出范围")
del self._data[index]
# 更新过期时间
self._update_expire_time()
def __len__(self) -> int:
"""获取列表长度
@ -427,6 +422,7 @@ class CacheList:
self.clear()
return 0
# sourcery skip: simplify-constant-sum
return sum(1 for item in self._data if item.value == value)
def _is_expired(self) -> bool:
@ -435,10 +431,7 @@ class CacheList:
def _update_expire_time(self) -> None:
"""更新过期时间"""
if self.expire > 0:
self._expire_time = time.time() + self.expire
else:
self._expire_time = 0
self._expire_time = time.time() + self.expire if self.expire > 0 else 0
def __str__(self) -> str:
"""字符串表示

View File

@ -0,0 +1,137 @@
import asyncio
from urllib.parse import urlparse
import nonebot
from nonebot.utils import is_coroutine_callable
from tortoise import Tortoise
from tortoise.connection import connections
from zhenxun.configs.config import BotConfig
from zhenxun.services.log import logger
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from .base_model import Model
from .config import (
DB_TIMEOUT_SECONDS,
MYSQL_CONFIG,
POSTGRESQL_CONFIG,
SLOW_QUERY_THRESHOLD,
SQLITE_CONFIG,
db_model,
prompt,
)
from .exceptions import DbConnectError, DbUrlIsNode
from .utils import with_db_timeout
__all__ = [
"DB_TIMEOUT_SECONDS",
"SLOW_QUERY_THRESHOLD",
"DbConnectError",
"DbUrlIsNode",
"Model",
"disconnect",
"init",
"with_db_timeout",
]
driver = nonebot.get_driver()
def get_config() -> dict:
"""获取数据库配置"""
parsed = urlparse(BotConfig.db_url)
# 基础配置
config = {
"connections": {
"default": BotConfig.db_url # 默认直接使用连接字符串
},
"apps": {
"models": {
"models": db_model.models,
"default_connection": "default",
}
},
"timezone": "Asia/Shanghai",
}
# 根据数据库类型应用高级配置
if parsed.scheme.startswith("postgres"):
config["connections"]["default"] = {
"engine": "tortoise.backends.asyncpg",
"credentials": {
"host": parsed.hostname,
"port": parsed.port or 5432,
"user": parsed.username,
"password": parsed.password,
"database": parsed.path[1:],
},
**POSTGRESQL_CONFIG,
}
elif parsed.scheme == "mysql":
config["connections"]["default"] = {
"engine": "tortoise.backends.mysql",
"credentials": {
"host": parsed.hostname,
"port": parsed.port or 3306,
"user": parsed.username,
"password": parsed.password,
"database": parsed.path[1:],
},
**MYSQL_CONFIG,
}
elif parsed.scheme == "sqlite":
config["connections"]["default"] = {
"engine": "tortoise.backends.sqlite",
"credentials": {
"file_path": parsed.path or ":memory:",
},
**SQLITE_CONFIG,
}
return config
@PriorityLifecycle.on_startup(priority=1)
async def init():
if not BotConfig.db_url:
error = prompt.format(host=driver.config.host, port=driver.config.port)
raise DbUrlIsNode("\n" + error.strip())
try:
await Tortoise.init(
config=get_config(),
)
if db_model.script_method:
db = Tortoise.get_connection("default")
logger.debug(
"即将运行SCRIPT_METHOD方法, 合计 "
f"<u><y>{len(db_model.script_method)}</y></u> 个..."
)
sql_list = []
for module, func in db_model.script_method:
try:
sql = await func() if is_coroutine_callable(func) else func()
if sql:
sql_list += sql
except Exception as e:
logger.debug(f"{module} 执行SCRIPT_METHOD方法出错...", e=e)
for sql in sql_list:
logger.debug(f"执行SQL: {sql}")
try:
await asyncio.wait_for(
db.execute_query_dict(sql), timeout=DB_TIMEOUT_SECONDS
)
# await TestSQL.raw(sql)
except Exception as e:
logger.debug(f"执行SQL: {sql} 错误...", e=e)
if sql_list:
logger.debug("SCRIPT_METHOD方法执行完毕!")
logger.debug("开始生成数据库表结构...")
await Tortoise.generate_schemas()
logger.debug("数据库表结构生成完毕!")
logger.info("Database loaded successfully!")
except Exception as e:
raise DbConnectError(f"数据库连接错误... e:{e}") from e
async def disconnect():
await connections.close_all()

View File

@ -0,0 +1,288 @@
import asyncio
from collections.abc import Iterable
import contextlib
from typing import Any, ClassVar
from typing_extensions import Self
from tortoise.backends.base.client import BaseDBAsyncClient
from tortoise.exceptions import IntegrityError, MultipleObjectsReturned
from tortoise.models import Model as TortoiseModel
from tortoise.transactions import in_transaction
from zhenxun.services.cache import CacheRoot
from zhenxun.services.log import logger
from zhenxun.utils.enum import DbLockType
from .config import LOG_COMMAND, db_model
from .utils import with_db_timeout
class Model(TortoiseModel):
"""
增强的ORM基类解决锁嵌套问题
"""
sem_data: ClassVar[dict[str, dict[str, asyncio.Semaphore]]] = {}
_current_locks: ClassVar[dict[int, DbLockType]] = {} # 跟踪当前协程持有的锁
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if cls.__module__ not in db_model.models:
db_model.models.append(cls.__module__)
if func := getattr(cls, "_run_script", None):
db_model.script_method.append((cls.__module__, func))
@classmethod
def get_cache_type(cls) -> str | None:
"""获取缓存类型"""
return getattr(cls, "cache_type", None)
@classmethod
def get_cache_key_field(cls) -> str | tuple[str]:
"""获取缓存键字段"""
return getattr(cls, "cache_key_field", "id")
@classmethod
def get_cache_key(cls, instance) -> str | None:
"""获取缓存键
参数:
instance: 模型实例
返回:
str | None: 缓存键如果无法获取则返回None
"""
from zhenxun.services.cache.config import COMPOSITE_KEY_SEPARATOR
key_field = cls.get_cache_key_field()
if isinstance(key_field, tuple):
# 多字段主键
key_parts = []
for field in key_field:
if hasattr(instance, field):
value = getattr(instance, field, None)
key_parts.append(value if value is not None else "")
else:
# 如果缺少任何必要的字段返回None
key_parts.append("")
# 如果没有有效参数返回None
return COMPOSITE_KEY_SEPARATOR.join(key_parts) if key_parts else None
elif hasattr(instance, key_field):
value = getattr(instance, key_field, None)
return str(value) if value is not None else None
return None
@classmethod
def get_semaphore(cls, lock_type: DbLockType):
enable_lock = getattr(cls, "enable_lock", None)
if not enable_lock or lock_type not in enable_lock:
return None
if cls.__name__ not in cls.sem_data:
cls.sem_data[cls.__name__] = {}
if lock_type not in cls.sem_data[cls.__name__]:
cls.sem_data[cls.__name__][lock_type] = asyncio.Semaphore(1)
return cls.sem_data[cls.__name__][lock_type]
@classmethod
def _require_lock(cls, lock_type: DbLockType) -> bool:
"""检查是否需要真正加锁"""
task_id = id(asyncio.current_task())
return cls._current_locks.get(task_id) != lock_type
@classmethod
@contextlib.asynccontextmanager
async def _lock_context(cls, lock_type: DbLockType):
"""带重入检查的锁上下文"""
task_id = id(asyncio.current_task())
need_lock = cls._require_lock(lock_type)
if need_lock and (sem := cls.get_semaphore(lock_type)):
cls._current_locks[task_id] = lock_type
async with sem:
yield
cls._current_locks.pop(task_id, None)
else:
yield
@classmethod
async def create(
cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any
) -> Self:
"""创建数据使用CREATE锁"""
async with cls._lock_context(DbLockType.CREATE):
# 直接调用父类的_create方法避免触发save的锁
result = await super().create(using_db=using_db, **kwargs)
if cache_type := cls.get_cache_type():
await CacheRoot.invalidate_cache(cache_type, cls.get_cache_key(result))
return result
@classmethod
async def get_or_create(
cls,
defaults: dict | None = None,
using_db: BaseDBAsyncClient | None = None,
**kwargs: Any,
) -> tuple[Self, bool]:
"""获取或创建数据(无锁版本,依赖数据库约束)"""
result = await super().get_or_create(
defaults=defaults, using_db=using_db, **kwargs
)
if cache_type := cls.get_cache_type():
await CacheRoot.invalidate_cache(cache_type, cls.get_cache_key(result[0]))
return result
@classmethod
async def update_or_create(
cls,
defaults: dict | None = None,
using_db: BaseDBAsyncClient | None = None,
**kwargs: Any,
) -> tuple[Self, bool]:
"""更新或创建数据使用UPSERT锁"""
async with cls._lock_context(DbLockType.UPSERT):
try:
# 先尝试更新(带行锁)
async with in_transaction():
if obj := await cls.filter(**kwargs).select_for_update().first():
await obj.update_from_dict(defaults or {})
await obj.save()
result = (obj, False)
else:
# 创建时不重复加锁
result = await cls.create(**kwargs, **(defaults or {})), True
if cache_type := cls.get_cache_type():
await CacheRoot.invalidate_cache(
cache_type, cls.get_cache_key(result[0])
)
return result
except IntegrityError:
# 处理极端情况下的唯一约束冲突
obj = await cls.get(**kwargs)
return obj, False
async def save(
self,
using_db: BaseDBAsyncClient | None = None,
update_fields: Iterable[str] | None = None,
force_create: bool = False,
force_update: bool = False,
):
"""保存数据(根据操作类型自动选择锁)"""
lock_type = (
DbLockType.CREATE
if getattr(self, "id", None) is None
else DbLockType.UPDATE
)
async with self._lock_context(lock_type):
await super().save(
using_db=using_db,
update_fields=update_fields,
force_create=force_create,
force_update=force_update,
)
if cache_type := getattr(self, "cache_type", None):
await CacheRoot.invalidate_cache(
cache_type, self.__class__.get_cache_key(self)
)
async def delete(self, using_db: BaseDBAsyncClient | None = None):
cache_type = getattr(self, "cache_type", None)
key = self.__class__.get_cache_key(self) if cache_type else None
# 执行删除操作
await super().delete(using_db=using_db)
# 清除缓存
if cache_type:
await CacheRoot.invalidate_cache(cache_type, key)
@classmethod
async def safe_get_or_none(
cls,
*args,
using_db: BaseDBAsyncClient | None = None,
clean_duplicates: bool = True,
**kwargs: Any,
) -> Self | None:
"""安全地获取一条记录或None处理存在多个记录时返回最新的那个
注意默认会删除重复的记录仅保留最新的
参数:
*args: 查询参数
using_db: 数据库连接
clean_duplicates: 是否删除重复的记录仅保留最新的
**kwargs: 查询参数
返回:
Self | None: 查询结果如果不存在返回None
"""
try:
# 先尝试使用 get_or_none 获取单个记录
try:
return await with_db_timeout(
cls.get_or_none(*args, using_db=using_db, **kwargs),
operation=f"{cls.__name__}.get_or_none",
)
except MultipleObjectsReturned:
# 如果出现多个记录的情况,进行特殊处理
logger.warning(
f"{cls.__name__} safe_get_or_none 发现多个记录: {kwargs}",
LOG_COMMAND,
)
# 查询所有匹配记录
records = await with_db_timeout(
cls.filter(*args, **kwargs).all(),
operation=f"{cls.__name__}.filter.all",
)
if not records:
return None
# 如果需要清理重复记录
if clean_duplicates and hasattr(records[0], "id"):
# 按 id 排序
records = sorted(
records, key=lambda x: getattr(x, "id", 0), reverse=True
)
for record in records[1:]:
try:
await with_db_timeout(
record.delete(),
operation=f"{cls.__name__}.delete_duplicate",
)
logger.info(
f"{cls.__name__} 删除重复记录:"
f" id={getattr(record, 'id', None)}",
LOG_COMMAND,
)
except Exception as del_e:
logger.error(f"删除重复记录失败: {del_e}")
return records[0]
# 如果不需要清理或没有 id 字段,则返回最新的记录
if hasattr(cls, "id"):
return await with_db_timeout(
cls.filter(*args, **kwargs).order_by("-id").first(),
operation=f"{cls.__name__}.filter.order_by.first",
)
# 如果没有 id 字段,则返回第一个记录
return await with_db_timeout(
cls.filter(*args, **kwargs).first(),
operation=f"{cls.__name__}.filter.first",
)
except asyncio.TimeoutError:
logger.error(
f"数据库操作超时: {cls.__name__}.safe_get_or_none", LOG_COMMAND
)
return None
except Exception as e:
# 其他类型的错误则继续抛出
logger.error(
f"数据库操作异常: {cls.__name__}.safe_get_or_none, {e!s}", LOG_COMMAND
)
raise

View File

@ -0,0 +1,46 @@
from collections.abc import Callable
from pydantic import BaseModel
# 数据库操作超时设置(秒)
DB_TIMEOUT_SECONDS = 3.0
# 性能监控阈值(秒)
SLOW_QUERY_THRESHOLD = 0.5
LOG_COMMAND = "DbContext"
POSTGRESQL_CONFIG = {
"max_size": 30, # 最大连接数
"min_size": 5, # 最小保持的连接数(可选)
}
MYSQL_CONFIG = {
"max_connections": 20, # 最大连接数
"connect_timeout": 30, # 连接超时(可选)
}
SQLITE_CONFIG = {
"journal_mode": "WAL", # 提高并发写入性能
"timeout": 30, # 锁等待超时(可选)
}
class DbModel(BaseModel):
script_method: list[tuple[str, Callable]] = []
models: list[str] = []
db_model = DbModel()
prompt = """
**********************************************************************
🌟 **************************** 配置为空 ************************* 🌟
🚀 请打开 WebUi 进行基础配置 🚀
🌐 配置地址http://{host}:{port}/#/configure 🌐
***********************************************************************
***********************************************************************
"""

View File

@ -0,0 +1,14 @@
class DbUrlIsNode(Exception):
"""
数据库链接地址为空
"""
pass
class DbConnectError(Exception):
"""
数据库连接错误
"""
pass

View File

@ -0,0 +1,27 @@
import asyncio
import time
from zhenxun.services.log import logger
from .config import (
DB_TIMEOUT_SECONDS,
LOG_COMMAND,
SLOW_QUERY_THRESHOLD,
)
async def with_db_timeout(
coro, timeout: float = DB_TIMEOUT_SECONDS, operation: str | None = None
):
"""带超时控制的数据库操作"""
start_time = time.time()
try:
result = await asyncio.wait_for(coro, timeout=timeout)
elapsed = time.time() - start_time
if elapsed > SLOW_QUERY_THRESHOLD and operation:
logger.warning(f"慢查询: {operation} 耗时 {elapsed:.3f}s", LOG_COMMAND)
return result
except asyncio.TimeoutError:
if operation:
logger.error(f"数据库操作超时: {operation} (>{timeout}s)", LOG_COMMAND)
raise

View File

@ -54,19 +54,20 @@ class PluginInitManager:
@classmethod
async def install_all(cls):
"""运行所有插件安装方法"""
if cls.plugins:
for module_path, model in cls.plugins.items():
if model.install:
class_ = model.class_()
try:
logger.debug(f"开始执行: {module_path}:install 方法")
if is_coroutine_callable(class_.install):
await class_.install()
else:
class_.install() # type: ignore
logger.debug(f"执行: {module_path}:install 完成")
except Exception as e:
logger.error(f"执行: {module_path}:install 失败", e=e)
if not cls.plugins:
return
for module_path, model in cls.plugins.items():
if model.install:
class_ = model.class_()
try:
logger.debug(f"开始执行: {module_path}:install 方法")
if is_coroutine_callable(class_.install):
await class_.install()
else:
class_.install() # type: ignore
logger.debug(f"执行: {module_path}:install 完成")
except Exception as e:
logger.error(f"执行: {module_path}:install 失败", e=e)
@classmethod
async def install(cls, module_path: str):