zhenxun_bot/zhenxun/services/llm/session.py

794 lines
30 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
LLM 服务 - 会话客户端
提供一个有状态的、面向会话的 LLM 客户端,用于进行多轮对话和复杂交互。
"""
from collections.abc import Awaitable, Callable
import copy
import json
from typing import Any, TypeVar, cast
import uuid
from jinja2 import Template
from nonebot.utils import is_coroutine_callable
from nonebot_plugin_alconna.uniseg import UniMessage
from pydantic import BaseModel
from zhenxun.services.log import logger
from zhenxun.utils.pydantic_compat import model_json_schema
from .config import (
CommonOverrides,
GenConfigBuilder,
LLMEmbeddingConfig,
LLMGenerationConfig,
)
from .config.generation import OutputConfig
from .config.providers import get_llm_config
from .manager import get_global_default_model_name, get_model_instance
from .memory import (
AIConfig,
BaseMemory,
MemoryProcessor,
_get_default_memory,
)
from .tools import tool_provider_manager
from .types import (
LLMContentPart,
LLMErrorCode,
LLMException,
LLMMessage,
LLMResponse,
ModelName,
ResponseFormat,
StructuredOutputStrategy,
ToolChoice,
ToolExecutable,
)
from .types.models import (
GeminiCodeExecution,
GeminiGoogleSearch,
)
from .utils import (
create_cot_wrapper,
normalize_to_llm_messages,
parse_and_validate_json,
should_apply_autocot,
)
T = TypeVar("T", bound=BaseModel)
DEFAULT_IVR_TEMPLATE = (
"你的响应未能通过结构校验。\n"
"错误详情: {error_msg}\n\n"
"请执行以下步骤进行修正:\n"
"1. 反思:分析为什么会出现这个错误。\n"
"2. 修正:生成一个新的、符合 Schema 要求的 JSON 对象。\n"
"请直接输出修正后的 JSON不要包含 Markdown 标记或其他解释。"
)
class AI:
"""
统一的AI服务类 - 提供了带记忆的会话接口。
不再执行自主工具循环当LLM返回工具调用时会直接将请求返回给调用者。
"""
def __init__(
self,
session_id: str | None = None,
config: AIConfig | None = None,
memory: BaseMemory | None = None,
default_generation_config: LLMGenerationConfig | None = None,
processors: list[MemoryProcessor] | None = None,
):
"""
初始化AI服务
参数:
session_id: 唯一的会话ID用于隔离记忆。
config: AI 配置.
memory: 可选的自定义记忆后端。如果为None则使用默认的 ChatMemory
(InMemoryMessageStore)。
default_generation_config: 此AI实例的默认生成配置。
processors: 记忆处理器列表,在添加记忆后触发。
"""
self.session_id = session_id or str(uuid.uuid4())
self.config = config or AIConfig()
self.memory = memory or _get_default_memory()
self.default_generation_config = (
default_generation_config or LLMGenerationConfig()
)
self.processors = processors or []
global_providers = tool_provider_manager._providers
config_providers = self.config.tool_providers
self._tool_providers = list(dict.fromkeys(global_providers + config_providers))
self.message_buffer: list[LLMMessage] = []
async def clear_history(self):
"""清空当前会话的历史记录。"""
await self.memory.clear_history(self.session_id)
logger.info(f"AI会话历史记录已清空 (session_id: {self.session_id})")
async def add_observation(
self, message: str | UniMessage | LLMMessage | list[LLMContentPart]
):
"""
将一条观察消息加入缓冲区,不立即触发模型调用。
返回:
int: 缓冲区中消息的数量。
"""
current_message = await self._normalize_input_to_message(message)
self.message_buffer.append(current_message)
content_preview = str(current_message.content)[:50]
logger.debug(
f"[放入观察] {content_preview} (缓冲区大小: {len(self.message_buffer)})",
"AI_MEMORY",
)
return len(self.message_buffer)
async def add_user_message_to_history(
self, message: str | LLMMessage | list[LLMContentPart]
):
"""
将一条用户消息标准化并添加到会话历史中。
参数:
message: 用户消息内容。
"""
user_message = await self._normalize_input_to_message(message)
await self.memory.add_message(self.session_id, user_message)
async def add_assistant_response_to_history(self, response_text: str):
"""
将助手的文本回复添加到会话历史中。
参数:
response_text: 助手的回复文本。
"""
assistant_message = LLMMessage.assistant_text_response(response_text)
await self.memory.add_message(self.session_id, assistant_message)
def _sanitize_message_for_history(self, message: LLMMessage) -> LLMMessage:
"""
净化用于存入历史记录的消息。
将非文本的多模态内容部分替换为文本占位符,以避免重复处理。
"""
if not isinstance(message.content, list):
return message
sanitized_message = copy.deepcopy(message)
content_list = sanitized_message.content
if not isinstance(content_list, list):
return sanitized_message
new_content_parts: list[LLMContentPart] = []
has_multimodal_content = False
for part in content_list:
if isinstance(part, LLMContentPart) and part.type == "text":
new_content_parts.append(part)
else:
has_multimodal_content = True
if has_multimodal_content:
placeholder = "[用户发送了媒体文件,内容已在首次分析时处理]"
text_part_found = False
for part in new_content_parts:
if part.type == "text":
part.text = f"{placeholder} {part.text or ''}".strip()
text_part_found = True
break
if not text_part_found:
new_content_parts.insert(0, LLMContentPart.text_part(placeholder))
sanitized_message.content = new_content_parts
return sanitized_message
async def _normalize_input_to_message(
self, message: str | UniMessage | LLMMessage | list[LLMContentPart]
) -> LLMMessage:
"""
内部辅助方法,将各种输入类型统一转换为单个 LLMMessage 对象。
它调用共享的工具函数并提取最后一条消息(通常是用户输入)。
"""
messages = await normalize_to_llm_messages(message)
if not messages:
raise LLMException(
"无法将输入标准化为有效的消息。", code=LLMErrorCode.CONFIGURATION_ERROR
)
return messages[-1]
async def generate_internal(
self,
messages: list[LLMMessage],
*,
model: ModelName = None,
config: LLMGenerationConfig | GenConfigBuilder | None = None,
tools: list[Any] | dict[str, ToolExecutable] | None = None,
tool_choice: str | dict[str, Any] | ToolChoice | None = None,
timeout: float | None = None,
model_instance: Any = None,
) -> LLMResponse:
"""
内部生成核心方法,负责配置合并、工具解析和模型调用。
此方法不处理历史记录的存储,供 AgentExecutor 或 chat 方法调用。
"""
final_config = self.default_generation_config
if isinstance(config, GenConfigBuilder):
config = config.build()
if config:
final_config = final_config.merge_with(config)
final_tools_list = []
if tools:
if isinstance(tools, dict):
final_tools_list = list(tools.values())
elif isinstance(tools, list):
to_resolve: list[Any] = []
for t in tools:
if isinstance(t, str | dict):
to_resolve.append(t)
else:
final_tools_list.append(t)
if to_resolve:
resolved_dict = await self._resolve_tools(to_resolve)
final_tools_list.extend(resolved_dict.values())
if model_instance:
return await model_instance.generate_response(
messages,
config=final_config,
tools=final_tools_list if final_tools_list else None,
tool_choice=tool_choice,
timeout=timeout,
)
resolved_model_name = self._resolve_model_name(model or self.config.model)
async with await get_model_instance(
resolved_model_name,
override_config=None,
) as instance:
return await instance.generate_response(
messages,
config=final_config,
tools=final_tools_list if final_tools_list else None,
tool_choice=tool_choice,
timeout=timeout,
)
async def chat(
self,
message: str | UniMessage | LLMMessage | list[LLMContentPart] | None,
*,
model: ModelName = None,
instruction: str | None = None,
template_vars: dict[str, Any] | None = None,
preserve_media_in_history: bool | None = None,
tools: list[Any] | dict[str, ToolExecutable] | None = None,
tool_choice: str | dict[str, Any] | ToolChoice | None = None,
config: LLMGenerationConfig | GenConfigBuilder | None = None,
use_buffer: bool = False,
timeout: float | None = None,
) -> LLMResponse:
"""
核心交互方法管理会话历史并执行单次LLM调用。
参数:
message: 用户输入的消息内容支持文本、UniMessage、LLMMessage或
内容部分列表。
model: 要使用的模型名称如果为None则使用配置中的默认模型。
instruction: 本次调用的特定系统指令,会与全局指令合并。
template_vars: 模板变量字典,用于在指令中进行变量替换。
preserve_media_in_history: 是否在历史记录中保留媒体内容,
None时使用默认配置。
tools: 可用的工具列表或工具字典,支持临时工具和预配置工具。
tool_choice: 工具选择策略控制AI如何选择和使用工具。
config: 生成配置对象,用于覆盖默认的生成参数。
use_buffer: 是否刷新并包含消息缓冲区的内容,在此次对话中一次性提交。
timeout: HTTP 请求超时时间(秒)。
返回:
LLMResponse: 包含AI回复、工具调用请求、使用信息等的完整响应对象。
"""
messages_to_add: list[LLMMessage] = []
if message:
current_message = await self._normalize_input_to_message(message)
messages_to_add.append(current_message)
if use_buffer and self.message_buffer:
messages_to_add = self.message_buffer + messages_to_add
self.message_buffer.clear()
messages_for_run = []
final_instruction = instruction
if final_instruction and template_vars:
try:
template = Template(final_instruction)
final_instruction = template.render(**template_vars)
logger.debug(f"渲染后的系统指令: {final_instruction}")
except Exception as e:
logger.error(f"渲染系统指令模板失败: {e}", e=e)
if final_instruction:
messages_for_run.append(LLMMessage.system(final_instruction))
current_history = await self.memory.get_history(self.session_id)
messages_for_run.extend(current_history)
messages_for_run.extend(messages_to_add)
try:
response = await self.generate_internal(
messages_for_run,
model=model,
config=config,
tools=tools,
tool_choice=tool_choice,
timeout=timeout,
)
should_preserve = (
preserve_media_in_history
if preserve_media_in_history is not None
else self.config.default_preserve_media_in_history
)
msgs_to_store: list[LLMMessage] = []
for msg in messages_to_add:
store_msg = (
msg if should_preserve else self._sanitize_message_for_history(msg)
)
msgs_to_store.append(store_msg)
if response.content_parts:
assistant_response_msg = LLMMessage(
role="assistant",
content=response.content_parts,
tool_calls=response.tool_calls,
)
else:
assistant_response_msg = LLMMessage.assistant_text_response(
response.text
)
if response.tool_calls:
assistant_response_msg = LLMMessage.assistant_tool_calls(
response.tool_calls, response.text
)
await self.memory.add_messages(
self.session_id, [*msgs_to_store, assistant_response_msg]
)
if self.processors:
for processor in self.processors:
await processor.process(
self.session_id, [*msgs_to_store, assistant_response_msg]
)
return response
except Exception as e:
raise (
e
if isinstance(e, LLMException)
else LLMException(f"聊天执行失败: {e}", cause=e)
)
async def code(
self,
prompt: str,
*,
model: ModelName = None,
timeout: int | None = None,
config: LLMGenerationConfig | GenConfigBuilder | None = None,
) -> LLMResponse:
"""
代码执行
参数:
prompt: 代码执行的提示词。
model: 要使用的模型名称。
timeout: 代码执行超时时间(秒)。
config: (可选) 覆盖默认的生成配置。
返回:
LLMResponse: 包含执行结果的完整响应对象。
"""
resolved_model = model or self.config.model
code_config = CommonOverrides.gemini_code_execution()
if timeout:
code_config.custom_params = code_config.custom_params or {}
code_config.custom_params["code_execution_timeout"] = timeout
if isinstance(config, GenConfigBuilder):
config = config.build()
if config:
code_config = code_config.merge_with(config)
return await self.chat(prompt, model=resolved_model, config=code_config)
async def search(
self,
query: UniMessage,
*,
model: ModelName = None,
instruction: str = (
"你是一位强大的信息检索和整合专家。请利用可用的搜索工具,"
"根据用户的查询找到最相关的信息,并进行总结和回答。"
),
template_vars: dict[str, Any] | None = None,
config: LLMGenerationConfig | GenConfigBuilder | None = None,
) -> LLMResponse:
"""
信息搜索的便捷入口,原生支持多模态查询。
"""
logger.info("执行 'search' 任务...")
search_config = CommonOverrides.gemini_grounding()
if isinstance(config, GenConfigBuilder):
config = config.build()
if config:
search_config = search_config.merge_with(config)
return await self.chat(
query,
model=model,
instruction=instruction,
template_vars=template_vars,
config=search_config,
tools=[GeminiGoogleSearch()],
)
async def generate_structured(
self,
message: str | UniMessage | LLMMessage | list[LLMContentPart] | None,
response_model: type[T],
*,
model: ModelName = None,
tools: list[Any] | dict[str, ToolExecutable] | None = None,
tool_choice: str | dict[str, Any] | ToolChoice | None = None,
instruction: str | None = None,
timeout: float | None = None,
template_vars: dict[str, Any] | None = None,
config: LLMGenerationConfig | GenConfigBuilder | None = None,
max_validation_retries: int | None = None,
validation_callback: Callable[[T], Any | Awaitable[Any]] | None = None,
error_prompt_template: str | None = None,
auto_thinking: bool = False,
) -> T:
"""
生成结构化响应并自动解析为指定的Pydantic模型。
参数:
message: 用户输入的消息内容支持多种格式。为None时只使用历史+缓冲区。
response_model: 用于解析和验证响应的Pydantic模型类。
model: 要使用的模型名称如果为None则使用配置中的默认模型。
instruction: 本次调用的特定系统指令会与JSON Schema指令合并。
timeout: HTTP 请求超时时间(秒)。
template_vars: 系统指令中的模板变量,用于动态渲染。
config: 生成配置对象,用于覆盖默认的生成参数。
返回:
T: 解析后的Pydantic模型实例类型为response_model指定的类型。
异常:
LLMException: 如果模型返回的不是有效的JSON或验证失败。
"""
if isinstance(config, GenConfigBuilder):
config = config.build()
final_config = self.default_generation_config.merge_with(config)
if final_config is None:
final_config = LLMGenerationConfig()
if max_validation_retries is None:
max_validation_retries = get_llm_config().client_settings.structured_retries
resolved_model_name = self._resolve_model_name(model or self.config.model)
request_autocot = True if auto_thinking is False else auto_thinking
effective_auto_thinking = should_apply_autocot(
request_autocot, resolved_model_name, final_config
)
target_model: type[T] = response_model
if effective_auto_thinking:
target_model = cast(type[T], create_cot_wrapper(response_model))
response_model = target_model
cot_instruction = (
"请务必先在 `reasoning` 字段中进行详细的一步步推理,确保逻辑正确,"
"然后再填充 `result` 字段。"
)
if instruction:
instruction = f"{instruction}\n\n{cot_instruction}"
else:
instruction = cot_instruction
final_instruction = instruction
if final_instruction and template_vars:
try:
template = Template(final_instruction)
final_instruction = template.render(**template_vars)
except Exception as e:
logger.error(f"渲染结构化指令模板失败: {e}", e=e)
try:
json_schema = model_json_schema(response_model)
except AttributeError:
json_schema = response_model.schema()
schema_str = json.dumps(json_schema, ensure_ascii=False, indent=2)
prompt_prefix = f"{final_instruction}\n\n" if final_instruction else ""
structured_strategy = (
final_config.output.structured_output_strategy
if final_config.output
else None
)
if structured_strategy == StructuredOutputStrategy.TOOL_CALL:
system_prompt = prompt_prefix + "请调用提供的工具提交结构化数据。"
else:
system_prompt = (
prompt_prefix
+ "请严格按照以下 JSON Schema 格式进行响应。不应包含任何额外的解释、"
"注释或代码块标记,只返回一个合法的 JSON 对象。\n\n"
)
system_prompt += f"JSON Schema:\n```json\n{schema_str}\n```"
structured_strategy = (
final_config.output.structured_output_strategy
if final_config.output
else StructuredOutputStrategy.NATIVE
)
final_tools_list: list[ToolExecutable] | None = None
if structured_strategy != StructuredOutputStrategy.NATIVE:
if tools:
final_tools_list = []
if isinstance(tools, dict):
final_tools_list = list(tools.values())
elif isinstance(tools, list):
to_resolve: list[Any] = []
for t in tools:
if isinstance(t, str | dict):
to_resolve.append(t)
else:
final_tools_list.append(t)
if to_resolve:
resolved_dict = await self._resolve_tools(to_resolve)
final_tools_list.extend(resolved_dict.values())
elif tools:
logger.warning(
"检测到在 generate_structured (NATIVE 策略) 中传入了 tools。"
"为了避免 API 冲突(Gemini)及输出歧义(OpenAI),这些"
"tools 将被本次请求忽略。"
"若需使用工具,请使用 chat() 方法或 Agent 流程。"
)
if final_config.output is None:
final_config.output = OutputConfig()
final_config.output.response_format = ResponseFormat.JSON
final_config.output.response_schema = json_schema
messages_for_run = [LLMMessage.system(system_prompt)]
current_history = await self.memory.get_history(self.session_id)
messages_for_run.extend(current_history)
messages_for_run.extend(self.message_buffer)
if message:
normalized_message = await self._normalize_input_to_message(message)
messages_for_run.append(normalized_message)
ivr_messages = list(messages_for_run)
last_exception: Exception | None = None
for attempt in range(max_validation_retries + 1):
current_response_text: str = ""
async with await get_model_instance(
resolved_model_name,
override_config=None,
) as model_instance:
response = await model_instance.generate_response(
ivr_messages,
config=final_config,
tools=final_tools_list if final_tools_list else None,
tool_choice=tool_choice,
timeout=timeout,
)
current_response_text = response.text
try:
parsed_obj = parse_and_validate_json(response.text, target_model)
final_obj: T = cast(T, parsed_obj)
if effective_auto_thinking:
logger.debug(
f"AutoCoT 思考过程: {getattr(parsed_obj, 'reasoning', '')}"
)
final_obj = cast(T, getattr(parsed_obj, "result"))
if validation_callback:
if is_coroutine_callable(validation_callback):
await validation_callback(final_obj)
else:
validation_callback(final_obj)
return final_obj
except Exception as e:
is_llm_error = isinstance(e, LLMException)
llm_error: LLMException | None = (
cast(LLMException, e) if is_llm_error else None
)
last_exception = e
if attempt < max_validation_retries:
error_msg = (
llm_error.details.get("validation_error", str(e))
if llm_error
else str(e)
)
raw_response = current_response_text or (
llm_error.details.get("raw_response", "") if llm_error else ""
)
logger.warning(
f"结构化校验失败 (尝试 {attempt + 1}/"
f"{max_validation_retries + 1})。正在尝试 IVR 修复... 错误:"
f"{error_msg}"
)
if raw_response:
ivr_messages.append(
LLMMessage.assistant_text_response(raw_response)
)
else:
logger.warning(
"IVR 警告: 无法获取上一轮生成的原始文本,"
"模型将在无上下文情况下尝试修复。"
)
template = error_prompt_template or DEFAULT_IVR_TEMPLATE
feedback_prompt = template.format(error_msg=error_msg)
ivr_messages.append(LLMMessage.user(feedback_prompt))
continue
if llm_error and not llm_error.recoverable:
raise llm_error
if last_exception:
raise last_exception
raise LLMException(
"IVR 循环异常结束,未能生成有效结果。", code=LLMErrorCode.GENERATION_FAILED
)
def _resolve_model_name(self, model_name: ModelName) -> str:
"""解析模型名称"""
if model_name:
return model_name
default_model = get_global_default_model_name()
if default_model:
return default_model
raise LLMException(
"未指定模型名称且未设置全局默认模型",
code=LLMErrorCode.MODEL_NOT_FOUND,
)
async def embed(
self,
texts: list[str] | str,
*,
model: ModelName = None,
config: LLMEmbeddingConfig | None = None,
) -> list[list[float]]:
"""
生成文本嵌入向量,将文本转换为数值向量表示。
参数:
texts: 要生成嵌入的文本内容,支持单个字符串或字符串列表。
model: 嵌入模型名称如果为None则使用配置中的默认嵌入模型。
config: 嵌入配置
返回:
list[list[float]]: 文本对应的嵌入向量列表,每个向量为浮点数列表。
异常:
LLMException: 当嵌入生成失败或模型配置错误时抛出
"""
if isinstance(texts, str):
texts = [texts]
if not texts:
return []
try:
resolved_model_str = (
model or self.config.default_embedding_model or self.config.model
)
if not resolved_model_str:
raise LLMException(
"使用 embed 方法时未指定嵌入模型名称,"
"且 AIConfig 未设置 default_embedding_model。",
code=LLMErrorCode.MODEL_NOT_FOUND,
)
resolved_model_str = self._resolve_model_name(resolved_model_str)
final_config = config or LLMEmbeddingConfig()
async with await get_model_instance(
resolved_model_str,
override_config=None,
) as embedding_model_instance:
return await embedding_model_instance.generate_embeddings(
texts, config=final_config
)
except LLMException:
raise
except Exception as e:
logger.error(f"文本嵌入失败: {e}", e=e)
raise LLMException(
f"文本嵌入失败: {e}", code=LLMErrorCode.EMBEDDING_FAILED, cause=e
)
async def _resolve_tools(
self,
tool_configs: list[Any],
) -> dict[str, ToolExecutable]:
"""
使用注入的 ToolProvider 异步解析 ad-hoc临时工具配置。
返回一个从工具名称到可执行对象的字典。
"""
resolved: dict[str, ToolExecutable] = {}
for config in tool_configs:
if isinstance(config, str):
if config == "google_search":
resolved[config] = GeminiGoogleSearch() # type: ignore[arg-type]
continue
elif config == "code_execution":
resolved[config] = GeminiCodeExecution() # type: ignore[arg-type]
continue
elif config == "url_context":
pass
name = config if isinstance(config, str) else config.get("name")
if not name:
raise LLMException(
"工具配置字典必须包含 'name' 字段。",
code=LLMErrorCode.CONFIGURATION_ERROR,
)
if isinstance(config, str):
config_dict = {"name": name, "type": "function"}
elif isinstance(config, dict):
config_dict = config
else:
raise TypeError(f"不支持的工具配置类型: {type(config)}")
executable = None
for provider in self._tool_providers:
executable = await provider.get_tool_executable(name, config_dict)
if executable:
break
if not executable:
raise LLMException(
f"没有为 ad-hoc 工具 '{name}' 找到合适的提供者。",
code=LLMErrorCode.CONFIGURATION_ERROR,
)
resolved[name] = executable
return resolved