zhenxun_bot/zhenxun/services/llm/adapters/base.py
Rumio c667fc215e
feat(llm): 增强LLM服务,支持图片生成、响应验证与OpenRouter集成 (#2054)
*  feat(llm): 增强LLM服务,支持图片生成、响应验证与OpenRouter集成

- 【新功能】统一图片生成与编辑API `create_image`,支持文生图、图生图及多图输入
- 【新功能】引入LLM响应验证机制,通过 `validation_policy` 和 `response_validator` 确保响应内容符合预期,例如强制返回图片
- 【新功能】适配OpenRouter API,扩展LLM服务提供商支持,并添加OpenRouter特定请求头
- 【重构】将日志净化逻辑重构至 `log_sanitizer` 模块,提供统一的净化入口,并应用于NoneBot消息、LLM请求/响应日志
- 【修复】优化Gemini适配器,正确解析图片生成响应中的Base64图片数据,并更新模型能力注册表

*  feat(image): 优化图片生成响应并返回完整LLMResponse

*  feat(llm): 为 OpenAI 兼容请求体添加日志净化

* 🐛 fix(ui): 截断UI调试HTML日志中的长base64图片数据

---------

Co-authored-by: webjoin111 <455457521@qq.com>
2025-10-01 18:41:46 +08:00

593 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
LLM 适配器基类和通用数据结构
"""
from abc import ABC, abstractmethod
import base64
import binascii
import json
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel
from zhenxun.services.log import logger
from ..types.exceptions import LLMErrorCode, LLMException
from ..types.models import LLMToolCall
if TYPE_CHECKING:
from ..config.generation import LLMGenerationConfig
from ..service import LLMModel
from ..types.content import LLMMessage
from ..types.enums import EmbeddingTaskType
from ..types.protocols import ToolExecutable
class RequestData(BaseModel):
"""请求数据封装"""
url: str
headers: dict[str, str]
body: dict[str, Any]
class ResponseData(BaseModel):
"""响应数据封装 - 支持所有高级功能"""
text: str
image_bytes: bytes | None = None
usage_info: dict[str, Any] | None = None
raw_response: dict[str, Any] | None = None
tool_calls: list[LLMToolCall] | None = None
code_executions: list[Any] | None = None
grounding_metadata: Any | None = None
cache_info: Any | None = None
code_execution_results: list[dict[str, Any]] | None = None
search_results: list[dict[str, Any]] | None = None
function_calls: list[dict[str, Any]] | None = None
safety_ratings: list[dict[str, Any]] | None = None
citations: list[dict[str, Any]] | None = None
class BaseAdapter(ABC):
"""LLM API适配器基类"""
@property
@abstractmethod
def api_type(self) -> str:
"""API类型标识"""
pass
@property
@abstractmethod
def supported_api_types(self) -> list[str]:
"""支持的API类型列表"""
pass
async def prepare_simple_request(
self,
model: "LLMModel",
api_key: str,
prompt: str,
history: list[dict[str, str]] | None = None,
) -> RequestData:
"""准备简单文本生成请求
默认实现:将简单请求转换为高级请求格式
子类可以重写此方法以提供特定的优化实现
"""
from ..types.content import LLMMessage
messages: list[LLMMessage] = []
if history:
for msg in history:
role = msg.get("role", "user")
content = msg.get("content", "")
messages.append(LLMMessage(role=role, content=content))
messages.append(LLMMessage(role="user", content=prompt))
config = model._generation_config
return await self.prepare_advanced_request(
model=model,
api_key=api_key,
messages=messages,
config=config,
tools=None,
tool_choice=None,
)
@abstractmethod
async def prepare_advanced_request(
self,
model: "LLMModel",
api_key: str,
messages: list["LLMMessage"],
config: "LLMGenerationConfig | None" = None,
tools: dict[str, "ToolExecutable"] | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> RequestData:
"""准备高级请求"""
pass
@abstractmethod
def parse_response(
self,
model: "LLMModel",
response_json: dict[str, Any],
is_advanced: bool = False,
) -> ResponseData:
"""解析API响应"""
pass
@abstractmethod
def prepare_embedding_request(
self,
model: "LLMModel",
api_key: str,
texts: list[str],
task_type: "EmbeddingTaskType | str",
**kwargs: Any,
) -> RequestData:
"""准备文本嵌入请求"""
pass
@abstractmethod
def parse_embedding_response(
self, response_json: dict[str, Any]
) -> list[list[float]]:
"""解析文本嵌入响应"""
pass
def validate_embedding_response(self, response_json: dict[str, Any]) -> None:
"""验证嵌入API响应"""
if "error" in response_json:
error_info = response_json["error"]
msg = (
error_info.get("message", str(error_info))
if isinstance(error_info, dict)
else str(error_info)
)
raise LLMException(
f"嵌入API错误: {msg}",
code=LLMErrorCode.EMBEDDING_FAILED,
details=response_json,
)
def get_api_url(self, model: "LLMModel", endpoint: str) -> str:
"""构建API URL"""
if not model.api_base:
raise LLMException(
f"模型 {model.model_name} 的 api_base 未设置",
code=LLMErrorCode.CONFIGURATION_ERROR,
)
return f"{model.api_base.rstrip('/')}{endpoint}"
def get_base_headers(self, api_key: str) -> dict[str, str]:
"""获取基础请求头"""
from zhenxun.utils.user_agent import get_user_agent
headers = get_user_agent()
headers.update(
{
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
}
)
return headers
def convert_messages_to_openai_format(
self, messages: list["LLMMessage"]
) -> list[dict[str, Any]]:
"""将LLMMessage转换为OpenAI格式 - 通用方法"""
openai_messages: list[dict[str, Any]] = []
for msg in messages:
openai_msg: dict[str, Any] = {"role": msg.role}
if msg.role == "tool":
openai_msg["tool_call_id"] = msg.tool_call_id
openai_msg["name"] = msg.name
openai_msg["content"] = msg.content
else:
if isinstance(msg.content, str):
openai_msg["content"] = msg.content
else:
content_parts = []
for part in msg.content:
if part.type == "text":
content_parts.append({"type": "text", "text": part.text})
elif part.type == "image":
content_parts.append(
{
"type": "image_url",
"image_url": {"url": part.image_source},
}
)
openai_msg["content"] = content_parts
if msg.role == "assistant" and msg.tool_calls:
assistant_tool_calls = []
for call in msg.tool_calls:
assistant_tool_calls.append(
{
"id": call.id,
"type": "function",
"function": {
"name": call.function.name,
"arguments": call.function.arguments,
},
}
)
openai_msg["tool_calls"] = assistant_tool_calls
if msg.name and msg.role != "tool":
openai_msg["name"] = msg.name
openai_messages.append(openai_msg)
return openai_messages
def parse_openai_response(self, response_json: dict[str, Any]) -> ResponseData:
"""解析OpenAI格式的响应 - 通用方法"""
self.validate_response(response_json)
try:
choices = response_json.get("choices", [])
if not choices:
logger.debug("OpenAI响应中没有choices可能为空回复或流结束。")
return ResponseData(text="", raw_response=response_json)
choice = choices[0]
message = choice.get("message", {})
content = message.get("content", "")
if content:
content = content.strip()
image_bytes: bytes | None = None
if content and content.startswith("{") and content.endswith("}"):
try:
content_json = json.loads(content)
if "b64_json" in content_json:
image_bytes = base64.b64decode(content_json["b64_json"])
content = "[图片已生成]"
elif "data" in content_json and isinstance(
content_json["data"], str
):
image_bytes = base64.b64decode(content_json["data"])
content = "[图片已生成]"
except (json.JSONDecodeError, KeyError, binascii.Error):
pass
elif (
"images" in message
and isinstance(message["images"], list)
and message["images"]
):
image_info = message["images"][0]
if image_info.get("type") == "image_url":
image_url_obj = image_info.get("image_url", {})
url_str = image_url_obj.get("url", "")
if url_str.startswith("data:image/png;base64,"):
try:
b64_data = url_str.split(",", 1)[1]
image_bytes = base64.b64decode(b64_data)
content = content if content else "[图片已生成]"
except (IndexError, binascii.Error) as e:
logger.warning(f"解析OpenRouter Base64图片数据失败: {e}")
parsed_tool_calls: list[LLMToolCall] | None = None
if message_tool_calls := message.get("tool_calls"):
from ..types.models import LLMToolFunction
parsed_tool_calls = []
for tc_data in message_tool_calls:
try:
if tc_data.get("type") == "function":
parsed_tool_calls.append(
LLMToolCall(
id=tc_data["id"],
function=LLMToolFunction(
name=tc_data["function"]["name"],
arguments=tc_data["function"]["arguments"],
),
)
)
except KeyError as e:
logger.warning(
f"解析OpenAI工具调用数据时缺少键: {tc_data}, 错误: {e}"
)
except Exception as e:
logger.warning(
f"解析OpenAI工具调用数据时出错: {tc_data}, 错误: {e}"
)
if not parsed_tool_calls:
parsed_tool_calls = None
final_text = content if content is not None else ""
if not final_text and parsed_tool_calls:
final_text = f"请求调用 {len(parsed_tool_calls)} 个工具。"
usage_info = response_json.get("usage")
return ResponseData(
text=final_text,
tool_calls=parsed_tool_calls,
usage_info=usage_info,
image_bytes=image_bytes,
raw_response=response_json,
)
except Exception as e:
logger.error(f"解析OpenAI格式响应失败: {e}", e=e)
raise LLMException(
f"解析API响应失败: {e}",
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
cause=e,
)
def validate_response(self, response_json: dict[str, Any]) -> None:
"""验证API响应解析不同API的错误结构"""
if "error" in response_json:
error_info = response_json["error"]
if isinstance(error_info, dict):
error_message = error_info.get("message", "未知错误")
error_code = error_info.get("code", "unknown")
error_type = error_info.get("type", "api_error")
error_code_mapping = {
"invalid_api_key": LLMErrorCode.API_KEY_INVALID,
"authentication_failed": LLMErrorCode.API_KEY_INVALID,
"rate_limit_exceeded": LLMErrorCode.API_RATE_LIMITED,
"quota_exceeded": LLMErrorCode.API_RATE_LIMITED,
"model_not_found": LLMErrorCode.MODEL_NOT_FOUND,
"invalid_model": LLMErrorCode.MODEL_NOT_FOUND,
"context_length_exceeded": LLMErrorCode.CONTEXT_LENGTH_EXCEEDED,
"max_tokens_exceeded": LLMErrorCode.CONTEXT_LENGTH_EXCEEDED,
}
llm_error_code = error_code_mapping.get(
error_code, LLMErrorCode.API_RESPONSE_INVALID
)
logger.error(
f"API返回错误: {error_message} "
f"(代码: {error_code}, 类型: {error_type})"
)
else:
error_message = str(error_info)
error_code = "unknown"
llm_error_code = LLMErrorCode.API_RESPONSE_INVALID
logger.error(f"API返回错误: {error_message}")
raise LLMException(
f"API请求失败: {error_message}",
code=llm_error_code,
details={"api_error": error_info, "error_code": error_code},
)
if "candidates" in response_json:
candidates = response_json.get("candidates", [])
if candidates:
candidate = candidates[0]
finish_reason = candidate.get("finishReason")
if finish_reason in ["SAFETY", "RECITATION"]:
safety_ratings = candidate.get("safetyRatings", [])
logger.warning(
f"Gemini内容被安全过滤: {finish_reason}, "
f"安全评级: {safety_ratings}"
)
raise LLMException(
f"内容被安全过滤: {finish_reason}",
code=LLMErrorCode.CONTENT_FILTERED,
details={
"finish_reason": finish_reason,
"safety_ratings": safety_ratings,
},
)
if not response_json:
logger.error("API返回空响应")
raise LLMException(
"API返回空响应",
code=LLMErrorCode.API_RESPONSE_INVALID,
details={"response": response_json},
)
def _apply_generation_config(
self,
model: "LLMModel",
config: "LLMGenerationConfig | None" = None,
) -> dict[str, Any]:
"""通用的配置应用逻辑"""
if config is not None:
return config.to_api_params(model.api_type, model.model_name)
if model._generation_config is not None:
return model._generation_config.to_api_params(
model.api_type, model.model_name
)
base_config = {}
if model.temperature is not None:
base_config["temperature"] = model.temperature
if model.max_tokens is not None:
if model.api_type == "gemini":
base_config["maxOutputTokens"] = model.max_tokens
else:
base_config["max_tokens"] = model.max_tokens
return base_config
def apply_config_override(
self,
model: "LLMModel",
body: dict[str, Any],
config: "LLMGenerationConfig | None" = None,
) -> dict[str, Any]:
"""应用配置覆盖"""
config_params = self._apply_generation_config(model, config)
body.update(config_params)
return body
class OpenAICompatAdapter(BaseAdapter):
"""
处理所有 OpenAI 兼容 API 的通用适配器。
"""
@abstractmethod
def get_chat_endpoint(self, model: "LLMModel") -> str:
"""子类必须实现,返回 chat completions 的端点"""
pass
@abstractmethod
def get_embedding_endpoint(self, model: "LLMModel") -> str:
"""子类必须实现,返回 embeddings 的端点"""
pass
async def prepare_simple_request(
self,
model: "LLMModel",
api_key: str,
prompt: str,
history: list[dict[str, str]] | None = None,
) -> RequestData:
"""准备简单文本生成请求 - OpenAI兼容API的通用实现"""
url = self.get_api_url(model, self.get_chat_endpoint(model))
headers = self.get_base_headers(api_key)
messages = []
if history:
messages.extend(history)
messages.append({"role": "user", "content": prompt})
body = {
"model": model.model_name,
"messages": messages,
}
body = self.apply_config_override(model, body)
return RequestData(url=url, headers=headers, body=body)
async def prepare_advanced_request(
self,
model: "LLMModel",
api_key: str,
messages: list["LLMMessage"],
config: "LLMGenerationConfig | None" = None,
tools: dict[str, "ToolExecutable"] | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> RequestData:
"""准备高级请求 - OpenAI兼容格式"""
url = self.get_api_url(model, self.get_chat_endpoint(model))
headers = self.get_base_headers(api_key)
if model.api_type == "openrouter":
headers.update(
{
"HTTP-Referer": "https://github.com/zhenxun-org/zhenxun_bot",
"X-Title": "Zhenxun Bot",
}
)
openai_messages = self.convert_messages_to_openai_format(messages)
body = {
"model": model.model_name,
"messages": openai_messages,
}
if tools:
import asyncio
from zhenxun.utils.pydantic_compat import model_dump
definition_tasks = [
executable.get_definition() for executable in tools.values()
]
openai_tools = await asyncio.gather(*definition_tasks)
if openai_tools:
body["tools"] = [
{"type": "function", "function": model_dump(tool)}
for tool in openai_tools
]
if tool_choice:
body["tool_choice"] = tool_choice
body = self.apply_config_override(model, body, config)
return RequestData(url=url, headers=headers, body=body)
def parse_response(
self,
model: "LLMModel",
response_json: dict[str, Any],
is_advanced: bool = False,
) -> ResponseData:
"""解析响应 - 直接使用基类的 OpenAI 格式解析"""
_ = model, is_advanced
return self.parse_openai_response(response_json)
def prepare_embedding_request(
self,
model: "LLMModel",
api_key: str,
texts: list[str],
task_type: "EmbeddingTaskType | str",
**kwargs: Any,
) -> RequestData:
"""准备嵌入请求 - OpenAI兼容格式"""
_ = task_type
url = self.get_api_url(model, self.get_embedding_endpoint(model))
headers = self.get_base_headers(api_key)
body = {
"model": model.model_name,
"input": texts,
}
if kwargs:
body.update(kwargs)
return RequestData(url=url, headers=headers, body=body)
def parse_embedding_response(
self, response_json: dict[str, Any]
) -> list[list[float]]:
"""解析嵌入响应 - OpenAI兼容格式"""
self.validate_embedding_response(response_json)
try:
data = response_json.get("data", [])
if not data:
raise LLMException(
"嵌入响应中没有数据",
code=LLMErrorCode.EMBEDDING_FAILED,
details=response_json,
)
embeddings = []
for item in data:
if "embedding" in item:
embeddings.append(item["embedding"])
else:
raise LLMException(
"嵌入响应格式错误缺少embedding字段",
code=LLMErrorCode.EMBEDDING_FAILED,
details=item,
)
return embeddings
except Exception as e:
logger.error(f"解析嵌入响应失败: {e}", e=e)
raise LLMException(
f"解析嵌入响应失败: {e}",
code=LLMErrorCode.EMBEDDING_FAILED,
cause=e,
)