mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 14:22:55 +08:00
🎨 Ruff
This commit is contained in:
parent
6c045055a0
commit
d09a5b1c72
@ -131,7 +131,9 @@ class BaseAdapter(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def parse_embedding_response(self, response_json: dict[str, Any]) -> list[list[float]]:
|
||||
def parse_embedding_response(
|
||||
self, response_json: dict[str, Any]
|
||||
) -> list[list[float]]:
|
||||
"""解析文本嵌入响应"""
|
||||
pass
|
||||
|
||||
@ -145,7 +147,9 @@ class BaseAdapter(ABC):
|
||||
else str(error_info)
|
||||
)
|
||||
raise LLMException(
|
||||
f"嵌入API错误: {msg}", code=LLMErrorCode.EMBEDDING_FAILED, details=response_json
|
||||
f"嵌入API错误: {msg}",
|
||||
code=LLMErrorCode.EMBEDDING_FAILED,
|
||||
details=response_json,
|
||||
)
|
||||
|
||||
def get_api_url(self, model: "LLMModel", endpoint: str) -> str:
|
||||
@ -170,7 +174,9 @@ class BaseAdapter(ABC):
|
||||
)
|
||||
return headers
|
||||
|
||||
def convert_messages_to_openai_format(self, messages: list["LLMMessage"]) -> list[dict[str, Any]]:
|
||||
def convert_messages_to_openai_format(
|
||||
self, messages: list["LLMMessage"]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""将LLMMessage转换为OpenAI格式 - 通用方法"""
|
||||
openai_messages: list[dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
@ -190,7 +196,10 @@ class BaseAdapter(ABC):
|
||||
content_parts.append({"type": "text", "text": part.text})
|
||||
elif part.type == "image":
|
||||
content_parts.append(
|
||||
{"type": "image_url", "image_url": {"url": part.image_source}}
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": part.image_source},
|
||||
}
|
||||
)
|
||||
openai_msg["content"] = content_parts
|
||||
|
||||
@ -247,9 +256,13 @@ class BaseAdapter(ABC):
|
||||
)
|
||||
)
|
||||
except KeyError as e:
|
||||
logger.warning(f"解析OpenAI工具调用数据时缺少键: {tc_data}, 错误: {e}")
|
||||
logger.warning(
|
||||
f"解析OpenAI工具调用数据时缺少键: {tc_data}, 错误: {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"解析OpenAI工具调用数据时出错: {tc_data}, 错误: {e}")
|
||||
logger.warning(
|
||||
f"解析OpenAI工具调用数据时出错: {tc_data}, 错误: {e}"
|
||||
)
|
||||
if not parsed_tool_calls:
|
||||
parsed_tool_calls = None
|
||||
|
||||
@ -268,7 +281,11 @@ class BaseAdapter(ABC):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析OpenAI格式响应失败: {e}", e=e)
|
||||
raise LLMException(f"解析API响应失败: {e}", code=LLMErrorCode.RESPONSE_PARSE_ERROR, cause=e)
|
||||
raise LLMException(
|
||||
f"解析API响应失败: {e}",
|
||||
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
def validate_response(self, response_json: dict[str, Any]) -> None:
|
||||
"""验证API响应,解析不同API的错误结构"""
|
||||
@ -291,9 +308,14 @@ class BaseAdapter(ABC):
|
||||
"max_tokens_exceeded": LLMErrorCode.CONTEXT_LENGTH_EXCEEDED,
|
||||
}
|
||||
|
||||
llm_error_code = error_code_mapping.get(error_code, LLMErrorCode.API_RESPONSE_INVALID)
|
||||
llm_error_code = error_code_mapping.get(
|
||||
error_code, LLMErrorCode.API_RESPONSE_INVALID
|
||||
)
|
||||
|
||||
logger.error(f"API返回错误: {error_message} (代码: {error_code}, 类型: {error_type})")
|
||||
logger.error(
|
||||
f"API返回错误: {error_message} "
|
||||
f"(代码: {error_code}, 类型: {error_type})"
|
||||
)
|
||||
else:
|
||||
error_message = str(error_info)
|
||||
error_code = "unknown"
|
||||
@ -314,7 +336,10 @@ class BaseAdapter(ABC):
|
||||
finish_reason = candidate.get("finishReason")
|
||||
if finish_reason in ["SAFETY", "RECITATION"]:
|
||||
safety_ratings = candidate.get("safetyRatings", [])
|
||||
logger.warning(f"Gemini内容被安全过滤: {finish_reason}, 安全评级: {safety_ratings}")
|
||||
logger.warning(
|
||||
f"Gemini内容被安全过滤: {finish_reason}, "
|
||||
f"安全评级: {safety_ratings}"
|
||||
)
|
||||
raise LLMException(
|
||||
f"内容被安全过滤: {finish_reason}",
|
||||
code=LLMErrorCode.CONTENT_FILTERED,
|
||||
@ -342,7 +367,9 @@ class BaseAdapter(ABC):
|
||||
return config.to_api_params(model.api_type, model.model_name)
|
||||
|
||||
if model._generation_config is not None:
|
||||
return model._generation_config.to_api_params(model.api_type, model.model_name)
|
||||
return model._generation_config.to_api_params(
|
||||
model.api_type, model.model_name
|
||||
)
|
||||
|
||||
base_config = {}
|
||||
if model.temperature is not None:
|
||||
@ -417,6 +444,7 @@ class OpenAICompatAdapter(BaseAdapter):
|
||||
is_advanced: bool = False,
|
||||
) -> ResponseData:
|
||||
"""解析响应 - 直接使用基类的 OpenAI 格式解析"""
|
||||
_ = model, is_advanced # 未使用的参数
|
||||
return self.parse_openai_response(response_json)
|
||||
|
||||
def prepare_embedding_request(
|
||||
@ -428,6 +456,7 @@ class OpenAICompatAdapter(BaseAdapter):
|
||||
**kwargs: Any,
|
||||
) -> RequestData:
|
||||
"""准备嵌入请求 - OpenAI兼容格式"""
|
||||
_ = task_type # 未使用的参数
|
||||
url = self.get_api_url(model, self.get_embedding_endpoint())
|
||||
headers = self.get_base_headers(api_key)
|
||||
|
||||
@ -442,7 +471,9 @@ class OpenAICompatAdapter(BaseAdapter):
|
||||
|
||||
return RequestData(url=url, headers=headers, body=body)
|
||||
|
||||
def parse_embedding_response(self, response_json: dict[str, Any]) -> list[list[float]]:
|
||||
def parse_embedding_response(
|
||||
self, response_json: dict[str, Any]
|
||||
) -> list[list[float]]:
|
||||
"""解析嵌入响应 - OpenAI兼容格式"""
|
||||
self.validate_embedding_response(response_json)
|
||||
|
||||
|
||||
@ -48,7 +48,9 @@ class GeminiAdapter(BaseAdapter):
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
) -> RequestData:
|
||||
"""准备高级请求"""
|
||||
return self._prepare_request(model, api_key, messages, config, tools, tool_choice)
|
||||
return self._prepare_request(
|
||||
model, api_key, messages, config, tools, tool_choice
|
||||
)
|
||||
|
||||
def _prepare_request(
|
||||
self,
|
||||
@ -75,7 +77,9 @@ class GeminiAdapter(BaseAdapter):
|
||||
if isinstance(msg.content, str):
|
||||
system_instruction_parts = [{"text": msg.content}]
|
||||
elif isinstance(msg.content, list):
|
||||
system_instruction_parts = [part.convert_for_api("gemini") for part in msg.content]
|
||||
system_instruction_parts = [
|
||||
part.convert_for_api("gemini") for part in msg.content
|
||||
]
|
||||
continue
|
||||
|
||||
elif msg.role == "user":
|
||||
@ -115,12 +119,21 @@ class GeminiAdapter(BaseAdapter):
|
||||
import json
|
||||
|
||||
try:
|
||||
content_str = msg.content if isinstance(msg.content, str) else str(msg.content)
|
||||
content_str = (
|
||||
msg.content
|
||||
if isinstance(msg.content, str)
|
||||
else str(msg.content)
|
||||
)
|
||||
tool_result_obj = json.loads(content_str)
|
||||
except json.JSONDecodeError:
|
||||
content_str = msg.content if isinstance(msg.content, str) else str(msg.content)
|
||||
content_str = (
|
||||
msg.content
|
||||
if isinstance(msg.content, str)
|
||||
else str(msg.content)
|
||||
)
|
||||
logger.warning(
|
||||
f"工具 {msg.name} 的结果不是有效的 JSON: {content_str}. 包装为原始字符串。"
|
||||
f"工具 {msg.name} 的结果不是有效的 JSON: {content_str}. "
|
||||
f"包装为原始字符串。"
|
||||
)
|
||||
tool_result_obj = {"raw_output": content_str}
|
||||
|
||||
@ -144,7 +157,9 @@ class GeminiAdapter(BaseAdapter):
|
||||
for tool_item in tools:
|
||||
if isinstance(tool_item, dict):
|
||||
if "name" in tool_item and "description" in tool_item:
|
||||
all_tools_for_request.append({"functionDeclarations": [tool_item]})
|
||||
all_tools_for_request.append(
|
||||
{"functionDeclarations": [tool_item]}
|
||||
)
|
||||
else:
|
||||
all_tools_for_request.append(tool_item)
|
||||
else:
|
||||
@ -152,7 +167,9 @@ class GeminiAdapter(BaseAdapter):
|
||||
|
||||
if effective_config:
|
||||
if getattr(effective_config, "enable_grounding", False):
|
||||
has_explicit_gs_tool = any("googleSearch" in tool_item for tool_item in all_tools_for_request)
|
||||
has_explicit_gs_tool = any(
|
||||
"googleSearch" in tool_item for tool_item in all_tools_for_request
|
||||
)
|
||||
if not has_explicit_gs_tool:
|
||||
all_tools_for_request.append({"googleSearch": {}})
|
||||
logger.debug("隐式启用 Google Search 工具进行信息来源关联。")
|
||||
@ -166,7 +183,9 @@ class GeminiAdapter(BaseAdapter):
|
||||
logger.debug("隐式启用代码执行工具。")
|
||||
|
||||
if all_tools_for_request:
|
||||
gemini_api_tools = self._convert_tools_to_gemini_format(all_tools_for_request)
|
||||
gemini_api_tools = self._convert_tools_to_gemini_format(
|
||||
all_tools_for_request
|
||||
)
|
||||
if gemini_api_tools:
|
||||
body["tools"] = gemini_api_tools
|
||||
|
||||
@ -180,11 +199,17 @@ class GeminiAdapter(BaseAdapter):
|
||||
if mode_upper in ["AUTO", "NONE", "ANY"]:
|
||||
body["toolConfig"] = {"functionCallingConfig": {"mode": mode_upper}}
|
||||
else:
|
||||
body["toolConfig"] = self._convert_tool_choice_to_gemini(final_tool_choice)
|
||||
body["toolConfig"] = self._convert_tool_choice_to_gemini(
|
||||
final_tool_choice
|
||||
)
|
||||
else:
|
||||
body["toolConfig"] = self._convert_tool_choice_to_gemini(final_tool_choice)
|
||||
body["toolConfig"] = self._convert_tool_choice_to_gemini(
|
||||
final_tool_choice
|
||||
)
|
||||
|
||||
final_generation_config = self._build_gemini_generation_config(model, effective_config)
|
||||
final_generation_config = self._build_gemini_generation_config(
|
||||
model, effective_config
|
||||
)
|
||||
if final_generation_config:
|
||||
body["generationConfig"] = final_generation_config
|
||||
|
||||
@ -203,7 +228,9 @@ class GeminiAdapter(BaseAdapter):
|
||||
"""应用配置覆盖 - Gemini 不需要额外的配置覆盖"""
|
||||
return body
|
||||
|
||||
def _get_gemini_endpoint(self, model: "LLMModel", config: "LLMGenerationConfig | None" = None) -> str:
|
||||
def _get_gemini_endpoint(
|
||||
self, model: "LLMModel", config: "LLMGenerationConfig | None" = None
|
||||
) -> str:
|
||||
"""根据配置选择Gemini API端点"""
|
||||
if config:
|
||||
if getattr(config, "enable_code_execution", False):
|
||||
@ -214,7 +241,9 @@ class GeminiAdapter(BaseAdapter):
|
||||
|
||||
return f"/v1beta/models/{model.model_name}:generateContent"
|
||||
|
||||
def _convert_tools_to_gemini_format(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
def _convert_tools_to_gemini_format(
|
||||
self, tools: list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""转换工具格式为Gemini格式"""
|
||||
gemini_tools = []
|
||||
|
||||
@ -232,7 +261,9 @@ class GeminiAdapter(BaseAdapter):
|
||||
}
|
||||
gemini_tools.append(gemini_tool)
|
||||
elif tool.get("type") == "code_execution":
|
||||
gemini_tools.append({"codeExecution": {"language": tool.get("language", "python")}})
|
||||
gemini_tools.append(
|
||||
{"codeExecution": {"language": tool.get("language", "python")}}
|
||||
)
|
||||
elif tool.get("type") == "google_search":
|
||||
gemini_tools.append({"googleSearch": {}})
|
||||
elif "googleSearch" in tool:
|
||||
@ -242,18 +273,26 @@ class GeminiAdapter(BaseAdapter):
|
||||
|
||||
return gemini_tools
|
||||
|
||||
def _convert_tool_choice_to_gemini(self, tool_choice_value: str | dict[str, Any]) -> dict[str, Any]:
|
||||
def _convert_tool_choice_to_gemini(
|
||||
self, tool_choice_value: str | dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""转换工具选择策略为Gemini格式"""
|
||||
if isinstance(tool_choice_value, str):
|
||||
mode_upper = tool_choice_value.upper()
|
||||
if mode_upper in ["AUTO", "NONE", "ANY"]:
|
||||
return {"functionCallingConfig": {"mode": mode_upper}}
|
||||
else:
|
||||
logger.warning(f"不支持的 tool_choice 字符串值: '{tool_choice_value}'。回退到 AUTO。")
|
||||
logger.warning(
|
||||
f"不支持的 tool_choice 字符串值: '{tool_choice_value}'。"
|
||||
f"回退到 AUTO。"
|
||||
)
|
||||
return {"functionCallingConfig": {"mode": "AUTO"}}
|
||||
|
||||
elif isinstance(tool_choice_value, dict):
|
||||
if tool_choice_value.get("type") == "function" and "function" in tool_choice_value:
|
||||
if (
|
||||
tool_choice_value.get("type") == "function"
|
||||
and "function" in tool_choice_value
|
||||
):
|
||||
func_name = tool_choice_value["function"].get("name")
|
||||
if func_name:
|
||||
return {
|
||||
@ -263,17 +302,26 @@ class GeminiAdapter(BaseAdapter):
|
||||
}
|
||||
}
|
||||
else:
|
||||
logger.warning(f"tool_choice dict 中的函数名无效: {tool_choice_value}。回退到 AUTO。")
|
||||
logger.warning(
|
||||
f"tool_choice dict 中的函数名无效: {tool_choice_value}。"
|
||||
f"回退到 AUTO。"
|
||||
)
|
||||
return {"functionCallingConfig": {"mode": "AUTO"}}
|
||||
|
||||
elif "functionCallingConfig" in tool_choice_value:
|
||||
return {"functionCallingConfig": tool_choice_value["functionCallingConfig"]}
|
||||
return {
|
||||
"functionCallingConfig": tool_choice_value["functionCallingConfig"]
|
||||
}
|
||||
|
||||
else:
|
||||
logger.warning(f"不支持的 tool_choice dict 值: {tool_choice_value}。回退到 AUTO。")
|
||||
logger.warning(
|
||||
f"不支持的 tool_choice dict 值: {tool_choice_value}。回退到 AUTO。"
|
||||
)
|
||||
return {"functionCallingConfig": {"mode": "AUTO"}}
|
||||
|
||||
logger.warning(f"tool_choice 的类型无效: {type(tool_choice_value)}。回退到 AUTO。")
|
||||
logger.warning(
|
||||
f"tool_choice 的类型无效: {type(tool_choice_value)}。回退到 AUTO。"
|
||||
)
|
||||
return {"functionCallingConfig": {"mode": "AUTO"}}
|
||||
|
||||
def _build_gemini_generation_config(
|
||||
@ -285,11 +333,15 @@ class GeminiAdapter(BaseAdapter):
|
||||
effective_config = config if config is not None else model._generation_config
|
||||
|
||||
if effective_config:
|
||||
base_api_params = effective_config.to_api_params(api_type="gemini", model_name=model.model_name)
|
||||
base_api_params = effective_config.to_api_params(
|
||||
api_type="gemini", model_name=model.model_name
|
||||
)
|
||||
generation_config.update(base_api_params)
|
||||
|
||||
if getattr(effective_config, "response_mime_type", None):
|
||||
generation_config["responseMimeType"] = effective_config.response_mime_type
|
||||
generation_config["responseMimeType"] = (
|
||||
effective_config.response_mime_type
|
||||
)
|
||||
|
||||
if getattr(effective_config, "response_schema", None):
|
||||
generation_config["responseSchema"] = effective_config.response_schema
|
||||
@ -303,15 +355,22 @@ class GeminiAdapter(BaseAdapter):
|
||||
if getattr(effective_config, "response_modalities", None):
|
||||
modalities = effective_config.response_modalities
|
||||
if isinstance(modalities, list):
|
||||
generation_config["responseModalities"] = [m.upper() for m in modalities]
|
||||
generation_config["responseModalities"] = [
|
||||
m.upper() for m in modalities
|
||||
]
|
||||
elif isinstance(modalities, str):
|
||||
generation_config["responseModalities"] = [modalities.upper()]
|
||||
|
||||
generation_config = {k: v for k, v in generation_config.items() if v is not None}
|
||||
generation_config = {
|
||||
k: v for k, v in generation_config.items() if v is not None
|
||||
}
|
||||
|
||||
if generation_config:
|
||||
param_keys = list(generation_config.keys())
|
||||
logger.debug(f"构建Gemini生成配置完成,包含 {len(generation_config)} 个参数: {param_keys}")
|
||||
logger.debug(
|
||||
f"构建Gemini生成配置完成,包含 {len(generation_config)} 个参数: "
|
||||
f"{param_keys}"
|
||||
)
|
||||
|
||||
return generation_config
|
||||
|
||||
@ -337,7 +396,9 @@ class GeminiAdapter(BaseAdapter):
|
||||
safety_settings.append({"category": category, "threshold": threshold})
|
||||
else:
|
||||
for category in safety_categories:
|
||||
safety_settings.append({"category": category, "threshold": "BLOCK_MEDIUM_AND_ABOVE"})
|
||||
safety_settings.append(
|
||||
{"category": category, "threshold": "BLOCK_MEDIUM_AND_ABOVE"}
|
||||
)
|
||||
|
||||
return safety_settings if safety_settings else None
|
||||
|
||||
@ -368,12 +429,18 @@ class GeminiAdapter(BaseAdapter):
|
||||
|
||||
candidate = candidates[0]
|
||||
|
||||
if candidate.get("finishReason") in ["RECITATION", "OTHER"] and not candidate.get("content"):
|
||||
if candidate.get("finishReason") in [
|
||||
"RECITATION",
|
||||
"OTHER",
|
||||
] and not candidate.get("content"):
|
||||
logger.warning(
|
||||
f"Gemini candidate finished with reason '{candidate.get('finishReason')}' and no content."
|
||||
f"Gemini candidate finished with reason "
|
||||
f"'{candidate.get('finishReason')}' and no content."
|
||||
)
|
||||
return ResponseData(
|
||||
text="", raw_response=response_json, usage_info=response_json.get("usageMetadata")
|
||||
text="",
|
||||
raw_response=response_json,
|
||||
usage_info=response_json.get("usageMetadata"),
|
||||
)
|
||||
|
||||
content_data = candidate.get("content", {})
|
||||
@ -394,9 +461,10 @@ class GeminiAdapter(BaseAdapter):
|
||||
|
||||
from ..types.models import LLMToolCall, LLMToolFunction
|
||||
|
||||
call_id = f"call_{model.provider_name}_{len(parsed_tool_calls)}"
|
||||
parsed_tool_calls.append(
|
||||
LLMToolCall(
|
||||
id=f"call_{model.provider_name}_{len(parsed_tool_calls)}",
|
||||
id=call_id,
|
||||
function=LLMToolFunction(
|
||||
name=fc_data["name"],
|
||||
arguments=json.dumps(fc_data["args"]),
|
||||
@ -404,16 +472,22 @@ class GeminiAdapter(BaseAdapter):
|
||||
)
|
||||
)
|
||||
except KeyError as e:
|
||||
logger.warning(f"解析Gemini functionCall时缺少键: {fc_data}, 错误: {e}")
|
||||
logger.warning(
|
||||
f"解析Gemini functionCall时缺少键: {fc_data}, 错误: {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"解析Gemini functionCall时出错: {fc_data}, 错误: {e}")
|
||||
logger.warning(
|
||||
f"解析Gemini functionCall时出错: {fc_data}, 错误: {e}"
|
||||
)
|
||||
elif "codeExecutionResult" in part:
|
||||
result = part["codeExecutionResult"]
|
||||
if result.get("outcome") == "OK":
|
||||
output = result.get("output", "")
|
||||
text_content += f"\n[代码执行结果]:\n{output}\n"
|
||||
else:
|
||||
text_content += f"\n[代码执行失败]: {result.get('outcome', 'UNKNOWN')}\n"
|
||||
text_content += (
|
||||
f"\n[代码执行失败]: {result.get('outcome', 'UNKNOWN')}\n"
|
||||
)
|
||||
|
||||
usage_info = response_json.get("usageMetadata")
|
||||
|
||||
@ -436,7 +510,11 @@ class GeminiAdapter(BaseAdapter):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析 Gemini 响应失败: {e}", e=e)
|
||||
raise LLMException(f"解析API响应失败: {e}", code=LLMErrorCode.RESPONSE_PARSE_ERROR, cause=e)
|
||||
raise LLMException(
|
||||
f"解析API响应失败: {e}",
|
||||
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
def prepare_embedding_request(
|
||||
self,
|
||||
@ -474,7 +552,9 @@ class GeminiAdapter(BaseAdapter):
|
||||
body = {"requests": requests_payload}
|
||||
return RequestData(url=url, headers=headers, body=body)
|
||||
|
||||
def parse_embedding_response(self, response_json: dict[str, Any]) -> list[list[float]]:
|
||||
def parse_embedding_response(
|
||||
self, response_json: dict[str, Any]
|
||||
) -> list[list[float]]:
|
||||
"""解析文本嵌入响应"""
|
||||
try:
|
||||
embeddings_data = response_json["embeddings"]
|
||||
@ -482,18 +562,26 @@ class GeminiAdapter(BaseAdapter):
|
||||
except KeyError as e:
|
||||
logger.error(f"解析Gemini嵌入响应时缺少键: {e}. 响应: {response_json}")
|
||||
raise LLMException(
|
||||
"Gemini嵌入响应格式错误", code=LLMErrorCode.RESPONSE_PARSE_ERROR, details={"error": str(e)}
|
||||
"Gemini嵌入响应格式错误",
|
||||
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
|
||||
details={"error": str(e)},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"解析Gemini嵌入响应时发生未知错误: {e}. 响应: {response_json}")
|
||||
logger.error(
|
||||
f"解析Gemini嵌入响应时发生未知错误: {e}. 响应: {response_json}"
|
||||
)
|
||||
raise LLMException(
|
||||
f"解析Gemini嵌入响应失败: {e}", code=LLMErrorCode.RESPONSE_PARSE_ERROR, cause=e
|
||||
f"解析Gemini嵌入响应失败: {e}",
|
||||
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
def validate_embedding_response(self, response_json: dict[str, Any]) -> None:
|
||||
"""验证嵌入响应"""
|
||||
super().validate_embedding_response(response_json)
|
||||
if "embeddings" not in response_json or not isinstance(response_json["embeddings"], list):
|
||||
if "embeddings" not in response_json or not isinstance(
|
||||
response_json["embeddings"], list
|
||||
):
|
||||
raise LLMException(
|
||||
"Gemini嵌入响应缺少'embeddings'字段或格式不正确",
|
||||
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
|
||||
|
||||
@ -91,7 +91,9 @@ class AI:
|
||||
|
||||
if isinstance(message, str):
|
||||
llm_messages = [LLMMessage.user(message)]
|
||||
elif isinstance(message, list) and all(isinstance(part, LLMContentPart) for part in message):
|
||||
elif isinstance(message, list) and all(
|
||||
isinstance(part, LLMContentPart) for part in message
|
||||
):
|
||||
llm_messages = [LLMMessage.user(message)]
|
||||
elif isinstance(message, LLMMessage):
|
||||
llm_messages = [message]
|
||||
@ -103,7 +105,9 @@ class AI:
|
||||
code=LLMErrorCode.API_REQUEST_FAILED,
|
||||
)
|
||||
|
||||
response = await self._execute_generation(llm_messages, model, "聊天失败", kwargs)
|
||||
response = await self._execute_generation(
|
||||
llm_messages, model, "聊天失败", kwargs
|
||||
)
|
||||
return response.text
|
||||
|
||||
async def code(
|
||||
@ -135,7 +139,12 @@ class AI:
|
||||
}
|
||||
|
||||
async def search(
|
||||
self, query: str | UniMessage, *, model: ModelName = None, instruction: str = "", **kwargs: Any
|
||||
self,
|
||||
query: str | UniMessage,
|
||||
*,
|
||||
model: ModelName = None,
|
||||
instruction: str = "",
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""信息搜索 - 支持多模态输入"""
|
||||
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
|
||||
@ -154,7 +163,9 @@ class AI:
|
||||
if instruction:
|
||||
final_messages.append(LLMMessage.user(instruction))
|
||||
else:
|
||||
raise LLMException("搜索内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED)
|
||||
raise LLMException(
|
||||
"搜索内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
|
||||
)
|
||||
else:
|
||||
final_messages.append(LLMMessage.user(content_parts))
|
||||
|
||||
@ -206,7 +217,9 @@ class AI:
|
||||
if instruction:
|
||||
final_messages.append(LLMMessage.user(instruction))
|
||||
else:
|
||||
raise LLMException("分析内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED)
|
||||
raise LLMException(
|
||||
"分析内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
|
||||
)
|
||||
else:
|
||||
final_messages.append(LLMMessage.user(content_parts))
|
||||
|
||||
@ -241,7 +254,12 @@ class AI:
|
||||
tool_choice = "none"
|
||||
|
||||
response = await self._execute_generation(
|
||||
final_messages, model, "内容分析失败", kwargs, llm_tools=llm_tools, tool_choice=tool_choice
|
||||
final_messages,
|
||||
model,
|
||||
"内容分析失败",
|
||||
kwargs,
|
||||
llm_tools=llm_tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
if response.tool_calls:
|
||||
@ -260,8 +278,12 @@ class AI:
|
||||
) -> LLMResponse:
|
||||
"""通用的生成执行方法,封装重复的模型获取、配置合并和异常处理逻辑"""
|
||||
try:
|
||||
resolved_model_name = self._resolve_model_name(model_name or self.config.model)
|
||||
final_config_dict = self._merge_config(config_overrides, base_config=base_config)
|
||||
resolved_model_name = self._resolve_model_name(
|
||||
model_name or self.config.model
|
||||
)
|
||||
final_config_dict = self._merge_config(
|
||||
config_overrides, base_config=base_config
|
||||
)
|
||||
|
||||
async with await get_model_instance(
|
||||
resolved_model_name, override_config=final_config_dict
|
||||
@ -316,7 +338,9 @@ class AI:
|
||||
if self.config.enable_gemini_thinking:
|
||||
final_config["thinking_budget"] = 0.8
|
||||
if self.config.enable_gemini_safe_mode:
|
||||
final_config["safety_settings"] = CommonOverrides.gemini_safe().safety_settings
|
||||
final_config["safety_settings"] = (
|
||||
CommonOverrides.gemini_safe().safety_settings
|
||||
)
|
||||
if self.config.enable_gemini_multimodal:
|
||||
final_config.update(CommonOverrides.gemini_multimodal().to_dict())
|
||||
if self.config.enable_gemini_grounding:
|
||||
@ -341,10 +365,13 @@ class AI:
|
||||
return []
|
||||
|
||||
try:
|
||||
resolved_model_str = model or self.config.default_embedding_model or self.config.model
|
||||
resolved_model_str = (
|
||||
model or self.config.default_embedding_model or self.config.model
|
||||
)
|
||||
if not resolved_model_str:
|
||||
raise LLMException(
|
||||
"使用 embed 功能时必须指定嵌入模型名称,或在 AIConfig 中配置 default_embedding_model。",
|
||||
"使用 embed 功能时必须指定嵌入模型名称,"
|
||||
"或在 AIConfig 中配置 default_embedding_model。",
|
||||
code=LLMErrorCode.MODEL_NOT_FOUND,
|
||||
)
|
||||
resolved_model_str = self._resolve_model_name(resolved_model_str)
|
||||
@ -360,7 +387,9 @@ class AI:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"文本嵌入失败: {e}", e=e)
|
||||
raise LLMException(f"文本嵌入失败: {e}", code=LLMErrorCode.EMBEDDING_FAILED, cause=e)
|
||||
raise LLMException(
|
||||
f"文本嵌入失败: {e}", code=LLMErrorCode.EMBEDDING_FAILED, cause=e
|
||||
)
|
||||
|
||||
|
||||
async def chat(
|
||||
@ -443,7 +472,9 @@ async def analyze_multimodal(
|
||||
**kwargs: Any,
|
||||
) -> str | LLMResponse:
|
||||
"""多模态分析便捷函数"""
|
||||
message = create_multimodal_message(text=text, images=images, videos=videos, audios=audios)
|
||||
message = create_multimodal_message(
|
||||
text=text, images=images, videos=videos, audios=audios
|
||||
)
|
||||
return await analyze(message, instruction=instruction, model=model, **kwargs)
|
||||
|
||||
|
||||
@ -458,7 +489,9 @@ async def search_multimodal(
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""多模态搜索便捷函数"""
|
||||
message = create_multimodal_message(text=text, images=images, videos=videos, audios=audios)
|
||||
message = create_multimodal_message(
|
||||
text=text, images=images, videos=videos, audios=audios
|
||||
)
|
||||
ai = AI()
|
||||
return await ai.search(message, model=model, instruction=instruction, **kwargs)
|
||||
|
||||
|
||||
@ -15,27 +15,47 @@ from ..types.exceptions import LLMErrorCode, LLMException
|
||||
class ModelConfigOverride(BaseModel):
|
||||
"""模型配置覆盖参数"""
|
||||
|
||||
temperature: float | None = Field(default=None, ge=0.0, le=2.0, description="生成温度")
|
||||
temperature: float | None = Field(
|
||||
default=None, ge=0.0, le=2.0, description="生成温度"
|
||||
)
|
||||
max_tokens: int | None = Field(default=None, gt=0, description="最大输出token数")
|
||||
top_p: float | None = Field(default=None, ge=0.0, le=1.0, description="核采样参数")
|
||||
top_k: int | None = Field(default=None, gt=0, description="Top-K采样参数")
|
||||
frequency_penalty: float | None = Field(default=None, ge=-2.0, le=2.0, description="频率惩罚")
|
||||
presence_penalty: float | None = Field(default=None, ge=-2.0, le=2.0, description="存在惩罚")
|
||||
repetition_penalty: float | None = Field(default=None, ge=0.0, le=2.0, description="重复惩罚")
|
||||
frequency_penalty: float | None = Field(
|
||||
default=None, ge=-2.0, le=2.0, description="频率惩罚"
|
||||
)
|
||||
presence_penalty: float | None = Field(
|
||||
default=None, ge=-2.0, le=2.0, description="存在惩罚"
|
||||
)
|
||||
repetition_penalty: float | None = Field(
|
||||
default=None, ge=0.0, le=2.0, description="重复惩罚"
|
||||
)
|
||||
|
||||
stop: list[str] | str | None = Field(default=None, description="停止序列")
|
||||
|
||||
response_format: ResponseFormat | dict[str, Any] | None = Field(
|
||||
default=None, description="期望的响应格式"
|
||||
)
|
||||
response_mime_type: str | None = Field(default=None, description="响应MIME类型(Gemini专用)")
|
||||
response_schema: dict[str, Any] | None = Field(default=None, description="JSON响应模式")
|
||||
thinking_budget: float | None = Field(default=None, ge=0.0, le=1.0, description="思考预算")
|
||||
response_mime_type: str | None = Field(
|
||||
default=None, description="响应MIME类型(Gemini专用)"
|
||||
)
|
||||
response_schema: dict[str, Any] | None = Field(
|
||||
default=None, description="JSON响应模式"
|
||||
)
|
||||
thinking_budget: float | None = Field(
|
||||
default=None, ge=0.0, le=1.0, description="思考预算"
|
||||
)
|
||||
safety_settings: dict[str, str] | None = Field(default=None, description="安全设置")
|
||||
response_modalities: list[str] | None = Field(default=None, description="响应模态类型")
|
||||
response_modalities: list[str] | None = Field(
|
||||
default=None, description="响应模态类型"
|
||||
)
|
||||
|
||||
enable_code_execution: bool | None = Field(default=None, description="是否启用代码执行")
|
||||
enable_grounding: bool | None = Field(default=None, description="是否启用信息来源关联")
|
||||
enable_code_execution: bool | None = Field(
|
||||
default=None, description="是否启用代码执行"
|
||||
)
|
||||
enable_grounding: bool | None = Field(
|
||||
default=None, description="是否启用信息来源关联"
|
||||
)
|
||||
enable_caching: bool | None = Field(default=None, description="是否启用响应缓存")
|
||||
|
||||
custom_params: dict[str, Any] | None = Field(default=None, description="自定义参数")
|
||||
@ -43,7 +63,16 @@ class ModelConfigOverride(BaseModel):
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""转换为字典,排除None值"""
|
||||
result = {}
|
||||
for key, value in self.model_dump().items():
|
||||
model_data = getattr(self, "model_dump", lambda: {})()
|
||||
if not model_data:
|
||||
model_data = {}
|
||||
for field_name, _ in self.__class__.__dict__.get(
|
||||
"model_fields", {}
|
||||
).items():
|
||||
value = getattr(self, field_name, None)
|
||||
if value is not None:
|
||||
model_data[field_name] = value
|
||||
for key, value in model_data.items():
|
||||
if value is not None:
|
||||
if key == "custom_params" and isinstance(value, dict):
|
||||
result.update(value)
|
||||
@ -110,13 +139,14 @@ class LLMGenerationConfig(ModelConfigOverride):
|
||||
else:
|
||||
params["repetition_penalty"] = self.repetition_penalty
|
||||
|
||||
# 处理 response_format 参数
|
||||
if self.response_format is not None:
|
||||
if isinstance(self.response_format, dict):
|
||||
# 直接使用字典格式的 response_format(如 {'type': 'json_object'})
|
||||
if api_type in ["openai", "zhipu", "deepseek", "general_openai_compat"]:
|
||||
params["response_format"] = self.response_format
|
||||
logger.debug(f"为 {api_type} 使用自定义 response_format: {self.response_format}")
|
||||
logger.debug(
|
||||
f"为 {api_type} 使用自定义 response_format: "
|
||||
f"{self.response_format}"
|
||||
)
|
||||
elif self.response_format == ResponseFormat.JSON:
|
||||
if api_type in ["openai", "zhipu", "deepseek", "general_openai_compat"]:
|
||||
params["response_format"] = {"type": "json_object"}
|
||||
@ -128,9 +158,14 @@ class LLMGenerationConfig(ModelConfigOverride):
|
||||
logger.debug(f"为 {api_type} 启用 JSON MIME 类型输出模式")
|
||||
|
||||
if api_type in ["gemini", "gemini_native"]:
|
||||
if self.response_format != ResponseFormat.JSON and self.response_mime_type is not None:
|
||||
if (
|
||||
self.response_format != ResponseFormat.JSON
|
||||
and self.response_mime_type is not None
|
||||
):
|
||||
params["responseMimeType"] = self.response_mime_type
|
||||
logger.debug(f"使用显式设置的 responseMimeType: {self.response_mime_type}")
|
||||
logger.debug(
|
||||
f"使用显式设置的 responseMimeType: {self.response_mime_type}"
|
||||
)
|
||||
|
||||
if self.response_schema is not None and "responseSchema" not in params:
|
||||
params["responseSchema"] = self.response_schema
|
||||
@ -158,7 +193,9 @@ def validate_override_params(
|
||||
|
||||
if isinstance(override_config, dict):
|
||||
try:
|
||||
filtered_config = {k: v for k, v in override_config.items() if v is not None}
|
||||
filtered_config = {
|
||||
k: v for k, v in override_config.items() if v is not None
|
||||
}
|
||||
return LLMGenerationConfig(**filtered_config)
|
||||
except Exception as e:
|
||||
logger.warning(f"覆盖配置参数验证失败: {e}")
|
||||
@ -171,7 +208,9 @@ def validate_override_params(
|
||||
return override_config
|
||||
|
||||
|
||||
def apply_api_specific_mappings(params: dict[str, Any], api_type: str) -> dict[str, Any]:
|
||||
def apply_api_specific_mappings(
|
||||
params: dict[str, Any], api_type: str
|
||||
) -> dict[str, Any]:
|
||||
"""应用API特定的参数映射"""
|
||||
mapped_params = params.copy()
|
||||
|
||||
@ -204,7 +243,8 @@ def apply_api_specific_mappings(params: dict[str, Any], api_type: str) -> dict[s
|
||||
|
||||
def create_generation_config_from_kwargs(**kwargs) -> LLMGenerationConfig:
|
||||
"""从关键字参数创建生成配置"""
|
||||
known_fields = set(LLMGenerationConfig.model_fields.keys())
|
||||
model_fields = getattr(LLMGenerationConfig, "model_fields", {})
|
||||
known_fields = set(model_fields.keys())
|
||||
known_params = {}
|
||||
custom_params = {}
|
||||
|
||||
|
||||
@ -30,17 +30,25 @@ class CommonOverrides:
|
||||
@staticmethod
|
||||
def concise(max_tokens: int = 100) -> LLMGenerationConfig:
|
||||
"""简洁模式:限制输出长度"""
|
||||
return LLMGenerationConfig(temperature=0.3, max_tokens=max_tokens, stop=["\n\n", "。", "!", "?"])
|
||||
return LLMGenerationConfig(
|
||||
temperature=0.3,
|
||||
max_tokens=max_tokens,
|
||||
stop=["\n\n", "。", "!", "?"],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def detailed(max_tokens: int = 2000) -> LLMGenerationConfig:
|
||||
"""详细模式:鼓励详细输出"""
|
||||
return LLMGenerationConfig(temperature=0.7, max_tokens=max_tokens, frequency_penalty=-0.1)
|
||||
return LLMGenerationConfig(
|
||||
temperature=0.7, max_tokens=max_tokens, frequency_penalty=-0.1
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def gemini_json() -> LLMGenerationConfig:
|
||||
"""Gemini JSON模式:强制JSON输出"""
|
||||
return LLMGenerationConfig(temperature=0.3, response_mime_type="application/json")
|
||||
return LLMGenerationConfig(
|
||||
temperature=0.3, response_mime_type="application/json"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def gemini_thinking(budget: float = 0.8) -> LLMGenerationConfig:
|
||||
@ -96,7 +104,9 @@ class CommonOverrides:
|
||||
temperature=0.5,
|
||||
max_tokens=4096,
|
||||
enable_grounding=True,
|
||||
custom_params={"grounding_config": {"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}}},
|
||||
custom_params={
|
||||
"grounding_config": {"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -119,7 +129,9 @@ class CommonOverrides:
|
||||
enable_caching=True,
|
||||
custom_params={
|
||||
"code_execution_timeout": 30,
|
||||
"grounding_config": {"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}},
|
||||
"grounding_config": {
|
||||
"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
@ -132,7 +144,9 @@ class CommonOverrides:
|
||||
thinking_budget=0.8,
|
||||
enable_grounding=True,
|
||||
response_mime_type="application/json",
|
||||
custom_params={"grounding_config": {"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}}},
|
||||
custom_params={
|
||||
"grounding_config": {"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -40,7 +40,8 @@ class LLMHttpClient:
|
||||
async with self._lock:
|
||||
if self._client is None or self._client.is_closed:
|
||||
logger.debug(
|
||||
f"LLMHttpClient: Initializing new httpx.AsyncClient with config: {self.config}"
|
||||
f"LLMHttpClient: Initializing new httpx.AsyncClient "
|
||||
f"with config: {self.config}"
|
||||
)
|
||||
headers = get_user_agent()
|
||||
limits = httpx.Limits(
|
||||
@ -52,11 +53,13 @@ class LLMHttpClient:
|
||||
headers=headers,
|
||||
limits=limits,
|
||||
timeout=timeout,
|
||||
proxy=self.config.proxy,
|
||||
proxies=self.config.proxy,
|
||||
follow_redirects=True,
|
||||
)
|
||||
if self._client is None:
|
||||
raise LLMException("HTTP client failed to initialize.", LLMErrorCode.CONFIGURATION_ERROR)
|
||||
raise LLMException(
|
||||
"HTTP client failed to initialize.", LLMErrorCode.CONFIGURATION_ERROR
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def post(self, url: str, **kwargs: Any) -> httpx.Response:
|
||||
@ -78,7 +81,8 @@ class LLMHttpClient:
|
||||
)
|
||||
if self._active_requests > 0:
|
||||
logger.warning(
|
||||
f"LLMHttpClient: Closing while {self._active_requests} requests are still active."
|
||||
f"LLMHttpClient: Closing while {self._active_requests} "
|
||||
f"requests are still active."
|
||||
)
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
@ -96,7 +100,9 @@ class LLMHttpClientManager:
|
||||
self._clients: dict[tuple[int, str | None], LLMHttpClient] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
def _get_client_key(self, provider_config: ProviderConfig) -> tuple[int, str | None]:
|
||||
def _get_client_key(
|
||||
self, provider_config: ProviderConfig
|
||||
) -> tuple[int, str | None]:
|
||||
return (provider_config.timeout, provider_config.proxy)
|
||||
|
||||
async def get_client(self, provider_config: ProviderConfig) -> LLMHttpClient:
|
||||
@ -104,15 +110,21 @@ class LLMHttpClientManager:
|
||||
async with self._lock:
|
||||
client = self._clients.get(key)
|
||||
if client and not client.is_closed:
|
||||
logger.debug(f"LLMHttpClientManager: Reusing existing LLMHttpClient for key: {key}")
|
||||
logger.debug(
|
||||
f"LLMHttpClientManager: Reusing existing LLMHttpClient "
|
||||
f"for key: {key}"
|
||||
)
|
||||
return client
|
||||
|
||||
if client and client.is_closed:
|
||||
logger.debug(
|
||||
f"LLMHttpClientManager: Found a closed client for key {key}. Creating a new one."
|
||||
f"LLMHttpClientManager: Found a closed client for key {key}. "
|
||||
f"Creating a new one."
|
||||
)
|
||||
|
||||
logger.debug(f"LLMHttpClientManager: Creating new LLMHttpClient for key: {key}")
|
||||
logger.debug(
|
||||
f"LLMHttpClientManager: Creating new LLMHttpClient for key: {key}"
|
||||
)
|
||||
http_client_config = HttpClientConfig(
|
||||
timeout=provider_config.timeout, proxy=provider_config.proxy
|
||||
)
|
||||
@ -122,9 +134,14 @@ class LLMHttpClientManager:
|
||||
|
||||
async def shutdown(self):
|
||||
async with self._lock:
|
||||
logger.info(f"LLMHttpClientManager: Shutting down. Closing {len(self._clients)} client(s).")
|
||||
logger.info(
|
||||
f"LLMHttpClientManager: Shutting down. "
|
||||
f"Closing {len(self._clients)} client(s)."
|
||||
)
|
||||
close_tasks = [
|
||||
client.close() for client in self._clients.values() if client and not client.is_closed
|
||||
client.close()
|
||||
for client in self._clients.values()
|
||||
if client and not client.is_closed
|
||||
]
|
||||
if close_tasks:
|
||||
await asyncio.gather(*close_tasks, return_exceptions=True)
|
||||
@ -183,11 +200,16 @@ async def with_smart_retry(
|
||||
except LLMException as e:
|
||||
last_exception = e
|
||||
|
||||
if e.code in [LLMErrorCode.API_KEY_INVALID, LLMErrorCode.API_QUOTA_EXCEEDED]:
|
||||
if e.code in [
|
||||
LLMErrorCode.API_KEY_INVALID,
|
||||
LLMErrorCode.API_QUOTA_EXCEEDED,
|
||||
]:
|
||||
if hasattr(e, "details") and e.details and "api_key" in e.details:
|
||||
failed_keys.add(e.details["api_key"])
|
||||
if key_store and provider_name:
|
||||
await key_store.record_failure(e.details["api_key"], e.details.get("status_code"))
|
||||
await key_store.record_failure(
|
||||
e.details["api_key"], e.details.get("status_code")
|
||||
)
|
||||
|
||||
should_retry = _should_retry_llm_error(e, attempt, config.max_retries)
|
||||
if not should_retry:
|
||||
@ -198,7 +220,9 @@ async def with_smart_retry(
|
||||
wait_time = config.retry_delay
|
||||
if config.exponential_backoff:
|
||||
wait_time *= 2**attempt
|
||||
logger.warning(f"请求失败,{wait_time}秒后重试 (第{attempt + 1}次): {e}")
|
||||
logger.warning(
|
||||
f"请求失败,{wait_time}秒后重试 (第{attempt + 1}次): {e}"
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
logger.error(f"重试{config.max_retries}次后仍然失败: {e}")
|
||||
@ -218,7 +242,9 @@ async def with_smart_retry(
|
||||
raise RuntimeError("重试函数未能正常执行且未捕获到异常")
|
||||
|
||||
|
||||
def _should_retry_llm_error(error: LLMException, attempt: int, max_retries: int) -> bool:
|
||||
def _should_retry_llm_error(
|
||||
error: LLMException, attempt: int, max_retries: int
|
||||
) -> bool:
|
||||
"""判断LLM错误是否应该重试"""
|
||||
non_retryable_errors = {
|
||||
LLMErrorCode.MODEL_NOT_FOUND,
|
||||
@ -263,7 +289,10 @@ class KeyStatusStore:
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def get_next_available_key(
|
||||
self, provider_name: str, api_keys: list[str], exclude_keys: set[str] | None = None
|
||||
self,
|
||||
provider_name: str,
|
||||
api_keys: list[str],
|
||||
exclude_keys: set[str] | None = None,
|
||||
) -> str | None:
|
||||
"""获取下一个可用的API密钥(轮询策略)"""
|
||||
if not api_keys:
|
||||
@ -271,7 +300,9 @@ class KeyStatusStore:
|
||||
|
||||
exclude_keys = exclude_keys or set()
|
||||
available_keys = [
|
||||
key for key in api_keys if key not in exclude_keys and self._key_status.get(key, True)
|
||||
key
|
||||
for key in api_keys
|
||||
if key not in exclude_keys and self._key_status.get(key, True)
|
||||
]
|
||||
|
||||
if not available_keys:
|
||||
@ -282,11 +313,15 @@ class KeyStatusStore:
|
||||
|
||||
selected_key = available_keys[current_index % len(available_keys)]
|
||||
|
||||
self._provider_key_index[provider_name] = (current_index + 1) % len(available_keys)
|
||||
self._provider_key_index[provider_name] = (current_index + 1) % len(
|
||||
available_keys
|
||||
)
|
||||
|
||||
import time
|
||||
|
||||
self._key_usage_count[selected_key] = self._key_usage_count.get(selected_key, 0) + 1
|
||||
self._key_usage_count[selected_key] = (
|
||||
self._key_usage_count.get(selected_key, 0) + 1
|
||||
)
|
||||
self._key_last_used[selected_key] = time.time()
|
||||
|
||||
logger.debug(
|
||||
@ -308,7 +343,9 @@ class KeyStatusStore:
|
||||
async with self._lock:
|
||||
if status_code in [401, 403]:
|
||||
self._key_status[api_key] = False
|
||||
logger.warning(f"API密钥认证失败,标记为不可用: {key_id} (状态码: {status_code})")
|
||||
logger.warning(
|
||||
f"API密钥认证失败,标记为不可用: {key_id} (状态码: {status_code})"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"记录API密钥失败使用: {key_id} (状态码: {status_code})")
|
||||
|
||||
|
||||
@ -37,9 +37,13 @@ def parse_provider_model_string(name_str: str | None) -> tuple[str | None, str |
|
||||
return None, None
|
||||
|
||||
|
||||
def _make_cache_key(provider_model_name: str | None, override_config: dict | None) -> str:
|
||||
def _make_cache_key(
|
||||
provider_model_name: str | None, override_config: dict | None
|
||||
) -> str:
|
||||
"""生成缓存键"""
|
||||
config_str = json.dumps(override_config, sort_keys=True) if override_config else "None"
|
||||
config_str = (
|
||||
json.dumps(override_config, sort_keys=True) if override_config else "None"
|
||||
)
|
||||
key_data = f"{provider_model_name}:{config_str}"
|
||||
return hashlib.md5(key_data.encode()).hexdigest()
|
||||
|
||||
@ -62,7 +66,9 @@ def _get_cached_model(cache_key: str) -> LLMModel | None:
|
||||
)
|
||||
model._is_closed = False
|
||||
|
||||
logger.debug(f"使用缓存的模型: {cache_key} -> {model.provider_name}/{model.model_name}")
|
||||
logger.debug(
|
||||
f"使用缓存的模型: {cache_key} -> {model.provider_name}/{model.model_name}"
|
||||
)
|
||||
return model
|
||||
return None
|
||||
|
||||
@ -113,7 +119,10 @@ def get_configured_providers() -> list[ProviderConfig]:
|
||||
ai_config = get_ai_config()
|
||||
providers_raw = ai_config.get(PROVIDERS_CONFIG_KEY, [])
|
||||
if not isinstance(providers_raw, list):
|
||||
logger.error(f"配置项 {AI_CONFIG_GROUP}.{PROVIDERS_CONFIG_KEY} 不是一个列表,将使用空列表。")
|
||||
logger.error(
|
||||
f"配置项 {AI_CONFIG_GROUP}.{PROVIDERS_CONFIG_KEY} 不是一个列表,"
|
||||
f"将使用空列表。"
|
||||
)
|
||||
return []
|
||||
|
||||
valid_providers = []
|
||||
@ -128,7 +137,9 @@ def get_configured_providers() -> list[ProviderConfig]:
|
||||
continue
|
||||
|
||||
if not item.get("api_key"):
|
||||
logger.warning(f"Provider '{item['name']}' 缺少 'api_key' 字段,已跳过。")
|
||||
logger.warning(
|
||||
f"Provider '{item['name']}' 缺少 'api_key' 字段,已跳过。"
|
||||
)
|
||||
continue
|
||||
|
||||
if "api_type" not in item or not item["api_type"]:
|
||||
@ -159,7 +170,9 @@ def get_configured_providers() -> list[ProviderConfig]:
|
||||
return valid_providers
|
||||
|
||||
|
||||
def find_model_config(provider_name: str, model_name: str) -> tuple[ProviderConfig, ModelDetail] | None:
|
||||
def find_model_config(
|
||||
provider_name: str, model_name: str
|
||||
) -> tuple[ProviderConfig, ModelDetail] | None:
|
||||
"""在配置中查找指定的 Provider 和 ModelDetail
|
||||
|
||||
Args:
|
||||
@ -194,7 +207,9 @@ def list_available_models() -> list[dict[str, Any]]:
|
||||
"api_base": provider.api_base,
|
||||
"is_available": model_detail.is_available,
|
||||
"is_embedding_model": model_detail.is_embedding_model,
|
||||
"available_identifiers": _get_model_identifiers(provider.name, model_detail),
|
||||
"available_identifiers": _get_model_identifiers(
|
||||
provider.name, model_detail
|
||||
),
|
||||
}
|
||||
model_list.append(model_info)
|
||||
return model_list
|
||||
@ -242,7 +257,8 @@ async def get_model_instance(
|
||||
if cached_model._generation_config != validated_override:
|
||||
cached_model._generation_config = validated_override
|
||||
logger.debug(
|
||||
f"对缓存模型 {provider_model_name} 应用新的覆盖配置: {validated_override.to_dict()}"
|
||||
f"对缓存模型 {provider_model_name} 应用新的覆盖配置: "
|
||||
f"{validated_override.to_dict()}"
|
||||
)
|
||||
return cached_model
|
||||
|
||||
@ -252,21 +268,25 @@ async def get_model_instance(
|
||||
if resolved_model_name_str is None:
|
||||
available_models_list = list_available_models()
|
||||
if not available_models_list:
|
||||
raise LLMException("未配置任何AI模型", code=LLMErrorCode.CONFIGURATION_ERROR)
|
||||
raise LLMException(
|
||||
"未配置任何AI模型", code=LLMErrorCode.CONFIGURATION_ERROR
|
||||
)
|
||||
resolved_model_name_str = available_models_list[0]["full_name"]
|
||||
logger.warning(f"未指定模型,使用第一个可用模型: {resolved_model_name_str}")
|
||||
|
||||
prov_name_str, mod_name_str = parse_provider_model_string(resolved_model_name_str)
|
||||
if not prov_name_str or not mod_name_str:
|
||||
raise LLMException(
|
||||
f"无效的模型名称格式: '{resolved_model_name_str}'", code=LLMErrorCode.MODEL_NOT_FOUND
|
||||
f"无效的模型名称格式: '{resolved_model_name_str}'",
|
||||
code=LLMErrorCode.MODEL_NOT_FOUND,
|
||||
)
|
||||
|
||||
config_tuple_found = find_model_config(prov_name_str, mod_name_str)
|
||||
if not config_tuple_found:
|
||||
all_models = list_available_models()
|
||||
raise LLMException(
|
||||
f"未找到模型: '{resolved_model_name_str}'. 可用: {[m['full_name'] for m in all_models]}",
|
||||
f"未找到模型: '{resolved_model_name_str}'. "
|
||||
f"可用: {[m['full_name'] for m in all_models]}",
|
||||
code=LLMErrorCode.MODEL_NOT_FOUND,
|
||||
)
|
||||
|
||||
@ -274,7 +294,11 @@ async def get_model_instance(
|
||||
|
||||
ai_config = get_ai_config()
|
||||
global_proxy_setting = ai_config.get(PROXY_KEY)
|
||||
default_timeout = provider_config_found.timeout if provider_config_found.timeout is not None else 180
|
||||
default_timeout = (
|
||||
provider_config_found.timeout
|
||||
if provider_config_found.timeout is not None
|
||||
else 180
|
||||
)
|
||||
global_timeout_setting = ai_config.get(TIMEOUT_KEY, default_timeout)
|
||||
|
||||
config_for_http_client = ProviderConfig(
|
||||
@ -304,16 +328,21 @@ async def get_model_instance(
|
||||
validated_override_params = validate_override_params(override_config)
|
||||
model_instance._generation_config = validated_override_params
|
||||
logger.debug(
|
||||
f"为新模型 {resolved_model_name_str} 应用配置覆盖: {validated_override_params.to_dict()}"
|
||||
f"为新模型 {resolved_model_name_str} 应用配置覆盖: "
|
||||
f"{validated_override_params.to_dict()}"
|
||||
)
|
||||
|
||||
_cache_model(cache_key, model_instance)
|
||||
logger.debug(f"创建并缓存了新模型: {cache_key} -> {prov_name_str}/{mod_name_str}")
|
||||
logger.debug(
|
||||
f"创建并缓存了新模型: {cache_key} -> {prov_name_str}/{mod_name_str}"
|
||||
)
|
||||
return model_instance
|
||||
except LLMException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"实例化 LLMModel ({resolved_model_name_str}) 时发生内部错误: {e!s}", e=e)
|
||||
logger.error(
|
||||
f"实例化 LLMModel ({resolved_model_name_str}) 时发生内部错误: {e!s}", e=e
|
||||
)
|
||||
raise LLMException(
|
||||
f"初始化模型 '{resolved_model_name_str}' 失败: {e!s}",
|
||||
code=LLMErrorCode.MODEL_INIT_FAILED,
|
||||
@ -332,10 +361,14 @@ def set_global_default_model_name(provider_model_name: str | None) -> bool:
|
||||
if provider_model_name:
|
||||
prov_name, mod_name = parse_provider_model_string(provider_model_name)
|
||||
if not prov_name or not mod_name or not find_model_config(prov_name, mod_name):
|
||||
logger.error(f"尝试设置的全局默认模型 '{provider_model_name}' 无效或未配置。")
|
||||
logger.error(
|
||||
f"尝试设置的全局默认模型 '{provider_model_name}' 无效或未配置。"
|
||||
)
|
||||
return False
|
||||
|
||||
Config.set_config(AI_CONFIG_GROUP, DEFAULT_MODEL_NAME_KEY, provider_model_name, auto_save=True)
|
||||
Config.set_config(
|
||||
AI_CONFIG_GROUP, DEFAULT_MODEL_NAME_KEY, provider_model_name, auto_save=True
|
||||
)
|
||||
if provider_model_name:
|
||||
logger.info(f"LLM 服务全局默认模型已更新为: {provider_model_name}")
|
||||
else:
|
||||
@ -350,10 +383,16 @@ async def get_key_usage_stats() -> dict[str, Any]:
|
||||
|
||||
for provider in providers:
|
||||
provider_stats = await key_store.get_key_stats(
|
||||
[provider.api_key] if isinstance(provider.api_key, str) else provider.api_key
|
||||
[provider.api_key]
|
||||
if isinstance(provider.api_key, str)
|
||||
else provider.api_key
|
||||
)
|
||||
stats[provider.name] = {
|
||||
"total_keys": len([provider.api_key] if isinstance(provider.api_key, str) else provider.api_key),
|
||||
"total_keys": len(
|
||||
[provider.api_key]
|
||||
if isinstance(provider.api_key, str)
|
||||
else provider.api_key
|
||||
),
|
||||
"key_stats": provider_stats,
|
||||
}
|
||||
|
||||
@ -375,7 +414,9 @@ async def reset_key_status(provider_name: str, api_key: str | None = None) -> bo
|
||||
return False
|
||||
|
||||
provider_keys = (
|
||||
[target_provider.api_key] if isinstance(target_provider.api_key, str) else target_provider.api_key
|
||||
[target_provider.api_key]
|
||||
if isinstance(target_provider.api_key, str)
|
||||
else target_provider.api_key
|
||||
)
|
||||
|
||||
if api_key:
|
||||
|
||||
@ -89,7 +89,9 @@ class LLMModel(LLMModelBase):
|
||||
self.api_type = provider_config.api_type
|
||||
self.api_base = provider_config.api_base
|
||||
self.api_keys = (
|
||||
[provider_config.api_key] if isinstance(provider_config.api_key, str) else provider_config.api_key
|
||||
[provider_config.api_key]
|
||||
if isinstance(provider_config.api_key, str)
|
||||
else provider_config.api_key
|
||||
)
|
||||
self.model_name = model_detail.model_name
|
||||
self.temperature = model_detail.temperature
|
||||
@ -101,9 +103,12 @@ class LLMModel(LLMModelBase):
|
||||
"""获取HTTP客户端"""
|
||||
if self.http_client.is_closed:
|
||||
logger.debug(
|
||||
f"LLMModel {self.provider_name}/{self.model_name} 的 HTTP 客户端已关闭,正在获取新的客户端"
|
||||
f"LLMModel {self.provider_name}/{self.model_name} 的 HTTP 客户端已关闭,"
|
||||
"正在获取新的客户端"
|
||||
)
|
||||
self.http_client = await http_client_manager.get_client(
|
||||
self.provider_config
|
||||
)
|
||||
self.http_client = await http_client_manager.get_client(self.provider_config)
|
||||
return self.http_client
|
||||
|
||||
async def _select_api_key(self, failed_keys: set[str] | None = None) -> str:
|
||||
@ -122,7 +127,10 @@ class LLMModel(LLMModelBase):
|
||||
raise LLMException(
|
||||
f"提供商 {self.provider_name} 的所有API密钥当前都不可用",
|
||||
code=LLMErrorCode.NO_AVAILABLE_KEYS,
|
||||
details={"total_keys": len(self.api_keys), "failed_keys": len(failed_keys or set())},
|
||||
details={
|
||||
"total_keys": len(self.api_keys),
|
||||
"failed_keys": len(failed_keys or set()),
|
||||
},
|
||||
)
|
||||
|
||||
return selected_key
|
||||
@ -154,7 +162,9 @@ class LLMModel(LLMModelBase):
|
||||
|
||||
if http_response.status_code != 200:
|
||||
error_text = http_response.text
|
||||
logger.error(f"HTTP嵌入请求失败: {http_response.status_code} - {error_text}")
|
||||
logger.error(
|
||||
f"HTTP嵌入请求失败: {http_response.status_code} - {error_text}"
|
||||
)
|
||||
await self.key_store.record_failure(api_key, http_response.status_code)
|
||||
|
||||
error_code = LLMErrorCode.API_REQUEST_FAILED
|
||||
@ -262,7 +272,9 @@ class LLMModel(LLMModelBase):
|
||||
|
||||
if http_response.status_code != 200:
|
||||
error_text = http_response.text
|
||||
logger.error(f"HTTP请求失败: {http_response.status_code} - {error_text}")
|
||||
logger.error(
|
||||
f"HTTP请求失败: {http_response.status_code} - {error_text}"
|
||||
)
|
||||
|
||||
await self.key_store.record_failure(api_key, http_response.status_code)
|
||||
|
||||
@ -304,7 +316,10 @@ class LLMModel(LLMModelBase):
|
||||
try:
|
||||
response_tool_calls.append(LLMToolCall(**tc_data))
|
||||
except Exception as e:
|
||||
logger.warning(f"无法将工具调用数据转换为LLMToolCall: {tc_data}, error: {e}")
|
||||
logger.warning(
|
||||
f"无法将工具调用数据转换为LLMToolCall: {tc_data}, "
|
||||
f"error: {e}"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"工具调用数据格式未知: {tc_data}")
|
||||
|
||||
@ -355,11 +370,16 @@ class LLMModel(LLMModelBase):
|
||||
if self._is_closed:
|
||||
return
|
||||
self._is_closed = True
|
||||
logger.debug(f"LLMModel实例的使用周期已结束: {self} (共享HTTP客户端状态不受影响)")
|
||||
logger.debug(
|
||||
f"LLMModel实例的使用周期已结束: {self} (共享HTTP客户端状态不受影响)"
|
||||
)
|
||||
|
||||
async def __aenter__(self):
|
||||
if self._is_closed:
|
||||
logger.debug(f"Re-entering context for closed LLMModel {self}. Resetting _is_closed to False.")
|
||||
logger.debug(
|
||||
f"Re-entering context for closed LLMModel {self}. "
|
||||
f"Resetting _is_closed to False."
|
||||
)
|
||||
self._is_closed = False
|
||||
self._check_not_closed()
|
||||
return self
|
||||
@ -393,12 +413,15 @@ class LLMModel(LLMModelBase):
|
||||
|
||||
messages.append(LLMMessage.user(prompt))
|
||||
|
||||
model_fields = getattr(LLMGenerationConfig, "model_fields", {})
|
||||
request_specific_config_dict = {
|
||||
k: v for k, v in kwargs.items() if k in LLMGenerationConfig.model_fields
|
||||
k: v for k, v in kwargs.items() if k in model_fields
|
||||
}
|
||||
request_specific_config = None
|
||||
if request_specific_config_dict:
|
||||
request_specific_config = LLMGenerationConfig(**request_specific_config_dict)
|
||||
request_specific_config = LLMGenerationConfig(
|
||||
**request_specific_config_dict
|
||||
)
|
||||
|
||||
for key in request_specific_config_dict:
|
||||
kwargs.pop(key, None)
|
||||
@ -450,7 +473,8 @@ class LLMModel(LLMModelBase):
|
||||
tools_dict = []
|
||||
for tool in tools:
|
||||
if hasattr(tool, "model_dump"):
|
||||
tools_dict.append(tool.model_dump(exclude_none=True))
|
||||
model_dump_func = getattr(tool, "model_dump")
|
||||
tools_dict.append(model_dump_func(exclude_none=True))
|
||||
elif isinstance(tool, dict):
|
||||
tools_dict.append(tool)
|
||||
else:
|
||||
@ -497,7 +521,9 @@ class LLMModel(LLMModelBase):
|
||||
logger.debug(f"执行工具: {tool_name},参数: {tool_args_dict}")
|
||||
|
||||
tool_result = await tool_executor(tool_name, tool_args_dict)
|
||||
logger.debug(f"工具 '{tool_name}' 执行结果: {str(tool_result)[:200]}...")
|
||||
logger.debug(
|
||||
f"工具 '{tool_name}' 执行结果: {str(tool_result)[:200]}..."
|
||||
)
|
||||
|
||||
tool_response_messages.append(
|
||||
LLMMessage.tool_response(
|
||||
@ -508,13 +534,17 @@ class LLMModel(LLMModelBase):
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(
|
||||
f"工具 '{tool_name}' 参数JSON解析失败: {tool_call.function.arguments}, 错误: {e}"
|
||||
f"工具 '{tool_name}' 参数JSON解析失败: "
|
||||
f"{tool_call.function.arguments}, 错误: {e}"
|
||||
)
|
||||
tool_response_messages.append(
|
||||
LLMMessage.tool_response(
|
||||
tool_call_id=tool_call.id,
|
||||
function_name=tool_name,
|
||||
result={"error": "Argument JSON parsing failed", "details": str(e)},
|
||||
result={
|
||||
"error": "Argument JSON parsing failed",
|
||||
"details": str(e),
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
@ -523,7 +553,10 @@ class LLMModel(LLMModelBase):
|
||||
LLMMessage.tool_response(
|
||||
tool_call_id=tool_call.id,
|
||||
function_name=tool_name,
|
||||
result={"error": "Tool execution failed", "details": str(e)},
|
||||
result={
|
||||
"error": "Tool execution failed",
|
||||
"details": str(e),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@ -533,7 +566,10 @@ class LLMModel(LLMModelBase):
|
||||
raise LLMException(
|
||||
"已达到最大工具调用迭代次数,但模型仍在请求工具调用或未提供最终文本回复。",
|
||||
code=LLMErrorCode.GENERATION_FAILED,
|
||||
details={"iterations": max_tool_iterations, "last_messages": current_messages[-2:]},
|
||||
details={
|
||||
"iterations": max_tool_iterations,
|
||||
"last_messages": current_messages[-2:],
|
||||
},
|
||||
)
|
||||
|
||||
async def generate_embeddings(
|
||||
@ -561,7 +597,9 @@ class LLMModel(LLMModelBase):
|
||||
ai_config = get_ai_config()
|
||||
default_max_retries = ai_config.get("max_retries_llm", 3)
|
||||
default_retry_delay = ai_config.get("retry_delay_llm", 2)
|
||||
max_retries_embed = kwargs.get("max_retries_embed", max(1, default_max_retries // 2))
|
||||
max_retries_embed = kwargs.get(
|
||||
"max_retries_embed", max(1, default_max_retries // 2)
|
||||
)
|
||||
retry_delay_embed = kwargs.get("retry_delay_embed", default_retry_delay / 2)
|
||||
|
||||
retry_config = RetryConfig(
|
||||
|
||||
@ -58,7 +58,9 @@ class LLMContentPart(BaseModel):
|
||||
return cls(type="image", image_source=url)
|
||||
|
||||
@classmethod
|
||||
def image_base64_part(cls, data: str, mime_type: str = "image/png") -> "LLMContentPart":
|
||||
def image_base64_part(
|
||||
cls, data: str, mime_type: str = "image/png"
|
||||
) -> "LLMContentPart":
|
||||
"""创建Base64图片内容部分"""
|
||||
data_url = f"data:{mime_type};base64,{data}"
|
||||
return cls(type="image", image_source=data_url)
|
||||
@ -74,13 +76,17 @@ class LLMContentPart(BaseModel):
|
||||
return cls(type="video", video_source=url, mime_type=mime_type)
|
||||
|
||||
@classmethod
|
||||
def video_base64_part(cls, data: str, mime_type: str = "video/mp4") -> "LLMContentPart":
|
||||
def video_base64_part(
|
||||
cls, data: str, mime_type: str = "video/mp4"
|
||||
) -> "LLMContentPart":
|
||||
"""创建Base64视频内容部分"""
|
||||
data_url = f"data:{mime_type};base64,{data}"
|
||||
return cls(type="video", video_source=data_url, mime_type=mime_type)
|
||||
|
||||
@classmethod
|
||||
def audio_base64_part(cls, data: str, mime_type: str = "audio/wav") -> "LLMContentPart":
|
||||
def audio_base64_part(
|
||||
cls, data: str, mime_type: str = "audio/wav"
|
||||
) -> "LLMContentPart":
|
||||
"""创建Base64音频内容部分"""
|
||||
data_url = f"data:{mime_type};base64,{data}"
|
||||
return cls(type="audio", audio_source=data_url, mime_type=mime_type)
|
||||
@ -101,7 +107,9 @@ class LLMContentPart(BaseModel):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def from_path(cls, path_like: str | Path, target_api: str | None = None) -> "LLMContentPart | None":
|
||||
async def from_path(
|
||||
cls, path_like: str | Path, target_api: str | None = None
|
||||
) -> "LLMContentPart | None":
|
||||
"""
|
||||
从本地文件路径创建 LLMContentPart。
|
||||
自动检测MIME类型,并根据类型(如图片)可能加载为Base64。
|
||||
@ -116,7 +124,9 @@ class LLMContentPart(BaseModel):
|
||||
mime_type, _ = mimetypes.guess_type(path.resolve().as_uri())
|
||||
|
||||
if not mime_type:
|
||||
logger.warning(f"无法猜测文件 {path.name} 的MIME类型,将尝试作为文本文件处理。")
|
||||
logger.warning(
|
||||
f"无法猜测文件 {path.name} 的MIME类型,将尝试作为文本文件处理。"
|
||||
)
|
||||
try:
|
||||
async with aiofiles.open(path, encoding="utf-8") as f:
|
||||
text_content = await f.read()
|
||||
@ -131,17 +141,22 @@ class LLMContentPart(BaseModel):
|
||||
async with aiofiles.open(path, "rb") as f:
|
||||
img_bytes = await f.read()
|
||||
base64_data = base64.b64encode(img_bytes).decode("utf-8")
|
||||
return cls.image_base64_part(data=base64_data, mime_type=mime_type)
|
||||
return cls.image_base64_part(
|
||||
data=base64_data, mime_type=mime_type
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"读取或编码图片文件 {path.name} 失败: {e}")
|
||||
return None
|
||||
else:
|
||||
logger.warning(
|
||||
f"为本地图片路径 {path.name} 生成 image_url_part。实际API可能不支持 file:// URI。考虑使用Base64或公网URL。" # noqa: E501
|
||||
f"为本地图片路径 {path.name} 生成 image_url_part。"
|
||||
"实际API可能不支持 file:// URI。考虑使用Base64或公网URL。"
|
||||
)
|
||||
return cls.image_url_part(url=path.resolve().as_uri())
|
||||
elif mime_type.startswith("audio/"):
|
||||
return cls.audio_url_part(url=path.resolve().as_uri(), mime_type=mime_type)
|
||||
return cls.audio_url_part(
|
||||
url=path.resolve().as_uri(), mime_type=mime_type
|
||||
)
|
||||
elif mime_type.startswith("video/"):
|
||||
if target_api == "gemini":
|
||||
# 对于 Gemini API,将视频转换为 base64
|
||||
@ -149,12 +164,16 @@ class LLMContentPart(BaseModel):
|
||||
async with aiofiles.open(path, "rb") as f:
|
||||
video_bytes = await f.read()
|
||||
base64_data = base64.b64encode(video_bytes).decode("utf-8")
|
||||
return cls.video_base64_part(data=base64_data, mime_type=mime_type)
|
||||
return cls.video_base64_part(
|
||||
data=base64_data, mime_type=mime_type
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"读取或编码视频文件 {path.name} 失败: {e}")
|
||||
return None
|
||||
else:
|
||||
return cls.video_url_part(url=path.resolve().as_uri(), mime_type=mime_type)
|
||||
return cls.video_url_part(
|
||||
url=path.resolve().as_uri(), mime_type=mime_type
|
||||
)
|
||||
elif (
|
||||
mime_type.startswith("text/")
|
||||
or mime_type == "application/json"
|
||||
@ -168,7 +187,9 @@ class LLMContentPart(BaseModel):
|
||||
logger.error(f"读取文本类文件 {path.name} 失败: {e}")
|
||||
return None
|
||||
else:
|
||||
logger.info(f"文件 {path.name} (MIME: {mime_type}) 将作为通用文件URI处理。")
|
||||
logger.info(
|
||||
f"文件 {path.name} (MIME: {mime_type}) 将作为通用文件URI处理。"
|
||||
)
|
||||
return cls.file_uri_part(
|
||||
file_uri=path.resolve().as_uri(),
|
||||
mime_type=mime_type,
|
||||
@ -228,9 +249,13 @@ class LLMContentPart(BaseModel):
|
||||
return {"inlineData": {"mimeType": mime_type, "data": data}}
|
||||
else:
|
||||
# 如果无法解析 Base64 数据,抛出异常
|
||||
raise ValueError(f"无法解析Base64图像数据: {self.image_source[:50]}...")
|
||||
raise ValueError(
|
||||
f"无法解析Base64图像数据: {self.image_source[:50]}..."
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Gemini API需要Base64格式,但提供的是URL: {self.image_source}")
|
||||
logger.warning(
|
||||
f"Gemini API需要Base64格式,但提供的是URL: {self.image_source}"
|
||||
)
|
||||
return {
|
||||
"inlineData": {
|
||||
"mimeType": "image/jpeg",
|
||||
@ -253,10 +278,14 @@ class LLMContentPart(BaseModel):
|
||||
mime_type = header.split(";")[0].replace("data:", "")
|
||||
return {"inlineData": {"mimeType": mime_type, "data": data}}
|
||||
except (ValueError, IndexError):
|
||||
raise ValueError(f"无法解析Base64视频数据: {self.video_source[:50]}...")
|
||||
raise ValueError(
|
||||
f"无法解析Base64视频数据: {self.video_source[:50]}..."
|
||||
)
|
||||
else:
|
||||
# 对于 URL 或其他格式,暂时不支持直接内联
|
||||
raise ValueError("Gemini API 的视频处理需要通过 File API 上传,不支持直接 URL")
|
||||
raise ValueError(
|
||||
"Gemini API 的视频处理需要通过 File API 上传,不支持直接 URL"
|
||||
)
|
||||
else:
|
||||
# 其他 API 可能不支持视频
|
||||
raise ValueError(f"API类型 '{api_type}' 不支持视频内容")
|
||||
@ -273,17 +302,25 @@ class LLMContentPart(BaseModel):
|
||||
mime_type = header.split(";")[0].replace("data:", "")
|
||||
return {"inlineData": {"mimeType": mime_type, "data": data}}
|
||||
except (ValueError, IndexError):
|
||||
raise ValueError(f"无法解析Base64音频数据: {self.audio_source[:50]}...")
|
||||
raise ValueError(
|
||||
f"无法解析Base64音频数据: {self.audio_source[:50]}..."
|
||||
)
|
||||
else:
|
||||
raise ValueError("Gemini API 的音频处理需要通过 File API 上传,不支持直接 URL")
|
||||
raise ValueError(
|
||||
"Gemini API 的音频处理需要通过 File API 上传,不支持直接 URL"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"API类型 '{api_type}' 不支持音频内容")
|
||||
|
||||
elif self.type == "file":
|
||||
if api_type == "gemini" and self.file_uri:
|
||||
return {"fileData": {"mimeType": self.mime_type, "fileUri": self.file_uri}}
|
||||
return {
|
||||
"fileData": {"mimeType": self.mime_type, "fileUri": self.file_uri}
|
||||
}
|
||||
elif self.file_source:
|
||||
file_name = self.metadata.get("name", "file") if self.metadata else "file"
|
||||
file_name = (
|
||||
self.metadata.get("name", "file") if self.metadata else "file"
|
||||
)
|
||||
if api_type == "gemini":
|
||||
return {"text": f"[文件: {file_name}]\n{self.file_source}"}
|
||||
else:
|
||||
@ -317,7 +354,8 @@ class LLMMessage(BaseModel):
|
||||
raise ValueError("工具角色的消息必须包含函数名 (在 name 字段中)")
|
||||
if self.role == "tool" and not isinstance(self.content, str):
|
||||
logger.warning(
|
||||
f"工具角色消息的内容期望是字符串,但得到的是: {type(self.content)}. 将尝试转换为字符串。"
|
||||
f"工具角色消息的内容期望是字符串,但得到的是: {type(self.content)}. "
|
||||
"将尝试转换为字符串。"
|
||||
)
|
||||
try:
|
||||
self.content = str(self.content)
|
||||
@ -339,7 +377,9 @@ class LLMMessage(BaseModel):
|
||||
return cls(role="assistant", content=content, tool_calls=tool_calls)
|
||||
|
||||
@classmethod
|
||||
def assistant_text_response(cls, content: str | list[LLMContentPart]) -> "LLMMessage":
|
||||
def assistant_text_response(
|
||||
cls, content: str | list[LLMContentPart]
|
||||
) -> "LLMMessage":
|
||||
"""创建助手纯文本回复的消息"""
|
||||
return cls(role="assistant", content=content, tool_calls=None)
|
||||
|
||||
@ -356,10 +396,19 @@ class LLMMessage(BaseModel):
|
||||
try:
|
||||
content_str = json.dumps(result)
|
||||
except TypeError as e:
|
||||
logger.error(f"工具 '{function_name}' 的结果无法JSON序列化: {result}. 错误: {e}")
|
||||
content_str = json.dumps({"error": "Tool result not JSON serializable", "details": str(e)})
|
||||
logger.error(
|
||||
f"工具 '{function_name}' 的结果无法JSON序列化: {result}. 错误: {e}"
|
||||
)
|
||||
content_str = json.dumps(
|
||||
{"error": "Tool result not JSON serializable", "details": str(e)}
|
||||
)
|
||||
|
||||
return cls(role="tool", content=content_str, tool_call_id=tool_call_id, name=function_name)
|
||||
return cls(
|
||||
role="tool",
|
||||
content=content_str,
|
||||
tool_call_id=tool_call_id,
|
||||
name=function_name,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def system(cls, content: str) -> "LLMMessage":
|
||||
|
||||
@ -36,11 +36,15 @@ class LLMException(Exception):
|
||||
error_messages = {
|
||||
LLMErrorCode.MODEL_NOT_FOUND: "AI模型未找到,请检查配置或联系管理员。",
|
||||
LLMErrorCode.API_KEY_INVALID: "API密钥无效,请联系管理员更新配置。",
|
||||
LLMErrorCode.API_QUOTA_EXCEEDED: "API使用配额已用尽,请稍后再试或联系管理员。",
|
||||
LLMErrorCode.API_QUOTA_EXCEEDED: (
|
||||
"API使用配额已用尽,请稍后再试或联系管理员。"
|
||||
),
|
||||
LLMErrorCode.API_TIMEOUT: "AI服务响应超时,请稍后再试。",
|
||||
LLMErrorCode.API_RATE_LIMITED: "请求过于频繁,已被AI服务限流,请稍后再试。",
|
||||
LLMErrorCode.MODEL_INIT_FAILED: "AI模型初始化失败,请联系管理员检查配置。",
|
||||
LLMErrorCode.NO_AVAILABLE_KEYS: "当前所有API密钥均不可用,请稍后再试或联系管理员。",
|
||||
LLMErrorCode.NO_AVAILABLE_KEYS: (
|
||||
"当前所有API密钥均不可用,请稍后再试或联系管理员。"
|
||||
),
|
||||
LLMErrorCode.USER_LOCATION_NOT_SUPPORTED: (
|
||||
"当前地区暂不支持此AI服务,请联系管理员或尝试其他模型。"
|
||||
),
|
||||
|
||||
@ -5,7 +5,16 @@ LLM 模块的工具和转换函数
|
||||
import base64
|
||||
from pathlib import Path
|
||||
|
||||
from nonebot_plugin_alconna.uniseg import At, File, Image, Reply, Text, UniMessage, Video, Voice
|
||||
from nonebot_plugin_alconna.uniseg import (
|
||||
At,
|
||||
File,
|
||||
Image,
|
||||
Reply,
|
||||
Text,
|
||||
UniMessage,
|
||||
Video,
|
||||
Voice,
|
||||
)
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
@ -29,7 +38,11 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
|
||||
elif seg.url:
|
||||
part = LLMContentPart.image_url_part(seg.url)
|
||||
elif hasattr(seg, "raw") and seg.raw:
|
||||
mime_type = getattr(seg, "mimetype", "image/png") if hasattr(seg, "mimetype") else "image/png"
|
||||
mime_type = (
|
||||
getattr(seg, "mimetype", "image/png")
|
||||
if hasattr(seg, "mimetype")
|
||||
else "image/png"
|
||||
)
|
||||
if isinstance(seg.raw, bytes):
|
||||
b64_data = base64.b64encode(seg.raw).decode("utf-8")
|
||||
part = LLMContentPart.image_base64_part(b64_data, mime_type)
|
||||
@ -38,8 +51,13 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
|
||||
if seg.path:
|
||||
part = await LLMContentPart.from_path(seg.path)
|
||||
elif seg.url:
|
||||
logger.warning(f"直接使用 URL 的 {type(seg).__name__} 段,API 可能不支持: {seg.url}")
|
||||
part = LLMContentPart.text_part(f"[{type(seg).__name__.upper()} FILE: {seg.name or seg.url}]")
|
||||
logger.warning(
|
||||
f"直接使用 URL 的 {type(seg).__name__} 段,"
|
||||
f"API 可能不支持: {seg.url}"
|
||||
)
|
||||
part = LLMContentPart.text_part(
|
||||
f"[{type(seg).__name__.upper()} FILE: {seg.name or seg.url}]"
|
||||
)
|
||||
elif hasattr(seg, "raw") and seg.raw:
|
||||
mime_type = getattr(seg, "mimetype", None)
|
||||
if isinstance(seg.raw, bytes):
|
||||
@ -48,18 +66,29 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
|
||||
if isinstance(seg, Video):
|
||||
if not mime_type:
|
||||
mime_type = "video/mp4"
|
||||
part = LLMContentPart.video_base64_part(data=b64_data, mime_type=mime_type)
|
||||
logger.debug(f"处理视频字节数据: {mime_type}, 大小: {len(seg.raw)} bytes")
|
||||
part = LLMContentPart.video_base64_part(
|
||||
data=b64_data, mime_type=mime_type
|
||||
)
|
||||
logger.debug(
|
||||
f"处理视频字节数据: {mime_type}, 大小: {len(seg.raw)} bytes"
|
||||
)
|
||||
elif isinstance(seg, Voice):
|
||||
if not mime_type:
|
||||
mime_type = "audio/wav"
|
||||
part = LLMContentPart.audio_base64_part(data=b64_data, mime_type=mime_type)
|
||||
logger.debug(f"处理音频字节数据: {mime_type}, 大小: {len(seg.raw)} bytes")
|
||||
part = LLMContentPart.audio_base64_part(
|
||||
data=b64_data, mime_type=mime_type
|
||||
)
|
||||
logger.debug(
|
||||
f"处理音频字节数据: {mime_type}, 大小: {len(seg.raw)} bytes"
|
||||
)
|
||||
else:
|
||||
part = LLMContentPart.text_part(
|
||||
f"[FILE: {mime_type or 'unknown'}, {len(seg.raw)} bytes]"
|
||||
)
|
||||
logger.debug(f"处理其他文件字节数据: {mime_type}, 大小: {len(seg.raw)} bytes")
|
||||
logger.debug(
|
||||
f"处理其他文件字节数据: {mime_type}, "
|
||||
f"大小: {len(seg.raw)} bytes"
|
||||
)
|
||||
|
||||
elif isinstance(seg, At):
|
||||
if seg.flag == "all":
|
||||
@ -76,7 +105,9 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
|
||||
else:
|
||||
reply_text = str(seg.msg).strip()
|
||||
if reply_text:
|
||||
part = LLMContentPart.text_part(f'[Replied to: "{reply_text[:50]}..."]')
|
||||
part = LLMContentPart.text_part(
|
||||
f'[Replied to: "{reply_text[:50]}..."]'
|
||||
)
|
||||
except Exception:
|
||||
part = LLMContentPart.text_part("[Replied to a message]")
|
||||
|
||||
@ -118,10 +149,14 @@ def create_multimodal_message(
|
||||
msg = create_multimodal_message("分析图片", images="/path/to/image.jpg")
|
||||
|
||||
# 文本 + 多张图片
|
||||
msg = create_multimodal_message("比较图片", images=["/path/1.jpg", "/path/2.jpg"])
|
||||
msg = create_multimodal_message(
|
||||
"比较图片", images=["/path/1.jpg", "/path/2.jpg"]
|
||||
)
|
||||
|
||||
# 文本 + 图片字节数据
|
||||
msg = create_multimodal_message("分析", images=image_data, image_mimetypes="image/jpeg")
|
||||
msg = create_multimodal_message(
|
||||
"分析", images=image_data, image_mimetypes="image/jpeg"
|
||||
)
|
||||
|
||||
# 文本 + 视频
|
||||
msg = create_multimodal_message("分析视频", videos="/path/to/video.mp4")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user