This commit is contained in:
Rumio 2025-12-08 05:26:37 +00:00 committed by GitHub
commit df1a2429b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
40 changed files with 6087 additions and 14376 deletions

File diff suppressed because it is too large Load Diff

View File

@ -36,7 +36,6 @@ feedparser = "^6.0.11"
imagehash = "^4.3.1"
cn2an = "^0.5.22"
dateparser = "^1.2.0"
bilireq = ">=0.2.10"
python-jose = { extras = ["cryptography"], version = "^3.3.0" }
python-multipart = "^0.0.9"
aiocache = {extras = ["redis"], version = "^0.12.3"}
@ -47,10 +46,10 @@ nonebot-plugin-uninfo = ">=0.7.3"
nonebot-plugin-waiter = "^0.8.1"
multidict = ">=6.0.0,!=6.3.2"
pydantic = ">=1.0.0, <2.0.0"
redis = { version = ">=5", optional = true }
asyncpg = { version = ">=0.20.0", optional = true }
alibabacloud-devops20210625 = "^5.0.2"
json_repair = "^0.54.0"
[tool.poetry.group.dev.dependencies]
nonebug = "^0.4"

File diff suppressed because it is too large Load Diff

View File

@ -36,7 +36,6 @@ feedparser = "^6.0.11"
imagehash = "^4.3.1"
cn2an = "^0.5.22"
dateparser = "^1.2.0"
bilireq = ">=0.2.10"
python-jose = { extras = ["cryptography"], version = "^3.3.0" }
python-multipart = "^0.0.9"
aiocache = {extras = ["redis"], version = "^0.12.3"}
@ -47,10 +46,10 @@ nonebot-plugin-uninfo = ">=0.7.3"
nonebot-plugin-waiter = "^0.8.1"
multidict = ">=6.0.0,!=6.3.2"
pydantic = ">=2.0.0, <3.0.0"
redis = { version = ">=5", optional = true }
asyncpg = { version = ">=0.20.0", optional = true }
alibabacloud-devops20210625 = "^5.0.2"
json_repair = "^0.54.0"
[tool.poetry.group.dev.dependencies]
nonebug = "^0.4"

View File

@ -36,7 +36,6 @@ feedparser = "^6.0.11"
imagehash = "^4.3.1"
cn2an = "^0.5.22"
dateparser = "^1.2.0"
bilireq = ">=0.2.10"
python-jose = { extras = ["cryptography"], version = "^3.3.0" }
python-multipart = "^0.0.9"
aiocache = {extras = ["redis"], version = "^0.12.3"}
@ -46,6 +45,7 @@ tenacity = "^9.0.0"
nonebot-plugin-uninfo = ">=0.7.3"
nonebot-plugin-waiter = "^0.8.1"
multidict = ">=6.0.0,!=6.3.2"
json_repair = "^0.54.0"
redis = { version = ">=5", optional = true }
asyncpg = { version = ">=0.20.0", optional = true }

View File

@ -21,7 +21,6 @@ feedparser>=6.0.11,<7.0.0
ImageHash>=4.3.1,<5.0.0
cn2an>=0.5.22,<0.6.0
dateparser>=1.2.0,<2.0.0
bilireq>=0.2.10
python-jose[cryptography]>=3.3.0,<4.0.0
python-multipart>=0.0.9,<0.1.0
aiocache[redis]>=0.12.3,<0.13.0
@ -32,6 +31,6 @@ nonebot-plugin-uninfo>=0.7.3
nonebot-plugin-waiter>=0.8.1,<0.9.0
multidict>=6.0.0,<7.0.0,!=6.3.2
alibabacloud-devops20210625>=5.0.2,<6.0.0
json_repair>=0.54.0,<0.55.0
redis>=5
asyncpg>=0.20.0

View File

@ -9,13 +9,15 @@ from .api import (
code,
create_image,
embed,
embed_documents,
embed_query,
generate,
generate_structured,
run_with_tools,
search,
)
from .config import (
CommonOverrides,
GenConfigBuilder,
LLMGenerationConfig,
register_llm_configs,
)
@ -32,8 +34,8 @@ from .manager import (
list_model_identifiers,
set_global_default_model_name,
)
from .session import AI, AIConfig
from .tools import function_tool, tool_provider_manager
from .session import AI, AIConfig, MemoryProcessor, set_default_memory_backend
from .tools import RunContext, ToolInvoker, function_tool, tool_provider_manager
from .types import (
EmbeddingTaskType,
LLMContentPart,
@ -50,6 +52,11 @@ from .types import (
ToolMetadata,
UsageInfo,
)
from .types.models import (
GeminiCodeExecution,
GeminiGoogleSearch,
GeminiUrlContext,
)
from .utils import create_multimodal_message, message_to_unimessage, unimsg_to_llm_parts
__all__ = [
@ -57,19 +64,26 @@ __all__ = [
"AIConfig",
"CommonOverrides",
"EmbeddingTaskType",
"GeminiCodeExecution",
"GeminiGoogleSearch",
"GeminiUrlContext",
"GenConfigBuilder",
"LLMContentPart",
"LLMErrorCode",
"LLMException",
"LLMGenerationConfig",
"LLMMessage",
"LLMResponse",
"MemoryProcessor",
"ModelDetail",
"ModelInfo",
"ModelName",
"ModelProvider",
"ResponseFormat",
"RunContext",
"TaskType",
"ToolCategory",
"ToolInvoker",
"ToolMetadata",
"UsageInfo",
"chat",
@ -78,6 +92,8 @@ __all__ = [
"create_image",
"create_multimodal_message",
"embed",
"embed_documents",
"embed_query",
"function_tool",
"generate",
"generate_structured",
@ -89,8 +105,8 @@ __all__ = [
"list_model_identifiers",
"message_to_unimessage",
"register_llm_configs",
"run_with_tools",
"search",
"set_default_memory_backend",
"set_global_default_model_name",
"tool_provider_manager",
"unimsg_to_llm_parts",

View File

@ -7,16 +7,18 @@ LLM 适配器模块
from .base import BaseAdapter, OpenAICompatAdapter, RequestData, ResponseData
from .factory import LLMAdapterFactory, get_adapter_for_api_type, register_adapter
from .gemini import GeminiAdapter
from .openai import OpenAIAdapter
from .openai import DeepSeekAdapter, OpenAIAdapter, OpenAIImageAdapter
LLMAdapterFactory.initialize()
__all__ = [
"BaseAdapter",
"DeepSeekAdapter",
"GeminiAdapter",
"LLMAdapterFactory",
"OpenAIAdapter",
"OpenAICompatAdapter",
"OpenAIImageAdapter",
"RequestData",
"ResponseData",
"get_adapter_for_api_type",

View File

@ -3,24 +3,26 @@ LLM 适配器基类和通用数据结构
"""
from abc import ABC, abstractmethod
import base64
import binascii
import json
from pathlib import Path
from typing import TYPE_CHECKING, Any
import uuid
import httpx
from pydantic import BaseModel
from zhenxun.configs.path_config import TEMP_PATH
from zhenxun.services.log import logger
from ..types import LLMContentPart
from ..types.exceptions import LLMErrorCode, LLMException
from ..types.models import LLMToolCall
if TYPE_CHECKING:
from ..config.generation import LLMGenerationConfig
from ..config.generation import LLMEmbeddingConfig, LLMGenerationConfig
from ..service import LLMModel
from ..types.content import LLMMessage
from ..types.enums import EmbeddingTaskType
from ..types.protocols import ToolExecutable
from ..types import LLMMessage
from ..types.models import ToolChoice
class RequestData(BaseModel):
@ -29,19 +31,23 @@ class RequestData(BaseModel):
url: str
headers: dict[str, str]
body: dict[str, Any]
files: dict[str, Any] | list[tuple[str, Any]] | None = None
class ResponseData(BaseModel):
"""响应数据封装 - 支持所有高级功能"""
text: str
images: list[bytes] | None = None
content_parts: list[LLMContentPart] | None = None
images: list[bytes | Path] | None = None
usage_info: dict[str, Any] | None = None
raw_response: dict[str, Any] | None = None
tool_calls: list[LLMToolCall] | None = None
code_executions: list[Any] | None = None
grounding_metadata: Any | None = None
cache_info: Any | None = None
thought_text: str | None = None
thought_signature: str | None = None
code_execution_results: list[dict[str, Any]] | None = None
search_results: list[dict[str, Any]] | None = None
@ -50,9 +56,33 @@ class ResponseData(BaseModel):
citations: list[dict[str, Any]] | None = None
def process_image_data(image_data: bytes) -> bytes | Path:
"""
处理图片数据若超过 2MB 则保存到临时目录避免占用内存
"""
max_inline_size = 2 * 1024 * 1024
if len(image_data) > max_inline_size:
save_dir = TEMP_PATH / "llm"
save_dir.mkdir(parents=True, exist_ok=True)
file_name = f"{uuid.uuid4()}.png"
file_path = save_dir / file_name
file_path.write_bytes(image_data)
logger.info(
f"图片数据过大 ({len(image_data)} bytes),已保存到临时文件: {file_path}",
"LLMAdapter",
)
return file_path.resolve()
return image_data
class BaseAdapter(ABC):
"""LLM API适配器基类"""
@property
def log_sanitization_context(self) -> str:
"""用于日志清洗的上下文名称,默认 'default'"""
return "default"
@property
@abstractmethod
def api_type(self) -> str:
@ -77,7 +107,7 @@ class BaseAdapter(ABC):
默认实现将简单请求转换为高级请求格式
子类可以重写此方法以提供特定的优化实现
"""
from ..types.content import LLMMessage
from ..types import LLMMessage
messages: list[LLMMessage] = []
@ -107,8 +137,8 @@ class BaseAdapter(ABC):
api_key: str,
messages: list["LLMMessage"],
config: "LLMGenerationConfig | None" = None,
tools: dict[str, "ToolExecutable"] | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[Any] | None = None,
tool_choice: "str | dict[str, Any] | ToolChoice | None" = None,
) -> RequestData:
"""准备高级请求"""
pass
@ -129,8 +159,7 @@ class BaseAdapter(ABC):
model: "LLMModel",
api_key: str,
texts: list[str],
task_type: "EmbeddingTaskType | str",
**kwargs: Any,
config: "LLMEmbeddingConfig",
) -> RequestData:
"""准备文本嵌入请求"""
pass
@ -142,9 +171,16 @@ class BaseAdapter(ABC):
"""解析文本嵌入响应"""
pass
@abstractmethod
def convert_generation_config(
self, config: "LLMGenerationConfig", model: "LLMModel"
) -> dict[str, Any]:
"""将通用生成配置转换为特定API的参数字典"""
pass
def validate_embedding_response(self, response_json: dict[str, Any]) -> None:
"""验证嵌入API响应"""
if "error" in response_json:
if response_json.get("error"):
error_info = response_json["error"]
msg = (
error_info.get("message", str(error_info))
@ -179,158 +215,9 @@ class BaseAdapter(ABC):
)
return headers
def convert_messages_to_openai_format(
self, messages: list["LLMMessage"]
) -> list[dict[str, Any]]:
"""将LLMMessage转换为OpenAI格式 - 通用方法"""
openai_messages: list[dict[str, Any]] = []
for msg in messages:
openai_msg: dict[str, Any] = {"role": msg.role}
if msg.role == "tool":
openai_msg["tool_call_id"] = msg.tool_call_id
openai_msg["name"] = msg.name
openai_msg["content"] = msg.content
else:
if isinstance(msg.content, str):
openai_msg["content"] = msg.content
else:
content_parts = []
for part in msg.content:
if part.type == "text":
content_parts.append({"type": "text", "text": part.text})
elif part.type == "image":
content_parts.append(
{
"type": "image_url",
"image_url": {"url": part.image_source},
}
)
openai_msg["content"] = content_parts
if msg.role == "assistant" and msg.tool_calls:
assistant_tool_calls = []
for call in msg.tool_calls:
assistant_tool_calls.append(
{
"id": call.id,
"type": "function",
"function": {
"name": call.function.name,
"arguments": call.function.arguments,
},
}
)
openai_msg["tool_calls"] = assistant_tool_calls
if msg.name and msg.role != "tool":
openai_msg["name"] = msg.name
openai_messages.append(openai_msg)
return openai_messages
def parse_openai_response(self, response_json: dict[str, Any]) -> ResponseData:
"""解析OpenAI格式的响应 - 通用方法"""
self.validate_response(response_json)
try:
choices = response_json.get("choices", [])
if not choices:
logger.debug("OpenAI响应中没有choices可能为空回复或流结束。")
return ResponseData(text="", raw_response=response_json)
choice = choices[0]
message = choice.get("message", {})
content = message.get("content", "")
if content:
content = content.strip()
images_bytes: list[bytes] = []
if content and content.startswith("{") and content.endswith("}"):
try:
content_json = json.loads(content)
if "b64_json" in content_json:
images_bytes.append(base64.b64decode(content_json["b64_json"]))
content = "[图片已生成]"
elif "data" in content_json and isinstance(
content_json["data"], str
):
images_bytes.append(base64.b64decode(content_json["data"]))
content = "[图片已生成]"
except (json.JSONDecodeError, KeyError, binascii.Error):
pass
elif (
"images" in message
and isinstance(message["images"], list)
and message["images"]
):
image_info = message["images"][0]
if image_info.get("type") == "image_url":
image_url_obj = image_info.get("image_url", {})
url_str = image_url_obj.get("url", "")
if url_str.startswith("data:image/png;base64,"):
try:
b64_data = url_str.split(",", 1)[1]
images_bytes.append(base64.b64decode(b64_data))
content = content if content else "[图片已生成]"
except (IndexError, binascii.Error) as e:
logger.warning(f"解析OpenRouter Base64图片数据失败: {e}")
parsed_tool_calls: list[LLMToolCall] | None = None
if message_tool_calls := message.get("tool_calls"):
from ..types.models import LLMToolFunction
parsed_tool_calls = []
for tc_data in message_tool_calls:
try:
if tc_data.get("type") == "function":
parsed_tool_calls.append(
LLMToolCall(
id=tc_data["id"],
function=LLMToolFunction(
name=tc_data["function"]["name"],
arguments=tc_data["function"]["arguments"],
),
)
)
except KeyError as e:
logger.warning(
f"解析OpenAI工具调用数据时缺少键: {tc_data}, 错误: {e}"
)
except Exception as e:
logger.warning(
f"解析OpenAI工具调用数据时出错: {tc_data}, 错误: {e}"
)
if not parsed_tool_calls:
parsed_tool_calls = None
final_text = content if content is not None else ""
if not final_text and parsed_tool_calls:
final_text = f"请求调用 {len(parsed_tool_calls)} 个工具。"
usage_info = response_json.get("usage")
return ResponseData(
text=final_text,
tool_calls=parsed_tool_calls,
usage_info=usage_info,
images=images_bytes if images_bytes else None,
raw_response=response_json,
)
except Exception as e:
logger.error(f"解析OpenAI格式响应失败: {e}", e=e)
raise LLMException(
f"解析API响应失败: {e}",
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
cause=e,
)
def validate_response(self, response_json: dict[str, Any]) -> None:
"""验证API响应解析不同API的错误结构"""
if "error" in response_json:
if response_json.get("error"):
error_info = response_json["error"]
if isinstance(error_info, dict):
@ -341,12 +228,15 @@ class BaseAdapter(ABC):
error_code_mapping = {
"invalid_api_key": LLMErrorCode.API_KEY_INVALID,
"authentication_failed": LLMErrorCode.API_KEY_INVALID,
"insufficient_quota": LLMErrorCode.API_QUOTA_EXCEEDED,
"rate_limit_exceeded": LLMErrorCode.API_RATE_LIMITED,
"quota_exceeded": LLMErrorCode.API_RATE_LIMITED,
"model_not_found": LLMErrorCode.MODEL_NOT_FOUND,
"invalid_model": LLMErrorCode.MODEL_NOT_FOUND,
"context_length_exceeded": LLMErrorCode.CONTEXT_LENGTH_EXCEEDED,
"max_tokens_exceeded": LLMErrorCode.CONTEXT_LENGTH_EXCEEDED,
"invalid_request_error": LLMErrorCode.INVALID_PARAMETER,
"invalid_parameter": LLMErrorCode.INVALID_PARAMETER,
}
llm_error_code = error_code_mapping.get(
@ -405,23 +295,12 @@ class BaseAdapter(ABC):
) -> dict[str, Any]:
"""通用的配置应用逻辑"""
if config is not None:
return config.to_api_params(model.api_type, model.model_name)
return self.convert_generation_config(config, model)
if model._generation_config is not None:
return model._generation_config.to_api_params(
model.api_type, model.model_name
)
if model._generation_config:
return self.convert_generation_config(model._generation_config, model)
base_config = {}
if model.temperature is not None:
base_config["temperature"] = model.temperature
if model.max_tokens is not None:
if model.api_type == "gemini":
base_config["maxOutputTokens"] = model.max_tokens
else:
base_config["max_tokens"] = model.max_tokens
return base_config
return {}
def apply_config_override(
self,
@ -434,12 +313,96 @@ class BaseAdapter(ABC):
body.update(config_params)
return body
def handle_http_error(self, response: httpx.Response) -> LLMException | None:
"""
处理 HTTP 错误响应
如果响应状态码表示成功 (200)返回 None否则构造 LLMException 供外部捕获
"""
if response.status_code == 200:
return None
error_text = response.content.decode("utf-8", errors="ignore")
error_status = ""
error_msg = error_text
try:
error_json = json.loads(error_text)
if isinstance(error_json, dict) and "error" in error_json:
error_info = error_json["error"]
if isinstance(error_info, dict):
error_msg = error_info.get("message", error_msg)
raw_status = error_info.get("status") or error_info.get("code")
error_status = str(raw_status) if raw_status is not None else ""
elif error_info is not None:
error_msg = str(error_info)
error_status = error_msg
except Exception:
pass
status_upper = error_status.upper() if error_status else ""
text_upper = error_text.upper()
error_code = LLMErrorCode.API_REQUEST_FAILED
if response.status_code == 400:
if (
"FAILED_PRECONDITION" in status_upper
or "LOCATION IS NOT SUPPORTED" in text_upper
):
error_code = LLMErrorCode.USER_LOCATION_NOT_SUPPORTED
elif "INVALID_ARGUMENT" in status_upper:
error_code = LLMErrorCode.INVALID_PARAMETER
elif "API_KEY_INVALID" in text_upper or "API KEY NOT VALID" in text_upper:
error_code = LLMErrorCode.API_KEY_INVALID
else:
error_code = LLMErrorCode.INVALID_PARAMETER
elif response.status_code in [401, 403]:
if error_msg and (
"country" in error_msg.lower()
or "region" in error_msg.lower()
or "unsupported" in error_msg.lower()
):
error_code = LLMErrorCode.USER_LOCATION_NOT_SUPPORTED
elif "PERMISSION_DENIED" in status_upper:
error_code = LLMErrorCode.API_KEY_INVALID
else:
error_code = LLMErrorCode.API_KEY_INVALID
elif response.status_code == 404:
error_code = LLMErrorCode.MODEL_NOT_FOUND
elif response.status_code == 429:
if (
"RESOURCE_EXHAUSTED" in status_upper
or "INSUFFICIENT_QUOTA" in status_upper
or ("quota" in error_msg.lower() if error_msg else False)
):
error_code = LLMErrorCode.API_QUOTA_EXCEEDED
else:
error_code = LLMErrorCode.API_RATE_LIMITED
elif response.status_code in [402, 413]:
error_code = LLMErrorCode.API_QUOTA_EXCEEDED
elif response.status_code == 422:
error_code = LLMErrorCode.GENERATION_FAILED
elif response.status_code >= 500:
error_code = LLMErrorCode.API_TIMEOUT
return LLMException(
f"HTTP请求失败: {response.status_code} ({error_status or 'Unknown'})",
code=error_code,
details={
"status_code": response.status_code,
"api_status": error_status,
"response": error_text,
},
)
class OpenAICompatAdapter(BaseAdapter):
"""
处理所有 OpenAI 兼容 API 的通用适配器
"""
@property
def log_sanitization_context(self) -> str:
return "openai_request"
@abstractmethod
def get_chat_endpoint(self, model: "LLMModel") -> str:
"""子类必须实现,返回 chat completions 的端点"""
@ -481,8 +444,8 @@ class OpenAICompatAdapter(BaseAdapter):
api_key: str,
messages: list["LLMMessage"],
config: "LLMGenerationConfig | None" = None,
tools: dict[str, "ToolExecutable"] | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[Any] | None = None,
tool_choice: "str | dict[str, Any] | ToolChoice | None" = None,
) -> RequestData:
"""准备高级请求 - OpenAI兼容格式"""
url = self.get_api_url(model, self.get_chat_endpoint(model))
@ -494,28 +457,44 @@ class OpenAICompatAdapter(BaseAdapter):
"X-Title": "Zhenxun Bot",
}
)
openai_messages = self.convert_messages_to_openai_format(messages)
from .components.openai_components import OpenAIMessageConverter
converter = OpenAIMessageConverter()
openai_messages = converter.convert_messages(messages)
body = {
"model": model.model_name,
"messages": openai_messages,
}
openai_tools: list[dict[str, Any]] | None = None
executables: list[Any] = []
if tools:
for tool in tools:
if hasattr(tool, "get_definition"):
executables.append(tool)
if executables:
import asyncio
from zhenxun.utils.pydantic_compat import model_dump
definition_tasks = [
executable.get_definition() for executable in tools.values()
executable.get_definition() for executable in executables
]
openai_tools = await asyncio.gather(*definition_tasks)
if openai_tools:
body["tools"] = [
tool_defs = []
if definition_tasks:
tool_defs = await asyncio.gather(*definition_tasks)
if tool_defs:
openai_tools = [
{"type": "function", "function": model_dump(tool)}
for tool in openai_tools
for tool in tool_defs
]
if openai_tools:
body["tools"] = openai_tools
if tool_choice:
body["tool_choice"] = tool_choice
@ -528,20 +507,21 @@ class OpenAICompatAdapter(BaseAdapter):
response_json: dict[str, Any],
is_advanced: bool = False,
) -> ResponseData:
"""解析响应 - 直接使用基类的 OpenAI 格式解析"""
"""解析响应 - 直接使用组件化 ResponseParser"""
_ = model, is_advanced
return self.parse_openai_response(response_json)
from .components.openai_components import OpenAIResponseParser
parser = OpenAIResponseParser()
return parser.parse(response_json)
def prepare_embedding_request(
self,
model: "LLMModel",
api_key: str,
texts: list[str],
task_type: "EmbeddingTaskType | str",
**kwargs: Any,
config: "LLMEmbeddingConfig",
) -> RequestData:
"""准备嵌入请求 - OpenAI兼容格式"""
_ = task_type
url = self.get_api_url(model, self.get_embedding_endpoint(model))
headers = self.get_base_headers(api_key)
@ -550,8 +530,14 @@ class OpenAICompatAdapter(BaseAdapter):
"input": texts,
}
if kwargs:
body.update(kwargs)
if config.output_dimensionality:
body["dimensions"] = config.output_dimensionality
if config.task_type:
body["task"] = config.task_type
if config.encoding_format and config.encoding_format != "float":
body["encoding_format"] = config.encoding_format
return RequestData(url=url, headers=headers, body=body)

View File

@ -0,0 +1 @@

View File

@ -0,0 +1,606 @@
import base64
import json
from pathlib import Path
from typing import Any
from zhenxun.services.llm.adapters.base import ResponseData, process_image_data
from zhenxun.services.llm.adapters.components.interfaces import (
ConfigMapper,
MessageConverter,
ResponseParser,
ToolSerializer,
)
from zhenxun.services.llm.config.generation import (
ImageAspectRatio,
LLMGenerationConfig,
ReasoningEffort,
ResponseFormat,
)
from zhenxun.services.llm.config.providers import get_gemini_safety_threshold
from zhenxun.services.llm.types import (
CodeExecutionOutcome,
LLMContentPart,
LLMMessage,
)
from zhenxun.services.llm.types.capabilities import ModelCapabilities
from zhenxun.services.llm.types.exceptions import LLMErrorCode, LLMException
from zhenxun.services.llm.types.models import (
LLMGroundingAttribution,
LLMGroundingMetadata,
LLMToolCall,
LLMToolFunction,
ModelDetail,
ToolDefinition,
)
from zhenxun.services.llm.utils import (
resolve_json_schema_refs,
sanitize_schema_for_llm,
)
from zhenxun.services.log import logger
from zhenxun.utils.http_utils import AsyncHttpx
from zhenxun.utils.pydantic_compat import model_copy, model_dump
class GeminiConfigMapper(ConfigMapper):
def map_config(
self,
config: LLMGenerationConfig,
model_detail: ModelDetail | None = None,
capabilities: ModelCapabilities | None = None,
) -> dict[str, Any]:
params: dict[str, Any] = {}
if config.core:
if config.core.temperature is not None:
params["temperature"] = config.core.temperature
if config.core.max_tokens is not None:
params["maxOutputTokens"] = config.core.max_tokens
if config.core.top_k is not None:
params["topK"] = config.core.top_k
if config.core.top_p is not None:
params["topP"] = config.core.top_p
if config.output:
if config.output.response_format == ResponseFormat.JSON:
params["responseMimeType"] = "application/json"
if config.output.response_schema:
params["responseJsonSchema"] = config.output.response_schema
elif config.output.response_mime_type is not None:
params["responseMimeType"] = config.output.response_mime_type
if (
config.output.response_schema is not None
and "responseJsonSchema" not in params
):
params["responseJsonSchema"] = config.output.response_schema
if config.output.response_modalities:
params["responseModalities"] = config.output.response_modalities
if config.tool_config:
fc_config: dict[str, Any] = {"mode": config.tool_config.mode}
if (
config.tool_config.allowed_function_names
and config.tool_config.mode == "ANY"
):
builtins = {"code_execution", "google_search", "google_map"}
user_funcs = [
name
for name in config.tool_config.allowed_function_names
if name not in builtins
]
if user_funcs:
fc_config["allowedFunctionNames"] = user_funcs
params["toolConfig"] = {"functionCallingConfig": fc_config}
if config.reasoning:
thinking_config = params.setdefault("thinkingConfig", {})
if config.reasoning.budget_tokens is not None:
if (
config.reasoning.budget_tokens <= 0
or config.reasoning.budget_tokens >= 1
):
budget_value = int(config.reasoning.budget_tokens)
else:
budget_value = int(config.reasoning.budget_tokens * 32768)
thinking_config["thinkingBudget"] = budget_value
elif config.reasoning.effort:
if config.reasoning.effort == ReasoningEffort.MEDIUM:
thinking_config["thinkingLevel"] = "HIGH"
else:
thinking_config["thinkingLevel"] = config.reasoning.effort.value
if config.reasoning.show_thoughts is not None:
thinking_config["includeThoughts"] = config.reasoning.show_thoughts
elif capabilities and capabilities.reasoning_visibility == "visible":
thinking_config["includeThoughts"] = True
if config.visual:
image_config: dict[str, Any] = {}
if config.visual.aspect_ratio is not None:
ar_value = (
config.visual.aspect_ratio.value
if isinstance(config.visual.aspect_ratio, ImageAspectRatio)
else config.visual.aspect_ratio
)
image_config["aspectRatio"] = ar_value
if config.visual.resolution:
image_config["imageSize"] = config.visual.resolution
if image_config:
params["imageConfig"] = image_config
if config.visual.media_resolution:
media_value = config.visual.media_resolution.upper()
if not media_value.startswith("MEDIA_RESOLUTION_"):
media_value = f"MEDIA_RESOLUTION_{media_value}"
params["mediaResolution"] = media_value
if config.custom_params:
mapped_custom = config.custom_params.copy()
if "max_tokens" in mapped_custom:
mapped_custom["maxOutputTokens"] = mapped_custom.pop("max_tokens")
if "top_k" in mapped_custom:
mapped_custom["topK"] = mapped_custom.pop("top_k")
if "top_p" in mapped_custom:
mapped_custom["topP"] = mapped_custom.pop("top_p")
for key in (
"code_execution_timeout",
"grounding_config",
"dynamic_threshold",
"user_location",
"reflexion_retries",
):
mapped_custom.pop(key, None)
for unsupported in [
"frequency_penalty",
"presence_penalty",
"repetition_penalty",
]:
if unsupported in mapped_custom:
mapped_custom.pop(unsupported)
params.update(mapped_custom)
safety_settings: list[dict[str, Any]] = []
if config.safety and config.safety.safety_settings:
for category, threshold in config.safety.safety_settings.items():
safety_settings.append({"category": category, "threshold": threshold})
else:
threshold = get_gemini_safety_threshold()
for category in [
"HARM_CATEGORY_HARASSMENT",
"HARM_CATEGORY_HATE_SPEECH",
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
"HARM_CATEGORY_DANGEROUS_CONTENT",
]:
safety_settings.append({"category": category, "threshold": threshold})
if safety_settings:
params["safetySettings"] = safety_settings
return params
class GeminiMessageConverter(MessageConverter):
async def convert_part(self, part: LLMContentPart) -> dict[str, Any]:
"""将单个内容部分转换为 Gemini API 格式"""
def _get_gemini_resolution_dict() -> dict[str, Any]:
if part.media_resolution:
value = part.media_resolution.upper()
if not value.startswith("MEDIA_RESOLUTION_"):
value = f"MEDIA_RESOLUTION_{value}"
return {"media_resolution": {"level": value}}
return {}
if part.type == "text":
return {"text": part.text}
if part.type == "thought":
return {"text": part.thought_text, "thought": True}
if part.type == "image":
if not part.image_source:
raise ValueError("图像类型的内容必须包含image_source")
if part.is_image_base64():
base64_info = part.get_base64_data()
if base64_info:
mime_type, data = base64_info
payload = {"inlineData": {"mimeType": mime_type, "data": data}}
payload.update(_get_gemini_resolution_dict())
return payload
raise ValueError(f"无法解析Base64图像数据: {part.image_source[:50]}...")
if part.is_image_url():
logger.debug(f"正在为Gemini下载并编码URL图片: {part.image_source}")
try:
image_bytes = await AsyncHttpx.get_content(part.image_source)
mime_type = part.mime_type or "image/jpeg"
base64_data = base64.b64encode(image_bytes).decode("utf-8")
payload = {
"inlineData": {"mimeType": mime_type, "data": base64_data}
}
payload.update(_get_gemini_resolution_dict())
return payload
except Exception as e:
logger.error(f"下载或编码URL图片失败: {e}", e=e)
raise ValueError(f"无法处理图片URL: {e}")
raise ValueError(f"不支持的图像源格式: {part.image_source[:50]}...")
if part.type == "video":
if not part.video_source:
raise ValueError("视频类型的内容必须包含video_source")
if part.video_source.startswith("data:"):
try:
header, data = part.video_source.split(",", 1)
mime_type = header.split(";")[0].replace("data:", "")
payload = {"inlineData": {"mimeType": mime_type, "data": data}}
payload.update(_get_gemini_resolution_dict())
return payload
except (ValueError, IndexError):
raise ValueError(
f"无法解析Base64视频数据: {part.video_source[:50]}..."
)
raise ValueError(
"Gemini API 的视频处理需要通过 File API 上传,不支持直接 URL"
)
if part.type == "audio":
if not part.audio_source:
raise ValueError("音频类型的内容必须包含audio_source")
if part.audio_source.startswith("data:"):
try:
header, data = part.audio_source.split(",", 1)
mime_type = header.split(";")[0].replace("data:", "")
payload = {"inlineData": {"mimeType": mime_type, "data": data}}
payload.update(_get_gemini_resolution_dict())
return payload
except (ValueError, IndexError):
raise ValueError(
f"无法解析Base64音频数据: {part.audio_source[:50]}..."
)
raise ValueError(
"Gemini API 的音频处理需要通过 File API 上传,不支持直接 URL"
)
if part.type == "file":
if part.file_uri:
payload = {
"fileData": {"mimeType": part.mime_type, "fileUri": part.file_uri}
}
payload.update(_get_gemini_resolution_dict())
return payload
if part.file_source:
file_name = (
part.metadata.get("name", "file") if part.metadata else "file"
)
return {"text": f"[文件: {file_name}]\n{part.file_source}"}
raise ValueError("文件类型的内容必须包含file_uri或file_source")
raise ValueError(f"不支持的内容类型: {part.type}")
async def convert_messages_async(
self, messages: list[LLMMessage]
) -> list[dict[str, Any]]:
gemini_contents: list[dict[str, Any]] = []
for msg in messages:
current_parts: list[dict[str, Any]] = []
if msg.role == "system":
continue
elif msg.role == "user":
if isinstance(msg.content, str):
current_parts.append({"text": msg.content})
elif isinstance(msg.content, list):
for part_obj in msg.content:
current_parts.append(await self.convert_part(part_obj))
gemini_contents.append({"role": "user", "parts": current_parts})
elif msg.role == "assistant" or msg.role == "model":
if isinstance(msg.content, str) and msg.content:
current_parts.append({"text": msg.content})
elif isinstance(msg.content, list):
for part_obj in msg.content:
part_dict = await self.convert_part(part_obj)
if "executableCode" in part_dict:
part_dict["executable_code"] = part_dict.pop(
"executableCode"
)
if "codeExecutionResult" in part_dict:
part_dict["code_execution_result"] = part_dict.pop(
"codeExecutionResult"
)
if (
part_obj.metadata
and "thought_signature" in part_obj.metadata
):
part_dict["thoughtSignature"] = part_obj.metadata[
"thought_signature"
]
current_parts.append(part_dict)
if msg.tool_calls:
for call in msg.tool_calls:
fc_part = {
"functionCall": {
"name": call.function.name,
"args": json.loads(call.function.arguments),
}
}
if call.thought_signature:
fc_part["thoughtSignature"] = call.thought_signature
current_parts.append(fc_part)
if current_parts:
gemini_contents.append({"role": "model", "parts": current_parts})
elif msg.role == "tool":
if not msg.name:
raise ValueError("Gemini 工具消息必须包含 'name' 字段(函数名)。")
try:
content_str = (
msg.content
if isinstance(msg.content, str)
else str(msg.content)
)
tool_result_obj = json.loads(content_str)
except json.JSONDecodeError:
content_str = (
msg.content
if isinstance(msg.content, str)
else str(msg.content)
)
tool_result_obj = {"raw_output": content_str}
if isinstance(tool_result_obj, list):
final_response_payload = {"result": tool_result_obj}
elif not isinstance(tool_result_obj, dict):
final_response_payload = {"result": tool_result_obj}
else:
final_response_payload = tool_result_obj
current_parts.append(
{
"functionResponse": {
"name": msg.name,
"response": final_response_payload,
}
}
)
if gemini_contents and gemini_contents[-1]["role"] == "function":
gemini_contents[-1]["parts"].extend(current_parts)
else:
gemini_contents.append({"role": "function", "parts": current_parts})
return gemini_contents
def convert_messages(self, messages: list[LLMMessage]) -> list[dict[str, Any]]:
raise NotImplementedError("Use convert_messages_async for Gemini")
class GeminiToolSerializer(ToolSerializer):
def serialize_tools(self, tools: list[ToolDefinition]) -> list[dict[str, Any]]:
function_declarations: list[dict[str, Any]] = []
for tool_def in tools:
tool_copy = model_copy(tool_def)
tool_copy.parameters = resolve_json_schema_refs(tool_copy.parameters)
tool_copy.parameters = sanitize_schema_for_llm(
tool_copy.parameters, api_type="gemini"
)
function_declarations.append(model_dump(tool_copy))
return function_declarations
class GeminiResponseParser(ResponseParser):
def validate_response(self, response_json: dict[str, Any]) -> None:
if error := response_json.get("error"):
code = error.get("code")
message = error.get("message", "")
status = error.get("status")
details = error.get("details", [])
if code == 429 or status == "RESOURCE_EXHAUSTED":
is_quota = any(
d.get("reason") in ("QUOTA_EXCEEDED", "SERVICE_DISABLED")
for d in details
if isinstance(d, dict)
)
if is_quota or "quota" in message.lower():
raise LLMException(
f"Gemini配额耗尽: {message}",
code=LLMErrorCode.API_QUOTA_EXCEEDED,
details=error,
)
raise LLMException(
f"Gemini速率限制: {message}",
code=LLMErrorCode.API_RATE_LIMITED,
details=error,
)
if code == 400 or status in ("INVALID_ARGUMENT", "FAILED_PRECONDITION"):
raise LLMException(
f"Gemini参数错误: {message}",
code=LLMErrorCode.INVALID_PARAMETER,
details=error,
recoverable=False,
)
if prompt_feedback := response_json.get("promptFeedback"):
if block_reason := prompt_feedback.get("blockReason"):
raise LLMException(
f"内容被安全过滤: {block_reason}",
code=LLMErrorCode.CONTENT_FILTERED,
details={
"block_reason": block_reason,
"safety_ratings": prompt_feedback.get("safetyRatings"),
},
)
def parse(self, response_json: dict[str, Any]) -> ResponseData:
self.validate_response(response_json)
if "image_generation" in response_json and isinstance(
response_json["image_generation"], dict
):
candidates_source = response_json["image_generation"]
else:
candidates_source = response_json
candidates = candidates_source.get("candidates", [])
usage_info = response_json.get("usageMetadata")
if not candidates:
return ResponseData(text="", raw_response=response_json)
candidate = candidates[0]
thought_signature: str | None = None
content_data = candidate.get("content", {})
parts = content_data.get("parts", [])
text_content = ""
images_payload: list[bytes | Path] = []
parsed_tool_calls: list[LLMToolCall] | None = None
parsed_code_executions: list[dict[str, Any]] = []
content_parts: list[LLMContentPart] = []
thought_summary_parts: list[str] = []
answer_parts = []
for part in parts:
part_signature = part.get("thoughtSignature")
if part_signature and thought_signature is None:
thought_signature = part_signature
part_metadata: dict[str, Any] | None = None
if part_signature:
part_metadata = {"thought_signature": part_signature}
if part.get("thought") is True:
t_text = part.get("text", "")
thought_summary_parts.append(t_text)
content_parts.append(LLMContentPart.thought_part(t_text))
elif "text" in part:
answer_parts.append(part["text"])
c_part = LLMContentPart(
type="text", text=part["text"], metadata=part_metadata
)
content_parts.append(c_part)
elif "thoughtSummary" in part:
thought_summary_parts.append(part["thoughtSummary"])
content_parts.append(
LLMContentPart.thought_part(part["thoughtSummary"])
)
elif "inlineData" in part:
inline_data = part["inlineData"]
if "data" in inline_data:
decoded = base64.b64decode(inline_data["data"])
images_payload.append(process_image_data(decoded))
elif "functionCall" in part:
if parsed_tool_calls is None:
parsed_tool_calls = []
fc_data = part["functionCall"]
fc_sig = part_signature
try:
call_id = f"call_gemini_{len(parsed_tool_calls)}"
parsed_tool_calls.append(
LLMToolCall(
id=call_id,
thought_signature=fc_sig,
function=LLMToolFunction(
name=fc_data["name"],
arguments=json.dumps(fc_data["args"]),
),
)
)
except Exception as e:
logger.warning(
f"解析Gemini functionCall时出错: {fc_data}, 错误: {e}"
)
elif "executableCode" in part:
exec_code = part["executableCode"]
lang = exec_code.get("language", "PYTHON")
code = exec_code.get("code", "")
content_parts.append(LLMContentPart.executable_code_part(lang, code))
answer_parts.append(f"\n[生成代码 ({lang})]:\n```python\n{code}\n```\n")
elif "codeExecutionResult" in part:
result = part["codeExecutionResult"]
outcome = result.get("outcome", CodeExecutionOutcome.OUTCOME_UNKNOWN)
output = result.get("output", "")
content_parts.append(
LLMContentPart.execution_result_part(outcome, output)
)
parsed_code_executions.append(result)
if outcome == CodeExecutionOutcome.OUTCOME_OK:
answer_parts.append(f"\n[代码执行结果]:\n```\n{output}\n```\n")
else:
answer_parts.append(f"\n[代码执行失败 ({outcome})]:\n{output}\n")
full_answer = "".join(answer_parts).strip()
text_content = full_answer
final_thought_text = (
"\n\n".join(thought_summary_parts).strip()
if thought_summary_parts
else None
)
grounding_metadata_obj = None
if grounding_data := candidate.get("groundingMetadata"):
try:
sep_content = None
sep_field = grounding_data.get("searchEntryPoint")
if isinstance(sep_field, dict):
sep_content = sep_field.get("renderedContent")
attributions = []
if chunks := grounding_data.get("groundingChunks"):
for chunk in chunks:
if web := chunk.get("web"):
attributions.append(
LLMGroundingAttribution(
title=web.get("title"),
uri=web.get("uri"),
snippet=web.get("snippet"),
confidence_score=None,
)
)
grounding_metadata_obj = LLMGroundingMetadata(
web_search_queries=grounding_data.get("webSearchQueries"),
grounding_attributions=attributions or None,
search_suggestions=grounding_data.get("searchSuggestions"),
search_entry_point=sep_content,
map_widget_token=grounding_data.get("googleMapsWidgetContextToken"),
)
except Exception as e:
logger.warning(f"无法解析Grounding元数据: {grounding_data}, {e}")
return ResponseData(
text=text_content,
tool_calls=parsed_tool_calls,
code_executions=parsed_code_executions if parsed_code_executions else None,
content_parts=content_parts if content_parts else None,
images=images_payload if images_payload else None,
usage_info=usage_info,
raw_response=response_json,
grounding_metadata=grounding_metadata_obj,
thought_text=final_thought_text,
thought_signature=thought_signature,
)

View File

@ -0,0 +1,43 @@
from abc import ABC, abstractmethod
from typing import Any
from zhenxun.services.llm.adapters.base import ResponseData
from zhenxun.services.llm.config.generation import LLMGenerationConfig
from zhenxun.services.llm.types import LLMMessage
from zhenxun.services.llm.types.capabilities import ModelCapabilities
from zhenxun.services.llm.types.models import ModelDetail, ToolDefinition
class ConfigMapper(ABC):
@abstractmethod
def map_config(
self,
config: LLMGenerationConfig,
model_detail: ModelDetail | None = None,
capabilities: ModelCapabilities | None = None,
) -> dict[str, Any]:
"""将通用生成配置转换为特定 API 的参数字典"""
...
class MessageConverter(ABC):
@abstractmethod
def convert_messages(
self, messages: list[LLMMessage]
) -> list[dict[str, Any]] | dict[str, Any]:
"""将通用消息列表转换为特定 API 的消息格式"""
...
class ToolSerializer(ABC):
@abstractmethod
def serialize_tools(self, tools: list[ToolDefinition]) -> Any:
"""将通用工具定义转换为特定 API 的工具格式"""
...
class ResponseParser(ABC):
@abstractmethod
def parse(self, response_json: dict[str, Any]) -> ResponseData:
"""将特定 API 的响应解析为通用响应数据"""
...

View File

@ -0,0 +1,347 @@
import base64
import binascii
import json
from pathlib import Path
from typing import Any
from zhenxun.services.llm.adapters.base import ResponseData, process_image_data
from zhenxun.services.llm.adapters.components.interfaces import (
ConfigMapper,
MessageConverter,
ResponseParser,
ToolSerializer,
)
from zhenxun.services.llm.config.generation import (
ImageAspectRatio,
LLMGenerationConfig,
ResponseFormat,
StructuredOutputStrategy,
)
from zhenxun.services.llm.types import LLMMessage
from zhenxun.services.llm.types.capabilities import ModelCapabilities
from zhenxun.services.llm.types.exceptions import LLMErrorCode, LLMException
from zhenxun.services.llm.types.models import (
LLMToolCall,
LLMToolFunction,
ModelDetail,
ToolDefinition,
)
from zhenxun.services.llm.utils import sanitize_schema_for_llm
from zhenxun.services.log import logger
from zhenxun.utils.pydantic_compat import model_dump
class OpenAIConfigMapper(ConfigMapper):
def __init__(self, api_type: str = "openai"):
self.api_type = api_type
def map_config(
self,
config: LLMGenerationConfig,
model_detail: ModelDetail | None = None,
capabilities: ModelCapabilities | None = None,
) -> dict[str, Any]:
params: dict[str, Any] = {}
strategy = config.output.structured_output_strategy if config.output else None
if strategy is None:
strategy = (
StructuredOutputStrategy.TOOL_CALL
if self.api_type == "deepseek"
else StructuredOutputStrategy.NATIVE
)
if config.core:
if config.core.temperature is not None:
params["temperature"] = config.core.temperature
if config.core.max_tokens is not None:
params["max_tokens"] = config.core.max_tokens
if config.core.top_k is not None:
params["top_k"] = config.core.top_k
if config.core.top_p is not None:
params["top_p"] = config.core.top_p
if config.core.frequency_penalty is not None:
params["frequency_penalty"] = config.core.frequency_penalty
if config.core.presence_penalty is not None:
params["presence_penalty"] = config.core.presence_penalty
if config.core.stop is not None:
params["stop"] = config.core.stop
if config.core.repetition_penalty is not None:
if self.api_type == "openai":
logger.warning("OpenAI官方API不支持repetition_penalty参数已忽略")
else:
params["repetition_penalty"] = config.core.repetition_penalty
if config.reasoning and config.reasoning.effort:
params["reasoning_effort"] = config.reasoning.effort.value.lower()
if config.output:
if isinstance(config.output.response_format, dict):
params["response_format"] = config.output.response_format
elif (
config.output.response_format == ResponseFormat.JSON
and strategy == StructuredOutputStrategy.NATIVE
):
if config.output.response_schema:
sanitized = sanitize_schema_for_llm(
config.output.response_schema, api_type="openai"
)
params["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": "structured_response",
"schema": sanitized,
"strict": True,
},
}
else:
params["response_format"] = {"type": "json_object"}
if config.tool_config:
mode = config.tool_config.mode
if mode == "NONE":
params["tool_choice"] = "none"
elif mode == "AUTO":
params["tool_choice"] = "auto"
elif mode == "ANY":
params["tool_choice"] = "required"
if config.visual and config.visual.aspect_ratio:
size_map = {
ImageAspectRatio.SQUARE: "1024x1024",
ImageAspectRatio.LANDSCAPE_16_9: "1792x1024",
ImageAspectRatio.PORTRAIT_9_16: "1024x1792",
}
ar = config.visual.aspect_ratio
if isinstance(ar, ImageAspectRatio):
mapped_size = size_map.get(ar)
if mapped_size:
params["size"] = mapped_size
elif isinstance(ar, str):
params["size"] = ar
if config.custom_params:
mapped_custom = config.custom_params.copy()
if "repetition_penalty" in mapped_custom and self.api_type == "openai":
mapped_custom.pop("repetition_penalty")
if "stop" in mapped_custom:
stop_value = mapped_custom["stop"]
if isinstance(stop_value, str):
mapped_custom["stop"] = [stop_value]
params.update(mapped_custom)
return params
class OpenAIMessageConverter(MessageConverter):
def convert_messages(self, messages: list[LLMMessage]) -> list[dict[str, Any]]:
openai_messages: list[dict[str, Any]] = []
for msg in messages:
openai_msg: dict[str, Any] = {"role": msg.role}
if msg.role == "tool":
openai_msg["tool_call_id"] = msg.tool_call_id
openai_msg["name"] = msg.name
openai_msg["content"] = msg.content
else:
if isinstance(msg.content, str):
openai_msg["content"] = msg.content
else:
content_parts = []
for part in msg.content:
if part.type == "text":
content_parts.append({"type": "text", "text": part.text})
elif part.type == "image":
content_parts.append(
{
"type": "image_url",
"image_url": {"url": part.image_source},
}
)
openai_msg["content"] = content_parts
if msg.role == "assistant" and msg.tool_calls:
assistant_tool_calls = []
for call in msg.tool_calls:
assistant_tool_calls.append(
{
"id": call.id,
"type": "function",
"function": {
"name": call.function.name,
"arguments": call.function.arguments,
},
}
)
openai_msg["tool_calls"] = assistant_tool_calls
if msg.name and msg.role != "tool":
openai_msg["name"] = msg.name
openai_messages.append(openai_msg)
return openai_messages
class OpenAIToolSerializer(ToolSerializer):
def serialize_tools(
self, tools: list[ToolDefinition]
) -> list[dict[str, Any]] | None:
if not tools:
return None
openai_tools = []
for tool in tools:
tool_dict = model_dump(tool)
parameters = tool_dict.get("parameters")
if parameters:
tool_dict["parameters"] = sanitize_schema_for_llm(
parameters, api_type="openai"
)
tool_dict["strict"] = True
openai_tools.append({"type": "function", "function": tool_dict})
return openai_tools
class OpenAIResponseParser(ResponseParser):
def validate_response(self, response_json: dict[str, Any]) -> None:
if response_json.get("error"):
error_info = response_json["error"]
if isinstance(error_info, dict):
error_message = error_info.get("message", "未知错误")
error_code = error_info.get("code", "unknown")
error_code_mapping = {
"invalid_api_key": LLMErrorCode.API_KEY_INVALID,
"authentication_failed": LLMErrorCode.API_KEY_INVALID,
"insufficient_quota": LLMErrorCode.API_QUOTA_EXCEEDED,
"rate_limit_exceeded": LLMErrorCode.API_RATE_LIMITED,
"quota_exceeded": LLMErrorCode.API_RATE_LIMITED,
"model_not_found": LLMErrorCode.MODEL_NOT_FOUND,
"invalid_model": LLMErrorCode.MODEL_NOT_FOUND,
"context_length_exceeded": LLMErrorCode.CONTEXT_LENGTH_EXCEEDED,
"max_tokens_exceeded": LLMErrorCode.CONTEXT_LENGTH_EXCEEDED,
"invalid_request_error": LLMErrorCode.INVALID_PARAMETER,
"invalid_parameter": LLMErrorCode.INVALID_PARAMETER,
}
llm_error_code = error_code_mapping.get(
error_code, LLMErrorCode.API_RESPONSE_INVALID
)
else:
error_message = str(error_info)
error_code = "unknown"
llm_error_code = LLMErrorCode.API_RESPONSE_INVALID
raise LLMException(
f"API请求失败: {error_message}",
code=llm_error_code,
details={"api_error": error_info, "error_code": error_code},
)
def parse(self, response_json: dict[str, Any]) -> ResponseData:
self.validate_response(response_json)
choices = response_json.get("choices", [])
if not choices:
return ResponseData(text="", raw_response=response_json)
choice = choices[0]
message = choice.get("message", {})
content = message.get("content", "")
reasoning_content = message.get("reasoning_content", None)
refusal = message.get("refusal")
if refusal:
raise LLMException(
f"模型拒绝生成请求: {refusal}",
code=LLMErrorCode.CONTENT_FILTERED,
details={"refusal": refusal},
recoverable=False,
)
if content:
content = content.strip()
images_payload: list[bytes | Path] = []
if content and content.startswith("{") and content.endswith("}"):
try:
content_json = json.loads(content)
if "b64_json" in content_json:
b64_str = content_json["b64_json"]
if isinstance(b64_str, str) and b64_str.startswith("data:"):
b64_str = b64_str.split(",", 1)[1]
decoded = base64.b64decode(b64_str)
images_payload.append(process_image_data(decoded))
content = "[图片已生成]"
elif "data" in content_json and isinstance(content_json["data"], str):
b64_str = content_json["data"]
if b64_str.startswith("data:"):
b64_str = b64_str.split(",", 1)[1]
decoded = base64.b64decode(b64_str)
images_payload.append(process_image_data(decoded))
content = "[图片已生成]"
except (json.JSONDecodeError, KeyError, binascii.Error):
pass
elif (
"images" in message
and isinstance(message["images"], list)
and message["images"]
):
for image_info in message["images"]:
if image_info.get("type") == "image_url":
image_url_obj = image_info.get("image_url", {})
url_str = image_url_obj.get("url", "")
if url_str.startswith("data:image"):
try:
b64_data = url_str.split(",", 1)[1]
decoded = base64.b64decode(b64_data)
images_payload.append(process_image_data(decoded))
except (IndexError, binascii.Error) as e:
logger.warning(f"解析OpenRouter Base64图片数据失败: {e}")
if images_payload:
content = content if content else "[图片已生成]"
parsed_tool_calls: list[LLMToolCall] | None = None
if message_tool_calls := message.get("tool_calls"):
parsed_tool_calls = []
for tc_data in message_tool_calls:
try:
if tc_data.get("type") == "function":
parsed_tool_calls.append(
LLMToolCall(
id=tc_data["id"],
function=LLMToolFunction(
name=tc_data["function"]["name"],
arguments=tc_data["function"]["arguments"],
),
)
)
except KeyError as e:
logger.warning(
f"解析OpenAI工具调用数据时缺少键: {tc_data}, 错误: {e}"
)
except Exception as e:
logger.warning(
f"解析OpenAI工具调用数据时出错: {tc_data}, 错误: {e}"
)
if not parsed_tool_calls:
parsed_tool_calls = None
final_text = content if content is not None else ""
if not final_text and parsed_tool_calls:
final_text = f"请求调用 {len(parsed_tool_calls)} 个工具。"
usage_info = response_json.get("usage")
return ResponseData(
text=final_text,
tool_calls=parsed_tool_calls,
usage_info=usage_info,
images=images_payload if images_payload else None,
raw_response=response_json,
thought_text=reasoning_content,
)

