mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 06:12:53 +08:00
194 lines
6.1 KiB
Python
194 lines
6.1 KiB
Python
|
|
"""
|
|||
|
|
LLM 轻量级工具执行器
|
|||
|
|
|
|||
|
|
提供驱动 LLM 与本地函数工具之间交互的核心循环。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import asyncio
|
|||
|
|
from enum import Enum
|
|||
|
|
import json
|
|||
|
|
from typing import Any
|
|||
|
|
|
|||
|
|
from pydantic import BaseModel, Field
|
|||
|
|
|
|||
|
|
from zhenxun.services.log import logger
|
|||
|
|
from zhenxun.utils.decorator.retry import Retry
|
|||
|
|
from zhenxun.utils.pydantic_compat import model_dump
|
|||
|
|
|
|||
|
|
from .service import LLMModel
|
|||
|
|
from .types import (
|
|||
|
|
LLMErrorCode,
|
|||
|
|
LLMException,
|
|||
|
|
LLMMessage,
|
|||
|
|
ToolExecutable,
|
|||
|
|
ToolResult,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ExecutionConfig(BaseModel):
|
|||
|
|
"""
|
|||
|
|
轻量级执行器的配置。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
max_cycles: int = Field(default=5, description="工具调用循环的最大次数。")
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ToolErrorType(str, Enum):
|
|||
|
|
"""结构化工具错误的类型枚举。"""
|
|||
|
|
|
|||
|
|
TOOL_NOT_FOUND = "ToolNotFound"
|
|||
|
|
INVALID_ARGUMENTS = "InvalidArguments"
|
|||
|
|
EXECUTION_ERROR = "ExecutionError"
|
|||
|
|
USER_CANCELLATION = "UserCancellation"
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ToolErrorResult(BaseModel):
|
|||
|
|
"""一个结构化的工具执行错误模型,用于返回给 LLM。"""
|
|||
|
|
|
|||
|
|
error_type: ToolErrorType = Field(..., description="错误的类型。")
|
|||
|
|
message: str = Field(..., description="对错误的详细描述。")
|
|||
|
|
is_retryable: bool = Field(False, description="指示这个错误是否可能通过重试解决。")
|
|||
|
|
|
|||
|
|
def model_dump(self, **kwargs):
|
|||
|
|
return model_dump(self, **kwargs)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _is_exception_retryable(e: Exception) -> bool:
|
|||
|
|
"""判断一个异常是否应该触发重试。"""
|
|||
|
|
if isinstance(e, LLMException):
|
|||
|
|
retryable_codes = {
|
|||
|
|
LLMErrorCode.API_REQUEST_FAILED,
|
|||
|
|
LLMErrorCode.API_TIMEOUT,
|
|||
|
|
LLMErrorCode.API_RATE_LIMITED,
|
|||
|
|
}
|
|||
|
|
return e.code in retryable_codes
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
|
|||
|
|
class LLMToolExecutor:
|
|||
|
|
"""
|
|||
|
|
一个通用的执行器,负责驱动 LLM 与工具之间的多轮交互。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, model: LLMModel):
|
|||
|
|
self.model = model
|
|||
|
|
|
|||
|
|
async def run(
|
|||
|
|
self,
|
|||
|
|
messages: list[LLMMessage],
|
|||
|
|
tools: dict[str, ToolExecutable],
|
|||
|
|
config: ExecutionConfig | None = None,
|
|||
|
|
) -> list[LLMMessage]:
|
|||
|
|
"""
|
|||
|
|
执行完整的思考-行动循环。
|
|||
|
|
"""
|
|||
|
|
effective_config = config or ExecutionConfig()
|
|||
|
|
execution_history = list(messages)
|
|||
|
|
|
|||
|
|
for i in range(effective_config.max_cycles):
|
|||
|
|
response = await self.model.generate_response(
|
|||
|
|
execution_history, tools=tools
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assistant_message = LLMMessage(
|
|||
|
|
role="assistant",
|
|||
|
|
content=response.text,
|
|||
|
|
tool_calls=response.tool_calls,
|
|||
|
|
)
|
|||
|
|
execution_history.append(assistant_message)
|
|||
|
|
|
|||
|
|
if not response.tool_calls:
|
|||
|
|
logger.info("✅ LLMToolExecutor:模型未请求工具调用,执行结束。")
|
|||
|
|
return execution_history
|
|||
|
|
|
|||
|
|
logger.info(
|
|||
|
|
f"🛠️ LLMToolExecutor:模型请求并行调用 {len(response.tool_calls)} 个工具"
|
|||
|
|
)
|
|||
|
|
tool_results = await self._execute_tools_parallel_safely(
|
|||
|
|
response.tool_calls,
|
|||
|
|
tools,
|
|||
|
|
)
|
|||
|
|
execution_history.extend(tool_results)
|
|||
|
|
|
|||
|
|
raise LLMException(
|
|||
|
|
f"超过最大工具调用循环次数 ({effective_config.max_cycles})。",
|
|||
|
|
code=LLMErrorCode.GENERATION_FAILED,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
async def _execute_single_tool_safely(
|
|||
|
|
self, tool_call: Any, available_tools: dict[str, ToolExecutable]
|
|||
|
|
) -> tuple[Any, ToolResult]:
|
|||
|
|
"""安全地执行单个工具调用。"""
|
|||
|
|
tool_name = tool_call.function.name
|
|||
|
|
arguments = {}
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
if tool_call.function.arguments:
|
|||
|
|
arguments = json.loads(tool_call.function.arguments)
|
|||
|
|
except json.JSONDecodeError as e:
|
|||
|
|
error_result = ToolErrorResult(
|
|||
|
|
error_type=ToolErrorType.INVALID_ARGUMENTS,
|
|||
|
|
message=f"参数解析失败: {e}",
|
|||
|
|
is_retryable=False,
|
|||
|
|
)
|
|||
|
|
return tool_call, ToolResult(output=model_dump(error_result))
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
executable = available_tools.get(tool_name)
|
|||
|
|
if not executable:
|
|||
|
|
raise LLMException(
|
|||
|
|
f"Tool '{tool_name}' not found.",
|
|||
|
|
code=LLMErrorCode.CONFIGURATION_ERROR,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
@Retry.simple(
|
|||
|
|
stop_max_attempt=2, wait_fixed_seconds=1, return_on_failure=None
|
|||
|
|
)
|
|||
|
|
async def execute_with_retry():
|
|||
|
|
return await executable.execute(**arguments)
|
|||
|
|
|
|||
|
|
execution_result = await execute_with_retry()
|
|||
|
|
if execution_result is None:
|
|||
|
|
raise LLMException("工具执行在多次重试后仍然失败。")
|
|||
|
|
|
|||
|
|
return tool_call, execution_result
|
|||
|
|
except Exception as e:
|
|||
|
|
error_type = ToolErrorType.EXECUTION_ERROR
|
|||
|
|
is_retryable = _is_exception_retryable(e)
|
|||
|
|
if (
|
|||
|
|
isinstance(e, LLMException)
|
|||
|
|
and e.code == LLMErrorCode.CONFIGURATION_ERROR
|
|||
|
|
):
|
|||
|
|
error_type = ToolErrorType.TOOL_NOT_FOUND
|
|||
|
|
is_retryable = False
|
|||
|
|
|
|||
|
|
error_result = ToolErrorResult(
|
|||
|
|
error_type=error_type, message=str(e), is_retryable=is_retryable
|
|||
|
|
)
|
|||
|
|
return tool_call, ToolResult(output=model_dump(error_result))
|
|||
|
|
|
|||
|
|
async def _execute_tools_parallel_safely(
|
|||
|
|
self,
|
|||
|
|
tool_calls: list[Any],
|
|||
|
|
available_tools: dict[str, ToolExecutable],
|
|||
|
|
) -> list[LLMMessage]:
|
|||
|
|
"""并行执行所有工具调用,并对每个调用的错误进行隔离。"""
|
|||
|
|
if not tool_calls:
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
tasks = [
|
|||
|
|
self._execute_single_tool_safely(call, available_tools)
|
|||
|
|
for call in tool_calls
|
|||
|
|
]
|
|||
|
|
results = await asyncio.gather(*tasks)
|
|||
|
|
|
|||
|
|
tool_messages = [
|
|||
|
|
LLMMessage.tool_response(
|
|||
|
|
tool_call_id=original_call.id,
|
|||
|
|
function_name=original_call.function.name,
|
|||
|
|
result=result.output,
|
|||
|
|
)
|
|||
|
|
for original_call, result in results
|
|||
|
|
]
|
|||
|
|
return tool_messages
|