mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 06:12:53 +08:00
* ✨ feat(llm): 全面重构LLM服务模块,增强多模态与工具支持 🚀 核心功能增强 - 多模型链式调用:新增 `pipeline_chat` 支持复杂任务流处理 - 扩展提供商支持:新增 ARK(火山方舟)、SiliconFlow(硅基流动) 适配器 - 多模态处理增强:支持URL媒体文件下载转换,提升输入灵活性 - 历史对话支持:AI.analyze 方法支持历史消息上下文和可选 UniMessage 参数 - 文本嵌入功能:新增 `embed`、`analyze_multimodal`、`search_multimodal` 等API - 模型能力系统:新增 `ModelCapabilities` 统一管理模型特性(多模态、工具调用等) 🔧 架构重构与优化 - MCP工具系统重构:配置独立化至 `data/llm/mcp_tools.json`,预置常用工具 - API调用逻辑统一:提取通用 `_perform_api_call` 方法,消除代码重复 - 跨平台兼容:Windows平台MCP工具npx命令自动包装处理 - HTTP客户端增强:兼容不同版本httpx代理配置(0.28+版本适配) 🛠️ API与配置完善 - 统一返回类型:`AI.analyze` 统一返回 `LLMResponse` 类型 - 消息转换工具:新增 `message_to_unimessage` 转换函数 - Gemini适配器增强:URL图片下载编码、动态安全阈值配置 - 缓存管理:新增模型实例缓存和管理功能 - 配置预设:扩展 CommonOverrides 预设配置选项 - 历史管理优化:支持多模态内容占位符替换,提升效率 📚 文档与开发体验 - README全面重写:新增完整使用指南、API参考和架构概览 - 文档内容扩充:补充嵌入模型、缓存管理、工具注册等功能说明 - 日志记录增强:支持详细调试信息输出 - API简化:移除冗余函数,优化接口设计 * 🎨 feat(llm): 统一LLM服务函数文档格式 * ✨ feat(llm): 添加新模型并简化提供者配置加载 * 🚨 auto fix by pre-commit hooks --------- Co-authored-by: webjoin111 <455457521@qq.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
182 lines
6.7 KiB
Python
182 lines
6.7 KiB
Python
"""
|
||
工具注册表
|
||
|
||
负责加载、管理和实例化来自配置的工具。
|
||
"""
|
||
|
||
from collections.abc import Callable
|
||
from contextlib import AbstractAsyncContextManager
|
||
from functools import partial
|
||
from typing import TYPE_CHECKING
|
||
|
||
from pydantic import BaseModel
|
||
|
||
from zhenxun.services.log import logger
|
||
|
||
from ..types import LLMTool
|
||
|
||
if TYPE_CHECKING:
|
||
from ..config.providers import ToolConfig
|
||
from ..types.protocols import MCPCompatible
|
||
|
||
|
||
class ToolRegistry:
|
||
"""工具注册表,用于管理和实例化配置的工具。"""
|
||
|
||
def __init__(self):
|
||
self._function_tools: dict[str, LLMTool] = {}
|
||
|
||
self._mcp_config_models: dict[str, type[BaseModel]] = {}
|
||
if TYPE_CHECKING:
|
||
self._mcp_factories: dict[
|
||
str, Callable[..., AbstractAsyncContextManager["MCPCompatible"]]
|
||
] = {}
|
||
else:
|
||
self._mcp_factories: dict[str, Callable] = {}
|
||
|
||
self._tool_configs: dict[str, "ToolConfig"] | None = None
|
||
self._tool_cache: dict[str, "LLMTool"] = {}
|
||
|
||
def _load_configs_if_needed(self):
|
||
"""如果尚未加载,则从主配置中加载MCP工具定义。"""
|
||
if self._tool_configs is None:
|
||
logger.debug("首次访问,正在加载MCP工具配置...")
|
||
from ..config.providers import get_llm_config
|
||
|
||
llm_config = get_llm_config()
|
||
self._tool_configs = {tool.name: tool for tool in llm_config.mcp_tools}
|
||
logger.info(f"已加载 {len(self._tool_configs)} 个MCP工具配置。")
|
||
|
||
def function_tool(
|
||
self,
|
||
name: str,
|
||
description: str,
|
||
parameters: dict,
|
||
required: list[str] | None = None,
|
||
):
|
||
"""
|
||
装饰器:在代码中注册一个简单的、无状态的函数工具。
|
||
|
||
参数:
|
||
name: 工具的唯一名称。
|
||
description: 工具功能的描述。
|
||
parameters: OpenAPI格式的函数参数schema的properties部分。
|
||
required: 必需的参数列表。
|
||
"""
|
||
|
||
def decorator(func: Callable):
|
||
if name in self._function_tools or name in self._mcp_factories:
|
||
logger.warning(f"正在覆盖已注册的工具: {name}")
|
||
|
||
tool_definition = LLMTool.create(
|
||
name=name,
|
||
description=description,
|
||
parameters=parameters,
|
||
required=required,
|
||
)
|
||
self._function_tools[name] = tool_definition
|
||
logger.info(f"已在代码中注册函数工具: '{name}'")
|
||
tool_definition.annotations = tool_definition.annotations or {}
|
||
tool_definition.annotations["executable"] = func
|
||
return func
|
||
|
||
return decorator
|
||
|
||
def mcp_tool(self, name: str, config_model: type[BaseModel]):
|
||
"""
|
||
装饰器:注册一个MCP工具及其配置模型。
|
||
|
||
参数:
|
||
name: 工具的唯一名称,必须与配置文件中的名称匹配。
|
||
config_model: 一个Pydantic模型,用于定义和验证该工具的 `mcp_config`。
|
||
"""
|
||
|
||
def decorator(factory_func: Callable):
|
||
if name in self._mcp_factories:
|
||
logger.warning(f"正在覆盖已注册的 MCP 工厂: {name}")
|
||
self._mcp_factories[name] = factory_func
|
||
self._mcp_config_models[name] = config_model
|
||
logger.info(f"已注册 MCP 工具 '{name}' (配置模型: {config_model.__name__})")
|
||
return factory_func
|
||
|
||
return decorator
|
||
|
||
def get_mcp_config_model(self, name: str) -> type[BaseModel] | None:
|
||
"""根据名称获取MCP工具的配置模型。"""
|
||
return self._mcp_config_models.get(name)
|
||
|
||
def register_mcp_factory(
|
||
self,
|
||
name: str,
|
||
factory: Callable,
|
||
):
|
||
"""
|
||
在代码中注册一个 MCP 会话工厂,将其与配置中的工具名称关联。
|
||
|
||
参数:
|
||
name: 工具的唯一名称,必须与配置文件中的名称匹配。
|
||
factory: 一个返回异步生成器的可调用对象(会话工厂)。
|
||
"""
|
||
if name in self._mcp_factories:
|
||
logger.warning(f"正在覆盖已注册的 MCP 工厂: {name}")
|
||
self._mcp_factories[name] = factory
|
||
logger.info(f"已注册 MCP 会话工厂: '{name}'")
|
||
|
||
def get_tool(self, name: str) -> "LLMTool":
|
||
"""
|
||
根据名称获取一个 LLMTool 定义。
|
||
对于MCP工具,返回的 LLMTool 实例包含一个可调用的会话工厂,
|
||
而不是一个已激活的会话。
|
||
"""
|
||
logger.debug(f"🔍 请求获取工具定义: {name}")
|
||
|
||
if name in self._tool_cache:
|
||
logger.debug(f"✅ 从缓存中获取工具定义: {name}")
|
||
return self._tool_cache[name]
|
||
|
||
if name in self._function_tools:
|
||
logger.debug(f"🛠️ 获取函数工具定义: {name}")
|
||
tool = self._function_tools[name]
|
||
self._tool_cache[name] = tool
|
||
return tool
|
||
|
||
self._load_configs_if_needed()
|
||
if self._tool_configs is None or name not in self._tool_configs:
|
||
known_tools = list(self._function_tools.keys()) + (
|
||
list(self._tool_configs.keys()) if self._tool_configs else []
|
||
)
|
||
logger.error(f"❌ 未找到名为 '{name}' 的工具定义")
|
||
logger.debug(f"📋 可用工具定义列表: {known_tools}")
|
||
raise ValueError(f"未找到名为 '{name}' 的工具定义。已知工具: {known_tools}")
|
||
|
||
config = self._tool_configs[name]
|
||
tool: "LLMTool"
|
||
|
||
if name not in self._mcp_factories:
|
||
logger.error(f"❌ MCP工具 '{name}' 缺少工厂函数")
|
||
available_factories = list(self._mcp_factories.keys())
|
||
logger.debug(f"📋 已注册的MCP工厂: {available_factories}")
|
||
raise ValueError(
|
||
f"MCP 工具 '{name}' 已在配置中定义,但没有注册对应的工厂函数。"
|
||
"请使用 `@tool_registry.mcp_tool` 装饰器进行注册。"
|
||
)
|
||
|
||
logger.info(f"🔧 创建MCP工具定义: {name}")
|
||
factory = self._mcp_factories[name]
|
||
typed_mcp_config = config.mcp_config
|
||
logger.debug(f"📋 MCP工具配置: {typed_mcp_config}")
|
||
|
||
configured_factory = partial(factory, config=typed_mcp_config)
|
||
tool = LLMTool.from_mcp_session(session=configured_factory)
|
||
|
||
self._tool_cache[name] = tool
|
||
logger.debug(f"💾 MCP工具定义已缓存: {name}")
|
||
return tool
|
||
|
||
def get_tools(self, names: list[str]) -> list["LLMTool"]:
|
||
"""根据名称列表获取多个 LLMTool 实例。"""
|
||
return [self.get_tool(name) for name in names]
|
||
|
||
|
||
tool_registry = ToolRegistry()
|