zhenxun_bot/zhenxun/services/llm/executor.py

194 lines
6.1 KiB
Python
Raw Normal View History

"""
LLM 轻量级工具执行器
提供驱动 LLM 与本地函数工具之间交互的核心循环
"""
import asyncio
from enum import Enum
import json
from typing import Any
from pydantic import BaseModel, Field
from zhenxun.services.log import logger
from zhenxun.utils.decorator.retry import Retry
from zhenxun.utils.pydantic_compat import model_dump
from .service import LLMModel
from .types import (
LLMErrorCode,
LLMException,
LLMMessage,
ToolExecutable,
ToolResult,
)
class ExecutionConfig(BaseModel):
"""
轻量级执行器的配置
"""
max_cycles: int = Field(default=5, description="工具调用循环的最大次数。")
class ToolErrorType(str, Enum):
"""结构化工具错误的类型枚举。"""
TOOL_NOT_FOUND = "ToolNotFound"
INVALID_ARGUMENTS = "InvalidArguments"
EXECUTION_ERROR = "ExecutionError"
USER_CANCELLATION = "UserCancellation"
class ToolErrorResult(BaseModel):
"""一个结构化的工具执行错误模型,用于返回给 LLM。"""
error_type: ToolErrorType = Field(..., description="错误的类型。")
message: str = Field(..., description="对错误的详细描述。")
is_retryable: bool = Field(False, description="指示这个错误是否可能通过重试解决。")
def model_dump(self, **kwargs):
return model_dump(self, **kwargs)
def _is_exception_retryable(e: Exception) -> bool:
"""判断一个异常是否应该触发重试。"""
if isinstance(e, LLMException):
retryable_codes = {
LLMErrorCode.API_REQUEST_FAILED,
LLMErrorCode.API_TIMEOUT,
LLMErrorCode.API_RATE_LIMITED,
}
return e.code in retryable_codes
return True
class LLMToolExecutor:
"""
一个通用的执行器负责驱动 LLM 与工具之间的多轮交互
"""
def __init__(self, model: LLMModel):
self.model = model
async def run(
self,
messages: list[LLMMessage],
tools: dict[str, ToolExecutable],
config: ExecutionConfig | None = None,
) -> list[LLMMessage]:
"""
执行完整的思考-行动循环
"""
effective_config = config or ExecutionConfig()
execution_history = list(messages)
for i in range(effective_config.max_cycles):
response = await self.model.generate_response(
execution_history, tools=tools
)
assistant_message = LLMMessage(
role="assistant",
content=response.text,
tool_calls=response.tool_calls,
)
execution_history.append(assistant_message)
if not response.tool_calls:
logger.info("✅ LLMToolExecutor模型未请求工具调用执行结束。")
return execution_history
logger.info(
f"🛠️ LLMToolExecutor模型请求并行调用 {len(response.tool_calls)} 个工具"
)
tool_results = await self._execute_tools_parallel_safely(
response.tool_calls,
tools,
)
execution_history.extend(tool_results)
raise LLMException(
f"超过最大工具调用循环次数 ({effective_config.max_cycles})。",
code=LLMErrorCode.GENERATION_FAILED,
)
async def _execute_single_tool_safely(
self, tool_call: Any, available_tools: dict[str, ToolExecutable]
) -> tuple[Any, ToolResult]:
"""安全地执行单个工具调用。"""
tool_name = tool_call.function.name
arguments = {}
try:
if tool_call.function.arguments:
arguments = json.loads(tool_call.function.arguments)
except json.JSONDecodeError as e:
error_result = ToolErrorResult(
error_type=ToolErrorType.INVALID_ARGUMENTS,
message=f"参数解析失败: {e}",
is_retryable=False,
)
return tool_call, ToolResult(output=model_dump(error_result))
try:
executable = available_tools.get(tool_name)
if not executable:
raise LLMException(
f"Tool '{tool_name}' not found.",
code=LLMErrorCode.CONFIGURATION_ERROR,
)
@Retry.simple(
stop_max_attempt=2, wait_fixed_seconds=1, return_on_failure=None
)
async def execute_with_retry():
return await executable.execute(**arguments)
execution_result = await execute_with_retry()
if execution_result is None:
raise LLMException("工具执行在多次重试后仍然失败。")
return tool_call, execution_result
except Exception as e:
error_type = ToolErrorType.EXECUTION_ERROR
is_retryable = _is_exception_retryable(e)
if (
isinstance(e, LLMException)
and e.code == LLMErrorCode.CONFIGURATION_ERROR
):
error_type = ToolErrorType.TOOL_NOT_FOUND
is_retryable = False
error_result = ToolErrorResult(
error_type=error_type, message=str(e), is_retryable=is_retryable
)
return tool_call, ToolResult(output=model_dump(error_result))
async def _execute_tools_parallel_safely(
self,
tool_calls: list[Any],
available_tools: dict[str, ToolExecutable],
) -> list[LLMMessage]:
"""并行执行所有工具调用,并对每个调用的错误进行隔离。"""
if not tool_calls:
return []
tasks = [
self._execute_single_tool_safely(call, available_tools)
for call in tool_calls
]
results = await asyncio.gather(*tasks)
tool_messages = [
LLMMessage.tool_response(
tool_call_id=original_call.id,
function_name=original_call.function.name,
result=result.output,
)
for original_call, result in results
]
return tool_messages