mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-14 21:52:56 +08:00
161 lines
5.0 KiB
Python
161 lines
5.0 KiB
Python
import asyncio
|
||
from pathlib import Path
|
||
from urllib.parse import urlparse
|
||
|
||
import aiofiles
|
||
import nonebot
|
||
from nonebot.utils import is_coroutine_callable
|
||
from tortoise import Tortoise
|
||
from tortoise.connection import connections
|
||
|
||
from zhenxun.configs.config import BotConfig
|
||
from zhenxun.services.log import logger
|
||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||
|
||
from .base_model import Model
|
||
from .config import (
|
||
DB_TIMEOUT_SECONDS,
|
||
MYSQL_CONFIG,
|
||
POSTGRESQL_CONFIG,
|
||
SLOW_QUERY_THRESHOLD,
|
||
SQLITE_CONFIG,
|
||
db_model,
|
||
prompt,
|
||
)
|
||
from .exceptions import DbConnectError, DbUrlIsNode
|
||
from .utils import with_db_timeout
|
||
|
||
MODELS = db_model.models
|
||
SCRIPT_METHOD = db_model.script_method
|
||
|
||
__all__ = [
|
||
"DB_TIMEOUT_SECONDS",
|
||
"MODELS",
|
||
"SCRIPT_METHOD",
|
||
"SLOW_QUERY_THRESHOLD",
|
||
"DbConnectError",
|
||
"DbUrlIsNode",
|
||
"Model",
|
||
"disconnect",
|
||
"init",
|
||
"with_db_timeout",
|
||
]
|
||
|
||
driver = nonebot.get_driver()
|
||
|
||
|
||
def get_config() -> dict:
|
||
"""获取数据库配置"""
|
||
if not BotConfig.db_url:
|
||
raise DbUrlIsNode("数据库Url连接字符串为空,请检查配置文件(.env.dev)")
|
||
parsed = urlparse(BotConfig.db_url)
|
||
|
||
# 基础配置
|
||
config = {
|
||
"connections": {
|
||
"default": BotConfig.db_url # 默认直接使用连接字符串
|
||
},
|
||
"apps": {
|
||
"models": {
|
||
"models": db_model.models,
|
||
"default_connection": "default",
|
||
}
|
||
},
|
||
"timezone": "Asia/Shanghai",
|
||
}
|
||
|
||
# 根据数据库类型应用高级配置
|
||
if parsed.scheme.startswith("postgres"):
|
||
config["connections"]["default"] = {
|
||
"engine": "tortoise.backends.asyncpg",
|
||
"credentials": {
|
||
"host": parsed.hostname,
|
||
"port": parsed.port or 5432,
|
||
"user": parsed.username,
|
||
"password": parsed.password,
|
||
"database": parsed.path[1:],
|
||
},
|
||
**POSTGRESQL_CONFIG,
|
||
}
|
||
elif parsed.scheme == "mysql":
|
||
config["connections"]["default"] = {
|
||
"engine": "tortoise.backends.mysql",
|
||
"credentials": {
|
||
"host": parsed.hostname,
|
||
"port": parsed.port or 3306,
|
||
"user": parsed.username,
|
||
"password": parsed.password,
|
||
"database": parsed.path[1:],
|
||
},
|
||
**MYSQL_CONFIG,
|
||
}
|
||
elif parsed.scheme == "sqlite":
|
||
Path(parsed.path).parent.mkdir(parents=True, exist_ok=True)
|
||
config["connections"]["default"] = {
|
||
"engine": "tortoise.backends.sqlite",
|
||
"credentials": {
|
||
"file_path": parsed.path,
|
||
},
|
||
**SQLITE_CONFIG,
|
||
}
|
||
return config
|
||
|
||
|
||
@PriorityLifecycle.on_startup(priority=1)
|
||
async def init():
|
||
global MODELS, SCRIPT_METHOD
|
||
|
||
env_example_file = Path() / ".env.example"
|
||
env_dev_file = Path() / ".env.dev"
|
||
if not env_dev_file.exists():
|
||
async with aiofiles.open(env_example_file, encoding="utf-8") as f:
|
||
env_text = await f.read()
|
||
async with aiofiles.open(env_dev_file, "w", encoding="utf-8") as f:
|
||
await f.write(env_text)
|
||
logger.info("已生成 .env.dev 文件,请根据 .env.example 文件配置进行配置")
|
||
|
||
MODELS = db_model.models
|
||
SCRIPT_METHOD = db_model.script_method
|
||
if not BotConfig.db_url:
|
||
error = prompt.format(host=driver.config.host, port=driver.config.port)
|
||
raise DbUrlIsNode("\n" + error.strip())
|
||
try:
|
||
await Tortoise.init(
|
||
config=get_config(),
|
||
)
|
||
if db_model.script_method:
|
||
db = Tortoise.get_connection("default")
|
||
logger.debug(
|
||
"即将运行SCRIPT_METHOD方法, 合计 "
|
||
f"<u><y>{len(db_model.script_method)}</y></u> 个..."
|
||
)
|
||
sql_list = []
|
||
for module, func in db_model.script_method:
|
||
try:
|
||
sql = await func() if is_coroutine_callable(func) else func()
|
||
if sql:
|
||
sql_list += sql
|
||
except Exception as e:
|
||
logger.debug(f"{module} 执行SCRIPT_METHOD方法出错...", e=e)
|
||
for sql in sql_list:
|
||
logger.debug(f"执行SQL: {sql}")
|
||
try:
|
||
await asyncio.wait_for(
|
||
db.execute_query_dict(sql), timeout=DB_TIMEOUT_SECONDS
|
||
)
|
||
# await TestSQL.raw(sql)
|
||
except Exception as e:
|
||
logger.debug(f"执行SQL: {sql} 错误...", e=e)
|
||
if sql_list:
|
||
logger.debug("SCRIPT_METHOD方法执行完毕!")
|
||
logger.debug("开始生成数据库表结构...")
|
||
await Tortoise.generate_schemas()
|
||
logger.debug("数据库表结构生成完毕!")
|
||
logger.info("Database loaded successfully!")
|
||
except Exception as e:
|
||
raise DbConnectError(f"数据库连接错误... e:{e}") from e
|
||
|
||
|
||
async def disconnect():
|
||
await connections.close_all()
|