zhenxun_bot/zhenxun/services/db_context.py

240 lines
7.5 KiB
Python
Raw Normal View History

2025-07-01 16:56:34 +08:00
from asyncio import Semaphore
from collections.abc import Iterable
from typing import Any, ClassVar
from typing_extensions import Self
import nonebot
from nonebot.utils import is_coroutine_callable
2024-08-21 22:22:42 +08:00
from tortoise import Tortoise
2025-07-01 16:56:34 +08:00
from tortoise.backends.base.client import BaseDBAsyncClient
2024-02-04 04:18:54 +08:00
from tortoise.connection import connections
2025-07-01 16:56:34 +08:00
from tortoise.models import Model as TortoiseModel
2024-02-04 04:18:54 +08:00
2024-08-24 19:32:52 +08:00
from zhenxun.configs.config import BotConfig
2025-07-01 16:56:34 +08:00
from zhenxun.utils.enum import DbLockType
2025-07-02 16:13:47 +08:00
from zhenxun.utils.exception import HookPriorityException
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
2024-02-04 04:18:54 +08:00
2025-07-01 16:56:34 +08:00
from .cache import CacheRoot
2024-02-04 04:18:54 +08:00
from .log import logger
SCRIPT_METHOD = []
2024-08-21 22:22:42 +08:00
MODELS: list[str] = []
2024-08-10 12:10:53 +08:00
2025-07-01 16:56:34 +08:00
driver = nonebot.get_driver()
CACHE_FLAG = False
@driver.on_bot_connect
def _():
global CACHE_FLAG
CACHE_FLAG = True
2024-02-04 04:18:54 +08:00
2025-07-05 02:48:59 +08:00
driver = nonebot.get_driver()
2025-07-01 16:56:34 +08:00
class Model(TortoiseModel):
2024-02-04 04:18:54 +08:00
"""
自动添加模块
2025-07-05 02:48:59 +08:00
Args:
Model_: Model
"""
2025-07-01 16:56:34 +08:00
sem_data: ClassVar[dict[str, dict[str, Semaphore]]] = {}
2025-07-05 02:48:59 +08:00
2024-02-04 04:18:54 +08:00
def __init_subclass__(cls, **kwargs):
2025-07-05 02:48:59 +08:00
MODELS.append(cls.__module__)
2024-02-04 04:18:54 +08:00
if func := getattr(cls, "_run_script", None):
SCRIPT_METHOD.append((cls.__module__, func))
2025-07-01 16:56:34 +08:00
if enable_lock := getattr(cls, "enable_lock", []):
"""创建锁"""
cls.sem_data[cls.__module__] = {}
for lock in enable_lock:
cls.sem_data[cls.__module__][lock] = Semaphore(1)
@classmethod
def get_semaphore(cls, lock_type: DbLockType):
return cls.sem_data.get(cls.__module__, {}).get(lock_type, None)
@classmethod
def get_cache_type(cls):
return getattr(cls, "cache_type", None) if CACHE_FLAG else None
@classmethod
async def create(
cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any
) -> Self:
return await super().create(using_db=using_db, **kwargs)
@classmethod
async def get_or_create(
cls,
defaults: dict | None = None,
using_db: BaseDBAsyncClient | None = None,
**kwargs: Any,
) -> tuple[Self, bool]:
result, is_create = await super().get_or_create(
defaults=defaults, using_db=using_db, **kwargs
)
if is_create and (cache_type := cls.get_cache_type()):
await CacheRoot.reload(cache_type)
return (result, is_create)
@classmethod
async def update_or_create(
cls,
defaults: dict | None = None,
using_db: BaseDBAsyncClient | None = None,
**kwargs: Any,
) -> tuple[Self, bool]:
result = await super().update_or_create(
defaults=defaults, using_db=using_db, **kwargs
)
if cache_type := cls.get_cache_type():
await CacheRoot.reload(cache_type)
return result
@classmethod
async def bulk_create( # type: ignore
cls,
objects: Iterable[Self],
batch_size: int | None = None,
ignore_conflicts: bool = False,
update_fields: Iterable[str] | None = None,
on_conflict: Iterable[str] | None = None,
using_db: BaseDBAsyncClient | None = None,
) -> list[Self]:
result = await super().bulk_create(
objects=objects,
batch_size=batch_size,
ignore_conflicts=ignore_conflicts,
update_fields=update_fields,
on_conflict=on_conflict,
using_db=using_db,
)
if cache_type := cls.get_cache_type():
await CacheRoot.reload(cache_type)
return result
@classmethod
async def bulk_update( # type: ignore
cls,
objects: Iterable[Self],
fields: Iterable[str],
batch_size: int | None = None,
using_db: BaseDBAsyncClient | None = None,
) -> int:
result = await super().bulk_update(
objects=objects,
fields=fields,
batch_size=batch_size,
using_db=using_db,
)
if cache_type := cls.get_cache_type():
await CacheRoot.reload(cache_type)
return result
async def save(
self,
using_db: BaseDBAsyncClient | None = None,
update_fields: Iterable[str] | None = None,
force_create: bool = False,
force_update: bool = False,
):
if getattr(self, "id", None) is None:
sem = self.get_semaphore(DbLockType.CREATE)
else:
sem = self.get_semaphore(DbLockType.UPDATE)
if sem:
async with sem:
await super().save(
using_db=using_db,
update_fields=update_fields,
force_create=force_create,
force_update=force_update,
)
else:
await super().save(
using_db=using_db,
update_fields=update_fields,
force_create=force_create,
force_update=force_update,
)
if CACHE_FLAG and (cache_type := getattr(self, "cache_type", None)):
await CacheRoot.reload(cache_type)
async def delete(self, using_db: BaseDBAsyncClient | None = None):
await super().delete(using_db=using_db)
if CACHE_FLAG and (cache_type := getattr(self, "cache_type", None)):
await CacheRoot.reload(cache_type)
2024-02-04 04:18:54 +08:00
2025-07-05 02:48:59 +08:00
class DbUrlMissing(Exception):
2024-09-02 21:45:37 +08:00
"""
数据库链接地址为空
"""
pass
class DbConnectError(Exception):
"""
数据库连接错误
"""
pass
2025-07-02 16:13:47 +08:00
@PriorityLifecycle.on_startup(priority=1)
2024-02-04 04:18:54 +08:00
async def init():
2024-08-24 19:32:52 +08:00
if not BotConfig.db_url:
2025-07-05 02:48:59 +08:00
# raise DbUrlMissing("数据库配置为空,请在.env.dev中配置DB_URL...")
2025-07-02 16:13:47 +08:00
error = f"""
**********************************************************************
🌟 **************************** 配置为空 ************************* 🌟
🚀 请打开 WebUi 进行基础配置 🚀
🌐 配置地址http://{driver.config.host}:{driver.config.port}/#/configure 🌐
***********************************************************************
***********************************************************************
"""
2025-07-05 02:48:59 +08:00
raise DbUrlMissing("\n" + error.strip())
2024-05-26 15:22:55 +08:00
try:
2025-07-05 02:48:59 +08:00
await Tortoise.init(
db_url=BotConfig.db_url,
modules={"models": MODELS},
timezone="Asia/Shanghai",
)
2024-05-26 15:22:55 +08:00
if SCRIPT_METHOD:
db = Tortoise.get_connection("default")
logger.debug(
2024-09-02 21:45:37 +08:00
"即将运行SCRIPT_METHOD方法, 合计 "
f"<u><y>{len(SCRIPT_METHOD)}</y></u> 个..."
2024-05-26 15:22:55 +08:00
)
sql_list = []
for module, func in SCRIPT_METHOD:
try:
2024-09-02 21:45:37 +08:00
sql = await func() if is_coroutine_callable(func) else func()
2024-05-26 15:22:55 +08:00
if sql:
sql_list += sql
except Exception as e:
logger.debug(f"{module} 执行SCRIPT_METHOD方法出错...", e=e)
for sql in sql_list:
logger.debug(f"执行SQL: {sql}")
try:
await db.execute_query_dict(sql)
2025-07-05 02:48:59 +08:00
# await TestSQL.raw(sql)
2024-05-26 15:22:55 +08:00
except Exception as e:
logger.debug(f"执行SQL: {sql} 错误...", e=e)
2024-07-20 00:45:26 +08:00
if sql_list:
logger.debug("SCRIPT_METHOD方法执行完毕!")
2024-05-26 15:22:55 +08:00
await Tortoise.generate_schemas()
2024-09-02 21:45:37 +08:00
logger.info("Database loaded successfully!")
2024-05-26 15:22:55 +08:00
except Exception as e:
2024-09-02 21:45:37 +08:00
raise DbConnectError(f"数据库连接错误... e:{e}") from e
2024-02-04 04:18:54 +08:00
async def disconnect():
2025-07-05 02:48:59 +08:00
await connections.close_all()