mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 14:22:55 +08:00
294 lines
10 KiB
Python
294 lines
10 KiB
Python
|
|
"""
|
|||
|
|
工具提供者管理器
|
|||
|
|
|
|||
|
|
负责注册、生命周期管理(包括懒加载)和统一提供所有工具。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import asyncio
|
|||
|
|
from collections.abc import Callable
|
|||
|
|
import inspect
|
|||
|
|
from typing import Any
|
|||
|
|
|
|||
|
|
from pydantic import BaseModel
|
|||
|
|
|
|||
|
|
from zhenxun.services.log import logger
|
|||
|
|
from zhenxun.utils.pydantic_compat import model_json_schema
|
|||
|
|
|
|||
|
|
from ..types import ToolExecutable, ToolProvider
|
|||
|
|
from ..types.models import ToolDefinition, ToolResult
|
|||
|
|
|
|||
|
|
|
|||
|
|
class FunctionExecutable(ToolExecutable):
|
|||
|
|
"""一个 ToolExecutable 的实现,用于包装一个普通的 Python 函数。"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
func: Callable,
|
|||
|
|
name: str,
|
|||
|
|
description: str,
|
|||
|
|
params_model: type[BaseModel] | None,
|
|||
|
|
):
|
|||
|
|
self._func = func
|
|||
|
|
self._name = name
|
|||
|
|
self._description = description
|
|||
|
|
self._params_model = params_model
|
|||
|
|
|
|||
|
|
async def get_definition(self) -> ToolDefinition:
|
|||
|
|
if not self._params_model:
|
|||
|
|
return ToolDefinition(
|
|||
|
|
name=self._name,
|
|||
|
|
description=self._description,
|
|||
|
|
parameters={"type": "object", "properties": {}},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
schema = model_json_schema(self._params_model)
|
|||
|
|
|
|||
|
|
return ToolDefinition(
|
|||
|
|
name=self._name,
|
|||
|
|
description=self._description,
|
|||
|
|
parameters={
|
|||
|
|
"type": "object",
|
|||
|
|
"properties": schema.get("properties", {}),
|
|||
|
|
"required": schema.get("required", []),
|
|||
|
|
},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
async def execute(self, **kwargs: Any) -> ToolResult:
|
|||
|
|
raw_result: Any
|
|||
|
|
|
|||
|
|
if self._params_model:
|
|||
|
|
try:
|
|||
|
|
params_instance = self._params_model(**kwargs)
|
|||
|
|
|
|||
|
|
if inspect.iscoroutinefunction(self._func):
|
|||
|
|
raw_result = await self._func(params_instance)
|
|||
|
|
else:
|
|||
|
|
loop = asyncio.get_event_loop()
|
|||
|
|
raw_result = await loop.run_in_executor(
|
|||
|
|
None, lambda: self._func(params_instance)
|
|||
|
|
)
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(
|
|||
|
|
f"执行工具 '{self._name}' 时参数验证或实例化失败: {e}", e=e
|
|||
|
|
)
|
|||
|
|
raise
|
|||
|
|
else:
|
|||
|
|
if inspect.iscoroutinefunction(self._func):
|
|||
|
|
raw_result = await self._func(**kwargs)
|
|||
|
|
else:
|
|||
|
|
loop = asyncio.get_event_loop()
|
|||
|
|
raw_result = await loop.run_in_executor(
|
|||
|
|
None, lambda: self._func(**kwargs)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return ToolResult(output=raw_result, display_content=str(raw_result))
|
|||
|
|
|
|||
|
|
|
|||
|
|
class BuiltinFunctionToolProvider(ToolProvider):
|
|||
|
|
"""一个内置的 ToolProvider,用于处理通过装饰器注册的函数。"""
|
|||
|
|
|
|||
|
|
def __init__(self):
|
|||
|
|
self._functions: dict[str, dict[str, Any]] = {}
|
|||
|
|
|
|||
|
|
def register(
|
|||
|
|
self,
|
|||
|
|
name: str,
|
|||
|
|
func: Callable,
|
|||
|
|
description: str,
|
|||
|
|
params_model: type[BaseModel] | None,
|
|||
|
|
):
|
|||
|
|
self._functions[name] = {
|
|||
|
|
"func": func,
|
|||
|
|
"description": description,
|
|||
|
|
"params_model": params_model,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
async def initialize(self) -> None:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
async def discover_tools(
|
|||
|
|
self,
|
|||
|
|
allowed_servers: list[str] | None = None,
|
|||
|
|
excluded_servers: list[str] | None = None,
|
|||
|
|
) -> dict[str, ToolExecutable]:
|
|||
|
|
executables = {}
|
|||
|
|
for name, info in self._functions.items():
|
|||
|
|
executables[name] = FunctionExecutable(
|
|||
|
|
func=info["func"],
|
|||
|
|
name=name,
|
|||
|
|
description=info["description"],
|
|||
|
|
params_model=info["params_model"],
|
|||
|
|
)
|
|||
|
|
return executables
|
|||
|
|
|
|||
|
|
async def get_tool_executable(
|
|||
|
|
self, name: str, config: dict[str, Any]
|
|||
|
|
) -> ToolExecutable | None:
|
|||
|
|
if config.get("type") == "function" and name in self._functions:
|
|||
|
|
info = self._functions[name]
|
|||
|
|
return FunctionExecutable(
|
|||
|
|
func=info["func"],
|
|||
|
|
name=name,
|
|||
|
|
description=info["description"],
|
|||
|
|
params_model=info["params_model"],
|
|||
|
|
)
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ToolProviderManager:
|
|||
|
|
"""工具提供者的中心化管理器,采用单例模式。"""
|
|||
|
|
|
|||
|
|
_instance: "ToolProviderManager | None" = None
|
|||
|
|
|
|||
|
|
def __new__(cls) -> "ToolProviderManager":
|
|||
|
|
if cls._instance is None:
|
|||
|
|
cls._instance = super().__new__(cls)
|
|||
|
|
return cls._instance
|
|||
|
|
|
|||
|
|
def __init__(self):
|
|||
|
|
if hasattr(self, "_initialized") and self._initialized:
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
self._providers: list[ToolProvider] = []
|
|||
|
|
self._resolved_tools: dict[str, ToolExecutable] | None = None
|
|||
|
|
self._init_lock = asyncio.Lock()
|
|||
|
|
self._init_promise: asyncio.Task | None = None
|
|||
|
|
self._builtin_function_provider = BuiltinFunctionToolProvider()
|
|||
|
|
self.register(self._builtin_function_provider)
|
|||
|
|
self._initialized = True
|
|||
|
|
|
|||
|
|
def register(self, provider: ToolProvider):
|
|||
|
|
"""注册一个新的 ToolProvider。"""
|
|||
|
|
if provider not in self._providers:
|
|||
|
|
self._providers.append(provider)
|
|||
|
|
logger.info(f"已注册工具提供者: {provider.__class__.__name__}")
|
|||
|
|
|
|||
|
|
def function_tool(
|
|||
|
|
self,
|
|||
|
|
name: str,
|
|||
|
|
description: str,
|
|||
|
|
params_model: type[BaseModel] | None = None,
|
|||
|
|
):
|
|||
|
|
"""装饰器:将一个函数注册为内置工具。"""
|
|||
|
|
|
|||
|
|
def decorator(func: Callable):
|
|||
|
|
if name in self._builtin_function_provider._functions:
|
|||
|
|
logger.warning(f"正在覆盖已注册的函数工具: {name}")
|
|||
|
|
|
|||
|
|
self._builtin_function_provider.register(
|
|||
|
|
name=name,
|
|||
|
|
func=func,
|
|||
|
|
description=description,
|
|||
|
|
params_model=params_model,
|
|||
|
|
)
|
|||
|
|
logger.info(f"已注册函数工具: '{name}'")
|
|||
|
|
return func
|
|||
|
|
|
|||
|
|
return decorator
|
|||
|
|
|
|||
|
|
async def initialize(self) -> None:
|
|||
|
|
"""懒加载初始化所有已注册的 ToolProvider。"""
|
|||
|
|
if not self._init_promise:
|
|||
|
|
async with self._init_lock:
|
|||
|
|
if not self._init_promise:
|
|||
|
|
self._init_promise = asyncio.create_task(
|
|||
|
|
self._initialize_providers()
|
|||
|
|
)
|
|||
|
|
await self._init_promise
|
|||
|
|
|
|||
|
|
async def _initialize_providers(self) -> None:
|
|||
|
|
"""内部初始化逻辑。"""
|
|||
|
|
logger.info(f"开始初始化 {len(self._providers)} 个工具提供者...")
|
|||
|
|
init_tasks = [provider.initialize() for provider in self._providers]
|
|||
|
|
await asyncio.gather(*init_tasks, return_exceptions=True)
|
|||
|
|
logger.info("所有工具提供者初始化完成。")
|
|||
|
|
|
|||
|
|
async def get_resolved_tools(
|
|||
|
|
self,
|
|||
|
|
allowed_servers: list[str] | None = None,
|
|||
|
|
excluded_servers: list[str] | None = None,
|
|||
|
|
) -> dict[str, ToolExecutable]:
|
|||
|
|
"""
|
|||
|
|
获取所有已发现和解析的工具。
|
|||
|
|
此方法会触发懒加载初始化,并根据是否传入过滤器来决定是否使用全局缓存。
|
|||
|
|
"""
|
|||
|
|
await self.initialize()
|
|||
|
|
|
|||
|
|
has_filters = allowed_servers is not None or excluded_servers is not None
|
|||
|
|
|
|||
|
|
if not has_filters and self._resolved_tools is not None:
|
|||
|
|
logger.debug("使用全局工具缓存。")
|
|||
|
|
return self._resolved_tools
|
|||
|
|
|
|||
|
|
if has_filters:
|
|||
|
|
logger.info("检测到过滤器,执行临时工具发现 (不使用缓存)。")
|
|||
|
|
logger.debug(
|
|||
|
|
f"过滤器详情: allowed_servers={allowed_servers}, "
|
|||
|
|
f"excluded_servers={excluded_servers}"
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
logger.info("未应用过滤器,开始全局工具发现...")
|
|||
|
|
|
|||
|
|
all_tools: dict[str, ToolExecutable] = {}
|
|||
|
|
|
|||
|
|
discover_tasks = []
|
|||
|
|
for provider in self._providers:
|
|||
|
|
sig = inspect.signature(provider.discover_tools)
|
|||
|
|
params_to_pass = {}
|
|||
|
|
if "allowed_servers" in sig.parameters:
|
|||
|
|
params_to_pass["allowed_servers"] = allowed_servers
|
|||
|
|
if "excluded_servers" in sig.parameters:
|
|||
|
|
params_to_pass["excluded_servers"] = excluded_servers
|
|||
|
|
|
|||
|
|
discover_tasks.append(provider.discover_tools(**params_to_pass))
|
|||
|
|
|
|||
|
|
results = await asyncio.gather(*discover_tasks, return_exceptions=True)
|
|||
|
|
|
|||
|
|
for i, provider_result in enumerate(results):
|
|||
|
|
provider_name = self._providers[i].__class__.__name__
|
|||
|
|
if isinstance(provider_result, dict):
|
|||
|
|
logger.debug(
|
|||
|
|
f"提供者 '{provider_name}' 发现了 {len(provider_result)} 个工具。"
|
|||
|
|
)
|
|||
|
|
for name, executable in provider_result.items():
|
|||
|
|
if name in all_tools:
|
|||
|
|
logger.warning(
|
|||
|
|
f"发现重复的工具名称 '{name}',后发现的将覆盖前者。"
|
|||
|
|
)
|
|||
|
|
all_tools[name] = executable
|
|||
|
|
elif isinstance(provider_result, Exception):
|
|||
|
|
logger.error(
|
|||
|
|
f"提供者 '{provider_name}' 在发现工具时出错: {provider_result}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if not has_filters:
|
|||
|
|
self._resolved_tools = all_tools
|
|||
|
|
logger.info(f"全局工具发现完成,共找到并缓存了 {len(all_tools)} 个工具。")
|
|||
|
|
else:
|
|||
|
|
logger.info(f"带过滤器的工具发现完成,共找到 {len(all_tools)} 个工具。")
|
|||
|
|
|
|||
|
|
return all_tools
|
|||
|
|
|
|||
|
|
async def get_function_tools(
|
|||
|
|
self, names: list[str] | None = None
|
|||
|
|
) -> dict[str, ToolExecutable]:
|
|||
|
|
"""
|
|||
|
|
仅从内置的函数提供者中解析指定的工具。
|
|||
|
|
"""
|
|||
|
|
all_function_tools = await self._builtin_function_provider.discover_tools()
|
|||
|
|
if names is None:
|
|||
|
|
return all_function_tools
|
|||
|
|
|
|||
|
|
resolved_tools = {}
|
|||
|
|
for name in names:
|
|||
|
|
if name in all_function_tools:
|
|||
|
|
resolved_tools[name] = all_function_tools[name]
|
|||
|
|
else:
|
|||
|
|
logger.warning(
|
|||
|
|
f"本地函数工具 '{name}' 未通过 @function_tool 注册,将被忽略。"
|
|||
|
|
)
|
|||
|
|
return resolved_tools
|
|||
|
|
|
|||
|
|
|
|||
|
|
tool_provider_manager = ToolProviderManager()
|