mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-14 21:52:56 +08:00
Some checks failed
检查bot是否运行正常 / bot check (push) Waiting to run
Sequential Lint and Type Check / ruff-call (push) Waiting to run
Sequential Lint and Type Check / pyright-call (push) Blocked by required conditions
Release Drafter / Update Release Draft (push) Waiting to run
Force Sync to Aliyun / sync (push) Waiting to run
Update Version / update-version (push) Waiting to run
CodeQL Code Security Analysis / Analyze (${{ matrix.language }}) (none, javascript-typescript) (push) Has been cancelled
CodeQL Code Security Analysis / Analyze (${{ matrix.language }}) (none, python) (push) Has been cancelled
* ♻️ refactor(pydantic): 提取 Pydantic 兼容函数到独立模块 * ♻️ refactor!(llm): 重构LLM服务,引入现代化工具和执行器架构 🏗️ **架构变更** - 引入ToolProvider/ToolExecutable协议,取代ToolRegistry - 新增LLMToolExecutor,分离工具调用逻辑 - 新增BaseMemory抽象,解耦会话状态管理 🔄 **API重构** - 移除:analyze, analyze_multimodal, pipeline_chat - 新增:generate_structured, run_with_tools - 重构:chat, search, code变为无状态调用 🛠️ **工具系统** - 新增@function_tool装饰器 - 统一工具定义到ToolExecutable协议 - 移除MCP工具系统和mcp_tools.json --------- Co-authored-by: webjoin111 <455457521@qq.com>
294 lines
10 KiB
Python
294 lines
10 KiB
Python
"""
|
||
工具提供者管理器
|
||
|
||
负责注册、生命周期管理(包括懒加载)和统一提供所有工具。
|
||
"""
|
||
|
||
import asyncio
|
||
from collections.abc import Callable
|
||
import inspect
|
||
from typing import Any
|
||
|
||
from pydantic import BaseModel
|
||
|
||
from zhenxun.services.log import logger
|
||
from zhenxun.utils.pydantic_compat import model_json_schema
|
||
|
||
from ..types import ToolExecutable, ToolProvider
|
||
from ..types.models import ToolDefinition, ToolResult
|
||
|
||
|
||
class FunctionExecutable(ToolExecutable):
|
||
"""一个 ToolExecutable 的实现,用于包装一个普通的 Python 函数。"""
|
||
|
||
def __init__(
|
||
self,
|
||
func: Callable,
|
||
name: str,
|
||
description: str,
|
||
params_model: type[BaseModel] | None,
|
||
):
|
||
self._func = func
|
||
self._name = name
|
||
self._description = description
|
||
self._params_model = params_model
|
||
|
||
async def get_definition(self) -> ToolDefinition:
|
||
if not self._params_model:
|
||
return ToolDefinition(
|
||
name=self._name,
|
||
description=self._description,
|
||
parameters={"type": "object", "properties": {}},
|
||
)
|
||
|
||
schema = model_json_schema(self._params_model)
|
||
|
||
return ToolDefinition(
|
||
name=self._name,
|
||
description=self._description,
|
||
parameters={
|
||
"type": "object",
|
||
"properties": schema.get("properties", {}),
|
||
"required": schema.get("required", []),
|
||
},
|
||
)
|
||
|
||
async def execute(self, **kwargs: Any) -> ToolResult:
|
||
raw_result: Any
|
||
|
||
if self._params_model:
|
||
try:
|
||
params_instance = self._params_model(**kwargs)
|
||
|
||
if inspect.iscoroutinefunction(self._func):
|
||
raw_result = await self._func(params_instance)
|
||
else:
|
||
loop = asyncio.get_event_loop()
|
||
raw_result = await loop.run_in_executor(
|
||
None, lambda: self._func(params_instance)
|
||
)
|
||
except Exception as e:
|
||
logger.error(
|
||
f"执行工具 '{self._name}' 时参数验证或实例化失败: {e}", e=e
|
||
)
|
||
raise
|
||
else:
|
||
if inspect.iscoroutinefunction(self._func):
|
||
raw_result = await self._func(**kwargs)
|
||
else:
|
||
loop = asyncio.get_event_loop()
|
||
raw_result = await loop.run_in_executor(
|
||
None, lambda: self._func(**kwargs)
|
||
)
|
||
|
||
return ToolResult(output=raw_result, display_content=str(raw_result))
|
||
|
||
|
||
class BuiltinFunctionToolProvider(ToolProvider):
|
||
"""一个内置的 ToolProvider,用于处理通过装饰器注册的函数。"""
|
||
|
||
def __init__(self):
|
||
self._functions: dict[str, dict[str, Any]] = {}
|
||
|
||
def register(
|
||
self,
|
||
name: str,
|
||
func: Callable,
|
||
description: str,
|
||
params_model: type[BaseModel] | None,
|
||
):
|
||
self._functions[name] = {
|
||
"func": func,
|
||
"description": description,
|
||
"params_model": params_model,
|
||
}
|
||
|
||
async def initialize(self) -> None:
|
||
pass
|
||
|
||
async def discover_tools(
|
||
self,
|
||
allowed_servers: list[str] | None = None,
|
||
excluded_servers: list[str] | None = None,
|
||
) -> dict[str, ToolExecutable]:
|
||
executables = {}
|
||
for name, info in self._functions.items():
|
||
executables[name] = FunctionExecutable(
|
||
func=info["func"],
|
||
name=name,
|
||
description=info["description"],
|
||
params_model=info["params_model"],
|
||
)
|
||
return executables
|
||
|
||
async def get_tool_executable(
|
||
self, name: str, config: dict[str, Any]
|
||
) -> ToolExecutable | None:
|
||
if config.get("type") == "function" and name in self._functions:
|
||
info = self._functions[name]
|
||
return FunctionExecutable(
|
||
func=info["func"],
|
||
name=name,
|
||
description=info["description"],
|
||
params_model=info["params_model"],
|
||
)
|
||
return None
|
||
|
||
|
||
class ToolProviderManager:
|
||
"""工具提供者的中心化管理器,采用单例模式。"""
|
||
|
||
_instance: "ToolProviderManager | None" = None
|
||
|
||
def __new__(cls) -> "ToolProviderManager":
|
||
if cls._instance is None:
|
||
cls._instance = super().__new__(cls)
|
||
return cls._instance
|
||
|
||
def __init__(self):
|
||
if hasattr(self, "_initialized") and self._initialized:
|
||
return
|
||
|
||
self._providers: list[ToolProvider] = []
|
||
self._resolved_tools: dict[str, ToolExecutable] | None = None
|
||
self._init_lock = asyncio.Lock()
|
||
self._init_promise: asyncio.Task | None = None
|
||
self._builtin_function_provider = BuiltinFunctionToolProvider()
|
||
self.register(self._builtin_function_provider)
|
||
self._initialized = True
|
||
|
||
def register(self, provider: ToolProvider):
|
||
"""注册一个新的 ToolProvider。"""
|
||
if provider not in self._providers:
|
||
self._providers.append(provider)
|
||
logger.info(f"已注册工具提供者: {provider.__class__.__name__}")
|
||
|
||
def function_tool(
|
||
self,
|
||
name: str,
|
||
description: str,
|
||
params_model: type[BaseModel] | None = None,
|
||
):
|
||
"""装饰器:将一个函数注册为内置工具。"""
|
||
|
||
def decorator(func: Callable):
|
||
if name in self._builtin_function_provider._functions:
|
||
logger.warning(f"正在覆盖已注册的函数工具: {name}")
|
||
|
||
self._builtin_function_provider.register(
|
||
name=name,
|
||
func=func,
|
||
description=description,
|
||
params_model=params_model,
|
||
)
|
||
logger.info(f"已注册函数工具: '{name}'")
|
||
return func
|
||
|
||
return decorator
|
||
|
||
async def initialize(self) -> None:
|
||
"""懒加载初始化所有已注册的 ToolProvider。"""
|
||
if not self._init_promise:
|
||
async with self._init_lock:
|
||
if not self._init_promise:
|
||
self._init_promise = asyncio.create_task(
|
||
self._initialize_providers()
|
||
)
|
||
await self._init_promise
|
||
|
||
async def _initialize_providers(self) -> None:
|
||
"""内部初始化逻辑。"""
|
||
logger.info(f"开始初始化 {len(self._providers)} 个工具提供者...")
|
||
init_tasks = [provider.initialize() for provider in self._providers]
|
||
await asyncio.gather(*init_tasks, return_exceptions=True)
|
||
logger.info("所有工具提供者初始化完成。")
|
||
|
||
async def get_resolved_tools(
|
||
self,
|
||
allowed_servers: list[str] | None = None,
|
||
excluded_servers: list[str] | None = None,
|
||
) -> dict[str, ToolExecutable]:
|
||
"""
|
||
获取所有已发现和解析的工具。
|
||
此方法会触发懒加载初始化,并根据是否传入过滤器来决定是否使用全局缓存。
|
||
"""
|
||
await self.initialize()
|
||
|
||
has_filters = allowed_servers is not None or excluded_servers is not None
|
||
|
||
if not has_filters and self._resolved_tools is not None:
|
||
logger.debug("使用全局工具缓存。")
|
||
return self._resolved_tools
|
||
|
||
if has_filters:
|
||
logger.info("检测到过滤器,执行临时工具发现 (不使用缓存)。")
|
||
logger.debug(
|
||
f"过滤器详情: allowed_servers={allowed_servers}, "
|
||
f"excluded_servers={excluded_servers}"
|
||
)
|
||
else:
|
||
logger.info("未应用过滤器,开始全局工具发现...")
|
||
|
||
all_tools: dict[str, ToolExecutable] = {}
|
||
|
||
discover_tasks = []
|
||
for provider in self._providers:
|
||
sig = inspect.signature(provider.discover_tools)
|
||
params_to_pass = {}
|
||
if "allowed_servers" in sig.parameters:
|
||
params_to_pass["allowed_servers"] = allowed_servers
|
||
if "excluded_servers" in sig.parameters:
|
||
params_to_pass["excluded_servers"] = excluded_servers
|
||
|
||
discover_tasks.append(provider.discover_tools(**params_to_pass))
|
||
|
||
results = await asyncio.gather(*discover_tasks, return_exceptions=True)
|
||
|
||
for i, provider_result in enumerate(results):
|
||
provider_name = self._providers[i].__class__.__name__
|
||
if isinstance(provider_result, dict):
|
||
logger.debug(
|
||
f"提供者 '{provider_name}' 发现了 {len(provider_result)} 个工具。"
|
||
)
|
||
for name, executable in provider_result.items():
|
||
if name in all_tools:
|
||
logger.warning(
|
||
f"发现重复的工具名称 '{name}',后发现的将覆盖前者。"
|
||
)
|
||
all_tools[name] = executable
|
||
elif isinstance(provider_result, Exception):
|
||
logger.error(
|
||
f"提供者 '{provider_name}' 在发现工具时出错: {provider_result}"
|
||
)
|
||
|
||
if not has_filters:
|
||
self._resolved_tools = all_tools
|
||
logger.info(f"全局工具发现完成,共找到并缓存了 {len(all_tools)} 个工具。")
|
||
else:
|
||
logger.info(f"带过滤器的工具发现完成,共找到 {len(all_tools)} 个工具。")
|
||
|
||
return all_tools
|
||
|
||
async def get_function_tools(
|
||
self, names: list[str] | None = None
|
||
) -> dict[str, ToolExecutable]:
|
||
"""
|
||
仅从内置的函数提供者中解析指定的工具。
|
||
"""
|
||
all_function_tools = await self._builtin_function_provider.discover_tools()
|
||
if names is None:
|
||
return all_function_tools
|
||
|
||
resolved_tools = {}
|
||
for name in names:
|
||
if name in all_function_tools:
|
||
resolved_tools[name] = all_function_tools[name]
|
||
else:
|
||
logger.warning(
|
||
f"本地函数工具 '{name}' 未通过 @function_tool 注册,将被忽略。"
|
||
)
|
||
return resolved_tools
|
||
|
||
|
||
tool_provider_manager = ToolProviderManager()
|