mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-14 21:52:56 +08:00
✨ feat(llm): 新增LLM模型管理插件并增强API密钥管理 (#1972)
🔧 新增功能: - LLM模型管理插件 (builtin_plugins/llm_manager/) • llm list - 查看可用模型列表 (图片格式) • llm info - 查看模型详细信息 (Markdown图片) • llm default - 管理全局默认模型 • llm test - 测试模型连通性 • llm keys - 查看API Key状态 (表格图片,含健康度/成功率/延迟) • llm reset-key - 重置API Key失败状态 🏗️ 架构重构: - 会话管理: AI/AIConfig 类迁移至独立的 session.py - 类型定义: TaskType 枚举移至 types/enums.py - API增强: • chat() 函数返回完整 LLMResponse,支持工具调用 • 新增 generate() 函数用于一次性响应生成 • 统一API调用核心方法 _perform_api_call,返回使用的API密钥 🚀 密钥管理增强: - 详细状态跟踪: 健康度、成功率、平均延迟、错误信息、建议操作 - 状态持久化: 启动时加载,关闭时自动保存密钥状态 - 智能冷却策略: 根据错误类型设置不同冷却时间 - 延迟监控: with_smart_retry 记录API调用延迟并更新统计 Co-authored-by: webjoin111 <455457521@qq.com> Co-authored-by: HibiKier <45528451+HibiKier@users.noreply.github.com>
This commit is contained in:
parent
8649aaaa54
commit
46a0768a45
171
zhenxun/builtin_plugins/llm_manager/__init__.py
Normal file
171
zhenxun/builtin_plugins/llm_manager/__init__.py
Normal file
@ -0,0 +1,171 @@
|
||||
from nonebot.permission import SUPERUSER
|
||||
from nonebot.plugin import PluginMetadata
|
||||
from nonebot_plugin_alconna import (
|
||||
Alconna,
|
||||
Args,
|
||||
Arparma,
|
||||
Match,
|
||||
Option,
|
||||
Query,
|
||||
Subcommand,
|
||||
on_alconna,
|
||||
store_true,
|
||||
)
|
||||
|
||||
from zhenxun.configs.utils import PluginExtraData
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import PluginType
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
|
||||
from .data_source import DataSource
|
||||
from .presenters import Presenters
|
||||
|
||||
__plugin_meta__ = PluginMetadata(
|
||||
name="LLM模型管理",
|
||||
description="查看和管理大语言模型服务。",
|
||||
usage="""
|
||||
LLM模型管理 (SUPERUSER)
|
||||
|
||||
llm list [--all]
|
||||
- 查看可用模型列表。
|
||||
- --all: 显示包括不可用在内的所有模型。
|
||||
|
||||
llm info <Provider/ModelName>
|
||||
- 查看指定模型的详细信息和能力。
|
||||
|
||||
llm default [Provider/ModelName]
|
||||
- 查看或设置全局默认模型。
|
||||
- 不带参数: 查看当前默认模型。
|
||||
- 带参数: 设置新的默认模型。
|
||||
- 例子: llm default Gemini/gemini-2.0-flash
|
||||
|
||||
llm test <Provider/ModelName>
|
||||
- 测试指定模型的连通性和API Key有效性。
|
||||
|
||||
llm keys <ProviderName>
|
||||
- 查看指定提供商的所有API Key状态。
|
||||
|
||||
llm reset-key <ProviderName> [--key <api_key>]
|
||||
- 重置提供商的所有或指定API Key的失败状态。
|
||||
""",
|
||||
extra=PluginExtraData(
|
||||
author="HibiKier",
|
||||
version="1.0.0",
|
||||
plugin_type=PluginType.SUPERUSER,
|
||||
).to_dict(),
|
||||
)
|
||||
|
||||
llm_cmd = on_alconna(
|
||||
Alconna(
|
||||
"llm",
|
||||
Subcommand("list", alias=["ls"], help_text="查看模型列表"),
|
||||
Subcommand("info", Args["model_name", str], help_text="查看模型详情"),
|
||||
Subcommand("default", Args["model_name?", str], help_text="查看或设置默认模型"),
|
||||
Subcommand(
|
||||
"test", Args["model_name", str], alias=["ping"], help_text="测试模型连通性"
|
||||
),
|
||||
Subcommand("keys", Args["provider_name", str], help_text="查看API密钥状态"),
|
||||
Subcommand(
|
||||
"reset-key",
|
||||
Args["provider_name", str],
|
||||
Option("--key", Args["api_key", str], help_text="指定要重置的API Key"),
|
||||
help_text="重置API Key状态",
|
||||
),
|
||||
Option("--all", action=store_true, help_text="显示所有条目"),
|
||||
),
|
||||
permission=SUPERUSER,
|
||||
priority=5,
|
||||
block=True,
|
||||
)
|
||||
|
||||
|
||||
@llm_cmd.assign("list")
|
||||
async def handle_list(arp: Arparma, show_all: Query[bool] = Query("all")):
|
||||
"""处理 'llm list' 命令"""
|
||||
logger.info("获取LLM模型列表", command="LLM Manage", session=arp.header_result)
|
||||
models = await DataSource.get_model_list(show_all=show_all.result)
|
||||
|
||||
image = await Presenters.format_model_list_as_image(models, show_all.result)
|
||||
await llm_cmd.finish(MessageUtils.build_message(image))
|
||||
|
||||
|
||||
@llm_cmd.assign("info")
|
||||
async def handle_info(arp: Arparma, model_name: Match[str]):
|
||||
"""处理 'llm info' 命令"""
|
||||
logger.info(
|
||||
f"获取模型详情: {model_name.result}",
|
||||
command="LLM Manage",
|
||||
session=arp.header_result,
|
||||
)
|
||||
details = await DataSource.get_model_details(model_name.result)
|
||||
if not details:
|
||||
await llm_cmd.finish(f"未找到模型: {model_name.result}")
|
||||
|
||||
image_bytes = await Presenters.format_model_details_as_markdown_image(details)
|
||||
await llm_cmd.finish(MessageUtils.build_message(image_bytes))
|
||||
|
||||
|
||||
@llm_cmd.assign("default")
|
||||
async def handle_default(arp: Arparma, model_name: Match[str]):
|
||||
"""处理 'llm default' 命令"""
|
||||
if model_name.available:
|
||||
logger.info(
|
||||
f"设置默认模型为: {model_name.result}",
|
||||
command="LLM Manage",
|
||||
session=arp.header_result,
|
||||
)
|
||||
success, message = await DataSource.set_default_model(model_name.result)
|
||||
await llm_cmd.finish(message)
|
||||
else:
|
||||
logger.info("查看默认模型", command="LLM Manage", session=arp.header_result)
|
||||
current_default = await DataSource.get_default_model()
|
||||
await llm_cmd.finish(f"当前全局默认模型为: {current_default or '未设置'}")
|
||||
|
||||
|
||||
@llm_cmd.assign("test")
|
||||
async def handle_test(arp: Arparma, model_name: Match[str]):
|
||||
"""处理 'llm test' 命令"""
|
||||
logger.info(
|
||||
f"测试模型连通性: {model_name.result}",
|
||||
command="LLM Manage",
|
||||
session=arp.header_result,
|
||||
)
|
||||
await llm_cmd.send(f"正在测试模型 '{model_name.result}',请稍候...")
|
||||
|
||||
success, message = await DataSource.test_model_connectivity(model_name.result)
|
||||
await llm_cmd.finish(message)
|
||||
|
||||
|
||||
@llm_cmd.assign("keys")
|
||||
async def handle_keys(arp: Arparma, provider_name: Match[str]):
|
||||
"""处理 'llm keys' 命令"""
|
||||
logger.info(
|
||||
f"查看提供商API Key状态: {provider_name.result}",
|
||||
command="LLM Manage",
|
||||
session=arp.header_result,
|
||||
)
|
||||
sorted_stats = await DataSource.get_key_status(provider_name.result)
|
||||
if not sorted_stats:
|
||||
await llm_cmd.finish(
|
||||
f"未找到提供商 '{provider_name.result}' 或其没有配置API Keys。"
|
||||
)
|
||||
|
||||
image = await Presenters.format_key_status_as_image(
|
||||
provider_name.result, sorted_stats
|
||||
)
|
||||
await llm_cmd.finish(MessageUtils.build_message(image))
|
||||
|
||||
|
||||
@llm_cmd.assign("reset-key")
|
||||
async def handle_reset_key(
|
||||
arp: Arparma, provider_name: Match[str], api_key: Match[str]
|
||||
):
|
||||
"""处理 'llm reset-key' 命令"""
|
||||
key_to_reset = api_key.result if api_key.available else None
|
||||
log_msg = f"重置 {provider_name.result} 的 " + (
|
||||
"指定API Key" if key_to_reset else "所有API Keys"
|
||||
)
|
||||
logger.info(log_msg, command="LLM Manage", session=arp.header_result)
|
||||
|
||||
success, message = await DataSource.reset_key(provider_name.result, key_to_reset)
|
||||
await llm_cmd.finish(message)
|
||||
120
zhenxun/builtin_plugins/llm_manager/data_source.py
Normal file
120
zhenxun/builtin_plugins/llm_manager/data_source.py
Normal file
@ -0,0 +1,120 @@
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from zhenxun.services.llm import (
|
||||
LLMException,
|
||||
get_global_default_model_name,
|
||||
get_model_instance,
|
||||
list_available_models,
|
||||
set_global_default_model_name,
|
||||
)
|
||||
from zhenxun.services.llm.core import KeyStatus
|
||||
from zhenxun.services.llm.manager import (
|
||||
reset_key_status,
|
||||
)
|
||||
|
||||
|
||||
class DataSource:
|
||||
"""LLM管理插件的数据源和业务逻辑"""
|
||||
|
||||
@staticmethod
|
||||
async def get_model_list(show_all: bool = False) -> list[dict[str, Any]]:
|
||||
"""获取模型列表"""
|
||||
models = list_available_models()
|
||||
if show_all:
|
||||
return models
|
||||
return [m for m in models if m.get("is_available", True)]
|
||||
|
||||
@staticmethod
|
||||
async def get_model_details(model_name_str: str) -> dict[str, Any] | None:
|
||||
"""获取指定模型的详细信息"""
|
||||
try:
|
||||
model = await get_model_instance(model_name_str)
|
||||
return {
|
||||
"provider_config": model.provider_config,
|
||||
"model_detail": model.model_detail,
|
||||
"capabilities": model.capabilities,
|
||||
}
|
||||
except LLMException:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def get_default_model() -> str | None:
|
||||
"""获取全局默认模型"""
|
||||
return get_global_default_model_name()
|
||||
|
||||
@staticmethod
|
||||
async def set_default_model(model_name_str: str) -> tuple[bool, str]:
|
||||
"""设置全局默认模型"""
|
||||
success = set_global_default_model_name(model_name_str)
|
||||
if success:
|
||||
return True, f"✅ 成功将默认模型设置为: {model_name_str}"
|
||||
else:
|
||||
return False, f"❌ 设置失败,模型 '{model_name_str}' 不存在或无效。"
|
||||
|
||||
@staticmethod
|
||||
async def test_model_connectivity(model_name_str: str) -> tuple[bool, str]:
|
||||
"""测试模型连通性"""
|
||||
start_time = time.monotonic()
|
||||
try:
|
||||
async with await get_model_instance(model_name_str) as model:
|
||||
await model.generate_text("你好")
|
||||
end_time = time.monotonic()
|
||||
latency = (end_time - start_time) * 1000
|
||||
return (
|
||||
True,
|
||||
f"✅ 模型 '{model_name_str}' 连接成功!\n响应延迟: {latency:.2f} ms",
|
||||
)
|
||||
except LLMException as e:
|
||||
return (
|
||||
False,
|
||||
f"❌ 模型 '{model_name_str}' 连接测试失败:\n"
|
||||
f"{e.user_friendly_message}\n错误码: {e.code.name}",
|
||||
)
|
||||
except Exception as e:
|
||||
return False, f"❌ 测试时发生未知错误: {e!s}"
|
||||
|
||||
@staticmethod
|
||||
async def get_key_status(provider_name: str) -> list[dict[str, Any]] | None:
|
||||
"""获取并排序指定提供商的API Key状态"""
|
||||
from zhenxun.services.llm.manager import get_key_usage_stats
|
||||
|
||||
all_stats = await get_key_usage_stats()
|
||||
provider_stats = all_stats.get(provider_name)
|
||||
|
||||
if not provider_stats or not provider_stats.get("key_stats"):
|
||||
return None
|
||||
|
||||
key_stats_dict = provider_stats["key_stats"]
|
||||
|
||||
stats_list = [
|
||||
{"key_id": key_id, **stats} for key_id, stats in key_stats_dict.items()
|
||||
]
|
||||
|
||||
def sort_key(item: dict[str, Any]):
|
||||
status_priority = item.get("status_enum", KeyStatus.UNUSED).value
|
||||
return (
|
||||
status_priority,
|
||||
100 - item.get("success_rate", 100.0),
|
||||
-item.get("total_calls", 0),
|
||||
)
|
||||
|
||||
sorted_stats_list = sorted(stats_list, key=sort_key)
|
||||
|
||||
return sorted_stats_list
|
||||
|
||||
@staticmethod
|
||||
async def reset_key(provider_name: str, api_key: str | None) -> tuple[bool, str]:
|
||||
"""重置API Key状态"""
|
||||
success = await reset_key_status(provider_name, api_key)
|
||||
if success:
|
||||
if api_key:
|
||||
if len(api_key) > 8:
|
||||
target = f"API Key '{api_key[:4]}...{api_key[-4:]}'"
|
||||
else:
|
||||
target = f"API Key '{api_key}'"
|
||||
else:
|
||||
target = "所有API Keys"
|
||||
return True, f"✅ 成功重置提供商 '{provider_name}' 的 {target} 的状态。"
|
||||
else:
|
||||
return False, "❌ 重置失败,请检查提供商名称或API Key是否正确。"
|
||||
204
zhenxun/builtin_plugins/llm_manager/presenters.py
Normal file
204
zhenxun/builtin_plugins/llm_manager/presenters.py
Normal file
@ -0,0 +1,204 @@
|
||||
from typing import Any
|
||||
|
||||
from zhenxun.services.llm.core import KeyStatus
|
||||
from zhenxun.services.llm.types import ModelModality
|
||||
from zhenxun.utils._build_image import BuildImage
|
||||
from zhenxun.utils._image_template import ImageTemplate, Markdown, RowStyle
|
||||
|
||||
|
||||
def _format_seconds(seconds: int) -> str:
|
||||
"""将秒数格式化为 'Xm Ys' 或 'Xh Ym' 的形式"""
|
||||
if seconds <= 0:
|
||||
return "0s"
|
||||
if seconds < 60:
|
||||
return f"{seconds}s"
|
||||
|
||||
minutes, seconds = divmod(seconds, 60)
|
||||
if minutes < 60:
|
||||
return f"{minutes}m {seconds}s"
|
||||
|
||||
hours, minutes = divmod(minutes, 60)
|
||||
return f"{hours}h {minutes}m"
|
||||
|
||||
|
||||
class Presenters:
|
||||
"""格式化LLM管理插件的输出 (图片格式)"""
|
||||
|
||||
@staticmethod
|
||||
async def format_model_list_as_image(
|
||||
models: list[dict[str, Any]], show_all: bool
|
||||
) -> BuildImage:
|
||||
"""将模型列表格式化为表格图片"""
|
||||
title = "📋 LLM模型列表" + (" (所有已配置模型)" if show_all else " (仅可用)")
|
||||
|
||||
if not models:
|
||||
return await BuildImage.build_text_image(
|
||||
f"{title}\n\n当前没有配置任何LLM模型。"
|
||||
)
|
||||
|
||||
column_name = ["提供商", "模型名称", "API类型", "状态"]
|
||||
data_list = []
|
||||
for model in models:
|
||||
status_text = "✅ 可用" if model.get("is_available", True) else "❌ 不可用"
|
||||
embed_tag = " (Embed)" if model.get("is_embedding_model", False) else ""
|
||||
data_list.append(
|
||||
[
|
||||
model.get("provider_name", "N/A"),
|
||||
f"{model.get('model_name', 'N/A')}{embed_tag}",
|
||||
model.get("api_type", "N/A"),
|
||||
status_text,
|
||||
]
|
||||
)
|
||||
|
||||
return await ImageTemplate.table_page(
|
||||
head_text=title,
|
||||
tip_text="使用 `llm info <Provider/ModelName>` 查看详情",
|
||||
column_name=column_name,
|
||||
data_list=data_list,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def format_model_details_as_markdown_image(details: dict[str, Any]) -> bytes:
|
||||
"""将模型详情格式化为Markdown图片"""
|
||||
provider = details["provider_config"]
|
||||
model = details["model_detail"]
|
||||
caps = details["capabilities"]
|
||||
|
||||
cap_list = []
|
||||
if ModelModality.IMAGE in caps.input_modalities:
|
||||
cap_list.append("视觉")
|
||||
if ModelModality.VIDEO in caps.input_modalities:
|
||||
cap_list.append("视频")
|
||||
if ModelModality.AUDIO in caps.input_modalities:
|
||||
cap_list.append("音频")
|
||||
if caps.supports_tool_calling:
|
||||
cap_list.append("工具调用")
|
||||
if caps.is_embedding_model:
|
||||
cap_list.append("文本嵌入")
|
||||
|
||||
md = Markdown()
|
||||
md.head(f"🔎 模型详情: {provider.name}/{model.model_name}", level=1)
|
||||
md.text("---")
|
||||
md.head("提供商信息", level=2)
|
||||
md.list(
|
||||
[
|
||||
f"**名称**: {provider.name}",
|
||||
f"**API 类型**: {provider.api_type}",
|
||||
f"**API Base**: {provider.api_base or '默认'}",
|
||||
]
|
||||
)
|
||||
md.head("模型详情", level=2)
|
||||
|
||||
temp_value = model.temperature or provider.temperature or "未设置"
|
||||
token_value = model.max_tokens or provider.max_tokens or "未设置"
|
||||
|
||||
md.list(
|
||||
[
|
||||
f"**名称**: {model.model_name}",
|
||||
f"**默认温度**: {temp_value}",
|
||||
f"**最大Token**: {token_value}",
|
||||
f"**核心能力**: {', '.join(cap_list) or '纯文本'}",
|
||||
]
|
||||
)
|
||||
|
||||
return await md.build()
|
||||
|
||||
@staticmethod
|
||||
async def format_key_status_as_image(
|
||||
provider_name: str, sorted_stats: list[dict[str, Any]]
|
||||
) -> BuildImage:
|
||||
"""将已排序的、详细的API Key状态格式化为表格图片"""
|
||||
title = f"🔑 '{provider_name}' API Key 状态"
|
||||
|
||||
if not sorted_stats:
|
||||
return await BuildImage.build_text_image(
|
||||
f"{title}\n\n该提供商没有配置API Keys。"
|
||||
)
|
||||
|
||||
def _status_row_style(column: str, text: str) -> RowStyle:
|
||||
style = RowStyle()
|
||||
if column == "状态":
|
||||
if "✅ 健康" in text:
|
||||
style.font_color = "#67C23A"
|
||||
elif "⚠️ 告警" in text:
|
||||
style.font_color = "#E6A23C"
|
||||
elif "❌ 错误" in text or "🚫" in text:
|
||||
style.font_color = "#F56C6C"
|
||||
elif "❄️ 冷却中" in text:
|
||||
style.font_color = "#409EFF"
|
||||
elif column == "成功率":
|
||||
try:
|
||||
if text != "N/A":
|
||||
rate = float(text.replace("%", ""))
|
||||
if rate < 80:
|
||||
style.font_color = "#F56C6C"
|
||||
elif rate < 95:
|
||||
style.font_color = "#E6A23C"
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
return style
|
||||
|
||||
column_name = [
|
||||
"Key (部分)",
|
||||
"状态",
|
||||
"总调用",
|
||||
"成功率",
|
||||
"平均延迟(s)",
|
||||
"上次错误",
|
||||
"建议操作",
|
||||
]
|
||||
data_list = []
|
||||
|
||||
for key_info in sorted_stats:
|
||||
status_enum: KeyStatus = key_info["status_enum"]
|
||||
|
||||
if status_enum == KeyStatus.COOLDOWN:
|
||||
cooldown_seconds = int(key_info["cooldown_seconds_left"])
|
||||
formatted_time = _format_seconds(cooldown_seconds)
|
||||
status_text = f"❄️ 冷却中({formatted_time})"
|
||||
else:
|
||||
status_text = {
|
||||
KeyStatus.DISABLED: "🚫 永久禁用",
|
||||
KeyStatus.ERROR: "❌ 错误",
|
||||
KeyStatus.WARNING: "⚠️ 告警",
|
||||
KeyStatus.HEALTHY: "✅ 健康",
|
||||
KeyStatus.UNUSED: "⚪️ 未使用",
|
||||
}.get(status_enum, "❔ 未知")
|
||||
|
||||
total_calls = key_info["total_calls"]
|
||||
total_calls_text = (
|
||||
f"{key_info['success_count']}/{total_calls}"
|
||||
if total_calls > 0
|
||||
else "0/0"
|
||||
)
|
||||
|
||||
success_rate = key_info["success_rate"]
|
||||
success_rate_text = f"{success_rate:.1f}%" if total_calls > 0 else "N/A"
|
||||
|
||||
avg_latency = key_info["avg_latency"]
|
||||
avg_latency_text = f"{avg_latency / 1000:.2f}" if avg_latency > 0 else "N/A"
|
||||
|
||||
last_error = key_info.get("last_error") or "-"
|
||||
if len(last_error) > 25:
|
||||
last_error = last_error[:22] + "..."
|
||||
|
||||
data_list.append(
|
||||
[
|
||||
key_info["key_id"],
|
||||
status_text,
|
||||
total_calls_text,
|
||||
success_rate_text,
|
||||
avg_latency_text,
|
||||
last_error,
|
||||
key_info["suggested_action"],
|
||||
]
|
||||
)
|
||||
|
||||
return await ImageTemplate.table_page(
|
||||
head_text=title,
|
||||
tip_text="使用 `llm reset-key <Provider>` 重置Key状态",
|
||||
column_name=column_name,
|
||||
data_list=data_list,
|
||||
text_style=_status_row_style,
|
||||
column_space=15,
|
||||
)
|
||||
@ -21,11 +21,28 @@ require("nonebot_plugin_waiter")
|
||||
from .db_context import Model, disconnect
|
||||
from .llm import (
|
||||
AI,
|
||||
AIConfig,
|
||||
CommonOverrides,
|
||||
LLMContentPart,
|
||||
LLMException,
|
||||
LLMGenerationConfig,
|
||||
LLMMessage,
|
||||
analyze,
|
||||
analyze_multimodal,
|
||||
chat,
|
||||
clear_model_cache,
|
||||
code,
|
||||
create_multimodal_message,
|
||||
embed,
|
||||
generate,
|
||||
get_cache_stats,
|
||||
get_model_instance,
|
||||
list_available_models,
|
||||
list_embedding_models,
|
||||
pipeline_chat,
|
||||
search,
|
||||
search_multimodal,
|
||||
set_global_default_model_name,
|
||||
tool_registry,
|
||||
)
|
||||
from .log import logger
|
||||
@ -34,16 +51,33 @@ from .scheduler import scheduler_manager
|
||||
|
||||
__all__ = [
|
||||
"AI",
|
||||
"AIConfig",
|
||||
"CommonOverrides",
|
||||
"LLMContentPart",
|
||||
"LLMException",
|
||||
"LLMGenerationConfig",
|
||||
"LLMMessage",
|
||||
"Model",
|
||||
"PluginInit",
|
||||
"PluginInitManager",
|
||||
"analyze",
|
||||
"analyze_multimodal",
|
||||
"chat",
|
||||
"clear_model_cache",
|
||||
"code",
|
||||
"create_multimodal_message",
|
||||
"disconnect",
|
||||
"embed",
|
||||
"generate",
|
||||
"get_cache_stats",
|
||||
"get_model_instance",
|
||||
"list_available_models",
|
||||
"list_embedding_models",
|
||||
"logger",
|
||||
"pipeline_chat",
|
||||
"scheduler_manager",
|
||||
"search",
|
||||
"search_multimodal",
|
||||
"set_global_default_model_name",
|
||||
"tool_registry",
|
||||
]
|
||||
|
||||
@ -198,7 +198,7 @@ print(search_result['text'])
|
||||
当你需要进行有上下文的、连续的对话时,`AI` 类是你的最佳选择。
|
||||
|
||||
```python
|
||||
from zhenxun.services.llm.api import AI, AIConfig
|
||||
from zhenxun.services.llm import AI, AIConfig
|
||||
|
||||
# 初始化一个AI会话,可以传入自定义配置
|
||||
ai_config = AIConfig(model="GLM/glm-4-flash", temperature=0.7)
|
||||
@ -395,7 +395,7 @@ async def my_tool_factory(config: MyToolConfig):
|
||||
在 `analyze` 或 `generate_response` 中使用 `use_tools` 参数。框架会自动处理整个调用流程。
|
||||
|
||||
```python
|
||||
from zhenxun.services.llm.api import analyze
|
||||
from zhenxun.services.llm import analyze
|
||||
from nonebot_plugin_alconna.uniseg import UniMessage
|
||||
|
||||
response = await analyze(
|
||||
@ -442,7 +442,6 @@ from zhenxun.services.llm.manager import (
|
||||
get_key_usage_stats,
|
||||
reset_key_status
|
||||
)
|
||||
from zhenxun.services.llm import clear_model_cache, get_cache_stats
|
||||
|
||||
# 列出所有在config.yaml中配置的可用模型
|
||||
models = list_available_models()
|
||||
|
||||
@ -5,14 +5,12 @@ LLM 服务模块 - 公共 API 入口
|
||||
"""
|
||||
|
||||
from .api import (
|
||||
AI,
|
||||
AIConfig,
|
||||
TaskType,
|
||||
analyze,
|
||||
analyze_multimodal,
|
||||
chat,
|
||||
code,
|
||||
embed,
|
||||
generate,
|
||||
pipeline_chat,
|
||||
search,
|
||||
search_multimodal,
|
||||
@ -35,6 +33,7 @@ from .manager import (
|
||||
list_model_identifiers,
|
||||
set_global_default_model_name,
|
||||
)
|
||||
from .session import AI, AIConfig
|
||||
from .tools import tool_registry
|
||||
from .types import (
|
||||
EmbeddingTaskType,
|
||||
@ -49,6 +48,7 @@ from .types import (
|
||||
ModelInfo,
|
||||
ModelProvider,
|
||||
ResponseFormat,
|
||||
TaskType,
|
||||
ToolCategory,
|
||||
ToolMetadata,
|
||||
UsageInfo,
|
||||
@ -84,6 +84,7 @@ __all__ = [
|
||||
"code",
|
||||
"create_multimodal_message",
|
||||
"embed",
|
||||
"generate",
|
||||
"get_cache_stats",
|
||||
"get_global_default_model_name",
|
||||
"get_model_instance",
|
||||
|
||||
@ -1,10 +1,7 @@
|
||||
"""
|
||||
LLM 服务的高级 API 接口
|
||||
LLM 服务的高级 API 接口 - 便捷函数入口
|
||||
"""
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@ -12,10 +9,8 @@ from nonebot_plugin_alconna.uniseg import UniMessage
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
from .config import CommonOverrides, LLMGenerationConfig
|
||||
from .config.providers import get_ai_config
|
||||
from .manager import get_global_default_model_name, get_model_instance
|
||||
from .tools import tool_registry
|
||||
from .manager import get_model_instance
|
||||
from .session import AI
|
||||
from .types import (
|
||||
EmbeddingTaskType,
|
||||
LLMContentPart,
|
||||
@ -29,514 +24,31 @@ from .types import (
|
||||
from .utils import create_multimodal_message, unimsg_to_llm_parts
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
"""任务类型枚举"""
|
||||
|
||||
CHAT = "chat"
|
||||
CODE = "code"
|
||||
SEARCH = "search"
|
||||
ANALYSIS = "analysis"
|
||||
GENERATION = "generation"
|
||||
MULTIMODAL = "multimodal"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AIConfig:
|
||||
"""AI配置类 - 简化版本"""
|
||||
|
||||
model: ModelName = None
|
||||
default_embedding_model: ModelName = None
|
||||
temperature: float | None = None
|
||||
max_tokens: int | None = None
|
||||
enable_cache: bool = False
|
||||
enable_code: bool = False
|
||||
enable_search: bool = False
|
||||
timeout: int | None = None
|
||||
|
||||
enable_gemini_json_mode: bool = False
|
||||
enable_gemini_thinking: bool = False
|
||||
enable_gemini_safe_mode: bool = False
|
||||
enable_gemini_multimodal: bool = False
|
||||
enable_gemini_grounding: bool = False
|
||||
default_preserve_media_in_history: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
"""初始化后从配置中读取默认值"""
|
||||
ai_config = get_ai_config()
|
||||
if self.model is None:
|
||||
self.model = ai_config.get("default_model_name")
|
||||
if self.timeout is None:
|
||||
self.timeout = ai_config.get("timeout", 180)
|
||||
|
||||
|
||||
class AI:
|
||||
"""统一的AI服务类 - 平衡设计版本
|
||||
|
||||
提供三层API:
|
||||
1. 简单方法:ai.chat(), ai.code(), ai.search()
|
||||
2. 标准方法:ai.analyze() 支持复杂参数
|
||||
3. 高级方法:通过get_model_instance()直接访问
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, config: AIConfig | None = None, history: list[LLMMessage] | None = None
|
||||
):
|
||||
"""
|
||||
初始化AI服务
|
||||
|
||||
参数:
|
||||
config: AI 配置.
|
||||
history: 可选的初始对话历史.
|
||||
"""
|
||||
self.config = config or AIConfig()
|
||||
self.history = history or []
|
||||
|
||||
def clear_history(self):
|
||||
"""清空当前会话的历史记录"""
|
||||
self.history = []
|
||||
logger.info("AI session history cleared.")
|
||||
|
||||
def _sanitize_message_for_history(self, message: LLMMessage) -> LLMMessage:
|
||||
"""
|
||||
净化用于存入历史记录的消息。
|
||||
将非文本的多模态内容部分替换为文本占位符,以避免重复处理。
|
||||
"""
|
||||
if not isinstance(message.content, list):
|
||||
return message
|
||||
|
||||
sanitized_message = copy.deepcopy(message)
|
||||
content_list = sanitized_message.content
|
||||
if not isinstance(content_list, list):
|
||||
return sanitized_message
|
||||
|
||||
new_content_parts: list[LLMContentPart] = []
|
||||
has_multimodal_content = False
|
||||
|
||||
for part in content_list:
|
||||
if isinstance(part, LLMContentPart) and part.type == "text":
|
||||
new_content_parts.append(part)
|
||||
else:
|
||||
has_multimodal_content = True
|
||||
|
||||
if has_multimodal_content:
|
||||
placeholder = "[用户发送了媒体文件,内容已在首次分析时处理]"
|
||||
text_part_found = False
|
||||
for part in new_content_parts:
|
||||
if part.type == "text":
|
||||
part.text = f"{placeholder} {part.text or ''}".strip()
|
||||
text_part_found = True
|
||||
break
|
||||
if not text_part_found:
|
||||
new_content_parts.insert(0, LLMContentPart.text_part(placeholder))
|
||||
|
||||
sanitized_message.content = new_content_parts
|
||||
return sanitized_message
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
message: str | LLMMessage | list[LLMContentPart],
|
||||
*,
|
||||
model: ModelName = None,
|
||||
preserve_media_in_history: bool | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
进行一次聊天对话。
|
||||
此方法会自动使用和更新会话内的历史记录。
|
||||
|
||||
参数:
|
||||
message: 用户输入的消息。
|
||||
model: 本次对话要使用的模型。
|
||||
preserve_media_in_history: 是否在历史记录中保留原始多模态信息。
|
||||
- True: 保留,用于深度多轮媒体分析。
|
||||
- False: 不保留,替换为占位符,提高效率。
|
||||
- None (默认): 使用AI实例配置的默认值。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
str: 模型的文本响应。
|
||||
"""
|
||||
current_message: LLMMessage
|
||||
if isinstance(message, str):
|
||||
current_message = LLMMessage.user(message)
|
||||
elif isinstance(message, list) and all(
|
||||
isinstance(part, LLMContentPart) for part in message
|
||||
):
|
||||
current_message = LLMMessage.user(message)
|
||||
elif isinstance(message, LLMMessage):
|
||||
current_message = message
|
||||
else:
|
||||
raise LLMException(
|
||||
f"AI.chat 不支持的消息类型: {type(message)}. "
|
||||
"请使用 str, LLMMessage, 或 list[LLMContentPart]. "
|
||||
"对于更复杂的多模态输入或文件路径,请使用 AI.analyze().",
|
||||
code=LLMErrorCode.API_REQUEST_FAILED,
|
||||
)
|
||||
|
||||
final_messages = [*self.history, current_message]
|
||||
|
||||
response = await self._execute_generation(
|
||||
final_messages, model, "聊天失败", kwargs
|
||||
)
|
||||
|
||||
should_preserve = (
|
||||
preserve_media_in_history
|
||||
if preserve_media_in_history is not None
|
||||
else self.config.default_preserve_media_in_history
|
||||
)
|
||||
|
||||
if should_preserve:
|
||||
logger.debug("深度分析模式:在历史记录中保留原始多模态消息。")
|
||||
self.history.append(current_message)
|
||||
else:
|
||||
logger.debug("高效模式:净化历史记录中的多模态消息。")
|
||||
sanitized_user_message = self._sanitize_message_for_history(current_message)
|
||||
self.history.append(sanitized_user_message)
|
||||
|
||||
self.history.append(LLMMessage.assistant_text_response(response.text))
|
||||
|
||||
return response.text
|
||||
|
||||
async def code(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
model: ModelName = None,
|
||||
timeout: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
代码执行
|
||||
|
||||
参数:
|
||||
prompt: 代码执行的提示词。
|
||||
model: 要使用的模型名称。
|
||||
timeout: 代码执行超时时间(秒)。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
dict[str, Any]: 包含执行结果的字典,包含text、code_executions和success字段。
|
||||
"""
|
||||
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
|
||||
|
||||
config = CommonOverrides.gemini_code_execution()
|
||||
if timeout:
|
||||
config.custom_params = config.custom_params or {}
|
||||
config.custom_params["code_execution_timeout"] = timeout
|
||||
|
||||
messages = [LLMMessage.user(prompt)]
|
||||
|
||||
response = await self._execute_generation(
|
||||
messages, resolved_model, "代码执行失败", kwargs, base_config=config
|
||||
)
|
||||
|
||||
return {
|
||||
"text": response.text,
|
||||
"code_executions": response.code_executions or [],
|
||||
"success": True,
|
||||
}
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str | UniMessage,
|
||||
*,
|
||||
model: ModelName = None,
|
||||
instruction: str = "",
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
信息搜索 - 支持多模态输入
|
||||
|
||||
参数:
|
||||
query: 搜索查询内容,支持文本或多模态消息。
|
||||
model: 要使用的模型名称。
|
||||
instruction: 搜索指令。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
dict[str, Any]: 包含搜索结果的字典,包含text、sources、queries和success字段
|
||||
"""
|
||||
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
|
||||
config = CommonOverrides.gemini_grounding()
|
||||
|
||||
if isinstance(query, str):
|
||||
messages = [LLMMessage.user(query)]
|
||||
elif isinstance(query, UniMessage):
|
||||
content_parts = await unimsg_to_llm_parts(query)
|
||||
|
||||
final_messages: list[LLMMessage] = []
|
||||
if instruction:
|
||||
final_messages.append(LLMMessage.system(instruction))
|
||||
|
||||
if not content_parts:
|
||||
if instruction:
|
||||
final_messages.append(LLMMessage.user(instruction))
|
||||
else:
|
||||
raise LLMException(
|
||||
"搜索内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
|
||||
)
|
||||
else:
|
||||
final_messages.append(LLMMessage.user(content_parts))
|
||||
|
||||
messages = final_messages
|
||||
else:
|
||||
raise LLMException(
|
||||
f"不支持的搜索输入类型: {type(query)}. 请使用 str 或 UniMessage.",
|
||||
code=LLMErrorCode.API_REQUEST_FAILED,
|
||||
)
|
||||
|
||||
response = await self._execute_generation(
|
||||
messages, resolved_model, "信息搜索失败", kwargs, base_config=config
|
||||
)
|
||||
|
||||
result = {
|
||||
"text": response.text,
|
||||
"sources": [],
|
||||
"queries": [],
|
||||
"success": True,
|
||||
}
|
||||
|
||||
if response.grounding_metadata:
|
||||
result["sources"] = response.grounding_metadata.grounding_attributions or []
|
||||
result["queries"] = response.grounding_metadata.web_search_queries or []
|
||||
|
||||
return result
|
||||
|
||||
async def analyze(
|
||||
self,
|
||||
message: UniMessage | None,
|
||||
*,
|
||||
instruction: str = "",
|
||||
model: ModelName = None,
|
||||
use_tools: list[str] | None = None,
|
||||
tool_config: dict[str, Any] | None = None,
|
||||
activated_tools: list[LLMTool] | None = None,
|
||||
history: list[LLMMessage] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
内容分析 - 接收 UniMessage 物件进行多模态分析和工具呼叫。
|
||||
|
||||
参数:
|
||||
message: 要分析的消息内容(支持多模态)。
|
||||
instruction: 分析指令。
|
||||
model: 要使用的模型名称。
|
||||
use_tools: 要使用的工具名称列表。
|
||||
tool_config: 工具配置。
|
||||
activated_tools: 已激活的工具列表。
|
||||
history: 对话历史记录。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
LLMResponse: 模型的完整响应结果。
|
||||
"""
|
||||
content_parts = await unimsg_to_llm_parts(message or UniMessage())
|
||||
|
||||
final_messages: list[LLMMessage] = []
|
||||
if history:
|
||||
final_messages.extend(history)
|
||||
|
||||
if instruction:
|
||||
if not any(msg.role == "system" for msg in final_messages):
|
||||
final_messages.insert(0, LLMMessage.system(instruction))
|
||||
|
||||
if not content_parts:
|
||||
if instruction and not history:
|
||||
final_messages.append(LLMMessage.user(instruction))
|
||||
elif not history:
|
||||
raise LLMException(
|
||||
"分析内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
|
||||
)
|
||||
else:
|
||||
final_messages.append(LLMMessage.user(content_parts))
|
||||
|
||||
llm_tools: list[LLMTool] | None = activated_tools
|
||||
if not llm_tools and use_tools:
|
||||
try:
|
||||
llm_tools = tool_registry.get_tools(use_tools)
|
||||
logger.debug(f"已从注册表加载工具定义: {use_tools}")
|
||||
except ValueError as e:
|
||||
raise LLMException(
|
||||
f"加载工具定义失败: {e}",
|
||||
code=LLMErrorCode.CONFIGURATION_ERROR,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
tool_choice = None
|
||||
if tool_config:
|
||||
mode = tool_config.get("mode", "auto")
|
||||
if mode in ["auto", "any", "none"]:
|
||||
tool_choice = mode
|
||||
|
||||
response = await self._execute_generation(
|
||||
final_messages,
|
||||
model,
|
||||
"内容分析失败",
|
||||
kwargs,
|
||||
llm_tools=llm_tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def _execute_generation(
|
||||
self,
|
||||
messages: list[LLMMessage],
|
||||
model_name: ModelName,
|
||||
error_message: str,
|
||||
config_overrides: dict[str, Any],
|
||||
llm_tools: list[LLMTool] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
base_config: LLMGenerationConfig | None = None,
|
||||
) -> LLMResponse:
|
||||
"""通用的生成执行方法,封装模型获取和单次API调用"""
|
||||
try:
|
||||
resolved_model_name = self._resolve_model_name(
|
||||
model_name or self.config.model
|
||||
)
|
||||
final_config_dict = self._merge_config(
|
||||
config_overrides, base_config=base_config
|
||||
)
|
||||
|
||||
async with await get_model_instance(
|
||||
resolved_model_name, override_config=final_config_dict
|
||||
) as model_instance:
|
||||
return await model_instance.generate_response(
|
||||
messages,
|
||||
tools=llm_tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
except LLMException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"{error_message}: {e}", e=e)
|
||||
raise LLMException(f"{error_message}: {e}", cause=e)
|
||||
|
||||
def _resolve_model_name(self, model_name: ModelName) -> str:
|
||||
"""解析模型名称"""
|
||||
if model_name:
|
||||
return model_name
|
||||
|
||||
default_model = get_global_default_model_name()
|
||||
if default_model:
|
||||
return default_model
|
||||
|
||||
raise LLMException(
|
||||
"未指定模型名称且未设置全局默认模型",
|
||||
code=LLMErrorCode.MODEL_NOT_FOUND,
|
||||
)
|
||||
|
||||
def _merge_config(
|
||||
self,
|
||||
user_config: dict[str, Any],
|
||||
base_config: LLMGenerationConfig | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""合并配置"""
|
||||
final_config = {}
|
||||
if base_config:
|
||||
final_config.update(base_config.to_dict())
|
||||
|
||||
if self.config.temperature is not None:
|
||||
final_config["temperature"] = self.config.temperature
|
||||
if self.config.max_tokens is not None:
|
||||
final_config["max_tokens"] = self.config.max_tokens
|
||||
|
||||
if self.config.enable_cache:
|
||||
final_config["enable_caching"] = True
|
||||
if self.config.enable_code:
|
||||
final_config["enable_code_execution"] = True
|
||||
if self.config.enable_search:
|
||||
final_config["enable_grounding"] = True
|
||||
|
||||
if self.config.enable_gemini_json_mode:
|
||||
final_config["response_mime_type"] = "application/json"
|
||||
if self.config.enable_gemini_thinking:
|
||||
final_config["thinking_budget"] = 0.8
|
||||
if self.config.enable_gemini_safe_mode:
|
||||
final_config["safety_settings"] = (
|
||||
CommonOverrides.gemini_safe().safety_settings
|
||||
)
|
||||
if self.config.enable_gemini_multimodal:
|
||||
final_config.update(CommonOverrides.gemini_multimodal().to_dict())
|
||||
if self.config.enable_gemini_grounding:
|
||||
final_config["enable_grounding"] = True
|
||||
|
||||
final_config.update(user_config)
|
||||
|
||||
return final_config
|
||||
|
||||
async def embed(
|
||||
self,
|
||||
texts: list[str] | str,
|
||||
*,
|
||||
model: ModelName = None,
|
||||
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
||||
**kwargs: Any,
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
生成文本嵌入向量
|
||||
|
||||
参数:
|
||||
texts: 要生成嵌入向量的文本或文本列表。
|
||||
model: 要使用的嵌入模型名称。
|
||||
task_type: 嵌入任务类型。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
list[list[float]]: 文本的嵌入向量列表。
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
try:
|
||||
resolved_model_str = (
|
||||
model or self.config.default_embedding_model or self.config.model
|
||||
)
|
||||
if not resolved_model_str:
|
||||
raise LLMException(
|
||||
"使用 embed 功能时必须指定嵌入模型名称,"
|
||||
"或在 AIConfig 中配置 default_embedding_model。",
|
||||
code=LLMErrorCode.MODEL_NOT_FOUND,
|
||||
)
|
||||
resolved_model_str = self._resolve_model_name(resolved_model_str)
|
||||
|
||||
async with await get_model_instance(
|
||||
resolved_model_str,
|
||||
override_config=None,
|
||||
) as embedding_model_instance:
|
||||
return await embedding_model_instance.generate_embeddings(
|
||||
texts, task_type=task_type, **kwargs
|
||||
)
|
||||
except LLMException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"文本嵌入失败: {e}", e=e)
|
||||
raise LLMException(
|
||||
f"文本嵌入失败: {e}", code=LLMErrorCode.EMBEDDING_FAILED, cause=e
|
||||
)
|
||||
|
||||
|
||||
async def chat(
|
||||
message: str | LLMMessage | list[LLMContentPart],
|
||||
*,
|
||||
model: ModelName = None,
|
||||
tools: list[LLMTool] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
聊天对话便捷函数
|
||||
|
||||
参数:
|
||||
message: 用户输入的消息。
|
||||
model: 要使用的模型名称。
|
||||
tools: 本次对话可用的工具列表。
|
||||
tool_choice: 强制模型使用的工具。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
str: 模型的文本响应。
|
||||
LLMResponse: 模型的完整响应,可能包含文本或工具调用请求。
|
||||
"""
|
||||
ai = AI()
|
||||
return await ai.chat(message, model=model, **kwargs)
|
||||
return await ai.chat(
|
||||
message, model=model, tools=tools, tool_choice=tool_choice, **kwargs
|
||||
)
|
||||
|
||||
|
||||
async def code(
|
||||
@ -730,12 +242,14 @@ async def pipeline_chat(
|
||||
raise ValueError("模型链`model_chain`不能为空。")
|
||||
|
||||
current_content: str | list[LLMContentPart]
|
||||
if isinstance(message, str):
|
||||
if isinstance(message, UniMessage):
|
||||
current_content = await unimsg_to_llm_parts(message)
|
||||
elif isinstance(message, str):
|
||||
current_content = message
|
||||
elif isinstance(message, list):
|
||||
current_content = message
|
||||
else:
|
||||
current_content = await unimsg_to_llm_parts(message)
|
||||
raise TypeError(f"不支持的消息类型: {type(message)}")
|
||||
|
||||
final_response: LLMResponse | None = None
|
||||
|
||||
@ -787,3 +301,45 @@ async def pipeline_chat(
|
||||
)
|
||||
|
||||
return final_response
|
||||
|
||||
|
||||
async def generate(
|
||||
messages: list[LLMMessage],
|
||||
*,
|
||||
model: ModelName = None,
|
||||
tools: list[LLMTool] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
根据完整的消息列表(包括系统指令)生成一次性响应。
|
||||
这是一个便捷的函数,不使用或修改任何会话历史。
|
||||
|
||||
参数:
|
||||
messages: 用于生成响应的完整消息列表。
|
||||
model: 要使用的模型名称。
|
||||
tools: 可用的工具列表。
|
||||
tool_choice: 工具选择策略。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
LLMResponse: 模型的完整响应对象。
|
||||
"""
|
||||
try:
|
||||
ai_instance = AI()
|
||||
resolved_model_name = ai_instance._resolve_model_name(model)
|
||||
final_config_dict = ai_instance._merge_config(kwargs)
|
||||
|
||||
async with await get_model_instance(
|
||||
resolved_model_name, override_config=final_config_dict
|
||||
) as model_instance:
|
||||
return await model_instance.generate_response(
|
||||
messages,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
except LLMException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"生成响应失败: {e}", e=e)
|
||||
raise LLMException(f"生成响应失败: {e}", cause=e)
|
||||
|
||||
@ -17,6 +17,7 @@ from zhenxun.configs.utils import parse_as
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||
|
||||
from ..core import key_store
|
||||
from ..types.models import ModelDetail, ProviderConfig
|
||||
|
||||
|
||||
@ -502,12 +503,13 @@ def set_default_model(provider_model_name: str | None) -> bool:
|
||||
@PriorityLifecycle.on_startup(priority=10)
|
||||
async def _init_llm_config_on_startup():
|
||||
"""
|
||||
在服务启动时主动调用一次 get_llm_config,
|
||||
以触发必要的初始化操作,例如创建默认的 mcp_tools.json 文件。
|
||||
在服务启动时主动调用一次 get_llm_config 和 key_store.initialize,
|
||||
以触发必要的初始化操作。
|
||||
"""
|
||||
logger.info("正在初始化 LLM 配置并检查 MCP 工具文件...")
|
||||
logger.info("正在初始化 LLM 配置并加载密钥状态...")
|
||||
try:
|
||||
get_llm_config()
|
||||
logger.info("LLM 配置初始化完成。")
|
||||
await key_store.initialize()
|
||||
logger.info("LLM 配置和密钥状态初始化完成。")
|
||||
except Exception as e:
|
||||
logger.error(f"LLM 配置初始化时发生错误: {e}", e=e)
|
||||
logger.error(f"LLM 配置或密钥状态初始化时发生错误: {e}", e=e)
|
||||
|
||||
@ -5,17 +5,27 @@ LLM 核心基础设施模块
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from dataclasses import asdict, dataclass
|
||||
from enum import IntEnum
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import aiofiles
|
||||
import httpx
|
||||
import nonebot
|
||||
from pydantic import BaseModel
|
||||
|
||||
from zhenxun.configs.path_config import DATA_PATH
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.user_agent import get_user_agent
|
||||
|
||||
from .types import ProviderConfig
|
||||
from .types.exceptions import LLMErrorCode, LLMException
|
||||
|
||||
driver = nonebot.get_driver()
|
||||
|
||||
|
||||
class HttpClientConfig(BaseModel):
|
||||
"""HTTP客户端配置"""
|
||||
@ -194,6 +204,82 @@ async def create_llm_http_client(
|
||||
return LLMHttpClient(config)
|
||||
|
||||
|
||||
class KeyStatus(IntEnum):
|
||||
"""用于排序和展示的密钥状态枚举"""
|
||||
|
||||
DISABLED = 0
|
||||
ERROR = 1
|
||||
COOLDOWN = 2
|
||||
WARNING = 3
|
||||
HEALTHY = 4
|
||||
UNUSED = 5
|
||||
|
||||
|
||||
@dataclass
|
||||
class KeyStats:
|
||||
"""单个API Key的详细状态和统计信息"""
|
||||
|
||||
cooldown_until: float = 0.0
|
||||
success_count: int = 0
|
||||
failure_count: int = 0
|
||||
total_latency: float = 0.0
|
||||
last_error_info: str | None = None
|
||||
|
||||
@property
|
||||
def is_available(self) -> bool:
|
||||
"""检查Key当前是否可用"""
|
||||
return time.time() >= self.cooldown_until
|
||||
|
||||
@property
|
||||
def avg_latency(self) -> float:
|
||||
"""计算平均延迟"""
|
||||
return (
|
||||
self.total_latency / self.success_count if self.success_count > 0 else 0.0
|
||||
)
|
||||
|
||||
@property
|
||||
def success_rate(self) -> float:
|
||||
"""计算成功率"""
|
||||
total = self.success_count + self.failure_count
|
||||
return self.success_count / total * 100 if total > 0 else 100.0
|
||||
|
||||
@property
|
||||
def status(self) -> KeyStatus:
|
||||
"""根据当前统计数据动态计算状态"""
|
||||
now = time.time()
|
||||
cooldown_left = max(0, self.cooldown_until - now)
|
||||
|
||||
if cooldown_left > 31536000 - 60:
|
||||
return KeyStatus.DISABLED
|
||||
if cooldown_left > 0:
|
||||
return KeyStatus.COOLDOWN
|
||||
|
||||
total_calls = self.success_count + self.failure_count
|
||||
if total_calls == 0:
|
||||
return KeyStatus.UNUSED
|
||||
|
||||
if self.success_rate < 80:
|
||||
return KeyStatus.ERROR
|
||||
|
||||
if total_calls >= 5 and self.avg_latency > 15000:
|
||||
return KeyStatus.WARNING
|
||||
|
||||
return KeyStatus.HEALTHY
|
||||
|
||||
@property
|
||||
def suggested_action(self) -> str:
|
||||
"""根据状态给出建议操作"""
|
||||
status_actions = {
|
||||
KeyStatus.DISABLED: "更换Key",
|
||||
KeyStatus.ERROR: "检查网络/重置",
|
||||
KeyStatus.COOLDOWN: "等待/重置",
|
||||
KeyStatus.WARNING: "观察",
|
||||
KeyStatus.HEALTHY: "-",
|
||||
KeyStatus.UNUSED: "-",
|
||||
}
|
||||
return status_actions.get(self.status, "未知")
|
||||
|
||||
|
||||
class RetryConfig:
|
||||
"""重试配置"""
|
||||
|
||||
@ -236,26 +322,38 @@ async def with_smart_retry(
|
||||
last_exception: Exception | None = None
|
||||
failed_keys: set[str] = set()
|
||||
|
||||
model_instance = next((arg for arg in args if hasattr(arg, "api_keys")), None)
|
||||
all_provider_keys = model_instance.api_keys if model_instance else []
|
||||
|
||||
for attempt in range(config.max_retries + 1):
|
||||
try:
|
||||
if config.key_rotation and "failed_keys" in func.__code__.co_varnames:
|
||||
kwargs["failed_keys"] = failed_keys
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
start_time = time.monotonic()
|
||||
result = await func(*args, **kwargs)
|
||||
latency = (time.monotonic() - start_time) * 1000
|
||||
|
||||
if key_store and isinstance(result, tuple) and len(result) == 2:
|
||||
final_result, api_key_used = result
|
||||
if api_key_used:
|
||||
await key_store.record_success(api_key_used, latency)
|
||||
return final_result
|
||||
else:
|
||||
return result
|
||||
|
||||
except LLMException as e:
|
||||
last_exception = e
|
||||
api_key_in_use = e.details.get("api_key")
|
||||
|
||||
if e.code in [
|
||||
LLMErrorCode.API_KEY_INVALID,
|
||||
LLMErrorCode.API_QUOTA_EXCEEDED,
|
||||
]:
|
||||
if hasattr(e, "details") and e.details and "api_key" in e.details:
|
||||
failed_keys.add(e.details["api_key"])
|
||||
if key_store and provider_name:
|
||||
await key_store.record_failure(
|
||||
e.details["api_key"], e.details.get("status_code")
|
||||
)
|
||||
if api_key_in_use:
|
||||
failed_keys.add(api_key_in_use)
|
||||
if key_store and provider_name and len(all_provider_keys) > 1:
|
||||
status_code = e.details.get("status_code")
|
||||
error_message = f"({e.code.name}) {e.message}"
|
||||
await key_store.record_failure(
|
||||
api_key_in_use, status_code, error_message
|
||||
)
|
||||
|
||||
should_retry = _should_retry_llm_error(e, attempt, config.max_retries)
|
||||
if not should_retry:
|
||||
@ -267,7 +365,7 @@ async def with_smart_retry(
|
||||
if config.exponential_backoff:
|
||||
wait_time *= 2**attempt
|
||||
logger.warning(
|
||||
f"请求失败,{wait_time}秒后重试 (第{attempt + 1}次): {e}"
|
||||
f"请求失败,{wait_time:.2f}秒后重试 (第{attempt + 1}次): {e}"
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
@ -325,14 +423,66 @@ def _should_retry_llm_error(
|
||||
|
||||
|
||||
class KeyStatusStore:
|
||||
"""API Key 状态管理存储 - 优化版本,支持轮询和负载均衡"""
|
||||
"""API Key 状态管理存储 - 支持持久化"""
|
||||
|
||||
def __init__(self):
|
||||
self._key_status: dict[str, bool] = {}
|
||||
self._key_usage_count: dict[str, int] = {}
|
||||
self._key_last_used: dict[str, float] = {}
|
||||
self._key_stats: dict[str, KeyStats] = {}
|
||||
self._provider_key_index: dict[str, int] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
self._file_path = DATA_PATH / "llm" / "key_status.json"
|
||||
|
||||
async def initialize(self):
|
||||
"""从文件异步加载密钥状态,在应用启动时调用"""
|
||||
async with self._lock:
|
||||
if not self._file_path.exists():
|
||||
logger.info("未找到密钥状态文件,将使用内存状态启动。")
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info(f"正在从 {self._file_path} 加载密钥状态...")
|
||||
async with aiofiles.open(self._file_path, encoding="utf-8") as f:
|
||||
content = await f.read()
|
||||
if not content:
|
||||
logger.warning("密钥状态文件为空。")
|
||||
return
|
||||
data = json.loads(content)
|
||||
|
||||
for key, stats_dict in data.items():
|
||||
self._key_stats[key] = KeyStats(**stats_dict)
|
||||
|
||||
logger.info(f"成功加载 {len(self._key_stats)} 个密钥的状态。")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"密钥状态文件 {self._file_path} 格式错误,无法解析。")
|
||||
except Exception as e:
|
||||
logger.error(f"加载密钥状态文件时发生错误: {e}", e=e)
|
||||
|
||||
async def _save_to_file_internal(self):
|
||||
"""
|
||||
[内部方法] 将当前密钥状态安全地写入JSON文件。
|
||||
假定调用方已持有锁。
|
||||
"""
|
||||
data_to_save = {key: asdict(stats) for key, stats in self._key_stats.items()}
|
||||
|
||||
try:
|
||||
self._file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
temp_path = self._file_path.with_suffix(".json.tmp")
|
||||
|
||||
async with aiofiles.open(temp_path, "w", encoding="utf-8") as f:
|
||||
await f.write(json.dumps(data_to_save, ensure_ascii=False, indent=2))
|
||||
|
||||
if self._file_path.exists():
|
||||
self._file_path.unlink()
|
||||
os.rename(temp_path, self._file_path)
|
||||
logger.debug("密钥状态已成功持久化到文件。")
|
||||
except Exception as e:
|
||||
logger.error(f"保存密钥状态到文件失败: {e}", e=e)
|
||||
|
||||
async def shutdown(self):
|
||||
"""在应用关闭时安全地保存状态"""
|
||||
async with self._lock:
|
||||
await self._save_to_file_internal()
|
||||
logger.info("KeyStatusStore 已在关闭前保存状态。")
|
||||
|
||||
async def get_next_available_key(
|
||||
self,
|
||||
@ -355,88 +505,122 @@ class KeyStatusStore:
|
||||
return None
|
||||
|
||||
exclude_keys = exclude_keys or set()
|
||||
available_keys = [
|
||||
key
|
||||
for key in api_keys
|
||||
if key not in exclude_keys and self._key_status.get(key, True)
|
||||
]
|
||||
|
||||
if not available_keys:
|
||||
return api_keys[0] if api_keys else None
|
||||
|
||||
async with self._lock:
|
||||
for key in api_keys:
|
||||
if key not in self._key_stats:
|
||||
self._key_stats[key] = KeyStats()
|
||||
|
||||
available_keys = [
|
||||
key
|
||||
for key in api_keys
|
||||
if key not in exclude_keys and self._key_stats[key].is_available
|
||||
]
|
||||
|
||||
if not available_keys:
|
||||
return api_keys[0]
|
||||
|
||||
current_index = self._provider_key_index.get(provider_name, 0)
|
||||
|
||||
selected_key = available_keys[current_index % len(available_keys)]
|
||||
self._provider_key_index[provider_name] = current_index + 1
|
||||
|
||||
self._provider_key_index[provider_name] = (current_index + 1) % len(
|
||||
available_keys
|
||||
total_usage = (
|
||||
self._key_stats[selected_key].success_count
|
||||
+ self._key_stats[selected_key].failure_count
|
||||
)
|
||||
|
||||
import time
|
||||
|
||||
self._key_usage_count[selected_key] = (
|
||||
self._key_usage_count.get(selected_key, 0) + 1
|
||||
)
|
||||
self._key_last_used[selected_key] = time.time()
|
||||
|
||||
logger.debug(
|
||||
f"轮询选择API密钥: {self._get_key_id(selected_key)} "
|
||||
f"(使用次数: {self._key_usage_count[selected_key]})"
|
||||
f"(使用次数: {total_usage})"
|
||||
)
|
||||
|
||||
return selected_key
|
||||
|
||||
async def record_success(self, api_key: str):
|
||||
"""记录成功使用"""
|
||||
async def record_success(self, api_key: str, latency: float):
|
||||
"""记录成功使用,并持久化"""
|
||||
async with self._lock:
|
||||
self._key_status[api_key] = True
|
||||
logger.debug(f"记录API密钥成功使用: {self._get_key_id(api_key)}")
|
||||
stats = self._key_stats.setdefault(api_key, KeyStats())
|
||||
stats.cooldown_until = 0.0
|
||||
stats.success_count += 1
|
||||
stats.total_latency += latency
|
||||
stats.last_error_info = None
|
||||
await self._save_to_file_internal()
|
||||
logger.debug(
|
||||
f"记录API密钥成功使用: {self._get_key_id(api_key)}, 延迟: {latency:.2f}ms"
|
||||
)
|
||||
|
||||
async def record_failure(self, api_key: str, status_code: int | None):
|
||||
async def record_failure(
|
||||
self, api_key: str, status_code: int | None, error_message: str
|
||||
):
|
||||
"""
|
||||
记录失败使用
|
||||
记录失败使用,并设置冷却时间
|
||||
|
||||
参数:
|
||||
api_key: API密钥。
|
||||
status_code: HTTP状态码。
|
||||
error_message: 错误信息。
|
||||
"""
|
||||
key_id = self._get_key_id(api_key)
|
||||
now = time.time()
|
||||
cooldown_duration = 300
|
||||
|
||||
if status_code in [401, 403, 404]:
|
||||
cooldown_duration = 31536000
|
||||
log_level = "error"
|
||||
log_message = f"API密钥认证/权限/路径错误,将永久禁用: {key_id}"
|
||||
elif status_code == 429:
|
||||
cooldown_duration = 60
|
||||
log_level = "warning"
|
||||
log_message = f"API密钥被限流,冷却60秒: {key_id}"
|
||||
else:
|
||||
log_level = "warning"
|
||||
log_message = f"API密钥遇到临时性错误,冷却{cooldown_duration}秒: {key_id}"
|
||||
|
||||
async with self._lock:
|
||||
if status_code in [401, 403]:
|
||||
self._key_status[api_key] = False
|
||||
logger.warning(
|
||||
f"API密钥认证失败,标记为不可用: {key_id} (状态码: {status_code})"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"记录API密钥失败使用: {key_id} (状态码: {status_code})")
|
||||
stats = self._key_stats.setdefault(api_key, KeyStats())
|
||||
stats.cooldown_until = now + cooldown_duration
|
||||
stats.failure_count += 1
|
||||
stats.last_error_info = error_message[:256]
|
||||
await self._save_to_file_internal()
|
||||
|
||||
getattr(logger, log_level)(log_message)
|
||||
|
||||
async def reset_key_status(self, api_key: str):
|
||||
"""重置密钥状态(用于恢复机制)"""
|
||||
"""重置密钥状态,并持久化"""
|
||||
async with self._lock:
|
||||
self._key_status[api_key] = True
|
||||
stats = self._key_stats.setdefault(api_key, KeyStats())
|
||||
stats.cooldown_until = 0.0
|
||||
stats.last_error_info = None
|
||||
await self._save_to_file_internal()
|
||||
logger.info(f"重置API密钥状态: {self._get_key_id(api_key)}")
|
||||
|
||||
async def get_key_stats(self, api_keys: list[str]) -> dict[str, dict]:
|
||||
"""
|
||||
获取密钥使用统计
|
||||
获取密钥使用统计,并计算出用于展示的派生数据。
|
||||
|
||||
参数:
|
||||
api_keys: API密钥列表。
|
||||
|
||||
返回:
|
||||
dict[str, dict]: 密钥统计信息字典。
|
||||
dict[str, dict]: 包含丰富状态和统计信息的密钥字典。
|
||||
"""
|
||||
stats = {}
|
||||
stats_dict = {}
|
||||
now = time.time()
|
||||
async with self._lock:
|
||||
for key in api_keys:
|
||||
key_id = self._get_key_id(key)
|
||||
stats[key_id] = {
|
||||
"available": self._key_status.get(key, True),
|
||||
"usage_count": self._key_usage_count.get(key, 0),
|
||||
"last_used": self._key_last_used.get(key, 0),
|
||||
stats = self._key_stats.get(key, KeyStats())
|
||||
|
||||
stats_dict[key_id] = {
|
||||
"status_enum": stats.status,
|
||||
"cooldown_seconds_left": max(0, stats.cooldown_until - now),
|
||||
"total_calls": stats.success_count + stats.failure_count,
|
||||
"success_count": stats.success_count,
|
||||
"failure_count": stats.failure_count,
|
||||
"success_rate": stats.success_rate,
|
||||
"avg_latency": stats.avg_latency,
|
||||
"last_error": stats.last_error_info,
|
||||
"suggested_action": stats.suggested_action,
|
||||
}
|
||||
return stats
|
||||
return stats_dict
|
||||
|
||||
def _get_key_id(self, api_key: str) -> str:
|
||||
"""获取API密钥的标识符(用于日志)"""
|
||||
@ -446,3 +630,8 @@ class KeyStatusStore:
|
||||
|
||||
|
||||
key_store = KeyStatusStore()
|
||||
|
||||
|
||||
@driver.on_shutdown
|
||||
async def _shutdown_key_store():
|
||||
await key_store.shutdown()
|
||||
|
||||
@ -137,8 +137,8 @@ def get_configured_providers() -> list[ProviderConfig]:
|
||||
valid_providers.append(item)
|
||||
else:
|
||||
logger.warning(
|
||||
f"配置文件中第 {i + 1} 项未能正确解析为 ProviderConfig 对象,"
|
||||
f"已跳过。实际类型: {type(item)}"
|
||||
f"配置文件中第 {i + 1} 项未能正确解析为 ProviderConfig 对象,已跳过。"
|
||||
f"实际类型: {type(item)}"
|
||||
)
|
||||
|
||||
return valid_providers
|
||||
|
||||
@ -46,17 +46,7 @@ class LLMModelBase(ABC):
|
||||
history: list[dict[str, str]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
生成文本
|
||||
|
||||
参数:
|
||||
prompt: 输入提示词。
|
||||
history: 对话历史记录。
|
||||
**kwargs: 其他参数。
|
||||
|
||||
返回:
|
||||
str: 生成的文本。
|
||||
"""
|
||||
"""生成文本"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -68,19 +58,7 @@ class LLMModelBase(ABC):
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
生成高级响应
|
||||
|
||||
参数:
|
||||
messages: 消息列表。
|
||||
config: 生成配置。
|
||||
tools: 工具列表。
|
||||
tool_choice: 工具选择策略。
|
||||
**kwargs: 其他参数。
|
||||
|
||||
返回:
|
||||
LLMResponse: 模型响应。
|
||||
"""
|
||||
"""生成高级响应"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -90,17 +68,7 @@ class LLMModelBase(ABC):
|
||||
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
||||
**kwargs: Any,
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
生成文本嵌入向量
|
||||
|
||||
参数:
|
||||
texts: 文本列表。
|
||||
task_type: 嵌入任务类型。
|
||||
**kwargs: 其他参数。
|
||||
|
||||
返回:
|
||||
list[list[float]]: 嵌入向量列表。
|
||||
"""
|
||||
"""生成文本嵌入向量"""
|
||||
pass
|
||||
|
||||
|
||||
@ -208,28 +176,8 @@ class LLMModel(LLMModelBase):
|
||||
http_client: "LLMHttpClient",
|
||||
failed_keys: set[str] | None = None,
|
||||
log_context: str = "API",
|
||||
) -> Any:
|
||||
"""
|
||||
执行API调用的通用核心方法。
|
||||
|
||||
该方法封装了以下通用逻辑:
|
||||
1. 选择API密钥。
|
||||
2. 准备和记录请求。
|
||||
3. 发送HTTP POST请求。
|
||||
4. 处理HTTP错误和API特定错误。
|
||||
5. 记录密钥使用状态。
|
||||
6. 解析成功的响应。
|
||||
|
||||
参数:
|
||||
prepare_request_func: 准备请求的函数。
|
||||
parse_response_func: 解析响应的函数。
|
||||
http_client: HTTP客户端。
|
||||
failed_keys: 失败的密钥集合。
|
||||
log_context: 日志上下文。
|
||||
|
||||
返回:
|
||||
Any: 解析后的响应数据。
|
||||
"""
|
||||
) -> tuple[Any, str]:
|
||||
"""执行API调用的通用核心方法"""
|
||||
api_key = await self._select_api_key(failed_keys)
|
||||
|
||||
try:
|
||||
@ -267,7 +215,9 @@ class LLMModel(LLMModelBase):
|
||||
)
|
||||
logger.debug(f"💥 完整错误响应: {error_text}")
|
||||
|
||||
await self.key_store.record_failure(api_key, http_response.status_code)
|
||||
await self.key_store.record_failure(
|
||||
api_key, http_response.status_code, error_text
|
||||
)
|
||||
|
||||
if http_response.status_code in [401, 403]:
|
||||
error_code = LLMErrorCode.API_KEY_INVALID
|
||||
@ -298,7 +248,7 @@ class LLMModel(LLMModelBase):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析 {log_context} 响应失败: {e}", e=e)
|
||||
await self.key_store.record_failure(api_key, None)
|
||||
await self.key_store.record_failure(api_key, None, str(e))
|
||||
if isinstance(e, LLMException):
|
||||
raise
|
||||
else:
|
||||
@ -308,17 +258,15 @@ class LLMModel(LLMModelBase):
|
||||
cause=e,
|
||||
)
|
||||
|
||||
await self.key_store.record_success(api_key)
|
||||
logger.debug(f"✅ API密钥使用成功: {masked_key}")
|
||||
logger.info(f"🎯 LLM响应解析完成 [{log_context}]")
|
||||
return parsed_data
|
||||
return parsed_data, api_key
|
||||
|
||||
except LLMException:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_log_msg = f"生成 {log_context.lower()} 时发生未预期错误: {e}"
|
||||
logger.error(error_log_msg, e=e)
|
||||
await self.key_store.record_failure(api_key, None)
|
||||
await self.key_store.record_failure(api_key, None, str(e))
|
||||
raise LLMException(
|
||||
error_log_msg,
|
||||
code=LLMErrorCode.GENERATION_FAILED
|
||||
@ -349,13 +297,14 @@ class LLMModel(LLMModelBase):
|
||||
adapter.validate_embedding_response(response_json)
|
||||
return adapter.parse_embedding_response(response_json)
|
||||
|
||||
return await self._perform_api_call(
|
||||
parsed_data, api_key_used = await self._perform_api_call(
|
||||
prepare_request_func=prepare_request,
|
||||
parse_response_func=parse_response,
|
||||
http_client=http_client,
|
||||
failed_keys=failed_keys,
|
||||
log_context="Embedding",
|
||||
)
|
||||
return parsed_data
|
||||
|
||||
async def _execute_with_smart_retry(
|
||||
self,
|
||||
@ -394,8 +343,8 @@ class LLMModel(LLMModelBase):
|
||||
tool_choice: str | dict[str, Any] | None,
|
||||
http_client: LLMHttpClient,
|
||||
failed_keys: set[str] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""执行单次请求 - 供重试机制调用,直接返回 LLMResponse"""
|
||||
) -> tuple[LLMResponse, str]:
|
||||
"""执行单次请求 - 供重试机制调用,直接返回 LLMResponse 和使用的 key"""
|
||||
|
||||
async def prepare_request(api_key: str) -> RequestData:
|
||||
return await adapter.prepare_advanced_request(
|
||||
@ -441,19 +390,17 @@ class LLMModel(LLMModelBase):
|
||||
cache_info=response_data.cache_info,
|
||||
)
|
||||
|
||||
return await self._perform_api_call(
|
||||
parsed_data, api_key_used = await self._perform_api_call(
|
||||
prepare_request_func=prepare_request,
|
||||
parse_response_func=parse_response,
|
||||
http_client=http_client,
|
||||
failed_keys=failed_keys,
|
||||
log_context="Generation",
|
||||
)
|
||||
return parsed_data, api_key_used
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
标记模型实例的当前使用周期结束。
|
||||
共享的 HTTP 客户端由 LLMHttpClientManager 管理,不由 LLMModel 关闭。
|
||||
"""
|
||||
"""标记模型实例的当前使用周期结束"""
|
||||
if self._is_closed:
|
||||
return
|
||||
self._is_closed = True
|
||||
@ -487,17 +434,7 @@ class LLMModel(LLMModelBase):
|
||||
history: list[dict[str, str]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
生成文本 - 通过 generate_response 实现
|
||||
|
||||
参数:
|
||||
prompt: 输入提示词。
|
||||
history: 对话历史记录。
|
||||
**kwargs: 其他参数。
|
||||
|
||||
返回:
|
||||
str: 生成的文本。
|
||||
"""
|
||||
"""生成文本"""
|
||||
self._check_not_closed()
|
||||
|
||||
messages: list[LLMMessage] = []
|
||||
@ -538,19 +475,7 @@ class LLMModel(LLMModelBase):
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
生成高级响应
|
||||
|
||||
参数:
|
||||
messages: 消息列表。
|
||||
config: 生成配置。
|
||||
tools: 工具列表。
|
||||
tool_choice: 工具选择策略。
|
||||
**kwargs: 其他参数。
|
||||
|
||||
返回:
|
||||
LLMResponse: 模型响应。
|
||||
"""
|
||||
"""生成高级响应"""
|
||||
self._check_not_closed()
|
||||
|
||||
from .adapters import get_adapter_for_api_type
|
||||
@ -619,17 +544,7 @@ class LLMModel(LLMModelBase):
|
||||
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
||||
**kwargs: Any,
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
生成文本嵌入向量
|
||||
|
||||
参数:
|
||||
texts: 文本列表。
|
||||
task_type: 嵌入任务类型。
|
||||
**kwargs: 其他参数。
|
||||
|
||||
返回:
|
||||
list[list[float]]: 嵌入向量列表。
|
||||
"""
|
||||
"""生成文本嵌入向量"""
|
||||
self._check_not_closed()
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
532
zhenxun/services/llm/session.py
Normal file
532
zhenxun/services/llm/session.py
Normal file
@ -0,0 +1,532 @@
|
||||
"""
|
||||
LLM 服务 - 会话客户端
|
||||
|
||||
提供一个有状态的、面向会话的 LLM 客户端,用于进行多轮对话和复杂交互。
|
||||
"""
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from nonebot_plugin_alconna.uniseg import UniMessage
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
from .config import CommonOverrides, LLMGenerationConfig
|
||||
from .config.providers import get_ai_config
|
||||
from .manager import get_global_default_model_name, get_model_instance
|
||||
from .tools import tool_registry
|
||||
from .types import (
|
||||
EmbeddingTaskType,
|
||||
LLMContentPart,
|
||||
LLMErrorCode,
|
||||
LLMException,
|
||||
LLMMessage,
|
||||
LLMResponse,
|
||||
LLMTool,
|
||||
ModelName,
|
||||
)
|
||||
from .utils import unimsg_to_llm_parts
|
||||
|
||||
|
||||
@dataclass
|
||||
class AIConfig:
|
||||
"""AI配置类 - 简化版本"""
|
||||
|
||||
model: ModelName = None
|
||||
default_embedding_model: ModelName = None
|
||||
temperature: float | None = None
|
||||
max_tokens: int | None = None
|
||||
enable_cache: bool = False
|
||||
enable_code: bool = False
|
||||
enable_search: bool = False
|
||||
timeout: int | None = None
|
||||
|
||||
enable_gemini_json_mode: bool = False
|
||||
enable_gemini_thinking: bool = False
|
||||
enable_gemini_safe_mode: bool = False
|
||||
enable_gemini_multimodal: bool = False
|
||||
enable_gemini_grounding: bool = False
|
||||
default_preserve_media_in_history: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
"""初始化后从配置中读取默认值"""
|
||||
ai_config = get_ai_config()
|
||||
if self.model is None:
|
||||
self.model = ai_config.get("default_model_name")
|
||||
if self.timeout is None:
|
||||
self.timeout = ai_config.get("timeout", 180)
|
||||
|
||||
|
||||
class AI:
|
||||
"""统一的AI服务类 - 平衡设计版本
|
||||
|
||||
提供三层API:
|
||||
1. 简单方法:ai.chat(), ai.code(), ai.search()
|
||||
2. 标准方法:ai.analyze() 支持复杂参数
|
||||
3. 高级方法:通过get_model_instance()直接访问
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, config: AIConfig | None = None, history: list[LLMMessage] | None = None
|
||||
):
|
||||
"""
|
||||
初始化AI服务
|
||||
|
||||
参数:
|
||||
config: AI 配置.
|
||||
history: 可选的初始对话历史.
|
||||
"""
|
||||
self.config = config or AIConfig()
|
||||
self.history = history or []
|
||||
|
||||
def clear_history(self):
|
||||
"""清空当前会话的历史记录"""
|
||||
self.history = []
|
||||
logger.info("AI session history cleared.")
|
||||
|
||||
def _sanitize_message_for_history(self, message: LLMMessage) -> LLMMessage:
|
||||
"""
|
||||
净化用于存入历史记录的消息。
|
||||
将非文本的多模态内容部分替换为文本占位符,以避免重复处理。
|
||||
"""
|
||||
if not isinstance(message.content, list):
|
||||
return message
|
||||
|
||||
sanitized_message = copy.deepcopy(message)
|
||||
content_list = sanitized_message.content
|
||||
if not isinstance(content_list, list):
|
||||
return sanitized_message
|
||||
|
||||
new_content_parts: list[LLMContentPart] = []
|
||||
has_multimodal_content = False
|
||||
|
||||
for part in content_list:
|
||||
if isinstance(part, LLMContentPart) and part.type == "text":
|
||||
new_content_parts.append(part)
|
||||
else:
|
||||
has_multimodal_content = True
|
||||
|
||||
if has_multimodal_content:
|
||||
placeholder = "[用户发送了媒体文件,内容已在首次分析时处理]"
|
||||
text_part_found = False
|
||||
for part in new_content_parts:
|
||||
if part.type == "text":
|
||||
part.text = f"{placeholder} {part.text or ''}".strip()
|
||||
text_part_found = True
|
||||
break
|
||||
if not text_part_found:
|
||||
new_content_parts.insert(0, LLMContentPart.text_part(placeholder))
|
||||
|
||||
sanitized_message.content = new_content_parts
|
||||
return sanitized_message
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
message: str | LLMMessage | list[LLMContentPart],
|
||||
*,
|
||||
model: ModelName = None,
|
||||
preserve_media_in_history: bool | None = None,
|
||||
tools: list[LLMTool] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
进行一次聊天对话,支持工具调用。
|
||||
此方法会自动使用和更新会话内的历史记录。
|
||||
|
||||
参数:
|
||||
message: 用户输入的消息。
|
||||
model: 本次对话要使用的模型。
|
||||
preserve_media_in_history: 是否在历史记录中保留原始多模态信息。
|
||||
- True: 保留,用于深度多轮媒体分析。
|
||||
- False: 不保留,替换为占位符,提高效率。
|
||||
- None (默认): 使用AI实例配置的默认值。
|
||||
tools: 本次对话可用的工具列表。
|
||||
tool_choice: 强制模型使用的工具。
|
||||
**kwargs: 传递给模型的其他生成参数。
|
||||
|
||||
返回:
|
||||
LLMResponse: 模型的完整响应,可能包含文本或工具调用请求。
|
||||
"""
|
||||
current_message: LLMMessage
|
||||
if isinstance(message, str):
|
||||
current_message = LLMMessage.user(message)
|
||||
elif isinstance(message, list) and all(
|
||||
isinstance(part, LLMContentPart) for part in message
|
||||
):
|
||||
current_message = LLMMessage.user(message)
|
||||
elif isinstance(message, LLMMessage):
|
||||
current_message = message
|
||||
else:
|
||||
raise LLMException(
|
||||
f"AI.chat 不支持的消息类型: {type(message)}. "
|
||||
"请使用 str, LLMMessage, 或 list[LLMContentPart]. "
|
||||
"对于更复杂的多模态输入或文件路径,请使用 AI.analyze().",
|
||||
code=LLMErrorCode.API_REQUEST_FAILED,
|
||||
)
|
||||
|
||||
final_messages = [*self.history, current_message]
|
||||
|
||||
response = await self._execute_generation(
|
||||
messages=final_messages,
|
||||
model_name=model,
|
||||
error_message="聊天失败",
|
||||
config_overrides=kwargs,
|
||||
llm_tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
should_preserve = (
|
||||
preserve_media_in_history
|
||||
if preserve_media_in_history is not None
|
||||
else self.config.default_preserve_media_in_history
|
||||
)
|
||||
|
||||
if should_preserve:
|
||||
logger.debug("深度分析模式:在历史记录中保留原始多模态消息。")
|
||||
self.history.append(current_message)
|
||||
else:
|
||||
logger.debug("高效模式:净化历史记录中的多模态消息。")
|
||||
sanitized_user_message = self._sanitize_message_for_history(current_message)
|
||||
self.history.append(sanitized_user_message)
|
||||
|
||||
self.history.append(
|
||||
LLMMessage(
|
||||
role="assistant", content=response.text, tool_calls=response.tool_calls
|
||||
)
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def code(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
model: ModelName = None,
|
||||
timeout: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
代码执行
|
||||
|
||||
参数:
|
||||
prompt: 代码执行的提示词。
|
||||
model: 要使用的模型名称。
|
||||
timeout: 代码执行超时时间(秒)。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
dict[str, Any]: 包含执行结果的字典,包含text、code_executions和success字段。
|
||||
"""
|
||||
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
|
||||
|
||||
config = CommonOverrides.gemini_code_execution()
|
||||
if timeout:
|
||||
config.custom_params = config.custom_params or {}
|
||||
config.custom_params["code_execution_timeout"] = timeout
|
||||
|
||||
messages = [LLMMessage.user(prompt)]
|
||||
|
||||
response = await self._execute_generation(
|
||||
messages=messages,
|
||||
model_name=resolved_model,
|
||||
error_message="代码执行失败",
|
||||
config_overrides=kwargs,
|
||||
base_config=config,
|
||||
)
|
||||
|
||||
return {
|
||||
"text": response.text,
|
||||
"code_executions": response.code_executions or [],
|
||||
"success": True,
|
||||
}
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str | UniMessage,
|
||||
*,
|
||||
model: ModelName = None,
|
||||
instruction: str = "",
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
信息搜索 - 支持多模态输入
|
||||
|
||||
参数:
|
||||
query: 搜索查询内容,支持文本或多模态消息。
|
||||
model: 要使用的模型名称。
|
||||
instruction: 搜索指令。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
dict[str, Any]: 包含搜索结果的字典,包含text、sources、queries和success字段
|
||||
"""
|
||||
from nonebot_plugin_alconna.uniseg import UniMessage
|
||||
|
||||
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
|
||||
config = CommonOverrides.gemini_grounding()
|
||||
|
||||
if isinstance(query, str):
|
||||
messages = [LLMMessage.user(query)]
|
||||
elif isinstance(query, UniMessage):
|
||||
content_parts = await unimsg_to_llm_parts(query)
|
||||
|
||||
final_messages: list[LLMMessage] = []
|
||||
if instruction:
|
||||
final_messages.append(LLMMessage.system(instruction))
|
||||
|
||||
if not content_parts:
|
||||
if instruction:
|
||||
final_messages.append(LLMMessage.user(instruction))
|
||||
else:
|
||||
raise LLMException(
|
||||
"搜索内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
|
||||
)
|
||||
else:
|
||||
final_messages.append(LLMMessage.user(content_parts))
|
||||
|
||||
messages = final_messages
|
||||
else:
|
||||
raise LLMException(
|
||||
f"不支持的搜索输入类型: {type(query)}. 请使用 str 或 UniMessage.",
|
||||
code=LLMErrorCode.API_REQUEST_FAILED,
|
||||
)
|
||||
|
||||
response = await self._execute_generation(
|
||||
messages=messages,
|
||||
model_name=resolved_model,
|
||||
error_message="信息搜索失败",
|
||||
config_overrides=kwargs,
|
||||
base_config=config,
|
||||
)
|
||||
|
||||
result = {
|
||||
"text": response.text,
|
||||
"sources": [],
|
||||
"queries": [],
|
||||
"success": True,
|
||||
}
|
||||
|
||||
if response.grounding_metadata:
|
||||
result["sources"] = response.grounding_metadata.grounding_attributions or []
|
||||
result["queries"] = response.grounding_metadata.web_search_queries or []
|
||||
|
||||
return result
|
||||
|
||||
async def analyze(
|
||||
self,
|
||||
message: UniMessage | None,
|
||||
*,
|
||||
instruction: str = "",
|
||||
model: ModelName = None,
|
||||
use_tools: list[str] | None = None,
|
||||
tool_config: dict[str, Any] | None = None,
|
||||
activated_tools: list[LLMTool] | None = None,
|
||||
history: list[LLMMessage] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
内容分析 - 接收 UniMessage 物件进行多模态分析和工具呼叫。
|
||||
|
||||
参数:
|
||||
message: 要分析的消息内容(支持多模态)。
|
||||
instruction: 分析指令。
|
||||
model: 要使用的模型名称。
|
||||
use_tools: 要使用的工具名称列表。
|
||||
tool_config: 工具配置。
|
||||
activated_tools: 已激活的工具列表。
|
||||
history: 对话历史记录。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
LLMResponse: 模型的完整响应结果。
|
||||
"""
|
||||
from nonebot_plugin_alconna.uniseg import UniMessage
|
||||
|
||||
content_parts = await unimsg_to_llm_parts(message or UniMessage())
|
||||
|
||||
final_messages: list[LLMMessage] = []
|
||||
if history:
|
||||
final_messages.extend(history)
|
||||
|
||||
if instruction:
|
||||
if not any(msg.role == "system" for msg in final_messages):
|
||||
final_messages.insert(0, LLMMessage.system(instruction))
|
||||
|
||||
if not content_parts:
|
||||
if instruction and not history:
|
||||
final_messages.append(LLMMessage.user(instruction))
|
||||
elif not history:
|
||||
raise LLMException(
|
||||
"分析内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
|
||||
)
|
||||
else:
|
||||
final_messages.append(LLMMessage.user(content_parts))
|
||||
|
||||
llm_tools: list[LLMTool] | None = activated_tools
|
||||
if not llm_tools and use_tools:
|
||||
try:
|
||||
llm_tools = tool_registry.get_tools(use_tools)
|
||||
logger.debug(f"已从注册表加载工具定义: {use_tools}")
|
||||
except ValueError as e:
|
||||
raise LLMException(
|
||||
f"加载工具定义失败: {e}",
|
||||
code=LLMErrorCode.CONFIGURATION_ERROR,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
tool_choice = None
|
||||
if tool_config:
|
||||
mode = tool_config.get("mode", "auto")
|
||||
if mode in ["auto", "any", "none"]:
|
||||
tool_choice = mode
|
||||
|
||||
response = await self._execute_generation(
|
||||
messages=final_messages,
|
||||
model_name=model,
|
||||
error_message="内容分析失败",
|
||||
config_overrides=kwargs,
|
||||
llm_tools=llm_tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def _execute_generation(
|
||||
self,
|
||||
messages: list[LLMMessage],
|
||||
model_name: ModelName,
|
||||
error_message: str,
|
||||
config_overrides: dict[str, Any],
|
||||
llm_tools: list[LLMTool] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
base_config: LLMGenerationConfig | None = None,
|
||||
) -> LLMResponse:
|
||||
"""通用的生成执行方法,封装模型获取和单次API调用"""
|
||||
try:
|
||||
resolved_model_name = self._resolve_model_name(
|
||||
model_name or self.config.model
|
||||
)
|
||||
final_config_dict = self._merge_config(
|
||||
config_overrides, base_config=base_config
|
||||
)
|
||||
|
||||
async with await get_model_instance(
|
||||
resolved_model_name, override_config=final_config_dict
|
||||
) as model_instance:
|
||||
return await model_instance.generate_response(
|
||||
messages,
|
||||
tools=llm_tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
except LLMException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"{error_message}: {e}", e=e)
|
||||
raise LLMException(f"{error_message}: {e}", cause=e)
|
||||
|
||||
def _resolve_model_name(self, model_name: ModelName) -> str:
|
||||
"""解析模型名称"""
|
||||
if model_name:
|
||||
return model_name
|
||||
|
||||
default_model = get_global_default_model_name()
|
||||
if default_model:
|
||||
return default_model
|
||||
|
||||
raise LLMException(
|
||||
"未指定模型名称且未设置全局默认模型",
|
||||
code=LLMErrorCode.MODEL_NOT_FOUND,
|
||||
)
|
||||
|
||||
def _merge_config(
|
||||
self,
|
||||
user_config: dict[str, Any],
|
||||
base_config: LLMGenerationConfig | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""合并配置"""
|
||||
final_config = {}
|
||||
if base_config:
|
||||
final_config.update(base_config.to_dict())
|
||||
|
||||
if self.config.temperature is not None:
|
||||
final_config["temperature"] = self.config.temperature
|
||||
if self.config.max_tokens is not None:
|
||||
final_config["max_tokens"] = self.config.max_tokens
|
||||
|
||||
if self.config.enable_cache:
|
||||
final_config["enable_caching"] = True
|
||||
if self.config.enable_code:
|
||||
final_config["enable_code_execution"] = True
|
||||
if self.config.enable_search:
|
||||
final_config["enable_grounding"] = True
|
||||
|
||||
if self.config.enable_gemini_json_mode:
|
||||
final_config["response_mime_type"] = "application/json"
|
||||
if self.config.enable_gemini_thinking:
|
||||
final_config["thinking_budget"] = 0.8
|
||||
if self.config.enable_gemini_safe_mode:
|
||||
final_config["safety_settings"] = (
|
||||
CommonOverrides.gemini_safe().safety_settings
|
||||
)
|
||||
if self.config.enable_gemini_multimodal:
|
||||
final_config.update(CommonOverrides.gemini_multimodal().to_dict())
|
||||
if self.config.enable_gemini_grounding:
|
||||
final_config["enable_grounding"] = True
|
||||
|
||||
final_config.update(user_config)
|
||||
|
||||
return final_config
|
||||
|
||||
async def embed(
|
||||
self,
|
||||
texts: list[str] | str,
|
||||
*,
|
||||
model: ModelName = None,
|
||||
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
||||
**kwargs: Any,
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
生成文本嵌入向量
|
||||
|
||||
参数:
|
||||
texts: 要生成嵌入向量的文本或文本列表。
|
||||
model: 要使用的嵌入模型名称。
|
||||
task_type: 嵌入任务类型。
|
||||
**kwargs: 传递给模型的其他参数。
|
||||
|
||||
返回:
|
||||
list[list[float]]: 文本的嵌入向量列表。
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
try:
|
||||
resolved_model_str = (
|
||||
model or self.config.default_embedding_model or self.config.model
|
||||
)
|
||||
if not resolved_model_str:
|
||||
raise LLMException(
|
||||
"使用 embed 功能时必须指定嵌入模型名称,"
|
||||
"或在 AIConfig 中配置 default_embedding_model。",
|
||||
code=LLMErrorCode.MODEL_NOT_FOUND,
|
||||
)
|
||||
resolved_model_str = self._resolve_model_name(resolved_model_str)
|
||||
|
||||
async with await get_model_instance(
|
||||
resolved_model_str,
|
||||
override_config=None,
|
||||
) as embedding_model_instance:
|
||||
return await embedding_model_instance.generate_embeddings(
|
||||
texts, task_type=task_type, **kwargs
|
||||
)
|
||||
except LLMException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"文本嵌入失败: {e}", e=e)
|
||||
raise LLMException(
|
||||
f"文本嵌入失败: {e}", code=LLMErrorCode.EMBEDDING_FAILED, cause=e
|
||||
)
|
||||
@ -10,7 +10,13 @@ from .content import (
|
||||
LLMMessage,
|
||||
LLMResponse,
|
||||
)
|
||||
from .enums import EmbeddingTaskType, ModelProvider, ResponseFormat, ToolCategory
|
||||
from .enums import (
|
||||
EmbeddingTaskType,
|
||||
ModelProvider,
|
||||
ResponseFormat,
|
||||
TaskType,
|
||||
ToolCategory,
|
||||
)
|
||||
from .exceptions import LLMErrorCode, LLMException, get_user_friendly_error_message
|
||||
from .models import (
|
||||
LLMCacheInfo,
|
||||
@ -52,6 +58,7 @@ __all__ = [
|
||||
"ModelProvider",
|
||||
"ProviderConfig",
|
||||
"ResponseFormat",
|
||||
"TaskType",
|
||||
"ToolCategory",
|
||||
"ToolMetadata",
|
||||
"UsageInfo",
|
||||
|
||||
@ -45,6 +45,17 @@ class ToolCategory(Enum):
|
||||
CUSTOM = auto()
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
"""任务类型枚举"""
|
||||
|
||||
CHAT = "chat"
|
||||
CODE = "code"
|
||||
SEARCH = "search"
|
||||
ANALYSIS = "analysis"
|
||||
GENERATION = "generation"
|
||||
MULTIMODAL = "multimodal"
|
||||
|
||||
|
||||
class LLMErrorCode(Enum):
|
||||
"""LLM 服务相关的错误代码枚举"""
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user