zhenxun_bot/zhenxun/services/llm/manager.py
webjoin111 bba90e62db ♻️ refactor(llm): 重构 LLM 服务架构,引入中间件与组件化适配器
- 【重构】LLM 服务核心架构:
    - 引入中间件管道,统一处理请求生命周期(重试、密钥选择、日志、网络请求)。
    - 适配器重构为组件化设计,分离配置映射、消息转换、响应解析和工具序列化逻辑。
    - 移除 `with_smart_retry` 装饰器,其功能由中间件接管。
    - 移除 `LLMToolExecutor`,工具执行逻辑集成到 `ToolInvoker`。
- 【功能】增强配置系统:
    - `LLMGenerationConfig` 采用组件化结构(Core, Reasoning, Visual, Output, Safety, ToolConfig)。
    - 新增 `GenConfigBuilder` 提供语义化配置构建方式。
    - 新增 `LLMEmbeddingConfig` 用于嵌入专用配置。
    - `CommonOverrides` 迁移并更新至新配置结构。
- 【功能】强化工具系统:
    - 引入 `ToolInvoker` 实现更灵活的工具执行,支持回调与结构化错误。
    - `function_tool` 装饰器支持动态 Pydantic 模型创建和依赖注入 (`ToolParam`, `RunContext`)。
    - 平台原生工具支持 (`GeminiCodeExecution`, `GeminiGoogleSearch`, `GeminiUrlContext`)。
- 【功能】高级生成与嵌入:
    - `generate_structured` 方法支持 In-Context Validation and Repair (IVR) 循环和 AutoCoT (思维链) 包装。
    - 新增 `embed_query` 和 `embed_documents` 便捷嵌入 API。
    - `OpenAIImageAdapter` 支持 OpenAI 兼容的图像生成。
    - `SmartAdapter` 实现模型名称智能路由。
- 【重构】消息与类型系统:
    - `LLMContentPart` 扩展支持更多模态和代码执行相关内容。
    - `LLMMessage` 和 `LLMResponse` 结构更新,支持 `content_parts` 和思维链签名。
    - 统一 `LLMErrorCode` 和用户友好错误消息,提供更详细的网络/代理错误提示。
    - `pyproject.toml` 移除 `bilireq`,新增 `json_repair`。
- 【优化】日志与调试:
    - 引入 `DebugLogOptions`,提供细粒度日志脱敏控制。
    - 增强日志净化器,处理更多敏感数据和长字符串。
- 【清理】删除废弃模块:
    - `zhenxun/services/llm/memory.py`
    - `zhenxun/services/llm/executor.py`
    - `zhenxun/services/llm/config/presets.py`
    - `zhenxun/services/llm/types/content.py`
    - `zhenxun/services/llm/types/enums.py`
    - `zhenxun/services/llm/tools/__init__.py`
    - `zhenxun/services/llm/tools/manager.py`
2025-12-07 18:57:55 +08:00

