🎨 Ruff

This commit is contained in:
fccckaug 2025-06-13 16:21:27 +08:00 committed by fccckaug
parent 6c045055a0
commit d09a5b1c72
11 changed files with 596 additions and 186 deletions

View File

@ -131,7 +131,9 @@ class BaseAdapter(ABC):
pass
@abstractmethod
def parse_embedding_response(self, response_json: dict[str, Any]) -> list[list[float]]:
def parse_embedding_response(
self, response_json: dict[str, Any]
) -> list[list[float]]:
"""解析文本嵌入响应"""
pass
@ -145,7 +147,9 @@ class BaseAdapter(ABC):
else str(error_info)
)
raise LLMException(
f"嵌入API错误: {msg}", code=LLMErrorCode.EMBEDDING_FAILED, details=response_json
f"嵌入API错误: {msg}",
code=LLMErrorCode.EMBEDDING_FAILED,
details=response_json,
)
def get_api_url(self, model: "LLMModel", endpoint: str) -> str:
@ -170,7 +174,9 @@ class BaseAdapter(ABC):
)
return headers
def convert_messages_to_openai_format(self, messages: list["LLMMessage"]) -> list[dict[str, Any]]:
def convert_messages_to_openai_format(
self, messages: list["LLMMessage"]
) -> list[dict[str, Any]]:
"""将LLMMessage转换为OpenAI格式 - 通用方法"""
openai_messages: list[dict[str, Any]] = []
for msg in messages:
@ -190,7 +196,10 @@ class BaseAdapter(ABC):
content_parts.append({"type": "text", "text": part.text})
elif part.type == "image":
content_parts.append(
{"type": "image_url", "image_url": {"url": part.image_source}}
{
"type": "image_url",
"image_url": {"url": part.image_source},
}
)
openai_msg["content"] = content_parts
@ -247,9 +256,13 @@ class BaseAdapter(ABC):
)
)
except KeyError as e:
logger.warning(f"解析OpenAI工具调用数据时缺少键: {tc_data}, 错误: {e}")
logger.warning(
f"解析OpenAI工具调用数据时缺少键: {tc_data}, 错误: {e}"
)
except Exception as e:
logger.warning(f"解析OpenAI工具调用数据时出错: {tc_data}, 错误: {e}")
logger.warning(
f"解析OpenAI工具调用数据时出错: {tc_data}, 错误: {e}"
)
if not parsed_tool_calls:
parsed_tool_calls = None
@ -268,7 +281,11 @@ class BaseAdapter(ABC):
except Exception as e:
logger.error(f"解析OpenAI格式响应失败: {e}", e=e)
raise LLMException(f"解析API响应失败: {e}", code=LLMErrorCode.RESPONSE_PARSE_ERROR, cause=e)
raise LLMException(
f"解析API响应失败: {e}",
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
cause=e,
)
def validate_response(self, response_json: dict[str, Any]) -> None:
"""验证API响应解析不同API的错误结构"""
@ -291,9 +308,14 @@ class BaseAdapter(ABC):
"max_tokens_exceeded": LLMErrorCode.CONTEXT_LENGTH_EXCEEDED,
}
llm_error_code = error_code_mapping.get(error_code, LLMErrorCode.API_RESPONSE_INVALID)
llm_error_code = error_code_mapping.get(
error_code, LLMErrorCode.API_RESPONSE_INVALID
)
logger.error(f"API返回错误: {error_message} (代码: {error_code}, 类型: {error_type})")
logger.error(
f"API返回错误: {error_message} "
f"(代码: {error_code}, 类型: {error_type})"
)
else:
error_message = str(error_info)
error_code = "unknown"
@ -314,7 +336,10 @@ class BaseAdapter(ABC):
finish_reason = candidate.get("finishReason")
if finish_reason in ["SAFETY", "RECITATION"]:
safety_ratings = candidate.get("safetyRatings", [])
logger.warning(f"Gemini内容被安全过滤: {finish_reason}, 安全评级: {safety_ratings}")
logger.warning(
f"Gemini内容被安全过滤: {finish_reason}, "
f"安全评级: {safety_ratings}"
)
raise LLMException(
f"内容被安全过滤: {finish_reason}",
code=LLMErrorCode.CONTENT_FILTERED,
@ -342,7 +367,9 @@ class BaseAdapter(ABC):
return config.to_api_params(model.api_type, model.model_name)
if model._generation_config is not None:
return model._generation_config.to_api_params(model.api_type, model.model_name)
return model._generation_config.to_api_params(
model.api_type, model.model_name
)
base_config = {}
if model.temperature is not None:
@ -417,6 +444,7 @@ class OpenAICompatAdapter(BaseAdapter):
is_advanced: bool = False,
) -> ResponseData:
"""解析响应 - 直接使用基类的 OpenAI 格式解析"""
_ = model, is_advanced # 未使用的参数
return self.parse_openai_response(response_json)
def prepare_embedding_request(
@ -428,6 +456,7 @@ class OpenAICompatAdapter(BaseAdapter):
**kwargs: Any,
) -> RequestData:
"""准备嵌入请求 - OpenAI兼容格式"""
_ = task_type # 未使用的参数
url = self.get_api_url(model, self.get_embedding_endpoint())
headers = self.get_base_headers(api_key)
@ -442,7 +471,9 @@ class OpenAICompatAdapter(BaseAdapter):
return RequestData(url=url, headers=headers, body=body)
def parse_embedding_response(self, response_json: dict[str, Any]) -> list[list[float]]:
def parse_embedding_response(
self, response_json: dict[str, Any]
) -> list[list[float]]:
"""解析嵌入响应 - OpenAI兼容格式"""
self.validate_embedding_response(response_json)

View File

