diff --git a/zhenxun/services/db_context.py b/zhenxun/services/db_context.py index 9f4da8ba..b8897c83 100644 --- a/zhenxun/services/db_context.py +++ b/zhenxun/services/db_context.py @@ -1,6 +1,7 @@ from asyncio import Semaphore from collections.abc import Iterable from typing import Any, ClassVar +from urllib.parse import urlparse from typing_extensions import Self import nonebot @@ -186,11 +187,72 @@ async def init(): if not BotConfig.db_url: raise DbUrlMissing("数据库配置为空,请在.env.dev中配置DB_URL...") try: - await Tortoise.init( - db_url=BotConfig.db_url, - modules={"models": MODELS}, - timezone="Asia/Shanghai", - ) + db_url = BotConfig.db_url + url_scheme = db_url.split(":", 1)[0] + config = None + + if url_scheme in ("postgres", "postgresql"): + # 解析 db_url + url = urlparse(db_url) + credentials = { + "host": url.hostname, + "port": url.port or 5432, + "user": url.username, + "password": url.password, + "database": url.path.lstrip("/"), + "minsize": 1, + "maxsize": 50, + } + config = { + "connections": { + "default": { + "engine": "tortoise.backends.asyncpg", + "credentials": credentials, + } + }, + "apps": { + "models": { + "models": MODELS, + "default_connection": "default", + } + }, + } + elif url_scheme in ("mysql", "mysql+aiomysql"): + url = urlparse(db_url) + credentials = { + "host": url.hostname, + "port": url.port or 3306, + "user": url.username, + "password": url.password, + "database": url.path.lstrip("/"), + "minsize": 1, + "maxsize": 50, + } + config = { + "connections": { + "default": { + "engine": "tortoise.backends.mysql", + "credentials": credentials, + } + }, + "apps": { + "models": { + "models": MODELS, + "default_connection": "default", + } + }, + } + else: + # sqlite 或其它,直接用 db_url + await Tortoise.init( + db_url=db_url, + modules={"models": MODELS}, + timezone="Asia/Shanghai", + ) + + if config: + await Tortoise.init(config=config) + if SCRIPT_METHOD: db = Tortoise.get_connection("default") logger.debug(