mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 14:22:55 +08:00
348 lines
14 KiB
Python
348 lines
14 KiB
Python
|
|
import base64
|
|||
|
|
import binascii
|
|||
|
|
import json
|
|||
|
|
from pathlib import Path
|
|||
|
|
from typing import Any
|
|||
|
|
|
|||
|
|
from zhenxun.services.llm.adapters.base import ResponseData, process_image_data
|
|||
|
|
from zhenxun.services.llm.adapters.components.interfaces import (
|
|||
|
|
ConfigMapper,
|
|||
|
|
MessageConverter,
|
|||
|
|
ResponseParser,
|
|||
|
|
ToolSerializer,
|
|||
|
|
)
|
|||
|
|
from zhenxun.services.llm.config.generation import (
|
|||
|
|
ImageAspectRatio,
|
|||
|
|
LLMGenerationConfig,
|
|||
|
|
ResponseFormat,
|
|||
|
|
StructuredOutputStrategy,
|
|||
|
|
)
|
|||
|
|
from zhenxun.services.llm.types import LLMMessage
|
|||
|
|
from zhenxun.services.llm.types.capabilities import ModelCapabilities
|
|||
|
|
from zhenxun.services.llm.types.exceptions import LLMErrorCode, LLMException
|
|||
|
|
from zhenxun.services.llm.types.models import (
|
|||
|
|
LLMToolCall,
|
|||
|
|
LLMToolFunction,
|
|||
|
|
ModelDetail,
|
|||
|
|
ToolDefinition,
|
|||
|
|
)
|
|||
|
|
from zhenxun.services.llm.utils import sanitize_schema_for_llm
|
|||
|
|
from zhenxun.services.log import logger
|
|||
|
|
from zhenxun.utils.pydantic_compat import model_dump
|
|||
|
|
|
|||
|
|
|
|||
|
|
class OpenAIConfigMapper(ConfigMapper):
|
|||
|
|
def __init__(self, api_type: str = "openai"):
|
|||
|
|
self.api_type = api_type
|
|||
|
|
|
|||
|
|
def map_config(
|
|||
|
|
self,
|
|||
|
|
config: LLMGenerationConfig,
|
|||
|
|
model_detail: ModelDetail | None = None,
|
|||
|
|
capabilities: ModelCapabilities | None = None,
|
|||
|
|
) -> dict[str, Any]:
|
|||
|
|
params: dict[str, Any] = {}
|
|||
|
|
strategy = config.output.structured_output_strategy if config.output else None
|
|||
|
|
if strategy is None:
|
|||
|
|
strategy = (
|
|||
|
|
StructuredOutputStrategy.TOOL_CALL
|
|||
|
|
if self.api_type == "deepseek"
|
|||
|
|
else StructuredOutputStrategy.NATIVE
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if config.core:
|
|||
|
|
if config.core.temperature is not None:
|
|||
|
|
params["temperature"] = config.core.temperature
|
|||
|
|
if config.core.max_tokens is not None:
|
|||
|
|
params["max_tokens"] = config.core.max_tokens
|
|||
|
|
if config.core.top_k is not None:
|
|||
|
|
params["top_k"] = config.core.top_k
|
|||
|
|
if config.core.top_p is not None:
|
|||
|
|
params["top_p"] = config.core.top_p
|
|||
|
|
if config.core.frequency_penalty is not None:
|
|||
|
|
params["frequency_penalty"] = config.core.frequency_penalty
|
|||
|
|
if config.core.presence_penalty is not None:
|
|||
|
|
params["presence_penalty"] = config.core.presence_penalty
|
|||
|
|
if config.core.stop is not None:
|
|||
|
|
params["stop"] = config.core.stop
|
|||
|
|
|
|||
|
|
if config.core.repetition_penalty is not None:
|
|||
|
|
if self.api_type == "openai":
|
|||
|
|
logger.warning("OpenAI官方API不支持repetition_penalty参数,已忽略")
|
|||
|
|
else:
|
|||
|
|
params["repetition_penalty"] = config.core.repetition_penalty
|
|||
|
|
|
|||
|
|
if config.reasoning and config.reasoning.effort:
|
|||
|
|
params["reasoning_effort"] = config.reasoning.effort.value.lower()
|
|||
|
|
|
|||
|
|
if config.output:
|
|||
|
|
if isinstance(config.output.response_format, dict):
|
|||
|
|
params["response_format"] = config.output.response_format
|
|||
|
|
elif (
|
|||
|
|
config.output.response_format == ResponseFormat.JSON
|
|||
|
|
and strategy == StructuredOutputStrategy.NATIVE
|
|||
|
|
):
|
|||
|
|
if config.output.response_schema:
|
|||
|
|
sanitized = sanitize_schema_for_llm(
|
|||
|
|
config.output.response_schema, api_type="openai"
|
|||
|
|
)
|
|||
|
|
params["response_format"] = {
|
|||
|
|
"type": "json_schema",
|
|||
|
|
"json_schema": {
|
|||
|
|
"name": "structured_response",
|
|||
|
|
"schema": sanitized,
|
|||
|
|
"strict": True,
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
else:
|
|||
|
|
params["response_format"] = {"type": "json_object"}
|
|||
|
|
|
|||
|
|
if config.tool_config:
|
|||
|
|
mode = config.tool_config.mode
|
|||
|
|
if mode == "NONE":
|
|||
|
|
params["tool_choice"] = "none"
|
|||
|
|
elif mode == "AUTO":
|
|||
|
|
params["tool_choice"] = "auto"
|
|||
|
|
elif mode == "ANY":
|
|||
|
|
params["tool_choice"] = "required"
|
|||
|
|
|
|||
|
|
if config.visual and config.visual.aspect_ratio:
|
|||
|
|
size_map = {
|
|||
|
|
ImageAspectRatio.SQUARE: "1024x1024",
|
|||
|
|
ImageAspectRatio.LANDSCAPE_16_9: "1792x1024",
|
|||
|
|
ImageAspectRatio.PORTRAIT_9_16: "1024x1792",
|
|||
|
|
}
|
|||
|
|
ar = config.visual.aspect_ratio
|
|||
|
|
if isinstance(ar, ImageAspectRatio):
|
|||
|
|
mapped_size = size_map.get(ar)
|
|||
|
|
if mapped_size:
|
|||
|
|
params["size"] = mapped_size
|
|||
|
|
elif isinstance(ar, str):
|
|||
|
|
params["size"] = ar
|
|||
|
|
|
|||
|
|
if config.custom_params:
|
|||
|
|
mapped_custom = config.custom_params.copy()
|
|||
|
|
if "repetition_penalty" in mapped_custom and self.api_type == "openai":
|
|||
|
|
mapped_custom.pop("repetition_penalty")
|
|||
|
|
|
|||
|
|
if "stop" in mapped_custom:
|
|||
|
|
stop_value = mapped_custom["stop"]
|
|||
|
|
if isinstance(stop_value, str):
|
|||
|
|
mapped_custom["stop"] = [stop_value]
|
|||
|
|
|
|||
|
|
params.update(mapped_custom)
|
|||
|
|
|
|||
|
|
return params
|
|||
|
|
|
|||
|
|
|
|||
|
|
class OpenAIMessageConverter(MessageConverter):
|
|||
|
|
def convert_messages(self, messages: list[LLMMessage]) -> list[dict[str, Any]]:
|
|||
|
|
openai_messages: list[dict[str, Any]] = []
|
|||
|
|
for msg in messages:
|
|||
|
|
openai_msg: dict[str, Any] = {"role": msg.role}
|
|||
|
|
|
|||
|
|
if msg.role == "tool":
|
|||
|
|
openai_msg["tool_call_id"] = msg.tool_call_id
|
|||
|
|
openai_msg["name"] = msg.name
|
|||
|
|
openai_msg["content"] = msg.content
|
|||
|
|
else:
|
|||
|
|
if isinstance(msg.content, str):
|
|||
|
|
openai_msg["content"] = msg.content
|
|||
|
|
else:
|
|||
|
|
content_parts = []
|
|||
|
|
for part in msg.content:
|
|||
|
|
if part.type == "text":
|
|||
|
|
content_parts.append({"type": "text", "text": part.text})
|
|||
|
|
elif part.type == "image":
|
|||
|
|
content_parts.append(
|
|||
|
|
{
|
|||
|
|
"type": "image_url",
|
|||
|
|
"image_url": {"url": part.image_source},
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
openai_msg["content"] = content_parts
|
|||
|
|
|
|||
|
|
if msg.role == "assistant" and msg.tool_calls:
|
|||
|
|
assistant_tool_calls = []
|
|||
|
|
for call in msg.tool_calls:
|
|||
|
|
assistant_tool_calls.append(
|
|||
|
|
{
|
|||
|
|
"id": call.id,
|
|||
|
|
"type": "function",
|
|||
|
|
"function": {
|
|||
|
|
"name": call.function.name,
|
|||
|
|
"arguments": call.function.arguments,
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
openai_msg["tool_calls"] = assistant_tool_calls
|
|||
|
|
|
|||
|
|
if msg.name and msg.role != "tool":
|
|||
|
|
openai_msg["name"] = msg.name
|
|||
|
|
|
|||
|
|
openai_messages.append(openai_msg)
|
|||
|
|
return openai_messages
|
|||
|
|
|
|||
|
|
|
|||
|
|
class OpenAIToolSerializer(ToolSerializer):
|
|||
|
|
def serialize_tools(
|
|||
|
|
self, tools: list[ToolDefinition]
|
|||
|
|
) -> list[dict[str, Any]] | None:
|
|||
|
|
if not tools:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
openai_tools = []
|
|||
|
|
for tool in tools:
|
|||
|
|
tool_dict = model_dump(tool)
|
|||
|
|
parameters = tool_dict.get("parameters")
|
|||
|
|
if parameters:
|
|||
|
|
tool_dict["parameters"] = sanitize_schema_for_llm(
|
|||
|
|
parameters, api_type="openai"
|
|||
|
|
)
|
|||
|
|
tool_dict["strict"] = True
|
|||
|
|
openai_tools.append({"type": "function", "function": tool_dict})
|
|||
|
|
return openai_tools
|
|||
|
|
|
|||
|
|
|
|||
|
|
class OpenAIResponseParser(ResponseParser):
|
|||
|
|
def validate_response(self, response_json: dict[str, Any]) -> None:
|
|||
|
|
if response_json.get("error"):
|
|||
|
|
error_info = response_json["error"]
|
|||
|
|
if isinstance(error_info, dict):
|
|||
|
|
error_message = error_info.get("message", "未知错误")
|
|||
|
|
error_code = error_info.get("code", "unknown")
|
|||
|
|
|
|||
|
|
error_code_mapping = {
|
|||
|
|
"invalid_api_key": LLMErrorCode.API_KEY_INVALID,
|
|||
|
|
"authentication_failed": LLMErrorCode.API_KEY_INVALID,
|
|||
|
|
"insufficient_quota": LLMErrorCode.API_QUOTA_EXCEEDED,
|
|||
|
|
"rate_limit_exceeded": LLMErrorCode.API_RATE_LIMITED,
|
|||
|
|
"quota_exceeded": LLMErrorCode.API_RATE_LIMITED,
|
|||
|
|
"model_not_found": LLMErrorCode.MODEL_NOT_FOUND,
|
|||
|
|
"invalid_model": LLMErrorCode.MODEL_NOT_FOUND,
|
|||
|
|
"context_length_exceeded": LLMErrorCode.CONTEXT_LENGTH_EXCEEDED,
|
|||
|
|
"max_tokens_exceeded": LLMErrorCode.CONTEXT_LENGTH_EXCEEDED,
|
|||
|
|
"invalid_request_error": LLMErrorCode.INVALID_PARAMETER,
|
|||
|
|
"invalid_parameter": LLMErrorCode.INVALID_PARAMETER,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
llm_error_code = error_code_mapping.get(
|
|||
|
|
error_code, LLMErrorCode.API_RESPONSE_INVALID
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
error_message = str(error_info)
|
|||
|
|
error_code = "unknown"
|
|||
|
|
llm_error_code = LLMErrorCode.API_RESPONSE_INVALID
|
|||
|
|
|
|||
|
|
raise LLMException(
|
|||
|
|
f"API请求失败: {error_message}",
|
|||
|
|
code=llm_error_code,
|
|||
|
|
details={"api_error": error_info, "error_code": error_code},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def parse(self, response_json: dict[str, Any]) -> ResponseData:
|
|||
|
|
self.validate_response(response_json)
|
|||
|
|
|
|||
|
|
choices = response_json.get("choices", [])
|
|||
|
|
if not choices:
|
|||
|
|
return ResponseData(text="", raw_response=response_json)
|
|||
|
|
|
|||
|
|
choice = choices[0]
|
|||
|
|
message = choice.get("message", {})
|
|||
|
|
content = message.get("content", "")
|
|||
|
|
reasoning_content = message.get("reasoning_content", None)
|
|||
|
|
refusal = message.get("refusal")
|
|||
|
|
|
|||
|
|
if refusal:
|
|||
|
|
raise LLMException(
|
|||
|
|
f"模型拒绝生成请求: {refusal}",
|
|||
|
|
code=LLMErrorCode.CONTENT_FILTERED,
|
|||
|
|
details={"refusal": refusal},
|
|||
|
|
recoverable=False,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if content:
|
|||
|
|
content = content.strip()
|
|||
|
|
|
|||
|
|
images_payload: list[bytes | Path] = []
|
|||
|
|
if content and content.startswith("{") and content.endswith("}"):
|
|||
|
|
try:
|
|||
|
|
content_json = json.loads(content)
|
|||
|
|
if "b64_json" in content_json:
|
|||
|
|
b64_str = content_json["b64_json"]
|
|||
|
|
if isinstance(b64_str, str) and b64_str.startswith("data:"):
|
|||
|
|
b64_str = b64_str.split(",", 1)[1]
|
|||
|
|
decoded = base64.b64decode(b64_str)
|
|||
|
|
images_payload.append(process_image_data(decoded))
|
|||
|
|
content = "[图片已生成]"
|
|||
|
|
elif "data" in content_json and isinstance(content_json["data"], str):
|
|||
|
|
b64_str = content_json["data"]
|
|||
|
|
if b64_str.startswith("data:"):
|
|||
|
|
b64_str = b64_str.split(",", 1)[1]
|
|||
|
|
decoded = base64.b64decode(b64_str)
|
|||
|
|
images_payload.append(process_image_data(decoded))
|
|||
|
|
content = "[图片已生成]"
|
|||
|
|
|
|||
|
|
except (json.JSONDecodeError, KeyError, binascii.Error):
|
|||
|
|
pass
|
|||
|
|
elif (
|
|||
|
|
"images" in message
|
|||
|
|
and isinstance(message["images"], list)
|
|||
|
|
and message["images"]
|
|||
|
|
):
|
|||
|
|
for image_info in message["images"]:
|
|||
|
|
if image_info.get("type") == "image_url":
|
|||
|
|
image_url_obj = image_info.get("image_url", {})
|
|||
|
|
url_str = image_url_obj.get("url", "")
|
|||
|
|
if url_str.startswith("data:image"):
|
|||
|
|
try:
|
|||
|
|
b64_data = url_str.split(",", 1)[1]
|
|||
|
|
decoded = base64.b64decode(b64_data)
|
|||
|
|
images_payload.append(process_image_data(decoded))
|
|||
|
|
except (IndexError, binascii.Error) as e:
|
|||
|
|
logger.warning(f"解析OpenRouter Base64图片数据失败: {e}")
|
|||
|
|
|
|||
|
|
if images_payload:
|
|||
|
|
content = content if content else "[图片已生成]"
|
|||
|
|
|
|||
|
|
parsed_tool_calls: list[LLMToolCall] | None = None
|
|||
|
|
if message_tool_calls := message.get("tool_calls"):
|
|||
|
|
parsed_tool_calls = []
|
|||
|
|
for tc_data in message_tool_calls:
|
|||
|
|
try:
|
|||
|
|
if tc_data.get("type") == "function":
|
|||
|
|
parsed_tool_calls.append(
|
|||
|
|
LLMToolCall(
|
|||
|
|
id=tc_data["id"],
|
|||
|
|
function=LLMToolFunction(
|
|||
|
|
name=tc_data["function"]["name"],
|
|||
|
|
arguments=tc_data["function"]["arguments"],
|
|||
|
|
),
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
except KeyError as e:
|
|||
|
|
logger.warning(
|
|||
|
|
f"解析OpenAI工具调用数据时缺少键: {tc_data}, 错误: {e}"
|
|||
|
|
)
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning(
|
|||
|
|
f"解析OpenAI工具调用数据时出错: {tc_data}, 错误: {e}"
|
|||
|
|
)
|
|||
|
|
if not parsed_tool_calls:
|
|||
|
|
parsed_tool_calls = None
|
|||
|
|
|
|||
|
|
final_text = content if content is not None else ""
|
|||
|
|
if not final_text and parsed_tool_calls:
|
|||
|
|
final_text = f"请求调用 {len(parsed_tool_calls)} 个工具。"
|
|||
|
|
|
|||
|
|
usage_info = response_json.get("usage")
|
|||
|
|
|
|||
|
|
return ResponseData(
|
|||
|
|
text=final_text,
|
|||
|
|
tool_calls=parsed_tool_calls,
|
|||
|
|
usage_info=usage_info,
|
|||
|
|
images=images_payload if images_payload else None,
|
|||
|
|
raw_response=response_json,
|
|||
|
|
thought_text=reasoning_content,
|
|||
|
|
)
|