feat(llm): 实现LLM服务模块,支持多提供商统一接口和高级功能 (#1923)

*  feat(llm): 实现LLM服务模块,支持多提供商统一接口和高级功能

* 🎨 Ruff

*  Config配置类支持BaseModel存储

* 🎨 代码格式化

* 🎨 代码格式化

* 🎨 格式化代码

*  feat(llm): 添加 AI 对话历史管理

*  feat(llmConfig): 引入 LLM 配置模型及管理功能

* 🎨 Ruff

---------

Co-authored-by: fccckaug <xxxmio123123@gmail.com>
Co-authored-by: HibiKier <45528451+HibiKier@users.noreply.github.com>
Co-authored-by: HibiKier <775757368@qq.com>
Co-authored-by: fccckaug <xxxmcsmiomio3@gmail.com>
Co-authored-by: webjoin111 <455457521@qq.com>
This commit is contained in:
Rumio 2025-06-21 16:33:21 +08:00 committed by GitHub
parent 14f5842f10
commit a020ea5c87
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 6373 additions and 311 deletions

View File

@ -58,7 +58,7 @@ def _generate_simple_config(exists_module: list[str]):
生成简易配置
异常:
AttributeError: _description_
AttributeError: AttributeError
"""
# 读取用户配置
_data = {}
@ -74,7 +74,9 @@ def _generate_simple_config(exists_module: list[str]):
if _data.get(module) and k in _data[module].keys():
Config.set_config(module, k, _data[module][k])
if f"{module}:{k}".lower() in exists_module:
_tmp_data[module][k] = Config.get_config(module, k)
_tmp_data[module][k] = Config.get_config(
module, k, build_model=False
)
except AttributeError as e:
raise AttributeError(f"{e}\n可能为config.yaml配置文件填写不规范") from e
if not _tmp_data[module]:

View File

@ -1,89 +1,82 @@
from collections.abc import Callable
import copy
from datetime import datetime
from pathlib import Path
from typing import Any, Literal
from typing import Any, TypeVar, get_args, get_origin
import cattrs
from nonebot.compat import model_dump
from pydantic import BaseModel, Field
from pydantic import VERSION, BaseModel, Field
from ruamel.yaml import YAML
from ruamel.yaml.scanner import ScannerError
from zhenxun.configs.path_config import DATA_PATH
from zhenxun.services.log import logger
from zhenxun.utils.enum import BlockType, LimitWatchType, PluginLimitType, PluginType
from .models import (
AICallableParam,
AICallableProperties,
AICallableTag,
BaseBlock,
Command,
ConfigModel,
Example,
PluginCdBlock,
PluginCountBlock,
PluginExtraData,
PluginSetting,
RegisterConfig,
Task,
)
_yaml = YAML(pure=True)
_yaml.indent = 2
_yaml.allow_unicode = True
T = TypeVar("T")
class Example(BaseModel):
class NoSuchConfig(Exception):
pass
def _dump_pydantic_obj(obj: Any) -> Any:
"""
示例
递归地将一个对象内部的 Pydantic BaseModel 实例转换为字典
支持单个实例实例列表实例字典等情况
"""
exec: str
"""执行命令"""
description: str = ""
"""命令描述"""
if isinstance(obj, BaseModel):
return model_dump(obj)
if isinstance(obj, list):
return [_dump_pydantic_obj(item) for item in obj]
if isinstance(obj, dict):
return {key: _dump_pydantic_obj(value) for key, value in obj.items()}
return obj
class Command(BaseModel):
def _is_pydantic_type(t: Any) -> bool:
"""
具体参数说明
递归检查一个类型注解是否与 Pydantic BaseModel 相关
"""
command: str
"""命令名称"""
params: list[str] = Field(default_factory=list)
"""参数"""
description: str = ""
"""描述"""
examples: list[Example] = Field(default_factory=list)
"""示例列表"""
if t is None:
return False
origin = get_origin(t)
if origin:
return any(_is_pydantic_type(arg) for arg in get_args(t))
return isinstance(t, type) and issubclass(t, BaseModel)
class RegisterConfig(BaseModel):
def parse_as(type_: type[T], obj: Any) -> T:
"""
注册配置项
一个兼容 Pydantic V1 parse_obj_as 和V2的TypeAdapter.validate_python 的辅助函数
"""
if VERSION.startswith("1"):
from pydantic import parse_obj_as
key: str
"""配置项键"""
value: Any
"""配置项值"""
module: str | None = None
"""模块名"""
help: str | None
"""配置注解"""
default_value: Any | None = None
"""默认值"""
type: Any = None
"""参数类型"""
arg_parser: Callable | None = None
"""参数解析"""
return parse_obj_as(type_, obj)
else:
from pydantic import TypeAdapter # type: ignore
class ConfigModel(BaseModel):
"""
配置项
"""
value: Any
"""配置项值"""
help: str | None
"""配置注解"""
default_value: Any | None = None
"""默认值"""
type: Any = None
"""参数类型"""
arg_parser: Callable | None = None
"""参数解析"""
def to_dict(self, **kwargs):
return model_dump(self, **kwargs)
return TypeAdapter(type_).validate_python(obj)
class ConfigGroup(BaseModel):
@ -98,202 +91,41 @@ class ConfigGroup(BaseModel):
configs: dict[str, ConfigModel] = Field(default_factory=dict)
"""配置项列表"""
def get(self, c: str, default: Any = None) -> Any:
cfg = self.configs.get(c.upper())
if cfg is not None:
if cfg.value is not None:
return cfg.value
if cfg.default_value is not None:
return cfg.default_value
return default
def get(self, c: str, default: Any = None, *, build_model: bool = True) -> Any:
"""
获取配置项的值如果指定了类型会自动构建实例
"""
key = c.upper()
cfg = self.configs.get(key)
if cfg is None:
return default
value_to_process = cfg.value if cfg.value is not None else cfg.default_value
if value_to_process is None:
return default
if cfg.type:
if _is_pydantic_type(cfg.type):
if build_model:
try:
return parse_as(cfg.type, value_to_process)
except Exception as e:
logger.warning(
f"Pydantic 模型解析失败 (key: {c.upper()}). ", e=e
)
try:
return cattrs.structure(value_to_process, cfg.type)
except Exception as e:
logger.warning(f"Cattrs 结构化失败 (key: {key}),返回原始值。", e=e)
return value_to_process
def to_dict(self, **kwargs):
return model_dump(self, **kwargs)
class BaseBlock(BaseModel):
"""
插件阻断基本类插件阻断限制
"""
status: bool = True
"""限制状态"""
check_type: BlockType = BlockType.ALL
"""检查类型"""
watch_type: LimitWatchType = LimitWatchType.USER
"""监听对象"""
result: str | None = None
"""阻断时回复内容"""
_type: PluginLimitType = PluginLimitType.BLOCK
"""类型"""
def to_dict(self, **kwargs):
return model_dump(self, **kwargs)
class PluginCdBlock(BaseBlock):
"""
插件cd限制
"""
cd: int = 5
"""cd"""
_type: PluginLimitType = PluginLimitType.CD
"""类型"""
class PluginCountBlock(BaseBlock):
"""
插件次数限制
"""
max_count: int
"""最大调用次数"""
_type: PluginLimitType = PluginLimitType.COUNT
"""类型"""
class PluginSetting(BaseModel):
"""
插件基本配置
"""
level: int = 5
"""群权限等级"""
default_status: bool = True
"""进群默认开关状态"""
limit_superuser: bool = False
"""是否限制超级用户"""
cost_gold: int = 0
"""调用插件花费金币"""
impression: float = 0.0
"""调用插件好感度限制"""
class AICallableProperties(BaseModel):
type: str
"""参数类型"""
description: str
"""参数描述"""
enums: list[str] | None = None
"""参数枚举"""
class AICallableParam(BaseModel):
type: str
"""类型"""
properties: dict[str, AICallableProperties]
"""参数列表"""
required: list[str]
"""必要参数"""
class AICallableTag(BaseModel):
name: str
"""工具名称"""
parameters: AICallableParam | None = None
"""工具参数"""
description: str
"""工具描述"""
func: Callable | None = None
"""工具函数"""
def to_dict(self):
result = model_dump(self)
del result["func"]
return result
class SchedulerModel(BaseModel):
trigger: Literal["date", "interval", "cron"]
"""trigger"""
day: int | None = None
"""天数"""
hour: int | None = None
"""小时"""
minute: int | None = None
"""分钟"""
second: int | None = None
""""""
run_date: datetime | None = None
"""运行日期"""
id: str | None = None
"""id"""
max_instances: int | None = None
"""最大运行实例"""
args: list | None = None
"""参数"""
kwargs: dict | None = None
"""参数"""
class Task(BaseBlock):
module: str
"""被动技能模块名"""
name: str
"""被动技能名称"""
status: bool = True
"""全局开关状态"""
create_status: bool = False
"""初次加载默认开关状态"""
default_status: bool = True
"""进群时默认状态"""
scheduler: SchedulerModel | None = None
"""定时任务配置"""
run_func: Callable | None = None
"""运行函数"""
check: Callable | None = None
"""检查函数"""
check_args: list = Field(default_factory=list)
"""检查函数参数"""
class PluginExtraData(BaseModel):
"""
插件扩展信息
"""
author: str | None = None
"""作者"""
version: str | None = None
"""版本"""
plugin_type: PluginType = PluginType.NORMAL
"""插件类型"""
menu_type: str = "功能"
"""菜单类型"""
admin_level: int | None = None
"""管理员插件所需权限等级"""
configs: list[RegisterConfig] | None = None
"""插件配置"""
setting: PluginSetting | None = None
"""插件基本配置"""
limits: list[BaseBlock | PluginCdBlock | PluginCountBlock] | None = None
"""插件限制"""
commands: list[Command] = Field(default_factory=list)
"""命令列表,用于说明帮助"""
ignore_prompt: bool = False
"""是否忽略阻断提示"""
tasks: list[Task] | None = None
"""技能被动"""
superuser_help: str | None = None
"""超级用户帮助"""
aliases: set[str] = Field(default_factory=set)
"""额外名称"""
sql_list: list[str] | None = None
"""常用sql"""
is_show: bool = True
"""是否显示在菜单中"""
smart_tools: list[AICallableTag] | None = None
"""智能模式函数工具集"""
def to_dict(self, **kwargs):
return model_dump(self, **kwargs)
class NoSuchConfig(Exception):
pass
class ConfigsManager:
"""
插件配置 资源 管理器
@ -366,23 +198,32 @@ class ConfigsManager:
if not module or not key:
raise ValueError("add_plugin_config: module和key不能为为空")
if isinstance(value, BaseModel):
value = model_dump(value)
if isinstance(default_value, BaseModel):
default_value = model_dump(default_value)
processed_value = _dump_pydantic_obj(value)
processed_default_value = _dump_pydantic_obj(default_value)
self.add_module.append(f"{module}:{key}".lower())
if module in self._data and (config := self._data[module].configs.get(key)):
config.help = help
config.arg_parser = arg_parser
config.type = type
if _override:
config.value = value
config.default_value = default_value
config.value = processed_value
config.default_value = processed_default_value
else:
key = key.upper()
if not self._data.get(module):
self._data[module] = ConfigGroup(module=module)
self._data[module].configs[key] = ConfigModel(
value=value,
value=processed_value,
help=help,
default_value=default_value,
default_value=processed_default_value,
type=type,
arg_parser=arg_parser,
)
def set_config(
@ -402,6 +243,8 @@ class ConfigsManager:
"""
key = key.upper()
if module in self._data:
if module not in self._simple_data:
self._simple_data[module] = {}
if self._data[module].configs.get(key):
self._data[module].configs[key].value = value
else:
@ -410,63 +253,68 @@ class ConfigsManager:
if auto_save:
self.save(save_simple_data=True)
def get_config(self, module: str, key: str, default: Any = None) -> Any:
"""获取指定配置值
参数:
module: 模块名
key: 配置键
default: 没有key值内容的默认返回值.
异常:
NoSuchConfig: 未查询到配置
返回:
Any: 配置值
def get_config(
self,
module: str,
key: str,
default: Any = None,
*,
build_model: bool = True,
) -> Any:
"""
获取指定配置值自动构建Pydantic模型或其它类型实例
- 兼容Pydantic V1/V2
- 支持 list[BaseModel] 等泛型容器
- 优先使用Pydantic原生方式解析失败后回退到cattrs
"""
logger.debug(
f"尝试获取配置MODULE: [<u><y>{module}</y></u>] | KEY: [<u><y>{key}</y></u>]"
)
key = key.upper()
value = None
if module in self._data.keys():
config = self._data[module].configs.get(key) or self._data[
module
].configs.get(key)
if not config:
raise NoSuchConfig(
f"未查询到配置项 MODULE: [ {module} ] | KEY: [ {key} ]"
)
config_group = self._data.get(module)
if not config_group:
return default
config = config_group.configs.get(key)
if not config:
return default
value_to_process = (
config.value if config.value is not None else config.default_value
)
if value_to_process is None:
return default
# 1. 最高优先级:自定义的参数解析器
if config.arg_parser:
try:
if config.arg_parser:
value = config.arg_parser(value or config.default_value)
elif config.value is not None:
# try:
value = (
cattrs.structure(config.value, config.type)
if config.type
else config.value
)
elif config.default_value is not None:
value = (
cattrs.structure(config.default_value, config.type)
if config.type
else config.default_value
)
return config.arg_parser(value_to_process)
except Exception as e:
logger.debug(
f"配置项类型转换 MODULE: [<u><y>{module}</y></u>]"
f" | KEY: [<u><y>{key}</y></u>] 将使用原始值",
e=e,
)
value = config.value or config.default_value
if value is None:
value = default
logger.debug(
f"获取配置 MODULE: [<u><y>{module}</y></u>] | "
f" KEY: [<u><y>{key}</y></u>] -> [<u><c>{value}</c></u>]"
)
return value
if config.type:
if _is_pydantic_type(config.type):
if build_model:
try:
return parse_as(config.type, value_to_process)
except Exception as e:
logger.warning(
f"pydantic类型转换失败 MODULE: [<u><y>{module}</y></u>] | "
f"KEY: [<u><y>{key}</y></u>].",
e=e,
)
else:
try:
return cattrs.structure(value_to_process, config.type)
except Exception as e:
logger.warning(
f"cattrs类型转换失败 MODULE: [<u><y>{module}</y></u>] | "
f"KEY: [<u><y>{key}</y></u>].",
e=e,
)
return value_to_process
def get(self, key: str) -> ConfigGroup:
"""获取插件配置数据
@ -490,16 +338,16 @@ class ConfigsManager:
with open(self._simple_file, "w", encoding="utf8") as f:
_yaml.dump(self._simple_data, f)
path = path or self.file
data = {}
for module in self._data:
data[module] = {}
for config in self._data[module].configs:
value = self._data[module].configs[config].dict()
del value["type"]
del value["arg_parser"]
data[module][config] = value
save_data = {}
for module, config_group in self._data.items():
save_data[module] = {}
for config_key, config_model in config_group.configs.items():
save_data[module][config_key] = model_dump(
config_model, exclude={"type", "arg_parser"}
)
with open(path, "w", encoding="utf8") as f:
_yaml.dump(data, f)
_yaml.dump(save_data, f)
def reload(self):
"""重新加载配置文件"""
@ -558,3 +406,23 @@ class ConfigsManager:
def __getitem__(self, key):
return self._data[key]
__all__ = [
"AICallableParam",
"AICallableProperties",
"AICallableTag",
"BaseBlock",
"Command",
"ConfigGroup",
"ConfigModel",
"ConfigsManager",
"Example",
"NoSuchConfig",
"PluginCdBlock",
"PluginCountBlock",
"PluginExtraData",
"PluginSetting",
"RegisterConfig",
"Task",
]

View File

@ -0,0 +1,270 @@
from collections.abc import Callable
from datetime import datetime
from typing import Any, Literal
from nonebot.compat import model_dump
from pydantic import BaseModel, Field
from zhenxun.utils.enum import BlockType, LimitWatchType, PluginLimitType, PluginType
__all__ = [
"AICallableParam",
"AICallableProperties",
"AICallableTag",
"BaseBlock",
"Command",
"ConfigModel",
"Example",
"PluginCdBlock",
"PluginCountBlock",
"PluginExtraData",
"PluginSetting",
"RegisterConfig",
"Task",
]
class Example(BaseModel):
"""
示例
"""
exec: str
"""执行命令"""
description: str = ""
"""命令描述"""
class Command(BaseModel):
"""
具体参数说明
"""
command: str
"""命令名称"""
params: list[str] = Field(default_factory=list)
"""参数"""
description: str = ""
"""描述"""
examples: list[Example] = Field(default_factory=list)
"""示例列表"""
class RegisterConfig(BaseModel):
"""
注册配置项
"""
key: str
"""配置项键"""
value: Any
"""配置项值"""
module: str | None = None
"""模块名"""
help: str | None
"""配置注解"""
default_value: Any | None = None
"""默认值"""
type: Any = None
"""参数类型"""
arg_parser: Callable | None = None
"""参数解析"""
class ConfigModel(BaseModel):
"""
配置项
"""
value: Any
"""配置项值"""
help: str | None
"""配置注解"""
default_value: Any | None = None
"""默认值"""
type: Any = None
"""参数类型"""
arg_parser: Callable | None = None
"""参数解析"""
def to_dict(self, **kwargs):
return model_dump(self, **kwargs)
class BaseBlock(BaseModel):
"""
插件阻断基本类插件阻断限制
"""
status: bool = True
"""限制状态"""
check_type: BlockType = BlockType.ALL
"""检查类型"""
watch_type: LimitWatchType = LimitWatchType.USER
"""监听对象"""
result: str | None = None
"""阻断时回复内容"""
_type: PluginLimitType = PluginLimitType.BLOCK
"""类型"""
def to_dict(self, **kwargs):
return model_dump(self, **kwargs)
class PluginCdBlock(BaseBlock):
"""
插件cd限制
"""
cd: int = 5
"""cd"""
_type: PluginLimitType = PluginLimitType.CD
"""类型"""
class PluginCountBlock(BaseBlock):
"""
插件次数限制
"""
max_count: int
"""最大调用次数"""
_type: PluginLimitType = PluginLimitType.COUNT
"""类型"""
class PluginSetting(BaseModel):
"""
插件基本配置
"""
level: int = 5
"""群权限等级"""
default_status: bool = True
"""进群默认开关状态"""
limit_superuser: bool = False
"""是否限制超级用户"""
cost_gold: int = 0
"""调用插件花费金币"""
impression: float = 0.0
"""调用插件好感度限制"""
class AICallableProperties(BaseModel):
type: str
"""参数类型"""
description: str
"""参数描述"""
enums: list[str] | None = None
"""参数枚举"""
class AICallableParam(BaseModel):
type: str
"""类型"""
properties: dict[str, AICallableProperties]
"""参数列表"""
required: list[str]
"""必要参数"""
class AICallableTag(BaseModel):
name: str
"""工具名称"""
parameters: AICallableParam | None = None
"""工具参数"""
description: str
"""工具描述"""
func: Callable | None = None
"""工具函数"""
def to_dict(self):
result = model_dump(self)
del result["func"]
return result
class SchedulerModel(BaseModel):
trigger: Literal["date", "interval", "cron"]
"""trigger"""
day: int | None = None
"""天数"""
hour: int | None = None
"""小时"""
minute: int | None = None
"""分钟"""
second: int | None = None
""""""
run_date: datetime | None = None
"""运行日期"""
id: str | None = None
"""id"""
max_instances: int | None = None
"""最大运行实例"""
args: list | None = None
"""参数"""
kwargs: dict | None = None
"""参数"""
class Task(BaseBlock):
module: str
"""被动技能模块名"""
name: str
"""被动技能名称"""
status: bool = True
"""全局开关状态"""
create_status: bool = False
"""初次加载默认开关状态"""
default_status: bool = True
"""进群时默认状态"""
scheduler: SchedulerModel | None = None
"""定时任务配置"""
run_func: Callable | None = None
"""运行函数"""
check: Callable | None = None
"""检查函数"""
check_args: list = Field(default_factory=list)
"""检查函数参数"""
class PluginExtraData(BaseModel):
"""
插件扩展信息
"""
author: str | None = None
"""作者"""
version: str | None = None
"""版本"""
plugin_type: PluginType = PluginType.NORMAL
"""插件类型"""
menu_type: str = "功能"
"""菜单类型"""
admin_level: int | None = None
"""管理员插件所需权限等级"""
configs: list[RegisterConfig] | None = None
"""插件配置"""
setting: PluginSetting | None = None
"""插件基本配置"""
limits: list[BaseBlock | PluginCdBlock | PluginCountBlock] | None = None
"""插件限制"""
commands: list[Command] = Field(default_factory=list)
"""命令列表,用于说明帮助"""
ignore_prompt: bool = False
"""是否忽略阻断提示"""
tasks: list[Task] | None = None
"""技能被动"""
superuser_help: str | None = None
"""超级用户帮助"""
aliases: set[str] = Field(default_factory=set)
"""额外名称"""
sql_list: list[str] | None = None
"""常用sql"""
is_show: bool = True
"""是否显示在菜单中"""
smart_tools: list[AICallableTag] | None = None
"""智能模式函数工具集"""
def to_dict(self, **kwargs):
return model_dump(self, **kwargs)

View File

@ -0,0 +1,731 @@
# Zhenxun LLM 服务模块
## 📑 目录
- [📖 概述](#-概述)
- [🌟 主要特性](#-主要特性)
- [🚀 快速开始](#-快速开始)
- [📚 API 参考](#-api-参考)
- [⚙️ 配置](#-配置)
- [🔧 高级功能](#-高级功能)
- [🏗️ 架构设计](#-架构设计)
- [🔌 支持的提供商](#-支持的提供商)
- [🎯 使用场景](#-使用场景)
- [📊 性能优化](#-性能优化)
- [🛠️ 故障排除](#-故障排除)
- [❓ 常见问题](#-常见问题)
- [📝 示例项目](#-示例项目)
- [🤝 贡献](#-贡献)
- [📄 许可证](#-许可证)
## 📖 概述
Zhenxun LLM 服务模块是一个现代化的AI服务框架提供统一的接口来访问多个大语言模型提供商。该模块采用模块化设计支持异步操作、智能重试、Key轮询和负载均衡等高级功能。
### 🌟 主要特性
- **多提供商支持**: OpenAI、Gemini、智谱AI、DeepSeek等
- **统一接口**: 简洁一致的API设计
- **智能Key轮询**: 自动负载均衡和故障转移
- **异步高性能**: 基于asyncio的并发处理
- **模型缓存**: 智能缓存机制提升性能
- **工具调用**: 支持Function Calling
- **嵌入向量**: 文本向量化支持
- **错误处理**: 完善的异常处理和重试机制
- **多模态支持**: 文本、图像、音频、视频处理
- **代码执行**: Gemini代码执行功能
- **搜索增强**: Google搜索集成
## 🚀 快速开始
### 基本使用
```python
from zhenxun.services.llm import chat, code, search, analyze
# 简单聊天
response = await chat("你好,请介绍一下自己")
print(response)
# 代码执行
result = await code("计算斐波那契数列的前10项")
print(result["text"])
print(result["code_executions"])
# 搜索功能
search_result = await search("Python异步编程最佳实践")
print(search_result["text"])
# 多模态分析
from nonebot_plugin_alconna.uniseg import UniMessage, Image, Text
message = UniMessage([
Text("分析这张图片"),
Image(path="image.jpg")
])
analysis = await analyze(message, model="Gemini/gemini-2.0-flash")
print(analysis)
```
### 使用AI类
```python
from zhenxun.services.llm import AI, AIConfig, CommonOverrides
# 创建AI实例
ai = AI(AIConfig(model="OpenAI/gpt-4"))
# 聊天对话
response = await ai.chat("解释量子计算的基本原理")
# 多模态分析
from nonebot_plugin_alconna.uniseg import UniMessage, Image, Text
multimodal_msg = UniMessage([
Text("这张图片显示了什么?"),
Image(path="image.jpg")
])
result = await ai.analyze(multimodal_msg)
# 便捷的多模态函数
result = await analyze_with_images(
"分析这张图片",
images="image.jpg",
model="Gemini/gemini-2.0-flash"
)
```
## 📚 API 参考
### 快速函数
#### `chat(message, *, model=None, **kwargs) -> str`
简单聊天对话
**参数:**
- `message`: 消息内容字符串、LLMMessage或内容部分列表
- `model`: 模型名称(可选)
- `**kwargs`: 额外配置参数
#### `code(prompt, *, model=None, timeout=None, **kwargs) -> dict`
代码执行功能
**返回:**
```python
{
"text": "执行结果说明",
"code_executions": [{"code": "...", "output": "..."}],
"success": True
}
```
#### `search(query, *, model=None, instruction="", **kwargs) -> dict`
搜索增强生成
**返回:**
```python
{
"text": "搜索结果和分析",
"grounding_metadata": {...},
"success": True
}
```
#### `analyze(message, *, instruction="", model=None, tools=None, tool_config=None, **kwargs) -> str | LLMResponse`
高级分析功能,支持多模态输入和工具调用
#### `analyze_with_images(text, images, *, instruction="", model=None, **kwargs) -> str`
图片分析便捷函数
#### `analyze_multimodal(text=None, images=None, videos=None, audios=None, *, instruction="", model=None, **kwargs) -> str`
多模态分析便捷函数
#### `embed(texts, *, model=None, task_type="RETRIEVAL_DOCUMENT", **kwargs) -> list[list[float]]`
文本嵌入向量
### AI类方法
#### `AI.chat(message, *, model=None, **kwargs) -> str`
聊天对话方法,支持简单多模态输入
#### `AI.analyze(message, *, instruction="", model=None, tools=None, tool_config=None, **kwargs) -> str | LLMResponse`
高级分析方法接收UniMessage进行多模态分析和工具调用
### 模型管理
```python
from zhenxun.services.llm import (
get_model_instance,
list_available_models,
set_global_default_model_name,
clear_model_cache
)
# 获取模型实例
model = await get_model_instance("OpenAI/gpt-4o")
# 列出可用模型
models = list_available_models()
# 设置默认模型
set_global_default_model_name("Gemini/gemini-2.0-flash")
# 清理缓存
clear_model_cache()
```
## ⚙️ 配置
### 预设配置
```python
from zhenxun.services.llm import CommonOverrides
# 创意模式
creative_config = CommonOverrides.creative()
# 精确模式
precise_config = CommonOverrides.precise()
# Gemini特殊功能
json_config = CommonOverrides.gemini_json()
thinking_config = CommonOverrides.gemini_thinking()
code_exec_config = CommonOverrides.gemini_code_execution()
grounding_config = CommonOverrides.gemini_grounding()
```
### 自定义配置
```python
from zhenxun.services.llm import LLMGenerationConfig
config = LLMGenerationConfig(
temperature=0.7,
max_tokens=2048,
top_p=0.9,
frequency_penalty=0.1,
presence_penalty=0.1,
stop=["END", "STOP"],
response_mime_type="application/json",
enable_code_execution=True,
enable_grounding=True
)
response = await chat("你的问题", override_config=config)
```
## 🔧 高级功能
### 工具调用 (Function Calling)
```python
from zhenxun.services.llm import LLMTool, get_model_instance
# 定义工具
tools = [
LLMTool(
name="get_weather",
description="获取天气信息",
parameters={
"type": "object",
"properties": {
"city": {"type": "string", "description": "城市名称"}
},
"required": ["city"]
}
)
]
# 工具执行器
async def tool_executor(tool_name: str, args: dict) -> str:
if tool_name == "get_weather":
return f"{args['city']}今天晴天25°C"
return "未知工具"
# 使用工具
model = await get_model_instance("OpenAI/gpt-4")
response = await model.generate_response(
messages=[{"role": "user", "content": "北京天气如何?"}],
tools=tools,
tool_executor=tool_executor
)
```
### 多模态处理
```python
from zhenxun.services.llm import create_multimodal_message, analyze_multimodal, analyze_with_images
# 方法1使用便捷函数
result = await analyze_multimodal(
text="分析这些媒体文件",
images="image.jpg",
audios="audio.mp3",
model="Gemini/gemini-2.0-flash"
)
# 方法2使用create_multimodal_message
message = create_multimodal_message(
text="分析这张图片和音频",
images="image.jpg",
audios="audio.mp3"
)
result = await analyze(message)
# 方法3图片分析专用函数
result = await analyze_with_images(
"这张图片显示了什么?",
images=["image1.jpg", "image2.jpg"]
)
```
## 🛠️ 故障排除
### 常见错误
1. **配置错误**: 检查API密钥和模型配置
2. **网络问题**: 检查代理设置和网络连接
3. **模型不可用**: 使用 `list_available_models()` 检查可用模型
4. **超时错误**: 调整timeout参数或使用更快的模型
### 调试技巧
```python
from zhenxun.services.llm import get_cache_stats
from zhenxun.services.log import logger
# 查看缓存状态
stats = get_cache_stats()
print(f"缓存命中率: {stats['hit_rate']}")
# 启用详细日志
logger.setLevel("DEBUG")
```
## ❓ 常见问题
### Q: 如何处理多模态输入?
**A:** 有多种方式处理多模态输入:
```python
# 方法1使用便捷函数
result = await analyze_with_images("分析这张图片", images="image.jpg")
# 方法2使用analyze函数
from nonebot_plugin_alconna.uniseg import UniMessage, Image, Text
message = UniMessage([Text("分析这张图片"), Image(path="image.jpg")])
result = await analyze(message)
# 方法3使用create_multimodal_message
from zhenxun.services.llm import create_multimodal_message
message = create_multimodal_message(text="分析这张图片", images="image.jpg")
result = await analyze(message)
```
### Q: 如何自定义工具调用?
**A:** 使用analyze函数的tools参数
```python
# 定义工具
tools = [{
"name": "calculator",
"description": "计算数学表达式",
"parameters": {
"type": "object",
"properties": {
"expression": {"type": "string", "description": "数学表达式"}
},
"required": ["expression"]
}
}]
# 使用工具
from nonebot_plugin_alconna.uniseg import UniMessage, Text
message = UniMessage([Text("计算 2+3*4")])
response = await analyze(message, tools=tools, tool_config={"mode": "auto"})
# 如果返回LLMResponse说明有工具调用
if hasattr(response, 'tool_calls'):
for tool_call in response.tool_calls:
print(f"调用工具: {tool_call.function.name}")
print(f"参数: {tool_call.function.arguments}")
```
### Q: 如何确保输出格式?
**A:** 使用结构化输出:
```python
# JSON格式输出
config = CommonOverrides.gemini_json()
# 自定义Schema
schema = {
"type": "object",
"properties": {
"answer": {"type": "string"},
"confidence": {"type": "number"}
}
}
config = CommonOverrides.gemini_structured(schema)
```
## 📝 示例项目
### 完整示例
#### 1. 智能客服机器人
```python
from zhenxun.services.llm import AI, CommonOverrides
from typing import Dict, List
class CustomerService:
def __init__(self):
self.ai = AI()
self.sessions: Dict[str, List[dict]] = {}
async def handle_query(self, user_id: str, query: str) -> str:
# 获取或创建会话历史
if user_id not in self.sessions:
self.sessions[user_id] = []
history = self.sessions[user_id]
# 添加系统提示
if not history:
history.append({
"role": "system",
"content": "你是一个专业的客服助手,请友好、准确地回答用户问题。"
})
# 添加用户问题
history.append({"role": "user", "content": query})
# 生成回复
response = await self.ai.chat(
query,
history=history[-20:], # 保留最近20轮对话
override_config=CommonOverrides.balanced()
)
# 保存回复到历史
history.append({"role": "assistant", "content": response})
return response
```
#### 2. 文档智能问答
```python
from zhenxun.services.llm import embed, analyze
import numpy as np
from typing import List, Tuple
class DocumentQA:
def __init__(self):
self.documents: List[str] = []
self.embeddings: List[List[float]] = []
async def add_document(self, text: str):
"""添加文档到知识库"""
self.documents.append(text)
# 生成嵌入向量
embedding = await embed([text])
self.embeddings.extend(embedding)
async def query(self, question: str, top_k: int = 3) -> str:
"""查询文档并生成答案"""
if not self.documents:
return "知识库为空,请先添加文档。"
# 生成问题的嵌入向量
question_embedding = await embed([question])
# 计算相似度并找到最相关的文档
similarities = []
for doc_embedding in self.embeddings:
similarity = np.dot(question_embedding[0], doc_embedding)
similarities.append(similarity)
# 获取最相关的文档
top_indices = np.argsort(similarities)[-top_k:][::-1]
relevant_docs = [self.documents[i] for i in top_indices]
# 构建上下文
context = "\n\n".join(relevant_docs)
prompt = f"""
基于以下文档内容回答问题:
文档内容:
{context}
问题:{question}
请基于文档内容给出准确的答案,如果文档中没有相关信息,请说明。
"""
result = await analyze(prompt)
return result["text"]
```
#### 3. 代码审查助手
```python
from zhenxun.services.llm import code, analyze
import os
class CodeReviewer:
async def review_file(self, file_path: str) -> dict:
"""审查代码文件"""
if not os.path.exists(file_path):
return {"error": "文件不存在"}
with open(file_path, 'r', encoding='utf-8') as f:
code_content = f.read()
prompt = f"""
请审查以下代码,提供详细的反馈:
文件:{file_path}
代码:
```
{code_content}
```
请从以下方面进行审查:
1. 代码质量和可读性
2. 潜在的bug和安全问题
3. 性能优化建议
4. 最佳实践建议
5. 代码风格问题
请以JSON格式返回结果。
"""
result = await analyze(
prompt,
model="DeepSeek/deepseek-coder",
override_config=CommonOverrides.gemini_json()
)
return {
"file": file_path,
"review": result["text"],
"success": True
}
async def suggest_improvements(self, code: str, language: str = "python") -> str:
"""建议代码改进"""
prompt = f"""
请改进以下{language}代码,使其更加高效、可读和符合最佳实践:
原代码:
```{language}
{code}
```
请提供改进后的代码和说明。
"""
result = await code(prompt, model="DeepSeek/deepseek-coder")
return result["text"]
```
## 🏗️ 架构设计
### 模块结构
```
zhenxun/services/llm/
├── __init__.py # 包入口导入和暴露公共API
├── api.py # 高级API接口AI类、便捷函数
├── core.py # 核心基础设施HTTP客户端、重试逻辑、KeyStore
├── service.py # LLM模型实现类
├── utils.py # 工具和转换函数
├── manager.py # 模型管理和缓存
├── adapters/ # 适配器模块
│ ├── __init__.py # 适配器包入口
│ ├── base.py # 基础适配器
│ ├── factory.py # 适配器工厂
│ ├── openai.py # OpenAI适配器
│ ├── gemini.py # Gemini适配器
│ └── zhipu.py # 智谱AI适配器
├── config/ # 配置模块
│ ├── __init__.py # 配置包入口
│ ├── generation.py # 生成配置
│ ├── presets.py # 预设配置
│ └── providers.py # 提供商配置
└── types/ # 类型定义
├── __init__.py # 类型包入口
├── content.py # 内容类型
├── enums.py # 枚举定义
├── exceptions.py # 异常定义
└── models.py # 数据模型
```
### 模块职责
- **`__init__.py`**: 纯粹的包入口只负责导入和暴露公共API
- **`api.py`**: 高级API接口包含AI类和所有便捷函数
- **`core.py`**: 核心基础设施包含HTTP客户端管理、重试逻辑和KeyStore
- **`service.py`**: LLM模型实现类专注于模型逻辑
- **`utils.py`**: 工具和转换函数,如多模态消息处理
- **`manager.py`**: 模型管理和缓存机制
- **`adapters/`**: 各大提供商的适配器模块负责与不同API的交互
- `base.py`: 定义适配器的基础接口
- `factory.py`: 适配器工厂,用于动态加载和实例化适配器
- `openai.py`: OpenAI API适配器
- `gemini.py`: Google Gemini API适配器
- `zhipu.py`: 智谱AI API适配器
- **`config/`**: 配置管理模块
- `generation.py`: 生成配置和预设
- `presets.py`: 预设配置
- `providers.py`: 提供商配置
- **`types/`**: 类型定义模块
- `content.py`: 内容类型定义
- `enums.py`: 枚举定义
- `exceptions.py`: 异常定义
- `models.py`: 数据模型定义
## 🔌 支持的提供商
### OpenAI 兼容
- **OpenAI**: GPT-4o, GPT-3.5-turbo等
- **DeepSeek**: deepseek-chat, deepseek-reasoner等
- **其他OpenAI兼容API**: 支持自定义端点
```python
# OpenAI
await chat("Hello", model="OpenAI/gpt-4o")
# DeepSeek
await chat("写代码", model="DeepSeek/deepseek-reasoner")
```
### Google Gemini
- **Gemini Pro**: gemini-2.5-flash-preview-05-20 gemini-2.0-flash等
- **特殊功能**: 代码执行、搜索增强、思考模式
```python
# 基础使用
await chat("你好", model="Gemini/gemini-2.0-flash")
# 代码执行
await code("计算质数", model="Gemini/gemini-2.0-flash")
# 搜索增强
await search("最新AI发展", model="Gemini/gemini-2.5-flash-preview-05-20")
```
### 智谱AI
- **GLM系列**: glm-4, glm-4v等
- **支持功能**: 文本生成、多模态理解
```python
await chat("介绍北京", model="Zhipu/glm-4")
```
## 🎯 使用场景
### 1. 聊天机器人
```python
from zhenxun.services.llm import AI, CommonOverrides
class ChatBot:
def __init__(self):
self.ai = AI()
self.history = []
async def chat(self, user_input: str) -> str:
# 添加历史记录
self.history.append({"role": "user", "content": user_input})
# 生成回复
response = await self.ai.chat(
user_input,
history=self.history[-10:], # 保留最近10轮对话
override_config=CommonOverrides.balanced()
)
self.history.append({"role": "assistant", "content": response})
return response
```
### 2. 代码助手
```python
async def code_assistant(task: str) -> dict:
"""代码生成和执行助手"""
result = await code(
f"请帮我{task},并执行代码验证结果",
model="Gemini/gemini-2.0-flash",
timeout=60
)
return {
"explanation": result["text"],
"code_blocks": result["code_executions"],
"success": result["success"]
}
# 使用示例
result = await code_assistant("实现快速排序算法")
```
### 3. 文档分析
```python
from zhenxun.services.llm import analyze_with_images
async def analyze_document(image_path: str, question: str) -> str:
"""分析文档图片并回答问题"""
result = await analyze_with_images(
f"请分析这个文档并回答:{question}",
images=image_path,
model="Gemini/gemini-2.0-flash"
)
return result
```
### 4. 智能搜索
```python
async def smart_search(query: str) -> dict:
"""智能搜索和总结"""
result = await search(
query,
model="Gemini/gemini-2.0-flash",
instruction="请提供准确、最新的信息,并注明信息来源"
)
return {
"summary": result["text"],
"sources": result.get("grounding_metadata", {}),
"confidence": result.get("confidence_score", 0.0)
}
```
## 🔧 配置管理
### 动态配置
```python
from zhenxun.services.llm import set_global_default_model_name
# 运行时更改默认模型
set_global_default_model_name("OpenAI/gpt-4")
# 检查可用模型
models = list_available_models()
for model in models:
print(f"{model.provider}/{model.name} - {model.description}")
```

View File

@ -0,0 +1,96 @@
"""
LLM 服务模块 - 公共 API 入口
提供统一的 AI 服务调用接口核心类型定义和模型管理功能
"""
from .api import (
AI,
AIConfig,
TaskType,
analyze,
analyze_multimodal,
analyze_with_images,
chat,
code,
embed,
search,
search_multimodal,
)
from .config import (
CommonOverrides,
LLMGenerationConfig,
register_llm_configs,
)
register_llm_configs()
from .api import ModelName
from .manager import (
clear_model_cache,
get_cache_stats,
get_global_default_model_name,
get_model_instance,
list_available_models,
list_embedding_models,
list_model_identifiers,
set_global_default_model_name,
)
from .types import (
EmbeddingTaskType,
LLMContentPart,
LLMErrorCode,
LLMException,
LLMMessage,
LLMResponse,
LLMTool,
ModelDetail,
ModelInfo,
ModelProvider,
ResponseFormat,
ToolCategory,
ToolMetadata,
UsageInfo,
)
from .utils import create_multimodal_message, unimsg_to_llm_parts
__all__ = [
"AI",
"AIConfig",
"CommonOverrides",
"EmbeddingTaskType",
"LLMContentPart",
"LLMErrorCode",
"LLMException",
"LLMGenerationConfig",
"LLMMessage",
"LLMResponse",
"LLMTool",
"ModelDetail",
"ModelInfo",
"ModelName",
"ModelProvider",
"ResponseFormat",
"TaskType",
"ToolCategory",
"ToolMetadata",
"UsageInfo",
"analyze",
"analyze_multimodal",
"analyze_with_images",
"chat",
"clear_model_cache",
"code",
"create_multimodal_message",
"embed",
"get_cache_stats",
"get_global_default_model_name",
"get_model_instance",
"list_available_models",
"list_embedding_models",
"list_model_identifiers",
"register_llm_configs",
"search",
"search_multimodal",
"set_global_default_model_name",
"unimsg_to_llm_parts",
]

View File

@ -0,0 +1,26 @@
"""
LLM 适配器模块
提供不同LLM服务商的API适配器实现统一接口调用方式
"""
from .base import BaseAdapter, OpenAICompatAdapter, RequestData, ResponseData
from .factory import LLMAdapterFactory, get_adapter_for_api_type, register_adapter
from .gemini import GeminiAdapter
from .openai import OpenAIAdapter
from .zhipu import ZhipuAdapter
LLMAdapterFactory.initialize()
__all__ = [
"BaseAdapter",
"GeminiAdapter",
"LLMAdapterFactory",
"OpenAIAdapter",
"OpenAICompatAdapter",
"RequestData",
"ResponseData",
"ZhipuAdapter",
"get_adapter_for_api_type",
"register_adapter",
]

View File

@ -0,0 +1,508 @@
"""
LLM 适配器基类和通用数据结构
"""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel
from zhenxun.services.log import logger
from ..types.exceptions import LLMErrorCode, LLMException
from ..types.models import LLMToolCall
if TYPE_CHECKING:
from ..config.generation import LLMGenerationConfig
from ..service import LLMModel
from ..types.content import LLMMessage
from ..types.enums import EmbeddingTaskType
class RequestData(BaseModel):
"""请求数据封装"""
url: str
headers: dict[str, str]
body: dict[str, Any]
class ResponseData(BaseModel):
"""响应数据封装 - 支持所有高级功能"""
text: str
usage_info: dict[str, Any] | None = None
raw_response: dict[str, Any] | None = None
tool_calls: list[LLMToolCall] | None = None
code_executions: list[Any] | None = None
grounding_metadata: Any | None = None
cache_info: Any | None = None
code_execution_results: list[dict[str, Any]] | None = None
search_results: list[dict[str, Any]] | None = None
function_calls: list[dict[str, Any]] | None = None
safety_ratings: list[dict[str, Any]] | None = None
citations: list[dict[str, Any]] | None = None
class BaseAdapter(ABC):
"""LLM API适配器基类"""
@property
@abstractmethod
def api_type(self) -> str:
"""API类型标识"""
pass
@property
@abstractmethod
def supported_api_types(self) -> list[str]:
"""支持的API类型列表"""
pass
def prepare_simple_request(
self,
model: "LLMModel",
api_key: str,
prompt: str,
history: list[dict[str, str]] | None = None,
) -> RequestData:
"""准备简单文本生成请求
默认实现将简单请求转换为高级请求格式
子类可以重写此方法以提供特定的优化实现
"""
from ..types.content import LLMMessage
messages: list[LLMMessage] = []
if history:
for msg in history:
role = msg.get("role", "user")
content = msg.get("content", "")
messages.append(LLMMessage(role=role, content=content))
messages.append(LLMMessage(role="user", content=prompt))
config = model._generation_config
return self.prepare_advanced_request(
model=model,
api_key=api_key,
messages=messages,
config=config,
tools=None,
tool_choice=None,
)
@abstractmethod
def prepare_advanced_request(
self,
model: "LLMModel",
api_key: str,
messages: list["LLMMessage"],
config: "LLMGenerationConfig | None" = None,
tools: list[dict[str, Any]] | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> RequestData:
"""准备高级请求"""
pass
@abstractmethod
def parse_response(
self,
model: "LLMModel",
response_json: dict[str, Any],
is_advanced: bool = False,
) -> ResponseData:
"""解析API响应"""
pass
@abstractmethod
def prepare_embedding_request(
self,
model: "LLMModel",
api_key: str,
texts: list[str],
task_type: "EmbeddingTaskType | str",
**kwargs: Any,
) -> RequestData:
"""准备文本嵌入请求"""
pass
@abstractmethod
def parse_embedding_response(
self, response_json: dict[str, Any]
) -> list[list[float]]:
"""解析文本嵌入响应"""
pass
def validate_embedding_response(self, response_json: dict[str, Any]) -> None:
"""验证嵌入API响应"""
if "error" in response_json:
error_info = response_json["error"]
msg = (
error_info.get("message", str(error_info))
if isinstance(error_info, dict)
else str(error_info)
)
raise LLMException(
f"嵌入API错误: {msg}",
code=LLMErrorCode.EMBEDDING_FAILED,
details=response_json,
)
def get_api_url(self, model: "LLMModel", endpoint: str) -> str:
"""构建API URL"""
if not model.api_base:
raise LLMException(
f"模型 {model.model_name} 的 api_base 未设置",
code=LLMErrorCode.CONFIGURATION_ERROR,
)
return f"{model.api_base.rstrip('/')}{endpoint}"
def get_base_headers(self, api_key: str) -> dict[str, str]:
"""获取基础请求头"""
from zhenxun.utils.user_agent import get_user_agent
headers = get_user_agent()
headers.update(
{
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
}
)
return headers
def convert_messages_to_openai_format(
self, messages: list["LLMMessage"]
) -> list[dict[str, Any]]:
"""将LLMMessage转换为OpenAI格式 - 通用方法"""
openai_messages: list[dict[str, Any]] = []
for msg in messages:
openai_msg: dict[str, Any] = {"role": msg.role}
if msg.role == "tool":
openai_msg["tool_call_id"] = msg.tool_call_id
openai_msg["name"] = msg.name
openai_msg["content"] = msg.content
else:
if isinstance(msg.content, str):
openai_msg["content"] = msg.content
else:
content_parts = []
for part in msg.content:
if part.type == "text":
content_parts.append({"type": "text", "text": part.text})
elif part.type == "image":
content_parts.append(
{
"type": "image_url",
"image_url": {"url": part.image_source},
}
)
openai_msg["content"] = content_parts
if msg.role == "assistant" and msg.tool_calls:
assistant_tool_calls = []
for call in msg.tool_calls:
assistant_tool_calls.append(
{
"id": call.id,
"type": "function",
"function": {
"name": call.function.name,
"arguments": call.function.arguments,
},
}
)
openai_msg["tool_calls"] = assistant_tool_calls
if msg.name and msg.role != "tool":
openai_msg["name"] = msg.name
openai_messages.append(openai_msg)
return openai_messages
def parse_openai_response(self, response_json: dict[str, Any]) -> ResponseData:
"""解析OpenAI格式的响应 - 通用方法"""
self.validate_response(response_json)
try:
choices = response_json.get("choices", [])
if not choices:
logger.debug("OpenAI响应中没有choices可能为空回复或流结束。")
return ResponseData(text="", raw_response=response_json)
choice = choices[0]
message = choice.get("message", {})
content = message.get("content", "")
parsed_tool_calls: list[LLMToolCall] | None = None
if message_tool_calls := message.get("tool_calls"):
from ..types.models import LLMToolFunction
parsed_tool_calls = []
for tc_data in message_tool_calls:
try:
if tc_data.get("type") == "function":
parsed_tool_calls.append(
LLMToolCall(
id=tc_data["id"],
function=LLMToolFunction(
name=tc_data["function"]["name"],
arguments=tc_data["function"]["arguments"],
),
)
)
except KeyError as e:
logger.warning(
f"解析OpenAI工具调用数据时缺少键: {tc_data}, 错误: {e}"
)
except Exception as e:
logger.warning(
f"解析OpenAI工具调用数据时出错: {tc_data}, 错误: {e}"
)
if not parsed_tool_calls:
parsed_tool_calls = None
final_text = content if content is not None else ""
if not final_text and parsed_tool_calls:
final_text = f"请求调用 {len(parsed_tool_calls)} 个工具。"
usage_info = response_json.get("usage")
return ResponseData(
text=final_text,
tool_calls=parsed_tool_calls,
usage_info=usage_info,
raw_response=response_json,
)
except Exception as e:
logger.error(f"解析OpenAI格式响应失败: {e}", e=e)
raise LLMException(
f"解析API响应失败: {e}",
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
cause=e,
)
def validate_response(self, response_json: dict[str, Any]) -> None:
"""验证API响应解析不同API的错误结构"""
if "error" in response_json:
error_info = response_json["error"]
if isinstance(error_info, dict):
error_message = error_info.get("message", "未知错误")
error_code = error_info.get("code", "unknown")
error_type = error_info.get("type", "api_error")
error_code_mapping = {
"invalid_api_key": LLMErrorCode.API_KEY_INVALID,
"authentication_failed": LLMErrorCode.API_KEY_INVALID,
"rate_limit_exceeded": LLMErrorCode.API_RATE_LIMITED,
"quota_exceeded": LLMErrorCode.API_RATE_LIMITED,
"model_not_found": LLMErrorCode.MODEL_NOT_FOUND,
"invalid_model": LLMErrorCode.MODEL_NOT_FOUND,
"context_length_exceeded": LLMErrorCode.CONTEXT_LENGTH_EXCEEDED,
"max_tokens_exceeded": LLMErrorCode.CONTEXT_LENGTH_EXCEEDED,
}
llm_error_code = error_code_mapping.get(
error_code, LLMErrorCode.API_RESPONSE_INVALID
)
logger.error(
f"API返回错误: {error_message} "
f"(代码: {error_code}, 类型: {error_type})"
)
else:
error_message = str(error_info)
error_code = "unknown"
llm_error_code = LLMErrorCode.API_RESPONSE_INVALID
logger.error(f"API返回错误: {error_message}")
raise LLMException(
f"API请求失败: {error_message}",
code=llm_error_code,
details={"api_error": error_info, "error_code": error_code},
)
if "candidates" in response_json:
candidates = response_json.get("candidates", [])
if candidates:
candidate = candidates[0]
finish_reason = candidate.get("finishReason")
if finish_reason in ["SAFETY", "RECITATION"]:
safety_ratings = candidate.get("safetyRatings", [])
logger.warning(
f"Gemini内容被安全过滤: {finish_reason}, "
f"安全评级: {safety_ratings}"
)
raise LLMException(
f"内容被安全过滤: {finish_reason}",
code=LLMErrorCode.CONTENT_FILTERED,
details={
"finish_reason": finish_reason,
"safety_ratings": safety_ratings,
},
)
if not response_json:
logger.error("API返回空响应")
raise LLMException(
"API返回空响应",
code=LLMErrorCode.API_RESPONSE_INVALID,
details={"response": response_json},
)
def _apply_generation_config(
self,
model: "LLMModel",
config: "LLMGenerationConfig | None" = None,
) -> dict[str, Any]:
"""通用的配置应用逻辑"""
if config is not None:
return config.to_api_params(model.api_type, model.model_name)
if model._generation_config is not None:
return model._generation_config.to_api_params(
model.api_type, model.model_name
)
base_config = {}
if model.temperature is not None:
base_config["temperature"] = model.temperature
if model.max_tokens is not None:
if model.api_type in ["gemini", "gemini_native"]:
base_config["maxOutputTokens"] = model.max_tokens
else:
base_config["max_tokens"] = model.max_tokens
return base_config
def apply_config_override(
self,
model: "LLMModel",
body: dict[str, Any],
config: "LLMGenerationConfig | None" = None,
) -> dict[str, Any]:
"""应用配置覆盖"""
config_params = self._apply_generation_config(model, config)
body.update(config_params)
return body
class OpenAICompatAdapter(BaseAdapter):
"""
处理所有 OpenAI 兼容 API 的通用适配器
消除 OpenAIAdapter ZhipuAdapter 之间的代码重复
"""
@abstractmethod
def get_chat_endpoint(self) -> str:
"""子类必须实现,返回 chat completions 的端点"""
pass
@abstractmethod
def get_embedding_endpoint(self) -> str:
"""子类必须实现,返回 embeddings 的端点"""
pass
def prepare_advanced_request(
self,
model: "LLMModel",
api_key: str,
messages: list["LLMMessage"],
config: "LLMGenerationConfig | None" = None,
tools: list[dict[str, Any]] | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> RequestData:
"""准备高级请求 - OpenAI兼容格式"""
url = self.get_api_url(model, self.get_chat_endpoint())
headers = self.get_base_headers(api_key)
openai_messages = self.convert_messages_to_openai_format(messages)
body = {
"model": model.model_name,
"messages": openai_messages,
}
if tools:
body["tools"] = tools
if tool_choice:
body["tool_choice"] = tool_choice
body = self.apply_config_override(model, body, config)
return RequestData(url=url, headers=headers, body=body)
def parse_response(
self,
model: "LLMModel",
response_json: dict[str, Any],
is_advanced: bool = False,
) -> ResponseData:
"""解析响应 - 直接使用基类的 OpenAI 格式解析"""
_ = model, is_advanced # 未使用的参数
return self.parse_openai_response(response_json)
def prepare_embedding_request(
self,
model: "LLMModel",
api_key: str,
texts: list[str],
task_type: "EmbeddingTaskType | str",
**kwargs: Any,
) -> RequestData:
"""准备嵌入请求 - OpenAI兼容格式"""
_ = task_type # 未使用的参数
url = self.get_api_url(model, self.get_embedding_endpoint())
headers = self.get_base_headers(api_key)
body = {
"model": model.model_name,
"input": texts,
}
# 应用额外的配置参数
if kwargs:
body.update(kwargs)
return RequestData(url=url, headers=headers, body=body)
def parse_embedding_response(
self, response_json: dict[str, Any]
) -> list[list[float]]:
"""解析嵌入响应 - OpenAI兼容格式"""
self.validate_embedding_response(response_json)
try:
data = response_json.get("data", [])
if not data:
raise LLMException(
"嵌入响应中没有数据",
code=LLMErrorCode.EMBEDDING_FAILED,
details=response_json,
)
embeddings = []
for item in data:
if "embedding" in item:
embeddings.append(item["embedding"])
else:
raise LLMException(
"嵌入响应格式错误缺少embedding字段",
code=LLMErrorCode.EMBEDDING_FAILED,
details=item,
)
return embeddings
except Exception as e:
logger.error(f"解析嵌入响应失败: {e}", e=e)
raise LLMException(
f"解析嵌入响应失败: {e}",
code=LLMErrorCode.EMBEDDING_FAILED,
cause=e,
)

View File

@ -0,0 +1,78 @@
"""
LLM 适配器工厂类
"""
from typing import ClassVar
from ..types.exceptions import LLMErrorCode, LLMException
from .base import BaseAdapter
class LLMAdapterFactory:
"""LLM适配器工厂类"""
_adapters: ClassVar[dict[str, BaseAdapter]] = {}
_api_type_mapping: ClassVar[dict[str, str]] = {}
@classmethod
def initialize(cls) -> None:
"""初始化默认适配器"""
if cls._adapters:
return
from .gemini import GeminiAdapter
from .openai import OpenAIAdapter
from .zhipu import ZhipuAdapter
cls.register_adapter(OpenAIAdapter())
cls.register_adapter(ZhipuAdapter())
cls.register_adapter(GeminiAdapter())
@classmethod
def register_adapter(cls, adapter: BaseAdapter) -> None:
"""注册适配器"""
adapter_key = adapter.api_type
cls._adapters[adapter_key] = adapter
for api_type in adapter.supported_api_types:
cls._api_type_mapping[api_type] = adapter_key
@classmethod
def get_adapter(cls, api_type: str) -> BaseAdapter:
"""获取适配器"""
cls.initialize()
adapter_key = cls._api_type_mapping.get(api_type)
if not adapter_key:
raise LLMException(
f"不支持的API类型: {api_type}",
code=LLMErrorCode.UNKNOWN_API_TYPE,
details={
"api_type": api_type,
"supported_types": list(cls._api_type_mapping.keys()),
},
)
return cls._adapters[adapter_key]
@classmethod
def list_supported_types(cls) -> list[str]:
"""列出所有支持的API类型"""
cls.initialize()
return list(cls._api_type_mapping.keys())
@classmethod
def list_adapters(cls) -> dict[str, BaseAdapter]:
"""列出所有注册的适配器"""
cls.initialize()
return cls._adapters.copy()
def get_adapter_for_api_type(api_type: str) -> BaseAdapter:
"""获取指定API类型的适配器"""
return LLMAdapterFactory.get_adapter(api_type)
def register_adapter(adapter: BaseAdapter) -> None:
"""注册新的适配器"""
LLMAdapterFactory.register_adapter(adapter)

View File

@ -0,0 +1,596 @@
"""
Gemini API 适配器
"""
from typing import TYPE_CHECKING, Any
from zhenxun.services.log import logger
from ..types.exceptions import LLMErrorCode, LLMException
from .base import BaseAdapter, RequestData, ResponseData
if TYPE_CHECKING:
from ..config.generation import LLMGenerationConfig
from ..service import LLMModel
from ..types.content import LLMMessage
from ..types.enums import EmbeddingTaskType
from ..types.models import LLMToolCall
class GeminiAdapter(BaseAdapter):
"""Gemini API 适配器"""
@property
def api_type(self) -> str:
return "gemini"
@property
def supported_api_types(self) -> list[str]:
return ["gemini"]
def get_base_headers(self, api_key: str) -> dict[str, str]:
"""获取基础请求头"""
from zhenxun.utils.user_agent import get_user_agent
headers = get_user_agent()
headers.update({"Content-Type": "application/json"})
headers["x-goog-api-key"] = api_key
return headers
def prepare_advanced_request(
self,
model: "LLMModel",
api_key: str,
messages: list["LLMMessage"],
config: "LLMGenerationConfig | None" = None,
tools: list[dict[str, Any]] | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> RequestData:
"""准备高级请求"""
return self._prepare_request(
model, api_key, messages, config, tools, tool_choice
)
def _prepare_request(
self,
model: "LLMModel",
api_key: str,
messages: list["LLMMessage"],
config: "LLMGenerationConfig | None" = None,
tools: list[dict[str, Any]] | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> RequestData:
"""准备 Gemini API 请求 - 支持所有高级功能"""
effective_config = config if config is not None else model._generation_config
endpoint = self._get_gemini_endpoint(model, effective_config)
url = self.get_api_url(model, endpoint)
headers = self.get_base_headers(api_key)
gemini_contents: list[dict[str, Any]] = []
system_instruction_parts: list[dict[str, Any]] | None = None
for msg in messages:
current_parts: list[dict[str, Any]] = []
if msg.role == "system":
if isinstance(msg.content, str):
system_instruction_parts = [{"text": msg.content}]
elif isinstance(msg.content, list):
system_instruction_parts = [
part.convert_for_api("gemini") for part in msg.content
]
continue
elif msg.role == "user":
if isinstance(msg.content, str):
current_parts.append({"text": msg.content})
elif isinstance(msg.content, list):
for part_obj in msg.content:
current_parts.append(part_obj.convert_for_api("gemini"))
gemini_contents.append({"role": "user", "parts": current_parts})
elif msg.role == "assistant" or msg.role == "model":
if isinstance(msg.content, str) and msg.content:
current_parts.append({"text": msg.content})
elif isinstance(msg.content, list):
for part_obj in msg.content:
current_parts.append(part_obj.convert_for_api("gemini"))
if msg.tool_calls:
import json
for call in msg.tool_calls:
current_parts.append(
{
"functionCall": {
"name": call.function.name,
"args": json.loads(call.function.arguments),
}
}
)
if current_parts:
gemini_contents.append({"role": "model", "parts": current_parts})
elif msg.role == "tool":
if not msg.name:
raise ValueError("Gemini 工具消息必须包含 'name' 字段(函数名)。")
import json
try:
content_str = (
msg.content
if isinstance(msg.content, str)
else str(msg.content)
)
tool_result_obj = json.loads(content_str)
except json.JSONDecodeError:
content_str = (
msg.content
if isinstance(msg.content, str)
else str(msg.content)
)
logger.warning(
f"工具 {msg.name} 的结果不是有效的 JSON: {content_str}. "
f"包装为原始字符串。"
)
tool_result_obj = {"raw_output": content_str}
current_parts.append(
{
"functionResponse": {
"name": msg.name,
"response": tool_result_obj,
}
}
)
gemini_contents.append({"role": "function", "parts": current_parts})
body: dict[str, Any] = {"contents": gemini_contents}
if system_instruction_parts:
body["systemInstruction"] = {"parts": system_instruction_parts}
all_tools_for_request = []
if tools:
for tool_item in tools:
if isinstance(tool_item, dict):
if "name" in tool_item and "description" in tool_item:
all_tools_for_request.append(
{"functionDeclarations": [tool_item]}
)
else:
all_tools_for_request.append(tool_item)
else:
all_tools_for_request.append(tool_item)
if effective_config:
if getattr(effective_config, "enable_grounding", False):
has_explicit_gs_tool = any(
"googleSearch" in tool_item for tool_item in all_tools_for_request
)
if not has_explicit_gs_tool:
all_tools_for_request.append({"googleSearch": {}})
logger.debug("隐式启用 Google Search 工具进行信息来源关联。")
if getattr(effective_config, "enable_code_execution", False):
has_explicit_ce_tool = any(
"codeExecution" in tool_item for tool_item in all_tools_for_request
)
if not has_explicit_ce_tool:
all_tools_for_request.append({"codeExecution": {}})
logger.debug("隐式启用代码执行工具。")
if all_tools_for_request:
gemini_api_tools = self._convert_tools_to_gemini_format(
all_tools_for_request
)
if gemini_api_tools:
body["tools"] = gemini_api_tools
final_tool_choice = tool_choice
if final_tool_choice is None and effective_config:
final_tool_choice = getattr(effective_config, "tool_choice", None)
if final_tool_choice:
if isinstance(final_tool_choice, str):
mode_upper = final_tool_choice.upper()
if mode_upper in ["AUTO", "NONE", "ANY"]:
body["toolConfig"] = {"functionCallingConfig": {"mode": mode_upper}}
else:
body["toolConfig"] = self._convert_tool_choice_to_gemini(
final_tool_choice
)
else:
body["toolConfig"] = self._convert_tool_choice_to_gemini(
final_tool_choice
)
final_generation_config = self._build_gemini_generation_config(
model, effective_config
)
if final_generation_config:
body["generationConfig"] = final_generation_config
safety_settings = self._build_safety_settings(effective_config)
if safety_settings:
body["safetySettings"] = safety_settings
return RequestData(url=url, headers=headers, body=body)
def apply_config_override(
self,
model: "LLMModel",
body: dict[str, Any],
config: "LLMGenerationConfig | None" = None,
) -> dict[str, Any]:
"""应用配置覆盖 - Gemini 不需要额外的配置覆盖"""
return body
def _get_gemini_endpoint(
self, model: "LLMModel", config: "LLMGenerationConfig | None" = None
) -> str:
"""根据配置选择Gemini API端点"""
if config:
if getattr(config, "enable_code_execution", False):
return f"/v1beta/models/{model.model_name}:generateContent"
if getattr(config, "enable_grounding", False):
return f"/v1beta/models/{model.model_name}:generateContent"
return f"/v1beta/models/{model.model_name}:generateContent"
def _convert_tools_to_gemini_format(
self, tools: list[dict[str, Any]]
) -> list[dict[str, Any]]:
"""转换工具格式为Gemini格式"""
gemini_tools = []
for tool in tools:
if tool.get("type") == "function":
func = tool["function"]
gemini_tool = {
"functionDeclarations": [
{
"name": func["name"],
"description": func.get("description", ""),
"parameters": func.get("parameters", {}),
}
]
}
gemini_tools.append(gemini_tool)
elif tool.get("type") == "code_execution":
gemini_tools.append(
{"codeExecution": {"language": tool.get("language", "python")}}
)
elif tool.get("type") == "google_search":
gemini_tools.append({"googleSearch": {}})
elif "googleSearch" in tool:
gemini_tools.append({"googleSearch": tool["googleSearch"]})
elif "codeExecution" in tool:
gemini_tools.append({"codeExecution": tool["codeExecution"]})
return gemini_tools
def _convert_tool_choice_to_gemini(
self, tool_choice_value: str | dict[str, Any]
) -> dict[str, Any]:
"""转换工具选择策略为Gemini格式"""
if isinstance(tool_choice_value, str):
mode_upper = tool_choice_value.upper()
if mode_upper in ["AUTO", "NONE", "ANY"]:
return {"functionCallingConfig": {"mode": mode_upper}}
else:
logger.warning(
f"不支持的 tool_choice 字符串值: '{tool_choice_value}'"
f"回退到 AUTO。"
)
return {"functionCallingConfig": {"mode": "AUTO"}}
elif isinstance(tool_choice_value, dict):
if (
tool_choice_value.get("type") == "function"
and "function" in tool_choice_value
):
func_name = tool_choice_value["function"].get("name")
if func_name:
return {
"functionCallingConfig": {
"mode": "ANY",
"allowedFunctionNames": [func_name],
}
}
else:
logger.warning(
f"tool_choice dict 中的函数名无效: {tool_choice_value}"
f"回退到 AUTO。"
)
return {"functionCallingConfig": {"mode": "AUTO"}}
elif "functionCallingConfig" in tool_choice_value:
return {
"functionCallingConfig": tool_choice_value["functionCallingConfig"]
}
else:
logger.warning(
f"不支持的 tool_choice dict 值: {tool_choice_value}。回退到 AUTO。"
)
return {"functionCallingConfig": {"mode": "AUTO"}}
logger.warning(
f"tool_choice 的类型无效: {type(tool_choice_value)}。回退到 AUTO。"
)
return {"functionCallingConfig": {"mode": "AUTO"}}
def _build_gemini_generation_config(
self, model: "LLMModel", config: "LLMGenerationConfig | None" = None
) -> dict[str, Any]:
"""构建Gemini生成配置"""
generation_config: dict[str, Any] = {}
effective_config = config if config is not None else model._generation_config
if effective_config:
base_api_params = effective_config.to_api_params(
api_type="gemini", model_name=model.model_name
)
generation_config.update(base_api_params)
if getattr(effective_config, "response_mime_type", None):
generation_config["responseMimeType"] = (
effective_config.response_mime_type
)
if getattr(effective_config, "response_schema", None):
generation_config["responseSchema"] = effective_config.response_schema
thinking_budget = getattr(effective_config, "thinking_budget", None)
if thinking_budget is not None:
if "thinkingConfig" not in generation_config:
generation_config["thinkingConfig"] = {}
generation_config["thinkingConfig"]["thinkingBudget"] = thinking_budget
if getattr(effective_config, "response_modalities", None):
modalities = effective_config.response_modalities
if isinstance(modalities, list):
generation_config["responseModalities"] = [
m.upper() for m in modalities
]
elif isinstance(modalities, str):
generation_config["responseModalities"] = [modalities.upper()]
generation_config = {
k: v for k, v in generation_config.items() if v is not None
}
if generation_config:
param_keys = list(generation_config.keys())
logger.debug(
f"构建Gemini生成配置完成包含 {len(generation_config)} 个参数: "
f"{param_keys}"
)
return generation_config
def _build_safety_settings(
self, config: "LLMGenerationConfig | None" = None
) -> list[dict[str, Any]] | None:
"""构建安全设置"""
if not config:
return None
safety_settings = []
safety_categories = [
"HARM_CATEGORY_HARASSMENT",
"HARM_CATEGORY_HATE_SPEECH",
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
"HARM_CATEGORY_DANGEROUS_CONTENT",
]
custom_safety_settings = getattr(config, "safety_settings", None)
if custom_safety_settings:
for category, threshold in custom_safety_settings.items():
safety_settings.append({"category": category, "threshold": threshold})
else:
for category in safety_categories:
safety_settings.append(
{"category": category, "threshold": "BLOCK_MEDIUM_AND_ABOVE"}
)
return safety_settings if safety_settings else None
def parse_response(
self,
model: "LLMModel",
response_json: dict[str, Any],
is_advanced: bool = False,
) -> ResponseData:
"""解析API响应"""
return self._parse_response(model, response_json, is_advanced)
def _parse_response(
self,
model: "LLMModel",
response_json: dict[str, Any],
is_advanced: bool = False,
) -> ResponseData:
"""解析 Gemini API 响应"""
_ = is_advanced
self.validate_response(response_json)
try:
candidates = response_json.get("candidates", [])
if not candidates:
logger.debug("Gemini响应中没有candidates。")
return ResponseData(text="", raw_response=response_json)
candidate = candidates[0]
if candidate.get("finishReason") in [
"RECITATION",
"OTHER",
] and not candidate.get("content"):
logger.warning(
f"Gemini candidate finished with reason "
f"'{candidate.get('finishReason')}' and no content."
)
return ResponseData(
text="",
raw_response=response_json,
usage_info=response_json.get("usageMetadata"),
)
content_data = candidate.get("content", {})
parts = content_data.get("parts", [])
text_content = ""
parsed_tool_calls: list["LLMToolCall"] | None = None
for part in parts:
if "text" in part:
text_content += part["text"]
elif "functionCall" in part:
if parsed_tool_calls is None:
parsed_tool_calls = []
fc_data = part["functionCall"]
try:
import json
from ..types.models import LLMToolCall, LLMToolFunction
call_id = f"call_{model.provider_name}_{len(parsed_tool_calls)}"
parsed_tool_calls.append(
LLMToolCall(
id=call_id,
function=LLMToolFunction(
name=fc_data["name"],
arguments=json.dumps(fc_data["args"]),
),
)
)
except KeyError as e:
logger.warning(
f"解析Gemini functionCall时缺少键: {fc_data}, 错误: {e}"
)
except Exception as e:
logger.warning(
f"解析Gemini functionCall时出错: {fc_data}, 错误: {e}"
)
elif "codeExecutionResult" in part:
result = part["codeExecutionResult"]
if result.get("outcome") == "OK":
output = result.get("output", "")
text_content += f"\n[代码执行结果]:\n{output}\n"
else:
text_content += (
f"\n[代码执行失败]: {result.get('outcome', 'UNKNOWN')}\n"
)
usage_info = response_json.get("usageMetadata")
grounding_metadata_obj = None
if grounding_data := candidate.get("groundingMetadata"):
try:
from ..types.models import LLMGroundingMetadata
grounding_metadata_obj = LLMGroundingMetadata(**grounding_data)
except Exception as e:
logger.warning(f"无法解析Grounding元数据: {grounding_data}, {e}")
return ResponseData(
text=text_content,
tool_calls=parsed_tool_calls,
usage_info=usage_info,
raw_response=response_json,
grounding_metadata=grounding_metadata_obj,
)
except Exception as e:
logger.error(f"解析 Gemini 响应失败: {e}", e=e)
raise LLMException(
f"解析API响应失败: {e}",
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
cause=e,
)
def prepare_embedding_request(
self,
model: "LLMModel",
api_key: str,
texts: list[str],
task_type: "EmbeddingTaskType | str",
**kwargs: Any,
) -> RequestData:
"""准备文本嵌入请求"""
api_model_name = model.model_name
if not api_model_name.startswith("models/"):
api_model_name = f"models/{api_model_name}"
url = self.get_api_url(model, f"/{api_model_name}:batchEmbedContents")
headers = self.get_base_headers(api_key)
requests_payload = []
for text_content in texts:
request_item: dict[str, Any] = {
"content": {"parts": [{"text": text_content}]},
}
from ..types.enums import EmbeddingTaskType
if task_type and task_type != EmbeddingTaskType.RETRIEVAL_DOCUMENT:
request_item["task_type"] = str(task_type).upper()
if title := kwargs.get("title"):
request_item["title"] = title
if output_dimensionality := kwargs.get("output_dimensionality"):
request_item["output_dimensionality"] = output_dimensionality
requests_payload.append(request_item)
body = {"requests": requests_payload}
return RequestData(url=url, headers=headers, body=body)
def parse_embedding_response(
self, response_json: dict[str, Any]
) -> list[list[float]]:
"""解析文本嵌入响应"""
try:
embeddings_data = response_json["embeddings"]
return [item["values"] for item in embeddings_data]
except KeyError as e:
logger.error(f"解析Gemini嵌入响应时缺少键: {e}. 响应: {response_json}")
raise LLMException(
"Gemini嵌入响应格式错误",
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
details={"error": str(e)},
)
except Exception as e:
logger.error(
f"解析Gemini嵌入响应时发生未知错误: {e}. 响应: {response_json}"
)
raise LLMException(
f"解析Gemini嵌入响应失败: {e}",
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
cause=e,
)
def validate_embedding_response(self, response_json: dict[str, Any]) -> None:
"""验证嵌入响应"""
super().validate_embedding_response(response_json)
if "embeddings" not in response_json or not isinstance(
response_json["embeddings"], list
):
raise LLMException(
"Gemini嵌入响应缺少'embeddings'字段或格式不正确",
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
details=response_json,
)
for item in response_json["embeddings"]:
if "values" not in item:
raise LLMException(
"Gemini嵌入响应的条目中缺少'values'字段",
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
details=response_json,
)

View File

@ -0,0 +1,57 @@
"""
OpenAI API 适配器
支持 OpenAIDeepSeek 和其他 OpenAI 兼容的 API 服务
"""
from typing import TYPE_CHECKING
from .base import OpenAICompatAdapter, RequestData
if TYPE_CHECKING:
from ..service import LLMModel
class OpenAIAdapter(OpenAICompatAdapter):
"""OpenAI兼容API适配器"""
@property
def api_type(self) -> str:
return "openai"
@property
def supported_api_types(self) -> list[str]:
return ["openai", "deepseek", "general_openai_compat"]
def get_chat_endpoint(self) -> str:
"""返回聊天完成端点"""
return "/v1/chat/completions"
def get_embedding_endpoint(self) -> str:
"""返回嵌入端点"""
return "/v1/embeddings"
def prepare_simple_request(
self,
model: "LLMModel",
api_key: str,
prompt: str,
history: list[dict[str, str]] | None = None,
) -> RequestData:
"""准备简单文本生成请求 - OpenAI优化实现"""
url = self.get_api_url(model, self.get_chat_endpoint())
headers = self.get_base_headers(api_key)
messages = []
if history:
messages.extend(history)
messages.append({"role": "user", "content": prompt})
body = {
"model": model.model_name,
"messages": messages,
}
body = self.apply_config_override(model, body)
return RequestData(url=url, headers=headers, body=body)

View File

@ -0,0 +1,57 @@
"""
智谱 AI API 适配器
支持智谱 AI GLM 系列模型使用 OpenAI 兼容的接口格式
"""
from typing import TYPE_CHECKING
from .base import OpenAICompatAdapter, RequestData
if TYPE_CHECKING:
from ..service import LLMModel
class ZhipuAdapter(OpenAICompatAdapter):
"""智谱AI适配器 - 使用智谱AI专用的OpenAI兼容接口"""
@property
def api_type(self) -> str:
return "zhipu"
@property
def supported_api_types(self) -> list[str]:
return ["zhipu"]
def get_chat_endpoint(self) -> str:
"""返回智谱AI聊天完成端点"""
return "/api/paas/v4/chat/completions"
def get_embedding_endpoint(self) -> str:
"""返回智谱AI嵌入端点"""
return "/v4/embeddings"
def prepare_simple_request(
self,
model: "LLMModel",
api_key: str,
prompt: str,
history: list[dict[str, str]] | None = None,
) -> RequestData:
"""准备简单文本生成请求 - 智谱AI优化实现"""
url = self.get_api_url(model, self.get_chat_endpoint())
headers = self.get_base_headers(api_key)
messages = []
if history:
messages.extend(history)
messages.append({"role": "user", "content": prompt})
body = {
"model": model.model_name,
"messages": messages,
}
body = self.apply_config_override(model, body)
return RequestData(url=url, headers=headers, body=body)

530
zhenxun/services/llm/api.py Normal file
View File

@ -0,0 +1,530 @@
"""
LLM 服务的高级 API 接口
"""
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any
from nonebot_plugin_alconna.uniseg import UniMessage
from zhenxun.services.log import logger
from .config import CommonOverrides, LLMGenerationConfig
from .config.providers import get_ai_config
from .manager import get_global_default_model_name, get_model_instance
from .types import (
EmbeddingTaskType,
LLMContentPart,
LLMErrorCode,
LLMException,
LLMMessage,
LLMResponse,
LLMTool,
ModelName,
)
from .utils import create_multimodal_message, unimsg_to_llm_parts
class TaskType(Enum):
"""任务类型枚举"""
CHAT = "chat"
CODE = "code"
SEARCH = "search"
ANALYSIS = "analysis"
GENERATION = "generation"
MULTIMODAL = "multimodal"
@dataclass
class AIConfig:
"""AI配置类 - 简化版本"""
model: ModelName = None
default_embedding_model: ModelName = None
temperature: float | None = None
max_tokens: int | None = None
enable_cache: bool = False
enable_code: bool = False
enable_search: bool = False
timeout: int | None = None
enable_gemini_json_mode: bool = False
enable_gemini_thinking: bool = False
enable_gemini_safe_mode: bool = False
enable_gemini_multimodal: bool = False
enable_gemini_grounding: bool = False
def __post_init__(self):
"""初始化后从配置中读取默认值"""
ai_config = get_ai_config()
if self.model is None:
self.model = ai_config.get("default_model_name")
if self.timeout is None:
self.timeout = ai_config.get("timeout", 180)
class AI:
"""统一的AI服务类 - 平衡设计版本
提供三层API
1. 简单方法ai.chat(), ai.code(), ai.search()
2. 标准方法ai.analyze() 支持复杂参数
3. 高级方法通过get_model_instance()直接访问
"""
def __init__(
self, config: AIConfig | None = None, history: list[LLMMessage] | None = None
):
"""
初始化AI服务
Args:
config: AI 配置.
history: 可选的初始对话历史.
"""
self.config = config or AIConfig()
self.history = history or []
def clear_history(self):
"""清空当前会话的历史记录"""
self.history = []
logger.info("AI session history cleared.")
async def chat(
self,
message: str | LLMMessage | list[LLMContentPart],
*,
model: ModelName = None,
**kwargs: Any,
) -> str:
"""
进行一次聊天对话
此方法会自动使用和更新会话内的历史记录
"""
current_message: LLMMessage
if isinstance(message, str):
current_message = LLMMessage.user(message)
elif isinstance(message, list) and all(
isinstance(part, LLMContentPart) for part in message
):
current_message = LLMMessage.user(message)
elif isinstance(message, LLMMessage):
current_message = message
else:
raise LLMException(
f"AI.chat 不支持的消息类型: {type(message)}. "
"请使用 str, LLMMessage, 或 list[LLMContentPart]. "
"对于更复杂的多模态输入或文件路径,请使用 AI.analyze().",
code=LLMErrorCode.API_REQUEST_FAILED,
)
final_messages = [*self.history, current_message]
response = await self._execute_generation(
final_messages, model, "聊天失败", kwargs
)
self.history.append(current_message)
self.history.append(LLMMessage.assistant_text_response(response.text))
return response.text
async def code(
self,
prompt: str,
*,
model: ModelName = None,
timeout: int | None = None,
**kwargs: Any,
) -> dict[str, Any]:
"""代码执行"""
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
config = CommonOverrides.gemini_code_execution()
if timeout:
config.custom_params = config.custom_params or {}
config.custom_params["code_execution_timeout"] = timeout
messages = [LLMMessage.user(prompt)]
response = await self._execute_generation(
messages, resolved_model, "代码执行失败", kwargs, base_config=config
)
return {
"text": response.text,
"code_executions": response.code_executions or [],
"success": True,
}
async def search(
self,
query: str | UniMessage,
*,
model: ModelName = None,
instruction: str = "",
**kwargs: Any,
) -> dict[str, Any]:
"""信息搜索 - 支持多模态输入"""
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
config = CommonOverrides.gemini_grounding()
if isinstance(query, str):
messages = [LLMMessage.user(query)]
elif isinstance(query, UniMessage):
content_parts = await unimsg_to_llm_parts(query)
final_messages: list[LLMMessage] = []
if instruction:
final_messages.append(LLMMessage.system(instruction))
if not content_parts:
if instruction:
final_messages.append(LLMMessage.user(instruction))
else:
raise LLMException(
"搜索内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
)
else:
final_messages.append(LLMMessage.user(content_parts))
messages = final_messages
else:
raise LLMException(
f"不支持的搜索输入类型: {type(query)}. 请使用 str 或 UniMessage.",
code=LLMErrorCode.API_REQUEST_FAILED,
)
response = await self._execute_generation(
messages, resolved_model, "信息搜索失败", kwargs, base_config=config
)
result = {
"text": response.text,
"sources": [],
"queries": [],
"success": True,
}
if response.grounding_metadata:
result["sources"] = response.grounding_metadata.grounding_attributions or []
result["queries"] = response.grounding_metadata.web_search_queries or []
return result
async def analyze(
self,
message: UniMessage,
*,
instruction: str = "",
model: ModelName = None,
tools: list[dict[str, Any]] | None = None,
tool_config: dict[str, Any] | None = None,
**kwargs: Any,
) -> str | LLMResponse:
"""
内容分析 - 接收 UniMessage 物件进行多模态分析和工具呼叫
这是处理复杂互动的主要方法
"""
content_parts = await unimsg_to_llm_parts(message)
final_messages: list[LLMMessage] = []
if instruction:
final_messages.append(LLMMessage.system(instruction))
if not content_parts:
if instruction:
final_messages.append(LLMMessage.user(instruction))
else:
raise LLMException(
"分析内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
)
else:
final_messages.append(LLMMessage.user(content_parts))
llm_tools = None
if tools:
llm_tools = []
for tool_dict in tools:
if isinstance(tool_dict, dict):
if "name" in tool_dict and "description" in tool_dict:
llm_tool = LLMTool(
type="function",
function={
"name": tool_dict["name"],
"description": tool_dict["description"],
"parameters": tool_dict.get("parameters", {}),
},
)
llm_tools.append(llm_tool)
else:
llm_tools.append(LLMTool(**tool_dict))
else:
llm_tools.append(tool_dict)
tool_choice = None
if tool_config:
mode = tool_config.get("mode", "auto")
if mode == "auto":
tool_choice = "auto"
elif mode == "any":
tool_choice = "any"
elif mode == "none":
tool_choice = "none"
response = await self._execute_generation(
final_messages,
model,
"内容分析失败",
kwargs,
llm_tools=llm_tools,
tool_choice=tool_choice,
)
if response.tool_calls:
return response
return response.text
async def _execute_generation(
self,
messages: list[LLMMessage],
model_name: ModelName,
error_message: str,
config_overrides: dict[str, Any],
llm_tools: list[LLMTool] | None = None,
tool_choice: str | dict[str, Any] | None = None,
base_config: LLMGenerationConfig | None = None,
) -> LLMResponse:
"""通用的生成执行方法,封装重复的模型获取、配置合并和异常处理逻辑"""
try:
resolved_model_name = self._resolve_model_name(
model_name or self.config.model
)
final_config_dict = self._merge_config(
config_overrides, base_config=base_config
)
async with await get_model_instance(
resolved_model_name, override_config=final_config_dict
) as model_instance:
return await model_instance.generate_response(
messages, tools=llm_tools, tool_choice=tool_choice
)
except LLMException:
raise
except Exception as e:
logger.error(f"{error_message}: {e}", e=e)
raise LLMException(f"{error_message}: {e}", cause=e)
def _resolve_model_name(self, model_name: ModelName) -> str:
"""解析模型名称"""
if model_name:
return model_name
default_model = get_global_default_model_name()
if default_model:
return default_model
raise LLMException(
"未指定模型名称且未设置全局默认模型",
code=LLMErrorCode.MODEL_NOT_FOUND,
)
def _merge_config(
self,
user_config: dict[str, Any],
base_config: LLMGenerationConfig | None = None,
) -> dict[str, Any]:
"""合并配置"""
final_config = {}
if base_config:
final_config.update(base_config.to_dict())
if self.config.temperature is not None:
final_config["temperature"] = self.config.temperature
if self.config.max_tokens is not None:
final_config["max_tokens"] = self.config.max_tokens
if self.config.enable_cache:
final_config["enable_caching"] = True
if self.config.enable_code:
final_config["enable_code_execution"] = True
if self.config.enable_search:
final_config["enable_grounding"] = True
if self.config.enable_gemini_json_mode:
final_config["response_mime_type"] = "application/json"
if self.config.enable_gemini_thinking:
final_config["thinking_budget"] = 0.8
if self.config.enable_gemini_safe_mode:
final_config["safety_settings"] = (
CommonOverrides.gemini_safe().safety_settings
)
if self.config.enable_gemini_multimodal:
final_config.update(CommonOverrides.gemini_multimodal().to_dict())
if self.config.enable_gemini_grounding:
final_config["enable_grounding"] = True
final_config.update(user_config)
return final_config
async def embed(
self,
texts: list[str] | str,
*,
model: ModelName = None,
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
**kwargs: Any,
) -> list[list[float]]:
"""生成文本嵌入向量"""
if isinstance(texts, str):
texts = [texts]
if not texts:
return []
try:
resolved_model_str = (
model or self.config.default_embedding_model or self.config.model
)
if not resolved_model_str:
raise LLMException(
"使用 embed 功能时必须指定嵌入模型名称,"
"或在 AIConfig 中配置 default_embedding_model。",
code=LLMErrorCode.MODEL_NOT_FOUND,
)
resolved_model_str = self._resolve_model_name(resolved_model_str)
async with await get_model_instance(
resolved_model_str,
override_config=None,
) as embedding_model_instance:
return await embedding_model_instance.generate_embeddings(
texts, task_type=task_type, **kwargs
)
except LLMException:
raise
except Exception as e:
logger.error(f"文本嵌入失败: {e}", e=e)
raise LLMException(
f"文本嵌入失败: {e}", code=LLMErrorCode.EMBEDDING_FAILED, cause=e
)
async def chat(
message: str | LLMMessage | list[LLMContentPart],
*,
model: ModelName = None,
**kwargs: Any,
) -> str:
"""聊天对话便捷函数"""
ai = AI()
return await ai.chat(message, model=model, **kwargs)
async def code(
prompt: str,
*,
model: ModelName = None,
timeout: int | None = None,
**kwargs: Any,
) -> dict[str, Any]:
"""代码执行便捷函数"""
ai = AI()
return await ai.code(prompt, model=model, timeout=timeout, **kwargs)
async def search(
query: str | UniMessage,
*,
model: ModelName = None,
instruction: str = "",
**kwargs: Any,
) -> dict[str, Any]:
"""信息搜索便捷函数"""
ai = AI()
return await ai.search(query, model=model, instruction=instruction, **kwargs)
async def analyze(
message: UniMessage,
*,
instruction: str = "",
model: ModelName = None,
tools: list[dict[str, Any]] | None = None,
tool_config: dict[str, Any] | None = None,
**kwargs: Any,
) -> str | LLMResponse:
"""内容分析便捷函数"""
ai = AI()
return await ai.analyze(
message,
instruction=instruction,
model=model,
tools=tools,
tool_config=tool_config,
**kwargs,
)
async def analyze_with_images(
text: str,
images: list[str | Path | bytes] | str | Path | bytes,
*,
instruction: str = "",
model: ModelName = None,
**kwargs: Any,
) -> str | LLMResponse:
"""图片分析便捷函数"""
message = create_multimodal_message(text=text, images=images)
return await analyze(message, instruction=instruction, model=model, **kwargs)
async def analyze_multimodal(
text: str | None = None,
images: list[str | Path | bytes] | str | Path | bytes | None = None,
videos: list[str | Path | bytes] | str | Path | bytes | None = None,
audios: list[str | Path | bytes] | str | Path | bytes | None = None,
*,
instruction: str = "",
model: ModelName = None,
**kwargs: Any,
) -> str | LLMResponse:
"""多模态分析便捷函数"""
message = create_multimodal_message(
text=text, images=images, videos=videos, audios=audios
)
return await analyze(message, instruction=instruction, model=model, **kwargs)
async def search_multimodal(
text: str | None = None,
images: list[str | Path | bytes] | str | Path | bytes | None = None,
videos: list[str | Path | bytes] | str | Path | bytes | None = None,
audios: list[str | Path | bytes] | str | Path | bytes | None = None,
*,
instruction: str = "",
model: ModelName = None,
**kwargs: Any,
) -> dict[str, Any]:
"""多模态搜索便捷函数"""
message = create_multimodal_message(
text=text, images=images, videos=videos, audios=audios
)
ai = AI()
return await ai.search(message, model=model, instruction=instruction, **kwargs)
async def embed(
texts: list[str] | str,
*,
model: ModelName = None,
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
**kwargs: Any,
) -> list[list[float]]:
"""文本嵌入便捷函数"""
ai = AI()
return await ai.embed(texts, model=model, task_type=task_type, **kwargs)

View File

@ -0,0 +1,35 @@
"""
LLM 配置模块
提供生成配置预设配置和配置验证功能
"""
from .generation import (
LLMGenerationConfig,
ModelConfigOverride,
apply_api_specific_mappings,
create_generation_config_from_kwargs,
validate_override_params,
)
from .presets import CommonOverrides
from .providers import (
LLMConfig,
get_llm_config,
register_llm_configs,
set_default_model,
validate_llm_config,
)
__all__ = [
"CommonOverrides",
"LLMConfig",
"LLMGenerationConfig",
"ModelConfigOverride",
"apply_api_specific_mappings",
"create_generation_config_from_kwargs",
"get_llm_config",
"register_llm_configs",
"set_default_model",
"validate_llm_config",
"validate_override_params",
]

View File

@ -0,0 +1,260 @@
"""
LLM 生成配置相关类和函数
"""
from typing import Any
from pydantic import BaseModel, Field
from zhenxun.services.log import logger
from ..types.enums import ResponseFormat
from ..types.exceptions import LLMErrorCode, LLMException
class ModelConfigOverride(BaseModel):
"""模型配置覆盖参数"""
temperature: float | None = Field(
default=None, ge=0.0, le=2.0, description="生成温度"
)
max_tokens: int | None = Field(default=None, gt=0, description="最大输出token数")
top_p: float | None = Field(default=None, ge=0.0, le=1.0, description="核采样参数")
top_k: int | None = Field(default=None, gt=0, description="Top-K采样参数")
frequency_penalty: float | None = Field(
default=None, ge=-2.0, le=2.0, description="频率惩罚"
)
presence_penalty: float | None = Field(
default=None, ge=-2.0, le=2.0, description="存在惩罚"
)
repetition_penalty: float | None = Field(
default=None, ge=0.0, le=2.0, description="重复惩罚"
)
stop: list[str] | str | None = Field(default=None, description="停止序列")
response_format: ResponseFormat | dict[str, Any] | None = Field(
default=None, description="期望的响应格式"
)
response_mime_type: str | None = Field(
default=None, description="响应MIME类型Gemini专用"
)
response_schema: dict[str, Any] | None = Field(
default=None, description="JSON响应模式"
)
thinking_budget: float | None = Field(
default=None, ge=0.0, le=1.0, description="思考预算"
)
safety_settings: dict[str, str] | None = Field(default=None, description="安全设置")
response_modalities: list[str] | None = Field(
default=None, description="响应模态类型"
)
enable_code_execution: bool | None = Field(
default=None, description="是否启用代码执行"
)
enable_grounding: bool | None = Field(
default=None, description="是否启用信息来源关联"
)
enable_caching: bool | None = Field(default=None, description="是否启用响应缓存")
custom_params: dict[str, Any] | None = Field(default=None, description="自定义参数")
def to_dict(self) -> dict[str, Any]:
"""转换为字典排除None值"""
result = {}
model_data = getattr(self, "model_dump", lambda: {})()
if not model_data:
model_data = {}
for field_name, _ in self.__class__.__dict__.get(
"model_fields", {}
).items():
value = getattr(self, field_name, None)
if value is not None:
model_data[field_name] = value
for key, value in model_data.items():
if value is not None:
if key == "custom_params" and isinstance(value, dict):
result.update(value)
else:
result[key] = value
return result
def merge_with_base_config(
self,
base_temperature: float | None = None,
base_max_tokens: int | None = None,
) -> dict[str, Any]:
"""与基础配置合并,覆盖参数优先"""
merged = {}
if base_temperature is not None:
merged["temperature"] = base_temperature
if base_max_tokens is not None:
merged["max_tokens"] = base_max_tokens
override_dict = self.to_dict()
merged.update(override_dict)
return merged
class LLMGenerationConfig(ModelConfigOverride):
"""LLM 生成配置,继承模型配置覆盖参数"""
def to_api_params(self, api_type: str, model_name: str) -> dict[str, Any]:
"""转换为API参数支持不同API类型的参数名映射"""
_ = model_name
params = {}
if self.temperature is not None:
params["temperature"] = self.temperature
if self.max_tokens is not None:
if api_type in ["gemini", "gemini_native"]:
params["maxOutputTokens"] = self.max_tokens
else:
params["max_tokens"] = self.max_tokens
if api_type in ["gemini", "gemini_native"]:
if self.top_k is not None:
params["topK"] = self.top_k
if self.top_p is not None:
params["topP"] = self.top_p
else:
if self.top_k is not None:
params["top_k"] = self.top_k
if self.top_p is not None:
params["top_p"] = self.top_p
if api_type in ["openai", "deepseek", "zhipu", "general_openai_compat"]:
if self.frequency_penalty is not None:
params["frequency_penalty"] = self.frequency_penalty
if self.presence_penalty is not None:
params["presence_penalty"] = self.presence_penalty
if self.repetition_penalty is not None:
if api_type == "openai":
logger.warning("OpenAI官方API不支持repetition_penalty参数已忽略")
else:
params["repetition_penalty"] = self.repetition_penalty
if self.response_format is not None:
if isinstance(self.response_format, dict):
if api_type in ["openai", "zhipu", "deepseek", "general_openai_compat"]:
params["response_format"] = self.response_format
logger.debug(
f"{api_type} 使用自定义 response_format: "
f"{self.response_format}"
)
elif self.response_format == ResponseFormat.JSON:
if api_type in ["openai", "zhipu", "deepseek", "general_openai_compat"]:
params["response_format"] = {"type": "json_object"}
logger.debug(f"{api_type} 启用 JSON 对象输出模式")
elif api_type in ["gemini", "gemini_native"]:
params["responseMimeType"] = "application/json"
if self.response_schema:
params["responseSchema"] = self.response_schema
logger.debug(f"{api_type} 启用 JSON MIME 类型输出模式")
if api_type in ["gemini", "gemini_native"]:
if (
self.response_format != ResponseFormat.JSON
and self.response_mime_type is not None
):
params["responseMimeType"] = self.response_mime_type
logger.debug(
f"使用显式设置的 responseMimeType: {self.response_mime_type}"
)
if self.response_schema is not None and "responseSchema" not in params:
params["responseSchema"] = self.response_schema
if self.thinking_budget is not None:
params["thinkingBudget"] = self.thinking_budget
if self.safety_settings is not None:
params["safetySettings"] = self.safety_settings
if self.response_modalities is not None:
params["responseModalities"] = self.response_modalities
if self.custom_params:
custom_mapped = apply_api_specific_mappings(self.custom_params, api_type)
params.update(custom_mapped)
logger.debug(f"{api_type}转换配置参数: {len(params)}个参数")
return params
def validate_override_params(
override_config: dict[str, Any] | LLMGenerationConfig | None,
) -> LLMGenerationConfig:
"""验证和标准化覆盖参数"""
if override_config is None:
return LLMGenerationConfig()
if isinstance(override_config, dict):
try:
filtered_config = {
k: v for k, v in override_config.items() if v is not None
}
return LLMGenerationConfig(**filtered_config)
except Exception as e:
logger.warning(f"覆盖配置参数验证失败: {e}")
raise LLMException(
f"无效的覆盖配置参数: {e}",
code=LLMErrorCode.CONFIGURATION_ERROR,
cause=e,
)
return override_config
def apply_api_specific_mappings(
params: dict[str, Any], api_type: str
) -> dict[str, Any]:
"""应用API特定的参数映射"""
mapped_params = params.copy()
if api_type in ["gemini", "gemini_native"]:
if "max_tokens" in mapped_params:
mapped_params["maxOutputTokens"] = mapped_params.pop("max_tokens")
if "top_k" in mapped_params:
mapped_params["topK"] = mapped_params.pop("top_k")
if "top_p" in mapped_params:
mapped_params["topP"] = mapped_params.pop("top_p")
unsupported = ["frequency_penalty", "presence_penalty", "repetition_penalty"]
for param in unsupported:
if param in mapped_params:
logger.warning(f"Gemini 原生API不支持参数 '{param}',已忽略")
mapped_params.pop(param)
elif api_type in ["openai", "deepseek", "zhipu", "general_openai_compat"]:
if "repetition_penalty" in mapped_params and api_type == "openai":
logger.warning("OpenAI官方API不支持repetition_penalty参数已忽略")
mapped_params.pop("repetition_penalty")
if "stop" in mapped_params:
stop_value = mapped_params["stop"]
if isinstance(stop_value, str):
mapped_params["stop"] = [stop_value]
return mapped_params
def create_generation_config_from_kwargs(**kwargs) -> LLMGenerationConfig:
"""从关键字参数创建生成配置"""
model_fields = getattr(LLMGenerationConfig, "model_fields", {})
known_fields = set(model_fields.keys())
known_params = {}
custom_params = {}
for key, value in kwargs.items():
if key in known_fields:
known_params[key] = value
else:
custom_params[key] = value
if custom_params:
known_params["custom_params"] = custom_params
return LLMGenerationConfig(**known_params)

View File

@ -0,0 +1,169 @@
"""
LLM 预设配置
提供常用的配置预设特别是针对 Gemini 的高级功能
"""
from typing import Any
from .generation import LLMGenerationConfig
class CommonOverrides:
"""常用的配置覆盖预设"""
@staticmethod
def creative() -> LLMGenerationConfig:
"""创意模式:高温度,鼓励创新"""
return LLMGenerationConfig(temperature=0.9, top_p=0.95, frequency_penalty=0.1)
@staticmethod
def precise() -> LLMGenerationConfig:
"""精确模式:低温度,确定性输出"""
return LLMGenerationConfig(temperature=0.1, top_p=0.9, frequency_penalty=0.0)
@staticmethod
def balanced() -> LLMGenerationConfig:
"""平衡模式:中等温度"""
return LLMGenerationConfig(temperature=0.5, top_p=0.9, frequency_penalty=0.0)
@staticmethod
def concise(max_tokens: int = 100) -> LLMGenerationConfig:
"""简洁模式:限制输出长度"""
return LLMGenerationConfig(
temperature=0.3,
max_tokens=max_tokens,
stop=["\n\n", "", "", ""],
)
@staticmethod
def detailed(max_tokens: int = 2000) -> LLMGenerationConfig:
"""详细模式:鼓励详细输出"""
return LLMGenerationConfig(
temperature=0.7, max_tokens=max_tokens, frequency_penalty=-0.1
)
@staticmethod
def gemini_json() -> LLMGenerationConfig:
"""Gemini JSON模式强制JSON输出"""
return LLMGenerationConfig(
temperature=0.3, response_mime_type="application/json"
)
@staticmethod
def gemini_thinking(budget: float = 0.8) -> LLMGenerationConfig:
"""Gemini 思考模式:使用思考预算"""
return LLMGenerationConfig(temperature=0.7, thinking_budget=budget)
@staticmethod
def gemini_creative() -> LLMGenerationConfig:
"""Gemini 创意模式:高温度创意输出"""
return LLMGenerationConfig(temperature=0.9, top_p=0.95)
@staticmethod
def gemini_structured(schema: dict[str, Any]) -> LLMGenerationConfig:
"""Gemini 结构化输出自定义JSON模式"""
return LLMGenerationConfig(
temperature=0.3,
response_mime_type="application/json",
response_schema=schema,
)
@staticmethod
def gemini_safe() -> LLMGenerationConfig:
"""Gemini 安全模式:严格安全设置"""
return LLMGenerationConfig(
temperature=0.5,
safety_settings={
"HARM_CATEGORY_HARASSMENT": "BLOCK_MEDIUM_AND_ABOVE",
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_MEDIUM_AND_ABOVE",
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_MEDIUM_AND_ABOVE",
"HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_MEDIUM_AND_ABOVE",
},
)
@staticmethod
def gemini_multimodal() -> LLMGenerationConfig:
"""Gemini 多模态模式:优化多模态处理"""
return LLMGenerationConfig(temperature=0.6, max_tokens=2048, top_p=0.8)
@staticmethod
def gemini_code_execution() -> LLMGenerationConfig:
"""Gemini 代码执行模式:启用代码执行功能"""
return LLMGenerationConfig(
temperature=0.3,
max_tokens=4096,
enable_code_execution=True,
custom_params={"code_execution_timeout": 30},
)
@staticmethod
def gemini_grounding() -> LLMGenerationConfig:
"""Gemini 信息来源关联模式启用Google搜索"""
return LLMGenerationConfig(
temperature=0.5,
max_tokens=4096,
enable_grounding=True,
custom_params={
"grounding_config": {"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}}
},
)
@staticmethod
def gemini_cached() -> LLMGenerationConfig:
"""Gemini 缓存模式:启用响应缓存"""
return LLMGenerationConfig(
temperature=0.3,
max_tokens=2048,
enable_caching=True,
)
@staticmethod
def gemini_advanced() -> LLMGenerationConfig:
"""Gemini 高级模式:启用所有高级功能"""
return LLMGenerationConfig(
temperature=0.5,
max_tokens=4096,
enable_code_execution=True,
enable_grounding=True,
enable_caching=True,
custom_params={
"code_execution_timeout": 30,
"grounding_config": {
"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}
},
},
)
@staticmethod
def gemini_research() -> LLMGenerationConfig:
"""Gemini 研究模式:思考+搜索+结构化输出"""
return LLMGenerationConfig(
temperature=0.6,
max_tokens=4096,
thinking_budget=0.8,
enable_grounding=True,
response_mime_type="application/json",
custom_params={
"grounding_config": {"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}}
},
)
@staticmethod
def gemini_analysis() -> LLMGenerationConfig:
"""Gemini 分析模式:深度思考+详细输出"""
return LLMGenerationConfig(
temperature=0.4,
max_tokens=6000,
thinking_budget=0.9,
top_p=0.8,
)
@staticmethod
def gemini_fast_response() -> LLMGenerationConfig:
"""Gemini 快速响应模式:低延迟+简洁输出"""
return LLMGenerationConfig(
temperature=0.3,
max_tokens=512,
top_p=0.8,
)

View File

@ -0,0 +1,328 @@
"""
LLM 提供商配置管理
负责注册和管理 AI 服务提供商的配置项
"""
from typing import Any
from pydantic import BaseModel, Field
from zhenxun.configs.config import Config
from zhenxun.services.log import logger
from ..types.models import ModelDetail, ProviderConfig
AI_CONFIG_GROUP = "AI"
PROVIDERS_CONFIG_KEY = "PROVIDERS"
class LLMConfig(BaseModel):
"""LLM 服务配置类"""
default_model_name: str | None = Field(
default=None,
description="LLM服务全局默认使用的模型名称 (格式: ProviderName/ModelName)",
)
proxy: str | None = Field(
default=None,
description="LLM服务请求使用的网络代理例如 http://127.0.0.1:7890",
)
timeout: int = Field(default=180, description="LLM服务API请求超时时间")
max_retries_llm: int = Field(
default=3, description="LLM服务请求失败时的最大重试次数"
)
retry_delay_llm: int = Field(
default=2, description="LLM服务请求重试的基础延迟时间"
)
providers: list[ProviderConfig] = Field(
default_factory=list, description="配置多个 AI 服务提供商及其模型信息"
)
def get_provider_by_name(self, name: str) -> ProviderConfig | None:
"""根据名称获取提供商配置
参数:
name: 提供商名称
返回:
ProviderConfig | None: 提供商配置如果未找到则返回 None
"""
for provider in self.providers:
if provider.name == name:
return provider
return None
def get_model_by_provider_and_name(
self, provider_name: str, model_name: str
) -> tuple[ProviderConfig, ModelDetail] | None:
"""根据提供商名称和模型名称获取配置
参数:
provider_name: 提供商名称
model_name: 模型名称
返回:
tuple[ProviderConfig, ModelDetail] | None: 提供商配置和模型详情的元组
如果未找到则返回 None
"""
provider = self.get_provider_by_name(provider_name)
if not provider:
return None
for model in provider.models:
if model.model_name == model_name:
return provider, model
return None
def list_available_models(self) -> list[dict[str, Any]]:
"""列出所有可用的模型
返回:
list[dict[str, Any]]: 模型信息列表
"""
models = []
for provider in self.providers:
for model in provider.models:
models.append(
{
"provider_name": provider.name,
"model_name": model.model_name,
"full_name": f"{provider.name}/{model.model_name}",
"is_available": model.is_available,
"is_embedding_model": model.is_embedding_model,
"api_type": provider.api_type,
}
)
return models
def validate_model_name(self, provider_model_name: str) -> bool:
"""验证模型名称格式是否正确
参数:
provider_model_name: 格式为 "ProviderName/ModelName" 的字符串
返回:
bool: 是否有效
"""
if not provider_model_name or "/" not in provider_model_name:
return False
parts = provider_model_name.split("/", 1)
if len(parts) != 2:
return False
provider_name, model_name = parts
return (
self.get_model_by_provider_and_name(provider_name, model_name) is not None
)
def get_ai_config():
"""获取 AI 配置组"""
return Config.get(AI_CONFIG_GROUP)
def get_default_providers() -> list[dict[str, Any]]:
"""获取默认的提供商配置
返回:
list[dict[str, Any]]: 默认提供商配置列表
"""
return [
{
"name": "DeepSeek",
"api_key": "sk-******",
"api_base": "https://api.deepseek.com",
"api_type": "openai",
"models": [
{
"model_name": "deepseek-chat",
"max_tokens": 4096,
"temperature": 0.7,
},
{
"model_name": "deepseek-reasoner",
},
],
},
{
"name": "GLM",
"api_key": "",
"api_base": "https://open.bigmodel.cn",
"api_type": "zhipu",
"models": [
{"model_name": "glm-4-flash"},
{"model_name": "glm-4-plus"},
],
},
{
"name": "Gemini",
"api_key": [
"AIzaSy*****************************",
"AIzaSy*****************************",
"AIzaSy*****************************",
],
"api_base": "https://generativelanguage.googleapis.com",
"api_type": "gemini",
"models": [
{"model_name": "gemini-2.0-flash"},
{"model_name": "gemini-2.5-flash-preview-05-20"},
],
},
]
def register_llm_configs():
"""注册 LLM 服务的配置项"""
logger.info("注册 LLM 服务的配置项")
llm_config = LLMConfig()
Config.add_plugin_config(
AI_CONFIG_GROUP,
"default_model_name",
llm_config.default_model_name,
help="LLM服务全局默认使用的模型名称 (格式: ProviderName/ModelName)",
type=str,
)
Config.add_plugin_config(
AI_CONFIG_GROUP,
"proxy",
llm_config.proxy,
help="LLM服务请求使用的网络代理例如 http://127.0.0.1:7890",
type=str,
)
Config.add_plugin_config(
AI_CONFIG_GROUP,
"timeout",
llm_config.timeout,
help="LLM服务API请求超时时间",
type=int,
)
Config.add_plugin_config(
AI_CONFIG_GROUP,
"max_retries_llm",
llm_config.max_retries_llm,
help="LLM服务请求失败时的最大重试次数",
type=int,
)
Config.add_plugin_config(
AI_CONFIG_GROUP,
"retry_delay_llm",
llm_config.retry_delay_llm,
help="LLM服务请求重试的基础延迟时间",
type=int,
)
Config.add_plugin_config(
AI_CONFIG_GROUP,
PROVIDERS_CONFIG_KEY,
get_default_providers(),
help="配置多个 AI 服务提供商及其模型信息",
default_value=[],
type=list[ProviderConfig],
)
def get_llm_config() -> LLMConfig:
"""获取 LLM 配置实例
返回:
LLMConfig: LLM 配置实例
"""
ai_config = get_ai_config()
config_data = {
"default_model_name": ai_config.get("default_model_name"),
"proxy": ai_config.get("proxy"),
"timeout": ai_config.get("timeout", 180),
"max_retries_llm": ai_config.get("max_retries_llm", 3),
"retry_delay_llm": ai_config.get("retry_delay_llm", 2),
"providers": ai_config.get(PROVIDERS_CONFIG_KEY, []),
}
return LLMConfig(**config_data)
def validate_llm_config() -> tuple[bool, list[str]]:
"""验证 LLM 配置的有效性
返回:
tuple[bool, list[str]]: (是否有效, 错误信息列表)
"""
errors = []
try:
llm_config = get_llm_config()
if llm_config.timeout <= 0:
errors.append("timeout 必须大于 0")
if llm_config.max_retries_llm < 0:
errors.append("max_retries_llm 不能小于 0")
if llm_config.retry_delay_llm <= 0:
errors.append("retry_delay_llm 必须大于 0")
if not llm_config.providers:
errors.append("至少需要配置一个 AI 服务提供商")
else:
provider_names = set()
for provider in llm_config.providers:
if provider.name in provider_names:
errors.append(f"提供商名称重复: {provider.name}")
provider_names.add(provider.name)
if not provider.api_key:
errors.append(f"提供商 {provider.name} 缺少 API Key")
if not provider.models:
errors.append(f"提供商 {provider.name} 没有配置任何模型")
else:
model_names = set()
for model in provider.models:
if model.model_name in model_names:
errors.append(
f"提供商 {provider.name} 中模型名称重复: "
f"{model.model_name}"
)
model_names.add(model.model_name)
if llm_config.default_model_name:
if not llm_config.validate_model_name(llm_config.default_model_name):
errors.append(
f"默认模型 {llm_config.default_model_name} 在配置中不存在"
)
except Exception as e:
errors.append(f"配置解析失败: {e!s}")
return len(errors) == 0, errors
def set_default_model(provider_model_name: str | None) -> bool:
"""设置默认模型
参数:
provider_model_name: 模型名称格式为 "ProviderName/ModelName"None 表示清除
返回:
bool: 是否设置成功
"""
if provider_model_name:
llm_config = get_llm_config()
if not llm_config.validate_model_name(provider_model_name):
logger.error(f"模型 {provider_model_name} 在配置中不存在")
return False
Config.set_config(
AI_CONFIG_GROUP, "default_model_name", provider_model_name, auto_save=True
)
if provider_model_name:
logger.info(f"默认模型已设置为: {provider_model_name}")
else:
logger.info("默认模型已清除")
return True

View File

@ -0,0 +1,378 @@
"""
LLM 核心基础设施模块
包含执行 LLM 请求所需的底层组件 HTTP 客户端API Key 存储和智能重试逻辑
"""
import asyncio
from typing import Any
import httpx
from pydantic import BaseModel
from zhenxun.services.log import logger
from zhenxun.utils.user_agent import get_user_agent
from .types import ProviderConfig
from .types.exceptions import LLMErrorCode, LLMException
class HttpClientConfig(BaseModel):
"""HTTP客户端配置"""
timeout: int = 180
max_connections: int = 100
max_keepalive_connections: int = 20
proxy: str | None = None
class LLMHttpClient:
"""LLM服务专用HTTP客户端"""
def __init__(self, config: HttpClientConfig | None = None):
self.config = config or HttpClientConfig()
self._client: httpx.AsyncClient | None = None
self._active_requests = 0
self._lock = asyncio.Lock()
async def _ensure_client_initialized(self) -> httpx.AsyncClient:
if self._client is None or self._client.is_closed:
async with self._lock:
if self._client is None or self._client.is_closed:
logger.debug(
f"LLMHttpClient: Initializing new httpx.AsyncClient "
f"with config: {self.config}"
)
headers = get_user_agent()
limits = httpx.Limits(
max_connections=self.config.max_connections,
max_keepalive_connections=self.config.max_keepalive_connections,
)
timeout = httpx.Timeout(self.config.timeout)
self._client = httpx.AsyncClient(
headers=headers,
limits=limits,
timeout=timeout,
proxies=self.config.proxy,
follow_redirects=True,
)
if self._client is None:
raise LLMException(
"HTTP client failed to initialize.", LLMErrorCode.CONFIGURATION_ERROR
)
return self._client
async def post(self, url: str, **kwargs: Any) -> httpx.Response:
client = await self._ensure_client_initialized()
async with self._lock:
self._active_requests += 1
try:
return await client.post(url, **kwargs)
finally:
async with self._lock:
self._active_requests -= 1
async def close(self):
async with self._lock:
if self._client and not self._client.is_closed:
logger.debug(
f"LLMHttpClient: Closing with config: {self.config}. "
f"Active requests: {self._active_requests}"
)
if self._active_requests > 0:
logger.warning(
f"LLMHttpClient: Closing while {self._active_requests} "
f"requests are still active."
)
await self._client.aclose()
self._client = None
logger.debug(f"LLMHttpClient for config {self.config} definitively closed.")
@property
def is_closed(self) -> bool:
return self._client is None or self._client.is_closed
class LLMHttpClientManager:
"""管理 LLMHttpClient 实例的工厂和池"""
def __init__(self):
self._clients: dict[tuple[int, str | None], LLMHttpClient] = {}
self._lock = asyncio.Lock()
def _get_client_key(
self, provider_config: ProviderConfig
) -> tuple[int, str | None]:
return (provider_config.timeout, provider_config.proxy)
async def get_client(self, provider_config: ProviderConfig) -> LLMHttpClient:
key = self._get_client_key(provider_config)
async with self._lock:
client = self._clients.get(key)
if client and not client.is_closed:
logger.debug(
f"LLMHttpClientManager: Reusing existing LLMHttpClient "
f"for key: {key}"
)
return client
if client and client.is_closed:
logger.debug(
f"LLMHttpClientManager: Found a closed client for key {key}. "
f"Creating a new one."
)
logger.debug(
f"LLMHttpClientManager: Creating new LLMHttpClient for key: {key}"
)
http_client_config = HttpClientConfig(
timeout=provider_config.timeout, proxy=provider_config.proxy
)
new_client = LLMHttpClient(config=http_client_config)
self._clients[key] = new_client
return new_client
async def shutdown(self):
async with self._lock:
logger.info(
f"LLMHttpClientManager: Shutting down. "
f"Closing {len(self._clients)} client(s)."
)
close_tasks = [
client.close()
for client in self._clients.values()
if client and not client.is_closed
]
if close_tasks:
await asyncio.gather(*close_tasks, return_exceptions=True)
self._clients.clear()
logger.info("LLMHttpClientManager: Shutdown complete.")
http_client_manager = LLMHttpClientManager()
async def create_llm_http_client(
timeout: int = 180,
proxy: str | None = None,
) -> LLMHttpClient:
"""创建LLM HTTP客户端"""
config = HttpClientConfig(timeout=timeout, proxy=proxy)
return LLMHttpClient(config)
class RetryConfig:
"""重试配置"""
def __init__(
self,
max_retries: int = 3,
retry_delay: float = 1.0,
exponential_backoff: bool = True,
key_rotation: bool = True,
):
self.max_retries = max_retries
self.retry_delay = retry_delay
self.exponential_backoff = exponential_backoff
self.key_rotation = key_rotation
async def with_smart_retry(
func,
*args,
retry_config: RetryConfig | None = None,
key_store: "KeyStatusStore | None" = None,
provider_name: str | None = None,
**kwargs: Any,
) -> Any:
"""智能重试装饰器 - 支持Key轮询和错误分类"""
config = retry_config or RetryConfig()
last_exception: Exception | None = None
failed_keys: set[str] = set()
for attempt in range(config.max_retries + 1):
try:
if config.key_rotation and "failed_keys" in func.__code__.co_varnames:
kwargs["failed_keys"] = failed_keys
return await func(*args, **kwargs)
except LLMException as e:
last_exception = e
if e.code in [
LLMErrorCode.API_KEY_INVALID,
LLMErrorCode.API_QUOTA_EXCEEDED,
]:
if hasattr(e, "details") and e.details and "api_key" in e.details:
failed_keys.add(e.details["api_key"])
if key_store and provider_name:
await key_store.record_failure(
e.details["api_key"], e.details.get("status_code")
)
should_retry = _should_retry_llm_error(e, attempt, config.max_retries)
if not should_retry:
logger.error(f"不可重试的错误,停止重试: {e}")
raise
if attempt < config.max_retries:
wait_time = config.retry_delay
if config.exponential_backoff:
wait_time *= 2**attempt
logger.warning(
f"请求失败,{wait_time}秒后重试 (第{attempt + 1}次): {e}"
)
await asyncio.sleep(wait_time)
else:
logger.error(f"重试{config.max_retries}次后仍然失败: {e}")
except Exception as e:
last_exception = e
logger.error(f"非LLM异常停止重试: {e}")
raise LLMException(
f"操作失败: {e}",
code=LLMErrorCode.GENERATION_FAILED,
cause=e,
)
if last_exception:
raise last_exception
else:
raise RuntimeError("重试函数未能正常执行且未捕获到异常")
def _should_retry_llm_error(
error: LLMException, attempt: int, max_retries: int
) -> bool:
"""判断LLM错误是否应该重试"""
non_retryable_errors = {
LLMErrorCode.MODEL_NOT_FOUND,
LLMErrorCode.CONTEXT_LENGTH_EXCEEDED,
LLMErrorCode.USER_LOCATION_NOT_SUPPORTED,
LLMErrorCode.CONFIGURATION_ERROR,
}
if error.code in non_retryable_errors:
return False
retryable_errors = {
LLMErrorCode.API_REQUEST_FAILED,
LLMErrorCode.API_TIMEOUT,
LLMErrorCode.API_RATE_LIMITED,
LLMErrorCode.API_RESPONSE_INVALID,
LLMErrorCode.RESPONSE_PARSE_ERROR,
LLMErrorCode.GENERATION_FAILED,
LLMErrorCode.CONTENT_FILTERED,
LLMErrorCode.API_KEY_INVALID,
LLMErrorCode.API_QUOTA_EXCEEDED,
}
if error.code in retryable_errors:
if error.code == LLMErrorCode.API_QUOTA_EXCEEDED:
return attempt < min(2, max_retries)
elif error.code == LLMErrorCode.CONTENT_FILTERED:
return attempt < min(1, max_retries)
return True
return False
class KeyStatusStore:
"""API Key 状态管理存储 - 优化版本,支持轮询和负载均衡"""
def __init__(self):
self._key_status: dict[str, bool] = {}
self._key_usage_count: dict[str, int] = {}
self._key_last_used: dict[str, float] = {}
self._provider_key_index: dict[str, int] = {}
self._lock = asyncio.Lock()
async def get_next_available_key(
self,
provider_name: str,
api_keys: list[str],
exclude_keys: set[str] | None = None,
) -> str | None:
"""获取下一个可用的API密钥轮询策略"""
if not api_keys:
return None
exclude_keys = exclude_keys or set()
available_keys = [
key
for key in api_keys
if key not in exclude_keys and self._key_status.get(key, True)
]
if not available_keys:
return api_keys[0] if api_keys else None
async with self._lock:
current_index = self._provider_key_index.get(provider_name, 0)
selected_key = available_keys[current_index % len(available_keys)]
self._provider_key_index[provider_name] = (current_index + 1) % len(
available_keys
)
import time
self._key_usage_count[selected_key] = (
self._key_usage_count.get(selected_key, 0) + 1
)
self._key_last_used[selected_key] = time.time()
logger.debug(
f"轮询选择API密钥: {self._get_key_id(selected_key)} "
f"(使用次数: {self._key_usage_count[selected_key]})"
)
return selected_key
async def record_success(self, api_key: str):
"""记录成功使用"""
async with self._lock:
self._key_status[api_key] = True
logger.debug(f"记录API密钥成功使用: {self._get_key_id(api_key)}")
async def record_failure(self, api_key: str, status_code: int | None):
"""记录失败使用"""
key_id = self._get_key_id(api_key)
async with self._lock:
if status_code in [401, 403]:
self._key_status[api_key] = False
logger.warning(
f"API密钥认证失败标记为不可用: {key_id} (状态码: {status_code})"
)
else:
logger.debug(f"记录API密钥失败使用: {key_id} (状态码: {status_code})")
async def reset_key_status(self, api_key: str):
"""重置密钥状态(用于恢复机制)"""
async with self._lock:
self._key_status[api_key] = True
logger.info(f"重置API密钥状态: {self._get_key_id(api_key)}")
async def get_key_stats(self, api_keys: list[str]) -> dict[str, dict]:
"""获取密钥使用统计"""
stats = {}
async with self._lock:
for key in api_keys:
key_id = self._get_key_id(key)
stats[key_id] = {
"available": self._key_status.get(key, True),
"usage_count": self._key_usage_count.get(key, 0),
"last_used": self._key_last_used.get(key, 0),
}
return stats
def _get_key_id(self, api_key: str) -> str:
"""获取API密钥的标识符用于日志"""
if len(api_key) <= 8:
return api_key
return f"{api_key[:4]}...{api_key[-4:]}"
key_store = KeyStatusStore()

View File

@ -0,0 +1,434 @@
"""
LLM 模型管理器
负责模型实例的创建缓存配置管理和生命周期管理
"""
import hashlib
import json
import time
from typing import Any
from zhenxun.configs.config import Config
from zhenxun.services.log import logger
from .config import validate_override_params
from .config.providers import AI_CONFIG_GROUP, PROVIDERS_CONFIG_KEY, get_ai_config
from .core import http_client_manager, key_store
from .service import LLMModel
from .types import LLMErrorCode, LLMException, ModelDetail, ProviderConfig
DEFAULT_MODEL_NAME_KEY = "default_model_name"
PROXY_KEY = "proxy"
TIMEOUT_KEY = "timeout"
_model_cache: dict[str, tuple[LLMModel, float]] = {}
_cache_ttl = 3600
_max_cache_size = 10
def parse_provider_model_string(name_str: str | None) -> tuple[str | None, str | None]:
"""解析 'ProviderName/ModelName' 格式的字符串"""
if not name_str or "/" not in name_str:
return None, None
parts = name_str.split("/", 1)
if len(parts) == 2 and parts[0].strip() and parts[1].strip():
return parts[0].strip(), parts[1].strip()
return None, None
def _make_cache_key(
provider_model_name: str | None, override_config: dict | None
) -> str:
"""生成缓存键"""
config_str = (
json.dumps(override_config, sort_keys=True) if override_config else "None"
)
key_data = f"{provider_model_name}:{config_str}"
return hashlib.md5(key_data.encode()).hexdigest()
def _get_cached_model(cache_key: str) -> LLMModel | None:
"""从缓存获取模型"""
if cache_key in _model_cache:
model, created_time = _model_cache[cache_key]
current_time = time.time()
if current_time - created_time > _cache_ttl:
del _model_cache[cache_key]
logger.debug(f"模型缓存已过期: {cache_key}")
return None
if model._is_closed:
logger.debug(
f"缓存的模型 {cache_key} ({model.provider_name}/{model.model_name}) "
f"处于_is_closed=True状态重置为False以供复用。"
)
model._is_closed = False
logger.debug(
f"使用缓存的模型: {cache_key} -> {model.provider_name}/{model.model_name}"
)
return model
return None
def _cache_model(cache_key: str, model: LLMModel):
"""缓存模型实例"""
current_time = time.time()
if len(_model_cache) >= _max_cache_size:
oldest_key = min(_model_cache.keys(), key=lambda k: _model_cache[k][1])
del _model_cache[oldest_key]
_model_cache[cache_key] = (model, current_time)
def clear_model_cache():
"""清空模型缓存"""
global _model_cache
_model_cache.clear()
logger.info("已清空模型缓存")
def get_cache_stats() -> dict[str, Any]:
"""获取缓存统计信息"""
return {
"cache_size": len(_model_cache),
"max_cache_size": _max_cache_size,
"cache_ttl": _cache_ttl,
"cached_models": list(_model_cache.keys()),
}
def get_default_api_base_for_type(api_type: str) -> str | None:
"""根据API类型获取默认的API基础地址"""
default_api_bases = {
"openai": "https://api.openai.com",
"deepseek": "https://api.deepseek.com",
"zhipu": "https://open.bigmodel.cn",
"gemini": "https://generativelanguage.googleapis.com",
"general_openai_compat": None,
}
return default_api_bases.get(api_type)
def get_configured_providers() -> list[ProviderConfig]:
"""从配置中获取Provider列表 - 简化版本"""
ai_config = get_ai_config()
providers_raw = ai_config.get(PROVIDERS_CONFIG_KEY, [])
if not isinstance(providers_raw, list):
logger.error(
f"配置项 {AI_CONFIG_GROUP}.{PROVIDERS_CONFIG_KEY} 不是一个列表,"
f"将使用空列表。"
)
return []
valid_providers = []
for i, item in enumerate(providers_raw):
if not isinstance(item, dict):
logger.warning(f"配置文件中第 {i + 1} 项不是字典格式,已跳过。")
continue
try:
if not item.get("name"):
logger.warning(f"Provider {i + 1} 缺少 'name' 字段,已跳过。")
continue
if not item.get("api_key"):
logger.warning(
f"Provider '{item['name']}' 缺少 'api_key' 字段,已跳过。"
)
continue
if "api_type" not in item or not item["api_type"]:
provider_name = item.get("name", "").lower()
if "glm" in provider_name or "zhipu" in provider_name:
item["api_type"] = "zhipu"
elif "gemini" in provider_name or "google" in provider_name:
item["api_type"] = "gemini"
else:
item["api_type"] = "openai"
if "api_base" not in item or not item["api_base"]:
api_type = item.get("api_type")
if api_type:
default_api_base = get_default_api_base_for_type(api_type)
if default_api_base:
item["api_base"] = default_api_base
if "models" not in item:
item["models"] = [{"model_name": item.get("name", "default")}]
provider_conf = ProviderConfig(**item)
valid_providers.append(provider_conf)
except Exception as e:
logger.warning(f"解析配置文件中 Provider {i + 1} 时出错: {e},已跳过。")
return valid_providers
def find_model_config(
provider_name: str, model_name: str
) -> tuple[ProviderConfig, ModelDetail] | None:
"""在配置中查找指定的 Provider 和 ModelDetail
Args:
provider_name: 提供商名称
model_name: 模型名称
Returns:
找到的 (ProviderConfig, ModelDetail) 元组未找到则返回 None
"""
providers = get_configured_providers()
for provider in providers:
if provider.name.lower() == provider_name.lower():
for model_detail in provider.models:
if model_detail.model_name.lower() == model_name.lower():
return provider, model_detail
return None
def list_available_models() -> list[dict[str, Any]]:
"""列出所有配置的可用模型"""
providers = get_configured_providers()
model_list = []
for provider in providers:
for model_detail in provider.models:
model_info = {
"provider_name": provider.name,
"model_name": model_detail.model_name,
"full_name": f"{provider.name}/{model_detail.model_name}",
"api_type": provider.api_type or "auto-detect",
"api_base": provider.api_base,
"is_available": model_detail.is_available,
"is_embedding_model": model_detail.is_embedding_model,
"available_identifiers": _get_model_identifiers(
provider.name, model_detail
),
}
model_list.append(model_info)
return model_list
def _get_model_identifiers(provider_name: str, model_detail: ModelDetail) -> list[str]:
"""获取模型的所有可用标识符"""
return [f"{provider_name}/{model_detail.model_name}"]
def list_model_identifiers() -> dict[str, list[str]]:
"""列出所有模型的可用标识符
Returns:
字典键为模型的完整名称值为该模型的所有可用标识符列表
"""
providers = get_configured_providers()
result = {}
for provider in providers:
for model_detail in provider.models:
full_name = f"{provider.name}/{model_detail.model_name}"
identifiers = _get_model_identifiers(provider.name, model_detail)
result[full_name] = identifiers
return result
def list_embedding_models() -> list[dict[str, Any]]:
"""列出所有配置的嵌入模型"""
all_models = list_available_models()
return [model for model in all_models if model.get("is_embedding_model", False)]
async def get_model_instance(
provider_model_name: str | None = None,
override_config: dict[str, Any] | None = None,
) -> LLMModel:
"""根据 'ProviderName/ModelName' 字符串获取并实例化 LLMModel (异步版本)"""
cache_key = _make_cache_key(provider_model_name, override_config)
cached_model = _get_cached_model(cache_key)
if cached_model:
if override_config:
validated_override = validate_override_params(override_config)
if cached_model._generation_config != validated_override:
cached_model._generation_config = validated_override
logger.debug(
f"对缓存模型 {provider_model_name} 应用新的覆盖配置: "
f"{validated_override.to_dict()}"
)
return cached_model
resolved_model_name_str = provider_model_name
if resolved_model_name_str is None:
resolved_model_name_str = get_global_default_model_name()
if resolved_model_name_str is None:
available_models_list = list_available_models()
if not available_models_list:
raise LLMException(
"未配置任何AI模型", code=LLMErrorCode.CONFIGURATION_ERROR
)
resolved_model_name_str = available_models_list[0]["full_name"]
logger.warning(f"未指定模型,使用第一个可用模型: {resolved_model_name_str}")
prov_name_str, mod_name_str = parse_provider_model_string(resolved_model_name_str)
if not prov_name_str or not mod_name_str:
raise LLMException(
f"无效的模型名称格式: '{resolved_model_name_str}'",
code=LLMErrorCode.MODEL_NOT_FOUND,
)
config_tuple_found = find_model_config(prov_name_str, mod_name_str)
if not config_tuple_found:
all_models = list_available_models()
raise LLMException(
f"未找到模型: '{resolved_model_name_str}'. "
f"可用: {[m['full_name'] for m in all_models]}",
code=LLMErrorCode.MODEL_NOT_FOUND,
)
provider_config_found, model_detail_found = config_tuple_found
ai_config = get_ai_config()
global_proxy_setting = ai_config.get(PROXY_KEY)
default_timeout = (
provider_config_found.timeout
if provider_config_found.timeout is not None
else 180
)
global_timeout_setting = ai_config.get(TIMEOUT_KEY, default_timeout)
config_for_http_client = ProviderConfig(
name=provider_config_found.name,
api_key=provider_config_found.api_key,
models=provider_config_found.models,
timeout=global_timeout_setting,
proxy=global_proxy_setting,
api_base=provider_config_found.api_base,
api_type=provider_config_found.api_type,
openai_compat=provider_config_found.openai_compat,
temperature=provider_config_found.temperature,
max_tokens=provider_config_found.max_tokens,
)
shared_http_client = await http_client_manager.get_client(config_for_http_client)
try:
model_instance = LLMModel(
provider_config=config_for_http_client,
model_detail=model_detail_found,
key_store=key_store,
http_client=shared_http_client,
)
if override_config:
validated_override_params = validate_override_params(override_config)
model_instance._generation_config = validated_override_params
logger.debug(
f"为新模型 {resolved_model_name_str} 应用配置覆盖: "
f"{validated_override_params.to_dict()}"
)
_cache_model(cache_key, model_instance)
logger.debug(
f"创建并缓存了新模型: {cache_key} -> {prov_name_str}/{mod_name_str}"
)
return model_instance
except LLMException:
raise
except Exception as e:
logger.error(
f"实例化 LLMModel ({resolved_model_name_str}) 时发生内部错误: {e!s}", e=e
)
raise LLMException(
f"初始化模型 '{resolved_model_name_str}' 失败: {e!s}",
code=LLMErrorCode.MODEL_INIT_FAILED,
cause=e,
)
def get_global_default_model_name() -> str | None:
"""获取全局默认模型名称"""
ai_config = get_ai_config()
return ai_config.get(DEFAULT_MODEL_NAME_KEY)
def set_global_default_model_name(provider_model_name: str | None) -> bool:
"""设置全局默认模型名称"""
if provider_model_name:
prov_name, mod_name = parse_provider_model_string(provider_model_name)
if not prov_name or not mod_name or not find_model_config(prov_name, mod_name):
logger.error(
f"尝试设置的全局默认模型 '{provider_model_name}' 无效或未配置。"
)
return False
Config.set_config(
AI_CONFIG_GROUP, DEFAULT_MODEL_NAME_KEY, provider_model_name, auto_save=True
)
if provider_model_name:
logger.info(f"LLM 服务全局默认模型已更新为: {provider_model_name}")
else:
logger.info("LLM 服务全局默认模型已清除。")
return True
async def get_key_usage_stats() -> dict[str, Any]:
"""获取所有Provider的Key使用统计"""
providers = get_configured_providers()
stats = {}
for provider in providers:
provider_stats = await key_store.get_key_stats(
[provider.api_key]
if isinstance(provider.api_key, str)
else provider.api_key
)
stats[provider.name] = {
"total_keys": len(
[provider.api_key]
if isinstance(provider.api_key, str)
else provider.api_key
),
"key_stats": provider_stats,
}
return stats
async def reset_key_status(provider_name: str, api_key: str | None = None) -> bool:
"""重置指定Provider的Key状态"""
providers = get_configured_providers()
target_provider = None
for provider in providers:
if provider.name.lower() == provider_name.lower():
target_provider = provider
break
if not target_provider:
logger.error(f"未找到Provider: {provider_name}")
return False
provider_keys = (
[target_provider.api_key]
if isinstance(target_provider.api_key, str)
else target_provider.api_key
)
if api_key:
if api_key in provider_keys:
await key_store.reset_key_status(api_key)
logger.info(f"已重置Provider '{provider_name}' 的指定Key状态")
return True
else:
logger.error(f"指定的Key不属于Provider '{provider_name}'")
return False
else:
for key in provider_keys:
await key_store.reset_key_status(key)
logger.info(f"已重置Provider '{provider_name}' 的所有Key状态")
return True

View File

@ -0,0 +1,632 @@
"""
LLM 模型实现类
包含 LLM 模型的抽象基类和具体实现负责与各种 AI 提供商的 API 交互
"""
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Callable
import json
from typing import Any
from zhenxun.services.log import logger
from .config import LLMGenerationConfig
from .config.providers import get_ai_config
from .core import (
KeyStatusStore,
LLMHttpClient,
RetryConfig,
http_client_manager,
with_smart_retry,
)
from .types import (
EmbeddingTaskType,
LLMErrorCode,
LLMException,
LLMMessage,
LLMResponse,
LLMTool,
ModelDetail,
ProviderConfig,
)
class LLMModelBase(ABC):
"""LLM模型抽象基类"""
@abstractmethod
async def generate_text(
self,
prompt: str,
history: list[dict[str, str]] | None = None,
**kwargs: Any,
) -> str:
"""生成文本"""
pass
@abstractmethod
async def generate_response(
self,
messages: list[LLMMessage],
config: LLMGenerationConfig | None = None,
tools: list[LLMTool] | None = None,
tool_choice: str | dict[str, Any] | None = None,
**kwargs: Any,
) -> LLMResponse:
"""生成高级响应"""
pass
@abstractmethod
async def generate_embeddings(
self,
texts: list[str],
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
**kwargs: Any,
) -> list[list[float]]:
"""生成文本嵌入向量"""
pass
class LLMModel(LLMModelBase):
"""LLM 模型实现类"""
def __init__(
self,
provider_config: ProviderConfig,
model_detail: ModelDetail,
key_store: KeyStatusStore,
http_client: LLMHttpClient,
config_override: LLMGenerationConfig | None = None,
):
self.provider_config = provider_config
self.model_detail = model_detail
self.key_store = key_store
self.http_client: LLMHttpClient = http_client
self._generation_config = config_override
self.provider_name = provider_config.name
self.api_type = provider_config.api_type
self.api_base = provider_config.api_base
self.api_keys = (
[provider_config.api_key]
if isinstance(provider_config.api_key, str)
else provider_config.api_key
)
self.model_name = model_detail.model_name
self.temperature = model_detail.temperature
self.max_tokens = model_detail.max_tokens
self._is_closed = False
async def _get_http_client(self) -> LLMHttpClient:
"""获取HTTP客户端"""
if self.http_client.is_closed:
logger.debug(
f"LLMModel {self.provider_name}/{self.model_name} 的 HTTP 客户端已关闭,"
"正在获取新的客户端"
)
self.http_client = await http_client_manager.get_client(
self.provider_config
)
return self.http_client
async def _select_api_key(self, failed_keys: set[str] | None = None) -> str:
"""选择可用的API密钥使用轮询策略"""
if not self.api_keys:
raise LLMException(
f"提供商 {self.provider_name} 没有配置API密钥",
code=LLMErrorCode.NO_AVAILABLE_KEYS,
)
selected_key = await self.key_store.get_next_available_key(
self.provider_name, self.api_keys, failed_keys
)
if not selected_key:
raise LLMException(
f"提供商 {self.provider_name} 的所有API密钥当前都不可用",
code=LLMErrorCode.NO_AVAILABLE_KEYS,
details={
"total_keys": len(self.api_keys),
"failed_keys": len(failed_keys or set()),
},
)
return selected_key
async def _execute_embedding_request(
self,
adapter,
texts: list[str],
task_type: EmbeddingTaskType | str,
http_client: LLMHttpClient,
failed_keys: set[str] | None = None,
) -> list[list[float]]:
"""执行单次嵌入请求 - 供重试机制调用"""
api_key = await self._select_api_key(failed_keys)
try:
request_data = adapter.prepare_embedding_request(
model=self,
api_key=api_key,
texts=texts,
task_type=task_type,
)
http_response = await http_client.post(
request_data.url,
headers=request_data.headers,
json=request_data.body,
)
if http_response.status_code != 200:
error_text = http_response.text
logger.error(
f"HTTP嵌入请求失败: {http_response.status_code} - {error_text}"
)
await self.key_store.record_failure(api_key, http_response.status_code)
error_code = LLMErrorCode.API_REQUEST_FAILED
if http_response.status_code in [401, 403]:
error_code = LLMErrorCode.API_KEY_INVALID
elif http_response.status_code == 429:
error_code = LLMErrorCode.API_RATE_LIMITED
raise LLMException(
f"HTTP嵌入请求失败: {http_response.status_code}",
code=error_code,
details={
"status_code": http_response.status_code,
"response": error_text,
"api_key": api_key,
},
)
try:
response_json = http_response.json()
adapter.validate_embedding_response(response_json)
embeddings = adapter.parse_embedding_response(response_json)
except Exception as e:
logger.error(f"解析嵌入响应失败: {e}", e=e)
await self.key_store.record_failure(api_key, None)
if isinstance(e, LLMException):
raise
else:
raise LLMException(
f"解析API嵌入响应失败: {e}",
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
cause=e,
)
await self.key_store.record_success(api_key)
return embeddings
except LLMException:
raise
except Exception as e:
logger.error(f"生成嵌入时发生未预期错误: {e}", e=e)
await self.key_store.record_failure(api_key, None)
raise LLMException(
f"生成嵌入失败: {e}",
code=LLMErrorCode.EMBEDDING_FAILED,
cause=e,
)
async def _execute_with_smart_retry(
self,
adapter,
messages: list[LLMMessage],
config: LLMGenerationConfig | None,
tools_dict: list[dict[str, Any]] | None,
tool_choice: str | dict[str, Any] | None,
http_client: LLMHttpClient,
):
"""智能重试机制 - 使用统一的重试装饰器"""
ai_config = get_ai_config()
max_retries = ai_config.get("max_retries_llm", 3)
retry_delay = ai_config.get("retry_delay_llm", 2)
retry_config = RetryConfig(max_retries=max_retries, retry_delay=retry_delay)
return await with_smart_retry(
self._execute_single_request,
adapter,
messages,
config,
tools_dict,
tool_choice,
http_client,
retry_config=retry_config,
key_store=self.key_store,
provider_name=self.provider_name,
)
async def _execute_single_request(
self,
adapter,
messages: list[LLMMessage],
config: LLMGenerationConfig | None,
tools_dict: list[dict[str, Any]] | None,
tool_choice: str | dict[str, Any] | None,
http_client: LLMHttpClient,
failed_keys: set[str] | None = None,
) -> LLMResponse:
"""执行单次请求 - 供重试机制调用,直接返回 LLMResponse"""
api_key = await self._select_api_key(failed_keys)
try:
request_data = adapter.prepare_advanced_request(
model=self,
api_key=api_key,
messages=messages,
config=config,
tools=tools_dict,
tool_choice=tool_choice,
)
http_response = await http_client.post(
request_data.url,
headers=request_data.headers,
json=request_data.body,
)
if http_response.status_code != 200:
error_text = http_response.text
logger.error(
f"HTTP请求失败: {http_response.status_code} - {error_text}"
)
await self.key_store.record_failure(api_key, http_response.status_code)
if http_response.status_code in [401, 403]:
error_code = LLMErrorCode.API_KEY_INVALID
elif http_response.status_code == 429:
error_code = LLMErrorCode.API_RATE_LIMITED
elif http_response.status_code in [402, 413]:
error_code = LLMErrorCode.API_QUOTA_EXCEEDED
else:
error_code = LLMErrorCode.API_REQUEST_FAILED
raise LLMException(
f"HTTP请求失败: {http_response.status_code}",
code=error_code,
details={
"status_code": http_response.status_code,
"response": error_text,
"api_key": api_key,
},
)
try:
response_json = http_response.json()
response_data = adapter.parse_response(
model=self,
response_json=response_json,
is_advanced=True,
)
from .types.models import LLMToolCall
response_tool_calls = []
if response_data.tool_calls:
for tc_data in response_data.tool_calls:
if isinstance(tc_data, LLMToolCall):
response_tool_calls.append(tc_data)
elif isinstance(tc_data, dict):
try:
response_tool_calls.append(LLMToolCall(**tc_data))
except Exception as e:
logger.warning(
f"无法将工具调用数据转换为LLMToolCall: {tc_data}, "
f"error: {e}"
)
else:
logger.warning(f"工具调用数据格式未知: {tc_data}")
llm_response = LLMResponse(
text=response_data.text,
usage_info=response_data.usage_info,
raw_response=response_data.raw_response,
tool_calls=response_tool_calls if response_tool_calls else None,
code_executions=response_data.code_executions,
grounding_metadata=response_data.grounding_metadata,
cache_info=response_data.cache_info,
)
except Exception as e:
logger.error(f"解析响应失败: {e}", e=e)
await self.key_store.record_failure(api_key, None)
if isinstance(e, LLMException):
raise
else:
raise LLMException(
f"解析API响应失败: {e}",
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
cause=e,
)
await self.key_store.record_success(api_key)
return llm_response
except LLMException:
raise
except Exception as e:
logger.error(f"生成响应时发生未预期错误: {e}", e=e)
await self.key_store.record_failure(api_key, None)
raise LLMException(
f"生成响应失败: {e}",
code=LLMErrorCode.GENERATION_FAILED,
cause=e,
)
async def close(self):
"""
标记模型实例的当前使用周期结束
共享的 HTTP 客户端由 LLMHttpClientManager 管理不由 LLMModel 关闭
"""
if self._is_closed:
return
self._is_closed = True
logger.debug(
f"LLMModel实例的使用周期已结束: {self} (共享HTTP客户端状态不受影响)"
)
async def __aenter__(self):
if self._is_closed:
logger.debug(
f"Re-entering context for closed LLMModel {self}. "
f"Resetting _is_closed to False."
)
self._is_closed = False
self._check_not_closed()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""异步上下文管理器出口"""
_ = exc_type, exc_val, exc_tb
await self.close()
def _check_not_closed(self):
"""检查实例是否已关闭"""
if self._is_closed:
raise RuntimeError(f"LLMModel实例已关闭: {self}")
async def generate_text(
self,
prompt: str,
history: list[dict[str, str]] | None = None,
**kwargs: Any,
) -> str:
"""生成文本 - 通过 generate_response 实现"""
self._check_not_closed()
messages: list[LLMMessage] = []
if history:
for msg in history:
role = msg.get("role", "user")
content_text = msg.get("content", "")
messages.append(LLMMessage(role=role, content=content_text))
messages.append(LLMMessage.user(prompt))
model_fields = getattr(LLMGenerationConfig, "model_fields", {})
request_specific_config_dict = {
k: v for k, v in kwargs.items() if k in model_fields
}
request_specific_config = None
if request_specific_config_dict:
request_specific_config = LLMGenerationConfig(
**request_specific_config_dict
)
for key in request_specific_config_dict:
kwargs.pop(key, None)
response = await self.generate_response(
messages,
config=request_specific_config,
**kwargs,
)
return response.text
async def generate_response(
self,
messages: list[LLMMessage],
config: LLMGenerationConfig | None = None,
tools: list[LLMTool] | None = None,
tool_choice: str | dict[str, Any] | None = None,
tool_executor: Callable[[str, dict[str, Any]], Awaitable[Any]] | None = None,
max_tool_iterations: int = 5,
**kwargs: Any,
) -> LLMResponse:
"""生成高级响应 - 实现完整的工具调用循环"""
self._check_not_closed()
from .adapters import get_adapter_for_api_type
from .config.generation import create_generation_config_from_kwargs
adapter = get_adapter_for_api_type(self.api_type)
if not adapter:
raise LLMException(
f"未找到适用于 API 类型 '{self.api_type}' 的适配器",
code=LLMErrorCode.CONFIGURATION_ERROR,
)
final_request_config = self._generation_config or LLMGenerationConfig()
if kwargs:
kwargs_config = create_generation_config_from_kwargs(**kwargs)
merged_dict = final_request_config.to_dict()
merged_dict.update(kwargs_config.to_dict())
final_request_config = LLMGenerationConfig(**merged_dict)
if config is not None:
merged_dict = final_request_config.to_dict()
merged_dict.update(config.to_dict())
final_request_config = LLMGenerationConfig(**merged_dict)
tools_dict: list[dict[str, Any]] | None = None
if tools:
tools_dict = []
for tool in tools:
if hasattr(tool, "model_dump"):
model_dump_func = getattr(tool, "model_dump")
tools_dict.append(model_dump_func(exclude_none=True))
elif isinstance(tool, dict):
tools_dict.append(tool)
else:
try:
tools_dict.append(dict(tool))
except (TypeError, ValueError):
logger.warning(f"工具 '{tool}' 无法转换为字典,已忽略。")
http_client = await self._get_http_client()
current_messages = list(messages)
for iteration in range(max_tool_iterations):
logger.debug(f"工具调用循环迭代: {iteration + 1}/{max_tool_iterations}")
llm_response = await self._execute_with_smart_retry(
adapter,
current_messages,
final_request_config,
tools_dict if iteration == 0 else None,
tool_choice if iteration == 0 else None,
http_client,
)
response_tool_calls = llm_response.tool_calls or []
if not response_tool_calls or not tool_executor:
logger.debug("模型未请求工具调用,或未提供工具执行器。返回当前响应。")
return llm_response
logger.info(f"模型请求执行 {len(response_tool_calls)} 个工具。")
assistant_message_content = llm_response.text if llm_response.text else ""
current_messages.append(
LLMMessage.assistant_tool_calls(
content=assistant_message_content, tool_calls=response_tool_calls
)
)
tool_response_messages: list[LLMMessage] = []
for tool_call in response_tool_calls:
tool_name = tool_call.function.name
try:
tool_args_dict = json.loads(tool_call.function.arguments)
logger.debug(f"执行工具: {tool_name},参数: {tool_args_dict}")
tool_result = await tool_executor(tool_name, tool_args_dict)
logger.debug(
f"工具 '{tool_name}' 执行结果: {str(tool_result)[:200]}..."
)
tool_response_messages.append(
LLMMessage.tool_response(
tool_call_id=tool_call.id,
function_name=tool_name,
result=tool_result,
)
)
except json.JSONDecodeError as e:
logger.error(
f"工具 '{tool_name}' 参数JSON解析失败: "
f"{tool_call.function.arguments}, 错误: {e}"
)
tool_response_messages.append(
LLMMessage.tool_response(
tool_call_id=tool_call.id,
function_name=tool_name,
result={
"error": "Argument JSON parsing failed",
"details": str(e),
},
)
)
except Exception as e:
logger.error(f"执行工具 '{tool_name}' 失败: {e}", e=e)
tool_response_messages.append(
LLMMessage.tool_response(
tool_call_id=tool_call.id,
function_name=tool_name,
result={
"error": "Tool execution failed",
"details": str(e),
},
)
)
current_messages.extend(tool_response_messages)
logger.warning(f"已达到最大工具调用迭代次数 ({max_tool_iterations})。")
raise LLMException(
"已达到最大工具调用迭代次数,但模型仍在请求工具调用或未提供最终文本回复。",
code=LLMErrorCode.GENERATION_FAILED,
details={
"iterations": max_tool_iterations,
"last_messages": current_messages[-2:],
},
)
async def generate_embeddings(
self,
texts: list[str],
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
**kwargs: Any,
) -> list[list[float]]:
"""生成文本嵌入向量"""
self._check_not_closed()
if not texts:
return []
from .adapters import get_adapter_for_api_type
adapter = get_adapter_for_api_type(self.api_type)
if not adapter:
raise LLMException(
f"未找到适用于 API 类型 '{self.api_type}' 的嵌入适配器",
code=LLMErrorCode.CONFIGURATION_ERROR,
)
http_client = await self._get_http_client()
ai_config = get_ai_config()
default_max_retries = ai_config.get("max_retries_llm", 3)
default_retry_delay = ai_config.get("retry_delay_llm", 2)
max_retries_embed = kwargs.get(
"max_retries_embed", max(1, default_max_retries // 2)
)
retry_delay_embed = kwargs.get("retry_delay_embed", default_retry_delay / 2)
retry_config = RetryConfig(
max_retries=max_retries_embed,
retry_delay=retry_delay_embed,
exponential_backoff=True,
key_rotation=True,
)
return await with_smart_retry(
self._execute_embedding_request,
adapter,
texts,
task_type,
http_client,
retry_config=retry_config,
key_store=self.key_store,
provider_name=self.provider_name,
)
def __str__(self) -> str:
status = "closed" if self._is_closed else "active"
return f"LLMModel({self.provider_name}/{self.model_name}, {status})"
def __repr__(self) -> str:
status = "closed" if self._is_closed else "active"
return (
f"LLMModel(provider={self.provider_name}, model={self.model_name}, "
f"api_type={self.api_type}, status={status})"
)

View File

@ -0,0 +1,54 @@
"""
LLM 类型定义模块
统一导出所有核心类型协议和异常定义
"""
from .content import (
LLMContentPart,
LLMMessage,
LLMResponse,
)
from .enums import EmbeddingTaskType, ModelProvider, ResponseFormat, ToolCategory
from .exceptions import LLMErrorCode, LLMException, get_user_friendly_error_message
from .models import (
LLMCacheInfo,
LLMCodeExecution,
LLMGroundingAttribution,
LLMGroundingMetadata,
LLMTool,
LLMToolCall,
LLMToolFunction,
ModelDetail,
ModelInfo,
ModelName,
ProviderConfig,
ToolMetadata,
UsageInfo,
)
__all__ = [
"EmbeddingTaskType",
"LLMCacheInfo",
"LLMCodeExecution",
"LLMContentPart",
"LLMErrorCode",
"LLMException",
"LLMGroundingAttribution",
"LLMGroundingMetadata",
"LLMMessage",
"LLMResponse",
"LLMTool",
"LLMToolCall",
"LLMToolFunction",
"ModelDetail",
"ModelInfo",
"ModelName",
"ModelProvider",
"ProviderConfig",
"ResponseFormat",
"ToolCategory",
"ToolMetadata",
"UsageInfo",
"get_user_friendly_error_message",
]

View File

@ -0,0 +1,428 @@
"""
LLM 内容类型定义
包含多模态内容部分消息和响应的数据模型
"""
import base64
import mimetypes
from pathlib import Path
from typing import Any
import aiofiles
from pydantic import BaseModel
from zhenxun.services.log import logger
class LLMContentPart(BaseModel):
"""LLM 消息内容部分 - 支持多模态内容"""
type: str
text: str | None = None
image_source: str | None = None
audio_source: str | None = None
video_source: str | None = None
document_source: str | None = None
file_uri: str | None = None
file_source: str | None = None
url: str | None = None
mime_type: str | None = None
metadata: dict[str, Any] | None = None
def model_post_init(self, /, __context: Any) -> None:
"""验证内容部分的有效性"""
_ = __context
validation_rules = {
"text": lambda: self.text,
"image": lambda: self.image_source,
"audio": lambda: self.audio_source,
"video": lambda: self.video_source,
"document": lambda: self.document_source,
"file": lambda: self.file_uri or self.file_source,
"url": lambda: self.url,
}
if self.type in validation_rules:
if not validation_rules[self.type]():
raise ValueError(f"{self.type}类型的内容部分必须包含相应字段")
@classmethod
def text_part(cls, text: str) -> "LLMContentPart":
"""创建文本内容部分"""
return cls(type="text", text=text)
@classmethod
def image_url_part(cls, url: str) -> "LLMContentPart":
"""创建图片URL内容部分"""
return cls(type="image", image_source=url)
@classmethod
def image_base64_part(
cls, data: str, mime_type: str = "image/png"
) -> "LLMContentPart":
"""创建Base64图片内容部分"""
data_url = f"data:{mime_type};base64,{data}"
return cls(type="image", image_source=data_url)
@classmethod
def audio_url_part(cls, url: str, mime_type: str = "audio/wav") -> "LLMContentPart":
"""创建音频URL内容部分"""
return cls(type="audio", audio_source=url, mime_type=mime_type)
@classmethod
def video_url_part(cls, url: str, mime_type: str = "video/mp4") -> "LLMContentPart":
"""创建视频URL内容部分"""
return cls(type="video", video_source=url, mime_type=mime_type)
@classmethod
def video_base64_part(
cls, data: str, mime_type: str = "video/mp4"
) -> "LLMContentPart":
"""创建Base64视频内容部分"""
data_url = f"data:{mime_type};base64,{data}"
return cls(type="video", video_source=data_url, mime_type=mime_type)
@classmethod
def audio_base64_part(
cls, data: str, mime_type: str = "audio/wav"
) -> "LLMContentPart":
"""创建Base64音频内容部分"""
data_url = f"data:{mime_type};base64,{data}"
return cls(type="audio", audio_source=data_url, mime_type=mime_type)
@classmethod
def file_uri_part(
cls,
file_uri: str,
mime_type: str | None = None,
metadata: dict[str, Any] | None = None,
) -> "LLMContentPart":
"""创建Gemini File API URI内容部分"""
return cls(
type="file",
file_uri=file_uri,
mime_type=mime_type,
metadata=metadata or {},
)
@classmethod
async def from_path(
cls, path_like: str | Path, target_api: str | None = None
) -> "LLMContentPart | None":
"""
从本地文件路径创建 LLMContentPart
自动检测MIME类型并根据类型如图片可能加载为Base64
target_api 可以用于提示如何最好地准备数据例如 'gemini' 可能偏好 base64
"""
try:
path = Path(path_like)
if not path.exists() or not path.is_file():
logger.warning(f"文件不存在或不是一个文件: {path}")
return None
mime_type, _ = mimetypes.guess_type(path.resolve().as_uri())
if not mime_type:
logger.warning(
f"无法猜测文件 {path.name} 的MIME类型将尝试作为文本文件处理。"
)
try:
async with aiofiles.open(path, encoding="utf-8") as f:
text_content = await f.read()
return cls.text_part(text_content)
except Exception as e:
logger.error(f"读取文本文件 {path.name} 失败: {e}")
return None
if mime_type.startswith("image/"):
if target_api == "gemini" or not path.is_absolute():
try:
async with aiofiles.open(path, "rb") as f:
img_bytes = await f.read()
base64_data = base64.b64encode(img_bytes).decode("utf-8")
return cls.image_base64_part(
data=base64_data, mime_type=mime_type
)
except Exception as e:
logger.error(f"读取或编码图片文件 {path.name} 失败: {e}")
return None
else:
logger.warning(
f"为本地图片路径 {path.name} 生成 image_url_part。"
"实际API可能不支持 file:// URI。考虑使用Base64或公网URL。"
)
return cls.image_url_part(url=path.resolve().as_uri())
elif mime_type.startswith("audio/"):
return cls.audio_url_part(
url=path.resolve().as_uri(), mime_type=mime_type
)
elif mime_type.startswith("video/"):
if target_api == "gemini":
# 对于 Gemini API将视频转换为 base64
try:
async with aiofiles.open(path, "rb") as f:
video_bytes = await f.read()
base64_data = base64.b64encode(video_bytes).decode("utf-8")
return cls.video_base64_part(
data=base64_data, mime_type=mime_type
)
except Exception as e:
logger.error(f"读取或编码视频文件 {path.name} 失败: {e}")
return None
else:
return cls.video_url_part(
url=path.resolve().as_uri(), mime_type=mime_type
)
elif (
mime_type.startswith("text/")
or mime_type == "application/json"
or mime_type == "application/xml"
):
try:
async with aiofiles.open(path, encoding="utf-8") as f:
text_content = await f.read()
return cls.text_part(text_content)
except Exception as e:
logger.error(f"读取文本类文件 {path.name} 失败: {e}")
return None
else:
logger.info(
f"文件 {path.name} (MIME: {mime_type}) 将作为通用文件URI处理。"
)
return cls.file_uri_part(
file_uri=path.resolve().as_uri(),
mime_type=mime_type,
metadata={"name": path.name, "source": "local_path"},
)
except Exception as e:
logger.error(f"从路径 {path_like} 创建LLMContentPart时出错: {e}")
return None
def is_image_url(self) -> bool:
"""检查图像源是否为URL"""
if not self.image_source:
return False
return self.image_source.startswith(("http://", "https://"))
def is_image_base64(self) -> bool:
"""检查图像源是否为Base64 Data URL"""
if not self.image_source:
return False
return self.image_source.startswith("data:")
def get_base64_data(self) -> tuple[str, str] | None:
"""从Data URL中提取Base64数据和MIME类型"""
if not self.is_image_base64() or not self.image_source:
return None
try:
header, data = self.image_source.split(",", 1)
mime_part = header.split(";")[0].replace("data:", "")
return mime_part, data
except (ValueError, IndexError):
logger.warning(f"无法解析Base64图像数据: {self.image_source[:50]}...")
return None
def convert_for_api(self, api_type: str) -> dict[str, Any]:
"""根据API类型转换多模态内容格式"""
if self.type == "text":
if api_type == "openai":
return {"type": "text", "text": self.text}
elif api_type == "gemini":
return {"text": self.text}
else:
return {"type": "text", "text": self.text}
elif self.type == "image":
if not self.image_source:
raise ValueError("图像类型的内容必须包含image_source")
if api_type == "openai":
return {"type": "image_url", "image_url": {"url": self.image_source}}
elif api_type == "gemini":
if self.is_image_base64():
base64_info = self.get_base64_data()
if base64_info:
mime_type, data = base64_info
return {"inlineData": {"mimeType": mime_type, "data": data}}
else:
# 如果无法解析 Base64 数据,抛出异常
raise ValueError(
f"无法解析Base64图像数据: {self.image_source[:50]}..."
)
else:
logger.warning(
f"Gemini API需要Base64格式但提供的是URL: {self.image_source}"
)
return {
"inlineData": {
"mimeType": "image/jpeg",
"data": self.image_source,
}
}
else:
return {"type": "image_url", "image_url": {"url": self.image_source}}
elif self.type == "video":
if not self.video_source:
raise ValueError("视频类型的内容必须包含video_source")
if api_type == "gemini":
# Gemini 支持视频,但需要通过 File API 上传
if self.video_source.startswith("data:"):
# 处理 base64 视频数据
try:
header, data = self.video_source.split(",", 1)
mime_type = header.split(";")[0].replace("data:", "")
return {"inlineData": {"mimeType": mime_type, "data": data}}
except (ValueError, IndexError):
raise ValueError(
f"无法解析Base64视频数据: {self.video_source[:50]}..."
)
else:
# 对于 URL 或其他格式,暂时不支持直接内联
raise ValueError(
"Gemini API 的视频处理需要通过 File API 上传,不支持直接 URL"
)
else:
# 其他 API 可能不支持视频
raise ValueError(f"API类型 '{api_type}' 不支持视频内容")
elif self.type == "audio":
if not self.audio_source:
raise ValueError("音频类型的内容必须包含audio_source")
if api_type == "gemini":
# Gemini 支持音频,处理方式类似视频
if self.audio_source.startswith("data:"):
try:
header, data = self.audio_source.split(",", 1)
mime_type = header.split(";")[0].replace("data:", "")
return {"inlineData": {"mimeType": mime_type, "data": data}}
except (ValueError, IndexError):
raise ValueError(
f"无法解析Base64音频数据: {self.audio_source[:50]}..."
)
else:
raise ValueError(
"Gemini API 的音频处理需要通过 File API 上传,不支持直接 URL"
)
else:
raise ValueError(f"API类型 '{api_type}' 不支持音频内容")
elif self.type == "file":
if api_type == "gemini" and self.file_uri:
return {
"fileData": {"mimeType": self.mime_type, "fileUri": self.file_uri}
}
elif self.file_source:
file_name = (
self.metadata.get("name", "file") if self.metadata else "file"
)
if api_type == "gemini":
return {"text": f"[文件: {file_name}]\n{self.file_source}"}
else:
return {
"type": "text",
"text": f"[文件: {file_name}]\n{self.file_source}",
}
else:
raise ValueError("文件类型的内容必须包含file_uri或file_source")
else:
raise ValueError(f"不支持的内容类型: {self.type}")
class LLMMessage(BaseModel):
"""LLM 消息"""
role: str
content: str | list[LLMContentPart]
name: str | None = None
tool_calls: list[Any] | None = None
tool_call_id: str | None = None
def model_post_init(self, /, __context: Any) -> None:
"""验证消息的有效性"""
_ = __context
if self.role == "tool":
if not self.tool_call_id:
raise ValueError("工具角色的消息必须包含 tool_call_id")
if not self.name:
raise ValueError("工具角色的消息必须包含函数名 (在 name 字段中)")
if self.role == "tool" and not isinstance(self.content, str):
logger.warning(
f"工具角色消息的内容期望是字符串,但得到的是: {type(self.content)}. "
"将尝试转换为字符串。"
)
try:
self.content = str(self.content)
except Exception as e:
raise ValueError(f"无法将工具角色的内容转换为字符串: {e}")
@classmethod
def user(cls, content: str | list[LLMContentPart]) -> "LLMMessage":
"""创建用户消息"""
return cls(role="user", content=content)
@classmethod
def assistant_tool_calls(
cls,
tool_calls: list[Any],
content: str | list[LLMContentPart] = "",
) -> "LLMMessage":
"""创建助手请求工具调用的消息"""
return cls(role="assistant", content=content, tool_calls=tool_calls)
@classmethod
def assistant_text_response(
cls, content: str | list[LLMContentPart]
) -> "LLMMessage":
"""创建助手纯文本回复的消息"""
return cls(role="assistant", content=content, tool_calls=None)
@classmethod
def tool_response(
cls,
tool_call_id: str,
function_name: str,
result: Any,
) -> "LLMMessage":
"""创建工具执行结果的消息"""
import json
try:
content_str = json.dumps(result)
except TypeError as e:
logger.error(
f"工具 '{function_name}' 的结果无法JSON序列化: {result}. 错误: {e}"
)
content_str = json.dumps(
{"error": "Tool result not JSON serializable", "details": str(e)}
)
return cls(
role="tool",
content=content_str,
tool_call_id=tool_call_id,
name=function_name,
)
@classmethod
def system(cls, content: str) -> "LLMMessage":
"""创建系统消息"""
return cls(role="system", content=content)
class LLMResponse(BaseModel):
"""LLM 响应"""
text: str
usage_info: dict[str, Any] | None = None
raw_response: dict[str, Any] | None = None
tool_calls: list[Any] | None = None
code_executions: list[Any] | None = None
grounding_metadata: Any | None = None
cache_info: Any | None = None

View File

@ -0,0 +1,67 @@
"""
LLM 枚举类型定义
"""
from enum import Enum, auto
class ModelProvider(Enum):
"""模型提供商枚举"""
OPENAI = "openai"
GEMINI = "gemini"
ZHIXPU = "zhipu"
CUSTOM = "custom"
class ResponseFormat(Enum):
"""响应格式枚举"""
TEXT = "text"
JSON = "json"
MULTIMODAL = "multimodal"
class EmbeddingTaskType(str, Enum):
"""文本嵌入任务类型 (主要用于Gemini)"""
RETRIEVAL_QUERY = "RETRIEVAL_QUERY"
RETRIEVAL_DOCUMENT = "RETRIEVAL_DOCUMENT"
SEMANTIC_SIMILARITY = "SEMANTIC_SIMILARITY"
CLASSIFICATION = "CLASSIFICATION"
CLUSTERING = "CLUSTERING"
QUESTION_ANSWERING = "QUESTION_ANSWERING"
FACT_VERIFICATION = "FACT_VERIFICATION"
class ToolCategory(Enum):
"""工具分类枚举"""
FILE_SYSTEM = auto()
NETWORK = auto()
SYSTEM_INFO = auto()
CALCULATION = auto()
DATA_PROCESSING = auto()
CUSTOM = auto()
class LLMErrorCode(Enum):
"""LLM 服务相关的错误代码枚举"""
MODEL_INIT_FAILED = 2000
MODEL_NOT_FOUND = 2001
API_REQUEST_FAILED = 2002
API_RESPONSE_INVALID = 2003
API_KEY_INVALID = 2004
API_QUOTA_EXCEEDED = 2005
API_TIMEOUT = 2006
API_RATE_LIMITED = 2007
NO_AVAILABLE_KEYS = 2008
UNKNOWN_API_TYPE = 2009
CONFIGURATION_ERROR = 2010
RESPONSE_PARSE_ERROR = 2011
CONTEXT_LENGTH_EXCEEDED = 2012
CONTENT_FILTERED = 2013
USER_LOCATION_NOT_SUPPORTED = 2014
GENERATION_FAILED = 2015
EMBEDDING_FAILED = 2016

View File

@ -0,0 +1,80 @@
"""
LLM 异常类型定义
"""
from typing import Any
from .enums import LLMErrorCode
class LLMException(Exception):
"""LLM 服务相关的基础异常类"""
def __init__(
self,
message: str,
code: LLMErrorCode = LLMErrorCode.API_REQUEST_FAILED,
details: dict[str, Any] | None = None,
recoverable: bool = True,
cause: Exception | None = None,
):
self.message = message
self.code = code
self.details = details or {}
self.recoverable = recoverable
self.cause = cause
super().__init__(message)
def __str__(self) -> str:
if self.details:
return f"{self.message} (错误码: {self.code.name}, 详情: {self.details})"
return f"{self.message} (错误码: {self.code.name})"
@property
def user_friendly_message(self) -> str:
"""返回适合向用户展示的错误消息"""
error_messages = {
LLMErrorCode.MODEL_NOT_FOUND: "AI模型未找到请检查配置或联系管理员。",
LLMErrorCode.API_KEY_INVALID: "API密钥无效请联系管理员更新配置。",
LLMErrorCode.API_QUOTA_EXCEEDED: (
"API使用配额已用尽请稍后再试或联系管理员。"
),
LLMErrorCode.API_TIMEOUT: "AI服务响应超时请稍后再试。",
LLMErrorCode.API_RATE_LIMITED: "请求过于频繁已被AI服务限流请稍后再试。",
LLMErrorCode.MODEL_INIT_FAILED: "AI模型初始化失败请联系管理员检查配置。",
LLMErrorCode.NO_AVAILABLE_KEYS: (
"当前所有API密钥均不可用请稍后再试或联系管理员。"
),
LLMErrorCode.USER_LOCATION_NOT_SUPPORTED: (
"当前地区暂不支持此AI服务请联系管理员或尝试其他模型。"
),
LLMErrorCode.API_REQUEST_FAILED: "AI服务请求失败请稍后再试。",
LLMErrorCode.API_RESPONSE_INVALID: "AI服务响应异常请稍后再试。",
LLMErrorCode.CONFIGURATION_ERROR: "AI服务配置错误请联系管理员。",
LLMErrorCode.CONTEXT_LENGTH_EXCEEDED: "输入内容过长,请缩短后重试。",
LLMErrorCode.CONTENT_FILTERED: "内容被安全过滤,请修改后重试。",
LLMErrorCode.RESPONSE_PARSE_ERROR: "AI服务响应解析失败请稍后再试。",
LLMErrorCode.UNKNOWN_API_TYPE: "不支持的AI服务类型请联系管理员。",
}
return error_messages.get(self.code, "AI服务暂时不可用请稍后再试。")
def get_user_friendly_error_message(error: Exception) -> str:
"""将任何异常转换为用户友好的错误消息"""
if isinstance(error, LLMException):
return error.user_friendly_message
error_str = str(error).lower()
if "timeout" in error_str or "超时" in error_str:
return "请求超时,请稍后再试。"
elif "connection" in error_str or "连接" in error_str:
return "网络连接失败,请检查网络后重试。"
elif "permission" in error_str or "权限" in error_str:
return "权限不足,请联系管理员。"
elif "not found" in error_str or "未找到" in error_str:
return "请求的资源未找到,请检查配置。"
elif "invalid" in error_str or "无效" in error_str:
return "请求参数无效,请检查输入。"
else:
return "服务暂时不可用,请稍后再试。"

View File

@ -0,0 +1,160 @@
"""
LLM 数据模型定义
包含模型信息配置工具定义和响应数据的模型类
"""
from dataclasses import dataclass, field
from typing import Any
from pydantic import BaseModel, Field
from .enums import ModelProvider, ToolCategory
ModelName = str | None
@dataclass(frozen=True)
class ModelInfo:
"""模型信息(不可变数据类)"""
name: str
provider: ModelProvider
max_tokens: int = 4096
supports_tools: bool = False
supports_vision: bool = False
supports_audio: bool = False
cost_per_1k_tokens: float = 0.0
@dataclass
class UsageInfo:
"""使用信息数据类"""
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
cost: float = 0.0
@property
def efficiency_ratio(self) -> float:
"""计算效率比(输出/输入)"""
return self.completion_tokens / max(self.prompt_tokens, 1)
@dataclass
class ToolMetadata:
"""工具元数据"""
name: str
description: str
category: ToolCategory
read_only: bool = True
destructive: bool = False
open_world: bool = False
parameters: dict[str, Any] = field(default_factory=dict)
required_params: list[str] = field(default_factory=list)
class ModelDetail(BaseModel):
"""模型详细信息"""
model_name: str
is_available: bool = True
is_embedding_model: bool = False
temperature: float | None = None
max_tokens: int | None = None
class ProviderConfig(BaseModel):
"""LLM 提供商配置"""
name: str = Field(..., description="Provider 的唯一名称标识")
api_key: str | list[str] = Field(..., description="API Key 或 Key 列表")
api_base: str | None = Field(None, description="API Base URL如果为空则使用默认值")
api_type: str = Field(default="openai", description="API 类型")
openai_compat: bool = Field(default=False, description="是否使用 OpenAI 兼容模式")
temperature: float | None = Field(default=0.7, description="默认温度参数")
max_tokens: int | None = Field(default=None, description="默认最大输出 token 限制")
models: list[ModelDetail] = Field(..., description="支持的模型列表")
timeout: int = Field(default=180, description="请求超时时间")
proxy: str | None = Field(default=None, description="代理设置")
class LLMToolFunction(BaseModel):
"""LLM 工具函数定义"""
name: str
arguments: str
class LLMToolCall(BaseModel):
"""LLM 工具调用"""
id: str
function: LLMToolFunction
class LLMTool(BaseModel):
"""LLM 工具定义(支持 MCP 风格)"""
type: str = "function"
function: dict[str, Any]
annotations: dict[str, Any] | None = Field(default=None, description="工具注解")
@classmethod
def create(
cls,
name: str,
description: str,
parameters: dict[str, Any],
required: list[str] | None = None,
annotations: dict[str, Any] | None = None,
) -> "LLMTool":
"""创建工具"""
function_def = {
"name": name,
"description": description,
"parameters": {
"type": "object",
"properties": parameters,
"required": required or [],
},
}
return cls(type="function", function=function_def, annotations=annotations)
class LLMCodeExecution(BaseModel):
"""代码执行结果"""
code: str
output: str | None = None
error: str | None = None
execution_time: float | None = None
files_generated: list[str] | None = None
class LLMGroundingAttribution(BaseModel):
"""信息来源关联"""
title: str | None = None
uri: str | None = None
snippet: str | None = None
confidence_score: float | None = None
class LLMGroundingMetadata(BaseModel):
"""信息来源关联元数据"""
web_search_queries: list[str] | None = None
grounding_attributions: list[LLMGroundingAttribution] | None = None
search_suggestions: list[dict[str, Any]] | None = None
class LLMCacheInfo(BaseModel):
"""缓存信息"""
cache_hit: bool = False
cache_key: str | None = None
cache_ttl: int | None = None
created_at: str | None = None

View File

@ -0,0 +1,218 @@
"""
LLM 模块的工具和转换函数
"""
import base64
from pathlib import Path
from nonebot_plugin_alconna.uniseg import (
At,
File,
Image,
Reply,
Text,
UniMessage,
Video,
Voice,
)
from zhenxun.services.log import logger
from .types import LLMContentPart
async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
"""
UniMessage 实例转换为一个 LLMContentPart 列表
这是处理多模态输入的核心转换逻辑
"""
parts: list[LLMContentPart] = []
for seg in message:
part = None
if isinstance(seg, Text):
if seg.text.strip():
part = LLMContentPart.text_part(seg.text)
elif isinstance(seg, Image):
if seg.path:
part = await LLMContentPart.from_path(seg.path, target_api="gemini")
elif seg.url:
part = LLMContentPart.image_url_part(seg.url)
elif hasattr(seg, "raw") and seg.raw:
mime_type = (
getattr(seg, "mimetype", "image/png")
if hasattr(seg, "mimetype")
else "image/png"
)
if isinstance(seg.raw, bytes):
b64_data = base64.b64encode(seg.raw).decode("utf-8")
part = LLMContentPart.image_base64_part(b64_data, mime_type)
elif isinstance(seg, File | Voice | Video):
if seg.path:
part = await LLMContentPart.from_path(seg.path)
elif seg.url:
logger.warning(
f"直接使用 URL 的 {type(seg).__name__} 段,"
f"API 可能不支持: {seg.url}"
)
part = LLMContentPart.text_part(
f"[{type(seg).__name__.upper()} FILE: {seg.name or seg.url}]"
)
elif hasattr(seg, "raw") and seg.raw:
mime_type = getattr(seg, "mimetype", None)
if isinstance(seg.raw, bytes):
b64_data = base64.b64encode(seg.raw).decode("utf-8")
if isinstance(seg, Video):
if not mime_type:
mime_type = "video/mp4"
part = LLMContentPart.video_base64_part(
data=b64_data, mime_type=mime_type
)
logger.debug(
f"处理视频字节数据: {mime_type}, 大小: {len(seg.raw)} bytes"
)
elif isinstance(seg, Voice):
if not mime_type:
mime_type = "audio/wav"
part = LLMContentPart.audio_base64_part(
data=b64_data, mime_type=mime_type
)
logger.debug(
f"处理音频字节数据: {mime_type}, 大小: {len(seg.raw)} bytes"
)
else:
part = LLMContentPart.text_part(
f"[FILE: {mime_type or 'unknown'}, {len(seg.raw)} bytes]"
)
logger.debug(
f"处理其他文件字节数据: {mime_type}, "
f"大小: {len(seg.raw)} bytes"
)
elif isinstance(seg, At):
if seg.flag == "all":
part = LLMContentPart.text_part("[Mentioned Everyone]")
else:
part = LLMContentPart.text_part(f"[Mentioned user: {seg.target}]")
elif isinstance(seg, Reply):
if seg.msg:
try:
extract_method = getattr(seg.msg, "extract_plain_text", None)
if extract_method and callable(extract_method):
reply_text = str(extract_method()).strip()
else:
reply_text = str(seg.msg).strip()
if reply_text:
part = LLMContentPart.text_part(
f'[Replied to: "{reply_text[:50]}..."]'
)
except Exception:
part = LLMContentPart.text_part("[Replied to a message]")
if part:
parts.append(part)
return parts
def create_multimodal_message(
text: str | None = None,
images: list[str | Path | bytes] | str | Path | bytes | None = None,
videos: list[str | Path | bytes] | str | Path | bytes | None = None,
audios: list[str | Path | bytes] | str | Path | bytes | None = None,
image_mimetypes: list[str] | str | None = None,
video_mimetypes: list[str] | str | None = None,
audio_mimetypes: list[str] | str | None = None,
) -> UniMessage:
"""
创建多模态消息的便捷函数方便第三方调用
Args:
text: 文本内容
images: 图片数据支持路径字节数据或URL
videos: 视频数据支持路径字节数据或URL
audios: 音频数据支持路径字节数据或URL
image_mimetypes: 图片MIME类型当images为bytes时需要指定
video_mimetypes: 视频MIME类型当videos为bytes时需要指定
audio_mimetypes: 音频MIME类型当audios为bytes时需要指定
Returns:
UniMessage: 构建好的多模态消息
Examples:
# 纯文本
msg = create_multimodal_message("请分析这段文字")
# 文本 + 单张图片(路径)
msg = create_multimodal_message("分析图片", images="/path/to/image.jpg")
# 文本 + 多张图片
msg = create_multimodal_message(
"比较图片", images=["/path/1.jpg", "/path/2.jpg"]
)
# 文本 + 图片字节数据
msg = create_multimodal_message(
"分析", images=image_data, image_mimetypes="image/jpeg"
)
# 文本 + 视频
msg = create_multimodal_message("分析视频", videos="/path/to/video.mp4")
# 文本 + 音频
msg = create_multimodal_message("转录音频", audios="/path/to/audio.wav")
# 混合多模态
msg = create_multimodal_message(
"分析这些媒体文件",
images="/path/to/image.jpg",
videos="/path/to/video.mp4",
audios="/path/to/audio.wav"
)
"""
message = UniMessage()
if text:
message.append(Text(text))
if images is not None:
_add_media_to_message(message, images, image_mimetypes, Image, "image/png")
if videos is not None:
_add_media_to_message(message, videos, video_mimetypes, Video, "video/mp4")
if audios is not None:
_add_media_to_message(message, audios, audio_mimetypes, Voice, "audio/wav")
return message
def _add_media_to_message(
message: UniMessage,
media_items: list[str | Path | bytes] | str | Path | bytes,
mimetypes: list[str] | str | None,
media_class: type,
default_mimetype: str,
) -> None:
"""添加媒体文件到 UniMessage 的辅助函数"""
if not isinstance(media_items, list):
media_items = [media_items]
mime_list = []
if mimetypes is not None:
if isinstance(mimetypes, str):
mime_list = [mimetypes] * len(media_items)
else:
mime_list = list(mimetypes)
for i, item in enumerate(media_items):
if isinstance(item, str | Path):
if str(item).startswith(("http://", "https://")):
message.append(media_class(url=str(item)))
else:
message.append(media_class(path=Path(item)))
elif isinstance(item, bytes):
mimetype = mime_list[i] if i < len(mime_list) else default_mimetype
message.append(media_class(raw=item, mimetype=mimetype))