mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-14 21:52:56 +08:00
793 lines
30 KiB
Python
793 lines
30 KiB
Python
"""
|
||
LLM 服务 - 会话客户端
|
||
|
||
提供一个有状态的、面向会话的 LLM 客户端,用于进行多轮对话和复杂交互。
|
||
"""
|
||
|
||
from collections.abc import Awaitable, Callable
|
||
import copy
|
||
import json
|
||
from typing import Any, TypeVar, cast
|
||
import uuid
|
||
|
||
from jinja2 import Template
|
||
from nonebot.utils import is_coroutine_callable
|
||
from nonebot_plugin_alconna.uniseg import UniMessage
|
||
from pydantic import BaseModel
|
||
|
||
from zhenxun.services.log import logger
|
||
from zhenxun.utils.pydantic_compat import model_json_schema
|
||
|
||
from .config import (
|
||
CommonOverrides,
|
||
GenConfigBuilder,
|
||
LLMEmbeddingConfig,
|
||
LLMGenerationConfig,
|
||
)
|
||
from .config.generation import OutputConfig
|
||
from .config.providers import get_llm_config
|
||
from .manager import get_global_default_model_name, get_model_instance
|
||
from .memory import (
|
||
AIConfig,
|
||
BaseMemory,
|
||
MemoryProcessor,
|
||
_get_default_memory,
|
||
)
|
||
from .tools import tool_provider_manager
|
||
from .types import (
|
||
LLMContentPart,
|
||
LLMErrorCode,
|
||
LLMException,
|
||
LLMMessage,
|
||
LLMResponse,
|
||
ModelName,
|
||
ResponseFormat,
|
||
StructuredOutputStrategy,
|
||
ToolChoice,
|
||
ToolExecutable,
|
||
)
|
||
from .types.models import (
|
||
GeminiCodeExecution,
|
||
GeminiGoogleSearch,
|
||
)
|
||
from .utils import (
|
||
create_cot_wrapper,
|
||
normalize_to_llm_messages,
|
||
parse_and_validate_json,
|
||
should_apply_autocot,
|
||
)
|
||
|
||
T = TypeVar("T", bound=BaseModel)
|
||
|
||
DEFAULT_IVR_TEMPLATE = (
|
||
"你的响应未能通过结构校验。\n"
|
||
"错误详情: {error_msg}\n\n"
|
||
"请执行以下步骤进行修正:\n"
|
||
"1. 反思:分析为什么会出现这个错误。\n"
|
||
"2. 修正:生成一个新的、符合 Schema 要求的 JSON 对象。\n"
|
||
"请直接输出修正后的 JSON,不要包含 Markdown 标记或其他解释。"
|
||
)
|
||
|
||
|
||
class AI:
|
||
"""
|
||
统一的AI服务类 - 提供了带记忆的会话接口。
|
||
不再执行自主工具循环,当LLM返回工具调用时,会直接将请求返回给调用者。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
session_id: str | None = None,
|
||
config: AIConfig | None = None,
|
||
memory: BaseMemory | None = None,
|
||
default_generation_config: LLMGenerationConfig | None = None,
|
||
processors: list[MemoryProcessor] | None = None,
|
||
):
|
||
"""
|
||
初始化AI服务
|
||
|
||
参数:
|
||
session_id: 唯一的会话ID,用于隔离记忆。
|
||
config: AI 配置.
|
||
memory: 可选的自定义记忆后端。如果为None,则使用默认的 ChatMemory
|
||
(InMemoryMessageStore)。
|
||
default_generation_config: 此AI实例的默认生成配置。
|
||
processors: 记忆处理器列表,在添加记忆后触发。
|
||
"""
|
||
self.session_id = session_id or str(uuid.uuid4())
|
||
self.config = config or AIConfig()
|
||
self.memory = memory or _get_default_memory()
|
||
self.default_generation_config = (
|
||
default_generation_config or LLMGenerationConfig()
|
||
)
|
||
self.processors = processors or []
|
||
|
||
global_providers = tool_provider_manager._providers
|
||
config_providers = self.config.tool_providers
|
||
self._tool_providers = list(dict.fromkeys(global_providers + config_providers))
|
||
self.message_buffer: list[LLMMessage] = []
|
||
|
||
async def clear_history(self):
|
||
"""清空当前会话的历史记录。"""
|
||
await self.memory.clear_history(self.session_id)
|
||
logger.info(f"AI会话历史记录已清空 (session_id: {self.session_id})")
|
||
|
||
async def add_observation(
|
||
self, message: str | UniMessage | LLMMessage | list[LLMContentPart]
|
||
):
|
||
"""
|
||
将一条观察消息加入缓冲区,不立即触发模型调用。
|
||
|
||
返回:
|
||
int: 缓冲区中消息的数量。
|
||
"""
|
||
current_message = await self._normalize_input_to_message(message)
|
||
self.message_buffer.append(current_message)
|
||
content_preview = str(current_message.content)[:50]
|
||
logger.debug(
|
||
f"[放入观察] {content_preview} (缓冲区大小: {len(self.message_buffer)})",
|
||
"AI_MEMORY",
|
||
)
|
||
return len(self.message_buffer)
|
||
|
||
async def add_user_message_to_history(
|
||
self, message: str | LLMMessage | list[LLMContentPart]
|
||
):
|
||
"""
|
||
将一条用户消息标准化并添加到会话历史中。
|
||
|
||
参数:
|
||
message: 用户消息内容。
|
||
"""
|
||
user_message = await self._normalize_input_to_message(message)
|
||
await self.memory.add_message(self.session_id, user_message)
|
||
|
||
async def add_assistant_response_to_history(self, response_text: str):
|
||
"""
|
||
将助手的文本回复添加到会话历史中。
|
||
|
||
参数:
|
||
response_text: 助手的回复文本。
|
||
"""
|
||
assistant_message = LLMMessage.assistant_text_response(response_text)
|
||
await self.memory.add_message(self.session_id, assistant_message)
|
||
|
||
def _sanitize_message_for_history(self, message: LLMMessage) -> LLMMessage:
|
||
"""
|
||
净化用于存入历史记录的消息。
|
||
将非文本的多模态内容部分替换为文本占位符,以避免重复处理。
|
||
"""
|
||
if not isinstance(message.content, list):
|
||
return message
|
||
|
||
sanitized_message = copy.deepcopy(message)
|
||
content_list = sanitized_message.content
|
||
if not isinstance(content_list, list):
|
||
return sanitized_message
|
||
|
||
new_content_parts: list[LLMContentPart] = []
|
||
has_multimodal_content = False
|
||
|
||
for part in content_list:
|
||
if isinstance(part, LLMContentPart) and part.type == "text":
|
||
new_content_parts.append(part)
|
||
else:
|
||
has_multimodal_content = True
|
||
|
||
if has_multimodal_content:
|
||
placeholder = "[用户发送了媒体文件,内容已在首次分析时处理]"
|
||
text_part_found = False
|
||
for part in new_content_parts:
|
||
if part.type == "text":
|
||
part.text = f"{placeholder} {part.text or ''}".strip()
|
||
text_part_found = True
|
||
break
|
||
if not text_part_found:
|
||
new_content_parts.insert(0, LLMContentPart.text_part(placeholder))
|
||
|
||
sanitized_message.content = new_content_parts
|
||
return sanitized_message
|
||
|
||
async def _normalize_input_to_message(
|
||
self, message: str | UniMessage | LLMMessage | list[LLMContentPart]
|
||
) -> LLMMessage:
|
||
"""
|
||
内部辅助方法,将各种输入类型统一转换为单个 LLMMessage 对象。
|
||
它调用共享的工具函数并提取最后一条消息(通常是用户输入)。
|
||
"""
|
||
messages = await normalize_to_llm_messages(message)
|
||
|
||
if not messages:
|
||
raise LLMException(
|
||
"无法将输入标准化为有效的消息。", code=LLMErrorCode.CONFIGURATION_ERROR
|
||
)
|
||
return messages[-1]
|
||
|
||
async def generate_internal(
|
||
self,
|
||
messages: list[LLMMessage],
|
||
*,
|
||
model: ModelName = None,
|
||
config: LLMGenerationConfig | GenConfigBuilder | None = None,
|
||
tools: list[Any] | dict[str, ToolExecutable] | None = None,
|
||
tool_choice: str | dict[str, Any] | ToolChoice | None = None,
|
||
timeout: float | None = None,
|
||
model_instance: Any = None,
|
||
) -> LLMResponse:
|
||
"""
|
||
内部生成核心方法,负责配置合并、工具解析和模型调用。
|
||
此方法不处理历史记录的存储,供 AgentExecutor 或 chat 方法调用。
|
||
"""
|
||
final_config = self.default_generation_config
|
||
if isinstance(config, GenConfigBuilder):
|
||
config = config.build()
|
||
|
||
if config:
|
||
final_config = final_config.merge_with(config)
|
||
|
||
final_tools_list = []
|
||
if tools:
|
||
if isinstance(tools, dict):
|
||
final_tools_list = list(tools.values())
|
||
elif isinstance(tools, list):
|
||
to_resolve: list[Any] = []
|
||
for t in tools:
|
||
if isinstance(t, str | dict):
|
||
to_resolve.append(t)
|
||
else:
|
||
final_tools_list.append(t)
|
||
|
||
if to_resolve:
|
||
resolved_dict = await self._resolve_tools(to_resolve)
|
||
final_tools_list.extend(resolved_dict.values())
|
||
|
||
if model_instance:
|
||
return await model_instance.generate_response(
|
||
messages,
|
||
config=final_config,
|
||
tools=final_tools_list if final_tools_list else None,
|
||
tool_choice=tool_choice,
|
||
timeout=timeout,
|
||
)
|
||
|
||
resolved_model_name = self._resolve_model_name(model or self.config.model)
|
||
async with await get_model_instance(
|
||
resolved_model_name,
|
||
override_config=None,
|
||
) as instance:
|
||
return await instance.generate_response(
|
||
messages,
|
||
config=final_config,
|
||
tools=final_tools_list if final_tools_list else None,
|
||
tool_choice=tool_choice,
|
||
timeout=timeout,
|
||
)
|
||
|
||
async def chat(
|
||
self,
|
||
message: str | UniMessage | LLMMessage | list[LLMContentPart] | None,
|
||
*,
|
||
model: ModelName = None,
|
||
instruction: str | None = None,
|
||
template_vars: dict[str, Any] | None = None,
|
||
preserve_media_in_history: bool | None = None,
|
||
tools: list[Any] | dict[str, ToolExecutable] | None = None,
|
||
tool_choice: str | dict[str, Any] | ToolChoice | None = None,
|
||
config: LLMGenerationConfig | GenConfigBuilder | None = None,
|
||
use_buffer: bool = False,
|
||
timeout: float | None = None,
|
||
) -> LLMResponse:
|
||
"""
|
||
核心交互方法,管理会话历史并执行单次LLM调用。
|
||
|
||
参数:
|
||
message: 用户输入的消息内容,支持文本、UniMessage、LLMMessage或
|
||
内容部分列表。
|
||
model: 要使用的模型名称,如果为None则使用配置中的默认模型。
|
||
instruction: 本次调用的特定系统指令,会与全局指令合并。
|
||
template_vars: 模板变量字典,用于在指令中进行变量替换。
|
||
preserve_media_in_history: 是否在历史记录中保留媒体内容,
|
||
None时使用默认配置。
|
||
tools: 可用的工具列表或工具字典,支持临时工具和预配置工具。
|
||
tool_choice: 工具选择策略,控制AI如何选择和使用工具。
|
||
config: 生成配置对象,用于覆盖默认的生成参数。
|
||
use_buffer: 是否刷新并包含消息缓冲区的内容,在此次对话中一次性提交。
|
||
timeout: HTTP 请求超时时间(秒)。
|
||
|
||
返回:
|
||
LLMResponse: 包含AI回复、工具调用请求、使用信息等的完整响应对象。
|
||
"""
|
||
messages_to_add: list[LLMMessage] = []
|
||
if message:
|
||
current_message = await self._normalize_input_to_message(message)
|
||
messages_to_add.append(current_message)
|
||
|
||
if use_buffer and self.message_buffer:
|
||
messages_to_add = self.message_buffer + messages_to_add
|
||
self.message_buffer.clear()
|
||
|
||
messages_for_run = []
|
||
final_instruction = instruction
|
||
|
||
if final_instruction and template_vars:
|
||
try:
|
||
template = Template(final_instruction)
|
||
final_instruction = template.render(**template_vars)
|
||
logger.debug(f"渲染后的系统指令: {final_instruction}")
|
||
except Exception as e:
|
||
logger.error(f"渲染系统指令模板失败: {e}", e=e)
|
||
|
||
if final_instruction:
|
||
messages_for_run.append(LLMMessage.system(final_instruction))
|
||
|
||
current_history = await self.memory.get_history(self.session_id)
|
||
messages_for_run.extend(current_history)
|
||
messages_for_run.extend(messages_to_add)
|
||
|
||
try:
|
||
response = await self.generate_internal(
|
||
messages_for_run,
|
||
model=model,
|
||
config=config,
|
||
tools=tools,
|
||
tool_choice=tool_choice,
|
||
timeout=timeout,
|
||
)
|
||
|
||
should_preserve = (
|
||
preserve_media_in_history
|
||
if preserve_media_in_history is not None
|
||
else self.config.default_preserve_media_in_history
|
||
)
|
||
msgs_to_store: list[LLMMessage] = []
|
||
for msg in messages_to_add:
|
||
store_msg = (
|
||
msg if should_preserve else self._sanitize_message_for_history(msg)
|
||
)
|
||
msgs_to_store.append(store_msg)
|
||
|
||
if response.content_parts:
|
||
assistant_response_msg = LLMMessage(
|
||
role="assistant",
|
||
content=response.content_parts,
|
||
tool_calls=response.tool_calls,
|
||
)
|
||
else:
|
||
assistant_response_msg = LLMMessage.assistant_text_response(
|
||
response.text
|
||
)
|
||
if response.tool_calls:
|
||
assistant_response_msg = LLMMessage.assistant_tool_calls(
|
||
response.tool_calls, response.text
|
||
)
|
||
|
||
await self.memory.add_messages(
|
||
self.session_id, [*msgs_to_store, assistant_response_msg]
|
||
)
|
||
|
||
if self.processors:
|
||
for processor in self.processors:
|
||
await processor.process(
|
||
self.session_id, [*msgs_to_store, assistant_response_msg]
|
||
)
|
||
|
||
return response
|
||
|
||
except Exception as e:
|
||
raise (
|
||
e
|
||
if isinstance(e, LLMException)
|
||
else LLMException(f"聊天执行失败: {e}", cause=e)
|
||
)
|
||
|
||
async def code(
|
||
self,
|
||
prompt: str,
|
||
*,
|
||
model: ModelName = None,
|
||
timeout: int | None = None,
|
||
config: LLMGenerationConfig | GenConfigBuilder | None = None,
|
||
) -> LLMResponse:
|
||
"""
|
||
代码执行
|
||
|
||
参数:
|
||
prompt: 代码执行的提示词。
|
||
model: 要使用的模型名称。
|
||
timeout: 代码执行超时时间(秒)。
|
||
config: (可选) 覆盖默认的生成配置。
|
||
|
||
返回:
|
||
LLMResponse: 包含执行结果的完整响应对象。
|
||
"""
|
||
resolved_model = model or self.config.model
|
||
|
||
code_config = CommonOverrides.gemini_code_execution()
|
||
if timeout:
|
||
code_config.custom_params = code_config.custom_params or {}
|
||
code_config.custom_params["code_execution_timeout"] = timeout
|
||
|
||
if isinstance(config, GenConfigBuilder):
|
||
config = config.build()
|
||
|
||
if config:
|
||
code_config = code_config.merge_with(config)
|
||
|
||
return await self.chat(prompt, model=resolved_model, config=code_config)
|
||
|
||
async def search(
|
||
self,
|
||
query: UniMessage,
|
||
*,
|
||
model: ModelName = None,
|
||
instruction: str = (
|
||
"你是一位强大的信息检索和整合专家。请利用可用的搜索工具,"
|
||
"根据用户的查询找到最相关的信息,并进行总结和回答。"
|
||
),
|
||
template_vars: dict[str, Any] | None = None,
|
||
config: LLMGenerationConfig | GenConfigBuilder | None = None,
|
||
) -> LLMResponse:
|
||
"""
|
||
信息搜索的便捷入口,原生支持多模态查询。
|
||
"""
|
||
logger.info("执行 'search' 任务...")
|
||
search_config = CommonOverrides.gemini_grounding()
|
||
|
||
if isinstance(config, GenConfigBuilder):
|
||
config = config.build()
|
||
|
||
if config:
|
||
search_config = search_config.merge_with(config)
|
||
|
||
return await self.chat(
|
||
query,
|
||
model=model,
|
||
instruction=instruction,
|
||
template_vars=template_vars,
|
||
config=search_config,
|
||
)
|
||
|
||
async def generate_structured(
|
||
self,
|
||
message: str | LLMMessage | list[LLMContentPart] | None,
|
||
response_model: type[T],
|
||
*,
|
||
model: ModelName = None,
|
||
tools: list[Any] | dict[str, ToolExecutable] | None = None,
|
||
tool_choice: str | dict[str, Any] | ToolChoice | None = None,
|
||
instruction: str | None = None,
|
||
timeout: float | None = None,
|
||
template_vars: dict[str, Any] | None = None,
|
||
config: LLMGenerationConfig | GenConfigBuilder | None = None,
|
||
max_validation_retries: int | None = None,
|
||
validation_callback: Callable[[T], Any | Awaitable[Any]] | None = None,
|
||
error_prompt_template: str | None = None,
|
||
auto_thinking: bool = False,
|
||
) -> T:
|
||
"""
|
||
生成结构化响应,并自动解析为指定的Pydantic模型。
|
||
|
||
参数:
|
||
message: 用户输入的消息内容,支持多种格式。为None时只使用历史+缓冲区。
|
||
response_model: 用于解析和验证响应的Pydantic模型类。
|
||
model: 要使用的模型名称,如果为None则使用配置中的默认模型。
|
||
instruction: 本次调用的特定系统指令,会与JSON Schema指令合并。
|
||
timeout: HTTP 请求超时时间(秒)。
|
||
template_vars: 系统指令中的模板变量,用于动态渲染。
|
||
config: 生成配置对象,用于覆盖默认的生成参数。
|
||
|
||
返回:
|
||
T: 解析后的Pydantic模型实例,类型为response_model指定的类型。
|
||
|
||
异常:
|
||
LLMException: 如果模型返回的不是有效的JSON或验证失败。
|
||
"""
|
||
if isinstance(config, GenConfigBuilder):
|
||
config = config.build()
|
||
|
||
final_config = self.default_generation_config.merge_with(config)
|
||
|
||
if final_config is None:
|
||
final_config = LLMGenerationConfig()
|
||
|
||
if max_validation_retries is None:
|
||
max_validation_retries = get_llm_config().client_settings.structured_retries
|
||
|
||
resolved_model_name = self._resolve_model_name(model or self.config.model)
|
||
|
||
request_autocot = True if auto_thinking is False else auto_thinking
|
||
effective_auto_thinking = should_apply_autocot(
|
||
request_autocot, resolved_model_name, final_config
|
||
)
|
||
|
||
target_model: type[T] = response_model
|
||
if effective_auto_thinking:
|
||
target_model = cast(type[T], create_cot_wrapper(response_model))
|
||
response_model = target_model
|
||
|
||
cot_instruction = (
|
||
"请务必先在 `reasoning` 字段中进行详细的一步步推理,确保逻辑正确,"
|
||
"然后再填充 `result` 字段。"
|
||
)
|
||
if instruction:
|
||
instruction = f"{instruction}\n\n{cot_instruction}"
|
||
else:
|
||
instruction = cot_instruction
|
||
|
||
final_instruction = instruction
|
||
if final_instruction and template_vars:
|
||
try:
|
||
template = Template(final_instruction)
|
||
final_instruction = template.render(**template_vars)
|
||
except Exception as e:
|
||
logger.error(f"渲染结构化指令模板失败: {e}", e=e)
|
||
|
||
try:
|
||
json_schema = model_json_schema(response_model)
|
||
except AttributeError:
|
||
json_schema = response_model.schema()
|
||
|
||
schema_str = json.dumps(json_schema, ensure_ascii=False, indent=2)
|
||
|
||
prompt_prefix = f"{final_instruction}\n\n" if final_instruction else ""
|
||
structured_strategy = (
|
||
final_config.output.structured_output_strategy
|
||
if final_config.output
|
||
else None
|
||
)
|
||
if structured_strategy == StructuredOutputStrategy.TOOL_CALL:
|
||
system_prompt = prompt_prefix + "请调用提供的工具提交结构化数据。"
|
||
else:
|
||
system_prompt = (
|
||
prompt_prefix
|
||
+ "请严格按照以下 JSON Schema 格式进行响应。不应包含任何额外的解释、"
|
||
"注释或代码块标记,只返回一个合法的 JSON 对象。\n\n"
|
||
)
|
||
system_prompt += f"JSON Schema:\n```json\n{schema_str}\n```"
|
||
|
||
structured_strategy = (
|
||
final_config.output.structured_output_strategy
|
||
if final_config.output
|
||
else StructuredOutputStrategy.NATIVE
|
||
)
|
||
|
||
final_tools_list: list[ToolExecutable] | None = None
|
||
if structured_strategy != StructuredOutputStrategy.NATIVE:
|
||
if tools:
|
||
final_tools_list = []
|
||
if isinstance(tools, dict):
|
||
final_tools_list = list(tools.values())
|
||
elif isinstance(tools, list):
|
||
to_resolve: list[Any] = []
|
||
for t in tools:
|
||
if isinstance(t, str | dict):
|
||
to_resolve.append(t)
|
||
else:
|
||
final_tools_list.append(t)
|
||
if to_resolve:
|
||
resolved_dict = await self._resolve_tools(to_resolve)
|
||
final_tools_list.extend(resolved_dict.values())
|
||
elif tools:
|
||
logger.warning(
|
||
"检测到在 generate_structured (NATIVE 策略) 中传入了 tools。"
|
||
"为了避免 API 冲突(Gemini)及输出歧义(OpenAI),这些"
|
||
"tools 将被本次请求忽略。"
|
||
"若需使用工具,请使用 chat() 方法或 Agent 流程。"
|
||
)
|
||
|
||
if final_config.output is None:
|
||
final_config.output = OutputConfig()
|
||
|
||
final_config.output.response_format = ResponseFormat.JSON
|
||
final_config.output.response_schema = json_schema
|
||
|
||
messages_for_run = [LLMMessage.system(system_prompt)]
|
||
current_history = await self.memory.get_history(self.session_id)
|
||
messages_for_run.extend(current_history)
|
||
messages_for_run.extend(self.message_buffer)
|
||
if message:
|
||
normalized_message = await self._normalize_input_to_message(message)
|
||
messages_for_run.append(normalized_message)
|
||
|
||
ivr_messages = list(messages_for_run)
|
||
last_exception: Exception | None = None
|
||
|
||
for attempt in range(max_validation_retries + 1):
|
||
current_response_text: str = ""
|
||
|
||
async with await get_model_instance(
|
||
resolved_model_name,
|
||
override_config=None,
|
||
) as model_instance:
|
||
response = await model_instance.generate_response(
|
||
ivr_messages,
|
||
config=final_config,
|
||
tools=final_tools_list if final_tools_list else None,
|
||
tool_choice=tool_choice,
|
||
timeout=timeout,
|
||
)
|
||
current_response_text = response.text
|
||
|
||
try:
|
||
parsed_obj = parse_and_validate_json(response.text, target_model)
|
||
|
||
final_obj: T = cast(T, parsed_obj)
|
||
if effective_auto_thinking:
|
||
logger.debug(
|
||
f"AutoCoT 思考过程: {getattr(parsed_obj, 'reasoning', '')}"
|
||
)
|
||
final_obj = cast(T, getattr(parsed_obj, "result"))
|
||
|
||
if validation_callback:
|
||
if is_coroutine_callable(validation_callback):
|
||
await validation_callback(final_obj)
|
||
else:
|
||
validation_callback(final_obj)
|
||
|
||
return final_obj
|
||
|
||
except Exception as e:
|
||
is_llm_error = isinstance(e, LLMException)
|
||
llm_error: LLMException | None = (
|
||
cast(LLMException, e) if is_llm_error else None
|
||
)
|
||
last_exception = e
|
||
|
||
if attempt < max_validation_retries:
|
||
error_msg = (
|
||
llm_error.details.get("validation_error", str(e))
|
||
if llm_error
|
||
else str(e)
|
||
)
|
||
raw_response = current_response_text or (
|
||
llm_error.details.get("raw_response", "") if llm_error else ""
|
||
)
|
||
logger.warning(
|
||
f"结构化校验失败 (尝试 {attempt + 1}/"
|
||
f"{max_validation_retries + 1})。正在尝试 IVR 修复... 错误:"
|
||
f"{error_msg}"
|
||
)
|
||
|
||
if raw_response:
|
||
ivr_messages.append(
|
||
LLMMessage.assistant_text_response(raw_response)
|
||
)
|
||
else:
|
||
logger.warning(
|
||
"IVR 警告: 无法获取上一轮生成的原始文本,"
|
||
"模型将在无上下文情况下尝试修复。"
|
||
)
|
||
|
||
template = error_prompt_template or DEFAULT_IVR_TEMPLATE
|
||
feedback_prompt = template.format(error_msg=error_msg)
|
||
ivr_messages.append(LLMMessage.user(feedback_prompt))
|
||
continue
|
||
|
||
if llm_error and not llm_error.recoverable:
|
||
raise llm_error
|
||
|
||
if last_exception:
|
||
raise last_exception
|
||
raise LLMException(
|
||
"IVR 循环异常结束,未能生成有效结果。", code=LLMErrorCode.GENERATION_FAILED
|
||
)
|
||
|
||
def _resolve_model_name(self, model_name: ModelName) -> str:
|
||
"""解析模型名称"""
|
||
if model_name:
|
||
return model_name
|
||
|
||
default_model = get_global_default_model_name()
|
||
if default_model:
|
||
return default_model
|
||
|
||
raise LLMException(
|
||
"未指定模型名称且未设置全局默认模型",
|
||
code=LLMErrorCode.MODEL_NOT_FOUND,
|
||
)
|
||
|
||
async def embed(
|
||
self,
|
||
texts: list[str] | str,
|
||
*,
|
||
model: ModelName = None,
|
||
config: LLMEmbeddingConfig | None = None,
|
||
) -> list[list[float]]:
|
||
"""
|
||
生成文本嵌入向量,将文本转换为数值向量表示。
|
||
|
||
参数:
|
||
texts: 要生成嵌入的文本内容,支持单个字符串或字符串列表。
|
||
model: 嵌入模型名称,如果为None则使用配置中的默认嵌入模型。
|
||
config: 嵌入配置
|
||
|
||
返回:
|
||
list[list[float]]: 文本对应的嵌入向量列表,每个向量为浮点数列表。
|
||
|
||
异常:
|
||
LLMException: 当嵌入生成失败或模型配置错误时抛出
|
||
"""
|
||
if isinstance(texts, str):
|
||
texts = [texts]
|
||
if not texts:
|
||
return []
|
||
|
||
try:
|
||
resolved_model_str = (
|
||
model or self.config.default_embedding_model or self.config.model
|
||
)
|
||
if not resolved_model_str:
|
||
raise LLMException(
|
||
"使用 embed 方法时未指定嵌入模型名称,"
|
||
"且 AIConfig 未设置 default_embedding_model。",
|
||
code=LLMErrorCode.MODEL_NOT_FOUND,
|
||
)
|
||
resolved_model_str = self._resolve_model_name(resolved_model_str)
|
||
|
||
final_config = config or LLMEmbeddingConfig()
|
||
|
||
async with await get_model_instance(
|
||
resolved_model_str,
|
||
override_config=None,
|
||
) as embedding_model_instance:
|
||
return await embedding_model_instance.generate_embeddings(
|
||
texts, config=final_config
|
||
)
|
||
except LLMException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"文本嵌入失败: {e}", e=e)
|
||
raise LLMException(
|
||
f"文本嵌入失败: {e}", code=LLMErrorCode.EMBEDDING_FAILED, cause=e
|
||
)
|
||
|
||
async def _resolve_tools(
|
||
self,
|
||
tool_configs: list[Any],
|
||
) -> dict[str, ToolExecutable]:
|
||
"""
|
||
使用注入的 ToolProvider 异步解析 ad-hoc(临时)工具配置。
|
||
返回一个从工具名称到可执行对象的字典。
|
||
"""
|
||
resolved: dict[str, ToolExecutable] = {}
|
||
|
||
for config in tool_configs:
|
||
if isinstance(config, str):
|
||
if config == "google_search":
|
||
resolved[config] = GeminiGoogleSearch() # type: ignore[arg-type]
|
||
continue
|
||
elif config == "code_execution":
|
||
resolved[config] = GeminiCodeExecution() # type: ignore[arg-type]
|
||
continue
|
||
elif config == "url_context":
|
||
pass
|
||
name = config if isinstance(config, str) else config.get("name")
|
||
if not name:
|
||
raise LLMException(
|
||
"工具配置字典必须包含 'name' 字段。",
|
||
code=LLMErrorCode.CONFIGURATION_ERROR,
|
||
)
|
||
|
||
if isinstance(config, str):
|
||
config_dict = {"name": name, "type": "function"}
|
||
elif isinstance(config, dict):
|
||
config_dict = config
|
||
else:
|
||
raise TypeError(f"不支持的工具配置类型: {type(config)}")
|
||
|
||
executable = None
|
||
for provider in self._tool_providers:
|
||
executable = await provider.get_tool_executable(name, config_dict)
|
||
if executable:
|
||
break
|
||
|
||
if not executable:
|
||
raise LLMException(
|
||
f"没有为 ad-hoc 工具 '{name}' 找到合适的提供者。",
|
||
code=LLMErrorCode.CONFIGURATION_ERROR,
|
||
)
|
||
|
||
resolved[name] = executable
|
||
|
||
return resolved
|