from contextlib import asynccontextmanager import os from pathlib import Path import re from typing import Any import aiosqlite from zhenxun.services.log import logger from ...utils.config import g_sDBFilePath, g_sDBPath class CSqlManager: def __init__(self): dbPath = Path(g_sDBPath) if dbPath and not dbPath.exists(): os.makedirs(dbPath, exist_ok=True) @classmethod async def cleanup(cls): if hasattr(cls, "m_pDB") and cls.m_pDB: await cls.m_pDB.close() @classmethod async def init(cls) -> bool: try: cls.m_pDB = await aiosqlite.connect(g_sDBFilePath) cls.m_pDB.row_factory = aiosqlite.Row return True except Exception as e: logger.debug("真寻农场初始化总数据库失败", e=e) return False @classmethod @asynccontextmanager async def _transaction(cls): await cls.m_pDB.execute("BEGIN;") try: yield except: await cls.m_pDB.execute("ROLLBACK;") raise else: await cls.m_pDB.execute("COMMIT;") @classmethod async def getTableInfo(cls, tableName: str) -> list: if not re.match(r"^[A-Za-z_][A-Za-z0-9_]*$", tableName): raise ValueError(f"Illegal table name: {tableName}") try: cursor = await cls.m_pDB.execute(f'PRAGMA table_info("{tableName}")') rows = await cursor.fetchall() return [{"name": row[1], "type": row[2]} for row in rows] except aiosqlite.Error: return [] @classmethod async def ensureTableSchema(cls, tableName: str, columns: dict) -> bool: """由AI生成 创建表或为已存在表添加缺失字段。 返回 True 表示有变更(创建或新增列),False 则无操作 Args: tableName (_type_): 表名 columns (_type_): 字典 Returns: _type_: _description_ """ info = await cls.getTableInfo(tableName) existing = {col["name"]: col["type"].upper() for col in info} desired = {k: v.upper() for k, v in columns.items() if k != "PRIMARY KEY"} primaryKey = columns.get("PRIMARY KEY", "") if not existing: colsDef = ", ".join(f'"{k}" {v}' for k, v in desired.items()) if primaryKey: colsDef += f", PRIMARY KEY {primaryKey}" await cls.m_pDB.execute(f'CREATE TABLE "{tableName}" ({colsDef});') return True toAdd = [k for k in desired if k not in existing] toRemove = [k for k in existing if k not in desired] typeMismatch = [ k for k in desired if k in existing and existing[k] != desired[k] ] if toAdd and not toRemove and not typeMismatch: for col in toAdd: await cls.m_pDB.execute( f'ALTER TABLE "{tableName}" ADD COLUMN "{col}" {columns[col]}' ) return True async with cls._transaction(): tmpTable = f"{tableName}_new" colsDef = ", ".join(f'"{k}" {v}' for k, v in desired.items()) if primaryKey: colsDef += f", PRIMARY KEY {primaryKey}" await cls.m_pDB.execute(f'CREATE TABLE "{tmpTable}" ({colsDef});') commonCols = [k for k in desired if k in existing] if commonCols: colsStr = ", ".join(f'"{c}"' for c in commonCols) sql = ( f'INSERT INTO "{tmpTable}" ({colsStr}) ' f"SELECT {colsStr} " f'FROM "{tableName}";' ) await cls.m_pDB.execute(sql) await cls.m_pDB.execute(f'DROP TABLE "{tableName}";') await cls.m_pDB.execute( f'ALTER TABLE "{tmpTable}" RENAME TO "{tableName}";' ) return True @classmethod async def executeDB(cls, command: str) -> bool: """执行自定义SQL Args: command (str): SQL语句 Returns: bool: 是否执行成功 """ if not command: return False try: async with cls._transaction(): await cls.m_pDB.execute(command) return True except Exception: return False @classmethod async def insert(cls, tableName: str, data: dict) -> bool: """ 插入数据 Args: tableName: 表名 data: 要插入的数据字典,键为字段名,值为字段值 Returns: bool: 是否执行成功 """ if not data: return False try: # 构建参数化查询 columns = ", ".join(f'"{k}"' for k in data.keys()) placeholders = ", ".join("?" for _ in data.keys()) values = list(data.values()) sql = f'INSERT INTO "{tableName}" ({columns}) VALUES ({placeholders})' async with cls._transaction(): await cls.m_pDB.execute(sql, values) return True except Exception as e: logger.debug("真寻农场插入数据失败!", e=e) return False @classmethod async def batch_insert(cls, tableName: str, data_list: list) -> bool: """ 批量插入数据 Args: tableName: 表名 data_list: 要插入的数据字典列表 Returns: bool: 是否执行成功 """ if not data_list: return False try: # 使用第一个字典的键作为所有记录的字段 columns = ", ".join(f'"{k}"' for k in data_list[0].keys()) placeholders = ", ".join("?" for _ in data_list[0].keys()) sql = f'INSERT INTO "{tableName}" ({columns}) VALUES ({placeholders})' async with cls._transaction(): await cls.m_pDB.executemany( sql, [list(data.values()) for data in data_list] ) return True except Exception as e: logger.debug("真寻农场批量插入数据失败!", e=e) return False @classmethod async def select( cls, tableName: str, columns: list[Any] | None = None, where: dict[str, Any] | None = None, order_by: str | None = None, limit: int | None = None, ) -> list[dict]: """ 查询数据 Args: tableName: 表名 columns: 要查询的字段列表,None表示所有字段 where: 查询条件字典 order_by: 排序字段 limit: 限制返回记录数 Returns: list: 查询结果列表,每个元素是一个字典 """ try: # 构建SELECT部分 if columns: select_clause = ", ".join(f'"{col}"' for col in columns) else: select_clause = "*" sql = f'SELECT {select_clause} FROM "{tableName}"' # 构建WHERE部分 params = [] if where: where_conditions = [] for key, value in where.items(): if isinstance(value, (list, tuple)): # 处理IN查询 placeholders = ", ".join("?" for _ in value) where_conditions.append(f'"{key}" IN ({placeholders})') params.extend(value) else: where_conditions.append(f'"{key}" = ?') params.append(value) if where_conditions: sql += " WHERE " + " AND ".join(where_conditions) # 构建ORDER BY部分 if order_by: sql += f" ORDER BY {order_by}" # 构建LIMIT部分 if limit: sql += f" LIMIT {limit}" cursor = await cls.m_pDB.execute(sql, params) rows = await cursor.fetchall() # 转换为字典列表 result = [] for row in rows: result.append(dict(row)) return result except Exception as e: logger.debug("真寻农场查询数据失败!", e=e) return [] @classmethod async def update(cls, tableName: str, data: dict, where: dict) -> bool: """ 更新数据 Args: tableName: 表名 data: 要更新的数据字典 where: 更新条件字典 Returns: bool: 是否执行成功 """ if not data: return False if not where: return False try: # 构建SET部分 set_conditions = [] params = [] for key, value in data.items(): set_conditions.append(f'"{key}" = ?') params.append(value) # 构建WHERE部分 where_conditions = [] for key, value in where.items(): where_conditions.append(f'"{key}" = ?') params.append(value) sql = f'UPDATE "{tableName}" SET {", ".join(set_conditions)} WHERE {" AND ".join(where_conditions)}' async with cls._transaction(): cursor = await cls.m_pDB.execute(sql, params) # 检查是否影响了行 return cursor.rowcount > 0 except Exception as e: logger.debug("真寻农场更新数据失败!", e=e) return False @classmethod async def delete(cls, tableName: str, where: dict) -> bool: """ 删除数据 Args: tableName: 表名 where: 删除条件字典 Returns: bool: 是否执行成功 """ if not where: return False try: # 构建WHERE部分 where_conditions = [] params = [] for key, value in where.items(): where_conditions.append(f'"{key}" = ?') params.append(value) sql = f'DELETE FROM "{tableName}" WHERE {" AND ".join(where_conditions)}' async with cls._transaction(): cursor = await cls.m_pDB.execute(sql, params) # 检查是否影响了行 return cursor.rowcount > 0 except Exception as e: logger.debug("真寻农场删除数据失败!", e=e) return False @classmethod async def exists(cls, tableName: str, where: dict) -> bool: """ 检查记录是否存在 Args: tableName: 表名 where: 查询条件字典 Returns: bool: 是否存在符合条件的记录 """ try: result = await cls.select(tableName, columns=["1"], where=where, limit=1) return len(result) > 0 except Exception as e: logger.debug("真寻农场检查数据失败!", e=e) return False @classmethod async def count(cls, tableName: str, where: dict = {}) -> int: """ 统计记录数量 Args: tableName: 表名 where: 查询条件字典 Returns: int: 记录数量 """ try: # 构建WHERE部分 sql = f'SELECT COUNT(*) as count FROM "{tableName}"' params = [] if where: where_conditions = [] for key, value in where.items(): where_conditions.append(f'"{key}" = ?') params.append(value) sql += " WHERE " + " AND ".join(where_conditions) cursor = await cls.m_pDB.execute(sql, params) row = await cursor.fetchone() return row["count"] if row else 0 except Exception as e: logger.debug("真寻农场统计数据失败!", e=e) return 0 g_pSqlManager = CSqlManager()