mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 14:22:55 +08:00
✨ feat(llmConfig): 引入 LLM 配置模型及管理功能
This commit is contained in:
parent
2ac3baf63d
commit
194448b08e
@ -12,14 +12,24 @@ from .generation import (
|
||||
validate_override_params,
|
||||
)
|
||||
from .presets import CommonOverrides
|
||||
from .providers import register_llm_configs
|
||||
from .providers import (
|
||||
LLMConfig,
|
||||
get_llm_config,
|
||||
register_llm_configs,
|
||||
set_default_model,
|
||||
validate_llm_config,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CommonOverrides",
|
||||
"LLMConfig",
|
||||
"LLMGenerationConfig",
|
||||
"ModelConfigOverride",
|
||||
"apply_api_specific_mappings",
|
||||
"create_generation_config_from_kwargs",
|
||||
"get_llm_config",
|
||||
"register_llm_configs",
|
||||
"set_default_model",
|
||||
"validate_llm_config",
|
||||
"validate_override_params",
|
||||
]
|
||||
|
||||
@ -4,104 +4,326 @@ LLM 提供商配置管理
|
||||
负责注册和管理 AI 服务提供商的配置项。
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
from ..types.models import ProviderConfig
|
||||
from ..types.models import ModelDetail, ProviderConfig
|
||||
|
||||
AI_CONFIG_GROUP = "AI"
|
||||
PROVIDERS_CONFIG_KEY = "PROVIDERS"
|
||||
|
||||
|
||||
class LLMConfig(BaseModel):
|
||||
"""LLM 服务配置类"""
|
||||
|
||||
default_model_name: str | None = Field(
|
||||
default=None,
|
||||
description="LLM服务全局默认使用的模型名称 (格式: ProviderName/ModelName)",
|
||||
)
|
||||
proxy: str | None = Field(
|
||||
default=None,
|
||||
description="LLM服务请求使用的网络代理,例如 http://127.0.0.1:7890",
|
||||
)
|
||||
timeout: int = Field(default=180, description="LLM服务API请求超时时间(秒)")
|
||||
max_retries_llm: int = Field(
|
||||
default=3, description="LLM服务请求失败时的最大重试次数"
|
||||
)
|
||||
retry_delay_llm: int = Field(
|
||||
default=2, description="LLM服务请求重试的基础延迟时间(秒)"
|
||||
)
|
||||
providers: list[ProviderConfig] = Field(
|
||||
default_factory=list, description="配置多个 AI 服务提供商及其模型信息"
|
||||
)
|
||||
|
||||
def get_provider_by_name(self, name: str) -> ProviderConfig | None:
|
||||
"""根据名称获取提供商配置
|
||||
|
||||
参数:
|
||||
name: 提供商名称
|
||||
|
||||
返回:
|
||||
ProviderConfig | None: 提供商配置,如果未找到则返回 None
|
||||
"""
|
||||
for provider in self.providers:
|
||||
if provider.name == name:
|
||||
return provider
|
||||
return None
|
||||
|
||||
def get_model_by_provider_and_name(
|
||||
self, provider_name: str, model_name: str
|
||||
) -> tuple[ProviderConfig, ModelDetail] | None:
|
||||
"""根据提供商名称和模型名称获取配置
|
||||
|
||||
参数:
|
||||
provider_name: 提供商名称
|
||||
model_name: 模型名称
|
||||
|
||||
返回:
|
||||
tuple[ProviderConfig, ModelDetail] | None: 提供商配置和模型详情的元组,
|
||||
如果未找到则返回 None
|
||||
"""
|
||||
provider = self.get_provider_by_name(provider_name)
|
||||
if not provider:
|
||||
return None
|
||||
|
||||
for model in provider.models:
|
||||
if model.model_name == model_name:
|
||||
return provider, model
|
||||
return None
|
||||
|
||||
def list_available_models(self) -> list[dict[str, Any]]:
|
||||
"""列出所有可用的模型
|
||||
|
||||
返回:
|
||||
list[dict[str, Any]]: 模型信息列表
|
||||
"""
|
||||
models = []
|
||||
for provider in self.providers:
|
||||
for model in provider.models:
|
||||
models.append(
|
||||
{
|
||||
"provider_name": provider.name,
|
||||
"model_name": model.model_name,
|
||||
"full_name": f"{provider.name}/{model.model_name}",
|
||||
"is_available": model.is_available,
|
||||
"is_embedding_model": model.is_embedding_model,
|
||||
"api_type": provider.api_type,
|
||||
}
|
||||
)
|
||||
return models
|
||||
|
||||
def validate_model_name(self, provider_model_name: str) -> bool:
|
||||
"""验证模型名称格式是否正确
|
||||
|
||||
参数:
|
||||
provider_model_name: 格式为 "ProviderName/ModelName" 的字符串
|
||||
|
||||
返回:
|
||||
bool: 是否有效
|
||||
"""
|
||||
if not provider_model_name or "/" not in provider_model_name:
|
||||
return False
|
||||
|
||||
parts = provider_model_name.split("/", 1)
|
||||
if len(parts) != 2:
|
||||
return False
|
||||
|
||||
provider_name, model_name = parts
|
||||
return (
|
||||
self.get_model_by_provider_and_name(provider_name, model_name) is not None
|
||||
)
|
||||
|
||||
|
||||
def get_ai_config():
|
||||
"""获取 AI 配置组"""
|
||||
return Config.get(AI_CONFIG_GROUP)
|
||||
|
||||
|
||||
def get_default_providers() -> list[dict[str, Any]]:
|
||||
"""获取默认的提供商配置
|
||||
|
||||
返回:
|
||||
list[dict[str, Any]]: 默认提供商配置列表
|
||||
"""
|
||||
return [
|
||||
{
|
||||
"name": "DeepSeek",
|
||||
"api_key": "sk-******",
|
||||
"api_base": "https://api.deepseek.com",
|
||||
"api_type": "openai",
|
||||
"models": [
|
||||
{
|
||||
"model_name": "deepseek-chat",
|
||||
"max_tokens": 4096,
|
||||
"temperature": 0.7,
|
||||
},
|
||||
{
|
||||
"model_name": "deepseek-reasoner",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "GLM",
|
||||
"api_key": "",
|
||||
"api_base": "https://open.bigmodel.cn",
|
||||
"api_type": "zhipu",
|
||||
"models": [
|
||||
{"model_name": "glm-4-flash"},
|
||||
{"model_name": "glm-4-plus"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Gemini",
|
||||
"api_key": [
|
||||
"AIzaSy*****************************",
|
||||
"AIzaSy*****************************",
|
||||
"AIzaSy*****************************",
|
||||
],
|
||||
"api_base": "https://generativelanguage.googleapis.com",
|
||||
"api_type": "gemini",
|
||||
"models": [
|
||||
{"model_name": "gemini-2.0-flash"},
|
||||
{"model_name": "gemini-2.5-flash-preview-05-20"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def register_llm_configs():
|
||||
"""注册 LLM 服务的配置项"""
|
||||
logger.info("注册 LLM 服务的配置项")
|
||||
|
||||
llm_config = LLMConfig()
|
||||
model_fields = LLMConfig.model_fields
|
||||
|
||||
Config.add_plugin_config(
|
||||
AI_CONFIG_GROUP,
|
||||
"default_model_name",
|
||||
None,
|
||||
help="LLM服务全局默认使用的模型名称 (格式: ProviderName/ModelName)",
|
||||
llm_config.default_model_name,
|
||||
help=model_fields["default_model_name"].description,
|
||||
type=str,
|
||||
)
|
||||
Config.add_plugin_config(
|
||||
AI_CONFIG_GROUP,
|
||||
"proxy",
|
||||
None,
|
||||
help="LLM服务请求使用的网络代理,例如 http://127.0.0.1:7890",
|
||||
llm_config.proxy,
|
||||
help=model_fields["proxy"].description,
|
||||
type=str,
|
||||
)
|
||||
Config.add_plugin_config(
|
||||
AI_CONFIG_GROUP,
|
||||
"timeout",
|
||||
180,
|
||||
help="LLM服务API请求超时时间(秒)",
|
||||
llm_config.timeout,
|
||||
help=model_fields["timeout"].description,
|
||||
type=int,
|
||||
)
|
||||
Config.add_plugin_config(
|
||||
AI_CONFIG_GROUP,
|
||||
"max_retries_llm",
|
||||
3,
|
||||
help="LLM服务请求失败时的最大重试次数",
|
||||
llm_config.max_retries_llm,
|
||||
help=model_fields["max_retries_llm"].description,
|
||||
type=int,
|
||||
)
|
||||
Config.add_plugin_config(
|
||||
AI_CONFIG_GROUP,
|
||||
"retry_delay_llm",
|
||||
2,
|
||||
help="LLM服务请求重试的基础延迟时间(秒)",
|
||||
llm_config.retry_delay_llm,
|
||||
help=model_fields["retry_delay_llm"].description,
|
||||
type=int,
|
||||
)
|
||||
|
||||
Config.add_plugin_config(
|
||||
AI_CONFIG_GROUP,
|
||||
PROVIDERS_CONFIG_KEY,
|
||||
[
|
||||
{
|
||||
"name": "DeepSeek",
|
||||
"api_key": "sk-******",
|
||||
"api_base": "https://api.deepseek.com",
|
||||
"api_type": "openai",
|
||||
"models": [
|
||||
{
|
||||
"model_name": "deepseek-chat",
|
||||
"max_tokens": 4096,
|
||||
"temperature": 0.7,
|
||||
},
|
||||
{
|
||||
"model_name": "deepseek-reasoner",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "GLM",
|
||||
"api_key": "",
|
||||
"api_base": "https://open.bigmodel.cn",
|
||||
"api_type": "zhipu",
|
||||
"models": [
|
||||
{"model_name": "glm-4-flash"},
|
||||
{"model_name": "glm-4-plus"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Gemini",
|
||||
"api_key": [
|
||||
"AIzaSy*****************************",
|
||||
"AIzaSy*****************************",
|
||||
"AIzaSy*****************************",
|
||||
],
|
||||
"api_base": "https://generativelanguage.googleapis.com",
|
||||
"api_type": "gemini",
|
||||
"models": [
|
||||
{"model_name": "gemini-2.0-flash"},
|
||||
{"model_name": "gemini-2.5-flash-preview-05-20"},
|
||||
],
|
||||
},
|
||||
],
|
||||
help="配置多个 AI 服务提供商及其模型信息 (列表)",
|
||||
get_default_providers(),
|
||||
help=model_fields["providers"].description,
|
||||
default_value=[],
|
||||
type=list[ProviderConfig],
|
||||
)
|
||||
|
||||
|
||||
def get_llm_config() -> LLMConfig:
|
||||
"""获取 LLM 配置实例
|
||||
|
||||
返回:
|
||||
LLMConfig: LLM 配置实例
|
||||
"""
|
||||
ai_config = get_ai_config()
|
||||
|
||||
config_data = {
|
||||
"default_model_name": ai_config.get("default_model_name"),
|
||||
"proxy": ai_config.get("proxy"),
|
||||
"timeout": ai_config.get("timeout", 180),
|
||||
"max_retries_llm": ai_config.get("max_retries_llm", 3),
|
||||
"retry_delay_llm": ai_config.get("retry_delay_llm", 2),
|
||||
"providers": ai_config.get(PROVIDERS_CONFIG_KEY, []),
|
||||
}
|
||||
|
||||
return LLMConfig(**config_data)
|
||||
|
||||
|
||||
def validate_llm_config() -> tuple[bool, list[str]]:
|
||||
"""验证 LLM 配置的有效性
|
||||
|
||||
返回:
|
||||
tuple[bool, list[str]]: (是否有效, 错误信息列表)
|
||||
"""
|
||||
errors = []
|
||||
|
||||
try:
|
||||
llm_config = get_llm_config()
|
||||
|
||||
if llm_config.timeout <= 0:
|
||||
errors.append("timeout 必须大于 0")
|
||||
|
||||
if llm_config.max_retries_llm < 0:
|
||||
errors.append("max_retries_llm 不能小于 0")
|
||||
|
||||
if llm_config.retry_delay_llm <= 0:
|
||||
errors.append("retry_delay_llm 必须大于 0")
|
||||
|
||||
if not llm_config.providers:
|
||||
errors.append("至少需要配置一个 AI 服务提供商")
|
||||
else:
|
||||
provider_names = set()
|
||||
for provider in llm_config.providers:
|
||||
if provider.name in provider_names:
|
||||
errors.append(f"提供商名称重复: {provider.name}")
|
||||
provider_names.add(provider.name)
|
||||
|
||||
if not provider.api_key:
|
||||
errors.append(f"提供商 {provider.name} 缺少 API Key")
|
||||
|
||||
if not provider.models:
|
||||
errors.append(f"提供商 {provider.name} 没有配置任何模型")
|
||||
else:
|
||||
model_names = set()
|
||||
for model in provider.models:
|
||||
if model.model_name in model_names:
|
||||
errors.append(
|
||||
f"提供商 {provider.name} 中模型名称重复: "
|
||||
f"{model.model_name}"
|
||||
)
|
||||
model_names.add(model.model_name)
|
||||
|
||||
if llm_config.default_model_name:
|
||||
if not llm_config.validate_model_name(llm_config.default_model_name):
|
||||
errors.append(
|
||||
f"默认模型 {llm_config.default_model_name} 在配置中不存在"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
errors.append(f"配置解析失败: {e!s}")
|
||||
|
||||
return len(errors) == 0, errors
|
||||
|
||||
|
||||
def set_default_model(provider_model_name: str | None) -> bool:
|
||||
"""设置默认模型
|
||||
|
||||
参数:
|
||||
provider_model_name: 模型名称,格式为 "ProviderName/ModelName",None 表示清除
|
||||
|
||||
返回:
|
||||
bool: 是否设置成功
|
||||
"""
|
||||
if provider_model_name:
|
||||
llm_config = get_llm_config()
|
||||
if not llm_config.validate_model_name(provider_model_name):
|
||||
logger.error(f"模型 {provider_model_name} 在配置中不存在")
|
||||
return False
|
||||
|
||||
Config.set_config(
|
||||
AI_CONFIG_GROUP, "default_model_name", provider_model_name, auto_save=True
|
||||
)
|
||||
|
||||
if provider_model_name:
|
||||
logger.info(f"默认模型已设置为: {provider_model_name}")
|
||||
else:
|
||||
logger.info("默认模型已清除")
|
||||
|
||||
return True
|
||||
|
||||
Loading…
Reference in New Issue
Block a user