zhenxun_bot/zhenxun/services/db_context/__init__.py

161 lines
5.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
from pathlib import Path
from urllib.parse import urlparse
import aiofiles
import nonebot
from nonebot.utils import is_coroutine_callable
from tortoise import Tortoise
from tortoise.connection import connections
from zhenxun.configs.config import BotConfig
from zhenxun.services.log import logger
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from .base_model import Model
from .config import (
DB_TIMEOUT_SECONDS,
MYSQL_CONFIG,
POSTGRESQL_CONFIG,
SLOW_QUERY_THRESHOLD,
SQLITE_CONFIG,
db_model,
prompt,
)
from .exceptions import DbConnectError, DbUrlIsNode
from .utils import with_db_timeout
MODELS = db_model.models
SCRIPT_METHOD = db_model.script_method
__all__ = [
"DB_TIMEOUT_SECONDS",
"MODELS",
"SCRIPT_METHOD",
"SLOW_QUERY_THRESHOLD",
"DbConnectError",
"DbUrlIsNode",
"Model",
"disconnect",
"init",
"with_db_timeout",
]
driver = nonebot.get_driver()
def get_config() -> dict:
"""获取数据库配置"""
if not BotConfig.db_url:
raise DbUrlIsNode("数据库Url连接字符串为空请检查配置文件.env.dev")
parsed = urlparse(BotConfig.db_url)
# 基础配置
config = {
"connections": {
"default": BotConfig.db_url # 默认直接使用连接字符串
},
"apps": {
"models": {
"models": db_model.models,
"default_connection": "default",
}
},
"timezone": "Asia/Shanghai",
}
# 根据数据库类型应用高级配置
if parsed.scheme.startswith("postgres"):
config["connections"]["default"] = {
"engine": "tortoise.backends.asyncpg",
"credentials": {
"host": parsed.hostname,
"port": parsed.port or 5432,
"user": parsed.username,
"password": parsed.password,
"database": parsed.path[1:],
},
**POSTGRESQL_CONFIG,
}
elif parsed.scheme == "mysql":
config["connections"]["default"] = {
"engine": "tortoise.backends.mysql",
"credentials": {
"host": parsed.hostname,
"port": parsed.port or 3306,
"user": parsed.username,
"password": parsed.password,
"database": parsed.path[1:],
},
**MYSQL_CONFIG,
}
elif parsed.scheme == "sqlite":
Path(parsed.path).parent.mkdir(parents=True, exist_ok=True)
config["connections"]["default"] = {
"engine": "tortoise.backends.sqlite",
"credentials": {
"file_path": parsed.path,
},
**SQLITE_CONFIG,
}
return config
@PriorityLifecycle.on_startup(priority=1)
async def init():
global MODELS, SCRIPT_METHOD
env_example_file = Path() / ".env.example"
env_dev_file = Path() / ".env.dev"
if not env_dev_file.exists():
async with aiofiles.open(env_example_file, encoding="utf-8") as f:
env_text = await f.read()
async with aiofiles.open(env_dev_file, "w", encoding="utf-8") as f:
await f.write(env_text)
logger.info("已生成 .env.dev 文件,请根据 .env.example 文件配置进行配置")
MODELS = db_model.models
SCRIPT_METHOD = db_model.script_method
if not BotConfig.db_url:
error = prompt.format(host=driver.config.host, port=driver.config.port)
raise DbUrlIsNode("\n" + error.strip())
try:
await Tortoise.init(
config=get_config(),
)
if db_model.script_method:
db = Tortoise.get_connection("default")
logger.debug(
"即将运行SCRIPT_METHOD方法, 合计 "
f"<u><y>{len(db_model.script_method)}</y></u> 个..."
)
sql_list = []
for module, func in db_model.script_method:
try:
sql = await func() if is_coroutine_callable(func) else func()
if sql:
sql_list += sql
except Exception as e:
logger.debug(f"{module} 执行SCRIPT_METHOD方法出错...", e=e)
for sql in sql_list:
logger.debug(f"执行SQL: {sql}")
try:
await asyncio.wait_for(
db.execute_query_dict(sql), timeout=DB_TIMEOUT_SECONDS
)
# await TestSQL.raw(sql)
except Exception as e:
logger.debug(f"执行SQL: {sql} 错误...", e=e)
if sql_list:
logger.debug("SCRIPT_METHOD方法执行完毕!")
logger.debug("开始生成数据库表结构...")
await Tortoise.generate_schemas()
logger.debug("数据库表结构生成完毕!")
logger.info("Database loaded successfully!")
except Exception as e:
raise DbConnectError(f"数据库连接错误... e:{e}") from e
async def disconnect():
await connections.close_all()