mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 14:22:55 +08:00
✨ feat(llm): 新增LLM模型管理插件并增强API密钥管理
🔧 新增功能: - LLM模型管理插件 (builtin_plugins/llm_manager/) • llm list - 查看可用模型列表 (图片格式) • llm info - 查看模型详细信息 (Markdown图片) • llm default - 管理全局默认模型 • llm test - 测试模型连通性 • llm keys - 查看API Key状态 (表格图片,含健康度/成功率/延迟) • llm reset-key - 重置API Key失败状态 🏗️ 架构重构: - 会话管理: AI/AIConfig 类迁移至独立的 session.py - 类型定义: TaskType 枚举移至 types/enums.py - API增强: • chat() 函数返回完整 LLMResponse,支持工具调用 • 新增 generate() 函数用于一次性响应生成 • 统一API调用核心方法 _perform_api_call,返回使用的API密钥 🚀 密钥管理增强: - 详细状态跟踪: 健康度、成功率、平均延迟、错误信息、建议操作 - 状态持久化: 启动时加载,关闭时自动保存密钥状态 - 智能冷却策略: 根据错误类型设置不同冷却时间 - 延迟监控: with_smart_retry 记录API调用延迟并更新统计
This commit is contained in:
parent
632ec3e46e
commit
36f36b3ac4
171
zhenxun/builtin_plugins/llm_manager/__init__.py
Normal file
171
zhenxun/builtin_plugins/llm_manager/__init__.py
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
from nonebot.permission import SUPERUSER
|
||||||
|
from nonebot.plugin import PluginMetadata
|
||||||
|
from nonebot_plugin_alconna import (
|
||||||
|
Alconna,
|
||||||
|
Args,
|
||||||
|
Arparma,
|
||||||
|
Match,
|
||||||
|
Option,
|
||||||
|
Query,
|
||||||
|
Subcommand,
|
||||||
|
on_alconna,
|
||||||
|
store_true,
|
||||||
|
)
|
||||||
|
|
||||||
|
from zhenxun.configs.utils import PluginExtraData
|
||||||
|
from zhenxun.services.log import logger
|
||||||
|
from zhenxun.utils.enum import PluginType
|
||||||
|
from zhenxun.utils.message import MessageUtils
|
||||||
|
|
||||||
|
from .data_source import DataSource
|
||||||
|
from .presenters import Presenters
|
||||||
|
|
||||||
|
__plugin_meta__ = PluginMetadata(
|
||||||
|
name="LLM模型管理",
|
||||||
|
description="查看和管理大语言模型服务。",
|
||||||
|
usage="""
|
||||||
|
LLM模型管理 (SUPERUSER)
|
||||||
|
|
||||||
|
llm list [--all]
|
||||||
|
- 查看可用模型列表。
|
||||||
|
- --all: 显示包括不可用在内的所有模型。
|
||||||
|
|
||||||
|
llm info <Provider/ModelName>
|
||||||
|
- 查看指定模型的详细信息和能力。
|
||||||
|
|
||||||
|
llm default [Provider/ModelName]
|
||||||
|
- 查看或设置全局默认模型。
|
||||||
|
- 不带参数: 查看当前默认模型。
|
||||||
|
- 带参数: 设置新的默认模型。
|
||||||
|
- 例子: llm default Gemini/gemini-2.0-flash
|
||||||
|
|
||||||
|
llm test <Provider/ModelName>
|
||||||
|
- 测试指定模型的连通性和API Key有效性。
|
||||||
|
|
||||||
|
llm keys <ProviderName>
|
||||||
|
- 查看指定提供商的所有API Key状态。
|
||||||
|
|
||||||
|
llm reset-key <ProviderName> [--key <api_key>]
|
||||||
|
- 重置提供商的所有或指定API Key的失败状态。
|
||||||
|
""",
|
||||||
|
extra=PluginExtraData(
|
||||||
|
author="HibiKier",
|
||||||
|
version="1.0.0",
|
||||||
|
plugin_type=PluginType.SUPERUSER,
|
||||||
|
).to_dict(),
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_cmd = on_alconna(
|
||||||
|
Alconna(
|
||||||
|
"llm",
|
||||||
|
Subcommand("list", alias=["ls"], help_text="查看模型列表"),
|
||||||
|
Subcommand("info", Args["model_name", str], help_text="查看模型详情"),
|
||||||
|
Subcommand("default", Args["model_name?", str], help_text="查看或设置默认模型"),
|
||||||
|
Subcommand(
|
||||||
|
"test", Args["model_name", str], alias=["ping"], help_text="测试模型连通性"
|
||||||
|
),
|
||||||
|
Subcommand("keys", Args["provider_name", str], help_text="查看API密钥状态"),
|
||||||
|
Subcommand(
|
||||||
|
"reset-key",
|
||||||
|
Args["provider_name", str],
|
||||||
|
Option("--key", Args["api_key", str], help_text="指定要重置的API Key"),
|
||||||
|
help_text="重置API Key状态",
|
||||||
|
),
|
||||||
|
Option("--all", action=store_true, help_text="显示所有条目"),
|
||||||
|
),
|
||||||
|
permission=SUPERUSER,
|
||||||
|
priority=5,
|
||||||
|
block=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@llm_cmd.assign("list")
|
||||||
|
async def handle_list(arp: Arparma, show_all: Query[bool] = Query("all")):
|
||||||
|
"""处理 'llm list' 命令"""
|
||||||
|
logger.info("获取LLM模型列表", command="LLM Manage", session=arp.header_result)
|
||||||
|
models = await DataSource.get_model_list(show_all=show_all.result)
|
||||||
|
|
||||||
|
image = await Presenters.format_model_list_as_image(models, show_all.result)
|
||||||
|
await llm_cmd.finish(MessageUtils.build_message(image))
|
||||||
|
|
||||||
|
|
||||||
|
@llm_cmd.assign("info")
|
||||||
|
async def handle_info(arp: Arparma, model_name: Match[str]):
|
||||||
|
"""处理 'llm info' 命令"""
|
||||||
|
logger.info(
|
||||||
|
f"获取模型详情: {model_name.result}",
|
||||||
|
command="LLM Manage",
|
||||||
|
session=arp.header_result,
|
||||||
|
)
|
||||||
|
details = await DataSource.get_model_details(model_name.result)
|
||||||
|
if not details:
|
||||||
|
await llm_cmd.finish(f"未找到模型: {model_name.result}")
|
||||||
|
|
||||||
|
image_bytes = await Presenters.format_model_details_as_markdown_image(details)
|
||||||
|
await llm_cmd.finish(MessageUtils.build_message(image_bytes))
|
||||||
|
|
||||||
|
|
||||||
|
@llm_cmd.assign("default")
|
||||||
|
async def handle_default(arp: Arparma, model_name: Match[str]):
|
||||||
|
"""处理 'llm default' 命令"""
|
||||||
|
if model_name.available:
|
||||||
|
logger.info(
|
||||||
|
f"设置默认模型为: {model_name.result}",
|
||||||
|
command="LLM Manage",
|
||||||
|
session=arp.header_result,
|
||||||
|
)
|
||||||
|
success, message = await DataSource.set_default_model(model_name.result)
|
||||||
|
await llm_cmd.finish(message)
|
||||||
|
else:
|
||||||
|
logger.info("查看默认模型", command="LLM Manage", session=arp.header_result)
|
||||||
|
current_default = await DataSource.get_default_model()
|
||||||
|
await llm_cmd.finish(f"当前全局默认模型为: {current_default or '未设置'}")
|
||||||
|
|
||||||
|
|
||||||
|
@llm_cmd.assign("test")
|
||||||
|
async def handle_test(arp: Arparma, model_name: Match[str]):
|
||||||
|
"""处理 'llm test' 命令"""
|
||||||
|
logger.info(
|
||||||
|
f"测试模型连通性: {model_name.result}",
|
||||||
|
command="LLM Manage",
|
||||||
|
session=arp.header_result,
|
||||||
|
)
|
||||||
|
await llm_cmd.send(f"正在测试模型 '{model_name.result}',请稍候...")
|
||||||
|
|
||||||
|
success, message = await DataSource.test_model_connectivity(model_name.result)
|
||||||
|
await llm_cmd.finish(message)
|
||||||
|
|
||||||
|
|
||||||
|
@llm_cmd.assign("keys")
|
||||||
|
async def handle_keys(arp: Arparma, provider_name: Match[str]):
|
||||||
|
"""处理 'llm keys' 命令"""
|
||||||
|
logger.info(
|
||||||
|
f"查看提供商API Key状态: {provider_name.result}",
|
||||||
|
command="LLM Manage",
|
||||||
|
session=arp.header_result,
|
||||||
|
)
|
||||||
|
sorted_stats = await DataSource.get_key_status(provider_name.result)
|
||||||
|
if not sorted_stats:
|
||||||
|
await llm_cmd.finish(
|
||||||
|
f"未找到提供商 '{provider_name.result}' 或其没有配置API Keys。"
|
||||||
|
)
|
||||||
|
|
||||||
|
image = await Presenters.format_key_status_as_image(
|
||||||
|
provider_name.result, sorted_stats
|
||||||
|
)
|
||||||
|
await llm_cmd.finish(MessageUtils.build_message(image))
|
||||||
|
|
||||||
|
|
||||||
|
@llm_cmd.assign("reset-key")
|
||||||
|
async def handle_reset_key(
|
||||||
|
arp: Arparma, provider_name: Match[str], api_key: Match[str]
|
||||||
|
):
|
||||||
|
"""处理 'llm reset-key' 命令"""
|
||||||
|
key_to_reset = api_key.result if api_key.available else None
|
||||||
|
log_msg = f"重置 {provider_name.result} 的 " + (
|
||||||
|
"指定API Key" if key_to_reset else "所有API Keys"
|
||||||
|
)
|
||||||
|
logger.info(log_msg, command="LLM Manage", session=arp.header_result)
|
||||||
|
|
||||||
|
success, message = await DataSource.reset_key(provider_name.result, key_to_reset)
|
||||||
|
await llm_cmd.finish(message)
|
||||||
120
zhenxun/builtin_plugins/llm_manager/data_source.py
Normal file
120
zhenxun/builtin_plugins/llm_manager/data_source.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from zhenxun.services.llm import (
|
||||||
|
LLMException,
|
||||||
|
get_global_default_model_name,
|
||||||
|
get_model_instance,
|
||||||
|
list_available_models,
|
||||||
|
set_global_default_model_name,
|
||||||
|
)
|
||||||
|
from zhenxun.services.llm.core import KeyStatus
|
||||||
|
from zhenxun.services.llm.manager import (
|
||||||
|
reset_key_status,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DataSource:
|
||||||
|
"""LLM管理插件的数据源和业务逻辑"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_model_list(show_all: bool = False) -> list[dict[str, Any]]:
|
||||||
|
"""获取模型列表"""
|
||||||
|
models = list_available_models()
|
||||||
|
if show_all:
|
||||||
|
return models
|
||||||
|
return [m for m in models if m.get("is_available", True)]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_model_details(model_name_str: str) -> dict[str, Any] | None:
|
||||||
|
"""获取指定模型的详细信息"""
|
||||||
|
try:
|
||||||
|
model = await get_model_instance(model_name_str)
|
||||||
|
return {
|
||||||
|
"provider_config": model.provider_config,
|
||||||
|
"model_detail": model.model_detail,
|
||||||
|
"capabilities": model.capabilities,
|
||||||
|
}
|
||||||
|
except LLMException:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_default_model() -> str | None:
|
||||||
|
"""获取全局默认模型"""
|
||||||
|
return get_global_default_model_name()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def set_default_model(model_name_str: str) -> tuple[bool, str]:
|
||||||
|
"""设置全局默认模型"""
|
||||||
|
success = set_global_default_model_name(model_name_str)
|
||||||
|
if success:
|
||||||
|
return True, f"✅ 成功将默认模型设置为: {model_name_str}"
|
||||||
|
else:
|
||||||
|
return False, f"❌ 设置失败,模型 '{model_name_str}' 不存在或无效。"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def test_model_connectivity(model_name_str: str) -> tuple[bool, str]:
|
||||||
|
"""测试模型连通性"""
|
||||||
|
start_time = time.monotonic()
|
||||||
|
try:
|
||||||
|
async with await get_model_instance(model_name_str) as model:
|
||||||
|
await model.generate_text("你好")
|
||||||
|
end_time = time.monotonic()
|
||||||
|
latency = (end_time - start_time) * 1000
|
||||||
|
return (
|
||||||
|
True,
|
||||||
|
f"✅ 模型 '{model_name_str}' 连接成功!\n响应延迟: {latency:.2f} ms",
|
||||||
|
)
|
||||||
|
except LLMException as e:
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"❌ 模型 '{model_name_str}' 连接测试失败:\n"
|
||||||
|
f"{e.user_friendly_message}\n错误码: {e.code.name}",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return False, f"❌ 测试时发生未知错误: {e!s}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_key_status(provider_name: str) -> list[dict[str, Any]] | None:
|
||||||
|
"""获取并排序指定提供商的API Key状态"""
|
||||||
|
from zhenxun.services.llm.manager import get_key_usage_stats
|
||||||
|
|
||||||
|
all_stats = await get_key_usage_stats()
|
||||||
|
provider_stats = all_stats.get(provider_name)
|
||||||
|
|
||||||
|
if not provider_stats or not provider_stats.get("key_stats"):
|
||||||
|
return None
|
||||||
|
|
||||||
|
key_stats_dict = provider_stats["key_stats"]
|
||||||
|
|
||||||
|
stats_list = [
|
||||||
|
{"key_id": key_id, **stats} for key_id, stats in key_stats_dict.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
def sort_key(item: dict[str, Any]):
|
||||||
|
status_priority = item.get("status_enum", KeyStatus.UNUSED).value
|
||||||
|
return (
|
||||||
|
status_priority,
|
||||||
|
100 - item.get("success_rate", 100.0),
|
||||||
|
-item.get("total_calls", 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
sorted_stats_list = sorted(stats_list, key=sort_key)
|
||||||
|
|
||||||
|
return sorted_stats_list
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def reset_key(provider_name: str, api_key: str | None) -> tuple[bool, str]:
|
||||||
|
"""重置API Key状态"""
|
||||||
|
success = await reset_key_status(provider_name, api_key)
|
||||||
|
if success:
|
||||||
|
if api_key:
|
||||||
|
if len(api_key) > 8:
|
||||||
|
target = f"API Key '{api_key[:4]}...{api_key[-4:]}'"
|
||||||
|
else:
|
||||||
|
target = f"API Key '{api_key}'"
|
||||||
|
else:
|
||||||
|
target = "所有API Keys"
|
||||||
|
return True, f"✅ 成功重置提供商 '{provider_name}' 的 {target} 的状态。"
|
||||||
|
else:
|
||||||
|
return False, "❌ 重置失败,请检查提供商名称或API Key是否正确。"
|
||||||
204
zhenxun/builtin_plugins/llm_manager/presenters.py
Normal file
204
zhenxun/builtin_plugins/llm_manager/presenters.py
Normal file
@ -0,0 +1,204 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from zhenxun.services.llm.core import KeyStatus
|
||||||
|
from zhenxun.services.llm.types import ModelModality
|
||||||
|
from zhenxun.utils._build_image import BuildImage
|
||||||
|
from zhenxun.utils._image_template import ImageTemplate, Markdown, RowStyle
|
||||||
|
|
||||||
|
|
||||||
|
def _format_seconds(seconds: int) -> str:
|
||||||
|
"""将秒数格式化为 'Xm Ys' 或 'Xh Ym' 的形式"""
|
||||||
|
if seconds <= 0:
|
||||||
|
return "0s"
|
||||||
|
if seconds < 60:
|
||||||
|
return f"{seconds}s"
|
||||||
|
|
||||||
|
minutes, seconds = divmod(seconds, 60)
|
||||||
|
if minutes < 60:
|
||||||
|
return f"{minutes}m {seconds}s"
|
||||||
|
|
||||||
|
hours, minutes = divmod(minutes, 60)
|
||||||
|
return f"{hours}h {minutes}m"
|
||||||
|
|
||||||
|
|
||||||
|
class Presenters:
|
||||||
|
"""格式化LLM管理插件的输出 (图片格式)"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def format_model_list_as_image(
|
||||||
|
models: list[dict[str, Any]], show_all: bool
|
||||||
|
) -> BuildImage:
|
||||||
|
"""将模型列表格式化为表格图片"""
|
||||||
|
title = "📋 LLM模型列表" + (" (所有已配置模型)" if show_all else " (仅可用)")
|
||||||
|
|
||||||
|
if not models:
|
||||||
|
return await BuildImage.build_text_image(
|
||||||
|
f"{title}\n\n当前没有配置任何LLM模型。"
|
||||||
|
)
|
||||||
|
|
||||||
|
column_name = ["提供商", "模型名称", "API类型", "状态"]
|
||||||
|
data_list = []
|
||||||
|
for model in models:
|
||||||
|
status_text = "✅ 可用" if model.get("is_available", True) else "❌ 不可用"
|
||||||
|
embed_tag = " (Embed)" if model.get("is_embedding_model", False) else ""
|
||||||
|
data_list.append(
|
||||||
|
[
|
||||||
|
model.get("provider_name", "N/A"),
|
||||||
|
f"{model.get('model_name', 'N/A')}{embed_tag}",
|
||||||
|
model.get("api_type", "N/A"),
|
||||||
|
status_text,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return await ImageTemplate.table_page(
|
||||||
|
head_text=title,
|
||||||
|
tip_text="使用 `llm info <Provider/ModelName>` 查看详情",
|
||||||
|
column_name=column_name,
|
||||||
|
data_list=data_list,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def format_model_details_as_markdown_image(details: dict[str, Any]) -> bytes:
|
||||||
|
"""将模型详情格式化为Markdown图片"""
|
||||||
|
provider = details["provider_config"]
|
||||||
|
model = details["model_detail"]
|
||||||
|
caps = details["capabilities"]
|
||||||
|
|
||||||
|
cap_list = []
|
||||||
|
if ModelModality.IMAGE in caps.input_modalities:
|
||||||
|
cap_list.append("视觉")
|
||||||
|
if ModelModality.VIDEO in caps.input_modalities:
|
||||||
|
cap_list.append("视频")
|
||||||
|
if ModelModality.AUDIO in caps.input_modalities:
|
||||||
|
cap_list.append("音频")
|
||||||
|
if caps.supports_tool_calling:
|
||||||
|
cap_list.append("工具调用")
|
||||||
|
if caps.is_embedding_model:
|
||||||
|
cap_list.append("文本嵌入")
|
||||||
|
|
||||||
|
md = Markdown()
|
||||||
|
md.head(f"🔎 模型详情: {provider.name}/{model.model_name}", level=1)
|
||||||
|
md.text("---")
|
||||||
|
md.head("提供商信息", level=2)
|
||||||
|
md.list(
|
||||||
|
[
|
||||||
|
f"**名称**: {provider.name}",
|
||||||
|
f"**API 类型**: {provider.api_type}",
|
||||||
|
f"**API Base**: {provider.api_base or '默认'}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
md.head("模型详情", level=2)
|
||||||
|
|
||||||
|
temp_value = model.temperature or provider.temperature or "未设置"
|
||||||
|
token_value = model.max_tokens or provider.max_tokens or "未设置"
|
||||||
|
|
||||||
|
md.list(
|
||||||
|
[
|
||||||
|
f"**名称**: {model.model_name}",
|
||||||
|
f"**默认温度**: {temp_value}",
|
||||||
|
f"**最大Token**: {token_value}",
|
||||||
|
f"**核心能力**: {', '.join(cap_list) or '纯文本'}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return await md.build()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def format_key_status_as_image(
|
||||||
|
provider_name: str, sorted_stats: list[dict[str, Any]]
|
||||||
|
) -> BuildImage:
|
||||||
|
"""将已排序的、详细的API Key状态格式化为表格图片"""
|
||||||
|
title = f"🔑 '{provider_name}' API Key 状态"
|
||||||
|
|
||||||
|
if not sorted_stats:
|
||||||
|
return await BuildImage.build_text_image(
|
||||||
|
f"{title}\n\n该提供商没有配置API Keys。"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _status_row_style(column: str, text: str) -> RowStyle:
|
||||||
|
style = RowStyle()
|
||||||
|
if column == "状态":
|
||||||
|
if "✅ 健康" in text:
|
||||||
|
style.font_color = "#67C23A"
|
||||||
|
elif "⚠️ 告警" in text:
|
||||||
|
style.font_color = "#E6A23C"
|
||||||
|
elif "❌ 错误" in text or "🚫" in text:
|
||||||
|
style.font_color = "#F56C6C"
|
||||||
|
elif "❄️ 冷却中" in text:
|
||||||
|
style.font_color = "#409EFF"
|
||||||
|
elif column == "成功率":
|
||||||
|
try:
|
||||||
|
if text != "N/A":
|
||||||
|
rate = float(text.replace("%", ""))
|
||||||
|
if rate < 80:
|
||||||
|
style.font_color = "#F56C6C"
|
||||||
|
elif rate < 95:
|
||||||
|
style.font_color = "#E6A23C"
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
pass
|
||||||
|
return style
|
||||||
|
|
||||||
|
column_name = [
|
||||||
|
"Key (部分)",
|
||||||
|
"状态",
|
||||||
|
"总调用",
|
||||||
|
"成功率",
|
||||||
|
"平均延迟(s)",
|
||||||
|
"上次错误",
|
||||||
|
"建议操作",
|
||||||
|
]
|
||||||
|
data_list = []
|
||||||
|
|
||||||
|
for key_info in sorted_stats:
|
||||||
|
status_enum: KeyStatus = key_info["status_enum"]
|
||||||
|
|
||||||
|
if status_enum == KeyStatus.COOLDOWN:
|
||||||
|
cooldown_seconds = int(key_info["cooldown_seconds_left"])
|
||||||
|
formatted_time = _format_seconds(cooldown_seconds)
|
||||||
|
status_text = f"❄️ 冷却中({formatted_time})"
|
||||||
|
else:
|
||||||
|
status_text = {
|
||||||
|
KeyStatus.DISABLED: "🚫 永久禁用",
|
||||||
|
KeyStatus.ERROR: "❌ 错误",
|
||||||
|
KeyStatus.WARNING: "⚠️ 告警",
|
||||||
|
KeyStatus.HEALTHY: "✅ 健康",
|
||||||
|
KeyStatus.UNUSED: "⚪️ 未使用",
|
||||||
|
}.get(status_enum, "❔ 未知")
|
||||||
|
|
||||||
|
total_calls = key_info["total_calls"]
|
||||||
|
total_calls_text = (
|
||||||
|
f"{key_info['success_count']}/{total_calls}"
|
||||||
|
if total_calls > 0
|
||||||
|
else "0/0"
|
||||||
|
)
|
||||||
|
|
||||||
|
success_rate = key_info["success_rate"]
|
||||||
|
success_rate_text = f"{success_rate:.1f}%" if total_calls > 0 else "N/A"
|
||||||
|
|
||||||
|
avg_latency = key_info["avg_latency"]
|
||||||
|
avg_latency_text = f"{avg_latency / 1000:.2f}" if avg_latency > 0 else "N/A"
|
||||||
|
|
||||||
|
last_error = key_info.get("last_error") or "-"
|
||||||
|
if len(last_error) > 25:
|
||||||
|
last_error = last_error[:22] + "..."
|
||||||
|
|
||||||
|
data_list.append(
|
||||||
|
[
|
||||||
|
key_info["key_id"],
|
||||||
|
status_text,
|
||||||
|
total_calls_text,
|
||||||
|
success_rate_text,
|
||||||
|
avg_latency_text,
|
||||||
|
last_error,
|
||||||
|
key_info["suggested_action"],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return await ImageTemplate.table_page(
|
||||||
|
head_text=title,
|
||||||
|
tip_text="使用 `llm reset-key <Provider>` 重置Key状态",
|
||||||
|
column_name=column_name,
|
||||||
|
data_list=data_list,
|
||||||
|
text_style=_status_row_style,
|
||||||
|
column_space=15,
|
||||||
|
)
|
||||||
@ -21,11 +21,28 @@ require("nonebot_plugin_waiter")
|
|||||||
from .db_context import Model, disconnect
|
from .db_context import Model, disconnect
|
||||||
from .llm import (
|
from .llm import (
|
||||||
AI,
|
AI,
|
||||||
|
AIConfig,
|
||||||
|
CommonOverrides,
|
||||||
LLMContentPart,
|
LLMContentPart,
|
||||||
LLMException,
|
LLMException,
|
||||||
|
LLMGenerationConfig,
|
||||||
LLMMessage,
|
LLMMessage,
|
||||||
|
analyze,
|
||||||
|
analyze_multimodal,
|
||||||
|
chat,
|
||||||
|
clear_model_cache,
|
||||||
|
code,
|
||||||
|
create_multimodal_message,
|
||||||
|
embed,
|
||||||
|
generate,
|
||||||
|
get_cache_stats,
|
||||||
get_model_instance,
|
get_model_instance,
|
||||||
list_available_models,
|
list_available_models,
|
||||||
|
list_embedding_models,
|
||||||
|
pipeline_chat,
|
||||||
|
search,
|
||||||
|
search_multimodal,
|
||||||
|
set_global_default_model_name,
|
||||||
tool_registry,
|
tool_registry,
|
||||||
)
|
)
|
||||||
from .log import logger
|
from .log import logger
|
||||||
@ -34,16 +51,33 @@ from .scheduler import scheduler_manager
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AI",
|
"AI",
|
||||||
|
"AIConfig",
|
||||||
|
"CommonOverrides",
|
||||||
"LLMContentPart",
|
"LLMContentPart",
|
||||||
"LLMException",
|
"LLMException",
|
||||||
|
"LLMGenerationConfig",
|
||||||
"LLMMessage",
|
"LLMMessage",
|
||||||
"Model",
|
"Model",
|
||||||
"PluginInit",
|
"PluginInit",
|
||||||
"PluginInitManager",
|
"PluginInitManager",
|
||||||
|
"analyze",
|
||||||
|
"analyze_multimodal",
|
||||||
|
"chat",
|
||||||
|
"clear_model_cache",
|
||||||
|
"code",
|
||||||
|
"create_multimodal_message",
|
||||||
"disconnect",
|
"disconnect",
|
||||||
|
"embed",
|
||||||
|
"generate",
|
||||||
|
"get_cache_stats",
|
||||||
"get_model_instance",
|
"get_model_instance",
|
||||||
"list_available_models",
|
"list_available_models",
|
||||||
|
"list_embedding_models",
|
||||||
"logger",
|
"logger",
|
||||||
|
"pipeline_chat",
|
||||||
"scheduler_manager",
|
"scheduler_manager",
|
||||||
|
"search",
|
||||||
|
"search_multimodal",
|
||||||
|
"set_global_default_model_name",
|
||||||
"tool_registry",
|
"tool_registry",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -198,7 +198,7 @@ print(search_result['text'])
|
|||||||
当你需要进行有上下文的、连续的对话时,`AI` 类是你的最佳选择。
|
当你需要进行有上下文的、连续的对话时,`AI` 类是你的最佳选择。
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from zhenxun.services.llm.api import AI, AIConfig
|
from zhenxun.services.llm import AI, AIConfig
|
||||||
|
|
||||||
# 初始化一个AI会话,可以传入自定义配置
|
# 初始化一个AI会话,可以传入自定义配置
|
||||||
ai_config = AIConfig(model="GLM/glm-4-flash", temperature=0.7)
|
ai_config = AIConfig(model="GLM/glm-4-flash", temperature=0.7)
|
||||||
@ -395,7 +395,7 @@ async def my_tool_factory(config: MyToolConfig):
|
|||||||
在 `analyze` 或 `generate_response` 中使用 `use_tools` 参数。框架会自动处理整个调用流程。
|
在 `analyze` 或 `generate_response` 中使用 `use_tools` 参数。框架会自动处理整个调用流程。
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from zhenxun.services.llm.api import analyze
|
from zhenxun.services.llm import analyze
|
||||||
from nonebot_plugin_alconna.uniseg import UniMessage
|
from nonebot_plugin_alconna.uniseg import UniMessage
|
||||||
|
|
||||||
response = await analyze(
|
response = await analyze(
|
||||||
@ -442,7 +442,6 @@ from zhenxun.services.llm.manager import (
|
|||||||
get_key_usage_stats,
|
get_key_usage_stats,
|
||||||
reset_key_status
|
reset_key_status
|
||||||
)
|
)
|
||||||
from zhenxun.services.llm import clear_model_cache, get_cache_stats
|
|
||||||
|
|
||||||
# 列出所有在config.yaml中配置的可用模型
|
# 列出所有在config.yaml中配置的可用模型
|
||||||
models = list_available_models()
|
models = list_available_models()
|
||||||
|
|||||||
@ -5,14 +5,12 @@ LLM 服务模块 - 公共 API 入口
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from .api import (
|
from .api import (
|
||||||
AI,
|
|
||||||
AIConfig,
|
|
||||||
TaskType,
|
|
||||||
analyze,
|
analyze,
|
||||||
analyze_multimodal,
|
analyze_multimodal,
|
||||||
chat,
|
chat,
|
||||||
code,
|
code,
|
||||||
embed,
|
embed,
|
||||||
|
generate,
|
||||||
pipeline_chat,
|
pipeline_chat,
|
||||||
search,
|
search,
|
||||||
search_multimodal,
|
search_multimodal,
|
||||||
@ -35,6 +33,7 @@ from .manager import (
|
|||||||
list_model_identifiers,
|
list_model_identifiers,
|
||||||
set_global_default_model_name,
|
set_global_default_model_name,
|
||||||
)
|
)
|
||||||
|
from .session import AI, AIConfig
|
||||||
from .tools import tool_registry
|
from .tools import tool_registry
|
||||||
from .types import (
|
from .types import (
|
||||||
EmbeddingTaskType,
|
EmbeddingTaskType,
|
||||||
@ -49,6 +48,7 @@ from .types import (
|
|||||||
ModelInfo,
|
ModelInfo,
|
||||||
ModelProvider,
|
ModelProvider,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
|
TaskType,
|
||||||
ToolCategory,
|
ToolCategory,
|
||||||
ToolMetadata,
|
ToolMetadata,
|
||||||
UsageInfo,
|
UsageInfo,
|
||||||
@ -84,6 +84,7 @@ __all__ = [
|
|||||||
"code",
|
"code",
|
||||||
"create_multimodal_message",
|
"create_multimodal_message",
|
||||||
"embed",
|
"embed",
|
||||||
|
"generate",
|
||||||
"get_cache_stats",
|
"get_cache_stats",
|
||||||
"get_global_default_model_name",
|
"get_global_default_model_name",
|
||||||
"get_model_instance",
|
"get_model_instance",
|
||||||
|
|||||||
@ -1,10 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
LLM 服务的高级 API 接口
|
LLM 服务的高级 API 接口 - 便捷函数入口
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import copy
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -12,10 +9,8 @@ from nonebot_plugin_alconna.uniseg import UniMessage
|
|||||||
|
|
||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
|
|
||||||
from .config import CommonOverrides, LLMGenerationConfig
|
from .manager import get_model_instance
|
||||||
from .config.providers import get_ai_config
|
from .session import AI
|
||||||
from .manager import get_global_default_model_name, get_model_instance
|
|
||||||
from .tools import tool_registry
|
|
||||||
from .types import (
|
from .types import (
|
||||||
EmbeddingTaskType,
|
EmbeddingTaskType,
|
||||||
LLMContentPart,
|
LLMContentPart,
|
||||||
@ -29,514 +24,31 @@ from .types import (
|
|||||||
from .utils import create_multimodal_message, unimsg_to_llm_parts
|
from .utils import create_multimodal_message, unimsg_to_llm_parts
|
||||||
|
|
||||||
|
|
||||||
class TaskType(Enum):
|
|
||||||
"""任务类型枚举"""
|
|
||||||
|
|
||||||
CHAT = "chat"
|
|
||||||
CODE = "code"
|
|
||||||
SEARCH = "search"
|
|
||||||
ANALYSIS = "analysis"
|
|
||||||
GENERATION = "generation"
|
|
||||||
MULTIMODAL = "multimodal"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AIConfig:
|
|
||||||
"""AI配置类 - 简化版本"""
|
|
||||||
|
|
||||||
model: ModelName = None
|
|
||||||
default_embedding_model: ModelName = None
|
|
||||||
temperature: float | None = None
|
|
||||||
max_tokens: int | None = None
|
|
||||||
enable_cache: bool = False
|
|
||||||
enable_code: bool = False
|
|
||||||
enable_search: bool = False
|
|
||||||
timeout: int | None = None
|
|
||||||
|
|
||||||
enable_gemini_json_mode: bool = False
|
|
||||||
enable_gemini_thinking: bool = False
|
|
||||||
enable_gemini_safe_mode: bool = False
|
|
||||||
enable_gemini_multimodal: bool = False
|
|
||||||
enable_gemini_grounding: bool = False
|
|
||||||
default_preserve_media_in_history: bool = False
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
"""初始化后从配置中读取默认值"""
|
|
||||||
ai_config = get_ai_config()
|
|
||||||
if self.model is None:
|
|
||||||
self.model = ai_config.get("default_model_name")
|
|
||||||
if self.timeout is None:
|
|
||||||
self.timeout = ai_config.get("timeout", 180)
|
|
||||||
|
|
||||||
|
|
||||||
class AI:
|
|
||||||
"""统一的AI服务类 - 平衡设计版本
|
|
||||||
|
|
||||||
提供三层API:
|
|
||||||
1. 简单方法:ai.chat(), ai.code(), ai.search()
|
|
||||||
2. 标准方法:ai.analyze() 支持复杂参数
|
|
||||||
3. 高级方法:通过get_model_instance()直接访问
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, config: AIConfig | None = None, history: list[LLMMessage] | None = None
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
初始化AI服务
|
|
||||||
|
|
||||||
参数:
|
|
||||||
config: AI 配置.
|
|
||||||
history: 可选的初始对话历史.
|
|
||||||
"""
|
|
||||||
self.config = config or AIConfig()
|
|
||||||
self.history = history or []
|
|
||||||
|
|
||||||
def clear_history(self):
|
|
||||||
"""清空当前会话的历史记录"""
|
|
||||||
self.history = []
|
|
||||||
logger.info("AI session history cleared.")
|
|
||||||
|
|
||||||
def _sanitize_message_for_history(self, message: LLMMessage) -> LLMMessage:
|
|
||||||
"""
|
|
||||||
净化用于存入历史记录的消息。
|
|
||||||
将非文本的多模态内容部分替换为文本占位符,以避免重复处理。
|
|
||||||
"""
|
|
||||||
if not isinstance(message.content, list):
|
|
||||||
return message
|
|
||||||
|
|
||||||
sanitized_message = copy.deepcopy(message)
|
|
||||||
content_list = sanitized_message.content
|
|
||||||
if not isinstance(content_list, list):
|
|
||||||
return sanitized_message
|
|
||||||
|
|
||||||
new_content_parts: list[LLMContentPart] = []
|
|
||||||
has_multimodal_content = False
|
|
||||||
|
|
||||||
for part in content_list:
|
|
||||||
if isinstance(part, LLMContentPart) and part.type == "text":
|
|
||||||
new_content_parts.append(part)
|
|
||||||
else:
|
|
||||||
has_multimodal_content = True
|
|
||||||
|
|
||||||
if has_multimodal_content:
|
|
||||||
placeholder = "[用户发送了媒体文件,内容已在首次分析时处理]"
|
|
||||||
text_part_found = False
|
|
||||||
for part in new_content_parts:
|
|
||||||
if part.type == "text":
|
|
||||||
part.text = f"{placeholder} {part.text or ''}".strip()
|
|
||||||
text_part_found = True
|
|
||||||
break
|
|
||||||
if not text_part_found:
|
|
||||||
new_content_parts.insert(0, LLMContentPart.text_part(placeholder))
|
|
||||||
|
|
||||||
sanitized_message.content = new_content_parts
|
|
||||||
return sanitized_message
|
|
||||||
|
|
||||||
async def chat(
|
|
||||||
self,
|
|
||||||
message: str | LLMMessage | list[LLMContentPart],
|
|
||||||
*,
|
|
||||||
model: ModelName = None,
|
|
||||||
preserve_media_in_history: bool | None = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
进行一次聊天对话。
|
|
||||||
此方法会自动使用和更新会话内的历史记录。
|
|
||||||
|
|
||||||
参数:
|
|
||||||
message: 用户输入的消息。
|
|
||||||
model: 本次对话要使用的模型。
|
|
||||||
preserve_media_in_history: 是否在历史记录中保留原始多模态信息。
|
|
||||||
- True: 保留,用于深度多轮媒体分析。
|
|
||||||
- False: 不保留,替换为占位符,提高效率。
|
|
||||||
- None (默认): 使用AI实例配置的默认值。
|
|
||||||
**kwargs: 传递给模型的其他参数。
|
|
||||||
|
|
||||||
返回:
|
|
||||||
str: 模型的文本响应。
|
|
||||||
"""
|
|
||||||
current_message: LLMMessage
|
|
||||||
if isinstance(message, str):
|
|
||||||
current_message = LLMMessage.user(message)
|
|
||||||
elif isinstance(message, list) and all(
|
|
||||||
isinstance(part, LLMContentPart) for part in message
|
|
||||||
):
|
|
||||||
current_message = LLMMessage.user(message)
|
|
||||||
elif isinstance(message, LLMMessage):
|
|
||||||
current_message = message
|
|
||||||
else:
|
|
||||||
raise LLMException(
|
|
||||||
f"AI.chat 不支持的消息类型: {type(message)}. "
|
|
||||||
"请使用 str, LLMMessage, 或 list[LLMContentPart]. "
|
|
||||||
"对于更复杂的多模态输入或文件路径,请使用 AI.analyze().",
|
|
||||||
code=LLMErrorCode.API_REQUEST_FAILED,
|
|
||||||
)
|
|
||||||
|
|
||||||
final_messages = [*self.history, current_message]
|
|
||||||
|
|
||||||
response = await self._execute_generation(
|
|
||||||
final_messages, model, "聊天失败", kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
should_preserve = (
|
|
||||||
preserve_media_in_history
|
|
||||||
if preserve_media_in_history is not None
|
|
||||||
else self.config.default_preserve_media_in_history
|
|
||||||
)
|
|
||||||
|
|
||||||
if should_preserve:
|
|
||||||
logger.debug("深度分析模式:在历史记录中保留原始多模态消息。")
|
|
||||||
self.history.append(current_message)
|
|
||||||
else:
|
|
||||||
logger.debug("高效模式:净化历史记录中的多模态消息。")
|
|
||||||
sanitized_user_message = self._sanitize_message_for_history(current_message)
|
|
||||||
self.history.append(sanitized_user_message)
|
|
||||||
|
|
||||||
self.history.append(LLMMessage.assistant_text_response(response.text))
|
|
||||||
|
|
||||||
return response.text
|
|
||||||
|
|
||||||
async def code(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
*,
|
|
||||||
model: ModelName = None,
|
|
||||||
timeout: int | None = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
代码执行
|
|
||||||
|
|
||||||
参数:
|
|
||||||
prompt: 代码执行的提示词。
|
|
||||||
model: 要使用的模型名称。
|
|
||||||
timeout: 代码执行超时时间(秒)。
|
|
||||||
**kwargs: 传递给模型的其他参数。
|
|
||||||
|
|
||||||
返回:
|
|
||||||
dict[str, Any]: 包含执行结果的字典,包含text、code_executions和success字段。
|
|
||||||
"""
|
|
||||||
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
|
|
||||||
|
|
||||||
config = CommonOverrides.gemini_code_execution()
|
|
||||||
if timeout:
|
|
||||||
config.custom_params = config.custom_params or {}
|
|
||||||
config.custom_params["code_execution_timeout"] = timeout
|
|
||||||
|
|
||||||
messages = [LLMMessage.user(prompt)]
|
|
||||||
|
|
||||||
response = await self._execute_generation(
|
|
||||||
messages, resolved_model, "代码执行失败", kwargs, base_config=config
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"text": response.text,
|
|
||||||
"code_executions": response.code_executions or [],
|
|
||||||
"success": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
async def search(
|
|
||||||
self,
|
|
||||||
query: str | UniMessage,
|
|
||||||
*,
|
|
||||||
model: ModelName = None,
|
|
||||||
instruction: str = "",
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
信息搜索 - 支持多模态输入
|
|
||||||
|
|
||||||
参数:
|
|
||||||
query: 搜索查询内容,支持文本或多模态消息。
|
|
||||||
model: 要使用的模型名称。
|
|
||||||
instruction: 搜索指令。
|
|
||||||
**kwargs: 传递给模型的其他参数。
|
|
||||||
|
|
||||||
返回:
|
|
||||||
dict[str, Any]: 包含搜索结果的字典,包含text、sources、queries和success字段
|
|
||||||
"""
|
|
||||||
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
|
|
||||||
config = CommonOverrides.gemini_grounding()
|
|
||||||
|
|
||||||
if isinstance(query, str):
|
|
||||||
messages = [LLMMessage.user(query)]
|
|
||||||
elif isinstance(query, UniMessage):
|
|
||||||
content_parts = await unimsg_to_llm_parts(query)
|
|
||||||
|
|
||||||
final_messages: list[LLMMessage] = []
|
|
||||||
if instruction:
|
|
||||||
final_messages.append(LLMMessage.system(instruction))
|
|
||||||
|
|
||||||
if not content_parts:
|
|
||||||
if instruction:
|
|
||||||
final_messages.append(LLMMessage.user(instruction))
|
|
||||||
else:
|
|
||||||
raise LLMException(
|
|
||||||
"搜索内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
final_messages.append(LLMMessage.user(content_parts))
|
|
||||||
|
|
||||||
messages = final_messages
|
|
||||||
else:
|
|
||||||
raise LLMException(
|
|
||||||
f"不支持的搜索输入类型: {type(query)}. 请使用 str 或 UniMessage.",
|
|
||||||
code=LLMErrorCode.API_REQUEST_FAILED,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await self._execute_generation(
|
|
||||||
messages, resolved_model, "信息搜索失败", kwargs, base_config=config
|
|
||||||
)
|
|
||||||
|
|
||||||
result = {
|
|
||||||
"text": response.text,
|
|
||||||
"sources": [],
|
|
||||||
"queries": [],
|
|
||||||
"success": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
if response.grounding_metadata:
|
|
||||||
result["sources"] = response.grounding_metadata.grounding_attributions or []
|
|
||||||
result["queries"] = response.grounding_metadata.web_search_queries or []
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def analyze(
|
|
||||||
self,
|
|
||||||
message: UniMessage | None,
|
|
||||||
*,
|
|
||||||
instruction: str = "",
|
|
||||||
model: ModelName = None,
|
|
||||||
use_tools: list[str] | None = None,
|
|
||||||
tool_config: dict[str, Any] | None = None,
|
|
||||||
activated_tools: list[LLMTool] | None = None,
|
|
||||||
history: list[LLMMessage] | None = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> LLMResponse:
|
|
||||||
"""
|
|
||||||
内容分析 - 接收 UniMessage 物件进行多模态分析和工具呼叫。
|
|
||||||
|
|
||||||
参数:
|
|
||||||
message: 要分析的消息内容(支持多模态)。
|
|
||||||
instruction: 分析指令。
|
|
||||||
model: 要使用的模型名称。
|
|
||||||
use_tools: 要使用的工具名称列表。
|
|
||||||
tool_config: 工具配置。
|
|
||||||
activated_tools: 已激活的工具列表。
|
|
||||||
history: 对话历史记录。
|
|
||||||
**kwargs: 传递给模型的其他参数。
|
|
||||||
|
|
||||||
返回:
|
|
||||||
LLMResponse: 模型的完整响应结果。
|
|
||||||
"""
|
|
||||||
content_parts = await unimsg_to_llm_parts(message or UniMessage())
|
|
||||||
|
|
||||||
final_messages: list[LLMMessage] = []
|
|
||||||
if history:
|
|
||||||
final_messages.extend(history)
|
|
||||||
|
|
||||||
if instruction:
|
|
||||||
if not any(msg.role == "system" for msg in final_messages):
|
|
||||||
final_messages.insert(0, LLMMessage.system(instruction))
|
|
||||||
|
|
||||||
if not content_parts:
|
|
||||||
if instruction and not history:
|
|
||||||
final_messages.append(LLMMessage.user(instruction))
|
|
||||||
elif not history:
|
|
||||||
raise LLMException(
|
|
||||||
"分析内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
final_messages.append(LLMMessage.user(content_parts))
|
|
||||||
|
|
||||||
llm_tools: list[LLMTool] | None = activated_tools
|
|
||||||
if not llm_tools and use_tools:
|
|
||||||
try:
|
|
||||||
llm_tools = tool_registry.get_tools(use_tools)
|
|
||||||
logger.debug(f"已从注册表加载工具定义: {use_tools}")
|
|
||||||
except ValueError as e:
|
|
||||||
raise LLMException(
|
|
||||||
f"加载工具定义失败: {e}",
|
|
||||||
code=LLMErrorCode.CONFIGURATION_ERROR,
|
|
||||||
cause=e,
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_choice = None
|
|
||||||
if tool_config:
|
|
||||||
mode = tool_config.get("mode", "auto")
|
|
||||||
if mode in ["auto", "any", "none"]:
|
|
||||||
tool_choice = mode
|
|
||||||
|
|
||||||
response = await self._execute_generation(
|
|
||||||
final_messages,
|
|
||||||
model,
|
|
||||||
"内容分析失败",
|
|
||||||
kwargs,
|
|
||||||
llm_tools=llm_tools,
|
|
||||||
tool_choice=tool_choice,
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
async def _execute_generation(
|
|
||||||
self,
|
|
||||||
messages: list[LLMMessage],
|
|
||||||
model_name: ModelName,
|
|
||||||
error_message: str,
|
|
||||||
config_overrides: dict[str, Any],
|
|
||||||
llm_tools: list[LLMTool] | None = None,
|
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
|
||||||
base_config: LLMGenerationConfig | None = None,
|
|
||||||
) -> LLMResponse:
|
|
||||||
"""通用的生成执行方法,封装模型获取和单次API调用"""
|
|
||||||
try:
|
|
||||||
resolved_model_name = self._resolve_model_name(
|
|
||||||
model_name or self.config.model
|
|
||||||
)
|
|
||||||
final_config_dict = self._merge_config(
|
|
||||||
config_overrides, base_config=base_config
|
|
||||||
)
|
|
||||||
|
|
||||||
async with await get_model_instance(
|
|
||||||
resolved_model_name, override_config=final_config_dict
|
|
||||||
) as model_instance:
|
|
||||||
return await model_instance.generate_response(
|
|
||||||
messages,
|
|
||||||
tools=llm_tools,
|
|
||||||
tool_choice=tool_choice,
|
|
||||||
)
|
|
||||||
except LLMException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"{error_message}: {e}", e=e)
|
|
||||||
raise LLMException(f"{error_message}: {e}", cause=e)
|
|
||||||
|
|
||||||
def _resolve_model_name(self, model_name: ModelName) -> str:
|
|
||||||
"""解析模型名称"""
|
|
||||||
if model_name:
|
|
||||||
return model_name
|
|
||||||
|
|
||||||
default_model = get_global_default_model_name()
|
|
||||||
if default_model:
|
|
||||||
return default_model
|
|
||||||
|
|
||||||
raise LLMException(
|
|
||||||
"未指定模型名称且未设置全局默认模型",
|
|
||||||
code=LLMErrorCode.MODEL_NOT_FOUND,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _merge_config(
|
|
||||||
self,
|
|
||||||
user_config: dict[str, Any],
|
|
||||||
base_config: LLMGenerationConfig | None = None,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""合并配置"""
|
|
||||||
final_config = {}
|
|
||||||
if base_config:
|
|
||||||
final_config.update(base_config.to_dict())
|
|
||||||
|
|
||||||
if self.config.temperature is not None:
|
|
||||||
final_config["temperature"] = self.config.temperature
|
|
||||||
if self.config.max_tokens is not None:
|
|
||||||
final_config["max_tokens"] = self.config.max_tokens
|
|
||||||
|
|
||||||
if self.config.enable_cache:
|
|
||||||
final_config["enable_caching"] = True
|
|
||||||
if self.config.enable_code:
|
|
||||||
final_config["enable_code_execution"] = True
|
|
||||||
if self.config.enable_search:
|
|
||||||
final_config["enable_grounding"] = True
|
|
||||||
|
|
||||||
if self.config.enable_gemini_json_mode:
|
|
||||||
final_config["response_mime_type"] = "application/json"
|
|
||||||
if self.config.enable_gemini_thinking:
|
|
||||||
final_config["thinking_budget"] = 0.8
|
|
||||||
if self.config.enable_gemini_safe_mode:
|
|
||||||
final_config["safety_settings"] = (
|
|
||||||
CommonOverrides.gemini_safe().safety_settings
|
|
||||||
)
|
|
||||||
if self.config.enable_gemini_multimodal:
|
|
||||||
final_config.update(CommonOverrides.gemini_multimodal().to_dict())
|
|
||||||
if self.config.enable_gemini_grounding:
|
|
||||||
final_config["enable_grounding"] = True
|
|
||||||
|
|
||||||
final_config.update(user_config)
|
|
||||||
|
|
||||||
return final_config
|
|
||||||
|
|
||||||
async def embed(
|
|
||||||
self,
|
|
||||||
texts: list[str] | str,
|
|
||||||
*,
|
|
||||||
model: ModelName = None,
|
|
||||||
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> list[list[float]]:
|
|
||||||
"""
|
|
||||||
生成文本嵌入向量
|
|
||||||
|
|
||||||
参数:
|
|
||||||
texts: 要生成嵌入向量的文本或文本列表。
|
|
||||||
model: 要使用的嵌入模型名称。
|
|
||||||
task_type: 嵌入任务类型。
|
|
||||||
**kwargs: 传递给模型的其他参数。
|
|
||||||
|
|
||||||
返回:
|
|
||||||
list[list[float]]: 文本的嵌入向量列表。
|
|
||||||
"""
|
|
||||||
if isinstance(texts, str):
|
|
||||||
texts = [texts]
|
|
||||||
if not texts:
|
|
||||||
return []
|
|
||||||
|
|
||||||
try:
|
|
||||||
resolved_model_str = (
|
|
||||||
model or self.config.default_embedding_model or self.config.model
|
|
||||||
)
|
|
||||||
if not resolved_model_str:
|
|
||||||
raise LLMException(
|
|
||||||
"使用 embed 功能时必须指定嵌入模型名称,"
|
|
||||||
"或在 AIConfig 中配置 default_embedding_model。",
|
|
||||||
code=LLMErrorCode.MODEL_NOT_FOUND,
|
|
||||||
)
|
|
||||||
resolved_model_str = self._resolve_model_name(resolved_model_str)
|
|
||||||
|
|
||||||
async with await get_model_instance(
|
|
||||||
resolved_model_str,
|
|
||||||
override_config=None,
|
|
||||||
) as embedding_model_instance:
|
|
||||||
return await embedding_model_instance.generate_embeddings(
|
|
||||||
texts, task_type=task_type, **kwargs
|
|
||||||
)
|
|
||||||
except LLMException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"文本嵌入失败: {e}", e=e)
|
|
||||||
raise LLMException(
|
|
||||||
f"文本嵌入失败: {e}", code=LLMErrorCode.EMBEDDING_FAILED, cause=e
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
message: str | LLMMessage | list[LLMContentPart],
|
message: str | LLMMessage | list[LLMContentPart],
|
||||||
*,
|
*,
|
||||||
model: ModelName = None,
|
model: ModelName = None,
|
||||||
|
tools: list[LLMTool] | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> LLMResponse:
|
||||||
"""
|
"""
|
||||||
聊天对话便捷函数
|
聊天对话便捷函数
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
message: 用户输入的消息。
|
message: 用户输入的消息。
|
||||||
model: 要使用的模型名称。
|
model: 要使用的模型名称。
|
||||||
|
tools: 本次对话可用的工具列表。
|
||||||
|
tool_choice: 强制模型使用的工具。
|
||||||
**kwargs: 传递给模型的其他参数。
|
**kwargs: 传递给模型的其他参数。
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
str: 模型的文本响应。
|
LLMResponse: 模型的完整响应,可能包含文本或工具调用请求。
|
||||||
"""
|
"""
|
||||||
ai = AI()
|
ai = AI()
|
||||||
return await ai.chat(message, model=model, **kwargs)
|
return await ai.chat(
|
||||||
|
message, model=model, tools=tools, tool_choice=tool_choice, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def code(
|
async def code(
|
||||||
@ -730,12 +242,14 @@ async def pipeline_chat(
|
|||||||
raise ValueError("模型链`model_chain`不能为空。")
|
raise ValueError("模型链`model_chain`不能为空。")
|
||||||
|
|
||||||
current_content: str | list[LLMContentPart]
|
current_content: str | list[LLMContentPart]
|
||||||
if isinstance(message, str):
|
if isinstance(message, UniMessage):
|
||||||
|
current_content = await unimsg_to_llm_parts(message)
|
||||||
|
elif isinstance(message, str):
|
||||||
current_content = message
|
current_content = message
|
||||||
elif isinstance(message, list):
|
elif isinstance(message, list):
|
||||||
current_content = message
|
current_content = message
|
||||||
else:
|
else:
|
||||||
current_content = await unimsg_to_llm_parts(message)
|
raise TypeError(f"不支持的消息类型: {type(message)}")
|
||||||
|
|
||||||
final_response: LLMResponse | None = None
|
final_response: LLMResponse | None = None
|
||||||
|
|
||||||
@ -787,3 +301,45 @@ async def pipeline_chat(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return final_response
|
return final_response
|
||||||
|
|
||||||
|
|
||||||
|
async def generate(
|
||||||
|
messages: list[LLMMessage],
|
||||||
|
*,
|
||||||
|
model: ModelName = None,
|
||||||
|
tools: list[LLMTool] | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""
|
||||||
|
根据完整的消息列表(包括系统指令)生成一次性响应。
|
||||||
|
这是一个便捷的函数,不使用或修改任何会话历史。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
messages: 用于生成响应的完整消息列表。
|
||||||
|
model: 要使用的模型名称。
|
||||||
|
tools: 可用的工具列表。
|
||||||
|
tool_choice: 工具选择策略。
|
||||||
|
**kwargs: 传递给模型的其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
LLMResponse: 模型的完整响应对象。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
ai_instance = AI()
|
||||||
|
resolved_model_name = ai_instance._resolve_model_name(model)
|
||||||
|
final_config_dict = ai_instance._merge_config(kwargs)
|
||||||
|
|
||||||
|
async with await get_model_instance(
|
||||||
|
resolved_model_name, override_config=final_config_dict
|
||||||
|
) as model_instance:
|
||||||
|
return await model_instance.generate_response(
|
||||||
|
messages,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
)
|
||||||
|
except LLMException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"生成响应失败: {e}", e=e)
|
||||||
|
raise LLMException(f"生成响应失败: {e}", cause=e)
|
||||||
|
|||||||
@ -17,6 +17,7 @@ from zhenxun.configs.utils import parse_as
|
|||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||||
|
|
||||||
|
from ..core import key_store
|
||||||
from ..types.models import ModelDetail, ProviderConfig
|
from ..types.models import ModelDetail, ProviderConfig
|
||||||
|
|
||||||
|
|
||||||
@ -502,12 +503,13 @@ def set_default_model(provider_model_name: str | None) -> bool:
|
|||||||
@PriorityLifecycle.on_startup(priority=10)
|
@PriorityLifecycle.on_startup(priority=10)
|
||||||
async def _init_llm_config_on_startup():
|
async def _init_llm_config_on_startup():
|
||||||
"""
|
"""
|
||||||
在服务启动时主动调用一次 get_llm_config,
|
在服务启动时主动调用一次 get_llm_config 和 key_store.initialize,
|
||||||
以触发必要的初始化操作,例如创建默认的 mcp_tools.json 文件。
|
以触发必要的初始化操作。
|
||||||
"""
|
"""
|
||||||
logger.info("正在初始化 LLM 配置并检查 MCP 工具文件...")
|
logger.info("正在初始化 LLM 配置并加载密钥状态...")
|
||||||
try:
|
try:
|
||||||
get_llm_config()
|
get_llm_config()
|
||||||
logger.info("LLM 配置初始化完成。")
|
await key_store.initialize()
|
||||||
|
logger.info("LLM 配置和密钥状态初始化完成。")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"LLM 配置初始化时发生错误: {e}", e=e)
|
logger.error(f"LLM 配置或密钥状态初始化时发生错误: {e}", e=e)
|
||||||
|
|||||||
@ -5,17 +5,27 @@ LLM 核心基础设施模块
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
from enum import IntEnum
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import aiofiles
|
||||||
import httpx
|
import httpx
|
||||||
|
import nonebot
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from zhenxun.configs.path_config import DATA_PATH
|
||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
from zhenxun.utils.user_agent import get_user_agent
|
from zhenxun.utils.user_agent import get_user_agent
|
||||||
|
|
||||||
from .types import ProviderConfig
|
from .types import ProviderConfig
|
||||||
from .types.exceptions import LLMErrorCode, LLMException
|
from .types.exceptions import LLMErrorCode, LLMException
|
||||||
|
|
||||||
|
driver = nonebot.get_driver()
|
||||||
|
|
||||||
|
|
||||||
class HttpClientConfig(BaseModel):
|
class HttpClientConfig(BaseModel):
|
||||||
"""HTTP客户端配置"""
|
"""HTTP客户端配置"""
|
||||||
@ -194,6 +204,82 @@ async def create_llm_http_client(
|
|||||||
return LLMHttpClient(config)
|
return LLMHttpClient(config)
|
||||||
|
|
||||||
|
|
||||||
|
class KeyStatus(IntEnum):
|
||||||
|
"""用于排序和展示的密钥状态枚举"""
|
||||||
|
|
||||||
|
DISABLED = 0
|
||||||
|
ERROR = 1
|
||||||
|
COOLDOWN = 2
|
||||||
|
WARNING = 3
|
||||||
|
HEALTHY = 4
|
||||||
|
UNUSED = 5
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class KeyStats:
|
||||||
|
"""单个API Key的详细状态和统计信息"""
|
||||||
|
|
||||||
|
cooldown_until: float = 0.0
|
||||||
|
success_count: int = 0
|
||||||
|
failure_count: int = 0
|
||||||
|
total_latency: float = 0.0
|
||||||
|
last_error_info: str | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_available(self) -> bool:
|
||||||
|
"""检查Key当前是否可用"""
|
||||||
|
return time.time() >= self.cooldown_until
|
||||||
|
|
||||||
|
@property
|
||||||
|
def avg_latency(self) -> float:
|
||||||
|
"""计算平均延迟"""
|
||||||
|
return (
|
||||||
|
self.total_latency / self.success_count if self.success_count > 0 else 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def success_rate(self) -> float:
|
||||||
|
"""计算成功率"""
|
||||||
|
total = self.success_count + self.failure_count
|
||||||
|
return self.success_count / total * 100 if total > 0 else 100.0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def status(self) -> KeyStatus:
|
||||||
|
"""根据当前统计数据动态计算状态"""
|
||||||
|
now = time.time()
|
||||||
|
cooldown_left = max(0, self.cooldown_until - now)
|
||||||
|
|
||||||
|
if cooldown_left > 31536000 - 60:
|
||||||
|
return KeyStatus.DISABLED
|
||||||
|
if cooldown_left > 0:
|
||||||
|
return KeyStatus.COOLDOWN
|
||||||
|
|
||||||
|
total_calls = self.success_count + self.failure_count
|
||||||
|
if total_calls == 0:
|
||||||
|
return KeyStatus.UNUSED
|
||||||
|
|
||||||
|
if self.success_rate < 80:
|
||||||
|
return KeyStatus.ERROR
|
||||||
|
|
||||||
|
if total_calls >= 5 and self.avg_latency > 15000:
|
||||||
|
return KeyStatus.WARNING
|
||||||
|
|
||||||
|
return KeyStatus.HEALTHY
|
||||||
|
|
||||||
|
@property
|
||||||
|
def suggested_action(self) -> str:
|
||||||
|
"""根据状态给出建议操作"""
|
||||||
|
status_actions = {
|
||||||
|
KeyStatus.DISABLED: "更换Key",
|
||||||
|
KeyStatus.ERROR: "检查网络/重置",
|
||||||
|
KeyStatus.COOLDOWN: "等待/重置",
|
||||||
|
KeyStatus.WARNING: "观察",
|
||||||
|
KeyStatus.HEALTHY: "-",
|
||||||
|
KeyStatus.UNUSED: "-",
|
||||||
|
}
|
||||||
|
return status_actions.get(self.status, "未知")
|
||||||
|
|
||||||
|
|
||||||
class RetryConfig:
|
class RetryConfig:
|
||||||
"""重试配置"""
|
"""重试配置"""
|
||||||
|
|
||||||
@ -236,25 +322,37 @@ async def with_smart_retry(
|
|||||||
last_exception: Exception | None = None
|
last_exception: Exception | None = None
|
||||||
failed_keys: set[str] = set()
|
failed_keys: set[str] = set()
|
||||||
|
|
||||||
|
model_instance = next((arg for arg in args if hasattr(arg, "api_keys")), None)
|
||||||
|
all_provider_keys = model_instance.api_keys if model_instance else []
|
||||||
|
|
||||||
for attempt in range(config.max_retries + 1):
|
for attempt in range(config.max_retries + 1):
|
||||||
try:
|
try:
|
||||||
if config.key_rotation and "failed_keys" in func.__code__.co_varnames:
|
if config.key_rotation and "failed_keys" in func.__code__.co_varnames:
|
||||||
kwargs["failed_keys"] = failed_keys
|
kwargs["failed_keys"] = failed_keys
|
||||||
|
|
||||||
return await func(*args, **kwargs)
|
start_time = time.monotonic()
|
||||||
|
result = await func(*args, **kwargs)
|
||||||
|
latency = (time.monotonic() - start_time) * 1000
|
||||||
|
|
||||||
|
if key_store and isinstance(result, tuple) and len(result) == 2:
|
||||||
|
final_result, api_key_used = result
|
||||||
|
if api_key_used:
|
||||||
|
await key_store.record_success(api_key_used, latency)
|
||||||
|
return final_result
|
||||||
|
else:
|
||||||
|
return result
|
||||||
|
|
||||||
except LLMException as e:
|
except LLMException as e:
|
||||||
last_exception = e
|
last_exception = e
|
||||||
|
api_key_in_use = e.details.get("api_key")
|
||||||
|
|
||||||
if e.code in [
|
if api_key_in_use:
|
||||||
LLMErrorCode.API_KEY_INVALID,
|
failed_keys.add(api_key_in_use)
|
||||||
LLMErrorCode.API_QUOTA_EXCEEDED,
|
if key_store and provider_name and len(all_provider_keys) > 1:
|
||||||
]:
|
status_code = e.details.get("status_code")
|
||||||
if hasattr(e, "details") and e.details and "api_key" in e.details:
|
error_message = f"({e.code.name}) {e.message}"
|
||||||
failed_keys.add(e.details["api_key"])
|
|
||||||
if key_store and provider_name:
|
|
||||||
await key_store.record_failure(
|
await key_store.record_failure(
|
||||||
e.details["api_key"], e.details.get("status_code")
|
api_key_in_use, status_code, error_message
|
||||||
)
|
)
|
||||||
|
|
||||||
should_retry = _should_retry_llm_error(e, attempt, config.max_retries)
|
should_retry = _should_retry_llm_error(e, attempt, config.max_retries)
|
||||||
@ -267,7 +365,7 @@ async def with_smart_retry(
|
|||||||
if config.exponential_backoff:
|
if config.exponential_backoff:
|
||||||
wait_time *= 2**attempt
|
wait_time *= 2**attempt
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"请求失败,{wait_time}秒后重试 (第{attempt + 1}次): {e}"
|
f"请求失败,{wait_time:.2f}秒后重试 (第{attempt + 1}次): {e}"
|
||||||
)
|
)
|
||||||
await asyncio.sleep(wait_time)
|
await asyncio.sleep(wait_time)
|
||||||
else:
|
else:
|
||||||
@ -325,14 +423,66 @@ def _should_retry_llm_error(
|
|||||||
|
|
||||||
|
|
||||||
class KeyStatusStore:
|
class KeyStatusStore:
|
||||||
"""API Key 状态管理存储 - 优化版本,支持轮询和负载均衡"""
|
"""API Key 状态管理存储 - 支持持久化"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._key_status: dict[str, bool] = {}
|
self._key_stats: dict[str, KeyStats] = {}
|
||||||
self._key_usage_count: dict[str, int] = {}
|
|
||||||
self._key_last_used: dict[str, float] = {}
|
|
||||||
self._provider_key_index: dict[str, int] = {}
|
self._provider_key_index: dict[str, int] = {}
|
||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
|
self._file_path = DATA_PATH / "llm" / "key_status.json"
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
"""从文件异步加载密钥状态,在应用启动时调用"""
|
||||||
|
async with self._lock:
|
||||||
|
if not self._file_path.exists():
|
||||||
|
logger.info("未找到密钥状态文件,将使用内存状态启动。")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"正在从 {self._file_path} 加载密钥状态...")
|
||||||
|
async with aiofiles.open(self._file_path, encoding="utf-8") as f:
|
||||||
|
content = await f.read()
|
||||||
|
if not content:
|
||||||
|
logger.warning("密钥状态文件为空。")
|
||||||
|
return
|
||||||
|
data = json.loads(content)
|
||||||
|
|
||||||
|
for key, stats_dict in data.items():
|
||||||
|
self._key_stats[key] = KeyStats(**stats_dict)
|
||||||
|
|
||||||
|
logger.info(f"成功加载 {len(self._key_stats)} 个密钥的状态。")
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.error(f"密钥状态文件 {self._file_path} 格式错误,无法解析。")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"加载密钥状态文件时发生错误: {e}", e=e)
|
||||||
|
|
||||||
|
async def _save_to_file_internal(self):
|
||||||
|
"""
|
||||||
|
[内部方法] 将当前密钥状态安全地写入JSON文件。
|
||||||
|
假定调用方已持有锁。
|
||||||
|
"""
|
||||||
|
data_to_save = {key: asdict(stats) for key, stats in self._key_stats.items()}
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
temp_path = self._file_path.with_suffix(".json.tmp")
|
||||||
|
|
||||||
|
async with aiofiles.open(temp_path, "w", encoding="utf-8") as f:
|
||||||
|
await f.write(json.dumps(data_to_save, ensure_ascii=False, indent=2))
|
||||||
|
|
||||||
|
if self._file_path.exists():
|
||||||
|
self._file_path.unlink()
|
||||||
|
os.rename(temp_path, self._file_path)
|
||||||
|
logger.debug("密钥状态已成功持久化到文件。")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"保存密钥状态到文件失败: {e}", e=e)
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
"""在应用关闭时安全地保存状态"""
|
||||||
|
async with self._lock:
|
||||||
|
await self._save_to_file_internal()
|
||||||
|
logger.info("KeyStatusStore 已在关闭前保存状态。")
|
||||||
|
|
||||||
async def get_next_available_key(
|
async def get_next_available_key(
|
||||||
self,
|
self,
|
||||||
@ -355,88 +505,122 @@ class KeyStatusStore:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
exclude_keys = exclude_keys or set()
|
exclude_keys = exclude_keys or set()
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
for key in api_keys:
|
||||||
|
if key not in self._key_stats:
|
||||||
|
self._key_stats[key] = KeyStats()
|
||||||
|
|
||||||
available_keys = [
|
available_keys = [
|
||||||
key
|
key
|
||||||
for key in api_keys
|
for key in api_keys
|
||||||
if key not in exclude_keys and self._key_status.get(key, True)
|
if key not in exclude_keys and self._key_stats[key].is_available
|
||||||
]
|
]
|
||||||
|
|
||||||
if not available_keys:
|
if not available_keys:
|
||||||
return api_keys[0] if api_keys else None
|
return api_keys[0]
|
||||||
|
|
||||||
async with self._lock:
|
|
||||||
current_index = self._provider_key_index.get(provider_name, 0)
|
current_index = self._provider_key_index.get(provider_name, 0)
|
||||||
|
|
||||||
selected_key = available_keys[current_index % len(available_keys)]
|
selected_key = available_keys[current_index % len(available_keys)]
|
||||||
|
self._provider_key_index[provider_name] = current_index + 1
|
||||||
|
|
||||||
self._provider_key_index[provider_name] = (current_index + 1) % len(
|
total_usage = (
|
||||||
available_keys
|
self._key_stats[selected_key].success_count
|
||||||
|
+ self._key_stats[selected_key].failure_count
|
||||||
)
|
)
|
||||||
|
|
||||||
import time
|
|
||||||
|
|
||||||
self._key_usage_count[selected_key] = (
|
|
||||||
self._key_usage_count.get(selected_key, 0) + 1
|
|
||||||
)
|
|
||||||
self._key_last_used[selected_key] = time.time()
|
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"轮询选择API密钥: {self._get_key_id(selected_key)} "
|
f"轮询选择API密钥: {self._get_key_id(selected_key)} "
|
||||||
f"(使用次数: {self._key_usage_count[selected_key]})"
|
f"(使用次数: {total_usage})"
|
||||||
)
|
)
|
||||||
|
|
||||||
return selected_key
|
return selected_key
|
||||||
|
|
||||||
async def record_success(self, api_key: str):
|
async def record_success(self, api_key: str, latency: float):
|
||||||
"""记录成功使用"""
|
"""记录成功使用,并持久化"""
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
self._key_status[api_key] = True
|
stats = self._key_stats.setdefault(api_key, KeyStats())
|
||||||
logger.debug(f"记录API密钥成功使用: {self._get_key_id(api_key)}")
|
stats.cooldown_until = 0.0
|
||||||
|
stats.success_count += 1
|
||||||
|
stats.total_latency += latency
|
||||||
|
stats.last_error_info = None
|
||||||
|
await self._save_to_file_internal()
|
||||||
|
logger.debug(
|
||||||
|
f"记录API密钥成功使用: {self._get_key_id(api_key)}, 延迟: {latency:.2f}ms"
|
||||||
|
)
|
||||||
|
|
||||||
async def record_failure(self, api_key: str, status_code: int | None):
|
async def record_failure(
|
||||||
|
self, api_key: str, status_code: int | None, error_message: str
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
记录失败使用
|
记录失败使用,并设置冷却时间
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
api_key: API密钥。
|
api_key: API密钥。
|
||||||
status_code: HTTP状态码。
|
status_code: HTTP状态码。
|
||||||
|
error_message: 错误信息。
|
||||||
"""
|
"""
|
||||||
key_id = self._get_key_id(api_key)
|
key_id = self._get_key_id(api_key)
|
||||||
async with self._lock:
|
now = time.time()
|
||||||
if status_code in [401, 403]:
|
cooldown_duration = 300
|
||||||
self._key_status[api_key] = False
|
|
||||||
logger.warning(
|
if status_code in [401, 403, 404]:
|
||||||
f"API密钥认证失败,标记为不可用: {key_id} (状态码: {status_code})"
|
cooldown_duration = 31536000
|
||||||
)
|
log_level = "error"
|
||||||
|
log_message = f"API密钥认证/权限/路径错误,将永久禁用: {key_id}"
|
||||||
|
elif status_code == 429:
|
||||||
|
cooldown_duration = 60
|
||||||
|
log_level = "warning"
|
||||||
|
log_message = f"API密钥被限流,冷却60秒: {key_id}"
|
||||||
else:
|
else:
|
||||||
logger.debug(f"记录API密钥失败使用: {key_id} (状态码: {status_code})")
|
log_level = "warning"
|
||||||
|
log_message = f"API密钥遇到临时性错误,冷却{cooldown_duration}秒: {key_id}"
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
stats = self._key_stats.setdefault(api_key, KeyStats())
|
||||||
|
stats.cooldown_until = now + cooldown_duration
|
||||||
|
stats.failure_count += 1
|
||||||
|
stats.last_error_info = error_message[:256]
|
||||||
|
await self._save_to_file_internal()
|
||||||
|
|
||||||
|
getattr(logger, log_level)(log_message)
|
||||||
|
|
||||||
async def reset_key_status(self, api_key: str):
|
async def reset_key_status(self, api_key: str):
|
||||||
"""重置密钥状态(用于恢复机制)"""
|
"""重置密钥状态,并持久化"""
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
self._key_status[api_key] = True
|
stats = self._key_stats.setdefault(api_key, KeyStats())
|
||||||
|
stats.cooldown_until = 0.0
|
||||||
|
stats.last_error_info = None
|
||||||
|
await self._save_to_file_internal()
|
||||||
logger.info(f"重置API密钥状态: {self._get_key_id(api_key)}")
|
logger.info(f"重置API密钥状态: {self._get_key_id(api_key)}")
|
||||||
|
|
||||||
async def get_key_stats(self, api_keys: list[str]) -> dict[str, dict]:
|
async def get_key_stats(self, api_keys: list[str]) -> dict[str, dict]:
|
||||||
"""
|
"""
|
||||||
获取密钥使用统计
|
获取密钥使用统计,并计算出用于展示的派生数据。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
api_keys: API密钥列表。
|
api_keys: API密钥列表。
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
dict[str, dict]: 密钥统计信息字典。
|
dict[str, dict]: 包含丰富状态和统计信息的密钥字典。
|
||||||
"""
|
"""
|
||||||
stats = {}
|
stats_dict = {}
|
||||||
|
now = time.time()
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
for key in api_keys:
|
for key in api_keys:
|
||||||
key_id = self._get_key_id(key)
|
key_id = self._get_key_id(key)
|
||||||
stats[key_id] = {
|
stats = self._key_stats.get(key, KeyStats())
|
||||||
"available": self._key_status.get(key, True),
|
|
||||||
"usage_count": self._key_usage_count.get(key, 0),
|
stats_dict[key_id] = {
|
||||||
"last_used": self._key_last_used.get(key, 0),
|
"status_enum": stats.status,
|
||||||
|
"cooldown_seconds_left": max(0, stats.cooldown_until - now),
|
||||||
|
"total_calls": stats.success_count + stats.failure_count,
|
||||||
|
"success_count": stats.success_count,
|
||||||
|
"failure_count": stats.failure_count,
|
||||||
|
"success_rate": stats.success_rate,
|
||||||
|
"avg_latency": stats.avg_latency,
|
||||||
|
"last_error": stats.last_error_info,
|
||||||
|
"suggested_action": stats.suggested_action,
|
||||||
}
|
}
|
||||||
return stats
|
return stats_dict
|
||||||
|
|
||||||
def _get_key_id(self, api_key: str) -> str:
|
def _get_key_id(self, api_key: str) -> str:
|
||||||
"""获取API密钥的标识符(用于日志)"""
|
"""获取API密钥的标识符(用于日志)"""
|
||||||
@ -446,3 +630,8 @@ class KeyStatusStore:
|
|||||||
|
|
||||||
|
|
||||||
key_store = KeyStatusStore()
|
key_store = KeyStatusStore()
|
||||||
|
|
||||||
|
|
||||||
|
@driver.on_shutdown
|
||||||
|
async def _shutdown_key_store():
|
||||||
|
await key_store.shutdown()
|
||||||
|
|||||||
@ -137,8 +137,8 @@ def get_configured_providers() -> list[ProviderConfig]:
|
|||||||
valid_providers.append(item)
|
valid_providers.append(item)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"配置文件中第 {i + 1} 项未能正确解析为 ProviderConfig 对象,"
|
f"配置文件中第 {i + 1} 项未能正确解析为 ProviderConfig 对象,已跳过。"
|
||||||
f"已跳过。实际类型: {type(item)}"
|
f"实际类型: {type(item)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return valid_providers
|
return valid_providers
|
||||||
|
|||||||
@ -46,17 +46,7 @@ class LLMModelBase(ABC):
|
|||||||
history: list[dict[str, str]] | None = None,
|
history: list[dict[str, str]] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""生成文本"""
|
||||||
生成文本
|
|
||||||
|
|
||||||
参数:
|
|
||||||
prompt: 输入提示词。
|
|
||||||
history: 对话历史记录。
|
|
||||||
**kwargs: 其他参数。
|
|
||||||
|
|
||||||
返回:
|
|
||||||
str: 生成的文本。
|
|
||||||
"""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -68,19 +58,7 @@ class LLMModelBase(ABC):
|
|||||||
tool_choice: str | dict[str, Any] | None = None,
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""
|
"""生成高级响应"""
|
||||||
生成高级响应
|
|
||||||
|
|
||||||
参数:
|
|
||||||
messages: 消息列表。
|
|
||||||
config: 生成配置。
|
|
||||||
tools: 工具列表。
|
|
||||||
tool_choice: 工具选择策略。
|
|
||||||
**kwargs: 其他参数。
|
|
||||||
|
|
||||||
返回:
|
|
||||||
LLMResponse: 模型响应。
|
|
||||||
"""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -90,17 +68,7 @@ class LLMModelBase(ABC):
|
|||||||
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> list[list[float]]:
|
) -> list[list[float]]:
|
||||||
"""
|
"""生成文本嵌入向量"""
|
||||||
生成文本嵌入向量
|
|
||||||
|
|
||||||
参数:
|
|
||||||
texts: 文本列表。
|
|
||||||
task_type: 嵌入任务类型。
|
|
||||||
**kwargs: 其他参数。
|
|
||||||
|
|
||||||
返回:
|
|
||||||
list[list[float]]: 嵌入向量列表。
|
|
||||||
"""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -208,28 +176,8 @@ class LLMModel(LLMModelBase):
|
|||||||
http_client: "LLMHttpClient",
|
http_client: "LLMHttpClient",
|
||||||
failed_keys: set[str] | None = None,
|
failed_keys: set[str] | None = None,
|
||||||
log_context: str = "API",
|
log_context: str = "API",
|
||||||
) -> Any:
|
) -> tuple[Any, str]:
|
||||||
"""
|
"""执行API调用的通用核心方法"""
|
||||||
执行API调用的通用核心方法。
|
|
||||||
|
|
||||||
该方法封装了以下通用逻辑:
|
|
||||||
1. 选择API密钥。
|
|
||||||
2. 准备和记录请求。
|
|
||||||
3. 发送HTTP POST请求。
|
|
||||||
4. 处理HTTP错误和API特定错误。
|
|
||||||
5. 记录密钥使用状态。
|
|
||||||
6. 解析成功的响应。
|
|
||||||
|
|
||||||
参数:
|
|
||||||
prepare_request_func: 准备请求的函数。
|
|
||||||
parse_response_func: 解析响应的函数。
|
|
||||||
http_client: HTTP客户端。
|
|
||||||
failed_keys: 失败的密钥集合。
|
|
||||||
log_context: 日志上下文。
|
|
||||||
|
|
||||||
返回:
|
|
||||||
Any: 解析后的响应数据。
|
|
||||||
"""
|
|
||||||
api_key = await self._select_api_key(failed_keys)
|
api_key = await self._select_api_key(failed_keys)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -267,7 +215,9 @@ class LLMModel(LLMModelBase):
|
|||||||
)
|
)
|
||||||
logger.debug(f"💥 完整错误响应: {error_text}")
|
logger.debug(f"💥 完整错误响应: {error_text}")
|
||||||
|
|
||||||
await self.key_store.record_failure(api_key, http_response.status_code)
|
await self.key_store.record_failure(
|
||||||
|
api_key, http_response.status_code, error_text
|
||||||
|
)
|
||||||
|
|
||||||
if http_response.status_code in [401, 403]:
|
if http_response.status_code in [401, 403]:
|
||||||
error_code = LLMErrorCode.API_KEY_INVALID
|
error_code = LLMErrorCode.API_KEY_INVALID
|
||||||
@ -298,7 +248,7 @@ class LLMModel(LLMModelBase):
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"解析 {log_context} 响应失败: {e}", e=e)
|
logger.error(f"解析 {log_context} 响应失败: {e}", e=e)
|
||||||
await self.key_store.record_failure(api_key, None)
|
await self.key_store.record_failure(api_key, None, str(e))
|
||||||
if isinstance(e, LLMException):
|
if isinstance(e, LLMException):
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
@ -308,17 +258,15 @@ class LLMModel(LLMModelBase):
|
|||||||
cause=e,
|
cause=e,
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.key_store.record_success(api_key)
|
|
||||||
logger.debug(f"✅ API密钥使用成功: {masked_key}")
|
|
||||||
logger.info(f"🎯 LLM响应解析完成 [{log_context}]")
|
logger.info(f"🎯 LLM响应解析完成 [{log_context}]")
|
||||||
return parsed_data
|
return parsed_data, api_key
|
||||||
|
|
||||||
except LLMException:
|
except LLMException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_log_msg = f"生成 {log_context.lower()} 时发生未预期错误: {e}"
|
error_log_msg = f"生成 {log_context.lower()} 时发生未预期错误: {e}"
|
||||||
logger.error(error_log_msg, e=e)
|
logger.error(error_log_msg, e=e)
|
||||||
await self.key_store.record_failure(api_key, None)
|
await self.key_store.record_failure(api_key, None, str(e))
|
||||||
raise LLMException(
|
raise LLMException(
|
||||||
error_log_msg,
|
error_log_msg,
|
||||||
code=LLMErrorCode.GENERATION_FAILED
|
code=LLMErrorCode.GENERATION_FAILED
|
||||||
@ -349,13 +297,14 @@ class LLMModel(LLMModelBase):
|
|||||||
adapter.validate_embedding_response(response_json)
|
adapter.validate_embedding_response(response_json)
|
||||||
return adapter.parse_embedding_response(response_json)
|
return adapter.parse_embedding_response(response_json)
|
||||||
|
|
||||||
return await self._perform_api_call(
|
parsed_data, api_key_used = await self._perform_api_call(
|
||||||
prepare_request_func=prepare_request,
|
prepare_request_func=prepare_request,
|
||||||
parse_response_func=parse_response,
|
parse_response_func=parse_response,
|
||||||
http_client=http_client,
|
http_client=http_client,
|
||||||
failed_keys=failed_keys,
|
failed_keys=failed_keys,
|
||||||
log_context="Embedding",
|
log_context="Embedding",
|
||||||
)
|
)
|
||||||
|
return parsed_data
|
||||||
|
|
||||||
async def _execute_with_smart_retry(
|
async def _execute_with_smart_retry(
|
||||||
self,
|
self,
|
||||||
@ -394,8 +343,8 @@ class LLMModel(LLMModelBase):
|
|||||||
tool_choice: str | dict[str, Any] | None,
|
tool_choice: str | dict[str, Any] | None,
|
||||||
http_client: LLMHttpClient,
|
http_client: LLMHttpClient,
|
||||||
failed_keys: set[str] | None = None,
|
failed_keys: set[str] | None = None,
|
||||||
) -> LLMResponse:
|
) -> tuple[LLMResponse, str]:
|
||||||
"""执行单次请求 - 供重试机制调用,直接返回 LLMResponse"""
|
"""执行单次请求 - 供重试机制调用,直接返回 LLMResponse 和使用的 key"""
|
||||||
|
|
||||||
async def prepare_request(api_key: str) -> RequestData:
|
async def prepare_request(api_key: str) -> RequestData:
|
||||||
return await adapter.prepare_advanced_request(
|
return await adapter.prepare_advanced_request(
|
||||||
@ -441,19 +390,17 @@ class LLMModel(LLMModelBase):
|
|||||||
cache_info=response_data.cache_info,
|
cache_info=response_data.cache_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
return await self._perform_api_call(
|
parsed_data, api_key_used = await self._perform_api_call(
|
||||||
prepare_request_func=prepare_request,
|
prepare_request_func=prepare_request,
|
||||||
parse_response_func=parse_response,
|
parse_response_func=parse_response,
|
||||||
http_client=http_client,
|
http_client=http_client,
|
||||||
failed_keys=failed_keys,
|
failed_keys=failed_keys,
|
||||||
log_context="Generation",
|
log_context="Generation",
|
||||||
)
|
)
|
||||||
|
return parsed_data, api_key_used
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
"""
|
"""标记模型实例的当前使用周期结束"""
|
||||||
标记模型实例的当前使用周期结束。
|
|
||||||
共享的 HTTP 客户端由 LLMHttpClientManager 管理,不由 LLMModel 关闭。
|
|
||||||
"""
|
|
||||||
if self._is_closed:
|
if self._is_closed:
|
||||||
return
|
return
|
||||||
self._is_closed = True
|
self._is_closed = True
|
||||||
@ -487,17 +434,7 @@ class LLMModel(LLMModelBase):
|
|||||||
history: list[dict[str, str]] | None = None,
|
history: list[dict[str, str]] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""生成文本"""
|
||||||
生成文本 - 通过 generate_response 实现
|
|
||||||
|
|
||||||
参数:
|
|
||||||
prompt: 输入提示词。
|
|
||||||
history: 对话历史记录。
|
|
||||||
**kwargs: 其他参数。
|
|
||||||
|
|
||||||
返回:
|
|
||||||
str: 生成的文本。
|
|
||||||
"""
|
|
||||||
self._check_not_closed()
|
self._check_not_closed()
|
||||||
|
|
||||||
messages: list[LLMMessage] = []
|
messages: list[LLMMessage] = []
|
||||||
@ -538,19 +475,7 @@ class LLMModel(LLMModelBase):
|
|||||||
tool_choice: str | dict[str, Any] | None = None,
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""
|
"""生成高级响应"""
|
||||||
生成高级响应
|
|
||||||
|
|
||||||
参数:
|
|
||||||
messages: 消息列表。
|
|
||||||
config: 生成配置。
|
|
||||||
tools: 工具列表。
|
|
||||||
tool_choice: 工具选择策略。
|
|
||||||
**kwargs: 其他参数。
|
|
||||||
|
|
||||||
返回:
|
|
||||||
LLMResponse: 模型响应。
|
|
||||||
"""
|
|
||||||
self._check_not_closed()
|
self._check_not_closed()
|
||||||
|
|
||||||
from .adapters import get_adapter_for_api_type
|
from .adapters import get_adapter_for_api_type
|
||||||
@ -619,17 +544,7 @@ class LLMModel(LLMModelBase):
|
|||||||
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> list[list[float]]:
|
) -> list[list[float]]:
|
||||||
"""
|
"""生成文本嵌入向量"""
|
||||||
生成文本嵌入向量
|
|
||||||
|
|
||||||
参数:
|
|
||||||
texts: 文本列表。
|
|
||||||
task_type: 嵌入任务类型。
|
|
||||||
**kwargs: 其他参数。
|
|
||||||
|
|
||||||
返回:
|
|
||||||
list[list[float]]: 嵌入向量列表。
|
|
||||||
"""
|
|
||||||
self._check_not_closed()
|
self._check_not_closed()
|
||||||
if not texts:
|
if not texts:
|
||||||
return []
|
return []
|
||||||
|
|||||||
532
zhenxun/services/llm/session.py
Normal file
532
zhenxun/services/llm/session.py
Normal file
@ -0,0 +1,532 @@
|
|||||||
|
"""
|
||||||
|
LLM 服务 - 会话客户端
|
||||||
|
|
||||||
|
提供一个有状态的、面向会话的 LLM 客户端,用于进行多轮对话和复杂交互。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from nonebot_plugin_alconna.uniseg import UniMessage
|
||||||
|
|
||||||
|
from zhenxun.services.log import logger
|
||||||
|
|
||||||
|
from .config import CommonOverrides, LLMGenerationConfig
|
||||||
|
from .config.providers import get_ai_config
|
||||||
|
from .manager import get_global_default_model_name, get_model_instance
|
||||||
|
from .tools import tool_registry
|
||||||
|
from .types import (
|
||||||
|
EmbeddingTaskType,
|
||||||
|
LLMContentPart,
|
||||||
|
LLMErrorCode,
|
||||||
|
LLMException,
|
||||||
|
LLMMessage,
|
||||||
|
LLMResponse,
|
||||||
|
LLMTool,
|
||||||
|
ModelName,
|
||||||
|
)
|
||||||
|
from .utils import unimsg_to_llm_parts
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AIConfig:
|
||||||
|
"""AI配置类 - 简化版本"""
|
||||||
|
|
||||||
|
model: ModelName = None
|
||||||
|
default_embedding_model: ModelName = None
|
||||||
|
temperature: float | None = None
|
||||||
|
max_tokens: int | None = None
|
||||||
|
enable_cache: bool = False
|
||||||
|
enable_code: bool = False
|
||||||
|
enable_search: bool = False
|
||||||
|
timeout: int | None = None
|
||||||
|
|
||||||
|
enable_gemini_json_mode: bool = False
|
||||||
|
enable_gemini_thinking: bool = False
|
||||||
|
enable_gemini_safe_mode: bool = False
|
||||||
|
enable_gemini_multimodal: bool = False
|
||||||
|
enable_gemini_grounding: bool = False
|
||||||
|
default_preserve_media_in_history: bool = False
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""初始化后从配置中读取默认值"""
|
||||||
|
ai_config = get_ai_config()
|
||||||
|
if self.model is None:
|
||||||
|
self.model = ai_config.get("default_model_name")
|
||||||
|
if self.timeout is None:
|
||||||
|
self.timeout = ai_config.get("timeout", 180)
|
||||||
|
|
||||||
|
|
||||||
|
class AI:
|
||||||
|
"""统一的AI服务类 - 平衡设计版本
|
||||||
|
|
||||||
|
提供三层API:
|
||||||
|
1. 简单方法:ai.chat(), ai.code(), ai.search()
|
||||||
|
2. 标准方法:ai.analyze() 支持复杂参数
|
||||||
|
3. 高级方法:通过get_model_instance()直接访问
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, config: AIConfig | None = None, history: list[LLMMessage] | None = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化AI服务
|
||||||
|
|
||||||
|
参数:
|
||||||
|
config: AI 配置.
|
||||||
|
history: 可选的初始对话历史.
|
||||||
|
"""
|
||||||
|
self.config = config or AIConfig()
|
||||||
|
self.history = history or []
|
||||||
|
|
||||||
|
def clear_history(self):
|
||||||
|
"""清空当前会话的历史记录"""
|
||||||
|
self.history = []
|
||||||
|
logger.info("AI session history cleared.")
|
||||||
|
|
||||||
|
def _sanitize_message_for_history(self, message: LLMMessage) -> LLMMessage:
|
||||||
|
"""
|
||||||
|
净化用于存入历史记录的消息。
|
||||||
|
将非文本的多模态内容部分替换为文本占位符,以避免重复处理。
|
||||||
|
"""
|
||||||
|
if not isinstance(message.content, list):
|
||||||
|
return message
|
||||||
|
|
||||||
|
sanitized_message = copy.deepcopy(message)
|
||||||
|
content_list = sanitized_message.content
|
||||||
|
if not isinstance(content_list, list):
|
||||||
|
return sanitized_message
|
||||||
|
|
||||||
|
new_content_parts: list[LLMContentPart] = []
|
||||||
|
has_multimodal_content = False
|
||||||
|
|
||||||
|
for part in content_list:
|
||||||
|
if isinstance(part, LLMContentPart) and part.type == "text":
|
||||||
|
new_content_parts.append(part)
|
||||||
|
else:
|
||||||
|
has_multimodal_content = True
|
||||||
|
|
||||||
|
if has_multimodal_content:
|
||||||
|
placeholder = "[用户发送了媒体文件,内容已在首次分析时处理]"
|
||||||
|
text_part_found = False
|
||||||
|
for part in new_content_parts:
|
||||||
|
if part.type == "text":
|
||||||
|
part.text = f"{placeholder} {part.text or ''}".strip()
|
||||||
|
text_part_found = True
|
||||||
|
break
|
||||||
|
if not text_part_found:
|
||||||
|
new_content_parts.insert(0, LLMContentPart.text_part(placeholder))
|
||||||
|
|
||||||
|
sanitized_message.content = new_content_parts
|
||||||
|
return sanitized_message
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
message: str | LLMMessage | list[LLMContentPart],
|
||||||
|
*,
|
||||||
|
model: ModelName = None,
|
||||||
|
preserve_media_in_history: bool | None = None,
|
||||||
|
tools: list[LLMTool] | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""
|
||||||
|
进行一次聊天对话,支持工具调用。
|
||||||
|
此方法会自动使用和更新会话内的历史记录。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
message: 用户输入的消息。
|
||||||
|
model: 本次对话要使用的模型。
|
||||||
|
preserve_media_in_history: 是否在历史记录中保留原始多模态信息。
|
||||||
|
- True: 保留,用于深度多轮媒体分析。
|
||||||
|
- False: 不保留,替换为占位符,提高效率。
|
||||||
|
- None (默认): 使用AI实例配置的默认值。
|
||||||
|
tools: 本次对话可用的工具列表。
|
||||||
|
tool_choice: 强制模型使用的工具。
|
||||||
|
**kwargs: 传递给模型的其他生成参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
LLMResponse: 模型的完整响应,可能包含文本或工具调用请求。
|
||||||
|
"""
|
||||||
|
current_message: LLMMessage
|
||||||
|
if isinstance(message, str):
|
||||||
|
current_message = LLMMessage.user(message)
|
||||||
|
elif isinstance(message, list) and all(
|
||||||
|
isinstance(part, LLMContentPart) for part in message
|
||||||
|
):
|
||||||
|
current_message = LLMMessage.user(message)
|
||||||
|
elif isinstance(message, LLMMessage):
|
||||||
|
current_message = message
|
||||||
|
else:
|
||||||
|
raise LLMException(
|
||||||
|
f"AI.chat 不支持的消息类型: {type(message)}. "
|
||||||
|
"请使用 str, LLMMessage, 或 list[LLMContentPart]. "
|
||||||
|
"对于更复杂的多模态输入或文件路径,请使用 AI.analyze().",
|
||||||
|
code=LLMErrorCode.API_REQUEST_FAILED,
|
||||||
|
)
|
||||||
|
|
||||||
|
final_messages = [*self.history, current_message]
|
||||||
|
|
||||||
|
response = await self._execute_generation(
|
||||||
|
messages=final_messages,
|
||||||
|
model_name=model,
|
||||||
|
error_message="聊天失败",
|
||||||
|
config_overrides=kwargs,
|
||||||
|
llm_tools=tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
)
|
||||||
|
|
||||||
|
should_preserve = (
|
||||||
|
preserve_media_in_history
|
||||||
|
if preserve_media_in_history is not None
|
||||||
|
else self.config.default_preserve_media_in_history
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_preserve:
|
||||||
|
logger.debug("深度分析模式:在历史记录中保留原始多模态消息。")
|
||||||
|
self.history.append(current_message)
|
||||||
|
else:
|
||||||
|
logger.debug("高效模式:净化历史记录中的多模态消息。")
|
||||||
|
sanitized_user_message = self._sanitize_message_for_history(current_message)
|
||||||
|
self.history.append(sanitized_user_message)
|
||||||
|
|
||||||
|
self.history.append(
|
||||||
|
LLMMessage(
|
||||||
|
role="assistant", content=response.text, tool_calls=response.tool_calls
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def code(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
*,
|
||||||
|
model: ModelName = None,
|
||||||
|
timeout: int | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
代码执行
|
||||||
|
|
||||||
|
参数:
|
||||||
|
prompt: 代码执行的提示词。
|
||||||
|
model: 要使用的模型名称。
|
||||||
|
timeout: 代码执行超时时间(秒)。
|
||||||
|
**kwargs: 传递给模型的其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
dict[str, Any]: 包含执行结果的字典,包含text、code_executions和success字段。
|
||||||
|
"""
|
||||||
|
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
|
||||||
|
|
||||||
|
config = CommonOverrides.gemini_code_execution()
|
||||||
|
if timeout:
|
||||||
|
config.custom_params = config.custom_params or {}
|
||||||
|
config.custom_params["code_execution_timeout"] = timeout
|
||||||
|
|
||||||
|
messages = [LLMMessage.user(prompt)]
|
||||||
|
|
||||||
|
response = await self._execute_generation(
|
||||||
|
messages=messages,
|
||||||
|
model_name=resolved_model,
|
||||||
|
error_message="代码执行失败",
|
||||||
|
config_overrides=kwargs,
|
||||||
|
base_config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"text": response.text,
|
||||||
|
"code_executions": response.code_executions or [],
|
||||||
|
"success": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
query: str | UniMessage,
|
||||||
|
*,
|
||||||
|
model: ModelName = None,
|
||||||
|
instruction: str = "",
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
信息搜索 - 支持多模态输入
|
||||||
|
|
||||||
|
参数:
|
||||||
|
query: 搜索查询内容,支持文本或多模态消息。
|
||||||
|
model: 要使用的模型名称。
|
||||||
|
instruction: 搜索指令。
|
||||||
|
**kwargs: 传递给模型的其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
dict[str, Any]: 包含搜索结果的字典,包含text、sources、queries和success字段
|
||||||
|
"""
|
||||||
|
from nonebot_plugin_alconna.uniseg import UniMessage
|
||||||
|
|
||||||
|
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
|
||||||
|
config = CommonOverrides.gemini_grounding()
|
||||||
|
|
||||||
|
if isinstance(query, str):
|
||||||
|
messages = [LLMMessage.user(query)]
|
||||||
|
elif isinstance(query, UniMessage):
|
||||||
|
content_parts = await unimsg_to_llm_parts(query)
|
||||||
|
|
||||||
|
final_messages: list[LLMMessage] = []
|
||||||
|
if instruction:
|
||||||
|
final_messages.append(LLMMessage.system(instruction))
|
||||||
|
|
||||||
|
if not content_parts:
|
||||||
|
if instruction:
|
||||||
|
final_messages.append(LLMMessage.user(instruction))
|
||||||
|
else:
|
||||||
|
raise LLMException(
|
||||||
|
"搜索内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
final_messages.append(LLMMessage.user(content_parts))
|
||||||
|
|
||||||
|
messages = final_messages
|
||||||
|
else:
|
||||||
|
raise LLMException(
|
||||||
|
f"不支持的搜索输入类型: {type(query)}. 请使用 str 或 UniMessage.",
|
||||||
|
code=LLMErrorCode.API_REQUEST_FAILED,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await self._execute_generation(
|
||||||
|
messages=messages,
|
||||||
|
model_name=resolved_model,
|
||||||
|
error_message="信息搜索失败",
|
||||||
|
config_overrides=kwargs,
|
||||||
|
base_config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"text": response.text,
|
||||||
|
"sources": [],
|
||||||
|
"queries": [],
|
||||||
|
"success": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.grounding_metadata:
|
||||||
|
result["sources"] = response.grounding_metadata.grounding_attributions or []
|
||||||
|
result["queries"] = response.grounding_metadata.web_search_queries or []
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def analyze(
|
||||||
|
self,
|
||||||
|
message: UniMessage | None,
|
||||||
|
*,
|
||||||
|
instruction: str = "",
|
||||||
|
model: ModelName = None,
|
||||||
|
use_tools: list[str] | None = None,
|
||||||
|
tool_config: dict[str, Any] | None = None,
|
||||||
|
activated_tools: list[LLMTool] | None = None,
|
||||||
|
history: list[LLMMessage] | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""
|
||||||
|
内容分析 - 接收 UniMessage 物件进行多模态分析和工具呼叫。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
message: 要分析的消息内容(支持多模态)。
|
||||||
|
instruction: 分析指令。
|
||||||
|
model: 要使用的模型名称。
|
||||||
|
use_tools: 要使用的工具名称列表。
|
||||||
|
tool_config: 工具配置。
|
||||||
|
activated_tools: 已激活的工具列表。
|
||||||
|
history: 对话历史记录。
|
||||||
|
**kwargs: 传递给模型的其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
LLMResponse: 模型的完整响应结果。
|
||||||
|
"""
|
||||||
|
from nonebot_plugin_alconna.uniseg import UniMessage
|
||||||
|
|
||||||
|
content_parts = await unimsg_to_llm_parts(message or UniMessage())
|
||||||
|
|
||||||
|
final_messages: list[LLMMessage] = []
|
||||||
|
if history:
|
||||||
|
final_messages.extend(history)
|
||||||
|
|
||||||
|
if instruction:
|
||||||
|
if not any(msg.role == "system" for msg in final_messages):
|
||||||
|
final_messages.insert(0, LLMMessage.system(instruction))
|
||||||
|
|
||||||
|
if not content_parts:
|
||||||
|
if instruction and not history:
|
||||||
|
final_messages.append(LLMMessage.user(instruction))
|
||||||
|
elif not history:
|
||||||
|
raise LLMException(
|
||||||
|
"分析内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
final_messages.append(LLMMessage.user(content_parts))
|
||||||
|
|
||||||
|
llm_tools: list[LLMTool] | None = activated_tools
|
||||||
|
if not llm_tools and use_tools:
|
||||||
|
try:
|
||||||
|
llm_tools = tool_registry.get_tools(use_tools)
|
||||||
|
logger.debug(f"已从注册表加载工具定义: {use_tools}")
|
||||||
|
except ValueError as e:
|
||||||
|
raise LLMException(
|
||||||
|
f"加载工具定义失败: {e}",
|
||||||
|
code=LLMErrorCode.CONFIGURATION_ERROR,
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_choice = None
|
||||||
|
if tool_config:
|
||||||
|
mode = tool_config.get("mode", "auto")
|
||||||
|
if mode in ["auto", "any", "none"]:
|
||||||
|
tool_choice = mode
|
||||||
|
|
||||||
|
response = await self._execute_generation(
|
||||||
|
messages=final_messages,
|
||||||
|
model_name=model,
|
||||||
|
error_message="内容分析失败",
|
||||||
|
config_overrides=kwargs,
|
||||||
|
llm_tools=llm_tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def _execute_generation(
|
||||||
|
self,
|
||||||
|
messages: list[LLMMessage],
|
||||||
|
model_name: ModelName,
|
||||||
|
error_message: str,
|
||||||
|
config_overrides: dict[str, Any],
|
||||||
|
llm_tools: list[LLMTool] | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
|
base_config: LLMGenerationConfig | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""通用的生成执行方法,封装模型获取和单次API调用"""
|
||||||
|
try:
|
||||||
|
resolved_model_name = self._resolve_model_name(
|
||||||
|
model_name or self.config.model
|
||||||
|
)
|
||||||
|
final_config_dict = self._merge_config(
|
||||||
|
config_overrides, base_config=base_config
|
||||||
|
)
|
||||||
|
|
||||||
|
async with await get_model_instance(
|
||||||
|
resolved_model_name, override_config=final_config_dict
|
||||||
|
) as model_instance:
|
||||||
|
return await model_instance.generate_response(
|
||||||
|
messages,
|
||||||
|
tools=llm_tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
)
|
||||||
|
except LLMException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"{error_message}: {e}", e=e)
|
||||||
|
raise LLMException(f"{error_message}: {e}", cause=e)
|
||||||
|
|
||||||
|
def _resolve_model_name(self, model_name: ModelName) -> str:
|
||||||
|
"""解析模型名称"""
|
||||||
|
if model_name:
|
||||||
|
return model_name
|
||||||
|
|
||||||
|
default_model = get_global_default_model_name()
|
||||||
|
if default_model:
|
||||||
|
return default_model
|
||||||
|
|
||||||
|
raise LLMException(
|
||||||
|
"未指定模型名称且未设置全局默认模型",
|
||||||
|
code=LLMErrorCode.MODEL_NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _merge_config(
|
||||||
|
self,
|
||||||
|
user_config: dict[str, Any],
|
||||||
|
base_config: LLMGenerationConfig | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""合并配置"""
|
||||||
|
final_config = {}
|
||||||
|
if base_config:
|
||||||
|
final_config.update(base_config.to_dict())
|
||||||
|
|
||||||
|
if self.config.temperature is not None:
|
||||||
|
final_config["temperature"] = self.config.temperature
|
||||||
|
if self.config.max_tokens is not None:
|
||||||
|
final_config["max_tokens"] = self.config.max_tokens
|
||||||
|
|
||||||
|
if self.config.enable_cache:
|
||||||
|
final_config["enable_caching"] = True
|
||||||
|
if self.config.enable_code:
|
||||||
|
final_config["enable_code_execution"] = True
|
||||||
|
if self.config.enable_search:
|
||||||
|
final_config["enable_grounding"] = True
|
||||||
|
|
||||||
|
if self.config.enable_gemini_json_mode:
|
||||||
|
final_config["response_mime_type"] = "application/json"
|
||||||
|
if self.config.enable_gemini_thinking:
|
||||||
|
final_config["thinking_budget"] = 0.8
|
||||||
|
if self.config.enable_gemini_safe_mode:
|
||||||
|
final_config["safety_settings"] = (
|
||||||
|
CommonOverrides.gemini_safe().safety_settings
|
||||||
|
)
|
||||||
|
if self.config.enable_gemini_multimodal:
|
||||||
|
final_config.update(CommonOverrides.gemini_multimodal().to_dict())
|
||||||
|
if self.config.enable_gemini_grounding:
|
||||||
|
final_config["enable_grounding"] = True
|
||||||
|
|
||||||
|
final_config.update(user_config)
|
||||||
|
|
||||||
|
return final_config
|
||||||
|
|
||||||
|
async def embed(
|
||||||
|
self,
|
||||||
|
texts: list[str] | str,
|
||||||
|
*,
|
||||||
|
model: ModelName = None,
|
||||||
|
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> list[list[float]]:
|
||||||
|
"""
|
||||||
|
生成文本嵌入向量
|
||||||
|
|
||||||
|
参数:
|
||||||
|
texts: 要生成嵌入向量的文本或文本列表。
|
||||||
|
model: 要使用的嵌入模型名称。
|
||||||
|
task_type: 嵌入任务类型。
|
||||||
|
**kwargs: 传递给模型的其他参数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
list[list[float]]: 文本的嵌入向量列表。
|
||||||
|
"""
|
||||||
|
if isinstance(texts, str):
|
||||||
|
texts = [texts]
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
resolved_model_str = (
|
||||||
|
model or self.config.default_embedding_model or self.config.model
|
||||||
|
)
|
||||||
|
if not resolved_model_str:
|
||||||
|
raise LLMException(
|
||||||
|
"使用 embed 功能时必须指定嵌入模型名称,"
|
||||||
|
"或在 AIConfig 中配置 default_embedding_model。",
|
||||||
|
code=LLMErrorCode.MODEL_NOT_FOUND,
|
||||||
|
)
|
||||||
|
resolved_model_str = self._resolve_model_name(resolved_model_str)
|
||||||
|
|
||||||
|
async with await get_model_instance(
|
||||||
|
resolved_model_str,
|
||||||
|
override_config=None,
|
||||||
|
) as embedding_model_instance:
|
||||||
|
return await embedding_model_instance.generate_embeddings(
|
||||||
|
texts, task_type=task_type, **kwargs
|
||||||
|
)
|
||||||
|
except LLMException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"文本嵌入失败: {e}", e=e)
|
||||||
|
raise LLMException(
|
||||||
|
f"文本嵌入失败: {e}", code=LLMErrorCode.EMBEDDING_FAILED, cause=e
|
||||||
|
)
|
||||||
@ -10,7 +10,13 @@ from .content import (
|
|||||||
LLMMessage,
|
LLMMessage,
|
||||||
LLMResponse,
|
LLMResponse,
|
||||||
)
|
)
|
||||||
from .enums import EmbeddingTaskType, ModelProvider, ResponseFormat, ToolCategory
|
from .enums import (
|
||||||
|
EmbeddingTaskType,
|
||||||
|
ModelProvider,
|
||||||
|
ResponseFormat,
|
||||||
|
TaskType,
|
||||||
|
ToolCategory,
|
||||||
|
)
|
||||||
from .exceptions import LLMErrorCode, LLMException, get_user_friendly_error_message
|
from .exceptions import LLMErrorCode, LLMException, get_user_friendly_error_message
|
||||||
from .models import (
|
from .models import (
|
||||||
LLMCacheInfo,
|
LLMCacheInfo,
|
||||||
@ -52,6 +58,7 @@ __all__ = [
|
|||||||
"ModelProvider",
|
"ModelProvider",
|
||||||
"ProviderConfig",
|
"ProviderConfig",
|
||||||
"ResponseFormat",
|
"ResponseFormat",
|
||||||
|
"TaskType",
|
||||||
"ToolCategory",
|
"ToolCategory",
|
||||||
"ToolMetadata",
|
"ToolMetadata",
|
||||||
"UsageInfo",
|
"UsageInfo",
|
||||||
|
|||||||
@ -45,6 +45,17 @@ class ToolCategory(Enum):
|
|||||||
CUSTOM = auto()
|
CUSTOM = auto()
|
||||||
|
|
||||||
|
|
||||||
|
class TaskType(Enum):
|
||||||
|
"""任务类型枚举"""
|
||||||
|
|
||||||
|
CHAT = "chat"
|
||||||
|
CODE = "code"
|
||||||
|
SEARCH = "search"
|
||||||
|
ANALYSIS = "analysis"
|
||||||
|
GENERATION = "generation"
|
||||||
|
MULTIMODAL = "multimodal"
|
||||||
|
|
||||||
|
|
||||||
class LLMErrorCode(Enum):
|
class LLMErrorCode(Enum):
|
||||||
"""LLM 服务相关的错误代码枚举"""
|
"""LLM 服务相关的错误代码枚举"""
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user