feat: 优化数据库初始化配置

This commit is contained in:
molanp 2025-07-02 14:21:09 +08:00 committed by GitHub
parent 5734bea175
commit 98b4249e29
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,7 @@
from asyncio import Semaphore from asyncio import Semaphore
from collections.abc import Iterable from collections.abc import Iterable
from typing import Any, ClassVar from typing import Any, ClassVar
from urllib.parse import urlparse
from typing_extensions import Self from typing_extensions import Self
import nonebot import nonebot
@ -186,11 +187,72 @@ async def init():
if not BotConfig.db_url: if not BotConfig.db_url:
raise DbUrlMissing("数据库配置为空,请在.env.dev中配置DB_URL...") raise DbUrlMissing("数据库配置为空,请在.env.dev中配置DB_URL...")
try: try:
db_url = BotConfig.db_url
url_scheme = db_url.split(":", 1)[0]
config = None
if url_scheme in ("postgres", "postgresql"):
# 解析 db_url
url = urlparse(db_url)
credentials = {
"host": url.hostname,
"port": url.port or 5432,
"user": url.username,
"password": url.password,
"database": url.path.lstrip("/"),
"minsize": 1,
"maxsize": 50,
}
config = {
"connections": {
"default": {
"engine": "tortoise.backends.asyncpg",
"credentials": credentials,
}
},
"apps": {
"models": {
"models": MODELS,
"default_connection": "default",
}
},
}
elif url_scheme in ("mysql", "mysql+aiomysql"):
url = urlparse(db_url)
credentials = {
"host": url.hostname,
"port": url.port or 3306,
"user": url.username,
"password": url.password,
"database": url.path.lstrip("/"),
"minsize": 1,
"maxsize": 50,
}
config = {
"connections": {
"default": {
"engine": "tortoise.backends.mysql",
"credentials": credentials,
}
},
"apps": {
"models": {
"models": MODELS,
"default_connection": "default",
}
},
}
else:
# sqlite 或其它,直接用 db_url
await Tortoise.init( await Tortoise.init(
db_url=BotConfig.db_url, db_url=db_url,
modules={"models": MODELS}, modules={"models": MODELS},
timezone="Asia/Shanghai", timezone="Asia/Shanghai",
) )
if config:
await Tortoise.init(config=config)
if SCRIPT_METHOD: if SCRIPT_METHOD:
db = Tortoise.get_connection("default") db = Tortoise.get_connection("default")
logger.debug( logger.debug(