mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-14 21:52:56 +08:00
✨ feat(llm): 全面重构LLM服务模块,增强多模态与工具支持 (#1953)
* ✨ feat(llm): 全面重构LLM服务模块,增强多模态与工具支持 🚀 核心功能增强 - 多模型链式调用:新增 `pipeline_chat` 支持复杂任务流处理 - 扩展提供商支持:新增 ARK(火山方舟)、SiliconFlow(硅基流动) 适配器 - 多模态处理增强:支持URL媒体文件下载转换,提升输入灵活性 - 历史对话支持:AI.analyze 方法支持历史消息上下文和可选 UniMessage 参数 - 文本嵌入功能:新增 `embed`、`analyze_multimodal`、`search_multimodal` 等API - 模型能力系统:新增 `ModelCapabilities` 统一管理模型特性(多模态、工具调用等) 🔧 架构重构与优化 - MCP工具系统重构:配置独立化至 `data/llm/mcp_tools.json`,预置常用工具 - API调用逻辑统一:提取通用 `_perform_api_call` 方法,消除代码重复 - 跨平台兼容:Windows平台MCP工具npx命令自动包装处理 - HTTP客户端增强:兼容不同版本httpx代理配置(0.28+版本适配) 🛠️ API与配置完善 - 统一返回类型:`AI.analyze` 统一返回 `LLMResponse` 类型 - 消息转换工具:新增 `message_to_unimessage` 转换函数 - Gemini适配器增强:URL图片下载编码、动态安全阈值配置 - 缓存管理:新增模型实例缓存和管理功能 - 配置预设:扩展 CommonOverrides 预设配置选项 - 历史管理优化:支持多模态内容占位符替换,提升效率 📚 文档与开发体验 - README全面重写:新增完整使用指南、API参考和架构概览 - 文档内容扩充:补充嵌入模型、缓存管理、工具注册等功能说明 - 日志记录增强:支持详细调试信息输出 - API简化:移除冗余函数,优化接口设计 * 🎨 feat(llm): 统一LLM服务函数文档格式 * ✨ feat(llm): 添加新模型并简化提供者配置加载 * 🚨 auto fix by pre-commit hooks --------- Co-authored-by: webjoin111 <455457521@qq.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
1e7ae38684
commit
48cbb2bf1d
File diff suppressed because it is too large
Load Diff
@ -10,10 +10,10 @@ from .api import (
|
||||
TaskType,
|
||||
analyze,
|
||||
analyze_multimodal,
|
||||
analyze_with_images,
|
||||
chat,
|
||||
code,
|
||||
embed,
|
||||
pipeline_chat,
|
||||
search,
|
||||
search_multimodal,
|
||||
)
|
||||
@ -35,6 +35,7 @@ from .manager import (
|
||||
list_model_identifiers,
|
||||
set_global_default_model_name,
|
||||
)
|
||||
from .tools import tool_registry
|
||||
from .types import (
|
||||
EmbeddingTaskType,
|
||||
LLMContentPart,
|
||||
@ -43,6 +44,7 @@ from .types import (
|
||||
LLMMessage,
|
||||
LLMResponse,
|
||||
LLMTool,
|
||||
MCPCompatible,
|
||||
ModelDetail,
|
||||
ModelInfo,
|
||||
ModelProvider,
|
||||
@ -51,7 +53,7 @@ from .types import (
|
||||
ToolMetadata,
|
||||
UsageInfo,
|
||||
)
|
||||
from .utils import create_multimodal_message, unimsg_to_llm_parts
|
||||
from .utils import create_multimodal_message, message_to_unimessage, unimsg_to_llm_parts
|
||||
|
||||
__all__ = [
|
||||
"AI",
|
||||
@ -65,6 +67,7 @@ __all__ = [
|
||||
"LLMMessage",
|
||||
"LLMResponse",
|
||||
"LLMTool",
|
||||
"MCPCompatible",
|
||||
"ModelDetail",
|
||||
"ModelInfo",
|
||||
"ModelName",
|
||||
@ -76,7 +79,6 @@ __all__ = [
|
||||
"UsageInfo",
|
||||
"analyze",
|
||||
"analyze_multimodal",
|
||||
"analyze_with_images",
|
||||
"chat",
|
||||
"clear_model_cache",
|
||||
"code",
|
||||
@ -88,9 +90,12 @@ __all__ = [
|
||||
"list_available_models",
|
||||
"list_embedding_models",
|
||||
"list_model_identifiers",
|
||||
"message_to_unimessage",
|
||||
"pipeline_chat",
|
||||
"register_llm_configs",
|
||||
"search",
|
||||
"search_multimodal",
|
||||
"set_global_default_model_name",
|
||||
"tool_registry",
|
||||
"unimsg_to_llm_parts",
|
||||
]
|
||||
|
||||
@ -8,7 +8,6 @@ from .base import BaseAdapter, OpenAICompatAdapter, RequestData, ResponseData
|
||||
from .factory import LLMAdapterFactory, get_adapter_for_api_type, register_adapter
|
||||
from .gemini import GeminiAdapter
|
||||
from .openai import OpenAIAdapter
|
||||
from .zhipu import ZhipuAdapter
|
||||
|
||||
LLMAdapterFactory.initialize()
|
||||
|
||||
@ -20,7 +19,6 @@ __all__ = [
|
||||
"OpenAICompatAdapter",
|
||||
"RequestData",
|
||||
"ResponseData",
|
||||
"ZhipuAdapter",
|
||||
"get_adapter_for_api_type",
|
||||
"register_adapter",
|
||||
]
|
||||
|
||||
@ -17,6 +17,7 @@ if TYPE_CHECKING:
|
||||
from ..service import LLMModel
|
||||
from ..types.content import LLMMessage
|
||||
from ..types.enums import EmbeddingTaskType
|
||||
from ..types.models import LLMTool
|
||||
|
||||
|
||||
class RequestData(BaseModel):
|
||||
@ -60,7 +61,7 @@ class BaseAdapter(ABC):
|
||||
"""支持的API类型列表"""
|
||||
pass
|
||||
|
||||
def prepare_simple_request(
|
||||
async def prepare_simple_request(
|
||||
self,
|
||||
model: "LLMModel",
|
||||
api_key: str,
|
||||
@ -86,7 +87,7 @@ class BaseAdapter(ABC):
|
||||
|
||||
config = model._generation_config
|
||||
|
||||
return self.prepare_advanced_request(
|
||||
return await self.prepare_advanced_request(
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
messages=messages,
|
||||
@ -96,13 +97,13 @@ class BaseAdapter(ABC):
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def prepare_advanced_request(
|
||||
async def prepare_advanced_request(
|
||||
self,
|
||||
model: "LLMModel",
|
||||
api_key: str,
|
||||
messages: list["LLMMessage"],
|
||||
config: "LLMGenerationConfig | None" = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
tools: list["LLMTool"] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
) -> RequestData:
|
||||
"""准备高级请求"""
|
||||
@ -238,6 +239,9 @@ class BaseAdapter(ABC):
|
||||
message = choice.get("message", {})
|
||||
content = message.get("content", "")
|
||||
|
||||
if content:
|
||||
content = content.strip()
|
||||
|
||||
parsed_tool_calls: list[LLMToolCall] | None = None
|
||||
if message_tool_calls := message.get("tool_calls"):
|
||||
from ..types.models import LLMToolFunction
|
||||
@ -375,7 +379,7 @@ class BaseAdapter(ABC):
|
||||
if model.temperature is not None:
|
||||
base_config["temperature"] = model.temperature
|
||||
if model.max_tokens is not None:
|
||||
if model.api_type in ["gemini", "gemini_native"]:
|
||||
if model.api_type == "gemini":
|
||||
base_config["maxOutputTokens"] = model.max_tokens
|
||||
else:
|
||||
base_config["max_tokens"] = model.max_tokens
|
||||
@ -401,26 +405,51 @@ class OpenAICompatAdapter(BaseAdapter):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_chat_endpoint(self) -> str:
|
||||
def get_chat_endpoint(self, model: "LLMModel") -> str:
|
||||
"""子类必须实现,返回 chat completions 的端点"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_embedding_endpoint(self) -> str:
|
||||
def get_embedding_endpoint(self, model: "LLMModel") -> str:
|
||||
"""子类必须实现,返回 embeddings 的端点"""
|
||||
pass
|
||||
|
||||
def prepare_advanced_request(
|
||||
async def prepare_simple_request(
|
||||
self,
|
||||
model: "LLMModel",
|
||||
api_key: str,
|
||||
prompt: str,
|
||||
history: list[dict[str, str]] | None = None,
|
||||
) -> RequestData:
|
||||
"""准备简单文本生成请求 - OpenAI兼容API的通用实现"""
|
||||
url = self.get_api_url(model, self.get_chat_endpoint(model))
|
||||
headers = self.get_base_headers(api_key)
|
||||
|
||||
messages = []
|
||||
if history:
|
||||
messages.extend(history)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
body = {
|
||||
"model": model.model_name,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
body = self.apply_config_override(model, body)
|
||||
|
||||
return RequestData(url=url, headers=headers, body=body)
|
||||
|
||||
async def prepare_advanced_request(
|
||||
self,
|
||||
model: "LLMModel",
|
||||
api_key: str,
|
||||
messages: list["LLMMessage"],
|
||||
config: "LLMGenerationConfig | None" = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
tools: list["LLMTool"] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
) -> RequestData:
|
||||
"""准备高级请求 - OpenAI兼容格式"""
|
||||
url = self.get_api_url(model, self.get_chat_endpoint())
|
||||
url = self.get_api_url(model, self.get_chat_endpoint(model))
|
||||
headers = self.get_base_headers(api_key)
|
||||
openai_messages = self.convert_messages_to_openai_format(messages)
|
||||
|
||||
@ -430,7 +459,21 @@ class OpenAICompatAdapter(BaseAdapter):
|
||||
}
|
||||
|
||||
if tools:
|
||||
body["tools"] = tools
|
||||
openai_tools = []
|
||||
for tool in tools:
|
||||
if tool.type == "function" and tool.function:
|
||||
openai_tools.append({"type": "function", "function": tool.function})
|
||||
elif tool.type == "mcp" and tool.mcp_session:
|
||||
if callable(tool.mcp_session):
|
||||
raise ValueError(
|
||||
"适配器接收到未激活的 MCP 会话工厂。"
|
||||
"会话工厂应该在 LLMModel.generate_response 中被激活。"
|
||||
)
|
||||
openai_tools.append(
|
||||
tool.mcp_session.to_api_tool(api_type=self.api_type)
|
||||
)
|
||||
if openai_tools:
|
||||
body["tools"] = openai_tools
|
||||
if tool_choice:
|
||||
body["tool_choice"] = tool_choice
|
||||
|
||||
@ -444,7 +487,7 @@ class OpenAICompatAdapter(BaseAdapter):
|
||||
is_advanced: bool = False,
|
||||
) -> ResponseData:
|
||||
"""解析响应 - 直接使用基类的 OpenAI 格式解析"""
|
||||
_ = model, is_advanced # 未使用的参数
|
||||
_ = model, is_advanced
|
||||
return self.parse_openai_response(response_json)
|
||||
|
||||
def prepare_embedding_request(
|
||||
@ -456,8 +499,8 @@ class OpenAICompatAdapter(BaseAdapter):
|
||||
**kwargs: Any,
|
||||
) -> RequestData:
|
||||
"""准备嵌入请求 - OpenAI兼容格式"""
|
||||
_ = task_type # 未使用的参数
|
||||
url = self.get_api_url(model, self.get_embedding_endpoint())
|
||||
_ = task_type
|
||||
url = self.get_api_url(model, self.get_embedding_endpoint(model))
|
||||
headers = self.get_base_headers(api_key)
|
||||
|
||||
body = {
|
||||
@ -465,7 +508,6 @@ class OpenAICompatAdapter(BaseAdapter):
|
||||
"input": texts,
|
||||
}
|
||||
|
||||
# 应用额外的配置参数
|
||||
if kwargs:
|
||||
body.update(kwargs)
|
||||
|
||||
|
||||
@ -22,10 +22,8 @@ class LLMAdapterFactory:
|
||||
|
||||
from .gemini import GeminiAdapter
|
||||
from .openai import OpenAIAdapter
|
||||
from .zhipu import ZhipuAdapter
|
||||
|
||||
cls.register_adapter(OpenAIAdapter())
|
||||
cls.register_adapter(ZhipuAdapter())
|
||||
cls.register_adapter(GeminiAdapter())
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
||||
from ..service import LLMModel
|
||||
from ..types.content import LLMMessage
|
||||
from ..types.enums import EmbeddingTaskType
|
||||
from ..types.models import LLMToolCall
|
||||
from ..types.models import LLMTool, LLMToolCall
|
||||
|
||||
|
||||
class GeminiAdapter(BaseAdapter):
|
||||
@ -38,30 +38,16 @@ class GeminiAdapter(BaseAdapter):
|
||||
|
||||
return headers
|
||||
|
||||
def prepare_advanced_request(
|
||||
async def prepare_advanced_request(
|
||||
self,
|
||||
model: "LLMModel",
|
||||
api_key: str,
|
||||
messages: list["LLMMessage"],
|
||||
config: "LLMGenerationConfig | None" = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
tools: list["LLMTool"] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
) -> RequestData:
|
||||
"""准备高级请求"""
|
||||
return self._prepare_request(
|
||||
model, api_key, messages, config, tools, tool_choice
|
||||
)
|
||||
|
||||
def _prepare_request(
|
||||
self,
|
||||
model: "LLMModel",
|
||||
api_key: str,
|
||||
messages: list["LLMMessage"],
|
||||
config: "LLMGenerationConfig | None" = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
) -> RequestData:
|
||||
"""准备 Gemini API 请求 - 支持所有高级功能"""
|
||||
effective_config = config if config is not None else model._generation_config
|
||||
|
||||
endpoint = self._get_gemini_endpoint(model, effective_config)
|
||||
@ -78,7 +64,8 @@ class GeminiAdapter(BaseAdapter):
|
||||
system_instruction_parts = [{"text": msg.content}]
|
||||
elif isinstance(msg.content, list):
|
||||
system_instruction_parts = [
|
||||
part.convert_for_api("gemini") for part in msg.content
|
||||
await part.convert_for_api_async("gemini")
|
||||
for part in msg.content
|
||||
]
|
||||
continue
|
||||
|
||||
@ -87,7 +74,9 @@ class GeminiAdapter(BaseAdapter):
|
||||
current_parts.append({"text": msg.content})
|
||||
elif isinstance(msg.content, list):
|
||||
for part_obj in msg.content:
|
||||
current_parts.append(part_obj.convert_for_api("gemini"))
|
||||
current_parts.append(
|
||||
await part_obj.convert_for_api_async("gemini")
|
||||
)
|
||||
gemini_contents.append({"role": "user", "parts": current_parts})
|
||||
|
||||
elif msg.role == "assistant" or msg.role == "model":
|
||||
@ -95,7 +84,9 @@ class GeminiAdapter(BaseAdapter):
|
||||
current_parts.append({"text": msg.content})
|
||||
elif isinstance(msg.content, list):
|
||||
for part_obj in msg.content:
|
||||
current_parts.append(part_obj.convert_for_api("gemini"))
|
||||
current_parts.append(
|
||||
await part_obj.convert_for_api_async("gemini")
|
||||
)
|
||||
|
||||
if msg.tool_calls:
|
||||
import json
|
||||
@ -154,16 +145,22 @@ class GeminiAdapter(BaseAdapter):
|
||||
|
||||
all_tools_for_request = []
|
||||
if tools:
|
||||
for tool_item in tools:
|
||||
if isinstance(tool_item, dict):
|
||||
if "name" in tool_item and "description" in tool_item:
|
||||
all_tools_for_request.append(
|
||||
{"functionDeclarations": [tool_item]}
|
||||
for tool in tools:
|
||||
if tool.type == "function" and tool.function:
|
||||
all_tools_for_request.append(
|
||||
{"functionDeclarations": [tool.function]}
|
||||
)
|
||||
elif tool.type == "mcp" and tool.mcp_session:
|
||||
if callable(tool.mcp_session):
|
||||
raise ValueError(
|
||||
"适配器接收到未激活的 MCP 会话工厂。"
|
||||
"会话工厂应该在 LLMModel.generate_response 中被激活。"
|
||||
)
|
||||
else:
|
||||
all_tools_for_request.append(tool_item)
|
||||
else:
|
||||
all_tools_for_request.append(tool_item)
|
||||
all_tools_for_request.append(
|
||||
tool.mcp_session.to_api_tool(api_type=self.api_type)
|
||||
)
|
||||
elif tool.type == "google_search":
|
||||
all_tools_for_request.append({"googleSearch": {}})
|
||||
|
||||
if effective_config:
|
||||
if getattr(effective_config, "enable_grounding", False):
|
||||
@ -183,11 +180,7 @@ class GeminiAdapter(BaseAdapter):
|
||||
logger.debug("隐式启用代码执行工具。")
|
||||
|
||||
if all_tools_for_request:
|
||||
gemini_api_tools = self._convert_tools_to_gemini_format(
|
||||
all_tools_for_request
|
||||
)
|
||||
if gemini_api_tools:
|
||||
body["tools"] = gemini_api_tools
|
||||
body["tools"] = all_tools_for_request
|
||||
|
||||
final_tool_choice = tool_choice
|
||||
if final_tool_choice is None and effective_config:
|
||||
@ -241,38 +234,6 @@ class GeminiAdapter(BaseAdapter):
|
||||
|
||||
return f"/v1beta/models/{model.model_name}:generateContent"
|
||||
|
||||
def _convert_tools_to_gemini_format(
|
||||
self, tools: list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""转换工具格式为Gemini格式"""
|
||||
gemini_tools = []
|
||||
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function":
|
||||
func = tool["function"]
|
||||
gemini_tool = {
|
||||
"functionDeclarations": [
|
||||
{
|
||||
"name": func["name"],
|
||||
"description": func.get("description", ""),
|
||||
"parameters": func.get("parameters", {}),
|
||||
}
|
||||
]
|
||||
}
|
||||
gemini_tools.append(gemini_tool)
|
||||
elif tool.get("type") == "code_execution":
|
||||
gemini_tools.append(
|
||||
{"codeExecution": {"language": tool.get("language", "python")}}
|
||||
)
|
||||
elif tool.get("type") == "google_search":
|
||||
gemini_tools.append({"googleSearch": {}})
|
||||
elif "googleSearch" in tool:
|
||||
gemini_tools.append({"googleSearch": tool["googleSearch"]})
|
||||
elif "codeExecution" in tool:
|
||||
gemini_tools.append({"codeExecution": tool["codeExecution"]})
|
||||
|
||||
return gemini_tools
|
||||
|
||||
def _convert_tool_choice_to_gemini(
|
||||
self, tool_choice_value: str | dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
@ -395,10 +356,11 @@ class GeminiAdapter(BaseAdapter):
|
||||
for category, threshold in custom_safety_settings.items():
|
||||
safety_settings.append({"category": category, "threshold": threshold})
|
||||
else:
|
||||
from ..config.providers import get_gemini_safety_threshold
|
||||
|
||||
threshold = get_gemini_safety_threshold()
|
||||
for category in safety_categories:
|
||||
safety_settings.append(
|
||||
{"category": category, "threshold": "BLOCK_MEDIUM_AND_ABOVE"}
|
||||
)
|
||||
safety_settings.append({"category": category, "threshold": threshold})
|
||||
|
||||
return safety_settings if safety_settings else None
|
||||
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
"""
|
||||
OpenAI API 适配器
|
||||
|
||||
支持 OpenAI、DeepSeek 和其他 OpenAI 兼容的 API 服务。
|
||||
支持 OpenAI、DeepSeek、智谱AI 和其他 OpenAI 兼容的 API 服务。
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .base import OpenAICompatAdapter, RequestData
|
||||
from .base import OpenAICompatAdapter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..service import LLMModel
|
||||
@ -21,37 +21,18 @@ class OpenAIAdapter(OpenAICompatAdapter):
|
||||
|
||||
@property
|
||||
def supported_api_types(self) -> list[str]:
|
||||
return ["openai", "deepseek", "general_openai_compat"]
|
||||
return ["openai", "deepseek", "zhipu", "general_openai_compat", "ark"]
|
||||
|
||||
def get_chat_endpoint(self) -> str:
|
||||
def get_chat_endpoint(self, model: "LLMModel") -> str:
|
||||
"""返回聊天完成端点"""
|
||||
if model.api_type == "ark":
|
||||
return "/api/v3/chat/completions"
|
||||
if model.api_type == "zhipu":
|
||||
return "/api/paas/v4/chat/completions"
|
||||
return "/v1/chat/completions"
|
||||
|
||||
def get_embedding_endpoint(self) -> str:
|
||||
"""返回嵌入端点"""
|
||||
def get_embedding_endpoint(self, model: "LLMModel") -> str:
|
||||
"""根据API类型返回嵌入端点"""
|
||||
if model.api_type == "zhipu":
|
||||
return "/v4/embeddings"
|
||||
return "/v1/embeddings"
|
||||
|
||||
def prepare_simple_request(
|
||||
self,
|
||||
model: "LLMModel",
|
||||
api_key: str,
|
||||
prompt: str,
|
||||
history: list[dict[str, str]] | None = None,
|
||||
) -> RequestData:
|
||||
"""准备简单文本生成请求 - OpenAI优化实现"""
|
||||
url = self.get_api_url(model, self.get_chat_endpoint())
|
||||
headers = self.get_base_headers(api_key)
|
||||
|
||||
messages = []
|
||||
if history:
|
||||
messages.extend(history)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
body = {
|
||||
"model": model.model_name,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
body = self.apply_config_override(model, body)
|
||||
|
||||
return RequestData(url=url, headers=headers, body=body)
|
||||
|
||||
@ -1,57 +0,0 @@
|
||||
"""
|
||||
智谱 AI API 适配器
|
||||
|
||||
支持智谱 AI 的 GLM 系列模型,使用 OpenAI 兼容的接口格式。
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .base import OpenAICompatAdapter, RequestData
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..service import LLMModel
|
||||
|
||||
|
||||
class ZhipuAdapter(OpenAICompatAdapter):
|
||||
"""智谱AI适配器 - 使用智谱AI专用的OpenAI兼容接口"""
|
||||
|
||||
@property
|
||||
def api_type(self) -> str:
|
||||
return "zhipu"
|
||||
|
||||
@property
|
||||
def supported_api_types(self) -> list[str]:
|
||||
return ["zhipu"]
|
||||
|
||||
def get_chat_endpoint(self) -> str:
|
||||
"""返回智谱AI聊天完成端点"""
|
||||
return "/api/paas/v4/chat/completions"
|
||||
|
||||
def get_embedding_endpoint(self) -> str:
|
||||
"""返回智谱AI嵌入端点"""
|
||||
return "/v4/embeddings"
|
||||
|
||||
def prepare_simple_request(
|
||||
self,
|
||||
model: "LLMModel",
|
||||
api_key: str,
|
||||
prompt: str,
|
||||
history: list[dict[str, str]] | None = None,
|
||||
) -> RequestData:
|
||||
"""准备简单文本生成请求 - 智谱AI优化实现"""
|
||||
url = self.get_api_url(model, self.get_chat_endpoint())
|
||||
headers = self.get_base_headers(api_key)
|
||||
|
||||
messages = []
|
||||
if history:
|
||||
messages.extend(history)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
body = {
|
||||
"model": model.model_name,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
body = self.apply_config_override(model, body)
|
||||
|
||||
return RequestData(url=url, headers=headers, body=body)
|
||||
@ -2,6 +2,7 @@
|
||||
LLM 服务的高级 API 接口
|
||||
"""
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
@ -14,6 +15,7 @@ from zhenxun.services.log import logger
|
||||
from .config import CommonOverrides, LLMGenerationConfig
|
||||
from .config.providers import get_ai_config
|
||||
from .manager import get_global_default_model_name, get_model_instance
|
||||
from .tools import tool_registry
|
||||
from .types import (
|
||||
EmbeddingTaskType,
|
||||
LLMContentPart,
|
||||
@ -56,6 +58,7 @@ class AIConfig:
|
||||
enable_gemini_safe_mode: bool = False
|
||||
enable_gemini_multimodal: bool = False
|
||||
enable_gemini_grounding: bool = False
|
||||
default_preserve_media_in_history: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
"""初始化后从配置中读取默认值"""
|
||||
@ -81,7 +84,7 @@ class AI:
|
||||
"""
|
||||
初始化AI服务
|
||||
|
||||
Args:
|
||||
参数:
|
||||
config: AI 配置.
|
||||
history: 可选的初始对话历史.
|
||||
"""
|
||||
@ -93,16 +96,65 @@ class AI:
|
||||
self.history = []
|
||||
logger.info("AI session history cleared.")
|
||||
|
||||
def _sanitize_message_for_history(self, message: LLMMessage) -> LLMMessage:
|
||||
"""
|
||||
净化用于存入历史记录的消息。
|
||||
将非文本的多模态内容部分替换为文本占位符,以避免重复处理。
|
||||
"""
|
||||
if not isinstance(message.content, list):
|
||||
return message
|
||||
|
||||
sanitized_message = copy.deepcopy(message)
|
||||
content_list = sanitized_message.content
|
||||
if not isinstance(content_list, list):
|
||||
return sanitized_message
|
||||
|
||||
new_content_parts: list[LLMContentPart] = []
|
||||
has_multimodal_content = False
|
||||
|
||||
for part in content_list:
|
||||
if isinstance(part, LLMContentPart) and part.type == "text":
|
||||
new_content_parts.append(part)
|
||||
else:
|
||||
has_multimodal_content = True
|
||||
|
||||
if has_multimodal_content:
|
||||
placeholder = "[用户发送了媒体文件,内容已在首次分析时处理]"
|
||||
text_part_found = False
|
||||
for part in new_content_parts:
|
||||
if part.type == "text":
|
||||
part.text = f"{placeholder} {part.text or ''}".strip()
|
||||
text_part_found = True
|
||||
break
|
||||
if not text_part_found:
|
||||
new_content_parts.insert(0, LLMContentPart.text_part(placeholder))
|
||||
|
||||
sanitized_message.content = new_content_parts
|
||||
return sanitized_message
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
message: str | LLMMessage | list[LLMContentPart],
|
||||
*,
|
||||
model: ModelName = None,
|
||||
preserve_media_in_history: bool | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
进行一次聊天对话。
|
||||
此方法会自动使用和更新会话内的历史记录。
|
||||
|
||||
参数:
|
||||
message: 用户输入的消息。
|
||||
model: 本次对话要使用的模型。
|
||||
preserve_media_in_history: 是否在历史记录中保留原始多模态信息。
|
||||
- True: 保留,用于深度多轮媒体分析。
|
||||
- False: 不保留,替换为占位符,提高效率。
|
||||
- None (默认): 使用AI实例配置的默认值。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
str: 模型的文本响应。
|
||||
"""
|
||||
current_message: LLMMessage
|
||||
if isinstance(message, str):
|
||||
@ -127,7 +179,20 @@ class AI:
|
||||
final_messages, model, "聊天失败", kwargs
|
||||
)
|
||||
|
||||
self.history.append(current_message)
|
||||
should_preserve = (
|
||||
preserve_media_in_history
|
||||
if preserve_media_in_history is not None
|
||||
else self.config.default_preserve_media_in_history
|
||||
)
|
||||
|
||||
if should_preserve:
|
||||
logger.debug("深度分析模式:在历史记录中保留原始多模态消息。")
|
||||
self.history.append(current_message)
|
||||
else:
|
||||
logger.debug("高效模式:净化历史记录中的多模态消息。")
|
||||
sanitized_user_message = self._sanitize_message_for_history(current_message)
|
||||
self.history.append(sanitized_user_message)
|
||||
|
||||
self.history.append(LLMMessage.assistant_text_response(response.text))
|
||||
|
||||
return response.text
|
||||
@ -140,7 +205,18 @@ class AI:
|
||||
timeout: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""代码执行"""
|
||||
"""
|
||||
代码执行
|
||||
|
||||
参数:
|
||||
prompt: 代码执行的提示词。
|
||||
model: 要使用的模型名称。
|
||||
timeout: 代码执行超时时间(秒)。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
dict[str, Any]: 包含执行结果的字典,包含text、code_executions和success字段。
|
||||
"""
|
||||
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
|
||||
|
||||
config = CommonOverrides.gemini_code_execution()
|
||||
@ -168,7 +244,18 @@ class AI:
|
||||
instruction: str = "",
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""信息搜索 - 支持多模态输入"""
|
||||
"""
|
||||
信息搜索 - 支持多模态输入
|
||||
|
||||
参数:
|
||||
query: 搜索查询内容,支持文本或多模态消息。
|
||||
model: 要使用的模型名称。
|
||||
instruction: 搜索指令。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
dict[str, Any]: 包含搜索结果的字典,包含text、sources、queries和success字段
|
||||
"""
|
||||
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
|
||||
config = CommonOverrides.gemini_grounding()
|
||||
|
||||
@ -217,63 +304,69 @@ class AI:
|
||||
|
||||
async def analyze(
|
||||
self,
|
||||
message: UniMessage,
|
||||
message: UniMessage | None,
|
||||
*,
|
||||
instruction: str = "",
|
||||
model: ModelName = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
use_tools: list[str] | None = None,
|
||||
tool_config: dict[str, Any] | None = None,
|
||||
activated_tools: list[LLMTool] | None = None,
|
||||
history: list[LLMMessage] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str | LLMResponse:
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
内容分析 - 接收 UniMessage 物件进行多模态分析和工具呼叫。
|
||||
这是处理复杂互动的主要方法。
|
||||
|
||||
参数:
|
||||
message: 要分析的消息内容(支持多模态)。
|
||||
instruction: 分析指令。
|
||||
model: 要使用的模型名称。
|
||||
use_tools: 要使用的工具名称列表。
|
||||
tool_config: 工具配置。
|
||||
activated_tools: 已激活的工具列表。
|
||||
history: 对话历史记录。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
LLMResponse: 模型的完整响应结果。
|
||||
"""
|
||||
content_parts = await unimsg_to_llm_parts(message)
|
||||
content_parts = await unimsg_to_llm_parts(message or UniMessage())
|
||||
|
||||
final_messages: list[LLMMessage] = []
|
||||
if history:
|
||||
final_messages.extend(history)
|
||||
|
||||
if instruction:
|
||||
final_messages.append(LLMMessage.system(instruction))
|
||||
if not any(msg.role == "system" for msg in final_messages):
|
||||
final_messages.insert(0, LLMMessage.system(instruction))
|
||||
|
||||
if not content_parts:
|
||||
if instruction:
|
||||
if instruction and not history:
|
||||
final_messages.append(LLMMessage.user(instruction))
|
||||
else:
|
||||
elif not history:
|
||||
raise LLMException(
|
||||
"分析内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
|
||||
)
|
||||
else:
|
||||
final_messages.append(LLMMessage.user(content_parts))
|
||||
|
||||
llm_tools = None
|
||||
if tools:
|
||||
llm_tools = []
|
||||
for tool_dict in tools:
|
||||
if isinstance(tool_dict, dict):
|
||||
if "name" in tool_dict and "description" in tool_dict:
|
||||
llm_tool = LLMTool(
|
||||
type="function",
|
||||
function={
|
||||
"name": tool_dict["name"],
|
||||
"description": tool_dict["description"],
|
||||
"parameters": tool_dict.get("parameters", {}),
|
||||
},
|
||||
)
|
||||
llm_tools.append(llm_tool)
|
||||
else:
|
||||
llm_tools.append(LLMTool(**tool_dict))
|
||||
else:
|
||||
llm_tools.append(tool_dict)
|
||||
llm_tools: list[LLMTool] | None = activated_tools
|
||||
if not llm_tools and use_tools:
|
||||
try:
|
||||
llm_tools = tool_registry.get_tools(use_tools)
|
||||
logger.debug(f"已从注册表加载工具定义: {use_tools}")
|
||||
except ValueError as e:
|
||||
raise LLMException(
|
||||
f"加载工具定义失败: {e}",
|
||||
code=LLMErrorCode.CONFIGURATION_ERROR,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
tool_choice = None
|
||||
if tool_config:
|
||||
mode = tool_config.get("mode", "auto")
|
||||
if mode == "auto":
|
||||
tool_choice = "auto"
|
||||
elif mode == "any":
|
||||
tool_choice = "any"
|
||||
elif mode == "none":
|
||||
tool_choice = "none"
|
||||
if mode in ["auto", "any", "none"]:
|
||||
tool_choice = mode
|
||||
|
||||
response = await self._execute_generation(
|
||||
final_messages,
|
||||
@ -284,9 +377,7 @@ class AI:
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
if response.tool_calls:
|
||||
return response
|
||||
return response.text
|
||||
return response
|
||||
|
||||
async def _execute_generation(
|
||||
self,
|
||||
@ -298,7 +389,7 @@ class AI:
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
base_config: LLMGenerationConfig | None = None,
|
||||
) -> LLMResponse:
|
||||
"""通用的生成执行方法,封装重复的模型获取、配置合并和异常处理逻辑"""
|
||||
"""通用的生成执行方法,封装模型获取和单次API调用"""
|
||||
try:
|
||||
resolved_model_name = self._resolve_model_name(
|
||||
model_name or self.config.model
|
||||
@ -311,7 +402,9 @@ class AI:
|
||||
resolved_model_name, override_config=final_config_dict
|
||||
) as model_instance:
|
||||
return await model_instance.generate_response(
|
||||
messages, tools=llm_tools, tool_choice=tool_choice
|
||||
messages,
|
||||
tools=llm_tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
except LLMException:
|
||||
raise
|
||||
@ -380,7 +473,18 @@ class AI:
|
||||
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
||||
**kwargs: Any,
|
||||
) -> list[list[float]]:
|
||||
"""生成文本嵌入向量"""
|
||||
"""
|
||||
生成文本嵌入向量
|
||||
|
||||
参数:
|
||||
texts: 要生成嵌入向量的文本或文本列表。
|
||||
model: 要使用的嵌入模型名称。
|
||||
task_type: 嵌入任务类型。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
list[list[float]]: 文本的嵌入向量列表。
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
if not texts:
|
||||
@ -420,7 +524,17 @@ async def chat(
|
||||
model: ModelName = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""聊天对话便捷函数"""
|
||||
"""
|
||||
聊天对话便捷函数
|
||||
|
||||
参数:
|
||||
message: 用户输入的消息。
|
||||
model: 要使用的模型名称。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
str: 模型的文本响应。
|
||||
"""
|
||||
ai = AI()
|
||||
return await ai.chat(message, model=model, **kwargs)
|
||||
|
||||
@ -432,7 +546,18 @@ async def code(
|
||||
timeout: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""代码执行便捷函数"""
|
||||
"""
|
||||
代码执行便捷函数
|
||||
|
||||
参数:
|
||||
prompt: 代码执行的提示词。
|
||||
model: 要使用的模型名称。
|
||||
timeout: 代码执行超时时间(秒)。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
dict[str, Any]: 包含执行结果的字典。
|
||||
"""
|
||||
ai = AI()
|
||||
return await ai.code(prompt, model=model, timeout=timeout, **kwargs)
|
||||
|
||||
@ -444,45 +569,56 @@ async def search(
|
||||
instruction: str = "",
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""信息搜索便捷函数"""
|
||||
"""
|
||||
信息搜索便捷函数
|
||||
|
||||
参数:
|
||||
query: 搜索查询内容。
|
||||
model: 要使用的模型名称。
|
||||
instruction: 搜索指令。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
dict[str, Any]: 包含搜索结果的字典。
|
||||
"""
|
||||
ai = AI()
|
||||
return await ai.search(query, model=model, instruction=instruction, **kwargs)
|
||||
|
||||
|
||||
async def analyze(
|
||||
message: UniMessage,
|
||||
message: UniMessage | None,
|
||||
*,
|
||||
instruction: str = "",
|
||||
model: ModelName = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
use_tools: list[str] | None = None,
|
||||
tool_config: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str | LLMResponse:
|
||||
"""内容分析便捷函数"""
|
||||
"""
|
||||
内容分析便捷函数
|
||||
|
||||
参数:
|
||||
message: 要分析的消息内容。
|
||||
instruction: 分析指令。
|
||||
model: 要使用的模型名称。
|
||||
use_tools: 要使用的工具名称列表。
|
||||
tool_config: 工具配置。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
str | LLMResponse: 分析结果。
|
||||
"""
|
||||
ai = AI()
|
||||
return await ai.analyze(
|
||||
message,
|
||||
instruction=instruction,
|
||||
model=model,
|
||||
tools=tools,
|
||||
use_tools=use_tools,
|
||||
tool_config=tool_config,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def analyze_with_images(
|
||||
text: str,
|
||||
images: list[str | Path | bytes] | str | Path | bytes,
|
||||
*,
|
||||
instruction: str = "",
|
||||
model: ModelName = None,
|
||||
**kwargs: Any,
|
||||
) -> str | LLMResponse:
|
||||
"""图片分析便捷函数"""
|
||||
message = create_multimodal_message(text=text, images=images)
|
||||
return await analyze(message, instruction=instruction, model=model, **kwargs)
|
||||
|
||||
|
||||
async def analyze_multimodal(
|
||||
text: str | None = None,
|
||||
images: list[str | Path | bytes] | str | Path | bytes | None = None,
|
||||
@ -493,7 +629,21 @@ async def analyze_multimodal(
|
||||
model: ModelName = None,
|
||||
**kwargs: Any,
|
||||
) -> str | LLMResponse:
|
||||
"""多模态分析便捷函数"""
|
||||
"""
|
||||
多模态分析便捷函数
|
||||
|
||||
参数:
|
||||
text: 文本内容。
|
||||
images: 图片文件路径、字节数据或列表。
|
||||
videos: 视频文件路径、字节数据或列表。
|
||||
audios: 音频文件路径、字节数据或列表。
|
||||
instruction: 分析指令。
|
||||
model: 要使用的模型名称。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
str | LLMResponse: 分析结果。
|
||||
"""
|
||||
message = create_multimodal_message(
|
||||
text=text, images=images, videos=videos, audios=audios
|
||||
)
|
||||
@ -510,7 +660,21 @@ async def search_multimodal(
|
||||
model: ModelName = None,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""多模态搜索便捷函数"""
|
||||
"""
|
||||
多模态搜索便捷函数
|
||||
|
||||
参数:
|
||||
text: 文本内容。
|
||||
images: 图片文件路径、字节数据或列表。
|
||||
videos: 视频文件路径、字节数据或列表。
|
||||
audios: 音频文件路径、字节数据或列表。
|
||||
instruction: 搜索指令。
|
||||
model: 要使用的模型名称。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
dict[str, Any]: 包含搜索结果的字典。
|
||||
"""
|
||||
message = create_multimodal_message(
|
||||
text=text, images=images, videos=videos, audios=audios
|
||||
)
|
||||
@ -525,6 +689,101 @@ async def embed(
|
||||
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
||||
**kwargs: Any,
|
||||
) -> list[list[float]]:
|
||||
"""文本嵌入便捷函数"""
|
||||
"""
|
||||
文本嵌入便捷函数
|
||||
|
||||
参数:
|
||||
texts: 要生成嵌入向量的文本或文本列表。
|
||||
model: 要使用的嵌入模型名称。
|
||||
task_type: 嵌入任务类型。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
list[list[float]]: 文本的嵌入向量列表。
|
||||
"""
|
||||
ai = AI()
|
||||
return await ai.embed(texts, model=model, task_type=task_type, **kwargs)
|
||||
|
||||
|
||||
async def pipeline_chat(
|
||||
message: UniMessage | str | list[LLMContentPart],
|
||||
model_chain: list[ModelName],
|
||||
*,
|
||||
initial_instruction: str = "",
|
||||
final_instruction: str = "",
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
AI模型链式调用,前一个模型的输出作为下一个模型的输入。
|
||||
|
||||
参数:
|
||||
message: 初始输入消息(支持多模态)
|
||||
model_chain: 模型名称列表
|
||||
initial_instruction: 第一个模型的系统指令
|
||||
final_instruction: 最后一个模型的系统指令
|
||||
**kwargs: 传递给模型实例的其他参数
|
||||
|
||||
返回:
|
||||
LLMResponse: 最后一个模型的响应结果
|
||||
"""
|
||||
if not model_chain:
|
||||
raise ValueError("模型链`model_chain`不能为空。")
|
||||
|
||||
current_content: str | list[LLMContentPart]
|
||||
if isinstance(message, str):
|
||||
current_content = message
|
||||
elif isinstance(message, list):
|
||||
current_content = message
|
||||
else:
|
||||
current_content = await unimsg_to_llm_parts(message)
|
||||
|
||||
final_response: LLMResponse | None = None
|
||||
|
||||
for i, model_name in enumerate(model_chain):
|
||||
if not model_name:
|
||||
raise ValueError(f"模型链中第 {i + 1} 个模型名称为空。")
|
||||
|
||||
is_first_step = i == 0
|
||||
is_last_step = i == len(model_chain) - 1
|
||||
|
||||
messages_for_step: list[LLMMessage] = []
|
||||
instruction_for_step = ""
|
||||
if is_first_step and initial_instruction:
|
||||
instruction_for_step = initial_instruction
|
||||
elif is_last_step and final_instruction:
|
||||
instruction_for_step = final_instruction
|
||||
|
||||
if instruction_for_step:
|
||||
messages_for_step.append(LLMMessage.system(instruction_for_step))
|
||||
|
||||
messages_for_step.append(LLMMessage.user(current_content))
|
||||
|
||||
logger.info(
|
||||
f"Pipeline Step [{i + 1}/{len(model_chain)}]: "
|
||||
f"使用模型 '{model_name}' 进行处理..."
|
||||
)
|
||||
try:
|
||||
async with await get_model_instance(model_name, **kwargs) as model:
|
||||
response = await model.generate_response(messages_for_step)
|
||||
final_response = response
|
||||
current_content = response.text.strip()
|
||||
if not current_content and not is_last_step:
|
||||
logger.warning(
|
||||
f"模型 '{model_name}' 在中间步骤返回了空内容,流水线可能无法继续。"
|
||||
)
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"在模型链的第 {i + 1} 步 ('{model_name}') 出错: {e}", e=e)
|
||||
raise LLMException(
|
||||
f"流水线在模型 '{model_name}' 处执行失败: {e}",
|
||||
code=LLMErrorCode.GENERATION_FAILED,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
if final_response is None:
|
||||
raise LLMException(
|
||||
"AI流水线未能产生任何响应。", code=LLMErrorCode.GENERATION_FAILED
|
||||
)
|
||||
|
||||
return final_response
|
||||
|
||||
@ -14,6 +14,8 @@ from .generation import (
|
||||
from .presets import CommonOverrides
|
||||
from .providers import (
|
||||
LLMConfig,
|
||||
ToolConfig,
|
||||
get_gemini_safety_threshold,
|
||||
get_llm_config,
|
||||
register_llm_configs,
|
||||
set_default_model,
|
||||
@ -25,8 +27,10 @@ __all__ = [
|
||||
"LLMConfig",
|
||||
"LLMGenerationConfig",
|
||||
"ModelConfigOverride",
|
||||
"ToolConfig",
|
||||
"apply_api_specific_mappings",
|
||||
"create_generation_config_from_kwargs",
|
||||
"get_gemini_safety_threshold",
|
||||
"get_llm_config",
|
||||
"register_llm_configs",
|
||||
"set_default_model",
|
||||
|
||||
@ -111,12 +111,12 @@ class LLMGenerationConfig(ModelConfigOverride):
|
||||
params["temperature"] = self.temperature
|
||||
|
||||
if self.max_tokens is not None:
|
||||
if api_type in ["gemini", "gemini_native"]:
|
||||
if api_type == "gemini":
|
||||
params["maxOutputTokens"] = self.max_tokens
|
||||
else:
|
||||
params["max_tokens"] = self.max_tokens
|
||||
|
||||
if api_type in ["gemini", "gemini_native"]:
|
||||
if api_type == "gemini":
|
||||
if self.top_k is not None:
|
||||
params["topK"] = self.top_k
|
||||
if self.top_p is not None:
|
||||
@ -151,13 +151,13 @@ class LLMGenerationConfig(ModelConfigOverride):
|
||||
if api_type in ["openai", "zhipu", "deepseek", "general_openai_compat"]:
|
||||
params["response_format"] = {"type": "json_object"}
|
||||
logger.debug(f"为 {api_type} 启用 JSON 对象输出模式")
|
||||
elif api_type in ["gemini", "gemini_native"]:
|
||||
elif api_type == "gemini":
|
||||
params["responseMimeType"] = "application/json"
|
||||
if self.response_schema:
|
||||
params["responseSchema"] = self.response_schema
|
||||
logger.debug(f"为 {api_type} 启用 JSON MIME 类型输出模式")
|
||||
|
||||
if api_type in ["gemini", "gemini_native"]:
|
||||
if api_type == "gemini":
|
||||
if (
|
||||
self.response_format != ResponseFormat.JSON
|
||||
and self.response_mime_type is not None
|
||||
@ -214,7 +214,7 @@ def apply_api_specific_mappings(
|
||||
"""应用API特定的参数映射"""
|
||||
mapped_params = params.copy()
|
||||
|
||||
if api_type in ["gemini", "gemini_native"]:
|
||||
if api_type == "gemini":
|
||||
if "max_tokens" in mapped_params:
|
||||
mapped_params["maxOutputTokens"] = mapped_params.pop("max_tokens")
|
||||
if "top_k" in mapped_params:
|
||||
|
||||
@ -71,14 +71,17 @@ class CommonOverrides:
|
||||
|
||||
@staticmethod
|
||||
def gemini_safe() -> LLMGenerationConfig:
|
||||
"""Gemini 安全模式:严格安全设置"""
|
||||
"""Gemini 安全模式:使用配置的安全设置"""
|
||||
from .providers import get_gemini_safety_threshold
|
||||
|
||||
threshold = get_gemini_safety_threshold()
|
||||
return LLMGenerationConfig(
|
||||
temperature=0.5,
|
||||
safety_settings={
|
||||
"HARM_CATEGORY_HARASSMENT": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
"HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
"HARM_CATEGORY_HARASSMENT": threshold,
|
||||
"HARM_CATEGORY_HATE_SPEECH": threshold,
|
||||
"HARM_CATEGORY_SEXUALLY_EXPLICIT": threshold,
|
||||
"HARM_CATEGORY_DANGEROUS_CONTENT": threshold,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -4,15 +4,33 @@ LLM 提供商配置管理
|
||||
负责注册和管理 AI 服务提供商的配置项。
|
||||
"""
|
||||
|
||||
from functools import lru_cache
|
||||
import json
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.configs.path_config import DATA_PATH
|
||||
from zhenxun.configs.utils import parse_as
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||
|
||||
from ..types.models import ModelDetail, ProviderConfig
|
||||
|
||||
|
||||
class ToolConfig(BaseModel):
|
||||
"""MCP类型工具的配置定义"""
|
||||
|
||||
type: str = "mcp"
|
||||
name: str = Field(..., description="工具的唯一名称标识")
|
||||
description: str | None = Field(None, description="工具功能的描述")
|
||||
mcp_config: dict[str, Any] | BaseModel = Field(
|
||||
..., description="MCP服务器的特定配置"
|
||||
)
|
||||
|
||||
|
||||
AI_CONFIG_GROUP = "AI"
|
||||
PROVIDERS_CONFIG_KEY = "PROVIDERS"
|
||||
|
||||
@ -38,6 +56,9 @@ class LLMConfig(BaseModel):
|
||||
providers: list[ProviderConfig] = Field(
|
||||
default_factory=list, description="配置多个 AI 服务提供商及其模型信息"
|
||||
)
|
||||
mcp_tools: list[ToolConfig] = Field(
|
||||
default_factory=list, description="配置可用的外部MCP工具"
|
||||
)
|
||||
|
||||
def get_provider_by_name(self, name: str) -> ProviderConfig | None:
|
||||
"""根据名称获取提供商配置
|
||||
@ -132,7 +153,7 @@ def get_default_providers() -> list[dict[str, Any]]:
|
||||
return [
|
||||
{
|
||||
"name": "DeepSeek",
|
||||
"api_key": "sk-******",
|
||||
"api_key": "YOUR_ARK_API_KEY",
|
||||
"api_base": "https://api.deepseek.com",
|
||||
"api_type": "openai",
|
||||
"models": [
|
||||
@ -146,9 +167,30 @@ def get_default_providers() -> list[dict[str, Any]]:
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "ARK",
|
||||
"api_key": "YOUR_ARK_API_KEY",
|
||||
"api_base": "https://ark.cn-beijing.volces.com",
|
||||
"api_type": "ark",
|
||||
"models": [
|
||||
{"model_name": "deepseek-r1-250528"},
|
||||
{"model_name": "doubao-seed-1-6-250615"},
|
||||
{"model_name": "doubao-seed-1-6-flash-250615"},
|
||||
{"model_name": "doubao-seed-1-6-thinking-250615"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "siliconflow",
|
||||
"api_key": "YOUR_ARK_API_KEY",
|
||||
"api_base": "https://api.siliconflow.cn",
|
||||
"api_type": "openai",
|
||||
"models": [
|
||||
{"model_name": "deepseek-ai/DeepSeek-V3"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "GLM",
|
||||
"api_key": "",
|
||||
"api_key": "YOUR_ARK_API_KEY",
|
||||
"api_base": "https://open.bigmodel.cn",
|
||||
"api_type": "zhipu",
|
||||
"models": [
|
||||
@ -167,12 +209,41 @@ def get_default_providers() -> list[dict[str, Any]]:
|
||||
"api_type": "gemini",
|
||||
"models": [
|
||||
{"model_name": "gemini-2.0-flash"},
|
||||
{"model_name": "gemini-2.5-flash-preview-05-20"},
|
||||
{"model_name": "gemini-2.5-flash"},
|
||||
{"model_name": "gemini-2.5-pro"},
|
||||
{"model_name": "gemini-2.5-flash-lite-preview-06-17"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def get_default_mcp_tools() -> dict[str, Any]:
|
||||
"""
|
||||
获取默认的MCP工具配置,用于在文件不存在时创建。
|
||||
包含了 baidu-map, Context7, 和 sequential-thinking.
|
||||
"""
|
||||
return {
|
||||
"mcpServers": {
|
||||
"baidu-map": {
|
||||
"command": "npx",
|
||||
"args": ["-y", "@baidumap/mcp-server-baidu-map"],
|
||||
"env": {"BAIDU_MAP_API_KEY": "<YOUR_BAIDU_MAP_API_KEY>"},
|
||||
"description": "百度地图工具,提供地理编码、路线规划等功能。",
|
||||
},
|
||||
"sequential-thinking": {
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-sequential-thinking"],
|
||||
"description": "顺序思维工具,用于帮助模型进行多步骤推理。",
|
||||
},
|
||||
"Context7": {
|
||||
"command": "npx",
|
||||
"args": ["-y", "@upstash/context7-mcp@latest"],
|
||||
"description": "Upstash 提供的上下文管理和记忆工具。",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def register_llm_configs():
|
||||
"""注册 LLM 服务的配置项"""
|
||||
logger.info("注册 LLM 服务的配置项")
|
||||
@ -214,6 +285,19 @@ def register_llm_configs():
|
||||
help="LLM服务请求重试的基础延迟时间(秒)",
|
||||
type=int,
|
||||
)
|
||||
Config.add_plugin_config(
|
||||
AI_CONFIG_GROUP,
|
||||
"gemini_safety_threshold",
|
||||
"BLOCK_MEDIUM_AND_ABOVE",
|
||||
help=(
|
||||
"Gemini 安全过滤阈值 "
|
||||
"(BLOCK_LOW_AND_ABOVE: 阻止低级别及以上, "
|
||||
"BLOCK_MEDIUM_AND_ABOVE: 阻止中等级别及以上, "
|
||||
"BLOCK_ONLY_HIGH: 只阻止高级别, "
|
||||
"BLOCK_NONE: 不阻止)"
|
||||
),
|
||||
type=str,
|
||||
)
|
||||
|
||||
Config.add_plugin_config(
|
||||
AI_CONFIG_GROUP,
|
||||
@ -225,24 +309,111 @@ def register_llm_configs():
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_llm_config() -> LLMConfig:
|
||||
"""获取 LLM 配置实例
|
||||
|
||||
返回:
|
||||
LLMConfig: LLM 配置实例
|
||||
"""
|
||||
"""获取 LLM 配置实例,现在会从新的 JSON 文件加载 MCP 工具"""
|
||||
ai_config = get_ai_config()
|
||||
|
||||
llm_data_path = DATA_PATH / "llm"
|
||||
mcp_tools_path = llm_data_path / "mcp_tools.json"
|
||||
|
||||
mcp_tools_list = []
|
||||
mcp_servers_dict = {}
|
||||
|
||||
if not mcp_tools_path.exists():
|
||||
logger.info(f"未找到 MCP 工具配置文件,将在 '{mcp_tools_path}' 创建一个。")
|
||||
llm_data_path.mkdir(parents=True, exist_ok=True)
|
||||
default_mcp_config = get_default_mcp_tools()
|
||||
try:
|
||||
with mcp_tools_path.open("w", encoding="utf-8") as f:
|
||||
json.dump(default_mcp_config, f, ensure_ascii=False, indent=2)
|
||||
mcp_servers_dict = default_mcp_config.get("mcpServers", {})
|
||||
except Exception as e:
|
||||
logger.error(f"创建默认 MCP 配置文件失败: {e}", e=e)
|
||||
mcp_servers_dict = {}
|
||||
else:
|
||||
try:
|
||||
with mcp_tools_path.open("r", encoding="utf-8") as f:
|
||||
mcp_data = json.load(f)
|
||||
mcp_servers_dict = mcp_data.get("mcpServers", {})
|
||||
if not isinstance(mcp_servers_dict, dict):
|
||||
logger.warning(
|
||||
f"'{mcp_tools_path}' 中的 'mcpServers' 键不是一个字典,"
|
||||
f"将使用空配置。"
|
||||
)
|
||||
mcp_servers_dict = {}
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"解析 MCP 配置文件 '{mcp_tools_path}' 失败: {e}", e=e)
|
||||
except Exception as e:
|
||||
logger.error(f"读取 MCP 配置文件时发生未知错误: {e}", e=e)
|
||||
mcp_servers_dict = {}
|
||||
|
||||
if sys.platform == "win32":
|
||||
logger.debug("检测到Windows平台,正在调整MCP工具的npx命令...")
|
||||
for name, config in mcp_servers_dict.items():
|
||||
if isinstance(config, dict) and config.get("command") == "npx":
|
||||
logger.info(f"为工具 '{name}' 包装npx命令以兼容Windows。")
|
||||
original_args = config.get("args", [])
|
||||
config["command"] = "cmd"
|
||||
config["args"] = ["/c", "npx", *original_args]
|
||||
|
||||
if mcp_servers_dict:
|
||||
mcp_tools_list = [
|
||||
{
|
||||
"name": name,
|
||||
"type": "mcp",
|
||||
"description": config.get("description", f"MCP tool for {name}"),
|
||||
"mcp_config": config,
|
||||
}
|
||||
for name, config in mcp_servers_dict.items()
|
||||
if isinstance(config, dict)
|
||||
]
|
||||
|
||||
from ..tools.registry import tool_registry
|
||||
|
||||
for tool_dict in mcp_tools_list:
|
||||
if isinstance(tool_dict, dict):
|
||||
tool_name = tool_dict.get("name")
|
||||
if not tool_name:
|
||||
continue
|
||||
|
||||
config_model = tool_registry.get_mcp_config_model(tool_name)
|
||||
if not config_model:
|
||||
logger.debug(
|
||||
f"MCP工具 '{tool_name}' 没有注册其配置模型,"
|
||||
f"将跳过特定配置验证,直接使用原始配置字典。"
|
||||
)
|
||||
continue
|
||||
|
||||
mcp_config_data = tool_dict.get("mcp_config", {})
|
||||
try:
|
||||
parsed_mcp_config = parse_as(config_model, mcp_config_data)
|
||||
tool_dict["mcp_config"] = parsed_mcp_config
|
||||
except Exception as e:
|
||||
raise ValueError(f"MCP工具 '{tool_name}' 的 `mcp_config` 配置错误: {e}")
|
||||
|
||||
config_data = {
|
||||
"default_model_name": ai_config.get("default_model_name"),
|
||||
"proxy": ai_config.get("proxy"),
|
||||
"timeout": ai_config.get("timeout", 180),
|
||||
"max_retries_llm": ai_config.get("max_retries_llm", 3),
|
||||
"retry_delay_llm": ai_config.get("retry_delay_llm", 2),
|
||||
"providers": ai_config.get(PROVIDERS_CONFIG_KEY, []),
|
||||
PROVIDERS_CONFIG_KEY: ai_config.get(PROVIDERS_CONFIG_KEY, []),
|
||||
"mcp_tools": mcp_tools_list,
|
||||
}
|
||||
|
||||
return LLMConfig(**config_data)
|
||||
return parse_as(LLMConfig, config_data)
|
||||
|
||||
|
||||
def get_gemini_safety_threshold() -> str:
|
||||
"""获取 Gemini 安全过滤阈值配置
|
||||
|
||||
返回:
|
||||
str: 安全过滤阈值
|
||||
"""
|
||||
ai_config = get_ai_config()
|
||||
return ai_config.get("gemini_safety_threshold", "BLOCK_MEDIUM_AND_ABOVE")
|
||||
|
||||
|
||||
def validate_llm_config() -> tuple[bool, list[str]]:
|
||||
@ -326,3 +497,17 @@ def set_default_model(provider_model_name: str | None) -> bool:
|
||||
logger.info("默认模型已清除")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@PriorityLifecycle.on_startup(priority=10)
|
||||
async def _init_llm_config_on_startup():
|
||||
"""
|
||||
在服务启动时主动调用一次 get_llm_config,
|
||||
以触发必要的初始化操作,例如创建默认的 mcp_tools.json 文件。
|
||||
"""
|
||||
logger.info("正在初始化 LLM 配置并检查 MCP 工具文件...")
|
||||
try:
|
||||
get_llm_config()
|
||||
logger.info("LLM 配置初始化完成。")
|
||||
except Exception as e:
|
||||
logger.error(f"LLM 配置初始化时发生错误: {e}", e=e)
|
||||
|
||||
@ -49,12 +49,36 @@ class LLMHttpClient:
|
||||
max_keepalive_connections=self.config.max_keepalive_connections,
|
||||
)
|
||||
timeout = httpx.Timeout(self.config.timeout)
|
||||
|
||||
client_kwargs = {}
|
||||
if self.config.proxy:
|
||||
try:
|
||||
version_parts = httpx.__version__.split(".")
|
||||
major = int(
|
||||
"".join(c for c in version_parts[0] if c.isdigit())
|
||||
)
|
||||
minor = (
|
||||
int("".join(c for c in version_parts[1] if c.isdigit()))
|
||||
if len(version_parts) > 1
|
||||
else 0
|
||||
)
|
||||
if (major, minor) >= (0, 28):
|
||||
client_kwargs["proxy"] = self.config.proxy
|
||||
else:
|
||||
client_kwargs["proxies"] = self.config.proxy
|
||||
except (ValueError, IndexError):
|
||||
client_kwargs["proxies"] = self.config.proxy
|
||||
logger.warning(
|
||||
f"无法解析 httpx 版本 '{httpx.__version__}',"
|
||||
"LLM模块将默认使用旧版 'proxies' 参数语法。"
|
||||
)
|
||||
|
||||
self._client = httpx.AsyncClient(
|
||||
headers=headers,
|
||||
limits=limits,
|
||||
timeout=timeout,
|
||||
proxies=self.config.proxy,
|
||||
follow_redirects=True,
|
||||
**client_kwargs,
|
||||
)
|
||||
if self._client is None:
|
||||
raise LLMException(
|
||||
@ -156,7 +180,16 @@ async def create_llm_http_client(
|
||||
timeout: int = 180,
|
||||
proxy: str | None = None,
|
||||
) -> LLMHttpClient:
|
||||
"""创建LLM HTTP客户端"""
|
||||
"""
|
||||
创建LLM HTTP客户端
|
||||
|
||||
参数:
|
||||
timeout: 超时时间(秒)。
|
||||
proxy: 代理服务器地址。
|
||||
|
||||
返回:
|
||||
LLMHttpClient: HTTP客户端实例。
|
||||
"""
|
||||
config = HttpClientConfig(timeout=timeout, proxy=proxy)
|
||||
return LLMHttpClient(config)
|
||||
|
||||
@ -185,7 +218,20 @@ async def with_smart_retry(
|
||||
provider_name: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""智能重试装饰器 - 支持Key轮询和错误分类"""
|
||||
"""
|
||||
智能重试装饰器 - 支持Key轮询和错误分类
|
||||
|
||||
参数:
|
||||
func: 要重试的异步函数。
|
||||
*args: 传递给函数的位置参数。
|
||||
retry_config: 重试配置。
|
||||
key_store: API密钥状态存储。
|
||||
provider_name: 提供商名称。
|
||||
**kwargs: 传递给函数的关键字参数。
|
||||
|
||||
返回:
|
||||
Any: 函数执行结果。
|
||||
"""
|
||||
config = retry_config or RetryConfig()
|
||||
last_exception: Exception | None = None
|
||||
failed_keys: set[str] = set()
|
||||
@ -294,7 +340,17 @@ class KeyStatusStore:
|
||||
api_keys: list[str],
|
||||
exclude_keys: set[str] | None = None,
|
||||
) -> str | None:
|
||||
"""获取下一个可用的API密钥(轮询策略)"""
|
||||
"""
|
||||
获取下一个可用的API密钥(轮询策略)
|
||||
|
||||
参数:
|
||||
provider_name: 提供商名称。
|
||||
api_keys: API密钥列表。
|
||||
exclude_keys: 要排除的密钥集合。
|
||||
|
||||
返回:
|
||||
str | None: 可用的API密钥,如果没有可用密钥则返回None。
|
||||
"""
|
||||
if not api_keys:
|
||||
return None
|
||||
|
||||
@ -338,7 +394,13 @@ class KeyStatusStore:
|
||||
logger.debug(f"记录API密钥成功使用: {self._get_key_id(api_key)}")
|
||||
|
||||
async def record_failure(self, api_key: str, status_code: int | None):
|
||||
"""记录失败使用"""
|
||||
"""
|
||||
记录失败使用
|
||||
|
||||
参数:
|
||||
api_key: API密钥。
|
||||
status_code: HTTP状态码。
|
||||
"""
|
||||
key_id = self._get_key_id(api_key)
|
||||
async with self._lock:
|
||||
if status_code in [401, 403]:
|
||||
@ -356,7 +418,15 @@ class KeyStatusStore:
|
||||
logger.info(f"重置API密钥状态: {self._get_key_id(api_key)}")
|
||||
|
||||
async def get_key_stats(self, api_keys: list[str]) -> dict[str, dict]:
|
||||
"""获取密钥使用统计"""
|
||||
"""
|
||||
获取密钥使用统计
|
||||
|
||||
参数:
|
||||
api_keys: API密钥列表。
|
||||
|
||||
返回:
|
||||
dict[str, dict]: 密钥统计信息字典。
|
||||
"""
|
||||
stats = {}
|
||||
async with self._lock:
|
||||
for key in api_keys:
|
||||
|
||||
@ -17,6 +17,7 @@ from .config.providers import AI_CONFIG_GROUP, PROVIDERS_CONFIG_KEY, get_ai_conf
|
||||
from .core import http_client_manager, key_store
|
||||
from .service import LLMModel
|
||||
from .types import LLMErrorCode, LLMException, ModelDetail, ProviderConfig
|
||||
from .types.capabilities import get_model_capabilities
|
||||
|
||||
DEFAULT_MODEL_NAME_KEY = "default_model_name"
|
||||
PROXY_KEY = "proxy"
|
||||
@ -115,57 +116,30 @@ def get_default_api_base_for_type(api_type: str) -> str | None:
|
||||
|
||||
|
||||
def get_configured_providers() -> list[ProviderConfig]:
|
||||
"""从配置中获取Provider列表 - 简化版本"""
|
||||
"""从配置中获取Provider列表 - 简化和修正版本"""
|
||||
ai_config = get_ai_config()
|
||||
providers_raw = ai_config.get(PROVIDERS_CONFIG_KEY, [])
|
||||
if not isinstance(providers_raw, list):
|
||||
providers = ai_config.get(PROVIDERS_CONFIG_KEY, [])
|
||||
|
||||
if not isinstance(providers, list):
|
||||
logger.error(
|
||||
f"配置项 {AI_CONFIG_GROUP}.{PROVIDERS_CONFIG_KEY} 不是一个列表,"
|
||||
f"配置项 {AI_CONFIG_GROUP}.{PROVIDERS_CONFIG_KEY} 的值不是一个列表,"
|
||||
f"将使用空列表。"
|
||||
)
|
||||
return []
|
||||
|
||||
valid_providers = []
|
||||
for i, item in enumerate(providers_raw):
|
||||
if not isinstance(item, dict):
|
||||
logger.warning(f"配置文件中第 {i + 1} 项不是字典格式,已跳过。")
|
||||
continue
|
||||
|
||||
try:
|
||||
if not item.get("name"):
|
||||
logger.warning(f"Provider {i + 1} 缺少 'name' 字段,已跳过。")
|
||||
continue
|
||||
|
||||
if not item.get("api_key"):
|
||||
logger.warning(
|
||||
f"Provider '{item['name']}' 缺少 'api_key' 字段,已跳过。"
|
||||
)
|
||||
continue
|
||||
|
||||
if "api_type" not in item or not item["api_type"]:
|
||||
provider_name = item.get("name", "").lower()
|
||||
if "glm" in provider_name or "zhipu" in provider_name:
|
||||
item["api_type"] = "zhipu"
|
||||
elif "gemini" in provider_name or "google" in provider_name:
|
||||
item["api_type"] = "gemini"
|
||||
else:
|
||||
item["api_type"] = "openai"
|
||||
|
||||
if "api_base" not in item or not item["api_base"]:
|
||||
api_type = item.get("api_type")
|
||||
if api_type:
|
||||
default_api_base = get_default_api_base_for_type(api_type)
|
||||
if default_api_base:
|
||||
item["api_base"] = default_api_base
|
||||
|
||||
if "models" not in item:
|
||||
item["models"] = [{"model_name": item.get("name", "default")}]
|
||||
|
||||
provider_conf = ProviderConfig(**item)
|
||||
valid_providers.append(provider_conf)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"解析配置文件中 Provider {i + 1} 时出错: {e},已跳过。")
|
||||
for i, item in enumerate(providers):
|
||||
if isinstance(item, ProviderConfig):
|
||||
if not item.api_base:
|
||||
default_api_base = get_default_api_base_for_type(item.api_type)
|
||||
if default_api_base:
|
||||
item.api_base = default_api_base
|
||||
valid_providers.append(item)
|
||||
else:
|
||||
logger.warning(
|
||||
f"配置文件中第 {i + 1} 项未能正确解析为 ProviderConfig 对象,"
|
||||
f"已跳过。实际类型: {type(item)}"
|
||||
)
|
||||
|
||||
return valid_providers
|
||||
|
||||
@ -173,14 +147,15 @@ def get_configured_providers() -> list[ProviderConfig]:
|
||||
def find_model_config(
|
||||
provider_name: str, model_name: str
|
||||
) -> tuple[ProviderConfig, ModelDetail] | None:
|
||||
"""在配置中查找指定的 Provider 和 ModelDetail
|
||||
"""
|
||||
在配置中查找指定的 Provider 和 ModelDetail
|
||||
|
||||
Args:
|
||||
参数:
|
||||
provider_name: 提供商名称
|
||||
model_name: 模型名称
|
||||
|
||||
Returns:
|
||||
找到的 (ProviderConfig, ModelDetail) 元组,未找到则返回 None
|
||||
返回:
|
||||
tuple[ProviderConfig, ModelDetail] | None: 找到的配置元组,未找到则返回 None
|
||||
"""
|
||||
providers = get_configured_providers()
|
||||
|
||||
@ -221,10 +196,11 @@ def _get_model_identifiers(provider_name: str, model_detail: ModelDetail) -> lis
|
||||
|
||||
|
||||
def list_model_identifiers() -> dict[str, list[str]]:
|
||||
"""列出所有模型的可用标识符
|
||||
"""
|
||||
列出所有模型的可用标识符
|
||||
|
||||
Returns:
|
||||
字典,键为模型的完整名称,值为该模型的所有可用标识符列表
|
||||
返回:
|
||||
dict[str, list[str]]: 字典,键为模型的完整名称,值为该模型的所有可用标识符列表
|
||||
"""
|
||||
providers = get_configured_providers()
|
||||
result = {}
|
||||
@ -248,7 +224,16 @@ async def get_model_instance(
|
||||
provider_model_name: str | None = None,
|
||||
override_config: dict[str, Any] | None = None,
|
||||
) -> LLMModel:
|
||||
"""根据 'ProviderName/ModelName' 字符串获取并实例化 LLMModel (异步版本)"""
|
||||
"""
|
||||
根据 'ProviderName/ModelName' 字符串获取并实例化 LLMModel (异步版本)
|
||||
|
||||
参数:
|
||||
provider_model_name: 模型名称,格式为 'ProviderName/ModelName'。
|
||||
override_config: 覆盖配置字典。
|
||||
|
||||
返回:
|
||||
LLMModel: 模型实例。
|
||||
"""
|
||||
cache_key = _make_cache_key(provider_model_name, override_config)
|
||||
cached_model = _get_cached_model(cache_key)
|
||||
if cached_model:
|
||||
@ -292,6 +277,10 @@ async def get_model_instance(
|
||||
|
||||
provider_config_found, model_detail_found = config_tuple_found
|
||||
|
||||
capabilities = get_model_capabilities(model_detail_found.model_name)
|
||||
|
||||
model_detail_found.is_embedding_model = capabilities.is_embedding_model
|
||||
|
||||
ai_config = get_ai_config()
|
||||
global_proxy_setting = ai_config.get(PROXY_KEY)
|
||||
default_timeout = (
|
||||
@ -322,6 +311,7 @@ async def get_model_instance(
|
||||
model_detail=model_detail_found,
|
||||
key_store=key_store,
|
||||
http_client=shared_http_client,
|
||||
capabilities=capabilities,
|
||||
)
|
||||
|
||||
if override_config:
|
||||
@ -357,7 +347,15 @@ def get_global_default_model_name() -> str | None:
|
||||
|
||||
|
||||
def set_global_default_model_name(provider_model_name: str | None) -> bool:
|
||||
"""设置全局默认模型名称"""
|
||||
"""
|
||||
设置全局默认模型名称
|
||||
|
||||
参数:
|
||||
provider_model_name: 模型名称,格式为 'ProviderName/ModelName'。
|
||||
|
||||
返回:
|
||||
bool: 设置是否成功。
|
||||
"""
|
||||
if provider_model_name:
|
||||
prov_name, mod_name = parse_provider_model_string(provider_model_name)
|
||||
if not prov_name or not mod_name or not find_model_config(prov_name, mod_name):
|
||||
@ -377,7 +375,12 @@ def set_global_default_model_name(provider_model_name: str | None) -> bool:
|
||||
|
||||
|
||||
async def get_key_usage_stats() -> dict[str, Any]:
|
||||
"""获取所有Provider的Key使用统计"""
|
||||
"""
|
||||
获取所有Provider的Key使用统计
|
||||
|
||||
返回:
|
||||
dict[str, Any]: 包含所有Provider的Key使用统计信息。
|
||||
"""
|
||||
providers = get_configured_providers()
|
||||
stats = {}
|
||||
|
||||
@ -400,7 +403,16 @@ async def get_key_usage_stats() -> dict[str, Any]:
|
||||
|
||||
|
||||
async def reset_key_status(provider_name: str, api_key: str | None = None) -> bool:
|
||||
"""重置指定Provider的Key状态"""
|
||||
"""
|
||||
重置指定Provider的Key状态
|
||||
|
||||
参数:
|
||||
provider_name: 提供商名称。
|
||||
api_key: 要重置的特定API密钥,如果为None则重置所有密钥。
|
||||
|
||||
返回:
|
||||
bool: 重置是否成功。
|
||||
"""
|
||||
providers = get_configured_providers()
|
||||
target_provider = None
|
||||
|
||||
|
||||
@ -6,11 +6,13 @@ LLM 模型实现类
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Awaitable, Callable
|
||||
from contextlib import AsyncExitStack
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
from .adapters.base import RequestData
|
||||
from .config import LLMGenerationConfig
|
||||
from .config.providers import get_ai_config
|
||||
from .core import (
|
||||
@ -30,6 +32,8 @@ from .types import (
|
||||
ModelDetail,
|
||||
ProviderConfig,
|
||||
)
|
||||
from .types.capabilities import ModelCapabilities, ModelModality
|
||||
from .utils import _sanitize_request_body_for_logging
|
||||
|
||||
|
||||
class LLMModelBase(ABC):
|
||||
@ -42,7 +46,17 @@ class LLMModelBase(ABC):
|
||||
history: list[dict[str, str]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""生成文本"""
|
||||
"""
|
||||
生成文本
|
||||
|
||||
参数:
|
||||
prompt: 输入提示词。
|
||||
history: 对话历史记录。
|
||||
**kwargs: 其他参数。
|
||||
|
||||
返回:
|
||||
str: 生成的文本。
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -54,7 +68,19 @@ class LLMModelBase(ABC):
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""生成高级响应"""
|
||||
"""
|
||||
生成高级响应
|
||||
|
||||
参数:
|
||||
messages: 消息列表。
|
||||
config: 生成配置。
|
||||
tools: 工具列表。
|
||||
tool_choice: 工具选择策略。
|
||||
**kwargs: 其他参数。
|
||||
|
||||
返回:
|
||||
LLMResponse: 模型响应。
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -64,7 +90,17 @@ class LLMModelBase(ABC):
|
||||
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
||||
**kwargs: Any,
|
||||
) -> list[list[float]]:
|
||||
"""生成文本嵌入向量"""
|
||||
"""
|
||||
生成文本嵌入向量
|
||||
|
||||
参数:
|
||||
texts: 文本列表。
|
||||
task_type: 嵌入任务类型。
|
||||
**kwargs: 其他参数。
|
||||
|
||||
返回:
|
||||
list[list[float]]: 嵌入向量列表。
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@ -77,12 +113,14 @@ class LLMModel(LLMModelBase):
|
||||
model_detail: ModelDetail,
|
||||
key_store: KeyStatusStore,
|
||||
http_client: LLMHttpClient,
|
||||
capabilities: ModelCapabilities,
|
||||
config_override: LLMGenerationConfig | None = None,
|
||||
):
|
||||
self.provider_config = provider_config
|
||||
self.model_detail = model_detail
|
||||
self.key_store = key_store
|
||||
self.http_client: LLMHttpClient = http_client
|
||||
self.capabilities = capabilities
|
||||
self._generation_config = config_override
|
||||
|
||||
self.provider_name = provider_config.name
|
||||
@ -99,6 +137,34 @@ class LLMModel(LLMModelBase):
|
||||
|
||||
self._is_closed = False
|
||||
|
||||
def can_process_images(self) -> bool:
|
||||
"""检查模型是否支持图片作为输入。"""
|
||||
return ModelModality.IMAGE in self.capabilities.input_modalities
|
||||
|
||||
def can_process_video(self) -> bool:
|
||||
"""检查模型是否支持视频作为输入。"""
|
||||
return ModelModality.VIDEO in self.capabilities.input_modalities
|
||||
|
||||
def can_process_audio(self) -> bool:
|
||||
"""检查模型是否支持音频作为输入。"""
|
||||
return ModelModality.AUDIO in self.capabilities.input_modalities
|
||||
|
||||
def can_generate_images(self) -> bool:
|
||||
"""检查模型是否支持生成图片。"""
|
||||
return ModelModality.IMAGE in self.capabilities.output_modalities
|
||||
|
||||
def can_generate_audio(self) -> bool:
|
||||
"""检查模型是否支持生成音频 (TTS)。"""
|
||||
return ModelModality.AUDIO in self.capabilities.output_modalities
|
||||
|
||||
def can_use_tools(self) -> bool:
|
||||
"""检查模型是否支持工具调用/函数调用。"""
|
||||
return self.capabilities.supports_tool_calling
|
||||
|
||||
def is_embedding_model(self) -> bool:
|
||||
"""检查这是否是一个嵌入模型。"""
|
||||
return self.capabilities.is_embedding_model
|
||||
|
||||
async def _get_http_client(self) -> LLMHttpClient:
|
||||
"""获取HTTP客户端"""
|
||||
if self.http_client.is_closed:
|
||||
@ -135,24 +201,54 @@ class LLMModel(LLMModelBase):
|
||||
|
||||
return selected_key
|
||||
|
||||
async def _execute_embedding_request(
|
||||
async def _perform_api_call(
|
||||
self,
|
||||
adapter,
|
||||
texts: list[str],
|
||||
task_type: EmbeddingTaskType | str,
|
||||
http_client: LLMHttpClient,
|
||||
prepare_request_func: Callable[[str], Awaitable["RequestData"]],
|
||||
parse_response_func: Callable[[dict[str, Any]], Any],
|
||||
http_client: "LLMHttpClient",
|
||||
failed_keys: set[str] | None = None,
|
||||
) -> list[list[float]]:
|
||||
"""执行单次嵌入请求 - 供重试机制调用"""
|
||||
log_context: str = "API",
|
||||
) -> Any:
|
||||
"""
|
||||
执行API调用的通用核心方法。
|
||||
|
||||
该方法封装了以下通用逻辑:
|
||||
1. 选择API密钥。
|
||||
2. 准备和记录请求。
|
||||
3. 发送HTTP POST请求。
|
||||
4. 处理HTTP错误和API特定错误。
|
||||
5. 记录密钥使用状态。
|
||||
6. 解析成功的响应。
|
||||
|
||||
参数:
|
||||
prepare_request_func: 准备请求的函数。
|
||||
parse_response_func: 解析响应的函数。
|
||||
http_client: HTTP客户端。
|
||||
failed_keys: 失败的密钥集合。
|
||||
log_context: 日志上下文。
|
||||
|
||||
返回:
|
||||
Any: 解析后的响应数据。
|
||||
"""
|
||||
api_key = await self._select_api_key(failed_keys)
|
||||
|
||||
try:
|
||||
request_data = adapter.prepare_embedding_request(
|
||||
model=self,
|
||||
api_key=api_key,
|
||||
texts=texts,
|
||||
task_type=task_type,
|
||||
request_data = await prepare_request_func(api_key)
|
||||
|
||||
logger.info(
|
||||
f"🌐 发起LLM请求 - 模型: {self.provider_name}/{self.model_name} "
|
||||
f"[{log_context}]"
|
||||
)
|
||||
logger.debug(f"📡 请求URL: {request_data.url}")
|
||||
masked_key = (
|
||||
f"{api_key[:8]}...{api_key[-4:] if len(api_key) > 12 else '***'}"
|
||||
)
|
||||
logger.debug(f"🔑 API密钥: {masked_key}")
|
||||
logger.debug(f"📋 请求头: {dict(request_data.headers)}")
|
||||
|
||||
sanitized_body = _sanitize_request_body_for_logging(request_data.body)
|
||||
request_body_str = json.dumps(sanitized_body, ensure_ascii=False, indent=2)
|
||||
logger.debug(f"📦 请求体: {request_body_str}")
|
||||
|
||||
http_response = await http_client.post(
|
||||
request_data.url,
|
||||
@ -160,121 +256,16 @@ class LLMModel(LLMModelBase):
|
||||
json=request_data.body,
|
||||
)
|
||||
|
||||
if http_response.status_code != 200:
|
||||
error_text = http_response.text
|
||||
logger.error(
|
||||
f"HTTP嵌入请求失败: {http_response.status_code} - {error_text}"
|
||||
)
|
||||
await self.key_store.record_failure(api_key, http_response.status_code)
|
||||
|
||||
error_code = LLMErrorCode.API_REQUEST_FAILED
|
||||
if http_response.status_code in [401, 403]:
|
||||
error_code = LLMErrorCode.API_KEY_INVALID
|
||||
elif http_response.status_code == 429:
|
||||
error_code = LLMErrorCode.API_RATE_LIMITED
|
||||
|
||||
raise LLMException(
|
||||
f"HTTP嵌入请求失败: {http_response.status_code}",
|
||||
code=error_code,
|
||||
details={
|
||||
"status_code": http_response.status_code,
|
||||
"response": error_text,
|
||||
"api_key": api_key,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response_json = http_response.json()
|
||||
adapter.validate_embedding_response(response_json)
|
||||
embeddings = adapter.parse_embedding_response(response_json)
|
||||
except Exception as e:
|
||||
logger.error(f"解析嵌入响应失败: {e}", e=e)
|
||||
await self.key_store.record_failure(api_key, None)
|
||||
if isinstance(e, LLMException):
|
||||
raise
|
||||
else:
|
||||
raise LLMException(
|
||||
f"解析API嵌入响应失败: {e}",
|
||||
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
await self.key_store.record_success(api_key)
|
||||
return embeddings
|
||||
|
||||
except LLMException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"生成嵌入时发生未预期错误: {e}", e=e)
|
||||
await self.key_store.record_failure(api_key, None)
|
||||
raise LLMException(
|
||||
f"生成嵌入失败: {e}",
|
||||
code=LLMErrorCode.EMBEDDING_FAILED,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def _execute_with_smart_retry(
|
||||
self,
|
||||
adapter,
|
||||
messages: list[LLMMessage],
|
||||
config: LLMGenerationConfig | None,
|
||||
tools_dict: list[dict[str, Any]] | None,
|
||||
tool_choice: str | dict[str, Any] | None,
|
||||
http_client: LLMHttpClient,
|
||||
):
|
||||
"""智能重试机制 - 使用统一的重试装饰器"""
|
||||
ai_config = get_ai_config()
|
||||
max_retries = ai_config.get("max_retries_llm", 3)
|
||||
retry_delay = ai_config.get("retry_delay_llm", 2)
|
||||
retry_config = RetryConfig(max_retries=max_retries, retry_delay=retry_delay)
|
||||
|
||||
return await with_smart_retry(
|
||||
self._execute_single_request,
|
||||
adapter,
|
||||
messages,
|
||||
config,
|
||||
tools_dict,
|
||||
tool_choice,
|
||||
http_client,
|
||||
retry_config=retry_config,
|
||||
key_store=self.key_store,
|
||||
provider_name=self.provider_name,
|
||||
)
|
||||
|
||||
async def _execute_single_request(
|
||||
self,
|
||||
adapter,
|
||||
messages: list[LLMMessage],
|
||||
config: LLMGenerationConfig | None,
|
||||
tools_dict: list[dict[str, Any]] | None,
|
||||
tool_choice: str | dict[str, Any] | None,
|
||||
http_client: LLMHttpClient,
|
||||
failed_keys: set[str] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""执行单次请求 - 供重试机制调用,直接返回 LLMResponse"""
|
||||
api_key = await self._select_api_key(failed_keys)
|
||||
|
||||
try:
|
||||
request_data = adapter.prepare_advanced_request(
|
||||
model=self,
|
||||
api_key=api_key,
|
||||
messages=messages,
|
||||
config=config,
|
||||
tools=tools_dict,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
http_response = await http_client.post(
|
||||
request_data.url,
|
||||
headers=request_data.headers,
|
||||
json=request_data.body,
|
||||
)
|
||||
logger.debug(f"📥 响应状态码: {http_response.status_code}")
|
||||
logger.debug(f"📄 响应头: {dict(http_response.headers)}")
|
||||
|
||||
if http_response.status_code != 200:
|
||||
error_text = http_response.text
|
||||
logger.error(
|
||||
f"HTTP请求失败: {http_response.status_code} - {error_text}"
|
||||
f"❌ HTTP请求失败: {http_response.status_code} - {error_text} "
|
||||
f"[{log_context}]"
|
||||
)
|
||||
logger.debug(f"💥 完整错误响应: {error_text}")
|
||||
|
||||
await self.key_store.record_failure(api_key, http_response.status_code)
|
||||
|
||||
@ -299,69 +290,165 @@ class LLMModel(LLMModelBase):
|
||||
|
||||
try:
|
||||
response_json = http_response.json()
|
||||
response_data = adapter.parse_response(
|
||||
model=self,
|
||||
response_json=response_json,
|
||||
is_advanced=True,
|
||||
)
|
||||
|
||||
from .types.models import LLMToolCall
|
||||
|
||||
response_tool_calls = []
|
||||
if response_data.tool_calls:
|
||||
for tc_data in response_data.tool_calls:
|
||||
if isinstance(tc_data, LLMToolCall):
|
||||
response_tool_calls.append(tc_data)
|
||||
elif isinstance(tc_data, dict):
|
||||
try:
|
||||
response_tool_calls.append(LLMToolCall(**tc_data))
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"无法将工具调用数据转换为LLMToolCall: {tc_data}, "
|
||||
f"error: {e}"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"工具调用数据格式未知: {tc_data}")
|
||||
|
||||
llm_response = LLMResponse(
|
||||
text=response_data.text,
|
||||
usage_info=response_data.usage_info,
|
||||
raw_response=response_data.raw_response,
|
||||
tool_calls=response_tool_calls if response_tool_calls else None,
|
||||
code_executions=response_data.code_executions,
|
||||
grounding_metadata=response_data.grounding_metadata,
|
||||
cache_info=response_data.cache_info,
|
||||
response_json_str = json.dumps(
|
||||
response_json, ensure_ascii=False, indent=2
|
||||
)
|
||||
logger.debug(f"📋 响应JSON: {response_json_str}")
|
||||
parsed_data = parse_response_func(response_json)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析响应失败: {e}", e=e)
|
||||
logger.error(f"解析 {log_context} 响应失败: {e}", e=e)
|
||||
await self.key_store.record_failure(api_key, None)
|
||||
|
||||
if isinstance(e, LLMException):
|
||||
raise
|
||||
else:
|
||||
raise LLMException(
|
||||
f"解析API响应失败: {e}",
|
||||
f"解析API {log_context} 响应失败: {e}",
|
||||
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
await self.key_store.record_success(api_key)
|
||||
|
||||
return llm_response
|
||||
logger.debug(f"✅ API密钥使用成功: {masked_key}")
|
||||
logger.info(f"🎯 LLM响应解析完成 [{log_context}]")
|
||||
return parsed_data
|
||||
|
||||
except LLMException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"生成响应时发生未预期错误: {e}", e=e)
|
||||
error_log_msg = f"生成 {log_context.lower()} 时发生未预期错误: {e}"
|
||||
logger.error(error_log_msg, e=e)
|
||||
await self.key_store.record_failure(api_key, None)
|
||||
|
||||
raise LLMException(
|
||||
f"生成响应失败: {e}",
|
||||
code=LLMErrorCode.GENERATION_FAILED,
|
||||
error_log_msg,
|
||||
code=LLMErrorCode.GENERATION_FAILED
|
||||
if log_context == "Generation"
|
||||
else LLMErrorCode.EMBEDDING_FAILED,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def _execute_embedding_request(
|
||||
self,
|
||||
adapter,
|
||||
texts: list[str],
|
||||
task_type: EmbeddingTaskType | str,
|
||||
http_client: LLMHttpClient,
|
||||
failed_keys: set[str] | None = None,
|
||||
) -> list[list[float]]:
|
||||
"""执行单次嵌入请求 - 供重试机制调用"""
|
||||
|
||||
async def prepare_request(api_key: str) -> RequestData:
|
||||
return adapter.prepare_embedding_request(
|
||||
model=self,
|
||||
api_key=api_key,
|
||||
texts=texts,
|
||||
task_type=task_type,
|
||||
)
|
||||
|
||||
def parse_response(response_json: dict[str, Any]) -> list[list[float]]:
|
||||
adapter.validate_embedding_response(response_json)
|
||||
return adapter.parse_embedding_response(response_json)
|
||||
|
||||
return await self._perform_api_call(
|
||||
prepare_request_func=prepare_request,
|
||||
parse_response_func=parse_response,
|
||||
http_client=http_client,
|
||||
failed_keys=failed_keys,
|
||||
log_context="Embedding",
|
||||
)
|
||||
|
||||
async def _execute_with_smart_retry(
|
||||
self,
|
||||
adapter,
|
||||
messages: list[LLMMessage],
|
||||
config: LLMGenerationConfig | None,
|
||||
tools: list[LLMTool] | None,
|
||||
tool_choice: str | dict[str, Any] | None,
|
||||
http_client: LLMHttpClient,
|
||||
):
|
||||
"""智能重试机制 - 使用统一的重试装饰器"""
|
||||
ai_config = get_ai_config()
|
||||
max_retries = ai_config.get("max_retries_llm", 3)
|
||||
retry_delay = ai_config.get("retry_delay_llm", 2)
|
||||
retry_config = RetryConfig(max_retries=max_retries, retry_delay=retry_delay)
|
||||
|
||||
return await with_smart_retry(
|
||||
self._execute_single_request,
|
||||
adapter,
|
||||
messages,
|
||||
config,
|
||||
tools,
|
||||
tool_choice,
|
||||
http_client,
|
||||
retry_config=retry_config,
|
||||
key_store=self.key_store,
|
||||
provider_name=self.provider_name,
|
||||
)
|
||||
|
||||
async def _execute_single_request(
|
||||
self,
|
||||
adapter,
|
||||
messages: list[LLMMessage],
|
||||
config: LLMGenerationConfig | None,
|
||||
tools: list[LLMTool] | None,
|
||||
tool_choice: str | dict[str, Any] | None,
|
||||
http_client: LLMHttpClient,
|
||||
failed_keys: set[str] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""执行单次请求 - 供重试机制调用,直接返回 LLMResponse"""
|
||||
|
||||
async def prepare_request(api_key: str) -> RequestData:
|
||||
return await adapter.prepare_advanced_request(
|
||||
model=self,
|
||||
api_key=api_key,
|
||||
messages=messages,
|
||||
config=config,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
def parse_response(response_json: dict[str, Any]) -> LLMResponse:
|
||||
response_data = adapter.parse_response(
|
||||
model=self,
|
||||
response_json=response_json,
|
||||
is_advanced=True,
|
||||
)
|
||||
from .types.models import LLMToolCall
|
||||
|
||||
response_tool_calls = []
|
||||
if response_data.tool_calls:
|
||||
for tc_data in response_data.tool_calls:
|
||||
if isinstance(tc_data, LLMToolCall):
|
||||
response_tool_calls.append(tc_data)
|
||||
elif isinstance(tc_data, dict):
|
||||
try:
|
||||
response_tool_calls.append(LLMToolCall(**tc_data))
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"无法将工具调用数据转换为LLMToolCall: {tc_data}, "
|
||||
f"error: {e}"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"工具调用数据格式未知: {tc_data}")
|
||||
|
||||
return LLMResponse(
|
||||
text=response_data.text,
|
||||
usage_info=response_data.usage_info,
|
||||
raw_response=response_data.raw_response,
|
||||
tool_calls=response_tool_calls if response_tool_calls else None,
|
||||
code_executions=response_data.code_executions,
|
||||
grounding_metadata=response_data.grounding_metadata,
|
||||
cache_info=response_data.cache_info,
|
||||
)
|
||||
|
||||
return await self._perform_api_call(
|
||||
prepare_request_func=prepare_request,
|
||||
parse_response_func=parse_response,
|
||||
http_client=http_client,
|
||||
failed_keys=failed_keys,
|
||||
log_context="Generation",
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
标记模型实例的当前使用周期结束。
|
||||
@ -400,7 +487,17 @@ class LLMModel(LLMModelBase):
|
||||
history: list[dict[str, str]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""生成文本 - 通过 generate_response 实现"""
|
||||
"""
|
||||
生成文本 - 通过 generate_response 实现
|
||||
|
||||
参数:
|
||||
prompt: 输入提示词。
|
||||
history: 对话历史记录。
|
||||
**kwargs: 其他参数。
|
||||
|
||||
返回:
|
||||
str: 生成的文本。
|
||||
"""
|
||||
self._check_not_closed()
|
||||
|
||||
messages: list[LLMMessage] = []
|
||||
@ -439,11 +536,21 @@ class LLMModel(LLMModelBase):
|
||||
config: LLMGenerationConfig | None = None,
|
||||
tools: list[LLMTool] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tool_executor: Callable[[str, dict[str, Any]], Awaitable[Any]] | None = None,
|
||||
max_tool_iterations: int = 5,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""生成高级响应 - 实现完整的工具调用循环"""
|
||||
"""
|
||||
生成高级响应
|
||||
|
||||
参数:
|
||||
messages: 消息列表。
|
||||
config: 生成配置。
|
||||
tools: 工具列表。
|
||||
tool_choice: 工具选择策略。
|
||||
**kwargs: 其他参数。
|
||||
|
||||
返回:
|
||||
LLMResponse: 模型响应。
|
||||
"""
|
||||
self._check_not_closed()
|
||||
|
||||
from .adapters import get_adapter_for_api_type
|
||||
@ -468,109 +575,43 @@ class LLMModel(LLMModelBase):
|
||||
merged_dict.update(config.to_dict())
|
||||
final_request_config = LLMGenerationConfig(**merged_dict)
|
||||
|
||||
tools_dict: list[dict[str, Any]] | None = None
|
||||
if tools:
|
||||
tools_dict = []
|
||||
for tool in tools:
|
||||
if hasattr(tool, "model_dump"):
|
||||
model_dump_func = getattr(tool, "model_dump")
|
||||
tools_dict.append(model_dump_func(exclude_none=True))
|
||||
elif isinstance(tool, dict):
|
||||
tools_dict.append(tool)
|
||||
else:
|
||||
try:
|
||||
tools_dict.append(dict(tool))
|
||||
except (TypeError, ValueError):
|
||||
logger.warning(f"工具 '{tool}' 无法转换为字典,已忽略。")
|
||||
|
||||
http_client = await self._get_http_client()
|
||||
current_messages = list(messages)
|
||||
|
||||
for iteration in range(max_tool_iterations):
|
||||
logger.debug(f"工具调用循环迭代: {iteration + 1}/{max_tool_iterations}")
|
||||
async with AsyncExitStack() as stack:
|
||||
activated_tools = []
|
||||
if tools:
|
||||
for tool in tools:
|
||||
if tool.type == "mcp" and callable(tool.mcp_session):
|
||||
func_obj = getattr(tool.mcp_session, "func", None)
|
||||
tool_name = (
|
||||
getattr(func_obj, "__name__", "unknown")
|
||||
if func_obj
|
||||
else "unknown"
|
||||
)
|
||||
logger.debug(f"正在激活 MCP 工具会话: {tool_name}")
|
||||
|
||||
active_session = await stack.enter_async_context(
|
||||
tool.mcp_session()
|
||||
)
|
||||
|
||||
activated_tools.append(
|
||||
LLMTool.from_mcp_session(
|
||||
session=active_session, annotations=tool.annotations
|
||||
)
|
||||
)
|
||||
else:
|
||||
activated_tools.append(tool)
|
||||
|
||||
llm_response = await self._execute_with_smart_retry(
|
||||
adapter,
|
||||
current_messages,
|
||||
messages,
|
||||
final_request_config,
|
||||
tools_dict if iteration == 0 else None,
|
||||
tool_choice if iteration == 0 else None,
|
||||
activated_tools if activated_tools else None,
|
||||
tool_choice,
|
||||
http_client,
|
||||
)
|
||||
|
||||
response_tool_calls = llm_response.tool_calls or []
|
||||
|
||||
if not response_tool_calls or not tool_executor:
|
||||
logger.debug("模型未请求工具调用,或未提供工具执行器。返回当前响应。")
|
||||
return llm_response
|
||||
|
||||
logger.info(f"模型请求执行 {len(response_tool_calls)} 个工具。")
|
||||
|
||||
assistant_message_content = llm_response.text if llm_response.text else ""
|
||||
current_messages.append(
|
||||
LLMMessage.assistant_tool_calls(
|
||||
content=assistant_message_content, tool_calls=response_tool_calls
|
||||
)
|
||||
)
|
||||
|
||||
tool_response_messages: list[LLMMessage] = []
|
||||
for tool_call in response_tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
try:
|
||||
tool_args_dict = json.loads(tool_call.function.arguments)
|
||||
logger.debug(f"执行工具: {tool_name},参数: {tool_args_dict}")
|
||||
|
||||
tool_result = await tool_executor(tool_name, tool_args_dict)
|
||||
logger.debug(
|
||||
f"工具 '{tool_name}' 执行结果: {str(tool_result)[:200]}..."
|
||||
)
|
||||
|
||||
tool_response_messages.append(
|
||||
LLMMessage.tool_response(
|
||||
tool_call_id=tool_call.id,
|
||||
function_name=tool_name,
|
||||
result=tool_result,
|
||||
)
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(
|
||||
f"工具 '{tool_name}' 参数JSON解析失败: "
|
||||
f"{tool_call.function.arguments}, 错误: {e}"
|
||||
)
|
||||
tool_response_messages.append(
|
||||
LLMMessage.tool_response(
|
||||
tool_call_id=tool_call.id,
|
||||
function_name=tool_name,
|
||||
result={
|
||||
"error": "Argument JSON parsing failed",
|
||||
"details": str(e),
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"执行工具 '{tool_name}' 失败: {e}", e=e)
|
||||
tool_response_messages.append(
|
||||
LLMMessage.tool_response(
|
||||
tool_call_id=tool_call.id,
|
||||
function_name=tool_name,
|
||||
result={
|
||||
"error": "Tool execution failed",
|
||||
"details": str(e),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
current_messages.extend(tool_response_messages)
|
||||
|
||||
logger.warning(f"已达到最大工具调用迭代次数 ({max_tool_iterations})。")
|
||||
raise LLMException(
|
||||
"已达到最大工具调用迭代次数,但模型仍在请求工具调用或未提供最终文本回复。",
|
||||
code=LLMErrorCode.GENERATION_FAILED,
|
||||
details={
|
||||
"iterations": max_tool_iterations,
|
||||
"last_messages": current_messages[-2:],
|
||||
},
|
||||
)
|
||||
return llm_response
|
||||
|
||||
async def generate_embeddings(
|
||||
self,
|
||||
@ -578,7 +619,17 @@ class LLMModel(LLMModelBase):
|
||||
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
||||
**kwargs: Any,
|
||||
) -> list[list[float]]:
|
||||
"""生成文本嵌入向量"""
|
||||
"""
|
||||
生成文本嵌入向量
|
||||
|
||||
参数:
|
||||
texts: 文本列表。
|
||||
task_type: 嵌入任务类型。
|
||||
**kwargs: 其他参数。
|
||||
|
||||
返回:
|
||||
list[list[float]]: 嵌入向量列表。
|
||||
"""
|
||||
self._check_not_closed()
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
7
zhenxun/services/llm/tools/__init__.py
Normal file
7
zhenxun/services/llm/tools/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
"""
|
||||
工具模块导出
|
||||
"""
|
||||
|
||||
from .registry import tool_registry
|
||||
|
||||
__all__ = ["tool_registry"]
|
||||
181
zhenxun/services/llm/tools/registry.py
Normal file
181
zhenxun/services/llm/tools/registry.py
Normal file
@ -0,0 +1,181 @@
|
||||
"""
|
||||
工具注册表
|
||||
|
||||
负责加载、管理和实例化来自配置的工具。
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from contextlib import AbstractAsyncContextManager
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
from ..types import LLMTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..config.providers import ToolConfig
|
||||
from ..types.protocols import MCPCompatible
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""工具注册表,用于管理和实例化配置的工具。"""
|
||||
|
||||
def __init__(self):
|
||||
self._function_tools: dict[str, LLMTool] = {}
|
||||
|
||||
self._mcp_config_models: dict[str, type[BaseModel]] = {}
|
||||
if TYPE_CHECKING:
|
||||
self._mcp_factories: dict[
|
||||
str, Callable[..., AbstractAsyncContextManager["MCPCompatible"]]
|
||||
] = {}
|
||||
else:
|
||||
self._mcp_factories: dict[str, Callable] = {}
|
||||
|
||||
self._tool_configs: dict[str, "ToolConfig"] | None = None
|
||||
self._tool_cache: dict[str, "LLMTool"] = {}
|
||||
|
||||
def _load_configs_if_needed(self):
|
||||
"""如果尚未加载,则从主配置中加载MCP工具定义。"""
|
||||
if self._tool_configs is None:
|
||||
logger.debug("首次访问,正在加载MCP工具配置...")
|
||||
from ..config.providers import get_llm_config
|
||||
|
||||
llm_config = get_llm_config()
|
||||
self._tool_configs = {tool.name: tool for tool in llm_config.mcp_tools}
|
||||
logger.info(f"已加载 {len(self._tool_configs)} 个MCP工具配置。")
|
||||
|
||||
def function_tool(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
parameters: dict,
|
||||
required: list[str] | None = None,
|
||||
):
|
||||
"""
|
||||
装饰器:在代码中注册一个简单的、无状态的函数工具。
|
||||
|
||||
参数:
|
||||
name: 工具的唯一名称。
|
||||
description: 工具功能的描述。
|
||||
parameters: OpenAPI格式的函数参数schema的properties部分。
|
||||
required: 必需的参数列表。
|
||||
"""
|
||||
|
||||
def decorator(func: Callable):
|
||||
if name in self._function_tools or name in self._mcp_factories:
|
||||
logger.warning(f"正在覆盖已注册的工具: {name}")
|
||||
|
||||
tool_definition = LLMTool.create(
|
||||
name=name,
|
||||
description=description,
|
||||
parameters=parameters,
|
||||
required=required,
|
||||
)
|
||||
self._function_tools[name] = tool_definition
|
||||
logger.info(f"已在代码中注册函数工具: '{name}'")
|
||||
tool_definition.annotations = tool_definition.annotations or {}
|
||||
tool_definition.annotations["executable"] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def mcp_tool(self, name: str, config_model: type[BaseModel]):
|
||||
"""
|
||||
装饰器:注册一个MCP工具及其配置模型。
|
||||
|
||||
参数:
|
||||
name: 工具的唯一名称,必须与配置文件中的名称匹配。
|
||||
config_model: 一个Pydantic模型,用于定义和验证该工具的 `mcp_config`。
|
||||
"""
|
||||
|
||||
def decorator(factory_func: Callable):
|
||||
if name in self._mcp_factories:
|
||||
logger.warning(f"正在覆盖已注册的 MCP 工厂: {name}")
|
||||
self._mcp_factories[name] = factory_func
|
||||
self._mcp_config_models[name] = config_model
|
||||
logger.info(f"已注册 MCP 工具 '{name}' (配置模型: {config_model.__name__})")
|
||||
return factory_func
|
||||
|
||||
return decorator
|
||||
|
||||
def get_mcp_config_model(self, name: str) -> type[BaseModel] | None:
|
||||
"""根据名称获取MCP工具的配置模型。"""
|
||||
return self._mcp_config_models.get(name)
|
||||
|
||||
def register_mcp_factory(
|
||||
self,
|
||||
name: str,
|
||||
factory: Callable,
|
||||
):
|
||||
"""
|
||||
在代码中注册一个 MCP 会话工厂,将其与配置中的工具名称关联。
|
||||
|
||||
参数:
|
||||
name: 工具的唯一名称,必须与配置文件中的名称匹配。
|
||||
factory: 一个返回异步生成器的可调用对象(会话工厂)。
|
||||
"""
|
||||
if name in self._mcp_factories:
|
||||
logger.warning(f"正在覆盖已注册的 MCP 工厂: {name}")
|
||||
self._mcp_factories[name] = factory
|
||||
logger.info(f"已注册 MCP 会话工厂: '{name}'")
|
||||
|
||||
def get_tool(self, name: str) -> "LLMTool":
|
||||
"""
|
||||
根据名称获取一个 LLMTool 定义。
|
||||
对于MCP工具,返回的 LLMTool 实例包含一个可调用的会话工厂,
|
||||
而不是一个已激活的会话。
|
||||
"""
|
||||
logger.debug(f"🔍 请求获取工具定义: {name}")
|
||||
|
||||
if name in self._tool_cache:
|
||||
logger.debug(f"✅ 从缓存中获取工具定义: {name}")
|
||||
return self._tool_cache[name]
|
||||
|
||||
if name in self._function_tools:
|
||||
logger.debug(f"🛠️ 获取函数工具定义: {name}")
|
||||
tool = self._function_tools[name]
|
||||
self._tool_cache[name] = tool
|
||||
return tool
|
||||
|
||||
self._load_configs_if_needed()
|
||||
if self._tool_configs is None or name not in self._tool_configs:
|
||||
known_tools = list(self._function_tools.keys()) + (
|
||||
list(self._tool_configs.keys()) if self._tool_configs else []
|
||||
)
|
||||
logger.error(f"❌ 未找到名为 '{name}' 的工具定义")
|
||||
logger.debug(f"📋 可用工具定义列表: {known_tools}")
|
||||
raise ValueError(f"未找到名为 '{name}' 的工具定义。已知工具: {known_tools}")
|
||||
|
||||
config = self._tool_configs[name]
|
||||
tool: "LLMTool"
|
||||
|
||||
if name not in self._mcp_factories:
|
||||
logger.error(f"❌ MCP工具 '{name}' 缺少工厂函数")
|
||||
available_factories = list(self._mcp_factories.keys())
|
||||
logger.debug(f"📋 已注册的MCP工厂: {available_factories}")
|
||||
raise ValueError(
|
||||
f"MCP 工具 '{name}' 已在配置中定义,但没有注册对应的工厂函数。"
|
||||
"请使用 `@tool_registry.mcp_tool` 装饰器进行注册。"
|
||||
)
|
||||
|
||||
logger.info(f"🔧 创建MCP工具定义: {name}")
|
||||
factory = self._mcp_factories[name]
|
||||
typed_mcp_config = config.mcp_config
|
||||
logger.debug(f"📋 MCP工具配置: {typed_mcp_config}")
|
||||
|
||||
configured_factory = partial(factory, config=typed_mcp_config)
|
||||
tool = LLMTool.from_mcp_session(session=configured_factory)
|
||||
|
||||
self._tool_cache[name] = tool
|
||||
logger.debug(f"💾 MCP工具定义已缓存: {name}")
|
||||
return tool
|
||||
|
||||
def get_tools(self, names: list[str]) -> list["LLMTool"]:
|
||||
"""根据名称列表获取多个 LLMTool 实例。"""
|
||||
return [self.get_tool(name) for name in names]
|
||||
|
||||
|
||||
tool_registry = ToolRegistry()
|
||||
@ -4,6 +4,7 @@ LLM 类型定义模块
|
||||
统一导出所有核心类型、协议和异常定义。
|
||||
"""
|
||||
|
||||
from .capabilities import ModelCapabilities, ModelModality, get_model_capabilities
|
||||
from .content import (
|
||||
LLMContentPart,
|
||||
LLMMessage,
|
||||
@ -26,6 +27,7 @@ from .models import (
|
||||
ToolMetadata,
|
||||
UsageInfo,
|
||||
)
|
||||
from .protocols import MCPCompatible
|
||||
|
||||
__all__ = [
|
||||
"EmbeddingTaskType",
|
||||
@ -41,8 +43,11 @@ __all__ = [
|
||||
"LLMTool",
|
||||
"LLMToolCall",
|
||||
"LLMToolFunction",
|
||||
"MCPCompatible",
|
||||
"ModelCapabilities",
|
||||
"ModelDetail",
|
||||
"ModelInfo",
|
||||
"ModelModality",
|
||||
"ModelName",
|
||||
"ModelProvider",
|
||||
"ProviderConfig",
|
||||
@ -50,5 +55,6 @@ __all__ = [
|
||||
"ToolCategory",
|
||||
"ToolMetadata",
|
||||
"UsageInfo",
|
||||
"get_model_capabilities",
|
||||
"get_user_friendly_error_message",
|
||||
]
|
||||
|
||||
128
zhenxun/services/llm/types/capabilities.py
Normal file
128
zhenxun/services/llm/types/capabilities.py
Normal file
@ -0,0 +1,128 @@
|
||||
"""
|
||||
LLM 模型能力定义模块
|
||||
|
||||
定义模型的输入输出模态、工具调用支持等核心能力。
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
import fnmatch
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ModelModality(str, Enum):
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
AUDIO = "audio"
|
||||
VIDEO = "video"
|
||||
EMBEDDING = "embedding"
|
||||
|
||||
|
||||
class ModelCapabilities(BaseModel):
|
||||
"""定义一个模型的核心、稳定能力。"""
|
||||
|
||||
input_modalities: set[ModelModality] = Field(default={ModelModality.TEXT})
|
||||
output_modalities: set[ModelModality] = Field(default={ModelModality.TEXT})
|
||||
supports_tool_calling: bool = False
|
||||
is_embedding_model: bool = False
|
||||
|
||||
|
||||
STANDARD_TEXT_TOOL_CAPABILITIES = ModelCapabilities(
|
||||
input_modalities={ModelModality.TEXT},
|
||||
output_modalities={ModelModality.TEXT},
|
||||
supports_tool_calling=True,
|
||||
)
|
||||
|
||||
GEMINI_CAPABILITIES = ModelCapabilities(
|
||||
input_modalities={
|
||||
ModelModality.TEXT,
|
||||
ModelModality.IMAGE,
|
||||
ModelModality.AUDIO,
|
||||
ModelModality.VIDEO,
|
||||
},
|
||||
output_modalities={ModelModality.TEXT},
|
||||
supports_tool_calling=True,
|
||||
)
|
||||
|
||||
DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES = ModelCapabilities(
|
||||
input_modalities={ModelModality.TEXT, ModelModality.IMAGE, ModelModality.VIDEO},
|
||||
output_modalities={ModelModality.TEXT},
|
||||
supports_tool_calling=True,
|
||||
)
|
||||
|
||||
|
||||
MODEL_ALIAS_MAPPING: dict[str, str] = {
|
||||
"deepseek-v3*": "deepseek-chat",
|
||||
"deepseek-ai/DeepSeek-V3": "deepseek-chat",
|
||||
"deepseek-r1*": "deepseek-reasoner",
|
||||
}
|
||||
|
||||
|
||||
MODEL_CAPABILITIES_REGISTRY: dict[str, ModelCapabilities] = {
|
||||
"gemini-*-tts": ModelCapabilities(
|
||||
input_modalities={ModelModality.TEXT},
|
||||
output_modalities={ModelModality.AUDIO},
|
||||
),
|
||||
"gemini-*-native-audio-*": ModelCapabilities(
|
||||
input_modalities={ModelModality.TEXT, ModelModality.AUDIO, ModelModality.VIDEO},
|
||||
output_modalities={ModelModality.TEXT, ModelModality.AUDIO},
|
||||
supports_tool_calling=True,
|
||||
),
|
||||
"gemini-2.0-flash-preview-image-generation": ModelCapabilities(
|
||||
input_modalities={
|
||||
ModelModality.TEXT,
|
||||
ModelModality.IMAGE,
|
||||
ModelModality.AUDIO,
|
||||
ModelModality.VIDEO,
|
||||
},
|
||||
output_modalities={ModelModality.TEXT, ModelModality.IMAGE},
|
||||
supports_tool_calling=True,
|
||||
),
|
||||
"gemini-embedding-exp": ModelCapabilities(
|
||||
input_modalities={ModelModality.TEXT},
|
||||
output_modalities={ModelModality.EMBEDDING},
|
||||
is_embedding_model=True,
|
||||
),
|
||||
"gemini-2.5-pro*": GEMINI_CAPABILITIES,
|
||||
"gemini-1.5-pro*": GEMINI_CAPABILITIES,
|
||||
"gemini-2.5-flash*": GEMINI_CAPABILITIES,
|
||||
"gemini-2.0-flash*": GEMINI_CAPABILITIES,
|
||||
"gemini-1.5-flash*": GEMINI_CAPABILITIES,
|
||||
"GLM-4V-Flash": ModelCapabilities(
|
||||
input_modalities={ModelModality.TEXT, ModelModality.IMAGE},
|
||||
output_modalities={ModelModality.TEXT},
|
||||
supports_tool_calling=True,
|
||||
),
|
||||
"GLM-4V-Plus*": ModelCapabilities(
|
||||
input_modalities={ModelModality.TEXT, ModelModality.IMAGE, ModelModality.VIDEO},
|
||||
output_modalities={ModelModality.TEXT},
|
||||
supports_tool_calling=True,
|
||||
),
|
||||
"glm-4-*": STANDARD_TEXT_TOOL_CAPABILITIES,
|
||||
"glm-z1-*": STANDARD_TEXT_TOOL_CAPABILITIES,
|
||||
"doubao-seed-*": DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES,
|
||||
"doubao-1-5-thinking-vision-pro": DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES,
|
||||
"deepseek-chat": STANDARD_TEXT_TOOL_CAPABILITIES,
|
||||
"deepseek-reasoner": STANDARD_TEXT_TOOL_CAPABILITIES,
|
||||
}
|
||||
|
||||
|
||||
def get_model_capabilities(model_name: str) -> ModelCapabilities:
|
||||
"""
|
||||
从注册表获取模型能力,支持别名映射和通配符匹配。
|
||||
查找顺序: 1. 标准化名称 -> 2. 精确匹配 -> 3. 通配符匹配 -> 4. 默认值
|
||||
"""
|
||||
canonical_name = model_name
|
||||
for alias_pattern, c_name in MODEL_ALIAS_MAPPING.items():
|
||||
if fnmatch.fnmatch(model_name, alias_pattern):
|
||||
canonical_name = c_name
|
||||
break
|
||||
|
||||
if canonical_name in MODEL_CAPABILITIES_REGISTRY:
|
||||
return MODEL_CAPABILITIES_REGISTRY[canonical_name]
|
||||
|
||||
for pattern, capabilities in MODEL_CAPABILITIES_REGISTRY.items():
|
||||
if "*" in pattern and fnmatch.fnmatch(model_name, pattern):
|
||||
return capabilities
|
||||
|
||||
return ModelCapabilities()
|
||||
@ -225,8 +225,10 @@ class LLMContentPart(BaseModel):
|
||||
logger.warning(f"无法解析Base64图像数据: {self.image_source[:50]}...")
|
||||
return None
|
||||
|
||||
def convert_for_api(self, api_type: str) -> dict[str, Any]:
|
||||
async def convert_for_api_async(self, api_type: str) -> dict[str, Any]:
|
||||
"""根据API类型转换多模态内容格式"""
|
||||
from zhenxun.utils.http_utils import AsyncHttpx
|
||||
|
||||
if self.type == "text":
|
||||
if api_type == "openai":
|
||||
return {"type": "text", "text": self.text}
|
||||
@ -248,20 +250,23 @@ class LLMContentPart(BaseModel):
|
||||
mime_type, data = base64_info
|
||||
return {"inlineData": {"mimeType": mime_type, "data": data}}
|
||||
else:
|
||||
# 如果无法解析 Base64 数据,抛出异常
|
||||
raise ValueError(
|
||||
f"无法解析Base64图像数据: {self.image_source[:50]}..."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Gemini API需要Base64格式,但提供的是URL: {self.image_source}"
|
||||
)
|
||||
return {
|
||||
"inlineData": {
|
||||
"mimeType": "image/jpeg",
|
||||
"data": self.image_source,
|
||||
elif self.is_image_url():
|
||||
logger.debug(f"正在为Gemini下载并编码URL图片: {self.image_source}")
|
||||
try:
|
||||
image_bytes = await AsyncHttpx.get_content(self.image_source)
|
||||
mime_type = self.mime_type or "image/jpeg"
|
||||
base64_data = base64.b64encode(image_bytes).decode("utf-8")
|
||||
return {
|
||||
"inlineData": {"mimeType": mime_type, "data": base64_data}
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"下载或编码URL图片失败: {e}", e=e)
|
||||
raise ValueError(f"无法处理图片URL: {e}")
|
||||
else:
|
||||
raise ValueError(f"不支持的图像源格式: {self.image_source[:50]}...")
|
||||
else:
|
||||
return {"type": "image_url", "image_url": {"url": self.image_source}}
|
||||
|
||||
|
||||
@ -4,13 +4,25 @@ LLM 数据模型定义
|
||||
包含模型信息、配置、工具定义和响应数据的模型类。
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from contextlib import AbstractAsyncContextManager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .enums import ModelProvider, ToolCategory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .protocols import MCPCompatible
|
||||
|
||||
MCPSessionType = (
|
||||
MCPCompatible | Callable[[], AbstractAsyncContextManager[MCPCompatible]] | None
|
||||
)
|
||||
else:
|
||||
MCPCompatible = object
|
||||
MCPSessionType = Any
|
||||
|
||||
ModelName = str | None
|
||||
|
||||
|
||||
@ -98,10 +110,21 @@ class LLMToolCall(BaseModel):
|
||||
class LLMTool(BaseModel):
|
||||
"""LLM 工具定义(支持 MCP 风格)"""
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
type: str = "function"
|
||||
function: dict[str, Any]
|
||||
function: dict[str, Any] | None = None
|
||||
mcp_session: MCPSessionType = None
|
||||
annotations: dict[str, Any] | None = Field(default=None, description="工具注解")
|
||||
|
||||
def model_post_init(self, /, __context: Any) -> None:
|
||||
"""验证工具定义的有效性"""
|
||||
_ = __context
|
||||
if self.type == "function" and self.function is None:
|
||||
raise ValueError("函数类型的工具必须包含 'function' 字段。")
|
||||
if self.type == "mcp" and self.mcp_session is None:
|
||||
raise ValueError("MCP 类型的工具必须包含 'mcp_session' 字段。")
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
@ -111,7 +134,7 @@ class LLMTool(BaseModel):
|
||||
required: list[str] | None = None,
|
||||
annotations: dict[str, Any] | None = None,
|
||||
) -> "LLMTool":
|
||||
"""创建工具"""
|
||||
"""创建函数工具"""
|
||||
function_def = {
|
||||
"name": name,
|
||||
"description": description,
|
||||
@ -123,6 +146,15 @@ class LLMTool(BaseModel):
|
||||
}
|
||||
return cls(type="function", function=function_def, annotations=annotations)
|
||||
|
||||
@classmethod
|
||||
def from_mcp_session(
|
||||
cls,
|
||||
session: Any,
|
||||
annotations: dict[str, Any] | None = None,
|
||||
) -> "LLMTool":
|
||||
"""从 MCP 会话创建工具"""
|
||||
return cls(type="mcp", mcp_session=session, annotations=annotations)
|
||||
|
||||
|
||||
class LLMCodeExecution(BaseModel):
|
||||
"""代码执行结果"""
|
||||
|
||||
24
zhenxun/services/llm/types/protocols.py
Normal file
24
zhenxun/services/llm/types/protocols.py
Normal file
@ -0,0 +1,24 @@
|
||||
"""
|
||||
LLM 模块的协议定义
|
||||
"""
|
||||
|
||||
from typing import Any, Protocol
|
||||
|
||||
|
||||
class MCPCompatible(Protocol):
|
||||
"""
|
||||
一个协议,定义了与LLM模块兼容的MCP会话对象应具备的行为。
|
||||
任何实现了 to_api_tool 方法的对象都可以被认为是 MCPCompatible。
|
||||
"""
|
||||
|
||||
def to_api_tool(self, api_type: str) -> dict[str, Any]:
|
||||
"""
|
||||
将此MCP会话转换为特定LLM提供商API所需的工具格式。
|
||||
|
||||
参数:
|
||||
api_type: 目标API的类型 (例如 'gemini', 'openai')。
|
||||
|
||||
返回:
|
||||
dict[str, Any]: 一个字典,代表可以在API请求中使用的工具定义。
|
||||
"""
|
||||
...
|
||||
@ -3,8 +3,10 @@ LLM 模块的工具和转换函数
|
||||
"""
|
||||
|
||||
import base64
|
||||
import copy
|
||||
from pathlib import Path
|
||||
|
||||
from nonebot.adapters import Message as PlatformMessage
|
||||
from nonebot_plugin_alconna.uniseg import (
|
||||
At,
|
||||
File,
|
||||
@ -17,6 +19,7 @@ from nonebot_plugin_alconna.uniseg import (
|
||||
)
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.http_utils import AsyncHttpx
|
||||
|
||||
from .types import LLMContentPart
|
||||
|
||||
@ -25,6 +28,12 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
|
||||
"""
|
||||
将 UniMessage 实例转换为一个 LLMContentPart 列表。
|
||||
这是处理多模态输入的核心转换逻辑。
|
||||
|
||||
参数:
|
||||
message: 要转换的UniMessage实例。
|
||||
|
||||
返回:
|
||||
list[LLMContentPart]: 转换后的内容部分列表。
|
||||
"""
|
||||
parts: list[LLMContentPart] = []
|
||||
for seg in message:
|
||||
@ -51,14 +60,25 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
|
||||
if seg.path:
|
||||
part = await LLMContentPart.from_path(seg.path)
|
||||
elif seg.url:
|
||||
logger.warning(
|
||||
f"直接使用 URL 的 {type(seg).__name__} 段,"
|
||||
f"API 可能不支持: {seg.url}"
|
||||
)
|
||||
part = LLMContentPart.text_part(
|
||||
f"[{type(seg).__name__.upper()} FILE: {seg.name or seg.url}]"
|
||||
)
|
||||
elif hasattr(seg, "raw") and seg.raw:
|
||||
try:
|
||||
logger.debug(f"检测到媒体URL,开始下载: {seg.url}")
|
||||
media_bytes = await AsyncHttpx.get_content(seg.url)
|
||||
|
||||
new_seg = copy.copy(seg)
|
||||
new_seg.raw = media_bytes
|
||||
seg = new_seg
|
||||
logger.debug(f"媒体文件下载成功,大小: {len(media_bytes)} bytes")
|
||||
except Exception as e:
|
||||
logger.error(f"从URL下载媒体失败: {seg.url}, 错误: {e}")
|
||||
part = LLMContentPart.text_part(
|
||||
f"[下载媒体失败: {seg.name or seg.url}]"
|
||||
)
|
||||
|
||||
if part:
|
||||
parts.append(part)
|
||||
continue
|
||||
|
||||
if hasattr(seg, "raw") and seg.raw:
|
||||
mime_type = getattr(seg, "mimetype", None)
|
||||
if isinstance(seg.raw, bytes):
|
||||
b64_data = base64.b64encode(seg.raw).decode("utf-8")
|
||||
@ -127,50 +147,19 @@ def create_multimodal_message(
|
||||
audio_mimetypes: list[str] | str | None = None,
|
||||
) -> UniMessage:
|
||||
"""
|
||||
创建多模态消息的便捷函数,方便第三方调用。
|
||||
创建多模态消息的便捷函数
|
||||
|
||||
Args:
|
||||
参数:
|
||||
text: 文本内容
|
||||
images: 图片数据,支持路径、字节数据或URL
|
||||
videos: 视频数据,支持路径、字节数据或URL
|
||||
audios: 音频数据,支持路径、字节数据或URL
|
||||
image_mimetypes: 图片MIME类型,当images为bytes时需要指定
|
||||
video_mimetypes: 视频MIME类型,当videos为bytes时需要指定
|
||||
audio_mimetypes: 音频MIME类型,当audios为bytes时需要指定
|
||||
videos: 视频数据
|
||||
audios: 音频数据
|
||||
image_mimetypes: 图片MIME类型,bytes数据时需要指定
|
||||
video_mimetypes: 视频MIME类型,bytes数据时需要指定
|
||||
audio_mimetypes: 音频MIME类型,bytes数据时需要指定
|
||||
|
||||
Returns:
|
||||
返回:
|
||||
UniMessage: 构建好的多模态消息
|
||||
|
||||
Examples:
|
||||
# 纯文本
|
||||
msg = create_multimodal_message("请分析这段文字")
|
||||
|
||||
# 文本 + 单张图片(路径)
|
||||
msg = create_multimodal_message("分析图片", images="/path/to/image.jpg")
|
||||
|
||||
# 文本 + 多张图片
|
||||
msg = create_multimodal_message(
|
||||
"比较图片", images=["/path/1.jpg", "/path/2.jpg"]
|
||||
)
|
||||
|
||||
# 文本 + 图片字节数据
|
||||
msg = create_multimodal_message(
|
||||
"分析", images=image_data, image_mimetypes="image/jpeg"
|
||||
)
|
||||
|
||||
# 文本 + 视频
|
||||
msg = create_multimodal_message("分析视频", videos="/path/to/video.mp4")
|
||||
|
||||
# 文本 + 音频
|
||||
msg = create_multimodal_message("转录音频", audios="/path/to/audio.wav")
|
||||
|
||||
# 混合多模态
|
||||
msg = create_multimodal_message(
|
||||
"分析这些媒体文件",
|
||||
images="/path/to/image.jpg",
|
||||
videos="/path/to/video.mp4",
|
||||
audios="/path/to/audio.wav"
|
||||
)
|
||||
"""
|
||||
message = UniMessage()
|
||||
|
||||
@ -196,7 +185,7 @@ def _add_media_to_message(
|
||||
media_class: type,
|
||||
default_mimetype: str,
|
||||
) -> None:
|
||||
"""添加媒体文件到 UniMessage 的辅助函数"""
|
||||
"""添加媒体文件到 UniMessage"""
|
||||
if not isinstance(media_items, list):
|
||||
media_items = [media_items]
|
||||
|
||||
@ -216,3 +205,80 @@ def _add_media_to_message(
|
||||
elif isinstance(item, bytes):
|
||||
mimetype = mime_list[i] if i < len(mime_list) else default_mimetype
|
||||
message.append(media_class(raw=item, mimetype=mimetype))
|
||||
|
||||
|
||||
def message_to_unimessage(message: PlatformMessage) -> UniMessage:
|
||||
"""
|
||||
将平台特定的 Message 对象转换为通用的 UniMessage。
|
||||
主要用于处理引用消息等未被自动转换的消息体。
|
||||
|
||||
参数:
|
||||
message: 平台特定的Message对象。
|
||||
|
||||
返回:
|
||||
UniMessage: 转换后的通用消息对象。
|
||||
"""
|
||||
uni_segments = []
|
||||
for seg in message:
|
||||
if seg.type == "text":
|
||||
uni_segments.append(Text(seg.data.get("text", "")))
|
||||
elif seg.type == "image":
|
||||
uni_segments.append(Image(url=seg.data.get("url")))
|
||||
elif seg.type == "record":
|
||||
uni_segments.append(Voice(url=seg.data.get("url")))
|
||||
elif seg.type == "video":
|
||||
uni_segments.append(Video(url=seg.data.get("url")))
|
||||
elif seg.type == "at":
|
||||
uni_segments.append(At("user", str(seg.data.get("qq", ""))))
|
||||
else:
|
||||
logger.debug(f"跳过不支持的平台消息段类型: {seg.type}")
|
||||
|
||||
return UniMessage(uni_segments)
|
||||
|
||||
|
||||
def _sanitize_request_body_for_logging(body: dict) -> dict:
|
||||
"""
|
||||
净化请求体用于日志记录,移除大数据字段并添加摘要信息
|
||||
|
||||
参数:
|
||||
body: 原始请求体字典。
|
||||
|
||||
返回:
|
||||
dict: 净化后的请求体字典。
|
||||
"""
|
||||
try:
|
||||
sanitized_body = copy.deepcopy(body)
|
||||
|
||||
if "contents" in sanitized_body and isinstance(
|
||||
sanitized_body["contents"], list
|
||||
):
|
||||
for content_item in sanitized_body["contents"]:
|
||||
if "parts" in content_item and isinstance(content_item["parts"], list):
|
||||
media_summary = []
|
||||
new_parts = []
|
||||
for part in content_item["parts"]:
|
||||
if "inlineData" in part and isinstance(
|
||||
part["inlineData"], dict
|
||||
):
|
||||
data = part["inlineData"].get("data")
|
||||
if isinstance(data, str):
|
||||
mime_type = part["inlineData"].get(
|
||||
"mimeType", "unknown"
|
||||
)
|
||||
media_summary.append(f"{mime_type} ({len(data)} chars)")
|
||||
continue
|
||||
new_parts.append(part)
|
||||
|
||||
if media_summary:
|
||||
summary_text = (
|
||||
f"[多模态内容: {len(media_summary)}个文件 - "
|
||||
f"{', '.join(media_summary)}]"
|
||||
)
|
||||
new_parts.insert(0, {"text": summary_text})
|
||||
|
||||
content_item["parts"] = new_parts
|
||||
|
||||
return sanitized_body
|
||||
except Exception as e:
|
||||
logger.warning(f"日志净化失败: {e},将记录原始请求体。")
|
||||
return body
|
||||
|
||||
Loading…
Reference in New Issue
Block a user