zhenxun_bot/zhenxun/services/llm/core.py
webjoin111 bba90e62db ♻️ refactor(llm): 重构 LLM 服务架构,引入中间件与组件化适配器
- 【重构】LLM 服务核心架构:
    - 引入中间件管道,统一处理请求生命周期(重试、密钥选择、日志、网络请求)。
    - 适配器重构为组件化设计,分离配置映射、消息转换、响应解析和工具序列化逻辑。
    - 移除 `with_smart_retry` 装饰器,其功能由中间件接管。
    - 移除 `LLMToolExecutor`,工具执行逻辑集成到 `ToolInvoker`。
- 【功能】增强配置系统:
    - `LLMGenerationConfig` 采用组件化结构(Core, Reasoning, Visual, Output, Safety, ToolConfig)。
    - 新增 `GenConfigBuilder` 提供语义化配置构建方式。
    - 新增 `LLMEmbeddingConfig` 用于嵌入专用配置。
    - `CommonOverrides` 迁移并更新至新配置结构。
- 【功能】强化工具系统:
    - 引入 `ToolInvoker` 实现更灵活的工具执行,支持回调与结构化错误。
    - `function_tool` 装饰器支持动态 Pydantic 模型创建和依赖注入 (`ToolParam`, `RunContext`)。
    - 平台原生工具支持 (`GeminiCodeExecution`, `GeminiGoogleSearch`, `GeminiUrlContext`)。
- 【功能】高级生成与嵌入:
    - `generate_structured` 方法支持 In-Context Validation and Repair (IVR) 循环和 AutoCoT (思维链) 包装。
    - 新增 `embed_query` 和 `embed_documents` 便捷嵌入 API。
    - `OpenAIImageAdapter` 支持 OpenAI 兼容的图像生成。
    - `SmartAdapter` 实现模型名称智能路由。
- 【重构】消息与类型系统:
    - `LLMContentPart` 扩展支持更多模态和代码执行相关内容。
    - `LLMMessage` 和 `LLMResponse` 结构更新,支持 `content_parts` 和思维链签名。
    - 统一 `LLMErrorCode` 和用户友好错误消息,提供更详细的网络/代理错误提示。
    - `pyproject.toml` 移除 `bilireq`,新增 `json_repair`。
- 【优化】日志与调试:
    - 引入 `DebugLogOptions`,提供细粒度日志脱敏控制。
    - 增强日志净化器,处理更多敏感数据和长字符串。
- 【清理】删除废弃模块:
    - `zhenxun/services/llm/memory.py`
    - `zhenxun/services/llm/executor.py`
    - `zhenxun/services/llm/config/presets.py`
    - `zhenxun/services/llm/types/content.py`
    - `zhenxun/services/llm/types/enums.py`
    - `zhenxun/services/llm/tools/__init__.py`
    - `zhenxun/services/llm/tools/manager.py`
2025-12-07 18:57:55 +08:00

