diff --git a/zhenxun/services/llm/adapters/base.py b/zhenxun/services/llm/adapters/base.py index f94c22cd..499f9248 100644 --- a/zhenxun/services/llm/adapters/base.py +++ b/zhenxun/services/llm/adapters/base.py @@ -131,7 +131,9 @@ class BaseAdapter(ABC): pass @abstractmethod - def parse_embedding_response(self, response_json: dict[str, Any]) -> list[list[float]]: + def parse_embedding_response( + self, response_json: dict[str, Any] + ) -> list[list[float]]: """解析文本嵌入响应""" pass @@ -145,7 +147,9 @@ class BaseAdapter(ABC): else str(error_info) ) raise LLMException( - f"嵌入API错误: {msg}", code=LLMErrorCode.EMBEDDING_FAILED, details=response_json + f"嵌入API错误: {msg}", + code=LLMErrorCode.EMBEDDING_FAILED, + details=response_json, ) def get_api_url(self, model: "LLMModel", endpoint: str) -> str: @@ -170,7 +174,9 @@ class BaseAdapter(ABC): ) return headers - def convert_messages_to_openai_format(self, messages: list["LLMMessage"]) -> list[dict[str, Any]]: + def convert_messages_to_openai_format( + self, messages: list["LLMMessage"] + ) -> list[dict[str, Any]]: """将LLMMessage转换为OpenAI格式 - 通用方法""" openai_messages: list[dict[str, Any]] = [] for msg in messages: @@ -190,7 +196,10 @@ class BaseAdapter(ABC): content_parts.append({"type": "text", "text": part.text}) elif part.type == "image": content_parts.append( - {"type": "image_url", "image_url": {"url": part.image_source}} + { + "type": "image_url", + "image_url": {"url": part.image_source}, + } ) openai_msg["content"] = content_parts @@ -247,9 +256,13 @@ class BaseAdapter(ABC): ) ) except KeyError as e: - logger.warning(f"解析OpenAI工具调用数据时缺少键: {tc_data}, 错误: {e}") + logger.warning( + f"解析OpenAI工具调用数据时缺少键: {tc_data}, 错误: {e}" + ) except Exception as e: - logger.warning(f"解析OpenAI工具调用数据时出错: {tc_data}, 错误: {e}") + logger.warning( + f"解析OpenAI工具调用数据时出错: {tc_data}, 错误: {e}" + ) if not parsed_tool_calls: parsed_tool_calls = None @@ -268,7 +281,11 @@ class BaseAdapter(ABC): except Exception as e: logger.error(f"解析OpenAI格式响应失败: {e}", e=e) - raise LLMException(f"解析API响应失败: {e}", code=LLMErrorCode.RESPONSE_PARSE_ERROR, cause=e) + raise LLMException( + f"解析API响应失败: {e}", + code=LLMErrorCode.RESPONSE_PARSE_ERROR, + cause=e, + ) def validate_response(self, response_json: dict[str, Any]) -> None: """验证API响应,解析不同API的错误结构""" @@ -291,9 +308,14 @@ class BaseAdapter(ABC): "max_tokens_exceeded": LLMErrorCode.CONTEXT_LENGTH_EXCEEDED, } - llm_error_code = error_code_mapping.get(error_code, LLMErrorCode.API_RESPONSE_INVALID) + llm_error_code = error_code_mapping.get( + error_code, LLMErrorCode.API_RESPONSE_INVALID + ) - logger.error(f"API返回错误: {error_message} (代码: {error_code}, 类型: {error_type})") + logger.error( + f"API返回错误: {error_message} " + f"(代码: {error_code}, 类型: {error_type})" + ) else: error_message = str(error_info) error_code = "unknown" @@ -314,7 +336,10 @@ class BaseAdapter(ABC): finish_reason = candidate.get("finishReason") if finish_reason in ["SAFETY", "RECITATION"]: safety_ratings = candidate.get("safetyRatings", []) - logger.warning(f"Gemini内容被安全过滤: {finish_reason}, 安全评级: {safety_ratings}") + logger.warning( + f"Gemini内容被安全过滤: {finish_reason}, " + f"安全评级: {safety_ratings}" + ) raise LLMException( f"内容被安全过滤: {finish_reason}", code=LLMErrorCode.CONTENT_FILTERED, @@ -342,7 +367,9 @@ class BaseAdapter(ABC): return config.to_api_params(model.api_type, model.model_name) if model._generation_config is not None: - return model._generation_config.to_api_params(model.api_type, model.model_name) + return model._generation_config.to_api_params( + model.api_type, model.model_name + ) base_config = {} if model.temperature is not None: @@ -417,6 +444,7 @@ class OpenAICompatAdapter(BaseAdapter): is_advanced: bool = False, ) -> ResponseData: """解析响应 - 直接使用基类的 OpenAI 格式解析""" + _ = model, is_advanced # 未使用的参数 return self.parse_openai_response(response_json) def prepare_embedding_request( @@ -428,6 +456,7 @@ class OpenAICompatAdapter(BaseAdapter): **kwargs: Any, ) -> RequestData: """准备嵌入请求 - OpenAI兼容格式""" + _ = task_type # 未使用的参数 url = self.get_api_url(model, self.get_embedding_endpoint()) headers = self.get_base_headers(api_key) @@ -442,7 +471,9 @@ class OpenAICompatAdapter(BaseAdapter): return RequestData(url=url, headers=headers, body=body) - def parse_embedding_response(self, response_json: dict[str, Any]) -> list[list[float]]: + def parse_embedding_response( + self, response_json: dict[str, Any] + ) -> list[list[float]]: """解析嵌入响应 - OpenAI兼容格式""" self.validate_embedding_response(response_json) diff --git a/zhenxun/services/llm/adapters/gemini.py b/zhenxun/services/llm/adapters/gemini.py index 3c6a7681..0ca22185 100644 --- a/zhenxun/services/llm/adapters/gemini.py +++ b/zhenxun/services/llm/adapters/gemini.py @@ -48,7 +48,9 @@ class GeminiAdapter(BaseAdapter): tool_choice: str | dict[str, Any] | None = None, ) -> RequestData: """准备高级请求""" - return self._prepare_request(model, api_key, messages, config, tools, tool_choice) + return self._prepare_request( + model, api_key, messages, config, tools, tool_choice + ) def _prepare_request( self, @@ -75,7 +77,9 @@ class GeminiAdapter(BaseAdapter): if isinstance(msg.content, str): system_instruction_parts = [{"text": msg.content}] elif isinstance(msg.content, list): - system_instruction_parts = [part.convert_for_api("gemini") for part in msg.content] + system_instruction_parts = [ + part.convert_for_api("gemini") for part in msg.content + ] continue elif msg.role == "user": @@ -115,12 +119,21 @@ class GeminiAdapter(BaseAdapter): import json try: - content_str = msg.content if isinstance(msg.content, str) else str(msg.content) + content_str = ( + msg.content + if isinstance(msg.content, str) + else str(msg.content) + ) tool_result_obj = json.loads(content_str) except json.JSONDecodeError: - content_str = msg.content if isinstance(msg.content, str) else str(msg.content) + content_str = ( + msg.content + if isinstance(msg.content, str) + else str(msg.content) + ) logger.warning( - f"工具 {msg.name} 的结果不是有效的 JSON: {content_str}. 包装为原始字符串。" + f"工具 {msg.name} 的结果不是有效的 JSON: {content_str}. " + f"包装为原始字符串。" ) tool_result_obj = {"raw_output": content_str} @@ -144,7 +157,9 @@ class GeminiAdapter(BaseAdapter): for tool_item in tools: if isinstance(tool_item, dict): if "name" in tool_item and "description" in tool_item: - all_tools_for_request.append({"functionDeclarations": [tool_item]}) + all_tools_for_request.append( + {"functionDeclarations": [tool_item]} + ) else: all_tools_for_request.append(tool_item) else: @@ -152,7 +167,9 @@ class GeminiAdapter(BaseAdapter): if effective_config: if getattr(effective_config, "enable_grounding", False): - has_explicit_gs_tool = any("googleSearch" in tool_item for tool_item in all_tools_for_request) + has_explicit_gs_tool = any( + "googleSearch" in tool_item for tool_item in all_tools_for_request + ) if not has_explicit_gs_tool: all_tools_for_request.append({"googleSearch": {}}) logger.debug("隐式启用 Google Search 工具进行信息来源关联。") @@ -166,7 +183,9 @@ class GeminiAdapter(BaseAdapter): logger.debug("隐式启用代码执行工具。") if all_tools_for_request: - gemini_api_tools = self._convert_tools_to_gemini_format(all_tools_for_request) + gemini_api_tools = self._convert_tools_to_gemini_format( + all_tools_for_request + ) if gemini_api_tools: body["tools"] = gemini_api_tools @@ -180,11 +199,17 @@ class GeminiAdapter(BaseAdapter): if mode_upper in ["AUTO", "NONE", "ANY"]: body["toolConfig"] = {"functionCallingConfig": {"mode": mode_upper}} else: - body["toolConfig"] = self._convert_tool_choice_to_gemini(final_tool_choice) + body["toolConfig"] = self._convert_tool_choice_to_gemini( + final_tool_choice + ) else: - body["toolConfig"] = self._convert_tool_choice_to_gemini(final_tool_choice) + body["toolConfig"] = self._convert_tool_choice_to_gemini( + final_tool_choice + ) - final_generation_config = self._build_gemini_generation_config(model, effective_config) + final_generation_config = self._build_gemini_generation_config( + model, effective_config + ) if final_generation_config: body["generationConfig"] = final_generation_config @@ -203,7 +228,9 @@ class GeminiAdapter(BaseAdapter): """应用配置覆盖 - Gemini 不需要额外的配置覆盖""" return body - def _get_gemini_endpoint(self, model: "LLMModel", config: "LLMGenerationConfig | None" = None) -> str: + def _get_gemini_endpoint( + self, model: "LLMModel", config: "LLMGenerationConfig | None" = None + ) -> str: """根据配置选择Gemini API端点""" if config: if getattr(config, "enable_code_execution", False): @@ -214,7 +241,9 @@ class GeminiAdapter(BaseAdapter): return f"/v1beta/models/{model.model_name}:generateContent" - def _convert_tools_to_gemini_format(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]: + def _convert_tools_to_gemini_format( + self, tools: list[dict[str, Any]] + ) -> list[dict[str, Any]]: """转换工具格式为Gemini格式""" gemini_tools = [] @@ -232,7 +261,9 @@ class GeminiAdapter(BaseAdapter): } gemini_tools.append(gemini_tool) elif tool.get("type") == "code_execution": - gemini_tools.append({"codeExecution": {"language": tool.get("language", "python")}}) + gemini_tools.append( + {"codeExecution": {"language": tool.get("language", "python")}} + ) elif tool.get("type") == "google_search": gemini_tools.append({"googleSearch": {}}) elif "googleSearch" in tool: @@ -242,18 +273,26 @@ class GeminiAdapter(BaseAdapter): return gemini_tools - def _convert_tool_choice_to_gemini(self, tool_choice_value: str | dict[str, Any]) -> dict[str, Any]: + def _convert_tool_choice_to_gemini( + self, tool_choice_value: str | dict[str, Any] + ) -> dict[str, Any]: """转换工具选择策略为Gemini格式""" if isinstance(tool_choice_value, str): mode_upper = tool_choice_value.upper() if mode_upper in ["AUTO", "NONE", "ANY"]: return {"functionCallingConfig": {"mode": mode_upper}} else: - logger.warning(f"不支持的 tool_choice 字符串值: '{tool_choice_value}'。回退到 AUTO。") + logger.warning( + f"不支持的 tool_choice 字符串值: '{tool_choice_value}'。" + f"回退到 AUTO。" + ) return {"functionCallingConfig": {"mode": "AUTO"}} elif isinstance(tool_choice_value, dict): - if tool_choice_value.get("type") == "function" and "function" in tool_choice_value: + if ( + tool_choice_value.get("type") == "function" + and "function" in tool_choice_value + ): func_name = tool_choice_value["function"].get("name") if func_name: return { @@ -263,17 +302,26 @@ class GeminiAdapter(BaseAdapter): } } else: - logger.warning(f"tool_choice dict 中的函数名无效: {tool_choice_value}。回退到 AUTO。") + logger.warning( + f"tool_choice dict 中的函数名无效: {tool_choice_value}。" + f"回退到 AUTO。" + ) return {"functionCallingConfig": {"mode": "AUTO"}} elif "functionCallingConfig" in tool_choice_value: - return {"functionCallingConfig": tool_choice_value["functionCallingConfig"]} + return { + "functionCallingConfig": tool_choice_value["functionCallingConfig"] + } else: - logger.warning(f"不支持的 tool_choice dict 值: {tool_choice_value}。回退到 AUTO。") + logger.warning( + f"不支持的 tool_choice dict 值: {tool_choice_value}。回退到 AUTO。" + ) return {"functionCallingConfig": {"mode": "AUTO"}} - logger.warning(f"tool_choice 的类型无效: {type(tool_choice_value)}。回退到 AUTO。") + logger.warning( + f"tool_choice 的类型无效: {type(tool_choice_value)}。回退到 AUTO。" + ) return {"functionCallingConfig": {"mode": "AUTO"}} def _build_gemini_generation_config( @@ -285,11 +333,15 @@ class GeminiAdapter(BaseAdapter): effective_config = config if config is not None else model._generation_config if effective_config: - base_api_params = effective_config.to_api_params(api_type="gemini", model_name=model.model_name) + base_api_params = effective_config.to_api_params( + api_type="gemini", model_name=model.model_name + ) generation_config.update(base_api_params) if getattr(effective_config, "response_mime_type", None): - generation_config["responseMimeType"] = effective_config.response_mime_type + generation_config["responseMimeType"] = ( + effective_config.response_mime_type + ) if getattr(effective_config, "response_schema", None): generation_config["responseSchema"] = effective_config.response_schema @@ -303,15 +355,22 @@ class GeminiAdapter(BaseAdapter): if getattr(effective_config, "response_modalities", None): modalities = effective_config.response_modalities if isinstance(modalities, list): - generation_config["responseModalities"] = [m.upper() for m in modalities] + generation_config["responseModalities"] = [ + m.upper() for m in modalities + ] elif isinstance(modalities, str): generation_config["responseModalities"] = [modalities.upper()] - generation_config = {k: v for k, v in generation_config.items() if v is not None} + generation_config = { + k: v for k, v in generation_config.items() if v is not None + } if generation_config: param_keys = list(generation_config.keys()) - logger.debug(f"构建Gemini生成配置完成,包含 {len(generation_config)} 个参数: {param_keys}") + logger.debug( + f"构建Gemini生成配置完成,包含 {len(generation_config)} 个参数: " + f"{param_keys}" + ) return generation_config @@ -337,7 +396,9 @@ class GeminiAdapter(BaseAdapter): safety_settings.append({"category": category, "threshold": threshold}) else: for category in safety_categories: - safety_settings.append({"category": category, "threshold": "BLOCK_MEDIUM_AND_ABOVE"}) + safety_settings.append( + {"category": category, "threshold": "BLOCK_MEDIUM_AND_ABOVE"} + ) return safety_settings if safety_settings else None @@ -368,12 +429,18 @@ class GeminiAdapter(BaseAdapter): candidate = candidates[0] - if candidate.get("finishReason") in ["RECITATION", "OTHER"] and not candidate.get("content"): + if candidate.get("finishReason") in [ + "RECITATION", + "OTHER", + ] and not candidate.get("content"): logger.warning( - f"Gemini candidate finished with reason '{candidate.get('finishReason')}' and no content." + f"Gemini candidate finished with reason " + f"'{candidate.get('finishReason')}' and no content." ) return ResponseData( - text="", raw_response=response_json, usage_info=response_json.get("usageMetadata") + text="", + raw_response=response_json, + usage_info=response_json.get("usageMetadata"), ) content_data = candidate.get("content", {}) @@ -394,9 +461,10 @@ class GeminiAdapter(BaseAdapter): from ..types.models import LLMToolCall, LLMToolFunction + call_id = f"call_{model.provider_name}_{len(parsed_tool_calls)}" parsed_tool_calls.append( LLMToolCall( - id=f"call_{model.provider_name}_{len(parsed_tool_calls)}", + id=call_id, function=LLMToolFunction( name=fc_data["name"], arguments=json.dumps(fc_data["args"]), @@ -404,16 +472,22 @@ class GeminiAdapter(BaseAdapter): ) ) except KeyError as e: - logger.warning(f"解析Gemini functionCall时缺少键: {fc_data}, 错误: {e}") + logger.warning( + f"解析Gemini functionCall时缺少键: {fc_data}, 错误: {e}" + ) except Exception as e: - logger.warning(f"解析Gemini functionCall时出错: {fc_data}, 错误: {e}") + logger.warning( + f"解析Gemini functionCall时出错: {fc_data}, 错误: {e}" + ) elif "codeExecutionResult" in part: result = part["codeExecutionResult"] if result.get("outcome") == "OK": output = result.get("output", "") text_content += f"\n[代码执行结果]:\n{output}\n" else: - text_content += f"\n[代码执行失败]: {result.get('outcome', 'UNKNOWN')}\n" + text_content += ( + f"\n[代码执行失败]: {result.get('outcome', 'UNKNOWN')}\n" + ) usage_info = response_json.get("usageMetadata") @@ -436,7 +510,11 @@ class GeminiAdapter(BaseAdapter): except Exception as e: logger.error(f"解析 Gemini 响应失败: {e}", e=e) - raise LLMException(f"解析API响应失败: {e}", code=LLMErrorCode.RESPONSE_PARSE_ERROR, cause=e) + raise LLMException( + f"解析API响应失败: {e}", + code=LLMErrorCode.RESPONSE_PARSE_ERROR, + cause=e, + ) def prepare_embedding_request( self, @@ -474,7 +552,9 @@ class GeminiAdapter(BaseAdapter): body = {"requests": requests_payload} return RequestData(url=url, headers=headers, body=body) - def parse_embedding_response(self, response_json: dict[str, Any]) -> list[list[float]]: + def parse_embedding_response( + self, response_json: dict[str, Any] + ) -> list[list[float]]: """解析文本嵌入响应""" try: embeddings_data = response_json["embeddings"] @@ -482,18 +562,26 @@ class GeminiAdapter(BaseAdapter): except KeyError as e: logger.error(f"解析Gemini嵌入响应时缺少键: {e}. 响应: {response_json}") raise LLMException( - "Gemini嵌入响应格式错误", code=LLMErrorCode.RESPONSE_PARSE_ERROR, details={"error": str(e)} + "Gemini嵌入响应格式错误", + code=LLMErrorCode.RESPONSE_PARSE_ERROR, + details={"error": str(e)}, ) except Exception as e: - logger.error(f"解析Gemini嵌入响应时发生未知错误: {e}. 响应: {response_json}") + logger.error( + f"解析Gemini嵌入响应时发生未知错误: {e}. 响应: {response_json}" + ) raise LLMException( - f"解析Gemini嵌入响应失败: {e}", code=LLMErrorCode.RESPONSE_PARSE_ERROR, cause=e + f"解析Gemini嵌入响应失败: {e}", + code=LLMErrorCode.RESPONSE_PARSE_ERROR, + cause=e, ) def validate_embedding_response(self, response_json: dict[str, Any]) -> None: """验证嵌入响应""" super().validate_embedding_response(response_json) - if "embeddings" not in response_json or not isinstance(response_json["embeddings"], list): + if "embeddings" not in response_json or not isinstance( + response_json["embeddings"], list + ): raise LLMException( "Gemini嵌入响应缺少'embeddings'字段或格式不正确", code=LLMErrorCode.RESPONSE_PARSE_ERROR, diff --git a/zhenxun/services/llm/api.py b/zhenxun/services/llm/api.py index a4ffe90f..c722b17c 100644 --- a/zhenxun/services/llm/api.py +++ b/zhenxun/services/llm/api.py @@ -91,7 +91,9 @@ class AI: if isinstance(message, str): llm_messages = [LLMMessage.user(message)] - elif isinstance(message, list) and all(isinstance(part, LLMContentPart) for part in message): + elif isinstance(message, list) and all( + isinstance(part, LLMContentPart) for part in message + ): llm_messages = [LLMMessage.user(message)] elif isinstance(message, LLMMessage): llm_messages = [message] @@ -103,7 +105,9 @@ class AI: code=LLMErrorCode.API_REQUEST_FAILED, ) - response = await self._execute_generation(llm_messages, model, "聊天失败", kwargs) + response = await self._execute_generation( + llm_messages, model, "聊天失败", kwargs + ) return response.text async def code( @@ -135,7 +139,12 @@ class AI: } async def search( - self, query: str | UniMessage, *, model: ModelName = None, instruction: str = "", **kwargs: Any + self, + query: str | UniMessage, + *, + model: ModelName = None, + instruction: str = "", + **kwargs: Any, ) -> dict[str, Any]: """信息搜索 - 支持多模态输入""" resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash" @@ -154,7 +163,9 @@ class AI: if instruction: final_messages.append(LLMMessage.user(instruction)) else: - raise LLMException("搜索内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED) + raise LLMException( + "搜索内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED + ) else: final_messages.append(LLMMessage.user(content_parts)) @@ -206,7 +217,9 @@ class AI: if instruction: final_messages.append(LLMMessage.user(instruction)) else: - raise LLMException("分析内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED) + raise LLMException( + "分析内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED + ) else: final_messages.append(LLMMessage.user(content_parts)) @@ -241,7 +254,12 @@ class AI: tool_choice = "none" response = await self._execute_generation( - final_messages, model, "内容分析失败", kwargs, llm_tools=llm_tools, tool_choice=tool_choice + final_messages, + model, + "内容分析失败", + kwargs, + llm_tools=llm_tools, + tool_choice=tool_choice, ) if response.tool_calls: @@ -260,8 +278,12 @@ class AI: ) -> LLMResponse: """通用的生成执行方法,封装重复的模型获取、配置合并和异常处理逻辑""" try: - resolved_model_name = self._resolve_model_name(model_name or self.config.model) - final_config_dict = self._merge_config(config_overrides, base_config=base_config) + resolved_model_name = self._resolve_model_name( + model_name or self.config.model + ) + final_config_dict = self._merge_config( + config_overrides, base_config=base_config + ) async with await get_model_instance( resolved_model_name, override_config=final_config_dict @@ -316,7 +338,9 @@ class AI: if self.config.enable_gemini_thinking: final_config["thinking_budget"] = 0.8 if self.config.enable_gemini_safe_mode: - final_config["safety_settings"] = CommonOverrides.gemini_safe().safety_settings + final_config["safety_settings"] = ( + CommonOverrides.gemini_safe().safety_settings + ) if self.config.enable_gemini_multimodal: final_config.update(CommonOverrides.gemini_multimodal().to_dict()) if self.config.enable_gemini_grounding: @@ -341,10 +365,13 @@ class AI: return [] try: - resolved_model_str = model or self.config.default_embedding_model or self.config.model + resolved_model_str = ( + model or self.config.default_embedding_model or self.config.model + ) if not resolved_model_str: raise LLMException( - "使用 embed 功能时必须指定嵌入模型名称,或在 AIConfig 中配置 default_embedding_model。", + "使用 embed 功能时必须指定嵌入模型名称," + "或在 AIConfig 中配置 default_embedding_model。", code=LLMErrorCode.MODEL_NOT_FOUND, ) resolved_model_str = self._resolve_model_name(resolved_model_str) @@ -360,7 +387,9 @@ class AI: raise except Exception as e: logger.error(f"文本嵌入失败: {e}", e=e) - raise LLMException(f"文本嵌入失败: {e}", code=LLMErrorCode.EMBEDDING_FAILED, cause=e) + raise LLMException( + f"文本嵌入失败: {e}", code=LLMErrorCode.EMBEDDING_FAILED, cause=e + ) async def chat( @@ -443,7 +472,9 @@ async def analyze_multimodal( **kwargs: Any, ) -> str | LLMResponse: """多模态分析便捷函数""" - message = create_multimodal_message(text=text, images=images, videos=videos, audios=audios) + message = create_multimodal_message( + text=text, images=images, videos=videos, audios=audios + ) return await analyze(message, instruction=instruction, model=model, **kwargs) @@ -458,7 +489,9 @@ async def search_multimodal( **kwargs: Any, ) -> dict[str, Any]: """多模态搜索便捷函数""" - message = create_multimodal_message(text=text, images=images, videos=videos, audios=audios) + message = create_multimodal_message( + text=text, images=images, videos=videos, audios=audios + ) ai = AI() return await ai.search(message, model=model, instruction=instruction, **kwargs) diff --git a/zhenxun/services/llm/config/generation.py b/zhenxun/services/llm/config/generation.py index a7ad9171..a143dedd 100644 --- a/zhenxun/services/llm/config/generation.py +++ b/zhenxun/services/llm/config/generation.py @@ -15,27 +15,47 @@ from ..types.exceptions import LLMErrorCode, LLMException class ModelConfigOverride(BaseModel): """模型配置覆盖参数""" - temperature: float | None = Field(default=None, ge=0.0, le=2.0, description="生成温度") + temperature: float | None = Field( + default=None, ge=0.0, le=2.0, description="生成温度" + ) max_tokens: int | None = Field(default=None, gt=0, description="最大输出token数") top_p: float | None = Field(default=None, ge=0.0, le=1.0, description="核采样参数") top_k: int | None = Field(default=None, gt=0, description="Top-K采样参数") - frequency_penalty: float | None = Field(default=None, ge=-2.0, le=2.0, description="频率惩罚") - presence_penalty: float | None = Field(default=None, ge=-2.0, le=2.0, description="存在惩罚") - repetition_penalty: float | None = Field(default=None, ge=0.0, le=2.0, description="重复惩罚") + frequency_penalty: float | None = Field( + default=None, ge=-2.0, le=2.0, description="频率惩罚" + ) + presence_penalty: float | None = Field( + default=None, ge=-2.0, le=2.0, description="存在惩罚" + ) + repetition_penalty: float | None = Field( + default=None, ge=0.0, le=2.0, description="重复惩罚" + ) stop: list[str] | str | None = Field(default=None, description="停止序列") response_format: ResponseFormat | dict[str, Any] | None = Field( default=None, description="期望的响应格式" ) - response_mime_type: str | None = Field(default=None, description="响应MIME类型(Gemini专用)") - response_schema: dict[str, Any] | None = Field(default=None, description="JSON响应模式") - thinking_budget: float | None = Field(default=None, ge=0.0, le=1.0, description="思考预算") + response_mime_type: str | None = Field( + default=None, description="响应MIME类型(Gemini专用)" + ) + response_schema: dict[str, Any] | None = Field( + default=None, description="JSON响应模式" + ) + thinking_budget: float | None = Field( + default=None, ge=0.0, le=1.0, description="思考预算" + ) safety_settings: dict[str, str] | None = Field(default=None, description="安全设置") - response_modalities: list[str] | None = Field(default=None, description="响应模态类型") + response_modalities: list[str] | None = Field( + default=None, description="响应模态类型" + ) - enable_code_execution: bool | None = Field(default=None, description="是否启用代码执行") - enable_grounding: bool | None = Field(default=None, description="是否启用信息来源关联") + enable_code_execution: bool | None = Field( + default=None, description="是否启用代码执行" + ) + enable_grounding: bool | None = Field( + default=None, description="是否启用信息来源关联" + ) enable_caching: bool | None = Field(default=None, description="是否启用响应缓存") custom_params: dict[str, Any] | None = Field(default=None, description="自定义参数") @@ -43,7 +63,16 @@ class ModelConfigOverride(BaseModel): def to_dict(self) -> dict[str, Any]: """转换为字典,排除None值""" result = {} - for key, value in self.model_dump().items(): + model_data = getattr(self, "model_dump", lambda: {})() + if not model_data: + model_data = {} + for field_name, _ in self.__class__.__dict__.get( + "model_fields", {} + ).items(): + value = getattr(self, field_name, None) + if value is not None: + model_data[field_name] = value + for key, value in model_data.items(): if value is not None: if key == "custom_params" and isinstance(value, dict): result.update(value) @@ -110,13 +139,14 @@ class LLMGenerationConfig(ModelConfigOverride): else: params["repetition_penalty"] = self.repetition_penalty - # 处理 response_format 参数 if self.response_format is not None: if isinstance(self.response_format, dict): - # 直接使用字典格式的 response_format(如 {'type': 'json_object'}) if api_type in ["openai", "zhipu", "deepseek", "general_openai_compat"]: params["response_format"] = self.response_format - logger.debug(f"为 {api_type} 使用自定义 response_format: {self.response_format}") + logger.debug( + f"为 {api_type} 使用自定义 response_format: " + f"{self.response_format}" + ) elif self.response_format == ResponseFormat.JSON: if api_type in ["openai", "zhipu", "deepseek", "general_openai_compat"]: params["response_format"] = {"type": "json_object"} @@ -128,9 +158,14 @@ class LLMGenerationConfig(ModelConfigOverride): logger.debug(f"为 {api_type} 启用 JSON MIME 类型输出模式") if api_type in ["gemini", "gemini_native"]: - if self.response_format != ResponseFormat.JSON and self.response_mime_type is not None: + if ( + self.response_format != ResponseFormat.JSON + and self.response_mime_type is not None + ): params["responseMimeType"] = self.response_mime_type - logger.debug(f"使用显式设置的 responseMimeType: {self.response_mime_type}") + logger.debug( + f"使用显式设置的 responseMimeType: {self.response_mime_type}" + ) if self.response_schema is not None and "responseSchema" not in params: params["responseSchema"] = self.response_schema @@ -158,7 +193,9 @@ def validate_override_params( if isinstance(override_config, dict): try: - filtered_config = {k: v for k, v in override_config.items() if v is not None} + filtered_config = { + k: v for k, v in override_config.items() if v is not None + } return LLMGenerationConfig(**filtered_config) except Exception as e: logger.warning(f"覆盖配置参数验证失败: {e}") @@ -171,7 +208,9 @@ def validate_override_params( return override_config -def apply_api_specific_mappings(params: dict[str, Any], api_type: str) -> dict[str, Any]: +def apply_api_specific_mappings( + params: dict[str, Any], api_type: str +) -> dict[str, Any]: """应用API特定的参数映射""" mapped_params = params.copy() @@ -204,7 +243,8 @@ def apply_api_specific_mappings(params: dict[str, Any], api_type: str) -> dict[s def create_generation_config_from_kwargs(**kwargs) -> LLMGenerationConfig: """从关键字参数创建生成配置""" - known_fields = set(LLMGenerationConfig.model_fields.keys()) + model_fields = getattr(LLMGenerationConfig, "model_fields", {}) + known_fields = set(model_fields.keys()) known_params = {} custom_params = {} diff --git a/zhenxun/services/llm/config/presets.py b/zhenxun/services/llm/config/presets.py index 04a72dab..7a6023d5 100644 --- a/zhenxun/services/llm/config/presets.py +++ b/zhenxun/services/llm/config/presets.py @@ -30,17 +30,25 @@ class CommonOverrides: @staticmethod def concise(max_tokens: int = 100) -> LLMGenerationConfig: """简洁模式:限制输出长度""" - return LLMGenerationConfig(temperature=0.3, max_tokens=max_tokens, stop=["\n\n", "。", "!", "?"]) + return LLMGenerationConfig( + temperature=0.3, + max_tokens=max_tokens, + stop=["\n\n", "。", "!", "?"], + ) @staticmethod def detailed(max_tokens: int = 2000) -> LLMGenerationConfig: """详细模式:鼓励详细输出""" - return LLMGenerationConfig(temperature=0.7, max_tokens=max_tokens, frequency_penalty=-0.1) + return LLMGenerationConfig( + temperature=0.7, max_tokens=max_tokens, frequency_penalty=-0.1 + ) @staticmethod def gemini_json() -> LLMGenerationConfig: """Gemini JSON模式:强制JSON输出""" - return LLMGenerationConfig(temperature=0.3, response_mime_type="application/json") + return LLMGenerationConfig( + temperature=0.3, response_mime_type="application/json" + ) @staticmethod def gemini_thinking(budget: float = 0.8) -> LLMGenerationConfig: @@ -96,7 +104,9 @@ class CommonOverrides: temperature=0.5, max_tokens=4096, enable_grounding=True, - custom_params={"grounding_config": {"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}}}, + custom_params={ + "grounding_config": {"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}} + }, ) @staticmethod @@ -119,7 +129,9 @@ class CommonOverrides: enable_caching=True, custom_params={ "code_execution_timeout": 30, - "grounding_config": {"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}}, + "grounding_config": { + "dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"} + }, }, ) @@ -132,7 +144,9 @@ class CommonOverrides: thinking_budget=0.8, enable_grounding=True, response_mime_type="application/json", - custom_params={"grounding_config": {"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}}}, + custom_params={ + "grounding_config": {"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}} + }, ) @staticmethod diff --git a/zhenxun/services/llm/core.py b/zhenxun/services/llm/core.py index 1c1c67aa..ffd900cf 100644 --- a/zhenxun/services/llm/core.py +++ b/zhenxun/services/llm/core.py @@ -40,7 +40,8 @@ class LLMHttpClient: async with self._lock: if self._client is None or self._client.is_closed: logger.debug( - f"LLMHttpClient: Initializing new httpx.AsyncClient with config: {self.config}" + f"LLMHttpClient: Initializing new httpx.AsyncClient " + f"with config: {self.config}" ) headers = get_user_agent() limits = httpx.Limits( @@ -52,11 +53,13 @@ class LLMHttpClient: headers=headers, limits=limits, timeout=timeout, - proxy=self.config.proxy, + proxies=self.config.proxy, follow_redirects=True, ) if self._client is None: - raise LLMException("HTTP client failed to initialize.", LLMErrorCode.CONFIGURATION_ERROR) + raise LLMException( + "HTTP client failed to initialize.", LLMErrorCode.CONFIGURATION_ERROR + ) return self._client async def post(self, url: str, **kwargs: Any) -> httpx.Response: @@ -78,7 +81,8 @@ class LLMHttpClient: ) if self._active_requests > 0: logger.warning( - f"LLMHttpClient: Closing while {self._active_requests} requests are still active." + f"LLMHttpClient: Closing while {self._active_requests} " + f"requests are still active." ) await self._client.aclose() self._client = None @@ -96,7 +100,9 @@ class LLMHttpClientManager: self._clients: dict[tuple[int, str | None], LLMHttpClient] = {} self._lock = asyncio.Lock() - def _get_client_key(self, provider_config: ProviderConfig) -> tuple[int, str | None]: + def _get_client_key( + self, provider_config: ProviderConfig + ) -> tuple[int, str | None]: return (provider_config.timeout, provider_config.proxy) async def get_client(self, provider_config: ProviderConfig) -> LLMHttpClient: @@ -104,15 +110,21 @@ class LLMHttpClientManager: async with self._lock: client = self._clients.get(key) if client and not client.is_closed: - logger.debug(f"LLMHttpClientManager: Reusing existing LLMHttpClient for key: {key}") + logger.debug( + f"LLMHttpClientManager: Reusing existing LLMHttpClient " + f"for key: {key}" + ) return client if client and client.is_closed: logger.debug( - f"LLMHttpClientManager: Found a closed client for key {key}. Creating a new one." + f"LLMHttpClientManager: Found a closed client for key {key}. " + f"Creating a new one." ) - logger.debug(f"LLMHttpClientManager: Creating new LLMHttpClient for key: {key}") + logger.debug( + f"LLMHttpClientManager: Creating new LLMHttpClient for key: {key}" + ) http_client_config = HttpClientConfig( timeout=provider_config.timeout, proxy=provider_config.proxy ) @@ -122,9 +134,14 @@ class LLMHttpClientManager: async def shutdown(self): async with self._lock: - logger.info(f"LLMHttpClientManager: Shutting down. Closing {len(self._clients)} client(s).") + logger.info( + f"LLMHttpClientManager: Shutting down. " + f"Closing {len(self._clients)} client(s)." + ) close_tasks = [ - client.close() for client in self._clients.values() if client and not client.is_closed + client.close() + for client in self._clients.values() + if client and not client.is_closed ] if close_tasks: await asyncio.gather(*close_tasks, return_exceptions=True) @@ -183,11 +200,16 @@ async def with_smart_retry( except LLMException as e: last_exception = e - if e.code in [LLMErrorCode.API_KEY_INVALID, LLMErrorCode.API_QUOTA_EXCEEDED]: + if e.code in [ + LLMErrorCode.API_KEY_INVALID, + LLMErrorCode.API_QUOTA_EXCEEDED, + ]: if hasattr(e, "details") and e.details and "api_key" in e.details: failed_keys.add(e.details["api_key"]) if key_store and provider_name: - await key_store.record_failure(e.details["api_key"], e.details.get("status_code")) + await key_store.record_failure( + e.details["api_key"], e.details.get("status_code") + ) should_retry = _should_retry_llm_error(e, attempt, config.max_retries) if not should_retry: @@ -198,7 +220,9 @@ async def with_smart_retry( wait_time = config.retry_delay if config.exponential_backoff: wait_time *= 2**attempt - logger.warning(f"请求失败,{wait_time}秒后重试 (第{attempt + 1}次): {e}") + logger.warning( + f"请求失败,{wait_time}秒后重试 (第{attempt + 1}次): {e}" + ) await asyncio.sleep(wait_time) else: logger.error(f"重试{config.max_retries}次后仍然失败: {e}") @@ -218,7 +242,9 @@ async def with_smart_retry( raise RuntimeError("重试函数未能正常执行且未捕获到异常") -def _should_retry_llm_error(error: LLMException, attempt: int, max_retries: int) -> bool: +def _should_retry_llm_error( + error: LLMException, attempt: int, max_retries: int +) -> bool: """判断LLM错误是否应该重试""" non_retryable_errors = { LLMErrorCode.MODEL_NOT_FOUND, @@ -263,7 +289,10 @@ class KeyStatusStore: self._lock = asyncio.Lock() async def get_next_available_key( - self, provider_name: str, api_keys: list[str], exclude_keys: set[str] | None = None + self, + provider_name: str, + api_keys: list[str], + exclude_keys: set[str] | None = None, ) -> str | None: """获取下一个可用的API密钥(轮询策略)""" if not api_keys: @@ -271,7 +300,9 @@ class KeyStatusStore: exclude_keys = exclude_keys or set() available_keys = [ - key for key in api_keys if key not in exclude_keys and self._key_status.get(key, True) + key + for key in api_keys + if key not in exclude_keys and self._key_status.get(key, True) ] if not available_keys: @@ -282,11 +313,15 @@ class KeyStatusStore: selected_key = available_keys[current_index % len(available_keys)] - self._provider_key_index[provider_name] = (current_index + 1) % len(available_keys) + self._provider_key_index[provider_name] = (current_index + 1) % len( + available_keys + ) import time - self._key_usage_count[selected_key] = self._key_usage_count.get(selected_key, 0) + 1 + self._key_usage_count[selected_key] = ( + self._key_usage_count.get(selected_key, 0) + 1 + ) self._key_last_used[selected_key] = time.time() logger.debug( @@ -308,7 +343,9 @@ class KeyStatusStore: async with self._lock: if status_code in [401, 403]: self._key_status[api_key] = False - logger.warning(f"API密钥认证失败,标记为不可用: {key_id} (状态码: {status_code})") + logger.warning( + f"API密钥认证失败,标记为不可用: {key_id} (状态码: {status_code})" + ) else: logger.debug(f"记录API密钥失败使用: {key_id} (状态码: {status_code})") diff --git a/zhenxun/services/llm/manager.py b/zhenxun/services/llm/manager.py index 6ba2db79..f23dfa50 100644 --- a/zhenxun/services/llm/manager.py +++ b/zhenxun/services/llm/manager.py @@ -37,9 +37,13 @@ def parse_provider_model_string(name_str: str | None) -> tuple[str | None, str | return None, None -def _make_cache_key(provider_model_name: str | None, override_config: dict | None) -> str: +def _make_cache_key( + provider_model_name: str | None, override_config: dict | None +) -> str: """生成缓存键""" - config_str = json.dumps(override_config, sort_keys=True) if override_config else "None" + config_str = ( + json.dumps(override_config, sort_keys=True) if override_config else "None" + ) key_data = f"{provider_model_name}:{config_str}" return hashlib.md5(key_data.encode()).hexdigest() @@ -62,7 +66,9 @@ def _get_cached_model(cache_key: str) -> LLMModel | None: ) model._is_closed = False - logger.debug(f"使用缓存的模型: {cache_key} -> {model.provider_name}/{model.model_name}") + logger.debug( + f"使用缓存的模型: {cache_key} -> {model.provider_name}/{model.model_name}" + ) return model return None @@ -113,7 +119,10 @@ def get_configured_providers() -> list[ProviderConfig]: ai_config = get_ai_config() providers_raw = ai_config.get(PROVIDERS_CONFIG_KEY, []) if not isinstance(providers_raw, list): - logger.error(f"配置项 {AI_CONFIG_GROUP}.{PROVIDERS_CONFIG_KEY} 不是一个列表,将使用空列表。") + logger.error( + f"配置项 {AI_CONFIG_GROUP}.{PROVIDERS_CONFIG_KEY} 不是一个列表," + f"将使用空列表。" + ) return [] valid_providers = [] @@ -128,7 +137,9 @@ def get_configured_providers() -> list[ProviderConfig]: continue if not item.get("api_key"): - logger.warning(f"Provider '{item['name']}' 缺少 'api_key' 字段,已跳过。") + logger.warning( + f"Provider '{item['name']}' 缺少 'api_key' 字段,已跳过。" + ) continue if "api_type" not in item or not item["api_type"]: @@ -159,7 +170,9 @@ def get_configured_providers() -> list[ProviderConfig]: return valid_providers -def find_model_config(provider_name: str, model_name: str) -> tuple[ProviderConfig, ModelDetail] | None: +def find_model_config( + provider_name: str, model_name: str +) -> tuple[ProviderConfig, ModelDetail] | None: """在配置中查找指定的 Provider 和 ModelDetail Args: @@ -194,7 +207,9 @@ def list_available_models() -> list[dict[str, Any]]: "api_base": provider.api_base, "is_available": model_detail.is_available, "is_embedding_model": model_detail.is_embedding_model, - "available_identifiers": _get_model_identifiers(provider.name, model_detail), + "available_identifiers": _get_model_identifiers( + provider.name, model_detail + ), } model_list.append(model_info) return model_list @@ -242,7 +257,8 @@ async def get_model_instance( if cached_model._generation_config != validated_override: cached_model._generation_config = validated_override logger.debug( - f"对缓存模型 {provider_model_name} 应用新的覆盖配置: {validated_override.to_dict()}" + f"对缓存模型 {provider_model_name} 应用新的覆盖配置: " + f"{validated_override.to_dict()}" ) return cached_model @@ -252,21 +268,25 @@ async def get_model_instance( if resolved_model_name_str is None: available_models_list = list_available_models() if not available_models_list: - raise LLMException("未配置任何AI模型", code=LLMErrorCode.CONFIGURATION_ERROR) + raise LLMException( + "未配置任何AI模型", code=LLMErrorCode.CONFIGURATION_ERROR + ) resolved_model_name_str = available_models_list[0]["full_name"] logger.warning(f"未指定模型,使用第一个可用模型: {resolved_model_name_str}") prov_name_str, mod_name_str = parse_provider_model_string(resolved_model_name_str) if not prov_name_str or not mod_name_str: raise LLMException( - f"无效的模型名称格式: '{resolved_model_name_str}'", code=LLMErrorCode.MODEL_NOT_FOUND + f"无效的模型名称格式: '{resolved_model_name_str}'", + code=LLMErrorCode.MODEL_NOT_FOUND, ) config_tuple_found = find_model_config(prov_name_str, mod_name_str) if not config_tuple_found: all_models = list_available_models() raise LLMException( - f"未找到模型: '{resolved_model_name_str}'. 可用: {[m['full_name'] for m in all_models]}", + f"未找到模型: '{resolved_model_name_str}'. " + f"可用: {[m['full_name'] for m in all_models]}", code=LLMErrorCode.MODEL_NOT_FOUND, ) @@ -274,7 +294,11 @@ async def get_model_instance( ai_config = get_ai_config() global_proxy_setting = ai_config.get(PROXY_KEY) - default_timeout = provider_config_found.timeout if provider_config_found.timeout is not None else 180 + default_timeout = ( + provider_config_found.timeout + if provider_config_found.timeout is not None + else 180 + ) global_timeout_setting = ai_config.get(TIMEOUT_KEY, default_timeout) config_for_http_client = ProviderConfig( @@ -304,16 +328,21 @@ async def get_model_instance( validated_override_params = validate_override_params(override_config) model_instance._generation_config = validated_override_params logger.debug( - f"为新模型 {resolved_model_name_str} 应用配置覆盖: {validated_override_params.to_dict()}" + f"为新模型 {resolved_model_name_str} 应用配置覆盖: " + f"{validated_override_params.to_dict()}" ) _cache_model(cache_key, model_instance) - logger.debug(f"创建并缓存了新模型: {cache_key} -> {prov_name_str}/{mod_name_str}") + logger.debug( + f"创建并缓存了新模型: {cache_key} -> {prov_name_str}/{mod_name_str}" + ) return model_instance except LLMException: raise except Exception as e: - logger.error(f"实例化 LLMModel ({resolved_model_name_str}) 时发生内部错误: {e!s}", e=e) + logger.error( + f"实例化 LLMModel ({resolved_model_name_str}) 时发生内部错误: {e!s}", e=e + ) raise LLMException( f"初始化模型 '{resolved_model_name_str}' 失败: {e!s}", code=LLMErrorCode.MODEL_INIT_FAILED, @@ -332,10 +361,14 @@ def set_global_default_model_name(provider_model_name: str | None) -> bool: if provider_model_name: prov_name, mod_name = parse_provider_model_string(provider_model_name) if not prov_name or not mod_name or not find_model_config(prov_name, mod_name): - logger.error(f"尝试设置的全局默认模型 '{provider_model_name}' 无效或未配置。") + logger.error( + f"尝试设置的全局默认模型 '{provider_model_name}' 无效或未配置。" + ) return False - Config.set_config(AI_CONFIG_GROUP, DEFAULT_MODEL_NAME_KEY, provider_model_name, auto_save=True) + Config.set_config( + AI_CONFIG_GROUP, DEFAULT_MODEL_NAME_KEY, provider_model_name, auto_save=True + ) if provider_model_name: logger.info(f"LLM 服务全局默认模型已更新为: {provider_model_name}") else: @@ -350,10 +383,16 @@ async def get_key_usage_stats() -> dict[str, Any]: for provider in providers: provider_stats = await key_store.get_key_stats( - [provider.api_key] if isinstance(provider.api_key, str) else provider.api_key + [provider.api_key] + if isinstance(provider.api_key, str) + else provider.api_key ) stats[provider.name] = { - "total_keys": len([provider.api_key] if isinstance(provider.api_key, str) else provider.api_key), + "total_keys": len( + [provider.api_key] + if isinstance(provider.api_key, str) + else provider.api_key + ), "key_stats": provider_stats, } @@ -375,7 +414,9 @@ async def reset_key_status(provider_name: str, api_key: str | None = None) -> bo return False provider_keys = ( - [target_provider.api_key] if isinstance(target_provider.api_key, str) else target_provider.api_key + [target_provider.api_key] + if isinstance(target_provider.api_key, str) + else target_provider.api_key ) if api_key: diff --git a/zhenxun/services/llm/service.py b/zhenxun/services/llm/service.py index 7a0c95d3..d054ca9b 100644 --- a/zhenxun/services/llm/service.py +++ b/zhenxun/services/llm/service.py @@ -89,7 +89,9 @@ class LLMModel(LLMModelBase): self.api_type = provider_config.api_type self.api_base = provider_config.api_base self.api_keys = ( - [provider_config.api_key] if isinstance(provider_config.api_key, str) else provider_config.api_key + [provider_config.api_key] + if isinstance(provider_config.api_key, str) + else provider_config.api_key ) self.model_name = model_detail.model_name self.temperature = model_detail.temperature @@ -101,9 +103,12 @@ class LLMModel(LLMModelBase): """获取HTTP客户端""" if self.http_client.is_closed: logger.debug( - f"LLMModel {self.provider_name}/{self.model_name} 的 HTTP 客户端已关闭,正在获取新的客户端" + f"LLMModel {self.provider_name}/{self.model_name} 的 HTTP 客户端已关闭," + "正在获取新的客户端" + ) + self.http_client = await http_client_manager.get_client( + self.provider_config ) - self.http_client = await http_client_manager.get_client(self.provider_config) return self.http_client async def _select_api_key(self, failed_keys: set[str] | None = None) -> str: @@ -122,7 +127,10 @@ class LLMModel(LLMModelBase): raise LLMException( f"提供商 {self.provider_name} 的所有API密钥当前都不可用", code=LLMErrorCode.NO_AVAILABLE_KEYS, - details={"total_keys": len(self.api_keys), "failed_keys": len(failed_keys or set())}, + details={ + "total_keys": len(self.api_keys), + "failed_keys": len(failed_keys or set()), + }, ) return selected_key @@ -154,7 +162,9 @@ class LLMModel(LLMModelBase): if http_response.status_code != 200: error_text = http_response.text - logger.error(f"HTTP嵌入请求失败: {http_response.status_code} - {error_text}") + logger.error( + f"HTTP嵌入请求失败: {http_response.status_code} - {error_text}" + ) await self.key_store.record_failure(api_key, http_response.status_code) error_code = LLMErrorCode.API_REQUEST_FAILED @@ -262,7 +272,9 @@ class LLMModel(LLMModelBase): if http_response.status_code != 200: error_text = http_response.text - logger.error(f"HTTP请求失败: {http_response.status_code} - {error_text}") + logger.error( + f"HTTP请求失败: {http_response.status_code} - {error_text}" + ) await self.key_store.record_failure(api_key, http_response.status_code) @@ -304,7 +316,10 @@ class LLMModel(LLMModelBase): try: response_tool_calls.append(LLMToolCall(**tc_data)) except Exception as e: - logger.warning(f"无法将工具调用数据转换为LLMToolCall: {tc_data}, error: {e}") + logger.warning( + f"无法将工具调用数据转换为LLMToolCall: {tc_data}, " + f"error: {e}" + ) else: logger.warning(f"工具调用数据格式未知: {tc_data}") @@ -355,11 +370,16 @@ class LLMModel(LLMModelBase): if self._is_closed: return self._is_closed = True - logger.debug(f"LLMModel实例的使用周期已结束: {self} (共享HTTP客户端状态不受影响)") + logger.debug( + f"LLMModel实例的使用周期已结束: {self} (共享HTTP客户端状态不受影响)" + ) async def __aenter__(self): if self._is_closed: - logger.debug(f"Re-entering context for closed LLMModel {self}. Resetting _is_closed to False.") + logger.debug( + f"Re-entering context for closed LLMModel {self}. " + f"Resetting _is_closed to False." + ) self._is_closed = False self._check_not_closed() return self @@ -393,12 +413,15 @@ class LLMModel(LLMModelBase): messages.append(LLMMessage.user(prompt)) + model_fields = getattr(LLMGenerationConfig, "model_fields", {}) request_specific_config_dict = { - k: v for k, v in kwargs.items() if k in LLMGenerationConfig.model_fields + k: v for k, v in kwargs.items() if k in model_fields } request_specific_config = None if request_specific_config_dict: - request_specific_config = LLMGenerationConfig(**request_specific_config_dict) + request_specific_config = LLMGenerationConfig( + **request_specific_config_dict + ) for key in request_specific_config_dict: kwargs.pop(key, None) @@ -450,7 +473,8 @@ class LLMModel(LLMModelBase): tools_dict = [] for tool in tools: if hasattr(tool, "model_dump"): - tools_dict.append(tool.model_dump(exclude_none=True)) + model_dump_func = getattr(tool, "model_dump") + tools_dict.append(model_dump_func(exclude_none=True)) elif isinstance(tool, dict): tools_dict.append(tool) else: @@ -497,7 +521,9 @@ class LLMModel(LLMModelBase): logger.debug(f"执行工具: {tool_name},参数: {tool_args_dict}") tool_result = await tool_executor(tool_name, tool_args_dict) - logger.debug(f"工具 '{tool_name}' 执行结果: {str(tool_result)[:200]}...") + logger.debug( + f"工具 '{tool_name}' 执行结果: {str(tool_result)[:200]}..." + ) tool_response_messages.append( LLMMessage.tool_response( @@ -508,13 +534,17 @@ class LLMModel(LLMModelBase): ) except json.JSONDecodeError as e: logger.error( - f"工具 '{tool_name}' 参数JSON解析失败: {tool_call.function.arguments}, 错误: {e}" + f"工具 '{tool_name}' 参数JSON解析失败: " + f"{tool_call.function.arguments}, 错误: {e}" ) tool_response_messages.append( LLMMessage.tool_response( tool_call_id=tool_call.id, function_name=tool_name, - result={"error": "Argument JSON parsing failed", "details": str(e)}, + result={ + "error": "Argument JSON parsing failed", + "details": str(e), + }, ) ) except Exception as e: @@ -523,7 +553,10 @@ class LLMModel(LLMModelBase): LLMMessage.tool_response( tool_call_id=tool_call.id, function_name=tool_name, - result={"error": "Tool execution failed", "details": str(e)}, + result={ + "error": "Tool execution failed", + "details": str(e), + }, ) ) @@ -533,7 +566,10 @@ class LLMModel(LLMModelBase): raise LLMException( "已达到最大工具调用迭代次数,但模型仍在请求工具调用或未提供最终文本回复。", code=LLMErrorCode.GENERATION_FAILED, - details={"iterations": max_tool_iterations, "last_messages": current_messages[-2:]}, + details={ + "iterations": max_tool_iterations, + "last_messages": current_messages[-2:], + }, ) async def generate_embeddings( @@ -561,7 +597,9 @@ class LLMModel(LLMModelBase): ai_config = get_ai_config() default_max_retries = ai_config.get("max_retries_llm", 3) default_retry_delay = ai_config.get("retry_delay_llm", 2) - max_retries_embed = kwargs.get("max_retries_embed", max(1, default_max_retries // 2)) + max_retries_embed = kwargs.get( + "max_retries_embed", max(1, default_max_retries // 2) + ) retry_delay_embed = kwargs.get("retry_delay_embed", default_retry_delay / 2) retry_config = RetryConfig( diff --git a/zhenxun/services/llm/types/content.py b/zhenxun/services/llm/types/content.py index e24c0568..54887bc3 100644 --- a/zhenxun/services/llm/types/content.py +++ b/zhenxun/services/llm/types/content.py @@ -58,7 +58,9 @@ class LLMContentPart(BaseModel): return cls(type="image", image_source=url) @classmethod - def image_base64_part(cls, data: str, mime_type: str = "image/png") -> "LLMContentPart": + def image_base64_part( + cls, data: str, mime_type: str = "image/png" + ) -> "LLMContentPart": """创建Base64图片内容部分""" data_url = f"data:{mime_type};base64,{data}" return cls(type="image", image_source=data_url) @@ -74,13 +76,17 @@ class LLMContentPart(BaseModel): return cls(type="video", video_source=url, mime_type=mime_type) @classmethod - def video_base64_part(cls, data: str, mime_type: str = "video/mp4") -> "LLMContentPart": + def video_base64_part( + cls, data: str, mime_type: str = "video/mp4" + ) -> "LLMContentPart": """创建Base64视频内容部分""" data_url = f"data:{mime_type};base64,{data}" return cls(type="video", video_source=data_url, mime_type=mime_type) @classmethod - def audio_base64_part(cls, data: str, mime_type: str = "audio/wav") -> "LLMContentPart": + def audio_base64_part( + cls, data: str, mime_type: str = "audio/wav" + ) -> "LLMContentPart": """创建Base64音频内容部分""" data_url = f"data:{mime_type};base64,{data}" return cls(type="audio", audio_source=data_url, mime_type=mime_type) @@ -101,7 +107,9 @@ class LLMContentPart(BaseModel): ) @classmethod - async def from_path(cls, path_like: str | Path, target_api: str | None = None) -> "LLMContentPart | None": + async def from_path( + cls, path_like: str | Path, target_api: str | None = None + ) -> "LLMContentPart | None": """ 从本地文件路径创建 LLMContentPart。 自动检测MIME类型,并根据类型(如图片)可能加载为Base64。 @@ -116,7 +124,9 @@ class LLMContentPart(BaseModel): mime_type, _ = mimetypes.guess_type(path.resolve().as_uri()) if not mime_type: - logger.warning(f"无法猜测文件 {path.name} 的MIME类型,将尝试作为文本文件处理。") + logger.warning( + f"无法猜测文件 {path.name} 的MIME类型,将尝试作为文本文件处理。" + ) try: async with aiofiles.open(path, encoding="utf-8") as f: text_content = await f.read() @@ -131,17 +141,22 @@ class LLMContentPart(BaseModel): async with aiofiles.open(path, "rb") as f: img_bytes = await f.read() base64_data = base64.b64encode(img_bytes).decode("utf-8") - return cls.image_base64_part(data=base64_data, mime_type=mime_type) + return cls.image_base64_part( + data=base64_data, mime_type=mime_type + ) except Exception as e: logger.error(f"读取或编码图片文件 {path.name} 失败: {e}") return None else: logger.warning( - f"为本地图片路径 {path.name} 生成 image_url_part。实际API可能不支持 file:// URI。考虑使用Base64或公网URL。" # noqa: E501 + f"为本地图片路径 {path.name} 生成 image_url_part。" + "实际API可能不支持 file:// URI。考虑使用Base64或公网URL。" ) return cls.image_url_part(url=path.resolve().as_uri()) elif mime_type.startswith("audio/"): - return cls.audio_url_part(url=path.resolve().as_uri(), mime_type=mime_type) + return cls.audio_url_part( + url=path.resolve().as_uri(), mime_type=mime_type + ) elif mime_type.startswith("video/"): if target_api == "gemini": # 对于 Gemini API,将视频转换为 base64 @@ -149,12 +164,16 @@ class LLMContentPart(BaseModel): async with aiofiles.open(path, "rb") as f: video_bytes = await f.read() base64_data = base64.b64encode(video_bytes).decode("utf-8") - return cls.video_base64_part(data=base64_data, mime_type=mime_type) + return cls.video_base64_part( + data=base64_data, mime_type=mime_type + ) except Exception as e: logger.error(f"读取或编码视频文件 {path.name} 失败: {e}") return None else: - return cls.video_url_part(url=path.resolve().as_uri(), mime_type=mime_type) + return cls.video_url_part( + url=path.resolve().as_uri(), mime_type=mime_type + ) elif ( mime_type.startswith("text/") or mime_type == "application/json" @@ -168,7 +187,9 @@ class LLMContentPart(BaseModel): logger.error(f"读取文本类文件 {path.name} 失败: {e}") return None else: - logger.info(f"文件 {path.name} (MIME: {mime_type}) 将作为通用文件URI处理。") + logger.info( + f"文件 {path.name} (MIME: {mime_type}) 将作为通用文件URI处理。" + ) return cls.file_uri_part( file_uri=path.resolve().as_uri(), mime_type=mime_type, @@ -228,9 +249,13 @@ class LLMContentPart(BaseModel): return {"inlineData": {"mimeType": mime_type, "data": data}} else: # 如果无法解析 Base64 数据,抛出异常 - raise ValueError(f"无法解析Base64图像数据: {self.image_source[:50]}...") + raise ValueError( + f"无法解析Base64图像数据: {self.image_source[:50]}..." + ) else: - logger.warning(f"Gemini API需要Base64格式,但提供的是URL: {self.image_source}") + logger.warning( + f"Gemini API需要Base64格式,但提供的是URL: {self.image_source}" + ) return { "inlineData": { "mimeType": "image/jpeg", @@ -253,10 +278,14 @@ class LLMContentPart(BaseModel): mime_type = header.split(";")[0].replace("data:", "") return {"inlineData": {"mimeType": mime_type, "data": data}} except (ValueError, IndexError): - raise ValueError(f"无法解析Base64视频数据: {self.video_source[:50]}...") + raise ValueError( + f"无法解析Base64视频数据: {self.video_source[:50]}..." + ) else: # 对于 URL 或其他格式,暂时不支持直接内联 - raise ValueError("Gemini API 的视频处理需要通过 File API 上传,不支持直接 URL") + raise ValueError( + "Gemini API 的视频处理需要通过 File API 上传,不支持直接 URL" + ) else: # 其他 API 可能不支持视频 raise ValueError(f"API类型 '{api_type}' 不支持视频内容") @@ -273,17 +302,25 @@ class LLMContentPart(BaseModel): mime_type = header.split(";")[0].replace("data:", "") return {"inlineData": {"mimeType": mime_type, "data": data}} except (ValueError, IndexError): - raise ValueError(f"无法解析Base64音频数据: {self.audio_source[:50]}...") + raise ValueError( + f"无法解析Base64音频数据: {self.audio_source[:50]}..." + ) else: - raise ValueError("Gemini API 的音频处理需要通过 File API 上传,不支持直接 URL") + raise ValueError( + "Gemini API 的音频处理需要通过 File API 上传,不支持直接 URL" + ) else: raise ValueError(f"API类型 '{api_type}' 不支持音频内容") elif self.type == "file": if api_type == "gemini" and self.file_uri: - return {"fileData": {"mimeType": self.mime_type, "fileUri": self.file_uri}} + return { + "fileData": {"mimeType": self.mime_type, "fileUri": self.file_uri} + } elif self.file_source: - file_name = self.metadata.get("name", "file") if self.metadata else "file" + file_name = ( + self.metadata.get("name", "file") if self.metadata else "file" + ) if api_type == "gemini": return {"text": f"[文件: {file_name}]\n{self.file_source}"} else: @@ -317,7 +354,8 @@ class LLMMessage(BaseModel): raise ValueError("工具角色的消息必须包含函数名 (在 name 字段中)") if self.role == "tool" and not isinstance(self.content, str): logger.warning( - f"工具角色消息的内容期望是字符串,但得到的是: {type(self.content)}. 将尝试转换为字符串。" + f"工具角色消息的内容期望是字符串,但得到的是: {type(self.content)}. " + "将尝试转换为字符串。" ) try: self.content = str(self.content) @@ -339,7 +377,9 @@ class LLMMessage(BaseModel): return cls(role="assistant", content=content, tool_calls=tool_calls) @classmethod - def assistant_text_response(cls, content: str | list[LLMContentPart]) -> "LLMMessage": + def assistant_text_response( + cls, content: str | list[LLMContentPart] + ) -> "LLMMessage": """创建助手纯文本回复的消息""" return cls(role="assistant", content=content, tool_calls=None) @@ -356,10 +396,19 @@ class LLMMessage(BaseModel): try: content_str = json.dumps(result) except TypeError as e: - logger.error(f"工具 '{function_name}' 的结果无法JSON序列化: {result}. 错误: {e}") - content_str = json.dumps({"error": "Tool result not JSON serializable", "details": str(e)}) + logger.error( + f"工具 '{function_name}' 的结果无法JSON序列化: {result}. 错误: {e}" + ) + content_str = json.dumps( + {"error": "Tool result not JSON serializable", "details": str(e)} + ) - return cls(role="tool", content=content_str, tool_call_id=tool_call_id, name=function_name) + return cls( + role="tool", + content=content_str, + tool_call_id=tool_call_id, + name=function_name, + ) @classmethod def system(cls, content: str) -> "LLMMessage": diff --git a/zhenxun/services/llm/types/exceptions.py b/zhenxun/services/llm/types/exceptions.py index 9621c09d..623d4c26 100644 --- a/zhenxun/services/llm/types/exceptions.py +++ b/zhenxun/services/llm/types/exceptions.py @@ -36,11 +36,15 @@ class LLMException(Exception): error_messages = { LLMErrorCode.MODEL_NOT_FOUND: "AI模型未找到,请检查配置或联系管理员。", LLMErrorCode.API_KEY_INVALID: "API密钥无效,请联系管理员更新配置。", - LLMErrorCode.API_QUOTA_EXCEEDED: "API使用配额已用尽,请稍后再试或联系管理员。", + LLMErrorCode.API_QUOTA_EXCEEDED: ( + "API使用配额已用尽,请稍后再试或联系管理员。" + ), LLMErrorCode.API_TIMEOUT: "AI服务响应超时,请稍后再试。", LLMErrorCode.API_RATE_LIMITED: "请求过于频繁,已被AI服务限流,请稍后再试。", LLMErrorCode.MODEL_INIT_FAILED: "AI模型初始化失败,请联系管理员检查配置。", - LLMErrorCode.NO_AVAILABLE_KEYS: "当前所有API密钥均不可用,请稍后再试或联系管理员。", + LLMErrorCode.NO_AVAILABLE_KEYS: ( + "当前所有API密钥均不可用,请稍后再试或联系管理员。" + ), LLMErrorCode.USER_LOCATION_NOT_SUPPORTED: ( "当前地区暂不支持此AI服务,请联系管理员或尝试其他模型。" ), diff --git a/zhenxun/services/llm/utils.py b/zhenxun/services/llm/utils.py index 2e41f34a..3610df27 100644 --- a/zhenxun/services/llm/utils.py +++ b/zhenxun/services/llm/utils.py @@ -5,7 +5,16 @@ LLM 模块的工具和转换函数 import base64 from pathlib import Path -from nonebot_plugin_alconna.uniseg import At, File, Image, Reply, Text, UniMessage, Video, Voice +from nonebot_plugin_alconna.uniseg import ( + At, + File, + Image, + Reply, + Text, + UniMessage, + Video, + Voice, +) from zhenxun.services.log import logger @@ -29,7 +38,11 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]: elif seg.url: part = LLMContentPart.image_url_part(seg.url) elif hasattr(seg, "raw") and seg.raw: - mime_type = getattr(seg, "mimetype", "image/png") if hasattr(seg, "mimetype") else "image/png" + mime_type = ( + getattr(seg, "mimetype", "image/png") + if hasattr(seg, "mimetype") + else "image/png" + ) if isinstance(seg.raw, bytes): b64_data = base64.b64encode(seg.raw).decode("utf-8") part = LLMContentPart.image_base64_part(b64_data, mime_type) @@ -38,8 +51,13 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]: if seg.path: part = await LLMContentPart.from_path(seg.path) elif seg.url: - logger.warning(f"直接使用 URL 的 {type(seg).__name__} 段,API 可能不支持: {seg.url}") - part = LLMContentPart.text_part(f"[{type(seg).__name__.upper()} FILE: {seg.name or seg.url}]") + logger.warning( + f"直接使用 URL 的 {type(seg).__name__} 段," + f"API 可能不支持: {seg.url}" + ) + part = LLMContentPart.text_part( + f"[{type(seg).__name__.upper()} FILE: {seg.name or seg.url}]" + ) elif hasattr(seg, "raw") and seg.raw: mime_type = getattr(seg, "mimetype", None) if isinstance(seg.raw, bytes): @@ -48,18 +66,29 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]: if isinstance(seg, Video): if not mime_type: mime_type = "video/mp4" - part = LLMContentPart.video_base64_part(data=b64_data, mime_type=mime_type) - logger.debug(f"处理视频字节数据: {mime_type}, 大小: {len(seg.raw)} bytes") + part = LLMContentPart.video_base64_part( + data=b64_data, mime_type=mime_type + ) + logger.debug( + f"处理视频字节数据: {mime_type}, 大小: {len(seg.raw)} bytes" + ) elif isinstance(seg, Voice): if not mime_type: mime_type = "audio/wav" - part = LLMContentPart.audio_base64_part(data=b64_data, mime_type=mime_type) - logger.debug(f"处理音频字节数据: {mime_type}, 大小: {len(seg.raw)} bytes") + part = LLMContentPart.audio_base64_part( + data=b64_data, mime_type=mime_type + ) + logger.debug( + f"处理音频字节数据: {mime_type}, 大小: {len(seg.raw)} bytes" + ) else: part = LLMContentPart.text_part( f"[FILE: {mime_type or 'unknown'}, {len(seg.raw)} bytes]" ) - logger.debug(f"处理其他文件字节数据: {mime_type}, 大小: {len(seg.raw)} bytes") + logger.debug( + f"处理其他文件字节数据: {mime_type}, " + f"大小: {len(seg.raw)} bytes" + ) elif isinstance(seg, At): if seg.flag == "all": @@ -76,7 +105,9 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]: else: reply_text = str(seg.msg).strip() if reply_text: - part = LLMContentPart.text_part(f'[Replied to: "{reply_text[:50]}..."]') + part = LLMContentPart.text_part( + f'[Replied to: "{reply_text[:50]}..."]' + ) except Exception: part = LLMContentPart.text_part("[Replied to a message]") @@ -118,10 +149,14 @@ def create_multimodal_message( msg = create_multimodal_message("分析图片", images="/path/to/image.jpg") # 文本 + 多张图片 - msg = create_multimodal_message("比较图片", images=["/path/1.jpg", "/path/2.jpg"]) + msg = create_multimodal_message( + "比较图片", images=["/path/1.jpg", "/path/2.jpg"] + ) # 文本 + 图片字节数据 - msg = create_multimodal_message("分析", images=image_data, image_mimetypes="image/jpeg") + msg = create_multimodal_message( + "分析", images=image_data, image_mimetypes="image/jpeg" + ) # 文本 + 视频 msg = create_multimodal_message("分析视频", videos="/path/to/video.mp4")