diff --git a/zhenxun/services/llm/README.md b/zhenxun/services/llm/README.md new file mode 100644 index 00000000..263be1e6 --- /dev/null +++ b/zhenxun/services/llm/README.md @@ -0,0 +1,731 @@ +# Zhenxun LLM 服务模块 + +## 📑 目录 + +- [📖 概述](#-概述) +- [🌟 主要特性](#-主要特性) +- [🚀 快速开始](#-快速开始) +- [📚 API 参考](#-api-参考) +- [⚙️ 配置](#️-配置) +- [🔧 高级功能](#-高级功能) +- [🏗️ 架构设计](#️-架构设计) +- [🔌 支持的提供商](#-支持的提供商) +- [🎯 使用场景](#-使用场景) +- [📊 性能优化](#-性能优化) +- [🛠️ 故障排除](#️-故障排除) +- [❓ 常见问题](#-常见问题) +- [📝 示例项目](#-示例项目) +- [🤝 贡献](#-贡献) +- [📄 许可证](#-许可证) + +## 📖 概述 + +Zhenxun LLM 服务模块是一个现代化的AI服务框架,提供统一的接口来访问多个大语言模型提供商。该模块采用模块化设计,支持异步操作、智能重试、Key轮询和负载均衡等高级功能。 + +### 🌟 主要特性 + +- **多提供商支持**: OpenAI、Gemini、智谱AI、DeepSeek等 +- **统一接口**: 简洁一致的API设计 +- **智能Key轮询**: 自动负载均衡和故障转移 +- **异步高性能**: 基于asyncio的并发处理 +- **模型缓存**: 智能缓存机制提升性能 +- **工具调用**: 支持Function Calling +- **嵌入向量**: 文本向量化支持 +- **错误处理**: 完善的异常处理和重试机制 +- **多模态支持**: 文本、图像、音频、视频处理 +- **代码执行**: Gemini代码执行功能 +- **搜索增强**: Google搜索集成 + +## 🚀 快速开始 + +### 基本使用 + +```python +from zhenxun.services.llm import chat, code, search, analyze + +# 简单聊天 +response = await chat("你好,请介绍一下自己") +print(response) + +# 代码执行 +result = await code("计算斐波那契数列的前10项") +print(result["text"]) +print(result["code_executions"]) + +# 搜索功能 +search_result = await search("Python异步编程最佳实践") +print(search_result["text"]) + +# 多模态分析 +from nonebot_plugin_alconna.uniseg import UniMessage, Image, Text +message = UniMessage([ + Text("分析这张图片"), + Image(path="image.jpg") +]) +analysis = await analyze(message, model="Gemini/gemini-2.0-flash") +print(analysis) +``` + +### 使用AI类 + +```python +from zhenxun.services.llm import AI, AIConfig, CommonOverrides + +# 创建AI实例 +ai = AI(AIConfig(model="OpenAI/gpt-4")) + +# 聊天对话 +response = await ai.chat("解释量子计算的基本原理") + +# 多模态分析 +from nonebot_plugin_alconna.uniseg import UniMessage, Image, Text + +multimodal_msg = UniMessage([ + Text("这张图片显示了什么?"), + Image(path="image.jpg") +]) +result = await ai.analyze(multimodal_msg) + +# 便捷的多模态函数 +result = await analyze_with_images( + "分析这张图片", + images="image.jpg", + model="Gemini/gemini-2.0-flash" +) +``` + +## 📚 API 参考 + +### 快速函数 + +#### `chat(message, *, model=None, **kwargs) -> str` +简单聊天对话 + +**参数:** +- `message`: 消息内容(字符串、LLMMessage或内容部分列表) +- `model`: 模型名称(可选) +- `**kwargs`: 额外配置参数 + +#### `code(prompt, *, model=None, timeout=None, **kwargs) -> dict` +代码执行功能 + +**返回:** +```python +{ + "text": "执行结果说明", + "code_executions": [{"code": "...", "output": "..."}], + "success": True +} +``` + +#### `search(query, *, model=None, instruction="", **kwargs) -> dict` +搜索增强生成 + +**返回:** +```python +{ + "text": "搜索结果和分析", + "grounding_metadata": {...}, + "success": True +} +``` + +#### `analyze(message, *, instruction="", model=None, tools=None, tool_config=None, **kwargs) -> str | LLMResponse` +高级分析功能,支持多模态输入和工具调用 + +#### `analyze_with_images(text, images, *, instruction="", model=None, **kwargs) -> str` +图片分析便捷函数 + +#### `analyze_multimodal(text=None, images=None, videos=None, audios=None, *, instruction="", model=None, **kwargs) -> str` +多模态分析便捷函数 + +#### `embed(texts, *, model=None, task_type="RETRIEVAL_DOCUMENT", **kwargs) -> list[list[float]]` +文本嵌入向量 + +### AI类方法 + +#### `AI.chat(message, *, model=None, **kwargs) -> str` +聊天对话方法,支持简单多模态输入 + +#### `AI.analyze(message, *, instruction="", model=None, tools=None, tool_config=None, **kwargs) -> str | LLMResponse` +高级分析方法,接收UniMessage进行多模态分析和工具调用 + +### 模型管理 + +```python +from zhenxun.services.llm import ( + get_model_instance, + list_available_models, + set_global_default_model_name, + clear_model_cache +) + +# 获取模型实例 +model = await get_model_instance("OpenAI/gpt-4o") + +# 列出可用模型 +models = list_available_models() + +# 设置默认模型 +set_global_default_model_name("Gemini/gemini-2.0-flash") + +# 清理缓存 +clear_model_cache() +``` + +## ⚙️ 配置 + +### 预设配置 + +```python +from zhenxun.services.llm import CommonOverrides + +# 创意模式 +creative_config = CommonOverrides.creative() + +# 精确模式 +precise_config = CommonOverrides.precise() + +# Gemini特殊功能 +json_config = CommonOverrides.gemini_json() +thinking_config = CommonOverrides.gemini_thinking() +code_exec_config = CommonOverrides.gemini_code_execution() +grounding_config = CommonOverrides.gemini_grounding() +``` + +### 自定义配置 + +```python +from zhenxun.services.llm import LLMGenerationConfig + +config = LLMGenerationConfig( + temperature=0.7, + max_tokens=2048, + top_p=0.9, + frequency_penalty=0.1, + presence_penalty=0.1, + stop=["END", "STOP"], + response_mime_type="application/json", + enable_code_execution=True, + enable_grounding=True +) + +response = await chat("你的问题", override_config=config) +``` + +## 🔧 高级功能 + +### 工具调用 (Function Calling) + +```python +from zhenxun.services.llm import LLMTool, get_model_instance + +# 定义工具 +tools = [ + LLMTool( + name="get_weather", + description="获取天气信息", + parameters={ + "type": "object", + "properties": { + "city": {"type": "string", "description": "城市名称"} + }, + "required": ["city"] + } + ) +] + +# 工具执行器 +async def tool_executor(tool_name: str, args: dict) -> str: + if tool_name == "get_weather": + return f"{args['city']}今天晴天,25°C" + return "未知工具" + +# 使用工具 +model = await get_model_instance("OpenAI/gpt-4") +response = await model.generate_response( + messages=[{"role": "user", "content": "北京天气如何?"}], + tools=tools, + tool_executor=tool_executor +) +``` + +### 多模态处理 + +```python +from zhenxun.services.llm import create_multimodal_message, analyze_multimodal, analyze_with_images + +# 方法1:使用便捷函数 +result = await analyze_multimodal( + text="分析这些媒体文件", + images="image.jpg", + audios="audio.mp3", + model="Gemini/gemini-2.0-flash" +) + +# 方法2:使用create_multimodal_message +message = create_multimodal_message( + text="分析这张图片和音频", + images="image.jpg", + audios="audio.mp3" +) +result = await analyze(message) + +# 方法3:图片分析专用函数 +result = await analyze_with_images( + "这张图片显示了什么?", + images=["image1.jpg", "image2.jpg"] +) +``` + +## 🛠️ 故障排除 + +### 常见错误 + +1. **配置错误**: 检查API密钥和模型配置 +2. **网络问题**: 检查代理设置和网络连接 +3. **模型不可用**: 使用 `list_available_models()` 检查可用模型 +4. **超时错误**: 调整timeout参数或使用更快的模型 + +### 调试技巧 + +```python +from zhenxun.services.llm import get_cache_stats +from zhenxun.services.log import logger + +# 查看缓存状态 +stats = get_cache_stats() +print(f"缓存命中率: {stats['hit_rate']}") + +# 启用详细日志 +logger.setLevel("DEBUG") +``` + +## ❓ 常见问题 + + +### Q: 如何处理多模态输入? + +**A:** 有多种方式处理多模态输入: +```python +# 方法1:使用便捷函数 +result = await analyze_with_images("分析这张图片", images="image.jpg") + +# 方法2:使用analyze函数 +from nonebot_plugin_alconna.uniseg import UniMessage, Image, Text +message = UniMessage([Text("分析这张图片"), Image(path="image.jpg")]) +result = await analyze(message) + +# 方法3:使用create_multimodal_message +from zhenxun.services.llm import create_multimodal_message +message = create_multimodal_message(text="分析这张图片", images="image.jpg") +result = await analyze(message) +``` + +### Q: 如何自定义工具调用? + +**A:** 使用analyze函数的tools参数: +```python +# 定义工具 +tools = [{ + "name": "calculator", + "description": "计算数学表达式", + "parameters": { + "type": "object", + "properties": { + "expression": {"type": "string", "description": "数学表达式"} + }, + "required": ["expression"] + } +}] + +# 使用工具 +from nonebot_plugin_alconna.uniseg import UniMessage, Text +message = UniMessage([Text("计算 2+3*4")]) +response = await analyze(message, tools=tools, tool_config={"mode": "auto"}) + +# 如果返回LLMResponse,说明有工具调用 +if hasattr(response, 'tool_calls'): + for tool_call in response.tool_calls: + print(f"调用工具: {tool_call.function.name}") + print(f"参数: {tool_call.function.arguments}") +``` + + +### Q: 如何确保输出格式? + +**A:** 使用结构化输出: +```python +# JSON格式输出 +config = CommonOverrides.gemini_json() + +# 自定义Schema +schema = { + "type": "object", + "properties": { + "answer": {"type": "string"}, + "confidence": {"type": "number"} + } +} +config = CommonOverrides.gemini_structured(schema) +``` + +## 📝 示例项目 + +### 完整示例 + +#### 1. 智能客服机器人 + +```python +from zhenxun.services.llm import AI, CommonOverrides +from typing import Dict, List + +class CustomerService: + def __init__(self): + self.ai = AI() + self.sessions: Dict[str, List[dict]] = {} + + async def handle_query(self, user_id: str, query: str) -> str: + # 获取或创建会话历史 + if user_id not in self.sessions: + self.sessions[user_id] = [] + + history = self.sessions[user_id] + + # 添加系统提示 + if not history: + history.append({ + "role": "system", + "content": "你是一个专业的客服助手,请友好、准确地回答用户问题。" + }) + + # 添加用户问题 + history.append({"role": "user", "content": query}) + + # 生成回复 + response = await self.ai.chat( + query, + history=history[-20:], # 保留最近20轮对话 + override_config=CommonOverrides.balanced() + ) + + # 保存回复到历史 + history.append({"role": "assistant", "content": response}) + + return response +``` + +#### 2. 文档智能问答 + +```python +from zhenxun.services.llm import embed, analyze +import numpy as np +from typing import List, Tuple + +class DocumentQA: + def __init__(self): + self.documents: List[str] = [] + self.embeddings: List[List[float]] = [] + + async def add_document(self, text: str): + """添加文档到知识库""" + self.documents.append(text) + + # 生成嵌入向量 + embedding = await embed([text]) + self.embeddings.extend(embedding) + + async def query(self, question: str, top_k: int = 3) -> str: + """查询文档并生成答案""" + if not self.documents: + return "知识库为空,请先添加文档。" + + # 生成问题的嵌入向量 + question_embedding = await embed([question]) + + # 计算相似度并找到最相关的文档 + similarities = [] + for doc_embedding in self.embeddings: + similarity = np.dot(question_embedding[0], doc_embedding) + similarities.append(similarity) + + # 获取最相关的文档 + top_indices = np.argsort(similarities)[-top_k:][::-1] + relevant_docs = [self.documents[i] for i in top_indices] + + # 构建上下文 + context = "\n\n".join(relevant_docs) + prompt = f""" +基于以下文档内容回答问题: + +文档内容: +{context} + +问题:{question} + +请基于文档内容给出准确的答案,如果文档中没有相关信息,请说明。 +""" + + result = await analyze(prompt) + return result["text"] +``` + +#### 3. 代码审查助手 + +```python +from zhenxun.services.llm import code, analyze +import os + +class CodeReviewer: + async def review_file(self, file_path: str) -> dict: + """审查代码文件""" + if not os.path.exists(file_path): + return {"error": "文件不存在"} + + with open(file_path, 'r', encoding='utf-8') as f: + code_content = f.read() + + prompt = f""" +请审查以下代码,提供详细的反馈: + +文件:{file_path} +代码: +``` +{code_content} +``` + +请从以下方面进行审查: +1. 代码质量和可读性 +2. 潜在的bug和安全问题 +3. 性能优化建议 +4. 最佳实践建议 +5. 代码风格问题 + +请以JSON格式返回结果。 +""" + + result = await analyze( + prompt, + model="DeepSeek/deepseek-coder", + override_config=CommonOverrides.gemini_json() + ) + + return { + "file": file_path, + "review": result["text"], + "success": True + } + + async def suggest_improvements(self, code: str, language: str = "python") -> str: + """建议代码改进""" + prompt = f""" +请改进以下{language}代码,使其更加高效、可读和符合最佳实践: + +原代码: +```{language} +{code} +``` + +请提供改进后的代码和说明。 +""" + + result = await code(prompt, model="DeepSeek/deepseek-coder") + return result["text"] +``` + + +## 🏗️ 架构设计 + +### 模块结构 + +``` +zhenxun/services/llm/ +├── __init__.py # 包入口,导入和暴露公共API +├── api.py # 高级API接口(AI类、便捷函数) +├── core.py # 核心基础设施(HTTP客户端、重试逻辑、KeyStore) +├── service.py # LLM模型实现类 +├── utils.py # 工具和转换函数 +├── manager.py # 模型管理和缓存 +├── adapters/ # 适配器模块 +│ ├── __init__.py # 适配器包入口 +│ ├── base.py # 基础适配器 +│ ├── factory.py # 适配器工厂 +│ ├── openai.py # OpenAI适配器 +│ ├── gemini.py # Gemini适配器 +│ └── zhipu.py # 智谱AI适配器 +├── config/ # 配置模块 +│ ├── __init__.py # 配置包入口 +│ ├── generation.py # 生成配置 +│ ├── presets.py # 预设配置 +│ └── providers.py # 提供商配置 +└── types/ # 类型定义 + ├── __init__.py # 类型包入口 + ├── content.py # 内容类型 + ├── enums.py # 枚举定义 + ├── exceptions.py # 异常定义 + └── models.py # 数据模型 +``` + +### 模块职责 + +- **`__init__.py`**: 纯粹的包入口,只负责导入和暴露公共API +- **`api.py`**: 高级API接口,包含AI类和所有便捷函数 +- **`core.py`**: 核心基础设施,包含HTTP客户端管理、重试逻辑和KeyStore +- **`service.py`**: LLM模型实现类,专注于模型逻辑 +- **`utils.py`**: 工具和转换函数,如多模态消息处理 +- **`manager.py`**: 模型管理和缓存机制 +- **`adapters/`**: 各大提供商的适配器模块,负责与不同API的交互 + - `base.py`: 定义适配器的基础接口 + - `factory.py`: 适配器工厂,用于动态加载和实例化适配器 + - `openai.py`: OpenAI API适配器 + - `gemini.py`: Google Gemini API适配器 + - `zhipu.py`: 智谱AI API适配器 +- **`config/`**: 配置管理模块 + - `generation.py`: 生成配置和预设 + - `presets.py`: 预设配置 + - `providers.py`: 提供商配置 +- **`types/`**: 类型定义模块 + - `content.py`: 内容类型定义 + - `enums.py`: 枚举定义 + - `exceptions.py`: 异常定义 + - `models.py`: 数据模型定义 + +## 🔌 支持的提供商 + +### OpenAI 兼容 + +- **OpenAI**: GPT-4o, GPT-3.5-turbo等 +- **DeepSeek**: deepseek-chat, deepseek-reasoner等 +- **其他OpenAI兼容API**: 支持自定义端点 + +```python +# OpenAI +await chat("Hello", model="OpenAI/gpt-4o") + +# DeepSeek +await chat("写代码", model="DeepSeek/deepseek-reasoner") +``` + +### Google Gemini + +- **Gemini Pro**: gemini-2.5-flash-preview-05-20 gemini-2.0-flash等 +- **特殊功能**: 代码执行、搜索增强、思考模式 + +```python +# 基础使用 +await chat("你好", model="Gemini/gemini-2.0-flash") + +# 代码执行 +await code("计算质数", model="Gemini/gemini-2.0-flash") + +# 搜索增强 +await search("最新AI发展", model="Gemini/gemini-2.5-flash-preview-05-20") +``` + +### 智谱AI + +- **GLM系列**: glm-4, glm-4v等 +- **支持功能**: 文本生成、多模态理解 + +```python +await chat("介绍北京", model="Zhipu/glm-4") +``` + +## 🎯 使用场景 + +### 1. 聊天机器人 + +```python +from zhenxun.services.llm import AI, CommonOverrides + +class ChatBot: + def __init__(self): + self.ai = AI() + self.history = [] + + async def chat(self, user_input: str) -> str: + # 添加历史记录 + self.history.append({"role": "user", "content": user_input}) + + # 生成回复 + response = await self.ai.chat( + user_input, + history=self.history[-10:], # 保留最近10轮对话 + override_config=CommonOverrides.balanced() + ) + + self.history.append({"role": "assistant", "content": response}) + return response +``` + +### 2. 代码助手 + +```python +async def code_assistant(task: str) -> dict: + """代码生成和执行助手""" + result = await code( + f"请帮我{task},并执行代码验证结果", + model="Gemini/gemini-2.0-flash", + timeout=60 + ) + + return { + "explanation": result["text"], + "code_blocks": result["code_executions"], + "success": result["success"] + } + +# 使用示例 +result = await code_assistant("实现快速排序算法") +``` + +### 3. 文档分析 + +```python +from zhenxun.services.llm import analyze_with_images + +async def analyze_document(image_path: str, question: str) -> str: + """分析文档图片并回答问题""" + result = await analyze_with_images( + f"请分析这个文档并回答:{question}", + images=image_path, + model="Gemini/gemini-2.0-flash" + ) + return result +``` + +### 4. 智能搜索 + +```python +async def smart_search(query: str) -> dict: + """智能搜索和总结""" + result = await search( + query, + model="Gemini/gemini-2.0-flash", + instruction="请提供准确、最新的信息,并注明信息来源" + ) + + return { + "summary": result["text"], + "sources": result.get("grounding_metadata", {}), + "confidence": result.get("confidence_score", 0.0) + } +``` + +## 🔧 配置管理 + + +### 动态配置 + +```python +from zhenxun.services.llm import set_global_default_model_name + +# 运行时更改默认模型 +set_global_default_model_name("OpenAI/gpt-4") + +# 检查可用模型 +models = list_available_models() +for model in models: + print(f"{model.provider}/{model.name} - {model.description}") +``` + diff --git a/zhenxun/services/llm/__init__.py b/zhenxun/services/llm/__init__.py new file mode 100644 index 00000000..ff09ef7a --- /dev/null +++ b/zhenxun/services/llm/__init__.py @@ -0,0 +1,96 @@ +""" +LLM 服务模块 - 公共 API 入口 + +提供统一的 AI 服务调用接口、核心类型定义和模型管理功能。 +""" + +from .api import ( + AI, + AIConfig, + TaskType, + analyze, + analyze_multimodal, + analyze_with_images, + chat, + code, + embed, + search, + search_multimodal, +) +from .config import ( + CommonOverrides, + LLMGenerationConfig, + register_llm_configs, +) + +register_llm_configs() +from .api import ModelName +from .manager import ( + clear_model_cache, + get_cache_stats, + get_global_default_model_name, + get_model_instance, + list_available_models, + list_embedding_models, + list_model_identifiers, + set_global_default_model_name, +) +from .types import ( + EmbeddingTaskType, + LLMContentPart, + LLMErrorCode, + LLMException, + LLMMessage, + LLMResponse, + LLMTool, + ModelDetail, + ModelInfo, + ModelProvider, + ResponseFormat, + ToolCategory, + ToolMetadata, + UsageInfo, +) +from .utils import create_multimodal_message, unimsg_to_llm_parts + +__all__ = [ + "AI", + "AIConfig", + "CommonOverrides", + "EmbeddingTaskType", + "LLMContentPart", + "LLMErrorCode", + "LLMException", + "LLMGenerationConfig", + "LLMMessage", + "LLMResponse", + "LLMTool", + "ModelDetail", + "ModelInfo", + "ModelName", + "ModelProvider", + "ResponseFormat", + "TaskType", + "ToolCategory", + "ToolMetadata", + "UsageInfo", + "analyze", + "analyze_multimodal", + "analyze_with_images", + "chat", + "clear_model_cache", + "code", + "create_multimodal_message", + "embed", + "get_cache_stats", + "get_global_default_model_name", + "get_model_instance", + "list_available_models", + "list_embedding_models", + "list_model_identifiers", + "register_llm_configs", + "search", + "search_multimodal", + "set_global_default_model_name", + "unimsg_to_llm_parts", +] diff --git a/zhenxun/services/llm/adapters/__init__.py b/zhenxun/services/llm/adapters/__init__.py new file mode 100644 index 00000000..93ed9d31 --- /dev/null +++ b/zhenxun/services/llm/adapters/__init__.py @@ -0,0 +1,26 @@ +""" +LLM 适配器模块 + +提供不同LLM服务商的API适配器实现,统一接口调用方式。 +""" + +from .base import BaseAdapter, OpenAICompatAdapter, RequestData, ResponseData +from .factory import LLMAdapterFactory, get_adapter_for_api_type, register_adapter +from .gemini import GeminiAdapter +from .openai import OpenAIAdapter +from .zhipu import ZhipuAdapter + +LLMAdapterFactory.initialize() + +__all__ = [ + "BaseAdapter", + "GeminiAdapter", + "LLMAdapterFactory", + "OpenAIAdapter", + "OpenAICompatAdapter", + "RequestData", + "ResponseData", + "ZhipuAdapter", + "get_adapter_for_api_type", + "register_adapter", +] diff --git a/zhenxun/services/llm/adapters/base.py b/zhenxun/services/llm/adapters/base.py new file mode 100644 index 00000000..f94c22cd --- /dev/null +++ b/zhenxun/services/llm/adapters/base.py @@ -0,0 +1,477 @@ +""" +LLM 适配器基类和通用数据结构 +""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +from pydantic import BaseModel + +from zhenxun.services.log import logger + +from ..types.exceptions import LLMErrorCode, LLMException +from ..types.models import LLMToolCall + +if TYPE_CHECKING: + from ..config.generation import LLMGenerationConfig + from ..service import LLMModel + from ..types.content import LLMMessage + from ..types.enums import EmbeddingTaskType + + +class RequestData(BaseModel): + """请求数据封装""" + + url: str + headers: dict[str, str] + body: dict[str, Any] + + +class ResponseData(BaseModel): + """响应数据封装 - 支持所有高级功能""" + + text: str + usage_info: dict[str, Any] | None = None + raw_response: dict[str, Any] | None = None + tool_calls: list[LLMToolCall] | None = None + code_executions: list[Any] | None = None + grounding_metadata: Any | None = None + cache_info: Any | None = None + + code_execution_results: list[dict[str, Any]] | None = None + search_results: list[dict[str, Any]] | None = None + function_calls: list[dict[str, Any]] | None = None + safety_ratings: list[dict[str, Any]] | None = None + citations: list[dict[str, Any]] | None = None + + +class BaseAdapter(ABC): + """LLM API适配器基类""" + + @property + @abstractmethod + def api_type(self) -> str: + """API类型标识""" + pass + + @property + @abstractmethod + def supported_api_types(self) -> list[str]: + """支持的API类型列表""" + pass + + def prepare_simple_request( + self, + model: "LLMModel", + api_key: str, + prompt: str, + history: list[dict[str, str]] | None = None, + ) -> RequestData: + """准备简单文本生成请求 + + 默认实现:将简单请求转换为高级请求格式 + 子类可以重写此方法以提供特定的优化实现 + """ + from ..types.content import LLMMessage + + messages: list[LLMMessage] = [] + + if history: + for msg in history: + role = msg.get("role", "user") + content = msg.get("content", "") + messages.append(LLMMessage(role=role, content=content)) + + messages.append(LLMMessage(role="user", content=prompt)) + + config = model._generation_config + + return self.prepare_advanced_request( + model=model, + api_key=api_key, + messages=messages, + config=config, + tools=None, + tool_choice=None, + ) + + @abstractmethod + def prepare_advanced_request( + self, + model: "LLMModel", + api_key: str, + messages: list["LLMMessage"], + config: "LLMGenerationConfig | None" = None, + tools: list[dict[str, Any]] | None = None, + tool_choice: str | dict[str, Any] | None = None, + ) -> RequestData: + """准备高级请求""" + pass + + @abstractmethod + def parse_response( + self, + model: "LLMModel", + response_json: dict[str, Any], + is_advanced: bool = False, + ) -> ResponseData: + """解析API响应""" + pass + + @abstractmethod + def prepare_embedding_request( + self, + model: "LLMModel", + api_key: str, + texts: list[str], + task_type: "EmbeddingTaskType | str", + **kwargs: Any, + ) -> RequestData: + """准备文本嵌入请求""" + pass + + @abstractmethod + def parse_embedding_response(self, response_json: dict[str, Any]) -> list[list[float]]: + """解析文本嵌入响应""" + pass + + def validate_embedding_response(self, response_json: dict[str, Any]) -> None: + """验证嵌入API响应""" + if "error" in response_json: + error_info = response_json["error"] + msg = ( + error_info.get("message", str(error_info)) + if isinstance(error_info, dict) + else str(error_info) + ) + raise LLMException( + f"嵌入API错误: {msg}", code=LLMErrorCode.EMBEDDING_FAILED, details=response_json + ) + + def get_api_url(self, model: "LLMModel", endpoint: str) -> str: + """构建API URL""" + if not model.api_base: + raise LLMException( + f"模型 {model.model_name} 的 api_base 未设置", + code=LLMErrorCode.CONFIGURATION_ERROR, + ) + return f"{model.api_base.rstrip('/')}{endpoint}" + + def get_base_headers(self, api_key: str) -> dict[str, str]: + """获取基础请求头""" + from zhenxun.utils.user_agent import get_user_agent + + headers = get_user_agent() + headers.update( + { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}", + } + ) + return headers + + def convert_messages_to_openai_format(self, messages: list["LLMMessage"]) -> list[dict[str, Any]]: + """将LLMMessage转换为OpenAI格式 - 通用方法""" + openai_messages: list[dict[str, Any]] = [] + for msg in messages: + openai_msg: dict[str, Any] = {"role": msg.role} + + if msg.role == "tool": + openai_msg["tool_call_id"] = msg.tool_call_id + openai_msg["name"] = msg.name + openai_msg["content"] = msg.content + else: + if isinstance(msg.content, str): + openai_msg["content"] = msg.content + else: + content_parts = [] + for part in msg.content: + if part.type == "text": + content_parts.append({"type": "text", "text": part.text}) + elif part.type == "image": + content_parts.append( + {"type": "image_url", "image_url": {"url": part.image_source}} + ) + openai_msg["content"] = content_parts + + if msg.role == "assistant" and msg.tool_calls: + assistant_tool_calls = [] + for call in msg.tool_calls: + assistant_tool_calls.append( + { + "id": call.id, + "type": "function", + "function": { + "name": call.function.name, + "arguments": call.function.arguments, + }, + } + ) + openai_msg["tool_calls"] = assistant_tool_calls + + if msg.name and msg.role != "tool": + openai_msg["name"] = msg.name + + openai_messages.append(openai_msg) + return openai_messages + + def parse_openai_response(self, response_json: dict[str, Any]) -> ResponseData: + """解析OpenAI格式的响应 - 通用方法""" + self.validate_response(response_json) + + try: + choices = response_json.get("choices", []) + if not choices: + logger.debug("OpenAI响应中没有choices,可能为空回复或流结束。") + return ResponseData(text="", raw_response=response_json) + + choice = choices[0] + message = choice.get("message", {}) + content = message.get("content", "") + + parsed_tool_calls: list[LLMToolCall] | None = None + if message_tool_calls := message.get("tool_calls"): + from ..types.models import LLMToolFunction + + parsed_tool_calls = [] + for tc_data in message_tool_calls: + try: + if tc_data.get("type") == "function": + parsed_tool_calls.append( + LLMToolCall( + id=tc_data["id"], + function=LLMToolFunction( + name=tc_data["function"]["name"], + arguments=tc_data["function"]["arguments"], + ), + ) + ) + except KeyError as e: + logger.warning(f"解析OpenAI工具调用数据时缺少键: {tc_data}, 错误: {e}") + except Exception as e: + logger.warning(f"解析OpenAI工具调用数据时出错: {tc_data}, 错误: {e}") + if not parsed_tool_calls: + parsed_tool_calls = None + + final_text = content if content is not None else "" + if not final_text and parsed_tool_calls: + final_text = f"请求调用 {len(parsed_tool_calls)} 个工具。" + + usage_info = response_json.get("usage") + + return ResponseData( + text=final_text, + tool_calls=parsed_tool_calls, + usage_info=usage_info, + raw_response=response_json, + ) + + except Exception as e: + logger.error(f"解析OpenAI格式响应失败: {e}", e=e) + raise LLMException(f"解析API响应失败: {e}", code=LLMErrorCode.RESPONSE_PARSE_ERROR, cause=e) + + def validate_response(self, response_json: dict[str, Any]) -> None: + """验证API响应,解析不同API的错误结构""" + if "error" in response_json: + error_info = response_json["error"] + + if isinstance(error_info, dict): + error_message = error_info.get("message", "未知错误") + error_code = error_info.get("code", "unknown") + error_type = error_info.get("type", "api_error") + + error_code_mapping = { + "invalid_api_key": LLMErrorCode.API_KEY_INVALID, + "authentication_failed": LLMErrorCode.API_KEY_INVALID, + "rate_limit_exceeded": LLMErrorCode.API_RATE_LIMITED, + "quota_exceeded": LLMErrorCode.API_RATE_LIMITED, + "model_not_found": LLMErrorCode.MODEL_NOT_FOUND, + "invalid_model": LLMErrorCode.MODEL_NOT_FOUND, + "context_length_exceeded": LLMErrorCode.CONTEXT_LENGTH_EXCEEDED, + "max_tokens_exceeded": LLMErrorCode.CONTEXT_LENGTH_EXCEEDED, + } + + llm_error_code = error_code_mapping.get(error_code, LLMErrorCode.API_RESPONSE_INVALID) + + logger.error(f"API返回错误: {error_message} (代码: {error_code}, 类型: {error_type})") + else: + error_message = str(error_info) + error_code = "unknown" + llm_error_code = LLMErrorCode.API_RESPONSE_INVALID + + logger.error(f"API返回错误: {error_message}") + + raise LLMException( + f"API请求失败: {error_message}", + code=llm_error_code, + details={"api_error": error_info, "error_code": error_code}, + ) + + if "candidates" in response_json: + candidates = response_json.get("candidates", []) + if candidates: + candidate = candidates[0] + finish_reason = candidate.get("finishReason") + if finish_reason in ["SAFETY", "RECITATION"]: + safety_ratings = candidate.get("safetyRatings", []) + logger.warning(f"Gemini内容被安全过滤: {finish_reason}, 安全评级: {safety_ratings}") + raise LLMException( + f"内容被安全过滤: {finish_reason}", + code=LLMErrorCode.CONTENT_FILTERED, + details={ + "finish_reason": finish_reason, + "safety_ratings": safety_ratings, + }, + ) + + if not response_json: + logger.error("API返回空响应") + raise LLMException( + "API返回空响应", + code=LLMErrorCode.API_RESPONSE_INVALID, + details={"response": response_json}, + ) + + def _apply_generation_config( + self, + model: "LLMModel", + config: "LLMGenerationConfig | None" = None, + ) -> dict[str, Any]: + """通用的配置应用逻辑""" + if config is not None: + return config.to_api_params(model.api_type, model.model_name) + + if model._generation_config is not None: + return model._generation_config.to_api_params(model.api_type, model.model_name) + + base_config = {} + if model.temperature is not None: + base_config["temperature"] = model.temperature + if model.max_tokens is not None: + if model.api_type in ["gemini", "gemini_native"]: + base_config["maxOutputTokens"] = model.max_tokens + else: + base_config["max_tokens"] = model.max_tokens + + return base_config + + def apply_config_override( + self, + model: "LLMModel", + body: dict[str, Any], + config: "LLMGenerationConfig | None" = None, + ) -> dict[str, Any]: + """应用配置覆盖""" + config_params = self._apply_generation_config(model, config) + body.update(config_params) + return body + + +class OpenAICompatAdapter(BaseAdapter): + """ + 处理所有 OpenAI 兼容 API 的通用适配器。 + 消除 OpenAIAdapter 和 ZhipuAdapter 之间的代码重复。 + """ + + @abstractmethod + def get_chat_endpoint(self) -> str: + """子类必须实现,返回 chat completions 的端点""" + pass + + @abstractmethod + def get_embedding_endpoint(self) -> str: + """子类必须实现,返回 embeddings 的端点""" + pass + + def prepare_advanced_request( + self, + model: "LLMModel", + api_key: str, + messages: list["LLMMessage"], + config: "LLMGenerationConfig | None" = None, + tools: list[dict[str, Any]] | None = None, + tool_choice: str | dict[str, Any] | None = None, + ) -> RequestData: + """准备高级请求 - OpenAI兼容格式""" + url = self.get_api_url(model, self.get_chat_endpoint()) + headers = self.get_base_headers(api_key) + openai_messages = self.convert_messages_to_openai_format(messages) + + body = { + "model": model.model_name, + "messages": openai_messages, + } + + if tools: + body["tools"] = tools + if tool_choice: + body["tool_choice"] = tool_choice + + body = self.apply_config_override(model, body, config) + return RequestData(url=url, headers=headers, body=body) + + def parse_response( + self, + model: "LLMModel", + response_json: dict[str, Any], + is_advanced: bool = False, + ) -> ResponseData: + """解析响应 - 直接使用基类的 OpenAI 格式解析""" + return self.parse_openai_response(response_json) + + def prepare_embedding_request( + self, + model: "LLMModel", + api_key: str, + texts: list[str], + task_type: "EmbeddingTaskType | str", + **kwargs: Any, + ) -> RequestData: + """准备嵌入请求 - OpenAI兼容格式""" + url = self.get_api_url(model, self.get_embedding_endpoint()) + headers = self.get_base_headers(api_key) + + body = { + "model": model.model_name, + "input": texts, + } + + # 应用额外的配置参数 + if kwargs: + body.update(kwargs) + + return RequestData(url=url, headers=headers, body=body) + + def parse_embedding_response(self, response_json: dict[str, Any]) -> list[list[float]]: + """解析嵌入响应 - OpenAI兼容格式""" + self.validate_embedding_response(response_json) + + try: + data = response_json.get("data", []) + if not data: + raise LLMException( + "嵌入响应中没有数据", + code=LLMErrorCode.EMBEDDING_FAILED, + details=response_json, + ) + + embeddings = [] + for item in data: + if "embedding" in item: + embeddings.append(item["embedding"]) + else: + raise LLMException( + "嵌入响应格式错误:缺少embedding字段", + code=LLMErrorCode.EMBEDDING_FAILED, + details=item, + ) + + return embeddings + + except Exception as e: + logger.error(f"解析嵌入响应失败: {e}", e=e) + raise LLMException( + f"解析嵌入响应失败: {e}", + code=LLMErrorCode.EMBEDDING_FAILED, + cause=e, + ) diff --git a/zhenxun/services/llm/adapters/factory.py b/zhenxun/services/llm/adapters/factory.py new file mode 100644 index 00000000..8652fc67 --- /dev/null +++ b/zhenxun/services/llm/adapters/factory.py @@ -0,0 +1,78 @@ +""" +LLM 适配器工厂类 +""" + +from typing import ClassVar + +from ..types.exceptions import LLMErrorCode, LLMException +from .base import BaseAdapter + + +class LLMAdapterFactory: + """LLM适配器工厂类""" + + _adapters: ClassVar[dict[str, BaseAdapter]] = {} + _api_type_mapping: ClassVar[dict[str, str]] = {} + + @classmethod + def initialize(cls) -> None: + """初始化默认适配器""" + if cls._adapters: + return + + from .gemini import GeminiAdapter + from .openai import OpenAIAdapter + from .zhipu import ZhipuAdapter + + cls.register_adapter(OpenAIAdapter()) + cls.register_adapter(ZhipuAdapter()) + cls.register_adapter(GeminiAdapter()) + + @classmethod + def register_adapter(cls, adapter: BaseAdapter) -> None: + """注册适配器""" + adapter_key = adapter.api_type + cls._adapters[adapter_key] = adapter + + for api_type in adapter.supported_api_types: + cls._api_type_mapping[api_type] = adapter_key + + @classmethod + def get_adapter(cls, api_type: str) -> BaseAdapter: + """获取适配器""" + cls.initialize() + + adapter_key = cls._api_type_mapping.get(api_type) + if not adapter_key: + raise LLMException( + f"不支持的API类型: {api_type}", + code=LLMErrorCode.UNKNOWN_API_TYPE, + details={ + "api_type": api_type, + "supported_types": list(cls._api_type_mapping.keys()), + }, + ) + + return cls._adapters[adapter_key] + + @classmethod + def list_supported_types(cls) -> list[str]: + """列出所有支持的API类型""" + cls.initialize() + return list(cls._api_type_mapping.keys()) + + @classmethod + def list_adapters(cls) -> dict[str, BaseAdapter]: + """列出所有注册的适配器""" + cls.initialize() + return cls._adapters.copy() + + +def get_adapter_for_api_type(api_type: str) -> BaseAdapter: + """获取指定API类型的适配器""" + return LLMAdapterFactory.get_adapter(api_type) + + +def register_adapter(adapter: BaseAdapter) -> None: + """注册新的适配器""" + LLMAdapterFactory.register_adapter(adapter) diff --git a/zhenxun/services/llm/adapters/gemini.py b/zhenxun/services/llm/adapters/gemini.py new file mode 100644 index 00000000..3c6a7681 --- /dev/null +++ b/zhenxun/services/llm/adapters/gemini.py @@ -0,0 +1,508 @@ +""" +Gemini API 适配器 +""" + +from typing import TYPE_CHECKING, Any + +from zhenxun.services.log import logger + +from ..types.exceptions import LLMErrorCode, LLMException +from .base import BaseAdapter, RequestData, ResponseData + +if TYPE_CHECKING: + from ..config.generation import LLMGenerationConfig + from ..service import LLMModel + from ..types.content import LLMMessage + from ..types.enums import EmbeddingTaskType + from ..types.models import LLMToolCall + + +class GeminiAdapter(BaseAdapter): + """Gemini API 适配器""" + + @property + def api_type(self) -> str: + return "gemini" + + @property + def supported_api_types(self) -> list[str]: + return ["gemini"] + + def get_base_headers(self, api_key: str) -> dict[str, str]: + """获取基础请求头""" + from zhenxun.utils.user_agent import get_user_agent + + headers = get_user_agent() + headers.update({"Content-Type": "application/json"}) + headers["x-goog-api-key"] = api_key + + return headers + + def prepare_advanced_request( + self, + model: "LLMModel", + api_key: str, + messages: list["LLMMessage"], + config: "LLMGenerationConfig | None" = None, + tools: list[dict[str, Any]] | None = None, + tool_choice: str | dict[str, Any] | None = None, + ) -> RequestData: + """准备高级请求""" + return self._prepare_request(model, api_key, messages, config, tools, tool_choice) + + def _prepare_request( + self, + model: "LLMModel", + api_key: str, + messages: list["LLMMessage"], + config: "LLMGenerationConfig | None" = None, + tools: list[dict[str, Any]] | None = None, + tool_choice: str | dict[str, Any] | None = None, + ) -> RequestData: + """准备 Gemini API 请求 - 支持所有高级功能""" + effective_config = config if config is not None else model._generation_config + + endpoint = self._get_gemini_endpoint(model, effective_config) + url = self.get_api_url(model, endpoint) + headers = self.get_base_headers(api_key) + + gemini_contents: list[dict[str, Any]] = [] + system_instruction_parts: list[dict[str, Any]] | None = None + + for msg in messages: + current_parts: list[dict[str, Any]] = [] + if msg.role == "system": + if isinstance(msg.content, str): + system_instruction_parts = [{"text": msg.content}] + elif isinstance(msg.content, list): + system_instruction_parts = [part.convert_for_api("gemini") for part in msg.content] + continue + + elif msg.role == "user": + if isinstance(msg.content, str): + current_parts.append({"text": msg.content}) + elif isinstance(msg.content, list): + for part_obj in msg.content: + current_parts.append(part_obj.convert_for_api("gemini")) + gemini_contents.append({"role": "user", "parts": current_parts}) + + elif msg.role == "assistant" or msg.role == "model": + if isinstance(msg.content, str) and msg.content: + current_parts.append({"text": msg.content}) + elif isinstance(msg.content, list): + for part_obj in msg.content: + current_parts.append(part_obj.convert_for_api("gemini")) + + if msg.tool_calls: + import json + + for call in msg.tool_calls: + current_parts.append( + { + "functionCall": { + "name": call.function.name, + "args": json.loads(call.function.arguments), + } + } + ) + if current_parts: + gemini_contents.append({"role": "model", "parts": current_parts}) + + elif msg.role == "tool": + if not msg.name: + raise ValueError("Gemini 工具消息必须包含 'name' 字段(函数名)。") + + import json + + try: + content_str = msg.content if isinstance(msg.content, str) else str(msg.content) + tool_result_obj = json.loads(content_str) + except json.JSONDecodeError: + content_str = msg.content if isinstance(msg.content, str) else str(msg.content) + logger.warning( + f"工具 {msg.name} 的结果不是有效的 JSON: {content_str}. 包装为原始字符串。" + ) + tool_result_obj = {"raw_output": content_str} + + current_parts.append( + { + "functionResponse": { + "name": msg.name, + "response": tool_result_obj, + } + } + ) + gemini_contents.append({"role": "function", "parts": current_parts}) + + body: dict[str, Any] = {"contents": gemini_contents} + + if system_instruction_parts: + body["systemInstruction"] = {"parts": system_instruction_parts} + + all_tools_for_request = [] + if tools: + for tool_item in tools: + if isinstance(tool_item, dict): + if "name" in tool_item and "description" in tool_item: + all_tools_for_request.append({"functionDeclarations": [tool_item]}) + else: + all_tools_for_request.append(tool_item) + else: + all_tools_for_request.append(tool_item) + + if effective_config: + if getattr(effective_config, "enable_grounding", False): + has_explicit_gs_tool = any("googleSearch" in tool_item for tool_item in all_tools_for_request) + if not has_explicit_gs_tool: + all_tools_for_request.append({"googleSearch": {}}) + logger.debug("隐式启用 Google Search 工具进行信息来源关联。") + + if getattr(effective_config, "enable_code_execution", False): + has_explicit_ce_tool = any( + "codeExecution" in tool_item for tool_item in all_tools_for_request + ) + if not has_explicit_ce_tool: + all_tools_for_request.append({"codeExecution": {}}) + logger.debug("隐式启用代码执行工具。") + + if all_tools_for_request: + gemini_api_tools = self._convert_tools_to_gemini_format(all_tools_for_request) + if gemini_api_tools: + body["tools"] = gemini_api_tools + + final_tool_choice = tool_choice + if final_tool_choice is None and effective_config: + final_tool_choice = getattr(effective_config, "tool_choice", None) + + if final_tool_choice: + if isinstance(final_tool_choice, str): + mode_upper = final_tool_choice.upper() + if mode_upper in ["AUTO", "NONE", "ANY"]: + body["toolConfig"] = {"functionCallingConfig": {"mode": mode_upper}} + else: + body["toolConfig"] = self._convert_tool_choice_to_gemini(final_tool_choice) + else: + body["toolConfig"] = self._convert_tool_choice_to_gemini(final_tool_choice) + + final_generation_config = self._build_gemini_generation_config(model, effective_config) + if final_generation_config: + body["generationConfig"] = final_generation_config + + safety_settings = self._build_safety_settings(effective_config) + if safety_settings: + body["safetySettings"] = safety_settings + + return RequestData(url=url, headers=headers, body=body) + + def apply_config_override( + self, + model: "LLMModel", + body: dict[str, Any], + config: "LLMGenerationConfig | None" = None, + ) -> dict[str, Any]: + """应用配置覆盖 - Gemini 不需要额外的配置覆盖""" + return body + + def _get_gemini_endpoint(self, model: "LLMModel", config: "LLMGenerationConfig | None" = None) -> str: + """根据配置选择Gemini API端点""" + if config: + if getattr(config, "enable_code_execution", False): + return f"/v1beta/models/{model.model_name}:generateContent" + + if getattr(config, "enable_grounding", False): + return f"/v1beta/models/{model.model_name}:generateContent" + + return f"/v1beta/models/{model.model_name}:generateContent" + + def _convert_tools_to_gemini_format(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]: + """转换工具格式为Gemini格式""" + gemini_tools = [] + + for tool in tools: + if tool.get("type") == "function": + func = tool["function"] + gemini_tool = { + "functionDeclarations": [ + { + "name": func["name"], + "description": func.get("description", ""), + "parameters": func.get("parameters", {}), + } + ] + } + gemini_tools.append(gemini_tool) + elif tool.get("type") == "code_execution": + gemini_tools.append({"codeExecution": {"language": tool.get("language", "python")}}) + elif tool.get("type") == "google_search": + gemini_tools.append({"googleSearch": {}}) + elif "googleSearch" in tool: + gemini_tools.append({"googleSearch": tool["googleSearch"]}) + elif "codeExecution" in tool: + gemini_tools.append({"codeExecution": tool["codeExecution"]}) + + return gemini_tools + + def _convert_tool_choice_to_gemini(self, tool_choice_value: str | dict[str, Any]) -> dict[str, Any]: + """转换工具选择策略为Gemini格式""" + if isinstance(tool_choice_value, str): + mode_upper = tool_choice_value.upper() + if mode_upper in ["AUTO", "NONE", "ANY"]: + return {"functionCallingConfig": {"mode": mode_upper}} + else: + logger.warning(f"不支持的 tool_choice 字符串值: '{tool_choice_value}'。回退到 AUTO。") + return {"functionCallingConfig": {"mode": "AUTO"}} + + elif isinstance(tool_choice_value, dict): + if tool_choice_value.get("type") == "function" and "function" in tool_choice_value: + func_name = tool_choice_value["function"].get("name") + if func_name: + return { + "functionCallingConfig": { + "mode": "ANY", + "allowedFunctionNames": [func_name], + } + } + else: + logger.warning(f"tool_choice dict 中的函数名无效: {tool_choice_value}。回退到 AUTO。") + return {"functionCallingConfig": {"mode": "AUTO"}} + + elif "functionCallingConfig" in tool_choice_value: + return {"functionCallingConfig": tool_choice_value["functionCallingConfig"]} + + else: + logger.warning(f"不支持的 tool_choice dict 值: {tool_choice_value}。回退到 AUTO。") + return {"functionCallingConfig": {"mode": "AUTO"}} + + logger.warning(f"tool_choice 的类型无效: {type(tool_choice_value)}。回退到 AUTO。") + return {"functionCallingConfig": {"mode": "AUTO"}} + + def _build_gemini_generation_config( + self, model: "LLMModel", config: "LLMGenerationConfig | None" = None + ) -> dict[str, Any]: + """构建Gemini生成配置""" + generation_config: dict[str, Any] = {} + + effective_config = config if config is not None else model._generation_config + + if effective_config: + base_api_params = effective_config.to_api_params(api_type="gemini", model_name=model.model_name) + generation_config.update(base_api_params) + + if getattr(effective_config, "response_mime_type", None): + generation_config["responseMimeType"] = effective_config.response_mime_type + + if getattr(effective_config, "response_schema", None): + generation_config["responseSchema"] = effective_config.response_schema + + thinking_budget = getattr(effective_config, "thinking_budget", None) + if thinking_budget is not None: + if "thinkingConfig" not in generation_config: + generation_config["thinkingConfig"] = {} + generation_config["thinkingConfig"]["thinkingBudget"] = thinking_budget + + if getattr(effective_config, "response_modalities", None): + modalities = effective_config.response_modalities + if isinstance(modalities, list): + generation_config["responseModalities"] = [m.upper() for m in modalities] + elif isinstance(modalities, str): + generation_config["responseModalities"] = [modalities.upper()] + + generation_config = {k: v for k, v in generation_config.items() if v is not None} + + if generation_config: + param_keys = list(generation_config.keys()) + logger.debug(f"构建Gemini生成配置完成,包含 {len(generation_config)} 个参数: {param_keys}") + + return generation_config + + def _build_safety_settings( + self, config: "LLMGenerationConfig | None" = None + ) -> list[dict[str, Any]] | None: + """构建安全设置""" + if not config: + return None + + safety_settings = [] + + safety_categories = [ + "HARM_CATEGORY_HARASSMENT", + "HARM_CATEGORY_HATE_SPEECH", + "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "HARM_CATEGORY_DANGEROUS_CONTENT", + ] + + custom_safety_settings = getattr(config, "safety_settings", None) + if custom_safety_settings: + for category, threshold in custom_safety_settings.items(): + safety_settings.append({"category": category, "threshold": threshold}) + else: + for category in safety_categories: + safety_settings.append({"category": category, "threshold": "BLOCK_MEDIUM_AND_ABOVE"}) + + return safety_settings if safety_settings else None + + def parse_response( + self, + model: "LLMModel", + response_json: dict[str, Any], + is_advanced: bool = False, + ) -> ResponseData: + """解析API响应""" + return self._parse_response(model, response_json, is_advanced) + + def _parse_response( + self, + model: "LLMModel", + response_json: dict[str, Any], + is_advanced: bool = False, + ) -> ResponseData: + """解析 Gemini API 响应""" + _ = is_advanced + self.validate_response(response_json) + + try: + candidates = response_json.get("candidates", []) + if not candidates: + logger.debug("Gemini响应中没有candidates。") + return ResponseData(text="", raw_response=response_json) + + candidate = candidates[0] + + if candidate.get("finishReason") in ["RECITATION", "OTHER"] and not candidate.get("content"): + logger.warning( + f"Gemini candidate finished with reason '{candidate.get('finishReason')}' and no content." + ) + return ResponseData( + text="", raw_response=response_json, usage_info=response_json.get("usageMetadata") + ) + + content_data = candidate.get("content", {}) + parts = content_data.get("parts", []) + + text_content = "" + parsed_tool_calls: list["LLMToolCall"] | None = None + + for part in parts: + if "text" in part: + text_content += part["text"] + elif "functionCall" in part: + if parsed_tool_calls is None: + parsed_tool_calls = [] + fc_data = part["functionCall"] + try: + import json + + from ..types.models import LLMToolCall, LLMToolFunction + + parsed_tool_calls.append( + LLMToolCall( + id=f"call_{model.provider_name}_{len(parsed_tool_calls)}", + function=LLMToolFunction( + name=fc_data["name"], + arguments=json.dumps(fc_data["args"]), + ), + ) + ) + except KeyError as e: + logger.warning(f"解析Gemini functionCall时缺少键: {fc_data}, 错误: {e}") + except Exception as e: + logger.warning(f"解析Gemini functionCall时出错: {fc_data}, 错误: {e}") + elif "codeExecutionResult" in part: + result = part["codeExecutionResult"] + if result.get("outcome") == "OK": + output = result.get("output", "") + text_content += f"\n[代码执行结果]:\n{output}\n" + else: + text_content += f"\n[代码执行失败]: {result.get('outcome', 'UNKNOWN')}\n" + + usage_info = response_json.get("usageMetadata") + + grounding_metadata_obj = None + if grounding_data := candidate.get("groundingMetadata"): + try: + from ..types.models import LLMGroundingMetadata + + grounding_metadata_obj = LLMGroundingMetadata(**grounding_data) + except Exception as e: + logger.warning(f"无法解析Grounding元数据: {grounding_data}, {e}") + + return ResponseData( + text=text_content, + tool_calls=parsed_tool_calls, + usage_info=usage_info, + raw_response=response_json, + grounding_metadata=grounding_metadata_obj, + ) + + except Exception as e: + logger.error(f"解析 Gemini 响应失败: {e}", e=e) + raise LLMException(f"解析API响应失败: {e}", code=LLMErrorCode.RESPONSE_PARSE_ERROR, cause=e) + + def prepare_embedding_request( + self, + model: "LLMModel", + api_key: str, + texts: list[str], + task_type: "EmbeddingTaskType | str", + **kwargs: Any, + ) -> RequestData: + """准备文本嵌入请求""" + api_model_name = model.model_name + if not api_model_name.startswith("models/"): + api_model_name = f"models/{api_model_name}" + + url = self.get_api_url(model, f"/{api_model_name}:batchEmbedContents") + headers = self.get_base_headers(api_key) + + requests_payload = [] + for text_content in texts: + request_item: dict[str, Any] = { + "content": {"parts": [{"text": text_content}]}, + } + + from ..types.enums import EmbeddingTaskType + + if task_type and task_type != EmbeddingTaskType.RETRIEVAL_DOCUMENT: + request_item["task_type"] = str(task_type).upper() + if title := kwargs.get("title"): + request_item["title"] = title + if output_dimensionality := kwargs.get("output_dimensionality"): + request_item["output_dimensionality"] = output_dimensionality + + requests_payload.append(request_item) + + body = {"requests": requests_payload} + return RequestData(url=url, headers=headers, body=body) + + def parse_embedding_response(self, response_json: dict[str, Any]) -> list[list[float]]: + """解析文本嵌入响应""" + try: + embeddings_data = response_json["embeddings"] + return [item["values"] for item in embeddings_data] + except KeyError as e: + logger.error(f"解析Gemini嵌入响应时缺少键: {e}. 响应: {response_json}") + raise LLMException( + "Gemini嵌入响应格式错误", code=LLMErrorCode.RESPONSE_PARSE_ERROR, details={"error": str(e)} + ) + except Exception as e: + logger.error(f"解析Gemini嵌入响应时发生未知错误: {e}. 响应: {response_json}") + raise LLMException( + f"解析Gemini嵌入响应失败: {e}", code=LLMErrorCode.RESPONSE_PARSE_ERROR, cause=e + ) + + def validate_embedding_response(self, response_json: dict[str, Any]) -> None: + """验证嵌入响应""" + super().validate_embedding_response(response_json) + if "embeddings" not in response_json or not isinstance(response_json["embeddings"], list): + raise LLMException( + "Gemini嵌入响应缺少'embeddings'字段或格式不正确", + code=LLMErrorCode.RESPONSE_PARSE_ERROR, + details=response_json, + ) + for item in response_json["embeddings"]: + if "values" not in item: + raise LLMException( + "Gemini嵌入响应的条目中缺少'values'字段", + code=LLMErrorCode.RESPONSE_PARSE_ERROR, + details=response_json, + ) diff --git a/zhenxun/services/llm/adapters/openai.py b/zhenxun/services/llm/adapters/openai.py new file mode 100644 index 00000000..046f0277 --- /dev/null +++ b/zhenxun/services/llm/adapters/openai.py @@ -0,0 +1,57 @@ +""" +OpenAI API 适配器 + +支持 OpenAI、DeepSeek 和其他 OpenAI 兼容的 API 服务。 +""" + +from typing import TYPE_CHECKING + +from .base import OpenAICompatAdapter, RequestData + +if TYPE_CHECKING: + from ..service import LLMModel + + +class OpenAIAdapter(OpenAICompatAdapter): + """OpenAI兼容API适配器""" + + @property + def api_type(self) -> str: + return "openai" + + @property + def supported_api_types(self) -> list[str]: + return ["openai", "deepseek", "general_openai_compat"] + + def get_chat_endpoint(self) -> str: + """返回聊天完成端点""" + return "/v1/chat/completions" + + def get_embedding_endpoint(self) -> str: + """返回嵌入端点""" + return "/v1/embeddings" + + def prepare_simple_request( + self, + model: "LLMModel", + api_key: str, + prompt: str, + history: list[dict[str, str]] | None = None, + ) -> RequestData: + """准备简单文本生成请求 - OpenAI优化实现""" + url = self.get_api_url(model, self.get_chat_endpoint()) + headers = self.get_base_headers(api_key) + + messages = [] + if history: + messages.extend(history) + messages.append({"role": "user", "content": prompt}) + + body = { + "model": model.model_name, + "messages": messages, + } + + body = self.apply_config_override(model, body) + + return RequestData(url=url, headers=headers, body=body) diff --git a/zhenxun/services/llm/adapters/zhipu.py b/zhenxun/services/llm/adapters/zhipu.py new file mode 100644 index 00000000..e5eb032f --- /dev/null +++ b/zhenxun/services/llm/adapters/zhipu.py @@ -0,0 +1,57 @@ +""" +智谱 AI API 适配器 + +支持智谱 AI 的 GLM 系列模型,使用 OpenAI 兼容的接口格式。 +""" + +from typing import TYPE_CHECKING + +from .base import OpenAICompatAdapter, RequestData + +if TYPE_CHECKING: + from ..service import LLMModel + + +class ZhipuAdapter(OpenAICompatAdapter): + """智谱AI适配器 - 使用智谱AI专用的OpenAI兼容接口""" + + @property + def api_type(self) -> str: + return "zhipu" + + @property + def supported_api_types(self) -> list[str]: + return ["zhipu"] + + def get_chat_endpoint(self) -> str: + """返回智谱AI聊天完成端点""" + return "/api/paas/v4/chat/completions" + + def get_embedding_endpoint(self) -> str: + """返回智谱AI嵌入端点""" + return "/v4/embeddings" + + def prepare_simple_request( + self, + model: "LLMModel", + api_key: str, + prompt: str, + history: list[dict[str, str]] | None = None, + ) -> RequestData: + """准备简单文本生成请求 - 智谱AI优化实现""" + url = self.get_api_url(model, self.get_chat_endpoint()) + headers = self.get_base_headers(api_key) + + messages = [] + if history: + messages.extend(history) + messages.append({"role": "user", "content": prompt}) + + body = { + "model": model.model_name, + "messages": messages, + } + + body = self.apply_config_override(model, body) + + return RequestData(url=url, headers=headers, body=body) diff --git a/zhenxun/services/llm/api.py b/zhenxun/services/llm/api.py new file mode 100644 index 00000000..a4ffe90f --- /dev/null +++ b/zhenxun/services/llm/api.py @@ -0,0 +1,475 @@ +""" +LLM 服务的高级 API 接口 +""" + +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from typing import Any + +from nonebot_plugin_alconna.uniseg import UniMessage + +from zhenxun.services.log import logger + +from .config import CommonOverrides, LLMGenerationConfig +from .config.providers import get_ai_config +from .manager import get_global_default_model_name, get_model_instance +from .types import ( + EmbeddingTaskType, + LLMContentPart, + LLMErrorCode, + LLMException, + LLMMessage, + LLMResponse, + LLMTool, + ModelName, +) +from .utils import create_multimodal_message, unimsg_to_llm_parts + + +class TaskType(Enum): + """任务类型枚举""" + + CHAT = "chat" + CODE = "code" + SEARCH = "search" + ANALYSIS = "analysis" + GENERATION = "generation" + MULTIMODAL = "multimodal" + + +@dataclass +class AIConfig: + """AI配置类 - 简化版本""" + + model: ModelName = None + default_embedding_model: ModelName = None + temperature: float | None = None + max_tokens: int | None = None + enable_cache: bool = False + enable_code: bool = False + enable_search: bool = False + timeout: int | None = None + + enable_gemini_json_mode: bool = False + enable_gemini_thinking: bool = False + enable_gemini_safe_mode: bool = False + enable_gemini_multimodal: bool = False + enable_gemini_grounding: bool = False + + def __post_init__(self): + """初始化后从配置中读取默认值""" + ai_config = get_ai_config() + if self.model is None: + self.model = ai_config.get("default_model_name") + if self.timeout is None: + self.timeout = ai_config.get("timeout", 180) + + +class AI: + """统一的AI服务类 - 平衡设计版本 + + 提供三层API: + 1. 简单方法:ai.chat(), ai.code(), ai.search() + 2. 标准方法:ai.analyze() 支持复杂参数 + 3. 高级方法:通过get_model_instance()直接访问 + """ + + def __init__(self, config: AIConfig | None = None): + """初始化AI服务""" + self.config = config or AIConfig() + + async def chat( + self, + message: str | LLMMessage | list[LLMContentPart], + *, + model: ModelName = None, + **kwargs: Any, + ) -> str: + """聊天对话 - 支持简单多模态输入""" + llm_messages: list[LLMMessage] + + if isinstance(message, str): + llm_messages = [LLMMessage.user(message)] + elif isinstance(message, list) and all(isinstance(part, LLMContentPart) for part in message): + llm_messages = [LLMMessage.user(message)] + elif isinstance(message, LLMMessage): + llm_messages = [message] + else: + raise LLMException( + f"AI.chat 不支持的消息类型: {type(message)}. " + "请使用 str, LLMMessage, 或 list[LLMContentPart]. " + "对于更复杂的多模态输入或文件路径,请使用 AI.analyze().", + code=LLMErrorCode.API_REQUEST_FAILED, + ) + + response = await self._execute_generation(llm_messages, model, "聊天失败", kwargs) + return response.text + + async def code( + self, + prompt: str, + *, + model: ModelName = None, + timeout: int | None = None, + **kwargs: Any, + ) -> dict[str, Any]: + """代码执行""" + resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash" + + config = CommonOverrides.gemini_code_execution() + if timeout: + config.custom_params = config.custom_params or {} + config.custom_params["code_execution_timeout"] = timeout + + messages = [LLMMessage.user(prompt)] + + response = await self._execute_generation( + messages, resolved_model, "代码执行失败", kwargs, base_config=config + ) + + return { + "text": response.text, + "code_executions": response.code_executions or [], + "success": True, + } + + async def search( + self, query: str | UniMessage, *, model: ModelName = None, instruction: str = "", **kwargs: Any + ) -> dict[str, Any]: + """信息搜索 - 支持多模态输入""" + resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash" + config = CommonOverrides.gemini_grounding() + + if isinstance(query, str): + messages = [LLMMessage.user(query)] + elif isinstance(query, UniMessage): + content_parts = await unimsg_to_llm_parts(query) + + final_messages: list[LLMMessage] = [] + if instruction: + final_messages.append(LLMMessage.system(instruction)) + + if not content_parts: + if instruction: + final_messages.append(LLMMessage.user(instruction)) + else: + raise LLMException("搜索内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED) + else: + final_messages.append(LLMMessage.user(content_parts)) + + messages = final_messages + else: + raise LLMException( + f"不支持的搜索输入类型: {type(query)}. 请使用 str 或 UniMessage.", + code=LLMErrorCode.API_REQUEST_FAILED, + ) + + response = await self._execute_generation( + messages, resolved_model, "信息搜索失败", kwargs, base_config=config + ) + + result = { + "text": response.text, + "sources": [], + "queries": [], + "success": True, + } + + if response.grounding_metadata: + result["sources"] = response.grounding_metadata.grounding_attributions or [] + result["queries"] = response.grounding_metadata.web_search_queries or [] + + return result + + async def analyze( + self, + message: UniMessage, + *, + instruction: str = "", + model: ModelName = None, + tools: list[dict[str, Any]] | None = None, + tool_config: dict[str, Any] | None = None, + **kwargs: Any, + ) -> str | LLMResponse: + """ + 内容分析 - 接收 UniMessage 物件进行多模态分析和工具呼叫。 + 这是处理复杂互动的主要方法。 + """ + content_parts = await unimsg_to_llm_parts(message) + + final_messages: list[LLMMessage] = [] + if instruction: + final_messages.append(LLMMessage.system(instruction)) + + if not content_parts: + if instruction: + final_messages.append(LLMMessage.user(instruction)) + else: + raise LLMException("分析内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED) + else: + final_messages.append(LLMMessage.user(content_parts)) + + llm_tools = None + if tools: + llm_tools = [] + for tool_dict in tools: + if isinstance(tool_dict, dict): + if "name" in tool_dict and "description" in tool_dict: + llm_tool = LLMTool( + type="function", + function={ + "name": tool_dict["name"], + "description": tool_dict["description"], + "parameters": tool_dict.get("parameters", {}), + }, + ) + llm_tools.append(llm_tool) + else: + llm_tools.append(LLMTool(**tool_dict)) + else: + llm_tools.append(tool_dict) + + tool_choice = None + if tool_config: + mode = tool_config.get("mode", "auto") + if mode == "auto": + tool_choice = "auto" + elif mode == "any": + tool_choice = "any" + elif mode == "none": + tool_choice = "none" + + response = await self._execute_generation( + final_messages, model, "内容分析失败", kwargs, llm_tools=llm_tools, tool_choice=tool_choice + ) + + if response.tool_calls: + return response + return response.text + + async def _execute_generation( + self, + messages: list[LLMMessage], + model_name: ModelName, + error_message: str, + config_overrides: dict[str, Any], + llm_tools: list[LLMTool] | None = None, + tool_choice: str | dict[str, Any] | None = None, + base_config: LLMGenerationConfig | None = None, + ) -> LLMResponse: + """通用的生成执行方法,封装重复的模型获取、配置合并和异常处理逻辑""" + try: + resolved_model_name = self._resolve_model_name(model_name or self.config.model) + final_config_dict = self._merge_config(config_overrides, base_config=base_config) + + async with await get_model_instance( + resolved_model_name, override_config=final_config_dict + ) as model_instance: + return await model_instance.generate_response( + messages, tools=llm_tools, tool_choice=tool_choice + ) + except LLMException: + raise + except Exception as e: + logger.error(f"{error_message}: {e}", e=e) + raise LLMException(f"{error_message}: {e}", cause=e) + + def _resolve_model_name(self, model_name: ModelName) -> str: + """解析模型名称""" + if model_name: + return model_name + + default_model = get_global_default_model_name() + if default_model: + return default_model + + raise LLMException( + "未指定模型名称且未设置全局默认模型", + code=LLMErrorCode.MODEL_NOT_FOUND, + ) + + def _merge_config( + self, + user_config: dict[str, Any], + base_config: LLMGenerationConfig | None = None, + ) -> dict[str, Any]: + """合并配置""" + final_config = {} + if base_config: + final_config.update(base_config.to_dict()) + + if self.config.temperature is not None: + final_config["temperature"] = self.config.temperature + if self.config.max_tokens is not None: + final_config["max_tokens"] = self.config.max_tokens + + if self.config.enable_cache: + final_config["enable_caching"] = True + if self.config.enable_code: + final_config["enable_code_execution"] = True + if self.config.enable_search: + final_config["enable_grounding"] = True + + if self.config.enable_gemini_json_mode: + final_config["response_mime_type"] = "application/json" + if self.config.enable_gemini_thinking: + final_config["thinking_budget"] = 0.8 + if self.config.enable_gemini_safe_mode: + final_config["safety_settings"] = CommonOverrides.gemini_safe().safety_settings + if self.config.enable_gemini_multimodal: + final_config.update(CommonOverrides.gemini_multimodal().to_dict()) + if self.config.enable_gemini_grounding: + final_config["enable_grounding"] = True + + final_config.update(user_config) + + return final_config + + async def embed( + self, + texts: list[str] | str, + *, + model: ModelName = None, + task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT, + **kwargs: Any, + ) -> list[list[float]]: + """生成文本嵌入向量""" + if isinstance(texts, str): + texts = [texts] + if not texts: + return [] + + try: + resolved_model_str = model or self.config.default_embedding_model or self.config.model + if not resolved_model_str: + raise LLMException( + "使用 embed 功能时必须指定嵌入模型名称,或在 AIConfig 中配置 default_embedding_model。", + code=LLMErrorCode.MODEL_NOT_FOUND, + ) + resolved_model_str = self._resolve_model_name(resolved_model_str) + + async with await get_model_instance( + resolved_model_str, + override_config=None, + ) as embedding_model_instance: + return await embedding_model_instance.generate_embeddings( + texts, task_type=task_type, **kwargs + ) + except LLMException: + raise + except Exception as e: + logger.error(f"文本嵌入失败: {e}", e=e) + raise LLMException(f"文本嵌入失败: {e}", code=LLMErrorCode.EMBEDDING_FAILED, cause=e) + + +async def chat( + message: str | LLMMessage | list[LLMContentPart], + *, + model: ModelName = None, + **kwargs: Any, +) -> str: + """聊天对话便捷函数""" + ai = AI() + return await ai.chat(message, model=model, **kwargs) + + +async def code( + prompt: str, + *, + model: ModelName = None, + timeout: int | None = None, + **kwargs: Any, +) -> dict[str, Any]: + """代码执行便捷函数""" + ai = AI() + return await ai.code(prompt, model=model, timeout=timeout, **kwargs) + + +async def search( + query: str | UniMessage, + *, + model: ModelName = None, + instruction: str = "", + **kwargs: Any, +) -> dict[str, Any]: + """信息搜索便捷函数""" + ai = AI() + return await ai.search(query, model=model, instruction=instruction, **kwargs) + + +async def analyze( + message: UniMessage, + *, + instruction: str = "", + model: ModelName = None, + tools: list[dict[str, Any]] | None = None, + tool_config: dict[str, Any] | None = None, + **kwargs: Any, +) -> str | LLMResponse: + """内容分析便捷函数""" + ai = AI() + return await ai.analyze( + message, + instruction=instruction, + model=model, + tools=tools, + tool_config=tool_config, + **kwargs, + ) + + +async def analyze_with_images( + text: str, + images: list[str | Path | bytes] | str | Path | bytes, + *, + instruction: str = "", + model: ModelName = None, + **kwargs: Any, +) -> str | LLMResponse: + """图片分析便捷函数""" + message = create_multimodal_message(text=text, images=images) + return await analyze(message, instruction=instruction, model=model, **kwargs) + + +async def analyze_multimodal( + text: str | None = None, + images: list[str | Path | bytes] | str | Path | bytes | None = None, + videos: list[str | Path | bytes] | str | Path | bytes | None = None, + audios: list[str | Path | bytes] | str | Path | bytes | None = None, + *, + instruction: str = "", + model: ModelName = None, + **kwargs: Any, +) -> str | LLMResponse: + """多模态分析便捷函数""" + message = create_multimodal_message(text=text, images=images, videos=videos, audios=audios) + return await analyze(message, instruction=instruction, model=model, **kwargs) + + +async def search_multimodal( + text: str | None = None, + images: list[str | Path | bytes] | str | Path | bytes | None = None, + videos: list[str | Path | bytes] | str | Path | bytes | None = None, + audios: list[str | Path | bytes] | str | Path | bytes | None = None, + *, + instruction: str = "", + model: ModelName = None, + **kwargs: Any, +) -> dict[str, Any]: + """多模态搜索便捷函数""" + message = create_multimodal_message(text=text, images=images, videos=videos, audios=audios) + ai = AI() + return await ai.search(message, model=model, instruction=instruction, **kwargs) + + +async def embed( + texts: list[str] | str, + *, + model: ModelName = None, + task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT, + **kwargs: Any, +) -> list[list[float]]: + """文本嵌入便捷函数""" + ai = AI() + return await ai.embed(texts, model=model, task_type=task_type, **kwargs) diff --git a/zhenxun/services/llm/config/__init__.py b/zhenxun/services/llm/config/__init__.py new file mode 100644 index 00000000..778a04bd --- /dev/null +++ b/zhenxun/services/llm/config/__init__.py @@ -0,0 +1,25 @@ +""" +LLM 配置模块 + +提供生成配置、预设配置和配置验证功能。 +""" + +from .generation import ( + LLMGenerationConfig, + ModelConfigOverride, + apply_api_specific_mappings, + create_generation_config_from_kwargs, + validate_override_params, +) +from .presets import CommonOverrides +from .providers import register_llm_configs + +__all__ = [ + "CommonOverrides", + "LLMGenerationConfig", + "ModelConfigOverride", + "apply_api_specific_mappings", + "create_generation_config_from_kwargs", + "register_llm_configs", + "validate_override_params", +] diff --git a/zhenxun/services/llm/config/generation.py b/zhenxun/services/llm/config/generation.py new file mode 100644 index 00000000..a7ad9171 --- /dev/null +++ b/zhenxun/services/llm/config/generation.py @@ -0,0 +1,220 @@ +""" +LLM 生成配置相关类和函数 +""" + +from typing import Any + +from pydantic import BaseModel, Field + +from zhenxun.services.log import logger + +from ..types.enums import ResponseFormat +from ..types.exceptions import LLMErrorCode, LLMException + + +class ModelConfigOverride(BaseModel): + """模型配置覆盖参数""" + + temperature: float | None = Field(default=None, ge=0.0, le=2.0, description="生成温度") + max_tokens: int | None = Field(default=None, gt=0, description="最大输出token数") + top_p: float | None = Field(default=None, ge=0.0, le=1.0, description="核采样参数") + top_k: int | None = Field(default=None, gt=0, description="Top-K采样参数") + frequency_penalty: float | None = Field(default=None, ge=-2.0, le=2.0, description="频率惩罚") + presence_penalty: float | None = Field(default=None, ge=-2.0, le=2.0, description="存在惩罚") + repetition_penalty: float | None = Field(default=None, ge=0.0, le=2.0, description="重复惩罚") + + stop: list[str] | str | None = Field(default=None, description="停止序列") + + response_format: ResponseFormat | dict[str, Any] | None = Field( + default=None, description="期望的响应格式" + ) + response_mime_type: str | None = Field(default=None, description="响应MIME类型(Gemini专用)") + response_schema: dict[str, Any] | None = Field(default=None, description="JSON响应模式") + thinking_budget: float | None = Field(default=None, ge=0.0, le=1.0, description="思考预算") + safety_settings: dict[str, str] | None = Field(default=None, description="安全设置") + response_modalities: list[str] | None = Field(default=None, description="响应模态类型") + + enable_code_execution: bool | None = Field(default=None, description="是否启用代码执行") + enable_grounding: bool | None = Field(default=None, description="是否启用信息来源关联") + enable_caching: bool | None = Field(default=None, description="是否启用响应缓存") + + custom_params: dict[str, Any] | None = Field(default=None, description="自定义参数") + + def to_dict(self) -> dict[str, Any]: + """转换为字典,排除None值""" + result = {} + for key, value in self.model_dump().items(): + if value is not None: + if key == "custom_params" and isinstance(value, dict): + result.update(value) + else: + result[key] = value + return result + + def merge_with_base_config( + self, + base_temperature: float | None = None, + base_max_tokens: int | None = None, + ) -> dict[str, Any]: + """与基础配置合并,覆盖参数优先""" + merged = {} + + if base_temperature is not None: + merged["temperature"] = base_temperature + if base_max_tokens is not None: + merged["max_tokens"] = base_max_tokens + + override_dict = self.to_dict() + merged.update(override_dict) + + return merged + + +class LLMGenerationConfig(ModelConfigOverride): + """LLM 生成配置,继承模型配置覆盖参数""" + + def to_api_params(self, api_type: str, model_name: str) -> dict[str, Any]: + """转换为API参数,支持不同API类型的参数名映射""" + _ = model_name + params = {} + + if self.temperature is not None: + params["temperature"] = self.temperature + + if self.max_tokens is not None: + if api_type in ["gemini", "gemini_native"]: + params["maxOutputTokens"] = self.max_tokens + else: + params["max_tokens"] = self.max_tokens + + if api_type in ["gemini", "gemini_native"]: + if self.top_k is not None: + params["topK"] = self.top_k + if self.top_p is not None: + params["topP"] = self.top_p + else: + if self.top_k is not None: + params["top_k"] = self.top_k + if self.top_p is not None: + params["top_p"] = self.top_p + + if api_type in ["openai", "deepseek", "zhipu", "general_openai_compat"]: + if self.frequency_penalty is not None: + params["frequency_penalty"] = self.frequency_penalty + if self.presence_penalty is not None: + params["presence_penalty"] = self.presence_penalty + + if self.repetition_penalty is not None: + if api_type == "openai": + logger.warning("OpenAI官方API不支持repetition_penalty参数,已忽略") + else: + params["repetition_penalty"] = self.repetition_penalty + + # 处理 response_format 参数 + if self.response_format is not None: + if isinstance(self.response_format, dict): + # 直接使用字典格式的 response_format(如 {'type': 'json_object'}) + if api_type in ["openai", "zhipu", "deepseek", "general_openai_compat"]: + params["response_format"] = self.response_format + logger.debug(f"为 {api_type} 使用自定义 response_format: {self.response_format}") + elif self.response_format == ResponseFormat.JSON: + if api_type in ["openai", "zhipu", "deepseek", "general_openai_compat"]: + params["response_format"] = {"type": "json_object"} + logger.debug(f"为 {api_type} 启用 JSON 对象输出模式") + elif api_type in ["gemini", "gemini_native"]: + params["responseMimeType"] = "application/json" + if self.response_schema: + params["responseSchema"] = self.response_schema + logger.debug(f"为 {api_type} 启用 JSON MIME 类型输出模式") + + if api_type in ["gemini", "gemini_native"]: + if self.response_format != ResponseFormat.JSON and self.response_mime_type is not None: + params["responseMimeType"] = self.response_mime_type + logger.debug(f"使用显式设置的 responseMimeType: {self.response_mime_type}") + + if self.response_schema is not None and "responseSchema" not in params: + params["responseSchema"] = self.response_schema + if self.thinking_budget is not None: + params["thinkingBudget"] = self.thinking_budget + if self.safety_settings is not None: + params["safetySettings"] = self.safety_settings + if self.response_modalities is not None: + params["responseModalities"] = self.response_modalities + + if self.custom_params: + custom_mapped = apply_api_specific_mappings(self.custom_params, api_type) + params.update(custom_mapped) + + logger.debug(f"为{api_type}转换配置参数: {len(params)}个参数") + return params + + +def validate_override_params( + override_config: dict[str, Any] | LLMGenerationConfig | None, +) -> LLMGenerationConfig: + """验证和标准化覆盖参数""" + if override_config is None: + return LLMGenerationConfig() + + if isinstance(override_config, dict): + try: + filtered_config = {k: v for k, v in override_config.items() if v is not None} + return LLMGenerationConfig(**filtered_config) + except Exception as e: + logger.warning(f"覆盖配置参数验证失败: {e}") + raise LLMException( + f"无效的覆盖配置参数: {e}", + code=LLMErrorCode.CONFIGURATION_ERROR, + cause=e, + ) + + return override_config + + +def apply_api_specific_mappings(params: dict[str, Any], api_type: str) -> dict[str, Any]: + """应用API特定的参数映射""" + mapped_params = params.copy() + + if api_type in ["gemini", "gemini_native"]: + if "max_tokens" in mapped_params: + mapped_params["maxOutputTokens"] = mapped_params.pop("max_tokens") + if "top_k" in mapped_params: + mapped_params["topK"] = mapped_params.pop("top_k") + if "top_p" in mapped_params: + mapped_params["topP"] = mapped_params.pop("top_p") + + unsupported = ["frequency_penalty", "presence_penalty", "repetition_penalty"] + for param in unsupported: + if param in mapped_params: + logger.warning(f"Gemini 原生API不支持参数 '{param}',已忽略") + mapped_params.pop(param) + + elif api_type in ["openai", "deepseek", "zhipu", "general_openai_compat"]: + if "repetition_penalty" in mapped_params and api_type == "openai": + logger.warning("OpenAI官方API不支持repetition_penalty参数,已忽略") + mapped_params.pop("repetition_penalty") + + if "stop" in mapped_params: + stop_value = mapped_params["stop"] + if isinstance(stop_value, str): + mapped_params["stop"] = [stop_value] + + return mapped_params + + +def create_generation_config_from_kwargs(**kwargs) -> LLMGenerationConfig: + """从关键字参数创建生成配置""" + known_fields = set(LLMGenerationConfig.model_fields.keys()) + known_params = {} + custom_params = {} + + for key, value in kwargs.items(): + if key in known_fields: + known_params[key] = value + else: + custom_params[key] = value + + if custom_params: + known_params["custom_params"] = custom_params + + return LLMGenerationConfig(**known_params) diff --git a/zhenxun/services/llm/config/presets.py b/zhenxun/services/llm/config/presets.py new file mode 100644 index 00000000..04a72dab --- /dev/null +++ b/zhenxun/services/llm/config/presets.py @@ -0,0 +1,155 @@ +""" +LLM 预设配置 + +提供常用的配置预设,特别是针对 Gemini 的高级功能。 +""" + +from typing import Any + +from .generation import LLMGenerationConfig + + +class CommonOverrides: + """常用的配置覆盖预设""" + + @staticmethod + def creative() -> LLMGenerationConfig: + """创意模式:高温度,鼓励创新""" + return LLMGenerationConfig(temperature=0.9, top_p=0.95, frequency_penalty=0.1) + + @staticmethod + def precise() -> LLMGenerationConfig: + """精确模式:低温度,确定性输出""" + return LLMGenerationConfig(temperature=0.1, top_p=0.9, frequency_penalty=0.0) + + @staticmethod + def balanced() -> LLMGenerationConfig: + """平衡模式:中等温度""" + return LLMGenerationConfig(temperature=0.5, top_p=0.9, frequency_penalty=0.0) + + @staticmethod + def concise(max_tokens: int = 100) -> LLMGenerationConfig: + """简洁模式:限制输出长度""" + return LLMGenerationConfig(temperature=0.3, max_tokens=max_tokens, stop=["\n\n", "。", "!", "?"]) + + @staticmethod + def detailed(max_tokens: int = 2000) -> LLMGenerationConfig: + """详细模式:鼓励详细输出""" + return LLMGenerationConfig(temperature=0.7, max_tokens=max_tokens, frequency_penalty=-0.1) + + @staticmethod + def gemini_json() -> LLMGenerationConfig: + """Gemini JSON模式:强制JSON输出""" + return LLMGenerationConfig(temperature=0.3, response_mime_type="application/json") + + @staticmethod + def gemini_thinking(budget: float = 0.8) -> LLMGenerationConfig: + """Gemini 思考模式:使用思考预算""" + return LLMGenerationConfig(temperature=0.7, thinking_budget=budget) + + @staticmethod + def gemini_creative() -> LLMGenerationConfig: + """Gemini 创意模式:高温度创意输出""" + return LLMGenerationConfig(temperature=0.9, top_p=0.95) + + @staticmethod + def gemini_structured(schema: dict[str, Any]) -> LLMGenerationConfig: + """Gemini 结构化输出:自定义JSON模式""" + return LLMGenerationConfig( + temperature=0.3, + response_mime_type="application/json", + response_schema=schema, + ) + + @staticmethod + def gemini_safe() -> LLMGenerationConfig: + """Gemini 安全模式:严格安全设置""" + return LLMGenerationConfig( + temperature=0.5, + safety_settings={ + "HARM_CATEGORY_HARASSMENT": "BLOCK_MEDIUM_AND_ABOVE", + "HARM_CATEGORY_HATE_SPEECH": "BLOCK_MEDIUM_AND_ABOVE", + "HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_MEDIUM_AND_ABOVE", + "HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_MEDIUM_AND_ABOVE", + }, + ) + + @staticmethod + def gemini_multimodal() -> LLMGenerationConfig: + """Gemini 多模态模式:优化多模态处理""" + return LLMGenerationConfig(temperature=0.6, max_tokens=2048, top_p=0.8) + + @staticmethod + def gemini_code_execution() -> LLMGenerationConfig: + """Gemini 代码执行模式:启用代码执行功能""" + return LLMGenerationConfig( + temperature=0.3, + max_tokens=4096, + enable_code_execution=True, + custom_params={"code_execution_timeout": 30}, + ) + + @staticmethod + def gemini_grounding() -> LLMGenerationConfig: + """Gemini 信息来源关联模式:启用Google搜索""" + return LLMGenerationConfig( + temperature=0.5, + max_tokens=4096, + enable_grounding=True, + custom_params={"grounding_config": {"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}}}, + ) + + @staticmethod + def gemini_cached() -> LLMGenerationConfig: + """Gemini 缓存模式:启用响应缓存""" + return LLMGenerationConfig( + temperature=0.3, + max_tokens=2048, + enable_caching=True, + ) + + @staticmethod + def gemini_advanced() -> LLMGenerationConfig: + """Gemini 高级模式:启用所有高级功能""" + return LLMGenerationConfig( + temperature=0.5, + max_tokens=4096, + enable_code_execution=True, + enable_grounding=True, + enable_caching=True, + custom_params={ + "code_execution_timeout": 30, + "grounding_config": {"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}}, + }, + ) + + @staticmethod + def gemini_research() -> LLMGenerationConfig: + """Gemini 研究模式:思考+搜索+结构化输出""" + return LLMGenerationConfig( + temperature=0.6, + max_tokens=4096, + thinking_budget=0.8, + enable_grounding=True, + response_mime_type="application/json", + custom_params={"grounding_config": {"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}}}, + ) + + @staticmethod + def gemini_analysis() -> LLMGenerationConfig: + """Gemini 分析模式:深度思考+详细输出""" + return LLMGenerationConfig( + temperature=0.4, + max_tokens=6000, + thinking_budget=0.9, + top_p=0.8, + ) + + @staticmethod + def gemini_fast_response() -> LLMGenerationConfig: + """Gemini 快速响应模式:低延迟+简洁输出""" + return LLMGenerationConfig( + temperature=0.3, + max_tokens=512, + top_p=0.8, + ) diff --git a/zhenxun/services/llm/config/providers.py b/zhenxun/services/llm/config/providers.py new file mode 100644 index 00000000..bdb1c584 --- /dev/null +++ b/zhenxun/services/llm/config/providers.py @@ -0,0 +1,107 @@ +""" +LLM 提供商配置管理 + +负责注册和管理 AI 服务提供商的配置项。 +""" + +from zhenxun.configs.config import Config +from zhenxun.services.log import logger + +from ..types.models import ProviderConfig + +AI_CONFIG_GROUP = "AI" +PROVIDERS_CONFIG_KEY = "PROVIDERS" + + +def get_ai_config(): + """获取 AI 配置组""" + return Config.get(AI_CONFIG_GROUP) + + +def register_llm_configs(): + """注册 LLM 服务的配置项""" + logger.info("注册 LLM 服务的配置项") + Config.add_plugin_config( + AI_CONFIG_GROUP, + "default_model_name", + None, + help="LLM服务全局默认使用的模型名称 (格式: ProviderName/ModelName)", + type=str, + ) + Config.add_plugin_config( + AI_CONFIG_GROUP, + "proxy", + None, + help="LLM服务请求使用的网络代理,例如 http://127.0.0.1:7890", + type=str, + ) + Config.add_plugin_config( + AI_CONFIG_GROUP, + "timeout", + 180, + help="LLM服务API请求超时时间(秒)", + type=int, + ) + Config.add_plugin_config( + AI_CONFIG_GROUP, + "max_retries_llm", + 3, + help="LLM服务请求失败时的最大重试次数", + type=int, + ) + Config.add_plugin_config( + AI_CONFIG_GROUP, + "retry_delay_llm", + 2, + help="LLM服务请求重试的基础延迟时间(秒)", + type=int, + ) + Config.add_plugin_config( + AI_CONFIG_GROUP, + PROVIDERS_CONFIG_KEY, + [ + { + "name": "DeepSeek", + "api_key": "sk-******", + "api_base": "https://api.deepseek.com", + "api_type": "openai", + "models": [ + { + "model_name": "deepseek-chat", + "max_tokens": 4096, + "temperature": 0.7, + }, + { + "model_name": "deepseek-reasoner", + }, + ], + }, + { + "name": "GLM", + "api_key": "", + "api_base": "https://open.bigmodel.cn", + "api_type": "zhipu", + "models": [ + {"model_name": "glm-4-flash"}, + {"model_name": "glm-4-plus"}, + ], + }, + { + "name": "Gemini", + "api_key": [ + "AIzaSy*****************************", + "AIzaSy*****************************", + "AIzaSy*****************************", + ], + "api_base": "https://generativelanguage.googleapis.com", + "api_type": "gemini", + "models": [ + {"model_name": "gemini-2.0-flash"}, + {"model_name": "gemini-2.5-flash-preview-05-20"}, + ], + }, + ], + help="配置多个 AI 服务提供商及其模型信息 (列表)", + default_value=[], + type=list[ProviderConfig], + ) diff --git a/zhenxun/services/llm/core.py b/zhenxun/services/llm/core.py new file mode 100644 index 00000000..1c1c67aa --- /dev/null +++ b/zhenxun/services/llm/core.py @@ -0,0 +1,341 @@ +""" +LLM 核心基础设施模块 + +包含执行 LLM 请求所需的底层组件,如 HTTP 客户端、API Key 存储和智能重试逻辑。 +""" + +import asyncio +from typing import Any + +import httpx +from pydantic import BaseModel + +from zhenxun.services.log import logger +from zhenxun.utils.user_agent import get_user_agent + +from .types import ProviderConfig +from .types.exceptions import LLMErrorCode, LLMException + + +class HttpClientConfig(BaseModel): + """HTTP客户端配置""" + + timeout: int = 180 + max_connections: int = 100 + max_keepalive_connections: int = 20 + proxy: str | None = None + + +class LLMHttpClient: + """LLM服务专用HTTP客户端""" + + def __init__(self, config: HttpClientConfig | None = None): + self.config = config or HttpClientConfig() + self._client: httpx.AsyncClient | None = None + self._active_requests = 0 + self._lock = asyncio.Lock() + + async def _ensure_client_initialized(self) -> httpx.AsyncClient: + if self._client is None or self._client.is_closed: + async with self._lock: + if self._client is None or self._client.is_closed: + logger.debug( + f"LLMHttpClient: Initializing new httpx.AsyncClient with config: {self.config}" + ) + headers = get_user_agent() + limits = httpx.Limits( + max_connections=self.config.max_connections, + max_keepalive_connections=self.config.max_keepalive_connections, + ) + timeout = httpx.Timeout(self.config.timeout) + self._client = httpx.AsyncClient( + headers=headers, + limits=limits, + timeout=timeout, + proxy=self.config.proxy, + follow_redirects=True, + ) + if self._client is None: + raise LLMException("HTTP client failed to initialize.", LLMErrorCode.CONFIGURATION_ERROR) + return self._client + + async def post(self, url: str, **kwargs: Any) -> httpx.Response: + client = await self._ensure_client_initialized() + async with self._lock: + self._active_requests += 1 + try: + return await client.post(url, **kwargs) + finally: + async with self._lock: + self._active_requests -= 1 + + async def close(self): + async with self._lock: + if self._client and not self._client.is_closed: + logger.debug( + f"LLMHttpClient: Closing with config: {self.config}. " + f"Active requests: {self._active_requests}" + ) + if self._active_requests > 0: + logger.warning( + f"LLMHttpClient: Closing while {self._active_requests} requests are still active." + ) + await self._client.aclose() + self._client = None + logger.debug(f"LLMHttpClient for config {self.config} definitively closed.") + + @property + def is_closed(self) -> bool: + return self._client is None or self._client.is_closed + + +class LLMHttpClientManager: + """管理 LLMHttpClient 实例的工厂和池""" + + def __init__(self): + self._clients: dict[tuple[int, str | None], LLMHttpClient] = {} + self._lock = asyncio.Lock() + + def _get_client_key(self, provider_config: ProviderConfig) -> tuple[int, str | None]: + return (provider_config.timeout, provider_config.proxy) + + async def get_client(self, provider_config: ProviderConfig) -> LLMHttpClient: + key = self._get_client_key(provider_config) + async with self._lock: + client = self._clients.get(key) + if client and not client.is_closed: + logger.debug(f"LLMHttpClientManager: Reusing existing LLMHttpClient for key: {key}") + return client + + if client and client.is_closed: + logger.debug( + f"LLMHttpClientManager: Found a closed client for key {key}. Creating a new one." + ) + + logger.debug(f"LLMHttpClientManager: Creating new LLMHttpClient for key: {key}") + http_client_config = HttpClientConfig( + timeout=provider_config.timeout, proxy=provider_config.proxy + ) + new_client = LLMHttpClient(config=http_client_config) + self._clients[key] = new_client + return new_client + + async def shutdown(self): + async with self._lock: + logger.info(f"LLMHttpClientManager: Shutting down. Closing {len(self._clients)} client(s).") + close_tasks = [ + client.close() for client in self._clients.values() if client and not client.is_closed + ] + if close_tasks: + await asyncio.gather(*close_tasks, return_exceptions=True) + self._clients.clear() + logger.info("LLMHttpClientManager: Shutdown complete.") + + +http_client_manager = LLMHttpClientManager() + + +async def create_llm_http_client( + timeout: int = 180, + proxy: str | None = None, +) -> LLMHttpClient: + """创建LLM HTTP客户端""" + config = HttpClientConfig(timeout=timeout, proxy=proxy) + return LLMHttpClient(config) + + +class RetryConfig: + """重试配置""" + + def __init__( + self, + max_retries: int = 3, + retry_delay: float = 1.0, + exponential_backoff: bool = True, + key_rotation: bool = True, + ): + self.max_retries = max_retries + self.retry_delay = retry_delay + self.exponential_backoff = exponential_backoff + self.key_rotation = key_rotation + + +async def with_smart_retry( + func, + *args, + retry_config: RetryConfig | None = None, + key_store: "KeyStatusStore | None" = None, + provider_name: str | None = None, + **kwargs: Any, +) -> Any: + """智能重试装饰器 - 支持Key轮询和错误分类""" + config = retry_config or RetryConfig() + last_exception: Exception | None = None + failed_keys: set[str] = set() + + for attempt in range(config.max_retries + 1): + try: + if config.key_rotation and "failed_keys" in func.__code__.co_varnames: + kwargs["failed_keys"] = failed_keys + + return await func(*args, **kwargs) + + except LLMException as e: + last_exception = e + + if e.code in [LLMErrorCode.API_KEY_INVALID, LLMErrorCode.API_QUOTA_EXCEEDED]: + if hasattr(e, "details") and e.details and "api_key" in e.details: + failed_keys.add(e.details["api_key"]) + if key_store and provider_name: + await key_store.record_failure(e.details["api_key"], e.details.get("status_code")) + + should_retry = _should_retry_llm_error(e, attempt, config.max_retries) + if not should_retry: + logger.error(f"不可重试的错误,停止重试: {e}") + raise + + if attempt < config.max_retries: + wait_time = config.retry_delay + if config.exponential_backoff: + wait_time *= 2**attempt + logger.warning(f"请求失败,{wait_time}秒后重试 (第{attempt + 1}次): {e}") + await asyncio.sleep(wait_time) + else: + logger.error(f"重试{config.max_retries}次后仍然失败: {e}") + + except Exception as e: + last_exception = e + logger.error(f"非LLM异常,停止重试: {e}") + raise LLMException( + f"操作失败: {e}", + code=LLMErrorCode.GENERATION_FAILED, + cause=e, + ) + + if last_exception: + raise last_exception + else: + raise RuntimeError("重试函数未能正常执行且未捕获到异常") + + +def _should_retry_llm_error(error: LLMException, attempt: int, max_retries: int) -> bool: + """判断LLM错误是否应该重试""" + non_retryable_errors = { + LLMErrorCode.MODEL_NOT_FOUND, + LLMErrorCode.CONTEXT_LENGTH_EXCEEDED, + LLMErrorCode.USER_LOCATION_NOT_SUPPORTED, + LLMErrorCode.CONFIGURATION_ERROR, + } + + if error.code in non_retryable_errors: + return False + + retryable_errors = { + LLMErrorCode.API_REQUEST_FAILED, + LLMErrorCode.API_TIMEOUT, + LLMErrorCode.API_RATE_LIMITED, + LLMErrorCode.API_RESPONSE_INVALID, + LLMErrorCode.RESPONSE_PARSE_ERROR, + LLMErrorCode.GENERATION_FAILED, + LLMErrorCode.CONTENT_FILTERED, + LLMErrorCode.API_KEY_INVALID, + LLMErrorCode.API_QUOTA_EXCEEDED, + } + + if error.code in retryable_errors: + if error.code == LLMErrorCode.API_QUOTA_EXCEEDED: + return attempt < min(2, max_retries) + elif error.code == LLMErrorCode.CONTENT_FILTERED: + return attempt < min(1, max_retries) + return True + + return False + + +class KeyStatusStore: + """API Key 状态管理存储 - 优化版本,支持轮询和负载均衡""" + + def __init__(self): + self._key_status: dict[str, bool] = {} + self._key_usage_count: dict[str, int] = {} + self._key_last_used: dict[str, float] = {} + self._provider_key_index: dict[str, int] = {} + self._lock = asyncio.Lock() + + async def get_next_available_key( + self, provider_name: str, api_keys: list[str], exclude_keys: set[str] | None = None + ) -> str | None: + """获取下一个可用的API密钥(轮询策略)""" + if not api_keys: + return None + + exclude_keys = exclude_keys or set() + available_keys = [ + key for key in api_keys if key not in exclude_keys and self._key_status.get(key, True) + ] + + if not available_keys: + return api_keys[0] if api_keys else None + + async with self._lock: + current_index = self._provider_key_index.get(provider_name, 0) + + selected_key = available_keys[current_index % len(available_keys)] + + self._provider_key_index[provider_name] = (current_index + 1) % len(available_keys) + + import time + + self._key_usage_count[selected_key] = self._key_usage_count.get(selected_key, 0) + 1 + self._key_last_used[selected_key] = time.time() + + logger.debug( + f"轮询选择API密钥: {self._get_key_id(selected_key)} " + f"(使用次数: {self._key_usage_count[selected_key]})" + ) + + return selected_key + + async def record_success(self, api_key: str): + """记录成功使用""" + async with self._lock: + self._key_status[api_key] = True + logger.debug(f"记录API密钥成功使用: {self._get_key_id(api_key)}") + + async def record_failure(self, api_key: str, status_code: int | None): + """记录失败使用""" + key_id = self._get_key_id(api_key) + async with self._lock: + if status_code in [401, 403]: + self._key_status[api_key] = False + logger.warning(f"API密钥认证失败,标记为不可用: {key_id} (状态码: {status_code})") + else: + logger.debug(f"记录API密钥失败使用: {key_id} (状态码: {status_code})") + + async def reset_key_status(self, api_key: str): + """重置密钥状态(用于恢复机制)""" + async with self._lock: + self._key_status[api_key] = True + logger.info(f"重置API密钥状态: {self._get_key_id(api_key)}") + + async def get_key_stats(self, api_keys: list[str]) -> dict[str, dict]: + """获取密钥使用统计""" + stats = {} + async with self._lock: + for key in api_keys: + key_id = self._get_key_id(key) + stats[key_id] = { + "available": self._key_status.get(key, True), + "usage_count": self._key_usage_count.get(key, 0), + "last_used": self._key_last_used.get(key, 0), + } + return stats + + def _get_key_id(self, api_key: str) -> str: + """获取API密钥的标识符(用于日志)""" + if len(api_key) <= 8: + return api_key + return f"{api_key[:4]}...{api_key[-4:]}" + + +key_store = KeyStatusStore() diff --git a/zhenxun/services/llm/manager.py b/zhenxun/services/llm/manager.py new file mode 100644 index 00000000..6ba2db79 --- /dev/null +++ b/zhenxun/services/llm/manager.py @@ -0,0 +1,393 @@ +""" +LLM 模型管理器 + +负责模型实例的创建、缓存、配置管理和生命周期管理。 +""" + +import hashlib +import json +import time +from typing import Any + +from zhenxun.configs.config import Config +from zhenxun.services.log import logger + +from .config import validate_override_params +from .config.providers import AI_CONFIG_GROUP, PROVIDERS_CONFIG_KEY, get_ai_config +from .core import http_client_manager, key_store +from .service import LLMModel +from .types import LLMErrorCode, LLMException, ModelDetail, ProviderConfig + +DEFAULT_MODEL_NAME_KEY = "default_model_name" +PROXY_KEY = "proxy" +TIMEOUT_KEY = "timeout" + +_model_cache: dict[str, tuple[LLMModel, float]] = {} +_cache_ttl = 3600 +_max_cache_size = 10 + + +def parse_provider_model_string(name_str: str | None) -> tuple[str | None, str | None]: + """解析 'ProviderName/ModelName' 格式的字符串""" + if not name_str or "/" not in name_str: + return None, None + parts = name_str.split("/", 1) + if len(parts) == 2 and parts[0].strip() and parts[1].strip(): + return parts[0].strip(), parts[1].strip() + return None, None + + +def _make_cache_key(provider_model_name: str | None, override_config: dict | None) -> str: + """生成缓存键""" + config_str = json.dumps(override_config, sort_keys=True) if override_config else "None" + key_data = f"{provider_model_name}:{config_str}" + return hashlib.md5(key_data.encode()).hexdigest() + + +def _get_cached_model(cache_key: str) -> LLMModel | None: + """从缓存获取模型""" + if cache_key in _model_cache: + model, created_time = _model_cache[cache_key] + current_time = time.time() + + if current_time - created_time > _cache_ttl: + del _model_cache[cache_key] + logger.debug(f"模型缓存已过期: {cache_key}") + return None + + if model._is_closed: + logger.debug( + f"缓存的模型 {cache_key} ({model.provider_name}/{model.model_name}) " + f"处于_is_closed=True状态,重置为False以供复用。" + ) + model._is_closed = False + + logger.debug(f"使用缓存的模型: {cache_key} -> {model.provider_name}/{model.model_name}") + return model + return None + + +def _cache_model(cache_key: str, model: LLMModel): + """缓存模型实例""" + current_time = time.time() + + if len(_model_cache) >= _max_cache_size: + oldest_key = min(_model_cache.keys(), key=lambda k: _model_cache[k][1]) + del _model_cache[oldest_key] + + _model_cache[cache_key] = (model, current_time) + + +def clear_model_cache(): + """清空模型缓存""" + global _model_cache + _model_cache.clear() + logger.info("已清空模型缓存") + + +def get_cache_stats() -> dict[str, Any]: + """获取缓存统计信息""" + return { + "cache_size": len(_model_cache), + "max_cache_size": _max_cache_size, + "cache_ttl": _cache_ttl, + "cached_models": list(_model_cache.keys()), + } + + +def get_default_api_base_for_type(api_type: str) -> str | None: + """根据API类型获取默认的API基础地址""" + default_api_bases = { + "openai": "https://api.openai.com", + "deepseek": "https://api.deepseek.com", + "zhipu": "https://open.bigmodel.cn", + "gemini": "https://generativelanguage.googleapis.com", + "general_openai_compat": None, + } + + return default_api_bases.get(api_type) + + +def get_configured_providers() -> list[ProviderConfig]: + """从配置中获取Provider列表 - 简化版本""" + ai_config = get_ai_config() + providers_raw = ai_config.get(PROVIDERS_CONFIG_KEY, []) + if not isinstance(providers_raw, list): + logger.error(f"配置项 {AI_CONFIG_GROUP}.{PROVIDERS_CONFIG_KEY} 不是一个列表,将使用空列表。") + return [] + + valid_providers = [] + for i, item in enumerate(providers_raw): + if not isinstance(item, dict): + logger.warning(f"配置文件中第 {i + 1} 项不是字典格式,已跳过。") + continue + + try: + if not item.get("name"): + logger.warning(f"Provider {i + 1} 缺少 'name' 字段,已跳过。") + continue + + if not item.get("api_key"): + logger.warning(f"Provider '{item['name']}' 缺少 'api_key' 字段,已跳过。") + continue + + if "api_type" not in item or not item["api_type"]: + provider_name = item.get("name", "").lower() + if "glm" in provider_name or "zhipu" in provider_name: + item["api_type"] = "zhipu" + elif "gemini" in provider_name or "google" in provider_name: + item["api_type"] = "gemini" + else: + item["api_type"] = "openai" + + if "api_base" not in item or not item["api_base"]: + api_type = item.get("api_type") + if api_type: + default_api_base = get_default_api_base_for_type(api_type) + if default_api_base: + item["api_base"] = default_api_base + + if "models" not in item: + item["models"] = [{"model_name": item.get("name", "default")}] + + provider_conf = ProviderConfig(**item) + valid_providers.append(provider_conf) + + except Exception as e: + logger.warning(f"解析配置文件中 Provider {i + 1} 时出错: {e},已跳过。") + + return valid_providers + + +def find_model_config(provider_name: str, model_name: str) -> tuple[ProviderConfig, ModelDetail] | None: + """在配置中查找指定的 Provider 和 ModelDetail + + Args: + provider_name: 提供商名称 + model_name: 模型名称 + + Returns: + 找到的 (ProviderConfig, ModelDetail) 元组,未找到则返回 None + """ + providers = get_configured_providers() + + for provider in providers: + if provider.name.lower() == provider_name.lower(): + for model_detail in provider.models: + if model_detail.model_name.lower() == model_name.lower(): + return provider, model_detail + + return None + + +def list_available_models() -> list[dict[str, Any]]: + """列出所有配置的可用模型""" + providers = get_configured_providers() + model_list = [] + for provider in providers: + for model_detail in provider.models: + model_info = { + "provider_name": provider.name, + "model_name": model_detail.model_name, + "full_name": f"{provider.name}/{model_detail.model_name}", + "api_type": provider.api_type or "auto-detect", + "api_base": provider.api_base, + "is_available": model_detail.is_available, + "is_embedding_model": model_detail.is_embedding_model, + "available_identifiers": _get_model_identifiers(provider.name, model_detail), + } + model_list.append(model_info) + return model_list + + +def _get_model_identifiers(provider_name: str, model_detail: ModelDetail) -> list[str]: + """获取模型的所有可用标识符""" + return [f"{provider_name}/{model_detail.model_name}"] + + +def list_model_identifiers() -> dict[str, list[str]]: + """列出所有模型的可用标识符 + + Returns: + 字典,键为模型的完整名称,值为该模型的所有可用标识符列表 + """ + providers = get_configured_providers() + result = {} + + for provider in providers: + for model_detail in provider.models: + full_name = f"{provider.name}/{model_detail.model_name}" + identifiers = _get_model_identifiers(provider.name, model_detail) + result[full_name] = identifiers + + return result + + +def list_embedding_models() -> list[dict[str, Any]]: + """列出所有配置的嵌入模型""" + all_models = list_available_models() + return [model for model in all_models if model.get("is_embedding_model", False)] + + +async def get_model_instance( + provider_model_name: str | None = None, + override_config: dict[str, Any] | None = None, +) -> LLMModel: + """根据 'ProviderName/ModelName' 字符串获取并实例化 LLMModel (异步版本)""" + cache_key = _make_cache_key(provider_model_name, override_config) + cached_model = _get_cached_model(cache_key) + if cached_model: + if override_config: + validated_override = validate_override_params(override_config) + if cached_model._generation_config != validated_override: + cached_model._generation_config = validated_override + logger.debug( + f"对缓存模型 {provider_model_name} 应用新的覆盖配置: {validated_override.to_dict()}" + ) + return cached_model + + resolved_model_name_str = provider_model_name + if resolved_model_name_str is None: + resolved_model_name_str = get_global_default_model_name() + if resolved_model_name_str is None: + available_models_list = list_available_models() + if not available_models_list: + raise LLMException("未配置任何AI模型", code=LLMErrorCode.CONFIGURATION_ERROR) + resolved_model_name_str = available_models_list[0]["full_name"] + logger.warning(f"未指定模型,使用第一个可用模型: {resolved_model_name_str}") + + prov_name_str, mod_name_str = parse_provider_model_string(resolved_model_name_str) + if not prov_name_str or not mod_name_str: + raise LLMException( + f"无效的模型名称格式: '{resolved_model_name_str}'", code=LLMErrorCode.MODEL_NOT_FOUND + ) + + config_tuple_found = find_model_config(prov_name_str, mod_name_str) + if not config_tuple_found: + all_models = list_available_models() + raise LLMException( + f"未找到模型: '{resolved_model_name_str}'. 可用: {[m['full_name'] for m in all_models]}", + code=LLMErrorCode.MODEL_NOT_FOUND, + ) + + provider_config_found, model_detail_found = config_tuple_found + + ai_config = get_ai_config() + global_proxy_setting = ai_config.get(PROXY_KEY) + default_timeout = provider_config_found.timeout if provider_config_found.timeout is not None else 180 + global_timeout_setting = ai_config.get(TIMEOUT_KEY, default_timeout) + + config_for_http_client = ProviderConfig( + name=provider_config_found.name, + api_key=provider_config_found.api_key, + models=provider_config_found.models, + timeout=global_timeout_setting, + proxy=global_proxy_setting, + api_base=provider_config_found.api_base, + api_type=provider_config_found.api_type, + openai_compat=provider_config_found.openai_compat, + temperature=provider_config_found.temperature, + max_tokens=provider_config_found.max_tokens, + ) + + shared_http_client = await http_client_manager.get_client(config_for_http_client) + + try: + model_instance = LLMModel( + provider_config=config_for_http_client, + model_detail=model_detail_found, + key_store=key_store, + http_client=shared_http_client, + ) + + if override_config: + validated_override_params = validate_override_params(override_config) + model_instance._generation_config = validated_override_params + logger.debug( + f"为新模型 {resolved_model_name_str} 应用配置覆盖: {validated_override_params.to_dict()}" + ) + + _cache_model(cache_key, model_instance) + logger.debug(f"创建并缓存了新模型: {cache_key} -> {prov_name_str}/{mod_name_str}") + return model_instance + except LLMException: + raise + except Exception as e: + logger.error(f"实例化 LLMModel ({resolved_model_name_str}) 时发生内部错误: {e!s}", e=e) + raise LLMException( + f"初始化模型 '{resolved_model_name_str}' 失败: {e!s}", + code=LLMErrorCode.MODEL_INIT_FAILED, + cause=e, + ) + + +def get_global_default_model_name() -> str | None: + """获取全局默认模型名称""" + ai_config = get_ai_config() + return ai_config.get(DEFAULT_MODEL_NAME_KEY) + + +def set_global_default_model_name(provider_model_name: str | None) -> bool: + """设置全局默认模型名称""" + if provider_model_name: + prov_name, mod_name = parse_provider_model_string(provider_model_name) + if not prov_name or not mod_name or not find_model_config(prov_name, mod_name): + logger.error(f"尝试设置的全局默认模型 '{provider_model_name}' 无效或未配置。") + return False + + Config.set_config(AI_CONFIG_GROUP, DEFAULT_MODEL_NAME_KEY, provider_model_name, auto_save=True) + if provider_model_name: + logger.info(f"LLM 服务全局默认模型已更新为: {provider_model_name}") + else: + logger.info("LLM 服务全局默认模型已清除。") + return True + + +async def get_key_usage_stats() -> dict[str, Any]: + """获取所有Provider的Key使用统计""" + providers = get_configured_providers() + stats = {} + + for provider in providers: + provider_stats = await key_store.get_key_stats( + [provider.api_key] if isinstance(provider.api_key, str) else provider.api_key + ) + stats[provider.name] = { + "total_keys": len([provider.api_key] if isinstance(provider.api_key, str) else provider.api_key), + "key_stats": provider_stats, + } + + return stats + + +async def reset_key_status(provider_name: str, api_key: str | None = None) -> bool: + """重置指定Provider的Key状态""" + providers = get_configured_providers() + target_provider = None + + for provider in providers: + if provider.name.lower() == provider_name.lower(): + target_provider = provider + break + + if not target_provider: + logger.error(f"未找到Provider: {provider_name}") + return False + + provider_keys = ( + [target_provider.api_key] if isinstance(target_provider.api_key, str) else target_provider.api_key + ) + + if api_key: + if api_key in provider_keys: + await key_store.reset_key_status(api_key) + logger.info(f"已重置Provider '{provider_name}' 的指定Key状态") + return True + else: + logger.error(f"指定的Key不属于Provider '{provider_name}'") + return False + else: + for key in provider_keys: + await key_store.reset_key_status(key) + logger.info(f"已重置Provider '{provider_name}' 的所有Key状态") + return True diff --git a/zhenxun/services/llm/service.py b/zhenxun/services/llm/service.py new file mode 100644 index 00000000..7a0c95d3 --- /dev/null +++ b/zhenxun/services/llm/service.py @@ -0,0 +1,594 @@ +""" +LLM 模型实现类 + +包含 LLM 模型的抽象基类和具体实现,负责与各种 AI 提供商的 API 交互。 +""" + +from abc import ABC, abstractmethod +from collections.abc import Awaitable, Callable +import json +from typing import Any + +from zhenxun.services.log import logger + +from .config import LLMGenerationConfig +from .config.providers import get_ai_config +from .core import ( + KeyStatusStore, + LLMHttpClient, + RetryConfig, + http_client_manager, + with_smart_retry, +) +from .types import ( + EmbeddingTaskType, + LLMErrorCode, + LLMException, + LLMMessage, + LLMResponse, + LLMTool, + ModelDetail, + ProviderConfig, +) + + +class LLMModelBase(ABC): + """LLM模型抽象基类""" + + @abstractmethod + async def generate_text( + self, + prompt: str, + history: list[dict[str, str]] | None = None, + **kwargs: Any, + ) -> str: + """生成文本""" + pass + + @abstractmethod + async def generate_response( + self, + messages: list[LLMMessage], + config: LLMGenerationConfig | None = None, + tools: list[LLMTool] | None = None, + tool_choice: str | dict[str, Any] | None = None, + **kwargs: Any, + ) -> LLMResponse: + """生成高级响应""" + pass + + @abstractmethod + async def generate_embeddings( + self, + texts: list[str], + task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT, + **kwargs: Any, + ) -> list[list[float]]: + """生成文本嵌入向量""" + pass + + +class LLMModel(LLMModelBase): + """LLM 模型实现类""" + + def __init__( + self, + provider_config: ProviderConfig, + model_detail: ModelDetail, + key_store: KeyStatusStore, + http_client: LLMHttpClient, + config_override: LLMGenerationConfig | None = None, + ): + self.provider_config = provider_config + self.model_detail = model_detail + self.key_store = key_store + self.http_client: LLMHttpClient = http_client + self._generation_config = config_override + + self.provider_name = provider_config.name + self.api_type = provider_config.api_type + self.api_base = provider_config.api_base + self.api_keys = ( + [provider_config.api_key] if isinstance(provider_config.api_key, str) else provider_config.api_key + ) + self.model_name = model_detail.model_name + self.temperature = model_detail.temperature + self.max_tokens = model_detail.max_tokens + + self._is_closed = False + + async def _get_http_client(self) -> LLMHttpClient: + """获取HTTP客户端""" + if self.http_client.is_closed: + logger.debug( + f"LLMModel {self.provider_name}/{self.model_name} 的 HTTP 客户端已关闭,正在获取新的客户端" + ) + self.http_client = await http_client_manager.get_client(self.provider_config) + return self.http_client + + async def _select_api_key(self, failed_keys: set[str] | None = None) -> str: + """选择可用的API密钥(使用轮询策略)""" + if not self.api_keys: + raise LLMException( + f"提供商 {self.provider_name} 没有配置API密钥", + code=LLMErrorCode.NO_AVAILABLE_KEYS, + ) + + selected_key = await self.key_store.get_next_available_key( + self.provider_name, self.api_keys, failed_keys + ) + + if not selected_key: + raise LLMException( + f"提供商 {self.provider_name} 的所有API密钥当前都不可用", + code=LLMErrorCode.NO_AVAILABLE_KEYS, + details={"total_keys": len(self.api_keys), "failed_keys": len(failed_keys or set())}, + ) + + return selected_key + + async def _execute_embedding_request( + self, + adapter, + texts: list[str], + task_type: EmbeddingTaskType | str, + http_client: LLMHttpClient, + failed_keys: set[str] | None = None, + ) -> list[list[float]]: + """执行单次嵌入请求 - 供重试机制调用""" + api_key = await self._select_api_key(failed_keys) + + try: + request_data = adapter.prepare_embedding_request( + model=self, + api_key=api_key, + texts=texts, + task_type=task_type, + ) + + http_response = await http_client.post( + request_data.url, + headers=request_data.headers, + json=request_data.body, + ) + + if http_response.status_code != 200: + error_text = http_response.text + logger.error(f"HTTP嵌入请求失败: {http_response.status_code} - {error_text}") + await self.key_store.record_failure(api_key, http_response.status_code) + + error_code = LLMErrorCode.API_REQUEST_FAILED + if http_response.status_code in [401, 403]: + error_code = LLMErrorCode.API_KEY_INVALID + elif http_response.status_code == 429: + error_code = LLMErrorCode.API_RATE_LIMITED + + raise LLMException( + f"HTTP嵌入请求失败: {http_response.status_code}", + code=error_code, + details={ + "status_code": http_response.status_code, + "response": error_text, + "api_key": api_key, + }, + ) + + try: + response_json = http_response.json() + adapter.validate_embedding_response(response_json) + embeddings = adapter.parse_embedding_response(response_json) + except Exception as e: + logger.error(f"解析嵌入响应失败: {e}", e=e) + await self.key_store.record_failure(api_key, None) + if isinstance(e, LLMException): + raise + else: + raise LLMException( + f"解析API嵌入响应失败: {e}", + code=LLMErrorCode.RESPONSE_PARSE_ERROR, + cause=e, + ) + + await self.key_store.record_success(api_key) + return embeddings + + except LLMException: + raise + except Exception as e: + logger.error(f"生成嵌入时发生未预期错误: {e}", e=e) + await self.key_store.record_failure(api_key, None) + raise LLMException( + f"生成嵌入失败: {e}", + code=LLMErrorCode.EMBEDDING_FAILED, + cause=e, + ) + + async def _execute_with_smart_retry( + self, + adapter, + messages: list[LLMMessage], + config: LLMGenerationConfig | None, + tools_dict: list[dict[str, Any]] | None, + tool_choice: str | dict[str, Any] | None, + http_client: LLMHttpClient, + ): + """智能重试机制 - 使用统一的重试装饰器""" + ai_config = get_ai_config() + max_retries = ai_config.get("max_retries_llm", 3) + retry_delay = ai_config.get("retry_delay_llm", 2) + retry_config = RetryConfig(max_retries=max_retries, retry_delay=retry_delay) + + return await with_smart_retry( + self._execute_single_request, + adapter, + messages, + config, + tools_dict, + tool_choice, + http_client, + retry_config=retry_config, + key_store=self.key_store, + provider_name=self.provider_name, + ) + + async def _execute_single_request( + self, + adapter, + messages: list[LLMMessage], + config: LLMGenerationConfig | None, + tools_dict: list[dict[str, Any]] | None, + tool_choice: str | dict[str, Any] | None, + http_client: LLMHttpClient, + failed_keys: set[str] | None = None, + ) -> LLMResponse: + """执行单次请求 - 供重试机制调用,直接返回 LLMResponse""" + api_key = await self._select_api_key(failed_keys) + + try: + request_data = adapter.prepare_advanced_request( + model=self, + api_key=api_key, + messages=messages, + config=config, + tools=tools_dict, + tool_choice=tool_choice, + ) + + http_response = await http_client.post( + request_data.url, + headers=request_data.headers, + json=request_data.body, + ) + + if http_response.status_code != 200: + error_text = http_response.text + logger.error(f"HTTP请求失败: {http_response.status_code} - {error_text}") + + await self.key_store.record_failure(api_key, http_response.status_code) + + if http_response.status_code in [401, 403]: + error_code = LLMErrorCode.API_KEY_INVALID + elif http_response.status_code == 429: + error_code = LLMErrorCode.API_RATE_LIMITED + elif http_response.status_code in [402, 413]: + error_code = LLMErrorCode.API_QUOTA_EXCEEDED + else: + error_code = LLMErrorCode.API_REQUEST_FAILED + + raise LLMException( + f"HTTP请求失败: {http_response.status_code}", + code=error_code, + details={ + "status_code": http_response.status_code, + "response": error_text, + "api_key": api_key, + }, + ) + + try: + response_json = http_response.json() + response_data = adapter.parse_response( + model=self, + response_json=response_json, + is_advanced=True, + ) + + from .types.models import LLMToolCall + + response_tool_calls = [] + if response_data.tool_calls: + for tc_data in response_data.tool_calls: + if isinstance(tc_data, LLMToolCall): + response_tool_calls.append(tc_data) + elif isinstance(tc_data, dict): + try: + response_tool_calls.append(LLMToolCall(**tc_data)) + except Exception as e: + logger.warning(f"无法将工具调用数据转换为LLMToolCall: {tc_data}, error: {e}") + else: + logger.warning(f"工具调用数据格式未知: {tc_data}") + + llm_response = LLMResponse( + text=response_data.text, + usage_info=response_data.usage_info, + raw_response=response_data.raw_response, + tool_calls=response_tool_calls if response_tool_calls else None, + code_executions=response_data.code_executions, + grounding_metadata=response_data.grounding_metadata, + cache_info=response_data.cache_info, + ) + + except Exception as e: + logger.error(f"解析响应失败: {e}", e=e) + await self.key_store.record_failure(api_key, None) + + if isinstance(e, LLMException): + raise + else: + raise LLMException( + f"解析API响应失败: {e}", + code=LLMErrorCode.RESPONSE_PARSE_ERROR, + cause=e, + ) + + await self.key_store.record_success(api_key) + + return llm_response + + except LLMException: + raise + except Exception as e: + logger.error(f"生成响应时发生未预期错误: {e}", e=e) + await self.key_store.record_failure(api_key, None) + + raise LLMException( + f"生成响应失败: {e}", + code=LLMErrorCode.GENERATION_FAILED, + cause=e, + ) + + async def close(self): + """ + 标记模型实例的当前使用周期结束。 + 共享的 HTTP 客户端由 LLMHttpClientManager 管理,不由 LLMModel 关闭。 + """ + if self._is_closed: + return + self._is_closed = True + logger.debug(f"LLMModel实例的使用周期已结束: {self} (共享HTTP客户端状态不受影响)") + + async def __aenter__(self): + if self._is_closed: + logger.debug(f"Re-entering context for closed LLMModel {self}. Resetting _is_closed to False.") + self._is_closed = False + self._check_not_closed() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """异步上下文管理器出口""" + _ = exc_type, exc_val, exc_tb + await self.close() + + def _check_not_closed(self): + """检查实例是否已关闭""" + if self._is_closed: + raise RuntimeError(f"LLMModel实例已关闭: {self}") + + async def generate_text( + self, + prompt: str, + history: list[dict[str, str]] | None = None, + **kwargs: Any, + ) -> str: + """生成文本 - 通过 generate_response 实现""" + self._check_not_closed() + + messages: list[LLMMessage] = [] + + if history: + for msg in history: + role = msg.get("role", "user") + content_text = msg.get("content", "") + messages.append(LLMMessage(role=role, content=content_text)) + + messages.append(LLMMessage.user(prompt)) + + request_specific_config_dict = { + k: v for k, v in kwargs.items() if k in LLMGenerationConfig.model_fields + } + request_specific_config = None + if request_specific_config_dict: + request_specific_config = LLMGenerationConfig(**request_specific_config_dict) + + for key in request_specific_config_dict: + kwargs.pop(key, None) + + response = await self.generate_response( + messages, + config=request_specific_config, + **kwargs, + ) + return response.text + + async def generate_response( + self, + messages: list[LLMMessage], + config: LLMGenerationConfig | None = None, + tools: list[LLMTool] | None = None, + tool_choice: str | dict[str, Any] | None = None, + tool_executor: Callable[[str, dict[str, Any]], Awaitable[Any]] | None = None, + max_tool_iterations: int = 5, + **kwargs: Any, + ) -> LLMResponse: + """生成高级响应 - 实现完整的工具调用循环""" + self._check_not_closed() + + from .adapters import get_adapter_for_api_type + from .config.generation import create_generation_config_from_kwargs + + adapter = get_adapter_for_api_type(self.api_type) + if not adapter: + raise LLMException( + f"未找到适用于 API 类型 '{self.api_type}' 的适配器", + code=LLMErrorCode.CONFIGURATION_ERROR, + ) + + final_request_config = self._generation_config or LLMGenerationConfig() + if kwargs: + kwargs_config = create_generation_config_from_kwargs(**kwargs) + merged_dict = final_request_config.to_dict() + merged_dict.update(kwargs_config.to_dict()) + final_request_config = LLMGenerationConfig(**merged_dict) + + if config is not None: + merged_dict = final_request_config.to_dict() + merged_dict.update(config.to_dict()) + final_request_config = LLMGenerationConfig(**merged_dict) + + tools_dict: list[dict[str, Any]] | None = None + if tools: + tools_dict = [] + for tool in tools: + if hasattr(tool, "model_dump"): + tools_dict.append(tool.model_dump(exclude_none=True)) + elif isinstance(tool, dict): + tools_dict.append(tool) + else: + try: + tools_dict.append(dict(tool)) + except (TypeError, ValueError): + logger.warning(f"工具 '{tool}' 无法转换为字典,已忽略。") + + http_client = await self._get_http_client() + current_messages = list(messages) + + for iteration in range(max_tool_iterations): + logger.debug(f"工具调用循环迭代: {iteration + 1}/{max_tool_iterations}") + + llm_response = await self._execute_with_smart_retry( + adapter, + current_messages, + final_request_config, + tools_dict if iteration == 0 else None, + tool_choice if iteration == 0 else None, + http_client, + ) + + response_tool_calls = llm_response.tool_calls or [] + + if not response_tool_calls or not tool_executor: + logger.debug("模型未请求工具调用,或未提供工具执行器。返回当前响应。") + return llm_response + + logger.info(f"模型请求执行 {len(response_tool_calls)} 个工具。") + + assistant_message_content = llm_response.text if llm_response.text else "" + current_messages.append( + LLMMessage.assistant_tool_calls( + content=assistant_message_content, tool_calls=response_tool_calls + ) + ) + + tool_response_messages: list[LLMMessage] = [] + for tool_call in response_tool_calls: + tool_name = tool_call.function.name + try: + tool_args_dict = json.loads(tool_call.function.arguments) + logger.debug(f"执行工具: {tool_name},参数: {tool_args_dict}") + + tool_result = await tool_executor(tool_name, tool_args_dict) + logger.debug(f"工具 '{tool_name}' 执行结果: {str(tool_result)[:200]}...") + + tool_response_messages.append( + LLMMessage.tool_response( + tool_call_id=tool_call.id, + function_name=tool_name, + result=tool_result, + ) + ) + except json.JSONDecodeError as e: + logger.error( + f"工具 '{tool_name}' 参数JSON解析失败: {tool_call.function.arguments}, 错误: {e}" + ) + tool_response_messages.append( + LLMMessage.tool_response( + tool_call_id=tool_call.id, + function_name=tool_name, + result={"error": "Argument JSON parsing failed", "details": str(e)}, + ) + ) + except Exception as e: + logger.error(f"执行工具 '{tool_name}' 失败: {e}", e=e) + tool_response_messages.append( + LLMMessage.tool_response( + tool_call_id=tool_call.id, + function_name=tool_name, + result={"error": "Tool execution failed", "details": str(e)}, + ) + ) + + current_messages.extend(tool_response_messages) + + logger.warning(f"已达到最大工具调用迭代次数 ({max_tool_iterations})。") + raise LLMException( + "已达到最大工具调用迭代次数,但模型仍在请求工具调用或未提供最终文本回复。", + code=LLMErrorCode.GENERATION_FAILED, + details={"iterations": max_tool_iterations, "last_messages": current_messages[-2:]}, + ) + + async def generate_embeddings( + self, + texts: list[str], + task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT, + **kwargs: Any, + ) -> list[list[float]]: + """生成文本嵌入向量""" + self._check_not_closed() + if not texts: + return [] + + from .adapters import get_adapter_for_api_type + + adapter = get_adapter_for_api_type(self.api_type) + if not adapter: + raise LLMException( + f"未找到适用于 API 类型 '{self.api_type}' 的嵌入适配器", + code=LLMErrorCode.CONFIGURATION_ERROR, + ) + + http_client = await self._get_http_client() + + ai_config = get_ai_config() + default_max_retries = ai_config.get("max_retries_llm", 3) + default_retry_delay = ai_config.get("retry_delay_llm", 2) + max_retries_embed = kwargs.get("max_retries_embed", max(1, default_max_retries // 2)) + retry_delay_embed = kwargs.get("retry_delay_embed", default_retry_delay / 2) + + retry_config = RetryConfig( + max_retries=max_retries_embed, + retry_delay=retry_delay_embed, + exponential_backoff=True, + key_rotation=True, + ) + + return await with_smart_retry( + self._execute_embedding_request, + adapter, + texts, + task_type, + http_client, + retry_config=retry_config, + key_store=self.key_store, + provider_name=self.provider_name, + ) + + def __str__(self) -> str: + status = "closed" if self._is_closed else "active" + return f"LLMModel({self.provider_name}/{self.model_name}, {status})" + + def __repr__(self) -> str: + status = "closed" if self._is_closed else "active" + return ( + f"LLMModel(provider={self.provider_name}, model={self.model_name}, " + f"api_type={self.api_type}, status={status})" + ) diff --git a/zhenxun/services/llm/types/__init__.py b/zhenxun/services/llm/types/__init__.py new file mode 100644 index 00000000..ebae4185 --- /dev/null +++ b/zhenxun/services/llm/types/__init__.py @@ -0,0 +1,54 @@ +""" +LLM 类型定义模块 + +统一导出所有核心类型、协议和异常定义。 +""" + +from .content import ( + LLMContentPart, + LLMMessage, + LLMResponse, +) +from .enums import EmbeddingTaskType, ModelProvider, ResponseFormat, ToolCategory +from .exceptions import LLMErrorCode, LLMException, get_user_friendly_error_message +from .models import ( + LLMCacheInfo, + LLMCodeExecution, + LLMGroundingAttribution, + LLMGroundingMetadata, + LLMTool, + LLMToolCall, + LLMToolFunction, + ModelDetail, + ModelInfo, + ModelName, + ProviderConfig, + ToolMetadata, + UsageInfo, +) + +__all__ = [ + "EmbeddingTaskType", + "LLMCacheInfo", + "LLMCodeExecution", + "LLMContentPart", + "LLMErrorCode", + "LLMException", + "LLMGroundingAttribution", + "LLMGroundingMetadata", + "LLMMessage", + "LLMResponse", + "LLMTool", + "LLMToolCall", + "LLMToolFunction", + "ModelDetail", + "ModelInfo", + "ModelName", + "ModelProvider", + "ProviderConfig", + "ResponseFormat", + "ToolCategory", + "ToolMetadata", + "UsageInfo", + "get_user_friendly_error_message", +] diff --git a/zhenxun/services/llm/types/content.py b/zhenxun/services/llm/types/content.py new file mode 100644 index 00000000..e24c0568 --- /dev/null +++ b/zhenxun/services/llm/types/content.py @@ -0,0 +1,379 @@ +""" +LLM 内容类型定义 + +包含多模态内容部分、消息和响应的数据模型。 +""" + +import base64 +import mimetypes +from pathlib import Path +from typing import Any + +import aiofiles +from pydantic import BaseModel + +from zhenxun.services.log import logger + + +class LLMContentPart(BaseModel): + """LLM 消息内容部分 - 支持多模态内容""" + + type: str + text: str | None = None + image_source: str | None = None + audio_source: str | None = None + video_source: str | None = None + document_source: str | None = None + file_uri: str | None = None + file_source: str | None = None + url: str | None = None + mime_type: str | None = None + metadata: dict[str, Any] | None = None + + def model_post_init(self, /, __context: Any) -> None: + """验证内容部分的有效性""" + _ = __context + validation_rules = { + "text": lambda: self.text, + "image": lambda: self.image_source, + "audio": lambda: self.audio_source, + "video": lambda: self.video_source, + "document": lambda: self.document_source, + "file": lambda: self.file_uri or self.file_source, + "url": lambda: self.url, + } + + if self.type in validation_rules: + if not validation_rules[self.type](): + raise ValueError(f"{self.type}类型的内容部分必须包含相应字段") + + @classmethod + def text_part(cls, text: str) -> "LLMContentPart": + """创建文本内容部分""" + return cls(type="text", text=text) + + @classmethod + def image_url_part(cls, url: str) -> "LLMContentPart": + """创建图片URL内容部分""" + return cls(type="image", image_source=url) + + @classmethod + def image_base64_part(cls, data: str, mime_type: str = "image/png") -> "LLMContentPart": + """创建Base64图片内容部分""" + data_url = f"data:{mime_type};base64,{data}" + return cls(type="image", image_source=data_url) + + @classmethod + def audio_url_part(cls, url: str, mime_type: str = "audio/wav") -> "LLMContentPart": + """创建音频URL内容部分""" + return cls(type="audio", audio_source=url, mime_type=mime_type) + + @classmethod + def video_url_part(cls, url: str, mime_type: str = "video/mp4") -> "LLMContentPart": + """创建视频URL内容部分""" + return cls(type="video", video_source=url, mime_type=mime_type) + + @classmethod + def video_base64_part(cls, data: str, mime_type: str = "video/mp4") -> "LLMContentPart": + """创建Base64视频内容部分""" + data_url = f"data:{mime_type};base64,{data}" + return cls(type="video", video_source=data_url, mime_type=mime_type) + + @classmethod + def audio_base64_part(cls, data: str, mime_type: str = "audio/wav") -> "LLMContentPart": + """创建Base64音频内容部分""" + data_url = f"data:{mime_type};base64,{data}" + return cls(type="audio", audio_source=data_url, mime_type=mime_type) + + @classmethod + def file_uri_part( + cls, + file_uri: str, + mime_type: str | None = None, + metadata: dict[str, Any] | None = None, + ) -> "LLMContentPart": + """创建Gemini File API URI内容部分""" + return cls( + type="file", + file_uri=file_uri, + mime_type=mime_type, + metadata=metadata or {}, + ) + + @classmethod + async def from_path(cls, path_like: str | Path, target_api: str | None = None) -> "LLMContentPart | None": + """ + 从本地文件路径创建 LLMContentPart。 + 自动检测MIME类型,并根据类型(如图片)可能加载为Base64。 + target_api 可以用于提示如何最好地准备数据(例如 'gemini' 可能偏好 base64) + """ + try: + path = Path(path_like) + if not path.exists() or not path.is_file(): + logger.warning(f"文件不存在或不是一个文件: {path}") + return None + + mime_type, _ = mimetypes.guess_type(path.resolve().as_uri()) + + if not mime_type: + logger.warning(f"无法猜测文件 {path.name} 的MIME类型,将尝试作为文本文件处理。") + try: + async with aiofiles.open(path, encoding="utf-8") as f: + text_content = await f.read() + return cls.text_part(text_content) + except Exception as e: + logger.error(f"读取文本文件 {path.name} 失败: {e}") + return None + + if mime_type.startswith("image/"): + if target_api == "gemini" or not path.is_absolute(): + try: + async with aiofiles.open(path, "rb") as f: + img_bytes = await f.read() + base64_data = base64.b64encode(img_bytes).decode("utf-8") + return cls.image_base64_part(data=base64_data, mime_type=mime_type) + except Exception as e: + logger.error(f"读取或编码图片文件 {path.name} 失败: {e}") + return None + else: + logger.warning( + f"为本地图片路径 {path.name} 生成 image_url_part。实际API可能不支持 file:// URI。考虑使用Base64或公网URL。" # noqa: E501 + ) + return cls.image_url_part(url=path.resolve().as_uri()) + elif mime_type.startswith("audio/"): + return cls.audio_url_part(url=path.resolve().as_uri(), mime_type=mime_type) + elif mime_type.startswith("video/"): + if target_api == "gemini": + # 对于 Gemini API,将视频转换为 base64 + try: + async with aiofiles.open(path, "rb") as f: + video_bytes = await f.read() + base64_data = base64.b64encode(video_bytes).decode("utf-8") + return cls.video_base64_part(data=base64_data, mime_type=mime_type) + except Exception as e: + logger.error(f"读取或编码视频文件 {path.name} 失败: {e}") + return None + else: + return cls.video_url_part(url=path.resolve().as_uri(), mime_type=mime_type) + elif ( + mime_type.startswith("text/") + or mime_type == "application/json" + or mime_type == "application/xml" + ): + try: + async with aiofiles.open(path, encoding="utf-8") as f: + text_content = await f.read() + return cls.text_part(text_content) + except Exception as e: + logger.error(f"读取文本类文件 {path.name} 失败: {e}") + return None + else: + logger.info(f"文件 {path.name} (MIME: {mime_type}) 将作为通用文件URI处理。") + return cls.file_uri_part( + file_uri=path.resolve().as_uri(), + mime_type=mime_type, + metadata={"name": path.name, "source": "local_path"}, + ) + + except Exception as e: + logger.error(f"从路径 {path_like} 创建LLMContentPart时出错: {e}") + return None + + def is_image_url(self) -> bool: + """检查图像源是否为URL""" + if not self.image_source: + return False + return self.image_source.startswith(("http://", "https://")) + + def is_image_base64(self) -> bool: + """检查图像源是否为Base64 Data URL""" + if not self.image_source: + return False + return self.image_source.startswith("data:") + + def get_base64_data(self) -> tuple[str, str] | None: + """从Data URL中提取Base64数据和MIME类型""" + if not self.is_image_base64() or not self.image_source: + return None + + try: + header, data = self.image_source.split(",", 1) + mime_part = header.split(";")[0].replace("data:", "") + return mime_part, data + except (ValueError, IndexError): + logger.warning(f"无法解析Base64图像数据: {self.image_source[:50]}...") + return None + + def convert_for_api(self, api_type: str) -> dict[str, Any]: + """根据API类型转换多模态内容格式""" + if self.type == "text": + if api_type == "openai": + return {"type": "text", "text": self.text} + elif api_type == "gemini": + return {"text": self.text} + else: + return {"type": "text", "text": self.text} + + elif self.type == "image": + if not self.image_source: + raise ValueError("图像类型的内容必须包含image_source") + + if api_type == "openai": + return {"type": "image_url", "image_url": {"url": self.image_source}} + elif api_type == "gemini": + if self.is_image_base64(): + base64_info = self.get_base64_data() + if base64_info: + mime_type, data = base64_info + return {"inlineData": {"mimeType": mime_type, "data": data}} + else: + # 如果无法解析 Base64 数据,抛出异常 + raise ValueError(f"无法解析Base64图像数据: {self.image_source[:50]}...") + else: + logger.warning(f"Gemini API需要Base64格式,但提供的是URL: {self.image_source}") + return { + "inlineData": { + "mimeType": "image/jpeg", + "data": self.image_source, + } + } + else: + return {"type": "image_url", "image_url": {"url": self.image_source}} + + elif self.type == "video": + if not self.video_source: + raise ValueError("视频类型的内容必须包含video_source") + + if api_type == "gemini": + # Gemini 支持视频,但需要通过 File API 上传 + if self.video_source.startswith("data:"): + # 处理 base64 视频数据 + try: + header, data = self.video_source.split(",", 1) + mime_type = header.split(";")[0].replace("data:", "") + return {"inlineData": {"mimeType": mime_type, "data": data}} + except (ValueError, IndexError): + raise ValueError(f"无法解析Base64视频数据: {self.video_source[:50]}...") + else: + # 对于 URL 或其他格式,暂时不支持直接内联 + raise ValueError("Gemini API 的视频处理需要通过 File API 上传,不支持直接 URL") + else: + # 其他 API 可能不支持视频 + raise ValueError(f"API类型 '{api_type}' 不支持视频内容") + + elif self.type == "audio": + if not self.audio_source: + raise ValueError("音频类型的内容必须包含audio_source") + + if api_type == "gemini": + # Gemini 支持音频,处理方式类似视频 + if self.audio_source.startswith("data:"): + try: + header, data = self.audio_source.split(",", 1) + mime_type = header.split(";")[0].replace("data:", "") + return {"inlineData": {"mimeType": mime_type, "data": data}} + except (ValueError, IndexError): + raise ValueError(f"无法解析Base64音频数据: {self.audio_source[:50]}...") + else: + raise ValueError("Gemini API 的音频处理需要通过 File API 上传,不支持直接 URL") + else: + raise ValueError(f"API类型 '{api_type}' 不支持音频内容") + + elif self.type == "file": + if api_type == "gemini" and self.file_uri: + return {"fileData": {"mimeType": self.mime_type, "fileUri": self.file_uri}} + elif self.file_source: + file_name = self.metadata.get("name", "file") if self.metadata else "file" + if api_type == "gemini": + return {"text": f"[文件: {file_name}]\n{self.file_source}"} + else: + return { + "type": "text", + "text": f"[文件: {file_name}]\n{self.file_source}", + } + else: + raise ValueError("文件类型的内容必须包含file_uri或file_source") + + else: + raise ValueError(f"不支持的内容类型: {self.type}") + + +class LLMMessage(BaseModel): + """LLM 消息""" + + role: str + content: str | list[LLMContentPart] + name: str | None = None + tool_calls: list[Any] | None = None + tool_call_id: str | None = None + + def model_post_init(self, /, __context: Any) -> None: + """验证消息的有效性""" + _ = __context + if self.role == "tool": + if not self.tool_call_id: + raise ValueError("工具角色的消息必须包含 tool_call_id") + if not self.name: + raise ValueError("工具角色的消息必须包含函数名 (在 name 字段中)") + if self.role == "tool" and not isinstance(self.content, str): + logger.warning( + f"工具角色消息的内容期望是字符串,但得到的是: {type(self.content)}. 将尝试转换为字符串。" + ) + try: + self.content = str(self.content) + except Exception as e: + raise ValueError(f"无法将工具角色的内容转换为字符串: {e}") + + @classmethod + def user(cls, content: str | list[LLMContentPart]) -> "LLMMessage": + """创建用户消息""" + return cls(role="user", content=content) + + @classmethod + def assistant_tool_calls( + cls, + tool_calls: list[Any], + content: str | list[LLMContentPart] = "", + ) -> "LLMMessage": + """创建助手请求工具调用的消息""" + return cls(role="assistant", content=content, tool_calls=tool_calls) + + @classmethod + def assistant_text_response(cls, content: str | list[LLMContentPart]) -> "LLMMessage": + """创建助手纯文本回复的消息""" + return cls(role="assistant", content=content, tool_calls=None) + + @classmethod + def tool_response( + cls, + tool_call_id: str, + function_name: str, + result: Any, + ) -> "LLMMessage": + """创建工具执行结果的消息""" + import json + + try: + content_str = json.dumps(result) + except TypeError as e: + logger.error(f"工具 '{function_name}' 的结果无法JSON序列化: {result}. 错误: {e}") + content_str = json.dumps({"error": "Tool result not JSON serializable", "details": str(e)}) + + return cls(role="tool", content=content_str, tool_call_id=tool_call_id, name=function_name) + + @classmethod + def system(cls, content: str) -> "LLMMessage": + """创建系统消息""" + return cls(role="system", content=content) + + +class LLMResponse(BaseModel): + """LLM 响应""" + + text: str + usage_info: dict[str, Any] | None = None + raw_response: dict[str, Any] | None = None + tool_calls: list[Any] | None = None + code_executions: list[Any] | None = None + grounding_metadata: Any | None = None + cache_info: Any | None = None diff --git a/zhenxun/services/llm/types/enums.py b/zhenxun/services/llm/types/enums.py new file mode 100644 index 00000000..718a52ef --- /dev/null +++ b/zhenxun/services/llm/types/enums.py @@ -0,0 +1,67 @@ +""" +LLM 枚举类型定义 +""" + +from enum import Enum, auto + + +class ModelProvider(Enum): + """模型提供商枚举""" + + OPENAI = "openai" + GEMINI = "gemini" + ZHIXPU = "zhipu" + CUSTOM = "custom" + + +class ResponseFormat(Enum): + """响应格式枚举""" + + TEXT = "text" + JSON = "json" + MULTIMODAL = "multimodal" + + +class EmbeddingTaskType(str, Enum): + """文本嵌入任务类型 (主要用于Gemini)""" + + RETRIEVAL_QUERY = "RETRIEVAL_QUERY" + RETRIEVAL_DOCUMENT = "RETRIEVAL_DOCUMENT" + SEMANTIC_SIMILARITY = "SEMANTIC_SIMILARITY" + CLASSIFICATION = "CLASSIFICATION" + CLUSTERING = "CLUSTERING" + QUESTION_ANSWERING = "QUESTION_ANSWERING" + FACT_VERIFICATION = "FACT_VERIFICATION" + + +class ToolCategory(Enum): + """工具分类枚举""" + + FILE_SYSTEM = auto() + NETWORK = auto() + SYSTEM_INFO = auto() + CALCULATION = auto() + DATA_PROCESSING = auto() + CUSTOM = auto() + + +class LLMErrorCode(Enum): + """LLM 服务相关的错误代码枚举""" + + MODEL_INIT_FAILED = 2000 + MODEL_NOT_FOUND = 2001 + API_REQUEST_FAILED = 2002 + API_RESPONSE_INVALID = 2003 + API_KEY_INVALID = 2004 + API_QUOTA_EXCEEDED = 2005 + API_TIMEOUT = 2006 + API_RATE_LIMITED = 2007 + NO_AVAILABLE_KEYS = 2008 + UNKNOWN_API_TYPE = 2009 + CONFIGURATION_ERROR = 2010 + RESPONSE_PARSE_ERROR = 2011 + CONTEXT_LENGTH_EXCEEDED = 2012 + CONTENT_FILTERED = 2013 + USER_LOCATION_NOT_SUPPORTED = 2014 + GENERATION_FAILED = 2015 + EMBEDDING_FAILED = 2016 diff --git a/zhenxun/services/llm/types/exceptions.py b/zhenxun/services/llm/types/exceptions.py new file mode 100644 index 00000000..9621c09d --- /dev/null +++ b/zhenxun/services/llm/types/exceptions.py @@ -0,0 +1,76 @@ +""" +LLM 异常类型定义 +""" + +from typing import Any + +from .enums import LLMErrorCode + + +class LLMException(Exception): + """LLM 服务相关的基础异常类""" + + def __init__( + self, + message: str, + code: LLMErrorCode = LLMErrorCode.API_REQUEST_FAILED, + details: dict[str, Any] | None = None, + recoverable: bool = True, + cause: Exception | None = None, + ): + self.message = message + self.code = code + self.details = details or {} + self.recoverable = recoverable + self.cause = cause + super().__init__(message) + + def __str__(self) -> str: + if self.details: + return f"{self.message} (错误码: {self.code.name}, 详情: {self.details})" + return f"{self.message} (错误码: {self.code.name})" + + @property + def user_friendly_message(self) -> str: + """返回适合向用户展示的错误消息""" + error_messages = { + LLMErrorCode.MODEL_NOT_FOUND: "AI模型未找到,请检查配置或联系管理员。", + LLMErrorCode.API_KEY_INVALID: "API密钥无效,请联系管理员更新配置。", + LLMErrorCode.API_QUOTA_EXCEEDED: "API使用配额已用尽,请稍后再试或联系管理员。", + LLMErrorCode.API_TIMEOUT: "AI服务响应超时,请稍后再试。", + LLMErrorCode.API_RATE_LIMITED: "请求过于频繁,已被AI服务限流,请稍后再试。", + LLMErrorCode.MODEL_INIT_FAILED: "AI模型初始化失败,请联系管理员检查配置。", + LLMErrorCode.NO_AVAILABLE_KEYS: "当前所有API密钥均不可用,请稍后再试或联系管理员。", + LLMErrorCode.USER_LOCATION_NOT_SUPPORTED: ( + "当前地区暂不支持此AI服务,请联系管理员或尝试其他模型。" + ), + LLMErrorCode.API_REQUEST_FAILED: "AI服务请求失败,请稍后再试。", + LLMErrorCode.API_RESPONSE_INVALID: "AI服务响应异常,请稍后再试。", + LLMErrorCode.CONFIGURATION_ERROR: "AI服务配置错误,请联系管理员。", + LLMErrorCode.CONTEXT_LENGTH_EXCEEDED: "输入内容过长,请缩短后重试。", + LLMErrorCode.CONTENT_FILTERED: "内容被安全过滤,请修改后重试。", + LLMErrorCode.RESPONSE_PARSE_ERROR: "AI服务响应解析失败,请稍后再试。", + LLMErrorCode.UNKNOWN_API_TYPE: "不支持的AI服务类型,请联系管理员。", + } + return error_messages.get(self.code, "AI服务暂时不可用,请稍后再试。") + + +def get_user_friendly_error_message(error: Exception) -> str: + """将任何异常转换为用户友好的错误消息""" + if isinstance(error, LLMException): + return error.user_friendly_message + + error_str = str(error).lower() + + if "timeout" in error_str or "超时" in error_str: + return "请求超时,请稍后再试。" + elif "connection" in error_str or "连接" in error_str: + return "网络连接失败,请检查网络后重试。" + elif "permission" in error_str or "权限" in error_str: + return "权限不足,请联系管理员。" + elif "not found" in error_str or "未找到" in error_str: + return "请求的资源未找到,请检查配置。" + elif "invalid" in error_str or "无效" in error_str: + return "请求参数无效,请检查输入。" + else: + return "服务暂时不可用,请稍后再试。" diff --git a/zhenxun/services/llm/types/models.py b/zhenxun/services/llm/types/models.py new file mode 100644 index 00000000..c5f541bc --- /dev/null +++ b/zhenxun/services/llm/types/models.py @@ -0,0 +1,160 @@ +""" +LLM 数据模型定义 + +包含模型信息、配置、工具定义和响应数据的模型类。 +""" + +from dataclasses import dataclass, field +from typing import Any + +from pydantic import BaseModel, Field + +from .enums import ModelProvider, ToolCategory + +ModelName = str | None + + +@dataclass(frozen=True) +class ModelInfo: + """模型信息(不可变数据类)""" + + name: str + provider: ModelProvider + max_tokens: int = 4096 + supports_tools: bool = False + supports_vision: bool = False + supports_audio: bool = False + cost_per_1k_tokens: float = 0.0 + + +@dataclass +class UsageInfo: + """使用信息数据类""" + + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + cost: float = 0.0 + + @property + def efficiency_ratio(self) -> float: + """计算效率比(输出/输入)""" + return self.completion_tokens / max(self.prompt_tokens, 1) + + +@dataclass +class ToolMetadata: + """工具元数据""" + + name: str + description: str + category: ToolCategory + read_only: bool = True + destructive: bool = False + open_world: bool = False + parameters: dict[str, Any] = field(default_factory=dict) + required_params: list[str] = field(default_factory=list) + + +class ModelDetail(BaseModel): + """模型详细信息""" + + model_name: str + is_available: bool = True + is_embedding_model: bool = False + temperature: float | None = None + max_tokens: int | None = None + + +class ProviderConfig(BaseModel): + """LLM 提供商配置""" + + name: str = Field(..., description="Provider 的唯一名称标识") + api_key: str | list[str] = Field(..., description="API Key 或 Key 列表") + api_base: str | None = Field(None, description="API Base URL,如果为空则使用默认值") + api_type: str = Field(default="openai", description="API 类型") + openai_compat: bool = Field(default=False, description="是否使用 OpenAI 兼容模式") + temperature: float | None = Field(default=0.7, description="默认温度参数") + max_tokens: int | None = Field(default=None, description="默认最大输出 token 限制") + models: list[ModelDetail] = Field(..., description="支持的模型列表") + timeout: int = Field(default=180, description="请求超时时间") + proxy: str | None = Field(default=None, description="代理设置") + + +class LLMToolFunction(BaseModel): + """LLM 工具函数定义""" + + name: str + arguments: str + + +class LLMToolCall(BaseModel): + """LLM 工具调用""" + + id: str + function: LLMToolFunction + + +class LLMTool(BaseModel): + """LLM 工具定义(支持 MCP 风格)""" + + type: str = "function" + function: dict[str, Any] + annotations: dict[str, Any] | None = Field(default=None, description="工具注解") + + @classmethod + def create( + cls, + name: str, + description: str, + parameters: dict[str, Any], + required: list[str] | None = None, + annotations: dict[str, Any] | None = None, + ) -> "LLMTool": + """创建工具""" + function_def = { + "name": name, + "description": description, + "parameters": { + "type": "object", + "properties": parameters, + "required": required or [], + }, + } + return cls(type="function", function=function_def, annotations=annotations) + + +class LLMCodeExecution(BaseModel): + """代码执行结果""" + + code: str + output: str | None = None + error: str | None = None + execution_time: float | None = None + files_generated: list[str] | None = None + + +class LLMGroundingAttribution(BaseModel): + """信息来源关联""" + + title: str | None = None + uri: str | None = None + snippet: str | None = None + confidence_score: float | None = None + + +class LLMGroundingMetadata(BaseModel): + """信息来源关联元数据""" + + web_search_queries: list[str] | None = None + grounding_attributions: list[LLMGroundingAttribution] | None = None + search_suggestions: list[dict[str, Any]] | None = None + + +class LLMCacheInfo(BaseModel): + """缓存信息""" + + cache_hit: bool = False + cache_key: str | None = None + cache_ttl: int | None = None + created_at: str | None = None diff --git a/zhenxun/services/llm/utils.py b/zhenxun/services/llm/utils.py new file mode 100644 index 00000000..2e41f34a --- /dev/null +++ b/zhenxun/services/llm/utils.py @@ -0,0 +1,183 @@ +""" +LLM 模块的工具和转换函数 +""" + +import base64 +from pathlib import Path + +from nonebot_plugin_alconna.uniseg import At, File, Image, Reply, Text, UniMessage, Video, Voice + +from zhenxun.services.log import logger + +from .types import LLMContentPart + + +async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]: + """ + 将 UniMessage 实例转换为一个 LLMContentPart 列表。 + 这是处理多模态输入的核心转换逻辑。 + """ + parts: list[LLMContentPart] = [] + for seg in message: + part = None + if isinstance(seg, Text): + if seg.text.strip(): + part = LLMContentPart.text_part(seg.text) + elif isinstance(seg, Image): + if seg.path: + part = await LLMContentPart.from_path(seg.path, target_api="gemini") + elif seg.url: + part = LLMContentPart.image_url_part(seg.url) + elif hasattr(seg, "raw") and seg.raw: + mime_type = getattr(seg, "mimetype", "image/png") if hasattr(seg, "mimetype") else "image/png" + if isinstance(seg.raw, bytes): + b64_data = base64.b64encode(seg.raw).decode("utf-8") + part = LLMContentPart.image_base64_part(b64_data, mime_type) + + elif isinstance(seg, File | Voice | Video): + if seg.path: + part = await LLMContentPart.from_path(seg.path) + elif seg.url: + logger.warning(f"直接使用 URL 的 {type(seg).__name__} 段,API 可能不支持: {seg.url}") + part = LLMContentPart.text_part(f"[{type(seg).__name__.upper()} FILE: {seg.name or seg.url}]") + elif hasattr(seg, "raw") and seg.raw: + mime_type = getattr(seg, "mimetype", None) + if isinstance(seg.raw, bytes): + b64_data = base64.b64encode(seg.raw).decode("utf-8") + + if isinstance(seg, Video): + if not mime_type: + mime_type = "video/mp4" + part = LLMContentPart.video_base64_part(data=b64_data, mime_type=mime_type) + logger.debug(f"处理视频字节数据: {mime_type}, 大小: {len(seg.raw)} bytes") + elif isinstance(seg, Voice): + if not mime_type: + mime_type = "audio/wav" + part = LLMContentPart.audio_base64_part(data=b64_data, mime_type=mime_type) + logger.debug(f"处理音频字节数据: {mime_type}, 大小: {len(seg.raw)} bytes") + else: + part = LLMContentPart.text_part( + f"[FILE: {mime_type or 'unknown'}, {len(seg.raw)} bytes]" + ) + logger.debug(f"处理其他文件字节数据: {mime_type}, 大小: {len(seg.raw)} bytes") + + elif isinstance(seg, At): + if seg.flag == "all": + part = LLMContentPart.text_part("[Mentioned Everyone]") + else: + part = LLMContentPart.text_part(f"[Mentioned user: {seg.target}]") + + elif isinstance(seg, Reply): + if seg.msg: + try: + extract_method = getattr(seg.msg, "extract_plain_text", None) + if extract_method and callable(extract_method): + reply_text = str(extract_method()).strip() + else: + reply_text = str(seg.msg).strip() + if reply_text: + part = LLMContentPart.text_part(f'[Replied to: "{reply_text[:50]}..."]') + except Exception: + part = LLMContentPart.text_part("[Replied to a message]") + + if part: + parts.append(part) + + return parts + + +def create_multimodal_message( + text: str | None = None, + images: list[str | Path | bytes] | str | Path | bytes | None = None, + videos: list[str | Path | bytes] | str | Path | bytes | None = None, + audios: list[str | Path | bytes] | str | Path | bytes | None = None, + image_mimetypes: list[str] | str | None = None, + video_mimetypes: list[str] | str | None = None, + audio_mimetypes: list[str] | str | None = None, +) -> UniMessage: + """ + 创建多模态消息的便捷函数,方便第三方调用。 + + Args: + text: 文本内容 + images: 图片数据,支持路径、字节数据或URL + videos: 视频数据,支持路径、字节数据或URL + audios: 音频数据,支持路径、字节数据或URL + image_mimetypes: 图片MIME类型,当images为bytes时需要指定 + video_mimetypes: 视频MIME类型,当videos为bytes时需要指定 + audio_mimetypes: 音频MIME类型,当audios为bytes时需要指定 + + Returns: + UniMessage: 构建好的多模态消息 + + Examples: + # 纯文本 + msg = create_multimodal_message("请分析这段文字") + + # 文本 + 单张图片(路径) + msg = create_multimodal_message("分析图片", images="/path/to/image.jpg") + + # 文本 + 多张图片 + msg = create_multimodal_message("比较图片", images=["/path/1.jpg", "/path/2.jpg"]) + + # 文本 + 图片字节数据 + msg = create_multimodal_message("分析", images=image_data, image_mimetypes="image/jpeg") + + # 文本 + 视频 + msg = create_multimodal_message("分析视频", videos="/path/to/video.mp4") + + # 文本 + 音频 + msg = create_multimodal_message("转录音频", audios="/path/to/audio.wav") + + # 混合多模态 + msg = create_multimodal_message( + "分析这些媒体文件", + images="/path/to/image.jpg", + videos="/path/to/video.mp4", + audios="/path/to/audio.wav" + ) + """ + message = UniMessage() + + if text: + message.append(Text(text)) + + if images is not None: + _add_media_to_message(message, images, image_mimetypes, Image, "image/png") + + if videos is not None: + _add_media_to_message(message, videos, video_mimetypes, Video, "video/mp4") + + if audios is not None: + _add_media_to_message(message, audios, audio_mimetypes, Voice, "audio/wav") + + return message + + +def _add_media_to_message( + message: UniMessage, + media_items: list[str | Path | bytes] | str | Path | bytes, + mimetypes: list[str] | str | None, + media_class: type, + default_mimetype: str, +) -> None: + """添加媒体文件到 UniMessage 的辅助函数""" + if not isinstance(media_items, list): + media_items = [media_items] + + mime_list = [] + if mimetypes is not None: + if isinstance(mimetypes, str): + mime_list = [mimetypes] * len(media_items) + else: + mime_list = list(mimetypes) + + for i, item in enumerate(media_items): + if isinstance(item, str | Path): + if str(item).startswith(("http://", "https://")): + message.append(media_class(url=str(item))) + else: + message.append(media_class(path=Path(item))) + elif isinstance(item, bytes): + mimetype = mime_list[i] if i < len(mime_list) else default_mimetype + message.append(media_class(raw=item, mimetype=mimetype))