From bba90e62db73e769c9b6dc1e8752b1709b3a1e2e Mon Sep 17 00:00:00 2001 From: webjoin111 <455457521@qq.com> Date: Sun, 7 Dec 2025 18:57:55 +0800 Subject: [PATCH 1/4] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor(llm):=20?= =?UTF-8?q?=E9=87=8D=E6=9E=84=20LLM=20=E6=9C=8D=E5=8A=A1=E6=9E=B6=E6=9E=84?= =?UTF-8?q?=EF=BC=8C=E5=BC=95=E5=85=A5=E4=B8=AD=E9=97=B4=E4=BB=B6=E4=B8=8E?= =?UTF-8?q?=E7=BB=84=E4=BB=B6=E5=8C=96=E9=80=82=E9=85=8D=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 【重构】LLM 服务核心架构: - 引入中间件管道,统一处理请求生命周期(重试、密钥选择、日志、网络请求)。 - 适配器重构为组件化设计,分离配置映射、消息转换、响应解析和工具序列化逻辑。 - 移除 `with_smart_retry` 装饰器,其功能由中间件接管。 - 移除 `LLMToolExecutor`,工具执行逻辑集成到 `ToolInvoker`。 - 【功能】增强配置系统: - `LLMGenerationConfig` 采用组件化结构(Core, Reasoning, Visual, Output, Safety, ToolConfig)。 - 新增 `GenConfigBuilder` 提供语义化配置构建方式。 - 新增 `LLMEmbeddingConfig` 用于嵌入专用配置。 - `CommonOverrides` 迁移并更新至新配置结构。 - 【功能】强化工具系统: - 引入 `ToolInvoker` 实现更灵活的工具执行,支持回调与结构化错误。 - `function_tool` 装饰器支持动态 Pydantic 模型创建和依赖注入 (`ToolParam`, `RunContext`)。 - 平台原生工具支持 (`GeminiCodeExecution`, `GeminiGoogleSearch`, `GeminiUrlContext`)。 - 【功能】高级生成与嵌入: - `generate_structured` 方法支持 In-Context Validation and Repair (IVR) 循环和 AutoCoT (思维链) 包装。 - 新增 `embed_query` 和 `embed_documents` 便捷嵌入 API。 - `OpenAIImageAdapter` 支持 OpenAI 兼容的图像生成。 - `SmartAdapter` 实现模型名称智能路由。 - 【重构】消息与类型系统: - `LLMContentPart` 扩展支持更多模态和代码执行相关内容。 - `LLMMessage` 和 `LLMResponse` 结构更新,支持 `content_parts` 和思维链签名。 - 统一 `LLMErrorCode` 和用户友好错误消息,提供更详细的网络/代理错误提示。 - `pyproject.toml` 移除 `bilireq`,新增 `json_repair`。 - 【优化】日志与调试: - 引入 `DebugLogOptions`,提供细粒度日志脱敏控制。 - 增强日志净化器,处理更多敏感数据和长字符串。 - 【清理】删除废弃模块: - `zhenxun/services/llm/memory.py` - `zhenxun/services/llm/executor.py` - `zhenxun/services/llm/config/presets.py` - `zhenxun/services/llm/types/content.py` - `zhenxun/services/llm/types/enums.py` - `zhenxun/services/llm/tools/__init__.py` - `zhenxun/services/llm/tools/manager.py` --- pyproject.toml | 2 +- zhenxun/services/llm/__init__.py | 24 +- zhenxun/services/llm/adapters/__init__.py | 4 +- zhenxun/services/llm/adapters/base.py | 372 ++++--- .../llm/adapters/components/__init__.py | 1 + .../adapters/components/gemini_components.py | 606 +++++++++++ .../llm/adapters/components/interfaces.py | 43 + .../adapters/components/openai_components.py | 347 ++++++ zhenxun/services/llm/adapters/factory.py | 113 +- zhenxun/services/llm/adapters/gemini.py | 580 +++------- zhenxun/services/llm/adapters/openai.py | 568 +++++++++- zhenxun/services/llm/api.py | 297 +++--- zhenxun/services/llm/config/__init__.py | 12 +- zhenxun/services/llm/config/generation.py | 611 +++++++---- zhenxun/services/llm/config/presets.py | 172 --- zhenxun/services/llm/config/providers.py | 130 ++- zhenxun/services/llm/core.py | 153 +-- zhenxun/services/llm/executor.py | 193 ---- zhenxun/services/llm/manager.py | 31 +- zhenxun/services/llm/memory.py | 55 - zhenxun/services/llm/service.py | 990 +++++++++++------- zhenxun/services/llm/session.py | 577 ++++++++-- zhenxun/services/llm/tools.py | 839 +++++++++++++++ zhenxun/services/llm/tools/__init__.py | 13 - zhenxun/services/llm/tools/manager.py | 293 ------ zhenxun/services/llm/types/__init__.py | 32 +- zhenxun/services/llm/types/capabilities.py | 151 ++- zhenxun/services/llm/types/content.py | 434 -------- zhenxun/services/llm/types/enums.py | 78 -- zhenxun/services/llm/types/exceptions.py | 61 +- zhenxun/services/llm/types/models.py | 518 ++++++++- zhenxun/services/llm/types/protocols.py | 95 +- zhenxun/services/llm/utils.py | 562 +++++++--- zhenxun/utils/log_sanitizer.py | 183 +++- zhenxun/utils/pydantic_compat.py | 44 +- 35 files changed, 6087 insertions(+), 3097 deletions(-) create mode 100644 zhenxun/services/llm/adapters/components/__init__.py create mode 100644 zhenxun/services/llm/adapters/components/gemini_components.py create mode 100644 zhenxun/services/llm/adapters/components/interfaces.py create mode 100644 zhenxun/services/llm/adapters/components/openai_components.py delete mode 100644 zhenxun/services/llm/config/presets.py delete mode 100644 zhenxun/services/llm/executor.py delete mode 100644 zhenxun/services/llm/memory.py create mode 100644 zhenxun/services/llm/tools.py delete mode 100644 zhenxun/services/llm/tools/__init__.py delete mode 100644 zhenxun/services/llm/tools/manager.py delete mode 100644 zhenxun/services/llm/types/content.py delete mode 100644 zhenxun/services/llm/types/enums.py diff --git a/pyproject.toml b/pyproject.toml index a2dfbc82..615cad87 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,6 @@ feedparser = "^6.0.11" imagehash = "^4.3.1" cn2an = "^0.5.22" dateparser = "^1.2.0" -bilireq = ">=0.2.10" python-jose = { extras = ["cryptography"], version = "^3.3.0" } python-multipart = "^0.0.9" aiocache = {extras = ["redis"], version = "^0.12.3"} @@ -46,6 +45,7 @@ tenacity = "^9.0.0" nonebot-plugin-uninfo = ">=0.7.3" nonebot-plugin-waiter = "^0.8.1" multidict = ">=6.0.0,!=6.3.2" +json_repair = "^0.54.0" redis = { version = ">=5", optional = true } asyncpg = { version = ">=0.20.0", optional = true } diff --git a/zhenxun/services/llm/__init__.py b/zhenxun/services/llm/__init__.py index 3bce85ce..eaa6f8a3 100644 --- a/zhenxun/services/llm/__init__.py +++ b/zhenxun/services/llm/__init__.py @@ -9,13 +9,15 @@ from .api import ( code, create_image, embed, + embed_documents, + embed_query, generate, generate_structured, - run_with_tools, search, ) from .config import ( CommonOverrides, + GenConfigBuilder, LLMGenerationConfig, register_llm_configs, ) @@ -32,8 +34,8 @@ from .manager import ( list_model_identifiers, set_global_default_model_name, ) -from .session import AI, AIConfig -from .tools import function_tool, tool_provider_manager +from .session import AI, AIConfig, MemoryProcessor, set_default_memory_backend +from .tools import RunContext, ToolInvoker, function_tool, tool_provider_manager from .types import ( EmbeddingTaskType, LLMContentPart, @@ -50,6 +52,11 @@ from .types import ( ToolMetadata, UsageInfo, ) +from .types.models import ( + GeminiCodeExecution, + GeminiGoogleSearch, + GeminiUrlContext, +) from .utils import create_multimodal_message, message_to_unimessage, unimsg_to_llm_parts __all__ = [ @@ -57,19 +64,26 @@ __all__ = [ "AIConfig", "CommonOverrides", "EmbeddingTaskType", + "GeminiCodeExecution", + "GeminiGoogleSearch", + "GeminiUrlContext", + "GenConfigBuilder", "LLMContentPart", "LLMErrorCode", "LLMException", "LLMGenerationConfig", "LLMMessage", "LLMResponse", + "MemoryProcessor", "ModelDetail", "ModelInfo", "ModelName", "ModelProvider", "ResponseFormat", + "RunContext", "TaskType", "ToolCategory", + "ToolInvoker", "ToolMetadata", "UsageInfo", "chat", @@ -78,6 +92,8 @@ __all__ = [ "create_image", "create_multimodal_message", "embed", + "embed_documents", + "embed_query", "function_tool", "generate", "generate_structured", @@ -89,8 +105,8 @@ __all__ = [ "list_model_identifiers", "message_to_unimessage", "register_llm_configs", - "run_with_tools", "search", + "set_default_memory_backend", "set_global_default_model_name", "tool_provider_manager", "unimsg_to_llm_parts", diff --git a/zhenxun/services/llm/adapters/__init__.py b/zhenxun/services/llm/adapters/__init__.py index 773d3ed2..d296fb33 100644 --- a/zhenxun/services/llm/adapters/__init__.py +++ b/zhenxun/services/llm/adapters/__init__.py @@ -7,16 +7,18 @@ LLM 适配器模块 from .base import BaseAdapter, OpenAICompatAdapter, RequestData, ResponseData from .factory import LLMAdapterFactory, get_adapter_for_api_type, register_adapter from .gemini import GeminiAdapter -from .openai import OpenAIAdapter +from .openai import DeepSeekAdapter, OpenAIAdapter, OpenAIImageAdapter LLMAdapterFactory.initialize() __all__ = [ "BaseAdapter", + "DeepSeekAdapter", "GeminiAdapter", "LLMAdapterFactory", "OpenAIAdapter", "OpenAICompatAdapter", + "OpenAIImageAdapter", "RequestData", "ResponseData", "get_adapter_for_api_type", diff --git a/zhenxun/services/llm/adapters/base.py b/zhenxun/services/llm/adapters/base.py index 67888816..ca19aaeb 100644 --- a/zhenxun/services/llm/adapters/base.py +++ b/zhenxun/services/llm/adapters/base.py @@ -3,24 +3,26 @@ LLM 适配器基类和通用数据结构 """ from abc import ABC, abstractmethod -import base64 -import binascii import json +from pathlib import Path from typing import TYPE_CHECKING, Any +import uuid +import httpx from pydantic import BaseModel +from zhenxun.configs.path_config import TEMP_PATH from zhenxun.services.log import logger +from ..types import LLMContentPart from ..types.exceptions import LLMErrorCode, LLMException from ..types.models import LLMToolCall if TYPE_CHECKING: - from ..config.generation import LLMGenerationConfig + from ..config.generation import LLMEmbeddingConfig, LLMGenerationConfig from ..service import LLMModel - from ..types.content import LLMMessage - from ..types.enums import EmbeddingTaskType - from ..types.protocols import ToolExecutable + from ..types import LLMMessage + from ..types.models import ToolChoice class RequestData(BaseModel): @@ -29,19 +31,23 @@ class RequestData(BaseModel): url: str headers: dict[str, str] body: dict[str, Any] + files: dict[str, Any] | list[tuple[str, Any]] | None = None class ResponseData(BaseModel): """响应数据封装 - 支持所有高级功能""" text: str - images: list[bytes] | None = None + content_parts: list[LLMContentPart] | None = None + images: list[bytes | Path] | None = None usage_info: dict[str, Any] | None = None raw_response: dict[str, Any] | None = None tool_calls: list[LLMToolCall] | None = None code_executions: list[Any] | None = None grounding_metadata: Any | None = None cache_info: Any | None = None + thought_text: str | None = None + thought_signature: str | None = None code_execution_results: list[dict[str, Any]] | None = None search_results: list[dict[str, Any]] | None = None @@ -50,9 +56,33 @@ class ResponseData(BaseModel): citations: list[dict[str, Any]] | None = None +def process_image_data(image_data: bytes) -> bytes | Path: + """ + 处理图片数据:若超过 2MB 则保存到临时目录,避免占用内存。 + """ + max_inline_size = 2 * 1024 * 1024 + if len(image_data) > max_inline_size: + save_dir = TEMP_PATH / "llm" + save_dir.mkdir(parents=True, exist_ok=True) + file_name = f"{uuid.uuid4()}.png" + file_path = save_dir / file_name + file_path.write_bytes(image_data) + logger.info( + f"图片数据过大 ({len(image_data)} bytes),已保存到临时文件: {file_path}", + "LLMAdapter", + ) + return file_path.resolve() + return image_data + + class BaseAdapter(ABC): """LLM API适配器基类""" + @property + def log_sanitization_context(self) -> str: + """用于日志清洗的上下文名称,默认 'default'""" + return "default" + @property @abstractmethod def api_type(self) -> str: @@ -77,7 +107,7 @@ class BaseAdapter(ABC): 默认实现:将简单请求转换为高级请求格式 子类可以重写此方法以提供特定的优化实现 """ - from ..types.content import LLMMessage + from ..types import LLMMessage messages: list[LLMMessage] = [] @@ -107,8 +137,8 @@ class BaseAdapter(ABC): api_key: str, messages: list["LLMMessage"], config: "LLMGenerationConfig | None" = None, - tools: dict[str, "ToolExecutable"] | None = None, - tool_choice: str | dict[str, Any] | None = None, + tools: list[Any] | None = None, + tool_choice: "str | dict[str, Any] | ToolChoice | None" = None, ) -> RequestData: """准备高级请求""" pass @@ -129,8 +159,7 @@ class BaseAdapter(ABC): model: "LLMModel", api_key: str, texts: list[str], - task_type: "EmbeddingTaskType | str", - **kwargs: Any, + config: "LLMEmbeddingConfig", ) -> RequestData: """准备文本嵌入请求""" pass @@ -142,9 +171,16 @@ class BaseAdapter(ABC): """解析文本嵌入响应""" pass + @abstractmethod + def convert_generation_config( + self, config: "LLMGenerationConfig", model: "LLMModel" + ) -> dict[str, Any]: + """将通用生成配置转换为特定API的参数字典""" + pass + def validate_embedding_response(self, response_json: dict[str, Any]) -> None: """验证嵌入API响应""" - if "error" in response_json: + if response_json.get("error"): error_info = response_json["error"] msg = ( error_info.get("message", str(error_info)) @@ -179,158 +215,9 @@ class BaseAdapter(ABC): ) return headers - def convert_messages_to_openai_format( - self, messages: list["LLMMessage"] - ) -> list[dict[str, Any]]: - """将LLMMessage转换为OpenAI格式 - 通用方法""" - openai_messages: list[dict[str, Any]] = [] - for msg in messages: - openai_msg: dict[str, Any] = {"role": msg.role} - - if msg.role == "tool": - openai_msg["tool_call_id"] = msg.tool_call_id - openai_msg["name"] = msg.name - openai_msg["content"] = msg.content - else: - if isinstance(msg.content, str): - openai_msg["content"] = msg.content - else: - content_parts = [] - for part in msg.content: - if part.type == "text": - content_parts.append({"type": "text", "text": part.text}) - elif part.type == "image": - content_parts.append( - { - "type": "image_url", - "image_url": {"url": part.image_source}, - } - ) - openai_msg["content"] = content_parts - - if msg.role == "assistant" and msg.tool_calls: - assistant_tool_calls = [] - for call in msg.tool_calls: - assistant_tool_calls.append( - { - "id": call.id, - "type": "function", - "function": { - "name": call.function.name, - "arguments": call.function.arguments, - }, - } - ) - openai_msg["tool_calls"] = assistant_tool_calls - - if msg.name and msg.role != "tool": - openai_msg["name"] = msg.name - - openai_messages.append(openai_msg) - return openai_messages - - def parse_openai_response(self, response_json: dict[str, Any]) -> ResponseData: - """解析OpenAI格式的响应 - 通用方法""" - self.validate_response(response_json) - - try: - choices = response_json.get("choices", []) - if not choices: - logger.debug("OpenAI响应中没有choices,可能为空回复或流结束。") - return ResponseData(text="", raw_response=response_json) - - choice = choices[0] - message = choice.get("message", {}) - content = message.get("content", "") - - if content: - content = content.strip() - - images_bytes: list[bytes] = [] - if content and content.startswith("{") and content.endswith("}"): - try: - content_json = json.loads(content) - if "b64_json" in content_json: - images_bytes.append(base64.b64decode(content_json["b64_json"])) - content = "[图片已生成]" - elif "data" in content_json and isinstance( - content_json["data"], str - ): - images_bytes.append(base64.b64decode(content_json["data"])) - content = "[图片已生成]" - - except (json.JSONDecodeError, KeyError, binascii.Error): - pass - elif ( - "images" in message - and isinstance(message["images"], list) - and message["images"] - ): - image_info = message["images"][0] - if image_info.get("type") == "image_url": - image_url_obj = image_info.get("image_url", {}) - url_str = image_url_obj.get("url", "") - if url_str.startswith("data:image/png;base64,"): - try: - b64_data = url_str.split(",", 1)[1] - images_bytes.append(base64.b64decode(b64_data)) - content = content if content else "[图片已生成]" - except (IndexError, binascii.Error) as e: - logger.warning(f"解析OpenRouter Base64图片数据失败: {e}") - - parsed_tool_calls: list[LLMToolCall] | None = None - if message_tool_calls := message.get("tool_calls"): - from ..types.models import LLMToolFunction - - parsed_tool_calls = [] - for tc_data in message_tool_calls: - try: - if tc_data.get("type") == "function": - parsed_tool_calls.append( - LLMToolCall( - id=tc_data["id"], - function=LLMToolFunction( - name=tc_data["function"]["name"], - arguments=tc_data["function"]["arguments"], - ), - ) - ) - except KeyError as e: - logger.warning( - f"解析OpenAI工具调用数据时缺少键: {tc_data}, 错误: {e}" - ) - except Exception as e: - logger.warning( - f"解析OpenAI工具调用数据时出错: {tc_data}, 错误: {e}" - ) - if not parsed_tool_calls: - parsed_tool_calls = None - - final_text = content if content is not None else "" - if not final_text and parsed_tool_calls: - final_text = f"请求调用 {len(parsed_tool_calls)} 个工具。" - - usage_info = response_json.get("usage") - - return ResponseData( - text=final_text, - tool_calls=parsed_tool_calls, - usage_info=usage_info, - images=images_bytes if images_bytes else None, - raw_response=response_json, - ) - - except Exception as e: - logger.error(f"解析OpenAI格式响应失败: {e}", e=e) - raise LLMException( - f"解析API响应失败: {e}", - code=LLMErrorCode.RESPONSE_PARSE_ERROR, - cause=e, - ) - def validate_response(self, response_json: dict[str, Any]) -> None: """验证API响应,解析不同API的错误结构""" - if "error" in response_json: + if response_json.get("error"): error_info = response_json["error"] if isinstance(error_info, dict): @@ -341,12 +228,15 @@ class BaseAdapter(ABC): error_code_mapping = { "invalid_api_key": LLMErrorCode.API_KEY_INVALID, "authentication_failed": LLMErrorCode.API_KEY_INVALID, + "insufficient_quota": LLMErrorCode.API_QUOTA_EXCEEDED, "rate_limit_exceeded": LLMErrorCode.API_RATE_LIMITED, "quota_exceeded": LLMErrorCode.API_RATE_LIMITED, "model_not_found": LLMErrorCode.MODEL_NOT_FOUND, "invalid_model": LLMErrorCode.MODEL_NOT_FOUND, "context_length_exceeded": LLMErrorCode.CONTEXT_LENGTH_EXCEEDED, "max_tokens_exceeded": LLMErrorCode.CONTEXT_LENGTH_EXCEEDED, + "invalid_request_error": LLMErrorCode.INVALID_PARAMETER, + "invalid_parameter": LLMErrorCode.INVALID_PARAMETER, } llm_error_code = error_code_mapping.get( @@ -405,23 +295,12 @@ class BaseAdapter(ABC): ) -> dict[str, Any]: """通用的配置应用逻辑""" if config is not None: - return config.to_api_params(model.api_type, model.model_name) + return self.convert_generation_config(config, model) - if model._generation_config is not None: - return model._generation_config.to_api_params( - model.api_type, model.model_name - ) + if model._generation_config: + return self.convert_generation_config(model._generation_config, model) - base_config = {} - if model.temperature is not None: - base_config["temperature"] = model.temperature - if model.max_tokens is not None: - if model.api_type == "gemini": - base_config["maxOutputTokens"] = model.max_tokens - else: - base_config["max_tokens"] = model.max_tokens - - return base_config + return {} def apply_config_override( self, @@ -434,12 +313,96 @@ class BaseAdapter(ABC): body.update(config_params) return body + def handle_http_error(self, response: httpx.Response) -> LLMException | None: + """ + 处理 HTTP 错误响应。 + 如果响应状态码表示成功 (200),返回 None;否则构造 LLMException 供外部捕获。 + """ + if response.status_code == 200: + return None + + error_text = response.content.decode("utf-8", errors="ignore") + error_status = "" + error_msg = error_text + try: + error_json = json.loads(error_text) + if isinstance(error_json, dict) and "error" in error_json: + error_info = error_json["error"] + if isinstance(error_info, dict): + error_msg = error_info.get("message", error_msg) + raw_status = error_info.get("status") or error_info.get("code") + error_status = str(raw_status) if raw_status is not None else "" + elif error_info is not None: + error_msg = str(error_info) + error_status = error_msg + except Exception: + pass + + status_upper = error_status.upper() if error_status else "" + text_upper = error_text.upper() + + error_code = LLMErrorCode.API_REQUEST_FAILED + if response.status_code == 400: + if ( + "FAILED_PRECONDITION" in status_upper + or "LOCATION IS NOT SUPPORTED" in text_upper + ): + error_code = LLMErrorCode.USER_LOCATION_NOT_SUPPORTED + elif "INVALID_ARGUMENT" in status_upper: + error_code = LLMErrorCode.INVALID_PARAMETER + elif "API_KEY_INVALID" in text_upper or "API KEY NOT VALID" in text_upper: + error_code = LLMErrorCode.API_KEY_INVALID + else: + error_code = LLMErrorCode.INVALID_PARAMETER + elif response.status_code in [401, 403]: + if error_msg and ( + "country" in error_msg.lower() + or "region" in error_msg.lower() + or "unsupported" in error_msg.lower() + ): + error_code = LLMErrorCode.USER_LOCATION_NOT_SUPPORTED + elif "PERMISSION_DENIED" in status_upper: + error_code = LLMErrorCode.API_KEY_INVALID + else: + error_code = LLMErrorCode.API_KEY_INVALID + elif response.status_code == 404: + error_code = LLMErrorCode.MODEL_NOT_FOUND + elif response.status_code == 429: + if ( + "RESOURCE_EXHAUSTED" in status_upper + or "INSUFFICIENT_QUOTA" in status_upper + or ("quota" in error_msg.lower() if error_msg else False) + ): + error_code = LLMErrorCode.API_QUOTA_EXCEEDED + else: + error_code = LLMErrorCode.API_RATE_LIMITED + elif response.status_code in [402, 413]: + error_code = LLMErrorCode.API_QUOTA_EXCEEDED + elif response.status_code == 422: + error_code = LLMErrorCode.GENERATION_FAILED + elif response.status_code >= 500: + error_code = LLMErrorCode.API_TIMEOUT + + return LLMException( + f"HTTP请求失败: {response.status_code} ({error_status or 'Unknown'})", + code=error_code, + details={ + "status_code": response.status_code, + "api_status": error_status, + "response": error_text, + }, + ) + class OpenAICompatAdapter(BaseAdapter): """ 处理所有 OpenAI 兼容 API 的通用适配器。 """ + @property + def log_sanitization_context(self) -> str: + return "openai_request" + @abstractmethod def get_chat_endpoint(self, model: "LLMModel") -> str: """子类必须实现,返回 chat completions 的端点""" @@ -481,8 +444,8 @@ class OpenAICompatAdapter(BaseAdapter): api_key: str, messages: list["LLMMessage"], config: "LLMGenerationConfig | None" = None, - tools: dict[str, "ToolExecutable"] | None = None, - tool_choice: str | dict[str, Any] | None = None, + tools: list[Any] | None = None, + tool_choice: "str | dict[str, Any] | ToolChoice | None" = None, ) -> RequestData: """准备高级请求 - OpenAI兼容格式""" url = self.get_api_url(model, self.get_chat_endpoint(model)) @@ -494,28 +457,44 @@ class OpenAICompatAdapter(BaseAdapter): "X-Title": "Zhenxun Bot", } ) - openai_messages = self.convert_messages_to_openai_format(messages) + from .components.openai_components import OpenAIMessageConverter + + converter = OpenAIMessageConverter() + openai_messages = converter.convert_messages(messages) body = { "model": model.model_name, "messages": openai_messages, } + openai_tools: list[dict[str, Any]] | None = None + executables: list[Any] = [] if tools: + for tool in tools: + if hasattr(tool, "get_definition"): + executables.append(tool) + + if executables: import asyncio from zhenxun.utils.pydantic_compat import model_dump definition_tasks = [ - executable.get_definition() for executable in tools.values() + executable.get_definition() for executable in executables ] - openai_tools = await asyncio.gather(*definition_tasks) - if openai_tools: - body["tools"] = [ + tool_defs = [] + if definition_tasks: + tool_defs = await asyncio.gather(*definition_tasks) + + if tool_defs: + openai_tools = [ {"type": "function", "function": model_dump(tool)} - for tool in openai_tools + for tool in tool_defs ] + if openai_tools: + body["tools"] = openai_tools + if tool_choice: body["tool_choice"] = tool_choice @@ -528,20 +507,21 @@ class OpenAICompatAdapter(BaseAdapter): response_json: dict[str, Any], is_advanced: bool = False, ) -> ResponseData: - """解析响应 - 直接使用基类的 OpenAI 格式解析""" + """解析响应 - 直接使用组件化 ResponseParser""" _ = model, is_advanced - return self.parse_openai_response(response_json) + from .components.openai_components import OpenAIResponseParser + + parser = OpenAIResponseParser() + return parser.parse(response_json) def prepare_embedding_request( self, model: "LLMModel", api_key: str, texts: list[str], - task_type: "EmbeddingTaskType | str", - **kwargs: Any, + config: "LLMEmbeddingConfig", ) -> RequestData: """准备嵌入请求 - OpenAI兼容格式""" - _ = task_type url = self.get_api_url(model, self.get_embedding_endpoint(model)) headers = self.get_base_headers(api_key) @@ -550,8 +530,14 @@ class OpenAICompatAdapter(BaseAdapter): "input": texts, } - if kwargs: - body.update(kwargs) + if config.output_dimensionality: + body["dimensions"] = config.output_dimensionality + + if config.task_type: + body["task"] = config.task_type + + if config.encoding_format and config.encoding_format != "float": + body["encoding_format"] = config.encoding_format return RequestData(url=url, headers=headers, body=body) diff --git a/zhenxun/services/llm/adapters/components/__init__.py b/zhenxun/services/llm/adapters/components/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/zhenxun/services/llm/adapters/components/__init__.py @@ -0,0 +1 @@ + diff --git a/zhenxun/services/llm/adapters/components/gemini_components.py b/zhenxun/services/llm/adapters/components/gemini_components.py new file mode 100644 index 00000000..b2b9015f --- /dev/null +++ b/zhenxun/services/llm/adapters/components/gemini_components.py @@ -0,0 +1,606 @@ +import base64 +import json +from pathlib import Path +from typing import Any + +from zhenxun.services.llm.adapters.base import ResponseData, process_image_data +from zhenxun.services.llm.adapters.components.interfaces import ( + ConfigMapper, + MessageConverter, + ResponseParser, + ToolSerializer, +) +from zhenxun.services.llm.config.generation import ( + ImageAspectRatio, + LLMGenerationConfig, + ReasoningEffort, + ResponseFormat, +) +from zhenxun.services.llm.config.providers import get_gemini_safety_threshold +from zhenxun.services.llm.types import ( + CodeExecutionOutcome, + LLMContentPart, + LLMMessage, +) +from zhenxun.services.llm.types.capabilities import ModelCapabilities +from zhenxun.services.llm.types.exceptions import LLMErrorCode, LLMException +from zhenxun.services.llm.types.models import ( + LLMGroundingAttribution, + LLMGroundingMetadata, + LLMToolCall, + LLMToolFunction, + ModelDetail, + ToolDefinition, +) +from zhenxun.services.llm.utils import ( + resolve_json_schema_refs, + sanitize_schema_for_llm, +) +from zhenxun.services.log import logger +from zhenxun.utils.http_utils import AsyncHttpx +from zhenxun.utils.pydantic_compat import model_copy, model_dump + + +class GeminiConfigMapper(ConfigMapper): + def map_config( + self, + config: LLMGenerationConfig, + model_detail: ModelDetail | None = None, + capabilities: ModelCapabilities | None = None, + ) -> dict[str, Any]: + params: dict[str, Any] = {} + + if config.core: + if config.core.temperature is not None: + params["temperature"] = config.core.temperature + if config.core.max_tokens is not None: + params["maxOutputTokens"] = config.core.max_tokens + if config.core.top_k is not None: + params["topK"] = config.core.top_k + if config.core.top_p is not None: + params["topP"] = config.core.top_p + + if config.output: + if config.output.response_format == ResponseFormat.JSON: + params["responseMimeType"] = "application/json" + if config.output.response_schema: + params["responseJsonSchema"] = config.output.response_schema + elif config.output.response_mime_type is not None: + params["responseMimeType"] = config.output.response_mime_type + + if ( + config.output.response_schema is not None + and "responseJsonSchema" not in params + ): + params["responseJsonSchema"] = config.output.response_schema + if config.output.response_modalities: + params["responseModalities"] = config.output.response_modalities + + if config.tool_config: + fc_config: dict[str, Any] = {"mode": config.tool_config.mode} + if ( + config.tool_config.allowed_function_names + and config.tool_config.mode == "ANY" + ): + builtins = {"code_execution", "google_search", "google_map"} + user_funcs = [ + name + for name in config.tool_config.allowed_function_names + if name not in builtins + ] + if user_funcs: + fc_config["allowedFunctionNames"] = user_funcs + params["toolConfig"] = {"functionCallingConfig": fc_config} + + if config.reasoning: + thinking_config = params.setdefault("thinkingConfig", {}) + + if config.reasoning.budget_tokens is not None: + if ( + config.reasoning.budget_tokens <= 0 + or config.reasoning.budget_tokens >= 1 + ): + budget_value = int(config.reasoning.budget_tokens) + else: + budget_value = int(config.reasoning.budget_tokens * 32768) + thinking_config["thinkingBudget"] = budget_value + elif config.reasoning.effort: + if config.reasoning.effort == ReasoningEffort.MEDIUM: + thinking_config["thinkingLevel"] = "HIGH" + else: + thinking_config["thinkingLevel"] = config.reasoning.effort.value + + if config.reasoning.show_thoughts is not None: + thinking_config["includeThoughts"] = config.reasoning.show_thoughts + elif capabilities and capabilities.reasoning_visibility == "visible": + thinking_config["includeThoughts"] = True + + if config.visual: + image_config: dict[str, Any] = {} + + if config.visual.aspect_ratio is not None: + ar_value = ( + config.visual.aspect_ratio.value + if isinstance(config.visual.aspect_ratio, ImageAspectRatio) + else config.visual.aspect_ratio + ) + image_config["aspectRatio"] = ar_value + + if config.visual.resolution: + image_config["imageSize"] = config.visual.resolution + + if image_config: + params["imageConfig"] = image_config + + if config.visual.media_resolution: + media_value = config.visual.media_resolution.upper() + if not media_value.startswith("MEDIA_RESOLUTION_"): + media_value = f"MEDIA_RESOLUTION_{media_value}" + params["mediaResolution"] = media_value + + if config.custom_params: + mapped_custom = config.custom_params.copy() + if "max_tokens" in mapped_custom: + mapped_custom["maxOutputTokens"] = mapped_custom.pop("max_tokens") + if "top_k" in mapped_custom: + mapped_custom["topK"] = mapped_custom.pop("top_k") + if "top_p" in mapped_custom: + mapped_custom["topP"] = mapped_custom.pop("top_p") + + for key in ( + "code_execution_timeout", + "grounding_config", + "dynamic_threshold", + "user_location", + "reflexion_retries", + ): + mapped_custom.pop(key, None) + + for unsupported in [ + "frequency_penalty", + "presence_penalty", + "repetition_penalty", + ]: + if unsupported in mapped_custom: + mapped_custom.pop(unsupported) + + params.update(mapped_custom) + + safety_settings: list[dict[str, Any]] = [] + if config.safety and config.safety.safety_settings: + for category, threshold in config.safety.safety_settings.items(): + safety_settings.append({"category": category, "threshold": threshold}) + else: + threshold = get_gemini_safety_threshold() + for category in [ + "HARM_CATEGORY_HARASSMENT", + "HARM_CATEGORY_HATE_SPEECH", + "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "HARM_CATEGORY_DANGEROUS_CONTENT", + ]: + safety_settings.append({"category": category, "threshold": threshold}) + + if safety_settings: + params["safetySettings"] = safety_settings + + return params + + +class GeminiMessageConverter(MessageConverter): + async def convert_part(self, part: LLMContentPart) -> dict[str, Any]: + """将单个内容部分转换为 Gemini API 格式""" + + def _get_gemini_resolution_dict() -> dict[str, Any]: + if part.media_resolution: + value = part.media_resolution.upper() + if not value.startswith("MEDIA_RESOLUTION_"): + value = f"MEDIA_RESOLUTION_{value}" + return {"media_resolution": {"level": value}} + return {} + + if part.type == "text": + return {"text": part.text} + + if part.type == "thought": + return {"text": part.thought_text, "thought": True} + + if part.type == "image": + if not part.image_source: + raise ValueError("图像类型的内容必须包含image_source") + + if part.is_image_base64(): + base64_info = part.get_base64_data() + if base64_info: + mime_type, data = base64_info + payload = {"inlineData": {"mimeType": mime_type, "data": data}} + payload.update(_get_gemini_resolution_dict()) + return payload + raise ValueError(f"无法解析Base64图像数据: {part.image_source[:50]}...") + if part.is_image_url(): + logger.debug(f"正在为Gemini下载并编码URL图片: {part.image_source}") + try: + image_bytes = await AsyncHttpx.get_content(part.image_source) + mime_type = part.mime_type or "image/jpeg" + base64_data = base64.b64encode(image_bytes).decode("utf-8") + payload = { + "inlineData": {"mimeType": mime_type, "data": base64_data} + } + payload.update(_get_gemini_resolution_dict()) + return payload + except Exception as e: + logger.error(f"下载或编码URL图片失败: {e}", e=e) + raise ValueError(f"无法处理图片URL: {e}") + raise ValueError(f"不支持的图像源格式: {part.image_source[:50]}...") + + if part.type == "video": + if not part.video_source: + raise ValueError("视频类型的内容必须包含video_source") + + if part.video_source.startswith("data:"): + try: + header, data = part.video_source.split(",", 1) + mime_type = header.split(";")[0].replace("data:", "") + payload = {"inlineData": {"mimeType": mime_type, "data": data}} + payload.update(_get_gemini_resolution_dict()) + return payload + except (ValueError, IndexError): + raise ValueError( + f"无法解析Base64视频数据: {part.video_source[:50]}..." + ) + raise ValueError( + "Gemini API 的视频处理需要通过 File API 上传,不支持直接 URL" + ) + + if part.type == "audio": + if not part.audio_source: + raise ValueError("音频类型的内容必须包含audio_source") + + if part.audio_source.startswith("data:"): + try: + header, data = part.audio_source.split(",", 1) + mime_type = header.split(";")[0].replace("data:", "") + payload = {"inlineData": {"mimeType": mime_type, "data": data}} + payload.update(_get_gemini_resolution_dict()) + return payload + except (ValueError, IndexError): + raise ValueError( + f"无法解析Base64音频数据: {part.audio_source[:50]}..." + ) + raise ValueError( + "Gemini API 的音频处理需要通过 File API 上传,不支持直接 URL" + ) + + if part.type == "file": + if part.file_uri: + payload = { + "fileData": {"mimeType": part.mime_type, "fileUri": part.file_uri} + } + payload.update(_get_gemini_resolution_dict()) + return payload + if part.file_source: + file_name = ( + part.metadata.get("name", "file") if part.metadata else "file" + ) + return {"text": f"[文件: {file_name}]\n{part.file_source}"} + raise ValueError("文件类型的内容必须包含file_uri或file_source") + + raise ValueError(f"不支持的内容类型: {part.type}") + + async def convert_messages_async( + self, messages: list[LLMMessage] + ) -> list[dict[str, Any]]: + gemini_contents: list[dict[str, Any]] = [] + + for msg in messages: + current_parts: list[dict[str, Any]] = [] + if msg.role == "system": + continue + + elif msg.role == "user": + if isinstance(msg.content, str): + current_parts.append({"text": msg.content}) + elif isinstance(msg.content, list): + for part_obj in msg.content: + current_parts.append(await self.convert_part(part_obj)) + gemini_contents.append({"role": "user", "parts": current_parts}) + + elif msg.role == "assistant" or msg.role == "model": + if isinstance(msg.content, str) and msg.content: + current_parts.append({"text": msg.content}) + elif isinstance(msg.content, list): + for part_obj in msg.content: + part_dict = await self.convert_part(part_obj) + + if "executableCode" in part_dict: + part_dict["executable_code"] = part_dict.pop( + "executableCode" + ) + + if "codeExecutionResult" in part_dict: + part_dict["code_execution_result"] = part_dict.pop( + "codeExecutionResult" + ) + + if ( + part_obj.metadata + and "thought_signature" in part_obj.metadata + ): + part_dict["thoughtSignature"] = part_obj.metadata[ + "thought_signature" + ] + current_parts.append(part_dict) + + if msg.tool_calls: + for call in msg.tool_calls: + fc_part = { + "functionCall": { + "name": call.function.name, + "args": json.loads(call.function.arguments), + } + } + if call.thought_signature: + fc_part["thoughtSignature"] = call.thought_signature + current_parts.append(fc_part) + if current_parts: + gemini_contents.append({"role": "model", "parts": current_parts}) + + elif msg.role == "tool": + if not msg.name: + raise ValueError("Gemini 工具消息必须包含 'name' 字段(函数名)。") + + try: + content_str = ( + msg.content + if isinstance(msg.content, str) + else str(msg.content) + ) + tool_result_obj = json.loads(content_str) + except json.JSONDecodeError: + content_str = ( + msg.content + if isinstance(msg.content, str) + else str(msg.content) + ) + tool_result_obj = {"raw_output": content_str} + + if isinstance(tool_result_obj, list): + final_response_payload = {"result": tool_result_obj} + elif not isinstance(tool_result_obj, dict): + final_response_payload = {"result": tool_result_obj} + else: + final_response_payload = tool_result_obj + + current_parts.append( + { + "functionResponse": { + "name": msg.name, + "response": final_response_payload, + } + } + ) + if gemini_contents and gemini_contents[-1]["role"] == "function": + gemini_contents[-1]["parts"].extend(current_parts) + else: + gemini_contents.append({"role": "function", "parts": current_parts}) + + return gemini_contents + + def convert_messages(self, messages: list[LLMMessage]) -> list[dict[str, Any]]: + raise NotImplementedError("Use convert_messages_async for Gemini") + + +class GeminiToolSerializer(ToolSerializer): + def serialize_tools(self, tools: list[ToolDefinition]) -> list[dict[str, Any]]: + function_declarations: list[dict[str, Any]] = [] + for tool_def in tools: + tool_copy = model_copy(tool_def) + tool_copy.parameters = resolve_json_schema_refs(tool_copy.parameters) + tool_copy.parameters = sanitize_schema_for_llm( + tool_copy.parameters, api_type="gemini" + ) + function_declarations.append(model_dump(tool_copy)) + return function_declarations + + +class GeminiResponseParser(ResponseParser): + def validate_response(self, response_json: dict[str, Any]) -> None: + if error := response_json.get("error"): + code = error.get("code") + message = error.get("message", "") + status = error.get("status") + details = error.get("details", []) + + if code == 429 or status == "RESOURCE_EXHAUSTED": + is_quota = any( + d.get("reason") in ("QUOTA_EXCEEDED", "SERVICE_DISABLED") + for d in details + if isinstance(d, dict) + ) + if is_quota or "quota" in message.lower(): + raise LLMException( + f"Gemini配额耗尽: {message}", + code=LLMErrorCode.API_QUOTA_EXCEEDED, + details=error, + ) + raise LLMException( + f"Gemini速率限制: {message}", + code=LLMErrorCode.API_RATE_LIMITED, + details=error, + ) + + if code == 400 or status in ("INVALID_ARGUMENT", "FAILED_PRECONDITION"): + raise LLMException( + f"Gemini参数错误: {message}", + code=LLMErrorCode.INVALID_PARAMETER, + details=error, + recoverable=False, + ) + + if prompt_feedback := response_json.get("promptFeedback"): + if block_reason := prompt_feedback.get("blockReason"): + raise LLMException( + f"内容被安全过滤: {block_reason}", + code=LLMErrorCode.CONTENT_FILTERED, + details={ + "block_reason": block_reason, + "safety_ratings": prompt_feedback.get("safetyRatings"), + }, + ) + + def parse(self, response_json: dict[str, Any]) -> ResponseData: + self.validate_response(response_json) + + if "image_generation" in response_json and isinstance( + response_json["image_generation"], dict + ): + candidates_source = response_json["image_generation"] + else: + candidates_source = response_json + + candidates = candidates_source.get("candidates", []) + usage_info = response_json.get("usageMetadata") + + if not candidates: + return ResponseData(text="", raw_response=response_json) + + candidate = candidates[0] + thought_signature: str | None = None + + content_data = candidate.get("content", {}) + parts = content_data.get("parts", []) + + text_content = "" + images_payload: list[bytes | Path] = [] + parsed_tool_calls: list[LLMToolCall] | None = None + parsed_code_executions: list[dict[str, Any]] = [] + content_parts: list[LLMContentPart] = [] + thought_summary_parts: list[str] = [] + answer_parts = [] + + for part in parts: + part_signature = part.get("thoughtSignature") + if part_signature and thought_signature is None: + thought_signature = part_signature + part_metadata: dict[str, Any] | None = None + if part_signature: + part_metadata = {"thought_signature": part_signature} + + if part.get("thought") is True: + t_text = part.get("text", "") + thought_summary_parts.append(t_text) + content_parts.append(LLMContentPart.thought_part(t_text)) + + elif "text" in part: + answer_parts.append(part["text"]) + c_part = LLMContentPart( + type="text", text=part["text"], metadata=part_metadata + ) + content_parts.append(c_part) + + elif "thoughtSummary" in part: + thought_summary_parts.append(part["thoughtSummary"]) + content_parts.append( + LLMContentPart.thought_part(part["thoughtSummary"]) + ) + + elif "inlineData" in part: + inline_data = part["inlineData"] + if "data" in inline_data: + decoded = base64.b64decode(inline_data["data"]) + images_payload.append(process_image_data(decoded)) + + elif "functionCall" in part: + if parsed_tool_calls is None: + parsed_tool_calls = [] + fc_data = part["functionCall"] + fc_sig = part_signature + try: + call_id = f"call_gemini_{len(parsed_tool_calls)}" + parsed_tool_calls.append( + LLMToolCall( + id=call_id, + thought_signature=fc_sig, + function=LLMToolFunction( + name=fc_data["name"], + arguments=json.dumps(fc_data["args"]), + ), + ) + ) + except Exception as e: + logger.warning( + f"解析Gemini functionCall时出错: {fc_data}, 错误: {e}" + ) + elif "executableCode" in part: + exec_code = part["executableCode"] + lang = exec_code.get("language", "PYTHON") + code = exec_code.get("code", "") + content_parts.append(LLMContentPart.executable_code_part(lang, code)) + answer_parts.append(f"\n[生成代码 ({lang})]:\n```python\n{code}\n```\n") + + elif "codeExecutionResult" in part: + result = part["codeExecutionResult"] + outcome = result.get("outcome", CodeExecutionOutcome.OUTCOME_UNKNOWN) + output = result.get("output", "") + + content_parts.append( + LLMContentPart.execution_result_part(outcome, output) + ) + + parsed_code_executions.append(result) + + if outcome == CodeExecutionOutcome.OUTCOME_OK: + answer_parts.append(f"\n[代码执行结果]:\n```\n{output}\n```\n") + else: + answer_parts.append(f"\n[代码执行失败 ({outcome})]:\n{output}\n") + + full_answer = "".join(answer_parts).strip() + text_content = full_answer + final_thought_text = ( + "\n\n".join(thought_summary_parts).strip() + if thought_summary_parts + else None + ) + + grounding_metadata_obj = None + if grounding_data := candidate.get("groundingMetadata"): + try: + sep_content = None + sep_field = grounding_data.get("searchEntryPoint") + if isinstance(sep_field, dict): + sep_content = sep_field.get("renderedContent") + + attributions = [] + if chunks := grounding_data.get("groundingChunks"): + for chunk in chunks: + if web := chunk.get("web"): + attributions.append( + LLMGroundingAttribution( + title=web.get("title"), + uri=web.get("uri"), + snippet=web.get("snippet"), + confidence_score=None, + ) + ) + + grounding_metadata_obj = LLMGroundingMetadata( + web_search_queries=grounding_data.get("webSearchQueries"), + grounding_attributions=attributions or None, + search_suggestions=grounding_data.get("searchSuggestions"), + search_entry_point=sep_content, + map_widget_token=grounding_data.get("googleMapsWidgetContextToken"), + ) + except Exception as e: + logger.warning(f"无法解析Grounding元数据: {grounding_data}, {e}") + + return ResponseData( + text=text_content, + tool_calls=parsed_tool_calls, + code_executions=parsed_code_executions if parsed_code_executions else None, + content_parts=content_parts if content_parts else None, + images=images_payload if images_payload else None, + usage_info=usage_info, + raw_response=response_json, + grounding_metadata=grounding_metadata_obj, + thought_text=final_thought_text, + thought_signature=thought_signature, + ) diff --git a/zhenxun/services/llm/adapters/components/interfaces.py b/zhenxun/services/llm/adapters/components/interfaces.py new file mode 100644 index 00000000..d5eb2be1 --- /dev/null +++ b/zhenxun/services/llm/adapters/components/interfaces.py @@ -0,0 +1,43 @@ +from abc import ABC, abstractmethod +from typing import Any + +from zhenxun.services.llm.adapters.base import ResponseData +from zhenxun.services.llm.config.generation import LLMGenerationConfig +from zhenxun.services.llm.types import LLMMessage +from zhenxun.services.llm.types.capabilities import ModelCapabilities +from zhenxun.services.llm.types.models import ModelDetail, ToolDefinition + + +class ConfigMapper(ABC): + @abstractmethod + def map_config( + self, + config: LLMGenerationConfig, + model_detail: ModelDetail | None = None, + capabilities: ModelCapabilities | None = None, + ) -> dict[str, Any]: + """将通用生成配置转换为特定 API 的参数字典""" + ... + + +class MessageConverter(ABC): + @abstractmethod + def convert_messages( + self, messages: list[LLMMessage] + ) -> list[dict[str, Any]] | dict[str, Any]: + """将通用消息列表转换为特定 API 的消息格式""" + ... + + +class ToolSerializer(ABC): + @abstractmethod + def serialize_tools(self, tools: list[ToolDefinition]) -> Any: + """将通用工具定义转换为特定 API 的工具格式""" + ... + + +class ResponseParser(ABC): + @abstractmethod + def parse(self, response_json: dict[str, Any]) -> ResponseData: + """将特定 API 的响应解析为通用响应数据""" + ... diff --git a/zhenxun/services/llm/adapters/components/openai_components.py b/zhenxun/services/llm/adapters/components/openai_components.py new file mode 100644 index 00000000..fcbb9997 --- /dev/null +++ b/zhenxun/services/llm/adapters/components/openai_components.py @@ -0,0 +1,347 @@ +import base64 +import binascii +import json +from pathlib import Path +from typing import Any + +from zhenxun.services.llm.adapters.base import ResponseData, process_image_data +from zhenxun.services.llm.adapters.components.interfaces import ( + ConfigMapper, + MessageConverter, + ResponseParser, + ToolSerializer, +) +from zhenxun.services.llm.config.generation import ( + ImageAspectRatio, + LLMGenerationConfig, + ResponseFormat, + StructuredOutputStrategy, +) +from zhenxun.services.llm.types import LLMMessage +from zhenxun.services.llm.types.capabilities import ModelCapabilities +from zhenxun.services.llm.types.exceptions import LLMErrorCode, LLMException +from zhenxun.services.llm.types.models import ( + LLMToolCall, + LLMToolFunction, + ModelDetail, + ToolDefinition, +) +from zhenxun.services.llm.utils import sanitize_schema_for_llm +from zhenxun.services.log import logger +from zhenxun.utils.pydantic_compat import model_dump + + +class OpenAIConfigMapper(ConfigMapper): + def __init__(self, api_type: str = "openai"): + self.api_type = api_type + + def map_config( + self, + config: LLMGenerationConfig, + model_detail: ModelDetail | None = None, + capabilities: ModelCapabilities | None = None, + ) -> dict[str, Any]: + params: dict[str, Any] = {} + strategy = config.output.structured_output_strategy if config.output else None + if strategy is None: + strategy = ( + StructuredOutputStrategy.TOOL_CALL + if self.api_type == "deepseek" + else StructuredOutputStrategy.NATIVE + ) + + if config.core: + if config.core.temperature is not None: + params["temperature"] = config.core.temperature + if config.core.max_tokens is not None: + params["max_tokens"] = config.core.max_tokens + if config.core.top_k is not None: + params["top_k"] = config.core.top_k + if config.core.top_p is not None: + params["top_p"] = config.core.top_p + if config.core.frequency_penalty is not None: + params["frequency_penalty"] = config.core.frequency_penalty + if config.core.presence_penalty is not None: + params["presence_penalty"] = config.core.presence_penalty + if config.core.stop is not None: + params["stop"] = config.core.stop + + if config.core.repetition_penalty is not None: + if self.api_type == "openai": + logger.warning("OpenAI官方API不支持repetition_penalty参数,已忽略") + else: + params["repetition_penalty"] = config.core.repetition_penalty + + if config.reasoning and config.reasoning.effort: + params["reasoning_effort"] = config.reasoning.effort.value.lower() + + if config.output: + if isinstance(config.output.response_format, dict): + params["response_format"] = config.output.response_format + elif ( + config.output.response_format == ResponseFormat.JSON + and strategy == StructuredOutputStrategy.NATIVE + ): + if config.output.response_schema: + sanitized = sanitize_schema_for_llm( + config.output.response_schema, api_type="openai" + ) + params["response_format"] = { + "type": "json_schema", + "json_schema": { + "name": "structured_response", + "schema": sanitized, + "strict": True, + }, + } + else: + params["response_format"] = {"type": "json_object"} + + if config.tool_config: + mode = config.tool_config.mode + if mode == "NONE": + params["tool_choice"] = "none" + elif mode == "AUTO": + params["tool_choice"] = "auto" + elif mode == "ANY": + params["tool_choice"] = "required" + + if config.visual and config.visual.aspect_ratio: + size_map = { + ImageAspectRatio.SQUARE: "1024x1024", + ImageAspectRatio.LANDSCAPE_16_9: "1792x1024", + ImageAspectRatio.PORTRAIT_9_16: "1024x1792", + } + ar = config.visual.aspect_ratio + if isinstance(ar, ImageAspectRatio): + mapped_size = size_map.get(ar) + if mapped_size: + params["size"] = mapped_size + elif isinstance(ar, str): + params["size"] = ar + + if config.custom_params: + mapped_custom = config.custom_params.copy() + if "repetition_penalty" in mapped_custom and self.api_type == "openai": + mapped_custom.pop("repetition_penalty") + + if "stop" in mapped_custom: + stop_value = mapped_custom["stop"] + if isinstance(stop_value, str): + mapped_custom["stop"] = [stop_value] + + params.update(mapped_custom) + + return params + + +class OpenAIMessageConverter(MessageConverter): + def convert_messages(self, messages: list[LLMMessage]) -> list[dict[str, Any]]: + openai_messages: list[dict[str, Any]] = [] + for msg in messages: + openai_msg: dict[str, Any] = {"role": msg.role} + + if msg.role == "tool": + openai_msg["tool_call_id"] = msg.tool_call_id + openai_msg["name"] = msg.name + openai_msg["content"] = msg.content + else: + if isinstance(msg.content, str): + openai_msg["content"] = msg.content + else: + content_parts = [] + for part in msg.content: + if part.type == "text": + content_parts.append({"type": "text", "text": part.text}) + elif part.type == "image": + content_parts.append( + { + "type": "image_url", + "image_url": {"url": part.image_source}, + } + ) + openai_msg["content"] = content_parts + + if msg.role == "assistant" and msg.tool_calls: + assistant_tool_calls = [] + for call in msg.tool_calls: + assistant_tool_calls.append( + { + "id": call.id, + "type": "function", + "function": { + "name": call.function.name, + "arguments": call.function.arguments, + }, + } + ) + openai_msg["tool_calls"] = assistant_tool_calls + + if msg.name and msg.role != "tool": + openai_msg["name"] = msg.name + + openai_messages.append(openai_msg) + return openai_messages + + +class OpenAIToolSerializer(ToolSerializer): + def serialize_tools( + self, tools: list[ToolDefinition] + ) -> list[dict[str, Any]] | None: + if not tools: + return None + + openai_tools = [] + for tool in tools: + tool_dict = model_dump(tool) + parameters = tool_dict.get("parameters") + if parameters: + tool_dict["parameters"] = sanitize_schema_for_llm( + parameters, api_type="openai" + ) + tool_dict["strict"] = True + openai_tools.append({"type": "function", "function": tool_dict}) + return openai_tools + + +class OpenAIResponseParser(ResponseParser): + def validate_response(self, response_json: dict[str, Any]) -> None: + if response_json.get("error"): + error_info = response_json["error"] + if isinstance(error_info, dict): + error_message = error_info.get("message", "未知错误") + error_code = error_info.get("code", "unknown") + + error_code_mapping = { + "invalid_api_key": LLMErrorCode.API_KEY_INVALID, + "authentication_failed": LLMErrorCode.API_KEY_INVALID, + "insufficient_quota": LLMErrorCode.API_QUOTA_EXCEEDED, + "rate_limit_exceeded": LLMErrorCode.API_RATE_LIMITED, + "quota_exceeded": LLMErrorCode.API_RATE_LIMITED, + "model_not_found": LLMErrorCode.MODEL_NOT_FOUND, + "invalid_model": LLMErrorCode.MODEL_NOT_FOUND, + "context_length_exceeded": LLMErrorCode.CONTEXT_LENGTH_EXCEEDED, + "max_tokens_exceeded": LLMErrorCode.CONTEXT_LENGTH_EXCEEDED, + "invalid_request_error": LLMErrorCode.INVALID_PARAMETER, + "invalid_parameter": LLMErrorCode.INVALID_PARAMETER, + } + + llm_error_code = error_code_mapping.get( + error_code, LLMErrorCode.API_RESPONSE_INVALID + ) + else: + error_message = str(error_info) + error_code = "unknown" + llm_error_code = LLMErrorCode.API_RESPONSE_INVALID + + raise LLMException( + f"API请求失败: {error_message}", + code=llm_error_code, + details={"api_error": error_info, "error_code": error_code}, + ) + + def parse(self, response_json: dict[str, Any]) -> ResponseData: + self.validate_response(response_json) + + choices = response_json.get("choices", []) + if not choices: + return ResponseData(text="", raw_response=response_json) + + choice = choices[0] + message = choice.get("message", {}) + content = message.get("content", "") + reasoning_content = message.get("reasoning_content", None) + refusal = message.get("refusal") + + if refusal: + raise LLMException( + f"模型拒绝生成请求: {refusal}", + code=LLMErrorCode.CONTENT_FILTERED, + details={"refusal": refusal}, + recoverable=False, + ) + + if content: + content = content.strip() + + images_payload: list[bytes | Path] = [] + if content and content.startswith("{") and content.endswith("}"): + try: + content_json = json.loads(content) + if "b64_json" in content_json: + b64_str = content_json["b64_json"] + if isinstance(b64_str, str) and b64_str.startswith("data:"): + b64_str = b64_str.split(",", 1)[1] + decoded = base64.b64decode(b64_str) + images_payload.append(process_image_data(decoded)) + content = "[图片已生成]" + elif "data" in content_json and isinstance(content_json["data"], str): + b64_str = content_json["data"] + if b64_str.startswith("data:"): + b64_str = b64_str.split(",", 1)[1] + decoded = base64.b64decode(b64_str) + images_payload.append(process_image_data(decoded)) + content = "[图片已生成]" + + except (json.JSONDecodeError, KeyError, binascii.Error): + pass + elif ( + "images" in message + and isinstance(message["images"], list) + and message["images"] + ): + for image_info in message["images"]: + if image_info.get("type") == "image_url": + image_url_obj = image_info.get("image_url", {}) + url_str = image_url_obj.get("url", "") + if url_str.startswith("data:image"): + try: + b64_data = url_str.split(",", 1)[1] + decoded = base64.b64decode(b64_data) + images_payload.append(process_image_data(decoded)) + except (IndexError, binascii.Error) as e: + logger.warning(f"解析OpenRouter Base64图片数据失败: {e}") + + if images_payload: + content = content if content else "[图片已生成]" + + parsed_tool_calls: list[LLMToolCall] | None = None + if message_tool_calls := message.get("tool_calls"): + parsed_tool_calls = [] + for tc_data in message_tool_calls: + try: + if tc_data.get("type") == "function": + parsed_tool_calls.append( + LLMToolCall( + id=tc_data["id"], + function=LLMToolFunction( + name=tc_data["function"]["name"], + arguments=tc_data["function"]["arguments"], + ), + ) + ) + except KeyError as e: + logger.warning( + f"解析OpenAI工具调用数据时缺少键: {tc_data}, 错误: {e}" + ) + except Exception as e: + logger.warning( + f"解析OpenAI工具调用数据时出错: {tc_data}, 错误: {e}" + ) + if not parsed_tool_calls: + parsed_tool_calls = None + + final_text = content if content is not None else "" + if not final_text and parsed_tool_calls: + final_text = f"请求调用 {len(parsed_tool_calls)} 个工具。" + + usage_info = response_json.get("usage") + + return ResponseData( + text=final_text, + tool_calls=parsed_tool_calls, + usage_info=usage_info, + images=images_payload if images_payload else None, + raw_response=response_json, + thought_text=reasoning_content, + ) diff --git a/zhenxun/services/llm/adapters/factory.py b/zhenxun/services/llm/adapters/factory.py index 9f2a8b64..a21349e4 100644 --- a/zhenxun/services/llm/adapters/factory.py +++ b/zhenxun/services/llm/adapters/factory.py @@ -2,10 +2,17 @@ LLM 适配器工厂类 """ -from typing import ClassVar +import fnmatch +from typing import TYPE_CHECKING, Any, ClassVar from ..types.exceptions import LLMErrorCode, LLMException -from .base import BaseAdapter +from ..types.models import ToolChoice +from .base import BaseAdapter, RequestData, ResponseData + +if TYPE_CHECKING: + from ..config.generation import LLMEmbeddingConfig, LLMGenerationConfig + from ..service import LLMModel + from ..types import LLMMessage class LLMAdapterFactory: @@ -21,10 +28,13 @@ class LLMAdapterFactory: return from .gemini import GeminiAdapter - from .openai import OpenAIAdapter + from .openai import DeepSeekAdapter, OpenAIAdapter, OpenAIImageAdapter cls.register_adapter(OpenAIAdapter()) + cls.register_adapter(DeepSeekAdapter()) cls.register_adapter(GeminiAdapter()) + cls.register_adapter(SmartAdapter()) + cls.register_adapter(OpenAIImageAdapter()) @classmethod def register_adapter(cls, adapter: BaseAdapter) -> None: @@ -74,3 +84,100 @@ def get_adapter_for_api_type(api_type: str) -> BaseAdapter: def register_adapter(adapter: BaseAdapter) -> None: """注册新的适配器""" LLMAdapterFactory.register_adapter(adapter) + + +class SmartAdapter(BaseAdapter): + """ + 智能路由适配器。 + 本身不处理序列化,而是根据规则委托给 OpenAIAdapter 或 GeminiAdapter。 + """ + + @property + def log_sanitization_context(self) -> str: + return "openai_request" + + _ROUTING_RULES: ClassVar[list[tuple[str, str]]] = [ + ("*nano-banana*", "gemini"), + ("*gemini*", "gemini"), + ] + _DEFAULT_API_TYPE: ClassVar[str] = "openai" + + def __init__(self): + self._adapter_cache: dict[str, BaseAdapter] = {} + + @property + def api_type(self) -> str: + return "smart" + + @property + def supported_api_types(self) -> list[str]: + return ["smart"] + + def _get_delegate_adapter(self, model: "LLMModel") -> BaseAdapter: + """ + 核心路由逻辑:决定使用哪个适配器 (带缓存) + """ + if model.model_detail.api_type: + return get_adapter_for_api_type(model.model_detail.api_type) + + model_name = model.model_name + if model_name in self._adapter_cache: + return self._adapter_cache[model_name] + + target_api_type = self._DEFAULT_API_TYPE + model_name_lower = model_name.lower() + + for pattern, api_type in self._ROUTING_RULES: + if fnmatch.fnmatch(model_name_lower, pattern): + target_api_type = api_type + break + + adapter = get_adapter_for_api_type(target_api_type) + self._adapter_cache[model_name] = adapter + return adapter + + async def prepare_advanced_request( + self, + model: "LLMModel", + api_key: str, + messages: list["LLMMessage"], + config: "LLMGenerationConfig | None" = None, + tools: list[Any] | None = None, + tool_choice: "str | dict[str, Any] | ToolChoice | None" = None, + ) -> RequestData: + adapter = self._get_delegate_adapter(model) + return await adapter.prepare_advanced_request( + model, api_key, messages, config, tools, tool_choice + ) + + def parse_response( + self, + model: "LLMModel", + response_json: dict[str, Any], + is_advanced: bool = False, + ) -> ResponseData: + adapter = self._get_delegate_adapter(model) + return adapter.parse_response(model, response_json, is_advanced) + + def prepare_embedding_request( + self, + model: "LLMModel", + api_key: str, + texts: list[str], + config: "LLMEmbeddingConfig", + ) -> RequestData: + adapter = self._get_delegate_adapter(model) + return adapter.prepare_embedding_request(model, api_key, texts, config) + + def parse_embedding_response( + self, response_json: dict[str, Any] + ) -> list[list[float]]: + return get_adapter_for_api_type("openai").parse_embedding_response( + response_json + ) + + def convert_generation_config( + self, config: "LLMGenerationConfig", model: "LLMModel" + ) -> dict[str, Any]: + adapter = self._get_delegate_adapter(model) + return adapter.convert_generation_config(config, model) diff --git a/zhenxun/services/llm/adapters/gemini.py b/zhenxun/services/llm/adapters/gemini.py index 444a9023..109569d2 100644 --- a/zhenxun/services/llm/adapters/gemini.py +++ b/zhenxun/services/llm/adapters/gemini.py @@ -2,27 +2,35 @@ Gemini API 适配器 """ -import base64 from typing import TYPE_CHECKING, Any from zhenxun.services.log import logger +from ..config.generation import ResponseFormat +from ..types import LLMContentPart from ..types.exceptions import LLMErrorCode, LLMException -from ..utils import sanitize_schema_for_llm +from ..types.models import BasePlatformTool, ToolChoice from .base import BaseAdapter, RequestData, ResponseData +from .components.gemini_components import ( + GeminiConfigMapper, + GeminiMessageConverter, + GeminiResponseParser, + GeminiToolSerializer, +) if TYPE_CHECKING: - from ..config.generation import LLMGenerationConfig + from ..config.generation import LLMEmbeddingConfig, LLMGenerationConfig from ..service import LLMModel - from ..types.content import LLMMessage - from ..types.enums import EmbeddingTaskType - from ..types.models import LLMToolCall - from ..types.protocols import ToolExecutable + from ..types import LLMMessage class GeminiAdapter(BaseAdapter): """Gemini API 适配器""" + @property + def log_sanitization_context(self) -> str: + return "gemini_request" + @property def api_type(self) -> str: return "gemini" @@ -47,110 +55,76 @@ class GeminiAdapter(BaseAdapter): api_key: str, messages: list["LLMMessage"], config: "LLMGenerationConfig | None" = None, - tools: dict[str, "ToolExecutable"] | None = None, - tool_choice: str | dict[str, Any] | None = None, + tools: list[Any] | None = None, + tool_choice: str | dict[str, Any] | ToolChoice | None = None, ) -> RequestData: """准备高级请求""" effective_config = config if config is not None else model._generation_config + if tools: + from ..types.models import GeminiUrlContext + + context_urls: list[str] = [] + for tool in tools: + if isinstance(tool, GeminiUrlContext): + context_urls.extend(tool.urls) + + if context_urls and messages: + last_msg = messages[-1] + if last_msg.role == "user": + url_text = "\n\n[Context URLs]:\n" + "\n".join(context_urls) + if isinstance(last_msg.content, str): + last_msg.content += url_text + elif isinstance(last_msg.content, list): + last_msg.content.append(LLMContentPart.text_part(url_text)) + + has_function_tools = False + if tools: + has_function_tools = any(hasattr(tool, "get_definition") for tool in tools) + + is_structured = False + if effective_config and effective_config.output: + if ( + effective_config.output.response_schema + or effective_config.output.response_format == ResponseFormat.JSON + or effective_config.output.response_mime_type == "application/json" + ): + is_structured = True + + if (has_function_tools or is_structured) and effective_config: + if effective_config.reasoning is None: + from ..config.generation import ReasoningConfig + + effective_config.reasoning = ReasoningConfig() + + if ( + effective_config.reasoning.budget_tokens is None + and effective_config.reasoning.effort is None + ): + reason_desc = "工具调用" if has_function_tools else "结构化输出" + logger.debug( + f"检测到{reason_desc},自动为模型 {model.model_name} 开启思维链增强" + ) + effective_config.reasoning.budget_tokens = -1 + endpoint = self._get_gemini_endpoint(model, effective_config) url = self.get_api_url(model, endpoint) headers = self.get_base_headers(api_key) - gemini_contents: list[dict[str, Any]] = [] + converter = GeminiMessageConverter() system_instruction_parts: list[dict[str, Any]] | None = None - for msg in messages: - current_parts: list[dict[str, Any]] = [] if msg.role == "system": if isinstance(msg.content, str): system_instruction_parts = [{"text": msg.content}] elif isinstance(msg.content, list): system_instruction_parts = [ - await part.convert_for_api_async("gemini") + await converter.convert_part(part) for part in msg.content ] continue - elif msg.role == "user": - if isinstance(msg.content, str): - current_parts.append({"text": msg.content}) - elif isinstance(msg.content, list): - for part_obj in msg.content: - current_parts.append( - await part_obj.convert_for_api_async("gemini") - ) - gemini_contents.append({"role": "user", "parts": current_parts}) - - elif msg.role == "assistant" or msg.role == "model": - if isinstance(msg.content, str) and msg.content: - current_parts.append({"text": msg.content}) - elif isinstance(msg.content, list): - for part_obj in msg.content: - current_parts.append( - await part_obj.convert_for_api_async("gemini") - ) - - if msg.tool_calls: - import json - - for call in msg.tool_calls: - current_parts.append( - { - "functionCall": { - "name": call.function.name, - "args": json.loads(call.function.arguments), - } - } - ) - if current_parts: - gemini_contents.append({"role": "model", "parts": current_parts}) - - elif msg.role == "tool": - if not msg.name: - raise ValueError("Gemini 工具消息必须包含 'name' 字段(函数名)。") - - import json - - try: - content_str = ( - msg.content - if isinstance(msg.content, str) - else str(msg.content) - ) - tool_result_obj = json.loads(content_str) - except json.JSONDecodeError: - content_str = ( - msg.content - if isinstance(msg.content, str) - else str(msg.content) - ) - logger.warning( - f"工具 {msg.name} 的结果不是有效的 JSON: {content_str}. " - f"包装为原始字符串。" - ) - tool_result_obj = {"raw_output": content_str} - - if isinstance(tool_result_obj, list): - logger.debug( - f"工具 '{msg.name}' 的返回结果是列表," - f"正在为Gemini API包装为JSON对象。" - ) - final_response_payload = {"result": tool_result_obj} - elif not isinstance(tool_result_obj, dict): - final_response_payload = {"result": tool_result_obj} - else: - final_response_payload = tool_result_obj - - current_parts.append( - { - "functionResponse": { - "name": msg.name, - "response": final_response_payload, - } - } - ) - gemini_contents.append({"role": "function", "parts": current_parts}) + gemini_contents = await converter.convert_messages_async(messages) body: dict[str, Any] = {"contents": gemini_contents} @@ -158,75 +132,78 @@ class GeminiAdapter(BaseAdapter): body["systemInstruction"] = {"parts": system_instruction_parts} all_tools_for_request = [] + has_user_functions = False if tools: - import asyncio + from ..types.protocols import ToolExecutable - from zhenxun.utils.pydantic_compat import model_dump + function_tools: list[ToolExecutable] = [] + gemini_tools_dict: dict[str, Any] = {} - definition_tasks = [ - executable.get_definition() for executable in tools.values() - ] - tool_definitions = await asyncio.gather(*definition_tasks) + for tool in tools: + if isinstance(tool, BasePlatformTool): + declaration = tool.get_tool_declaration() + if declaration: + gemini_tools_dict.update(declaration) + elif hasattr(tool, "get_definition"): + function_tools.append(tool) - function_declarations = [] - for tool_def in tool_definitions: - tool_def.parameters = sanitize_schema_for_llm( - tool_def.parameters, api_type="gemini" - ) - function_declarations.append(model_dump(tool_def)) + if function_tools: + import asyncio - if function_declarations: - all_tools_for_request.append( - {"functionDeclarations": function_declarations} - ) + definition_tasks = [ + executable.get_definition() for executable in function_tools + ] + tool_definitions = await asyncio.gather(*definition_tasks) - if effective_config: - if getattr(effective_config, "enable_grounding", False): - has_explicit_gs_tool = any( - "googleSearch" in tool_item for tool_item in all_tools_for_request - ) - if not has_explicit_gs_tool: - all_tools_for_request.append({"googleSearch": {}}) - logger.debug("隐式启用 Google Search 工具进行信息来源关联。") + serializer = GeminiToolSerializer() + function_declarations = serializer.serialize_tools(tool_definitions) - if getattr(effective_config, "enable_code_execution", False): - has_explicit_ce_tool = any( - "codeExecution" in tool_item for tool_item in all_tools_for_request - ) - if not has_explicit_ce_tool: - all_tools_for_request.append({"codeExecution": {}}) - logger.debug("隐式启用代码执行工具。") + if function_declarations: + gemini_tools_dict["functionDeclarations"] = function_declarations + has_user_functions = True + + if gemini_tools_dict: + all_tools_for_request.append(gemini_tools_dict) if all_tools_for_request: body["tools"] = all_tools_for_request - final_tool_choice = tool_choice - if final_tool_choice is None and effective_config: - final_tool_choice = getattr(effective_config, "tool_choice", None) + tool_config_updates: dict[str, Any] = {} + if ( + effective_config + and effective_config.custom_params + and "user_location" in effective_config.custom_params + ): + tool_config_updates["retrievalConfig"] = { + "latLng": effective_config.custom_params["user_location"] + } - if final_tool_choice: - if isinstance(final_tool_choice, str): - mode_upper = final_tool_choice.upper() - if mode_upper in ["AUTO", "NONE", "ANY"]: - body["toolConfig"] = {"functionCallingConfig": {"mode": mode_upper}} - else: - body["toolConfig"] = self._convert_tool_choice_to_gemini( - final_tool_choice - ) - else: - body["toolConfig"] = self._convert_tool_choice_to_gemini( - final_tool_choice + if tool_config_updates: + body.setdefault("toolConfig", {}).update(tool_config_updates) + + converted_params: dict[str, Any] = {} + if effective_config: + converted_params = self.convert_generation_config(effective_config, model) + + if converted_params: + if "toolConfig" in converted_params: + tool_config_payload = converted_params.pop("toolConfig") + fc_config = tool_config_payload.get("functionCallingConfig") + should_apply_fc = has_user_functions or ( + fc_config and fc_config.get("mode") == "NONE" ) + if should_apply_fc: + body.setdefault("toolConfig", {}).update(tool_config_payload) + elif fc_config and fc_config.get("mode") != "AUTO": + logger.debug( + "Gemini: 忽略针对纯内置工具的 functionCallingConfig (API限制)" + ) - final_generation_config = self._build_gemini_generation_config( - model, effective_config - ) - if final_generation_config: - body["generationConfig"] = final_generation_config + if "safetySettings" in converted_params: + body["safetySettings"] = converted_params.pop("safetySettings") - safety_settings = self._build_safety_settings(effective_config) - if safety_settings: - body["safetySettings"] = safety_settings + if converted_params: + body["generationConfig"] = converted_params return RequestData(url=url, headers=headers, body=body) @@ -242,317 +219,56 @@ class GeminiAdapter(BaseAdapter): def _get_gemini_endpoint( self, model: "LLMModel", config: "LLMGenerationConfig | None" = None ) -> str: - """根据配置选择Gemini API端点""" - if config: - if getattr(config, "enable_code_execution", False): - return f"/v1beta/models/{model.model_name}:generateContent" - - if getattr(config, "enable_grounding", False): - return f"/v1beta/models/{model.model_name}:generateContent" - + """返回Gemini generateContent 端点""" return f"/v1beta/models/{model.model_name}:generateContent" - def _convert_tool_choice_to_gemini( - self, tool_choice_value: str | dict[str, Any] - ) -> dict[str, Any]: - """转换工具选择策略为Gemini格式""" - if isinstance(tool_choice_value, str): - mode_upper = tool_choice_value.upper() - if mode_upper in ["AUTO", "NONE", "ANY"]: - return {"functionCallingConfig": {"mode": mode_upper}} - else: - logger.warning( - f"不支持的 tool_choice 字符串值: '{tool_choice_value}'。" - f"回退到 AUTO。" - ) - return {"functionCallingConfig": {"mode": "AUTO"}} - - elif isinstance(tool_choice_value, dict): - if ( - tool_choice_value.get("type") == "function" - and "function" in tool_choice_value - ): - func_name = tool_choice_value["function"].get("name") - if func_name: - return { - "functionCallingConfig": { - "mode": "ANY", - "allowedFunctionNames": [func_name], - } - } - else: - logger.warning( - f"tool_choice dict 中的函数名无效: {tool_choice_value}。" - f"回退到 AUTO。" - ) - return {"functionCallingConfig": {"mode": "AUTO"}} - - elif "functionCallingConfig" in tool_choice_value: - return { - "functionCallingConfig": tool_choice_value["functionCallingConfig"] - } - - else: - logger.warning( - f"不支持的 tool_choice dict 值: {tool_choice_value}。回退到 AUTO。" - ) - return {"functionCallingConfig": {"mode": "AUTO"}} - - logger.warning( - f"tool_choice 的类型无效: {type(tool_choice_value)}。回退到 AUTO。" - ) - return {"functionCallingConfig": {"mode": "AUTO"}} - - def _build_gemini_generation_config( - self, model: "LLMModel", config: "LLMGenerationConfig | None" = None - ) -> dict[str, Any]: - """构建Gemini生成配置""" - effective_config = config if config is not None else model._generation_config - - if not effective_config: - return {} - - generation_config = effective_config.to_api_params( - api_type="gemini", model_name=model.model_name - ) - - if generation_config: - param_keys = list(generation_config.keys()) - logger.debug( - f"构建Gemini生成配置完成,包含 {len(generation_config)} 个参数: " - f"{param_keys}" - ) - - return generation_config - - def _build_safety_settings( - self, config: "LLMGenerationConfig | None" = None - ) -> list[dict[str, Any]] | None: - """构建安全设置""" - if not config: - return None - - safety_settings = [] - - safety_categories = [ - "HARM_CATEGORY_HARASSMENT", - "HARM_CATEGORY_HATE_SPEECH", - "HARM_CATEGORY_SEXUALLY_EXPLICIT", - "HARM_CATEGORY_DANGEROUS_CONTENT", - ] - - custom_safety_settings = getattr(config, "safety_settings", None) - if custom_safety_settings: - for category, threshold in custom_safety_settings.items(): - safety_settings.append({"category": category, "threshold": threshold}) - else: - from ..config.providers import get_gemini_safety_threshold - - threshold = get_gemini_safety_threshold() - for category in safety_categories: - safety_settings.append({"category": category, "threshold": threshold}) - - return safety_settings if safety_settings else None - - def validate_response(self, response_json: dict[str, Any]) -> None: - """验证 Gemini API 响应,增加对 promptFeedback 的检查""" - super().validate_response(response_json) - - if prompt_feedback := response_json.get("promptFeedback"): - if block_reason := prompt_feedback.get("blockReason"): - logger.warning( - f"Gemini 内容因 promptFeedback 被安全过滤: {block_reason}" - ) - raise LLMException( - f"内容被安全过滤: {block_reason}", - code=LLMErrorCode.CONTENT_FILTERED, - details={ - "block_reason": block_reason, - "safety_ratings": prompt_feedback.get("safetyRatings"), - }, - ) - def parse_response( self, model: "LLMModel", response_json: dict[str, Any], is_advanced: bool = False, - ) -> ResponseData: - """解析API响应""" - return self._parse_response(model, response_json, is_advanced) - - def _parse_response( - self, - model: "LLMModel", - response_json: dict[str, Any], - is_advanced: bool = False, ) -> ResponseData: """解析 Gemini API 响应""" - _ = is_advanced - self.validate_response(response_json) - - try: - if "image_generation" in response_json and isinstance( - response_json["image_generation"], dict - ): - candidates_source = response_json["image_generation"] - else: - candidates_source = response_json - - candidates = candidates_source.get("candidates", []) - usage_info = response_json.get("usageMetadata") - - if not candidates: - logger.debug("Gemini响应中没有candidates。") - return ResponseData(text="", raw_response=response_json) - - candidate = candidates[0] - - if candidate.get("finishReason") in [ - "RECITATION", - "OTHER", - ] and not candidate.get("content"): - logger.warning( - f"Gemini candidate finished with reason " - f"'{candidate.get('finishReason')}' and no content." - ) - return ResponseData( - text="", - raw_response=response_json, - usage_info=response_json.get("usageMetadata"), - ) - - content_data = candidate.get("content", {}) - parts = content_data.get("parts", []) - - text_content = "" - images_bytes: list[bytes] = [] - parsed_tool_calls: list["LLMToolCall"] | None = None - thought_summary_parts = [] - answer_parts = [] - - for part in parts: - if "text" in part: - answer_parts.append(part["text"]) - elif "thought" in part: - thought_summary_parts.append(part["thought"]) - elif "thoughtSummary" in part: - thought_summary_parts.append(part["thoughtSummary"]) - elif "inlineData" in part: - inline_data = part["inlineData"] - if "data" in inline_data: - images_bytes.append(base64.b64decode(inline_data["data"])) - - elif "functionCall" in part: - if parsed_tool_calls is None: - parsed_tool_calls = [] - fc_data = part["functionCall"] - try: - import json - - from ..types.models import LLMToolCall, LLMToolFunction - - call_id = f"call_{model.provider_name}_{len(parsed_tool_calls)}" - parsed_tool_calls.append( - LLMToolCall( - id=call_id, - function=LLMToolFunction( - name=fc_data["name"], - arguments=json.dumps(fc_data["args"]), - ), - ) - ) - except KeyError as e: - logger.warning( - f"解析Gemini functionCall时缺少键: {fc_data}, 错误: {e}" - ) - except Exception as e: - logger.warning( - f"解析Gemini functionCall时出错: {fc_data}, 错误: {e}" - ) - elif "codeExecutionResult" in part: - result = part["codeExecutionResult"] - if result.get("outcome") == "OK": - output = result.get("output", "") - answer_parts.append(f"\n[代码执行结果]:\n```\n{output}\n```\n") - else: - answer_parts.append( - f"\n[代码执行失败]: {result.get('outcome', 'UNKNOWN')}\n" - ) - - if thought_summary_parts: - full_thought_summary = "\n".join(thought_summary_parts).strip() - full_answer = "".join(answer_parts).strip() - - formatted_parts = [] - if full_thought_summary: - formatted_parts.append(f"🤔 **思考过程**\n\n{full_thought_summary}") - if full_answer: - separator = "\n\n---\n\n" if full_thought_summary else "" - formatted_parts.append(f"{separator}✅ **回答**\n\n{full_answer}") - - text_content = "".join(formatted_parts) - else: - text_content = "".join(answer_parts) - - usage_info = response_json.get("usageMetadata") - - grounding_metadata_obj = None - if grounding_data := candidate.get("groundingMetadata"): - try: - from ..types.models import LLMGroundingMetadata - - grounding_metadata_obj = LLMGroundingMetadata(**grounding_data) - except Exception as e: - logger.warning(f"无法解析Grounding元数据: {grounding_data}, {e}") - - return ResponseData( - text=text_content, - tool_calls=parsed_tool_calls, - images=images_bytes if images_bytes else None, - usage_info=usage_info, - raw_response=response_json, - grounding_metadata=grounding_metadata_obj, - ) - - except Exception as e: - logger.error(f"解析 Gemini 响应失败: {e}", e=e) - raise LLMException( - f"解析API响应失败: {e}", - code=LLMErrorCode.RESPONSE_PARSE_ERROR, - cause=e, - ) + _ = model, is_advanced + parser = GeminiResponseParser() + return parser.parse(response_json) def prepare_embedding_request( self, model: "LLMModel", api_key: str, texts: list[str], - task_type: "EmbeddingTaskType | str", - **kwargs: Any, + config: "LLMEmbeddingConfig", ) -> RequestData: """准备文本嵌入请求""" api_model_name = model.model_name if not api_model_name.startswith("models/"): api_model_name = f"models/{api_model_name}" - url = self.get_api_url(model, f"/{api_model_name}:batchEmbedContents") + if not model.api_base: + raise LLMException( + f"模型 {model.model_name} 的 api_base 未设置", + code=LLMErrorCode.CONFIGURATION_ERROR, + ) + + base_url = model.api_base.rstrip("/") + url = f"{base_url}/v1beta/{api_model_name}:batchEmbedContents" headers = self.get_base_headers(api_key) requests_payload = [] for text_content in texts: + safe_text = text_content if text_content else " " request_item: dict[str, Any] = { - "content": {"parts": [{"text": text_content}]}, + "model": api_model_name, + "content": {"parts": [{"text": safe_text}]}, } - from ..types.enums import EmbeddingTaskType - - if task_type and task_type != EmbeddingTaskType.RETRIEVAL_DOCUMENT: - request_item["task_type"] = str(task_type).upper() - if title := kwargs.get("title"): - request_item["title"] = title - if output_dimensionality := kwargs.get("output_dimensionality"): - request_item["output_dimensionality"] = output_dimensionality + if config.task_type: + request_item["task_type"] = str(config.task_type).upper() + if config.title: + request_item["title"] = config.title + if config.output_dimensionality: + request_item["output_dimensionality"] = config.output_dimensionality requests_payload.append(request_item) @@ -601,3 +317,9 @@ class GeminiAdapter(BaseAdapter): code=LLMErrorCode.RESPONSE_PARSE_ERROR, details=response_json, ) + + def convert_generation_config( + self, config: "LLMGenerationConfig", model: "LLMModel" + ) -> dict[str, Any]: + mapper = GeminiConfigMapper() + return mapper.map_config(config, model.model_detail, model.capabilities) diff --git a/zhenxun/services/llm/adapters/openai.py b/zhenxun/services/llm/adapters/openai.py index a992d7ee..16613524 100644 --- a/zhenxun/services/llm/adapters/openai.py +++ b/zhenxun/services/llm/adapters/openai.py @@ -1,15 +1,181 @@ """ OpenAI API 适配器 -支持 OpenAI、DeepSeek、智谱AI 和其他 OpenAI 兼容的 API 服务。 +支持 OpenAI、智谱AI 等 OpenAI 兼容的 API 服务。 """ -from typing import TYPE_CHECKING +from abc import ABC, abstractmethod +import base64 +from pathlib import Path +from typing import TYPE_CHECKING, Any -from .base import OpenAICompatAdapter +import json_repair + +from zhenxun.services.llm.config.generation import ImageAspectRatio +from zhenxun.services.llm.types.exceptions import LLMErrorCode, LLMException +from zhenxun.services.log import logger +from zhenxun.utils.http_utils import AsyncHttpx + +from ..types import StructuredOutputStrategy +from ..types.models import ToolChoice +from ..utils import sanitize_schema_for_llm +from .base import ( + BaseAdapter, + OpenAICompatAdapter, + RequestData, + ResponseData, + process_image_data, +) +from .components.openai_components import ( + OpenAIConfigMapper, + OpenAIMessageConverter, + OpenAIResponseParser, + OpenAIToolSerializer, +) if TYPE_CHECKING: + from ..config.generation import LLMEmbeddingConfig, LLMGenerationConfig from ..service import LLMModel + from ..types import LLMMessage + + +class APIProtocol(ABC): + """API 协议策略基类""" + + @abstractmethod + def build_request_body( + self, + model: "LLMModel", + messages: list["LLMMessage"], + tools: list[dict[str, Any]] | None, + tool_choice: Any, + ) -> dict[str, Any]: + """构建不同协议下的请求体""" + pass + + @abstractmethod + def parse_response(self, response_json: dict[str, Any]) -> ResponseData: + """解析不同协议下的响应""" + pass + + +class StandardProtocol(APIProtocol): + """标准 OpenAI 协议策略""" + + def __init__(self, adapter: "OpenAICompatAdapter"): + self.adapter = adapter + + def build_request_body( + self, + model: "LLMModel", + messages: list["LLMMessage"], + tools: list[dict[str, Any]] | None, + tool_choice: Any, + ) -> dict[str, Any]: + converter = OpenAIMessageConverter() + openai_messages = converter.convert_messages(messages) + body: dict[str, Any] = { + "model": model.model_name, + "messages": openai_messages, + } + if tools: + body["tools"] = tools + if tool_choice: + body["tool_choice"] = tool_choice + return body + + def parse_response(self, response_json: dict[str, Any]) -> ResponseData: + parser = OpenAIResponseParser() + return parser.parse(response_json) + + +class ResponsesProtocol(APIProtocol): + """/v1/responses 新版协议策略""" + + def __init__(self, adapter: "OpenAICompatAdapter"): + self.adapter = adapter + + def build_request_body( + self, + model: "LLMModel", + messages: list["LLMMessage"], + tools: list[dict[str, Any]] | None, + tool_choice: Any, + ) -> dict[str, Any]: + input_items: list[dict[str, Any]] = [] + + for msg in messages: + role = msg.role + content_list: list[dict[str, Any]] = [] + raw_contents = ( + msg.content if isinstance(msg.content, list) else [msg.content] + ) + + for part in raw_contents: + if part is None: + continue + if isinstance(part, str): + content_list.append({"type": "input_text", "text": part}) + continue + + if hasattr(part, "type"): + part_type = getattr(part, "type", None) + if part_type == "text": + content_list.append( + {"type": "input_text", "text": getattr(part, "text", "")} + ) + elif part_type == "image": + content_list.append( + { + "type": "input_image", + "image_url": getattr(part, "image_source", ""), + } + ) + continue + + if isinstance(part, dict): + part_type = part.get("type") + if part_type == "text": + content_list.append( + {"type": "input_text", "text": part.get("text", "")} + ) + elif part_type in {"image", "image_url"}: + image_src = part.get("image_url") or part.get( + "image_source", "" + ) + content_list.append( + { + "type": "input_image", + "image_url": image_src, + } + ) + + input_items.append({"role": role, "content": content_list}) + + body: dict[str, Any] = { + "model": model.model_name, + "input": input_items, + } + if tools: + body["tools"] = tools + if tool_choice: + body["tool_choice"] = tool_choice + return body + + def parse_response(self, response_json: dict[str, Any]) -> ResponseData: + self.adapter.validate_response(response_json) + text_content = "" + for item in response_json.get("output", []): + if item.get("type") == "message" and item.get("role") == "assistant": + for content_item in item.get("content", []): + if content_item.get("type") == "output_text": + text_content += content_item.get("text", "") + + return ResponseData( + text=text_content, + usage_info=response_json.get("usage"), + raw_response=response_json, + ) class OpenAIAdapter(OpenAICompatAdapter): @@ -23,23 +189,411 @@ class OpenAIAdapter(OpenAICompatAdapter): def supported_api_types(self) -> list[str]: return [ "openai", - "deepseek", "zhipu", - "general_openai_compat", "ark", "openrouter", + "openai_responses", ] def get_chat_endpoint(self, model: "LLMModel") -> str: """返回聊天完成端点""" - if model.api_type == "ark": + if model.model_detail.endpoint: + return model.model_detail.endpoint + + current_api_type = model.model_detail.api_type or model.api_type + + if current_api_type == "openai_responses": + return "/v1/responses" + if current_api_type == "ark": return "/api/v3/chat/completions" - if model.api_type == "zhipu": + if current_api_type == "zhipu": return "/api/paas/v4/chat/completions" return "/v1/chat/completions" + def _get_protocol_strategy(self, model: "LLMModel") -> APIProtocol: + """根据 API 类型获取对应的处理策略""" + current_api_type = model.model_detail.api_type or model.api_type + if current_api_type == "openai_responses": + return ResponsesProtocol(self) + return StandardProtocol(self) + def get_embedding_endpoint(self, model: "LLMModel") -> str: """根据API类型返回嵌入端点""" if model.api_type == "zhipu": return "/v4/embeddings" return "/v1/embeddings" + + def convert_generation_config( + self, config: "LLMGenerationConfig", model: "LLMModel" + ) -> dict[str, Any]: + mapper = OpenAIConfigMapper(api_type=self.api_type) + return mapper.map_config(config, model.model_detail, model.capabilities) + + async def prepare_advanced_request( + self, + model: "LLMModel", + api_key: str, + messages: list["LLMMessage"], + config: "LLMGenerationConfig | None" = None, + tools: list[Any] | None = None, + tool_choice: str | dict[str, Any] | ToolChoice | None = None, + ) -> "RequestData": + """根据不同协议策略构建高级请求""" + url = self.get_api_url(model, self.get_chat_endpoint(model)) + headers = self.get_base_headers(api_key) + if model.api_type == "openrouter": + headers.update( + { + "HTTP-Referer": "https://github.com/zhenxun-org/zhenxun_bot", + "X-Title": "Zhenxun Bot", + } + ) + + default_config = getattr(model, "_generation_config", None) + effective_config = config if config is not None else default_config + structured_strategy = ( + effective_config.output.structured_output_strategy + if effective_config and effective_config.output + else None + ) + if structured_strategy is None: + structured_strategy = StructuredOutputStrategy.NATIVE + + openai_tools: list[dict[str, Any]] | None = None + executables: list[Any] = [] + if tools: + if isinstance(tools, dict): + executables = list(tools.values()) + else: + for tool in tools: + if hasattr(tool, "get_definition"): + executables.append(tool) + + definition_tasks = [executable.get_definition() for executable in executables] + tool_defs: list[Any] = [] + if definition_tasks: + import asyncio + + tool_defs = await asyncio.gather(*definition_tasks) + + if tool_defs: + serializer = OpenAIToolSerializer() + openai_tools = serializer.serialize_tools(tool_defs) + + final_tool_choice = tool_choice + if final_tool_choice is None: + if ( + effective_config + and effective_config.tool_config + and effective_config.tool_config.mode == "ANY" + ): + allowed = effective_config.tool_config.allowed_function_names + if allowed: + if len(allowed) == 1: + final_tool_choice = { + "type": "function", + "function": {"name": allowed[0]}, + } + else: + logger.warning( + "OpenAI API 不支持多个 allowed_function_names,降级为" + " required。" + ) + final_tool_choice = "required" + else: + final_tool_choice = "required" + + if ( + structured_strategy == StructuredOutputStrategy.TOOL_CALL + and effective_config + and effective_config.output + and effective_config.output.response_schema + ): + sanitized_schema = sanitize_schema_for_llm( + effective_config.output.response_schema, api_type="openai" + ) + structured_tool = { + "type": "function", + "function": { + "name": "return_structured_response", + "description": "Return the final structured response.", + "parameters": sanitized_schema, + "strict": True if model.api_type != "deepseek" else False, + }, + } + if openai_tools is None: + openai_tools = [] + openai_tools.append(structured_tool) + final_tool_choice = { + "type": "function", + "function": {"name": "return_structured_response"}, + } + + protocol_strategy = self._get_protocol_strategy(model) + body = protocol_strategy.build_request_body( + model=model, + messages=messages, + tools=openai_tools, + tool_choice=final_tool_choice, + ) + + body = self.apply_config_override(model, body, config) + + if final_tool_choice is not None: + body["tool_choice"] = final_tool_choice + + response_format = body.get("response_format", {}) + inject_prompt = ( + structured_strategy == StructuredOutputStrategy.NATIVE + and isinstance(response_format, dict) + and response_format.get("type") == "json_object" + ) + + if inject_prompt: + messages_list = body.get("messages", []) + has_json_keyword = False + for msg in messages_list: + content = msg.get("content") + if isinstance(content, str) and "json" in content.lower(): + has_json_keyword = True + break + if isinstance(content, list): + for part in content: + if ( + isinstance(part, dict) + and part.get("type") == "text" + and "json" in part.get("text", "").lower() + ): + has_json_keyword = True + break + if has_json_keyword: + break + + if not has_json_keyword: + injection_text = ( + "请务必输出合法的 JSON 格式,避免额外的文本、Markdown 或解释。" + ) + system_msg = next( + (m for m in messages_list if m.get("role") == "system"), None + ) + if system_msg: + if isinstance(system_msg.get("content"), str): + system_msg["content"] += " " + injection_text + elif isinstance(system_msg.get("content"), list): + system_msg["content"].append( + {"type": "text", "text": injection_text} + ) + else: + messages_list.insert( + 0, {"role": "system", "content": injection_text} + ) + body["messages"] = messages_list + + return RequestData(url=url, headers=headers, body=body) + + def parse_response( + self, + model: "LLMModel", + response_json: dict[str, Any], + is_advanced: bool = False, + ) -> ResponseData: + """解析响应 - 使用策略模式委托处理""" + _ = is_advanced + protocol_strategy = self._get_protocol_strategy(model) + response_data = protocol_strategy.parse_response(response_json) + + if response_data.tool_calls: + target_tool = next( + ( + tc + for tc in response_data.tool_calls + if tc.function.name == "return_structured_response" + ), + None, + ) + if target_tool: + response_data.text = json_repair.repair_json( + target_tool.function.arguments + ) + remaining = [ + tc + for tc in response_data.tool_calls + if tc.function.name != "return_structured_response" + ] + response_data.tool_calls = remaining or None + + return response_data + + +class DeepSeekAdapter(OpenAIAdapter): + """DeepSeek 专用适配器 (基于 OpenAI 协议)""" + + @property + def api_type(self) -> str: + return "deepseek" + + @property + def supported_api_types(self) -> list[str]: + return ["deepseek"] + + +class OpenAIImageAdapter(BaseAdapter): + """OpenAI 图像生成/编辑适配器""" + + @property + def api_type(self) -> str: + return "openai_image" + + @property + def log_sanitization_context(self) -> str: + return "openai_request" + + @property + def supported_api_types(self) -> list[str]: + return ["openai_image", "nano_banana"] + + async def prepare_advanced_request( + self, + model: "LLMModel", + api_key: str, + messages: list["LLMMessage"], + config: "LLMGenerationConfig | None" = None, + tools: list[Any] | None = None, + tool_choice: "str | dict[str, Any] | ToolChoice | None" = None, + ) -> RequestData: + _ = tools, tool_choice + effective_config = config if config is not None else model._generation_config + headers = self.get_base_headers(api_key) + + prompt = "" + images_bytes_list: list[bytes] = [] + + for msg in reversed(messages): + if msg.role != "user": + continue + if isinstance(msg.content, str): + prompt = msg.content + elif isinstance(msg.content, list): + for part in msg.content: + if part.type == "text" and not prompt: + prompt = part.text + elif part.type == "image": + if part.is_image_base64(): + if b64_data := part.get_base64_data(): + _, b64_str = b64_data + images_bytes_list.append(base64.b64decode(b64_str)) + elif part.is_image_url() and part.image_source: + images_bytes_list.append( + await AsyncHttpx.get_content(part.image_source) + ) + if prompt: + break + + if not prompt and not images_bytes_list: + raise LLMException( + "图像生成需要提供 Prompt", + code=LLMErrorCode.CONFIGURATION_ERROR, + ) + + body: dict[str, Any] = { + "model": model.model_name, + "prompt": prompt, + "response_format": "b64_json", + } + + if effective_config: + if effective_config.visual: + if effective_config.visual.aspect_ratio: + ar = effective_config.visual.aspect_ratio + size_map = { + ImageAspectRatio.SQUARE: "1024x1024", + ImageAspectRatio.LANDSCAPE_16_9: "1792x1024", + ImageAspectRatio.PORTRAIT_9_16: "1024x1792", + } + if isinstance(ar, ImageAspectRatio) and ar in size_map: + body["size"] = size_map[ar] + body["aspect_ratio"] = ar.value + elif isinstance(ar, str): + if "x" in ar: + body["size"] = ar + else: + body["aspect_ratio"] = ar + + if effective_config.visual.resolution: + res_val = effective_config.visual.resolution + if not isinstance(res_val, str): + res_val = getattr(res_val, "value", res_val) + body["image_size"] = res_val + + if effective_config.custom_params: + body.update(effective_config.custom_params) + + if images_bytes_list: + b64_images = [] + for img_bytes in images_bytes_list: + b64_str = base64.b64encode(img_bytes).decode("utf-8") + b64_images.append(b64_str) + body["image"] = b64_images + + endpoint = "/v1/images/generations" + url = self.get_api_url(model, endpoint) + return RequestData(url=url, headers=headers, body=body) + + def parse_response( + self, + model: "LLMModel", + response_json: dict[str, Any], + is_advanced: bool = False, + ) -> ResponseData: + _ = model, is_advanced + self.validate_response(response_json) + + images_data: list[bytes | Path] = [] + data_list = response_json.get("data", []) + + for item in data_list: + if "b64_json" in item: + try: + b64_str = item["b64_json"] + if b64_str.startswith("data:"): + b64_str = b64_str.split(",", 1)[1] + img = base64.b64decode(b64_str) + images_data.append(process_image_data(img)) + except Exception as exc: + logger.error(f"Base64 解码失败: {exc}") + elif "url" in item: + logger.warning( + f"API 返回了 URL 而不是 Base64: {item.get('url', 'unknown')}" + ) + + text_summary = ( + f"已生成 {len(images_data)} 张图片。" + if images_data + else "图像生成接口调用成功,但未解析到图片数据。" + ) + + return ResponseData( + text=text_summary, + images=images_data if images_data else None, + raw_response=response_json, + ) + + def prepare_embedding_request( + self, + model: "LLMModel", + api_key: str, + texts: list[str], + config: "LLMEmbeddingConfig", + ) -> RequestData: + raise NotImplementedError("OpenAIImageAdapter 不支持 Embedding") + + def parse_embedding_response( + self, response_json: dict[str, Any] + ) -> list[list[float]]: + raise NotImplementedError("OpenAIImageAdapter 不支持 Embedding") + + def convert_generation_config( + self, config: "LLMGenerationConfig", model: "LLMModel" + ) -> dict[str, Any]: + _ = config, model + return {} diff --git a/zhenxun/services/llm/api.py b/zhenxun/services/llm/api.py index 23da7f1b..1824f67b 100644 --- a/zhenxun/services/llm/api.py +++ b/zhenxun/services/llm/api.py @@ -2,6 +2,7 @@ LLM 服务的高级 API 接口 - 便捷函数入口 (无状态) """ +from collections.abc import Awaitable, Callable from pathlib import Path from typing import Any, TypeVar, overload @@ -11,19 +12,24 @@ from pydantic import BaseModel from zhenxun.services.log import logger from .config import CommonOverrides -from .config.generation import LLMGenerationConfig, create_generation_config_from_kwargs +from .config.generation import ( + GenConfigBuilder, + LLMEmbeddingConfig, + LLMGenerationConfig, + OutputConfig, +) from .manager import get_model_instance from .session import AI -from .tools.manager import tool_provider_manager from .types import ( - EmbeddingTaskType, LLMContentPart, LLMErrorCode, LLMException, LLMMessage, LLMResponse, ModelName, + ToolChoice, ) +from .types.exceptions import get_user_friendly_error_message from .utils import create_multimodal_message T = TypeVar("T", bound=BaseModel) @@ -34,9 +40,10 @@ async def chat( *, model: ModelName = None, instruction: str | None = None, - tools: list[dict[str, Any] | str] | None = None, - tool_choice: str | dict[str, Any] | None = None, - **kwargs: Any, + tools: list[Any] | None = None, + tool_choice: str | dict[str, Any] | ToolChoice | None = None, + config: LLMGenerationConfig | GenConfigBuilder | None = None, + timeout: float | None = None, ) -> LLMResponse: """ 无状态的聊天对话便捷函数,通过临时的AI会话实例与LLM模型交互。 @@ -47,14 +54,13 @@ async def chat( instruction: 系统指令,用于指导AI的行为和回复风格。 tools: 可用的工具列表,支持字典配置或字符串标识符。 tool_choice: 工具选择策略,控制AI如何选择和使用工具。 - **kwargs: 额外的生成配置参数,会被转换为LLMGenerationConfig。 + config: (可选) 生成配置对象,将与默认配置合并后传递。 + timeout: (可选) HTTP 请求超时时间(秒)。 返回: LLMResponse: 包含AI回复内容、使用信息和工具调用等的完整响应对象。 """ try: - config = create_generation_config_from_kwargs(**kwargs) if kwargs else None - ai_session = AI() return await ai_session.chat( @@ -64,12 +70,14 @@ async def chat( tools=tools, tool_choice=tool_choice, config=config, + timeout=timeout, ) except LLMException: raise except Exception as e: - logger.error(f"执行 chat 函数失败: {e}", e=e) - raise LLMException(f"聊天执行失败: {e}", cause=e) + friendly_msg = get_user_friendly_error_message(e) + logger.error(f"执行 chat 函数失败: {e} | 建议: {friendly_msg}", e=e) + raise LLMException(f"聊天执行失败: {friendly_msg}", cause=e) async def code( @@ -77,7 +85,6 @@ async def code( *, model: ModelName = None, timeout: int | None = None, - **kwargs: Any, ) -> LLMResponse: """ 无状态的代码执行便捷函数,支持在沙箱环境中执行代码。 @@ -86,66 +93,25 @@ async def code( prompt: 代码执行的提示词,描述要执行的代码任务。 model: 要使用的模型名称,默认使用Gemini/gemini-2.0-flash。 timeout: 代码执行超时时间(秒),防止长时间运行的代码阻塞。 - **kwargs: 额外的生成配置参数。 返回: LLMResponse: 包含代码执行结果的完整响应对象。 """ - resolved_model = model or "Gemini/gemini-2.0-flash" + resolved_model = model config = CommonOverrides.gemini_code_execution() if timeout: config.custom_params = config.custom_params or {} config.custom_params["code_execution_timeout"] = timeout - final_config = config.to_dict() - final_config.update(kwargs) - - return await chat(prompt, model=resolved_model, **final_config) - - -async def search( - query: str | UniMessage | LLMMessage | list[LLMContentPart], - *, - model: ModelName = None, - instruction: str = ( - "你是一位强大的信息检索和整合专家。请利用可用的搜索工具," - "根据用户的查询找到最相关的信息,并进行总结和回答。" - ), - **kwargs: Any, -) -> LLMResponse: - """ - 无状态的信息搜索便捷函数,利用搜索工具获取实时信息。 - - 参数: - query: 搜索查询内容,支持多种输入格式。 - model: 要使用的模型名称,如果为None则使用默认模型。 - instruction: 搜索任务的系统指令,指导AI如何处理搜索结果。 - **kwargs: 额外的生成配置参数。 - - 返回: - LLMResponse: 包含搜索结果和AI整合回复的完整响应对象。 - """ - logger.debug("执行无状态 'search' 任务...") - search_config = CommonOverrides.gemini_grounding() - - final_config = search_config.to_dict() - final_config.update(kwargs) - - return await chat( - query, - model=model, - instruction=instruction, - **final_config, - ) + return await chat(prompt, model=resolved_model, config=config) async def embed( texts: list[str] | str, *, model: ModelName = None, - task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT, - **kwargs: Any, + config: LLMEmbeddingConfig | None = None, ) -> list[list[float]]: """ 无状态的文本嵌入便捷函数,将文本转换为向量表示。 @@ -153,8 +119,7 @@ async def embed( 参数: texts: 要生成嵌入的文本内容,支持单个字符串或字符串列表。 model: 要使用的嵌入模型名称,如果为None则使用默认模型。 - task_type: 嵌入任务类型,影响向量的优化方向(如检索、分类等)。 - **kwargs: 额外的模型配置参数。 + config: 嵌入配置对象。 返回: list[list[float]]: 文本对应的嵌入向量列表,每个向量为浮点数列表。 @@ -164,27 +129,71 @@ async def embed( if not texts: return [] + final_config = config or LLMEmbeddingConfig() + try: async with await get_model_instance(model) as model_instance: - return await model_instance.generate_embeddings( - texts, task_type=task_type, **kwargs - ) + return await model_instance.generate_embeddings(texts, config=final_config) except LLMException: raise except Exception as e: - logger.error(f"文本嵌入失败: {e}", e=e) + friendly_msg = get_user_friendly_error_message(e) + logger.error(f"文本嵌入失败: {e} | 建议: {friendly_msg}", e=e) raise LLMException( - f"文本嵌入失败: {e}", code=LLMErrorCode.EMBEDDING_FAILED, cause=e + f"文本嵌入失败: {friendly_msg}", + code=LLMErrorCode.EMBEDDING_FAILED, + cause=e, ) +async def embed_query( + text: str, + *, + model: ModelName = None, + dimensions: int | None = None, +) -> list[float]: + """ + 语义化便捷 API:为检索查询生成嵌入。 + """ + config = LLMEmbeddingConfig( + task_type="RETRIEVAL_QUERY", + output_dimensionality=dimensions, + ) + vectors = await embed([text], model=model, config=config) + return vectors[0] if vectors else [] + + +async def embed_documents( + texts: list[str], + *, + model: ModelName = None, + dimensions: int | None = None, + title: str | None = None, +) -> list[list[float]]: + """ + 语义化便捷 API:为文档集合生成嵌入。 + """ + config = LLMEmbeddingConfig( + task_type="RETRIEVAL_DOCUMENT", + output_dimensionality=dimensions, + title=title, + ) + return await embed(texts, model=model, config=config) + + async def generate_structured( message: str | LLMMessage | list[LLMContentPart], response_model: type[T], *, model: ModelName = None, + tools: list[Any] | None = None, + tool_choice: str | dict[str, Any] | ToolChoice | None = None, + max_validation_retries: int | None = None, + validation_callback: Callable[[T], Any | Awaitable[Any]] | None = None, + error_prompt_template: str | None = None, + auto_thinking: bool = False, instruction: str | None = None, - **kwargs: Any, + timeout: float | None = None, ) -> T: """ 无状态地生成结构化响应,并自动解析为指定的Pydantic模型。 @@ -192,39 +201,48 @@ async def generate_structured( 参数: message: 用户输入的消息内容,支持多种格式。 response_model: 用于解析和验证响应的Pydantic模型类。 + max_validation_retries: 校验失败时的最大重试次数,默认为 None (使用全局配置)。 + validation_callback: 自定义校验回调函数,抛出异常视为校验失败。 + error_prompt_template: 自定义错误反馈提示词模板。 + auto_thinking: 是否自动开启思维链 (CoT) 包装。适用于不支持原生思考的模型 model: 要使用的模型名称,如果为None则使用默认模型。 instruction: 系统指令,用于指导AI生成符合要求的结构化输出。 - **kwargs: 额外的生成配置参数。 + timeout: HTTP 请求超时时间(秒)。 返回: T: 解析后的Pydantic模型实例,类型为response_model指定的类型。 """ try: - config = create_generation_config_from_kwargs(**kwargs) if kwargs else None - ai_session = AI() return await ai_session.generate_structured( message, response_model, model=model, + tools=tools, + tool_choice=tool_choice, + max_validation_retries=max_validation_retries, + validation_callback=validation_callback, + error_prompt_template=error_prompt_template, + auto_thinking=auto_thinking, instruction=instruction, - config=config, + timeout=timeout, ) except LLMException: raise except Exception as e: - logger.error(f"生成结构化响应失败: {e}", e=e) - raise LLMException(f"生成结构化响应失败: {e}", cause=e) + friendly_msg = get_user_friendly_error_message(e) + logger.error(f"生成结构化响应失败: {e} | 建议: {friendly_msg}", e=e) + raise LLMException(f"生成结构化响应失败: {friendly_msg}", cause=e) async def generate( messages: list[LLMMessage], *, model: ModelName = None, - tools: list[dict[str, Any] | str] | None = None, - tool_choice: str | dict[str, Any] | None = None, - **kwargs: Any, + tools: list[Any] | None = None, + tool_choice: str | dict[str, Any] | ToolChoice | None = None, + config: LLMGenerationConfig | GenConfigBuilder | None = None, ) -> LLMResponse: """ 根据完整的消息列表生成一次性响应,这是一个无状态的底层函数。 @@ -234,103 +252,57 @@ async def generate( model: 要使用的模型名称,如果为None则使用默认模型。 tools: 可用的工具列表,支持字典配置或字符串标识符。 tool_choice: 工具选择策略,控制AI如何选择和使用工具。 - **kwargs: 额外的生成配置参数,会覆盖默认配置。 + config: (可选) 生成配置对象,将与默认配置合并后传递。 返回: LLMResponse: 包含AI回复内容、使用信息和工具调用等的完整响应对象。 """ try: + if isinstance(config, GenConfigBuilder): + config = config.build() + async with await get_model_instance( - model, override_config=kwargs + model, override_config=None ) as model_instance: return await model_instance.generate_response( messages, - tools=tools, # type: ignore + config=config, + tools=tools, # type: ignore[arg-type] tool_choice=tool_choice, ) except LLMException: raise except Exception as e: - logger.error(f"生成响应失败: {e}", e=e) - raise LLMException(f"生成响应失败: {e}", cause=e) - - -async def run_with_tools( - message: str | UniMessage | LLMMessage | list[LLMContentPart], - *, - model: ModelName = None, - instruction: str | None = None, - tools: list[str], - max_cycles: int = 5, - **kwargs: Any, -) -> LLMResponse: - """ - 无状态地执行一个带本地Python函数的LLM调用循环。 - - 参数: - message: 用户输入。 - model: 使用的模型。 - instruction: 系统指令。 - tools: 要使用的本地函数工具名称列表 (必须已通过 @function_tool 注册)。 - max_cycles: 最大工具调用循环次数。 - **kwargs: 额外的生成配置参数。 - - 返回: - LLMResponse: 包含最终回复的响应对象。 - """ - from .executor import ExecutionConfig, LLMToolExecutor - from .utils import normalize_to_llm_messages - - messages = await normalize_to_llm_messages(message, instruction) - - async with await get_model_instance( - model, override_config=kwargs - ) as model_instance: - resolved_tools = await tool_provider_manager.get_function_tools(tools) - if not resolved_tools: - logger.warning( - "run_with_tools 未找到任何可用的本地函数工具,将作为普通聊天执行。" - ) - return await model_instance.generate_response(messages, tools=None) - - executor = LLMToolExecutor(model_instance) - config = ExecutionConfig(max_cycles=max_cycles) - final_history = await executor.run(messages, resolved_tools, config) - - for msg in reversed(final_history): - if msg.role == "assistant": - text = msg.content if isinstance(msg.content, str) else str(msg.content) - return LLMResponse(text=text, tool_calls=msg.tool_calls) - - raise LLMException( - "带工具的执行循环未能产生有效的助手回复。", code=LLMErrorCode.GENERATION_FAILED - ) + friendly_msg = get_user_friendly_error_message(e) + logger.error(f"生成响应失败: {e} | 建议: {friendly_msg}", e=e) + raise LLMException(f"生成响应失败: {friendly_msg}", cause=e) async def _generate_image_from_message( message: UniMessage, model: ModelName = None, - **kwargs: Any, + config: LLMGenerationConfig | GenConfigBuilder | None = None, ) -> LLMResponse: """ [内部] 从 UniMessage 生成图片的核心辅助函数。 """ from .utils import normalize_to_llm_messages - config = ( - create_generation_config_from_kwargs(**kwargs) - if kwargs - else LLMGenerationConfig() - ) + if isinstance(config, GenConfigBuilder): + config = config.build() + + config = config or LLMGenerationConfig() config.validation_policy = {"require_image": True} - config.response_modalities = ["IMAGE", "TEXT"] + if config.output is None: + config.output = OutputConfig() + config.output.response_modalities = ["IMAGE", "TEXT"] try: messages = await normalize_to_llm_messages(message) async with await get_model_instance(model) as model_instance: - if not model_instance.can_generate_images(): + if not model_instance.can_generate_images: raise LLMException( f"模型 '{model_instance.provider_name}/{model_instance.model_name}'" f"不支持图片生成", @@ -347,8 +319,9 @@ async def _generate_image_from_message( except LLMException: raise except Exception as e: - logger.error(f"执行图片生成时发生未知错误: {e}", e=e) - raise LLMException(f"图片生成失败: {e}", cause=e) + friendly_msg = get_user_friendly_error_message(e) + logger.error(f"执行图片生成时发生未知错误: {e} | 建议: {friendly_msg}", e=e) + raise LLMException(f"图片生成失败: {friendly_msg}", cause=e) @overload @@ -357,7 +330,6 @@ async def create_image( *, images: None = None, model: ModelName = None, - **kwargs: Any, ) -> LLMResponse: """根据文本提示生成一张新图片。""" ... @@ -369,7 +341,6 @@ async def create_image( *, images: list[Path | bytes | str] | Path | bytes | str, model: ModelName = None, - **kwargs: Any, ) -> LLMResponse: """在给定图片的基础上,根据文本提示进行编辑或重新生成。""" ... @@ -380,7 +351,7 @@ async def create_image( *, images: list[Path | bytes | str] | Path | bytes | str | None = None, model: ModelName = None, - **kwargs: Any, + config: LLMGenerationConfig | GenConfigBuilder | None = None, ) -> LLMResponse: """ 智能图片生成/编辑函数。 @@ -400,4 +371,42 @@ async def create_image( message = create_multimodal_message(text=text_prompt, images=image_list) - return await _generate_image_from_message(message, model=model, **kwargs) + return await _generate_image_from_message(message, model=model, config=config) + + +async def search( + query: str | UniMessage | LLMMessage | list[LLMContentPart], + *, + model: ModelName = None, + instruction: str = ( + "你是一位强大的信息检索和整合专家。请利用可用的搜索工具," + "根据用户的查询找到最相关的信息,并进行总结和回答。" + ), + config: LLMGenerationConfig | GenConfigBuilder | None = None, +) -> LLMResponse: + """ + 无状态的信息搜索便捷函数,利用搜索工具获取实时信息。 + + 参数: + query: 搜索查询内容,支持多种输入格式。 + model: 要使用的模型名称,如果为None则使用默认模型。 + config: (可选) 生成配置对象,将与预设配置合并后传递。 + instruction: 搜索任务的系统指令,指导AI如何处理搜索结果。 + + 返回: + LLMResponse: 包含搜索结果和AI整合回复的完整响应对象。 + """ + logger.debug("执行无状态 'search' 任务...") + search_config = CommonOverrides.gemini_grounding() + + if isinstance(config, GenConfigBuilder): + config = config.build() + + final_config = search_config.merge_with(config) + + return await chat( + query, + model=model, + instruction=instruction, + config=final_config, + ) diff --git a/zhenxun/services/llm/config/__init__.py b/zhenxun/services/llm/config/__init__.py index f43792c4..345347a0 100644 --- a/zhenxun/services/llm/config/__init__.py +++ b/zhenxun/services/llm/config/__init__.py @@ -5,13 +5,12 @@ LLM 配置模块 """ from .generation import ( + CommonOverrides, + GenConfigBuilder, + LLMEmbeddingConfig, LLMGenerationConfig, - ModelConfigOverride, - apply_api_specific_mappings, - create_generation_config_from_kwargs, validate_override_params, ) -from .presets import CommonOverrides from .providers import ( LLMConfig, get_gemini_safety_threshold, @@ -23,11 +22,10 @@ from .providers import ( __all__ = [ "CommonOverrides", + "GenConfigBuilder", "LLMConfig", + "LLMEmbeddingConfig", "LLMGenerationConfig", - "ModelConfigOverride", - "apply_api_specific_mappings", - "create_generation_config_from_kwargs", "get_gemini_safety_threshold", "get_llm_config", "register_llm_configs", diff --git a/zhenxun/services/llm/config/generation.py b/zhenxun/services/llm/config/generation.py index 9f132b8a..45560b89 100644 --- a/zhenxun/services/llm/config/generation.py +++ b/zhenxun/services/llm/config/generation.py @@ -3,209 +3,397 @@ LLM 生成配置相关类和函数 """ from collections.abc import Callable -from typing import Any +from enum import Enum +from typing import Any, Literal +from typing_extensions import Self from pydantic import BaseModel, ConfigDict, Field from zhenxun.services.log import logger -from zhenxun.utils.pydantic_compat import model_dump +from zhenxun.utils.pydantic_compat import model_copy, model_dump, model_validate -from ..types import LLMResponse -from ..types.enums import ResponseFormat +from ..types import LLMResponse, ResponseFormat, StructuredOutputStrategy from ..types.exceptions import LLMErrorCode, LLMException +from .providers import get_gemini_safety_threshold -class ModelConfigOverride(BaseModel): - """模型配置覆盖参数""" +class ReasoningEffort(str, Enum): + """推理努力程度枚举""" + + LOW = "LOW" + MEDIUM = "MEDIUM" + HIGH = "HIGH" + + +class ImageAspectRatio(str, Enum): + """图像宽高比枚举""" + + SQUARE = "1:1" + LANDSCAPE_16_9 = "16:9" + PORTRAIT_9_16 = "9:16" + LANDSCAPE_4_3 = "4:3" + PORTRAIT_3_4 = "3:4" + LANDSCAPE_3_2 = "3:2" + PORTRAIT_2_3 = "2:3" + + +class ImageResolution(str, Enum): + """图像分辨率/质量枚举""" + + STANDARD = "STANDARD" + HD = "HD" + + +class CoreConfig(BaseModel): + """核心生成参数""" temperature: float | None = Field( default=None, ge=0.0, le=2.0, description="生成温度" ) + """生成温度""" max_tokens: int | None = Field(default=None, gt=0, description="最大输出token数") + """最大输出token数""" top_p: float | None = Field(default=None, ge=0.0, le=1.0, description="核采样参数") + """核采样参数""" top_k: int | None = Field(default=None, gt=0, description="Top-K采样参数") + """Top-K采样参数""" frequency_penalty: float | None = Field( default=None, ge=-2.0, le=2.0, description="频率惩罚" ) + """频率惩罚""" presence_penalty: float | None = Field( default=None, ge=-2.0, le=2.0, description="存在惩罚" ) + """存在惩罚""" repetition_penalty: float | None = Field( default=None, ge=0.0, le=2.0, description="重复惩罚" ) - + """重复惩罚""" stop: list[str] | str | None = Field(default=None, description="停止序列") + """停止序列""" + + +class ReasoningConfig(BaseModel): + """推理能力配置""" + + effort: ReasoningEffort | None = Field( + default=None, description="推理努力程度 (适用于 O1, Gemini 3)" + ) + """推理努力程度 (适用于 O1, Gemini 3)""" + budget_tokens: int | None = Field( + default=None, description="具体的思考 Token 预算 (适用于 Gemini 2.5)" + ) + """具体的思考 Token 预算 (适用于 Gemini 2.5)""" + show_thoughts: bool | None = Field( + default=None, description="是否在响应中显式包含思维链内容" + ) + """是否在响应中显式包含思维链内容""" + + +class VisualConfig(BaseModel): + """视觉生成配置""" + + aspect_ratio: ImageAspectRatio | str | None = Field( + default=None, description="宽高比" + ) + """宽高比""" + resolution: ImageResolution | str | None = Field( + default=None, description="生成质量/分辨率" + ) + """生成质量/分辨率""" + media_resolution: str | None = Field( + default=None, + description="输入媒体的解析度 (Gemini 3+): 'LOW', 'MEDIUM', 'HIGH'", + ) + """输入媒体的解析度 (Gemini 3+): 'LOW', 'MEDIUM', 'HIGH'""" + style: str | None = Field( + default=None, description="图像风格 (如 DALL-E 3 vivid/natural)" + ) + """图像风格 (如 DALL-E 3 vivid/natural)""" + + +class OutputConfig(BaseModel): + """输出格式控制""" response_format: ResponseFormat | dict[str, Any] | None = Field( default=None, description="期望的响应格式" ) + """期望的响应格式""" response_mime_type: str | None = Field( default=None, description="响应MIME类型(Gemini专用)" ) + """响应MIME类型(Gemini专用)""" response_schema: dict[str, Any] | None = Field( default=None, description="JSON响应模式" ) - thinking_budget: float | None = Field( - default=None, ge=0.0, le=1.0, description="思考预算" - ) - include_thoughts: bool | None = Field( - default=None, description="是否在响应中包含思维过程(Gemini专用)" - ) - safety_settings: dict[str, str] | None = Field(default=None, description="安全设置") + """JSON响应模式""" response_modalities: list[str] | None = Field( - default=None, description="响应模态类型" + default=None, description="响应模态类型 (TEXT, IMAGE, AUDIO)" ) + """响应模态类型 (TEXT, IMAGE, AUDIO)""" + structured_output_strategy: StructuredOutputStrategy | str | None = Field( + default=None, description="结构化输出策略 (NATIVE/TOOL_CALL/PROMPT)" + ) + """结构化输出策略 (NATIVE/TOOL_CALL/PROMPT)""" - enable_code_execution: bool | None = Field( - default=None, description="是否启用代码执行" + +class SafetyConfig(BaseModel): + """安全设置""" + + safety_settings: dict[str, str] | None = Field(default=None, description="安全设置") + """安全设置""" + + +class ToolConfig(BaseModel): + """工具调用控制配置""" + + mode: Literal["AUTO", "ANY", "NONE"] = Field( + default="AUTO", + description="工具调用模式: AUTO(自动), ANY(强制), NONE(禁用)", ) - enable_grounding: bool | None = Field( - default=None, description="是否启用信息来源关联" + """工具调用模式: AUTO(自动), ANY(强制), NONE(禁用)""" + allowed_function_names: list[str] | None = Field( + default=None, + description="当 mode 为 ANY 时,允许调用的函数名称白名单", ) + """当 mode 为 ANY 时,允许调用的函数名称白名单""" + + +class LLMGenerationConfig(BaseModel): + """ + LLM 生成配置 + 采用组件化设计,不再扁平化参数。 + """ + + core: CoreConfig | None = Field(default=None, description="基础生成参数") + """基础生成参数""" + reasoning: ReasoningConfig | None = Field(default=None, description="推理能力配置") + """推理能力配置""" + visual: VisualConfig | None = Field(default=None, description="视觉生成配置") + """视觉生成配置""" + output: OutputConfig | None = Field(default=None, description="输出格式配置") + """输出格式配置""" + safety: SafetyConfig | None = Field(default=None, description="安全配置") + """安全配置""" + tool_config: ToolConfig | None = Field(default=None, description="工具调用策略配置") + """工具调用策略配置""" + enable_caching: bool | None = Field(default=None, description="是否启用响应缓存") + """是否启用响应缓存""" custom_params: dict[str, Any] | None = Field(default=None, description="自定义参数") + """自定义参数""" validation_policy: dict[str, Any] | None = Field( default=None, description="声明式的响应验证策略 (例如: {'require_image': True})" ) + """声明式的响应验证策略 (例如: {'require_image': True})""" response_validator: Callable[[LLMResponse], None] | None = Field( - default=None, description="一个高级回调函数,用于验证响应,验证失败时应抛出异常" + default=None, + description="一个高级回调函数,用于验证响应,验证失败时应抛出异常", ) + """一个高级回调函数,用于验证响应,验证失败时应抛出异常""" model_config = ConfigDict(arbitrary_types_allowed=True) + @classmethod + def builder(cls) -> "GenConfigBuilder": + """创建一个新的配置构建器""" + return GenConfigBuilder() + def to_dict(self) -> dict[str, Any]: - """转换为字典,排除None值""" + """ + 转换为字典,排除None值。 + 注意:这会返回嵌套结构的字典。适配器需要处理这种嵌套。 + """ + return model_dump(self, exclude_none=True) - model_data = model_dump(self, exclude_none=True) + def merge_with(self, other: "LLMGenerationConfig | None") -> "LLMGenerationConfig": + """ + 与另一个配置对象进行深度合并。 + other 中的非 None 字段会覆盖当前配置中的对应字段。 + 返回一个新的配置对象,原对象不变。 + """ + if not other: + return model_copy(self, deep=True) - result = {} - for key, value in model_data.items(): - if key == "custom_params" and isinstance(value, dict): - result.update(value) - else: - result[key] = value + new_config = model_copy(self, deep=True) - return result + def _merge_component(base_comp, override_comp, comp_cls): + if override_comp is None: + return base_comp + if base_comp is None: + return override_comp + updates = model_dump(override_comp, exclude_none=True) + return model_copy(base_comp, update=updates) - def merge_with_base_config( + new_config.core = _merge_component(new_config.core, other.core, CoreConfig) + new_config.reasoning = _merge_component( + new_config.reasoning, other.reasoning, ReasoningConfig + ) + new_config.visual = _merge_component( + new_config.visual, other.visual, VisualConfig + ) + new_config.output = _merge_component( + new_config.output, other.output, OutputConfig + ) + new_config.safety = _merge_component( + new_config.safety, other.safety, SafetyConfig + ) + new_config.tool_config = _merge_component( + new_config.tool_config, other.tool_config, ToolConfig + ) + + if other.enable_caching is not None: + new_config.enable_caching = other.enable_caching + + if other.custom_params: + if new_config.custom_params is None: + new_config.custom_params = {} + new_config.custom_params.update(other.custom_params) + + if other.validation_policy: + if new_config.validation_policy is None: + new_config.validation_policy = {} + new_config.validation_policy.update(other.validation_policy) + + if other.response_validator: + new_config.response_validator = other.response_validator + + return new_config + + +class LLMEmbeddingConfig(BaseModel): + """Embedding 专用配置""" + + task_type: str | None = Field(default=None, description="任务类型 (Gemini/Jina)") + """任务类型 (Gemini/Jina)""" + output_dimensionality: int | None = Field( + default=None, description="输出维度/压缩维度 (Gemini/Jina/OpenAI)" + ) + """输出维度/压缩维度 (Gemini/Jina/OpenAI)""" + title: str | None = Field( + default=None, description="仅用于 Gemini RETRIEVAL_DOCUMENT 任务的标题" + ) + """仅用于 Gemini RETRIEVAL_DOCUMENT 任务的标题""" + encoding_format: str | None = Field( + default="float", description="编码格式 (float/base64)" + ) + """编码格式 (float/base64)""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class GenConfigBuilder: + """ + LLM 生成配置的语义化构建器。 + 设计原则:高频业务场景优先,低频参数命名空间化。 + """ + + def __init__(self): + self._config = LLMGenerationConfig() + + def _ensure_core(self) -> CoreConfig: + if self._config.core is None: + self._config.core = CoreConfig() + return self._config.core + + def _ensure_output(self) -> OutputConfig: + if self._config.output is None: + self._config.output = OutputConfig() + return self._config.output + + def _ensure_reasoning(self) -> ReasoningConfig: + if self._config.reasoning is None: + self._config.reasoning = ReasoningConfig() + return self._config.reasoning + + def as_json(self, schema: dict[str, Any] | None = None) -> Self: + """ + [高频] 强制模型输出 JSON 格式。 + """ + out = self._ensure_output() + out.response_format = ResponseFormat.JSON + if schema: + out.response_schema = schema + return self + + def enable_thinking( + self, budget_tokens: int = -1, show_thoughts: bool = False + ) -> Self: + """ + [高频] 启用模型的思考/推理能力 (如 Gemini 2.0 Flash Thinking, DeepSeek R1)。 + """ + reasoning = self._ensure_reasoning() + reasoning.budget_tokens = budget_tokens + reasoning.show_thoughts = show_thoughts + return self + + def config_core( self, - base_temperature: float | None = None, - base_max_tokens: int | None = None, - ) -> dict[str, Any]: - """与基础配置合并,覆盖参数优先""" - merged = {} + temperature: float | None = None, + max_tokens: int | None = None, + top_p: float | None = None, + top_k: int | None = None, + stop: list[str] | str | None = None, + frequency_penalty: float | None = None, + presence_penalty: float | None = None, + ) -> Self: + """ + [低频] 配置核心生成参数。 + """ + core = self._ensure_core() + if temperature is not None: + core.temperature = temperature + if max_tokens is not None: + core.max_tokens = max_tokens + if top_p is not None: + core.top_p = top_p + if top_k is not None: + core.top_k = top_k + if stop is not None: + core.stop = stop + if frequency_penalty is not None: + core.frequency_penalty = frequency_penalty + if presence_penalty is not None: + core.presence_penalty = presence_penalty + return self - if base_temperature is not None: - merged["temperature"] = base_temperature - if base_max_tokens is not None: - merged["max_tokens"] = base_max_tokens + def config_safety(self, settings: dict[str, str]) -> Self: + """ + [低频] 配置安全过滤设置。 + """ + if self._config.safety is None: + self._config.safety = SafetyConfig() + self._config.safety.safety_settings = settings + return self - override_dict = self.to_dict() - merged.update(override_dict) + def config_visual( + self, + aspect_ratio: ImageAspectRatio | str | None = None, + resolution: ImageResolution | str | None = None, + ) -> Self: + """ + [低频] 配置视觉生成参数 (DALL-E 3 / Gemini Imagen)。 + """ + if self._config.visual is None: + self._config.visual = VisualConfig() + if aspect_ratio: + self._config.visual.aspect_ratio = aspect_ratio + if resolution: + self._config.visual.resolution = resolution + return self - return merged + def set_custom_param(self, key: str, value: Any) -> Self: + """设置特定于厂商的自定义参数""" + if self._config.custom_params is None: + self._config.custom_params = {} + self._config.custom_params[key] = value + return self - -class LLMGenerationConfig(ModelConfigOverride): - """LLM 生成配置,继承模型配置覆盖参数""" - - def to_api_params(self, api_type: str, model_name: str) -> dict[str, Any]: - """转换为API参数,支持不同API类型的参数名映射""" - _ = model_name - params = {} - - if self.temperature is not None: - params["temperature"] = self.temperature - - if self.max_tokens is not None: - if api_type == "gemini": - params["maxOutputTokens"] = self.max_tokens - else: - params["max_tokens"] = self.max_tokens - - if api_type == "gemini": - if self.top_k is not None: - params["topK"] = self.top_k - if self.top_p is not None: - params["topP"] = self.top_p - else: - if self.top_k is not None: - params["top_k"] = self.top_k - if self.top_p is not None: - params["top_p"] = self.top_p - - if api_type in ["openai", "deepseek", "zhipu", "general_openai_compat"]: - if self.frequency_penalty is not None: - params["frequency_penalty"] = self.frequency_penalty - if self.presence_penalty is not None: - params["presence_penalty"] = self.presence_penalty - - if self.repetition_penalty is not None: - if api_type == "openai": - logger.warning("OpenAI官方API不支持repetition_penalty参数,已忽略") - else: - params["repetition_penalty"] = self.repetition_penalty - - if self.response_format is not None: - if isinstance(self.response_format, dict): - if api_type in ["openai", "zhipu", "deepseek", "general_openai_compat"]: - params["response_format"] = self.response_format - logger.debug( - f"为 {api_type} 使用自定义 response_format: " - f"{self.response_format}" - ) - elif self.response_format == ResponseFormat.JSON: - if api_type in ["openai", "zhipu", "deepseek", "general_openai_compat"]: - params["response_format"] = {"type": "json_object"} - logger.debug(f"为 {api_type} 启用 JSON 对象输出模式") - elif api_type == "gemini": - params["responseMimeType"] = "application/json" - if self.response_schema: - params["responseSchema"] = self.response_schema - logger.debug(f"为 {api_type} 启用 JSON MIME 类型输出模式") - - if self.custom_params: - custom_mapped = apply_api_specific_mappings(self.custom_params, api_type) - params.update(custom_mapped) - - if api_type == "gemini": - if ( - self.response_format != ResponseFormat.JSON - and self.response_mime_type is not None - ): - params["responseMimeType"] = self.response_mime_type - logger.debug( - f"使用显式设置的 responseMimeType: {self.response_mime_type}" - ) - - if self.response_schema is not None and "responseSchema" not in params: - params["responseSchema"] = self.response_schema - - if self.thinking_budget is not None or self.include_thoughts is not None: - thinking_config = params.setdefault("thinkingConfig", {}) - - if self.thinking_budget is not None: - max_budget = 24576 - budget_value = int(self.thinking_budget * max_budget) - thinking_config["thinkingBudget"] = budget_value - logger.debug( - f"已将 thinking_budget (float: {self.thinking_budget}) " - f"转换为 Gemini API 的整数格式: {budget_value}" - ) - - if self.include_thoughts is not None: - thinking_config["includeThoughts"] = self.include_thoughts - logger.debug(f"已设置 includeThoughts: {self.include_thoughts}") - - if self.safety_settings is not None: - params["safetySettings"] = self.safety_settings - if self.response_modalities is not None: - params["responseModalities"] = self.response_modalities - - logger.debug(f"为{api_type}转换配置参数: {len(params)}个参数") - return params + def build(self) -> LLMGenerationConfig: + """构建最终的配置对象""" + return self._config def validate_override_params( @@ -215,12 +403,12 @@ def validate_override_params( if override_config is None: return LLMGenerationConfig() + if isinstance(override_config, LLMGenerationConfig): + return override_config + if isinstance(override_config, dict): try: - filtered_config = { - k: v for k, v in override_config.items() if v is not None - } - return LLMGenerationConfig(**filtered_config) + return model_validate(LLMGenerationConfig, override_config) except Exception as e: logger.warning(f"覆盖配置参数验证失败: {e}") raise LLMException( @@ -229,56 +417,107 @@ def validate_override_params( cause=e, ) - return override_config + raise LLMException( + f"不支持的配置类型: {type(override_config)}", + code=LLMErrorCode.CONFIGURATION_ERROR, + ) -def apply_api_specific_mappings( - params: dict[str, Any], api_type: str -) -> dict[str, Any]: - """应用API特定的参数映射""" - mapped_params = params.copy() +class CommonOverrides: + """常用的配置覆盖预设""" - if api_type == "gemini": - if "max_tokens" in mapped_params: - mapped_params["maxOutputTokens"] = mapped_params.pop("max_tokens") - if "top_k" in mapped_params: - mapped_params["topK"] = mapped_params.pop("top_k") - if "top_p" in mapped_params: - mapped_params["topP"] = mapped_params.pop("top_p") + @staticmethod + def gemini_json() -> LLMGenerationConfig: + """Gemini JSON模式:强制JSON输出""" + return LLMGenerationConfig( + core=CoreConfig(), + output=OutputConfig( + response_format=ResponseFormat.JSON, + response_mime_type="application/json", + ), + ) - unsupported = ["frequency_penalty", "presence_penalty", "repetition_penalty"] - for param in unsupported: - if param in mapped_params: - logger.warning(f"Gemini 原生API不支持参数 '{param}',已忽略") - mapped_params.pop(param) + @staticmethod + def gemini_2_5_thinking(tokens: int = -1) -> LLMGenerationConfig: + """Gemini 2.5 思考模式:默认 -1 (动态思考),0 为禁用,>=1024 为固定预算""" + return LLMGenerationConfig( + core=CoreConfig(temperature=1.0), + reasoning=ReasoningConfig(budget_tokens=tokens, show_thoughts=True), + ) - elif api_type in ["openai", "deepseek", "zhipu", "general_openai_compat"]: - if "repetition_penalty" in mapped_params and api_type == "openai": - logger.warning("OpenAI官方API不支持repetition_penalty参数,已忽略") - mapped_params.pop("repetition_penalty") + @staticmethod + def gemini_3_thinking(level: str = "HIGH") -> LLMGenerationConfig: + """Gemini 3 深度思考模式:使用思考等级""" + try: + effort = ReasoningEffort(level.upper()) + except ValueError: + effort = ReasoningEffort.HIGH - if "stop" in mapped_params: - stop_value = mapped_params["stop"] - if isinstance(stop_value, str): - mapped_params["stop"] = [stop_value] + return LLMGenerationConfig( + core=CoreConfig(), + reasoning=ReasoningConfig(effort=effort, show_thoughts=True), + ) - return mapped_params + @staticmethod + def gemini_structured(schema: dict[str, Any]) -> LLMGenerationConfig: + """Gemini 结构化输出:自定义JSON模式""" + return LLMGenerationConfig( + core=CoreConfig(), + output=OutputConfig( + response_mime_type="application/json", response_schema=schema + ), + ) + @staticmethod + def gemini_safe() -> LLMGenerationConfig: + """Gemini 安全模式:使用配置的安全设置""" + threshold = get_gemini_safety_threshold() + return LLMGenerationConfig( + core=CoreConfig(), + safety=SafetyConfig( + safety_settings={ + "HARM_CATEGORY_HARASSMENT": threshold, + "HARM_CATEGORY_HATE_SPEECH": threshold, + "HARM_CATEGORY_SEXUALLY_EXPLICIT": threshold, + "HARM_CATEGORY_DANGEROUS_CONTENT": threshold, + } + ), + ) -def create_generation_config_from_kwargs(**kwargs) -> LLMGenerationConfig: - """从关键字参数创建生成配置""" - model_fields = getattr(LLMGenerationConfig, "model_fields", {}) - known_fields = set(model_fields.keys()) - known_params = {} - custom_params = {} + @staticmethod + def gemini_code_execution() -> LLMGenerationConfig: + """Gemini 代码执行模式:启用代码执行功能""" + return LLMGenerationConfig( + core=CoreConfig(), + custom_params={"code_execution_timeout": 30}, + ) - for key, value in kwargs.items(): - if key in known_fields: - known_params[key] = value - else: - custom_params[key] = value + @staticmethod + def gemini_grounding() -> LLMGenerationConfig: + """Gemini 信息来源关联模式:启用Google搜索""" + return LLMGenerationConfig( + core=CoreConfig(), + custom_params={ + "grounding_config": {"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}} + }, + ) - if custom_params: - known_params["custom_params"] = custom_params + @staticmethod + def gemini_nano_banana(aspect_ratio: str = "16:9") -> LLMGenerationConfig: + """Gemini Nano Banana Pro:自定义比例生图""" + try: + ar = ImageAspectRatio(aspect_ratio) + except ValueError: + ar = ImageAspectRatio.LANDSCAPE_16_9 - return LLMGenerationConfig(**known_params) + return LLMGenerationConfig( + core=CoreConfig(), + visual=VisualConfig(aspect_ratio=ar), + ) + + @staticmethod + def gemini_high_res() -> LLMGenerationConfig: + """Gemini 3: 强制使用高解析度处理输入媒体""" + return LLMGenerationConfig( + visual=VisualConfig(media_resolution="HIGH", resolution=ImageResolution.HD) + ) diff --git a/zhenxun/services/llm/config/presets.py b/zhenxun/services/llm/config/presets.py deleted file mode 100644 index aa4b6c21..00000000 --- a/zhenxun/services/llm/config/presets.py +++ /dev/null @@ -1,172 +0,0 @@ -""" -LLM 预设配置 - -提供常用的配置预设,特别是针对 Gemini 的高级功能。 -""" - -from typing import Any - -from .generation import LLMGenerationConfig - - -class CommonOverrides: - """常用的配置覆盖预设""" - - @staticmethod - def creative() -> LLMGenerationConfig: - """创意模式:高温度,鼓励创新""" - return LLMGenerationConfig(temperature=0.9, top_p=0.95, frequency_penalty=0.1) - - @staticmethod - def precise() -> LLMGenerationConfig: - """精确模式:低温度,确定性输出""" - return LLMGenerationConfig(temperature=0.1, top_p=0.9, frequency_penalty=0.0) - - @staticmethod - def balanced() -> LLMGenerationConfig: - """平衡模式:中等温度""" - return LLMGenerationConfig(temperature=0.5, top_p=0.9, frequency_penalty=0.0) - - @staticmethod - def concise(max_tokens: int = 100) -> LLMGenerationConfig: - """简洁模式:限制输出长度""" - return LLMGenerationConfig( - temperature=0.3, - max_tokens=max_tokens, - stop=["\n\n", "。", "!", "?"], - ) - - @staticmethod - def detailed(max_tokens: int = 2000) -> LLMGenerationConfig: - """详细模式:鼓励详细输出""" - return LLMGenerationConfig( - temperature=0.7, max_tokens=max_tokens, frequency_penalty=-0.1 - ) - - @staticmethod - def gemini_json() -> LLMGenerationConfig: - """Gemini JSON模式:强制JSON输出""" - return LLMGenerationConfig( - temperature=0.3, response_mime_type="application/json" - ) - - @staticmethod - def gemini_thinking(budget: float = 0.8) -> LLMGenerationConfig: - """Gemini 思考模式:使用思考预算""" - return LLMGenerationConfig(temperature=0.7, thinking_budget=budget) - - @staticmethod - def gemini_creative() -> LLMGenerationConfig: - """Gemini 创意模式:高温度创意输出""" - return LLMGenerationConfig(temperature=0.9, top_p=0.95) - - @staticmethod - def gemini_structured(schema: dict[str, Any]) -> LLMGenerationConfig: - """Gemini 结构化输出:自定义JSON模式""" - return LLMGenerationConfig( - temperature=0.3, - response_mime_type="application/json", - response_schema=schema, - ) - - @staticmethod - def gemini_safe() -> LLMGenerationConfig: - """Gemini 安全模式:使用配置的安全设置""" - from .providers import get_gemini_safety_threshold - - threshold = get_gemini_safety_threshold() - return LLMGenerationConfig( - temperature=0.5, - safety_settings={ - "HARM_CATEGORY_HARASSMENT": threshold, - "HARM_CATEGORY_HATE_SPEECH": threshold, - "HARM_CATEGORY_SEXUALLY_EXPLICIT": threshold, - "HARM_CATEGORY_DANGEROUS_CONTENT": threshold, - }, - ) - - @staticmethod - def gemini_multimodal() -> LLMGenerationConfig: - """Gemini 多模态模式:优化多模态处理""" - return LLMGenerationConfig(temperature=0.6, max_tokens=2048, top_p=0.8) - - @staticmethod - def gemini_code_execution() -> LLMGenerationConfig: - """Gemini 代码执行模式:启用代码执行功能""" - return LLMGenerationConfig( - temperature=0.3, - max_tokens=4096, - enable_code_execution=True, - custom_params={"code_execution_timeout": 30}, - ) - - @staticmethod - def gemini_grounding() -> LLMGenerationConfig: - """Gemini 信息来源关联模式:启用Google搜索""" - return LLMGenerationConfig( - temperature=0.5, - max_tokens=4096, - enable_grounding=True, - custom_params={ - "grounding_config": {"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}} - }, - ) - - @staticmethod - def gemini_cached() -> LLMGenerationConfig: - """Gemini 缓存模式:启用响应缓存""" - return LLMGenerationConfig( - temperature=0.3, - max_tokens=2048, - enable_caching=True, - ) - - @staticmethod - def gemini_advanced() -> LLMGenerationConfig: - """Gemini 高级模式:启用所有高级功能""" - return LLMGenerationConfig( - temperature=0.5, - max_tokens=4096, - enable_code_execution=True, - enable_grounding=True, - enable_caching=True, - custom_params={ - "code_execution_timeout": 30, - "grounding_config": { - "dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"} - }, - }, - ) - - @staticmethod - def gemini_research() -> LLMGenerationConfig: - """Gemini 研究模式:思考+搜索+结构化输出""" - return LLMGenerationConfig( - temperature=0.6, - max_tokens=4096, - thinking_budget=0.8, - enable_grounding=True, - response_mime_type="application/json", - custom_params={ - "grounding_config": {"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}} - }, - ) - - @staticmethod - def gemini_analysis() -> LLMGenerationConfig: - """Gemini 分析模式:深度思考+详细输出""" - return LLMGenerationConfig( - temperature=0.4, - max_tokens=6000, - thinking_budget=0.9, - top_p=0.8, - ) - - @staticmethod - def gemini_fast_response() -> LLMGenerationConfig: - """Gemini 快速响应模式:低延迟+简洁输出""" - return LLMGenerationConfig( - temperature=0.3, - max_tokens=512, - top_p=0.8, - ) diff --git a/zhenxun/services/llm/config/providers.py b/zhenxun/services/llm/config/providers.py index f3895481..b6cc0c16 100644 --- a/zhenxun/services/llm/config/providers.py +++ b/zhenxun/services/llm/config/providers.py @@ -13,6 +13,7 @@ from zhenxun.configs.config import Config from zhenxun.configs.utils import parse_as from zhenxun.services.log import logger from zhenxun.utils.manager.priority_manager import PriorityLifecycle +from zhenxun.utils.pydantic_compat import model_dump from ..core import key_store from ..tools import tool_provider_manager @@ -22,6 +23,39 @@ AI_CONFIG_GROUP = "AI" PROVIDERS_CONFIG_KEY = "PROVIDERS" +class DebugLogOptions(BaseModel): + """调试日志细粒度控制""" + + show_tools: bool = Field( + default=True, description="是否在日志中显示工具定义(JSON Schema)" + ) + show_schema: bool = Field( + default=True, description="是否在日志中显示结构化输出Schema(response_format)" + ) + show_safety: bool = Field( + default=True, description="是否在日志中显示安全设置(safetySettings)" + ) + + def __bool__(self) -> bool: + """支持 bool(debug_options) 的语法,方便兼容旧逻辑。""" + return self.show_tools or self.show_schema or self.show_safety + + +class ClientSettings(BaseModel): + """LLM 客户端通用设置""" + + timeout: int = Field(default=300, description="API请求超时时间(秒)") + max_retries: int = Field(default=3, description="请求失败时的最大重试次数") + retry_delay: int = Field(default=2, description="请求重试的基础延迟时间(秒)") + structured_retries: int = Field( + default=2, description="结构化生成校验失败时的最大重试次数 (IVR)" + ) + proxy: str | None = Field( + default=None, + description="网络代理,例如 http://127.0.0.1:7890", + ) + + class LLMConfig(BaseModel): """LLM 服务配置类""" @@ -29,20 +63,16 @@ class LLMConfig(BaseModel): default=None, description="LLM服务全局默认使用的模型名称 (格式: ProviderName/ModelName)", ) - proxy: str | None = Field( - default=None, - description="LLM服务请求使用的网络代理,例如 http://127.0.0.1:7890", - ) - timeout: int = Field(default=180, description="LLM服务API请求超时时间(秒)") - max_retries_llm: int = Field( - default=3, description="LLM服务请求失败时的最大重试次数" - ) - retry_delay_llm: int = Field( - default=2, description="LLM服务请求重试的基础延迟时间(秒)" + client_settings: ClientSettings = Field( + default_factory=ClientSettings, description="客户端连接与重试配置" ) providers: list[ProviderConfig] = Field( default_factory=list, description="配置多个 AI 服务提供商及其模型信息" ) + debug_log: DebugLogOptions | bool = Field( + default_factory=DebugLogOptions, + description="LLM请求日志详情开关。支持 bool (全开/全关) 或 dict (细粒度控制)。", + ) def get_provider_by_name(self, name: str) -> ProviderConfig | None: """根据名称获取提供商配置 @@ -226,36 +256,29 @@ def register_llm_configs(): ) Config.add_plugin_config( AI_CONFIG_GROUP, - "proxy", - llm_config.proxy, - help="LLM服务请求使用的网络代理,例如 http://127.0.0.1:7890", - type=str, + "client_settings", + model_dump(llm_config.client_settings), + help=( + "LLM客户端高级设置。\n" + "包含: timeout(超时秒数), max_retries(重试次数), " + "retry_delay(重试延迟), structured_retries(结构化生成重试), proxy(代理)" + ), + type=dict, ) Config.add_plugin_config( AI_CONFIG_GROUP, - "timeout", - llm_config.timeout, - help="LLM服务API请求超时时间(秒)", - type=int, - ) - Config.add_plugin_config( - AI_CONFIG_GROUP, - "max_retries_llm", - llm_config.max_retries_llm, - help="LLM服务请求失败时的最大重试次数", - type=int, - ) - Config.add_plugin_config( - AI_CONFIG_GROUP, - "retry_delay_llm", - llm_config.retry_delay_llm, - help="LLM服务请求重试的基础延迟时间(秒)", - type=int, + "debug_log", + {"show_tools": True, "show_schema": True, "show_safety": True}, + help=( + "LLM日志详情开关。示例: {'show_tools': True, 'show_schema': False, " + "'show_safety': False}" + ), + type=dict, ) Config.add_plugin_config( AI_CONFIG_GROUP, "gemini_safety_threshold", - "BLOCK_MEDIUM_AND_ABOVE", + "BLOCK_NONE", help=( "Gemini 安全过滤阈值 " "(BLOCK_LOW_AND_ABOVE: 阻止低级别及以上, " @@ -270,7 +293,20 @@ def register_llm_configs(): AI_CONFIG_GROUP, PROVIDERS_CONFIG_KEY, get_default_providers(), - help="配置多个 AI 服务提供商及其模型信息", + help=( + "配置多个 AI 服务提供商及其模型信息。\n" + "注意:可以在特定模型配置下添加 'api_type' 以覆盖提供商的全局设置。\n" + "支持的 api_type 包括:\n" + "- 'openai': 标准 OpenAI 格式 (DeepSeek, SiliconFlow, Moonshot 等)\n" + "- 'gemini': Google Gemini API\n" + "- 'zhipu': 智谱 AI (GLM)\n" + "- 'ark': 字节跳动火山引擎 (Doubao)\n" + "- 'openrouter': OpenRouter 聚合平台\n" + "- 'openai_image': OpenAI 兼容的图像生成接口 (DALL-E)\n" + "- 'openai_responses': 支持新版 responses 格式的 OpenAI 兼容接口\n" + "- 'smart': 智能路由模式 (主要用于第三方中转场景,自动根据模型名" + "分发请求到 openai 或 gemini)" + ), default_value=[], type=list[ProviderConfig], ) @@ -278,15 +314,21 @@ def register_llm_configs(): @lru_cache(maxsize=1) def get_llm_config() -> LLMConfig: - """获取 LLM 配置实例,不再加载 MCP 工具配置""" + """获取 LLM 配置实例""" ai_config = get_ai_config() + raw_debug = ai_config.get("debug_log", False) + if isinstance(raw_debug, bool): + debug_log_val = DebugLogOptions( + show_tools=raw_debug, show_schema=raw_debug, show_safety=raw_debug + ) + else: + debug_log_val = raw_debug + config_data = { "default_model_name": ai_config.get("default_model_name"), - "proxy": ai_config.get("proxy"), - "timeout": ai_config.get("timeout", 180), - "max_retries_llm": ai_config.get("max_retries_llm", 3), - "retry_delay_llm": ai_config.get("retry_delay_llm", 2), + "client_settings": ai_config.get("client_settings", {}), + "debug_log": debug_log_val, PROVIDERS_CONFIG_KEY: ai_config.get(PROVIDERS_CONFIG_KEY, []), } @@ -314,14 +356,14 @@ def validate_llm_config() -> tuple[bool, list[str]]: try: llm_config = get_llm_config() - if llm_config.timeout <= 0: + if llm_config.client_settings.timeout <= 0: errors.append("timeout 必须大于 0") - if llm_config.max_retries_llm < 0: - errors.append("max_retries_llm 不能小于 0") + if llm_config.client_settings.max_retries < 0: + errors.append("max_retries 不能小于 0") - if llm_config.retry_delay_llm <= 0: - errors.append("retry_delay_llm 必须大于 0") + if llm_config.client_settings.retry_delay <= 0: + errors.append("retry_delay 必须大于 0") if not llm_config.providers: errors.append("至少需要配置一个 AI 服务提供商") diff --git a/zhenxun/services/llm/core.py b/zhenxun/services/llm/core.py index 6ab846d0..03781cbe 100644 --- a/zhenxun/services/llm/core.py +++ b/zhenxun/services/llm/core.py @@ -254,7 +254,7 @@ class KeyStats: if total_calls == 0: return KeyStatus.UNUSED - if self.success_rate < 80: + if self.success_rate < 70: return KeyStatus.ERROR if total_calls >= 5 and self.avg_latency > 15000: @@ -292,96 +292,6 @@ class RetryConfig: self.key_rotation = key_rotation -async def with_smart_retry( - func, - *args, - retry_config: RetryConfig | None = None, - key_store: "KeyStatusStore | None" = None, - provider_name: str | None = None, - **kwargs: Any, -) -> Any: - """ - 智能重试装饰器 - 支持Key轮询和错误分类 - - 参数: - func: 要重试的异步函数。 - *args: 传递给函数的位置参数。 - retry_config: 重试配置。 - key_store: API密钥状态存储。 - provider_name: 提供商名称。 - **kwargs: 传递给函数的关键字参数。 - - 返回: - Any: 函数执行结果。 - """ - config = retry_config or RetryConfig() - last_exception: Exception | None = None - failed_keys: set[str] = set() - - model_instance = next((arg for arg in args if hasattr(arg, "api_keys")), None) - all_provider_keys = model_instance.api_keys if model_instance else [] - - for attempt in range(config.max_retries + 1): - try: - if config.key_rotation and "failed_keys" in func.__code__.co_varnames: - kwargs["failed_keys"] = failed_keys - - start_time = time.monotonic() - result = await func(*args, **kwargs) - latency = (time.monotonic() - start_time) * 1000 - - if key_store and isinstance(result, tuple) and len(result) == 2: - _, api_key_used = result - if api_key_used: - await key_store.record_success(api_key_used, latency) - return result - else: - return result - - except LLMException as e: - last_exception = e - api_key_in_use = e.details.get("api_key") - - if api_key_in_use: - failed_keys.add(api_key_in_use) - if key_store and provider_name and len(all_provider_keys) > 1: - status_code = e.details.get("status_code") - error_message = f"({e.code.name}) {e.message}" - await key_store.record_failure( - api_key_in_use, status_code, error_message - ) - - should_retry = _should_retry_llm_error(e, attempt, config.max_retries) - if not should_retry: - logger.error(f"不可重试的错误,停止重试: {e}") - raise - - if attempt < config.max_retries: - wait_time = config.retry_delay - if config.exponential_backoff: - wait_time *= 2**attempt - logger.warning( - f"请求失败,{wait_time:.2f}秒后重试 (第{attempt + 1}次): {e}" - ) - await asyncio.sleep(wait_time) - else: - logger.error(f"重试{config.max_retries}次后仍然失败: {e}") - - except Exception as e: - last_exception = e - logger.error(f"非LLM异常,停止重试: {e}") - raise LLMException( - f"操作失败: {e}", - code=LLMErrorCode.GENERATION_FAILED, - cause=e, - ) - - if last_exception: - raise last_exception - else: - raise RuntimeError("重试函数未能正常执行且未捕获到异常") - - def _should_retry_llm_error( error: LLMException, attempt: int, max_retries: int ) -> bool: @@ -390,7 +300,9 @@ def _should_retry_llm_error( LLMErrorCode.MODEL_NOT_FOUND, LLMErrorCode.CONTEXT_LENGTH_EXCEEDED, LLMErrorCode.USER_LOCATION_NOT_SUPPORTED, + LLMErrorCode.INVALID_PARAMETER, LLMErrorCode.CONFIGURATION_ERROR, + LLMErrorCode.API_KEY_INVALID, } if error.code in non_retryable_errors: @@ -404,15 +316,12 @@ def _should_retry_llm_error( LLMErrorCode.RESPONSE_PARSE_ERROR, LLMErrorCode.GENERATION_FAILED, LLMErrorCode.CONTENT_FILTERED, - LLMErrorCode.API_KEY_INVALID, LLMErrorCode.API_QUOTA_EXCEEDED, } if error.code in retryable_errors: if error.code == LLMErrorCode.API_QUOTA_EXCEEDED: return attempt < min(2, max_retries) - elif error.code == LLMErrorCode.CONTENT_FILTERED: - return attempt < min(1, max_retries) return True return False @@ -558,14 +467,68 @@ class KeyStatusStore: now = time.time() cooldown_duration = 300 - if status_code in [401, 403, 404]: + location_not_supported = error_message and ( + "USER_LOCATION_NOT_SUPPORTED" in error_message + or "User location is not supported" in error_message + ) + if location_not_supported: + logger.warning( + f"API Key {key_id} 请求失败,原因是地区不支持 (Gemini)。" + " 这通常是代理节点问题,Key 本身可能是正常的。跳过冷却。" + ) + async with self._lock: + stats = self._key_stats.setdefault(api_key, KeyStats()) + stats.failure_count += 1 + stats.last_error_info = error_message[:256] + await self._save_to_file_internal() + return + + if error_message and ( + "API_QUOTA_EXCEEDED" in error_message + or "insufficient_quota" in error_message.lower() + ): + cooldown_duration = 3600 + logger.warning(f"API Key {key_id} 额度耗尽,冷却 1 小时。") + + is_key_invalid = status_code == 401 or ( + status_code == 400 + and error_message + and ( + "API_KEY_INVALID" in error_message + or "API key not valid" in error_message + ) + ) + + if is_key_invalid: cooldown_duration = 31536000 log_level = "error" log_message = f"API密钥认证/权限/路径错误,将永久禁用: {key_id}" + elif status_code == 403: + cooldown_duration = 3600 + log_level = "warning" + log_message = f"API密钥权限不足或地区不支持(403),冷却1小时: {key_id}" + elif status_code == 404: + log_level = "error" + log_message = "API请求返回 404 (未找到),可能是模型名称错误或接口地址" + f"错误,不冷却密钥: {key_id}" + elif status_code == 422: + cooldown_duration = 0 + log_level = "warning" + log_message = f"API请求无法处理(422),可能是生成故障,不冷却密钥: {key_id}" elif status_code == 429: cooldown_duration = 60 log_level = "warning" log_message = f"API密钥被限流,冷却60秒: {key_id}" + elif error_message and ( + "ConnectError" in error_message + or "NetworkError" in error_message + or "Connection refused" in error_message + or "RemoteProtocolError" in error_message + or "ProxyError" in error_message + ): + cooldown_duration = 0 + log_level = "warning" + log_message = f"网络连接层异常(代理/DNS),不冷却密钥: {key_id}" else: log_level = "warning" log_message = f"API密钥遇到临时性错误,冷却{cooldown_duration}秒: {key_id}" diff --git a/zhenxun/services/llm/executor.py b/zhenxun/services/llm/executor.py deleted file mode 100644 index b731e520..00000000 --- a/zhenxun/services/llm/executor.py +++ /dev/null @@ -1,193 +0,0 @@ -""" -LLM 轻量级工具执行器 - -提供驱动 LLM 与本地函数工具之间交互的核心循环。 -""" - -import asyncio -from enum import Enum -import json -from typing import Any - -from pydantic import BaseModel, Field - -from zhenxun.services.log import logger -from zhenxun.utils.decorator.retry import Retry -from zhenxun.utils.pydantic_compat import model_dump - -from .service import LLMModel -from .types import ( - LLMErrorCode, - LLMException, - LLMMessage, - ToolExecutable, - ToolResult, -) - - -class ExecutionConfig(BaseModel): - """ - 轻量级执行器的配置。 - """ - - max_cycles: int = Field(default=5, description="工具调用循环的最大次数。") - - -class ToolErrorType(str, Enum): - """结构化工具错误的类型枚举。""" - - TOOL_NOT_FOUND = "ToolNotFound" - INVALID_ARGUMENTS = "InvalidArguments" - EXECUTION_ERROR = "ExecutionError" - USER_CANCELLATION = "UserCancellation" - - -class ToolErrorResult(BaseModel): - """一个结构化的工具执行错误模型,用于返回给 LLM。""" - - error_type: ToolErrorType = Field(..., description="错误的类型。") - message: str = Field(..., description="对错误的详细描述。") - is_retryable: bool = Field(False, description="指示这个错误是否可能通过重试解决。") - - def model_dump(self, **kwargs): - return model_dump(self, **kwargs) - - -def _is_exception_retryable(e: Exception) -> bool: - """判断一个异常是否应该触发重试。""" - if isinstance(e, LLMException): - retryable_codes = { - LLMErrorCode.API_REQUEST_FAILED, - LLMErrorCode.API_TIMEOUT, - LLMErrorCode.API_RATE_LIMITED, - } - return e.code in retryable_codes - return True - - -class LLMToolExecutor: - """ - 一个通用的执行器,负责驱动 LLM 与工具之间的多轮交互。 - """ - - def __init__(self, model: LLMModel): - self.model = model - - async def run( - self, - messages: list[LLMMessage], - tools: dict[str, ToolExecutable], - config: ExecutionConfig | None = None, - ) -> list[LLMMessage]: - """ - 执行完整的思考-行动循环。 - """ - effective_config = config or ExecutionConfig() - execution_history = list(messages) - - for i in range(effective_config.max_cycles): - response = await self.model.generate_response( - execution_history, tools=tools - ) - - assistant_message = LLMMessage( - role="assistant", - content=response.text, - tool_calls=response.tool_calls, - ) - execution_history.append(assistant_message) - - if not response.tool_calls: - logger.info("✅ LLMToolExecutor:模型未请求工具调用,执行结束。") - return execution_history - - logger.info( - f"🛠️ LLMToolExecutor:模型请求并行调用 {len(response.tool_calls)} 个工具" - ) - tool_results = await self._execute_tools_parallel_safely( - response.tool_calls, - tools, - ) - execution_history.extend(tool_results) - - raise LLMException( - f"超过最大工具调用循环次数 ({effective_config.max_cycles})。", - code=LLMErrorCode.GENERATION_FAILED, - ) - - async def _execute_single_tool_safely( - self, tool_call: Any, available_tools: dict[str, ToolExecutable] - ) -> tuple[Any, ToolResult]: - """安全地执行单个工具调用。""" - tool_name = tool_call.function.name - arguments = {} - - try: - if tool_call.function.arguments: - arguments = json.loads(tool_call.function.arguments) - except json.JSONDecodeError as e: - error_result = ToolErrorResult( - error_type=ToolErrorType.INVALID_ARGUMENTS, - message=f"参数解析失败: {e}", - is_retryable=False, - ) - return tool_call, ToolResult(output=model_dump(error_result)) - - try: - executable = available_tools.get(tool_name) - if not executable: - raise LLMException( - f"Tool '{tool_name}' not found.", - code=LLMErrorCode.CONFIGURATION_ERROR, - ) - - @Retry.simple( - stop_max_attempt=2, wait_fixed_seconds=1, return_on_failure=None - ) - async def execute_with_retry(): - return await executable.execute(**arguments) - - execution_result = await execute_with_retry() - if execution_result is None: - raise LLMException("工具执行在多次重试后仍然失败。") - - return tool_call, execution_result - except Exception as e: - error_type = ToolErrorType.EXECUTION_ERROR - is_retryable = _is_exception_retryable(e) - if ( - isinstance(e, LLMException) - and e.code == LLMErrorCode.CONFIGURATION_ERROR - ): - error_type = ToolErrorType.TOOL_NOT_FOUND - is_retryable = False - - error_result = ToolErrorResult( - error_type=error_type, message=str(e), is_retryable=is_retryable - ) - return tool_call, ToolResult(output=model_dump(error_result)) - - async def _execute_tools_parallel_safely( - self, - tool_calls: list[Any], - available_tools: dict[str, ToolExecutable], - ) -> list[LLMMessage]: - """并行执行所有工具调用,并对每个调用的错误进行隔离。""" - if not tool_calls: - return [] - - tasks = [ - self._execute_single_tool_safely(call, available_tools) - for call in tool_calls - ] - results = await asyncio.gather(*tasks) - - tool_messages = [ - LLMMessage.tool_response( - tool_call_id=original_call.id, - function_name=original_call.function.name, - result=result.output, - ) - for original_call, result in results - ] - return tool_messages diff --git a/zhenxun/services/llm/manager.py b/zhenxun/services/llm/manager.py index dbe6d675..e69f9cec 100644 --- a/zhenxun/services/llm/manager.py +++ b/zhenxun/services/llm/manager.py @@ -13,15 +13,19 @@ from zhenxun.services.log import logger from zhenxun.utils.pydantic_compat import dump_json_safely from .config import validate_override_params -from .config.providers import AI_CONFIG_GROUP, PROVIDERS_CONFIG_KEY, get_ai_config +from .config.generation import LLMGenerationConfig +from .config.providers import ( + AI_CONFIG_GROUP, + PROVIDERS_CONFIG_KEY, + get_ai_config, + get_llm_config, +) from .core import http_client_manager, key_store from .service import LLMModel from .types import LLMErrorCode, LLMException, ModelDetail, ProviderConfig from .types.capabilities import get_model_capabilities DEFAULT_MODEL_NAME_KEY = "default_model_name" -PROXY_KEY = "proxy" -TIMEOUT_KEY = "timeout" _model_cache: dict[str, tuple[LLMModel, float]] = {} _cache_ttl = 3600 @@ -39,7 +43,8 @@ def parse_provider_model_string(name_str: str | None) -> tuple[str | None, str | def _make_cache_key( - provider_model_name: str | None, override_config: dict | None + provider_model_name: str | None, + override_config: dict | LLMGenerationConfig | None, ) -> str: """生成缓存键""" config_str = ( @@ -115,11 +120,12 @@ def get_default_api_base_for_type(api_type: str) -> str | None: """根据API类型获取默认的API基础地址""" default_api_bases = { "openai": "https://api.openai.com", - "deepseek": "https://api.deepseek.com", + "deepseek": "https://api.deepseek.com/beta", "zhipu": "https://open.bigmodel.cn", "gemini": "https://generativelanguage.googleapis.com", "openrouter": "https://openrouter.ai/api", - "general_openai_compat": None, + "smart": None, + "openai_responses": None, } return default_api_bases.get(api_type) @@ -244,7 +250,7 @@ def list_embedding_models() -> list[dict[str, Any]]: async def get_model_instance( provider_model_name: str | None = None, - override_config: dict[str, Any] | None = None, + override_config: dict[str, Any] | LLMGenerationConfig | None = None, ) -> LLMModel: """ 根据 'ProviderName/ModelName' 字符串获取并实例化 LLMModel (异步版本) @@ -303,21 +309,20 @@ async def get_model_instance( model_detail_found.is_embedding_model = capabilities.is_embedding_model - ai_config = get_ai_config() - global_proxy_setting = ai_config.get(PROXY_KEY) + llm_config = get_llm_config() + client_settings = llm_config.client_settings default_timeout = ( provider_config_found.timeout if provider_config_found.timeout is not None - else 180 + else client_settings.timeout ) - global_timeout_setting = ai_config.get(TIMEOUT_KEY, default_timeout) config_for_http_client = ProviderConfig( name=provider_config_found.name, api_key=provider_config_found.api_key, models=provider_config_found.models, - timeout=global_timeout_setting, - proxy=global_proxy_setting, + timeout=default_timeout, + proxy=client_settings.proxy, api_base=provider_config_found.api_base, api_type=provider_config_found.api_type, openai_compat=provider_config_found.openai_compat, diff --git a/zhenxun/services/llm/memory.py b/zhenxun/services/llm/memory.py deleted file mode 100644 index d983090d..00000000 --- a/zhenxun/services/llm/memory.py +++ /dev/null @@ -1,55 +0,0 @@ -from abc import ABC, abstractmethod -from collections import defaultdict -from typing import Any - -from .types import LLMMessage - - -class BaseMemory(ABC): - """ - 记忆系统的抽象基类。 - 定义了任何记忆后端都必须实现的接口。 - """ - - @abstractmethod - async def get_history(self, session_id: str) -> list[LLMMessage]: - """根据会话ID获取历史记录。""" - raise NotImplementedError - - @abstractmethod - async def add_message(self, session_id: str, message: LLMMessage) -> None: - """向指定会话添加一条消息。""" - raise NotImplementedError - - @abstractmethod - async def add_messages(self, session_id: str, messages: list[LLMMessage]) -> None: - """向指定会话添加多条消息。""" - raise NotImplementedError - - @abstractmethod - async def clear_history(self, session_id: str) -> None: - """清空指定会话的历史记录。""" - raise NotImplementedError - - -class InMemoryMemory(BaseMemory): - """ - 一个简单的、默认的内存记忆后端。 - 将历史记录存储在进程内存中的字典里。 - """ - - def __init__(self, **kwargs: Any): - self._history: dict[str, list[LLMMessage]] = defaultdict(list) - - async def get_history(self, session_id: str) -> list[LLMMessage]: - return self._history.get(session_id, []).copy() - - async def add_message(self, session_id: str, message: LLMMessage) -> None: - self._history[session_id].append(message) - - async def add_messages(self, session_id: str, messages: list[LLMMessage]) -> None: - self._history[session_id].extend(messages) - - async def clear_history(self, session_id: str) -> None: - if session_id in self._history: - del self._history[session_id] diff --git a/zhenxun/services/llm/service.py b/zhenxun/services/llm/service.py index 5b95bdf4..ecefd3a0 100644 --- a/zhenxun/services/llm/service.py +++ b/zhenxun/services/llm/service.py @@ -5,41 +5,87 @@ LLM 模型实现类 """ from abc import ABC, abstractmethod +import asyncio from collections.abc import Awaitable, Callable import json -from typing import Any, TypeVar +import re +import time +from typing import Any, Literal, TypeVar, cast -from pydantic import BaseModel +import httpx +from pydantic import BaseModel, ConfigDict, Field from zhenxun.services.log import logger +from zhenxun.utils.http_utils import AsyncHttpx from zhenxun.utils.log_sanitizer import sanitize_for_logging from zhenxun.utils.pydantic_compat import dump_json_safely -from .adapters.base import RequestData +from .adapters.base import BaseAdapter, RequestData, process_image_data from .config import LLMGenerationConfig -from .config.providers import get_ai_config +from .config.generation import LLMEmbeddingConfig +from .config.providers import get_llm_config from .core import ( KeyStatusStore, LLMHttpClient, RetryConfig, + _should_retry_llm_error, http_client_manager, - with_smart_retry, ) from .types import ( - EmbeddingTaskType, LLMErrorCode, LLMException, LLMMessage, LLMResponse, + LLMToolCall, ModelDetail, ProviderConfig, - ToolExecutable, + ToolChoice, ) from .types.capabilities import ModelCapabilities, ModelModality T = TypeVar("T", bound=BaseModel) +class LLMContext(BaseModel): + """LLM 执行上下文,用于在中间件管道中传递请求状态""" + + messages: list[LLMMessage] + config: LLMGenerationConfig | LLMEmbeddingConfig + tools: list[Any] | None + tool_choice: str | dict[str, Any] | ToolChoice | None + timeout: float | None + extra: dict[str, Any] = Field(default_factory=dict) + request_type: Literal["generation", "embedding"] = "generation" + runtime_state: dict[str, Any] = Field( + default_factory=dict, + description="中间件运行时的临时状态存储(api_key, retry_count等)", + ) + + model_config = ConfigDict(arbitrary_types_allowed=True) + + +NextCall = Callable[[LLMContext], Awaitable[LLMResponse]] +LLMMiddleware = Callable[[LLMContext, NextCall], Awaitable[LLMResponse]] + + +class BaseLLMMiddleware(ABC): + """LLM 中间件抽象基类""" + + @abstractmethod + async def __call__(self, context: LLMContext, next_call: NextCall) -> LLMResponse: + """ + 执行中间件逻辑 + + Args: + context: 请求上下文,包含配置和运行时状态 + next_call: 调用链中的下一个处理函数 + + Returns: + LLMResponse: 模型响应结果 + """ + pass + + class LLMModelBase(ABC): """LLM模型抽象基类""" @@ -48,9 +94,9 @@ class LLMModelBase(ABC): self, messages: list[LLMMessage], config: LLMGenerationConfig | None = None, - tools: dict[str, ToolExecutable] | None = None, - tool_choice: str | dict[str, Any] | None = None, - **kwargs: Any, + tools: list[Any] | None = None, + tool_choice: str | dict[str, Any] | ToolChoice | None = None, + timeout: float | None = None, ) -> LLMResponse: """生成高级响应""" pass @@ -59,8 +105,7 @@ class LLMModelBase(ABC): async def generate_embeddings( self, texts: list[str], - task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT, - **kwargs: Any, + config: LLMEmbeddingConfig, ) -> list[list[float]]: """生成文本嵌入向量""" pass @@ -98,35 +143,115 @@ class LLMModel(LLMModelBase): self.max_tokens = model_detail.max_tokens self._is_closed = False + self._ref_count = 0 + self._middlewares: list[LLMMiddleware] = [] + def _has_modality(self, modality: ModelModality, is_input: bool = True) -> bool: + target_set = ( + self.capabilities.input_modalities + if is_input + else self.capabilities.output_modalities + ) + return modality in target_set + + @property def can_process_images(self) -> bool: """检查模型是否支持图片作为输入。""" - return ModelModality.IMAGE in self.capabilities.input_modalities + return self._has_modality(ModelModality.IMAGE) + @property def can_process_video(self) -> bool: """检查模型是否支持视频作为输入。""" - return ModelModality.VIDEO in self.capabilities.input_modalities + return self._has_modality(ModelModality.VIDEO) + @property def can_process_audio(self) -> bool: """检查模型是否支持音频作为输入。""" - return ModelModality.AUDIO in self.capabilities.input_modalities + return self._has_modality(ModelModality.AUDIO) + @property def can_generate_images(self) -> bool: """检查模型是否支持生成图片。""" - return ModelModality.IMAGE in self.capabilities.output_modalities + return self._has_modality(ModelModality.IMAGE, is_input=False) + @property def can_generate_audio(self) -> bool: """检查模型是否支持生成音频 (TTS)。""" - return ModelModality.AUDIO in self.capabilities.output_modalities - - def can_use_tools(self) -> bool: - """检查模型是否支持工具调用/函数调用。""" - return self.capabilities.supports_tool_calling + return self._has_modality(ModelModality.AUDIO, is_input=False) + @property def is_embedding_model(self) -> bool: """检查这是否是一个嵌入模型。""" return self.capabilities.is_embedding_model + def add_middleware(self, middleware: LLMMiddleware) -> None: + """注册一个中间件到处理管道的最外层""" + self._middlewares.append(middleware) + + def _build_pipeline(self) -> NextCall: + """ + 构建完整的中间件调用链。顺序为: + 用户自定义中间件 -> Retry -> Logging -> KeySelection -> Network (终结者) + """ + from .adapters import get_adapter_for_api_type + + client_settings = get_llm_config().client_settings + retry_config = RetryConfig( + max_retries=client_settings.max_retries, + retry_delay=client_settings.retry_delay, + ) + adapter = get_adapter_for_api_type(self.api_type) + + network_middleware = NetworkRequestMiddleware(self, adapter) + + async def terminal_handler(ctx: LLMContext) -> LLMResponse: + async def _noop(_: LLMContext) -> LLMResponse: + raise RuntimeError("NetworkRequestMiddleware 不应调用 next_call") + + return await network_middleware(ctx, _noop) + + def _wrap(middleware: LLMMiddleware, next_call: NextCall) -> NextCall: + async def _handler(inner_ctx: LLMContext) -> LLMResponse: + return await middleware(inner_ctx, next_call) + + return _handler + + handler: NextCall = terminal_handler + handler = _wrap( + KeySelectionMiddleware(self.key_store, self.provider_name, self.api_keys), + handler, + ) + handler = _wrap( + LoggingMiddleware(self.provider_name, self.model_name), + handler, + ) + handler = _wrap( + RetryMiddleware(retry_config, self.key_store), + handler, + ) + + for middleware in reversed(self._middlewares): + handler = _wrap(middleware, handler) + + return handler + + def _get_effective_api_type(self) -> str: + """ + 获取实际生效的 API 类型。 + 主要用于 Smart 模式下,判断日志净化应该使用哪种格式。 + """ + if self.api_type != "smart": + return self.api_type + + if self.model_detail.api_type: + return self.model_detail.api_type + if ( + "gemini" in self.model_name.lower() + and "openai" not in self.model_name.lower() + ): + return "gemini" + return "openai" + async def _get_http_client(self) -> LLMHttpClient: """获取HTTP客户端""" if self.http_client.is_closed: @@ -163,307 +288,6 @@ class LLMModel(LLMModelBase): return selected_key - async def _perform_api_call( - self, - prepare_request_func: Callable[[str], Awaitable["RequestData"]], - parse_response_func: Callable[[dict[str, Any]], Any], - http_client: "LLMHttpClient", - failed_keys: set[str] | None = None, - log_context: str = "API", - ) -> tuple[Any, str]: - """执行API调用的通用核心方法""" - api_key = await self._select_api_key(failed_keys) - - try: - request_data = await prepare_request_func(api_key) - - logger.info( - f"🌐 发起LLM请求 - 模型: {self.provider_name}/{self.model_name} " - f"[{log_context}]" - ) - logger.debug(f"📡 请求URL: {request_data.url}") - masked_key = ( - f"{api_key[:8]}...{api_key[-4:] if len(api_key) > 12 else '***'}" - ) - logger.debug(f"🔑 API密钥: {masked_key}") - logger.debug(f"📋 请求头: {dict(request_data.headers)}") - - sanitizer_req_context_map = {"gemini": "gemini_request"} - sanitizer_req_context = sanitizer_req_context_map.get( - self.api_type, "openai_request" - ) - sanitized_body = sanitize_for_logging( - request_data.body, context=sanitizer_req_context - ) - request_body_str = dump_json_safely( - sanitized_body, ensure_ascii=False, indent=2 - ) - logger.debug(f"📦 请求体: {request_body_str}") - - http_response = await http_client.post( - request_data.url, - headers=request_data.headers, - content=dump_json_safely(request_data.body, ensure_ascii=False), - ) - - logger.debug(f"📥 响应状态码: {http_response.status_code}") - logger.debug(f"📄 响应头: {dict(http_response.headers)}") - - response_bytes = await http_response.aread() - logger.debug(f"📦 响应体已完整读取 ({len(response_bytes)} bytes)") - - if http_response.status_code != 200: - error_text = response_bytes.decode("utf-8", errors="ignore") - logger.error( - f"❌ HTTP请求失败: {http_response.status_code} - {error_text} " - f"[{log_context}]" - ) - logger.debug(f"💥 完整错误响应: {error_text}") - - await self.key_store.record_failure( - api_key, http_response.status_code, error_text - ) - - if http_response.status_code in [401, 403]: - error_code = LLMErrorCode.API_KEY_INVALID - elif http_response.status_code == 429: - error_code = LLMErrorCode.API_RATE_LIMITED - elif http_response.status_code in [402, 413]: - error_code = LLMErrorCode.API_QUOTA_EXCEEDED - else: - error_code = LLMErrorCode.API_REQUEST_FAILED - - raise LLMException( - f"HTTP请求失败: {http_response.status_code}", - code=error_code, - details={ - "status_code": http_response.status_code, - "response": error_text, - "api_key": api_key, - }, - ) - - try: - response_json = json.loads(response_bytes) - - sanitizer_context_map = {"gemini": "gemini_response"} - sanitizer_context = sanitizer_context_map.get( - self.api_type, "openai_response" - ) - - sanitized_for_log = sanitize_for_logging( - response_json, context=sanitizer_context - ) - - response_json_str = json.dumps( - sanitized_for_log, ensure_ascii=False, indent=2 - ) - logger.debug(f"📋 响应JSON: {response_json_str}") - parsed_data = parse_response_func(response_json) - except Exception as e: - logger.error(f"解析 {log_context} 响应失败: {e}", e=e) - await self.key_store.record_failure(api_key, None, str(e)) - if isinstance(e, LLMException): - raise - else: - raise LLMException( - f"解析API {log_context} 响应失败: {e}", - code=LLMErrorCode.RESPONSE_PARSE_ERROR, - cause=e, - ) - - logger.info(f"🎯 LLM响应解析完成 [{log_context}]") - return parsed_data, api_key - - except LLMException: - raise - except Exception as e: - error_log_msg = f"生成 {log_context.lower()} 时发生未预期错误: {e}" - logger.error(error_log_msg, e=e) - await self.key_store.record_failure(api_key, None, str(e)) - raise LLMException( - error_log_msg, - code=LLMErrorCode.GENERATION_FAILED - if log_context == "Generation" - else LLMErrorCode.EMBEDDING_FAILED, - cause=e, - ) - - async def _execute_embedding_request( - self, - adapter, - texts: list[str], - task_type: EmbeddingTaskType | str, - http_client: LLMHttpClient, - failed_keys: set[str] | None = None, - ) -> list[list[float]]: - """执行单次嵌入请求 - 供重试机制调用""" - - async def prepare_request(api_key: str) -> RequestData: - return adapter.prepare_embedding_request( - model=self, - api_key=api_key, - texts=texts, - task_type=task_type, - ) - - def parse_response(response_json: dict[str, Any]) -> list[list[float]]: - adapter.validate_embedding_response(response_json) - return adapter.parse_embedding_response(response_json) - - parsed_data, _api_key_used = await self._perform_api_call( - prepare_request_func=prepare_request, - parse_response_func=parse_response, - http_client=http_client, - failed_keys=failed_keys, - log_context="Embedding", - ) - return parsed_data - - async def _execute_with_smart_retry( - self, - adapter, - messages: list[LLMMessage], - config: LLMGenerationConfig | None, - tools: dict[str, ToolExecutable] | None, - tool_choice: str | dict[str, Any] | None, - http_client: LLMHttpClient, - ): - """智能重试机制 - 使用统一的重试装饰器""" - ai_config = get_ai_config() - max_retries = ai_config.get("max_retries_llm", 3) - retry_delay = ai_config.get("retry_delay_llm", 2) - retry_config = RetryConfig(max_retries=max_retries, retry_delay=retry_delay) - - return await with_smart_retry( - self._execute_single_request, - adapter, - messages, - config, - tools, - tool_choice, - http_client, - retry_config=retry_config, - key_store=self.key_store, - provider_name=self.provider_name, - ) - - async def _execute_single_request( - self, - adapter, - messages: list[LLMMessage], - config: LLMGenerationConfig | None, - tools: dict[str, ToolExecutable] | None, - tool_choice: str | dict[str, Any] | None, - http_client: LLMHttpClient, - failed_keys: set[str] | None = None, - ) -> tuple[LLMResponse, str]: - """执行单次请求 - 供重试机制调用,直接返回 LLMResponse 和使用的 key""" - - async def prepare_request(api_key: str) -> RequestData: - return await adapter.prepare_advanced_request( - model=self, - api_key=api_key, - messages=messages, - config=config, - tools=tools, - tool_choice=tool_choice, - ) - - def parse_response(response_json: dict[str, Any]) -> LLMResponse: - response_data = adapter.parse_response( - model=self, - response_json=response_json, - is_advanced=True, - ) - from .types.models import LLMToolCall - - response_tool_calls = [] - if response_data.tool_calls: - for tc_data in response_data.tool_calls: - if isinstance(tc_data, LLMToolCall): - response_tool_calls.append(tc_data) - elif isinstance(tc_data, dict): - try: - response_tool_calls.append(LLMToolCall(**tc_data)) - except Exception as e: - logger.warning( - f"无法将工具调用数据转换为LLMToolCall: {tc_data}, " - f"error: {e}" - ) - else: - logger.warning(f"工具调用数据格式未知: {tc_data}") - - return LLMResponse( - text=response_data.text, - usage_info=response_data.usage_info, - images=response_data.images, - raw_response=response_data.raw_response, - tool_calls=response_tool_calls if response_tool_calls else None, - code_executions=response_data.code_executions, - grounding_metadata=response_data.grounding_metadata, - cache_info=response_data.cache_info, - ) - - parsed_data, api_key_used = await self._perform_api_call( - prepare_request_func=prepare_request, - parse_response_func=parse_response, - http_client=http_client, - failed_keys=failed_keys, - log_context="Generation", - ) - - if config: - if config.response_validator: - try: - config.response_validator(parsed_data) - except Exception as e: - raise LLMException( - f"响应内容未通过自定义验证器: {e}", - code=LLMErrorCode.API_RESPONSE_INVALID, - details={"validator_error": str(e)}, - cause=e, - ) from e - - policy = config.validation_policy - if policy: - if policy.get("require_image") and not parsed_data.images: - if self.api_type == "gemini" and parsed_data.raw_response: - usage_metadata = parsed_data.raw_response.get( - "usageMetadata", {} - ) - prompt_token_details = usage_metadata.get( - "promptTokensDetails", [] - ) - prompt_had_image = any( - detail.get("modality") == "IMAGE" - for detail in prompt_token_details - ) - - if prompt_had_image: - raise LLMException( - "响应验证失败:模型接收了图片输入但未生成图片。", - code=LLMErrorCode.API_RESPONSE_INVALID, - details={ - "policy": policy, - "text_response": parsed_data.text, - "raw_response": parsed_data.raw_response, - }, - ) - else: - logger.debug("Gemini提示词中未包含图片,跳过图片要求重试。") - else: - raise LLMException( - "响应验证失败:要求返回图片但未找到图片数据。", - code=LLMErrorCode.API_RESPONSE_INVALID, - details={ - "policy": policy, - "text_response": parsed_data.text, - }, - ) - - return parsed_data, api_key_used - async def close(self): """标记模型实例的当前使用周期结束""" if self._is_closed: @@ -481,108 +305,102 @@ class LLMModel(LLMModelBase): ) self._is_closed = False self._check_not_closed() + self._ref_count += 1 return self async def __aexit__(self, exc_type, exc_val, exc_tb): """异步上下文管理器出口""" _ = exc_type, exc_val, exc_tb - await self.close() + self._ref_count -= 1 + if self._ref_count <= 0: + self._ref_count = 0 + await self.close() def _check_not_closed(self): """检查实例是否已关闭""" if self._is_closed: raise RuntimeError(f"LLMModel实例已关闭: {self}") + async def _execute_core_generation(self, context: LLMContext) -> LLMResponse: + """ + [内核] 执行核心生成逻辑:构建管道并执行。 + 此方法作为中间件管道的终点被调用。 + """ + pipeline_handler = self._build_pipeline() + return await pipeline_handler(context) + async def generate_response( self, messages: list[LLMMessage], config: LLMGenerationConfig | None = None, - tools: dict[str, ToolExecutable] | None = None, - tool_choice: str | dict[str, Any] | None = None, - **kwargs: Any, + tools: list[Any] | None = None, + tool_choice: str | dict[str, Any] | ToolChoice | None = None, + timeout: float | None = None, ) -> LLMResponse: """ - 生成高级响应。 - 此方法现在只执行 *单次* LLM API 调用,并将结果(包括工具调用请求)返回。 + 生成高级响应 (支持中间件管道)。 """ self._check_not_closed() - from .adapters import get_adapter_for_api_type - from .config.generation import create_generation_config_from_kwargs + if self._generation_config and config: + final_request_config = self._generation_config.merge_with(config) + elif config: + final_request_config = config + else: + final_request_config = self._generation_config or LLMGenerationConfig() - final_request_config = self._generation_config or LLMGenerationConfig() - if kwargs: - kwargs_config = create_generation_config_from_kwargs(**kwargs) - merged_dict = final_request_config.to_dict() - merged_dict.update(kwargs_config.to_dict()) - final_request_config = LLMGenerationConfig(**merged_dict) + normalized_tools: list[Any] | None = None + if tools: + if isinstance(tools, dict): + normalized_tools = list(tools.values()) + elif isinstance(tools, list): + normalized_tools = tools + else: + normalized_tools = [tools] - if config is not None: - merged_dict = final_request_config.to_dict() - merged_dict.update(config.to_dict()) - final_request_config = LLMGenerationConfig(**merged_dict) - - adapter = get_adapter_for_api_type(self.api_type) - http_client = await self._get_http_client() - - response, _ = await self._execute_with_smart_retry( - adapter, - messages, - final_request_config, - tools, - tool_choice, - http_client, + context = LLMContext( + messages=messages, + config=final_request_config, + tools=normalized_tools, + tool_choice=tool_choice, + timeout=timeout, ) - return response + return await self._execute_core_generation(context) async def generate_embeddings( self, texts: list[str], - task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT, - **kwargs: Any, + config: LLMEmbeddingConfig | None = None, ) -> list[list[float]]: """生成文本嵌入向量""" self._check_not_closed() if not texts: return [] - from .adapters import get_adapter_for_api_type + final_config = config or LLMEmbeddingConfig() - adapter = get_adapter_for_api_type(self.api_type) - if not adapter: + context = LLMContext( + messages=[], + config=final_config, + tools=None, + tool_choice=None, + timeout=None, + request_type="embedding", + extra={"texts": texts}, + ) + + pipeline = self._build_pipeline() + response = await pipeline(context) + embeddings = ( + response.cache_info.get("embeddings") if response.cache_info else None + ) + if embeddings is None: raise LLMException( - f"未找到适用于 API 类型 '{self.api_type}' 的嵌入适配器", - code=LLMErrorCode.CONFIGURATION_ERROR, + "嵌入请求未返回 embeddings 数据", + code=LLMErrorCode.EMBEDDING_FAILED, ) - - http_client = await self._get_http_client() - - ai_config = get_ai_config() - default_max_retries = ai_config.get("max_retries_llm", 3) - default_retry_delay = ai_config.get("retry_delay_llm", 2) - max_retries_embed = kwargs.get( - "max_retries_embed", max(1, default_max_retries // 2) - ) - retry_delay_embed = kwargs.get("retry_delay_embed", default_retry_delay / 2) - - retry_config = RetryConfig( - max_retries=max_retries_embed, - retry_delay=retry_delay_embed, - exponential_backoff=True, - key_rotation=True, - ) - - return await with_smart_retry( - self._execute_embedding_request, - adapter, - texts, - task_type, - http_client, - retry_config=retry_config, - key_store=self.key_store, - provider_name=self.provider_name, - ) + return embeddings def __str__(self) -> str: status = "closed" if self._is_closed else "active" @@ -594,3 +412,401 @@ class LLMModel(LLMModelBase): f"LLMModel(provider={self.provider_name}, model={self.model_name}, " f"api_type={self.api_type}, status={status})" ) + + +class RetryMiddleware(BaseLLMMiddleware): + """ + 重试中间件:处理异常捕获与重试循环 + """ + + def __init__(self, retry_config: RetryConfig, key_store: KeyStatusStore): + self.retry_config = retry_config + self.key_store = key_store + + async def __call__(self, context: LLMContext, next_call: NextCall) -> LLMResponse: + last_exception: Exception | None = None + total_attempts = self.retry_config.max_retries + 1 + + for attempt in range(total_attempts): + try: + context.runtime_state["attempt"] = attempt + 1 + return await next_call(context) + + except LLMException as e: + last_exception = e + api_key = context.runtime_state.get("api_key") + + if api_key: + status_code = e.details.get("status_code") + error_msg = f"({e.code.name}) {e.message}" + await self.key_store.record_failure(api_key, status_code, error_msg) + + if not _should_retry_llm_error( + e, attempt, self.retry_config.max_retries + ): + raise e + + if attempt == total_attempts - 1: + raise e + + wait_time = self.retry_config.retry_delay + if self.retry_config.exponential_backoff: + wait_time *= 2**attempt + + logger.warning( + f"请求失败,{wait_time:.2f}秒后重试" + f" (第{attempt + 1}/{self.retry_config.max_retries}次重试): {e}" + ) + await asyncio.sleep(wait_time) + + except Exception as e: + logger.error(f"非预期异常,停止重试: {e}", e=e) + raise e + + if last_exception: + raise last_exception + raise LLMException("重试循环异常结束") + + +class KeySelectionMiddleware(BaseLLMMiddleware): + """ + 密钥选择中间件:负责轮询获取可用 API Key + """ + + def __init__( + self, key_store: KeyStatusStore, provider_name: str, api_keys: list[str] + ): + self.key_store = key_store + self.provider_name = provider_name + self.api_keys = api_keys + self._failed_keys: set[str] = set() + + async def __call__(self, context: LLMContext, next_call: NextCall) -> LLMResponse: + selected_key = await self.key_store.get_next_available_key( + self.provider_name, self.api_keys, exclude_keys=self._failed_keys + ) + + if not selected_key: + raise LLMException( + f"提供商 {self.provider_name} 无可用 API Key", + code=LLMErrorCode.NO_AVAILABLE_KEYS, + ) + + context.runtime_state["api_key"] = selected_key + + try: + response = await next_call(context) + return response + except LLMException as e: + self._failed_keys.add(selected_key) + masked = f"{selected_key[:8]}..." + if isinstance(e.details, dict): + e.details["api_key"] = masked + raise e + + +class LoggingMiddleware(BaseLLMMiddleware): + """ + 日志中间件:负责请求和响应的日志记录与脱敏 + """ + + def __init__( + self, provider_name: str, model_name: str, log_context: str = "Generation" + ): + self.provider_name = provider_name + self.model_name = model_name + self.log_context = log_context + + async def __call__(self, context: LLMContext, next_call: NextCall) -> LLMResponse: + attempt = context.runtime_state.get("attempt", 1) + api_key = context.runtime_state.get("api_key", "unknown") + masked_key = f"{api_key[:8]}..." + + logger.info( + f"🌐 发起LLM请求 (尝试 {attempt}) - {self.provider_name}/{self.model_name} " + f"[{self.log_context}] Key: {masked_key}" + ) + + try: + start_time = time.monotonic() + response = await next_call(context) + duration = (time.monotonic() - start_time) * 1000 + logger.info(f"🎯 LLM响应成功 [{self.log_context}] 耗时: {duration:.2f}ms") + return response + except Exception as e: + logger.error(f"❌ 请求异常 [{self.log_context}]: {type(e).__name__} - {e}") + raise e + + +class NetworkRequestMiddleware(BaseLLMMiddleware): + """ + 网络请求中间件:执行 Adapter 转换和 HTTP 请求 + """ + + def __init__(self, model_instance: "LLMModel", adapter: "BaseAdapter"): + self.model = model_instance + self.http_client = model_instance.http_client + self.adapter = adapter + self.key_store = model_instance.key_store + + async def __call__(self, context: LLMContext, next_call: NextCall) -> LLMResponse: + api_key = context.runtime_state["api_key"] + + request_data: RequestData + gen_config: LLMGenerationConfig | None = None + embed_config: LLMEmbeddingConfig | None = None + + if context.request_type == "embedding": + embed_config = cast(LLMEmbeddingConfig, context.config) + texts = (context.extra or {}).get("texts", []) + request_data = self.adapter.prepare_embedding_request( + model=self.model, + api_key=api_key, + texts=texts, + config=embed_config, + ) + else: + gen_config = cast(LLMGenerationConfig, context.config) + request_data = await self.adapter.prepare_advanced_request( + model=self.model, + api_key=api_key, + messages=context.messages, + config=gen_config, + tools=context.tools, + tool_choice=context.tool_choice, + ) + + masked_key = ( + f"{api_key[:8]}...{api_key[-4:] if len(api_key) > 12 else '***'}" + if api_key + else "N/A" + ) + logger.debug(f"🔑 API密钥: {masked_key}") + logger.debug(f"📡 请求URL: {request_data.url}") + logger.debug(f"📋 请求头: {dict(request_data.headers)}") + + if self.model.api_type == "smart": + effective_type = self.model._get_effective_api_type() + sanitizer_req_context = f"{effective_type}_request" + else: + sanitizer_req_context = self.adapter.log_sanitization_context + sanitized_body = sanitize_for_logging( + request_data.body, context=sanitizer_req_context + ) + + if request_data.files and isinstance(sanitized_body, dict): + file_info: list[str] = [] + file_count = 0 + if isinstance(request_data.files, list): + file_count = len(request_data.files) + for key, value in request_data.files: + filename = ( + value[0] + if isinstance(value, tuple) and len(value) > 0 + else "..." + ) + file_info.append(f"{key}='{filename}'") + elif isinstance(request_data.files, dict): + file_count = len(request_data.files) + file_info = list(request_data.files.keys()) + + sanitized_body["[MULTIPART_FILES]"] = f"Count: {file_count} | {file_info}" + + request_body_str = dump_json_safely( + sanitized_body, ensure_ascii=False, indent=2 + ) + logger.debug(f"📦 请求体: {request_body_str}") + + start_time = time.monotonic() + try: + http_response = await self.http_client.post( + request_data.url, + headers=request_data.headers, + content=dump_json_safely(request_data.body, ensure_ascii=False) + if not request_data.files + else None, + data=request_data.body if request_data.files else None, + files=request_data.files, + timeout=context.timeout, + ) + + logger.debug(f"📥 响应状态码: {http_response.status_code}") + + if exception := self.adapter.handle_http_error(http_response): + error_text = http_response.content.decode("utf-8", errors="ignore") + logger.debug(f"💥 完整错误响应: {error_text}") + await self.key_store.record_failure( + api_key, http_response.status_code, error_text + ) + raise exception + + response_bytes = await http_response.aread() + logger.debug(f"📦 响应体已完整读取 ({len(response_bytes)} bytes)") + + response_json = json.loads(response_bytes) + + sanitizer_resp_context = sanitizer_req_context.replace( + "_request", "_response" + ) + if sanitizer_resp_context == sanitizer_req_context: + sanitizer_resp_context = f"{sanitizer_req_context}_response" + + sanitized_response = sanitize_for_logging( + response_json, context=sanitizer_resp_context + ) + response_json_str = json.dumps( + sanitized_response, ensure_ascii=False, indent=2 + ) + logger.debug(f"📋 响应JSON: {response_json_str}") + + if context.request_type == "embedding": + self.adapter.validate_embedding_response(response_json) + embeddings = self.adapter.parse_embedding_response(response_json) + latency = (time.monotonic() - start_time) * 1000 + await self.key_store.record_success(api_key, latency) + + return LLMResponse( + text="", + raw_response=response_json, + cache_info={"embeddings": embeddings}, + ) + + response_data = self.adapter.parse_response( + self.model, response_json, is_advanced=True + ) + + should_rescue_image = ( + gen_config + and gen_config.validation_policy + and gen_config.validation_policy.get("require_image") + ) + if ( + should_rescue_image + and not response_data.images + and response_data.text + and gen_config + ): + markdown_matches = re.findall( + r"(!?\[.*?\]\((https?://[^\)]+)\))", response_data.text + ) + if markdown_matches: + logger.info( + f"检测到 {len(markdown_matches)} " + "个资源链接,尝试自动下载并清洗。" + ) + if response_data.images is None: + response_data.images = [] + + downloaded_urls = set() + for full_tag, url in markdown_matches: + try: + if url not in downloaded_urls: + content = await AsyncHttpx.get_content(url) + response_data.images.append(process_image_data(content)) + downloaded_urls.add(url) + response_data.text = response_data.text.replace( + full_tag, "" + ) + except Exception as exc: + logger.warning( + f"自动下载生成的图片失败: {url}, 错误: {exc}" + ) + response_data.text = response_data.text.strip() + + latency = (time.monotonic() - start_time) * 1000 + await self.key_store.record_success(api_key, latency) + + response_tool_calls: list[LLMToolCall] = [] + if response_data.tool_calls: + for tc_data in response_data.tool_calls: + if isinstance(tc_data, LLMToolCall): + response_tool_calls.append(tc_data) + elif isinstance(tc_data, dict): + try: + response_tool_calls.append(LLMToolCall(**tc_data)) + except Exception: + pass + + final_response = LLMResponse( + text=response_data.text, + content_parts=response_data.content_parts, + usage_info=response_data.usage_info, + images=response_data.images, + raw_response=response_data.raw_response, + tool_calls=response_tool_calls if response_tool_calls else None, + code_executions=response_data.code_executions, + grounding_metadata=response_data.grounding_metadata, + cache_info=response_data.cache_info, + thought_text=response_data.thought_text, + thought_signature=response_data.thought_signature, + ) + + if context.request_type == "generation" and gen_config: + if gen_config.response_validator: + try: + gen_config.response_validator(final_response) + except Exception as exc: + raise LLMException( + f"响应内容未通过自定义验证器: {exc}", + code=LLMErrorCode.API_RESPONSE_INVALID, + details={"validator_error": str(exc)}, + cause=exc, + ) from exc + + policy = gen_config.validation_policy + if policy: + effective_type = self.model._get_effective_api_type() + if policy.get("require_image") and not final_response.images: + if effective_type == "gemini" and response_data.raw_response: + usage_metadata = response_data.raw_response.get( + "usageMetadata", {} + ) + prompt_token_details = usage_metadata.get( + "promptTokensDetails", [] + ) + prompt_had_image = any( + detail.get("modality") == "IMAGE" + for detail in prompt_token_details + ) + + if prompt_had_image: + raise LLMException( + "响应验证失败:模型接收了图片输入但未生成图片。", + code=LLMErrorCode.API_RESPONSE_INVALID, + details={ + "policy": policy, + "text_response": final_response.text, + "raw_response": response_data.raw_response, + }, + ) + else: + logger.debug( + "Gemini提示词中未包含图片,跳过图片要求重试。" + ) + else: + raise LLMException( + "响应验证失败:要求返回图片但未找到图片数据。", + code=LLMErrorCode.API_RESPONSE_INVALID, + details={ + "policy": policy, + "text_response": final_response.text, + }, + ) + + return final_response + + except Exception as e: + if isinstance(e, LLMException): + raise e + + logger.error(f"解析响应失败或发生未知错误: {e}") + + if not isinstance(e, httpx.NetworkError | httpx.TimeoutException): + await self.key_store.record_failure(api_key, None, str(e)) + + raise LLMException( + f"网络请求异常: {type(e).__name__} - {e}", + code=LLMErrorCode.API_REQUEST_FAILED, + details={"api_key": masked_key}, + cause=e, + ) diff --git a/zhenxun/services/llm/session.py b/zhenxun/services/llm/session.py index 59937cf0..c7df1476 100644 --- a/zhenxun/services/llm/session.py +++ b/zhenxun/services/llm/session.py @@ -4,30 +4,34 @@ LLM 服务 - 会话客户端 提供一个有状态的、面向会话的 LLM 客户端,用于进行多轮对话和复杂交互。 """ +from abc import ABC, abstractmethod +from collections import defaultdict +from collections.abc import Awaitable, Callable import copy from dataclasses import dataclass, field import json -from typing import Any, TypeVar +from typing import Any, TypeVar, cast import uuid -from jinja2 import Environment -from nonebot.compat import type_validate_json +from jinja2 import Template +from nonebot.utils import is_coroutine_callable from nonebot_plugin_alconna.uniseg import UniMessage -from pydantic import BaseModel, ValidationError +from pydantic import BaseModel from zhenxun.services.log import logger -from zhenxun.utils.pydantic_compat import model_copy, model_dump, model_json_schema +from zhenxun.utils.pydantic_compat import model_json_schema from .config import ( CommonOverrides, + GenConfigBuilder, + LLMEmbeddingConfig, LLMGenerationConfig, ) -from .config.providers import get_ai_config +from .config.generation import OutputConfig +from .config.providers import get_ai_config, get_llm_config from .manager import get_global_default_model_name, get_model_instance -from .memory import BaseMemory, InMemoryMemory -from .tools.manager import tool_provider_manager +from .tools import tool_provider_manager from .types import ( - EmbeddingTaskType, LLMContentPart, LLMErrorCode, LLMException, @@ -35,19 +39,28 @@ from .types import ( LLMResponse, ModelName, ResponseFormat, + StructuredOutputStrategy, + ToolChoice, ToolExecutable, ToolProvider, ) -from .utils import normalize_to_llm_messages +from .types.models import ( + GeminiCodeExecution, + GeminiGoogleSearch, +) +from .utils import ( + create_cot_wrapper, + normalize_to_llm_messages, + parse_and_validate_json, + should_apply_autocot, +) T = TypeVar("T", bound=BaseModel) -jinja_env = Environment(autoescape=False) - @dataclass class AIConfig: - """AI配置类 - [重构后] 简化版本""" + """AI配置类""" model: ModelName = None default_embedding_model: ModelName = None @@ -61,6 +74,98 @@ class AIConfig: self.model = ai_config.get("default_model_name") +class BaseMemory(ABC): + """记忆系统的抽象基类。""" + + @abstractmethod + async def get_history(self, session_id: str) -> list[LLMMessage]: + raise NotImplementedError + + @abstractmethod + async def add_message(self, session_id: str, message: LLMMessage) -> None: + raise NotImplementedError + + @abstractmethod + async def add_messages(self, session_id: str, messages: list[LLMMessage]) -> None: + raise NotImplementedError + + @abstractmethod + async def clear_history(self, session_id: str) -> None: + raise NotImplementedError + + +class InMemoryMemory(BaseMemory): + """一个简单的、默认的内存记忆后端。""" + + def __init__(self, max_messages: int = 50, **kwargs: Any): + self._history: dict[str, list[LLMMessage]] = defaultdict(list) + self._max_messages = max_messages + + def _trim_history(self, session_id: str) -> None: + """修剪历史记录,确保不超过最大长度,同时保留 System Prompt""" + history = self._history[session_id] + if len(history) <= self._max_messages: + return + + has_system = history and history[0].role == "system" + + if has_system: + keep_count = max(0, self._max_messages - 1) + self._history[session_id] = [history[0], *history[-keep_count:]] + else: + self._history[session_id] = history[-self._max_messages :] + + async def get_history(self, session_id: str) -> list[LLMMessage]: + return self._history.get(session_id, []).copy() + + async def add_message(self, session_id: str, message: LLMMessage) -> None: + self._history[session_id].append(message) + self._trim_history(session_id) + + async def add_messages(self, session_id: str, messages: list[LLMMessage]) -> None: + self._history[session_id].extend(messages) + self._trim_history(session_id) + + async def clear_history(self, session_id: str) -> None: + if session_id in self._history: + del self._history[session_id] + + +class MemoryProcessor(ABC): + """记忆处理器接口""" + + @abstractmethod + async def process(self, session_id: str, new_messages: list[LLMMessage]) -> None: + pass + + +_default_memory_factory: Callable[[], BaseMemory] | None = None + + +def set_default_memory_backend(factory: Callable[[], BaseMemory]): + """ + 设置全局默认记忆后端工厂,允许统一替换会话的记忆实现。 + """ + global _default_memory_factory + _default_memory_factory = factory + + +def _get_default_memory() -> BaseMemory: + if _default_memory_factory: + return _default_memory_factory() + return InMemoryMemory() + + +DEFAULT_IVR_TEMPLATE = ( + "你的响应未能通过结构校验。\n" + "错误详情: {error_msg}\n\n" + "请执行以下步骤进行修正:\n" + "1. 反思:分析为什么会出现这个错误。\n" + "2. 修正:生成一个新的、符合 Schema 要求的 JSON 对象。\n" + "请直接输出修正后的 JSON,不要包含 Markdown 标记或其他解释。" +) + + class AI: """ 统一的AI服务类 - 提供了带记忆的会话接口。 @@ -73,6 +178,7 @@ class AI: config: AIConfig | None = None, memory: BaseMemory | None = None, default_generation_config: LLMGenerationConfig | None = None, + processors: list[MemoryProcessor] | None = None, ): """ 初始化AI服务 @@ -81,24 +187,45 @@ class AI: session_id: 唯一的会话ID,用于隔离记忆。 config: AI 配置. memory: 可选的自定义记忆后端。如果为None,则使用默认的InMemoryMemory。 - default_generation_config: (新增) 此AI实例的默认生成配置。 + default_generation_config: 此AI实例的默认生成配置。 + processors: 记忆处理器列表,在添加记忆后触发。 """ self.session_id = session_id or str(uuid.uuid4()) self.config = config or AIConfig() - self.memory = memory or InMemoryMemory() + self.memory = memory or _get_default_memory() self.default_generation_config = ( default_generation_config or LLMGenerationConfig() ) + self.processors = processors or [] global_providers = tool_provider_manager._providers config_providers = self.config.tool_providers self._tool_providers = list(dict.fromkeys(global_providers + config_providers)) + self.message_buffer: list[LLMMessage] = [] async def clear_history(self): """清空当前会话的历史记录。""" await self.memory.clear_history(self.session_id) logger.info(f"AI会话历史记录已清空 (session_id: {self.session_id})") + async def add_observation( + self, message: str | UniMessage | LLMMessage | list[LLMContentPart] + ): + """ + 将一条观察消息加入缓冲区,不立即触发模型调用。 + + 返回: + int: 缓冲区中消息的数量。 + """ + current_message = await self._normalize_input_to_message(message) + self.message_buffer.append(current_message) + content_preview = str(current_message.content)[:50] + logger.debug( + f"[放入观察] {content_preview} (缓冲区大小: {len(self.message_buffer)})", + "AI_MEMORY", + ) + return len(self.message_buffer) + async def add_user_message_to_history( self, message: str | LLMMessage | list[LLMContentPart] ): @@ -161,7 +288,7 @@ class AI: self, message: str | UniMessage | LLMMessage | list[LLMContentPart] ) -> LLMMessage: """ - [重构后] 内部辅助方法,将各种输入类型统一转换为单个 LLMMessage 对象。 + 内部辅助方法,将各种输入类型统一转换为单个 LLMMessage 对象。 它调用共享的工具函数并提取最后一条消息(通常是用户输入)。 """ messages = await normalize_to_llm_messages(message) @@ -172,17 +299,79 @@ class AI: ) return messages[-1] + async def generate_internal( + self, + messages: list[LLMMessage], + *, + model: ModelName = None, + config: LLMGenerationConfig | GenConfigBuilder | None = None, + tools: list[Any] | dict[str, ToolExecutable] | None = None, + tool_choice: str | dict[str, Any] | ToolChoice | None = None, + timeout: float | None = None, + model_instance: Any = None, + ) -> LLMResponse: + """ + 内部生成核心方法,负责配置合并、工具解析和模型调用。 + 此方法不处理历史记录的存储,供 AgentExecutor 或 chat 方法调用。 + """ + final_config = self.default_generation_config + if isinstance(config, GenConfigBuilder): + config = config.build() + + if config: + final_config = final_config.merge_with(config) + + final_tools_list = [] + if tools: + if isinstance(tools, dict): + final_tools_list = list(tools.values()) + elif isinstance(tools, list): + to_resolve: list[Any] = [] + for t in tools: + if isinstance(t, str | dict): + to_resolve.append(t) + else: + final_tools_list.append(t) + + if to_resolve: + resolved_dict = await self._resolve_tools(to_resolve) + final_tools_list.extend(resolved_dict.values()) + + if model_instance: + return await model_instance.generate_response( + messages, + config=final_config, + tools=final_tools_list if final_tools_list else None, + tool_choice=tool_choice, + timeout=timeout, + ) + + resolved_model_name = self._resolve_model_name(model or self.config.model) + async with await get_model_instance( + resolved_model_name, + override_config=None, + ) as instance: + return await instance.generate_response( + messages, + config=final_config, + tools=final_tools_list if final_tools_list else None, + tool_choice=tool_choice, + timeout=timeout, + ) + async def chat( self, - message: str | UniMessage | LLMMessage | list[LLMContentPart], + message: str | UniMessage | LLMMessage | list[LLMContentPart] | None, *, model: ModelName = None, instruction: str | None = None, template_vars: dict[str, Any] | None = None, preserve_media_in_history: bool | None = None, - tools: list[dict[str, Any] | str] | dict[str, ToolExecutable] | None = None, - tool_choice: str | dict[str, Any] | None = None, - config: LLMGenerationConfig | None = None, + tools: list[Any] | dict[str, ToolExecutable] | None = None, + tool_choice: str | dict[str, Any] | ToolChoice | None = None, + config: LLMGenerationConfig | GenConfigBuilder | None = None, + use_buffer: bool = False, + timeout: float | None = None, ) -> LLMResponse: """ 核心交互方法,管理会话历史并执行单次LLM调用。 @@ -198,18 +387,27 @@ class AI: tools: 可用的工具列表或工具字典,支持临时工具和预配置工具。 tool_choice: 工具选择策略,控制AI如何选择和使用工具。 config: 生成配置对象,用于覆盖默认的生成参数。 + use_buffer: 是否刷新并包含消息缓冲区的内容,在此次对话中一次性提交。 + timeout: HTTP 请求超时时间(秒)。 返回: LLMResponse: 包含AI回复、工具调用请求、使用信息等的完整响应对象。 """ - current_message = await self._normalize_input_to_message(message) + messages_to_add: list[LLMMessage] = [] + if message: + current_message = await self._normalize_input_to_message(message) + messages_to_add.append(current_message) + + if use_buffer and self.message_buffer: + messages_to_add = self.message_buffer + messages_to_add + self.message_buffer.clear() messages_for_run = [] final_instruction = instruction if final_instruction and template_vars: try: - template = jinja_env.from_string(final_instruction) + template = Template(final_instruction) final_instruction = template.render(**template_vars) logger.debug(f"渲染后的系统指令: {final_instruction}") except Exception as e: @@ -220,51 +418,55 @@ class AI: current_history = await self.memory.get_history(self.session_id) messages_for_run.extend(current_history) - messages_for_run.append(current_message) + messages_for_run.extend(messages_to_add) try: - resolved_model_name = self._resolve_model_name(model or self.config.model) - - final_config = model_copy(self.default_generation_config, deep=True) - if config: - update_dict = model_dump(config, exclude_unset=True) - final_config = model_copy(final_config, update=update_dict) - - ad_hoc_tools = None - if tools: - if isinstance(tools, dict): - ad_hoc_tools = tools - else: - ad_hoc_tools = await self._resolve_tools(tools) - - async with await get_model_instance( - resolved_model_name, - override_config=final_config.to_dict(), - ) as model_instance: - response = await model_instance.generate_response( - messages_for_run, tools=ad_hoc_tools, tool_choice=tool_choice - ) + response = await self.generate_internal( + messages_for_run, + model=model, + config=config, + tools=tools, + tool_choice=tool_choice, + timeout=timeout, + ) should_preserve = ( preserve_media_in_history if preserve_media_in_history is not None else self.config.default_preserve_media_in_history ) - user_msg_to_store = ( - current_message - if should_preserve - else self._sanitize_message_for_history(current_message) - ) - assistant_response_msg = LLMMessage.assistant_text_response(response.text) - if response.tool_calls: - assistant_response_msg = LLMMessage.assistant_tool_calls( - response.tool_calls, response.text + msgs_to_store: list[LLMMessage] = [] + for msg in messages_to_add: + store_msg = ( + msg if should_preserve else self._sanitize_message_for_history(msg) ) + msgs_to_store.append(store_msg) + + if response.content_parts: + assistant_response_msg = LLMMessage( + role="assistant", + content=response.content_parts, + tool_calls=response.tool_calls, + ) + else: + assistant_response_msg = LLMMessage.assistant_text_response( + response.text + ) + if response.tool_calls: + assistant_response_msg = LLMMessage.assistant_tool_calls( + response.tool_calls, response.text + ) await self.memory.add_messages( - self.session_id, [user_msg_to_store, assistant_response_msg] + self.session_id, [*msgs_to_store, assistant_response_msg] ) + if self.processors: + for processor in self.processors: + await processor.process( + self.session_id, [*msgs_to_store, assistant_response_msg] + ) + return response except Exception as e: @@ -280,7 +482,7 @@ class AI: *, model: ModelName = None, timeout: int | None = None, - config: LLMGenerationConfig | None = None, + config: LLMGenerationConfig | GenConfigBuilder | None = None, ) -> LLMResponse: """ 代码执行 @@ -294,16 +496,18 @@ class AI: 返回: LLMResponse: 包含执行结果的完整响应对象。 """ - resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash" + resolved_model = model or self.config.model code_config = CommonOverrides.gemini_code_execution() if timeout: code_config.custom_params = code_config.custom_params or {} code_config.custom_params["code_execution_timeout"] = timeout + if isinstance(config, GenConfigBuilder): + config = config.build() + if config: - update_dict = model_dump(config, exclude_unset=True) - code_config = model_copy(code_config, update=update_dict) + code_config = code_config.merge_with(config) return await self.chat(prompt, model=resolved_model, config=code_config) @@ -317,7 +521,7 @@ class AI: "根据用户的查询找到最相关的信息,并进行总结和回答。" ), template_vars: dict[str, Any] | None = None, - config: LLMGenerationConfig | None = None, + config: LLMGenerationConfig | GenConfigBuilder | None = None, ) -> LLMResponse: """ 信息搜索的便捷入口,原生支持多模态查询。 @@ -325,9 +529,11 @@ class AI: logger.info("执行 'search' 任务...") search_config = CommonOverrides.gemini_grounding() + if isinstance(config, GenConfigBuilder): + config = config.build() + if config: - update_dict = model_dump(config, exclude_unset=True) - search_config = model_copy(search_config, update=update_dict) + search_config = search_config.merge_with(config) return await self.chat( query, @@ -339,21 +545,31 @@ class AI: async def generate_structured( self, - message: str | LLMMessage | list[LLMContentPart], + message: str | LLMMessage | list[LLMContentPart] | None, response_model: type[T], *, model: ModelName = None, + tools: list[Any] | dict[str, ToolExecutable] | None = None, + tool_choice: str | dict[str, Any] | ToolChoice | None = None, instruction: str | None = None, - config: LLMGenerationConfig | None = None, + timeout: float | None = None, + template_vars: dict[str, Any] | None = None, + config: LLMGenerationConfig | GenConfigBuilder | None = None, + max_validation_retries: int | None = None, + validation_callback: Callable[[T], Any | Awaitable[Any]] | None = None, + error_prompt_template: str | None = None, + auto_thinking: bool = False, ) -> T: """ 生成结构化响应,并自动解析为指定的Pydantic模型。 参数: - message: 用户输入的消息内容,支持多种格式。 + message: 用户输入的消息内容,支持多种格式。为None时只使用历史+缓冲区。 response_model: 用于解析和验证响应的Pydantic模型类。 model: 要使用的模型名称,如果为None则使用配置中的默认模型。 instruction: 本次调用的特定系统指令,会与JSON Schema指令合并。 + timeout: HTTP 请求超时时间(秒)。 + template_vars: 系统指令中的模板变量,用于动态渲染。 config: 生成配置对象,用于覆盖默认的生成参数。 返回: @@ -362,6 +578,48 @@ class AI: 异常: LLMException: 如果模型返回的不是有效的JSON或验证失败。 """ + if isinstance(config, GenConfigBuilder): + config = config.build() + + final_config = self.default_generation_config.merge_with(config) + + if final_config is None: + final_config = LLMGenerationConfig() + + if max_validation_retries is None: + max_validation_retries = ( + get_llm_config().client_settings.structured_retries + ) + + resolved_model_name = self._resolve_model_name(model or self.config.model) + + request_autocot = True if auto_thinking is False else auto_thinking + effective_auto_thinking = should_apply_autocot( + request_autocot, resolved_model_name, final_config + ) + + target_model: type[T] = response_model + if effective_auto_thinking: + target_model = cast(type[T], create_cot_wrapper(response_model)) + response_model = target_model + + cot_instruction = ( + "请务必先在 `reasoning` 字段中进行详细的一步步推理,确保逻辑正确," + "然后再填充 `result` 字段。" + ) + if instruction: + instruction = f"{instruction}\n\n{cot_instruction}" + else: + instruction = cot_instruction + + final_instruction = instruction + if final_instruction and template_vars: + try: + template = Template(final_instruction) + final_instruction = template.render(**template_vars) + except Exception as e: + logger.error(f"渲染结构化指令模板失败: {e}", e=e) + try: json_schema = model_json_schema(response_model) except AttributeError: @@ -369,41 +627,149 @@ class AI: schema_str = json.dumps(json_schema, ensure_ascii=False, indent=2) - system_prompt = ( - (f"{instruction}\n\n" if instruction else "") - + "你必须严格按照以下 JSON Schema 格式进行响应。" - + "不要包含任何额外的解释、注释或代码块标记,只返回纯粹的 JSON 对象。\n\n" + prompt_prefix = f"{final_instruction}\n\n" if final_instruction else "" + structured_strategy = ( + final_config.output.structured_output_strategy + if final_config.output + else None ) - system_prompt += f"JSON Schema:\n```json\n{schema_str}\n```" + if structured_strategy == StructuredOutputStrategy.TOOL_CALL: + system_prompt = prompt_prefix + "请调用提供的工具提交结构化数据。" + else: + system_prompt = ( + prompt_prefix + + "请严格按照以下 JSON Schema 格式进行响应。不应包含任何额外的解释、" + "注释或代码块标记,只返回一个合法的 JSON 对象。\n\n" + ) + system_prompt += f"JSON Schema:\n```json\n{schema_str}\n```" - final_config = model_copy(config) if config else LLMGenerationConfig() - - final_config.response_format = ResponseFormat.JSON - final_config.response_schema = json_schema - - response = await self.chat( - message, model=model, instruction=system_prompt, config=final_config + structured_strategy = ( + final_config.output.structured_output_strategy + if final_config.output + else StructuredOutputStrategy.NATIVE ) - try: - return type_validate_json(response_model, response.text) - except ValidationError as e: - logger.error(f"LLM结构化输出验证失败: {e}", e=e) - raise LLMException( - "LLM返回的JSON未能通过结构验证。", - code=LLMErrorCode.RESPONSE_PARSE_ERROR, - details={"raw_response": response.text, "validation_error": str(e)}, - cause=e, - ) - except Exception as e: - logger.error(f"解析LLM结构化输出时发生未知错误: {e}", e=e) - raise LLMException( - "解析LLM的JSON输出时失败。", - code=LLMErrorCode.RESPONSE_PARSE_ERROR, - details={"raw_response": response.text}, - cause=e, + final_tools_list: list[ToolExecutable] | None = None + if structured_strategy != StructuredOutputStrategy.NATIVE: + if tools: + final_tools_list = [] + if isinstance(tools, dict): + final_tools_list = list(tools.values()) + elif isinstance(tools, list): + to_resolve: list[Any] = [] + for t in tools: + if isinstance(t, str | dict): + to_resolve.append(t) + else: + final_tools_list.append(t) + if to_resolve: + resolved_dict = await self._resolve_tools(to_resolve) + final_tools_list.extend(resolved_dict.values()) + elif tools: + logger.warning( + "检测到在 generate_structured (NATIVE 策略) 中传入了 tools。" + "为了避免 API 冲突(Gemini)及输出歧义(OpenAI),这些" + "tools 将被本次请求忽略。" + "若需使用工具,请使用 chat() 方法或 Agent 流程。" ) + if final_config.output is None: + final_config.output = OutputConfig() + + final_config.output.response_format = ResponseFormat.JSON + final_config.output.response_schema = json_schema + + messages_for_run = [LLMMessage.system(system_prompt)] + current_history = await self.memory.get_history(self.session_id) + messages_for_run.extend(current_history) + messages_for_run.extend(self.message_buffer) + if message: + normalized_message = await self._normalize_input_to_message(message) + messages_for_run.append(normalized_message) + + ivr_messages = list(messages_for_run) + last_exception: Exception | None = None + + for attempt in range(max_validation_retries + 1): + current_response_text: str = "" + + async with await get_model_instance( + resolved_model_name, + override_config=None, + ) as model_instance: + response = await model_instance.generate_response( + ivr_messages, + config=final_config, + tools=final_tools_list if final_tools_list else None, + tool_choice=tool_choice, + timeout=timeout, + ) + current_response_text = response.text + + try: + parsed_obj = parse_and_validate_json(response.text, target_model) + + final_obj: T = cast(T, parsed_obj) + if effective_auto_thinking: + logger.debug( + f"AutoCoT 思考过程: {getattr(parsed_obj, 'reasoning', '')}" + ) + final_obj = cast(T, getattr(parsed_obj, "result")) + + if validation_callback: + if is_coroutine_callable(validation_callback): + await validation_callback(final_obj) + else: + validation_callback(final_obj) + + return final_obj + + except Exception as e: + is_llm_error = isinstance(e, LLMException) + llm_error: LLMException | None = ( + cast(LLMException, e) if is_llm_error else None + ) + last_exception = e + + if attempt < max_validation_retries: + error_msg = ( + llm_error.details.get("validation_error", str(e)) + if llm_error + else str(e) + ) + raw_response = current_response_text or ( + llm_error.details.get("raw_response", "") if llm_error else "" + ) + logger.warning( + f"结构化校验失败 (尝试 {attempt + 1}/" + f"{max_validation_retries + 1})。正在尝试 IVR 修复... 错误:" + f"{error_msg}" + ) + + if raw_response: + ivr_messages.append( + LLMMessage.assistant_text_response(raw_response) + ) + else: + logger.warning( + "IVR 警告: 无法获取上一轮生成的原始文本," + "模型将在无上下文情况下尝试修复。" + ) + + template = error_prompt_template or DEFAULT_IVR_TEMPLATE + feedback_prompt = template.format(error_msg=error_msg) + ivr_messages.append(LLMMessage.user(feedback_prompt)) + continue + + if llm_error and not llm_error.recoverable: + raise llm_error + + if last_exception: + raise last_exception + raise LLMException( + "IVR 循环异常结束,未能生成有效结果。", code=LLMErrorCode.GENERATION_FAILED + ) + def _resolve_model_name(self, model_name: ModelName) -> str: """解析模型名称""" if model_name: @@ -423,8 +789,7 @@ class AI: texts: list[str] | str, *, model: ModelName = None, - task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT, - **kwargs: Any, + config: LLMEmbeddingConfig | None = None, ) -> list[list[float]]: """ 生成文本嵌入向量,将文本转换为数值向量表示。 @@ -432,14 +797,13 @@ class AI: 参数: texts: 要生成嵌入的文本内容,支持单个字符串或字符串列表。 model: 嵌入模型名称,如果为None则使用配置中的默认嵌入模型。 - task_type: 嵌入任务类型,影响向量的优化方向(如检索、分类等)。 - **kwargs: 传递给嵌入模型的额外参数。 + config: 嵌入配置 返回: list[list[float]]: 文本对应的嵌入向量列表,每个向量为浮点数列表。 异常: - LLMException: 如果嵌入生成失败或模型配置错误。 + LLMException: 当嵌入生成失败或模型配置错误时抛出 """ if isinstance(texts, str): texts = [texts] @@ -452,18 +816,20 @@ class AI: ) if not resolved_model_str: raise LLMException( - "使用 embed 功能时必须指定嵌入模型名称," - "或在 AIConfig 中配置 default_embedding_model。", + "使用 embed 方法时未指定嵌入模型名称," + "且 AIConfig 未设置 default_embedding_model。", code=LLMErrorCode.MODEL_NOT_FOUND, ) resolved_model_str = self._resolve_model_name(resolved_model_str) + final_config = config or LLMEmbeddingConfig() + async with await get_model_instance( resolved_model_str, override_config=None, ) as embedding_model_instance: return await embedding_model_instance.generate_embeddings( - texts, task_type=task_type, **kwargs + texts, config=final_config ) except LLMException: raise @@ -484,6 +850,15 @@ class AI: resolved: dict[str, ToolExecutable] = {} for config in tool_configs: + if isinstance(config, str): + if config == "google_search": + resolved[config] = GeminiGoogleSearch() # type: ignore[arg-type] + continue + elif config == "code_execution": + resolved[config] = GeminiCodeExecution() # type: ignore[arg-type] + continue + elif config == "url_context": + pass name = config if isinstance(config, str) else config.get("name") if not name: raise LLMException( diff --git a/zhenxun/services/llm/tools.py b/zhenxun/services/llm/tools.py new file mode 100644 index 00000000..bbf1b9ed --- /dev/null +++ b/zhenxun/services/llm/tools.py @@ -0,0 +1,839 @@ +""" +工具模块 + +整合了工具参数解析器、工具提供者管理器与工具执行逻辑,便于在 LLM 服务层统一调用。 +""" + +import asyncio +from collections.abc import Callable +from enum import Enum +import inspect +import json +import re +import time +from typing import ( + Annotated, + Any, + Optional, + Union, + cast, + get_args, + get_origin, + get_type_hints, +) +from typing_extensions import override + +from httpx import NetworkError, TimeoutException + +try: + import ujson as fast_json +except ImportError: + fast_json = json + +import nonebot +from nonebot.dependencies import Dependent, Param +from nonebot.internal.adapter import Bot, Event +from nonebot.internal.params import ( + BotParam, + DefaultParam, + DependParam, + DependsInner, + EventParam, + StateParam, +) +from pydantic import BaseModel, Field, ValidationError, create_model +from pydantic.fields import FieldInfo + +from zhenxun.services.log import logger +from zhenxun.utils.decorator.retry import Retry +from zhenxun.utils.pydantic_compat import model_dump, model_fields, model_json_schema + +from .types import ( + LLMErrorCode, + LLMException, + LLMMessage, + LLMToolCall, + ToolExecutable, + ToolProvider, + ToolResult, +) +from .types.models import ToolDefinition +from .types.protocols import BaseCallbackHandler, ToolCallData + + +class ToolParam(Param): + """ + 工具参数提取器。 + + 用于在自定义工具函数(Function Tool)中,从 LLM 解析出的参数字典 + (`state["_tool_params"]`) + 中提取特定的参数值。通常配合 `Annotated` 和依赖注入系统使用。 + """ + + def __init__(self, *args: Any, name: str, **kwargs: Any): + super().__init__(*args, **kwargs) + self.name = name + + def __repr__(self) -> str: + return f"ToolParam(name={self.name})" + + @classmethod + @override + def _check_param( + cls, param: inspect.Parameter, allow_types: tuple[type[Param], ...] + ) -> Optional["ToolParam"]: + if param.default is not inspect.Parameter.empty and isinstance( + param.default, DependsInner + ): + return None + + if get_origin(param.annotation) is Annotated: + for arg in get_args(param.annotation): + if isinstance(arg, DependsInner): + return None + + if param.kind not in ( + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + ): + return cls(name=param.name) + return None + + @override + async def _solve(self, **kwargs: Any) -> Any: + state: dict[str, Any] = kwargs.get("state", {}) + tool_params = state.get("_tool_params", {}) + if self.name in tool_params: + return tool_params[self.name] + return None + + +class RunContext(BaseModel): + """ + 依赖注入容器(DI Container),保留原有上下文信息的同时提升获取类型的能力。 + """ + + session_id: str | None = None + scope: dict[str, Any] = Field(default_factory=dict) + extra: dict[str, Any] = Field(default_factory=dict) + + class Config: + arbitrary_types_allowed = True + + +class RunContextParam(Param): + """自动注入 RunContext 的参数解析器""" + + @classmethod + def _check_param( + cls, param: inspect.Parameter, allow_types: tuple[type[Param], ...] + ) -> Optional["RunContextParam"]: + if param.annotation is RunContext: + return cls() + return None + + async def _solve(self, **kwargs: Any) -> Any: + state = kwargs.get("state", {}) + return state.get("_agent_context") + + +def _parse_docstring_params(docstring: str | None) -> dict[str, str]: + """ + 解析文档字符串,提取参数描述。 + 支持 Google Style (Args:), ReST Style (:param:), 和中文风格 (参数:)。 + """ + if not docstring: + return {} + + params: dict[str, str] = {} + lines = docstring.splitlines() + + rest_pattern = re.compile(r"[:@]param\s+(\w+)\s*:?\s*(.*)") + found_rest = False + for line in lines: + match = rest_pattern.search(line) + if match: + params[match.group(1)] = match.group(2).strip() + found_rest = True + + if found_rest: + return params + + section_header_pattern = re.compile( + r"^\s*(?:Args|Arguments|Parameters|参数)\s*[::]\s*$" + ) + + param_section_active = False + google_pattern = re.compile(r"^\s*(\**\w+)(?:\s*\(.*?\))?\s*[::]\s*(.*)") + + for line in lines: + stripped_line = line.strip() + if not stripped_line: + continue + + if section_header_pattern.match(line): + param_section_active = True + continue + + if param_section_active: + if ( + stripped_line.endswith(":") or stripped_line.endswith(":") + ) and not google_pattern.match(line): + param_section_active = False + continue + + match = google_pattern.match(line) + if match: + name = match.group(1).lstrip("*") + desc = match.group(2).strip() + params[name] = desc + + return params + + +def _create_dynamic_model(func: Callable) -> type[BaseModel]: + """根据函数签名动态创建 Pydantic 模型""" + sig = inspect.signature(func) + doc_params = _parse_docstring_params(func.__doc__) + type_hints = get_type_hints(func, include_extras=True) + + fields = {} + for name, param in sig.parameters.items(): + if name in ("self", "cls"): + continue + + annotation = type_hints.get(name, Any) + default = param.default + + is_run_context = False + if annotation is RunContext: + is_run_context = True + else: + origin = get_origin(annotation) + if origin is Union: + args = get_args(annotation) + if RunContext in args: + is_run_context = True + + if is_run_context: + continue + + if default is not inspect.Parameter.empty and isinstance(default, DependsInner): + continue + + if get_origin(annotation) is Annotated: + args = get_args(annotation) + if any(isinstance(arg, DependsInner) for arg in args): + continue + + description = doc_params.get(name) + if isinstance(default, FieldInfo): + if description and not getattr(default, "description", None): + default.description = description + fields[name] = (annotation, default) + else: + if default is inspect.Parameter.empty: + default = ... + fields[name] = (annotation, Field(default, description=description)) + + return create_model(f"{func.__name__}Params", **fields) + + +class FunctionExecutable(ToolExecutable): + """一个 ToolExecutable 的实现,用于包装一个普通的 Python 函数。""" + + def __init__( + self, + func: Callable, + name: str, + description: str, + params_model: type[BaseModel] | None = None, + unpack_args: bool = False, + ): + self._func = func + self._name = name + self._description = description + self._params_model = params_model + self._unpack_args = unpack_args + + self.dependent = Dependent[Any].parse( + call=func, + allow_types=( + DependParam, + BotParam, + EventParam, + StateParam, + RunContextParam, + ToolParam, + DefaultParam, + ), + ) + + async def get_definition(self) -> ToolDefinition: + if not self._params_model: + return ToolDefinition( + name=self._name, + description=self._description, + parameters={"type": "object", "properties": {}}, + ) + + schema = model_json_schema(self._params_model) + + return ToolDefinition( + name=self._name, + description=self._description, + parameters={ + "type": "object", + "properties": schema.get("properties", {}), + "required": schema.get("required", []), + }, + ) + + async def execute( + self, context: RunContext | None = None, **kwargs: Any + ) -> ToolResult: + context = context or RunContext() + + tool_arguments = kwargs + + if self._params_model: + try: + _fields = model_fields(self._params_model) + validation_input = { + key: value for key, value in kwargs.items() if key in _fields + } + + validated_params = self._params_model(**validation_input) + + if not self._unpack_args: + pass + else: + validated_dict = model_dump(validated_params) + tool_arguments = validated_dict + + except ValidationError as e: + error_msgs = [] + for err in e.errors(): + loc = ".".join(str(x) for x in err["loc"]) + msg = err["msg"] + error_msgs.append(f"Parameter '{loc}': {msg}") + + formatted_error = "; ".join(error_msgs) + error_payload = { + "error_type": "InvalidArguments", + "message": f"Parameter validation failed: {formatted_error}", + "is_retryable": True, + } + return ToolResult( + output=json.dumps(error_payload, ensure_ascii=False), + display_content=f"Validation Error: {formatted_error}", + ) + except Exception as e: + logger.error( + f"执行工具 '{self._name}' 时参数验证或实例化失败: {e}", e=e + ) + raise + + state = { + "_tool_params": tool_arguments, + "_agent_context": context, + } + + bot: Bot | None = None + if context and context.scope.get("bot"): + bot = context.scope.get("bot") + if not bot: + try: + bot = nonebot.get_bot() + except ValueError: + pass + + event: Event | None = None + if context and context.scope.get("event"): + event = context.scope.get("event") + + raw_result = await self.dependent( + bot=bot, + event=event, + state=state, + ) + + return ToolResult(output=raw_result, display_content=str(raw_result)) + + +class BuiltinFunctionToolProvider(ToolProvider): + """一个内置的 ToolProvider,用于处理通过装饰器注册的函数。""" + + def __init__(self): + self._functions: dict[str, dict[str, Any]] = {} + + def register( + self, + name: str, + func: Callable, + description: str, + params_model: type[BaseModel] | None = None, + unpack_args: bool = False, + ): + self._functions[name] = { + "func": func, + "description": description, + "params_model": params_model, + "unpack_args": unpack_args, + } + + async def initialize(self) -> None: + pass + + async def discover_tools( + self, + allowed_servers: list[str] | None = None, + excluded_servers: list[str] | None = None, + ) -> dict[str, ToolExecutable]: + executables = {} + for name, info in self._functions.items(): + executables[name] = FunctionExecutable( + func=info["func"], + name=name, + description=info["description"], + params_model=info["params_model"], + unpack_args=info.get("unpack_args", False), + ) + return executables + + async def get_tool_executable( + self, name: str, config: dict[str, Any] + ) -> ToolExecutable | None: + if config.get("type", "function") == "function" and name in self._functions: + info = self._functions[name] + return FunctionExecutable( + func=info["func"], + name=name, + description=info["description"], + params_model=info["params_model"], + unpack_args=info.get("unpack_args", False), + ) + return None + + +class ToolProviderManager: + """工具提供者的中心化管理器,采用单例模式。""" + + _instance: "ToolProviderManager | None" = None + + def __new__(cls) -> "ToolProviderManager": + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + if hasattr(self, "_initialized") and self._initialized: + return + + self._providers: list[ToolProvider] = [] + self._resolved_tools: dict[str, ToolExecutable] | None = None + self._init_lock = asyncio.Lock() + self._init_promise: asyncio.Task | None = None + self._builtin_function_provider = BuiltinFunctionToolProvider() + self.register(self._builtin_function_provider) + self._initialized = True + + def register(self, provider: ToolProvider): + """注册一个新的 ToolProvider。""" + if provider not in self._providers: + self._providers.append(provider) + logger.info(f"已注册工具提供者: {provider.__class__.__name__}") + + def function_tool( + self, + name: str, + description: str, + params_model: type[BaseModel] | None = None, + ): + """装饰器:将一个函数注册为内置工具。""" + + def decorator(func: Callable): + if name in self._builtin_function_provider._functions: + logger.warning(f"正在覆盖已注册的函数工具: {name}") + + final_model = params_model + unpack_args = False + if final_model is None: + final_model = _create_dynamic_model(func) + unpack_args = True + + self._builtin_function_provider.register( + name=name, + func=func, + description=description, + params_model=final_model, + unpack_args=unpack_args, + ) + logger.info(f"已注册函数工具: '{name}'") + return func + + return decorator + + async def initialize(self) -> None: + """懒加载初始化所有已注册的 ToolProvider。""" + if not self._init_promise: + async with self._init_lock: + if not self._init_promise: + self._init_promise = asyncio.create_task( + self._initialize_providers() + ) + await self._init_promise + + async def _initialize_providers(self) -> None: + """内部初始化逻辑。""" + logger.info(f"开始初始化 {len(self._providers)} 个工具提供者...") + init_tasks = [provider.initialize() for provider in self._providers] + await asyncio.gather(*init_tasks, return_exceptions=True) + logger.info("所有工具提供者初始化完成。") + + async def get_resolved_tools( + self, + allowed_servers: list[str] | None = None, + excluded_servers: list[str] | None = None, + ) -> dict[str, ToolExecutable]: + """ + 获取所有已发现和解析的工具。 + 此方法会触发懒加载初始化,并根据是否传入过滤器来决定是否使用全局缓存。 + """ + await self.initialize() + + has_filters = allowed_servers is not None or excluded_servers is not None + + if not has_filters and self._resolved_tools is not None: + logger.debug("使用全局工具缓存。") + return self._resolved_tools + + if has_filters: + logger.info("检测到过滤器,执行临时工具发现 (不使用缓存)。") + logger.debug( + f"过滤器详情: allowed_servers={allowed_servers}, " + f"excluded_servers={excluded_servers}" + ) + else: + logger.info("未应用过滤器,开始全局工具发现...") + + all_tools: dict[str, ToolExecutable] = {} + + discover_tasks = [] + for provider in self._providers: + sig = inspect.signature(provider.discover_tools) + params_to_pass = {} + if "allowed_servers" in sig.parameters: + params_to_pass["allowed_servers"] = allowed_servers + if "excluded_servers" in sig.parameters: + params_to_pass["excluded_servers"] = excluded_servers + + discover_tasks.append(provider.discover_tools(**params_to_pass)) + + results = await asyncio.gather(*discover_tasks, return_exceptions=True) + + for i, provider_result in enumerate(results): + provider_name = self._providers[i].__class__.__name__ + if isinstance(provider_result, dict): + logger.debug( + f"提供者 '{provider_name}' 发现了 {len(provider_result)} 个工具。" + ) + for name, executable in provider_result.items(): + if name in all_tools: + logger.warning( + f"发现重复的工具名称 '{name}',后发现的将覆盖前者。" + ) + all_tools[name] = executable + elif isinstance(provider_result, Exception): + logger.error( + f"提供者 '{provider_name}' 在发现工具时出错: {provider_result}" + ) + + if not has_filters: + self._resolved_tools = all_tools + logger.info(f"全局工具发现完成,共找到并缓存了 {len(all_tools)} 个工具。") + else: + logger.info(f"带过滤器的工具发现完成,共找到 {len(all_tools)} 个工具。") + + return all_tools + + async def resolve_specific_tools( + self, tool_names: list[str] + ) -> dict[str, ToolExecutable]: + """ + 仅解析指定名称的工具,避免触发全量工具发现。 + """ + resolved: dict[str, ToolExecutable] = {} + if not tool_names: + return resolved + + await self.initialize() + + for name in tool_names: + config: dict[str, Any] = {"name": name} + for provider in self._providers: + try: + executable = await provider.get_tool_executable(name, config) + except Exception as exc: + logger.error( + f"provider '{provider.__class__.__name__}' 在解析工具 '{name}'" + f"时出错: {exc}", + e=exc, + ) + continue + + if executable: + resolved[name] = executable + break + else: + logger.warning(f"没有找到名为 '{name}' 的工具,已跳过。") + + return resolved + + async def get_function_tools( + self, names: list[str] | None = None + ) -> dict[str, ToolExecutable]: + """ + 仅从内置的函数提供者中解析指定的工具。 + """ + all_function_tools = await self._builtin_function_provider.discover_tools() + if names is None: + return all_function_tools + + resolved_tools = {} + for name in names: + if name in all_function_tools: + resolved_tools[name] = all_function_tools[name] + else: + logger.warning( + f"本地函数工具 '{name}' 未通过 @function_tool 注册,将被忽略。" + ) + return resolved_tools + + +tool_provider_manager = ToolProviderManager() +function_tool = tool_provider_manager.function_tool + + +class ToolErrorType(str, Enum): + """结构化工具错误的类型枚举。""" + + TOOL_NOT_FOUND = "ToolNotFound" + INVALID_ARGUMENTS = "InvalidArguments" + EXECUTION_ERROR = "ExecutionError" + USER_CANCELLATION = "UserCancellation" + + +class ToolErrorResult(BaseModel): + """一个结构化的工具执行错误模型。""" + + error_type: ToolErrorType = Field(..., description="错误的类型。") + message: str = Field(..., description="对错误的详细描述。") + is_retryable: bool = Field(False, description="指示这个错误是否可能通过重试解决。") + + +class ToolInvoker: + """ + 全能工具执行器。 + 负责接收工具调用请求,解析参数,触发回调,执行工具,并返回标准化的结果。 + """ + + def __init__(self, callbacks: list[BaseCallbackHandler] | None = None): + self.callbacks = callbacks or [] + + async def _trigger_callbacks(self, event_name: str, *args, **kwargs: Any) -> None: + if not self.callbacks: + return + tasks = [ + getattr(handler, event_name)(*args, **kwargs) + for handler in self.callbacks + if hasattr(handler, event_name) + ] + await asyncio.gather(*tasks, return_exceptions=True) + + async def execute_tool_call( + self, + tool_call: LLMToolCall, + available_tools: dict[str, ToolExecutable], + context: Any | None = None, + ) -> tuple[LLMToolCall, ToolResult]: + tool_name = tool_call.function.name + arguments_str = tool_call.function.arguments + arguments: dict[str, Any] = {} + + try: + if arguments_str: + arguments = json.loads(arguments_str) + except json.JSONDecodeError as e: + error_result = ToolErrorResult( + error_type=ToolErrorType.INVALID_ARGUMENTS, + message=f"参数解析失败: {e}", + is_retryable=False, + ) + return tool_call, ToolResult(output=model_dump(error_result)) + + tool_data = ToolCallData(tool_name=tool_name, tool_args=arguments) + pre_calculated_result: ToolResult | None = None + for handler in self.callbacks: + res = await handler.on_tool_start(tool_call, tool_data) + if isinstance(res, ToolCallData): + tool_data = res + arguments = tool_data.tool_args + tool_call.function.arguments = json.dumps(arguments, ensure_ascii=False) + elif isinstance(res, ToolResult): + pre_calculated_result = res + break + + if pre_calculated_result: + return tool_call, pre_calculated_result + + executable = available_tools.get(tool_name) + if not executable: + error_result = ToolErrorResult( + error_type=ToolErrorType.TOOL_NOT_FOUND, + message=f"Tool '{tool_name}' not found.", + is_retryable=False, + ) + return tool_call, ToolResult(output=model_dump(error_result)) + + from .config.providers import get_llm_config + + if not get_llm_config().debug_log: + try: + definition = await executable.get_definition() + schema_payload = getattr(definition, "parameters", {}) + schema_json = fast_json.dumps( + schema_payload, + ensure_ascii=False, + ) + logger.debug( + f"🔍 [JIT Schema] {tool_name}: {schema_json}", + "ToolInvoker", + ) + except Exception as e: + logger.trace(f"JIT Schema logging failed: {e}") + + start_t = time.monotonic() + result: ToolResult | None = None + error: Exception | None = None + + try: + + @Retry.simple(stop_max_attempt=2, wait_fixed_seconds=1) + async def execute_with_retry(): + return await executable.execute(context=context, **arguments) + + result = await execute_with_retry() + except ValidationError as e: + error = e + error_msgs = [] + for err in e.errors(): + loc = ".".join(str(x) for x in err["loc"]) + msg = err["msg"] + error_msgs.append(f"参数 '{loc}': {msg}") + + formatted_error = "; ".join(error_msgs) + error_result = ToolErrorResult( + error_type=ToolErrorType.INVALID_ARGUMENTS, + message=f"参数验证失败。请根据错误修正你的输入: {formatted_error}", + is_retryable=True, + ) + result = ToolResult(output=model_dump(error_result)) + except (TimeoutException, NetworkError) as e: + error = e + error_result = ToolErrorResult( + error_type=ToolErrorType.EXECUTION_ERROR, + message=f"工具执行网络超时或连接失败: {e!s}", + is_retryable=False, + ) + result = ToolResult(output=model_dump(error_result)) + except Exception as e: + error = e + error_type = ToolErrorType.EXECUTION_ERROR + if ( + isinstance(e, LLMException) + and e.code == LLMErrorCode.CONFIGURATION_ERROR + ): + error_type = ToolErrorType.TOOL_NOT_FOUND + is_retryable = False + + is_retryable = False + + error_result = ToolErrorResult( + error_type=error_type, message=str(e), is_retryable=is_retryable + ) + result = ToolResult(output=model_dump(error_result)) + + duration = time.monotonic() - start_t + + await self._trigger_callbacks( + "on_tool_end", + result=result, + error=error, + tool_call=tool_call, + duration=duration, + ) + + if result is None: + raise LLMException("工具执行未返回任何结果。") + + return tool_call, result + + async def execute_batch( + self, + tool_calls: list[LLMToolCall], + available_tools: dict[str, ToolExecutable], + context: Any | None = None, + ) -> list[LLMMessage]: + if not tool_calls: + return [] + + tasks = [ + self.execute_tool_call(call, available_tools, context) + for call in tool_calls + ] + results = await asyncio.gather(*tasks, return_exceptions=True) + + tool_messages: list[LLMMessage] = [] + for index, result_pair in enumerate(results): + original_call = tool_calls[index] + + if isinstance(result_pair, Exception): + logger.error( + f"工具执行发生未捕获异常: {original_call.function.name}, " + f"错误: {result_pair}" + ) + tool_messages.append( + LLMMessage.tool_response( + tool_call_id=original_call.id, + function_name=original_call.function.name, + result={ + "error": f"System Execution Error: {result_pair}", + "status": "failed", + }, + ) + ) + continue + + tool_call_result = cast(tuple[LLMToolCall, ToolResult], result_pair) + _, tool_result = tool_call_result + tool_messages.append( + LLMMessage.tool_response( + tool_call_id=original_call.id, + function_name=original_call.function.name, + result=tool_result.output, + ) + ) + return tool_messages + + +__all__ = [ + "RunContext", + "RunContextParam", + "ToolErrorResult", + "ToolErrorType", + "ToolInvoker", + "ToolParam", + "function_tool", + "tool_provider_manager", +] diff --git a/zhenxun/services/llm/tools/__init__.py b/zhenxun/services/llm/tools/__init__.py deleted file mode 100644 index ffb1c691..00000000 --- a/zhenxun/services/llm/tools/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -""" -工具模块导出 -""" - -from .manager import tool_provider_manager - -function_tool = tool_provider_manager.function_tool - - -__all__ = [ - "function_tool", - "tool_provider_manager", -] diff --git a/zhenxun/services/llm/tools/manager.py b/zhenxun/services/llm/tools/manager.py deleted file mode 100644 index 6cbf3136..00000000 --- a/zhenxun/services/llm/tools/manager.py +++ /dev/null @@ -1,293 +0,0 @@ -""" -工具提供者管理器 - -负责注册、生命周期管理(包括懒加载)和统一提供所有工具。 -""" - -import asyncio -from collections.abc import Callable -import inspect -from typing import Any - -from pydantic import BaseModel - -from zhenxun.services.log import logger -from zhenxun.utils.pydantic_compat import model_json_schema - -from ..types import ToolExecutable, ToolProvider -from ..types.models import ToolDefinition, ToolResult - - -class FunctionExecutable(ToolExecutable): - """一个 ToolExecutable 的实现,用于包装一个普通的 Python 函数。""" - - def __init__( - self, - func: Callable, - name: str, - description: str, - params_model: type[BaseModel] | None, - ): - self._func = func - self._name = name - self._description = description - self._params_model = params_model - - async def get_definition(self) -> ToolDefinition: - if not self._params_model: - return ToolDefinition( - name=self._name, - description=self._description, - parameters={"type": "object", "properties": {}}, - ) - - schema = model_json_schema(self._params_model) - - return ToolDefinition( - name=self._name, - description=self._description, - parameters={ - "type": "object", - "properties": schema.get("properties", {}), - "required": schema.get("required", []), - }, - ) - - async def execute(self, **kwargs: Any) -> ToolResult: - raw_result: Any - - if self._params_model: - try: - params_instance = self._params_model(**kwargs) - - if inspect.iscoroutinefunction(self._func): - raw_result = await self._func(params_instance) - else: - loop = asyncio.get_event_loop() - raw_result = await loop.run_in_executor( - None, lambda: self._func(params_instance) - ) - except Exception as e: - logger.error( - f"执行工具 '{self._name}' 时参数验证或实例化失败: {e}", e=e - ) - raise - else: - if inspect.iscoroutinefunction(self._func): - raw_result = await self._func(**kwargs) - else: - loop = asyncio.get_event_loop() - raw_result = await loop.run_in_executor( - None, lambda: self._func(**kwargs) - ) - - return ToolResult(output=raw_result, display_content=str(raw_result)) - - -class BuiltinFunctionToolProvider(ToolProvider): - """一个内置的 ToolProvider,用于处理通过装饰器注册的函数。""" - - def __init__(self): - self._functions: dict[str, dict[str, Any]] = {} - - def register( - self, - name: str, - func: Callable, - description: str, - params_model: type[BaseModel] | None, - ): - self._functions[name] = { - "func": func, - "description": description, - "params_model": params_model, - } - - async def initialize(self) -> None: - pass - - async def discover_tools( - self, - allowed_servers: list[str] | None = None, - excluded_servers: list[str] | None = None, - ) -> dict[str, ToolExecutable]: - executables = {} - for name, info in self._functions.items(): - executables[name] = FunctionExecutable( - func=info["func"], - name=name, - description=info["description"], - params_model=info["params_model"], - ) - return executables - - async def get_tool_executable( - self, name: str, config: dict[str, Any] - ) -> ToolExecutable | None: - if config.get("type") == "function" and name in self._functions: - info = self._functions[name] - return FunctionExecutable( - func=info["func"], - name=name, - description=info["description"], - params_model=info["params_model"], - ) - return None - - -class ToolProviderManager: - """工具提供者的中心化管理器,采用单例模式。""" - - _instance: "ToolProviderManager | None" = None - - def __new__(cls) -> "ToolProviderManager": - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance - - def __init__(self): - if hasattr(self, "_initialized") and self._initialized: - return - - self._providers: list[ToolProvider] = [] - self._resolved_tools: dict[str, ToolExecutable] | None = None - self._init_lock = asyncio.Lock() - self._init_promise: asyncio.Task | None = None - self._builtin_function_provider = BuiltinFunctionToolProvider() - self.register(self._builtin_function_provider) - self._initialized = True - - def register(self, provider: ToolProvider): - """注册一个新的 ToolProvider。""" - if provider not in self._providers: - self._providers.append(provider) - logger.info(f"已注册工具提供者: {provider.__class__.__name__}") - - def function_tool( - self, - name: str, - description: str, - params_model: type[BaseModel] | None = None, - ): - """装饰器:将一个函数注册为内置工具。""" - - def decorator(func: Callable): - if name in self._builtin_function_provider._functions: - logger.warning(f"正在覆盖已注册的函数工具: {name}") - - self._builtin_function_provider.register( - name=name, - func=func, - description=description, - params_model=params_model, - ) - logger.info(f"已注册函数工具: '{name}'") - return func - - return decorator - - async def initialize(self) -> None: - """懒加载初始化所有已注册的 ToolProvider。""" - if not self._init_promise: - async with self._init_lock: - if not self._init_promise: - self._init_promise = asyncio.create_task( - self._initialize_providers() - ) - await self._init_promise - - async def _initialize_providers(self) -> None: - """内部初始化逻辑。""" - logger.info(f"开始初始化 {len(self._providers)} 个工具提供者...") - init_tasks = [provider.initialize() for provider in self._providers] - await asyncio.gather(*init_tasks, return_exceptions=True) - logger.info("所有工具提供者初始化完成。") - - async def get_resolved_tools( - self, - allowed_servers: list[str] | None = None, - excluded_servers: list[str] | None = None, - ) -> dict[str, ToolExecutable]: - """ - 获取所有已发现和解析的工具。 - 此方法会触发懒加载初始化,并根据是否传入过滤器来决定是否使用全局缓存。 - """ - await self.initialize() - - has_filters = allowed_servers is not None or excluded_servers is not None - - if not has_filters and self._resolved_tools is not None: - logger.debug("使用全局工具缓存。") - return self._resolved_tools - - if has_filters: - logger.info("检测到过滤器,执行临时工具发现 (不使用缓存)。") - logger.debug( - f"过滤器详情: allowed_servers={allowed_servers}, " - f"excluded_servers={excluded_servers}" - ) - else: - logger.info("未应用过滤器,开始全局工具发现...") - - all_tools: dict[str, ToolExecutable] = {} - - discover_tasks = [] - for provider in self._providers: - sig = inspect.signature(provider.discover_tools) - params_to_pass = {} - if "allowed_servers" in sig.parameters: - params_to_pass["allowed_servers"] = allowed_servers - if "excluded_servers" in sig.parameters: - params_to_pass["excluded_servers"] = excluded_servers - - discover_tasks.append(provider.discover_tools(**params_to_pass)) - - results = await asyncio.gather(*discover_tasks, return_exceptions=True) - - for i, provider_result in enumerate(results): - provider_name = self._providers[i].__class__.__name__ - if isinstance(provider_result, dict): - logger.debug( - f"提供者 '{provider_name}' 发现了 {len(provider_result)} 个工具。" - ) - for name, executable in provider_result.items(): - if name in all_tools: - logger.warning( - f"发现重复的工具名称 '{name}',后发现的将覆盖前者。" - ) - all_tools[name] = executable - elif isinstance(provider_result, Exception): - logger.error( - f"提供者 '{provider_name}' 在发现工具时出错: {provider_result}" - ) - - if not has_filters: - self._resolved_tools = all_tools - logger.info(f"全局工具发现完成,共找到并缓存了 {len(all_tools)} 个工具。") - else: - logger.info(f"带过滤器的工具发现完成,共找到 {len(all_tools)} 个工具。") - - return all_tools - - async def get_function_tools( - self, names: list[str] | None = None - ) -> dict[str, ToolExecutable]: - """ - 仅从内置的函数提供者中解析指定的工具。 - """ - all_function_tools = await self._builtin_function_provider.discover_tools() - if names is None: - return all_function_tools - - resolved_tools = {} - for name in names: - if name in all_function_tools: - resolved_tools[name] = all_function_tools[name] - else: - logger.warning( - f"本地函数工具 '{name}' 未通过 @function_tool 注册,将被忽略。" - ) - return resolved_tools - - -tool_provider_manager = ToolProviderManager() diff --git a/zhenxun/services/llm/types/__init__.py b/zhenxun/services/llm/types/__init__.py index 183e71db..4a16924a 100644 --- a/zhenxun/services/llm/types/__init__.py +++ b/zhenxun/services/llm/types/__init__.py @@ -5,30 +5,32 @@ LLM 类型定义模块 """ from .capabilities import ModelCapabilities, ModelModality, get_model_capabilities -from .content import ( - LLMContentPart, - LLMMessage, - LLMResponse, -) -from .enums import ( - EmbeddingTaskType, - ModelProvider, - ResponseFormat, - TaskType, - ToolCategory, -) from .exceptions import LLMErrorCode, LLMException, get_user_friendly_error_message from .models import ( + CodeExecutionOutcome, + EmbeddingTaskType, + GeminiCodeExecution, + GeminiGoogleSearch, + GeminiUrlContext, LLMCacheInfo, LLMCodeExecution, + LLMContentPart, LLMGroundingAttribution, LLMGroundingMetadata, + LLMMessage, + LLMResponse, LLMToolCall, LLMToolFunction, ModelDetail, ModelInfo, ModelName, + ModelProvider, ProviderConfig, + ResponseFormat, + StructuredOutputStrategy, + TaskType, + ToolCategory, + ToolChoice, ToolMetadata, ToolResult, UsageInfo, @@ -36,7 +38,11 @@ from .models import ( from .protocols import ToolExecutable, ToolProvider __all__ = [ + "CodeExecutionOutcome", "EmbeddingTaskType", + "GeminiCodeExecution", + "GeminiGoogleSearch", + "GeminiUrlContext", "LLMCacheInfo", "LLMCodeExecution", "LLMContentPart", @@ -56,8 +62,10 @@ __all__ = [ "ModelProvider", "ProviderConfig", "ResponseFormat", + "StructuredOutputStrategy", "TaskType", "ToolCategory", + "ToolChoice", "ToolExecutable", "ToolMetadata", "ToolProvider", diff --git a/zhenxun/services/llm/types/capabilities.py b/zhenxun/services/llm/types/capabilities.py index 2e083708..4e7ccd09 100644 --- a/zhenxun/services/llm/types/capabilities.py +++ b/zhenxun/services/llm/types/capabilities.py @@ -6,6 +6,7 @@ LLM 模型能力定义模块 from enum import Enum import fnmatch +from typing import Literal from pydantic import BaseModel, Field @@ -20,6 +21,35 @@ class ModelModality(str, Enum): EMBEDDING = "embedding" +class ReasoningMode(str, Enum): + """推理/思考模式类型""" + + NONE = "none" + BUDGET = "budget" + LEVEL = "level" + EFFORT = "effort" + + +PATTERNS_GEMINI_2_5 = [ + "gemini-2.5*", + "gemini-flash*", + "gemini*lite*", + "gemini-flash-latest", +] + +PATTERNS_GEMINI_3 = [ + "gemini-3*", + "gemini-exp*", +] + +PATTERNS_OPENAI_REASONING = [ + "o1-*", + "o3-*", + "deepseek-r1*", + "deepseek-reasoner", +] + + class ModelCapabilities(BaseModel): """定义一个模型的核心、稳定能力。""" @@ -27,6 +57,8 @@ class ModelCapabilities(BaseModel): output_modalities: set[ModelModality] = Field(default={ModelModality.TEXT}) supports_tool_calling: bool = False is_embedding_model: bool = False + reasoning_mode: ReasoningMode = ReasoningMode.NONE + reasoning_visibility: Literal["visible", "hidden", "none"] = "none" STANDARD_TEXT_TOOL_CAPABILITIES = ModelCapabilities( @@ -35,7 +67,7 @@ STANDARD_TEXT_TOOL_CAPABILITIES = ModelCapabilities( supports_tool_calling=True, ) -GEMINI_CAPABILITIES = ModelCapabilities( +CAP_GEMINI_2_5 = ModelCapabilities( input_modalities={ ModelModality.TEXT, ModelModality.IMAGE, @@ -44,21 +76,44 @@ GEMINI_CAPABILITIES = ModelCapabilities( }, output_modalities={ModelModality.TEXT}, supports_tool_calling=True, + reasoning_mode=ReasoningMode.BUDGET, + reasoning_visibility="visible", ) -GEMINI_IMAGE_GEN_CAPABILITIES = ModelCapabilities( +CAP_GEMINI_3 = ModelCapabilities( + input_modalities={ + ModelModality.TEXT, + ModelModality.IMAGE, + ModelModality.AUDIO, + ModelModality.VIDEO, + }, + output_modalities={ModelModality.TEXT}, + supports_tool_calling=True, + reasoning_mode=ReasoningMode.LEVEL, + reasoning_visibility="visible", +) + +CAP_GEMINI_IMAGE_GEN = ModelCapabilities( input_modalities={ModelModality.TEXT, ModelModality.IMAGE}, output_modalities={ModelModality.TEXT, ModelModality.IMAGE}, supports_tool_calling=True, ) -GPT_ADVANCED_TEXT_IMAGE_CAPABILITIES = ModelCapabilities( +CAP_OPENAI_REASONING = ModelCapabilities( + input_modalities={ModelModality.TEXT, ModelModality.IMAGE}, + output_modalities={ModelModality.TEXT}, + supports_tool_calling=True, + reasoning_mode=ReasoningMode.EFFORT, + reasoning_visibility="hidden", +) + +CAP_GPT_ADVANCED = ModelCapabilities( input_modalities={ModelModality.TEXT, ModelModality.IMAGE}, output_modalities={ModelModality.TEXT}, supports_tool_calling=True, ) -GPT_MULTIMODAL_IO_CAPABILITIES = ModelCapabilities( +CAP_GPT_MULTIMODAL_IO = ModelCapabilities( input_modalities={ModelModality.TEXT, ModelModality.AUDIO, ModelModality.IMAGE}, output_modalities={ModelModality.TEXT, ModelModality.AUDIO}, supports_tool_calling=True, @@ -76,6 +131,12 @@ GPT_VIDEO_GENERATION_CAPABILITIES = ModelCapabilities( supports_tool_calling=True, ) +EMBEDDING_CAPABILITIES = ModelCapabilities( + input_modalities={ModelModality.TEXT}, + output_modalities={ModelModality.EMBEDDING}, + is_embedding_model=True, +) + DEFAULT_PERMISSIVE_CAPABILITIES = ModelCapabilities( input_modalities={ ModelModality.TEXT, @@ -107,17 +168,33 @@ MODEL_ALIAS_MAPPING: dict[str, str] = { } -MODEL_CAPABILITIES_REGISTRY: dict[str, ModelCapabilities] = { - "gemini-*-tts": ModelCapabilities( +def _build_registry() -> dict[str, ModelCapabilities]: + """构建模型能力注册表,展开模式列表以减少冗余""" + registry: dict[str, ModelCapabilities] = {} + + def register_family(patterns: list[str], cap: ModelCapabilities) -> None: + for pattern in patterns: + registry[pattern] = cap + + register_family( + ["*gemini-*-image-preview*", "gemini-*-image*"], CAP_GEMINI_IMAGE_GEN + ) + + register_family(PATTERNS_GEMINI_2_5, CAP_GEMINI_2_5) + register_family(PATTERNS_GEMINI_3, CAP_GEMINI_3) + + register_family(PATTERNS_OPENAI_REASONING, CAP_OPENAI_REASONING) + + registry["gemini-*-tts"] = ModelCapabilities( input_modalities={ModelModality.TEXT}, output_modalities={ModelModality.AUDIO}, - ), - "gemini-*-native-audio-*": ModelCapabilities( + ) + registry["gemini-*-native-audio-*"] = ModelCapabilities( input_modalities={ModelModality.TEXT, ModelModality.AUDIO, ModelModality.VIDEO}, output_modalities={ModelModality.TEXT, ModelModality.AUDIO}, supports_tool_calling=True, - ), - "gemini-2.0-flash-preview-image-generation": ModelCapabilities( + ) + registry["gemini-2.0-flash-preview-image-generation"] = ModelCapabilities( input_modalities={ ModelModality.TEXT, ModelModality.IMAGE, @@ -126,39 +203,39 @@ MODEL_CAPABILITIES_REGISTRY: dict[str, ModelCapabilities] = { }, output_modalities={ModelModality.TEXT, ModelModality.IMAGE}, supports_tool_calling=True, - ), - "gemini-embedding-exp": ModelCapabilities( - input_modalities={ModelModality.TEXT}, - output_modalities={ModelModality.EMBEDDING}, - is_embedding_model=True, - ), - "*gemini-*-image-preview*": GEMINI_IMAGE_GEN_CAPABILITIES, - "gemini-*-pro*": GEMINI_CAPABILITIES, - "gemini-*-flash*": GEMINI_CAPABILITIES, - "GLM-4V-Flash": ModelCapabilities( + ) + + registry["GLM-4V-Flash"] = ModelCapabilities( input_modalities={ModelModality.TEXT, ModelModality.IMAGE}, output_modalities={ModelModality.TEXT}, supports_tool_calling=True, - ), - "GLM-4V-Plus*": ModelCapabilities( + ) + registry["GLM-4V-Plus*"] = ModelCapabilities( input_modalities={ModelModality.TEXT, ModelModality.IMAGE, ModelModality.VIDEO}, output_modalities={ModelModality.TEXT}, supports_tool_calling=True, - ), - "glm-4-*": STANDARD_TEXT_TOOL_CAPABILITIES, - "glm-z1-*": STANDARD_TEXT_TOOL_CAPABILITIES, - "doubao-seed-*": DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES, - "doubao-1-5-thinking-vision-pro": DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES, - "deepseek-chat": STANDARD_TEXT_TOOL_CAPABILITIES, - "deepseek-reasoner": STANDARD_TEXT_TOOL_CAPABILITIES, - "gpt-5*": GPT_ADVANCED_TEXT_IMAGE_CAPABILITIES, - "gpt-4.1*": GPT_ADVANCED_TEXT_IMAGE_CAPABILITIES, - "gpt-4o*": GPT_MULTIMODAL_IO_CAPABILITIES, - "o3*": GPT_ADVANCED_TEXT_IMAGE_CAPABILITIES, - "o4-mini*": GPT_ADVANCED_TEXT_IMAGE_CAPABILITIES, - "gpt image*": GPT_IMAGE_GENERATION_CAPABILITIES, - "sora*": GPT_VIDEO_GENERATION_CAPABILITIES, -} + ) + + register_family( + ["glm-4-*", "glm-z1-*", "deepseek-chat"], STANDARD_TEXT_TOOL_CAPABILITIES + ) + register_family( + ["doubao-seed-*", "doubao-1-5-thinking-vision-pro"], + DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES, + ) + + register_family(["gpt-5*", "gpt-4.1*", "o4-mini*"], CAP_GPT_ADVANCED) + registry["gpt-4o*"] = CAP_GPT_MULTIMODAL_IO + + registry["gpt image*"] = GPT_IMAGE_GENERATION_CAPABILITIES + registry["sora*"] = GPT_VIDEO_GENERATION_CAPABILITIES + + registry["*embedding*"] = EMBEDDING_CAPABILITIES + + return registry + + +MODEL_CAPABILITIES_REGISTRY = _build_registry() def get_model_capabilities(model_name: str) -> ModelCapabilities: diff --git a/zhenxun/services/llm/types/content.py b/zhenxun/services/llm/types/content.py deleted file mode 100644 index 2ceee5a4..00000000 --- a/zhenxun/services/llm/types/content.py +++ /dev/null @@ -1,434 +0,0 @@ -""" -LLM 内容类型定义 - -包含多模态内容部分、消息和响应的数据模型。 -""" - -import base64 -import mimetypes -from pathlib import Path -from typing import Any - -import aiofiles -from pydantic import BaseModel - -from zhenxun.services.log import logger - - -class LLMContentPart(BaseModel): - """LLM 消息内容部分 - 支持多模态内容""" - - type: str - text: str | None = None - image_source: str | None = None - audio_source: str | None = None - video_source: str | None = None - document_source: str | None = None - file_uri: str | None = None - file_source: str | None = None - url: str | None = None - mime_type: str | None = None - metadata: dict[str, Any] | None = None - - def model_post_init(self, /, __context: Any) -> None: - """验证内容部分的有效性""" - _ = __context - validation_rules = { - "text": lambda: self.text, - "image": lambda: self.image_source, - "audio": lambda: self.audio_source, - "video": lambda: self.video_source, - "document": lambda: self.document_source, - "file": lambda: self.file_uri or self.file_source, - "url": lambda: self.url, - } - - if self.type in validation_rules: - if not validation_rules[self.type](): - raise ValueError(f"{self.type}类型的内容部分必须包含相应字段") - - @classmethod - def text_part(cls, text: str) -> "LLMContentPart": - """创建文本内容部分""" - return cls(type="text", text=text) - - @classmethod - def image_url_part(cls, url: str) -> "LLMContentPart": - """创建图片URL内容部分""" - return cls(type="image", image_source=url) - - @classmethod - def image_base64_part( - cls, data: str, mime_type: str = "image/png" - ) -> "LLMContentPart": - """创建Base64图片内容部分""" - data_url = f"data:{mime_type};base64,{data}" - return cls(type="image", image_source=data_url) - - @classmethod - def audio_url_part(cls, url: str, mime_type: str = "audio/wav") -> "LLMContentPart": - """创建音频URL内容部分""" - return cls(type="audio", audio_source=url, mime_type=mime_type) - - @classmethod - def video_url_part(cls, url: str, mime_type: str = "video/mp4") -> "LLMContentPart": - """创建视频URL内容部分""" - return cls(type="video", video_source=url, mime_type=mime_type) - - @classmethod - def video_base64_part( - cls, data: str, mime_type: str = "video/mp4" - ) -> "LLMContentPart": - """创建Base64视频内容部分""" - data_url = f"data:{mime_type};base64,{data}" - return cls(type="video", video_source=data_url, mime_type=mime_type) - - @classmethod - def audio_base64_part( - cls, data: str, mime_type: str = "audio/wav" - ) -> "LLMContentPart": - """创建Base64音频内容部分""" - data_url = f"data:{mime_type};base64,{data}" - return cls(type="audio", audio_source=data_url, mime_type=mime_type) - - @classmethod - def file_uri_part( - cls, - file_uri: str, - mime_type: str | None = None, - metadata: dict[str, Any] | None = None, - ) -> "LLMContentPart": - """创建Gemini File API URI内容部分""" - return cls( - type="file", - file_uri=file_uri, - mime_type=mime_type, - metadata=metadata or {}, - ) - - @classmethod - async def from_path( - cls, path_like: str | Path, target_api: str | None = None - ) -> "LLMContentPart | None": - """ - 从本地文件路径创建 LLMContentPart。 - 自动检测MIME类型,并根据类型(如图片)可能加载为Base64。 - target_api 可以用于提示如何最好地准备数据(例如 'gemini' 可能偏好 base64) - """ - try: - path = Path(path_like) - if not path.exists() or not path.is_file(): - logger.warning(f"文件不存在或不是一个文件: {path}") - return None - - mime_type, _ = mimetypes.guess_type(path.resolve().as_uri()) - - if not mime_type: - logger.warning( - f"无法猜测文件 {path.name} 的MIME类型,将尝试作为文本文件处理。" - ) - try: - async with aiofiles.open(path, encoding="utf-8") as f: - text_content = await f.read() - return cls.text_part(text_content) - except Exception as e: - logger.error(f"读取文本文件 {path.name} 失败: {e}") - return None - - if mime_type.startswith("image/"): - if target_api == "gemini" or not path.is_absolute(): - try: - async with aiofiles.open(path, "rb") as f: - img_bytes = await f.read() - base64_data = base64.b64encode(img_bytes).decode("utf-8") - return cls.image_base64_part( - data=base64_data, mime_type=mime_type - ) - except Exception as e: - logger.error(f"读取或编码图片文件 {path.name} 失败: {e}") - return None - else: - logger.warning( - f"为本地图片路径 {path.name} 生成 image_url_part。" - "实际API可能不支持 file:// URI。考虑使用Base64或公网URL。" - ) - return cls.image_url_part(url=path.resolve().as_uri()) - elif mime_type.startswith("audio/"): - return cls.audio_url_part( - url=path.resolve().as_uri(), mime_type=mime_type - ) - elif mime_type.startswith("video/"): - if target_api == "gemini": - # 对于 Gemini API,将视频转换为 base64 - try: - async with aiofiles.open(path, "rb") as f: - video_bytes = await f.read() - base64_data = base64.b64encode(video_bytes).decode("utf-8") - return cls.video_base64_part( - data=base64_data, mime_type=mime_type - ) - except Exception as e: - logger.error(f"读取或编码视频文件 {path.name} 失败: {e}") - return None - else: - return cls.video_url_part( - url=path.resolve().as_uri(), mime_type=mime_type - ) - elif ( - mime_type.startswith("text/") - or mime_type == "application/json" - or mime_type == "application/xml" - ): - try: - async with aiofiles.open(path, encoding="utf-8") as f: - text_content = await f.read() - return cls.text_part(text_content) - except Exception as e: - logger.error(f"读取文本类文件 {path.name} 失败: {e}") - return None - else: - logger.info( - f"文件 {path.name} (MIME: {mime_type}) 将作为通用文件URI处理。" - ) - return cls.file_uri_part( - file_uri=path.resolve().as_uri(), - mime_type=mime_type, - metadata={"name": path.name, "source": "local_path"}, - ) - - except Exception as e: - logger.error(f"从路径 {path_like} 创建LLMContentPart时出错: {e}") - return None - - def is_image_url(self) -> bool: - """检查图像源是否为URL""" - if not self.image_source: - return False - return self.image_source.startswith(("http://", "https://")) - - def is_image_base64(self) -> bool: - """检查图像源是否为Base64 Data URL""" - if not self.image_source: - return False - return self.image_source.startswith("data:") - - def get_base64_data(self) -> tuple[str, str] | None: - """从Data URL中提取Base64数据和MIME类型""" - if not self.is_image_base64() or not self.image_source: - return None - - try: - header, data = self.image_source.split(",", 1) - mime_part = header.split(";")[0].replace("data:", "") - return mime_part, data - except (ValueError, IndexError): - logger.warning(f"无法解析Base64图像数据: {self.image_source[:50]}...") - return None - - async def convert_for_api_async(self, api_type: str) -> dict[str, Any]: - """根据API类型转换多模态内容格式""" - from zhenxun.utils.http_utils import AsyncHttpx - - if self.type == "text": - if api_type == "openai": - return {"type": "text", "text": self.text} - elif api_type == "gemini": - return {"text": self.text} - else: - return {"type": "text", "text": self.text} - - elif self.type == "image": - if not self.image_source: - raise ValueError("图像类型的内容必须包含image_source") - - if api_type == "openai": - return {"type": "image_url", "image_url": {"url": self.image_source}} - elif api_type == "gemini": - if self.is_image_base64(): - base64_info = self.get_base64_data() - if base64_info: - mime_type, data = base64_info - return {"inlineData": {"mimeType": mime_type, "data": data}} - else: - raise ValueError( - f"无法解析Base64图像数据: {self.image_source[:50]}..." - ) - elif self.is_image_url(): - logger.debug(f"正在为Gemini下载并编码URL图片: {self.image_source}") - try: - image_bytes = await AsyncHttpx.get_content(self.image_source) - mime_type = self.mime_type or "image/jpeg" - base64_data = base64.b64encode(image_bytes).decode("utf-8") - return { - "inlineData": {"mimeType": mime_type, "data": base64_data} - } - except Exception as e: - logger.error(f"下载或编码URL图片失败: {e}", e=e) - raise ValueError(f"无法处理图片URL: {e}") - else: - raise ValueError(f"不支持的图像源格式: {self.image_source[:50]}...") - else: - return {"type": "image_url", "image_url": {"url": self.image_source}} - - elif self.type == "video": - if not self.video_source: - raise ValueError("视频类型的内容必须包含video_source") - - if api_type == "gemini": - # Gemini 支持视频,但需要通过 File API 上传 - if self.video_source.startswith("data:"): - # 处理 base64 视频数据 - try: - header, data = self.video_source.split(",", 1) - mime_type = header.split(";")[0].replace("data:", "") - return {"inlineData": {"mimeType": mime_type, "data": data}} - except (ValueError, IndexError): - raise ValueError( - f"无法解析Base64视频数据: {self.video_source[:50]}..." - ) - else: - # 对于 URL 或其他格式,暂时不支持直接内联 - raise ValueError( - "Gemini API 的视频处理需要通过 File API 上传,不支持直接 URL" - ) - else: - # 其他 API 可能不支持视频 - raise ValueError(f"API类型 '{api_type}' 不支持视频内容") - - elif self.type == "audio": - if not self.audio_source: - raise ValueError("音频类型的内容必须包含audio_source") - - if api_type == "gemini": - # Gemini 支持音频,处理方式类似视频 - if self.audio_source.startswith("data:"): - try: - header, data = self.audio_source.split(",", 1) - mime_type = header.split(";")[0].replace("data:", "") - return {"inlineData": {"mimeType": mime_type, "data": data}} - except (ValueError, IndexError): - raise ValueError( - f"无法解析Base64音频数据: {self.audio_source[:50]}..." - ) - else: - raise ValueError( - "Gemini API 的音频处理需要通过 File API 上传,不支持直接 URL" - ) - else: - raise ValueError(f"API类型 '{api_type}' 不支持音频内容") - - elif self.type == "file": - if api_type == "gemini" and self.file_uri: - return { - "fileData": {"mimeType": self.mime_type, "fileUri": self.file_uri} - } - elif self.file_source: - file_name = ( - self.metadata.get("name", "file") if self.metadata else "file" - ) - if api_type == "gemini": - return {"text": f"[文件: {file_name}]\n{self.file_source}"} - else: - return { - "type": "text", - "text": f"[文件: {file_name}]\n{self.file_source}", - } - else: - raise ValueError("文件类型的内容必须包含file_uri或file_source") - - else: - raise ValueError(f"不支持的内容类型: {self.type}") - - -class LLMMessage(BaseModel): - """LLM 消息""" - - role: str - content: str | list[LLMContentPart] - name: str | None = None - tool_calls: list[Any] | None = None - tool_call_id: str | None = None - - def model_post_init(self, /, __context: Any) -> None: - """验证消息的有效性""" - _ = __context - if self.role == "tool": - if not self.tool_call_id: - raise ValueError("工具角色的消息必须包含 tool_call_id") - if not self.name: - raise ValueError("工具角色的消息必须包含函数名 (在 name 字段中)") - if self.role == "tool" and not isinstance(self.content, str): - logger.warning( - f"工具角色消息的内容期望是字符串,但得到的是: {type(self.content)}. " - "将尝试转换为字符串。" - ) - try: - self.content = str(self.content) - except Exception as e: - raise ValueError(f"无法将工具角色的内容转换为字符串: {e}") - - @classmethod - def user(cls, content: str | list[LLMContentPart]) -> "LLMMessage": - """创建用户消息""" - return cls(role="user", content=content) - - @classmethod - def assistant_tool_calls( - cls, - tool_calls: list[Any], - content: str | list[LLMContentPart] = "", - ) -> "LLMMessage": - """创建助手请求工具调用的消息""" - return cls(role="assistant", content=content, tool_calls=tool_calls) - - @classmethod - def assistant_text_response( - cls, content: str | list[LLMContentPart] - ) -> "LLMMessage": - """创建助手纯文本回复的消息""" - return cls(role="assistant", content=content, tool_calls=None) - - @classmethod - def tool_response( - cls, - tool_call_id: str, - function_name: str, - result: Any, - ) -> "LLMMessage": - """创建工具执行结果的消息""" - import json - - try: - content_str = json.dumps(result) - except TypeError as e: - logger.error( - f"工具 '{function_name}' 的结果无法JSON序列化: {result}. 错误: {e}" - ) - content_str = json.dumps( - {"error": "工具结果无法JSON序列化", "details": str(e)} - ) - - return cls( - role="tool", - content=content_str, - tool_call_id=tool_call_id, - name=function_name, - ) - - @classmethod - def system(cls, content: str) -> "LLMMessage": - """创建系统消息""" - return cls(role="system", content=content) - - -class LLMResponse(BaseModel): - """LLM 响应""" - - text: str - images: list[bytes] | None = None - usage_info: dict[str, Any] | None = None - raw_response: dict[str, Any] | None = None - tool_calls: list[Any] | None = None - code_executions: list[Any] | None = None - grounding_metadata: Any | None = None - cache_info: Any | None = None diff --git a/zhenxun/services/llm/types/enums.py b/zhenxun/services/llm/types/enums.py deleted file mode 100644 index 82cb49b0..00000000 --- a/zhenxun/services/llm/types/enums.py +++ /dev/null @@ -1,78 +0,0 @@ -""" -LLM 枚举类型定义 -""" - -from enum import Enum, auto - - -class ModelProvider(Enum): - """模型提供商枚举""" - - OPENAI = "openai" - GEMINI = "gemini" - ZHIXPU = "zhipu" - CUSTOM = "custom" - - -class ResponseFormat(Enum): - """响应格式枚举""" - - TEXT = "text" - JSON = "json" - MULTIMODAL = "multimodal" - - -class EmbeddingTaskType(str, Enum): - """文本嵌入任务类型 (主要用于Gemini)""" - - RETRIEVAL_QUERY = "RETRIEVAL_QUERY" - RETRIEVAL_DOCUMENT = "RETRIEVAL_DOCUMENT" - SEMANTIC_SIMILARITY = "SEMANTIC_SIMILARITY" - CLASSIFICATION = "CLASSIFICATION" - CLUSTERING = "CLUSTERING" - QUESTION_ANSWERING = "QUESTION_ANSWERING" - FACT_VERIFICATION = "FACT_VERIFICATION" - - -class ToolCategory(Enum): - """工具分类枚举""" - - FILE_SYSTEM = auto() - NETWORK = auto() - SYSTEM_INFO = auto() - CALCULATION = auto() - DATA_PROCESSING = auto() - CUSTOM = auto() - - -class TaskType(Enum): - """任务类型枚举""" - - CHAT = "chat" - CODE = "code" - SEARCH = "search" - ANALYSIS = "analysis" - GENERATION = "generation" - MULTIMODAL = "multimodal" - - -class LLMErrorCode(Enum): - """LLM 服务相关的错误代码枚举""" - - MODEL_INIT_FAILED = 2000 - MODEL_NOT_FOUND = 2001 - API_REQUEST_FAILED = 2002 - API_RESPONSE_INVALID = 2003 - API_KEY_INVALID = 2004 - API_QUOTA_EXCEEDED = 2005 - API_TIMEOUT = 2006 - API_RATE_LIMITED = 2007 - NO_AVAILABLE_KEYS = 2008 - UNKNOWN_API_TYPE = 2009 - CONFIGURATION_ERROR = 2010 - RESPONSE_PARSE_ERROR = 2011 - CONTEXT_LENGTH_EXCEEDED = 2012 - CONTENT_FILTERED = 2013 - USER_LOCATION_NOT_SUPPORTED = 2014 - GENERATION_FAILED = 2015 - EMBEDDING_FAILED = 2016 diff --git a/zhenxun/services/llm/types/exceptions.py b/zhenxun/services/llm/types/exceptions.py index 623d4c26..4168adf3 100644 --- a/zhenxun/services/llm/types/exceptions.py +++ b/zhenxun/services/llm/types/exceptions.py @@ -2,9 +2,31 @@ LLM 异常类型定义 """ +from enum import Enum from typing import Any -from .enums import LLMErrorCode + +class LLMErrorCode(Enum): + """LLM 服务相关的错误代码枚举""" + + MODEL_INIT_FAILED = 2000 + MODEL_NOT_FOUND = 2001 + API_REQUEST_FAILED = 2002 + API_RESPONSE_INVALID = 2003 + API_KEY_INVALID = 2004 + API_QUOTA_EXCEEDED = 2005 + API_TIMEOUT = 2006 + API_RATE_LIMITED = 2007 + NO_AVAILABLE_KEYS = 2008 + UNKNOWN_API_TYPE = 2009 + CONFIGURATION_ERROR = 2010 + RESPONSE_PARSE_ERROR = 2011 + CONTEXT_LENGTH_EXCEEDED = 2012 + CONTENT_FILTERED = 2013 + USER_LOCATION_NOT_SUPPORTED = 2014 + INVALID_PARAMETER = 2017 + GENERATION_FAILED = 2015 + EMBEDDING_FAILED = 2016 class LLMException(Exception): @@ -27,7 +49,11 @@ class LLMException(Exception): def __str__(self) -> str: if self.details: - return f"{self.message} (错误码: {self.code.name}, 详情: {self.details})" + safe_details = {k: v for k, v in self.details.items() if k != "api_key"} + if safe_details: + return ( + f"{self.message} (错误码: {self.code.name}, 详情: {safe_details})" + ) return f"{self.message} (错误码: {self.code.name})" @property @@ -46,10 +72,13 @@ class LLMException(Exception): "当前所有API密钥均不可用,请稍后再试或联系管理员。" ), LLMErrorCode.USER_LOCATION_NOT_SUPPORTED: ( - "当前地区暂不支持此AI服务,请联系管理员或尝试其他模型。" + "当前网络环境不支持此 AI 模型 (如 Gemini/OpenAI)。\n" + "原因: 代理节点所在地区(如香港/国内/非支持区)被服务商屏蔽。\n" + "建议: 请尝试更换代理节点至支持的地区(如美国/日本/新加坡)。" ), LLMErrorCode.API_REQUEST_FAILED: "AI服务请求失败,请稍后再试。", LLMErrorCode.API_RESPONSE_INVALID: "AI服务响应异常,请稍后再试。", + LLMErrorCode.INVALID_PARAMETER: "请求参数错误,请检查输入内容。", LLMErrorCode.CONFIGURATION_ERROR: "AI服务配置错误,请联系管理员。", LLMErrorCode.CONTEXT_LENGTH_EXCEEDED: "输入内容过长,请缩短后重试。", LLMErrorCode.CONTENT_FILTERED: "内容被安全过滤,请修改后重试。", @@ -66,15 +95,19 @@ def get_user_friendly_error_message(error: Exception) -> str: error_str = str(error).lower() - if "timeout" in error_str or "超时" in error_str: - return "请求超时,请稍后再试。" - elif "connection" in error_str or "连接" in error_str: - return "网络连接失败,请检查网络后重试。" - elif "permission" in error_str or "权限" in error_str: - return "权限不足,请联系管理员。" - elif "not found" in error_str or "未找到" in error_str: - return "请求的资源未找到,请检查配置。" - elif "invalid" in error_str or "无效" in error_str: + if "timeout" in error_str or "timed out" in error_str: + return "网络请求超时,请检查服务器网络或代理连接。" + if "connect" in error_str and ("refused" in error_str or "error" in error_str): + return "无法连接到 AI 服务商,请检查网络连接或代理设置。" + if "proxy" in error_str: + return "代理连接失败,请检查代理服务器是否正常运行。" + if "ssl" in error_str or "certificate" in error_str: + return "SSL 证书验证失败,请检查网络环境。" + if "permission" in error_str or "forbidden" in error_str: + return "权限不足,可能是 API Key 权限受限。" + if "not found" in error_str: + return "请求的资源未找到 (404),请检查模型名称或端点配置。" + if "invalid" in error_str or "无效" in error_str: return "请求参数无效,请检查输入。" - else: - return "服务暂时不可用,请稍后再试。" + + return f"服务暂时不可用 ({type(error).__name__}),请稍后再试。" diff --git a/zhenxun/services/llm/types/models.py b/zhenxun/services/llm/types/models.py index 9ff56158..e8c79078 100644 --- a/zhenxun/services/llm/types/models.py +++ b/zhenxun/services/llm/types/models.py @@ -4,12 +4,459 @@ LLM 数据模型定义 包含模型信息、配置、工具定义和响应数据的模型类。 """ +import base64 from dataclasses import dataclass, field -from typing import Any +from enum import Enum, auto +import mimetypes +from pathlib import Path +import sys +from typing import Any, Literal +import aiofiles from pydantic import BaseModel, Field -from .enums import ModelProvider, ToolCategory +from zhenxun.services.log import logger + +if sys.version_info >= (3, 11): + from enum import StrEnum +else: + from strenum import StrEnum + + +class ModelProvider(Enum): + """模型提供商枚举""" + + OPENAI = "openai" + GEMINI = "gemini" + ZHIXPU = "zhipu" + CUSTOM = "custom" + + +class ResponseFormat(Enum): + """响应格式枚举""" + + TEXT = "text" + JSON = "json" + MULTIMODAL = "multimodal" + + +class StructuredOutputStrategy(str, Enum): + """结构化输出策略""" + + NATIVE = "native" + """使用原生 API (如 OpenAI json_object/json_schema, Gemini mime_type)""" + TOOL_CALL = "tool_call" + """构造虚假工具调用来强制输出结构化数据 (适用于指令跟随弱但工具调用强的模型)""" + PROMPT = "prompt" + """仅在 Prompt 中追加 Schema 说明,依赖文本补全""" + + +class EmbeddingTaskType(str, Enum): + """文本嵌入任务类型 (主要用于Gemini)""" + + RETRIEVAL_QUERY = "RETRIEVAL_QUERY" + RETRIEVAL_DOCUMENT = "RETRIEVAL_DOCUMENT" + SEMANTIC_SIMILARITY = "SEMANTIC_SIMILARITY" + CLASSIFICATION = "CLASSIFICATION" + CLUSTERING = "CLUSTERING" + QUESTION_ANSWERING = "QUESTION_ANSWERING" + FACT_VERIFICATION = "FACT_VERIFICATION" + + +class ToolCategory(Enum): + """工具分类枚举""" + + FILE_SYSTEM = auto() + NETWORK = auto() + SYSTEM_INFO = auto() + CALCULATION = auto() + DATA_PROCESSING = auto() + CUSTOM = auto() + + +class CodeExecutionOutcome(StrEnum): + """代码执行结果状态枚举""" + + OUTCOME_OK = "OUTCOME_OK" + OUTCOME_FAILED = "OUTCOME_FAILED" + OUTCOME_DEADLINE_EXCEEDED = "OUTCOME_DEADLINE_EXCEEDED" + OUTCOME_COMPILATION_ERROR = "OUTCOME_COMPILATION_ERROR" + OUTCOME_RUNTIME_ERROR = "OUTCOME_RUNTIME_ERROR" + OUTCOME_UNKNOWN = "OUTCOME_UNKNOWN" + + +class TaskType(Enum): + """任务类型枚举""" + + CHAT = "chat" + CODE = "code" + SEARCH = "search" + ANALYSIS = "analysis" + GENERATION = "generation" + MULTIMODAL = "multimodal" + + +class LLMContentPart(BaseModel): + """ + LLM 消息内容部分 - 支持多模态内容。 + + 这是一个联合体模型,`type` 字段决定了哪些其他字段是有效的。 + 例如: + - type='text': 使用 `text` 字段。 + - type='image': 使用 `image_source` 字段。 + - type='executable_code': 使用 `code_language` 和 `code_content` 字段。 + """ + + type: str + text: str | None = None + image_source: str | None = None + audio_source: str | None = None + video_source: str | None = None + document_source: str | None = None + file_uri: str | None = None + file_source: str | None = None + url: str | None = None + mime_type: str | None = None + thought_text: str | None = None + media_resolution: str | None = None + code_language: str | None = None + code_content: str | None = None + execution_outcome: str | None = None + execution_output: str | None = None + metadata: dict[str, Any] | None = None + + def model_post_init(self, /, __context: Any) -> None: + """验证内容部分的有效性""" + _ = __context + validation_rules = { + "text": lambda: self.text is not None, + "image": lambda: self.image_source, + "audio": lambda: self.audio_source, + "video": lambda: self.video_source, + "document": lambda: self.document_source, + "file": lambda: self.file_uri or self.file_source, + "url": lambda: self.url, + "thought": lambda: self.thought_text, + "executable_code": lambda: self.code_content is not None, + "execution_result": lambda: self.execution_outcome is not None, + } + + if self.type in validation_rules: + if not validation_rules[self.type](): + raise ValueError(f"{self.type}类型的内容部分必须包含相应字段") + + @classmethod + def text_part(cls, text: str) -> "LLMContentPart": + """创建文本内容部分""" + return cls(type="text", text=text) + + @classmethod + def thought_part(cls, text: str) -> "LLMContentPart": + """创建思考过程内容部分""" + return cls(type="thought", thought_text=text) + + @classmethod + def image_url_part(cls, url: str) -> "LLMContentPart": + """创建图片URL内容部分""" + return cls(type="image", image_source=url) + + @classmethod + def image_base64_part( + cls, data: str, mime_type: str = "image/png" + ) -> "LLMContentPart": + """创建Base64图片内容部分""" + data_url = f"data:{mime_type};base64,{data}" + return cls(type="image", image_source=data_url) + + @classmethod + def audio_url_part(cls, url: str, mime_type: str = "audio/wav") -> "LLMContentPart": + """创建音频URL内容部分""" + return cls(type="audio", audio_source=url, mime_type=mime_type) + + @classmethod + def video_url_part(cls, url: str, mime_type: str = "video/mp4") -> "LLMContentPart": + """创建视频URL内容部分""" + return cls(type="video", video_source=url, mime_type=mime_type) + + @classmethod + def video_base64_part( + cls, data: str, mime_type: str = "video/mp4" + ) -> "LLMContentPart": + """创建Base64视频内容部分""" + data_url = f"data:{mime_type};base64,{data}" + return cls(type="video", video_source=data_url, mime_type=mime_type) + + @classmethod + def audio_base64_part( + cls, data: str, mime_type: str = "audio/wav" + ) -> "LLMContentPart": + """创建Base64音频内容部分""" + data_url = f"data:{mime_type};base64,{data}" + return cls(type="audio", audio_source=data_url, mime_type=mime_type) + + @classmethod + def file_uri_part( + cls, + file_uri: str, + mime_type: str | None = None, + metadata: dict[str, Any] | None = None, + ) -> "LLMContentPart": + """创建Gemini File API URI内容部分""" + return cls( + type="file", + file_uri=file_uri, + mime_type=mime_type, + metadata=metadata or {}, + ) + + @classmethod + def executable_code_part(cls, language: str, code: str) -> "LLMContentPart": + """创建可执行代码内容部分""" + return cls(type="executable_code", code_language=language, code_content=code) + + @classmethod + def execution_result_part( + cls, outcome: str, output: str | None + ) -> "LLMContentPart": + """创建代码执行结果部分""" + return cls( + type="execution_result", execution_outcome=outcome, execution_output=output + ) + + @classmethod + async def from_path( + cls, path_like: str | Path, target_api: str | None = None + ) -> "LLMContentPart | None": + """ + 从本地文件路径创建 LLMContentPart。 + 自动检测MIME类型,并根据类型(如图片)可能加载为Base64。 + target_api 可以用于提示如何最好地准备数据(例如 'gemini' 可能偏好 base64) + """ + try: + path = Path(path_like) + if not path.exists() or not path.is_file(): + logger.warning(f"文件不存在或不是一个文件: {path}") + return None + + mime_type, _ = mimetypes.guess_type(path.resolve().as_uri()) + + if not mime_type: + logger.warning( + f"无法猜测文件 {path.name} 的MIME类型,将尝试作为文本文件处理。" + ) + try: + async with aiofiles.open(path, encoding="utf-8") as f: + text_content = await f.read() + return cls.text_part(text_content) + except Exception as e: + logger.error(f"读取文本文件 {path.name} 失败: {e}") + return None + + if mime_type.startswith("image/"): + if target_api == "gemini" or not path.is_absolute(): + try: + async with aiofiles.open(path, "rb") as f: + img_bytes = await f.read() + base64_data = base64.b64encode(img_bytes).decode("utf-8") + return cls.image_base64_part( + data=base64_data, mime_type=mime_type + ) + except Exception as e: + logger.error(f"读取或编码图片文件 {path.name} 失败: {e}") + return None + else: + logger.warning( + f"为本地图片路径 {path.name} 生成 image_url_part。" + "实际API可能不支持 file:// URI。考虑使用Base64或公网URL。" + ) + return cls.image_url_part(url=path.resolve().as_uri()) + elif mime_type.startswith("audio/"): + return cls.audio_url_part( + url=path.resolve().as_uri(), mime_type=mime_type + ) + elif mime_type.startswith("video/"): + if target_api == "gemini": + try: + async with aiofiles.open(path, "rb") as f: + video_bytes = await f.read() + base64_data = base64.b64encode(video_bytes).decode("utf-8") + return cls.video_base64_part( + data=base64_data, mime_type=mime_type + ) + except Exception as e: + logger.error(f"读取或编码视频文件 {path.name} 失败: {e}") + return None + else: + return cls.video_url_part( + url=path.resolve().as_uri(), mime_type=mime_type + ) + elif ( + mime_type.startswith("text/") + or mime_type == "application/json" + or mime_type == "application/xml" + ): + try: + async with aiofiles.open(path, encoding="utf-8") as f: + text_content = await f.read() + return cls.text_part(text_content) + except Exception as e: + logger.error(f"读取文本类文件 {path.name} 失败: {e}") + return None + else: + logger.info( + f"文件 {path.name} (MIME: {mime_type}) 将作为通用文件URI处理。" + ) + return cls.file_uri_part( + file_uri=path.resolve().as_uri(), + mime_type=mime_type, + metadata={"name": path.name, "source": "local_path"}, + ) + + except Exception as e: + logger.error(f"从路径 {path_like} 创建LLMContentPart时出错: {e}") + return None + + def is_image_url(self) -> bool: + """检查图像源是否为URL""" + if not self.image_source: + return False + return self.image_source.startswith(("http://", "https://")) + + def is_image_base64(self) -> bool: + """检查图像源是否为Base64 Data URL""" + if not self.image_source: + return False + return self.image_source.startswith("data:") + + def get_base64_data(self) -> tuple[str, str] | None: + """从Data URL中提取Base64数据和MIME类型""" + if not self.is_image_base64() or not self.image_source: + return None + + try: + header, data = self.image_source.split(",", 1) + mime_part = header.split(";")[0].replace("data:", "") + return mime_part, data + except (ValueError, IndexError): + logger.warning(f"无法解析Base64图像数据: {self.image_source[:50]}...") + return None + + +class LLMMessage(BaseModel): + """ + LLM 消息对象,用于构建对话历史。 + + 核心字段说明: + - role: 消息角色,推荐值为 'user', 'assistant', 'system', 'tool'。 + - content: 消息内容,可以是纯文本字符串,也可以是 LLMContentPart 列表(用于多模态) + - tool_calls: (仅 assistant) 包含模型生成的工具调用请求。 + - tool_call_id: (仅 tool) 对应 tool 消息响应的调用 ID。 + - name: (仅 tool) 对应 tool 消息响应的函数名称。 + """ + + role: str + content: str | list[LLMContentPart] + name: str | None = None + tool_calls: list[Any] | None = None + tool_call_id: str | None = None + thought_signature: str | None = None + + def model_post_init(self, /, __context: Any) -> None: + """验证消息的有效性""" + _ = __context + if self.role == "tool": + if not self.tool_call_id: + raise ValueError("工具角色的消息必须包含 tool_call_id") + if not self.name: + raise ValueError("工具角色的消息必须包含函数名 (在 name 字段中)") + if self.role == "tool" and not isinstance(self.content, str): + logger.warning( + f"工具角色消息的内容期望是字符串,但得到的是: {type(self.content)}. " + "将尝试转换为字符串。" + ) + try: + self.content = str(self.content) + except Exception as e: + raise ValueError(f"无法将工具角色的内容转换为字符串: {e}") + + @classmethod + def user(cls, content: str | list[LLMContentPart]) -> "LLMMessage": + """创建用户消息""" + return cls(role="user", content=content) + + @classmethod + def assistant_tool_calls( + cls, + tool_calls: list[Any], + content: str | list[LLMContentPart] = "", + ) -> "LLMMessage": + """创建助手请求工具调用的消息""" + return cls(role="assistant", content=content, tool_calls=tool_calls) + + @classmethod + def assistant_text_response( + cls, content: str | list[LLMContentPart] + ) -> "LLMMessage": + """创建助手纯文本回复的消息""" + return cls(role="assistant", content=content, tool_calls=None) + + @classmethod + def tool_response( + cls, + tool_call_id: str, + function_name: str, + result: Any, + ) -> "LLMMessage": + """创建工具执行结果的消息""" + import json + + try: + content_str = json.dumps(result) + except TypeError as e: + logger.error( + f"工具 '{function_name}' 的结果无法JSON序列化: {result}. 错误: {e}" + ) + content_str = json.dumps( + {"error": "工具结果无法JSON序列化", "details": str(e)} + ) + + return cls( + role="tool", + content=content_str, + tool_call_id=tool_call_id, + name=function_name, + ) + + @classmethod + def system(cls, content: str) -> "LLMMessage": + """创建系统消息""" + return cls(role="system", content=content) + + +class LLMResponse(BaseModel): + """ + LLM 响应对象,封装了模型生成的全部信息。 + + 核心字段说明: + - text: 模型生成的文本内容。如果是纯文本回复,此字段即为结果。 + - tool_calls: 如果模型决定调用工具,此列表包含调用详情。 + - content_parts: 包含多模态或结构化内容的原始部分列表(如思维链、代码块)。 + - raw_response: 原始的第三方 API 响应字典(用于调试)。 + - images: 如果请求涉及生图,此处包含生成的图片数据。 + """ + + text: str + content_parts: list[Any] | None = None + images: list[bytes | Path] | None = None + usage_info: dict[str, Any] | None = None + raw_response: dict[str, Any] | None = None + tool_calls: list[Any] | None = None + code_executions: list[Any] | None = None + grounding_metadata: Any | None = None + cache_info: Any | None = None + thought_text: str | None = None + thought_signature: str | None = None + ModelName = str | None @@ -26,6 +473,64 @@ class ToolDefinition(BaseModel): ) +class ToolChoice(BaseModel): + """统一的工具选择配置""" + + mode: Literal["auto", "none", "any", "required"] = Field( + default="auto", description="工具调用模式" + ) + allowed_function_names: list[str] | None = Field( + default=None, description="允许调用的函数名称列表" + ) + + +class BasePlatformTool(BaseModel): + """平台原生工具基类""" + + class Config: + extra = "forbid" + + def get_tool_declaration(self) -> dict[str, Any]: + """获取放入 'tools' 列表中的声明对象 (Snake Case)""" + raise NotImplementedError + + def get_tool_config(self) -> dict[str, Any] | None: + """获取放入 'toolConfig' 中的配置对象 (Snake Case)""" + return None + + +class GeminiCodeExecution(BasePlatformTool): + """Gemini 代码执行工具""" + + def get_tool_declaration(self) -> dict[str, Any]: + return {"code_execution": {}} + + +class GeminiGoogleSearch(BasePlatformTool): + """Gemini 谷歌搜索 (Grounding) 工具""" + + mode: Literal["MODE_DYNAMIC"] = "MODE_DYNAMIC" + dynamic_threshold: float | None = Field(default=None) + + def get_tool_declaration(self) -> dict[str, Any]: + return {"google_search": {}} + + def get_tool_config(self) -> dict[str, Any] | None: + return None + + +class GeminiUrlContext(BasePlatformTool): + """Gemini 网址上下文工具""" + + urls: list[str] = Field(..., description="作为上下文的 URL 列表", max_length=20) + + def get_tool_declaration(self) -> dict[str, Any]: + return {"google_search": {}, "url_context": {}} + + def get_tool_config(self) -> dict[str, Any] | None: + return None + + class ToolResult(BaseModel): """ 一个结构化的工具执行结果模型。 @@ -87,6 +592,8 @@ class ModelDetail(BaseModel): is_embedding_model: bool = False temperature: float | None = None max_tokens: int | None = None + api_type: str | None = None + endpoint: str | None = None class ProviderConfig(BaseModel): @@ -116,6 +623,7 @@ class LLMToolCall(BaseModel): id: str function: LLMToolFunction + thought_signature: str | None = None class LLMCodeExecution(BaseModel): @@ -143,6 +651,12 @@ class LLMGroundingMetadata(BaseModel): web_search_queries: list[str] | None = None grounding_attributions: list[LLMGroundingAttribution] | None = None search_suggestions: list[dict[str, Any]] | None = None + search_entry_point: str | None = Field( + default=None, description="Google搜索建议的HTML片段(renderedContent)" + ) + map_widget_token: str | None = Field( + default=None, description="Google Maps 前端组件令牌" + ) class LLMCacheInfo(BaseModel): diff --git a/zhenxun/services/llm/types/protocols.py b/zhenxun/services/llm/types/protocols.py index d7e295fc..d63aedcd 100644 --- a/zhenxun/services/llm/types/protocols.py +++ b/zhenxun/services/llm/types/protocols.py @@ -2,10 +2,97 @@ LLM 模块的协议定义 """ -from typing import Any, Protocol +from abc import ABC +from typing import TYPE_CHECKING, Any, Protocol, Union + +from pydantic import BaseModel from .models import ToolDefinition, ToolResult +if TYPE_CHECKING: + from .models import LLMMessage, LLMResponse, LLMToolCall + + +class ToolCallData(BaseModel): + """传递给 on_tool_start 的数据模型""" + + tool_name: str + tool_args: dict[str, Any] + + +class ToolCallCompleteData(BaseModel): + """传递给 on_tool_call_complete 的数据模型""" + + id: str + name: str + arguments: str + result: "ToolResult" + + +class BaseCallbackHandler(ABC): + """ + Agent/LLM 生命周期回调处理器的基类。 + 下沉至 LLM 层以允许 ToolInvoker 直接调用。 + """ + + async def on_agent_start(self, messages: list["LLMMessage"], **kwargs: Any) -> None: + """在 AgentExecutor 开始运行时调用。""" + pass + + async def on_model_start( + self, model_name: str, messages: list["LLMMessage"], **kwargs: Any + ) -> None: + """在向LLM发起请求之前调用。""" + pass + + async def on_model_end( + self, response: "LLMResponse", duration: float, **kwargs: Any + ) -> None: + """在收到LLM响应之后调用。""" + pass + + async def on_tool_start( + self, tool_call: "LLMToolCall", data: ToolCallData, **kwargs: Any + ) -> Union[ToolCallData, "ToolResult", None]: + """ + 在单个工具即将被执行时调用。 + + 返回: + ToolCallData: 修改参数并继续执行 + ToolResult: 拦截执行并直接返回给模型 + None: 正常继续 + """ + pass + + async def on_tool_end( + self, + result: Union["ToolResult", None], + error: Exception | None, + tool_call: "LLMToolCall", + duration: float, + **kwargs: Any, + ) -> None: + """在单个工具执行完毕后调用,无论成功或失败。""" + pass + + async def on_tool_call_complete( + self, data: ToolCallCompleteData, **kwargs: Any + ) -> None: + """在工具调用完成并准备创建响应消息时调用。""" + pass + + async def on_human_input_request(self, query: str, **kwargs: Any) -> str | None: + """ + 当 Agent 需要人类输入时调用。 + """ + return None + + async def on_agent_end( + self, final_history: list["LLMMessage"], duration: float, **kwargs: Any + ) -> None: + """在 AgentExecutor 运行结束时调用。""" + pass + class ToolExecutable(Protocol): """ @@ -19,10 +106,14 @@ class ToolExecutable(Protocol): """ ... - async def execute(self, **kwargs: Any) -> ToolResult: + async def execute(self, context: Any | None = None, **kwargs: Any) -> ToolResult: """ 异步执行工具并返回一个结构化的结果。 参数由LLM根据工具定义生成。 + + Args: + context: 运行时上下文 (RunContext),可选注入 + **kwargs: 工具参数 """ ... diff --git a/zhenxun/services/llm/utils.py b/zhenxun/services/llm/utils.py index 624d1557..aed1dc39 100644 --- a/zhenxun/services/llm/utils.py +++ b/zhenxun/services/llm/utils.py @@ -3,26 +3,176 @@ LLM 模块的工具和转换函数 """ import base64 -import copy +from collections.abc import Awaitable, Callable +import io from pathlib import Path -from typing import Any +from typing import Any, TypeVar +import aiofiles +import json_repair from nonebot.adapters import Message as PlatformMessage +from nonebot.compat import type_validate_json from nonebot_plugin_alconna.uniseg import ( At, File, Image, Reply, + Segment, Text, UniMessage, Video, Voice, ) +from PIL.Image import Image as PILImageType +from pydantic import BaseModel, Field, ValidationError, create_model from zhenxun.services.log import logger from zhenxun.utils.http_utils import AsyncHttpx +from zhenxun.utils.pydantic_compat import model_validate -from .types import LLMContentPart, LLMMessage +from .types import LLMContentPart, LLMErrorCode, LLMException, LLMMessage +from .types.capabilities import ReasoningMode, get_model_capabilities + +T = TypeVar("T", bound=BaseModel) + + +S = TypeVar("S", bound=Segment) +_SEGMENT_HANDLERS: dict[ + type[Segment], Callable[[Any], Awaitable[LLMContentPart | None]] +] = {} + + +def register_segment_handler(seg_type: type[S]): + """装饰器:注册 Uniseg 消息段的处理器""" + + def decorator(func: Callable[[S], Awaitable[LLMContentPart | None]]): + _SEGMENT_HANDLERS[seg_type] = func + return func + + return decorator + + +async def _process_media_data(seg: Any, default_mime: str) -> tuple[str, str] | None: + """ + [内部复用] 通用媒体数据处理:获取 Base64 数据和 MIME 类型。 + 优先顺序:Raw -> Path -> URL (下载) + """ + mime_type = getattr(seg, "mimetype", None) or default_mime + b64_data = None + + if hasattr(seg, "raw") and seg.raw: + if isinstance(seg.raw, bytes): + b64_data = base64.b64encode(seg.raw).decode("utf-8") + + elif getattr(seg, "path", None): + try: + path = Path(seg.path) + if path.exists(): + async with aiofiles.open(path, "rb") as f: + content = await f.read() + b64_data = base64.b64encode(content).decode("utf-8") + except Exception as e: + logger.error(f"读取媒体文件失败: {seg.path}, 错误: {e}") + + elif getattr(seg, "url", None): + try: + logger.debug(f"检测到媒体URL,开始下载: {seg.url}") + media_bytes = await AsyncHttpx.get_content(seg.url) + b64_data = base64.b64encode(media_bytes).decode("utf-8") + logger.debug(f"媒体文件下载成功,大小: {len(media_bytes)} bytes") + except Exception as e: + logger.error(f"从URL下载媒体失败: {seg.url}, 错误: {e}") + return None + + if b64_data: + return mime_type, b64_data + return None + + +@register_segment_handler(Text) +async def _handle_text(seg: Text) -> LLMContentPart | None: + if seg.text.strip(): + return LLMContentPart.text_part(seg.text) + return None + + +@register_segment_handler(Image) +async def _handle_image(seg: Image) -> LLMContentPart | None: + media_info = await _process_media_data(seg, "image/png") + if media_info: + mime, data = media_info + return LLMContentPart.image_base64_part(data, mime) + return None + + +@register_segment_handler(Voice) +async def _handle_voice(seg: Voice) -> LLMContentPart | None: + media_info = await _process_media_data(seg, "audio/wav") + if media_info: + mime, data = media_info + return LLMContentPart.audio_base64_part(data, mime) + return LLMContentPart.text_part(f"[语音消息: {seg.id or 'unknown'}]") + + +@register_segment_handler(Video) +async def _handle_video(seg: Video) -> LLMContentPart | None: + media_info = await _process_media_data(seg, "video/mp4") + if media_info: + mime, data = media_info + return LLMContentPart.video_base64_part(data, mime) + return LLMContentPart.text_part(f"[视频消息: {seg.id or 'unknown'}]") + + +@register_segment_handler(File) +async def _handle_file(seg: File) -> LLMContentPart | None: + if seg.path: + return await LLMContentPart.from_path(seg.path) + return LLMContentPart.text_part(f"[文件: {seg.name} (ID: {seg.id})]") + + +@register_segment_handler(At) +async def _handle_at(seg: At) -> LLMContentPart | None: + if seg.flag == "all": + return LLMContentPart.text_part("[提及所有人]") + return LLMContentPart.text_part(f"[提及用户: {seg.target}]") + + +@register_segment_handler(Reply) +async def _handle_reply(seg: Reply) -> LLMContentPart | None: + text = str(seg.msg) if seg.msg else "" + if text: + return LLMContentPart.text_part(f'[回复消息: "{text[:50]}..."]') + return LLMContentPart.text_part("[回复了一条消息]") + + +async def _transform_to_content_part(item: Any) -> LLMContentPart: + """ + 将混合输入转换为统一的 LLMContentPart,便于 normalize_to_llm_messages 使用。 + """ + if isinstance(item, LLMContentPart): + return item + + if isinstance(item, str): + return LLMContentPart.text_part(item) + + if isinstance(item, Path): + part = await LLMContentPart.from_path(item) + if part is None: + raise ValueError(f"无法从路径加载内容: {item}") + return part + + if isinstance(item, dict): + return LLMContentPart(**item) + + if PILImageType and isinstance(item, PILImageType): + buffer = io.BytesIO() + fmt = item.format or "PNG" + item.save(buffer, format=fmt) + b64_data = base64.b64encode(buffer.getvalue()).decode("utf-8") + mime_type = f"image/{fmt.lower()}" + return LLMContentPart.image_base64_part(b64_data, mime_type) + + raise TypeError(f"不支持的输入类型用于构建 ContentPart: {type(item)}") async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]: @@ -36,110 +186,25 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]: 返回: list[LLMContentPart]: 转换后的内容部分列表。 """ + if not _SEGMENT_HANDLERS: + pass + parts: list[LLMContentPart] = [] for seg in message: - part = None - if isinstance(seg, Text): - if seg.text.strip(): - part = LLMContentPart.text_part(seg.text) - elif isinstance(seg, Image): - if seg.path: - part = await LLMContentPart.from_path(seg.path, target_api="gemini") - elif seg.url: - part = LLMContentPart.image_url_part(seg.url) - elif hasattr(seg, "raw") and seg.raw: - mime_type = ( - getattr(seg, "mimetype", "image/png") - if hasattr(seg, "mimetype") - else "image/png" - ) - if isinstance(seg.raw, bytes): - b64_data = base64.b64encode(seg.raw).decode("utf-8") - part = LLMContentPart.image_base64_part(b64_data, mime_type) - - elif isinstance(seg, File | Voice | Video): - if seg.path: - part = await LLMContentPart.from_path(seg.path) - elif seg.url: - try: - logger.debug(f"检测到媒体URL,开始下载: {seg.url}") - media_bytes = await AsyncHttpx.get_content(seg.url) - - new_seg = copy.copy(seg) - new_seg.raw = media_bytes - seg = new_seg - logger.debug(f"媒体文件下载成功,大小: {len(media_bytes)} bytes") - except Exception as e: - logger.error(f"从URL下载媒体失败: {seg.url}, 错误: {e}") - part = LLMContentPart.text_part( - f"[下载媒体失败: {seg.name or seg.url}]" - ) - + handler = _SEGMENT_HANDLERS.get(type(seg)) + if handler: + try: + part = await handler(seg) if part: parts.append(part) - continue - - if hasattr(seg, "raw") and seg.raw: - mime_type = getattr(seg, "mimetype", None) - if isinstance(seg.raw, bytes): - b64_data = base64.b64encode(seg.raw).decode("utf-8") - - if isinstance(seg, Video): - if not mime_type: - mime_type = "video/mp4" - part = LLMContentPart.video_base64_part( - data=b64_data, mime_type=mime_type - ) - logger.debug( - f"处理视频字节数据: {mime_type}, 大小: {len(seg.raw)} bytes" - ) - elif isinstance(seg, Voice): - if not mime_type: - mime_type = "audio/wav" - part = LLMContentPart.audio_base64_part( - data=b64_data, mime_type=mime_type - ) - logger.debug( - f"处理音频字节数据: {mime_type}, 大小: {len(seg.raw)} bytes" - ) - else: - part = LLMContentPart.text_part( - f"[FILE: {mime_type or 'unknown'}, {len(seg.raw)} bytes]" - ) - logger.debug( - f"处理其他文件字节数据: {mime_type}, " - f"大小: {len(seg.raw)} bytes" - ) - - elif isinstance(seg, At): - if seg.flag == "all": - part = LLMContentPart.text_part("[提及所有人]") - else: - part = LLMContentPart.text_part(f"[提及用户: {seg.target}]") - - elif isinstance(seg, Reply): - if seg.msg: - try: - extract_method = getattr(seg.msg, "extract_plain_text", None) - if extract_method and callable(extract_method): - reply_text = str(extract_method()).strip() - else: - reply_text = str(seg.msg).strip() - if reply_text: - part = LLMContentPart.text_part( - f'[回复消息: "{reply_text[:50]}..."]' - ) - except Exception: - part = LLMContentPart.text_part("[回复了一条消息]") - - if part: - parts.append(part) + except Exception as e: + logger.warning(f"处理消息段 {seg} 失败: {e}", "LLMUtils") return parts async def normalize_to_llm_messages( - message: str | UniMessage | LLMMessage | list[LLMContentPart] | list[LLMMessage], + message: str | UniMessage | LLMMessage | list[Any], instruction: str | None = None, ) -> list[LLMMessage]: """ @@ -167,7 +232,10 @@ async def normalize_to_llm_messages( content_parts = await unimsg_to_llm_parts(message) messages.append(LLMMessage.user(content_parts)) elif isinstance(message, list): - messages.append(LLMMessage.user(message)) # type: ignore + parts = [] + for item in message: + parts.append(await _transform_to_content_part(item)) + messages.append(LLMMessage.user(parts)) else: raise TypeError(f"不支持的消息类型: {type(message)}") @@ -255,53 +323,271 @@ def message_to_unimessage(message: PlatformMessage) -> UniMessage: 返回: UniMessage: 转换后的通用消息对象。 """ - uni_segments = [] - for seg in message: - if seg.type == "text": - uni_segments.append(Text(seg.data.get("text", ""))) - elif seg.type == "image": - uni_segments.append(Image(url=seg.data.get("url"))) - elif seg.type == "record": - uni_segments.append(Voice(url=seg.data.get("url"))) - elif seg.type == "video": - uni_segments.append(Video(url=seg.data.get("url"))) - elif seg.type == "at": - uni_segments.append(At("user", str(seg.data.get("qq", "")))) - else: - logger.debug(f"跳过不支持的平台消息段类型: {seg.type}") + return UniMessage.of(message) - return UniMessage(uni_segments) + +def resolve_json_schema_refs(schema: dict) -> dict: + """ + 递归解析 JSON Schema 中的 $ref,将其替换为 $defs/definitions 中的定义。 + 用于兼容不支持 $ref 的 Gemini API。 + """ + definitions = schema.get("$defs") or schema.get("definitions") or {} + + def _resolve(node: Any) -> Any: + if isinstance(node, dict): + if "$ref" in node: + ref_name = node["$ref"].split("/")[-1] + if ref_name in definitions: + return _resolve(definitions[ref_name]) + + return { + key: _resolve(value) + for key, value in node.items() + if key not in ("$defs", "definitions") + } + + if isinstance(node, list): + return [_resolve(item) for item in node] + + return node + + return _resolve(schema) def sanitize_schema_for_llm(schema: Any, api_type: str) -> Any: """ 递归地净化 JSON Schema,移除特定 LLM API 不支持的关键字。 - - 参数: - schema: 要净化的 JSON Schema (可以是字典、列表或其它类型)。 - api_type: 目标 API 的类型,例如 'gemini'。 - - 返回: - Any: 净化后的 JSON Schema。 """ - if isinstance(schema, dict): - schema_copy = {} - for key, value in schema.items(): - if api_type == "gemini": - unsupported_keys = ["exclusiveMinimum", "exclusiveMaximum", "default"] - if key in unsupported_keys: - continue - - if key == "format" and isinstance(value, str): - supported_formats = ["enum", "date-time"] - if value not in supported_formats: - continue - - schema_copy[key] = sanitize_schema_for_llm(value, api_type) - return schema_copy - - elif isinstance(schema, list): + if isinstance(schema, list): return [sanitize_schema_for_llm(item, api_type) for item in schema] + if isinstance(schema, dict): + schema_copy = schema.copy() + if api_type == "gemini": + if "const" in schema_copy: + schema_copy["enum"] = [schema_copy.pop("const")] + + if "type" in schema_copy and isinstance(schema_copy["type"], list): + types_list = schema_copy["type"] + if "null" in types_list: + schema_copy["nullable"] = True + types_list = [t for t in types_list if t != "null"] + if len(types_list) == 1: + schema_copy["type"] = types_list[0] + else: + schema_copy["type"] = types_list + + if "anyOf" in schema_copy: + any_of = schema_copy["anyOf"] + has_null = any( + isinstance(x, dict) and x.get("type") == "null" for x in any_of + ) + if has_null: + schema_copy["nullable"] = True + new_any_of = [ + x + for x in any_of + if not (isinstance(x, dict) and x.get("type") == "null") + ] + if len(new_any_of) == 1: + schema_copy.update(new_any_of[0]) + schema_copy.pop("anyOf", None) + else: + schema_copy["anyOf"] = new_any_of + + unsupported_keys = [ + "exclusiveMinimum", + "exclusiveMaximum", + "default", + "title", + "additionalProperties", + "$schema", + "$id", + ] + for key in unsupported_keys: + schema_copy.pop(key, None) + + if schema_copy.get("format") and schema_copy["format"] not in [ + "enum", + "date-time", + ]: + schema_copy.pop("format", None) + + elif api_type == "openai": + unsupported_keys = [ + "default", + "minLength", + "maxLength", + "pattern", + "format", + "minimum", + "maximum", + "multipleOf", + "patternProperties", + "minItems", + "maxItems", + "uniqueItems", + "$schema", + "title", + ] + for key in unsupported_keys: + schema_copy.pop(key, None) + + if "$ref" in schema_copy: + ref_key = schema_copy["$ref"].split("/")[-1] + defs = schema_copy.get("$defs") or schema_copy.get("definitions") + if defs and ref_key in defs: + schema_copy.pop("$ref", None) + schema_copy.update(defs[ref_key]) + else: + return {"$ref": schema_copy["$ref"]} + + is_object = ( + schema_copy.get("type") == "object" or "properties" in schema_copy + ) + if is_object: + schema_copy["type"] = "object" + schema_copy["additionalProperties"] = False + + properties = schema_copy.get("properties", {}) + required = schema_copy.get("required", []) + if properties: + existing_req = set(required) + for prop in properties.keys(): + if prop not in existing_req: + required.append(prop) + schema_copy["required"] = required + + for def_key in ["$defs", "definitions"]: + if def_key in schema_copy and isinstance(schema_copy[def_key], dict): + schema_copy[def_key] = { + k: sanitize_schema_for_llm(v, api_type) + for k, v in schema_copy[def_key].items() + } + + recursive_keys = ["properties", "items", "allOf", "anyOf", "oneOf"] + for key in recursive_keys: + if key in schema_copy: + if key == "properties" and isinstance(schema_copy[key], dict): + schema_copy[key] = { + k: sanitize_schema_for_llm(v, api_type) + for k, v in schema_copy[key].items() + } + else: + schema_copy[key] = sanitize_schema_for_llm( + schema_copy[key], api_type + ) + + return schema_copy else: return schema + + +def extract_text_from_content( + content: str | list[LLMContentPart] | None, +) -> str: + """ + 从消息内容中提取纯文本,自动过滤非文本部分,防止污染 Prompt。 + """ + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + return " ".join( + part.text for part in content if part.type == "text" and part.text + ) + return str(content) + + +def parse_and_validate_json(text: str, response_model: type[T]) -> T: + """ + 通用工具:尝试将文本解析为指定的 Pydantic 模型,并统一处理异常。 + """ + try: + return type_validate_json(response_model, text) + except (ValidationError, ValueError) as e: + try: + logger.warning(f"标准JSON解析失败,尝试使用json_repair修复: {e}") + repaired_obj = json_repair.loads(text, skip_json_loads=True) + return model_validate(response_model, repaired_obj) + except Exception as repair_error: + logger.error( + f"LLM结构化输出校验最终失败: {repair_error}", + e=repair_error, + ) + raise LLMException( + "LLM返回的JSON未能通过结构验证。", + code=LLMErrorCode.RESPONSE_PARSE_ERROR, + details={ + "raw_response": text, + "validation_error": str(repair_error), + "original_error": repair_error, + }, + cause=repair_error, + ) + except Exception as e: + logger.error(f"解析LLM结构化输出时发生未知错误: {e}", e=e) + raise LLMException( + "解析LLM的JSON输出时失败。", + code=LLMErrorCode.RESPONSE_PARSE_ERROR, + details={"raw_response": text}, + cause=e, + ) + + +def create_cot_wrapper(inner_model: type[BaseModel]) -> type[BaseModel]: + """ + [动态运行时封装] + 创建一个包含思维链 (Chain of Thought) 的包装模型。 + 强制模型在生成最终 JSON 结构前,先输出一个 reasoning 字段进行思考。 + """ + wrapper_name = f"CoT_{inner_model.__name__}" + + return create_model( + wrapper_name, + reasoning=( + str, + Field( + ..., + min_length=10, + description=( + "在生成最终结果之前,请务必在此字段中详细描述你的推理步骤、计算过程或思考逻辑。禁止留空。" + ), + ), + ), + result=( + inner_model, + Field( + ..., + ), + ), + ) + + +def should_apply_autocot( + requested: bool, + model_name: str | None, + config: Any, +) -> bool: + """ + [智能决策管道] + 判断是否应该应用 AutoCoT (显式思维链包装)。 + 防止在模型已有原生思维能力时进行“双重思考”。 + """ + if not requested: + return False + + if config: + thinking_budget = getattr(config, "thinking_budget", 0) or 0 + if thinking_budget > 0: + return False + if getattr(config, "thinking_level", None) is not None: + return False + + if model_name: + caps = get_model_capabilities(model_name) + if caps.reasoning_mode != ReasoningMode.NONE: + return False + + return True diff --git a/zhenxun/utils/log_sanitizer.py b/zhenxun/utils/log_sanitizer.py index 2884da5c..9d8a5c2b 100644 --- a/zhenxun/utils/log_sanitizer.py +++ b/zhenxun/utils/log_sanitizer.py @@ -14,9 +14,34 @@ def _truncate_base64_string(value: str, threshold: int = 256) -> str: if value.startswith(prefixes) and len(value) > threshold: prefix = next((p for p in prefixes if value.startswith(p)), "base64") return f"[{prefix}_data_omitted_len={len(value)}]" + + if len(value) > 1000: + return f"[long_string_omitted_len={len(value)}] {value[:20]}...{value[-20:]}" + + if len(value) > 2000: + return f"[long_string_omitted_len={len(value)}] {value[:50]}...{value[-20:]}" + return value +def _truncate_vector_list(vector: list, threshold: int = 10) -> list: + """如果列表过长(通常是embedding向量),则截断它用于日志显示。""" + if isinstance(vector, list) and len(vector) > threshold: + return [*vector[:3], f"...({len(vector)} floats omitted)...", *vector[-3:]] + return vector + + +def _recursive_sanitize_any(obj: Any) -> Any: + """递归清洗任何对象中的长字符串""" + if isinstance(obj, dict): + return {k: _recursive_sanitize_any(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [_recursive_sanitize_any(v) for v in obj] + elif isinstance(obj, str): + return _truncate_base64_string(obj) + return obj + + def _sanitize_ui_html(html_string: str) -> str: """ 专门用于净化UI渲染调试HTML的函数。 @@ -64,6 +89,37 @@ def _sanitize_openai_response(response_json: dict) -> dict: message["images"][i]["image_url"]["url"] = ( _truncate_base64_string(url) ) + if "reasoning_details" in message and isinstance( + message["reasoning_details"], list + ): + for detail in message["reasoning_details"]: + if isinstance(detail, dict): + if "data" in detail and isinstance(detail["data"], str): + if len(detail["data"]) > 100: + detail["data"] = ( + f"[encrypted_data_omitted_len={len(detail['data'])}]" + ) + if "text" in detail and isinstance(detail["text"], str): + detail["text"] = _truncate_base64_string( + detail["text"], threshold=2000 + ) + if "data" in sanitized_json and isinstance(sanitized_json["data"], list): + for item in sanitized_json["data"]: + if "embedding" in item and isinstance(item["embedding"], list): + item["embedding"] = _truncate_vector_list(item["embedding"]) + if "b64_json" in item and isinstance(item["b64_json"], str): + if len(item["b64_json"]) > 256: + item["b64_json"] = ( + f"[base64_json_omitted_len={len(item['b64_json'])}]" + ) + if "input" in sanitized_json and isinstance(sanitized_json["input"], list): + for item in sanitized_json["input"]: + if "content" in item and isinstance(item["content"], list): + for part in item["content"]: + if isinstance(part, dict) and part.get("type") == "input_image": + image_url = part.get("image_url") + if isinstance(image_url, str): + part["image_url"] = _truncate_base64_string(image_url) return sanitized_json except Exception: return response_json @@ -71,22 +127,44 @@ def _sanitize_openai_response(response_json: dict) -> dict: def _sanitize_openai_request(body: dict) -> dict: """净化OpenAI兼容API的请求体,主要截断图片base64。""" + from zhenxun.services.llm.config.providers import ( + DebugLogOptions, + get_llm_config, + ) + + debug_conf = get_llm_config().debug_log + if isinstance(debug_conf, bool): + debug_conf = DebugLogOptions( + show_tools=debug_conf, show_schema=debug_conf, show_safety=debug_conf + ) + try: - sanitized_json = copy.deepcopy(body) - if "messages" in sanitized_json and isinstance( - sanitized_json["messages"], list - ): - for message in sanitized_json["messages"]: - if "content" in message and isinstance(message["content"], list): - for i, part in enumerate(message["content"]): - if part.get("type") == "image_url": - if "image_url" in part and isinstance( - part["image_url"], dict - ): - url = part["image_url"].get("url", "") - message["content"][i]["image_url"]["url"] = ( - _truncate_base64_string(url) - ) + sanitized_json = _recursive_sanitize_any(copy.deepcopy(body)) + if "tools" in sanitized_json and not debug_conf.show_tools: + tools = sanitized_json["tools"] + if isinstance(tools, list): + tool_names = [] + for t in tools: + if isinstance(t, dict): + name = None + if "function" in t and isinstance(t["function"], dict): + name = t["function"].get("name") + if not name and "name" in t: + name = t.get("name") + tool_names.append(name or "unknown") + sanitized_json["tools"] = ( + f"<{len(tool_names)} tools hidden: {', '.join(tool_names)}>" + ) + + if "response_format" in sanitized_json and not debug_conf.show_schema: + response_format = sanitized_json["response_format"] + if isinstance(response_format, dict): + if response_format.get("type") == "json_schema": + sanitized_json["response_format"] = { + "type": "json_schema", + "json_schema": "