diff --git a/zhenxun/builtin_plugins/llm_manager/__init__.py b/zhenxun/builtin_plugins/llm_manager/__init__.py new file mode 100644 index 00000000..de0e0caf --- /dev/null +++ b/zhenxun/builtin_plugins/llm_manager/__init__.py @@ -0,0 +1,171 @@ +from nonebot.permission import SUPERUSER +from nonebot.plugin import PluginMetadata +from nonebot_plugin_alconna import ( + Alconna, + Args, + Arparma, + Match, + Option, + Query, + Subcommand, + on_alconna, + store_true, +) + +from zhenxun.configs.utils import PluginExtraData +from zhenxun.services.log import logger +from zhenxun.utils.enum import PluginType +from zhenxun.utils.message import MessageUtils + +from .data_source import DataSource +from .presenters import Presenters + +__plugin_meta__ = PluginMetadata( + name="LLM模型管理", + description="查看和管理大语言模型服务。", + usage=""" + LLM模型管理 (SUPERUSER) + + llm list [--all] + - 查看可用模型列表。 + - --all: 显示包括不可用在内的所有模型。 + + llm info + - 查看指定模型的详细信息和能力。 + + llm default [Provider/ModelName] + - 查看或设置全局默认模型。 + - 不带参数: 查看当前默认模型。 + - 带参数: 设置新的默认模型。 + - 例子: llm default Gemini/gemini-2.0-flash + + llm test + - 测试指定模型的连通性和API Key有效性。 + + llm keys + - 查看指定提供商的所有API Key状态。 + + llm reset-key [--key ] + - 重置提供商的所有或指定API Key的失败状态。 + """, + extra=PluginExtraData( + author="HibiKier", + version="1.0.0", + plugin_type=PluginType.SUPERUSER, + ).to_dict(), +) + +llm_cmd = on_alconna( + Alconna( + "llm", + Subcommand("list", alias=["ls"], help_text="查看模型列表"), + Subcommand("info", Args["model_name", str], help_text="查看模型详情"), + Subcommand("default", Args["model_name?", str], help_text="查看或设置默认模型"), + Subcommand( + "test", Args["model_name", str], alias=["ping"], help_text="测试模型连通性" + ), + Subcommand("keys", Args["provider_name", str], help_text="查看API密钥状态"), + Subcommand( + "reset-key", + Args["provider_name", str], + Option("--key", Args["api_key", str], help_text="指定要重置的API Key"), + help_text="重置API Key状态", + ), + Option("--all", action=store_true, help_text="显示所有条目"), + ), + permission=SUPERUSER, + priority=5, + block=True, +) + + +@llm_cmd.assign("list") +async def handle_list(arp: Arparma, show_all: Query[bool] = Query("all")): + """处理 'llm list' 命令""" + logger.info("获取LLM模型列表", command="LLM Manage", session=arp.header_result) + models = await DataSource.get_model_list(show_all=show_all.result) + + image = await Presenters.format_model_list_as_image(models, show_all.result) + await llm_cmd.finish(MessageUtils.build_message(image)) + + +@llm_cmd.assign("info") +async def handle_info(arp: Arparma, model_name: Match[str]): + """处理 'llm info' 命令""" + logger.info( + f"获取模型详情: {model_name.result}", + command="LLM Manage", + session=arp.header_result, + ) + details = await DataSource.get_model_details(model_name.result) + if not details: + await llm_cmd.finish(f"未找到模型: {model_name.result}") + + image_bytes = await Presenters.format_model_details_as_markdown_image(details) + await llm_cmd.finish(MessageUtils.build_message(image_bytes)) + + +@llm_cmd.assign("default") +async def handle_default(arp: Arparma, model_name: Match[str]): + """处理 'llm default' 命令""" + if model_name.available: + logger.info( + f"设置默认模型为: {model_name.result}", + command="LLM Manage", + session=arp.header_result, + ) + success, message = await DataSource.set_default_model(model_name.result) + await llm_cmd.finish(message) + else: + logger.info("查看默认模型", command="LLM Manage", session=arp.header_result) + current_default = await DataSource.get_default_model() + await llm_cmd.finish(f"当前全局默认模型为: {current_default or '未设置'}") + + +@llm_cmd.assign("test") +async def handle_test(arp: Arparma, model_name: Match[str]): + """处理 'llm test' 命令""" + logger.info( + f"测试模型连通性: {model_name.result}", + command="LLM Manage", + session=arp.header_result, + ) + await llm_cmd.send(f"正在测试模型 '{model_name.result}',请稍候...") + + success, message = await DataSource.test_model_connectivity(model_name.result) + await llm_cmd.finish(message) + + +@llm_cmd.assign("keys") +async def handle_keys(arp: Arparma, provider_name: Match[str]): + """处理 'llm keys' 命令""" + logger.info( + f"查看提供商API Key状态: {provider_name.result}", + command="LLM Manage", + session=arp.header_result, + ) + sorted_stats = await DataSource.get_key_status(provider_name.result) + if not sorted_stats: + await llm_cmd.finish( + f"未找到提供商 '{provider_name.result}' 或其没有配置API Keys。" + ) + + image = await Presenters.format_key_status_as_image( + provider_name.result, sorted_stats + ) + await llm_cmd.finish(MessageUtils.build_message(image)) + + +@llm_cmd.assign("reset-key") +async def handle_reset_key( + arp: Arparma, provider_name: Match[str], api_key: Match[str] +): + """处理 'llm reset-key' 命令""" + key_to_reset = api_key.result if api_key.available else None + log_msg = f"重置 {provider_name.result} 的 " + ( + "指定API Key" if key_to_reset else "所有API Keys" + ) + logger.info(log_msg, command="LLM Manage", session=arp.header_result) + + success, message = await DataSource.reset_key(provider_name.result, key_to_reset) + await llm_cmd.finish(message) diff --git a/zhenxun/builtin_plugins/llm_manager/data_source.py b/zhenxun/builtin_plugins/llm_manager/data_source.py new file mode 100644 index 00000000..0fc19e99 --- /dev/null +++ b/zhenxun/builtin_plugins/llm_manager/data_source.py @@ -0,0 +1,120 @@ +import time +from typing import Any + +from zhenxun.services.llm import ( + LLMException, + get_global_default_model_name, + get_model_instance, + list_available_models, + set_global_default_model_name, +) +from zhenxun.services.llm.core import KeyStatus +from zhenxun.services.llm.manager import ( + reset_key_status, +) + + +class DataSource: + """LLM管理插件的数据源和业务逻辑""" + + @staticmethod + async def get_model_list(show_all: bool = False) -> list[dict[str, Any]]: + """获取模型列表""" + models = list_available_models() + if show_all: + return models + return [m for m in models if m.get("is_available", True)] + + @staticmethod + async def get_model_details(model_name_str: str) -> dict[str, Any] | None: + """获取指定模型的详细信息""" + try: + model = await get_model_instance(model_name_str) + return { + "provider_config": model.provider_config, + "model_detail": model.model_detail, + "capabilities": model.capabilities, + } + except LLMException: + return None + + @staticmethod + async def get_default_model() -> str | None: + """获取全局默认模型""" + return get_global_default_model_name() + + @staticmethod + async def set_default_model(model_name_str: str) -> tuple[bool, str]: + """设置全局默认模型""" + success = set_global_default_model_name(model_name_str) + if success: + return True, f"✅ 成功将默认模型设置为: {model_name_str}" + else: + return False, f"❌ 设置失败,模型 '{model_name_str}' 不存在或无效。" + + @staticmethod + async def test_model_connectivity(model_name_str: str) -> tuple[bool, str]: + """测试模型连通性""" + start_time = time.monotonic() + try: + async with await get_model_instance(model_name_str) as model: + await model.generate_text("你好") + end_time = time.monotonic() + latency = (end_time - start_time) * 1000 + return ( + True, + f"✅ 模型 '{model_name_str}' 连接成功!\n响应延迟: {latency:.2f} ms", + ) + except LLMException as e: + return ( + False, + f"❌ 模型 '{model_name_str}' 连接测试失败:\n" + f"{e.user_friendly_message}\n错误码: {e.code.name}", + ) + except Exception as e: + return False, f"❌ 测试时发生未知错误: {e!s}" + + @staticmethod + async def get_key_status(provider_name: str) -> list[dict[str, Any]] | None: + """获取并排序指定提供商的API Key状态""" + from zhenxun.services.llm.manager import get_key_usage_stats + + all_stats = await get_key_usage_stats() + provider_stats = all_stats.get(provider_name) + + if not provider_stats or not provider_stats.get("key_stats"): + return None + + key_stats_dict = provider_stats["key_stats"] + + stats_list = [ + {"key_id": key_id, **stats} for key_id, stats in key_stats_dict.items() + ] + + def sort_key(item: dict[str, Any]): + status_priority = item.get("status_enum", KeyStatus.UNUSED).value + return ( + status_priority, + 100 - item.get("success_rate", 100.0), + -item.get("total_calls", 0), + ) + + sorted_stats_list = sorted(stats_list, key=sort_key) + + return sorted_stats_list + + @staticmethod + async def reset_key(provider_name: str, api_key: str | None) -> tuple[bool, str]: + """重置API Key状态""" + success = await reset_key_status(provider_name, api_key) + if success: + if api_key: + if len(api_key) > 8: + target = f"API Key '{api_key[:4]}...{api_key[-4:]}'" + else: + target = f"API Key '{api_key}'" + else: + target = "所有API Keys" + return True, f"✅ 成功重置提供商 '{provider_name}' 的 {target} 的状态。" + else: + return False, "❌ 重置失败,请检查提供商名称或API Key是否正确。" diff --git a/zhenxun/builtin_plugins/llm_manager/presenters.py b/zhenxun/builtin_plugins/llm_manager/presenters.py new file mode 100644 index 00000000..d745aaf7 --- /dev/null +++ b/zhenxun/builtin_plugins/llm_manager/presenters.py @@ -0,0 +1,204 @@ +from typing import Any + +from zhenxun.services.llm.core import KeyStatus +from zhenxun.services.llm.types import ModelModality +from zhenxun.utils._build_image import BuildImage +from zhenxun.utils._image_template import ImageTemplate, Markdown, RowStyle + + +def _format_seconds(seconds: int) -> str: + """将秒数格式化为 'Xm Ys' 或 'Xh Ym' 的形式""" + if seconds <= 0: + return "0s" + if seconds < 60: + return f"{seconds}s" + + minutes, seconds = divmod(seconds, 60) + if minutes < 60: + return f"{minutes}m {seconds}s" + + hours, minutes = divmod(minutes, 60) + return f"{hours}h {minutes}m" + + +class Presenters: + """格式化LLM管理插件的输出 (图片格式)""" + + @staticmethod + async def format_model_list_as_image( + models: list[dict[str, Any]], show_all: bool + ) -> BuildImage: + """将模型列表格式化为表格图片""" + title = "📋 LLM模型列表" + (" (所有已配置模型)" if show_all else " (仅可用)") + + if not models: + return await BuildImage.build_text_image( + f"{title}\n\n当前没有配置任何LLM模型。" + ) + + column_name = ["提供商", "模型名称", "API类型", "状态"] + data_list = [] + for model in models: + status_text = "✅ 可用" if model.get("is_available", True) else "❌ 不可用" + embed_tag = " (Embed)" if model.get("is_embedding_model", False) else "" + data_list.append( + [ + model.get("provider_name", "N/A"), + f"{model.get('model_name', 'N/A')}{embed_tag}", + model.get("api_type", "N/A"), + status_text, + ] + ) + + return await ImageTemplate.table_page( + head_text=title, + tip_text="使用 `llm info ` 查看详情", + column_name=column_name, + data_list=data_list, + ) + + @staticmethod + async def format_model_details_as_markdown_image(details: dict[str, Any]) -> bytes: + """将模型详情格式化为Markdown图片""" + provider = details["provider_config"] + model = details["model_detail"] + caps = details["capabilities"] + + cap_list = [] + if ModelModality.IMAGE in caps.input_modalities: + cap_list.append("视觉") + if ModelModality.VIDEO in caps.input_modalities: + cap_list.append("视频") + if ModelModality.AUDIO in caps.input_modalities: + cap_list.append("音频") + if caps.supports_tool_calling: + cap_list.append("工具调用") + if caps.is_embedding_model: + cap_list.append("文本嵌入") + + md = Markdown() + md.head(f"🔎 模型详情: {provider.name}/{model.model_name}", level=1) + md.text("---") + md.head("提供商信息", level=2) + md.list( + [ + f"**名称**: {provider.name}", + f"**API 类型**: {provider.api_type}", + f"**API Base**: {provider.api_base or '默认'}", + ] + ) + md.head("模型详情", level=2) + + temp_value = model.temperature or provider.temperature or "未设置" + token_value = model.max_tokens or provider.max_tokens or "未设置" + + md.list( + [ + f"**名称**: {model.model_name}", + f"**默认温度**: {temp_value}", + f"**最大Token**: {token_value}", + f"**核心能力**: {', '.join(cap_list) or '纯文本'}", + ] + ) + + return await md.build() + + @staticmethod + async def format_key_status_as_image( + provider_name: str, sorted_stats: list[dict[str, Any]] + ) -> BuildImage: + """将已排序的、详细的API Key状态格式化为表格图片""" + title = f"🔑 '{provider_name}' API Key 状态" + + if not sorted_stats: + return await BuildImage.build_text_image( + f"{title}\n\n该提供商没有配置API Keys。" + ) + + def _status_row_style(column: str, text: str) -> RowStyle: + style = RowStyle() + if column == "状态": + if "✅ 健康" in text: + style.font_color = "#67C23A" + elif "⚠️ 告警" in text: + style.font_color = "#E6A23C" + elif "❌ 错误" in text or "🚫" in text: + style.font_color = "#F56C6C" + elif "❄️ 冷却中" in text: + style.font_color = "#409EFF" + elif column == "成功率": + try: + if text != "N/A": + rate = float(text.replace("%", "")) + if rate < 80: + style.font_color = "#F56C6C" + elif rate < 95: + style.font_color = "#E6A23C" + except (ValueError, TypeError): + pass + return style + + column_name = [ + "Key (部分)", + "状态", + "总调用", + "成功率", + "平均延迟(s)", + "上次错误", + "建议操作", + ] + data_list = [] + + for key_info in sorted_stats: + status_enum: KeyStatus = key_info["status_enum"] + + if status_enum == KeyStatus.COOLDOWN: + cooldown_seconds = int(key_info["cooldown_seconds_left"]) + formatted_time = _format_seconds(cooldown_seconds) + status_text = f"❄️ 冷却中({formatted_time})" + else: + status_text = { + KeyStatus.DISABLED: "🚫 永久禁用", + KeyStatus.ERROR: "❌ 错误", + KeyStatus.WARNING: "⚠️ 告警", + KeyStatus.HEALTHY: "✅ 健康", + KeyStatus.UNUSED: "⚪️ 未使用", + }.get(status_enum, "❔ 未知") + + total_calls = key_info["total_calls"] + total_calls_text = ( + f"{key_info['success_count']}/{total_calls}" + if total_calls > 0 + else "0/0" + ) + + success_rate = key_info["success_rate"] + success_rate_text = f"{success_rate:.1f}%" if total_calls > 0 else "N/A" + + avg_latency = key_info["avg_latency"] + avg_latency_text = f"{avg_latency / 1000:.2f}" if avg_latency > 0 else "N/A" + + last_error = key_info.get("last_error") or "-" + if len(last_error) > 25: + last_error = last_error[:22] + "..." + + data_list.append( + [ + key_info["key_id"], + status_text, + total_calls_text, + success_rate_text, + avg_latency_text, + last_error, + key_info["suggested_action"], + ] + ) + + return await ImageTemplate.table_page( + head_text=title, + tip_text="使用 `llm reset-key ` 重置Key状态", + column_name=column_name, + data_list=data_list, + text_style=_status_row_style, + column_space=15, + ) diff --git a/zhenxun/services/__init__.py b/zhenxun/services/__init__.py index 6af390a8..4c820b87 100644 --- a/zhenxun/services/__init__.py +++ b/zhenxun/services/__init__.py @@ -21,11 +21,28 @@ require("nonebot_plugin_waiter") from .db_context import Model, disconnect from .llm import ( AI, + AIConfig, + CommonOverrides, LLMContentPart, LLMException, + LLMGenerationConfig, LLMMessage, + analyze, + analyze_multimodal, + chat, + clear_model_cache, + code, + create_multimodal_message, + embed, + generate, + get_cache_stats, get_model_instance, list_available_models, + list_embedding_models, + pipeline_chat, + search, + search_multimodal, + set_global_default_model_name, tool_registry, ) from .log import logger @@ -34,16 +51,33 @@ from .scheduler import scheduler_manager __all__ = [ "AI", + "AIConfig", + "CommonOverrides", "LLMContentPart", "LLMException", + "LLMGenerationConfig", "LLMMessage", "Model", "PluginInit", "PluginInitManager", + "analyze", + "analyze_multimodal", + "chat", + "clear_model_cache", + "code", + "create_multimodal_message", "disconnect", + "embed", + "generate", + "get_cache_stats", "get_model_instance", "list_available_models", + "list_embedding_models", "logger", + "pipeline_chat", "scheduler_manager", + "search", + "search_multimodal", + "set_global_default_model_name", "tool_registry", ] diff --git a/zhenxun/services/llm/README.md b/zhenxun/services/llm/README.md index 93394fdf..c827f80a 100644 --- a/zhenxun/services/llm/README.md +++ b/zhenxun/services/llm/README.md @@ -198,7 +198,7 @@ print(search_result['text']) 当你需要进行有上下文的、连续的对话时,`AI` 类是你的最佳选择。 ```python -from zhenxun.services.llm.api import AI, AIConfig +from zhenxun.services.llm import AI, AIConfig # 初始化一个AI会话,可以传入自定义配置 ai_config = AIConfig(model="GLM/glm-4-flash", temperature=0.7) @@ -395,7 +395,7 @@ async def my_tool_factory(config: MyToolConfig): 在 `analyze` 或 `generate_response` 中使用 `use_tools` 参数。框架会自动处理整个调用流程。 ```python -from zhenxun.services.llm.api import analyze +from zhenxun.services.llm import analyze from nonebot_plugin_alconna.uniseg import UniMessage response = await analyze( @@ -442,7 +442,6 @@ from zhenxun.services.llm.manager import ( get_key_usage_stats, reset_key_status ) -from zhenxun.services.llm import clear_model_cache, get_cache_stats # 列出所有在config.yaml中配置的可用模型 models = list_available_models() diff --git a/zhenxun/services/llm/__init__.py b/zhenxun/services/llm/__init__.py index 62a0003f..31e82d4d 100644 --- a/zhenxun/services/llm/__init__.py +++ b/zhenxun/services/llm/__init__.py @@ -5,14 +5,12 @@ LLM 服务模块 - 公共 API 入口 """ from .api import ( - AI, - AIConfig, - TaskType, analyze, analyze_multimodal, chat, code, embed, + generate, pipeline_chat, search, search_multimodal, @@ -35,6 +33,7 @@ from .manager import ( list_model_identifiers, set_global_default_model_name, ) +from .session import AI, AIConfig from .tools import tool_registry from .types import ( EmbeddingTaskType, @@ -49,6 +48,7 @@ from .types import ( ModelInfo, ModelProvider, ResponseFormat, + TaskType, ToolCategory, ToolMetadata, UsageInfo, @@ -84,6 +84,7 @@ __all__ = [ "code", "create_multimodal_message", "embed", + "generate", "get_cache_stats", "get_global_default_model_name", "get_model_instance", diff --git a/zhenxun/services/llm/api.py b/zhenxun/services/llm/api.py index d9606f80..5059bbe9 100644 --- a/zhenxun/services/llm/api.py +++ b/zhenxun/services/llm/api.py @@ -1,10 +1,7 @@ """ -LLM 服务的高级 API 接口 +LLM 服务的高级 API 接口 - 便捷函数入口 """ -import copy -from dataclasses import dataclass -from enum import Enum from pathlib import Path from typing import Any @@ -12,10 +9,8 @@ from nonebot_plugin_alconna.uniseg import UniMessage from zhenxun.services.log import logger -from .config import CommonOverrides, LLMGenerationConfig -from .config.providers import get_ai_config -from .manager import get_global_default_model_name, get_model_instance -from .tools import tool_registry +from .manager import get_model_instance +from .session import AI from .types import ( EmbeddingTaskType, LLMContentPart, @@ -29,514 +24,31 @@ from .types import ( from .utils import create_multimodal_message, unimsg_to_llm_parts -class TaskType(Enum): - """任务类型枚举""" - - CHAT = "chat" - CODE = "code" - SEARCH = "search" - ANALYSIS = "analysis" - GENERATION = "generation" - MULTIMODAL = "multimodal" - - -@dataclass -class AIConfig: - """AI配置类 - 简化版本""" - - model: ModelName = None - default_embedding_model: ModelName = None - temperature: float | None = None - max_tokens: int | None = None - enable_cache: bool = False - enable_code: bool = False - enable_search: bool = False - timeout: int | None = None - - enable_gemini_json_mode: bool = False - enable_gemini_thinking: bool = False - enable_gemini_safe_mode: bool = False - enable_gemini_multimodal: bool = False - enable_gemini_grounding: bool = False - default_preserve_media_in_history: bool = False - - def __post_init__(self): - """初始化后从配置中读取默认值""" - ai_config = get_ai_config() - if self.model is None: - self.model = ai_config.get("default_model_name") - if self.timeout is None: - self.timeout = ai_config.get("timeout", 180) - - -class AI: - """统一的AI服务类 - 平衡设计版本 - - 提供三层API: - 1. 简单方法:ai.chat(), ai.code(), ai.search() - 2. 标准方法:ai.analyze() 支持复杂参数 - 3. 高级方法:通过get_model_instance()直接访问 - """ - - def __init__( - self, config: AIConfig | None = None, history: list[LLMMessage] | None = None - ): - """ - 初始化AI服务 - - 参数: - config: AI 配置. - history: 可选的初始对话历史. - """ - self.config = config or AIConfig() - self.history = history or [] - - def clear_history(self): - """清空当前会话的历史记录""" - self.history = [] - logger.info("AI session history cleared.") - - def _sanitize_message_for_history(self, message: LLMMessage) -> LLMMessage: - """ - 净化用于存入历史记录的消息。 - 将非文本的多模态内容部分替换为文本占位符,以避免重复处理。 - """ - if not isinstance(message.content, list): - return message - - sanitized_message = copy.deepcopy(message) - content_list = sanitized_message.content - if not isinstance(content_list, list): - return sanitized_message - - new_content_parts: list[LLMContentPart] = [] - has_multimodal_content = False - - for part in content_list: - if isinstance(part, LLMContentPart) and part.type == "text": - new_content_parts.append(part) - else: - has_multimodal_content = True - - if has_multimodal_content: - placeholder = "[用户发送了媒体文件,内容已在首次分析时处理]" - text_part_found = False - for part in new_content_parts: - if part.type == "text": - part.text = f"{placeholder} {part.text or ''}".strip() - text_part_found = True - break - if not text_part_found: - new_content_parts.insert(0, LLMContentPart.text_part(placeholder)) - - sanitized_message.content = new_content_parts - return sanitized_message - - async def chat( - self, - message: str | LLMMessage | list[LLMContentPart], - *, - model: ModelName = None, - preserve_media_in_history: bool | None = None, - **kwargs: Any, - ) -> str: - """ - 进行一次聊天对话。 - 此方法会自动使用和更新会话内的历史记录。 - - 参数: - message: 用户输入的消息。 - model: 本次对话要使用的模型。 - preserve_media_in_history: 是否在历史记录中保留原始多模态信息。 - - True: 保留,用于深度多轮媒体分析。 - - False: 不保留,替换为占位符,提高效率。 - - None (默认): 使用AI实例配置的默认值。 - **kwargs: 传递给模型的其他参数。 - - 返回: - str: 模型的文本响应。 - """ - current_message: LLMMessage - if isinstance(message, str): - current_message = LLMMessage.user(message) - elif isinstance(message, list) and all( - isinstance(part, LLMContentPart) for part in message - ): - current_message = LLMMessage.user(message) - elif isinstance(message, LLMMessage): - current_message = message - else: - raise LLMException( - f"AI.chat 不支持的消息类型: {type(message)}. " - "请使用 str, LLMMessage, 或 list[LLMContentPart]. " - "对于更复杂的多模态输入或文件路径,请使用 AI.analyze().", - code=LLMErrorCode.API_REQUEST_FAILED, - ) - - final_messages = [*self.history, current_message] - - response = await self._execute_generation( - final_messages, model, "聊天失败", kwargs - ) - - should_preserve = ( - preserve_media_in_history - if preserve_media_in_history is not None - else self.config.default_preserve_media_in_history - ) - - if should_preserve: - logger.debug("深度分析模式:在历史记录中保留原始多模态消息。") - self.history.append(current_message) - else: - logger.debug("高效模式:净化历史记录中的多模态消息。") - sanitized_user_message = self._sanitize_message_for_history(current_message) - self.history.append(sanitized_user_message) - - self.history.append(LLMMessage.assistant_text_response(response.text)) - - return response.text - - async def code( - self, - prompt: str, - *, - model: ModelName = None, - timeout: int | None = None, - **kwargs: Any, - ) -> dict[str, Any]: - """ - 代码执行 - - 参数: - prompt: 代码执行的提示词。 - model: 要使用的模型名称。 - timeout: 代码执行超时时间(秒)。 - **kwargs: 传递给模型的其他参数。 - - 返回: - dict[str, Any]: 包含执行结果的字典,包含text、code_executions和success字段。 - """ - resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash" - - config = CommonOverrides.gemini_code_execution() - if timeout: - config.custom_params = config.custom_params or {} - config.custom_params["code_execution_timeout"] = timeout - - messages = [LLMMessage.user(prompt)] - - response = await self._execute_generation( - messages, resolved_model, "代码执行失败", kwargs, base_config=config - ) - - return { - "text": response.text, - "code_executions": response.code_executions or [], - "success": True, - } - - async def search( - self, - query: str | UniMessage, - *, - model: ModelName = None, - instruction: str = "", - **kwargs: Any, - ) -> dict[str, Any]: - """ - 信息搜索 - 支持多模态输入 - - 参数: - query: 搜索查询内容,支持文本或多模态消息。 - model: 要使用的模型名称。 - instruction: 搜索指令。 - **kwargs: 传递给模型的其他参数。 - - 返回: - dict[str, Any]: 包含搜索结果的字典,包含text、sources、queries和success字段 - """ - resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash" - config = CommonOverrides.gemini_grounding() - - if isinstance(query, str): - messages = [LLMMessage.user(query)] - elif isinstance(query, UniMessage): - content_parts = await unimsg_to_llm_parts(query) - - final_messages: list[LLMMessage] = [] - if instruction: - final_messages.append(LLMMessage.system(instruction)) - - if not content_parts: - if instruction: - final_messages.append(LLMMessage.user(instruction)) - else: - raise LLMException( - "搜索内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED - ) - else: - final_messages.append(LLMMessage.user(content_parts)) - - messages = final_messages - else: - raise LLMException( - f"不支持的搜索输入类型: {type(query)}. 请使用 str 或 UniMessage.", - code=LLMErrorCode.API_REQUEST_FAILED, - ) - - response = await self._execute_generation( - messages, resolved_model, "信息搜索失败", kwargs, base_config=config - ) - - result = { - "text": response.text, - "sources": [], - "queries": [], - "success": True, - } - - if response.grounding_metadata: - result["sources"] = response.grounding_metadata.grounding_attributions or [] - result["queries"] = response.grounding_metadata.web_search_queries or [] - - return result - - async def analyze( - self, - message: UniMessage | None, - *, - instruction: str = "", - model: ModelName = None, - use_tools: list[str] | None = None, - tool_config: dict[str, Any] | None = None, - activated_tools: list[LLMTool] | None = None, - history: list[LLMMessage] | None = None, - **kwargs: Any, - ) -> LLMResponse: - """ - 内容分析 - 接收 UniMessage 物件进行多模态分析和工具呼叫。 - - 参数: - message: 要分析的消息内容(支持多模态)。 - instruction: 分析指令。 - model: 要使用的模型名称。 - use_tools: 要使用的工具名称列表。 - tool_config: 工具配置。 - activated_tools: 已激活的工具列表。 - history: 对话历史记录。 - **kwargs: 传递给模型的其他参数。 - - 返回: - LLMResponse: 模型的完整响应结果。 - """ - content_parts = await unimsg_to_llm_parts(message or UniMessage()) - - final_messages: list[LLMMessage] = [] - if history: - final_messages.extend(history) - - if instruction: - if not any(msg.role == "system" for msg in final_messages): - final_messages.insert(0, LLMMessage.system(instruction)) - - if not content_parts: - if instruction and not history: - final_messages.append(LLMMessage.user(instruction)) - elif not history: - raise LLMException( - "分析内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED - ) - else: - final_messages.append(LLMMessage.user(content_parts)) - - llm_tools: list[LLMTool] | None = activated_tools - if not llm_tools and use_tools: - try: - llm_tools = tool_registry.get_tools(use_tools) - logger.debug(f"已从注册表加载工具定义: {use_tools}") - except ValueError as e: - raise LLMException( - f"加载工具定义失败: {e}", - code=LLMErrorCode.CONFIGURATION_ERROR, - cause=e, - ) - - tool_choice = None - if tool_config: - mode = tool_config.get("mode", "auto") - if mode in ["auto", "any", "none"]: - tool_choice = mode - - response = await self._execute_generation( - final_messages, - model, - "内容分析失败", - kwargs, - llm_tools=llm_tools, - tool_choice=tool_choice, - ) - - return response - - async def _execute_generation( - self, - messages: list[LLMMessage], - model_name: ModelName, - error_message: str, - config_overrides: dict[str, Any], - llm_tools: list[LLMTool] | None = None, - tool_choice: str | dict[str, Any] | None = None, - base_config: LLMGenerationConfig | None = None, - ) -> LLMResponse: - """通用的生成执行方法,封装模型获取和单次API调用""" - try: - resolved_model_name = self._resolve_model_name( - model_name or self.config.model - ) - final_config_dict = self._merge_config( - config_overrides, base_config=base_config - ) - - async with await get_model_instance( - resolved_model_name, override_config=final_config_dict - ) as model_instance: - return await model_instance.generate_response( - messages, - tools=llm_tools, - tool_choice=tool_choice, - ) - except LLMException: - raise - except Exception as e: - logger.error(f"{error_message}: {e}", e=e) - raise LLMException(f"{error_message}: {e}", cause=e) - - def _resolve_model_name(self, model_name: ModelName) -> str: - """解析模型名称""" - if model_name: - return model_name - - default_model = get_global_default_model_name() - if default_model: - return default_model - - raise LLMException( - "未指定模型名称且未设置全局默认模型", - code=LLMErrorCode.MODEL_NOT_FOUND, - ) - - def _merge_config( - self, - user_config: dict[str, Any], - base_config: LLMGenerationConfig | None = None, - ) -> dict[str, Any]: - """合并配置""" - final_config = {} - if base_config: - final_config.update(base_config.to_dict()) - - if self.config.temperature is not None: - final_config["temperature"] = self.config.temperature - if self.config.max_tokens is not None: - final_config["max_tokens"] = self.config.max_tokens - - if self.config.enable_cache: - final_config["enable_caching"] = True - if self.config.enable_code: - final_config["enable_code_execution"] = True - if self.config.enable_search: - final_config["enable_grounding"] = True - - if self.config.enable_gemini_json_mode: - final_config["response_mime_type"] = "application/json" - if self.config.enable_gemini_thinking: - final_config["thinking_budget"] = 0.8 - if self.config.enable_gemini_safe_mode: - final_config["safety_settings"] = ( - CommonOverrides.gemini_safe().safety_settings - ) - if self.config.enable_gemini_multimodal: - final_config.update(CommonOverrides.gemini_multimodal().to_dict()) - if self.config.enable_gemini_grounding: - final_config["enable_grounding"] = True - - final_config.update(user_config) - - return final_config - - async def embed( - self, - texts: list[str] | str, - *, - model: ModelName = None, - task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT, - **kwargs: Any, - ) -> list[list[float]]: - """ - 生成文本嵌入向量 - - 参数: - texts: 要生成嵌入向量的文本或文本列表。 - model: 要使用的嵌入模型名称。 - task_type: 嵌入任务类型。 - **kwargs: 传递给模型的其他参数。 - - 返回: - list[list[float]]: 文本的嵌入向量列表。 - """ - if isinstance(texts, str): - texts = [texts] - if not texts: - return [] - - try: - resolved_model_str = ( - model or self.config.default_embedding_model or self.config.model - ) - if not resolved_model_str: - raise LLMException( - "使用 embed 功能时必须指定嵌入模型名称," - "或在 AIConfig 中配置 default_embedding_model。", - code=LLMErrorCode.MODEL_NOT_FOUND, - ) - resolved_model_str = self._resolve_model_name(resolved_model_str) - - async with await get_model_instance( - resolved_model_str, - override_config=None, - ) as embedding_model_instance: - return await embedding_model_instance.generate_embeddings( - texts, task_type=task_type, **kwargs - ) - except LLMException: - raise - except Exception as e: - logger.error(f"文本嵌入失败: {e}", e=e) - raise LLMException( - f"文本嵌入失败: {e}", code=LLMErrorCode.EMBEDDING_FAILED, cause=e - ) - - async def chat( message: str | LLMMessage | list[LLMContentPart], *, model: ModelName = None, + tools: list[LLMTool] | None = None, + tool_choice: str | dict[str, Any] | None = None, **kwargs: Any, -) -> str: +) -> LLMResponse: """ 聊天对话便捷函数 参数: message: 用户输入的消息。 model: 要使用的模型名称。 + tools: 本次对话可用的工具列表。 + tool_choice: 强制模型使用的工具。 **kwargs: 传递给模型的其他参数。 返回: - str: 模型的文本响应。 + LLMResponse: 模型的完整响应,可能包含文本或工具调用请求。 """ ai = AI() - return await ai.chat(message, model=model, **kwargs) + return await ai.chat( + message, model=model, tools=tools, tool_choice=tool_choice, **kwargs + ) async def code( @@ -730,12 +242,14 @@ async def pipeline_chat( raise ValueError("模型链`model_chain`不能为空。") current_content: str | list[LLMContentPart] - if isinstance(message, str): + if isinstance(message, UniMessage): + current_content = await unimsg_to_llm_parts(message) + elif isinstance(message, str): current_content = message elif isinstance(message, list): current_content = message else: - current_content = await unimsg_to_llm_parts(message) + raise TypeError(f"不支持的消息类型: {type(message)}") final_response: LLMResponse | None = None @@ -787,3 +301,45 @@ async def pipeline_chat( ) return final_response + + +async def generate( + messages: list[LLMMessage], + *, + model: ModelName = None, + tools: list[LLMTool] | None = None, + tool_choice: str | dict[str, Any] | None = None, + **kwargs: Any, +) -> LLMResponse: + """ + 根据完整的消息列表(包括系统指令)生成一次性响应。 + 这是一个便捷的函数,不使用或修改任何会话历史。 + + 参数: + messages: 用于生成响应的完整消息列表。 + model: 要使用的模型名称。 + tools: 可用的工具列表。 + tool_choice: 工具选择策略。 + **kwargs: 传递给模型的其他参数。 + + 返回: + LLMResponse: 模型的完整响应对象。 + """ + try: + ai_instance = AI() + resolved_model_name = ai_instance._resolve_model_name(model) + final_config_dict = ai_instance._merge_config(kwargs) + + async with await get_model_instance( + resolved_model_name, override_config=final_config_dict + ) as model_instance: + return await model_instance.generate_response( + messages, + tools=tools, + tool_choice=tool_choice, + ) + except LLMException: + raise + except Exception as e: + logger.error(f"生成响应失败: {e}", e=e) + raise LLMException(f"生成响应失败: {e}", cause=e) diff --git a/zhenxun/services/llm/config/providers.py b/zhenxun/services/llm/config/providers.py index a39e32c9..96d30cdf 100644 --- a/zhenxun/services/llm/config/providers.py +++ b/zhenxun/services/llm/config/providers.py @@ -17,6 +17,7 @@ from zhenxun.configs.utils import parse_as from zhenxun.services.log import logger from zhenxun.utils.manager.priority_manager import PriorityLifecycle +from ..core import key_store from ..types.models import ModelDetail, ProviderConfig @@ -502,12 +503,13 @@ def set_default_model(provider_model_name: str | None) -> bool: @PriorityLifecycle.on_startup(priority=10) async def _init_llm_config_on_startup(): """ - 在服务启动时主动调用一次 get_llm_config, - 以触发必要的初始化操作,例如创建默认的 mcp_tools.json 文件。 + 在服务启动时主动调用一次 get_llm_config 和 key_store.initialize, + 以触发必要的初始化操作。 """ - logger.info("正在初始化 LLM 配置并检查 MCP 工具文件...") + logger.info("正在初始化 LLM 配置并加载密钥状态...") try: get_llm_config() - logger.info("LLM 配置初始化完成。") + await key_store.initialize() + logger.info("LLM 配置和密钥状态初始化完成。") except Exception as e: - logger.error(f"LLM 配置初始化时发生错误: {e}", e=e) + logger.error(f"LLM 配置或密钥状态初始化时发生错误: {e}", e=e) diff --git a/zhenxun/services/llm/core.py b/zhenxun/services/llm/core.py index 56591701..6e5b5960 100644 --- a/zhenxun/services/llm/core.py +++ b/zhenxun/services/llm/core.py @@ -5,17 +5,27 @@ LLM 核心基础设施模块 """ import asyncio +from dataclasses import asdict, dataclass +from enum import IntEnum +import json +import os +import time from typing import Any +import aiofiles import httpx +import nonebot from pydantic import BaseModel +from zhenxun.configs.path_config import DATA_PATH from zhenxun.services.log import logger from zhenxun.utils.user_agent import get_user_agent from .types import ProviderConfig from .types.exceptions import LLMErrorCode, LLMException +driver = nonebot.get_driver() + class HttpClientConfig(BaseModel): """HTTP客户端配置""" @@ -194,6 +204,82 @@ async def create_llm_http_client( return LLMHttpClient(config) +class KeyStatus(IntEnum): + """用于排序和展示的密钥状态枚举""" + + DISABLED = 0 + ERROR = 1 + COOLDOWN = 2 + WARNING = 3 + HEALTHY = 4 + UNUSED = 5 + + +@dataclass +class KeyStats: + """单个API Key的详细状态和统计信息""" + + cooldown_until: float = 0.0 + success_count: int = 0 + failure_count: int = 0 + total_latency: float = 0.0 + last_error_info: str | None = None + + @property + def is_available(self) -> bool: + """检查Key当前是否可用""" + return time.time() >= self.cooldown_until + + @property + def avg_latency(self) -> float: + """计算平均延迟""" + return ( + self.total_latency / self.success_count if self.success_count > 0 else 0.0 + ) + + @property + def success_rate(self) -> float: + """计算成功率""" + total = self.success_count + self.failure_count + return self.success_count / total * 100 if total > 0 else 100.0 + + @property + def status(self) -> KeyStatus: + """根据当前统计数据动态计算状态""" + now = time.time() + cooldown_left = max(0, self.cooldown_until - now) + + if cooldown_left > 31536000 - 60: + return KeyStatus.DISABLED + if cooldown_left > 0: + return KeyStatus.COOLDOWN + + total_calls = self.success_count + self.failure_count + if total_calls == 0: + return KeyStatus.UNUSED + + if self.success_rate < 80: + return KeyStatus.ERROR + + if total_calls >= 5 and self.avg_latency > 15000: + return KeyStatus.WARNING + + return KeyStatus.HEALTHY + + @property + def suggested_action(self) -> str: + """根据状态给出建议操作""" + status_actions = { + KeyStatus.DISABLED: "更换Key", + KeyStatus.ERROR: "检查网络/重置", + KeyStatus.COOLDOWN: "等待/重置", + KeyStatus.WARNING: "观察", + KeyStatus.HEALTHY: "-", + KeyStatus.UNUSED: "-", + } + return status_actions.get(self.status, "未知") + + class RetryConfig: """重试配置""" @@ -236,26 +322,38 @@ async def with_smart_retry( last_exception: Exception | None = None failed_keys: set[str] = set() + model_instance = next((arg for arg in args if hasattr(arg, "api_keys")), None) + all_provider_keys = model_instance.api_keys if model_instance else [] + for attempt in range(config.max_retries + 1): try: if config.key_rotation and "failed_keys" in func.__code__.co_varnames: kwargs["failed_keys"] = failed_keys - return await func(*args, **kwargs) + start_time = time.monotonic() + result = await func(*args, **kwargs) + latency = (time.monotonic() - start_time) * 1000 + + if key_store and isinstance(result, tuple) and len(result) == 2: + final_result, api_key_used = result + if api_key_used: + await key_store.record_success(api_key_used, latency) + return final_result + else: + return result except LLMException as e: last_exception = e + api_key_in_use = e.details.get("api_key") - if e.code in [ - LLMErrorCode.API_KEY_INVALID, - LLMErrorCode.API_QUOTA_EXCEEDED, - ]: - if hasattr(e, "details") and e.details and "api_key" in e.details: - failed_keys.add(e.details["api_key"]) - if key_store and provider_name: - await key_store.record_failure( - e.details["api_key"], e.details.get("status_code") - ) + if api_key_in_use: + failed_keys.add(api_key_in_use) + if key_store and provider_name and len(all_provider_keys) > 1: + status_code = e.details.get("status_code") + error_message = f"({e.code.name}) {e.message}" + await key_store.record_failure( + api_key_in_use, status_code, error_message + ) should_retry = _should_retry_llm_error(e, attempt, config.max_retries) if not should_retry: @@ -267,7 +365,7 @@ async def with_smart_retry( if config.exponential_backoff: wait_time *= 2**attempt logger.warning( - f"请求失败,{wait_time}秒后重试 (第{attempt + 1}次): {e}" + f"请求失败,{wait_time:.2f}秒后重试 (第{attempt + 1}次): {e}" ) await asyncio.sleep(wait_time) else: @@ -325,14 +423,66 @@ def _should_retry_llm_error( class KeyStatusStore: - """API Key 状态管理存储 - 优化版本,支持轮询和负载均衡""" + """API Key 状态管理存储 - 支持持久化""" def __init__(self): - self._key_status: dict[str, bool] = {} - self._key_usage_count: dict[str, int] = {} - self._key_last_used: dict[str, float] = {} + self._key_stats: dict[str, KeyStats] = {} self._provider_key_index: dict[str, int] = {} self._lock = asyncio.Lock() + self._file_path = DATA_PATH / "llm" / "key_status.json" + + async def initialize(self): + """从文件异步加载密钥状态,在应用启动时调用""" + async with self._lock: + if not self._file_path.exists(): + logger.info("未找到密钥状态文件,将使用内存状态启动。") + return + + try: + logger.info(f"正在从 {self._file_path} 加载密钥状态...") + async with aiofiles.open(self._file_path, encoding="utf-8") as f: + content = await f.read() + if not content: + logger.warning("密钥状态文件为空。") + return + data = json.loads(content) + + for key, stats_dict in data.items(): + self._key_stats[key] = KeyStats(**stats_dict) + + logger.info(f"成功加载 {len(self._key_stats)} 个密钥的状态。") + + except json.JSONDecodeError: + logger.error(f"密钥状态文件 {self._file_path} 格式错误,无法解析。") + except Exception as e: + logger.error(f"加载密钥状态文件时发生错误: {e}", e=e) + + async def _save_to_file_internal(self): + """ + [内部方法] 将当前密钥状态安全地写入JSON文件。 + 假定调用方已持有锁。 + """ + data_to_save = {key: asdict(stats) for key, stats in self._key_stats.items()} + + try: + self._file_path.parent.mkdir(parents=True, exist_ok=True) + temp_path = self._file_path.with_suffix(".json.tmp") + + async with aiofiles.open(temp_path, "w", encoding="utf-8") as f: + await f.write(json.dumps(data_to_save, ensure_ascii=False, indent=2)) + + if self._file_path.exists(): + self._file_path.unlink() + os.rename(temp_path, self._file_path) + logger.debug("密钥状态已成功持久化到文件。") + except Exception as e: + logger.error(f"保存密钥状态到文件失败: {e}", e=e) + + async def shutdown(self): + """在应用关闭时安全地保存状态""" + async with self._lock: + await self._save_to_file_internal() + logger.info("KeyStatusStore 已在关闭前保存状态。") async def get_next_available_key( self, @@ -355,88 +505,122 @@ class KeyStatusStore: return None exclude_keys = exclude_keys or set() - available_keys = [ - key - for key in api_keys - if key not in exclude_keys and self._key_status.get(key, True) - ] - - if not available_keys: - return api_keys[0] if api_keys else None async with self._lock: + for key in api_keys: + if key not in self._key_stats: + self._key_stats[key] = KeyStats() + + available_keys = [ + key + for key in api_keys + if key not in exclude_keys and self._key_stats[key].is_available + ] + + if not available_keys: + return api_keys[0] + current_index = self._provider_key_index.get(provider_name, 0) - selected_key = available_keys[current_index % len(available_keys)] + self._provider_key_index[provider_name] = current_index + 1 - self._provider_key_index[provider_name] = (current_index + 1) % len( - available_keys + total_usage = ( + self._key_stats[selected_key].success_count + + self._key_stats[selected_key].failure_count ) - - import time - - self._key_usage_count[selected_key] = ( - self._key_usage_count.get(selected_key, 0) + 1 - ) - self._key_last_used[selected_key] = time.time() - logger.debug( f"轮询选择API密钥: {self._get_key_id(selected_key)} " - f"(使用次数: {self._key_usage_count[selected_key]})" + f"(使用次数: {total_usage})" ) - return selected_key - async def record_success(self, api_key: str): - """记录成功使用""" + async def record_success(self, api_key: str, latency: float): + """记录成功使用,并持久化""" async with self._lock: - self._key_status[api_key] = True - logger.debug(f"记录API密钥成功使用: {self._get_key_id(api_key)}") + stats = self._key_stats.setdefault(api_key, KeyStats()) + stats.cooldown_until = 0.0 + stats.success_count += 1 + stats.total_latency += latency + stats.last_error_info = None + await self._save_to_file_internal() + logger.debug( + f"记录API密钥成功使用: {self._get_key_id(api_key)}, 延迟: {latency:.2f}ms" + ) - async def record_failure(self, api_key: str, status_code: int | None): + async def record_failure( + self, api_key: str, status_code: int | None, error_message: str + ): """ - 记录失败使用 + 记录失败使用,并设置冷却时间 参数: api_key: API密钥。 status_code: HTTP状态码。 + error_message: 错误信息。 """ key_id = self._get_key_id(api_key) + now = time.time() + cooldown_duration = 300 + + if status_code in [401, 403, 404]: + cooldown_duration = 31536000 + log_level = "error" + log_message = f"API密钥认证/权限/路径错误,将永久禁用: {key_id}" + elif status_code == 429: + cooldown_duration = 60 + log_level = "warning" + log_message = f"API密钥被限流,冷却60秒: {key_id}" + else: + log_level = "warning" + log_message = f"API密钥遇到临时性错误,冷却{cooldown_duration}秒: {key_id}" + async with self._lock: - if status_code in [401, 403]: - self._key_status[api_key] = False - logger.warning( - f"API密钥认证失败,标记为不可用: {key_id} (状态码: {status_code})" - ) - else: - logger.debug(f"记录API密钥失败使用: {key_id} (状态码: {status_code})") + stats = self._key_stats.setdefault(api_key, KeyStats()) + stats.cooldown_until = now + cooldown_duration + stats.failure_count += 1 + stats.last_error_info = error_message[:256] + await self._save_to_file_internal() + + getattr(logger, log_level)(log_message) async def reset_key_status(self, api_key: str): - """重置密钥状态(用于恢复机制)""" + """重置密钥状态,并持久化""" async with self._lock: - self._key_status[api_key] = True + stats = self._key_stats.setdefault(api_key, KeyStats()) + stats.cooldown_until = 0.0 + stats.last_error_info = None + await self._save_to_file_internal() logger.info(f"重置API密钥状态: {self._get_key_id(api_key)}") async def get_key_stats(self, api_keys: list[str]) -> dict[str, dict]: """ - 获取密钥使用统计 + 获取密钥使用统计,并计算出用于展示的派生数据。 参数: api_keys: API密钥列表。 返回: - dict[str, dict]: 密钥统计信息字典。 + dict[str, dict]: 包含丰富状态和统计信息的密钥字典。 """ - stats = {} + stats_dict = {} + now = time.time() async with self._lock: for key in api_keys: key_id = self._get_key_id(key) - stats[key_id] = { - "available": self._key_status.get(key, True), - "usage_count": self._key_usage_count.get(key, 0), - "last_used": self._key_last_used.get(key, 0), + stats = self._key_stats.get(key, KeyStats()) + + stats_dict[key_id] = { + "status_enum": stats.status, + "cooldown_seconds_left": max(0, stats.cooldown_until - now), + "total_calls": stats.success_count + stats.failure_count, + "success_count": stats.success_count, + "failure_count": stats.failure_count, + "success_rate": stats.success_rate, + "avg_latency": stats.avg_latency, + "last_error": stats.last_error_info, + "suggested_action": stats.suggested_action, } - return stats + return stats_dict def _get_key_id(self, api_key: str) -> str: """获取API密钥的标识符(用于日志)""" @@ -446,3 +630,8 @@ class KeyStatusStore: key_store = KeyStatusStore() + + +@driver.on_shutdown +async def _shutdown_key_store(): + await key_store.shutdown() diff --git a/zhenxun/services/llm/manager.py b/zhenxun/services/llm/manager.py index f0e9c560..6bad91f1 100644 --- a/zhenxun/services/llm/manager.py +++ b/zhenxun/services/llm/manager.py @@ -137,8 +137,8 @@ def get_configured_providers() -> list[ProviderConfig]: valid_providers.append(item) else: logger.warning( - f"配置文件中第 {i + 1} 项未能正确解析为 ProviderConfig 对象," - f"已跳过。实际类型: {type(item)}" + f"配置文件中第 {i + 1} 项未能正确解析为 ProviderConfig 对象,已跳过。" + f"实际类型: {type(item)}" ) return valid_providers diff --git a/zhenxun/services/llm/service.py b/zhenxun/services/llm/service.py index 587b15cc..76d846ba 100644 --- a/zhenxun/services/llm/service.py +++ b/zhenxun/services/llm/service.py @@ -46,17 +46,7 @@ class LLMModelBase(ABC): history: list[dict[str, str]] | None = None, **kwargs: Any, ) -> str: - """ - 生成文本 - - 参数: - prompt: 输入提示词。 - history: 对话历史记录。 - **kwargs: 其他参数。 - - 返回: - str: 生成的文本。 - """ + """生成文本""" pass @abstractmethod @@ -68,19 +58,7 @@ class LLMModelBase(ABC): tool_choice: str | dict[str, Any] | None = None, **kwargs: Any, ) -> LLMResponse: - """ - 生成高级响应 - - 参数: - messages: 消息列表。 - config: 生成配置。 - tools: 工具列表。 - tool_choice: 工具选择策略。 - **kwargs: 其他参数。 - - 返回: - LLMResponse: 模型响应。 - """ + """生成高级响应""" pass @abstractmethod @@ -90,17 +68,7 @@ class LLMModelBase(ABC): task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT, **kwargs: Any, ) -> list[list[float]]: - """ - 生成文本嵌入向量 - - 参数: - texts: 文本列表。 - task_type: 嵌入任务类型。 - **kwargs: 其他参数。 - - 返回: - list[list[float]]: 嵌入向量列表。 - """ + """生成文本嵌入向量""" pass @@ -208,28 +176,8 @@ class LLMModel(LLMModelBase): http_client: "LLMHttpClient", failed_keys: set[str] | None = None, log_context: str = "API", - ) -> Any: - """ - 执行API调用的通用核心方法。 - - 该方法封装了以下通用逻辑: - 1. 选择API密钥。 - 2. 准备和记录请求。 - 3. 发送HTTP POST请求。 - 4. 处理HTTP错误和API特定错误。 - 5. 记录密钥使用状态。 - 6. 解析成功的响应。 - - 参数: - prepare_request_func: 准备请求的函数。 - parse_response_func: 解析响应的函数。 - http_client: HTTP客户端。 - failed_keys: 失败的密钥集合。 - log_context: 日志上下文。 - - 返回: - Any: 解析后的响应数据。 - """ + ) -> tuple[Any, str]: + """执行API调用的通用核心方法""" api_key = await self._select_api_key(failed_keys) try: @@ -267,7 +215,9 @@ class LLMModel(LLMModelBase): ) logger.debug(f"💥 完整错误响应: {error_text}") - await self.key_store.record_failure(api_key, http_response.status_code) + await self.key_store.record_failure( + api_key, http_response.status_code, error_text + ) if http_response.status_code in [401, 403]: error_code = LLMErrorCode.API_KEY_INVALID @@ -298,7 +248,7 @@ class LLMModel(LLMModelBase): except Exception as e: logger.error(f"解析 {log_context} 响应失败: {e}", e=e) - await self.key_store.record_failure(api_key, None) + await self.key_store.record_failure(api_key, None, str(e)) if isinstance(e, LLMException): raise else: @@ -308,17 +258,15 @@ class LLMModel(LLMModelBase): cause=e, ) - await self.key_store.record_success(api_key) - logger.debug(f"✅ API密钥使用成功: {masked_key}") logger.info(f"🎯 LLM响应解析完成 [{log_context}]") - return parsed_data + return parsed_data, api_key except LLMException: raise except Exception as e: error_log_msg = f"生成 {log_context.lower()} 时发生未预期错误: {e}" logger.error(error_log_msg, e=e) - await self.key_store.record_failure(api_key, None) + await self.key_store.record_failure(api_key, None, str(e)) raise LLMException( error_log_msg, code=LLMErrorCode.GENERATION_FAILED @@ -349,13 +297,14 @@ class LLMModel(LLMModelBase): adapter.validate_embedding_response(response_json) return adapter.parse_embedding_response(response_json) - return await self._perform_api_call( + parsed_data, api_key_used = await self._perform_api_call( prepare_request_func=prepare_request, parse_response_func=parse_response, http_client=http_client, failed_keys=failed_keys, log_context="Embedding", ) + return parsed_data async def _execute_with_smart_retry( self, @@ -394,8 +343,8 @@ class LLMModel(LLMModelBase): tool_choice: str | dict[str, Any] | None, http_client: LLMHttpClient, failed_keys: set[str] | None = None, - ) -> LLMResponse: - """执行单次请求 - 供重试机制调用,直接返回 LLMResponse""" + ) -> tuple[LLMResponse, str]: + """执行单次请求 - 供重试机制调用,直接返回 LLMResponse 和使用的 key""" async def prepare_request(api_key: str) -> RequestData: return await adapter.prepare_advanced_request( @@ -441,19 +390,17 @@ class LLMModel(LLMModelBase): cache_info=response_data.cache_info, ) - return await self._perform_api_call( + parsed_data, api_key_used = await self._perform_api_call( prepare_request_func=prepare_request, parse_response_func=parse_response, http_client=http_client, failed_keys=failed_keys, log_context="Generation", ) + return parsed_data, api_key_used async def close(self): - """ - 标记模型实例的当前使用周期结束。 - 共享的 HTTP 客户端由 LLMHttpClientManager 管理,不由 LLMModel 关闭。 - """ + """标记模型实例的当前使用周期结束""" if self._is_closed: return self._is_closed = True @@ -487,17 +434,7 @@ class LLMModel(LLMModelBase): history: list[dict[str, str]] | None = None, **kwargs: Any, ) -> str: - """ - 生成文本 - 通过 generate_response 实现 - - 参数: - prompt: 输入提示词。 - history: 对话历史记录。 - **kwargs: 其他参数。 - - 返回: - str: 生成的文本。 - """ + """生成文本""" self._check_not_closed() messages: list[LLMMessage] = [] @@ -538,19 +475,7 @@ class LLMModel(LLMModelBase): tool_choice: str | dict[str, Any] | None = None, **kwargs: Any, ) -> LLMResponse: - """ - 生成高级响应 - - 参数: - messages: 消息列表。 - config: 生成配置。 - tools: 工具列表。 - tool_choice: 工具选择策略。 - **kwargs: 其他参数。 - - 返回: - LLMResponse: 模型响应。 - """ + """生成高级响应""" self._check_not_closed() from .adapters import get_adapter_for_api_type @@ -619,17 +544,7 @@ class LLMModel(LLMModelBase): task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT, **kwargs: Any, ) -> list[list[float]]: - """ - 生成文本嵌入向量 - - 参数: - texts: 文本列表。 - task_type: 嵌入任务类型。 - **kwargs: 其他参数。 - - 返回: - list[list[float]]: 嵌入向量列表。 - """ + """生成文本嵌入向量""" self._check_not_closed() if not texts: return [] diff --git a/zhenxun/services/llm/session.py b/zhenxun/services/llm/session.py new file mode 100644 index 00000000..ed23eeca --- /dev/null +++ b/zhenxun/services/llm/session.py @@ -0,0 +1,532 @@ +""" +LLM 服务 - 会话客户端 + +提供一个有状态的、面向会话的 LLM 客户端,用于进行多轮对话和复杂交互。 +""" + +import copy +from dataclasses import dataclass +from typing import Any + +from nonebot_plugin_alconna.uniseg import UniMessage + +from zhenxun.services.log import logger + +from .config import CommonOverrides, LLMGenerationConfig +from .config.providers import get_ai_config +from .manager import get_global_default_model_name, get_model_instance +from .tools import tool_registry +from .types import ( + EmbeddingTaskType, + LLMContentPart, + LLMErrorCode, + LLMException, + LLMMessage, + LLMResponse, + LLMTool, + ModelName, +) +from .utils import unimsg_to_llm_parts + + +@dataclass +class AIConfig: + """AI配置类 - 简化版本""" + + model: ModelName = None + default_embedding_model: ModelName = None + temperature: float | None = None + max_tokens: int | None = None + enable_cache: bool = False + enable_code: bool = False + enable_search: bool = False + timeout: int | None = None + + enable_gemini_json_mode: bool = False + enable_gemini_thinking: bool = False + enable_gemini_safe_mode: bool = False + enable_gemini_multimodal: bool = False + enable_gemini_grounding: bool = False + default_preserve_media_in_history: bool = False + + def __post_init__(self): + """初始化后从配置中读取默认值""" + ai_config = get_ai_config() + if self.model is None: + self.model = ai_config.get("default_model_name") + if self.timeout is None: + self.timeout = ai_config.get("timeout", 180) + + +class AI: + """统一的AI服务类 - 平衡设计版本 + + 提供三层API: + 1. 简单方法:ai.chat(), ai.code(), ai.search() + 2. 标准方法:ai.analyze() 支持复杂参数 + 3. 高级方法:通过get_model_instance()直接访问 + """ + + def __init__( + self, config: AIConfig | None = None, history: list[LLMMessage] | None = None + ): + """ + 初始化AI服务 + + 参数: + config: AI 配置. + history: 可选的初始对话历史. + """ + self.config = config or AIConfig() + self.history = history or [] + + def clear_history(self): + """清空当前会话的历史记录""" + self.history = [] + logger.info("AI session history cleared.") + + def _sanitize_message_for_history(self, message: LLMMessage) -> LLMMessage: + """ + 净化用于存入历史记录的消息。 + 将非文本的多模态内容部分替换为文本占位符,以避免重复处理。 + """ + if not isinstance(message.content, list): + return message + + sanitized_message = copy.deepcopy(message) + content_list = sanitized_message.content + if not isinstance(content_list, list): + return sanitized_message + + new_content_parts: list[LLMContentPart] = [] + has_multimodal_content = False + + for part in content_list: + if isinstance(part, LLMContentPart) and part.type == "text": + new_content_parts.append(part) + else: + has_multimodal_content = True + + if has_multimodal_content: + placeholder = "[用户发送了媒体文件,内容已在首次分析时处理]" + text_part_found = False + for part in new_content_parts: + if part.type == "text": + part.text = f"{placeholder} {part.text or ''}".strip() + text_part_found = True + break + if not text_part_found: + new_content_parts.insert(0, LLMContentPart.text_part(placeholder)) + + sanitized_message.content = new_content_parts + return sanitized_message + + async def chat( + self, + message: str | LLMMessage | list[LLMContentPart], + *, + model: ModelName = None, + preserve_media_in_history: bool | None = None, + tools: list[LLMTool] | None = None, + tool_choice: str | dict[str, Any] | None = None, + **kwargs: Any, + ) -> LLMResponse: + """ + 进行一次聊天对话,支持工具调用。 + 此方法会自动使用和更新会话内的历史记录。 + + 参数: + message: 用户输入的消息。 + model: 本次对话要使用的模型。 + preserve_media_in_history: 是否在历史记录中保留原始多模态信息。 + - True: 保留,用于深度多轮媒体分析。 + - False: 不保留,替换为占位符,提高效率。 + - None (默认): 使用AI实例配置的默认值。 + tools: 本次对话可用的工具列表。 + tool_choice: 强制模型使用的工具。 + **kwargs: 传递给模型的其他生成参数。 + + 返回: + LLMResponse: 模型的完整响应,可能包含文本或工具调用请求。 + """ + current_message: LLMMessage + if isinstance(message, str): + current_message = LLMMessage.user(message) + elif isinstance(message, list) and all( + isinstance(part, LLMContentPart) for part in message + ): + current_message = LLMMessage.user(message) + elif isinstance(message, LLMMessage): + current_message = message + else: + raise LLMException( + f"AI.chat 不支持的消息类型: {type(message)}. " + "请使用 str, LLMMessage, 或 list[LLMContentPart]. " + "对于更复杂的多模态输入或文件路径,请使用 AI.analyze().", + code=LLMErrorCode.API_REQUEST_FAILED, + ) + + final_messages = [*self.history, current_message] + + response = await self._execute_generation( + messages=final_messages, + model_name=model, + error_message="聊天失败", + config_overrides=kwargs, + llm_tools=tools, + tool_choice=tool_choice, + ) + + should_preserve = ( + preserve_media_in_history + if preserve_media_in_history is not None + else self.config.default_preserve_media_in_history + ) + + if should_preserve: + logger.debug("深度分析模式:在历史记录中保留原始多模态消息。") + self.history.append(current_message) + else: + logger.debug("高效模式:净化历史记录中的多模态消息。") + sanitized_user_message = self._sanitize_message_for_history(current_message) + self.history.append(sanitized_user_message) + + self.history.append( + LLMMessage( + role="assistant", content=response.text, tool_calls=response.tool_calls + ) + ) + + return response + + async def code( + self, + prompt: str, + *, + model: ModelName = None, + timeout: int | None = None, + **kwargs: Any, + ) -> dict[str, Any]: + """ + 代码执行 + + 参数: + prompt: 代码执行的提示词。 + model: 要使用的模型名称。 + timeout: 代码执行超时时间(秒)。 + **kwargs: 传递给模型的其他参数。 + + 返回: + dict[str, Any]: 包含执行结果的字典,包含text、code_executions和success字段。 + """ + resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash" + + config = CommonOverrides.gemini_code_execution() + if timeout: + config.custom_params = config.custom_params or {} + config.custom_params["code_execution_timeout"] = timeout + + messages = [LLMMessage.user(prompt)] + + response = await self._execute_generation( + messages=messages, + model_name=resolved_model, + error_message="代码执行失败", + config_overrides=kwargs, + base_config=config, + ) + + return { + "text": response.text, + "code_executions": response.code_executions or [], + "success": True, + } + + async def search( + self, + query: str | UniMessage, + *, + model: ModelName = None, + instruction: str = "", + **kwargs: Any, + ) -> dict[str, Any]: + """ + 信息搜索 - 支持多模态输入 + + 参数: + query: 搜索查询内容,支持文本或多模态消息。 + model: 要使用的模型名称。 + instruction: 搜索指令。 + **kwargs: 传递给模型的其他参数。 + + 返回: + dict[str, Any]: 包含搜索结果的字典,包含text、sources、queries和success字段 + """ + from nonebot_plugin_alconna.uniseg import UniMessage + + resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash" + config = CommonOverrides.gemini_grounding() + + if isinstance(query, str): + messages = [LLMMessage.user(query)] + elif isinstance(query, UniMessage): + content_parts = await unimsg_to_llm_parts(query) + + final_messages: list[LLMMessage] = [] + if instruction: + final_messages.append(LLMMessage.system(instruction)) + + if not content_parts: + if instruction: + final_messages.append(LLMMessage.user(instruction)) + else: + raise LLMException( + "搜索内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED + ) + else: + final_messages.append(LLMMessage.user(content_parts)) + + messages = final_messages + else: + raise LLMException( + f"不支持的搜索输入类型: {type(query)}. 请使用 str 或 UniMessage.", + code=LLMErrorCode.API_REQUEST_FAILED, + ) + + response = await self._execute_generation( + messages=messages, + model_name=resolved_model, + error_message="信息搜索失败", + config_overrides=kwargs, + base_config=config, + ) + + result = { + "text": response.text, + "sources": [], + "queries": [], + "success": True, + } + + if response.grounding_metadata: + result["sources"] = response.grounding_metadata.grounding_attributions or [] + result["queries"] = response.grounding_metadata.web_search_queries or [] + + return result + + async def analyze( + self, + message: UniMessage | None, + *, + instruction: str = "", + model: ModelName = None, + use_tools: list[str] | None = None, + tool_config: dict[str, Any] | None = None, + activated_tools: list[LLMTool] | None = None, + history: list[LLMMessage] | None = None, + **kwargs: Any, + ) -> LLMResponse: + """ + 内容分析 - 接收 UniMessage 物件进行多模态分析和工具呼叫。 + + 参数: + message: 要分析的消息内容(支持多模态)。 + instruction: 分析指令。 + model: 要使用的模型名称。 + use_tools: 要使用的工具名称列表。 + tool_config: 工具配置。 + activated_tools: 已激活的工具列表。 + history: 对话历史记录。 + **kwargs: 传递给模型的其他参数。 + + 返回: + LLMResponse: 模型的完整响应结果。 + """ + from nonebot_plugin_alconna.uniseg import UniMessage + + content_parts = await unimsg_to_llm_parts(message or UniMessage()) + + final_messages: list[LLMMessage] = [] + if history: + final_messages.extend(history) + + if instruction: + if not any(msg.role == "system" for msg in final_messages): + final_messages.insert(0, LLMMessage.system(instruction)) + + if not content_parts: + if instruction and not history: + final_messages.append(LLMMessage.user(instruction)) + elif not history: + raise LLMException( + "分析内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED + ) + else: + final_messages.append(LLMMessage.user(content_parts)) + + llm_tools: list[LLMTool] | None = activated_tools + if not llm_tools and use_tools: + try: + llm_tools = tool_registry.get_tools(use_tools) + logger.debug(f"已从注册表加载工具定义: {use_tools}") + except ValueError as e: + raise LLMException( + f"加载工具定义失败: {e}", + code=LLMErrorCode.CONFIGURATION_ERROR, + cause=e, + ) + + tool_choice = None + if tool_config: + mode = tool_config.get("mode", "auto") + if mode in ["auto", "any", "none"]: + tool_choice = mode + + response = await self._execute_generation( + messages=final_messages, + model_name=model, + error_message="内容分析失败", + config_overrides=kwargs, + llm_tools=llm_tools, + tool_choice=tool_choice, + ) + + return response + + async def _execute_generation( + self, + messages: list[LLMMessage], + model_name: ModelName, + error_message: str, + config_overrides: dict[str, Any], + llm_tools: list[LLMTool] | None = None, + tool_choice: str | dict[str, Any] | None = None, + base_config: LLMGenerationConfig | None = None, + ) -> LLMResponse: + """通用的生成执行方法,封装模型获取和单次API调用""" + try: + resolved_model_name = self._resolve_model_name( + model_name or self.config.model + ) + final_config_dict = self._merge_config( + config_overrides, base_config=base_config + ) + + async with await get_model_instance( + resolved_model_name, override_config=final_config_dict + ) as model_instance: + return await model_instance.generate_response( + messages, + tools=llm_tools, + tool_choice=tool_choice, + ) + except LLMException: + raise + except Exception as e: + logger.error(f"{error_message}: {e}", e=e) + raise LLMException(f"{error_message}: {e}", cause=e) + + def _resolve_model_name(self, model_name: ModelName) -> str: + """解析模型名称""" + if model_name: + return model_name + + default_model = get_global_default_model_name() + if default_model: + return default_model + + raise LLMException( + "未指定模型名称且未设置全局默认模型", + code=LLMErrorCode.MODEL_NOT_FOUND, + ) + + def _merge_config( + self, + user_config: dict[str, Any], + base_config: LLMGenerationConfig | None = None, + ) -> dict[str, Any]: + """合并配置""" + final_config = {} + if base_config: + final_config.update(base_config.to_dict()) + + if self.config.temperature is not None: + final_config["temperature"] = self.config.temperature + if self.config.max_tokens is not None: + final_config["max_tokens"] = self.config.max_tokens + + if self.config.enable_cache: + final_config["enable_caching"] = True + if self.config.enable_code: + final_config["enable_code_execution"] = True + if self.config.enable_search: + final_config["enable_grounding"] = True + + if self.config.enable_gemini_json_mode: + final_config["response_mime_type"] = "application/json" + if self.config.enable_gemini_thinking: + final_config["thinking_budget"] = 0.8 + if self.config.enable_gemini_safe_mode: + final_config["safety_settings"] = ( + CommonOverrides.gemini_safe().safety_settings + ) + if self.config.enable_gemini_multimodal: + final_config.update(CommonOverrides.gemini_multimodal().to_dict()) + if self.config.enable_gemini_grounding: + final_config["enable_grounding"] = True + + final_config.update(user_config) + + return final_config + + async def embed( + self, + texts: list[str] | str, + *, + model: ModelName = None, + task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT, + **kwargs: Any, + ) -> list[list[float]]: + """ + 生成文本嵌入向量 + + 参数: + texts: 要生成嵌入向量的文本或文本列表。 + model: 要使用的嵌入模型名称。 + task_type: 嵌入任务类型。 + **kwargs: 传递给模型的其他参数。 + + 返回: + list[list[float]]: 文本的嵌入向量列表。 + """ + if isinstance(texts, str): + texts = [texts] + if not texts: + return [] + + try: + resolved_model_str = ( + model or self.config.default_embedding_model or self.config.model + ) + if not resolved_model_str: + raise LLMException( + "使用 embed 功能时必须指定嵌入模型名称," + "或在 AIConfig 中配置 default_embedding_model。", + code=LLMErrorCode.MODEL_NOT_FOUND, + ) + resolved_model_str = self._resolve_model_name(resolved_model_str) + + async with await get_model_instance( + resolved_model_str, + override_config=None, + ) as embedding_model_instance: + return await embedding_model_instance.generate_embeddings( + texts, task_type=task_type, **kwargs + ) + except LLMException: + raise + except Exception as e: + logger.error(f"文本嵌入失败: {e}", e=e) + raise LLMException( + f"文本嵌入失败: {e}", code=LLMErrorCode.EMBEDDING_FAILED, cause=e + ) diff --git a/zhenxun/services/llm/types/__init__.py b/zhenxun/services/llm/types/__init__.py index f01bc291..72920d06 100644 --- a/zhenxun/services/llm/types/__init__.py +++ b/zhenxun/services/llm/types/__init__.py @@ -10,7 +10,13 @@ from .content import ( LLMMessage, LLMResponse, ) -from .enums import EmbeddingTaskType, ModelProvider, ResponseFormat, ToolCategory +from .enums import ( + EmbeddingTaskType, + ModelProvider, + ResponseFormat, + TaskType, + ToolCategory, +) from .exceptions import LLMErrorCode, LLMException, get_user_friendly_error_message from .models import ( LLMCacheInfo, @@ -52,6 +58,7 @@ __all__ = [ "ModelProvider", "ProviderConfig", "ResponseFormat", + "TaskType", "ToolCategory", "ToolMetadata", "UsageInfo", diff --git a/zhenxun/services/llm/types/enums.py b/zhenxun/services/llm/types/enums.py index 718a52ef..82cb49b0 100644 --- a/zhenxun/services/llm/types/enums.py +++ b/zhenxun/services/llm/types/enums.py @@ -45,6 +45,17 @@ class ToolCategory(Enum): CUSTOM = auto() +class TaskType(Enum): + """任务类型枚举""" + + CHAT = "chat" + CODE = "code" + SEARCH = "search" + ANALYSIS = "analysis" + GENERATION = "generation" + MULTIMODAL = "multimodal" + + class LLMErrorCode(Enum): """LLM 服务相关的错误代码枚举"""