View File

@ -2,10 +2,17 @@
LLM 适配器工厂类
"""
from typing import ClassVar
import fnmatch
from typing import TYPE_CHECKING, Any, ClassVar
from ..types.exceptions import LLMErrorCode, LLMException
from .base import BaseAdapter
from ..types.models import ToolChoice
from .base import BaseAdapter, RequestData, ResponseData
if TYPE_CHECKING:
from ..config.generation import LLMEmbeddingConfig, LLMGenerationConfig
from ..service import LLMModel
from ..types import LLMMessage
class LLMAdapterFactory:
@ -21,10 +28,13 @@ class LLMAdapterFactory:
return
from .gemini import GeminiAdapter
from .openai import OpenAIAdapter
from .openai import DeepSeekAdapter, OpenAIAdapter, OpenAIImageAdapter
cls.register_adapter(OpenAIAdapter())
cls.register_adapter(DeepSeekAdapter())
cls.register_adapter(GeminiAdapter())
cls.register_adapter(SmartAdapter())
cls.register_adapter(OpenAIImageAdapter())
@classmethod
def register_adapter(cls, adapter: BaseAdapter) -> None:
@ -74,3 +84,100 @@ def get_adapter_for_api_type(api_type: str) -> BaseAdapter:
def register_adapter(adapter: BaseAdapter) -> None:
"""注册新的适配器"""
LLMAdapterFactory.register_adapter(adapter)
class SmartAdapter(BaseAdapter):
"""
智能路由适配器
本身不处理序列化而是根据规则委托给 OpenAIAdapter GeminiAdapter
"""
@property
def log_sanitization_context(self) -> str:
return "openai_request"
_ROUTING_RULES: ClassVar[list[tuple[str, str]]] = [
("*nano-banana*", "gemini"),
("*gemini*", "gemini"),
]
_DEFAULT_API_TYPE: ClassVar[str] = "openai"
def __init__(self):
self._adapter_cache: dict[str, BaseAdapter] = {}
@property
def api_type(self) -> str:
return "smart"
@property
def supported_api_types(self) -> list[str]:
return ["smart"]
def _get_delegate_adapter(self, model: "LLMModel") -> BaseAdapter:
"""
核心路由逻辑决定使用哪个适配器 (带缓存)
"""
if model.model_detail.api_type:
return get_adapter_for_api_type(model.model_detail.api_type)
model_name = model.model_name
if model_name in self._adapter_cache:
return self._adapter_cache[model_name]
target_api_type = self._DEFAULT_API_TYPE
model_name_lower = model_name.lower()
for pattern, api_type in self._ROUTING_RULES:
if fnmatch.fnmatch(model_name_lower, pattern):
target_api_type = api_type
break
adapter = get_adapter_for_api_type(target_api_type)
self._adapter_cache[model_name] = adapter
return adapter
async def prepare_advanced_request(
self,
model: "LLMModel",
api_key: str,
messages: list["LLMMessage"],
config: "LLMGenerationConfig | None" = None,
tools: list[Any] | None = None,
tool_choice: "str | dict[str, Any] | ToolChoice | None" = None,
) -> RequestData:
adapter = self._get_delegate_adapter(model)
return await adapter.prepare_advanced_request(
model, api_key, messages, config, tools, tool_choice
)
def parse_response(
self,
model: "LLMModel",
response_json: dict[str, Any],
is_advanced: bool = False,
) -> ResponseData:
adapter = self._get_delegate_adapter(model)
return adapter.parse_response(model, response_json, is_advanced)
def prepare_embedding_request(
self,
model: "LLMModel",
api_key: str,
texts: list[str],
config: "LLMEmbeddingConfig",
) -> RequestData:
adapter = self._get_delegate_adapter(model)
return adapter.prepare_embedding_request(model, api_key, texts, config)
def parse_embedding_response(
self, response_json: dict[str, Any]
) -> list[list[float]]:
return get_adapter_for_api_type("openai").parse_embedding_response(
response_json
)
def convert_generation_config(
self, config: "LLMGenerationConfig", model: "LLMModel"
) -> dict[str, Any]:
adapter = self._get_delegate_adapter(model)
return adapter.convert_generation_config(config, model)

View File

@ -2,27 +2,35 @@
Gemini API 适配器
"""
import base64
from typing import TYPE_CHECKING, Any
from zhenxun.services.log import logger
from ..config.generation import ResponseFormat
from ..types import LLMContentPart
from ..types.exceptions import LLMErrorCode, LLMException
from ..utils import sanitize_schema_for_llm
from ..types.models import BasePlatformTool, ToolChoice
from .base import BaseAdapter, RequestData, ResponseData
from .components.gemini_components import (
GeminiConfigMapper,
GeminiMessageConverter,
GeminiResponseParser,
GeminiToolSerializer,
)
if TYPE_CHECKING:
from ..config.generation import LLMGenerationConfig
from ..config.generation import LLMEmbeddingConfig, LLMGenerationConfig
from ..service import LLMModel
from ..types.content import LLMMessage
from ..types.enums import EmbeddingTaskType
from ..types.models import LLMToolCall
from ..types.protocols import ToolExecutable
from ..types import LLMMessage
class GeminiAdapter(BaseAdapter):
"""Gemini API 适配器"""
@property
def log_sanitization_context(self) -> str:
return "gemini_request"
@property
def api_type(self) -> str:
return "gemini"
@ -47,110 +55,75 @@ class GeminiAdapter(BaseAdapter):
api_key: str,
messages: list["LLMMessage"],
config: "LLMGenerationConfig | None" = None,
tools: dict[str, "ToolExecutable"] | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[Any] | None = None,
tool_choice: str | dict[str, Any] | ToolChoice | None = None,
) -> RequestData:
"""准备高级请求"""
effective_config = config if config is not None else model._generation_config
if tools:
from ..types.models import GeminiUrlContext
context_urls: list[str] = []
for tool in tools:
if isinstance(tool, GeminiUrlContext):
context_urls.extend(tool.urls)
if context_urls and messages:
last_msg = messages[-1]
if last_msg.role == "user":
url_text = "\n\n[Context URLs]:\n" + "\n".join(context_urls)
if isinstance(last_msg.content, str):
last_msg.content += url_text
elif isinstance(last_msg.content, list):
last_msg.content.append(LLMContentPart.text_part(url_text))
has_function_tools = False
if tools:
has_function_tools = any(hasattr(tool, "get_definition") for tool in tools)
is_structured = False
if effective_config and effective_config.output:
if (
effective_config.output.response_schema
or effective_config.output.response_format == ResponseFormat.JSON
or effective_config.output.response_mime_type == "application/json"
):
is_structured = True
if (has_function_tools or is_structured) and effective_config:
if effective_config.reasoning is None:
from ..config.generation import ReasoningConfig
effective_config.reasoning = ReasoningConfig()
if (
effective_config.reasoning.budget_tokens is None
and effective_config.reasoning.effort is None
):
reason_desc = "工具调用" if has_function_tools else "结构化输出"
logger.debug(
f"检测到{reason_desc},自动为模型 {model.model_name} 开启思维链增强"
)
effective_config.reasoning.budget_tokens = -1
endpoint = self._get_gemini_endpoint(model, effective_config)
url = self.get_api_url(model, endpoint)
headers = self.get_base_headers(api_key)
gemini_contents: list[dict[str, Any]] = []
converter = GeminiMessageConverter()
system_instruction_parts: list[dict[str, Any]] | None = None
for msg in messages:
current_parts: list[dict[str, Any]] = []
if msg.role == "system":
if isinstance(msg.content, str):
system_instruction_parts = [{"text": msg.content}]
elif isinstance(msg.content, list):
system_instruction_parts = [
await part.convert_for_api_async("gemini")
for part in msg.content
await converter.convert_part(part) for part in msg.content
]
continue
elif msg.role == "user":
if isinstance(msg.content, str):
current_parts.append({"text": msg.content})
elif isinstance(msg.content, list):
for part_obj in msg.content:
current_parts.append(
await part_obj.convert_for_api_async("gemini")
)
gemini_contents.append({"role": "user", "parts": current_parts})
elif msg.role == "assistant" or msg.role == "model":
if isinstance(msg.content, str) and msg.content:
current_parts.append({"text": msg.content})
elif isinstance(msg.content, list):
for part_obj in msg.content:
current_parts.append(
await part_obj.convert_for_api_async("gemini")
)
if msg.tool_calls:
import json
for call in msg.tool_calls:
current_parts.append(
{
"functionCall": {
"name": call.function.name,
"args": json.loads(call.function.arguments),
}
}
)
if current_parts:
gemini_contents.append({"role": "model", "parts": current_parts})
elif msg.role == "tool":
if not msg.name:
raise ValueError("Gemini 工具消息必须包含 'name' 字段(函数名)。")
import json
try:
content_str = (
msg.content
if isinstance(msg.content, str)
else str(msg.content)
)
tool_result_obj = json.loads(content_str)
except json.JSONDecodeError:
content_str = (
msg.content
if isinstance(msg.content, str)
else str(msg.content)
)
logger.warning(
f"工具 {msg.name} 的结果不是有效的 JSON: {content_str}. "
f"包装为原始字符串。"
)
tool_result_obj = {"raw_output": content_str}
if isinstance(tool_result_obj, list):
logger.debug(
f"工具 '{msg.name}' 的返回结果是列表,"
f"正在为Gemini API包装为JSON对象。"
)
final_response_payload = {"result": tool_result_obj}
elif not isinstance(tool_result_obj, dict):
final_response_payload = {"result": tool_result_obj}
else:
final_response_payload = tool_result_obj
current_parts.append(
{
"functionResponse": {
"name": msg.name,
"response": final_response_payload,
}
}
)
gemini_contents.append({"role": "function", "parts": current_parts})
gemini_contents = await converter.convert_messages_async(messages)
body: dict[str, Any] = {"contents": gemini_contents}
@ -158,75 +131,78 @@ class GeminiAdapter(BaseAdapter):
body["systemInstruction"] = {"parts": system_instruction_parts}
all_tools_for_request = []
has_user_functions = False
if tools:
import asyncio
from ..types.protocols import ToolExecutable
from zhenxun.utils.pydantic_compat import model_dump
function_tools: list[ToolExecutable] = []
gemini_tools_dict: dict[str, Any] = {}
definition_tasks = [
executable.get_definition() for executable in tools.values()
]
tool_definitions = await asyncio.gather(*definition_tasks)
for tool in tools:
if isinstance(tool, BasePlatformTool):
declaration = tool.get_tool_declaration()
if declaration:
gemini_tools_dict.update(declaration)
elif hasattr(tool, "get_definition"):
function_tools.append(tool)
function_declarations = []
for tool_def in tool_definitions:
tool_def.parameters = sanitize_schema_for_llm(
tool_def.parameters, api_type="gemini"
)
function_declarations.append(model_dump(tool_def))
if function_tools:
import asyncio
if function_declarations:
all_tools_for_request.append(
{"functionDeclarations": function_declarations}
)
definition_tasks = [
executable.get_definition() for executable in function_tools
]
tool_definitions = await asyncio.gather(*definition_tasks)
if effective_config:
if getattr(effective_config, "enable_grounding", False):
has_explicit_gs_tool = any(
"googleSearch" in tool_item for tool_item in all_tools_for_request
)
if not has_explicit_gs_tool:
all_tools_for_request.append({"googleSearch": {}})
logger.debug("隐式启用 Google Search 工具进行信息来源关联。")
serializer = GeminiToolSerializer()
function_declarations = serializer.serialize_tools(tool_definitions)
if getattr(effective_config, "enable_code_execution", False):
has_explicit_ce_tool = any(
"codeExecution" in tool_item for tool_item in all_tools_for_request
)
if not has_explicit_ce_tool:
all_tools_for_request.append({"codeExecution": {}})
logger.debug("隐式启用代码执行工具。")
if function_declarations:
gemini_tools_dict["functionDeclarations"] = function_declarations
has_user_functions = True
if gemini_tools_dict:
all_tools_for_request.append(gemini_tools_dict)
if all_tools_for_request:
body["tools"] = all_tools_for_request
final_tool_choice = tool_choice
if final_tool_choice is None and effective_config:
final_tool_choice = getattr(effective_config, "tool_choice", None)
tool_config_updates: dict[str, Any] = {}
if (
effective_config
and effective_config.custom_params
and "user_location" in effective_config.custom_params
):
tool_config_updates["retrievalConfig"] = {
"latLng": effective_config.custom_params["user_location"]
}
if final_tool_choice:
if isinstance(final_tool_choice, str):
mode_upper = final_tool_choice.upper()
if mode_upper in ["AUTO", "NONE", "ANY"]:
body["toolConfig"] = {"functionCallingConfig": {"mode": mode_upper}}
else:
body["toolConfig"] = self._convert_tool_choice_to_gemini(
final_tool_choice
)
else:
body["toolConfig"] = self._convert_tool_choice_to_gemini(
final_tool_choice
if tool_config_updates:
body.setdefault("toolConfig", {}).update(tool_config_updates)
converted_params: dict[str, Any] = {}
if effective_config:
converted_params = self.convert_generation_config(effective_config, model)
if converted_params:
if "toolConfig" in converted_params:
tool_config_payload = converted_params.pop("toolConfig")
fc_config = tool_config_payload.get("functionCallingConfig")
should_apply_fc = has_user_functions or (
fc_config and fc_config.get("mode") == "NONE"
)
if should_apply_fc:
body.setdefault("toolConfig", {}).update(tool_config_payload)
elif fc_config and fc_config.get("mode") != "AUTO":
logger.debug(
"Gemini: 忽略针对纯内置工具的 functionCallingConfig (API限制)"
)
final_generation_config = self._build_gemini_generation_config(
model, effective_config
)
if final_generation_config:
body["generationConfig"] = final_generation_config
if "safetySettings" in converted_params:
body["safetySettings"] = converted_params.pop("safetySettings")
safety_settings = self._build_safety_settings(effective_config)
if safety_settings:
body["safetySettings"] = safety_settings
if converted_params:
body["generationConfig"] = converted_params
return RequestData(url=url, headers=headers, body=body)
@ -242,317 +218,56 @@ class GeminiAdapter(BaseAdapter):
def _get_gemini_endpoint(
self, model: "LLMModel", config: "LLMGenerationConfig | None" = None
) -> str:
"""根据配置选择Gemini API端点"""
if config:
if getattr(config, "enable_code_execution", False):
return f"/v1beta/models/{model.model_name}:generateContent"
if getattr(config, "enable_grounding", False):
return f"/v1beta/models/{model.model_name}:generateContent"
"""返回Gemini generateContent 端点"""
return f"/v1beta/models/{model.model_name}:generateContent"
def _convert_tool_choice_to_gemini(
self, tool_choice_value: str | dict[str, Any]
) -> dict[str, Any]:
"""转换工具选择策略为Gemini格式"""
if isinstance(tool_choice_value, str):
mode_upper = tool_choice_value.upper()
if mode_upper in ["AUTO", "NONE", "ANY"]:
return {"functionCallingConfig": {"mode": mode_upper}}
else:
logger.warning(
f"不支持的 tool_choice 字符串值: '{tool_choice_value}'"
f"回退到 AUTO。"
)
return {"functionCallingConfig": {"mode": "AUTO"}}
elif isinstance(tool_choice_value, dict):
if (
tool_choice_value.get("type") == "function"
and "function" in tool_choice_value
):
func_name = tool_choice_value["function"].get("name")
if func_name:
return {
"functionCallingConfig": {
"mode": "ANY",
"allowedFunctionNames": [func_name],
}
}
else:
logger.warning(
f"tool_choice dict 中的函数名无效: {tool_choice_value}"
f"回退到 AUTO。"
)
return {"functionCallingConfig": {"mode": "AUTO"}}
elif "functionCallingConfig" in tool_choice_value:
return {
"functionCallingConfig": tool_choice_value["functionCallingConfig"]
}
else:
logger.warning(
f"不支持的 tool_choice dict 值: {tool_choice_value}。回退到 AUTO。"
)
return {"functionCallingConfig": {"mode": "AUTO"}}
logger.warning(
f"tool_choice 的类型无效: {type(tool_choice_value)}。回退到 AUTO。"
)
return {"functionCallingConfig": {"mode": "AUTO"}}
def _build_gemini_generation_config(
self, model: "LLMModel", config: "LLMGenerationConfig | None" = None
) -> dict[str, Any]:
"""构建Gemini生成配置"""
effective_config = config if config is not None else model._generation_config
if not effective_config:
return {}
generation_config = effective_config.to_api_params(
api_type="gemini", model_name=model.model_name
)
if generation_config:
param_keys = list(generation_config.keys())
logger.debug(
f"构建Gemini生成配置完成包含 {len(generation_config)} 个参数: "
f"{param_keys}"
)
return generation_config
def _build_safety_settings(
self, config: "LLMGenerationConfig | None" = None
) -> list[dict[str, Any]] | None:
"""构建安全设置"""
if not config:
return None
safety_settings = []
safety_categories = [
"HARM_CATEGORY_HARASSMENT",
"HARM_CATEGORY_HATE_SPEECH",
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
"HARM_CATEGORY_DANGEROUS_CONTENT",
]
custom_safety_settings = getattr(config, "safety_settings", None)
if custom_safety_settings:
for category, threshold in custom_safety_settings.items():
safety_settings.append({"category": category, "threshold": threshold})
else:
from ..config.providers import get_gemini_safety_threshold
threshold = get_gemini_safety_threshold()
for category in safety_categories:
safety_settings.append({"category": category, "threshold": threshold})
return safety_settings if safety_settings else None
def validate_response(self, response_json: dict[str, Any]) -> None:
"""验证 Gemini API 响应,增加对 promptFeedback 的检查"""
super().validate_response(response_json)
if prompt_feedback := response_json.get("promptFeedback"):
if block_reason := prompt_feedback.get("blockReason"):
logger.warning(
f"Gemini 内容因 promptFeedback 被安全过滤: {block_reason}"
)
raise LLMException(
f"内容被安全过滤: {block_reason}",
code=LLMErrorCode.CONTENT_FILTERED,
details={
"block_reason": block_reason,
"safety_ratings": prompt_feedback.get("safetyRatings"),
},
)
def parse_response(
self,
model: "LLMModel",
response_json: dict[str, Any],
is_advanced: bool = False,
) -> ResponseData:
"""解析API响应"""
return self._parse_response(model, response_json, is_advanced)
def _parse_response(
self,
model: "LLMModel",
response_json: dict[str, Any],
is_advanced: bool = False,
) -> ResponseData:
"""解析 Gemini API 响应"""
_ = is_advanced
self.validate_response(response_json)
try:
if "image_generation" in response_json and isinstance(
response_json["image_generation"], dict
):
candidates_source = response_json["image_generation"]
else:
candidates_source = response_json
candidates = candidates_source.get("candidates", [])
usage_info = response_json.get("usageMetadata")
if not candidates:
logger.debug("Gemini响应中没有candidates。")
return ResponseData(text="", raw_response=response_json)
candidate = candidates[0]
if candidate.get("finishReason") in [
"RECITATION",
"OTHER",
] and not candidate.get("content"):
logger.warning(
f"Gemini candidate finished with reason "
f"'{candidate.get('finishReason')}' and no content."
)
return ResponseData(
text="",
raw_response=response_json,
usage_info=response_json.get("usageMetadata"),
)
content_data = candidate.get("content", {})
parts = content_data.get("parts", [])
text_content = ""
images_bytes: list[bytes] = []
parsed_tool_calls: list["LLMToolCall"] | None = None
thought_summary_parts = []
answer_parts = []
for part in parts:
if "text" in part:
answer_parts.append(part["text"])
elif "thought" in part:
thought_summary_parts.append(part["thought"])
elif "thoughtSummary" in part:
thought_summary_parts.append(part["thoughtSummary"])
elif "inlineData" in part:
inline_data = part["inlineData"]
if "data" in inline_data:
images_bytes.append(base64.b64decode(inline_data["data"]))
elif "functionCall" in part:
if parsed_tool_calls is None:
parsed_tool_calls = []
fc_data = part["functionCall"]
try:
import json
from ..types.models import LLMToolCall, LLMToolFunction
call_id = f"call_{model.provider_name}_{len(parsed_tool_calls)}"
parsed_tool_calls.append(
LLMToolCall(
id=call_id,
function=LLMToolFunction(
name=fc_data["name"],
arguments=json.dumps(fc_data["args"]),
),
)
)
except KeyError as e:
logger.warning(
f"解析Gemini functionCall时缺少键: {fc_data}, 错误: {e}"
)
except Exception as e:
logger.warning(
f"解析Gemini functionCall时出错: {fc_data}, 错误: {e}"
)
elif "codeExecutionResult" in part:
result = part["codeExecutionResult"]
if result.get("outcome") == "OK":
output = result.get("output", "")
answer_parts.append(f"\n[代码执行结果]:\n```\n{output}\n```\n")
else:
answer_parts.append(
f"\n[代码执行失败]: {result.get('outcome', 'UNKNOWN')}\n"
)
if thought_summary_parts:
full_thought_summary = "\n".join(thought_summary_parts).strip()
full_answer = "".join(answer_parts).strip()
formatted_parts = []
if full_thought_summary:
formatted_parts.append(f"🤔 **思考过程**\n\n{full_thought_summary}")
if full_answer:
separator = "\n\n---\n\n" if full_thought_summary else ""
formatted_parts.append(f"{separator}✅ **回答**\n\n{full_answer}")
text_content = "".join(formatted_parts)
else:
text_content = "".join(answer_parts)
usage_info = response_json.get("usageMetadata")
grounding_metadata_obj = None
if grounding_data := candidate.get("groundingMetadata"):
try:
from ..types.models import LLMGroundingMetadata
grounding_metadata_obj = LLMGroundingMetadata(**grounding_data)
except Exception as e:
logger.warning(f"无法解析Grounding元数据: {grounding_data}, {e}")
return ResponseData(
text=text_content,
tool_calls=parsed_tool_calls,
images=images_bytes if images_bytes else None,
usage_info=usage_info,
raw_response=response_json,
grounding_metadata=grounding_metadata_obj,
)
except Exception as e:
logger.error(f"解析 Gemini 响应失败: {e}", e=e)
raise LLMException(
f"解析API响应失败: {e}",
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
cause=e,
)
_ = model, is_advanced
parser = GeminiResponseParser()
return parser.parse(response_json)
def prepare_embedding_request(
self,
model: "LLMModel",
api_key: str,
texts: list[str],
task_type: "EmbeddingTaskType | str",
**kwargs: Any,
config: "LLMEmbeddingConfig",
) -> RequestData:
"""准备文本嵌入请求"""
api_model_name = model.model_name
if not api_model_name.startswith("models/"):
api_model_name = f"models/{api_model_name}"
url = self.get_api_url(model, f"/{api_model_name}:batchEmbedContents")
if not model.api_base:
raise LLMException(
f"模型 {model.model_name} 的 api_base 未设置",
code=LLMErrorCode.CONFIGURATION_ERROR,
)
base_url = model.api_base.rstrip("/")
url = f"{base_url}/v1beta/{api_model_name}:batchEmbedContents"
headers = self.get_base_headers(api_key)
requests_payload = []
for text_content in texts:
safe_text = text_content if text_content else " "
request_item: dict[str, Any] = {
"content": {"parts": [{"text": text_content}]},
"model": api_model_name,
"content": {"parts": [{"text": safe_text}]},
}
from ..types.enums import EmbeddingTaskType
if task_type and task_type != EmbeddingTaskType.RETRIEVAL_DOCUMENT:
request_item["task_type"] = str(task_type).upper()
if title := kwargs.get("title"):
request_item["title"] = title
if output_dimensionality := kwargs.get("output_dimensionality"):
request_item["output_dimensionality"] = output_dimensionality
if config.task_type:
request_item["task_type"] = str(config.task_type).upper()
if config.title:
request_item["title"] = config.title
if config.output_dimensionality:
request_item["output_dimensionality"] = config.output_dimensionality
requests_payload.append(request_item)
@ -601,3 +316,9 @@ class GeminiAdapter(BaseAdapter):
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
details=response_json,
)
def convert_generation_config(
self, config: "LLMGenerationConfig", model: "LLMModel"
) -> dict[str, Any]:
mapper = GeminiConfigMapper()
return mapper.map_config(config, model.model_detail, model.capabilities)

