zhenxun_bot/zhenxun/services/llm/tools.py
webjoin111 bba90e62db ♻️ refactor(llm): 重构 LLM 服务架构,引入中间件与组件化适配器
- 【重构】LLM 服务核心架构:
    - 引入中间件管道,统一处理请求生命周期(重试、密钥选择、日志、网络请求)。
    - 适配器重构为组件化设计,分离配置映射、消息转换、响应解析和工具序列化逻辑。
    - 移除 `with_smart_retry` 装饰器,其功能由中间件接管。
    - 移除 `LLMToolExecutor`,工具执行逻辑集成到 `ToolInvoker`。
- 【功能】增强配置系统:
    - `LLMGenerationConfig` 采用组件化结构(Core, Reasoning, Visual, Output, Safety, ToolConfig)。
    - 新增 `GenConfigBuilder` 提供语义化配置构建方式。
    - 新增 `LLMEmbeddingConfig` 用于嵌入专用配置。
    - `CommonOverrides` 迁移并更新至新配置结构。
- 【功能】强化工具系统:
    - 引入 `ToolInvoker` 实现更灵活的工具执行,支持回调与结构化错误。
    - `function_tool` 装饰器支持动态 Pydantic 模型创建和依赖注入 (`ToolParam`, `RunContext`)。
    - 平台原生工具支持 (`GeminiCodeExecution`, `GeminiGoogleSearch`, `GeminiUrlContext`)。
- 【功能】高级生成与嵌入:
    - `generate_structured` 方法支持 In-Context Validation and Repair (IVR) 循环和 AutoCoT (思维链) 包装。
    - 新增 `embed_query` 和 `embed_documents` 便捷嵌入 API。
    - `OpenAIImageAdapter` 支持 OpenAI 兼容的图像生成。
    - `SmartAdapter` 实现模型名称智能路由。
- 【重构】消息与类型系统:
    - `LLMContentPart` 扩展支持更多模态和代码执行相关内容。
    - `LLMMessage` 和 `LLMResponse` 结构更新,支持 `content_parts` 和思维链签名。
    - 统一 `LLMErrorCode` 和用户友好错误消息,提供更详细的网络/代理错误提示。
    - `pyproject.toml` 移除 `bilireq`,新增 `json_repair`。
- 【优化】日志与调试:
    - 引入 `DebugLogOptions`,提供细粒度日志脱敏控制。
    - 增强日志净化器,处理更多敏感数据和长字符串。
- 【清理】删除废弃模块:
    - `zhenxun/services/llm/memory.py`
    - `zhenxun/services/llm/executor.py`
    - `zhenxun/services/llm/config/presets.py`
    - `zhenxun/services/llm/types/content.py`
    - `zhenxun/services/llm/types/enums.py`
    - `zhenxun/services/llm/tools/__init__.py`
    - `zhenxun/services/llm/tools/manager.py`
2025-12-07 18:57:55 +08:00