474 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
LLM 模型管理器
负责模型实例的创建、缓存、配置管理和生命周期管理。
"""
import hashlib
import time
from typing import Any
from zhenxun.configs.config import Config
from zhenxun.services.log import logger
from zhenxun.utils.pydantic_compat import dump_json_safely
from .config import validate_override_params
from .config.generation import LLMGenerationConfig
from .config.providers import (
AI_CONFIG_GROUP,
PROVIDERS_CONFIG_KEY,
get_ai_config,
get_llm_config,
)
from .core import http_client_manager, key_store
from .service import LLMModel
from .types import LLMErrorCode, LLMException, ModelDetail, ProviderConfig
from .types.capabilities import get_model_capabilities
DEFAULT_MODEL_NAME_KEY = "default_model_name"
_model_cache: dict[str, tuple[LLMModel, float]] = {}
_cache_ttl = 3600
_max_cache_size = 10
def parse_provider_model_string(name_str: str | None) -> tuple[str | None, str | None]:
"""解析 'ProviderName/ModelName' 格式的字符串"""
if not name_str or "/" not in name_str:
return None, None
parts = name_str.split("/", 1)
if len(parts) == 2 and parts[0].strip() and parts[1].strip():
return parts[0].strip(), parts[1].strip()
return None, None
def _make_cache_key(
provider_model_name: str | None,
override_config: dict | LLMGenerationConfig | None,
) -> str:
"""生成缓存键"""
config_str = (
dump_json_safely(override_config, sort_keys=True) if override_config else "None"
)
key_data = f"{provider_model_name}:{config_str}"
return hashlib.md5(key_data.encode()).hexdigest()
def _get_cached_model(cache_key: str) -> LLMModel | None:
"""从缓存获取模型"""
if cache_key in _model_cache:
model, created_time = _model_cache[cache_key]
current_time = time.time()
if current_time - created_time > _cache_ttl:
del _model_cache[cache_key]
logger.debug(f"模型缓存已过期: {cache_key}")
return None
if model._is_closed:
logger.debug(
f"缓存的模型 {cache_key} ({model.provider_name}/{model.model_name}) "
f"处于_is_closed=True状态重置为False以供复用。"
)
model._is_closed = False
logger.debug(
f"使用缓存的模型: {cache_key} -> {model.provider_name}/{model.model_name}"
)
return model
return None
def _cache_model(cache_key: str, model: LLMModel):
"""缓存模型实例"""
current_time = time.time()
if len(_model_cache) >= _max_cache_size:
oldest_key = min(_model_cache.keys(), key=lambda k: _model_cache[k][1])
del _model_cache[oldest_key]
_model_cache[cache_key] = (model, current_time)
def clear_model_cache():
"""
清空模型缓存,释放所有缓存的模型实例。
用于在内存不足或需要强制重新加载模型配置时清理缓存。
"""
global _model_cache
_model_cache.clear()
logger.info("已清空模型缓存")
def get_cache_stats() -> dict[str, Any]:
"""
获取模型缓存的统计信息。
返回:
dict[str, Any]: 包含缓存大小、最大容量、TTL和已缓存模型列表的统计信息。
"""
return {
"cache_size": len(_model_cache),
"max_cache_size": _max_cache_size,
"cache_ttl": _cache_ttl,
"cached_models": list(_model_cache.keys()),
}
def get_default_api_base_for_type(api_type: str) -> str | None:
"""根据API类型获取默认的API基础地址"""
default_api_bases = {
"openai": "https://api.openai.com",
"deepseek": "https://api.deepseek.com/beta",
"zhipu": "https://open.bigmodel.cn",
"gemini": "https://generativelanguage.googleapis.com",
"openrouter": "https://openrouter.ai/api",
"smart": None,
"openai_responses": None,
}
return default_api_bases.get(api_type)
def get_configured_providers() -> list[ProviderConfig]:
"""从配置中获取Provider列表 - 简化和修正版本"""
ai_config = get_ai_config()
providers = ai_config.get(PROVIDERS_CONFIG_KEY, [])
if not isinstance(providers, list):
logger.error(
f"配置项 {AI_CONFIG_GROUP}.{PROVIDERS_CONFIG_KEY} 的值不是一个列表,"
f"将使用空列表。"
)
return []
valid_providers = []
for i, item in enumerate(providers):
if isinstance(item, ProviderConfig):
if not item.api_base:
default_api_base = get_default_api_base_for_type(item.api_type)
if default_api_base:
item.api_base = default_api_base
valid_providers.append(item)
else:
logger.warning(
f"配置文件中第 {i + 1} 项未能正确解析为 ProviderConfig 对象,已跳过。"
f"实际类型: {type(item)}"
)
return valid_providers
def find_model_config(
provider_name: str, model_name: str
) -> tuple[ProviderConfig, ModelDetail] | None:
"""
在配置中查找指定的 Provider 和 ModelDetail
参数:
provider_name: 提供商名称
model_name: 模型名称
返回:
tuple[ProviderConfig, ModelDetail] | None: 找到的配置元组,未找到则返回 None
"""
providers = get_configured_providers()
for provider in providers:
if provider.name.lower() == provider_name.lower():
for model_detail in provider.models:
if model_detail.model_name.lower() == model_name.lower():
return provider, model_detail
return None
def list_available_models() -> list[dict[str, Any]]:
"""
列出所有配置的可用模型及其详细信息。
返回:
list[dict[str, Any]]: 模型信息列表,每个字典包含提供商名称、模型名称、
能力信息、是否为嵌入模型等详细信息。
"""
providers = get_configured_providers()
model_list = []
for provider in providers:
for model_detail in provider.models:
model_info = {
"provider_name": provider.name,
"model_name": model_detail.model_name,
"full_name": f"{provider.name}/{model_detail.model_name}",
"api_type": provider.api_type or "auto-detect",
"api_base": provider.api_base,
"is_available": model_detail.is_available,
"is_embedding_model": model_detail.is_embedding_model,
"available_identifiers": _get_model_identifiers(
provider.name, model_detail
),
}
model_list.append(model_info)
return model_list
def _get_model_identifiers(provider_name: str, model_detail: ModelDetail) -> list[str]:
"""获取模型的所有可用标识符"""
return [f"{provider_name}/{model_detail.model_name}"]
def list_model_identifiers() -> dict[str, list[str]]:
"""
列出所有模型的可用标识符
返回:
dict[str, list[str]]: 字典,键为模型的完整名称,值为该模型的所有可用标识符列表
"""
providers = get_configured_providers()
result = {}
for provider in providers:
for model_detail in provider.models:
full_name = f"{provider.name}/{model_detail.model_name}"
identifiers = _get_model_identifiers(provider.name, model_detail)
result[full_name] = identifiers
return result
def list_embedding_models() -> list[dict[str, Any]]:
"""
列出所有配置的嵌入模型。
返回:
list[dict[str, Any]]: 嵌入模型信息列表,从所有可用模型中筛选出
支持嵌入功能的模型。
"""
all_models = list_available_models()
return [model for model in all_models if model.get("is_embedding_model", False)]
async def get_model_instance(
provider_model_name: str | None = None,
override_config: dict[str, Any] | LLMGenerationConfig | None = None,
) -> LLMModel:
"""
根据 'ProviderName/ModelName' 字符串获取并实例化 LLMModel (异步版本)
参数:
provider_model_name: 模型名称,格式为 'ProviderName/ModelName'
override_config: 覆盖配置字典。
返回:
LLMModel: 模型实例。
"""
cache_key = _make_cache_key(provider_model_name, override_config)
cached_model = _get_cached_model(cache_key)
if cached_model:
if override_config:
validated_override = validate_override_params(override_config)
if cached_model._generation_config != validated_override:
cached_model._generation_config = validated_override
logger.debug(
f"对缓存模型 {provider_model_name} 应用新的覆盖配置: "
f"{validated_override.to_dict()}"
)
return cached_model
resolved_model_name_str = provider_model_name
if resolved_model_name_str is None:
resolved_model_name_str = get_global_default_model_name()
if resolved_model_name_str is None:
available_models_list = list_available_models()
if not available_models_list:
raise LLMException(
"未配置任何AI模型", code=LLMErrorCode.CONFIGURATION_ERROR
)
resolved_model_name_str = available_models_list[0]["full_name"]
logger.warning(f"未指定模型,使用第一个可用模型: {resolved_model_name_str}")
prov_name_str, mod_name_str = parse_provider_model_string(resolved_model_name_str)
if not prov_name_str or not mod_name_str:
raise LLMException(
f"无效的模型名称格式: '{resolved_model_name_str}'",
code=LLMErrorCode.MODEL_NOT_FOUND,
)
config_tuple_found = find_model_config(prov_name_str, mod_name_str)
if not config_tuple_found:
all_models = list_available_models()
raise LLMException(
f"未找到模型: '{resolved_model_name_str}'. "
f"可用: {[m['full_name'] for m in all_models]}",
code=LLMErrorCode.MODEL_NOT_FOUND,
)
provider_config_found, model_detail_found = config_tuple_found
capabilities = get_model_capabilities(model_detail_found.model_name)
model_detail_found.is_embedding_model = capabilities.is_embedding_model
llm_config = get_llm_config()
client_settings = llm_config.client_settings
default_timeout = (
provider_config_found.timeout
if provider_config_found.timeout is not None
else client_settings.timeout
)
config_for_http_client = ProviderConfig(
name=provider_config_found.name,
api_key=provider_config_found.api_key,
models=provider_config_found.models,
timeout=default_timeout,
proxy=client_settings.proxy,
api_base=provider_config_found.api_base,
api_type=provider_config_found.api_type,
openai_compat=provider_config_found.openai_compat,
temperature=provider_config_found.temperature,
max_tokens=provider_config_found.max_tokens,
)
shared_http_client = await http_client_manager.get_client(config_for_http_client)
try:
model_instance = LLMModel(
provider_config=config_for_http_client,
model_detail=model_detail_found,
key_store=key_store,
http_client=shared_http_client,
capabilities=capabilities,
)
if override_config:
validated_override_params = validate_override_params(override_config)
model_instance._generation_config = validated_override_params
logger.debug(
f"为新模型 {resolved_model_name_str} 应用配置覆盖: "
f"{validated_override_params.to_dict()}"
)
_cache_model(cache_key, model_instance)
logger.debug(
f"创建并缓存了新模型: {cache_key} -> {prov_name_str}/{mod_name_str}"
)
return model_instance
except LLMException:
raise
except Exception as e:
logger.error(
f"实例化 LLMModel ({resolved_model_name_str}) 时发生内部错误: {e!s}", e=e
)
raise LLMException(
f"初始化模型 '{resolved_model_name_str}' 失败: {e!s}",
code=LLMErrorCode.MODEL_INIT_FAILED,
cause=e,
)
def get_global_default_model_name() -> str | None:
"""获取全局默认模型名称"""
ai_config = get_ai_config()
return ai_config.get(DEFAULT_MODEL_NAME_KEY)
def set_global_default_model_name(provider_model_name: str | None) -> bool:
"""
设置全局默认模型名称
参数:
provider_model_name: 模型名称,格式为 'ProviderName/ModelName'
返回:
bool: 设置是否成功。
"""
if provider_model_name:
prov_name, mod_name = parse_provider_model_string(provider_model_name)
if not prov_name or not mod_name or not find_model_config(prov_name, mod_name):
logger.error(
f"尝试设置的全局默认模型 '{provider_model_name}' 无效或未配置。"
)
return False
Config.set_config(
AI_CONFIG_GROUP, DEFAULT_MODEL_NAME_KEY, provider_model_name, auto_save=True
)
if provider_model_name:
logger.info(f"LLM 服务全局默认模型已更新为: {provider_model_name}")
else:
logger.info("LLM 服务全局默认模型已清除。")
return True
async def get_key_usage_stats() -> dict[str, Any]:
"""
获取所有Provider的Key使用统计
返回:
dict[str, Any]: 包含所有Provider的Key使用统计信息。
"""
providers = get_configured_providers()
stats = {}
for provider in providers:
provider_stats = await key_store.get_key_stats(
[provider.api_key]
if isinstance(provider.api_key, str)
else provider.api_key
)
stats[provider.name] = {
"total_keys": len(
[provider.api_key]
if isinstance(provider.api_key, str)
else provider.api_key
),
"key_stats": provider_stats,
}
return stats
async def reset_key_status(provider_name: str, api_key: str | None = None) -> bool:
"""
重置指定Provider的Key状态
参数:
provider_name: 提供商名称。
api_key: 要重置的特定API密钥如果为None则重置所有密钥。
返回:
bool: 重置是否成功。
"""
providers = get_configured_providers()
target_provider = None
for provider in providers:
if provider.name.lower() == provider_name.lower():
target_provider = provider
break
if not target_provider:
logger.error(f"未找到Provider: {provider_name}")
return False
provider_keys = (
[target_provider.api_key]
if isinstance(target_provider.api_key, str)
else target_provider.api_key
)
if api_key:
if api_key in provider_keys:
await key_store.reset_key_status(api_key)
logger.info(f"已重置Provider '{provider_name}' 的指定Key状态")
return True
else:
logger.error(f"指定的Key不属于Provider '{provider_name}'")
return False
else:
for key in provider_keys:
await key_store.reset_key_status(key)
logger.info(f"已重置Provider '{provider_name}' 的所有Key状态")
return True