View File

@ -1,15 +1,181 @@
"""
OpenAI API 适配器
支持 OpenAIDeepSeek智谱AI 和其他 OpenAI 兼容的 API 服务
支持 OpenAI智谱AI OpenAI 兼容的 API 服务
"""
from typing import TYPE_CHECKING
from abc import ABC, abstractmethod
import base64
from pathlib import Path
from typing import TYPE_CHECKING, Any
from .base import OpenAICompatAdapter
import json_repair
from zhenxun.services.llm.config.generation import ImageAspectRatio
from zhenxun.services.llm.types.exceptions import LLMErrorCode, LLMException
from zhenxun.services.log import logger
from zhenxun.utils.http_utils import AsyncHttpx
from ..types import StructuredOutputStrategy
from ..types.models import ToolChoice
from ..utils import sanitize_schema_for_llm
from .base import (
BaseAdapter,
OpenAICompatAdapter,
RequestData,
ResponseData,
process_image_data,
)
from .components.openai_components import (
OpenAIConfigMapper,
OpenAIMessageConverter,
OpenAIResponseParser,
OpenAIToolSerializer,
)
if TYPE_CHECKING:
from ..config.generation import LLMEmbeddingConfig, LLMGenerationConfig
from ..service import LLMModel
from ..types import LLMMessage
class APIProtocol(ABC):
"""API 协议策略基类"""
@abstractmethod
def build_request_body(
self,
model: "LLMModel",
messages: list["LLMMessage"],
tools: list[dict[str, Any]] | None,
tool_choice: Any,
) -> dict[str, Any]:
"""构建不同协议下的请求体"""
pass
@abstractmethod
def parse_response(self, response_json: dict[str, Any]) -> ResponseData:
"""解析不同协议下的响应"""
pass
class StandardProtocol(APIProtocol):
"""标准 OpenAI 协议策略"""
def __init__(self, adapter: "OpenAICompatAdapter"):
self.adapter = adapter
def build_request_body(
self,
model: "LLMModel",
messages: list["LLMMessage"],
tools: list[dict[str, Any]] | None,
tool_choice: Any,
) -> dict[str, Any]:
converter = OpenAIMessageConverter()
openai_messages = converter.convert_messages(messages)
body: dict[str, Any] = {
"model": model.model_name,
"messages": openai_messages,
}
if tools:
body["tools"] = tools
if tool_choice:
body["tool_choice"] = tool_choice
return body
def parse_response(self, response_json: dict[str, Any]) -> ResponseData:
parser = OpenAIResponseParser()
return parser.parse(response_json)
class ResponsesProtocol(APIProtocol):
"""/v1/responses 新版协议策略"""
def __init__(self, adapter: "OpenAICompatAdapter"):
self.adapter = adapter
def build_request_body(
self,
model: "LLMModel",
messages: list["LLMMessage"],
tools: list[dict[str, Any]] | None,
tool_choice: Any,
) -> dict[str, Any]:
input_items: list[dict[str, Any]] = []
for msg in messages:
role = msg.role
content_list: list[dict[str, Any]] = []
raw_contents = (
msg.content if isinstance(msg.content, list) else [msg.content]
)
for part in raw_contents:
if part is None:
continue
if isinstance(part, str):
content_list.append({"type": "input_text", "text": part})
continue
if hasattr(part, "type"):
part_type = getattr(part, "type", None)
if part_type == "text":
content_list.append(
{"type": "input_text", "text": getattr(part, "text", "")}
)
elif part_type == "image":
content_list.append(
{
"type": "input_image",
"image_url": getattr(part, "image_source", ""),
}
)
continue
if isinstance(part, dict):
part_type = part.get("type")
if part_type == "text":
content_list.append(
{"type": "input_text", "text": part.get("text", "")}
)
elif part_type in {"image", "image_url"}:
image_src = part.get("image_url") or part.get(
"image_source", ""
)
content_list.append(
{
"type": "input_image",
"image_url": image_src,
}
)
input_items.append({"role": role, "content": content_list})
body: dict[str, Any] = {
"model": model.model_name,
"input": input_items,
}
if tools:
body["tools"] = tools
if tool_choice:
body["tool_choice"] = tool_choice
return body
def parse_response(self, response_json: dict[str, Any]) -> ResponseData:
self.adapter.validate_response(response_json)
text_content = ""
for item in response_json.get("output", []):
if item.get("type") == "message" and item.get("role") == "assistant":
for content_item in item.get("content", []):
if content_item.get("type") == "output_text":
text_content += content_item.get("text", "")
return ResponseData(
text=text_content,
usage_info=response_json.get("usage"),
raw_response=response_json,
)
class OpenAIAdapter(OpenAICompatAdapter):
@ -23,23 +189,411 @@ class OpenAIAdapter(OpenAICompatAdapter):
def supported_api_types(self) -> list[str]:
return [
"openai",
"deepseek",
"zhipu",
"general_openai_compat",
"ark",
"openrouter",
"openai_responses",
]
def get_chat_endpoint(self, model: "LLMModel") -> str:
"""返回聊天完成端点"""
if model.api_type == "ark":
if model.model_detail.endpoint:
return model.model_detail.endpoint
current_api_type = model.model_detail.api_type or model.api_type
if current_api_type == "openai_responses":
return "/v1/responses"
if current_api_type == "ark":
return "/api/v3/chat/completions"
if model.api_type == "zhipu":
if current_api_type == "zhipu":
return "/api/paas/v4/chat/completions"
return "/v1/chat/completions"
def _get_protocol_strategy(self, model: "LLMModel") -> APIProtocol:
"""根据 API 类型获取对应的处理策略"""
current_api_type = model.model_detail.api_type or model.api_type
if current_api_type == "openai_responses":
return ResponsesProtocol(self)
return StandardProtocol(self)
def get_embedding_endpoint(self, model: "LLMModel") -> str:
"""根据API类型返回嵌入端点"""
if model.api_type == "zhipu":
return "/v4/embeddings"
return "/v1/embeddings"
def convert_generation_config(
self, config: "LLMGenerationConfig", model: "LLMModel"
) -> dict[str, Any]:
mapper = OpenAIConfigMapper(api_type=self.api_type)
return mapper.map_config(config, model.model_detail, model.capabilities)
async def prepare_advanced_request(
self,
model: "LLMModel",
api_key: str,
messages: list["LLMMessage"],
config: "LLMGenerationConfig | None" = None,
tools: list[Any] | None = None,
tool_choice: str | dict[str, Any] | ToolChoice | None = None,
) -> "RequestData":
"""根据不同协议策略构建高级请求"""
url = self.get_api_url(model, self.get_chat_endpoint(model))
headers = self.get_base_headers(api_key)
if model.api_type == "openrouter":
headers.update(
{
"HTTP-Referer": "https://github.com/zhenxun-org/zhenxun_bot",
"X-Title": "Zhenxun Bot",
}
)
default_config = getattr(model, "_generation_config", None)
effective_config = config if config is not None else default_config
structured_strategy = (
effective_config.output.structured_output_strategy
if effective_config and effective_config.output
else None
)
if structured_strategy is None:
structured_strategy = StructuredOutputStrategy.NATIVE
openai_tools: list[dict[str, Any]] | None = None
executables: list[Any] = []
if tools:
if isinstance(tools, dict):
executables = list(tools.values())
else:
for tool in tools:
if hasattr(tool, "get_definition"):
executables.append(tool)
definition_tasks = [executable.get_definition() for executable in executables]
tool_defs: list[Any] = []
if definition_tasks:
import asyncio
tool_defs = await asyncio.gather(*definition_tasks)
if tool_defs:
serializer = OpenAIToolSerializer()
openai_tools = serializer.serialize_tools(tool_defs)
final_tool_choice = tool_choice
if final_tool_choice is None:
if (
effective_config
and effective_config.tool_config
and effective_config.tool_config.mode == "ANY"
):
allowed = effective_config.tool_config.allowed_function_names
if allowed:
if len(allowed) == 1:
final_tool_choice = {
"type": "function",
"function": {"name": allowed[0]},
}
else:
logger.warning(
"OpenAI API 不支持多个 allowed_function_names降级为"
" required。"
)
final_tool_choice = "required"
else:
final_tool_choice = "required"
if (
structured_strategy == StructuredOutputStrategy.TOOL_CALL
and effective_config
and effective_config.output
and effective_config.output.response_schema
):
sanitized_schema = sanitize_schema_for_llm(
effective_config.output.response_schema, api_type="openai"
)
structured_tool = {
"type": "function",
"function": {
"name": "return_structured_response",
"description": "Return the final structured response.",
"parameters": sanitized_schema,
"strict": True if model.api_type != "deepseek" else False,
},
}
if openai_tools is None:
openai_tools = []
openai_tools.append(structured_tool)
final_tool_choice = {
"type": "function",
"function": {"name": "return_structured_response"},
}
protocol_strategy = self._get_protocol_strategy(model)
body = protocol_strategy.build_request_body(
model=model,
messages=messages,
tools=openai_tools,
tool_choice=final_tool_choice,
)
body = self.apply_config_override(model, body, config)
if final_tool_choice is not None:
body["tool_choice"] = final_tool_choice
response_format = body.get("response_format", {})
inject_prompt = (
structured_strategy == StructuredOutputStrategy.NATIVE
and isinstance(response_format, dict)
and response_format.get("type") == "json_object"
)
if inject_prompt:
messages_list = body.get("messages", [])
has_json_keyword = False
for msg in messages_list:
content = msg.get("content")
if isinstance(content, str) and "json" in content.lower():
has_json_keyword = True
break
if isinstance(content, list):
for part in content:
if (
isinstance(part, dict)
and part.get("type") == "text"
and "json" in part.get("text", "").lower()
):
has_json_keyword = True
break
if has_json_keyword:
break
if not has_json_keyword:
injection_text = (
"请务必输出合法的 JSON 格式避免额外的文本、Markdown 或解释。"
)
system_msg = next(
(m for m in messages_list if m.get("role") == "system"), None
)
if system_msg:
if isinstance(system_msg.get("content"), str):
system_msg["content"] += " " + injection_text
elif isinstance(system_msg.get("content"), list):
system_msg["content"].append(
{"type": "text", "text": injection_text}
)
else:
messages_list.insert(
0, {"role": "system", "content": injection_text}
)
body["messages"] = messages_list
return RequestData(url=url, headers=headers, body=body)
def parse_response(
self,
model: "LLMModel",
response_json: dict[str, Any],
is_advanced: bool = False,
) -> ResponseData:
"""解析响应 - 使用策略模式委托处理"""
_ = is_advanced
protocol_strategy = self._get_protocol_strategy(model)
response_data = protocol_strategy.parse_response(response_json)
if response_data.tool_calls:
target_tool = next(
(
tc
for tc in response_data.tool_calls
if tc.function.name == "return_structured_response"
),
None,
)
if target_tool:
response_data.text = json_repair.repair_json(
target_tool.function.arguments
)
remaining = [
tc
for tc in response_data.tool_calls
if tc.function.name != "return_structured_response"
]
response_data.tool_calls = remaining or None
return response_data
class DeepSeekAdapter(OpenAIAdapter):
"""DeepSeek 专用适配器 (基于 OpenAI 协议)"""
@property
def api_type(self) -> str:
return "deepseek"
@property
def supported_api_types(self) -> list[str]:
return ["deepseek"]
class OpenAIImageAdapter(BaseAdapter):
"""OpenAI 图像生成/编辑适配器"""
@property
def api_type(self) -> str:
return "openai_image"
@property
def log_sanitization_context(self) -> str:
return "openai_request"
@property
def supported_api_types(self) -> list[str]:
return ["openai_image", "nano_banana"]
async def prepare_advanced_request(
self,
model: "LLMModel",
api_key: str,
messages: list["LLMMessage"],
config: "LLMGenerationConfig | None" = None,
tools: list[Any] | None = None,
tool_choice: "str | dict[str, Any] | ToolChoice | None" = None,
) -> RequestData:
_ = tools, tool_choice
effective_config = config if config is not None else model._generation_config
headers = self.get_base_headers(api_key)
prompt = ""
images_bytes_list: list[bytes] = []
for msg in reversed(messages):
if msg.role != "user":
continue
if isinstance(msg.content, str):
prompt = msg.content
elif isinstance(msg.content, list):
for part in msg.content:
if part.type == "text" and not prompt:
prompt = part.text
elif part.type == "image":
if part.is_image_base64():
if b64_data := part.get_base64_data():
_, b64_str = b64_data
images_bytes_list.append(base64.b64decode(b64_str))
elif part.is_image_url() and part.image_source:
images_bytes_list.append(
await AsyncHttpx.get_content(part.image_source)
)
if prompt:
break
if not prompt and not images_bytes_list:
raise LLMException(
"图像生成需要提供 Prompt",
code=LLMErrorCode.CONFIGURATION_ERROR,
)
body: dict[str, Any] = {
"model": model.model_name,
"prompt": prompt,
"response_format": "b64_json",
}
if effective_config:
if effective_config.visual:
if effective_config.visual.aspect_ratio:
ar = effective_config.visual.aspect_ratio
size_map = {
ImageAspectRatio.SQUARE: "1024x1024",
ImageAspectRatio.LANDSCAPE_16_9: "1792x1024",
ImageAspectRatio.PORTRAIT_9_16: "1024x1792",
}
if isinstance(ar, ImageAspectRatio) and ar in size_map:
body["size"] = size_map[ar]
body["aspect_ratio"] = ar.value
elif isinstance(ar, str):
if "x" in ar:
body["size"] = ar
else:
body["aspect_ratio"] = ar
if effective_config.visual.resolution:
res_val = effective_config.visual.resolution
if not isinstance(res_val, str):
res_val = getattr(res_val, "value", res_val)
body["image_size"] = res_val
if effective_config.custom_params:
body.update(effective_config.custom_params)
if images_bytes_list:
b64_images = []
for img_bytes in images_bytes_list:
b64_str = base64.b64encode(img_bytes).decode("utf-8")
b64_images.append(b64_str)
body["image"] = b64_images
endpoint = "/v1/images/generations"
url = self.get_api_url(model, endpoint)
return RequestData(url=url, headers=headers, body=body)
def parse_response(
self,
model: "LLMModel",
response_json: dict[str, Any],
is_advanced: bool = False,
) -> ResponseData:
_ = model, is_advanced
self.validate_response(response_json)
images_data: list[bytes | Path] = []
data_list = response_json.get("data", [])
for item in data_list:
if "b64_json" in item:
try:
b64_str = item["b64_json"]
if b64_str.startswith("data:"):
b64_str = b64_str.split(",", 1)[1]
img = base64.b64decode(b64_str)
images_data.append(process_image_data(img))
except Exception as exc:
logger.error(f"Base64 解码失败: {exc}")
elif "url" in item:
logger.warning(
f"API 返回了 URL 而不是 Base64: {item.get('url', 'unknown')}"
)
text_summary = (
f"已生成 {len(images_data)} 张图片。"
if images_data
else "图像生成接口调用成功,但未解析到图片数据。"
)
return ResponseData(
text=text_summary,
images=images_data if images_data else None,
raw_response=response_json,
)
def prepare_embedding_request(
self,
model: "LLMModel",
api_key: str,
texts: list[str],
config: "LLMEmbeddingConfig",
) -> RequestData:
raise NotImplementedError("OpenAIImageAdapter 不支持 Embedding")
def parse_embedding_response(
self, response_json: dict[str, Any]
) -> list[list[float]]:
raise NotImplementedError("OpenAIImageAdapter 不支持 Embedding")
def convert_generation_config(
self, config: "LLMGenerationConfig", model: "LLMModel"
) -> dict[str, Any]:
_ = config, model
return {}

View File

