zhenxun_bot/zhenxun/services/llm/tools/manager.py

294 lines
10 KiB
Python
Raw Normal View History

"""
工具提供者管理器
负责注册生命周期管理包括懒加载和统一提供所有工具
"""
import asyncio
from collections.abc import Callable
import inspect
from typing import Any
from pydantic import BaseModel
from zhenxun.services.log import logger
from zhenxun.utils.pydantic_compat import model_json_schema
from ..types import ToolExecutable, ToolProvider
from ..types.models import ToolDefinition, ToolResult
class FunctionExecutable(ToolExecutable):
"""一个 ToolExecutable 的实现,用于包装一个普通的 Python 函数。"""
def __init__(
self,
func: Callable,
name: str,
description: str,
params_model: type[BaseModel] | None,
):
self._func = func
self._name = name
self._description = description
self._params_model = params_model
async def get_definition(self) -> ToolDefinition:
if not self._params_model:
return ToolDefinition(
name=self._name,
description=self._description,
parameters={"type": "object", "properties": {}},
)
schema = model_json_schema(self._params_model)
return ToolDefinition(
name=self._name,
description=self._description,
parameters={
"type": "object",
"properties": schema.get("properties", {}),
"required": schema.get("required", []),
},
)
async def execute(self, **kwargs: Any) -> ToolResult:
raw_result: Any
if self._params_model:
try:
params_instance = self._params_model(**kwargs)
if inspect.iscoroutinefunction(self._func):
raw_result = await self._func(params_instance)
else:
loop = asyncio.get_event_loop()
raw_result = await loop.run_in_executor(
None, lambda: self._func(params_instance)
)
except Exception as e:
logger.error(
f"执行工具 '{self._name}' 时参数验证或实例化失败: {e}", e=e
)
raise
else:
if inspect.iscoroutinefunction(self._func):
raw_result = await self._func(**kwargs)
else:
loop = asyncio.get_event_loop()
raw_result = await loop.run_in_executor(
None, lambda: self._func(**kwargs)
)
return ToolResult(output=raw_result, display_content=str(raw_result))
class BuiltinFunctionToolProvider(ToolProvider):
"""一个内置的 ToolProvider用于处理通过装饰器注册的函数。"""
def __init__(self):
self._functions: dict[str, dict[str, Any]] = {}
def register(
self,
name: str,
func: Callable,
description: str,
params_model: type[BaseModel] | None,
):
self._functions[name] = {
"func": func,
"description": description,
"params_model": params_model,
}
async def initialize(self) -> None:
pass
async def discover_tools(
self,
allowed_servers: list[str] | None = None,
excluded_servers: list[str] | None = None,
) -> dict[str, ToolExecutable]:
executables = {}
for name, info in self._functions.items():
executables[name] = FunctionExecutable(
func=info["func"],
name=name,
description=info["description"],
params_model=info["params_model"],
)
return executables
async def get_tool_executable(
self, name: str, config: dict[str, Any]
) -> ToolExecutable | None:
if config.get("type") == "function" and name in self._functions:
info = self._functions[name]
return FunctionExecutable(
func=info["func"],
name=name,
description=info["description"],
params_model=info["params_model"],
)
return None
class ToolProviderManager:
"""工具提供者的中心化管理器,采用单例模式。"""
_instance: "ToolProviderManager | None" = None
def __new__(cls) -> "ToolProviderManager":
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if hasattr(self, "_initialized") and self._initialized:
return
self._providers: list[ToolProvider] = []
self._resolved_tools: dict[str, ToolExecutable] | None = None
self._init_lock = asyncio.Lock()
self._init_promise: asyncio.Task | None = None
self._builtin_function_provider = BuiltinFunctionToolProvider()
self.register(self._builtin_function_provider)
self._initialized = True
def register(self, provider: ToolProvider):
"""注册一个新的 ToolProvider。"""
if provider not in self._providers:
self._providers.append(provider)
logger.info(f"已注册工具提供者: {provider.__class__.__name__}")
def function_tool(
self,
name: str,
description: str,
params_model: type[BaseModel] | None = None,
):
"""装饰器:将一个函数注册为内置工具。"""
def decorator(func: Callable):
if name in self._builtin_function_provider._functions:
logger.warning(f"正在覆盖已注册的函数工具: {name}")
self._builtin_function_provider.register(
name=name,
func=func,
description=description,
params_model=params_model,
)
logger.info(f"已注册函数工具: '{name}'")
return func
return decorator
async def initialize(self) -> None:
"""懒加载初始化所有已注册的 ToolProvider。"""
if not self._init_promise:
async with self._init_lock:
if not self._init_promise:
self._init_promise = asyncio.create_task(
self._initialize_providers()
)
await self._init_promise
async def _initialize_providers(self) -> None:
"""内部初始化逻辑。"""
logger.info(f"开始初始化 {len(self._providers)} 个工具提供者...")
init_tasks = [provider.initialize() for provider in self._providers]
await asyncio.gather(*init_tasks, return_exceptions=True)
logger.info("所有工具提供者初始化完成。")
async def get_resolved_tools(
self,
allowed_servers: list[str] | None = None,
excluded_servers: list[str] | None = None,
) -> dict[str, ToolExecutable]:
"""
获取所有已发现和解析的工具
此方法会触发懒加载初始化并根据是否传入过滤器来决定是否使用全局缓存
"""
await self.initialize()
has_filters = allowed_servers is not None or excluded_servers is not None
if not has_filters and self._resolved_tools is not None:
logger.debug("使用全局工具缓存。")
return self._resolved_tools
if has_filters:
logger.info("检测到过滤器,执行临时工具发现 (不使用缓存)。")
logger.debug(
f"过滤器详情: allowed_servers={allowed_servers}, "
f"excluded_servers={excluded_servers}"
)
else:
logger.info("未应用过滤器,开始全局工具发现...")
all_tools: dict[str, ToolExecutable] = {}
discover_tasks = []
for provider in self._providers:
sig = inspect.signature(provider.discover_tools)
params_to_pass = {}
if "allowed_servers" in sig.parameters:
params_to_pass["allowed_servers"] = allowed_servers
if "excluded_servers" in sig.parameters:
params_to_pass["excluded_servers"] = excluded_servers
discover_tasks.append(provider.discover_tools(**params_to_pass))
results = await asyncio.gather(*discover_tasks, return_exceptions=True)
for i, provider_result in enumerate(results):
provider_name = self._providers[i].__class__.__name__
if isinstance(provider_result, dict):
logger.debug(
f"提供者 '{provider_name}' 发现了 {len(provider_result)} 个工具。"
)
for name, executable in provider_result.items():
if name in all_tools:
logger.warning(
f"发现重复的工具名称 '{name}',后发现的将覆盖前者。"
)
all_tools[name] = executable
elif isinstance(provider_result, Exception):
logger.error(
f"提供者 '{provider_name}' 在发现工具时出错: {provider_result}"
)
if not has_filters:
self._resolved_tools = all_tools
logger.info(f"全局工具发现完成,共找到并缓存了 {len(all_tools)} 个工具。")
else:
logger.info(f"带过滤器的工具发现完成,共找到 {len(all_tools)} 个工具。")
return all_tools
async def get_function_tools(
self, names: list[str] | None = None
) -> dict[str, ToolExecutable]:
"""
仅从内置的函数提供者中解析指定的工具
"""
all_function_tools = await self._builtin_function_provider.discover_tools()
if names is None:
return all_function_tools
resolved_tools = {}
for name in names:
if name in all_function_tools:
resolved_tools[name] = all_function_tools[name]
else:
logger.warning(
f"本地函数工具 '{name}' 未通过 @function_tool 注册,将被忽略。"
)
return resolved_tools
tool_provider_manager = ToolProviderManager()