mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-14 21:52:56 +08:00
Merge db53afbabc into 07be73c1b7
This commit is contained in:
commit
a82bf95da3
@ -1,12 +1,12 @@
|
||||
from typing import Any
|
||||
|
||||
from nonebot.adapters import Bot, Message
|
||||
from nonebot.adapters.onebot.v11 import MessageSegment
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.models.bot_message_store import BotMessageStore
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import BotSentType
|
||||
from zhenxun.utils.log_sanitizer import sanitize_for_logging
|
||||
from zhenxun.utils.manager.message_manager import MessageManager
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
@ -41,35 +41,6 @@ def replace_message(message: Message) -> str:
|
||||
return result
|
||||
|
||||
|
||||
def format_message_for_log(message: Message) -> str:
|
||||
"""
|
||||
将消息对象转换为适合日志记录的字符串,对base64等长内容进行摘要处理。
|
||||
"""
|
||||
if not isinstance(message, Message):
|
||||
return str(message)
|
||||
|
||||
log_parts = []
|
||||
for seg in message:
|
||||
seg: MessageSegment
|
||||
if seg.type == "text":
|
||||
log_parts.append(seg.data.get("text", ""))
|
||||
elif seg.type in ("image", "record", "video"):
|
||||
file_info = seg.data.get("file", "")
|
||||
if isinstance(file_info, str) and file_info.startswith("base64://"):
|
||||
b64_data = file_info[9:]
|
||||
data_size_bytes = (len(b64_data) * 3) / 4 - b64_data.count("=", -2)
|
||||
log_parts.append(
|
||||
f"[{seg.type}: base64, size={data_size_bytes / 1024:.2f}KB]"
|
||||
)
|
||||
else:
|
||||
log_parts.append(f"[{seg.type}]")
|
||||
elif seg.type == "at":
|
||||
log_parts.append(f"[@{seg.data.get('qq', 'unknown')}]")
|
||||
else:
|
||||
log_parts.append(f"[{seg.type}]")
|
||||
return "".join(log_parts)
|
||||
|
||||
|
||||
@Bot.on_called_api
|
||||
async def handle_api_result(
|
||||
bot: Bot, exception: Exception | None, api: str, data: dict[str, Any], result: Any
|
||||
@ -82,7 +53,6 @@ async def handle_api_result(
|
||||
message: Message = data.get("message", "")
|
||||
message_type = data.get("message_type")
|
||||
try:
|
||||
# 记录消息id
|
||||
if user_id and message_id:
|
||||
MessageManager.add(str(user_id), str(message_id))
|
||||
logger.debug(
|
||||
@ -108,7 +78,8 @@ async def handle_api_result(
|
||||
else replace_message(message),
|
||||
platform=PlatformUtils.get_platform(bot),
|
||||
)
|
||||
logger.debug(f"消息发送记录,message: {format_message_for_log(message)}")
|
||||
sanitized_message = sanitize_for_logging(message, context="nonebot_message")
|
||||
logger.debug(f"消息发送记录,message: {sanitized_message}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"消息发送记录发生错误...data: {data}, result: {result}",
|
||||
|
||||
@ -7,6 +7,7 @@ LLM 服务模块 - 公共 API 入口
|
||||
from .api import (
|
||||
chat,
|
||||
code,
|
||||
create_image,
|
||||
embed,
|
||||
generate,
|
||||
generate_structured,
|
||||
@ -74,6 +75,7 @@ __all__ = [
|
||||
"chat",
|
||||
"clear_model_cache",
|
||||
"code",
|
||||
"create_image",
|
||||
"create_multimodal_message",
|
||||
"embed",
|
||||
"function_tool",
|
||||
|
||||
@ -3,6 +3,9 @@ LLM 适配器基类和通用数据结构
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import base64
|
||||
import binascii
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
@ -32,6 +35,7 @@ class ResponseData(BaseModel):
|
||||
"""响应数据封装 - 支持所有高级功能"""
|
||||
|
||||
text: str
|
||||
image_bytes: bytes | None = None
|
||||
usage_info: dict[str, Any] | None = None
|
||||
raw_response: dict[str, Any] | None = None
|
||||
tool_calls: list[LLMToolCall] | None = None
|
||||
@ -242,6 +246,38 @@ class BaseAdapter(ABC):
|
||||
if content:
|
||||
content = content.strip()
|
||||
|
||||
image_bytes: bytes | None = None
|
||||
if content and content.startswith("{") and content.endswith("}"):
|
||||
try:
|
||||
content_json = json.loads(content)
|
||||
if "b64_json" in content_json:
|
||||
image_bytes = base64.b64decode(content_json["b64_json"])
|
||||
content = "[图片已生成]"
|
||||
elif "data" in content_json and isinstance(
|
||||
content_json["data"], str
|
||||
):
|
||||
image_bytes = base64.b64decode(content_json["data"])
|
||||
content = "[图片已生成]"
|
||||
|
||||
except (json.JSONDecodeError, KeyError, binascii.Error):
|
||||
pass
|
||||
elif (
|
||||
"images" in message
|
||||
and isinstance(message["images"], list)
|
||||
and message["images"]
|
||||
):
|
||||
image_info = message["images"][0]
|
||||
if image_info.get("type") == "image_url":
|
||||
image_url_obj = image_info.get("image_url", {})
|
||||
url_str = image_url_obj.get("url", "")
|
||||
if url_str.startswith("data:image/png;base64,"):
|
||||
try:
|
||||
b64_data = url_str.split(",", 1)[1]
|
||||
image_bytes = base64.b64decode(b64_data)
|
||||
content = content if content else "[图片已生成]"
|
||||
except (IndexError, binascii.Error) as e:
|
||||
logger.warning(f"解析OpenRouter Base64图片数据失败: {e}")
|
||||
|
||||
parsed_tool_calls: list[LLMToolCall] | None = None
|
||||
if message_tool_calls := message.get("tool_calls"):
|
||||
from ..types.models import LLMToolFunction
|
||||
@ -280,6 +316,7 @@ class BaseAdapter(ABC):
|
||||
text=final_text,
|
||||
tool_calls=parsed_tool_calls,
|
||||
usage_info=usage_info,
|
||||
image_bytes=image_bytes,
|
||||
raw_response=response_json,
|
||||
)
|
||||
|
||||
@ -450,6 +487,13 @@ class OpenAICompatAdapter(BaseAdapter):
|
||||
"""准备高级请求 - OpenAI兼容格式"""
|
||||
url = self.get_api_url(model, self.get_chat_endpoint(model))
|
||||
headers = self.get_base_headers(api_key)
|
||||
if model.api_type == "openrouter":
|
||||
headers.update(
|
||||
{
|
||||
"HTTP-Referer": "https://github.com/zhenxun-org/zhenxun_bot",
|
||||
"X-Title": "Zhenxun Bot",
|
||||
}
|
||||
)
|
||||
openai_messages = self.convert_messages_to_openai_format(messages)
|
||||
|
||||
body = {
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
Gemini API 适配器
|
||||
"""
|
||||
|
||||
import base64
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
@ -373,7 +374,16 @@ class GeminiAdapter(BaseAdapter):
|
||||
self.validate_response(response_json)
|
||||
|
||||
try:
|
||||
candidates = response_json.get("candidates", [])
|
||||
if "image_generation" in response_json and isinstance(
|
||||
response_json["image_generation"], dict
|
||||
):
|
||||
candidates_source = response_json["image_generation"]
|
||||
else:
|
||||
candidates_source = response_json
|
||||
|
||||
candidates = candidates_source.get("candidates", [])
|
||||
usage_info = response_json.get("usageMetadata")
|
||||
|
||||
if not candidates:
|
||||
logger.debug("Gemini响应中没有candidates。")
|
||||
return ResponseData(text="", raw_response=response_json)
|
||||
@ -398,6 +408,7 @@ class GeminiAdapter(BaseAdapter):
|
||||
parts = content_data.get("parts", [])
|
||||
|
||||
text_content = ""
|
||||
image_bytes: bytes | None = None
|
||||
parsed_tool_calls: list["LLMToolCall"] | None = None
|
||||
thought_summary_parts = []
|
||||
answer_parts = []
|
||||
@ -409,6 +420,14 @@ class GeminiAdapter(BaseAdapter):
|
||||
thought_summary_parts.append(part["thought"])
|
||||
elif "thoughtSummary" in part:
|
||||
thought_summary_parts.append(part["thoughtSummary"])
|
||||
elif "inlineData" in part:
|
||||
inline_data = part["inlineData"]
|
||||
if "data" in inline_data:
|
||||
image_bytes = base64.b64decode(inline_data["data"])
|
||||
answer_parts.append(
|
||||
f"[图片已生成: {inline_data.get('mimeType', 'image')}]"
|
||||
)
|
||||
|
||||
elif "functionCall" in part:
|
||||
if parsed_tool_calls is None:
|
||||
parsed_tool_calls = []
|
||||
@ -475,6 +494,7 @@ class GeminiAdapter(BaseAdapter):
|
||||
return ResponseData(
|
||||
text=text_content,
|
||||
tool_calls=parsed_tool_calls,
|
||||
image_bytes=image_bytes,
|
||||
usage_info=usage_info,
|
||||
raw_response=response_json,
|
||||
grounding_metadata=grounding_metadata_obj,
|
||||
|
||||
@ -21,7 +21,14 @@ class OpenAIAdapter(OpenAICompatAdapter):
|
||||
|
||||
@property
|
||||
def supported_api_types(self) -> list[str]:
|
||||
return ["openai", "deepseek", "zhipu", "general_openai_compat", "ark"]
|
||||
return [
|
||||
"openai",
|
||||
"deepseek",
|
||||
"zhipu",
|
||||
"general_openai_compat",
|
||||
"ark",
|
||||
"openrouter",
|
||||
]
|
||||
|
||||
def get_chat_endpoint(self, model: "LLMModel") -> str:
|
||||
"""返回聊天完成端点"""
|
||||
|
||||
@ -2,7 +2,8 @@
|
||||
LLM 服务的高级 API 接口 - 便捷函数入口 (无状态)
|
||||
"""
|
||||
|
||||
from typing import Any, TypeVar
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeVar, overload
|
||||
|
||||
from nonebot_plugin_alconna.uniseg import UniMessage
|
||||
from pydantic import BaseModel
|
||||
@ -10,7 +11,7 @@ from pydantic import BaseModel
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
from .config import CommonOverrides
|
||||
from .config.generation import create_generation_config_from_kwargs
|
||||
from .config.generation import LLMGenerationConfig, create_generation_config_from_kwargs
|
||||
from .manager import get_model_instance
|
||||
from .session import AI
|
||||
from .tools.manager import tool_provider_manager
|
||||
@ -23,6 +24,7 @@ from .types import (
|
||||
LLMResponse,
|
||||
ModelName,
|
||||
)
|
||||
from .utils import create_multimodal_message
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
@ -303,3 +305,99 @@ async def run_with_tools(
|
||||
raise LLMException(
|
||||
"带工具的执行循环未能产生有效的助手回复。", code=LLMErrorCode.GENERATION_FAILED
|
||||
)
|
||||
|
||||
|
||||
async def _generate_image_from_message(
|
||||
message: UniMessage,
|
||||
model: ModelName = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
[内部] 从 UniMessage 生成图片的核心辅助函数。
|
||||
"""
|
||||
from .utils import normalize_to_llm_messages
|
||||
|
||||
config = (
|
||||
create_generation_config_from_kwargs(**kwargs)
|
||||
if kwargs
|
||||
else LLMGenerationConfig()
|
||||
)
|
||||
|
||||
config.validation_policy = {"require_image": True}
|
||||
config.response_modalities = ["IMAGE", "TEXT"]
|
||||
|
||||
try:
|
||||
messages = await normalize_to_llm_messages(message)
|
||||
|
||||
async with await get_model_instance(model) as model_instance:
|
||||
if not model_instance.can_generate_images():
|
||||
raise LLMException(
|
||||
f"模型 '{model_instance.provider_name}/{model_instance.model_name}'"
|
||||
f"不支持图片生成",
|
||||
code=LLMErrorCode.CONFIGURATION_ERROR,
|
||||
)
|
||||
|
||||
response = await model_instance.generate_response(messages, config=config)
|
||||
|
||||
if not response.image_bytes:
|
||||
error_text = response.text or "模型未返回图片数据。"
|
||||
logger.warning(f"图片生成调用未返回图片,返回文本内容: {error_text}")
|
||||
|
||||
return response
|
||||
except LLMException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"执行图片生成时发生未知错误: {e}", e=e)
|
||||
raise LLMException(f"图片生成失败: {e}", cause=e)
|
||||
|
||||
|
||||
@overload
|
||||
async def create_image(
|
||||
prompt: str | UniMessage,
|
||||
*,
|
||||
images: None = None,
|
||||
model: ModelName = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""根据文本提示生成一张新图片。"""
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
async def create_image(
|
||||
prompt: str | UniMessage,
|
||||
*,
|
||||
images: list[Path | bytes | str] | Path | bytes | str,
|
||||
model: ModelName = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""在给定图片的基础上,根据文本提示进行编辑或重新生成。"""
|
||||
...
|
||||
|
||||
|
||||
async def create_image(
|
||||
prompt: str | UniMessage,
|
||||
*,
|
||||
images: list[Path | bytes | str] | Path | bytes | str | None = None,
|
||||
model: ModelName = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
智能图片生成/编辑函数。
|
||||
- 如果 `images` 为 None,执行文生图。
|
||||
- 如果提供了 `images`,执行图+文生图,支持多张图片输入。
|
||||
"""
|
||||
text_prompt = (
|
||||
prompt.extract_plain_text() if isinstance(prompt, UniMessage) else str(prompt)
|
||||
)
|
||||
|
||||
image_list = []
|
||||
if images:
|
||||
if isinstance(images, list):
|
||||
image_list.extend(images)
|
||||
else:
|
||||
image_list.append(images)
|
||||
|
||||
message = create_multimodal_message(text=text_prompt, images=image_list)
|
||||
|
||||
return await _generate_image_from_message(message, model=model, **kwargs)
|
||||
|
||||
@ -2,13 +2,15 @@
|
||||
LLM 生成配置相关类和函数
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.pydantic_compat import model_dump
|
||||
|
||||
from ..types import LLMResponse
|
||||
from ..types.enums import ResponseFormat
|
||||
from ..types.exceptions import LLMErrorCode, LLMException
|
||||
|
||||
@ -64,6 +66,15 @@ class ModelConfigOverride(BaseModel):
|
||||
|
||||
custom_params: dict[str, Any] | None = Field(default=None, description="自定义参数")
|
||||
|
||||
validation_policy: dict[str, Any] | None = Field(
|
||||
default=None, description="声明式的响应验证策略 (例如: {'require_image': True})"
|
||||
)
|
||||
response_validator: Callable[[LLMResponse], None] | None = Field(
|
||||
default=None, description="一个高级回调函数,用于验证响应,验证失败时应抛出异常"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""转换为字典,排除None值"""
|
||||
|
||||
|
||||
@ -50,8 +50,8 @@ class LLMHttpClient:
|
||||
async with self._lock:
|
||||
if self._client is None or self._client.is_closed:
|
||||
logger.debug(
|
||||
f"LLMHttpClient: Initializing new httpx.AsyncClient "
|
||||
f"with config: {self.config}"
|
||||
f"LLMHttpClient: 正在初始化新的 httpx.AsyncClient "
|
||||
f"配置: {self.config}"
|
||||
)
|
||||
headers = get_user_agent()
|
||||
limits = httpx.Limits(
|
||||
@ -92,7 +92,7 @@ class LLMHttpClient:
|
||||
)
|
||||
if self._client is None:
|
||||
raise LLMException(
|
||||
"HTTP client failed to initialize.", LLMErrorCode.CONFIGURATION_ERROR
|
||||
"HTTP 客户端初始化失败。", LLMErrorCode.CONFIGURATION_ERROR
|
||||
)
|
||||
return self._client
|
||||
|
||||
@ -110,17 +110,17 @@ class LLMHttpClient:
|
||||
async with self._lock:
|
||||
if self._client and not self._client.is_closed:
|
||||
logger.debug(
|
||||
f"LLMHttpClient: Closing with config: {self.config}. "
|
||||
f"Active requests: {self._active_requests}"
|
||||
f"LLMHttpClient: 正在关闭,配置: {self.config}. "
|
||||
f"活跃请求数: {self._active_requests}"
|
||||
)
|
||||
if self._active_requests > 0:
|
||||
logger.warning(
|
||||
f"LLMHttpClient: Closing while {self._active_requests} "
|
||||
f"requests are still active."
|
||||
f"LLMHttpClient: 关闭时仍有 {self._active_requests} "
|
||||
f"个请求处于活跃状态。"
|
||||
)
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
logger.debug(f"LLMHttpClient for config {self.config} definitively closed.")
|
||||
logger.debug(f"配置为 {self.config} 的 LLMHttpClient 已完全关闭。")
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
@ -145,20 +145,17 @@ class LLMHttpClientManager:
|
||||
client = self._clients.get(key)
|
||||
if client and not client.is_closed:
|
||||
logger.debug(
|
||||
f"LLMHttpClientManager: Reusing existing LLMHttpClient "
|
||||
f"for key: {key}"
|
||||
f"LLMHttpClientManager: 复用现有的 LLMHttpClient 密钥: {key}"
|
||||
)
|
||||
return client
|
||||
|
||||
if client and client.is_closed:
|
||||
logger.debug(
|
||||
f"LLMHttpClientManager: Found a closed client for key {key}. "
|
||||
f"Creating a new one."
|
||||
f"LLMHttpClientManager: 发现密钥 {key} 对应的客户端已关闭。"
|
||||
f"正在创建新的客户端。"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"LLMHttpClientManager: Creating new LLMHttpClient for key: {key}"
|
||||
)
|
||||
logger.debug(f"LLMHttpClientManager: 为密钥 {key} 创建新的 LLMHttpClient")
|
||||
http_client_config = HttpClientConfig(
|
||||
timeout=provider_config.timeout, proxy=provider_config.proxy
|
||||
)
|
||||
@ -169,8 +166,7 @@ class LLMHttpClientManager:
|
||||
async def shutdown(self):
|
||||
async with self._lock:
|
||||
logger.info(
|
||||
f"LLMHttpClientManager: Shutting down. "
|
||||
f"Closing {len(self._clients)} client(s)."
|
||||
f"LLMHttpClientManager: 正在关闭。关闭 {len(self._clients)} 个客户端。"
|
||||
)
|
||||
close_tasks = [
|
||||
client.close()
|
||||
@ -180,7 +176,7 @@ class LLMHttpClientManager:
|
||||
if close_tasks:
|
||||
await asyncio.gather(*close_tasks, return_exceptions=True)
|
||||
self._clients.clear()
|
||||
logger.info("LLMHttpClientManager: Shutdown complete.")
|
||||
logger.info("LLMHttpClientManager: 关闭完成。")
|
||||
|
||||
|
||||
http_client_manager = LLMHttpClientManager()
|
||||
|
||||
@ -118,6 +118,7 @@ def get_default_api_base_for_type(api_type: str) -> str | None:
|
||||
"deepseek": "https://api.deepseek.com",
|
||||
"zhipu": "https://open.bigmodel.cn",
|
||||
"gemini": "https://generativelanguage.googleapis.com",
|
||||
"openrouter": "https://openrouter.ai/api",
|
||||
"general_openai_compat": None,
|
||||
}
|
||||
|
||||
|
||||
@ -12,6 +12,7 @@ from typing import Any, TypeVar
|
||||
from pydantic import BaseModel
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.log_sanitizer import sanitize_for_logging
|
||||
|
||||
from .adapters.base import RequestData
|
||||
from .config import LLMGenerationConfig
|
||||
@ -34,7 +35,6 @@ from .types import (
|
||||
ToolExecutable,
|
||||
)
|
||||
from .types.capabilities import ModelCapabilities, ModelModality
|
||||
from .utils import _sanitize_request_body_for_logging
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
@ -187,7 +187,13 @@ class LLMModel(LLMModelBase):
|
||||
logger.debug(f"🔑 API密钥: {masked_key}")
|
||||
logger.debug(f"📋 请求头: {dict(request_data.headers)}")
|
||||
|
||||
sanitized_body = _sanitize_request_body_for_logging(request_data.body)
|
||||
sanitizer_req_context_map = {"gemini": "gemini_request"}
|
||||
sanitizer_req_context = sanitizer_req_context_map.get(
|
||||
self.api_type, "openai_request"
|
||||
)
|
||||
sanitized_body = sanitize_for_logging(
|
||||
request_data.body, context=sanitizer_req_context
|
||||
)
|
||||
request_body_str = json.dumps(sanitized_body, ensure_ascii=False, indent=2)
|
||||
logger.debug(f"📦 请求体: {request_body_str}")
|
||||
|
||||
@ -200,8 +206,11 @@ class LLMModel(LLMModelBase):
|
||||
logger.debug(f"📥 响应状态码: {http_response.status_code}")
|
||||
logger.debug(f"📄 响应头: {dict(http_response.headers)}")
|
||||
|
||||
response_bytes = await http_response.aread()
|
||||
logger.debug(f"📦 响应体已完整读取 ({len(response_bytes)} bytes)")
|
||||
|
||||
if http_response.status_code != 200:
|
||||
error_text = http_response.text
|
||||
error_text = response_bytes.decode("utf-8", errors="ignore")
|
||||
logger.error(
|
||||
f"❌ HTTP请求失败: {http_response.status_code} - {error_text} "
|
||||
f"[{log_context}]"
|
||||
@ -232,13 +241,22 @@ class LLMModel(LLMModelBase):
|
||||
)
|
||||
|
||||
try:
|
||||
response_json = http_response.json()
|
||||
response_json = json.loads(response_bytes)
|
||||
|
||||
sanitizer_context_map = {"gemini": "gemini_response"}
|
||||
sanitizer_context = sanitizer_context_map.get(
|
||||
self.api_type, "openai_response"
|
||||
)
|
||||
|
||||
sanitized_for_log = sanitize_for_logging(
|
||||
response_json, context=sanitizer_context
|
||||
)
|
||||
|
||||
response_json_str = json.dumps(
|
||||
response_json, ensure_ascii=False, indent=2
|
||||
sanitized_for_log, ensure_ascii=False, indent=2
|
||||
)
|
||||
logger.debug(f"📋 响应JSON: {response_json_str}")
|
||||
parsed_data = parse_response_func(response_json)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析 {log_context} 响应失败: {e}", e=e)
|
||||
await self.key_store.record_failure(api_key, None, str(e))
|
||||
@ -290,7 +308,7 @@ class LLMModel(LLMModelBase):
|
||||
adapter.validate_embedding_response(response_json)
|
||||
return adapter.parse_embedding_response(response_json)
|
||||
|
||||
parsed_data, api_key_used = await self._perform_api_call(
|
||||
parsed_data, _api_key_used = await self._perform_api_call(
|
||||
prepare_request_func=prepare_request,
|
||||
parse_response_func=parse_response,
|
||||
http_client=http_client,
|
||||
@ -376,6 +394,7 @@ class LLMModel(LLMModelBase):
|
||||
return LLMResponse(
|
||||
text=response_data.text,
|
||||
usage_info=response_data.usage_info,
|
||||
image_bytes=response_data.image_bytes,
|
||||
raw_response=response_data.raw_response,
|
||||
tool_calls=response_tool_calls if response_tool_calls else None,
|
||||
code_executions=response_data.code_executions,
|
||||
@ -390,6 +409,56 @@ class LLMModel(LLMModelBase):
|
||||
failed_keys=failed_keys,
|
||||
log_context="Generation",
|
||||
)
|
||||
|
||||
if config:
|
||||
if config.response_validator:
|
||||
try:
|
||||
config.response_validator(parsed_data)
|
||||
except Exception as e:
|
||||
raise LLMException(
|
||||
f"响应内容未通过自定义验证器: {e}",
|
||||
code=LLMErrorCode.API_RESPONSE_INVALID,
|
||||
details={"validator_error": str(e)},
|
||||
cause=e,
|
||||
) from e
|
||||
|
||||
policy = config.validation_policy
|
||||
if policy:
|
||||
if policy.get("require_image") and not parsed_data.image_bytes:
|
||||
if self.api_type == "gemini" and parsed_data.raw_response:
|
||||
usage_metadata = parsed_data.raw_response.get(
|
||||
"usageMetadata", {}
|
||||
)
|
||||
prompt_token_details = usage_metadata.get(
|
||||
"promptTokensDetails", []
|
||||
)
|
||||
prompt_had_image = any(
|
||||
detail.get("modality") == "IMAGE"
|
||||
for detail in prompt_token_details
|
||||
)
|
||||
|
||||
if prompt_had_image:
|
||||
raise LLMException(
|
||||
"响应验证失败:模型接收了图片输入但未生成图片。",
|
||||
code=LLMErrorCode.API_RESPONSE_INVALID,
|
||||
details={
|
||||
"policy": policy,
|
||||
"text_response": parsed_data.text,
|
||||
"raw_response": parsed_data.raw_response,
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.debug("Gemini提示词中未包含图片,跳过图片要求重试。")
|
||||
else:
|
||||
raise LLMException(
|
||||
"响应验证失败:要求返回图片但未找到图片数据。",
|
||||
code=LLMErrorCode.API_RESPONSE_INVALID,
|
||||
details={
|
||||
"policy": policy,
|
||||
"text_response": parsed_data.text,
|
||||
},
|
||||
)
|
||||
|
||||
return parsed_data, api_key_used
|
||||
|
||||
async def close(self):
|
||||
|
||||
@ -44,6 +44,13 @@ GEMINI_CAPABILITIES = ModelCapabilities(
|
||||
supports_tool_calling=True,
|
||||
)
|
||||
|
||||
GEMINI_IMAGE_GEN_CAPABILITIES = ModelCapabilities(
|
||||
input_modalities={ModelModality.TEXT, ModelModality.IMAGE},
|
||||
output_modalities={ModelModality.TEXT, ModelModality.IMAGE},
|
||||
supports_tool_calling=True,
|
||||
)
|
||||
|
||||
|
||||
DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES = ModelCapabilities(
|
||||
input_modalities={ModelModality.TEXT, ModelModality.IMAGE, ModelModality.VIDEO},
|
||||
output_modalities={ModelModality.TEXT},
|
||||
@ -83,6 +90,7 @@ MODEL_CAPABILITIES_REGISTRY: dict[str, ModelCapabilities] = {
|
||||
output_modalities={ModelModality.EMBEDDING},
|
||||
is_embedding_model=True,
|
||||
),
|
||||
"*gemini-*-image-preview*": GEMINI_IMAGE_GEN_CAPABILITIES,
|
||||
"gemini-2.5-pro*": GEMINI_CAPABILITIES,
|
||||
"gemini-1.5-pro*": GEMINI_CAPABILITIES,
|
||||
"gemini-2.5-flash*": GEMINI_CAPABILITIES,
|
||||
|
||||
@ -425,6 +425,7 @@ class LLMResponse(BaseModel):
|
||||
"""LLM 响应"""
|
||||
|
||||
text: str
|
||||
image_bytes: bytes | None = None
|
||||
usage_info: dict[str, Any] | None = None
|
||||
raw_response: dict[str, Any] | None = None
|
||||
tool_calls: list[Any] | None = None
|
||||
|
||||
@ -273,54 +273,6 @@ def message_to_unimessage(message: PlatformMessage) -> UniMessage:
|
||||
return UniMessage(uni_segments)
|
||||
|
||||
|
||||
def _sanitize_request_body_for_logging(body: dict) -> dict:
|
||||
"""
|
||||
净化请求体用于日志记录,移除大数据字段并添加摘要信息
|
||||
|
||||
参数:
|
||||
body: 原始请求体字典。
|
||||
|
||||
返回:
|
||||
dict: 净化后的请求体字典。
|
||||
"""
|
||||
try:
|
||||
sanitized_body = copy.deepcopy(body)
|
||||
|
||||
if "contents" in sanitized_body and isinstance(
|
||||
sanitized_body["contents"], list
|
||||
):
|
||||
for content_item in sanitized_body["contents"]:
|
||||
if "parts" in content_item and isinstance(content_item["parts"], list):
|
||||
media_summary = []
|
||||
new_parts = []
|
||||
for part in content_item["parts"]:
|
||||
if "inlineData" in part and isinstance(
|
||||
part["inlineData"], dict
|
||||
):
|
||||
data = part["inlineData"].get("data")
|
||||
if isinstance(data, str):
|
||||
mime_type = part["inlineData"].get(
|
||||
"mimeType", "unknown"
|
||||
)
|
||||
media_summary.append(f"{mime_type} ({len(data)} chars)")
|
||||
continue
|
||||
new_parts.append(part)
|
||||
|
||||
if media_summary:
|
||||
summary_text = (
|
||||
f"[多模态内容: {len(media_summary)}个文件 - "
|
||||
f"{', '.join(media_summary)}]"
|
||||
)
|
||||
new_parts.insert(0, {"text": summary_text})
|
||||
|
||||
content_item["parts"] = new_parts
|
||||
|
||||
return sanitized_body
|
||||
except Exception as e:
|
||||
logger.warning(f"日志净化失败: {e},将记录原始请求体。")
|
||||
return body
|
||||
|
||||
|
||||
def sanitize_schema_for_llm(schema: Any, api_type: str) -> Any:
|
||||
"""
|
||||
递归地净化 JSON Schema,移除特定 LLM API 不支持的关键字。
|
||||
|
||||
180
zhenxun/utils/log_sanitizer.py
Normal file
180
zhenxun/utils/log_sanitizer.py
Normal file
@ -0,0 +1,180 @@
|
||||
import copy
|
||||
from typing import Any
|
||||
|
||||
from nonebot.adapters import Message, MessageSegment
|
||||
|
||||
|
||||
def _truncate_base64_string(value: str, threshold: int = 256) -> str:
|
||||
"""如果字符串是超长的base64或data URI,则截断它。"""
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
|
||||
prefixes = ("base64://", "data:image", "data:video", "data:audio")
|
||||
if value.startswith(prefixes) and len(value) > threshold:
|
||||
prefix = next((p for p in prefixes if value.startswith(p)), "base64")
|
||||
return f"[{prefix}_data_omitted_len={len(value)}]"
|
||||
return value
|
||||
|
||||
|
||||
def _sanitize_nonebot_message(message: Message) -> Message:
|
||||
"""净化nonebot.adapter.Message对象,用于日志记录。"""
|
||||
sanitized_message = copy.deepcopy(message)
|
||||
for seg in sanitized_message:
|
||||
seg: MessageSegment
|
||||
if seg.type in ("image", "record", "video"):
|
||||
file_info = seg.data.get("file", "")
|
||||
if isinstance(file_info, str):
|
||||
seg.data["file"] = _truncate_base64_string(file_info)
|
||||
return sanitized_message
|
||||
|
||||
|
||||
def _sanitize_openai_response(response_json: dict) -> dict:
|
||||
"""净化OpenAI兼容API的响应体。"""
|
||||
try:
|
||||
sanitized_json = copy.deepcopy(response_json)
|
||||
if "choices" in sanitized_json and isinstance(sanitized_json["choices"], list):
|
||||
for choice in sanitized_json["choices"]:
|
||||
if "message" in choice and isinstance(choice["message"], dict):
|
||||
message = choice["message"]
|
||||
if "images" in message and isinstance(message["images"], list):
|
||||
for i, image_info in enumerate(message["images"]):
|
||||
if "image_url" in image_info and isinstance(
|
||||
image_info["image_url"], dict
|
||||
):
|
||||
url = image_info["image_url"].get("url", "")
|
||||
message["images"][i]["image_url"]["url"] = (
|
||||
_truncate_base64_string(url)
|
||||
)
|
||||
return sanitized_json
|
||||
except Exception:
|
||||
return response_json
|
||||
|
||||
|
||||
def _sanitize_openai_request(body: dict) -> dict:
|
||||
"""净化OpenAI兼容API的请求体,主要截断图片base64。"""
|
||||
try:
|
||||
sanitized_json = copy.deepcopy(body)
|
||||
if "messages" in sanitized_json and isinstance(
|
||||
sanitized_json["messages"], list
|
||||
):
|
||||
for message in sanitized_json["messages"]:
|
||||
if "content" in message and isinstance(message["content"], list):
|
||||
for i, part in enumerate(message["content"]):
|
||||
if part.get("type") == "image_url":
|
||||
if "image_url" in part and isinstance(
|
||||
part["image_url"], dict
|
||||
):
|
||||
url = part["image_url"].get("url", "")
|
||||
message["content"][i]["image_url"]["url"] = (
|
||||
_truncate_base64_string(url)
|
||||
)
|
||||
return sanitized_json
|
||||
except Exception:
|
||||
return body
|
||||
|
||||
|
||||
def _sanitize_gemini_response(response_json: dict) -> dict:
|
||||
"""净化Gemini API的响应体,处理文本和图片生成两种格式。"""
|
||||
try:
|
||||
sanitized_json = copy.deepcopy(response_json)
|
||||
|
||||
def _process_candidates(candidates_list: list):
|
||||
"""辅助函数,用于处理任何 candidates 列表。"""
|
||||
if not isinstance(candidates_list, list):
|
||||
return
|
||||
for candidate in candidates_list:
|
||||
if "content" in candidate and isinstance(candidate["content"], dict):
|
||||
content = candidate["content"]
|
||||
if "parts" in content and isinstance(content["parts"], list):
|
||||
for i, part in enumerate(content["parts"]):
|
||||
if "inlineData" in part and isinstance(
|
||||
part["inlineData"], dict
|
||||
):
|
||||
data = part["inlineData"].get("data", "")
|
||||
if isinstance(data, str) and len(data) > 256:
|
||||
content["parts"][i]["inlineData"]["data"] = (
|
||||
f"[base64_data_omitted_len={len(data)}]"
|
||||
)
|
||||
|
||||
if "candidates" in sanitized_json:
|
||||
_process_candidates(sanitized_json["candidates"])
|
||||
|
||||
if "image_generation" in sanitized_json and isinstance(
|
||||
sanitized_json["image_generation"], dict
|
||||
):
|
||||
if "candidates" in sanitized_json["image_generation"]:
|
||||
_process_candidates(sanitized_json["image_generation"]["candidates"])
|
||||
|
||||
return sanitized_json
|
||||
except Exception:
|
||||
return response_json
|
||||
|
||||
|
||||
def _sanitize_gemini_request(body: dict) -> dict:
|
||||
"""净化Gemini API的请求体,进行结构转换和总结。"""
|
||||
try:
|
||||
sanitized_body = copy.deepcopy(body)
|
||||
if "contents" in sanitized_body and isinstance(
|
||||
sanitized_body["contents"], list
|
||||
):
|
||||
for content_item in sanitized_body["contents"]:
|
||||
if "parts" in content_item and isinstance(content_item["parts"], list):
|
||||
media_summary = []
|
||||
new_parts = []
|
||||
for part in content_item["parts"]:
|
||||
if "inlineData" in part and isinstance(
|
||||
part["inlineData"], dict
|
||||
):
|
||||
data = part["inlineData"].get("data")
|
||||
if isinstance(data, str):
|
||||
mime_type = part["inlineData"].get(
|
||||
"mimeType", "unknown"
|
||||
)
|
||||
media_summary.append(f"{mime_type} ({len(data)} chars)")
|
||||
continue
|
||||
new_parts.append(part)
|
||||
|
||||
if media_summary:
|
||||
summary_text = (
|
||||
f"[多模态内容: {len(media_summary)}个文件 - "
|
||||
f"{', '.join(media_summary)}]"
|
||||
)
|
||||
new_parts.insert(0, {"text": summary_text})
|
||||
|
||||
content_item["parts"] = new_parts
|
||||
return sanitized_body
|
||||
except Exception:
|
||||
return body
|
||||
|
||||
|
||||
def sanitize_for_logging(data: Any, context: str | None = None) -> Any:
|
||||
"""
|
||||
统一的日志净化入口。
|
||||
|
||||
Args:
|
||||
data: 需要净化的数据 (dict, Message, etc.).
|
||||
context: 净化场景的上下文标识,例如 'gemini_request', 'openai_response'.
|
||||
|
||||
Returns:
|
||||
净化后的数据。
|
||||
"""
|
||||
if context == "nonebot_message":
|
||||
if isinstance(data, Message):
|
||||
return _sanitize_nonebot_message(data)
|
||||
elif context == "openai_response":
|
||||
if isinstance(data, dict):
|
||||
return _sanitize_openai_response(data)
|
||||
elif context == "gemini_response":
|
||||
if isinstance(data, dict):
|
||||
return _sanitize_gemini_response(data)
|
||||
elif context == "gemini_request":
|
||||
if isinstance(data, dict):
|
||||
return _sanitize_gemini_request(data)
|
||||
elif context == "openai_request":
|
||||
if isinstance(data, dict):
|
||||
return _sanitize_openai_request(data)
|
||||
else:
|
||||
if isinstance(data, str):
|
||||
return _truncate_base64_string(data)
|
||||
|
||||
return data
|
||||
Loading…
Reference in New Issue
Block a user