@ -48,7 +48,9 @@ class GeminiAdapter(BaseAdapter):
tool_choice: str | dict[str, Any] | None = None,
) -> RequestData:
"""准备高级请求"""
return self._prepare_request(model, api_key, messages, config, tools, tool_choice)
return self._prepare_request(
model, api_key, messages, config, tools, tool_choice
)
def _prepare_request(
self,
@ -75,7 +77,9 @@ class GeminiAdapter(BaseAdapter):
if isinstance(msg.content, str):
system_instruction_parts = [{"text": msg.content}]
elif isinstance(msg.content, list):
system_instruction_parts = [part.convert_for_api("gemini") for part in msg.content]
system_instruction_parts = [
part.convert_for_api("gemini") for part in msg.content
]
continue
elif msg.role == "user":
@ -115,12 +119,21 @@ class GeminiAdapter(BaseAdapter):
import json
try:
content_str = msg.content if isinstance(msg.content, str) else str(msg.content)
content_str = (
msg.content
if isinstance(msg.content, str)
else str(msg.content)
)
tool_result_obj = json.loads(content_str)
except json.JSONDecodeError:
content_str = msg.content if isinstance(msg.content, str) else str(msg.content)
content_str = (
msg.content
if isinstance(msg.content, str)
else str(msg.content)
)
logger.warning(
f"工具 {msg.name} 的结果不是有效的 JSON: {content_str}. 包装为原始字符串。"
f"工具 {msg.name} 的结果不是有效的 JSON: {content_str}. "
f"包装为原始字符串。"
)
tool_result_obj = {"raw_output": content_str}
@ -144,7 +157,9 @@ class GeminiAdapter(BaseAdapter):
for tool_item in tools:
if isinstance(tool_item, dict):
if "name" in tool_item and "description" in tool_item:
all_tools_for_request.append({"functionDeclarations": [tool_item]})
all_tools_for_request.append(
{"functionDeclarations": [tool_item]}
)
else:
all_tools_for_request.append(tool_item)
else:
@ -152,7 +167,9 @@ class GeminiAdapter(BaseAdapter):
if effective_config:
if getattr(effective_config, "enable_grounding", False):
has_explicit_gs_tool = any("googleSearch" in tool_item for tool_item in all_tools_for_request)
has_explicit_gs_tool = any(
"googleSearch" in tool_item for tool_item in all_tools_for_request
)
if not has_explicit_gs_tool:
all_tools_for_request.append({"googleSearch": {}})
logger.debug("隐式启用 Google Search 工具进行信息来源关联。")
@ -166,7 +183,9 @@ class GeminiAdapter(BaseAdapter):
logger.debug("隐式启用代码执行工具。")
if all_tools_for_request:
gemini_api_tools = self._convert_tools_to_gemini_format(all_tools_for_request)
gemini_api_tools = self._convert_tools_to_gemini_format(
all_tools_for_request
)
if gemini_api_tools:
body["tools"] = gemini_api_tools
@ -180,11 +199,17 @@ class GeminiAdapter(BaseAdapter):
if mode_upper in ["AUTO", "NONE", "ANY"]:
body["toolConfig"] = {"functionCallingConfig": {"mode": mode_upper}}
else:
body["toolConfig"] = self._convert_tool_choice_to_gemini(final_tool_choice)
body["toolConfig"] = self._convert_tool_choice_to_gemini(
final_tool_choice
)
else:
body["toolConfig"] = self._convert_tool_choice_to_gemini(final_tool_choice)
body["toolConfig"] = self._convert_tool_choice_to_gemini(
final_tool_choice
)
final_generation_config = self._build_gemini_generation_config(model, effective_config)
final_generation_config = self._build_gemini_generation_config(
model, effective_config
)
if final_generation_config:
body["generationConfig"] = final_generation_config
@ -203,7 +228,9 @@ class GeminiAdapter(BaseAdapter):
"""应用配置覆盖 - Gemini 不需要额外的配置覆盖"""
return body
def _get_gemini_endpoint(self, model: "LLMModel", config: "LLMGenerationConfig | None" = None) -> str:
def _get_gemini_endpoint(
self, model: "LLMModel", config: "LLMGenerationConfig | None" = None
) -> str:
"""根据配置选择Gemini API端点"""
if config:
if getattr(config, "enable_code_execution", False):
@ -214,7 +241,9 @@ class GeminiAdapter(BaseAdapter):
return f"/v1beta/models/{model.model_name}:generateContent"
def _convert_tools_to_gemini_format(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
def _convert_tools_to_gemini_format(
self, tools: list[dict[str, Any]]
) -> list[dict[str, Any]]:
"""转换工具格式为Gemini格式"""
gemini_tools = []
@ -232,7 +261,9 @@ class GeminiAdapter(BaseAdapter):
}
gemini_tools.append(gemini_tool)
elif tool.get("type") == "code_execution":
gemini_tools.append({"codeExecution": {"language": tool.get("language", "python")}})
gemini_tools.append(
{"codeExecution": {"language": tool.get("language", "python")}}
)
elif tool.get("type") == "google_search":
gemini_tools.append({"googleSearch": {}})
elif "googleSearch" in tool:
@ -242,18 +273,26 @@ class GeminiAdapter(BaseAdapter):
return gemini_tools
def _convert_tool_choice_to_gemini(self, tool_choice_value: str | dict[str, Any]) -> dict[str, Any]:
def _convert_tool_choice_to_gemini(
self, tool_choice_value: str | dict[str, Any]
) -> dict[str, Any]:
"""转换工具选择策略为Gemini格式"""
if isinstance(tool_choice_value, str):
mode_upper = tool_choice_value.upper()
if mode_upper in ["AUTO", "NONE", "ANY"]:
return {"functionCallingConfig": {"mode": mode_upper}}
else:
logger.warning(f"不支持的 tool_choice 字符串值: '{tool_choice_value}'。回退到 AUTO。")
logger.warning(
f"不支持的 tool_choice 字符串值: '{tool_choice_value}'"
f"回退到 AUTO。"
)
return {"functionCallingConfig": {"mode": "AUTO"}}
elif isinstance(tool_choice_value, dict):
if tool_choice_value.get("type") == "function" and "function" in tool_choice_value:
if (
tool_choice_value.get("type") == "function"
and "function" in tool_choice_value
):
func_name = tool_choice_value["function"].get("name")
if func_name:
return {
@ -263,17 +302,26 @@ class GeminiAdapter(BaseAdapter):
}
}
else:
logger.warning(f"tool_choice dict 中的函数名无效: {tool_choice_value}。回退到 AUTO。")
logger.warning(
f"tool_choice dict 中的函数名无效: {tool_choice_value}"
f"回退到 AUTO。"
)
return {"functionCallingConfig": {"mode": "AUTO"}}
elif "functionCallingConfig" in tool_choice_value:
return {"functionCallingConfig": tool_choice_value["functionCallingConfig"]}
return {
"functionCallingConfig": tool_choice_value["functionCallingConfig"]
}
else:
logger.warning(f"不支持的 tool_choice dict 值: {tool_choice_value}。回退到 AUTO。")
logger.warning(
f"不支持的 tool_choice dict 值: {tool_choice_value}。回退到 AUTO。"
)
return {"functionCallingConfig": {"mode": "AUTO"}}
logger.warning(f"tool_choice 的类型无效: {type(tool_choice_value)}。回退到 AUTO。")
logger.warning(
f"tool_choice 的类型无效: {type(tool_choice_value)}。回退到 AUTO。"
)
return {"functionCallingConfig": {"mode": "AUTO"}}
def _build_gemini_generation_config(
@ -285,11 +333,15 @@ class GeminiAdapter(BaseAdapter):
effective_config = config if config is not None else model._generation_config
if effective_config:
base_api_params = effective_config.to_api_params(api_type="gemini", model_name=model.model_name)
base_api_params = effective_config.to_api_params(
api_type="gemini", model_name=model.model_name
)
generation_config.update(base_api_params)
if getattr(effective_config, "response_mime_type", None):
generation_config["responseMimeType"] = effective_config.response_mime_type
generation_config["responseMimeType"] = (
effective_config.response_mime_type
)
if getattr(effective_config, "response_schema", None):
generation_config["responseSchema"] = effective_config.response_schema
@ -303,15 +355,22 @@ class GeminiAdapter(BaseAdapter):
if getattr(effective_config, "response_modalities", None):
modalities = effective_config.response_modalities
if isinstance(modalities, list):
generation_config["responseModalities"] = [m.upper() for m in modalities]
generation_config["responseModalities"] = [
m.upper() for m in modalities
]
elif isinstance(modalities, str):
generation_config["responseModalities"] = [modalities.upper()]
generation_config = {k: v for k, v in generation_config.items() if v is not None}
generation_config = {
k: v for k, v in generation_config.items() if v is not None
}
if generation_config:
param_keys = list(generation_config.keys())
logger.debug(f"构建Gemini生成配置完成包含 {len(generation_config)} 个参数: {param_keys}")
logger.debug(
f"构建Gemini生成配置完成包含 {len(generation_config)} 个参数: "
f"{param_keys}"
)
return generation_config
@ -337,7 +396,9 @@ class GeminiAdapter(BaseAdapter):
safety_settings.append({"category": category, "threshold": threshold})
else:
for category in safety_categories:
safety_settings.append({"category": category, "threshold": "BLOCK_MEDIUM_AND_ABOVE"})
safety_settings.append(
{"category": category, "threshold": "BLOCK_MEDIUM_AND_ABOVE"}
)
return safety_settings if safety_settings else None
@ -368,12 +429,18 @@ class GeminiAdapter(BaseAdapter):
candidate = candidates[0]
if candidate.get("finishReason") in ["RECITATION", "OTHER"] and not candidate.get("content"):
if candidate.get("finishReason") in [
"RECITATION",
"OTHER",
] and not candidate.get("content"):
logger.warning(
f"Gemini candidate finished with reason '{candidate.get('finishReason')}' and no content."
f"Gemini candidate finished with reason "
f"'{candidate.get('finishReason')}' and no content."
)
return ResponseData(
text="", raw_response=response_json, usage_info=response_json.get("usageMetadata")
text="",
raw_response=response_json,
usage_info=response_json.get("usageMetadata"),
)
content_data = candidate.get("content", {})
@ -394,9 +461,10 @@ class GeminiAdapter(BaseAdapter):
from ..types.models import LLMToolCall, LLMToolFunction
call_id = f"call_{model.provider_name}_{len(parsed_tool_calls)}"
parsed_tool_calls.append(
LLMToolCall(
id=f"call_{model.provider_name}_{len(parsed_tool_calls)}",
id=call_id,
function=LLMToolFunction(
name=fc_data["name"],
arguments=json.dumps(fc_data["args"]),
@ -404,16 +472,22 @@ class GeminiAdapter(BaseAdapter):
)
)
except KeyError as e:
logger.warning(f"解析Gemini functionCall时缺少键: {fc_data}, 错误: {e}")
logger.warning(
f"解析Gemini functionCall时缺少键: {fc_data}, 错误: {e}"
)
except Exception as e:
logger.warning(f"解析Gemini functionCall时出错: {fc_data}, 错误: {e}")
logger.warning(
f"解析Gemini functionCall时出错: {fc_data}, 错误: {e}"
)
elif "codeExecutionResult" in part:
result = part["codeExecutionResult"]
if result.get("outcome") == "OK":
output = result.get("output", "")
text_content += f"\n[代码执行结果]:\n{output}\n"
else:
text_content += f"\n[代码执行失败]: {result.get('outcome', 'UNKNOWN')}\n"
text_content += (
f"\n[代码执行失败]: {result.get('outcome', 'UNKNOWN')}\n"
)
usage_info = response_json.get("usageMetadata")
@ -436,7 +510,11 @@ class GeminiAdapter(BaseAdapter):
except Exception as e:
logger.error(f"解析 Gemini 响应失败: {e}", e=e)
raise LLMException(f"解析API响应失败: {e}", code=LLMErrorCode.RESPONSE_PARSE_ERROR, cause=e)
raise LLMException(
f"解析API响应失败: {e}",
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
cause=e,
)
def prepare_embedding_request(
self,
@ -474,7 +552,9 @@ class GeminiAdapter(BaseAdapter):
body = {"requests": requests_payload}
return RequestData(url=url, headers=headers, body=body)
def parse_embedding_response(self, response_json: dict[str, Any]) -> list[list[float]]:
def parse_embedding_response(
self, response_json: dict[str, Any]
) -> list[list[float]]:
"""解析文本嵌入响应"""
try:
embeddings_data = response_json["embeddings"]
@ -482,18 +562,26 @@ class GeminiAdapter(BaseAdapter):
except KeyError as e:
logger.error(f"解析Gemini嵌入响应时缺少键: {e}. 响应: {response_json}")
raise LLMException(
"Gemini嵌入响应格式错误", code=LLMErrorCode.RESPONSE_PARSE_ERROR, details={"error": str(e)}
"Gemini嵌入响应格式错误",
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
details={"error": str(e)},
)
except Exception as e:
logger.error(f"解析Gemini嵌入响应时发生未知错误: {e}. 响应: {response_json}")
logger.error(
f"解析Gemini嵌入响应时发生未知错误: {e}. 响应: {response_json}"
)
raise LLMException(
f"解析Gemini嵌入响应失败: {e}", code=LLMErrorCode.RESPONSE_PARSE_ERROR, cause=e
f"解析Gemini嵌入响应失败: {e}",
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
cause=e,
)
def validate_embedding_response(self, response_json: dict[str, Any]) -> None:
"""验证嵌入响应"""
super().validate_embedding_response(response_json)
if "embeddings" not in response_json or not isinstance(response_json["embeddings"], list):
if "embeddings" not in response_json or not isinstance(
response_json["embeddings"], list
):
raise LLMException(
"Gemini嵌入响应缺少'embeddings'字段或格式不正确",
code=LLMErrorCode.RESPONSE_PARSE_ERROR,

View File

@ -91,7 +91,9 @@ class AI:
if isinstance(message, str):
llm_messages = [LLMMessage.user(message)]
elif isinstance(message, list) and all(isinstance(part, LLMContentPart) for part in message):
elif isinstance(message, list) and all(
isinstance(part, LLMContentPart) for part in message
):
llm_messages = [LLMMessage.user(message)]
elif isinstance(message, LLMMessage):
llm_messages = [message]
@ -103,7 +105,9 @@ class AI:
code=LLMErrorCode.API_REQUEST_FAILED,
)
response = await self._execute_generation(llm_messages, model, "聊天失败", kwargs)
response = await self._execute_generation(
llm_messages, model, "聊天失败", kwargs
)
return response.text
async def code(
@ -135,7 +139,12 @@ class AI:
}
async def search(
self, query: str | UniMessage, *, model: ModelName = None, instruction: str = "", **kwargs: Any
self,
query: str | UniMessage,
*,
model: ModelName = None,
instruction: str = "",
**kwargs: Any,
) -> dict[str, Any]:
"""信息搜索 - 支持多模态输入"""
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
@ -154,7 +163,9 @@ class AI:
if instruction:
final_messages.append(LLMMessage.user(instruction))
else:
raise LLMException("搜索内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED)
raise LLMException(
"搜索内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
)
else:
final_messages.append(LLMMessage.user(content_parts))
@ -206,7 +217,9 @@ class AI:
if instruction:
final_messages.append(LLMMessage.user(instruction))
else:
raise LLMException("分析内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED)
raise LLMException(
"分析内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
)
else:
final_messages.append(LLMMessage.user(content_parts))
@ -241,7 +254,12 @@ class AI:
tool_choice = "none"
response = await self._execute_generation(
final_messages, model, "内容分析失败", kwargs, llm_tools=llm_tools, tool_choice=tool_choice
final_messages,
model,
"内容分析失败",
kwargs,
llm_tools=llm_tools,
tool_choice=tool_choice,
)
if response.tool_calls:
@ -260,8 +278,12 @@ class AI:
) -> LLMResponse:
"""通用的生成执行方法,封装重复的模型获取、配置合并和异常处理逻辑"""
try:
resolved_model_name = self._resolve_model_name(model_name or self.config.model)
final_config_dict = self._merge_config(config_overrides, base_config=base_config)
resolved_model_name = self._resolve_model_name(
model_name or self.config.model
)
final_config_dict = self._merge_config(
config_overrides, base_config=base_config
)
async with await get_model_instance(
resolved_model_name, override_config=final_config_dict
@ -316,7 +338,9 @@ class AI:
if self.config.enable_gemini_thinking:
final_config["thinking_budget"] = 0.8
if self.config.enable_gemini_safe_mode:
final_config["safety_settings"] = CommonOverrides.gemini_safe().safety_settings
final_config["safety_settings"] = (
CommonOverrides.gemini_safe().safety_settings
)
if self.config.enable_gemini_multimodal:
final_config.update(CommonOverrides.gemini_multimodal().to_dict())
if self.config.enable_gemini_grounding:
@ -341,10 +365,13 @@ class AI:
return []
try:
resolved_model_str = model or self.config.default_embedding_model or self.config.model
resolved_model_str = (
model or self.config.default_embedding_model or self.config.model
)
if not resolved_model_str:
raise LLMException(
"使用 embed 功能时必须指定嵌入模型名称,或在 AIConfig 中配置 default_embedding_model。",
"使用 embed 功能时必须指定嵌入模型名称,"
"或在 AIConfig 中配置 default_embedding_model。",
code=LLMErrorCode.MODEL_NOT_FOUND,
)
resolved_model_str = self._resolve_model_name(resolved_model_str)
@ -360,7 +387,9 @@ class AI:
raise
except Exception as e:
logger.error(f"文本嵌入失败: {e}", e=e)
raise LLMException(f"文本嵌入失败: {e}", code=LLMErrorCode.EMBEDDING_FAILED, cause=e)
raise LLMException(
f"文本嵌入失败: {e}", code=LLMErrorCode.EMBEDDING_FAILED, cause=e
)
async def chat(
@ -443,7 +472,9 @@ async def analyze_multimodal(
**kwargs: Any,
) -> str | LLMResponse:
"""多模态分析便捷函数"""
message = create_multimodal_message(text=text, images=images, videos=videos, audios=audios)
message = create_multimodal_message(
text=text, images=images, videos=videos, audios=audios
)
return await analyze(message, instruction=instruction, model=model, **kwargs)
@ -458,7 +489,9 @@ async def search_multimodal(
**kwargs: Any,
) -> dict[str, Any]:
"""多模态搜索便捷函数"""
message = create_multimodal_message(text=text, images=images, videos=videos, audios=audios)
message = create_multimodal_message(
text=text, images=images, videos=videos, audios=audios
)
ai = AI()
return await ai.search(message, model=model, instruction=instruction, **kwargs)

View File

@ -15,27 +15,47 @@ from ..types.exceptions import LLMErrorCode, LLMException
class ModelConfigOverride(BaseModel):
"""模型配置覆盖参数"""
temperature: float | None = Field(default=None, ge=0.0, le=2.0, description="生成温度")
temperature: float | None = Field(
default=None, ge=0.0, le=2.0, description="生成温度"
)
max_tokens: int | None = Field(default=None, gt=0, description="最大输出token数")
top_p: float | None = Field(default=None, ge=0.0, le=1.0, description="核采样参数")
top_k: int | None = Field(default=None, gt=0, description="Top-K采样参数")
frequency_penalty: float | None = Field(default=None, ge=-2.0, le=2.0, description="频率惩罚")
presence_penalty: float | None = Field(default=None, ge=-2.0, le=2.0, description="存在惩罚")
repetition_penalty: float | None = Field(default=None, ge=0.0, le=2.0, description="重复惩罚")
frequency_penalty: float | None = Field(
default=None, ge=-2.0, le=2.0, description="频率惩罚"
)
presence_penalty: float | None = Field(
default=None, ge=-2.0, le=2.0, description="存在惩罚"
)
repetition_penalty: float | None = Field(
default=None, ge=0.0, le=2.0, description="重复惩罚"
)
stop: list[str] | str | None = Field(default=None, description="停止序列")
response_format: ResponseFormat | dict[str, Any] | None = Field(
default=None, description="期望的响应格式"
)
response_mime_type: str | None = Field(default=None, description="响应MIME类型Gemini专用")
response_schema: dict[str, Any] | None = Field(default=None, description="JSON响应模式")
thinking_budget: float | None = Field(default=None, ge=0.0, le=1.0, description="思考预算")
response_mime_type: str | None = Field(
default=None, description="响应MIME类型Gemini专用"
)
response_schema: dict[str, Any] | None = Field(
default=None, description="JSON响应模式"
)
thinking_budget: float | None = Field(
default=None, ge=0.0, le=1.0, description="思考预算"
)
safety_settings: dict[str, str] | None = Field(default=None, description="安全设置")
response_modalities: list[str] | None = Field(default=None, description="响应模态类型")
response_modalities: list[str] | None = Field(
default=None, description="响应模态类型"
)
enable_code_execution: bool | None = Field(default=None, description="是否启用代码执行")
enable_grounding: bool | None = Field(default=None, description="是否启用信息来源关联")
enable_code_execution: bool | None = Field(
default=None, description="是否启用代码执行"
)
enable_grounding: bool | None = Field(
default=None, description="是否启用信息来源关联"
)
enable_caching: bool | None = Field(default=None, description="是否启用响应缓存")
custom_params: dict[str, Any] | None = Field(default=None, description="自定义参数")
@ -43,7 +63,16 @@ class ModelConfigOverride(BaseModel):
def to_dict(self) -> dict[str, Any]:
"""转换为字典排除None值"""
result = {}
for key, value in self.model_dump().items():
model_data = getattr(self, "model_dump", lambda: {})()
if not model_data:
model_data = {}
for field_name, _ in self.__class__.__dict__.get(
"model_fields", {}
).items():
value = getattr(self, field_name, None)
if value is not None:
model_data[field_name] = value
for key, value in model_data.items():
if value is not None:
if key == "custom_params" and isinstance(value, dict):
result.update(value)
@ -110,13 +139,14 @@ class LLMGenerationConfig(ModelConfigOverride):
else:
params["repetition_penalty"] = self.repetition_penalty
# 处理 response_format 参数
if self.response_format is not None:
if isinstance(self.response_format, dict):
# 直接使用字典格式的 response_format如 {'type': 'json_object'}
if api_type in ["openai", "zhipu", "deepseek", "general_openai_compat"]:
params["response_format"] = self.response_format
logger.debug(f"{api_type} 使用自定义 response_format: {self.response_format}")
logger.debug(
f"{api_type} 使用自定义 response_format: "
f"{self.response_format}"
)
elif self.response_format == ResponseFormat.JSON:
if api_type in ["openai", "zhipu", "deepseek", "general_openai_compat"]:
params["response_format"] = {"type": "json_object"}
@ -128,9 +158,14 @@ class LLMGenerationConfig(ModelConfigOverride):
logger.debug(f"{api_type} 启用 JSON MIME 类型输出模式")
if api_type in ["gemini", "gemini_native"]:
if self.response_format != ResponseFormat.JSON and self.response_mime_type is not None:
if (
self.response_format != ResponseFormat.JSON
and self.response_mime_type is not None
):
params["responseMimeType"] = self.response_mime_type
logger.debug(f"使用显式设置的 responseMimeType: {self.response_mime_type}")
logger.debug(
f"使用显式设置的 responseMimeType: {self.response_mime_type}"
)
if self.response_schema is not None and "responseSchema" not in params:
params["responseSchema"] = self.response_schema
@ -158,7 +193,9 @@ def validate_override_params(
if isinstance(override_config, dict):
try:
filtered_config = {k: v for k, v in override_config.items() if v is not None}
filtered_config = {
k: v for k, v in override_config.items() if v is not None
}
return LLMGenerationConfig(**filtered_config)
except Exception as e:
logger.warning(f"覆盖配置参数验证失败: {e}")
@ -171,7 +208,9 @@ def validate_override_params(
return override_config
def apply_api_specific_mappings(params: dict[str, Any], api_type: str) -> dict[str, Any]:
def apply_api_specific_mappings(
params: dict[str, Any], api_type: str
) -> dict[str, Any]:
"""应用API特定的参数映射"""
mapped_params = params.copy()
@ -204,7 +243,8 @@ def apply_api_specific_mappings(params: dict[str, Any], api_type: str) -> dict[s
def create_generation_config_from_kwargs(**kwargs) -> LLMGenerationConfig:
"""从关键字参数创建生成配置"""
known_fields = set(LLMGenerationConfig.model_fields.keys())
model_fields = getattr(LLMGenerationConfig, "model_fields", {})
known_fields = set(model_fields.keys())
known_params = {}
custom_params = {}

View File

@ -30,17 +30,25 @@ class CommonOverrides:
@staticmethod
def concise(max_tokens: int = 100) -> LLMGenerationConfig:
"""简洁模式:限制输出长度"""
return LLMGenerationConfig(temperature=0.3, max_tokens=max_tokens, stop=["\n\n", "", "", ""])
return LLMGenerationConfig(
temperature=0.3,
max_tokens=max_tokens,
stop=["\n\n", "", "", ""],
)
@staticmethod
def detailed(max_tokens: int = 2000) -> LLMGenerationConfig:
"""详细模式:鼓励详细输出"""
return LLMGenerationConfig(temperature=0.7, max_tokens=max_tokens, frequency_penalty=-0.1)
return LLMGenerationConfig(
temperature=0.7, max_tokens=max_tokens, frequency_penalty=-0.1
)
@staticmethod
def gemini_json() -> LLMGenerationConfig:
"""Gemini JSON模式强制JSON输出"""
return LLMGenerationConfig(temperature=0.3, response_mime_type="application/json")
return LLMGenerationConfig(
temperature=0.3, response_mime_type="application/json"
)
@staticmethod
def gemini_thinking(budget: float = 0.8) -> LLMGenerationConfig:
@ -96,7 +104,9 @@ class CommonOverrides:
temperature=0.5,
max_tokens=4096,
enable_grounding=True,
custom_params={"grounding_config": {"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}}},
custom_params={
"grounding_config": {"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}}
},
)
@staticmethod
@ -119,7 +129,9 @@ class CommonOverrides:
enable_caching=True,
custom_params={
"code_execution_timeout": 30,
"grounding_config": {"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}},
"grounding_config": {
"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}
},
},
)
@ -132,7 +144,9 @@ class CommonOverrides:
thinking_budget=0.8,
enable_grounding=True,
response_mime_type="application/json",
custom_params={"grounding_config": {"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}}},
custom_params={
"grounding_config": {"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}}
},
)
@staticmethod

View File

@ -40,7 +40,8 @@ class LLMHttpClient:
async with self._lock:
if self._client is None or self._client.is_closed:
logger.debug(
f"LLMHttpClient: Initializing new httpx.AsyncClient with config: {self.config}"
f"LLMHttpClient: Initializing new httpx.AsyncClient "
f"with config: {self.config}"
)
headers = get_user_agent()
limits = httpx.Limits(
@ -52,11 +53,13 @@ class LLMHttpClient:
headers=headers,
limits=limits,
timeout=timeout,
proxy=self.config.proxy,
proxies=self.config.proxy,
follow_redirects=True,
)
if self._client is None:
raise LLMException("HTTP client failed to initialize.", LLMErrorCode.CONFIGURATION_ERROR)
raise LLMException(
"HTTP client failed to initialize.", LLMErrorCode.CONFIGURATION_ERROR
)
return self._client
async def post(self, url: str, **kwargs: Any) -> httpx.Response:
@ -78,7 +81,8 @@ class LLMHttpClient:
)
if self._active_requests > 0:
logger.warning(
f"LLMHttpClient: Closing while {self._active_requests} requests are still active."
f"LLMHttpClient: Closing while {self._active_requests} "
f"requests are still active."
)
await self._client.aclose()
self._client = None
@ -96,7 +100,9 @@ class LLMHttpClientManager:
self._clients: dict[tuple[int, str | None], LLMHttpClient] = {}
self._lock = asyncio.Lock()
def _get_client_key(self, provider_config: ProviderConfig) -> tuple[int, str | None]:
def _get_client_key(
self, provider_config: ProviderConfig
) -> tuple[int, str | None]:
return (provider_config.timeout, provider_config.proxy)
async def get_client(self, provider_config: ProviderConfig) -> LLMHttpClient:
@ -104,15 +110,21 @@ class LLMHttpClientManager:
async with self._lock:
client = self._clients.get(key)
if client and not client.is_closed:
logger.debug(f"LLMHttpClientManager: Reusing existing LLMHttpClient for key: {key}")
logger.debug(
f"LLMHttpClientManager: Reusing existing LLMHttpClient "
f"for key: {key}"
)
return client
if client and client.is_closed:
logger.debug(
f"LLMHttpClientManager: Found a closed client for key {key}. Creating a new one."
f"LLMHttpClientManager: Found a closed client for key {key}. "
f"Creating a new one."
)
logger.debug(f"LLMHttpClientManager: Creating new LLMHttpClient for key: {key}")
logger.debug(
f"LLMHttpClientManager: Creating new LLMHttpClient for key: {key}"
)
http_client_config = HttpClientConfig(
timeout=provider_config.timeout, proxy=provider_config.proxy
)
@ -122,9 +134,14 @@ class LLMHttpClientManager:
async def shutdown(self):
async with self._lock:
logger.info(f"LLMHttpClientManager: Shutting down. Closing {len(self._clients)} client(s).")
logger.info(
f"LLMHttpClientManager: Shutting down. "
f"Closing {len(self._clients)} client(s)."
)
close_tasks = [
client.close() for client in self._clients.values() if client and not client.is_closed
client.close()
for client in self._clients.values()
if client and not client.is_closed
]
if close_tasks:
await asyncio.gather(*close_tasks, return_exceptions=True)
@ -183,11 +200,16 @@ async def with_smart_retry(
except LLMException as e:
last_exception = e
if e.code in [LLMErrorCode.API_KEY_INVALID, LLMErrorCode.API_QUOTA_EXCEEDED]:
if e.code in [
LLMErrorCode.API_KEY_INVALID,
LLMErrorCode.API_QUOTA_EXCEEDED,
]:
if hasattr(e, "details") and e.details and "api_key" in e.details:
failed_keys.add(e.details["api_key"])
if key_store and provider_name:
await key_store.record_failure(e.details["api_key"], e.details.get("status_code"))
await key_store.record_failure(
e.details["api_key"], e.details.get("status_code")
)
should_retry = _should_retry_llm_error(e, attempt, config.max_retries)
if not should_retry:
@ -198,7 +220,9 @@ async def with_smart_retry(
wait_time = config.retry_delay
if config.exponential_backoff:
wait_time *= 2**attempt
logger.warning(f"请求失败,{wait_time}秒后重试 (第{attempt + 1}次): {e}")
logger.warning(
f"请求失败,{wait_time}秒后重试 (第{attempt + 1}次): {e}"
)
await asyncio.sleep(wait_time)
else:
logger.error(f"重试{config.max_retries}次后仍然失败: {e}")
@ -218,7 +242,9 @@ async def with_smart_retry(
raise RuntimeError("重试函数未能正常执行且未捕获到异常")
def _should_retry_llm_error(error: LLMException, attempt: int, max_retries: int) -> bool:
def _should_retry_llm_error(
error: LLMException, attempt: int, max_retries: int
) -> bool:
"""判断LLM错误是否应该重试"""
non_retryable_errors = {
LLMErrorCode.MODEL_NOT_FOUND,
@ -263,7 +289,10 @@ class KeyStatusStore:
self._lock = asyncio.Lock()
async def get_next_available_key(
self, provider_name: str, api_keys: list[str], exclude_keys: set[str] | None = None
self,
provider_name: str,
api_keys: list[str],
exclude_keys: set[str] | None = None,
) -> str | None:
"""获取下一个可用的API密钥轮询策略"""
if not api_keys:
@ -271,7 +300,9 @@ class KeyStatusStore:
exclude_keys = exclude_keys or set()
available_keys = [
key for key in api_keys if key not in exclude_keys and self._key_status.get(key, True)
key
for key in api_keys
if key not in exclude_keys and self._key_status.get(key, True)
]
if not available_keys:
@ -282,11 +313,15 @@ class KeyStatusStore:
selected_key = available_keys[current_index % len(available_keys)]
self._provider_key_index[provider_name] = (current_index + 1) % len(available_keys)
self._provider_key_index[provider_name] = (current_index + 1) % len(
available_keys
)
import time
self._key_usage_count[selected_key] = self._key_usage_count.get(selected_key, 0) + 1
self._key_usage_count[selected_key] = (
self._key_usage_count.get(selected_key, 0) + 1
)
self._key_last_used[selected_key] = time.time()
logger.debug(
@ -308,7 +343,9 @@ class KeyStatusStore:
async with self._lock:
if status_code in [401, 403]:
self._key_status[api_key] = False
logger.warning(f"API密钥认证失败标记为不可用: {key_id} (状态码: {status_code})")
logger.warning(
f"API密钥认证失败标记为不可用: {key_id} (状态码: {status_code})"
)
else:
logger.debug(f"记录API密钥失败使用: {key_id} (状态码: {status_code})")

View File

@ -37,9 +37,13 @@ def parse_provider_model_string(name_str: str | None) -> tuple[str | None, str |
return None, None
def _make_cache_key(provider_model_name: str | None, override_config: dict | None) -> str:
def _make_cache_key(
provider_model_name: str | None, override_config: dict | None
) -> str:
"""生成缓存键"""
config_str = json.dumps(override_config, sort_keys=True) if override_config else "None"
config_str = (
json.dumps(override_config, sort_keys=True) if override_config else "None"
)
key_data = f"{provider_model_name}:{config_str}"
return hashlib.md5(key_data.encode()).hexdigest()
@ -62,7 +66,9 @@ def _get_cached_model(cache_key: str) -> LLMModel | None:
)
model._is_closed = False
logger.debug(f"使用缓存的模型: {cache_key} -> {model.provider_name}/{model.model_name}")
logger.debug(
f"使用缓存的模型: {cache_key} -> {model.provider_name}/{model.model_name}"
)
return model
return None
@ -113,7 +119,10 @@ def get_configured_providers() -> list[ProviderConfig]:
ai_config = get_ai_config()
providers_raw = ai_config.get(PROVIDERS_CONFIG_KEY, [])
if not isinstance(providers_raw, list):
logger.error(f"配置项 {AI_CONFIG_GROUP}.{PROVIDERS_CONFIG_KEY} 不是一个列表,将使用空列表。")
logger.error(
f"配置项 {AI_CONFIG_GROUP}.{PROVIDERS_CONFIG_KEY} 不是一个列表,"
f"将使用空列表。"
)
return []
valid_providers = []
@ -128,7 +137,9 @@ def get_configured_providers() -> list[ProviderConfig]:
continue
if not item.get("api_key"):
logger.warning(f"Provider '{item['name']}' 缺少 'api_key' 字段,已跳过。")
logger.warning(
f"Provider '{item['name']}' 缺少 'api_key' 字段,已跳过。"
)
continue
if "api_type" not in item or not item["api_type"]:
@ -159,7 +170,9 @@ def get_configured_providers() -> list[ProviderConfig]:
return valid_providers
def find_model_config(provider_name: str, model_name: str) -> tuple[ProviderConfig, ModelDetail] | None:
def find_model_config(
provider_name: str, model_name: str
) -> tuple[ProviderConfig, ModelDetail] | None:
"""在配置中查找指定的 Provider 和 ModelDetail
Args:
@ -194,7 +207,9 @@ def list_available_models() -> list[dict[str, Any]]:
"api_base": provider.api_base,
"is_available": model_detail.is_available,
"is_embedding_model": model_detail.is_embedding_model,
"available_identifiers": _get_model_identifiers(provider.name, model_detail),
"available_identifiers": _get_model_identifiers(
provider.name, model_detail
),
}
model_list.append(model_info)
return model_list
@ -242,7 +257,8 @@ async def get_model_instance(
if cached_model._generation_config != validated_override:
cached_model._generation_config = validated_override
logger.debug(
f"对缓存模型 {provider_model_name} 应用新的覆盖配置: {validated_override.to_dict()}"
f"对缓存模型 {provider_model_name} 应用新的覆盖配置: "
f"{validated_override.to_dict()}"
)
return cached_model
@ -252,21 +268,25 @@ async def get_model_instance(
if resolved_model_name_str is None:
available_models_list = list_available_models()
if not available_models_list:
raise LLMException("未配置任何AI模型", code=LLMErrorCode.CONFIGURATION_ERROR)
raise LLMException(
"未配置任何AI模型", code=LLMErrorCode.CONFIGURATION_ERROR
)
resolved_model_name_str = available_models_list[0]["full_name"]
logger.warning(f"未指定模型,使用第一个可用模型: {resolved_model_name_str}")
prov_name_str, mod_name_str = parse_provider_model_string(resolved_model_name_str)
if not prov_name_str or not mod_name_str:
raise LLMException(
f"无效的模型名称格式: '{resolved_model_name_str}'", code=LLMErrorCode.MODEL_NOT_FOUND
f"无效的模型名称格式: '{resolved_model_name_str}'",
code=LLMErrorCode.MODEL_NOT_FOUND,
)
config_tuple_found = find_model_config(prov_name_str, mod_name_str)
if not config_tuple_found:
all_models = list_available_models()
raise LLMException(
f"未找到模型: '{resolved_model_name_str}'. 可用: {[m['full_name'] for m in all_models]}",
f"未找到模型: '{resolved_model_name_str}'. "
f"可用: {[m['full_name'] for m in all_models]}",
code=LLMErrorCode.MODEL_NOT_FOUND,
)
@ -274,7 +294,11 @@ async def get_model_instance(
ai_config = get_ai_config()
global_proxy_setting = ai_config.get(PROXY_KEY)
default_timeout = provider_config_found.timeout if provider_config_found.timeout is not None else 180
default_timeout = (
provider_config_found.timeout
if provider_config_found.timeout is not None
else 180
)
global_timeout_setting = ai_config.get(TIMEOUT_KEY, default_timeout)
config_for_http_client = ProviderConfig(
@ -304,16 +328,21 @@ async def get_model_instance(
validated_override_params = validate_override_params(override_config)
model_instance._generation_config = validated_override_params
logger.debug(
f"为新模型 {resolved_model_name_str} 应用配置覆盖: {validated_override_params.to_dict()}"
f"为新模型 {resolved_model_name_str} 应用配置覆盖: "
f"{validated_override_params.to_dict()}"
)
_cache_model(cache_key, model_instance)
logger.debug(f"创建并缓存了新模型: {cache_key} -> {prov_name_str}/{mod_name_str}")
logger.debug(
f"创建并缓存了新模型: {cache_key} -> {prov_name_str}/{mod_name_str}"
)
return model_instance
except LLMException:
raise
except Exception as e:
logger.error(f"实例化 LLMModel ({resolved_model_name_str}) 时发生内部错误: {e!s}", e=e)
logger.error(
f"实例化 LLMModel ({resolved_model_name_str}) 时发生内部错误: {e!s}", e=e
)
raise LLMException(
f"初始化模型 '{resolved_model_name_str}' 失败: {e!s}",
code=LLMErrorCode.MODEL_INIT_FAILED,
@ -332,10 +361,14 @@ def set_global_default_model_name(provider_model_name: str | None) -> bool:
if provider_model_name:
prov_name, mod_name = parse_provider_model_string(provider_model_name)
if not prov_name or not mod_name or not find_model_config(prov_name, mod_name):
logger.error(f"尝试设置的全局默认模型 '{provider_model_name}' 无效或未配置。")
logger.error(
f"尝试设置的全局默认模型 '{provider_model_name}' 无效或未配置。"
)
return False
Config.set_config(AI_CONFIG_GROUP, DEFAULT_MODEL_NAME_KEY, provider_model_name, auto_save=True)
Config.set_config(
AI_CONFIG_GROUP, DEFAULT_MODEL_NAME_KEY, provider_model_name, auto_save=True
)
if provider_model_name:
logger.info(f"LLM 服务全局默认模型已更新为: {provider_model_name}")
else:
@ -350,10 +383,16 @@ async def get_key_usage_stats() -> dict[str, Any]:
for provider in providers:
provider_stats = await key_store.get_key_stats(
[provider.api_key] if isinstance(provider.api_key, str) else provider.api_key
[provider.api_key]
if isinstance(provider.api_key, str)
else provider.api_key
)
stats[provider.name] = {
"total_keys": len([provider.api_key] if isinstance(provider.api_key, str) else provider.api_key),
"total_keys": len(
[provider.api_key]
if isinstance(provider.api_key, str)
else provider.api_key
),
"key_stats": provider_stats,
}
@ -375,7 +414,9 @@ async def reset_key_status(provider_name: str, api_key: str | None = None) -> bo
return False
provider_keys = (
[target_provider.api_key] if isinstance(target_provider.api_key, str) else target_provider.api_key
[target_provider.api_key]
if isinstance(target_provider.api_key, str)
else target_provider.api_key
)
if api_key:

View File

@ -89,7 +89,9 @@ class LLMModel(LLMModelBase):
self.api_type = provider_config.api_type
self.api_base = provider_config.api_base
self.api_keys = (
[provider_config.api_key] if isinstance(provider_config.api_key, str) else provider_config.api_key
[provider_config.api_key]
if isinstance(provider_config.api_key, str)
else provider_config.api_key
)
self.model_name = model_detail.model_name
self.temperature = model_detail.temperature
@ -101,9 +103,12 @@ class LLMModel(LLMModelBase):
"""获取HTTP客户端"""
if self.http_client.is_closed:
logger.debug(
f"LLMModel {self.provider_name}/{self.model_name} 的 HTTP 客户端已关闭,正在获取新的客户端"
f"LLMModel {self.provider_name}/{self.model_name} 的 HTTP 客户端已关闭,"
"正在获取新的客户端"
)
self.http_client = await http_client_manager.get_client(
self.provider_config
)
self.http_client = await http_client_manager.get_client(self.provider_config)
return self.http_client
async def _select_api_key(self, failed_keys: set[str] | None = None) -> str:
@ -122,7 +127,10 @@ class LLMModel(LLMModelBase):
raise LLMException(
f"提供商 {self.provider_name} 的所有API密钥当前都不可用",
code=LLMErrorCode.NO_AVAILABLE_KEYS,
details={"total_keys": len(self.api_keys), "failed_keys": len(failed_keys or set())},
details={
"total_keys": len(self.api_keys),
"failed_keys": len(failed_keys or set()),
},
)
return selected_key
@ -154,7 +162,9 @@ class LLMModel(LLMModelBase):
if http_response.status_code != 200:
error_text = http_response.text
logger.error(f"HTTP嵌入请求失败: {http_response.status_code} - {error_text}")
logger.error(
f"HTTP嵌入请求失败: {http_response.status_code} - {error_text}"
)
await self.key_store.record_failure(api_key, http_response.status_code)
error_code = LLMErrorCode.API_REQUEST_FAILED
@ -262,7 +272,9 @@ class LLMModel(LLMModelBase):
if http_response.status_code != 200:
error_text = http_response.text
logger.error(f"HTTP请求失败: {http_response.status_code} - {error_text}")
logger.error(
f"HTTP请求失败: {http_response.status_code} - {error_text}"
)
await self.key_store.record_failure(api_key, http_response.status_code)
@ -304,7 +316,10 @@ class LLMModel(LLMModelBase):
try:
response_tool_calls.append(LLMToolCall(**tc_data))
except Exception as e:
logger.warning(f"无法将工具调用数据转换为LLMToolCall: {tc_data}, error: {e}")
logger.warning(
f"无法将工具调用数据转换为LLMToolCall: {tc_data}, "
f"error: {e}"
)
else:
logger.warning(f"工具调用数据格式未知: {tc_data}")
@ -355,11 +370,16 @@ class LLMModel(LLMModelBase):
if self._is_closed:
return
self._is_closed = True
logger.debug(f"LLMModel实例的使用周期已结束: {self} (共享HTTP客户端状态不受影响)")
logger.debug(
f"LLMModel实例的使用周期已结束: {self} (共享HTTP客户端状态不受影响)"
)
async def __aenter__(self):
if self._is_closed:
logger.debug(f"Re-entering context for closed LLMModel {self}. Resetting _is_closed to False.")
logger.debug(
f"Re-entering context for closed LLMModel {self}. "
f"Resetting _is_closed to False."
)
self._is_closed = False
self._check_not_closed()
return self
@ -393,12 +413,15 @@ class LLMModel(LLMModelBase):
messages.append(LLMMessage.user(prompt))
model_fields = getattr(LLMGenerationConfig, "model_fields", {})
request_specific_config_dict = {
k: v for k, v in kwargs.items() if k in LLMGenerationConfig.model_fields
k: v for k, v in kwargs.items() if k in model_fields
}
request_specific_config = None
if request_specific_config_dict:
request_specific_config = LLMGenerationConfig(**request_specific_config_dict)
request_specific_config = LLMGenerationConfig(
**request_specific_config_dict
)
for key in request_specific_config_dict:
kwargs.pop(key, None)
@ -450,7 +473,8 @@ class LLMModel(LLMModelBase):
tools_dict = []
for tool in tools:
if hasattr(tool, "model_dump"):
tools_dict.append(tool.model_dump(exclude_none=True))
model_dump_func = getattr(tool, "model_dump")
tools_dict.append(model_dump_func(exclude_none=True))
elif isinstance(tool, dict):
tools_dict.append(tool)
else:
@ -497,7 +521,9 @@ class LLMModel(LLMModelBase):
logger.debug(f"执行工具: {tool_name},参数: {tool_args_dict}")
tool_result = await tool_executor(tool_name, tool_args_dict)
logger.debug(f"工具 '{tool_name}' 执行结果: {str(tool_result)[:200]}...")
logger.debug(
f"工具 '{tool_name}' 执行结果: {str(tool_result)[:200]}..."
)
tool_response_messages.append(
LLMMessage.tool_response(
@ -508,13 +534,17 @@ class LLMModel(LLMModelBase):
)
except json.JSONDecodeError as e:
logger.error(
f"工具 '{tool_name}' 参数JSON解析失败: {tool_call.function.arguments}, 错误: {e}"
f"工具 '{tool_name}' 参数JSON解析失败: "
f"{tool_call.function.arguments}, 错误: {e}"
)
tool_response_messages.append(
LLMMessage.tool_response(
tool_call_id=tool_call.id,
function_name=tool_name,
result={"error": "Argument JSON parsing failed", "details": str(e)},
result={
"error": "Argument JSON parsing failed",
"details": str(e),
},
)
)
except Exception as e:
@ -523,7 +553,10 @@ class LLMModel(LLMModelBase):
LLMMessage.tool_response(
tool_call_id=tool_call.id,
function_name=tool_name,
result={"error": "Tool execution failed", "details": str(e)},
result={
"error": "Tool execution failed",
"details": str(e),
},
)
)
@ -533,7 +566,10 @@ class LLMModel(LLMModelBase):
raise LLMException(
"已达到最大工具调用迭代次数,但模型仍在请求工具调用或未提供最终文本回复。",
code=LLMErrorCode.GENERATION_FAILED,
details={"iterations": max_tool_iterations, "last_messages": current_messages[-2:]},
details={
"iterations": max_tool_iterations,
"last_messages": current_messages[-2:],
},
)
async def generate_embeddings(
@ -561,7 +597,9 @@ class LLMModel(LLMModelBase):
ai_config = get_ai_config()
default_max_retries = ai_config.get("max_retries_llm", 3)
default_retry_delay = ai_config.get("retry_delay_llm", 2)
max_retries_embed = kwargs.get("max_retries_embed", max(1, default_max_retries // 2))
max_retries_embed = kwargs.get(
"max_retries_embed", max(1, default_max_retries // 2)
)
retry_delay_embed = kwargs.get("retry_delay_embed", default_retry_delay / 2)
retry_config = RetryConfig(

View File

@ -58,7 +58,9 @@ class LLMContentPart(BaseModel):
return cls(type="image", image_source=url)
@classmethod
def image_base64_part(cls, data: str, mime_type: str = "image/png") -> "LLMContentPart":
def image_base64_part(
cls, data: str, mime_type: str = "image/png"
) -> "LLMContentPart":
"""创建Base64图片内容部分"""
data_url = f"data:{mime_type};base64,{data}"
return cls(type="image", image_source=data_url)
@ -74,13 +76,17 @@ class LLMContentPart(BaseModel):
return cls(type="video", video_source=url, mime_type=mime_type)
@classmethod
def video_base64_part(cls, data: str, mime_type: str = "video/mp4") -> "LLMContentPart":
def video_base64_part(
cls, data: str, mime_type: str = "video/mp4"
) -> "LLMContentPart":
"""创建Base64视频内容部分"""
data_url = f"data:{mime_type};base64,{data}"
return cls(type="video", video_source=data_url, mime_type=mime_type)
@classmethod
def audio_base64_part(cls, data: str, mime_type: str = "audio/wav") -> "LLMContentPart":
def audio_base64_part(
cls, data: str, mime_type: str = "audio/wav"
) -> "LLMContentPart":
"""创建Base64音频内容部分"""
data_url = f"data:{mime_type};base64,{data}"
return cls(type="audio", audio_source=data_url, mime_type=mime_type)
@ -101,7 +107,9 @@ class LLMContentPart(BaseModel):
)
@classmethod
async def from_path(cls, path_like: str | Path, target_api: str | None = None) -> "LLMContentPart | None":
async def from_path(
cls, path_like: str | Path, target_api: str | None = None
) -> "LLMContentPart | None":
"""
从本地文件路径创建 LLMContentPart
自动检测MIME类型并根据类型如图片可能加载为Base64
@ -116,7 +124,9 @@ class LLMContentPart(BaseModel):
mime_type, _ = mimetypes.guess_type(path.resolve().as_uri())
if not mime_type:
logger.warning(f"无法猜测文件 {path.name} 的MIME类型将尝试作为文本文件处理。")
logger.warning(
f"无法猜测文件 {path.name} 的MIME类型将尝试作为文本文件处理。"
)
try:
async with aiofiles.open(path, encoding="utf-8") as f:
text_content = await f.read()
@ -131,17 +141,22 @@ class LLMContentPart(BaseModel):
async with aiofiles.open(path, "rb") as f:
img_bytes = await f.read()
base64_data = base64.b64encode(img_bytes).decode("utf-8")
return cls.image_base64_part(data=base64_data, mime_type=mime_type)
return cls.image_base64_part(
data=base64_data, mime_type=mime_type
)
except Exception as e:
logger.error(f"读取或编码图片文件 {path.name} 失败: {e}")
return None
else:
logger.warning(
f"为本地图片路径 {path.name} 生成 image_url_part。实际API可能不支持 file:// URI。考虑使用Base64或公网URL。" # noqa: E501
f"为本地图片路径 {path.name} 生成 image_url_part。"
"实际API可能不支持 file:// URI。考虑使用Base64或公网URL。"
)
return cls.image_url_part(url=path.resolve().as_uri())
elif mime_type.startswith("audio/"):
return cls.audio_url_part(url=path.resolve().as_uri(), mime_type=mime_type)
return cls.audio_url_part(
url=path.resolve().as_uri(), mime_type=mime_type
)
elif mime_type.startswith("video/"):
if target_api == "gemini":
# 对于 Gemini API将视频转换为 base64
@ -149,12 +164,16 @@ class LLMContentPart(BaseModel):
async with aiofiles.open(path, "rb") as f:
video_bytes = await f.read()
base64_data = base64.b64encode(video_bytes).decode("utf-8")
return cls.video_base64_part(data=base64_data, mime_type=mime_type)
return cls.video_base64_part(
data=base64_data, mime_type=mime_type
)
except Exception as e:
logger.error(f"读取或编码视频文件 {path.name} 失败: {e}")
return None
else:
return cls.video_url_part(url=path.resolve().as_uri(), mime_type=mime_type)
return cls.video_url_part(
url=path.resolve().as_uri(), mime_type=mime_type
)
elif (
mime_type.startswith("text/")
or mime_type == "application/json"
@ -168,7 +187,9 @@ class LLMContentPart(BaseModel):
logger.error(f"读取文本类文件 {path.name} 失败: {e}")
return None
else:
logger.info(f"文件 {path.name} (MIME: {mime_type}) 将作为通用文件URI处理。")
logger.info(
f"文件 {path.name} (MIME: {mime_type}) 将作为通用文件URI处理。"
)
return cls.file_uri_part(
file_uri=path.resolve().as_uri(),
mime_type=mime_type,
@ -228,9 +249,13 @@ class LLMContentPart(BaseModel):
return {"inlineData": {"mimeType": mime_type, "data": data}}
else:
# 如果无法解析 Base64 数据,抛出异常
raise ValueError(f"无法解析Base64图像数据: {self.image_source[:50]}...")
raise ValueError(
f"无法解析Base64图像数据: {self.image_source[:50]}..."
)
else:
logger.warning(f"Gemini API需要Base64格式但提供的是URL: {self.image_source}")
logger.warning(
f"Gemini API需要Base64格式但提供的是URL: {self.image_source}"
)
return {
"inlineData": {
"mimeType": "image/jpeg",
@ -253,10 +278,14 @@ class LLMContentPart(BaseModel):
mime_type = header.split(";")[0].replace("data:", "")
return {"inlineData": {"mimeType": mime_type, "data": data}}
except (ValueError, IndexError):
raise ValueError(f"无法解析Base64视频数据: {self.video_source[:50]}...")
raise ValueError(
f"无法解析Base64视频数据: {self.video_source[:50]}..."
)
else:
# 对于 URL 或其他格式,暂时不支持直接内联
raise ValueError("Gemini API 的视频处理需要通过 File API 上传,不支持直接 URL")
raise ValueError(
"Gemini API 的视频处理需要通过 File API 上传,不支持直接 URL"
)
else:
# 其他 API 可能不支持视频
raise ValueError(f"API类型 '{api_type}' 不支持视频内容")
@ -273,17 +302,25 @@ class LLMContentPart(BaseModel):
mime_type = header.split(";")[0].replace("data:", "")
return {"inlineData": {"mimeType": mime_type, "data": data}}
except (ValueError, IndexError):
raise ValueError(f"无法解析Base64音频数据: {self.audio_source[:50]}...")
raise ValueError(
f"无法解析Base64音频数据: {self.audio_source[:50]}..."
)
else:
raise ValueError("Gemini API 的音频处理需要通过 File API 上传,不支持直接 URL")
raise ValueError(
"Gemini API 的音频处理需要通过 File API 上传,不支持直接 URL"
)
else:
raise ValueError(f"API类型 '{api_type}' 不支持音频内容")
elif self.type == "file":
if api_type == "gemini" and self.file_uri:
return {"fileData": {"mimeType": self.mime_type, "fileUri": self.file_uri}}
return {
"fileData": {"mimeType": self.mime_type, "fileUri": self.file_uri}
}
elif self.file_source:
file_name = self.metadata.get("name", "file") if self.metadata else "file"
file_name = (
self.metadata.get("name", "file") if self.metadata else "file"
)
if api_type == "gemini":
return {"text": f"[文件: {file_name}]\n{self.file_source}"}
else:
@ -317,7 +354,8 @@ class LLMMessage(BaseModel):
raise ValueError("工具角色的消息必须包含函数名 (在 name 字段中)")
if self.role == "tool" and not isinstance(self.content, str):
logger.warning(
f"工具角色消息的内容期望是字符串,但得到的是: {type(self.content)}. 将尝试转换为字符串。"
f"工具角色消息的内容期望是字符串,但得到的是: {type(self.content)}. "
"将尝试转换为字符串。"
)
try:
self.content = str(self.content)
@ -339,7 +377,9 @@ class LLMMessage(BaseModel):
return cls(role="assistant", content=content, tool_calls=tool_calls)
@classmethod
def assistant_text_response(cls, content: str | list[LLMContentPart]) -> "LLMMessage":
def assistant_text_response(
cls, content: str | list[LLMContentPart]
) -> "LLMMessage":
"""创建助手纯文本回复的消息"""
return cls(role="assistant", content=content, tool_calls=None)
@ -356,10 +396,19 @@ class LLMMessage(BaseModel):
try:
content_str = json.dumps(result)
except TypeError as e:
logger.error(f"工具 '{function_name}' 的结果无法JSON序列化: {result}. 错误: {e}")
content_str = json.dumps({"error": "Tool result not JSON serializable", "details": str(e)})
logger.error(
f"工具 '{function_name}' 的结果无法JSON序列化: {result}. 错误: {e}"
)
content_str = json.dumps(
{"error": "Tool result not JSON serializable", "details": str(e)}
)
return cls(role="tool", content=content_str, tool_call_id=tool_call_id, name=function_name)
return cls(
role="tool",
content=content_str,
tool_call_id=tool_call_id,
name=function_name,
)
@classmethod
def system(cls, content: str) -> "LLMMessage":

View File

@ -36,11 +36,15 @@ class LLMException(Exception):
error_messages = {
LLMErrorCode.MODEL_NOT_FOUND: "AI模型未找到请检查配置或联系管理员。",
LLMErrorCode.API_KEY_INVALID: "API密钥无效请联系管理员更新配置。",
LLMErrorCode.API_QUOTA_EXCEEDED: "API使用配额已用尽请稍后再试或联系管理员。",
LLMErrorCode.API_QUOTA_EXCEEDED: (
"API使用配额已用尽请稍后再试或联系管理员。"
),
LLMErrorCode.API_TIMEOUT: "AI服务响应超时请稍后再试。",
LLMErrorCode.API_RATE_LIMITED: "请求过于频繁已被AI服务限流请稍后再试。",
LLMErrorCode.MODEL_INIT_FAILED: "AI模型初始化失败请联系管理员检查配置。",
LLMErrorCode.NO_AVAILABLE_KEYS: "当前所有API密钥均不可用请稍后再试或联系管理员。",
LLMErrorCode.NO_AVAILABLE_KEYS: (
"当前所有API密钥均不可用请稍后再试或联系管理员。"
),
LLMErrorCode.USER_LOCATION_NOT_SUPPORTED: (
"当前地区暂不支持此AI服务请联系管理员或尝试其他模型。"
),

View File

@ -5,7 +5,16 @@ LLM 模块的工具和转换函数
import base64
from pathlib import Path
from nonebot_plugin_alconna.uniseg import At, File, Image, Reply, Text, UniMessage, Video, Voice
from nonebot_plugin_alconna.uniseg import (
At,
File,
Image,
Reply,
Text,
UniMessage,
Video,
Voice,
)
from zhenxun.services.log import logger
@ -29,7 +38,11 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
elif seg.url:
part = LLMContentPart.image_url_part(seg.url)
elif hasattr(seg, "raw") and seg.raw:
mime_type = getattr(seg, "mimetype", "image/png") if hasattr(seg, "mimetype") else "image/png"
mime_type = (
getattr(seg, "mimetype", "image/png")
if hasattr(seg, "mimetype")
else "image/png"
)
if isinstance(seg.raw, bytes):
b64_data = base64.b64encode(seg.raw).decode("utf-8")
part = LLMContentPart.image_base64_part(b64_data, mime_type)
@ -38,8 +51,13 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
if seg.path:
part = await LLMContentPart.from_path(seg.path)
elif seg.url:
logger.warning(f"直接使用 URL 的 {type(seg).__name__}API 可能不支持: {seg.url}")
part = LLMContentPart.text_part(f"[{type(seg).__name__.upper()} FILE: {seg.name or seg.url}]")
logger.warning(
f"直接使用 URL 的 {type(seg).__name__} 段,"
f"API 可能不支持: {seg.url}"
)
part = LLMContentPart.text_part(
f"[{type(seg).__name__.upper()} FILE: {seg.name or seg.url}]"
)
elif hasattr(seg, "raw") and seg.raw:
mime_type = getattr(seg, "mimetype", None)
if isinstance(seg.raw, bytes):
@ -48,18 +66,29 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
if isinstance(seg, Video):
if not mime_type:
mime_type = "video/mp4"
part = LLMContentPart.video_base64_part(data=b64_data, mime_type=mime_type)
logger.debug(f"处理视频字节数据: {mime_type}, 大小: {len(seg.raw)} bytes")
part = LLMContentPart.video_base64_part(
data=b64_data, mime_type=mime_type
)
logger.debug(
f"处理视频字节数据: {mime_type}, 大小: {len(seg.raw)} bytes"
)
elif isinstance(seg, Voice):
if not mime_type:
mime_type = "audio/wav"
part = LLMContentPart.audio_base64_part(data=b64_data, mime_type=mime_type)
logger.debug(f"处理音频字节数据: {mime_type}, 大小: {len(seg.raw)} bytes")
part = LLMContentPart.audio_base64_part(
data=b64_data, mime_type=mime_type
)
logger.debug(
f"处理音频字节数据: {mime_type}, 大小: {len(seg.raw)} bytes"
)
else:
part = LLMContentPart.text_part(
f"[FILE: {mime_type or 'unknown'}, {len(seg.raw)} bytes]"
)
logger.debug(f"处理其他文件字节数据: {mime_type}, 大小: {len(seg.raw)} bytes")
logger.debug(
f"处理其他文件字节数据: {mime_type}, "
f"大小: {len(seg.raw)} bytes"
)
elif isinstance(seg, At):
if seg.flag == "all":
@ -76,7 +105,9 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
else:
reply_text = str(seg.msg).strip()
if reply_text:
part = LLMContentPart.text_part(f'[Replied to: "{reply_text[:50]}..."]')
part = LLMContentPart.text_part(
f'[Replied to: "{reply_text[:50]}..."]'
)
except Exception:
part = LLMContentPart.text_part("[Replied to a message]")
@ -118,10 +149,14 @@ def create_multimodal_message(
msg = create_multimodal_message("分析图片", images="/path/to/image.jpg")
# 文本 + 多张图片
msg = create_multimodal_message("比较图片", images=["/path/1.jpg", "/path/2.jpg"])
msg = create_multimodal_message(
"比较图片", images=["/path/1.jpg", "/path/2.jpg"]
)
# 文本 + 图片字节数据
msg = create_multimodal_message("分析", images=image_data, image_mimetypes="image/jpeg")
msg = create_multimodal_message(
"分析", images=image_data, image_mimetypes="image/jpeg"
)
# 文本 + 视频
msg = create_multimodal_message("分析视频", videos="/path/to/video.mp4")