zhenxun_bot/zhenxun/services/llm/core.py

379 lines
13 KiB
Python
Raw Normal View History

2025-07-01 16:56:34 +08:00
"""
LLM 核心基础设施模块
包含执行 LLM 请求所需的底层组件 HTTP 客户端API Key 存储和智能重试逻辑
"""
import asyncio
from typing import Any
import httpx
from pydantic import BaseModel
from zhenxun.services.log import logger
from zhenxun.utils.user_agent import get_user_agent
from .types import ProviderConfig
from .types.exceptions import LLMErrorCode, LLMException
class HttpClientConfig(BaseModel):
"""HTTP客户端配置"""
timeout: int = 180
max_connections: int = 100
max_keepalive_connections: int = 20
proxy: str | None = None
class LLMHttpClient:
"""LLM服务专用HTTP客户端"""
def __init__(self, config: HttpClientConfig | None = None):
self.config = config or HttpClientConfig()
self._client: httpx.AsyncClient | None = None
self._active_requests = 0
self._lock = asyncio.Lock()
async def _ensure_client_initialized(self) -> httpx.AsyncClient:
if self._client is None or self._client.is_closed:
async with self._lock:
if self._client is None or self._client.is_closed:
logger.debug(
f"LLMHttpClient: Initializing new httpx.AsyncClient "
f"with config: {self.config}"
)
headers = get_user_agent()
limits = httpx.Limits(
max_connections=self.config.max_connections,
max_keepalive_connections=self.config.max_keepalive_connections,
)
timeout = httpx.Timeout(self.config.timeout)
self._client = httpx.AsyncClient(
headers=headers,
limits=limits,
timeout=timeout,
proxies=self.config.proxy,
follow_redirects=True,
)
if self._client is None:
raise LLMException(
"HTTP client failed to initialize.", LLMErrorCode.CONFIGURATION_ERROR
)
return self._client
async def post(self, url: str, **kwargs: Any) -> httpx.Response:
client = await self._ensure_client_initialized()
async with self._lock:
self._active_requests += 1
try:
return await client.post(url, **kwargs)
finally:
async with self._lock:
self._active_requests -= 1
async def close(self):
async with self._lock:
if self._client and not self._client.is_closed:
logger.debug(
f"LLMHttpClient: Closing with config: {self.config}. "
f"Active requests: {self._active_requests}"
)
if self._active_requests > 0:
logger.warning(
f"LLMHttpClient: Closing while {self._active_requests} "
f"requests are still active."
)
await self._client.aclose()
self._client = None
logger.debug(f"LLMHttpClient for config {self.config} definitively closed.")
@property
def is_closed(self) -> bool:
return self._client is None or self._client.is_closed
class LLMHttpClientManager:
"""管理 LLMHttpClient 实例的工厂和池"""
def __init__(self):
self._clients: dict[tuple[int, str | None], LLMHttpClient] = {}
self._lock = asyncio.Lock()
def _get_client_key(
self, provider_config: ProviderConfig
) -> tuple[int, str | None]:
return (provider_config.timeout, provider_config.proxy)
async def get_client(self, provider_config: ProviderConfig) -> LLMHttpClient:
key = self._get_client_key(provider_config)
async with self._lock:
client = self._clients.get(key)
if client and not client.is_closed:
logger.debug(
f"LLMHttpClientManager: Reusing existing LLMHttpClient "
f"for key: {key}"
)
return client
if client and client.is_closed:
logger.debug(
f"LLMHttpClientManager: Found a closed client for key {key}. "
f"Creating a new one."
)
logger.debug(
f"LLMHttpClientManager: Creating new LLMHttpClient for key: {key}"
)
http_client_config = HttpClientConfig(
timeout=provider_config.timeout, proxy=provider_config.proxy
)
new_client = LLMHttpClient(config=http_client_config)
self._clients[key] = new_client
return new_client
async def shutdown(self):
async with self._lock:
logger.info(
f"LLMHttpClientManager: Shutting down. "
f"Closing {len(self._clients)} client(s)."
)
close_tasks = [
client.close()
for client in self._clients.values()
if client and not client.is_closed
]
if close_tasks:
await asyncio.gather(*close_tasks, return_exceptions=True)
self._clients.clear()
logger.info("LLMHttpClientManager: Shutdown complete.")
http_client_manager = LLMHttpClientManager()
async def create_llm_http_client(
timeout: int = 180,
proxy: str | None = None,
) -> LLMHttpClient:
"""创建LLM HTTP客户端"""
config = HttpClientConfig(timeout=timeout, proxy=proxy)
return LLMHttpClient(config)
class RetryConfig:
"""重试配置"""
def __init__(
self,
max_retries: int = 3,
retry_delay: float = 1.0,
exponential_backoff: bool = True,
key_rotation: bool = True,
):
self.max_retries = max_retries
self.retry_delay = retry_delay
self.exponential_backoff = exponential_backoff
self.key_rotation = key_rotation
async def with_smart_retry(
func,
*args,
retry_config: RetryConfig | None = None,
key_store: "KeyStatusStore | None" = None,
provider_name: str | None = None,
**kwargs: Any,
) -> Any:
"""智能重试装饰器 - 支持Key轮询和错误分类"""
config = retry_config or RetryConfig()
last_exception: Exception | None = None
failed_keys: set[str] = set()
for attempt in range(config.max_retries + 1):
try:
if config.key_rotation and "failed_keys" in func.__code__.co_varnames:
kwargs["failed_keys"] = failed_keys
return await func(*args, **kwargs)
except LLMException as e:
last_exception = e
if e.code in [
LLMErrorCode.API_KEY_INVALID,
LLMErrorCode.API_QUOTA_EXCEEDED,
]:
if hasattr(e, "details") and e.details and "api_key" in e.details:
failed_keys.add(e.details["api_key"])
if key_store and provider_name:
await key_store.record_failure(
e.details["api_key"], e.details.get("status_code")
)
should_retry = _should_retry_llm_error(e, attempt, config.max_retries)
if not should_retry:
logger.error(f"不可重试的错误,停止重试: {e}")
raise
if attempt < config.max_retries:
wait_time = config.retry_delay
if config.exponential_backoff:
wait_time *= 2**attempt
logger.warning(
f"请求失败,{wait_time}秒后重试 (第{attempt + 1}次): {e}"
)
await asyncio.sleep(wait_time)
else:
logger.error(f"重试{config.max_retries}次后仍然失败: {e}")
except Exception as e:
last_exception = e
logger.error(f"非LLM异常停止重试: {e}")
raise LLMException(
f"操作失败: {e}",
code=LLMErrorCode.GENERATION_FAILED,
cause=e,
)
if last_exception:
raise last_exception
else:
raise RuntimeError("重试函数未能正常执行且未捕获到异常")
def _should_retry_llm_error(
error: LLMException, attempt: int, max_retries: int
) -> bool:
"""判断LLM错误是否应该重试"""
non_retryable_errors = {
LLMErrorCode.MODEL_NOT_FOUND,
LLMErrorCode.CONTEXT_LENGTH_EXCEEDED,
LLMErrorCode.USER_LOCATION_NOT_SUPPORTED,
LLMErrorCode.CONFIGURATION_ERROR,
}
if error.code in non_retryable_errors:
return False
retryable_errors = {
LLMErrorCode.API_REQUEST_FAILED,
LLMErrorCode.API_TIMEOUT,
LLMErrorCode.API_RATE_LIMITED,
LLMErrorCode.API_RESPONSE_INVALID,
LLMErrorCode.RESPONSE_PARSE_ERROR,
LLMErrorCode.GENERATION_FAILED,
LLMErrorCode.CONTENT_FILTERED,
LLMErrorCode.API_KEY_INVALID,
LLMErrorCode.API_QUOTA_EXCEEDED,
}
if error.code in retryable_errors:
if error.code == LLMErrorCode.API_QUOTA_EXCEEDED:
return attempt < min(2, max_retries)
elif error.code == LLMErrorCode.CONTENT_FILTERED:
return attempt < min(1, max_retries)
return True
return False
class KeyStatusStore:
"""API Key 状态管理存储 - 优化版本,支持轮询和负载均衡"""
def __init__(self):
self._key_status: dict[str, bool] = {}
self._key_usage_count: dict[str, int] = {}
self._key_last_used: dict[str, float] = {}
self._provider_key_index: dict[str, int] = {}
self._lock = asyncio.Lock()
async def get_next_available_key(
self,
provider_name: str,
api_keys: list[str],
exclude_keys: set[str] | None = None,
) -> str | None:
"""获取下一个可用的API密钥轮询策略"""
if not api_keys:
return None
exclude_keys = exclude_keys or set()
available_keys = [
key
for key in api_keys
if key not in exclude_keys and self._key_status.get(key, True)
]
if not available_keys:
return api_keys[0] if api_keys else None
async with self._lock:
current_index = self._provider_key_index.get(provider_name, 0)
selected_key = available_keys[current_index % len(available_keys)]
self._provider_key_index[provider_name] = (current_index + 1) % len(
available_keys
)
import time
self._key_usage_count[selected_key] = (
self._key_usage_count.get(selected_key, 0) + 1
)
self._key_last_used[selected_key] = time.time()
logger.debug(
f"轮询选择API密钥: {self._get_key_id(selected_key)} "
f"(使用次数: {self._key_usage_count[selected_key]})"
)
return selected_key
async def record_success(self, api_key: str):
"""记录成功使用"""
async with self._lock:
self._key_status[api_key] = True
logger.debug(f"记录API密钥成功使用: {self._get_key_id(api_key)}")
async def record_failure(self, api_key: str, status_code: int | None):
"""记录失败使用"""
key_id = self._get_key_id(api_key)
async with self._lock:
if status_code in [401, 403]:
self._key_status[api_key] = False
logger.warning(
f"API密钥认证失败标记为不可用: {key_id} (状态码: {status_code})"
)
else:
logger.debug(f"记录API密钥失败使用: {key_id} (状态码: {status_code})")
async def reset_key_status(self, api_key: str):
"""重置密钥状态(用于恢复机制)"""
async with self._lock:
self._key_status[api_key] = True
logger.info(f"重置API密钥状态: {self._get_key_id(api_key)}")
async def get_key_stats(self, api_keys: list[str]) -> dict[str, dict]:
"""获取密钥使用统计"""
stats = {}
async with self._lock:
for key in api_keys:
key_id = self._get_key_id(key)
stats[key_id] = {
"available": self._key_status.get(key, True),
"usage_count": self._key_usage_count.get(key, 0),
"last_used": self._key_last_used.get(key, 0),
}
return stats
def _get_key_id(self, api_key: str) -> str:
"""获取API密钥的标识符用于日志"""
if len(api_key) <= 8:
return api_key
return f"{api_key[:4]}...{api_key[-4:]}"
key_store = KeyStatusStore()