zhenxun_bot/zhenxun/services/llm/session.py

533 lines
18 KiB
Python
Raw Normal View History

"""
LLM 服务 - 会话客户端
提供一个有状态的面向会话的 LLM 客户端用于进行多轮对话和复杂交互
"""
import copy
from dataclasses import dataclass
from typing import Any
from nonebot_plugin_alconna.uniseg import UniMessage
from zhenxun.services.log import logger
from .config import CommonOverrides, LLMGenerationConfig
from .config.providers import get_ai_config
from .manager import get_global_default_model_name, get_model_instance
from .tools import tool_registry
from .types import (
EmbeddingTaskType,
LLMContentPart,
LLMErrorCode,
LLMException,
LLMMessage,
LLMResponse,
LLMTool,
ModelName,
)
from .utils import unimsg_to_llm_parts
@dataclass
class AIConfig:
"""AI配置类 - 简化版本"""
model: ModelName = None
default_embedding_model: ModelName = None
temperature: float | None = None
max_tokens: int | None = None
enable_cache: bool = False
enable_code: bool = False
enable_search: bool = False
timeout: int | None = None
enable_gemini_json_mode: bool = False
enable_gemini_thinking: bool = False
enable_gemini_safe_mode: bool = False
enable_gemini_multimodal: bool = False
enable_gemini_grounding: bool = False
default_preserve_media_in_history: bool = False
def __post_init__(self):
"""初始化后从配置中读取默认值"""
ai_config = get_ai_config()
if self.model is None:
self.model = ai_config.get("default_model_name")
if self.timeout is None:
self.timeout = ai_config.get("timeout", 180)
class AI:
"""统一的AI服务类 - 平衡设计版本
提供三层API
1. 简单方法ai.chat(), ai.code(), ai.search()
2. 标准方法ai.analyze() 支持复杂参数
3. 高级方法通过get_model_instance()直接访问
"""
def __init__(
self, config: AIConfig | None = None, history: list[LLMMessage] | None = None
):
"""
初始化AI服务
参数:
config: AI 配置.
history: 可选的初始对话历史.
"""
self.config = config or AIConfig()
self.history = history or []
def clear_history(self):
"""清空当前会话的历史记录"""
self.history = []
logger.info("AI session history cleared.")
def _sanitize_message_for_history(self, message: LLMMessage) -> LLMMessage:
"""
净化用于存入历史记录的消息
将非文本的多模态内容部分替换为文本占位符以避免重复处理
"""
if not isinstance(message.content, list):
return message
sanitized_message = copy.deepcopy(message)
content_list = sanitized_message.content
if not isinstance(content_list, list):
return sanitized_message
new_content_parts: list[LLMContentPart] = []
has_multimodal_content = False
for part in content_list:
if isinstance(part, LLMContentPart) and part.type == "text":
new_content_parts.append(part)
else:
has_multimodal_content = True
if has_multimodal_content:
placeholder = "[用户发送了媒体文件,内容已在首次分析时处理]"
text_part_found = False
for part in new_content_parts:
if part.type == "text":
part.text = f"{placeholder} {part.text or ''}".strip()
text_part_found = True
break
if not text_part_found:
new_content_parts.insert(0, LLMContentPart.text_part(placeholder))
sanitized_message.content = new_content_parts
return sanitized_message
async def chat(
self,
message: str | LLMMessage | list[LLMContentPart],
*,
model: ModelName = None,
preserve_media_in_history: bool | None = None,
tools: list[LLMTool] | None = None,
tool_choice: str | dict[str, Any] | None = None,
**kwargs: Any,
) -> LLMResponse:
"""
进行一次聊天对话支持工具调用
此方法会自动使用和更新会话内的历史记录
参数:
message: 用户输入的消息
model: 本次对话要使用的模型
preserve_media_in_history: 是否在历史记录中保留原始多模态信息
- True: 保留用于深度多轮媒体分析
- False: 不保留替换为占位符提高效率
- None (默认): 使用AI实例配置的默认值
tools: 本次对话可用的工具列表
tool_choice: 强制模型使用的工具
**kwargs: 传递给模型的其他生成参数
返回:
LLMResponse: 模型的完整响应可能包含文本或工具调用请求
"""
current_message: LLMMessage
if isinstance(message, str):
current_message = LLMMessage.user(message)
elif isinstance(message, list) and all(
isinstance(part, LLMContentPart) for part in message
):
current_message = LLMMessage.user(message)
elif isinstance(message, LLMMessage):
current_message = message
else:
raise LLMException(
f"AI.chat 不支持的消息类型: {type(message)}. "
"请使用 str, LLMMessage, 或 list[LLMContentPart]. "
"对于更复杂的多模态输入或文件路径,请使用 AI.analyze().",
code=LLMErrorCode.API_REQUEST_FAILED,
)
final_messages = [*self.history, current_message]
response = await self._execute_generation(
messages=final_messages,
model_name=model,
error_message="聊天失败",
config_overrides=kwargs,
llm_tools=tools,
tool_choice=tool_choice,
)
should_preserve = (
preserve_media_in_history
if preserve_media_in_history is not None
else self.config.default_preserve_media_in_history
)
if should_preserve:
logger.debug("深度分析模式:在历史记录中保留原始多模态消息。")
self.history.append(current_message)
else:
logger.debug("高效模式:净化历史记录中的多模态消息。")
sanitized_user_message = self._sanitize_message_for_history(current_message)
self.history.append(sanitized_user_message)
self.history.append(
LLMMessage(
role="assistant", content=response.text, tool_calls=response.tool_calls
)
)
return response
async def code(
self,
prompt: str,
*,
model: ModelName = None,
timeout: int | None = None,
**kwargs: Any,
) -> dict[str, Any]:
"""
代码执行
参数:
prompt: 代码执行的提示词
model: 要使用的模型名称
timeout: 代码执行超时时间
**kwargs: 传递给模型的其他参数
返回:
dict[str, Any]: 包含执行结果的字典包含textcode_executions和success字段
"""
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
config = CommonOverrides.gemini_code_execution()
if timeout:
config.custom_params = config.custom_params or {}
config.custom_params["code_execution_timeout"] = timeout
messages = [LLMMessage.user(prompt)]
response = await self._execute_generation(
messages=messages,
model_name=resolved_model,
error_message="代码执行失败",
config_overrides=kwargs,
base_config=config,
)
return {
"text": response.text,
"code_executions": response.code_executions or [],
"success": True,
}
async def search(
self,
query: str | UniMessage,
*,
model: ModelName = None,
instruction: str = "",
**kwargs: Any,
) -> dict[str, Any]:
"""
信息搜索 - 支持多模态输入
参数:
query: 搜索查询内容支持文本或多模态消息
model: 要使用的模型名称
instruction: 搜索指令
**kwargs: 传递给模型的其他参数
返回:
dict[str, Any]: 包含搜索结果的字典包含textsourcesqueries和success字段
"""
from nonebot_plugin_alconna.uniseg import UniMessage
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
config = CommonOverrides.gemini_grounding()
if isinstance(query, str):
messages = [LLMMessage.user(query)]
elif isinstance(query, UniMessage):
content_parts = await unimsg_to_llm_parts(query)
final_messages: list[LLMMessage] = []
if instruction:
final_messages.append(LLMMessage.system(instruction))
if not content_parts:
if instruction:
final_messages.append(LLMMessage.user(instruction))
else:
raise LLMException(
"搜索内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
)
else:
final_messages.append(LLMMessage.user(content_parts))
messages = final_messages
else:
raise LLMException(
f"不支持的搜索输入类型: {type(query)}. 请使用 str 或 UniMessage.",
code=LLMErrorCode.API_REQUEST_FAILED,
)
response = await self._execute_generation(
messages=messages,
model_name=resolved_model,
error_message="信息搜索失败",
config_overrides=kwargs,
base_config=config,
)
result = {
"text": response.text,
"sources": [],
"queries": [],
"success": True,
}
if response.grounding_metadata:
result["sources"] = response.grounding_metadata.grounding_attributions or []
result["queries"] = response.grounding_metadata.web_search_queries or []
return result
async def analyze(
self,
message: UniMessage | None,
*,
instruction: str = "",
model: ModelName = None,
use_tools: list[str] | None = None,
tool_config: dict[str, Any] | None = None,
activated_tools: list[LLMTool] | None = None,
history: list[LLMMessage] | None = None,
**kwargs: Any,
) -> LLMResponse:
"""
内容分析 - 接收 UniMessage 物件进行多模态分析和工具呼叫
参数:
message: 要分析的消息内容支持多模态
instruction: 分析指令
model: 要使用的模型名称
use_tools: 要使用的工具名称列表
tool_config: 工具配置
activated_tools: 已激活的工具列表
history: 对话历史记录
**kwargs: 传递给模型的其他参数
返回:
LLMResponse: 模型的完整响应结果
"""
from nonebot_plugin_alconna.uniseg import UniMessage
content_parts = await unimsg_to_llm_parts(message or UniMessage())
final_messages: list[LLMMessage] = []
if history:
final_messages.extend(history)
if instruction:
if not any(msg.role == "system" for msg in final_messages):
final_messages.insert(0, LLMMessage.system(instruction))
if not content_parts:
if instruction and not history:
final_messages.append(LLMMessage.user(instruction))
elif not history:
raise LLMException(
"分析内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
)
else:
final_messages.append(LLMMessage.user(content_parts))
llm_tools: list[LLMTool] | None = activated_tools
if not llm_tools and use_tools:
try:
llm_tools = tool_registry.get_tools(use_tools)
logger.debug(f"已从注册表加载工具定义: {use_tools}")
except ValueError as e:
raise LLMException(
f"加载工具定义失败: {e}",
code=LLMErrorCode.CONFIGURATION_ERROR,
cause=e,
)
tool_choice = None
if tool_config:
mode = tool_config.get("mode", "auto")
if mode in ["auto", "any", "none"]:
tool_choice = mode
response = await self._execute_generation(
messages=final_messages,
model_name=model,
error_message="内容分析失败",
config_overrides=kwargs,
llm_tools=llm_tools,
tool_choice=tool_choice,
)
return response
async def _execute_generation(
self,
messages: list[LLMMessage],
model_name: ModelName,
error_message: str,
config_overrides: dict[str, Any],
llm_tools: list[LLMTool] | None = None,
tool_choice: str | dict[str, Any] | None = None,
base_config: LLMGenerationConfig | None = None,
) -> LLMResponse:
"""通用的生成执行方法封装模型获取和单次API调用"""
try:
resolved_model_name = self._resolve_model_name(
model_name or self.config.model
)
final_config_dict = self._merge_config(
config_overrides, base_config=base_config
)
async with await get_model_instance(
resolved_model_name, override_config=final_config_dict
) as model_instance:
return await model_instance.generate_response(
messages,
tools=llm_tools,
tool_choice=tool_choice,
)
except LLMException:
raise
except Exception as e:
logger.error(f"{error_message}: {e}", e=e)
raise LLMException(f"{error_message}: {e}", cause=e)
def _resolve_model_name(self, model_name: ModelName) -> str:
"""解析模型名称"""
if model_name:
return model_name
default_model = get_global_default_model_name()
if default_model:
return default_model
raise LLMException(
"未指定模型名称且未设置全局默认模型",
code=LLMErrorCode.MODEL_NOT_FOUND,
)
def _merge_config(
self,
user_config: dict[str, Any],
base_config: LLMGenerationConfig | None = None,
) -> dict[str, Any]:
"""合并配置"""
final_config = {}
if base_config:
final_config.update(base_config.to_dict())
if self.config.temperature is not None:
final_config["temperature"] = self.config.temperature
if self.config.max_tokens is not None:
final_config["max_tokens"] = self.config.max_tokens
if self.config.enable_cache:
final_config["enable_caching"] = True
if self.config.enable_code:
final_config["enable_code_execution"] = True
if self.config.enable_search:
final_config["enable_grounding"] = True
if self.config.enable_gemini_json_mode:
final_config["response_mime_type"] = "application/json"
if self.config.enable_gemini_thinking:
final_config["thinking_budget"] = 0.8
if self.config.enable_gemini_safe_mode:
final_config["safety_settings"] = (
CommonOverrides.gemini_safe().safety_settings
)
if self.config.enable_gemini_multimodal:
final_config.update(CommonOverrides.gemini_multimodal().to_dict())
if self.config.enable_gemini_grounding:
final_config["enable_grounding"] = True
final_config.update(user_config)
return final_config
async def embed(
self,
texts: list[str] | str,
*,
model: ModelName = None,
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
**kwargs: Any,
) -> list[list[float]]:
"""
生成文本嵌入向量
参数:
texts: 要生成嵌入向量的文本或文本列表
model: 要使用的嵌入模型名称
task_type: 嵌入任务类型
**kwargs: 传递给模型的其他参数
返回:
list[list[float]]: 文本的嵌入向量列表
"""
if isinstance(texts, str):
texts = [texts]
if not texts:
return []
try:
resolved_model_str = (
model or self.config.default_embedding_model or self.config.model
)
if not resolved_model_str:
raise LLMException(
"使用 embed 功能时必须指定嵌入模型名称,"
"或在 AIConfig 中配置 default_embedding_model。",
code=LLMErrorCode.MODEL_NOT_FOUND,
)
resolved_model_str = self._resolve_model_name(resolved_model_str)
async with await get_model_instance(
resolved_model_str,
override_config=None,
) as embedding_model_instance:
return await embedding_model_instance.generate_embeddings(
texts, task_type=task_type, **kwargs
)
except LLMException:
raise
except Exception as e:
logger.error(f"文本嵌入失败: {e}", e=e)
raise LLMException(
f"文本嵌入失败: {e}", code=LLMErrorCode.EMBEDDING_FAILED, cause=e
)