mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 14:22:55 +08:00
182 lines
6.7 KiB
Python
182 lines
6.7 KiB
Python
|
|
"""
|
|||
|
|
工具注册表
|
|||
|
|
|
|||
|
|
负责加载、管理和实例化来自配置的工具。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from collections.abc import Callable
|
|||
|
|
from contextlib import AbstractAsyncContextManager
|
|||
|
|
from functools import partial
|
|||
|
|
from typing import TYPE_CHECKING
|
|||
|
|
|
|||
|
|
from pydantic import BaseModel
|
|||
|
|
|
|||
|
|
from zhenxun.services.log import logger
|
|||
|
|
|
|||
|
|
from ..types import LLMTool
|
|||
|
|
|
|||
|
|
if TYPE_CHECKING:
|
|||
|
|
from ..config.providers import ToolConfig
|
|||
|
|
from ..types.protocols import MCPCompatible
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ToolRegistry:
|
|||
|
|
"""工具注册表,用于管理和实例化配置的工具。"""
|
|||
|
|
|
|||
|
|
def __init__(self):
|
|||
|
|
self._function_tools: dict[str, LLMTool] = {}
|
|||
|
|
|
|||
|
|
self._mcp_config_models: dict[str, type[BaseModel]] = {}
|
|||
|
|
if TYPE_CHECKING:
|
|||
|
|
self._mcp_factories: dict[
|
|||
|
|
str, Callable[..., AbstractAsyncContextManager["MCPCompatible"]]
|
|||
|
|
] = {}
|
|||
|
|
else:
|
|||
|
|
self._mcp_factories: dict[str, Callable] = {}
|
|||
|
|
|
|||
|
|
self._tool_configs: dict[str, "ToolConfig"] | None = None
|
|||
|
|
self._tool_cache: dict[str, "LLMTool"] = {}
|
|||
|
|
|
|||
|
|
def _load_configs_if_needed(self):
|
|||
|
|
"""如果尚未加载,则从主配置中加载MCP工具定义。"""
|
|||
|
|
if self._tool_configs is None:
|
|||
|
|
logger.debug("首次访问,正在加载MCP工具配置...")
|
|||
|
|
from ..config.providers import get_llm_config
|
|||
|
|
|
|||
|
|
llm_config = get_llm_config()
|
|||
|
|
self._tool_configs = {tool.name: tool for tool in llm_config.mcp_tools}
|
|||
|
|
logger.info(f"已加载 {len(self._tool_configs)} 个MCP工具配置。")
|
|||
|
|
|
|||
|
|
def function_tool(
|
|||
|
|
self,
|
|||
|
|
name: str,
|
|||
|
|
description: str,
|
|||
|
|
parameters: dict,
|
|||
|
|
required: list[str] | None = None,
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
装饰器:在代码中注册一个简单的、无状态的函数工具。
|
|||
|
|
|
|||
|
|
参数:
|
|||
|
|
name: 工具的唯一名称。
|
|||
|
|
description: 工具功能的描述。
|
|||
|
|
parameters: OpenAPI格式的函数参数schema的properties部分。
|
|||
|
|
required: 必需的参数列表。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def decorator(func: Callable):
|
|||
|
|
if name in self._function_tools or name in self._mcp_factories:
|
|||
|
|
logger.warning(f"正在覆盖已注册的工具: {name}")
|
|||
|
|
|
|||
|
|
tool_definition = LLMTool.create(
|
|||
|
|
name=name,
|
|||
|
|
description=description,
|
|||
|
|
parameters=parameters,
|
|||
|
|
required=required,
|
|||
|
|
)
|
|||
|
|
self._function_tools[name] = tool_definition
|
|||
|
|
logger.info(f"已在代码中注册函数工具: '{name}'")
|
|||
|
|
tool_definition.annotations = tool_definition.annotations or {}
|
|||
|
|
tool_definition.annotations["executable"] = func
|
|||
|
|
return func
|
|||
|
|
|
|||
|
|
return decorator
|
|||
|
|
|
|||
|
|
def mcp_tool(self, name: str, config_model: type[BaseModel]):
|
|||
|
|
"""
|
|||
|
|
装饰器:注册一个MCP工具及其配置模型。
|
|||
|
|
|
|||
|
|
参数:
|
|||
|
|
name: 工具的唯一名称,必须与配置文件中的名称匹配。
|
|||
|
|
config_model: 一个Pydantic模型,用于定义和验证该工具的 `mcp_config`。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def decorator(factory_func: Callable):
|
|||
|
|
if name in self._mcp_factories:
|
|||
|
|
logger.warning(f"正在覆盖已注册的 MCP 工厂: {name}")
|
|||
|
|
self._mcp_factories[name] = factory_func
|
|||
|
|
self._mcp_config_models[name] = config_model
|
|||
|
|
logger.info(f"已注册 MCP 工具 '{name}' (配置模型: {config_model.__name__})")
|
|||
|
|
return factory_func
|
|||
|
|
|
|||
|
|
return decorator
|
|||
|
|
|
|||
|
|
def get_mcp_config_model(self, name: str) -> type[BaseModel] | None:
|
|||
|
|
"""根据名称获取MCP工具的配置模型。"""
|
|||
|
|
return self._mcp_config_models.get(name)
|
|||
|
|
|
|||
|
|
def register_mcp_factory(
|
|||
|
|
self,
|
|||
|
|
name: str,
|
|||
|
|
factory: Callable,
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
在代码中注册一个 MCP 会话工厂,将其与配置中的工具名称关联。
|
|||
|
|
|
|||
|
|
参数:
|
|||
|
|
name: 工具的唯一名称,必须与配置文件中的名称匹配。
|
|||
|
|
factory: 一个返回异步生成器的可调用对象(会话工厂)。
|
|||
|
|
"""
|
|||
|
|
if name in self._mcp_factories:
|
|||
|
|
logger.warning(f"正在覆盖已注册的 MCP 工厂: {name}")
|
|||
|
|
self._mcp_factories[name] = factory
|
|||
|
|
logger.info(f"已注册 MCP 会话工厂: '{name}'")
|
|||
|
|
|
|||
|
|
def get_tool(self, name: str) -> "LLMTool":
|
|||
|
|
"""
|
|||
|
|
根据名称获取一个 LLMTool 定义。
|
|||
|
|
对于MCP工具,返回的 LLMTool 实例包含一个可调用的会话工厂,
|
|||
|
|
而不是一个已激活的会话。
|
|||
|
|
"""
|
|||
|
|
logger.debug(f"🔍 请求获取工具定义: {name}")
|
|||
|
|
|
|||
|
|
if name in self._tool_cache:
|
|||
|
|
logger.debug(f"✅ 从缓存中获取工具定义: {name}")
|
|||
|
|
return self._tool_cache[name]
|
|||
|
|
|
|||
|
|
if name in self._function_tools:
|
|||
|
|
logger.debug(f"🛠️ 获取函数工具定义: {name}")
|
|||
|
|
tool = self._function_tools[name]
|
|||
|
|
self._tool_cache[name] = tool
|
|||
|
|
return tool
|
|||
|
|
|
|||
|
|
self._load_configs_if_needed()
|
|||
|
|
if self._tool_configs is None or name not in self._tool_configs:
|
|||
|
|
known_tools = list(self._function_tools.keys()) + (
|
|||
|
|
list(self._tool_configs.keys()) if self._tool_configs else []
|
|||
|
|
)
|
|||
|
|
logger.error(f"❌ 未找到名为 '{name}' 的工具定义")
|
|||
|
|
logger.debug(f"📋 可用工具定义列表: {known_tools}")
|
|||
|
|
raise ValueError(f"未找到名为 '{name}' 的工具定义。已知工具: {known_tools}")
|
|||
|
|
|
|||
|
|
config = self._tool_configs[name]
|
|||
|
|
tool: "LLMTool"
|
|||
|
|
|
|||
|
|
if name not in self._mcp_factories:
|
|||
|
|
logger.error(f"❌ MCP工具 '{name}' 缺少工厂函数")
|
|||
|
|
available_factories = list(self._mcp_factories.keys())
|
|||
|
|
logger.debug(f"📋 已注册的MCP工厂: {available_factories}")
|
|||
|
|
raise ValueError(
|
|||
|
|
f"MCP 工具 '{name}' 已在配置中定义,但没有注册对应的工厂函数。"
|
|||
|
|
"请使用 `@tool_registry.mcp_tool` 装饰器进行注册。"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
logger.info(f"🔧 创建MCP工具定义: {name}")
|
|||
|
|
factory = self._mcp_factories[name]
|
|||
|
|
typed_mcp_config = config.mcp_config
|
|||
|
|
logger.debug(f"📋 MCP工具配置: {typed_mcp_config}")
|
|||
|
|
|
|||
|
|
configured_factory = partial(factory, config=typed_mcp_config)
|
|||
|
|
tool = LLMTool.from_mcp_session(session=configured_factory)
|
|||
|
|
|
|||
|
|
self._tool_cache[name] = tool
|
|||
|
|
logger.debug(f"💾 MCP工具定义已缓存: {name}")
|
|||
|
|
return tool
|
|||
|
|
|
|||
|
|
def get_tools(self, names: list[str]) -> list["LLMTool"]:
|
|||
|
|
"""根据名称列表获取多个 LLMTool 实例。"""
|
|||
|
|
return [self.get_tool(name) for name in names]
|
|||
|
|
|
|||
|
|
|
|||
|
|
tool_registry = ToolRegistry()
|