mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 14:22:55 +08:00
476 lines
16 KiB
Python
476 lines
16 KiB
Python
"""
|
||
LLM 服务的高级 API 接口
|
||
"""
|
||
|
||
from dataclasses import dataclass
|
||
from enum import Enum
|
||
from pathlib import Path
|
||
from typing import Any
|
||
|
||
from nonebot_plugin_alconna.uniseg import UniMessage
|
||
|
||
from zhenxun.services.log import logger
|
||
|
||
from .config import CommonOverrides, LLMGenerationConfig
|
||
from .config.providers import get_ai_config
|
||
from .manager import get_global_default_model_name, get_model_instance
|
||
from .types import (
|
||
EmbeddingTaskType,
|
||
LLMContentPart,
|
||
LLMErrorCode,
|
||
LLMException,
|
||
LLMMessage,
|
||
LLMResponse,
|
||
LLMTool,
|
||
ModelName,
|
||
)
|
||
from .utils import create_multimodal_message, unimsg_to_llm_parts
|
||
|
||
|
||
class TaskType(Enum):
|
||
"""任务类型枚举"""
|
||
|
||
CHAT = "chat"
|
||
CODE = "code"
|
||
SEARCH = "search"
|
||
ANALYSIS = "analysis"
|
||
GENERATION = "generation"
|
||
MULTIMODAL = "multimodal"
|
||
|
||
|
||
@dataclass
|
||
class AIConfig:
|
||
"""AI配置类 - 简化版本"""
|
||
|
||
model: ModelName = None
|
||
default_embedding_model: ModelName = None
|
||
temperature: float | None = None
|
||
max_tokens: int | None = None
|
||
enable_cache: bool = False
|
||
enable_code: bool = False
|
||
enable_search: bool = False
|
||
timeout: int | None = None
|
||
|
||
enable_gemini_json_mode: bool = False
|
||
enable_gemini_thinking: bool = False
|
||
enable_gemini_safe_mode: bool = False
|
||
enable_gemini_multimodal: bool = False
|
||
enable_gemini_grounding: bool = False
|
||
|
||
def __post_init__(self):
|
||
"""初始化后从配置中读取默认值"""
|
||
ai_config = get_ai_config()
|
||
if self.model is None:
|
||
self.model = ai_config.get("default_model_name")
|
||
if self.timeout is None:
|
||
self.timeout = ai_config.get("timeout", 180)
|
||
|
||
|
||
class AI:
|
||
"""统一的AI服务类 - 平衡设计版本
|
||
|
||
提供三层API:
|
||
1. 简单方法:ai.chat(), ai.code(), ai.search()
|
||
2. 标准方法:ai.analyze() 支持复杂参数
|
||
3. 高级方法:通过get_model_instance()直接访问
|
||
"""
|
||
|
||
def __init__(self, config: AIConfig | None = None):
|
||
"""初始化AI服务"""
|
||
self.config = config or AIConfig()
|
||
|
||
async def chat(
|
||
self,
|
||
message: str | LLMMessage | list[LLMContentPart],
|
||
*,
|
||
model: ModelName = None,
|
||
**kwargs: Any,
|
||
) -> str:
|
||
"""聊天对话 - 支持简单多模态输入"""
|
||
llm_messages: list[LLMMessage]
|
||
|
||
if isinstance(message, str):
|
||
llm_messages = [LLMMessage.user(message)]
|
||
elif isinstance(message, list) and all(isinstance(part, LLMContentPart) for part in message):
|
||
llm_messages = [LLMMessage.user(message)]
|
||
elif isinstance(message, LLMMessage):
|
||
llm_messages = [message]
|
||
else:
|
||
raise LLMException(
|
||
f"AI.chat 不支持的消息类型: {type(message)}. "
|
||
"请使用 str, LLMMessage, 或 list[LLMContentPart]. "
|
||
"对于更复杂的多模态输入或文件路径,请使用 AI.analyze().",
|
||
code=LLMErrorCode.API_REQUEST_FAILED,
|
||
)
|
||
|
||
response = await self._execute_generation(llm_messages, model, "聊天失败", kwargs)
|
||
return response.text
|
||
|
||
async def code(
|
||
self,
|
||
prompt: str,
|
||
*,
|
||
model: ModelName = None,
|
||
timeout: int | None = None,
|
||
**kwargs: Any,
|
||
) -> dict[str, Any]:
|
||
"""代码执行"""
|
||
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
|
||
|
||
config = CommonOverrides.gemini_code_execution()
|
||
if timeout:
|
||
config.custom_params = config.custom_params or {}
|
||
config.custom_params["code_execution_timeout"] = timeout
|
||
|
||
messages = [LLMMessage.user(prompt)]
|
||
|
||
response = await self._execute_generation(
|
||
messages, resolved_model, "代码执行失败", kwargs, base_config=config
|
||
)
|
||
|
||
return {
|
||
"text": response.text,
|
||
"code_executions": response.code_executions or [],
|
||
"success": True,
|
||
}
|
||
|
||
async def search(
|
||
self, query: str | UniMessage, *, model: ModelName = None, instruction: str = "", **kwargs: Any
|
||
) -> dict[str, Any]:
|
||
"""信息搜索 - 支持多模态输入"""
|
||
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
|
||
config = CommonOverrides.gemini_grounding()
|
||
|
||
if isinstance(query, str):
|
||
messages = [LLMMessage.user(query)]
|
||
elif isinstance(query, UniMessage):
|
||
content_parts = await unimsg_to_llm_parts(query)
|
||
|
||
final_messages: list[LLMMessage] = []
|
||
if instruction:
|
||
final_messages.append(LLMMessage.system(instruction))
|
||
|
||
if not content_parts:
|
||
if instruction:
|
||
final_messages.append(LLMMessage.user(instruction))
|
||
else:
|
||
raise LLMException("搜索内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED)
|
||
else:
|
||
final_messages.append(LLMMessage.user(content_parts))
|
||
|
||
messages = final_messages
|
||
else:
|
||
raise LLMException(
|
||
f"不支持的搜索输入类型: {type(query)}. 请使用 str 或 UniMessage.",
|
||
code=LLMErrorCode.API_REQUEST_FAILED,
|
||
)
|
||
|
||
response = await self._execute_generation(
|
||
messages, resolved_model, "信息搜索失败", kwargs, base_config=config
|
||
)
|
||
|
||
result = {
|
||
"text": response.text,
|
||
"sources": [],
|
||
"queries": [],
|
||
"success": True,
|
||
}
|
||
|
||
if response.grounding_metadata:
|
||
result["sources"] = response.grounding_metadata.grounding_attributions or []
|
||
result["queries"] = response.grounding_metadata.web_search_queries or []
|
||
|
||
return result
|
||
|
||
async def analyze(
|
||
self,
|
||
message: UniMessage,
|
||
*,
|
||
instruction: str = "",
|
||
model: ModelName = None,
|
||
tools: list[dict[str, Any]] | None = None,
|
||
tool_config: dict[str, Any] | None = None,
|
||
**kwargs: Any,
|
||
) -> str | LLMResponse:
|
||
"""
|
||
内容分析 - 接收 UniMessage 物件进行多模态分析和工具呼叫。
|
||
这是处理复杂互动的主要方法。
|
||
"""
|
||
content_parts = await unimsg_to_llm_parts(message)
|
||
|
||
final_messages: list[LLMMessage] = []
|
||
if instruction:
|
||
final_messages.append(LLMMessage.system(instruction))
|
||
|
||
if not content_parts:
|
||
if instruction:
|
||
final_messages.append(LLMMessage.user(instruction))
|
||
else:
|
||
raise LLMException("分析内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED)
|
||
else:
|
||
final_messages.append(LLMMessage.user(content_parts))
|
||
|
||
llm_tools = None
|
||
if tools:
|
||
llm_tools = []
|
||
for tool_dict in tools:
|
||
if isinstance(tool_dict, dict):
|
||
if "name" in tool_dict and "description" in tool_dict:
|
||
llm_tool = LLMTool(
|
||
type="function",
|
||
function={
|
||
"name": tool_dict["name"],
|
||
"description": tool_dict["description"],
|
||
"parameters": tool_dict.get("parameters", {}),
|
||
},
|
||
)
|
||
llm_tools.append(llm_tool)
|
||
else:
|
||
llm_tools.append(LLMTool(**tool_dict))
|
||
else:
|
||
llm_tools.append(tool_dict)
|
||
|
||
tool_choice = None
|
||
if tool_config:
|
||
mode = tool_config.get("mode", "auto")
|
||
if mode == "auto":
|
||
tool_choice = "auto"
|
||
elif mode == "any":
|
||
tool_choice = "any"
|
||
elif mode == "none":
|
||
tool_choice = "none"
|
||
|
||
response = await self._execute_generation(
|
||
final_messages, model, "内容分析失败", kwargs, llm_tools=llm_tools, tool_choice=tool_choice
|
||
)
|
||
|
||
if response.tool_calls:
|
||
return response
|
||
return response.text
|
||
|
||
async def _execute_generation(
|
||
self,
|
||
messages: list[LLMMessage],
|
||
model_name: ModelName,
|
||
error_message: str,
|
||
config_overrides: dict[str, Any],
|
||
llm_tools: list[LLMTool] | None = None,
|
||
tool_choice: str | dict[str, Any] | None = None,
|
||
base_config: LLMGenerationConfig | None = None,
|
||
) -> LLMResponse:
|
||
"""通用的生成执行方法,封装重复的模型获取、配置合并和异常处理逻辑"""
|
||
try:
|
||
resolved_model_name = self._resolve_model_name(model_name or self.config.model)
|
||
final_config_dict = self._merge_config(config_overrides, base_config=base_config)
|
||
|
||
async with await get_model_instance(
|
||
resolved_model_name, override_config=final_config_dict
|
||
) as model_instance:
|
||
return await model_instance.generate_response(
|
||
messages, tools=llm_tools, tool_choice=tool_choice
|
||
)
|
||
except LLMException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"{error_message}: {e}", e=e)
|
||
raise LLMException(f"{error_message}: {e}", cause=e)
|
||
|
||
def _resolve_model_name(self, model_name: ModelName) -> str:
|
||
"""解析模型名称"""
|
||
if model_name:
|
||
return model_name
|
||
|
||
default_model = get_global_default_model_name()
|
||
if default_model:
|
||
return default_model
|
||
|
||
raise LLMException(
|
||
"未指定模型名称且未设置全局默认模型",
|
||
code=LLMErrorCode.MODEL_NOT_FOUND,
|
||
)
|
||
|
||
def _merge_config(
|
||
self,
|
||
user_config: dict[str, Any],
|
||
base_config: LLMGenerationConfig | None = None,
|
||
) -> dict[str, Any]:
|
||
"""合并配置"""
|
||
final_config = {}
|
||
if base_config:
|
||
final_config.update(base_config.to_dict())
|
||
|
||
if self.config.temperature is not None:
|
||
final_config["temperature"] = self.config.temperature
|
||
if self.config.max_tokens is not None:
|
||
final_config["max_tokens"] = self.config.max_tokens
|
||
|
||
if self.config.enable_cache:
|
||
final_config["enable_caching"] = True
|
||
if self.config.enable_code:
|
||
final_config["enable_code_execution"] = True
|
||
if self.config.enable_search:
|
||
final_config["enable_grounding"] = True
|
||
|
||
if self.config.enable_gemini_json_mode:
|
||
final_config["response_mime_type"] = "application/json"
|
||
if self.config.enable_gemini_thinking:
|
||
final_config["thinking_budget"] = 0.8
|
||
if self.config.enable_gemini_safe_mode:
|
||
final_config["safety_settings"] = CommonOverrides.gemini_safe().safety_settings
|
||
if self.config.enable_gemini_multimodal:
|
||
final_config.update(CommonOverrides.gemini_multimodal().to_dict())
|
||
if self.config.enable_gemini_grounding:
|
||
final_config["enable_grounding"] = True
|
||
|
||
final_config.update(user_config)
|
||
|
||
return final_config
|
||
|
||
async def embed(
|
||
self,
|
||
texts: list[str] | str,
|
||
*,
|
||
model: ModelName = None,
|
||
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
||
**kwargs: Any,
|
||
) -> list[list[float]]:
|
||
"""生成文本嵌入向量"""
|
||
if isinstance(texts, str):
|
||
texts = [texts]
|
||
if not texts:
|
||
return []
|
||
|
||
try:
|
||
resolved_model_str = model or self.config.default_embedding_model or self.config.model
|
||
if not resolved_model_str:
|
||
raise LLMException(
|
||
"使用 embed 功能时必须指定嵌入模型名称,或在 AIConfig 中配置 default_embedding_model。",
|
||
code=LLMErrorCode.MODEL_NOT_FOUND,
|
||
)
|
||
resolved_model_str = self._resolve_model_name(resolved_model_str)
|
||
|
||
async with await get_model_instance(
|
||
resolved_model_str,
|
||
override_config=None,
|
||
) as embedding_model_instance:
|
||
return await embedding_model_instance.generate_embeddings(
|
||
texts, task_type=task_type, **kwargs
|
||
)
|
||
except LLMException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"文本嵌入失败: {e}", e=e)
|
||
raise LLMException(f"文本嵌入失败: {e}", code=LLMErrorCode.EMBEDDING_FAILED, cause=e)
|
||
|
||
|
||
async def chat(
|
||
message: str | LLMMessage | list[LLMContentPart],
|
||
*,
|
||
model: ModelName = None,
|
||
**kwargs: Any,
|
||
) -> str:
|
||
"""聊天对话便捷函数"""
|
||
ai = AI()
|
||
return await ai.chat(message, model=model, **kwargs)
|
||
|
||
|
||
async def code(
|
||
prompt: str,
|
||
*,
|
||
model: ModelName = None,
|
||
timeout: int | None = None,
|
||
**kwargs: Any,
|
||
) -> dict[str, Any]:
|
||
"""代码执行便捷函数"""
|
||
ai = AI()
|
||
return await ai.code(prompt, model=model, timeout=timeout, **kwargs)
|
||
|
||
|
||
async def search(
|
||
query: str | UniMessage,
|
||
*,
|
||
model: ModelName = None,
|
||
instruction: str = "",
|
||
**kwargs: Any,
|
||
) -> dict[str, Any]:
|
||
"""信息搜索便捷函数"""
|
||
ai = AI()
|
||
return await ai.search(query, model=model, instruction=instruction, **kwargs)
|
||
|
||
|
||
async def analyze(
|
||
message: UniMessage,
|
||
*,
|
||
instruction: str = "",
|
||
model: ModelName = None,
|
||
tools: list[dict[str, Any]] | None = None,
|
||
tool_config: dict[str, Any] | None = None,
|
||
**kwargs: Any,
|
||
) -> str | LLMResponse:
|
||
"""内容分析便捷函数"""
|
||
ai = AI()
|
||
return await ai.analyze(
|
||
message,
|
||
instruction=instruction,
|
||
model=model,
|
||
tools=tools,
|
||
tool_config=tool_config,
|
||
**kwargs,
|
||
)
|
||
|
||
|
||
async def analyze_with_images(
|
||
text: str,
|
||
images: list[str | Path | bytes] | str | Path | bytes,
|
||
*,
|
||
instruction: str = "",
|
||
model: ModelName = None,
|
||
**kwargs: Any,
|
||
) -> str | LLMResponse:
|
||
"""图片分析便捷函数"""
|
||
message = create_multimodal_message(text=text, images=images)
|
||
return await analyze(message, instruction=instruction, model=model, **kwargs)
|
||
|
||
|
||
async def analyze_multimodal(
|
||
text: str | None = None,
|
||
images: list[str | Path | bytes] | str | Path | bytes | None = None,
|
||
videos: list[str | Path | bytes] | str | Path | bytes | None = None,
|
||
audios: list[str | Path | bytes] | str | Path | bytes | None = None,
|
||
*,
|
||
instruction: str = "",
|
||
model: ModelName = None,
|
||
**kwargs: Any,
|
||
) -> str | LLMResponse:
|
||
"""多模态分析便捷函数"""
|
||
message = create_multimodal_message(text=text, images=images, videos=videos, audios=audios)
|
||
return await analyze(message, instruction=instruction, model=model, **kwargs)
|
||
|
||
|
||
async def search_multimodal(
|
||
text: str | None = None,
|
||
images: list[str | Path | bytes] | str | Path | bytes | None = None,
|
||
videos: list[str | Path | bytes] | str | Path | bytes | None = None,
|
||
audios: list[str | Path | bytes] | str | Path | bytes | None = None,
|
||
*,
|
||
instruction: str = "",
|
||
model: ModelName = None,
|
||
**kwargs: Any,
|
||
) -> dict[str, Any]:
|
||
"""多模态搜索便捷函数"""
|
||
message = create_multimodal_message(text=text, images=images, videos=videos, audios=audios)
|
||
ai = AI()
|
||
return await ai.search(message, model=model, instruction=instruction, **kwargs)
|
||
|
||
|
||
async def embed(
|
||
texts: list[str] | str,
|
||
*,
|
||
model: ModelName = None,
|
||
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
||
**kwargs: Any,
|
||
) -> list[list[float]]:
|
||
"""文本嵌入便捷函数"""
|
||
ai = AI()
|
||
return await ai.embed(texts, model=model, task_type=task_type, **kwargs)
|