zhenxun_bot/zhenxun/services/llm/manager.py
Rumio a020ea5c87
feat(llm): 实现LLM服务模块,支持多提供商统一接口和高级功能 (#1923)
*  feat(llm): 实现LLM服务模块,支持多提供商统一接口和高级功能

* 🎨 Ruff

*  Config配置类支持BaseModel存储

* 🎨 代码格式化

* 🎨 代码格式化

* 🎨 格式化代码

*  feat(llm): 添加 AI 对话历史管理

*  feat(llmConfig): 引入 LLM 配置模型及管理功能

* 🎨 Ruff

---------

Co-authored-by: fccckaug <xxxmio123123@gmail.com>
Co-authored-by: HibiKier <45528451+HibiKier@users.noreply.github.com>
Co-authored-by: HibiKier <775757368@qq.com>
Co-authored-by: fccckaug <xxxmcsmiomio3@gmail.com>
Co-authored-by: webjoin111 <455457521@qq.com>
2025-06-21 16:33:21 +08:00

435 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
LLM 模型管理器
负责模型实例的创建、缓存、配置管理和生命周期管理。
"""
import hashlib
import json
import time
from typing import Any
from zhenxun.configs.config import Config
from zhenxun.services.log import logger
from .config import validate_override_params
from .config.providers import AI_CONFIG_GROUP, PROVIDERS_CONFIG_KEY, get_ai_config
from .core import http_client_manager, key_store
from .service import LLMModel
from .types import LLMErrorCode, LLMException, ModelDetail, ProviderConfig
DEFAULT_MODEL_NAME_KEY = "default_model_name"
PROXY_KEY = "proxy"
TIMEOUT_KEY = "timeout"
_model_cache: dict[str, tuple[LLMModel, float]] = {}
_cache_ttl = 3600
_max_cache_size = 10
def parse_provider_model_string(name_str: str | None) -> tuple[str | None, str | None]:
"""解析 'ProviderName/ModelName' 格式的字符串"""
if not name_str or "/" not in name_str:
return None, None
parts = name_str.split("/", 1)
if len(parts) == 2 and parts[0].strip() and parts[1].strip():
return parts[0].strip(), parts[1].strip()
return None, None
def _make_cache_key(
provider_model_name: str | None, override_config: dict | None
) -> str:
"""生成缓存键"""
config_str = (
json.dumps(override_config, sort_keys=True) if override_config else "None"
)
key_data = f"{provider_model_name}:{config_str}"
return hashlib.md5(key_data.encode()).hexdigest()
def _get_cached_model(cache_key: str) -> LLMModel | None:
"""从缓存获取模型"""
if cache_key in _model_cache:
model, created_time = _model_cache[cache_key]
current_time = time.time()
if current_time - created_time > _cache_ttl:
del _model_cache[cache_key]
logger.debug(f"模型缓存已过期: {cache_key}")
return None
if model._is_closed:
logger.debug(
f"缓存的模型 {cache_key} ({model.provider_name}/{model.model_name}) "
f"处于_is_closed=True状态重置为False以供复用。"
)
model._is_closed = False
logger.debug(
f"使用缓存的模型: {cache_key} -> {model.provider_name}/{model.model_name}"
)
return model
return None
def _cache_model(cache_key: str, model: LLMModel):
"""缓存模型实例"""
current_time = time.time()
if len(_model_cache) >= _max_cache_size:
oldest_key = min(_model_cache.keys(), key=lambda k: _model_cache[k][1])
del _model_cache[oldest_key]
_model_cache[cache_key] = (model, current_time)
def clear_model_cache():
"""清空模型缓存"""
global _model_cache
_model_cache.clear()
logger.info("已清空模型缓存")
def get_cache_stats() -> dict[str, Any]:
"""获取缓存统计信息"""
return {
"cache_size": len(_model_cache),
"max_cache_size": _max_cache_size,
"cache_ttl": _cache_ttl,
"cached_models": list(_model_cache.keys()),
}
def get_default_api_base_for_type(api_type: str) -> str | None:
"""根据API类型获取默认的API基础地址"""
default_api_bases = {
"openai": "https://api.openai.com",
"deepseek": "https://api.deepseek.com",
"zhipu": "https://open.bigmodel.cn",
"gemini": "https://generativelanguage.googleapis.com",
"general_openai_compat": None,
}
return default_api_bases.get(api_type)
def get_configured_providers() -> list[ProviderConfig]:
"""从配置中获取Provider列表 - 简化版本"""
ai_config = get_ai_config()
providers_raw = ai_config.get(PROVIDERS_CONFIG_KEY, [])
if not isinstance(providers_raw, list):
logger.error(
f"配置项 {AI_CONFIG_GROUP}.{PROVIDERS_CONFIG_KEY} 不是一个列表,"
f"将使用空列表。"
)
return []
valid_providers = []
for i, item in enumerate(providers_raw):
if not isinstance(item, dict):
logger.warning(f"配置文件中第 {i + 1} 项不是字典格式,已跳过。")
continue
try:
if not item.get("name"):
logger.warning(f"Provider {i + 1} 缺少 'name' 字段,已跳过。")
continue
if not item.get("api_key"):
logger.warning(
f"Provider '{item['name']}' 缺少 'api_key' 字段,已跳过。"
)
continue
if "api_type" not in item or not item["api_type"]:
provider_name = item.get("name", "").lower()
if "glm" in provider_name or "zhipu" in provider_name:
item["api_type"] = "zhipu"
elif "gemini" in provider_name or "google" in provider_name:
item["api_type"] = "gemini"
else:
item["api_type"] = "openai"
if "api_base" not in item or not item["api_base"]:
api_type = item.get("api_type")
if api_type:
default_api_base = get_default_api_base_for_type(api_type)
if default_api_base:
item["api_base"] = default_api_base
if "models" not in item:
item["models"] = [{"model_name": item.get("name", "default")}]
provider_conf = ProviderConfig(**item)
valid_providers.append(provider_conf)
except Exception as e:
logger.warning(f"解析配置文件中 Provider {i + 1} 时出错: {e},已跳过。")
return valid_providers
def find_model_config(
provider_name: str, model_name: str
) -> tuple[ProviderConfig, ModelDetail] | None:
"""在配置中查找指定的 Provider 和 ModelDetail
Args:
provider_name: 提供商名称
model_name: 模型名称
Returns:
找到的 (ProviderConfig, ModelDetail) 元组,未找到则返回 None
"""
providers = get_configured_providers()
for provider in providers:
if provider.name.lower() == provider_name.lower():
for model_detail in provider.models:
if model_detail.model_name.lower() == model_name.lower():
return provider, model_detail
return None
def list_available_models() -> list[dict[str, Any]]:
"""列出所有配置的可用模型"""
providers = get_configured_providers()
model_list = []
for provider in providers:
for model_detail in provider.models:
model_info = {
"provider_name": provider.name,
"model_name": model_detail.model_name,
"full_name": f"{provider.name}/{model_detail.model_name}",
"api_type": provider.api_type or "auto-detect",
"api_base": provider.api_base,
"is_available": model_detail.is_available,
"is_embedding_model": model_detail.is_embedding_model,
"available_identifiers": _get_model_identifiers(
provider.name, model_detail
),
}
model_list.append(model_info)
return model_list
def _get_model_identifiers(provider_name: str, model_detail: ModelDetail) -> list[str]:
"""获取模型的所有可用标识符"""
return [f"{provider_name}/{model_detail.model_name}"]
def list_model_identifiers() -> dict[str, list[str]]:
"""列出所有模型的可用标识符
Returns:
字典,键为模型的完整名称,值为该模型的所有可用标识符列表
"""
providers = get_configured_providers()
result = {}
for provider in providers:
for model_detail in provider.models:
full_name = f"{provider.name}/{model_detail.model_name}"
identifiers = _get_model_identifiers(provider.name, model_detail)
result[full_name] = identifiers
return result
def list_embedding_models() -> list[dict[str, Any]]:
"""列出所有配置的嵌入模型"""
all_models = list_available_models()
return [model for model in all_models if model.get("is_embedding_model", False)]
async def get_model_instance(
provider_model_name: str | None = None,
override_config: dict[str, Any] | None = None,
) -> LLMModel:
"""根据 'ProviderName/ModelName' 字符串获取并实例化 LLMModel (异步版本)"""
cache_key = _make_cache_key(provider_model_name, override_config)
cached_model = _get_cached_model(cache_key)
if cached_model:
if override_config:
validated_override = validate_override_params(override_config)
if cached_model._generation_config != validated_override:
cached_model._generation_config = validated_override
logger.debug(
f"对缓存模型 {provider_model_name} 应用新的覆盖配置: "
f"{validated_override.to_dict()}"
)
return cached_model
resolved_model_name_str = provider_model_name
if resolved_model_name_str is None:
resolved_model_name_str = get_global_default_model_name()
if resolved_model_name_str is None:
available_models_list = list_available_models()
if not available_models_list:
raise LLMException(
"未配置任何AI模型", code=LLMErrorCode.CONFIGURATION_ERROR
)
resolved_model_name_str = available_models_list[0]["full_name"]
logger.warning(f"未指定模型,使用第一个可用模型: {resolved_model_name_str}")
prov_name_str, mod_name_str = parse_provider_model_string(resolved_model_name_str)
if not prov_name_str or not mod_name_str:
raise LLMException(
f"无效的模型名称格式: '{resolved_model_name_str}'",
code=LLMErrorCode.MODEL_NOT_FOUND,
)
config_tuple_found = find_model_config(prov_name_str, mod_name_str)
if not config_tuple_found:
all_models = list_available_models()
raise LLMException(
f"未找到模型: '{resolved_model_name_str}'. "
f"可用: {[m['full_name'] for m in all_models]}",
code=LLMErrorCode.MODEL_NOT_FOUND,
)
provider_config_found, model_detail_found = config_tuple_found
ai_config = get_ai_config()
global_proxy_setting = ai_config.get(PROXY_KEY)
default_timeout = (
provider_config_found.timeout
if provider_config_found.timeout is not None
else 180
)
global_timeout_setting = ai_config.get(TIMEOUT_KEY, default_timeout)
config_for_http_client = ProviderConfig(
name=provider_config_found.name,
api_key=provider_config_found.api_key,
models=provider_config_found.models,
timeout=global_timeout_setting,
proxy=global_proxy_setting,
api_base=provider_config_found.api_base,
api_type=provider_config_found.api_type,
openai_compat=provider_config_found.openai_compat,
temperature=provider_config_found.temperature,
max_tokens=provider_config_found.max_tokens,
)
shared_http_client = await http_client_manager.get_client(config_for_http_client)
try:
model_instance = LLMModel(
provider_config=config_for_http_client,
model_detail=model_detail_found,
key_store=key_store,
http_client=shared_http_client,
)
if override_config:
validated_override_params = validate_override_params(override_config)
model_instance._generation_config = validated_override_params
logger.debug(
f"为新模型 {resolved_model_name_str} 应用配置覆盖: "
f"{validated_override_params.to_dict()}"
)
_cache_model(cache_key, model_instance)
logger.debug(
f"创建并缓存了新模型: {cache_key} -> {prov_name_str}/{mod_name_str}"
)
return model_instance
except LLMException:
raise
except Exception as e:
logger.error(
f"实例化 LLMModel ({resolved_model_name_str}) 时发生内部错误: {e!s}", e=e
)
raise LLMException(
f"初始化模型 '{resolved_model_name_str}' 失败: {e!s}",
code=LLMErrorCode.MODEL_INIT_FAILED,
cause=e,
)
def get_global_default_model_name() -> str | None:
"""获取全局默认模型名称"""
ai_config = get_ai_config()
return ai_config.get(DEFAULT_MODEL_NAME_KEY)
def set_global_default_model_name(provider_model_name: str | None) -> bool:
"""设置全局默认模型名称"""
if provider_model_name:
prov_name, mod_name = parse_provider_model_string(provider_model_name)
if not prov_name or not mod_name or not find_model_config(prov_name, mod_name):
logger.error(
f"尝试设置的全局默认模型 '{provider_model_name}' 无效或未配置。"
)
return False
Config.set_config(
AI_CONFIG_GROUP, DEFAULT_MODEL_NAME_KEY, provider_model_name, auto_save=True
)
if provider_model_name:
logger.info(f"LLM 服务全局默认模型已更新为: {provider_model_name}")
else:
logger.info("LLM 服务全局默认模型已清除。")
return True
async def get_key_usage_stats() -> dict[str, Any]:
"""获取所有Provider的Key使用统计"""
providers = get_configured_providers()
stats = {}
for provider in providers:
provider_stats = await key_store.get_key_stats(
[provider.api_key]
if isinstance(provider.api_key, str)
else provider.api_key
)
stats[provider.name] = {
"total_keys": len(
[provider.api_key]
if isinstance(provider.api_key, str)
else provider.api_key
),
"key_stats": provider_stats,
}
return stats
async def reset_key_status(provider_name: str, api_key: str | None = None) -> bool:
"""重置指定Provider的Key状态"""
providers = get_configured_providers()
target_provider = None
for provider in providers:
if provider.name.lower() == provider_name.lower():
target_provider = provider
break
if not target_provider:
logger.error(f"未找到Provider: {provider_name}")
return False
provider_keys = (
[target_provider.api_key]
if isinstance(target_provider.api_key, str)
else target_provider.api_key
)
if api_key:
if api_key in provider_keys:
await key_store.reset_key_status(api_key)
logger.info(f"已重置Provider '{provider_name}' 的指定Key状态")
return True
else:
logger.error(f"指定的Key不属于Provider '{provider_name}'")
return False
else:
for key in provider_keys:
await key_store.reset_key_status(key)
logger.info(f"已重置Provider '{provider_name}' 的所有Key状态")
return True