mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-14 21:52:56 +08:00
- 【重构】LLM 服务核心架构:
- 引入中间件管道,统一处理请求生命周期(重试、密钥选择、日志、网络请求)。
- 适配器重构为组件化设计,分离配置映射、消息转换、响应解析和工具序列化逻辑。
- 移除 `with_smart_retry` 装饰器,其功能由中间件接管。
- 移除 `LLMToolExecutor`,工具执行逻辑集成到 `ToolInvoker`。
- 【功能】增强配置系统:
- `LLMGenerationConfig` 采用组件化结构(Core, Reasoning, Visual, Output, Safety, ToolConfig)。
- 新增 `GenConfigBuilder` 提供语义化配置构建方式。
- 新增 `LLMEmbeddingConfig` 用于嵌入专用配置。
- `CommonOverrides` 迁移并更新至新配置结构。
- 【功能】强化工具系统:
- 引入 `ToolInvoker` 实现更灵活的工具执行,支持回调与结构化错误。
- `function_tool` 装饰器支持动态 Pydantic 模型创建和依赖注入 (`ToolParam`, `RunContext`)。
- 平台原生工具支持 (`GeminiCodeExecution`, `GeminiGoogleSearch`, `GeminiUrlContext`)。
- 【功能】高级生成与嵌入:
- `generate_structured` 方法支持 In-Context Validation and Repair (IVR) 循环和 AutoCoT (思维链) 包装。
- 新增 `embed_query` 和 `embed_documents` 便捷嵌入 API。
- `OpenAIImageAdapter` 支持 OpenAI 兼容的图像生成。
- `SmartAdapter` 实现模型名称智能路由。
- 【重构】消息与类型系统:
- `LLMContentPart` 扩展支持更多模态和代码执行相关内容。
- `LLMMessage` 和 `LLMResponse` 结构更新,支持 `content_parts` 和思维链签名。
- 统一 `LLMErrorCode` 和用户友好错误消息,提供更详细的网络/代理错误提示。
- `pyproject.toml` 移除 `bilireq`,新增 `json_repair`。
- 【优化】日志与调试:
- 引入 `DebugLogOptions`,提供细粒度日志脱敏控制。
- 增强日志净化器,处理更多敏感数据和长字符串。
- 【清理】删除废弃模块:
- `zhenxun/services/llm/memory.py`
- `zhenxun/services/llm/executor.py`
- `zhenxun/services/llm/config/presets.py`
- `zhenxun/services/llm/types/content.py`
- `zhenxun/services/llm/types/enums.py`
- `zhenxun/services/llm/tools/__init__.py`
- `zhenxun/services/llm/tools/manager.py`
474 lines
16 KiB
Python
474 lines
16 KiB
Python
"""
|
||
LLM 模型管理器
|
||
|
||
负责模型实例的创建、缓存、配置管理和生命周期管理。
|
||
"""
|
||
|
||
import hashlib
|
||
import time
|
||
from typing import Any
|
||
|
||
from zhenxun.configs.config import Config
|
||
from zhenxun.services.log import logger
|
||
from zhenxun.utils.pydantic_compat import dump_json_safely
|
||
|
||
from .config import validate_override_params
|
||
from .config.generation import LLMGenerationConfig
|
||
from .config.providers import (
|
||
AI_CONFIG_GROUP,
|
||
PROVIDERS_CONFIG_KEY,
|
||
get_ai_config,
|
||
get_llm_config,
|
||
)
|
||
from .core import http_client_manager, key_store
|
||
from .service import LLMModel
|
||
from .types import LLMErrorCode, LLMException, ModelDetail, ProviderConfig
|
||
from .types.capabilities import get_model_capabilities
|
||
|
||
DEFAULT_MODEL_NAME_KEY = "default_model_name"
|
||
|
||
_model_cache: dict[str, tuple[LLMModel, float]] = {}
|
||
_cache_ttl = 3600
|
||
_max_cache_size = 10
|
||
|
||
|
||
def parse_provider_model_string(name_str: str | None) -> tuple[str | None, str | None]:
|
||
"""解析 'ProviderName/ModelName' 格式的字符串"""
|
||
if not name_str or "/" not in name_str:
|
||
return None, None
|
||
parts = name_str.split("/", 1)
|
||
if len(parts) == 2 and parts[0].strip() and parts[1].strip():
|
||
return parts[0].strip(), parts[1].strip()
|
||
return None, None
|
||
|
||
|
||
def _make_cache_key(
|
||
provider_model_name: str | None,
|
||
override_config: dict | LLMGenerationConfig | None,
|
||
) -> str:
|
||
"""生成缓存键"""
|
||
config_str = (
|
||
dump_json_safely(override_config, sort_keys=True) if override_config else "None"
|
||
)
|
||
key_data = f"{provider_model_name}:{config_str}"
|
||
return hashlib.md5(key_data.encode()).hexdigest()
|
||
|
||
|
||
def _get_cached_model(cache_key: str) -> LLMModel | None:
|
||
"""从缓存获取模型"""
|
||
if cache_key in _model_cache:
|
||
model, created_time = _model_cache[cache_key]
|
||
current_time = time.time()
|
||
|
||
if current_time - created_time > _cache_ttl:
|
||
del _model_cache[cache_key]
|
||
logger.debug(f"模型缓存已过期: {cache_key}")
|
||
return None
|
||
|
||
if model._is_closed:
|
||
logger.debug(
|
||
f"缓存的模型 {cache_key} ({model.provider_name}/{model.model_name}) "
|
||
f"处于_is_closed=True状态,重置为False以供复用。"
|
||
)
|
||
model._is_closed = False
|
||
|
||
logger.debug(
|
||
f"使用缓存的模型: {cache_key} -> {model.provider_name}/{model.model_name}"
|
||
)
|
||
return model
|
||
return None
|
||
|
||
|
||
def _cache_model(cache_key: str, model: LLMModel):
|
||
"""缓存模型实例"""
|
||
current_time = time.time()
|
||
|
||
if len(_model_cache) >= _max_cache_size:
|
||
oldest_key = min(_model_cache.keys(), key=lambda k: _model_cache[k][1])
|
||
del _model_cache[oldest_key]
|
||
|
||
_model_cache[cache_key] = (model, current_time)
|
||
|
||
|
||
def clear_model_cache():
|
||
"""
|
||
清空模型缓存,释放所有缓存的模型实例。
|
||
|
||
用于在内存不足或需要强制重新加载模型配置时清理缓存。
|
||
"""
|
||
global _model_cache
|
||
_model_cache.clear()
|
||
logger.info("已清空模型缓存")
|
||
|
||
|
||
def get_cache_stats() -> dict[str, Any]:
|
||
"""
|
||
获取模型缓存的统计信息。
|
||
|
||
返回:
|
||
dict[str, Any]: 包含缓存大小、最大容量、TTL和已缓存模型列表的统计信息。
|
||
"""
|
||
return {
|
||
"cache_size": len(_model_cache),
|
||
"max_cache_size": _max_cache_size,
|
||
"cache_ttl": _cache_ttl,
|
||
"cached_models": list(_model_cache.keys()),
|
||
}
|
||
|
||
|
||
def get_default_api_base_for_type(api_type: str) -> str | None:
|
||
"""根据API类型获取默认的API基础地址"""
|
||
default_api_bases = {
|
||
"openai": "https://api.openai.com",
|
||
"deepseek": "https://api.deepseek.com/beta",
|
||
"zhipu": "https://open.bigmodel.cn",
|
||
"gemini": "https://generativelanguage.googleapis.com",
|
||
"openrouter": "https://openrouter.ai/api",
|
||
"smart": None,
|
||
"openai_responses": None,
|
||
}
|
||
|
||
return default_api_bases.get(api_type)
|
||
|
||
|
||
def get_configured_providers() -> list[ProviderConfig]:
|
||
"""从配置中获取Provider列表 - 简化和修正版本"""
|
||
ai_config = get_ai_config()
|
||
providers = ai_config.get(PROVIDERS_CONFIG_KEY, [])
|
||
|
||
if not isinstance(providers, list):
|
||
logger.error(
|
||
f"配置项 {AI_CONFIG_GROUP}.{PROVIDERS_CONFIG_KEY} 的值不是一个列表,"
|
||
f"将使用空列表。"
|
||
)
|
||
return []
|
||
|
||
valid_providers = []
|
||
for i, item in enumerate(providers):
|
||
if isinstance(item, ProviderConfig):
|
||
if not item.api_base:
|
||
default_api_base = get_default_api_base_for_type(item.api_type)
|
||
if default_api_base:
|
||
item.api_base = default_api_base
|
||
valid_providers.append(item)
|
||
else:
|
||
logger.warning(
|
||
f"配置文件中第 {i + 1} 项未能正确解析为 ProviderConfig 对象,已跳过。"
|
||
f"实际类型: {type(item)}"
|
||
)
|
||
|
||
return valid_providers
|
||
|
||
|
||
def find_model_config(
|
||
provider_name: str, model_name: str
|
||
) -> tuple[ProviderConfig, ModelDetail] | None:
|
||
"""
|
||
在配置中查找指定的 Provider 和 ModelDetail
|
||
|
||
参数:
|
||
provider_name: 提供商名称
|
||
model_name: 模型名称
|
||
|
||
返回:
|
||
tuple[ProviderConfig, ModelDetail] | None: 找到的配置元组,未找到则返回 None
|
||
"""
|
||
providers = get_configured_providers()
|
||
|
||
for provider in providers:
|
||
if provider.name.lower() == provider_name.lower():
|
||
for model_detail in provider.models:
|
||
if model_detail.model_name.lower() == model_name.lower():
|
||
return provider, model_detail
|
||
|
||
return None
|
||
|
||
|
||
def list_available_models() -> list[dict[str, Any]]:
|
||
"""
|
||
列出所有配置的可用模型及其详细信息。
|
||
|
||
返回:
|
||
list[dict[str, Any]]: 模型信息列表,每个字典包含提供商名称、模型名称、
|
||
能力信息、是否为嵌入模型等详细信息。
|
||
"""
|
||
providers = get_configured_providers()
|
||
model_list = []
|
||
for provider in providers:
|
||
for model_detail in provider.models:
|
||
model_info = {
|
||
"provider_name": provider.name,
|
||
"model_name": model_detail.model_name,
|
||
"full_name": f"{provider.name}/{model_detail.model_name}",
|
||
"api_type": provider.api_type or "auto-detect",
|
||
"api_base": provider.api_base,
|
||
"is_available": model_detail.is_available,
|
||
"is_embedding_model": model_detail.is_embedding_model,
|
||
"available_identifiers": _get_model_identifiers(
|
||
provider.name, model_detail
|
||
),
|
||
}
|
||
model_list.append(model_info)
|
||
return model_list
|
||
|
||
|
||
def _get_model_identifiers(provider_name: str, model_detail: ModelDetail) -> list[str]:
|
||
"""获取模型的所有可用标识符"""
|
||
return [f"{provider_name}/{model_detail.model_name}"]
|
||
|
||
|
||
def list_model_identifiers() -> dict[str, list[str]]:
|
||
"""
|
||
列出所有模型的可用标识符
|
||
|
||
返回:
|
||
dict[str, list[str]]: 字典,键为模型的完整名称,值为该模型的所有可用标识符列表
|
||
"""
|
||
providers = get_configured_providers()
|
||
result = {}
|
||
|
||
for provider in providers:
|
||
for model_detail in provider.models:
|
||
full_name = f"{provider.name}/{model_detail.model_name}"
|
||
identifiers = _get_model_identifiers(provider.name, model_detail)
|
||
result[full_name] = identifiers
|
||
|
||
return result
|
||
|
||
|
||
def list_embedding_models() -> list[dict[str, Any]]:
|
||
"""
|
||
列出所有配置的嵌入模型。
|
||
|
||
返回:
|
||
list[dict[str, Any]]: 嵌入模型信息列表,从所有可用模型中筛选出
|
||
支持嵌入功能的模型。
|
||
"""
|
||
all_models = list_available_models()
|
||
return [model for model in all_models if model.get("is_embedding_model", False)]
|
||
|
||
|
||
async def get_model_instance(
|
||
provider_model_name: str | None = None,
|
||
override_config: dict[str, Any] | LLMGenerationConfig | None = None,
|
||
) -> LLMModel:
|
||
"""
|
||
根据 'ProviderName/ModelName' 字符串获取并实例化 LLMModel (异步版本)
|
||
|
||
参数:
|
||
provider_model_name: 模型名称,格式为 'ProviderName/ModelName'。
|
||
override_config: 覆盖配置字典。
|
||
|
||
返回:
|
||
LLMModel: 模型实例。
|
||
"""
|
||
cache_key = _make_cache_key(provider_model_name, override_config)
|
||
cached_model = _get_cached_model(cache_key)
|
||
if cached_model:
|
||
if override_config:
|
||
validated_override = validate_override_params(override_config)
|
||
if cached_model._generation_config != validated_override:
|
||
cached_model._generation_config = validated_override
|
||
logger.debug(
|
||
f"对缓存模型 {provider_model_name} 应用新的覆盖配置: "
|
||
f"{validated_override.to_dict()}"
|
||
)
|
||
return cached_model
|
||
|
||
resolved_model_name_str = provider_model_name
|
||
if resolved_model_name_str is None:
|
||
resolved_model_name_str = get_global_default_model_name()
|
||
if resolved_model_name_str is None:
|
||
available_models_list = list_available_models()
|
||
if not available_models_list:
|
||
raise LLMException(
|
||
"未配置任何AI模型", code=LLMErrorCode.CONFIGURATION_ERROR
|
||
)
|
||
resolved_model_name_str = available_models_list[0]["full_name"]
|
||
logger.warning(f"未指定模型,使用第一个可用模型: {resolved_model_name_str}")
|
||
|
||
prov_name_str, mod_name_str = parse_provider_model_string(resolved_model_name_str)
|
||
if not prov_name_str or not mod_name_str:
|
||
raise LLMException(
|
||
f"无效的模型名称格式: '{resolved_model_name_str}'",
|
||
code=LLMErrorCode.MODEL_NOT_FOUND,
|
||
)
|
||
|
||
config_tuple_found = find_model_config(prov_name_str, mod_name_str)
|
||
if not config_tuple_found:
|
||
all_models = list_available_models()
|
||
raise LLMException(
|
||
f"未找到模型: '{resolved_model_name_str}'. "
|
||
f"可用: {[m['full_name'] for m in all_models]}",
|
||
code=LLMErrorCode.MODEL_NOT_FOUND,
|
||
)
|
||
|
||
provider_config_found, model_detail_found = config_tuple_found
|
||
|
||
capabilities = get_model_capabilities(model_detail_found.model_name)
|
||
|
||
model_detail_found.is_embedding_model = capabilities.is_embedding_model
|
||
|
||
llm_config = get_llm_config()
|
||
client_settings = llm_config.client_settings
|
||
default_timeout = (
|
||
provider_config_found.timeout
|
||
if provider_config_found.timeout is not None
|
||
else client_settings.timeout
|
||
)
|
||
|
||
config_for_http_client = ProviderConfig(
|
||
name=provider_config_found.name,
|
||
api_key=provider_config_found.api_key,
|
||
models=provider_config_found.models,
|
||
timeout=default_timeout,
|
||
proxy=client_settings.proxy,
|
||
api_base=provider_config_found.api_base,
|
||
api_type=provider_config_found.api_type,
|
||
openai_compat=provider_config_found.openai_compat,
|
||
temperature=provider_config_found.temperature,
|
||
max_tokens=provider_config_found.max_tokens,
|
||
)
|
||
|
||
shared_http_client = await http_client_manager.get_client(config_for_http_client)
|
||
|
||
try:
|
||
model_instance = LLMModel(
|
||
provider_config=config_for_http_client,
|
||
model_detail=model_detail_found,
|
||
key_store=key_store,
|
||
http_client=shared_http_client,
|
||
capabilities=capabilities,
|
||
)
|
||
|
||
if override_config:
|
||
validated_override_params = validate_override_params(override_config)
|
||
model_instance._generation_config = validated_override_params
|
||
logger.debug(
|
||
f"为新模型 {resolved_model_name_str} 应用配置覆盖: "
|
||
f"{validated_override_params.to_dict()}"
|
||
)
|
||
|
||
_cache_model(cache_key, model_instance)
|
||
logger.debug(
|
||
f"创建并缓存了新模型: {cache_key} -> {prov_name_str}/{mod_name_str}"
|
||
)
|
||
return model_instance
|
||
except LLMException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(
|
||
f"实例化 LLMModel ({resolved_model_name_str}) 时发生内部错误: {e!s}", e=e
|
||
)
|
||
raise LLMException(
|
||
f"初始化模型 '{resolved_model_name_str}' 失败: {e!s}",
|
||
code=LLMErrorCode.MODEL_INIT_FAILED,
|
||
cause=e,
|
||
)
|
||
|
||
|
||
def get_global_default_model_name() -> str | None:
|
||
"""获取全局默认模型名称"""
|
||
ai_config = get_ai_config()
|
||
return ai_config.get(DEFAULT_MODEL_NAME_KEY)
|
||
|
||
|
||
def set_global_default_model_name(provider_model_name: str | None) -> bool:
|
||
"""
|
||
设置全局默认模型名称
|
||
|
||
参数:
|
||
provider_model_name: 模型名称,格式为 'ProviderName/ModelName'。
|
||
|
||
返回:
|
||
bool: 设置是否成功。
|
||
"""
|
||
if provider_model_name:
|
||
prov_name, mod_name = parse_provider_model_string(provider_model_name)
|
||
if not prov_name or not mod_name or not find_model_config(prov_name, mod_name):
|
||
logger.error(
|
||
f"尝试设置的全局默认模型 '{provider_model_name}' 无效或未配置。"
|
||
)
|
||
return False
|
||
|
||
Config.set_config(
|
||
AI_CONFIG_GROUP, DEFAULT_MODEL_NAME_KEY, provider_model_name, auto_save=True
|
||
)
|
||
if provider_model_name:
|
||
logger.info(f"LLM 服务全局默认模型已更新为: {provider_model_name}")
|
||
else:
|
||
logger.info("LLM 服务全局默认模型已清除。")
|
||
return True
|
||
|
||
|
||
async def get_key_usage_stats() -> dict[str, Any]:
|
||
"""
|
||
获取所有Provider的Key使用统计
|
||
|
||
返回:
|
||
dict[str, Any]: 包含所有Provider的Key使用统计信息。
|
||
"""
|
||
providers = get_configured_providers()
|
||
stats = {}
|
||
|
||
for provider in providers:
|
||
provider_stats = await key_store.get_key_stats(
|
||
[provider.api_key]
|
||
if isinstance(provider.api_key, str)
|
||
else provider.api_key
|
||
)
|
||
stats[provider.name] = {
|
||
"total_keys": len(
|
||
[provider.api_key]
|
||
if isinstance(provider.api_key, str)
|
||
else provider.api_key
|
||
),
|
||
"key_stats": provider_stats,
|
||
}
|
||
|
||
return stats
|
||
|
||
|
||
async def reset_key_status(provider_name: str, api_key: str | None = None) -> bool:
|
||
"""
|
||
重置指定Provider的Key状态
|
||
|
||
参数:
|
||
provider_name: 提供商名称。
|
||
api_key: 要重置的特定API密钥,如果为None则重置所有密钥。
|
||
|
||
返回:
|
||
bool: 重置是否成功。
|
||
"""
|
||
providers = get_configured_providers()
|
||
target_provider = None
|
||
|
||
for provider in providers:
|
||
if provider.name.lower() == provider_name.lower():
|
||
target_provider = provider
|
||
break
|
||
|
||
if not target_provider:
|
||
logger.error(f"未找到Provider: {provider_name}")
|
||
return False
|
||
|
||
provider_keys = (
|
||
[target_provider.api_key]
|
||
if isinstance(target_provider.api_key, str)
|
||
else target_provider.api_key
|
||
)
|
||
|
||
if api_key:
|
||
if api_key in provider_keys:
|
||
await key_store.reset_key_status(api_key)
|
||
logger.info(f"已重置Provider '{provider_name}' 的指定Key状态")
|
||
return True
|
||
else:
|
||
logger.error(f"指定的Key不属于Provider '{provider_name}'")
|
||
return False
|
||
else:
|
||
for key in provider_keys:
|
||
await key_store.reset_key_status(key)
|
||
logger.info(f"已重置Provider '{provider_name}' 的所有Key状态")
|
||
return True
|