mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-14 21:52:56 +08:00
- 【重构】LLM 服务核心架构:
- 引入中间件管道,统一处理请求生命周期(重试、密钥选择、日志、网络请求)。
- 适配器重构为组件化设计,分离配置映射、消息转换、响应解析和工具序列化逻辑。
- 移除 `with_smart_retry` 装饰器,其功能由中间件接管。
- 移除 `LLMToolExecutor`,工具执行逻辑集成到 `ToolInvoker`。
- 【功能】增强配置系统:
- `LLMGenerationConfig` 采用组件化结构(Core, Reasoning, Visual, Output, Safety, ToolConfig)。
- 新增 `GenConfigBuilder` 提供语义化配置构建方式。
- 新增 `LLMEmbeddingConfig` 用于嵌入专用配置。
- `CommonOverrides` 迁移并更新至新配置结构。
- 【功能】强化工具系统:
- 引入 `ToolInvoker` 实现更灵活的工具执行,支持回调与结构化错误。
- `function_tool` 装饰器支持动态 Pydantic 模型创建和依赖注入 (`ToolParam`, `RunContext`)。
- 平台原生工具支持 (`GeminiCodeExecution`, `GeminiGoogleSearch`, `GeminiUrlContext`)。
- 【功能】高级生成与嵌入:
- `generate_structured` 方法支持 In-Context Validation and Repair (IVR) 循环和 AutoCoT (思维链) 包装。
- 新增 `embed_query` 和 `embed_documents` 便捷嵌入 API。
- `OpenAIImageAdapter` 支持 OpenAI 兼容的图像生成。
- `SmartAdapter` 实现模型名称智能路由。
- 【重构】消息与类型系统:
- `LLMContentPart` 扩展支持更多模态和代码执行相关内容。
- `LLMMessage` 和 `LLMResponse` 结构更新,支持 `content_parts` 和思维链签名。
- 统一 `LLMErrorCode` 和用户友好错误消息,提供更详细的网络/代理错误提示。
- `pyproject.toml` 移除 `bilireq`,新增 `json_repair`。
- 【优化】日志与调试:
- 引入 `DebugLogOptions`,提供细粒度日志脱敏控制。
- 增强日志净化器,处理更多敏感数据和长字符串。
- 【清理】删除废弃模块:
- `zhenxun/services/llm/memory.py`
- `zhenxun/services/llm/executor.py`
- `zhenxun/services/llm/config/presets.py`
- `zhenxun/services/llm/types/content.py`
- `zhenxun/services/llm/types/enums.py`
- `zhenxun/services/llm/tools/__init__.py`
- `zhenxun/services/llm/tools/manager.py`
840 lines
27 KiB
Python
840 lines
27 KiB
Python
"""
|
||
工具模块
|
||
|
||
整合了工具参数解析器、工具提供者管理器与工具执行逻辑,便于在 LLM 服务层统一调用。
|
||
"""
|
||
|
||
import asyncio
|
||
from collections.abc import Callable
|
||
from enum import Enum
|
||
import inspect
|
||
import json
|
||
import re
|
||
import time
|
||
from typing import (
|
||
Annotated,
|
||
Any,
|
||
Optional,
|
||
Union,
|
||
cast,
|
||
get_args,
|
||
get_origin,
|
||
get_type_hints,
|
||
)
|
||
from typing_extensions import override
|
||
|
||
from httpx import NetworkError, TimeoutException
|
||
|
||
try:
|
||
import ujson as fast_json
|
||
except ImportError:
|
||
fast_json = json
|
||
|
||
import nonebot
|
||
from nonebot.dependencies import Dependent, Param
|
||
from nonebot.internal.adapter import Bot, Event
|
||
from nonebot.internal.params import (
|
||
BotParam,
|
||
DefaultParam,
|
||
DependParam,
|
||
DependsInner,
|
||
EventParam,
|
||
StateParam,
|
||
)
|
||
from pydantic import BaseModel, Field, ValidationError, create_model
|
||
from pydantic.fields import FieldInfo
|
||
|
||
from zhenxun.services.log import logger
|
||
from zhenxun.utils.decorator.retry import Retry
|
||
from zhenxun.utils.pydantic_compat import model_dump, model_fields, model_json_schema
|
||
|
||
from .types import (
|
||
LLMErrorCode,
|
||
LLMException,
|
||
LLMMessage,
|
||
LLMToolCall,
|
||
ToolExecutable,
|
||
ToolProvider,
|
||
ToolResult,
|
||
)
|
||
from .types.models import ToolDefinition
|
||
from .types.protocols import BaseCallbackHandler, ToolCallData
|
||
|
||
|
||
class ToolParam(Param):
|
||
"""
|
||
工具参数提取器。
|
||
|
||
用于在自定义工具函数(Function Tool)中,从 LLM 解析出的参数字典
|
||
(`state["_tool_params"]`)
|
||
中提取特定的参数值。通常配合 `Annotated` 和依赖注入系统使用。
|
||
"""
|
||
|
||
def __init__(self, *args: Any, name: str, **kwargs: Any):
|
||
super().__init__(*args, **kwargs)
|
||
self.name = name
|
||
|
||
def __repr__(self) -> str:
|
||
return f"ToolParam(name={self.name})"
|
||
|
||
@classmethod
|
||
@override
|
||
def _check_param(
|
||
cls, param: inspect.Parameter, allow_types: tuple[type[Param], ...]
|
||
) -> Optional["ToolParam"]:
|
||
if param.default is not inspect.Parameter.empty and isinstance(
|
||
param.default, DependsInner
|
||
):
|
||
return None
|
||
|
||
if get_origin(param.annotation) is Annotated:
|
||
for arg in get_args(param.annotation):
|
||
if isinstance(arg, DependsInner):
|
||
return None
|
||
|
||
if param.kind not in (
|
||
inspect.Parameter.VAR_POSITIONAL,
|
||
inspect.Parameter.VAR_KEYWORD,
|
||
):
|
||
return cls(name=param.name)
|
||
return None
|
||
|
||
@override
|
||
async def _solve(self, **kwargs: Any) -> Any:
|
||
state: dict[str, Any] = kwargs.get("state", {})
|
||
tool_params = state.get("_tool_params", {})
|
||
if self.name in tool_params:
|
||
return tool_params[self.name]
|
||
return None
|
||
|
||
|
||
class RunContext(BaseModel):
|
||
"""
|
||
依赖注入容器(DI Container),保留原有上下文信息的同时提升获取类型的能力。
|
||
"""
|
||
|
||
session_id: str | None = None
|
||
scope: dict[str, Any] = Field(default_factory=dict)
|
||
extra: dict[str, Any] = Field(default_factory=dict)
|
||
|
||
class Config:
|
||
arbitrary_types_allowed = True
|
||
|
||
|
||
class RunContextParam(Param):
|
||
"""自动注入 RunContext 的参数解析器"""
|
||
|
||
@classmethod
|
||
def _check_param(
|
||
cls, param: inspect.Parameter, allow_types: tuple[type[Param], ...]
|
||
) -> Optional["RunContextParam"]:
|
||
if param.annotation is RunContext:
|
||
return cls()
|
||
return None
|
||
|
||
async def _solve(self, **kwargs: Any) -> Any:
|
||
state = kwargs.get("state", {})
|
||
return state.get("_agent_context")
|
||
|
||
|
||
def _parse_docstring_params(docstring: str | None) -> dict[str, str]:
|
||
"""
|
||
解析文档字符串,提取参数描述。
|
||
支持 Google Style (Args:), ReST Style (:param:), 和中文风格 (参数:)。
|
||
"""
|
||
if not docstring:
|
||
return {}
|
||
|
||
params: dict[str, str] = {}
|
||
lines = docstring.splitlines()
|
||
|
||
rest_pattern = re.compile(r"[:@]param\s+(\w+)\s*:?\s*(.*)")
|
||
found_rest = False
|
||
for line in lines:
|
||
match = rest_pattern.search(line)
|
||
if match:
|
||
params[match.group(1)] = match.group(2).strip()
|
||
found_rest = True
|
||
|
||
if found_rest:
|
||
return params
|
||
|
||
section_header_pattern = re.compile(
|
||
r"^\s*(?:Args|Arguments|Parameters|参数)\s*[::]\s*$"
|
||
)
|
||
|
||
param_section_active = False
|
||
google_pattern = re.compile(r"^\s*(\**\w+)(?:\s*\(.*?\))?\s*[::]\s*(.*)")
|
||
|
||
for line in lines:
|
||
stripped_line = line.strip()
|
||
if not stripped_line:
|
||
continue
|
||
|
||
if section_header_pattern.match(line):
|
||
param_section_active = True
|
||
continue
|
||
|
||
if param_section_active:
|
||
if (
|
||
stripped_line.endswith(":") or stripped_line.endswith(":")
|
||
) and not google_pattern.match(line):
|
||
param_section_active = False
|
||
continue
|
||
|
||
match = google_pattern.match(line)
|
||
if match:
|
||
name = match.group(1).lstrip("*")
|
||
desc = match.group(2).strip()
|
||
params[name] = desc
|
||
|
||
return params
|
||
|
||
|
||
def _create_dynamic_model(func: Callable) -> type[BaseModel]:
|
||
"""根据函数签名动态创建 Pydantic 模型"""
|
||
sig = inspect.signature(func)
|
||
doc_params = _parse_docstring_params(func.__doc__)
|
||
type_hints = get_type_hints(func, include_extras=True)
|
||
|
||
fields = {}
|
||
for name, param in sig.parameters.items():
|
||
if name in ("self", "cls"):
|
||
continue
|
||
|
||
annotation = type_hints.get(name, Any)
|
||
default = param.default
|
||
|
||
is_run_context = False
|
||
if annotation is RunContext:
|
||
is_run_context = True
|
||
else:
|
||
origin = get_origin(annotation)
|
||
if origin is Union:
|
||
args = get_args(annotation)
|
||
if RunContext in args:
|
||
is_run_context = True
|
||
|
||
if is_run_context:
|
||
continue
|
||
|
||
if default is not inspect.Parameter.empty and isinstance(default, DependsInner):
|
||
continue
|
||
|
||
if get_origin(annotation) is Annotated:
|
||
args = get_args(annotation)
|
||
if any(isinstance(arg, DependsInner) for arg in args):
|
||
continue
|
||
|
||
description = doc_params.get(name)
|
||
if isinstance(default, FieldInfo):
|
||
if description and not getattr(default, "description", None):
|
||
default.description = description
|
||
fields[name] = (annotation, default)
|
||
else:
|
||
if default is inspect.Parameter.empty:
|
||
default = ...
|
||
fields[name] = (annotation, Field(default, description=description))
|
||
|
||
return create_model(f"{func.__name__}Params", **fields)
|
||
|
||
|
||
class FunctionExecutable(ToolExecutable):
|
||
"""一个 ToolExecutable 的实现,用于包装一个普通的 Python 函数。"""
|
||
|
||
def __init__(
|
||
self,
|
||
func: Callable,
|
||
name: str,
|
||
description: str,
|
||
params_model: type[BaseModel] | None = None,
|
||
unpack_args: bool = False,
|
||
):
|
||
self._func = func
|
||
self._name = name
|
||
self._description = description
|
||
self._params_model = params_model
|
||
self._unpack_args = unpack_args
|
||
|
||
self.dependent = Dependent[Any].parse(
|
||
call=func,
|
||
allow_types=(
|
||
DependParam,
|
||
BotParam,
|
||
EventParam,
|
||
StateParam,
|
||
RunContextParam,
|
||
ToolParam,
|
||
DefaultParam,
|
||
),
|
||
)
|
||
|
||
async def get_definition(self) -> ToolDefinition:
|
||
if not self._params_model:
|
||
return ToolDefinition(
|
||
name=self._name,
|
||
description=self._description,
|
||
parameters={"type": "object", "properties": {}},
|
||
)
|
||
|
||
schema = model_json_schema(self._params_model)
|
||
|
||
return ToolDefinition(
|
||
name=self._name,
|
||
description=self._description,
|
||
parameters={
|
||
"type": "object",
|
||
"properties": schema.get("properties", {}),
|
||
"required": schema.get("required", []),
|
||
},
|
||
)
|
||
|
||
async def execute(
|
||
self, context: RunContext | None = None, **kwargs: Any
|
||
) -> ToolResult:
|
||
context = context or RunContext()
|
||
|
||
tool_arguments = kwargs
|
||
|
||
if self._params_model:
|
||
try:
|
||
_fields = model_fields(self._params_model)
|
||
validation_input = {
|
||
key: value for key, value in kwargs.items() if key in _fields
|
||
}
|
||
|
||
validated_params = self._params_model(**validation_input)
|
||
|
||
if not self._unpack_args:
|
||
pass
|
||
else:
|
||
validated_dict = model_dump(validated_params)
|
||
tool_arguments = validated_dict
|
||
|
||
except ValidationError as e:
|
||
error_msgs = []
|
||
for err in e.errors():
|
||
loc = ".".join(str(x) for x in err["loc"])
|
||
msg = err["msg"]
|
||
error_msgs.append(f"Parameter '{loc}': {msg}")
|
||
|
||
formatted_error = "; ".join(error_msgs)
|
||
error_payload = {
|
||
"error_type": "InvalidArguments",
|
||
"message": f"Parameter validation failed: {formatted_error}",
|
||
"is_retryable": True,
|
||
}
|
||
return ToolResult(
|
||
output=json.dumps(error_payload, ensure_ascii=False),
|
||
display_content=f"Validation Error: {formatted_error}",
|
||
)
|
||
except Exception as e:
|
||
logger.error(
|
||
f"执行工具 '{self._name}' 时参数验证或实例化失败: {e}", e=e
|
||
)
|
||
raise
|
||
|
||
state = {
|
||
"_tool_params": tool_arguments,
|
||
"_agent_context": context,
|
||
}
|
||
|
||
bot: Bot | None = None
|
||
if context and context.scope.get("bot"):
|
||
bot = context.scope.get("bot")
|
||
if not bot:
|
||
try:
|
||
bot = nonebot.get_bot()
|
||
except ValueError:
|
||
pass
|
||
|
||
event: Event | None = None
|
||
if context and context.scope.get("event"):
|
||
event = context.scope.get("event")
|
||
|
||
raw_result = await self.dependent(
|
||
bot=bot,
|
||
event=event,
|
||
state=state,
|
||
)
|
||
|
||
return ToolResult(output=raw_result, display_content=str(raw_result))
|
||
|
||
|
||
class BuiltinFunctionToolProvider(ToolProvider):
|
||
"""一个内置的 ToolProvider,用于处理通过装饰器注册的函数。"""
|
||
|
||
def __init__(self):
|
||
self._functions: dict[str, dict[str, Any]] = {}
|
||
|
||
def register(
|
||
self,
|
||
name: str,
|
||
func: Callable,
|
||
description: str,
|
||
params_model: type[BaseModel] | None = None,
|
||
unpack_args: bool = False,
|
||
):
|
||
self._functions[name] = {
|
||
"func": func,
|
||
"description": description,
|
||
"params_model": params_model,
|
||
"unpack_args": unpack_args,
|
||
}
|
||
|
||
async def initialize(self) -> None:
|
||
pass
|
||
|
||
async def discover_tools(
|
||
self,
|
||
allowed_servers: list[str] | None = None,
|
||
excluded_servers: list[str] | None = None,
|
||
) -> dict[str, ToolExecutable]:
|
||
executables = {}
|
||
for name, info in self._functions.items():
|
||
executables[name] = FunctionExecutable(
|
||
func=info["func"],
|
||
name=name,
|
||
description=info["description"],
|
||
params_model=info["params_model"],
|
||
unpack_args=info.get("unpack_args", False),
|
||
)
|
||
return executables
|
||
|
||
async def get_tool_executable(
|
||
self, name: str, config: dict[str, Any]
|
||
) -> ToolExecutable | None:
|
||
if config.get("type", "function") == "function" and name in self._functions:
|
||
info = self._functions[name]
|
||
return FunctionExecutable(
|
||
func=info["func"],
|
||
name=name,
|
||
description=info["description"],
|
||
params_model=info["params_model"],
|
||
unpack_args=info.get("unpack_args", False),
|
||
)
|
||
return None
|
||
|
||
|
||
class ToolProviderManager:
|
||
"""工具提供者的中心化管理器,采用单例模式。"""
|
||
|
||
_instance: "ToolProviderManager | None" = None
|
||
|
||
def __new__(cls) -> "ToolProviderManager":
|
||
if cls._instance is None:
|
||
cls._instance = super().__new__(cls)
|
||
return cls._instance
|
||
|
||
def __init__(self):
|
||
if hasattr(self, "_initialized") and self._initialized:
|
||
return
|
||
|
||
self._providers: list[ToolProvider] = []
|
||
self._resolved_tools: dict[str, ToolExecutable] | None = None
|
||
self._init_lock = asyncio.Lock()
|
||
self._init_promise: asyncio.Task | None = None
|
||
self._builtin_function_provider = BuiltinFunctionToolProvider()
|
||
self.register(self._builtin_function_provider)
|
||
self._initialized = True
|
||
|
||
def register(self, provider: ToolProvider):
|
||
"""注册一个新的 ToolProvider。"""
|
||
if provider not in self._providers:
|
||
self._providers.append(provider)
|
||
logger.info(f"已注册工具提供者: {provider.__class__.__name__}")
|
||
|
||
def function_tool(
|
||
self,
|
||
name: str,
|
||
description: str,
|
||
params_model: type[BaseModel] | None = None,
|
||
):
|
||
"""装饰器:将一个函数注册为内置工具。"""
|
||
|
||
def decorator(func: Callable):
|
||
if name in self._builtin_function_provider._functions:
|
||
logger.warning(f"正在覆盖已注册的函数工具: {name}")
|
||
|
||
final_model = params_model
|
||
unpack_args = False
|
||
if final_model is None:
|
||
final_model = _create_dynamic_model(func)
|
||
unpack_args = True
|
||
|
||
self._builtin_function_provider.register(
|
||
name=name,
|
||
func=func,
|
||
description=description,
|
||
params_model=final_model,
|
||
unpack_args=unpack_args,
|
||
)
|
||
logger.info(f"已注册函数工具: '{name}'")
|
||
return func
|
||
|
||
return decorator
|
||
|
||
async def initialize(self) -> None:
|
||
"""懒加载初始化所有已注册的 ToolProvider。"""
|
||
if not self._init_promise:
|
||
async with self._init_lock:
|
||
if not self._init_promise:
|
||
self._init_promise = asyncio.create_task(
|
||
self._initialize_providers()
|
||
)
|
||
await self._init_promise
|
||
|
||
async def _initialize_providers(self) -> None:
|
||
"""内部初始化逻辑。"""
|
||
logger.info(f"开始初始化 {len(self._providers)} 个工具提供者...")
|
||
init_tasks = [provider.initialize() for provider in self._providers]
|
||
await asyncio.gather(*init_tasks, return_exceptions=True)
|
||
logger.info("所有工具提供者初始化完成。")
|
||
|
||
async def get_resolved_tools(
|
||
self,
|
||
allowed_servers: list[str] | None = None,
|
||
excluded_servers: list[str] | None = None,
|
||
) -> dict[str, ToolExecutable]:
|
||
"""
|
||
获取所有已发现和解析的工具。
|
||
此方法会触发懒加载初始化,并根据是否传入过滤器来决定是否使用全局缓存。
|
||
"""
|
||
await self.initialize()
|
||
|
||
has_filters = allowed_servers is not None or excluded_servers is not None
|
||
|
||
if not has_filters and self._resolved_tools is not None:
|
||
logger.debug("使用全局工具缓存。")
|
||
return self._resolved_tools
|
||
|
||
if has_filters:
|
||
logger.info("检测到过滤器,执行临时工具发现 (不使用缓存)。")
|
||
logger.debug(
|
||
f"过滤器详情: allowed_servers={allowed_servers}, "
|
||
f"excluded_servers={excluded_servers}"
|
||
)
|
||
else:
|
||
logger.info("未应用过滤器,开始全局工具发现...")
|
||
|
||
all_tools: dict[str, ToolExecutable] = {}
|
||
|
||
discover_tasks = []
|
||
for provider in self._providers:
|
||
sig = inspect.signature(provider.discover_tools)
|
||
params_to_pass = {}
|
||
if "allowed_servers" in sig.parameters:
|
||
params_to_pass["allowed_servers"] = allowed_servers
|
||
if "excluded_servers" in sig.parameters:
|
||
params_to_pass["excluded_servers"] = excluded_servers
|
||
|
||
discover_tasks.append(provider.discover_tools(**params_to_pass))
|
||
|
||
results = await asyncio.gather(*discover_tasks, return_exceptions=True)
|
||
|
||
for i, provider_result in enumerate(results):
|
||
provider_name = self._providers[i].__class__.__name__
|
||
if isinstance(provider_result, dict):
|
||
logger.debug(
|
||
f"提供者 '{provider_name}' 发现了 {len(provider_result)} 个工具。"
|
||
)
|
||
for name, executable in provider_result.items():
|
||
if name in all_tools:
|
||
logger.warning(
|
||
f"发现重复的工具名称 '{name}',后发现的将覆盖前者。"
|
||
)
|
||
all_tools[name] = executable
|
||
elif isinstance(provider_result, Exception):
|
||
logger.error(
|
||
f"提供者 '{provider_name}' 在发现工具时出错: {provider_result}"
|
||
)
|
||
|
||
if not has_filters:
|
||
self._resolved_tools = all_tools
|
||
logger.info(f"全局工具发现完成,共找到并缓存了 {len(all_tools)} 个工具。")
|
||
else:
|
||
logger.info(f"带过滤器的工具发现完成,共找到 {len(all_tools)} 个工具。")
|
||
|
||
return all_tools
|
||
|
||
async def resolve_specific_tools(
|
||
self, tool_names: list[str]
|
||
) -> dict[str, ToolExecutable]:
|
||
"""
|
||
仅解析指定名称的工具,避免触发全量工具发现。
|
||
"""
|
||
resolved: dict[str, ToolExecutable] = {}
|
||
if not tool_names:
|
||
return resolved
|
||
|
||
await self.initialize()
|
||
|
||
for name in tool_names:
|
||
config: dict[str, Any] = {"name": name}
|
||
for provider in self._providers:
|
||
try:
|
||
executable = await provider.get_tool_executable(name, config)
|
||
except Exception as exc:
|
||
logger.error(
|
||
f"provider '{provider.__class__.__name__}' 在解析工具 '{name}'"
|
||
f"时出错: {exc}",
|
||
e=exc,
|
||
)
|
||
continue
|
||
|
||
if executable:
|
||
resolved[name] = executable
|
||
break
|
||
else:
|
||
logger.warning(f"没有找到名为 '{name}' 的工具,已跳过。")
|
||
|
||
return resolved
|
||
|
||
async def get_function_tools(
|
||
self, names: list[str] | None = None
|
||
) -> dict[str, ToolExecutable]:
|
||
"""
|
||
仅从内置的函数提供者中解析指定的工具。
|
||
"""
|
||
all_function_tools = await self._builtin_function_provider.discover_tools()
|
||
if names is None:
|
||
return all_function_tools
|
||
|
||
resolved_tools = {}
|
||
for name in names:
|
||
if name in all_function_tools:
|
||
resolved_tools[name] = all_function_tools[name]
|
||
else:
|
||
logger.warning(
|
||
f"本地函数工具 '{name}' 未通过 @function_tool 注册,将被忽略。"
|
||
)
|
||
return resolved_tools
|
||
|
||
|
||
tool_provider_manager = ToolProviderManager()
|
||
function_tool = tool_provider_manager.function_tool
|
||
|
||
|
||
class ToolErrorType(str, Enum):
|
||
"""结构化工具错误的类型枚举。"""
|
||
|
||
TOOL_NOT_FOUND = "ToolNotFound"
|
||
INVALID_ARGUMENTS = "InvalidArguments"
|
||
EXECUTION_ERROR = "ExecutionError"
|
||
USER_CANCELLATION = "UserCancellation"
|
||
|
||
|
||
class ToolErrorResult(BaseModel):
|
||
"""一个结构化的工具执行错误模型。"""
|
||
|
||
error_type: ToolErrorType = Field(..., description="错误的类型。")
|
||
message: str = Field(..., description="对错误的详细描述。")
|
||
is_retryable: bool = Field(False, description="指示这个错误是否可能通过重试解决。")
|
||
|
||
|
||
class ToolInvoker:
|
||
"""
|
||
全能工具执行器。
|
||
负责接收工具调用请求,解析参数,触发回调,执行工具,并返回标准化的结果。
|
||
"""
|
||
|
||
def __init__(self, callbacks: list[BaseCallbackHandler] | None = None):
|
||
self.callbacks = callbacks or []
|
||
|
||
async def _trigger_callbacks(self, event_name: str, *args, **kwargs: Any) -> None:
|
||
if not self.callbacks:
|
||
return
|
||
tasks = [
|
||
getattr(handler, event_name)(*args, **kwargs)
|
||
for handler in self.callbacks
|
||
if hasattr(handler, event_name)
|
||
]
|
||
await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
async def execute_tool_call(
|
||
self,
|
||
tool_call: LLMToolCall,
|
||
available_tools: dict[str, ToolExecutable],
|
||
context: Any | None = None,
|
||
) -> tuple[LLMToolCall, ToolResult]:
|
||
tool_name = tool_call.function.name
|
||
arguments_str = tool_call.function.arguments
|
||
arguments: dict[str, Any] = {}
|
||
|
||
try:
|
||
if arguments_str:
|
||
arguments = json.loads(arguments_str)
|
||
except json.JSONDecodeError as e:
|
||
error_result = ToolErrorResult(
|
||
error_type=ToolErrorType.INVALID_ARGUMENTS,
|
||
message=f"参数解析失败: {e}",
|
||
is_retryable=False,
|
||
)
|
||
return tool_call, ToolResult(output=model_dump(error_result))
|
||
|
||
tool_data = ToolCallData(tool_name=tool_name, tool_args=arguments)
|
||
pre_calculated_result: ToolResult | None = None
|
||
for handler in self.callbacks:
|
||
res = await handler.on_tool_start(tool_call, tool_data)
|
||
if isinstance(res, ToolCallData):
|
||
tool_data = res
|
||
arguments = tool_data.tool_args
|
||
tool_call.function.arguments = json.dumps(arguments, ensure_ascii=False)
|
||
elif isinstance(res, ToolResult):
|
||
pre_calculated_result = res
|
||
break
|
||
|
||
if pre_calculated_result:
|
||
return tool_call, pre_calculated_result
|
||
|
||
executable = available_tools.get(tool_name)
|
||
if not executable:
|
||
error_result = ToolErrorResult(
|
||
error_type=ToolErrorType.TOOL_NOT_FOUND,
|
||
message=f"Tool '{tool_name}' not found.",
|
||
is_retryable=False,
|
||
)
|
||
return tool_call, ToolResult(output=model_dump(error_result))
|
||
|
||
from .config.providers import get_llm_config
|
||
|
||
if not get_llm_config().debug_log:
|
||
try:
|
||
definition = await executable.get_definition()
|
||
schema_payload = getattr(definition, "parameters", {})
|
||
schema_json = fast_json.dumps(
|
||
schema_payload,
|
||
ensure_ascii=False,
|
||
)
|
||
logger.debug(
|
||
f"🔍 [JIT Schema] {tool_name}: {schema_json}",
|
||
"ToolInvoker",
|
||
)
|
||
except Exception as e:
|
||
logger.trace(f"JIT Schema logging failed: {e}")
|
||
|
||
start_t = time.monotonic()
|
||
result: ToolResult | None = None
|
||
error: Exception | None = None
|
||
|
||
try:
|
||
|
||
@Retry.simple(stop_max_attempt=2, wait_fixed_seconds=1)
|
||
async def execute_with_retry():
|
||
return await executable.execute(context=context, **arguments)
|
||
|
||
result = await execute_with_retry()
|
||
except ValidationError as e:
|
||
error = e
|
||
error_msgs = []
|
||
for err in e.errors():
|
||
loc = ".".join(str(x) for x in err["loc"])
|
||
msg = err["msg"]
|
||
error_msgs.append(f"参数 '{loc}': {msg}")
|
||
|
||
formatted_error = "; ".join(error_msgs)
|
||
error_result = ToolErrorResult(
|
||
error_type=ToolErrorType.INVALID_ARGUMENTS,
|
||
message=f"参数验证失败。请根据错误修正你的输入: {formatted_error}",
|
||
is_retryable=True,
|
||
)
|
||
result = ToolResult(output=model_dump(error_result))
|
||
except (TimeoutException, NetworkError) as e:
|
||
error = e
|
||
error_result = ToolErrorResult(
|
||
error_type=ToolErrorType.EXECUTION_ERROR,
|
||
message=f"工具执行网络超时或连接失败: {e!s}",
|
||
is_retryable=False,
|
||
)
|
||
result = ToolResult(output=model_dump(error_result))
|
||
except Exception as e:
|
||
error = e
|
||
error_type = ToolErrorType.EXECUTION_ERROR
|
||
if (
|
||
isinstance(e, LLMException)
|
||
and e.code == LLMErrorCode.CONFIGURATION_ERROR
|
||
):
|
||
error_type = ToolErrorType.TOOL_NOT_FOUND
|
||
is_retryable = False
|
||
|
||
is_retryable = False
|
||
|
||
error_result = ToolErrorResult(
|
||
error_type=error_type, message=str(e), is_retryable=is_retryable
|
||
)
|
||
result = ToolResult(output=model_dump(error_result))
|
||
|
||
duration = time.monotonic() - start_t
|
||
|
||
await self._trigger_callbacks(
|
||
"on_tool_end",
|
||
result=result,
|
||
error=error,
|
||
tool_call=tool_call,
|
||
duration=duration,
|
||
)
|
||
|
||
if result is None:
|
||
raise LLMException("工具执行未返回任何结果。")
|
||
|
||
return tool_call, result
|
||
|
||
async def execute_batch(
|
||
self,
|
||
tool_calls: list[LLMToolCall],
|
||
available_tools: dict[str, ToolExecutable],
|
||
context: Any | None = None,
|
||
) -> list[LLMMessage]:
|
||
if not tool_calls:
|
||
return []
|
||
|
||
tasks = [
|
||
self.execute_tool_call(call, available_tools, context)
|
||
for call in tool_calls
|
||
]
|
||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
tool_messages: list[LLMMessage] = []
|
||
for index, result_pair in enumerate(results):
|
||
original_call = tool_calls[index]
|
||
|
||
if isinstance(result_pair, Exception):
|
||
logger.error(
|
||
f"工具执行发生未捕获异常: {original_call.function.name}, "
|
||
f"错误: {result_pair}"
|
||
)
|
||
tool_messages.append(
|
||
LLMMessage.tool_response(
|
||
tool_call_id=original_call.id,
|
||
function_name=original_call.function.name,
|
||
result={
|
||
"error": f"System Execution Error: {result_pair}",
|
||
"status": "failed",
|
||
},
|
||
)
|
||
)
|
||
continue
|
||
|
||
tool_call_result = cast(tuple[LLMToolCall, ToolResult], result_pair)
|
||
_, tool_result = tool_call_result
|
||
tool_messages.append(
|
||
LLMMessage.tool_response(
|
||
tool_call_id=original_call.id,
|
||
function_name=original_call.function.name,
|
||
result=tool_result.output,
|
||
)
|
||
)
|
||
return tool_messages
|
||
|
||
|
||
__all__ = [
|
||
"RunContext",
|
||
"RunContextParam",
|
||
"ToolErrorResult",
|
||
"ToolErrorType",
|
||
"ToolInvoker",
|
||
"ToolParam",
|
||
"function_tool",
|
||
"tool_provider_manager",
|
||
]
|