zhenxun_plugin_farm/core/database/database.py

400 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from contextlib import asynccontextmanager
import os
from pathlib import Path
import re
from typing import Any
import aiosqlite
from zhenxun.services.log import logger
from ...utils.config import g_sDBFilePath, g_sDBPath
class CSqlManager:
def __init__(self):
dbPath = Path(g_sDBPath)
if dbPath and not dbPath.exists():
os.makedirs(dbPath, exist_ok=True)
@classmethod
async def cleanup(cls):
if hasattr(cls, "m_pDB") and cls.m_pDB:
await cls.m_pDB.close()
@classmethod
async def init(cls) -> bool:
try:
cls.m_pDB = await aiosqlite.connect(g_sDBFilePath)
cls.m_pDB.row_factory = aiosqlite.Row
return True
except Exception as e:
logger.debug("真寻农场初始化总数据库失败", e=e)
return False
@classmethod
@asynccontextmanager
async def _transaction(cls):
await cls.m_pDB.execute("BEGIN;")
try:
yield
except:
await cls.m_pDB.execute("ROLLBACK;")
raise
else:
await cls.m_pDB.execute("COMMIT;")
@classmethod
async def getTableInfo(cls, tableName: str) -> list:
if not re.match(r"^[A-Za-z_][A-Za-z0-9_]*$", tableName):
raise ValueError(f"Illegal table name: {tableName}")
try:
cursor = await cls.m_pDB.execute(f'PRAGMA table_info("{tableName}")')
rows = await cursor.fetchall()
return [{"name": row[1], "type": row[2]} for row in rows]
except aiosqlite.Error:
return []
@classmethod
async def ensureTableSchema(cls, tableName: str, columns: dict) -> bool:
"""由AI生成
创建表或为已存在表添加缺失字段。
返回 True 表示有变更创建或新增列False 则无操作
Args:
tableName (_type_): 表名
columns (_type_): 字典
Returns:
_type_: _description_
"""
info = await cls.getTableInfo(tableName)
existing = {col["name"]: col["type"].upper() for col in info}
desired = {k: v.upper() for k, v in columns.items() if k != "PRIMARY KEY"}
primaryKey = columns.get("PRIMARY KEY", "")
if not existing:
colsDef = ", ".join(f'"{k}" {v}' for k, v in desired.items())
if primaryKey:
colsDef += f", PRIMARY KEY {primaryKey}"
await cls.m_pDB.execute(f'CREATE TABLE "{tableName}" ({colsDef});')
return True
toAdd = [k for k in desired if k not in existing]
toRemove = [k for k in existing if k not in desired]
typeMismatch = [
k for k in desired if k in existing and existing[k] != desired[k]
]
if toAdd and not toRemove and not typeMismatch:
for col in toAdd:
await cls.m_pDB.execute(
f'ALTER TABLE "{tableName}" ADD COLUMN "{col}" {columns[col]}'
)
return True
async with cls._transaction():
tmpTable = f"{tableName}_new"
colsDef = ", ".join(f'"{k}" {v}' for k, v in desired.items())
if primaryKey:
colsDef += f", PRIMARY KEY {primaryKey}"
await cls.m_pDB.execute(f'CREATE TABLE "{tmpTable}" ({colsDef});')
commonCols = [k for k in desired if k in existing]
if commonCols:
colsStr = ", ".join(f'"{c}"' for c in commonCols)
sql = (
f'INSERT INTO "{tmpTable}" ({colsStr}) '
f"SELECT {colsStr} "
f'FROM "{tableName}";'
)
await cls.m_pDB.execute(sql)
await cls.m_pDB.execute(f'DROP TABLE "{tableName}";')
await cls.m_pDB.execute(
f'ALTER TABLE "{tmpTable}" RENAME TO "{tableName}";'
)
return True
@classmethod
async def executeDB(cls, command: str) -> bool:
"""执行自定义SQL
Args:
command (str): SQL语句
Returns:
bool: 是否执行成功
"""
if not command:
return False
try:
async with cls._transaction():
await cls.m_pDB.execute(command)
return True
except Exception:
return False
@classmethod
async def insert(cls, tableName: str, data: dict) -> bool:
"""
插入数据
Args:
tableName: 表名
data: 要插入的数据字典,键为字段名,值为字段值
Returns:
bool: 是否执行成功
"""
if not data:
return False
try:
# 构建参数化查询
columns = ", ".join(f'"{k}"' for k in data.keys())
placeholders = ", ".join("?" for _ in data.keys())
values = list(data.values())
sql = f'INSERT INTO "{tableName}" ({columns}) VALUES ({placeholders})'
async with cls._transaction():
await cls.m_pDB.execute(sql, values)
return True
except Exception as e:
logger.debug("真寻农场插入数据失败!", e=e)
return False
@classmethod
async def batch_insert(cls, tableName: str, data_list: list) -> bool:
"""
批量插入数据
Args:
tableName: 表名
data_list: 要插入的数据字典列表
Returns:
bool: 是否执行成功
"""
if not data_list:
return False
try:
# 使用第一个字典的键作为所有记录的字段
columns = ", ".join(f'"{k}"' for k in data_list[0].keys())
placeholders = ", ".join("?" for _ in data_list[0].keys())
sql = f'INSERT INTO "{tableName}" ({columns}) VALUES ({placeholders})'
async with cls._transaction():
await cls.m_pDB.executemany(
sql, [list(data.values()) for data in data_list]
)
return True
except Exception as e:
logger.debug("真寻农场批量插入数据失败!", e=e)
return False
@classmethod
async def select(
cls,
tableName: str,
columns: list[Any] | None = None,
where: dict[str, Any] | None = None,
order_by: str | None = None,
limit: int | None = None,
) -> list[dict]:
"""
查询数据
Args:
tableName: 表名
columns: 要查询的字段列表None表示所有字段
where: 查询条件字典
order_by: 排序字段
limit: 限制返回记录数
Returns:
list: 查询结果列表,每个元素是一个字典
"""
try:
# 构建SELECT部分
if columns:
select_clause = ", ".join(f'"{col}"' for col in columns)
else:
select_clause = "*"
sql = f'SELECT {select_clause} FROM "{tableName}"'
# 构建WHERE部分
params = []
if where:
where_conditions = []
for key, value in where.items():
if isinstance(value, (list, tuple)):
# 处理IN查询
placeholders = ", ".join("?" for _ in value)
where_conditions.append(f'"{key}" IN ({placeholders})')
params.extend(value)
else:
where_conditions.append(f'"{key}" = ?')
params.append(value)
if where_conditions:
sql += " WHERE " + " AND ".join(where_conditions)
# 构建ORDER BY部分
if order_by:
sql += f" ORDER BY {order_by}"
# 构建LIMIT部分
if limit:
sql += f" LIMIT {limit}"
cursor = await cls.m_pDB.execute(sql, params)
rows = await cursor.fetchall()
# 转换为字典列表
result = []
for row in rows:
result.append(dict(row))
return result
except Exception as e:
logger.debug("真寻农场查询数据失败!", e=e)
return []
@classmethod
async def update(cls, tableName: str, data: dict, where: dict) -> bool:
"""
更新数据
Args:
tableName: 表名
data: 要更新的数据字典
where: 更新条件字典
Returns:
bool: 是否执行成功
"""
if not data:
return False
if not where:
return False
try:
# 构建SET部分
set_conditions = []
params = []
for key, value in data.items():
set_conditions.append(f'"{key}" = ?')
params.append(value)
# 构建WHERE部分
where_conditions = []
for key, value in where.items():
where_conditions.append(f'"{key}" = ?')
params.append(value)
sql = f'UPDATE "{tableName}" SET {", ".join(set_conditions)} WHERE {" AND ".join(where_conditions)}'
async with cls._transaction():
cursor = await cls.m_pDB.execute(sql, params)
# 检查是否影响了行
return cursor.rowcount > 0
except Exception as e:
logger.debug("真寻农场更新数据失败!", e=e)
return False
@classmethod
async def delete(cls, tableName: str, where: dict) -> bool:
"""
删除数据
Args:
tableName: 表名
where: 删除条件字典
Returns:
bool: 是否执行成功
"""
if not where:
return False
try:
# 构建WHERE部分
where_conditions = []
params = []
for key, value in where.items():
where_conditions.append(f'"{key}" = ?')
params.append(value)
sql = f'DELETE FROM "{tableName}" WHERE {" AND ".join(where_conditions)}'
async with cls._transaction():
cursor = await cls.m_pDB.execute(sql, params)
# 检查是否影响了行
return cursor.rowcount > 0
except Exception as e:
logger.debug("真寻农场删除数据失败!", e=e)
return False
@classmethod
async def exists(cls, tableName: str, where: dict) -> bool:
"""
检查记录是否存在
Args:
tableName: 表名
where: 查询条件字典
Returns:
bool: 是否存在符合条件的记录
"""
try:
result = await cls.select(tableName, columns=["1"], where=where, limit=1)
return len(result) > 0
except Exception as e:
logger.debug("真寻农场检查数据失败!", e=e)
return False
@classmethod
async def count(cls, tableName: str, where: dict = {}) -> int:
"""
统计记录数量
Args:
tableName: 表名
where: 查询条件字典
Returns:
int: 记录数量
"""
try:
# 构建WHERE部分
sql = f'SELECT COUNT(*) as count FROM "{tableName}"'
params = []
if where:
where_conditions = []
for key, value in where.items():
where_conditions.append(f'"{key}" = ?')
params.append(value)
sql += " WHERE " + " AND ".join(where_conditions)
cursor = await cls.m_pDB.execute(sql, params)
row = await cursor.fetchone()
return row["count"] if row else 0
except Exception as e:
logger.debug("真寻农场统计数据失败!", e=e)
return 0
g_pSqlManager = CSqlManager()