840 lines
27 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
工具模块
整合了工具参数解析器、工具提供者管理器与工具执行逻辑,便于在 LLM 服务层统一调用。
"""
import asyncio
from collections.abc import Callable
from enum import Enum
import inspect
import json
import re
import time
from typing import (
Annotated,
Any,
Optional,
Union,
cast,
get_args,
get_origin,
get_type_hints,
)
from typing_extensions import override
from httpx import NetworkError, TimeoutException
try:
import ujson as fast_json
except ImportError:
fast_json = json
import nonebot
from nonebot.dependencies import Dependent, Param
from nonebot.internal.adapter import Bot, Event
from nonebot.internal.params import (
BotParam,
DefaultParam,
DependParam,
DependsInner,
EventParam,
StateParam,
)
from pydantic import BaseModel, Field, ValidationError, create_model
from pydantic.fields import FieldInfo
from zhenxun.services.log import logger
from zhenxun.utils.decorator.retry import Retry
from zhenxun.utils.pydantic_compat import model_dump, model_fields, model_json_schema
from .types import (
LLMErrorCode,
LLMException,
LLMMessage,
LLMToolCall,
ToolExecutable,
ToolProvider,
ToolResult,
)
from .types.models import ToolDefinition
from .types.protocols import BaseCallbackHandler, ToolCallData
class ToolParam(Param):
"""
工具参数提取器。
用于在自定义工具函数Function Tool从 LLM 解析出的参数字典
(`state["_tool_params"]`)
中提取特定的参数值。通常配合 `Annotated` 和依赖注入系统使用。
"""
def __init__(self, *args: Any, name: str, **kwargs: Any):
super().__init__(*args, **kwargs)
self.name = name
def __repr__(self) -> str:
return f"ToolParam(name={self.name})"
@classmethod
@override
def _check_param(
cls, param: inspect.Parameter, allow_types: tuple[type[Param], ...]
) -> Optional["ToolParam"]:
if param.default is not inspect.Parameter.empty and isinstance(
param.default, DependsInner
):
return None
if get_origin(param.annotation) is Annotated:
for arg in get_args(param.annotation):
if isinstance(arg, DependsInner):
return None
if param.kind not in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
):
return cls(name=param.name)
return None
@override
async def _solve(self, **kwargs: Any) -> Any:
state: dict[str, Any] = kwargs.get("state", {})
tool_params = state.get("_tool_params", {})
if self.name in tool_params:
return tool_params[self.name]
return None
class RunContext(BaseModel):
"""
依赖注入容器DI Container保留原有上下文信息的同时提升获取类型的能力。
"""
session_id: str | None = None
scope: dict[str, Any] = Field(default_factory=dict)
extra: dict[str, Any] = Field(default_factory=dict)
class Config:
arbitrary_types_allowed = True
class RunContextParam(Param):
"""自动注入 RunContext 的参数解析器"""
@classmethod
def _check_param(
cls, param: inspect.Parameter, allow_types: tuple[type[Param], ...]
) -> Optional["RunContextParam"]:
if param.annotation is RunContext:
return cls()
return None
async def _solve(self, **kwargs: Any) -> Any:
state = kwargs.get("state", {})
return state.get("_agent_context")
def _parse_docstring_params(docstring: str | None) -> dict[str, str]:
"""
解析文档字符串,提取参数描述。
支持 Google Style (Args:), ReST Style (:param:), 和中文风格 (参数:)。
"""
if not docstring:
return {}
params: dict[str, str] = {}
lines = docstring.splitlines()
rest_pattern = re.compile(r"[:@]param\s+(\w+)\s*:?\s*(.*)")
found_rest = False
for line in lines:
match = rest_pattern.search(line)
if match:
params[match.group(1)] = match.group(2).strip()
found_rest = True
if found_rest:
return params
section_header_pattern = re.compile(
r"^\s*(?:Args|Arguments|Parameters|参数)\s*[:]\s*$"
)
param_section_active = False
google_pattern = re.compile(r"^\s*(\**\w+)(?:\s*\(.*?\))?\s*[:]\s*(.*)")
for line in lines:
stripped_line = line.strip()
if not stripped_line:
continue
if section_header_pattern.match(line):
param_section_active = True
continue
if param_section_active:
if (
stripped_line.endswith(":") or stripped_line.endswith("")
) and not google_pattern.match(line):
param_section_active = False
continue
match = google_pattern.match(line)
if match:
name = match.group(1).lstrip("*")
desc = match.group(2).strip()
params[name] = desc
return params
def _create_dynamic_model(func: Callable) -> type[BaseModel]:
"""根据函数签名动态创建 Pydantic 模型"""
sig = inspect.signature(func)
doc_params = _parse_docstring_params(func.__doc__)
type_hints = get_type_hints(func, include_extras=True)
fields = {}
for name, param in sig.parameters.items():
if name in ("self", "cls"):
continue
annotation = type_hints.get(name, Any)
default = param.default
is_run_context = False
if annotation is RunContext:
is_run_context = True
else:
origin = get_origin(annotation)
if origin is Union:
args = get_args(annotation)
if RunContext in args:
is_run_context = True
if is_run_context:
continue
if default is not inspect.Parameter.empty and isinstance(default, DependsInner):
continue
if get_origin(annotation) is Annotated:
args = get_args(annotation)
if any(isinstance(arg, DependsInner) for arg in args):
continue
description = doc_params.get(name)
if isinstance(default, FieldInfo):
if description and not getattr(default, "description", None):
default.description = description
fields[name] = (annotation, default)
else:
if default is inspect.Parameter.empty:
default = ...
fields[name] = (annotation, Field(default, description=description))
return create_model(f"{func.__name__}Params", **fields)
class FunctionExecutable(ToolExecutable):
"""一个 ToolExecutable 的实现,用于包装一个普通的 Python 函数。"""
def __init__(
self,
func: Callable,
name: str,
description: str,
params_model: type[BaseModel] | None = None,
unpack_args: bool = False,
):
self._func = func
self._name = name
self._description = description
self._params_model = params_model
self._unpack_args = unpack_args
self.dependent = Dependent[Any].parse(
call=func,
allow_types=(
DependParam,
BotParam,
EventParam,
StateParam,
RunContextParam,
ToolParam,
DefaultParam,
),
)
async def get_definition(self) -> ToolDefinition:
if not self._params_model:
return ToolDefinition(
name=self._name,
description=self._description,
parameters={"type": "object", "properties": {}},
)
schema = model_json_schema(self._params_model)
return ToolDefinition(
name=self._name,
description=self._description,
parameters={
"type": "object",
"properties": schema.get("properties", {}),
"required": schema.get("required", []),
},
)
async def execute(
self, context: RunContext | None = None, **kwargs: Any
) -> ToolResult:
context = context or RunContext()
tool_arguments = kwargs
if self._params_model:
try:
_fields = model_fields(self._params_model)
validation_input = {
key: value for key, value in kwargs.items() if key in _fields
}
validated_params = self._params_model(**validation_input)
if not self._unpack_args:
pass
else:
validated_dict = model_dump(validated_params)
tool_arguments = validated_dict
except ValidationError as e:
error_msgs = []
for err in e.errors():
loc = ".".join(str(x) for x in err["loc"])
msg = err["msg"]
error_msgs.append(f"Parameter '{loc}': {msg}")
formatted_error = "; ".join(error_msgs)
error_payload = {
"error_type": "InvalidArguments",
"message": f"Parameter validation failed: {formatted_error}",
"is_retryable": True,
}
return ToolResult(
output=json.dumps(error_payload, ensure_ascii=False),
display_content=f"Validation Error: {formatted_error}",
)
except Exception as e:
logger.error(
f"执行工具 '{self._name}' 时参数验证或实例化失败: {e}", e=e
)
raise
state = {
"_tool_params": tool_arguments,
"_agent_context": context,
}
bot: Bot | None = None
if context and context.scope.get("bot"):
bot = context.scope.get("bot")
if not bot:
try:
bot = nonebot.get_bot()
except ValueError:
pass
event: Event | None = None
if context and context.scope.get("event"):
event = context.scope.get("event")
raw_result = await self.dependent(
bot=bot,
event=event,
state=state,
)
return ToolResult(output=raw_result, display_content=str(raw_result))
class BuiltinFunctionToolProvider(ToolProvider):
"""一个内置的 ToolProvider用于处理通过装饰器注册的函数。"""
def __init__(self):
self._functions: dict[str, dict[str, Any]] = {}
def register(
self,
name: str,
func: Callable,
description: str,
params_model: type[BaseModel] | None = None,
unpack_args: bool = False,
):
self._functions[name] = {
"func": func,
"description": description,
"params_model": params_model,
"unpack_args": unpack_args,
}
async def initialize(self) -> None:
pass
async def discover_tools(
self,
allowed_servers: list[str] | None = None,
excluded_servers: list[str] | None = None,
) -> dict[str, ToolExecutable]:
executables = {}
for name, info in self._functions.items():
executables[name] = FunctionExecutable(
func=info["func"],
name=name,
description=info["description"],
params_model=info["params_model"],
unpack_args=info.get("unpack_args", False),
)
return executables
async def get_tool_executable(
self, name: str, config: dict[str, Any]
) -> ToolExecutable | None:
if config.get("type", "function") == "function" and name in self._functions:
info = self._functions[name]
return FunctionExecutable(
func=info["func"],
name=name,
description=info["description"],
params_model=info["params_model"],
unpack_args=info.get("unpack_args", False),
)
return None
class ToolProviderManager:
"""工具提供者的中心化管理器,采用单例模式。"""
_instance: "ToolProviderManager | None" = None
def __new__(cls) -> "ToolProviderManager":
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if hasattr(self, "_initialized") and self._initialized:
return
self._providers: list[ToolProvider] = []
self._resolved_tools: dict[str, ToolExecutable] | None = None
self._init_lock = asyncio.Lock()
self._init_promise: asyncio.Task | None = None
self._builtin_function_provider = BuiltinFunctionToolProvider()
self.register(self._builtin_function_provider)
self._initialized = True
def register(self, provider: ToolProvider):
"""注册一个新的 ToolProvider。"""
if provider not in self._providers:
self._providers.append(provider)
logger.info(f"已注册工具提供者: {provider.__class__.__name__}")
def function_tool(
self,
name: str,
description: str,
params_model: type[BaseModel] | None = None,
):
"""装饰器:将一个函数注册为内置工具。"""
def decorator(func: Callable):
if name in self._builtin_function_provider._functions:
logger.warning(f"正在覆盖已注册的函数工具: {name}")
final_model = params_model
unpack_args = False
if final_model is None:
final_model = _create_dynamic_model(func)
unpack_args = True
self._builtin_function_provider.register(
name=name,
func=func,
description=description,
params_model=final_model,
unpack_args=unpack_args,
)
logger.info(f"已注册函数工具: '{name}'")
return func
return decorator
async def initialize(self) -> None:
"""懒加载初始化所有已注册的 ToolProvider。"""
if not self._init_promise:
async with self._init_lock:
if not self._init_promise:
self._init_promise = asyncio.create_task(
self._initialize_providers()
)
await self._init_promise
async def _initialize_providers(self) -> None:
"""内部初始化逻辑。"""
logger.info(f"开始初始化 {len(self._providers)} 个工具提供者...")
init_tasks = [provider.initialize() for provider in self._providers]
await asyncio.gather(*init_tasks, return_exceptions=True)
logger.info("所有工具提供者初始化完成。")
async def get_resolved_tools(
self,
allowed_servers: list[str] | None = None,
excluded_servers: list[str] | None = None,
) -> dict[str, ToolExecutable]:
"""
获取所有已发现和解析的工具。
此方法会触发懒加载初始化,并根据是否传入过滤器来决定是否使用全局缓存。
"""
await self.initialize()
has_filters = allowed_servers is not None or excluded_servers is not None
if not has_filters and self._resolved_tools is not None:
logger.debug("使用全局工具缓存。")
return self._resolved_tools
if has_filters:
logger.info("检测到过滤器,执行临时工具发现 (不使用缓存)。")
logger.debug(
f"过滤器详情: allowed_servers={allowed_servers}, "
f"excluded_servers={excluded_servers}"
)
else:
logger.info("未应用过滤器,开始全局工具发现...")
all_tools: dict[str, ToolExecutable] = {}
discover_tasks = []
for provider in self._providers:
sig = inspect.signature(provider.discover_tools)
params_to_pass = {}
if "allowed_servers" in sig.parameters:
params_to_pass["allowed_servers"] = allowed_servers
if "excluded_servers" in sig.parameters:
params_to_pass["excluded_servers"] = excluded_servers
discover_tasks.append(provider.discover_tools(**params_to_pass))
results = await asyncio.gather(*discover_tasks, return_exceptions=True)
for i, provider_result in enumerate(results):
provider_name = self._providers[i].__class__.__name__
if isinstance(provider_result, dict):
logger.debug(
f"提供者 '{provider_name}' 发现了 {len(provider_result)} 个工具。"
)
for name, executable in provider_result.items():
if name in all_tools:
logger.warning(
f"发现重复的工具名称 '{name}',后发现的将覆盖前者。"
)
all_tools[name] = executable
elif isinstance(provider_result, Exception):
logger.error(
f"提供者 '{provider_name}' 在发现工具时出错: {provider_result}"
)
if not has_filters:
self._resolved_tools = all_tools
logger.info(f"全局工具发现完成,共找到并缓存了 {len(all_tools)} 个工具。")
else:
logger.info(f"带过滤器的工具发现完成,共找到 {len(all_tools)} 个工具。")
return all_tools
async def resolve_specific_tools(
self, tool_names: list[str]
) -> dict[str, ToolExecutable]:
"""
仅解析指定名称的工具,避免触发全量工具发现。
"""
resolved: dict[str, ToolExecutable] = {}
if not tool_names:
return resolved
await self.initialize()
for name in tool_names:
config: dict[str, Any] = {"name": name}
for provider in self._providers:
try:
executable = await provider.get_tool_executable(name, config)
except Exception as exc:
logger.error(
f"provider '{provider.__class__.__name__}' 在解析工具 '{name}'"
f"时出错: {exc}",
e=exc,
)
continue
if executable:
resolved[name] = executable
break
else:
logger.warning(f"没有找到名为 '{name}' 的工具,已跳过。")
return resolved
async def get_function_tools(
self, names: list[str] | None = None
) -> dict[str, ToolExecutable]:
"""
仅从内置的函数提供者中解析指定的工具。
"""
all_function_tools = await self._builtin_function_provider.discover_tools()
if names is None:
return all_function_tools
resolved_tools = {}
for name in names:
if name in all_function_tools:
resolved_tools[name] = all_function_tools[name]
else:
logger.warning(
f"本地函数工具 '{name}' 未通过 @function_tool 注册,将被忽略。"
)
return resolved_tools
tool_provider_manager = ToolProviderManager()
function_tool = tool_provider_manager.function_tool
class ToolErrorType(str, Enum):
"""结构化工具错误的类型枚举。"""
TOOL_NOT_FOUND = "ToolNotFound"
INVALID_ARGUMENTS = "InvalidArguments"
EXECUTION_ERROR = "ExecutionError"
USER_CANCELLATION = "UserCancellation"
class ToolErrorResult(BaseModel):
"""一个结构化的工具执行错误模型。"""
error_type: ToolErrorType = Field(..., description="错误的类型。")
message: str = Field(..., description="对错误的详细描述。")
is_retryable: bool = Field(False, description="指示这个错误是否可能通过重试解决。")
class ToolInvoker:
"""
全能工具执行器。
负责接收工具调用请求,解析参数,触发回调,执行工具,并返回标准化的结果。
"""
def __init__(self, callbacks: list[BaseCallbackHandler] | None = None):
self.callbacks = callbacks or []
async def _trigger_callbacks(self, event_name: str, *args, **kwargs: Any) -> None:
if not self.callbacks:
return
tasks = [
getattr(handler, event_name)(*args, **kwargs)
for handler in self.callbacks
if hasattr(handler, event_name)
]
await asyncio.gather(*tasks, return_exceptions=True)
async def execute_tool_call(
self,
tool_call: LLMToolCall,
available_tools: dict[str, ToolExecutable],
context: Any | None = None,
) -> tuple[LLMToolCall, ToolResult]:
tool_name = tool_call.function.name
arguments_str = tool_call.function.arguments
arguments: dict[str, Any] = {}
try:
if arguments_str:
arguments = json.loads(arguments_str)
except json.JSONDecodeError as e:
error_result = ToolErrorResult(
error_type=ToolErrorType.INVALID_ARGUMENTS,
message=f"参数解析失败: {e}",
is_retryable=False,
)
return tool_call, ToolResult(output=model_dump(error_result))
tool_data = ToolCallData(tool_name=tool_name, tool_args=arguments)
pre_calculated_result: ToolResult | None = None
for handler in self.callbacks:
res = await handler.on_tool_start(tool_call, tool_data)
if isinstance(res, ToolCallData):
tool_data = res
arguments = tool_data.tool_args
tool_call.function.arguments = json.dumps(arguments, ensure_ascii=False)
elif isinstance(res, ToolResult):
pre_calculated_result = res
break
if pre_calculated_result:
return tool_call, pre_calculated_result
executable = available_tools.get(tool_name)
if not executable:
error_result = ToolErrorResult(
error_type=ToolErrorType.TOOL_NOT_FOUND,
message=f"Tool '{tool_name}' not found.",
is_retryable=False,
)
return tool_call, ToolResult(output=model_dump(error_result))
from .config.providers import get_llm_config
if not get_llm_config().debug_log:
try:
definition = await executable.get_definition()
schema_payload = getattr(definition, "parameters", {})
schema_json = fast_json.dumps(
schema_payload,
ensure_ascii=False,
)
logger.debug(
f"🔍 [JIT Schema] {tool_name}: {schema_json}",
"ToolInvoker",
)
except Exception as e:
logger.trace(f"JIT Schema logging failed: {e}")
start_t = time.monotonic()
result: ToolResult | None = None
error: Exception | None = None
try:
@Retry.simple(stop_max_attempt=2, wait_fixed_seconds=1)
async def execute_with_retry():
return await executable.execute(context=context, **arguments)
result = await execute_with_retry()
except ValidationError as e:
error = e
error_msgs = []
for err in e.errors():
loc = ".".join(str(x) for x in err["loc"])
msg = err["msg"]
error_msgs.append(f"参数 '{loc}': {msg}")
formatted_error = "; ".join(error_msgs)
error_result = ToolErrorResult(
error_type=ToolErrorType.INVALID_ARGUMENTS,
message=f"参数验证失败。请根据错误修正你的输入: {formatted_error}",
is_retryable=True,
)
result = ToolResult(output=model_dump(error_result))
except (TimeoutException, NetworkError) as e:
error = e
error_result = ToolErrorResult(
error_type=ToolErrorType.EXECUTION_ERROR,
message=f"工具执行网络超时或连接失败: {e!s}",
is_retryable=False,
)
result = ToolResult(output=model_dump(error_result))
except Exception as e:
error = e
error_type = ToolErrorType.EXECUTION_ERROR
if (
isinstance(e, LLMException)
and e.code == LLMErrorCode.CONFIGURATION_ERROR
):
error_type = ToolErrorType.TOOL_NOT_FOUND
is_retryable = False
is_retryable = False
error_result = ToolErrorResult(
error_type=error_type, message=str(e), is_retryable=is_retryable
)
result = ToolResult(output=model_dump(error_result))
duration = time.monotonic() - start_t
await self._trigger_callbacks(
"on_tool_end",
result=result,
error=error,
tool_call=tool_call,
duration=duration,
)
if result is None:
raise LLMException("工具执行未返回任何结果。")
return tool_call, result
async def execute_batch(
self,
tool_calls: list[LLMToolCall],
available_tools: dict[str, ToolExecutable],
context: Any | None = None,
) -> list[LLMMessage]:
if not tool_calls:
return []
tasks = [
self.execute_tool_call(call, available_tools, context)
for call in tool_calls
]
results = await asyncio.gather(*tasks, return_exceptions=True)
tool_messages: list[LLMMessage] = []
for index, result_pair in enumerate(results):
original_call = tool_calls[index]
if isinstance(result_pair, Exception):
logger.error(
f"工具执行发生未捕获异常: {original_call.function.name}, "
f"错误: {result_pair}"
)
tool_messages.append(
LLMMessage.tool_response(
tool_call_id=original_call.id,
function_name=original_call.function.name,
result={
"error": f"System Execution Error: {result_pair}",
"status": "failed",
},
)
)
continue
tool_call_result = cast(tuple[LLMToolCall, ToolResult], result_pair)
_, tool_result = tool_call_result
tool_messages.append(
LLMMessage.tool_response(
tool_call_id=original_call.id,
function_name=original_call.function.name,
result=tool_result.output,
)
)
return tool_messages
__all__ = [
"RunContext",
"RunContextParam",
"ToolErrorResult",
"ToolErrorType",
"ToolInvoker",
"ToolParam",
"function_tool",
"tool_provider_manager",
]