""" LLM 模块的工具和转换函数 """ import base64 import copy from pathlib import Path from typing import Any from nonebot.adapters import Message as PlatformMessage from nonebot_plugin_alconna.uniseg import ( At, File, Image, Reply, Text, UniMessage, Video, Voice, ) from zhenxun.services.log import logger from zhenxun.utils.http_utils import AsyncHttpx from .types import LLMContentPart, LLMMessage async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]: """ 将 UniMessage 实例转换为一个 LLMContentPart 列表。 这是处理多模态输入的核心转换逻辑。 参数: message: 要转换的UniMessage实例。 返回: list[LLMContentPart]: 转换后的内容部分列表。 """ parts: list[LLMContentPart] = [] for seg in message: part = None if isinstance(seg, Text): if seg.text.strip(): part = LLMContentPart.text_part(seg.text) elif isinstance(seg, Image): if seg.path: part = await LLMContentPart.from_path(seg.path, target_api="gemini") elif seg.url: part = LLMContentPart.image_url_part(seg.url) elif hasattr(seg, "raw") and seg.raw: mime_type = ( getattr(seg, "mimetype", "image/png") if hasattr(seg, "mimetype") else "image/png" ) if isinstance(seg.raw, bytes): b64_data = base64.b64encode(seg.raw).decode("utf-8") part = LLMContentPart.image_base64_part(b64_data, mime_type) elif isinstance(seg, File | Voice | Video): if seg.path: part = await LLMContentPart.from_path(seg.path) elif seg.url: try: logger.debug(f"检测到媒体URL,开始下载: {seg.url}") media_bytes = await AsyncHttpx.get_content(seg.url) new_seg = copy.copy(seg) new_seg.raw = media_bytes seg = new_seg logger.debug(f"媒体文件下载成功,大小: {len(media_bytes)} bytes") except Exception as e: logger.error(f"从URL下载媒体失败: {seg.url}, 错误: {e}") part = LLMContentPart.text_part( f"[下载媒体失败: {seg.name or seg.url}]" ) if part: parts.append(part) continue if hasattr(seg, "raw") and seg.raw: mime_type = getattr(seg, "mimetype", None) if isinstance(seg.raw, bytes): b64_data = base64.b64encode(seg.raw).decode("utf-8") if isinstance(seg, Video): if not mime_type: mime_type = "video/mp4" part = LLMContentPart.video_base64_part( data=b64_data, mime_type=mime_type ) logger.debug( f"处理视频字节数据: {mime_type}, 大小: {len(seg.raw)} bytes" ) elif isinstance(seg, Voice): if not mime_type: mime_type = "audio/wav" part = LLMContentPart.audio_base64_part( data=b64_data, mime_type=mime_type ) logger.debug( f"处理音频字节数据: {mime_type}, 大小: {len(seg.raw)} bytes" ) else: part = LLMContentPart.text_part( f"[FILE: {mime_type or 'unknown'}, {len(seg.raw)} bytes]" ) logger.debug( f"处理其他文件字节数据: {mime_type}, " f"大小: {len(seg.raw)} bytes" ) elif isinstance(seg, At): if seg.flag == "all": part = LLMContentPart.text_part("[提及所有人]") else: part = LLMContentPart.text_part(f"[提及用户: {seg.target}]") elif isinstance(seg, Reply): if seg.msg: try: extract_method = getattr(seg.msg, "extract_plain_text", None) if extract_method and callable(extract_method): reply_text = str(extract_method()).strip() else: reply_text = str(seg.msg).strip() if reply_text: part = LLMContentPart.text_part( f'[回复消息: "{reply_text[:50]}..."]' ) except Exception: part = LLMContentPart.text_part("[回复了一条消息]") if part: parts.append(part) return parts async def normalize_to_llm_messages( message: str | UniMessage | LLMMessage | list[LLMContentPart] | list[LLMMessage], instruction: str | None = None, ) -> list[LLMMessage]: """ 将多种输入格式标准化为 LLMMessage 列表,并可选地添加系统指令。 这是处理 LLM 输入的核心工具函数。 参数: message: 要标准化的输入消息。 instruction: 可选的系统指令。 返回: list[LLMMessage]: 标准化后的消息列表。 """ messages = [] if instruction: messages.append(LLMMessage.system(instruction)) if isinstance(message, LLMMessage): messages.append(message) elif isinstance(message, list) and all(isinstance(m, LLMMessage) for m in message): messages.extend(message) elif isinstance(message, str): messages.append(LLMMessage.user(message)) elif isinstance(message, UniMessage): content_parts = await unimsg_to_llm_parts(message) messages.append(LLMMessage.user(content_parts)) elif isinstance(message, list): messages.append(LLMMessage.user(message)) # type: ignore else: raise TypeError(f"不支持的消息类型: {type(message)}") return messages def create_multimodal_message( text: str | None = None, images: list[str | Path | bytes] | str | Path | bytes | None = None, videos: list[str | Path | bytes] | str | Path | bytes | None = None, audios: list[str | Path | bytes] | str | Path | bytes | None = None, image_mimetypes: list[str] | str | None = None, video_mimetypes: list[str] | str | None = None, audio_mimetypes: list[str] | str | None = None, ) -> UniMessage: """ 创建多模态消息的便捷函数 参数: text: 文本内容 images: 图片数据,支持路径、字节数据或URL videos: 视频数据 audios: 音频数据 image_mimetypes: 图片MIME类型,bytes数据时需要指定 video_mimetypes: 视频MIME类型,bytes数据时需要指定 audio_mimetypes: 音频MIME类型,bytes数据时需要指定 返回: UniMessage: 构建好的多模态消息 """ message = UniMessage() if text: message.append(Text(text)) if images is not None: _add_media_to_message(message, images, image_mimetypes, Image, "image/png") if videos is not None: _add_media_to_message(message, videos, video_mimetypes, Video, "video/mp4") if audios is not None: _add_media_to_message(message, audios, audio_mimetypes, Voice, "audio/wav") return message def _add_media_to_message( message: UniMessage, media_items: list[str | Path | bytes] | str | Path | bytes, mimetypes: list[str] | str | None, media_class: type, default_mimetype: str, ) -> None: """添加媒体文件到 UniMessage""" if not isinstance(media_items, list): media_items = [media_items] mime_list = [] if mimetypes is not None: if isinstance(mimetypes, str): mime_list = [mimetypes] * len(media_items) else: mime_list = list(mimetypes) for i, item in enumerate(media_items): if isinstance(item, str | Path): if str(item).startswith(("http://", "https://")): message.append(media_class(url=str(item))) else: message.append(media_class(path=Path(item))) elif isinstance(item, bytes): mimetype = mime_list[i] if i < len(mime_list) else default_mimetype message.append(media_class(raw=item, mimetype=mimetype)) def message_to_unimessage(message: PlatformMessage) -> UniMessage: """ 将平台特定的 Message 对象转换为通用的 UniMessage。 主要用于处理引用消息等未被自动转换的消息体。 参数: message: 平台特定的Message对象。 返回: UniMessage: 转换后的通用消息对象。 """ uni_segments = [] for seg in message: if seg.type == "text": uni_segments.append(Text(seg.data.get("text", ""))) elif seg.type == "image": uni_segments.append(Image(url=seg.data.get("url"))) elif seg.type == "record": uni_segments.append(Voice(url=seg.data.get("url"))) elif seg.type == "video": uni_segments.append(Video(url=seg.data.get("url"))) elif seg.type == "at": uni_segments.append(At("user", str(seg.data.get("qq", "")))) else: logger.debug(f"跳过不支持的平台消息段类型: {seg.type}") return UniMessage(uni_segments) def sanitize_schema_for_llm(schema: Any, api_type: str) -> Any: """ 递归地净化 JSON Schema,移除特定 LLM API 不支持的关键字。 参数: schema: 要净化的 JSON Schema (可以是字典、列表或其它类型)。 api_type: 目标 API 的类型,例如 'gemini'。 返回: Any: 净化后的 JSON Schema。 """ if isinstance(schema, dict): schema_copy = {} for key, value in schema.items(): if api_type == "gemini": unsupported_keys = ["exclusiveMinimum", "exclusiveMaximum", "default"] if key in unsupported_keys: continue if key == "format" and isinstance(value, str): supported_formats = ["enum", "date-time"] if value not in supported_formats: continue schema_copy[key] = sanitize_schema_for_llm(value, api_type) return schema_copy elif isinstance(schema, list): return [sanitize_schema_for_llm(item, api_type) for item in schema] else: return schema