mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-14 21:52:56 +08:00
- 【重构】LLM 服务核心架构:
- 引入中间件管道,统一处理请求生命周期(重试、密钥选择、日志、网络请求)。
- 适配器重构为组件化设计,分离配置映射、消息转换、响应解析和工具序列化逻辑。
- 移除 `with_smart_retry` 装饰器,其功能由中间件接管。
- 移除 `LLMToolExecutor`,工具执行逻辑集成到 `ToolInvoker`。
- 【功能】增强配置系统:
- `LLMGenerationConfig` 采用组件化结构(Core, Reasoning, Visual, Output, Safety, ToolConfig)。
- 新增 `GenConfigBuilder` 提供语义化配置构建方式。
- 新增 `LLMEmbeddingConfig` 用于嵌入专用配置。
- `CommonOverrides` 迁移并更新至新配置结构。
- 【功能】强化工具系统:
- 引入 `ToolInvoker` 实现更灵活的工具执行,支持回调与结构化错误。
- `function_tool` 装饰器支持动态 Pydantic 模型创建和依赖注入 (`ToolParam`, `RunContext`)。
- 平台原生工具支持 (`GeminiCodeExecution`, `GeminiGoogleSearch`, `GeminiUrlContext`)。
- 【功能】高级生成与嵌入:
- `generate_structured` 方法支持 In-Context Validation and Repair (IVR) 循环和 AutoCoT (思维链) 包装。
- 新增 `embed_query` 和 `embed_documents` 便捷嵌入 API。
- `OpenAIImageAdapter` 支持 OpenAI 兼容的图像生成。
- `SmartAdapter` 实现模型名称智能路由。
- 【重构】消息与类型系统:
- `LLMContentPart` 扩展支持更多模态和代码执行相关内容。
- `LLMMessage` 和 `LLMResponse` 结构更新,支持 `content_parts` 和思维链签名。
- 统一 `LLMErrorCode` 和用户友好错误消息,提供更详细的网络/代理错误提示。
- `pyproject.toml` 移除 `bilireq`,新增 `json_repair`。
- 【优化】日志与调试:
- 引入 `DebugLogOptions`,提供细粒度日志脱敏控制。
- 增强日志净化器,处理更多敏感数据和长字符串。
- 【清理】删除废弃模块:
- `zhenxun/services/llm/memory.py`
- `zhenxun/services/llm/executor.py`
- `zhenxun/services/llm/config/presets.py`
- `zhenxun/services/llm/types/content.py`
- `zhenxun/services/llm/types/enums.py`
- `zhenxun/services/llm/tools/__init__.py`
- `zhenxun/services/llm/tools/manager.py`
579 lines
19 KiB
Python
579 lines
19 KiB
Python
"""
|
||
LLM 适配器基类和通用数据结构
|
||
"""
|
||
|
||
from abc import ABC, abstractmethod
|
||
import json
|
||
from pathlib import Path
|
||
from typing import TYPE_CHECKING, Any
|
||
import uuid
|
||
|
||
import httpx
|
||
from pydantic import BaseModel
|
||
|
||
from zhenxun.configs.path_config import TEMP_PATH
|
||
from zhenxun.services.log import logger
|
||
|
||
from ..types import LLMContentPart
|
||
from ..types.exceptions import LLMErrorCode, LLMException
|
||
from ..types.models import LLMToolCall
|
||
|
||
if TYPE_CHECKING:
|
||
from ..config.generation import LLMEmbeddingConfig, LLMGenerationConfig
|
||
from ..service import LLMModel
|
||
from ..types import LLMMessage
|
||
from ..types.models import ToolChoice
|
||
|
||
|
||
class RequestData(BaseModel):
|
||
"""请求数据封装"""
|
||
|
||
url: str
|
||
headers: dict[str, str]
|
||
body: dict[str, Any]
|
||
files: dict[str, Any] | list[tuple[str, Any]] | None = None
|
||
|
||
|
||
class ResponseData(BaseModel):
|
||
"""响应数据封装 - 支持所有高级功能"""
|
||
|
||
text: str
|
||
content_parts: list[LLMContentPart] | None = None
|
||
images: list[bytes | Path] | None = None
|
||
usage_info: dict[str, Any] | None = None
|
||
raw_response: dict[str, Any] | None = None
|
||
tool_calls: list[LLMToolCall] | None = None
|
||
code_executions: list[Any] | None = None
|
||
grounding_metadata: Any | None = None
|
||
cache_info: Any | None = None
|
||
thought_text: str | None = None
|
||
thought_signature: str | None = None
|
||
|
||
code_execution_results: list[dict[str, Any]] | None = None
|
||
search_results: list[dict[str, Any]] | None = None
|
||
function_calls: list[dict[str, Any]] | None = None
|
||
safety_ratings: list[dict[str, Any]] | None = None
|
||
citations: list[dict[str, Any]] | None = None
|
||
|
||
|
||
def process_image_data(image_data: bytes) -> bytes | Path:
|
||
"""
|
||
处理图片数据:若超过 2MB 则保存到临时目录,避免占用内存。
|
||
"""
|
||
max_inline_size = 2 * 1024 * 1024
|
||
if len(image_data) > max_inline_size:
|
||
save_dir = TEMP_PATH / "llm"
|
||
save_dir.mkdir(parents=True, exist_ok=True)
|
||
file_name = f"{uuid.uuid4()}.png"
|
||
file_path = save_dir / file_name
|
||
file_path.write_bytes(image_data)
|
||
logger.info(
|
||
f"图片数据过大 ({len(image_data)} bytes),已保存到临时文件: {file_path}",
|
||
"LLMAdapter",
|
||
)
|
||
return file_path.resolve()
|
||
return image_data
|
||
|
||
|
||
class BaseAdapter(ABC):
|
||
"""LLM API适配器基类"""
|
||
|
||
@property
|
||
def log_sanitization_context(self) -> str:
|
||
"""用于日志清洗的上下文名称,默认 'default'"""
|
||
return "default"
|
||
|
||
@property
|
||
@abstractmethod
|
||
def api_type(self) -> str:
|
||
"""API类型标识"""
|
||
pass
|
||
|
||
@property
|
||
@abstractmethod
|
||
def supported_api_types(self) -> list[str]:
|
||
"""支持的API类型列表"""
|
||
pass
|
||
|
||
async def prepare_simple_request(
|
||
self,
|
||
model: "LLMModel",
|
||
api_key: str,
|
||
prompt: str,
|
||
history: list[dict[str, str]] | None = None,
|
||
) -> RequestData:
|
||
"""准备简单文本生成请求
|
||
|
||
默认实现:将简单请求转换为高级请求格式
|
||
子类可以重写此方法以提供特定的优化实现
|
||
"""
|
||
from ..types import LLMMessage
|
||
|
||
messages: list[LLMMessage] = []
|
||
|
||
if history:
|
||
for msg in history:
|
||
role = msg.get("role", "user")
|
||
content = msg.get("content", "")
|
||
messages.append(LLMMessage(role=role, content=content))
|
||
|
||
messages.append(LLMMessage(role="user", content=prompt))
|
||
|
||
config = model._generation_config
|
||
|
||
return await self.prepare_advanced_request(
|
||
model=model,
|
||
api_key=api_key,
|
||
messages=messages,
|
||
config=config,
|
||
tools=None,
|
||
tool_choice=None,
|
||
)
|
||
|
||
@abstractmethod
|
||
async def prepare_advanced_request(
|
||
self,
|
||
model: "LLMModel",
|
||
api_key: str,
|
||
messages: list["LLMMessage"],
|
||
config: "LLMGenerationConfig | None" = None,
|
||
tools: list[Any] | None = None,
|
||
tool_choice: "str | dict[str, Any] | ToolChoice | None" = None,
|
||
) -> RequestData:
|
||
"""准备高级请求"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def parse_response(
|
||
self,
|
||
model: "LLMModel",
|
||
response_json: dict[str, Any],
|
||
is_advanced: bool = False,
|
||
) -> ResponseData:
|
||
"""解析API响应"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def prepare_embedding_request(
|
||
self,
|
||
model: "LLMModel",
|
||
api_key: str,
|
||
texts: list[str],
|
||
config: "LLMEmbeddingConfig",
|
||
) -> RequestData:
|
||
"""准备文本嵌入请求"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def parse_embedding_response(
|
||
self, response_json: dict[str, Any]
|
||
) -> list[list[float]]:
|
||
"""解析文本嵌入响应"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def convert_generation_config(
|
||
self, config: "LLMGenerationConfig", model: "LLMModel"
|
||
) -> dict[str, Any]:
|
||
"""将通用生成配置转换为特定API的参数字典"""
|
||
pass
|
||
|
||
def validate_embedding_response(self, response_json: dict[str, Any]) -> None:
|
||
"""验证嵌入API响应"""
|
||
if response_json.get("error"):
|
||
error_info = response_json["error"]
|
||
msg = (
|
||
error_info.get("message", str(error_info))
|
||
if isinstance(error_info, dict)
|
||
else str(error_info)
|
||
)
|
||
raise LLMException(
|
||
f"嵌入API错误: {msg}",
|
||
code=LLMErrorCode.EMBEDDING_FAILED,
|
||
details=response_json,
|
||
)
|
||
|
||
def get_api_url(self, model: "LLMModel", endpoint: str) -> str:
|
||
"""构建API URL"""
|
||
if not model.api_base:
|
||
raise LLMException(
|
||
f"模型 {model.model_name} 的 api_base 未设置",
|
||
code=LLMErrorCode.CONFIGURATION_ERROR,
|
||
)
|
||
return f"{model.api_base.rstrip('/')}{endpoint}"
|
||
|
||
def get_base_headers(self, api_key: str) -> dict[str, str]:
|
||
"""获取基础请求头"""
|
||
from zhenxun.utils.user_agent import get_user_agent
|
||
|
||
headers = get_user_agent()
|
||
headers.update(
|
||
{
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {api_key}",
|
||
}
|
||
)
|
||
return headers
|
||
|
||
def validate_response(self, response_json: dict[str, Any]) -> None:
|
||
"""验证API响应,解析不同API的错误结构"""
|
||
if response_json.get("error"):
|
||
error_info = response_json["error"]
|
||
|
||
if isinstance(error_info, dict):
|
||
error_message = error_info.get("message", "未知错误")
|
||
error_code = error_info.get("code", "unknown")
|
||
error_type = error_info.get("type", "api_error")
|
||
|
||
error_code_mapping = {
|
||
"invalid_api_key": LLMErrorCode.API_KEY_INVALID,
|
||
"authentication_failed": LLMErrorCode.API_KEY_INVALID,
|
||
"insufficient_quota": LLMErrorCode.API_QUOTA_EXCEEDED,
|
||
"rate_limit_exceeded": LLMErrorCode.API_RATE_LIMITED,
|
||
"quota_exceeded": LLMErrorCode.API_RATE_LIMITED,
|
||
"model_not_found": LLMErrorCode.MODEL_NOT_FOUND,
|
||
"invalid_model": LLMErrorCode.MODEL_NOT_FOUND,
|
||
"context_length_exceeded": LLMErrorCode.CONTEXT_LENGTH_EXCEEDED,
|
||
"max_tokens_exceeded": LLMErrorCode.CONTEXT_LENGTH_EXCEEDED,
|
||
"invalid_request_error": LLMErrorCode.INVALID_PARAMETER,
|
||
"invalid_parameter": LLMErrorCode.INVALID_PARAMETER,
|
||
}
|
||
|
||
llm_error_code = error_code_mapping.get(
|
||
error_code, LLMErrorCode.API_RESPONSE_INVALID
|
||
)
|
||
|
||
logger.error(
|
||
f"API返回错误: {error_message} "
|
||
f"(代码: {error_code}, 类型: {error_type})"
|
||
)
|
||
else:
|
||
error_message = str(error_info)
|
||
error_code = "unknown"
|
||
llm_error_code = LLMErrorCode.API_RESPONSE_INVALID
|
||
|
||
logger.error(f"API返回错误: {error_message}")
|
||
|
||
raise LLMException(
|
||
f"API请求失败: {error_message}",
|
||
code=llm_error_code,
|
||
details={"api_error": error_info, "error_code": error_code},
|
||
)
|
||
|
||
if "candidates" in response_json:
|
||
candidates = response_json.get("candidates", [])
|
||
if candidates:
|
||
candidate = candidates[0]
|
||
finish_reason = candidate.get("finishReason")
|
||
if finish_reason in ["SAFETY", "RECITATION"]:
|
||
safety_ratings = candidate.get("safetyRatings", [])
|
||
logger.warning(
|
||
f"Gemini内容被安全过滤: {finish_reason}, "
|
||
f"安全评级: {safety_ratings}"
|
||
)
|
||
raise LLMException(
|
||
f"内容被安全过滤: {finish_reason}",
|
||
code=LLMErrorCode.CONTENT_FILTERED,
|
||
details={
|
||
"finish_reason": finish_reason,
|
||
"safety_ratings": safety_ratings,
|
||
},
|
||
)
|
||
|
||
if not response_json:
|
||
logger.error("API返回空响应")
|
||
raise LLMException(
|
||
"API返回空响应",
|
||
code=LLMErrorCode.API_RESPONSE_INVALID,
|
||
details={"response": response_json},
|
||
)
|
||
|
||
def _apply_generation_config(
|
||
self,
|
||
model: "LLMModel",
|
||
config: "LLMGenerationConfig | None" = None,
|
||
) -> dict[str, Any]:
|
||
"""通用的配置应用逻辑"""
|
||
if config is not None:
|
||
return self.convert_generation_config(config, model)
|
||
|
||
if model._generation_config:
|
||
return self.convert_generation_config(model._generation_config, model)
|
||
|
||
return {}
|
||
|
||
def apply_config_override(
|
||
self,
|
||
model: "LLMModel",
|
||
body: dict[str, Any],
|
||
config: "LLMGenerationConfig | None" = None,
|
||
) -> dict[str, Any]:
|
||
"""应用配置覆盖"""
|
||
config_params = self._apply_generation_config(model, config)
|
||
body.update(config_params)
|
||
return body
|
||
|
||
def handle_http_error(self, response: httpx.Response) -> LLMException | None:
|
||
"""
|
||
处理 HTTP 错误响应。
|
||
如果响应状态码表示成功 (200),返回 None;否则构造 LLMException 供外部捕获。
|
||
"""
|
||
if response.status_code == 200:
|
||
return None
|
||
|
||
error_text = response.content.decode("utf-8", errors="ignore")
|
||
error_status = ""
|
||
error_msg = error_text
|
||
try:
|
||
error_json = json.loads(error_text)
|
||
if isinstance(error_json, dict) and "error" in error_json:
|
||
error_info = error_json["error"]
|
||
if isinstance(error_info, dict):
|
||
error_msg = error_info.get("message", error_msg)
|
||
raw_status = error_info.get("status") or error_info.get("code")
|
||
error_status = str(raw_status) if raw_status is not None else ""
|
||
elif error_info is not None:
|
||
error_msg = str(error_info)
|
||
error_status = error_msg
|
||
except Exception:
|
||
pass
|
||
|
||
status_upper = error_status.upper() if error_status else ""
|
||
text_upper = error_text.upper()
|
||
|
||
error_code = LLMErrorCode.API_REQUEST_FAILED
|
||
if response.status_code == 400:
|
||
if (
|
||
"FAILED_PRECONDITION" in status_upper
|
||
or "LOCATION IS NOT SUPPORTED" in text_upper
|
||
):
|
||
error_code = LLMErrorCode.USER_LOCATION_NOT_SUPPORTED
|
||
elif "INVALID_ARGUMENT" in status_upper:
|
||
error_code = LLMErrorCode.INVALID_PARAMETER
|
||
elif "API_KEY_INVALID" in text_upper or "API KEY NOT VALID" in text_upper:
|
||
error_code = LLMErrorCode.API_KEY_INVALID
|
||
else:
|
||
error_code = LLMErrorCode.INVALID_PARAMETER
|
||
elif response.status_code in [401, 403]:
|
||
if error_msg and (
|
||
"country" in error_msg.lower()
|
||
or "region" in error_msg.lower()
|
||
or "unsupported" in error_msg.lower()
|
||
):
|
||
error_code = LLMErrorCode.USER_LOCATION_NOT_SUPPORTED
|
||
elif "PERMISSION_DENIED" in status_upper:
|
||
error_code = LLMErrorCode.API_KEY_INVALID
|
||
else:
|
||
error_code = LLMErrorCode.API_KEY_INVALID
|
||
elif response.status_code == 404:
|
||
error_code = LLMErrorCode.MODEL_NOT_FOUND
|
||
elif response.status_code == 429:
|
||
if (
|
||
"RESOURCE_EXHAUSTED" in status_upper
|
||
or "INSUFFICIENT_QUOTA" in status_upper
|
||
or ("quota" in error_msg.lower() if error_msg else False)
|
||
):
|
||
error_code = LLMErrorCode.API_QUOTA_EXCEEDED
|
||
else:
|
||
error_code = LLMErrorCode.API_RATE_LIMITED
|
||
elif response.status_code in [402, 413]:
|
||
error_code = LLMErrorCode.API_QUOTA_EXCEEDED
|
||
elif response.status_code == 422:
|
||
error_code = LLMErrorCode.GENERATION_FAILED
|
||
elif response.status_code >= 500:
|
||
error_code = LLMErrorCode.API_TIMEOUT
|
||
|
||
return LLMException(
|
||
f"HTTP请求失败: {response.status_code} ({error_status or 'Unknown'})",
|
||
code=error_code,
|
||
details={
|
||
"status_code": response.status_code,
|
||
"api_status": error_status,
|
||
"response": error_text,
|
||
},
|
||
)
|
||
|
||
|
||
class OpenAICompatAdapter(BaseAdapter):
|
||
"""
|
||
处理所有 OpenAI 兼容 API 的通用适配器。
|
||
"""
|
||
|
||
@property
|
||
def log_sanitization_context(self) -> str:
|
||
return "openai_request"
|
||
|
||
@abstractmethod
|
||
def get_chat_endpoint(self, model: "LLMModel") -> str:
|
||
"""子类必须实现,返回 chat completions 的端点"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def get_embedding_endpoint(self, model: "LLMModel") -> str:
|
||
"""子类必须实现,返回 embeddings 的端点"""
|
||
pass
|
||
|
||
async def prepare_simple_request(
|
||
self,
|
||
model: "LLMModel",
|
||
api_key: str,
|
||
prompt: str,
|
||
history: list[dict[str, str]] | None = None,
|
||
) -> RequestData:
|
||
"""准备简单文本生成请求 - OpenAI兼容API的通用实现"""
|
||
url = self.get_api_url(model, self.get_chat_endpoint(model))
|
||
headers = self.get_base_headers(api_key)
|
||
|
||
messages = []
|
||
if history:
|
||
messages.extend(history)
|
||
messages.append({"role": "user", "content": prompt})
|
||
|
||
body = {
|
||
"model": model.model_name,
|
||
"messages": messages,
|
||
}
|
||
|
||
body = self.apply_config_override(model, body)
|
||
|
||
return RequestData(url=url, headers=headers, body=body)
|
||
|
||
async def prepare_advanced_request(
|
||
self,
|
||
model: "LLMModel",
|
||
api_key: str,
|
||
messages: list["LLMMessage"],
|
||
config: "LLMGenerationConfig | None" = None,
|
||
tools: list[Any] | None = None,
|
||
tool_choice: "str | dict[str, Any] | ToolChoice | None" = None,
|
||
) -> RequestData:
|
||
"""准备高级请求 - OpenAI兼容格式"""
|
||
url = self.get_api_url(model, self.get_chat_endpoint(model))
|
||
headers = self.get_base_headers(api_key)
|
||
if model.api_type == "openrouter":
|
||
headers.update(
|
||
{
|
||
"HTTP-Referer": "https://github.com/zhenxun-org/zhenxun_bot",
|
||
"X-Title": "Zhenxun Bot",
|
||
}
|
||
)
|
||
from .components.openai_components import OpenAIMessageConverter
|
||
|
||
converter = OpenAIMessageConverter()
|
||
openai_messages = converter.convert_messages(messages)
|
||
|
||
body = {
|
||
"model": model.model_name,
|
||
"messages": openai_messages,
|
||
}
|
||
|
||
openai_tools: list[dict[str, Any]] | None = None
|
||
executables: list[Any] = []
|
||
if tools:
|
||
for tool in tools:
|
||
if hasattr(tool, "get_definition"):
|
||
executables.append(tool)
|
||
|
||
if executables:
|
||
import asyncio
|
||
|
||
from zhenxun.utils.pydantic_compat import model_dump
|
||
|
||
definition_tasks = [
|
||
executable.get_definition() for executable in executables
|
||
]
|
||
tool_defs = []
|
||
if definition_tasks:
|
||
tool_defs = await asyncio.gather(*definition_tasks)
|
||
|
||
if tool_defs:
|
||
openai_tools = [
|
||
{"type": "function", "function": model_dump(tool)}
|
||
for tool in tool_defs
|
||
]
|
||
|
||
if openai_tools:
|
||
body["tools"] = openai_tools
|
||
|
||
if tool_choice:
|
||
body["tool_choice"] = tool_choice
|
||
|
||
body = self.apply_config_override(model, body, config)
|
||
return RequestData(url=url, headers=headers, body=body)
|
||
|
||
def parse_response(
|
||
self,
|
||
model: "LLMModel",
|
||
response_json: dict[str, Any],
|
||
is_advanced: bool = False,
|
||
) -> ResponseData:
|
||
"""解析响应 - 直接使用组件化 ResponseParser"""
|
||
_ = model, is_advanced
|
||
from .components.openai_components import OpenAIResponseParser
|
||
|
||
parser = OpenAIResponseParser()
|
||
return parser.parse(response_json)
|
||
|
||
def prepare_embedding_request(
|
||
self,
|
||
model: "LLMModel",
|
||
api_key: str,
|
||
texts: list[str],
|
||
config: "LLMEmbeddingConfig",
|
||
) -> RequestData:
|
||
"""准备嵌入请求 - OpenAI兼容格式"""
|
||
url = self.get_api_url(model, self.get_embedding_endpoint(model))
|
||
headers = self.get_base_headers(api_key)
|
||
|
||
body = {
|
||
"model": model.model_name,
|
||
"input": texts,
|
||
}
|
||
|
||
if config.output_dimensionality:
|
||
body["dimensions"] = config.output_dimensionality
|
||
|
||
if config.task_type:
|
||
body["task"] = config.task_type
|
||
|
||
if config.encoding_format and config.encoding_format != "float":
|
||
body["encoding_format"] = config.encoding_format
|
||
|
||
return RequestData(url=url, headers=headers, body=body)
|
||
|
||
def parse_embedding_response(
|
||
self, response_json: dict[str, Any]
|
||
) -> list[list[float]]:
|
||
"""解析嵌入响应 - OpenAI兼容格式"""
|
||
self.validate_embedding_response(response_json)
|
||
|
||
try:
|
||
data = response_json.get("data", [])
|
||
if not data:
|
||
raise LLMException(
|
||
"嵌入响应中没有数据",
|
||
code=LLMErrorCode.EMBEDDING_FAILED,
|
||
details=response_json,
|
||
)
|
||
|
||
embeddings = []
|
||
for item in data:
|
||
if "embedding" in item:
|
||
embeddings.append(item["embedding"])
|
||
else:
|
||
raise LLMException(
|
||
"嵌入响应格式错误:缺少embedding字段",
|
||
code=LLMErrorCode.EMBEDDING_FAILED,
|
||
details=item,
|
||
)
|
||
|
||
return embeddings
|
||
|
||
except Exception as e:
|
||
logger.error(f"解析嵌入响应失败: {e}", e=e)
|
||
raise LLMException(
|
||
f"解析嵌入响应失败: {e}",
|
||
code=LLMErrorCode.EMBEDDING_FAILED,
|
||
cause=e,
|
||
)
|