diff --git a/zhenxun/services/llm/README.md b/zhenxun/services/llm/README.md index 263be1e6..93394fdf 100644 --- a/zhenxun/services/llm/README.md +++ b/zhenxun/services/llm/README.md @@ -1,731 +1,559 @@ -# Zhenxun LLM 服务模块 -## 📑 目录 +--- -- [📖 概述](#-概述) -- [🌟 主要特性](#-主要特性) -- [🚀 快速开始](#-快速开始) -- [📚 API 参考](#-api-参考) -- [⚙️ 配置](#️-配置) -- [🔧 高级功能](#-高级功能) -- [🏗️ 架构设计](#️-架构设计) -- [🔌 支持的提供商](#-支持的提供商) -- [🎯 使用场景](#-使用场景) -- [📊 性能优化](#-性能优化) -- [🛠️ 故障排除](#️-故障排除) -- [❓ 常见问题](#-常见问题) -- [📝 示例项目](#-示例项目) -- [🤝 贡献](#-贡献) -- [📄 许可证](#-许可证) +# 🚀 Zhenxun LLM 服务模块 -## 📖 概述 +本模块是一个功能强大、高度可扩展的统一大语言模型(LLM)服务框架。它旨在将各种不同的 LLM 提供商(如 OpenAI、Gemini、智谱AI等)的 API 封装在一个统一、易于使用的接口之后,让开发者可以无缝切换和使用不同的模型,同时支持多模态输入、工具调用、智能重试和缓存等高级功能。 -Zhenxun LLM 服务模块是一个现代化的AI服务框架,提供统一的接口来访问多个大语言模型提供商。该模块采用模块化设计,支持异步操作、智能重试、Key轮询和负载均衡等高级功能。 +## 目录 -### 🌟 主要特性 +- [🚀 Zhenxun LLM 服务模块](#-zhenxun-llm-服务模块) + - [目录](#目录) + - [✨ 核心特性](#-核心特性) + - [🧠 核心概念](#-核心概念) + - [🛠️ 安装与配置](#️-安装与配置) + - [服务提供商配置 (`config.yaml`)](#服务提供商配置-configyaml) + - [MCP 工具配置 (`mcp_tools.json`)](#mcp-工具配置-mcp_toolsjson) + - [📘 使用指南](#-使用指南) + - [**等级1: 便捷函数** - 最快速的调用方式](#等级1-便捷函数---最快速的调用方式) + - [**等级2: `AI` 会话类** - 管理有状态的对话](#等级2-ai-会话类---管理有状态的对话) + - [**等级3: 直接模型控制** - `get_model_instance`](#等级3-直接模型控制---get_model_instance) + - [🌟 功能深度剖析](#-功能深度剖析) + - [精细化控制模型生成 (`LLMGenerationConfig` 与 `CommonOverrides`)](#精细化控制模型生成-llmgenerationconfig-与-commonoverrides) + - [赋予模型能力:工具使用 (Function Calling)](#赋予模型能力工具使用-function-calling) + - [1. 注册工具](#1-注册工具) + - [函数工具注册](#函数工具注册) + - [MCP工具注册](#mcp工具注册) + - [2. 调用带工具的模型](#2-调用带工具的模型) + - [处理多模态输入](#处理多模态输入) + - [🔧 高级主题与扩展](#-高级主题与扩展) + - [模型与密钥管理](#模型与密钥管理) + - [缓存管理](#缓存管理) + - [错误处理 (`LLMException`)](#错误处理-llmexception) + - [自定义适配器 (Adapter)](#自定义适配器-adapter) + - [📚 API 快速参考](#-api-快速参考) -- **多提供商支持**: OpenAI、Gemini、智谱AI、DeepSeek等 -- **统一接口**: 简洁一致的API设计 -- **智能Key轮询**: 自动负载均衡和故障转移 -- **异步高性能**: 基于asyncio的并发处理 -- **模型缓存**: 智能缓存机制提升性能 -- **工具调用**: 支持Function Calling -- **嵌入向量**: 文本向量化支持 -- **错误处理**: 完善的异常处理和重试机制 -- **多模态支持**: 文本、图像、音频、视频处理 -- **代码执行**: Gemini代码执行功能 -- **搜索增强**: Google搜索集成 +--- -## 🚀 快速开始 +## ✨ 核心特性 -### 基本使用 +- **多提供商支持**: 内置对 OpenAI、Gemini、智谱AI 等多种 API 的适配器,并可通过通用 OpenAI 兼容适配器轻松接入更多服务。 +- **统一的 API**: 提供从简单到高级的三层 API,满足不同场景的需求,无论是快速聊天还是复杂的分析任务。 +- **强大的工具调用 (Function Calling)**: 支持标准的函数调用和实验性的 MCP (Model Context Protocol) 工具,让 LLM 能够与外部世界交互。 +- **多模态能力**: 无缝集成 `UniMessage`,轻松处理文本、图片、音频、视频等混合输入,支持多模态搜索和分析。 +- **文本嵌入向量化**: 提供统一的嵌入接口,支持语义搜索、相似度计算和文本聚类等应用。 +- **智能重试与 Key 轮询**: 内置健壮的请求重试逻辑,当 API Key 失效或达到速率限制时,能自动轮询使用备用 Key。 +- **灵活的配置系统**: 通过配置文件和代码中的 `LLMGenerationConfig`,可以精细控制模型的生成行为(如温度、最大Token等)。 +- **高性能缓存机制**: 内置模型实例缓存,减少重复初始化开销,提供缓存管理和监控功能。 +- **丰富的配置预设**: 提供 `CommonOverrides` 类,包含创意模式、精确模式、JSON输出等多种常用配置预设。 +- **可扩展的适配器架构**: 开发者可以轻松编写自己的适配器来支持新的 LLM 服务。 + +## 🧠 核心概念 + +- **适配器 (Adapter)**: 这是连接我们统一接口和特定 LLM 提供商 API 的“翻译官”。例如,`GeminiAdapter` 知道如何将我们的标准请求格式转换为 Google Gemini API 需要的格式,并解析其响应。 +- **模型实例 (`LLMModel`)**: 这是框架中的核心操作对象,代表一个**具体配置好**的模型。例如,一个 `LLMModel` 实例可能代表使用特定 API Key、特定代理的 `Gemini/gemini-1.5-pro`。所有与模型交互的操作都通过这个类的实例进行。 +- **生成配置 (`LLMGenerationConfig`)**: 这是一个数据类,用于控制模型在生成内容时的行为,例如 `temperature` (温度)、`max_tokens` (最大输出长度)、`response_format` (响应格式) 等。 +- **工具 (Tool)**: 代表一个可以让 LLM 调用的函数。它可以是一个简单的 Python 函数,也可以是一个更复杂的、有状态的 MCP 服务。 +- **多模态内容 (`LLMContentPart`)**: 这是处理多模态输入的基础单元,一个 `LLMMessage` 可以包含多个 `LLMContentPart`,如一个文本部分和多个图片部分。 + +## 🛠️ 安装与配置 + +该模块作为 `zhenxun` 项目的一部分被集成,无需额外安装。核心配置主要涉及两个文件。 + +### 服务提供商配置 (`config.yaml`) + +核心配置位于项目 `/data/config.yaml` 文件中的 `AI` 部分。 + +```yaml +# /data/configs/config.yaml +AI: + # (可选) 全局默认模型,格式: "ProviderName/ModelName" + default_model_name: Gemini/gemini-2.5-flash + # (可选) 全局代理设置 + proxy: http://127.0.0.1:7890 + # (可选) 全局超时设置 (秒) + timeout: 180 + # (可选) Gemini 的安全过滤阈值 + gemini_safety_threshold: BLOCK_MEDIUM_AND_ABOVE + + # 配置你的AI服务提供商 + PROVIDERS: + # 示例1: Gemini + - name: Gemini + api_key: + - "AIzaSy_KEY_1" # 支持多个Key,会自动轮询 + - "AIzaSy_KEY_2" + api_base: https://generativelanguage.googleapis.com + api_type: gemini + models: + - model_name: gemini-2.5-pro + - model_name: gemini-2.5-flash + - model_name: gemini-2.0-flash + - model_name: embedding-001 + is_embedding_model: true # 标记为嵌入模型 + max_input_tokens: 2048 # 嵌入模型特有配置 + + # 示例2: 智谱AI + - name: GLM + api_key: "YOUR_ZHIPU_API_KEY" + api_type: zhipu # 适配器类型 + models: + - model_name: glm-4-flash + - model_name: glm-4-plus + temperature: 0.8 # 可以为特定模型设置默认温度 + + # 示例3: 一个兼容OpenAI的自定义服务 + - name: MyOpenAIService + api_key: "sk-my-custom-key" + api_base: "http://localhost:8080/v1" + api_type: general_openai_compat # 使用通用OpenAI兼容适配器 + models: + - model_name: Llama3-8B-Instruct + max_tokens: 2048 # 可以为特定模型设置默认最大Token +``` + +### MCP 工具配置 (`mcp_tools.json`) + +此文件位于 `/data/llm/mcp_tools.json`,用于配置通过 MCP 协议启动的外部工具服务。如果文件不存在,系统会自动创建一个包含示例的默认文件。 + +```json +{ + "mcpServers": { + "baidu-map": { + "command": "npx", + "args": ["-y", "@baidumap/mcp-server-baidu-map"], + "env": { + "BAIDU_MAP_API_KEY": "" + }, + "description": "百度地图工具,提供地理编码、路线规划等功能。" + }, + "sequential-thinking": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-sequential-thinking"], + "description": "顺序思维工具,用于帮助模型进行多步骤推理。" + } + } +} +``` + +## 📘 使用指南 + +我们提供了三层 API,以满足从简单到复杂的各种需求。 + +### **等级1: 便捷函数** - 最快速的调用方式 + +这些函数位于 `zhenxun.services.llm` 包的顶层,为你处理了所有的底层细节。 ```python -from zhenxun.services.llm import chat, code, search, analyze +from zhenxun.services.llm import chat, search, code, pipeline_chat, embed, analyze_multimodal, search_multimodal +from zhenxun.services.llm.utils import create_multimodal_message -# 简单聊天 -response = await chat("你好,请介绍一下自己") +# 1. 纯文本聊天 +response_text = await chat("你好,请用苏轼的风格写一首关于月亮的诗。") +print(response_text) + +# 2. 带网络搜索的问答 +search_result = await search("马斯克的Neuralink公司最近有什么新进展?") +print(search_result['text']) +# print(search_result['sources']) # 查看信息来源 + +# 3. 执行代码 +code_result = await code("用Python画一个心形图案。") +print(code_result['text']) # 包含代码和解释的回复 + +# 4. 链式调用 +image_msg = create_multimodal_message(images="path/to/cat.jpg") +final_poem = await pipeline_chat( + message=image_msg, + model_chain=["Gemini/gemini-1.5-pro", "GLM/glm-4-flash"], + initial_instruction="详细描述这只猫的外观和姿态。", + final_instruction="将上述描述凝练成一首可爱的短诗。" +) +print(final_poem.text) + +# 5. 文本嵌入向量生成 +texts_to_embed = ["今天天气真好", "我喜欢打篮球", "这部电影很感人"] +vectors = await embed(texts_to_embed, model="Gemini/embedding-001") +print(f"生成了 {len(vectors)} 个向量,每个向量维度: {len(vectors[0])}") + +# 6. 多模态分析便捷函数 +response = await analyze_multimodal( + text="请分析这张图片中的内容", + images="path/to/image.jpg", + model="Gemini/gemini-1.5-pro" +) print(response) -# 代码执行 -result = await code("计算斐波那契数列的前10项") -print(result["text"]) -print(result["code_executions"]) - -# 搜索功能 -search_result = await search("Python异步编程最佳实践") -print(search_result["text"]) - -# 多模态分析 -from nonebot_plugin_alconna.uniseg import UniMessage, Image, Text -message = UniMessage([ - Text("分析这张图片"), - Image(path="image.jpg") -]) -analysis = await analyze(message, model="Gemini/gemini-2.0-flash") -print(analysis) -``` - -### 使用AI类 - -```python -from zhenxun.services.llm import AI, AIConfig, CommonOverrides - -# 创建AI实例 -ai = AI(AIConfig(model="OpenAI/gpt-4")) - -# 聊天对话 -response = await ai.chat("解释量子计算的基本原理") - -# 多模态分析 -from nonebot_plugin_alconna.uniseg import UniMessage, Image, Text - -multimodal_msg = UniMessage([ - Text("这张图片显示了什么?"), - Image(path="image.jpg") -]) -result = await ai.analyze(multimodal_msg) - -# 便捷的多模态函数 -result = await analyze_with_images( - "分析这张图片", - images="image.jpg", - model="Gemini/gemini-2.0-flash" +# 7. 多模态搜索便捷函数 +search_result = await search_multimodal( + text="搜索与这张图片相关的信息", + images="path/to/image.jpg", + model="Gemini/gemini-1.5-pro" ) +print(search_result['text']) ``` -## 📚 API 参考 +### **等级2: `AI` 会话类** - 管理有状态的对话 -### 快速函数 - -#### `chat(message, *, model=None, **kwargs) -> str` -简单聊天对话 - -**参数:** -- `message`: 消息内容(字符串、LLMMessage或内容部分列表) -- `model`: 模型名称(可选) -- `**kwargs`: 额外配置参数 - -#### `code(prompt, *, model=None, timeout=None, **kwargs) -> dict` -代码执行功能 - -**返回:** -```python -{ - "text": "执行结果说明", - "code_executions": [{"code": "...", "output": "..."}], - "success": True -} -``` - -#### `search(query, *, model=None, instruction="", **kwargs) -> dict` -搜索增强生成 - -**返回:** -```python -{ - "text": "搜索结果和分析", - "grounding_metadata": {...}, - "success": True -} -``` - -#### `analyze(message, *, instruction="", model=None, tools=None, tool_config=None, **kwargs) -> str | LLMResponse` -高级分析功能,支持多模态输入和工具调用 - -#### `analyze_with_images(text, images, *, instruction="", model=None, **kwargs) -> str` -图片分析便捷函数 - -#### `analyze_multimodal(text=None, images=None, videos=None, audios=None, *, instruction="", model=None, **kwargs) -> str` -多模态分析便捷函数 - -#### `embed(texts, *, model=None, task_type="RETRIEVAL_DOCUMENT", **kwargs) -> list[list[float]]` -文本嵌入向量 - -### AI类方法 - -#### `AI.chat(message, *, model=None, **kwargs) -> str` -聊天对话方法,支持简单多模态输入 - -#### `AI.analyze(message, *, instruction="", model=None, tools=None, tool_config=None, **kwargs) -> str | LLMResponse` -高级分析方法,接收UniMessage进行多模态分析和工具调用 - -### 模型管理 +当你需要进行有上下文的、连续的对话时,`AI` 类是你的最佳选择。 ```python -from zhenxun.services.llm import ( - get_model_instance, - list_available_models, - set_global_default_model_name, - clear_model_cache -) +from zhenxun.services.llm.api import AI, AIConfig -# 获取模型实例 -model = await get_model_instance("OpenAI/gpt-4o") +# 初始化一个AI会话,可以传入自定义配置 +ai_config = AIConfig(model="GLM/glm-4-flash", temperature=0.7) +ai_session = AI(config=ai_config) -# 列出可用模型 -models = list_available_models() - -# 设置默认模型 -set_global_default_model_name("Gemini/gemini-2.0-flash") - -# 清理缓存 -clear_model_cache() -``` - -## ⚙️ 配置 - -### 预设配置 - -```python -from zhenxun.services.llm import CommonOverrides - -# 创意模式 -creative_config = CommonOverrides.creative() - -# 精确模式 -precise_config = CommonOverrides.precise() - -# Gemini特殊功能 -json_config = CommonOverrides.gemini_json() -thinking_config = CommonOverrides.gemini_thinking() -code_exec_config = CommonOverrides.gemini_code_execution() -grounding_config = CommonOverrides.gemini_grounding() -``` - -### 自定义配置 - -```python -from zhenxun.services.llm import LLMGenerationConfig - -config = LLMGenerationConfig( +# 更完整的AIConfig配置示例 +advanced_config = AIConfig( + model="GLM/glm-4-flash", + default_embedding_model="Gemini/embedding-001", # 默认嵌入模型 temperature=0.7, - max_tokens=2048, - top_p=0.9, - frequency_penalty=0.1, - presence_penalty=0.1, - stop=["END", "STOP"], - response_mime_type="application/json", - enable_code_execution=True, - enable_grounding=True + max_tokens=2000, + enable_cache=True, # 启用模型缓存 + enable_code=True, # 启用代码执行功能 + enable_search=True, # 启用搜索功能 + timeout=180, # 请求超时时间(秒) + # Gemini特定配置选项 + enable_gemini_json_mode=True, # 启用Gemini JSON模式 + enable_gemini_thinking=True, # 启用Gemini 思考模式 + enable_gemini_safe_mode=True, # 启用Gemini 安全模式 + enable_gemini_multimodal=True, # 启用Gemini 多模态优化 + enable_gemini_grounding=True, # 启用Gemini 信息来源关联 ) +advanced_session = AI(config=advanced_config) -response = await chat("你的问题", override_config=config) +# 进行连续对话 +await ai_session.chat("我最喜欢的城市是成都。") +response = await ai_session.chat("它有什么好吃的?") # AI会知道“它”指的是成都 +print(response) + +# 在同一个会话中,临时切换模型进行一次调用 +response_gemini = await ai_session.chat( + "从AI的角度分析一下成都的科技发展潜力。", + model="Gemini/gemini-1.5-pro" +) +print(response_gemini) + +# 清空历史,开始新一轮对话 +ai_session.clear_history() ``` -## 🔧 高级功能 +### **等级3: 直接模型控制** - `get_model_instance` -### 工具调用 (Function Calling) +这是最底层的 API,为你提供对模型实例的完全控制。推荐使用 `async with` 语句来优雅地管理模型实例的生命周期。 ```python -from zhenxun.services.llm import LLMTool, get_model_instance +from zhenxun.services.llm import get_model_instance, LLMMessage +from zhenxun.services.llm.config import LLMGenerationConfig -# 定义工具 -tools = [ - LLMTool( - name="get_weather", - description="获取天气信息", - parameters={ - "type": "object", - "properties": { - "city": {"type": "string", "description": "城市名称"} - }, - "required": ["city"] - } +# 1. 获取模型实例 +# get_model_instance 返回一个异步上下文管理器 +async with await get_model_instance("Gemini/gemini-1.5-pro") as model: + # 2. 准备消息列表 + messages = [ + LLMMessage.system("你是一个专业的营养师。"), + LLMMessage.user("我今天吃了汉堡和可乐,请给我一些健康建议。") + ] + + # 3. (可选) 定义本次调用的生成配置 + gen_config = LLMGenerationConfig( + temperature=0.2, # 更严谨的回复 + max_tokens=300 ) -] - -# 工具执行器 -async def tool_executor(tool_name: str, args: dict) -> str: - if tool_name == "get_weather": - return f"{args['city']}今天晴天,25°C" - return "未知工具" - -# 使用工具 -model = await get_model_instance("OpenAI/gpt-4") -response = await model.generate_response( - messages=[{"role": "user", "content": "北京天气如何?"}], - tools=tools, - tool_executor=tool_executor -) + + # 4. 生成响应 + response = await model.generate_response(messages, config=gen_config) + + # 5. 处理响应 + print(response.text) + if response.usage_info: + print(f"Token 消耗: {response.usage_info['total_tokens']}") ``` -### 多模态处理 +## 🌟 功能深度剖析 + +### 精细化控制模型生成 (`LLMGenerationConfig` 与 `CommonOverrides`) + +- **`LLMGenerationConfig`**: 一个 Pydantic 模型,用于覆盖模型的默认生成参数。 +- **`CommonOverrides`**: 一个包含多种常用配置预设的类,如 `creative()`, `precise()`, `gemini_json()` 等,能极大地简化配置过程。 ```python -from zhenxun.services.llm import create_multimodal_message, analyze_multimodal, analyze_with_images +from zhenxun.services.llm.config import LLMGenerationConfig, CommonOverrides -# 方法1:使用便捷函数 -result = await analyze_multimodal( - text="分析这些媒体文件", - images="image.jpg", - audios="audio.mp3", - model="Gemini/gemini-2.0-flash" +# LLMGenerationConfig 完整参数示例 +comprehensive_config = LLMGenerationConfig( + temperature=0.7, # 生成温度 (0.0-2.0) + max_tokens=1000, # 最大输出token数 + top_p=0.9, # 核采样参数 (0.0-1.0) + top_k=40, # Top-K采样参数 + frequency_penalty=0.0, # 频率惩罚 (-2.0-2.0) + presence_penalty=0.0, # 存在惩罚 (-2.0-2.0) + repetition_penalty=1.0, # 重复惩罚 (0.0-2.0) + stop=["END", "\n\n"], # 停止序列 + response_format={"type": "json_object"}, # 响应格式 + response_mime_type="application/json", # Gemini专用MIME类型 + response_schema={...}, # JSON响应模式 + thinking_budget=0.8, # Gemini思考预算 (0.0-1.0) + enable_code_execution=True, # 启用代码执行 + safety_settings={...}, # 安全设置 + response_modalities=["TEXT"], # 响应模态类型 ) -# 方法2:使用create_multimodal_message +# 创建一个配置,要求模型输出JSON格式 +json_config = LLMGenerationConfig( + temperature=0.1, + response_mime_type="application/json" # Gemini特有 +) +# 对于OpenAI兼容API,可以这样做 +json_config_openai = LLMGenerationConfig( + temperature=0.1, + response_format={"type": "json_object"} +) + +# 使用框架提供的预设 - 基础预设 +safe_config = CommonOverrides.gemini_safe() +creative_config = CommonOverrides.creative() +precise_config = CommonOverrides.precise() +balanced_config = CommonOverrides.balanced() + +# 更多实用预设 +concise_config = CommonOverrides.concise(max_tokens=50) # 简洁模式 +detailed_config = CommonOverrides.detailed(max_tokens=3000) # 详细模式 +json_config = CommonOverrides.gemini_json() # JSON输出模式 +thinking_config = CommonOverrides.gemini_thinking(budget=0.8) # 思考模式 + +# Gemini特定高级预设 +code_config = CommonOverrides.gemini_code_execution() # 代码执行模式 +grounding_config = CommonOverrides.gemini_grounding() # 信息来源关联模式 +multimodal_config = CommonOverrides.gemini_multimodal() # 多模态优化模式 + +# 在调用时传入config对象 +# await model.generate_response(messages, config=json_config) +``` + +### 赋予模型能力:工具使用 (Function Calling) + +工具调用让 LLM 能够与外部函数、API 或服务进行交互。 + +#### 1. 注册工具 + +##### 函数工具注册 + +使用 `@tool_registry.function_tool` 装饰器注册一个简单的函数工具。 + +```python +from zhenxun.services.llm import tool_registry + +@tool_registry.function_tool( + name="query_stock_price", + description="查询指定股票代码的当前价格。", + parameters={ + "stock_symbol": {"type": "string", "description": "股票代码, 例如 'AAPL' 或 'GOOG'"} + }, + required=["stock_symbol"] +) +async def query_stock_price(stock_symbol: str) -> dict: + """一个查询股票价格的伪函数""" + print(f"--- 正在查询 {stock_symbol} 的价格 ---") + if stock_symbol == "AAPL": + return {"symbol": "AAPL", "price": 175.50, "currency": "USD"} + return {"error": "未知的股票代码"} +``` + +##### MCP工具注册 + +对于更复杂的、有状态的工具,可以使用 `@tool_registry.mcp_tool` 装饰器注册MCP工具。 + +```python +from contextlib import asynccontextmanager +from pydantic import BaseModel +from zhenxun.services.llm import tool_registry + +# 定义工具的配置模型 +class MyToolConfig(BaseModel): + api_key: str + endpoint: str + timeout: int = 30 + +# 注册MCP工具 +@tool_registry.mcp_tool(name="my-custom-tool", config_model=MyToolConfig) +@asynccontextmanager +async def my_tool_factory(config: MyToolConfig): + """MCP工具工厂函数""" + # 初始化工具会话 + session = MyToolSession(config) + try: + await session.initialize() + yield session + finally: + await session.cleanup() +``` + +#### 2. 调用带工具的模型 + +在 `analyze` 或 `generate_response` 中使用 `use_tools` 参数。框架会自动处理整个调用流程。 + +```python +from zhenxun.services.llm.api import analyze +from nonebot_plugin_alconna.uniseg import UniMessage + +response = await analyze( + UniMessage("帮我查一下苹果公司的股价"), + use_tools=["query_stock_price"] +) +print(response.text) # 输出应为 "苹果公司(AAPL)的当前股价为175.5美元。" 或类似内容 +``` + +### 处理多模态输入 + +本模块通过 `UniMessage` 和 `LLMContentPart` 完美支持多模态。 + +- **`create_multimodal_message`**: 推荐的、用于从代码中便捷地创建多模态消息的函数。 +- **`unimsg_to_llm_parts`**: 框架内部使用的核心转换函数,将 `UniMessage` 的各个段(文本、图片等)转换为 `LLMContentPart` 列表。 + +```python +from zhenxun.services.llm import analyze +from zhenxun.services.llm.utils import create_multimodal_message +from pathlib import Path + +# 从本地文件创建消息 message = create_multimodal_message( - text="分析这张图片和音频", - images="image.jpg", - audios="audio.mp3" + text="请分析这张图片和这个视频。图片里是什么?视频里发生了什么?", + images=[Path("path/to/your/image.jpg")], + videos=[Path("path/to/your/video.mp4")] ) -result = await analyze(message) +response = await analyze(message, model="Gemini/gemini-1.5-pro") +print(response.text) +``` -# 方法3:图片分析专用函数 -result = await analyze_with_images( - "这张图片显示了什么?", - images=["image1.jpg", "image2.jpg"] +## 🔧 高级主题与扩展 + +### 模型与密钥管理 + +模块提供了一些工具函数来管理你的模型配置。 + +```python +from zhenxun.services.llm.manager import ( + list_available_models, + list_embedding_models, + set_global_default_model_name, + get_global_default_model_name, + get_key_usage_stats, + reset_key_status ) -``` +from zhenxun.services.llm import clear_model_cache, get_cache_stats -## 🛠️ 故障排除 - -### 常见错误 - -1. **配置错误**: 检查API密钥和模型配置 -2. **网络问题**: 检查代理设置和网络连接 -3. **模型不可用**: 使用 `list_available_models()` 检查可用模型 -4. **超时错误**: 调整timeout参数或使用更快的模型 - -### 调试技巧 - -```python -from zhenxun.services.llm import get_cache_stats -from zhenxun.services.log import logger - -# 查看缓存状态 -stats = get_cache_stats() -print(f"缓存命中率: {stats['hit_rate']}") - -# 启用详细日志 -logger.setLevel("DEBUG") -``` - -## ❓ 常见问题 - - -### Q: 如何处理多模态输入? - -**A:** 有多种方式处理多模态输入: -```python -# 方法1:使用便捷函数 -result = await analyze_with_images("分析这张图片", images="image.jpg") - -# 方法2:使用analyze函数 -from nonebot_plugin_alconna.uniseg import UniMessage, Image, Text -message = UniMessage([Text("分析这张图片"), Image(path="image.jpg")]) -result = await analyze(message) - -# 方法3:使用create_multimodal_message -from zhenxun.services.llm import create_multimodal_message -message = create_multimodal_message(text="分析这张图片", images="image.jpg") -result = await analyze(message) -``` - -### Q: 如何自定义工具调用? - -**A:** 使用analyze函数的tools参数: -```python -# 定义工具 -tools = [{ - "name": "calculator", - "description": "计算数学表达式", - "parameters": { - "type": "object", - "properties": { - "expression": {"type": "string", "description": "数学表达式"} - }, - "required": ["expression"] - } -}] - -# 使用工具 -from nonebot_plugin_alconna.uniseg import UniMessage, Text -message = UniMessage([Text("计算 2+3*4")]) -response = await analyze(message, tools=tools, tool_config={"mode": "auto"}) - -# 如果返回LLMResponse,说明有工具调用 -if hasattr(response, 'tool_calls'): - for tool_call in response.tool_calls: - print(f"调用工具: {tool_call.function.name}") - print(f"参数: {tool_call.function.arguments}") -``` - - -### Q: 如何确保输出格式? - -**A:** 使用结构化输出: -```python -# JSON格式输出 -config = CommonOverrides.gemini_json() - -# 自定义Schema -schema = { - "type": "object", - "properties": { - "answer": {"type": "string"}, - "confidence": {"type": "number"} - } -} -config = CommonOverrides.gemini_structured(schema) -``` - -## 📝 示例项目 - -### 完整示例 - -#### 1. 智能客服机器人 - -```python -from zhenxun.services.llm import AI, CommonOverrides -from typing import Dict, List - -class CustomerService: - def __init__(self): - self.ai = AI() - self.sessions: Dict[str, List[dict]] = {} - - async def handle_query(self, user_id: str, query: str) -> str: - # 获取或创建会话历史 - if user_id not in self.sessions: - self.sessions[user_id] = [] - - history = self.sessions[user_id] - - # 添加系统提示 - if not history: - history.append({ - "role": "system", - "content": "你是一个专业的客服助手,请友好、准确地回答用户问题。" - }) - - # 添加用户问题 - history.append({"role": "user", "content": query}) - - # 生成回复 - response = await self.ai.chat( - query, - history=history[-20:], # 保留最近20轮对话 - override_config=CommonOverrides.balanced() - ) - - # 保存回复到历史 - history.append({"role": "assistant", "content": response}) - - return response -``` - -#### 2. 文档智能问答 - -```python -from zhenxun.services.llm import embed, analyze -import numpy as np -from typing import List, Tuple - -class DocumentQA: - def __init__(self): - self.documents: List[str] = [] - self.embeddings: List[List[float]] = [] - - async def add_document(self, text: str): - """添加文档到知识库""" - self.documents.append(text) - - # 生成嵌入向量 - embedding = await embed([text]) - self.embeddings.extend(embedding) - - async def query(self, question: str, top_k: int = 3) -> str: - """查询文档并生成答案""" - if not self.documents: - return "知识库为空,请先添加文档。" - - # 生成问题的嵌入向量 - question_embedding = await embed([question]) - - # 计算相似度并找到最相关的文档 - similarities = [] - for doc_embedding in self.embeddings: - similarity = np.dot(question_embedding[0], doc_embedding) - similarities.append(similarity) - - # 获取最相关的文档 - top_indices = np.argsort(similarities)[-top_k:][::-1] - relevant_docs = [self.documents[i] for i in top_indices] - - # 构建上下文 - context = "\n\n".join(relevant_docs) - prompt = f""" -基于以下文档内容回答问题: - -文档内容: -{context} - -问题:{question} - -请基于文档内容给出准确的答案,如果文档中没有相关信息,请说明。 -""" - - result = await analyze(prompt) - return result["text"] -``` - -#### 3. 代码审查助手 - -```python -from zhenxun.services.llm import code, analyze -import os - -class CodeReviewer: - async def review_file(self, file_path: str) -> dict: - """审查代码文件""" - if not os.path.exists(file_path): - return {"error": "文件不存在"} - - with open(file_path, 'r', encoding='utf-8') as f: - code_content = f.read() - - prompt = f""" -请审查以下代码,提供详细的反馈: - -文件:{file_path} -代码: -``` -{code_content} -``` - -请从以下方面进行审查: -1. 代码质量和可读性 -2. 潜在的bug和安全问题 -3. 性能优化建议 -4. 最佳实践建议 -5. 代码风格问题 - -请以JSON格式返回结果。 -""" - - result = await analyze( - prompt, - model="DeepSeek/deepseek-coder", - override_config=CommonOverrides.gemini_json() - ) - - return { - "file": file_path, - "review": result["text"], - "success": True - } - - async def suggest_improvements(self, code: str, language: str = "python") -> str: - """建议代码改进""" - prompt = f""" -请改进以下{language}代码,使其更加高效、可读和符合最佳实践: - -原代码: -```{language} -{code} -``` - -请提供改进后的代码和说明。 -""" - - result = await code(prompt, model="DeepSeek/deepseek-coder") - return result["text"] -``` - - -## 🏗️ 架构设计 - -### 模块结构 - -``` -zhenxun/services/llm/ -├── __init__.py # 包入口,导入和暴露公共API -├── api.py # 高级API接口(AI类、便捷函数) -├── core.py # 核心基础设施(HTTP客户端、重试逻辑、KeyStore) -├── service.py # LLM模型实现类 -├── utils.py # 工具和转换函数 -├── manager.py # 模型管理和缓存 -├── adapters/ # 适配器模块 -│ ├── __init__.py # 适配器包入口 -│ ├── base.py # 基础适配器 -│ ├── factory.py # 适配器工厂 -│ ├── openai.py # OpenAI适配器 -│ ├── gemini.py # Gemini适配器 -│ └── zhipu.py # 智谱AI适配器 -├── config/ # 配置模块 -│ ├── __init__.py # 配置包入口 -│ ├── generation.py # 生成配置 -│ ├── presets.py # 预设配置 -│ └── providers.py # 提供商配置 -└── types/ # 类型定义 - ├── __init__.py # 类型包入口 - ├── content.py # 内容类型 - ├── enums.py # 枚举定义 - ├── exceptions.py # 异常定义 - └── models.py # 数据模型 -``` - -### 模块职责 - -- **`__init__.py`**: 纯粹的包入口,只负责导入和暴露公共API -- **`api.py`**: 高级API接口,包含AI类和所有便捷函数 -- **`core.py`**: 核心基础设施,包含HTTP客户端管理、重试逻辑和KeyStore -- **`service.py`**: LLM模型实现类,专注于模型逻辑 -- **`utils.py`**: 工具和转换函数,如多模态消息处理 -- **`manager.py`**: 模型管理和缓存机制 -- **`adapters/`**: 各大提供商的适配器模块,负责与不同API的交互 - - `base.py`: 定义适配器的基础接口 - - `factory.py`: 适配器工厂,用于动态加载和实例化适配器 - - `openai.py`: OpenAI API适配器 - - `gemini.py`: Google Gemini API适配器 - - `zhipu.py`: 智谱AI API适配器 -- **`config/`**: 配置管理模块 - - `generation.py`: 生成配置和预设 - - `presets.py`: 预设配置 - - `providers.py`: 提供商配置 -- **`types/`**: 类型定义模块 - - `content.py`: 内容类型定义 - - `enums.py`: 枚举定义 - - `exceptions.py`: 异常定义 - - `models.py`: 数据模型定义 - -## 🔌 支持的提供商 - -### OpenAI 兼容 - -- **OpenAI**: GPT-4o, GPT-3.5-turbo等 -- **DeepSeek**: deepseek-chat, deepseek-reasoner等 -- **其他OpenAI兼容API**: 支持自定义端点 - -```python -# OpenAI -await chat("Hello", model="OpenAI/gpt-4o") - -# DeepSeek -await chat("写代码", model="DeepSeek/deepseek-reasoner") -``` - -### Google Gemini - -- **Gemini Pro**: gemini-2.5-flash-preview-05-20 gemini-2.0-flash等 -- **特殊功能**: 代码执行、搜索增强、思考模式 - -```python -# 基础使用 -await chat("你好", model="Gemini/gemini-2.0-flash") - -# 代码执行 -await code("计算质数", model="Gemini/gemini-2.0-flash") - -# 搜索增强 -await search("最新AI发展", model="Gemini/gemini-2.5-flash-preview-05-20") -``` - -### 智谱AI - -- **GLM系列**: glm-4, glm-4v等 -- **支持功能**: 文本生成、多模态理解 - -```python -await chat("介绍北京", model="Zhipu/glm-4") -``` - -## 🎯 使用场景 - -### 1. 聊天机器人 - -```python -from zhenxun.services.llm import AI, CommonOverrides - -class ChatBot: - def __init__(self): - self.ai = AI() - self.history = [] - - async def chat(self, user_input: str) -> str: - # 添加历史记录 - self.history.append({"role": "user", "content": user_input}) - - # 生成回复 - response = await self.ai.chat( - user_input, - history=self.history[-10:], # 保留最近10轮对话 - override_config=CommonOverrides.balanced() - ) - - self.history.append({"role": "assistant", "content": response}) - return response -``` - -### 2. 代码助手 - -```python -async def code_assistant(task: str) -> dict: - """代码生成和执行助手""" - result = await code( - f"请帮我{task},并执行代码验证结果", - model="Gemini/gemini-2.0-flash", - timeout=60 - ) - - return { - "explanation": result["text"], - "code_blocks": result["code_executions"], - "success": result["success"] - } - -# 使用示例 -result = await code_assistant("实现快速排序算法") -``` - -### 3. 文档分析 - -```python -from zhenxun.services.llm import analyze_with_images - -async def analyze_document(image_path: str, question: str) -> str: - """分析文档图片并回答问题""" - result = await analyze_with_images( - f"请分析这个文档并回答:{question}", - images=image_path, - model="Gemini/gemini-2.0-flash" - ) - return result -``` - -### 4. 智能搜索 - -```python -async def smart_search(query: str) -> dict: - """智能搜索和总结""" - result = await search( - query, - model="Gemini/gemini-2.0-flash", - instruction="请提供准确、最新的信息,并注明信息来源" - ) - - return { - "summary": result["text"], - "sources": result.get("grounding_metadata", {}), - "confidence": result.get("confidence_score", 0.0) - } -``` - -## 🔧 配置管理 - - -### 动态配置 - -```python -from zhenxun.services.llm import set_global_default_model_name - -# 运行时更改默认模型 -set_global_default_model_name("OpenAI/gpt-4") - -# 检查可用模型 +# 列出所有在config.yaml中配置的可用模型 models = list_available_models() -for model in models: - print(f"{model.provider}/{model.name} - {model.description}") +print([m['full_name'] for m in models]) + +# 列出所有可用的嵌入模型 +embedding_models = list_embedding_models() +print([m['full_name'] for m in embedding_models]) + +# 动态设置全局默认模型 +success = set_global_default_model_name("GLM/glm-4-plus") + +# 获取所有Key的使用统计 +stats = await get_key_usage_stats() +print(stats) + +# 重置'Gemini'提供商的所有Key +await reset_key_status("Gemini") ``` +### 缓存管理 + +模块提供了模型实例缓存功能,可以提高性能并减少重复初始化的开销。 + +```python +from zhenxun.services.llm import clear_model_cache, get_cache_stats + +# 获取缓存统计信息 +stats = get_cache_stats() +print(f"缓存大小: {stats['cache_size']}/{stats['max_cache_size']}") +print(f"缓存TTL: {stats['cache_ttl']}秒") +print(f"已缓存模型: {stats['cached_models']}") + +# 清空模型缓存(在内存不足或需要强制重新初始化时使用) +clear_model_cache() +print("模型缓存已清空") +``` + +### 错误处理 (`LLMException`) + +所有模块内的预期错误都会被包装成 `LLMException`,方便统一处理。 + +```python +from zhenxun.services.llm import chat, LLMException, LLMErrorCode + +try: + await chat("test", model="InvalidProvider/invalid_model") +except LLMException as e: + print(f"捕获到LLM异常: {e}") + print(f"错误码: {e.code}") # 例如 LLMErrorCode.MODEL_NOT_FOUND + print(f"用户友好提示: {e.user_friendly_message}") +``` + +### 自定义适配器 (Adapter) + +如果你想支持一个新的、非 OpenAI 兼容的 LLM 服务,可以通过实现自己的适配器来完成。 + +1. **创建适配器类**: 继承 `BaseAdapter` 并实现其抽象方法。 + + ```python + # my_adapters/custom_adapter.py + from zhenxun.services.llm.adapters import BaseAdapter, RequestData, ResponseData + + class MyCustomAdapter(BaseAdapter): + @property + def api_type(self) -> str: return "my_custom_api" + + @property + def supported_api_types(self) -> list[str]: return ["my_custom_api"] + # ... 实现 prepare_advanced_request, parse_response 等方法 + ``` + +2. **注册适配器**: 在你的插件初始化代码中注册你的适配器。 + + ```python + from zhenxun.services.llm.adapters import register_adapter + from .my_adapters.custom_adapter import MyCustomAdapter + + register_adapter(MyCustomAdapter()) + ``` + +3. **在 `config.yaml` 中使用**: + + ```yaml + AI: + PROVIDERS: + - name: MyAwesomeLLM + api_key: "my-secret-key" + api_type: "my_custom_api" # 关键!使用你注册的 api_type + # ... + ``` + +## 📚 API 快速参考 + +| 类/函数 | 主要用途 | 推荐场景 | +| ------------------------------------- | ---------------------------------------------------------------------- | ------------------------------------------------------------ | +| `llm.chat()` | 进行简单的、无状态的文本对话。 | 快速实现单轮问答。 | +| `llm.search()` | 执行带网络搜索的问答。 | 需要最新信息或回答事实性问题时。 | +| `llm.code()` | 请求模型执行代码。 | 计算、数据处理、代码生成等。 | +| `llm.pipeline_chat()` | 将多个模型串联,处理复杂任务流。 | 需要多模型协作完成的任务,如“图生文再润色”。 | +| `llm.analyze()` | 处理复杂的多模态输入 (`UniMessage`) 和工具调用。 | 插件中处理用户命令,需要解析图片、at、回复等复杂消息时。 | +| `llm.AI` (类) | 管理一个有状态的、连续的对话会话。 | 需要实现上下文关联的连续对话机器人。 | +| `llm.get_model_instance()` | 获取一个底层的、可直接控制的 `LLMModel` 实例。 | 需要对模型进行最精细控制的复杂或自定义场景。 | +| `llm.config.LLMGenerationConfig` (类) | 定义模型生成的具体参数,如温度、最大Token等。 | 当需要微调模型输出风格或格式时。 | +| `llm.tools.tool_registry` (实例) | 注册和管理可供LLM调用的函数工具。 | 当你想让LLM拥有与外部世界交互的能力时。 | +| `llm.embed()` | 生成文本的嵌入向量表示。 | 语义搜索、相似度计算、文本聚类等。 | +| `llm.search_multimodal()` | 执行带网络搜索的多模态问答。 | 需要基于图片、视频等多模态内容进行搜索时。 | +| `llm.analyze_multimodal()` | 便捷的多模态分析函数。 | 直接分析文本、图片、视频、音频等多模态内容。 | +| `llm.AIConfig` (类) | AI会话的配置类,包含模型、温度等参数。 | 配置AI会话的行为和特性。 | +| `llm.clear_model_cache()` | 清空模型实例缓存。 | 内存管理或强制重新初始化模型时。 | +| `llm.get_cache_stats()` | 获取模型缓存的统计信息。 | 监控缓存使用情况和性能优化。 | +| `llm.list_embedding_models()` | 列出所有可用的嵌入模型。 | 选择合适的嵌入模型进行向量化任务。 | +| `llm.config.CommonOverrides` (类) | 提供常用的配置预设,如创意模式、精确模式等。 | 快速应用常见的模型配置组合。 | +| `llm.utils.create_multimodal_message` | 便捷地从文本、图片、音视频等数据创建 `UniMessage`。 | 在代码中以编程方式构建多模态输入时。 | \ No newline at end of file diff --git a/zhenxun/services/llm/__init__.py b/zhenxun/services/llm/__init__.py index ff09ef7a..62a0003f 100644 --- a/zhenxun/services/llm/__init__.py +++ b/zhenxun/services/llm/__init__.py @@ -10,10 +10,10 @@ from .api import ( TaskType, analyze, analyze_multimodal, - analyze_with_images, chat, code, embed, + pipeline_chat, search, search_multimodal, ) @@ -35,6 +35,7 @@ from .manager import ( list_model_identifiers, set_global_default_model_name, ) +from .tools import tool_registry from .types import ( EmbeddingTaskType, LLMContentPart, @@ -43,6 +44,7 @@ from .types import ( LLMMessage, LLMResponse, LLMTool, + MCPCompatible, ModelDetail, ModelInfo, ModelProvider, @@ -51,7 +53,7 @@ from .types import ( ToolMetadata, UsageInfo, ) -from .utils import create_multimodal_message, unimsg_to_llm_parts +from .utils import create_multimodal_message, message_to_unimessage, unimsg_to_llm_parts __all__ = [ "AI", @@ -65,6 +67,7 @@ __all__ = [ "LLMMessage", "LLMResponse", "LLMTool", + "MCPCompatible", "ModelDetail", "ModelInfo", "ModelName", @@ -76,7 +79,6 @@ __all__ = [ "UsageInfo", "analyze", "analyze_multimodal", - "analyze_with_images", "chat", "clear_model_cache", "code", @@ -88,9 +90,12 @@ __all__ = [ "list_available_models", "list_embedding_models", "list_model_identifiers", + "message_to_unimessage", + "pipeline_chat", "register_llm_configs", "search", "search_multimodal", "set_global_default_model_name", + "tool_registry", "unimsg_to_llm_parts", ] diff --git a/zhenxun/services/llm/adapters/__init__.py b/zhenxun/services/llm/adapters/__init__.py index 93ed9d31..773d3ed2 100644 --- a/zhenxun/services/llm/adapters/__init__.py +++ b/zhenxun/services/llm/adapters/__init__.py @@ -8,7 +8,6 @@ from .base import BaseAdapter, OpenAICompatAdapter, RequestData, ResponseData from .factory import LLMAdapterFactory, get_adapter_for_api_type, register_adapter from .gemini import GeminiAdapter from .openai import OpenAIAdapter -from .zhipu import ZhipuAdapter LLMAdapterFactory.initialize() @@ -20,7 +19,6 @@ __all__ = [ "OpenAICompatAdapter", "RequestData", "ResponseData", - "ZhipuAdapter", "get_adapter_for_api_type", "register_adapter", ] diff --git a/zhenxun/services/llm/adapters/base.py b/zhenxun/services/llm/adapters/base.py index 499f9248..60258f7c 100644 --- a/zhenxun/services/llm/adapters/base.py +++ b/zhenxun/services/llm/adapters/base.py @@ -17,6 +17,7 @@ if TYPE_CHECKING: from ..service import LLMModel from ..types.content import LLMMessage from ..types.enums import EmbeddingTaskType + from ..types.models import LLMTool class RequestData(BaseModel): @@ -60,7 +61,7 @@ class BaseAdapter(ABC): """支持的API类型列表""" pass - def prepare_simple_request( + async def prepare_simple_request( self, model: "LLMModel", api_key: str, @@ -86,7 +87,7 @@ class BaseAdapter(ABC): config = model._generation_config - return self.prepare_advanced_request( + return await self.prepare_advanced_request( model=model, api_key=api_key, messages=messages, @@ -96,13 +97,13 @@ class BaseAdapter(ABC): ) @abstractmethod - def prepare_advanced_request( + async def prepare_advanced_request( self, model: "LLMModel", api_key: str, messages: list["LLMMessage"], config: "LLMGenerationConfig | None" = None, - tools: list[dict[str, Any]] | None = None, + tools: list["LLMTool"] | None = None, tool_choice: str | dict[str, Any] | None = None, ) -> RequestData: """准备高级请求""" @@ -238,6 +239,9 @@ class BaseAdapter(ABC): message = choice.get("message", {}) content = message.get("content", "") + if content: + content = content.strip() + parsed_tool_calls: list[LLMToolCall] | None = None if message_tool_calls := message.get("tool_calls"): from ..types.models import LLMToolFunction @@ -375,7 +379,7 @@ class BaseAdapter(ABC): if model.temperature is not None: base_config["temperature"] = model.temperature if model.max_tokens is not None: - if model.api_type in ["gemini", "gemini_native"]: + if model.api_type == "gemini": base_config["maxOutputTokens"] = model.max_tokens else: base_config["max_tokens"] = model.max_tokens @@ -401,26 +405,51 @@ class OpenAICompatAdapter(BaseAdapter): """ @abstractmethod - def get_chat_endpoint(self) -> str: + def get_chat_endpoint(self, model: "LLMModel") -> str: """子类必须实现,返回 chat completions 的端点""" pass @abstractmethod - def get_embedding_endpoint(self) -> str: + def get_embedding_endpoint(self, model: "LLMModel") -> str: """子类必须实现,返回 embeddings 的端点""" pass - def prepare_advanced_request( + async def prepare_simple_request( + self, + model: "LLMModel", + api_key: str, + prompt: str, + history: list[dict[str, str]] | None = None, + ) -> RequestData: + """准备简单文本生成请求 - OpenAI兼容API的通用实现""" + url = self.get_api_url(model, self.get_chat_endpoint(model)) + headers = self.get_base_headers(api_key) + + messages = [] + if history: + messages.extend(history) + messages.append({"role": "user", "content": prompt}) + + body = { + "model": model.model_name, + "messages": messages, + } + + body = self.apply_config_override(model, body) + + return RequestData(url=url, headers=headers, body=body) + + async def prepare_advanced_request( self, model: "LLMModel", api_key: str, messages: list["LLMMessage"], config: "LLMGenerationConfig | None" = None, - tools: list[dict[str, Any]] | None = None, + tools: list["LLMTool"] | None = None, tool_choice: str | dict[str, Any] | None = None, ) -> RequestData: """准备高级请求 - OpenAI兼容格式""" - url = self.get_api_url(model, self.get_chat_endpoint()) + url = self.get_api_url(model, self.get_chat_endpoint(model)) headers = self.get_base_headers(api_key) openai_messages = self.convert_messages_to_openai_format(messages) @@ -430,7 +459,21 @@ class OpenAICompatAdapter(BaseAdapter): } if tools: - body["tools"] = tools + openai_tools = [] + for tool in tools: + if tool.type == "function" and tool.function: + openai_tools.append({"type": "function", "function": tool.function}) + elif tool.type == "mcp" and tool.mcp_session: + if callable(tool.mcp_session): + raise ValueError( + "适配器接收到未激活的 MCP 会话工厂。" + "会话工厂应该在 LLMModel.generate_response 中被激活。" + ) + openai_tools.append( + tool.mcp_session.to_api_tool(api_type=self.api_type) + ) + if openai_tools: + body["tools"] = openai_tools if tool_choice: body["tool_choice"] = tool_choice @@ -444,7 +487,7 @@ class OpenAICompatAdapter(BaseAdapter): is_advanced: bool = False, ) -> ResponseData: """解析响应 - 直接使用基类的 OpenAI 格式解析""" - _ = model, is_advanced # 未使用的参数 + _ = model, is_advanced return self.parse_openai_response(response_json) def prepare_embedding_request( @@ -456,8 +499,8 @@ class OpenAICompatAdapter(BaseAdapter): **kwargs: Any, ) -> RequestData: """准备嵌入请求 - OpenAI兼容格式""" - _ = task_type # 未使用的参数 - url = self.get_api_url(model, self.get_embedding_endpoint()) + _ = task_type + url = self.get_api_url(model, self.get_embedding_endpoint(model)) headers = self.get_base_headers(api_key) body = { @@ -465,7 +508,6 @@ class OpenAICompatAdapter(BaseAdapter): "input": texts, } - # 应用额外的配置参数 if kwargs: body.update(kwargs) diff --git a/zhenxun/services/llm/adapters/factory.py b/zhenxun/services/llm/adapters/factory.py index 8652fc67..9f2a8b64 100644 --- a/zhenxun/services/llm/adapters/factory.py +++ b/zhenxun/services/llm/adapters/factory.py @@ -22,10 +22,8 @@ class LLMAdapterFactory: from .gemini import GeminiAdapter from .openai import OpenAIAdapter - from .zhipu import ZhipuAdapter cls.register_adapter(OpenAIAdapter()) - cls.register_adapter(ZhipuAdapter()) cls.register_adapter(GeminiAdapter()) @classmethod diff --git a/zhenxun/services/llm/adapters/gemini.py b/zhenxun/services/llm/adapters/gemini.py index 0ca22185..3e614d3f 100644 --- a/zhenxun/services/llm/adapters/gemini.py +++ b/zhenxun/services/llm/adapters/gemini.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: from ..service import LLMModel from ..types.content import LLMMessage from ..types.enums import EmbeddingTaskType - from ..types.models import LLMToolCall + from ..types.models import LLMTool, LLMToolCall class GeminiAdapter(BaseAdapter): @@ -38,30 +38,16 @@ class GeminiAdapter(BaseAdapter): return headers - def prepare_advanced_request( + async def prepare_advanced_request( self, model: "LLMModel", api_key: str, messages: list["LLMMessage"], config: "LLMGenerationConfig | None" = None, - tools: list[dict[str, Any]] | None = None, + tools: list["LLMTool"] | None = None, tool_choice: str | dict[str, Any] | None = None, ) -> RequestData: """准备高级请求""" - return self._prepare_request( - model, api_key, messages, config, tools, tool_choice - ) - - def _prepare_request( - self, - model: "LLMModel", - api_key: str, - messages: list["LLMMessage"], - config: "LLMGenerationConfig | None" = None, - tools: list[dict[str, Any]] | None = None, - tool_choice: str | dict[str, Any] | None = None, - ) -> RequestData: - """准备 Gemini API 请求 - 支持所有高级功能""" effective_config = config if config is not None else model._generation_config endpoint = self._get_gemini_endpoint(model, effective_config) @@ -78,7 +64,8 @@ class GeminiAdapter(BaseAdapter): system_instruction_parts = [{"text": msg.content}] elif isinstance(msg.content, list): system_instruction_parts = [ - part.convert_for_api("gemini") for part in msg.content + await part.convert_for_api_async("gemini") + for part in msg.content ] continue @@ -87,7 +74,9 @@ class GeminiAdapter(BaseAdapter): current_parts.append({"text": msg.content}) elif isinstance(msg.content, list): for part_obj in msg.content: - current_parts.append(part_obj.convert_for_api("gemini")) + current_parts.append( + await part_obj.convert_for_api_async("gemini") + ) gemini_contents.append({"role": "user", "parts": current_parts}) elif msg.role == "assistant" or msg.role == "model": @@ -95,7 +84,9 @@ class GeminiAdapter(BaseAdapter): current_parts.append({"text": msg.content}) elif isinstance(msg.content, list): for part_obj in msg.content: - current_parts.append(part_obj.convert_for_api("gemini")) + current_parts.append( + await part_obj.convert_for_api_async("gemini") + ) if msg.tool_calls: import json @@ -154,16 +145,22 @@ class GeminiAdapter(BaseAdapter): all_tools_for_request = [] if tools: - for tool_item in tools: - if isinstance(tool_item, dict): - if "name" in tool_item and "description" in tool_item: - all_tools_for_request.append( - {"functionDeclarations": [tool_item]} + for tool in tools: + if tool.type == "function" and tool.function: + all_tools_for_request.append( + {"functionDeclarations": [tool.function]} + ) + elif tool.type == "mcp" and tool.mcp_session: + if callable(tool.mcp_session): + raise ValueError( + "适配器接收到未激活的 MCP 会话工厂。" + "会话工厂应该在 LLMModel.generate_response 中被激活。" ) - else: - all_tools_for_request.append(tool_item) - else: - all_tools_for_request.append(tool_item) + all_tools_for_request.append( + tool.mcp_session.to_api_tool(api_type=self.api_type) + ) + elif tool.type == "google_search": + all_tools_for_request.append({"googleSearch": {}}) if effective_config: if getattr(effective_config, "enable_grounding", False): @@ -183,11 +180,7 @@ class GeminiAdapter(BaseAdapter): logger.debug("隐式启用代码执行工具。") if all_tools_for_request: - gemini_api_tools = self._convert_tools_to_gemini_format( - all_tools_for_request - ) - if gemini_api_tools: - body["tools"] = gemini_api_tools + body["tools"] = all_tools_for_request final_tool_choice = tool_choice if final_tool_choice is None and effective_config: @@ -241,38 +234,6 @@ class GeminiAdapter(BaseAdapter): return f"/v1beta/models/{model.model_name}:generateContent" - def _convert_tools_to_gemini_format( - self, tools: list[dict[str, Any]] - ) -> list[dict[str, Any]]: - """转换工具格式为Gemini格式""" - gemini_tools = [] - - for tool in tools: - if tool.get("type") == "function": - func = tool["function"] - gemini_tool = { - "functionDeclarations": [ - { - "name": func["name"], - "description": func.get("description", ""), - "parameters": func.get("parameters", {}), - } - ] - } - gemini_tools.append(gemini_tool) - elif tool.get("type") == "code_execution": - gemini_tools.append( - {"codeExecution": {"language": tool.get("language", "python")}} - ) - elif tool.get("type") == "google_search": - gemini_tools.append({"googleSearch": {}}) - elif "googleSearch" in tool: - gemini_tools.append({"googleSearch": tool["googleSearch"]}) - elif "codeExecution" in tool: - gemini_tools.append({"codeExecution": tool["codeExecution"]}) - - return gemini_tools - def _convert_tool_choice_to_gemini( self, tool_choice_value: str | dict[str, Any] ) -> dict[str, Any]: @@ -395,10 +356,11 @@ class GeminiAdapter(BaseAdapter): for category, threshold in custom_safety_settings.items(): safety_settings.append({"category": category, "threshold": threshold}) else: + from ..config.providers import get_gemini_safety_threshold + + threshold = get_gemini_safety_threshold() for category in safety_categories: - safety_settings.append( - {"category": category, "threshold": "BLOCK_MEDIUM_AND_ABOVE"} - ) + safety_settings.append({"category": category, "threshold": threshold}) return safety_settings if safety_settings else None diff --git a/zhenxun/services/llm/adapters/openai.py b/zhenxun/services/llm/adapters/openai.py index 046f0277..c7e73a13 100644 --- a/zhenxun/services/llm/adapters/openai.py +++ b/zhenxun/services/llm/adapters/openai.py @@ -1,12 +1,12 @@ """ OpenAI API 适配器 -支持 OpenAI、DeepSeek 和其他 OpenAI 兼容的 API 服务。 +支持 OpenAI、DeepSeek、智谱AI 和其他 OpenAI 兼容的 API 服务。 """ from typing import TYPE_CHECKING -from .base import OpenAICompatAdapter, RequestData +from .base import OpenAICompatAdapter if TYPE_CHECKING: from ..service import LLMModel @@ -21,37 +21,18 @@ class OpenAIAdapter(OpenAICompatAdapter): @property def supported_api_types(self) -> list[str]: - return ["openai", "deepseek", "general_openai_compat"] + return ["openai", "deepseek", "zhipu", "general_openai_compat", "ark"] - def get_chat_endpoint(self) -> str: + def get_chat_endpoint(self, model: "LLMModel") -> str: """返回聊天完成端点""" + if model.api_type == "ark": + return "/api/v3/chat/completions" + if model.api_type == "zhipu": + return "/api/paas/v4/chat/completions" return "/v1/chat/completions" - def get_embedding_endpoint(self) -> str: - """返回嵌入端点""" + def get_embedding_endpoint(self, model: "LLMModel") -> str: + """根据API类型返回嵌入端点""" + if model.api_type == "zhipu": + return "/v4/embeddings" return "/v1/embeddings" - - def prepare_simple_request( - self, - model: "LLMModel", - api_key: str, - prompt: str, - history: list[dict[str, str]] | None = None, - ) -> RequestData: - """准备简单文本生成请求 - OpenAI优化实现""" - url = self.get_api_url(model, self.get_chat_endpoint()) - headers = self.get_base_headers(api_key) - - messages = [] - if history: - messages.extend(history) - messages.append({"role": "user", "content": prompt}) - - body = { - "model": model.model_name, - "messages": messages, - } - - body = self.apply_config_override(model, body) - - return RequestData(url=url, headers=headers, body=body) diff --git a/zhenxun/services/llm/adapters/zhipu.py b/zhenxun/services/llm/adapters/zhipu.py deleted file mode 100644 index e5eb032f..00000000 --- a/zhenxun/services/llm/adapters/zhipu.py +++ /dev/null @@ -1,57 +0,0 @@ -""" -智谱 AI API 适配器 - -支持智谱 AI 的 GLM 系列模型,使用 OpenAI 兼容的接口格式。 -""" - -from typing import TYPE_CHECKING - -from .base import OpenAICompatAdapter, RequestData - -if TYPE_CHECKING: - from ..service import LLMModel - - -class ZhipuAdapter(OpenAICompatAdapter): - """智谱AI适配器 - 使用智谱AI专用的OpenAI兼容接口""" - - @property - def api_type(self) -> str: - return "zhipu" - - @property - def supported_api_types(self) -> list[str]: - return ["zhipu"] - - def get_chat_endpoint(self) -> str: - """返回智谱AI聊天完成端点""" - return "/api/paas/v4/chat/completions" - - def get_embedding_endpoint(self) -> str: - """返回智谱AI嵌入端点""" - return "/v4/embeddings" - - def prepare_simple_request( - self, - model: "LLMModel", - api_key: str, - prompt: str, - history: list[dict[str, str]] | None = None, - ) -> RequestData: - """准备简单文本生成请求 - 智谱AI优化实现""" - url = self.get_api_url(model, self.get_chat_endpoint()) - headers = self.get_base_headers(api_key) - - messages = [] - if history: - messages.extend(history) - messages.append({"role": "user", "content": prompt}) - - body = { - "model": model.model_name, - "messages": messages, - } - - body = self.apply_config_override(model, body) - - return RequestData(url=url, headers=headers, body=body) diff --git a/zhenxun/services/llm/api.py b/zhenxun/services/llm/api.py index 7aaed437..d9606f80 100644 --- a/zhenxun/services/llm/api.py +++ b/zhenxun/services/llm/api.py @@ -2,6 +2,7 @@ LLM 服务的高级 API 接口 """ +import copy from dataclasses import dataclass from enum import Enum from pathlib import Path @@ -14,6 +15,7 @@ from zhenxun.services.log import logger from .config import CommonOverrides, LLMGenerationConfig from .config.providers import get_ai_config from .manager import get_global_default_model_name, get_model_instance +from .tools import tool_registry from .types import ( EmbeddingTaskType, LLMContentPart, @@ -56,6 +58,7 @@ class AIConfig: enable_gemini_safe_mode: bool = False enable_gemini_multimodal: bool = False enable_gemini_grounding: bool = False + default_preserve_media_in_history: bool = False def __post_init__(self): """初始化后从配置中读取默认值""" @@ -81,7 +84,7 @@ class AI: """ 初始化AI服务 - Args: + 参数: config: AI 配置. history: 可选的初始对话历史. """ @@ -93,16 +96,65 @@ class AI: self.history = [] logger.info("AI session history cleared.") + def _sanitize_message_for_history(self, message: LLMMessage) -> LLMMessage: + """ + 净化用于存入历史记录的消息。 + 将非文本的多模态内容部分替换为文本占位符,以避免重复处理。 + """ + if not isinstance(message.content, list): + return message + + sanitized_message = copy.deepcopy(message) + content_list = sanitized_message.content + if not isinstance(content_list, list): + return sanitized_message + + new_content_parts: list[LLMContentPart] = [] + has_multimodal_content = False + + for part in content_list: + if isinstance(part, LLMContentPart) and part.type == "text": + new_content_parts.append(part) + else: + has_multimodal_content = True + + if has_multimodal_content: + placeholder = "[用户发送了媒体文件,内容已在首次分析时处理]" + text_part_found = False + for part in new_content_parts: + if part.type == "text": + part.text = f"{placeholder} {part.text or ''}".strip() + text_part_found = True + break + if not text_part_found: + new_content_parts.insert(0, LLMContentPart.text_part(placeholder)) + + sanitized_message.content = new_content_parts + return sanitized_message + async def chat( self, message: str | LLMMessage | list[LLMContentPart], *, model: ModelName = None, + preserve_media_in_history: bool | None = None, **kwargs: Any, ) -> str: """ 进行一次聊天对话。 此方法会自动使用和更新会话内的历史记录。 + + 参数: + message: 用户输入的消息。 + model: 本次对话要使用的模型。 + preserve_media_in_history: 是否在历史记录中保留原始多模态信息。 + - True: 保留,用于深度多轮媒体分析。 + - False: 不保留,替换为占位符,提高效率。 + - None (默认): 使用AI实例配置的默认值。 + **kwargs: 传递给模型的其他参数。 + + 返回: + str: 模型的文本响应。 """ current_message: LLMMessage if isinstance(message, str): @@ -127,7 +179,20 @@ class AI: final_messages, model, "聊天失败", kwargs ) - self.history.append(current_message) + should_preserve = ( + preserve_media_in_history + if preserve_media_in_history is not None + else self.config.default_preserve_media_in_history + ) + + if should_preserve: + logger.debug("深度分析模式:在历史记录中保留原始多模态消息。") + self.history.append(current_message) + else: + logger.debug("高效模式:净化历史记录中的多模态消息。") + sanitized_user_message = self._sanitize_message_for_history(current_message) + self.history.append(sanitized_user_message) + self.history.append(LLMMessage.assistant_text_response(response.text)) return response.text @@ -140,7 +205,18 @@ class AI: timeout: int | None = None, **kwargs: Any, ) -> dict[str, Any]: - """代码执行""" + """ + 代码执行 + + 参数: + prompt: 代码执行的提示词。 + model: 要使用的模型名称。 + timeout: 代码执行超时时间(秒)。 + **kwargs: 传递给模型的其他参数。 + + 返回: + dict[str, Any]: 包含执行结果的字典,包含text、code_executions和success字段。 + """ resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash" config = CommonOverrides.gemini_code_execution() @@ -168,7 +244,18 @@ class AI: instruction: str = "", **kwargs: Any, ) -> dict[str, Any]: - """信息搜索 - 支持多模态输入""" + """ + 信息搜索 - 支持多模态输入 + + 参数: + query: 搜索查询内容,支持文本或多模态消息。 + model: 要使用的模型名称。 + instruction: 搜索指令。 + **kwargs: 传递给模型的其他参数。 + + 返回: + dict[str, Any]: 包含搜索结果的字典,包含text、sources、queries和success字段 + """ resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash" config = CommonOverrides.gemini_grounding() @@ -217,63 +304,69 @@ class AI: async def analyze( self, - message: UniMessage, + message: UniMessage | None, *, instruction: str = "", model: ModelName = None, - tools: list[dict[str, Any]] | None = None, + use_tools: list[str] | None = None, tool_config: dict[str, Any] | None = None, + activated_tools: list[LLMTool] | None = None, + history: list[LLMMessage] | None = None, **kwargs: Any, - ) -> str | LLMResponse: + ) -> LLMResponse: """ 内容分析 - 接收 UniMessage 物件进行多模态分析和工具呼叫。 - 这是处理复杂互动的主要方法。 + + 参数: + message: 要分析的消息内容(支持多模态)。 + instruction: 分析指令。 + model: 要使用的模型名称。 + use_tools: 要使用的工具名称列表。 + tool_config: 工具配置。 + activated_tools: 已激活的工具列表。 + history: 对话历史记录。 + **kwargs: 传递给模型的其他参数。 + + 返回: + LLMResponse: 模型的完整响应结果。 """ - content_parts = await unimsg_to_llm_parts(message) + content_parts = await unimsg_to_llm_parts(message or UniMessage()) final_messages: list[LLMMessage] = [] + if history: + final_messages.extend(history) + if instruction: - final_messages.append(LLMMessage.system(instruction)) + if not any(msg.role == "system" for msg in final_messages): + final_messages.insert(0, LLMMessage.system(instruction)) if not content_parts: - if instruction: + if instruction and not history: final_messages.append(LLMMessage.user(instruction)) - else: + elif not history: raise LLMException( "分析内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED ) else: final_messages.append(LLMMessage.user(content_parts)) - llm_tools = None - if tools: - llm_tools = [] - for tool_dict in tools: - if isinstance(tool_dict, dict): - if "name" in tool_dict and "description" in tool_dict: - llm_tool = LLMTool( - type="function", - function={ - "name": tool_dict["name"], - "description": tool_dict["description"], - "parameters": tool_dict.get("parameters", {}), - }, - ) - llm_tools.append(llm_tool) - else: - llm_tools.append(LLMTool(**tool_dict)) - else: - llm_tools.append(tool_dict) + llm_tools: list[LLMTool] | None = activated_tools + if not llm_tools and use_tools: + try: + llm_tools = tool_registry.get_tools(use_tools) + logger.debug(f"已从注册表加载工具定义: {use_tools}") + except ValueError as e: + raise LLMException( + f"加载工具定义失败: {e}", + code=LLMErrorCode.CONFIGURATION_ERROR, + cause=e, + ) tool_choice = None if tool_config: mode = tool_config.get("mode", "auto") - if mode == "auto": - tool_choice = "auto" - elif mode == "any": - tool_choice = "any" - elif mode == "none": - tool_choice = "none" + if mode in ["auto", "any", "none"]: + tool_choice = mode response = await self._execute_generation( final_messages, @@ -284,9 +377,7 @@ class AI: tool_choice=tool_choice, ) - if response.tool_calls: - return response - return response.text + return response async def _execute_generation( self, @@ -298,7 +389,7 @@ class AI: tool_choice: str | dict[str, Any] | None = None, base_config: LLMGenerationConfig | None = None, ) -> LLMResponse: - """通用的生成执行方法,封装重复的模型获取、配置合并和异常处理逻辑""" + """通用的生成执行方法,封装模型获取和单次API调用""" try: resolved_model_name = self._resolve_model_name( model_name or self.config.model @@ -311,7 +402,9 @@ class AI: resolved_model_name, override_config=final_config_dict ) as model_instance: return await model_instance.generate_response( - messages, tools=llm_tools, tool_choice=tool_choice + messages, + tools=llm_tools, + tool_choice=tool_choice, ) except LLMException: raise @@ -380,7 +473,18 @@ class AI: task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT, **kwargs: Any, ) -> list[list[float]]: - """生成文本嵌入向量""" + """ + 生成文本嵌入向量 + + 参数: + texts: 要生成嵌入向量的文本或文本列表。 + model: 要使用的嵌入模型名称。 + task_type: 嵌入任务类型。 + **kwargs: 传递给模型的其他参数。 + + 返回: + list[list[float]]: 文本的嵌入向量列表。 + """ if isinstance(texts, str): texts = [texts] if not texts: @@ -420,7 +524,17 @@ async def chat( model: ModelName = None, **kwargs: Any, ) -> str: - """聊天对话便捷函数""" + """ + 聊天对话便捷函数 + + 参数: + message: 用户输入的消息。 + model: 要使用的模型名称。 + **kwargs: 传递给模型的其他参数。 + + 返回: + str: 模型的文本响应。 + """ ai = AI() return await ai.chat(message, model=model, **kwargs) @@ -432,7 +546,18 @@ async def code( timeout: int | None = None, **kwargs: Any, ) -> dict[str, Any]: - """代码执行便捷函数""" + """ + 代码执行便捷函数 + + 参数: + prompt: 代码执行的提示词。 + model: 要使用的模型名称。 + timeout: 代码执行超时时间(秒)。 + **kwargs: 传递给模型的其他参数。 + + 返回: + dict[str, Any]: 包含执行结果的字典。 + """ ai = AI() return await ai.code(prompt, model=model, timeout=timeout, **kwargs) @@ -444,45 +569,56 @@ async def search( instruction: str = "", **kwargs: Any, ) -> dict[str, Any]: - """信息搜索便捷函数""" + """ + 信息搜索便捷函数 + + 参数: + query: 搜索查询内容。 + model: 要使用的模型名称。 + instruction: 搜索指令。 + **kwargs: 传递给模型的其他参数。 + + 返回: + dict[str, Any]: 包含搜索结果的字典。 + """ ai = AI() return await ai.search(query, model=model, instruction=instruction, **kwargs) async def analyze( - message: UniMessage, + message: UniMessage | None, *, instruction: str = "", model: ModelName = None, - tools: list[dict[str, Any]] | None = None, + use_tools: list[str] | None = None, tool_config: dict[str, Any] | None = None, **kwargs: Any, ) -> str | LLMResponse: - """内容分析便捷函数""" + """ + 内容分析便捷函数 + + 参数: + message: 要分析的消息内容。 + instruction: 分析指令。 + model: 要使用的模型名称。 + use_tools: 要使用的工具名称列表。 + tool_config: 工具配置。 + **kwargs: 传递给模型的其他参数。 + + 返回: + str | LLMResponse: 分析结果。 + """ ai = AI() return await ai.analyze( message, instruction=instruction, model=model, - tools=tools, + use_tools=use_tools, tool_config=tool_config, **kwargs, ) -async def analyze_with_images( - text: str, - images: list[str | Path | bytes] | str | Path | bytes, - *, - instruction: str = "", - model: ModelName = None, - **kwargs: Any, -) -> str | LLMResponse: - """图片分析便捷函数""" - message = create_multimodal_message(text=text, images=images) - return await analyze(message, instruction=instruction, model=model, **kwargs) - - async def analyze_multimodal( text: str | None = None, images: list[str | Path | bytes] | str | Path | bytes | None = None, @@ -493,7 +629,21 @@ async def analyze_multimodal( model: ModelName = None, **kwargs: Any, ) -> str | LLMResponse: - """多模态分析便捷函数""" + """ + 多模态分析便捷函数 + + 参数: + text: 文本内容。 + images: 图片文件路径、字节数据或列表。 + videos: 视频文件路径、字节数据或列表。 + audios: 音频文件路径、字节数据或列表。 + instruction: 分析指令。 + model: 要使用的模型名称。 + **kwargs: 传递给模型的其他参数。 + + 返回: + str | LLMResponse: 分析结果。 + """ message = create_multimodal_message( text=text, images=images, videos=videos, audios=audios ) @@ -510,7 +660,21 @@ async def search_multimodal( model: ModelName = None, **kwargs: Any, ) -> dict[str, Any]: - """多模态搜索便捷函数""" + """ + 多模态搜索便捷函数 + + 参数: + text: 文本内容。 + images: 图片文件路径、字节数据或列表。 + videos: 视频文件路径、字节数据或列表。 + audios: 音频文件路径、字节数据或列表。 + instruction: 搜索指令。 + model: 要使用的模型名称。 + **kwargs: 传递给模型的其他参数。 + + 返回: + dict[str, Any]: 包含搜索结果的字典。 + """ message = create_multimodal_message( text=text, images=images, videos=videos, audios=audios ) @@ -525,6 +689,101 @@ async def embed( task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT, **kwargs: Any, ) -> list[list[float]]: - """文本嵌入便捷函数""" + """ + 文本嵌入便捷函数 + + 参数: + texts: 要生成嵌入向量的文本或文本列表。 + model: 要使用的嵌入模型名称。 + task_type: 嵌入任务类型。 + **kwargs: 传递给模型的其他参数。 + + 返回: + list[list[float]]: 文本的嵌入向量列表。 + """ ai = AI() return await ai.embed(texts, model=model, task_type=task_type, **kwargs) + + +async def pipeline_chat( + message: UniMessage | str | list[LLMContentPart], + model_chain: list[ModelName], + *, + initial_instruction: str = "", + final_instruction: str = "", + **kwargs: Any, +) -> LLMResponse: + """ + AI模型链式调用,前一个模型的输出作为下一个模型的输入。 + + 参数: + message: 初始输入消息(支持多模态) + model_chain: 模型名称列表 + initial_instruction: 第一个模型的系统指令 + final_instruction: 最后一个模型的系统指令 + **kwargs: 传递给模型实例的其他参数 + + 返回: + LLMResponse: 最后一个模型的响应结果 + """ + if not model_chain: + raise ValueError("模型链`model_chain`不能为空。") + + current_content: str | list[LLMContentPart] + if isinstance(message, str): + current_content = message + elif isinstance(message, list): + current_content = message + else: + current_content = await unimsg_to_llm_parts(message) + + final_response: LLMResponse | None = None + + for i, model_name in enumerate(model_chain): + if not model_name: + raise ValueError(f"模型链中第 {i + 1} 个模型名称为空。") + + is_first_step = i == 0 + is_last_step = i == len(model_chain) - 1 + + messages_for_step: list[LLMMessage] = [] + instruction_for_step = "" + if is_first_step and initial_instruction: + instruction_for_step = initial_instruction + elif is_last_step and final_instruction: + instruction_for_step = final_instruction + + if instruction_for_step: + messages_for_step.append(LLMMessage.system(instruction_for_step)) + + messages_for_step.append(LLMMessage.user(current_content)) + + logger.info( + f"Pipeline Step [{i + 1}/{len(model_chain)}]: " + f"使用模型 '{model_name}' 进行处理..." + ) + try: + async with await get_model_instance(model_name, **kwargs) as model: + response = await model.generate_response(messages_for_step) + final_response = response + current_content = response.text.strip() + if not current_content and not is_last_step: + logger.warning( + f"模型 '{model_name}' 在中间步骤返回了空内容,流水线可能无法继续。" + ) + break + + except Exception as e: + logger.error(f"在模型链的第 {i + 1} 步 ('{model_name}') 出错: {e}", e=e) + raise LLMException( + f"流水线在模型 '{model_name}' 处执行失败: {e}", + code=LLMErrorCode.GENERATION_FAILED, + cause=e, + ) + + if final_response is None: + raise LLMException( + "AI流水线未能产生任何响应。", code=LLMErrorCode.GENERATION_FAILED + ) + + return final_response diff --git a/zhenxun/services/llm/config/__init__.py b/zhenxun/services/llm/config/__init__.py index 09fd9599..41021a92 100644 --- a/zhenxun/services/llm/config/__init__.py +++ b/zhenxun/services/llm/config/__init__.py @@ -14,6 +14,8 @@ from .generation import ( from .presets import CommonOverrides from .providers import ( LLMConfig, + ToolConfig, + get_gemini_safety_threshold, get_llm_config, register_llm_configs, set_default_model, @@ -25,8 +27,10 @@ __all__ = [ "LLMConfig", "LLMGenerationConfig", "ModelConfigOverride", + "ToolConfig", "apply_api_specific_mappings", "create_generation_config_from_kwargs", + "get_gemini_safety_threshold", "get_llm_config", "register_llm_configs", "set_default_model", diff --git a/zhenxun/services/llm/config/generation.py b/zhenxun/services/llm/config/generation.py index a143dedd..a452ae1f 100644 --- a/zhenxun/services/llm/config/generation.py +++ b/zhenxun/services/llm/config/generation.py @@ -111,12 +111,12 @@ class LLMGenerationConfig(ModelConfigOverride): params["temperature"] = self.temperature if self.max_tokens is not None: - if api_type in ["gemini", "gemini_native"]: + if api_type == "gemini": params["maxOutputTokens"] = self.max_tokens else: params["max_tokens"] = self.max_tokens - if api_type in ["gemini", "gemini_native"]: + if api_type == "gemini": if self.top_k is not None: params["topK"] = self.top_k if self.top_p is not None: @@ -151,13 +151,13 @@ class LLMGenerationConfig(ModelConfigOverride): if api_type in ["openai", "zhipu", "deepseek", "general_openai_compat"]: params["response_format"] = {"type": "json_object"} logger.debug(f"为 {api_type} 启用 JSON 对象输出模式") - elif api_type in ["gemini", "gemini_native"]: + elif api_type == "gemini": params["responseMimeType"] = "application/json" if self.response_schema: params["responseSchema"] = self.response_schema logger.debug(f"为 {api_type} 启用 JSON MIME 类型输出模式") - if api_type in ["gemini", "gemini_native"]: + if api_type == "gemini": if ( self.response_format != ResponseFormat.JSON and self.response_mime_type is not None @@ -214,7 +214,7 @@ def apply_api_specific_mappings( """应用API特定的参数映射""" mapped_params = params.copy() - if api_type in ["gemini", "gemini_native"]: + if api_type == "gemini": if "max_tokens" in mapped_params: mapped_params["maxOutputTokens"] = mapped_params.pop("max_tokens") if "top_k" in mapped_params: diff --git a/zhenxun/services/llm/config/presets.py b/zhenxun/services/llm/config/presets.py index 7a6023d5..aa4b6c21 100644 --- a/zhenxun/services/llm/config/presets.py +++ b/zhenxun/services/llm/config/presets.py @@ -71,14 +71,17 @@ class CommonOverrides: @staticmethod def gemini_safe() -> LLMGenerationConfig: - """Gemini 安全模式:严格安全设置""" + """Gemini 安全模式:使用配置的安全设置""" + from .providers import get_gemini_safety_threshold + + threshold = get_gemini_safety_threshold() return LLMGenerationConfig( temperature=0.5, safety_settings={ - "HARM_CATEGORY_HARASSMENT": "BLOCK_MEDIUM_AND_ABOVE", - "HARM_CATEGORY_HATE_SPEECH": "BLOCK_MEDIUM_AND_ABOVE", - "HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_MEDIUM_AND_ABOVE", - "HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_MEDIUM_AND_ABOVE", + "HARM_CATEGORY_HARASSMENT": threshold, + "HARM_CATEGORY_HATE_SPEECH": threshold, + "HARM_CATEGORY_SEXUALLY_EXPLICIT": threshold, + "HARM_CATEGORY_DANGEROUS_CONTENT": threshold, }, ) diff --git a/zhenxun/services/llm/config/providers.py b/zhenxun/services/llm/config/providers.py index 8f4dea80..a39e32c9 100644 --- a/zhenxun/services/llm/config/providers.py +++ b/zhenxun/services/llm/config/providers.py @@ -4,15 +4,33 @@ LLM 提供商配置管理 负责注册和管理 AI 服务提供商的配置项。 """ +from functools import lru_cache +import json +import sys from typing import Any from pydantic import BaseModel, Field from zhenxun.configs.config import Config +from zhenxun.configs.path_config import DATA_PATH +from zhenxun.configs.utils import parse_as from zhenxun.services.log import logger +from zhenxun.utils.manager.priority_manager import PriorityLifecycle from ..types.models import ModelDetail, ProviderConfig + +class ToolConfig(BaseModel): + """MCP类型工具的配置定义""" + + type: str = "mcp" + name: str = Field(..., description="工具的唯一名称标识") + description: str | None = Field(None, description="工具功能的描述") + mcp_config: dict[str, Any] | BaseModel = Field( + ..., description="MCP服务器的特定配置" + ) + + AI_CONFIG_GROUP = "AI" PROVIDERS_CONFIG_KEY = "PROVIDERS" @@ -38,6 +56,9 @@ class LLMConfig(BaseModel): providers: list[ProviderConfig] = Field( default_factory=list, description="配置多个 AI 服务提供商及其模型信息" ) + mcp_tools: list[ToolConfig] = Field( + default_factory=list, description="配置可用的外部MCP工具" + ) def get_provider_by_name(self, name: str) -> ProviderConfig | None: """根据名称获取提供商配置 @@ -132,7 +153,7 @@ def get_default_providers() -> list[dict[str, Any]]: return [ { "name": "DeepSeek", - "api_key": "sk-******", + "api_key": "YOUR_ARK_API_KEY", "api_base": "https://api.deepseek.com", "api_type": "openai", "models": [ @@ -146,9 +167,30 @@ def get_default_providers() -> list[dict[str, Any]]: }, ], }, + { + "name": "ARK", + "api_key": "YOUR_ARK_API_KEY", + "api_base": "https://ark.cn-beijing.volces.com", + "api_type": "ark", + "models": [ + {"model_name": "deepseek-r1-250528"}, + {"model_name": "doubao-seed-1-6-250615"}, + {"model_name": "doubao-seed-1-6-flash-250615"}, + {"model_name": "doubao-seed-1-6-thinking-250615"}, + ], + }, + { + "name": "siliconflow", + "api_key": "YOUR_ARK_API_KEY", + "api_base": "https://api.siliconflow.cn", + "api_type": "openai", + "models": [ + {"model_name": "deepseek-ai/DeepSeek-V3"}, + ], + }, { "name": "GLM", - "api_key": "", + "api_key": "YOUR_ARK_API_KEY", "api_base": "https://open.bigmodel.cn", "api_type": "zhipu", "models": [ @@ -167,12 +209,41 @@ def get_default_providers() -> list[dict[str, Any]]: "api_type": "gemini", "models": [ {"model_name": "gemini-2.0-flash"}, - {"model_name": "gemini-2.5-flash-preview-05-20"}, + {"model_name": "gemini-2.5-flash"}, + {"model_name": "gemini-2.5-pro"}, + {"model_name": "gemini-2.5-flash-lite-preview-06-17"}, ], }, ] +def get_default_mcp_tools() -> dict[str, Any]: + """ + 获取默认的MCP工具配置,用于在文件不存在时创建。 + 包含了 baidu-map, Context7, 和 sequential-thinking. + """ + return { + "mcpServers": { + "baidu-map": { + "command": "npx", + "args": ["-y", "@baidumap/mcp-server-baidu-map"], + "env": {"BAIDU_MAP_API_KEY": ""}, + "description": "百度地图工具,提供地理编码、路线规划等功能。", + }, + "sequential-thinking": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-sequential-thinking"], + "description": "顺序思维工具,用于帮助模型进行多步骤推理。", + }, + "Context7": { + "command": "npx", + "args": ["-y", "@upstash/context7-mcp@latest"], + "description": "Upstash 提供的上下文管理和记忆工具。", + }, + } + } + + def register_llm_configs(): """注册 LLM 服务的配置项""" logger.info("注册 LLM 服务的配置项") @@ -214,6 +285,19 @@ def register_llm_configs(): help="LLM服务请求重试的基础延迟时间(秒)", type=int, ) + Config.add_plugin_config( + AI_CONFIG_GROUP, + "gemini_safety_threshold", + "BLOCK_MEDIUM_AND_ABOVE", + help=( + "Gemini 安全过滤阈值 " + "(BLOCK_LOW_AND_ABOVE: 阻止低级别及以上, " + "BLOCK_MEDIUM_AND_ABOVE: 阻止中等级别及以上, " + "BLOCK_ONLY_HIGH: 只阻止高级别, " + "BLOCK_NONE: 不阻止)" + ), + type=str, + ) Config.add_plugin_config( AI_CONFIG_GROUP, @@ -225,24 +309,111 @@ def register_llm_configs(): ) +@lru_cache(maxsize=1) def get_llm_config() -> LLMConfig: - """获取 LLM 配置实例 - - 返回: - LLMConfig: LLM 配置实例 - """ + """获取 LLM 配置实例,现在会从新的 JSON 文件加载 MCP 工具""" ai_config = get_ai_config() + llm_data_path = DATA_PATH / "llm" + mcp_tools_path = llm_data_path / "mcp_tools.json" + + mcp_tools_list = [] + mcp_servers_dict = {} + + if not mcp_tools_path.exists(): + logger.info(f"未找到 MCP 工具配置文件,将在 '{mcp_tools_path}' 创建一个。") + llm_data_path.mkdir(parents=True, exist_ok=True) + default_mcp_config = get_default_mcp_tools() + try: + with mcp_tools_path.open("w", encoding="utf-8") as f: + json.dump(default_mcp_config, f, ensure_ascii=False, indent=2) + mcp_servers_dict = default_mcp_config.get("mcpServers", {}) + except Exception as e: + logger.error(f"创建默认 MCP 配置文件失败: {e}", e=e) + mcp_servers_dict = {} + else: + try: + with mcp_tools_path.open("r", encoding="utf-8") as f: + mcp_data = json.load(f) + mcp_servers_dict = mcp_data.get("mcpServers", {}) + if not isinstance(mcp_servers_dict, dict): + logger.warning( + f"'{mcp_tools_path}' 中的 'mcpServers' 键不是一个字典," + f"将使用空配置。" + ) + mcp_servers_dict = {} + + except json.JSONDecodeError as e: + logger.error(f"解析 MCP 配置文件 '{mcp_tools_path}' 失败: {e}", e=e) + except Exception as e: + logger.error(f"读取 MCP 配置文件时发生未知错误: {e}", e=e) + mcp_servers_dict = {} + + if sys.platform == "win32": + logger.debug("检测到Windows平台,正在调整MCP工具的npx命令...") + for name, config in mcp_servers_dict.items(): + if isinstance(config, dict) and config.get("command") == "npx": + logger.info(f"为工具 '{name}' 包装npx命令以兼容Windows。") + original_args = config.get("args", []) + config["command"] = "cmd" + config["args"] = ["/c", "npx", *original_args] + + if mcp_servers_dict: + mcp_tools_list = [ + { + "name": name, + "type": "mcp", + "description": config.get("description", f"MCP tool for {name}"), + "mcp_config": config, + } + for name, config in mcp_servers_dict.items() + if isinstance(config, dict) + ] + + from ..tools.registry import tool_registry + + for tool_dict in mcp_tools_list: + if isinstance(tool_dict, dict): + tool_name = tool_dict.get("name") + if not tool_name: + continue + + config_model = tool_registry.get_mcp_config_model(tool_name) + if not config_model: + logger.debug( + f"MCP工具 '{tool_name}' 没有注册其配置模型," + f"将跳过特定配置验证,直接使用原始配置字典。" + ) + continue + + mcp_config_data = tool_dict.get("mcp_config", {}) + try: + parsed_mcp_config = parse_as(config_model, mcp_config_data) + tool_dict["mcp_config"] = parsed_mcp_config + except Exception as e: + raise ValueError(f"MCP工具 '{tool_name}' 的 `mcp_config` 配置错误: {e}") + config_data = { "default_model_name": ai_config.get("default_model_name"), "proxy": ai_config.get("proxy"), "timeout": ai_config.get("timeout", 180), "max_retries_llm": ai_config.get("max_retries_llm", 3), "retry_delay_llm": ai_config.get("retry_delay_llm", 2), - "providers": ai_config.get(PROVIDERS_CONFIG_KEY, []), + PROVIDERS_CONFIG_KEY: ai_config.get(PROVIDERS_CONFIG_KEY, []), + "mcp_tools": mcp_tools_list, } - return LLMConfig(**config_data) + return parse_as(LLMConfig, config_data) + + +def get_gemini_safety_threshold() -> str: + """获取 Gemini 安全过滤阈值配置 + + 返回: + str: 安全过滤阈值 + """ + ai_config = get_ai_config() + return ai_config.get("gemini_safety_threshold", "BLOCK_MEDIUM_AND_ABOVE") def validate_llm_config() -> tuple[bool, list[str]]: @@ -326,3 +497,17 @@ def set_default_model(provider_model_name: str | None) -> bool: logger.info("默认模型已清除") return True + + +@PriorityLifecycle.on_startup(priority=10) +async def _init_llm_config_on_startup(): + """ + 在服务启动时主动调用一次 get_llm_config, + 以触发必要的初始化操作,例如创建默认的 mcp_tools.json 文件。 + """ + logger.info("正在初始化 LLM 配置并检查 MCP 工具文件...") + try: + get_llm_config() + logger.info("LLM 配置初始化完成。") + except Exception as e: + logger.error(f"LLM 配置初始化时发生错误: {e}", e=e) diff --git a/zhenxun/services/llm/core.py b/zhenxun/services/llm/core.py index ffd900cf..56591701 100644 --- a/zhenxun/services/llm/core.py +++ b/zhenxun/services/llm/core.py @@ -49,12 +49,36 @@ class LLMHttpClient: max_keepalive_connections=self.config.max_keepalive_connections, ) timeout = httpx.Timeout(self.config.timeout) + + client_kwargs = {} + if self.config.proxy: + try: + version_parts = httpx.__version__.split(".") + major = int( + "".join(c for c in version_parts[0] if c.isdigit()) + ) + minor = ( + int("".join(c for c in version_parts[1] if c.isdigit())) + if len(version_parts) > 1 + else 0 + ) + if (major, minor) >= (0, 28): + client_kwargs["proxy"] = self.config.proxy + else: + client_kwargs["proxies"] = self.config.proxy + except (ValueError, IndexError): + client_kwargs["proxies"] = self.config.proxy + logger.warning( + f"无法解析 httpx 版本 '{httpx.__version__}'," + "LLM模块将默认使用旧版 'proxies' 参数语法。" + ) + self._client = httpx.AsyncClient( headers=headers, limits=limits, timeout=timeout, - proxies=self.config.proxy, follow_redirects=True, + **client_kwargs, ) if self._client is None: raise LLMException( @@ -156,7 +180,16 @@ async def create_llm_http_client( timeout: int = 180, proxy: str | None = None, ) -> LLMHttpClient: - """创建LLM HTTP客户端""" + """ + 创建LLM HTTP客户端 + + 参数: + timeout: 超时时间(秒)。 + proxy: 代理服务器地址。 + + 返回: + LLMHttpClient: HTTP客户端实例。 + """ config = HttpClientConfig(timeout=timeout, proxy=proxy) return LLMHttpClient(config) @@ -185,7 +218,20 @@ async def with_smart_retry( provider_name: str | None = None, **kwargs: Any, ) -> Any: - """智能重试装饰器 - 支持Key轮询和错误分类""" + """ + 智能重试装饰器 - 支持Key轮询和错误分类 + + 参数: + func: 要重试的异步函数。 + *args: 传递给函数的位置参数。 + retry_config: 重试配置。 + key_store: API密钥状态存储。 + provider_name: 提供商名称。 + **kwargs: 传递给函数的关键字参数。 + + 返回: + Any: 函数执行结果。 + """ config = retry_config or RetryConfig() last_exception: Exception | None = None failed_keys: set[str] = set() @@ -294,7 +340,17 @@ class KeyStatusStore: api_keys: list[str], exclude_keys: set[str] | None = None, ) -> str | None: - """获取下一个可用的API密钥(轮询策略)""" + """ + 获取下一个可用的API密钥(轮询策略) + + 参数: + provider_name: 提供商名称。 + api_keys: API密钥列表。 + exclude_keys: 要排除的密钥集合。 + + 返回: + str | None: 可用的API密钥,如果没有可用密钥则返回None。 + """ if not api_keys: return None @@ -338,7 +394,13 @@ class KeyStatusStore: logger.debug(f"记录API密钥成功使用: {self._get_key_id(api_key)}") async def record_failure(self, api_key: str, status_code: int | None): - """记录失败使用""" + """ + 记录失败使用 + + 参数: + api_key: API密钥。 + status_code: HTTP状态码。 + """ key_id = self._get_key_id(api_key) async with self._lock: if status_code in [401, 403]: @@ -356,7 +418,15 @@ class KeyStatusStore: logger.info(f"重置API密钥状态: {self._get_key_id(api_key)}") async def get_key_stats(self, api_keys: list[str]) -> dict[str, dict]: - """获取密钥使用统计""" + """ + 获取密钥使用统计 + + 参数: + api_keys: API密钥列表。 + + 返回: + dict[str, dict]: 密钥统计信息字典。 + """ stats = {} async with self._lock: for key in api_keys: diff --git a/zhenxun/services/llm/manager.py b/zhenxun/services/llm/manager.py index f23dfa50..f0e9c560 100644 --- a/zhenxun/services/llm/manager.py +++ b/zhenxun/services/llm/manager.py @@ -17,6 +17,7 @@ from .config.providers import AI_CONFIG_GROUP, PROVIDERS_CONFIG_KEY, get_ai_conf from .core import http_client_manager, key_store from .service import LLMModel from .types import LLMErrorCode, LLMException, ModelDetail, ProviderConfig +from .types.capabilities import get_model_capabilities DEFAULT_MODEL_NAME_KEY = "default_model_name" PROXY_KEY = "proxy" @@ -115,57 +116,30 @@ def get_default_api_base_for_type(api_type: str) -> str | None: def get_configured_providers() -> list[ProviderConfig]: - """从配置中获取Provider列表 - 简化版本""" + """从配置中获取Provider列表 - 简化和修正版本""" ai_config = get_ai_config() - providers_raw = ai_config.get(PROVIDERS_CONFIG_KEY, []) - if not isinstance(providers_raw, list): + providers = ai_config.get(PROVIDERS_CONFIG_KEY, []) + + if not isinstance(providers, list): logger.error( - f"配置项 {AI_CONFIG_GROUP}.{PROVIDERS_CONFIG_KEY} 不是一个列表," + f"配置项 {AI_CONFIG_GROUP}.{PROVIDERS_CONFIG_KEY} 的值不是一个列表," f"将使用空列表。" ) return [] valid_providers = [] - for i, item in enumerate(providers_raw): - if not isinstance(item, dict): - logger.warning(f"配置文件中第 {i + 1} 项不是字典格式,已跳过。") - continue - - try: - if not item.get("name"): - logger.warning(f"Provider {i + 1} 缺少 'name' 字段,已跳过。") - continue - - if not item.get("api_key"): - logger.warning( - f"Provider '{item['name']}' 缺少 'api_key' 字段,已跳过。" - ) - continue - - if "api_type" not in item or not item["api_type"]: - provider_name = item.get("name", "").lower() - if "glm" in provider_name or "zhipu" in provider_name: - item["api_type"] = "zhipu" - elif "gemini" in provider_name or "google" in provider_name: - item["api_type"] = "gemini" - else: - item["api_type"] = "openai" - - if "api_base" not in item or not item["api_base"]: - api_type = item.get("api_type") - if api_type: - default_api_base = get_default_api_base_for_type(api_type) - if default_api_base: - item["api_base"] = default_api_base - - if "models" not in item: - item["models"] = [{"model_name": item.get("name", "default")}] - - provider_conf = ProviderConfig(**item) - valid_providers.append(provider_conf) - - except Exception as e: - logger.warning(f"解析配置文件中 Provider {i + 1} 时出错: {e},已跳过。") + for i, item in enumerate(providers): + if isinstance(item, ProviderConfig): + if not item.api_base: + default_api_base = get_default_api_base_for_type(item.api_type) + if default_api_base: + item.api_base = default_api_base + valid_providers.append(item) + else: + logger.warning( + f"配置文件中第 {i + 1} 项未能正确解析为 ProviderConfig 对象," + f"已跳过。实际类型: {type(item)}" + ) return valid_providers @@ -173,14 +147,15 @@ def get_configured_providers() -> list[ProviderConfig]: def find_model_config( provider_name: str, model_name: str ) -> tuple[ProviderConfig, ModelDetail] | None: - """在配置中查找指定的 Provider 和 ModelDetail + """ + 在配置中查找指定的 Provider 和 ModelDetail - Args: + 参数: provider_name: 提供商名称 model_name: 模型名称 - Returns: - 找到的 (ProviderConfig, ModelDetail) 元组,未找到则返回 None + 返回: + tuple[ProviderConfig, ModelDetail] | None: 找到的配置元组,未找到则返回 None """ providers = get_configured_providers() @@ -221,10 +196,11 @@ def _get_model_identifiers(provider_name: str, model_detail: ModelDetail) -> lis def list_model_identifiers() -> dict[str, list[str]]: - """列出所有模型的可用标识符 + """ + 列出所有模型的可用标识符 - Returns: - 字典,键为模型的完整名称,值为该模型的所有可用标识符列表 + 返回: + dict[str, list[str]]: 字典,键为模型的完整名称,值为该模型的所有可用标识符列表 """ providers = get_configured_providers() result = {} @@ -248,7 +224,16 @@ async def get_model_instance( provider_model_name: str | None = None, override_config: dict[str, Any] | None = None, ) -> LLMModel: - """根据 'ProviderName/ModelName' 字符串获取并实例化 LLMModel (异步版本)""" + """ + 根据 'ProviderName/ModelName' 字符串获取并实例化 LLMModel (异步版本) + + 参数: + provider_model_name: 模型名称,格式为 'ProviderName/ModelName'。 + override_config: 覆盖配置字典。 + + 返回: + LLMModel: 模型实例。 + """ cache_key = _make_cache_key(provider_model_name, override_config) cached_model = _get_cached_model(cache_key) if cached_model: @@ -292,6 +277,10 @@ async def get_model_instance( provider_config_found, model_detail_found = config_tuple_found + capabilities = get_model_capabilities(model_detail_found.model_name) + + model_detail_found.is_embedding_model = capabilities.is_embedding_model + ai_config = get_ai_config() global_proxy_setting = ai_config.get(PROXY_KEY) default_timeout = ( @@ -322,6 +311,7 @@ async def get_model_instance( model_detail=model_detail_found, key_store=key_store, http_client=shared_http_client, + capabilities=capabilities, ) if override_config: @@ -357,7 +347,15 @@ def get_global_default_model_name() -> str | None: def set_global_default_model_name(provider_model_name: str | None) -> bool: - """设置全局默认模型名称""" + """ + 设置全局默认模型名称 + + 参数: + provider_model_name: 模型名称,格式为 'ProviderName/ModelName'。 + + 返回: + bool: 设置是否成功。 + """ if provider_model_name: prov_name, mod_name = parse_provider_model_string(provider_model_name) if not prov_name or not mod_name or not find_model_config(prov_name, mod_name): @@ -377,7 +375,12 @@ def set_global_default_model_name(provider_model_name: str | None) -> bool: async def get_key_usage_stats() -> dict[str, Any]: - """获取所有Provider的Key使用统计""" + """ + 获取所有Provider的Key使用统计 + + 返回: + dict[str, Any]: 包含所有Provider的Key使用统计信息。 + """ providers = get_configured_providers() stats = {} @@ -400,7 +403,16 @@ async def get_key_usage_stats() -> dict[str, Any]: async def reset_key_status(provider_name: str, api_key: str | None = None) -> bool: - """重置指定Provider的Key状态""" + """ + 重置指定Provider的Key状态 + + 参数: + provider_name: 提供商名称。 + api_key: 要重置的特定API密钥,如果为None则重置所有密钥。 + + 返回: + bool: 重置是否成功。 + """ providers = get_configured_providers() target_provider = None diff --git a/zhenxun/services/llm/service.py b/zhenxun/services/llm/service.py index d054ca9b..587b15cc 100644 --- a/zhenxun/services/llm/service.py +++ b/zhenxun/services/llm/service.py @@ -6,11 +6,13 @@ LLM 模型实现类 from abc import ABC, abstractmethod from collections.abc import Awaitable, Callable +from contextlib import AsyncExitStack import json from typing import Any from zhenxun.services.log import logger +from .adapters.base import RequestData from .config import LLMGenerationConfig from .config.providers import get_ai_config from .core import ( @@ -30,6 +32,8 @@ from .types import ( ModelDetail, ProviderConfig, ) +from .types.capabilities import ModelCapabilities, ModelModality +from .utils import _sanitize_request_body_for_logging class LLMModelBase(ABC): @@ -42,7 +46,17 @@ class LLMModelBase(ABC): history: list[dict[str, str]] | None = None, **kwargs: Any, ) -> str: - """生成文本""" + """ + 生成文本 + + 参数: + prompt: 输入提示词。 + history: 对话历史记录。 + **kwargs: 其他参数。 + + 返回: + str: 生成的文本。 + """ pass @abstractmethod @@ -54,7 +68,19 @@ class LLMModelBase(ABC): tool_choice: str | dict[str, Any] | None = None, **kwargs: Any, ) -> LLMResponse: - """生成高级响应""" + """ + 生成高级响应 + + 参数: + messages: 消息列表。 + config: 生成配置。 + tools: 工具列表。 + tool_choice: 工具选择策略。 + **kwargs: 其他参数。 + + 返回: + LLMResponse: 模型响应。 + """ pass @abstractmethod @@ -64,7 +90,17 @@ class LLMModelBase(ABC): task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT, **kwargs: Any, ) -> list[list[float]]: - """生成文本嵌入向量""" + """ + 生成文本嵌入向量 + + 参数: + texts: 文本列表。 + task_type: 嵌入任务类型。 + **kwargs: 其他参数。 + + 返回: + list[list[float]]: 嵌入向量列表。 + """ pass @@ -77,12 +113,14 @@ class LLMModel(LLMModelBase): model_detail: ModelDetail, key_store: KeyStatusStore, http_client: LLMHttpClient, + capabilities: ModelCapabilities, config_override: LLMGenerationConfig | None = None, ): self.provider_config = provider_config self.model_detail = model_detail self.key_store = key_store self.http_client: LLMHttpClient = http_client + self.capabilities = capabilities self._generation_config = config_override self.provider_name = provider_config.name @@ -99,6 +137,34 @@ class LLMModel(LLMModelBase): self._is_closed = False + def can_process_images(self) -> bool: + """检查模型是否支持图片作为输入。""" + return ModelModality.IMAGE in self.capabilities.input_modalities + + def can_process_video(self) -> bool: + """检查模型是否支持视频作为输入。""" + return ModelModality.VIDEO in self.capabilities.input_modalities + + def can_process_audio(self) -> bool: + """检查模型是否支持音频作为输入。""" + return ModelModality.AUDIO in self.capabilities.input_modalities + + def can_generate_images(self) -> bool: + """检查模型是否支持生成图片。""" + return ModelModality.IMAGE in self.capabilities.output_modalities + + def can_generate_audio(self) -> bool: + """检查模型是否支持生成音频 (TTS)。""" + return ModelModality.AUDIO in self.capabilities.output_modalities + + def can_use_tools(self) -> bool: + """检查模型是否支持工具调用/函数调用。""" + return self.capabilities.supports_tool_calling + + def is_embedding_model(self) -> bool: + """检查这是否是一个嵌入模型。""" + return self.capabilities.is_embedding_model + async def _get_http_client(self) -> LLMHttpClient: """获取HTTP客户端""" if self.http_client.is_closed: @@ -135,24 +201,54 @@ class LLMModel(LLMModelBase): return selected_key - async def _execute_embedding_request( + async def _perform_api_call( self, - adapter, - texts: list[str], - task_type: EmbeddingTaskType | str, - http_client: LLMHttpClient, + prepare_request_func: Callable[[str], Awaitable["RequestData"]], + parse_response_func: Callable[[dict[str, Any]], Any], + http_client: "LLMHttpClient", failed_keys: set[str] | None = None, - ) -> list[list[float]]: - """执行单次嵌入请求 - 供重试机制调用""" + log_context: str = "API", + ) -> Any: + """ + 执行API调用的通用核心方法。 + + 该方法封装了以下通用逻辑: + 1. 选择API密钥。 + 2. 准备和记录请求。 + 3. 发送HTTP POST请求。 + 4. 处理HTTP错误和API特定错误。 + 5. 记录密钥使用状态。 + 6. 解析成功的响应。 + + 参数: + prepare_request_func: 准备请求的函数。 + parse_response_func: 解析响应的函数。 + http_client: HTTP客户端。 + failed_keys: 失败的密钥集合。 + log_context: 日志上下文。 + + 返回: + Any: 解析后的响应数据。 + """ api_key = await self._select_api_key(failed_keys) try: - request_data = adapter.prepare_embedding_request( - model=self, - api_key=api_key, - texts=texts, - task_type=task_type, + request_data = await prepare_request_func(api_key) + + logger.info( + f"🌐 发起LLM请求 - 模型: {self.provider_name}/{self.model_name} " + f"[{log_context}]" ) + logger.debug(f"📡 请求URL: {request_data.url}") + masked_key = ( + f"{api_key[:8]}...{api_key[-4:] if len(api_key) > 12 else '***'}" + ) + logger.debug(f"🔑 API密钥: {masked_key}") + logger.debug(f"📋 请求头: {dict(request_data.headers)}") + + sanitized_body = _sanitize_request_body_for_logging(request_data.body) + request_body_str = json.dumps(sanitized_body, ensure_ascii=False, indent=2) + logger.debug(f"📦 请求体: {request_body_str}") http_response = await http_client.post( request_data.url, @@ -160,121 +256,16 @@ class LLMModel(LLMModelBase): json=request_data.body, ) - if http_response.status_code != 200: - error_text = http_response.text - logger.error( - f"HTTP嵌入请求失败: {http_response.status_code} - {error_text}" - ) - await self.key_store.record_failure(api_key, http_response.status_code) - - error_code = LLMErrorCode.API_REQUEST_FAILED - if http_response.status_code in [401, 403]: - error_code = LLMErrorCode.API_KEY_INVALID - elif http_response.status_code == 429: - error_code = LLMErrorCode.API_RATE_LIMITED - - raise LLMException( - f"HTTP嵌入请求失败: {http_response.status_code}", - code=error_code, - details={ - "status_code": http_response.status_code, - "response": error_text, - "api_key": api_key, - }, - ) - - try: - response_json = http_response.json() - adapter.validate_embedding_response(response_json) - embeddings = adapter.parse_embedding_response(response_json) - except Exception as e: - logger.error(f"解析嵌入响应失败: {e}", e=e) - await self.key_store.record_failure(api_key, None) - if isinstance(e, LLMException): - raise - else: - raise LLMException( - f"解析API嵌入响应失败: {e}", - code=LLMErrorCode.RESPONSE_PARSE_ERROR, - cause=e, - ) - - await self.key_store.record_success(api_key) - return embeddings - - except LLMException: - raise - except Exception as e: - logger.error(f"生成嵌入时发生未预期错误: {e}", e=e) - await self.key_store.record_failure(api_key, None) - raise LLMException( - f"生成嵌入失败: {e}", - code=LLMErrorCode.EMBEDDING_FAILED, - cause=e, - ) - - async def _execute_with_smart_retry( - self, - adapter, - messages: list[LLMMessage], - config: LLMGenerationConfig | None, - tools_dict: list[dict[str, Any]] | None, - tool_choice: str | dict[str, Any] | None, - http_client: LLMHttpClient, - ): - """智能重试机制 - 使用统一的重试装饰器""" - ai_config = get_ai_config() - max_retries = ai_config.get("max_retries_llm", 3) - retry_delay = ai_config.get("retry_delay_llm", 2) - retry_config = RetryConfig(max_retries=max_retries, retry_delay=retry_delay) - - return await with_smart_retry( - self._execute_single_request, - adapter, - messages, - config, - tools_dict, - tool_choice, - http_client, - retry_config=retry_config, - key_store=self.key_store, - provider_name=self.provider_name, - ) - - async def _execute_single_request( - self, - adapter, - messages: list[LLMMessage], - config: LLMGenerationConfig | None, - tools_dict: list[dict[str, Any]] | None, - tool_choice: str | dict[str, Any] | None, - http_client: LLMHttpClient, - failed_keys: set[str] | None = None, - ) -> LLMResponse: - """执行单次请求 - 供重试机制调用,直接返回 LLMResponse""" - api_key = await self._select_api_key(failed_keys) - - try: - request_data = adapter.prepare_advanced_request( - model=self, - api_key=api_key, - messages=messages, - config=config, - tools=tools_dict, - tool_choice=tool_choice, - ) - - http_response = await http_client.post( - request_data.url, - headers=request_data.headers, - json=request_data.body, - ) + logger.debug(f"📥 响应状态码: {http_response.status_code}") + logger.debug(f"📄 响应头: {dict(http_response.headers)}") if http_response.status_code != 200: error_text = http_response.text logger.error( - f"HTTP请求失败: {http_response.status_code} - {error_text}" + f"❌ HTTP请求失败: {http_response.status_code} - {error_text} " + f"[{log_context}]" ) + logger.debug(f"💥 完整错误响应: {error_text}") await self.key_store.record_failure(api_key, http_response.status_code) @@ -299,69 +290,165 @@ class LLMModel(LLMModelBase): try: response_json = http_response.json() - response_data = adapter.parse_response( - model=self, - response_json=response_json, - is_advanced=True, - ) - - from .types.models import LLMToolCall - - response_tool_calls = [] - if response_data.tool_calls: - for tc_data in response_data.tool_calls: - if isinstance(tc_data, LLMToolCall): - response_tool_calls.append(tc_data) - elif isinstance(tc_data, dict): - try: - response_tool_calls.append(LLMToolCall(**tc_data)) - except Exception as e: - logger.warning( - f"无法将工具调用数据转换为LLMToolCall: {tc_data}, " - f"error: {e}" - ) - else: - logger.warning(f"工具调用数据格式未知: {tc_data}") - - llm_response = LLMResponse( - text=response_data.text, - usage_info=response_data.usage_info, - raw_response=response_data.raw_response, - tool_calls=response_tool_calls if response_tool_calls else None, - code_executions=response_data.code_executions, - grounding_metadata=response_data.grounding_metadata, - cache_info=response_data.cache_info, + response_json_str = json.dumps( + response_json, ensure_ascii=False, indent=2 ) + logger.debug(f"📋 响应JSON: {response_json_str}") + parsed_data = parse_response_func(response_json) except Exception as e: - logger.error(f"解析响应失败: {e}", e=e) + logger.error(f"解析 {log_context} 响应失败: {e}", e=e) await self.key_store.record_failure(api_key, None) - if isinstance(e, LLMException): raise else: raise LLMException( - f"解析API响应失败: {e}", + f"解析API {log_context} 响应失败: {e}", code=LLMErrorCode.RESPONSE_PARSE_ERROR, cause=e, ) await self.key_store.record_success(api_key) - - return llm_response + logger.debug(f"✅ API密钥使用成功: {masked_key}") + logger.info(f"🎯 LLM响应解析完成 [{log_context}]") + return parsed_data except LLMException: raise except Exception as e: - logger.error(f"生成响应时发生未预期错误: {e}", e=e) + error_log_msg = f"生成 {log_context.lower()} 时发生未预期错误: {e}" + logger.error(error_log_msg, e=e) await self.key_store.record_failure(api_key, None) - raise LLMException( - f"生成响应失败: {e}", - code=LLMErrorCode.GENERATION_FAILED, + error_log_msg, + code=LLMErrorCode.GENERATION_FAILED + if log_context == "Generation" + else LLMErrorCode.EMBEDDING_FAILED, cause=e, ) + async def _execute_embedding_request( + self, + adapter, + texts: list[str], + task_type: EmbeddingTaskType | str, + http_client: LLMHttpClient, + failed_keys: set[str] | None = None, + ) -> list[list[float]]: + """执行单次嵌入请求 - 供重试机制调用""" + + async def prepare_request(api_key: str) -> RequestData: + return adapter.prepare_embedding_request( + model=self, + api_key=api_key, + texts=texts, + task_type=task_type, + ) + + def parse_response(response_json: dict[str, Any]) -> list[list[float]]: + adapter.validate_embedding_response(response_json) + return adapter.parse_embedding_response(response_json) + + return await self._perform_api_call( + prepare_request_func=prepare_request, + parse_response_func=parse_response, + http_client=http_client, + failed_keys=failed_keys, + log_context="Embedding", + ) + + async def _execute_with_smart_retry( + self, + adapter, + messages: list[LLMMessage], + config: LLMGenerationConfig | None, + tools: list[LLMTool] | None, + tool_choice: str | dict[str, Any] | None, + http_client: LLMHttpClient, + ): + """智能重试机制 - 使用统一的重试装饰器""" + ai_config = get_ai_config() + max_retries = ai_config.get("max_retries_llm", 3) + retry_delay = ai_config.get("retry_delay_llm", 2) + retry_config = RetryConfig(max_retries=max_retries, retry_delay=retry_delay) + + return await with_smart_retry( + self._execute_single_request, + adapter, + messages, + config, + tools, + tool_choice, + http_client, + retry_config=retry_config, + key_store=self.key_store, + provider_name=self.provider_name, + ) + + async def _execute_single_request( + self, + adapter, + messages: list[LLMMessage], + config: LLMGenerationConfig | None, + tools: list[LLMTool] | None, + tool_choice: str | dict[str, Any] | None, + http_client: LLMHttpClient, + failed_keys: set[str] | None = None, + ) -> LLMResponse: + """执行单次请求 - 供重试机制调用,直接返回 LLMResponse""" + + async def prepare_request(api_key: str) -> RequestData: + return await adapter.prepare_advanced_request( + model=self, + api_key=api_key, + messages=messages, + config=config, + tools=tools, + tool_choice=tool_choice, + ) + + def parse_response(response_json: dict[str, Any]) -> LLMResponse: + response_data = adapter.parse_response( + model=self, + response_json=response_json, + is_advanced=True, + ) + from .types.models import LLMToolCall + + response_tool_calls = [] + if response_data.tool_calls: + for tc_data in response_data.tool_calls: + if isinstance(tc_data, LLMToolCall): + response_tool_calls.append(tc_data) + elif isinstance(tc_data, dict): + try: + response_tool_calls.append(LLMToolCall(**tc_data)) + except Exception as e: + logger.warning( + f"无法将工具调用数据转换为LLMToolCall: {tc_data}, " + f"error: {e}" + ) + else: + logger.warning(f"工具调用数据格式未知: {tc_data}") + + return LLMResponse( + text=response_data.text, + usage_info=response_data.usage_info, + raw_response=response_data.raw_response, + tool_calls=response_tool_calls if response_tool_calls else None, + code_executions=response_data.code_executions, + grounding_metadata=response_data.grounding_metadata, + cache_info=response_data.cache_info, + ) + + return await self._perform_api_call( + prepare_request_func=prepare_request, + parse_response_func=parse_response, + http_client=http_client, + failed_keys=failed_keys, + log_context="Generation", + ) + async def close(self): """ 标记模型实例的当前使用周期结束。 @@ -400,7 +487,17 @@ class LLMModel(LLMModelBase): history: list[dict[str, str]] | None = None, **kwargs: Any, ) -> str: - """生成文本 - 通过 generate_response 实现""" + """ + 生成文本 - 通过 generate_response 实现 + + 参数: + prompt: 输入提示词。 + history: 对话历史记录。 + **kwargs: 其他参数。 + + 返回: + str: 生成的文本。 + """ self._check_not_closed() messages: list[LLMMessage] = [] @@ -439,11 +536,21 @@ class LLMModel(LLMModelBase): config: LLMGenerationConfig | None = None, tools: list[LLMTool] | None = None, tool_choice: str | dict[str, Any] | None = None, - tool_executor: Callable[[str, dict[str, Any]], Awaitable[Any]] | None = None, - max_tool_iterations: int = 5, **kwargs: Any, ) -> LLMResponse: - """生成高级响应 - 实现完整的工具调用循环""" + """ + 生成高级响应 + + 参数: + messages: 消息列表。 + config: 生成配置。 + tools: 工具列表。 + tool_choice: 工具选择策略。 + **kwargs: 其他参数。 + + 返回: + LLMResponse: 模型响应。 + """ self._check_not_closed() from .adapters import get_adapter_for_api_type @@ -468,109 +575,43 @@ class LLMModel(LLMModelBase): merged_dict.update(config.to_dict()) final_request_config = LLMGenerationConfig(**merged_dict) - tools_dict: list[dict[str, Any]] | None = None - if tools: - tools_dict = [] - for tool in tools: - if hasattr(tool, "model_dump"): - model_dump_func = getattr(tool, "model_dump") - tools_dict.append(model_dump_func(exclude_none=True)) - elif isinstance(tool, dict): - tools_dict.append(tool) - else: - try: - tools_dict.append(dict(tool)) - except (TypeError, ValueError): - logger.warning(f"工具 '{tool}' 无法转换为字典,已忽略。") - http_client = await self._get_http_client() - current_messages = list(messages) - for iteration in range(max_tool_iterations): - logger.debug(f"工具调用循环迭代: {iteration + 1}/{max_tool_iterations}") + async with AsyncExitStack() as stack: + activated_tools = [] + if tools: + for tool in tools: + if tool.type == "mcp" and callable(tool.mcp_session): + func_obj = getattr(tool.mcp_session, "func", None) + tool_name = ( + getattr(func_obj, "__name__", "unknown") + if func_obj + else "unknown" + ) + logger.debug(f"正在激活 MCP 工具会话: {tool_name}") + + active_session = await stack.enter_async_context( + tool.mcp_session() + ) + + activated_tools.append( + LLMTool.from_mcp_session( + session=active_session, annotations=tool.annotations + ) + ) + else: + activated_tools.append(tool) llm_response = await self._execute_with_smart_retry( adapter, - current_messages, + messages, final_request_config, - tools_dict if iteration == 0 else None, - tool_choice if iteration == 0 else None, + activated_tools if activated_tools else None, + tool_choice, http_client, ) - response_tool_calls = llm_response.tool_calls or [] - - if not response_tool_calls or not tool_executor: - logger.debug("模型未请求工具调用,或未提供工具执行器。返回当前响应。") - return llm_response - - logger.info(f"模型请求执行 {len(response_tool_calls)} 个工具。") - - assistant_message_content = llm_response.text if llm_response.text else "" - current_messages.append( - LLMMessage.assistant_tool_calls( - content=assistant_message_content, tool_calls=response_tool_calls - ) - ) - - tool_response_messages: list[LLMMessage] = [] - for tool_call in response_tool_calls: - tool_name = tool_call.function.name - try: - tool_args_dict = json.loads(tool_call.function.arguments) - logger.debug(f"执行工具: {tool_name},参数: {tool_args_dict}") - - tool_result = await tool_executor(tool_name, tool_args_dict) - logger.debug( - f"工具 '{tool_name}' 执行结果: {str(tool_result)[:200]}..." - ) - - tool_response_messages.append( - LLMMessage.tool_response( - tool_call_id=tool_call.id, - function_name=tool_name, - result=tool_result, - ) - ) - except json.JSONDecodeError as e: - logger.error( - f"工具 '{tool_name}' 参数JSON解析失败: " - f"{tool_call.function.arguments}, 错误: {e}" - ) - tool_response_messages.append( - LLMMessage.tool_response( - tool_call_id=tool_call.id, - function_name=tool_name, - result={ - "error": "Argument JSON parsing failed", - "details": str(e), - }, - ) - ) - except Exception as e: - logger.error(f"执行工具 '{tool_name}' 失败: {e}", e=e) - tool_response_messages.append( - LLMMessage.tool_response( - tool_call_id=tool_call.id, - function_name=tool_name, - result={ - "error": "Tool execution failed", - "details": str(e), - }, - ) - ) - - current_messages.extend(tool_response_messages) - - logger.warning(f"已达到最大工具调用迭代次数 ({max_tool_iterations})。") - raise LLMException( - "已达到最大工具调用迭代次数,但模型仍在请求工具调用或未提供最终文本回复。", - code=LLMErrorCode.GENERATION_FAILED, - details={ - "iterations": max_tool_iterations, - "last_messages": current_messages[-2:], - }, - ) + return llm_response async def generate_embeddings( self, @@ -578,7 +619,17 @@ class LLMModel(LLMModelBase): task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT, **kwargs: Any, ) -> list[list[float]]: - """生成文本嵌入向量""" + """ + 生成文本嵌入向量 + + 参数: + texts: 文本列表。 + task_type: 嵌入任务类型。 + **kwargs: 其他参数。 + + 返回: + list[list[float]]: 嵌入向量列表。 + """ self._check_not_closed() if not texts: return [] diff --git a/zhenxun/services/llm/tools/__init__.py b/zhenxun/services/llm/tools/__init__.py new file mode 100644 index 00000000..3c62ed2a --- /dev/null +++ b/zhenxun/services/llm/tools/__init__.py @@ -0,0 +1,7 @@ +""" +工具模块导出 +""" + +from .registry import tool_registry + +__all__ = ["tool_registry"] diff --git a/zhenxun/services/llm/tools/registry.py b/zhenxun/services/llm/tools/registry.py new file mode 100644 index 00000000..daa0c796 --- /dev/null +++ b/zhenxun/services/llm/tools/registry.py @@ -0,0 +1,181 @@ +""" +工具注册表 + +负责加载、管理和实例化来自配置的工具。 +""" + +from collections.abc import Callable +from contextlib import AbstractAsyncContextManager +from functools import partial +from typing import TYPE_CHECKING + +from pydantic import BaseModel + +from zhenxun.services.log import logger + +from ..types import LLMTool + +if TYPE_CHECKING: + from ..config.providers import ToolConfig + from ..types.protocols import MCPCompatible + + +class ToolRegistry: + """工具注册表,用于管理和实例化配置的工具。""" + + def __init__(self): + self._function_tools: dict[str, LLMTool] = {} + + self._mcp_config_models: dict[str, type[BaseModel]] = {} + if TYPE_CHECKING: + self._mcp_factories: dict[ + str, Callable[..., AbstractAsyncContextManager["MCPCompatible"]] + ] = {} + else: + self._mcp_factories: dict[str, Callable] = {} + + self._tool_configs: dict[str, "ToolConfig"] | None = None + self._tool_cache: dict[str, "LLMTool"] = {} + + def _load_configs_if_needed(self): + """如果尚未加载,则从主配置中加载MCP工具定义。""" + if self._tool_configs is None: + logger.debug("首次访问,正在加载MCP工具配置...") + from ..config.providers import get_llm_config + + llm_config = get_llm_config() + self._tool_configs = {tool.name: tool for tool in llm_config.mcp_tools} + logger.info(f"已加载 {len(self._tool_configs)} 个MCP工具配置。") + + def function_tool( + self, + name: str, + description: str, + parameters: dict, + required: list[str] | None = None, + ): + """ + 装饰器:在代码中注册一个简单的、无状态的函数工具。 + + 参数: + name: 工具的唯一名称。 + description: 工具功能的描述。 + parameters: OpenAPI格式的函数参数schema的properties部分。 + required: 必需的参数列表。 + """ + + def decorator(func: Callable): + if name in self._function_tools or name in self._mcp_factories: + logger.warning(f"正在覆盖已注册的工具: {name}") + + tool_definition = LLMTool.create( + name=name, + description=description, + parameters=parameters, + required=required, + ) + self._function_tools[name] = tool_definition + logger.info(f"已在代码中注册函数工具: '{name}'") + tool_definition.annotations = tool_definition.annotations or {} + tool_definition.annotations["executable"] = func + return func + + return decorator + + def mcp_tool(self, name: str, config_model: type[BaseModel]): + """ + 装饰器:注册一个MCP工具及其配置模型。 + + 参数: + name: 工具的唯一名称,必须与配置文件中的名称匹配。 + config_model: 一个Pydantic模型,用于定义和验证该工具的 `mcp_config`。 + """ + + def decorator(factory_func: Callable): + if name in self._mcp_factories: + logger.warning(f"正在覆盖已注册的 MCP 工厂: {name}") + self._mcp_factories[name] = factory_func + self._mcp_config_models[name] = config_model + logger.info(f"已注册 MCP 工具 '{name}' (配置模型: {config_model.__name__})") + return factory_func + + return decorator + + def get_mcp_config_model(self, name: str) -> type[BaseModel] | None: + """根据名称获取MCP工具的配置模型。""" + return self._mcp_config_models.get(name) + + def register_mcp_factory( + self, + name: str, + factory: Callable, + ): + """ + 在代码中注册一个 MCP 会话工厂,将其与配置中的工具名称关联。 + + 参数: + name: 工具的唯一名称,必须与配置文件中的名称匹配。 + factory: 一个返回异步生成器的可调用对象(会话工厂)。 + """ + if name in self._mcp_factories: + logger.warning(f"正在覆盖已注册的 MCP 工厂: {name}") + self._mcp_factories[name] = factory + logger.info(f"已注册 MCP 会话工厂: '{name}'") + + def get_tool(self, name: str) -> "LLMTool": + """ + 根据名称获取一个 LLMTool 定义。 + 对于MCP工具,返回的 LLMTool 实例包含一个可调用的会话工厂, + 而不是一个已激活的会话。 + """ + logger.debug(f"🔍 请求获取工具定义: {name}") + + if name in self._tool_cache: + logger.debug(f"✅ 从缓存中获取工具定义: {name}") + return self._tool_cache[name] + + if name in self._function_tools: + logger.debug(f"🛠️ 获取函数工具定义: {name}") + tool = self._function_tools[name] + self._tool_cache[name] = tool + return tool + + self._load_configs_if_needed() + if self._tool_configs is None or name not in self._tool_configs: + known_tools = list(self._function_tools.keys()) + ( + list(self._tool_configs.keys()) if self._tool_configs else [] + ) + logger.error(f"❌ 未找到名为 '{name}' 的工具定义") + logger.debug(f"📋 可用工具定义列表: {known_tools}") + raise ValueError(f"未找到名为 '{name}' 的工具定义。已知工具: {known_tools}") + + config = self._tool_configs[name] + tool: "LLMTool" + + if name not in self._mcp_factories: + logger.error(f"❌ MCP工具 '{name}' 缺少工厂函数") + available_factories = list(self._mcp_factories.keys()) + logger.debug(f"📋 已注册的MCP工厂: {available_factories}") + raise ValueError( + f"MCP 工具 '{name}' 已在配置中定义,但没有注册对应的工厂函数。" + "请使用 `@tool_registry.mcp_tool` 装饰器进行注册。" + ) + + logger.info(f"🔧 创建MCP工具定义: {name}") + factory = self._mcp_factories[name] + typed_mcp_config = config.mcp_config + logger.debug(f"📋 MCP工具配置: {typed_mcp_config}") + + configured_factory = partial(factory, config=typed_mcp_config) + tool = LLMTool.from_mcp_session(session=configured_factory) + + self._tool_cache[name] = tool + logger.debug(f"💾 MCP工具定义已缓存: {name}") + return tool + + def get_tools(self, names: list[str]) -> list["LLMTool"]: + """根据名称列表获取多个 LLMTool 实例。""" + return [self.get_tool(name) for name in names] + + +tool_registry = ToolRegistry() diff --git a/zhenxun/services/llm/types/__init__.py b/zhenxun/services/llm/types/__init__.py index ebae4185..f01bc291 100644 --- a/zhenxun/services/llm/types/__init__.py +++ b/zhenxun/services/llm/types/__init__.py @@ -4,6 +4,7 @@ LLM 类型定义模块 统一导出所有核心类型、协议和异常定义。 """ +from .capabilities import ModelCapabilities, ModelModality, get_model_capabilities from .content import ( LLMContentPart, LLMMessage, @@ -26,6 +27,7 @@ from .models import ( ToolMetadata, UsageInfo, ) +from .protocols import MCPCompatible __all__ = [ "EmbeddingTaskType", @@ -41,8 +43,11 @@ __all__ = [ "LLMTool", "LLMToolCall", "LLMToolFunction", + "MCPCompatible", + "ModelCapabilities", "ModelDetail", "ModelInfo", + "ModelModality", "ModelName", "ModelProvider", "ProviderConfig", @@ -50,5 +55,6 @@ __all__ = [ "ToolCategory", "ToolMetadata", "UsageInfo", + "get_model_capabilities", "get_user_friendly_error_message", ] diff --git a/zhenxun/services/llm/types/capabilities.py b/zhenxun/services/llm/types/capabilities.py new file mode 100644 index 00000000..fc25cf7e --- /dev/null +++ b/zhenxun/services/llm/types/capabilities.py @@ -0,0 +1,128 @@ +""" +LLM 模型能力定义模块 + +定义模型的输入输出模态、工具调用支持等核心能力。 +""" + +from enum import Enum +import fnmatch + +from pydantic import BaseModel, Field + + +class ModelModality(str, Enum): + TEXT = "text" + IMAGE = "image" + AUDIO = "audio" + VIDEO = "video" + EMBEDDING = "embedding" + + +class ModelCapabilities(BaseModel): + """定义一个模型的核心、稳定能力。""" + + input_modalities: set[ModelModality] = Field(default={ModelModality.TEXT}) + output_modalities: set[ModelModality] = Field(default={ModelModality.TEXT}) + supports_tool_calling: bool = False + is_embedding_model: bool = False + + +STANDARD_TEXT_TOOL_CAPABILITIES = ModelCapabilities( + input_modalities={ModelModality.TEXT}, + output_modalities={ModelModality.TEXT}, + supports_tool_calling=True, +) + +GEMINI_CAPABILITIES = ModelCapabilities( + input_modalities={ + ModelModality.TEXT, + ModelModality.IMAGE, + ModelModality.AUDIO, + ModelModality.VIDEO, + }, + output_modalities={ModelModality.TEXT}, + supports_tool_calling=True, +) + +DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES = ModelCapabilities( + input_modalities={ModelModality.TEXT, ModelModality.IMAGE, ModelModality.VIDEO}, + output_modalities={ModelModality.TEXT}, + supports_tool_calling=True, +) + + +MODEL_ALIAS_MAPPING: dict[str, str] = { + "deepseek-v3*": "deepseek-chat", + "deepseek-ai/DeepSeek-V3": "deepseek-chat", + "deepseek-r1*": "deepseek-reasoner", +} + + +MODEL_CAPABILITIES_REGISTRY: dict[str, ModelCapabilities] = { + "gemini-*-tts": ModelCapabilities( + input_modalities={ModelModality.TEXT}, + output_modalities={ModelModality.AUDIO}, + ), + "gemini-*-native-audio-*": ModelCapabilities( + input_modalities={ModelModality.TEXT, ModelModality.AUDIO, ModelModality.VIDEO}, + output_modalities={ModelModality.TEXT, ModelModality.AUDIO}, + supports_tool_calling=True, + ), + "gemini-2.0-flash-preview-image-generation": ModelCapabilities( + input_modalities={ + ModelModality.TEXT, + ModelModality.IMAGE, + ModelModality.AUDIO, + ModelModality.VIDEO, + }, + output_modalities={ModelModality.TEXT, ModelModality.IMAGE}, + supports_tool_calling=True, + ), + "gemini-embedding-exp": ModelCapabilities( + input_modalities={ModelModality.TEXT}, + output_modalities={ModelModality.EMBEDDING}, + is_embedding_model=True, + ), + "gemini-2.5-pro*": GEMINI_CAPABILITIES, + "gemini-1.5-pro*": GEMINI_CAPABILITIES, + "gemini-2.5-flash*": GEMINI_CAPABILITIES, + "gemini-2.0-flash*": GEMINI_CAPABILITIES, + "gemini-1.5-flash*": GEMINI_CAPABILITIES, + "GLM-4V-Flash": ModelCapabilities( + input_modalities={ModelModality.TEXT, ModelModality.IMAGE}, + output_modalities={ModelModality.TEXT}, + supports_tool_calling=True, + ), + "GLM-4V-Plus*": ModelCapabilities( + input_modalities={ModelModality.TEXT, ModelModality.IMAGE, ModelModality.VIDEO}, + output_modalities={ModelModality.TEXT}, + supports_tool_calling=True, + ), + "glm-4-*": STANDARD_TEXT_TOOL_CAPABILITIES, + "glm-z1-*": STANDARD_TEXT_TOOL_CAPABILITIES, + "doubao-seed-*": DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES, + "doubao-1-5-thinking-vision-pro": DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES, + "deepseek-chat": STANDARD_TEXT_TOOL_CAPABILITIES, + "deepseek-reasoner": STANDARD_TEXT_TOOL_CAPABILITIES, +} + + +def get_model_capabilities(model_name: str) -> ModelCapabilities: + """ + 从注册表获取模型能力,支持别名映射和通配符匹配。 + 查找顺序: 1. 标准化名称 -> 2. 精确匹配 -> 3. 通配符匹配 -> 4. 默认值 + """ + canonical_name = model_name + for alias_pattern, c_name in MODEL_ALIAS_MAPPING.items(): + if fnmatch.fnmatch(model_name, alias_pattern): + canonical_name = c_name + break + + if canonical_name in MODEL_CAPABILITIES_REGISTRY: + return MODEL_CAPABILITIES_REGISTRY[canonical_name] + + for pattern, capabilities in MODEL_CAPABILITIES_REGISTRY.items(): + if "*" in pattern and fnmatch.fnmatch(model_name, pattern): + return capabilities + + return ModelCapabilities() diff --git a/zhenxun/services/llm/types/content.py b/zhenxun/services/llm/types/content.py index 54887bc3..9dc10821 100644 --- a/zhenxun/services/llm/types/content.py +++ b/zhenxun/services/llm/types/content.py @@ -225,8 +225,10 @@ class LLMContentPart(BaseModel): logger.warning(f"无法解析Base64图像数据: {self.image_source[:50]}...") return None - def convert_for_api(self, api_type: str) -> dict[str, Any]: + async def convert_for_api_async(self, api_type: str) -> dict[str, Any]: """根据API类型转换多模态内容格式""" + from zhenxun.utils.http_utils import AsyncHttpx + if self.type == "text": if api_type == "openai": return {"type": "text", "text": self.text} @@ -248,20 +250,23 @@ class LLMContentPart(BaseModel): mime_type, data = base64_info return {"inlineData": {"mimeType": mime_type, "data": data}} else: - # 如果无法解析 Base64 数据,抛出异常 raise ValueError( f"无法解析Base64图像数据: {self.image_source[:50]}..." ) - else: - logger.warning( - f"Gemini API需要Base64格式,但提供的是URL: {self.image_source}" - ) - return { - "inlineData": { - "mimeType": "image/jpeg", - "data": self.image_source, + elif self.is_image_url(): + logger.debug(f"正在为Gemini下载并编码URL图片: {self.image_source}") + try: + image_bytes = await AsyncHttpx.get_content(self.image_source) + mime_type = self.mime_type or "image/jpeg" + base64_data = base64.b64encode(image_bytes).decode("utf-8") + return { + "inlineData": {"mimeType": mime_type, "data": base64_data} } - } + except Exception as e: + logger.error(f"下载或编码URL图片失败: {e}", e=e) + raise ValueError(f"无法处理图片URL: {e}") + else: + raise ValueError(f"不支持的图像源格式: {self.image_source[:50]}...") else: return {"type": "image_url", "image_url": {"url": self.image_source}} diff --git a/zhenxun/services/llm/types/models.py b/zhenxun/services/llm/types/models.py index c5f541bc..ce574d53 100644 --- a/zhenxun/services/llm/types/models.py +++ b/zhenxun/services/llm/types/models.py @@ -4,13 +4,25 @@ LLM 数据模型定义 包含模型信息、配置、工具定义和响应数据的模型类。 """ +from collections.abc import Callable +from contextlib import AbstractAsyncContextManager from dataclasses import dataclass, field -from typing import Any +from typing import TYPE_CHECKING, Any from pydantic import BaseModel, Field from .enums import ModelProvider, ToolCategory +if TYPE_CHECKING: + from .protocols import MCPCompatible + + MCPSessionType = ( + MCPCompatible | Callable[[], AbstractAsyncContextManager[MCPCompatible]] | None + ) +else: + MCPCompatible = object + MCPSessionType = Any + ModelName = str | None @@ -98,10 +110,21 @@ class LLMToolCall(BaseModel): class LLMTool(BaseModel): """LLM 工具定义(支持 MCP 风格)""" + model_config = {"arbitrary_types_allowed": True} + type: str = "function" - function: dict[str, Any] + function: dict[str, Any] | None = None + mcp_session: MCPSessionType = None annotations: dict[str, Any] | None = Field(default=None, description="工具注解") + def model_post_init(self, /, __context: Any) -> None: + """验证工具定义的有效性""" + _ = __context + if self.type == "function" and self.function is None: + raise ValueError("函数类型的工具必须包含 'function' 字段。") + if self.type == "mcp" and self.mcp_session is None: + raise ValueError("MCP 类型的工具必须包含 'mcp_session' 字段。") + @classmethod def create( cls, @@ -111,7 +134,7 @@ class LLMTool(BaseModel): required: list[str] | None = None, annotations: dict[str, Any] | None = None, ) -> "LLMTool": - """创建工具""" + """创建函数工具""" function_def = { "name": name, "description": description, @@ -123,6 +146,15 @@ class LLMTool(BaseModel): } return cls(type="function", function=function_def, annotations=annotations) + @classmethod + def from_mcp_session( + cls, + session: Any, + annotations: dict[str, Any] | None = None, + ) -> "LLMTool": + """从 MCP 会话创建工具""" + return cls(type="mcp", mcp_session=session, annotations=annotations) + class LLMCodeExecution(BaseModel): """代码执行结果""" diff --git a/zhenxun/services/llm/types/protocols.py b/zhenxun/services/llm/types/protocols.py new file mode 100644 index 00000000..1ab1ace2 --- /dev/null +++ b/zhenxun/services/llm/types/protocols.py @@ -0,0 +1,24 @@ +""" +LLM 模块的协议定义 +""" + +from typing import Any, Protocol + + +class MCPCompatible(Protocol): + """ + 一个协议,定义了与LLM模块兼容的MCP会话对象应具备的行为。 + 任何实现了 to_api_tool 方法的对象都可以被认为是 MCPCompatible。 + """ + + def to_api_tool(self, api_type: str) -> dict[str, Any]: + """ + 将此MCP会话转换为特定LLM提供商API所需的工具格式。 + + 参数: + api_type: 目标API的类型 (例如 'gemini', 'openai')。 + + 返回: + dict[str, Any]: 一个字典,代表可以在API请求中使用的工具定义。 + """ + ... diff --git a/zhenxun/services/llm/utils.py b/zhenxun/services/llm/utils.py index 3610df27..d5e9177d 100644 --- a/zhenxun/services/llm/utils.py +++ b/zhenxun/services/llm/utils.py @@ -3,8 +3,10 @@ LLM 模块的工具和转换函数 """ import base64 +import copy from pathlib import Path +from nonebot.adapters import Message as PlatformMessage from nonebot_plugin_alconna.uniseg import ( At, File, @@ -17,6 +19,7 @@ from nonebot_plugin_alconna.uniseg import ( ) from zhenxun.services.log import logger +from zhenxun.utils.http_utils import AsyncHttpx from .types import LLMContentPart @@ -25,6 +28,12 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]: """ 将 UniMessage 实例转换为一个 LLMContentPart 列表。 这是处理多模态输入的核心转换逻辑。 + + 参数: + message: 要转换的UniMessage实例。 + + 返回: + list[LLMContentPart]: 转换后的内容部分列表。 """ parts: list[LLMContentPart] = [] for seg in message: @@ -51,14 +60,25 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]: if seg.path: part = await LLMContentPart.from_path(seg.path) elif seg.url: - logger.warning( - f"直接使用 URL 的 {type(seg).__name__} 段," - f"API 可能不支持: {seg.url}" - ) - part = LLMContentPart.text_part( - f"[{type(seg).__name__.upper()} FILE: {seg.name or seg.url}]" - ) - elif hasattr(seg, "raw") and seg.raw: + try: + logger.debug(f"检测到媒体URL,开始下载: {seg.url}") + media_bytes = await AsyncHttpx.get_content(seg.url) + + new_seg = copy.copy(seg) + new_seg.raw = media_bytes + seg = new_seg + logger.debug(f"媒体文件下载成功,大小: {len(media_bytes)} bytes") + except Exception as e: + logger.error(f"从URL下载媒体失败: {seg.url}, 错误: {e}") + part = LLMContentPart.text_part( + f"[下载媒体失败: {seg.name or seg.url}]" + ) + + if part: + parts.append(part) + continue + + if hasattr(seg, "raw") and seg.raw: mime_type = getattr(seg, "mimetype", None) if isinstance(seg.raw, bytes): b64_data = base64.b64encode(seg.raw).decode("utf-8") @@ -127,50 +147,19 @@ def create_multimodal_message( audio_mimetypes: list[str] | str | None = None, ) -> UniMessage: """ - 创建多模态消息的便捷函数,方便第三方调用。 + 创建多模态消息的便捷函数 - Args: + 参数: text: 文本内容 images: 图片数据,支持路径、字节数据或URL - videos: 视频数据,支持路径、字节数据或URL - audios: 音频数据,支持路径、字节数据或URL - image_mimetypes: 图片MIME类型,当images为bytes时需要指定 - video_mimetypes: 视频MIME类型,当videos为bytes时需要指定 - audio_mimetypes: 音频MIME类型,当audios为bytes时需要指定 + videos: 视频数据 + audios: 音频数据 + image_mimetypes: 图片MIME类型,bytes数据时需要指定 + video_mimetypes: 视频MIME类型,bytes数据时需要指定 + audio_mimetypes: 音频MIME类型,bytes数据时需要指定 - Returns: + 返回: UniMessage: 构建好的多模态消息 - - Examples: - # 纯文本 - msg = create_multimodal_message("请分析这段文字") - - # 文本 + 单张图片(路径) - msg = create_multimodal_message("分析图片", images="/path/to/image.jpg") - - # 文本 + 多张图片 - msg = create_multimodal_message( - "比较图片", images=["/path/1.jpg", "/path/2.jpg"] - ) - - # 文本 + 图片字节数据 - msg = create_multimodal_message( - "分析", images=image_data, image_mimetypes="image/jpeg" - ) - - # 文本 + 视频 - msg = create_multimodal_message("分析视频", videos="/path/to/video.mp4") - - # 文本 + 音频 - msg = create_multimodal_message("转录音频", audios="/path/to/audio.wav") - - # 混合多模态 - msg = create_multimodal_message( - "分析这些媒体文件", - images="/path/to/image.jpg", - videos="/path/to/video.mp4", - audios="/path/to/audio.wav" - ) """ message = UniMessage() @@ -196,7 +185,7 @@ def _add_media_to_message( media_class: type, default_mimetype: str, ) -> None: - """添加媒体文件到 UniMessage 的辅助函数""" + """添加媒体文件到 UniMessage""" if not isinstance(media_items, list): media_items = [media_items] @@ -216,3 +205,80 @@ def _add_media_to_message( elif isinstance(item, bytes): mimetype = mime_list[i] if i < len(mime_list) else default_mimetype message.append(media_class(raw=item, mimetype=mimetype)) + + +def message_to_unimessage(message: PlatformMessage) -> UniMessage: + """ + 将平台特定的 Message 对象转换为通用的 UniMessage。 + 主要用于处理引用消息等未被自动转换的消息体。 + + 参数: + message: 平台特定的Message对象。 + + 返回: + UniMessage: 转换后的通用消息对象。 + """ + uni_segments = [] + for seg in message: + if seg.type == "text": + uni_segments.append(Text(seg.data.get("text", ""))) + elif seg.type == "image": + uni_segments.append(Image(url=seg.data.get("url"))) + elif seg.type == "record": + uni_segments.append(Voice(url=seg.data.get("url"))) + elif seg.type == "video": + uni_segments.append(Video(url=seg.data.get("url"))) + elif seg.type == "at": + uni_segments.append(At("user", str(seg.data.get("qq", "")))) + else: + logger.debug(f"跳过不支持的平台消息段类型: {seg.type}") + + return UniMessage(uni_segments) + + +def _sanitize_request_body_for_logging(body: dict) -> dict: + """ + 净化请求体用于日志记录,移除大数据字段并添加摘要信息 + + 参数: + body: 原始请求体字典。 + + 返回: + dict: 净化后的请求体字典。 + """ + try: + sanitized_body = copy.deepcopy(body) + + if "contents" in sanitized_body and isinstance( + sanitized_body["contents"], list + ): + for content_item in sanitized_body["contents"]: + if "parts" in content_item and isinstance(content_item["parts"], list): + media_summary = [] + new_parts = [] + for part in content_item["parts"]: + if "inlineData" in part and isinstance( + part["inlineData"], dict + ): + data = part["inlineData"].get("data") + if isinstance(data, str): + mime_type = part["inlineData"].get( + "mimeType", "unknown" + ) + media_summary.append(f"{mime_type} ({len(data)} chars)") + continue + new_parts.append(part) + + if media_summary: + summary_text = ( + f"[多模态内容: {len(media_summary)}个文件 - " + f"{', '.join(media_summary)}]" + ) + new_parts.insert(0, {"text": summary_text}) + + content_item["parts"] = new_parts + + return sanitized_body + except Exception as e: + logger.warning(f"日志净化失败: {e},将记录原始请求体。") + return body