diff --git a/zhenxun/services/llm/config/__init__.py b/zhenxun/services/llm/config/__init__.py index 778a04bd..09fd9599 100644 --- a/zhenxun/services/llm/config/__init__.py +++ b/zhenxun/services/llm/config/__init__.py @@ -12,14 +12,24 @@ from .generation import ( validate_override_params, ) from .presets import CommonOverrides -from .providers import register_llm_configs +from .providers import ( + LLMConfig, + get_llm_config, + register_llm_configs, + set_default_model, + validate_llm_config, +) __all__ = [ "CommonOverrides", + "LLMConfig", "LLMGenerationConfig", "ModelConfigOverride", "apply_api_specific_mappings", "create_generation_config_from_kwargs", + "get_llm_config", "register_llm_configs", + "set_default_model", + "validate_llm_config", "validate_override_params", ] diff --git a/zhenxun/services/llm/config/providers.py b/zhenxun/services/llm/config/providers.py index bdb1c584..d16050a3 100644 --- a/zhenxun/services/llm/config/providers.py +++ b/zhenxun/services/llm/config/providers.py @@ -4,104 +4,326 @@ LLM 提供商配置管理 负责注册和管理 AI 服务提供商的配置项。 """ +from typing import Any + +from pydantic import BaseModel, Field + from zhenxun.configs.config import Config from zhenxun.services.log import logger -from ..types.models import ProviderConfig +from ..types.models import ModelDetail, ProviderConfig AI_CONFIG_GROUP = "AI" PROVIDERS_CONFIG_KEY = "PROVIDERS" +class LLMConfig(BaseModel): + """LLM 服务配置类""" + + default_model_name: str | None = Field( + default=None, + description="LLM服务全局默认使用的模型名称 (格式: ProviderName/ModelName)", + ) + proxy: str | None = Field( + default=None, + description="LLM服务请求使用的网络代理,例如 http://127.0.0.1:7890", + ) + timeout: int = Field(default=180, description="LLM服务API请求超时时间(秒)") + max_retries_llm: int = Field( + default=3, description="LLM服务请求失败时的最大重试次数" + ) + retry_delay_llm: int = Field( + default=2, description="LLM服务请求重试的基础延迟时间(秒)" + ) + providers: list[ProviderConfig] = Field( + default_factory=list, description="配置多个 AI 服务提供商及其模型信息" + ) + + def get_provider_by_name(self, name: str) -> ProviderConfig | None: + """根据名称获取提供商配置 + + 参数: + name: 提供商名称 + + 返回: + ProviderConfig | None: 提供商配置,如果未找到则返回 None + """ + for provider in self.providers: + if provider.name == name: + return provider + return None + + def get_model_by_provider_and_name( + self, provider_name: str, model_name: str + ) -> tuple[ProviderConfig, ModelDetail] | None: + """根据提供商名称和模型名称获取配置 + + 参数: + provider_name: 提供商名称 + model_name: 模型名称 + + 返回: + tuple[ProviderConfig, ModelDetail] | None: 提供商配置和模型详情的元组, + 如果未找到则返回 None + """ + provider = self.get_provider_by_name(provider_name) + if not provider: + return None + + for model in provider.models: + if model.model_name == model_name: + return provider, model + return None + + def list_available_models(self) -> list[dict[str, Any]]: + """列出所有可用的模型 + + 返回: + list[dict[str, Any]]: 模型信息列表 + """ + models = [] + for provider in self.providers: + for model in provider.models: + models.append( + { + "provider_name": provider.name, + "model_name": model.model_name, + "full_name": f"{provider.name}/{model.model_name}", + "is_available": model.is_available, + "is_embedding_model": model.is_embedding_model, + "api_type": provider.api_type, + } + ) + return models + + def validate_model_name(self, provider_model_name: str) -> bool: + """验证模型名称格式是否正确 + + 参数: + provider_model_name: 格式为 "ProviderName/ModelName" 的字符串 + + 返回: + bool: 是否有效 + """ + if not provider_model_name or "/" not in provider_model_name: + return False + + parts = provider_model_name.split("/", 1) + if len(parts) != 2: + return False + + provider_name, model_name = parts + return ( + self.get_model_by_provider_and_name(provider_name, model_name) is not None + ) + + def get_ai_config(): """获取 AI 配置组""" return Config.get(AI_CONFIG_GROUP) +def get_default_providers() -> list[dict[str, Any]]: + """获取默认的提供商配置 + + 返回: + list[dict[str, Any]]: 默认提供商配置列表 + """ + return [ + { + "name": "DeepSeek", + "api_key": "sk-******", + "api_base": "https://api.deepseek.com", + "api_type": "openai", + "models": [ + { + "model_name": "deepseek-chat", + "max_tokens": 4096, + "temperature": 0.7, + }, + { + "model_name": "deepseek-reasoner", + }, + ], + }, + { + "name": "GLM", + "api_key": "", + "api_base": "https://open.bigmodel.cn", + "api_type": "zhipu", + "models": [ + {"model_name": "glm-4-flash"}, + {"model_name": "glm-4-plus"}, + ], + }, + { + "name": "Gemini", + "api_key": [ + "AIzaSy*****************************", + "AIzaSy*****************************", + "AIzaSy*****************************", + ], + "api_base": "https://generativelanguage.googleapis.com", + "api_type": "gemini", + "models": [ + {"model_name": "gemini-2.0-flash"}, + {"model_name": "gemini-2.5-flash-preview-05-20"}, + ], + }, + ] + + def register_llm_configs(): """注册 LLM 服务的配置项""" logger.info("注册 LLM 服务的配置项") + + llm_config = LLMConfig() + model_fields = LLMConfig.model_fields + Config.add_plugin_config( AI_CONFIG_GROUP, "default_model_name", - None, - help="LLM服务全局默认使用的模型名称 (格式: ProviderName/ModelName)", + llm_config.default_model_name, + help=model_fields["default_model_name"].description, type=str, ) Config.add_plugin_config( AI_CONFIG_GROUP, "proxy", - None, - help="LLM服务请求使用的网络代理,例如 http://127.0.0.1:7890", + llm_config.proxy, + help=model_fields["proxy"].description, type=str, ) Config.add_plugin_config( AI_CONFIG_GROUP, "timeout", - 180, - help="LLM服务API请求超时时间(秒)", + llm_config.timeout, + help=model_fields["timeout"].description, type=int, ) Config.add_plugin_config( AI_CONFIG_GROUP, "max_retries_llm", - 3, - help="LLM服务请求失败时的最大重试次数", + llm_config.max_retries_llm, + help=model_fields["max_retries_llm"].description, type=int, ) Config.add_plugin_config( AI_CONFIG_GROUP, "retry_delay_llm", - 2, - help="LLM服务请求重试的基础延迟时间(秒)", + llm_config.retry_delay_llm, + help=model_fields["retry_delay_llm"].description, type=int, ) + Config.add_plugin_config( AI_CONFIG_GROUP, PROVIDERS_CONFIG_KEY, - [ - { - "name": "DeepSeek", - "api_key": "sk-******", - "api_base": "https://api.deepseek.com", - "api_type": "openai", - "models": [ - { - "model_name": "deepseek-chat", - "max_tokens": 4096, - "temperature": 0.7, - }, - { - "model_name": "deepseek-reasoner", - }, - ], - }, - { - "name": "GLM", - "api_key": "", - "api_base": "https://open.bigmodel.cn", - "api_type": "zhipu", - "models": [ - {"model_name": "glm-4-flash"}, - {"model_name": "glm-4-plus"}, - ], - }, - { - "name": "Gemini", - "api_key": [ - "AIzaSy*****************************", - "AIzaSy*****************************", - "AIzaSy*****************************", - ], - "api_base": "https://generativelanguage.googleapis.com", - "api_type": "gemini", - "models": [ - {"model_name": "gemini-2.0-flash"}, - {"model_name": "gemini-2.5-flash-preview-05-20"}, - ], - }, - ], - help="配置多个 AI 服务提供商及其模型信息 (列表)", + get_default_providers(), + help=model_fields["providers"].description, default_value=[], type=list[ProviderConfig], ) + + +def get_llm_config() -> LLMConfig: + """获取 LLM 配置实例 + + 返回: + LLMConfig: LLM 配置实例 + """ + ai_config = get_ai_config() + + config_data = { + "default_model_name": ai_config.get("default_model_name"), + "proxy": ai_config.get("proxy"), + "timeout": ai_config.get("timeout", 180), + "max_retries_llm": ai_config.get("max_retries_llm", 3), + "retry_delay_llm": ai_config.get("retry_delay_llm", 2), + "providers": ai_config.get(PROVIDERS_CONFIG_KEY, []), + } + + return LLMConfig(**config_data) + + +def validate_llm_config() -> tuple[bool, list[str]]: + """验证 LLM 配置的有效性 + + 返回: + tuple[bool, list[str]]: (是否有效, 错误信息列表) + """ + errors = [] + + try: + llm_config = get_llm_config() + + if llm_config.timeout <= 0: + errors.append("timeout 必须大于 0") + + if llm_config.max_retries_llm < 0: + errors.append("max_retries_llm 不能小于 0") + + if llm_config.retry_delay_llm <= 0: + errors.append("retry_delay_llm 必须大于 0") + + if not llm_config.providers: + errors.append("至少需要配置一个 AI 服务提供商") + else: + provider_names = set() + for provider in llm_config.providers: + if provider.name in provider_names: + errors.append(f"提供商名称重复: {provider.name}") + provider_names.add(provider.name) + + if not provider.api_key: + errors.append(f"提供商 {provider.name} 缺少 API Key") + + if not provider.models: + errors.append(f"提供商 {provider.name} 没有配置任何模型") + else: + model_names = set() + for model in provider.models: + if model.model_name in model_names: + errors.append( + f"提供商 {provider.name} 中模型名称重复: " + f"{model.model_name}" + ) + model_names.add(model.model_name) + + if llm_config.default_model_name: + if not llm_config.validate_model_name(llm_config.default_model_name): + errors.append( + f"默认模型 {llm_config.default_model_name} 在配置中不存在" + ) + + except Exception as e: + errors.append(f"配置解析失败: {e!s}") + + return len(errors) == 0, errors + + +def set_default_model(provider_model_name: str | None) -> bool: + """设置默认模型 + + 参数: + provider_model_name: 模型名称,格式为 "ProviderName/ModelName",None 表示清除 + + 返回: + bool: 是否设置成功 + """ + if provider_model_name: + llm_config = get_llm_config() + if not llm_config.validate_model_name(provider_model_name): + logger.error(f"模型 {provider_model_name} 在配置中不存在") + return False + + Config.set_config( + AI_CONFIG_GROUP, "default_model_name", provider_model_name, auto_save=True + ) + + if provider_model_name: + logger.info(f"默认模型已设置为: {provider_model_name}") + else: + logger.info("默认模型已清除") + + return True