feat(llm): 新增LLM模型管理插件并增强API密钥管理 (#1972)

🔧 新增功能:
- LLM模型管理插件 (builtin_plugins/llm_manager/)
  • llm list - 查看可用模型列表 (图片格式)
  • llm info - 查看模型详细信息 (Markdown图片)
  • llm default - 管理全局默认模型
  • llm test - 测试模型连通性
  • llm keys - 查看API Key状态 (表格图片,含健康度/成功率/延迟)
  • llm reset-key - 重置API Key失败状态

🏗️ 架构重构:
- 会话管理: AI/AIConfig 类迁移至独立的 session.py
- 类型定义: TaskType 枚举移至 types/enums.py
- API增强:
  • chat() 函数返回完整 LLMResponse,支持工具调用
  • 新增 generate() 函数用于一次性响应生成
  • 统一API调用核心方法 _perform_api_call,返回使用的API密钥

🚀 密钥管理增强:
- 详细状态跟踪: 健康度、成功率、平均延迟、错误信息、建议操作
- 状态持久化: 启动时加载,关闭时自动保存密钥状态
- 智能冷却策略: 根据错误类型设置不同冷却时间
- 延迟监控: with_smart_retry 记录API调用延迟并更新统计

Co-authored-by: webjoin111 <455457521@qq.com>
Co-authored-by: HibiKier <45528451+HibiKier@users.noreply.github.com>
This commit is contained in:
Rumio 2025-07-14 22:39:17 +08:00 committed by GitHub
parent 8649aaaa54
commit 46a0768a45
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 1423 additions and 682 deletions

View File

@ -0,0 +1,171 @@
from nonebot.permission import SUPERUSER
from nonebot.plugin import PluginMetadata
from nonebot_plugin_alconna import (
Alconna,
Args,
Arparma,
Match,
Option,
Query,
Subcommand,
on_alconna,
store_true,
)
from zhenxun.configs.utils import PluginExtraData
from zhenxun.services.log import logger
from zhenxun.utils.enum import PluginType
from zhenxun.utils.message import MessageUtils
from .data_source import DataSource
from .presenters import Presenters
__plugin_meta__ = PluginMetadata(
name="LLM模型管理",
description="查看和管理大语言模型服务。",
usage="""
LLM模型管理 (SUPERUSER)
llm list [--all]
- 查看可用模型列表
- --all: 显示包括不可用在内的所有模型
llm info <Provider/ModelName>
- 查看指定模型的详细信息和能力
llm default [Provider/ModelName]
- 查看或设置全局默认模型
- 不带参数: 查看当前默认模型
- 带参数: 设置新的默认模型
- 例子: llm default Gemini/gemini-2.0-flash
llm test <Provider/ModelName>
- 测试指定模型的连通性和API Key有效性
llm keys <ProviderName>
- 查看指定提供商的所有API Key状态
llm reset-key <ProviderName> [--key <api_key>]
- 重置提供商的所有或指定API Key的失败状态
""",
extra=PluginExtraData(
author="HibiKier",
version="1.0.0",
plugin_type=PluginType.SUPERUSER,
).to_dict(),
)
llm_cmd = on_alconna(
Alconna(
"llm",
Subcommand("list", alias=["ls"], help_text="查看模型列表"),
Subcommand("info", Args["model_name", str], help_text="查看模型详情"),
Subcommand("default", Args["model_name?", str], help_text="查看或设置默认模型"),
Subcommand(
"test", Args["model_name", str], alias=["ping"], help_text="测试模型连通性"
),
Subcommand("keys", Args["provider_name", str], help_text="查看API密钥状态"),
Subcommand(
"reset-key",
Args["provider_name", str],
Option("--key", Args["api_key", str], help_text="指定要重置的API Key"),
help_text="重置API Key状态",
),
Option("--all", action=store_true, help_text="显示所有条目"),
),
permission=SUPERUSER,
priority=5,
block=True,
)
@llm_cmd.assign("list")
async def handle_list(arp: Arparma, show_all: Query[bool] = Query("all")):
"""处理 'llm list' 命令"""
logger.info("获取LLM模型列表", command="LLM Manage", session=arp.header_result)
models = await DataSource.get_model_list(show_all=show_all.result)
image = await Presenters.format_model_list_as_image(models, show_all.result)
await llm_cmd.finish(MessageUtils.build_message(image))
@llm_cmd.assign("info")
async def handle_info(arp: Arparma, model_name: Match[str]):
"""处理 'llm info' 命令"""
logger.info(
f"获取模型详情: {model_name.result}",
command="LLM Manage",
session=arp.header_result,
)
details = await DataSource.get_model_details(model_name.result)
if not details:
await llm_cmd.finish(f"未找到模型: {model_name.result}")
image_bytes = await Presenters.format_model_details_as_markdown_image(details)
await llm_cmd.finish(MessageUtils.build_message(image_bytes))
@llm_cmd.assign("default")
async def handle_default(arp: Arparma, model_name: Match[str]):
"""处理 'llm default' 命令"""
if model_name.available:
logger.info(
f"设置默认模型为: {model_name.result}",
command="LLM Manage",
session=arp.header_result,
)
success, message = await DataSource.set_default_model(model_name.result)
await llm_cmd.finish(message)
else:
logger.info("查看默认模型", command="LLM Manage", session=arp.header_result)
current_default = await DataSource.get_default_model()
await llm_cmd.finish(f"当前全局默认模型为: {current_default or '未设置'}")
@llm_cmd.assign("test")
async def handle_test(arp: Arparma, model_name: Match[str]):
"""处理 'llm test' 命令"""
logger.info(
f"测试模型连通性: {model_name.result}",
command="LLM Manage",
session=arp.header_result,
)
await llm_cmd.send(f"正在测试模型 '{model_name.result}',请稍候...")
success, message = await DataSource.test_model_connectivity(model_name.result)
await llm_cmd.finish(message)
@llm_cmd.assign("keys")
async def handle_keys(arp: Arparma, provider_name: Match[str]):
"""处理 'llm keys' 命令"""
logger.info(
f"查看提供商API Key状态: {provider_name.result}",
command="LLM Manage",
session=arp.header_result,
)
sorted_stats = await DataSource.get_key_status(provider_name.result)
if not sorted_stats:
await llm_cmd.finish(
f"未找到提供商 '{provider_name.result}' 或其没有配置API Keys。"
)
image = await Presenters.format_key_status_as_image(
provider_name.result, sorted_stats
)
await llm_cmd.finish(MessageUtils.build_message(image))
@llm_cmd.assign("reset-key")
async def handle_reset_key(
arp: Arparma, provider_name: Match[str], api_key: Match[str]
):
"""处理 'llm reset-key' 命令"""
key_to_reset = api_key.result if api_key.available else None
log_msg = f"重置 {provider_name.result}" + (
"指定API Key" if key_to_reset else "所有API Keys"
)
logger.info(log_msg, command="LLM Manage", session=arp.header_result)
success, message = await DataSource.reset_key(provider_name.result, key_to_reset)
await llm_cmd.finish(message)

View File

@ -0,0 +1,120 @@
import time
from typing import Any
from zhenxun.services.llm import (
LLMException,
get_global_default_model_name,
get_model_instance,
list_available_models,
set_global_default_model_name,
)
from zhenxun.services.llm.core import KeyStatus
from zhenxun.services.llm.manager import (
reset_key_status,
)
class DataSource:
"""LLM管理插件的数据源和业务逻辑"""
@staticmethod
async def get_model_list(show_all: bool = False) -> list[dict[str, Any]]:
"""获取模型列表"""
models = list_available_models()
if show_all:
return models
return [m for m in models if m.get("is_available", True)]
@staticmethod
async def get_model_details(model_name_str: str) -> dict[str, Any] | None:
"""获取指定模型的详细信息"""
try:
model = await get_model_instance(model_name_str)
return {
"provider_config": model.provider_config,
"model_detail": model.model_detail,
"capabilities": model.capabilities,
}
except LLMException:
return None
@staticmethod
async def get_default_model() -> str | None:
"""获取全局默认模型"""
return get_global_default_model_name()
@staticmethod
async def set_default_model(model_name_str: str) -> tuple[bool, str]:
"""设置全局默认模型"""
success = set_global_default_model_name(model_name_str)
if success:
return True, f"✅ 成功将默认模型设置为: {model_name_str}"
else:
return False, f"❌ 设置失败,模型 '{model_name_str}' 不存在或无效。"
@staticmethod
async def test_model_connectivity(model_name_str: str) -> tuple[bool, str]:
"""测试模型连通性"""
start_time = time.monotonic()
try:
async with await get_model_instance(model_name_str) as model:
await model.generate_text("你好")
end_time = time.monotonic()
latency = (end_time - start_time) * 1000
return (
True,
f"✅ 模型 '{model_name_str}' 连接成功!\n响应延迟: {latency:.2f} ms",
)
except LLMException as e:
return (
False,
f"❌ 模型 '{model_name_str}' 连接测试失败:\n"
f"{e.user_friendly_message}\n错误码: {e.code.name}",
)
except Exception as e:
return False, f"❌ 测试时发生未知错误: {e!s}"
@staticmethod
async def get_key_status(provider_name: str) -> list[dict[str, Any]] | None:
"""获取并排序指定提供商的API Key状态"""
from zhenxun.services.llm.manager import get_key_usage_stats
all_stats = await get_key_usage_stats()
provider_stats = all_stats.get(provider_name)
if not provider_stats or not provider_stats.get("key_stats"):
return None
key_stats_dict = provider_stats["key_stats"]
stats_list = [
{"key_id": key_id, **stats} for key_id, stats in key_stats_dict.items()
]
def sort_key(item: dict[str, Any]):
status_priority = item.get("status_enum", KeyStatus.UNUSED).value
return (
status_priority,
100 - item.get("success_rate", 100.0),
-item.get("total_calls", 0),
)
sorted_stats_list = sorted(stats_list, key=sort_key)
return sorted_stats_list
@staticmethod
async def reset_key(provider_name: str, api_key: str | None) -> tuple[bool, str]:
"""重置API Key状态"""
success = await reset_key_status(provider_name, api_key)
if success:
if api_key:
if len(api_key) > 8:
target = f"API Key '{api_key[:4]}...{api_key[-4:]}'"
else:
target = f"API Key '{api_key}'"
else:
target = "所有API Keys"
return True, f"✅ 成功重置提供商 '{provider_name}'{target} 的状态。"
else:
return False, "❌ 重置失败请检查提供商名称或API Key是否正确。"

View File

@ -0,0 +1,204 @@
from typing import Any
from zhenxun.services.llm.core import KeyStatus
from zhenxun.services.llm.types import ModelModality
from zhenxun.utils._build_image import BuildImage
from zhenxun.utils._image_template import ImageTemplate, Markdown, RowStyle
def _format_seconds(seconds: int) -> str:
"""将秒数格式化为 'Xm Ys''Xh Ym' 的形式"""
if seconds <= 0:
return "0s"
if seconds < 60:
return f"{seconds}s"
minutes, seconds = divmod(seconds, 60)
if minutes < 60:
return f"{minutes}m {seconds}s"
hours, minutes = divmod(minutes, 60)
return f"{hours}h {minutes}m"
class Presenters:
"""格式化LLM管理插件的输出 (图片格式)"""
@staticmethod
async def format_model_list_as_image(
models: list[dict[str, Any]], show_all: bool
) -> BuildImage:
"""将模型列表格式化为表格图片"""
title = "📋 LLM模型列表" + (" (所有已配置模型)" if show_all else " (仅可用)")
if not models:
return await BuildImage.build_text_image(
f"{title}\n\n当前没有配置任何LLM模型。"
)
column_name = ["提供商", "模型名称", "API类型", "状态"]
data_list = []
for model in models:
status_text = "✅ 可用" if model.get("is_available", True) else "❌ 不可用"
embed_tag = " (Embed)" if model.get("is_embedding_model", False) else ""
data_list.append(
[
model.get("provider_name", "N/A"),
f"{model.get('model_name', 'N/A')}{embed_tag}",
model.get("api_type", "N/A"),
status_text,
]
)
return await ImageTemplate.table_page(
head_text=title,
tip_text="使用 `llm info <Provider/ModelName>` 查看详情",
column_name=column_name,
data_list=data_list,
)
@staticmethod
async def format_model_details_as_markdown_image(details: dict[str, Any]) -> bytes:
"""将模型详情格式化为Markdown图片"""
provider = details["provider_config"]
model = details["model_detail"]
caps = details["capabilities"]
cap_list = []
if ModelModality.IMAGE in caps.input_modalities:
cap_list.append("视觉")
if ModelModality.VIDEO in caps.input_modalities:
cap_list.append("视频")
if ModelModality.AUDIO in caps.input_modalities:
cap_list.append("音频")
if caps.supports_tool_calling:
cap_list.append("工具调用")
if caps.is_embedding_model:
cap_list.append("文本嵌入")
md = Markdown()
md.head(f"🔎 模型详情: {provider.name}/{model.model_name}", level=1)
md.text("---")
md.head("提供商信息", level=2)
md.list(
[
f"**名称**: {provider.name}",
f"**API 类型**: {provider.api_type}",
f"**API Base**: {provider.api_base or '默认'}",
]
)
md.head("模型详情", level=2)
temp_value = model.temperature or provider.temperature or "未设置"
token_value = model.max_tokens or provider.max_tokens or "未设置"
md.list(
[
f"**名称**: {model.model_name}",
f"**默认温度**: {temp_value}",
f"**最大Token**: {token_value}",
f"**核心能力**: {', '.join(cap_list) or '纯文本'}",
]
)
return await md.build()
@staticmethod
async def format_key_status_as_image(
provider_name: str, sorted_stats: list[dict[str, Any]]
) -> BuildImage:
"""将已排序的、详细的API Key状态格式化为表格图片"""
title = f"🔑 '{provider_name}' API Key 状态"
if not sorted_stats:
return await BuildImage.build_text_image(
f"{title}\n\n该提供商没有配置API Keys。"
)
def _status_row_style(column: str, text: str) -> RowStyle:
style = RowStyle()
if column == "状态":
if "✅ 健康" in text:
style.font_color = "#67C23A"
elif "⚠️ 告警" in text:
style.font_color = "#E6A23C"
elif "❌ 错误" in text or "🚫" in text:
style.font_color = "#F56C6C"
elif "❄️ 冷却中" in text:
style.font_color = "#409EFF"
elif column == "成功率":
try:
if text != "N/A":
rate = float(text.replace("%", ""))
if rate < 80:
style.font_color = "#F56C6C"
elif rate < 95:
style.font_color = "#E6A23C"
except (ValueError, TypeError):
pass
return style
column_name = [
"Key (部分)",
"状态",
"总调用",
"成功率",
"平均延迟(s)",
"上次错误",
"建议操作",
]
data_list = []
for key_info in sorted_stats:
status_enum: KeyStatus = key_info["status_enum"]
if status_enum == KeyStatus.COOLDOWN:
cooldown_seconds = int(key_info["cooldown_seconds_left"])
formatted_time = _format_seconds(cooldown_seconds)
status_text = f"❄️ 冷却中({formatted_time})"
else:
status_text = {
KeyStatus.DISABLED: "🚫 永久禁用",
KeyStatus.ERROR: "❌ 错误",
KeyStatus.WARNING: "⚠️ 告警",
KeyStatus.HEALTHY: "✅ 健康",
KeyStatus.UNUSED: "⚪️ 未使用",
}.get(status_enum, "❔ 未知")
total_calls = key_info["total_calls"]
total_calls_text = (
f"{key_info['success_count']}/{total_calls}"
if total_calls > 0
else "0/0"
)
success_rate = key_info["success_rate"]
success_rate_text = f"{success_rate:.1f}%" if total_calls > 0 else "N/A"
avg_latency = key_info["avg_latency"]
avg_latency_text = f"{avg_latency / 1000:.2f}" if avg_latency > 0 else "N/A"
last_error = key_info.get("last_error") or "-"
if len(last_error) > 25:
last_error = last_error[:22] + "..."
data_list.append(
[
key_info["key_id"],
status_text,
total_calls_text,
success_rate_text,
avg_latency_text,
last_error,
key_info["suggested_action"],
]
)
return await ImageTemplate.table_page(
head_text=title,
tip_text="使用 `llm reset-key <Provider>` 重置Key状态",
column_name=column_name,
data_list=data_list,
text_style=_status_row_style,
column_space=15,
)

View File

@ -21,11 +21,28 @@ require("nonebot_plugin_waiter")
from .db_context import Model, disconnect
from .llm import (
AI,
AIConfig,
CommonOverrides,
LLMContentPart,
LLMException,
LLMGenerationConfig,
LLMMessage,
analyze,
analyze_multimodal,
chat,
clear_model_cache,
code,
create_multimodal_message,
embed,
generate,
get_cache_stats,
get_model_instance,
list_available_models,
list_embedding_models,
pipeline_chat,
search,
search_multimodal,
set_global_default_model_name,
tool_registry,
)
from .log import logger
@ -34,16 +51,33 @@ from .scheduler import scheduler_manager
__all__ = [
"AI",
"AIConfig",
"CommonOverrides",
"LLMContentPart",
"LLMException",
"LLMGenerationConfig",
"LLMMessage",
"Model",
"PluginInit",
"PluginInitManager",
"analyze",
"analyze_multimodal",
"chat",
"clear_model_cache",
"code",
"create_multimodal_message",
"disconnect",
"embed",
"generate",
"get_cache_stats",
"get_model_instance",
"list_available_models",
"list_embedding_models",
"logger",
"pipeline_chat",
"scheduler_manager",
"search",
"search_multimodal",
"set_global_default_model_name",
"tool_registry",
]

View File

@ -198,7 +198,7 @@ print(search_result['text'])
当你需要进行有上下文的、连续的对话时,`AI` 类是你的最佳选择。
```python
from zhenxun.services.llm.api import AI, AIConfig
from zhenxun.services.llm import AI, AIConfig
# 初始化一个AI会话可以传入自定义配置
ai_config = AIConfig(model="GLM/glm-4-flash", temperature=0.7)
@ -395,7 +395,7 @@ async def my_tool_factory(config: MyToolConfig):
`analyze``generate_response` 中使用 `use_tools` 参数。框架会自动处理整个调用流程。
```python
from zhenxun.services.llm.api import analyze
from zhenxun.services.llm import analyze
from nonebot_plugin_alconna.uniseg import UniMessage
response = await analyze(
@ -442,7 +442,6 @@ from zhenxun.services.llm.manager import (
get_key_usage_stats,
reset_key_status
)
from zhenxun.services.llm import clear_model_cache, get_cache_stats
# 列出所有在config.yaml中配置的可用模型
models = list_available_models()

View File

@ -5,14 +5,12 @@ LLM 服务模块 - 公共 API 入口
"""
from .api import (
AI,
AIConfig,
TaskType,
analyze,
analyze_multimodal,
chat,
code,
embed,
generate,
pipeline_chat,
search,
search_multimodal,
@ -35,6 +33,7 @@ from .manager import (
list_model_identifiers,
set_global_default_model_name,
)
from .session import AI, AIConfig
from .tools import tool_registry
from .types import (
EmbeddingTaskType,
@ -49,6 +48,7 @@ from .types import (
ModelInfo,
ModelProvider,
ResponseFormat,
TaskType,
ToolCategory,
ToolMetadata,
UsageInfo,
@ -84,6 +84,7 @@ __all__ = [
"code",
"create_multimodal_message",
"embed",
"generate",
"get_cache_stats",
"get_global_default_model_name",
"get_model_instance",

View File

@ -1,10 +1,7 @@
"""
LLM 服务的高级 API 接口
LLM 服务的高级 API 接口 - 便捷函数入口
"""
import copy
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any
@ -12,10 +9,8 @@ from nonebot_plugin_alconna.uniseg import UniMessage
from zhenxun.services.log import logger
from .config import CommonOverrides, LLMGenerationConfig
from .config.providers import get_ai_config
from .manager import get_global_default_model_name, get_model_instance
from .tools import tool_registry
from .manager import get_model_instance
from .session import AI
from .types import (
EmbeddingTaskType,
LLMContentPart,
@ -29,514 +24,31 @@ from .types import (
from .utils import create_multimodal_message, unimsg_to_llm_parts
class TaskType(Enum):
"""任务类型枚举"""
CHAT = "chat"
CODE = "code"
SEARCH = "search"
ANALYSIS = "analysis"
GENERATION = "generation"
MULTIMODAL = "multimodal"
@dataclass
class AIConfig:
"""AI配置类 - 简化版本"""
model: ModelName = None
default_embedding_model: ModelName = None
temperature: float | None = None
max_tokens: int | None = None
enable_cache: bool = False
enable_code: bool = False
enable_search: bool = False
timeout: int | None = None
enable_gemini_json_mode: bool = False
enable_gemini_thinking: bool = False
enable_gemini_safe_mode: bool = False
enable_gemini_multimodal: bool = False
enable_gemini_grounding: bool = False
default_preserve_media_in_history: bool = False
def __post_init__(self):
"""初始化后从配置中读取默认值"""
ai_config = get_ai_config()
if self.model is None:
self.model = ai_config.get("default_model_name")
if self.timeout is None:
self.timeout = ai_config.get("timeout", 180)
class AI:
"""统一的AI服务类 - 平衡设计版本
提供三层API
1. 简单方法ai.chat(), ai.code(), ai.search()
2. 标准方法ai.analyze() 支持复杂参数
3. 高级方法通过get_model_instance()直接访问
"""
def __init__(
self, config: AIConfig | None = None, history: list[LLMMessage] | None = None
):
"""
初始化AI服务
参数:
config: AI 配置.
history: 可选的初始对话历史.
"""
self.config = config or AIConfig()
self.history = history or []
def clear_history(self):
"""清空当前会话的历史记录"""
self.history = []
logger.info("AI session history cleared.")
def _sanitize_message_for_history(self, message: LLMMessage) -> LLMMessage:
"""
净化用于存入历史记录的消息
将非文本的多模态内容部分替换为文本占位符以避免重复处理
"""
if not isinstance(message.content, list):
return message
sanitized_message = copy.deepcopy(message)
content_list = sanitized_message.content
if not isinstance(content_list, list):
return sanitized_message
new_content_parts: list[LLMContentPart] = []
has_multimodal_content = False
for part in content_list:
if isinstance(part, LLMContentPart) and part.type == "text":
new_content_parts.append(part)
else:
has_multimodal_content = True
if has_multimodal_content:
placeholder = "[用户发送了媒体文件,内容已在首次分析时处理]"
text_part_found = False
for part in new_content_parts:
if part.type == "text":
part.text = f"{placeholder} {part.text or ''}".strip()
text_part_found = True
break
if not text_part_found:
new_content_parts.insert(0, LLMContentPart.text_part(placeholder))
sanitized_message.content = new_content_parts
return sanitized_message
async def chat(
self,
message: str | LLMMessage | list[LLMContentPart],
*,
model: ModelName = None,
preserve_media_in_history: bool | None = None,
**kwargs: Any,
) -> str:
"""
进行一次聊天对话
此方法会自动使用和更新会话内的历史记录
参数:
message: 用户输入的消息
model: 本次对话要使用的模型
preserve_media_in_history: 是否在历史记录中保留原始多模态信息
- True: 保留用于深度多轮媒体分析
- False: 不保留替换为占位符提高效率
- None (默认): 使用AI实例配置的默认值
**kwargs: 传递给模型的其他参数
返回:
str: 模型的文本响应
"""
current_message: LLMMessage
if isinstance(message, str):
current_message = LLMMessage.user(message)
elif isinstance(message, list) and all(
isinstance(part, LLMContentPart) for part in message
):
current_message = LLMMessage.user(message)
elif isinstance(message, LLMMessage):
current_message = message
else:
raise LLMException(
f"AI.chat 不支持的消息类型: {type(message)}. "
"请使用 str, LLMMessage, 或 list[LLMContentPart]. "
"对于更复杂的多模态输入或文件路径,请使用 AI.analyze().",
code=LLMErrorCode.API_REQUEST_FAILED,
)
final_messages = [*self.history, current_message]
response = await self._execute_generation(
final_messages, model, "聊天失败", kwargs
)
should_preserve = (
preserve_media_in_history
if preserve_media_in_history is not None
else self.config.default_preserve_media_in_history
)
if should_preserve:
logger.debug("深度分析模式:在历史记录中保留原始多模态消息。")
self.history.append(current_message)
else:
logger.debug("高效模式:净化历史记录中的多模态消息。")
sanitized_user_message = self._sanitize_message_for_history(current_message)
self.history.append(sanitized_user_message)
self.history.append(LLMMessage.assistant_text_response(response.text))
return response.text
async def code(
self,
prompt: str,
*,
model: ModelName = None,
timeout: int | None = None,
**kwargs: Any,
) -> dict[str, Any]:
"""
代码执行
参数:
prompt: 代码执行的提示词
model: 要使用的模型名称
timeout: 代码执行超时时间
**kwargs: 传递给模型的其他参数
返回:
dict[str, Any]: 包含执行结果的字典包含textcode_executions和success字段
"""
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
config = CommonOverrides.gemini_code_execution()
if timeout:
config.custom_params = config.custom_params or {}
config.custom_params["code_execution_timeout"] = timeout
messages = [LLMMessage.user(prompt)]
response = await self._execute_generation(
messages, resolved_model, "代码执行失败", kwargs, base_config=config
)
return {
"text": response.text,
"code_executions": response.code_executions or [],
"success": True,
}
async def search(
self,
query: str | UniMessage,
*,
model: ModelName = None,
instruction: str = "",
**kwargs: Any,
) -> dict[str, Any]:
"""
信息搜索 - 支持多模态输入
参数:
query: 搜索查询内容支持文本或多模态消息
model: 要使用的模型名称
instruction: 搜索指令
**kwargs: 传递给模型的其他参数
返回:
dict[str, Any]: 包含搜索结果的字典包含textsourcesqueries和success字段
"""
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
config = CommonOverrides.gemini_grounding()
if isinstance(query, str):
messages = [LLMMessage.user(query)]
elif isinstance(query, UniMessage):
content_parts = await unimsg_to_llm_parts(query)
final_messages: list[LLMMessage] = []
if instruction:
final_messages.append(LLMMessage.system(instruction))
if not content_parts:
if instruction:
final_messages.append(LLMMessage.user(instruction))
else:
raise LLMException(
"搜索内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
)
else:
final_messages.append(LLMMessage.user(content_parts))
messages = final_messages
else:
raise LLMException(
f"不支持的搜索输入类型: {type(query)}. 请使用 str 或 UniMessage.",
code=LLMErrorCode.API_REQUEST_FAILED,
)
response = await self._execute_generation(
messages, resolved_model, "信息搜索失败", kwargs, base_config=config
)
result = {
"text": response.text,
"sources": [],
"queries": [],
"success": True,
}
if response.grounding_metadata:
result["sources"] = response.grounding_metadata.grounding_attributions or []
result["queries"] = response.grounding_metadata.web_search_queries or []
return result
async def analyze(
self,
message: UniMessage | None,
*,
instruction: str = "",
model: ModelName = None,
use_tools: list[str] | None = None,
tool_config: dict[str, Any] | None = None,
activated_tools: list[LLMTool] | None = None,
history: list[LLMMessage] | None = None,
**kwargs: Any,
) -> LLMResponse:
"""
内容分析 - 接收 UniMessage 物件进行多模态分析和工具呼叫
参数:
message: 要分析的消息内容支持多模态
instruction: 分析指令
model: 要使用的模型名称
use_tools: 要使用的工具名称列表
tool_config: 工具配置
activated_tools: 已激活的工具列表
history: 对话历史记录
**kwargs: 传递给模型的其他参数
返回:
LLMResponse: 模型的完整响应结果
"""
content_parts = await unimsg_to_llm_parts(message or UniMessage())
final_messages: list[LLMMessage] = []
if history:
final_messages.extend(history)
if instruction:
if not any(msg.role == "system" for msg in final_messages):
final_messages.insert(0, LLMMessage.system(instruction))
if not content_parts:
if instruction and not history:
final_messages.append(LLMMessage.user(instruction))
elif not history:
raise LLMException(
"分析内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
)
else:
final_messages.append(LLMMessage.user(content_parts))
llm_tools: list[LLMTool] | None = activated_tools
if not llm_tools and use_tools:
try:
llm_tools = tool_registry.get_tools(use_tools)
logger.debug(f"已从注册表加载工具定义: {use_tools}")
except ValueError as e:
raise LLMException(
f"加载工具定义失败: {e}",
code=LLMErrorCode.CONFIGURATION_ERROR,
cause=e,
)
tool_choice = None
if tool_config:
mode = tool_config.get("mode", "auto")
if mode in ["auto", "any", "none"]:
tool_choice = mode
response = await self._execute_generation(
final_messages,
model,
"内容分析失败",
kwargs,
llm_tools=llm_tools,
tool_choice=tool_choice,
)
return response
async def _execute_generation(
self,
messages: list[LLMMessage],
model_name: ModelName,
error_message: str,
config_overrides: dict[str, Any],
llm_tools: list[LLMTool] | None = None,
tool_choice: str | dict[str, Any] | None = None,
base_config: LLMGenerationConfig | None = None,
) -> LLMResponse:
"""通用的生成执行方法封装模型获取和单次API调用"""
try:
resolved_model_name = self._resolve_model_name(
model_name or self.config.model
)
final_config_dict = self._merge_config(
config_overrides, base_config=base_config
)
async with await get_model_instance(
resolved_model_name, override_config=final_config_dict
) as model_instance:
return await model_instance.generate_response(
messages,
tools=llm_tools,
tool_choice=tool_choice,
)
except LLMException:
raise
except Exception as e:
logger.error(f"{error_message}: {e}", e=e)
raise LLMException(f"{error_message}: {e}", cause=e)
def _resolve_model_name(self, model_name: ModelName) -> str:
"""解析模型名称"""
if model_name:
return model_name
default_model = get_global_default_model_name()
if default_model:
return default_model
raise LLMException(
"未指定模型名称且未设置全局默认模型",
code=LLMErrorCode.MODEL_NOT_FOUND,
)
def _merge_config(
self,
user_config: dict[str, Any],
base_config: LLMGenerationConfig | None = None,
) -> dict[str, Any]:
"""合并配置"""
final_config = {}
if base_config:
final_config.update(base_config.to_dict())
if self.config.temperature is not None:
final_config["temperature"] = self.config.temperature
if self.config.max_tokens is not None:
final_config["max_tokens"] = self.config.max_tokens
if self.config.enable_cache:
final_config["enable_caching"] = True
if self.config.enable_code:
final_config["enable_code_execution"] = True
if self.config.enable_search:
final_config["enable_grounding"] = True
if self.config.enable_gemini_json_mode:
final_config["response_mime_type"] = "application/json"
if self.config.enable_gemini_thinking:
final_config["thinking_budget"] = 0.8
if self.config.enable_gemini_safe_mode:
final_config["safety_settings"] = (
CommonOverrides.gemini_safe().safety_settings
)
if self.config.enable_gemini_multimodal:
final_config.update(CommonOverrides.gemini_multimodal().to_dict())
if self.config.enable_gemini_grounding:
final_config["enable_grounding"] = True
final_config.update(user_config)
return final_config
async def embed(
self,
texts: list[str] | str,
*,
model: ModelName = None,
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
**kwargs: Any,
) -> list[list[float]]:
"""
生成文本嵌入向量
参数:
texts: 要生成嵌入向量的文本或文本列表
model: 要使用的嵌入模型名称
task_type: 嵌入任务类型
**kwargs: 传递给模型的其他参数
返回:
list[list[float]]: 文本的嵌入向量列表
"""
if isinstance(texts, str):
texts = [texts]
if not texts:
return []
try:
resolved_model_str = (
model or self.config.default_embedding_model or self.config.model
)
if not resolved_model_str:
raise LLMException(
"使用 embed 功能时必须指定嵌入模型名称,"
"或在 AIConfig 中配置 default_embedding_model。",
code=LLMErrorCode.MODEL_NOT_FOUND,
)
resolved_model_str = self._resolve_model_name(resolved_model_str)
async with await get_model_instance(
resolved_model_str,
override_config=None,
) as embedding_model_instance:
return await embedding_model_instance.generate_embeddings(
texts, task_type=task_type, **kwargs
)
except LLMException:
raise
except Exception as e:
logger.error(f"文本嵌入失败: {e}", e=e)
raise LLMException(
f"文本嵌入失败: {e}", code=LLMErrorCode.EMBEDDING_FAILED, cause=e
)
async def chat(
message: str | LLMMessage | list[LLMContentPart],
*,
model: ModelName = None,
tools: list[LLMTool] | None = None,
tool_choice: str | dict[str, Any] | None = None,
**kwargs: Any,
) -> str:
) -> LLMResponse:
"""
聊天对话便捷函数
参数:
message: 用户输入的消息
model: 要使用的模型名称
tools: 本次对话可用的工具列表
tool_choice: 强制模型使用的工具
**kwargs: 传递给模型的其他参数
返回:
str: 模型的文本响应
LLMResponse: 模型的完整响应可能包含文本或工具调用请求
"""
ai = AI()
return await ai.chat(message, model=model, **kwargs)
return await ai.chat(
message, model=model, tools=tools, tool_choice=tool_choice, **kwargs
)
async def code(
@ -730,12 +242,14 @@ async def pipeline_chat(
raise ValueError("模型链`model_chain`不能为空。")
current_content: str | list[LLMContentPart]
if isinstance(message, str):
if isinstance(message, UniMessage):
current_content = await unimsg_to_llm_parts(message)
elif isinstance(message, str):
current_content = message
elif isinstance(message, list):
current_content = message
else:
current_content = await unimsg_to_llm_parts(message)
raise TypeError(f"不支持的消息类型: {type(message)}")
final_response: LLMResponse | None = None
@ -787,3 +301,45 @@ async def pipeline_chat(
)
return final_response
async def generate(
messages: list[LLMMessage],
*,
model: ModelName = None,
tools: list[LLMTool] | None = None,
tool_choice: str | dict[str, Any] | None = None,
**kwargs: Any,
) -> LLMResponse:
"""
根据完整的消息列表包括系统指令生成一次性响应
这是一个便捷的函数不使用或修改任何会话历史
参数:
messages: 用于生成响应的完整消息列表
model: 要使用的模型名称
tools: 可用的工具列表
tool_choice: 工具选择策略
**kwargs: 传递给模型的其他参数
返回:
LLMResponse: 模型的完整响应对象
"""
try:
ai_instance = AI()
resolved_model_name = ai_instance._resolve_model_name(model)
final_config_dict = ai_instance._merge_config(kwargs)
async with await get_model_instance(
resolved_model_name, override_config=final_config_dict
) as model_instance:
return await model_instance.generate_response(
messages,
tools=tools,
tool_choice=tool_choice,
)
except LLMException:
raise
except Exception as e:
logger.error(f"生成响应失败: {e}", e=e)
raise LLMException(f"生成响应失败: {e}", cause=e)

View File

@ -17,6 +17,7 @@ from zhenxun.configs.utils import parse_as
from zhenxun.services.log import logger
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from ..core import key_store
from ..types.models import ModelDetail, ProviderConfig
@ -502,12 +503,13 @@ def set_default_model(provider_model_name: str | None) -> bool:
@PriorityLifecycle.on_startup(priority=10)
async def _init_llm_config_on_startup():
"""
在服务启动时主动调用一次 get_llm_config
以触发必要的初始化操作例如创建默认的 mcp_tools.json 文件
在服务启动时主动调用一次 get_llm_config key_store.initialize
以触发必要的初始化操作
"""
logger.info("正在初始化 LLM 配置并检查 MCP 工具文件...")
logger.info("正在初始化 LLM 配置并加载密钥状态...")
try:
get_llm_config()
logger.info("LLM 配置初始化完成。")
await key_store.initialize()
logger.info("LLM 配置和密钥状态初始化完成。")
except Exception as e:
logger.error(f"LLM 配置初始化时发生错误: {e}", e=e)
logger.error(f"LLM 配置或密钥状态初始化时发生错误: {e}", e=e)

View File

@ -5,17 +5,27 @@ LLM 核心基础设施模块
"""
import asyncio
from dataclasses import asdict, dataclass
from enum import IntEnum
import json
import os
import time
from typing import Any
import aiofiles
import httpx
import nonebot
from pydantic import BaseModel
from zhenxun.configs.path_config import DATA_PATH
from zhenxun.services.log import logger
from zhenxun.utils.user_agent import get_user_agent
from .types import ProviderConfig
from .types.exceptions import LLMErrorCode, LLMException
driver = nonebot.get_driver()
class HttpClientConfig(BaseModel):
"""HTTP客户端配置"""
@ -194,6 +204,82 @@ async def create_llm_http_client(
return LLMHttpClient(config)
class KeyStatus(IntEnum):
"""用于排序和展示的密钥状态枚举"""
DISABLED = 0
ERROR = 1
COOLDOWN = 2
WARNING = 3
HEALTHY = 4
UNUSED = 5
@dataclass
class KeyStats:
"""单个API Key的详细状态和统计信息"""
cooldown_until: float = 0.0
success_count: int = 0
failure_count: int = 0
total_latency: float = 0.0
last_error_info: str | None = None
@property
def is_available(self) -> bool:
"""检查Key当前是否可用"""
return time.time() >= self.cooldown_until
@property
def avg_latency(self) -> float:
"""计算平均延迟"""
return (
self.total_latency / self.success_count if self.success_count > 0 else 0.0
)
@property
def success_rate(self) -> float:
"""计算成功率"""
total = self.success_count + self.failure_count
return self.success_count / total * 100 if total > 0 else 100.0
@property
def status(self) -> KeyStatus:
"""根据当前统计数据动态计算状态"""
now = time.time()
cooldown_left = max(0, self.cooldown_until - now)
if cooldown_left > 31536000 - 60:
return KeyStatus.DISABLED
if cooldown_left > 0:
return KeyStatus.COOLDOWN
total_calls = self.success_count + self.failure_count
if total_calls == 0:
return KeyStatus.UNUSED
if self.success_rate < 80:
return KeyStatus.ERROR
if total_calls >= 5 and self.avg_latency > 15000:
return KeyStatus.WARNING
return KeyStatus.HEALTHY
@property
def suggested_action(self) -> str:
"""根据状态给出建议操作"""
status_actions = {
KeyStatus.DISABLED: "更换Key",
KeyStatus.ERROR: "检查网络/重置",
KeyStatus.COOLDOWN: "等待/重置",
KeyStatus.WARNING: "观察",
KeyStatus.HEALTHY: "-",
KeyStatus.UNUSED: "-",
}
return status_actions.get(self.status, "未知")
class RetryConfig:
"""重试配置"""
@ -236,26 +322,38 @@ async def with_smart_retry(
last_exception: Exception | None = None
failed_keys: set[str] = set()
model_instance = next((arg for arg in args if hasattr(arg, "api_keys")), None)
all_provider_keys = model_instance.api_keys if model_instance else []
for attempt in range(config.max_retries + 1):
try:
if config.key_rotation and "failed_keys" in func.__code__.co_varnames:
kwargs["failed_keys"] = failed_keys
return await func(*args, **kwargs)
start_time = time.monotonic()
result = await func(*args, **kwargs)
latency = (time.monotonic() - start_time) * 1000
if key_store and isinstance(result, tuple) and len(result) == 2:
final_result, api_key_used = result
if api_key_used:
await key_store.record_success(api_key_used, latency)
return final_result
else:
return result
except LLMException as e:
last_exception = e
api_key_in_use = e.details.get("api_key")
if e.code in [
LLMErrorCode.API_KEY_INVALID,
LLMErrorCode.API_QUOTA_EXCEEDED,
]:
if hasattr(e, "details") and e.details and "api_key" in e.details:
failed_keys.add(e.details["api_key"])
if key_store and provider_name:
await key_store.record_failure(
e.details["api_key"], e.details.get("status_code")
)
if api_key_in_use:
failed_keys.add(api_key_in_use)
if key_store and provider_name and len(all_provider_keys) > 1:
status_code = e.details.get("status_code")
error_message = f"({e.code.name}) {e.message}"
await key_store.record_failure(
api_key_in_use, status_code, error_message
)
should_retry = _should_retry_llm_error(e, attempt, config.max_retries)
if not should_retry:
@ -267,7 +365,7 @@ async def with_smart_retry(
if config.exponential_backoff:
wait_time *= 2**attempt
logger.warning(
f"请求失败,{wait_time}秒后重试 (第{attempt + 1}次): {e}"
f"请求失败,{wait_time:.2f}秒后重试 (第{attempt + 1}次): {e}"
)
await asyncio.sleep(wait_time)
else:
@ -325,14 +423,66 @@ def _should_retry_llm_error(
class KeyStatusStore:
"""API Key 状态管理存储 - 优化版本,支持轮询和负载均衡"""
"""API Key 状态管理存储 - 支持持久化"""
def __init__(self):
self._key_status: dict[str, bool] = {}
self._key_usage_count: dict[str, int] = {}
self._key_last_used: dict[str, float] = {}
self._key_stats: dict[str, KeyStats] = {}
self._provider_key_index: dict[str, int] = {}
self._lock = asyncio.Lock()
self._file_path = DATA_PATH / "llm" / "key_status.json"
async def initialize(self):
"""从文件异步加载密钥状态,在应用启动时调用"""
async with self._lock:
if not self._file_path.exists():
logger.info("未找到密钥状态文件,将使用内存状态启动。")
return
try:
logger.info(f"正在从 {self._file_path} 加载密钥状态...")
async with aiofiles.open(self._file_path, encoding="utf-8") as f:
content = await f.read()
if not content:
logger.warning("密钥状态文件为空。")
return
data = json.loads(content)
for key, stats_dict in data.items():
self._key_stats[key] = KeyStats(**stats_dict)
logger.info(f"成功加载 {len(self._key_stats)} 个密钥的状态。")
except json.JSONDecodeError:
logger.error(f"密钥状态文件 {self._file_path} 格式错误,无法解析。")
except Exception as e:
logger.error(f"加载密钥状态文件时发生错误: {e}", e=e)
async def _save_to_file_internal(self):
"""
[内部方法] 将当前密钥状态安全地写入JSON文件
假定调用方已持有锁
"""
data_to_save = {key: asdict(stats) for key, stats in self._key_stats.items()}
try:
self._file_path.parent.mkdir(parents=True, exist_ok=True)
temp_path = self._file_path.with_suffix(".json.tmp")
async with aiofiles.open(temp_path, "w", encoding="utf-8") as f:
await f.write(json.dumps(data_to_save, ensure_ascii=False, indent=2))
if self._file_path.exists():
self._file_path.unlink()
os.rename(temp_path, self._file_path)
logger.debug("密钥状态已成功持久化到文件。")
except Exception as e:
logger.error(f"保存密钥状态到文件失败: {e}", e=e)
async def shutdown(self):
"""在应用关闭时安全地保存状态"""
async with self._lock:
await self._save_to_file_internal()
logger.info("KeyStatusStore 已在关闭前保存状态。")
async def get_next_available_key(
self,
@ -355,88 +505,122 @@ class KeyStatusStore:
return None
exclude_keys = exclude_keys or set()
available_keys = [
key
for key in api_keys
if key not in exclude_keys and self._key_status.get(key, True)
]
if not available_keys:
return api_keys[0] if api_keys else None
async with self._lock:
for key in api_keys:
if key not in self._key_stats:
self._key_stats[key] = KeyStats()
available_keys = [
key
for key in api_keys
if key not in exclude_keys and self._key_stats[key].is_available
]
if not available_keys:
return api_keys[0]
current_index = self._provider_key_index.get(provider_name, 0)
selected_key = available_keys[current_index % len(available_keys)]
self._provider_key_index[provider_name] = current_index + 1
self._provider_key_index[provider_name] = (current_index + 1) % len(
available_keys
total_usage = (
self._key_stats[selected_key].success_count
+ self._key_stats[selected_key].failure_count
)
import time
self._key_usage_count[selected_key] = (
self._key_usage_count.get(selected_key, 0) + 1
)
self._key_last_used[selected_key] = time.time()
logger.debug(
f"轮询选择API密钥: {self._get_key_id(selected_key)} "
f"(使用次数: {self._key_usage_count[selected_key]})"
f"(使用次数: {total_usage})"
)
return selected_key
async def record_success(self, api_key: str):
"""记录成功使用"""
async def record_success(self, api_key: str, latency: float):
"""记录成功使用,并持久化"""
async with self._lock:
self._key_status[api_key] = True
logger.debug(f"记录API密钥成功使用: {self._get_key_id(api_key)}")
stats = self._key_stats.setdefault(api_key, KeyStats())
stats.cooldown_until = 0.0
stats.success_count += 1
stats.total_latency += latency
stats.last_error_info = None
await self._save_to_file_internal()
logger.debug(
f"记录API密钥成功使用: {self._get_key_id(api_key)}, 延迟: {latency:.2f}ms"
)
async def record_failure(self, api_key: str, status_code: int | None):
async def record_failure(
self, api_key: str, status_code: int | None, error_message: str
):
"""
记录失败使用
记录失败使用并设置冷却时间
参数:
api_key: API密钥
status_code: HTTP状态码
error_message: 错误信息
"""
key_id = self._get_key_id(api_key)
now = time.time()
cooldown_duration = 300
if status_code in [401, 403, 404]:
cooldown_duration = 31536000
log_level = "error"
log_message = f"API密钥认证/权限/路径错误,将永久禁用: {key_id}"
elif status_code == 429:
cooldown_duration = 60
log_level = "warning"
log_message = f"API密钥被限流冷却60秒: {key_id}"
else:
log_level = "warning"
log_message = f"API密钥遇到临时性错误冷却{cooldown_duration}秒: {key_id}"
async with self._lock:
if status_code in [401, 403]:
self._key_status[api_key] = False
logger.warning(
f"API密钥认证失败标记为不可用: {key_id} (状态码: {status_code})"
)
else:
logger.debug(f"记录API密钥失败使用: {key_id} (状态码: {status_code})")
stats = self._key_stats.setdefault(api_key, KeyStats())
stats.cooldown_until = now + cooldown_duration
stats.failure_count += 1
stats.last_error_info = error_message[:256]
await self._save_to_file_internal()
getattr(logger, log_level)(log_message)
async def reset_key_status(self, api_key: str):
"""重置密钥状态(用于恢复机制)"""
"""重置密钥状态,并持久化"""
async with self._lock:
self._key_status[api_key] = True
stats = self._key_stats.setdefault(api_key, KeyStats())
stats.cooldown_until = 0.0
stats.last_error_info = None
await self._save_to_file_internal()
logger.info(f"重置API密钥状态: {self._get_key_id(api_key)}")
async def get_key_stats(self, api_keys: list[str]) -> dict[str, dict]:
"""
获取密钥使用统计
获取密钥使用统计并计算出用于展示的派生数据
参数:
api_keys: API密钥列表
返回:
dict[str, dict]: 密钥统计信息字典
dict[str, dict]: 包含丰富状态和统计信息的密钥字典
"""
stats = {}
stats_dict = {}
now = time.time()
async with self._lock:
for key in api_keys:
key_id = self._get_key_id(key)
stats[key_id] = {
"available": self._key_status.get(key, True),
"usage_count": self._key_usage_count.get(key, 0),
"last_used": self._key_last_used.get(key, 0),
stats = self._key_stats.get(key, KeyStats())
stats_dict[key_id] = {
"status_enum": stats.status,
"cooldown_seconds_left": max(0, stats.cooldown_until - now),
"total_calls": stats.success_count + stats.failure_count,
"success_count": stats.success_count,
"failure_count": stats.failure_count,
"success_rate": stats.success_rate,
"avg_latency": stats.avg_latency,
"last_error": stats.last_error_info,
"suggested_action": stats.suggested_action,
}
return stats
return stats_dict
def _get_key_id(self, api_key: str) -> str:
"""获取API密钥的标识符用于日志"""
@ -446,3 +630,8 @@ class KeyStatusStore:
key_store = KeyStatusStore()
@driver.on_shutdown
async def _shutdown_key_store():
await key_store.shutdown()

View File

@ -137,8 +137,8 @@ def get_configured_providers() -> list[ProviderConfig]:
valid_providers.append(item)
else:
logger.warning(
f"配置文件中第 {i + 1} 项未能正确解析为 ProviderConfig 对象,"
f"已跳过。实际类型: {type(item)}"
f"配置文件中第 {i + 1} 项未能正确解析为 ProviderConfig 对象,已跳过。"
f"实际类型: {type(item)}"
)
return valid_providers

View File

@ -46,17 +46,7 @@ class LLMModelBase(ABC):
history: list[dict[str, str]] | None = None,
**kwargs: Any,
) -> str:
"""
生成文本
参数:
prompt: 输入提示词
history: 对话历史记录
**kwargs: 其他参数
返回:
str: 生成的文本
"""
"""生成文本"""
pass
@abstractmethod
@ -68,19 +58,7 @@ class LLMModelBase(ABC):
tool_choice: str | dict[str, Any] | None = None,
**kwargs: Any,
) -> LLMResponse:
"""
生成高级响应
参数:
messages: 消息列表
config: 生成配置
tools: 工具列表
tool_choice: 工具选择策略
**kwargs: 其他参数
返回:
LLMResponse: 模型响应
"""
"""生成高级响应"""
pass
@abstractmethod
@ -90,17 +68,7 @@ class LLMModelBase(ABC):
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
**kwargs: Any,
) -> list[list[float]]:
"""
生成文本嵌入向量
参数:
texts: 文本列表
task_type: 嵌入任务类型
**kwargs: 其他参数
返回:
list[list[float]]: 嵌入向量列表
"""
"""生成文本嵌入向量"""
pass
@ -208,28 +176,8 @@ class LLMModel(LLMModelBase):
http_client: "LLMHttpClient",
failed_keys: set[str] | None = None,
log_context: str = "API",
) -> Any:
"""
执行API调用的通用核心方法
该方法封装了以下通用逻辑:
1. 选择API密钥
2. 准备和记录请求
3. 发送HTTP POST请求
4. 处理HTTP错误和API特定错误
5. 记录密钥使用状态
6. 解析成功的响应
参数:
prepare_request_func: 准备请求的函数
parse_response_func: 解析响应的函数
http_client: HTTP客户端
failed_keys: 失败的密钥集合
log_context: 日志上下文
返回:
Any: 解析后的响应数据
"""
) -> tuple[Any, str]:
"""执行API调用的通用核心方法"""
api_key = await self._select_api_key(failed_keys)
try:
@ -267,7 +215,9 @@ class LLMModel(LLMModelBase):
)
logger.debug(f"💥 完整错误响应: {error_text}")
await self.key_store.record_failure(api_key, http_response.status_code)
await self.key_store.record_failure(
api_key, http_response.status_code, error_text
)
if http_response.status_code in [401, 403]:
error_code = LLMErrorCode.API_KEY_INVALID
@ -298,7 +248,7 @@ class LLMModel(LLMModelBase):
except Exception as e:
logger.error(f"解析 {log_context} 响应失败: {e}", e=e)
await self.key_store.record_failure(api_key, None)
await self.key_store.record_failure(api_key, None, str(e))
if isinstance(e, LLMException):
raise
else:
@ -308,17 +258,15 @@ class LLMModel(LLMModelBase):
cause=e,
)
await self.key_store.record_success(api_key)
logger.debug(f"✅ API密钥使用成功: {masked_key}")
logger.info(f"🎯 LLM响应解析完成 [{log_context}]")
return parsed_data
return parsed_data, api_key
except LLMException:
raise
except Exception as e:
error_log_msg = f"生成 {log_context.lower()} 时发生未预期错误: {e}"
logger.error(error_log_msg, e=e)
await self.key_store.record_failure(api_key, None)
await self.key_store.record_failure(api_key, None, str(e))
raise LLMException(
error_log_msg,
code=LLMErrorCode.GENERATION_FAILED
@ -349,13 +297,14 @@ class LLMModel(LLMModelBase):
adapter.validate_embedding_response(response_json)
return adapter.parse_embedding_response(response_json)
return await self._perform_api_call(
parsed_data, api_key_used = await self._perform_api_call(
prepare_request_func=prepare_request,
parse_response_func=parse_response,
http_client=http_client,
failed_keys=failed_keys,
log_context="Embedding",
)
return parsed_data
async def _execute_with_smart_retry(
self,
@ -394,8 +343,8 @@ class LLMModel(LLMModelBase):
tool_choice: str | dict[str, Any] | None,
http_client: LLMHttpClient,
failed_keys: set[str] | None = None,
) -> LLMResponse:
"""执行单次请求 - 供重试机制调用,直接返回 LLMResponse"""
) -> tuple[LLMResponse, str]:
"""执行单次请求 - 供重试机制调用,直接返回 LLMResponse 和使用的 key"""
async def prepare_request(api_key: str) -> RequestData:
return await adapter.prepare_advanced_request(
@ -441,19 +390,17 @@ class LLMModel(LLMModelBase):
cache_info=response_data.cache_info,
)
return await self._perform_api_call(
parsed_data, api_key_used = await self._perform_api_call(
prepare_request_func=prepare_request,
parse_response_func=parse_response,
http_client=http_client,
failed_keys=failed_keys,
log_context="Generation",
)
return parsed_data, api_key_used
async def close(self):
"""
标记模型实例的当前使用周期结束
共享的 HTTP 客户端由 LLMHttpClientManager 管理不由 LLMModel 关闭
"""
"""标记模型实例的当前使用周期结束"""
if self._is_closed:
return
self._is_closed = True
@ -487,17 +434,7 @@ class LLMModel(LLMModelBase):
history: list[dict[str, str]] | None = None,
**kwargs: Any,
) -> str:
"""
生成文本 - 通过 generate_response 实现
参数:
prompt: 输入提示词
history: 对话历史记录
**kwargs: 其他参数
返回:
str: 生成的文本
"""
"""生成文本"""
self._check_not_closed()
messages: list[LLMMessage] = []
@ -538,19 +475,7 @@ class LLMModel(LLMModelBase):
tool_choice: str | dict[str, Any] | None = None,
**kwargs: Any,
) -> LLMResponse:
"""
生成高级响应
参数:
messages: 消息列表
config: 生成配置
tools: 工具列表
tool_choice: 工具选择策略
**kwargs: 其他参数
返回:
LLMResponse: 模型响应
"""
"""生成高级响应"""
self._check_not_closed()
from .adapters import get_adapter_for_api_type
@ -619,17 +544,7 @@ class LLMModel(LLMModelBase):
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
**kwargs: Any,
) -> list[list[float]]:
"""
生成文本嵌入向量
参数:
texts: 文本列表
task_type: 嵌入任务类型
**kwargs: 其他参数
返回:
list[list[float]]: 嵌入向量列表
"""
"""生成文本嵌入向量"""
self._check_not_closed()
if not texts:
return []

View File

@ -0,0 +1,532 @@
"""
LLM 服务 - 会话客户端
提供一个有状态的面向会话的 LLM 客户端用于进行多轮对话和复杂交互
"""
import copy
from dataclasses import dataclass
from typing import Any
from nonebot_plugin_alconna.uniseg import UniMessage
from zhenxun.services.log import logger
from .config import CommonOverrides, LLMGenerationConfig
from .config.providers import get_ai_config
from .manager import get_global_default_model_name, get_model_instance
from .tools import tool_registry
from .types import (
EmbeddingTaskType,
LLMContentPart,
LLMErrorCode,
LLMException,
LLMMessage,
LLMResponse,
LLMTool,
ModelName,
)
from .utils import unimsg_to_llm_parts
@dataclass
class AIConfig:
"""AI配置类 - 简化版本"""
model: ModelName = None
default_embedding_model: ModelName = None
temperature: float | None = None
max_tokens: int | None = None
enable_cache: bool = False
enable_code: bool = False
enable_search: bool = False
timeout: int | None = None
enable_gemini_json_mode: bool = False
enable_gemini_thinking: bool = False
enable_gemini_safe_mode: bool = False
enable_gemini_multimodal: bool = False
enable_gemini_grounding: bool = False
default_preserve_media_in_history: bool = False
def __post_init__(self):
"""初始化后从配置中读取默认值"""
ai_config = get_ai_config()
if self.model is None:
self.model = ai_config.get("default_model_name")
if self.timeout is None:
self.timeout = ai_config.get("timeout", 180)
class AI:
"""统一的AI服务类 - 平衡设计版本
提供三层API
1. 简单方法ai.chat(), ai.code(), ai.search()
2. 标准方法ai.analyze() 支持复杂参数
3. 高级方法通过get_model_instance()直接访问
"""
def __init__(
self, config: AIConfig | None = None, history: list[LLMMessage] | None = None
):
"""
初始化AI服务
参数:
config: AI 配置.
history: 可选的初始对话历史.
"""
self.config = config or AIConfig()
self.history = history or []
def clear_history(self):
"""清空当前会话的历史记录"""
self.history = []
logger.info("AI session history cleared.")
def _sanitize_message_for_history(self, message: LLMMessage) -> LLMMessage:
"""
净化用于存入历史记录的消息
将非文本的多模态内容部分替换为文本占位符以避免重复处理
"""
if not isinstance(message.content, list):
return message
sanitized_message = copy.deepcopy(message)
content_list = sanitized_message.content
if not isinstance(content_list, list):
return sanitized_message
new_content_parts: list[LLMContentPart] = []
has_multimodal_content = False
for part in content_list:
if isinstance(part, LLMContentPart) and part.type == "text":
new_content_parts.append(part)
else:
has_multimodal_content = True
if has_multimodal_content:
placeholder = "[用户发送了媒体文件,内容已在首次分析时处理]"
text_part_found = False
for part in new_content_parts:
if part.type == "text":
part.text = f"{placeholder} {part.text or ''}".strip()
text_part_found = True
break
if not text_part_found:
new_content_parts.insert(0, LLMContentPart.text_part(placeholder))
sanitized_message.content = new_content_parts
return sanitized_message
async def chat(
self,
message: str | LLMMessage | list[LLMContentPart],
*,
model: ModelName = None,
preserve_media_in_history: bool | None = None,
tools: list[LLMTool] | None = None,
tool_choice: str | dict[str, Any] | None = None,
**kwargs: Any,
) -> LLMResponse:
"""
进行一次聊天对话支持工具调用
此方法会自动使用和更新会话内的历史记录
参数:
message: 用户输入的消息
model: 本次对话要使用的模型
preserve_media_in_history: 是否在历史记录中保留原始多模态信息
- True: 保留用于深度多轮媒体分析
- False: 不保留替换为占位符提高效率
- None (默认): 使用AI实例配置的默认值
tools: 本次对话可用的工具列表
tool_choice: 强制模型使用的工具
**kwargs: 传递给模型的其他生成参数
返回:
LLMResponse: 模型的完整响应可能包含文本或工具调用请求
"""
current_message: LLMMessage
if isinstance(message, str):
current_message = LLMMessage.user(message)
elif isinstance(message, list) and all(
isinstance(part, LLMContentPart) for part in message
):
current_message = LLMMessage.user(message)
elif isinstance(message, LLMMessage):
current_message = message
else:
raise LLMException(
f"AI.chat 不支持的消息类型: {type(message)}. "
"请使用 str, LLMMessage, 或 list[LLMContentPart]. "
"对于更复杂的多模态输入或文件路径,请使用 AI.analyze().",
code=LLMErrorCode.API_REQUEST_FAILED,
)
final_messages = [*self.history, current_message]
response = await self._execute_generation(
messages=final_messages,
model_name=model,
error_message="聊天失败",
config_overrides=kwargs,
llm_tools=tools,
tool_choice=tool_choice,
)
should_preserve = (
preserve_media_in_history
if preserve_media_in_history is not None
else self.config.default_preserve_media_in_history
)
if should_preserve:
logger.debug("深度分析模式:在历史记录中保留原始多模态消息。")
self.history.append(current_message)
else:
logger.debug("高效模式:净化历史记录中的多模态消息。")
sanitized_user_message = self._sanitize_message_for_history(current_message)
self.history.append(sanitized_user_message)
self.history.append(
LLMMessage(
role="assistant", content=response.text, tool_calls=response.tool_calls
)
)
return response
async def code(
self,
prompt: str,
*,
model: ModelName = None,
timeout: int | None = None,
**kwargs: Any,
) -> dict[str, Any]:
"""
代码执行
参数:
prompt: 代码执行的提示词
model: 要使用的模型名称
timeout: 代码执行超时时间
**kwargs: 传递给模型的其他参数
返回:
dict[str, Any]: 包含执行结果的字典包含textcode_executions和success字段
"""
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
config = CommonOverrides.gemini_code_execution()
if timeout:
config.custom_params = config.custom_params or {}
config.custom_params["code_execution_timeout"] = timeout
messages = [LLMMessage.user(prompt)]
response = await self._execute_generation(
messages=messages,
model_name=resolved_model,
error_message="代码执行失败",
config_overrides=kwargs,
base_config=config,
)
return {
"text": response.text,
"code_executions": response.code_executions or [],
"success": True,
}
async def search(
self,
query: str | UniMessage,
*,
model: ModelName = None,
instruction: str = "",
**kwargs: Any,
) -> dict[str, Any]:
"""
信息搜索 - 支持多模态输入
参数:
query: 搜索查询内容支持文本或多模态消息
model: 要使用的模型名称
instruction: 搜索指令
**kwargs: 传递给模型的其他参数
返回:
dict[str, Any]: 包含搜索结果的字典包含textsourcesqueries和success字段
"""
from nonebot_plugin_alconna.uniseg import UniMessage
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
config = CommonOverrides.gemini_grounding()
if isinstance(query, str):
messages = [LLMMessage.user(query)]
elif isinstance(query, UniMessage):
content_parts = await unimsg_to_llm_parts(query)
final_messages: list[LLMMessage] = []
if instruction:
final_messages.append(LLMMessage.system(instruction))
if not content_parts:
if instruction:
final_messages.append(LLMMessage.user(instruction))
else:
raise LLMException(
"搜索内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
)
else:
final_messages.append(LLMMessage.user(content_parts))
messages = final_messages
else:
raise LLMException(
f"不支持的搜索输入类型: {type(query)}. 请使用 str 或 UniMessage.",
code=LLMErrorCode.API_REQUEST_FAILED,
)
response = await self._execute_generation(
messages=messages,
model_name=resolved_model,
error_message="信息搜索失败",
config_overrides=kwargs,
base_config=config,
)
result = {
"text": response.text,
"sources": [],
"queries": [],
"success": True,
}
if response.grounding_metadata:
result["sources"] = response.grounding_metadata.grounding_attributions or []
result["queries"] = response.grounding_metadata.web_search_queries or []
return result
async def analyze(
self,
message: UniMessage | None,
*,
instruction: str = "",
model: ModelName = None,
use_tools: list[str] | None = None,
tool_config: dict[str, Any] | None = None,
activated_tools: list[LLMTool] | None = None,
history: list[LLMMessage] | None = None,
**kwargs: Any,
) -> LLMResponse:
"""
内容分析 - 接收 UniMessage 物件进行多模态分析和工具呼叫
参数:
message: 要分析的消息内容支持多模态
instruction: 分析指令
model: 要使用的模型名称
use_tools: 要使用的工具名称列表
tool_config: 工具配置
activated_tools: 已激活的工具列表
history: 对话历史记录
**kwargs: 传递给模型的其他参数
返回:
LLMResponse: 模型的完整响应结果
"""
from nonebot_plugin_alconna.uniseg import UniMessage
content_parts = await unimsg_to_llm_parts(message or UniMessage())
final_messages: list[LLMMessage] = []
if history:
final_messages.extend(history)
if instruction:
if not any(msg.role == "system" for msg in final_messages):
final_messages.insert(0, LLMMessage.system(instruction))
if not content_parts:
if instruction and not history:
final_messages.append(LLMMessage.user(instruction))
elif not history:
raise LLMException(
"分析内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
)
else:
final_messages.append(LLMMessage.user(content_parts))
llm_tools: list[LLMTool] | None = activated_tools
if not llm_tools and use_tools:
try:
llm_tools = tool_registry.get_tools(use_tools)
logger.debug(f"已从注册表加载工具定义: {use_tools}")
except ValueError as e:
raise LLMException(
f"加载工具定义失败: {e}",
code=LLMErrorCode.CONFIGURATION_ERROR,
cause=e,
)
tool_choice = None
if tool_config:
mode = tool_config.get("mode", "auto")
if mode in ["auto", "any", "none"]:
tool_choice = mode
response = await self._execute_generation(
messages=final_messages,
model_name=model,
error_message="内容分析失败",
config_overrides=kwargs,
llm_tools=llm_tools,
tool_choice=tool_choice,
)
return response
async def _execute_generation(
self,
messages: list[LLMMessage],
model_name: ModelName,
error_message: str,
config_overrides: dict[str, Any],
llm_tools: list[LLMTool] | None = None,
tool_choice: str | dict[str, Any] | None = None,
base_config: LLMGenerationConfig | None = None,
) -> LLMResponse:
"""通用的生成执行方法封装模型获取和单次API调用"""
try:
resolved_model_name = self._resolve_model_name(
model_name or self.config.model
)
final_config_dict = self._merge_config(
config_overrides, base_config=base_config
)
async with await get_model_instance(
resolved_model_name, override_config=final_config_dict
) as model_instance:
return await model_instance.generate_response(
messages,
tools=llm_tools,
tool_choice=tool_choice,
)
except LLMException:
raise
except Exception as e:
logger.error(f"{error_message}: {e}", e=e)
raise LLMException(f"{error_message}: {e}", cause=e)
def _resolve_model_name(self, model_name: ModelName) -> str:
"""解析模型名称"""
if model_name:
return model_name
default_model = get_global_default_model_name()
if default_model:
return default_model
raise LLMException(
"未指定模型名称且未设置全局默认模型",
code=LLMErrorCode.MODEL_NOT_FOUND,
)
def _merge_config(
self,
user_config: dict[str, Any],
base_config: LLMGenerationConfig | None = None,
) -> dict[str, Any]:
"""合并配置"""
final_config = {}
if base_config:
final_config.update(base_config.to_dict())
if self.config.temperature is not None:
final_config["temperature"] = self.config.temperature
if self.config.max_tokens is not None:
final_config["max_tokens"] = self.config.max_tokens
if self.config.enable_cache:
final_config["enable_caching"] = True
if self.config.enable_code:
final_config["enable_code_execution"] = True
if self.config.enable_search:
final_config["enable_grounding"] = True
if self.config.enable_gemini_json_mode:
final_config["response_mime_type"] = "application/json"
if self.config.enable_gemini_thinking:
final_config["thinking_budget"] = 0.8
if self.config.enable_gemini_safe_mode:
final_config["safety_settings"] = (
CommonOverrides.gemini_safe().safety_settings
)
if self.config.enable_gemini_multimodal:
final_config.update(CommonOverrides.gemini_multimodal().to_dict())
if self.config.enable_gemini_grounding:
final_config["enable_grounding"] = True
final_config.update(user_config)
return final_config
async def embed(
self,
texts: list[str] | str,
*,
model: ModelName = None,
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
**kwargs: Any,
) -> list[list[float]]:
"""
生成文本嵌入向量
参数:
texts: 要生成嵌入向量的文本或文本列表
model: 要使用的嵌入模型名称
task_type: 嵌入任务类型
**kwargs: 传递给模型的其他参数
返回:
list[list[float]]: 文本的嵌入向量列表
"""
if isinstance(texts, str):
texts = [texts]
if not texts:
return []
try:
resolved_model_str = (
model or self.config.default_embedding_model or self.config.model
)
if not resolved_model_str:
raise LLMException(
"使用 embed 功能时必须指定嵌入模型名称,"
"或在 AIConfig 中配置 default_embedding_model。",
code=LLMErrorCode.MODEL_NOT_FOUND,
)
resolved_model_str = self._resolve_model_name(resolved_model_str)
async with await get_model_instance(
resolved_model_str,
override_config=None,
) as embedding_model_instance:
return await embedding_model_instance.generate_embeddings(
texts, task_type=task_type, **kwargs
)
except LLMException:
raise
except Exception as e:
logger.error(f"文本嵌入失败: {e}", e=e)
raise LLMException(
f"文本嵌入失败: {e}", code=LLMErrorCode.EMBEDDING_FAILED, cause=e
)

View File

@ -10,7 +10,13 @@ from .content import (
LLMMessage,
LLMResponse,
)
from .enums import EmbeddingTaskType, ModelProvider, ResponseFormat, ToolCategory
from .enums import (
EmbeddingTaskType,
ModelProvider,
ResponseFormat,
TaskType,
ToolCategory,
)
from .exceptions import LLMErrorCode, LLMException, get_user_friendly_error_message
from .models import (
LLMCacheInfo,
@ -52,6 +58,7 @@ __all__ = [
"ModelProvider",
"ProviderConfig",
"ResponseFormat",
"TaskType",
"ToolCategory",
"ToolMetadata",
"UsageInfo",

View File

@ -45,6 +45,17 @@ class ToolCategory(Enum):
CUSTOM = auto()
class TaskType(Enum):
"""任务类型枚举"""
CHAT = "chat"
CODE = "code"
SEARCH = "search"
ANALYSIS = "analysis"
GENERATION = "generation"
MULTIMODAL = "multimodal"
class LLMErrorCode(Enum):
"""LLM 服务相关的错误代码枚举"""