feat(llm): 添加新模型并简化提供者配置加载

This commit is contained in:
webjoin111 2025-07-07 18:16:49 +08:00
parent a348c3d276
commit ae8ff7b824
2 changed files with 21 additions and 47 deletions

View File

@ -173,9 +173,10 @@ def get_default_providers() -> list[dict[str, Any]]:
"api_base": "https://ark.cn-beijing.volces.com",
"api_type": "ark",
"models": [
{
"model_name": "ep-xxxxxxxxxxxxxxxx-xxxxx",
},
{"model_name": "deepseek-r1-250528"},
{"model_name": "doubao-seed-1-6-250615"},
{"model_name": "doubao-seed-1-6-flash-250615"},
{"model_name": "doubao-seed-1-6-thinking-250615"},
],
},
{

View File

@ -116,57 +116,30 @@ def get_default_api_base_for_type(api_type: str) -> str | None:
def get_configured_providers() -> list[ProviderConfig]:
"""从配置中获取Provider列表 - 简化版本"""
"""从配置中获取Provider列表 - 简化和修正版本"""
ai_config = get_ai_config()
providers_raw = ai_config.get(PROVIDERS_CONFIG_KEY, [])
if not isinstance(providers_raw, list):
providers = ai_config.get(PROVIDERS_CONFIG_KEY, [])
if not isinstance(providers, list):
logger.error(
f"配置项 {AI_CONFIG_GROUP}.{PROVIDERS_CONFIG_KEY} 不是一个列表,"
f"配置项 {AI_CONFIG_GROUP}.{PROVIDERS_CONFIG_KEY} 的值不是一个列表,"
f"将使用空列表。"
)
return []
valid_providers = []
for i, item in enumerate(providers_raw):
if not isinstance(item, dict):
logger.warning(f"配置文件中第 {i + 1} 项不是字典格式,已跳过。")
continue
try:
if not item.get("name"):
logger.warning(f"Provider {i + 1} 缺少 'name' 字段,已跳过。")
continue
if not item.get("api_key"):
logger.warning(
f"Provider '{item['name']}' 缺少 'api_key' 字段,已跳过。"
)
continue
if "api_type" not in item or not item["api_type"]:
provider_name = item.get("name", "").lower()
if "glm" in provider_name or "zhipu" in provider_name:
item["api_type"] = "zhipu"
elif "gemini" in provider_name or "google" in provider_name:
item["api_type"] = "gemini"
else:
item["api_type"] = "openai"
if "api_base" not in item or not item["api_base"]:
api_type = item.get("api_type")
if api_type:
default_api_base = get_default_api_base_for_type(api_type)
if default_api_base:
item["api_base"] = default_api_base
if "models" not in item:
item["models"] = [{"model_name": item.get("name", "default")}]
provider_conf = ProviderConfig(**item)
valid_providers.append(provider_conf)
except Exception as e:
logger.warning(f"解析配置文件中 Provider {i + 1} 时出错: {e},已跳过。")
for i, item in enumerate(providers):
if isinstance(item, ProviderConfig):
if not item.api_base:
default_api_base = get_default_api_base_for_type(item.api_type)
if default_api_base:
item.api_base = default_api_base
valid_providers.append(item)
else:
logger.warning(
f"配置文件中第 {i + 1} 项未能正确解析为 ProviderConfig 对象,"
f"已跳过。实际类型: {type(item)}"
)
return valid_providers