mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 14:22:55 +08:00
✨ feat(llm): 全面重构LLM服务模块,增强多模态与工具支持 (#1953)
* ✨ feat(llm): 全面重构LLM服务模块,增强多模态与工具支持 🚀 核心功能增强 - 多模型链式调用:新增 `pipeline_chat` 支持复杂任务流处理 - 扩展提供商支持:新增 ARK(火山方舟)、SiliconFlow(硅基流动) 适配器 - 多模态处理增强:支持URL媒体文件下载转换,提升输入灵活性 - 历史对话支持:AI.analyze 方法支持历史消息上下文和可选 UniMessage 参数 - 文本嵌入功能:新增 `embed`、`analyze_multimodal`、`search_multimodal` 等API - 模型能力系统:新增 `ModelCapabilities` 统一管理模型特性(多模态、工具调用等) 🔧 架构重构与优化 - MCP工具系统重构:配置独立化至 `data/llm/mcp_tools.json`,预置常用工具 - API调用逻辑统一:提取通用 `_perform_api_call` 方法,消除代码重复 - 跨平台兼容:Windows平台MCP工具npx命令自动包装处理 - HTTP客户端增强:兼容不同版本httpx代理配置(0.28+版本适配) 🛠️ API与配置完善 - 统一返回类型:`AI.analyze` 统一返回 `LLMResponse` 类型 - 消息转换工具:新增 `message_to_unimessage` 转换函数 - Gemini适配器增强:URL图片下载编码、动态安全阈值配置 - 缓存管理:新增模型实例缓存和管理功能 - 配置预设:扩展 CommonOverrides 预设配置选项 - 历史管理优化:支持多模态内容占位符替换,提升效率 📚 文档与开发体验 - README全面重写:新增完整使用指南、API参考和架构概览 - 文档内容扩充:补充嵌入模型、缓存管理、工具注册等功能说明 - 日志记录增强:支持详细调试信息输出 - API简化:移除冗余函数,优化接口设计 * 🎨 feat(llm): 统一LLM服务函数文档格式 * ✨ feat(llm): 添加新模型并简化提供者配置加载 * 🚨 auto fix by pre-commit hooks --------- Co-authored-by: webjoin111 <455457521@qq.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
1e7ae38684
commit
48cbb2bf1d
File diff suppressed because it is too large
Load Diff
@ -10,10 +10,10 @@ from .api import (
|
|||||||
TaskType,
|
TaskType,
|
||||||
analyze,
|
analyze,
|
||||||
analyze_multimodal,
|
analyze_multimodal,
|
||||||
analyze_with_images,
|
|
||||||
chat,
|
chat,
|
||||||
code,
|
code,
|
||||||
embed,
|
embed,
|
||||||
|
pipeline_chat,
|
||||||
search,
|
search,
|
||||||
search_multimodal,
|
search_multimodal,
|
||||||
)
|
)
|
||||||
@ -35,6 +35,7 @@ from .manager import (
|
|||||||
list_model_identifiers,
|
list_model_identifiers,
|
||||||
set_global_default_model_name,
|
set_global_default_model_name,
|
||||||
)
|
)
|
||||||
|
from .tools import tool_registry
|
||||||
from .types import (
|
from .types import (
|
||||||
EmbeddingTaskType,
|
EmbeddingTaskType,
|
||||||
LLMContentPart,
|
LLMContentPart,
|
||||||
@ -43,6 +44,7 @@ from .types import (
|
|||||||
LLMMessage,
|
LLMMessage,
|
||||||
LLMResponse,
|
LLMResponse,
|
||||||
LLMTool,
|
LLMTool,
|
||||||
|
MCPCompatible,
|
||||||
ModelDetail,
|
ModelDetail,
|
||||||
ModelInfo,
|
ModelInfo,
|
||||||
ModelProvider,
|
ModelProvider,
|
||||||
@ -51,7 +53,7 @@ from .types import (
|
|||||||
ToolMetadata,
|
ToolMetadata,
|
||||||
UsageInfo,
|
UsageInfo,
|
||||||
)
|
)
|
||||||
from .utils import create_multimodal_message, unimsg_to_llm_parts
|
from .utils import create_multimodal_message, message_to_unimessage, unimsg_to_llm_parts
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AI",
|
"AI",
|
||||||
@ -65,6 +67,7 @@ __all__ = [
|
|||||||
"LLMMessage",
|
"LLMMessage",
|
||||||
"LLMResponse",
|
"LLMResponse",
|
||||||
"LLMTool",
|
"LLMTool",
|
||||||
|
"MCPCompatible",
|
||||||
"ModelDetail",
|
"ModelDetail",
|
||||||
"ModelInfo",
|
"ModelInfo",
|
||||||
"ModelName",
|
"ModelName",
|
||||||
@ -76,7 +79,6 @@ __all__ = [
|
|||||||
"UsageInfo",
|
"UsageInfo",
|
||||||
"analyze",
|
"analyze",
|
||||||
"analyze_multimodal",
|
"analyze_multimodal",
|
||||||
"analyze_with_images",
|
|
||||||
"chat",
|
"chat",
|
||||||
"clear_model_cache",
|
"clear_model_cache",
|
||||||
"code",
|
"code",
|
||||||
@ -88,9 +90,12 @@ __all__ = [
|
|||||||
"list_available_models",
|
"list_available_models",
|
||||||
"list_embedding_models",
|
"list_embedding_models",
|
||||||
"list_model_identifiers",
|
"list_model_identifiers",
|
||||||
|
"message_to_unimessage",
|
||||||
|
"pipeline_chat",
|
||||||
"register_llm_configs",
|
"register_llm_configs",
|
||||||
"search",
|
"search",
|
||||||
"search_multimodal",
|
"search_multimodal",
|
||||||
"set_global_default_model_name",
|
"set_global_default_model_name",
|
||||||
|
"tool_registry",
|
||||||
"unimsg_to_llm_parts",
|
"unimsg_to_llm_parts",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -8,7 +8,6 @@ from .base import BaseAdapter, OpenAICompatAdapter, RequestData, ResponseData
|
|||||||
from .factory import LLMAdapterFactory, get_adapter_for_api_type, register_adapter
|
from .factory import LLMAdapterFactory, get_adapter_for_api_type, register_adapter
|
||||||
from .gemini import GeminiAdapter
|
from .gemini import GeminiAdapter
|
||||||
from .openai import OpenAIAdapter
|
from .openai import OpenAIAdapter
|
||||||
from .zhipu import ZhipuAdapter
|
|
||||||
|
|
||||||
LLMAdapterFactory.initialize()
|
LLMAdapterFactory.initialize()
|
||||||
|
|
||||||
@ -20,7 +19,6 @@ __all__ = [
|
|||||||
"OpenAICompatAdapter",
|
"OpenAICompatAdapter",
|
||||||
"RequestData",
|
"RequestData",
|
||||||
"ResponseData",
|
"ResponseData",
|
||||||
"ZhipuAdapter",
|
|
||||||
"get_adapter_for_api_type",
|
"get_adapter_for_api_type",
|
||||||
"register_adapter",
|
"register_adapter",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -17,6 +17,7 @@ if TYPE_CHECKING:
|
|||||||
from ..service import LLMModel
|
from ..service import LLMModel
|
||||||
from ..types.content import LLMMessage
|
from ..types.content import LLMMessage
|
||||||
from ..types.enums import EmbeddingTaskType
|
from ..types.enums import EmbeddingTaskType
|
||||||
|
from ..types.models import LLMTool
|
||||||
|
|
||||||
|
|
||||||
class RequestData(BaseModel):
|
class RequestData(BaseModel):
|
||||||
@ -60,7 +61,7 @@ class BaseAdapter(ABC):
|
|||||||
"""支持的API类型列表"""
|
"""支持的API类型列表"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def prepare_simple_request(
|
async def prepare_simple_request(
|
||||||
self,
|
self,
|
||||||
model: "LLMModel",
|
model: "LLMModel",
|
||||||
api_key: str,
|
api_key: str,
|
||||||
@ -86,7 +87,7 @@ class BaseAdapter(ABC):
|
|||||||
|
|
||||||
config = model._generation_config
|
config = model._generation_config
|
||||||
|
|
||||||
return self.prepare_advanced_request(
|
return await self.prepare_advanced_request(
|
||||||
model=model,
|
model=model,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
@ -96,13 +97,13 @@ class BaseAdapter(ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def prepare_advanced_request(
|
async def prepare_advanced_request(
|
||||||
self,
|
self,
|
||||||
model: "LLMModel",
|
model: "LLMModel",
|
||||||
api_key: str,
|
api_key: str,
|
||||||
messages: list["LLMMessage"],
|
messages: list["LLMMessage"],
|
||||||
config: "LLMGenerationConfig | None" = None,
|
config: "LLMGenerationConfig | None" = None,
|
||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list["LLMTool"] | None = None,
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
) -> RequestData:
|
) -> RequestData:
|
||||||
"""准备高级请求"""
|
"""准备高级请求"""
|
||||||
@ -238,6 +239,9 @@ class BaseAdapter(ABC):
|
|||||||
message = choice.get("message", {})
|
message = choice.get("message", {})
|
||||||
content = message.get("content", "")
|
content = message.get("content", "")
|
||||||
|
|
||||||
|
if content:
|
||||||
|
content = content.strip()
|
||||||
|
|
||||||
parsed_tool_calls: list[LLMToolCall] | None = None
|
parsed_tool_calls: list[LLMToolCall] | None = None
|
||||||
if message_tool_calls := message.get("tool_calls"):
|
if message_tool_calls := message.get("tool_calls"):
|
||||||
from ..types.models import LLMToolFunction
|
from ..types.models import LLMToolFunction
|
||||||
@ -375,7 +379,7 @@ class BaseAdapter(ABC):
|
|||||||
if model.temperature is not None:
|
if model.temperature is not None:
|
||||||
base_config["temperature"] = model.temperature
|
base_config["temperature"] = model.temperature
|
||||||
if model.max_tokens is not None:
|
if model.max_tokens is not None:
|
||||||
if model.api_type in ["gemini", "gemini_native"]:
|
if model.api_type == "gemini":
|
||||||
base_config["maxOutputTokens"] = model.max_tokens
|
base_config["maxOutputTokens"] = model.max_tokens
|
||||||
else:
|
else:
|
||||||
base_config["max_tokens"] = model.max_tokens
|
base_config["max_tokens"] = model.max_tokens
|
||||||
@ -401,26 +405,51 @@ class OpenAICompatAdapter(BaseAdapter):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_chat_endpoint(self) -> str:
|
def get_chat_endpoint(self, model: "LLMModel") -> str:
|
||||||
"""子类必须实现,返回 chat completions 的端点"""
|
"""子类必须实现,返回 chat completions 的端点"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_embedding_endpoint(self) -> str:
|
def get_embedding_endpoint(self, model: "LLMModel") -> str:
|
||||||
"""子类必须实现,返回 embeddings 的端点"""
|
"""子类必须实现,返回 embeddings 的端点"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def prepare_advanced_request(
|
async def prepare_simple_request(
|
||||||
|
self,
|
||||||
|
model: "LLMModel",
|
||||||
|
api_key: str,
|
||||||
|
prompt: str,
|
||||||
|
history: list[dict[str, str]] | None = None,
|
||||||
|
) -> RequestData:
|
||||||
|
"""准备简单文本生成请求 - OpenAI兼容API的通用实现"""
|
||||||
|
url = self.get_api_url(model, self.get_chat_endpoint(model))
|
||||||
|
headers = self.get_base_headers(api_key)
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
if history:
|
||||||
|
messages.extend(history)
|
||||||
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
|
body = {
|
||||||
|
"model": model.model_name,
|
||||||
|
"messages": messages,
|
||||||
|
}
|
||||||
|
|
||||||
|
body = self.apply_config_override(model, body)
|
||||||
|
|
||||||
|
return RequestData(url=url, headers=headers, body=body)
|
||||||
|
|
||||||
|
async def prepare_advanced_request(
|
||||||
self,
|
self,
|
||||||
model: "LLMModel",
|
model: "LLMModel",
|
||||||
api_key: str,
|
api_key: str,
|
||||||
messages: list["LLMMessage"],
|
messages: list["LLMMessage"],
|
||||||
config: "LLMGenerationConfig | None" = None,
|
config: "LLMGenerationConfig | None" = None,
|
||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list["LLMTool"] | None = None,
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
) -> RequestData:
|
) -> RequestData:
|
||||||
"""准备高级请求 - OpenAI兼容格式"""
|
"""准备高级请求 - OpenAI兼容格式"""
|
||||||
url = self.get_api_url(model, self.get_chat_endpoint())
|
url = self.get_api_url(model, self.get_chat_endpoint(model))
|
||||||
headers = self.get_base_headers(api_key)
|
headers = self.get_base_headers(api_key)
|
||||||
openai_messages = self.convert_messages_to_openai_format(messages)
|
openai_messages = self.convert_messages_to_openai_format(messages)
|
||||||
|
|
||||||
@ -430,7 +459,21 @@ class OpenAICompatAdapter(BaseAdapter):
|
|||||||
}
|
}
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
body["tools"] = tools
|
openai_tools = []
|
||||||
|
for tool in tools:
|
||||||
|
if tool.type == "function" and tool.function:
|
||||||
|
openai_tools.append({"type": "function", "function": tool.function})
|
||||||
|
elif tool.type == "mcp" and tool.mcp_session:
|
||||||
|
if callable(tool.mcp_session):
|
||||||
|
raise ValueError(
|
||||||
|
"适配器接收到未激活的 MCP 会话工厂。"
|
||||||
|
"会话工厂应该在 LLMModel.generate_response 中被激活。"
|
||||||
|
)
|
||||||
|
openai_tools.append(
|
||||||
|
tool.mcp_session.to_api_tool(api_type=self.api_type)
|
||||||
|
)
|
||||||
|
if openai_tools:
|
||||||
|
body["tools"] = openai_tools
|
||||||
if tool_choice:
|
if tool_choice:
|
||||||
body["tool_choice"] = tool_choice
|
body["tool_choice"] = tool_choice
|
||||||
|
|
||||||
@ -444,7 +487,7 @@ class OpenAICompatAdapter(BaseAdapter):
|
|||||||
is_advanced: bool = False,
|
is_advanced: bool = False,
|
||||||
) -> ResponseData:
|
) -> ResponseData:
|
||||||
"""解析响应 - 直接使用基类的 OpenAI 格式解析"""
|
"""解析响应 - 直接使用基类的 OpenAI 格式解析"""
|
||||||
_ = model, is_advanced # 未使用的参数
|
_ = model, is_advanced
|
||||||
return self.parse_openai_response(response_json)
|
return self.parse_openai_response(response_json)
|
||||||
|
|
||||||
def prepare_embedding_request(
|
def prepare_embedding_request(
|
||||||
@ -456,8 +499,8 @@ class OpenAICompatAdapter(BaseAdapter):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> RequestData:
|
) -> RequestData:
|
||||||
"""准备嵌入请求 - OpenAI兼容格式"""
|
"""准备嵌入请求 - OpenAI兼容格式"""
|
||||||
_ = task_type # 未使用的参数
|
_ = task_type
|
||||||
url = self.get_api_url(model, self.get_embedding_endpoint())
|
url = self.get_api_url(model, self.get_embedding_endpoint(model))
|
||||||
headers = self.get_base_headers(api_key)
|
headers = self.get_base_headers(api_key)
|
||||||
|
|
||||||
body = {
|
body = {
|
||||||
@ -465,7 +508,6 @@ class OpenAICompatAdapter(BaseAdapter):
|
|||||||
"input": texts,
|
"input": texts,
|
||||||
}
|
}
|
||||||
|
|
||||||
# 应用额外的配置参数
|
|
||||||
if kwargs:
|
if kwargs:
|
||||||
body.update(kwargs)
|
body.update(kwargs)
|
||||||
|
|
||||||
|
|||||||
@ -22,10 +22,8 @@ class LLMAdapterFactory:
|
|||||||
|
|
||||||
from .gemini import GeminiAdapter
|
from .gemini import GeminiAdapter
|
||||||
from .openai import OpenAIAdapter
|
from .openai import OpenAIAdapter
|
||||||
from .zhipu import ZhipuAdapter
|
|
||||||
|
|
||||||
cls.register_adapter(OpenAIAdapter())
|
cls.register_adapter(OpenAIAdapter())
|
||||||
cls.register_adapter(ZhipuAdapter())
|
|
||||||
cls.register_adapter(GeminiAdapter())
|
cls.register_adapter(GeminiAdapter())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
|||||||
from ..service import LLMModel
|
from ..service import LLMModel
|
||||||
from ..types.content import LLMMessage
|
from ..types.content import LLMMessage
|
||||||
from ..types.enums import EmbeddingTaskType
|
from ..types.enums import EmbeddingTaskType
|
||||||
from ..types.models import LLMToolCall
|
from ..types.models import LLMTool, LLMToolCall
|
||||||
|
|
||||||
|
|
||||||
class GeminiAdapter(BaseAdapter):
|
class GeminiAdapter(BaseAdapter):
|
||||||
@ -38,30 +38,16 @@ class GeminiAdapter(BaseAdapter):
|
|||||||
|
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
def prepare_advanced_request(
|
async def prepare_advanced_request(
|
||||||
self,
|
self,
|
||||||
model: "LLMModel",
|
model: "LLMModel",
|
||||||
api_key: str,
|
api_key: str,
|
||||||
messages: list["LLMMessage"],
|
messages: list["LLMMessage"],
|
||||||
config: "LLMGenerationConfig | None" = None,
|
config: "LLMGenerationConfig | None" = None,
|
||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list["LLMTool"] | None = None,
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
) -> RequestData:
|
) -> RequestData:
|
||||||
"""准备高级请求"""
|
"""准备高级请求"""
|
||||||
return self._prepare_request(
|
|
||||||
model, api_key, messages, config, tools, tool_choice
|
|
||||||
)
|
|
||||||
|
|
||||||
def _prepare_request(
|
|
||||||
self,
|
|
||||||
model: "LLMModel",
|
|
||||||
api_key: str,
|
|
||||||
messages: list["LLMMessage"],
|
|
||||||
config: "LLMGenerationConfig | None" = None,
|
|
||||||
tools: list[dict[str, Any]] | None = None,
|
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
|
||||||
) -> RequestData:
|
|
||||||
"""准备 Gemini API 请求 - 支持所有高级功能"""
|
|
||||||
effective_config = config if config is not None else model._generation_config
|
effective_config = config if config is not None else model._generation_config
|
||||||
|
|
||||||
endpoint = self._get_gemini_endpoint(model, effective_config)
|
endpoint = self._get_gemini_endpoint(model, effective_config)
|
||||||
@ -78,7 +64,8 @@ class GeminiAdapter(BaseAdapter):
|
|||||||
system_instruction_parts = [{"text": msg.content}]
|
system_instruction_parts = [{"text": msg.content}]
|
||||||
elif isinstance(msg.content, list):
|
elif isinstance(msg.content, list):
|
||||||
system_instruction_parts = [
|
system_instruction_parts = [
|
||||||
part.convert_for_api("gemini") for part in msg.content
|
await part.convert_for_api_async("gemini")
|
||||||
|
for part in msg.content
|
||||||
]
|
]
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -87,7 +74,9 @@ class GeminiAdapter(BaseAdapter):
|
|||||||
current_parts.append({"text": msg.content})
|
current_parts.append({"text": msg.content})
|
||||||
elif isinstance(msg.content, list):
|
elif isinstance(msg.content, list):
|
||||||
for part_obj in msg.content:
|
for part_obj in msg.content:
|
||||||
current_parts.append(part_obj.convert_for_api("gemini"))
|
current_parts.append(
|
||||||
|
await part_obj.convert_for_api_async("gemini")
|
||||||
|
)
|
||||||
gemini_contents.append({"role": "user", "parts": current_parts})
|
gemini_contents.append({"role": "user", "parts": current_parts})
|
||||||
|
|
||||||
elif msg.role == "assistant" or msg.role == "model":
|
elif msg.role == "assistant" or msg.role == "model":
|
||||||
@ -95,7 +84,9 @@ class GeminiAdapter(BaseAdapter):
|
|||||||
current_parts.append({"text": msg.content})
|
current_parts.append({"text": msg.content})
|
||||||
elif isinstance(msg.content, list):
|
elif isinstance(msg.content, list):
|
||||||
for part_obj in msg.content:
|
for part_obj in msg.content:
|
||||||
current_parts.append(part_obj.convert_for_api("gemini"))
|
current_parts.append(
|
||||||
|
await part_obj.convert_for_api_async("gemini")
|
||||||
|
)
|
||||||
|
|
||||||
if msg.tool_calls:
|
if msg.tool_calls:
|
||||||
import json
|
import json
|
||||||
@ -154,16 +145,22 @@ class GeminiAdapter(BaseAdapter):
|
|||||||
|
|
||||||
all_tools_for_request = []
|
all_tools_for_request = []
|
||||||
if tools:
|
if tools:
|
||||||
for tool_item in tools:
|
for tool in tools:
|
||||||
if isinstance(tool_item, dict):
|
if tool.type == "function" and tool.function:
|
||||||
if "name" in tool_item and "description" in tool_item:
|
|
||||||
all_tools_for_request.append(
|
all_tools_for_request.append(
|
||||||
{"functionDeclarations": [tool_item]}
|
{"functionDeclarations": [tool.function]}
|
||||||
)
|
)
|
||||||
else:
|
elif tool.type == "mcp" and tool.mcp_session:
|
||||||
all_tools_for_request.append(tool_item)
|
if callable(tool.mcp_session):
|
||||||
else:
|
raise ValueError(
|
||||||
all_tools_for_request.append(tool_item)
|
"适配器接收到未激活的 MCP 会话工厂。"
|
||||||
|
"会话工厂应该在 LLMModel.generate_response 中被激活。"
|
||||||
|
)
|
||||||
|
all_tools_for_request.append(
|
||||||
|
tool.mcp_session.to_api_tool(api_type=self.api_type)
|
||||||
|
)
|
||||||
|
elif tool.type == "google_search":
|
||||||
|
all_tools_for_request.append({"googleSearch": {}})
|
||||||
|
|
||||||
if effective_config:
|
if effective_config:
|
||||||
if getattr(effective_config, "enable_grounding", False):
|
if getattr(effective_config, "enable_grounding", False):
|
||||||
@ -183,11 +180,7 @@ class GeminiAdapter(BaseAdapter):
|
|||||||
logger.debug("隐式启用代码执行工具。")
|
logger.debug("隐式启用代码执行工具。")
|
||||||
|
|
||||||
if all_tools_for_request:
|
if all_tools_for_request:
|
||||||
gemini_api_tools = self._convert_tools_to_gemini_format(
|
body["tools"] = all_tools_for_request
|
||||||
all_tools_for_request
|
|
||||||
)
|
|
||||||
if gemini_api_tools:
|
|
||||||
body["tools"] = gemini_api_tools
|
|
||||||
|
|
||||||
final_tool_choice = tool_choice
|
final_tool_choice = tool_choice
|
||||||
if final_tool_choice is None and effective_config:
|
if final_tool_choice is None and effective_config:
|
||||||
@ -241,38 +234,6 @@ class GeminiAdapter(BaseAdapter):
|
|||||||
|
|
||||||
return f"/v1beta/models/{model.model_name}:generateContent"
|
return f"/v1beta/models/{model.model_name}:generateContent"
|
||||||
|
|
||||||
def _convert_tools_to_gemini_format(
|
|
||||||
self, tools: list[dict[str, Any]]
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
"""转换工具格式为Gemini格式"""
|
|
||||||
gemini_tools = []
|
|
||||||
|
|
||||||
for tool in tools:
|
|
||||||
if tool.get("type") == "function":
|
|
||||||
func = tool["function"]
|
|
||||||
gemini_tool = {
|
|
||||||
"functionDeclarations": [
|
|
||||||
{
|
|
||||||
"name": func["name"],
|
|
||||||
"description": func.get("description", ""),
|
|
||||||
"parameters": func.get("parameters", {}),
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
gemini_tools.append(gemini_tool)
|
|
||||||
elif tool.get("type") == "code_execution":
|
|
||||||
gemini_tools.append(
|
|
||||||
{"codeExecution": {"language": tool.get("language", "python")}}
|
|
||||||
)
|
|
||||||
elif tool.get("type") == "google_search":
|
|
||||||
gemini_tools.append({"googleSearch": {}})
|
|
||||||
elif "googleSearch" in tool:
|
|
||||||
gemini_tools.append({"googleSearch": tool["googleSearch"]})
|
|
||||||
elif "codeExecution" in tool:
|
|
||||||
gemini_tools.append({"codeExecution": tool["codeExecution"]})
|
|
||||||
|
|
||||||
return gemini_tools
|
|
||||||
|
|
||||||
def _convert_tool_choice_to_gemini(
|
def _convert_tool_choice_to_gemini(
|
||||||
self, tool_choice_value: str | dict[str, Any]
|
self, tool_choice_value: str | dict[str, Any]
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
@ -395,10 +356,11 @@ class GeminiAdapter(BaseAdapter):
|
|||||||
for category, threshold in custom_safety_settings.items():
|
for category, threshold in custom_safety_settings.items():
|
||||||
safety_settings.append({"category": category, "threshold": threshold})
|
safety_settings.append({"category": category, "threshold": threshold})
|
||||||
else:
|
else:
|
||||||
|
from ..config.providers import get_gemini_safety_threshold
|
||||||
|
|
||||||
|
threshold = get_gemini_safety_threshold()
|
||||||
for category in safety_categories:
|
for category in safety_categories:
|
||||||
safety_settings.append(
|
safety_settings.append({"category": category, "threshold": threshold})
|
||||||
{"category": category, "threshold": "BLOCK_MEDIUM_AND_ABOVE"}
|
|
||||||
)
|
|
||||||
|
|
||||||
return safety_settings if safety_settings else None
|
return safety_settings if safety_settings else None
|
||||||
|
|
||||||
|
|||||||
@ -1,12 +1,12 @@
|
|||||||
"""
|
"""
|
||||||
OpenAI API 适配器
|
OpenAI API 适配器
|
||||||
|
|
||||||
支持 OpenAI、DeepSeek 和其他 OpenAI 兼容的 API 服务。
|
支持 OpenAI、DeepSeek、智谱AI 和其他 OpenAI 兼容的 API 服务。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from .base import OpenAICompatAdapter, RequestData
|
from .base import OpenAICompatAdapter
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..service import LLMModel
|
from ..service import LLMModel
|
||||||
@ -21,37 +21,18 @@ class OpenAIAdapter(OpenAICompatAdapter):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def supported_api_types(self) -> list[str]:
|
def supported_api_types(self) -> list[str]:
|
||||||
return ["openai", "deepseek", "general_openai_compat"]
|
return ["openai", "deepseek", "zhipu", "general_openai_compat", "ark"]
|
||||||
|
|
||||||
def get_chat_endpoint(self) -> str:
|
def get_chat_endpoint(self, model: "LLMModel") -> str:
|
||||||
"""返回聊天完成端点"""
|
"""返回聊天完成端点"""
|
||||||
|
if model.api_type == "ark":
|
||||||
|
return "/api/v3/chat/completions"
|
||||||
|
if model.api_type == "zhipu":
|
||||||
|
return "/api/paas/v4/chat/completions"
|
||||||
return "/v1/chat/completions"
|
return "/v1/chat/completions"
|
||||||
|
|
||||||
def get_embedding_endpoint(self) -> str:
|
def get_embedding_endpoint(self, model: "LLMModel") -> str:
|
||||||
"""返回嵌入端点"""
|
"""根据API类型返回嵌入端点"""
|
||||||
|
if model.api_type == "zhipu":
|
||||||
|
return "/v4/embeddings"
|
||||||
return "/v1/embeddings"
|
return "/v1/embeddings"
|
||||||
|
|
||||||
def prepare_simple_request(
|
|
||||||
self,
|
|
||||||
model: "LLMModel",
|
|
||||||
api_key: str,
|
|
||||||
prompt: str,
|
|
||||||
history: list[dict[str, str]] | None = None,
|
|
||||||
) -> RequestData:
|
|
||||||
"""准备简单文本生成请求 - OpenAI优化实现"""
|
|
||||||
url = self.get_api_url(model, self.get_chat_endpoint())
|
|
||||||
headers = self.get_base_headers(api_key)
|
|
||||||
|
|
||||||
messages = []
|
|
||||||
if history:
|
|
||||||
messages.extend(history)
|
|
||||||
messages.append({"role": "user", "content": prompt})
|
|
||||||
|
|
||||||
body = {
|
|
||||||
"model": model.model_name,
|
|
||||||
"messages": messages,
|
|
||||||
}
|
|
||||||
|
|
||||||
body = self.apply_config_override(model, body)
|
|
||||||
|
|
||||||
return RequestData(url=url, headers=headers, body=body)
|
|
||||||
|
|||||||
@ -1,57 +0,0 @@
|
|||||||
"""
|
|
||||||
智谱 AI API 适配器
|
|
||||||
|
|
||||||
支持智谱 AI 的 GLM 系列模型,使用 OpenAI 兼容的接口格式。
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from .base import OpenAICompatAdapter, RequestData
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from ..service import LLMModel
|
|
||||||
|
|
||||||
|
|
||||||
class ZhipuAdapter(OpenAICompatAdapter):
|
|
||||||
"""智谱AI适配器 - 使用智谱AI专用的OpenAI兼容接口"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def api_type(self) -> str:
|
|
||||||
return "zhipu"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supported_api_types(self) -> list[str]:
|
|
||||||
return ["zhipu"]
|
|
||||||
|
|
||||||
def get_chat_endpoint(self) -> str:
|
|
||||||
"""返回智谱AI聊天完成端点"""
|
|
||||||
return "/api/paas/v4/chat/completions"
|
|
||||||
|
|
||||||
def get_embedding_endpoint(self) -> str:
|
|
||||||
"""返回智谱AI嵌入端点"""
|
|
||||||
return "/v4/embeddings"
|
|
||||||
|
|
||||||
def prepare_simple_request(
|
|
||||||
self,
|
|
||||||
model: "LLMModel",
|
|
||||||
api_key: str,
|
|
||||||
prompt: str,
|
|
||||||
history: list[dict[str, str]] | None = None,
|
|
||||||
) -> RequestData:
|
|
||||||
"""准备简单文本生成请求 - 智谱AI优化实现"""
|
|
||||||
url = self.get_api_url(model, self.get_chat_endpoint())
|
|
||||||
headers = self.get_base_headers(api_key)
|
|
||||||
|
|
||||||
messages = []
|
|
||||||
if history:
|
|
||||||
messages.extend(history)
|
|
||||||
messages.append({"role": "user", "content": prompt})
|
|
||||||
|
|
||||||
body = {
|
|
||||||
"model": model.model_name,
|
|
||||||
"messages": messages,
|
|
||||||
}
|
|
||||||
|
|
||||||
body = self.apply_config_override(model, body)
|
|
||||||
|
|
||||||
return RequestData(url=url, headers=headers, body=body)
|
|
||||||
@ -2,6 +2,7 @@
|
|||||||
LLM 服务的高级 API 接口
|
LLM 服务的高级 API 接口
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -14,6 +15,7 @@ from zhenxun.services.log import logger
|
|||||||
from .config import CommonOverrides, LLMGenerationConfig
|
from .config import CommonOverrides, LLMGenerationConfig
|
||||||
from .config.providers import get_ai_config
|
from .config.providers import get_ai_config
|
||||||
from .manager import get_global_default_model_name, get_model_instance
|
from .manager import get_global_default_model_name, get_model_instance
|
||||||
|
from .tools import tool_registry
|
||||||
from .types import (
|
from .types import (
|
||||||
EmbeddingTaskType,
|
EmbeddingTaskType,
|
||||||
LLMContentPart,
|
LLMContentPart,
|
||||||
@ -56,6 +58,7 @@ class AIConfig:
|
|||||||
enable_gemini_safe_mode: bool = False
|
enable_gemini_safe_mode: bool = False
|
||||||
enable_gemini_multimodal: bool = False
|
enable_gemini_multimodal: bool = False
|
||||||
enable_gemini_grounding: bool = False
|
enable_gemini_grounding: bool = False
|
||||||
|
default_preserve_media_in_history: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""初始化后从配置中读取默认值"""
|
"""初始化后从配置中读取默认值"""
|
||||||
@ -81,7 +84,7 @@ class AI:
|
|||||||
"""
|
"""
|
||||||
初始化AI服务
|
初始化AI服务
|
||||||
|
|
||||||
Args:
|
参数:
|
||||||
config: AI 配置.
|
config: AI 配置.
|
||||||
history: 可选的初始对话历史.
|
history: 可选的初始对话历史.
|
||||||
"""
|
"""
|
||||||
@ -93,16 +96,65 @@ class AI:
|
|||||||
self.history = []
|
self.history = []
|
||||||
logger.info("AI session history cleared.")
|
logger.info("AI session history cleared.")
|
||||||
|
|
||||||
|
def _sanitize_message_for_history(self, message: LLMMessage) -> LLMMessage:
|
||||||
|
"""
|
||||||
|
净化用于存入历史记录的消息。
|
||||||
|
将非文本的多模态内容部分替换为文本占位符,以避免重复处理。
|
||||||
|
"""
|
||||||
|
if not isinstance(message.content, list):
|
||||||
|
return message
|
||||||
|
|
||||||
|
sanitized_message = copy.deepcopy(message)
|
||||||
|
content_list = sanitized_message.content
|
||||||
|
if not isinstance(content_list, list):
|
||||||
|
return sanitized_message
|
||||||
|
|
||||||
|
new_content_parts: list[LLMContentPart] = []
|
||||||
|
has_multimodal_content = False
|
||||||
|
|
||||||
|
for part in content_list:
|
||||||
|
if isinstance(part, LLMContentPart) and part.type == "text":
|
||||||
|
new_content_parts.append(part)
|
||||||
|
else:
|
||||||
|
has_multimodal_content = True
|
||||||
|
|
||||||
|
if has_multimodal_content:
|
||||||
|
placeholder = "[用户发送了媒体文件,内容已在首次分析时处理]"
|
||||||
|
text_part_found = False
|
||||||
|
for part in new_content_parts:
|
||||||
|
if part.type == "text":
|
||||||
|
part.text = f"{placeholder} {part.text or ''}".strip()
|
||||||
|
text_part_found = True
|
||||||
|
break
|
||||||
|
if not text_part_found:
|
||||||
|
new_content_parts.insert(0, LLMContentPart.text_part(placeholder))
|
||||||
|
|
||||||
|
sanitized_message.content = new_content_parts
|
||||||
|
return sanitized_message
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
message: str | LLMMessage | list[LLMContentPart],
|
message: str | LLMMessage | list[LLMContentPart],
|
||||||
*,
|
*,
|
||||||
model: ModelName = None,
|
model: ModelName = None,
|
||||||
|
preserve_media_in_history: bool | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
进行一次聊天对话。
|
进行一次聊天对话。
|
||||||
此方法会自动使用和更新会话内的历史记录。
|
此方法会自动使用和更新会话内的历史记录。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
message: 用户输入的消息。
|
||||||
|
model: 本次对话要使用的模型。
|
||||||
|
preserve_media_in_history: 是否在历史记录中保留原始多模态信息。
|
||||||
|
- True: 保留,用于深度多轮媒体分析。
|
||||||
|
- False: 不保留,替换为占位符,提高效率。
|
||||||
|
- None (默认): 使用AI实例配置的默认值。
|
||||||
|
**kwargs: 传递给模型的其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 模型的文本响应。
|
||||||
"""
|
"""
|
||||||
current_message: LLMMessage
|
current_message: LLMMessage
|
||||||
if isinstance(message, str):
|
if isinstance(message, str):
|
||||||
@ -127,7 +179,20 @@ class AI:
|
|||||||
final_messages, model, "聊天失败", kwargs
|
final_messages, model, "聊天失败", kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
should_preserve = (
|
||||||
|
preserve_media_in_history
|
||||||
|
if preserve_media_in_history is not None
|
||||||
|
else self.config.default_preserve_media_in_history
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_preserve:
|
||||||
|
logger.debug("深度分析模式:在历史记录中保留原始多模态消息。")
|
||||||
self.history.append(current_message)
|
self.history.append(current_message)
|
||||||
|
else:
|
||||||
|
logger.debug("高效模式:净化历史记录中的多模态消息。")
|
||||||
|
sanitized_user_message = self._sanitize_message_for_history(current_message)
|
||||||
|
self.history.append(sanitized_user_message)
|
||||||
|
|
||||||
self.history.append(LLMMessage.assistant_text_response(response.text))
|
self.history.append(LLMMessage.assistant_text_response(response.text))
|
||||||
|
|
||||||
return response.text
|
return response.text
|
||||||
@ -140,7 +205,18 @@ class AI:
|
|||||||
timeout: int | None = None,
|
timeout: int | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""代码执行"""
|
"""
|
||||||
|
代码执行
|
||||||
|
|
||||||
|
参数:
|
||||||
|
prompt: 代码执行的提示词。
|
||||||
|
model: 要使用的模型名称。
|
||||||
|
timeout: 代码执行超时时间(秒)。
|
||||||
|
**kwargs: 传递给模型的其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
dict[str, Any]: 包含执行结果的字典,包含text、code_executions和success字段。
|
||||||
|
"""
|
||||||
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
|
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
|
||||||
|
|
||||||
config = CommonOverrides.gemini_code_execution()
|
config = CommonOverrides.gemini_code_execution()
|
||||||
@ -168,7 +244,18 @@ class AI:
|
|||||||
instruction: str = "",
|
instruction: str = "",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""信息搜索 - 支持多模态输入"""
|
"""
|
||||||
|
信息搜索 - 支持多模态输入
|
||||||
|
|
||||||
|
参数:
|
||||||
|
query: 搜索查询内容,支持文本或多模态消息。
|
||||||
|
model: 要使用的模型名称。
|
||||||
|
instruction: 搜索指令。
|
||||||
|
**kwargs: 传递给模型的其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
dict[str, Any]: 包含搜索结果的字典,包含text、sources、queries和success字段
|
||||||
|
"""
|
||||||
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
|
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
|
||||||
config = CommonOverrides.gemini_grounding()
|
config = CommonOverrides.gemini_grounding()
|
||||||
|
|
||||||
@ -217,63 +304,69 @@ class AI:
|
|||||||
|
|
||||||
async def analyze(
|
async def analyze(
|
||||||
self,
|
self,
|
||||||
message: UniMessage,
|
message: UniMessage | None,
|
||||||
*,
|
*,
|
||||||
instruction: str = "",
|
instruction: str = "",
|
||||||
model: ModelName = None,
|
model: ModelName = None,
|
||||||
tools: list[dict[str, Any]] | None = None,
|
use_tools: list[str] | None = None,
|
||||||
tool_config: dict[str, Any] | None = None,
|
tool_config: dict[str, Any] | None = None,
|
||||||
|
activated_tools: list[LLMTool] | None = None,
|
||||||
|
history: list[LLMMessage] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str | LLMResponse:
|
) -> LLMResponse:
|
||||||
"""
|
"""
|
||||||
内容分析 - 接收 UniMessage 物件进行多模态分析和工具呼叫。
|
内容分析 - 接收 UniMessage 物件进行多模态分析和工具呼叫。
|
||||||
这是处理复杂互动的主要方法。
|
|
||||||
|
参数:
|
||||||
|
message: 要分析的消息内容(支持多模态)。
|
||||||
|
instruction: 分析指令。
|
||||||
|
model: 要使用的模型名称。
|
||||||
|
use_tools: 要使用的工具名称列表。
|
||||||
|
tool_config: 工具配置。
|
||||||
|
activated_tools: 已激活的工具列表。
|
||||||
|
history: 对话历史记录。
|
||||||
|
**kwargs: 传递给模型的其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
LLMResponse: 模型的完整响应结果。
|
||||||
"""
|
"""
|
||||||
content_parts = await unimsg_to_llm_parts(message)
|
content_parts = await unimsg_to_llm_parts(message or UniMessage())
|
||||||
|
|
||||||
final_messages: list[LLMMessage] = []
|
final_messages: list[LLMMessage] = []
|
||||||
|
if history:
|
||||||
|
final_messages.extend(history)
|
||||||
|
|
||||||
if instruction:
|
if instruction:
|
||||||
final_messages.append(LLMMessage.system(instruction))
|
if not any(msg.role == "system" for msg in final_messages):
|
||||||
|
final_messages.insert(0, LLMMessage.system(instruction))
|
||||||
|
|
||||||
if not content_parts:
|
if not content_parts:
|
||||||
if instruction:
|
if instruction and not history:
|
||||||
final_messages.append(LLMMessage.user(instruction))
|
final_messages.append(LLMMessage.user(instruction))
|
||||||
else:
|
elif not history:
|
||||||
raise LLMException(
|
raise LLMException(
|
||||||
"分析内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
|
"分析内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
final_messages.append(LLMMessage.user(content_parts))
|
final_messages.append(LLMMessage.user(content_parts))
|
||||||
|
|
||||||
llm_tools = None
|
llm_tools: list[LLMTool] | None = activated_tools
|
||||||
if tools:
|
if not llm_tools and use_tools:
|
||||||
llm_tools = []
|
try:
|
||||||
for tool_dict in tools:
|
llm_tools = tool_registry.get_tools(use_tools)
|
||||||
if isinstance(tool_dict, dict):
|
logger.debug(f"已从注册表加载工具定义: {use_tools}")
|
||||||
if "name" in tool_dict and "description" in tool_dict:
|
except ValueError as e:
|
||||||
llm_tool = LLMTool(
|
raise LLMException(
|
||||||
type="function",
|
f"加载工具定义失败: {e}",
|
||||||
function={
|
code=LLMErrorCode.CONFIGURATION_ERROR,
|
||||||
"name": tool_dict["name"],
|
cause=e,
|
||||||
"description": tool_dict["description"],
|
|
||||||
"parameters": tool_dict.get("parameters", {}),
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
llm_tools.append(llm_tool)
|
|
||||||
else:
|
|
||||||
llm_tools.append(LLMTool(**tool_dict))
|
|
||||||
else:
|
|
||||||
llm_tools.append(tool_dict)
|
|
||||||
|
|
||||||
tool_choice = None
|
tool_choice = None
|
||||||
if tool_config:
|
if tool_config:
|
||||||
mode = tool_config.get("mode", "auto")
|
mode = tool_config.get("mode", "auto")
|
||||||
if mode == "auto":
|
if mode in ["auto", "any", "none"]:
|
||||||
tool_choice = "auto"
|
tool_choice = mode
|
||||||
elif mode == "any":
|
|
||||||
tool_choice = "any"
|
|
||||||
elif mode == "none":
|
|
||||||
tool_choice = "none"
|
|
||||||
|
|
||||||
response = await self._execute_generation(
|
response = await self._execute_generation(
|
||||||
final_messages,
|
final_messages,
|
||||||
@ -284,9 +377,7 @@ class AI:
|
|||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.tool_calls:
|
|
||||||
return response
|
return response
|
||||||
return response.text
|
|
||||||
|
|
||||||
async def _execute_generation(
|
async def _execute_generation(
|
||||||
self,
|
self,
|
||||||
@ -298,7 +389,7 @@ class AI:
|
|||||||
tool_choice: str | dict[str, Any] | None = None,
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
base_config: LLMGenerationConfig | None = None,
|
base_config: LLMGenerationConfig | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""通用的生成执行方法,封装重复的模型获取、配置合并和异常处理逻辑"""
|
"""通用的生成执行方法,封装模型获取和单次API调用"""
|
||||||
try:
|
try:
|
||||||
resolved_model_name = self._resolve_model_name(
|
resolved_model_name = self._resolve_model_name(
|
||||||
model_name or self.config.model
|
model_name or self.config.model
|
||||||
@ -311,7 +402,9 @@ class AI:
|
|||||||
resolved_model_name, override_config=final_config_dict
|
resolved_model_name, override_config=final_config_dict
|
||||||
) as model_instance:
|
) as model_instance:
|
||||||
return await model_instance.generate_response(
|
return await model_instance.generate_response(
|
||||||
messages, tools=llm_tools, tool_choice=tool_choice
|
messages,
|
||||||
|
tools=llm_tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
)
|
)
|
||||||
except LLMException:
|
except LLMException:
|
||||||
raise
|
raise
|
||||||
@ -380,7 +473,18 @@ class AI:
|
|||||||
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> list[list[float]]:
|
) -> list[list[float]]:
|
||||||
"""生成文本嵌入向量"""
|
"""
|
||||||
|
生成文本嵌入向量
|
||||||
|
|
||||||
|
参数:
|
||||||
|
texts: 要生成嵌入向量的文本或文本列表。
|
||||||
|
model: 要使用的嵌入模型名称。
|
||||||
|
task_type: 嵌入任务类型。
|
||||||
|
**kwargs: 传递给模型的其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
list[list[float]]: 文本的嵌入向量列表。
|
||||||
|
"""
|
||||||
if isinstance(texts, str):
|
if isinstance(texts, str):
|
||||||
texts = [texts]
|
texts = [texts]
|
||||||
if not texts:
|
if not texts:
|
||||||
@ -420,7 +524,17 @@ async def chat(
|
|||||||
model: ModelName = None,
|
model: ModelName = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""聊天对话便捷函数"""
|
"""
|
||||||
|
聊天对话便捷函数
|
||||||
|
|
||||||
|
参数:
|
||||||
|
message: 用户输入的消息。
|
||||||
|
model: 要使用的模型名称。
|
||||||
|
**kwargs: 传递给模型的其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 模型的文本响应。
|
||||||
|
"""
|
||||||
ai = AI()
|
ai = AI()
|
||||||
return await ai.chat(message, model=model, **kwargs)
|
return await ai.chat(message, model=model, **kwargs)
|
||||||
|
|
||||||
@ -432,7 +546,18 @@ async def code(
|
|||||||
timeout: int | None = None,
|
timeout: int | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""代码执行便捷函数"""
|
"""
|
||||||
|
代码执行便捷函数
|
||||||
|
|
||||||
|
参数:
|
||||||
|
prompt: 代码执行的提示词。
|
||||||
|
model: 要使用的模型名称。
|
||||||
|
timeout: 代码执行超时时间(秒)。
|
||||||
|
**kwargs: 传递给模型的其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
dict[str, Any]: 包含执行结果的字典。
|
||||||
|
"""
|
||||||
ai = AI()
|
ai = AI()
|
||||||
return await ai.code(prompt, model=model, timeout=timeout, **kwargs)
|
return await ai.code(prompt, model=model, timeout=timeout, **kwargs)
|
||||||
|
|
||||||
@ -444,45 +569,56 @@ async def search(
|
|||||||
instruction: str = "",
|
instruction: str = "",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""信息搜索便捷函数"""
|
"""
|
||||||
|
信息搜索便捷函数
|
||||||
|
|
||||||
|
参数:
|
||||||
|
query: 搜索查询内容。
|
||||||
|
model: 要使用的模型名称。
|
||||||
|
instruction: 搜索指令。
|
||||||
|
**kwargs: 传递给模型的其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
dict[str, Any]: 包含搜索结果的字典。
|
||||||
|
"""
|
||||||
ai = AI()
|
ai = AI()
|
||||||
return await ai.search(query, model=model, instruction=instruction, **kwargs)
|
return await ai.search(query, model=model, instruction=instruction, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
async def analyze(
|
async def analyze(
|
||||||
message: UniMessage,
|
message: UniMessage | None,
|
||||||
*,
|
*,
|
||||||
instruction: str = "",
|
instruction: str = "",
|
||||||
model: ModelName = None,
|
model: ModelName = None,
|
||||||
tools: list[dict[str, Any]] | None = None,
|
use_tools: list[str] | None = None,
|
||||||
tool_config: dict[str, Any] | None = None,
|
tool_config: dict[str, Any] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str | LLMResponse:
|
) -> str | LLMResponse:
|
||||||
"""内容分析便捷函数"""
|
"""
|
||||||
|
内容分析便捷函数
|
||||||
|
|
||||||
|
参数:
|
||||||
|
message: 要分析的消息内容。
|
||||||
|
instruction: 分析指令。
|
||||||
|
model: 要使用的模型名称。
|
||||||
|
use_tools: 要使用的工具名称列表。
|
||||||
|
tool_config: 工具配置。
|
||||||
|
**kwargs: 传递给模型的其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str | LLMResponse: 分析结果。
|
||||||
|
"""
|
||||||
ai = AI()
|
ai = AI()
|
||||||
return await ai.analyze(
|
return await ai.analyze(
|
||||||
message,
|
message,
|
||||||
instruction=instruction,
|
instruction=instruction,
|
||||||
model=model,
|
model=model,
|
||||||
tools=tools,
|
use_tools=use_tools,
|
||||||
tool_config=tool_config,
|
tool_config=tool_config,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def analyze_with_images(
|
|
||||||
text: str,
|
|
||||||
images: list[str | Path | bytes] | str | Path | bytes,
|
|
||||||
*,
|
|
||||||
instruction: str = "",
|
|
||||||
model: ModelName = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> str | LLMResponse:
|
|
||||||
"""图片分析便捷函数"""
|
|
||||||
message = create_multimodal_message(text=text, images=images)
|
|
||||||
return await analyze(message, instruction=instruction, model=model, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
async def analyze_multimodal(
|
async def analyze_multimodal(
|
||||||
text: str | None = None,
|
text: str | None = None,
|
||||||
images: list[str | Path | bytes] | str | Path | bytes | None = None,
|
images: list[str | Path | bytes] | str | Path | bytes | None = None,
|
||||||
@ -493,7 +629,21 @@ async def analyze_multimodal(
|
|||||||
model: ModelName = None,
|
model: ModelName = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str | LLMResponse:
|
) -> str | LLMResponse:
|
||||||
"""多模态分析便捷函数"""
|
"""
|
||||||
|
多模态分析便捷函数
|
||||||
|
|
||||||
|
参数:
|
||||||
|
text: 文本内容。
|
||||||
|
images: 图片文件路径、字节数据或列表。
|
||||||
|
videos: 视频文件路径、字节数据或列表。
|
||||||
|
audios: 音频文件路径、字节数据或列表。
|
||||||
|
instruction: 分析指令。
|
||||||
|
model: 要使用的模型名称。
|
||||||
|
**kwargs: 传递给模型的其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str | LLMResponse: 分析结果。
|
||||||
|
"""
|
||||||
message = create_multimodal_message(
|
message = create_multimodal_message(
|
||||||
text=text, images=images, videos=videos, audios=audios
|
text=text, images=images, videos=videos, audios=audios
|
||||||
)
|
)
|
||||||
@ -510,7 +660,21 @@ async def search_multimodal(
|
|||||||
model: ModelName = None,
|
model: ModelName = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""多模态搜索便捷函数"""
|
"""
|
||||||
|
多模态搜索便捷函数
|
||||||
|
|
||||||
|
参数:
|
||||||
|
text: 文本内容。
|
||||||
|
images: 图片文件路径、字节数据或列表。
|
||||||
|
videos: 视频文件路径、字节数据或列表。
|
||||||
|
audios: 音频文件路径、字节数据或列表。
|
||||||
|
instruction: 搜索指令。
|
||||||
|
model: 要使用的模型名称。
|
||||||
|
**kwargs: 传递给模型的其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
dict[str, Any]: 包含搜索结果的字典。
|
||||||
|
"""
|
||||||
message = create_multimodal_message(
|
message = create_multimodal_message(
|
||||||
text=text, images=images, videos=videos, audios=audios
|
text=text, images=images, videos=videos, audios=audios
|
||||||
)
|
)
|
||||||
@ -525,6 +689,101 @@ async def embed(
|
|||||||
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> list[list[float]]:
|
) -> list[list[float]]:
|
||||||
"""文本嵌入便捷函数"""
|
"""
|
||||||
|
文本嵌入便捷函数
|
||||||
|
|
||||||
|
参数:
|
||||||
|
texts: 要生成嵌入向量的文本或文本列表。
|
||||||
|
model: 要使用的嵌入模型名称。
|
||||||
|
task_type: 嵌入任务类型。
|
||||||
|
**kwargs: 传递给模型的其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
list[list[float]]: 文本的嵌入向量列表。
|
||||||
|
"""
|
||||||
ai = AI()
|
ai = AI()
|
||||||
return await ai.embed(texts, model=model, task_type=task_type, **kwargs)
|
return await ai.embed(texts, model=model, task_type=task_type, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
async def pipeline_chat(
|
||||||
|
message: UniMessage | str | list[LLMContentPart],
|
||||||
|
model_chain: list[ModelName],
|
||||||
|
*,
|
||||||
|
initial_instruction: str = "",
|
||||||
|
final_instruction: str = "",
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""
|
||||||
|
AI模型链式调用,前一个模型的输出作为下一个模型的输入。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
message: 初始输入消息(支持多模态)
|
||||||
|
model_chain: 模型名称列表
|
||||||
|
initial_instruction: 第一个模型的系统指令
|
||||||
|
final_instruction: 最后一个模型的系统指令
|
||||||
|
**kwargs: 传递给模型实例的其他参数
|
||||||
|
|
||||||
|
返回:
|
||||||
|
LLMResponse: 最后一个模型的响应结果
|
||||||
|
"""
|
||||||
|
if not model_chain:
|
||||||
|
raise ValueError("模型链`model_chain`不能为空。")
|
||||||
|
|
||||||
|
current_content: str | list[LLMContentPart]
|
||||||
|
if isinstance(message, str):
|
||||||
|
current_content = message
|
||||||
|
elif isinstance(message, list):
|
||||||
|
current_content = message
|
||||||
|
else:
|
||||||
|
current_content = await unimsg_to_llm_parts(message)
|
||||||
|
|
||||||
|
final_response: LLMResponse | None = None
|
||||||
|
|
||||||
|
for i, model_name in enumerate(model_chain):
|
||||||
|
if not model_name:
|
||||||
|
raise ValueError(f"模型链中第 {i + 1} 个模型名称为空。")
|
||||||
|
|
||||||
|
is_first_step = i == 0
|
||||||
|
is_last_step = i == len(model_chain) - 1
|
||||||
|
|
||||||
|
messages_for_step: list[LLMMessage] = []
|
||||||
|
instruction_for_step = ""
|
||||||
|
if is_first_step and initial_instruction:
|
||||||
|
instruction_for_step = initial_instruction
|
||||||
|
elif is_last_step and final_instruction:
|
||||||
|
instruction_for_step = final_instruction
|
||||||
|
|
||||||
|
if instruction_for_step:
|
||||||
|
messages_for_step.append(LLMMessage.system(instruction_for_step))
|
||||||
|
|
||||||
|
messages_for_step.append(LLMMessage.user(current_content))
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Pipeline Step [{i + 1}/{len(model_chain)}]: "
|
||||||
|
f"使用模型 '{model_name}' 进行处理..."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
async with await get_model_instance(model_name, **kwargs) as model:
|
||||||
|
response = await model.generate_response(messages_for_step)
|
||||||
|
final_response = response
|
||||||
|
current_content = response.text.strip()
|
||||||
|
if not current_content and not is_last_step:
|
||||||
|
logger.warning(
|
||||||
|
f"模型 '{model_name}' 在中间步骤返回了空内容,流水线可能无法继续。"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"在模型链的第 {i + 1} 步 ('{model_name}') 出错: {e}", e=e)
|
||||||
|
raise LLMException(
|
||||||
|
f"流水线在模型 '{model_name}' 处执行失败: {e}",
|
||||||
|
code=LLMErrorCode.GENERATION_FAILED,
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
|
||||||
|
if final_response is None:
|
||||||
|
raise LLMException(
|
||||||
|
"AI流水线未能产生任何响应。", code=LLMErrorCode.GENERATION_FAILED
|
||||||
|
)
|
||||||
|
|
||||||
|
return final_response
|
||||||
|
|||||||
@ -14,6 +14,8 @@ from .generation import (
|
|||||||
from .presets import CommonOverrides
|
from .presets import CommonOverrides
|
||||||
from .providers import (
|
from .providers import (
|
||||||
LLMConfig,
|
LLMConfig,
|
||||||
|
ToolConfig,
|
||||||
|
get_gemini_safety_threshold,
|
||||||
get_llm_config,
|
get_llm_config,
|
||||||
register_llm_configs,
|
register_llm_configs,
|
||||||
set_default_model,
|
set_default_model,
|
||||||
@ -25,8 +27,10 @@ __all__ = [
|
|||||||
"LLMConfig",
|
"LLMConfig",
|
||||||
"LLMGenerationConfig",
|
"LLMGenerationConfig",
|
||||||
"ModelConfigOverride",
|
"ModelConfigOverride",
|
||||||
|
"ToolConfig",
|
||||||
"apply_api_specific_mappings",
|
"apply_api_specific_mappings",
|
||||||
"create_generation_config_from_kwargs",
|
"create_generation_config_from_kwargs",
|
||||||
|
"get_gemini_safety_threshold",
|
||||||
"get_llm_config",
|
"get_llm_config",
|
||||||
"register_llm_configs",
|
"register_llm_configs",
|
||||||
"set_default_model",
|
"set_default_model",
|
||||||
|
|||||||
@ -111,12 +111,12 @@ class LLMGenerationConfig(ModelConfigOverride):
|
|||||||
params["temperature"] = self.temperature
|
params["temperature"] = self.temperature
|
||||||
|
|
||||||
if self.max_tokens is not None:
|
if self.max_tokens is not None:
|
||||||
if api_type in ["gemini", "gemini_native"]:
|
if api_type == "gemini":
|
||||||
params["maxOutputTokens"] = self.max_tokens
|
params["maxOutputTokens"] = self.max_tokens
|
||||||
else:
|
else:
|
||||||
params["max_tokens"] = self.max_tokens
|
params["max_tokens"] = self.max_tokens
|
||||||
|
|
||||||
if api_type in ["gemini", "gemini_native"]:
|
if api_type == "gemini":
|
||||||
if self.top_k is not None:
|
if self.top_k is not None:
|
||||||
params["topK"] = self.top_k
|
params["topK"] = self.top_k
|
||||||
if self.top_p is not None:
|
if self.top_p is not None:
|
||||||
@ -151,13 +151,13 @@ class LLMGenerationConfig(ModelConfigOverride):
|
|||||||
if api_type in ["openai", "zhipu", "deepseek", "general_openai_compat"]:
|
if api_type in ["openai", "zhipu", "deepseek", "general_openai_compat"]:
|
||||||
params["response_format"] = {"type": "json_object"}
|
params["response_format"] = {"type": "json_object"}
|
||||||
logger.debug(f"为 {api_type} 启用 JSON 对象输出模式")
|
logger.debug(f"为 {api_type} 启用 JSON 对象输出模式")
|
||||||
elif api_type in ["gemini", "gemini_native"]:
|
elif api_type == "gemini":
|
||||||
params["responseMimeType"] = "application/json"
|
params["responseMimeType"] = "application/json"
|
||||||
if self.response_schema:
|
if self.response_schema:
|
||||||
params["responseSchema"] = self.response_schema
|
params["responseSchema"] = self.response_schema
|
||||||
logger.debug(f"为 {api_type} 启用 JSON MIME 类型输出模式")
|
logger.debug(f"为 {api_type} 启用 JSON MIME 类型输出模式")
|
||||||
|
|
||||||
if api_type in ["gemini", "gemini_native"]:
|
if api_type == "gemini":
|
||||||
if (
|
if (
|
||||||
self.response_format != ResponseFormat.JSON
|
self.response_format != ResponseFormat.JSON
|
||||||
and self.response_mime_type is not None
|
and self.response_mime_type is not None
|
||||||
@ -214,7 +214,7 @@ def apply_api_specific_mappings(
|
|||||||
"""应用API特定的参数映射"""
|
"""应用API特定的参数映射"""
|
||||||
mapped_params = params.copy()
|
mapped_params = params.copy()
|
||||||
|
|
||||||
if api_type in ["gemini", "gemini_native"]:
|
if api_type == "gemini":
|
||||||
if "max_tokens" in mapped_params:
|
if "max_tokens" in mapped_params:
|
||||||
mapped_params["maxOutputTokens"] = mapped_params.pop("max_tokens")
|
mapped_params["maxOutputTokens"] = mapped_params.pop("max_tokens")
|
||||||
if "top_k" in mapped_params:
|
if "top_k" in mapped_params:
|
||||||
|
|||||||
@ -71,14 +71,17 @@ class CommonOverrides:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def gemini_safe() -> LLMGenerationConfig:
|
def gemini_safe() -> LLMGenerationConfig:
|
||||||
"""Gemini 安全模式:严格安全设置"""
|
"""Gemini 安全模式:使用配置的安全设置"""
|
||||||
|
from .providers import get_gemini_safety_threshold
|
||||||
|
|
||||||
|
threshold = get_gemini_safety_threshold()
|
||||||
return LLMGenerationConfig(
|
return LLMGenerationConfig(
|
||||||
temperature=0.5,
|
temperature=0.5,
|
||||||
safety_settings={
|
safety_settings={
|
||||||
"HARM_CATEGORY_HARASSMENT": "BLOCK_MEDIUM_AND_ABOVE",
|
"HARM_CATEGORY_HARASSMENT": threshold,
|
||||||
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_MEDIUM_AND_ABOVE",
|
"HARM_CATEGORY_HATE_SPEECH": threshold,
|
||||||
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_MEDIUM_AND_ABOVE",
|
"HARM_CATEGORY_SEXUALLY_EXPLICIT": threshold,
|
||||||
"HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_MEDIUM_AND_ABOVE",
|
"HARM_CATEGORY_DANGEROUS_CONTENT": threshold,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -4,15 +4,33 @@ LLM 提供商配置管理
|
|||||||
负责注册和管理 AI 服务提供商的配置项。
|
负责注册和管理 AI 服务提供商的配置项。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from functools import lru_cache
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from zhenxun.configs.config import Config
|
from zhenxun.configs.config import Config
|
||||||
|
from zhenxun.configs.path_config import DATA_PATH
|
||||||
|
from zhenxun.configs.utils import parse_as
|
||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
|
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||||
|
|
||||||
from ..types.models import ModelDetail, ProviderConfig
|
from ..types.models import ModelDetail, ProviderConfig
|
||||||
|
|
||||||
|
|
||||||
|
class ToolConfig(BaseModel):
|
||||||
|
"""MCP类型工具的配置定义"""
|
||||||
|
|
||||||
|
type: str = "mcp"
|
||||||
|
name: str = Field(..., description="工具的唯一名称标识")
|
||||||
|
description: str | None = Field(None, description="工具功能的描述")
|
||||||
|
mcp_config: dict[str, Any] | BaseModel = Field(
|
||||||
|
..., description="MCP服务器的特定配置"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
AI_CONFIG_GROUP = "AI"
|
AI_CONFIG_GROUP = "AI"
|
||||||
PROVIDERS_CONFIG_KEY = "PROVIDERS"
|
PROVIDERS_CONFIG_KEY = "PROVIDERS"
|
||||||
|
|
||||||
@ -38,6 +56,9 @@ class LLMConfig(BaseModel):
|
|||||||
providers: list[ProviderConfig] = Field(
|
providers: list[ProviderConfig] = Field(
|
||||||
default_factory=list, description="配置多个 AI 服务提供商及其模型信息"
|
default_factory=list, description="配置多个 AI 服务提供商及其模型信息"
|
||||||
)
|
)
|
||||||
|
mcp_tools: list[ToolConfig] = Field(
|
||||||
|
default_factory=list, description="配置可用的外部MCP工具"
|
||||||
|
)
|
||||||
|
|
||||||
def get_provider_by_name(self, name: str) -> ProviderConfig | None:
|
def get_provider_by_name(self, name: str) -> ProviderConfig | None:
|
||||||
"""根据名称获取提供商配置
|
"""根据名称获取提供商配置
|
||||||
@ -132,7 +153,7 @@ def get_default_providers() -> list[dict[str, Any]]:
|
|||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"name": "DeepSeek",
|
"name": "DeepSeek",
|
||||||
"api_key": "sk-******",
|
"api_key": "YOUR_ARK_API_KEY",
|
||||||
"api_base": "https://api.deepseek.com",
|
"api_base": "https://api.deepseek.com",
|
||||||
"api_type": "openai",
|
"api_type": "openai",
|
||||||
"models": [
|
"models": [
|
||||||
@ -146,9 +167,30 @@ def get_default_providers() -> list[dict[str, Any]]:
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "ARK",
|
||||||
|
"api_key": "YOUR_ARK_API_KEY",
|
||||||
|
"api_base": "https://ark.cn-beijing.volces.com",
|
||||||
|
"api_type": "ark",
|
||||||
|
"models": [
|
||||||
|
{"model_name": "deepseek-r1-250528"},
|
||||||
|
{"model_name": "doubao-seed-1-6-250615"},
|
||||||
|
{"model_name": "doubao-seed-1-6-flash-250615"},
|
||||||
|
{"model_name": "doubao-seed-1-6-thinking-250615"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "siliconflow",
|
||||||
|
"api_key": "YOUR_ARK_API_KEY",
|
||||||
|
"api_base": "https://api.siliconflow.cn",
|
||||||
|
"api_type": "openai",
|
||||||
|
"models": [
|
||||||
|
{"model_name": "deepseek-ai/DeepSeek-V3"},
|
||||||
|
],
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "GLM",
|
"name": "GLM",
|
||||||
"api_key": "",
|
"api_key": "YOUR_ARK_API_KEY",
|
||||||
"api_base": "https://open.bigmodel.cn",
|
"api_base": "https://open.bigmodel.cn",
|
||||||
"api_type": "zhipu",
|
"api_type": "zhipu",
|
||||||
"models": [
|
"models": [
|
||||||
@ -167,12 +209,41 @@ def get_default_providers() -> list[dict[str, Any]]:
|
|||||||
"api_type": "gemini",
|
"api_type": "gemini",
|
||||||
"models": [
|
"models": [
|
||||||
{"model_name": "gemini-2.0-flash"},
|
{"model_name": "gemini-2.0-flash"},
|
||||||
{"model_name": "gemini-2.5-flash-preview-05-20"},
|
{"model_name": "gemini-2.5-flash"},
|
||||||
|
{"model_name": "gemini-2.5-pro"},
|
||||||
|
{"model_name": "gemini-2.5-flash-lite-preview-06-17"},
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_mcp_tools() -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
获取默认的MCP工具配置,用于在文件不存在时创建。
|
||||||
|
包含了 baidu-map, Context7, 和 sequential-thinking.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"mcpServers": {
|
||||||
|
"baidu-map": {
|
||||||
|
"command": "npx",
|
||||||
|
"args": ["-y", "@baidumap/mcp-server-baidu-map"],
|
||||||
|
"env": {"BAIDU_MAP_API_KEY": "<YOUR_BAIDU_MAP_API_KEY>"},
|
||||||
|
"description": "百度地图工具,提供地理编码、路线规划等功能。",
|
||||||
|
},
|
||||||
|
"sequential-thinking": {
|
||||||
|
"command": "npx",
|
||||||
|
"args": ["-y", "@modelcontextprotocol/server-sequential-thinking"],
|
||||||
|
"description": "顺序思维工具,用于帮助模型进行多步骤推理。",
|
||||||
|
},
|
||||||
|
"Context7": {
|
||||||
|
"command": "npx",
|
||||||
|
"args": ["-y", "@upstash/context7-mcp@latest"],
|
||||||
|
"description": "Upstash 提供的上下文管理和记忆工具。",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def register_llm_configs():
|
def register_llm_configs():
|
||||||
"""注册 LLM 服务的配置项"""
|
"""注册 LLM 服务的配置项"""
|
||||||
logger.info("注册 LLM 服务的配置项")
|
logger.info("注册 LLM 服务的配置项")
|
||||||
@ -214,6 +285,19 @@ def register_llm_configs():
|
|||||||
help="LLM服务请求重试的基础延迟时间(秒)",
|
help="LLM服务请求重试的基础延迟时间(秒)",
|
||||||
type=int,
|
type=int,
|
||||||
)
|
)
|
||||||
|
Config.add_plugin_config(
|
||||||
|
AI_CONFIG_GROUP,
|
||||||
|
"gemini_safety_threshold",
|
||||||
|
"BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
help=(
|
||||||
|
"Gemini 安全过滤阈值 "
|
||||||
|
"(BLOCK_LOW_AND_ABOVE: 阻止低级别及以上, "
|
||||||
|
"BLOCK_MEDIUM_AND_ABOVE: 阻止中等级别及以上, "
|
||||||
|
"BLOCK_ONLY_HIGH: 只阻止高级别, "
|
||||||
|
"BLOCK_NONE: 不阻止)"
|
||||||
|
),
|
||||||
|
type=str,
|
||||||
|
)
|
||||||
|
|
||||||
Config.add_plugin_config(
|
Config.add_plugin_config(
|
||||||
AI_CONFIG_GROUP,
|
AI_CONFIG_GROUP,
|
||||||
@ -225,24 +309,111 @@ def register_llm_configs():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
def get_llm_config() -> LLMConfig:
|
def get_llm_config() -> LLMConfig:
|
||||||
"""获取 LLM 配置实例
|
"""获取 LLM 配置实例,现在会从新的 JSON 文件加载 MCP 工具"""
|
||||||
|
|
||||||
返回:
|
|
||||||
LLMConfig: LLM 配置实例
|
|
||||||
"""
|
|
||||||
ai_config = get_ai_config()
|
ai_config = get_ai_config()
|
||||||
|
|
||||||
|
llm_data_path = DATA_PATH / "llm"
|
||||||
|
mcp_tools_path = llm_data_path / "mcp_tools.json"
|
||||||
|
|
||||||
|
mcp_tools_list = []
|
||||||
|
mcp_servers_dict = {}
|
||||||
|
|
||||||
|
if not mcp_tools_path.exists():
|
||||||
|
logger.info(f"未找到 MCP 工具配置文件,将在 '{mcp_tools_path}' 创建一个。")
|
||||||
|
llm_data_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
default_mcp_config = get_default_mcp_tools()
|
||||||
|
try:
|
||||||
|
with mcp_tools_path.open("w", encoding="utf-8") as f:
|
||||||
|
json.dump(default_mcp_config, f, ensure_ascii=False, indent=2)
|
||||||
|
mcp_servers_dict = default_mcp_config.get("mcpServers", {})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"创建默认 MCP 配置文件失败: {e}", e=e)
|
||||||
|
mcp_servers_dict = {}
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
with mcp_tools_path.open("r", encoding="utf-8") as f:
|
||||||
|
mcp_data = json.load(f)
|
||||||
|
mcp_servers_dict = mcp_data.get("mcpServers", {})
|
||||||
|
if not isinstance(mcp_servers_dict, dict):
|
||||||
|
logger.warning(
|
||||||
|
f"'{mcp_tools_path}' 中的 'mcpServers' 键不是一个字典,"
|
||||||
|
f"将使用空配置。"
|
||||||
|
)
|
||||||
|
mcp_servers_dict = {}
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"解析 MCP 配置文件 '{mcp_tools_path}' 失败: {e}", e=e)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"读取 MCP 配置文件时发生未知错误: {e}", e=e)
|
||||||
|
mcp_servers_dict = {}
|
||||||
|
|
||||||
|
if sys.platform == "win32":
|
||||||
|
logger.debug("检测到Windows平台,正在调整MCP工具的npx命令...")
|
||||||
|
for name, config in mcp_servers_dict.items():
|
||||||
|
if isinstance(config, dict) and config.get("command") == "npx":
|
||||||
|
logger.info(f"为工具 '{name}' 包装npx命令以兼容Windows。")
|
||||||
|
original_args = config.get("args", [])
|
||||||
|
config["command"] = "cmd"
|
||||||
|
config["args"] = ["/c", "npx", *original_args]
|
||||||
|
|
||||||
|
if mcp_servers_dict:
|
||||||
|
mcp_tools_list = [
|
||||||
|
{
|
||||||
|
"name": name,
|
||||||
|
"type": "mcp",
|
||||||
|
"description": config.get("description", f"MCP tool for {name}"),
|
||||||
|
"mcp_config": config,
|
||||||
|
}
|
||||||
|
for name, config in mcp_servers_dict.items()
|
||||||
|
if isinstance(config, dict)
|
||||||
|
]
|
||||||
|
|
||||||
|
from ..tools.registry import tool_registry
|
||||||
|
|
||||||
|
for tool_dict in mcp_tools_list:
|
||||||
|
if isinstance(tool_dict, dict):
|
||||||
|
tool_name = tool_dict.get("name")
|
||||||
|
if not tool_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
config_model = tool_registry.get_mcp_config_model(tool_name)
|
||||||
|
if not config_model:
|
||||||
|
logger.debug(
|
||||||
|
f"MCP工具 '{tool_name}' 没有注册其配置模型,"
|
||||||
|
f"将跳过特定配置验证,直接使用原始配置字典。"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
mcp_config_data = tool_dict.get("mcp_config", {})
|
||||||
|
try:
|
||||||
|
parsed_mcp_config = parse_as(config_model, mcp_config_data)
|
||||||
|
tool_dict["mcp_config"] = parsed_mcp_config
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"MCP工具 '{tool_name}' 的 `mcp_config` 配置错误: {e}")
|
||||||
|
|
||||||
config_data = {
|
config_data = {
|
||||||
"default_model_name": ai_config.get("default_model_name"),
|
"default_model_name": ai_config.get("default_model_name"),
|
||||||
"proxy": ai_config.get("proxy"),
|
"proxy": ai_config.get("proxy"),
|
||||||
"timeout": ai_config.get("timeout", 180),
|
"timeout": ai_config.get("timeout", 180),
|
||||||
"max_retries_llm": ai_config.get("max_retries_llm", 3),
|
"max_retries_llm": ai_config.get("max_retries_llm", 3),
|
||||||
"retry_delay_llm": ai_config.get("retry_delay_llm", 2),
|
"retry_delay_llm": ai_config.get("retry_delay_llm", 2),
|
||||||
"providers": ai_config.get(PROVIDERS_CONFIG_KEY, []),
|
PROVIDERS_CONFIG_KEY: ai_config.get(PROVIDERS_CONFIG_KEY, []),
|
||||||
|
"mcp_tools": mcp_tools_list,
|
||||||
}
|
}
|
||||||
|
|
||||||
return LLMConfig(**config_data)
|
return parse_as(LLMConfig, config_data)
|
||||||
|
|
||||||
|
|
||||||
|
def get_gemini_safety_threshold() -> str:
|
||||||
|
"""获取 Gemini 安全过滤阈值配置
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 安全过滤阈值
|
||||||
|
"""
|
||||||
|
ai_config = get_ai_config()
|
||||||
|
return ai_config.get("gemini_safety_threshold", "BLOCK_MEDIUM_AND_ABOVE")
|
||||||
|
|
||||||
|
|
||||||
def validate_llm_config() -> tuple[bool, list[str]]:
|
def validate_llm_config() -> tuple[bool, list[str]]:
|
||||||
@ -326,3 +497,17 @@ def set_default_model(provider_model_name: str | None) -> bool:
|
|||||||
logger.info("默认模型已清除")
|
logger.info("默认模型已清除")
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@PriorityLifecycle.on_startup(priority=10)
|
||||||
|
async def _init_llm_config_on_startup():
|
||||||
|
"""
|
||||||
|
在服务启动时主动调用一次 get_llm_config,
|
||||||
|
以触发必要的初始化操作,例如创建默认的 mcp_tools.json 文件。
|
||||||
|
"""
|
||||||
|
logger.info("正在初始化 LLM 配置并检查 MCP 工具文件...")
|
||||||
|
try:
|
||||||
|
get_llm_config()
|
||||||
|
logger.info("LLM 配置初始化完成。")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"LLM 配置初始化时发生错误: {e}", e=e)
|
||||||
|
|||||||
@ -49,12 +49,36 @@ class LLMHttpClient:
|
|||||||
max_keepalive_connections=self.config.max_keepalive_connections,
|
max_keepalive_connections=self.config.max_keepalive_connections,
|
||||||
)
|
)
|
||||||
timeout = httpx.Timeout(self.config.timeout)
|
timeout = httpx.Timeout(self.config.timeout)
|
||||||
|
|
||||||
|
client_kwargs = {}
|
||||||
|
if self.config.proxy:
|
||||||
|
try:
|
||||||
|
version_parts = httpx.__version__.split(".")
|
||||||
|
major = int(
|
||||||
|
"".join(c for c in version_parts[0] if c.isdigit())
|
||||||
|
)
|
||||||
|
minor = (
|
||||||
|
int("".join(c for c in version_parts[1] if c.isdigit()))
|
||||||
|
if len(version_parts) > 1
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
if (major, minor) >= (0, 28):
|
||||||
|
client_kwargs["proxy"] = self.config.proxy
|
||||||
|
else:
|
||||||
|
client_kwargs["proxies"] = self.config.proxy
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
client_kwargs["proxies"] = self.config.proxy
|
||||||
|
logger.warning(
|
||||||
|
f"无法解析 httpx 版本 '{httpx.__version__}',"
|
||||||
|
"LLM模块将默认使用旧版 'proxies' 参数语法。"
|
||||||
|
)
|
||||||
|
|
||||||
self._client = httpx.AsyncClient(
|
self._client = httpx.AsyncClient(
|
||||||
headers=headers,
|
headers=headers,
|
||||||
limits=limits,
|
limits=limits,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
proxies=self.config.proxy,
|
|
||||||
follow_redirects=True,
|
follow_redirects=True,
|
||||||
|
**client_kwargs,
|
||||||
)
|
)
|
||||||
if self._client is None:
|
if self._client is None:
|
||||||
raise LLMException(
|
raise LLMException(
|
||||||
@ -156,7 +180,16 @@ async def create_llm_http_client(
|
|||||||
timeout: int = 180,
|
timeout: int = 180,
|
||||||
proxy: str | None = None,
|
proxy: str | None = None,
|
||||||
) -> LLMHttpClient:
|
) -> LLMHttpClient:
|
||||||
"""创建LLM HTTP客户端"""
|
"""
|
||||||
|
创建LLM HTTP客户端
|
||||||
|
|
||||||
|
参数:
|
||||||
|
timeout: 超时时间(秒)。
|
||||||
|
proxy: 代理服务器地址。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
LLMHttpClient: HTTP客户端实例。
|
||||||
|
"""
|
||||||
config = HttpClientConfig(timeout=timeout, proxy=proxy)
|
config = HttpClientConfig(timeout=timeout, proxy=proxy)
|
||||||
return LLMHttpClient(config)
|
return LLMHttpClient(config)
|
||||||
|
|
||||||
@ -185,7 +218,20 @@ async def with_smart_retry(
|
|||||||
provider_name: str | None = None,
|
provider_name: str | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""智能重试装饰器 - 支持Key轮询和错误分类"""
|
"""
|
||||||
|
智能重试装饰器 - 支持Key轮询和错误分类
|
||||||
|
|
||||||
|
参数:
|
||||||
|
func: 要重试的异步函数。
|
||||||
|
*args: 传递给函数的位置参数。
|
||||||
|
retry_config: 重试配置。
|
||||||
|
key_store: API密钥状态存储。
|
||||||
|
provider_name: 提供商名称。
|
||||||
|
**kwargs: 传递给函数的关键字参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
Any: 函数执行结果。
|
||||||
|
"""
|
||||||
config = retry_config or RetryConfig()
|
config = retry_config or RetryConfig()
|
||||||
last_exception: Exception | None = None
|
last_exception: Exception | None = None
|
||||||
failed_keys: set[str] = set()
|
failed_keys: set[str] = set()
|
||||||
@ -294,7 +340,17 @@ class KeyStatusStore:
|
|||||||
api_keys: list[str],
|
api_keys: list[str],
|
||||||
exclude_keys: set[str] | None = None,
|
exclude_keys: set[str] | None = None,
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
"""获取下一个可用的API密钥(轮询策略)"""
|
"""
|
||||||
|
获取下一个可用的API密钥(轮询策略)
|
||||||
|
|
||||||
|
参数:
|
||||||
|
provider_name: 提供商名称。
|
||||||
|
api_keys: API密钥列表。
|
||||||
|
exclude_keys: 要排除的密钥集合。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str | None: 可用的API密钥,如果没有可用密钥则返回None。
|
||||||
|
"""
|
||||||
if not api_keys:
|
if not api_keys:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -338,7 +394,13 @@ class KeyStatusStore:
|
|||||||
logger.debug(f"记录API密钥成功使用: {self._get_key_id(api_key)}")
|
logger.debug(f"记录API密钥成功使用: {self._get_key_id(api_key)}")
|
||||||
|
|
||||||
async def record_failure(self, api_key: str, status_code: int | None):
|
async def record_failure(self, api_key: str, status_code: int | None):
|
||||||
"""记录失败使用"""
|
"""
|
||||||
|
记录失败使用
|
||||||
|
|
||||||
|
参数:
|
||||||
|
api_key: API密钥。
|
||||||
|
status_code: HTTP状态码。
|
||||||
|
"""
|
||||||
key_id = self._get_key_id(api_key)
|
key_id = self._get_key_id(api_key)
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
if status_code in [401, 403]:
|
if status_code in [401, 403]:
|
||||||
@ -356,7 +418,15 @@ class KeyStatusStore:
|
|||||||
logger.info(f"重置API密钥状态: {self._get_key_id(api_key)}")
|
logger.info(f"重置API密钥状态: {self._get_key_id(api_key)}")
|
||||||
|
|
||||||
async def get_key_stats(self, api_keys: list[str]) -> dict[str, dict]:
|
async def get_key_stats(self, api_keys: list[str]) -> dict[str, dict]:
|
||||||
"""获取密钥使用统计"""
|
"""
|
||||||
|
获取密钥使用统计
|
||||||
|
|
||||||
|
参数:
|
||||||
|
api_keys: API密钥列表。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
dict[str, dict]: 密钥统计信息字典。
|
||||||
|
"""
|
||||||
stats = {}
|
stats = {}
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
for key in api_keys:
|
for key in api_keys:
|
||||||
|
|||||||
@ -17,6 +17,7 @@ from .config.providers import AI_CONFIG_GROUP, PROVIDERS_CONFIG_KEY, get_ai_conf
|
|||||||
from .core import http_client_manager, key_store
|
from .core import http_client_manager, key_store
|
||||||
from .service import LLMModel
|
from .service import LLMModel
|
||||||
from .types import LLMErrorCode, LLMException, ModelDetail, ProviderConfig
|
from .types import LLMErrorCode, LLMException, ModelDetail, ProviderConfig
|
||||||
|
from .types.capabilities import get_model_capabilities
|
||||||
|
|
||||||
DEFAULT_MODEL_NAME_KEY = "default_model_name"
|
DEFAULT_MODEL_NAME_KEY = "default_model_name"
|
||||||
PROXY_KEY = "proxy"
|
PROXY_KEY = "proxy"
|
||||||
@ -115,57 +116,30 @@ def get_default_api_base_for_type(api_type: str) -> str | None:
|
|||||||
|
|
||||||
|
|
||||||
def get_configured_providers() -> list[ProviderConfig]:
|
def get_configured_providers() -> list[ProviderConfig]:
|
||||||
"""从配置中获取Provider列表 - 简化版本"""
|
"""从配置中获取Provider列表 - 简化和修正版本"""
|
||||||
ai_config = get_ai_config()
|
ai_config = get_ai_config()
|
||||||
providers_raw = ai_config.get(PROVIDERS_CONFIG_KEY, [])
|
providers = ai_config.get(PROVIDERS_CONFIG_KEY, [])
|
||||||
if not isinstance(providers_raw, list):
|
|
||||||
|
if not isinstance(providers, list):
|
||||||
logger.error(
|
logger.error(
|
||||||
f"配置项 {AI_CONFIG_GROUP}.{PROVIDERS_CONFIG_KEY} 不是一个列表,"
|
f"配置项 {AI_CONFIG_GROUP}.{PROVIDERS_CONFIG_KEY} 的值不是一个列表,"
|
||||||
f"将使用空列表。"
|
f"将使用空列表。"
|
||||||
)
|
)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
valid_providers = []
|
valid_providers = []
|
||||||
for i, item in enumerate(providers_raw):
|
for i, item in enumerate(providers):
|
||||||
if not isinstance(item, dict):
|
if isinstance(item, ProviderConfig):
|
||||||
logger.warning(f"配置文件中第 {i + 1} 项不是字典格式,已跳过。")
|
if not item.api_base:
|
||||||
continue
|
default_api_base = get_default_api_base_for_type(item.api_type)
|
||||||
|
|
||||||
try:
|
|
||||||
if not item.get("name"):
|
|
||||||
logger.warning(f"Provider {i + 1} 缺少 'name' 字段,已跳过。")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not item.get("api_key"):
|
|
||||||
logger.warning(
|
|
||||||
f"Provider '{item['name']}' 缺少 'api_key' 字段,已跳过。"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if "api_type" not in item or not item["api_type"]:
|
|
||||||
provider_name = item.get("name", "").lower()
|
|
||||||
if "glm" in provider_name or "zhipu" in provider_name:
|
|
||||||
item["api_type"] = "zhipu"
|
|
||||||
elif "gemini" in provider_name or "google" in provider_name:
|
|
||||||
item["api_type"] = "gemini"
|
|
||||||
else:
|
|
||||||
item["api_type"] = "openai"
|
|
||||||
|
|
||||||
if "api_base" not in item or not item["api_base"]:
|
|
||||||
api_type = item.get("api_type")
|
|
||||||
if api_type:
|
|
||||||
default_api_base = get_default_api_base_for_type(api_type)
|
|
||||||
if default_api_base:
|
if default_api_base:
|
||||||
item["api_base"] = default_api_base
|
item.api_base = default_api_base
|
||||||
|
valid_providers.append(item)
|
||||||
if "models" not in item:
|
else:
|
||||||
item["models"] = [{"model_name": item.get("name", "default")}]
|
logger.warning(
|
||||||
|
f"配置文件中第 {i + 1} 项未能正确解析为 ProviderConfig 对象,"
|
||||||
provider_conf = ProviderConfig(**item)
|
f"已跳过。实际类型: {type(item)}"
|
||||||
valid_providers.append(provider_conf)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"解析配置文件中 Provider {i + 1} 时出错: {e},已跳过。")
|
|
||||||
|
|
||||||
return valid_providers
|
return valid_providers
|
||||||
|
|
||||||
@ -173,14 +147,15 @@ def get_configured_providers() -> list[ProviderConfig]:
|
|||||||
def find_model_config(
|
def find_model_config(
|
||||||
provider_name: str, model_name: str
|
provider_name: str, model_name: str
|
||||||
) -> tuple[ProviderConfig, ModelDetail] | None:
|
) -> tuple[ProviderConfig, ModelDetail] | None:
|
||||||
"""在配置中查找指定的 Provider 和 ModelDetail
|
"""
|
||||||
|
在配置中查找指定的 Provider 和 ModelDetail
|
||||||
|
|
||||||
Args:
|
参数:
|
||||||
provider_name: 提供商名称
|
provider_name: 提供商名称
|
||||||
model_name: 模型名称
|
model_name: 模型名称
|
||||||
|
|
||||||
Returns:
|
返回:
|
||||||
找到的 (ProviderConfig, ModelDetail) 元组,未找到则返回 None
|
tuple[ProviderConfig, ModelDetail] | None: 找到的配置元组,未找到则返回 None
|
||||||
"""
|
"""
|
||||||
providers = get_configured_providers()
|
providers = get_configured_providers()
|
||||||
|
|
||||||
@ -221,10 +196,11 @@ def _get_model_identifiers(provider_name: str, model_detail: ModelDetail) -> lis
|
|||||||
|
|
||||||
|
|
||||||
def list_model_identifiers() -> dict[str, list[str]]:
|
def list_model_identifiers() -> dict[str, list[str]]:
|
||||||
"""列出所有模型的可用标识符
|
"""
|
||||||
|
列出所有模型的可用标识符
|
||||||
|
|
||||||
Returns:
|
返回:
|
||||||
字典,键为模型的完整名称,值为该模型的所有可用标识符列表
|
dict[str, list[str]]: 字典,键为模型的完整名称,值为该模型的所有可用标识符列表
|
||||||
"""
|
"""
|
||||||
providers = get_configured_providers()
|
providers = get_configured_providers()
|
||||||
result = {}
|
result = {}
|
||||||
@ -248,7 +224,16 @@ async def get_model_instance(
|
|||||||
provider_model_name: str | None = None,
|
provider_model_name: str | None = None,
|
||||||
override_config: dict[str, Any] | None = None,
|
override_config: dict[str, Any] | None = None,
|
||||||
) -> LLMModel:
|
) -> LLMModel:
|
||||||
"""根据 'ProviderName/ModelName' 字符串获取并实例化 LLMModel (异步版本)"""
|
"""
|
||||||
|
根据 'ProviderName/ModelName' 字符串获取并实例化 LLMModel (异步版本)
|
||||||
|
|
||||||
|
参数:
|
||||||
|
provider_model_name: 模型名称,格式为 'ProviderName/ModelName'。
|
||||||
|
override_config: 覆盖配置字典。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
LLMModel: 模型实例。
|
||||||
|
"""
|
||||||
cache_key = _make_cache_key(provider_model_name, override_config)
|
cache_key = _make_cache_key(provider_model_name, override_config)
|
||||||
cached_model = _get_cached_model(cache_key)
|
cached_model = _get_cached_model(cache_key)
|
||||||
if cached_model:
|
if cached_model:
|
||||||
@ -292,6 +277,10 @@ async def get_model_instance(
|
|||||||
|
|
||||||
provider_config_found, model_detail_found = config_tuple_found
|
provider_config_found, model_detail_found = config_tuple_found
|
||||||
|
|
||||||
|
capabilities = get_model_capabilities(model_detail_found.model_name)
|
||||||
|
|
||||||
|
model_detail_found.is_embedding_model = capabilities.is_embedding_model
|
||||||
|
|
||||||
ai_config = get_ai_config()
|
ai_config = get_ai_config()
|
||||||
global_proxy_setting = ai_config.get(PROXY_KEY)
|
global_proxy_setting = ai_config.get(PROXY_KEY)
|
||||||
default_timeout = (
|
default_timeout = (
|
||||||
@ -322,6 +311,7 @@ async def get_model_instance(
|
|||||||
model_detail=model_detail_found,
|
model_detail=model_detail_found,
|
||||||
key_store=key_store,
|
key_store=key_store,
|
||||||
http_client=shared_http_client,
|
http_client=shared_http_client,
|
||||||
|
capabilities=capabilities,
|
||||||
)
|
)
|
||||||
|
|
||||||
if override_config:
|
if override_config:
|
||||||
@ -357,7 +347,15 @@ def get_global_default_model_name() -> str | None:
|
|||||||
|
|
||||||
|
|
||||||
def set_global_default_model_name(provider_model_name: str | None) -> bool:
|
def set_global_default_model_name(provider_model_name: str | None) -> bool:
|
||||||
"""设置全局默认模型名称"""
|
"""
|
||||||
|
设置全局默认模型名称
|
||||||
|
|
||||||
|
参数:
|
||||||
|
provider_model_name: 模型名称,格式为 'ProviderName/ModelName'。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bool: 设置是否成功。
|
||||||
|
"""
|
||||||
if provider_model_name:
|
if provider_model_name:
|
||||||
prov_name, mod_name = parse_provider_model_string(provider_model_name)
|
prov_name, mod_name = parse_provider_model_string(provider_model_name)
|
||||||
if not prov_name or not mod_name or not find_model_config(prov_name, mod_name):
|
if not prov_name or not mod_name or not find_model_config(prov_name, mod_name):
|
||||||
@ -377,7 +375,12 @@ def set_global_default_model_name(provider_model_name: str | None) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
async def get_key_usage_stats() -> dict[str, Any]:
|
async def get_key_usage_stats() -> dict[str, Any]:
|
||||||
"""获取所有Provider的Key使用统计"""
|
"""
|
||||||
|
获取所有Provider的Key使用统计
|
||||||
|
|
||||||
|
返回:
|
||||||
|
dict[str, Any]: 包含所有Provider的Key使用统计信息。
|
||||||
|
"""
|
||||||
providers = get_configured_providers()
|
providers = get_configured_providers()
|
||||||
stats = {}
|
stats = {}
|
||||||
|
|
||||||
@ -400,7 +403,16 @@ async def get_key_usage_stats() -> dict[str, Any]:
|
|||||||
|
|
||||||
|
|
||||||
async def reset_key_status(provider_name: str, api_key: str | None = None) -> bool:
|
async def reset_key_status(provider_name: str, api_key: str | None = None) -> bool:
|
||||||
"""重置指定Provider的Key状态"""
|
"""
|
||||||
|
重置指定Provider的Key状态
|
||||||
|
|
||||||
|
参数:
|
||||||
|
provider_name: 提供商名称。
|
||||||
|
api_key: 要重置的特定API密钥,如果为None则重置所有密钥。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bool: 重置是否成功。
|
||||||
|
"""
|
||||||
providers = get_configured_providers()
|
providers = get_configured_providers()
|
||||||
target_provider = None
|
target_provider = None
|
||||||
|
|
||||||
|
|||||||
@ -6,11 +6,13 @@ LLM 模型实现类
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
|
from contextlib import AsyncExitStack
|
||||||
import json
|
import json
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
|
|
||||||
|
from .adapters.base import RequestData
|
||||||
from .config import LLMGenerationConfig
|
from .config import LLMGenerationConfig
|
||||||
from .config.providers import get_ai_config
|
from .config.providers import get_ai_config
|
||||||
from .core import (
|
from .core import (
|
||||||
@ -30,6 +32,8 @@ from .types import (
|
|||||||
ModelDetail,
|
ModelDetail,
|
||||||
ProviderConfig,
|
ProviderConfig,
|
||||||
)
|
)
|
||||||
|
from .types.capabilities import ModelCapabilities, ModelModality
|
||||||
|
from .utils import _sanitize_request_body_for_logging
|
||||||
|
|
||||||
|
|
||||||
class LLMModelBase(ABC):
|
class LLMModelBase(ABC):
|
||||||
@ -42,7 +46,17 @@ class LLMModelBase(ABC):
|
|||||||
history: list[dict[str, str]] | None = None,
|
history: list[dict[str, str]] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""生成文本"""
|
"""
|
||||||
|
生成文本
|
||||||
|
|
||||||
|
参数:
|
||||||
|
prompt: 输入提示词。
|
||||||
|
history: 对话历史记录。
|
||||||
|
**kwargs: 其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 生成的文本。
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -54,7 +68,19 @@ class LLMModelBase(ABC):
|
|||||||
tool_choice: str | dict[str, Any] | None = None,
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""生成高级响应"""
|
"""
|
||||||
|
生成高级响应
|
||||||
|
|
||||||
|
参数:
|
||||||
|
messages: 消息列表。
|
||||||
|
config: 生成配置。
|
||||||
|
tools: 工具列表。
|
||||||
|
tool_choice: 工具选择策略。
|
||||||
|
**kwargs: 其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
LLMResponse: 模型响应。
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -64,7 +90,17 @@ class LLMModelBase(ABC):
|
|||||||
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> list[list[float]]:
|
) -> list[list[float]]:
|
||||||
"""生成文本嵌入向量"""
|
"""
|
||||||
|
生成文本嵌入向量
|
||||||
|
|
||||||
|
参数:
|
||||||
|
texts: 文本列表。
|
||||||
|
task_type: 嵌入任务类型。
|
||||||
|
**kwargs: 其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
list[list[float]]: 嵌入向量列表。
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -77,12 +113,14 @@ class LLMModel(LLMModelBase):
|
|||||||
model_detail: ModelDetail,
|
model_detail: ModelDetail,
|
||||||
key_store: KeyStatusStore,
|
key_store: KeyStatusStore,
|
||||||
http_client: LLMHttpClient,
|
http_client: LLMHttpClient,
|
||||||
|
capabilities: ModelCapabilities,
|
||||||
config_override: LLMGenerationConfig | None = None,
|
config_override: LLMGenerationConfig | None = None,
|
||||||
):
|
):
|
||||||
self.provider_config = provider_config
|
self.provider_config = provider_config
|
||||||
self.model_detail = model_detail
|
self.model_detail = model_detail
|
||||||
self.key_store = key_store
|
self.key_store = key_store
|
||||||
self.http_client: LLMHttpClient = http_client
|
self.http_client: LLMHttpClient = http_client
|
||||||
|
self.capabilities = capabilities
|
||||||
self._generation_config = config_override
|
self._generation_config = config_override
|
||||||
|
|
||||||
self.provider_name = provider_config.name
|
self.provider_name = provider_config.name
|
||||||
@ -99,6 +137,34 @@ class LLMModel(LLMModelBase):
|
|||||||
|
|
||||||
self._is_closed = False
|
self._is_closed = False
|
||||||
|
|
||||||
|
def can_process_images(self) -> bool:
|
||||||
|
"""检查模型是否支持图片作为输入。"""
|
||||||
|
return ModelModality.IMAGE in self.capabilities.input_modalities
|
||||||
|
|
||||||
|
def can_process_video(self) -> bool:
|
||||||
|
"""检查模型是否支持视频作为输入。"""
|
||||||
|
return ModelModality.VIDEO in self.capabilities.input_modalities
|
||||||
|
|
||||||
|
def can_process_audio(self) -> bool:
|
||||||
|
"""检查模型是否支持音频作为输入。"""
|
||||||
|
return ModelModality.AUDIO in self.capabilities.input_modalities
|
||||||
|
|
||||||
|
def can_generate_images(self) -> bool:
|
||||||
|
"""检查模型是否支持生成图片。"""
|
||||||
|
return ModelModality.IMAGE in self.capabilities.output_modalities
|
||||||
|
|
||||||
|
def can_generate_audio(self) -> bool:
|
||||||
|
"""检查模型是否支持生成音频 (TTS)。"""
|
||||||
|
return ModelModality.AUDIO in self.capabilities.output_modalities
|
||||||
|
|
||||||
|
def can_use_tools(self) -> bool:
|
||||||
|
"""检查模型是否支持工具调用/函数调用。"""
|
||||||
|
return self.capabilities.supports_tool_calling
|
||||||
|
|
||||||
|
def is_embedding_model(self) -> bool:
|
||||||
|
"""检查这是否是一个嵌入模型。"""
|
||||||
|
return self.capabilities.is_embedding_model
|
||||||
|
|
||||||
async def _get_http_client(self) -> LLMHttpClient:
|
async def _get_http_client(self) -> LLMHttpClient:
|
||||||
"""获取HTTP客户端"""
|
"""获取HTTP客户端"""
|
||||||
if self.http_client.is_closed:
|
if self.http_client.is_closed:
|
||||||
@ -135,24 +201,54 @@ class LLMModel(LLMModelBase):
|
|||||||
|
|
||||||
return selected_key
|
return selected_key
|
||||||
|
|
||||||
async def _execute_embedding_request(
|
async def _perform_api_call(
|
||||||
self,
|
self,
|
||||||
adapter,
|
prepare_request_func: Callable[[str], Awaitable["RequestData"]],
|
||||||
texts: list[str],
|
parse_response_func: Callable[[dict[str, Any]], Any],
|
||||||
task_type: EmbeddingTaskType | str,
|
http_client: "LLMHttpClient",
|
||||||
http_client: LLMHttpClient,
|
|
||||||
failed_keys: set[str] | None = None,
|
failed_keys: set[str] | None = None,
|
||||||
) -> list[list[float]]:
|
log_context: str = "API",
|
||||||
"""执行单次嵌入请求 - 供重试机制调用"""
|
) -> Any:
|
||||||
|
"""
|
||||||
|
执行API调用的通用核心方法。
|
||||||
|
|
||||||
|
该方法封装了以下通用逻辑:
|
||||||
|
1. 选择API密钥。
|
||||||
|
2. 准备和记录请求。
|
||||||
|
3. 发送HTTP POST请求。
|
||||||
|
4. 处理HTTP错误和API特定错误。
|
||||||
|
5. 记录密钥使用状态。
|
||||||
|
6. 解析成功的响应。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
prepare_request_func: 准备请求的函数。
|
||||||
|
parse_response_func: 解析响应的函数。
|
||||||
|
http_client: HTTP客户端。
|
||||||
|
failed_keys: 失败的密钥集合。
|
||||||
|
log_context: 日志上下文。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
Any: 解析后的响应数据。
|
||||||
|
"""
|
||||||
api_key = await self._select_api_key(failed_keys)
|
api_key = await self._select_api_key(failed_keys)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
request_data = adapter.prepare_embedding_request(
|
request_data = await prepare_request_func(api_key)
|
||||||
model=self,
|
|
||||||
api_key=api_key,
|
logger.info(
|
||||||
texts=texts,
|
f"🌐 发起LLM请求 - 模型: {self.provider_name}/{self.model_name} "
|
||||||
task_type=task_type,
|
f"[{log_context}]"
|
||||||
)
|
)
|
||||||
|
logger.debug(f"📡 请求URL: {request_data.url}")
|
||||||
|
masked_key = (
|
||||||
|
f"{api_key[:8]}...{api_key[-4:] if len(api_key) > 12 else '***'}"
|
||||||
|
)
|
||||||
|
logger.debug(f"🔑 API密钥: {masked_key}")
|
||||||
|
logger.debug(f"📋 请求头: {dict(request_data.headers)}")
|
||||||
|
|
||||||
|
sanitized_body = _sanitize_request_body_for_logging(request_data.body)
|
||||||
|
request_body_str = json.dumps(sanitized_body, ensure_ascii=False, indent=2)
|
||||||
|
logger.debug(f"📦 请求体: {request_body_str}")
|
||||||
|
|
||||||
http_response = await http_client.post(
|
http_response = await http_client.post(
|
||||||
request_data.url,
|
request_data.url,
|
||||||
@ -160,121 +256,16 @@ class LLMModel(LLMModelBase):
|
|||||||
json=request_data.body,
|
json=request_data.body,
|
||||||
)
|
)
|
||||||
|
|
||||||
if http_response.status_code != 200:
|
logger.debug(f"📥 响应状态码: {http_response.status_code}")
|
||||||
error_text = http_response.text
|
logger.debug(f"📄 响应头: {dict(http_response.headers)}")
|
||||||
logger.error(
|
|
||||||
f"HTTP嵌入请求失败: {http_response.status_code} - {error_text}"
|
|
||||||
)
|
|
||||||
await self.key_store.record_failure(api_key, http_response.status_code)
|
|
||||||
|
|
||||||
error_code = LLMErrorCode.API_REQUEST_FAILED
|
|
||||||
if http_response.status_code in [401, 403]:
|
|
||||||
error_code = LLMErrorCode.API_KEY_INVALID
|
|
||||||
elif http_response.status_code == 429:
|
|
||||||
error_code = LLMErrorCode.API_RATE_LIMITED
|
|
||||||
|
|
||||||
raise LLMException(
|
|
||||||
f"HTTP嵌入请求失败: {http_response.status_code}",
|
|
||||||
code=error_code,
|
|
||||||
details={
|
|
||||||
"status_code": http_response.status_code,
|
|
||||||
"response": error_text,
|
|
||||||
"api_key": api_key,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response_json = http_response.json()
|
|
||||||
adapter.validate_embedding_response(response_json)
|
|
||||||
embeddings = adapter.parse_embedding_response(response_json)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"解析嵌入响应失败: {e}", e=e)
|
|
||||||
await self.key_store.record_failure(api_key, None)
|
|
||||||
if isinstance(e, LLMException):
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
raise LLMException(
|
|
||||||
f"解析API嵌入响应失败: {e}",
|
|
||||||
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
|
|
||||||
cause=e,
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.key_store.record_success(api_key)
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
except LLMException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"生成嵌入时发生未预期错误: {e}", e=e)
|
|
||||||
await self.key_store.record_failure(api_key, None)
|
|
||||||
raise LLMException(
|
|
||||||
f"生成嵌入失败: {e}",
|
|
||||||
code=LLMErrorCode.EMBEDDING_FAILED,
|
|
||||||
cause=e,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _execute_with_smart_retry(
|
|
||||||
self,
|
|
||||||
adapter,
|
|
||||||
messages: list[LLMMessage],
|
|
||||||
config: LLMGenerationConfig | None,
|
|
||||||
tools_dict: list[dict[str, Any]] | None,
|
|
||||||
tool_choice: str | dict[str, Any] | None,
|
|
||||||
http_client: LLMHttpClient,
|
|
||||||
):
|
|
||||||
"""智能重试机制 - 使用统一的重试装饰器"""
|
|
||||||
ai_config = get_ai_config()
|
|
||||||
max_retries = ai_config.get("max_retries_llm", 3)
|
|
||||||
retry_delay = ai_config.get("retry_delay_llm", 2)
|
|
||||||
retry_config = RetryConfig(max_retries=max_retries, retry_delay=retry_delay)
|
|
||||||
|
|
||||||
return await with_smart_retry(
|
|
||||||
self._execute_single_request,
|
|
||||||
adapter,
|
|
||||||
messages,
|
|
||||||
config,
|
|
||||||
tools_dict,
|
|
||||||
tool_choice,
|
|
||||||
http_client,
|
|
||||||
retry_config=retry_config,
|
|
||||||
key_store=self.key_store,
|
|
||||||
provider_name=self.provider_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _execute_single_request(
|
|
||||||
self,
|
|
||||||
adapter,
|
|
||||||
messages: list[LLMMessage],
|
|
||||||
config: LLMGenerationConfig | None,
|
|
||||||
tools_dict: list[dict[str, Any]] | None,
|
|
||||||
tool_choice: str | dict[str, Any] | None,
|
|
||||||
http_client: LLMHttpClient,
|
|
||||||
failed_keys: set[str] | None = None,
|
|
||||||
) -> LLMResponse:
|
|
||||||
"""执行单次请求 - 供重试机制调用,直接返回 LLMResponse"""
|
|
||||||
api_key = await self._select_api_key(failed_keys)
|
|
||||||
|
|
||||||
try:
|
|
||||||
request_data = adapter.prepare_advanced_request(
|
|
||||||
model=self,
|
|
||||||
api_key=api_key,
|
|
||||||
messages=messages,
|
|
||||||
config=config,
|
|
||||||
tools=tools_dict,
|
|
||||||
tool_choice=tool_choice,
|
|
||||||
)
|
|
||||||
|
|
||||||
http_response = await http_client.post(
|
|
||||||
request_data.url,
|
|
||||||
headers=request_data.headers,
|
|
||||||
json=request_data.body,
|
|
||||||
)
|
|
||||||
|
|
||||||
if http_response.status_code != 200:
|
if http_response.status_code != 200:
|
||||||
error_text = http_response.text
|
error_text = http_response.text
|
||||||
logger.error(
|
logger.error(
|
||||||
f"HTTP请求失败: {http_response.status_code} - {error_text}"
|
f"❌ HTTP请求失败: {http_response.status_code} - {error_text} "
|
||||||
|
f"[{log_context}]"
|
||||||
)
|
)
|
||||||
|
logger.debug(f"💥 完整错误响应: {error_text}")
|
||||||
|
|
||||||
await self.key_store.record_failure(api_key, http_response.status_code)
|
await self.key_store.record_failure(api_key, http_response.status_code)
|
||||||
|
|
||||||
@ -299,12 +290,129 @@ class LLMModel(LLMModelBase):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
response_json = http_response.json()
|
response_json = http_response.json()
|
||||||
|
response_json_str = json.dumps(
|
||||||
|
response_json, ensure_ascii=False, indent=2
|
||||||
|
)
|
||||||
|
logger.debug(f"📋 响应JSON: {response_json_str}")
|
||||||
|
parsed_data = parse_response_func(response_json)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"解析 {log_context} 响应失败: {e}", e=e)
|
||||||
|
await self.key_store.record_failure(api_key, None)
|
||||||
|
if isinstance(e, LLMException):
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
raise LLMException(
|
||||||
|
f"解析API {log_context} 响应失败: {e}",
|
||||||
|
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.key_store.record_success(api_key)
|
||||||
|
logger.debug(f"✅ API密钥使用成功: {masked_key}")
|
||||||
|
logger.info(f"🎯 LLM响应解析完成 [{log_context}]")
|
||||||
|
return parsed_data
|
||||||
|
|
||||||
|
except LLMException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
error_log_msg = f"生成 {log_context.lower()} 时发生未预期错误: {e}"
|
||||||
|
logger.error(error_log_msg, e=e)
|
||||||
|
await self.key_store.record_failure(api_key, None)
|
||||||
|
raise LLMException(
|
||||||
|
error_log_msg,
|
||||||
|
code=LLMErrorCode.GENERATION_FAILED
|
||||||
|
if log_context == "Generation"
|
||||||
|
else LLMErrorCode.EMBEDDING_FAILED,
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _execute_embedding_request(
|
||||||
|
self,
|
||||||
|
adapter,
|
||||||
|
texts: list[str],
|
||||||
|
task_type: EmbeddingTaskType | str,
|
||||||
|
http_client: LLMHttpClient,
|
||||||
|
failed_keys: set[str] | None = None,
|
||||||
|
) -> list[list[float]]:
|
||||||
|
"""执行单次嵌入请求 - 供重试机制调用"""
|
||||||
|
|
||||||
|
async def prepare_request(api_key: str) -> RequestData:
|
||||||
|
return adapter.prepare_embedding_request(
|
||||||
|
model=self,
|
||||||
|
api_key=api_key,
|
||||||
|
texts=texts,
|
||||||
|
task_type=task_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
def parse_response(response_json: dict[str, Any]) -> list[list[float]]:
|
||||||
|
adapter.validate_embedding_response(response_json)
|
||||||
|
return adapter.parse_embedding_response(response_json)
|
||||||
|
|
||||||
|
return await self._perform_api_call(
|
||||||
|
prepare_request_func=prepare_request,
|
||||||
|
parse_response_func=parse_response,
|
||||||
|
http_client=http_client,
|
||||||
|
failed_keys=failed_keys,
|
||||||
|
log_context="Embedding",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _execute_with_smart_retry(
|
||||||
|
self,
|
||||||
|
adapter,
|
||||||
|
messages: list[LLMMessage],
|
||||||
|
config: LLMGenerationConfig | None,
|
||||||
|
tools: list[LLMTool] | None,
|
||||||
|
tool_choice: str | dict[str, Any] | None,
|
||||||
|
http_client: LLMHttpClient,
|
||||||
|
):
|
||||||
|
"""智能重试机制 - 使用统一的重试装饰器"""
|
||||||
|
ai_config = get_ai_config()
|
||||||
|
max_retries = ai_config.get("max_retries_llm", 3)
|
||||||
|
retry_delay = ai_config.get("retry_delay_llm", 2)
|
||||||
|
retry_config = RetryConfig(max_retries=max_retries, retry_delay=retry_delay)
|
||||||
|
|
||||||
|
return await with_smart_retry(
|
||||||
|
self._execute_single_request,
|
||||||
|
adapter,
|
||||||
|
messages,
|
||||||
|
config,
|
||||||
|
tools,
|
||||||
|
tool_choice,
|
||||||
|
http_client,
|
||||||
|
retry_config=retry_config,
|
||||||
|
key_store=self.key_store,
|
||||||
|
provider_name=self.provider_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _execute_single_request(
|
||||||
|
self,
|
||||||
|
adapter,
|
||||||
|
messages: list[LLMMessage],
|
||||||
|
config: LLMGenerationConfig | None,
|
||||||
|
tools: list[LLMTool] | None,
|
||||||
|
tool_choice: str | dict[str, Any] | None,
|
||||||
|
http_client: LLMHttpClient,
|
||||||
|
failed_keys: set[str] | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""执行单次请求 - 供重试机制调用,直接返回 LLMResponse"""
|
||||||
|
|
||||||
|
async def prepare_request(api_key: str) -> RequestData:
|
||||||
|
return await adapter.prepare_advanced_request(
|
||||||
|
model=self,
|
||||||
|
api_key=api_key,
|
||||||
|
messages=messages,
|
||||||
|
config=config,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
)
|
||||||
|
|
||||||
|
def parse_response(response_json: dict[str, Any]) -> LLMResponse:
|
||||||
response_data = adapter.parse_response(
|
response_data = adapter.parse_response(
|
||||||
model=self,
|
model=self,
|
||||||
response_json=response_json,
|
response_json=response_json,
|
||||||
is_advanced=True,
|
is_advanced=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .types.models import LLMToolCall
|
from .types.models import LLMToolCall
|
||||||
|
|
||||||
response_tool_calls = []
|
response_tool_calls = []
|
||||||
@ -323,7 +431,7 @@ class LLMModel(LLMModelBase):
|
|||||||
else:
|
else:
|
||||||
logger.warning(f"工具调用数据格式未知: {tc_data}")
|
logger.warning(f"工具调用数据格式未知: {tc_data}")
|
||||||
|
|
||||||
llm_response = LLMResponse(
|
return LLMResponse(
|
||||||
text=response_data.text,
|
text=response_data.text,
|
||||||
usage_info=response_data.usage_info,
|
usage_info=response_data.usage_info,
|
||||||
raw_response=response_data.raw_response,
|
raw_response=response_data.raw_response,
|
||||||
@ -333,33 +441,12 @@ class LLMModel(LLMModelBase):
|
|||||||
cache_info=response_data.cache_info,
|
cache_info=response_data.cache_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
return await self._perform_api_call(
|
||||||
logger.error(f"解析响应失败: {e}", e=e)
|
prepare_request_func=prepare_request,
|
||||||
await self.key_store.record_failure(api_key, None)
|
parse_response_func=parse_response,
|
||||||
|
http_client=http_client,
|
||||||
if isinstance(e, LLMException):
|
failed_keys=failed_keys,
|
||||||
raise
|
log_context="Generation",
|
||||||
else:
|
|
||||||
raise LLMException(
|
|
||||||
f"解析API响应失败: {e}",
|
|
||||||
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
|
|
||||||
cause=e,
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.key_store.record_success(api_key)
|
|
||||||
|
|
||||||
return llm_response
|
|
||||||
|
|
||||||
except LLMException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"生成响应时发生未预期错误: {e}", e=e)
|
|
||||||
await self.key_store.record_failure(api_key, None)
|
|
||||||
|
|
||||||
raise LLMException(
|
|
||||||
f"生成响应失败: {e}",
|
|
||||||
code=LLMErrorCode.GENERATION_FAILED,
|
|
||||||
cause=e,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
@ -400,7 +487,17 @@ class LLMModel(LLMModelBase):
|
|||||||
history: list[dict[str, str]] | None = None,
|
history: list[dict[str, str]] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""生成文本 - 通过 generate_response 实现"""
|
"""
|
||||||
|
生成文本 - 通过 generate_response 实现
|
||||||
|
|
||||||
|
参数:
|
||||||
|
prompt: 输入提示词。
|
||||||
|
history: 对话历史记录。
|
||||||
|
**kwargs: 其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 生成的文本。
|
||||||
|
"""
|
||||||
self._check_not_closed()
|
self._check_not_closed()
|
||||||
|
|
||||||
messages: list[LLMMessage] = []
|
messages: list[LLMMessage] = []
|
||||||
@ -439,11 +536,21 @@ class LLMModel(LLMModelBase):
|
|||||||
config: LLMGenerationConfig | None = None,
|
config: LLMGenerationConfig | None = None,
|
||||||
tools: list[LLMTool] | None = None,
|
tools: list[LLMTool] | None = None,
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
tool_executor: Callable[[str, dict[str, Any]], Awaitable[Any]] | None = None,
|
|
||||||
max_tool_iterations: int = 5,
|
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""生成高级响应 - 实现完整的工具调用循环"""
|
"""
|
||||||
|
生成高级响应
|
||||||
|
|
||||||
|
参数:
|
||||||
|
messages: 消息列表。
|
||||||
|
config: 生成配置。
|
||||||
|
tools: 工具列表。
|
||||||
|
tool_choice: 工具选择策略。
|
||||||
|
**kwargs: 其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
LLMResponse: 模型响应。
|
||||||
|
"""
|
||||||
self._check_not_closed()
|
self._check_not_closed()
|
||||||
|
|
||||||
from .adapters import get_adapter_for_api_type
|
from .adapters import get_adapter_for_api_type
|
||||||
@ -468,117 +575,61 @@ class LLMModel(LLMModelBase):
|
|||||||
merged_dict.update(config.to_dict())
|
merged_dict.update(config.to_dict())
|
||||||
final_request_config = LLMGenerationConfig(**merged_dict)
|
final_request_config = LLMGenerationConfig(**merged_dict)
|
||||||
|
|
||||||
tools_dict: list[dict[str, Any]] | None = None
|
|
||||||
if tools:
|
|
||||||
tools_dict = []
|
|
||||||
for tool in tools:
|
|
||||||
if hasattr(tool, "model_dump"):
|
|
||||||
model_dump_func = getattr(tool, "model_dump")
|
|
||||||
tools_dict.append(model_dump_func(exclude_none=True))
|
|
||||||
elif isinstance(tool, dict):
|
|
||||||
tools_dict.append(tool)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
tools_dict.append(dict(tool))
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
logger.warning(f"工具 '{tool}' 无法转换为字典,已忽略。")
|
|
||||||
|
|
||||||
http_client = await self._get_http_client()
|
http_client = await self._get_http_client()
|
||||||
current_messages = list(messages)
|
|
||||||
|
|
||||||
for iteration in range(max_tool_iterations):
|
async with AsyncExitStack() as stack:
|
||||||
logger.debug(f"工具调用循环迭代: {iteration + 1}/{max_tool_iterations}")
|
activated_tools = []
|
||||||
|
if tools:
|
||||||
|
for tool in tools:
|
||||||
|
if tool.type == "mcp" and callable(tool.mcp_session):
|
||||||
|
func_obj = getattr(tool.mcp_session, "func", None)
|
||||||
|
tool_name = (
|
||||||
|
getattr(func_obj, "__name__", "unknown")
|
||||||
|
if func_obj
|
||||||
|
else "unknown"
|
||||||
|
)
|
||||||
|
logger.debug(f"正在激活 MCP 工具会话: {tool_name}")
|
||||||
|
|
||||||
|
active_session = await stack.enter_async_context(
|
||||||
|
tool.mcp_session()
|
||||||
|
)
|
||||||
|
|
||||||
|
activated_tools.append(
|
||||||
|
LLMTool.from_mcp_session(
|
||||||
|
session=active_session, annotations=tool.annotations
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
activated_tools.append(tool)
|
||||||
|
|
||||||
llm_response = await self._execute_with_smart_retry(
|
llm_response = await self._execute_with_smart_retry(
|
||||||
adapter,
|
adapter,
|
||||||
current_messages,
|
messages,
|
||||||
final_request_config,
|
final_request_config,
|
||||||
tools_dict if iteration == 0 else None,
|
activated_tools if activated_tools else None,
|
||||||
tool_choice if iteration == 0 else None,
|
tool_choice,
|
||||||
http_client,
|
http_client,
|
||||||
)
|
)
|
||||||
|
|
||||||
response_tool_calls = llm_response.tool_calls or []
|
|
||||||
|
|
||||||
if not response_tool_calls or not tool_executor:
|
|
||||||
logger.debug("模型未请求工具调用,或未提供工具执行器。返回当前响应。")
|
|
||||||
return llm_response
|
return llm_response
|
||||||
|
|
||||||
logger.info(f"模型请求执行 {len(response_tool_calls)} 个工具。")
|
|
||||||
|
|
||||||
assistant_message_content = llm_response.text if llm_response.text else ""
|
|
||||||
current_messages.append(
|
|
||||||
LLMMessage.assistant_tool_calls(
|
|
||||||
content=assistant_message_content, tool_calls=response_tool_calls
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_response_messages: list[LLMMessage] = []
|
|
||||||
for tool_call in response_tool_calls:
|
|
||||||
tool_name = tool_call.function.name
|
|
||||||
try:
|
|
||||||
tool_args_dict = json.loads(tool_call.function.arguments)
|
|
||||||
logger.debug(f"执行工具: {tool_name},参数: {tool_args_dict}")
|
|
||||||
|
|
||||||
tool_result = await tool_executor(tool_name, tool_args_dict)
|
|
||||||
logger.debug(
|
|
||||||
f"工具 '{tool_name}' 执行结果: {str(tool_result)[:200]}..."
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_response_messages.append(
|
|
||||||
LLMMessage.tool_response(
|
|
||||||
tool_call_id=tool_call.id,
|
|
||||||
function_name=tool_name,
|
|
||||||
result=tool_result,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
logger.error(
|
|
||||||
f"工具 '{tool_name}' 参数JSON解析失败: "
|
|
||||||
f"{tool_call.function.arguments}, 错误: {e}"
|
|
||||||
)
|
|
||||||
tool_response_messages.append(
|
|
||||||
LLMMessage.tool_response(
|
|
||||||
tool_call_id=tool_call.id,
|
|
||||||
function_name=tool_name,
|
|
||||||
result={
|
|
||||||
"error": "Argument JSON parsing failed",
|
|
||||||
"details": str(e),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"执行工具 '{tool_name}' 失败: {e}", e=e)
|
|
||||||
tool_response_messages.append(
|
|
||||||
LLMMessage.tool_response(
|
|
||||||
tool_call_id=tool_call.id,
|
|
||||||
function_name=tool_name,
|
|
||||||
result={
|
|
||||||
"error": "Tool execution failed",
|
|
||||||
"details": str(e),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
current_messages.extend(tool_response_messages)
|
|
||||||
|
|
||||||
logger.warning(f"已达到最大工具调用迭代次数 ({max_tool_iterations})。")
|
|
||||||
raise LLMException(
|
|
||||||
"已达到最大工具调用迭代次数,但模型仍在请求工具调用或未提供最终文本回复。",
|
|
||||||
code=LLMErrorCode.GENERATION_FAILED,
|
|
||||||
details={
|
|
||||||
"iterations": max_tool_iterations,
|
|
||||||
"last_messages": current_messages[-2:],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def generate_embeddings(
|
async def generate_embeddings(
|
||||||
self,
|
self,
|
||||||
texts: list[str],
|
texts: list[str],
|
||||||
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> list[list[float]]:
|
) -> list[list[float]]:
|
||||||
"""生成文本嵌入向量"""
|
"""
|
||||||
|
生成文本嵌入向量
|
||||||
|
|
||||||
|
参数:
|
||||||
|
texts: 文本列表。
|
||||||
|
task_type: 嵌入任务类型。
|
||||||
|
**kwargs: 其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
list[list[float]]: 嵌入向量列表。
|
||||||
|
"""
|
||||||
self._check_not_closed()
|
self._check_not_closed()
|
||||||
if not texts:
|
if not texts:
|
||||||
return []
|
return []
|
||||||
|
|||||||
7
zhenxun/services/llm/tools/__init__.py
Normal file
7
zhenxun/services/llm/tools/__init__.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
"""
|
||||||
|
工具模块导出
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .registry import tool_registry
|
||||||
|
|
||||||
|
__all__ = ["tool_registry"]
|
||||||
181
zhenxun/services/llm/tools/registry.py
Normal file
181
zhenxun/services/llm/tools/registry.py
Normal file
@ -0,0 +1,181 @@
|
|||||||
|
"""
|
||||||
|
工具注册表
|
||||||
|
|
||||||
|
负责加载、管理和实例化来自配置的工具。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from contextlib import AbstractAsyncContextManager
|
||||||
|
from functools import partial
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from zhenxun.services.log import logger
|
||||||
|
|
||||||
|
from ..types import LLMTool
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..config.providers import ToolConfig
|
||||||
|
from ..types.protocols import MCPCompatible
|
||||||
|
|
||||||
|
|
||||||
|
class ToolRegistry:
|
||||||
|
"""工具注册表,用于管理和实例化配置的工具。"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._function_tools: dict[str, LLMTool] = {}
|
||||||
|
|
||||||
|
self._mcp_config_models: dict[str, type[BaseModel]] = {}
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
self._mcp_factories: dict[
|
||||||
|
str, Callable[..., AbstractAsyncContextManager["MCPCompatible"]]
|
||||||
|
] = {}
|
||||||
|
else:
|
||||||
|
self._mcp_factories: dict[str, Callable] = {}
|
||||||
|
|
||||||
|
self._tool_configs: dict[str, "ToolConfig"] | None = None
|
||||||
|
self._tool_cache: dict[str, "LLMTool"] = {}
|
||||||
|
|
||||||
|
def _load_configs_if_needed(self):
|
||||||
|
"""如果尚未加载,则从主配置中加载MCP工具定义。"""
|
||||||
|
if self._tool_configs is None:
|
||||||
|
logger.debug("首次访问,正在加载MCP工具配置...")
|
||||||
|
from ..config.providers import get_llm_config
|
||||||
|
|
||||||
|
llm_config = get_llm_config()
|
||||||
|
self._tool_configs = {tool.name: tool for tool in llm_config.mcp_tools}
|
||||||
|
logger.info(f"已加载 {len(self._tool_configs)} 个MCP工具配置。")
|
||||||
|
|
||||||
|
def function_tool(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
parameters: dict,
|
||||||
|
required: list[str] | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
装饰器:在代码中注册一个简单的、无状态的函数工具。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
name: 工具的唯一名称。
|
||||||
|
description: 工具功能的描述。
|
||||||
|
parameters: OpenAPI格式的函数参数schema的properties部分。
|
||||||
|
required: 必需的参数列表。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(func: Callable):
|
||||||
|
if name in self._function_tools or name in self._mcp_factories:
|
||||||
|
logger.warning(f"正在覆盖已注册的工具: {name}")
|
||||||
|
|
||||||
|
tool_definition = LLMTool.create(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
parameters=parameters,
|
||||||
|
required=required,
|
||||||
|
)
|
||||||
|
self._function_tools[name] = tool_definition
|
||||||
|
logger.info(f"已在代码中注册函数工具: '{name}'")
|
||||||
|
tool_definition.annotations = tool_definition.annotations or {}
|
||||||
|
tool_definition.annotations["executable"] = func
|
||||||
|
return func
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def mcp_tool(self, name: str, config_model: type[BaseModel]):
|
||||||
|
"""
|
||||||
|
装饰器:注册一个MCP工具及其配置模型。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
name: 工具的唯一名称,必须与配置文件中的名称匹配。
|
||||||
|
config_model: 一个Pydantic模型,用于定义和验证该工具的 `mcp_config`。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(factory_func: Callable):
|
||||||
|
if name in self._mcp_factories:
|
||||||
|
logger.warning(f"正在覆盖已注册的 MCP 工厂: {name}")
|
||||||
|
self._mcp_factories[name] = factory_func
|
||||||
|
self._mcp_config_models[name] = config_model
|
||||||
|
logger.info(f"已注册 MCP 工具 '{name}' (配置模型: {config_model.__name__})")
|
||||||
|
return factory_func
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def get_mcp_config_model(self, name: str) -> type[BaseModel] | None:
|
||||||
|
"""根据名称获取MCP工具的配置模型。"""
|
||||||
|
return self._mcp_config_models.get(name)
|
||||||
|
|
||||||
|
def register_mcp_factory(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
factory: Callable,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
在代码中注册一个 MCP 会话工厂,将其与配置中的工具名称关联。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
name: 工具的唯一名称,必须与配置文件中的名称匹配。
|
||||||
|
factory: 一个返回异步生成器的可调用对象(会话工厂)。
|
||||||
|
"""
|
||||||
|
if name in self._mcp_factories:
|
||||||
|
logger.warning(f"正在覆盖已注册的 MCP 工厂: {name}")
|
||||||
|
self._mcp_factories[name] = factory
|
||||||
|
logger.info(f"已注册 MCP 会话工厂: '{name}'")
|
||||||
|
|
||||||
|
def get_tool(self, name: str) -> "LLMTool":
|
||||||
|
"""
|
||||||
|
根据名称获取一个 LLMTool 定义。
|
||||||
|
对于MCP工具,返回的 LLMTool 实例包含一个可调用的会话工厂,
|
||||||
|
而不是一个已激活的会话。
|
||||||
|
"""
|
||||||
|
logger.debug(f"🔍 请求获取工具定义: {name}")
|
||||||
|
|
||||||
|
if name in self._tool_cache:
|
||||||
|
logger.debug(f"✅ 从缓存中获取工具定义: {name}")
|
||||||
|
return self._tool_cache[name]
|
||||||
|
|
||||||
|
if name in self._function_tools:
|
||||||
|
logger.debug(f"🛠️ 获取函数工具定义: {name}")
|
||||||
|
tool = self._function_tools[name]
|
||||||
|
self._tool_cache[name] = tool
|
||||||
|
return tool
|
||||||
|
|
||||||
|
self._load_configs_if_needed()
|
||||||
|
if self._tool_configs is None or name not in self._tool_configs:
|
||||||
|
known_tools = list(self._function_tools.keys()) + (
|
||||||
|
list(self._tool_configs.keys()) if self._tool_configs else []
|
||||||
|
)
|
||||||
|
logger.error(f"❌ 未找到名为 '{name}' 的工具定义")
|
||||||
|
logger.debug(f"📋 可用工具定义列表: {known_tools}")
|
||||||
|
raise ValueError(f"未找到名为 '{name}' 的工具定义。已知工具: {known_tools}")
|
||||||
|
|
||||||
|
config = self._tool_configs[name]
|
||||||
|
tool: "LLMTool"
|
||||||
|
|
||||||
|
if name not in self._mcp_factories:
|
||||||
|
logger.error(f"❌ MCP工具 '{name}' 缺少工厂函数")
|
||||||
|
available_factories = list(self._mcp_factories.keys())
|
||||||
|
logger.debug(f"📋 已注册的MCP工厂: {available_factories}")
|
||||||
|
raise ValueError(
|
||||||
|
f"MCP 工具 '{name}' 已在配置中定义,但没有注册对应的工厂函数。"
|
||||||
|
"请使用 `@tool_registry.mcp_tool` 装饰器进行注册。"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"🔧 创建MCP工具定义: {name}")
|
||||||
|
factory = self._mcp_factories[name]
|
||||||
|
typed_mcp_config = config.mcp_config
|
||||||
|
logger.debug(f"📋 MCP工具配置: {typed_mcp_config}")
|
||||||
|
|
||||||
|
configured_factory = partial(factory, config=typed_mcp_config)
|
||||||
|
tool = LLMTool.from_mcp_session(session=configured_factory)
|
||||||
|
|
||||||
|
self._tool_cache[name] = tool
|
||||||
|
logger.debug(f"💾 MCP工具定义已缓存: {name}")
|
||||||
|
return tool
|
||||||
|
|
||||||
|
def get_tools(self, names: list[str]) -> list["LLMTool"]:
|
||||||
|
"""根据名称列表获取多个 LLMTool 实例。"""
|
||||||
|
return [self.get_tool(name) for name in names]
|
||||||
|
|
||||||
|
|
||||||
|
tool_registry = ToolRegistry()
|
||||||
@ -4,6 +4,7 @@ LLM 类型定义模块
|
|||||||
统一导出所有核心类型、协议和异常定义。
|
统一导出所有核心类型、协议和异常定义。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from .capabilities import ModelCapabilities, ModelModality, get_model_capabilities
|
||||||
from .content import (
|
from .content import (
|
||||||
LLMContentPart,
|
LLMContentPart,
|
||||||
LLMMessage,
|
LLMMessage,
|
||||||
@ -26,6 +27,7 @@ from .models import (
|
|||||||
ToolMetadata,
|
ToolMetadata,
|
||||||
UsageInfo,
|
UsageInfo,
|
||||||
)
|
)
|
||||||
|
from .protocols import MCPCompatible
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"EmbeddingTaskType",
|
"EmbeddingTaskType",
|
||||||
@ -41,8 +43,11 @@ __all__ = [
|
|||||||
"LLMTool",
|
"LLMTool",
|
||||||
"LLMToolCall",
|
"LLMToolCall",
|
||||||
"LLMToolFunction",
|
"LLMToolFunction",
|
||||||
|
"MCPCompatible",
|
||||||
|
"ModelCapabilities",
|
||||||
"ModelDetail",
|
"ModelDetail",
|
||||||
"ModelInfo",
|
"ModelInfo",
|
||||||
|
"ModelModality",
|
||||||
"ModelName",
|
"ModelName",
|
||||||
"ModelProvider",
|
"ModelProvider",
|
||||||
"ProviderConfig",
|
"ProviderConfig",
|
||||||
@ -50,5 +55,6 @@ __all__ = [
|
|||||||
"ToolCategory",
|
"ToolCategory",
|
||||||
"ToolMetadata",
|
"ToolMetadata",
|
||||||
"UsageInfo",
|
"UsageInfo",
|
||||||
|
"get_model_capabilities",
|
||||||
"get_user_friendly_error_message",
|
"get_user_friendly_error_message",
|
||||||
]
|
]
|
||||||
|
|||||||
128
zhenxun/services/llm/types/capabilities.py
Normal file
128
zhenxun/services/llm/types/capabilities.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
"""
|
||||||
|
LLM 模型能力定义模块
|
||||||
|
|
||||||
|
定义模型的输入输出模态、工具调用支持等核心能力。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
import fnmatch
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class ModelModality(str, Enum):
|
||||||
|
TEXT = "text"
|
||||||
|
IMAGE = "image"
|
||||||
|
AUDIO = "audio"
|
||||||
|
VIDEO = "video"
|
||||||
|
EMBEDDING = "embedding"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelCapabilities(BaseModel):
|
||||||
|
"""定义一个模型的核心、稳定能力。"""
|
||||||
|
|
||||||
|
input_modalities: set[ModelModality] = Field(default={ModelModality.TEXT})
|
||||||
|
output_modalities: set[ModelModality] = Field(default={ModelModality.TEXT})
|
||||||
|
supports_tool_calling: bool = False
|
||||||
|
is_embedding_model: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
STANDARD_TEXT_TOOL_CAPABILITIES = ModelCapabilities(
|
||||||
|
input_modalities={ModelModality.TEXT},
|
||||||
|
output_modalities={ModelModality.TEXT},
|
||||||
|
supports_tool_calling=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
GEMINI_CAPABILITIES = ModelCapabilities(
|
||||||
|
input_modalities={
|
||||||
|
ModelModality.TEXT,
|
||||||
|
ModelModality.IMAGE,
|
||||||
|
ModelModality.AUDIO,
|
||||||
|
ModelModality.VIDEO,
|
||||||
|
},
|
||||||
|
output_modalities={ModelModality.TEXT},
|
||||||
|
supports_tool_calling=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES = ModelCapabilities(
|
||||||
|
input_modalities={ModelModality.TEXT, ModelModality.IMAGE, ModelModality.VIDEO},
|
||||||
|
output_modalities={ModelModality.TEXT},
|
||||||
|
supports_tool_calling=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_ALIAS_MAPPING: dict[str, str] = {
|
||||||
|
"deepseek-v3*": "deepseek-chat",
|
||||||
|
"deepseek-ai/DeepSeek-V3": "deepseek-chat",
|
||||||
|
"deepseek-r1*": "deepseek-reasoner",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_CAPABILITIES_REGISTRY: dict[str, ModelCapabilities] = {
|
||||||
|
"gemini-*-tts": ModelCapabilities(
|
||||||
|
input_modalities={ModelModality.TEXT},
|
||||||
|
output_modalities={ModelModality.AUDIO},
|
||||||
|
),
|
||||||
|
"gemini-*-native-audio-*": ModelCapabilities(
|
||||||
|
input_modalities={ModelModality.TEXT, ModelModality.AUDIO, ModelModality.VIDEO},
|
||||||
|
output_modalities={ModelModality.TEXT, ModelModality.AUDIO},
|
||||||
|
supports_tool_calling=True,
|
||||||
|
),
|
||||||
|
"gemini-2.0-flash-preview-image-generation": ModelCapabilities(
|
||||||
|
input_modalities={
|
||||||
|
ModelModality.TEXT,
|
||||||
|
ModelModality.IMAGE,
|
||||||
|
ModelModality.AUDIO,
|
||||||
|
ModelModality.VIDEO,
|
||||||
|
},
|
||||||
|
output_modalities={ModelModality.TEXT, ModelModality.IMAGE},
|
||||||
|
supports_tool_calling=True,
|
||||||
|
),
|
||||||
|
"gemini-embedding-exp": ModelCapabilities(
|
||||||
|
input_modalities={ModelModality.TEXT},
|
||||||
|
output_modalities={ModelModality.EMBEDDING},
|
||||||
|
is_embedding_model=True,
|
||||||
|
),
|
||||||
|
"gemini-2.5-pro*": GEMINI_CAPABILITIES,
|
||||||
|
"gemini-1.5-pro*": GEMINI_CAPABILITIES,
|
||||||
|
"gemini-2.5-flash*": GEMINI_CAPABILITIES,
|
||||||
|
"gemini-2.0-flash*": GEMINI_CAPABILITIES,
|
||||||
|
"gemini-1.5-flash*": GEMINI_CAPABILITIES,
|
||||||
|
"GLM-4V-Flash": ModelCapabilities(
|
||||||
|
input_modalities={ModelModality.TEXT, ModelModality.IMAGE},
|
||||||
|
output_modalities={ModelModality.TEXT},
|
||||||
|
supports_tool_calling=True,
|
||||||
|
),
|
||||||
|
"GLM-4V-Plus*": ModelCapabilities(
|
||||||
|
input_modalities={ModelModality.TEXT, ModelModality.IMAGE, ModelModality.VIDEO},
|
||||||
|
output_modalities={ModelModality.TEXT},
|
||||||
|
supports_tool_calling=True,
|
||||||
|
),
|
||||||
|
"glm-4-*": STANDARD_TEXT_TOOL_CAPABILITIES,
|
||||||
|
"glm-z1-*": STANDARD_TEXT_TOOL_CAPABILITIES,
|
||||||
|
"doubao-seed-*": DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES,
|
||||||
|
"doubao-1-5-thinking-vision-pro": DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES,
|
||||||
|
"deepseek-chat": STANDARD_TEXT_TOOL_CAPABILITIES,
|
||||||
|
"deepseek-reasoner": STANDARD_TEXT_TOOL_CAPABILITIES,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_capabilities(model_name: str) -> ModelCapabilities:
|
||||||
|
"""
|
||||||
|
从注册表获取模型能力,支持别名映射和通配符匹配。
|
||||||
|
查找顺序: 1. 标准化名称 -> 2. 精确匹配 -> 3. 通配符匹配 -> 4. 默认值
|
||||||
|
"""
|
||||||
|
canonical_name = model_name
|
||||||
|
for alias_pattern, c_name in MODEL_ALIAS_MAPPING.items():
|
||||||
|
if fnmatch.fnmatch(model_name, alias_pattern):
|
||||||
|
canonical_name = c_name
|
||||||
|
break
|
||||||
|
|
||||||
|
if canonical_name in MODEL_CAPABILITIES_REGISTRY:
|
||||||
|
return MODEL_CAPABILITIES_REGISTRY[canonical_name]
|
||||||
|
|
||||||
|
for pattern, capabilities in MODEL_CAPABILITIES_REGISTRY.items():
|
||||||
|
if "*" in pattern and fnmatch.fnmatch(model_name, pattern):
|
||||||
|
return capabilities
|
||||||
|
|
||||||
|
return ModelCapabilities()
|
||||||
@ -225,8 +225,10 @@ class LLMContentPart(BaseModel):
|
|||||||
logger.warning(f"无法解析Base64图像数据: {self.image_source[:50]}...")
|
logger.warning(f"无法解析Base64图像数据: {self.image_source[:50]}...")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def convert_for_api(self, api_type: str) -> dict[str, Any]:
|
async def convert_for_api_async(self, api_type: str) -> dict[str, Any]:
|
||||||
"""根据API类型转换多模态内容格式"""
|
"""根据API类型转换多模态内容格式"""
|
||||||
|
from zhenxun.utils.http_utils import AsyncHttpx
|
||||||
|
|
||||||
if self.type == "text":
|
if self.type == "text":
|
||||||
if api_type == "openai":
|
if api_type == "openai":
|
||||||
return {"type": "text", "text": self.text}
|
return {"type": "text", "text": self.text}
|
||||||
@ -248,20 +250,23 @@ class LLMContentPart(BaseModel):
|
|||||||
mime_type, data = base64_info
|
mime_type, data = base64_info
|
||||||
return {"inlineData": {"mimeType": mime_type, "data": data}}
|
return {"inlineData": {"mimeType": mime_type, "data": data}}
|
||||||
else:
|
else:
|
||||||
# 如果无法解析 Base64 数据,抛出异常
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"无法解析Base64图像数据: {self.image_source[:50]}..."
|
f"无法解析Base64图像数据: {self.image_source[:50]}..."
|
||||||
)
|
)
|
||||||
else:
|
elif self.is_image_url():
|
||||||
logger.warning(
|
logger.debug(f"正在为Gemini下载并编码URL图片: {self.image_source}")
|
||||||
f"Gemini API需要Base64格式,但提供的是URL: {self.image_source}"
|
try:
|
||||||
)
|
image_bytes = await AsyncHttpx.get_content(self.image_source)
|
||||||
|
mime_type = self.mime_type or "image/jpeg"
|
||||||
|
base64_data = base64.b64encode(image_bytes).decode("utf-8")
|
||||||
return {
|
return {
|
||||||
"inlineData": {
|
"inlineData": {"mimeType": mime_type, "data": base64_data}
|
||||||
"mimeType": "image/jpeg",
|
|
||||||
"data": self.image_source,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"下载或编码URL图片失败: {e}", e=e)
|
||||||
|
raise ValueError(f"无法处理图片URL: {e}")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的图像源格式: {self.image_source[:50]}...")
|
||||||
else:
|
else:
|
||||||
return {"type": "image_url", "image_url": {"url": self.image_source}}
|
return {"type": "image_url", "image_url": {"url": self.image_source}}
|
||||||
|
|
||||||
|
|||||||
@ -4,13 +4,25 @@ LLM 数据模型定义
|
|||||||
包含模型信息、配置、工具定义和响应数据的模型类。
|
包含模型信息、配置、工具定义和响应数据的模型类。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from contextlib import AbstractAsyncContextManager
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from .enums import ModelProvider, ToolCategory
|
from .enums import ModelProvider, ToolCategory
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .protocols import MCPCompatible
|
||||||
|
|
||||||
|
MCPSessionType = (
|
||||||
|
MCPCompatible | Callable[[], AbstractAsyncContextManager[MCPCompatible]] | None
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
MCPCompatible = object
|
||||||
|
MCPSessionType = Any
|
||||||
|
|
||||||
ModelName = str | None
|
ModelName = str | None
|
||||||
|
|
||||||
|
|
||||||
@ -98,10 +110,21 @@ class LLMToolCall(BaseModel):
|
|||||||
class LLMTool(BaseModel):
|
class LLMTool(BaseModel):
|
||||||
"""LLM 工具定义(支持 MCP 风格)"""
|
"""LLM 工具定义(支持 MCP 风格)"""
|
||||||
|
|
||||||
|
model_config = {"arbitrary_types_allowed": True}
|
||||||
|
|
||||||
type: str = "function"
|
type: str = "function"
|
||||||
function: dict[str, Any]
|
function: dict[str, Any] | None = None
|
||||||
|
mcp_session: MCPSessionType = None
|
||||||
annotations: dict[str, Any] | None = Field(default=None, description="工具注解")
|
annotations: dict[str, Any] | None = Field(default=None, description="工具注解")
|
||||||
|
|
||||||
|
def model_post_init(self, /, __context: Any) -> None:
|
||||||
|
"""验证工具定义的有效性"""
|
||||||
|
_ = __context
|
||||||
|
if self.type == "function" and self.function is None:
|
||||||
|
raise ValueError("函数类型的工具必须包含 'function' 字段。")
|
||||||
|
if self.type == "mcp" and self.mcp_session is None:
|
||||||
|
raise ValueError("MCP 类型的工具必须包含 'mcp_session' 字段。")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
cls,
|
cls,
|
||||||
@ -111,7 +134,7 @@ class LLMTool(BaseModel):
|
|||||||
required: list[str] | None = None,
|
required: list[str] | None = None,
|
||||||
annotations: dict[str, Any] | None = None,
|
annotations: dict[str, Any] | None = None,
|
||||||
) -> "LLMTool":
|
) -> "LLMTool":
|
||||||
"""创建工具"""
|
"""创建函数工具"""
|
||||||
function_def = {
|
function_def = {
|
||||||
"name": name,
|
"name": name,
|
||||||
"description": description,
|
"description": description,
|
||||||
@ -123,6 +146,15 @@ class LLMTool(BaseModel):
|
|||||||
}
|
}
|
||||||
return cls(type="function", function=function_def, annotations=annotations)
|
return cls(type="function", function=function_def, annotations=annotations)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_mcp_session(
|
||||||
|
cls,
|
||||||
|
session: Any,
|
||||||
|
annotations: dict[str, Any] | None = None,
|
||||||
|
) -> "LLMTool":
|
||||||
|
"""从 MCP 会话创建工具"""
|
||||||
|
return cls(type="mcp", mcp_session=session, annotations=annotations)
|
||||||
|
|
||||||
|
|
||||||
class LLMCodeExecution(BaseModel):
|
class LLMCodeExecution(BaseModel):
|
||||||
"""代码执行结果"""
|
"""代码执行结果"""
|
||||||
|
|||||||
24
zhenxun/services/llm/types/protocols.py
Normal file
24
zhenxun/services/llm/types/protocols.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
"""
|
||||||
|
LLM 模块的协议定义
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Protocol
|
||||||
|
|
||||||
|
|
||||||
|
class MCPCompatible(Protocol):
|
||||||
|
"""
|
||||||
|
一个协议,定义了与LLM模块兼容的MCP会话对象应具备的行为。
|
||||||
|
任何实现了 to_api_tool 方法的对象都可以被认为是 MCPCompatible。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def to_api_tool(self, api_type: str) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
将此MCP会话转换为特定LLM提供商API所需的工具格式。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
api_type: 目标API的类型 (例如 'gemini', 'openai')。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
dict[str, Any]: 一个字典,代表可以在API请求中使用的工具定义。
|
||||||
|
"""
|
||||||
|
...
|
||||||
@ -3,8 +3,10 @@ LLM 模块的工具和转换函数
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
|
import copy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from nonebot.adapters import Message as PlatformMessage
|
||||||
from nonebot_plugin_alconna.uniseg import (
|
from nonebot_plugin_alconna.uniseg import (
|
||||||
At,
|
At,
|
||||||
File,
|
File,
|
||||||
@ -17,6 +19,7 @@ from nonebot_plugin_alconna.uniseg import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
|
from zhenxun.utils.http_utils import AsyncHttpx
|
||||||
|
|
||||||
from .types import LLMContentPart
|
from .types import LLMContentPart
|
||||||
|
|
||||||
@ -25,6 +28,12 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
|
|||||||
"""
|
"""
|
||||||
将 UniMessage 实例转换为一个 LLMContentPart 列表。
|
将 UniMessage 实例转换为一个 LLMContentPart 列表。
|
||||||
这是处理多模态输入的核心转换逻辑。
|
这是处理多模态输入的核心转换逻辑。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
message: 要转换的UniMessage实例。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
list[LLMContentPart]: 转换后的内容部分列表。
|
||||||
"""
|
"""
|
||||||
parts: list[LLMContentPart] = []
|
parts: list[LLMContentPart] = []
|
||||||
for seg in message:
|
for seg in message:
|
||||||
@ -51,14 +60,25 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
|
|||||||
if seg.path:
|
if seg.path:
|
||||||
part = await LLMContentPart.from_path(seg.path)
|
part = await LLMContentPart.from_path(seg.path)
|
||||||
elif seg.url:
|
elif seg.url:
|
||||||
logger.warning(
|
try:
|
||||||
f"直接使用 URL 的 {type(seg).__name__} 段,"
|
logger.debug(f"检测到媒体URL,开始下载: {seg.url}")
|
||||||
f"API 可能不支持: {seg.url}"
|
media_bytes = await AsyncHttpx.get_content(seg.url)
|
||||||
)
|
|
||||||
|
new_seg = copy.copy(seg)
|
||||||
|
new_seg.raw = media_bytes
|
||||||
|
seg = new_seg
|
||||||
|
logger.debug(f"媒体文件下载成功,大小: {len(media_bytes)} bytes")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"从URL下载媒体失败: {seg.url}, 错误: {e}")
|
||||||
part = LLMContentPart.text_part(
|
part = LLMContentPart.text_part(
|
||||||
f"[{type(seg).__name__.upper()} FILE: {seg.name or seg.url}]"
|
f"[下载媒体失败: {seg.name or seg.url}]"
|
||||||
)
|
)
|
||||||
elif hasattr(seg, "raw") and seg.raw:
|
|
||||||
|
if part:
|
||||||
|
parts.append(part)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if hasattr(seg, "raw") and seg.raw:
|
||||||
mime_type = getattr(seg, "mimetype", None)
|
mime_type = getattr(seg, "mimetype", None)
|
||||||
if isinstance(seg.raw, bytes):
|
if isinstance(seg.raw, bytes):
|
||||||
b64_data = base64.b64encode(seg.raw).decode("utf-8")
|
b64_data = base64.b64encode(seg.raw).decode("utf-8")
|
||||||
@ -127,50 +147,19 @@ def create_multimodal_message(
|
|||||||
audio_mimetypes: list[str] | str | None = None,
|
audio_mimetypes: list[str] | str | None = None,
|
||||||
) -> UniMessage:
|
) -> UniMessage:
|
||||||
"""
|
"""
|
||||||
创建多模态消息的便捷函数,方便第三方调用。
|
创建多模态消息的便捷函数
|
||||||
|
|
||||||
Args:
|
参数:
|
||||||
text: 文本内容
|
text: 文本内容
|
||||||
images: 图片数据,支持路径、字节数据或URL
|
images: 图片数据,支持路径、字节数据或URL
|
||||||
videos: 视频数据,支持路径、字节数据或URL
|
videos: 视频数据
|
||||||
audios: 音频数据,支持路径、字节数据或URL
|
audios: 音频数据
|
||||||
image_mimetypes: 图片MIME类型,当images为bytes时需要指定
|
image_mimetypes: 图片MIME类型,bytes数据时需要指定
|
||||||
video_mimetypes: 视频MIME类型,当videos为bytes时需要指定
|
video_mimetypes: 视频MIME类型,bytes数据时需要指定
|
||||||
audio_mimetypes: 音频MIME类型,当audios为bytes时需要指定
|
audio_mimetypes: 音频MIME类型,bytes数据时需要指定
|
||||||
|
|
||||||
Returns:
|
返回:
|
||||||
UniMessage: 构建好的多模态消息
|
UniMessage: 构建好的多模态消息
|
||||||
|
|
||||||
Examples:
|
|
||||||
# 纯文本
|
|
||||||
msg = create_multimodal_message("请分析这段文字")
|
|
||||||
|
|
||||||
# 文本 + 单张图片(路径)
|
|
||||||
msg = create_multimodal_message("分析图片", images="/path/to/image.jpg")
|
|
||||||
|
|
||||||
# 文本 + 多张图片
|
|
||||||
msg = create_multimodal_message(
|
|
||||||
"比较图片", images=["/path/1.jpg", "/path/2.jpg"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# 文本 + 图片字节数据
|
|
||||||
msg = create_multimodal_message(
|
|
||||||
"分析", images=image_data, image_mimetypes="image/jpeg"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 文本 + 视频
|
|
||||||
msg = create_multimodal_message("分析视频", videos="/path/to/video.mp4")
|
|
||||||
|
|
||||||
# 文本 + 音频
|
|
||||||
msg = create_multimodal_message("转录音频", audios="/path/to/audio.wav")
|
|
||||||
|
|
||||||
# 混合多模态
|
|
||||||
msg = create_multimodal_message(
|
|
||||||
"分析这些媒体文件",
|
|
||||||
images="/path/to/image.jpg",
|
|
||||||
videos="/path/to/video.mp4",
|
|
||||||
audios="/path/to/audio.wav"
|
|
||||||
)
|
|
||||||
"""
|
"""
|
||||||
message = UniMessage()
|
message = UniMessage()
|
||||||
|
|
||||||
@ -196,7 +185,7 @@ def _add_media_to_message(
|
|||||||
media_class: type,
|
media_class: type,
|
||||||
default_mimetype: str,
|
default_mimetype: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""添加媒体文件到 UniMessage 的辅助函数"""
|
"""添加媒体文件到 UniMessage"""
|
||||||
if not isinstance(media_items, list):
|
if not isinstance(media_items, list):
|
||||||
media_items = [media_items]
|
media_items = [media_items]
|
||||||
|
|
||||||
@ -216,3 +205,80 @@ def _add_media_to_message(
|
|||||||
elif isinstance(item, bytes):
|
elif isinstance(item, bytes):
|
||||||
mimetype = mime_list[i] if i < len(mime_list) else default_mimetype
|
mimetype = mime_list[i] if i < len(mime_list) else default_mimetype
|
||||||
message.append(media_class(raw=item, mimetype=mimetype))
|
message.append(media_class(raw=item, mimetype=mimetype))
|
||||||
|
|
||||||
|
|
||||||
|
def message_to_unimessage(message: PlatformMessage) -> UniMessage:
|
||||||
|
"""
|
||||||
|
将平台特定的 Message 对象转换为通用的 UniMessage。
|
||||||
|
主要用于处理引用消息等未被自动转换的消息体。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
message: 平台特定的Message对象。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
UniMessage: 转换后的通用消息对象。
|
||||||
|
"""
|
||||||
|
uni_segments = []
|
||||||
|
for seg in message:
|
||||||
|
if seg.type == "text":
|
||||||
|
uni_segments.append(Text(seg.data.get("text", "")))
|
||||||
|
elif seg.type == "image":
|
||||||
|
uni_segments.append(Image(url=seg.data.get("url")))
|
||||||
|
elif seg.type == "record":
|
||||||
|
uni_segments.append(Voice(url=seg.data.get("url")))
|
||||||
|
elif seg.type == "video":
|
||||||
|
uni_segments.append(Video(url=seg.data.get("url")))
|
||||||
|
elif seg.type == "at":
|
||||||
|
uni_segments.append(At("user", str(seg.data.get("qq", ""))))
|
||||||
|
else:
|
||||||
|
logger.debug(f"跳过不支持的平台消息段类型: {seg.type}")
|
||||||
|
|
||||||
|
return UniMessage(uni_segments)
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_request_body_for_logging(body: dict) -> dict:
|
||||||
|
"""
|
||||||
|
净化请求体用于日志记录,移除大数据字段并添加摘要信息
|
||||||
|
|
||||||
|
参数:
|
||||||
|
body: 原始请求体字典。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
dict: 净化后的请求体字典。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
sanitized_body = copy.deepcopy(body)
|
||||||
|
|
||||||
|
if "contents" in sanitized_body and isinstance(
|
||||||
|
sanitized_body["contents"], list
|
||||||
|
):
|
||||||
|
for content_item in sanitized_body["contents"]:
|
||||||
|
if "parts" in content_item and isinstance(content_item["parts"], list):
|
||||||
|
media_summary = []
|
||||||
|
new_parts = []
|
||||||
|
for part in content_item["parts"]:
|
||||||
|
if "inlineData" in part and isinstance(
|
||||||
|
part["inlineData"], dict
|
||||||
|
):
|
||||||
|
data = part["inlineData"].get("data")
|
||||||
|
if isinstance(data, str):
|
||||||
|
mime_type = part["inlineData"].get(
|
||||||
|
"mimeType", "unknown"
|
||||||
|
)
|
||||||
|
media_summary.append(f"{mime_type} ({len(data)} chars)")
|
||||||
|
continue
|
||||||
|
new_parts.append(part)
|
||||||
|
|
||||||
|
if media_summary:
|
||||||
|
summary_text = (
|
||||||
|
f"[多模态内容: {len(media_summary)}个文件 - "
|
||||||
|
f"{', '.join(media_summary)}]"
|
||||||
|
)
|
||||||
|
new_parts.insert(0, {"text": summary_text})
|
||||||
|
|
||||||
|
content_item["parts"] = new_parts
|
||||||
|
|
||||||
|
return sanitized_body
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"日志净化失败: {e},将记录原始请求体。")
|
||||||
|
return body
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user