@ -2,6 +2,7 @@
LLM 服务的高级 API 接口 - 便捷函数入口 (无状态)
"""
from collections.abc import Awaitable, Callable
from pathlib import Path
from typing import Any, TypeVar, overload
@ -11,19 +12,24 @@ from pydantic import BaseModel
from zhenxun.services.log import logger
from .config import CommonOverrides
from .config.generation import LLMGenerationConfig, create_generation_config_from_kwargs
from .config.generation import (
GenConfigBuilder,
LLMEmbeddingConfig,
LLMGenerationConfig,
OutputConfig,
)
from .manager import get_model_instance
from .session import AI
from .tools.manager import tool_provider_manager
from .types import (
EmbeddingTaskType,
LLMContentPart,
LLMErrorCode,
LLMException,
LLMMessage,
LLMResponse,
ModelName,
ToolChoice,
)
from .types.exceptions import get_user_friendly_error_message
from .utils import create_multimodal_message
T = TypeVar("T", bound=BaseModel)
@ -34,9 +40,10 @@ async def chat(
*,
model: ModelName = None,
instruction: str | None = None,
tools: list[dict[str, Any] | str] | None = None,
tool_choice: str | dict[str, Any] | None = None,
**kwargs: Any,
tools: list[Any] | None = None,
tool_choice: str | dict[str, Any] | ToolChoice | None = None,
config: LLMGenerationConfig | GenConfigBuilder | None = None,
timeout: float | None = None,
) -> LLMResponse:
"""
无状态的聊天对话便捷函数通过临时的AI会话实例与LLM模型交互
@ -47,14 +54,13 @@ async def chat(
instruction: 系统指令用于指导AI的行为和回复风格
tools: 可用的工具列表支持字典配置或字符串标识符
tool_choice: 工具选择策略控制AI如何选择和使用工具
**kwargs: 额外的生成配置参数会被转换为LLMGenerationConfig
config: (可选) 生成配置对象将与默认配置合并后传递
timeout: (可选) HTTP 请求超时时间
返回:
LLMResponse: 包含AI回复内容使用信息和工具调用等的完整响应对象
"""
try:
config = create_generation_config_from_kwargs(**kwargs) if kwargs else None
ai_session = AI()
return await ai_session.chat(
@ -64,12 +70,14 @@ async def chat(
tools=tools,
tool_choice=tool_choice,
config=config,
timeout=timeout,
)
except LLMException:
raise
except Exception as e:
logger.error(f"执行 chat 函数失败: {e}", e=e)
raise LLMException(f"聊天执行失败: {e}", cause=e)
friendly_msg = get_user_friendly_error_message(e)
logger.error(f"执行 chat 函数失败: {e} | 建议: {friendly_msg}", e=e)
raise LLMException(f"聊天执行失败: {friendly_msg}", cause=e)
async def code(
@ -77,7 +85,6 @@ async def code(
*,
model: ModelName = None,
timeout: int | None = None,
**kwargs: Any,
) -> LLMResponse:
"""
无状态的代码执行便捷函数支持在沙箱环境中执行代码
@ -86,66 +93,25 @@ async def code(
prompt: 代码执行的提示词描述要执行的代码任务
model: 要使用的模型名称默认使用Gemini/gemini-2.0-flash
timeout: 代码执行超时时间防止长时间运行的代码阻塞
**kwargs: 额外的生成配置参数
返回:
LLMResponse: 包含代码执行结果的完整响应对象
"""
resolved_model = model or "Gemini/gemini-2.0-flash"
resolved_model = model
config = CommonOverrides.gemini_code_execution()
if timeout:
config.custom_params = config.custom_params or {}
config.custom_params["code_execution_timeout"] = timeout
final_config = config.to_dict()
final_config.update(kwargs)
return await chat(prompt, model=resolved_model, **final_config)
async def search(
query: str | UniMessage | LLMMessage | list[LLMContentPart],
*,
model: ModelName = None,
instruction: str = (
"你是一位强大的信息检索和整合专家。请利用可用的搜索工具,"
"根据用户的查询找到最相关的信息,并进行总结和回答。"
),
**kwargs: Any,
) -> LLMResponse:
"""
无状态的信息搜索便捷函数利用搜索工具获取实时信息
参数:
query: 搜索查询内容支持多种输入格式
model: 要使用的模型名称如果为None则使用默认模型
instruction: 搜索任务的系统指令指导AI如何处理搜索结果
**kwargs: 额外的生成配置参数
返回:
LLMResponse: 包含搜索结果和AI整合回复的完整响应对象
"""
logger.debug("执行无状态 'search' 任务...")
search_config = CommonOverrides.gemini_grounding()
final_config = search_config.to_dict()
final_config.update(kwargs)
return await chat(
query,
model=model,
instruction=instruction,
**final_config,
)
return await chat(prompt, model=resolved_model, config=config)
async def embed(
texts: list[str] | str,
*,
model: ModelName = None,
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
**kwargs: Any,
config: LLMEmbeddingConfig | None = None,
) -> list[list[float]]:
"""
无状态的文本嵌入便捷函数将文本转换为向量表示
@ -153,8 +119,7 @@ async def embed(
参数:
texts: 要生成嵌入的文本内容支持单个字符串或字符串列表
model: 要使用的嵌入模型名称如果为None则使用默认模型
task_type: 嵌入任务类型影响向量的优化方向如检索分类等
**kwargs: 额外的模型配置参数
config: 嵌入配置对象
返回:
list[list[float]]: 文本对应的嵌入向量列表每个向量为浮点数列表
@ -164,27 +129,71 @@ async def embed(
if not texts:
return []
final_config = config or LLMEmbeddingConfig()
try:
async with await get_model_instance(model) as model_instance:
return await model_instance.generate_embeddings(
texts, task_type=task_type, **kwargs
)
return await model_instance.generate_embeddings(texts, config=final_config)
except LLMException:
raise
except Exception as e:
logger.error(f"文本嵌入失败: {e}", e=e)
friendly_msg = get_user_friendly_error_message(e)
logger.error(f"文本嵌入失败: {e} | 建议: {friendly_msg}", e=e)
raise LLMException(
f"文本嵌入失败: {e}", code=LLMErrorCode.EMBEDDING_FAILED, cause=e
f"文本嵌入失败: {friendly_msg}",
code=LLMErrorCode.EMBEDDING_FAILED,
cause=e,
)
async def embed_query(
text: str,
*,
model: ModelName = None,
dimensions: int | None = None,
) -> list[float]:
"""
语义化便捷 API为检索查询生成嵌入
"""
config = LLMEmbeddingConfig(
task_type="RETRIEVAL_QUERY",
output_dimensionality=dimensions,
)
vectors = await embed([text], model=model, config=config)
return vectors[0] if vectors else []
async def embed_documents(
texts: list[str],
*,
model: ModelName = None,
dimensions: int | None = None,
title: str | None = None,
) -> list[list[float]]:
"""
语义化便捷 API为文档集合生成嵌入
"""
config = LLMEmbeddingConfig(
task_type="RETRIEVAL_DOCUMENT",
output_dimensionality=dimensions,
title=title,
)
return await embed(texts, model=model, config=config)
async def generate_structured(
message: str | LLMMessage | list[LLMContentPart],
response_model: type[T],
*,
model: ModelName = None,
tools: list[Any] | None = None,
tool_choice: str | dict[str, Any] | ToolChoice | None = None,
max_validation_retries: int | None = None,
validation_callback: Callable[[T], Any | Awaitable[Any]] | None = None,
error_prompt_template: str | None = None,
auto_thinking: bool = False,
instruction: str | None = None,
**kwargs: Any,
timeout: float | None = None,
) -> T:
"""
无状态地生成结构化响应并自动解析为指定的Pydantic模型
@ -192,39 +201,48 @@ async def generate_structured(
参数:
message: 用户输入的消息内容支持多种格式
response_model: 用于解析和验证响应的Pydantic模型类
max_validation_retries: 校验失败时的最大重试次数默认为 None (使用全局配置)
validation_callback: 自定义校验回调函数抛出异常视为校验失败
error_prompt_template: 自定义错误反馈提示词模板
auto_thinking: 是否自动开启思维链 (CoT) 包装适用于不支持原生思考的模型
model: 要使用的模型名称如果为None则使用默认模型
instruction: 系统指令用于指导AI生成符合要求的结构化输出
**kwargs: 额外的生成配置参数
timeout: HTTP 请求超时时间
返回:
T: 解析后的Pydantic模型实例类型为response_model指定的类型
"""
try:
config = create_generation_config_from_kwargs(**kwargs) if kwargs else None
ai_session = AI()
return await ai_session.generate_structured(
message,
response_model,
model=model,
tools=tools,
tool_choice=tool_choice,
max_validation_retries=max_validation_retries,
validation_callback=validation_callback,
error_prompt_template=error_prompt_template,
auto_thinking=auto_thinking,
instruction=instruction,
config=config,
timeout=timeout,
)
except LLMException:
raise
except Exception as e:
logger.error(f"生成结构化响应失败: {e}", e=e)
raise LLMException(f"生成结构化响应失败: {e}", cause=e)
friendly_msg = get_user_friendly_error_message(e)
logger.error(f"生成结构化响应失败: {e} | 建议: {friendly_msg}", e=e)
raise LLMException(f"生成结构化响应失败: {friendly_msg}", cause=e)
async def generate(
messages: list[LLMMessage],
*,
model: ModelName = None,
tools: list[dict[str, Any] | str] | None = None,
tool_choice: str | dict[str, Any] | None = None,
**kwargs: Any,
tools: list[Any] | None = None,
tool_choice: str | dict[str, Any] | ToolChoice | None = None,
config: LLMGenerationConfig | GenConfigBuilder | None = None,
) -> LLMResponse:
"""
根据完整的消息列表生成一次性响应这是一个无状态的底层函数
@ -234,109 +252,56 @@ async def generate(
model: 要使用的模型名称如果为None则使用默认模型
tools: 可用的工具列表支持字典配置或字符串标识符
tool_choice: 工具选择策略控制AI如何选择和使用工具
**kwargs: 额外的生成配置参数会覆盖默认配置
config: (可选) 生成配置对象将与默认配置合并后传递
返回:
LLMResponse: 包含AI回复内容使用信息和工具调用等的完整响应对象
"""
try:
if isinstance(config, GenConfigBuilder):
config = config.build()
async with await get_model_instance(
model, override_config=kwargs
model, override_config=None
) as model_instance:
return await model_instance.generate_response(
messages,
tools=tools, # type: ignore
config=config,
tools=tools, # type: ignore[arg-type]
tool_choice=tool_choice,
)
except LLMException:
raise
except Exception as e:
logger.error(f"生成响应失败: {e}", e=e)
raise LLMException(f"生成响应失败: {e}", cause=e)
async def run_with_tools(
message: str | UniMessage | LLMMessage | list[LLMContentPart],
*,
model: ModelName = None,
instruction: str | None = None,
tools: list[str],
max_cycles: int = 5,
**kwargs: Any,
) -> LLMResponse:
"""
无状态地执行一个带本地Python函数的LLM调用循环
参数:
message: 用户输入
model: 使用的模型
instruction: 系统指令
tools: 要使用的本地函数工具名称列表 (必须已通过 @function_tool 注册)
max_cycles: 最大工具调用循环次数
**kwargs: 额外的生成配置参数
返回:
LLMResponse: 包含最终回复的响应对象
"""
from .executor import ExecutionConfig, LLMToolExecutor
from .utils import normalize_to_llm_messages
messages = await normalize_to_llm_messages(message, instruction)
async with await get_model_instance(
model, override_config=kwargs
) as model_instance:
resolved_tools = await tool_provider_manager.get_function_tools(tools)
if not resolved_tools:
logger.warning(
"run_with_tools 未找到任何可用的本地函数工具,将作为普通聊天执行。"
)
return await model_instance.generate_response(messages, tools=None)
executor = LLMToolExecutor(model_instance)
config = ExecutionConfig(max_cycles=max_cycles)
final_history = await executor.run(messages, resolved_tools, config)
for msg in reversed(final_history):
if msg.role == "assistant":
text = msg.content if isinstance(msg.content, str) else str(msg.content)
return LLMResponse(text=text, tool_calls=msg.tool_calls)
raise LLMException(
"带工具的执行循环未能产生有效的助手回复。", code=LLMErrorCode.GENERATION_FAILED
)
friendly_msg = get_user_friendly_error_message(e)
logger.error(f"生成响应失败: {e} | 建议: {friendly_msg}", e=e)
raise LLMException(f"生成响应失败: {friendly_msg}", cause=e)
async def _generate_image_from_message(
message: UniMessage,
model: ModelName = None,
**kwargs: Any,
config: LLMGenerationConfig | GenConfigBuilder | None = None,
) -> LLMResponse:
"""
[内部] UniMessage 生成图片的核心辅助函数
"""
from .utils import normalize_to_llm_messages
config = (
create_generation_config_from_kwargs(**kwargs)
if kwargs
else LLMGenerationConfig()
)
if isinstance(config, GenConfigBuilder):
config = config.build()
config = config or LLMGenerationConfig()
config.validation_policy = {"require_image": True}
config.response_modalities = ["IMAGE", "TEXT"]
if config.output is None:
config.output = OutputConfig()
config.output.response_modalities = ["IMAGE", "TEXT"]
try:
messages = await normalize_to_llm_messages(message)
async with await get_model_instance(model) as model_instance:
if not model_instance.can_generate_images():
raise LLMException(
f"模型 '{model_instance.provider_name}/{model_instance.model_name}'"
f"不支持图片生成",
code=LLMErrorCode.CONFIGURATION_ERROR,
)
response = await model_instance.generate_response(messages, config=config)
if not response.images:
@ -347,8 +312,9 @@ async def _generate_image_from_message(
except LLMException:
raise
except Exception as e:
logger.error(f"执行图片生成时发生未知错误: {e}", e=e)
raise LLMException(f"图片生成失败: {e}", cause=e)
friendly_msg = get_user_friendly_error_message(e)
logger.error(f"执行图片生成时发生未知错误: {e} | 建议: {friendly_msg}", e=e)
raise LLMException(f"图片生成失败: {friendly_msg}", cause=e)
@overload
@ -357,7 +323,6 @@ async def create_image(
*,
images: None = None,
model: ModelName = None,
**kwargs: Any,
) -> LLMResponse:
"""根据文本提示生成一张新图片。"""
...
@ -369,7 +334,6 @@ async def create_image(
*,
images: list[Path | bytes | str] | Path | bytes | str,
model: ModelName = None,
**kwargs: Any,
) -> LLMResponse:
"""在给定图片的基础上,根据文本提示进行编辑或重新生成。"""
...
@ -380,7 +344,7 @@ async def create_image(
*,
images: list[Path | bytes | str] | Path | bytes | str | None = None,
model: ModelName = None,
**kwargs: Any,
config: LLMGenerationConfig | GenConfigBuilder | None = None,
) -> LLMResponse:
"""
智能图片生成/编辑函数
@ -400,4 +364,42 @@ async def create_image(
message = create_multimodal_message(text=text_prompt, images=image_list)
return await _generate_image_from_message(message, model=model, **kwargs)
return await _generate_image_from_message(message, model=model, config=config)
async def search(
query: str | UniMessage | LLMMessage | list[LLMContentPart],
*,
model: ModelName = None,
instruction: str = (
"你是一位强大的信息检索和整合专家。请利用可用的搜索工具,"
"根据用户的查询找到最相关的信息,并进行总结和回答。"
),
config: LLMGenerationConfig | GenConfigBuilder | None = None,
) -> LLMResponse:
"""
无状态的信息搜索便捷函数利用搜索工具获取实时信息
参数:
query: 搜索查询内容支持多种输入格式
model: 要使用的模型名称如果为None则使用默认模型
config: (可选) 生成配置对象将与预设配置合并后传递
instruction: 搜索任务的系统指令指导AI如何处理搜索结果
返回:
LLMResponse: 包含搜索结果和AI整合回复的完整响应对象
"""
logger.debug("执行无状态 'search' 任务...")
search_config = CommonOverrides.gemini_grounding()
if isinstance(config, GenConfigBuilder):
config = config.build()
final_config = search_config.merge_with(config)
return await chat(
query,
model=model,
instruction=instruction,
config=final_config,
)

View File

@ -5,13 +5,12 @@ LLM 配置模块
"""
from .generation import (
CommonOverrides,
GenConfigBuilder,
LLMEmbeddingConfig,
LLMGenerationConfig,
ModelConfigOverride,
apply_api_specific_mappings,
create_generation_config_from_kwargs,
validate_override_params,
)
from .presets import CommonOverrides
from .providers import (
LLMConfig,
get_gemini_safety_threshold,
@ -23,11 +22,10 @@ from .providers import (
__all__ = [
"CommonOverrides",
"GenConfigBuilder",
"LLMConfig",
"LLMEmbeddingConfig",
"LLMGenerationConfig",
"ModelConfigOverride",
"apply_api_specific_mappings",
"create_generation_config_from_kwargs",
"get_gemini_safety_threshold",
"get_llm_config",
"register_llm_configs",

View File

@ -3,209 +3,397 @@ LLM 生成配置相关类和函数
"""
from collections.abc import Callable
from typing import Any
from enum import Enum
from typing import Any, Literal
from typing_extensions import Self
from pydantic import BaseModel, ConfigDict, Field
from zhenxun.services.log import logger
from zhenxun.utils.pydantic_compat import model_dump
from zhenxun.utils.pydantic_compat import model_copy, model_dump, model_validate
from ..types import LLMResponse
from ..types.enums import ResponseFormat
from ..types import LLMResponse, ResponseFormat, StructuredOutputStrategy
from ..types.exceptions import LLMErrorCode, LLMException
from .providers import get_gemini_safety_threshold
class ModelConfigOverride(BaseModel):
"""模型配置覆盖参数"""
class ReasoningEffort(str, Enum):
"""推理努力程度枚举"""
LOW = "LOW"
MEDIUM = "MEDIUM"
HIGH = "HIGH"
class ImageAspectRatio(str, Enum):
"""图像宽高比枚举"""
SQUARE = "1:1"
LANDSCAPE_16_9 = "16:9"
PORTRAIT_9_16 = "9:16"
LANDSCAPE_4_3 = "4:3"
PORTRAIT_3_4 = "3:4"
LANDSCAPE_3_2 = "3:2"
PORTRAIT_2_3 = "2:3"
class ImageResolution(str, Enum):
"""图像分辨率/质量枚举"""
STANDARD = "STANDARD"
HD = "HD"
class CoreConfig(BaseModel):
"""核心生成参数"""
temperature: float | None = Field(
default=None, ge=0.0, le=2.0, description="生成温度"
)
"""生成温度"""
max_tokens: int | None = Field(default=None, gt=0, description="最大输出token数")
"""最大输出token数"""
top_p: float | None = Field(default=None, ge=0.0, le=1.0, description="核采样参数")
"""核采样参数"""
top_k: int | None = Field(default=None, gt=0, description="Top-K采样参数")
"""Top-K采样参数"""
frequency_penalty: float | None = Field(
default=None, ge=-2.0, le=2.0, description="频率惩罚"
)
"""频率惩罚"""
presence_penalty: float | None = Field(
default=None, ge=-2.0, le=2.0, description="存在惩罚"
)
"""存在惩罚"""
repetition_penalty: float | None = Field(
default=None, ge=0.0, le=2.0, description="重复惩罚"
)
"""重复惩罚"""
stop: list[str] | str | None = Field(default=None, description="停止序列")
"""停止序列"""
class ReasoningConfig(BaseModel):
"""推理能力配置"""
effort: ReasoningEffort | None = Field(
default=None, description="推理努力程度 (适用于 O1, Gemini 3)"
)
"""推理努力程度 (适用于 O1, Gemini 3)"""
budget_tokens: int | None = Field(
default=None, description="具体的思考 Token 预算 (适用于 Gemini 2.5)"
)
"""具体的思考 Token 预算 (适用于 Gemini 2.5)"""
show_thoughts: bool | None = Field(
default=None, description="是否在响应中显式包含思维链内容"
)
"""是否在响应中显式包含思维链内容"""
class VisualConfig(BaseModel):
"""视觉生成配置"""
aspect_ratio: ImageAspectRatio | str | None = Field(
default=None, description="宽高比"
)
"""宽高比"""
resolution: ImageResolution | str | None = Field(
default=None, description="生成质量/分辨率"
)
"""生成质量/分辨率"""
media_resolution: str | None = Field(
default=None,
description="输入媒体的解析度 (Gemini 3+): 'LOW', 'MEDIUM', 'HIGH'",
)
"""输入媒体的解析度 (Gemini 3+): 'LOW', 'MEDIUM', 'HIGH'"""
style: str | None = Field(
default=None, description="图像风格 (如 DALL-E 3 vivid/natural)"
)
"""图像风格 (如 DALL-E 3 vivid/natural)"""
class OutputConfig(BaseModel):
"""输出格式控制"""
response_format: ResponseFormat | dict[str, Any] | None = Field(
default=None, description="期望的响应格式"
)
"""期望的响应格式"""
response_mime_type: str | None = Field(
default=None, description="响应MIME类型Gemini专用"
)
"""响应MIME类型Gemini专用"""
response_schema: dict[str, Any] | None = Field(
default=None, description="JSON响应模式"
)
thinking_budget: float | None = Field(
default=None, ge=0.0, le=1.0, description="思考预算"
)
include_thoughts: bool | None = Field(
default=None, description="是否在响应中包含思维过程Gemini专用"
)
safety_settings: dict[str, str] | None = Field(default=None, description="安全设置")
"""JSON响应模式"""
response_modalities: list[str] | None = Field(
default=None, description="响应模态类型"
default=None, description="响应模态类型 (TEXT, IMAGE, AUDIO)"
)
"""响应模态类型 (TEXT, IMAGE, AUDIO)"""
structured_output_strategy: StructuredOutputStrategy | str | None = Field(
default=None, description="结构化输出策略 (NATIVE/TOOL_CALL/PROMPT)"
)
"""结构化输出策略 (NATIVE/TOOL_CALL/PROMPT)"""
enable_code_execution: bool | None = Field(
default=None, description="是否启用代码执行"
class SafetyConfig(BaseModel):
"""安全设置"""
safety_settings: dict[str, str] | None = Field(default=None, description="安全设置")
"""安全设置"""
class ToolConfig(BaseModel):
"""工具调用控制配置"""
mode: Literal["AUTO", "ANY", "NONE"] = Field(
default="AUTO",
description="工具调用模式: AUTO(自动), ANY(强制), NONE(禁用)",
)
enable_grounding: bool | None = Field(
default=None, description="是否启用信息来源关联"
"""工具调用模式: AUTO(自动), ANY(强制), NONE(禁用)"""
allowed_function_names: list[str] | None = Field(
default=None,
description="当 mode 为 ANY 时,允许调用的函数名称白名单",
)
"""当 mode 为 ANY 时,允许调用的函数名称白名单"""
class LLMGenerationConfig(BaseModel):
"""
LLM 生成配置
采用组件化设计不再扁平化参数
"""
core: CoreConfig | None = Field(default=None, description="基础生成参数")
"""基础生成参数"""
reasoning: ReasoningConfig | None = Field(default=None, description="推理能力配置")
"""推理能力配置"""
visual: VisualConfig | None = Field(default=None, description="视觉生成配置")
"""视觉生成配置"""
output: OutputConfig | None = Field(default=None, description="输出格式配置")
"""输出格式配置"""
safety: SafetyConfig | None = Field(default=None, description="安全配置")
"""安全配置"""
tool_config: ToolConfig | None = Field(default=None, description="工具调用策略配置")
"""工具调用策略配置"""
enable_caching: bool | None = Field(default=None, description="是否启用响应缓存")
"""是否启用响应缓存"""
custom_params: dict[str, Any] | None = Field(default=None, description="自定义参数")
"""自定义参数"""
validation_policy: dict[str, Any] | None = Field(
default=None, description="声明式的响应验证策略 (例如: {'require_image': True})"
)
"""声明式的响应验证策略 (例如: {'require_image': True})"""
response_validator: Callable[[LLMResponse], None] | None = Field(
default=None, description="一个高级回调函数,用于验证响应,验证失败时应抛出异常"
default=None,
description="一个高级回调函数,用于验证响应,验证失败时应抛出异常",
)
"""一个高级回调函数,用于验证响应,验证失败时应抛出异常"""
model_config = ConfigDict(arbitrary_types_allowed=True)
@classmethod
def builder(cls) -> "GenConfigBuilder":
"""创建一个新的配置构建器"""
return GenConfigBuilder()
def to_dict(self) -> dict[str, Any]:
"""转换为字典排除None值"""
"""
转换为字典排除None值
注意这会返回嵌套结构的字典适配器需要处理这种嵌套
"""
return model_dump(self, exclude_none=True)
model_data = model_dump(self, exclude_none=True)
def merge_with(self, other: "LLMGenerationConfig | None") -> "LLMGenerationConfig":
"""
与另一个配置对象进行深度合并
other 中的非 None 字段会覆盖当前配置中的对应字段
返回一个新的配置对象原对象不变
"""
if not other:
return model_copy(self, deep=True)
result = {}
for key, value in model_data.items():
if key == "custom_params" and isinstance(value, dict):
result.update(value)
else:
result[key] = value
new_config = model_copy(self, deep=True)
return result
def _merge_component(base_comp, override_comp, comp_cls):
if override_comp is None:
return base_comp
if base_comp is None:
return override_comp
updates = model_dump(override_comp, exclude_none=True)
return model_copy(base_comp, update=updates)
def merge_with_base_config(
new_config.core = _merge_component(new_config.core, other.core, CoreConfig)
new_config.reasoning = _merge_component(
new_config.reasoning, other.reasoning, ReasoningConfig
)
new_config.visual = _merge_component(
new_config.visual, other.visual, VisualConfig
)
new_config.output = _merge_component(
new_config.output, other.output, OutputConfig
)
new_config.safety = _merge_component(
new_config.safety, other.safety, SafetyConfig
)
new_config.tool_config = _merge_component(
new_config.tool_config, other.tool_config, ToolConfig
)
if other.enable_caching is not None:
new_config.enable_caching = other.enable_caching
if other.custom_params:
if new_config.custom_params is None:
new_config.custom_params = {}
new_config.custom_params.update(other.custom_params)
if other.validation_policy:
if new_config.validation_policy is None:
new_config.validation_policy = {}
new_config.validation_policy.update(other.validation_policy)
if other.response_validator:
new_config.response_validator = other.response_validator
return new_config
class LLMEmbeddingConfig(BaseModel):
"""Embedding 专用配置"""
task_type: str | None = Field(default=None, description="任务类型 (Gemini/Jina)")
"""任务类型 (Gemini/Jina)"""
output_dimensionality: int | None = Field(
default=None, description="输出维度/压缩维度 (Gemini/Jina/OpenAI)"
)
"""输出维度/压缩维度 (Gemini/Jina/OpenAI)"""
title: str | None = Field(
default=None, description="仅用于 Gemini RETRIEVAL_DOCUMENT 任务的标题"
)
"""仅用于 Gemini RETRIEVAL_DOCUMENT 任务的标题"""
encoding_format: str | None = Field(
default="float", description="编码格式 (float/base64)"
)
"""编码格式 (float/base64)"""
model_config = ConfigDict(arbitrary_types_allowed=True)
class GenConfigBuilder:
"""
LLM 生成配置的语义化构建器
设计原则高频业务场景优先低频参数命名空间化
"""
def __init__(self):
self._config = LLMGenerationConfig()
def _ensure_core(self) -> CoreConfig:
if self._config.core is None:
self._config.core = CoreConfig()
return self._config.core
def _ensure_output(self) -> OutputConfig:
if self._config.output is None:
self._config.output = OutputConfig()
return self._config.output
def _ensure_reasoning(self) -> ReasoningConfig:
if self._config.reasoning is None:
self._config.reasoning = ReasoningConfig()
return self._config.reasoning
def as_json(self, schema: dict[str, Any] | None = None) -> Self:
"""
[高频] 强制模型输出 JSON 格式
"""
out = self._ensure_output()
out.response_format = ResponseFormat.JSON
if schema:
out.response_schema = schema
return self
def enable_thinking(
self, budget_tokens: int = -1, show_thoughts: bool = False
) -> Self:
"""
[高频] 启用模型的思考/推理能力 ( Gemini 2.0 Flash Thinking, DeepSeek R1)
"""
reasoning = self._ensure_reasoning()
reasoning.budget_tokens = budget_tokens
reasoning.show_thoughts = show_thoughts
return self
def config_core(
self,
base_temperature: float | None = None,
base_max_tokens: int | None = None,
) -> dict[str, Any]:
"""与基础配置合并,覆盖参数优先"""
merged = {}
temperature: float | None = None,
max_tokens: int | None = None,
top_p: float | None = None,
top_k: int | None = None,
stop: list[str] | str | None = None,
frequency_penalty: float | None = None,
presence_penalty: float | None = None,
) -> Self:
"""
[低频] 配置核心生成参数
"""
core = self._ensure_core()
if temperature is not None:
core.temperature = temperature
if max_tokens is not None:
core.max_tokens = max_tokens
if top_p is not None:
core.top_p = top_p
if top_k is not None:
core.top_k = top_k
if stop is not None:
core.stop = stop
if frequency_penalty is not None:
core.frequency_penalty = frequency_penalty
if presence_penalty is not None:
core.presence_penalty = presence_penalty
return self
if base_temperature is not None:
merged["temperature"] = base_temperature
if base_max_tokens is not None:
merged["max_tokens"] = base_max_tokens
def config_safety(self, settings: dict[str, str]) -> Self:
"""
[低频] 配置安全过滤设置
"""
if self._config.safety is None:
self._config.safety = SafetyConfig()
self._config.safety.safety_settings = settings
return self
override_dict = self.to_dict()
merged.update(override_dict)
def config_visual(
self,
aspect_ratio: ImageAspectRatio | str | None = None,
resolution: ImageResolution | str | None = None,
) -> Self:
"""
[低频] 配置视觉生成参数 (DALL-E 3 / Gemini Imagen)
"""
if self._config.visual is None:
self._config.visual = VisualConfig()
if aspect_ratio:
self._config.visual.aspect_ratio = aspect_ratio
if resolution:
self._config.visual.resolution = resolution
return self
return merged
def set_custom_param(self, key: str, value: Any) -> Self:
"""设置特定于厂商的自定义参数"""
if self._config.custom_params is None:
self._config.custom_params = {}
self._config.custom_params[key] = value
return self
class LLMGenerationConfig(ModelConfigOverride):
"""LLM 生成配置,继承模型配置覆盖参数"""
def to_api_params(self, api_type: str, model_name: str) -> dict[str, Any]:
"""转换为API参数支持不同API类型的参数名映射"""
_ = model_name
params = {}
if self.temperature is not None:
params["temperature"] = self.temperature
if self.max_tokens is not None:
if api_type == "gemini":
params["maxOutputTokens"] = self.max_tokens
else:
params["max_tokens"] = self.max_tokens
if api_type == "gemini":
if self.top_k is not None:
params["topK"] = self.top_k
if self.top_p is not None:
params["topP"] = self.top_p
else:
if self.top_k is not None:
params["top_k"] = self.top_k
if self.top_p is not None:
params["top_p"] = self.top_p
if api_type in ["openai", "deepseek", "zhipu", "general_openai_compat"]:
if self.frequency_penalty is not None:
params["frequency_penalty"] = self.frequency_penalty
if self.presence_penalty is not None:
params["presence_penalty"] = self.presence_penalty
if self.repetition_penalty is not None:
if api_type == "openai":
logger.warning("OpenAI官方API不支持repetition_penalty参数已忽略")
else:
params["repetition_penalty"] = self.repetition_penalty
if self.response_format is not None:
if isinstance(self.response_format, dict):
if api_type in ["openai", "zhipu", "deepseek", "general_openai_compat"]:
params["response_format"] = self.response_format
logger.debug(
f"{api_type} 使用自定义 response_format: "
f"{self.response_format}"
)
elif self.response_format == ResponseFormat.JSON:
if api_type in ["openai", "zhipu", "deepseek", "general_openai_compat"]:
params["response_format"] = {"type": "json_object"}
logger.debug(f"{api_type} 启用 JSON 对象输出模式")
elif api_type == "gemini":
params["responseMimeType"] = "application/json"
if self.response_schema:
params["responseSchema"] = self.response_schema
logger.debug(f"{api_type} 启用 JSON MIME 类型输出模式")
if self.custom_params:
custom_mapped = apply_api_specific_mappings(self.custom_params, api_type)
params.update(custom_mapped)
if api_type == "gemini":
if (
self.response_format != ResponseFormat.JSON
and self.response_mime_type is not None
):
params["responseMimeType"] = self.response_mime_type
logger.debug(
f"使用显式设置的 responseMimeType: {self.response_mime_type}"
)
if self.response_schema is not None and "responseSchema" not in params:
params["responseSchema"] = self.response_schema
if self.thinking_budget is not None or self.include_thoughts is not None:
thinking_config = params.setdefault("thinkingConfig", {})
if self.thinking_budget is not None:
max_budget = 24576
budget_value = int(self.thinking_budget * max_budget)
thinking_config["thinkingBudget"] = budget_value
logger.debug(
f"已将 thinking_budget (float: {self.thinking_budget}) "
f"转换为 Gemini API 的整数格式: {budget_value}"
)
if self.include_thoughts is not None:
thinking_config["includeThoughts"] = self.include_thoughts
logger.debug(f"已设置 includeThoughts: {self.include_thoughts}")
if self.safety_settings is not None:
params["safetySettings"] = self.safety_settings
if self.response_modalities is not None:
params["responseModalities"] = self.response_modalities
logger.debug(f"{api_type}转换配置参数: {len(params)}个参数")
return params
def build(self) -> LLMGenerationConfig:
"""构建最终的配置对象"""
return self._config
def validate_override_params(
@ -215,12 +403,12 @@ def validate_override_params(
if override_config is None:
return LLMGenerationConfig()
if isinstance(override_config, LLMGenerationConfig):
return override_config
if isinstance(override_config, dict):
try:
filtered_config = {
k: v for k, v in override_config.items() if v is not None
}
return LLMGenerationConfig(**filtered_config)
return model_validate(LLMGenerationConfig, override_config)
except Exception as e:
logger.warning(f"覆盖配置参数验证失败: {e}")
raise LLMException(
@ -229,56 +417,107 @@ def validate_override_params(
cause=e,
)
return override_config
raise LLMException(
f"不支持的配置类型: {type(override_config)}",
code=LLMErrorCode.CONFIGURATION_ERROR,
)
def apply_api_specific_mappings(
params: dict[str, Any], api_type: str
) -> dict[str, Any]:
"""应用API特定的参数映射"""
mapped_params = params.copy()
class CommonOverrides:
"""常用的配置覆盖预设"""
if api_type == "gemini":
if "max_tokens" in mapped_params:
mapped_params["maxOutputTokens"] = mapped_params.pop("max_tokens")
if "top_k" in mapped_params:
mapped_params["topK"] = mapped_params.pop("top_k")
if "top_p" in mapped_params:
mapped_params["topP"] = mapped_params.pop("top_p")
@staticmethod
def gemini_json() -> LLMGenerationConfig:
"""Gemini JSON模式强制JSON输出"""
return LLMGenerationConfig(
core=CoreConfig(),
output=OutputConfig(
response_format=ResponseFormat.JSON,
response_mime_type="application/json",
),
)
unsupported = ["frequency_penalty", "presence_penalty", "repetition_penalty"]
for param in unsupported:
if param in mapped_params:
logger.warning(f"Gemini 原生API不支持参数 '{param}',已忽略")
mapped_params.pop(param)
@staticmethod
def gemini_2_5_thinking(tokens: int = -1) -> LLMGenerationConfig:
"""Gemini 2.5 思考模式:默认 -1 (动态思考)0 为禁用,>=1024 为固定预算"""
return LLMGenerationConfig(
core=CoreConfig(temperature=1.0),
reasoning=ReasoningConfig(budget_tokens=tokens, show_thoughts=True),
)
elif api_type in ["openai", "deepseek", "zhipu", "general_openai_compat"]:
if "repetition_penalty" in mapped_params and api_type == "openai":
logger.warning("OpenAI官方API不支持repetition_penalty参数已忽略")
mapped_params.pop("repetition_penalty")
@staticmethod
def gemini_3_thinking(level: str = "HIGH") -> LLMGenerationConfig:
"""Gemini 3 深度思考模式:使用思考等级"""
try:
effort = ReasoningEffort(level.upper())
except ValueError:
effort = ReasoningEffort.HIGH
if "stop" in mapped_params:
stop_value = mapped_params["stop"]
if isinstance(stop_value, str):
mapped_params["stop"] = [stop_value]
return LLMGenerationConfig(
core=CoreConfig(),
reasoning=ReasoningConfig(effort=effort, show_thoughts=True),
)
return mapped_params
@staticmethod
def gemini_structured(schema: dict[str, Any]) -> LLMGenerationConfig:
"""Gemini 结构化输出自定义JSON模式"""
return LLMGenerationConfig(
core=CoreConfig(),
output=OutputConfig(
response_mime_type="application/json", response_schema=schema
),
)
@staticmethod
def gemini_safe() -> LLMGenerationConfig:
"""Gemini 安全模式:使用配置的安全设置"""
threshold = get_gemini_safety_threshold()
return LLMGenerationConfig(
core=CoreConfig(),
safety=SafetyConfig(
safety_settings={
"HARM_CATEGORY_HARASSMENT": threshold,
"HARM_CATEGORY_HATE_SPEECH": threshold,
"HARM_CATEGORY_SEXUALLY_EXPLICIT": threshold,
"HARM_CATEGORY_DANGEROUS_CONTENT": threshold,
}
),
)
def create_generation_config_from_kwargs(**kwargs) -> LLMGenerationConfig:
"""从关键字参数创建生成配置"""
model_fields = getattr(LLMGenerationConfig, "model_fields", {})
known_fields = set(model_fields.keys())
known_params = {}
custom_params = {}
@staticmethod
def gemini_code_execution() -> LLMGenerationConfig:
"""Gemini 代码执行模式:启用代码执行功能"""
return LLMGenerationConfig(
core=CoreConfig(),
custom_params={"code_execution_timeout": 30},
)
for key, value in kwargs.items():
if key in known_fields:
known_params[key] = value
else:
custom_params[key] = value
@staticmethod
def gemini_grounding() -> LLMGenerationConfig:
"""Gemini 信息来源关联模式启用Google搜索"""
return LLMGenerationConfig(
core=CoreConfig(),
custom_params={
"grounding_config": {"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}}
},
)
if custom_params:
known_params["custom_params"] = custom_params
@staticmethod
def gemini_nano_banana(aspect_ratio: str = "16:9") -> LLMGenerationConfig:
"""Gemini Nano Banana Pro自定义比例生图"""
try:
ar = ImageAspectRatio(aspect_ratio)
except ValueError:
ar = ImageAspectRatio.LANDSCAPE_16_9
return LLMGenerationConfig(**known_params)
return LLMGenerationConfig(
core=CoreConfig(),
visual=VisualConfig(aspect_ratio=ar),
)
@staticmethod
def gemini_high_res() -> LLMGenerationConfig:
"""Gemini 3: 强制使用高解析度处理输入媒体"""
return LLMGenerationConfig(
visual=VisualConfig(media_resolution="HIGH", resolution=ImageResolution.HD)
)

View File

@ -1,172 +0,0 @@
"""
LLM 预设配置
提供常用的配置预设特别是针对 Gemini 的高级功能
"""
from typing import Any
from .generation import LLMGenerationConfig
class CommonOverrides:
"""常用的配置覆盖预设"""
@staticmethod
def creative() -> LLMGenerationConfig:
"""创意模式:高温度,鼓励创新"""
return LLMGenerationConfig(temperature=0.9, top_p=0.95, frequency_penalty=0.1)
@staticmethod
def precise() -> LLMGenerationConfig:
"""精确模式:低温度,确定性输出"""
return LLMGenerationConfig(temperature=0.1, top_p=0.9, frequency_penalty=0.0)
@staticmethod
def balanced() -> LLMGenerationConfig:
"""平衡模式:中等温度"""
return LLMGenerationConfig(temperature=0.5, top_p=0.9, frequency_penalty=0.0)
@staticmethod
def concise(max_tokens: int = 100) -> LLMGenerationConfig:
"""简洁模式:限制输出长度"""
return LLMGenerationConfig(
temperature=0.3,
max_tokens=max_tokens,
stop=["\n\n", "", "", ""],
)
@staticmethod
def detailed(max_tokens: int = 2000) -> LLMGenerationConfig:
"""详细模式:鼓励详细输出"""
return LLMGenerationConfig(
temperature=0.7, max_tokens=max_tokens, frequency_penalty=-0.1
)
@staticmethod
def gemini_json() -> LLMGenerationConfig:
"""Gemini JSON模式强制JSON输出"""
return LLMGenerationConfig(
temperature=0.3, response_mime_type="application/json"
)
@staticmethod
def gemini_thinking(budget: float = 0.8) -> LLMGenerationConfig:
"""Gemini 思考模式:使用思考预算"""
return LLMGenerationConfig(temperature=0.7, thinking_budget=budget)
@staticmethod
def gemini_creative() -> LLMGenerationConfig:
"""Gemini 创意模式:高温度创意输出"""
return LLMGenerationConfig(temperature=0.9, top_p=0.95)
@staticmethod
def gemini_structured(schema: dict[str, Any]) -> LLMGenerationConfig:
"""Gemini 结构化输出自定义JSON模式"""
return LLMGenerationConfig(
temperature=0.3,
response_mime_type="application/json",
response_schema=schema,
)
@staticmethod
def gemini_safe() -> LLMGenerationConfig:
"""Gemini 安全模式:使用配置的安全设置"""
from .providers import get_gemini_safety_threshold
threshold = get_gemini_safety_threshold()
return LLMGenerationConfig(
temperature=0.5,
safety_settings={
"HARM_CATEGORY_HARASSMENT": threshold,
"HARM_CATEGORY_HATE_SPEECH": threshold,
"HARM_CATEGORY_SEXUALLY_EXPLICIT": threshold,
"HARM_CATEGORY_DANGEROUS_CONTENT": threshold,
},
)
@staticmethod
def gemini_multimodal() -> LLMGenerationConfig:
"""Gemini 多模态模式:优化多模态处理"""
return LLMGenerationConfig(temperature=0.6, max_tokens=2048, top_p=0.8)
@staticmethod
def gemini_code_execution() -> LLMGenerationConfig:
"""Gemini 代码执行模式:启用代码执行功能"""
return LLMGenerationConfig(
temperature=0.3,
max_tokens=4096,
enable_code_execution=True,
custom_params={"code_execution_timeout": 30},
)
@staticmethod
def gemini_grounding() -> LLMGenerationConfig:
"""Gemini 信息来源关联模式启用Google搜索"""
return LLMGenerationConfig(
temperature=0.5,
max_tokens=4096,
enable_grounding=True,
custom_params={
"grounding_config": {"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}}
},
)
@staticmethod
def gemini_cached() -> LLMGenerationConfig:
"""Gemini 缓存模式:启用响应缓存"""
return LLMGenerationConfig(
temperature=0.3,
max_tokens=2048,
enable_caching=True,
)
@staticmethod
def gemini_advanced() -> LLMGenerationConfig:
"""Gemini 高级模式:启用所有高级功能"""
return LLMGenerationConfig(
temperature=0.5,
max_tokens=4096,
enable_code_execution=True,
enable_grounding=True,
enable_caching=True,
custom_params={
"code_execution_timeout": 30,
"grounding_config": {
"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}
},
},
)
@staticmethod
def gemini_research() -> LLMGenerationConfig:
"""Gemini 研究模式:思考+搜索+结构化输出"""
return LLMGenerationConfig(
temperature=0.6,
max_tokens=4096,
thinking_budget=0.8,
enable_grounding=True,
response_mime_type="application/json",
custom_params={
"grounding_config": {"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}}
},
)
@staticmethod
def gemini_analysis() -> LLMGenerationConfig:
"""Gemini 分析模式:深度思考+详细输出"""
return LLMGenerationConfig(
temperature=0.4,
max_tokens=6000,
thinking_budget=0.9,
top_p=0.8,
)
@staticmethod
def gemini_fast_response() -> LLMGenerationConfig:
"""Gemini 快速响应模式:低延迟+简洁输出"""
return LLMGenerationConfig(
temperature=0.3,
max_tokens=512,
top_p=0.8,
)

View File

@ -13,6 +13,7 @@ from zhenxun.configs.config import Config
from zhenxun.configs.utils import parse_as
from zhenxun.services.log import logger
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from zhenxun.utils.pydantic_compat import model_dump
from ..core import key_store
from ..tools import tool_provider_manager
@ -22,6 +23,39 @@ AI_CONFIG_GROUP = "AI"
PROVIDERS_CONFIG_KEY = "PROVIDERS"
class DebugLogOptions(BaseModel):
"""调试日志细粒度控制"""
show_tools: bool = Field(
default=True, description="是否在日志中显示工具定义(JSON Schema)"
)
show_schema: bool = Field(
default=True, description="是否在日志中显示结构化输出Schema(response_format)"
)
show_safety: bool = Field(
default=True, description="是否在日志中显示安全设置(safetySettings)"
)
def __bool__(self) -> bool:
"""支持 bool(debug_options) 的语法,方便兼容旧逻辑。"""
return self.show_tools or self.show_schema or self.show_safety
class ClientSettings(BaseModel):
"""LLM 客户端通用设置"""
timeout: int = Field(default=300, description="API请求超时时间")
max_retries: int = Field(default=3, description="请求失败时的最大重试次数")
retry_delay: int = Field(default=2, description="请求重试的基础延迟时间(秒)")
structured_retries: int = Field(
default=2, description="结构化生成校验失败时的最大重试次数 (IVR)"
)
proxy: str | None = Field(
default=None,
description="网络代理,例如 http://127.0.0.1:7890",
)
class LLMConfig(BaseModel):
"""LLM 服务配置类"""
@ -29,20 +63,16 @@ class LLMConfig(BaseModel):
default=None,
description="LLM服务全局默认使用的模型名称 (格式: ProviderName/ModelName)",
)
proxy: str | None = Field(
default=None,
description="LLM服务请求使用的网络代理例如 http://127.0.0.1:7890",
)
timeout: int = Field(default=180, description="LLM服务API请求超时时间")
max_retries_llm: int = Field(
default=3, description="LLM服务请求失败时的最大重试次数"
)
retry_delay_llm: int = Field(
default=2, description="LLM服务请求重试的基础延迟时间"
client_settings: ClientSettings = Field(
default_factory=ClientSettings, description="客户端连接与重试配置"
)
providers: list[ProviderConfig] = Field(
default_factory=list, description="配置多个 AI 服务提供商及其模型信息"
)
debug_log: DebugLogOptions | bool = Field(
default_factory=DebugLogOptions,
description="LLM请求日志详情开关。支持 bool (全开/全关) 或 dict (细粒度控制)。",
)
def get_provider_by_name(self, name: str) -> ProviderConfig | None:
"""根据名称获取提供商配置
@ -226,36 +256,29 @@ def register_llm_configs():
)
Config.add_plugin_config(
AI_CONFIG_GROUP,
"proxy",
llm_config.proxy,
help="LLM服务请求使用的网络代理例如 http://127.0.0.1:7890",
type=str,
"client_settings",
model_dump(llm_config.client_settings),
help=(
"LLM客户端高级设置。\n"
"包含: timeout(超时秒数), max_retries(重试次数), "
"retry_delay(重试延迟), structured_retries(结构化生成重试), proxy(代理)"
),
type=dict,
)
Config.add_plugin_config(
AI_CONFIG_GROUP,
"timeout",
llm_config.timeout,
help="LLM服务API请求超时时间",
type=int,
)
Config.add_plugin_config(
AI_CONFIG_GROUP,
"max_retries_llm",
llm_config.max_retries_llm,
help="LLM服务请求失败时的最大重试次数",
type=int,
)
Config.add_plugin_config(
AI_CONFIG_GROUP,
"retry_delay_llm",
llm_config.retry_delay_llm,
help="LLM服务请求重试的基础延迟时间",
type=int,
"debug_log",
{"show_tools": True, "show_schema": True, "show_safety": True},
help=(
"LLM日志详情开关。示例: {'show_tools': True, 'show_schema': False, "
"'show_safety': False}"
),
type=dict,
)
Config.add_plugin_config(
AI_CONFIG_GROUP,
"gemini_safety_threshold",
"BLOCK_MEDIUM_AND_ABOVE",
"BLOCK_NONE",
help=(
"Gemini 安全过滤阈值 "
"(BLOCK_LOW_AND_ABOVE: 阻止低级别及以上, "
@ -270,7 +293,20 @@ def register_llm_configs():
AI_CONFIG_GROUP,
PROVIDERS_CONFIG_KEY,
get_default_providers(),
help="配置多个 AI 服务提供商及其模型信息",
help=(
"配置多个 AI 服务提供商及其模型信息。\n"
"注意:可以在特定模型配置下添加 'api_type' 以覆盖提供商的全局设置。\n"
"支持的 api_type 包括:\n"
"- 'openai': 标准 OpenAI 格式 (DeepSeek, SiliconFlow, Moonshot 等)\n"
"- 'gemini': Google Gemini API\n"
"- 'zhipu': 智谱 AI (GLM)\n"
"- 'ark': 字节跳动火山引擎 (Doubao)\n"
"- 'openrouter': OpenRouter 聚合平台\n"
"- 'openai_image': OpenAI 兼容的图像生成接口 (DALL-E)\n"
"- 'openai_responses': 支持新版 responses 格式的 OpenAI 兼容接口\n"
"- 'smart': 智能路由模式 (主要用于第三方中转场景,自动根据模型名"
"分发请求到 openai 或 gemini)"
),
default_value=[],
type=list[ProviderConfig],
)
@ -278,15 +314,21 @@ def register_llm_configs():
@lru_cache(maxsize=1)
def get_llm_config() -> LLMConfig:
"""获取 LLM 配置实例,不再加载 MCP 工具配置"""
"""获取 LLM 配置实例"""
ai_config = get_ai_config()
raw_debug = ai_config.get("debug_log", False)
if isinstance(raw_debug, bool):
debug_log_val = DebugLogOptions(
show_tools=raw_debug, show_schema=raw_debug, show_safety=raw_debug
)
else:
debug_log_val = raw_debug
config_data = {
"default_model_name": ai_config.get("default_model_name"),
"proxy": ai_config.get("proxy"),
"timeout": ai_config.get("timeout", 180),
"max_retries_llm": ai_config.get("max_retries_llm", 3),
"retry_delay_llm": ai_config.get("retry_delay_llm", 2),
"client_settings": ai_config.get("client_settings", {}),
"debug_log": debug_log_val,
PROVIDERS_CONFIG_KEY: ai_config.get(PROVIDERS_CONFIG_KEY, []),
}
@ -314,14 +356,14 @@ def validate_llm_config() -> tuple[bool, list[str]]:
try:
llm_config = get_llm_config()
if llm_config.timeout <= 0:
if llm_config.client_settings.timeout <= 0:
errors.append("timeout 必须大于 0")
if llm_config.max_retries_llm < 0:
errors.append("max_retries_llm 不能小于 0")
if llm_config.client_settings.max_retries < 0:
errors.append("max_retries 不能小于 0")
if llm_config.retry_delay_llm <= 0:
errors.append("retry_delay_llm 必须大于 0")
if llm_config.client_settings.retry_delay <= 0:
errors.append("retry_delay 必须大于 0")
if not llm_config.providers:
errors.append("至少需要配置一个 AI 服务提供商")

View File

@ -254,7 +254,7 @@ class KeyStats:
if total_calls == 0:
return KeyStatus.UNUSED
if self.success_rate < 80:
if self.success_rate < 70:
return KeyStatus.ERROR
if total_calls >= 5 and self.avg_latency > 15000:
@ -292,96 +292,6 @@ class RetryConfig:
self.key_rotation = key_rotation
async def with_smart_retry(
func,
*args,
retry_config: RetryConfig | None = None,
key_store: "KeyStatusStore | None" = None,
provider_name: str | None = None,
**kwargs: Any,
) -> Any:
"""
智能重试装饰器 - 支持Key轮询和错误分类
参数:
func: 要重试的异步函数
*args: 传递给函数的位置参数
retry_config: 重试配置
key_store: API密钥状态存储
provider_name: 提供商名称
**kwargs: 传递给函数的关键字参数
返回:
Any: 函数执行结果
"""
config = retry_config or RetryConfig()
last_exception: Exception | None = None
failed_keys: set[str] = set()
model_instance = next((arg for arg in args if hasattr(arg, "api_keys")), None)
all_provider_keys = model_instance.api_keys if model_instance else []
for attempt in range(config.max_retries + 1):
try:
if config.key_rotation and "failed_keys" in func.__code__.co_varnames:
kwargs["failed_keys"] = failed_keys
start_time = time.monotonic()
result = await func(*args, **kwargs)
latency = (time.monotonic() - start_time) * 1000
if key_store and isinstance(result, tuple) and len(result) == 2:
_, api_key_used = result
if api_key_used:
await key_store.record_success(api_key_used, latency)
return result
else:
return result
except LLMException as e:
last_exception = e
api_key_in_use = e.details.get("api_key")
if api_key_in_use:
failed_keys.add(api_key_in_use)
if key_store and provider_name and len(all_provider_keys) > 1:
status_code = e.details.get("status_code")
error_message = f"({e.code.name}) {e.message}"
await key_store.record_failure(
api_key_in_use, status_code, error_message
)
should_retry = _should_retry_llm_error(e, attempt, config.max_retries)
if not should_retry:
logger.error(f"不可重试的错误,停止重试: {e}")
raise
if attempt < config.max_retries:
wait_time = config.retry_delay
if config.exponential_backoff:
wait_time *= 2**attempt
logger.warning(
f"请求失败,{wait_time:.2f}秒后重试 (第{attempt + 1}次): {e}"
)
await asyncio.sleep(wait_time)
else:
logger.error(f"重试{config.max_retries}次后仍然失败: {e}")
except Exception as e:
last_exception = e
logger.error(f"非LLM异常停止重试: {e}")
raise LLMException(
f"操作失败: {e}",
code=LLMErrorCode.GENERATION_FAILED,
cause=e,
)
if last_exception:
raise last_exception
else:
raise RuntimeError("重试函数未能正常执行且未捕获到异常")
def _should_retry_llm_error(
error: LLMException, attempt: int, max_retries: int
) -> bool:
@ -390,7 +300,9 @@ def _should_retry_llm_error(
LLMErrorCode.MODEL_NOT_FOUND,
LLMErrorCode.CONTEXT_LENGTH_EXCEEDED,
LLMErrorCode.USER_LOCATION_NOT_SUPPORTED,
LLMErrorCode.INVALID_PARAMETER,
LLMErrorCode.CONFIGURATION_ERROR,
LLMErrorCode.API_KEY_INVALID,
}
if error.code in non_retryable_errors:
@ -404,15 +316,12 @@ def _should_retry_llm_error(
LLMErrorCode.RESPONSE_PARSE_ERROR,
LLMErrorCode.GENERATION_FAILED,
LLMErrorCode.CONTENT_FILTERED,
LLMErrorCode.API_KEY_INVALID,
LLMErrorCode.API_QUOTA_EXCEEDED,
}
if error.code in retryable_errors:
if error.code == LLMErrorCode.API_QUOTA_EXCEEDED:
return attempt < min(2, max_retries)
elif error.code == LLMErrorCode.CONTENT_FILTERED:
return attempt < min(1, max_retries)
return True
return False
@ -558,14 +467,68 @@ class KeyStatusStore:
now = time.time()
cooldown_duration = 300
if status_code in [401, 403, 404]:
location_not_supported = error_message and (
"USER_LOCATION_NOT_SUPPORTED" in error_message
or "User location is not supported" in error_message
)
if location_not_supported:
logger.warning(
f"API Key {key_id} 请求失败,原因是地区不支持 (Gemini)。"
" 这通常是代理节点问题Key 本身可能是正常的。跳过冷却。"
)
async with self._lock:
stats = self._key_stats.setdefault(api_key, KeyStats())
stats.failure_count += 1
stats.last_error_info = error_message[:256]
await self._save_to_file_internal()
return
if error_message and (
"API_QUOTA_EXCEEDED" in error_message
or "insufficient_quota" in error_message.lower()
):
cooldown_duration = 3600
logger.warning(f"API Key {key_id} 额度耗尽,冷却 1 小时。")
is_key_invalid = status_code == 401 or (
status_code == 400
and error_message
and (
"API_KEY_INVALID" in error_message
or "API key not valid" in error_message
)
)
if is_key_invalid:
cooldown_duration = 31536000
log_level = "error"
log_message = f"API密钥认证/权限/路径错误,将永久禁用: {key_id}"
elif status_code == 403:
cooldown_duration = 3600
log_level = "warning"
log_message = f"API密钥权限不足或地区不支持(403)冷却1小时: {key_id}"
elif status_code == 404:
log_level = "error"
log_message = "API请求返回 404 (未找到),可能是模型名称错误或接口地址"
f"错误,不冷却密钥: {key_id}"
elif status_code == 422:
cooldown_duration = 0
log_level = "warning"
log_message = f"API请求无法处理(422),可能是生成故障,不冷却密钥: {key_id}"
elif status_code == 429:
cooldown_duration = 60
log_level = "warning"
log_message = f"API密钥被限流冷却60秒: {key_id}"
elif error_message and (
"ConnectError" in error_message
or "NetworkError" in error_message
or "Connection refused" in error_message
or "RemoteProtocolError" in error_message
or "ProxyError" in error_message
):
cooldown_duration = 0
log_level = "warning"
log_message = f"网络连接层异常(代理/DNS),不冷却密钥: {key_id}"
else:
log_level = "warning"
log_message = f"API密钥遇到临时性错误冷却{cooldown_duration}秒: {key_id}"

View File

@ -1,193 +0,0 @@
"""
LLM 轻量级工具执行器
提供驱动 LLM 与本地函数工具之间交互的核心循环
"""
import asyncio
from enum import Enum
import json
from typing import Any
from pydantic import BaseModel, Field
from zhenxun.services.log import logger
from zhenxun.utils.decorator.retry import Retry
from zhenxun.utils.pydantic_compat import model_dump
from .service import LLMModel
from .types import (
LLMErrorCode,
LLMException,
LLMMessage,
ToolExecutable,
ToolResult,
)
class ExecutionConfig(BaseModel):
"""
轻量级执行器的配置
"""
max_cycles: int = Field(default=5, description="工具调用循环的最大次数。")
class ToolErrorType(str, Enum):
"""结构化工具错误的类型枚举。"""
TOOL_NOT_FOUND = "ToolNotFound"
INVALID_ARGUMENTS = "InvalidArguments"
EXECUTION_ERROR = "ExecutionError"
USER_CANCELLATION = "UserCancellation"
class ToolErrorResult(BaseModel):
"""一个结构化的工具执行错误模型,用于返回给 LLM。"""
error_type: ToolErrorType = Field(..., description="错误的类型。")
message: str = Field(..., description="对错误的详细描述。")
is_retryable: bool = Field(False, description="指示这个错误是否可能通过重试解决。")
def model_dump(self, **kwargs):
return model_dump(self, **kwargs)
def _is_exception_retryable(e: Exception) -> bool:
"""判断一个异常是否应该触发重试。"""
if isinstance(e, LLMException):
retryable_codes = {
LLMErrorCode.API_REQUEST_FAILED,
LLMErrorCode.API_TIMEOUT,
LLMErrorCode.API_RATE_LIMITED,
}
return e.code in retryable_codes
return True
class LLMToolExecutor:
"""
一个通用的执行器负责驱动 LLM 与工具之间的多轮交互
"""
def __init__(self, model: LLMModel):
self.model = model
async def run(
self,
messages: list[LLMMessage],
tools: dict[str, ToolExecutable],
config: ExecutionConfig | None = None,
) -> list[LLMMessage]:
"""
执行完整的思考-行动循环
"""
effective_config = config or ExecutionConfig()
execution_history = list(messages)
for i in range(effective_config.max_cycles):
response = await self.model.generate_response(
execution_history, tools=tools
)
assistant_message = LLMMessage(
role="assistant",
content=response.text,
tool_calls=response.tool_calls,
)
execution_history.append(assistant_message)
if not response.tool_calls:
logger.info("✅ LLMToolExecutor模型未请求工具调用执行结束。")
return execution_history
logger.info(
f"🛠️ LLMToolExecutor模型请求并行调用 {len(response.tool_calls)} 个工具"
)
tool_results = await self._execute_tools_parallel_safely(
response.tool_calls,
tools,
)
execution_history.extend(tool_results)
raise LLMException(
f"超过最大工具调用循环次数 ({effective_config.max_cycles})。",
code=LLMErrorCode.GENERATION_FAILED,
)
async def _execute_single_tool_safely(
self, tool_call: Any, available_tools: dict[str, ToolExecutable]
) -> tuple[Any, ToolResult]:
"""安全地执行单个工具调用。"""
tool_name = tool_call.function.name
arguments = {}
try:
if tool_call.function.arguments:
arguments = json.loads(tool_call.function.arguments)
except json.JSONDecodeError as e:
error_result = ToolErrorResult(
error_type=ToolErrorType.INVALID_ARGUMENTS,
message=f"参数解析失败: {e}",
is_retryable=False,
)
return tool_call, ToolResult(output=model_dump(error_result))
try:
executable = available_tools.get(tool_name)
if not executable:
raise LLMException(
f"Tool '{tool_name}' not found.",
code=LLMErrorCode.CONFIGURATION_ERROR,
)
@Retry.simple(
stop_max_attempt=2, wait_fixed_seconds=1, return_on_failure=None
)
async def execute_with_retry():
return await executable.execute(**arguments)
execution_result = await execute_with_retry()
if execution_result is None:
raise LLMException("工具执行在多次重试后仍然失败。")
return tool_call, execution_result
except Exception as e:
error_type = ToolErrorType.EXECUTION_ERROR
is_retryable = _is_exception_retryable(e)
if (
isinstance(e, LLMException)
and e.code == LLMErrorCode.CONFIGURATION_ERROR
):
error_type = ToolErrorType.TOOL_NOT_FOUND
is_retryable = False
error_result = ToolErrorResult(
error_type=error_type, message=str(e), is_retryable=is_retryable
)
return tool_call, ToolResult(output=model_dump(error_result))
async def _execute_tools_parallel_safely(
self,
tool_calls: list[Any],
available_tools: dict[str, ToolExecutable],
) -> list[LLMMessage]:
"""并行执行所有工具调用,并对每个调用的错误进行隔离。"""
if not tool_calls:
return []
tasks = [
self._execute_single_tool_safely(call, available_tools)
for call in tool_calls
]
results = await asyncio.gather(*tasks)
tool_messages = [
LLMMessage.tool_response(
tool_call_id=original_call.id,
function_name=original_call.function.name,
result=result.output,
)
for original_call, result in results
]
return tool_messages

View File

@ -13,15 +13,19 @@ from zhenxun.services.log import logger
from zhenxun.utils.pydantic_compat import dump_json_safely
from .config import validate_override_params
from .config.providers import AI_CONFIG_GROUP, PROVIDERS_CONFIG_KEY, get_ai_config
from .config.generation import LLMGenerationConfig
from .config.providers import (
AI_CONFIG_GROUP,
PROVIDERS_CONFIG_KEY,
get_ai_config,
get_llm_config,
)
from .core import http_client_manager, key_store
from .service import LLMModel
from .types import LLMErrorCode, LLMException, ModelDetail, ProviderConfig
from .types.capabilities import get_model_capabilities
DEFAULT_MODEL_NAME_KEY = "default_model_name"
PROXY_KEY = "proxy"
TIMEOUT_KEY = "timeout"
_model_cache: dict[str, tuple[LLMModel, float]] = {}
_cache_ttl = 3600
@ -39,7 +43,8 @@ def parse_provider_model_string(name_str: str | None) -> tuple[str | None, str |
def _make_cache_key(
provider_model_name: str | None, override_config: dict | None
provider_model_name: str | None,
override_config: dict | LLMGenerationConfig | None,
) -> str:
"""生成缓存键"""
config_str = (
@ -115,11 +120,12 @@ def get_default_api_base_for_type(api_type: str) -> str | None:
"""根据API类型获取默认的API基础地址"""
default_api_bases = {
"openai": "https://api.openai.com",
"deepseek": "https://api.deepseek.com",
"deepseek": "https://api.deepseek.com/beta",
"zhipu": "https://open.bigmodel.cn",
"gemini": "https://generativelanguage.googleapis.com",
"openrouter": "https://openrouter.ai/api",
"general_openai_compat": None,
"smart": None,
"openai_responses": None,
}
return default_api_bases.get(api_type)
@ -244,7 +250,7 @@ def list_embedding_models() -> list[dict[str, Any]]:
async def get_model_instance(
provider_model_name: str | None = None,
override_config: dict[str, Any] | None = None,
override_config: dict[str, Any] | LLMGenerationConfig | None = None,
) -> LLMModel:
"""
根据 'ProviderName/ModelName' 字符串获取并实例化 LLMModel (异步版本)
@ -303,21 +309,20 @@ async def get_model_instance(
model_detail_found.is_embedding_model = capabilities.is_embedding_model
ai_config = get_ai_config()
global_proxy_setting = ai_config.get(PROXY_KEY)
llm_config = get_llm_config()
client_settings = llm_config.client_settings
default_timeout = (
provider_config_found.timeout
if provider_config_found.timeout is not None
else 180
else client_settings.timeout
)
global_timeout_setting = ai_config.get(TIMEOUT_KEY, default_timeout)
config_for_http_client = ProviderConfig(
name=provider_config_found.name,
api_key=provider_config_found.api_key,
models=provider_config_found.models,
timeout=global_timeout_setting,
proxy=global_proxy_setting,
timeout=default_timeout,
proxy=client_settings.proxy,
api_base=provider_config_found.api_base,
api_type=provider_config_found.api_type,
openai_compat=provider_config_found.openai_compat,

View File

@ -1,55 +0,0 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Any
from .types import LLMMessage
class BaseMemory(ABC):
"""
记忆系统的抽象基类
定义了任何记忆后端都必须实现的接口
"""
@abstractmethod
async def get_history(self, session_id: str) -> list[LLMMessage]:
"""根据会话ID获取历史记录。"""
raise NotImplementedError
@abstractmethod
async def add_message(self, session_id: str, message: LLMMessage) -> None:
"""向指定会话添加一条消息。"""
raise NotImplementedError
@abstractmethod
async def add_messages(self, session_id: str, messages: list[LLMMessage]) -> None:
"""向指定会话添加多条消息。"""
raise NotImplementedError
@abstractmethod
async def clear_history(self, session_id: str) -> None:
"""清空指定会话的历史记录。"""
raise NotImplementedError
class InMemoryMemory(BaseMemory):
"""
一个简单的默认的内存记忆后端
将历史记录存储在进程内存中的字典里
"""
def __init__(self, **kwargs: Any):
self._history: dict[str, list[LLMMessage]] = defaultdict(list)
async def get_history(self, session_id: str) -> list[LLMMessage]:
return self._history.get(session_id, []).copy()
async def add_message(self, session_id: str, message: LLMMessage) -> None:
self._history[session_id].append(message)
async def add_messages(self, session_id: str, messages: list[LLMMessage]) -> None:
self._history[session_id].extend(messages)
async def clear_history(self, session_id: str) -> None:
if session_id in self._history:
del self._history[session_id]

File diff suppressed because it is too large Load Diff

View File

@ -4,30 +4,34 @@ LLM 服务 - 会话客户端
提供一个有状态的面向会话的 LLM 客户端用于进行多轮对话和复杂交互
"""
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Awaitable, Callable
import copy
from dataclasses import dataclass, field
import json
from typing import Any, TypeVar
from typing import Any, TypeVar, cast
import uuid
from jinja2 import Environment
from nonebot.compat import type_validate_json
from jinja2 import Template
from nonebot.utils import is_coroutine_callable
from nonebot_plugin_alconna.uniseg import UniMessage
from pydantic import BaseModel, ValidationError
from pydantic import BaseModel
from zhenxun.services.log import logger
from zhenxun.utils.pydantic_compat import model_copy, model_dump, model_json_schema
from zhenxun.utils.pydantic_compat import model_json_schema
from .config import (
CommonOverrides,
GenConfigBuilder,
LLMEmbeddingConfig,
LLMGenerationConfig,
)
from .config.providers import get_ai_config
from .config.generation import OutputConfig
from .config.providers import get_ai_config, get_llm_config
from .manager import get_global_default_model_name, get_model_instance
from .memory import BaseMemory, InMemoryMemory
from .tools.manager import tool_provider_manager
from .tools import tool_provider_manager
from .types import (
EmbeddingTaskType,
LLMContentPart,
LLMErrorCode,
LLMException,
@ -35,19 +39,28 @@ from .types import (
LLMResponse,
ModelName,
ResponseFormat,
StructuredOutputStrategy,
ToolChoice,
ToolExecutable,
ToolProvider,
)
from .utils import normalize_to_llm_messages
from .types.models import (
GeminiCodeExecution,
GeminiGoogleSearch,
)
from .utils import (
create_cot_wrapper,
normalize_to_llm_messages,
parse_and_validate_json,
should_apply_autocot,
)
T = TypeVar("T", bound=BaseModel)
jinja_env = Environment(autoescape=False)
@dataclass
class AIConfig:
"""AI配置类 - [重构后] 简化版本"""
"""AI配置类"""
model: ModelName = None
default_embedding_model: ModelName = None
@ -61,6 +74,98 @@ class AIConfig:
self.model = ai_config.get("default_model_name")
class BaseMemory(ABC):
"""记忆系统的抽象基类。"""
@abstractmethod
async def get_history(self, session_id: str) -> list[LLMMessage]:
raise NotImplementedError
@abstractmethod
async def add_message(self, session_id: str, message: LLMMessage) -> None:
raise NotImplementedError
@abstractmethod
async def add_messages(self, session_id: str, messages: list[LLMMessage]) -> None:
raise NotImplementedError
@abstractmethod
async def clear_history(self, session_id: str) -> None:
raise NotImplementedError
class InMemoryMemory(BaseMemory):
"""一个简单的、默认的内存记忆后端。"""
def __init__(self, max_messages: int = 50, **kwargs: Any):
self._history: dict[str, list[LLMMessage]] = defaultdict(list)
self._max_messages = max_messages
def _trim_history(self, session_id: str) -> None:
"""修剪历史记录,确保不超过最大长度,同时保留 System Prompt"""
history = self._history[session_id]
if len(history) <= self._max_messages:
return
has_system = history and history[0].role == "system"
if has_system:
keep_count = max(0, self._max_messages - 1)
self._history[session_id] = [history[0], *history[-keep_count:]]
else:
self._history[session_id] = history[-self._max_messages :]
async def get_history(self, session_id: str) -> list[LLMMessage]:
return self._history.get(session_id, []).copy()
async def add_message(self, session_id: str, message: LLMMessage) -> None:
self._history[session_id].append(message)
self._trim_history(session_id)
async def add_messages(self, session_id: str, messages: list[LLMMessage]) -> None:
self._history[session_id].extend(messages)
self._trim_history(session_id)
async def clear_history(self, session_id: str) -> None:
if session_id in self._history:
del self._history[session_id]
class MemoryProcessor(ABC):
"""记忆处理器接口"""
@abstractmethod
async def process(self, session_id: str, new_messages: list[LLMMessage]) -> None:
pass
_default_memory_factory: Callable[[], BaseMemory] | None = None
def set_default_memory_backend(factory: Callable[[], BaseMemory]):
"""
设置全局默认记忆后端工厂允许统一替换会话的记忆实现
"""
global _default_memory_factory
_default_memory_factory = factory
def _get_default_memory() -> BaseMemory:
if _default_memory_factory:
return _default_memory_factory()
return InMemoryMemory()
DEFAULT_IVR_TEMPLATE = (
"你的响应未能通过结构校验。\n"
"错误详情: {error_msg}\n\n"
"请执行以下步骤进行修正:\n"
"1. 反思:分析为什么会出现这个错误。\n"
"2. 修正:生成一个新的、符合 Schema 要求的 JSON 对象。\n"
"请直接输出修正后的 JSON不要包含 Markdown 标记或其他解释。"
)
class AI:
"""
统一的AI服务类 - 提供了带记忆的会话接口
@ -73,6 +178,7 @@ class AI:
config: AIConfig | None = None,
memory: BaseMemory | None = None,
default_generation_config: LLMGenerationConfig | None = None,
processors: list[MemoryProcessor] | None = None,
):
"""
初始化AI服务
@ -81,24 +187,45 @@ class AI:
session_id: 唯一的会话ID用于隔离记忆
config: AI 配置.
memory: 可选的自定义记忆后端如果为None则使用默认的InMemoryMemory
default_generation_config: (新增) 此AI实例的默认生成配置
default_generation_config: 此AI实例的默认生成配置
processors: 记忆处理器列表在添加记忆后触发
"""
self.session_id = session_id or str(uuid.uuid4())
self.config = config or AIConfig()
self.memory = memory or InMemoryMemory()
self.memory = memory or _get_default_memory()
self.default_generation_config = (
default_generation_config or LLMGenerationConfig()
)
self.processors = processors or []
global_providers = tool_provider_manager._providers
config_providers = self.config.tool_providers
self._tool_providers = list(dict.fromkeys(global_providers + config_providers))
self.message_buffer: list[LLMMessage] = []
async def clear_history(self):
"""清空当前会话的历史记录。"""
await self.memory.clear_history(self.session_id)
logger.info(f"AI会话历史记录已清空 (session_id: {self.session_id})")
async def add_observation(
self, message: str | UniMessage | LLMMessage | list[LLMContentPart]
):
"""
将一条观察消息加入缓冲区不立即触发模型调用
返回:
int: 缓冲区中消息的数量
"""
current_message = await self._normalize_input_to_message(message)
self.message_buffer.append(current_message)
content_preview = str(current_message.content)[:50]
logger.debug(
f"[放入观察] {content_preview} (缓冲区大小: {len(self.message_buffer)})",
"AI_MEMORY",
)
return len(self.message_buffer)
async def add_user_message_to_history(
self, message: str | LLMMessage | list[LLMContentPart]
):
@ -161,7 +288,7 @@ class AI:
self, message: str | UniMessage | LLMMessage | list[LLMContentPart]
) -> LLMMessage:
"""
[重构后] 内部辅助方法将各种输入类型统一转换为单个 LLMMessage 对象
内部辅助方法将各种输入类型统一转换为单个 LLMMessage 对象
它调用共享的工具函数并提取最后一条消息通常是用户输入
"""
messages = await normalize_to_llm_messages(message)
@ -172,17 +299,79 @@ class AI:
)
return messages[-1]
async def generate_internal(
self,
messages: list[LLMMessage],
*,
model: ModelName = None,
config: LLMGenerationConfig | GenConfigBuilder | None = None,
tools: list[Any] | dict[str, ToolExecutable] | None = None,
tool_choice: str | dict[str, Any] | ToolChoice | None = None,
timeout: float | None = None,
model_instance: Any = None,
) -> LLMResponse:
"""
内部生成核心方法负责配置合并工具解析和模型调用
此方法不处理历史记录的存储 AgentExecutor chat 方法调用
"""
final_config = self.default_generation_config
if isinstance(config, GenConfigBuilder):
config = config.build()
if config:
final_config = final_config.merge_with(config)
final_tools_list = []
if tools:
if isinstance(tools, dict):
final_tools_list = list(tools.values())
elif isinstance(tools, list):
to_resolve: list[Any] = []
for t in tools:
if isinstance(t, str | dict):
to_resolve.append(t)
else:
final_tools_list.append(t)
if to_resolve:
resolved_dict = await self._resolve_tools(to_resolve)
final_tools_list.extend(resolved_dict.values())
if model_instance:
return await model_instance.generate_response(
messages,
config=final_config,
tools=final_tools_list if final_tools_list else None,
tool_choice=tool_choice,
timeout=timeout,
)
resolved_model_name = self._resolve_model_name(model or self.config.model)
async with await get_model_instance(
resolved_model_name,
override_config=None,
) as instance:
return await instance.generate_response(
messages,
config=final_config,
tools=final_tools_list if final_tools_list else None,
tool_choice=tool_choice,
timeout=timeout,
)
async def chat(
self,
message: str | UniMessage | LLMMessage | list[LLMContentPart],
message: str | UniMessage | LLMMessage | list[LLMContentPart] | None,
*,
model: ModelName = None,
instruction: str | None = None,
template_vars: dict[str, Any] | None = None,
preserve_media_in_history: bool | None = None,
tools: list[dict[str, Any] | str] | dict[str, ToolExecutable] | None = None,
tool_choice: str | dict[str, Any] | None = None,
config: LLMGenerationConfig | None = None,
tools: list[Any] | dict[str, ToolExecutable] | None = None,
tool_choice: str | dict[str, Any] | ToolChoice | None = None,
config: LLMGenerationConfig | GenConfigBuilder | None = None,
use_buffer: bool = False,
timeout: float | None = None,
) -> LLMResponse:
"""
核心交互方法管理会话历史并执行单次LLM调用
@ -198,18 +387,27 @@ class AI:
tools: 可用的工具列表或工具字典支持临时工具和预配置工具
tool_choice: 工具选择策略控制AI如何选择和使用工具
config: 生成配置对象用于覆盖默认的生成参数
use_buffer: 是否刷新并包含消息缓冲区的内容在此次对话中一次性提交
timeout: HTTP 请求超时时间
返回:
LLMResponse: 包含AI回复工具调用请求使用信息等的完整响应对象
"""
current_message = await self._normalize_input_to_message(message)
messages_to_add: list[LLMMessage] = []
if message:
current_message = await self._normalize_input_to_message(message)
messages_to_add.append(current_message)
if use_buffer and self.message_buffer:
messages_to_add = self.message_buffer + messages_to_add
self.message_buffer.clear()
messages_for_run = []
final_instruction = instruction
if final_instruction and template_vars:
try:
template = jinja_env.from_string(final_instruction)
template = Template(final_instruction)
final_instruction = template.render(**template_vars)
logger.debug(f"渲染后的系统指令: {final_instruction}")
except Exception as e:
@ -220,51 +418,55 @@ class AI:
current_history = await self.memory.get_history(self.session_id)
messages_for_run.extend(current_history)
messages_for_run.append(current_message)
messages_for_run.extend(messages_to_add)
try:
resolved_model_name = self._resolve_model_name(model or self.config.model)
final_config = model_copy(self.default_generation_config, deep=True)
if config:
update_dict = model_dump(config, exclude_unset=True)
final_config = model_copy(final_config, update=update_dict)
ad_hoc_tools = None
if tools:
if isinstance(tools, dict):
ad_hoc_tools = tools
else:
ad_hoc_tools = await self._resolve_tools(tools)
async with await get_model_instance(
resolved_model_name,
override_config=final_config.to_dict(),
) as model_instance:
response = await model_instance.generate_response(
messages_for_run, tools=ad_hoc_tools, tool_choice=tool_choice
)
response = await self.generate_internal(
messages_for_run,
model=model,
config=config,
tools=tools,
tool_choice=tool_choice,
timeout=timeout,
)
should_preserve = (
preserve_media_in_history
if preserve_media_in_history is not None
else self.config.default_preserve_media_in_history
)
user_msg_to_store = (
current_message
if should_preserve
else self._sanitize_message_for_history(current_message)
)
assistant_response_msg = LLMMessage.assistant_text_response(response.text)
if response.tool_calls:
assistant_response_msg = LLMMessage.assistant_tool_calls(
response.tool_calls, response.text
msgs_to_store: list[LLMMessage] = []
for msg in messages_to_add:
store_msg = (
msg if should_preserve else self._sanitize_message_for_history(msg)
)
msgs_to_store.append(store_msg)
if response.content_parts:
assistant_response_msg = LLMMessage(
role="assistant",
content=response.content_parts,
tool_calls=response.tool_calls,
)
else:
assistant_response_msg = LLMMessage.assistant_text_response(
response.text
)
if response.tool_calls:
assistant_response_msg = LLMMessage.assistant_tool_calls(
response.tool_calls, response.text
)
await self.memory.add_messages(
self.session_id, [user_msg_to_store, assistant_response_msg]
self.session_id, [*msgs_to_store, assistant_response_msg]
)
if self.processors:
for processor in self.processors:
await processor.process(
self.session_id, [*msgs_to_store, assistant_response_msg]
)
return response
except Exception as e:
@ -280,7 +482,7 @@ class AI:
*,
model: ModelName = None,
timeout: int | None = None,
config: LLMGenerationConfig | None = None,
config: LLMGenerationConfig | GenConfigBuilder | None = None,
) -> LLMResponse:
"""
代码执行
@ -294,16 +496,18 @@ class AI:
返回:
LLMResponse: 包含执行结果的完整响应对象
"""
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
resolved_model = model or self.config.model
code_config = CommonOverrides.gemini_code_execution()
if timeout:
code_config.custom_params = code_config.custom_params or {}
code_config.custom_params["code_execution_timeout"] = timeout
if isinstance(config, GenConfigBuilder):
config = config.build()
if config:
update_dict = model_dump(config, exclude_unset=True)
code_config = model_copy(code_config, update=update_dict)
code_config = code_config.merge_with(config)
return await self.chat(prompt, model=resolved_model, config=code_config)
@ -317,7 +521,7 @@ class AI:
"根据用户的查询找到最相关的信息,并进行总结和回答。"
),
template_vars: dict[str, Any] | None = None,
config: LLMGenerationConfig | None = None,
config: LLMGenerationConfig | GenConfigBuilder | None = None,
) -> LLMResponse:
"""
信息搜索的便捷入口原生支持多模态查询
@ -325,9 +529,11 @@ class AI:
logger.info("执行 'search' 任务...")
search_config = CommonOverrides.gemini_grounding()
if isinstance(config, GenConfigBuilder):
config = config.build()
if config:
update_dict = model_dump(config, exclude_unset=True)
search_config = model_copy(search_config, update=update_dict)
search_config = search_config.merge_with(config)
return await self.chat(
query,
@ -339,21 +545,31 @@ class AI:
async def generate_structured(
self,
message: str | LLMMessage | list[LLMContentPart],
message: str | LLMMessage | list[LLMContentPart] | None,
response_model: type[T],
*,
model: ModelName = None,
tools: list[Any] | dict[str, ToolExecutable] | None = None,
tool_choice: str | dict[str, Any] | ToolChoice | None = None,
instruction: str | None = None,
config: LLMGenerationConfig | None = None,
timeout: float | None = None,
template_vars: dict[str, Any] | None = None,
config: LLMGenerationConfig | GenConfigBuilder | None = None,
max_validation_retries: int | None = None,
validation_callback: Callable[[T], Any | Awaitable[Any]] | None = None,
error_prompt_template: str | None = None,
auto_thinking: bool = False,
) -> T:
"""
生成结构化响应并自动解析为指定的Pydantic模型
参数:
message: 用户输入的消息内容支持多种格式
message: 用户输入的消息内容支持多种格式为None时只使用历史+缓冲区
response_model: 用于解析和验证响应的Pydantic模型类
model: 要使用的模型名称如果为None则使用配置中的默认模型
instruction: 本次调用的特定系统指令会与JSON Schema指令合并
timeout: HTTP 请求超时时间
template_vars: 系统指令中的模板变量用于动态渲染
config: 生成配置对象用于覆盖默认的生成参数
返回:
@ -362,6 +578,46 @@ class AI:
异常:
LLMException: 如果模型返回的不是有效的JSON或验证失败
"""
if isinstance(config, GenConfigBuilder):
config = config.build()
final_config = self.default_generation_config.merge_with(config)
if final_config is None:
final_config = LLMGenerationConfig()
if max_validation_retries is None:
max_validation_retries = get_llm_config().client_settings.structured_retries
resolved_model_name = self._resolve_model_name(model or self.config.model)
request_autocot = True if auto_thinking is False else auto_thinking
effective_auto_thinking = should_apply_autocot(
request_autocot, resolved_model_name, final_config
)
target_model: type[T] = response_model
if effective_auto_thinking:
target_model = cast(type[T], create_cot_wrapper(response_model))
response_model = target_model
cot_instruction = (
"请务必先在 `reasoning` 字段中进行详细的一步步推理,确保逻辑正确,"
"然后再填充 `result` 字段。"
)
if instruction:
instruction = f"{instruction}\n\n{cot_instruction}"
else:
instruction = cot_instruction
final_instruction = instruction
if final_instruction and template_vars:
try:
template = Template(final_instruction)
final_instruction = template.render(**template_vars)
except Exception as e:
logger.error(f"渲染结构化指令模板失败: {e}", e=e)
try:
json_schema = model_json_schema(response_model)
except AttributeError:
@ -369,41 +625,149 @@ class AI:
schema_str = json.dumps(json_schema, ensure_ascii=False, indent=2)
system_prompt = (
(f"{instruction}\n\n" if instruction else "")
+ "你必须严格按照以下 JSON Schema 格式进行响应。"
+ "不要包含任何额外的解释、注释或代码块标记,只返回纯粹的 JSON 对象。\n\n"
prompt_prefix = f"{final_instruction}\n\n" if final_instruction else ""
structured_strategy = (
final_config.output.structured_output_strategy
if final_config.output
else None
)
system_prompt += f"JSON Schema:\n```json\n{schema_str}\n```"
if structured_strategy == StructuredOutputStrategy.TOOL_CALL:
system_prompt = prompt_prefix + "请调用提供的工具提交结构化数据。"
else:
system_prompt = (
prompt_prefix
+ "请严格按照以下 JSON Schema 格式进行响应。不应包含任何额外的解释、"
"注释或代码块标记,只返回一个合法的 JSON 对象。\n\n"
)
system_prompt += f"JSON Schema:\n```json\n{schema_str}\n```"
final_config = model_copy(config) if config else LLMGenerationConfig()
final_config.response_format = ResponseFormat.JSON
final_config.response_schema = json_schema
response = await self.chat(
message, model=model, instruction=system_prompt, config=final_config
structured_strategy = (
final_config.output.structured_output_strategy
if final_config.output
else StructuredOutputStrategy.NATIVE
)
try:
return type_validate_json(response_model, response.text)
except ValidationError as e:
logger.error(f"LLM结构化输出验证失败: {e}", e=e)
raise LLMException(
"LLM返回的JSON未能通过结构验证。",
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
details={"raw_response": response.text, "validation_error": str(e)},
cause=e,
)
except Exception as e:
logger.error(f"解析LLM结构化输出时发生未知错误: {e}", e=e)
raise LLMException(
"解析LLM的JSON输出时失败。",
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
details={"raw_response": response.text},
cause=e,
final_tools_list: list[ToolExecutable] | None = None
if structured_strategy != StructuredOutputStrategy.NATIVE:
if tools:
final_tools_list = []
if isinstance(tools, dict):
final_tools_list = list(tools.values())
elif isinstance(tools, list):
to_resolve: list[Any] = []
for t in tools:
if isinstance(t, str | dict):
to_resolve.append(t)
else:
final_tools_list.append(t)
if to_resolve:
resolved_dict = await self._resolve_tools(to_resolve)
final_tools_list.extend(resolved_dict.values())
elif tools:
logger.warning(
"检测到在 generate_structured (NATIVE 策略) 中传入了 tools。"
"为了避免 API 冲突(Gemini)及输出歧义(OpenAI),这些"
"tools 将被本次请求忽略。"
"若需使用工具,请使用 chat() 方法或 Agent 流程。"
)
if final_config.output is None:
final_config.output = OutputConfig()
final_config.output.response_format = ResponseFormat.JSON
final_config.output.response_schema = json_schema
messages_for_run = [LLMMessage.system(system_prompt)]
current_history = await self.memory.get_history(self.session_id)
messages_for_run.extend(current_history)
messages_for_run.extend(self.message_buffer)
if message:
normalized_message = await self._normalize_input_to_message(message)
messages_for_run.append(normalized_message)
ivr_messages = list(messages_for_run)
last_exception: Exception | None = None
for attempt in range(max_validation_retries + 1):
current_response_text: str = ""
async with await get_model_instance(
resolved_model_name,
override_config=None,
) as model_instance:
response = await model_instance.generate_response(
ivr_messages,
config=final_config,
tools=final_tools_list if final_tools_list else None,
tool_choice=tool_choice,
timeout=timeout,
)
current_response_text = response.text
try:
parsed_obj = parse_and_validate_json(response.text, target_model)
final_obj: T = cast(T, parsed_obj)
if effective_auto_thinking:
logger.debug(
f"AutoCoT 思考过程: {getattr(parsed_obj, 'reasoning', '')}"
)
final_obj = cast(T, getattr(parsed_obj, "result"))
if validation_callback:
if is_coroutine_callable(validation_callback):
await validation_callback(final_obj)
else:
validation_callback(final_obj)
return final_obj
except Exception as e:
is_llm_error = isinstance(e, LLMException)
llm_error: LLMException | None = (
cast(LLMException, e) if is_llm_error else None
)
last_exception = e
if attempt < max_validation_retries:
error_msg = (
llm_error.details.get("validation_error", str(e))
if llm_error
else str(e)
)
raw_response = current_response_text or (
llm_error.details.get("raw_response", "") if llm_error else ""
)
logger.warning(
f"结构化校验失败 (尝试 {attempt + 1}/"
f"{max_validation_retries + 1})。正在尝试 IVR 修复... 错误:"
f"{error_msg}"
)
if raw_response:
ivr_messages.append(
LLMMessage.assistant_text_response(raw_response)
)
else:
logger.warning(
"IVR 警告: 无法获取上一轮生成的原始文本,"
"模型将在无上下文情况下尝试修复。"
)
template = error_prompt_template or DEFAULT_IVR_TEMPLATE
feedback_prompt = template.format(error_msg=error_msg)
ivr_messages.append(LLMMessage.user(feedback_prompt))
continue
if llm_error and not llm_error.recoverable:
raise llm_error
if last_exception:
raise last_exception
raise LLMException(
"IVR 循环异常结束,未能生成有效结果。", code=LLMErrorCode.GENERATION_FAILED
)
def _resolve_model_name(self, model_name: ModelName) -> str:
"""解析模型名称"""
if model_name:
@ -423,8 +787,7 @@ class AI:
texts: list[str] | str,
*,
model: ModelName = None,
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
**kwargs: Any,
config: LLMEmbeddingConfig | None = None,
) -> list[list[float]]:
"""
生成文本嵌入向量将文本转换为数值向量表示
@ -432,14 +795,13 @@ class AI:
参数:
texts: 要生成嵌入的文本内容支持单个字符串或字符串列表
model: 嵌入模型名称如果为None则使用配置中的默认嵌入模型
task_type: 嵌入任务类型影响向量的优化方向如检索分类等
**kwargs: 传递给嵌入模型的额外参数
config: 嵌入配置
返回:
list[list[float]]: 文本对应的嵌入向量列表每个向量为浮点数列表
异常:
LLMException: 如果嵌入生成失败或模型配置错误
LLMException: 当嵌入生成失败或模型配置错误时抛出
"""
if isinstance(texts, str):
texts = [texts]
@ -452,18 +814,20 @@ class AI:
)
if not resolved_model_str:
raise LLMException(
"使用 embed 功能时必须指定嵌入模型名称,"
"或在 AIConfig 中配置 default_embedding_model。",
"使用 embed 方法时未指定嵌入模型名称,"
"且 AIConfig 未设置 default_embedding_model。",
code=LLMErrorCode.MODEL_NOT_FOUND,
)
resolved_model_str = self._resolve_model_name(resolved_model_str)
final_config = config or LLMEmbeddingConfig()
async with await get_model_instance(
resolved_model_str,
override_config=None,
) as embedding_model_instance:
return await embedding_model_instance.generate_embeddings(
texts, task_type=task_type, **kwargs
texts, config=final_config
)
except LLMException:
raise
@ -484,6 +848,15 @@ class AI:
resolved: dict[str, ToolExecutable] = {}
for config in tool_configs:
if isinstance(config, str):
if config == "google_search":
resolved[config] = GeminiGoogleSearch() # type: ignore[arg-type]
continue
elif config == "code_execution":
resolved[config] = GeminiCodeExecution() # type: ignore[arg-type]
continue
elif config == "url_context":
pass
name = config if isinstance(config, str) else config.get("name")
if not name:
raise LLMException(

View File

@ -0,0 +1,839 @@
"""
工具模块
整合了工具参数解析器工具提供者管理器与工具执行逻辑便于在 LLM 服务层统一调用
"""
import asyncio
from collections.abc import Callable
from enum import Enum
import inspect
import json
import re
import time
from typing import (
Annotated,
Any,
Optional,
Union,
cast,
get_args,
get_origin,
get_type_hints,
)
from typing_extensions import override
from httpx import NetworkError, TimeoutException
try:
import ujson as fast_json
except ImportError:
fast_json = json
import nonebot
from nonebot.dependencies import Dependent, Param
from nonebot.internal.adapter import Bot, Event
from nonebot.internal.params import (
BotParam,
DefaultParam,
DependParam,
DependsInner,
EventParam,
StateParam,
)
from pydantic import BaseModel, Field, ValidationError, create_model
from pydantic.fields import FieldInfo
from zhenxun.services.log import logger
from zhenxun.utils.decorator.retry import Retry
from zhenxun.utils.pydantic_compat import model_dump, model_fields, model_json_schema
from .types import (
LLMErrorCode,
LLMException,
LLMMessage,
LLMToolCall,
ToolExecutable,
ToolProvider,
ToolResult,
)
from .types.models import ToolDefinition
from .types.protocols import BaseCallbackHandler, ToolCallData
class ToolParam(Param):
"""
工具参数提取器
用于在自定义工具函数Function Tool LLM 解析出的参数字典
(`state["_tool_params"]`)
中提取特定的参数值通常配合 `Annotated` 和依赖注入系统使用
"""
def __init__(self, *args: Any, name: str, **kwargs: Any):
super().__init__(*args, **kwargs)
self.name = name
def __repr__(self) -> str:
return f"ToolParam(name={self.name})"
@classmethod
@override
def _check_param(
cls, param: inspect.Parameter, allow_types: tuple[type[Param], ...]
) -> Optional["ToolParam"]:
if param.default is not inspect.Parameter.empty and isinstance(
param.default, DependsInner
):
return None
if get_origin(param.annotation) is Annotated:
for arg in get_args(param.annotation):
if isinstance(arg, DependsInner):
return None
if param.kind not in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
):
return cls(name=param.name)
return None
@override
async def _solve(self, **kwargs: Any) -> Any:
state: dict[str, Any] = kwargs.get("state", {})
tool_params = state.get("_tool_params", {})
if self.name in tool_params:
return tool_params[self.name]
return None
class RunContext(BaseModel):
"""
依赖注入容器DI Container保留原有上下文信息的同时提升获取类型的能力
"""
session_id: str | None = None
scope: dict[str, Any] = Field(default_factory=dict)
extra: dict[str, Any] = Field(default_factory=dict)
class Config:
arbitrary_types_allowed = True
class RunContextParam(Param):
"""自动注入 RunContext 的参数解析器"""
@classmethod
def _check_param(
cls, param: inspect.Parameter, allow_types: tuple[type[Param], ...]
) -> Optional["RunContextParam"]:
if param.annotation is RunContext:
return cls()
return None
async def _solve(self, **kwargs: Any) -> Any:
state = kwargs.get("state", {})
return state.get("_agent_context")
def _parse_docstring_params(docstring: str | None) -> dict[str, str]:
"""
解析文档字符串提取参数描述
支持 Google Style (Args:), ReST Style (:param:), 和中文风格 (参数:)
"""
if not docstring:
return {}
params: dict[str, str] = {}
lines = docstring.splitlines()
rest_pattern = re.compile(r"[:@]param\s+(\w+)\s*:?\s*(.*)")
found_rest = False
for line in lines:
match = rest_pattern.search(line)
if match:
params[match.group(1)] = match.group(2).strip()
found_rest = True
if found_rest:
return params
section_header_pattern = re.compile(
r"^\s*(?:Args|Arguments|Parameters|参数)\s*[:]\s*$"
)
param_section_active = False
google_pattern = re.compile(r"^\s*(\**\w+)(?:\s*\(.*?\))?\s*[:]\s*(.*)")
for line in lines:
stripped_line = line.strip()
if not stripped_line:
continue
if section_header_pattern.match(line):
param_section_active = True
continue
if param_section_active:
if (
stripped_line.endswith(":") or stripped_line.endswith("")
) and not google_pattern.match(line):
param_section_active = False
continue
match = google_pattern.match(line)
if match:
name = match.group(1).lstrip("*")
desc = match.group(2).strip()
params[name] = desc
return params
def _create_dynamic_model(func: Callable) -> type[BaseModel]:
"""根据函数签名动态创建 Pydantic 模型"""
sig = inspect.signature(func)
doc_params = _parse_docstring_params(func.__doc__)
type_hints = get_type_hints(func, include_extras=True)
fields = {}
for name, param in sig.parameters.items():
if name in ("self", "cls"):
continue
annotation = type_hints.get(name, Any)
default = param.default
is_run_context = False
if annotation is RunContext:
is_run_context = True
else:
origin = get_origin(annotation)
if origin is Union:
args = get_args(annotation)
if RunContext in args:
is_run_context = True
if is_run_context:
continue
if default is not inspect.Parameter.empty and isinstance(default, DependsInner):
continue
if get_origin(annotation) is Annotated:
args = get_args(annotation)
if any(isinstance(arg, DependsInner) for arg in args):
continue
description = doc_params.get(name)
if isinstance(default, FieldInfo):
if description and not getattr(default, "description", None):
default.description = description
fields[name] = (annotation, default)
else:
if default is inspect.Parameter.empty:
default = ...
fields[name] = (annotation, Field(default, description=description))
return create_model(f"{func.__name__}Params", **fields)
class FunctionExecutable(ToolExecutable):
"""一个 ToolExecutable 的实现,用于包装一个普通的 Python 函数。"""
def __init__(
self,
func: Callable,
name: str,
description: str,
params_model: type[BaseModel] | None = None,
unpack_args: bool = False,
):
self._func = func
self._name = name
self._description = description
self._params_model = params_model
self._unpack_args = unpack_args
self.dependent = Dependent[Any].parse(
call=func,
allow_types=(
DependParam,
BotParam,
EventParam,
StateParam,
RunContextParam,
ToolParam,
DefaultParam,
),
)
async def get_definition(self) -> ToolDefinition:
if not self._params_model:
return ToolDefinition(
name=self._name,
description=self._description,
parameters={"type": "object", "properties": {}},
)
schema = model_json_schema(self._params_model)
return ToolDefinition(
name=self._name,
description=self._description,
parameters={
"type": "object",
"properties": schema.get("properties", {}),
"required": schema.get("required", []),
},
)
async def execute(
self, context: RunContext | None = None, **kwargs: Any
) -> ToolResult:
context = context or RunContext()
tool_arguments = kwargs
if self._params_model:
try:
_fields = model_fields(self._params_model)
validation_input = {
key: value for key, value in kwargs.items() if key in _fields
}
validated_params = self._params_model(**validation_input)
if not self._unpack_args:
pass
else:
validated_dict = model_dump(validated_params)
tool_arguments = validated_dict
except ValidationError as e:
error_msgs = []
for err in e.errors():
loc = ".".join(str(x) for x in err["loc"])
msg = err["msg"]
error_msgs.append(f"Parameter '{loc}': {msg}")
formatted_error = "; ".join(error_msgs)
error_payload = {
"error_type": "InvalidArguments",
"message": f"Parameter validation failed: {formatted_error}",
"is_retryable": True,
}
return ToolResult(
output=json.dumps(error_payload, ensure_ascii=False),
display_content=f"Validation Error: {formatted_error}",
)
except Exception as e:
logger.error(
f"执行工具 '{self._name}' 时参数验证或实例化失败: {e}", e=e
)
raise
state = {
"_tool_params": tool_arguments,
"_agent_context": context,
}
bot: Bot | None = None
if context and context.scope.get("bot"):
bot = context.scope.get("bot")
if not bot:
try:
bot = nonebot.get_bot()
except ValueError:
pass
event: Event | None = None
if context and context.scope.get("event"):
event = context.scope.get("event")
raw_result = await self.dependent(
bot=bot,
event=event,
state=state,
)
return ToolResult(output=raw_result, display_content=str(raw_result))
class BuiltinFunctionToolProvider(ToolProvider):
"""一个内置的 ToolProvider用于处理通过装饰器注册的函数。"""
def __init__(self):
self._functions: dict[str, dict[str, Any]] = {}
def register(
self,
name: str,
func: Callable,
description: str,
params_model: type[BaseModel] | None = None,
unpack_args: bool = False,
):
self._functions[name] = {
"func": func,
"description": description,
"params_model": params_model,
"unpack_args": unpack_args,
}
async def initialize(self) -> None:
pass
async def discover_tools(
self,
allowed_servers: list[str] | None = None,
excluded_servers: list[str] | None = None,
) -> dict[str, ToolExecutable]:
executables = {}
for name, info in self._functions.items():
executables[name] = FunctionExecutable(
func=info["func"],
name=name,
description=info["description"],
params_model=info["params_model"],
unpack_args=info.get("unpack_args", False),
)
return executables
async def get_tool_executable(
self, name: str, config: dict[str, Any]
) -> ToolExecutable | None:
if config.get("type", "function") == "function" and name in self._functions:
info = self._functions[name]
return FunctionExecutable(
func=info["func"],
name=name,
description=info["description"],
params_model=info["params_model"],
unpack_args=info.get("unpack_args", False),
)
return None
class ToolProviderManager:
"""工具提供者的中心化管理器,采用单例模式。"""
_instance: "ToolProviderManager | None" = None
def __new__(cls) -> "ToolProviderManager":
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if hasattr(self, "_initialized") and self._initialized:
return
self._providers: list[ToolProvider] = []
self._resolved_tools: dict[str, ToolExecutable] | None = None
self._init_lock = asyncio.Lock()
self._init_promise: asyncio.Task | None = None
self._builtin_function_provider = BuiltinFunctionToolProvider()
self.register(self._builtin_function_provider)
self._initialized = True
def register(self, provider: ToolProvider):
"""注册一个新的 ToolProvider。"""
if provider not in self._providers:
self._providers.append(provider)
logger.info(f"已注册工具提供者: {provider.__class__.__name__}")
def function_tool(
self,
name: str,
description: str,
params_model: type[BaseModel] | None = None,
):
"""装饰器:将一个函数注册为内置工具。"""
def decorator(func: Callable):
if name in self._builtin_function_provider._functions:
logger.warning(f"正在覆盖已注册的函数工具: {name}")
final_model = params_model
unpack_args = False
if final_model is None:
final_model = _create_dynamic_model(func)
unpack_args = True
self._builtin_function_provider.register(
name=name,
func=func,
description=description,
params_model=final_model,
unpack_args=unpack_args,
)
logger.info(f"已注册函数工具: '{name}'")
return func
return decorator
async def initialize(self) -> None:
"""懒加载初始化所有已注册的 ToolProvider。"""
if not self._init_promise:
async with self._init_lock:
if not self._init_promise:
self._init_promise = asyncio.create_task(
self._initialize_providers()
)
await self._init_promise
async def _initialize_providers(self) -> None:
"""内部初始化逻辑。"""
logger.info(f"开始初始化 {len(self._providers)} 个工具提供者...")
init_tasks = [provider.initialize() for provider in self._providers]
await asyncio.gather(*init_tasks, return_exceptions=True)
logger.info("所有工具提供者初始化完成。")
async def get_resolved_tools(
self,
allowed_servers: list[str] | None = None,
excluded_servers: list[str] | None = None,
) -> dict[str, ToolExecutable]:
"""
获取所有已发现和解析的工具
此方法会触发懒加载初始化并根据是否传入过滤器来决定是否使用全局缓存
"""
await self.initialize()
has_filters = allowed_servers is not None or excluded_servers is not None
if not has_filters and self._resolved_tools is not None:
logger.debug("使用全局工具缓存。")
return self._resolved_tools
if has_filters:
logger.info("检测到过滤器,执行临时工具发现 (不使用缓存)。")
logger.debug(
f"过滤器详情: allowed_servers={allowed_servers}, "
f"excluded_servers={excluded_servers}"
)
else:
logger.info("未应用过滤器,开始全局工具发现...")
all_tools: dict[str, ToolExecutable] = {}
discover_tasks = []
for provider in self._providers:
sig = inspect.signature(provider.discover_tools)
params_to_pass = {}
if "allowed_servers" in sig.parameters:
params_to_pass["allowed_servers"] = allowed_servers
if "excluded_servers" in sig.parameters:
params_to_pass["excluded_servers"] = excluded_servers
discover_tasks.append(provider.discover_tools(**params_to_pass))
results = await asyncio.gather(*discover_tasks, return_exceptions=True)
for i, provider_result in enumerate(results):
provider_name = self._providers[i].__class__.__name__
if isinstance(provider_result, dict):
logger.debug(
f"提供者 '{provider_name}' 发现了 {len(provider_result)} 个工具。"
)
for name, executable in provider_result.items():
if name in all_tools:
logger.warning(
f"发现重复的工具名称 '{name}',后发现的将覆盖前者。"
)
all_tools[name] = executable
elif isinstance(provider_result, Exception):
logger.error(
f"提供者 '{provider_name}' 在发现工具时出错: {provider_result}"
)
if not has_filters:
self._resolved_tools = all_tools
logger.info(f"全局工具发现完成,共找到并缓存了 {len(all_tools)} 个工具。")
else:
logger.info(f"带过滤器的工具发现完成,共找到 {len(all_tools)} 个工具。")
return all_tools
async def resolve_specific_tools(
self, tool_names: list[str]
) -> dict[str, ToolExecutable]:
"""
仅解析指定名称的工具避免触发全量工具发现
"""
resolved: dict[str, ToolExecutable] = {}
if not tool_names:
return resolved
await self.initialize()
for name in tool_names:
config: dict[str, Any] = {"name": name}
for provider in self._providers:
try:
executable = await provider.get_tool_executable(name, config)
except Exception as exc:
logger.error(
f"provider '{provider.__class__.__name__}' 在解析工具 '{name}'"
f"时出错: {exc}",
e=exc,
)
continue
if executable:
resolved[name] = executable
break
else:
logger.warning(f"没有找到名为 '{name}' 的工具,已跳过。")
return resolved
async def get_function_tools(
self, names: list[str] | None = None
) -> dict[str, ToolExecutable]:
"""
仅从内置的函数提供者中解析指定的工具
"""
all_function_tools = await self._builtin_function_provider.discover_tools()
if names is None:
return all_function_tools
resolved_tools = {}
for name in names:
if name in all_function_tools:
resolved_tools[name] = all_function_tools[name]
else:
logger.warning(
f"本地函数工具 '{name}' 未通过 @function_tool 注册,将被忽略。"
)
return resolved_tools
tool_provider_manager = ToolProviderManager()
function_tool = tool_provider_manager.function_tool
class ToolErrorType(str, Enum):
"""结构化工具错误的类型枚举。"""
TOOL_NOT_FOUND = "ToolNotFound"
INVALID_ARGUMENTS = "InvalidArguments"
EXECUTION_ERROR = "ExecutionError"
USER_CANCELLATION = "UserCancellation"
class ToolErrorResult(BaseModel):
"""一个结构化的工具执行错误模型。"""
error_type: ToolErrorType = Field(..., description="错误的类型。")
message: str = Field(..., description="对错误的详细描述。")
is_retryable: bool = Field(False, description="指示这个错误是否可能通过重试解决。")
class ToolInvoker:
"""
全能工具执行器
负责接收工具调用请求解析参数触发回调执行工具并返回标准化的结果
"""
def __init__(self, callbacks: list[BaseCallbackHandler] | None = None):
self.callbacks = callbacks or []
async def _trigger_callbacks(self, event_name: str, *args, **kwargs: Any) -> None:
if not self.callbacks:
return
tasks = [
getattr(handler, event_name)(*args, **kwargs)
for handler in self.callbacks
if hasattr(handler, event_name)
]
await asyncio.gather(*tasks, return_exceptions=True)
async def execute_tool_call(
self,
tool_call: LLMToolCall,
available_tools: dict[str, ToolExecutable],
context: Any | None = None,
) -> tuple[LLMToolCall, ToolResult]:
tool_name = tool_call.function.name
arguments_str = tool_call.function.arguments
arguments: dict[str, Any] = {}
try:
if arguments_str:
arguments = json.loads(arguments_str)
except json.JSONDecodeError as e:
error_result = ToolErrorResult(
error_type=ToolErrorType.INVALID_ARGUMENTS,
message=f"参数解析失败: {e}",
is_retryable=False,
)
return tool_call, ToolResult(output=model_dump(error_result))
tool_data = ToolCallData(tool_name=tool_name, tool_args=arguments)
pre_calculated_result: ToolResult | None = None
for handler in self.callbacks:
res = await handler.on_tool_start(tool_call, tool_data)
if isinstance(res, ToolCallData):
tool_data = res
arguments = tool_data.tool_args
tool_call.function.arguments = json.dumps(arguments, ensure_ascii=False)
elif isinstance(res, ToolResult):
pre_calculated_result = res
break
if pre_calculated_result:
return tool_call, pre_calculated_result
executable = available_tools.get(tool_name)
if not executable:
error_result = ToolErrorResult(
error_type=ToolErrorType.TOOL_NOT_FOUND,
message=f"Tool '{tool_name}' not found.",
is_retryable=False,
)
return tool_call, ToolResult(output=model_dump(error_result))
from .config.providers import get_llm_config
if not get_llm_config().debug_log:
try:
definition = await executable.get_definition()
schema_payload = getattr(definition, "parameters", {})
schema_json = fast_json.dumps(
schema_payload,
ensure_ascii=False,
)
logger.debug(
f"🔍 [JIT Schema] {tool_name}: {schema_json}",
"ToolInvoker",
)
except Exception as e:
logger.trace(f"JIT Schema logging failed: {e}")
start_t = time.monotonic()
result: ToolResult | None = None
error: Exception | None = None
try:
@Retry.simple(stop_max_attempt=2, wait_fixed_seconds=1)
async def execute_with_retry():
return await executable.execute(context=context, **arguments)
result = await execute_with_retry()
except ValidationError as e:
error = e
error_msgs = []
for err in e.errors():
loc = ".".join(str(x) for x in err["loc"])
msg = err["msg"]
error_msgs.append(f"参数 '{loc}': {msg}")
formatted_error = "; ".join(error_msgs)
error_result = ToolErrorResult(
error_type=ToolErrorType.INVALID_ARGUMENTS,
message=f"参数验证失败。请根据错误修正你的输入: {formatted_error}",
is_retryable=True,
)
result = ToolResult(output=model_dump(error_result))
except (TimeoutException, NetworkError) as e:
error = e
error_result = ToolErrorResult(
error_type=ToolErrorType.EXECUTION_ERROR,
message=f"工具执行网络超时或连接失败: {e!s}",
is_retryable=False,
)
result = ToolResult(output=model_dump(error_result))
except Exception as e:
error = e
error_type = ToolErrorType.EXECUTION_ERROR
if (
isinstance(e, LLMException)
and e.code == LLMErrorCode.CONFIGURATION_ERROR
):
error_type = ToolErrorType.TOOL_NOT_FOUND
is_retryable = False
is_retryable = False
error_result = ToolErrorResult(
error_type=error_type, message=str(e), is_retryable=is_retryable
)
result = ToolResult(output=model_dump(error_result))
duration = time.monotonic() - start_t
await self._trigger_callbacks(
"on_tool_end",
result=result,
error=error,
tool_call=tool_call,
duration=duration,
)
if result is None:
raise LLMException("工具执行未返回任何结果。")
return tool_call, result
async def execute_batch(
self,
tool_calls: list[LLMToolCall],
available_tools: dict[str, ToolExecutable],
context: Any | None = None,
) -> list[LLMMessage]:
if not tool_calls:
return []
tasks = [
self.execute_tool_call(call, available_tools, context)
for call in tool_calls
]
results = await asyncio.gather(*tasks, return_exceptions=True)
tool_messages: list[LLMMessage] = []
for index, result_pair in enumerate(results):
original_call = tool_calls[index]
if isinstance(result_pair, Exception):
logger.error(
f"工具执行发生未捕获异常: {original_call.function.name}, "
f"错误: {result_pair}"
)
tool_messages.append(
LLMMessage.tool_response(
tool_call_id=original_call.id,
function_name=original_call.function.name,
result={
"error": f"System Execution Error: {result_pair}",
"status": "failed",
},
)
)
continue
tool_call_result = cast(tuple[LLMToolCall, ToolResult], result_pair)
_, tool_result = tool_call_result
tool_messages.append(
LLMMessage.tool_response(
tool_call_id=original_call.id,
function_name=original_call.function.name,
result=tool_result.output,
)
)
return tool_messages
__all__ = [
"RunContext",
"RunContextParam",
"ToolErrorResult",
"ToolErrorType",
"ToolInvoker",
"ToolParam",
"function_tool",
"tool_provider_manager",
]

View File

@ -1,13 +0,0 @@
"""
工具模块导出
"""
from .manager import tool_provider_manager
function_tool = tool_provider_manager.function_tool
__all__ = [
"function_tool",
"tool_provider_manager",
]

View File

@ -1,293 +0,0 @@
"""
工具提供者管理器
负责注册生命周期管理包括懒加载和统一提供所有工具
"""
import asyncio
from collections.abc import Callable
import inspect
from typing import Any
from pydantic import BaseModel
from zhenxun.services.log import logger
from zhenxun.utils.pydantic_compat import model_json_schema
from ..types import ToolExecutable, ToolProvider
from ..types.models import ToolDefinition, ToolResult
class FunctionExecutable(ToolExecutable):
"""一个 ToolExecutable 的实现,用于包装一个普通的 Python 函数。"""
def __init__(
self,
func: Callable,
name: str,
description: str,
params_model: type[BaseModel] | None,
):
self._func = func
self._name = name
self._description = description
self._params_model = params_model
async def get_definition(self) -> ToolDefinition:
if not self._params_model:
return ToolDefinition(
name=self._name,
description=self._description,
parameters={"type": "object", "properties": {}},
)
schema = model_json_schema(self._params_model)
return ToolDefinition(
name=self._name,
description=self._description,
parameters={
"type": "object",
"properties": schema.get("properties", {}),
"required": schema.get("required", []),
},
)
async def execute(self, **kwargs: Any) -> ToolResult:
raw_result: Any
if self._params_model:
try:
params_instance = self._params_model(**kwargs)
if inspect.iscoroutinefunction(self._func):
raw_result = await self._func(params_instance)
else:
loop = asyncio.get_event_loop()
raw_result = await loop.run_in_executor(
None, lambda: self._func(params_instance)
)
except Exception as e:
logger.error(
f"执行工具 '{self._name}' 时参数验证或实例化失败: {e}", e=e
)
raise
else:
if inspect.iscoroutinefunction(self._func):
raw_result = await self._func(**kwargs)
else:
loop = asyncio.get_event_loop()
raw_result = await loop.run_in_executor(
None, lambda: self._func(**kwargs)
)
return ToolResult(output=raw_result, display_content=str(raw_result))
class BuiltinFunctionToolProvider(ToolProvider):
"""一个内置的 ToolProvider用于处理通过装饰器注册的函数。"""
def __init__(self):
self._functions: dict[str, dict[str, Any]] = {}
def register(
self,
name: str,
func: Callable,
description: str,
params_model: type[BaseModel] | None,
):
self._functions[name] = {
"func": func,
"description": description,
"params_model": params_model,
}
async def initialize(self) -> None:
pass
async def discover_tools(
self,
allowed_servers: list[str] | None = None,
excluded_servers: list[str] | None = None,
) -> dict[str, ToolExecutable]:
executables = {}
for name, info in self._functions.items():
executables[name] = FunctionExecutable(
func=info["func"],
name=name,
description=info["description"],
params_model=info["params_model"],
)
return executables
async def get_tool_executable(
self, name: str, config: dict[str, Any]
) -> ToolExecutable | None:
if config.get("type") == "function" and name in self._functions:
info = self._functions[name]
return FunctionExecutable(
func=info["func"],
name=name,
description=info["description"],
params_model=info["params_model"],
)
return None
class ToolProviderManager:
"""工具提供者的中心化管理器,采用单例模式。"""
_instance: "ToolProviderManager | None" = None
def __new__(cls) -> "ToolProviderManager":
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if hasattr(self, "_initialized") and self._initialized:
return
self._providers: list[ToolProvider] = []
self._resolved_tools: dict[str, ToolExecutable] | None = None
self._init_lock = asyncio.Lock()
self._init_promise: asyncio.Task | None = None
self._builtin_function_provider = BuiltinFunctionToolProvider()
self.register(self._builtin_function_provider)
self._initialized = True
def register(self, provider: ToolProvider):
"""注册一个新的 ToolProvider。"""
if provider not in self._providers:
self._providers.append(provider)
logger.info(f"已注册工具提供者: {provider.__class__.__name__}")
def function_tool(
self,
name: str,
description: str,
params_model: type[BaseModel] | None = None,
):
"""装饰器:将一个函数注册为内置工具。"""
def decorator(func: Callable):
if name in self._builtin_function_provider._functions:
logger.warning(f"正在覆盖已注册的函数工具: {name}")
self._builtin_function_provider.register(
name=name,
func=func,
description=description,
params_model=params_model,
)
logger.info(f"已注册函数工具: '{name}'")
return func
return decorator
async def initialize(self) -> None:
"""懒加载初始化所有已注册的 ToolProvider。"""
if not self._init_promise:
async with self._init_lock:
if not self._init_promise:
self._init_promise = asyncio.create_task(
self._initialize_providers()
)
await self._init_promise
async def _initialize_providers(self) -> None:
"""内部初始化逻辑。"""
logger.info(f"开始初始化 {len(self._providers)} 个工具提供者...")
init_tasks = [provider.initialize() for provider in self._providers]
await asyncio.gather(*init_tasks, return_exceptions=True)
logger.info("所有工具提供者初始化完成。")
async def get_resolved_tools(
self,
allowed_servers: list[str] | None = None,
excluded_servers: list[str] | None = None,
) -> dict[str, ToolExecutable]:
"""
获取所有已发现和解析的工具
此方法会触发懒加载初始化并根据是否传入过滤器来决定是否使用全局缓存
"""
await self.initialize()
has_filters = allowed_servers is not None or excluded_servers is not None
if not has_filters and self._resolved_tools is not None:
logger.debug("使用全局工具缓存。")
return self._resolved_tools
if has_filters:
logger.info("检测到过滤器,执行临时工具发现 (不使用缓存)。")
logger.debug(
f"过滤器详情: allowed_servers={allowed_servers}, "
f"excluded_servers={excluded_servers}"
)
else:
logger.info("未应用过滤器,开始全局工具发现...")
all_tools: dict[str, ToolExecutable] = {}
discover_tasks = []
for provider in self._providers:
sig = inspect.signature(provider.discover_tools)
params_to_pass = {}
if "allowed_servers" in sig.parameters:
params_to_pass["allowed_servers"] = allowed_servers
if "excluded_servers" in sig.parameters:
params_to_pass["excluded_servers"] = excluded_servers
discover_tasks.append(provider.discover_tools(**params_to_pass))
results = await asyncio.gather(*discover_tasks, return_exceptions=True)
for i, provider_result in enumerate(results):
provider_name = self._providers[i].__class__.__name__
if isinstance(provider_result, dict):
logger.debug(
f"提供者 '{provider_name}' 发现了 {len(provider_result)} 个工具。"
)
for name, executable in provider_result.items():
if name in all_tools:
logger.warning(
f"发现重复的工具名称 '{name}',后发现的将覆盖前者。"
)
all_tools[name] = executable
elif isinstance(provider_result, Exception):
logger.error(
f"提供者 '{provider_name}' 在发现工具时出错: {provider_result}"
)
if not has_filters:
self._resolved_tools = all_tools
logger.info(f"全局工具发现完成,共找到并缓存了 {len(all_tools)} 个工具。")
else:
logger.info(f"带过滤器的工具发现完成,共找到 {len(all_tools)} 个工具。")
return all_tools
async def get_function_tools(
self, names: list[str] | None = None
) -> dict[str, ToolExecutable]:
"""
仅从内置的函数提供者中解析指定的工具
"""
all_function_tools = await self._builtin_function_provider.discover_tools()
if names is None:
return all_function_tools
resolved_tools = {}
for name in names:
if name in all_function_tools:
resolved_tools[name] = all_function_tools[name]
else:
logger.warning(
f"本地函数工具 '{name}' 未通过 @function_tool 注册,将被忽略。"
)
return resolved_tools
tool_provider_manager = ToolProviderManager()

View File

@ -5,30 +5,32 @@ LLM 类型定义模块
"""
from .capabilities import ModelCapabilities, ModelModality, get_model_capabilities
from .content import (
LLMContentPart,
LLMMessage,
LLMResponse,
)
from .enums import (
EmbeddingTaskType,
ModelProvider,
ResponseFormat,
TaskType,
ToolCategory,
)
from .exceptions import LLMErrorCode, LLMException, get_user_friendly_error_message
from .models import (
CodeExecutionOutcome,
EmbeddingTaskType,
GeminiCodeExecution,
GeminiGoogleSearch,
GeminiUrlContext,
LLMCacheInfo,
LLMCodeExecution,
LLMContentPart,
LLMGroundingAttribution,
LLMGroundingMetadata,
LLMMessage,
LLMResponse,
LLMToolCall,
LLMToolFunction,
ModelDetail,
ModelInfo,
ModelName,
ModelProvider,
ProviderConfig,
ResponseFormat,
StructuredOutputStrategy,
TaskType,
ToolCategory,
ToolChoice,
ToolMetadata,
ToolResult,
UsageInfo,
@ -36,7 +38,11 @@ from .models import (
from .protocols import ToolExecutable, ToolProvider
__all__ = [
"CodeExecutionOutcome",
"EmbeddingTaskType",
"GeminiCodeExecution",
"GeminiGoogleSearch",
"GeminiUrlContext",
"LLMCacheInfo",
"LLMCodeExecution",
"LLMContentPart",
@ -56,8 +62,10 @@ __all__ = [
"ModelProvider",
"ProviderConfig",
"ResponseFormat",
"StructuredOutputStrategy",
"TaskType",
"ToolCategory",
"ToolChoice",
"ToolExecutable",
"ToolMetadata",
"ToolProvider",

View File

@ -6,6 +6,7 @@ LLM 模型能力定义模块
from enum import Enum
import fnmatch
from typing import Literal
from pydantic import BaseModel, Field
@ -20,6 +21,35 @@ class ModelModality(str, Enum):
EMBEDDING = "embedding"
class ReasoningMode(str, Enum):
"""推理/思考模式类型"""
NONE = "none"
BUDGET = "budget"
LEVEL = "level"
EFFORT = "effort"
PATTERNS_GEMINI_2_5 = [
"gemini-2.5*",
"gemini-flash*",
"gemini*lite*",
"gemini-flash-latest",
]
PATTERNS_GEMINI_3 = [
"gemini-3*",
"gemini-exp*",
]
PATTERNS_OPENAI_REASONING = [
"o1-*",
"o3-*",
"deepseek-r1*",
"deepseek-reasoner",
]
class ModelCapabilities(BaseModel):
"""定义一个模型的核心、稳定能力。"""
@ -27,6 +57,8 @@ class ModelCapabilities(BaseModel):
output_modalities: set[ModelModality] = Field(default={ModelModality.TEXT})
supports_tool_calling: bool = False
is_embedding_model: bool = False
reasoning_mode: ReasoningMode = ReasoningMode.NONE
reasoning_visibility: Literal["visible", "hidden", "none"] = "none"
STANDARD_TEXT_TOOL_CAPABILITIES = ModelCapabilities(
@ -35,7 +67,7 @@ STANDARD_TEXT_TOOL_CAPABILITIES = ModelCapabilities(
supports_tool_calling=True,
)
GEMINI_CAPABILITIES = ModelCapabilities(
CAP_GEMINI_2_5 = ModelCapabilities(
input_modalities={
ModelModality.TEXT,
ModelModality.IMAGE,
@ -44,21 +76,44 @@ GEMINI_CAPABILITIES = ModelCapabilities(
},
output_modalities={ModelModality.TEXT},
supports_tool_calling=True,
reasoning_mode=ReasoningMode.BUDGET,
reasoning_visibility="visible",
)
GEMINI_IMAGE_GEN_CAPABILITIES = ModelCapabilities(
CAP_GEMINI_3 = ModelCapabilities(
input_modalities={
ModelModality.TEXT,
ModelModality.IMAGE,
ModelModality.AUDIO,
ModelModality.VIDEO,
},
output_modalities={ModelModality.TEXT},
supports_tool_calling=True,
reasoning_mode=ReasoningMode.LEVEL,
reasoning_visibility="visible",
)
CAP_GEMINI_IMAGE_GEN = ModelCapabilities(
input_modalities={ModelModality.TEXT, ModelModality.IMAGE},
output_modalities={ModelModality.TEXT, ModelModality.IMAGE},
supports_tool_calling=True,
)
GPT_ADVANCED_TEXT_IMAGE_CAPABILITIES = ModelCapabilities(
CAP_OPENAI_REASONING = ModelCapabilities(
input_modalities={ModelModality.TEXT, ModelModality.IMAGE},
output_modalities={ModelModality.TEXT},
supports_tool_calling=True,
reasoning_mode=ReasoningMode.EFFORT,
reasoning_visibility="hidden",
)
CAP_GPT_ADVANCED = ModelCapabilities(
input_modalities={ModelModality.TEXT, ModelModality.IMAGE},
output_modalities={ModelModality.TEXT},
supports_tool_calling=True,
)
GPT_MULTIMODAL_IO_CAPABILITIES = ModelCapabilities(
CAP_GPT_MULTIMODAL_IO = ModelCapabilities(
input_modalities={ModelModality.TEXT, ModelModality.AUDIO, ModelModality.IMAGE},
output_modalities={ModelModality.TEXT, ModelModality.AUDIO},
supports_tool_calling=True,
@ -76,6 +131,12 @@ GPT_VIDEO_GENERATION_CAPABILITIES = ModelCapabilities(
supports_tool_calling=True,
)
EMBEDDING_CAPABILITIES = ModelCapabilities(
input_modalities={ModelModality.TEXT},
output_modalities={ModelModality.EMBEDDING},
is_embedding_model=True,
)
DEFAULT_PERMISSIVE_CAPABILITIES = ModelCapabilities(
input_modalities={
ModelModality.TEXT,
@ -107,17 +168,33 @@ MODEL_ALIAS_MAPPING: dict[str, str] = {
}
MODEL_CAPABILITIES_REGISTRY: dict[str, ModelCapabilities] = {
"gemini-*-tts": ModelCapabilities(
def _build_registry() -> dict[str, ModelCapabilities]:
"""构建模型能力注册表,展开模式列表以减少冗余"""
registry: dict[str, ModelCapabilities] = {}
def register_family(patterns: list[str], cap: ModelCapabilities) -> None:
for pattern in patterns:
registry[pattern] = cap
register_family(
["*gemini-*-image-preview*", "gemini-*-image*"], CAP_GEMINI_IMAGE_GEN
)
register_family(PATTERNS_GEMINI_2_5, CAP_GEMINI_2_5)
register_family(PATTERNS_GEMINI_3, CAP_GEMINI_3)
register_family(PATTERNS_OPENAI_REASONING, CAP_OPENAI_REASONING)
registry["gemini-*-tts"] = ModelCapabilities(
input_modalities={ModelModality.TEXT},
output_modalities={ModelModality.AUDIO},
),
"gemini-*-native-audio-*": ModelCapabilities(
)
registry["gemini-*-native-audio-*"] = ModelCapabilities(
input_modalities={ModelModality.TEXT, ModelModality.AUDIO, ModelModality.VIDEO},
output_modalities={ModelModality.TEXT, ModelModality.AUDIO},
supports_tool_calling=True,
),
"gemini-2.0-flash-preview-image-generation": ModelCapabilities(
)
registry["gemini-2.0-flash-preview-image-generation"] = ModelCapabilities(
input_modalities={
ModelModality.TEXT,
ModelModality.IMAGE,
@ -126,39 +203,39 @@ MODEL_CAPABILITIES_REGISTRY: dict[str, ModelCapabilities] = {
},
output_modalities={ModelModality.TEXT, ModelModality.IMAGE},
supports_tool_calling=True,
),
"gemini-embedding-exp": ModelCapabilities(
input_modalities={ModelModality.TEXT},
output_modalities={ModelModality.EMBEDDING},
is_embedding_model=True,
),
"*gemini-*-image-preview*": GEMINI_IMAGE_GEN_CAPABILITIES,
"gemini-*-pro*": GEMINI_CAPABILITIES,
"gemini-*-flash*": GEMINI_CAPABILITIES,
"GLM-4V-Flash": ModelCapabilities(
)
registry["GLM-4V-Flash"] = ModelCapabilities(
input_modalities={ModelModality.TEXT, ModelModality.IMAGE},
output_modalities={ModelModality.TEXT},
supports_tool_calling=True,
),
"GLM-4V-Plus*": ModelCapabilities(
)
registry["GLM-4V-Plus*"] = ModelCapabilities(
input_modalities={ModelModality.TEXT, ModelModality.IMAGE, ModelModality.VIDEO},
output_modalities={ModelModality.TEXT},
supports_tool_calling=True,
),
"glm-4-*": STANDARD_TEXT_TOOL_CAPABILITIES,
"glm-z1-*": STANDARD_TEXT_TOOL_CAPABILITIES,
"doubao-seed-*": DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES,
"doubao-1-5-thinking-vision-pro": DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES,
"deepseek-chat": STANDARD_TEXT_TOOL_CAPABILITIES,
"deepseek-reasoner": STANDARD_TEXT_TOOL_CAPABILITIES,
"gpt-5*": GPT_ADVANCED_TEXT_IMAGE_CAPABILITIES,
"gpt-4.1*": GPT_ADVANCED_TEXT_IMAGE_CAPABILITIES,
"gpt-4o*": GPT_MULTIMODAL_IO_CAPABILITIES,
"o3*": GPT_ADVANCED_TEXT_IMAGE_CAPABILITIES,
"o4-mini*": GPT_ADVANCED_TEXT_IMAGE_CAPABILITIES,
"gpt image*": GPT_IMAGE_GENERATION_CAPABILITIES,
"sora*": GPT_VIDEO_GENERATION_CAPABILITIES,
}
)
register_family(
["glm-4-*", "glm-z1-*", "deepseek-chat"], STANDARD_TEXT_TOOL_CAPABILITIES
)
register_family(
["doubao-seed-*", "doubao-1-5-thinking-vision-pro"],
DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES,
)
register_family(["gpt-5*", "gpt-4.1*", "o4-mini*"], CAP_GPT_ADVANCED)
registry["gpt-4o*"] = CAP_GPT_MULTIMODAL_IO
registry["gpt image*"] = GPT_IMAGE_GENERATION_CAPABILITIES
registry["sora*"] = GPT_VIDEO_GENERATION_CAPABILITIES
registry["*embedding*"] = EMBEDDING_CAPABILITIES
return registry
MODEL_CAPABILITIES_REGISTRY = _build_registry()
def get_model_capabilities(model_name: str) -> ModelCapabilities:

View File

@ -1,434 +0,0 @@
"""
LLM 内容类型定义
包含多模态内容部分消息和响应的数据模型
"""
import base64
import mimetypes
from pathlib import Path
from typing import Any
import aiofiles
from pydantic import BaseModel
from zhenxun.services.log import logger
class LLMContentPart(BaseModel):
"""LLM 消息内容部分 - 支持多模态内容"""
type: str
text: str | None = None
image_source: str | None = None
audio_source: str | None = None
video_source: str | None = None
document_source: str | None = None
file_uri: str | None = None
file_source: str | None = None
url: str | None = None
mime_type: str | None = None
metadata: dict[str, Any] | None = None
def model_post_init(self, /, __context: Any) -> None:
"""验证内容部分的有效性"""
_ = __context
validation_rules = {
"text": lambda: self.text,
"image": lambda: self.image_source,
"audio": lambda: self.audio_source,
"video": lambda: self.video_source,
"document": lambda: self.document_source,
"file": lambda: self.file_uri or self.file_source,
"url": lambda: self.url,
}
if self.type in validation_rules:
if not validation_rules[self.type]():
raise ValueError(f"{self.type}类型的内容部分必须包含相应字段")
@classmethod
def text_part(cls, text: str) -> "LLMContentPart":
"""创建文本内容部分"""
return cls(type="text", text=text)
@classmethod
def image_url_part(cls, url: str) -> "LLMContentPart":
"""创建图片URL内容部分"""
return cls(type="image", image_source=url)
@classmethod
def image_base64_part(
cls, data: str, mime_type: str = "image/png"
) -> "LLMContentPart":
"""创建Base64图片内容部分"""
data_url = f"data:{mime_type};base64,{data}"
return cls(type="image", image_source=data_url)
@classmethod
def audio_url_part(cls, url: str, mime_type: str = "audio/wav") -> "LLMContentPart":
"""创建音频URL内容部分"""
return cls(type="audio", audio_source=url, mime_type=mime_type)
@classmethod
def video_url_part(cls, url: str, mime_type: str = "video/mp4") -> "LLMContentPart":
"""创建视频URL内容部分"""
return cls(type="video", video_source=url, mime_type=mime_type)
@classmethod
def video_base64_part(
cls, data: str, mime_type: str = "video/mp4"
) -> "LLMContentPart":
"""创建Base64视频内容部分"""
data_url = f"data:{mime_type};base64,{data}"
return cls(type="video", video_source=data_url, mime_type=mime_type)
@classmethod
def audio_base64_part(
cls, data: str, mime_type: str = "audio/wav"
) -> "LLMContentPart":
"""创建Base64音频内容部分"""
data_url = f"data:{mime_type};base64,{data}"
return cls(type="audio", audio_source=data_url, mime_type=mime_type)
@classmethod
def file_uri_part(
cls,
file_uri: str,
mime_type: str | None = None,
metadata: dict[str, Any] | None = None,
) -> "LLMContentPart":
"""创建Gemini File API URI内容部分"""
return cls(
type="file",
file_uri=file_uri,
mime_type=mime_type,
metadata=metadata or {},
)
@classmethod
async def from_path(
cls, path_like: str | Path, target_api: str | None = None
) -> "LLMContentPart | None":
"""
从本地文件路径创建 LLMContentPart
自动检测MIME类型并根据类型如图片可能加载为Base64
target_api 可以用于提示如何最好地准备数据例如 'gemini' 可能偏好 base64
"""
try:
path = Path(path_like)
if not path.exists() or not path.is_file():
logger.warning(f"文件不存在或不是一个文件: {path}")
return None
mime_type, _ = mimetypes.guess_type(path.resolve().as_uri())
if not mime_type:
logger.warning(
f"无法猜测文件 {path.name} 的MIME类型将尝试作为文本文件处理。"
)
try:
async with aiofiles.open(path, encoding="utf-8") as f:
text_content = await f.read()
return cls.text_part(text_content)
except Exception as e:
logger.error(f"读取文本文件 {path.name} 失败: {e}")
return None
if mime_type.startswith("image/"):
if target_api == "gemini" or not path.is_absolute():
try:
async with aiofiles.open(path, "rb") as f:
img_bytes = await f.read()
base64_data = base64.b64encode(img_bytes).decode("utf-8")
return cls.image_base64_part(
data=base64_data, mime_type=mime_type
)
except Exception as e:
logger.error(f"读取或编码图片文件 {path.name} 失败: {e}")
return None
else:
logger.warning(
f"为本地图片路径 {path.name} 生成 image_url_part。"
"实际API可能不支持 file:// URI。考虑使用Base64或公网URL。"
)
return cls.image_url_part(url=path.resolve().as_uri())
elif mime_type.startswith("audio/"):
return cls.audio_url_part(
url=path.resolve().as_uri(), mime_type=mime_type
)
elif mime_type.startswith("video/"):
if target_api == "gemini":
# 对于 Gemini API将视频转换为 base64
try:
async with aiofiles.open(path, "rb") as f:
video_bytes = await f.read()
base64_data = base64.b64encode(video_bytes).decode("utf-8")
return cls.video_base64_part(
data=base64_data, mime_type=mime_type
)
except Exception as e:
logger.error(f"读取或编码视频文件 {path.name} 失败: {e}")
return None
else:
return cls.video_url_part(
url=path.resolve().as_uri(), mime_type=mime_type
)
elif (
mime_type.startswith("text/")
or mime_type == "application/json"
or mime_type == "application/xml"
):
try:
async with aiofiles.open(path, encoding="utf-8") as f:
text_content = await f.read()
return cls.text_part(text_content)
except Exception as e:
logger.error(f"读取文本类文件 {path.name} 失败: {e}")
return None
else:
logger.info(
f"文件 {path.name} (MIME: {mime_type}) 将作为通用文件URI处理。"
)
return cls.file_uri_part(
file_uri=path.resolve().as_uri(),
mime_type=mime_type,
metadata={"name": path.name, "source": "local_path"},
)
except Exception as e:
logger.error(f"从路径 {path_like} 创建LLMContentPart时出错: {e}")
return None
def is_image_url(self) -> bool:
"""检查图像源是否为URL"""
if not self.image_source:
return False
return self.image_source.startswith(("http://", "https://"))
def is_image_base64(self) -> bool:
"""检查图像源是否为Base64 Data URL"""
if not self.image_source:
return False
return self.image_source.startswith("data:")
def get_base64_data(self) -> tuple[str, str] | None:
"""从Data URL中提取Base64数据和MIME类型"""
if not self.is_image_base64() or not self.image_source:
return None
try:
header, data = self.image_source.split(",", 1)
mime_part = header.split(";")[0].replace("data:", "")
return mime_part, data
except (ValueError, IndexError):
logger.warning(f"无法解析Base64图像数据: {self.image_source[:50]}...")
return None
async def convert_for_api_async(self, api_type: str) -> dict[str, Any]:
"""根据API类型转换多模态内容格式"""
from zhenxun.utils.http_utils import AsyncHttpx
if self.type == "text":
if api_type == "openai":
return {"type": "text", "text": self.text}
elif api_type == "gemini":
return {"text": self.text}
else:
return {"type": "text", "text": self.text}
elif self.type == "image":
if not self.image_source:
raise ValueError("图像类型的内容必须包含image_source")
if api_type == "openai":
return {"type": "image_url", "image_url": {"url": self.image_source}}
elif api_type == "gemini":
if self.is_image_base64():
base64_info = self.get_base64_data()
if base64_info:
mime_type, data = base64_info
return {"inlineData": {"mimeType": mime_type, "data": data}}
else:
raise ValueError(
f"无法解析Base64图像数据: {self.image_source[:50]}..."
)
elif self.is_image_url():
logger.debug(f"正在为Gemini下载并编码URL图片: {self.image_source}")
try:
image_bytes = await AsyncHttpx.get_content(self.image_source)
mime_type = self.mime_type or "image/jpeg"
base64_data = base64.b64encode(image_bytes).decode("utf-8")
return {
"inlineData": {"mimeType": mime_type, "data": base64_data}
}
except Exception as e:
logger.error(f"下载或编码URL图片失败: {e}", e=e)
raise ValueError(f"无法处理图片URL: {e}")
else:
raise ValueError(f"不支持的图像源格式: {self.image_source[:50]}...")
else:
return {"type": "image_url", "image_url": {"url": self.image_source}}
elif self.type == "video":
if not self.video_source:
raise ValueError("视频类型的内容必须包含video_source")
if api_type == "gemini":
# Gemini 支持视频,但需要通过 File API 上传
if self.video_source.startswith("data:"):
# 处理 base64 视频数据
try:
header, data = self.video_source.split(",", 1)
mime_type = header.split(";")[0].replace("data:", "")
return {"inlineData": {"mimeType": mime_type, "data": data}}
except (ValueError, IndexError):
raise ValueError(
f"无法解析Base64视频数据: {self.video_source[:50]}..."
)
else:
# 对于 URL 或其他格式,暂时不支持直接内联
raise ValueError(
"Gemini API 的视频处理需要通过 File API 上传,不支持直接 URL"
)
else:
# 其他 API 可能不支持视频
raise ValueError(f"API类型 '{api_type}' 不支持视频内容")
elif self.type == "audio":
if not self.audio_source:
raise ValueError("音频类型的内容必须包含audio_source")
if api_type == "gemini":
# Gemini 支持音频,处理方式类似视频
if self.audio_source.startswith("data:"):
try:
header, data = self.audio_source.split(",", 1)
mime_type = header.split(";")[0].replace("data:", "")
return {"inlineData": {"mimeType": mime_type, "data": data}}
except (ValueError, IndexError):
raise ValueError(
f"无法解析Base64音频数据: {self.audio_source[:50]}..."
)
else:
raise ValueError(
"Gemini API 的音频处理需要通过 File API 上传,不支持直接 URL"
)
else:
raise ValueError(f"API类型 '{api_type}' 不支持音频内容")
elif self.type == "file":
if api_type == "gemini" and self.file_uri:
return {
"fileData": {"mimeType": self.mime_type, "fileUri": self.file_uri}
}
elif self.file_source:
file_name = (
self.metadata.get("name", "file") if self.metadata else "file"
)
if api_type == "gemini":
return {"text": f"[文件: {file_name}]\n{self.file_source}"}
else:
return {
"type": "text",
"text": f"[文件: {file_name}]\n{self.file_source}",
}
else:
raise ValueError("文件类型的内容必须包含file_uri或file_source")
else:
raise ValueError(f"不支持的内容类型: {self.type}")
class LLMMessage(BaseModel):
"""LLM 消息"""
role: str
content: str | list[LLMContentPart]
name: str | None = None
tool_calls: list[Any] | None = None
tool_call_id: str | None = None
def model_post_init(self, /, __context: Any) -> None:
"""验证消息的有效性"""
_ = __context
if self.role == "tool":
if not self.tool_call_id:
raise ValueError("工具角色的消息必须包含 tool_call_id")
if not self.name:
raise ValueError("工具角色的消息必须包含函数名 (在 name 字段中)")
if self.role == "tool" and not isinstance(self.content, str):
logger.warning(
f"工具角色消息的内容期望是字符串,但得到的是: {type(self.content)}. "
"将尝试转换为字符串。"
)
try:
self.content = str(self.content)
except Exception as e:
raise ValueError(f"无法将工具角色的内容转换为字符串: {e}")
@classmethod
def user(cls, content: str | list[LLMContentPart]) -> "LLMMessage":
"""创建用户消息"""
return cls(role="user", content=content)
@classmethod
def assistant_tool_calls(
cls,
tool_calls: list[Any],
content: str | list[LLMContentPart] = "",
) -> "LLMMessage":
"""创建助手请求工具调用的消息"""
return cls(role="assistant", content=content, tool_calls=tool_calls)
@classmethod
def assistant_text_response(
cls, content: str | list[LLMContentPart]
) -> "LLMMessage":
"""创建助手纯文本回复的消息"""
return cls(role="assistant", content=content, tool_calls=None)
@classmethod
def tool_response(
cls,
tool_call_id: str,
function_name: str,
result: Any,
) -> "LLMMessage":
"""创建工具执行结果的消息"""
import json
try:
content_str = json.dumps(result)
except TypeError as e:
logger.error(
f"工具 '{function_name}' 的结果无法JSON序列化: {result}. 错误: {e}"
)
content_str = json.dumps(
{"error": "工具结果无法JSON序列化", "details": str(e)}
)
return cls(
role="tool",
content=content_str,
tool_call_id=tool_call_id,
name=function_name,
)
@classmethod
def system(cls, content: str) -> "LLMMessage":
"""创建系统消息"""
return cls(role="system", content=content)
class LLMResponse(BaseModel):
"""LLM 响应"""
text: str
images: list[bytes] | None = None
usage_info: dict[str, Any] | None = None
raw_response: dict[str, Any] | None = None
tool_calls: list[Any] | None = None
code_executions: list[Any] | None = None
grounding_metadata: Any | None = None
cache_info: Any | None = None

View File

@ -1,78 +0,0 @@
"""
LLM 枚举类型定义
"""
from enum import Enum, auto
class ModelProvider(Enum):
"""模型提供商枚举"""
OPENAI = "openai"
GEMINI = "gemini"
ZHIXPU = "zhipu"
CUSTOM = "custom"
class ResponseFormat(Enum):
"""响应格式枚举"""
TEXT = "text"
JSON = "json"
MULTIMODAL = "multimodal"
class EmbeddingTaskType(str, Enum):
"""文本嵌入任务类型 (主要用于Gemini)"""
RETRIEVAL_QUERY = "RETRIEVAL_QUERY"
RETRIEVAL_DOCUMENT = "RETRIEVAL_DOCUMENT"
SEMANTIC_SIMILARITY = "SEMANTIC_SIMILARITY"
CLASSIFICATION = "CLASSIFICATION"
CLUSTERING = "CLUSTERING"
QUESTION_ANSWERING = "QUESTION_ANSWERING"
FACT_VERIFICATION = "FACT_VERIFICATION"
class ToolCategory(Enum):
"""工具分类枚举"""
FILE_SYSTEM = auto()
NETWORK = auto()
SYSTEM_INFO = auto()
CALCULATION = auto()
DATA_PROCESSING = auto()
CUSTOM = auto()
class TaskType(Enum):
"""任务类型枚举"""
CHAT = "chat"
CODE = "code"
SEARCH = "search"
ANALYSIS = "analysis"
GENERATION = "generation"
MULTIMODAL = "multimodal"
class LLMErrorCode(Enum):
"""LLM 服务相关的错误代码枚举"""
MODEL_INIT_FAILED = 2000
MODEL_NOT_FOUND = 2001
API_REQUEST_FAILED = 2002
API_RESPONSE_INVALID = 2003
API_KEY_INVALID = 2004
API_QUOTA_EXCEEDED = 2005
API_TIMEOUT = 2006
API_RATE_LIMITED = 2007
NO_AVAILABLE_KEYS = 2008
UNKNOWN_API_TYPE = 2009
CONFIGURATION_ERROR = 2010
RESPONSE_PARSE_ERROR = 2011
CONTEXT_LENGTH_EXCEEDED = 2012
CONTENT_FILTERED = 2013
USER_LOCATION_NOT_SUPPORTED = 2014
GENERATION_FAILED = 2015
EMBEDDING_FAILED = 2016

View File

@ -2,9 +2,31 @@
LLM 异常类型定义
"""
from enum import Enum
from typing import Any
from .enums import LLMErrorCode
class LLMErrorCode(Enum):
"""LLM 服务相关的错误代码枚举"""
MODEL_INIT_FAILED = 2000
MODEL_NOT_FOUND = 2001
API_REQUEST_FAILED = 2002
API_RESPONSE_INVALID = 2003
API_KEY_INVALID = 2004
API_QUOTA_EXCEEDED = 2005
API_TIMEOUT = 2006
API_RATE_LIMITED = 2007
NO_AVAILABLE_KEYS = 2008
UNKNOWN_API_TYPE = 2009
CONFIGURATION_ERROR = 2010
RESPONSE_PARSE_ERROR = 2011
CONTEXT_LENGTH_EXCEEDED = 2012
CONTENT_FILTERED = 2013
USER_LOCATION_NOT_SUPPORTED = 2014
INVALID_PARAMETER = 2017
GENERATION_FAILED = 2015
EMBEDDING_FAILED = 2016
class LLMException(Exception):
@ -27,7 +49,11 @@ class LLMException(Exception):
def __str__(self) -> str:
if self.details:
return f"{self.message} (错误码: {self.code.name}, 详情: {self.details})"
safe_details = {k: v for k, v in self.details.items() if k != "api_key"}
if safe_details:
return (
f"{self.message} (错误码: {self.code.name}, 详情: {safe_details})"
)
return f"{self.message} (错误码: {self.code.name})"
@property
@ -46,10 +72,13 @@ class LLMException(Exception):
"当前所有API密钥均不可用请稍后再试或联系管理员。"
),
LLMErrorCode.USER_LOCATION_NOT_SUPPORTED: (
"当前地区暂不支持此AI服务请联系管理员或尝试其他模型。"
"当前网络环境不支持此 AI 模型 (如 Gemini/OpenAI)。\n"
"原因: 代理节点所在地区(如香港/国内/非支持区)被服务商屏蔽。\n"
"建议: 请尝试更换代理节点至支持的地区(如美国/日本/新加坡)。"
),
LLMErrorCode.API_REQUEST_FAILED: "AI服务请求失败请稍后再试。",
LLMErrorCode.API_RESPONSE_INVALID: "AI服务响应异常请稍后再试。",
LLMErrorCode.INVALID_PARAMETER: "请求参数错误,请检查输入内容。",
LLMErrorCode.CONFIGURATION_ERROR: "AI服务配置错误请联系管理员。",
LLMErrorCode.CONTEXT_LENGTH_EXCEEDED: "输入内容过长,请缩短后重试。",
LLMErrorCode.CONTENT_FILTERED: "内容被安全过滤,请修改后重试。",
@ -66,15 +95,19 @@ def get_user_friendly_error_message(error: Exception) -> str:
error_str = str(error).lower()
if "timeout" in error_str or "超时" in error_str:
return "请求超时,请稍后再试。"
elif "connection" in error_str or "连接" in error_str:
return "网络连接失败,请检查网络后重试。"
elif "permission" in error_str or "权限" in error_str:
return "权限不足,请联系管理员。"
elif "not found" in error_str or "未找到" in error_str:
return "请求的资源未找到,请检查配置。"
elif "invalid" in error_str or "无效" in error_str:
if "timeout" in error_str or "timed out" in error_str:
return "网络请求超时,请检查服务器网络或代理连接。"
if "connect" in error_str and ("refused" in error_str or "error" in error_str):
return "无法连接到 AI 服务商,请检查网络连接或代理设置。"
if "proxy" in error_str:
return "代理连接失败,请检查代理服务器是否正常运行。"
if "ssl" in error_str or "certificate" in error_str:
return "SSL 证书验证失败,请检查网络环境。"
if "permission" in error_str or "forbidden" in error_str:
return "权限不足,可能是 API Key 权限受限。"
if "not found" in error_str:
return "请求的资源未找到 (404),请检查模型名称或端点配置。"
if "invalid" in error_str or "无效" in error_str:
return "请求参数无效,请检查输入。"
else:
return "服务暂时不可用,请稍后再试。"
return f"服务暂时不可用 ({type(error).__name__}),请稍后再试。"

View File

@ -4,12 +4,459 @@ LLM 数据模型定义
包含模型信息配置工具定义和响应数据的模型类
"""
import base64
from dataclasses import dataclass, field
from typing import Any
from enum import Enum, auto
import mimetypes
from pathlib import Path
import sys
from typing import Any, Literal
import aiofiles
from pydantic import BaseModel, Field
from .enums import ModelProvider, ToolCategory
from zhenxun.services.log import logger
if sys.version_info >= (3, 11):
from enum import StrEnum
else:
from strenum import StrEnum
class ModelProvider(Enum):
"""模型提供商枚举"""
OPENAI = "openai"
GEMINI = "gemini"
ZHIXPU = "zhipu"
CUSTOM = "custom"
class ResponseFormat(Enum):
"""响应格式枚举"""
TEXT = "text"
JSON = "json"
MULTIMODAL = "multimodal"
class StructuredOutputStrategy(str, Enum):
"""结构化输出策略"""
NATIVE = "native"
"""使用原生 API (如 OpenAI json_object/json_schema, Gemini mime_type)"""
TOOL_CALL = "tool_call"
"""构造虚假工具调用来强制输出结构化数据 (适用于指令跟随弱但工具调用强的模型)"""
PROMPT = "prompt"
"""仅在 Prompt 中追加 Schema 说明,依赖文本补全"""
class EmbeddingTaskType(str, Enum):
"""文本嵌入任务类型 (主要用于Gemini)"""
RETRIEVAL_QUERY = "RETRIEVAL_QUERY"
RETRIEVAL_DOCUMENT = "RETRIEVAL_DOCUMENT"
SEMANTIC_SIMILARITY = "SEMANTIC_SIMILARITY"
CLASSIFICATION = "CLASSIFICATION"
CLUSTERING = "CLUSTERING"
QUESTION_ANSWERING = "QUESTION_ANSWERING"
FACT_VERIFICATION = "FACT_VERIFICATION"
class ToolCategory(Enum):
"""工具分类枚举"""
FILE_SYSTEM = auto()
NETWORK = auto()
SYSTEM_INFO = auto()
CALCULATION = auto()
DATA_PROCESSING = auto()
CUSTOM = auto()
class CodeExecutionOutcome(StrEnum):
"""代码执行结果状态枚举"""
OUTCOME_OK = "OUTCOME_OK"
OUTCOME_FAILED = "OUTCOME_FAILED"
OUTCOME_DEADLINE_EXCEEDED = "OUTCOME_DEADLINE_EXCEEDED"
OUTCOME_COMPILATION_ERROR = "OUTCOME_COMPILATION_ERROR"
OUTCOME_RUNTIME_ERROR = "OUTCOME_RUNTIME_ERROR"
OUTCOME_UNKNOWN = "OUTCOME_UNKNOWN"
class TaskType(Enum):
"""任务类型枚举"""
CHAT = "chat"
CODE = "code"
SEARCH = "search"
ANALYSIS = "analysis"
GENERATION = "generation"
MULTIMODAL = "multimodal"
class LLMContentPart(BaseModel):
"""
LLM 消息内容部分 - 支持多模态内容
这是一个联合体模型`type` 字段决定了哪些其他字段是有效的
例如
- type='text': 使用 `text` 字段
- type='image': 使用 `image_source` 字段
- type='executable_code': 使用 `code_language` `code_content` 字段
"""
type: str
text: str | None = None
image_source: str | None = None
audio_source: str | None = None
video_source: str | None = None
document_source: str | None = None
file_uri: str | None = None
file_source: str | None = None
url: str | None = None
mime_type: str | None = None
thought_text: str | None = None
media_resolution: str | None = None
code_language: str | None = None
code_content: str | None = None
execution_outcome: str | None = None
execution_output: str | None = None
metadata: dict[str, Any] | None = None
def model_post_init(self, /, __context: Any) -> None:
"""验证内容部分的有效性"""
_ = __context
validation_rules = {
"text": lambda: self.text is not None,
"image": lambda: self.image_source,
"audio": lambda: self.audio_source,
"video": lambda: self.video_source,
"document": lambda: self.document_source,
"file": lambda: self.file_uri or self.file_source,
"url": lambda: self.url,
"thought": lambda: self.thought_text,
"executable_code": lambda: self.code_content is not None,
"execution_result": lambda: self.execution_outcome is not None,
}
if self.type in validation_rules:
if not validation_rules[self.type]():
raise ValueError(f"{self.type}类型的内容部分必须包含相应字段")
@classmethod
def text_part(cls, text: str) -> "LLMContentPart":
"""创建文本内容部分"""
return cls(type="text", text=text)
@classmethod
def thought_part(cls, text: str) -> "LLMContentPart":
"""创建思考过程内容部分"""
return cls(type="thought", thought_text=text)
@classmethod
def image_url_part(cls, url: str) -> "LLMContentPart":
"""创建图片URL内容部分"""
return cls(type="image", image_source=url)
@classmethod
def image_base64_part(
cls, data: str, mime_type: str = "image/png"
) -> "LLMContentPart":
"""创建Base64图片内容部分"""
data_url = f"data:{mime_type};base64,{data}"
return cls(type="image", image_source=data_url)
@classmethod
def audio_url_part(cls, url: str, mime_type: str = "audio/wav") -> "LLMContentPart":
"""创建音频URL内容部分"""
return cls(type="audio", audio_source=url, mime_type=mime_type)
@classmethod
def video_url_part(cls, url: str, mime_type: str = "video/mp4") -> "LLMContentPart":
"""创建视频URL内容部分"""
return cls(type="video", video_source=url, mime_type=mime_type)
@classmethod
def video_base64_part(
cls, data: str, mime_type: str = "video/mp4"
) -> "LLMContentPart":
"""创建Base64视频内容部分"""
data_url = f"data:{mime_type};base64,{data}"
return cls(type="video", video_source=data_url, mime_type=mime_type)
@classmethod
def audio_base64_part(
cls, data: str, mime_type: str = "audio/wav"
) -> "LLMContentPart":
"""创建Base64音频内容部分"""
data_url = f"data:{mime_type};base64,{data}"
return cls(type="audio", audio_source=data_url, mime_type=mime_type)
@classmethod
def file_uri_part(
cls,
file_uri: str,
mime_type: str | None = None,
metadata: dict[str, Any] | None = None,
) -> "LLMContentPart":
"""创建Gemini File API URI内容部分"""
return cls(
type="file",
file_uri=file_uri,
mime_type=mime_type,
metadata=metadata or {},
)
@classmethod
def executable_code_part(cls, language: str, code: str) -> "LLMContentPart":
"""创建可执行代码内容部分"""
return cls(type="executable_code", code_language=language, code_content=code)
@classmethod
def execution_result_part(
cls, outcome: str, output: str | None
) -> "LLMContentPart":
"""创建代码执行结果部分"""
return cls(
type="execution_result", execution_outcome=outcome, execution_output=output
)
@classmethod
async def from_path(
cls, path_like: str | Path, target_api: str | None = None
) -> "LLMContentPart | None":
"""
从本地文件路径创建 LLMContentPart
自动检测MIME类型并根据类型如图片可能加载为Base64
target_api 可以用于提示如何最好地准备数据例如 'gemini' 可能偏好 base64
"""
try:
path = Path(path_like)
if not path.exists() or not path.is_file():
logger.warning(f"文件不存在或不是一个文件: {path}")
return None
mime_type, _ = mimetypes.guess_type(path.resolve().as_uri())
if not mime_type:
logger.warning(
f"无法猜测文件 {path.name} 的MIME类型将尝试作为文本文件处理。"
)
try:
async with aiofiles.open(path, encoding="utf-8") as f:
text_content = await f.read()
return cls.text_part(text_content)
except Exception as e:
logger.error(f"读取文本文件 {path.name} 失败: {e}")
return None
if mime_type.startswith("image/"):
if target_api == "gemini" or not path.is_absolute():
try:
async with aiofiles.open(path, "rb") as f:
img_bytes = await f.read()
base64_data = base64.b64encode(img_bytes).decode("utf-8")
return cls.image_base64_part(
data=base64_data, mime_type=mime_type
)
except Exception as e:
logger.error(f"读取或编码图片文件 {path.name} 失败: {e}")
return None
else:
logger.warning(
f"为本地图片路径 {path.name} 生成 image_url_part。"
"实际API可能不支持 file:// URI。考虑使用Base64或公网URL。"
)
return cls.image_url_part(url=path.resolve().as_uri())
elif mime_type.startswith("audio/"):
return cls.audio_url_part(
url=path.resolve().as_uri(), mime_type=mime_type
)
elif mime_type.startswith("video/"):
if target_api == "gemini":
try:
async with aiofiles.open(path, "rb") as f:
video_bytes = await f.read()
base64_data = base64.b64encode(video_bytes).decode("utf-8")
return cls.video_base64_part(
data=base64_data, mime_type=mime_type
)
except Exception as e:
logger.error(f"读取或编码视频文件 {path.name} 失败: {e}")
return None
else:
return cls.video_url_part(
url=path.resolve().as_uri(), mime_type=mime_type
)
elif (
mime_type.startswith("text/")
or mime_type == "application/json"
or mime_type == "application/xml"
):
try:
async with aiofiles.open(path, encoding="utf-8") as f:
text_content = await f.read()
return cls.text_part(text_content)
except Exception as e:
logger.error(f"读取文本类文件 {path.name} 失败: {e}")
return None
else:
logger.info(
f"文件 {path.name} (MIME: {mime_type}) 将作为通用文件URI处理。"
)
return cls.file_uri_part(
file_uri=path.resolve().as_uri(),
mime_type=mime_type,
metadata={"name": path.name, "source": "local_path"},
)
except Exception as e:
logger.error(f"从路径 {path_like} 创建LLMContentPart时出错: {e}")
return None
def is_image_url(self) -> bool:
"""检查图像源是否为URL"""
if not self.image_source:
return False
return self.image_source.startswith(("http://", "https://"))
def is_image_base64(self) -> bool:
"""检查图像源是否为Base64 Data URL"""
if not self.image_source:
return False
return self.image_source.startswith("data:")
def get_base64_data(self) -> tuple[str, str] | None:
"""从Data URL中提取Base64数据和MIME类型"""
if not self.is_image_base64() or not self.image_source:
return None
try:
header, data = self.image_source.split(",", 1)
mime_part = header.split(";")[0].replace("data:", "")
return mime_part, data
except (ValueError, IndexError):
logger.warning(f"无法解析Base64图像数据: {self.image_source[:50]}...")
return None
class LLMMessage(BaseModel):
"""
LLM 消息对象用于构建对话历史
核心字段说明
- role: 消息角色推荐值为 'user', 'assistant', 'system', 'tool'
- content: 消息内容可以是纯文本字符串也可以是 LLMContentPart 列表用于多模态
- tool_calls: ( assistant) 包含模型生成的工具调用请求
- tool_call_id: ( tool) 对应 tool 消息响应的调用 ID
- name: ( tool) 对应 tool 消息响应的函数名称
"""
role: str
content: str | list[LLMContentPart]
name: str | None = None
tool_calls: list[Any] | None = None
tool_call_id: str | None = None
thought_signature: str | None = None
def model_post_init(self, /, __context: Any) -> None:
"""验证消息的有效性"""
_ = __context
if self.role == "tool":
if not self.tool_call_id:
raise ValueError("工具角色的消息必须包含 tool_call_id")
if not self.name:
raise ValueError("工具角色的消息必须包含函数名 (在 name 字段中)")
if self.role == "tool" and not isinstance(self.content, str):
logger.warning(
f"工具角色消息的内容期望是字符串,但得到的是: {type(self.content)}. "
"将尝试转换为字符串。"
)
try:
self.content = str(self.content)
except Exception as e:
raise ValueError(f"无法将工具角色的内容转换为字符串: {e}")
@classmethod
def user(cls, content: str | list[LLMContentPart]) -> "LLMMessage":
"""创建用户消息"""
return cls(role="user", content=content)
@classmethod
def assistant_tool_calls(
cls,
tool_calls: list[Any],
content: str | list[LLMContentPart] = "",
) -> "LLMMessage":
"""创建助手请求工具调用的消息"""
return cls(role="assistant", content=content, tool_calls=tool_calls)
@classmethod
def assistant_text_response(
cls, content: str | list[LLMContentPart]
) -> "LLMMessage":
"""创建助手纯文本回复的消息"""
return cls(role="assistant", content=content, tool_calls=None)
@classmethod
def tool_response(
cls,
tool_call_id: str,
function_name: str,
result: Any,
) -> "LLMMessage":
"""创建工具执行结果的消息"""
import json
try:
content_str = json.dumps(result)
except TypeError as e:
logger.error(
f"工具 '{function_name}' 的结果无法JSON序列化: {result}. 错误: {e}"
)
content_str = json.dumps(
{"error": "工具结果无法JSON序列化", "details": str(e)}
)
return cls(
role="tool",
content=content_str,
tool_call_id=tool_call_id,
name=function_name,
)
@classmethod
def system(cls, content: str) -> "LLMMessage":
"""创建系统消息"""
return cls(role="system", content=content)
class LLMResponse(BaseModel):
"""
LLM 响应对象封装了模型生成的全部信息
核心字段说明
- text: 模型生成的文本内容如果是纯文本回复此字段即为结果
- tool_calls: 如果模型决定调用工具此列表包含调用详情
- content_parts: 包含多模态或结构化内容的原始部分列表如思维链代码块
- raw_response: 原始的第三方 API 响应字典用于调试
- images: 如果请求涉及生图此处包含生成的图片数据
"""
text: str
content_parts: list[Any] | None = None
images: list[bytes | Path] | None = None
usage_info: dict[str, Any] | None = None
raw_response: dict[str, Any] | None = None
tool_calls: list[Any] | None = None
code_executions: list[Any] | None = None
grounding_metadata: Any | None = None
cache_info: Any | None = None
thought_text: str | None = None
thought_signature: str | None = None
ModelName = str | None
@ -26,6 +473,64 @@ class ToolDefinition(BaseModel):
)
class ToolChoice(BaseModel):
"""统一的工具选择配置"""
mode: Literal["auto", "none", "any", "required"] = Field(
default="auto", description="工具调用模式"
)
allowed_function_names: list[str] | None = Field(
default=None, description="允许调用的函数名称列表"
)
class BasePlatformTool(BaseModel):
"""平台原生工具基类"""
class Config:
extra = "forbid"
def get_tool_declaration(self) -> dict[str, Any]:
"""获取放入 'tools' 列表中的声明对象 (Snake Case)"""
raise NotImplementedError
def get_tool_config(self) -> dict[str, Any] | None:
"""获取放入 'toolConfig' 中的配置对象 (Snake Case)"""
return None
class GeminiCodeExecution(BasePlatformTool):
"""Gemini 代码执行工具"""
def get_tool_declaration(self) -> dict[str, Any]:
return {"code_execution": {}}
class GeminiGoogleSearch(BasePlatformTool):
"""Gemini 谷歌搜索 (Grounding) 工具"""
mode: Literal["MODE_DYNAMIC"] = "MODE_DYNAMIC"
dynamic_threshold: float | None = Field(default=None)
def get_tool_declaration(self) -> dict[str, Any]:
return {"google_search": {}}
def get_tool_config(self) -> dict[str, Any] | None:
return None
class GeminiUrlContext(BasePlatformTool):
"""Gemini 网址上下文工具"""
urls: list[str] = Field(..., description="作为上下文的 URL 列表", max_length=20)
def get_tool_declaration(self) -> dict[str, Any]:
return {"google_search": {}, "url_context": {}}
def get_tool_config(self) -> dict[str, Any] | None:
return None
class ToolResult(BaseModel):
"""
一个结构化的工具执行结果模型
@ -87,6 +592,8 @@ class ModelDetail(BaseModel):
is_embedding_model: bool = False
temperature: float | None = None
max_tokens: int | None = None
api_type: str | None = None
endpoint: str | None = None
class ProviderConfig(BaseModel):
@ -116,6 +623,7 @@ class LLMToolCall(BaseModel):
id: str
function: LLMToolFunction
thought_signature: str | None = None
class LLMCodeExecution(BaseModel):
@ -143,6 +651,12 @@ class LLMGroundingMetadata(BaseModel):
web_search_queries: list[str] | None = None
grounding_attributions: list[LLMGroundingAttribution] | None = None
search_suggestions: list[dict[str, Any]] | None = None
search_entry_point: str | None = Field(
default=None, description="Google搜索建议的HTML片段(renderedContent)"
)
map_widget_token: str | None = Field(
default=None, description="Google Maps 前端组件令牌"
)
class LLMCacheInfo(BaseModel):

View File

@ -2,10 +2,97 @@
LLM 模块的协议定义
"""
from typing import Any, Protocol
from abc import ABC
from typing import TYPE_CHECKING, Any, Protocol, Union
from pydantic import BaseModel
from .models import ToolDefinition, ToolResult
if TYPE_CHECKING:
from .models import LLMMessage, LLMResponse, LLMToolCall
class ToolCallData(BaseModel):
"""传递给 on_tool_start 的数据模型"""
tool_name: str
tool_args: dict[str, Any]
class ToolCallCompleteData(BaseModel):
"""传递给 on_tool_call_complete 的数据模型"""
id: str
name: str
arguments: str
result: "ToolResult"
class BaseCallbackHandler(ABC):
"""
Agent/LLM 生命周期回调处理器的基类
下沉至 LLM 层以允许 ToolInvoker 直接调用
"""
async def on_agent_start(self, messages: list["LLMMessage"], **kwargs: Any) -> None:
"""在 AgentExecutor 开始运行时调用。"""
pass
async def on_model_start(
self, model_name: str, messages: list["LLMMessage"], **kwargs: Any
) -> None:
"""在向LLM发起请求之前调用。"""
pass
async def on_model_end(
self, response: "LLMResponse", duration: float, **kwargs: Any
) -> None:
"""在收到LLM响应之后调用。"""
pass
async def on_tool_start(
self, tool_call: "LLMToolCall", data: ToolCallData, **kwargs: Any
) -> Union[ToolCallData, "ToolResult", None]:
"""
在单个工具即将被执行时调用
返回:
ToolCallData: 修改参数并继续执行
ToolResult: 拦截执行并直接返回给模型
None: 正常继续
"""
pass
async def on_tool_end(
self,
result: Union["ToolResult", None],
error: Exception | None,
tool_call: "LLMToolCall",
duration: float,
**kwargs: Any,
) -> None:
"""在单个工具执行完毕后调用,无论成功或失败。"""
pass
async def on_tool_call_complete(
self, data: ToolCallCompleteData, **kwargs: Any
) -> None:
"""在工具调用完成并准备创建响应消息时调用。"""
pass
async def on_human_input_request(self, query: str, **kwargs: Any) -> str | None:
"""
Agent 需要人类输入时调用
"""
return None
async def on_agent_end(
self, final_history: list["LLMMessage"], duration: float, **kwargs: Any
) -> None:
"""在 AgentExecutor 运行结束时调用。"""
pass
class ToolExecutable(Protocol):
"""
@ -19,10 +106,14 @@ class ToolExecutable(Protocol):
"""
...
async def execute(self, **kwargs: Any) -> ToolResult:
async def execute(self, context: Any | None = None, **kwargs: Any) -> ToolResult:
"""
异步执行工具并返回一个结构化的结果
参数由LLM根据工具定义生成
Args:
context: 运行时上下文 (RunContext)可选注入
**kwargs: 工具参数
"""
...

View File

@ -3,26 +3,176 @@ LLM 模块的工具和转换函数
"""
import base64
import copy
from collections.abc import Awaitable, Callable
import io
from pathlib import Path
from typing import Any
from typing import Any, TypeVar
import aiofiles
import json_repair
from nonebot.adapters import Message as PlatformMessage
from nonebot.compat import type_validate_json
from nonebot_plugin_alconna.uniseg import (
At,
File,
Image,
Reply,
Segment,
Text,
UniMessage,
Video,
Voice,
)
from PIL.Image import Image as PILImageType
from pydantic import BaseModel, Field, ValidationError, create_model
from zhenxun.services.log import logger
from zhenxun.utils.http_utils import AsyncHttpx
from zhenxun.utils.pydantic_compat import model_validate
from .types import LLMContentPart, LLMMessage
from .types import LLMContentPart, LLMErrorCode, LLMException, LLMMessage
from .types.capabilities import ReasoningMode, get_model_capabilities
T = TypeVar("T", bound=BaseModel)
S = TypeVar("S", bound=Segment)
_SEGMENT_HANDLERS: dict[
type[Segment], Callable[[Any], Awaitable[LLMContentPart | None]]
] = {}
def register_segment_handler(seg_type: type[S]):
"""装饰器:注册 Uniseg 消息段的处理器"""
def decorator(func: Callable[[S], Awaitable[LLMContentPart | None]]):
_SEGMENT_HANDLERS[seg_type] = func
return func
return decorator
async def _process_media_data(seg: Any, default_mime: str) -> tuple[str, str] | None:
"""
[内部复用] 通用媒体数据处理获取 Base64 数据和 MIME 类型
优先顺序Raw -> Path -> URL (下载)
"""
mime_type = getattr(seg, "mimetype", None) or default_mime
b64_data = None
if hasattr(seg, "raw") and seg.raw:
if isinstance(seg.raw, bytes):
b64_data = base64.b64encode(seg.raw).decode("utf-8")
elif getattr(seg, "path", None):
try:
path = Path(seg.path)
if path.exists():
async with aiofiles.open(path, "rb") as f:
content = await f.read()
b64_data = base64.b64encode(content).decode("utf-8")
except Exception as e:
logger.error(f"读取媒体文件失败: {seg.path}, 错误: {e}")
elif getattr(seg, "url", None):
try:
logger.debug(f"检测到媒体URL开始下载: {seg.url}")
media_bytes = await AsyncHttpx.get_content(seg.url)
b64_data = base64.b64encode(media_bytes).decode("utf-8")
logger.debug(f"媒体文件下载成功,大小: {len(media_bytes)} bytes")
except Exception as e:
logger.error(f"从URL下载媒体失败: {seg.url}, 错误: {e}")
return None
if b64_data:
return mime_type, b64_data
return None
@register_segment_handler(Text)
async def _handle_text(seg: Text) -> LLMContentPart | None:
if seg.text.strip():
return LLMContentPart.text_part(seg.text)
return None
@register_segment_handler(Image)
async def _handle_image(seg: Image) -> LLMContentPart | None:
media_info = await _process_media_data(seg, "image/png")
if media_info:
mime, data = media_info
return LLMContentPart.image_base64_part(data, mime)
return None
@register_segment_handler(Voice)
async def _handle_voice(seg: Voice) -> LLMContentPart | None:
media_info = await _process_media_data(seg, "audio/wav")
if media_info:
mime, data = media_info
return LLMContentPart.audio_base64_part(data, mime)
return LLMContentPart.text_part(f"[语音消息: {seg.id or 'unknown'}]")
@register_segment_handler(Video)
async def _handle_video(seg: Video) -> LLMContentPart | None:
media_info = await _process_media_data(seg, "video/mp4")
if media_info:
mime, data = media_info
return LLMContentPart.video_base64_part(data, mime)
return LLMContentPart.text_part(f"[视频消息: {seg.id or 'unknown'}]")
@register_segment_handler(File)
async def _handle_file(seg: File) -> LLMContentPart | None:
if seg.path:
return await LLMContentPart.from_path(seg.path)
return LLMContentPart.text_part(f"[文件: {seg.name} (ID: {seg.id})]")
@register_segment_handler(At)
async def _handle_at(seg: At) -> LLMContentPart | None:
if seg.flag == "all":
return LLMContentPart.text_part("[提及所有人]")
return LLMContentPart.text_part(f"[提及用户: {seg.target}]")
@register_segment_handler(Reply)
async def _handle_reply(seg: Reply) -> LLMContentPart | None:
text = str(seg.msg) if seg.msg else ""
if text:
return LLMContentPart.text_part(f'[回复消息: "{text[:50]}..."]')
return LLMContentPart.text_part("[回复了一条消息]")
async def _transform_to_content_part(item: Any) -> LLMContentPart:
"""
将混合输入转换为统一的 LLMContentPart便于 normalize_to_llm_messages 使用
"""
if isinstance(item, LLMContentPart):
return item
if isinstance(item, str):
return LLMContentPart.text_part(item)
if isinstance(item, Path):
part = await LLMContentPart.from_path(item)
if part is None:
raise ValueError(f"无法从路径加载内容: {item}")
return part
if isinstance(item, dict):
return LLMContentPart(**item)
if PILImageType and isinstance(item, PILImageType):
buffer = io.BytesIO()
fmt = item.format or "PNG"
item.save(buffer, format=fmt)
b64_data = base64.b64encode(buffer.getvalue()).decode("utf-8")
mime_type = f"image/{fmt.lower()}"
return LLMContentPart.image_base64_part(b64_data, mime_type)
raise TypeError(f"不支持的输入类型用于构建 ContentPart: {type(item)}")
async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
@ -36,110 +186,25 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
返回:
list[LLMContentPart]: 转换后的内容部分列表
"""
if not _SEGMENT_HANDLERS:
pass
parts: list[LLMContentPart] = []
for seg in message:
part = None
if isinstance(seg, Text):
if seg.text.strip():
part = LLMContentPart.text_part(seg.text)
elif isinstance(seg, Image):
if seg.path:
part = await LLMContentPart.from_path(seg.path, target_api="gemini")
elif seg.url:
part = LLMContentPart.image_url_part(seg.url)
elif hasattr(seg, "raw") and seg.raw:
mime_type = (
getattr(seg, "mimetype", "image/png")
if hasattr(seg, "mimetype")
else "image/png"
)
if isinstance(seg.raw, bytes):
b64_data = base64.b64encode(seg.raw).decode("utf-8")
part = LLMContentPart.image_base64_part(b64_data, mime_type)
elif isinstance(seg, File | Voice | Video):
if seg.path:
part = await LLMContentPart.from_path(seg.path)
elif seg.url:
try:
logger.debug(f"检测到媒体URL开始下载: {seg.url}")
media_bytes = await AsyncHttpx.get_content(seg.url)
new_seg = copy.copy(seg)
new_seg.raw = media_bytes
seg = new_seg
logger.debug(f"媒体文件下载成功,大小: {len(media_bytes)} bytes")
except Exception as e:
logger.error(f"从URL下载媒体失败: {seg.url}, 错误: {e}")
part = LLMContentPart.text_part(
f"[下载媒体失败: {seg.name or seg.url}]"
)
handler = _SEGMENT_HANDLERS.get(type(seg))
if handler:
try:
part = await handler(seg)
if part:
parts.append(part)
continue
if hasattr(seg, "raw") and seg.raw:
mime_type = getattr(seg, "mimetype", None)
if isinstance(seg.raw, bytes):
b64_data = base64.b64encode(seg.raw).decode("utf-8")
if isinstance(seg, Video):
if not mime_type:
mime_type = "video/mp4"
part = LLMContentPart.video_base64_part(
data=b64_data, mime_type=mime_type
)
logger.debug(
f"处理视频字节数据: {mime_type}, 大小: {len(seg.raw)} bytes"
)
elif isinstance(seg, Voice):
if not mime_type:
mime_type = "audio/wav"
part = LLMContentPart.audio_base64_part(
data=b64_data, mime_type=mime_type
)
logger.debug(
f"处理音频字节数据: {mime_type}, 大小: {len(seg.raw)} bytes"
)
else:
part = LLMContentPart.text_part(
f"[FILE: {mime_type or 'unknown'}, {len(seg.raw)} bytes]"
)
logger.debug(
f"处理其他文件字节数据: {mime_type}, "
f"大小: {len(seg.raw)} bytes"
)
elif isinstance(seg, At):
if seg.flag == "all":
part = LLMContentPart.text_part("[提及所有人]")
else:
part = LLMContentPart.text_part(f"[提及用户: {seg.target}]")
elif isinstance(seg, Reply):
if seg.msg:
try:
extract_method = getattr(seg.msg, "extract_plain_text", None)
if extract_method and callable(extract_method):
reply_text = str(extract_method()).strip()
else:
reply_text = str(seg.msg).strip()
if reply_text:
part = LLMContentPart.text_part(
f'[回复消息: "{reply_text[:50]}..."]'
)
except Exception:
part = LLMContentPart.text_part("[回复了一条消息]")
if part:
parts.append(part)
except Exception as e:
logger.warning(f"处理消息段 {seg} 失败: {e}", "LLMUtils")
return parts
async def normalize_to_llm_messages(
message: str | UniMessage | LLMMessage | list[LLMContentPart] | list[LLMMessage],
message: str | UniMessage | LLMMessage | list[Any],
instruction: str | None = None,
) -> list[LLMMessage]:
"""
@ -167,7 +232,10 @@ async def normalize_to_llm_messages(
content_parts = await unimsg_to_llm_parts(message)
messages.append(LLMMessage.user(content_parts))
elif isinstance(message, list):
messages.append(LLMMessage.user(message)) # type: ignore
parts = []
for item in message:
parts.append(await _transform_to_content_part(item))
messages.append(LLMMessage.user(parts))
else:
raise TypeError(f"不支持的消息类型: {type(message)}")
@ -255,53 +323,271 @@ def message_to_unimessage(message: PlatformMessage) -> UniMessage:
返回:
UniMessage: 转换后的通用消息对象
"""
uni_segments = []
for seg in message:
if seg.type == "text":
uni_segments.append(Text(seg.data.get("text", "")))
elif seg.type == "image":
uni_segments.append(Image(url=seg.data.get("url")))
elif seg.type == "record":
uni_segments.append(Voice(url=seg.data.get("url")))
elif seg.type == "video":
uni_segments.append(Video(url=seg.data.get("url")))
elif seg.type == "at":
uni_segments.append(At("user", str(seg.data.get("qq", ""))))
else:
logger.debug(f"跳过不支持的平台消息段类型: {seg.type}")
return UniMessage.of(message)
return UniMessage(uni_segments)
def resolve_json_schema_refs(schema: dict) -> dict:
"""
递归解析 JSON Schema 中的 $ref将其替换为 $defs/definitions 中的定义
用于兼容不支持 $ref Gemini API
"""
definitions = schema.get("$defs") or schema.get("definitions") or {}
def _resolve(node: Any) -> Any:
if isinstance(node, dict):
if "$ref" in node:
ref_name = node["$ref"].split("/")[-1]
if ref_name in definitions:
return _resolve(definitions[ref_name])
return {
key: _resolve(value)
for key, value in node.items()
if key not in ("$defs", "definitions")
}
if isinstance(node, list):
return [_resolve(item) for item in node]
return node
return _resolve(schema)
def sanitize_schema_for_llm(schema: Any, api_type: str) -> Any:
"""
递归地净化 JSON Schema移除特定 LLM API 不支持的关键字
参数:
schema: 要净化的 JSON Schema (可以是字典列表或其它类型)
api_type: 目标 API 的类型例如 'gemini'
返回:
Any: 净化后的 JSON Schema
"""
if isinstance(schema, dict):
schema_copy = {}
for key, value in schema.items():
if api_type == "gemini":
unsupported_keys = ["exclusiveMinimum", "exclusiveMaximum", "default"]
if key in unsupported_keys:
continue
if key == "format" and isinstance(value, str):
supported_formats = ["enum", "date-time"]
if value not in supported_formats:
continue
schema_copy[key] = sanitize_schema_for_llm(value, api_type)
return schema_copy
elif isinstance(schema, list):
if isinstance(schema, list):
return [sanitize_schema_for_llm(item, api_type) for item in schema]
if isinstance(schema, dict):
schema_copy = schema.copy()
if api_type == "gemini":
if "const" in schema_copy:
schema_copy["enum"] = [schema_copy.pop("const")]
if "type" in schema_copy and isinstance(schema_copy["type"], list):
types_list = schema_copy["type"]
if "null" in types_list:
schema_copy["nullable"] = True
types_list = [t for t in types_list if t != "null"]
if len(types_list) == 1:
schema_copy["type"] = types_list[0]
else:
schema_copy["type"] = types_list
if "anyOf" in schema_copy:
any_of = schema_copy["anyOf"]
has_null = any(
isinstance(x, dict) and x.get("type") == "null" for x in any_of
)
if has_null:
schema_copy["nullable"] = True
new_any_of = [
x
for x in any_of
if not (isinstance(x, dict) and x.get("type") == "null")
]
if len(new_any_of) == 1:
schema_copy.update(new_any_of[0])
schema_copy.pop("anyOf", None)
else:
schema_copy["anyOf"] = new_any_of
unsupported_keys = [
"exclusiveMinimum",
"exclusiveMaximum",
"default",
"title",
"additionalProperties",
"$schema",
"$id",
]
for key in unsupported_keys:
schema_copy.pop(key, None)
if schema_copy.get("format") and schema_copy["format"] not in [
"enum",
"date-time",
]:
schema_copy.pop("format", None)
elif api_type == "openai":
unsupported_keys = [
"default",
"minLength",
"maxLength",
"pattern",
"format",
"minimum",
"maximum",
"multipleOf",
"patternProperties",
"minItems",
"maxItems",
"uniqueItems",
"$schema",
"title",
]
for key in unsupported_keys:
schema_copy.pop(key, None)
if "$ref" in schema_copy:
ref_key = schema_copy["$ref"].split("/")[-1]
defs = schema_copy.get("$defs") or schema_copy.get("definitions")
if defs and ref_key in defs:
schema_copy.pop("$ref", None)
schema_copy.update(defs[ref_key])
else:
return {"$ref": schema_copy["$ref"]}
is_object = (
schema_copy.get("type") == "object" or "properties" in schema_copy
)
if is_object:
schema_copy["type"] = "object"
schema_copy["additionalProperties"] = False
properties = schema_copy.get("properties", {})
required = schema_copy.get("required", [])
if properties:
existing_req = set(required)
for prop in properties.keys():
if prop not in existing_req:
required.append(prop)
schema_copy["required"] = required
for def_key in ["$defs", "definitions"]:
if def_key in schema_copy and isinstance(schema_copy[def_key], dict):
schema_copy[def_key] = {
k: sanitize_schema_for_llm(v, api_type)
for k, v in schema_copy[def_key].items()
}
recursive_keys = ["properties", "items", "allOf", "anyOf", "oneOf"]
for key in recursive_keys:
if key in schema_copy:
if key == "properties" and isinstance(schema_copy[key], dict):
schema_copy[key] = {
k: sanitize_schema_for_llm(v, api_type)
for k, v in schema_copy[key].items()
}
else:
schema_copy[key] = sanitize_schema_for_llm(
schema_copy[key], api_type
)
return schema_copy
else:
return schema
def extract_text_from_content(
content: str | list[LLMContentPart] | None,
) -> str:
"""
从消息内容中提取纯文本自动过滤非文本部分防止污染 Prompt
"""
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
return " ".join(
part.text for part in content if part.type == "text" and part.text
)
return str(content)
def parse_and_validate_json(text: str, response_model: type[T]) -> T:
"""
通用工具尝试将文本解析为指定的 Pydantic 模型并统一处理异常
"""
try:
return type_validate_json(response_model, text)
except (ValidationError, ValueError) as e:
try:
logger.warning(f"标准JSON解析失败尝试使用json_repair修复: {e}")
repaired_obj = json_repair.loads(text, skip_json_loads=True)
return model_validate(response_model, repaired_obj)
except Exception as repair_error:
logger.error(
f"LLM结构化输出校验最终失败: {repair_error}",
e=repair_error,
)
raise LLMException(
"LLM返回的JSON未能通过结构验证。",
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
details={
"raw_response": text,
"validation_error": str(repair_error),
"original_error": repair_error,
},
cause=repair_error,
)
except Exception as e:
logger.error(f"解析LLM结构化输出时发生未知错误: {e}", e=e)
raise LLMException(
"解析LLM的JSON输出时失败。",
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
details={"raw_response": text},
cause=e,
)
def create_cot_wrapper(inner_model: type[BaseModel]) -> type[BaseModel]:
"""
[动态运行时封装]
创建一个包含思维链 (Chain of Thought) 的包装模型
强制模型在生成最终 JSON 结构前先输出一个 reasoning 字段进行思考
"""
wrapper_name = f"CoT_{inner_model.__name__}"
return create_model(
wrapper_name,
reasoning=(
str,
Field(
...,
min_length=10,
description=(
"在生成最终结果之前,请务必在此字段中详细描述你的推理步骤、计算过程或思考逻辑。禁止留空。"
),
),
),
result=(
inner_model,
Field(
...,
),
),
)
def should_apply_autocot(
requested: bool,
model_name: str | None,
config: Any,
) -> bool:
"""
[智能决策管道]
判断是否应该应用 AutoCoT (显式思维链包装)
防止在模型已有原生思维能力时进行双重思考
"""
if not requested:
return False
if config:
thinking_budget = getattr(config, "thinking_budget", 0) or 0
if thinking_budget > 0:
return False
if getattr(config, "thinking_level", None) is not None:
return False
if model_name:
caps = get_model_capabilities(model_name)
if caps.reasoning_mode != ReasoningMode.NONE:
return False
return True

View File

@ -14,9 +14,34 @@ def _truncate_base64_string(value: str, threshold: int = 256) -> str:
if value.startswith(prefixes) and len(value) > threshold:
prefix = next((p for p in prefixes if value.startswith(p)), "base64")
return f"[{prefix}_data_omitted_len={len(value)}]"
if len(value) > 1000:
return f"[long_string_omitted_len={len(value)}] {value[:20]}...{value[-20:]}"
if len(value) > 2000:
return f"[long_string_omitted_len={len(value)}] {value[:50]}...{value[-20:]}"
return value
def _truncate_vector_list(vector: list, threshold: int = 10) -> list:
"""如果列表过长通常是embedding向量则截断它用于日志显示。"""
if isinstance(vector, list) and len(vector) > threshold:
return [*vector[:3], f"...({len(vector)} floats omitted)...", *vector[-3:]]
return vector
def _recursive_sanitize_any(obj: Any) -> Any:
"""递归清洗任何对象中的长字符串"""
if isinstance(obj, dict):
return {k: _recursive_sanitize_any(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [_recursive_sanitize_any(v) for v in obj]
elif isinstance(obj, str):
return _truncate_base64_string(obj)
return obj
def _sanitize_ui_html(html_string: str) -> str:
"""
专门用于净化UI渲染调试HTML的函数
@ -64,6 +89,37 @@ def _sanitize_openai_response(response_json: dict) -> dict:
message["images"][i]["image_url"]["url"] = (
_truncate_base64_string(url)
)
if "reasoning_details" in message and isinstance(
message["reasoning_details"], list
):
for detail in message["reasoning_details"]:
if isinstance(detail, dict):
if "data" in detail and isinstance(detail["data"], str):
if len(detail["data"]) > 100:
detail["data"] = (
f"[encrypted_data_omitted_len={len(detail['data'])}]"
)
if "text" in detail and isinstance(detail["text"], str):
detail["text"] = _truncate_base64_string(
detail["text"], threshold=2000
)
if "data" in sanitized_json and isinstance(sanitized_json["data"], list):
for item in sanitized_json["data"]:
if "embedding" in item and isinstance(item["embedding"], list):
item["embedding"] = _truncate_vector_list(item["embedding"])
if "b64_json" in item and isinstance(item["b64_json"], str):
if len(item["b64_json"]) > 256:
item["b64_json"] = (
f"[base64_json_omitted_len={len(item['b64_json'])}]"
)
if "input" in sanitized_json and isinstance(sanitized_json["input"], list):
for item in sanitized_json["input"]:
if "content" in item and isinstance(item["content"], list):
for part in item["content"]:
if isinstance(part, dict) and part.get("type") == "input_image":
image_url = part.get("image_url")
if isinstance(image_url, str):
part["image_url"] = _truncate_base64_string(image_url)
return sanitized_json
except Exception:
return response_json
@ -71,22 +127,44 @@ def _sanitize_openai_response(response_json: dict) -> dict:
def _sanitize_openai_request(body: dict) -> dict:
"""净化OpenAI兼容API的请求体主要截断图片base64。"""
from zhenxun.services.llm.config.providers import (
DebugLogOptions,
get_llm_config,
)
debug_conf = get_llm_config().debug_log
if isinstance(debug_conf, bool):
debug_conf = DebugLogOptions(
show_tools=debug_conf, show_schema=debug_conf, show_safety=debug_conf
)
try:
sanitized_json = copy.deepcopy(body)
if "messages" in sanitized_json and isinstance(
sanitized_json["messages"], list
):
for message in sanitized_json["messages"]:
if "content" in message and isinstance(message["content"], list):
for i, part in enumerate(message["content"]):
if part.get("type") == "image_url":
if "image_url" in part and isinstance(
part["image_url"], dict
):
url = part["image_url"].get("url", "")
message["content"][i]["image_url"]["url"] = (
_truncate_base64_string(url)
)
sanitized_json = _recursive_sanitize_any(copy.deepcopy(body))
if "tools" in sanitized_json and not debug_conf.show_tools:
tools = sanitized_json["tools"]
if isinstance(tools, list):
tool_names = []
for t in tools:
if isinstance(t, dict):
name = None
if "function" in t and isinstance(t["function"], dict):
name = t["function"].get("name")
if not name and "name" in t:
name = t.get("name")
tool_names.append(name or "unknown")
sanitized_json["tools"] = (
f"<{len(tool_names)} tools hidden: {', '.join(tool_names)}>"
)
if "response_format" in sanitized_json and not debug_conf.show_schema:
response_format = sanitized_json["response_format"]
if isinstance(response_format, dict):
if response_format.get("type") == "json_schema":
sanitized_json["response_format"] = {
"type": "json_schema",
"json_schema": "<JSON Schema Hidden>",
}
return sanitized_json
except Exception:
return body
@ -94,6 +172,9 @@ def _sanitize_openai_request(body: dict) -> dict:
def _sanitize_gemini_response(response_json: dict) -> dict:
"""净化Gemini API的响应体处理文本和图片生成两种格式。"""
from zhenxun.services.llm.config.providers import get_llm_config
debug_mode = get_llm_config().debug_log
try:
sanitized_json = copy.deepcopy(response_json)
@ -114,6 +195,15 @@ def _sanitize_gemini_response(response_json: dict) -> dict:
content["parts"][i]["inlineData"]["data"] = (
f"[base64_data_omitted_len={len(data)}]"
)
if "thoughtSignature" in part:
signature = part.get("thoughtSignature", "")
if isinstance(signature, str) and len(signature) > 256:
content["parts"][i]["thoughtSignature"] = (
f"[signature_omitted_len={len(signature)}]"
)
if not debug_mode and isinstance(candidate, dict):
if "safetyRatings" in candidate:
candidate["safetyRatings"] = "<Safety Ratings Hidden>"
if "candidates" in sanitized_json:
_process_candidates(sanitized_json["candidates"])
@ -124,6 +214,19 @@ def _sanitize_gemini_response(response_json: dict) -> dict:
if "candidates" in sanitized_json["image_generation"]:
_process_candidates(sanitized_json["image_generation"]["candidates"])
if "embeddings" in sanitized_json and isinstance(
sanitized_json["embeddings"], list
):
for embedding in sanitized_json["embeddings"]:
if "values" in embedding and isinstance(embedding["values"], list):
embedding["values"] = _truncate_vector_list(embedding["values"])
if not debug_mode and "promptFeedback" in sanitized_json:
prompt_feedback = sanitized_json.get("promptFeedback") or {}
if isinstance(prompt_feedback, dict) and "safetyRatings" in prompt_feedback:
prompt_feedback["safetyRatings"] = "<Safety Ratings Hidden>"
sanitized_json["promptFeedback"] = prompt_feedback
return sanitized_json
except Exception:
return response_json
@ -131,8 +234,46 @@ def _sanitize_gemini_response(response_json: dict) -> dict:
def _sanitize_gemini_request(body: dict) -> dict:
"""净化Gemini API的请求体进行结构转换和总结。"""
from zhenxun.services.llm.config.providers import (
DebugLogOptions,
get_llm_config,
)
debug_conf = get_llm_config().debug_log
if isinstance(debug_conf, bool):
debug_conf = DebugLogOptions(
show_tools=debug_conf, show_schema=debug_conf, show_safety=debug_conf
)
try:
sanitized_body = copy.deepcopy(body)
if "tools" in sanitized_body and not debug_conf.show_tools:
tool_summary = []
for tool_group in sanitized_body["tools"]:
if (
isinstance(tool_group, dict)
and "functionDeclarations" in tool_group
):
declarations = tool_group["functionDeclarations"]
if isinstance(declarations, list):
for func in declarations:
if isinstance(func, dict):
tool_summary.append(func.get("name", "unknown"))
sanitized_body["tools"] = (
f"<{len(tool_summary)} functions hidden: {', '.join(tool_summary)}>"
)
if not debug_conf.show_safety and "safetySettings" in sanitized_body:
sanitized_body["safetySettings"] = "<Safety Settings Hidden>"
if not debug_conf.show_schema and "generationConfig" in sanitized_body:
generation_config = sanitized_body["generationConfig"]
if (
isinstance(generation_config, dict)
and "responseJsonSchema" in generation_config
):
generation_config["responseJsonSchema"] = "<JSON Schema Hidden>"
if "contents" in sanitized_body and isinstance(
sanitized_body["contents"], list
):
@ -153,6 +294,13 @@ def _sanitize_gemini_request(body: dict) -> dict:
continue
new_parts.append(part)
if "thoughtSignature" in part:
sig = part["thoughtSignature"]
if isinstance(sig, str) and len(sig) > 64:
part["thoughtSignature"] = (
f"[signature_omitted_len={len(sig)}]"
)
if media_summary:
summary_text = (
f"[多模态内容: {len(media_summary)}个文件 - "
@ -195,8 +343,5 @@ def sanitize_for_logging(data: Any, context: str | None = None) -> Any:
elif context == "ui_html":
if isinstance(data, str):
return _sanitize_ui_html(data)
else:
if isinstance(data, str):
return _truncate_base64_string(data)
return data
return _recursive_sanitize_any(data)

View File

@ -10,8 +10,14 @@ from enum import Enum
from pathlib import Path
from typing import Any, TypeVar, get_args, get_origin
from nonebot.compat import PYDANTIC_V2, model_dump
from pydantic import VERSION, BaseModel
from nonebot.compat import (
PYDANTIC_V2,
model_dump,
model_fields,
type_validate_json,
type_validate_python,
)
from pydantic import BaseModel
import ujson as json
T = TypeVar("T", bound=BaseModel)
@ -27,9 +33,13 @@ __all__ = [
"model_construct",
"model_copy",
"model_dump",
"model_dump_json",
"model_fields",
"model_json_schema",
"model_validate",
"parse_as",
"type_validate_json",
"type_validate_python",
]
@ -58,12 +68,18 @@ def model_construct(model_class: type[T], **kwargs: Any) -> T:
def model_validate(model_class: type[T], obj: Any) -> T:
"""
Pydantic `model_validate` (v2) `parse_obj` (v1) 的兼容函数
Pydantic 模型验证兼容函数
"""
return type_validate_python(model_class, obj)
def model_dump_json(model: BaseModel, **kwargs: Any) -> str:
"""
Pydantic `model.json()` (v1) `model.model_dump_json()` (v2) 的兼容函数
"""
if PYDANTIC_V2:
return model_class.model_validate(obj)
else:
return model_class.parse_obj(obj)
return model.model_dump_json(**kwargs)
return model.json(**kwargs)
if PYDANTIC_V2:
@ -78,8 +94,7 @@ def model_json_schema(model_class: type[BaseModel], **kwargs: Any) -> dict[str,
"""
if PYDANTIC_V2:
return model_class.model_json_schema(**kwargs)
else:
return model_class.schema(by_alias=kwargs.get("by_alias", True))
return model_class.schema(by_alias=kwargs.get("by_alias", True))
def _is_pydantic_type(t: Any) -> bool:
@ -108,18 +123,7 @@ def _dump_pydantic_obj(obj: Any) -> Any:
return obj
def parse_as(type_: type[V], obj: Any) -> V:
"""
一个兼容 Pydantic V1 parse_obj_as 和V2的TypeAdapter.validate_python 的辅助函数
"""
if VERSION.startswith("1"):
from pydantic import parse_obj_as
return parse_obj_as(type_, obj)
else:
from pydantic import TypeAdapter # type: ignore
return TypeAdapter(type_).validate_python(obj)
parse_as = type_validate_python
def dump_json_safely(obj: Any, **kwargs) -> str: