feat(llm): 全面重构LLM服务模块,增强多模态与工具支持 (#1953)

*  feat(llm): 全面重构LLM服务模块,增强多模态与工具支持

🚀 核心功能增强
- 多模型链式调用:新增 `pipeline_chat` 支持复杂任务流处理
- 扩展提供商支持:新增 ARK(火山方舟)、SiliconFlow(硅基流动) 适配器
- 多模态处理增强:支持URL媒体文件下载转换,提升输入灵活性
- 历史对话支持:AI.analyze 方法支持历史消息上下文和可选 UniMessage 参数
- 文本嵌入功能:新增 `embed`、`analyze_multimodal`、`search_multimodal` 等API
- 模型能力系统:新增 `ModelCapabilities` 统一管理模型特性(多模态、工具调用等)

🔧 架构重构与优化
- MCP工具系统重构:配置独立化至 `data/llm/mcp_tools.json`,预置常用工具
- API调用逻辑统一:提取通用 `_perform_api_call` 方法,消除代码重复
- 跨平台兼容:Windows平台MCP工具npx命令自动包装处理
- HTTP客户端增强:兼容不同版本httpx代理配置(0.28+版本适配)

🛠️ API与配置完善
- 统一返回类型:`AI.analyze` 统一返回 `LLMResponse` 类型
- 消息转换工具:新增 `message_to_unimessage` 转换函数
- Gemini适配器增强:URL图片下载编码、动态安全阈值配置
- 缓存管理:新增模型实例缓存和管理功能
- 配置预设:扩展 CommonOverrides 预设配置选项
- 历史管理优化:支持多模态内容占位符替换,提升效率

📚 文档与开发体验
- README全面重写:新增完整使用指南、API参考和架构概览
- 文档内容扩充:补充嵌入模型、缓存管理、工具注册等功能说明
- 日志记录增强:支持详细调试信息输出
- API简化:移除冗余函数,优化接口设计

* 🎨  feat(llm): 统一LLM服务函数文档格式

*  feat(llm): 添加新模型并简化提供者配置加载

* 🚨 auto fix by pre-commit hooks

---------

Co-authored-by: webjoin111 <455457521@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Rumio 2025-07-08 11:15:15 +08:00 committed by GitHub
parent 1e7ae38684
commit 48cbb2bf1d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 2128 additions and 1338 deletions

File diff suppressed because it is too large Load Diff

View File

@ -10,10 +10,10 @@ from .api import (
TaskType,
analyze,
analyze_multimodal,
analyze_with_images,
chat,
code,
embed,
pipeline_chat,
search,
search_multimodal,
)
@ -35,6 +35,7 @@ from .manager import (
list_model_identifiers,
set_global_default_model_name,
)
from .tools import tool_registry
from .types import (
EmbeddingTaskType,
LLMContentPart,
@ -43,6 +44,7 @@ from .types import (
LLMMessage,
LLMResponse,
LLMTool,
MCPCompatible,
ModelDetail,
ModelInfo,
ModelProvider,
@ -51,7 +53,7 @@ from .types import (
ToolMetadata,
UsageInfo,
)
from .utils import create_multimodal_message, unimsg_to_llm_parts
from .utils import create_multimodal_message, message_to_unimessage, unimsg_to_llm_parts
__all__ = [
"AI",
@ -65,6 +67,7 @@ __all__ = [
"LLMMessage",
"LLMResponse",
"LLMTool",
"MCPCompatible",
"ModelDetail",
"ModelInfo",
"ModelName",
@ -76,7 +79,6 @@ __all__ = [
"UsageInfo",
"analyze",
"analyze_multimodal",
"analyze_with_images",
"chat",
"clear_model_cache",
"code",
@ -88,9 +90,12 @@ __all__ = [
"list_available_models",
"list_embedding_models",
"list_model_identifiers",
"message_to_unimessage",
"pipeline_chat",
"register_llm_configs",
"search",
"search_multimodal",
"set_global_default_model_name",
"tool_registry",
"unimsg_to_llm_parts",
]

View File

@ -8,7 +8,6 @@ from .base import BaseAdapter, OpenAICompatAdapter, RequestData, ResponseData
from .factory import LLMAdapterFactory, get_adapter_for_api_type, register_adapter
from .gemini import GeminiAdapter
from .openai import OpenAIAdapter
from .zhipu import ZhipuAdapter
LLMAdapterFactory.initialize()
@ -20,7 +19,6 @@ __all__ = [
"OpenAICompatAdapter",
"RequestData",
"ResponseData",
"ZhipuAdapter",
"get_adapter_for_api_type",
"register_adapter",
]

View File

@ -17,6 +17,7 @@ if TYPE_CHECKING:
from ..service import LLMModel
from ..types.content import LLMMessage
from ..types.enums import EmbeddingTaskType
from ..types.models import LLMTool
class RequestData(BaseModel):
@ -60,7 +61,7 @@ class BaseAdapter(ABC):
"""支持的API类型列表"""
pass
def prepare_simple_request(
async def prepare_simple_request(
self,
model: "LLMModel",
api_key: str,
@ -86,7 +87,7 @@ class BaseAdapter(ABC):
config = model._generation_config
return self.prepare_advanced_request(
return await self.prepare_advanced_request(
model=model,
api_key=api_key,
messages=messages,
@ -96,13 +97,13 @@ class BaseAdapter(ABC):
)
@abstractmethod
def prepare_advanced_request(
async def prepare_advanced_request(
self,
model: "LLMModel",
api_key: str,
messages: list["LLMMessage"],
config: "LLMGenerationConfig | None" = None,
tools: list[dict[str, Any]] | None = None,
tools: list["LLMTool"] | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> RequestData:
"""准备高级请求"""
@ -238,6 +239,9 @@ class BaseAdapter(ABC):
message = choice.get("message", {})
content = message.get("content", "")
if content:
content = content.strip()
parsed_tool_calls: list[LLMToolCall] | None = None
if message_tool_calls := message.get("tool_calls"):
from ..types.models import LLMToolFunction
@ -375,7 +379,7 @@ class BaseAdapter(ABC):
if model.temperature is not None:
base_config["temperature"] = model.temperature
if model.max_tokens is not None:
if model.api_type in ["gemini", "gemini_native"]:
if model.api_type == "gemini":
base_config["maxOutputTokens"] = model.max_tokens
else:
base_config["max_tokens"] = model.max_tokens
@ -401,26 +405,51 @@ class OpenAICompatAdapter(BaseAdapter):
"""
@abstractmethod
def get_chat_endpoint(self) -> str:
def get_chat_endpoint(self, model: "LLMModel") -> str:
"""子类必须实现,返回 chat completions 的端点"""
pass
@abstractmethod
def get_embedding_endpoint(self) -> str:
def get_embedding_endpoint(self, model: "LLMModel") -> str:
"""子类必须实现,返回 embeddings 的端点"""
pass
def prepare_advanced_request(
async def prepare_simple_request(
self,
model: "LLMModel",
api_key: str,
prompt: str,
history: list[dict[str, str]] | None = None,
) -> RequestData:
"""准备简单文本生成请求 - OpenAI兼容API的通用实现"""
url = self.get_api_url(model, self.get_chat_endpoint(model))
headers = self.get_base_headers(api_key)
messages = []
if history:
messages.extend(history)
messages.append({"role": "user", "content": prompt})
body = {
"model": model.model_name,
"messages": messages,
}
body = self.apply_config_override(model, body)
return RequestData(url=url, headers=headers, body=body)
async def prepare_advanced_request(
self,
model: "LLMModel",
api_key: str,
messages: list["LLMMessage"],
config: "LLMGenerationConfig | None" = None,
tools: list[dict[str, Any]] | None = None,
tools: list["LLMTool"] | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> RequestData:
"""准备高级请求 - OpenAI兼容格式"""
url = self.get_api_url(model, self.get_chat_endpoint())
url = self.get_api_url(model, self.get_chat_endpoint(model))
headers = self.get_base_headers(api_key)
openai_messages = self.convert_messages_to_openai_format(messages)
@ -430,7 +459,21 @@ class OpenAICompatAdapter(BaseAdapter):
}
if tools:
body["tools"] = tools
openai_tools = []
for tool in tools:
if tool.type == "function" and tool.function:
openai_tools.append({"type": "function", "function": tool.function})
elif tool.type == "mcp" and tool.mcp_session:
if callable(tool.mcp_session):
raise ValueError(
"适配器接收到未激活的 MCP 会话工厂。"
"会话工厂应该在 LLMModel.generate_response 中被激活。"
)
openai_tools.append(
tool.mcp_session.to_api_tool(api_type=self.api_type)
)
if openai_tools:
body["tools"] = openai_tools
if tool_choice:
body["tool_choice"] = tool_choice
@ -444,7 +487,7 @@ class OpenAICompatAdapter(BaseAdapter):
is_advanced: bool = False,
) -> ResponseData:
"""解析响应 - 直接使用基类的 OpenAI 格式解析"""
_ = model, is_advanced # 未使用的参数
_ = model, is_advanced
return self.parse_openai_response(response_json)
def prepare_embedding_request(
@ -456,8 +499,8 @@ class OpenAICompatAdapter(BaseAdapter):
**kwargs: Any,
) -> RequestData:
"""准备嵌入请求 - OpenAI兼容格式"""
_ = task_type # 未使用的参数
url = self.get_api_url(model, self.get_embedding_endpoint())
_ = task_type
url = self.get_api_url(model, self.get_embedding_endpoint(model))
headers = self.get_base_headers(api_key)
body = {
@ -465,7 +508,6 @@ class OpenAICompatAdapter(BaseAdapter):
"input": texts,
}
# 应用额外的配置参数
if kwargs:
body.update(kwargs)

View File

@ -22,10 +22,8 @@ class LLMAdapterFactory:
from .gemini import GeminiAdapter
from .openai import OpenAIAdapter
from .zhipu import ZhipuAdapter
cls.register_adapter(OpenAIAdapter())
cls.register_adapter(ZhipuAdapter())
cls.register_adapter(GeminiAdapter())
@classmethod

View File

@ -14,7 +14,7 @@ if TYPE_CHECKING:
from ..service import LLMModel
from ..types.content import LLMMessage
from ..types.enums import EmbeddingTaskType
from ..types.models import LLMToolCall
from ..types.models import LLMTool, LLMToolCall
class GeminiAdapter(BaseAdapter):
@ -38,30 +38,16 @@ class GeminiAdapter(BaseAdapter):
return headers
def prepare_advanced_request(
async def prepare_advanced_request(
self,
model: "LLMModel",
api_key: str,
messages: list["LLMMessage"],
config: "LLMGenerationConfig | None" = None,
tools: list[dict[str, Any]] | None = None,
tools: list["LLMTool"] | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> RequestData:
"""准备高级请求"""
return self._prepare_request(
model, api_key, messages, config, tools, tool_choice
)
def _prepare_request(
self,
model: "LLMModel",
api_key: str,
messages: list["LLMMessage"],
config: "LLMGenerationConfig | None" = None,
tools: list[dict[str, Any]] | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> RequestData:
"""准备 Gemini API 请求 - 支持所有高级功能"""
effective_config = config if config is not None else model._generation_config
endpoint = self._get_gemini_endpoint(model, effective_config)
@ -78,7 +64,8 @@ class GeminiAdapter(BaseAdapter):
system_instruction_parts = [{"text": msg.content}]
elif isinstance(msg.content, list):
system_instruction_parts = [
part.convert_for_api("gemini") for part in msg.content
await part.convert_for_api_async("gemini")
for part in msg.content
]
continue
@ -87,7 +74,9 @@ class GeminiAdapter(BaseAdapter):
current_parts.append({"text": msg.content})
elif isinstance(msg.content, list):
for part_obj in msg.content:
current_parts.append(part_obj.convert_for_api("gemini"))
current_parts.append(
await part_obj.convert_for_api_async("gemini")
)
gemini_contents.append({"role": "user", "parts": current_parts})
elif msg.role == "assistant" or msg.role == "model":
@ -95,7 +84,9 @@ class GeminiAdapter(BaseAdapter):
current_parts.append({"text": msg.content})
elif isinstance(msg.content, list):
for part_obj in msg.content:
current_parts.append(part_obj.convert_for_api("gemini"))
current_parts.append(
await part_obj.convert_for_api_async("gemini")
)
if msg.tool_calls:
import json
@ -154,16 +145,22 @@ class GeminiAdapter(BaseAdapter):
all_tools_for_request = []
if tools:
for tool_item in tools:
if isinstance(tool_item, dict):
if "name" in tool_item and "description" in tool_item:
all_tools_for_request.append(
{"functionDeclarations": [tool_item]}
for tool in tools:
if tool.type == "function" and tool.function:
all_tools_for_request.append(
{"functionDeclarations": [tool.function]}
)
elif tool.type == "mcp" and tool.mcp_session:
if callable(tool.mcp_session):
raise ValueError(
"适配器接收到未激活的 MCP 会话工厂。"
"会话工厂应该在 LLMModel.generate_response 中被激活。"
)
else:
all_tools_for_request.append(tool_item)
else:
all_tools_for_request.append(tool_item)
all_tools_for_request.append(
tool.mcp_session.to_api_tool(api_type=self.api_type)
)
elif tool.type == "google_search":
all_tools_for_request.append({"googleSearch": {}})
if effective_config:
if getattr(effective_config, "enable_grounding", False):
@ -183,11 +180,7 @@ class GeminiAdapter(BaseAdapter):
logger.debug("隐式启用代码执行工具。")
if all_tools_for_request:
gemini_api_tools = self._convert_tools_to_gemini_format(
all_tools_for_request
)
if gemini_api_tools:
body["tools"] = gemini_api_tools
body["tools"] = all_tools_for_request
final_tool_choice = tool_choice
if final_tool_choice is None and effective_config:
@ -241,38 +234,6 @@ class GeminiAdapter(BaseAdapter):
return f"/v1beta/models/{model.model_name}:generateContent"
def _convert_tools_to_gemini_format(
self, tools: list[dict[str, Any]]
) -> list[dict[str, Any]]:
"""转换工具格式为Gemini格式"""
gemini_tools = []
for tool in tools:
if tool.get("type") == "function":
func = tool["function"]
gemini_tool = {
"functionDeclarations": [
{
"name": func["name"],
"description": func.get("description", ""),
"parameters": func.get("parameters", {}),
}
]
}
gemini_tools.append(gemini_tool)
elif tool.get("type") == "code_execution":
gemini_tools.append(
{"codeExecution": {"language": tool.get("language", "python")}}
)
elif tool.get("type") == "google_search":
gemini_tools.append({"googleSearch": {}})
elif "googleSearch" in tool:
gemini_tools.append({"googleSearch": tool["googleSearch"]})
elif "codeExecution" in tool:
gemini_tools.append({"codeExecution": tool["codeExecution"]})
return gemini_tools
def _convert_tool_choice_to_gemini(
self, tool_choice_value: str | dict[str, Any]
) -> dict[str, Any]:
@ -395,10 +356,11 @@ class GeminiAdapter(BaseAdapter):
for category, threshold in custom_safety_settings.items():
safety_settings.append({"category": category, "threshold": threshold})
else:
from ..config.providers import get_gemini_safety_threshold
threshold = get_gemini_safety_threshold()
for category in safety_categories:
safety_settings.append(
{"category": category, "threshold": "BLOCK_MEDIUM_AND_ABOVE"}
)
safety_settings.append({"category": category, "threshold": threshold})
return safety_settings if safety_settings else None

View File

@ -1,12 +1,12 @@
"""
OpenAI API 适配器
支持 OpenAIDeepSeek 和其他 OpenAI 兼容的 API 服务
支持 OpenAIDeepSeek智谱AI 和其他 OpenAI 兼容的 API 服务
"""
from typing import TYPE_CHECKING
from .base import OpenAICompatAdapter, RequestData
from .base import OpenAICompatAdapter
if TYPE_CHECKING:
from ..service import LLMModel
@ -21,37 +21,18 @@ class OpenAIAdapter(OpenAICompatAdapter):
@property
def supported_api_types(self) -> list[str]:
return ["openai", "deepseek", "general_openai_compat"]
return ["openai", "deepseek", "zhipu", "general_openai_compat", "ark"]
def get_chat_endpoint(self) -> str:
def get_chat_endpoint(self, model: "LLMModel") -> str:
"""返回聊天完成端点"""
if model.api_type == "ark":
return "/api/v3/chat/completions"
if model.api_type == "zhipu":
return "/api/paas/v4/chat/completions"
return "/v1/chat/completions"
def get_embedding_endpoint(self) -> str:
"""返回嵌入端点"""
def get_embedding_endpoint(self, model: "LLMModel") -> str:
"""根据API类型返回嵌入端点"""
if model.api_type == "zhipu":
return "/v4/embeddings"
return "/v1/embeddings"
def prepare_simple_request(
self,
model: "LLMModel",
api_key: str,
prompt: str,
history: list[dict[str, str]] | None = None,
) -> RequestData:
"""准备简单文本生成请求 - OpenAI优化实现"""
url = self.get_api_url(model, self.get_chat_endpoint())
headers = self.get_base_headers(api_key)
messages = []
if history:
messages.extend(history)
messages.append({"role": "user", "content": prompt})
body = {
"model": model.model_name,
"messages": messages,
}
body = self.apply_config_override(model, body)
return RequestData(url=url, headers=headers, body=body)

View File

@ -1,57 +0,0 @@
"""
智谱 AI API 适配器
支持智谱 AI GLM 系列模型使用 OpenAI 兼容的接口格式
"""
from typing import TYPE_CHECKING
from .base import OpenAICompatAdapter, RequestData
if TYPE_CHECKING:
from ..service import LLMModel
class ZhipuAdapter(OpenAICompatAdapter):
"""智谱AI适配器 - 使用智谱AI专用的OpenAI兼容接口"""
@property
def api_type(self) -> str:
return "zhipu"
@property
def supported_api_types(self) -> list[str]:
return ["zhipu"]
def get_chat_endpoint(self) -> str:
"""返回智谱AI聊天完成端点"""
return "/api/paas/v4/chat/completions"
def get_embedding_endpoint(self) -> str:
"""返回智谱AI嵌入端点"""
return "/v4/embeddings"
def prepare_simple_request(
self,
model: "LLMModel",
api_key: str,
prompt: str,
history: list[dict[str, str]] | None = None,
) -> RequestData:
"""准备简单文本生成请求 - 智谱AI优化实现"""
url = self.get_api_url(model, self.get_chat_endpoint())
headers = self.get_base_headers(api_key)
messages = []
if history:
messages.extend(history)
messages.append({"role": "user", "content": prompt})
body = {
"model": model.model_name,
"messages": messages,
}
body = self.apply_config_override(model, body)
return RequestData(url=url, headers=headers, body=body)

View File

@ -2,6 +2,7 @@
LLM 服务的高级 API 接口
"""
import copy
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
@ -14,6 +15,7 @@ from zhenxun.services.log import logger
from .config import CommonOverrides, LLMGenerationConfig
from .config.providers import get_ai_config
from .manager import get_global_default_model_name, get_model_instance
from .tools import tool_registry
from .types import (
EmbeddingTaskType,
LLMContentPart,
@ -56,6 +58,7 @@ class AIConfig:
enable_gemini_safe_mode: bool = False
enable_gemini_multimodal: bool = False
enable_gemini_grounding: bool = False
default_preserve_media_in_history: bool = False
def __post_init__(self):
"""初始化后从配置中读取默认值"""
@ -81,7 +84,7 @@ class AI:
"""
初始化AI服务
Args:
参数:
config: AI 配置.
history: 可选的初始对话历史.
"""
@ -93,16 +96,65 @@ class AI:
self.history = []
logger.info("AI session history cleared.")
def _sanitize_message_for_history(self, message: LLMMessage) -> LLMMessage:
"""
净化用于存入历史记录的消息
将非文本的多模态内容部分替换为文本占位符以避免重复处理
"""
if not isinstance(message.content, list):
return message
sanitized_message = copy.deepcopy(message)
content_list = sanitized_message.content
if not isinstance(content_list, list):
return sanitized_message
new_content_parts: list[LLMContentPart] = []
has_multimodal_content = False
for part in content_list:
if isinstance(part, LLMContentPart) and part.type == "text":
new_content_parts.append(part)
else:
has_multimodal_content = True
if has_multimodal_content:
placeholder = "[用户发送了媒体文件,内容已在首次分析时处理]"
text_part_found = False
for part in new_content_parts:
if part.type == "text":
part.text = f"{placeholder} {part.text or ''}".strip()
text_part_found = True
break
if not text_part_found:
new_content_parts.insert(0, LLMContentPart.text_part(placeholder))
sanitized_message.content = new_content_parts
return sanitized_message
async def chat(
self,
message: str | LLMMessage | list[LLMContentPart],
*,
model: ModelName = None,
preserve_media_in_history: bool | None = None,
**kwargs: Any,
) -> str:
"""
进行一次聊天对话
此方法会自动使用和更新会话内的历史记录
参数:
message: 用户输入的消息
model: 本次对话要使用的模型
preserve_media_in_history: 是否在历史记录中保留原始多模态信息
- True: 保留用于深度多轮媒体分析
- False: 不保留替换为占位符提高效率
- None (默认): 使用AI实例配置的默认值
**kwargs: 传递给模型的其他参数
返回:
str: 模型的文本响应
"""
current_message: LLMMessage
if isinstance(message, str):
@ -127,7 +179,20 @@ class AI:
final_messages, model, "聊天失败", kwargs
)
self.history.append(current_message)
should_preserve = (
preserve_media_in_history
if preserve_media_in_history is not None
else self.config.default_preserve_media_in_history
)
if should_preserve:
logger.debug("深度分析模式:在历史记录中保留原始多模态消息。")
self.history.append(current_message)
else:
logger.debug("高效模式:净化历史记录中的多模态消息。")
sanitized_user_message = self._sanitize_message_for_history(current_message)
self.history.append(sanitized_user_message)
self.history.append(LLMMessage.assistant_text_response(response.text))
return response.text
@ -140,7 +205,18 @@ class AI:
timeout: int | None = None,
**kwargs: Any,
) -> dict[str, Any]:
"""代码执行"""
"""
代码执行
参数:
prompt: 代码执行的提示词
model: 要使用的模型名称
timeout: 代码执行超时时间
**kwargs: 传递给模型的其他参数
返回:
dict[str, Any]: 包含执行结果的字典包含textcode_executions和success字段
"""
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
config = CommonOverrides.gemini_code_execution()
@ -168,7 +244,18 @@ class AI:
instruction: str = "",
**kwargs: Any,
) -> dict[str, Any]:
"""信息搜索 - 支持多模态输入"""
"""
信息搜索 - 支持多模态输入
参数:
query: 搜索查询内容支持文本或多模态消息
model: 要使用的模型名称
instruction: 搜索指令
**kwargs: 传递给模型的其他参数
返回:
dict[str, Any]: 包含搜索结果的字典包含textsourcesqueries和success字段
"""
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
config = CommonOverrides.gemini_grounding()
@ -217,63 +304,69 @@ class AI:
async def analyze(
self,
message: UniMessage,
message: UniMessage | None,
*,
instruction: str = "",
model: ModelName = None,
tools: list[dict[str, Any]] | None = None,
use_tools: list[str] | None = None,
tool_config: dict[str, Any] | None = None,
activated_tools: list[LLMTool] | None = None,
history: list[LLMMessage] | None = None,
**kwargs: Any,
) -> str | LLMResponse:
) -> LLMResponse:
"""
内容分析 - 接收 UniMessage 物件进行多模态分析和工具呼叫
这是处理复杂互动的主要方法
参数:
message: 要分析的消息内容支持多模态
instruction: 分析指令
model: 要使用的模型名称
use_tools: 要使用的工具名称列表
tool_config: 工具配置
activated_tools: 已激活的工具列表
history: 对话历史记录
**kwargs: 传递给模型的其他参数
返回:
LLMResponse: 模型的完整响应结果
"""
content_parts = await unimsg_to_llm_parts(message)
content_parts = await unimsg_to_llm_parts(message or UniMessage())
final_messages: list[LLMMessage] = []
if history:
final_messages.extend(history)
if instruction:
final_messages.append(LLMMessage.system(instruction))
if not any(msg.role == "system" for msg in final_messages):
final_messages.insert(0, LLMMessage.system(instruction))
if not content_parts:
if instruction:
if instruction and not history:
final_messages.append(LLMMessage.user(instruction))
else:
elif not history:
raise LLMException(
"分析内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
)
else:
final_messages.append(LLMMessage.user(content_parts))
llm_tools = None
if tools:
llm_tools = []
for tool_dict in tools:
if isinstance(tool_dict, dict):
if "name" in tool_dict and "description" in tool_dict:
llm_tool = LLMTool(
type="function",
function={
"name": tool_dict["name"],
"description": tool_dict["description"],
"parameters": tool_dict.get("parameters", {}),
},
)
llm_tools.append(llm_tool)
else:
llm_tools.append(LLMTool(**tool_dict))
else:
llm_tools.append(tool_dict)
llm_tools: list[LLMTool] | None = activated_tools
if not llm_tools and use_tools:
try:
llm_tools = tool_registry.get_tools(use_tools)
logger.debug(f"已从注册表加载工具定义: {use_tools}")
except ValueError as e:
raise LLMException(
f"加载工具定义失败: {e}",
code=LLMErrorCode.CONFIGURATION_ERROR,
cause=e,
)
tool_choice = None
if tool_config:
mode = tool_config.get("mode", "auto")
if mode == "auto":
tool_choice = "auto"
elif mode == "any":
tool_choice = "any"
elif mode == "none":
tool_choice = "none"
if mode in ["auto", "any", "none"]:
tool_choice = mode
response = await self._execute_generation(
final_messages,
@ -284,9 +377,7 @@ class AI:
tool_choice=tool_choice,
)
if response.tool_calls:
return response
return response.text
return response
async def _execute_generation(
self,
@ -298,7 +389,7 @@ class AI:
tool_choice: str | dict[str, Any] | None = None,
base_config: LLMGenerationConfig | None = None,
) -> LLMResponse:
"""通用的生成执行方法,封装重复的模型获取、配置合并和异常处理逻辑"""
"""通用的生成执行方法,封装模型获取和单次API调用"""
try:
resolved_model_name = self._resolve_model_name(
model_name or self.config.model
@ -311,7 +402,9 @@ class AI:
resolved_model_name, override_config=final_config_dict
) as model_instance:
return await model_instance.generate_response(
messages, tools=llm_tools, tool_choice=tool_choice
messages,
tools=llm_tools,
tool_choice=tool_choice,
)
except LLMException:
raise
@ -380,7 +473,18 @@ class AI:
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
**kwargs: Any,
) -> list[list[float]]:
"""生成文本嵌入向量"""
"""
生成文本嵌入向量
参数:
texts: 要生成嵌入向量的文本或文本列表
model: 要使用的嵌入模型名称
task_type: 嵌入任务类型
**kwargs: 传递给模型的其他参数
返回:
list[list[float]]: 文本的嵌入向量列表
"""
if isinstance(texts, str):
texts = [texts]
if not texts:
@ -420,7 +524,17 @@ async def chat(
model: ModelName = None,
**kwargs: Any,
) -> str:
"""聊天对话便捷函数"""
"""
聊天对话便捷函数
参数:
message: 用户输入的消息
model: 要使用的模型名称
**kwargs: 传递给模型的其他参数
返回:
str: 模型的文本响应
"""
ai = AI()
return await ai.chat(message, model=model, **kwargs)
@ -432,7 +546,18 @@ async def code(
timeout: int | None = None,
**kwargs: Any,
) -> dict[str, Any]:
"""代码执行便捷函数"""
"""
代码执行便捷函数
参数:
prompt: 代码执行的提示词
model: 要使用的模型名称
timeout: 代码执行超时时间
**kwargs: 传递给模型的其他参数
返回:
dict[str, Any]: 包含执行结果的字典
"""
ai = AI()
return await ai.code(prompt, model=model, timeout=timeout, **kwargs)
@ -444,45 +569,56 @@ async def search(
instruction: str = "",
**kwargs: Any,
) -> dict[str, Any]:
"""信息搜索便捷函数"""
"""
信息搜索便捷函数
参数:
query: 搜索查询内容
model: 要使用的模型名称
instruction: 搜索指令
**kwargs: 传递给模型的其他参数
返回:
dict[str, Any]: 包含搜索结果的字典
"""
ai = AI()
return await ai.search(query, model=model, instruction=instruction, **kwargs)
async def analyze(
message: UniMessage,
message: UniMessage | None,
*,
instruction: str = "",
model: ModelName = None,
tools: list[dict[str, Any]] | None = None,
use_tools: list[str] | None = None,
tool_config: dict[str, Any] | None = None,
**kwargs: Any,
) -> str | LLMResponse:
"""内容分析便捷函数"""
"""
内容分析便捷函数
参数:
message: 要分析的消息内容
instruction: 分析指令
model: 要使用的模型名称
use_tools: 要使用的工具名称列表
tool_config: 工具配置
**kwargs: 传递给模型的其他参数
返回:
str | LLMResponse: 分析结果
"""
ai = AI()
return await ai.analyze(
message,
instruction=instruction,
model=model,
tools=tools,
use_tools=use_tools,
tool_config=tool_config,
**kwargs,
)
async def analyze_with_images(
text: str,
images: list[str | Path | bytes] | str | Path | bytes,
*,
instruction: str = "",
model: ModelName = None,
**kwargs: Any,
) -> str | LLMResponse:
"""图片分析便捷函数"""
message = create_multimodal_message(text=text, images=images)
return await analyze(message, instruction=instruction, model=model, **kwargs)
async def analyze_multimodal(
text: str | None = None,
images: list[str | Path | bytes] | str | Path | bytes | None = None,
@ -493,7 +629,21 @@ async def analyze_multimodal(
model: ModelName = None,
**kwargs: Any,
) -> str | LLMResponse:
"""多模态分析便捷函数"""
"""
多模态分析便捷函数
参数:
text: 文本内容
images: 图片文件路径字节数据或列表
videos: 视频文件路径字节数据或列表
audios: 音频文件路径字节数据或列表
instruction: 分析指令
model: 要使用的模型名称
**kwargs: 传递给模型的其他参数
返回:
str | LLMResponse: 分析结果
"""
message = create_multimodal_message(
text=text, images=images, videos=videos, audios=audios
)
@ -510,7 +660,21 @@ async def search_multimodal(
model: ModelName = None,
**kwargs: Any,
) -> dict[str, Any]:
"""多模态搜索便捷函数"""
"""
多模态搜索便捷函数
参数:
text: 文本内容
images: 图片文件路径字节数据或列表
videos: 视频文件路径字节数据或列表
audios: 音频文件路径字节数据或列表
instruction: 搜索指令
model: 要使用的模型名称
**kwargs: 传递给模型的其他参数
返回:
dict[str, Any]: 包含搜索结果的字典
"""
message = create_multimodal_message(
text=text, images=images, videos=videos, audios=audios
)
@ -525,6 +689,101 @@ async def embed(
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
**kwargs: Any,
) -> list[list[float]]:
"""文本嵌入便捷函数"""
"""
文本嵌入便捷函数
参数:
texts: 要生成嵌入向量的文本或文本列表
model: 要使用的嵌入模型名称
task_type: 嵌入任务类型
**kwargs: 传递给模型的其他参数
返回:
list[list[float]]: 文本的嵌入向量列表
"""
ai = AI()
return await ai.embed(texts, model=model, task_type=task_type, **kwargs)
async def pipeline_chat(
message: UniMessage | str | list[LLMContentPart],
model_chain: list[ModelName],
*,
initial_instruction: str = "",
final_instruction: str = "",
**kwargs: Any,
) -> LLMResponse:
"""
AI模型链式调用前一个模型的输出作为下一个模型的输入
参数:
message: 初始输入消息支持多模态
model_chain: 模型名称列表
initial_instruction: 第一个模型的系统指令
final_instruction: 最后一个模型的系统指令
**kwargs: 传递给模型实例的其他参数
返回:
LLMResponse: 最后一个模型的响应结果
"""
if not model_chain:
raise ValueError("模型链`model_chain`不能为空。")
current_content: str | list[LLMContentPart]
if isinstance(message, str):
current_content = message
elif isinstance(message, list):
current_content = message
else:
current_content = await unimsg_to_llm_parts(message)
final_response: LLMResponse | None = None
for i, model_name in enumerate(model_chain):
if not model_name:
raise ValueError(f"模型链中第 {i + 1} 个模型名称为空。")
is_first_step = i == 0
is_last_step = i == len(model_chain) - 1
messages_for_step: list[LLMMessage] = []
instruction_for_step = ""
if is_first_step and initial_instruction:
instruction_for_step = initial_instruction
elif is_last_step and final_instruction:
instruction_for_step = final_instruction
if instruction_for_step:
messages_for_step.append(LLMMessage.system(instruction_for_step))
messages_for_step.append(LLMMessage.user(current_content))
logger.info(
f"Pipeline Step [{i + 1}/{len(model_chain)}]: "
f"使用模型 '{model_name}' 进行处理..."
)
try:
async with await get_model_instance(model_name, **kwargs) as model:
response = await model.generate_response(messages_for_step)
final_response = response
current_content = response.text.strip()
if not current_content and not is_last_step:
logger.warning(
f"模型 '{model_name}' 在中间步骤返回了空内容,流水线可能无法继续。"
)
break
except Exception as e:
logger.error(f"在模型链的第 {i + 1} 步 ('{model_name}') 出错: {e}", e=e)
raise LLMException(
f"流水线在模型 '{model_name}' 处执行失败: {e}",
code=LLMErrorCode.GENERATION_FAILED,
cause=e,
)
if final_response is None:
raise LLMException(
"AI流水线未能产生任何响应。", code=LLMErrorCode.GENERATION_FAILED
)
return final_response

View File

@ -14,6 +14,8 @@ from .generation import (
from .presets import CommonOverrides
from .providers import (
LLMConfig,
ToolConfig,
get_gemini_safety_threshold,
get_llm_config,
register_llm_configs,
set_default_model,
@ -25,8 +27,10 @@ __all__ = [
"LLMConfig",
"LLMGenerationConfig",
"ModelConfigOverride",
"ToolConfig",
"apply_api_specific_mappings",
"create_generation_config_from_kwargs",
"get_gemini_safety_threshold",
"get_llm_config",
"register_llm_configs",
"set_default_model",

View File

@ -111,12 +111,12 @@ class LLMGenerationConfig(ModelConfigOverride):
params["temperature"] = self.temperature
if self.max_tokens is not None:
if api_type in ["gemini", "gemini_native"]:
if api_type == "gemini":
params["maxOutputTokens"] = self.max_tokens
else:
params["max_tokens"] = self.max_tokens
if api_type in ["gemini", "gemini_native"]:
if api_type == "gemini":
if self.top_k is not None:
params["topK"] = self.top_k
if self.top_p is not None:
@ -151,13 +151,13 @@ class LLMGenerationConfig(ModelConfigOverride):
if api_type in ["openai", "zhipu", "deepseek", "general_openai_compat"]:
params["response_format"] = {"type": "json_object"}
logger.debug(f"{api_type} 启用 JSON 对象输出模式")
elif api_type in ["gemini", "gemini_native"]:
elif api_type == "gemini":
params["responseMimeType"] = "application/json"
if self.response_schema:
params["responseSchema"] = self.response_schema
logger.debug(f"{api_type} 启用 JSON MIME 类型输出模式")
if api_type in ["gemini", "gemini_native"]:
if api_type == "gemini":
if (
self.response_format != ResponseFormat.JSON
and self.response_mime_type is not None
@ -214,7 +214,7 @@ def apply_api_specific_mappings(
"""应用API特定的参数映射"""
mapped_params = params.copy()
if api_type in ["gemini", "gemini_native"]:
if api_type == "gemini":
if "max_tokens" in mapped_params:
mapped_params["maxOutputTokens"] = mapped_params.pop("max_tokens")
if "top_k" in mapped_params:

View File

@ -71,14 +71,17 @@ class CommonOverrides:
@staticmethod
def gemini_safe() -> LLMGenerationConfig:
"""Gemini 安全模式:严格安全设置"""
"""Gemini 安全模式:使用配置的安全设置"""
from .providers import get_gemini_safety_threshold
threshold = get_gemini_safety_threshold()
return LLMGenerationConfig(
temperature=0.5,
safety_settings={
"HARM_CATEGORY_HARASSMENT": "BLOCK_MEDIUM_AND_ABOVE",
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_MEDIUM_AND_ABOVE",
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_MEDIUM_AND_ABOVE",
"HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_MEDIUM_AND_ABOVE",
"HARM_CATEGORY_HARASSMENT": threshold,
"HARM_CATEGORY_HATE_SPEECH": threshold,
"HARM_CATEGORY_SEXUALLY_EXPLICIT": threshold,
"HARM_CATEGORY_DANGEROUS_CONTENT": threshold,
},
)

View File

@ -4,15 +4,33 @@ LLM 提供商配置管理
负责注册和管理 AI 服务提供商的配置项
"""
from functools import lru_cache
import json
import sys
from typing import Any
from pydantic import BaseModel, Field
from zhenxun.configs.config import Config
from zhenxun.configs.path_config import DATA_PATH
from zhenxun.configs.utils import parse_as
from zhenxun.services.log import logger
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from ..types.models import ModelDetail, ProviderConfig
class ToolConfig(BaseModel):
"""MCP类型工具的配置定义"""
type: str = "mcp"
name: str = Field(..., description="工具的唯一名称标识")
description: str | None = Field(None, description="工具功能的描述")
mcp_config: dict[str, Any] | BaseModel = Field(
..., description="MCP服务器的特定配置"
)
AI_CONFIG_GROUP = "AI"
PROVIDERS_CONFIG_KEY = "PROVIDERS"
@ -38,6 +56,9 @@ class LLMConfig(BaseModel):
providers: list[ProviderConfig] = Field(
default_factory=list, description="配置多个 AI 服务提供商及其模型信息"
)
mcp_tools: list[ToolConfig] = Field(
default_factory=list, description="配置可用的外部MCP工具"
)
def get_provider_by_name(self, name: str) -> ProviderConfig | None:
"""根据名称获取提供商配置
@ -132,7 +153,7 @@ def get_default_providers() -> list[dict[str, Any]]:
return [
{
"name": "DeepSeek",
"api_key": "sk-******",
"api_key": "YOUR_ARK_API_KEY",
"api_base": "https://api.deepseek.com",
"api_type": "openai",
"models": [
@ -146,9 +167,30 @@ def get_default_providers() -> list[dict[str, Any]]:
},
],
},
{
"name": "ARK",
"api_key": "YOUR_ARK_API_KEY",
"api_base": "https://ark.cn-beijing.volces.com",
"api_type": "ark",
"models": [
{"model_name": "deepseek-r1-250528"},
{"model_name": "doubao-seed-1-6-250615"},
{"model_name": "doubao-seed-1-6-flash-250615"},
{"model_name": "doubao-seed-1-6-thinking-250615"},
],
},
{
"name": "siliconflow",
"api_key": "YOUR_ARK_API_KEY",
"api_base": "https://api.siliconflow.cn",
"api_type": "openai",
"models": [
{"model_name": "deepseek-ai/DeepSeek-V3"},
],
},
{
"name": "GLM",
"api_key": "",
"api_key": "YOUR_ARK_API_KEY",
"api_base": "https://open.bigmodel.cn",
"api_type": "zhipu",
"models": [
@ -167,12 +209,41 @@ def get_default_providers() -> list[dict[str, Any]]:
"api_type": "gemini",
"models": [
{"model_name": "gemini-2.0-flash"},
{"model_name": "gemini-2.5-flash-preview-05-20"},
{"model_name": "gemini-2.5-flash"},
{"model_name": "gemini-2.5-pro"},
{"model_name": "gemini-2.5-flash-lite-preview-06-17"},
],
},
]
def get_default_mcp_tools() -> dict[str, Any]:
"""
获取默认的MCP工具配置用于在文件不存在时创建
包含了 baidu-map, Context7, sequential-thinking.
"""
return {
"mcpServers": {
"baidu-map": {
"command": "npx",
"args": ["-y", "@baidumap/mcp-server-baidu-map"],
"env": {"BAIDU_MAP_API_KEY": "<YOUR_BAIDU_MAP_API_KEY>"},
"description": "百度地图工具,提供地理编码、路线规划等功能。",
},
"sequential-thinking": {
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-sequential-thinking"],
"description": "顺序思维工具,用于帮助模型进行多步骤推理。",
},
"Context7": {
"command": "npx",
"args": ["-y", "@upstash/context7-mcp@latest"],
"description": "Upstash 提供的上下文管理和记忆工具。",
},
}
}
def register_llm_configs():
"""注册 LLM 服务的配置项"""
logger.info("注册 LLM 服务的配置项")
@ -214,6 +285,19 @@ def register_llm_configs():
help="LLM服务请求重试的基础延迟时间",
type=int,
)
Config.add_plugin_config(
AI_CONFIG_GROUP,
"gemini_safety_threshold",
"BLOCK_MEDIUM_AND_ABOVE",
help=(
"Gemini 安全过滤阈值 "
"(BLOCK_LOW_AND_ABOVE: 阻止低级别及以上, "
"BLOCK_MEDIUM_AND_ABOVE: 阻止中等级别及以上, "
"BLOCK_ONLY_HIGH: 只阻止高级别, "
"BLOCK_NONE: 不阻止)"
),
type=str,
)
Config.add_plugin_config(
AI_CONFIG_GROUP,
@ -225,24 +309,111 @@ def register_llm_configs():
)
@lru_cache(maxsize=1)
def get_llm_config() -> LLMConfig:
"""获取 LLM 配置实例
返回:
LLMConfig: LLM 配置实例
"""
"""获取 LLM 配置实例,现在会从新的 JSON 文件加载 MCP 工具"""
ai_config = get_ai_config()
llm_data_path = DATA_PATH / "llm"
mcp_tools_path = llm_data_path / "mcp_tools.json"
mcp_tools_list = []
mcp_servers_dict = {}
if not mcp_tools_path.exists():
logger.info(f"未找到 MCP 工具配置文件,将在 '{mcp_tools_path}' 创建一个。")
llm_data_path.mkdir(parents=True, exist_ok=True)
default_mcp_config = get_default_mcp_tools()
try:
with mcp_tools_path.open("w", encoding="utf-8") as f:
json.dump(default_mcp_config, f, ensure_ascii=False, indent=2)
mcp_servers_dict = default_mcp_config.get("mcpServers", {})
except Exception as e:
logger.error(f"创建默认 MCP 配置文件失败: {e}", e=e)
mcp_servers_dict = {}
else:
try:
with mcp_tools_path.open("r", encoding="utf-8") as f:
mcp_data = json.load(f)
mcp_servers_dict = mcp_data.get("mcpServers", {})
if not isinstance(mcp_servers_dict, dict):
logger.warning(
f"'{mcp_tools_path}' 中的 'mcpServers' 键不是一个字典,"
f"将使用空配置。"
)
mcp_servers_dict = {}
except json.JSONDecodeError as e:
logger.error(f"解析 MCP 配置文件 '{mcp_tools_path}' 失败: {e}", e=e)
except Exception as e:
logger.error(f"读取 MCP 配置文件时发生未知错误: {e}", e=e)
mcp_servers_dict = {}
if sys.platform == "win32":
logger.debug("检测到Windows平台正在调整MCP工具的npx命令...")
for name, config in mcp_servers_dict.items():
if isinstance(config, dict) and config.get("command") == "npx":
logger.info(f"为工具 '{name}' 包装npx命令以兼容Windows。")
original_args = config.get("args", [])
config["command"] = "cmd"
config["args"] = ["/c", "npx", *original_args]
if mcp_servers_dict:
mcp_tools_list = [
{
"name": name,
"type": "mcp",
"description": config.get("description", f"MCP tool for {name}"),
"mcp_config": config,
}
for name, config in mcp_servers_dict.items()
if isinstance(config, dict)
]
from ..tools.registry import tool_registry
for tool_dict in mcp_tools_list:
if isinstance(tool_dict, dict):
tool_name = tool_dict.get("name")
if not tool_name:
continue
config_model = tool_registry.get_mcp_config_model(tool_name)
if not config_model:
logger.debug(
f"MCP工具 '{tool_name}' 没有注册其配置模型,"
f"将跳过特定配置验证,直接使用原始配置字典。"
)
continue
mcp_config_data = tool_dict.get("mcp_config", {})
try:
parsed_mcp_config = parse_as(config_model, mcp_config_data)
tool_dict["mcp_config"] = parsed_mcp_config
except Exception as e:
raise ValueError(f"MCP工具 '{tool_name}' 的 `mcp_config` 配置错误: {e}")
config_data = {
"default_model_name": ai_config.get("default_model_name"),
"proxy": ai_config.get("proxy"),
"timeout": ai_config.get("timeout", 180),
"max_retries_llm": ai_config.get("max_retries_llm", 3),
"retry_delay_llm": ai_config.get("retry_delay_llm", 2),
"providers": ai_config.get(PROVIDERS_CONFIG_KEY, []),
PROVIDERS_CONFIG_KEY: ai_config.get(PROVIDERS_CONFIG_KEY, []),
"mcp_tools": mcp_tools_list,
}
return LLMConfig(**config_data)
return parse_as(LLMConfig, config_data)
def get_gemini_safety_threshold() -> str:
"""获取 Gemini 安全过滤阈值配置
返回:
str: 安全过滤阈值
"""
ai_config = get_ai_config()
return ai_config.get("gemini_safety_threshold", "BLOCK_MEDIUM_AND_ABOVE")
def validate_llm_config() -> tuple[bool, list[str]]:
@ -326,3 +497,17 @@ def set_default_model(provider_model_name: str | None) -> bool:
logger.info("默认模型已清除")
return True
@PriorityLifecycle.on_startup(priority=10)
async def _init_llm_config_on_startup():
"""
在服务启动时主动调用一次 get_llm_config
以触发必要的初始化操作例如创建默认的 mcp_tools.json 文件
"""
logger.info("正在初始化 LLM 配置并检查 MCP 工具文件...")
try:
get_llm_config()
logger.info("LLM 配置初始化完成。")
except Exception as e:
logger.error(f"LLM 配置初始化时发生错误: {e}", e=e)

View File

@ -49,12 +49,36 @@ class LLMHttpClient:
max_keepalive_connections=self.config.max_keepalive_connections,
)
timeout = httpx.Timeout(self.config.timeout)
client_kwargs = {}
if self.config.proxy:
try:
version_parts = httpx.__version__.split(".")
major = int(
"".join(c for c in version_parts[0] if c.isdigit())
)
minor = (
int("".join(c for c in version_parts[1] if c.isdigit()))
if len(version_parts) > 1
else 0
)
if (major, minor) >= (0, 28):
client_kwargs["proxy"] = self.config.proxy
else:
client_kwargs["proxies"] = self.config.proxy
except (ValueError, IndexError):
client_kwargs["proxies"] = self.config.proxy
logger.warning(
f"无法解析 httpx 版本 '{httpx.__version__}'"
"LLM模块将默认使用旧版 'proxies' 参数语法。"
)
self._client = httpx.AsyncClient(
headers=headers,
limits=limits,
timeout=timeout,
proxies=self.config.proxy,
follow_redirects=True,
**client_kwargs,
)
if self._client is None:
raise LLMException(
@ -156,7 +180,16 @@ async def create_llm_http_client(
timeout: int = 180,
proxy: str | None = None,
) -> LLMHttpClient:
"""创建LLM HTTP客户端"""
"""
创建LLM HTTP客户端
参数:
timeout: 超时时间
proxy: 代理服务器地址
返回:
LLMHttpClient: HTTP客户端实例
"""
config = HttpClientConfig(timeout=timeout, proxy=proxy)
return LLMHttpClient(config)
@ -185,7 +218,20 @@ async def with_smart_retry(
provider_name: str | None = None,
**kwargs: Any,
) -> Any:
"""智能重试装饰器 - 支持Key轮询和错误分类"""
"""
智能重试装饰器 - 支持Key轮询和错误分类
参数:
func: 要重试的异步函数
*args: 传递给函数的位置参数
retry_config: 重试配置
key_store: API密钥状态存储
provider_name: 提供商名称
**kwargs: 传递给函数的关键字参数
返回:
Any: 函数执行结果
"""
config = retry_config or RetryConfig()
last_exception: Exception | None = None
failed_keys: set[str] = set()
@ -294,7 +340,17 @@ class KeyStatusStore:
api_keys: list[str],
exclude_keys: set[str] | None = None,
) -> str | None:
"""获取下一个可用的API密钥轮询策略"""
"""
获取下一个可用的API密钥轮询策略
参数:
provider_name: 提供商名称
api_keys: API密钥列表
exclude_keys: 要排除的密钥集合
返回:
str | None: 可用的API密钥如果没有可用密钥则返回None
"""
if not api_keys:
return None
@ -338,7 +394,13 @@ class KeyStatusStore:
logger.debug(f"记录API密钥成功使用: {self._get_key_id(api_key)}")
async def record_failure(self, api_key: str, status_code: int | None):
"""记录失败使用"""
"""
记录失败使用
参数:
api_key: API密钥
status_code: HTTP状态码
"""
key_id = self._get_key_id(api_key)
async with self._lock:
if status_code in [401, 403]:
@ -356,7 +418,15 @@ class KeyStatusStore:
logger.info(f"重置API密钥状态: {self._get_key_id(api_key)}")
async def get_key_stats(self, api_keys: list[str]) -> dict[str, dict]:
"""获取密钥使用统计"""
"""
获取密钥使用统计
参数:
api_keys: API密钥列表
返回:
dict[str, dict]: 密钥统计信息字典
"""
stats = {}
async with self._lock:
for key in api_keys:

View File

@ -17,6 +17,7 @@ from .config.providers import AI_CONFIG_GROUP, PROVIDERS_CONFIG_KEY, get_ai_conf
from .core import http_client_manager, key_store
from .service import LLMModel
from .types import LLMErrorCode, LLMException, ModelDetail, ProviderConfig
from .types.capabilities import get_model_capabilities
DEFAULT_MODEL_NAME_KEY = "default_model_name"
PROXY_KEY = "proxy"
@ -115,57 +116,30 @@ def get_default_api_base_for_type(api_type: str) -> str | None:
def get_configured_providers() -> list[ProviderConfig]:
"""从配置中获取Provider列表 - 简化版本"""
"""从配置中获取Provider列表 - 简化和修正版本"""
ai_config = get_ai_config()
providers_raw = ai_config.get(PROVIDERS_CONFIG_KEY, [])
if not isinstance(providers_raw, list):
providers = ai_config.get(PROVIDERS_CONFIG_KEY, [])
if not isinstance(providers, list):
logger.error(
f"配置项 {AI_CONFIG_GROUP}.{PROVIDERS_CONFIG_KEY} 不是一个列表,"
f"配置项 {AI_CONFIG_GROUP}.{PROVIDERS_CONFIG_KEY} 的值不是一个列表,"
f"将使用空列表。"
)
return []
valid_providers = []
for i, item in enumerate(providers_raw):
if not isinstance(item, dict):
logger.warning(f"配置文件中第 {i + 1} 项不是字典格式,已跳过。")
continue
try:
if not item.get("name"):
logger.warning(f"Provider {i + 1} 缺少 'name' 字段,已跳过。")
continue
if not item.get("api_key"):
logger.warning(
f"Provider '{item['name']}' 缺少 'api_key' 字段,已跳过。"
)
continue
if "api_type" not in item or not item["api_type"]:
provider_name = item.get("name", "").lower()
if "glm" in provider_name or "zhipu" in provider_name:
item["api_type"] = "zhipu"
elif "gemini" in provider_name or "google" in provider_name:
item["api_type"] = "gemini"
else:
item["api_type"] = "openai"
if "api_base" not in item or not item["api_base"]:
api_type = item.get("api_type")
if api_type:
default_api_base = get_default_api_base_for_type(api_type)
if default_api_base:
item["api_base"] = default_api_base
if "models" not in item:
item["models"] = [{"model_name": item.get("name", "default")}]
provider_conf = ProviderConfig(**item)
valid_providers.append(provider_conf)
except Exception as e:
logger.warning(f"解析配置文件中 Provider {i + 1} 时出错: {e},已跳过。")
for i, item in enumerate(providers):
if isinstance(item, ProviderConfig):
if not item.api_base:
default_api_base = get_default_api_base_for_type(item.api_type)
if default_api_base:
item.api_base = default_api_base
valid_providers.append(item)
else:
logger.warning(
f"配置文件中第 {i + 1} 项未能正确解析为 ProviderConfig 对象,"
f"已跳过。实际类型: {type(item)}"
)
return valid_providers
@ -173,14 +147,15 @@ def get_configured_providers() -> list[ProviderConfig]:
def find_model_config(
provider_name: str, model_name: str
) -> tuple[ProviderConfig, ModelDetail] | None:
"""在配置中查找指定的 Provider 和 ModelDetail
"""
在配置中查找指定的 Provider ModelDetail
Args:
参数:
provider_name: 提供商名称
model_name: 模型名称
Returns:
找到的 (ProviderConfig, ModelDetail) 元组未找到则返回 None
返回:
tuple[ProviderConfig, ModelDetail] | None: 找到的配置元组未找到则返回 None
"""
providers = get_configured_providers()
@ -221,10 +196,11 @@ def _get_model_identifiers(provider_name: str, model_detail: ModelDetail) -> lis
def list_model_identifiers() -> dict[str, list[str]]:
"""列出所有模型的可用标识符
"""
列出所有模型的可用标识符
Returns:
字典键为模型的完整名称值为该模型的所有可用标识符列表
返回:
dict[str, list[str]]: 字典键为模型的完整名称值为该模型的所有可用标识符列表
"""
providers = get_configured_providers()
result = {}
@ -248,7 +224,16 @@ async def get_model_instance(
provider_model_name: str | None = None,
override_config: dict[str, Any] | None = None,
) -> LLMModel:
"""根据 'ProviderName/ModelName' 字符串获取并实例化 LLMModel (异步版本)"""
"""
根据 'ProviderName/ModelName' 字符串获取并实例化 LLMModel (异步版本)
参数:
provider_model_name: 模型名称格式为 'ProviderName/ModelName'
override_config: 覆盖配置字典
返回:
LLMModel: 模型实例
"""
cache_key = _make_cache_key(provider_model_name, override_config)
cached_model = _get_cached_model(cache_key)
if cached_model:
@ -292,6 +277,10 @@ async def get_model_instance(
provider_config_found, model_detail_found = config_tuple_found
capabilities = get_model_capabilities(model_detail_found.model_name)
model_detail_found.is_embedding_model = capabilities.is_embedding_model
ai_config = get_ai_config()
global_proxy_setting = ai_config.get(PROXY_KEY)
default_timeout = (
@ -322,6 +311,7 @@ async def get_model_instance(
model_detail=model_detail_found,
key_store=key_store,
http_client=shared_http_client,
capabilities=capabilities,
)
if override_config:
@ -357,7 +347,15 @@ def get_global_default_model_name() -> str | None:
def set_global_default_model_name(provider_model_name: str | None) -> bool:
"""设置全局默认模型名称"""
"""
设置全局默认模型名称
参数:
provider_model_name: 模型名称格式为 'ProviderName/ModelName'
返回:
bool: 设置是否成功
"""
if provider_model_name:
prov_name, mod_name = parse_provider_model_string(provider_model_name)
if not prov_name or not mod_name or not find_model_config(prov_name, mod_name):
@ -377,7 +375,12 @@ def set_global_default_model_name(provider_model_name: str | None) -> bool:
async def get_key_usage_stats() -> dict[str, Any]:
"""获取所有Provider的Key使用统计"""
"""
获取所有Provider的Key使用统计
返回:
dict[str, Any]: 包含所有Provider的Key使用统计信息
"""
providers = get_configured_providers()
stats = {}
@ -400,7 +403,16 @@ async def get_key_usage_stats() -> dict[str, Any]:
async def reset_key_status(provider_name: str, api_key: str | None = None) -> bool:
"""重置指定Provider的Key状态"""
"""
重置指定Provider的Key状态
参数:
provider_name: 提供商名称
api_key: 要重置的特定API密钥如果为None则重置所有密钥
返回:
bool: 重置是否成功
"""
providers = get_configured_providers()
target_provider = None

View File

@ -6,11 +6,13 @@ LLM 模型实现类
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Callable
from contextlib import AsyncExitStack
import json
from typing import Any
from zhenxun.services.log import logger
from .adapters.base import RequestData
from .config import LLMGenerationConfig
from .config.providers import get_ai_config
from .core import (
@ -30,6 +32,8 @@ from .types import (
ModelDetail,
ProviderConfig,
)
from .types.capabilities import ModelCapabilities, ModelModality
from .utils import _sanitize_request_body_for_logging
class LLMModelBase(ABC):
@ -42,7 +46,17 @@ class LLMModelBase(ABC):
history: list[dict[str, str]] | None = None,
**kwargs: Any,
) -> str:
"""生成文本"""
"""
生成文本
参数:
prompt: 输入提示词
history: 对话历史记录
**kwargs: 其他参数
返回:
str: 生成的文本
"""
pass
@abstractmethod
@ -54,7 +68,19 @@ class LLMModelBase(ABC):
tool_choice: str | dict[str, Any] | None = None,
**kwargs: Any,
) -> LLMResponse:
"""生成高级响应"""
"""
生成高级响应
参数:
messages: 消息列表
config: 生成配置
tools: 工具列表
tool_choice: 工具选择策略
**kwargs: 其他参数
返回:
LLMResponse: 模型响应
"""
pass
@abstractmethod
@ -64,7 +90,17 @@ class LLMModelBase(ABC):
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
**kwargs: Any,
) -> list[list[float]]:
"""生成文本嵌入向量"""
"""
生成文本嵌入向量
参数:
texts: 文本列表
task_type: 嵌入任务类型
**kwargs: 其他参数
返回:
list[list[float]]: 嵌入向量列表
"""
pass
@ -77,12 +113,14 @@ class LLMModel(LLMModelBase):
model_detail: ModelDetail,
key_store: KeyStatusStore,
http_client: LLMHttpClient,
capabilities: ModelCapabilities,
config_override: LLMGenerationConfig | None = None,
):
self.provider_config = provider_config
self.model_detail = model_detail
self.key_store = key_store
self.http_client: LLMHttpClient = http_client
self.capabilities = capabilities
self._generation_config = config_override
self.provider_name = provider_config.name
@ -99,6 +137,34 @@ class LLMModel(LLMModelBase):
self._is_closed = False
def can_process_images(self) -> bool:
"""检查模型是否支持图片作为输入。"""
return ModelModality.IMAGE in self.capabilities.input_modalities
def can_process_video(self) -> bool:
"""检查模型是否支持视频作为输入。"""
return ModelModality.VIDEO in self.capabilities.input_modalities
def can_process_audio(self) -> bool:
"""检查模型是否支持音频作为输入。"""
return ModelModality.AUDIO in self.capabilities.input_modalities
def can_generate_images(self) -> bool:
"""检查模型是否支持生成图片。"""
return ModelModality.IMAGE in self.capabilities.output_modalities
def can_generate_audio(self) -> bool:
"""检查模型是否支持生成音频 (TTS)。"""
return ModelModality.AUDIO in self.capabilities.output_modalities
def can_use_tools(self) -> bool:
"""检查模型是否支持工具调用/函数调用。"""
return self.capabilities.supports_tool_calling
def is_embedding_model(self) -> bool:
"""检查这是否是一个嵌入模型。"""
return self.capabilities.is_embedding_model
async def _get_http_client(self) -> LLMHttpClient:
"""获取HTTP客户端"""
if self.http_client.is_closed:
@ -135,24 +201,54 @@ class LLMModel(LLMModelBase):
return selected_key
async def _execute_embedding_request(
async def _perform_api_call(
self,
adapter,
texts: list[str],
task_type: EmbeddingTaskType | str,
http_client: LLMHttpClient,
prepare_request_func: Callable[[str], Awaitable["RequestData"]],
parse_response_func: Callable[[dict[str, Any]], Any],
http_client: "LLMHttpClient",
failed_keys: set[str] | None = None,
) -> list[list[float]]:
"""执行单次嵌入请求 - 供重试机制调用"""
log_context: str = "API",
) -> Any:
"""
执行API调用的通用核心方法
该方法封装了以下通用逻辑:
1. 选择API密钥
2. 准备和记录请求
3. 发送HTTP POST请求
4. 处理HTTP错误和API特定错误
5. 记录密钥使用状态
6. 解析成功的响应
参数:
prepare_request_func: 准备请求的函数
parse_response_func: 解析响应的函数
http_client: HTTP客户端
failed_keys: 失败的密钥集合
log_context: 日志上下文
返回:
Any: 解析后的响应数据
"""
api_key = await self._select_api_key(failed_keys)
try:
request_data = adapter.prepare_embedding_request(
model=self,
api_key=api_key,
texts=texts,
task_type=task_type,
request_data = await prepare_request_func(api_key)
logger.info(
f"🌐 发起LLM请求 - 模型: {self.provider_name}/{self.model_name} "
f"[{log_context}]"
)
logger.debug(f"📡 请求URL: {request_data.url}")
masked_key = (
f"{api_key[:8]}...{api_key[-4:] if len(api_key) > 12 else '***'}"
)
logger.debug(f"🔑 API密钥: {masked_key}")
logger.debug(f"📋 请求头: {dict(request_data.headers)}")
sanitized_body = _sanitize_request_body_for_logging(request_data.body)
request_body_str = json.dumps(sanitized_body, ensure_ascii=False, indent=2)
logger.debug(f"📦 请求体: {request_body_str}")
http_response = await http_client.post(
request_data.url,
@ -160,121 +256,16 @@ class LLMModel(LLMModelBase):
json=request_data.body,
)
if http_response.status_code != 200:
error_text = http_response.text
logger.error(
f"HTTP嵌入请求失败: {http_response.status_code} - {error_text}"
)
await self.key_store.record_failure(api_key, http_response.status_code)
error_code = LLMErrorCode.API_REQUEST_FAILED
if http_response.status_code in [401, 403]:
error_code = LLMErrorCode.API_KEY_INVALID
elif http_response.status_code == 429:
error_code = LLMErrorCode.API_RATE_LIMITED
raise LLMException(
f"HTTP嵌入请求失败: {http_response.status_code}",
code=error_code,
details={
"status_code": http_response.status_code,
"response": error_text,
"api_key": api_key,
},
)
try:
response_json = http_response.json()
adapter.validate_embedding_response(response_json)
embeddings = adapter.parse_embedding_response(response_json)
except Exception as e:
logger.error(f"解析嵌入响应失败: {e}", e=e)
await self.key_store.record_failure(api_key, None)
if isinstance(e, LLMException):
raise
else:
raise LLMException(
f"解析API嵌入响应失败: {e}",
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
cause=e,
)
await self.key_store.record_success(api_key)
return embeddings
except LLMException:
raise
except Exception as e:
logger.error(f"生成嵌入时发生未预期错误: {e}", e=e)
await self.key_store.record_failure(api_key, None)
raise LLMException(
f"生成嵌入失败: {e}",
code=LLMErrorCode.EMBEDDING_FAILED,
cause=e,
)
async def _execute_with_smart_retry(
self,
adapter,
messages: list[LLMMessage],
config: LLMGenerationConfig | None,
tools_dict: list[dict[str, Any]] | None,
tool_choice: str | dict[str, Any] | None,
http_client: LLMHttpClient,
):
"""智能重试机制 - 使用统一的重试装饰器"""
ai_config = get_ai_config()
max_retries = ai_config.get("max_retries_llm", 3)
retry_delay = ai_config.get("retry_delay_llm", 2)
retry_config = RetryConfig(max_retries=max_retries, retry_delay=retry_delay)
return await with_smart_retry(
self._execute_single_request,
adapter,
messages,
config,
tools_dict,
tool_choice,
http_client,
retry_config=retry_config,
key_store=self.key_store,
provider_name=self.provider_name,
)
async def _execute_single_request(
self,
adapter,
messages: list[LLMMessage],
config: LLMGenerationConfig | None,
tools_dict: list[dict[str, Any]] | None,
tool_choice: str | dict[str, Any] | None,
http_client: LLMHttpClient,
failed_keys: set[str] | None = None,
) -> LLMResponse:
"""执行单次请求 - 供重试机制调用,直接返回 LLMResponse"""
api_key = await self._select_api_key(failed_keys)
try:
request_data = adapter.prepare_advanced_request(
model=self,
api_key=api_key,
messages=messages,
config=config,
tools=tools_dict,
tool_choice=tool_choice,
)
http_response = await http_client.post(
request_data.url,
headers=request_data.headers,
json=request_data.body,
)
logger.debug(f"📥 响应状态码: {http_response.status_code}")
logger.debug(f"📄 响应头: {dict(http_response.headers)}")
if http_response.status_code != 200:
error_text = http_response.text
logger.error(
f"HTTP请求失败: {http_response.status_code} - {error_text}"
f"❌ HTTP请求失败: {http_response.status_code} - {error_text} "
f"[{log_context}]"
)
logger.debug(f"💥 完整错误响应: {error_text}")
await self.key_store.record_failure(api_key, http_response.status_code)
@ -299,69 +290,165 @@ class LLMModel(LLMModelBase):
try:
response_json = http_response.json()
response_data = adapter.parse_response(
model=self,
response_json=response_json,
is_advanced=True,
)
from .types.models import LLMToolCall
response_tool_calls = []
if response_data.tool_calls:
for tc_data in response_data.tool_calls:
if isinstance(tc_data, LLMToolCall):
response_tool_calls.append(tc_data)
elif isinstance(tc_data, dict):
try:
response_tool_calls.append(LLMToolCall(**tc_data))
except Exception as e:
logger.warning(
f"无法将工具调用数据转换为LLMToolCall: {tc_data}, "
f"error: {e}"
)
else:
logger.warning(f"工具调用数据格式未知: {tc_data}")
llm_response = LLMResponse(
text=response_data.text,
usage_info=response_data.usage_info,
raw_response=response_data.raw_response,
tool_calls=response_tool_calls if response_tool_calls else None,
code_executions=response_data.code_executions,
grounding_metadata=response_data.grounding_metadata,
cache_info=response_data.cache_info,
response_json_str = json.dumps(
response_json, ensure_ascii=False, indent=2
)
logger.debug(f"📋 响应JSON: {response_json_str}")
parsed_data = parse_response_func(response_json)
except Exception as e:
logger.error(f"解析响应失败: {e}", e=e)
logger.error(f"解析 {log_context} 响应失败: {e}", e=e)
await self.key_store.record_failure(api_key, None)
if isinstance(e, LLMException):
raise
else:
raise LLMException(
f"解析API响应失败: {e}",
f"解析API {log_context} 响应失败: {e}",
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
cause=e,
)
await self.key_store.record_success(api_key)
return llm_response
logger.debug(f"✅ API密钥使用成功: {masked_key}")
logger.info(f"🎯 LLM响应解析完成 [{log_context}]")
return parsed_data
except LLMException:
raise
except Exception as e:
logger.error(f"生成响应时发生未预期错误: {e}", e=e)
error_log_msg = f"生成 {log_context.lower()} 时发生未预期错误: {e}"
logger.error(error_log_msg, e=e)
await self.key_store.record_failure(api_key, None)
raise LLMException(
f"生成响应失败: {e}",
code=LLMErrorCode.GENERATION_FAILED,
error_log_msg,
code=LLMErrorCode.GENERATION_FAILED
if log_context == "Generation"
else LLMErrorCode.EMBEDDING_FAILED,
cause=e,
)
async def _execute_embedding_request(
self,
adapter,
texts: list[str],
task_type: EmbeddingTaskType | str,
http_client: LLMHttpClient,
failed_keys: set[str] | None = None,
) -> list[list[float]]:
"""执行单次嵌入请求 - 供重试机制调用"""
async def prepare_request(api_key: str) -> RequestData:
return adapter.prepare_embedding_request(
model=self,
api_key=api_key,
texts=texts,
task_type=task_type,
)
def parse_response(response_json: dict[str, Any]) -> list[list[float]]:
adapter.validate_embedding_response(response_json)
return adapter.parse_embedding_response(response_json)
return await self._perform_api_call(
prepare_request_func=prepare_request,
parse_response_func=parse_response,
http_client=http_client,
failed_keys=failed_keys,
log_context="Embedding",
)
async def _execute_with_smart_retry(
self,
adapter,
messages: list[LLMMessage],
config: LLMGenerationConfig | None,
tools: list[LLMTool] | None,
tool_choice: str | dict[str, Any] | None,
http_client: LLMHttpClient,
):
"""智能重试机制 - 使用统一的重试装饰器"""
ai_config = get_ai_config()
max_retries = ai_config.get("max_retries_llm", 3)
retry_delay = ai_config.get("retry_delay_llm", 2)
retry_config = RetryConfig(max_retries=max_retries, retry_delay=retry_delay)
return await with_smart_retry(
self._execute_single_request,
adapter,
messages,
config,
tools,
tool_choice,
http_client,
retry_config=retry_config,
key_store=self.key_store,
provider_name=self.provider_name,
)
async def _execute_single_request(
self,
adapter,
messages: list[LLMMessage],
config: LLMGenerationConfig | None,
tools: list[LLMTool] | None,
tool_choice: str | dict[str, Any] | None,
http_client: LLMHttpClient,
failed_keys: set[str] | None = None,
) -> LLMResponse:
"""执行单次请求 - 供重试机制调用,直接返回 LLMResponse"""
async def prepare_request(api_key: str) -> RequestData:
return await adapter.prepare_advanced_request(
model=self,
api_key=api_key,
messages=messages,
config=config,
tools=tools,
tool_choice=tool_choice,
)
def parse_response(response_json: dict[str, Any]) -> LLMResponse:
response_data = adapter.parse_response(
model=self,
response_json=response_json,
is_advanced=True,
)
from .types.models import LLMToolCall
response_tool_calls = []
if response_data.tool_calls:
for tc_data in response_data.tool_calls:
if isinstance(tc_data, LLMToolCall):
response_tool_calls.append(tc_data)
elif isinstance(tc_data, dict):
try:
response_tool_calls.append(LLMToolCall(**tc_data))
except Exception as e:
logger.warning(
f"无法将工具调用数据转换为LLMToolCall: {tc_data}, "
f"error: {e}"
)
else:
logger.warning(f"工具调用数据格式未知: {tc_data}")
return LLMResponse(
text=response_data.text,
usage_info=response_data.usage_info,
raw_response=response_data.raw_response,
tool_calls=response_tool_calls if response_tool_calls else None,
code_executions=response_data.code_executions,
grounding_metadata=response_data.grounding_metadata,
cache_info=response_data.cache_info,
)
return await self._perform_api_call(
prepare_request_func=prepare_request,
parse_response_func=parse_response,
http_client=http_client,
failed_keys=failed_keys,
log_context="Generation",
)
async def close(self):
"""
标记模型实例的当前使用周期结束
@ -400,7 +487,17 @@ class LLMModel(LLMModelBase):
history: list[dict[str, str]] | None = None,
**kwargs: Any,
) -> str:
"""生成文本 - 通过 generate_response 实现"""
"""
生成文本 - 通过 generate_response 实现
参数:
prompt: 输入提示词
history: 对话历史记录
**kwargs: 其他参数
返回:
str: 生成的文本
"""
self._check_not_closed()
messages: list[LLMMessage] = []
@ -439,11 +536,21 @@ class LLMModel(LLMModelBase):
config: LLMGenerationConfig | None = None,
tools: list[LLMTool] | None = None,
tool_choice: str | dict[str, Any] | None = None,
tool_executor: Callable[[str, dict[str, Any]], Awaitable[Any]] | None = None,
max_tool_iterations: int = 5,
**kwargs: Any,
) -> LLMResponse:
"""生成高级响应 - 实现完整的工具调用循环"""
"""
生成高级响应
参数:
messages: 消息列表
config: 生成配置
tools: 工具列表
tool_choice: 工具选择策略
**kwargs: 其他参数
返回:
LLMResponse: 模型响应
"""
self._check_not_closed()
from .adapters import get_adapter_for_api_type
@ -468,109 +575,43 @@ class LLMModel(LLMModelBase):
merged_dict.update(config.to_dict())
final_request_config = LLMGenerationConfig(**merged_dict)
tools_dict: list[dict[str, Any]] | None = None
if tools:
tools_dict = []
for tool in tools:
if hasattr(tool, "model_dump"):
model_dump_func = getattr(tool, "model_dump")
tools_dict.append(model_dump_func(exclude_none=True))
elif isinstance(tool, dict):
tools_dict.append(tool)
else:
try:
tools_dict.append(dict(tool))
except (TypeError, ValueError):
logger.warning(f"工具 '{tool}' 无法转换为字典,已忽略。")
http_client = await self._get_http_client()
current_messages = list(messages)
for iteration in range(max_tool_iterations):
logger.debug(f"工具调用循环迭代: {iteration + 1}/{max_tool_iterations}")
async with AsyncExitStack() as stack:
activated_tools = []
if tools:
for tool in tools:
if tool.type == "mcp" and callable(tool.mcp_session):
func_obj = getattr(tool.mcp_session, "func", None)
tool_name = (
getattr(func_obj, "__name__", "unknown")
if func_obj
else "unknown"
)
logger.debug(f"正在激活 MCP 工具会话: {tool_name}")
active_session = await stack.enter_async_context(
tool.mcp_session()
)
activated_tools.append(
LLMTool.from_mcp_session(
session=active_session, annotations=tool.annotations
)
)
else:
activated_tools.append(tool)
llm_response = await self._execute_with_smart_retry(
adapter,
current_messages,
messages,
final_request_config,
tools_dict if iteration == 0 else None,
tool_choice if iteration == 0 else None,
activated_tools if activated_tools else None,
tool_choice,
http_client,
)
response_tool_calls = llm_response.tool_calls or []
if not response_tool_calls or not tool_executor:
logger.debug("模型未请求工具调用,或未提供工具执行器。返回当前响应。")
return llm_response
logger.info(f"模型请求执行 {len(response_tool_calls)} 个工具。")
assistant_message_content = llm_response.text if llm_response.text else ""
current_messages.append(
LLMMessage.assistant_tool_calls(
content=assistant_message_content, tool_calls=response_tool_calls
)
)
tool_response_messages: list[LLMMessage] = []
for tool_call in response_tool_calls:
tool_name = tool_call.function.name
try:
tool_args_dict = json.loads(tool_call.function.arguments)
logger.debug(f"执行工具: {tool_name},参数: {tool_args_dict}")
tool_result = await tool_executor(tool_name, tool_args_dict)
logger.debug(
f"工具 '{tool_name}' 执行结果: {str(tool_result)[:200]}..."
)
tool_response_messages.append(
LLMMessage.tool_response(
tool_call_id=tool_call.id,
function_name=tool_name,
result=tool_result,
)
)
except json.JSONDecodeError as e:
logger.error(
f"工具 '{tool_name}' 参数JSON解析失败: "
f"{tool_call.function.arguments}, 错误: {e}"
)
tool_response_messages.append(
LLMMessage.tool_response(
tool_call_id=tool_call.id,
function_name=tool_name,
result={
"error": "Argument JSON parsing failed",
"details": str(e),
},
)
)
except Exception as e:
logger.error(f"执行工具 '{tool_name}' 失败: {e}", e=e)
tool_response_messages.append(
LLMMessage.tool_response(
tool_call_id=tool_call.id,
function_name=tool_name,
result={
"error": "Tool execution failed",
"details": str(e),
},
)
)
current_messages.extend(tool_response_messages)
logger.warning(f"已达到最大工具调用迭代次数 ({max_tool_iterations})。")
raise LLMException(
"已达到最大工具调用迭代次数,但模型仍在请求工具调用或未提供最终文本回复。",
code=LLMErrorCode.GENERATION_FAILED,
details={
"iterations": max_tool_iterations,
"last_messages": current_messages[-2:],
},
)
return llm_response
async def generate_embeddings(
self,
@ -578,7 +619,17 @@ class LLMModel(LLMModelBase):
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
**kwargs: Any,
) -> list[list[float]]:
"""生成文本嵌入向量"""
"""
生成文本嵌入向量
参数:
texts: 文本列表
task_type: 嵌入任务类型
**kwargs: 其他参数
返回:
list[list[float]]: 嵌入向量列表
"""
self._check_not_closed()
if not texts:
return []

View File

@ -0,0 +1,7 @@
"""
工具模块导出
"""
from .registry import tool_registry
__all__ = ["tool_registry"]

View File

@ -0,0 +1,181 @@
"""
工具注册表
负责加载管理和实例化来自配置的工具
"""
from collections.abc import Callable
from contextlib import AbstractAsyncContextManager
from functools import partial
from typing import TYPE_CHECKING
from pydantic import BaseModel
from zhenxun.services.log import logger
from ..types import LLMTool
if TYPE_CHECKING:
from ..config.providers import ToolConfig
from ..types.protocols import MCPCompatible
class ToolRegistry:
"""工具注册表,用于管理和实例化配置的工具。"""
def __init__(self):
self._function_tools: dict[str, LLMTool] = {}
self._mcp_config_models: dict[str, type[BaseModel]] = {}
if TYPE_CHECKING:
self._mcp_factories: dict[
str, Callable[..., AbstractAsyncContextManager["MCPCompatible"]]
] = {}
else:
self._mcp_factories: dict[str, Callable] = {}
self._tool_configs: dict[str, "ToolConfig"] | None = None
self._tool_cache: dict[str, "LLMTool"] = {}
def _load_configs_if_needed(self):
"""如果尚未加载则从主配置中加载MCP工具定义。"""
if self._tool_configs is None:
logger.debug("首次访问正在加载MCP工具配置...")
from ..config.providers import get_llm_config
llm_config = get_llm_config()
self._tool_configs = {tool.name: tool for tool in llm_config.mcp_tools}
logger.info(f"已加载 {len(self._tool_configs)} 个MCP工具配置。")
def function_tool(
self,
name: str,
description: str,
parameters: dict,
required: list[str] | None = None,
):
"""
装饰器在代码中注册一个简单的无状态的函数工具
参数:
name: 工具的唯一名称
description: 工具功能的描述
parameters: OpenAPI格式的函数参数schema的properties部分
required: 必需的参数列表
"""
def decorator(func: Callable):
if name in self._function_tools or name in self._mcp_factories:
logger.warning(f"正在覆盖已注册的工具: {name}")
tool_definition = LLMTool.create(
name=name,
description=description,
parameters=parameters,
required=required,
)
self._function_tools[name] = tool_definition
logger.info(f"已在代码中注册函数工具: '{name}'")
tool_definition.annotations = tool_definition.annotations or {}
tool_definition.annotations["executable"] = func
return func
return decorator
def mcp_tool(self, name: str, config_model: type[BaseModel]):
"""
装饰器注册一个MCP工具及其配置模型
参数:
name: 工具的唯一名称必须与配置文件中的名称匹配
config_model: 一个Pydantic模型用于定义和验证该工具的 `mcp_config`
"""
def decorator(factory_func: Callable):
if name in self._mcp_factories:
logger.warning(f"正在覆盖已注册的 MCP 工厂: {name}")
self._mcp_factories[name] = factory_func
self._mcp_config_models[name] = config_model
logger.info(f"已注册 MCP 工具 '{name}' (配置模型: {config_model.__name__})")
return factory_func
return decorator
def get_mcp_config_model(self, name: str) -> type[BaseModel] | None:
"""根据名称获取MCP工具的配置模型。"""
return self._mcp_config_models.get(name)
def register_mcp_factory(
self,
name: str,
factory: Callable,
):
"""
在代码中注册一个 MCP 会话工厂将其与配置中的工具名称关联
参数:
name: 工具的唯一名称必须与配置文件中的名称匹配
factory: 一个返回异步生成器的可调用对象会话工厂
"""
if name in self._mcp_factories:
logger.warning(f"正在覆盖已注册的 MCP 工厂: {name}")
self._mcp_factories[name] = factory
logger.info(f"已注册 MCP 会话工厂: '{name}'")
def get_tool(self, name: str) -> "LLMTool":
"""
根据名称获取一个 LLMTool 定义
对于MCP工具返回的 LLMTool 实例包含一个可调用的会话工厂
而不是一个已激活的会话
"""
logger.debug(f"🔍 请求获取工具定义: {name}")
if name in self._tool_cache:
logger.debug(f"✅ 从缓存中获取工具定义: {name}")
return self._tool_cache[name]
if name in self._function_tools:
logger.debug(f"🛠️ 获取函数工具定义: {name}")
tool = self._function_tools[name]
self._tool_cache[name] = tool
return tool
self._load_configs_if_needed()
if self._tool_configs is None or name not in self._tool_configs:
known_tools = list(self._function_tools.keys()) + (
list(self._tool_configs.keys()) if self._tool_configs else []
)
logger.error(f"❌ 未找到名为 '{name}' 的工具定义")
logger.debug(f"📋 可用工具定义列表: {known_tools}")
raise ValueError(f"未找到名为 '{name}' 的工具定义。已知工具: {known_tools}")
config = self._tool_configs[name]
tool: "LLMTool"
if name not in self._mcp_factories:
logger.error(f"❌ MCP工具 '{name}' 缺少工厂函数")
available_factories = list(self._mcp_factories.keys())
logger.debug(f"📋 已注册的MCP工厂: {available_factories}")
raise ValueError(
f"MCP 工具 '{name}' 已在配置中定义,但没有注册对应的工厂函数。"
"请使用 `@tool_registry.mcp_tool` 装饰器进行注册。"
)
logger.info(f"🔧 创建MCP工具定义: {name}")
factory = self._mcp_factories[name]
typed_mcp_config = config.mcp_config
logger.debug(f"📋 MCP工具配置: {typed_mcp_config}")
configured_factory = partial(factory, config=typed_mcp_config)
tool = LLMTool.from_mcp_session(session=configured_factory)
self._tool_cache[name] = tool
logger.debug(f"💾 MCP工具定义已缓存: {name}")
return tool
def get_tools(self, names: list[str]) -> list["LLMTool"]:
"""根据名称列表获取多个 LLMTool 实例。"""
return [self.get_tool(name) for name in names]
tool_registry = ToolRegistry()

View File

@ -4,6 +4,7 @@ LLM 类型定义模块
统一导出所有核心类型协议和异常定义
"""
from .capabilities import ModelCapabilities, ModelModality, get_model_capabilities
from .content import (
LLMContentPart,
LLMMessage,
@ -26,6 +27,7 @@ from .models import (
ToolMetadata,
UsageInfo,
)
from .protocols import MCPCompatible
__all__ = [
"EmbeddingTaskType",
@ -41,8 +43,11 @@ __all__ = [
"LLMTool",
"LLMToolCall",
"LLMToolFunction",
"MCPCompatible",
"ModelCapabilities",
"ModelDetail",
"ModelInfo",
"ModelModality",
"ModelName",
"ModelProvider",
"ProviderConfig",
@ -50,5 +55,6 @@ __all__ = [
"ToolCategory",
"ToolMetadata",
"UsageInfo",
"get_model_capabilities",
"get_user_friendly_error_message",
]

View File

@ -0,0 +1,128 @@
"""
LLM 模型能力定义模块
定义模型的输入输出模态工具调用支持等核心能力
"""
from enum import Enum
import fnmatch
from pydantic import BaseModel, Field
class ModelModality(str, Enum):
TEXT = "text"
IMAGE = "image"
AUDIO = "audio"
VIDEO = "video"
EMBEDDING = "embedding"
class ModelCapabilities(BaseModel):
"""定义一个模型的核心、稳定能力。"""
input_modalities: set[ModelModality] = Field(default={ModelModality.TEXT})
output_modalities: set[ModelModality] = Field(default={ModelModality.TEXT})
supports_tool_calling: bool = False
is_embedding_model: bool = False
STANDARD_TEXT_TOOL_CAPABILITIES = ModelCapabilities(
input_modalities={ModelModality.TEXT},
output_modalities={ModelModality.TEXT},
supports_tool_calling=True,
)
GEMINI_CAPABILITIES = ModelCapabilities(
input_modalities={
ModelModality.TEXT,
ModelModality.IMAGE,
ModelModality.AUDIO,
ModelModality.VIDEO,
},
output_modalities={ModelModality.TEXT},
supports_tool_calling=True,
)
DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES = ModelCapabilities(
input_modalities={ModelModality.TEXT, ModelModality.IMAGE, ModelModality.VIDEO},
output_modalities={ModelModality.TEXT},
supports_tool_calling=True,
)
MODEL_ALIAS_MAPPING: dict[str, str] = {
"deepseek-v3*": "deepseek-chat",
"deepseek-ai/DeepSeek-V3": "deepseek-chat",
"deepseek-r1*": "deepseek-reasoner",
}
MODEL_CAPABILITIES_REGISTRY: dict[str, ModelCapabilities] = {
"gemini-*-tts": ModelCapabilities(
input_modalities={ModelModality.TEXT},
output_modalities={ModelModality.AUDIO},
),
"gemini-*-native-audio-*": ModelCapabilities(
input_modalities={ModelModality.TEXT, ModelModality.AUDIO, ModelModality.VIDEO},
output_modalities={ModelModality.TEXT, ModelModality.AUDIO},
supports_tool_calling=True,
),
"gemini-2.0-flash-preview-image-generation": ModelCapabilities(
input_modalities={
ModelModality.TEXT,
ModelModality.IMAGE,
ModelModality.AUDIO,
ModelModality.VIDEO,
},
output_modalities={ModelModality.TEXT, ModelModality.IMAGE},
supports_tool_calling=True,
),
"gemini-embedding-exp": ModelCapabilities(
input_modalities={ModelModality.TEXT},
output_modalities={ModelModality.EMBEDDING},
is_embedding_model=True,
),
"gemini-2.5-pro*": GEMINI_CAPABILITIES,
"gemini-1.5-pro*": GEMINI_CAPABILITIES,
"gemini-2.5-flash*": GEMINI_CAPABILITIES,
"gemini-2.0-flash*": GEMINI_CAPABILITIES,
"gemini-1.5-flash*": GEMINI_CAPABILITIES,
"GLM-4V-Flash": ModelCapabilities(
input_modalities={ModelModality.TEXT, ModelModality.IMAGE},
output_modalities={ModelModality.TEXT},
supports_tool_calling=True,
),
"GLM-4V-Plus*": ModelCapabilities(
input_modalities={ModelModality.TEXT, ModelModality.IMAGE, ModelModality.VIDEO},
output_modalities={ModelModality.TEXT},
supports_tool_calling=True,
),
"glm-4-*": STANDARD_TEXT_TOOL_CAPABILITIES,
"glm-z1-*": STANDARD_TEXT_TOOL_CAPABILITIES,
"doubao-seed-*": DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES,
"doubao-1-5-thinking-vision-pro": DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES,
"deepseek-chat": STANDARD_TEXT_TOOL_CAPABILITIES,
"deepseek-reasoner": STANDARD_TEXT_TOOL_CAPABILITIES,
}
def get_model_capabilities(model_name: str) -> ModelCapabilities:
"""
从注册表获取模型能力支持别名映射和通配符匹配
查找顺序: 1. 标准化名称 -> 2. 精确匹配 -> 3. 通配符匹配 -> 4. 默认值
"""
canonical_name = model_name
for alias_pattern, c_name in MODEL_ALIAS_MAPPING.items():
if fnmatch.fnmatch(model_name, alias_pattern):
canonical_name = c_name
break
if canonical_name in MODEL_CAPABILITIES_REGISTRY:
return MODEL_CAPABILITIES_REGISTRY[canonical_name]
for pattern, capabilities in MODEL_CAPABILITIES_REGISTRY.items():
if "*" in pattern and fnmatch.fnmatch(model_name, pattern):
return capabilities
return ModelCapabilities()

View File

@ -225,8 +225,10 @@ class LLMContentPart(BaseModel):
logger.warning(f"无法解析Base64图像数据: {self.image_source[:50]}...")
return None
def convert_for_api(self, api_type: str) -> dict[str, Any]:
async def convert_for_api_async(self, api_type: str) -> dict[str, Any]:
"""根据API类型转换多模态内容格式"""
from zhenxun.utils.http_utils import AsyncHttpx
if self.type == "text":
if api_type == "openai":
return {"type": "text", "text": self.text}
@ -248,20 +250,23 @@ class LLMContentPart(BaseModel):
mime_type, data = base64_info
return {"inlineData": {"mimeType": mime_type, "data": data}}
else:
# 如果无法解析 Base64 数据,抛出异常
raise ValueError(
f"无法解析Base64图像数据: {self.image_source[:50]}..."
)
else:
logger.warning(
f"Gemini API需要Base64格式但提供的是URL: {self.image_source}"
)
return {
"inlineData": {
"mimeType": "image/jpeg",
"data": self.image_source,
elif self.is_image_url():
logger.debug(f"正在为Gemini下载并编码URL图片: {self.image_source}")
try:
image_bytes = await AsyncHttpx.get_content(self.image_source)
mime_type = self.mime_type or "image/jpeg"
base64_data = base64.b64encode(image_bytes).decode("utf-8")
return {
"inlineData": {"mimeType": mime_type, "data": base64_data}
}
}
except Exception as e:
logger.error(f"下载或编码URL图片失败: {e}", e=e)
raise ValueError(f"无法处理图片URL: {e}")
else:
raise ValueError(f"不支持的图像源格式: {self.image_source[:50]}...")
else:
return {"type": "image_url", "image_url": {"url": self.image_source}}

View File

@ -4,13 +4,25 @@ LLM 数据模型定义
包含模型信息配置工具定义和响应数据的模型类
"""
from collections.abc import Callable
from contextlib import AbstractAsyncContextManager
from dataclasses import dataclass, field
from typing import Any
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel, Field
from .enums import ModelProvider, ToolCategory
if TYPE_CHECKING:
from .protocols import MCPCompatible
MCPSessionType = (
MCPCompatible | Callable[[], AbstractAsyncContextManager[MCPCompatible]] | None
)
else:
MCPCompatible = object
MCPSessionType = Any
ModelName = str | None
@ -98,10 +110,21 @@ class LLMToolCall(BaseModel):
class LLMTool(BaseModel):
"""LLM 工具定义(支持 MCP 风格)"""
model_config = {"arbitrary_types_allowed": True}
type: str = "function"
function: dict[str, Any]
function: dict[str, Any] | None = None
mcp_session: MCPSessionType = None
annotations: dict[str, Any] | None = Field(default=None, description="工具注解")
def model_post_init(self, /, __context: Any) -> None:
"""验证工具定义的有效性"""
_ = __context
if self.type == "function" and self.function is None:
raise ValueError("函数类型的工具必须包含 'function' 字段。")
if self.type == "mcp" and self.mcp_session is None:
raise ValueError("MCP 类型的工具必须包含 'mcp_session' 字段。")
@classmethod
def create(
cls,
@ -111,7 +134,7 @@ class LLMTool(BaseModel):
required: list[str] | None = None,
annotations: dict[str, Any] | None = None,
) -> "LLMTool":
"""创建工具"""
"""创建函数工具"""
function_def = {
"name": name,
"description": description,
@ -123,6 +146,15 @@ class LLMTool(BaseModel):
}
return cls(type="function", function=function_def, annotations=annotations)
@classmethod
def from_mcp_session(
cls,
session: Any,
annotations: dict[str, Any] | None = None,
) -> "LLMTool":
"""从 MCP 会话创建工具"""
return cls(type="mcp", mcp_session=session, annotations=annotations)
class LLMCodeExecution(BaseModel):
"""代码执行结果"""

View File

@ -0,0 +1,24 @@
"""
LLM 模块的协议定义
"""
from typing import Any, Protocol
class MCPCompatible(Protocol):
"""
一个协议定义了与LLM模块兼容的MCP会话对象应具备的行为
任何实现了 to_api_tool 方法的对象都可以被认为是 MCPCompatible
"""
def to_api_tool(self, api_type: str) -> dict[str, Any]:
"""
将此MCP会话转换为特定LLM提供商API所需的工具格式
参数:
api_type: 目标API的类型 (例如 'gemini', 'openai')
返回:
dict[str, Any]: 一个字典代表可以在API请求中使用的工具定义
"""
...

View File

@ -3,8 +3,10 @@ LLM 模块的工具和转换函数
"""
import base64
import copy
from pathlib import Path
from nonebot.adapters import Message as PlatformMessage
from nonebot_plugin_alconna.uniseg import (
At,
File,
@ -17,6 +19,7 @@ from nonebot_plugin_alconna.uniseg import (
)
from zhenxun.services.log import logger
from zhenxun.utils.http_utils import AsyncHttpx
from .types import LLMContentPart
@ -25,6 +28,12 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
"""
UniMessage 实例转换为一个 LLMContentPart 列表
这是处理多模态输入的核心转换逻辑
参数:
message: 要转换的UniMessage实例
返回:
list[LLMContentPart]: 转换后的内容部分列表
"""
parts: list[LLMContentPart] = []
for seg in message:
@ -51,14 +60,25 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
if seg.path:
part = await LLMContentPart.from_path(seg.path)
elif seg.url:
logger.warning(
f"直接使用 URL 的 {type(seg).__name__} 段,"
f"API 可能不支持: {seg.url}"
)
part = LLMContentPart.text_part(
f"[{type(seg).__name__.upper()} FILE: {seg.name or seg.url}]"
)
elif hasattr(seg, "raw") and seg.raw:
try:
logger.debug(f"检测到媒体URL开始下载: {seg.url}")
media_bytes = await AsyncHttpx.get_content(seg.url)
new_seg = copy.copy(seg)
new_seg.raw = media_bytes
seg = new_seg
logger.debug(f"媒体文件下载成功,大小: {len(media_bytes)} bytes")
except Exception as e:
logger.error(f"从URL下载媒体失败: {seg.url}, 错误: {e}")
part = LLMContentPart.text_part(
f"[下载媒体失败: {seg.name or seg.url}]"
)
if part:
parts.append(part)
continue
if hasattr(seg, "raw") and seg.raw:
mime_type = getattr(seg, "mimetype", None)
if isinstance(seg.raw, bytes):
b64_data = base64.b64encode(seg.raw).decode("utf-8")
@ -127,50 +147,19 @@ def create_multimodal_message(
audio_mimetypes: list[str] | str | None = None,
) -> UniMessage:
"""
创建多模态消息的便捷函数方便第三方调用
创建多模态消息的便捷函数
Args:
参数:
text: 文本内容
images: 图片数据支持路径字节数据或URL
videos: 视频数据支持路径字节数据或URL
audios: 音频数据支持路径字节数据或URL
image_mimetypes: 图片MIME类型当images为bytes时需要指定
video_mimetypes: 视频MIME类型当videos为bytes时需要指定
audio_mimetypes: 音频MIME类型当audios为bytes时需要指定
videos: 视频数据
audios: 音频数据
image_mimetypes: 图片MIME类型bytes数据时需要指定
video_mimetypes: 视频MIME类型bytes数据时需要指定
audio_mimetypes: 音频MIME类型bytes数据时需要指定
Returns:
返回:
UniMessage: 构建好的多模态消息
Examples:
# 纯文本
msg = create_multimodal_message("请分析这段文字")
# 文本 + 单张图片(路径)
msg = create_multimodal_message("分析图片", images="/path/to/image.jpg")
# 文本 + 多张图片
msg = create_multimodal_message(
"比较图片", images=["/path/1.jpg", "/path/2.jpg"]
)
# 文本 + 图片字节数据
msg = create_multimodal_message(
"分析", images=image_data, image_mimetypes="image/jpeg"
)
# 文本 + 视频
msg = create_multimodal_message("分析视频", videos="/path/to/video.mp4")
# 文本 + 音频
msg = create_multimodal_message("转录音频", audios="/path/to/audio.wav")
# 混合多模态
msg = create_multimodal_message(
"分析这些媒体文件",
images="/path/to/image.jpg",
videos="/path/to/video.mp4",
audios="/path/to/audio.wav"
)
"""
message = UniMessage()
@ -196,7 +185,7 @@ def _add_media_to_message(
media_class: type,
default_mimetype: str,
) -> None:
"""添加媒体文件到 UniMessage 的辅助函数"""
"""添加媒体文件到 UniMessage"""
if not isinstance(media_items, list):
media_items = [media_items]
@ -216,3 +205,80 @@ def _add_media_to_message(
elif isinstance(item, bytes):
mimetype = mime_list[i] if i < len(mime_list) else default_mimetype
message.append(media_class(raw=item, mimetype=mimetype))
def message_to_unimessage(message: PlatformMessage) -> UniMessage:
"""
将平台特定的 Message 对象转换为通用的 UniMessage
主要用于处理引用消息等未被自动转换的消息体
参数:
message: 平台特定的Message对象
返回:
UniMessage: 转换后的通用消息对象
"""
uni_segments = []
for seg in message:
if seg.type == "text":
uni_segments.append(Text(seg.data.get("text", "")))
elif seg.type == "image":
uni_segments.append(Image(url=seg.data.get("url")))
elif seg.type == "record":
uni_segments.append(Voice(url=seg.data.get("url")))
elif seg.type == "video":
uni_segments.append(Video(url=seg.data.get("url")))
elif seg.type == "at":
uni_segments.append(At("user", str(seg.data.get("qq", ""))))
else:
logger.debug(f"跳过不支持的平台消息段类型: {seg.type}")
return UniMessage(uni_segments)
def _sanitize_request_body_for_logging(body: dict) -> dict:
"""
净化请求体用于日志记录移除大数据字段并添加摘要信息
参数:
body: 原始请求体字典
返回:
dict: 净化后的请求体字典
"""
try:
sanitized_body = copy.deepcopy(body)
if "contents" in sanitized_body and isinstance(
sanitized_body["contents"], list
):
for content_item in sanitized_body["contents"]:
if "parts" in content_item and isinstance(content_item["parts"], list):
media_summary = []
new_parts = []
for part in content_item["parts"]:
if "inlineData" in part and isinstance(
part["inlineData"], dict
):
data = part["inlineData"].get("data")
if isinstance(data, str):
mime_type = part["inlineData"].get(
"mimeType", "unknown"
)
media_summary.append(f"{mime_type} ({len(data)} chars)")
continue
new_parts.append(part)
if media_summary:
summary_text = (
f"[多模态内容: {len(media_summary)}个文件 - "
f"{', '.join(media_summary)}]"
)
new_parts.insert(0, {"text": summary_text})
content_item["parts"] = new_parts
return sanitized_body
except Exception as e:
logger.warning(f"日志净化失败: {e},将记录原始请求体。")
return body