mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-14 21:52:56 +08:00
Merge 19f60d34c9 into e5b2a872d3
This commit is contained in:
commit
df1a2429b1
5578
envs/pydantic-v1/poetry.lock
generated
5578
envs/pydantic-v1/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -36,7 +36,6 @@ feedparser = "^6.0.11"
|
|||||||
imagehash = "^4.3.1"
|
imagehash = "^4.3.1"
|
||||||
cn2an = "^0.5.22"
|
cn2an = "^0.5.22"
|
||||||
dateparser = "^1.2.0"
|
dateparser = "^1.2.0"
|
||||||
bilireq = ">=0.2.10"
|
|
||||||
python-jose = { extras = ["cryptography"], version = "^3.3.0" }
|
python-jose = { extras = ["cryptography"], version = "^3.3.0" }
|
||||||
python-multipart = "^0.0.9"
|
python-multipart = "^0.0.9"
|
||||||
aiocache = {extras = ["redis"], version = "^0.12.3"}
|
aiocache = {extras = ["redis"], version = "^0.12.3"}
|
||||||
@ -47,10 +46,10 @@ nonebot-plugin-uninfo = ">=0.7.3"
|
|||||||
nonebot-plugin-waiter = "^0.8.1"
|
nonebot-plugin-waiter = "^0.8.1"
|
||||||
multidict = ">=6.0.0,!=6.3.2"
|
multidict = ">=6.0.0,!=6.3.2"
|
||||||
pydantic = ">=1.0.0, <2.0.0"
|
pydantic = ">=1.0.0, <2.0.0"
|
||||||
|
|
||||||
redis = { version = ">=5", optional = true }
|
redis = { version = ">=5", optional = true }
|
||||||
asyncpg = { version = ">=0.20.0", optional = true }
|
asyncpg = { version = ">=0.20.0", optional = true }
|
||||||
alibabacloud-devops20210625 = "^5.0.2"
|
alibabacloud-devops20210625 = "^5.0.2"
|
||||||
|
json_repair = "^0.54.0"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
nonebug = "^0.4"
|
nonebug = "^0.4"
|
||||||
|
|||||||
5688
envs/pydantic-v2/poetry.lock
generated
5688
envs/pydantic-v2/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -36,7 +36,6 @@ feedparser = "^6.0.11"
|
|||||||
imagehash = "^4.3.1"
|
imagehash = "^4.3.1"
|
||||||
cn2an = "^0.5.22"
|
cn2an = "^0.5.22"
|
||||||
dateparser = "^1.2.0"
|
dateparser = "^1.2.0"
|
||||||
bilireq = ">=0.2.10"
|
|
||||||
python-jose = { extras = ["cryptography"], version = "^3.3.0" }
|
python-jose = { extras = ["cryptography"], version = "^3.3.0" }
|
||||||
python-multipart = "^0.0.9"
|
python-multipart = "^0.0.9"
|
||||||
aiocache = {extras = ["redis"], version = "^0.12.3"}
|
aiocache = {extras = ["redis"], version = "^0.12.3"}
|
||||||
@ -47,10 +46,10 @@ nonebot-plugin-uninfo = ">=0.7.3"
|
|||||||
nonebot-plugin-waiter = "^0.8.1"
|
nonebot-plugin-waiter = "^0.8.1"
|
||||||
multidict = ">=6.0.0,!=6.3.2"
|
multidict = ">=6.0.0,!=6.3.2"
|
||||||
pydantic = ">=2.0.0, <3.0.0"
|
pydantic = ">=2.0.0, <3.0.0"
|
||||||
|
|
||||||
redis = { version = ">=5", optional = true }
|
redis = { version = ">=5", optional = true }
|
||||||
asyncpg = { version = ">=0.20.0", optional = true }
|
asyncpg = { version = ">=0.20.0", optional = true }
|
||||||
alibabacloud-devops20210625 = "^5.0.2"
|
alibabacloud-devops20210625 = "^5.0.2"
|
||||||
|
json_repair = "^0.54.0"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
nonebug = "^0.4"
|
nonebug = "^0.4"
|
||||||
|
|||||||
@ -36,7 +36,6 @@ feedparser = "^6.0.11"
|
|||||||
imagehash = "^4.3.1"
|
imagehash = "^4.3.1"
|
||||||
cn2an = "^0.5.22"
|
cn2an = "^0.5.22"
|
||||||
dateparser = "^1.2.0"
|
dateparser = "^1.2.0"
|
||||||
bilireq = ">=0.2.10"
|
|
||||||
python-jose = { extras = ["cryptography"], version = "^3.3.0" }
|
python-jose = { extras = ["cryptography"], version = "^3.3.0" }
|
||||||
python-multipart = "^0.0.9"
|
python-multipart = "^0.0.9"
|
||||||
aiocache = {extras = ["redis"], version = "^0.12.3"}
|
aiocache = {extras = ["redis"], version = "^0.12.3"}
|
||||||
@ -46,6 +45,7 @@ tenacity = "^9.0.0"
|
|||||||
nonebot-plugin-uninfo = ">=0.7.3"
|
nonebot-plugin-uninfo = ">=0.7.3"
|
||||||
nonebot-plugin-waiter = "^0.8.1"
|
nonebot-plugin-waiter = "^0.8.1"
|
||||||
multidict = ">=6.0.0,!=6.3.2"
|
multidict = ">=6.0.0,!=6.3.2"
|
||||||
|
json_repair = "^0.54.0"
|
||||||
|
|
||||||
redis = { version = ">=5", optional = true }
|
redis = { version = ">=5", optional = true }
|
||||||
asyncpg = { version = ">=0.20.0", optional = true }
|
asyncpg = { version = ">=0.20.0", optional = true }
|
||||||
|
|||||||
@ -21,7 +21,6 @@ feedparser>=6.0.11,<7.0.0
|
|||||||
ImageHash>=4.3.1,<5.0.0
|
ImageHash>=4.3.1,<5.0.0
|
||||||
cn2an>=0.5.22,<0.6.0
|
cn2an>=0.5.22,<0.6.0
|
||||||
dateparser>=1.2.0,<2.0.0
|
dateparser>=1.2.0,<2.0.0
|
||||||
bilireq>=0.2.10
|
|
||||||
python-jose[cryptography]>=3.3.0,<4.0.0
|
python-jose[cryptography]>=3.3.0,<4.0.0
|
||||||
python-multipart>=0.0.9,<0.1.0
|
python-multipart>=0.0.9,<0.1.0
|
||||||
aiocache[redis]>=0.12.3,<0.13.0
|
aiocache[redis]>=0.12.3,<0.13.0
|
||||||
@ -32,6 +31,6 @@ nonebot-plugin-uninfo>=0.7.3
|
|||||||
nonebot-plugin-waiter>=0.8.1,<0.9.0
|
nonebot-plugin-waiter>=0.8.1,<0.9.0
|
||||||
multidict>=6.0.0,<7.0.0,!=6.3.2
|
multidict>=6.0.0,<7.0.0,!=6.3.2
|
||||||
alibabacloud-devops20210625>=5.0.2,<6.0.0
|
alibabacloud-devops20210625>=5.0.2,<6.0.0
|
||||||
|
json_repair>=0.54.0,<0.55.0
|
||||||
redis>=5
|
redis>=5
|
||||||
asyncpg>=0.20.0
|
asyncpg>=0.20.0
|
||||||
|
|||||||
@ -9,13 +9,15 @@ from .api import (
|
|||||||
code,
|
code,
|
||||||
create_image,
|
create_image,
|
||||||
embed,
|
embed,
|
||||||
|
embed_documents,
|
||||||
|
embed_query,
|
||||||
generate,
|
generate,
|
||||||
generate_structured,
|
generate_structured,
|
||||||
run_with_tools,
|
|
||||||
search,
|
search,
|
||||||
)
|
)
|
||||||
from .config import (
|
from .config import (
|
||||||
CommonOverrides,
|
CommonOverrides,
|
||||||
|
GenConfigBuilder,
|
||||||
LLMGenerationConfig,
|
LLMGenerationConfig,
|
||||||
register_llm_configs,
|
register_llm_configs,
|
||||||
)
|
)
|
||||||
@ -32,8 +34,8 @@ from .manager import (
|
|||||||
list_model_identifiers,
|
list_model_identifiers,
|
||||||
set_global_default_model_name,
|
set_global_default_model_name,
|
||||||
)
|
)
|
||||||
from .session import AI, AIConfig
|
from .session import AI, AIConfig, MemoryProcessor, set_default_memory_backend
|
||||||
from .tools import function_tool, tool_provider_manager
|
from .tools import RunContext, ToolInvoker, function_tool, tool_provider_manager
|
||||||
from .types import (
|
from .types import (
|
||||||
EmbeddingTaskType,
|
EmbeddingTaskType,
|
||||||
LLMContentPart,
|
LLMContentPart,
|
||||||
@ -50,6 +52,11 @@ from .types import (
|
|||||||
ToolMetadata,
|
ToolMetadata,
|
||||||
UsageInfo,
|
UsageInfo,
|
||||||
)
|
)
|
||||||
|
from .types.models import (
|
||||||
|
GeminiCodeExecution,
|
||||||
|
GeminiGoogleSearch,
|
||||||
|
GeminiUrlContext,
|
||||||
|
)
|
||||||
from .utils import create_multimodal_message, message_to_unimessage, unimsg_to_llm_parts
|
from .utils import create_multimodal_message, message_to_unimessage, unimsg_to_llm_parts
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -57,19 +64,26 @@ __all__ = [
|
|||||||
"AIConfig",
|
"AIConfig",
|
||||||
"CommonOverrides",
|
"CommonOverrides",
|
||||||
"EmbeddingTaskType",
|
"EmbeddingTaskType",
|
||||||
|
"GeminiCodeExecution",
|
||||||
|
"GeminiGoogleSearch",
|
||||||
|
"GeminiUrlContext",
|
||||||
|
"GenConfigBuilder",
|
||||||
"LLMContentPart",
|
"LLMContentPart",
|
||||||
"LLMErrorCode",
|
"LLMErrorCode",
|
||||||
"LLMException",
|
"LLMException",
|
||||||
"LLMGenerationConfig",
|
"LLMGenerationConfig",
|
||||||
"LLMMessage",
|
"LLMMessage",
|
||||||
"LLMResponse",
|
"LLMResponse",
|
||||||
|
"MemoryProcessor",
|
||||||
"ModelDetail",
|
"ModelDetail",
|
||||||
"ModelInfo",
|
"ModelInfo",
|
||||||
"ModelName",
|
"ModelName",
|
||||||
"ModelProvider",
|
"ModelProvider",
|
||||||
"ResponseFormat",
|
"ResponseFormat",
|
||||||
|
"RunContext",
|
||||||
"TaskType",
|
"TaskType",
|
||||||
"ToolCategory",
|
"ToolCategory",
|
||||||
|
"ToolInvoker",
|
||||||
"ToolMetadata",
|
"ToolMetadata",
|
||||||
"UsageInfo",
|
"UsageInfo",
|
||||||
"chat",
|
"chat",
|
||||||
@ -78,6 +92,8 @@ __all__ = [
|
|||||||
"create_image",
|
"create_image",
|
||||||
"create_multimodal_message",
|
"create_multimodal_message",
|
||||||
"embed",
|
"embed",
|
||||||
|
"embed_documents",
|
||||||
|
"embed_query",
|
||||||
"function_tool",
|
"function_tool",
|
||||||
"generate",
|
"generate",
|
||||||
"generate_structured",
|
"generate_structured",
|
||||||
@ -89,8 +105,8 @@ __all__ = [
|
|||||||
"list_model_identifiers",
|
"list_model_identifiers",
|
||||||
"message_to_unimessage",
|
"message_to_unimessage",
|
||||||
"register_llm_configs",
|
"register_llm_configs",
|
||||||
"run_with_tools",
|
|
||||||
"search",
|
"search",
|
||||||
|
"set_default_memory_backend",
|
||||||
"set_global_default_model_name",
|
"set_global_default_model_name",
|
||||||
"tool_provider_manager",
|
"tool_provider_manager",
|
||||||
"unimsg_to_llm_parts",
|
"unimsg_to_llm_parts",
|
||||||
|
|||||||
@ -7,16 +7,18 @@ LLM 适配器模块
|
|||||||
from .base import BaseAdapter, OpenAICompatAdapter, RequestData, ResponseData
|
from .base import BaseAdapter, OpenAICompatAdapter, RequestData, ResponseData
|
||||||
from .factory import LLMAdapterFactory, get_adapter_for_api_type, register_adapter
|
from .factory import LLMAdapterFactory, get_adapter_for_api_type, register_adapter
|
||||||
from .gemini import GeminiAdapter
|
from .gemini import GeminiAdapter
|
||||||
from .openai import OpenAIAdapter
|
from .openai import DeepSeekAdapter, OpenAIAdapter, OpenAIImageAdapter
|
||||||
|
|
||||||
LLMAdapterFactory.initialize()
|
LLMAdapterFactory.initialize()
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseAdapter",
|
"BaseAdapter",
|
||||||
|
"DeepSeekAdapter",
|
||||||
"GeminiAdapter",
|
"GeminiAdapter",
|
||||||
"LLMAdapterFactory",
|
"LLMAdapterFactory",
|
||||||
"OpenAIAdapter",
|
"OpenAIAdapter",
|
||||||
"OpenAICompatAdapter",
|
"OpenAICompatAdapter",
|
||||||
|
"OpenAIImageAdapter",
|
||||||
"RequestData",
|
"RequestData",
|
||||||
"ResponseData",
|
"ResponseData",
|
||||||
"get_adapter_for_api_type",
|
"get_adapter_for_api_type",
|
||||||
|
|||||||
@ -3,24 +3,26 @@ LLM 适配器基类和通用数据结构
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import base64
|
|
||||||
import binascii
|
|
||||||
import json
|
import json
|
||||||
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import httpx
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from zhenxun.configs.path_config import TEMP_PATH
|
||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
|
|
||||||
|
from ..types import LLMContentPart
|
||||||
from ..types.exceptions import LLMErrorCode, LLMException
|
from ..types.exceptions import LLMErrorCode, LLMException
|
||||||
from ..types.models import LLMToolCall
|
from ..types.models import LLMToolCall
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..config.generation import LLMGenerationConfig
|
from ..config.generation import LLMEmbeddingConfig, LLMGenerationConfig
|
||||||
from ..service import LLMModel
|
from ..service import LLMModel
|
||||||
from ..types.content import LLMMessage
|
from ..types import LLMMessage
|
||||||
from ..types.enums import EmbeddingTaskType
|
from ..types.models import ToolChoice
|
||||||
from ..types.protocols import ToolExecutable
|
|
||||||
|
|
||||||
|
|
||||||
class RequestData(BaseModel):
|
class RequestData(BaseModel):
|
||||||
@ -29,19 +31,23 @@ class RequestData(BaseModel):
|
|||||||
url: str
|
url: str
|
||||||
headers: dict[str, str]
|
headers: dict[str, str]
|
||||||
body: dict[str, Any]
|
body: dict[str, Any]
|
||||||
|
files: dict[str, Any] | list[tuple[str, Any]] | None = None
|
||||||
|
|
||||||
|
|
||||||
class ResponseData(BaseModel):
|
class ResponseData(BaseModel):
|
||||||
"""响应数据封装 - 支持所有高级功能"""
|
"""响应数据封装 - 支持所有高级功能"""
|
||||||
|
|
||||||
text: str
|
text: str
|
||||||
images: list[bytes] | None = None
|
content_parts: list[LLMContentPart] | None = None
|
||||||
|
images: list[bytes | Path] | None = None
|
||||||
usage_info: dict[str, Any] | None = None
|
usage_info: dict[str, Any] | None = None
|
||||||
raw_response: dict[str, Any] | None = None
|
raw_response: dict[str, Any] | None = None
|
||||||
tool_calls: list[LLMToolCall] | None = None
|
tool_calls: list[LLMToolCall] | None = None
|
||||||
code_executions: list[Any] | None = None
|
code_executions: list[Any] | None = None
|
||||||
grounding_metadata: Any | None = None
|
grounding_metadata: Any | None = None
|
||||||
cache_info: Any | None = None
|
cache_info: Any | None = None
|
||||||
|
thought_text: str | None = None
|
||||||
|
thought_signature: str | None = None
|
||||||
|
|
||||||
code_execution_results: list[dict[str, Any]] | None = None
|
code_execution_results: list[dict[str, Any]] | None = None
|
||||||
search_results: list[dict[str, Any]] | None = None
|
search_results: list[dict[str, Any]] | None = None
|
||||||
@ -50,9 +56,33 @@ class ResponseData(BaseModel):
|
|||||||
citations: list[dict[str, Any]] | None = None
|
citations: list[dict[str, Any]] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def process_image_data(image_data: bytes) -> bytes | Path:
|
||||||
|
"""
|
||||||
|
处理图片数据:若超过 2MB 则保存到临时目录,避免占用内存。
|
||||||
|
"""
|
||||||
|
max_inline_size = 2 * 1024 * 1024
|
||||||
|
if len(image_data) > max_inline_size:
|
||||||
|
save_dir = TEMP_PATH / "llm"
|
||||||
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
file_name = f"{uuid.uuid4()}.png"
|
||||||
|
file_path = save_dir / file_name
|
||||||
|
file_path.write_bytes(image_data)
|
||||||
|
logger.info(
|
||||||
|
f"图片数据过大 ({len(image_data)} bytes),已保存到临时文件: {file_path}",
|
||||||
|
"LLMAdapter",
|
||||||
|
)
|
||||||
|
return file_path.resolve()
|
||||||
|
return image_data
|
||||||
|
|
||||||
|
|
||||||
class BaseAdapter(ABC):
|
class BaseAdapter(ABC):
|
||||||
"""LLM API适配器基类"""
|
"""LLM API适配器基类"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def log_sanitization_context(self) -> str:
|
||||||
|
"""用于日志清洗的上下文名称,默认 'default'"""
|
||||||
|
return "default"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def api_type(self) -> str:
|
def api_type(self) -> str:
|
||||||
@ -77,7 +107,7 @@ class BaseAdapter(ABC):
|
|||||||
默认实现:将简单请求转换为高级请求格式
|
默认实现:将简单请求转换为高级请求格式
|
||||||
子类可以重写此方法以提供特定的优化实现
|
子类可以重写此方法以提供特定的优化实现
|
||||||
"""
|
"""
|
||||||
from ..types.content import LLMMessage
|
from ..types import LLMMessage
|
||||||
|
|
||||||
messages: list[LLMMessage] = []
|
messages: list[LLMMessage] = []
|
||||||
|
|
||||||
@ -107,8 +137,8 @@ class BaseAdapter(ABC):
|
|||||||
api_key: str,
|
api_key: str,
|
||||||
messages: list["LLMMessage"],
|
messages: list["LLMMessage"],
|
||||||
config: "LLMGenerationConfig | None" = None,
|
config: "LLMGenerationConfig | None" = None,
|
||||||
tools: dict[str, "ToolExecutable"] | None = None,
|
tools: list[Any] | None = None,
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
tool_choice: "str | dict[str, Any] | ToolChoice | None" = None,
|
||||||
) -> RequestData:
|
) -> RequestData:
|
||||||
"""准备高级请求"""
|
"""准备高级请求"""
|
||||||
pass
|
pass
|
||||||
@ -129,8 +159,7 @@ class BaseAdapter(ABC):
|
|||||||
model: "LLMModel",
|
model: "LLMModel",
|
||||||
api_key: str,
|
api_key: str,
|
||||||
texts: list[str],
|
texts: list[str],
|
||||||
task_type: "EmbeddingTaskType | str",
|
config: "LLMEmbeddingConfig",
|
||||||
**kwargs: Any,
|
|
||||||
) -> RequestData:
|
) -> RequestData:
|
||||||
"""准备文本嵌入请求"""
|
"""准备文本嵌入请求"""
|
||||||
pass
|
pass
|
||||||
@ -142,9 +171,16 @@ class BaseAdapter(ABC):
|
|||||||
"""解析文本嵌入响应"""
|
"""解析文本嵌入响应"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def convert_generation_config(
|
||||||
|
self, config: "LLMGenerationConfig", model: "LLMModel"
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""将通用生成配置转换为特定API的参数字典"""
|
||||||
|
pass
|
||||||
|
|
||||||
def validate_embedding_response(self, response_json: dict[str, Any]) -> None:
|
def validate_embedding_response(self, response_json: dict[str, Any]) -> None:
|
||||||
"""验证嵌入API响应"""
|
"""验证嵌入API响应"""
|
||||||
if "error" in response_json:
|
if response_json.get("error"):
|
||||||
error_info = response_json["error"]
|
error_info = response_json["error"]
|
||||||
msg = (
|
msg = (
|
||||||
error_info.get("message", str(error_info))
|
error_info.get("message", str(error_info))
|
||||||
@ -179,158 +215,9 @@ class BaseAdapter(ABC):
|
|||||||
)
|
)
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
def convert_messages_to_openai_format(
|
|
||||||
self, messages: list["LLMMessage"]
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
"""将LLMMessage转换为OpenAI格式 - 通用方法"""
|
|
||||||
openai_messages: list[dict[str, Any]] = []
|
|
||||||
for msg in messages:
|
|
||||||
openai_msg: dict[str, Any] = {"role": msg.role}
|
|
||||||
|
|
||||||
if msg.role == "tool":
|
|
||||||
openai_msg["tool_call_id"] = msg.tool_call_id
|
|
||||||
openai_msg["name"] = msg.name
|
|
||||||
openai_msg["content"] = msg.content
|
|
||||||
else:
|
|
||||||
if isinstance(msg.content, str):
|
|
||||||
openai_msg["content"] = msg.content
|
|
||||||
else:
|
|
||||||
content_parts = []
|
|
||||||
for part in msg.content:
|
|
||||||
if part.type == "text":
|
|
||||||
content_parts.append({"type": "text", "text": part.text})
|
|
||||||
elif part.type == "image":
|
|
||||||
content_parts.append(
|
|
||||||
{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {"url": part.image_source},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
openai_msg["content"] = content_parts
|
|
||||||
|
|
||||||
if msg.role == "assistant" and msg.tool_calls:
|
|
||||||
assistant_tool_calls = []
|
|
||||||
for call in msg.tool_calls:
|
|
||||||
assistant_tool_calls.append(
|
|
||||||
{
|
|
||||||
"id": call.id,
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": call.function.name,
|
|
||||||
"arguments": call.function.arguments,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
openai_msg["tool_calls"] = assistant_tool_calls
|
|
||||||
|
|
||||||
if msg.name and msg.role != "tool":
|
|
||||||
openai_msg["name"] = msg.name
|
|
||||||
|
|
||||||
openai_messages.append(openai_msg)
|
|
||||||
return openai_messages
|
|
||||||
|
|
||||||
def parse_openai_response(self, response_json: dict[str, Any]) -> ResponseData:
|
|
||||||
"""解析OpenAI格式的响应 - 通用方法"""
|
|
||||||
self.validate_response(response_json)
|
|
||||||
|
|
||||||
try:
|
|
||||||
choices = response_json.get("choices", [])
|
|
||||||
if not choices:
|
|
||||||
logger.debug("OpenAI响应中没有choices,可能为空回复或流结束。")
|
|
||||||
return ResponseData(text="", raw_response=response_json)
|
|
||||||
|
|
||||||
choice = choices[0]
|
|
||||||
message = choice.get("message", {})
|
|
||||||
content = message.get("content", "")
|
|
||||||
|
|
||||||
if content:
|
|
||||||
content = content.strip()
|
|
||||||
|
|
||||||
images_bytes: list[bytes] = []
|
|
||||||
if content and content.startswith("{") and content.endswith("}"):
|
|
||||||
try:
|
|
||||||
content_json = json.loads(content)
|
|
||||||
if "b64_json" in content_json:
|
|
||||||
images_bytes.append(base64.b64decode(content_json["b64_json"]))
|
|
||||||
content = "[图片已生成]"
|
|
||||||
elif "data" in content_json and isinstance(
|
|
||||||
content_json["data"], str
|
|
||||||
):
|
|
||||||
images_bytes.append(base64.b64decode(content_json["data"]))
|
|
||||||
content = "[图片已生成]"
|
|
||||||
|
|
||||||
except (json.JSONDecodeError, KeyError, binascii.Error):
|
|
||||||
pass
|
|
||||||
elif (
|
|
||||||
"images" in message
|
|
||||||
and isinstance(message["images"], list)
|
|
||||||
and message["images"]
|
|
||||||
):
|
|
||||||
image_info = message["images"][0]
|
|
||||||
if image_info.get("type") == "image_url":
|
|
||||||
image_url_obj = image_info.get("image_url", {})
|
|
||||||
url_str = image_url_obj.get("url", "")
|
|
||||||
if url_str.startswith("data:image/png;base64,"):
|
|
||||||
try:
|
|
||||||
b64_data = url_str.split(",", 1)[1]
|
|
||||||
images_bytes.append(base64.b64decode(b64_data))
|
|
||||||
content = content if content else "[图片已生成]"
|
|
||||||
except (IndexError, binascii.Error) as e:
|
|
||||||
logger.warning(f"解析OpenRouter Base64图片数据失败: {e}")
|
|
||||||
|
|
||||||
parsed_tool_calls: list[LLMToolCall] | None = None
|
|
||||||
if message_tool_calls := message.get("tool_calls"):
|
|
||||||
from ..types.models import LLMToolFunction
|
|
||||||
|
|
||||||
parsed_tool_calls = []
|
|
||||||
for tc_data in message_tool_calls:
|
|
||||||
try:
|
|
||||||
if tc_data.get("type") == "function":
|
|
||||||
parsed_tool_calls.append(
|
|
||||||
LLMToolCall(
|
|
||||||
id=tc_data["id"],
|
|
||||||
function=LLMToolFunction(
|
|
||||||
name=tc_data["function"]["name"],
|
|
||||||
arguments=tc_data["function"]["arguments"],
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except KeyError as e:
|
|
||||||
logger.warning(
|
|
||||||
f"解析OpenAI工具调用数据时缺少键: {tc_data}, 错误: {e}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(
|
|
||||||
f"解析OpenAI工具调用数据时出错: {tc_data}, 错误: {e}"
|
|
||||||
)
|
|
||||||
if not parsed_tool_calls:
|
|
||||||
parsed_tool_calls = None
|
|
||||||
|
|
||||||
final_text = content if content is not None else ""
|
|
||||||
if not final_text and parsed_tool_calls:
|
|
||||||
final_text = f"请求调用 {len(parsed_tool_calls)} 个工具。"
|
|
||||||
|
|
||||||
usage_info = response_json.get("usage")
|
|
||||||
|
|
||||||
return ResponseData(
|
|
||||||
text=final_text,
|
|
||||||
tool_calls=parsed_tool_calls,
|
|
||||||
usage_info=usage_info,
|
|
||||||
images=images_bytes if images_bytes else None,
|
|
||||||
raw_response=response_json,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"解析OpenAI格式响应失败: {e}", e=e)
|
|
||||||
raise LLMException(
|
|
||||||
f"解析API响应失败: {e}",
|
|
||||||
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
|
|
||||||
cause=e,
|
|
||||||
)
|
|
||||||
|
|
||||||
def validate_response(self, response_json: dict[str, Any]) -> None:
|
def validate_response(self, response_json: dict[str, Any]) -> None:
|
||||||
"""验证API响应,解析不同API的错误结构"""
|
"""验证API响应,解析不同API的错误结构"""
|
||||||
if "error" in response_json:
|
if response_json.get("error"):
|
||||||
error_info = response_json["error"]
|
error_info = response_json["error"]
|
||||||
|
|
||||||
if isinstance(error_info, dict):
|
if isinstance(error_info, dict):
|
||||||
@ -341,12 +228,15 @@ class BaseAdapter(ABC):
|
|||||||
error_code_mapping = {
|
error_code_mapping = {
|
||||||
"invalid_api_key": LLMErrorCode.API_KEY_INVALID,
|
"invalid_api_key": LLMErrorCode.API_KEY_INVALID,
|
||||||
"authentication_failed": LLMErrorCode.API_KEY_INVALID,
|
"authentication_failed": LLMErrorCode.API_KEY_INVALID,
|
||||||
|
"insufficient_quota": LLMErrorCode.API_QUOTA_EXCEEDED,
|
||||||
"rate_limit_exceeded": LLMErrorCode.API_RATE_LIMITED,
|
"rate_limit_exceeded": LLMErrorCode.API_RATE_LIMITED,
|
||||||
"quota_exceeded": LLMErrorCode.API_RATE_LIMITED,
|
"quota_exceeded": LLMErrorCode.API_RATE_LIMITED,
|
||||||
"model_not_found": LLMErrorCode.MODEL_NOT_FOUND,
|
"model_not_found": LLMErrorCode.MODEL_NOT_FOUND,
|
||||||
"invalid_model": LLMErrorCode.MODEL_NOT_FOUND,
|
"invalid_model": LLMErrorCode.MODEL_NOT_FOUND,
|
||||||
"context_length_exceeded": LLMErrorCode.CONTEXT_LENGTH_EXCEEDED,
|
"context_length_exceeded": LLMErrorCode.CONTEXT_LENGTH_EXCEEDED,
|
||||||
"max_tokens_exceeded": LLMErrorCode.CONTEXT_LENGTH_EXCEEDED,
|
"max_tokens_exceeded": LLMErrorCode.CONTEXT_LENGTH_EXCEEDED,
|
||||||
|
"invalid_request_error": LLMErrorCode.INVALID_PARAMETER,
|
||||||
|
"invalid_parameter": LLMErrorCode.INVALID_PARAMETER,
|
||||||
}
|
}
|
||||||
|
|
||||||
llm_error_code = error_code_mapping.get(
|
llm_error_code = error_code_mapping.get(
|
||||||
@ -405,23 +295,12 @@ class BaseAdapter(ABC):
|
|||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""通用的配置应用逻辑"""
|
"""通用的配置应用逻辑"""
|
||||||
if config is not None:
|
if config is not None:
|
||||||
return config.to_api_params(model.api_type, model.model_name)
|
return self.convert_generation_config(config, model)
|
||||||
|
|
||||||
if model._generation_config is not None:
|
if model._generation_config:
|
||||||
return model._generation_config.to_api_params(
|
return self.convert_generation_config(model._generation_config, model)
|
||||||
model.api_type, model.model_name
|
|
||||||
)
|
|
||||||
|
|
||||||
base_config = {}
|
return {}
|
||||||
if model.temperature is not None:
|
|
||||||
base_config["temperature"] = model.temperature
|
|
||||||
if model.max_tokens is not None:
|
|
||||||
if model.api_type == "gemini":
|
|
||||||
base_config["maxOutputTokens"] = model.max_tokens
|
|
||||||
else:
|
|
||||||
base_config["max_tokens"] = model.max_tokens
|
|
||||||
|
|
||||||
return base_config
|
|
||||||
|
|
||||||
def apply_config_override(
|
def apply_config_override(
|
||||||
self,
|
self,
|
||||||
@ -434,12 +313,96 @@ class BaseAdapter(ABC):
|
|||||||
body.update(config_params)
|
body.update(config_params)
|
||||||
return body
|
return body
|
||||||
|
|
||||||
|
def handle_http_error(self, response: httpx.Response) -> LLMException | None:
|
||||||
|
"""
|
||||||
|
处理 HTTP 错误响应。
|
||||||
|
如果响应状态码表示成功 (200),返回 None;否则构造 LLMException 供外部捕获。
|
||||||
|
"""
|
||||||
|
if response.status_code == 200:
|
||||||
|
return None
|
||||||
|
|
||||||
|
error_text = response.content.decode("utf-8", errors="ignore")
|
||||||
|
error_status = ""
|
||||||
|
error_msg = error_text
|
||||||
|
try:
|
||||||
|
error_json = json.loads(error_text)
|
||||||
|
if isinstance(error_json, dict) and "error" in error_json:
|
||||||
|
error_info = error_json["error"]
|
||||||
|
if isinstance(error_info, dict):
|
||||||
|
error_msg = error_info.get("message", error_msg)
|
||||||
|
raw_status = error_info.get("status") or error_info.get("code")
|
||||||
|
error_status = str(raw_status) if raw_status is not None else ""
|
||||||
|
elif error_info is not None:
|
||||||
|
error_msg = str(error_info)
|
||||||
|
error_status = error_msg
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
status_upper = error_status.upper() if error_status else ""
|
||||||
|
text_upper = error_text.upper()
|
||||||
|
|
||||||
|
error_code = LLMErrorCode.API_REQUEST_FAILED
|
||||||
|
if response.status_code == 400:
|
||||||
|
if (
|
||||||
|
"FAILED_PRECONDITION" in status_upper
|
||||||
|
or "LOCATION IS NOT SUPPORTED" in text_upper
|
||||||
|
):
|
||||||
|
error_code = LLMErrorCode.USER_LOCATION_NOT_SUPPORTED
|
||||||
|
elif "INVALID_ARGUMENT" in status_upper:
|
||||||
|
error_code = LLMErrorCode.INVALID_PARAMETER
|
||||||
|
elif "API_KEY_INVALID" in text_upper or "API KEY NOT VALID" in text_upper:
|
||||||
|
error_code = LLMErrorCode.API_KEY_INVALID
|
||||||
|
else:
|
||||||
|
error_code = LLMErrorCode.INVALID_PARAMETER
|
||||||
|
elif response.status_code in [401, 403]:
|
||||||
|
if error_msg and (
|
||||||
|
"country" in error_msg.lower()
|
||||||
|
or "region" in error_msg.lower()
|
||||||
|
or "unsupported" in error_msg.lower()
|
||||||
|
):
|
||||||
|
error_code = LLMErrorCode.USER_LOCATION_NOT_SUPPORTED
|
||||||
|
elif "PERMISSION_DENIED" in status_upper:
|
||||||
|
error_code = LLMErrorCode.API_KEY_INVALID
|
||||||
|
else:
|
||||||
|
error_code = LLMErrorCode.API_KEY_INVALID
|
||||||
|
elif response.status_code == 404:
|
||||||
|
error_code = LLMErrorCode.MODEL_NOT_FOUND
|
||||||
|
elif response.status_code == 429:
|
||||||
|
if (
|
||||||
|
"RESOURCE_EXHAUSTED" in status_upper
|
||||||
|
or "INSUFFICIENT_QUOTA" in status_upper
|
||||||
|
or ("quota" in error_msg.lower() if error_msg else False)
|
||||||
|
):
|
||||||
|
error_code = LLMErrorCode.API_QUOTA_EXCEEDED
|
||||||
|
else:
|
||||||
|
error_code = LLMErrorCode.API_RATE_LIMITED
|
||||||
|
elif response.status_code in [402, 413]:
|
||||||
|
error_code = LLMErrorCode.API_QUOTA_EXCEEDED
|
||||||
|
elif response.status_code == 422:
|
||||||
|
error_code = LLMErrorCode.GENERATION_FAILED
|
||||||
|
elif response.status_code >= 500:
|
||||||
|
error_code = LLMErrorCode.API_TIMEOUT
|
||||||
|
|
||||||
|
return LLMException(
|
||||||
|
f"HTTP请求失败: {response.status_code} ({error_status or 'Unknown'})",
|
||||||
|
code=error_code,
|
||||||
|
details={
|
||||||
|
"status_code": response.status_code,
|
||||||
|
"api_status": error_status,
|
||||||
|
"response": error_text,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class OpenAICompatAdapter(BaseAdapter):
|
class OpenAICompatAdapter(BaseAdapter):
|
||||||
"""
|
"""
|
||||||
处理所有 OpenAI 兼容 API 的通用适配器。
|
处理所有 OpenAI 兼容 API 的通用适配器。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def log_sanitization_context(self) -> str:
|
||||||
|
return "openai_request"
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_chat_endpoint(self, model: "LLMModel") -> str:
|
def get_chat_endpoint(self, model: "LLMModel") -> str:
|
||||||
"""子类必须实现,返回 chat completions 的端点"""
|
"""子类必须实现,返回 chat completions 的端点"""
|
||||||
@ -481,8 +444,8 @@ class OpenAICompatAdapter(BaseAdapter):
|
|||||||
api_key: str,
|
api_key: str,
|
||||||
messages: list["LLMMessage"],
|
messages: list["LLMMessage"],
|
||||||
config: "LLMGenerationConfig | None" = None,
|
config: "LLMGenerationConfig | None" = None,
|
||||||
tools: dict[str, "ToolExecutable"] | None = None,
|
tools: list[Any] | None = None,
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
tool_choice: "str | dict[str, Any] | ToolChoice | None" = None,
|
||||||
) -> RequestData:
|
) -> RequestData:
|
||||||
"""准备高级请求 - OpenAI兼容格式"""
|
"""准备高级请求 - OpenAI兼容格式"""
|
||||||
url = self.get_api_url(model, self.get_chat_endpoint(model))
|
url = self.get_api_url(model, self.get_chat_endpoint(model))
|
||||||
@ -494,28 +457,44 @@ class OpenAICompatAdapter(BaseAdapter):
|
|||||||
"X-Title": "Zhenxun Bot",
|
"X-Title": "Zhenxun Bot",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
openai_messages = self.convert_messages_to_openai_format(messages)
|
from .components.openai_components import OpenAIMessageConverter
|
||||||
|
|
||||||
|
converter = OpenAIMessageConverter()
|
||||||
|
openai_messages = converter.convert_messages(messages)
|
||||||
|
|
||||||
body = {
|
body = {
|
||||||
"model": model.model_name,
|
"model": model.model_name,
|
||||||
"messages": openai_messages,
|
"messages": openai_messages,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
openai_tools: list[dict[str, Any]] | None = None
|
||||||
|
executables: list[Any] = []
|
||||||
if tools:
|
if tools:
|
||||||
|
for tool in tools:
|
||||||
|
if hasattr(tool, "get_definition"):
|
||||||
|
executables.append(tool)
|
||||||
|
|
||||||
|
if executables:
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from zhenxun.utils.pydantic_compat import model_dump
|
from zhenxun.utils.pydantic_compat import model_dump
|
||||||
|
|
||||||
definition_tasks = [
|
definition_tasks = [
|
||||||
executable.get_definition() for executable in tools.values()
|
executable.get_definition() for executable in executables
|
||||||
]
|
]
|
||||||
openai_tools = await asyncio.gather(*definition_tasks)
|
tool_defs = []
|
||||||
if openai_tools:
|
if definition_tasks:
|
||||||
body["tools"] = [
|
tool_defs = await asyncio.gather(*definition_tasks)
|
||||||
|
|
||||||
|
if tool_defs:
|
||||||
|
openai_tools = [
|
||||||
{"type": "function", "function": model_dump(tool)}
|
{"type": "function", "function": model_dump(tool)}
|
||||||
for tool in openai_tools
|
for tool in tool_defs
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if openai_tools:
|
||||||
|
body["tools"] = openai_tools
|
||||||
|
|
||||||
if tool_choice:
|
if tool_choice:
|
||||||
body["tool_choice"] = tool_choice
|
body["tool_choice"] = tool_choice
|
||||||
|
|
||||||
@ -528,20 +507,21 @@ class OpenAICompatAdapter(BaseAdapter):
|
|||||||
response_json: dict[str, Any],
|
response_json: dict[str, Any],
|
||||||
is_advanced: bool = False,
|
is_advanced: bool = False,
|
||||||
) -> ResponseData:
|
) -> ResponseData:
|
||||||
"""解析响应 - 直接使用基类的 OpenAI 格式解析"""
|
"""解析响应 - 直接使用组件化 ResponseParser"""
|
||||||
_ = model, is_advanced
|
_ = model, is_advanced
|
||||||
return self.parse_openai_response(response_json)
|
from .components.openai_components import OpenAIResponseParser
|
||||||
|
|
||||||
|
parser = OpenAIResponseParser()
|
||||||
|
return parser.parse(response_json)
|
||||||
|
|
||||||
def prepare_embedding_request(
|
def prepare_embedding_request(
|
||||||
self,
|
self,
|
||||||
model: "LLMModel",
|
model: "LLMModel",
|
||||||
api_key: str,
|
api_key: str,
|
||||||
texts: list[str],
|
texts: list[str],
|
||||||
task_type: "EmbeddingTaskType | str",
|
config: "LLMEmbeddingConfig",
|
||||||
**kwargs: Any,
|
|
||||||
) -> RequestData:
|
) -> RequestData:
|
||||||
"""准备嵌入请求 - OpenAI兼容格式"""
|
"""准备嵌入请求 - OpenAI兼容格式"""
|
||||||
_ = task_type
|
|
||||||
url = self.get_api_url(model, self.get_embedding_endpoint(model))
|
url = self.get_api_url(model, self.get_embedding_endpoint(model))
|
||||||
headers = self.get_base_headers(api_key)
|
headers = self.get_base_headers(api_key)
|
||||||
|
|
||||||
@ -550,8 +530,14 @@ class OpenAICompatAdapter(BaseAdapter):
|
|||||||
"input": texts,
|
"input": texts,
|
||||||
}
|
}
|
||||||
|
|
||||||
if kwargs:
|
if config.output_dimensionality:
|
||||||
body.update(kwargs)
|
body["dimensions"] = config.output_dimensionality
|
||||||
|
|
||||||
|
if config.task_type:
|
||||||
|
body["task"] = config.task_type
|
||||||
|
|
||||||
|
if config.encoding_format and config.encoding_format != "float":
|
||||||
|
body["encoding_format"] = config.encoding_format
|
||||||
|
|
||||||
return RequestData(url=url, headers=headers, body=body)
|
return RequestData(url=url, headers=headers, body=body)
|
||||||
|
|
||||||
|
|||||||
1
zhenxun/services/llm/adapters/components/__init__.py
Normal file
1
zhenxun/services/llm/adapters/components/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
|
||||||
606
zhenxun/services/llm/adapters/components/gemini_components.py
Normal file
606
zhenxun/services/llm/adapters/components/gemini_components.py
Normal file
@ -0,0 +1,606 @@
|
|||||||
|
import base64
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from zhenxun.services.llm.adapters.base import ResponseData, process_image_data
|
||||||
|
from zhenxun.services.llm.adapters.components.interfaces import (
|
||||||
|
ConfigMapper,
|
||||||
|
MessageConverter,
|
||||||
|
ResponseParser,
|
||||||
|
ToolSerializer,
|
||||||
|
)
|
||||||
|
from zhenxun.services.llm.config.generation import (
|
||||||
|
ImageAspectRatio,
|
||||||
|
LLMGenerationConfig,
|
||||||
|
ReasoningEffort,
|
||||||
|
ResponseFormat,
|
||||||
|
)
|
||||||
|
from zhenxun.services.llm.config.providers import get_gemini_safety_threshold
|
||||||
|
from zhenxun.services.llm.types import (
|
||||||
|
CodeExecutionOutcome,
|
||||||
|
LLMContentPart,
|
||||||
|
LLMMessage,
|
||||||
|
)
|
||||||
|
from zhenxun.services.llm.types.capabilities import ModelCapabilities
|
||||||
|
from zhenxun.services.llm.types.exceptions import LLMErrorCode, LLMException
|
||||||
|
from zhenxun.services.llm.types.models import (
|
||||||
|
LLMGroundingAttribution,
|
||||||
|
LLMGroundingMetadata,
|
||||||
|
LLMToolCall,
|
||||||
|
LLMToolFunction,
|
||||||
|
ModelDetail,
|
||||||
|
ToolDefinition,
|
||||||
|
)
|
||||||
|
from zhenxun.services.llm.utils import (
|
||||||
|
resolve_json_schema_refs,
|
||||||
|
sanitize_schema_for_llm,
|
||||||
|
)
|
||||||
|
from zhenxun.services.log import logger
|
||||||
|
from zhenxun.utils.http_utils import AsyncHttpx
|
||||||
|
from zhenxun.utils.pydantic_compat import model_copy, model_dump
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiConfigMapper(ConfigMapper):
|
||||||
|
def map_config(
|
||||||
|
self,
|
||||||
|
config: LLMGenerationConfig,
|
||||||
|
model_detail: ModelDetail | None = None,
|
||||||
|
capabilities: ModelCapabilities | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
params: dict[str, Any] = {}
|
||||||
|
|
||||||
|
if config.core:
|
||||||
|
if config.core.temperature is not None:
|
||||||
|
params["temperature"] = config.core.temperature
|
||||||
|
if config.core.max_tokens is not None:
|
||||||
|
params["maxOutputTokens"] = config.core.max_tokens
|
||||||
|
if config.core.top_k is not None:
|
||||||
|
params["topK"] = config.core.top_k
|
||||||
|
if config.core.top_p is not None:
|
||||||
|
params["topP"] = config.core.top_p
|
||||||
|
|
||||||
|
if config.output:
|
||||||
|
if config.output.response_format == ResponseFormat.JSON:
|
||||||
|
params["responseMimeType"] = "application/json"
|
||||||
|
if config.output.response_schema:
|
||||||
|
params["responseJsonSchema"] = config.output.response_schema
|
||||||
|
elif config.output.response_mime_type is not None:
|
||||||
|
params["responseMimeType"] = config.output.response_mime_type
|
||||||
|
|
||||||
|
if (
|
||||||
|
config.output.response_schema is not None
|
||||||
|
and "responseJsonSchema" not in params
|
||||||
|
):
|
||||||
|
params["responseJsonSchema"] = config.output.response_schema
|
||||||
|
if config.output.response_modalities:
|
||||||
|
params["responseModalities"] = config.output.response_modalities
|
||||||
|
|
||||||
|
if config.tool_config:
|
||||||
|
fc_config: dict[str, Any] = {"mode": config.tool_config.mode}
|
||||||
|
if (
|
||||||
|
config.tool_config.allowed_function_names
|
||||||
|
and config.tool_config.mode == "ANY"
|
||||||
|
):
|
||||||
|
builtins = {"code_execution", "google_search", "google_map"}
|
||||||
|
user_funcs = [
|
||||||
|
name
|
||||||
|
for name in config.tool_config.allowed_function_names
|
||||||
|
if name not in builtins
|
||||||
|
]
|
||||||
|
if user_funcs:
|
||||||
|
fc_config["allowedFunctionNames"] = user_funcs
|
||||||
|
params["toolConfig"] = {"functionCallingConfig": fc_config}
|
||||||
|
|
||||||
|
if config.reasoning:
|
||||||
|
thinking_config = params.setdefault("thinkingConfig", {})
|
||||||
|
|
||||||
|
if config.reasoning.budget_tokens is not None:
|
||||||
|
if (
|
||||||
|
config.reasoning.budget_tokens <= 0
|
||||||
|
or config.reasoning.budget_tokens >= 1
|
||||||
|
):
|
||||||
|
budget_value = int(config.reasoning.budget_tokens)
|
||||||
|
else:
|
||||||
|
budget_value = int(config.reasoning.budget_tokens * 32768)
|
||||||
|
thinking_config["thinkingBudget"] = budget_value
|
||||||
|
elif config.reasoning.effort:
|
||||||
|
if config.reasoning.effort == ReasoningEffort.MEDIUM:
|
||||||
|
thinking_config["thinkingLevel"] = "HIGH"
|
||||||
|
else:
|
||||||
|
thinking_config["thinkingLevel"] = config.reasoning.effort.value
|
||||||
|
|
||||||
|
if config.reasoning.show_thoughts is not None:
|
||||||
|
thinking_config["includeThoughts"] = config.reasoning.show_thoughts
|
||||||
|
elif capabilities and capabilities.reasoning_visibility == "visible":
|
||||||
|
thinking_config["includeThoughts"] = True
|
||||||
|
|
||||||
|
if config.visual:
|
||||||
|
image_config: dict[str, Any] = {}
|
||||||
|
|
||||||
|
if config.visual.aspect_ratio is not None:
|
||||||
|
ar_value = (
|
||||||
|
config.visual.aspect_ratio.value
|
||||||
|
if isinstance(config.visual.aspect_ratio, ImageAspectRatio)
|
||||||
|
else config.visual.aspect_ratio
|
||||||
|
)
|
||||||
|
image_config["aspectRatio"] = ar_value
|
||||||
|
|
||||||
|
if config.visual.resolution:
|
||||||
|
image_config["imageSize"] = config.visual.resolution
|
||||||
|
|
||||||
|
if image_config:
|
||||||
|
params["imageConfig"] = image_config
|
||||||
|
|
||||||
|
if config.visual.media_resolution:
|
||||||
|
media_value = config.visual.media_resolution.upper()
|
||||||
|
if not media_value.startswith("MEDIA_RESOLUTION_"):
|
||||||
|
media_value = f"MEDIA_RESOLUTION_{media_value}"
|
||||||
|
params["mediaResolution"] = media_value
|
||||||
|
|
||||||
|
if config.custom_params:
|
||||||
|
mapped_custom = config.custom_params.copy()
|
||||||
|
if "max_tokens" in mapped_custom:
|
||||||
|
mapped_custom["maxOutputTokens"] = mapped_custom.pop("max_tokens")
|
||||||
|
if "top_k" in mapped_custom:
|
||||||
|
mapped_custom["topK"] = mapped_custom.pop("top_k")
|
||||||
|
if "top_p" in mapped_custom:
|
||||||
|
mapped_custom["topP"] = mapped_custom.pop("top_p")
|
||||||
|
|
||||||
|
for key in (
|
||||||
|
"code_execution_timeout",
|
||||||
|
"grounding_config",
|
||||||
|
"dynamic_threshold",
|
||||||
|
"user_location",
|
||||||
|
"reflexion_retries",
|
||||||
|
):
|
||||||
|
mapped_custom.pop(key, None)
|
||||||
|
|
||||||
|
for unsupported in [
|
||||||
|
"frequency_penalty",
|
||||||
|
"presence_penalty",
|
||||||
|
"repetition_penalty",
|
||||||
|
]:
|
||||||
|
if unsupported in mapped_custom:
|
||||||
|
mapped_custom.pop(unsupported)
|
||||||
|
|
||||||
|
params.update(mapped_custom)
|
||||||
|
|
||||||
|
safety_settings: list[dict[str, Any]] = []
|
||||||
|
if config.safety and config.safety.safety_settings:
|
||||||
|
for category, threshold in config.safety.safety_settings.items():
|
||||||
|
safety_settings.append({"category": category, "threshold": threshold})
|
||||||
|
else:
|
||||||
|
threshold = get_gemini_safety_threshold()
|
||||||
|
for category in [
|
||||||
|
"HARM_CATEGORY_HARASSMENT",
|
||||||
|
"HARM_CATEGORY_HATE_SPEECH",
|
||||||
|
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||||
|
"HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||||
|
]:
|
||||||
|
safety_settings.append({"category": category, "threshold": threshold})
|
||||||
|
|
||||||
|
if safety_settings:
|
||||||
|
params["safetySettings"] = safety_settings
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiMessageConverter(MessageConverter):
|
||||||
|
async def convert_part(self, part: LLMContentPart) -> dict[str, Any]:
|
||||||
|
"""将单个内容部分转换为 Gemini API 格式"""
|
||||||
|
|
||||||
|
def _get_gemini_resolution_dict() -> dict[str, Any]:
|
||||||
|
if part.media_resolution:
|
||||||
|
value = part.media_resolution.upper()
|
||||||
|
if not value.startswith("MEDIA_RESOLUTION_"):
|
||||||
|
value = f"MEDIA_RESOLUTION_{value}"
|
||||||
|
return {"media_resolution": {"level": value}}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
if part.type == "text":
|
||||||
|
return {"text": part.text}
|
||||||
|
|
||||||
|
if part.type == "thought":
|
||||||
|
return {"text": part.thought_text, "thought": True}
|
||||||
|
|
||||||
|
if part.type == "image":
|
||||||
|
if not part.image_source:
|
||||||
|
raise ValueError("图像类型的内容必须包含image_source")
|
||||||
|
|
||||||
|
if part.is_image_base64():
|
||||||
|
base64_info = part.get_base64_data()
|
||||||
|
if base64_info:
|
||||||
|
mime_type, data = base64_info
|
||||||
|
payload = {"inlineData": {"mimeType": mime_type, "data": data}}
|
||||||
|
payload.update(_get_gemini_resolution_dict())
|
||||||
|
return payload
|
||||||
|
raise ValueError(f"无法解析Base64图像数据: {part.image_source[:50]}...")
|
||||||
|
if part.is_image_url():
|
||||||
|
logger.debug(f"正在为Gemini下载并编码URL图片: {part.image_source}")
|
||||||
|
try:
|
||||||
|
image_bytes = await AsyncHttpx.get_content(part.image_source)
|
||||||
|
mime_type = part.mime_type or "image/jpeg"
|
||||||
|
base64_data = base64.b64encode(image_bytes).decode("utf-8")
|
||||||
|
payload = {
|
||||||
|
"inlineData": {"mimeType": mime_type, "data": base64_data}
|
||||||
|
}
|
||||||
|
payload.update(_get_gemini_resolution_dict())
|
||||||
|
return payload
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"下载或编码URL图片失败: {e}", e=e)
|
||||||
|
raise ValueError(f"无法处理图片URL: {e}")
|
||||||
|
raise ValueError(f"不支持的图像源格式: {part.image_source[:50]}...")
|
||||||
|
|
||||||
|
if part.type == "video":
|
||||||
|
if not part.video_source:
|
||||||
|
raise ValueError("视频类型的内容必须包含video_source")
|
||||||
|
|
||||||
|
if part.video_source.startswith("data:"):
|
||||||
|
try:
|
||||||
|
header, data = part.video_source.split(",", 1)
|
||||||
|
mime_type = header.split(";")[0].replace("data:", "")
|
||||||
|
payload = {"inlineData": {"mimeType": mime_type, "data": data}}
|
||||||
|
payload.update(_get_gemini_resolution_dict())
|
||||||
|
return payload
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
raise ValueError(
|
||||||
|
f"无法解析Base64视频数据: {part.video_source[:50]}..."
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
"Gemini API 的视频处理需要通过 File API 上传,不支持直接 URL"
|
||||||
|
)
|
||||||
|
|
||||||
|
if part.type == "audio":
|
||||||
|
if not part.audio_source:
|
||||||
|
raise ValueError("音频类型的内容必须包含audio_source")
|
||||||
|
|
||||||
|
if part.audio_source.startswith("data:"):
|
||||||
|
try:
|
||||||
|
header, data = part.audio_source.split(",", 1)
|
||||||
|
mime_type = header.split(";")[0].replace("data:", "")
|
||||||
|
payload = {"inlineData": {"mimeType": mime_type, "data": data}}
|
||||||
|
payload.update(_get_gemini_resolution_dict())
|
||||||
|
return payload
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
raise ValueError(
|
||||||
|
f"无法解析Base64音频数据: {part.audio_source[:50]}..."
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
"Gemini API 的音频处理需要通过 File API 上传,不支持直接 URL"
|
||||||
|
)
|
||||||
|
|
||||||
|
if part.type == "file":
|
||||||
|
if part.file_uri:
|
||||||
|
payload = {
|
||||||
|
"fileData": {"mimeType": part.mime_type, "fileUri": part.file_uri}
|
||||||
|
}
|
||||||
|
payload.update(_get_gemini_resolution_dict())
|
||||||
|
return payload
|
||||||
|
if part.file_source:
|
||||||
|
file_name = (
|
||||||
|
part.metadata.get("name", "file") if part.metadata else "file"
|
||||||
|
)
|
||||||
|
return {"text": f"[文件: {file_name}]\n{part.file_source}"}
|
||||||
|
raise ValueError("文件类型的内容必须包含file_uri或file_source")
|
||||||
|
|
||||||
|
raise ValueError(f"不支持的内容类型: {part.type}")
|
||||||
|
|
||||||
|
async def convert_messages_async(
|
||||||
|
self, messages: list[LLMMessage]
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
gemini_contents: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
current_parts: list[dict[str, Any]] = []
|
||||||
|
if msg.role == "system":
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif msg.role == "user":
|
||||||
|
if isinstance(msg.content, str):
|
||||||
|
current_parts.append({"text": msg.content})
|
||||||
|
elif isinstance(msg.content, list):
|
||||||
|
for part_obj in msg.content:
|
||||||
|
current_parts.append(await self.convert_part(part_obj))
|
||||||
|
gemini_contents.append({"role": "user", "parts": current_parts})
|
||||||
|
|
||||||
|
elif msg.role == "assistant" or msg.role == "model":
|
||||||
|
if isinstance(msg.content, str) and msg.content:
|
||||||
|
current_parts.append({"text": msg.content})
|
||||||
|
elif isinstance(msg.content, list):
|
||||||
|
for part_obj in msg.content:
|
||||||
|
part_dict = await self.convert_part(part_obj)
|
||||||
|
|
||||||
|
if "executableCode" in part_dict:
|
||||||
|
part_dict["executable_code"] = part_dict.pop(
|
||||||
|
"executableCode"
|
||||||
|
)
|
||||||
|
|
||||||
|
if "codeExecutionResult" in part_dict:
|
||||||
|
part_dict["code_execution_result"] = part_dict.pop(
|
||||||
|
"codeExecutionResult"
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
part_obj.metadata
|
||||||
|
and "thought_signature" in part_obj.metadata
|
||||||
|
):
|
||||||
|
part_dict["thoughtSignature"] = part_obj.metadata[
|
||||||
|
"thought_signature"
|
||||||
|
]
|
||||||
|
current_parts.append(part_dict)
|
||||||
|
|
||||||
|
if msg.tool_calls:
|
||||||
|
for call in msg.tool_calls:
|
||||||
|
fc_part = {
|
||||||
|
"functionCall": {
|
||||||
|
"name": call.function.name,
|
||||||
|
"args": json.loads(call.function.arguments),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if call.thought_signature:
|
||||||
|
fc_part["thoughtSignature"] = call.thought_signature
|
||||||
|
current_parts.append(fc_part)
|
||||||
|
if current_parts:
|
||||||
|
gemini_contents.append({"role": "model", "parts": current_parts})
|
||||||
|
|
||||||
|
elif msg.role == "tool":
|
||||||
|
if not msg.name:
|
||||||
|
raise ValueError("Gemini 工具消息必须包含 'name' 字段(函数名)。")
|
||||||
|
|
||||||
|
try:
|
||||||
|
content_str = (
|
||||||
|
msg.content
|
||||||
|
if isinstance(msg.content, str)
|
||||||
|
else str(msg.content)
|
||||||
|
)
|
||||||
|
tool_result_obj = json.loads(content_str)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
content_str = (
|
||||||
|
msg.content
|
||||||
|
if isinstance(msg.content, str)
|
||||||
|
else str(msg.content)
|
||||||
|
)
|
||||||
|
tool_result_obj = {"raw_output": content_str}
|
||||||
|
|
||||||
|
if isinstance(tool_result_obj, list):
|
||||||
|
final_response_payload = {"result": tool_result_obj}
|
||||||
|
elif not isinstance(tool_result_obj, dict):
|
||||||
|
final_response_payload = {"result": tool_result_obj}
|
||||||
|
else:
|
||||||
|
final_response_payload = tool_result_obj
|
||||||
|
|
||||||
|
current_parts.append(
|
||||||
|
{
|
||||||
|
"functionResponse": {
|
||||||
|
"name": msg.name,
|
||||||
|
"response": final_response_payload,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if gemini_contents and gemini_contents[-1]["role"] == "function":
|
||||||
|
gemini_contents[-1]["parts"].extend(current_parts)
|
||||||
|
else:
|
||||||
|
gemini_contents.append({"role": "function", "parts": current_parts})
|
||||||
|
|
||||||
|
return gemini_contents
|
||||||
|
|
||||||
|
def convert_messages(self, messages: list[LLMMessage]) -> list[dict[str, Any]]:
|
||||||
|
raise NotImplementedError("Use convert_messages_async for Gemini")
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiToolSerializer(ToolSerializer):
|
||||||
|
def serialize_tools(self, tools: list[ToolDefinition]) -> list[dict[str, Any]]:
|
||||||
|
function_declarations: list[dict[str, Any]] = []
|
||||||
|
for tool_def in tools:
|
||||||
|
tool_copy = model_copy(tool_def)
|
||||||
|
tool_copy.parameters = resolve_json_schema_refs(tool_copy.parameters)
|
||||||
|
tool_copy.parameters = sanitize_schema_for_llm(
|
||||||
|
tool_copy.parameters, api_type="gemini"
|
||||||
|
)
|
||||||
|
function_declarations.append(model_dump(tool_copy))
|
||||||
|
return function_declarations
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiResponseParser(ResponseParser):
|
||||||
|
def validate_response(self, response_json: dict[str, Any]) -> None:
|
||||||
|
if error := response_json.get("error"):
|
||||||
|
code = error.get("code")
|
||||||
|
message = error.get("message", "")
|
||||||
|
status = error.get("status")
|
||||||
|
details = error.get("details", [])
|
||||||
|
|
||||||
|
if code == 429 or status == "RESOURCE_EXHAUSTED":
|
||||||
|
is_quota = any(
|
||||||
|
d.get("reason") in ("QUOTA_EXCEEDED", "SERVICE_DISABLED")
|
||||||
|
for d in details
|
||||||
|
if isinstance(d, dict)
|
||||||
|
)
|
||||||
|
if is_quota or "quota" in message.lower():
|
||||||
|
raise LLMException(
|
||||||
|
f"Gemini配额耗尽: {message}",
|
||||||
|
code=LLMErrorCode.API_QUOTA_EXCEEDED,
|
||||||
|
details=error,
|
||||||
|
)
|
||||||
|
raise LLMException(
|
||||||
|
f"Gemini速率限制: {message}",
|
||||||
|
code=LLMErrorCode.API_RATE_LIMITED,
|
||||||
|
details=error,
|
||||||
|
)
|
||||||
|
|
||||||
|
if code == 400 or status in ("INVALID_ARGUMENT", "FAILED_PRECONDITION"):
|
||||||
|
raise LLMException(
|
||||||
|
f"Gemini参数错误: {message}",
|
||||||
|
code=LLMErrorCode.INVALID_PARAMETER,
|
||||||
|
details=error,
|
||||||
|
recoverable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if prompt_feedback := response_json.get("promptFeedback"):
|
||||||
|
if block_reason := prompt_feedback.get("blockReason"):
|
||||||
|
raise LLMException(
|
||||||
|
f"内容被安全过滤: {block_reason}",
|
||||||
|
code=LLMErrorCode.CONTENT_FILTERED,
|
||||||
|
details={
|
||||||
|
"block_reason": block_reason,
|
||||||
|
"safety_ratings": prompt_feedback.get("safetyRatings"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def parse(self, response_json: dict[str, Any]) -> ResponseData:
|
||||||
|
self.validate_response(response_json)
|
||||||
|
|
||||||
|
if "image_generation" in response_json and isinstance(
|
||||||
|
response_json["image_generation"], dict
|
||||||
|
):
|
||||||
|
candidates_source = response_json["image_generation"]
|
||||||
|
else:
|
||||||
|
candidates_source = response_json
|
||||||
|
|
||||||
|
candidates = candidates_source.get("candidates", [])
|
||||||
|
usage_info = response_json.get("usageMetadata")
|
||||||
|
|
||||||
|
if not candidates:
|
||||||
|
return ResponseData(text="", raw_response=response_json)
|
||||||
|
|
||||||
|
candidate = candidates[0]
|
||||||
|
thought_signature: str | None = None
|
||||||
|
|
||||||
|
content_data = candidate.get("content", {})
|
||||||
|
parts = content_data.get("parts", [])
|
||||||
|
|
||||||
|
text_content = ""
|
||||||
|
images_payload: list[bytes | Path] = []
|
||||||
|
parsed_tool_calls: list[LLMToolCall] | None = None
|
||||||
|
parsed_code_executions: list[dict[str, Any]] = []
|
||||||
|
content_parts: list[LLMContentPart] = []
|
||||||
|
thought_summary_parts: list[str] = []
|
||||||
|
answer_parts = []
|
||||||
|
|
||||||
|
for part in parts:
|
||||||
|
part_signature = part.get("thoughtSignature")
|
||||||
|
if part_signature and thought_signature is None:
|
||||||
|
thought_signature = part_signature
|
||||||
|
part_metadata: dict[str, Any] | None = None
|
||||||
|
if part_signature:
|
||||||
|
part_metadata = {"thought_signature": part_signature}
|
||||||
|
|
||||||
|
if part.get("thought") is True:
|
||||||
|
t_text = part.get("text", "")
|
||||||
|
thought_summary_parts.append(t_text)
|
||||||
|
content_parts.append(LLMContentPart.thought_part(t_text))
|
||||||
|
|
||||||
|
elif "text" in part:
|
||||||
|
answer_parts.append(part["text"])
|
||||||
|
c_part = LLMContentPart(
|
||||||
|
type="text", text=part["text"], metadata=part_metadata
|
||||||
|
)
|
||||||
|
content_parts.append(c_part)
|
||||||
|
|
||||||
|
elif "thoughtSummary" in part:
|
||||||
|
thought_summary_parts.append(part["thoughtSummary"])
|
||||||
|
content_parts.append(
|
||||||
|
LLMContentPart.thought_part(part["thoughtSummary"])
|
||||||
|
)
|
||||||
|
|
||||||
|
elif "inlineData" in part:
|
||||||
|
inline_data = part["inlineData"]
|
||||||
|
if "data" in inline_data:
|
||||||
|
decoded = base64.b64decode(inline_data["data"])
|
||||||
|
images_payload.append(process_image_data(decoded))
|
||||||
|
|
||||||
|
elif "functionCall" in part:
|
||||||
|
if parsed_tool_calls is None:
|
||||||
|
parsed_tool_calls = []
|
||||||
|
fc_data = part["functionCall"]
|
||||||
|
fc_sig = part_signature
|
||||||
|
try:
|
||||||
|
call_id = f"call_gemini_{len(parsed_tool_calls)}"
|
||||||
|
parsed_tool_calls.append(
|
||||||
|
LLMToolCall(
|
||||||
|
id=call_id,
|
||||||
|
thought_signature=fc_sig,
|
||||||
|
function=LLMToolFunction(
|
||||||
|
name=fc_data["name"],
|
||||||
|
arguments=json.dumps(fc_data["args"]),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"解析Gemini functionCall时出错: {fc_data}, 错误: {e}"
|
||||||
|
)
|
||||||
|
elif "executableCode" in part:
|
||||||
|
exec_code = part["executableCode"]
|
||||||
|
lang = exec_code.get("language", "PYTHON")
|
||||||
|
code = exec_code.get("code", "")
|
||||||
|
content_parts.append(LLMContentPart.executable_code_part(lang, code))
|
||||||
|
answer_parts.append(f"\n[生成代码 ({lang})]:\n```python\n{code}\n```\n")
|
||||||
|
|
||||||
|
elif "codeExecutionResult" in part:
|
||||||
|
result = part["codeExecutionResult"]
|
||||||
|
outcome = result.get("outcome", CodeExecutionOutcome.OUTCOME_UNKNOWN)
|
||||||
|
output = result.get("output", "")
|
||||||
|
|
||||||
|
content_parts.append(
|
||||||
|
LLMContentPart.execution_result_part(outcome, output)
|
||||||
|
)
|
||||||
|
|
||||||
|
parsed_code_executions.append(result)
|
||||||
|
|
||||||
|
if outcome == CodeExecutionOutcome.OUTCOME_OK:
|
||||||
|
answer_parts.append(f"\n[代码执行结果]:\n```\n{output}\n```\n")
|
||||||
|
else:
|
||||||
|
answer_parts.append(f"\n[代码执行失败 ({outcome})]:\n{output}\n")
|
||||||
|
|
||||||
|
full_answer = "".join(answer_parts).strip()
|
||||||
|
text_content = full_answer
|
||||||
|
final_thought_text = (
|
||||||
|
"\n\n".join(thought_summary_parts).strip()
|
||||||
|
if thought_summary_parts
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
grounding_metadata_obj = None
|
||||||
|
if grounding_data := candidate.get("groundingMetadata"):
|
||||||
|
try:
|
||||||
|
sep_content = None
|
||||||
|
sep_field = grounding_data.get("searchEntryPoint")
|
||||||
|
if isinstance(sep_field, dict):
|
||||||
|
sep_content = sep_field.get("renderedContent")
|
||||||
|
|
||||||
|
attributions = []
|
||||||
|
if chunks := grounding_data.get("groundingChunks"):
|
||||||
|
for chunk in chunks:
|
||||||
|
if web := chunk.get("web"):
|
||||||
|
attributions.append(
|
||||||
|
LLMGroundingAttribution(
|
||||||
|
title=web.get("title"),
|
||||||
|
uri=web.get("uri"),
|
||||||
|
snippet=web.get("snippet"),
|
||||||
|
confidence_score=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
grounding_metadata_obj = LLMGroundingMetadata(
|
||||||
|
web_search_queries=grounding_data.get("webSearchQueries"),
|
||||||
|
grounding_attributions=attributions or None,
|
||||||
|
search_suggestions=grounding_data.get("searchSuggestions"),
|
||||||
|
search_entry_point=sep_content,
|
||||||
|
map_widget_token=grounding_data.get("googleMapsWidgetContextToken"),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"无法解析Grounding元数据: {grounding_data}, {e}")
|
||||||
|
|
||||||
|
return ResponseData(
|
||||||
|
text=text_content,
|
||||||
|
tool_calls=parsed_tool_calls,
|
||||||
|
code_executions=parsed_code_executions if parsed_code_executions else None,
|
||||||
|
content_parts=content_parts if content_parts else None,
|
||||||
|
images=images_payload if images_payload else None,
|
||||||
|
usage_info=usage_info,
|
||||||
|
raw_response=response_json,
|
||||||
|
grounding_metadata=grounding_metadata_obj,
|
||||||
|
thought_text=final_thought_text,
|
||||||
|
thought_signature=thought_signature,
|
||||||
|
)
|
||||||
43
zhenxun/services/llm/adapters/components/interfaces.py
Normal file
43
zhenxun/services/llm/adapters/components/interfaces.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from zhenxun.services.llm.adapters.base import ResponseData
|
||||||
|
from zhenxun.services.llm.config.generation import LLMGenerationConfig
|
||||||
|
from zhenxun.services.llm.types import LLMMessage
|
||||||
|
from zhenxun.services.llm.types.capabilities import ModelCapabilities
|
||||||
|
from zhenxun.services.llm.types.models import ModelDetail, ToolDefinition
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigMapper(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def map_config(
|
||||||
|
self,
|
||||||
|
config: LLMGenerationConfig,
|
||||||
|
model_detail: ModelDetail | None = None,
|
||||||
|
capabilities: ModelCapabilities | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""将通用生成配置转换为特定 API 的参数字典"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class MessageConverter(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def convert_messages(
|
||||||
|
self, messages: list[LLMMessage]
|
||||||
|
) -> list[dict[str, Any]] | dict[str, Any]:
|
||||||
|
"""将通用消息列表转换为特定 API 的消息格式"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class ToolSerializer(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def serialize_tools(self, tools: list[ToolDefinition]) -> Any:
|
||||||
|
"""将通用工具定义转换为特定 API 的工具格式"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseParser(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def parse(self, response_json: dict[str, Any]) -> ResponseData:
|
||||||
|
"""将特定 API 的响应解析为通用响应数据"""
|
||||||
|
...
|
||||||
347
zhenxun/services/llm/adapters/components/openai_components.py
Normal file
347
zhenxun/services/llm/adapters/components/openai_components.py
Normal file
@ -0,0 +1,347 @@
|
|||||||
|
import base64
|
||||||
|
import binascii
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from zhenxun.services.llm.adapters.base import ResponseData, process_image_data
|
||||||
|
from zhenxun.services.llm.adapters.components.interfaces import (
|
||||||
|
ConfigMapper,
|
||||||
|
MessageConverter,
|
||||||
|
ResponseParser,
|
||||||
|
ToolSerializer,
|
||||||
|
)
|
||||||
|
from zhenxun.services.llm.config.generation import (
|
||||||
|
ImageAspectRatio,
|
||||||
|
LLMGenerationConfig,
|
||||||
|
ResponseFormat,
|
||||||
|
StructuredOutputStrategy,
|
||||||
|
)
|
||||||
|
from zhenxun.services.llm.types import LLMMessage
|
||||||
|
from zhenxun.services.llm.types.capabilities import ModelCapabilities
|
||||||
|
from zhenxun.services.llm.types.exceptions import LLMErrorCode, LLMException
|
||||||
|
from zhenxun.services.llm.types.models import (
|
||||||
|
LLMToolCall,
|
||||||
|
LLMToolFunction,
|
||||||
|
ModelDetail,
|
||||||
|
ToolDefinition,
|
||||||
|
)
|
||||||
|
from zhenxun.services.llm.utils import sanitize_schema_for_llm
|
||||||
|
from zhenxun.services.log import logger
|
||||||
|
from zhenxun.utils.pydantic_compat import model_dump
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIConfigMapper(ConfigMapper):
|
||||||
|
def __init__(self, api_type: str = "openai"):
|
||||||
|
self.api_type = api_type
|
||||||
|
|
||||||
|
def map_config(
|
||||||
|
self,
|
||||||
|
config: LLMGenerationConfig,
|
||||||
|
model_detail: ModelDetail | None = None,
|
||||||
|
capabilities: ModelCapabilities | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
params: dict[str, Any] = {}
|
||||||
|
strategy = config.output.structured_output_strategy if config.output else None
|
||||||
|
if strategy is None:
|
||||||
|
strategy = (
|
||||||
|
StructuredOutputStrategy.TOOL_CALL
|
||||||
|
if self.api_type == "deepseek"
|
||||||
|
else StructuredOutputStrategy.NATIVE
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.core:
|
||||||
|
if config.core.temperature is not None:
|
||||||
|
params["temperature"] = config.core.temperature
|
||||||
|
if config.core.max_tokens is not None:
|
||||||
|
params["max_tokens"] = config.core.max_tokens
|
||||||
|
if config.core.top_k is not None:
|
||||||
|
params["top_k"] = config.core.top_k
|
||||||
|
if config.core.top_p is not None:
|
||||||
|
params["top_p"] = config.core.top_p
|
||||||
|
if config.core.frequency_penalty is not None:
|
||||||
|
params["frequency_penalty"] = config.core.frequency_penalty
|
||||||
|
if config.core.presence_penalty is not None:
|
||||||
|
params["presence_penalty"] = config.core.presence_penalty
|
||||||
|
if config.core.stop is not None:
|
||||||
|
params["stop"] = config.core.stop
|
||||||
|
|
||||||
|
if config.core.repetition_penalty is not None:
|
||||||
|
if self.api_type == "openai":
|
||||||
|
logger.warning("OpenAI官方API不支持repetition_penalty参数,已忽略")
|
||||||
|
else:
|
||||||
|
params["repetition_penalty"] = config.core.repetition_penalty
|
||||||
|
|
||||||
|
if config.reasoning and config.reasoning.effort:
|
||||||
|
params["reasoning_effort"] = config.reasoning.effort.value.lower()
|
||||||
|
|
||||||
|
if config.output:
|
||||||
|
if isinstance(config.output.response_format, dict):
|
||||||
|
params["response_format"] = config.output.response_format
|
||||||
|
elif (
|
||||||
|
config.output.response_format == ResponseFormat.JSON
|
||||||
|
and strategy == StructuredOutputStrategy.NATIVE
|
||||||
|
):
|
||||||
|
if config.output.response_schema:
|
||||||
|
sanitized = sanitize_schema_for_llm(
|
||||||
|
config.output.response_schema, api_type="openai"
|
||||||
|
)
|
||||||
|
params["response_format"] = {
|
||||||
|
"type": "json_schema",
|
||||||
|
"json_schema": {
|
||||||
|
"name": "structured_response",
|
||||||
|
"schema": sanitized,
|
||||||
|
"strict": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
params["response_format"] = {"type": "json_object"}
|
||||||
|
|
||||||
|
if config.tool_config:
|
||||||
|
mode = config.tool_config.mode
|
||||||
|
if mode == "NONE":
|
||||||
|
params["tool_choice"] = "none"
|
||||||
|
elif mode == "AUTO":
|
||||||
|
params["tool_choice"] = "auto"
|
||||||
|
elif mode == "ANY":
|
||||||
|
params["tool_choice"] = "required"
|
||||||
|
|
||||||
|
if config.visual and config.visual.aspect_ratio:
|
||||||
|
size_map = {
|
||||||
|
ImageAspectRatio.SQUARE: "1024x1024",
|
||||||
|
ImageAspectRatio.LANDSCAPE_16_9: "1792x1024",
|
||||||
|
ImageAspectRatio.PORTRAIT_9_16: "1024x1792",
|
||||||
|
}
|
||||||
|
ar = config.visual.aspect_ratio
|
||||||
|
if isinstance(ar, ImageAspectRatio):
|
||||||
|
mapped_size = size_map.get(ar)
|
||||||
|
if mapped_size:
|
||||||
|
params["size"] = mapped_size
|
||||||
|
elif isinstance(ar, str):
|
||||||
|
params["size"] = ar
|
||||||
|
|
||||||
|
if config.custom_params:
|
||||||
|
mapped_custom = config.custom_params.copy()
|
||||||
|
if "repetition_penalty" in mapped_custom and self.api_type == "openai":
|
||||||
|
mapped_custom.pop("repetition_penalty")
|
||||||
|
|
||||||
|
if "stop" in mapped_custom:
|
||||||
|
stop_value = mapped_custom["stop"]
|
||||||
|
if isinstance(stop_value, str):
|
||||||
|
mapped_custom["stop"] = [stop_value]
|
||||||
|
|
||||||
|
params.update(mapped_custom)
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIMessageConverter(MessageConverter):
|
||||||
|
def convert_messages(self, messages: list[LLMMessage]) -> list[dict[str, Any]]:
|
||||||
|
openai_messages: list[dict[str, Any]] = []
|
||||||
|
for msg in messages:
|
||||||
|
openai_msg: dict[str, Any] = {"role": msg.role}
|
||||||
|
|
||||||
|
if msg.role == "tool":
|
||||||
|
openai_msg["tool_call_id"] = msg.tool_call_id
|
||||||
|
openai_msg["name"] = msg.name
|
||||||
|
openai_msg["content"] = msg.content
|
||||||
|
else:
|
||||||
|
if isinstance(msg.content, str):
|
||||||
|
openai_msg["content"] = msg.content
|
||||||
|
else:
|
||||||
|
content_parts = []
|
||||||
|
for part in msg.content:
|
||||||
|
if part.type == "text":
|
||||||
|
content_parts.append({"type": "text", "text": part.text})
|
||||||
|
elif part.type == "image":
|
||||||
|
content_parts.append(
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": part.image_source},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
openai_msg["content"] = content_parts
|
||||||
|
|
||||||
|
if msg.role == "assistant" and msg.tool_calls:
|
||||||
|
assistant_tool_calls = []
|
||||||
|
for call in msg.tool_calls:
|
||||||
|
assistant_tool_calls.append(
|
||||||
|
{
|
||||||
|
"id": call.id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": call.function.name,
|
||||||
|
"arguments": call.function.arguments,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
openai_msg["tool_calls"] = assistant_tool_calls
|
||||||
|
|
||||||
|
if msg.name and msg.role != "tool":
|
||||||
|
openai_msg["name"] = msg.name
|
||||||
|
|
||||||
|
openai_messages.append(openai_msg)
|
||||||
|
return openai_messages
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIToolSerializer(ToolSerializer):
|
||||||
|
def serialize_tools(
|
||||||
|
self, tools: list[ToolDefinition]
|
||||||
|
) -> list[dict[str, Any]] | None:
|
||||||
|
if not tools:
|
||||||
|
return None
|
||||||
|
|
||||||
|
openai_tools = []
|
||||||
|
for tool in tools:
|
||||||
|
tool_dict = model_dump(tool)
|
||||||
|
parameters = tool_dict.get("parameters")
|
||||||
|
if parameters:
|
||||||
|
tool_dict["parameters"] = sanitize_schema_for_llm(
|
||||||
|
parameters, api_type="openai"
|
||||||
|
)
|
||||||
|
tool_dict["strict"] = True
|
||||||
|
openai_tools.append({"type": "function", "function": tool_dict})
|
||||||
|
return openai_tools
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIResponseParser(ResponseParser):
|
||||||
|
def validate_response(self, response_json: dict[str, Any]) -> None:
|
||||||
|
if response_json.get("error"):
|
||||||
|
error_info = response_json["error"]
|
||||||
|
if isinstance(error_info, dict):
|
||||||
|
error_message = error_info.get("message", "未知错误")
|
||||||
|
error_code = error_info.get("code", "unknown")
|
||||||
|
|
||||||
|
error_code_mapping = {
|
||||||
|
"invalid_api_key": LLMErrorCode.API_KEY_INVALID,
|
||||||
|
"authentication_failed": LLMErrorCode.API_KEY_INVALID,
|
||||||
|
"insufficient_quota": LLMErrorCode.API_QUOTA_EXCEEDED,
|
||||||
|
"rate_limit_exceeded": LLMErrorCode.API_RATE_LIMITED,
|
||||||
|
"quota_exceeded": LLMErrorCode.API_RATE_LIMITED,
|
||||||
|
"model_not_found": LLMErrorCode.MODEL_NOT_FOUND,
|
||||||
|
"invalid_model": LLMErrorCode.MODEL_NOT_FOUND,
|
||||||
|
"context_length_exceeded": LLMErrorCode.CONTEXT_LENGTH_EXCEEDED,
|
||||||
|
"max_tokens_exceeded": LLMErrorCode.CONTEXT_LENGTH_EXCEEDED,
|
||||||
|
"invalid_request_error": LLMErrorCode.INVALID_PARAMETER,
|
||||||
|
"invalid_parameter": LLMErrorCode.INVALID_PARAMETER,
|
||||||
|
}
|
||||||
|
|
||||||
|
llm_error_code = error_code_mapping.get(
|
||||||
|
error_code, LLMErrorCode.API_RESPONSE_INVALID
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
error_message = str(error_info)
|
||||||
|
error_code = "unknown"
|
||||||
|
llm_error_code = LLMErrorCode.API_RESPONSE_INVALID
|
||||||
|
|
||||||
|
raise LLMException(
|
||||||
|
f"API请求失败: {error_message}",
|
||||||
|
code=llm_error_code,
|
||||||
|
details={"api_error": error_info, "error_code": error_code},
|
||||||
|
)
|
||||||
|
|
||||||
|
def parse(self, response_json: dict[str, Any]) -> ResponseData:
|
||||||
|
self.validate_response(response_json)
|
||||||
|
|
||||||
|
choices = response_json.get("choices", [])
|
||||||
|
if not choices:
|
||||||
|
return ResponseData(text="", raw_response=response_json)
|
||||||
|
|
||||||
|
choice = choices[0]
|
||||||
|
message = choice.get("message", {})
|
||||||
|
content = message.get("content", "")
|
||||||
|
reasoning_content = message.get("reasoning_content", None)
|
||||||
|
refusal = message.get("refusal")
|
||||||
|
|
||||||
|
if refusal:
|
||||||
|
raise LLMException(
|
||||||
|
f"模型拒绝生成请求: {refusal}",
|
||||||
|
code=LLMErrorCode.CONTENT_FILTERED,
|
||||||
|
details={"refusal": refusal},
|
||||||
|
recoverable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if content:
|
||||||
|
content = content.strip()
|
||||||
|
|
||||||
|
images_payload: list[bytes | Path] = []
|
||||||
|
if content and content.startswith("{") and content.endswith("}"):
|
||||||
|
try:
|
||||||
|
content_json = json.loads(content)
|
||||||
|
if "b64_json" in content_json:
|
||||||
|
b64_str = content_json["b64_json"]
|
||||||
|
if isinstance(b64_str, str) and b64_str.startswith("data:"):
|
||||||
|
b64_str = b64_str.split(",", 1)[1]
|
||||||
|
decoded = base64.b64decode(b64_str)
|
||||||
|
images_payload.append(process_image_data(decoded))
|
||||||
|
content = "[图片已生成]"
|
||||||
|
elif "data" in content_json and isinstance(content_json["data"], str):
|
||||||
|
b64_str = content_json["data"]
|
||||||
|
if b64_str.startswith("data:"):
|
||||||
|
b64_str = b64_str.split(",", 1)[1]
|
||||||
|
decoded = base64.b64decode(b64_str)
|
||||||
|
images_payload.append(process_image_data(decoded))
|
||||||
|
content = "[图片已生成]"
|
||||||
|
|
||||||
|
except (json.JSONDecodeError, KeyError, binascii.Error):
|
||||||
|
pass
|
||||||
|
elif (
|
||||||
|
"images" in message
|
||||||
|
and isinstance(message["images"], list)
|
||||||
|
and message["images"]
|
||||||
|
):
|
||||||
|
for image_info in message["images"]:
|
||||||
|
if image_info.get("type") == "image_url":
|
||||||
|
image_url_obj = image_info.get("image_url", {})
|
||||||
|
url_str = image_url_obj.get("url", "")
|
||||||
|
if url_str.startswith("data:image"):
|
||||||
|
try:
|
||||||
|
b64_data = url_str.split(",", 1)[1]
|
||||||
|
decoded = base64.b64decode(b64_data)
|
||||||
|
images_payload.append(process_image_data(decoded))
|
||||||
|
except (IndexError, binascii.Error) as e:
|
||||||
|
logger.warning(f"解析OpenRouter Base64图片数据失败: {e}")
|
||||||
|
|
||||||
|
if images_payload:
|
||||||
|
content = content if content else "[图片已生成]"
|
||||||
|
|
||||||
|
parsed_tool_calls: list[LLMToolCall] | None = None
|
||||||
|
if message_tool_calls := message.get("tool_calls"):
|
||||||
|
parsed_tool_calls = []
|
||||||
|
for tc_data in message_tool_calls:
|
||||||
|
try:
|
||||||
|
if tc_data.get("type") == "function":
|
||||||
|
parsed_tool_calls.append(
|
||||||
|
LLMToolCall(
|
||||||
|
id=tc_data["id"],
|
||||||
|
function=LLMToolFunction(
|
||||||
|
name=tc_data["function"]["name"],
|
||||||
|
arguments=tc_data["function"]["arguments"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except KeyError as e:
|
||||||
|
logger.warning(
|
||||||
|
f"解析OpenAI工具调用数据时缺少键: {tc_data}, 错误: {e}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"解析OpenAI工具调用数据时出错: {tc_data}, 错误: {e}"
|
||||||
|
)
|
||||||
|
if not parsed_tool_calls:
|
||||||
|
parsed_tool_calls = None
|
||||||
|
|
||||||
|
final_text = content if content is not None else ""
|
||||||
|
if not final_text and parsed_tool_calls:
|
||||||
|
final_text = f"请求调用 {len(parsed_tool_calls)} 个工具。"
|
||||||
|
|
||||||
|
usage_info = response_json.get("usage")
|
||||||
|
|
||||||
|
return ResponseData(
|
||||||
|
text=final_text,
|
||||||
|
tool_calls=parsed_tool_calls,
|
||||||
|
usage_info=usage_info,
|
||||||
|
images=images_payload if images_payload else None,
|
||||||
|
raw_response=response_json,
|
||||||
|
thought_text=reasoning_content,
|
||||||
|
)
|
||||||
@ -2,10 +2,17 @@
|
|||||||
LLM 适配器工厂类
|
LLM 适配器工厂类
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import ClassVar
|
import fnmatch
|
||||||
|
from typing import TYPE_CHECKING, Any, ClassVar
|
||||||
|
|
||||||
from ..types.exceptions import LLMErrorCode, LLMException
|
from ..types.exceptions import LLMErrorCode, LLMException
|
||||||
from .base import BaseAdapter
|
from ..types.models import ToolChoice
|
||||||
|
from .base import BaseAdapter, RequestData, ResponseData
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..config.generation import LLMEmbeddingConfig, LLMGenerationConfig
|
||||||
|
from ..service import LLMModel
|
||||||
|
from ..types import LLMMessage
|
||||||
|
|
||||||
|
|
||||||
class LLMAdapterFactory:
|
class LLMAdapterFactory:
|
||||||
@ -21,10 +28,13 @@ class LLMAdapterFactory:
|
|||||||
return
|
return
|
||||||
|
|
||||||
from .gemini import GeminiAdapter
|
from .gemini import GeminiAdapter
|
||||||
from .openai import OpenAIAdapter
|
from .openai import DeepSeekAdapter, OpenAIAdapter, OpenAIImageAdapter
|
||||||
|
|
||||||
cls.register_adapter(OpenAIAdapter())
|
cls.register_adapter(OpenAIAdapter())
|
||||||
|
cls.register_adapter(DeepSeekAdapter())
|
||||||
cls.register_adapter(GeminiAdapter())
|
cls.register_adapter(GeminiAdapter())
|
||||||
|
cls.register_adapter(SmartAdapter())
|
||||||
|
cls.register_adapter(OpenAIImageAdapter())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_adapter(cls, adapter: BaseAdapter) -> None:
|
def register_adapter(cls, adapter: BaseAdapter) -> None:
|
||||||
@ -74,3 +84,100 @@ def get_adapter_for_api_type(api_type: str) -> BaseAdapter:
|
|||||||
def register_adapter(adapter: BaseAdapter) -> None:
|
def register_adapter(adapter: BaseAdapter) -> None:
|
||||||
"""注册新的适配器"""
|
"""注册新的适配器"""
|
||||||
LLMAdapterFactory.register_adapter(adapter)
|
LLMAdapterFactory.register_adapter(adapter)
|
||||||
|
|
||||||
|
|
||||||
|
class SmartAdapter(BaseAdapter):
|
||||||
|
"""
|
||||||
|
智能路由适配器。
|
||||||
|
本身不处理序列化,而是根据规则委托给 OpenAIAdapter 或 GeminiAdapter。
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def log_sanitization_context(self) -> str:
|
||||||
|
return "openai_request"
|
||||||
|
|
||||||
|
_ROUTING_RULES: ClassVar[list[tuple[str, str]]] = [
|
||||||
|
("*nano-banana*", "gemini"),
|
||||||
|
("*gemini*", "gemini"),
|
||||||
|
]
|
||||||
|
_DEFAULT_API_TYPE: ClassVar[str] = "openai"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._adapter_cache: dict[str, BaseAdapter] = {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def api_type(self) -> str:
|
||||||
|
return "smart"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_api_types(self) -> list[str]:
|
||||||
|
return ["smart"]
|
||||||
|
|
||||||
|
def _get_delegate_adapter(self, model: "LLMModel") -> BaseAdapter:
|
||||||
|
"""
|
||||||
|
核心路由逻辑:决定使用哪个适配器 (带缓存)
|
||||||
|
"""
|
||||||
|
if model.model_detail.api_type:
|
||||||
|
return get_adapter_for_api_type(model.model_detail.api_type)
|
||||||
|
|
||||||
|
model_name = model.model_name
|
||||||
|
if model_name in self._adapter_cache:
|
||||||
|
return self._adapter_cache[model_name]
|
||||||
|
|
||||||
|
target_api_type = self._DEFAULT_API_TYPE
|
||||||
|
model_name_lower = model_name.lower()
|
||||||
|
|
||||||
|
for pattern, api_type in self._ROUTING_RULES:
|
||||||
|
if fnmatch.fnmatch(model_name_lower, pattern):
|
||||||
|
target_api_type = api_type
|
||||||
|
break
|
||||||
|
|
||||||
|
adapter = get_adapter_for_api_type(target_api_type)
|
||||||
|
self._adapter_cache[model_name] = adapter
|
||||||
|
return adapter
|
||||||
|
|
||||||
|
async def prepare_advanced_request(
|
||||||
|
self,
|
||||||
|
model: "LLMModel",
|
||||||
|
api_key: str,
|
||||||
|
messages: list["LLMMessage"],
|
||||||
|
config: "LLMGenerationConfig | None" = None,
|
||||||
|
tools: list[Any] | None = None,
|
||||||
|
tool_choice: "str | dict[str, Any] | ToolChoice | None" = None,
|
||||||
|
) -> RequestData:
|
||||||
|
adapter = self._get_delegate_adapter(model)
|
||||||
|
return await adapter.prepare_advanced_request(
|
||||||
|
model, api_key, messages, config, tools, tool_choice
|
||||||
|
)
|
||||||
|
|
||||||
|
def parse_response(
|
||||||
|
self,
|
||||||
|
model: "LLMModel",
|
||||||
|
response_json: dict[str, Any],
|
||||||
|
is_advanced: bool = False,
|
||||||
|
) -> ResponseData:
|
||||||
|
adapter = self._get_delegate_adapter(model)
|
||||||
|
return adapter.parse_response(model, response_json, is_advanced)
|
||||||
|
|
||||||
|
def prepare_embedding_request(
|
||||||
|
self,
|
||||||
|
model: "LLMModel",
|
||||||
|
api_key: str,
|
||||||
|
texts: list[str],
|
||||||
|
config: "LLMEmbeddingConfig",
|
||||||
|
) -> RequestData:
|
||||||
|
adapter = self._get_delegate_adapter(model)
|
||||||
|
return adapter.prepare_embedding_request(model, api_key, texts, config)
|
||||||
|
|
||||||
|
def parse_embedding_response(
|
||||||
|
self, response_json: dict[str, Any]
|
||||||
|
) -> list[list[float]]:
|
||||||
|
return get_adapter_for_api_type("openai").parse_embedding_response(
|
||||||
|
response_json
|
||||||
|
)
|
||||||
|
|
||||||
|
def convert_generation_config(
|
||||||
|
self, config: "LLMGenerationConfig", model: "LLMModel"
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
adapter = self._get_delegate_adapter(model)
|
||||||
|
return adapter.convert_generation_config(config, model)
|
||||||
|
|||||||
@ -2,27 +2,35 @@
|
|||||||
Gemini API 适配器
|
Gemini API 适配器
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import base64
|
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
|
|
||||||
|
from ..config.generation import ResponseFormat
|
||||||
|
from ..types import LLMContentPart
|
||||||
from ..types.exceptions import LLMErrorCode, LLMException
|
from ..types.exceptions import LLMErrorCode, LLMException
|
||||||
from ..utils import sanitize_schema_for_llm
|
from ..types.models import BasePlatformTool, ToolChoice
|
||||||
from .base import BaseAdapter, RequestData, ResponseData
|
from .base import BaseAdapter, RequestData, ResponseData
|
||||||
|
from .components.gemini_components import (
|
||||||
|
GeminiConfigMapper,
|
||||||
|
GeminiMessageConverter,
|
||||||
|
GeminiResponseParser,
|
||||||
|
GeminiToolSerializer,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..config.generation import LLMGenerationConfig
|
from ..config.generation import LLMEmbeddingConfig, LLMGenerationConfig
|
||||||
from ..service import LLMModel
|
from ..service import LLMModel
|
||||||
from ..types.content import LLMMessage
|
from ..types import LLMMessage
|
||||||
from ..types.enums import EmbeddingTaskType
|
|
||||||
from ..types.models import LLMToolCall
|
|
||||||
from ..types.protocols import ToolExecutable
|
|
||||||
|
|
||||||
|
|
||||||
class GeminiAdapter(BaseAdapter):
|
class GeminiAdapter(BaseAdapter):
|
||||||
"""Gemini API 适配器"""
|
"""Gemini API 适配器"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def log_sanitization_context(self) -> str:
|
||||||
|
return "gemini_request"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def api_type(self) -> str:
|
def api_type(self) -> str:
|
||||||
return "gemini"
|
return "gemini"
|
||||||
@ -47,110 +55,75 @@ class GeminiAdapter(BaseAdapter):
|
|||||||
api_key: str,
|
api_key: str,
|
||||||
messages: list["LLMMessage"],
|
messages: list["LLMMessage"],
|
||||||
config: "LLMGenerationConfig | None" = None,
|
config: "LLMGenerationConfig | None" = None,
|
||||||
tools: dict[str, "ToolExecutable"] | None = None,
|
tools: list[Any] | None = None,
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
tool_choice: str | dict[str, Any] | ToolChoice | None = None,
|
||||||
) -> RequestData:
|
) -> RequestData:
|
||||||
"""准备高级请求"""
|
"""准备高级请求"""
|
||||||
effective_config = config if config is not None else model._generation_config
|
effective_config = config if config is not None else model._generation_config
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
from ..types.models import GeminiUrlContext
|
||||||
|
|
||||||
|
context_urls: list[str] = []
|
||||||
|
for tool in tools:
|
||||||
|
if isinstance(tool, GeminiUrlContext):
|
||||||
|
context_urls.extend(tool.urls)
|
||||||
|
|
||||||
|
if context_urls and messages:
|
||||||
|
last_msg = messages[-1]
|
||||||
|
if last_msg.role == "user":
|
||||||
|
url_text = "\n\n[Context URLs]:\n" + "\n".join(context_urls)
|
||||||
|
if isinstance(last_msg.content, str):
|
||||||
|
last_msg.content += url_text
|
||||||
|
elif isinstance(last_msg.content, list):
|
||||||
|
last_msg.content.append(LLMContentPart.text_part(url_text))
|
||||||
|
|
||||||
|
has_function_tools = False
|
||||||
|
if tools:
|
||||||
|
has_function_tools = any(hasattr(tool, "get_definition") for tool in tools)
|
||||||
|
|
||||||
|
is_structured = False
|
||||||
|
if effective_config and effective_config.output:
|
||||||
|
if (
|
||||||
|
effective_config.output.response_schema
|
||||||
|
or effective_config.output.response_format == ResponseFormat.JSON
|
||||||
|
or effective_config.output.response_mime_type == "application/json"
|
||||||
|
):
|
||||||
|
is_structured = True
|
||||||
|
|
||||||
|
if (has_function_tools or is_structured) and effective_config:
|
||||||
|
if effective_config.reasoning is None:
|
||||||
|
from ..config.generation import ReasoningConfig
|
||||||
|
|
||||||
|
effective_config.reasoning = ReasoningConfig()
|
||||||
|
|
||||||
|
if (
|
||||||
|
effective_config.reasoning.budget_tokens is None
|
||||||
|
and effective_config.reasoning.effort is None
|
||||||
|
):
|
||||||
|
reason_desc = "工具调用" if has_function_tools else "结构化输出"
|
||||||
|
logger.debug(
|
||||||
|
f"检测到{reason_desc},自动为模型 {model.model_name} 开启思维链增强"
|
||||||
|
)
|
||||||
|
effective_config.reasoning.budget_tokens = -1
|
||||||
|
|
||||||
endpoint = self._get_gemini_endpoint(model, effective_config)
|
endpoint = self._get_gemini_endpoint(model, effective_config)
|
||||||
url = self.get_api_url(model, endpoint)
|
url = self.get_api_url(model, endpoint)
|
||||||
headers = self.get_base_headers(api_key)
|
headers = self.get_base_headers(api_key)
|
||||||
|
|
||||||
gemini_contents: list[dict[str, Any]] = []
|
converter = GeminiMessageConverter()
|
||||||
system_instruction_parts: list[dict[str, Any]] | None = None
|
system_instruction_parts: list[dict[str, Any]] | None = None
|
||||||
|
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
current_parts: list[dict[str, Any]] = []
|
|
||||||
if msg.role == "system":
|
if msg.role == "system":
|
||||||
if isinstance(msg.content, str):
|
if isinstance(msg.content, str):
|
||||||
system_instruction_parts = [{"text": msg.content}]
|
system_instruction_parts = [{"text": msg.content}]
|
||||||
elif isinstance(msg.content, list):
|
elif isinstance(msg.content, list):
|
||||||
system_instruction_parts = [
|
system_instruction_parts = [
|
||||||
await part.convert_for_api_async("gemini")
|
await converter.convert_part(part) for part in msg.content
|
||||||
for part in msg.content
|
|
||||||
]
|
]
|
||||||
continue
|
continue
|
||||||
|
|
||||||
elif msg.role == "user":
|
gemini_contents = await converter.convert_messages_async(messages)
|
||||||
if isinstance(msg.content, str):
|
|
||||||
current_parts.append({"text": msg.content})
|
|
||||||
elif isinstance(msg.content, list):
|
|
||||||
for part_obj in msg.content:
|
|
||||||
current_parts.append(
|
|
||||||
await part_obj.convert_for_api_async("gemini")
|
|
||||||
)
|
|
||||||
gemini_contents.append({"role": "user", "parts": current_parts})
|
|
||||||
|
|
||||||
elif msg.role == "assistant" or msg.role == "model":
|
|
||||||
if isinstance(msg.content, str) and msg.content:
|
|
||||||
current_parts.append({"text": msg.content})
|
|
||||||
elif isinstance(msg.content, list):
|
|
||||||
for part_obj in msg.content:
|
|
||||||
current_parts.append(
|
|
||||||
await part_obj.convert_for_api_async("gemini")
|
|
||||||
)
|
|
||||||
|
|
||||||
if msg.tool_calls:
|
|
||||||
import json
|
|
||||||
|
|
||||||
for call in msg.tool_calls:
|
|
||||||
current_parts.append(
|
|
||||||
{
|
|
||||||
"functionCall": {
|
|
||||||
"name": call.function.name,
|
|
||||||
"args": json.loads(call.function.arguments),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if current_parts:
|
|
||||||
gemini_contents.append({"role": "model", "parts": current_parts})
|
|
||||||
|
|
||||||
elif msg.role == "tool":
|
|
||||||
if not msg.name:
|
|
||||||
raise ValueError("Gemini 工具消息必须包含 'name' 字段(函数名)。")
|
|
||||||
|
|
||||||
import json
|
|
||||||
|
|
||||||
try:
|
|
||||||
content_str = (
|
|
||||||
msg.content
|
|
||||||
if isinstance(msg.content, str)
|
|
||||||
else str(msg.content)
|
|
||||||
)
|
|
||||||
tool_result_obj = json.loads(content_str)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
content_str = (
|
|
||||||
msg.content
|
|
||||||
if isinstance(msg.content, str)
|
|
||||||
else str(msg.content)
|
|
||||||
)
|
|
||||||
logger.warning(
|
|
||||||
f"工具 {msg.name} 的结果不是有效的 JSON: {content_str}. "
|
|
||||||
f"包装为原始字符串。"
|
|
||||||
)
|
|
||||||
tool_result_obj = {"raw_output": content_str}
|
|
||||||
|
|
||||||
if isinstance(tool_result_obj, list):
|
|
||||||
logger.debug(
|
|
||||||
f"工具 '{msg.name}' 的返回结果是列表,"
|
|
||||||
f"正在为Gemini API包装为JSON对象。"
|
|
||||||
)
|
|
||||||
final_response_payload = {"result": tool_result_obj}
|
|
||||||
elif not isinstance(tool_result_obj, dict):
|
|
||||||
final_response_payload = {"result": tool_result_obj}
|
|
||||||
else:
|
|
||||||
final_response_payload = tool_result_obj
|
|
||||||
|
|
||||||
current_parts.append(
|
|
||||||
{
|
|
||||||
"functionResponse": {
|
|
||||||
"name": msg.name,
|
|
||||||
"response": final_response_payload,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
gemini_contents.append({"role": "function", "parts": current_parts})
|
|
||||||
|
|
||||||
body: dict[str, Any] = {"contents": gemini_contents}
|
body: dict[str, Any] = {"contents": gemini_contents}
|
||||||
|
|
||||||
@ -158,75 +131,78 @@ class GeminiAdapter(BaseAdapter):
|
|||||||
body["systemInstruction"] = {"parts": system_instruction_parts}
|
body["systemInstruction"] = {"parts": system_instruction_parts}
|
||||||
|
|
||||||
all_tools_for_request = []
|
all_tools_for_request = []
|
||||||
|
has_user_functions = False
|
||||||
if tools:
|
if tools:
|
||||||
import asyncio
|
from ..types.protocols import ToolExecutable
|
||||||
|
|
||||||
from zhenxun.utils.pydantic_compat import model_dump
|
function_tools: list[ToolExecutable] = []
|
||||||
|
gemini_tools_dict: dict[str, Any] = {}
|
||||||
|
|
||||||
definition_tasks = [
|
for tool in tools:
|
||||||
executable.get_definition() for executable in tools.values()
|
if isinstance(tool, BasePlatformTool):
|
||||||
]
|
declaration = tool.get_tool_declaration()
|
||||||
tool_definitions = await asyncio.gather(*definition_tasks)
|
if declaration:
|
||||||
|
gemini_tools_dict.update(declaration)
|
||||||
|
elif hasattr(tool, "get_definition"):
|
||||||
|
function_tools.append(tool)
|
||||||
|
|
||||||
function_declarations = []
|
if function_tools:
|
||||||
for tool_def in tool_definitions:
|
import asyncio
|
||||||
tool_def.parameters = sanitize_schema_for_llm(
|
|
||||||
tool_def.parameters, api_type="gemini"
|
|
||||||
)
|
|
||||||
function_declarations.append(model_dump(tool_def))
|
|
||||||
|
|
||||||
if function_declarations:
|
definition_tasks = [
|
||||||
all_tools_for_request.append(
|
executable.get_definition() for executable in function_tools
|
||||||
{"functionDeclarations": function_declarations}
|
]
|
||||||
)
|
tool_definitions = await asyncio.gather(*definition_tasks)
|
||||||
|
|
||||||
if effective_config:
|
serializer = GeminiToolSerializer()
|
||||||
if getattr(effective_config, "enable_grounding", False):
|
function_declarations = serializer.serialize_tools(tool_definitions)
|
||||||
has_explicit_gs_tool = any(
|
|
||||||
"googleSearch" in tool_item for tool_item in all_tools_for_request
|
|
||||||
)
|
|
||||||
if not has_explicit_gs_tool:
|
|
||||||
all_tools_for_request.append({"googleSearch": {}})
|
|
||||||
logger.debug("隐式启用 Google Search 工具进行信息来源关联。")
|
|
||||||
|
|
||||||
if getattr(effective_config, "enable_code_execution", False):
|
if function_declarations:
|
||||||
has_explicit_ce_tool = any(
|
gemini_tools_dict["functionDeclarations"] = function_declarations
|
||||||
"codeExecution" in tool_item for tool_item in all_tools_for_request
|
has_user_functions = True
|
||||||
)
|
|
||||||
if not has_explicit_ce_tool:
|
if gemini_tools_dict:
|
||||||
all_tools_for_request.append({"codeExecution": {}})
|
all_tools_for_request.append(gemini_tools_dict)
|
||||||
logger.debug("隐式启用代码执行工具。")
|
|
||||||
|
|
||||||
if all_tools_for_request:
|
if all_tools_for_request:
|
||||||
body["tools"] = all_tools_for_request
|
body["tools"] = all_tools_for_request
|
||||||
|
|
||||||
final_tool_choice = tool_choice
|
tool_config_updates: dict[str, Any] = {}
|
||||||
if final_tool_choice is None and effective_config:
|
if (
|
||||||
final_tool_choice = getattr(effective_config, "tool_choice", None)
|
effective_config
|
||||||
|
and effective_config.custom_params
|
||||||
|
and "user_location" in effective_config.custom_params
|
||||||
|
):
|
||||||
|
tool_config_updates["retrievalConfig"] = {
|
||||||
|
"latLng": effective_config.custom_params["user_location"]
|
||||||
|
}
|
||||||
|
|
||||||
if final_tool_choice:
|
if tool_config_updates:
|
||||||
if isinstance(final_tool_choice, str):
|
body.setdefault("toolConfig", {}).update(tool_config_updates)
|
||||||
mode_upper = final_tool_choice.upper()
|
|
||||||
if mode_upper in ["AUTO", "NONE", "ANY"]:
|
converted_params: dict[str, Any] = {}
|
||||||
body["toolConfig"] = {"functionCallingConfig": {"mode": mode_upper}}
|
if effective_config:
|
||||||
else:
|
converted_params = self.convert_generation_config(effective_config, model)
|
||||||
body["toolConfig"] = self._convert_tool_choice_to_gemini(
|
|
||||||
final_tool_choice
|
if converted_params:
|
||||||
)
|
if "toolConfig" in converted_params:
|
||||||
else:
|
tool_config_payload = converted_params.pop("toolConfig")
|
||||||
body["toolConfig"] = self._convert_tool_choice_to_gemini(
|
fc_config = tool_config_payload.get("functionCallingConfig")
|
||||||
final_tool_choice
|
should_apply_fc = has_user_functions or (
|
||||||
|
fc_config and fc_config.get("mode") == "NONE"
|
||||||
)
|
)
|
||||||
|
if should_apply_fc:
|
||||||
|
body.setdefault("toolConfig", {}).update(tool_config_payload)
|
||||||
|
elif fc_config and fc_config.get("mode") != "AUTO":
|
||||||
|
logger.debug(
|
||||||
|
"Gemini: 忽略针对纯内置工具的 functionCallingConfig (API限制)"
|
||||||
|
)
|
||||||
|
|
||||||
final_generation_config = self._build_gemini_generation_config(
|
if "safetySettings" in converted_params:
|
||||||
model, effective_config
|
body["safetySettings"] = converted_params.pop("safetySettings")
|
||||||
)
|
|
||||||
if final_generation_config:
|
|
||||||
body["generationConfig"] = final_generation_config
|
|
||||||
|
|
||||||
safety_settings = self._build_safety_settings(effective_config)
|
if converted_params:
|
||||||
if safety_settings:
|
body["generationConfig"] = converted_params
|
||||||
body["safetySettings"] = safety_settings
|
|
||||||
|
|
||||||
return RequestData(url=url, headers=headers, body=body)
|
return RequestData(url=url, headers=headers, body=body)
|
||||||
|
|
||||||
@ -242,317 +218,56 @@ class GeminiAdapter(BaseAdapter):
|
|||||||
def _get_gemini_endpoint(
|
def _get_gemini_endpoint(
|
||||||
self, model: "LLMModel", config: "LLMGenerationConfig | None" = None
|
self, model: "LLMModel", config: "LLMGenerationConfig | None" = None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""根据配置选择Gemini API端点"""
|
"""返回Gemini generateContent 端点"""
|
||||||
if config:
|
|
||||||
if getattr(config, "enable_code_execution", False):
|
|
||||||
return f"/v1beta/models/{model.model_name}:generateContent"
|
|
||||||
|
|
||||||
if getattr(config, "enable_grounding", False):
|
|
||||||
return f"/v1beta/models/{model.model_name}:generateContent"
|
|
||||||
|
|
||||||
return f"/v1beta/models/{model.model_name}:generateContent"
|
return f"/v1beta/models/{model.model_name}:generateContent"
|
||||||
|
|
||||||
def _convert_tool_choice_to_gemini(
|
|
||||||
self, tool_choice_value: str | dict[str, Any]
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""转换工具选择策略为Gemini格式"""
|
|
||||||
if isinstance(tool_choice_value, str):
|
|
||||||
mode_upper = tool_choice_value.upper()
|
|
||||||
if mode_upper in ["AUTO", "NONE", "ANY"]:
|
|
||||||
return {"functionCallingConfig": {"mode": mode_upper}}
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"不支持的 tool_choice 字符串值: '{tool_choice_value}'。"
|
|
||||||
f"回退到 AUTO。"
|
|
||||||
)
|
|
||||||
return {"functionCallingConfig": {"mode": "AUTO"}}
|
|
||||||
|
|
||||||
elif isinstance(tool_choice_value, dict):
|
|
||||||
if (
|
|
||||||
tool_choice_value.get("type") == "function"
|
|
||||||
and "function" in tool_choice_value
|
|
||||||
):
|
|
||||||
func_name = tool_choice_value["function"].get("name")
|
|
||||||
if func_name:
|
|
||||||
return {
|
|
||||||
"functionCallingConfig": {
|
|
||||||
"mode": "ANY",
|
|
||||||
"allowedFunctionNames": [func_name],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"tool_choice dict 中的函数名无效: {tool_choice_value}。"
|
|
||||||
f"回退到 AUTO。"
|
|
||||||
)
|
|
||||||
return {"functionCallingConfig": {"mode": "AUTO"}}
|
|
||||||
|
|
||||||
elif "functionCallingConfig" in tool_choice_value:
|
|
||||||
return {
|
|
||||||
"functionCallingConfig": tool_choice_value["functionCallingConfig"]
|
|
||||||
}
|
|
||||||
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"不支持的 tool_choice dict 值: {tool_choice_value}。回退到 AUTO。"
|
|
||||||
)
|
|
||||||
return {"functionCallingConfig": {"mode": "AUTO"}}
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
f"tool_choice 的类型无效: {type(tool_choice_value)}。回退到 AUTO。"
|
|
||||||
)
|
|
||||||
return {"functionCallingConfig": {"mode": "AUTO"}}
|
|
||||||
|
|
||||||
def _build_gemini_generation_config(
|
|
||||||
self, model: "LLMModel", config: "LLMGenerationConfig | None" = None
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""构建Gemini生成配置"""
|
|
||||||
effective_config = config if config is not None else model._generation_config
|
|
||||||
|
|
||||||
if not effective_config:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
generation_config = effective_config.to_api_params(
|
|
||||||
api_type="gemini", model_name=model.model_name
|
|
||||||
)
|
|
||||||
|
|
||||||
if generation_config:
|
|
||||||
param_keys = list(generation_config.keys())
|
|
||||||
logger.debug(
|
|
||||||
f"构建Gemini生成配置完成,包含 {len(generation_config)} 个参数: "
|
|
||||||
f"{param_keys}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return generation_config
|
|
||||||
|
|
||||||
def _build_safety_settings(
|
|
||||||
self, config: "LLMGenerationConfig | None" = None
|
|
||||||
) -> list[dict[str, Any]] | None:
|
|
||||||
"""构建安全设置"""
|
|
||||||
if not config:
|
|
||||||
return None
|
|
||||||
|
|
||||||
safety_settings = []
|
|
||||||
|
|
||||||
safety_categories = [
|
|
||||||
"HARM_CATEGORY_HARASSMENT",
|
|
||||||
"HARM_CATEGORY_HATE_SPEECH",
|
|
||||||
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
|
||||||
"HARM_CATEGORY_DANGEROUS_CONTENT",
|
|
||||||
]
|
|
||||||
|
|
||||||
custom_safety_settings = getattr(config, "safety_settings", None)
|
|
||||||
if custom_safety_settings:
|
|
||||||
for category, threshold in custom_safety_settings.items():
|
|
||||||
safety_settings.append({"category": category, "threshold": threshold})
|
|
||||||
else:
|
|
||||||
from ..config.providers import get_gemini_safety_threshold
|
|
||||||
|
|
||||||
threshold = get_gemini_safety_threshold()
|
|
||||||
for category in safety_categories:
|
|
||||||
safety_settings.append({"category": category, "threshold": threshold})
|
|
||||||
|
|
||||||
return safety_settings if safety_settings else None
|
|
||||||
|
|
||||||
def validate_response(self, response_json: dict[str, Any]) -> None:
|
|
||||||
"""验证 Gemini API 响应,增加对 promptFeedback 的检查"""
|
|
||||||
super().validate_response(response_json)
|
|
||||||
|
|
||||||
if prompt_feedback := response_json.get("promptFeedback"):
|
|
||||||
if block_reason := prompt_feedback.get("blockReason"):
|
|
||||||
logger.warning(
|
|
||||||
f"Gemini 内容因 promptFeedback 被安全过滤: {block_reason}"
|
|
||||||
)
|
|
||||||
raise LLMException(
|
|
||||||
f"内容被安全过滤: {block_reason}",
|
|
||||||
code=LLMErrorCode.CONTENT_FILTERED,
|
|
||||||
details={
|
|
||||||
"block_reason": block_reason,
|
|
||||||
"safety_ratings": prompt_feedback.get("safetyRatings"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def parse_response(
|
def parse_response(
|
||||||
self,
|
self,
|
||||||
model: "LLMModel",
|
model: "LLMModel",
|
||||||
response_json: dict[str, Any],
|
response_json: dict[str, Any],
|
||||||
is_advanced: bool = False,
|
is_advanced: bool = False,
|
||||||
) -> ResponseData:
|
|
||||||
"""解析API响应"""
|
|
||||||
return self._parse_response(model, response_json, is_advanced)
|
|
||||||
|
|
||||||
def _parse_response(
|
|
||||||
self,
|
|
||||||
model: "LLMModel",
|
|
||||||
response_json: dict[str, Any],
|
|
||||||
is_advanced: bool = False,
|
|
||||||
) -> ResponseData:
|
) -> ResponseData:
|
||||||
"""解析 Gemini API 响应"""
|
"""解析 Gemini API 响应"""
|
||||||
_ = is_advanced
|
_ = model, is_advanced
|
||||||
self.validate_response(response_json)
|
parser = GeminiResponseParser()
|
||||||
|
return parser.parse(response_json)
|
||||||
try:
|
|
||||||
if "image_generation" in response_json and isinstance(
|
|
||||||
response_json["image_generation"], dict
|
|
||||||
):
|
|
||||||
candidates_source = response_json["image_generation"]
|
|
||||||
else:
|
|
||||||
candidates_source = response_json
|
|
||||||
|
|
||||||
candidates = candidates_source.get("candidates", [])
|
|
||||||
usage_info = response_json.get("usageMetadata")
|
|
||||||
|
|
||||||
if not candidates:
|
|
||||||
logger.debug("Gemini响应中没有candidates。")
|
|
||||||
return ResponseData(text="", raw_response=response_json)
|
|
||||||
|
|
||||||
candidate = candidates[0]
|
|
||||||
|
|
||||||
if candidate.get("finishReason") in [
|
|
||||||
"RECITATION",
|
|
||||||
"OTHER",
|
|
||||||
] and not candidate.get("content"):
|
|
||||||
logger.warning(
|
|
||||||
f"Gemini candidate finished with reason "
|
|
||||||
f"'{candidate.get('finishReason')}' and no content."
|
|
||||||
)
|
|
||||||
return ResponseData(
|
|
||||||
text="",
|
|
||||||
raw_response=response_json,
|
|
||||||
usage_info=response_json.get("usageMetadata"),
|
|
||||||
)
|
|
||||||
|
|
||||||
content_data = candidate.get("content", {})
|
|
||||||
parts = content_data.get("parts", [])
|
|
||||||
|
|
||||||
text_content = ""
|
|
||||||
images_bytes: list[bytes] = []
|
|
||||||
parsed_tool_calls: list["LLMToolCall"] | None = None
|
|
||||||
thought_summary_parts = []
|
|
||||||
answer_parts = []
|
|
||||||
|
|
||||||
for part in parts:
|
|
||||||
if "text" in part:
|
|
||||||
answer_parts.append(part["text"])
|
|
||||||
elif "thought" in part:
|
|
||||||
thought_summary_parts.append(part["thought"])
|
|
||||||
elif "thoughtSummary" in part:
|
|
||||||
thought_summary_parts.append(part["thoughtSummary"])
|
|
||||||
elif "inlineData" in part:
|
|
||||||
inline_data = part["inlineData"]
|
|
||||||
if "data" in inline_data:
|
|
||||||
images_bytes.append(base64.b64decode(inline_data["data"]))
|
|
||||||
|
|
||||||
elif "functionCall" in part:
|
|
||||||
if parsed_tool_calls is None:
|
|
||||||
parsed_tool_calls = []
|
|
||||||
fc_data = part["functionCall"]
|
|
||||||
try:
|
|
||||||
import json
|
|
||||||
|
|
||||||
from ..types.models import LLMToolCall, LLMToolFunction
|
|
||||||
|
|
||||||
call_id = f"call_{model.provider_name}_{len(parsed_tool_calls)}"
|
|
||||||
parsed_tool_calls.append(
|
|
||||||
LLMToolCall(
|
|
||||||
id=call_id,
|
|
||||||
function=LLMToolFunction(
|
|
||||||
name=fc_data["name"],
|
|
||||||
arguments=json.dumps(fc_data["args"]),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except KeyError as e:
|
|
||||||
logger.warning(
|
|
||||||
f"解析Gemini functionCall时缺少键: {fc_data}, 错误: {e}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(
|
|
||||||
f"解析Gemini functionCall时出错: {fc_data}, 错误: {e}"
|
|
||||||
)
|
|
||||||
elif "codeExecutionResult" in part:
|
|
||||||
result = part["codeExecutionResult"]
|
|
||||||
if result.get("outcome") == "OK":
|
|
||||||
output = result.get("output", "")
|
|
||||||
answer_parts.append(f"\n[代码执行结果]:\n```\n{output}\n```\n")
|
|
||||||
else:
|
|
||||||
answer_parts.append(
|
|
||||||
f"\n[代码执行失败]: {result.get('outcome', 'UNKNOWN')}\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
if thought_summary_parts:
|
|
||||||
full_thought_summary = "\n".join(thought_summary_parts).strip()
|
|
||||||
full_answer = "".join(answer_parts).strip()
|
|
||||||
|
|
||||||
formatted_parts = []
|
|
||||||
if full_thought_summary:
|
|
||||||
formatted_parts.append(f"🤔 **思考过程**\n\n{full_thought_summary}")
|
|
||||||
if full_answer:
|
|
||||||
separator = "\n\n---\n\n" if full_thought_summary else ""
|
|
||||||
formatted_parts.append(f"{separator}✅ **回答**\n\n{full_answer}")
|
|
||||||
|
|
||||||
text_content = "".join(formatted_parts)
|
|
||||||
else:
|
|
||||||
text_content = "".join(answer_parts)
|
|
||||||
|
|
||||||
usage_info = response_json.get("usageMetadata")
|
|
||||||
|
|
||||||
grounding_metadata_obj = None
|
|
||||||
if grounding_data := candidate.get("groundingMetadata"):
|
|
||||||
try:
|
|
||||||
from ..types.models import LLMGroundingMetadata
|
|
||||||
|
|
||||||
grounding_metadata_obj = LLMGroundingMetadata(**grounding_data)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"无法解析Grounding元数据: {grounding_data}, {e}")
|
|
||||||
|
|
||||||
return ResponseData(
|
|
||||||
text=text_content,
|
|
||||||
tool_calls=parsed_tool_calls,
|
|
||||||
images=images_bytes if images_bytes else None,
|
|
||||||
usage_info=usage_info,
|
|
||||||
raw_response=response_json,
|
|
||||||
grounding_metadata=grounding_metadata_obj,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"解析 Gemini 响应失败: {e}", e=e)
|
|
||||||
raise LLMException(
|
|
||||||
f"解析API响应失败: {e}",
|
|
||||||
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
|
|
||||||
cause=e,
|
|
||||||
)
|
|
||||||
|
|
||||||
def prepare_embedding_request(
|
def prepare_embedding_request(
|
||||||
self,
|
self,
|
||||||
model: "LLMModel",
|
model: "LLMModel",
|
||||||
api_key: str,
|
api_key: str,
|
||||||
texts: list[str],
|
texts: list[str],
|
||||||
task_type: "EmbeddingTaskType | str",
|
config: "LLMEmbeddingConfig",
|
||||||
**kwargs: Any,
|
|
||||||
) -> RequestData:
|
) -> RequestData:
|
||||||
"""准备文本嵌入请求"""
|
"""准备文本嵌入请求"""
|
||||||
api_model_name = model.model_name
|
api_model_name = model.model_name
|
||||||
if not api_model_name.startswith("models/"):
|
if not api_model_name.startswith("models/"):
|
||||||
api_model_name = f"models/{api_model_name}"
|
api_model_name = f"models/{api_model_name}"
|
||||||
|
|
||||||
url = self.get_api_url(model, f"/{api_model_name}:batchEmbedContents")
|
if not model.api_base:
|
||||||
|
raise LLMException(
|
||||||
|
f"模型 {model.model_name} 的 api_base 未设置",
|
||||||
|
code=LLMErrorCode.CONFIGURATION_ERROR,
|
||||||
|
)
|
||||||
|
|
||||||
|
base_url = model.api_base.rstrip("/")
|
||||||
|
url = f"{base_url}/v1beta/{api_model_name}:batchEmbedContents"
|
||||||
headers = self.get_base_headers(api_key)
|
headers = self.get_base_headers(api_key)
|
||||||
|
|
||||||
requests_payload = []
|
requests_payload = []
|
||||||
for text_content in texts:
|
for text_content in texts:
|
||||||
|
safe_text = text_content if text_content else " "
|
||||||
request_item: dict[str, Any] = {
|
request_item: dict[str, Any] = {
|
||||||
"content": {"parts": [{"text": text_content}]},
|
"model": api_model_name,
|
||||||
|
"content": {"parts": [{"text": safe_text}]},
|
||||||
}
|
}
|
||||||
|
|
||||||
from ..types.enums import EmbeddingTaskType
|
if config.task_type:
|
||||||
|
request_item["task_type"] = str(config.task_type).upper()
|
||||||
if task_type and task_type != EmbeddingTaskType.RETRIEVAL_DOCUMENT:
|
if config.title:
|
||||||
request_item["task_type"] = str(task_type).upper()
|
request_item["title"] = config.title
|
||||||
if title := kwargs.get("title"):
|
if config.output_dimensionality:
|
||||||
request_item["title"] = title
|
request_item["output_dimensionality"] = config.output_dimensionality
|
||||||
if output_dimensionality := kwargs.get("output_dimensionality"):
|
|
||||||
request_item["output_dimensionality"] = output_dimensionality
|
|
||||||
|
|
||||||
requests_payload.append(request_item)
|
requests_payload.append(request_item)
|
||||||
|
|
||||||
@ -601,3 +316,9 @@ class GeminiAdapter(BaseAdapter):
|
|||||||
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
|
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
|
||||||
details=response_json,
|
details=response_json,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def convert_generation_config(
|
||||||
|
self, config: "LLMGenerationConfig", model: "LLMModel"
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
mapper = GeminiConfigMapper()
|
||||||
|
return mapper.map_config(config, model.model_detail, model.capabilities)
|
||||||
|
|||||||
@ -1,15 +1,181 @@
|
|||||||
"""
|
"""
|
||||||
OpenAI API 适配器
|
OpenAI API 适配器
|
||||||
|
|
||||||
支持 OpenAI、DeepSeek、智谱AI 和其他 OpenAI 兼容的 API 服务。
|
支持 OpenAI、智谱AI 等 OpenAI 兼容的 API 服务。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from abc import ABC, abstractmethod
|
||||||
|
import base64
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from .base import OpenAICompatAdapter
|
import json_repair
|
||||||
|
|
||||||
|
from zhenxun.services.llm.config.generation import ImageAspectRatio
|
||||||
|
from zhenxun.services.llm.types.exceptions import LLMErrorCode, LLMException
|
||||||
|
from zhenxun.services.log import logger
|
||||||
|
from zhenxun.utils.http_utils import AsyncHttpx
|
||||||
|
|
||||||
|
from ..types import StructuredOutputStrategy
|
||||||
|
from ..types.models import ToolChoice
|
||||||
|
from ..utils import sanitize_schema_for_llm
|
||||||
|
from .base import (
|
||||||
|
BaseAdapter,
|
||||||
|
OpenAICompatAdapter,
|
||||||
|
RequestData,
|
||||||
|
ResponseData,
|
||||||
|
process_image_data,
|
||||||
|
)
|
||||||
|
from .components.openai_components import (
|
||||||
|
OpenAIConfigMapper,
|
||||||
|
OpenAIMessageConverter,
|
||||||
|
OpenAIResponseParser,
|
||||||
|
OpenAIToolSerializer,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from ..config.generation import LLMEmbeddingConfig, LLMGenerationConfig
|
||||||
from ..service import LLMModel
|
from ..service import LLMModel
|
||||||
|
from ..types import LLMMessage
|
||||||
|
|
||||||
|
|
||||||
|
class APIProtocol(ABC):
|
||||||
|
"""API 协议策略基类"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def build_request_body(
|
||||||
|
self,
|
||||||
|
model: "LLMModel",
|
||||||
|
messages: list["LLMMessage"],
|
||||||
|
tools: list[dict[str, Any]] | None,
|
||||||
|
tool_choice: Any,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""构建不同协议下的请求体"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def parse_response(self, response_json: dict[str, Any]) -> ResponseData:
|
||||||
|
"""解析不同协议下的响应"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class StandardProtocol(APIProtocol):
|
||||||
|
"""标准 OpenAI 协议策略"""
|
||||||
|
|
||||||
|
def __init__(self, adapter: "OpenAICompatAdapter"):
|
||||||
|
self.adapter = adapter
|
||||||
|
|
||||||
|
def build_request_body(
|
||||||
|
self,
|
||||||
|
model: "LLMModel",
|
||||||
|
messages: list["LLMMessage"],
|
||||||
|
tools: list[dict[str, Any]] | None,
|
||||||
|
tool_choice: Any,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
converter = OpenAIMessageConverter()
|
||||||
|
openai_messages = converter.convert_messages(messages)
|
||||||
|
body: dict[str, Any] = {
|
||||||
|
"model": model.model_name,
|
||||||
|
"messages": openai_messages,
|
||||||
|
}
|
||||||
|
if tools:
|
||||||
|
body["tools"] = tools
|
||||||
|
if tool_choice:
|
||||||
|
body["tool_choice"] = tool_choice
|
||||||
|
return body
|
||||||
|
|
||||||
|
def parse_response(self, response_json: dict[str, Any]) -> ResponseData:
|
||||||
|
parser = OpenAIResponseParser()
|
||||||
|
return parser.parse(response_json)
|
||||||
|
|
||||||
|
|
||||||
|
class ResponsesProtocol(APIProtocol):
|
||||||
|
"""/v1/responses 新版协议策略"""
|
||||||
|
|
||||||
|
def __init__(self, adapter: "OpenAICompatAdapter"):
|
||||||
|
self.adapter = adapter
|
||||||
|
|
||||||
|
def build_request_body(
|
||||||
|
self,
|
||||||
|
model: "LLMModel",
|
||||||
|
messages: list["LLMMessage"],
|
||||||
|
tools: list[dict[str, Any]] | None,
|
||||||
|
tool_choice: Any,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
input_items: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
role = msg.role
|
||||||
|
content_list: list[dict[str, Any]] = []
|
||||||
|
raw_contents = (
|
||||||
|
msg.content if isinstance(msg.content, list) else [msg.content]
|
||||||
|
)
|
||||||
|
|
||||||
|
for part in raw_contents:
|
||||||
|
if part is None:
|
||||||
|
continue
|
||||||
|
if isinstance(part, str):
|
||||||
|
content_list.append({"type": "input_text", "text": part})
|
||||||
|
continue
|
||||||
|
|
||||||
|
if hasattr(part, "type"):
|
||||||
|
part_type = getattr(part, "type", None)
|
||||||
|
if part_type == "text":
|
||||||
|
content_list.append(
|
||||||
|
{"type": "input_text", "text": getattr(part, "text", "")}
|
||||||
|
)
|
||||||
|
elif part_type == "image":
|
||||||
|
content_list.append(
|
||||||
|
{
|
||||||
|
"type": "input_image",
|
||||||
|
"image_url": getattr(part, "image_source", ""),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(part, dict):
|
||||||
|
part_type = part.get("type")
|
||||||
|
if part_type == "text":
|
||||||
|
content_list.append(
|
||||||
|
{"type": "input_text", "text": part.get("text", "")}
|
||||||
|
)
|
||||||
|
elif part_type in {"image", "image_url"}:
|
||||||
|
image_src = part.get("image_url") or part.get(
|
||||||
|
"image_source", ""
|
||||||
|
)
|
||||||
|
content_list.append(
|
||||||
|
{
|
||||||
|
"type": "input_image",
|
||||||
|
"image_url": image_src,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
input_items.append({"role": role, "content": content_list})
|
||||||
|
|
||||||
|
body: dict[str, Any] = {
|
||||||
|
"model": model.model_name,
|
||||||
|
"input": input_items,
|
||||||
|
}
|
||||||
|
if tools:
|
||||||
|
body["tools"] = tools
|
||||||
|
if tool_choice:
|
||||||
|
body["tool_choice"] = tool_choice
|
||||||
|
return body
|
||||||
|
|
||||||
|
def parse_response(self, response_json: dict[str, Any]) -> ResponseData:
|
||||||
|
self.adapter.validate_response(response_json)
|
||||||
|
text_content = ""
|
||||||
|
for item in response_json.get("output", []):
|
||||||
|
if item.get("type") == "message" and item.get("role") == "assistant":
|
||||||
|
for content_item in item.get("content", []):
|
||||||
|
if content_item.get("type") == "output_text":
|
||||||
|
text_content += content_item.get("text", "")
|
||||||
|
|
||||||
|
return ResponseData(
|
||||||
|
text=text_content,
|
||||||
|
usage_info=response_json.get("usage"),
|
||||||
|
raw_response=response_json,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIAdapter(OpenAICompatAdapter):
|
class OpenAIAdapter(OpenAICompatAdapter):
|
||||||
@ -23,23 +189,411 @@ class OpenAIAdapter(OpenAICompatAdapter):
|
|||||||
def supported_api_types(self) -> list[str]:
|
def supported_api_types(self) -> list[str]:
|
||||||
return [
|
return [
|
||||||
"openai",
|
"openai",
|
||||||
"deepseek",
|
|
||||||
"zhipu",
|
"zhipu",
|
||||||
"general_openai_compat",
|
|
||||||
"ark",
|
"ark",
|
||||||
"openrouter",
|
"openrouter",
|
||||||
|
"openai_responses",
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_chat_endpoint(self, model: "LLMModel") -> str:
|
def get_chat_endpoint(self, model: "LLMModel") -> str:
|
||||||
"""返回聊天完成端点"""
|
"""返回聊天完成端点"""
|
||||||
if model.api_type == "ark":
|
if model.model_detail.endpoint:
|
||||||
|
return model.model_detail.endpoint
|
||||||
|
|
||||||
|
current_api_type = model.model_detail.api_type or model.api_type
|
||||||
|
|
||||||
|
if current_api_type == "openai_responses":
|
||||||
|
return "/v1/responses"
|
||||||
|
if current_api_type == "ark":
|
||||||
return "/api/v3/chat/completions"
|
return "/api/v3/chat/completions"
|
||||||
if model.api_type == "zhipu":
|
if current_api_type == "zhipu":
|
||||||
return "/api/paas/v4/chat/completions"
|
return "/api/paas/v4/chat/completions"
|
||||||
return "/v1/chat/completions"
|
return "/v1/chat/completions"
|
||||||
|
|
||||||
|
def _get_protocol_strategy(self, model: "LLMModel") -> APIProtocol:
|
||||||
|
"""根据 API 类型获取对应的处理策略"""
|
||||||
|
current_api_type = model.model_detail.api_type or model.api_type
|
||||||
|
if current_api_type == "openai_responses":
|
||||||
|
return ResponsesProtocol(self)
|
||||||
|
return StandardProtocol(self)
|
||||||
|
|
||||||
def get_embedding_endpoint(self, model: "LLMModel") -> str:
|
def get_embedding_endpoint(self, model: "LLMModel") -> str:
|
||||||
"""根据API类型返回嵌入端点"""
|
"""根据API类型返回嵌入端点"""
|
||||||
if model.api_type == "zhipu":
|
if model.api_type == "zhipu":
|
||||||
return "/v4/embeddings"
|
return "/v4/embeddings"
|
||||||
return "/v1/embeddings"
|
return "/v1/embeddings"
|
||||||
|
|
||||||
|
def convert_generation_config(
|
||||||
|
self, config: "LLMGenerationConfig", model: "LLMModel"
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
mapper = OpenAIConfigMapper(api_type=self.api_type)
|
||||||
|
return mapper.map_config(config, model.model_detail, model.capabilities)
|
||||||
|
|
||||||
|
async def prepare_advanced_request(
|
||||||
|
self,
|
||||||
|
model: "LLMModel",
|
||||||
|
api_key: str,
|
||||||
|
messages: list["LLMMessage"],
|
||||||
|
config: "LLMGenerationConfig | None" = None,
|
||||||
|
tools: list[Any] | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | ToolChoice | None = None,
|
||||||
|
) -> "RequestData":
|
||||||
|
"""根据不同协议策略构建高级请求"""
|
||||||
|
url = self.get_api_url(model, self.get_chat_endpoint(model))
|
||||||
|
headers = self.get_base_headers(api_key)
|
||||||
|
if model.api_type == "openrouter":
|
||||||
|
headers.update(
|
||||||
|
{
|
||||||
|
"HTTP-Referer": "https://github.com/zhenxun-org/zhenxun_bot",
|
||||||
|
"X-Title": "Zhenxun Bot",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
default_config = getattr(model, "_generation_config", None)
|
||||||
|
effective_config = config if config is not None else default_config
|
||||||
|
structured_strategy = (
|
||||||
|
effective_config.output.structured_output_strategy
|
||||||
|
if effective_config and effective_config.output
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if structured_strategy is None:
|
||||||
|
structured_strategy = StructuredOutputStrategy.NATIVE
|
||||||
|
|
||||||
|
openai_tools: list[dict[str, Any]] | None = None
|
||||||
|
executables: list[Any] = []
|
||||||
|
if tools:
|
||||||
|
if isinstance(tools, dict):
|
||||||
|
executables = list(tools.values())
|
||||||
|
else:
|
||||||
|
for tool in tools:
|
||||||
|
if hasattr(tool, "get_definition"):
|
||||||
|
executables.append(tool)
|
||||||
|
|
||||||
|
definition_tasks = [executable.get_definition() for executable in executables]
|
||||||
|
tool_defs: list[Any] = []
|
||||||
|
if definition_tasks:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
tool_defs = await asyncio.gather(*definition_tasks)
|
||||||
|
|
||||||
|
if tool_defs:
|
||||||
|
serializer = OpenAIToolSerializer()
|
||||||
|
openai_tools = serializer.serialize_tools(tool_defs)
|
||||||
|
|
||||||
|
final_tool_choice = tool_choice
|
||||||
|
if final_tool_choice is None:
|
||||||
|
if (
|
||||||
|
effective_config
|
||||||
|
and effective_config.tool_config
|
||||||
|
and effective_config.tool_config.mode == "ANY"
|
||||||
|
):
|
||||||
|
allowed = effective_config.tool_config.allowed_function_names
|
||||||
|
if allowed:
|
||||||
|
if len(allowed) == 1:
|
||||||
|
final_tool_choice = {
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": allowed[0]},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"OpenAI API 不支持多个 allowed_function_names,降级为"
|
||||||
|
" required。"
|
||||||
|
)
|
||||||
|
final_tool_choice = "required"
|
||||||
|
else:
|
||||||
|
final_tool_choice = "required"
|
||||||
|
|
||||||
|
if (
|
||||||
|
structured_strategy == StructuredOutputStrategy.TOOL_CALL
|
||||||
|
and effective_config
|
||||||
|
and effective_config.output
|
||||||
|
and effective_config.output.response_schema
|
||||||
|
):
|
||||||
|
sanitized_schema = sanitize_schema_for_llm(
|
||||||
|
effective_config.output.response_schema, api_type="openai"
|
||||||
|
)
|
||||||
|
structured_tool = {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "return_structured_response",
|
||||||
|
"description": "Return the final structured response.",
|
||||||
|
"parameters": sanitized_schema,
|
||||||
|
"strict": True if model.api_type != "deepseek" else False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if openai_tools is None:
|
||||||
|
openai_tools = []
|
||||||
|
openai_tools.append(structured_tool)
|
||||||
|
final_tool_choice = {
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "return_structured_response"},
|
||||||
|
}
|
||||||
|
|
||||||
|
protocol_strategy = self._get_protocol_strategy(model)
|
||||||
|
body = protocol_strategy.build_request_body(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
tools=openai_tools,
|
||||||
|
tool_choice=final_tool_choice,
|
||||||
|
)
|
||||||
|
|
||||||
|
body = self.apply_config_override(model, body, config)
|
||||||
|
|
||||||
|
if final_tool_choice is not None:
|
||||||
|
body["tool_choice"] = final_tool_choice
|
||||||
|
|
||||||
|
response_format = body.get("response_format", {})
|
||||||
|
inject_prompt = (
|
||||||
|
structured_strategy == StructuredOutputStrategy.NATIVE
|
||||||
|
and isinstance(response_format, dict)
|
||||||
|
and response_format.get("type") == "json_object"
|
||||||
|
)
|
||||||
|
|
||||||
|
if inject_prompt:
|
||||||
|
messages_list = body.get("messages", [])
|
||||||
|
has_json_keyword = False
|
||||||
|
for msg in messages_list:
|
||||||
|
content = msg.get("content")
|
||||||
|
if isinstance(content, str) and "json" in content.lower():
|
||||||
|
has_json_keyword = True
|
||||||
|
break
|
||||||
|
if isinstance(content, list):
|
||||||
|
for part in content:
|
||||||
|
if (
|
||||||
|
isinstance(part, dict)
|
||||||
|
and part.get("type") == "text"
|
||||||
|
and "json" in part.get("text", "").lower()
|
||||||
|
):
|
||||||
|
has_json_keyword = True
|
||||||
|
break
|
||||||
|
if has_json_keyword:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not has_json_keyword:
|
||||||
|
injection_text = (
|
||||||
|
"请务必输出合法的 JSON 格式,避免额外的文本、Markdown 或解释。"
|
||||||
|
)
|
||||||
|
system_msg = next(
|
||||||
|
(m for m in messages_list if m.get("role") == "system"), None
|
||||||
|
)
|
||||||
|
if system_msg:
|
||||||
|
if isinstance(system_msg.get("content"), str):
|
||||||
|
system_msg["content"] += " " + injection_text
|
||||||
|
elif isinstance(system_msg.get("content"), list):
|
||||||
|
system_msg["content"].append(
|
||||||
|
{"type": "text", "text": injection_text}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
messages_list.insert(
|
||||||
|
0, {"role": "system", "content": injection_text}
|
||||||
|
)
|
||||||
|
body["messages"] = messages_list
|
||||||
|
|
||||||
|
return RequestData(url=url, headers=headers, body=body)
|
||||||
|
|
||||||
|
def parse_response(
|
||||||
|
self,
|
||||||
|
model: "LLMModel",
|
||||||
|
response_json: dict[str, Any],
|
||||||
|
is_advanced: bool = False,
|
||||||
|
) -> ResponseData:
|
||||||
|
"""解析响应 - 使用策略模式委托处理"""
|
||||||
|
_ = is_advanced
|
||||||
|
protocol_strategy = self._get_protocol_strategy(model)
|
||||||
|
response_data = protocol_strategy.parse_response(response_json)
|
||||||
|
|
||||||
|
if response_data.tool_calls:
|
||||||
|
target_tool = next(
|
||||||
|
(
|
||||||
|
tc
|
||||||
|
for tc in response_data.tool_calls
|
||||||
|
if tc.function.name == "return_structured_response"
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if target_tool:
|
||||||
|
response_data.text = json_repair.repair_json(
|
||||||
|
target_tool.function.arguments
|
||||||
|
)
|
||||||
|
remaining = [
|
||||||
|
tc
|
||||||
|
for tc in response_data.tool_calls
|
||||||
|
if tc.function.name != "return_structured_response"
|
||||||
|
]
|
||||||
|
response_data.tool_calls = remaining or None
|
||||||
|
|
||||||
|
return response_data
|
||||||
|
|
||||||
|
|
||||||
|
class DeepSeekAdapter(OpenAIAdapter):
|
||||||
|
"""DeepSeek 专用适配器 (基于 OpenAI 协议)"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def api_type(self) -> str:
|
||||||
|
return "deepseek"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_api_types(self) -> list[str]:
|
||||||
|
return ["deepseek"]
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIImageAdapter(BaseAdapter):
|
||||||
|
"""OpenAI 图像生成/编辑适配器"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def api_type(self) -> str:
|
||||||
|
return "openai_image"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def log_sanitization_context(self) -> str:
|
||||||
|
return "openai_request"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_api_types(self) -> list[str]:
|
||||||
|
return ["openai_image", "nano_banana"]
|
||||||
|
|
||||||
|
async def prepare_advanced_request(
|
||||||
|
self,
|
||||||
|
model: "LLMModel",
|
||||||
|
api_key: str,
|
||||||
|
messages: list["LLMMessage"],
|
||||||
|
config: "LLMGenerationConfig | None" = None,
|
||||||
|
tools: list[Any] | None = None,
|
||||||
|
tool_choice: "str | dict[str, Any] | ToolChoice | None" = None,
|
||||||
|
) -> RequestData:
|
||||||
|
_ = tools, tool_choice
|
||||||
|
effective_config = config if config is not None else model._generation_config
|
||||||
|
headers = self.get_base_headers(api_key)
|
||||||
|
|
||||||
|
prompt = ""
|
||||||
|
images_bytes_list: list[bytes] = []
|
||||||
|
|
||||||
|
for msg in reversed(messages):
|
||||||
|
if msg.role != "user":
|
||||||
|
continue
|
||||||
|
if isinstance(msg.content, str):
|
||||||
|
prompt = msg.content
|
||||||
|
elif isinstance(msg.content, list):
|
||||||
|
for part in msg.content:
|
||||||
|
if part.type == "text" and not prompt:
|
||||||
|
prompt = part.text
|
||||||
|
elif part.type == "image":
|
||||||
|
if part.is_image_base64():
|
||||||
|
if b64_data := part.get_base64_data():
|
||||||
|
_, b64_str = b64_data
|
||||||
|
images_bytes_list.append(base64.b64decode(b64_str))
|
||||||
|
elif part.is_image_url() and part.image_source:
|
||||||
|
images_bytes_list.append(
|
||||||
|
await AsyncHttpx.get_content(part.image_source)
|
||||||
|
)
|
||||||
|
if prompt:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not prompt and not images_bytes_list:
|
||||||
|
raise LLMException(
|
||||||
|
"图像生成需要提供 Prompt",
|
||||||
|
code=LLMErrorCode.CONFIGURATION_ERROR,
|
||||||
|
)
|
||||||
|
|
||||||
|
body: dict[str, Any] = {
|
||||||
|
"model": model.model_name,
|
||||||
|
"prompt": prompt,
|
||||||
|
"response_format": "b64_json",
|
||||||
|
}
|
||||||
|
|
||||||
|
if effective_config:
|
||||||
|
if effective_config.visual:
|
||||||
|
if effective_config.visual.aspect_ratio:
|
||||||
|
ar = effective_config.visual.aspect_ratio
|
||||||
|
size_map = {
|
||||||
|
ImageAspectRatio.SQUARE: "1024x1024",
|
||||||
|
ImageAspectRatio.LANDSCAPE_16_9: "1792x1024",
|
||||||
|
ImageAspectRatio.PORTRAIT_9_16: "1024x1792",
|
||||||
|
}
|
||||||
|
if isinstance(ar, ImageAspectRatio) and ar in size_map:
|
||||||
|
body["size"] = size_map[ar]
|
||||||
|
body["aspect_ratio"] = ar.value
|
||||||
|
elif isinstance(ar, str):
|
||||||
|
if "x" in ar:
|
||||||
|
body["size"] = ar
|
||||||
|
else:
|
||||||
|
body["aspect_ratio"] = ar
|
||||||
|
|
||||||
|
if effective_config.visual.resolution:
|
||||||
|
res_val = effective_config.visual.resolution
|
||||||
|
if not isinstance(res_val, str):
|
||||||
|
res_val = getattr(res_val, "value", res_val)
|
||||||
|
body["image_size"] = res_val
|
||||||
|
|
||||||
|
if effective_config.custom_params:
|
||||||
|
body.update(effective_config.custom_params)
|
||||||
|
|
||||||
|
if images_bytes_list:
|
||||||
|
b64_images = []
|
||||||
|
for img_bytes in images_bytes_list:
|
||||||
|
b64_str = base64.b64encode(img_bytes).decode("utf-8")
|
||||||
|
b64_images.append(b64_str)
|
||||||
|
body["image"] = b64_images
|
||||||
|
|
||||||
|
endpoint = "/v1/images/generations"
|
||||||
|
url = self.get_api_url(model, endpoint)
|
||||||
|
return RequestData(url=url, headers=headers, body=body)
|
||||||
|
|
||||||
|
def parse_response(
|
||||||
|
self,
|
||||||
|
model: "LLMModel",
|
||||||
|
response_json: dict[str, Any],
|
||||||
|
is_advanced: bool = False,
|
||||||
|
) -> ResponseData:
|
||||||
|
_ = model, is_advanced
|
||||||
|
self.validate_response(response_json)
|
||||||
|
|
||||||
|
images_data: list[bytes | Path] = []
|
||||||
|
data_list = response_json.get("data", [])
|
||||||
|
|
||||||
|
for item in data_list:
|
||||||
|
if "b64_json" in item:
|
||||||
|
try:
|
||||||
|
b64_str = item["b64_json"]
|
||||||
|
if b64_str.startswith("data:"):
|
||||||
|
b64_str = b64_str.split(",", 1)[1]
|
||||||
|
img = base64.b64decode(b64_str)
|
||||||
|
images_data.append(process_image_data(img))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Base64 解码失败: {exc}")
|
||||||
|
elif "url" in item:
|
||||||
|
logger.warning(
|
||||||
|
f"API 返回了 URL 而不是 Base64: {item.get('url', 'unknown')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
text_summary = (
|
||||||
|
f"已生成 {len(images_data)} 张图片。"
|
||||||
|
if images_data
|
||||||
|
else "图像生成接口调用成功,但未解析到图片数据。"
|
||||||
|
)
|
||||||
|
|
||||||
|
return ResponseData(
|
||||||
|
text=text_summary,
|
||||||
|
images=images_data if images_data else None,
|
||||||
|
raw_response=response_json,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_embedding_request(
|
||||||
|
self,
|
||||||
|
model: "LLMModel",
|
||||||
|
api_key: str,
|
||||||
|
texts: list[str],
|
||||||
|
config: "LLMEmbeddingConfig",
|
||||||
|
) -> RequestData:
|
||||||
|
raise NotImplementedError("OpenAIImageAdapter 不支持 Embedding")
|
||||||
|
|
||||||
|
def parse_embedding_response(
|
||||||
|
self, response_json: dict[str, Any]
|
||||||
|
) -> list[list[float]]:
|
||||||
|
raise NotImplementedError("OpenAIImageAdapter 不支持 Embedding")
|
||||||
|
|
||||||
|
def convert_generation_config(
|
||||||
|
self, config: "LLMGenerationConfig", model: "LLMModel"
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
_ = config, model
|
||||||
|
return {}
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
LLM 服务的高级 API 接口 - 便捷函数入口 (无状态)
|
LLM 服务的高级 API 接口 - 便捷函数入口 (无状态)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, TypeVar, overload
|
from typing import Any, TypeVar, overload
|
||||||
|
|
||||||
@ -11,19 +12,24 @@ from pydantic import BaseModel
|
|||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
|
|
||||||
from .config import CommonOverrides
|
from .config import CommonOverrides
|
||||||
from .config.generation import LLMGenerationConfig, create_generation_config_from_kwargs
|
from .config.generation import (
|
||||||
|
GenConfigBuilder,
|
||||||
|
LLMEmbeddingConfig,
|
||||||
|
LLMGenerationConfig,
|
||||||
|
OutputConfig,
|
||||||
|
)
|
||||||
from .manager import get_model_instance
|
from .manager import get_model_instance
|
||||||
from .session import AI
|
from .session import AI
|
||||||
from .tools.manager import tool_provider_manager
|
|
||||||
from .types import (
|
from .types import (
|
||||||
EmbeddingTaskType,
|
|
||||||
LLMContentPart,
|
LLMContentPart,
|
||||||
LLMErrorCode,
|
LLMErrorCode,
|
||||||
LLMException,
|
LLMException,
|
||||||
LLMMessage,
|
LLMMessage,
|
||||||
LLMResponse,
|
LLMResponse,
|
||||||
ModelName,
|
ModelName,
|
||||||
|
ToolChoice,
|
||||||
)
|
)
|
||||||
|
from .types.exceptions import get_user_friendly_error_message
|
||||||
from .utils import create_multimodal_message
|
from .utils import create_multimodal_message
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
@ -34,9 +40,10 @@ async def chat(
|
|||||||
*,
|
*,
|
||||||
model: ModelName = None,
|
model: ModelName = None,
|
||||||
instruction: str | None = None,
|
instruction: str | None = None,
|
||||||
tools: list[dict[str, Any] | str] | None = None,
|
tools: list[Any] | None = None,
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
tool_choice: str | dict[str, Any] | ToolChoice | None = None,
|
||||||
**kwargs: Any,
|
config: LLMGenerationConfig | GenConfigBuilder | None = None,
|
||||||
|
timeout: float | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""
|
"""
|
||||||
无状态的聊天对话便捷函数,通过临时的AI会话实例与LLM模型交互。
|
无状态的聊天对话便捷函数,通过临时的AI会话实例与LLM模型交互。
|
||||||
@ -47,14 +54,13 @@ async def chat(
|
|||||||
instruction: 系统指令,用于指导AI的行为和回复风格。
|
instruction: 系统指令,用于指导AI的行为和回复风格。
|
||||||
tools: 可用的工具列表,支持字典配置或字符串标识符。
|
tools: 可用的工具列表,支持字典配置或字符串标识符。
|
||||||
tool_choice: 工具选择策略,控制AI如何选择和使用工具。
|
tool_choice: 工具选择策略,控制AI如何选择和使用工具。
|
||||||
**kwargs: 额外的生成配置参数,会被转换为LLMGenerationConfig。
|
config: (可选) 生成配置对象,将与默认配置合并后传递。
|
||||||
|
timeout: (可选) HTTP 请求超时时间(秒)。
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
LLMResponse: 包含AI回复内容、使用信息和工具调用等的完整响应对象。
|
LLMResponse: 包含AI回复内容、使用信息和工具调用等的完整响应对象。
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
config = create_generation_config_from_kwargs(**kwargs) if kwargs else None
|
|
||||||
|
|
||||||
ai_session = AI()
|
ai_session = AI()
|
||||||
|
|
||||||
return await ai_session.chat(
|
return await ai_session.chat(
|
||||||
@ -64,12 +70,14 @@ async def chat(
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
config=config,
|
config=config,
|
||||||
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
except LLMException:
|
except LLMException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"执行 chat 函数失败: {e}", e=e)
|
friendly_msg = get_user_friendly_error_message(e)
|
||||||
raise LLMException(f"聊天执行失败: {e}", cause=e)
|
logger.error(f"执行 chat 函数失败: {e} | 建议: {friendly_msg}", e=e)
|
||||||
|
raise LLMException(f"聊天执行失败: {friendly_msg}", cause=e)
|
||||||
|
|
||||||
|
|
||||||
async def code(
|
async def code(
|
||||||
@ -77,7 +85,6 @@ async def code(
|
|||||||
*,
|
*,
|
||||||
model: ModelName = None,
|
model: ModelName = None,
|
||||||
timeout: int | None = None,
|
timeout: int | None = None,
|
||||||
**kwargs: Any,
|
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""
|
"""
|
||||||
无状态的代码执行便捷函数,支持在沙箱环境中执行代码。
|
无状态的代码执行便捷函数,支持在沙箱环境中执行代码。
|
||||||
@ -86,66 +93,25 @@ async def code(
|
|||||||
prompt: 代码执行的提示词,描述要执行的代码任务。
|
prompt: 代码执行的提示词,描述要执行的代码任务。
|
||||||
model: 要使用的模型名称,默认使用Gemini/gemini-2.0-flash。
|
model: 要使用的模型名称,默认使用Gemini/gemini-2.0-flash。
|
||||||
timeout: 代码执行超时时间(秒),防止长时间运行的代码阻塞。
|
timeout: 代码执行超时时间(秒),防止长时间运行的代码阻塞。
|
||||||
**kwargs: 额外的生成配置参数。
|
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
LLMResponse: 包含代码执行结果的完整响应对象。
|
LLMResponse: 包含代码执行结果的完整响应对象。
|
||||||
"""
|
"""
|
||||||
resolved_model = model or "Gemini/gemini-2.0-flash"
|
resolved_model = model
|
||||||
|
|
||||||
config = CommonOverrides.gemini_code_execution()
|
config = CommonOverrides.gemini_code_execution()
|
||||||
if timeout:
|
if timeout:
|
||||||
config.custom_params = config.custom_params or {}
|
config.custom_params = config.custom_params or {}
|
||||||
config.custom_params["code_execution_timeout"] = timeout
|
config.custom_params["code_execution_timeout"] = timeout
|
||||||
|
|
||||||
final_config = config.to_dict()
|
return await chat(prompt, model=resolved_model, config=config)
|
||||||
final_config.update(kwargs)
|
|
||||||
|
|
||||||
return await chat(prompt, model=resolved_model, **final_config)
|
|
||||||
|
|
||||||
|
|
||||||
async def search(
|
|
||||||
query: str | UniMessage | LLMMessage | list[LLMContentPart],
|
|
||||||
*,
|
|
||||||
model: ModelName = None,
|
|
||||||
instruction: str = (
|
|
||||||
"你是一位强大的信息检索和整合专家。请利用可用的搜索工具,"
|
|
||||||
"根据用户的查询找到最相关的信息,并进行总结和回答。"
|
|
||||||
),
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> LLMResponse:
|
|
||||||
"""
|
|
||||||
无状态的信息搜索便捷函数,利用搜索工具获取实时信息。
|
|
||||||
|
|
||||||
参数:
|
|
||||||
query: 搜索查询内容,支持多种输入格式。
|
|
||||||
model: 要使用的模型名称,如果为None则使用默认模型。
|
|
||||||
instruction: 搜索任务的系统指令,指导AI如何处理搜索结果。
|
|
||||||
**kwargs: 额外的生成配置参数。
|
|
||||||
|
|
||||||
返回:
|
|
||||||
LLMResponse: 包含搜索结果和AI整合回复的完整响应对象。
|
|
||||||
"""
|
|
||||||
logger.debug("执行无状态 'search' 任务...")
|
|
||||||
search_config = CommonOverrides.gemini_grounding()
|
|
||||||
|
|
||||||
final_config = search_config.to_dict()
|
|
||||||
final_config.update(kwargs)
|
|
||||||
|
|
||||||
return await chat(
|
|
||||||
query,
|
|
||||||
model=model,
|
|
||||||
instruction=instruction,
|
|
||||||
**final_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def embed(
|
async def embed(
|
||||||
texts: list[str] | str,
|
texts: list[str] | str,
|
||||||
*,
|
*,
|
||||||
model: ModelName = None,
|
model: ModelName = None,
|
||||||
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
config: LLMEmbeddingConfig | None = None,
|
||||||
**kwargs: Any,
|
|
||||||
) -> list[list[float]]:
|
) -> list[list[float]]:
|
||||||
"""
|
"""
|
||||||
无状态的文本嵌入便捷函数,将文本转换为向量表示。
|
无状态的文本嵌入便捷函数,将文本转换为向量表示。
|
||||||
@ -153,8 +119,7 @@ async def embed(
|
|||||||
参数:
|
参数:
|
||||||
texts: 要生成嵌入的文本内容,支持单个字符串或字符串列表。
|
texts: 要生成嵌入的文本内容,支持单个字符串或字符串列表。
|
||||||
model: 要使用的嵌入模型名称,如果为None则使用默认模型。
|
model: 要使用的嵌入模型名称,如果为None则使用默认模型。
|
||||||
task_type: 嵌入任务类型,影响向量的优化方向(如检索、分类等)。
|
config: 嵌入配置对象。
|
||||||
**kwargs: 额外的模型配置参数。
|
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
list[list[float]]: 文本对应的嵌入向量列表,每个向量为浮点数列表。
|
list[list[float]]: 文本对应的嵌入向量列表,每个向量为浮点数列表。
|
||||||
@ -164,27 +129,71 @@ async def embed(
|
|||||||
if not texts:
|
if not texts:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
final_config = config or LLMEmbeddingConfig()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with await get_model_instance(model) as model_instance:
|
async with await get_model_instance(model) as model_instance:
|
||||||
return await model_instance.generate_embeddings(
|
return await model_instance.generate_embeddings(texts, config=final_config)
|
||||||
texts, task_type=task_type, **kwargs
|
|
||||||
)
|
|
||||||
except LLMException:
|
except LLMException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"文本嵌入失败: {e}", e=e)
|
friendly_msg = get_user_friendly_error_message(e)
|
||||||
|
logger.error(f"文本嵌入失败: {e} | 建议: {friendly_msg}", e=e)
|
||||||
raise LLMException(
|
raise LLMException(
|
||||||
f"文本嵌入失败: {e}", code=LLMErrorCode.EMBEDDING_FAILED, cause=e
|
f"文本嵌入失败: {friendly_msg}",
|
||||||
|
code=LLMErrorCode.EMBEDDING_FAILED,
|
||||||
|
cause=e,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def embed_query(
|
||||||
|
text: str,
|
||||||
|
*,
|
||||||
|
model: ModelName = None,
|
||||||
|
dimensions: int | None = None,
|
||||||
|
) -> list[float]:
|
||||||
|
"""
|
||||||
|
语义化便捷 API:为检索查询生成嵌入。
|
||||||
|
"""
|
||||||
|
config = LLMEmbeddingConfig(
|
||||||
|
task_type="RETRIEVAL_QUERY",
|
||||||
|
output_dimensionality=dimensions,
|
||||||
|
)
|
||||||
|
vectors = await embed([text], model=model, config=config)
|
||||||
|
return vectors[0] if vectors else []
|
||||||
|
|
||||||
|
|
||||||
|
async def embed_documents(
|
||||||
|
texts: list[str],
|
||||||
|
*,
|
||||||
|
model: ModelName = None,
|
||||||
|
dimensions: int | None = None,
|
||||||
|
title: str | None = None,
|
||||||
|
) -> list[list[float]]:
|
||||||
|
"""
|
||||||
|
语义化便捷 API:为文档集合生成嵌入。
|
||||||
|
"""
|
||||||
|
config = LLMEmbeddingConfig(
|
||||||
|
task_type="RETRIEVAL_DOCUMENT",
|
||||||
|
output_dimensionality=dimensions,
|
||||||
|
title=title,
|
||||||
|
)
|
||||||
|
return await embed(texts, model=model, config=config)
|
||||||
|
|
||||||
|
|
||||||
async def generate_structured(
|
async def generate_structured(
|
||||||
message: str | LLMMessage | list[LLMContentPart],
|
message: str | LLMMessage | list[LLMContentPart],
|
||||||
response_model: type[T],
|
response_model: type[T],
|
||||||
*,
|
*,
|
||||||
model: ModelName = None,
|
model: ModelName = None,
|
||||||
|
tools: list[Any] | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | ToolChoice | None = None,
|
||||||
|
max_validation_retries: int | None = None,
|
||||||
|
validation_callback: Callable[[T], Any | Awaitable[Any]] | None = None,
|
||||||
|
error_prompt_template: str | None = None,
|
||||||
|
auto_thinking: bool = False,
|
||||||
instruction: str | None = None,
|
instruction: str | None = None,
|
||||||
**kwargs: Any,
|
timeout: float | None = None,
|
||||||
) -> T:
|
) -> T:
|
||||||
"""
|
"""
|
||||||
无状态地生成结构化响应,并自动解析为指定的Pydantic模型。
|
无状态地生成结构化响应,并自动解析为指定的Pydantic模型。
|
||||||
@ -192,39 +201,48 @@ async def generate_structured(
|
|||||||
参数:
|
参数:
|
||||||
message: 用户输入的消息内容,支持多种格式。
|
message: 用户输入的消息内容,支持多种格式。
|
||||||
response_model: 用于解析和验证响应的Pydantic模型类。
|
response_model: 用于解析和验证响应的Pydantic模型类。
|
||||||
|
max_validation_retries: 校验失败时的最大重试次数,默认为 None (使用全局配置)。
|
||||||
|
validation_callback: 自定义校验回调函数,抛出异常视为校验失败。
|
||||||
|
error_prompt_template: 自定义错误反馈提示词模板。
|
||||||
|
auto_thinking: 是否自动开启思维链 (CoT) 包装。适用于不支持原生思考的模型
|
||||||
model: 要使用的模型名称,如果为None则使用默认模型。
|
model: 要使用的模型名称,如果为None则使用默认模型。
|
||||||
instruction: 系统指令,用于指导AI生成符合要求的结构化输出。
|
instruction: 系统指令,用于指导AI生成符合要求的结构化输出。
|
||||||
**kwargs: 额外的生成配置参数。
|
timeout: HTTP 请求超时时间(秒)。
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
T: 解析后的Pydantic模型实例,类型为response_model指定的类型。
|
T: 解析后的Pydantic模型实例,类型为response_model指定的类型。
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
config = create_generation_config_from_kwargs(**kwargs) if kwargs else None
|
|
||||||
|
|
||||||
ai_session = AI()
|
ai_session = AI()
|
||||||
|
|
||||||
return await ai_session.generate_structured(
|
return await ai_session.generate_structured(
|
||||||
message,
|
message,
|
||||||
response_model,
|
response_model,
|
||||||
model=model,
|
model=model,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
max_validation_retries=max_validation_retries,
|
||||||
|
validation_callback=validation_callback,
|
||||||
|
error_prompt_template=error_prompt_template,
|
||||||
|
auto_thinking=auto_thinking,
|
||||||
instruction=instruction,
|
instruction=instruction,
|
||||||
config=config,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
except LLMException:
|
except LLMException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"生成结构化响应失败: {e}", e=e)
|
friendly_msg = get_user_friendly_error_message(e)
|
||||||
raise LLMException(f"生成结构化响应失败: {e}", cause=e)
|
logger.error(f"生成结构化响应失败: {e} | 建议: {friendly_msg}", e=e)
|
||||||
|
raise LLMException(f"生成结构化响应失败: {friendly_msg}", cause=e)
|
||||||
|
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
messages: list[LLMMessage],
|
messages: list[LLMMessage],
|
||||||
*,
|
*,
|
||||||
model: ModelName = None,
|
model: ModelName = None,
|
||||||
tools: list[dict[str, Any] | str] | None = None,
|
tools: list[Any] | None = None,
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
tool_choice: str | dict[str, Any] | ToolChoice | None = None,
|
||||||
**kwargs: Any,
|
config: LLMGenerationConfig | GenConfigBuilder | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""
|
"""
|
||||||
根据完整的消息列表生成一次性响应,这是一个无状态的底层函数。
|
根据完整的消息列表生成一次性响应,这是一个无状态的底层函数。
|
||||||
@ -234,109 +252,56 @@ async def generate(
|
|||||||
model: 要使用的模型名称,如果为None则使用默认模型。
|
model: 要使用的模型名称,如果为None则使用默认模型。
|
||||||
tools: 可用的工具列表,支持字典配置或字符串标识符。
|
tools: 可用的工具列表,支持字典配置或字符串标识符。
|
||||||
tool_choice: 工具选择策略,控制AI如何选择和使用工具。
|
tool_choice: 工具选择策略,控制AI如何选择和使用工具。
|
||||||
**kwargs: 额外的生成配置参数,会覆盖默认配置。
|
config: (可选) 生成配置对象,将与默认配置合并后传递。
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
LLMResponse: 包含AI回复内容、使用信息和工具调用等的完整响应对象。
|
LLMResponse: 包含AI回复内容、使用信息和工具调用等的完整响应对象。
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
if isinstance(config, GenConfigBuilder):
|
||||||
|
config = config.build()
|
||||||
|
|
||||||
async with await get_model_instance(
|
async with await get_model_instance(
|
||||||
model, override_config=kwargs
|
model, override_config=None
|
||||||
) as model_instance:
|
) as model_instance:
|
||||||
return await model_instance.generate_response(
|
return await model_instance.generate_response(
|
||||||
messages,
|
messages,
|
||||||
tools=tools, # type: ignore
|
config=config,
|
||||||
|
tools=tools, # type: ignore[arg-type]
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
)
|
)
|
||||||
except LLMException:
|
except LLMException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"生成响应失败: {e}", e=e)
|
friendly_msg = get_user_friendly_error_message(e)
|
||||||
raise LLMException(f"生成响应失败: {e}", cause=e)
|
logger.error(f"生成响应失败: {e} | 建议: {friendly_msg}", e=e)
|
||||||
|
raise LLMException(f"生成响应失败: {friendly_msg}", cause=e)
|
||||||
|
|
||||||
async def run_with_tools(
|
|
||||||
message: str | UniMessage | LLMMessage | list[LLMContentPart],
|
|
||||||
*,
|
|
||||||
model: ModelName = None,
|
|
||||||
instruction: str | None = None,
|
|
||||||
tools: list[str],
|
|
||||||
max_cycles: int = 5,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> LLMResponse:
|
|
||||||
"""
|
|
||||||
无状态地执行一个带本地Python函数的LLM调用循环。
|
|
||||||
|
|
||||||
参数:
|
|
||||||
message: 用户输入。
|
|
||||||
model: 使用的模型。
|
|
||||||
instruction: 系统指令。
|
|
||||||
tools: 要使用的本地函数工具名称列表 (必须已通过 @function_tool 注册)。
|
|
||||||
max_cycles: 最大工具调用循环次数。
|
|
||||||
**kwargs: 额外的生成配置参数。
|
|
||||||
|
|
||||||
返回:
|
|
||||||
LLMResponse: 包含最终回复的响应对象。
|
|
||||||
"""
|
|
||||||
from .executor import ExecutionConfig, LLMToolExecutor
|
|
||||||
from .utils import normalize_to_llm_messages
|
|
||||||
|
|
||||||
messages = await normalize_to_llm_messages(message, instruction)
|
|
||||||
|
|
||||||
async with await get_model_instance(
|
|
||||||
model, override_config=kwargs
|
|
||||||
) as model_instance:
|
|
||||||
resolved_tools = await tool_provider_manager.get_function_tools(tools)
|
|
||||||
if not resolved_tools:
|
|
||||||
logger.warning(
|
|
||||||
"run_with_tools 未找到任何可用的本地函数工具,将作为普通聊天执行。"
|
|
||||||
)
|
|
||||||
return await model_instance.generate_response(messages, tools=None)
|
|
||||||
|
|
||||||
executor = LLMToolExecutor(model_instance)
|
|
||||||
config = ExecutionConfig(max_cycles=max_cycles)
|
|
||||||
final_history = await executor.run(messages, resolved_tools, config)
|
|
||||||
|
|
||||||
for msg in reversed(final_history):
|
|
||||||
if msg.role == "assistant":
|
|
||||||
text = msg.content if isinstance(msg.content, str) else str(msg.content)
|
|
||||||
return LLMResponse(text=text, tool_calls=msg.tool_calls)
|
|
||||||
|
|
||||||
raise LLMException(
|
|
||||||
"带工具的执行循环未能产生有效的助手回复。", code=LLMErrorCode.GENERATION_FAILED
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _generate_image_from_message(
|
async def _generate_image_from_message(
|
||||||
message: UniMessage,
|
message: UniMessage,
|
||||||
model: ModelName = None,
|
model: ModelName = None,
|
||||||
**kwargs: Any,
|
config: LLMGenerationConfig | GenConfigBuilder | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""
|
"""
|
||||||
[内部] 从 UniMessage 生成图片的核心辅助函数。
|
[内部] 从 UniMessage 生成图片的核心辅助函数。
|
||||||
"""
|
"""
|
||||||
from .utils import normalize_to_llm_messages
|
from .utils import normalize_to_llm_messages
|
||||||
|
|
||||||
config = (
|
if isinstance(config, GenConfigBuilder):
|
||||||
create_generation_config_from_kwargs(**kwargs)
|
config = config.build()
|
||||||
if kwargs
|
|
||||||
else LLMGenerationConfig()
|
config = config or LLMGenerationConfig()
|
||||||
)
|
|
||||||
|
|
||||||
config.validation_policy = {"require_image": True}
|
config.validation_policy = {"require_image": True}
|
||||||
config.response_modalities = ["IMAGE", "TEXT"]
|
if config.output is None:
|
||||||
|
config.output = OutputConfig()
|
||||||
|
config.output.response_modalities = ["IMAGE", "TEXT"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
messages = await normalize_to_llm_messages(message)
|
messages = await normalize_to_llm_messages(message)
|
||||||
|
|
||||||
async with await get_model_instance(model) as model_instance:
|
async with await get_model_instance(model) as model_instance:
|
||||||
if not model_instance.can_generate_images():
|
|
||||||
raise LLMException(
|
|
||||||
f"模型 '{model_instance.provider_name}/{model_instance.model_name}'"
|
|
||||||
f"不支持图片生成",
|
|
||||||
code=LLMErrorCode.CONFIGURATION_ERROR,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await model_instance.generate_response(messages, config=config)
|
response = await model_instance.generate_response(messages, config=config)
|
||||||
|
|
||||||
if not response.images:
|
if not response.images:
|
||||||
@ -347,8 +312,9 @@ async def _generate_image_from_message(
|
|||||||
except LLMException:
|
except LLMException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"执行图片生成时发生未知错误: {e}", e=e)
|
friendly_msg = get_user_friendly_error_message(e)
|
||||||
raise LLMException(f"图片生成失败: {e}", cause=e)
|
logger.error(f"执行图片生成时发生未知错误: {e} | 建议: {friendly_msg}", e=e)
|
||||||
|
raise LLMException(f"图片生成失败: {friendly_msg}", cause=e)
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@ -357,7 +323,6 @@ async def create_image(
|
|||||||
*,
|
*,
|
||||||
images: None = None,
|
images: None = None,
|
||||||
model: ModelName = None,
|
model: ModelName = None,
|
||||||
**kwargs: Any,
|
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""根据文本提示生成一张新图片。"""
|
"""根据文本提示生成一张新图片。"""
|
||||||
...
|
...
|
||||||
@ -369,7 +334,6 @@ async def create_image(
|
|||||||
*,
|
*,
|
||||||
images: list[Path | bytes | str] | Path | bytes | str,
|
images: list[Path | bytes | str] | Path | bytes | str,
|
||||||
model: ModelName = None,
|
model: ModelName = None,
|
||||||
**kwargs: Any,
|
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""在给定图片的基础上,根据文本提示进行编辑或重新生成。"""
|
"""在给定图片的基础上,根据文本提示进行编辑或重新生成。"""
|
||||||
...
|
...
|
||||||
@ -380,7 +344,7 @@ async def create_image(
|
|||||||
*,
|
*,
|
||||||
images: list[Path | bytes | str] | Path | bytes | str | None = None,
|
images: list[Path | bytes | str] | Path | bytes | str | None = None,
|
||||||
model: ModelName = None,
|
model: ModelName = None,
|
||||||
**kwargs: Any,
|
config: LLMGenerationConfig | GenConfigBuilder | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""
|
"""
|
||||||
智能图片生成/编辑函数。
|
智能图片生成/编辑函数。
|
||||||
@ -400,4 +364,42 @@ async def create_image(
|
|||||||
|
|
||||||
message = create_multimodal_message(text=text_prompt, images=image_list)
|
message = create_multimodal_message(text=text_prompt, images=image_list)
|
||||||
|
|
||||||
return await _generate_image_from_message(message, model=model, **kwargs)
|
return await _generate_image_from_message(message, model=model, config=config)
|
||||||
|
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
query: str | UniMessage | LLMMessage | list[LLMContentPart],
|
||||||
|
*,
|
||||||
|
model: ModelName = None,
|
||||||
|
instruction: str = (
|
||||||
|
"你是一位强大的信息检索和整合专家。请利用可用的搜索工具,"
|
||||||
|
"根据用户的查询找到最相关的信息,并进行总结和回答。"
|
||||||
|
),
|
||||||
|
config: LLMGenerationConfig | GenConfigBuilder | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""
|
||||||
|
无状态的信息搜索便捷函数,利用搜索工具获取实时信息。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
query: 搜索查询内容,支持多种输入格式。
|
||||||
|
model: 要使用的模型名称,如果为None则使用默认模型。
|
||||||
|
config: (可选) 生成配置对象,将与预设配置合并后传递。
|
||||||
|
instruction: 搜索任务的系统指令,指导AI如何处理搜索结果。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
LLMResponse: 包含搜索结果和AI整合回复的完整响应对象。
|
||||||
|
"""
|
||||||
|
logger.debug("执行无状态 'search' 任务...")
|
||||||
|
search_config = CommonOverrides.gemini_grounding()
|
||||||
|
|
||||||
|
if isinstance(config, GenConfigBuilder):
|
||||||
|
config = config.build()
|
||||||
|
|
||||||
|
final_config = search_config.merge_with(config)
|
||||||
|
|
||||||
|
return await chat(
|
||||||
|
query,
|
||||||
|
model=model,
|
||||||
|
instruction=instruction,
|
||||||
|
config=final_config,
|
||||||
|
)
|
||||||
|
|||||||
@ -5,13 +5,12 @@ LLM 配置模块
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from .generation import (
|
from .generation import (
|
||||||
|
CommonOverrides,
|
||||||
|
GenConfigBuilder,
|
||||||
|
LLMEmbeddingConfig,
|
||||||
LLMGenerationConfig,
|
LLMGenerationConfig,
|
||||||
ModelConfigOverride,
|
|
||||||
apply_api_specific_mappings,
|
|
||||||
create_generation_config_from_kwargs,
|
|
||||||
validate_override_params,
|
validate_override_params,
|
||||||
)
|
)
|
||||||
from .presets import CommonOverrides
|
|
||||||
from .providers import (
|
from .providers import (
|
||||||
LLMConfig,
|
LLMConfig,
|
||||||
get_gemini_safety_threshold,
|
get_gemini_safety_threshold,
|
||||||
@ -23,11 +22,10 @@ from .providers import (
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CommonOverrides",
|
"CommonOverrides",
|
||||||
|
"GenConfigBuilder",
|
||||||
"LLMConfig",
|
"LLMConfig",
|
||||||
|
"LLMEmbeddingConfig",
|
||||||
"LLMGenerationConfig",
|
"LLMGenerationConfig",
|
||||||
"ModelConfigOverride",
|
|
||||||
"apply_api_specific_mappings",
|
|
||||||
"create_generation_config_from_kwargs",
|
|
||||||
"get_gemini_safety_threshold",
|
"get_gemini_safety_threshold",
|
||||||
"get_llm_config",
|
"get_llm_config",
|
||||||
"register_llm_configs",
|
"register_llm_configs",
|
||||||
|
|||||||
@ -3,209 +3,397 @@ LLM 生成配置相关类和函数
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any
|
from enum import Enum
|
||||||
|
from typing import Any, Literal
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
from zhenxun.utils.pydantic_compat import model_dump
|
from zhenxun.utils.pydantic_compat import model_copy, model_dump, model_validate
|
||||||
|
|
||||||
from ..types import LLMResponse
|
from ..types import LLMResponse, ResponseFormat, StructuredOutputStrategy
|
||||||
from ..types.enums import ResponseFormat
|
|
||||||
from ..types.exceptions import LLMErrorCode, LLMException
|
from ..types.exceptions import LLMErrorCode, LLMException
|
||||||
|
from .providers import get_gemini_safety_threshold
|
||||||
|
|
||||||
|
|
||||||
class ModelConfigOverride(BaseModel):
|
class ReasoningEffort(str, Enum):
|
||||||
"""模型配置覆盖参数"""
|
"""推理努力程度枚举"""
|
||||||
|
|
||||||
|
LOW = "LOW"
|
||||||
|
MEDIUM = "MEDIUM"
|
||||||
|
HIGH = "HIGH"
|
||||||
|
|
||||||
|
|
||||||
|
class ImageAspectRatio(str, Enum):
|
||||||
|
"""图像宽高比枚举"""
|
||||||
|
|
||||||
|
SQUARE = "1:1"
|
||||||
|
LANDSCAPE_16_9 = "16:9"
|
||||||
|
PORTRAIT_9_16 = "9:16"
|
||||||
|
LANDSCAPE_4_3 = "4:3"
|
||||||
|
PORTRAIT_3_4 = "3:4"
|
||||||
|
LANDSCAPE_3_2 = "3:2"
|
||||||
|
PORTRAIT_2_3 = "2:3"
|
||||||
|
|
||||||
|
|
||||||
|
class ImageResolution(str, Enum):
|
||||||
|
"""图像分辨率/质量枚举"""
|
||||||
|
|
||||||
|
STANDARD = "STANDARD"
|
||||||
|
HD = "HD"
|
||||||
|
|
||||||
|
|
||||||
|
class CoreConfig(BaseModel):
|
||||||
|
"""核心生成参数"""
|
||||||
|
|
||||||
temperature: float | None = Field(
|
temperature: float | None = Field(
|
||||||
default=None, ge=0.0, le=2.0, description="生成温度"
|
default=None, ge=0.0, le=2.0, description="生成温度"
|
||||||
)
|
)
|
||||||
|
"""生成温度"""
|
||||||
max_tokens: int | None = Field(default=None, gt=0, description="最大输出token数")
|
max_tokens: int | None = Field(default=None, gt=0, description="最大输出token数")
|
||||||
|
"""最大输出token数"""
|
||||||
top_p: float | None = Field(default=None, ge=0.0, le=1.0, description="核采样参数")
|
top_p: float | None = Field(default=None, ge=0.0, le=1.0, description="核采样参数")
|
||||||
|
"""核采样参数"""
|
||||||
top_k: int | None = Field(default=None, gt=0, description="Top-K采样参数")
|
top_k: int | None = Field(default=None, gt=0, description="Top-K采样参数")
|
||||||
|
"""Top-K采样参数"""
|
||||||
frequency_penalty: float | None = Field(
|
frequency_penalty: float | None = Field(
|
||||||
default=None, ge=-2.0, le=2.0, description="频率惩罚"
|
default=None, ge=-2.0, le=2.0, description="频率惩罚"
|
||||||
)
|
)
|
||||||
|
"""频率惩罚"""
|
||||||
presence_penalty: float | None = Field(
|
presence_penalty: float | None = Field(
|
||||||
default=None, ge=-2.0, le=2.0, description="存在惩罚"
|
default=None, ge=-2.0, le=2.0, description="存在惩罚"
|
||||||
)
|
)
|
||||||
|
"""存在惩罚"""
|
||||||
repetition_penalty: float | None = Field(
|
repetition_penalty: float | None = Field(
|
||||||
default=None, ge=0.0, le=2.0, description="重复惩罚"
|
default=None, ge=0.0, le=2.0, description="重复惩罚"
|
||||||
)
|
)
|
||||||
|
"""重复惩罚"""
|
||||||
stop: list[str] | str | None = Field(default=None, description="停止序列")
|
stop: list[str] | str | None = Field(default=None, description="停止序列")
|
||||||
|
"""停止序列"""
|
||||||
|
|
||||||
|
|
||||||
|
class ReasoningConfig(BaseModel):
|
||||||
|
"""推理能力配置"""
|
||||||
|
|
||||||
|
effort: ReasoningEffort | None = Field(
|
||||||
|
default=None, description="推理努力程度 (适用于 O1, Gemini 3)"
|
||||||
|
)
|
||||||
|
"""推理努力程度 (适用于 O1, Gemini 3)"""
|
||||||
|
budget_tokens: int | None = Field(
|
||||||
|
default=None, description="具体的思考 Token 预算 (适用于 Gemini 2.5)"
|
||||||
|
)
|
||||||
|
"""具体的思考 Token 预算 (适用于 Gemini 2.5)"""
|
||||||
|
show_thoughts: bool | None = Field(
|
||||||
|
default=None, description="是否在响应中显式包含思维链内容"
|
||||||
|
)
|
||||||
|
"""是否在响应中显式包含思维链内容"""
|
||||||
|
|
||||||
|
|
||||||
|
class VisualConfig(BaseModel):
|
||||||
|
"""视觉生成配置"""
|
||||||
|
|
||||||
|
aspect_ratio: ImageAspectRatio | str | None = Field(
|
||||||
|
default=None, description="宽高比"
|
||||||
|
)
|
||||||
|
"""宽高比"""
|
||||||
|
resolution: ImageResolution | str | None = Field(
|
||||||
|
default=None, description="生成质量/分辨率"
|
||||||
|
)
|
||||||
|
"""生成质量/分辨率"""
|
||||||
|
media_resolution: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="输入媒体的解析度 (Gemini 3+): 'LOW', 'MEDIUM', 'HIGH'",
|
||||||
|
)
|
||||||
|
"""输入媒体的解析度 (Gemini 3+): 'LOW', 'MEDIUM', 'HIGH'"""
|
||||||
|
style: str | None = Field(
|
||||||
|
default=None, description="图像风格 (如 DALL-E 3 vivid/natural)"
|
||||||
|
)
|
||||||
|
"""图像风格 (如 DALL-E 3 vivid/natural)"""
|
||||||
|
|
||||||
|
|
||||||
|
class OutputConfig(BaseModel):
|
||||||
|
"""输出格式控制"""
|
||||||
|
|
||||||
response_format: ResponseFormat | dict[str, Any] | None = Field(
|
response_format: ResponseFormat | dict[str, Any] | None = Field(
|
||||||
default=None, description="期望的响应格式"
|
default=None, description="期望的响应格式"
|
||||||
)
|
)
|
||||||
|
"""期望的响应格式"""
|
||||||
response_mime_type: str | None = Field(
|
response_mime_type: str | None = Field(
|
||||||
default=None, description="响应MIME类型(Gemini专用)"
|
default=None, description="响应MIME类型(Gemini专用)"
|
||||||
)
|
)
|
||||||
|
"""响应MIME类型(Gemini专用)"""
|
||||||
response_schema: dict[str, Any] | None = Field(
|
response_schema: dict[str, Any] | None = Field(
|
||||||
default=None, description="JSON响应模式"
|
default=None, description="JSON响应模式"
|
||||||
)
|
)
|
||||||
thinking_budget: float | None = Field(
|
"""JSON响应模式"""
|
||||||
default=None, ge=0.0, le=1.0, description="思考预算"
|
|
||||||
)
|
|
||||||
include_thoughts: bool | None = Field(
|
|
||||||
default=None, description="是否在响应中包含思维过程(Gemini专用)"
|
|
||||||
)
|
|
||||||
safety_settings: dict[str, str] | None = Field(default=None, description="安全设置")
|
|
||||||
response_modalities: list[str] | None = Field(
|
response_modalities: list[str] | None = Field(
|
||||||
default=None, description="响应模态类型"
|
default=None, description="响应模态类型 (TEXT, IMAGE, AUDIO)"
|
||||||
)
|
)
|
||||||
|
"""响应模态类型 (TEXT, IMAGE, AUDIO)"""
|
||||||
|
structured_output_strategy: StructuredOutputStrategy | str | None = Field(
|
||||||
|
default=None, description="结构化输出策略 (NATIVE/TOOL_CALL/PROMPT)"
|
||||||
|
)
|
||||||
|
"""结构化输出策略 (NATIVE/TOOL_CALL/PROMPT)"""
|
||||||
|
|
||||||
enable_code_execution: bool | None = Field(
|
|
||||||
default=None, description="是否启用代码执行"
|
class SafetyConfig(BaseModel):
|
||||||
|
"""安全设置"""
|
||||||
|
|
||||||
|
safety_settings: dict[str, str] | None = Field(default=None, description="安全设置")
|
||||||
|
"""安全设置"""
|
||||||
|
|
||||||
|
|
||||||
|
class ToolConfig(BaseModel):
|
||||||
|
"""工具调用控制配置"""
|
||||||
|
|
||||||
|
mode: Literal["AUTO", "ANY", "NONE"] = Field(
|
||||||
|
default="AUTO",
|
||||||
|
description="工具调用模式: AUTO(自动), ANY(强制), NONE(禁用)",
|
||||||
)
|
)
|
||||||
enable_grounding: bool | None = Field(
|
"""工具调用模式: AUTO(自动), ANY(强制), NONE(禁用)"""
|
||||||
default=None, description="是否启用信息来源关联"
|
allowed_function_names: list[str] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="当 mode 为 ANY 时,允许调用的函数名称白名单",
|
||||||
)
|
)
|
||||||
|
"""当 mode 为 ANY 时,允许调用的函数名称白名单"""
|
||||||
|
|
||||||
|
|
||||||
|
class LLMGenerationConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
LLM 生成配置
|
||||||
|
采用组件化设计,不再扁平化参数。
|
||||||
|
"""
|
||||||
|
|
||||||
|
core: CoreConfig | None = Field(default=None, description="基础生成参数")
|
||||||
|
"""基础生成参数"""
|
||||||
|
reasoning: ReasoningConfig | None = Field(default=None, description="推理能力配置")
|
||||||
|
"""推理能力配置"""
|
||||||
|
visual: VisualConfig | None = Field(default=None, description="视觉生成配置")
|
||||||
|
"""视觉生成配置"""
|
||||||
|
output: OutputConfig | None = Field(default=None, description="输出格式配置")
|
||||||
|
"""输出格式配置"""
|
||||||
|
safety: SafetyConfig | None = Field(default=None, description="安全配置")
|
||||||
|
"""安全配置"""
|
||||||
|
tool_config: ToolConfig | None = Field(default=None, description="工具调用策略配置")
|
||||||
|
"""工具调用策略配置"""
|
||||||
|
|
||||||
enable_caching: bool | None = Field(default=None, description="是否启用响应缓存")
|
enable_caching: bool | None = Field(default=None, description="是否启用响应缓存")
|
||||||
|
"""是否启用响应缓存"""
|
||||||
|
|
||||||
custom_params: dict[str, Any] | None = Field(default=None, description="自定义参数")
|
custom_params: dict[str, Any] | None = Field(default=None, description="自定义参数")
|
||||||
|
"""自定义参数"""
|
||||||
|
|
||||||
validation_policy: dict[str, Any] | None = Field(
|
validation_policy: dict[str, Any] | None = Field(
|
||||||
default=None, description="声明式的响应验证策略 (例如: {'require_image': True})"
|
default=None, description="声明式的响应验证策略 (例如: {'require_image': True})"
|
||||||
)
|
)
|
||||||
|
"""声明式的响应验证策略 (例如: {'require_image': True})"""
|
||||||
response_validator: Callable[[LLMResponse], None] | None = Field(
|
response_validator: Callable[[LLMResponse], None] | None = Field(
|
||||||
default=None, description="一个高级回调函数,用于验证响应,验证失败时应抛出异常"
|
default=None,
|
||||||
|
description="一个高级回调函数,用于验证响应,验证失败时应抛出异常",
|
||||||
)
|
)
|
||||||
|
"""一个高级回调函数,用于验证响应,验证失败时应抛出异常"""
|
||||||
|
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def builder(cls) -> "GenConfigBuilder":
|
||||||
|
"""创建一个新的配置构建器"""
|
||||||
|
return GenConfigBuilder()
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
"""转换为字典,排除None值"""
|
"""
|
||||||
|
转换为字典,排除None值。
|
||||||
|
注意:这会返回嵌套结构的字典。适配器需要处理这种嵌套。
|
||||||
|
"""
|
||||||
|
return model_dump(self, exclude_none=True)
|
||||||
|
|
||||||
model_data = model_dump(self, exclude_none=True)
|
def merge_with(self, other: "LLMGenerationConfig | None") -> "LLMGenerationConfig":
|
||||||
|
"""
|
||||||
|
与另一个配置对象进行深度合并。
|
||||||
|
other 中的非 None 字段会覆盖当前配置中的对应字段。
|
||||||
|
返回一个新的配置对象,原对象不变。
|
||||||
|
"""
|
||||||
|
if not other:
|
||||||
|
return model_copy(self, deep=True)
|
||||||
|
|
||||||
result = {}
|
new_config = model_copy(self, deep=True)
|
||||||
for key, value in model_data.items():
|
|
||||||
if key == "custom_params" and isinstance(value, dict):
|
|
||||||
result.update(value)
|
|
||||||
else:
|
|
||||||
result[key] = value
|
|
||||||
|
|
||||||
return result
|
def _merge_component(base_comp, override_comp, comp_cls):
|
||||||
|
if override_comp is None:
|
||||||
|
return base_comp
|
||||||
|
if base_comp is None:
|
||||||
|
return override_comp
|
||||||
|
updates = model_dump(override_comp, exclude_none=True)
|
||||||
|
return model_copy(base_comp, update=updates)
|
||||||
|
|
||||||
def merge_with_base_config(
|
new_config.core = _merge_component(new_config.core, other.core, CoreConfig)
|
||||||
|
new_config.reasoning = _merge_component(
|
||||||
|
new_config.reasoning, other.reasoning, ReasoningConfig
|
||||||
|
)
|
||||||
|
new_config.visual = _merge_component(
|
||||||
|
new_config.visual, other.visual, VisualConfig
|
||||||
|
)
|
||||||
|
new_config.output = _merge_component(
|
||||||
|
new_config.output, other.output, OutputConfig
|
||||||
|
)
|
||||||
|
new_config.safety = _merge_component(
|
||||||
|
new_config.safety, other.safety, SafetyConfig
|
||||||
|
)
|
||||||
|
new_config.tool_config = _merge_component(
|
||||||
|
new_config.tool_config, other.tool_config, ToolConfig
|
||||||
|
)
|
||||||
|
|
||||||
|
if other.enable_caching is not None:
|
||||||
|
new_config.enable_caching = other.enable_caching
|
||||||
|
|
||||||
|
if other.custom_params:
|
||||||
|
if new_config.custom_params is None:
|
||||||
|
new_config.custom_params = {}
|
||||||
|
new_config.custom_params.update(other.custom_params)
|
||||||
|
|
||||||
|
if other.validation_policy:
|
||||||
|
if new_config.validation_policy is None:
|
||||||
|
new_config.validation_policy = {}
|
||||||
|
new_config.validation_policy.update(other.validation_policy)
|
||||||
|
|
||||||
|
if other.response_validator:
|
||||||
|
new_config.response_validator = other.response_validator
|
||||||
|
|
||||||
|
return new_config
|
||||||
|
|
||||||
|
|
||||||
|
class LLMEmbeddingConfig(BaseModel):
|
||||||
|
"""Embedding 专用配置"""
|
||||||
|
|
||||||
|
task_type: str | None = Field(default=None, description="任务类型 (Gemini/Jina)")
|
||||||
|
"""任务类型 (Gemini/Jina)"""
|
||||||
|
output_dimensionality: int | None = Field(
|
||||||
|
default=None, description="输出维度/压缩维度 (Gemini/Jina/OpenAI)"
|
||||||
|
)
|
||||||
|
"""输出维度/压缩维度 (Gemini/Jina/OpenAI)"""
|
||||||
|
title: str | None = Field(
|
||||||
|
default=None, description="仅用于 Gemini RETRIEVAL_DOCUMENT 任务的标题"
|
||||||
|
)
|
||||||
|
"""仅用于 Gemini RETRIEVAL_DOCUMENT 任务的标题"""
|
||||||
|
encoding_format: str | None = Field(
|
||||||
|
default="float", description="编码格式 (float/base64)"
|
||||||
|
)
|
||||||
|
"""编码格式 (float/base64)"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
|
|
||||||
|
class GenConfigBuilder:
|
||||||
|
"""
|
||||||
|
LLM 生成配置的语义化构建器。
|
||||||
|
设计原则:高频业务场景优先,低频参数命名空间化。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._config = LLMGenerationConfig()
|
||||||
|
|
||||||
|
def _ensure_core(self) -> CoreConfig:
|
||||||
|
if self._config.core is None:
|
||||||
|
self._config.core = CoreConfig()
|
||||||
|
return self._config.core
|
||||||
|
|
||||||
|
def _ensure_output(self) -> OutputConfig:
|
||||||
|
if self._config.output is None:
|
||||||
|
self._config.output = OutputConfig()
|
||||||
|
return self._config.output
|
||||||
|
|
||||||
|
def _ensure_reasoning(self) -> ReasoningConfig:
|
||||||
|
if self._config.reasoning is None:
|
||||||
|
self._config.reasoning = ReasoningConfig()
|
||||||
|
return self._config.reasoning
|
||||||
|
|
||||||
|
def as_json(self, schema: dict[str, Any] | None = None) -> Self:
|
||||||
|
"""
|
||||||
|
[高频] 强制模型输出 JSON 格式。
|
||||||
|
"""
|
||||||
|
out = self._ensure_output()
|
||||||
|
out.response_format = ResponseFormat.JSON
|
||||||
|
if schema:
|
||||||
|
out.response_schema = schema
|
||||||
|
return self
|
||||||
|
|
||||||
|
def enable_thinking(
|
||||||
|
self, budget_tokens: int = -1, show_thoughts: bool = False
|
||||||
|
) -> Self:
|
||||||
|
"""
|
||||||
|
[高频] 启用模型的思考/推理能力 (如 Gemini 2.0 Flash Thinking, DeepSeek R1)。
|
||||||
|
"""
|
||||||
|
reasoning = self._ensure_reasoning()
|
||||||
|
reasoning.budget_tokens = budget_tokens
|
||||||
|
reasoning.show_thoughts = show_thoughts
|
||||||
|
return self
|
||||||
|
|
||||||
|
def config_core(
|
||||||
self,
|
self,
|
||||||
base_temperature: float | None = None,
|
temperature: float | None = None,
|
||||||
base_max_tokens: int | None = None,
|
max_tokens: int | None = None,
|
||||||
) -> dict[str, Any]:
|
top_p: float | None = None,
|
||||||
"""与基础配置合并,覆盖参数优先"""
|
top_k: int | None = None,
|
||||||
merged = {}
|
stop: list[str] | str | None = None,
|
||||||
|
frequency_penalty: float | None = None,
|
||||||
|
presence_penalty: float | None = None,
|
||||||
|
) -> Self:
|
||||||
|
"""
|
||||||
|
[低频] 配置核心生成参数。
|
||||||
|
"""
|
||||||
|
core = self._ensure_core()
|
||||||
|
if temperature is not None:
|
||||||
|
core.temperature = temperature
|
||||||
|
if max_tokens is not None:
|
||||||
|
core.max_tokens = max_tokens
|
||||||
|
if top_p is not None:
|
||||||
|
core.top_p = top_p
|
||||||
|
if top_k is not None:
|
||||||
|
core.top_k = top_k
|
||||||
|
if stop is not None:
|
||||||
|
core.stop = stop
|
||||||
|
if frequency_penalty is not None:
|
||||||
|
core.frequency_penalty = frequency_penalty
|
||||||
|
if presence_penalty is not None:
|
||||||
|
core.presence_penalty = presence_penalty
|
||||||
|
return self
|
||||||
|
|
||||||
if base_temperature is not None:
|
def config_safety(self, settings: dict[str, str]) -> Self:
|
||||||
merged["temperature"] = base_temperature
|
"""
|
||||||
if base_max_tokens is not None:
|
[低频] 配置安全过滤设置。
|
||||||
merged["max_tokens"] = base_max_tokens
|
"""
|
||||||
|
if self._config.safety is None:
|
||||||
|
self._config.safety = SafetyConfig()
|
||||||
|
self._config.safety.safety_settings = settings
|
||||||
|
return self
|
||||||
|
|
||||||
override_dict = self.to_dict()
|
def config_visual(
|
||||||
merged.update(override_dict)
|
self,
|
||||||
|
aspect_ratio: ImageAspectRatio | str | None = None,
|
||||||
|
resolution: ImageResolution | str | None = None,
|
||||||
|
) -> Self:
|
||||||
|
"""
|
||||||
|
[低频] 配置视觉生成参数 (DALL-E 3 / Gemini Imagen)。
|
||||||
|
"""
|
||||||
|
if self._config.visual is None:
|
||||||
|
self._config.visual = VisualConfig()
|
||||||
|
if aspect_ratio:
|
||||||
|
self._config.visual.aspect_ratio = aspect_ratio
|
||||||
|
if resolution:
|
||||||
|
self._config.visual.resolution = resolution
|
||||||
|
return self
|
||||||
|
|
||||||
return merged
|
def set_custom_param(self, key: str, value: Any) -> Self:
|
||||||
|
"""设置特定于厂商的自定义参数"""
|
||||||
|
if self._config.custom_params is None:
|
||||||
|
self._config.custom_params = {}
|
||||||
|
self._config.custom_params[key] = value
|
||||||
|
return self
|
||||||
|
|
||||||
|
def build(self) -> LLMGenerationConfig:
|
||||||
class LLMGenerationConfig(ModelConfigOverride):
|
"""构建最终的配置对象"""
|
||||||
"""LLM 生成配置,继承模型配置覆盖参数"""
|
return self._config
|
||||||
|
|
||||||
def to_api_params(self, api_type: str, model_name: str) -> dict[str, Any]:
|
|
||||||
"""转换为API参数,支持不同API类型的参数名映射"""
|
|
||||||
_ = model_name
|
|
||||||
params = {}
|
|
||||||
|
|
||||||
if self.temperature is not None:
|
|
||||||
params["temperature"] = self.temperature
|
|
||||||
|
|
||||||
if self.max_tokens is not None:
|
|
||||||
if api_type == "gemini":
|
|
||||||
params["maxOutputTokens"] = self.max_tokens
|
|
||||||
else:
|
|
||||||
params["max_tokens"] = self.max_tokens
|
|
||||||
|
|
||||||
if api_type == "gemini":
|
|
||||||
if self.top_k is not None:
|
|
||||||
params["topK"] = self.top_k
|
|
||||||
if self.top_p is not None:
|
|
||||||
params["topP"] = self.top_p
|
|
||||||
else:
|
|
||||||
if self.top_k is not None:
|
|
||||||
params["top_k"] = self.top_k
|
|
||||||
if self.top_p is not None:
|
|
||||||
params["top_p"] = self.top_p
|
|
||||||
|
|
||||||
if api_type in ["openai", "deepseek", "zhipu", "general_openai_compat"]:
|
|
||||||
if self.frequency_penalty is not None:
|
|
||||||
params["frequency_penalty"] = self.frequency_penalty
|
|
||||||
if self.presence_penalty is not None:
|
|
||||||
params["presence_penalty"] = self.presence_penalty
|
|
||||||
|
|
||||||
if self.repetition_penalty is not None:
|
|
||||||
if api_type == "openai":
|
|
||||||
logger.warning("OpenAI官方API不支持repetition_penalty参数,已忽略")
|
|
||||||
else:
|
|
||||||
params["repetition_penalty"] = self.repetition_penalty
|
|
||||||
|
|
||||||
if self.response_format is not None:
|
|
||||||
if isinstance(self.response_format, dict):
|
|
||||||
if api_type in ["openai", "zhipu", "deepseek", "general_openai_compat"]:
|
|
||||||
params["response_format"] = self.response_format
|
|
||||||
logger.debug(
|
|
||||||
f"为 {api_type} 使用自定义 response_format: "
|
|
||||||
f"{self.response_format}"
|
|
||||||
)
|
|
||||||
elif self.response_format == ResponseFormat.JSON:
|
|
||||||
if api_type in ["openai", "zhipu", "deepseek", "general_openai_compat"]:
|
|
||||||
params["response_format"] = {"type": "json_object"}
|
|
||||||
logger.debug(f"为 {api_type} 启用 JSON 对象输出模式")
|
|
||||||
elif api_type == "gemini":
|
|
||||||
params["responseMimeType"] = "application/json"
|
|
||||||
if self.response_schema:
|
|
||||||
params["responseSchema"] = self.response_schema
|
|
||||||
logger.debug(f"为 {api_type} 启用 JSON MIME 类型输出模式")
|
|
||||||
|
|
||||||
if self.custom_params:
|
|
||||||
custom_mapped = apply_api_specific_mappings(self.custom_params, api_type)
|
|
||||||
params.update(custom_mapped)
|
|
||||||
|
|
||||||
if api_type == "gemini":
|
|
||||||
if (
|
|
||||||
self.response_format != ResponseFormat.JSON
|
|
||||||
and self.response_mime_type is not None
|
|
||||||
):
|
|
||||||
params["responseMimeType"] = self.response_mime_type
|
|
||||||
logger.debug(
|
|
||||||
f"使用显式设置的 responseMimeType: {self.response_mime_type}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.response_schema is not None and "responseSchema" not in params:
|
|
||||||
params["responseSchema"] = self.response_schema
|
|
||||||
|
|
||||||
if self.thinking_budget is not None or self.include_thoughts is not None:
|
|
||||||
thinking_config = params.setdefault("thinkingConfig", {})
|
|
||||||
|
|
||||||
if self.thinking_budget is not None:
|
|
||||||
max_budget = 24576
|
|
||||||
budget_value = int(self.thinking_budget * max_budget)
|
|
||||||
thinking_config["thinkingBudget"] = budget_value
|
|
||||||
logger.debug(
|
|
||||||
f"已将 thinking_budget (float: {self.thinking_budget}) "
|
|
||||||
f"转换为 Gemini API 的整数格式: {budget_value}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.include_thoughts is not None:
|
|
||||||
thinking_config["includeThoughts"] = self.include_thoughts
|
|
||||||
logger.debug(f"已设置 includeThoughts: {self.include_thoughts}")
|
|
||||||
|
|
||||||
if self.safety_settings is not None:
|
|
||||||
params["safetySettings"] = self.safety_settings
|
|
||||||
if self.response_modalities is not None:
|
|
||||||
params["responseModalities"] = self.response_modalities
|
|
||||||
|
|
||||||
logger.debug(f"为{api_type}转换配置参数: {len(params)}个参数")
|
|
||||||
return params
|
|
||||||
|
|
||||||
|
|
||||||
def validate_override_params(
|
def validate_override_params(
|
||||||
@ -215,12 +403,12 @@ def validate_override_params(
|
|||||||
if override_config is None:
|
if override_config is None:
|
||||||
return LLMGenerationConfig()
|
return LLMGenerationConfig()
|
||||||
|
|
||||||
|
if isinstance(override_config, LLMGenerationConfig):
|
||||||
|
return override_config
|
||||||
|
|
||||||
if isinstance(override_config, dict):
|
if isinstance(override_config, dict):
|
||||||
try:
|
try:
|
||||||
filtered_config = {
|
return model_validate(LLMGenerationConfig, override_config)
|
||||||
k: v for k, v in override_config.items() if v is not None
|
|
||||||
}
|
|
||||||
return LLMGenerationConfig(**filtered_config)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"覆盖配置参数验证失败: {e}")
|
logger.warning(f"覆盖配置参数验证失败: {e}")
|
||||||
raise LLMException(
|
raise LLMException(
|
||||||
@ -229,56 +417,107 @@ def validate_override_params(
|
|||||||
cause=e,
|
cause=e,
|
||||||
)
|
)
|
||||||
|
|
||||||
return override_config
|
raise LLMException(
|
||||||
|
f"不支持的配置类型: {type(override_config)}",
|
||||||
|
code=LLMErrorCode.CONFIGURATION_ERROR,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def apply_api_specific_mappings(
|
class CommonOverrides:
|
||||||
params: dict[str, Any], api_type: str
|
"""常用的配置覆盖预设"""
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""应用API特定的参数映射"""
|
|
||||||
mapped_params = params.copy()
|
|
||||||
|
|
||||||
if api_type == "gemini":
|
@staticmethod
|
||||||
if "max_tokens" in mapped_params:
|
def gemini_json() -> LLMGenerationConfig:
|
||||||
mapped_params["maxOutputTokens"] = mapped_params.pop("max_tokens")
|
"""Gemini JSON模式:强制JSON输出"""
|
||||||
if "top_k" in mapped_params:
|
return LLMGenerationConfig(
|
||||||
mapped_params["topK"] = mapped_params.pop("top_k")
|
core=CoreConfig(),
|
||||||
if "top_p" in mapped_params:
|
output=OutputConfig(
|
||||||
mapped_params["topP"] = mapped_params.pop("top_p")
|
response_format=ResponseFormat.JSON,
|
||||||
|
response_mime_type="application/json",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
unsupported = ["frequency_penalty", "presence_penalty", "repetition_penalty"]
|
@staticmethod
|
||||||
for param in unsupported:
|
def gemini_2_5_thinking(tokens: int = -1) -> LLMGenerationConfig:
|
||||||
if param in mapped_params:
|
"""Gemini 2.5 思考模式:默认 -1 (动态思考),0 为禁用,>=1024 为固定预算"""
|
||||||
logger.warning(f"Gemini 原生API不支持参数 '{param}',已忽略")
|
return LLMGenerationConfig(
|
||||||
mapped_params.pop(param)
|
core=CoreConfig(temperature=1.0),
|
||||||
|
reasoning=ReasoningConfig(budget_tokens=tokens, show_thoughts=True),
|
||||||
|
)
|
||||||
|
|
||||||
elif api_type in ["openai", "deepseek", "zhipu", "general_openai_compat"]:
|
@staticmethod
|
||||||
if "repetition_penalty" in mapped_params and api_type == "openai":
|
def gemini_3_thinking(level: str = "HIGH") -> LLMGenerationConfig:
|
||||||
logger.warning("OpenAI官方API不支持repetition_penalty参数,已忽略")
|
"""Gemini 3 深度思考模式:使用思考等级"""
|
||||||
mapped_params.pop("repetition_penalty")
|
try:
|
||||||
|
effort = ReasoningEffort(level.upper())
|
||||||
|
except ValueError:
|
||||||
|
effort = ReasoningEffort.HIGH
|
||||||
|
|
||||||
if "stop" in mapped_params:
|
return LLMGenerationConfig(
|
||||||
stop_value = mapped_params["stop"]
|
core=CoreConfig(),
|
||||||
if isinstance(stop_value, str):
|
reasoning=ReasoningConfig(effort=effort, show_thoughts=True),
|
||||||
mapped_params["stop"] = [stop_value]
|
)
|
||||||
|
|
||||||
return mapped_params
|
@staticmethod
|
||||||
|
def gemini_structured(schema: dict[str, Any]) -> LLMGenerationConfig:
|
||||||
|
"""Gemini 结构化输出:自定义JSON模式"""
|
||||||
|
return LLMGenerationConfig(
|
||||||
|
core=CoreConfig(),
|
||||||
|
output=OutputConfig(
|
||||||
|
response_mime_type="application/json", response_schema=schema
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gemini_safe() -> LLMGenerationConfig:
|
||||||
|
"""Gemini 安全模式:使用配置的安全设置"""
|
||||||
|
threshold = get_gemini_safety_threshold()
|
||||||
|
return LLMGenerationConfig(
|
||||||
|
core=CoreConfig(),
|
||||||
|
safety=SafetyConfig(
|
||||||
|
safety_settings={
|
||||||
|
"HARM_CATEGORY_HARASSMENT": threshold,
|
||||||
|
"HARM_CATEGORY_HATE_SPEECH": threshold,
|
||||||
|
"HARM_CATEGORY_SEXUALLY_EXPLICIT": threshold,
|
||||||
|
"HARM_CATEGORY_DANGEROUS_CONTENT": threshold,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def create_generation_config_from_kwargs(**kwargs) -> LLMGenerationConfig:
|
@staticmethod
|
||||||
"""从关键字参数创建生成配置"""
|
def gemini_code_execution() -> LLMGenerationConfig:
|
||||||
model_fields = getattr(LLMGenerationConfig, "model_fields", {})
|
"""Gemini 代码执行模式:启用代码执行功能"""
|
||||||
known_fields = set(model_fields.keys())
|
return LLMGenerationConfig(
|
||||||
known_params = {}
|
core=CoreConfig(),
|
||||||
custom_params = {}
|
custom_params={"code_execution_timeout": 30},
|
||||||
|
)
|
||||||
|
|
||||||
for key, value in kwargs.items():
|
@staticmethod
|
||||||
if key in known_fields:
|
def gemini_grounding() -> LLMGenerationConfig:
|
||||||
known_params[key] = value
|
"""Gemini 信息来源关联模式:启用Google搜索"""
|
||||||
else:
|
return LLMGenerationConfig(
|
||||||
custom_params[key] = value
|
core=CoreConfig(),
|
||||||
|
custom_params={
|
||||||
|
"grounding_config": {"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
if custom_params:
|
@staticmethod
|
||||||
known_params["custom_params"] = custom_params
|
def gemini_nano_banana(aspect_ratio: str = "16:9") -> LLMGenerationConfig:
|
||||||
|
"""Gemini Nano Banana Pro:自定义比例生图"""
|
||||||
|
try:
|
||||||
|
ar = ImageAspectRatio(aspect_ratio)
|
||||||
|
except ValueError:
|
||||||
|
ar = ImageAspectRatio.LANDSCAPE_16_9
|
||||||
|
|
||||||
return LLMGenerationConfig(**known_params)
|
return LLMGenerationConfig(
|
||||||
|
core=CoreConfig(),
|
||||||
|
visual=VisualConfig(aspect_ratio=ar),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gemini_high_res() -> LLMGenerationConfig:
|
||||||
|
"""Gemini 3: 强制使用高解析度处理输入媒体"""
|
||||||
|
return LLMGenerationConfig(
|
||||||
|
visual=VisualConfig(media_resolution="HIGH", resolution=ImageResolution.HD)
|
||||||
|
)
|
||||||
|
|||||||
@ -1,172 +0,0 @@
|
|||||||
"""
|
|
||||||
LLM 预设配置
|
|
||||||
|
|
||||||
提供常用的配置预设,特别是针对 Gemini 的高级功能。
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from .generation import LLMGenerationConfig
|
|
||||||
|
|
||||||
|
|
||||||
class CommonOverrides:
|
|
||||||
"""常用的配置覆盖预设"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def creative() -> LLMGenerationConfig:
|
|
||||||
"""创意模式:高温度,鼓励创新"""
|
|
||||||
return LLMGenerationConfig(temperature=0.9, top_p=0.95, frequency_penalty=0.1)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def precise() -> LLMGenerationConfig:
|
|
||||||
"""精确模式:低温度,确定性输出"""
|
|
||||||
return LLMGenerationConfig(temperature=0.1, top_p=0.9, frequency_penalty=0.0)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def balanced() -> LLMGenerationConfig:
|
|
||||||
"""平衡模式:中等温度"""
|
|
||||||
return LLMGenerationConfig(temperature=0.5, top_p=0.9, frequency_penalty=0.0)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def concise(max_tokens: int = 100) -> LLMGenerationConfig:
|
|
||||||
"""简洁模式:限制输出长度"""
|
|
||||||
return LLMGenerationConfig(
|
|
||||||
temperature=0.3,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
stop=["\n\n", "。", "!", "?"],
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def detailed(max_tokens: int = 2000) -> LLMGenerationConfig:
|
|
||||||
"""详细模式:鼓励详细输出"""
|
|
||||||
return LLMGenerationConfig(
|
|
||||||
temperature=0.7, max_tokens=max_tokens, frequency_penalty=-0.1
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def gemini_json() -> LLMGenerationConfig:
|
|
||||||
"""Gemini JSON模式:强制JSON输出"""
|
|
||||||
return LLMGenerationConfig(
|
|
||||||
temperature=0.3, response_mime_type="application/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def gemini_thinking(budget: float = 0.8) -> LLMGenerationConfig:
|
|
||||||
"""Gemini 思考模式:使用思考预算"""
|
|
||||||
return LLMGenerationConfig(temperature=0.7, thinking_budget=budget)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def gemini_creative() -> LLMGenerationConfig:
|
|
||||||
"""Gemini 创意模式:高温度创意输出"""
|
|
||||||
return LLMGenerationConfig(temperature=0.9, top_p=0.95)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def gemini_structured(schema: dict[str, Any]) -> LLMGenerationConfig:
|
|
||||||
"""Gemini 结构化输出:自定义JSON模式"""
|
|
||||||
return LLMGenerationConfig(
|
|
||||||
temperature=0.3,
|
|
||||||
response_mime_type="application/json",
|
|
||||||
response_schema=schema,
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def gemini_safe() -> LLMGenerationConfig:
|
|
||||||
"""Gemini 安全模式:使用配置的安全设置"""
|
|
||||||
from .providers import get_gemini_safety_threshold
|
|
||||||
|
|
||||||
threshold = get_gemini_safety_threshold()
|
|
||||||
return LLMGenerationConfig(
|
|
||||||
temperature=0.5,
|
|
||||||
safety_settings={
|
|
||||||
"HARM_CATEGORY_HARASSMENT": threshold,
|
|
||||||
"HARM_CATEGORY_HATE_SPEECH": threshold,
|
|
||||||
"HARM_CATEGORY_SEXUALLY_EXPLICIT": threshold,
|
|
||||||
"HARM_CATEGORY_DANGEROUS_CONTENT": threshold,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def gemini_multimodal() -> LLMGenerationConfig:
|
|
||||||
"""Gemini 多模态模式:优化多模态处理"""
|
|
||||||
return LLMGenerationConfig(temperature=0.6, max_tokens=2048, top_p=0.8)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def gemini_code_execution() -> LLMGenerationConfig:
|
|
||||||
"""Gemini 代码执行模式:启用代码执行功能"""
|
|
||||||
return LLMGenerationConfig(
|
|
||||||
temperature=0.3,
|
|
||||||
max_tokens=4096,
|
|
||||||
enable_code_execution=True,
|
|
||||||
custom_params={"code_execution_timeout": 30},
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def gemini_grounding() -> LLMGenerationConfig:
|
|
||||||
"""Gemini 信息来源关联模式:启用Google搜索"""
|
|
||||||
return LLMGenerationConfig(
|
|
||||||
temperature=0.5,
|
|
||||||
max_tokens=4096,
|
|
||||||
enable_grounding=True,
|
|
||||||
custom_params={
|
|
||||||
"grounding_config": {"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def gemini_cached() -> LLMGenerationConfig:
|
|
||||||
"""Gemini 缓存模式:启用响应缓存"""
|
|
||||||
return LLMGenerationConfig(
|
|
||||||
temperature=0.3,
|
|
||||||
max_tokens=2048,
|
|
||||||
enable_caching=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def gemini_advanced() -> LLMGenerationConfig:
|
|
||||||
"""Gemini 高级模式:启用所有高级功能"""
|
|
||||||
return LLMGenerationConfig(
|
|
||||||
temperature=0.5,
|
|
||||||
max_tokens=4096,
|
|
||||||
enable_code_execution=True,
|
|
||||||
enable_grounding=True,
|
|
||||||
enable_caching=True,
|
|
||||||
custom_params={
|
|
||||||
"code_execution_timeout": 30,
|
|
||||||
"grounding_config": {
|
|
||||||
"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def gemini_research() -> LLMGenerationConfig:
|
|
||||||
"""Gemini 研究模式:思考+搜索+结构化输出"""
|
|
||||||
return LLMGenerationConfig(
|
|
||||||
temperature=0.6,
|
|
||||||
max_tokens=4096,
|
|
||||||
thinking_budget=0.8,
|
|
||||||
enable_grounding=True,
|
|
||||||
response_mime_type="application/json",
|
|
||||||
custom_params={
|
|
||||||
"grounding_config": {"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def gemini_analysis() -> LLMGenerationConfig:
|
|
||||||
"""Gemini 分析模式:深度思考+详细输出"""
|
|
||||||
return LLMGenerationConfig(
|
|
||||||
temperature=0.4,
|
|
||||||
max_tokens=6000,
|
|
||||||
thinking_budget=0.9,
|
|
||||||
top_p=0.8,
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def gemini_fast_response() -> LLMGenerationConfig:
|
|
||||||
"""Gemini 快速响应模式:低延迟+简洁输出"""
|
|
||||||
return LLMGenerationConfig(
|
|
||||||
temperature=0.3,
|
|
||||||
max_tokens=512,
|
|
||||||
top_p=0.8,
|
|
||||||
)
|
|
||||||
@ -13,6 +13,7 @@ from zhenxun.configs.config import Config
|
|||||||
from zhenxun.configs.utils import parse_as
|
from zhenxun.configs.utils import parse_as
|
||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||||
|
from zhenxun.utils.pydantic_compat import model_dump
|
||||||
|
|
||||||
from ..core import key_store
|
from ..core import key_store
|
||||||
from ..tools import tool_provider_manager
|
from ..tools import tool_provider_manager
|
||||||
@ -22,6 +23,39 @@ AI_CONFIG_GROUP = "AI"
|
|||||||
PROVIDERS_CONFIG_KEY = "PROVIDERS"
|
PROVIDERS_CONFIG_KEY = "PROVIDERS"
|
||||||
|
|
||||||
|
|
||||||
|
class DebugLogOptions(BaseModel):
|
||||||
|
"""调试日志细粒度控制"""
|
||||||
|
|
||||||
|
show_tools: bool = Field(
|
||||||
|
default=True, description="是否在日志中显示工具定义(JSON Schema)"
|
||||||
|
)
|
||||||
|
show_schema: bool = Field(
|
||||||
|
default=True, description="是否在日志中显示结构化输出Schema(response_format)"
|
||||||
|
)
|
||||||
|
show_safety: bool = Field(
|
||||||
|
default=True, description="是否在日志中显示安全设置(safetySettings)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __bool__(self) -> bool:
|
||||||
|
"""支持 bool(debug_options) 的语法,方便兼容旧逻辑。"""
|
||||||
|
return self.show_tools or self.show_schema or self.show_safety
|
||||||
|
|
||||||
|
|
||||||
|
class ClientSettings(BaseModel):
|
||||||
|
"""LLM 客户端通用设置"""
|
||||||
|
|
||||||
|
timeout: int = Field(default=300, description="API请求超时时间(秒)")
|
||||||
|
max_retries: int = Field(default=3, description="请求失败时的最大重试次数")
|
||||||
|
retry_delay: int = Field(default=2, description="请求重试的基础延迟时间(秒)")
|
||||||
|
structured_retries: int = Field(
|
||||||
|
default=2, description="结构化生成校验失败时的最大重试次数 (IVR)"
|
||||||
|
)
|
||||||
|
proxy: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="网络代理,例如 http://127.0.0.1:7890",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LLMConfig(BaseModel):
|
class LLMConfig(BaseModel):
|
||||||
"""LLM 服务配置类"""
|
"""LLM 服务配置类"""
|
||||||
|
|
||||||
@ -29,20 +63,16 @@ class LLMConfig(BaseModel):
|
|||||||
default=None,
|
default=None,
|
||||||
description="LLM服务全局默认使用的模型名称 (格式: ProviderName/ModelName)",
|
description="LLM服务全局默认使用的模型名称 (格式: ProviderName/ModelName)",
|
||||||
)
|
)
|
||||||
proxy: str | None = Field(
|
client_settings: ClientSettings = Field(
|
||||||
default=None,
|
default_factory=ClientSettings, description="客户端连接与重试配置"
|
||||||
description="LLM服务请求使用的网络代理,例如 http://127.0.0.1:7890",
|
|
||||||
)
|
|
||||||
timeout: int = Field(default=180, description="LLM服务API请求超时时间(秒)")
|
|
||||||
max_retries_llm: int = Field(
|
|
||||||
default=3, description="LLM服务请求失败时的最大重试次数"
|
|
||||||
)
|
|
||||||
retry_delay_llm: int = Field(
|
|
||||||
default=2, description="LLM服务请求重试的基础延迟时间(秒)"
|
|
||||||
)
|
)
|
||||||
providers: list[ProviderConfig] = Field(
|
providers: list[ProviderConfig] = Field(
|
||||||
default_factory=list, description="配置多个 AI 服务提供商及其模型信息"
|
default_factory=list, description="配置多个 AI 服务提供商及其模型信息"
|
||||||
)
|
)
|
||||||
|
debug_log: DebugLogOptions | bool = Field(
|
||||||
|
default_factory=DebugLogOptions,
|
||||||
|
description="LLM请求日志详情开关。支持 bool (全开/全关) 或 dict (细粒度控制)。",
|
||||||
|
)
|
||||||
|
|
||||||
def get_provider_by_name(self, name: str) -> ProviderConfig | None:
|
def get_provider_by_name(self, name: str) -> ProviderConfig | None:
|
||||||
"""根据名称获取提供商配置
|
"""根据名称获取提供商配置
|
||||||
@ -226,36 +256,29 @@ def register_llm_configs():
|
|||||||
)
|
)
|
||||||
Config.add_plugin_config(
|
Config.add_plugin_config(
|
||||||
AI_CONFIG_GROUP,
|
AI_CONFIG_GROUP,
|
||||||
"proxy",
|
"client_settings",
|
||||||
llm_config.proxy,
|
model_dump(llm_config.client_settings),
|
||||||
help="LLM服务请求使用的网络代理,例如 http://127.0.0.1:7890",
|
help=(
|
||||||
type=str,
|
"LLM客户端高级设置。\n"
|
||||||
|
"包含: timeout(超时秒数), max_retries(重试次数), "
|
||||||
|
"retry_delay(重试延迟), structured_retries(结构化生成重试), proxy(代理)"
|
||||||
|
),
|
||||||
|
type=dict,
|
||||||
)
|
)
|
||||||
Config.add_plugin_config(
|
Config.add_plugin_config(
|
||||||
AI_CONFIG_GROUP,
|
AI_CONFIG_GROUP,
|
||||||
"timeout",
|
"debug_log",
|
||||||
llm_config.timeout,
|
{"show_tools": True, "show_schema": True, "show_safety": True},
|
||||||
help="LLM服务API请求超时时间(秒)",
|
help=(
|
||||||
type=int,
|
"LLM日志详情开关。示例: {'show_tools': True, 'show_schema': False, "
|
||||||
)
|
"'show_safety': False}"
|
||||||
Config.add_plugin_config(
|
),
|
||||||
AI_CONFIG_GROUP,
|
type=dict,
|
||||||
"max_retries_llm",
|
|
||||||
llm_config.max_retries_llm,
|
|
||||||
help="LLM服务请求失败时的最大重试次数",
|
|
||||||
type=int,
|
|
||||||
)
|
|
||||||
Config.add_plugin_config(
|
|
||||||
AI_CONFIG_GROUP,
|
|
||||||
"retry_delay_llm",
|
|
||||||
llm_config.retry_delay_llm,
|
|
||||||
help="LLM服务请求重试的基础延迟时间(秒)",
|
|
||||||
type=int,
|
|
||||||
)
|
)
|
||||||
Config.add_plugin_config(
|
Config.add_plugin_config(
|
||||||
AI_CONFIG_GROUP,
|
AI_CONFIG_GROUP,
|
||||||
"gemini_safety_threshold",
|
"gemini_safety_threshold",
|
||||||
"BLOCK_MEDIUM_AND_ABOVE",
|
"BLOCK_NONE",
|
||||||
help=(
|
help=(
|
||||||
"Gemini 安全过滤阈值 "
|
"Gemini 安全过滤阈值 "
|
||||||
"(BLOCK_LOW_AND_ABOVE: 阻止低级别及以上, "
|
"(BLOCK_LOW_AND_ABOVE: 阻止低级别及以上, "
|
||||||
@ -270,7 +293,20 @@ def register_llm_configs():
|
|||||||
AI_CONFIG_GROUP,
|
AI_CONFIG_GROUP,
|
||||||
PROVIDERS_CONFIG_KEY,
|
PROVIDERS_CONFIG_KEY,
|
||||||
get_default_providers(),
|
get_default_providers(),
|
||||||
help="配置多个 AI 服务提供商及其模型信息",
|
help=(
|
||||||
|
"配置多个 AI 服务提供商及其模型信息。\n"
|
||||||
|
"注意:可以在特定模型配置下添加 'api_type' 以覆盖提供商的全局设置。\n"
|
||||||
|
"支持的 api_type 包括:\n"
|
||||||
|
"- 'openai': 标准 OpenAI 格式 (DeepSeek, SiliconFlow, Moonshot 等)\n"
|
||||||
|
"- 'gemini': Google Gemini API\n"
|
||||||
|
"- 'zhipu': 智谱 AI (GLM)\n"
|
||||||
|
"- 'ark': 字节跳动火山引擎 (Doubao)\n"
|
||||||
|
"- 'openrouter': OpenRouter 聚合平台\n"
|
||||||
|
"- 'openai_image': OpenAI 兼容的图像生成接口 (DALL-E)\n"
|
||||||
|
"- 'openai_responses': 支持新版 responses 格式的 OpenAI 兼容接口\n"
|
||||||
|
"- 'smart': 智能路由模式 (主要用于第三方中转场景,自动根据模型名"
|
||||||
|
"分发请求到 openai 或 gemini)"
|
||||||
|
),
|
||||||
default_value=[],
|
default_value=[],
|
||||||
type=list[ProviderConfig],
|
type=list[ProviderConfig],
|
||||||
)
|
)
|
||||||
@ -278,15 +314,21 @@ def register_llm_configs():
|
|||||||
|
|
||||||
@lru_cache(maxsize=1)
|
@lru_cache(maxsize=1)
|
||||||
def get_llm_config() -> LLMConfig:
|
def get_llm_config() -> LLMConfig:
|
||||||
"""获取 LLM 配置实例,不再加载 MCP 工具配置"""
|
"""获取 LLM 配置实例"""
|
||||||
ai_config = get_ai_config()
|
ai_config = get_ai_config()
|
||||||
|
|
||||||
|
raw_debug = ai_config.get("debug_log", False)
|
||||||
|
if isinstance(raw_debug, bool):
|
||||||
|
debug_log_val = DebugLogOptions(
|
||||||
|
show_tools=raw_debug, show_schema=raw_debug, show_safety=raw_debug
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
debug_log_val = raw_debug
|
||||||
|
|
||||||
config_data = {
|
config_data = {
|
||||||
"default_model_name": ai_config.get("default_model_name"),
|
"default_model_name": ai_config.get("default_model_name"),
|
||||||
"proxy": ai_config.get("proxy"),
|
"client_settings": ai_config.get("client_settings", {}),
|
||||||
"timeout": ai_config.get("timeout", 180),
|
"debug_log": debug_log_val,
|
||||||
"max_retries_llm": ai_config.get("max_retries_llm", 3),
|
|
||||||
"retry_delay_llm": ai_config.get("retry_delay_llm", 2),
|
|
||||||
PROVIDERS_CONFIG_KEY: ai_config.get(PROVIDERS_CONFIG_KEY, []),
|
PROVIDERS_CONFIG_KEY: ai_config.get(PROVIDERS_CONFIG_KEY, []),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -314,14 +356,14 @@ def validate_llm_config() -> tuple[bool, list[str]]:
|
|||||||
try:
|
try:
|
||||||
llm_config = get_llm_config()
|
llm_config = get_llm_config()
|
||||||
|
|
||||||
if llm_config.timeout <= 0:
|
if llm_config.client_settings.timeout <= 0:
|
||||||
errors.append("timeout 必须大于 0")
|
errors.append("timeout 必须大于 0")
|
||||||
|
|
||||||
if llm_config.max_retries_llm < 0:
|
if llm_config.client_settings.max_retries < 0:
|
||||||
errors.append("max_retries_llm 不能小于 0")
|
errors.append("max_retries 不能小于 0")
|
||||||
|
|
||||||
if llm_config.retry_delay_llm <= 0:
|
if llm_config.client_settings.retry_delay <= 0:
|
||||||
errors.append("retry_delay_llm 必须大于 0")
|
errors.append("retry_delay 必须大于 0")
|
||||||
|
|
||||||
if not llm_config.providers:
|
if not llm_config.providers:
|
||||||
errors.append("至少需要配置一个 AI 服务提供商")
|
errors.append("至少需要配置一个 AI 服务提供商")
|
||||||
|
|||||||
@ -254,7 +254,7 @@ class KeyStats:
|
|||||||
if total_calls == 0:
|
if total_calls == 0:
|
||||||
return KeyStatus.UNUSED
|
return KeyStatus.UNUSED
|
||||||
|
|
||||||
if self.success_rate < 80:
|
if self.success_rate < 70:
|
||||||
return KeyStatus.ERROR
|
return KeyStatus.ERROR
|
||||||
|
|
||||||
if total_calls >= 5 and self.avg_latency > 15000:
|
if total_calls >= 5 and self.avg_latency > 15000:
|
||||||
@ -292,96 +292,6 @@ class RetryConfig:
|
|||||||
self.key_rotation = key_rotation
|
self.key_rotation = key_rotation
|
||||||
|
|
||||||
|
|
||||||
async def with_smart_retry(
|
|
||||||
func,
|
|
||||||
*args,
|
|
||||||
retry_config: RetryConfig | None = None,
|
|
||||||
key_store: "KeyStatusStore | None" = None,
|
|
||||||
provider_name: str | None = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Any:
|
|
||||||
"""
|
|
||||||
智能重试装饰器 - 支持Key轮询和错误分类
|
|
||||||
|
|
||||||
参数:
|
|
||||||
func: 要重试的异步函数。
|
|
||||||
*args: 传递给函数的位置参数。
|
|
||||||
retry_config: 重试配置。
|
|
||||||
key_store: API密钥状态存储。
|
|
||||||
provider_name: 提供商名称。
|
|
||||||
**kwargs: 传递给函数的关键字参数。
|
|
||||||
|
|
||||||
返回:
|
|
||||||
Any: 函数执行结果。
|
|
||||||
"""
|
|
||||||
config = retry_config or RetryConfig()
|
|
||||||
last_exception: Exception | None = None
|
|
||||||
failed_keys: set[str] = set()
|
|
||||||
|
|
||||||
model_instance = next((arg for arg in args if hasattr(arg, "api_keys")), None)
|
|
||||||
all_provider_keys = model_instance.api_keys if model_instance else []
|
|
||||||
|
|
||||||
for attempt in range(config.max_retries + 1):
|
|
||||||
try:
|
|
||||||
if config.key_rotation and "failed_keys" in func.__code__.co_varnames:
|
|
||||||
kwargs["failed_keys"] = failed_keys
|
|
||||||
|
|
||||||
start_time = time.monotonic()
|
|
||||||
result = await func(*args, **kwargs)
|
|
||||||
latency = (time.monotonic() - start_time) * 1000
|
|
||||||
|
|
||||||
if key_store and isinstance(result, tuple) and len(result) == 2:
|
|
||||||
_, api_key_used = result
|
|
||||||
if api_key_used:
|
|
||||||
await key_store.record_success(api_key_used, latency)
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
return result
|
|
||||||
|
|
||||||
except LLMException as e:
|
|
||||||
last_exception = e
|
|
||||||
api_key_in_use = e.details.get("api_key")
|
|
||||||
|
|
||||||
if api_key_in_use:
|
|
||||||
failed_keys.add(api_key_in_use)
|
|
||||||
if key_store and provider_name and len(all_provider_keys) > 1:
|
|
||||||
status_code = e.details.get("status_code")
|
|
||||||
error_message = f"({e.code.name}) {e.message}"
|
|
||||||
await key_store.record_failure(
|
|
||||||
api_key_in_use, status_code, error_message
|
|
||||||
)
|
|
||||||
|
|
||||||
should_retry = _should_retry_llm_error(e, attempt, config.max_retries)
|
|
||||||
if not should_retry:
|
|
||||||
logger.error(f"不可重试的错误,停止重试: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
if attempt < config.max_retries:
|
|
||||||
wait_time = config.retry_delay
|
|
||||||
if config.exponential_backoff:
|
|
||||||
wait_time *= 2**attempt
|
|
||||||
logger.warning(
|
|
||||||
f"请求失败,{wait_time:.2f}秒后重试 (第{attempt + 1}次): {e}"
|
|
||||||
)
|
|
||||||
await asyncio.sleep(wait_time)
|
|
||||||
else:
|
|
||||||
logger.error(f"重试{config.max_retries}次后仍然失败: {e}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
last_exception = e
|
|
||||||
logger.error(f"非LLM异常,停止重试: {e}")
|
|
||||||
raise LLMException(
|
|
||||||
f"操作失败: {e}",
|
|
||||||
code=LLMErrorCode.GENERATION_FAILED,
|
|
||||||
cause=e,
|
|
||||||
)
|
|
||||||
|
|
||||||
if last_exception:
|
|
||||||
raise last_exception
|
|
||||||
else:
|
|
||||||
raise RuntimeError("重试函数未能正常执行且未捕获到异常")
|
|
||||||
|
|
||||||
|
|
||||||
def _should_retry_llm_error(
|
def _should_retry_llm_error(
|
||||||
error: LLMException, attempt: int, max_retries: int
|
error: LLMException, attempt: int, max_retries: int
|
||||||
) -> bool:
|
) -> bool:
|
||||||
@ -390,7 +300,9 @@ def _should_retry_llm_error(
|
|||||||
LLMErrorCode.MODEL_NOT_FOUND,
|
LLMErrorCode.MODEL_NOT_FOUND,
|
||||||
LLMErrorCode.CONTEXT_LENGTH_EXCEEDED,
|
LLMErrorCode.CONTEXT_LENGTH_EXCEEDED,
|
||||||
LLMErrorCode.USER_LOCATION_NOT_SUPPORTED,
|
LLMErrorCode.USER_LOCATION_NOT_SUPPORTED,
|
||||||
|
LLMErrorCode.INVALID_PARAMETER,
|
||||||
LLMErrorCode.CONFIGURATION_ERROR,
|
LLMErrorCode.CONFIGURATION_ERROR,
|
||||||
|
LLMErrorCode.API_KEY_INVALID,
|
||||||
}
|
}
|
||||||
|
|
||||||
if error.code in non_retryable_errors:
|
if error.code in non_retryable_errors:
|
||||||
@ -404,15 +316,12 @@ def _should_retry_llm_error(
|
|||||||
LLMErrorCode.RESPONSE_PARSE_ERROR,
|
LLMErrorCode.RESPONSE_PARSE_ERROR,
|
||||||
LLMErrorCode.GENERATION_FAILED,
|
LLMErrorCode.GENERATION_FAILED,
|
||||||
LLMErrorCode.CONTENT_FILTERED,
|
LLMErrorCode.CONTENT_FILTERED,
|
||||||
LLMErrorCode.API_KEY_INVALID,
|
|
||||||
LLMErrorCode.API_QUOTA_EXCEEDED,
|
LLMErrorCode.API_QUOTA_EXCEEDED,
|
||||||
}
|
}
|
||||||
|
|
||||||
if error.code in retryable_errors:
|
if error.code in retryable_errors:
|
||||||
if error.code == LLMErrorCode.API_QUOTA_EXCEEDED:
|
if error.code == LLMErrorCode.API_QUOTA_EXCEEDED:
|
||||||
return attempt < min(2, max_retries)
|
return attempt < min(2, max_retries)
|
||||||
elif error.code == LLMErrorCode.CONTENT_FILTERED:
|
|
||||||
return attempt < min(1, max_retries)
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
@ -558,14 +467,68 @@ class KeyStatusStore:
|
|||||||
now = time.time()
|
now = time.time()
|
||||||
cooldown_duration = 300
|
cooldown_duration = 300
|
||||||
|
|
||||||
if status_code in [401, 403, 404]:
|
location_not_supported = error_message and (
|
||||||
|
"USER_LOCATION_NOT_SUPPORTED" in error_message
|
||||||
|
or "User location is not supported" in error_message
|
||||||
|
)
|
||||||
|
if location_not_supported:
|
||||||
|
logger.warning(
|
||||||
|
f"API Key {key_id} 请求失败,原因是地区不支持 (Gemini)。"
|
||||||
|
" 这通常是代理节点问题,Key 本身可能是正常的。跳过冷却。"
|
||||||
|
)
|
||||||
|
async with self._lock:
|
||||||
|
stats = self._key_stats.setdefault(api_key, KeyStats())
|
||||||
|
stats.failure_count += 1
|
||||||
|
stats.last_error_info = error_message[:256]
|
||||||
|
await self._save_to_file_internal()
|
||||||
|
return
|
||||||
|
|
||||||
|
if error_message and (
|
||||||
|
"API_QUOTA_EXCEEDED" in error_message
|
||||||
|
or "insufficient_quota" in error_message.lower()
|
||||||
|
):
|
||||||
|
cooldown_duration = 3600
|
||||||
|
logger.warning(f"API Key {key_id} 额度耗尽,冷却 1 小时。")
|
||||||
|
|
||||||
|
is_key_invalid = status_code == 401 or (
|
||||||
|
status_code == 400
|
||||||
|
and error_message
|
||||||
|
and (
|
||||||
|
"API_KEY_INVALID" in error_message
|
||||||
|
or "API key not valid" in error_message
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_key_invalid:
|
||||||
cooldown_duration = 31536000
|
cooldown_duration = 31536000
|
||||||
log_level = "error"
|
log_level = "error"
|
||||||
log_message = f"API密钥认证/权限/路径错误,将永久禁用: {key_id}"
|
log_message = f"API密钥认证/权限/路径错误,将永久禁用: {key_id}"
|
||||||
|
elif status_code == 403:
|
||||||
|
cooldown_duration = 3600
|
||||||
|
log_level = "warning"
|
||||||
|
log_message = f"API密钥权限不足或地区不支持(403),冷却1小时: {key_id}"
|
||||||
|
elif status_code == 404:
|
||||||
|
log_level = "error"
|
||||||
|
log_message = "API请求返回 404 (未找到),可能是模型名称错误或接口地址"
|
||||||
|
f"错误,不冷却密钥: {key_id}"
|
||||||
|
elif status_code == 422:
|
||||||
|
cooldown_duration = 0
|
||||||
|
log_level = "warning"
|
||||||
|
log_message = f"API请求无法处理(422),可能是生成故障,不冷却密钥: {key_id}"
|
||||||
elif status_code == 429:
|
elif status_code == 429:
|
||||||
cooldown_duration = 60
|
cooldown_duration = 60
|
||||||
log_level = "warning"
|
log_level = "warning"
|
||||||
log_message = f"API密钥被限流,冷却60秒: {key_id}"
|
log_message = f"API密钥被限流,冷却60秒: {key_id}"
|
||||||
|
elif error_message and (
|
||||||
|
"ConnectError" in error_message
|
||||||
|
or "NetworkError" in error_message
|
||||||
|
or "Connection refused" in error_message
|
||||||
|
or "RemoteProtocolError" in error_message
|
||||||
|
or "ProxyError" in error_message
|
||||||
|
):
|
||||||
|
cooldown_duration = 0
|
||||||
|
log_level = "warning"
|
||||||
|
log_message = f"网络连接层异常(代理/DNS),不冷却密钥: {key_id}"
|
||||||
else:
|
else:
|
||||||
log_level = "warning"
|
log_level = "warning"
|
||||||
log_message = f"API密钥遇到临时性错误,冷却{cooldown_duration}秒: {key_id}"
|
log_message = f"API密钥遇到临时性错误,冷却{cooldown_duration}秒: {key_id}"
|
||||||
|
|||||||
@ -1,193 +0,0 @@
|
|||||||
"""
|
|
||||||
LLM 轻量级工具执行器
|
|
||||||
|
|
||||||
提供驱动 LLM 与本地函数工具之间交互的核心循环。
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from enum import Enum
|
|
||||||
import json
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from zhenxun.services.log import logger
|
|
||||||
from zhenxun.utils.decorator.retry import Retry
|
|
||||||
from zhenxun.utils.pydantic_compat import model_dump
|
|
||||||
|
|
||||||
from .service import LLMModel
|
|
||||||
from .types import (
|
|
||||||
LLMErrorCode,
|
|
||||||
LLMException,
|
|
||||||
LLMMessage,
|
|
||||||
ToolExecutable,
|
|
||||||
ToolResult,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ExecutionConfig(BaseModel):
|
|
||||||
"""
|
|
||||||
轻量级执行器的配置。
|
|
||||||
"""
|
|
||||||
|
|
||||||
max_cycles: int = Field(default=5, description="工具调用循环的最大次数。")
|
|
||||||
|
|
||||||
|
|
||||||
class ToolErrorType(str, Enum):
|
|
||||||
"""结构化工具错误的类型枚举。"""
|
|
||||||
|
|
||||||
TOOL_NOT_FOUND = "ToolNotFound"
|
|
||||||
INVALID_ARGUMENTS = "InvalidArguments"
|
|
||||||
EXECUTION_ERROR = "ExecutionError"
|
|
||||||
USER_CANCELLATION = "UserCancellation"
|
|
||||||
|
|
||||||
|
|
||||||
class ToolErrorResult(BaseModel):
|
|
||||||
"""一个结构化的工具执行错误模型,用于返回给 LLM。"""
|
|
||||||
|
|
||||||
error_type: ToolErrorType = Field(..., description="错误的类型。")
|
|
||||||
message: str = Field(..., description="对错误的详细描述。")
|
|
||||||
is_retryable: bool = Field(False, description="指示这个错误是否可能通过重试解决。")
|
|
||||||
|
|
||||||
def model_dump(self, **kwargs):
|
|
||||||
return model_dump(self, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_exception_retryable(e: Exception) -> bool:
|
|
||||||
"""判断一个异常是否应该触发重试。"""
|
|
||||||
if isinstance(e, LLMException):
|
|
||||||
retryable_codes = {
|
|
||||||
LLMErrorCode.API_REQUEST_FAILED,
|
|
||||||
LLMErrorCode.API_TIMEOUT,
|
|
||||||
LLMErrorCode.API_RATE_LIMITED,
|
|
||||||
}
|
|
||||||
return e.code in retryable_codes
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class LLMToolExecutor:
|
|
||||||
"""
|
|
||||||
一个通用的执行器,负责驱动 LLM 与工具之间的多轮交互。
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, model: LLMModel):
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
messages: list[LLMMessage],
|
|
||||||
tools: dict[str, ToolExecutable],
|
|
||||||
config: ExecutionConfig | None = None,
|
|
||||||
) -> list[LLMMessage]:
|
|
||||||
"""
|
|
||||||
执行完整的思考-行动循环。
|
|
||||||
"""
|
|
||||||
effective_config = config or ExecutionConfig()
|
|
||||||
execution_history = list(messages)
|
|
||||||
|
|
||||||
for i in range(effective_config.max_cycles):
|
|
||||||
response = await self.model.generate_response(
|
|
||||||
execution_history, tools=tools
|
|
||||||
)
|
|
||||||
|
|
||||||
assistant_message = LLMMessage(
|
|
||||||
role="assistant",
|
|
||||||
content=response.text,
|
|
||||||
tool_calls=response.tool_calls,
|
|
||||||
)
|
|
||||||
execution_history.append(assistant_message)
|
|
||||||
|
|
||||||
if not response.tool_calls:
|
|
||||||
logger.info("✅ LLMToolExecutor:模型未请求工具调用,执行结束。")
|
|
||||||
return execution_history
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"🛠️ LLMToolExecutor:模型请求并行调用 {len(response.tool_calls)} 个工具"
|
|
||||||
)
|
|
||||||
tool_results = await self._execute_tools_parallel_safely(
|
|
||||||
response.tool_calls,
|
|
||||||
tools,
|
|
||||||
)
|
|
||||||
execution_history.extend(tool_results)
|
|
||||||
|
|
||||||
raise LLMException(
|
|
||||||
f"超过最大工具调用循环次数 ({effective_config.max_cycles})。",
|
|
||||||
code=LLMErrorCode.GENERATION_FAILED,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _execute_single_tool_safely(
|
|
||||||
self, tool_call: Any, available_tools: dict[str, ToolExecutable]
|
|
||||||
) -> tuple[Any, ToolResult]:
|
|
||||||
"""安全地执行单个工具调用。"""
|
|
||||||
tool_name = tool_call.function.name
|
|
||||||
arguments = {}
|
|
||||||
|
|
||||||
try:
|
|
||||||
if tool_call.function.arguments:
|
|
||||||
arguments = json.loads(tool_call.function.arguments)
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
error_result = ToolErrorResult(
|
|
||||||
error_type=ToolErrorType.INVALID_ARGUMENTS,
|
|
||||||
message=f"参数解析失败: {e}",
|
|
||||||
is_retryable=False,
|
|
||||||
)
|
|
||||||
return tool_call, ToolResult(output=model_dump(error_result))
|
|
||||||
|
|
||||||
try:
|
|
||||||
executable = available_tools.get(tool_name)
|
|
||||||
if not executable:
|
|
||||||
raise LLMException(
|
|
||||||
f"Tool '{tool_name}' not found.",
|
|
||||||
code=LLMErrorCode.CONFIGURATION_ERROR,
|
|
||||||
)
|
|
||||||
|
|
||||||
@Retry.simple(
|
|
||||||
stop_max_attempt=2, wait_fixed_seconds=1, return_on_failure=None
|
|
||||||
)
|
|
||||||
async def execute_with_retry():
|
|
||||||
return await executable.execute(**arguments)
|
|
||||||
|
|
||||||
execution_result = await execute_with_retry()
|
|
||||||
if execution_result is None:
|
|
||||||
raise LLMException("工具执行在多次重试后仍然失败。")
|
|
||||||
|
|
||||||
return tool_call, execution_result
|
|
||||||
except Exception as e:
|
|
||||||
error_type = ToolErrorType.EXECUTION_ERROR
|
|
||||||
is_retryable = _is_exception_retryable(e)
|
|
||||||
if (
|
|
||||||
isinstance(e, LLMException)
|
|
||||||
and e.code == LLMErrorCode.CONFIGURATION_ERROR
|
|
||||||
):
|
|
||||||
error_type = ToolErrorType.TOOL_NOT_FOUND
|
|
||||||
is_retryable = False
|
|
||||||
|
|
||||||
error_result = ToolErrorResult(
|
|
||||||
error_type=error_type, message=str(e), is_retryable=is_retryable
|
|
||||||
)
|
|
||||||
return tool_call, ToolResult(output=model_dump(error_result))
|
|
||||||
|
|
||||||
async def _execute_tools_parallel_safely(
|
|
||||||
self,
|
|
||||||
tool_calls: list[Any],
|
|
||||||
available_tools: dict[str, ToolExecutable],
|
|
||||||
) -> list[LLMMessage]:
|
|
||||||
"""并行执行所有工具调用,并对每个调用的错误进行隔离。"""
|
|
||||||
if not tool_calls:
|
|
||||||
return []
|
|
||||||
|
|
||||||
tasks = [
|
|
||||||
self._execute_single_tool_safely(call, available_tools)
|
|
||||||
for call in tool_calls
|
|
||||||
]
|
|
||||||
results = await asyncio.gather(*tasks)
|
|
||||||
|
|
||||||
tool_messages = [
|
|
||||||
LLMMessage.tool_response(
|
|
||||||
tool_call_id=original_call.id,
|
|
||||||
function_name=original_call.function.name,
|
|
||||||
result=result.output,
|
|
||||||
)
|
|
||||||
for original_call, result in results
|
|
||||||
]
|
|
||||||
return tool_messages
|
|
||||||
@ -13,15 +13,19 @@ from zhenxun.services.log import logger
|
|||||||
from zhenxun.utils.pydantic_compat import dump_json_safely
|
from zhenxun.utils.pydantic_compat import dump_json_safely
|
||||||
|
|
||||||
from .config import validate_override_params
|
from .config import validate_override_params
|
||||||
from .config.providers import AI_CONFIG_GROUP, PROVIDERS_CONFIG_KEY, get_ai_config
|
from .config.generation import LLMGenerationConfig
|
||||||
|
from .config.providers import (
|
||||||
|
AI_CONFIG_GROUP,
|
||||||
|
PROVIDERS_CONFIG_KEY,
|
||||||
|
get_ai_config,
|
||||||
|
get_llm_config,
|
||||||
|
)
|
||||||
from .core import http_client_manager, key_store
|
from .core import http_client_manager, key_store
|
||||||
from .service import LLMModel
|
from .service import LLMModel
|
||||||
from .types import LLMErrorCode, LLMException, ModelDetail, ProviderConfig
|
from .types import LLMErrorCode, LLMException, ModelDetail, ProviderConfig
|
||||||
from .types.capabilities import get_model_capabilities
|
from .types.capabilities import get_model_capabilities
|
||||||
|
|
||||||
DEFAULT_MODEL_NAME_KEY = "default_model_name"
|
DEFAULT_MODEL_NAME_KEY = "default_model_name"
|
||||||
PROXY_KEY = "proxy"
|
|
||||||
TIMEOUT_KEY = "timeout"
|
|
||||||
|
|
||||||
_model_cache: dict[str, tuple[LLMModel, float]] = {}
|
_model_cache: dict[str, tuple[LLMModel, float]] = {}
|
||||||
_cache_ttl = 3600
|
_cache_ttl = 3600
|
||||||
@ -39,7 +43,8 @@ def parse_provider_model_string(name_str: str | None) -> tuple[str | None, str |
|
|||||||
|
|
||||||
|
|
||||||
def _make_cache_key(
|
def _make_cache_key(
|
||||||
provider_model_name: str | None, override_config: dict | None
|
provider_model_name: str | None,
|
||||||
|
override_config: dict | LLMGenerationConfig | None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""生成缓存键"""
|
"""生成缓存键"""
|
||||||
config_str = (
|
config_str = (
|
||||||
@ -115,11 +120,12 @@ def get_default_api_base_for_type(api_type: str) -> str | None:
|
|||||||
"""根据API类型获取默认的API基础地址"""
|
"""根据API类型获取默认的API基础地址"""
|
||||||
default_api_bases = {
|
default_api_bases = {
|
||||||
"openai": "https://api.openai.com",
|
"openai": "https://api.openai.com",
|
||||||
"deepseek": "https://api.deepseek.com",
|
"deepseek": "https://api.deepseek.com/beta",
|
||||||
"zhipu": "https://open.bigmodel.cn",
|
"zhipu": "https://open.bigmodel.cn",
|
||||||
"gemini": "https://generativelanguage.googleapis.com",
|
"gemini": "https://generativelanguage.googleapis.com",
|
||||||
"openrouter": "https://openrouter.ai/api",
|
"openrouter": "https://openrouter.ai/api",
|
||||||
"general_openai_compat": None,
|
"smart": None,
|
||||||
|
"openai_responses": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
return default_api_bases.get(api_type)
|
return default_api_bases.get(api_type)
|
||||||
@ -244,7 +250,7 @@ def list_embedding_models() -> list[dict[str, Any]]:
|
|||||||
|
|
||||||
async def get_model_instance(
|
async def get_model_instance(
|
||||||
provider_model_name: str | None = None,
|
provider_model_name: str | None = None,
|
||||||
override_config: dict[str, Any] | None = None,
|
override_config: dict[str, Any] | LLMGenerationConfig | None = None,
|
||||||
) -> LLMModel:
|
) -> LLMModel:
|
||||||
"""
|
"""
|
||||||
根据 'ProviderName/ModelName' 字符串获取并实例化 LLMModel (异步版本)
|
根据 'ProviderName/ModelName' 字符串获取并实例化 LLMModel (异步版本)
|
||||||
@ -303,21 +309,20 @@ async def get_model_instance(
|
|||||||
|
|
||||||
model_detail_found.is_embedding_model = capabilities.is_embedding_model
|
model_detail_found.is_embedding_model = capabilities.is_embedding_model
|
||||||
|
|
||||||
ai_config = get_ai_config()
|
llm_config = get_llm_config()
|
||||||
global_proxy_setting = ai_config.get(PROXY_KEY)
|
client_settings = llm_config.client_settings
|
||||||
default_timeout = (
|
default_timeout = (
|
||||||
provider_config_found.timeout
|
provider_config_found.timeout
|
||||||
if provider_config_found.timeout is not None
|
if provider_config_found.timeout is not None
|
||||||
else 180
|
else client_settings.timeout
|
||||||
)
|
)
|
||||||
global_timeout_setting = ai_config.get(TIMEOUT_KEY, default_timeout)
|
|
||||||
|
|
||||||
config_for_http_client = ProviderConfig(
|
config_for_http_client = ProviderConfig(
|
||||||
name=provider_config_found.name,
|
name=provider_config_found.name,
|
||||||
api_key=provider_config_found.api_key,
|
api_key=provider_config_found.api_key,
|
||||||
models=provider_config_found.models,
|
models=provider_config_found.models,
|
||||||
timeout=global_timeout_setting,
|
timeout=default_timeout,
|
||||||
proxy=global_proxy_setting,
|
proxy=client_settings.proxy,
|
||||||
api_base=provider_config_found.api_base,
|
api_base=provider_config_found.api_base,
|
||||||
api_type=provider_config_found.api_type,
|
api_type=provider_config_found.api_type,
|
||||||
openai_compat=provider_config_found.openai_compat,
|
openai_compat=provider_config_found.openai_compat,
|
||||||
|
|||||||
@ -1,55 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from collections import defaultdict
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from .types import LLMMessage
|
|
||||||
|
|
||||||
|
|
||||||
class BaseMemory(ABC):
|
|
||||||
"""
|
|
||||||
记忆系统的抽象基类。
|
|
||||||
定义了任何记忆后端都必须实现的接口。
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def get_history(self, session_id: str) -> list[LLMMessage]:
|
|
||||||
"""根据会话ID获取历史记录。"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def add_message(self, session_id: str, message: LLMMessage) -> None:
|
|
||||||
"""向指定会话添加一条消息。"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def add_messages(self, session_id: str, messages: list[LLMMessage]) -> None:
|
|
||||||
"""向指定会话添加多条消息。"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def clear_history(self, session_id: str) -> None:
|
|
||||||
"""清空指定会话的历史记录。"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class InMemoryMemory(BaseMemory):
|
|
||||||
"""
|
|
||||||
一个简单的、默认的内存记忆后端。
|
|
||||||
将历史记录存储在进程内存中的字典里。
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any):
|
|
||||||
self._history: dict[str, list[LLMMessage]] = defaultdict(list)
|
|
||||||
|
|
||||||
async def get_history(self, session_id: str) -> list[LLMMessage]:
|
|
||||||
return self._history.get(session_id, []).copy()
|
|
||||||
|
|
||||||
async def add_message(self, session_id: str, message: LLMMessage) -> None:
|
|
||||||
self._history[session_id].append(message)
|
|
||||||
|
|
||||||
async def add_messages(self, session_id: str, messages: list[LLMMessage]) -> None:
|
|
||||||
self._history[session_id].extend(messages)
|
|
||||||
|
|
||||||
async def clear_history(self, session_id: str) -> None:
|
|
||||||
if session_id in self._history:
|
|
||||||
del self._history[session_id]
|
|
||||||
File diff suppressed because it is too large
Load Diff
@ -4,30 +4,34 @@ LLM 服务 - 会话客户端
|
|||||||
提供一个有状态的、面向会话的 LLM 客户端,用于进行多轮对话和复杂交互。
|
提供一个有状态的、面向会话的 LLM 客户端,用于进行多轮对话和复杂交互。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections import defaultdict
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
import copy
|
import copy
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
import json
|
import json
|
||||||
from typing import Any, TypeVar
|
from typing import Any, TypeVar, cast
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from jinja2 import Environment
|
from jinja2 import Template
|
||||||
from nonebot.compat import type_validate_json
|
from nonebot.utils import is_coroutine_callable
|
||||||
from nonebot_plugin_alconna.uniseg import UniMessage
|
from nonebot_plugin_alconna.uniseg import UniMessage
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
from zhenxun.utils.pydantic_compat import model_copy, model_dump, model_json_schema
|
from zhenxun.utils.pydantic_compat import model_json_schema
|
||||||
|
|
||||||
from .config import (
|
from .config import (
|
||||||
CommonOverrides,
|
CommonOverrides,
|
||||||
|
GenConfigBuilder,
|
||||||
|
LLMEmbeddingConfig,
|
||||||
LLMGenerationConfig,
|
LLMGenerationConfig,
|
||||||
)
|
)
|
||||||
from .config.providers import get_ai_config
|
from .config.generation import OutputConfig
|
||||||
|
from .config.providers import get_ai_config, get_llm_config
|
||||||
from .manager import get_global_default_model_name, get_model_instance
|
from .manager import get_global_default_model_name, get_model_instance
|
||||||
from .memory import BaseMemory, InMemoryMemory
|
from .tools import tool_provider_manager
|
||||||
from .tools.manager import tool_provider_manager
|
|
||||||
from .types import (
|
from .types import (
|
||||||
EmbeddingTaskType,
|
|
||||||
LLMContentPart,
|
LLMContentPart,
|
||||||
LLMErrorCode,
|
LLMErrorCode,
|
||||||
LLMException,
|
LLMException,
|
||||||
@ -35,19 +39,28 @@ from .types import (
|
|||||||
LLMResponse,
|
LLMResponse,
|
||||||
ModelName,
|
ModelName,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
|
StructuredOutputStrategy,
|
||||||
|
ToolChoice,
|
||||||
ToolExecutable,
|
ToolExecutable,
|
||||||
ToolProvider,
|
ToolProvider,
|
||||||
)
|
)
|
||||||
from .utils import normalize_to_llm_messages
|
from .types.models import (
|
||||||
|
GeminiCodeExecution,
|
||||||
|
GeminiGoogleSearch,
|
||||||
|
)
|
||||||
|
from .utils import (
|
||||||
|
create_cot_wrapper,
|
||||||
|
normalize_to_llm_messages,
|
||||||
|
parse_and_validate_json,
|
||||||
|
should_apply_autocot,
|
||||||
|
)
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
jinja_env = Environment(autoescape=False)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AIConfig:
|
class AIConfig:
|
||||||
"""AI配置类 - [重构后] 简化版本"""
|
"""AI配置类"""
|
||||||
|
|
||||||
model: ModelName = None
|
model: ModelName = None
|
||||||
default_embedding_model: ModelName = None
|
default_embedding_model: ModelName = None
|
||||||
@ -61,6 +74,98 @@ class AIConfig:
|
|||||||
self.model = ai_config.get("default_model_name")
|
self.model = ai_config.get("default_model_name")
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMemory(ABC):
|
||||||
|
"""记忆系统的抽象基类。"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_history(self, session_id: str) -> list[LLMMessage]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def add_message(self, session_id: str, message: LLMMessage) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def add_messages(self, session_id: str, messages: list[LLMMessage]) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def clear_history(self, session_id: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class InMemoryMemory(BaseMemory):
|
||||||
|
"""一个简单的、默认的内存记忆后端。"""
|
||||||
|
|
||||||
|
def __init__(self, max_messages: int = 50, **kwargs: Any):
|
||||||
|
self._history: dict[str, list[LLMMessage]] = defaultdict(list)
|
||||||
|
self._max_messages = max_messages
|
||||||
|
|
||||||
|
def _trim_history(self, session_id: str) -> None:
|
||||||
|
"""修剪历史记录,确保不超过最大长度,同时保留 System Prompt"""
|
||||||
|
history = self._history[session_id]
|
||||||
|
if len(history) <= self._max_messages:
|
||||||
|
return
|
||||||
|
|
||||||
|
has_system = history and history[0].role == "system"
|
||||||
|
|
||||||
|
if has_system:
|
||||||
|
keep_count = max(0, self._max_messages - 1)
|
||||||
|
self._history[session_id] = [history[0], *history[-keep_count:]]
|
||||||
|
else:
|
||||||
|
self._history[session_id] = history[-self._max_messages :]
|
||||||
|
|
||||||
|
async def get_history(self, session_id: str) -> list[LLMMessage]:
|
||||||
|
return self._history.get(session_id, []).copy()
|
||||||
|
|
||||||
|
async def add_message(self, session_id: str, message: LLMMessage) -> None:
|
||||||
|
self._history[session_id].append(message)
|
||||||
|
self._trim_history(session_id)
|
||||||
|
|
||||||
|
async def add_messages(self, session_id: str, messages: list[LLMMessage]) -> None:
|
||||||
|
self._history[session_id].extend(messages)
|
||||||
|
self._trim_history(session_id)
|
||||||
|
|
||||||
|
async def clear_history(self, session_id: str) -> None:
|
||||||
|
if session_id in self._history:
|
||||||
|
del self._history[session_id]
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryProcessor(ABC):
|
||||||
|
"""记忆处理器接口"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def process(self, session_id: str, new_messages: list[LLMMessage]) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
_default_memory_factory: Callable[[], BaseMemory] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def set_default_memory_backend(factory: Callable[[], BaseMemory]):
|
||||||
|
"""
|
||||||
|
设置全局默认记忆后端工厂,允许统一替换会话的记忆实现。
|
||||||
|
"""
|
||||||
|
global _default_memory_factory
|
||||||
|
_default_memory_factory = factory
|
||||||
|
|
||||||
|
|
||||||
|
def _get_default_memory() -> BaseMemory:
|
||||||
|
if _default_memory_factory:
|
||||||
|
return _default_memory_factory()
|
||||||
|
return InMemoryMemory()
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_IVR_TEMPLATE = (
|
||||||
|
"你的响应未能通过结构校验。\n"
|
||||||
|
"错误详情: {error_msg}\n\n"
|
||||||
|
"请执行以下步骤进行修正:\n"
|
||||||
|
"1. 反思:分析为什么会出现这个错误。\n"
|
||||||
|
"2. 修正:生成一个新的、符合 Schema 要求的 JSON 对象。\n"
|
||||||
|
"请直接输出修正后的 JSON,不要包含 Markdown 标记或其他解释。"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AI:
|
class AI:
|
||||||
"""
|
"""
|
||||||
统一的AI服务类 - 提供了带记忆的会话接口。
|
统一的AI服务类 - 提供了带记忆的会话接口。
|
||||||
@ -73,6 +178,7 @@ class AI:
|
|||||||
config: AIConfig | None = None,
|
config: AIConfig | None = None,
|
||||||
memory: BaseMemory | None = None,
|
memory: BaseMemory | None = None,
|
||||||
default_generation_config: LLMGenerationConfig | None = None,
|
default_generation_config: LLMGenerationConfig | None = None,
|
||||||
|
processors: list[MemoryProcessor] | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
初始化AI服务
|
初始化AI服务
|
||||||
@ -81,24 +187,45 @@ class AI:
|
|||||||
session_id: 唯一的会话ID,用于隔离记忆。
|
session_id: 唯一的会话ID,用于隔离记忆。
|
||||||
config: AI 配置.
|
config: AI 配置.
|
||||||
memory: 可选的自定义记忆后端。如果为None,则使用默认的InMemoryMemory。
|
memory: 可选的自定义记忆后端。如果为None,则使用默认的InMemoryMemory。
|
||||||
default_generation_config: (新增) 此AI实例的默认生成配置。
|
default_generation_config: 此AI实例的默认生成配置。
|
||||||
|
processors: 记忆处理器列表,在添加记忆后触发。
|
||||||
"""
|
"""
|
||||||
self.session_id = session_id or str(uuid.uuid4())
|
self.session_id = session_id or str(uuid.uuid4())
|
||||||
self.config = config or AIConfig()
|
self.config = config or AIConfig()
|
||||||
self.memory = memory or InMemoryMemory()
|
self.memory = memory or _get_default_memory()
|
||||||
self.default_generation_config = (
|
self.default_generation_config = (
|
||||||
default_generation_config or LLMGenerationConfig()
|
default_generation_config or LLMGenerationConfig()
|
||||||
)
|
)
|
||||||
|
self.processors = processors or []
|
||||||
|
|
||||||
global_providers = tool_provider_manager._providers
|
global_providers = tool_provider_manager._providers
|
||||||
config_providers = self.config.tool_providers
|
config_providers = self.config.tool_providers
|
||||||
self._tool_providers = list(dict.fromkeys(global_providers + config_providers))
|
self._tool_providers = list(dict.fromkeys(global_providers + config_providers))
|
||||||
|
self.message_buffer: list[LLMMessage] = []
|
||||||
|
|
||||||
async def clear_history(self):
|
async def clear_history(self):
|
||||||
"""清空当前会话的历史记录。"""
|
"""清空当前会话的历史记录。"""
|
||||||
await self.memory.clear_history(self.session_id)
|
await self.memory.clear_history(self.session_id)
|
||||||
logger.info(f"AI会话历史记录已清空 (session_id: {self.session_id})")
|
logger.info(f"AI会话历史记录已清空 (session_id: {self.session_id})")
|
||||||
|
|
||||||
|
async def add_observation(
|
||||||
|
self, message: str | UniMessage | LLMMessage | list[LLMContentPart]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
将一条观察消息加入缓冲区,不立即触发模型调用。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
int: 缓冲区中消息的数量。
|
||||||
|
"""
|
||||||
|
current_message = await self._normalize_input_to_message(message)
|
||||||
|
self.message_buffer.append(current_message)
|
||||||
|
content_preview = str(current_message.content)[:50]
|
||||||
|
logger.debug(
|
||||||
|
f"[放入观察] {content_preview} (缓冲区大小: {len(self.message_buffer)})",
|
||||||
|
"AI_MEMORY",
|
||||||
|
)
|
||||||
|
return len(self.message_buffer)
|
||||||
|
|
||||||
async def add_user_message_to_history(
|
async def add_user_message_to_history(
|
||||||
self, message: str | LLMMessage | list[LLMContentPart]
|
self, message: str | LLMMessage | list[LLMContentPart]
|
||||||
):
|
):
|
||||||
@ -161,7 +288,7 @@ class AI:
|
|||||||
self, message: str | UniMessage | LLMMessage | list[LLMContentPart]
|
self, message: str | UniMessage | LLMMessage | list[LLMContentPart]
|
||||||
) -> LLMMessage:
|
) -> LLMMessage:
|
||||||
"""
|
"""
|
||||||
[重构后] 内部辅助方法,将各种输入类型统一转换为单个 LLMMessage 对象。
|
内部辅助方法,将各种输入类型统一转换为单个 LLMMessage 对象。
|
||||||
它调用共享的工具函数并提取最后一条消息(通常是用户输入)。
|
它调用共享的工具函数并提取最后一条消息(通常是用户输入)。
|
||||||
"""
|
"""
|
||||||
messages = await normalize_to_llm_messages(message)
|
messages = await normalize_to_llm_messages(message)
|
||||||
@ -172,17 +299,79 @@ class AI:
|
|||||||
)
|
)
|
||||||
return messages[-1]
|
return messages[-1]
|
||||||
|
|
||||||
|
async def generate_internal(
|
||||||
|
self,
|
||||||
|
messages: list[LLMMessage],
|
||||||
|
*,
|
||||||
|
model: ModelName = None,
|
||||||
|
config: LLMGenerationConfig | GenConfigBuilder | None = None,
|
||||||
|
tools: list[Any] | dict[str, ToolExecutable] | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | ToolChoice | None = None,
|
||||||
|
timeout: float | None = None,
|
||||||
|
model_instance: Any = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""
|
||||||
|
内部生成核心方法,负责配置合并、工具解析和模型调用。
|
||||||
|
此方法不处理历史记录的存储,供 AgentExecutor 或 chat 方法调用。
|
||||||
|
"""
|
||||||
|
final_config = self.default_generation_config
|
||||||
|
if isinstance(config, GenConfigBuilder):
|
||||||
|
config = config.build()
|
||||||
|
|
||||||
|
if config:
|
||||||
|
final_config = final_config.merge_with(config)
|
||||||
|
|
||||||
|
final_tools_list = []
|
||||||
|
if tools:
|
||||||
|
if isinstance(tools, dict):
|
||||||
|
final_tools_list = list(tools.values())
|
||||||
|
elif isinstance(tools, list):
|
||||||
|
to_resolve: list[Any] = []
|
||||||
|
for t in tools:
|
||||||
|
if isinstance(t, str | dict):
|
||||||
|
to_resolve.append(t)
|
||||||
|
else:
|
||||||
|
final_tools_list.append(t)
|
||||||
|
|
||||||
|
if to_resolve:
|
||||||
|
resolved_dict = await self._resolve_tools(to_resolve)
|
||||||
|
final_tools_list.extend(resolved_dict.values())
|
||||||
|
|
||||||
|
if model_instance:
|
||||||
|
return await model_instance.generate_response(
|
||||||
|
messages,
|
||||||
|
config=final_config,
|
||||||
|
tools=final_tools_list if final_tools_list else None,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
resolved_model_name = self._resolve_model_name(model or self.config.model)
|
||||||
|
async with await get_model_instance(
|
||||||
|
resolved_model_name,
|
||||||
|
override_config=None,
|
||||||
|
) as instance:
|
||||||
|
return await instance.generate_response(
|
||||||
|
messages,
|
||||||
|
config=final_config,
|
||||||
|
tools=final_tools_list if final_tools_list else None,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
message: str | UniMessage | LLMMessage | list[LLMContentPart],
|
message: str | UniMessage | LLMMessage | list[LLMContentPart] | None,
|
||||||
*,
|
*,
|
||||||
model: ModelName = None,
|
model: ModelName = None,
|
||||||
instruction: str | None = None,
|
instruction: str | None = None,
|
||||||
template_vars: dict[str, Any] | None = None,
|
template_vars: dict[str, Any] | None = None,
|
||||||
preserve_media_in_history: bool | None = None,
|
preserve_media_in_history: bool | None = None,
|
||||||
tools: list[dict[str, Any] | str] | dict[str, ToolExecutable] | None = None,
|
tools: list[Any] | dict[str, ToolExecutable] | None = None,
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
tool_choice: str | dict[str, Any] | ToolChoice | None = None,
|
||||||
config: LLMGenerationConfig | None = None,
|
config: LLMGenerationConfig | GenConfigBuilder | None = None,
|
||||||
|
use_buffer: bool = False,
|
||||||
|
timeout: float | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""
|
"""
|
||||||
核心交互方法,管理会话历史并执行单次LLM调用。
|
核心交互方法,管理会话历史并执行单次LLM调用。
|
||||||
@ -198,18 +387,27 @@ class AI:
|
|||||||
tools: 可用的工具列表或工具字典,支持临时工具和预配置工具。
|
tools: 可用的工具列表或工具字典,支持临时工具和预配置工具。
|
||||||
tool_choice: 工具选择策略,控制AI如何选择和使用工具。
|
tool_choice: 工具选择策略,控制AI如何选择和使用工具。
|
||||||
config: 生成配置对象,用于覆盖默认的生成参数。
|
config: 生成配置对象,用于覆盖默认的生成参数。
|
||||||
|
use_buffer: 是否刷新并包含消息缓冲区的内容,在此次对话中一次性提交。
|
||||||
|
timeout: HTTP 请求超时时间(秒)。
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
LLMResponse: 包含AI回复、工具调用请求、使用信息等的完整响应对象。
|
LLMResponse: 包含AI回复、工具调用请求、使用信息等的完整响应对象。
|
||||||
"""
|
"""
|
||||||
current_message = await self._normalize_input_to_message(message)
|
messages_to_add: list[LLMMessage] = []
|
||||||
|
if message:
|
||||||
|
current_message = await self._normalize_input_to_message(message)
|
||||||
|
messages_to_add.append(current_message)
|
||||||
|
|
||||||
|
if use_buffer and self.message_buffer:
|
||||||
|
messages_to_add = self.message_buffer + messages_to_add
|
||||||
|
self.message_buffer.clear()
|
||||||
|
|
||||||
messages_for_run = []
|
messages_for_run = []
|
||||||
final_instruction = instruction
|
final_instruction = instruction
|
||||||
|
|
||||||
if final_instruction and template_vars:
|
if final_instruction and template_vars:
|
||||||
try:
|
try:
|
||||||
template = jinja_env.from_string(final_instruction)
|
template = Template(final_instruction)
|
||||||
final_instruction = template.render(**template_vars)
|
final_instruction = template.render(**template_vars)
|
||||||
logger.debug(f"渲染后的系统指令: {final_instruction}")
|
logger.debug(f"渲染后的系统指令: {final_instruction}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -220,51 +418,55 @@ class AI:
|
|||||||
|
|
||||||
current_history = await self.memory.get_history(self.session_id)
|
current_history = await self.memory.get_history(self.session_id)
|
||||||
messages_for_run.extend(current_history)
|
messages_for_run.extend(current_history)
|
||||||
messages_for_run.append(current_message)
|
messages_for_run.extend(messages_to_add)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
resolved_model_name = self._resolve_model_name(model or self.config.model)
|
response = await self.generate_internal(
|
||||||
|
messages_for_run,
|
||||||
final_config = model_copy(self.default_generation_config, deep=True)
|
model=model,
|
||||||
if config:
|
config=config,
|
||||||
update_dict = model_dump(config, exclude_unset=True)
|
tools=tools,
|
||||||
final_config = model_copy(final_config, update=update_dict)
|
tool_choice=tool_choice,
|
||||||
|
timeout=timeout,
|
||||||
ad_hoc_tools = None
|
)
|
||||||
if tools:
|
|
||||||
if isinstance(tools, dict):
|
|
||||||
ad_hoc_tools = tools
|
|
||||||
else:
|
|
||||||
ad_hoc_tools = await self._resolve_tools(tools)
|
|
||||||
|
|
||||||
async with await get_model_instance(
|
|
||||||
resolved_model_name,
|
|
||||||
override_config=final_config.to_dict(),
|
|
||||||
) as model_instance:
|
|
||||||
response = await model_instance.generate_response(
|
|
||||||
messages_for_run, tools=ad_hoc_tools, tool_choice=tool_choice
|
|
||||||
)
|
|
||||||
|
|
||||||
should_preserve = (
|
should_preserve = (
|
||||||
preserve_media_in_history
|
preserve_media_in_history
|
||||||
if preserve_media_in_history is not None
|
if preserve_media_in_history is not None
|
||||||
else self.config.default_preserve_media_in_history
|
else self.config.default_preserve_media_in_history
|
||||||
)
|
)
|
||||||
user_msg_to_store = (
|
msgs_to_store: list[LLMMessage] = []
|
||||||
current_message
|
for msg in messages_to_add:
|
||||||
if should_preserve
|
store_msg = (
|
||||||
else self._sanitize_message_for_history(current_message)
|
msg if should_preserve else self._sanitize_message_for_history(msg)
|
||||||
)
|
|
||||||
assistant_response_msg = LLMMessage.assistant_text_response(response.text)
|
|
||||||
if response.tool_calls:
|
|
||||||
assistant_response_msg = LLMMessage.assistant_tool_calls(
|
|
||||||
response.tool_calls, response.text
|
|
||||||
)
|
)
|
||||||
|
msgs_to_store.append(store_msg)
|
||||||
|
|
||||||
|
if response.content_parts:
|
||||||
|
assistant_response_msg = LLMMessage(
|
||||||
|
role="assistant",
|
||||||
|
content=response.content_parts,
|
||||||
|
tool_calls=response.tool_calls,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assistant_response_msg = LLMMessage.assistant_text_response(
|
||||||
|
response.text
|
||||||
|
)
|
||||||
|
if response.tool_calls:
|
||||||
|
assistant_response_msg = LLMMessage.assistant_tool_calls(
|
||||||
|
response.tool_calls, response.text
|
||||||
|
)
|
||||||
|
|
||||||
await self.memory.add_messages(
|
await self.memory.add_messages(
|
||||||
self.session_id, [user_msg_to_store, assistant_response_msg]
|
self.session_id, [*msgs_to_store, assistant_response_msg]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.processors:
|
||||||
|
for processor in self.processors:
|
||||||
|
await processor.process(
|
||||||
|
self.session_id, [*msgs_to_store, assistant_response_msg]
|
||||||
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -280,7 +482,7 @@ class AI:
|
|||||||
*,
|
*,
|
||||||
model: ModelName = None,
|
model: ModelName = None,
|
||||||
timeout: int | None = None,
|
timeout: int | None = None,
|
||||||
config: LLMGenerationConfig | None = None,
|
config: LLMGenerationConfig | GenConfigBuilder | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""
|
"""
|
||||||
代码执行
|
代码执行
|
||||||
@ -294,16 +496,18 @@ class AI:
|
|||||||
返回:
|
返回:
|
||||||
LLMResponse: 包含执行结果的完整响应对象。
|
LLMResponse: 包含执行结果的完整响应对象。
|
||||||
"""
|
"""
|
||||||
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
|
resolved_model = model or self.config.model
|
||||||
|
|
||||||
code_config = CommonOverrides.gemini_code_execution()
|
code_config = CommonOverrides.gemini_code_execution()
|
||||||
if timeout:
|
if timeout:
|
||||||
code_config.custom_params = code_config.custom_params or {}
|
code_config.custom_params = code_config.custom_params or {}
|
||||||
code_config.custom_params["code_execution_timeout"] = timeout
|
code_config.custom_params["code_execution_timeout"] = timeout
|
||||||
|
|
||||||
|
if isinstance(config, GenConfigBuilder):
|
||||||
|
config = config.build()
|
||||||
|
|
||||||
if config:
|
if config:
|
||||||
update_dict = model_dump(config, exclude_unset=True)
|
code_config = code_config.merge_with(config)
|
||||||
code_config = model_copy(code_config, update=update_dict)
|
|
||||||
|
|
||||||
return await self.chat(prompt, model=resolved_model, config=code_config)
|
return await self.chat(prompt, model=resolved_model, config=code_config)
|
||||||
|
|
||||||
@ -317,7 +521,7 @@ class AI:
|
|||||||
"根据用户的查询找到最相关的信息,并进行总结和回答。"
|
"根据用户的查询找到最相关的信息,并进行总结和回答。"
|
||||||
),
|
),
|
||||||
template_vars: dict[str, Any] | None = None,
|
template_vars: dict[str, Any] | None = None,
|
||||||
config: LLMGenerationConfig | None = None,
|
config: LLMGenerationConfig | GenConfigBuilder | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""
|
"""
|
||||||
信息搜索的便捷入口,原生支持多模态查询。
|
信息搜索的便捷入口,原生支持多模态查询。
|
||||||
@ -325,9 +529,11 @@ class AI:
|
|||||||
logger.info("执行 'search' 任务...")
|
logger.info("执行 'search' 任务...")
|
||||||
search_config = CommonOverrides.gemini_grounding()
|
search_config = CommonOverrides.gemini_grounding()
|
||||||
|
|
||||||
|
if isinstance(config, GenConfigBuilder):
|
||||||
|
config = config.build()
|
||||||
|
|
||||||
if config:
|
if config:
|
||||||
update_dict = model_dump(config, exclude_unset=True)
|
search_config = search_config.merge_with(config)
|
||||||
search_config = model_copy(search_config, update=update_dict)
|
|
||||||
|
|
||||||
return await self.chat(
|
return await self.chat(
|
||||||
query,
|
query,
|
||||||
@ -339,21 +545,31 @@ class AI:
|
|||||||
|
|
||||||
async def generate_structured(
|
async def generate_structured(
|
||||||
self,
|
self,
|
||||||
message: str | LLMMessage | list[LLMContentPart],
|
message: str | LLMMessage | list[LLMContentPart] | None,
|
||||||
response_model: type[T],
|
response_model: type[T],
|
||||||
*,
|
*,
|
||||||
model: ModelName = None,
|
model: ModelName = None,
|
||||||
|
tools: list[Any] | dict[str, ToolExecutable] | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | ToolChoice | None = None,
|
||||||
instruction: str | None = None,
|
instruction: str | None = None,
|
||||||
config: LLMGenerationConfig | None = None,
|
timeout: float | None = None,
|
||||||
|
template_vars: dict[str, Any] | None = None,
|
||||||
|
config: LLMGenerationConfig | GenConfigBuilder | None = None,
|
||||||
|
max_validation_retries: int | None = None,
|
||||||
|
validation_callback: Callable[[T], Any | Awaitable[Any]] | None = None,
|
||||||
|
error_prompt_template: str | None = None,
|
||||||
|
auto_thinking: bool = False,
|
||||||
) -> T:
|
) -> T:
|
||||||
"""
|
"""
|
||||||
生成结构化响应,并自动解析为指定的Pydantic模型。
|
生成结构化响应,并自动解析为指定的Pydantic模型。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
message: 用户输入的消息内容,支持多种格式。
|
message: 用户输入的消息内容,支持多种格式。为None时只使用历史+缓冲区。
|
||||||
response_model: 用于解析和验证响应的Pydantic模型类。
|
response_model: 用于解析和验证响应的Pydantic模型类。
|
||||||
model: 要使用的模型名称,如果为None则使用配置中的默认模型。
|
model: 要使用的模型名称,如果为None则使用配置中的默认模型。
|
||||||
instruction: 本次调用的特定系统指令,会与JSON Schema指令合并。
|
instruction: 本次调用的特定系统指令,会与JSON Schema指令合并。
|
||||||
|
timeout: HTTP 请求超时时间(秒)。
|
||||||
|
template_vars: 系统指令中的模板变量,用于动态渲染。
|
||||||
config: 生成配置对象,用于覆盖默认的生成参数。
|
config: 生成配置对象,用于覆盖默认的生成参数。
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
@ -362,6 +578,46 @@ class AI:
|
|||||||
异常:
|
异常:
|
||||||
LLMException: 如果模型返回的不是有效的JSON或验证失败。
|
LLMException: 如果模型返回的不是有效的JSON或验证失败。
|
||||||
"""
|
"""
|
||||||
|
if isinstance(config, GenConfigBuilder):
|
||||||
|
config = config.build()
|
||||||
|
|
||||||
|
final_config = self.default_generation_config.merge_with(config)
|
||||||
|
|
||||||
|
if final_config is None:
|
||||||
|
final_config = LLMGenerationConfig()
|
||||||
|
|
||||||
|
if max_validation_retries is None:
|
||||||
|
max_validation_retries = get_llm_config().client_settings.structured_retries
|
||||||
|
|
||||||
|
resolved_model_name = self._resolve_model_name(model or self.config.model)
|
||||||
|
|
||||||
|
request_autocot = True if auto_thinking is False else auto_thinking
|
||||||
|
effective_auto_thinking = should_apply_autocot(
|
||||||
|
request_autocot, resolved_model_name, final_config
|
||||||
|
)
|
||||||
|
|
||||||
|
target_model: type[T] = response_model
|
||||||
|
if effective_auto_thinking:
|
||||||
|
target_model = cast(type[T], create_cot_wrapper(response_model))
|
||||||
|
response_model = target_model
|
||||||
|
|
||||||
|
cot_instruction = (
|
||||||
|
"请务必先在 `reasoning` 字段中进行详细的一步步推理,确保逻辑正确,"
|
||||||
|
"然后再填充 `result` 字段。"
|
||||||
|
)
|
||||||
|
if instruction:
|
||||||
|
instruction = f"{instruction}\n\n{cot_instruction}"
|
||||||
|
else:
|
||||||
|
instruction = cot_instruction
|
||||||
|
|
||||||
|
final_instruction = instruction
|
||||||
|
if final_instruction and template_vars:
|
||||||
|
try:
|
||||||
|
template = Template(final_instruction)
|
||||||
|
final_instruction = template.render(**template_vars)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"渲染结构化指令模板失败: {e}", e=e)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
json_schema = model_json_schema(response_model)
|
json_schema = model_json_schema(response_model)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
@ -369,41 +625,149 @@ class AI:
|
|||||||
|
|
||||||
schema_str = json.dumps(json_schema, ensure_ascii=False, indent=2)
|
schema_str = json.dumps(json_schema, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
system_prompt = (
|
prompt_prefix = f"{final_instruction}\n\n" if final_instruction else ""
|
||||||
(f"{instruction}\n\n" if instruction else "")
|
structured_strategy = (
|
||||||
+ "你必须严格按照以下 JSON Schema 格式进行响应。"
|
final_config.output.structured_output_strategy
|
||||||
+ "不要包含任何额外的解释、注释或代码块标记,只返回纯粹的 JSON 对象。\n\n"
|
if final_config.output
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
system_prompt += f"JSON Schema:\n```json\n{schema_str}\n```"
|
if structured_strategy == StructuredOutputStrategy.TOOL_CALL:
|
||||||
|
system_prompt = prompt_prefix + "请调用提供的工具提交结构化数据。"
|
||||||
|
else:
|
||||||
|
system_prompt = (
|
||||||
|
prompt_prefix
|
||||||
|
+ "请严格按照以下 JSON Schema 格式进行响应。不应包含任何额外的解释、"
|
||||||
|
"注释或代码块标记,只返回一个合法的 JSON 对象。\n\n"
|
||||||
|
)
|
||||||
|
system_prompt += f"JSON Schema:\n```json\n{schema_str}\n```"
|
||||||
|
|
||||||
final_config = model_copy(config) if config else LLMGenerationConfig()
|
structured_strategy = (
|
||||||
|
final_config.output.structured_output_strategy
|
||||||
final_config.response_format = ResponseFormat.JSON
|
if final_config.output
|
||||||
final_config.response_schema = json_schema
|
else StructuredOutputStrategy.NATIVE
|
||||||
|
|
||||||
response = await self.chat(
|
|
||||||
message, model=model, instruction=system_prompt, config=final_config
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
final_tools_list: list[ToolExecutable] | None = None
|
||||||
return type_validate_json(response_model, response.text)
|
if structured_strategy != StructuredOutputStrategy.NATIVE:
|
||||||
except ValidationError as e:
|
if tools:
|
||||||
logger.error(f"LLM结构化输出验证失败: {e}", e=e)
|
final_tools_list = []
|
||||||
raise LLMException(
|
if isinstance(tools, dict):
|
||||||
"LLM返回的JSON未能通过结构验证。",
|
final_tools_list = list(tools.values())
|
||||||
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
|
elif isinstance(tools, list):
|
||||||
details={"raw_response": response.text, "validation_error": str(e)},
|
to_resolve: list[Any] = []
|
||||||
cause=e,
|
for t in tools:
|
||||||
)
|
if isinstance(t, str | dict):
|
||||||
except Exception as e:
|
to_resolve.append(t)
|
||||||
logger.error(f"解析LLM结构化输出时发生未知错误: {e}", e=e)
|
else:
|
||||||
raise LLMException(
|
final_tools_list.append(t)
|
||||||
"解析LLM的JSON输出时失败。",
|
if to_resolve:
|
||||||
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
|
resolved_dict = await self._resolve_tools(to_resolve)
|
||||||
details={"raw_response": response.text},
|
final_tools_list.extend(resolved_dict.values())
|
||||||
cause=e,
|
elif tools:
|
||||||
|
logger.warning(
|
||||||
|
"检测到在 generate_structured (NATIVE 策略) 中传入了 tools。"
|
||||||
|
"为了避免 API 冲突(Gemini)及输出歧义(OpenAI),这些"
|
||||||
|
"tools 将被本次请求忽略。"
|
||||||
|
"若需使用工具,请使用 chat() 方法或 Agent 流程。"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if final_config.output is None:
|
||||||
|
final_config.output = OutputConfig()
|
||||||
|
|
||||||
|
final_config.output.response_format = ResponseFormat.JSON
|
||||||
|
final_config.output.response_schema = json_schema
|
||||||
|
|
||||||
|
messages_for_run = [LLMMessage.system(system_prompt)]
|
||||||
|
current_history = await self.memory.get_history(self.session_id)
|
||||||
|
messages_for_run.extend(current_history)
|
||||||
|
messages_for_run.extend(self.message_buffer)
|
||||||
|
if message:
|
||||||
|
normalized_message = await self._normalize_input_to_message(message)
|
||||||
|
messages_for_run.append(normalized_message)
|
||||||
|
|
||||||
|
ivr_messages = list(messages_for_run)
|
||||||
|
last_exception: Exception | None = None
|
||||||
|
|
||||||
|
for attempt in range(max_validation_retries + 1):
|
||||||
|
current_response_text: str = ""
|
||||||
|
|
||||||
|
async with await get_model_instance(
|
||||||
|
resolved_model_name,
|
||||||
|
override_config=None,
|
||||||
|
) as model_instance:
|
||||||
|
response = await model_instance.generate_response(
|
||||||
|
ivr_messages,
|
||||||
|
config=final_config,
|
||||||
|
tools=final_tools_list if final_tools_list else None,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
current_response_text = response.text
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed_obj = parse_and_validate_json(response.text, target_model)
|
||||||
|
|
||||||
|
final_obj: T = cast(T, parsed_obj)
|
||||||
|
if effective_auto_thinking:
|
||||||
|
logger.debug(
|
||||||
|
f"AutoCoT 思考过程: {getattr(parsed_obj, 'reasoning', '')}"
|
||||||
|
)
|
||||||
|
final_obj = cast(T, getattr(parsed_obj, "result"))
|
||||||
|
|
||||||
|
if validation_callback:
|
||||||
|
if is_coroutine_callable(validation_callback):
|
||||||
|
await validation_callback(final_obj)
|
||||||
|
else:
|
||||||
|
validation_callback(final_obj)
|
||||||
|
|
||||||
|
return final_obj
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
is_llm_error = isinstance(e, LLMException)
|
||||||
|
llm_error: LLMException | None = (
|
||||||
|
cast(LLMException, e) if is_llm_error else None
|
||||||
|
)
|
||||||
|
last_exception = e
|
||||||
|
|
||||||
|
if attempt < max_validation_retries:
|
||||||
|
error_msg = (
|
||||||
|
llm_error.details.get("validation_error", str(e))
|
||||||
|
if llm_error
|
||||||
|
else str(e)
|
||||||
|
)
|
||||||
|
raw_response = current_response_text or (
|
||||||
|
llm_error.details.get("raw_response", "") if llm_error else ""
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
f"结构化校验失败 (尝试 {attempt + 1}/"
|
||||||
|
f"{max_validation_retries + 1})。正在尝试 IVR 修复... 错误:"
|
||||||
|
f"{error_msg}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if raw_response:
|
||||||
|
ivr_messages.append(
|
||||||
|
LLMMessage.assistant_text_response(raw_response)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"IVR 警告: 无法获取上一轮生成的原始文本,"
|
||||||
|
"模型将在无上下文情况下尝试修复。"
|
||||||
|
)
|
||||||
|
|
||||||
|
template = error_prompt_template or DEFAULT_IVR_TEMPLATE
|
||||||
|
feedback_prompt = template.format(error_msg=error_msg)
|
||||||
|
ivr_messages.append(LLMMessage.user(feedback_prompt))
|
||||||
|
continue
|
||||||
|
|
||||||
|
if llm_error and not llm_error.recoverable:
|
||||||
|
raise llm_error
|
||||||
|
|
||||||
|
if last_exception:
|
||||||
|
raise last_exception
|
||||||
|
raise LLMException(
|
||||||
|
"IVR 循环异常结束,未能生成有效结果。", code=LLMErrorCode.GENERATION_FAILED
|
||||||
|
)
|
||||||
|
|
||||||
def _resolve_model_name(self, model_name: ModelName) -> str:
|
def _resolve_model_name(self, model_name: ModelName) -> str:
|
||||||
"""解析模型名称"""
|
"""解析模型名称"""
|
||||||
if model_name:
|
if model_name:
|
||||||
@ -423,8 +787,7 @@ class AI:
|
|||||||
texts: list[str] | str,
|
texts: list[str] | str,
|
||||||
*,
|
*,
|
||||||
model: ModelName = None,
|
model: ModelName = None,
|
||||||
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
config: LLMEmbeddingConfig | None = None,
|
||||||
**kwargs: Any,
|
|
||||||
) -> list[list[float]]:
|
) -> list[list[float]]:
|
||||||
"""
|
"""
|
||||||
生成文本嵌入向量,将文本转换为数值向量表示。
|
生成文本嵌入向量,将文本转换为数值向量表示。
|
||||||
@ -432,14 +795,13 @@ class AI:
|
|||||||
参数:
|
参数:
|
||||||
texts: 要生成嵌入的文本内容,支持单个字符串或字符串列表。
|
texts: 要生成嵌入的文本内容,支持单个字符串或字符串列表。
|
||||||
model: 嵌入模型名称,如果为None则使用配置中的默认嵌入模型。
|
model: 嵌入模型名称,如果为None则使用配置中的默认嵌入模型。
|
||||||
task_type: 嵌入任务类型,影响向量的优化方向(如检索、分类等)。
|
config: 嵌入配置
|
||||||
**kwargs: 传递给嵌入模型的额外参数。
|
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
list[list[float]]: 文本对应的嵌入向量列表,每个向量为浮点数列表。
|
list[list[float]]: 文本对应的嵌入向量列表,每个向量为浮点数列表。
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
LLMException: 如果嵌入生成失败或模型配置错误。
|
LLMException: 当嵌入生成失败或模型配置错误时抛出
|
||||||
"""
|
"""
|
||||||
if isinstance(texts, str):
|
if isinstance(texts, str):
|
||||||
texts = [texts]
|
texts = [texts]
|
||||||
@ -452,18 +814,20 @@ class AI:
|
|||||||
)
|
)
|
||||||
if not resolved_model_str:
|
if not resolved_model_str:
|
||||||
raise LLMException(
|
raise LLMException(
|
||||||
"使用 embed 功能时必须指定嵌入模型名称,"
|
"使用 embed 方法时未指定嵌入模型名称,"
|
||||||
"或在 AIConfig 中配置 default_embedding_model。",
|
"且 AIConfig 未设置 default_embedding_model。",
|
||||||
code=LLMErrorCode.MODEL_NOT_FOUND,
|
code=LLMErrorCode.MODEL_NOT_FOUND,
|
||||||
)
|
)
|
||||||
resolved_model_str = self._resolve_model_name(resolved_model_str)
|
resolved_model_str = self._resolve_model_name(resolved_model_str)
|
||||||
|
|
||||||
|
final_config = config or LLMEmbeddingConfig()
|
||||||
|
|
||||||
async with await get_model_instance(
|
async with await get_model_instance(
|
||||||
resolved_model_str,
|
resolved_model_str,
|
||||||
override_config=None,
|
override_config=None,
|
||||||
) as embedding_model_instance:
|
) as embedding_model_instance:
|
||||||
return await embedding_model_instance.generate_embeddings(
|
return await embedding_model_instance.generate_embeddings(
|
||||||
texts, task_type=task_type, **kwargs
|
texts, config=final_config
|
||||||
)
|
)
|
||||||
except LLMException:
|
except LLMException:
|
||||||
raise
|
raise
|
||||||
@ -484,6 +848,15 @@ class AI:
|
|||||||
resolved: dict[str, ToolExecutable] = {}
|
resolved: dict[str, ToolExecutable] = {}
|
||||||
|
|
||||||
for config in tool_configs:
|
for config in tool_configs:
|
||||||
|
if isinstance(config, str):
|
||||||
|
if config == "google_search":
|
||||||
|
resolved[config] = GeminiGoogleSearch() # type: ignore[arg-type]
|
||||||
|
continue
|
||||||
|
elif config == "code_execution":
|
||||||
|
resolved[config] = GeminiCodeExecution() # type: ignore[arg-type]
|
||||||
|
continue
|
||||||
|
elif config == "url_context":
|
||||||
|
pass
|
||||||
name = config if isinstance(config, str) else config.get("name")
|
name = config if isinstance(config, str) else config.get("name")
|
||||||
if not name:
|
if not name:
|
||||||
raise LLMException(
|
raise LLMException(
|
||||||
|
|||||||
839
zhenxun/services/llm/tools.py
Normal file
839
zhenxun/services/llm/tools.py
Normal file
@ -0,0 +1,839 @@
|
|||||||
|
"""
|
||||||
|
工具模块
|
||||||
|
|
||||||
|
整合了工具参数解析器、工具提供者管理器与工具执行逻辑,便于在 LLM 服务层统一调用。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from collections.abc import Callable
|
||||||
|
from enum import Enum
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from typing import (
|
||||||
|
Annotated,
|
||||||
|
Any,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
get_args,
|
||||||
|
get_origin,
|
||||||
|
get_type_hints,
|
||||||
|
)
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from httpx import NetworkError, TimeoutException
|
||||||
|
|
||||||
|
try:
|
||||||
|
import ujson as fast_json
|
||||||
|
except ImportError:
|
||||||
|
fast_json = json
|
||||||
|
|
||||||
|
import nonebot
|
||||||
|
from nonebot.dependencies import Dependent, Param
|
||||||
|
from nonebot.internal.adapter import Bot, Event
|
||||||
|
from nonebot.internal.params import (
|
||||||
|
BotParam,
|
||||||
|
DefaultParam,
|
||||||
|
DependParam,
|
||||||
|
DependsInner,
|
||||||
|
EventParam,
|
||||||
|
StateParam,
|
||||||
|
)
|
||||||
|
from pydantic import BaseModel, Field, ValidationError, create_model
|
||||||
|
from pydantic.fields import FieldInfo
|
||||||
|
|
||||||
|
from zhenxun.services.log import logger
|
||||||
|
from zhenxun.utils.decorator.retry import Retry
|
||||||
|
from zhenxun.utils.pydantic_compat import model_dump, model_fields, model_json_schema
|
||||||
|
|
||||||
|
from .types import (
|
||||||
|
LLMErrorCode,
|
||||||
|
LLMException,
|
||||||
|
LLMMessage,
|
||||||
|
LLMToolCall,
|
||||||
|
ToolExecutable,
|
||||||
|
ToolProvider,
|
||||||
|
ToolResult,
|
||||||
|
)
|
||||||
|
from .types.models import ToolDefinition
|
||||||
|
from .types.protocols import BaseCallbackHandler, ToolCallData
|
||||||
|
|
||||||
|
|
||||||
|
class ToolParam(Param):
|
||||||
|
"""
|
||||||
|
工具参数提取器。
|
||||||
|
|
||||||
|
用于在自定义工具函数(Function Tool)中,从 LLM 解析出的参数字典
|
||||||
|
(`state["_tool_params"]`)
|
||||||
|
中提取特定的参数值。通常配合 `Annotated` 和依赖注入系统使用。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args: Any, name: str, **kwargs: Any):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"ToolParam(name={self.name})"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@override
|
||||||
|
def _check_param(
|
||||||
|
cls, param: inspect.Parameter, allow_types: tuple[type[Param], ...]
|
||||||
|
) -> Optional["ToolParam"]:
|
||||||
|
if param.default is not inspect.Parameter.empty and isinstance(
|
||||||
|
param.default, DependsInner
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
if get_origin(param.annotation) is Annotated:
|
||||||
|
for arg in get_args(param.annotation):
|
||||||
|
if isinstance(arg, DependsInner):
|
||||||
|
return None
|
||||||
|
|
||||||
|
if param.kind not in (
|
||||||
|
inspect.Parameter.VAR_POSITIONAL,
|
||||||
|
inspect.Parameter.VAR_KEYWORD,
|
||||||
|
):
|
||||||
|
return cls(name=param.name)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def _solve(self, **kwargs: Any) -> Any:
|
||||||
|
state: dict[str, Any] = kwargs.get("state", {})
|
||||||
|
tool_params = state.get("_tool_params", {})
|
||||||
|
if self.name in tool_params:
|
||||||
|
return tool_params[self.name]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class RunContext(BaseModel):
|
||||||
|
"""
|
||||||
|
依赖注入容器(DI Container),保留原有上下文信息的同时提升获取类型的能力。
|
||||||
|
"""
|
||||||
|
|
||||||
|
session_id: str | None = None
|
||||||
|
scope: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
extra: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
|
class RunContextParam(Param):
|
||||||
|
"""自动注入 RunContext 的参数解析器"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _check_param(
|
||||||
|
cls, param: inspect.Parameter, allow_types: tuple[type[Param], ...]
|
||||||
|
) -> Optional["RunContextParam"]:
|
||||||
|
if param.annotation is RunContext:
|
||||||
|
return cls()
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _solve(self, **kwargs: Any) -> Any:
|
||||||
|
state = kwargs.get("state", {})
|
||||||
|
return state.get("_agent_context")
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_docstring_params(docstring: str | None) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
解析文档字符串,提取参数描述。
|
||||||
|
支持 Google Style (Args:), ReST Style (:param:), 和中文风格 (参数:)。
|
||||||
|
"""
|
||||||
|
if not docstring:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
params: dict[str, str] = {}
|
||||||
|
lines = docstring.splitlines()
|
||||||
|
|
||||||
|
rest_pattern = re.compile(r"[:@]param\s+(\w+)\s*:?\s*(.*)")
|
||||||
|
found_rest = False
|
||||||
|
for line in lines:
|
||||||
|
match = rest_pattern.search(line)
|
||||||
|
if match:
|
||||||
|
params[match.group(1)] = match.group(2).strip()
|
||||||
|
found_rest = True
|
||||||
|
|
||||||
|
if found_rest:
|
||||||
|
return params
|
||||||
|
|
||||||
|
section_header_pattern = re.compile(
|
||||||
|
r"^\s*(?:Args|Arguments|Parameters|参数)\s*[::]\s*$"
|
||||||
|
)
|
||||||
|
|
||||||
|
param_section_active = False
|
||||||
|
google_pattern = re.compile(r"^\s*(\**\w+)(?:\s*\(.*?\))?\s*[::]\s*(.*)")
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
stripped_line = line.strip()
|
||||||
|
if not stripped_line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if section_header_pattern.match(line):
|
||||||
|
param_section_active = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
if param_section_active:
|
||||||
|
if (
|
||||||
|
stripped_line.endswith(":") or stripped_line.endswith(":")
|
||||||
|
) and not google_pattern.match(line):
|
||||||
|
param_section_active = False
|
||||||
|
continue
|
||||||
|
|
||||||
|
match = google_pattern.match(line)
|
||||||
|
if match:
|
||||||
|
name = match.group(1).lstrip("*")
|
||||||
|
desc = match.group(2).strip()
|
||||||
|
params[name] = desc
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def _create_dynamic_model(func: Callable) -> type[BaseModel]:
|
||||||
|
"""根据函数签名动态创建 Pydantic 模型"""
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
doc_params = _parse_docstring_params(func.__doc__)
|
||||||
|
type_hints = get_type_hints(func, include_extras=True)
|
||||||
|
|
||||||
|
fields = {}
|
||||||
|
for name, param in sig.parameters.items():
|
||||||
|
if name in ("self", "cls"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
annotation = type_hints.get(name, Any)
|
||||||
|
default = param.default
|
||||||
|
|
||||||
|
is_run_context = False
|
||||||
|
if annotation is RunContext:
|
||||||
|
is_run_context = True
|
||||||
|
else:
|
||||||
|
origin = get_origin(annotation)
|
||||||
|
if origin is Union:
|
||||||
|
args = get_args(annotation)
|
||||||
|
if RunContext in args:
|
||||||
|
is_run_context = True
|
||||||
|
|
||||||
|
if is_run_context:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if default is not inspect.Parameter.empty and isinstance(default, DependsInner):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if get_origin(annotation) is Annotated:
|
||||||
|
args = get_args(annotation)
|
||||||
|
if any(isinstance(arg, DependsInner) for arg in args):
|
||||||
|
continue
|
||||||
|
|
||||||
|
description = doc_params.get(name)
|
||||||
|
if isinstance(default, FieldInfo):
|
||||||
|
if description and not getattr(default, "description", None):
|
||||||
|
default.description = description
|
||||||
|
fields[name] = (annotation, default)
|
||||||
|
else:
|
||||||
|
if default is inspect.Parameter.empty:
|
||||||
|
default = ...
|
||||||
|
fields[name] = (annotation, Field(default, description=description))
|
||||||
|
|
||||||
|
return create_model(f"{func.__name__}Params", **fields)
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionExecutable(ToolExecutable):
|
||||||
|
"""一个 ToolExecutable 的实现,用于包装一个普通的 Python 函数。"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
func: Callable,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
params_model: type[BaseModel] | None = None,
|
||||||
|
unpack_args: bool = False,
|
||||||
|
):
|
||||||
|
self._func = func
|
||||||
|
self._name = name
|
||||||
|
self._description = description
|
||||||
|
self._params_model = params_model
|
||||||
|
self._unpack_args = unpack_args
|
||||||
|
|
||||||
|
self.dependent = Dependent[Any].parse(
|
||||||
|
call=func,
|
||||||
|
allow_types=(
|
||||||
|
DependParam,
|
||||||
|
BotParam,
|
||||||
|
EventParam,
|
||||||
|
StateParam,
|
||||||
|
RunContextParam,
|
||||||
|
ToolParam,
|
||||||
|
DefaultParam,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_definition(self) -> ToolDefinition:
|
||||||
|
if not self._params_model:
|
||||||
|
return ToolDefinition(
|
||||||
|
name=self._name,
|
||||||
|
description=self._description,
|
||||||
|
parameters={"type": "object", "properties": {}},
|
||||||
|
)
|
||||||
|
|
||||||
|
schema = model_json_schema(self._params_model)
|
||||||
|
|
||||||
|
return ToolDefinition(
|
||||||
|
name=self._name,
|
||||||
|
description=self._description,
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": schema.get("properties", {}),
|
||||||
|
"required": schema.get("required", []),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self, context: RunContext | None = None, **kwargs: Any
|
||||||
|
) -> ToolResult:
|
||||||
|
context = context or RunContext()
|
||||||
|
|
||||||
|
tool_arguments = kwargs
|
||||||
|
|
||||||
|
if self._params_model:
|
||||||
|
try:
|
||||||
|
_fields = model_fields(self._params_model)
|
||||||
|
validation_input = {
|
||||||
|
key: value for key, value in kwargs.items() if key in _fields
|
||||||
|
}
|
||||||
|
|
||||||
|
validated_params = self._params_model(**validation_input)
|
||||||
|
|
||||||
|
if not self._unpack_args:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
validated_dict = model_dump(validated_params)
|
||||||
|
tool_arguments = validated_dict
|
||||||
|
|
||||||
|
except ValidationError as e:
|
||||||
|
error_msgs = []
|
||||||
|
for err in e.errors():
|
||||||
|
loc = ".".join(str(x) for x in err["loc"])
|
||||||
|
msg = err["msg"]
|
||||||
|
error_msgs.append(f"Parameter '{loc}': {msg}")
|
||||||
|
|
||||||
|
formatted_error = "; ".join(error_msgs)
|
||||||
|
error_payload = {
|
||||||
|
"error_type": "InvalidArguments",
|
||||||
|
"message": f"Parameter validation failed: {formatted_error}",
|
||||||
|
"is_retryable": True,
|
||||||
|
}
|
||||||
|
return ToolResult(
|
||||||
|
output=json.dumps(error_payload, ensure_ascii=False),
|
||||||
|
display_content=f"Validation Error: {formatted_error}",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"执行工具 '{self._name}' 时参数验证或实例化失败: {e}", e=e
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
state = {
|
||||||
|
"_tool_params": tool_arguments,
|
||||||
|
"_agent_context": context,
|
||||||
|
}
|
||||||
|
|
||||||
|
bot: Bot | None = None
|
||||||
|
if context and context.scope.get("bot"):
|
||||||
|
bot = context.scope.get("bot")
|
||||||
|
if not bot:
|
||||||
|
try:
|
||||||
|
bot = nonebot.get_bot()
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
event: Event | None = None
|
||||||
|
if context and context.scope.get("event"):
|
||||||
|
event = context.scope.get("event")
|
||||||
|
|
||||||
|
raw_result = await self.dependent(
|
||||||
|
bot=bot,
|
||||||
|
event=event,
|
||||||
|
state=state,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ToolResult(output=raw_result, display_content=str(raw_result))
|
||||||
|
|
||||||
|
|
||||||
|
class BuiltinFunctionToolProvider(ToolProvider):
|
||||||
|
"""一个内置的 ToolProvider,用于处理通过装饰器注册的函数。"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._functions: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
|
def register(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
func: Callable,
|
||||||
|
description: str,
|
||||||
|
params_model: type[BaseModel] | None = None,
|
||||||
|
unpack_args: bool = False,
|
||||||
|
):
|
||||||
|
self._functions[name] = {
|
||||||
|
"func": func,
|
||||||
|
"description": description,
|
||||||
|
"params_model": params_model,
|
||||||
|
"unpack_args": unpack_args,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def discover_tools(
|
||||||
|
self,
|
||||||
|
allowed_servers: list[str] | None = None,
|
||||||
|
excluded_servers: list[str] | None = None,
|
||||||
|
) -> dict[str, ToolExecutable]:
|
||||||
|
executables = {}
|
||||||
|
for name, info in self._functions.items():
|
||||||
|
executables[name] = FunctionExecutable(
|
||||||
|
func=info["func"],
|
||||||
|
name=name,
|
||||||
|
description=info["description"],
|
||||||
|
params_model=info["params_model"],
|
||||||
|
unpack_args=info.get("unpack_args", False),
|
||||||
|
)
|
||||||
|
return executables
|
||||||
|
|
||||||
|
async def get_tool_executable(
|
||||||
|
self, name: str, config: dict[str, Any]
|
||||||
|
) -> ToolExecutable | None:
|
||||||
|
if config.get("type", "function") == "function" and name in self._functions:
|
||||||
|
info = self._functions[name]
|
||||||
|
return FunctionExecutable(
|
||||||
|
func=info["func"],
|
||||||
|
name=name,
|
||||||
|
description=info["description"],
|
||||||
|
params_model=info["params_model"],
|
||||||
|
unpack_args=info.get("unpack_args", False),
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class ToolProviderManager:
|
||||||
|
"""工具提供者的中心化管理器,采用单例模式。"""
|
||||||
|
|
||||||
|
_instance: "ToolProviderManager | None" = None
|
||||||
|
|
||||||
|
def __new__(cls) -> "ToolProviderManager":
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
if hasattr(self, "_initialized") and self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._providers: list[ToolProvider] = []
|
||||||
|
self._resolved_tools: dict[str, ToolExecutable] | None = None
|
||||||
|
self._init_lock = asyncio.Lock()
|
||||||
|
self._init_promise: asyncio.Task | None = None
|
||||||
|
self._builtin_function_provider = BuiltinFunctionToolProvider()
|
||||||
|
self.register(self._builtin_function_provider)
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
def register(self, provider: ToolProvider):
|
||||||
|
"""注册一个新的 ToolProvider。"""
|
||||||
|
if provider not in self._providers:
|
||||||
|
self._providers.append(provider)
|
||||||
|
logger.info(f"已注册工具提供者: {provider.__class__.__name__}")
|
||||||
|
|
||||||
|
def function_tool(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
params_model: type[BaseModel] | None = None,
|
||||||
|
):
|
||||||
|
"""装饰器:将一个函数注册为内置工具。"""
|
||||||
|
|
||||||
|
def decorator(func: Callable):
|
||||||
|
if name in self._builtin_function_provider._functions:
|
||||||
|
logger.warning(f"正在覆盖已注册的函数工具: {name}")
|
||||||
|
|
||||||
|
final_model = params_model
|
||||||
|
unpack_args = False
|
||||||
|
if final_model is None:
|
||||||
|
final_model = _create_dynamic_model(func)
|
||||||
|
unpack_args = True
|
||||||
|
|
||||||
|
self._builtin_function_provider.register(
|
||||||
|
name=name,
|
||||||
|
func=func,
|
||||||
|
description=description,
|
||||||
|
params_model=final_model,
|
||||||
|
unpack_args=unpack_args,
|
||||||
|
)
|
||||||
|
logger.info(f"已注册函数工具: '{name}'")
|
||||||
|
return func
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
"""懒加载初始化所有已注册的 ToolProvider。"""
|
||||||
|
if not self._init_promise:
|
||||||
|
async with self._init_lock:
|
||||||
|
if not self._init_promise:
|
||||||
|
self._init_promise = asyncio.create_task(
|
||||||
|
self._initialize_providers()
|
||||||
|
)
|
||||||
|
await self._init_promise
|
||||||
|
|
||||||
|
async def _initialize_providers(self) -> None:
|
||||||
|
"""内部初始化逻辑。"""
|
||||||
|
logger.info(f"开始初始化 {len(self._providers)} 个工具提供者...")
|
||||||
|
init_tasks = [provider.initialize() for provider in self._providers]
|
||||||
|
await asyncio.gather(*init_tasks, return_exceptions=True)
|
||||||
|
logger.info("所有工具提供者初始化完成。")
|
||||||
|
|
||||||
|
async def get_resolved_tools(
|
||||||
|
self,
|
||||||
|
allowed_servers: list[str] | None = None,
|
||||||
|
excluded_servers: list[str] | None = None,
|
||||||
|
) -> dict[str, ToolExecutable]:
|
||||||
|
"""
|
||||||
|
获取所有已发现和解析的工具。
|
||||||
|
此方法会触发懒加载初始化,并根据是否传入过滤器来决定是否使用全局缓存。
|
||||||
|
"""
|
||||||
|
await self.initialize()
|
||||||
|
|
||||||
|
has_filters = allowed_servers is not None or excluded_servers is not None
|
||||||
|
|
||||||
|
if not has_filters and self._resolved_tools is not None:
|
||||||
|
logger.debug("使用全局工具缓存。")
|
||||||
|
return self._resolved_tools
|
||||||
|
|
||||||
|
if has_filters:
|
||||||
|
logger.info("检测到过滤器,执行临时工具发现 (不使用缓存)。")
|
||||||
|
logger.debug(
|
||||||
|
f"过滤器详情: allowed_servers={allowed_servers}, "
|
||||||
|
f"excluded_servers={excluded_servers}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info("未应用过滤器,开始全局工具发现...")
|
||||||
|
|
||||||
|
all_tools: dict[str, ToolExecutable] = {}
|
||||||
|
|
||||||
|
discover_tasks = []
|
||||||
|
for provider in self._providers:
|
||||||
|
sig = inspect.signature(provider.discover_tools)
|
||||||
|
params_to_pass = {}
|
||||||
|
if "allowed_servers" in sig.parameters:
|
||||||
|
params_to_pass["allowed_servers"] = allowed_servers
|
||||||
|
if "excluded_servers" in sig.parameters:
|
||||||
|
params_to_pass["excluded_servers"] = excluded_servers
|
||||||
|
|
||||||
|
discover_tasks.append(provider.discover_tools(**params_to_pass))
|
||||||
|
|
||||||
|
results = await asyncio.gather(*discover_tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
for i, provider_result in enumerate(results):
|
||||||
|
provider_name = self._providers[i].__class__.__name__
|
||||||
|
if isinstance(provider_result, dict):
|
||||||
|
logger.debug(
|
||||||
|
f"提供者 '{provider_name}' 发现了 {len(provider_result)} 个工具。"
|
||||||
|
)
|
||||||
|
for name, executable in provider_result.items():
|
||||||
|
if name in all_tools:
|
||||||
|
logger.warning(
|
||||||
|
f"发现重复的工具名称 '{name}',后发现的将覆盖前者。"
|
||||||
|
)
|
||||||
|
all_tools[name] = executable
|
||||||
|
elif isinstance(provider_result, Exception):
|
||||||
|
logger.error(
|
||||||
|
f"提供者 '{provider_name}' 在发现工具时出错: {provider_result}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not has_filters:
|
||||||
|
self._resolved_tools = all_tools
|
||||||
|
logger.info(f"全局工具发现完成,共找到并缓存了 {len(all_tools)} 个工具。")
|
||||||
|
else:
|
||||||
|
logger.info(f"带过滤器的工具发现完成,共找到 {len(all_tools)} 个工具。")
|
||||||
|
|
||||||
|
return all_tools
|
||||||
|
|
||||||
|
async def resolve_specific_tools(
|
||||||
|
self, tool_names: list[str]
|
||||||
|
) -> dict[str, ToolExecutable]:
|
||||||
|
"""
|
||||||
|
仅解析指定名称的工具,避免触发全量工具发现。
|
||||||
|
"""
|
||||||
|
resolved: dict[str, ToolExecutable] = {}
|
||||||
|
if not tool_names:
|
||||||
|
return resolved
|
||||||
|
|
||||||
|
await self.initialize()
|
||||||
|
|
||||||
|
for name in tool_names:
|
||||||
|
config: dict[str, Any] = {"name": name}
|
||||||
|
for provider in self._providers:
|
||||||
|
try:
|
||||||
|
executable = await provider.get_tool_executable(name, config)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
f"provider '{provider.__class__.__name__}' 在解析工具 '{name}'"
|
||||||
|
f"时出错: {exc}",
|
||||||
|
e=exc,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if executable:
|
||||||
|
resolved[name] = executable
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
logger.warning(f"没有找到名为 '{name}' 的工具,已跳过。")
|
||||||
|
|
||||||
|
return resolved
|
||||||
|
|
||||||
|
async def get_function_tools(
|
||||||
|
self, names: list[str] | None = None
|
||||||
|
) -> dict[str, ToolExecutable]:
|
||||||
|
"""
|
||||||
|
仅从内置的函数提供者中解析指定的工具。
|
||||||
|
"""
|
||||||
|
all_function_tools = await self._builtin_function_provider.discover_tools()
|
||||||
|
if names is None:
|
||||||
|
return all_function_tools
|
||||||
|
|
||||||
|
resolved_tools = {}
|
||||||
|
for name in names:
|
||||||
|
if name in all_function_tools:
|
||||||
|
resolved_tools[name] = all_function_tools[name]
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"本地函数工具 '{name}' 未通过 @function_tool 注册,将被忽略。"
|
||||||
|
)
|
||||||
|
return resolved_tools
|
||||||
|
|
||||||
|
|
||||||
|
tool_provider_manager = ToolProviderManager()
|
||||||
|
function_tool = tool_provider_manager.function_tool
|
||||||
|
|
||||||
|
|
||||||
|
class ToolErrorType(str, Enum):
|
||||||
|
"""结构化工具错误的类型枚举。"""
|
||||||
|
|
||||||
|
TOOL_NOT_FOUND = "ToolNotFound"
|
||||||
|
INVALID_ARGUMENTS = "InvalidArguments"
|
||||||
|
EXECUTION_ERROR = "ExecutionError"
|
||||||
|
USER_CANCELLATION = "UserCancellation"
|
||||||
|
|
||||||
|
|
||||||
|
class ToolErrorResult(BaseModel):
|
||||||
|
"""一个结构化的工具执行错误模型。"""
|
||||||
|
|
||||||
|
error_type: ToolErrorType = Field(..., description="错误的类型。")
|
||||||
|
message: str = Field(..., description="对错误的详细描述。")
|
||||||
|
is_retryable: bool = Field(False, description="指示这个错误是否可能通过重试解决。")
|
||||||
|
|
||||||
|
|
||||||
|
class ToolInvoker:
|
||||||
|
"""
|
||||||
|
全能工具执行器。
|
||||||
|
负责接收工具调用请求,解析参数,触发回调,执行工具,并返回标准化的结果。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, callbacks: list[BaseCallbackHandler] | None = None):
|
||||||
|
self.callbacks = callbacks or []
|
||||||
|
|
||||||
|
async def _trigger_callbacks(self, event_name: str, *args, **kwargs: Any) -> None:
|
||||||
|
if not self.callbacks:
|
||||||
|
return
|
||||||
|
tasks = [
|
||||||
|
getattr(handler, event_name)(*args, **kwargs)
|
||||||
|
for handler in self.callbacks
|
||||||
|
if hasattr(handler, event_name)
|
||||||
|
]
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
async def execute_tool_call(
|
||||||
|
self,
|
||||||
|
tool_call: LLMToolCall,
|
||||||
|
available_tools: dict[str, ToolExecutable],
|
||||||
|
context: Any | None = None,
|
||||||
|
) -> tuple[LLMToolCall, ToolResult]:
|
||||||
|
tool_name = tool_call.function.name
|
||||||
|
arguments_str = tool_call.function.arguments
|
||||||
|
arguments: dict[str, Any] = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
if arguments_str:
|
||||||
|
arguments = json.loads(arguments_str)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
error_result = ToolErrorResult(
|
||||||
|
error_type=ToolErrorType.INVALID_ARGUMENTS,
|
||||||
|
message=f"参数解析失败: {e}",
|
||||||
|
is_retryable=False,
|
||||||
|
)
|
||||||
|
return tool_call, ToolResult(output=model_dump(error_result))
|
||||||
|
|
||||||
|
tool_data = ToolCallData(tool_name=tool_name, tool_args=arguments)
|
||||||
|
pre_calculated_result: ToolResult | None = None
|
||||||
|
for handler in self.callbacks:
|
||||||
|
res = await handler.on_tool_start(tool_call, tool_data)
|
||||||
|
if isinstance(res, ToolCallData):
|
||||||
|
tool_data = res
|
||||||
|
arguments = tool_data.tool_args
|
||||||
|
tool_call.function.arguments = json.dumps(arguments, ensure_ascii=False)
|
||||||
|
elif isinstance(res, ToolResult):
|
||||||
|
pre_calculated_result = res
|
||||||
|
break
|
||||||
|
|
||||||
|
if pre_calculated_result:
|
||||||
|
return tool_call, pre_calculated_result
|
||||||
|
|
||||||
|
executable = available_tools.get(tool_name)
|
||||||
|
if not executable:
|
||||||
|
error_result = ToolErrorResult(
|
||||||
|
error_type=ToolErrorType.TOOL_NOT_FOUND,
|
||||||
|
message=f"Tool '{tool_name}' not found.",
|
||||||
|
is_retryable=False,
|
||||||
|
)
|
||||||
|
return tool_call, ToolResult(output=model_dump(error_result))
|
||||||
|
|
||||||
|
from .config.providers import get_llm_config
|
||||||
|
|
||||||
|
if not get_llm_config().debug_log:
|
||||||
|
try:
|
||||||
|
definition = await executable.get_definition()
|
||||||
|
schema_payload = getattr(definition, "parameters", {})
|
||||||
|
schema_json = fast_json.dumps(
|
||||||
|
schema_payload,
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"🔍 [JIT Schema] {tool_name}: {schema_json}",
|
||||||
|
"ToolInvoker",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.trace(f"JIT Schema logging failed: {e}")
|
||||||
|
|
||||||
|
start_t = time.monotonic()
|
||||||
|
result: ToolResult | None = None
|
||||||
|
error: Exception | None = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
@Retry.simple(stop_max_attempt=2, wait_fixed_seconds=1)
|
||||||
|
async def execute_with_retry():
|
||||||
|
return await executable.execute(context=context, **arguments)
|
||||||
|
|
||||||
|
result = await execute_with_retry()
|
||||||
|
except ValidationError as e:
|
||||||
|
error = e
|
||||||
|
error_msgs = []
|
||||||
|
for err in e.errors():
|
||||||
|
loc = ".".join(str(x) for x in err["loc"])
|
||||||
|
msg = err["msg"]
|
||||||
|
error_msgs.append(f"参数 '{loc}': {msg}")
|
||||||
|
|
||||||
|
formatted_error = "; ".join(error_msgs)
|
||||||
|
error_result = ToolErrorResult(
|
||||||
|
error_type=ToolErrorType.INVALID_ARGUMENTS,
|
||||||
|
message=f"参数验证失败。请根据错误修正你的输入: {formatted_error}",
|
||||||
|
is_retryable=True,
|
||||||
|
)
|
||||||
|
result = ToolResult(output=model_dump(error_result))
|
||||||
|
except (TimeoutException, NetworkError) as e:
|
||||||
|
error = e
|
||||||
|
error_result = ToolErrorResult(
|
||||||
|
error_type=ToolErrorType.EXECUTION_ERROR,
|
||||||
|
message=f"工具执行网络超时或连接失败: {e!s}",
|
||||||
|
is_retryable=False,
|
||||||
|
)
|
||||||
|
result = ToolResult(output=model_dump(error_result))
|
||||||
|
except Exception as e:
|
||||||
|
error = e
|
||||||
|
error_type = ToolErrorType.EXECUTION_ERROR
|
||||||
|
if (
|
||||||
|
isinstance(e, LLMException)
|
||||||
|
and e.code == LLMErrorCode.CONFIGURATION_ERROR
|
||||||
|
):
|
||||||
|
error_type = ToolErrorType.TOOL_NOT_FOUND
|
||||||
|
is_retryable = False
|
||||||
|
|
||||||
|
is_retryable = False
|
||||||
|
|
||||||
|
error_result = ToolErrorResult(
|
||||||
|
error_type=error_type, message=str(e), is_retryable=is_retryable
|
||||||
|
)
|
||||||
|
result = ToolResult(output=model_dump(error_result))
|
||||||
|
|
||||||
|
duration = time.monotonic() - start_t
|
||||||
|
|
||||||
|
await self._trigger_callbacks(
|
||||||
|
"on_tool_end",
|
||||||
|
result=result,
|
||||||
|
error=error,
|
||||||
|
tool_call=tool_call,
|
||||||
|
duration=duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
raise LLMException("工具执行未返回任何结果。")
|
||||||
|
|
||||||
|
return tool_call, result
|
||||||
|
|
||||||
|
async def execute_batch(
|
||||||
|
self,
|
||||||
|
tool_calls: list[LLMToolCall],
|
||||||
|
available_tools: dict[str, ToolExecutable],
|
||||||
|
context: Any | None = None,
|
||||||
|
) -> list[LLMMessage]:
|
||||||
|
if not tool_calls:
|
||||||
|
return []
|
||||||
|
|
||||||
|
tasks = [
|
||||||
|
self.execute_tool_call(call, available_tools, context)
|
||||||
|
for call in tool_calls
|
||||||
|
]
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
tool_messages: list[LLMMessage] = []
|
||||||
|
for index, result_pair in enumerate(results):
|
||||||
|
original_call = tool_calls[index]
|
||||||
|
|
||||||
|
if isinstance(result_pair, Exception):
|
||||||
|
logger.error(
|
||||||
|
f"工具执行发生未捕获异常: {original_call.function.name}, "
|
||||||
|
f"错误: {result_pair}"
|
||||||
|
)
|
||||||
|
tool_messages.append(
|
||||||
|
LLMMessage.tool_response(
|
||||||
|
tool_call_id=original_call.id,
|
||||||
|
function_name=original_call.function.name,
|
||||||
|
result={
|
||||||
|
"error": f"System Execution Error: {result_pair}",
|
||||||
|
"status": "failed",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
tool_call_result = cast(tuple[LLMToolCall, ToolResult], result_pair)
|
||||||
|
_, tool_result = tool_call_result
|
||||||
|
tool_messages.append(
|
||||||
|
LLMMessage.tool_response(
|
||||||
|
tool_call_id=original_call.id,
|
||||||
|
function_name=original_call.function.name,
|
||||||
|
result=tool_result.output,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return tool_messages
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"RunContext",
|
||||||
|
"RunContextParam",
|
||||||
|
"ToolErrorResult",
|
||||||
|
"ToolErrorType",
|
||||||
|
"ToolInvoker",
|
||||||
|
"ToolParam",
|
||||||
|
"function_tool",
|
||||||
|
"tool_provider_manager",
|
||||||
|
]
|
||||||
@ -1,13 +0,0 @@
|
|||||||
"""
|
|
||||||
工具模块导出
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .manager import tool_provider_manager
|
|
||||||
|
|
||||||
function_tool = tool_provider_manager.function_tool
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"function_tool",
|
|
||||||
"tool_provider_manager",
|
|
||||||
]
|
|
||||||
@ -1,293 +0,0 @@
|
|||||||
"""
|
|
||||||
工具提供者管理器
|
|
||||||
|
|
||||||
负责注册、生命周期管理(包括懒加载)和统一提供所有工具。
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from collections.abc import Callable
|
|
||||||
import inspect
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from zhenxun.services.log import logger
|
|
||||||
from zhenxun.utils.pydantic_compat import model_json_schema
|
|
||||||
|
|
||||||
from ..types import ToolExecutable, ToolProvider
|
|
||||||
from ..types.models import ToolDefinition, ToolResult
|
|
||||||
|
|
||||||
|
|
||||||
class FunctionExecutable(ToolExecutable):
|
|
||||||
"""一个 ToolExecutable 的实现,用于包装一个普通的 Python 函数。"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
func: Callable,
|
|
||||||
name: str,
|
|
||||||
description: str,
|
|
||||||
params_model: type[BaseModel] | None,
|
|
||||||
):
|
|
||||||
self._func = func
|
|
||||||
self._name = name
|
|
||||||
self._description = description
|
|
||||||
self._params_model = params_model
|
|
||||||
|
|
||||||
async def get_definition(self) -> ToolDefinition:
|
|
||||||
if not self._params_model:
|
|
||||||
return ToolDefinition(
|
|
||||||
name=self._name,
|
|
||||||
description=self._description,
|
|
||||||
parameters={"type": "object", "properties": {}},
|
|
||||||
)
|
|
||||||
|
|
||||||
schema = model_json_schema(self._params_model)
|
|
||||||
|
|
||||||
return ToolDefinition(
|
|
||||||
name=self._name,
|
|
||||||
description=self._description,
|
|
||||||
parameters={
|
|
||||||
"type": "object",
|
|
||||||
"properties": schema.get("properties", {}),
|
|
||||||
"required": schema.get("required", []),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def execute(self, **kwargs: Any) -> ToolResult:
|
|
||||||
raw_result: Any
|
|
||||||
|
|
||||||
if self._params_model:
|
|
||||||
try:
|
|
||||||
params_instance = self._params_model(**kwargs)
|
|
||||||
|
|
||||||
if inspect.iscoroutinefunction(self._func):
|
|
||||||
raw_result = await self._func(params_instance)
|
|
||||||
else:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
raw_result = await loop.run_in_executor(
|
|
||||||
None, lambda: self._func(params_instance)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"执行工具 '{self._name}' 时参数验证或实例化失败: {e}", e=e
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
if inspect.iscoroutinefunction(self._func):
|
|
||||||
raw_result = await self._func(**kwargs)
|
|
||||||
else:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
raw_result = await loop.run_in_executor(
|
|
||||||
None, lambda: self._func(**kwargs)
|
|
||||||
)
|
|
||||||
|
|
||||||
return ToolResult(output=raw_result, display_content=str(raw_result))
|
|
||||||
|
|
||||||
|
|
||||||
class BuiltinFunctionToolProvider(ToolProvider):
|
|
||||||
"""一个内置的 ToolProvider,用于处理通过装饰器注册的函数。"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._functions: dict[str, dict[str, Any]] = {}
|
|
||||||
|
|
||||||
def register(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
func: Callable,
|
|
||||||
description: str,
|
|
||||||
params_model: type[BaseModel] | None,
|
|
||||||
):
|
|
||||||
self._functions[name] = {
|
|
||||||
"func": func,
|
|
||||||
"description": description,
|
|
||||||
"params_model": params_model,
|
|
||||||
}
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def discover_tools(
|
|
||||||
self,
|
|
||||||
allowed_servers: list[str] | None = None,
|
|
||||||
excluded_servers: list[str] | None = None,
|
|
||||||
) -> dict[str, ToolExecutable]:
|
|
||||||
executables = {}
|
|
||||||
for name, info in self._functions.items():
|
|
||||||
executables[name] = FunctionExecutable(
|
|
||||||
func=info["func"],
|
|
||||||
name=name,
|
|
||||||
description=info["description"],
|
|
||||||
params_model=info["params_model"],
|
|
||||||
)
|
|
||||||
return executables
|
|
||||||
|
|
||||||
async def get_tool_executable(
|
|
||||||
self, name: str, config: dict[str, Any]
|
|
||||||
) -> ToolExecutable | None:
|
|
||||||
if config.get("type") == "function" and name in self._functions:
|
|
||||||
info = self._functions[name]
|
|
||||||
return FunctionExecutable(
|
|
||||||
func=info["func"],
|
|
||||||
name=name,
|
|
||||||
description=info["description"],
|
|
||||||
params_model=info["params_model"],
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class ToolProviderManager:
|
|
||||||
"""工具提供者的中心化管理器,采用单例模式。"""
|
|
||||||
|
|
||||||
_instance: "ToolProviderManager | None" = None
|
|
||||||
|
|
||||||
def __new__(cls) -> "ToolProviderManager":
|
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = super().__new__(cls)
|
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
if hasattr(self, "_initialized") and self._initialized:
|
|
||||||
return
|
|
||||||
|
|
||||||
self._providers: list[ToolProvider] = []
|
|
||||||
self._resolved_tools: dict[str, ToolExecutable] | None = None
|
|
||||||
self._init_lock = asyncio.Lock()
|
|
||||||
self._init_promise: asyncio.Task | None = None
|
|
||||||
self._builtin_function_provider = BuiltinFunctionToolProvider()
|
|
||||||
self.register(self._builtin_function_provider)
|
|
||||||
self._initialized = True
|
|
||||||
|
|
||||||
def register(self, provider: ToolProvider):
|
|
||||||
"""注册一个新的 ToolProvider。"""
|
|
||||||
if provider not in self._providers:
|
|
||||||
self._providers.append(provider)
|
|
||||||
logger.info(f"已注册工具提供者: {provider.__class__.__name__}")
|
|
||||||
|
|
||||||
def function_tool(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
description: str,
|
|
||||||
params_model: type[BaseModel] | None = None,
|
|
||||||
):
|
|
||||||
"""装饰器:将一个函数注册为内置工具。"""
|
|
||||||
|
|
||||||
def decorator(func: Callable):
|
|
||||||
if name in self._builtin_function_provider._functions:
|
|
||||||
logger.warning(f"正在覆盖已注册的函数工具: {name}")
|
|
||||||
|
|
||||||
self._builtin_function_provider.register(
|
|
||||||
name=name,
|
|
||||||
func=func,
|
|
||||||
description=description,
|
|
||||||
params_model=params_model,
|
|
||||||
)
|
|
||||||
logger.info(f"已注册函数工具: '{name}'")
|
|
||||||
return func
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
"""懒加载初始化所有已注册的 ToolProvider。"""
|
|
||||||
if not self._init_promise:
|
|
||||||
async with self._init_lock:
|
|
||||||
if not self._init_promise:
|
|
||||||
self._init_promise = asyncio.create_task(
|
|
||||||
self._initialize_providers()
|
|
||||||
)
|
|
||||||
await self._init_promise
|
|
||||||
|
|
||||||
async def _initialize_providers(self) -> None:
|
|
||||||
"""内部初始化逻辑。"""
|
|
||||||
logger.info(f"开始初始化 {len(self._providers)} 个工具提供者...")
|
|
||||||
init_tasks = [provider.initialize() for provider in self._providers]
|
|
||||||
await asyncio.gather(*init_tasks, return_exceptions=True)
|
|
||||||
logger.info("所有工具提供者初始化完成。")
|
|
||||||
|
|
||||||
async def get_resolved_tools(
|
|
||||||
self,
|
|
||||||
allowed_servers: list[str] | None = None,
|
|
||||||
excluded_servers: list[str] | None = None,
|
|
||||||
) -> dict[str, ToolExecutable]:
|
|
||||||
"""
|
|
||||||
获取所有已发现和解析的工具。
|
|
||||||
此方法会触发懒加载初始化,并根据是否传入过滤器来决定是否使用全局缓存。
|
|
||||||
"""
|
|
||||||
await self.initialize()
|
|
||||||
|
|
||||||
has_filters = allowed_servers is not None or excluded_servers is not None
|
|
||||||
|
|
||||||
if not has_filters and self._resolved_tools is not None:
|
|
||||||
logger.debug("使用全局工具缓存。")
|
|
||||||
return self._resolved_tools
|
|
||||||
|
|
||||||
if has_filters:
|
|
||||||
logger.info("检测到过滤器,执行临时工具发现 (不使用缓存)。")
|
|
||||||
logger.debug(
|
|
||||||
f"过滤器详情: allowed_servers={allowed_servers}, "
|
|
||||||
f"excluded_servers={excluded_servers}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info("未应用过滤器,开始全局工具发现...")
|
|
||||||
|
|
||||||
all_tools: dict[str, ToolExecutable] = {}
|
|
||||||
|
|
||||||
discover_tasks = []
|
|
||||||
for provider in self._providers:
|
|
||||||
sig = inspect.signature(provider.discover_tools)
|
|
||||||
params_to_pass = {}
|
|
||||||
if "allowed_servers" in sig.parameters:
|
|
||||||
params_to_pass["allowed_servers"] = allowed_servers
|
|
||||||
if "excluded_servers" in sig.parameters:
|
|
||||||
params_to_pass["excluded_servers"] = excluded_servers
|
|
||||||
|
|
||||||
discover_tasks.append(provider.discover_tools(**params_to_pass))
|
|
||||||
|
|
||||||
results = await asyncio.gather(*discover_tasks, return_exceptions=True)
|
|
||||||
|
|
||||||
for i, provider_result in enumerate(results):
|
|
||||||
provider_name = self._providers[i].__class__.__name__
|
|
||||||
if isinstance(provider_result, dict):
|
|
||||||
logger.debug(
|
|
||||||
f"提供者 '{provider_name}' 发现了 {len(provider_result)} 个工具。"
|
|
||||||
)
|
|
||||||
for name, executable in provider_result.items():
|
|
||||||
if name in all_tools:
|
|
||||||
logger.warning(
|
|
||||||
f"发现重复的工具名称 '{name}',后发现的将覆盖前者。"
|
|
||||||
)
|
|
||||||
all_tools[name] = executable
|
|
||||||
elif isinstance(provider_result, Exception):
|
|
||||||
logger.error(
|
|
||||||
f"提供者 '{provider_name}' 在发现工具时出错: {provider_result}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not has_filters:
|
|
||||||
self._resolved_tools = all_tools
|
|
||||||
logger.info(f"全局工具发现完成,共找到并缓存了 {len(all_tools)} 个工具。")
|
|
||||||
else:
|
|
||||||
logger.info(f"带过滤器的工具发现完成,共找到 {len(all_tools)} 个工具。")
|
|
||||||
|
|
||||||
return all_tools
|
|
||||||
|
|
||||||
async def get_function_tools(
|
|
||||||
self, names: list[str] | None = None
|
|
||||||
) -> dict[str, ToolExecutable]:
|
|
||||||
"""
|
|
||||||
仅从内置的函数提供者中解析指定的工具。
|
|
||||||
"""
|
|
||||||
all_function_tools = await self._builtin_function_provider.discover_tools()
|
|
||||||
if names is None:
|
|
||||||
return all_function_tools
|
|
||||||
|
|
||||||
resolved_tools = {}
|
|
||||||
for name in names:
|
|
||||||
if name in all_function_tools:
|
|
||||||
resolved_tools[name] = all_function_tools[name]
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"本地函数工具 '{name}' 未通过 @function_tool 注册,将被忽略。"
|
|
||||||
)
|
|
||||||
return resolved_tools
|
|
||||||
|
|
||||||
|
|
||||||
tool_provider_manager = ToolProviderManager()
|
|
||||||
@ -5,30 +5,32 @@ LLM 类型定义模块
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from .capabilities import ModelCapabilities, ModelModality, get_model_capabilities
|
from .capabilities import ModelCapabilities, ModelModality, get_model_capabilities
|
||||||
from .content import (
|
|
||||||
LLMContentPart,
|
|
||||||
LLMMessage,
|
|
||||||
LLMResponse,
|
|
||||||
)
|
|
||||||
from .enums import (
|
|
||||||
EmbeddingTaskType,
|
|
||||||
ModelProvider,
|
|
||||||
ResponseFormat,
|
|
||||||
TaskType,
|
|
||||||
ToolCategory,
|
|
||||||
)
|
|
||||||
from .exceptions import LLMErrorCode, LLMException, get_user_friendly_error_message
|
from .exceptions import LLMErrorCode, LLMException, get_user_friendly_error_message
|
||||||
from .models import (
|
from .models import (
|
||||||
|
CodeExecutionOutcome,
|
||||||
|
EmbeddingTaskType,
|
||||||
|
GeminiCodeExecution,
|
||||||
|
GeminiGoogleSearch,
|
||||||
|
GeminiUrlContext,
|
||||||
LLMCacheInfo,
|
LLMCacheInfo,
|
||||||
LLMCodeExecution,
|
LLMCodeExecution,
|
||||||
|
LLMContentPart,
|
||||||
LLMGroundingAttribution,
|
LLMGroundingAttribution,
|
||||||
LLMGroundingMetadata,
|
LLMGroundingMetadata,
|
||||||
|
LLMMessage,
|
||||||
|
LLMResponse,
|
||||||
LLMToolCall,
|
LLMToolCall,
|
||||||
LLMToolFunction,
|
LLMToolFunction,
|
||||||
ModelDetail,
|
ModelDetail,
|
||||||
ModelInfo,
|
ModelInfo,
|
||||||
ModelName,
|
ModelName,
|
||||||
|
ModelProvider,
|
||||||
ProviderConfig,
|
ProviderConfig,
|
||||||
|
ResponseFormat,
|
||||||
|
StructuredOutputStrategy,
|
||||||
|
TaskType,
|
||||||
|
ToolCategory,
|
||||||
|
ToolChoice,
|
||||||
ToolMetadata,
|
ToolMetadata,
|
||||||
ToolResult,
|
ToolResult,
|
||||||
UsageInfo,
|
UsageInfo,
|
||||||
@ -36,7 +38,11 @@ from .models import (
|
|||||||
from .protocols import ToolExecutable, ToolProvider
|
from .protocols import ToolExecutable, ToolProvider
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"CodeExecutionOutcome",
|
||||||
"EmbeddingTaskType",
|
"EmbeddingTaskType",
|
||||||
|
"GeminiCodeExecution",
|
||||||
|
"GeminiGoogleSearch",
|
||||||
|
"GeminiUrlContext",
|
||||||
"LLMCacheInfo",
|
"LLMCacheInfo",
|
||||||
"LLMCodeExecution",
|
"LLMCodeExecution",
|
||||||
"LLMContentPart",
|
"LLMContentPart",
|
||||||
@ -56,8 +62,10 @@ __all__ = [
|
|||||||
"ModelProvider",
|
"ModelProvider",
|
||||||
"ProviderConfig",
|
"ProviderConfig",
|
||||||
"ResponseFormat",
|
"ResponseFormat",
|
||||||
|
"StructuredOutputStrategy",
|
||||||
"TaskType",
|
"TaskType",
|
||||||
"ToolCategory",
|
"ToolCategory",
|
||||||
|
"ToolChoice",
|
||||||
"ToolExecutable",
|
"ToolExecutable",
|
||||||
"ToolMetadata",
|
"ToolMetadata",
|
||||||
"ToolProvider",
|
"ToolProvider",
|
||||||
|
|||||||
@ -6,6 +6,7 @@ LLM 模型能力定义模块
|
|||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import fnmatch
|
import fnmatch
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@ -20,6 +21,35 @@ class ModelModality(str, Enum):
|
|||||||
EMBEDDING = "embedding"
|
EMBEDDING = "embedding"
|
||||||
|
|
||||||
|
|
||||||
|
class ReasoningMode(str, Enum):
|
||||||
|
"""推理/思考模式类型"""
|
||||||
|
|
||||||
|
NONE = "none"
|
||||||
|
BUDGET = "budget"
|
||||||
|
LEVEL = "level"
|
||||||
|
EFFORT = "effort"
|
||||||
|
|
||||||
|
|
||||||
|
PATTERNS_GEMINI_2_5 = [
|
||||||
|
"gemini-2.5*",
|
||||||
|
"gemini-flash*",
|
||||||
|
"gemini*lite*",
|
||||||
|
"gemini-flash-latest",
|
||||||
|
]
|
||||||
|
|
||||||
|
PATTERNS_GEMINI_3 = [
|
||||||
|
"gemini-3*",
|
||||||
|
"gemini-exp*",
|
||||||
|
]
|
||||||
|
|
||||||
|
PATTERNS_OPENAI_REASONING = [
|
||||||
|
"o1-*",
|
||||||
|
"o3-*",
|
||||||
|
"deepseek-r1*",
|
||||||
|
"deepseek-reasoner",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class ModelCapabilities(BaseModel):
|
class ModelCapabilities(BaseModel):
|
||||||
"""定义一个模型的核心、稳定能力。"""
|
"""定义一个模型的核心、稳定能力。"""
|
||||||
|
|
||||||
@ -27,6 +57,8 @@ class ModelCapabilities(BaseModel):
|
|||||||
output_modalities: set[ModelModality] = Field(default={ModelModality.TEXT})
|
output_modalities: set[ModelModality] = Field(default={ModelModality.TEXT})
|
||||||
supports_tool_calling: bool = False
|
supports_tool_calling: bool = False
|
||||||
is_embedding_model: bool = False
|
is_embedding_model: bool = False
|
||||||
|
reasoning_mode: ReasoningMode = ReasoningMode.NONE
|
||||||
|
reasoning_visibility: Literal["visible", "hidden", "none"] = "none"
|
||||||
|
|
||||||
|
|
||||||
STANDARD_TEXT_TOOL_CAPABILITIES = ModelCapabilities(
|
STANDARD_TEXT_TOOL_CAPABILITIES = ModelCapabilities(
|
||||||
@ -35,7 +67,7 @@ STANDARD_TEXT_TOOL_CAPABILITIES = ModelCapabilities(
|
|||||||
supports_tool_calling=True,
|
supports_tool_calling=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
GEMINI_CAPABILITIES = ModelCapabilities(
|
CAP_GEMINI_2_5 = ModelCapabilities(
|
||||||
input_modalities={
|
input_modalities={
|
||||||
ModelModality.TEXT,
|
ModelModality.TEXT,
|
||||||
ModelModality.IMAGE,
|
ModelModality.IMAGE,
|
||||||
@ -44,21 +76,44 @@ GEMINI_CAPABILITIES = ModelCapabilities(
|
|||||||
},
|
},
|
||||||
output_modalities={ModelModality.TEXT},
|
output_modalities={ModelModality.TEXT},
|
||||||
supports_tool_calling=True,
|
supports_tool_calling=True,
|
||||||
|
reasoning_mode=ReasoningMode.BUDGET,
|
||||||
|
reasoning_visibility="visible",
|
||||||
)
|
)
|
||||||
|
|
||||||
GEMINI_IMAGE_GEN_CAPABILITIES = ModelCapabilities(
|
CAP_GEMINI_3 = ModelCapabilities(
|
||||||
|
input_modalities={
|
||||||
|
ModelModality.TEXT,
|
||||||
|
ModelModality.IMAGE,
|
||||||
|
ModelModality.AUDIO,
|
||||||
|
ModelModality.VIDEO,
|
||||||
|
},
|
||||||
|
output_modalities={ModelModality.TEXT},
|
||||||
|
supports_tool_calling=True,
|
||||||
|
reasoning_mode=ReasoningMode.LEVEL,
|
||||||
|
reasoning_visibility="visible",
|
||||||
|
)
|
||||||
|
|
||||||
|
CAP_GEMINI_IMAGE_GEN = ModelCapabilities(
|
||||||
input_modalities={ModelModality.TEXT, ModelModality.IMAGE},
|
input_modalities={ModelModality.TEXT, ModelModality.IMAGE},
|
||||||
output_modalities={ModelModality.TEXT, ModelModality.IMAGE},
|
output_modalities={ModelModality.TEXT, ModelModality.IMAGE},
|
||||||
supports_tool_calling=True,
|
supports_tool_calling=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
GPT_ADVANCED_TEXT_IMAGE_CAPABILITIES = ModelCapabilities(
|
CAP_OPENAI_REASONING = ModelCapabilities(
|
||||||
|
input_modalities={ModelModality.TEXT, ModelModality.IMAGE},
|
||||||
|
output_modalities={ModelModality.TEXT},
|
||||||
|
supports_tool_calling=True,
|
||||||
|
reasoning_mode=ReasoningMode.EFFORT,
|
||||||
|
reasoning_visibility="hidden",
|
||||||
|
)
|
||||||
|
|
||||||
|
CAP_GPT_ADVANCED = ModelCapabilities(
|
||||||
input_modalities={ModelModality.TEXT, ModelModality.IMAGE},
|
input_modalities={ModelModality.TEXT, ModelModality.IMAGE},
|
||||||
output_modalities={ModelModality.TEXT},
|
output_modalities={ModelModality.TEXT},
|
||||||
supports_tool_calling=True,
|
supports_tool_calling=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
GPT_MULTIMODAL_IO_CAPABILITIES = ModelCapabilities(
|
CAP_GPT_MULTIMODAL_IO = ModelCapabilities(
|
||||||
input_modalities={ModelModality.TEXT, ModelModality.AUDIO, ModelModality.IMAGE},
|
input_modalities={ModelModality.TEXT, ModelModality.AUDIO, ModelModality.IMAGE},
|
||||||
output_modalities={ModelModality.TEXT, ModelModality.AUDIO},
|
output_modalities={ModelModality.TEXT, ModelModality.AUDIO},
|
||||||
supports_tool_calling=True,
|
supports_tool_calling=True,
|
||||||
@ -76,6 +131,12 @@ GPT_VIDEO_GENERATION_CAPABILITIES = ModelCapabilities(
|
|||||||
supports_tool_calling=True,
|
supports_tool_calling=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
EMBEDDING_CAPABILITIES = ModelCapabilities(
|
||||||
|
input_modalities={ModelModality.TEXT},
|
||||||
|
output_modalities={ModelModality.EMBEDDING},
|
||||||
|
is_embedding_model=True,
|
||||||
|
)
|
||||||
|
|
||||||
DEFAULT_PERMISSIVE_CAPABILITIES = ModelCapabilities(
|
DEFAULT_PERMISSIVE_CAPABILITIES = ModelCapabilities(
|
||||||
input_modalities={
|
input_modalities={
|
||||||
ModelModality.TEXT,
|
ModelModality.TEXT,
|
||||||
@ -107,17 +168,33 @@ MODEL_ALIAS_MAPPING: dict[str, str] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
MODEL_CAPABILITIES_REGISTRY: dict[str, ModelCapabilities] = {
|
def _build_registry() -> dict[str, ModelCapabilities]:
|
||||||
"gemini-*-tts": ModelCapabilities(
|
"""构建模型能力注册表,展开模式列表以减少冗余"""
|
||||||
|
registry: dict[str, ModelCapabilities] = {}
|
||||||
|
|
||||||
|
def register_family(patterns: list[str], cap: ModelCapabilities) -> None:
|
||||||
|
for pattern in patterns:
|
||||||
|
registry[pattern] = cap
|
||||||
|
|
||||||
|
register_family(
|
||||||
|
["*gemini-*-image-preview*", "gemini-*-image*"], CAP_GEMINI_IMAGE_GEN
|
||||||
|
)
|
||||||
|
|
||||||
|
register_family(PATTERNS_GEMINI_2_5, CAP_GEMINI_2_5)
|
||||||
|
register_family(PATTERNS_GEMINI_3, CAP_GEMINI_3)
|
||||||
|
|
||||||
|
register_family(PATTERNS_OPENAI_REASONING, CAP_OPENAI_REASONING)
|
||||||
|
|
||||||
|
registry["gemini-*-tts"] = ModelCapabilities(
|
||||||
input_modalities={ModelModality.TEXT},
|
input_modalities={ModelModality.TEXT},
|
||||||
output_modalities={ModelModality.AUDIO},
|
output_modalities={ModelModality.AUDIO},
|
||||||
),
|
)
|
||||||
"gemini-*-native-audio-*": ModelCapabilities(
|
registry["gemini-*-native-audio-*"] = ModelCapabilities(
|
||||||
input_modalities={ModelModality.TEXT, ModelModality.AUDIO, ModelModality.VIDEO},
|
input_modalities={ModelModality.TEXT, ModelModality.AUDIO, ModelModality.VIDEO},
|
||||||
output_modalities={ModelModality.TEXT, ModelModality.AUDIO},
|
output_modalities={ModelModality.TEXT, ModelModality.AUDIO},
|
||||||
supports_tool_calling=True,
|
supports_tool_calling=True,
|
||||||
),
|
)
|
||||||
"gemini-2.0-flash-preview-image-generation": ModelCapabilities(
|
registry["gemini-2.0-flash-preview-image-generation"] = ModelCapabilities(
|
||||||
input_modalities={
|
input_modalities={
|
||||||
ModelModality.TEXT,
|
ModelModality.TEXT,
|
||||||
ModelModality.IMAGE,
|
ModelModality.IMAGE,
|
||||||
@ -126,39 +203,39 @@ MODEL_CAPABILITIES_REGISTRY: dict[str, ModelCapabilities] = {
|
|||||||
},
|
},
|
||||||
output_modalities={ModelModality.TEXT, ModelModality.IMAGE},
|
output_modalities={ModelModality.TEXT, ModelModality.IMAGE},
|
||||||
supports_tool_calling=True,
|
supports_tool_calling=True,
|
||||||
),
|
)
|
||||||
"gemini-embedding-exp": ModelCapabilities(
|
|
||||||
input_modalities={ModelModality.TEXT},
|
registry["GLM-4V-Flash"] = ModelCapabilities(
|
||||||
output_modalities={ModelModality.EMBEDDING},
|
|
||||||
is_embedding_model=True,
|
|
||||||
),
|
|
||||||
"*gemini-*-image-preview*": GEMINI_IMAGE_GEN_CAPABILITIES,
|
|
||||||
"gemini-*-pro*": GEMINI_CAPABILITIES,
|
|
||||||
"gemini-*-flash*": GEMINI_CAPABILITIES,
|
|
||||||
"GLM-4V-Flash": ModelCapabilities(
|
|
||||||
input_modalities={ModelModality.TEXT, ModelModality.IMAGE},
|
input_modalities={ModelModality.TEXT, ModelModality.IMAGE},
|
||||||
output_modalities={ModelModality.TEXT},
|
output_modalities={ModelModality.TEXT},
|
||||||
supports_tool_calling=True,
|
supports_tool_calling=True,
|
||||||
),
|
)
|
||||||
"GLM-4V-Plus*": ModelCapabilities(
|
registry["GLM-4V-Plus*"] = ModelCapabilities(
|
||||||
input_modalities={ModelModality.TEXT, ModelModality.IMAGE, ModelModality.VIDEO},
|
input_modalities={ModelModality.TEXT, ModelModality.IMAGE, ModelModality.VIDEO},
|
||||||
output_modalities={ModelModality.TEXT},
|
output_modalities={ModelModality.TEXT},
|
||||||
supports_tool_calling=True,
|
supports_tool_calling=True,
|
||||||
),
|
)
|
||||||
"glm-4-*": STANDARD_TEXT_TOOL_CAPABILITIES,
|
|
||||||
"glm-z1-*": STANDARD_TEXT_TOOL_CAPABILITIES,
|
register_family(
|
||||||
"doubao-seed-*": DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES,
|
["glm-4-*", "glm-z1-*", "deepseek-chat"], STANDARD_TEXT_TOOL_CAPABILITIES
|
||||||
"doubao-1-5-thinking-vision-pro": DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES,
|
)
|
||||||
"deepseek-chat": STANDARD_TEXT_TOOL_CAPABILITIES,
|
register_family(
|
||||||
"deepseek-reasoner": STANDARD_TEXT_TOOL_CAPABILITIES,
|
["doubao-seed-*", "doubao-1-5-thinking-vision-pro"],
|
||||||
"gpt-5*": GPT_ADVANCED_TEXT_IMAGE_CAPABILITIES,
|
DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES,
|
||||||
"gpt-4.1*": GPT_ADVANCED_TEXT_IMAGE_CAPABILITIES,
|
)
|
||||||
"gpt-4o*": GPT_MULTIMODAL_IO_CAPABILITIES,
|
|
||||||
"o3*": GPT_ADVANCED_TEXT_IMAGE_CAPABILITIES,
|
register_family(["gpt-5*", "gpt-4.1*", "o4-mini*"], CAP_GPT_ADVANCED)
|
||||||
"o4-mini*": GPT_ADVANCED_TEXT_IMAGE_CAPABILITIES,
|
registry["gpt-4o*"] = CAP_GPT_MULTIMODAL_IO
|
||||||
"gpt image*": GPT_IMAGE_GENERATION_CAPABILITIES,
|
|
||||||
"sora*": GPT_VIDEO_GENERATION_CAPABILITIES,
|
registry["gpt image*"] = GPT_IMAGE_GENERATION_CAPABILITIES
|
||||||
}
|
registry["sora*"] = GPT_VIDEO_GENERATION_CAPABILITIES
|
||||||
|
|
||||||
|
registry["*embedding*"] = EMBEDDING_CAPABILITIES
|
||||||
|
|
||||||
|
return registry
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_CAPABILITIES_REGISTRY = _build_registry()
|
||||||
|
|
||||||
|
|
||||||
def get_model_capabilities(model_name: str) -> ModelCapabilities:
|
def get_model_capabilities(model_name: str) -> ModelCapabilities:
|
||||||
|
|||||||
@ -1,434 +0,0 @@
|
|||||||
"""
|
|
||||||
LLM 内容类型定义
|
|
||||||
|
|
||||||
包含多模态内容部分、消息和响应的数据模型。
|
|
||||||
"""
|
|
||||||
|
|
||||||
import base64
|
|
||||||
import mimetypes
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import aiofiles
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from zhenxun.services.log import logger
|
|
||||||
|
|
||||||
|
|
||||||
class LLMContentPart(BaseModel):
|
|
||||||
"""LLM 消息内容部分 - 支持多模态内容"""
|
|
||||||
|
|
||||||
type: str
|
|
||||||
text: str | None = None
|
|
||||||
image_source: str | None = None
|
|
||||||
audio_source: str | None = None
|
|
||||||
video_source: str | None = None
|
|
||||||
document_source: str | None = None
|
|
||||||
file_uri: str | None = None
|
|
||||||
file_source: str | None = None
|
|
||||||
url: str | None = None
|
|
||||||
mime_type: str | None = None
|
|
||||||
metadata: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
def model_post_init(self, /, __context: Any) -> None:
|
|
||||||
"""验证内容部分的有效性"""
|
|
||||||
_ = __context
|
|
||||||
validation_rules = {
|
|
||||||
"text": lambda: self.text,
|
|
||||||
"image": lambda: self.image_source,
|
|
||||||
"audio": lambda: self.audio_source,
|
|
||||||
"video": lambda: self.video_source,
|
|
||||||
"document": lambda: self.document_source,
|
|
||||||
"file": lambda: self.file_uri or self.file_source,
|
|
||||||
"url": lambda: self.url,
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.type in validation_rules:
|
|
||||||
if not validation_rules[self.type]():
|
|
||||||
raise ValueError(f"{self.type}类型的内容部分必须包含相应字段")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def text_part(cls, text: str) -> "LLMContentPart":
|
|
||||||
"""创建文本内容部分"""
|
|
||||||
return cls(type="text", text=text)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def image_url_part(cls, url: str) -> "LLMContentPart":
|
|
||||||
"""创建图片URL内容部分"""
|
|
||||||
return cls(type="image", image_source=url)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def image_base64_part(
|
|
||||||
cls, data: str, mime_type: str = "image/png"
|
|
||||||
) -> "LLMContentPart":
|
|
||||||
"""创建Base64图片内容部分"""
|
|
||||||
data_url = f"data:{mime_type};base64,{data}"
|
|
||||||
return cls(type="image", image_source=data_url)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def audio_url_part(cls, url: str, mime_type: str = "audio/wav") -> "LLMContentPart":
|
|
||||||
"""创建音频URL内容部分"""
|
|
||||||
return cls(type="audio", audio_source=url, mime_type=mime_type)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def video_url_part(cls, url: str, mime_type: str = "video/mp4") -> "LLMContentPart":
|
|
||||||
"""创建视频URL内容部分"""
|
|
||||||
return cls(type="video", video_source=url, mime_type=mime_type)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def video_base64_part(
|
|
||||||
cls, data: str, mime_type: str = "video/mp4"
|
|
||||||
) -> "LLMContentPart":
|
|
||||||
"""创建Base64视频内容部分"""
|
|
||||||
data_url = f"data:{mime_type};base64,{data}"
|
|
||||||
return cls(type="video", video_source=data_url, mime_type=mime_type)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def audio_base64_part(
|
|
||||||
cls, data: str, mime_type: str = "audio/wav"
|
|
||||||
) -> "LLMContentPart":
|
|
||||||
"""创建Base64音频内容部分"""
|
|
||||||
data_url = f"data:{mime_type};base64,{data}"
|
|
||||||
return cls(type="audio", audio_source=data_url, mime_type=mime_type)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def file_uri_part(
|
|
||||||
cls,
|
|
||||||
file_uri: str,
|
|
||||||
mime_type: str | None = None,
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
) -> "LLMContentPart":
|
|
||||||
"""创建Gemini File API URI内容部分"""
|
|
||||||
return cls(
|
|
||||||
type="file",
|
|
||||||
file_uri=file_uri,
|
|
||||||
mime_type=mime_type,
|
|
||||||
metadata=metadata or {},
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def from_path(
|
|
||||||
cls, path_like: str | Path, target_api: str | None = None
|
|
||||||
) -> "LLMContentPart | None":
|
|
||||||
"""
|
|
||||||
从本地文件路径创建 LLMContentPart。
|
|
||||||
自动检测MIME类型,并根据类型(如图片)可能加载为Base64。
|
|
||||||
target_api 可以用于提示如何最好地准备数据(例如 'gemini' 可能偏好 base64)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
path = Path(path_like)
|
|
||||||
if not path.exists() or not path.is_file():
|
|
||||||
logger.warning(f"文件不存在或不是一个文件: {path}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
mime_type, _ = mimetypes.guess_type(path.resolve().as_uri())
|
|
||||||
|
|
||||||
if not mime_type:
|
|
||||||
logger.warning(
|
|
||||||
f"无法猜测文件 {path.name} 的MIME类型,将尝试作为文本文件处理。"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
async with aiofiles.open(path, encoding="utf-8") as f:
|
|
||||||
text_content = await f.read()
|
|
||||||
return cls.text_part(text_content)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"读取文本文件 {path.name} 失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
if mime_type.startswith("image/"):
|
|
||||||
if target_api == "gemini" or not path.is_absolute():
|
|
||||||
try:
|
|
||||||
async with aiofiles.open(path, "rb") as f:
|
|
||||||
img_bytes = await f.read()
|
|
||||||
base64_data = base64.b64encode(img_bytes).decode("utf-8")
|
|
||||||
return cls.image_base64_part(
|
|
||||||
data=base64_data, mime_type=mime_type
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"读取或编码图片文件 {path.name} 失败: {e}")
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"为本地图片路径 {path.name} 生成 image_url_part。"
|
|
||||||
"实际API可能不支持 file:// URI。考虑使用Base64或公网URL。"
|
|
||||||
)
|
|
||||||
return cls.image_url_part(url=path.resolve().as_uri())
|
|
||||||
elif mime_type.startswith("audio/"):
|
|
||||||
return cls.audio_url_part(
|
|
||||||
url=path.resolve().as_uri(), mime_type=mime_type
|
|
||||||
)
|
|
||||||
elif mime_type.startswith("video/"):
|
|
||||||
if target_api == "gemini":
|
|
||||||
# 对于 Gemini API,将视频转换为 base64
|
|
||||||
try:
|
|
||||||
async with aiofiles.open(path, "rb") as f:
|
|
||||||
video_bytes = await f.read()
|
|
||||||
base64_data = base64.b64encode(video_bytes).decode("utf-8")
|
|
||||||
return cls.video_base64_part(
|
|
||||||
data=base64_data, mime_type=mime_type
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"读取或编码视频文件 {path.name} 失败: {e}")
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
return cls.video_url_part(
|
|
||||||
url=path.resolve().as_uri(), mime_type=mime_type
|
|
||||||
)
|
|
||||||
elif (
|
|
||||||
mime_type.startswith("text/")
|
|
||||||
or mime_type == "application/json"
|
|
||||||
or mime_type == "application/xml"
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
async with aiofiles.open(path, encoding="utf-8") as f:
|
|
||||||
text_content = await f.read()
|
|
||||||
return cls.text_part(text_content)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"读取文本类文件 {path.name} 失败: {e}")
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
f"文件 {path.name} (MIME: {mime_type}) 将作为通用文件URI处理。"
|
|
||||||
)
|
|
||||||
return cls.file_uri_part(
|
|
||||||
file_uri=path.resolve().as_uri(),
|
|
||||||
mime_type=mime_type,
|
|
||||||
metadata={"name": path.name, "source": "local_path"},
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"从路径 {path_like} 创建LLMContentPart时出错: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def is_image_url(self) -> bool:
|
|
||||||
"""检查图像源是否为URL"""
|
|
||||||
if not self.image_source:
|
|
||||||
return False
|
|
||||||
return self.image_source.startswith(("http://", "https://"))
|
|
||||||
|
|
||||||
def is_image_base64(self) -> bool:
|
|
||||||
"""检查图像源是否为Base64 Data URL"""
|
|
||||||
if not self.image_source:
|
|
||||||
return False
|
|
||||||
return self.image_source.startswith("data:")
|
|
||||||
|
|
||||||
def get_base64_data(self) -> tuple[str, str] | None:
|
|
||||||
"""从Data URL中提取Base64数据和MIME类型"""
|
|
||||||
if not self.is_image_base64() or not self.image_source:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
header, data = self.image_source.split(",", 1)
|
|
||||||
mime_part = header.split(";")[0].replace("data:", "")
|
|
||||||
return mime_part, data
|
|
||||||
except (ValueError, IndexError):
|
|
||||||
logger.warning(f"无法解析Base64图像数据: {self.image_source[:50]}...")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def convert_for_api_async(self, api_type: str) -> dict[str, Any]:
|
|
||||||
"""根据API类型转换多模态内容格式"""
|
|
||||||
from zhenxun.utils.http_utils import AsyncHttpx
|
|
||||||
|
|
||||||
if self.type == "text":
|
|
||||||
if api_type == "openai":
|
|
||||||
return {"type": "text", "text": self.text}
|
|
||||||
elif api_type == "gemini":
|
|
||||||
return {"text": self.text}
|
|
||||||
else:
|
|
||||||
return {"type": "text", "text": self.text}
|
|
||||||
|
|
||||||
elif self.type == "image":
|
|
||||||
if not self.image_source:
|
|
||||||
raise ValueError("图像类型的内容必须包含image_source")
|
|
||||||
|
|
||||||
if api_type == "openai":
|
|
||||||
return {"type": "image_url", "image_url": {"url": self.image_source}}
|
|
||||||
elif api_type == "gemini":
|
|
||||||
if self.is_image_base64():
|
|
||||||
base64_info = self.get_base64_data()
|
|
||||||
if base64_info:
|
|
||||||
mime_type, data = base64_info
|
|
||||||
return {"inlineData": {"mimeType": mime_type, "data": data}}
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"无法解析Base64图像数据: {self.image_source[:50]}..."
|
|
||||||
)
|
|
||||||
elif self.is_image_url():
|
|
||||||
logger.debug(f"正在为Gemini下载并编码URL图片: {self.image_source}")
|
|
||||||
try:
|
|
||||||
image_bytes = await AsyncHttpx.get_content(self.image_source)
|
|
||||||
mime_type = self.mime_type or "image/jpeg"
|
|
||||||
base64_data = base64.b64encode(image_bytes).decode("utf-8")
|
|
||||||
return {
|
|
||||||
"inlineData": {"mimeType": mime_type, "data": base64_data}
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"下载或编码URL图片失败: {e}", e=e)
|
|
||||||
raise ValueError(f"无法处理图片URL: {e}")
|
|
||||||
else:
|
|
||||||
raise ValueError(f"不支持的图像源格式: {self.image_source[:50]}...")
|
|
||||||
else:
|
|
||||||
return {"type": "image_url", "image_url": {"url": self.image_source}}
|
|
||||||
|
|
||||||
elif self.type == "video":
|
|
||||||
if not self.video_source:
|
|
||||||
raise ValueError("视频类型的内容必须包含video_source")
|
|
||||||
|
|
||||||
if api_type == "gemini":
|
|
||||||
# Gemini 支持视频,但需要通过 File API 上传
|
|
||||||
if self.video_source.startswith("data:"):
|
|
||||||
# 处理 base64 视频数据
|
|
||||||
try:
|
|
||||||
header, data = self.video_source.split(",", 1)
|
|
||||||
mime_type = header.split(";")[0].replace("data:", "")
|
|
||||||
return {"inlineData": {"mimeType": mime_type, "data": data}}
|
|
||||||
except (ValueError, IndexError):
|
|
||||||
raise ValueError(
|
|
||||||
f"无法解析Base64视频数据: {self.video_source[:50]}..."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 对于 URL 或其他格式,暂时不支持直接内联
|
|
||||||
raise ValueError(
|
|
||||||
"Gemini API 的视频处理需要通过 File API 上传,不支持直接 URL"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 其他 API 可能不支持视频
|
|
||||||
raise ValueError(f"API类型 '{api_type}' 不支持视频内容")
|
|
||||||
|
|
||||||
elif self.type == "audio":
|
|
||||||
if not self.audio_source:
|
|
||||||
raise ValueError("音频类型的内容必须包含audio_source")
|
|
||||||
|
|
||||||
if api_type == "gemini":
|
|
||||||
# Gemini 支持音频,处理方式类似视频
|
|
||||||
if self.audio_source.startswith("data:"):
|
|
||||||
try:
|
|
||||||
header, data = self.audio_source.split(",", 1)
|
|
||||||
mime_type = header.split(";")[0].replace("data:", "")
|
|
||||||
return {"inlineData": {"mimeType": mime_type, "data": data}}
|
|
||||||
except (ValueError, IndexError):
|
|
||||||
raise ValueError(
|
|
||||||
f"无法解析Base64音频数据: {self.audio_source[:50]}..."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Gemini API 的音频处理需要通过 File API 上传,不支持直接 URL"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"API类型 '{api_type}' 不支持音频内容")
|
|
||||||
|
|
||||||
elif self.type == "file":
|
|
||||||
if api_type == "gemini" and self.file_uri:
|
|
||||||
return {
|
|
||||||
"fileData": {"mimeType": self.mime_type, "fileUri": self.file_uri}
|
|
||||||
}
|
|
||||||
elif self.file_source:
|
|
||||||
file_name = (
|
|
||||||
self.metadata.get("name", "file") if self.metadata else "file"
|
|
||||||
)
|
|
||||||
if api_type == "gemini":
|
|
||||||
return {"text": f"[文件: {file_name}]\n{self.file_source}"}
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"type": "text",
|
|
||||||
"text": f"[文件: {file_name}]\n{self.file_source}",
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
raise ValueError("文件类型的内容必须包含file_uri或file_source")
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError(f"不支持的内容类型: {self.type}")
|
|
||||||
|
|
||||||
|
|
||||||
class LLMMessage(BaseModel):
|
|
||||||
"""LLM 消息"""
|
|
||||||
|
|
||||||
role: str
|
|
||||||
content: str | list[LLMContentPart]
|
|
||||||
name: str | None = None
|
|
||||||
tool_calls: list[Any] | None = None
|
|
||||||
tool_call_id: str | None = None
|
|
||||||
|
|
||||||
def model_post_init(self, /, __context: Any) -> None:
|
|
||||||
"""验证消息的有效性"""
|
|
||||||
_ = __context
|
|
||||||
if self.role == "tool":
|
|
||||||
if not self.tool_call_id:
|
|
||||||
raise ValueError("工具角色的消息必须包含 tool_call_id")
|
|
||||||
if not self.name:
|
|
||||||
raise ValueError("工具角色的消息必须包含函数名 (在 name 字段中)")
|
|
||||||
if self.role == "tool" and not isinstance(self.content, str):
|
|
||||||
logger.warning(
|
|
||||||
f"工具角色消息的内容期望是字符串,但得到的是: {type(self.content)}. "
|
|
||||||
"将尝试转换为字符串。"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
self.content = str(self.content)
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(f"无法将工具角色的内容转换为字符串: {e}")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def user(cls, content: str | list[LLMContentPart]) -> "LLMMessage":
|
|
||||||
"""创建用户消息"""
|
|
||||||
return cls(role="user", content=content)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def assistant_tool_calls(
|
|
||||||
cls,
|
|
||||||
tool_calls: list[Any],
|
|
||||||
content: str | list[LLMContentPart] = "",
|
|
||||||
) -> "LLMMessage":
|
|
||||||
"""创建助手请求工具调用的消息"""
|
|
||||||
return cls(role="assistant", content=content, tool_calls=tool_calls)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def assistant_text_response(
|
|
||||||
cls, content: str | list[LLMContentPart]
|
|
||||||
) -> "LLMMessage":
|
|
||||||
"""创建助手纯文本回复的消息"""
|
|
||||||
return cls(role="assistant", content=content, tool_calls=None)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def tool_response(
|
|
||||||
cls,
|
|
||||||
tool_call_id: str,
|
|
||||||
function_name: str,
|
|
||||||
result: Any,
|
|
||||||
) -> "LLMMessage":
|
|
||||||
"""创建工具执行结果的消息"""
|
|
||||||
import json
|
|
||||||
|
|
||||||
try:
|
|
||||||
content_str = json.dumps(result)
|
|
||||||
except TypeError as e:
|
|
||||||
logger.error(
|
|
||||||
f"工具 '{function_name}' 的结果无法JSON序列化: {result}. 错误: {e}"
|
|
||||||
)
|
|
||||||
content_str = json.dumps(
|
|
||||||
{"error": "工具结果无法JSON序列化", "details": str(e)}
|
|
||||||
)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
role="tool",
|
|
||||||
content=content_str,
|
|
||||||
tool_call_id=tool_call_id,
|
|
||||||
name=function_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def system(cls, content: str) -> "LLMMessage":
|
|
||||||
"""创建系统消息"""
|
|
||||||
return cls(role="system", content=content)
|
|
||||||
|
|
||||||
|
|
||||||
class LLMResponse(BaseModel):
|
|
||||||
"""LLM 响应"""
|
|
||||||
|
|
||||||
text: str
|
|
||||||
images: list[bytes] | None = None
|
|
||||||
usage_info: dict[str, Any] | None = None
|
|
||||||
raw_response: dict[str, Any] | None = None
|
|
||||||
tool_calls: list[Any] | None = None
|
|
||||||
code_executions: list[Any] | None = None
|
|
||||||
grounding_metadata: Any | None = None
|
|
||||||
cache_info: Any | None = None
|
|
||||||
@ -1,78 +0,0 @@
|
|||||||
"""
|
|
||||||
LLM 枚举类型定义
|
|
||||||
"""
|
|
||||||
|
|
||||||
from enum import Enum, auto
|
|
||||||
|
|
||||||
|
|
||||||
class ModelProvider(Enum):
|
|
||||||
"""模型提供商枚举"""
|
|
||||||
|
|
||||||
OPENAI = "openai"
|
|
||||||
GEMINI = "gemini"
|
|
||||||
ZHIXPU = "zhipu"
|
|
||||||
CUSTOM = "custom"
|
|
||||||
|
|
||||||
|
|
||||||
class ResponseFormat(Enum):
|
|
||||||
"""响应格式枚举"""
|
|
||||||
|
|
||||||
TEXT = "text"
|
|
||||||
JSON = "json"
|
|
||||||
MULTIMODAL = "multimodal"
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingTaskType(str, Enum):
|
|
||||||
"""文本嵌入任务类型 (主要用于Gemini)"""
|
|
||||||
|
|
||||||
RETRIEVAL_QUERY = "RETRIEVAL_QUERY"
|
|
||||||
RETRIEVAL_DOCUMENT = "RETRIEVAL_DOCUMENT"
|
|
||||||
SEMANTIC_SIMILARITY = "SEMANTIC_SIMILARITY"
|
|
||||||
CLASSIFICATION = "CLASSIFICATION"
|
|
||||||
CLUSTERING = "CLUSTERING"
|
|
||||||
QUESTION_ANSWERING = "QUESTION_ANSWERING"
|
|
||||||
FACT_VERIFICATION = "FACT_VERIFICATION"
|
|
||||||
|
|
||||||
|
|
||||||
class ToolCategory(Enum):
|
|
||||||
"""工具分类枚举"""
|
|
||||||
|
|
||||||
FILE_SYSTEM = auto()
|
|
||||||
NETWORK = auto()
|
|
||||||
SYSTEM_INFO = auto()
|
|
||||||
CALCULATION = auto()
|
|
||||||
DATA_PROCESSING = auto()
|
|
||||||
CUSTOM = auto()
|
|
||||||
|
|
||||||
|
|
||||||
class TaskType(Enum):
|
|
||||||
"""任务类型枚举"""
|
|
||||||
|
|
||||||
CHAT = "chat"
|
|
||||||
CODE = "code"
|
|
||||||
SEARCH = "search"
|
|
||||||
ANALYSIS = "analysis"
|
|
||||||
GENERATION = "generation"
|
|
||||||
MULTIMODAL = "multimodal"
|
|
||||||
|
|
||||||
|
|
||||||
class LLMErrorCode(Enum):
|
|
||||||
"""LLM 服务相关的错误代码枚举"""
|
|
||||||
|
|
||||||
MODEL_INIT_FAILED = 2000
|
|
||||||
MODEL_NOT_FOUND = 2001
|
|
||||||
API_REQUEST_FAILED = 2002
|
|
||||||
API_RESPONSE_INVALID = 2003
|
|
||||||
API_KEY_INVALID = 2004
|
|
||||||
API_QUOTA_EXCEEDED = 2005
|
|
||||||
API_TIMEOUT = 2006
|
|
||||||
API_RATE_LIMITED = 2007
|
|
||||||
NO_AVAILABLE_KEYS = 2008
|
|
||||||
UNKNOWN_API_TYPE = 2009
|
|
||||||
CONFIGURATION_ERROR = 2010
|
|
||||||
RESPONSE_PARSE_ERROR = 2011
|
|
||||||
CONTEXT_LENGTH_EXCEEDED = 2012
|
|
||||||
CONTENT_FILTERED = 2013
|
|
||||||
USER_LOCATION_NOT_SUPPORTED = 2014
|
|
||||||
GENERATION_FAILED = 2015
|
|
||||||
EMBEDDING_FAILED = 2016
|
|
||||||
@ -2,9 +2,31 @@
|
|||||||
LLM 异常类型定义
|
LLM 异常类型定义
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from .enums import LLMErrorCode
|
|
||||||
|
class LLMErrorCode(Enum):
|
||||||
|
"""LLM 服务相关的错误代码枚举"""
|
||||||
|
|
||||||
|
MODEL_INIT_FAILED = 2000
|
||||||
|
MODEL_NOT_FOUND = 2001
|
||||||
|
API_REQUEST_FAILED = 2002
|
||||||
|
API_RESPONSE_INVALID = 2003
|
||||||
|
API_KEY_INVALID = 2004
|
||||||
|
API_QUOTA_EXCEEDED = 2005
|
||||||
|
API_TIMEOUT = 2006
|
||||||
|
API_RATE_LIMITED = 2007
|
||||||
|
NO_AVAILABLE_KEYS = 2008
|
||||||
|
UNKNOWN_API_TYPE = 2009
|
||||||
|
CONFIGURATION_ERROR = 2010
|
||||||
|
RESPONSE_PARSE_ERROR = 2011
|
||||||
|
CONTEXT_LENGTH_EXCEEDED = 2012
|
||||||
|
CONTENT_FILTERED = 2013
|
||||||
|
USER_LOCATION_NOT_SUPPORTED = 2014
|
||||||
|
INVALID_PARAMETER = 2017
|
||||||
|
GENERATION_FAILED = 2015
|
||||||
|
EMBEDDING_FAILED = 2016
|
||||||
|
|
||||||
|
|
||||||
class LLMException(Exception):
|
class LLMException(Exception):
|
||||||
@ -27,7 +49,11 @@ class LLMException(Exception):
|
|||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
if self.details:
|
if self.details:
|
||||||
return f"{self.message} (错误码: {self.code.name}, 详情: {self.details})"
|
safe_details = {k: v for k, v in self.details.items() if k != "api_key"}
|
||||||
|
if safe_details:
|
||||||
|
return (
|
||||||
|
f"{self.message} (错误码: {self.code.name}, 详情: {safe_details})"
|
||||||
|
)
|
||||||
return f"{self.message} (错误码: {self.code.name})"
|
return f"{self.message} (错误码: {self.code.name})"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -46,10 +72,13 @@ class LLMException(Exception):
|
|||||||
"当前所有API密钥均不可用,请稍后再试或联系管理员。"
|
"当前所有API密钥均不可用,请稍后再试或联系管理员。"
|
||||||
),
|
),
|
||||||
LLMErrorCode.USER_LOCATION_NOT_SUPPORTED: (
|
LLMErrorCode.USER_LOCATION_NOT_SUPPORTED: (
|
||||||
"当前地区暂不支持此AI服务,请联系管理员或尝试其他模型。"
|
"当前网络环境不支持此 AI 模型 (如 Gemini/OpenAI)。\n"
|
||||||
|
"原因: 代理节点所在地区(如香港/国内/非支持区)被服务商屏蔽。\n"
|
||||||
|
"建议: 请尝试更换代理节点至支持的地区(如美国/日本/新加坡)。"
|
||||||
),
|
),
|
||||||
LLMErrorCode.API_REQUEST_FAILED: "AI服务请求失败,请稍后再试。",
|
LLMErrorCode.API_REQUEST_FAILED: "AI服务请求失败,请稍后再试。",
|
||||||
LLMErrorCode.API_RESPONSE_INVALID: "AI服务响应异常,请稍后再试。",
|
LLMErrorCode.API_RESPONSE_INVALID: "AI服务响应异常,请稍后再试。",
|
||||||
|
LLMErrorCode.INVALID_PARAMETER: "请求参数错误,请检查输入内容。",
|
||||||
LLMErrorCode.CONFIGURATION_ERROR: "AI服务配置错误,请联系管理员。",
|
LLMErrorCode.CONFIGURATION_ERROR: "AI服务配置错误,请联系管理员。",
|
||||||
LLMErrorCode.CONTEXT_LENGTH_EXCEEDED: "输入内容过长,请缩短后重试。",
|
LLMErrorCode.CONTEXT_LENGTH_EXCEEDED: "输入内容过长,请缩短后重试。",
|
||||||
LLMErrorCode.CONTENT_FILTERED: "内容被安全过滤,请修改后重试。",
|
LLMErrorCode.CONTENT_FILTERED: "内容被安全过滤,请修改后重试。",
|
||||||
@ -66,15 +95,19 @@ def get_user_friendly_error_message(error: Exception) -> str:
|
|||||||
|
|
||||||
error_str = str(error).lower()
|
error_str = str(error).lower()
|
||||||
|
|
||||||
if "timeout" in error_str or "超时" in error_str:
|
if "timeout" in error_str or "timed out" in error_str:
|
||||||
return "请求超时,请稍后再试。"
|
return "网络请求超时,请检查服务器网络或代理连接。"
|
||||||
elif "connection" in error_str or "连接" in error_str:
|
if "connect" in error_str and ("refused" in error_str or "error" in error_str):
|
||||||
return "网络连接失败,请检查网络后重试。"
|
return "无法连接到 AI 服务商,请检查网络连接或代理设置。"
|
||||||
elif "permission" in error_str or "权限" in error_str:
|
if "proxy" in error_str:
|
||||||
return "权限不足,请联系管理员。"
|
return "代理连接失败,请检查代理服务器是否正常运行。"
|
||||||
elif "not found" in error_str or "未找到" in error_str:
|
if "ssl" in error_str or "certificate" in error_str:
|
||||||
return "请求的资源未找到,请检查配置。"
|
return "SSL 证书验证失败,请检查网络环境。"
|
||||||
elif "invalid" in error_str or "无效" in error_str:
|
if "permission" in error_str or "forbidden" in error_str:
|
||||||
|
return "权限不足,可能是 API Key 权限受限。"
|
||||||
|
if "not found" in error_str:
|
||||||
|
return "请求的资源未找到 (404),请检查模型名称或端点配置。"
|
||||||
|
if "invalid" in error_str or "无效" in error_str:
|
||||||
return "请求参数无效,请检查输入。"
|
return "请求参数无效,请检查输入。"
|
||||||
else:
|
|
||||||
return "服务暂时不可用,请稍后再试。"
|
return f"服务暂时不可用 ({type(error).__name__}),请稍后再试。"
|
||||||
|
|||||||
@ -4,12 +4,459 @@ LLM 数据模型定义
|
|||||||
包含模型信息、配置、工具定义和响应数据的模型类。
|
包含模型信息、配置、工具定义和响应数据的模型类。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import base64
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from enum import Enum, auto
|
||||||
|
import mimetypes
|
||||||
|
from pathlib import Path
|
||||||
|
import sys
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
import aiofiles
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from .enums import ModelProvider, ToolCategory
|
from zhenxun.services.log import logger
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 11):
|
||||||
|
from enum import StrEnum
|
||||||
|
else:
|
||||||
|
from strenum import StrEnum
|
||||||
|
|
||||||
|
|
||||||
|
class ModelProvider(Enum):
|
||||||
|
"""模型提供商枚举"""
|
||||||
|
|
||||||
|
OPENAI = "openai"
|
||||||
|
GEMINI = "gemini"
|
||||||
|
ZHIXPU = "zhipu"
|
||||||
|
CUSTOM = "custom"
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseFormat(Enum):
|
||||||
|
"""响应格式枚举"""
|
||||||
|
|
||||||
|
TEXT = "text"
|
||||||
|
JSON = "json"
|
||||||
|
MULTIMODAL = "multimodal"
|
||||||
|
|
||||||
|
|
||||||
|
class StructuredOutputStrategy(str, Enum):
|
||||||
|
"""结构化输出策略"""
|
||||||
|
|
||||||
|
NATIVE = "native"
|
||||||
|
"""使用原生 API (如 OpenAI json_object/json_schema, Gemini mime_type)"""
|
||||||
|
TOOL_CALL = "tool_call"
|
||||||
|
"""构造虚假工具调用来强制输出结构化数据 (适用于指令跟随弱但工具调用强的模型)"""
|
||||||
|
PROMPT = "prompt"
|
||||||
|
"""仅在 Prompt 中追加 Schema 说明,依赖文本补全"""
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingTaskType(str, Enum):
|
||||||
|
"""文本嵌入任务类型 (主要用于Gemini)"""
|
||||||
|
|
||||||
|
RETRIEVAL_QUERY = "RETRIEVAL_QUERY"
|
||||||
|
RETRIEVAL_DOCUMENT = "RETRIEVAL_DOCUMENT"
|
||||||
|
SEMANTIC_SIMILARITY = "SEMANTIC_SIMILARITY"
|
||||||
|
CLASSIFICATION = "CLASSIFICATION"
|
||||||
|
CLUSTERING = "CLUSTERING"
|
||||||
|
QUESTION_ANSWERING = "QUESTION_ANSWERING"
|
||||||
|
FACT_VERIFICATION = "FACT_VERIFICATION"
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCategory(Enum):
|
||||||
|
"""工具分类枚举"""
|
||||||
|
|
||||||
|
FILE_SYSTEM = auto()
|
||||||
|
NETWORK = auto()
|
||||||
|
SYSTEM_INFO = auto()
|
||||||
|
CALCULATION = auto()
|
||||||
|
DATA_PROCESSING = auto()
|
||||||
|
CUSTOM = auto()
|
||||||
|
|
||||||
|
|
||||||
|
class CodeExecutionOutcome(StrEnum):
|
||||||
|
"""代码执行结果状态枚举"""
|
||||||
|
|
||||||
|
OUTCOME_OK = "OUTCOME_OK"
|
||||||
|
OUTCOME_FAILED = "OUTCOME_FAILED"
|
||||||
|
OUTCOME_DEADLINE_EXCEEDED = "OUTCOME_DEADLINE_EXCEEDED"
|
||||||
|
OUTCOME_COMPILATION_ERROR = "OUTCOME_COMPILATION_ERROR"
|
||||||
|
OUTCOME_RUNTIME_ERROR = "OUTCOME_RUNTIME_ERROR"
|
||||||
|
OUTCOME_UNKNOWN = "OUTCOME_UNKNOWN"
|
||||||
|
|
||||||
|
|
||||||
|
class TaskType(Enum):
|
||||||
|
"""任务类型枚举"""
|
||||||
|
|
||||||
|
CHAT = "chat"
|
||||||
|
CODE = "code"
|
||||||
|
SEARCH = "search"
|
||||||
|
ANALYSIS = "analysis"
|
||||||
|
GENERATION = "generation"
|
||||||
|
MULTIMODAL = "multimodal"
|
||||||
|
|
||||||
|
|
||||||
|
class LLMContentPart(BaseModel):
|
||||||
|
"""
|
||||||
|
LLM 消息内容部分 - 支持多模态内容。
|
||||||
|
|
||||||
|
这是一个联合体模型,`type` 字段决定了哪些其他字段是有效的。
|
||||||
|
例如:
|
||||||
|
- type='text': 使用 `text` 字段。
|
||||||
|
- type='image': 使用 `image_source` 字段。
|
||||||
|
- type='executable_code': 使用 `code_language` 和 `code_content` 字段。
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: str
|
||||||
|
text: str | None = None
|
||||||
|
image_source: str | None = None
|
||||||
|
audio_source: str | None = None
|
||||||
|
video_source: str | None = None
|
||||||
|
document_source: str | None = None
|
||||||
|
file_uri: str | None = None
|
||||||
|
file_source: str | None = None
|
||||||
|
url: str | None = None
|
||||||
|
mime_type: str | None = None
|
||||||
|
thought_text: str | None = None
|
||||||
|
media_resolution: str | None = None
|
||||||
|
code_language: str | None = None
|
||||||
|
code_content: str | None = None
|
||||||
|
execution_outcome: str | None = None
|
||||||
|
execution_output: str | None = None
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
def model_post_init(self, /, __context: Any) -> None:
|
||||||
|
"""验证内容部分的有效性"""
|
||||||
|
_ = __context
|
||||||
|
validation_rules = {
|
||||||
|
"text": lambda: self.text is not None,
|
||||||
|
"image": lambda: self.image_source,
|
||||||
|
"audio": lambda: self.audio_source,
|
||||||
|
"video": lambda: self.video_source,
|
||||||
|
"document": lambda: self.document_source,
|
||||||
|
"file": lambda: self.file_uri or self.file_source,
|
||||||
|
"url": lambda: self.url,
|
||||||
|
"thought": lambda: self.thought_text,
|
||||||
|
"executable_code": lambda: self.code_content is not None,
|
||||||
|
"execution_result": lambda: self.execution_outcome is not None,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.type in validation_rules:
|
||||||
|
if not validation_rules[self.type]():
|
||||||
|
raise ValueError(f"{self.type}类型的内容部分必须包含相应字段")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def text_part(cls, text: str) -> "LLMContentPart":
|
||||||
|
"""创建文本内容部分"""
|
||||||
|
return cls(type="text", text=text)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def thought_part(cls, text: str) -> "LLMContentPart":
|
||||||
|
"""创建思考过程内容部分"""
|
||||||
|
return cls(type="thought", thought_text=text)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def image_url_part(cls, url: str) -> "LLMContentPart":
|
||||||
|
"""创建图片URL内容部分"""
|
||||||
|
return cls(type="image", image_source=url)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def image_base64_part(
|
||||||
|
cls, data: str, mime_type: str = "image/png"
|
||||||
|
) -> "LLMContentPart":
|
||||||
|
"""创建Base64图片内容部分"""
|
||||||
|
data_url = f"data:{mime_type};base64,{data}"
|
||||||
|
return cls(type="image", image_source=data_url)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def audio_url_part(cls, url: str, mime_type: str = "audio/wav") -> "LLMContentPart":
|
||||||
|
"""创建音频URL内容部分"""
|
||||||
|
return cls(type="audio", audio_source=url, mime_type=mime_type)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def video_url_part(cls, url: str, mime_type: str = "video/mp4") -> "LLMContentPart":
|
||||||
|
"""创建视频URL内容部分"""
|
||||||
|
return cls(type="video", video_source=url, mime_type=mime_type)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def video_base64_part(
|
||||||
|
cls, data: str, mime_type: str = "video/mp4"
|
||||||
|
) -> "LLMContentPart":
|
||||||
|
"""创建Base64视频内容部分"""
|
||||||
|
data_url = f"data:{mime_type};base64,{data}"
|
||||||
|
return cls(type="video", video_source=data_url, mime_type=mime_type)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def audio_base64_part(
|
||||||
|
cls, data: str, mime_type: str = "audio/wav"
|
||||||
|
) -> "LLMContentPart":
|
||||||
|
"""创建Base64音频内容部分"""
|
||||||
|
data_url = f"data:{mime_type};base64,{data}"
|
||||||
|
return cls(type="audio", audio_source=data_url, mime_type=mime_type)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def file_uri_part(
|
||||||
|
cls,
|
||||||
|
file_uri: str,
|
||||||
|
mime_type: str | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> "LLMContentPart":
|
||||||
|
"""创建Gemini File API URI内容部分"""
|
||||||
|
return cls(
|
||||||
|
type="file",
|
||||||
|
file_uri=file_uri,
|
||||||
|
mime_type=mime_type,
|
||||||
|
metadata=metadata or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def executable_code_part(cls, language: str, code: str) -> "LLMContentPart":
|
||||||
|
"""创建可执行代码内容部分"""
|
||||||
|
return cls(type="executable_code", code_language=language, code_content=code)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execution_result_part(
|
||||||
|
cls, outcome: str, output: str | None
|
||||||
|
) -> "LLMContentPart":
|
||||||
|
"""创建代码执行结果部分"""
|
||||||
|
return cls(
|
||||||
|
type="execution_result", execution_outcome=outcome, execution_output=output
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def from_path(
|
||||||
|
cls, path_like: str | Path, target_api: str | None = None
|
||||||
|
) -> "LLMContentPart | None":
|
||||||
|
"""
|
||||||
|
从本地文件路径创建 LLMContentPart。
|
||||||
|
自动检测MIME类型,并根据类型(如图片)可能加载为Base64。
|
||||||
|
target_api 可以用于提示如何最好地准备数据(例如 'gemini' 可能偏好 base64)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
path = Path(path_like)
|
||||||
|
if not path.exists() or not path.is_file():
|
||||||
|
logger.warning(f"文件不存在或不是一个文件: {path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
mime_type, _ = mimetypes.guess_type(path.resolve().as_uri())
|
||||||
|
|
||||||
|
if not mime_type:
|
||||||
|
logger.warning(
|
||||||
|
f"无法猜测文件 {path.name} 的MIME类型,将尝试作为文本文件处理。"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
async with aiofiles.open(path, encoding="utf-8") as f:
|
||||||
|
text_content = await f.read()
|
||||||
|
return cls.text_part(text_content)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"读取文本文件 {path.name} 失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if mime_type.startswith("image/"):
|
||||||
|
if target_api == "gemini" or not path.is_absolute():
|
||||||
|
try:
|
||||||
|
async with aiofiles.open(path, "rb") as f:
|
||||||
|
img_bytes = await f.read()
|
||||||
|
base64_data = base64.b64encode(img_bytes).decode("utf-8")
|
||||||
|
return cls.image_base64_part(
|
||||||
|
data=base64_data, mime_type=mime_type
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"读取或编码图片文件 {path.name} 失败: {e}")
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"为本地图片路径 {path.name} 生成 image_url_part。"
|
||||||
|
"实际API可能不支持 file:// URI。考虑使用Base64或公网URL。"
|
||||||
|
)
|
||||||
|
return cls.image_url_part(url=path.resolve().as_uri())
|
||||||
|
elif mime_type.startswith("audio/"):
|
||||||
|
return cls.audio_url_part(
|
||||||
|
url=path.resolve().as_uri(), mime_type=mime_type
|
||||||
|
)
|
||||||
|
elif mime_type.startswith("video/"):
|
||||||
|
if target_api == "gemini":
|
||||||
|
try:
|
||||||
|
async with aiofiles.open(path, "rb") as f:
|
||||||
|
video_bytes = await f.read()
|
||||||
|
base64_data = base64.b64encode(video_bytes).decode("utf-8")
|
||||||
|
return cls.video_base64_part(
|
||||||
|
data=base64_data, mime_type=mime_type
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"读取或编码视频文件 {path.name} 失败: {e}")
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return cls.video_url_part(
|
||||||
|
url=path.resolve().as_uri(), mime_type=mime_type
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
mime_type.startswith("text/")
|
||||||
|
or mime_type == "application/json"
|
||||||
|
or mime_type == "application/xml"
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
async with aiofiles.open(path, encoding="utf-8") as f:
|
||||||
|
text_content = await f.read()
|
||||||
|
return cls.text_part(text_content)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"读取文本类文件 {path.name} 失败: {e}")
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"文件 {path.name} (MIME: {mime_type}) 将作为通用文件URI处理。"
|
||||||
|
)
|
||||||
|
return cls.file_uri_part(
|
||||||
|
file_uri=path.resolve().as_uri(),
|
||||||
|
mime_type=mime_type,
|
||||||
|
metadata={"name": path.name, "source": "local_path"},
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"从路径 {path_like} 创建LLMContentPart时出错: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def is_image_url(self) -> bool:
|
||||||
|
"""检查图像源是否为URL"""
|
||||||
|
if not self.image_source:
|
||||||
|
return False
|
||||||
|
return self.image_source.startswith(("http://", "https://"))
|
||||||
|
|
||||||
|
def is_image_base64(self) -> bool:
|
||||||
|
"""检查图像源是否为Base64 Data URL"""
|
||||||
|
if not self.image_source:
|
||||||
|
return False
|
||||||
|
return self.image_source.startswith("data:")
|
||||||
|
|
||||||
|
def get_base64_data(self) -> tuple[str, str] | None:
|
||||||
|
"""从Data URL中提取Base64数据和MIME类型"""
|
||||||
|
if not self.is_image_base64() or not self.image_source:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
header, data = self.image_source.split(",", 1)
|
||||||
|
mime_part = header.split(";")[0].replace("data:", "")
|
||||||
|
return mime_part, data
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
logger.warning(f"无法解析Base64图像数据: {self.image_source[:50]}...")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class LLMMessage(BaseModel):
|
||||||
|
"""
|
||||||
|
LLM 消息对象,用于构建对话历史。
|
||||||
|
|
||||||
|
核心字段说明:
|
||||||
|
- role: 消息角色,推荐值为 'user', 'assistant', 'system', 'tool'。
|
||||||
|
- content: 消息内容,可以是纯文本字符串,也可以是 LLMContentPart 列表(用于多模态)
|
||||||
|
- tool_calls: (仅 assistant) 包含模型生成的工具调用请求。
|
||||||
|
- tool_call_id: (仅 tool) 对应 tool 消息响应的调用 ID。
|
||||||
|
- name: (仅 tool) 对应 tool 消息响应的函数名称。
|
||||||
|
"""
|
||||||
|
|
||||||
|
role: str
|
||||||
|
content: str | list[LLMContentPart]
|
||||||
|
name: str | None = None
|
||||||
|
tool_calls: list[Any] | None = None
|
||||||
|
tool_call_id: str | None = None
|
||||||
|
thought_signature: str | None = None
|
||||||
|
|
||||||
|
def model_post_init(self, /, __context: Any) -> None:
|
||||||
|
"""验证消息的有效性"""
|
||||||
|
_ = __context
|
||||||
|
if self.role == "tool":
|
||||||
|
if not self.tool_call_id:
|
||||||
|
raise ValueError("工具角色的消息必须包含 tool_call_id")
|
||||||
|
if not self.name:
|
||||||
|
raise ValueError("工具角色的消息必须包含函数名 (在 name 字段中)")
|
||||||
|
if self.role == "tool" and not isinstance(self.content, str):
|
||||||
|
logger.warning(
|
||||||
|
f"工具角色消息的内容期望是字符串,但得到的是: {type(self.content)}. "
|
||||||
|
"将尝试转换为字符串。"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
self.content = str(self.content)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"无法将工具角色的内容转换为字符串: {e}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def user(cls, content: str | list[LLMContentPart]) -> "LLMMessage":
|
||||||
|
"""创建用户消息"""
|
||||||
|
return cls(role="user", content=content)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def assistant_tool_calls(
|
||||||
|
cls,
|
||||||
|
tool_calls: list[Any],
|
||||||
|
content: str | list[LLMContentPart] = "",
|
||||||
|
) -> "LLMMessage":
|
||||||
|
"""创建助手请求工具调用的消息"""
|
||||||
|
return cls(role="assistant", content=content, tool_calls=tool_calls)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def assistant_text_response(
|
||||||
|
cls, content: str | list[LLMContentPart]
|
||||||
|
) -> "LLMMessage":
|
||||||
|
"""创建助手纯文本回复的消息"""
|
||||||
|
return cls(role="assistant", content=content, tool_calls=None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tool_response(
|
||||||
|
cls,
|
||||||
|
tool_call_id: str,
|
||||||
|
function_name: str,
|
||||||
|
result: Any,
|
||||||
|
) -> "LLMMessage":
|
||||||
|
"""创建工具执行结果的消息"""
|
||||||
|
import json
|
||||||
|
|
||||||
|
try:
|
||||||
|
content_str = json.dumps(result)
|
||||||
|
except TypeError as e:
|
||||||
|
logger.error(
|
||||||
|
f"工具 '{function_name}' 的结果无法JSON序列化: {result}. 错误: {e}"
|
||||||
|
)
|
||||||
|
content_str = json.dumps(
|
||||||
|
{"error": "工具结果无法JSON序列化", "details": str(e)}
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
role="tool",
|
||||||
|
content=content_str,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
name=function_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def system(cls, content: str) -> "LLMMessage":
|
||||||
|
"""创建系统消息"""
|
||||||
|
return cls(role="system", content=content)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMResponse(BaseModel):
|
||||||
|
"""
|
||||||
|
LLM 响应对象,封装了模型生成的全部信息。
|
||||||
|
|
||||||
|
核心字段说明:
|
||||||
|
- text: 模型生成的文本内容。如果是纯文本回复,此字段即为结果。
|
||||||
|
- tool_calls: 如果模型决定调用工具,此列表包含调用详情。
|
||||||
|
- content_parts: 包含多模态或结构化内容的原始部分列表(如思维链、代码块)。
|
||||||
|
- raw_response: 原始的第三方 API 响应字典(用于调试)。
|
||||||
|
- images: 如果请求涉及生图,此处包含生成的图片数据。
|
||||||
|
"""
|
||||||
|
|
||||||
|
text: str
|
||||||
|
content_parts: list[Any] | None = None
|
||||||
|
images: list[bytes | Path] | None = None
|
||||||
|
usage_info: dict[str, Any] | None = None
|
||||||
|
raw_response: dict[str, Any] | None = None
|
||||||
|
tool_calls: list[Any] | None = None
|
||||||
|
code_executions: list[Any] | None = None
|
||||||
|
grounding_metadata: Any | None = None
|
||||||
|
cache_info: Any | None = None
|
||||||
|
thought_text: str | None = None
|
||||||
|
thought_signature: str | None = None
|
||||||
|
|
||||||
|
|
||||||
ModelName = str | None
|
ModelName = str | None
|
||||||
|
|
||||||
@ -26,6 +473,64 @@ class ToolDefinition(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolChoice(BaseModel):
|
||||||
|
"""统一的工具选择配置"""
|
||||||
|
|
||||||
|
mode: Literal["auto", "none", "any", "required"] = Field(
|
||||||
|
default="auto", description="工具调用模式"
|
||||||
|
)
|
||||||
|
allowed_function_names: list[str] | None = Field(
|
||||||
|
default=None, description="允许调用的函数名称列表"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BasePlatformTool(BaseModel):
|
||||||
|
"""平台原生工具基类"""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "forbid"
|
||||||
|
|
||||||
|
def get_tool_declaration(self) -> dict[str, Any]:
|
||||||
|
"""获取放入 'tools' 列表中的声明对象 (Snake Case)"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_tool_config(self) -> dict[str, Any] | None:
|
||||||
|
"""获取放入 'toolConfig' 中的配置对象 (Snake Case)"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiCodeExecution(BasePlatformTool):
|
||||||
|
"""Gemini 代码执行工具"""
|
||||||
|
|
||||||
|
def get_tool_declaration(self) -> dict[str, Any]:
|
||||||
|
return {"code_execution": {}}
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiGoogleSearch(BasePlatformTool):
|
||||||
|
"""Gemini 谷歌搜索 (Grounding) 工具"""
|
||||||
|
|
||||||
|
mode: Literal["MODE_DYNAMIC"] = "MODE_DYNAMIC"
|
||||||
|
dynamic_threshold: float | None = Field(default=None)
|
||||||
|
|
||||||
|
def get_tool_declaration(self) -> dict[str, Any]:
|
||||||
|
return {"google_search": {}}
|
||||||
|
|
||||||
|
def get_tool_config(self) -> dict[str, Any] | None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiUrlContext(BasePlatformTool):
|
||||||
|
"""Gemini 网址上下文工具"""
|
||||||
|
|
||||||
|
urls: list[str] = Field(..., description="作为上下文的 URL 列表", max_length=20)
|
||||||
|
|
||||||
|
def get_tool_declaration(self) -> dict[str, Any]:
|
||||||
|
return {"google_search": {}, "url_context": {}}
|
||||||
|
|
||||||
|
def get_tool_config(self) -> dict[str, Any] | None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class ToolResult(BaseModel):
|
class ToolResult(BaseModel):
|
||||||
"""
|
"""
|
||||||
一个结构化的工具执行结果模型。
|
一个结构化的工具执行结果模型。
|
||||||
@ -87,6 +592,8 @@ class ModelDetail(BaseModel):
|
|||||||
is_embedding_model: bool = False
|
is_embedding_model: bool = False
|
||||||
temperature: float | None = None
|
temperature: float | None = None
|
||||||
max_tokens: int | None = None
|
max_tokens: int | None = None
|
||||||
|
api_type: str | None = None
|
||||||
|
endpoint: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class ProviderConfig(BaseModel):
|
class ProviderConfig(BaseModel):
|
||||||
@ -116,6 +623,7 @@ class LLMToolCall(BaseModel):
|
|||||||
|
|
||||||
id: str
|
id: str
|
||||||
function: LLMToolFunction
|
function: LLMToolFunction
|
||||||
|
thought_signature: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class LLMCodeExecution(BaseModel):
|
class LLMCodeExecution(BaseModel):
|
||||||
@ -143,6 +651,12 @@ class LLMGroundingMetadata(BaseModel):
|
|||||||
web_search_queries: list[str] | None = None
|
web_search_queries: list[str] | None = None
|
||||||
grounding_attributions: list[LLMGroundingAttribution] | None = None
|
grounding_attributions: list[LLMGroundingAttribution] | None = None
|
||||||
search_suggestions: list[dict[str, Any]] | None = None
|
search_suggestions: list[dict[str, Any]] | None = None
|
||||||
|
search_entry_point: str | None = Field(
|
||||||
|
default=None, description="Google搜索建议的HTML片段(renderedContent)"
|
||||||
|
)
|
||||||
|
map_widget_token: str | None = Field(
|
||||||
|
default=None, description="Google Maps 前端组件令牌"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LLMCacheInfo(BaseModel):
|
class LLMCacheInfo(BaseModel):
|
||||||
|
|||||||
@ -2,10 +2,97 @@
|
|||||||
LLM 模块的协议定义
|
LLM 模块的协议定义
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Protocol
|
from abc import ABC
|
||||||
|
from typing import TYPE_CHECKING, Any, Protocol, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from .models import ToolDefinition, ToolResult
|
from .models import ToolDefinition, ToolResult
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .models import LLMMessage, LLMResponse, LLMToolCall
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCallData(BaseModel):
|
||||||
|
"""传递给 on_tool_start 的数据模型"""
|
||||||
|
|
||||||
|
tool_name: str
|
||||||
|
tool_args: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCallCompleteData(BaseModel):
|
||||||
|
"""传递给 on_tool_call_complete 的数据模型"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
arguments: str
|
||||||
|
result: "ToolResult"
|
||||||
|
|
||||||
|
|
||||||
|
class BaseCallbackHandler(ABC):
|
||||||
|
"""
|
||||||
|
Agent/LLM 生命周期回调处理器的基类。
|
||||||
|
下沉至 LLM 层以允许 ToolInvoker 直接调用。
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def on_agent_start(self, messages: list["LLMMessage"], **kwargs: Any) -> None:
|
||||||
|
"""在 AgentExecutor 开始运行时调用。"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def on_model_start(
|
||||||
|
self, model_name: str, messages: list["LLMMessage"], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""在向LLM发起请求之前调用。"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def on_model_end(
|
||||||
|
self, response: "LLMResponse", duration: float, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""在收到LLM响应之后调用。"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def on_tool_start(
|
||||||
|
self, tool_call: "LLMToolCall", data: ToolCallData, **kwargs: Any
|
||||||
|
) -> Union[ToolCallData, "ToolResult", None]:
|
||||||
|
"""
|
||||||
|
在单个工具即将被执行时调用。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
ToolCallData: 修改参数并继续执行
|
||||||
|
ToolResult: 拦截执行并直接返回给模型
|
||||||
|
None: 正常继续
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def on_tool_end(
|
||||||
|
self,
|
||||||
|
result: Union["ToolResult", None],
|
||||||
|
error: Exception | None,
|
||||||
|
tool_call: "LLMToolCall",
|
||||||
|
duration: float,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""在单个工具执行完毕后调用,无论成功或失败。"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def on_tool_call_complete(
|
||||||
|
self, data: ToolCallCompleteData, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""在工具调用完成并准备创建响应消息时调用。"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def on_human_input_request(self, query: str, **kwargs: Any) -> str | None:
|
||||||
|
"""
|
||||||
|
当 Agent 需要人类输入时调用。
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def on_agent_end(
|
||||||
|
self, final_history: list["LLMMessage"], duration: float, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""在 AgentExecutor 运行结束时调用。"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ToolExecutable(Protocol):
|
class ToolExecutable(Protocol):
|
||||||
"""
|
"""
|
||||||
@ -19,10 +106,14 @@ class ToolExecutable(Protocol):
|
|||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
async def execute(self, **kwargs: Any) -> ToolResult:
|
async def execute(self, context: Any | None = None, **kwargs: Any) -> ToolResult:
|
||||||
"""
|
"""
|
||||||
异步执行工具并返回一个结构化的结果。
|
异步执行工具并返回一个结构化的结果。
|
||||||
参数由LLM根据工具定义生成。
|
参数由LLM根据工具定义生成。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: 运行时上下文 (RunContext),可选注入
|
||||||
|
**kwargs: 工具参数
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|||||||
@ -3,26 +3,176 @@ LLM 模块的工具和转换函数
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import copy
|
from collections.abc import Awaitable, Callable
|
||||||
|
import io
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
|
import aiofiles
|
||||||
|
import json_repair
|
||||||
from nonebot.adapters import Message as PlatformMessage
|
from nonebot.adapters import Message as PlatformMessage
|
||||||
|
from nonebot.compat import type_validate_json
|
||||||
from nonebot_plugin_alconna.uniseg import (
|
from nonebot_plugin_alconna.uniseg import (
|
||||||
At,
|
At,
|
||||||
File,
|
File,
|
||||||
Image,
|
Image,
|
||||||
Reply,
|
Reply,
|
||||||
|
Segment,
|
||||||
Text,
|
Text,
|
||||||
UniMessage,
|
UniMessage,
|
||||||
Video,
|
Video,
|
||||||
Voice,
|
Voice,
|
||||||
)
|
)
|
||||||
|
from PIL.Image import Image as PILImageType
|
||||||
|
from pydantic import BaseModel, Field, ValidationError, create_model
|
||||||
|
|
||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
from zhenxun.utils.http_utils import AsyncHttpx
|
from zhenxun.utils.http_utils import AsyncHttpx
|
||||||
|
from zhenxun.utils.pydantic_compat import model_validate
|
||||||
|
|
||||||
from .types import LLMContentPart, LLMMessage
|
from .types import LLMContentPart, LLMErrorCode, LLMException, LLMMessage
|
||||||
|
from .types.capabilities import ReasoningMode, get_model_capabilities
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
|
S = TypeVar("S", bound=Segment)
|
||||||
|
_SEGMENT_HANDLERS: dict[
|
||||||
|
type[Segment], Callable[[Any], Awaitable[LLMContentPart | None]]
|
||||||
|
] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_segment_handler(seg_type: type[S]):
|
||||||
|
"""装饰器:注册 Uniseg 消息段的处理器"""
|
||||||
|
|
||||||
|
def decorator(func: Callable[[S], Awaitable[LLMContentPart | None]]):
|
||||||
|
_SEGMENT_HANDLERS[seg_type] = func
|
||||||
|
return func
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
async def _process_media_data(seg: Any, default_mime: str) -> tuple[str, str] | None:
|
||||||
|
"""
|
||||||
|
[内部复用] 通用媒体数据处理:获取 Base64 数据和 MIME 类型。
|
||||||
|
优先顺序:Raw -> Path -> URL (下载)
|
||||||
|
"""
|
||||||
|
mime_type = getattr(seg, "mimetype", None) or default_mime
|
||||||
|
b64_data = None
|
||||||
|
|
||||||
|
if hasattr(seg, "raw") and seg.raw:
|
||||||
|
if isinstance(seg.raw, bytes):
|
||||||
|
b64_data = base64.b64encode(seg.raw).decode("utf-8")
|
||||||
|
|
||||||
|
elif getattr(seg, "path", None):
|
||||||
|
try:
|
||||||
|
path = Path(seg.path)
|
||||||
|
if path.exists():
|
||||||
|
async with aiofiles.open(path, "rb") as f:
|
||||||
|
content = await f.read()
|
||||||
|
b64_data = base64.b64encode(content).decode("utf-8")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"读取媒体文件失败: {seg.path}, 错误: {e}")
|
||||||
|
|
||||||
|
elif getattr(seg, "url", None):
|
||||||
|
try:
|
||||||
|
logger.debug(f"检测到媒体URL,开始下载: {seg.url}")
|
||||||
|
media_bytes = await AsyncHttpx.get_content(seg.url)
|
||||||
|
b64_data = base64.b64encode(media_bytes).decode("utf-8")
|
||||||
|
logger.debug(f"媒体文件下载成功,大小: {len(media_bytes)} bytes")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"从URL下载媒体失败: {seg.url}, 错误: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if b64_data:
|
||||||
|
return mime_type, b64_data
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@register_segment_handler(Text)
|
||||||
|
async def _handle_text(seg: Text) -> LLMContentPart | None:
|
||||||
|
if seg.text.strip():
|
||||||
|
return LLMContentPart.text_part(seg.text)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@register_segment_handler(Image)
|
||||||
|
async def _handle_image(seg: Image) -> LLMContentPart | None:
|
||||||
|
media_info = await _process_media_data(seg, "image/png")
|
||||||
|
if media_info:
|
||||||
|
mime, data = media_info
|
||||||
|
return LLMContentPart.image_base64_part(data, mime)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@register_segment_handler(Voice)
|
||||||
|
async def _handle_voice(seg: Voice) -> LLMContentPart | None:
|
||||||
|
media_info = await _process_media_data(seg, "audio/wav")
|
||||||
|
if media_info:
|
||||||
|
mime, data = media_info
|
||||||
|
return LLMContentPart.audio_base64_part(data, mime)
|
||||||
|
return LLMContentPart.text_part(f"[语音消息: {seg.id or 'unknown'}]")
|
||||||
|
|
||||||
|
|
||||||
|
@register_segment_handler(Video)
|
||||||
|
async def _handle_video(seg: Video) -> LLMContentPart | None:
|
||||||
|
media_info = await _process_media_data(seg, "video/mp4")
|
||||||
|
if media_info:
|
||||||
|
mime, data = media_info
|
||||||
|
return LLMContentPart.video_base64_part(data, mime)
|
||||||
|
return LLMContentPart.text_part(f"[视频消息: {seg.id or 'unknown'}]")
|
||||||
|
|
||||||
|
|
||||||
|
@register_segment_handler(File)
|
||||||
|
async def _handle_file(seg: File) -> LLMContentPart | None:
|
||||||
|
if seg.path:
|
||||||
|
return await LLMContentPart.from_path(seg.path)
|
||||||
|
return LLMContentPart.text_part(f"[文件: {seg.name} (ID: {seg.id})]")
|
||||||
|
|
||||||
|
|
||||||
|
@register_segment_handler(At)
|
||||||
|
async def _handle_at(seg: At) -> LLMContentPart | None:
|
||||||
|
if seg.flag == "all":
|
||||||
|
return LLMContentPart.text_part("[提及所有人]")
|
||||||
|
return LLMContentPart.text_part(f"[提及用户: {seg.target}]")
|
||||||
|
|
||||||
|
|
||||||
|
@register_segment_handler(Reply)
|
||||||
|
async def _handle_reply(seg: Reply) -> LLMContentPart | None:
|
||||||
|
text = str(seg.msg) if seg.msg else ""
|
||||||
|
if text:
|
||||||
|
return LLMContentPart.text_part(f'[回复消息: "{text[:50]}..."]')
|
||||||
|
return LLMContentPart.text_part("[回复了一条消息]")
|
||||||
|
|
||||||
|
|
||||||
|
async def _transform_to_content_part(item: Any) -> LLMContentPart:
|
||||||
|
"""
|
||||||
|
将混合输入转换为统一的 LLMContentPart,便于 normalize_to_llm_messages 使用。
|
||||||
|
"""
|
||||||
|
if isinstance(item, LLMContentPart):
|
||||||
|
return item
|
||||||
|
|
||||||
|
if isinstance(item, str):
|
||||||
|
return LLMContentPart.text_part(item)
|
||||||
|
|
||||||
|
if isinstance(item, Path):
|
||||||
|
part = await LLMContentPart.from_path(item)
|
||||||
|
if part is None:
|
||||||
|
raise ValueError(f"无法从路径加载内容: {item}")
|
||||||
|
return part
|
||||||
|
|
||||||
|
if isinstance(item, dict):
|
||||||
|
return LLMContentPart(**item)
|
||||||
|
|
||||||
|
if PILImageType and isinstance(item, PILImageType):
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
fmt = item.format or "PNG"
|
||||||
|
item.save(buffer, format=fmt)
|
||||||
|
b64_data = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||||
|
mime_type = f"image/{fmt.lower()}"
|
||||||
|
return LLMContentPart.image_base64_part(b64_data, mime_type)
|
||||||
|
|
||||||
|
raise TypeError(f"不支持的输入类型用于构建 ContentPart: {type(item)}")
|
||||||
|
|
||||||
|
|
||||||
async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
|
async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
|
||||||
@ -36,110 +186,25 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
|
|||||||
返回:
|
返回:
|
||||||
list[LLMContentPart]: 转换后的内容部分列表。
|
list[LLMContentPart]: 转换后的内容部分列表。
|
||||||
"""
|
"""
|
||||||
|
if not _SEGMENT_HANDLERS:
|
||||||
|
pass
|
||||||
|
|
||||||
parts: list[LLMContentPart] = []
|
parts: list[LLMContentPart] = []
|
||||||
for seg in message:
|
for seg in message:
|
||||||
part = None
|
handler = _SEGMENT_HANDLERS.get(type(seg))
|
||||||
if isinstance(seg, Text):
|
if handler:
|
||||||
if seg.text.strip():
|
try:
|
||||||
part = LLMContentPart.text_part(seg.text)
|
part = await handler(seg)
|
||||||
elif isinstance(seg, Image):
|
|
||||||
if seg.path:
|
|
||||||
part = await LLMContentPart.from_path(seg.path, target_api="gemini")
|
|
||||||
elif seg.url:
|
|
||||||
part = LLMContentPart.image_url_part(seg.url)
|
|
||||||
elif hasattr(seg, "raw") and seg.raw:
|
|
||||||
mime_type = (
|
|
||||||
getattr(seg, "mimetype", "image/png")
|
|
||||||
if hasattr(seg, "mimetype")
|
|
||||||
else "image/png"
|
|
||||||
)
|
|
||||||
if isinstance(seg.raw, bytes):
|
|
||||||
b64_data = base64.b64encode(seg.raw).decode("utf-8")
|
|
||||||
part = LLMContentPart.image_base64_part(b64_data, mime_type)
|
|
||||||
|
|
||||||
elif isinstance(seg, File | Voice | Video):
|
|
||||||
if seg.path:
|
|
||||||
part = await LLMContentPart.from_path(seg.path)
|
|
||||||
elif seg.url:
|
|
||||||
try:
|
|
||||||
logger.debug(f"检测到媒体URL,开始下载: {seg.url}")
|
|
||||||
media_bytes = await AsyncHttpx.get_content(seg.url)
|
|
||||||
|
|
||||||
new_seg = copy.copy(seg)
|
|
||||||
new_seg.raw = media_bytes
|
|
||||||
seg = new_seg
|
|
||||||
logger.debug(f"媒体文件下载成功,大小: {len(media_bytes)} bytes")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"从URL下载媒体失败: {seg.url}, 错误: {e}")
|
|
||||||
part = LLMContentPart.text_part(
|
|
||||||
f"[下载媒体失败: {seg.name or seg.url}]"
|
|
||||||
)
|
|
||||||
|
|
||||||
if part:
|
if part:
|
||||||
parts.append(part)
|
parts.append(part)
|
||||||
continue
|
except Exception as e:
|
||||||
|
logger.warning(f"处理消息段 {seg} 失败: {e}", "LLMUtils")
|
||||||
if hasattr(seg, "raw") and seg.raw:
|
|
||||||
mime_type = getattr(seg, "mimetype", None)
|
|
||||||
if isinstance(seg.raw, bytes):
|
|
||||||
b64_data = base64.b64encode(seg.raw).decode("utf-8")
|
|
||||||
|
|
||||||
if isinstance(seg, Video):
|
|
||||||
if not mime_type:
|
|
||||||
mime_type = "video/mp4"
|
|
||||||
part = LLMContentPart.video_base64_part(
|
|
||||||
data=b64_data, mime_type=mime_type
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"处理视频字节数据: {mime_type}, 大小: {len(seg.raw)} bytes"
|
|
||||||
)
|
|
||||||
elif isinstance(seg, Voice):
|
|
||||||
if not mime_type:
|
|
||||||
mime_type = "audio/wav"
|
|
||||||
part = LLMContentPart.audio_base64_part(
|
|
||||||
data=b64_data, mime_type=mime_type
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"处理音频字节数据: {mime_type}, 大小: {len(seg.raw)} bytes"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
part = LLMContentPart.text_part(
|
|
||||||
f"[FILE: {mime_type or 'unknown'}, {len(seg.raw)} bytes]"
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"处理其他文件字节数据: {mime_type}, "
|
|
||||||
f"大小: {len(seg.raw)} bytes"
|
|
||||||
)
|
|
||||||
|
|
||||||
elif isinstance(seg, At):
|
|
||||||
if seg.flag == "all":
|
|
||||||
part = LLMContentPart.text_part("[提及所有人]")
|
|
||||||
else:
|
|
||||||
part = LLMContentPart.text_part(f"[提及用户: {seg.target}]")
|
|
||||||
|
|
||||||
elif isinstance(seg, Reply):
|
|
||||||
if seg.msg:
|
|
||||||
try:
|
|
||||||
extract_method = getattr(seg.msg, "extract_plain_text", None)
|
|
||||||
if extract_method and callable(extract_method):
|
|
||||||
reply_text = str(extract_method()).strip()
|
|
||||||
else:
|
|
||||||
reply_text = str(seg.msg).strip()
|
|
||||||
if reply_text:
|
|
||||||
part = LLMContentPart.text_part(
|
|
||||||
f'[回复消息: "{reply_text[:50]}..."]'
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
part = LLMContentPart.text_part("[回复了一条消息]")
|
|
||||||
|
|
||||||
if part:
|
|
||||||
parts.append(part)
|
|
||||||
|
|
||||||
return parts
|
return parts
|
||||||
|
|
||||||
|
|
||||||
async def normalize_to_llm_messages(
|
async def normalize_to_llm_messages(
|
||||||
message: str | UniMessage | LLMMessage | list[LLMContentPart] | list[LLMMessage],
|
message: str | UniMessage | LLMMessage | list[Any],
|
||||||
instruction: str | None = None,
|
instruction: str | None = None,
|
||||||
) -> list[LLMMessage]:
|
) -> list[LLMMessage]:
|
||||||
"""
|
"""
|
||||||
@ -167,7 +232,10 @@ async def normalize_to_llm_messages(
|
|||||||
content_parts = await unimsg_to_llm_parts(message)
|
content_parts = await unimsg_to_llm_parts(message)
|
||||||
messages.append(LLMMessage.user(content_parts))
|
messages.append(LLMMessage.user(content_parts))
|
||||||
elif isinstance(message, list):
|
elif isinstance(message, list):
|
||||||
messages.append(LLMMessage.user(message)) # type: ignore
|
parts = []
|
||||||
|
for item in message:
|
||||||
|
parts.append(await _transform_to_content_part(item))
|
||||||
|
messages.append(LLMMessage.user(parts))
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"不支持的消息类型: {type(message)}")
|
raise TypeError(f"不支持的消息类型: {type(message)}")
|
||||||
|
|
||||||
@ -255,53 +323,271 @@ def message_to_unimessage(message: PlatformMessage) -> UniMessage:
|
|||||||
返回:
|
返回:
|
||||||
UniMessage: 转换后的通用消息对象。
|
UniMessage: 转换后的通用消息对象。
|
||||||
"""
|
"""
|
||||||
uni_segments = []
|
return UniMessage.of(message)
|
||||||
for seg in message:
|
|
||||||
if seg.type == "text":
|
|
||||||
uni_segments.append(Text(seg.data.get("text", "")))
|
|
||||||
elif seg.type == "image":
|
|
||||||
uni_segments.append(Image(url=seg.data.get("url")))
|
|
||||||
elif seg.type == "record":
|
|
||||||
uni_segments.append(Voice(url=seg.data.get("url")))
|
|
||||||
elif seg.type == "video":
|
|
||||||
uni_segments.append(Video(url=seg.data.get("url")))
|
|
||||||
elif seg.type == "at":
|
|
||||||
uni_segments.append(At("user", str(seg.data.get("qq", ""))))
|
|
||||||
else:
|
|
||||||
logger.debug(f"跳过不支持的平台消息段类型: {seg.type}")
|
|
||||||
|
|
||||||
return UniMessage(uni_segments)
|
|
||||||
|
def resolve_json_schema_refs(schema: dict) -> dict:
|
||||||
|
"""
|
||||||
|
递归解析 JSON Schema 中的 $ref,将其替换为 $defs/definitions 中的定义。
|
||||||
|
用于兼容不支持 $ref 的 Gemini API。
|
||||||
|
"""
|
||||||
|
definitions = schema.get("$defs") or schema.get("definitions") or {}
|
||||||
|
|
||||||
|
def _resolve(node: Any) -> Any:
|
||||||
|
if isinstance(node, dict):
|
||||||
|
if "$ref" in node:
|
||||||
|
ref_name = node["$ref"].split("/")[-1]
|
||||||
|
if ref_name in definitions:
|
||||||
|
return _resolve(definitions[ref_name])
|
||||||
|
|
||||||
|
return {
|
||||||
|
key: _resolve(value)
|
||||||
|
for key, value in node.items()
|
||||||
|
if key not in ("$defs", "definitions")
|
||||||
|
}
|
||||||
|
|
||||||
|
if isinstance(node, list):
|
||||||
|
return [_resolve(item) for item in node]
|
||||||
|
|
||||||
|
return node
|
||||||
|
|
||||||
|
return _resolve(schema)
|
||||||
|
|
||||||
|
|
||||||
def sanitize_schema_for_llm(schema: Any, api_type: str) -> Any:
|
def sanitize_schema_for_llm(schema: Any, api_type: str) -> Any:
|
||||||
"""
|
"""
|
||||||
递归地净化 JSON Schema,移除特定 LLM API 不支持的关键字。
|
递归地净化 JSON Schema,移除特定 LLM API 不支持的关键字。
|
||||||
|
|
||||||
参数:
|
|
||||||
schema: 要净化的 JSON Schema (可以是字典、列表或其它类型)。
|
|
||||||
api_type: 目标 API 的类型,例如 'gemini'。
|
|
||||||
|
|
||||||
返回:
|
|
||||||
Any: 净化后的 JSON Schema。
|
|
||||||
"""
|
"""
|
||||||
if isinstance(schema, dict):
|
if isinstance(schema, list):
|
||||||
schema_copy = {}
|
|
||||||
for key, value in schema.items():
|
|
||||||
if api_type == "gemini":
|
|
||||||
unsupported_keys = ["exclusiveMinimum", "exclusiveMaximum", "default"]
|
|
||||||
if key in unsupported_keys:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if key == "format" and isinstance(value, str):
|
|
||||||
supported_formats = ["enum", "date-time"]
|
|
||||||
if value not in supported_formats:
|
|
||||||
continue
|
|
||||||
|
|
||||||
schema_copy[key] = sanitize_schema_for_llm(value, api_type)
|
|
||||||
return schema_copy
|
|
||||||
|
|
||||||
elif isinstance(schema, list):
|
|
||||||
return [sanitize_schema_for_llm(item, api_type) for item in schema]
|
return [sanitize_schema_for_llm(item, api_type) for item in schema]
|
||||||
|
if isinstance(schema, dict):
|
||||||
|
schema_copy = schema.copy()
|
||||||
|
|
||||||
|
if api_type == "gemini":
|
||||||
|
if "const" in schema_copy:
|
||||||
|
schema_copy["enum"] = [schema_copy.pop("const")]
|
||||||
|
|
||||||
|
if "type" in schema_copy and isinstance(schema_copy["type"], list):
|
||||||
|
types_list = schema_copy["type"]
|
||||||
|
if "null" in types_list:
|
||||||
|
schema_copy["nullable"] = True
|
||||||
|
types_list = [t for t in types_list if t != "null"]
|
||||||
|
if len(types_list) == 1:
|
||||||
|
schema_copy["type"] = types_list[0]
|
||||||
|
else:
|
||||||
|
schema_copy["type"] = types_list
|
||||||
|
|
||||||
|
if "anyOf" in schema_copy:
|
||||||
|
any_of = schema_copy["anyOf"]
|
||||||
|
has_null = any(
|
||||||
|
isinstance(x, dict) and x.get("type") == "null" for x in any_of
|
||||||
|
)
|
||||||
|
if has_null:
|
||||||
|
schema_copy["nullable"] = True
|
||||||
|
new_any_of = [
|
||||||
|
x
|
||||||
|
for x in any_of
|
||||||
|
if not (isinstance(x, dict) and x.get("type") == "null")
|
||||||
|
]
|
||||||
|
if len(new_any_of) == 1:
|
||||||
|
schema_copy.update(new_any_of[0])
|
||||||
|
schema_copy.pop("anyOf", None)
|
||||||
|
else:
|
||||||
|
schema_copy["anyOf"] = new_any_of
|
||||||
|
|
||||||
|
unsupported_keys = [
|
||||||
|
"exclusiveMinimum",
|
||||||
|
"exclusiveMaximum",
|
||||||
|
"default",
|
||||||
|
"title",
|
||||||
|
"additionalProperties",
|
||||||
|
"$schema",
|
||||||
|
"$id",
|
||||||
|
]
|
||||||
|
for key in unsupported_keys:
|
||||||
|
schema_copy.pop(key, None)
|
||||||
|
|
||||||
|
if schema_copy.get("format") and schema_copy["format"] not in [
|
||||||
|
"enum",
|
||||||
|
"date-time",
|
||||||
|
]:
|
||||||
|
schema_copy.pop("format", None)
|
||||||
|
|
||||||
|
elif api_type == "openai":
|
||||||
|
unsupported_keys = [
|
||||||
|
"default",
|
||||||
|
"minLength",
|
||||||
|
"maxLength",
|
||||||
|
"pattern",
|
||||||
|
"format",
|
||||||
|
"minimum",
|
||||||
|
"maximum",
|
||||||
|
"multipleOf",
|
||||||
|
"patternProperties",
|
||||||
|
"minItems",
|
||||||
|
"maxItems",
|
||||||
|
"uniqueItems",
|
||||||
|
"$schema",
|
||||||
|
"title",
|
||||||
|
]
|
||||||
|
for key in unsupported_keys:
|
||||||
|
schema_copy.pop(key, None)
|
||||||
|
|
||||||
|
if "$ref" in schema_copy:
|
||||||
|
ref_key = schema_copy["$ref"].split("/")[-1]
|
||||||
|
defs = schema_copy.get("$defs") or schema_copy.get("definitions")
|
||||||
|
if defs and ref_key in defs:
|
||||||
|
schema_copy.pop("$ref", None)
|
||||||
|
schema_copy.update(defs[ref_key])
|
||||||
|
else:
|
||||||
|
return {"$ref": schema_copy["$ref"]}
|
||||||
|
|
||||||
|
is_object = (
|
||||||
|
schema_copy.get("type") == "object" or "properties" in schema_copy
|
||||||
|
)
|
||||||
|
if is_object:
|
||||||
|
schema_copy["type"] = "object"
|
||||||
|
schema_copy["additionalProperties"] = False
|
||||||
|
|
||||||
|
properties = schema_copy.get("properties", {})
|
||||||
|
required = schema_copy.get("required", [])
|
||||||
|
if properties:
|
||||||
|
existing_req = set(required)
|
||||||
|
for prop in properties.keys():
|
||||||
|
if prop not in existing_req:
|
||||||
|
required.append(prop)
|
||||||
|
schema_copy["required"] = required
|
||||||
|
|
||||||
|
for def_key in ["$defs", "definitions"]:
|
||||||
|
if def_key in schema_copy and isinstance(schema_copy[def_key], dict):
|
||||||
|
schema_copy[def_key] = {
|
||||||
|
k: sanitize_schema_for_llm(v, api_type)
|
||||||
|
for k, v in schema_copy[def_key].items()
|
||||||
|
}
|
||||||
|
|
||||||
|
recursive_keys = ["properties", "items", "allOf", "anyOf", "oneOf"]
|
||||||
|
for key in recursive_keys:
|
||||||
|
if key in schema_copy:
|
||||||
|
if key == "properties" and isinstance(schema_copy[key], dict):
|
||||||
|
schema_copy[key] = {
|
||||||
|
k: sanitize_schema_for_llm(v, api_type)
|
||||||
|
for k, v in schema_copy[key].items()
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
schema_copy[key] = sanitize_schema_for_llm(
|
||||||
|
schema_copy[key], api_type
|
||||||
|
)
|
||||||
|
|
||||||
|
return schema_copy
|
||||||
else:
|
else:
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
|
|
||||||
|
def extract_text_from_content(
|
||||||
|
content: str | list[LLMContentPart] | None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
从消息内容中提取纯文本,自动过滤非文本部分,防止污染 Prompt。
|
||||||
|
"""
|
||||||
|
if content is None:
|
||||||
|
return ""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
return " ".join(
|
||||||
|
part.text for part in content if part.type == "text" and part.text
|
||||||
|
)
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_and_validate_json(text: str, response_model: type[T]) -> T:
|
||||||
|
"""
|
||||||
|
通用工具:尝试将文本解析为指定的 Pydantic 模型,并统一处理异常。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return type_validate_json(response_model, text)
|
||||||
|
except (ValidationError, ValueError) as e:
|
||||||
|
try:
|
||||||
|
logger.warning(f"标准JSON解析失败,尝试使用json_repair修复: {e}")
|
||||||
|
repaired_obj = json_repair.loads(text, skip_json_loads=True)
|
||||||
|
return model_validate(response_model, repaired_obj)
|
||||||
|
except Exception as repair_error:
|
||||||
|
logger.error(
|
||||||
|
f"LLM结构化输出校验最终失败: {repair_error}",
|
||||||
|
e=repair_error,
|
||||||
|
)
|
||||||
|
raise LLMException(
|
||||||
|
"LLM返回的JSON未能通过结构验证。",
|
||||||
|
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
|
||||||
|
details={
|
||||||
|
"raw_response": text,
|
||||||
|
"validation_error": str(repair_error),
|
||||||
|
"original_error": repair_error,
|
||||||
|
},
|
||||||
|
cause=repair_error,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"解析LLM结构化输出时发生未知错误: {e}", e=e)
|
||||||
|
raise LLMException(
|
||||||
|
"解析LLM的JSON输出时失败。",
|
||||||
|
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
|
||||||
|
details={"raw_response": text},
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_cot_wrapper(inner_model: type[BaseModel]) -> type[BaseModel]:
|
||||||
|
"""
|
||||||
|
[动态运行时封装]
|
||||||
|
创建一个包含思维链 (Chain of Thought) 的包装模型。
|
||||||
|
强制模型在生成最终 JSON 结构前,先输出一个 reasoning 字段进行思考。
|
||||||
|
"""
|
||||||
|
wrapper_name = f"CoT_{inner_model.__name__}"
|
||||||
|
|
||||||
|
return create_model(
|
||||||
|
wrapper_name,
|
||||||
|
reasoning=(
|
||||||
|
str,
|
||||||
|
Field(
|
||||||
|
...,
|
||||||
|
min_length=10,
|
||||||
|
description=(
|
||||||
|
"在生成最终结果之前,请务必在此字段中详细描述你的推理步骤、计算过程或思考逻辑。禁止留空。"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
result=(
|
||||||
|
inner_model,
|
||||||
|
Field(
|
||||||
|
...,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def should_apply_autocot(
|
||||||
|
requested: bool,
|
||||||
|
model_name: str | None,
|
||||||
|
config: Any,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
[智能决策管道]
|
||||||
|
判断是否应该应用 AutoCoT (显式思维链包装)。
|
||||||
|
防止在模型已有原生思维能力时进行“双重思考”。
|
||||||
|
"""
|
||||||
|
if not requested:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if config:
|
||||||
|
thinking_budget = getattr(config, "thinking_budget", 0) or 0
|
||||||
|
if thinking_budget > 0:
|
||||||
|
return False
|
||||||
|
if getattr(config, "thinking_level", None) is not None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if model_name:
|
||||||
|
caps = get_model_capabilities(model_name)
|
||||||
|
if caps.reasoning_mode != ReasoningMode.NONE:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|||||||
@ -14,9 +14,34 @@ def _truncate_base64_string(value: str, threshold: int = 256) -> str:
|
|||||||
if value.startswith(prefixes) and len(value) > threshold:
|
if value.startswith(prefixes) and len(value) > threshold:
|
||||||
prefix = next((p for p in prefixes if value.startswith(p)), "base64")
|
prefix = next((p for p in prefixes if value.startswith(p)), "base64")
|
||||||
return f"[{prefix}_data_omitted_len={len(value)}]"
|
return f"[{prefix}_data_omitted_len={len(value)}]"
|
||||||
|
|
||||||
|
if len(value) > 1000:
|
||||||
|
return f"[long_string_omitted_len={len(value)}] {value[:20]}...{value[-20:]}"
|
||||||
|
|
||||||
|
if len(value) > 2000:
|
||||||
|
return f"[long_string_omitted_len={len(value)}] {value[:50]}...{value[-20:]}"
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _truncate_vector_list(vector: list, threshold: int = 10) -> list:
|
||||||
|
"""如果列表过长(通常是embedding向量),则截断它用于日志显示。"""
|
||||||
|
if isinstance(vector, list) and len(vector) > threshold:
|
||||||
|
return [*vector[:3], f"...({len(vector)} floats omitted)...", *vector[-3:]]
|
||||||
|
return vector
|
||||||
|
|
||||||
|
|
||||||
|
def _recursive_sanitize_any(obj: Any) -> Any:
|
||||||
|
"""递归清洗任何对象中的长字符串"""
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return {k: _recursive_sanitize_any(v) for k, v in obj.items()}
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
return [_recursive_sanitize_any(v) for v in obj]
|
||||||
|
elif isinstance(obj, str):
|
||||||
|
return _truncate_base64_string(obj)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_ui_html(html_string: str) -> str:
|
def _sanitize_ui_html(html_string: str) -> str:
|
||||||
"""
|
"""
|
||||||
专门用于净化UI渲染调试HTML的函数。
|
专门用于净化UI渲染调试HTML的函数。
|
||||||
@ -64,6 +89,37 @@ def _sanitize_openai_response(response_json: dict) -> dict:
|
|||||||
message["images"][i]["image_url"]["url"] = (
|
message["images"][i]["image_url"]["url"] = (
|
||||||
_truncate_base64_string(url)
|
_truncate_base64_string(url)
|
||||||
)
|
)
|
||||||
|
if "reasoning_details" in message and isinstance(
|
||||||
|
message["reasoning_details"], list
|
||||||
|
):
|
||||||
|
for detail in message["reasoning_details"]:
|
||||||
|
if isinstance(detail, dict):
|
||||||
|
if "data" in detail and isinstance(detail["data"], str):
|
||||||
|
if len(detail["data"]) > 100:
|
||||||
|
detail["data"] = (
|
||||||
|
f"[encrypted_data_omitted_len={len(detail['data'])}]"
|
||||||
|
)
|
||||||
|
if "text" in detail and isinstance(detail["text"], str):
|
||||||
|
detail["text"] = _truncate_base64_string(
|
||||||
|
detail["text"], threshold=2000
|
||||||
|
)
|
||||||
|
if "data" in sanitized_json and isinstance(sanitized_json["data"], list):
|
||||||
|
for item in sanitized_json["data"]:
|
||||||
|
if "embedding" in item and isinstance(item["embedding"], list):
|
||||||
|
item["embedding"] = _truncate_vector_list(item["embedding"])
|
||||||
|
if "b64_json" in item and isinstance(item["b64_json"], str):
|
||||||
|
if len(item["b64_json"]) > 256:
|
||||||
|
item["b64_json"] = (
|
||||||
|
f"[base64_json_omitted_len={len(item['b64_json'])}]"
|
||||||
|
)
|
||||||
|
if "input" in sanitized_json and isinstance(sanitized_json["input"], list):
|
||||||
|
for item in sanitized_json["input"]:
|
||||||
|
if "content" in item and isinstance(item["content"], list):
|
||||||
|
for part in item["content"]:
|
||||||
|
if isinstance(part, dict) and part.get("type") == "input_image":
|
||||||
|
image_url = part.get("image_url")
|
||||||
|
if isinstance(image_url, str):
|
||||||
|
part["image_url"] = _truncate_base64_string(image_url)
|
||||||
return sanitized_json
|
return sanitized_json
|
||||||
except Exception:
|
except Exception:
|
||||||
return response_json
|
return response_json
|
||||||
@ -71,22 +127,44 @@ def _sanitize_openai_response(response_json: dict) -> dict:
|
|||||||
|
|
||||||
def _sanitize_openai_request(body: dict) -> dict:
|
def _sanitize_openai_request(body: dict) -> dict:
|
||||||
"""净化OpenAI兼容API的请求体,主要截断图片base64。"""
|
"""净化OpenAI兼容API的请求体,主要截断图片base64。"""
|
||||||
|
from zhenxun.services.llm.config.providers import (
|
||||||
|
DebugLogOptions,
|
||||||
|
get_llm_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
debug_conf = get_llm_config().debug_log
|
||||||
|
if isinstance(debug_conf, bool):
|
||||||
|
debug_conf = DebugLogOptions(
|
||||||
|
show_tools=debug_conf, show_schema=debug_conf, show_safety=debug_conf
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
sanitized_json = copy.deepcopy(body)
|
sanitized_json = _recursive_sanitize_any(copy.deepcopy(body))
|
||||||
if "messages" in sanitized_json and isinstance(
|
if "tools" in sanitized_json and not debug_conf.show_tools:
|
||||||
sanitized_json["messages"], list
|
tools = sanitized_json["tools"]
|
||||||
):
|
if isinstance(tools, list):
|
||||||
for message in sanitized_json["messages"]:
|
tool_names = []
|
||||||
if "content" in message and isinstance(message["content"], list):
|
for t in tools:
|
||||||
for i, part in enumerate(message["content"]):
|
if isinstance(t, dict):
|
||||||
if part.get("type") == "image_url":
|
name = None
|
||||||
if "image_url" in part and isinstance(
|
if "function" in t and isinstance(t["function"], dict):
|
||||||
part["image_url"], dict
|
name = t["function"].get("name")
|
||||||
):
|
if not name and "name" in t:
|
||||||
url = part["image_url"].get("url", "")
|
name = t.get("name")
|
||||||
message["content"][i]["image_url"]["url"] = (
|
tool_names.append(name or "unknown")
|
||||||
_truncate_base64_string(url)
|
sanitized_json["tools"] = (
|
||||||
)
|
f"<{len(tool_names)} tools hidden: {', '.join(tool_names)}>"
|
||||||
|
)
|
||||||
|
|
||||||
|
if "response_format" in sanitized_json and not debug_conf.show_schema:
|
||||||
|
response_format = sanitized_json["response_format"]
|
||||||
|
if isinstance(response_format, dict):
|
||||||
|
if response_format.get("type") == "json_schema":
|
||||||
|
sanitized_json["response_format"] = {
|
||||||
|
"type": "json_schema",
|
||||||
|
"json_schema": "<JSON Schema Hidden>",
|
||||||
|
}
|
||||||
|
|
||||||
return sanitized_json
|
return sanitized_json
|
||||||
except Exception:
|
except Exception:
|
||||||
return body
|
return body
|
||||||
@ -94,6 +172,9 @@ def _sanitize_openai_request(body: dict) -> dict:
|
|||||||
|
|
||||||
def _sanitize_gemini_response(response_json: dict) -> dict:
|
def _sanitize_gemini_response(response_json: dict) -> dict:
|
||||||
"""净化Gemini API的响应体,处理文本和图片生成两种格式。"""
|
"""净化Gemini API的响应体,处理文本和图片生成两种格式。"""
|
||||||
|
from zhenxun.services.llm.config.providers import get_llm_config
|
||||||
|
|
||||||
|
debug_mode = get_llm_config().debug_log
|
||||||
try:
|
try:
|
||||||
sanitized_json = copy.deepcopy(response_json)
|
sanitized_json = copy.deepcopy(response_json)
|
||||||
|
|
||||||
@ -114,6 +195,15 @@ def _sanitize_gemini_response(response_json: dict) -> dict:
|
|||||||
content["parts"][i]["inlineData"]["data"] = (
|
content["parts"][i]["inlineData"]["data"] = (
|
||||||
f"[base64_data_omitted_len={len(data)}]"
|
f"[base64_data_omitted_len={len(data)}]"
|
||||||
)
|
)
|
||||||
|
if "thoughtSignature" in part:
|
||||||
|
signature = part.get("thoughtSignature", "")
|
||||||
|
if isinstance(signature, str) and len(signature) > 256:
|
||||||
|
content["parts"][i]["thoughtSignature"] = (
|
||||||
|
f"[signature_omitted_len={len(signature)}]"
|
||||||
|
)
|
||||||
|
if not debug_mode and isinstance(candidate, dict):
|
||||||
|
if "safetyRatings" in candidate:
|
||||||
|
candidate["safetyRatings"] = "<Safety Ratings Hidden>"
|
||||||
|
|
||||||
if "candidates" in sanitized_json:
|
if "candidates" in sanitized_json:
|
||||||
_process_candidates(sanitized_json["candidates"])
|
_process_candidates(sanitized_json["candidates"])
|
||||||
@ -124,6 +214,19 @@ def _sanitize_gemini_response(response_json: dict) -> dict:
|
|||||||
if "candidates" in sanitized_json["image_generation"]:
|
if "candidates" in sanitized_json["image_generation"]:
|
||||||
_process_candidates(sanitized_json["image_generation"]["candidates"])
|
_process_candidates(sanitized_json["image_generation"]["candidates"])
|
||||||
|
|
||||||
|
if "embeddings" in sanitized_json and isinstance(
|
||||||
|
sanitized_json["embeddings"], list
|
||||||
|
):
|
||||||
|
for embedding in sanitized_json["embeddings"]:
|
||||||
|
if "values" in embedding and isinstance(embedding["values"], list):
|
||||||
|
embedding["values"] = _truncate_vector_list(embedding["values"])
|
||||||
|
|
||||||
|
if not debug_mode and "promptFeedback" in sanitized_json:
|
||||||
|
prompt_feedback = sanitized_json.get("promptFeedback") or {}
|
||||||
|
if isinstance(prompt_feedback, dict) and "safetyRatings" in prompt_feedback:
|
||||||
|
prompt_feedback["safetyRatings"] = "<Safety Ratings Hidden>"
|
||||||
|
sanitized_json["promptFeedback"] = prompt_feedback
|
||||||
|
|
||||||
return sanitized_json
|
return sanitized_json
|
||||||
except Exception:
|
except Exception:
|
||||||
return response_json
|
return response_json
|
||||||
@ -131,8 +234,46 @@ def _sanitize_gemini_response(response_json: dict) -> dict:
|
|||||||
|
|
||||||
def _sanitize_gemini_request(body: dict) -> dict:
|
def _sanitize_gemini_request(body: dict) -> dict:
|
||||||
"""净化Gemini API的请求体,进行结构转换和总结。"""
|
"""净化Gemini API的请求体,进行结构转换和总结。"""
|
||||||
|
from zhenxun.services.llm.config.providers import (
|
||||||
|
DebugLogOptions,
|
||||||
|
get_llm_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
debug_conf = get_llm_config().debug_log
|
||||||
|
if isinstance(debug_conf, bool):
|
||||||
|
debug_conf = DebugLogOptions(
|
||||||
|
show_tools=debug_conf, show_schema=debug_conf, show_safety=debug_conf
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
sanitized_body = copy.deepcopy(body)
|
sanitized_body = copy.deepcopy(body)
|
||||||
|
if "tools" in sanitized_body and not debug_conf.show_tools:
|
||||||
|
tool_summary = []
|
||||||
|
for tool_group in sanitized_body["tools"]:
|
||||||
|
if (
|
||||||
|
isinstance(tool_group, dict)
|
||||||
|
and "functionDeclarations" in tool_group
|
||||||
|
):
|
||||||
|
declarations = tool_group["functionDeclarations"]
|
||||||
|
if isinstance(declarations, list):
|
||||||
|
for func in declarations:
|
||||||
|
if isinstance(func, dict):
|
||||||
|
tool_summary.append(func.get("name", "unknown"))
|
||||||
|
sanitized_body["tools"] = (
|
||||||
|
f"<{len(tool_summary)} functions hidden: {', '.join(tool_summary)}>"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not debug_conf.show_safety and "safetySettings" in sanitized_body:
|
||||||
|
sanitized_body["safetySettings"] = "<Safety Settings Hidden>"
|
||||||
|
|
||||||
|
if not debug_conf.show_schema and "generationConfig" in sanitized_body:
|
||||||
|
generation_config = sanitized_body["generationConfig"]
|
||||||
|
if (
|
||||||
|
isinstance(generation_config, dict)
|
||||||
|
and "responseJsonSchema" in generation_config
|
||||||
|
):
|
||||||
|
generation_config["responseJsonSchema"] = "<JSON Schema Hidden>"
|
||||||
|
|
||||||
if "contents" in sanitized_body and isinstance(
|
if "contents" in sanitized_body and isinstance(
|
||||||
sanitized_body["contents"], list
|
sanitized_body["contents"], list
|
||||||
):
|
):
|
||||||
@ -153,6 +294,13 @@ def _sanitize_gemini_request(body: dict) -> dict:
|
|||||||
continue
|
continue
|
||||||
new_parts.append(part)
|
new_parts.append(part)
|
||||||
|
|
||||||
|
if "thoughtSignature" in part:
|
||||||
|
sig = part["thoughtSignature"]
|
||||||
|
if isinstance(sig, str) and len(sig) > 64:
|
||||||
|
part["thoughtSignature"] = (
|
||||||
|
f"[signature_omitted_len={len(sig)}]"
|
||||||
|
)
|
||||||
|
|
||||||
if media_summary:
|
if media_summary:
|
||||||
summary_text = (
|
summary_text = (
|
||||||
f"[多模态内容: {len(media_summary)}个文件 - "
|
f"[多模态内容: {len(media_summary)}个文件 - "
|
||||||
@ -195,8 +343,5 @@ def sanitize_for_logging(data: Any, context: str | None = None) -> Any:
|
|||||||
elif context == "ui_html":
|
elif context == "ui_html":
|
||||||
if isinstance(data, str):
|
if isinstance(data, str):
|
||||||
return _sanitize_ui_html(data)
|
return _sanitize_ui_html(data)
|
||||||
else:
|
|
||||||
if isinstance(data, str):
|
|
||||||
return _truncate_base64_string(data)
|
|
||||||
|
|
||||||
return data
|
return _recursive_sanitize_any(data)
|
||||||
|
|||||||
@ -10,8 +10,14 @@ from enum import Enum
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, TypeVar, get_args, get_origin
|
from typing import Any, TypeVar, get_args, get_origin
|
||||||
|
|
||||||
from nonebot.compat import PYDANTIC_V2, model_dump
|
from nonebot.compat import (
|
||||||
from pydantic import VERSION, BaseModel
|
PYDANTIC_V2,
|
||||||
|
model_dump,
|
||||||
|
model_fields,
|
||||||
|
type_validate_json,
|
||||||
|
type_validate_python,
|
||||||
|
)
|
||||||
|
from pydantic import BaseModel
|
||||||
import ujson as json
|
import ujson as json
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
@ -27,9 +33,13 @@ __all__ = [
|
|||||||
"model_construct",
|
"model_construct",
|
||||||
"model_copy",
|
"model_copy",
|
||||||
"model_dump",
|
"model_dump",
|
||||||
|
"model_dump_json",
|
||||||
|
"model_fields",
|
||||||
"model_json_schema",
|
"model_json_schema",
|
||||||
"model_validate",
|
"model_validate",
|
||||||
"parse_as",
|
"parse_as",
|
||||||
|
"type_validate_json",
|
||||||
|
"type_validate_python",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -58,12 +68,18 @@ def model_construct(model_class: type[T], **kwargs: Any) -> T:
|
|||||||
|
|
||||||
def model_validate(model_class: type[T], obj: Any) -> T:
|
def model_validate(model_class: type[T], obj: Any) -> T:
|
||||||
"""
|
"""
|
||||||
Pydantic `model_validate` (v2) 与 `parse_obj` (v1) 的兼容函数。
|
Pydantic 模型验证兼容函数。
|
||||||
|
"""
|
||||||
|
return type_validate_python(model_class, obj)
|
||||||
|
|
||||||
|
|
||||||
|
def model_dump_json(model: BaseModel, **kwargs: Any) -> str:
|
||||||
|
"""
|
||||||
|
Pydantic `model.json()` (v1) 和 `model.model_dump_json()` (v2) 的兼容函数。
|
||||||
"""
|
"""
|
||||||
if PYDANTIC_V2:
|
if PYDANTIC_V2:
|
||||||
return model_class.model_validate(obj)
|
return model.model_dump_json(**kwargs)
|
||||||
else:
|
return model.json(**kwargs)
|
||||||
return model_class.parse_obj(obj)
|
|
||||||
|
|
||||||
|
|
||||||
if PYDANTIC_V2:
|
if PYDANTIC_V2:
|
||||||
@ -78,8 +94,7 @@ def model_json_schema(model_class: type[BaseModel], **kwargs: Any) -> dict[str,
|
|||||||
"""
|
"""
|
||||||
if PYDANTIC_V2:
|
if PYDANTIC_V2:
|
||||||
return model_class.model_json_schema(**kwargs)
|
return model_class.model_json_schema(**kwargs)
|
||||||
else:
|
return model_class.schema(by_alias=kwargs.get("by_alias", True))
|
||||||
return model_class.schema(by_alias=kwargs.get("by_alias", True))
|
|
||||||
|
|
||||||
|
|
||||||
def _is_pydantic_type(t: Any) -> bool:
|
def _is_pydantic_type(t: Any) -> bool:
|
||||||
@ -108,18 +123,7 @@ def _dump_pydantic_obj(obj: Any) -> Any:
|
|||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
def parse_as(type_: type[V], obj: Any) -> V:
|
parse_as = type_validate_python
|
||||||
"""
|
|
||||||
一个兼容 Pydantic V1 的 parse_obj_as 和V2的TypeAdapter.validate_python 的辅助函数。
|
|
||||||
"""
|
|
||||||
if VERSION.startswith("1"):
|
|
||||||
from pydantic import parse_obj_as
|
|
||||||
|
|
||||||
return parse_obj_as(type_, obj)
|
|
||||||
else:
|
|
||||||
from pydantic import TypeAdapter # type: ignore
|
|
||||||
|
|
||||||
return TypeAdapter(type_).validate_python(obj)
|
|
||||||
|
|
||||||
|
|
||||||
def dump_json_safely(obj: Any, **kwargs) -> str:
|
def dump_json_safely(obj: Any, **kwargs) -> str:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user