597 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
LLM 核心基础设施模块
包含执行 LLM 请求所需的底层组件,如 HTTP 客户端、API Key 存储和智能重试逻辑。
"""
import asyncio
from dataclasses import asdict, dataclass
from enum import IntEnum
import json
import os
import time
from typing import Any
import aiofiles
import httpx
import nonebot
from pydantic import BaseModel
from zhenxun.configs.path_config import DATA_PATH
from zhenxun.services.log import logger
from zhenxun.utils.user_agent import get_user_agent
from .types import ProviderConfig
from .types.exceptions import LLMErrorCode, LLMException
driver = nonebot.get_driver()
class HttpClientConfig(BaseModel):
"""HTTP客户端配置"""
timeout: int = 180
max_connections: int = 100
max_keepalive_connections: int = 20
proxy: str | None = None
class LLMHttpClient:
"""LLM服务专用HTTP客户端"""
def __init__(self, config: HttpClientConfig | None = None):
self.config = config or HttpClientConfig()
self._client: httpx.AsyncClient | None = None
self._active_requests = 0
self._lock = asyncio.Lock()
async def _ensure_client_initialized(self) -> httpx.AsyncClient:
if self._client is None or self._client.is_closed:
async with self._lock:
if self._client is None or self._client.is_closed:
logger.debug(
f"LLMHttpClient: 正在初始化新的 httpx.AsyncClient "
f"配置: {self.config}"
)
headers = get_user_agent()
limits = httpx.Limits(
max_connections=self.config.max_connections,
max_keepalive_connections=self.config.max_keepalive_connections,
)
timeout = httpx.Timeout(self.config.timeout)
client_kwargs = {}
if self.config.proxy:
try:
version_parts = httpx.__version__.split(".")
major = int(
"".join(c for c in version_parts[0] if c.isdigit())
)
minor = (
int("".join(c for c in version_parts[1] if c.isdigit()))
if len(version_parts) > 1
else 0
)
if (major, minor) >= (0, 28):
client_kwargs["proxy"] = self.config.proxy
else:
client_kwargs["proxies"] = self.config.proxy
except (ValueError, IndexError):
client_kwargs["proxies"] = self.config.proxy
logger.warning(
f"无法解析 httpx 版本 '{httpx.__version__}'"
"LLM模块将默认使用旧版 'proxies' 参数语法。"
)
self._client = httpx.AsyncClient(
headers=headers,
limits=limits,
timeout=timeout,
follow_redirects=True,
**client_kwargs,
)
if self._client is None:
raise LLMException(
"HTTP 客户端初始化失败。", LLMErrorCode.CONFIGURATION_ERROR
)
return self._client
async def post(self, url: str, **kwargs: Any) -> httpx.Response:
client = await self._ensure_client_initialized()
async with self._lock:
self._active_requests += 1
try:
return await client.post(url, **kwargs)
finally:
async with self._lock:
self._active_requests -= 1
async def close(self):
async with self._lock:
if self._client and not self._client.is_closed:
logger.debug(
f"LLMHttpClient: 正在关闭,配置: {self.config}. "
f"活跃请求数: {self._active_requests}"
)
if self._active_requests > 0:
logger.warning(
f"LLMHttpClient: 关闭时仍有 {self._active_requests} "
f"个请求处于活跃状态。"
)
await self._client.aclose()
self._client = None
logger.debug(f"配置为 {self.config} 的 LLMHttpClient 已完全关闭。")
@property
def is_closed(self) -> bool:
return self._client is None or self._client.is_closed
class LLMHttpClientManager:
"""管理 LLMHttpClient 实例的工厂和池"""
def __init__(self):
self._clients: dict[tuple[int, str | None], LLMHttpClient] = {}
self._lock = asyncio.Lock()
def _get_client_key(
self, provider_config: ProviderConfig
) -> tuple[int, str | None]:
return (provider_config.timeout, provider_config.proxy)
async def get_client(self, provider_config: ProviderConfig) -> LLMHttpClient:
key = self._get_client_key(provider_config)
async with self._lock:
client = self._clients.get(key)
if client and not client.is_closed:
logger.debug(
f"LLMHttpClientManager: 复用现有的 LLMHttpClient 密钥: {key}"
)
return client
if client and client.is_closed:
logger.debug(
f"LLMHttpClientManager: 发现密钥 {key} 对应的客户端已关闭。"
f"正在创建新的客户端。"
)
logger.debug(f"LLMHttpClientManager: 为密钥 {key} 创建新的 LLMHttpClient")
http_client_config = HttpClientConfig(
timeout=provider_config.timeout, proxy=provider_config.proxy
)
new_client = LLMHttpClient(config=http_client_config)
self._clients[key] = new_client
return new_client
async def shutdown(self):
async with self._lock:
logger.info(
f"LLMHttpClientManager: 正在关闭。关闭 {len(self._clients)} 个客户端。"
)
close_tasks = [
client.close()
for client in self._clients.values()
if client and not client.is_closed
]
if close_tasks:
await asyncio.gather(*close_tasks, return_exceptions=True)
self._clients.clear()
logger.info("LLMHttpClientManager: 关闭完成。")
http_client_manager = LLMHttpClientManager()
async def create_llm_http_client(
timeout: int = 180,
proxy: str | None = None,
) -> LLMHttpClient:
"""
创建LLM HTTP客户端
参数:
timeout: 超时时间(秒)。
proxy: 代理服务器地址。
返回:
LLMHttpClient: HTTP客户端实例。
"""
config = HttpClientConfig(timeout=timeout, proxy=proxy)
return LLMHttpClient(config)
class KeyStatus(IntEnum):
"""用于排序和展示的密钥状态枚举"""
DISABLED = 0
ERROR = 1
COOLDOWN = 2
WARNING = 3
HEALTHY = 4
UNUSED = 5
@dataclass
class KeyStats:
"""单个API Key的详细状态和统计信息"""
cooldown_until: float = 0.0
success_count: int = 0
failure_count: int = 0
total_latency: float = 0.0
last_error_info: str | None = None
@property
def is_available(self) -> bool:
"""检查Key当前是否可用"""
return time.time() >= self.cooldown_until
@property
def avg_latency(self) -> float:
"""计算平均延迟"""
return (
self.total_latency / self.success_count if self.success_count > 0 else 0.0
)
@property
def success_rate(self) -> float:
"""计算成功率"""
total = self.success_count + self.failure_count
return self.success_count / total * 100 if total > 0 else 100.0
@property
def status(self) -> KeyStatus:
"""根据当前统计数据动态计算状态"""
now = time.time()
cooldown_left = max(0, self.cooldown_until - now)
if cooldown_left > 31536000 - 60:
return KeyStatus.DISABLED
if cooldown_left > 0:
return KeyStatus.COOLDOWN
total_calls = self.success_count + self.failure_count
if total_calls == 0:
return KeyStatus.UNUSED
if self.success_rate < 70:
return KeyStatus.ERROR
if total_calls >= 5 and self.avg_latency > 15000:
return KeyStatus.WARNING
return KeyStatus.HEALTHY
@property
def suggested_action(self) -> str:
"""根据状态给出建议操作"""
status_actions = {
KeyStatus.DISABLED: "更换Key",
KeyStatus.ERROR: "检查网络/重置",
KeyStatus.COOLDOWN: "等待/重置",
KeyStatus.WARNING: "观察",
KeyStatus.HEALTHY: "-",
KeyStatus.UNUSED: "-",
}
return status_actions.get(self.status, "未知")
class RetryConfig:
"""重试配置"""
def __init__(
self,
max_retries: int = 3,
retry_delay: float = 1.0,
exponential_backoff: bool = True,
key_rotation: bool = True,
):
self.max_retries = max_retries
self.retry_delay = retry_delay
self.exponential_backoff = exponential_backoff
self.key_rotation = key_rotation
def _should_retry_llm_error(
error: LLMException, attempt: int, max_retries: int
) -> bool:
"""判断LLM错误是否应该重试"""
non_retryable_errors = {
LLMErrorCode.MODEL_NOT_FOUND,
LLMErrorCode.CONTEXT_LENGTH_EXCEEDED,
LLMErrorCode.USER_LOCATION_NOT_SUPPORTED,
LLMErrorCode.INVALID_PARAMETER,
LLMErrorCode.CONFIGURATION_ERROR,
LLMErrorCode.API_KEY_INVALID,
}
if error.code in non_retryable_errors:
return False
retryable_errors = {
LLMErrorCode.API_REQUEST_FAILED,
LLMErrorCode.API_TIMEOUT,
LLMErrorCode.API_RATE_LIMITED,
LLMErrorCode.API_RESPONSE_INVALID,
LLMErrorCode.RESPONSE_PARSE_ERROR,
LLMErrorCode.GENERATION_FAILED,
LLMErrorCode.CONTENT_FILTERED,
LLMErrorCode.API_QUOTA_EXCEEDED,
}
if error.code in retryable_errors:
if error.code == LLMErrorCode.API_QUOTA_EXCEEDED:
return attempt < min(2, max_retries)
return True
return False
class KeyStatusStore:
"""API Key 状态管理存储 - 支持持久化"""
def __init__(self):
self._key_stats: dict[str, KeyStats] = {}
self._provider_key_index: dict[str, int] = {}
self._lock = asyncio.Lock()
self._file_path = DATA_PATH / "llm" / "key_status.json"
async def initialize(self):
"""从文件异步加载密钥状态,在应用启动时调用"""
async with self._lock:
if not self._file_path.exists():
logger.info("未找到密钥状态文件,将使用内存状态启动。")
return
try:
logger.info(f"正在从 {self._file_path} 加载密钥状态...")
async with aiofiles.open(self._file_path, encoding="utf-8") as f:
content = await f.read()
if not content:
logger.warning("密钥状态文件为空。")
return
data = json.loads(content)
for key, stats_dict in data.items():
self._key_stats[key] = KeyStats(**stats_dict)
logger.info(f"成功加载 {len(self._key_stats)} 个密钥的状态。")
except json.JSONDecodeError:
logger.error(f"密钥状态文件 {self._file_path} 格式错误,无法解析。")
except Exception as e:
logger.error(f"加载密钥状态文件时发生错误: {e}", e=e)
async def _save_to_file_internal(self):
"""
[内部方法] 将当前密钥状态安全地写入JSON文件。
假定调用方已持有锁。
"""
data_to_save = {key: asdict(stats) for key, stats in self._key_stats.items()}
try:
self._file_path.parent.mkdir(parents=True, exist_ok=True)
temp_path = self._file_path.with_suffix(".json.tmp")
async with aiofiles.open(temp_path, "w", encoding="utf-8") as f:
await f.write(json.dumps(data_to_save, ensure_ascii=False, indent=2))
if self._file_path.exists():
self._file_path.unlink()
os.rename(temp_path, self._file_path)
logger.debug("密钥状态已成功持久化到文件。")
except Exception as e:
logger.error(f"保存密钥状态到文件失败: {e}", e=e)
async def shutdown(self):
"""在应用关闭时安全地保存状态"""
async with self._lock:
await self._save_to_file_internal()
logger.info("KeyStatusStore 已在关闭前保存状态。")
async def get_next_available_key(
self,
provider_name: str,
api_keys: list[str],
exclude_keys: set[str] | None = None,
) -> str | None:
"""
获取下一个可用的API密钥轮询策略
参数:
provider_name: 提供商名称。
api_keys: API密钥列表。
exclude_keys: 要排除的密钥集合。
返回:
str | None: 可用的API密钥如果没有可用密钥则返回None。
"""
if not api_keys:
return None
exclude_keys = exclude_keys or set()
async with self._lock:
for key in api_keys:
if key not in self._key_stats:
self._key_stats[key] = KeyStats()
available_keys = [
key
for key in api_keys
if key not in exclude_keys and self._key_stats[key].is_available
]
if not available_keys:
return api_keys[0]
current_index = self._provider_key_index.get(provider_name, 0)
selected_key = available_keys[current_index % len(available_keys)]
self._provider_key_index[provider_name] = current_index + 1
total_usage = (
self._key_stats[selected_key].success_count
+ self._key_stats[selected_key].failure_count
)
logger.debug(
f"轮询选择API密钥: {self._get_key_id(selected_key)} "
f"(使用次数: {total_usage})"
)
return selected_key
async def record_success(self, api_key: str, latency: float):
"""记录成功使用,并持久化"""
async with self._lock:
stats = self._key_stats.setdefault(api_key, KeyStats())
stats.cooldown_until = 0.0
stats.success_count += 1
stats.total_latency += latency
stats.last_error_info = None
await self._save_to_file_internal()
logger.debug(
f"记录API密钥成功使用: {self._get_key_id(api_key)}, 延迟: {latency:.2f}ms"
)
async def record_failure(
self, api_key: str, status_code: int | None, error_message: str
):
"""
记录失败使用,并设置冷却时间
参数:
api_key: API密钥。
status_code: HTTP状态码。
error_message: 错误信息。
"""
key_id = self._get_key_id(api_key)
now = time.time()
cooldown_duration = 300
location_not_supported = error_message and (
"USER_LOCATION_NOT_SUPPORTED" in error_message
or "User location is not supported" in error_message
)
if location_not_supported:
logger.warning(
f"API Key {key_id} 请求失败,原因是地区不支持 (Gemini)。"
" 这通常是代理节点问题Key 本身可能是正常的。跳过冷却。"
)
async with self._lock:
stats = self._key_stats.setdefault(api_key, KeyStats())
stats.failure_count += 1
stats.last_error_info = error_message[:256]
await self._save_to_file_internal()
return
if error_message and (
"API_QUOTA_EXCEEDED" in error_message
or "insufficient_quota" in error_message.lower()
):
cooldown_duration = 3600
logger.warning(f"API Key {key_id} 额度耗尽,冷却 1 小时。")
is_key_invalid = status_code == 401 or (
status_code == 400
and error_message
and (
"API_KEY_INVALID" in error_message
or "API key not valid" in error_message
)
)
if is_key_invalid:
cooldown_duration = 31536000
log_level = "error"
log_message = f"API密钥认证/权限/路径错误,将永久禁用: {key_id}"
elif status_code == 403:
cooldown_duration = 3600
log_level = "warning"
log_message = f"API密钥权限不足或地区不支持(403)冷却1小时: {key_id}"
elif status_code == 404:
log_level = "error"
log_message = "API请求返回 404 (未找到),可能是模型名称错误或接口地址"
f"错误,不冷却密钥: {key_id}"
elif status_code == 422:
cooldown_duration = 0
log_level = "warning"
log_message = f"API请求无法处理(422),可能是生成故障,不冷却密钥: {key_id}"
elif status_code == 429:
cooldown_duration = 60
log_level = "warning"
log_message = f"API密钥被限流冷却60秒: {key_id}"
elif error_message and (
"ConnectError" in error_message
or "NetworkError" in error_message
or "Connection refused" in error_message
or "RemoteProtocolError" in error_message
or "ProxyError" in error_message
):
cooldown_duration = 0
log_level = "warning"
log_message = f"网络连接层异常(代理/DNS),不冷却密钥: {key_id}"
else:
log_level = "warning"
log_message = f"API密钥遇到临时性错误冷却{cooldown_duration}秒: {key_id}"
async with self._lock:
stats = self._key_stats.setdefault(api_key, KeyStats())
stats.cooldown_until = now + cooldown_duration
stats.failure_count += 1
stats.last_error_info = error_message[:256]
await self._save_to_file_internal()
getattr(logger, log_level)(log_message)
async def reset_key_status(self, api_key: str):
"""重置密钥状态,并持久化"""
async with self._lock:
stats = self._key_stats.setdefault(api_key, KeyStats())
stats.cooldown_until = 0.0
stats.last_error_info = None
await self._save_to_file_internal()
logger.info(f"重置API密钥状态: {self._get_key_id(api_key)}")
async def get_key_stats(self, api_keys: list[str]) -> dict[str, dict]:
"""
获取密钥使用统计,并计算出用于展示的派生数据。
参数:
api_keys: API密钥列表。
返回:
dict[str, dict]: 包含丰富状态和统计信息的密钥字典。
"""
stats_dict = {}
now = time.time()
async with self._lock:
for key in api_keys:
key_id = self._get_key_id(key)
stats = self._key_stats.get(key, KeyStats())
stats_dict[key_id] = {
"status_enum": stats.status,
"cooldown_seconds_left": max(0, stats.cooldown_until - now),
"total_calls": stats.success_count + stats.failure_count,
"success_count": stats.success_count,
"failure_count": stats.failure_count,
"success_rate": stats.success_rate,
"avg_latency": stats.avg_latency,
"last_error": stats.last_error_info,
"suggested_action": stats.suggested_action,
}
return stats_dict
def _get_key_id(self, api_key: str) -> str:
"""获取API密钥的标识符用于日志"""
if len(api_key) <= 8:
return api_key
return f"{api_key[:4]}...{api_key[-4:]}"
key_store = KeyStatusStore()
@driver.on_shutdown
async def _shutdown_key_store():
await key_store.shutdown()