mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-14 21:52:56 +08:00
✨ feat(llm): 实现LLM服务模块,支持多提供商统一接口和高级功能 (#1923)
* ✨ feat(llm): 实现LLM服务模块,支持多提供商统一接口和高级功能 * 🎨 Ruff * ✨ Config配置类支持BaseModel存储 * 🎨 代码格式化 * 🎨 代码格式化 * 🎨 格式化代码 * ✨ feat(llm): 添加 AI 对话历史管理 * ✨ feat(llmConfig): 引入 LLM 配置模型及管理功能 * 🎨 Ruff --------- Co-authored-by: fccckaug <xxxmio123123@gmail.com> Co-authored-by: HibiKier <45528451+HibiKier@users.noreply.github.com> Co-authored-by: HibiKier <775757368@qq.com> Co-authored-by: fccckaug <xxxmcsmiomio3@gmail.com> Co-authored-by: webjoin111 <455457521@qq.com>
This commit is contained in:
parent
14f5842f10
commit
a020ea5c87
@ -58,7 +58,7 @@ def _generate_simple_config(exists_module: list[str]):
|
||||
生成简易配置
|
||||
|
||||
异常:
|
||||
AttributeError: _description_
|
||||
AttributeError: AttributeError
|
||||
"""
|
||||
# 读取用户配置
|
||||
_data = {}
|
||||
@ -74,7 +74,9 @@ def _generate_simple_config(exists_module: list[str]):
|
||||
if _data.get(module) and k in _data[module].keys():
|
||||
Config.set_config(module, k, _data[module][k])
|
||||
if f"{module}:{k}".lower() in exists_module:
|
||||
_tmp_data[module][k] = Config.get_config(module, k)
|
||||
_tmp_data[module][k] = Config.get_config(
|
||||
module, k, build_model=False
|
||||
)
|
||||
except AttributeError as e:
|
||||
raise AttributeError(f"{e}\n可能为config.yaml配置文件填写不规范") from e
|
||||
if not _tmp_data[module]:
|
||||
|
||||
@ -1,89 +1,82 @@
|
||||
from collections.abc import Callable
|
||||
import copy
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
from typing import Any, TypeVar, get_args, get_origin
|
||||
|
||||
import cattrs
|
||||
from nonebot.compat import model_dump
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import VERSION, BaseModel, Field
|
||||
from ruamel.yaml import YAML
|
||||
from ruamel.yaml.scanner import ScannerError
|
||||
|
||||
from zhenxun.configs.path_config import DATA_PATH
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import BlockType, LimitWatchType, PluginLimitType, PluginType
|
||||
|
||||
from .models import (
|
||||
AICallableParam,
|
||||
AICallableProperties,
|
||||
AICallableTag,
|
||||
BaseBlock,
|
||||
Command,
|
||||
ConfigModel,
|
||||
Example,
|
||||
PluginCdBlock,
|
||||
PluginCountBlock,
|
||||
PluginExtraData,
|
||||
PluginSetting,
|
||||
RegisterConfig,
|
||||
Task,
|
||||
)
|
||||
|
||||
_yaml = YAML(pure=True)
|
||||
_yaml.indent = 2
|
||||
_yaml.allow_unicode = True
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
class Example(BaseModel):
|
||||
|
||||
class NoSuchConfig(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _dump_pydantic_obj(obj: Any) -> Any:
|
||||
"""
|
||||
示例
|
||||
递归地将一个对象内部的 Pydantic BaseModel 实例转换为字典。
|
||||
支持单个实例、实例列表、实例字典等情况。
|
||||
"""
|
||||
|
||||
exec: str
|
||||
"""执行命令"""
|
||||
description: str = ""
|
||||
"""命令描述"""
|
||||
if isinstance(obj, BaseModel):
|
||||
return model_dump(obj)
|
||||
if isinstance(obj, list):
|
||||
return [_dump_pydantic_obj(item) for item in obj]
|
||||
if isinstance(obj, dict):
|
||||
return {key: _dump_pydantic_obj(value) for key, value in obj.items()}
|
||||
return obj
|
||||
|
||||
|
||||
class Command(BaseModel):
|
||||
def _is_pydantic_type(t: Any) -> bool:
|
||||
"""
|
||||
具体参数说明
|
||||
递归检查一个类型注解是否与 Pydantic BaseModel 相关。
|
||||
"""
|
||||
|
||||
command: str
|
||||
"""命令名称"""
|
||||
params: list[str] = Field(default_factory=list)
|
||||
"""参数"""
|
||||
description: str = ""
|
||||
"""描述"""
|
||||
examples: list[Example] = Field(default_factory=list)
|
||||
"""示例列表"""
|
||||
if t is None:
|
||||
return False
|
||||
origin = get_origin(t)
|
||||
if origin:
|
||||
return any(_is_pydantic_type(arg) for arg in get_args(t))
|
||||
return isinstance(t, type) and issubclass(t, BaseModel)
|
||||
|
||||
|
||||
class RegisterConfig(BaseModel):
|
||||
def parse_as(type_: type[T], obj: Any) -> T:
|
||||
"""
|
||||
注册配置项
|
||||
一个兼容 Pydantic V1 的 parse_obj_as 和V2的TypeAdapter.validate_python 的辅助函数。
|
||||
"""
|
||||
if VERSION.startswith("1"):
|
||||
from pydantic import parse_obj_as
|
||||
|
||||
key: str
|
||||
"""配置项键"""
|
||||
value: Any
|
||||
"""配置项值"""
|
||||
module: str | None = None
|
||||
"""模块名"""
|
||||
help: str | None
|
||||
"""配置注解"""
|
||||
default_value: Any | None = None
|
||||
"""默认值"""
|
||||
type: Any = None
|
||||
"""参数类型"""
|
||||
arg_parser: Callable | None = None
|
||||
"""参数解析"""
|
||||
return parse_obj_as(type_, obj)
|
||||
else:
|
||||
from pydantic import TypeAdapter # type: ignore
|
||||
|
||||
|
||||
class ConfigModel(BaseModel):
|
||||
"""
|
||||
配置项
|
||||
"""
|
||||
|
||||
value: Any
|
||||
"""配置项值"""
|
||||
help: str | None
|
||||
"""配置注解"""
|
||||
default_value: Any | None = None
|
||||
"""默认值"""
|
||||
type: Any = None
|
||||
"""参数类型"""
|
||||
arg_parser: Callable | None = None
|
||||
"""参数解析"""
|
||||
|
||||
def to_dict(self, **kwargs):
|
||||
return model_dump(self, **kwargs)
|
||||
return TypeAdapter(type_).validate_python(obj)
|
||||
|
||||
|
||||
class ConfigGroup(BaseModel):
|
||||
@ -98,202 +91,41 @@ class ConfigGroup(BaseModel):
|
||||
configs: dict[str, ConfigModel] = Field(default_factory=dict)
|
||||
"""配置项列表"""
|
||||
|
||||
def get(self, c: str, default: Any = None) -> Any:
|
||||
cfg = self.configs.get(c.upper())
|
||||
if cfg is not None:
|
||||
if cfg.value is not None:
|
||||
return cfg.value
|
||||
if cfg.default_value is not None:
|
||||
return cfg.default_value
|
||||
return default
|
||||
def get(self, c: str, default: Any = None, *, build_model: bool = True) -> Any:
|
||||
"""
|
||||
获取配置项的值。如果指定了类型,会自动构建实例。
|
||||
"""
|
||||
key = c.upper()
|
||||
cfg = self.configs.get(key)
|
||||
|
||||
if cfg is None:
|
||||
return default
|
||||
|
||||
value_to_process = cfg.value if cfg.value is not None else cfg.default_value
|
||||
|
||||
if value_to_process is None:
|
||||
return default
|
||||
|
||||
if cfg.type:
|
||||
if _is_pydantic_type(cfg.type):
|
||||
if build_model:
|
||||
try:
|
||||
return parse_as(cfg.type, value_to_process)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Pydantic 模型解析失败 (key: {c.upper()}). ", e=e
|
||||
)
|
||||
try:
|
||||
return cattrs.structure(value_to_process, cfg.type)
|
||||
except Exception as e:
|
||||
logger.warning(f"Cattrs 结构化失败 (key: {key}),返回原始值。", e=e)
|
||||
|
||||
return value_to_process
|
||||
|
||||
def to_dict(self, **kwargs):
|
||||
return model_dump(self, **kwargs)
|
||||
|
||||
|
||||
class BaseBlock(BaseModel):
|
||||
"""
|
||||
插件阻断基本类(插件阻断限制)
|
||||
"""
|
||||
|
||||
status: bool = True
|
||||
"""限制状态"""
|
||||
check_type: BlockType = BlockType.ALL
|
||||
"""检查类型"""
|
||||
watch_type: LimitWatchType = LimitWatchType.USER
|
||||
"""监听对象"""
|
||||
result: str | None = None
|
||||
"""阻断时回复内容"""
|
||||
_type: PluginLimitType = PluginLimitType.BLOCK
|
||||
"""类型"""
|
||||
|
||||
def to_dict(self, **kwargs):
|
||||
return model_dump(self, **kwargs)
|
||||
|
||||
|
||||
class PluginCdBlock(BaseBlock):
|
||||
"""
|
||||
插件cd限制
|
||||
"""
|
||||
|
||||
cd: int = 5
|
||||
"""cd"""
|
||||
_type: PluginLimitType = PluginLimitType.CD
|
||||
"""类型"""
|
||||
|
||||
|
||||
class PluginCountBlock(BaseBlock):
|
||||
"""
|
||||
插件次数限制
|
||||
"""
|
||||
|
||||
max_count: int
|
||||
"""最大调用次数"""
|
||||
_type: PluginLimitType = PluginLimitType.COUNT
|
||||
"""类型"""
|
||||
|
||||
|
||||
class PluginSetting(BaseModel):
|
||||
"""
|
||||
插件基本配置
|
||||
"""
|
||||
|
||||
level: int = 5
|
||||
"""群权限等级"""
|
||||
default_status: bool = True
|
||||
"""进群默认开关状态"""
|
||||
limit_superuser: bool = False
|
||||
"""是否限制超级用户"""
|
||||
cost_gold: int = 0
|
||||
"""调用插件花费金币"""
|
||||
impression: float = 0.0
|
||||
"""调用插件好感度限制"""
|
||||
|
||||
|
||||
class AICallableProperties(BaseModel):
|
||||
type: str
|
||||
"""参数类型"""
|
||||
description: str
|
||||
"""参数描述"""
|
||||
enums: list[str] | None = None
|
||||
"""参数枚举"""
|
||||
|
||||
|
||||
class AICallableParam(BaseModel):
|
||||
type: str
|
||||
"""类型"""
|
||||
properties: dict[str, AICallableProperties]
|
||||
"""参数列表"""
|
||||
required: list[str]
|
||||
"""必要参数"""
|
||||
|
||||
|
||||
class AICallableTag(BaseModel):
|
||||
name: str
|
||||
"""工具名称"""
|
||||
parameters: AICallableParam | None = None
|
||||
"""工具参数"""
|
||||
description: str
|
||||
"""工具描述"""
|
||||
func: Callable | None = None
|
||||
"""工具函数"""
|
||||
|
||||
def to_dict(self):
|
||||
result = model_dump(self)
|
||||
del result["func"]
|
||||
return result
|
||||
|
||||
|
||||
class SchedulerModel(BaseModel):
|
||||
trigger: Literal["date", "interval", "cron"]
|
||||
"""trigger"""
|
||||
day: int | None = None
|
||||
"""天数"""
|
||||
hour: int | None = None
|
||||
"""小时"""
|
||||
minute: int | None = None
|
||||
"""分钟"""
|
||||
second: int | None = None
|
||||
"""秒"""
|
||||
run_date: datetime | None = None
|
||||
"""运行日期"""
|
||||
id: str | None = None
|
||||
"""id"""
|
||||
max_instances: int | None = None
|
||||
"""最大运行实例"""
|
||||
args: list | None = None
|
||||
"""参数"""
|
||||
kwargs: dict | None = None
|
||||
"""参数"""
|
||||
|
||||
|
||||
class Task(BaseBlock):
|
||||
module: str
|
||||
"""被动技能模块名"""
|
||||
name: str
|
||||
"""被动技能名称"""
|
||||
status: bool = True
|
||||
"""全局开关状态"""
|
||||
create_status: bool = False
|
||||
"""初次加载默认开关状态"""
|
||||
default_status: bool = True
|
||||
"""进群时默认状态"""
|
||||
scheduler: SchedulerModel | None = None
|
||||
"""定时任务配置"""
|
||||
run_func: Callable | None = None
|
||||
"""运行函数"""
|
||||
check: Callable | None = None
|
||||
"""检查函数"""
|
||||
check_args: list = Field(default_factory=list)
|
||||
"""检查函数参数"""
|
||||
|
||||
|
||||
class PluginExtraData(BaseModel):
|
||||
"""
|
||||
插件扩展信息
|
||||
"""
|
||||
|
||||
author: str | None = None
|
||||
"""作者"""
|
||||
version: str | None = None
|
||||
"""版本"""
|
||||
plugin_type: PluginType = PluginType.NORMAL
|
||||
"""插件类型"""
|
||||
menu_type: str = "功能"
|
||||
"""菜单类型"""
|
||||
admin_level: int | None = None
|
||||
"""管理员插件所需权限等级"""
|
||||
configs: list[RegisterConfig] | None = None
|
||||
"""插件配置"""
|
||||
setting: PluginSetting | None = None
|
||||
"""插件基本配置"""
|
||||
limits: list[BaseBlock | PluginCdBlock | PluginCountBlock] | None = None
|
||||
"""插件限制"""
|
||||
commands: list[Command] = Field(default_factory=list)
|
||||
"""命令列表,用于说明帮助"""
|
||||
ignore_prompt: bool = False
|
||||
"""是否忽略阻断提示"""
|
||||
tasks: list[Task] | None = None
|
||||
"""技能被动"""
|
||||
superuser_help: str | None = None
|
||||
"""超级用户帮助"""
|
||||
aliases: set[str] = Field(default_factory=set)
|
||||
"""额外名称"""
|
||||
sql_list: list[str] | None = None
|
||||
"""常用sql"""
|
||||
is_show: bool = True
|
||||
"""是否显示在菜单中"""
|
||||
smart_tools: list[AICallableTag] | None = None
|
||||
"""智能模式函数工具集"""
|
||||
|
||||
def to_dict(self, **kwargs):
|
||||
return model_dump(self, **kwargs)
|
||||
|
||||
|
||||
class NoSuchConfig(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ConfigsManager:
|
||||
"""
|
||||
插件配置 与 资源 管理器
|
||||
@ -366,23 +198,32 @@ class ConfigsManager:
|
||||
|
||||
if not module or not key:
|
||||
raise ValueError("add_plugin_config: module和key不能为为空")
|
||||
if isinstance(value, BaseModel):
|
||||
value = model_dump(value)
|
||||
if isinstance(default_value, BaseModel):
|
||||
default_value = model_dump(default_value)
|
||||
|
||||
processed_value = _dump_pydantic_obj(value)
|
||||
processed_default_value = _dump_pydantic_obj(default_value)
|
||||
|
||||
self.add_module.append(f"{module}:{key}".lower())
|
||||
if module in self._data and (config := self._data[module].configs.get(key)):
|
||||
config.help = help
|
||||
config.arg_parser = arg_parser
|
||||
config.type = type
|
||||
if _override:
|
||||
config.value = value
|
||||
config.default_value = default_value
|
||||
config.value = processed_value
|
||||
config.default_value = processed_default_value
|
||||
else:
|
||||
key = key.upper()
|
||||
if not self._data.get(module):
|
||||
self._data[module] = ConfigGroup(module=module)
|
||||
self._data[module].configs[key] = ConfigModel(
|
||||
value=value,
|
||||
value=processed_value,
|
||||
help=help,
|
||||
default_value=default_value,
|
||||
default_value=processed_default_value,
|
||||
type=type,
|
||||
arg_parser=arg_parser,
|
||||
)
|
||||
|
||||
def set_config(
|
||||
@ -402,6 +243,8 @@ class ConfigsManager:
|
||||
"""
|
||||
key = key.upper()
|
||||
if module in self._data:
|
||||
if module not in self._simple_data:
|
||||
self._simple_data[module] = {}
|
||||
if self._data[module].configs.get(key):
|
||||
self._data[module].configs[key].value = value
|
||||
else:
|
||||
@ -410,63 +253,68 @@ class ConfigsManager:
|
||||
if auto_save:
|
||||
self.save(save_simple_data=True)
|
||||
|
||||
def get_config(self, module: str, key: str, default: Any = None) -> Any:
|
||||
"""获取指定配置值
|
||||
|
||||
参数:
|
||||
module: 模块名
|
||||
key: 配置键
|
||||
default: 没有key值内容的默认返回值.
|
||||
|
||||
异常:
|
||||
NoSuchConfig: 未查询到配置
|
||||
|
||||
返回:
|
||||
Any: 配置值
|
||||
def get_config(
|
||||
self,
|
||||
module: str,
|
||||
key: str,
|
||||
default: Any = None,
|
||||
*,
|
||||
build_model: bool = True,
|
||||
) -> Any:
|
||||
"""
|
||||
获取指定配置值,自动构建Pydantic模型或其它类型实例。
|
||||
- 兼容Pydantic V1/V2。
|
||||
- 支持 list[BaseModel] 等泛型容器。
|
||||
- 优先使用Pydantic原生方式解析,失败后回退到cattrs。
|
||||
"""
|
||||
logger.debug(
|
||||
f"尝试获取配置MODULE: [<u><y>{module}</y></u>] | KEY: [<u><y>{key}</y></u>]"
|
||||
)
|
||||
key = key.upper()
|
||||
value = None
|
||||
if module in self._data.keys():
|
||||
config = self._data[module].configs.get(key) or self._data[
|
||||
module
|
||||
].configs.get(key)
|
||||
if not config:
|
||||
raise NoSuchConfig(
|
||||
f"未查询到配置项 MODULE: [ {module} ] | KEY: [ {key} ]"
|
||||
)
|
||||
config_group = self._data.get(module)
|
||||
if not config_group:
|
||||
return default
|
||||
|
||||
config = config_group.configs.get(key)
|
||||
if not config:
|
||||
return default
|
||||
|
||||
value_to_process = (
|
||||
config.value if config.value is not None else config.default_value
|
||||
)
|
||||
if value_to_process is None:
|
||||
return default
|
||||
|
||||
# 1. 最高优先级:自定义的参数解析器
|
||||
if config.arg_parser:
|
||||
try:
|
||||
if config.arg_parser:
|
||||
value = config.arg_parser(value or config.default_value)
|
||||
elif config.value is not None:
|
||||
# try:
|
||||
value = (
|
||||
cattrs.structure(config.value, config.type)
|
||||
if config.type
|
||||
else config.value
|
||||
)
|
||||
elif config.default_value is not None:
|
||||
value = (
|
||||
cattrs.structure(config.default_value, config.type)
|
||||
if config.type
|
||||
else config.default_value
|
||||
)
|
||||
return config.arg_parser(value_to_process)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"配置项类型转换 MODULE: [<u><y>{module}</y></u>]"
|
||||
f" | KEY: [<u><y>{key}</y></u>] 将使用原始值",
|
||||
e=e,
|
||||
)
|
||||
value = config.value or config.default_value
|
||||
if value is None:
|
||||
value = default
|
||||
logger.debug(
|
||||
f"获取配置 MODULE: [<u><y>{module}</y></u>] | "
|
||||
f" KEY: [<u><y>{key}</y></u>] -> [<u><c>{value}</c></u>]"
|
||||
)
|
||||
return value
|
||||
|
||||
if config.type:
|
||||
if _is_pydantic_type(config.type):
|
||||
if build_model:
|
||||
try:
|
||||
return parse_as(config.type, value_to_process)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"pydantic类型转换失败 MODULE: [<u><y>{module}</y></u>] | "
|
||||
f"KEY: [<u><y>{key}</y></u>].",
|
||||
e=e,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
return cattrs.structure(value_to_process, config.type)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"cattrs类型转换失败 MODULE: [<u><y>{module}</y></u>] | "
|
||||
f"KEY: [<u><y>{key}</y></u>].",
|
||||
e=e,
|
||||
)
|
||||
|
||||
return value_to_process
|
||||
|
||||
def get(self, key: str) -> ConfigGroup:
|
||||
"""获取插件配置数据
|
||||
@ -490,16 +338,16 @@ class ConfigsManager:
|
||||
with open(self._simple_file, "w", encoding="utf8") as f:
|
||||
_yaml.dump(self._simple_data, f)
|
||||
path = path or self.file
|
||||
data = {}
|
||||
for module in self._data:
|
||||
data[module] = {}
|
||||
for config in self._data[module].configs:
|
||||
value = self._data[module].configs[config].dict()
|
||||
del value["type"]
|
||||
del value["arg_parser"]
|
||||
data[module][config] = value
|
||||
save_data = {}
|
||||
for module, config_group in self._data.items():
|
||||
save_data[module] = {}
|
||||
for config_key, config_model in config_group.configs.items():
|
||||
save_data[module][config_key] = model_dump(
|
||||
config_model, exclude={"type", "arg_parser"}
|
||||
)
|
||||
|
||||
with open(path, "w", encoding="utf8") as f:
|
||||
_yaml.dump(data, f)
|
||||
_yaml.dump(save_data, f)
|
||||
|
||||
def reload(self):
|
||||
"""重新加载配置文件"""
|
||||
@ -558,3 +406,23 @@ class ConfigsManager:
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._data[key]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AICallableParam",
|
||||
"AICallableProperties",
|
||||
"AICallableTag",
|
||||
"BaseBlock",
|
||||
"Command",
|
||||
"ConfigGroup",
|
||||
"ConfigModel",
|
||||
"ConfigsManager",
|
||||
"Example",
|
||||
"NoSuchConfig",
|
||||
"PluginCdBlock",
|
||||
"PluginCountBlock",
|
||||
"PluginExtraData",
|
||||
"PluginSetting",
|
||||
"RegisterConfig",
|
||||
"Task",
|
||||
]
|
||||
|
||||
270
zhenxun/configs/utils/models.py
Normal file
270
zhenxun/configs/utils/models.py
Normal file
@ -0,0 +1,270 @@
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
|
||||
from nonebot.compat import model_dump
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from zhenxun.utils.enum import BlockType, LimitWatchType, PluginLimitType, PluginType
|
||||
|
||||
__all__ = [
|
||||
"AICallableParam",
|
||||
"AICallableProperties",
|
||||
"AICallableTag",
|
||||
"BaseBlock",
|
||||
"Command",
|
||||
"ConfigModel",
|
||||
"Example",
|
||||
"PluginCdBlock",
|
||||
"PluginCountBlock",
|
||||
"PluginExtraData",
|
||||
"PluginSetting",
|
||||
"RegisterConfig",
|
||||
"Task",
|
||||
]
|
||||
|
||||
|
||||
class Example(BaseModel):
|
||||
"""
|
||||
示例
|
||||
"""
|
||||
|
||||
exec: str
|
||||
"""执行命令"""
|
||||
description: str = ""
|
||||
"""命令描述"""
|
||||
|
||||
|
||||
class Command(BaseModel):
|
||||
"""
|
||||
具体参数说明
|
||||
"""
|
||||
|
||||
command: str
|
||||
"""命令名称"""
|
||||
params: list[str] = Field(default_factory=list)
|
||||
"""参数"""
|
||||
description: str = ""
|
||||
"""描述"""
|
||||
examples: list[Example] = Field(default_factory=list)
|
||||
"""示例列表"""
|
||||
|
||||
|
||||
class RegisterConfig(BaseModel):
|
||||
"""
|
||||
注册配置项
|
||||
"""
|
||||
|
||||
key: str
|
||||
"""配置项键"""
|
||||
value: Any
|
||||
"""配置项值"""
|
||||
module: str | None = None
|
||||
"""模块名"""
|
||||
help: str | None
|
||||
"""配置注解"""
|
||||
default_value: Any | None = None
|
||||
"""默认值"""
|
||||
type: Any = None
|
||||
"""参数类型"""
|
||||
arg_parser: Callable | None = None
|
||||
"""参数解析"""
|
||||
|
||||
|
||||
class ConfigModel(BaseModel):
|
||||
"""
|
||||
配置项
|
||||
"""
|
||||
|
||||
value: Any
|
||||
"""配置项值"""
|
||||
help: str | None
|
||||
"""配置注解"""
|
||||
default_value: Any | None = None
|
||||
"""默认值"""
|
||||
type: Any = None
|
||||
"""参数类型"""
|
||||
arg_parser: Callable | None = None
|
||||
"""参数解析"""
|
||||
|
||||
def to_dict(self, **kwargs):
|
||||
return model_dump(self, **kwargs)
|
||||
|
||||
|
||||
class BaseBlock(BaseModel):
|
||||
"""
|
||||
插件阻断基本类(插件阻断限制)
|
||||
"""
|
||||
|
||||
status: bool = True
|
||||
"""限制状态"""
|
||||
check_type: BlockType = BlockType.ALL
|
||||
"""检查类型"""
|
||||
watch_type: LimitWatchType = LimitWatchType.USER
|
||||
"""监听对象"""
|
||||
result: str | None = None
|
||||
"""阻断时回复内容"""
|
||||
_type: PluginLimitType = PluginLimitType.BLOCK
|
||||
"""类型"""
|
||||
|
||||
def to_dict(self, **kwargs):
|
||||
return model_dump(self, **kwargs)
|
||||
|
||||
|
||||
class PluginCdBlock(BaseBlock):
|
||||
"""
|
||||
插件cd限制
|
||||
"""
|
||||
|
||||
cd: int = 5
|
||||
"""cd"""
|
||||
_type: PluginLimitType = PluginLimitType.CD
|
||||
"""类型"""
|
||||
|
||||
|
||||
class PluginCountBlock(BaseBlock):
|
||||
"""
|
||||
插件次数限制
|
||||
"""
|
||||
|
||||
max_count: int
|
||||
"""最大调用次数"""
|
||||
_type: PluginLimitType = PluginLimitType.COUNT
|
||||
"""类型"""
|
||||
|
||||
|
||||
class PluginSetting(BaseModel):
|
||||
"""
|
||||
插件基本配置
|
||||
"""
|
||||
|
||||
level: int = 5
|
||||
"""群权限等级"""
|
||||
default_status: bool = True
|
||||
"""进群默认开关状态"""
|
||||
limit_superuser: bool = False
|
||||
"""是否限制超级用户"""
|
||||
cost_gold: int = 0
|
||||
"""调用插件花费金币"""
|
||||
impression: float = 0.0
|
||||
"""调用插件好感度限制"""
|
||||
|
||||
|
||||
class AICallableProperties(BaseModel):
|
||||
type: str
|
||||
"""参数类型"""
|
||||
description: str
|
||||
"""参数描述"""
|
||||
enums: list[str] | None = None
|
||||
"""参数枚举"""
|
||||
|
||||
|
||||
class AICallableParam(BaseModel):
|
||||
type: str
|
||||
"""类型"""
|
||||
properties: dict[str, AICallableProperties]
|
||||
"""参数列表"""
|
||||
required: list[str]
|
||||
"""必要参数"""
|
||||
|
||||
|
||||
class AICallableTag(BaseModel):
|
||||
name: str
|
||||
"""工具名称"""
|
||||
parameters: AICallableParam | None = None
|
||||
"""工具参数"""
|
||||
description: str
|
||||
"""工具描述"""
|
||||
func: Callable | None = None
|
||||
"""工具函数"""
|
||||
|
||||
def to_dict(self):
|
||||
result = model_dump(self)
|
||||
del result["func"]
|
||||
return result
|
||||
|
||||
|
||||
class SchedulerModel(BaseModel):
|
||||
trigger: Literal["date", "interval", "cron"]
|
||||
"""trigger"""
|
||||
day: int | None = None
|
||||
"""天数"""
|
||||
hour: int | None = None
|
||||
"""小时"""
|
||||
minute: int | None = None
|
||||
"""分钟"""
|
||||
second: int | None = None
|
||||
"""秒"""
|
||||
run_date: datetime | None = None
|
||||
"""运行日期"""
|
||||
id: str | None = None
|
||||
"""id"""
|
||||
max_instances: int | None = None
|
||||
"""最大运行实例"""
|
||||
args: list | None = None
|
||||
"""参数"""
|
||||
kwargs: dict | None = None
|
||||
"""参数"""
|
||||
|
||||
|
||||
class Task(BaseBlock):
|
||||
module: str
|
||||
"""被动技能模块名"""
|
||||
name: str
|
||||
"""被动技能名称"""
|
||||
status: bool = True
|
||||
"""全局开关状态"""
|
||||
create_status: bool = False
|
||||
"""初次加载默认开关状态"""
|
||||
default_status: bool = True
|
||||
"""进群时默认状态"""
|
||||
scheduler: SchedulerModel | None = None
|
||||
"""定时任务配置"""
|
||||
run_func: Callable | None = None
|
||||
"""运行函数"""
|
||||
check: Callable | None = None
|
||||
"""检查函数"""
|
||||
check_args: list = Field(default_factory=list)
|
||||
"""检查函数参数"""
|
||||
|
||||
|
||||
class PluginExtraData(BaseModel):
|
||||
"""
|
||||
插件扩展信息
|
||||
"""
|
||||
|
||||
author: str | None = None
|
||||
"""作者"""
|
||||
version: str | None = None
|
||||
"""版本"""
|
||||
plugin_type: PluginType = PluginType.NORMAL
|
||||
"""插件类型"""
|
||||
menu_type: str = "功能"
|
||||
"""菜单类型"""
|
||||
admin_level: int | None = None
|
||||
"""管理员插件所需权限等级"""
|
||||
configs: list[RegisterConfig] | None = None
|
||||
"""插件配置"""
|
||||
setting: PluginSetting | None = None
|
||||
"""插件基本配置"""
|
||||
limits: list[BaseBlock | PluginCdBlock | PluginCountBlock] | None = None
|
||||
"""插件限制"""
|
||||
commands: list[Command] = Field(default_factory=list)
|
||||
"""命令列表,用于说明帮助"""
|
||||
ignore_prompt: bool = False
|
||||
"""是否忽略阻断提示"""
|
||||
tasks: list[Task] | None = None
|
||||
"""技能被动"""
|
||||
superuser_help: str | None = None
|
||||
"""超级用户帮助"""
|
||||
aliases: set[str] = Field(default_factory=set)
|
||||
"""额外名称"""
|
||||
sql_list: list[str] | None = None
|
||||
"""常用sql"""
|
||||
is_show: bool = True
|
||||
"""是否显示在菜单中"""
|
||||
smart_tools: list[AICallableTag] | None = None
|
||||
"""智能模式函数工具集"""
|
||||
|
||||
def to_dict(self, **kwargs):
|
||||
return model_dump(self, **kwargs)
|
||||
731
zhenxun/services/llm/README.md
Normal file
731
zhenxun/services/llm/README.md
Normal file
@ -0,0 +1,731 @@
|
||||
# Zhenxun LLM 服务模块
|
||||
|
||||
## 📑 目录
|
||||
|
||||
- [📖 概述](#-概述)
|
||||
- [🌟 主要特性](#-主要特性)
|
||||
- [🚀 快速开始](#-快速开始)
|
||||
- [📚 API 参考](#-api-参考)
|
||||
- [⚙️ 配置](#️-配置)
|
||||
- [🔧 高级功能](#-高级功能)
|
||||
- [🏗️ 架构设计](#️-架构设计)
|
||||
- [🔌 支持的提供商](#-支持的提供商)
|
||||
- [🎯 使用场景](#-使用场景)
|
||||
- [📊 性能优化](#-性能优化)
|
||||
- [🛠️ 故障排除](#️-故障排除)
|
||||
- [❓ 常见问题](#-常见问题)
|
||||
- [📝 示例项目](#-示例项目)
|
||||
- [🤝 贡献](#-贡献)
|
||||
- [📄 许可证](#-许可证)
|
||||
|
||||
## 📖 概述
|
||||
|
||||
Zhenxun LLM 服务模块是一个现代化的AI服务框架,提供统一的接口来访问多个大语言模型提供商。该模块采用模块化设计,支持异步操作、智能重试、Key轮询和负载均衡等高级功能。
|
||||
|
||||
### 🌟 主要特性
|
||||
|
||||
- **多提供商支持**: OpenAI、Gemini、智谱AI、DeepSeek等
|
||||
- **统一接口**: 简洁一致的API设计
|
||||
- **智能Key轮询**: 自动负载均衡和故障转移
|
||||
- **异步高性能**: 基于asyncio的并发处理
|
||||
- **模型缓存**: 智能缓存机制提升性能
|
||||
- **工具调用**: 支持Function Calling
|
||||
- **嵌入向量**: 文本向量化支持
|
||||
- **错误处理**: 完善的异常处理和重试机制
|
||||
- **多模态支持**: 文本、图像、音频、视频处理
|
||||
- **代码执行**: Gemini代码执行功能
|
||||
- **搜索增强**: Google搜索集成
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
### 基本使用
|
||||
|
||||
```python
|
||||
from zhenxun.services.llm import chat, code, search, analyze
|
||||
|
||||
# 简单聊天
|
||||
response = await chat("你好,请介绍一下自己")
|
||||
print(response)
|
||||
|
||||
# 代码执行
|
||||
result = await code("计算斐波那契数列的前10项")
|
||||
print(result["text"])
|
||||
print(result["code_executions"])
|
||||
|
||||
# 搜索功能
|
||||
search_result = await search("Python异步编程最佳实践")
|
||||
print(search_result["text"])
|
||||
|
||||
# 多模态分析
|
||||
from nonebot_plugin_alconna.uniseg import UniMessage, Image, Text
|
||||
message = UniMessage([
|
||||
Text("分析这张图片"),
|
||||
Image(path="image.jpg")
|
||||
])
|
||||
analysis = await analyze(message, model="Gemini/gemini-2.0-flash")
|
||||
print(analysis)
|
||||
```
|
||||
|
||||
### 使用AI类
|
||||
|
||||
```python
|
||||
from zhenxun.services.llm import AI, AIConfig, CommonOverrides
|
||||
|
||||
# 创建AI实例
|
||||
ai = AI(AIConfig(model="OpenAI/gpt-4"))
|
||||
|
||||
# 聊天对话
|
||||
response = await ai.chat("解释量子计算的基本原理")
|
||||
|
||||
# 多模态分析
|
||||
from nonebot_plugin_alconna.uniseg import UniMessage, Image, Text
|
||||
|
||||
multimodal_msg = UniMessage([
|
||||
Text("这张图片显示了什么?"),
|
||||
Image(path="image.jpg")
|
||||
])
|
||||
result = await ai.analyze(multimodal_msg)
|
||||
|
||||
# 便捷的多模态函数
|
||||
result = await analyze_with_images(
|
||||
"分析这张图片",
|
||||
images="image.jpg",
|
||||
model="Gemini/gemini-2.0-flash"
|
||||
)
|
||||
```
|
||||
|
||||
## 📚 API 参考
|
||||
|
||||
### 快速函数
|
||||
|
||||
#### `chat(message, *, model=None, **kwargs) -> str`
|
||||
简单聊天对话
|
||||
|
||||
**参数:**
|
||||
- `message`: 消息内容(字符串、LLMMessage或内容部分列表)
|
||||
- `model`: 模型名称(可选)
|
||||
- `**kwargs`: 额外配置参数
|
||||
|
||||
#### `code(prompt, *, model=None, timeout=None, **kwargs) -> dict`
|
||||
代码执行功能
|
||||
|
||||
**返回:**
|
||||
```python
|
||||
{
|
||||
"text": "执行结果说明",
|
||||
"code_executions": [{"code": "...", "output": "..."}],
|
||||
"success": True
|
||||
}
|
||||
```
|
||||
|
||||
#### `search(query, *, model=None, instruction="", **kwargs) -> dict`
|
||||
搜索增强生成
|
||||
|
||||
**返回:**
|
||||
```python
|
||||
{
|
||||
"text": "搜索结果和分析",
|
||||
"grounding_metadata": {...},
|
||||
"success": True
|
||||
}
|
||||
```
|
||||
|
||||
#### `analyze(message, *, instruction="", model=None, tools=None, tool_config=None, **kwargs) -> str | LLMResponse`
|
||||
高级分析功能,支持多模态输入和工具调用
|
||||
|
||||
#### `analyze_with_images(text, images, *, instruction="", model=None, **kwargs) -> str`
|
||||
图片分析便捷函数
|
||||
|
||||
#### `analyze_multimodal(text=None, images=None, videos=None, audios=None, *, instruction="", model=None, **kwargs) -> str`
|
||||
多模态分析便捷函数
|
||||
|
||||
#### `embed(texts, *, model=None, task_type="RETRIEVAL_DOCUMENT", **kwargs) -> list[list[float]]`
|
||||
文本嵌入向量
|
||||
|
||||
### AI类方法
|
||||
|
||||
#### `AI.chat(message, *, model=None, **kwargs) -> str`
|
||||
聊天对话方法,支持简单多模态输入
|
||||
|
||||
#### `AI.analyze(message, *, instruction="", model=None, tools=None, tool_config=None, **kwargs) -> str | LLMResponse`
|
||||
高级分析方法,接收UniMessage进行多模态分析和工具调用
|
||||
|
||||
### 模型管理
|
||||
|
||||
```python
|
||||
from zhenxun.services.llm import (
|
||||
get_model_instance,
|
||||
list_available_models,
|
||||
set_global_default_model_name,
|
||||
clear_model_cache
|
||||
)
|
||||
|
||||
# 获取模型实例
|
||||
model = await get_model_instance("OpenAI/gpt-4o")
|
||||
|
||||
# 列出可用模型
|
||||
models = list_available_models()
|
||||
|
||||
# 设置默认模型
|
||||
set_global_default_model_name("Gemini/gemini-2.0-flash")
|
||||
|
||||
# 清理缓存
|
||||
clear_model_cache()
|
||||
```
|
||||
|
||||
## ⚙️ 配置
|
||||
|
||||
### 预设配置
|
||||
|
||||
```python
|
||||
from zhenxun.services.llm import CommonOverrides
|
||||
|
||||
# 创意模式
|
||||
creative_config = CommonOverrides.creative()
|
||||
|
||||
# 精确模式
|
||||
precise_config = CommonOverrides.precise()
|
||||
|
||||
# Gemini特殊功能
|
||||
json_config = CommonOverrides.gemini_json()
|
||||
thinking_config = CommonOverrides.gemini_thinking()
|
||||
code_exec_config = CommonOverrides.gemini_code_execution()
|
||||
grounding_config = CommonOverrides.gemini_grounding()
|
||||
```
|
||||
|
||||
### 自定义配置
|
||||
|
||||
```python
|
||||
from zhenxun.services.llm import LLMGenerationConfig
|
||||
|
||||
config = LLMGenerationConfig(
|
||||
temperature=0.7,
|
||||
max_tokens=2048,
|
||||
top_p=0.9,
|
||||
frequency_penalty=0.1,
|
||||
presence_penalty=0.1,
|
||||
stop=["END", "STOP"],
|
||||
response_mime_type="application/json",
|
||||
enable_code_execution=True,
|
||||
enable_grounding=True
|
||||
)
|
||||
|
||||
response = await chat("你的问题", override_config=config)
|
||||
```
|
||||
|
||||
## 🔧 高级功能
|
||||
|
||||
### 工具调用 (Function Calling)
|
||||
|
||||
```python
|
||||
from zhenxun.services.llm import LLMTool, get_model_instance
|
||||
|
||||
# 定义工具
|
||||
tools = [
|
||||
LLMTool(
|
||||
name="get_weather",
|
||||
description="获取天气信息",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string", "description": "城市名称"}
|
||||
},
|
||||
"required": ["city"]
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
# 工具执行器
|
||||
async def tool_executor(tool_name: str, args: dict) -> str:
|
||||
if tool_name == "get_weather":
|
||||
return f"{args['city']}今天晴天,25°C"
|
||||
return "未知工具"
|
||||
|
||||
# 使用工具
|
||||
model = await get_model_instance("OpenAI/gpt-4")
|
||||
response = await model.generate_response(
|
||||
messages=[{"role": "user", "content": "北京天气如何?"}],
|
||||
tools=tools,
|
||||
tool_executor=tool_executor
|
||||
)
|
||||
```
|
||||
|
||||
### 多模态处理
|
||||
|
||||
```python
|
||||
from zhenxun.services.llm import create_multimodal_message, analyze_multimodal, analyze_with_images
|
||||
|
||||
# 方法1:使用便捷函数
|
||||
result = await analyze_multimodal(
|
||||
text="分析这些媒体文件",
|
||||
images="image.jpg",
|
||||
audios="audio.mp3",
|
||||
model="Gemini/gemini-2.0-flash"
|
||||
)
|
||||
|
||||
# 方法2:使用create_multimodal_message
|
||||
message = create_multimodal_message(
|
||||
text="分析这张图片和音频",
|
||||
images="image.jpg",
|
||||
audios="audio.mp3"
|
||||
)
|
||||
result = await analyze(message)
|
||||
|
||||
# 方法3:图片分析专用函数
|
||||
result = await analyze_with_images(
|
||||
"这张图片显示了什么?",
|
||||
images=["image1.jpg", "image2.jpg"]
|
||||
)
|
||||
```
|
||||
|
||||
## 🛠️ 故障排除
|
||||
|
||||
### 常见错误
|
||||
|
||||
1. **配置错误**: 检查API密钥和模型配置
|
||||
2. **网络问题**: 检查代理设置和网络连接
|
||||
3. **模型不可用**: 使用 `list_available_models()` 检查可用模型
|
||||
4. **超时错误**: 调整timeout参数或使用更快的模型
|
||||
|
||||
### 调试技巧
|
||||
|
||||
```python
|
||||
from zhenxun.services.llm import get_cache_stats
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
# 查看缓存状态
|
||||
stats = get_cache_stats()
|
||||
print(f"缓存命中率: {stats['hit_rate']}")
|
||||
|
||||
# 启用详细日志
|
||||
logger.setLevel("DEBUG")
|
||||
```
|
||||
|
||||
## ❓ 常见问题
|
||||
|
||||
|
||||
### Q: 如何处理多模态输入?
|
||||
|
||||
**A:** 有多种方式处理多模态输入:
|
||||
```python
|
||||
# 方法1:使用便捷函数
|
||||
result = await analyze_with_images("分析这张图片", images="image.jpg")
|
||||
|
||||
# 方法2:使用analyze函数
|
||||
from nonebot_plugin_alconna.uniseg import UniMessage, Image, Text
|
||||
message = UniMessage([Text("分析这张图片"), Image(path="image.jpg")])
|
||||
result = await analyze(message)
|
||||
|
||||
# 方法3:使用create_multimodal_message
|
||||
from zhenxun.services.llm import create_multimodal_message
|
||||
message = create_multimodal_message(text="分析这张图片", images="image.jpg")
|
||||
result = await analyze(message)
|
||||
```
|
||||
|
||||
### Q: 如何自定义工具调用?
|
||||
|
||||
**A:** 使用analyze函数的tools参数:
|
||||
```python
|
||||
# 定义工具
|
||||
tools = [{
|
||||
"name": "calculator",
|
||||
"description": "计算数学表达式",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {"type": "string", "description": "数学表达式"}
|
||||
},
|
||||
"required": ["expression"]
|
||||
}
|
||||
}]
|
||||
|
||||
# 使用工具
|
||||
from nonebot_plugin_alconna.uniseg import UniMessage, Text
|
||||
message = UniMessage([Text("计算 2+3*4")])
|
||||
response = await analyze(message, tools=tools, tool_config={"mode": "auto"})
|
||||
|
||||
# 如果返回LLMResponse,说明有工具调用
|
||||
if hasattr(response, 'tool_calls'):
|
||||
for tool_call in response.tool_calls:
|
||||
print(f"调用工具: {tool_call.function.name}")
|
||||
print(f"参数: {tool_call.function.arguments}")
|
||||
```
|
||||
|
||||
|
||||
### Q: 如何确保输出格式?
|
||||
|
||||
**A:** 使用结构化输出:
|
||||
```python
|
||||
# JSON格式输出
|
||||
config = CommonOverrides.gemini_json()
|
||||
|
||||
# 自定义Schema
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"answer": {"type": "string"},
|
||||
"confidence": {"type": "number"}
|
||||
}
|
||||
}
|
||||
config = CommonOverrides.gemini_structured(schema)
|
||||
```
|
||||
|
||||
## 📝 示例项目
|
||||
|
||||
### 完整示例
|
||||
|
||||
#### 1. 智能客服机器人
|
||||
|
||||
```python
|
||||
from zhenxun.services.llm import AI, CommonOverrides
|
||||
from typing import Dict, List
|
||||
|
||||
class CustomerService:
|
||||
def __init__(self):
|
||||
self.ai = AI()
|
||||
self.sessions: Dict[str, List[dict]] = {}
|
||||
|
||||
async def handle_query(self, user_id: str, query: str) -> str:
|
||||
# 获取或创建会话历史
|
||||
if user_id not in self.sessions:
|
||||
self.sessions[user_id] = []
|
||||
|
||||
history = self.sessions[user_id]
|
||||
|
||||
# 添加系统提示
|
||||
if not history:
|
||||
history.append({
|
||||
"role": "system",
|
||||
"content": "你是一个专业的客服助手,请友好、准确地回答用户问题。"
|
||||
})
|
||||
|
||||
# 添加用户问题
|
||||
history.append({"role": "user", "content": query})
|
||||
|
||||
# 生成回复
|
||||
response = await self.ai.chat(
|
||||
query,
|
||||
history=history[-20:], # 保留最近20轮对话
|
||||
override_config=CommonOverrides.balanced()
|
||||
)
|
||||
|
||||
# 保存回复到历史
|
||||
history.append({"role": "assistant", "content": response})
|
||||
|
||||
return response
|
||||
```
|
||||
|
||||
#### 2. 文档智能问答
|
||||
|
||||
```python
|
||||
from zhenxun.services.llm import embed, analyze
|
||||
import numpy as np
|
||||
from typing import List, Tuple
|
||||
|
||||
class DocumentQA:
|
||||
def __init__(self):
|
||||
self.documents: List[str] = []
|
||||
self.embeddings: List[List[float]] = []
|
||||
|
||||
async def add_document(self, text: str):
|
||||
"""添加文档到知识库"""
|
||||
self.documents.append(text)
|
||||
|
||||
# 生成嵌入向量
|
||||
embedding = await embed([text])
|
||||
self.embeddings.extend(embedding)
|
||||
|
||||
async def query(self, question: str, top_k: int = 3) -> str:
|
||||
"""查询文档并生成答案"""
|
||||
if not self.documents:
|
||||
return "知识库为空,请先添加文档。"
|
||||
|
||||
# 生成问题的嵌入向量
|
||||
question_embedding = await embed([question])
|
||||
|
||||
# 计算相似度并找到最相关的文档
|
||||
similarities = []
|
||||
for doc_embedding in self.embeddings:
|
||||
similarity = np.dot(question_embedding[0], doc_embedding)
|
||||
similarities.append(similarity)
|
||||
|
||||
# 获取最相关的文档
|
||||
top_indices = np.argsort(similarities)[-top_k:][::-1]
|
||||
relevant_docs = [self.documents[i] for i in top_indices]
|
||||
|
||||
# 构建上下文
|
||||
context = "\n\n".join(relevant_docs)
|
||||
prompt = f"""
|
||||
基于以下文档内容回答问题:
|
||||
|
||||
文档内容:
|
||||
{context}
|
||||
|
||||
问题:{question}
|
||||
|
||||
请基于文档内容给出准确的答案,如果文档中没有相关信息,请说明。
|
||||
"""
|
||||
|
||||
result = await analyze(prompt)
|
||||
return result["text"]
|
||||
```
|
||||
|
||||
#### 3. 代码审查助手
|
||||
|
||||
```python
|
||||
from zhenxun.services.llm import code, analyze
|
||||
import os
|
||||
|
||||
class CodeReviewer:
|
||||
async def review_file(self, file_path: str) -> dict:
|
||||
"""审查代码文件"""
|
||||
if not os.path.exists(file_path):
|
||||
return {"error": "文件不存在"}
|
||||
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
code_content = f.read()
|
||||
|
||||
prompt = f"""
|
||||
请审查以下代码,提供详细的反馈:
|
||||
|
||||
文件:{file_path}
|
||||
代码:
|
||||
```
|
||||
{code_content}
|
||||
```
|
||||
|
||||
请从以下方面进行审查:
|
||||
1. 代码质量和可读性
|
||||
2. 潜在的bug和安全问题
|
||||
3. 性能优化建议
|
||||
4. 最佳实践建议
|
||||
5. 代码风格问题
|
||||
|
||||
请以JSON格式返回结果。
|
||||
"""
|
||||
|
||||
result = await analyze(
|
||||
prompt,
|
||||
model="DeepSeek/deepseek-coder",
|
||||
override_config=CommonOverrides.gemini_json()
|
||||
)
|
||||
|
||||
return {
|
||||
"file": file_path,
|
||||
"review": result["text"],
|
||||
"success": True
|
||||
}
|
||||
|
||||
async def suggest_improvements(self, code: str, language: str = "python") -> str:
|
||||
"""建议代码改进"""
|
||||
prompt = f"""
|
||||
请改进以下{language}代码,使其更加高效、可读和符合最佳实践:
|
||||
|
||||
原代码:
|
||||
```{language}
|
||||
{code}
|
||||
```
|
||||
|
||||
请提供改进后的代码和说明。
|
||||
"""
|
||||
|
||||
result = await code(prompt, model="DeepSeek/deepseek-coder")
|
||||
return result["text"]
|
||||
```
|
||||
|
||||
|
||||
## 🏗️ 架构设计
|
||||
|
||||
### 模块结构
|
||||
|
||||
```
|
||||
zhenxun/services/llm/
|
||||
├── __init__.py # 包入口,导入和暴露公共API
|
||||
├── api.py # 高级API接口(AI类、便捷函数)
|
||||
├── core.py # 核心基础设施(HTTP客户端、重试逻辑、KeyStore)
|
||||
├── service.py # LLM模型实现类
|
||||
├── utils.py # 工具和转换函数
|
||||
├── manager.py # 模型管理和缓存
|
||||
├── adapters/ # 适配器模块
|
||||
│ ├── __init__.py # 适配器包入口
|
||||
│ ├── base.py # 基础适配器
|
||||
│ ├── factory.py # 适配器工厂
|
||||
│ ├── openai.py # OpenAI适配器
|
||||
│ ├── gemini.py # Gemini适配器
|
||||
│ └── zhipu.py # 智谱AI适配器
|
||||
├── config/ # 配置模块
|
||||
│ ├── __init__.py # 配置包入口
|
||||
│ ├── generation.py # 生成配置
|
||||
│ ├── presets.py # 预设配置
|
||||
│ └── providers.py # 提供商配置
|
||||
└── types/ # 类型定义
|
||||
├── __init__.py # 类型包入口
|
||||
├── content.py # 内容类型
|
||||
├── enums.py # 枚举定义
|
||||
├── exceptions.py # 异常定义
|
||||
└── models.py # 数据模型
|
||||
```
|
||||
|
||||
### 模块职责
|
||||
|
||||
- **`__init__.py`**: 纯粹的包入口,只负责导入和暴露公共API
|
||||
- **`api.py`**: 高级API接口,包含AI类和所有便捷函数
|
||||
- **`core.py`**: 核心基础设施,包含HTTP客户端管理、重试逻辑和KeyStore
|
||||
- **`service.py`**: LLM模型实现类,专注于模型逻辑
|
||||
- **`utils.py`**: 工具和转换函数,如多模态消息处理
|
||||
- **`manager.py`**: 模型管理和缓存机制
|
||||
- **`adapters/`**: 各大提供商的适配器模块,负责与不同API的交互
|
||||
- `base.py`: 定义适配器的基础接口
|
||||
- `factory.py`: 适配器工厂,用于动态加载和实例化适配器
|
||||
- `openai.py`: OpenAI API适配器
|
||||
- `gemini.py`: Google Gemini API适配器
|
||||
- `zhipu.py`: 智谱AI API适配器
|
||||
- **`config/`**: 配置管理模块
|
||||
- `generation.py`: 生成配置和预设
|
||||
- `presets.py`: 预设配置
|
||||
- `providers.py`: 提供商配置
|
||||
- **`types/`**: 类型定义模块
|
||||
- `content.py`: 内容类型定义
|
||||
- `enums.py`: 枚举定义
|
||||
- `exceptions.py`: 异常定义
|
||||
- `models.py`: 数据模型定义
|
||||
|
||||
## 🔌 支持的提供商
|
||||
|
||||
### OpenAI 兼容
|
||||
|
||||
- **OpenAI**: GPT-4o, GPT-3.5-turbo等
|
||||
- **DeepSeek**: deepseek-chat, deepseek-reasoner等
|
||||
- **其他OpenAI兼容API**: 支持自定义端点
|
||||
|
||||
```python
|
||||
# OpenAI
|
||||
await chat("Hello", model="OpenAI/gpt-4o")
|
||||
|
||||
# DeepSeek
|
||||
await chat("写代码", model="DeepSeek/deepseek-reasoner")
|
||||
```
|
||||
|
||||
### Google Gemini
|
||||
|
||||
- **Gemini Pro**: gemini-2.5-flash-preview-05-20 gemini-2.0-flash等
|
||||
- **特殊功能**: 代码执行、搜索增强、思考模式
|
||||
|
||||
```python
|
||||
# 基础使用
|
||||
await chat("你好", model="Gemini/gemini-2.0-flash")
|
||||
|
||||
# 代码执行
|
||||
await code("计算质数", model="Gemini/gemini-2.0-flash")
|
||||
|
||||
# 搜索增强
|
||||
await search("最新AI发展", model="Gemini/gemini-2.5-flash-preview-05-20")
|
||||
```
|
||||
|
||||
### 智谱AI
|
||||
|
||||
- **GLM系列**: glm-4, glm-4v等
|
||||
- **支持功能**: 文本生成、多模态理解
|
||||
|
||||
```python
|
||||
await chat("介绍北京", model="Zhipu/glm-4")
|
||||
```
|
||||
|
||||
## 🎯 使用场景
|
||||
|
||||
### 1. 聊天机器人
|
||||
|
||||
```python
|
||||
from zhenxun.services.llm import AI, CommonOverrides
|
||||
|
||||
class ChatBot:
|
||||
def __init__(self):
|
||||
self.ai = AI()
|
||||
self.history = []
|
||||
|
||||
async def chat(self, user_input: str) -> str:
|
||||
# 添加历史记录
|
||||
self.history.append({"role": "user", "content": user_input})
|
||||
|
||||
# 生成回复
|
||||
response = await self.ai.chat(
|
||||
user_input,
|
||||
history=self.history[-10:], # 保留最近10轮对话
|
||||
override_config=CommonOverrides.balanced()
|
||||
)
|
||||
|
||||
self.history.append({"role": "assistant", "content": response})
|
||||
return response
|
||||
```
|
||||
|
||||
### 2. 代码助手
|
||||
|
||||
```python
|
||||
async def code_assistant(task: str) -> dict:
|
||||
"""代码生成和执行助手"""
|
||||
result = await code(
|
||||
f"请帮我{task},并执行代码验证结果",
|
||||
model="Gemini/gemini-2.0-flash",
|
||||
timeout=60
|
||||
)
|
||||
|
||||
return {
|
||||
"explanation": result["text"],
|
||||
"code_blocks": result["code_executions"],
|
||||
"success": result["success"]
|
||||
}
|
||||
|
||||
# 使用示例
|
||||
result = await code_assistant("实现快速排序算法")
|
||||
```
|
||||
|
||||
### 3. 文档分析
|
||||
|
||||
```python
|
||||
from zhenxun.services.llm import analyze_with_images
|
||||
|
||||
async def analyze_document(image_path: str, question: str) -> str:
|
||||
"""分析文档图片并回答问题"""
|
||||
result = await analyze_with_images(
|
||||
f"请分析这个文档并回答:{question}",
|
||||
images=image_path,
|
||||
model="Gemini/gemini-2.0-flash"
|
||||
)
|
||||
return result
|
||||
```
|
||||
|
||||
### 4. 智能搜索
|
||||
|
||||
```python
|
||||
async def smart_search(query: str) -> dict:
|
||||
"""智能搜索和总结"""
|
||||
result = await search(
|
||||
query,
|
||||
model="Gemini/gemini-2.0-flash",
|
||||
instruction="请提供准确、最新的信息,并注明信息来源"
|
||||
)
|
||||
|
||||
return {
|
||||
"summary": result["text"],
|
||||
"sources": result.get("grounding_metadata", {}),
|
||||
"confidence": result.get("confidence_score", 0.0)
|
||||
}
|
||||
```
|
||||
|
||||
## 🔧 配置管理
|
||||
|
||||
|
||||
### 动态配置
|
||||
|
||||
```python
|
||||
from zhenxun.services.llm import set_global_default_model_name
|
||||
|
||||
# 运行时更改默认模型
|
||||
set_global_default_model_name("OpenAI/gpt-4")
|
||||
|
||||
# 检查可用模型
|
||||
models = list_available_models()
|
||||
for model in models:
|
||||
print(f"{model.provider}/{model.name} - {model.description}")
|
||||
```
|
||||
|
||||
96
zhenxun/services/llm/__init__.py
Normal file
96
zhenxun/services/llm/__init__.py
Normal file
@ -0,0 +1,96 @@
|
||||
"""
|
||||
LLM 服务模块 - 公共 API 入口
|
||||
|
||||
提供统一的 AI 服务调用接口、核心类型定义和模型管理功能。
|
||||
"""
|
||||
|
||||
from .api import (
|
||||
AI,
|
||||
AIConfig,
|
||||
TaskType,
|
||||
analyze,
|
||||
analyze_multimodal,
|
||||
analyze_with_images,
|
||||
chat,
|
||||
code,
|
||||
embed,
|
||||
search,
|
||||
search_multimodal,
|
||||
)
|
||||
from .config import (
|
||||
CommonOverrides,
|
||||
LLMGenerationConfig,
|
||||
register_llm_configs,
|
||||
)
|
||||
|
||||
register_llm_configs()
|
||||
from .api import ModelName
|
||||
from .manager import (
|
||||
clear_model_cache,
|
||||
get_cache_stats,
|
||||
get_global_default_model_name,
|
||||
get_model_instance,
|
||||
list_available_models,
|
||||
list_embedding_models,
|
||||
list_model_identifiers,
|
||||
set_global_default_model_name,
|
||||
)
|
||||
from .types import (
|
||||
EmbeddingTaskType,
|
||||
LLMContentPart,
|
||||
LLMErrorCode,
|
||||
LLMException,
|
||||
LLMMessage,
|
||||
LLMResponse,
|
||||
LLMTool,
|
||||
ModelDetail,
|
||||
ModelInfo,
|
||||
ModelProvider,
|
||||
ResponseFormat,
|
||||
ToolCategory,
|
||||
ToolMetadata,
|
||||
UsageInfo,
|
||||
)
|
||||
from .utils import create_multimodal_message, unimsg_to_llm_parts
|
||||
|
||||
__all__ = [
|
||||
"AI",
|
||||
"AIConfig",
|
||||
"CommonOverrides",
|
||||
"EmbeddingTaskType",
|
||||
"LLMContentPart",
|
||||
"LLMErrorCode",
|
||||
"LLMException",
|
||||
"LLMGenerationConfig",
|
||||
"LLMMessage",
|
||||
"LLMResponse",
|
||||
"LLMTool",
|
||||
"ModelDetail",
|
||||
"ModelInfo",
|
||||
"ModelName",
|
||||
"ModelProvider",
|
||||
"ResponseFormat",
|
||||
"TaskType",
|
||||
"ToolCategory",
|
||||
"ToolMetadata",
|
||||
"UsageInfo",
|
||||
"analyze",
|
||||
"analyze_multimodal",
|
||||
"analyze_with_images",
|
||||
"chat",
|
||||
"clear_model_cache",
|
||||
"code",
|
||||
"create_multimodal_message",
|
||||
"embed",
|
||||
"get_cache_stats",
|
||||
"get_global_default_model_name",
|
||||
"get_model_instance",
|
||||
"list_available_models",
|
||||
"list_embedding_models",
|
||||
"list_model_identifiers",
|
||||
"register_llm_configs",
|
||||
"search",
|
||||
"search_multimodal",
|
||||
"set_global_default_model_name",
|
||||
"unimsg_to_llm_parts",
|
||||
]
|
||||
26
zhenxun/services/llm/adapters/__init__.py
Normal file
26
zhenxun/services/llm/adapters/__init__.py
Normal file
@ -0,0 +1,26 @@
|
||||
"""
|
||||
LLM 适配器模块
|
||||
|
||||
提供不同LLM服务商的API适配器实现,统一接口调用方式。
|
||||
"""
|
||||
|
||||
from .base import BaseAdapter, OpenAICompatAdapter, RequestData, ResponseData
|
||||
from .factory import LLMAdapterFactory, get_adapter_for_api_type, register_adapter
|
||||
from .gemini import GeminiAdapter
|
||||
from .openai import OpenAIAdapter
|
||||
from .zhipu import ZhipuAdapter
|
||||
|
||||
LLMAdapterFactory.initialize()
|
||||
|
||||
__all__ = [
|
||||
"BaseAdapter",
|
||||
"GeminiAdapter",
|
||||
"LLMAdapterFactory",
|
||||
"OpenAIAdapter",
|
||||
"OpenAICompatAdapter",
|
||||
"RequestData",
|
||||
"ResponseData",
|
||||
"ZhipuAdapter",
|
||||
"get_adapter_for_api_type",
|
||||
"register_adapter",
|
||||
]
|
||||
508
zhenxun/services/llm/adapters/base.py
Normal file
508
zhenxun/services/llm/adapters/base.py
Normal file
@ -0,0 +1,508 @@
|
||||
"""
|
||||
LLM 适配器基类和通用数据结构
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
from ..types.exceptions import LLMErrorCode, LLMException
|
||||
from ..types.models import LLMToolCall
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..config.generation import LLMGenerationConfig
|
||||
from ..service import LLMModel
|
||||
from ..types.content import LLMMessage
|
||||
from ..types.enums import EmbeddingTaskType
|
||||
|
||||
|
||||
class RequestData(BaseModel):
|
||||
"""请求数据封装"""
|
||||
|
||||
url: str
|
||||
headers: dict[str, str]
|
||||
body: dict[str, Any]
|
||||
|
||||
|
||||
class ResponseData(BaseModel):
|
||||
"""响应数据封装 - 支持所有高级功能"""
|
||||
|
||||
text: str
|
||||
usage_info: dict[str, Any] | None = None
|
||||
raw_response: dict[str, Any] | None = None
|
||||
tool_calls: list[LLMToolCall] | None = None
|
||||
code_executions: list[Any] | None = None
|
||||
grounding_metadata: Any | None = None
|
||||
cache_info: Any | None = None
|
||||
|
||||
code_execution_results: list[dict[str, Any]] | None = None
|
||||
search_results: list[dict[str, Any]] | None = None
|
||||
function_calls: list[dict[str, Any]] | None = None
|
||||
safety_ratings: list[dict[str, Any]] | None = None
|
||||
citations: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class BaseAdapter(ABC):
|
||||
"""LLM API适配器基类"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def api_type(self) -> str:
|
||||
"""API类型标识"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supported_api_types(self) -> list[str]:
|
||||
"""支持的API类型列表"""
|
||||
pass
|
||||
|
||||
def prepare_simple_request(
|
||||
self,
|
||||
model: "LLMModel",
|
||||
api_key: str,
|
||||
prompt: str,
|
||||
history: list[dict[str, str]] | None = None,
|
||||
) -> RequestData:
|
||||
"""准备简单文本生成请求
|
||||
|
||||
默认实现:将简单请求转换为高级请求格式
|
||||
子类可以重写此方法以提供特定的优化实现
|
||||
"""
|
||||
from ..types.content import LLMMessage
|
||||
|
||||
messages: list[LLMMessage] = []
|
||||
|
||||
if history:
|
||||
for msg in history:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
messages.append(LLMMessage(role=role, content=content))
|
||||
|
||||
messages.append(LLMMessage(role="user", content=prompt))
|
||||
|
||||
config = model._generation_config
|
||||
|
||||
return self.prepare_advanced_request(
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
messages=messages,
|
||||
config=config,
|
||||
tools=None,
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def prepare_advanced_request(
|
||||
self,
|
||||
model: "LLMModel",
|
||||
api_key: str,
|
||||
messages: list["LLMMessage"],
|
||||
config: "LLMGenerationConfig | None" = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
) -> RequestData:
|
||||
"""准备高级请求"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def parse_response(
|
||||
self,
|
||||
model: "LLMModel",
|
||||
response_json: dict[str, Any],
|
||||
is_advanced: bool = False,
|
||||
) -> ResponseData:
|
||||
"""解析API响应"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def prepare_embedding_request(
|
||||
self,
|
||||
model: "LLMModel",
|
||||
api_key: str,
|
||||
texts: list[str],
|
||||
task_type: "EmbeddingTaskType | str",
|
||||
**kwargs: Any,
|
||||
) -> RequestData:
|
||||
"""准备文本嵌入请求"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def parse_embedding_response(
|
||||
self, response_json: dict[str, Any]
|
||||
) -> list[list[float]]:
|
||||
"""解析文本嵌入响应"""
|
||||
pass
|
||||
|
||||
def validate_embedding_response(self, response_json: dict[str, Any]) -> None:
|
||||
"""验证嵌入API响应"""
|
||||
if "error" in response_json:
|
||||
error_info = response_json["error"]
|
||||
msg = (
|
||||
error_info.get("message", str(error_info))
|
||||
if isinstance(error_info, dict)
|
||||
else str(error_info)
|
||||
)
|
||||
raise LLMException(
|
||||
f"嵌入API错误: {msg}",
|
||||
code=LLMErrorCode.EMBEDDING_FAILED,
|
||||
details=response_json,
|
||||
)
|
||||
|
||||
def get_api_url(self, model: "LLMModel", endpoint: str) -> str:
|
||||
"""构建API URL"""
|
||||
if not model.api_base:
|
||||
raise LLMException(
|
||||
f"模型 {model.model_name} 的 api_base 未设置",
|
||||
code=LLMErrorCode.CONFIGURATION_ERROR,
|
||||
)
|
||||
return f"{model.api_base.rstrip('/')}{endpoint}"
|
||||
|
||||
def get_base_headers(self, api_key: str) -> dict[str, str]:
|
||||
"""获取基础请求头"""
|
||||
from zhenxun.utils.user_agent import get_user_agent
|
||||
|
||||
headers = get_user_agent()
|
||||
headers.update(
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
)
|
||||
return headers
|
||||
|
||||
def convert_messages_to_openai_format(
|
||||
self, messages: list["LLMMessage"]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""将LLMMessage转换为OpenAI格式 - 通用方法"""
|
||||
openai_messages: list[dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
openai_msg: dict[str, Any] = {"role": msg.role}
|
||||
|
||||
if msg.role == "tool":
|
||||
openai_msg["tool_call_id"] = msg.tool_call_id
|
||||
openai_msg["name"] = msg.name
|
||||
openai_msg["content"] = msg.content
|
||||
else:
|
||||
if isinstance(msg.content, str):
|
||||
openai_msg["content"] = msg.content
|
||||
else:
|
||||
content_parts = []
|
||||
for part in msg.content:
|
||||
if part.type == "text":
|
||||
content_parts.append({"type": "text", "text": part.text})
|
||||
elif part.type == "image":
|
||||
content_parts.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": part.image_source},
|
||||
}
|
||||
)
|
||||
openai_msg["content"] = content_parts
|
||||
|
||||
if msg.role == "assistant" and msg.tool_calls:
|
||||
assistant_tool_calls = []
|
||||
for call in msg.tool_calls:
|
||||
assistant_tool_calls.append(
|
||||
{
|
||||
"id": call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": call.function.name,
|
||||
"arguments": call.function.arguments,
|
||||
},
|
||||
}
|
||||
)
|
||||
openai_msg["tool_calls"] = assistant_tool_calls
|
||||
|
||||
if msg.name and msg.role != "tool":
|
||||
openai_msg["name"] = msg.name
|
||||
|
||||
openai_messages.append(openai_msg)
|
||||
return openai_messages
|
||||
|
||||
def parse_openai_response(self, response_json: dict[str, Any]) -> ResponseData:
|
||||
"""解析OpenAI格式的响应 - 通用方法"""
|
||||
self.validate_response(response_json)
|
||||
|
||||
try:
|
||||
choices = response_json.get("choices", [])
|
||||
if not choices:
|
||||
logger.debug("OpenAI响应中没有choices,可能为空回复或流结束。")
|
||||
return ResponseData(text="", raw_response=response_json)
|
||||
|
||||
choice = choices[0]
|
||||
message = choice.get("message", {})
|
||||
content = message.get("content", "")
|
||||
|
||||
parsed_tool_calls: list[LLMToolCall] | None = None
|
||||
if message_tool_calls := message.get("tool_calls"):
|
||||
from ..types.models import LLMToolFunction
|
||||
|
||||
parsed_tool_calls = []
|
||||
for tc_data in message_tool_calls:
|
||||
try:
|
||||
if tc_data.get("type") == "function":
|
||||
parsed_tool_calls.append(
|
||||
LLMToolCall(
|
||||
id=tc_data["id"],
|
||||
function=LLMToolFunction(
|
||||
name=tc_data["function"]["name"],
|
||||
arguments=tc_data["function"]["arguments"],
|
||||
),
|
||||
)
|
||||
)
|
||||
except KeyError as e:
|
||||
logger.warning(
|
||||
f"解析OpenAI工具调用数据时缺少键: {tc_data}, 错误: {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"解析OpenAI工具调用数据时出错: {tc_data}, 错误: {e}"
|
||||
)
|
||||
if not parsed_tool_calls:
|
||||
parsed_tool_calls = None
|
||||
|
||||
final_text = content if content is not None else ""
|
||||
if not final_text and parsed_tool_calls:
|
||||
final_text = f"请求调用 {len(parsed_tool_calls)} 个工具。"
|
||||
|
||||
usage_info = response_json.get("usage")
|
||||
|
||||
return ResponseData(
|
||||
text=final_text,
|
||||
tool_calls=parsed_tool_calls,
|
||||
usage_info=usage_info,
|
||||
raw_response=response_json,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析OpenAI格式响应失败: {e}", e=e)
|
||||
raise LLMException(
|
||||
f"解析API响应失败: {e}",
|
||||
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
def validate_response(self, response_json: dict[str, Any]) -> None:
|
||||
"""验证API响应,解析不同API的错误结构"""
|
||||
if "error" in response_json:
|
||||
error_info = response_json["error"]
|
||||
|
||||
if isinstance(error_info, dict):
|
||||
error_message = error_info.get("message", "未知错误")
|
||||
error_code = error_info.get("code", "unknown")
|
||||
error_type = error_info.get("type", "api_error")
|
||||
|
||||
error_code_mapping = {
|
||||
"invalid_api_key": LLMErrorCode.API_KEY_INVALID,
|
||||
"authentication_failed": LLMErrorCode.API_KEY_INVALID,
|
||||
"rate_limit_exceeded": LLMErrorCode.API_RATE_LIMITED,
|
||||
"quota_exceeded": LLMErrorCode.API_RATE_LIMITED,
|
||||
"model_not_found": LLMErrorCode.MODEL_NOT_FOUND,
|
||||
"invalid_model": LLMErrorCode.MODEL_NOT_FOUND,
|
||||
"context_length_exceeded": LLMErrorCode.CONTEXT_LENGTH_EXCEEDED,
|
||||
"max_tokens_exceeded": LLMErrorCode.CONTEXT_LENGTH_EXCEEDED,
|
||||
}
|
||||
|
||||
llm_error_code = error_code_mapping.get(
|
||||
error_code, LLMErrorCode.API_RESPONSE_INVALID
|
||||
)
|
||||
|
||||
logger.error(
|
||||
f"API返回错误: {error_message} "
|
||||
f"(代码: {error_code}, 类型: {error_type})"
|
||||
)
|
||||
else:
|
||||
error_message = str(error_info)
|
||||
error_code = "unknown"
|
||||
llm_error_code = LLMErrorCode.API_RESPONSE_INVALID
|
||||
|
||||
logger.error(f"API返回错误: {error_message}")
|
||||
|
||||
raise LLMException(
|
||||
f"API请求失败: {error_message}",
|
||||
code=llm_error_code,
|
||||
details={"api_error": error_info, "error_code": error_code},
|
||||
)
|
||||
|
||||
if "candidates" in response_json:
|
||||
candidates = response_json.get("candidates", [])
|
||||
if candidates:
|
||||
candidate = candidates[0]
|
||||
finish_reason = candidate.get("finishReason")
|
||||
if finish_reason in ["SAFETY", "RECITATION"]:
|
||||
safety_ratings = candidate.get("safetyRatings", [])
|
||||
logger.warning(
|
||||
f"Gemini内容被安全过滤: {finish_reason}, "
|
||||
f"安全评级: {safety_ratings}"
|
||||
)
|
||||
raise LLMException(
|
||||
f"内容被安全过滤: {finish_reason}",
|
||||
code=LLMErrorCode.CONTENT_FILTERED,
|
||||
details={
|
||||
"finish_reason": finish_reason,
|
||||
"safety_ratings": safety_ratings,
|
||||
},
|
||||
)
|
||||
|
||||
if not response_json:
|
||||
logger.error("API返回空响应")
|
||||
raise LLMException(
|
||||
"API返回空响应",
|
||||
code=LLMErrorCode.API_RESPONSE_INVALID,
|
||||
details={"response": response_json},
|
||||
)
|
||||
|
||||
def _apply_generation_config(
|
||||
self,
|
||||
model: "LLMModel",
|
||||
config: "LLMGenerationConfig | None" = None,
|
||||
) -> dict[str, Any]:
|
||||
"""通用的配置应用逻辑"""
|
||||
if config is not None:
|
||||
return config.to_api_params(model.api_type, model.model_name)
|
||||
|
||||
if model._generation_config is not None:
|
||||
return model._generation_config.to_api_params(
|
||||
model.api_type, model.model_name
|
||||
)
|
||||
|
||||
base_config = {}
|
||||
if model.temperature is not None:
|
||||
base_config["temperature"] = model.temperature
|
||||
if model.max_tokens is not None:
|
||||
if model.api_type in ["gemini", "gemini_native"]:
|
||||
base_config["maxOutputTokens"] = model.max_tokens
|
||||
else:
|
||||
base_config["max_tokens"] = model.max_tokens
|
||||
|
||||
return base_config
|
||||
|
||||
def apply_config_override(
|
||||
self,
|
||||
model: "LLMModel",
|
||||
body: dict[str, Any],
|
||||
config: "LLMGenerationConfig | None" = None,
|
||||
) -> dict[str, Any]:
|
||||
"""应用配置覆盖"""
|
||||
config_params = self._apply_generation_config(model, config)
|
||||
body.update(config_params)
|
||||
return body
|
||||
|
||||
|
||||
class OpenAICompatAdapter(BaseAdapter):
|
||||
"""
|
||||
处理所有 OpenAI 兼容 API 的通用适配器。
|
||||
消除 OpenAIAdapter 和 ZhipuAdapter 之间的代码重复。
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_chat_endpoint(self) -> str:
|
||||
"""子类必须实现,返回 chat completions 的端点"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_embedding_endpoint(self) -> str:
|
||||
"""子类必须实现,返回 embeddings 的端点"""
|
||||
pass
|
||||
|
||||
def prepare_advanced_request(
|
||||
self,
|
||||
model: "LLMModel",
|
||||
api_key: str,
|
||||
messages: list["LLMMessage"],
|
||||
config: "LLMGenerationConfig | None" = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
) -> RequestData:
|
||||
"""准备高级请求 - OpenAI兼容格式"""
|
||||
url = self.get_api_url(model, self.get_chat_endpoint())
|
||||
headers = self.get_base_headers(api_key)
|
||||
openai_messages = self.convert_messages_to_openai_format(messages)
|
||||
|
||||
body = {
|
||||
"model": model.model_name,
|
||||
"messages": openai_messages,
|
||||
}
|
||||
|
||||
if tools:
|
||||
body["tools"] = tools
|
||||
if tool_choice:
|
||||
body["tool_choice"] = tool_choice
|
||||
|
||||
body = self.apply_config_override(model, body, config)
|
||||
return RequestData(url=url, headers=headers, body=body)
|
||||
|
||||
def parse_response(
|
||||
self,
|
||||
model: "LLMModel",
|
||||
response_json: dict[str, Any],
|
||||
is_advanced: bool = False,
|
||||
) -> ResponseData:
|
||||
"""解析响应 - 直接使用基类的 OpenAI 格式解析"""
|
||||
_ = model, is_advanced # 未使用的参数
|
||||
return self.parse_openai_response(response_json)
|
||||
|
||||
def prepare_embedding_request(
|
||||
self,
|
||||
model: "LLMModel",
|
||||
api_key: str,
|
||||
texts: list[str],
|
||||
task_type: "EmbeddingTaskType | str",
|
||||
**kwargs: Any,
|
||||
) -> RequestData:
|
||||
"""准备嵌入请求 - OpenAI兼容格式"""
|
||||
_ = task_type # 未使用的参数
|
||||
url = self.get_api_url(model, self.get_embedding_endpoint())
|
||||
headers = self.get_base_headers(api_key)
|
||||
|
||||
body = {
|
||||
"model": model.model_name,
|
||||
"input": texts,
|
||||
}
|
||||
|
||||
# 应用额外的配置参数
|
||||
if kwargs:
|
||||
body.update(kwargs)
|
||||
|
||||
return RequestData(url=url, headers=headers, body=body)
|
||||
|
||||
def parse_embedding_response(
|
||||
self, response_json: dict[str, Any]
|
||||
) -> list[list[float]]:
|
||||
"""解析嵌入响应 - OpenAI兼容格式"""
|
||||
self.validate_embedding_response(response_json)
|
||||
|
||||
try:
|
||||
data = response_json.get("data", [])
|
||||
if not data:
|
||||
raise LLMException(
|
||||
"嵌入响应中没有数据",
|
||||
code=LLMErrorCode.EMBEDDING_FAILED,
|
||||
details=response_json,
|
||||
)
|
||||
|
||||
embeddings = []
|
||||
for item in data:
|
||||
if "embedding" in item:
|
||||
embeddings.append(item["embedding"])
|
||||
else:
|
||||
raise LLMException(
|
||||
"嵌入响应格式错误:缺少embedding字段",
|
||||
code=LLMErrorCode.EMBEDDING_FAILED,
|
||||
details=item,
|
||||
)
|
||||
|
||||
return embeddings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析嵌入响应失败: {e}", e=e)
|
||||
raise LLMException(
|
||||
f"解析嵌入响应失败: {e}",
|
||||
code=LLMErrorCode.EMBEDDING_FAILED,
|
||||
cause=e,
|
||||
)
|
||||
78
zhenxun/services/llm/adapters/factory.py
Normal file
78
zhenxun/services/llm/adapters/factory.py
Normal file
@ -0,0 +1,78 @@
|
||||
"""
|
||||
LLM 适配器工厂类
|
||||
"""
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
from ..types.exceptions import LLMErrorCode, LLMException
|
||||
from .base import BaseAdapter
|
||||
|
||||
|
||||
class LLMAdapterFactory:
|
||||
"""LLM适配器工厂类"""
|
||||
|
||||
_adapters: ClassVar[dict[str, BaseAdapter]] = {}
|
||||
_api_type_mapping: ClassVar[dict[str, str]] = {}
|
||||
|
||||
@classmethod
|
||||
def initialize(cls) -> None:
|
||||
"""初始化默认适配器"""
|
||||
if cls._adapters:
|
||||
return
|
||||
|
||||
from .gemini import GeminiAdapter
|
||||
from .openai import OpenAIAdapter
|
||||
from .zhipu import ZhipuAdapter
|
||||
|
||||
cls.register_adapter(OpenAIAdapter())
|
||||
cls.register_adapter(ZhipuAdapter())
|
||||
cls.register_adapter(GeminiAdapter())
|
||||
|
||||
@classmethod
|
||||
def register_adapter(cls, adapter: BaseAdapter) -> None:
|
||||
"""注册适配器"""
|
||||
adapter_key = adapter.api_type
|
||||
cls._adapters[adapter_key] = adapter
|
||||
|
||||
for api_type in adapter.supported_api_types:
|
||||
cls._api_type_mapping[api_type] = adapter_key
|
||||
|
||||
@classmethod
|
||||
def get_adapter(cls, api_type: str) -> BaseAdapter:
|
||||
"""获取适配器"""
|
||||
cls.initialize()
|
||||
|
||||
adapter_key = cls._api_type_mapping.get(api_type)
|
||||
if not adapter_key:
|
||||
raise LLMException(
|
||||
f"不支持的API类型: {api_type}",
|
||||
code=LLMErrorCode.UNKNOWN_API_TYPE,
|
||||
details={
|
||||
"api_type": api_type,
|
||||
"supported_types": list(cls._api_type_mapping.keys()),
|
||||
},
|
||||
)
|
||||
|
||||
return cls._adapters[adapter_key]
|
||||
|
||||
@classmethod
|
||||
def list_supported_types(cls) -> list[str]:
|
||||
"""列出所有支持的API类型"""
|
||||
cls.initialize()
|
||||
return list(cls._api_type_mapping.keys())
|
||||
|
||||
@classmethod
|
||||
def list_adapters(cls) -> dict[str, BaseAdapter]:
|
||||
"""列出所有注册的适配器"""
|
||||
cls.initialize()
|
||||
return cls._adapters.copy()
|
||||
|
||||
|
||||
def get_adapter_for_api_type(api_type: str) -> BaseAdapter:
|
||||
"""获取指定API类型的适配器"""
|
||||
return LLMAdapterFactory.get_adapter(api_type)
|
||||
|
||||
|
||||
def register_adapter(adapter: BaseAdapter) -> None:
|
||||
"""注册新的适配器"""
|
||||
LLMAdapterFactory.register_adapter(adapter)
|
||||
596
zhenxun/services/llm/adapters/gemini.py
Normal file
596
zhenxun/services/llm/adapters/gemini.py
Normal file
@ -0,0 +1,596 @@
|
||||
"""
|
||||
Gemini API 适配器
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
from ..types.exceptions import LLMErrorCode, LLMException
|
||||
from .base import BaseAdapter, RequestData, ResponseData
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..config.generation import LLMGenerationConfig
|
||||
from ..service import LLMModel
|
||||
from ..types.content import LLMMessage
|
||||
from ..types.enums import EmbeddingTaskType
|
||||
from ..types.models import LLMToolCall
|
||||
|
||||
|
||||
class GeminiAdapter(BaseAdapter):
|
||||
"""Gemini API 适配器"""
|
||||
|
||||
@property
|
||||
def api_type(self) -> str:
|
||||
return "gemini"
|
||||
|
||||
@property
|
||||
def supported_api_types(self) -> list[str]:
|
||||
return ["gemini"]
|
||||
|
||||
def get_base_headers(self, api_key: str) -> dict[str, str]:
|
||||
"""获取基础请求头"""
|
||||
from zhenxun.utils.user_agent import get_user_agent
|
||||
|
||||
headers = get_user_agent()
|
||||
headers.update({"Content-Type": "application/json"})
|
||||
headers["x-goog-api-key"] = api_key
|
||||
|
||||
return headers
|
||||
|
||||
def prepare_advanced_request(
|
||||
self,
|
||||
model: "LLMModel",
|
||||
api_key: str,
|
||||
messages: list["LLMMessage"],
|
||||
config: "LLMGenerationConfig | None" = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
) -> RequestData:
|
||||
"""准备高级请求"""
|
||||
return self._prepare_request(
|
||||
model, api_key, messages, config, tools, tool_choice
|
||||
)
|
||||
|
||||
def _prepare_request(
|
||||
self,
|
||||
model: "LLMModel",
|
||||
api_key: str,
|
||||
messages: list["LLMMessage"],
|
||||
config: "LLMGenerationConfig | None" = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
) -> RequestData:
|
||||
"""准备 Gemini API 请求 - 支持所有高级功能"""
|
||||
effective_config = config if config is not None else model._generation_config
|
||||
|
||||
endpoint = self._get_gemini_endpoint(model, effective_config)
|
||||
url = self.get_api_url(model, endpoint)
|
||||
headers = self.get_base_headers(api_key)
|
||||
|
||||
gemini_contents: list[dict[str, Any]] = []
|
||||
system_instruction_parts: list[dict[str, Any]] | None = None
|
||||
|
||||
for msg in messages:
|
||||
current_parts: list[dict[str, Any]] = []
|
||||
if msg.role == "system":
|
||||
if isinstance(msg.content, str):
|
||||
system_instruction_parts = [{"text": msg.content}]
|
||||
elif isinstance(msg.content, list):
|
||||
system_instruction_parts = [
|
||||
part.convert_for_api("gemini") for part in msg.content
|
||||
]
|
||||
continue
|
||||
|
||||
elif msg.role == "user":
|
||||
if isinstance(msg.content, str):
|
||||
current_parts.append({"text": msg.content})
|
||||
elif isinstance(msg.content, list):
|
||||
for part_obj in msg.content:
|
||||
current_parts.append(part_obj.convert_for_api("gemini"))
|
||||
gemini_contents.append({"role": "user", "parts": current_parts})
|
||||
|
||||
elif msg.role == "assistant" or msg.role == "model":
|
||||
if isinstance(msg.content, str) and msg.content:
|
||||
current_parts.append({"text": msg.content})
|
||||
elif isinstance(msg.content, list):
|
||||
for part_obj in msg.content:
|
||||
current_parts.append(part_obj.convert_for_api("gemini"))
|
||||
|
||||
if msg.tool_calls:
|
||||
import json
|
||||
|
||||
for call in msg.tool_calls:
|
||||
current_parts.append(
|
||||
{
|
||||
"functionCall": {
|
||||
"name": call.function.name,
|
||||
"args": json.loads(call.function.arguments),
|
||||
}
|
||||
}
|
||||
)
|
||||
if current_parts:
|
||||
gemini_contents.append({"role": "model", "parts": current_parts})
|
||||
|
||||
elif msg.role == "tool":
|
||||
if not msg.name:
|
||||
raise ValueError("Gemini 工具消息必须包含 'name' 字段(函数名)。")
|
||||
|
||||
import json
|
||||
|
||||
try:
|
||||
content_str = (
|
||||
msg.content
|
||||
if isinstance(msg.content, str)
|
||||
else str(msg.content)
|
||||
)
|
||||
tool_result_obj = json.loads(content_str)
|
||||
except json.JSONDecodeError:
|
||||
content_str = (
|
||||
msg.content
|
||||
if isinstance(msg.content, str)
|
||||
else str(msg.content)
|
||||
)
|
||||
logger.warning(
|
||||
f"工具 {msg.name} 的结果不是有效的 JSON: {content_str}. "
|
||||
f"包装为原始字符串。"
|
||||
)
|
||||
tool_result_obj = {"raw_output": content_str}
|
||||
|
||||
current_parts.append(
|
||||
{
|
||||
"functionResponse": {
|
||||
"name": msg.name,
|
||||
"response": tool_result_obj,
|
||||
}
|
||||
}
|
||||
)
|
||||
gemini_contents.append({"role": "function", "parts": current_parts})
|
||||
|
||||
body: dict[str, Any] = {"contents": gemini_contents}
|
||||
|
||||
if system_instruction_parts:
|
||||
body["systemInstruction"] = {"parts": system_instruction_parts}
|
||||
|
||||
all_tools_for_request = []
|
||||
if tools:
|
||||
for tool_item in tools:
|
||||
if isinstance(tool_item, dict):
|
||||
if "name" in tool_item and "description" in tool_item:
|
||||
all_tools_for_request.append(
|
||||
{"functionDeclarations": [tool_item]}
|
||||
)
|
||||
else:
|
||||
all_tools_for_request.append(tool_item)
|
||||
else:
|
||||
all_tools_for_request.append(tool_item)
|
||||
|
||||
if effective_config:
|
||||
if getattr(effective_config, "enable_grounding", False):
|
||||
has_explicit_gs_tool = any(
|
||||
"googleSearch" in tool_item for tool_item in all_tools_for_request
|
||||
)
|
||||
if not has_explicit_gs_tool:
|
||||
all_tools_for_request.append({"googleSearch": {}})
|
||||
logger.debug("隐式启用 Google Search 工具进行信息来源关联。")
|
||||
|
||||
if getattr(effective_config, "enable_code_execution", False):
|
||||
has_explicit_ce_tool = any(
|
||||
"codeExecution" in tool_item for tool_item in all_tools_for_request
|
||||
)
|
||||
if not has_explicit_ce_tool:
|
||||
all_tools_for_request.append({"codeExecution": {}})
|
||||
logger.debug("隐式启用代码执行工具。")
|
||||
|
||||
if all_tools_for_request:
|
||||
gemini_api_tools = self._convert_tools_to_gemini_format(
|
||||
all_tools_for_request
|
||||
)
|
||||
if gemini_api_tools:
|
||||
body["tools"] = gemini_api_tools
|
||||
|
||||
final_tool_choice = tool_choice
|
||||
if final_tool_choice is None and effective_config:
|
||||
final_tool_choice = getattr(effective_config, "tool_choice", None)
|
||||
|
||||
if final_tool_choice:
|
||||
if isinstance(final_tool_choice, str):
|
||||
mode_upper = final_tool_choice.upper()
|
||||
if mode_upper in ["AUTO", "NONE", "ANY"]:
|
||||
body["toolConfig"] = {"functionCallingConfig": {"mode": mode_upper}}
|
||||
else:
|
||||
body["toolConfig"] = self._convert_tool_choice_to_gemini(
|
||||
final_tool_choice
|
||||
)
|
||||
else:
|
||||
body["toolConfig"] = self._convert_tool_choice_to_gemini(
|
||||
final_tool_choice
|
||||
)
|
||||
|
||||
final_generation_config = self._build_gemini_generation_config(
|
||||
model, effective_config
|
||||
)
|
||||
if final_generation_config:
|
||||
body["generationConfig"] = final_generation_config
|
||||
|
||||
safety_settings = self._build_safety_settings(effective_config)
|
||||
if safety_settings:
|
||||
body["safetySettings"] = safety_settings
|
||||
|
||||
return RequestData(url=url, headers=headers, body=body)
|
||||
|
||||
def apply_config_override(
|
||||
self,
|
||||
model: "LLMModel",
|
||||
body: dict[str, Any],
|
||||
config: "LLMGenerationConfig | None" = None,
|
||||
) -> dict[str, Any]:
|
||||
"""应用配置覆盖 - Gemini 不需要额外的配置覆盖"""
|
||||
return body
|
||||
|
||||
def _get_gemini_endpoint(
|
||||
self, model: "LLMModel", config: "LLMGenerationConfig | None" = None
|
||||
) -> str:
|
||||
"""根据配置选择Gemini API端点"""
|
||||
if config:
|
||||
if getattr(config, "enable_code_execution", False):
|
||||
return f"/v1beta/models/{model.model_name}:generateContent"
|
||||
|
||||
if getattr(config, "enable_grounding", False):
|
||||
return f"/v1beta/models/{model.model_name}:generateContent"
|
||||
|
||||
return f"/v1beta/models/{model.model_name}:generateContent"
|
||||
|
||||
def _convert_tools_to_gemini_format(
|
||||
self, tools: list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""转换工具格式为Gemini格式"""
|
||||
gemini_tools = []
|
||||
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function":
|
||||
func = tool["function"]
|
||||
gemini_tool = {
|
||||
"functionDeclarations": [
|
||||
{
|
||||
"name": func["name"],
|
||||
"description": func.get("description", ""),
|
||||
"parameters": func.get("parameters", {}),
|
||||
}
|
||||
]
|
||||
}
|
||||
gemini_tools.append(gemini_tool)
|
||||
elif tool.get("type") == "code_execution":
|
||||
gemini_tools.append(
|
||||
{"codeExecution": {"language": tool.get("language", "python")}}
|
||||
)
|
||||
elif tool.get("type") == "google_search":
|
||||
gemini_tools.append({"googleSearch": {}})
|
||||
elif "googleSearch" in tool:
|
||||
gemini_tools.append({"googleSearch": tool["googleSearch"]})
|
||||
elif "codeExecution" in tool:
|
||||
gemini_tools.append({"codeExecution": tool["codeExecution"]})
|
||||
|
||||
return gemini_tools
|
||||
|
||||
def _convert_tool_choice_to_gemini(
|
||||
self, tool_choice_value: str | dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""转换工具选择策略为Gemini格式"""
|
||||
if isinstance(tool_choice_value, str):
|
||||
mode_upper = tool_choice_value.upper()
|
||||
if mode_upper in ["AUTO", "NONE", "ANY"]:
|
||||
return {"functionCallingConfig": {"mode": mode_upper}}
|
||||
else:
|
||||
logger.warning(
|
||||
f"不支持的 tool_choice 字符串值: '{tool_choice_value}'。"
|
||||
f"回退到 AUTO。"
|
||||
)
|
||||
return {"functionCallingConfig": {"mode": "AUTO"}}
|
||||
|
||||
elif isinstance(tool_choice_value, dict):
|
||||
if (
|
||||
tool_choice_value.get("type") == "function"
|
||||
and "function" in tool_choice_value
|
||||
):
|
||||
func_name = tool_choice_value["function"].get("name")
|
||||
if func_name:
|
||||
return {
|
||||
"functionCallingConfig": {
|
||||
"mode": "ANY",
|
||||
"allowedFunctionNames": [func_name],
|
||||
}
|
||||
}
|
||||
else:
|
||||
logger.warning(
|
||||
f"tool_choice dict 中的函数名无效: {tool_choice_value}。"
|
||||
f"回退到 AUTO。"
|
||||
)
|
||||
return {"functionCallingConfig": {"mode": "AUTO"}}
|
||||
|
||||
elif "functionCallingConfig" in tool_choice_value:
|
||||
return {
|
||||
"functionCallingConfig": tool_choice_value["functionCallingConfig"]
|
||||
}
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
f"不支持的 tool_choice dict 值: {tool_choice_value}。回退到 AUTO。"
|
||||
)
|
||||
return {"functionCallingConfig": {"mode": "AUTO"}}
|
||||
|
||||
logger.warning(
|
||||
f"tool_choice 的类型无效: {type(tool_choice_value)}。回退到 AUTO。"
|
||||
)
|
||||
return {"functionCallingConfig": {"mode": "AUTO"}}
|
||||
|
||||
def _build_gemini_generation_config(
|
||||
self, model: "LLMModel", config: "LLMGenerationConfig | None" = None
|
||||
) -> dict[str, Any]:
|
||||
"""构建Gemini生成配置"""
|
||||
generation_config: dict[str, Any] = {}
|
||||
|
||||
effective_config = config if config is not None else model._generation_config
|
||||
|
||||
if effective_config:
|
||||
base_api_params = effective_config.to_api_params(
|
||||
api_type="gemini", model_name=model.model_name
|
||||
)
|
||||
generation_config.update(base_api_params)
|
||||
|
||||
if getattr(effective_config, "response_mime_type", None):
|
||||
generation_config["responseMimeType"] = (
|
||||
effective_config.response_mime_type
|
||||
)
|
||||
|
||||
if getattr(effective_config, "response_schema", None):
|
||||
generation_config["responseSchema"] = effective_config.response_schema
|
||||
|
||||
thinking_budget = getattr(effective_config, "thinking_budget", None)
|
||||
if thinking_budget is not None:
|
||||
if "thinkingConfig" not in generation_config:
|
||||
generation_config["thinkingConfig"] = {}
|
||||
generation_config["thinkingConfig"]["thinkingBudget"] = thinking_budget
|
||||
|
||||
if getattr(effective_config, "response_modalities", None):
|
||||
modalities = effective_config.response_modalities
|
||||
if isinstance(modalities, list):
|
||||
generation_config["responseModalities"] = [
|
||||
m.upper() for m in modalities
|
||||
]
|
||||
elif isinstance(modalities, str):
|
||||
generation_config["responseModalities"] = [modalities.upper()]
|
||||
|
||||
generation_config = {
|
||||
k: v for k, v in generation_config.items() if v is not None
|
||||
}
|
||||
|
||||
if generation_config:
|
||||
param_keys = list(generation_config.keys())
|
||||
logger.debug(
|
||||
f"构建Gemini生成配置完成,包含 {len(generation_config)} 个参数: "
|
||||
f"{param_keys}"
|
||||
)
|
||||
|
||||
return generation_config
|
||||
|
||||
def _build_safety_settings(
|
||||
self, config: "LLMGenerationConfig | None" = None
|
||||
) -> list[dict[str, Any]] | None:
|
||||
"""构建安全设置"""
|
||||
if not config:
|
||||
return None
|
||||
|
||||
safety_settings = []
|
||||
|
||||
safety_categories = [
|
||||
"HARM_CATEGORY_HARASSMENT",
|
||||
"HARM_CATEGORY_HATE_SPEECH",
|
||||
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
]
|
||||
|
||||
custom_safety_settings = getattr(config, "safety_settings", None)
|
||||
if custom_safety_settings:
|
||||
for category, threshold in custom_safety_settings.items():
|
||||
safety_settings.append({"category": category, "threshold": threshold})
|
||||
else:
|
||||
for category in safety_categories:
|
||||
safety_settings.append(
|
||||
{"category": category, "threshold": "BLOCK_MEDIUM_AND_ABOVE"}
|
||||
)
|
||||
|
||||
return safety_settings if safety_settings else None
|
||||
|
||||
def parse_response(
|
||||
self,
|
||||
model: "LLMModel",
|
||||
response_json: dict[str, Any],
|
||||
is_advanced: bool = False,
|
||||
) -> ResponseData:
|
||||
"""解析API响应"""
|
||||
return self._parse_response(model, response_json, is_advanced)
|
||||
|
||||
def _parse_response(
|
||||
self,
|
||||
model: "LLMModel",
|
||||
response_json: dict[str, Any],
|
||||
is_advanced: bool = False,
|
||||
) -> ResponseData:
|
||||
"""解析 Gemini API 响应"""
|
||||
_ = is_advanced
|
||||
self.validate_response(response_json)
|
||||
|
||||
try:
|
||||
candidates = response_json.get("candidates", [])
|
||||
if not candidates:
|
||||
logger.debug("Gemini响应中没有candidates。")
|
||||
return ResponseData(text="", raw_response=response_json)
|
||||
|
||||
candidate = candidates[0]
|
||||
|
||||
if candidate.get("finishReason") in [
|
||||
"RECITATION",
|
||||
"OTHER",
|
||||
] and not candidate.get("content"):
|
||||
logger.warning(
|
||||
f"Gemini candidate finished with reason "
|
||||
f"'{candidate.get('finishReason')}' and no content."
|
||||
)
|
||||
return ResponseData(
|
||||
text="",
|
||||
raw_response=response_json,
|
||||
usage_info=response_json.get("usageMetadata"),
|
||||
)
|
||||
|
||||
content_data = candidate.get("content", {})
|
||||
parts = content_data.get("parts", [])
|
||||
|
||||
text_content = ""
|
||||
parsed_tool_calls: list["LLMToolCall"] | None = None
|
||||
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
text_content += part["text"]
|
||||
elif "functionCall" in part:
|
||||
if parsed_tool_calls is None:
|
||||
parsed_tool_calls = []
|
||||
fc_data = part["functionCall"]
|
||||
try:
|
||||
import json
|
||||
|
||||
from ..types.models import LLMToolCall, LLMToolFunction
|
||||
|
||||
call_id = f"call_{model.provider_name}_{len(parsed_tool_calls)}"
|
||||
parsed_tool_calls.append(
|
||||
LLMToolCall(
|
||||
id=call_id,
|
||||
function=LLMToolFunction(
|
||||
name=fc_data["name"],
|
||||
arguments=json.dumps(fc_data["args"]),
|
||||
),
|
||||
)
|
||||
)
|
||||
except KeyError as e:
|
||||
logger.warning(
|
||||
f"解析Gemini functionCall时缺少键: {fc_data}, 错误: {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"解析Gemini functionCall时出错: {fc_data}, 错误: {e}"
|
||||
)
|
||||
elif "codeExecutionResult" in part:
|
||||
result = part["codeExecutionResult"]
|
||||
if result.get("outcome") == "OK":
|
||||
output = result.get("output", "")
|
||||
text_content += f"\n[代码执行结果]:\n{output}\n"
|
||||
else:
|
||||
text_content += (
|
||||
f"\n[代码执行失败]: {result.get('outcome', 'UNKNOWN')}\n"
|
||||
)
|
||||
|
||||
usage_info = response_json.get("usageMetadata")
|
||||
|
||||
grounding_metadata_obj = None
|
||||
if grounding_data := candidate.get("groundingMetadata"):
|
||||
try:
|
||||
from ..types.models import LLMGroundingMetadata
|
||||
|
||||
grounding_metadata_obj = LLMGroundingMetadata(**grounding_data)
|
||||
except Exception as e:
|
||||
logger.warning(f"无法解析Grounding元数据: {grounding_data}, {e}")
|
||||
|
||||
return ResponseData(
|
||||
text=text_content,
|
||||
tool_calls=parsed_tool_calls,
|
||||
usage_info=usage_info,
|
||||
raw_response=response_json,
|
||||
grounding_metadata=grounding_metadata_obj,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析 Gemini 响应失败: {e}", e=e)
|
||||
raise LLMException(
|
||||
f"解析API响应失败: {e}",
|
||||
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
def prepare_embedding_request(
|
||||
self,
|
||||
model: "LLMModel",
|
||||
api_key: str,
|
||||
texts: list[str],
|
||||
task_type: "EmbeddingTaskType | str",
|
||||
**kwargs: Any,
|
||||
) -> RequestData:
|
||||
"""准备文本嵌入请求"""
|
||||
api_model_name = model.model_name
|
||||
if not api_model_name.startswith("models/"):
|
||||
api_model_name = f"models/{api_model_name}"
|
||||
|
||||
url = self.get_api_url(model, f"/{api_model_name}:batchEmbedContents")
|
||||
headers = self.get_base_headers(api_key)
|
||||
|
||||
requests_payload = []
|
||||
for text_content in texts:
|
||||
request_item: dict[str, Any] = {
|
||||
"content": {"parts": [{"text": text_content}]},
|
||||
}
|
||||
|
||||
from ..types.enums import EmbeddingTaskType
|
||||
|
||||
if task_type and task_type != EmbeddingTaskType.RETRIEVAL_DOCUMENT:
|
||||
request_item["task_type"] = str(task_type).upper()
|
||||
if title := kwargs.get("title"):
|
||||
request_item["title"] = title
|
||||
if output_dimensionality := kwargs.get("output_dimensionality"):
|
||||
request_item["output_dimensionality"] = output_dimensionality
|
||||
|
||||
requests_payload.append(request_item)
|
||||
|
||||
body = {"requests": requests_payload}
|
||||
return RequestData(url=url, headers=headers, body=body)
|
||||
|
||||
def parse_embedding_response(
|
||||
self, response_json: dict[str, Any]
|
||||
) -> list[list[float]]:
|
||||
"""解析文本嵌入响应"""
|
||||
try:
|
||||
embeddings_data = response_json["embeddings"]
|
||||
return [item["values"] for item in embeddings_data]
|
||||
except KeyError as e:
|
||||
logger.error(f"解析Gemini嵌入响应时缺少键: {e}. 响应: {response_json}")
|
||||
raise LLMException(
|
||||
"Gemini嵌入响应格式错误",
|
||||
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
|
||||
details={"error": str(e)},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"解析Gemini嵌入响应时发生未知错误: {e}. 响应: {response_json}"
|
||||
)
|
||||
raise LLMException(
|
||||
f"解析Gemini嵌入响应失败: {e}",
|
||||
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
def validate_embedding_response(self, response_json: dict[str, Any]) -> None:
|
||||
"""验证嵌入响应"""
|
||||
super().validate_embedding_response(response_json)
|
||||
if "embeddings" not in response_json or not isinstance(
|
||||
response_json["embeddings"], list
|
||||
):
|
||||
raise LLMException(
|
||||
"Gemini嵌入响应缺少'embeddings'字段或格式不正确",
|
||||
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
|
||||
details=response_json,
|
||||
)
|
||||
for item in response_json["embeddings"]:
|
||||
if "values" not in item:
|
||||
raise LLMException(
|
||||
"Gemini嵌入响应的条目中缺少'values'字段",
|
||||
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
|
||||
details=response_json,
|
||||
)
|
||||
57
zhenxun/services/llm/adapters/openai.py
Normal file
57
zhenxun/services/llm/adapters/openai.py
Normal file
@ -0,0 +1,57 @@
|
||||
"""
|
||||
OpenAI API 适配器
|
||||
|
||||
支持 OpenAI、DeepSeek 和其他 OpenAI 兼容的 API 服务。
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .base import OpenAICompatAdapter, RequestData
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..service import LLMModel
|
||||
|
||||
|
||||
class OpenAIAdapter(OpenAICompatAdapter):
|
||||
"""OpenAI兼容API适配器"""
|
||||
|
||||
@property
|
||||
def api_type(self) -> str:
|
||||
return "openai"
|
||||
|
||||
@property
|
||||
def supported_api_types(self) -> list[str]:
|
||||
return ["openai", "deepseek", "general_openai_compat"]
|
||||
|
||||
def get_chat_endpoint(self) -> str:
|
||||
"""返回聊天完成端点"""
|
||||
return "/v1/chat/completions"
|
||||
|
||||
def get_embedding_endpoint(self) -> str:
|
||||
"""返回嵌入端点"""
|
||||
return "/v1/embeddings"
|
||||
|
||||
def prepare_simple_request(
|
||||
self,
|
||||
model: "LLMModel",
|
||||
api_key: str,
|
||||
prompt: str,
|
||||
history: list[dict[str, str]] | None = None,
|
||||
) -> RequestData:
|
||||
"""准备简单文本生成请求 - OpenAI优化实现"""
|
||||
url = self.get_api_url(model, self.get_chat_endpoint())
|
||||
headers = self.get_base_headers(api_key)
|
||||
|
||||
messages = []
|
||||
if history:
|
||||
messages.extend(history)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
body = {
|
||||
"model": model.model_name,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
body = self.apply_config_override(model, body)
|
||||
|
||||
return RequestData(url=url, headers=headers, body=body)
|
||||
57
zhenxun/services/llm/adapters/zhipu.py
Normal file
57
zhenxun/services/llm/adapters/zhipu.py
Normal file
@ -0,0 +1,57 @@
|
||||
"""
|
||||
智谱 AI API 适配器
|
||||
|
||||
支持智谱 AI 的 GLM 系列模型,使用 OpenAI 兼容的接口格式。
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .base import OpenAICompatAdapter, RequestData
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..service import LLMModel
|
||||
|
||||
|
||||
class ZhipuAdapter(OpenAICompatAdapter):
|
||||
"""智谱AI适配器 - 使用智谱AI专用的OpenAI兼容接口"""
|
||||
|
||||
@property
|
||||
def api_type(self) -> str:
|
||||
return "zhipu"
|
||||
|
||||
@property
|
||||
def supported_api_types(self) -> list[str]:
|
||||
return ["zhipu"]
|
||||
|
||||
def get_chat_endpoint(self) -> str:
|
||||
"""返回智谱AI聊天完成端点"""
|
||||
return "/api/paas/v4/chat/completions"
|
||||
|
||||
def get_embedding_endpoint(self) -> str:
|
||||
"""返回智谱AI嵌入端点"""
|
||||
return "/v4/embeddings"
|
||||
|
||||
def prepare_simple_request(
|
||||
self,
|
||||
model: "LLMModel",
|
||||
api_key: str,
|
||||
prompt: str,
|
||||
history: list[dict[str, str]] | None = None,
|
||||
) -> RequestData:
|
||||
"""准备简单文本生成请求 - 智谱AI优化实现"""
|
||||
url = self.get_api_url(model, self.get_chat_endpoint())
|
||||
headers = self.get_base_headers(api_key)
|
||||
|
||||
messages = []
|
||||
if history:
|
||||
messages.extend(history)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
body = {
|
||||
"model": model.model_name,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
body = self.apply_config_override(model, body)
|
||||
|
||||
return RequestData(url=url, headers=headers, body=body)
|
||||
530
zhenxun/services/llm/api.py
Normal file
530
zhenxun/services/llm/api.py
Normal file
@ -0,0 +1,530 @@
|
||||
"""
|
||||
LLM 服务的高级 API 接口
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from nonebot_plugin_alconna.uniseg import UniMessage
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
from .config import CommonOverrides, LLMGenerationConfig
|
||||
from .config.providers import get_ai_config
|
||||
from .manager import get_global_default_model_name, get_model_instance
|
||||
from .types import (
|
||||
EmbeddingTaskType,
|
||||
LLMContentPart,
|
||||
LLMErrorCode,
|
||||
LLMException,
|
||||
LLMMessage,
|
||||
LLMResponse,
|
||||
LLMTool,
|
||||
ModelName,
|
||||
)
|
||||
from .utils import create_multimodal_message, unimsg_to_llm_parts
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
"""任务类型枚举"""
|
||||
|
||||
CHAT = "chat"
|
||||
CODE = "code"
|
||||
SEARCH = "search"
|
||||
ANALYSIS = "analysis"
|
||||
GENERATION = "generation"
|
||||
MULTIMODAL = "multimodal"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AIConfig:
|
||||
"""AI配置类 - 简化版本"""
|
||||
|
||||
model: ModelName = None
|
||||
default_embedding_model: ModelName = None
|
||||
temperature: float | None = None
|
||||
max_tokens: int | None = None
|
||||
enable_cache: bool = False
|
||||
enable_code: bool = False
|
||||
enable_search: bool = False
|
||||
timeout: int | None = None
|
||||
|
||||
enable_gemini_json_mode: bool = False
|
||||
enable_gemini_thinking: bool = False
|
||||
enable_gemini_safe_mode: bool = False
|
||||
enable_gemini_multimodal: bool = False
|
||||
enable_gemini_grounding: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
"""初始化后从配置中读取默认值"""
|
||||
ai_config = get_ai_config()
|
||||
if self.model is None:
|
||||
self.model = ai_config.get("default_model_name")
|
||||
if self.timeout is None:
|
||||
self.timeout = ai_config.get("timeout", 180)
|
||||
|
||||
|
||||
class AI:
|
||||
"""统一的AI服务类 - 平衡设计版本
|
||||
|
||||
提供三层API:
|
||||
1. 简单方法:ai.chat(), ai.code(), ai.search()
|
||||
2. 标准方法:ai.analyze() 支持复杂参数
|
||||
3. 高级方法:通过get_model_instance()直接访问
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, config: AIConfig | None = None, history: list[LLMMessage] | None = None
|
||||
):
|
||||
"""
|
||||
初始化AI服务
|
||||
|
||||
Args:
|
||||
config: AI 配置.
|
||||
history: 可选的初始对话历史.
|
||||
"""
|
||||
self.config = config or AIConfig()
|
||||
self.history = history or []
|
||||
|
||||
def clear_history(self):
|
||||
"""清空当前会话的历史记录"""
|
||||
self.history = []
|
||||
logger.info("AI session history cleared.")
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
message: str | LLMMessage | list[LLMContentPart],
|
||||
*,
|
||||
model: ModelName = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
进行一次聊天对话。
|
||||
此方法会自动使用和更新会话内的历史记录。
|
||||
"""
|
||||
current_message: LLMMessage
|
||||
if isinstance(message, str):
|
||||
current_message = LLMMessage.user(message)
|
||||
elif isinstance(message, list) and all(
|
||||
isinstance(part, LLMContentPart) for part in message
|
||||
):
|
||||
current_message = LLMMessage.user(message)
|
||||
elif isinstance(message, LLMMessage):
|
||||
current_message = message
|
||||
else:
|
||||
raise LLMException(
|
||||
f"AI.chat 不支持的消息类型: {type(message)}. "
|
||||
"请使用 str, LLMMessage, 或 list[LLMContentPart]. "
|
||||
"对于更复杂的多模态输入或文件路径,请使用 AI.analyze().",
|
||||
code=LLMErrorCode.API_REQUEST_FAILED,
|
||||
)
|
||||
|
||||
final_messages = [*self.history, current_message]
|
||||
|
||||
response = await self._execute_generation(
|
||||
final_messages, model, "聊天失败", kwargs
|
||||
)
|
||||
|
||||
self.history.append(current_message)
|
||||
self.history.append(LLMMessage.assistant_text_response(response.text))
|
||||
|
||||
return response.text
|
||||
|
||||
async def code(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
model: ModelName = None,
|
||||
timeout: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""代码执行"""
|
||||
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
|
||||
|
||||
config = CommonOverrides.gemini_code_execution()
|
||||
if timeout:
|
||||
config.custom_params = config.custom_params or {}
|
||||
config.custom_params["code_execution_timeout"] = timeout
|
||||
|
||||
messages = [LLMMessage.user(prompt)]
|
||||
|
||||
response = await self._execute_generation(
|
||||
messages, resolved_model, "代码执行失败", kwargs, base_config=config
|
||||
)
|
||||
|
||||
return {
|
||||
"text": response.text,
|
||||
"code_executions": response.code_executions or [],
|
||||
"success": True,
|
||||
}
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str | UniMessage,
|
||||
*,
|
||||
model: ModelName = None,
|
||||
instruction: str = "",
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""信息搜索 - 支持多模态输入"""
|
||||
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
|
||||
config = CommonOverrides.gemini_grounding()
|
||||
|
||||
if isinstance(query, str):
|
||||
messages = [LLMMessage.user(query)]
|
||||
elif isinstance(query, UniMessage):
|
||||
content_parts = await unimsg_to_llm_parts(query)
|
||||
|
||||
final_messages: list[LLMMessage] = []
|
||||
if instruction:
|
||||
final_messages.append(LLMMessage.system(instruction))
|
||||
|
||||
if not content_parts:
|
||||
if instruction:
|
||||
final_messages.append(LLMMessage.user(instruction))
|
||||
else:
|
||||
raise LLMException(
|
||||
"搜索内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
|
||||
)
|
||||
else:
|
||||
final_messages.append(LLMMessage.user(content_parts))
|
||||
|
||||
messages = final_messages
|
||||
else:
|
||||
raise LLMException(
|
||||
f"不支持的搜索输入类型: {type(query)}. 请使用 str 或 UniMessage.",
|
||||
code=LLMErrorCode.API_REQUEST_FAILED,
|
||||
)
|
||||
|
||||
response = await self._execute_generation(
|
||||
messages, resolved_model, "信息搜索失败", kwargs, base_config=config
|
||||
)
|
||||
|
||||
result = {
|
||||
"text": response.text,
|
||||
"sources": [],
|
||||
"queries": [],
|
||||
"success": True,
|
||||
}
|
||||
|
||||
if response.grounding_metadata:
|
||||
result["sources"] = response.grounding_metadata.grounding_attributions or []
|
||||
result["queries"] = response.grounding_metadata.web_search_queries or []
|
||||
|
||||
return result
|
||||
|
||||
async def analyze(
|
||||
self,
|
||||
message: UniMessage,
|
||||
*,
|
||||
instruction: str = "",
|
||||
model: ModelName = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
tool_config: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str | LLMResponse:
|
||||
"""
|
||||
内容分析 - 接收 UniMessage 物件进行多模态分析和工具呼叫。
|
||||
这是处理复杂互动的主要方法。
|
||||
"""
|
||||
content_parts = await unimsg_to_llm_parts(message)
|
||||
|
||||
final_messages: list[LLMMessage] = []
|
||||
if instruction:
|
||||
final_messages.append(LLMMessage.system(instruction))
|
||||
|
||||
if not content_parts:
|
||||
if instruction:
|
||||
final_messages.append(LLMMessage.user(instruction))
|
||||
else:
|
||||
raise LLMException(
|
||||
"分析内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
|
||||
)
|
||||
else:
|
||||
final_messages.append(LLMMessage.user(content_parts))
|
||||
|
||||
llm_tools = None
|
||||
if tools:
|
||||
llm_tools = []
|
||||
for tool_dict in tools:
|
||||
if isinstance(tool_dict, dict):
|
||||
if "name" in tool_dict and "description" in tool_dict:
|
||||
llm_tool = LLMTool(
|
||||
type="function",
|
||||
function={
|
||||
"name": tool_dict["name"],
|
||||
"description": tool_dict["description"],
|
||||
"parameters": tool_dict.get("parameters", {}),
|
||||
},
|
||||
)
|
||||
llm_tools.append(llm_tool)
|
||||
else:
|
||||
llm_tools.append(LLMTool(**tool_dict))
|
||||
else:
|
||||
llm_tools.append(tool_dict)
|
||||
|
||||
tool_choice = None
|
||||
if tool_config:
|
||||
mode = tool_config.get("mode", "auto")
|
||||
if mode == "auto":
|
||||
tool_choice = "auto"
|
||||
elif mode == "any":
|
||||
tool_choice = "any"
|
||||
elif mode == "none":
|
||||
tool_choice = "none"
|
||||
|
||||
response = await self._execute_generation(
|
||||
final_messages,
|
||||
model,
|
||||
"内容分析失败",
|
||||
kwargs,
|
||||
llm_tools=llm_tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
if response.tool_calls:
|
||||
return response
|
||||
return response.text
|
||||
|
||||
async def _execute_generation(
|
||||
self,
|
||||
messages: list[LLMMessage],
|
||||
model_name: ModelName,
|
||||
error_message: str,
|
||||
config_overrides: dict[str, Any],
|
||||
llm_tools: list[LLMTool] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
base_config: LLMGenerationConfig | None = None,
|
||||
) -> LLMResponse:
|
||||
"""通用的生成执行方法,封装重复的模型获取、配置合并和异常处理逻辑"""
|
||||
try:
|
||||
resolved_model_name = self._resolve_model_name(
|
||||
model_name or self.config.model
|
||||
)
|
||||
final_config_dict = self._merge_config(
|
||||
config_overrides, base_config=base_config
|
||||
)
|
||||
|
||||
async with await get_model_instance(
|
||||
resolved_model_name, override_config=final_config_dict
|
||||
) as model_instance:
|
||||
return await model_instance.generate_response(
|
||||
messages, tools=llm_tools, tool_choice=tool_choice
|
||||
)
|
||||
except LLMException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"{error_message}: {e}", e=e)
|
||||
raise LLMException(f"{error_message}: {e}", cause=e)
|
||||
|
||||
def _resolve_model_name(self, model_name: ModelName) -> str:
|
||||
"""解析模型名称"""
|
||||
if model_name:
|
||||
return model_name
|
||||
|
||||
default_model = get_global_default_model_name()
|
||||
if default_model:
|
||||
return default_model
|
||||
|
||||
raise LLMException(
|
||||
"未指定模型名称且未设置全局默认模型",
|
||||
code=LLMErrorCode.MODEL_NOT_FOUND,
|
||||
)
|
||||
|
||||
def _merge_config(
|
||||
self,
|
||||
user_config: dict[str, Any],
|
||||
base_config: LLMGenerationConfig | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""合并配置"""
|
||||
final_config = {}
|
||||
if base_config:
|
||||
final_config.update(base_config.to_dict())
|
||||
|
||||
if self.config.temperature is not None:
|
||||
final_config["temperature"] = self.config.temperature
|
||||
if self.config.max_tokens is not None:
|
||||
final_config["max_tokens"] = self.config.max_tokens
|
||||
|
||||
if self.config.enable_cache:
|
||||
final_config["enable_caching"] = True
|
||||
if self.config.enable_code:
|
||||
final_config["enable_code_execution"] = True
|
||||
if self.config.enable_search:
|
||||
final_config["enable_grounding"] = True
|
||||
|
||||
if self.config.enable_gemini_json_mode:
|
||||
final_config["response_mime_type"] = "application/json"
|
||||
if self.config.enable_gemini_thinking:
|
||||
final_config["thinking_budget"] = 0.8
|
||||
if self.config.enable_gemini_safe_mode:
|
||||
final_config["safety_settings"] = (
|
||||
CommonOverrides.gemini_safe().safety_settings
|
||||
)
|
||||
if self.config.enable_gemini_multimodal:
|
||||
final_config.update(CommonOverrides.gemini_multimodal().to_dict())
|
||||
if self.config.enable_gemini_grounding:
|
||||
final_config["enable_grounding"] = True
|
||||
|
||||
final_config.update(user_config)
|
||||
|
||||
return final_config
|
||||
|
||||
async def embed(
|
||||
self,
|
||||
texts: list[str] | str,
|
||||
*,
|
||||
model: ModelName = None,
|
||||
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
||||
**kwargs: Any,
|
||||
) -> list[list[float]]:
|
||||
"""生成文本嵌入向量"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
try:
|
||||
resolved_model_str = (
|
||||
model or self.config.default_embedding_model or self.config.model
|
||||
)
|
||||
if not resolved_model_str:
|
||||
raise LLMException(
|
||||
"使用 embed 功能时必须指定嵌入模型名称,"
|
||||
"或在 AIConfig 中配置 default_embedding_model。",
|
||||
code=LLMErrorCode.MODEL_NOT_FOUND,
|
||||
)
|
||||
resolved_model_str = self._resolve_model_name(resolved_model_str)
|
||||
|
||||
async with await get_model_instance(
|
||||
resolved_model_str,
|
||||
override_config=None,
|
||||
) as embedding_model_instance:
|
||||
return await embedding_model_instance.generate_embeddings(
|
||||
texts, task_type=task_type, **kwargs
|
||||
)
|
||||
except LLMException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"文本嵌入失败: {e}", e=e)
|
||||
raise LLMException(
|
||||
f"文本嵌入失败: {e}", code=LLMErrorCode.EMBEDDING_FAILED, cause=e
|
||||
)
|
||||
|
||||
|
||||
async def chat(
|
||||
message: str | LLMMessage | list[LLMContentPart],
|
||||
*,
|
||||
model: ModelName = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""聊天对话便捷函数"""
|
||||
ai = AI()
|
||||
return await ai.chat(message, model=model, **kwargs)
|
||||
|
||||
|
||||
async def code(
|
||||
prompt: str,
|
||||
*,
|
||||
model: ModelName = None,
|
||||
timeout: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""代码执行便捷函数"""
|
||||
ai = AI()
|
||||
return await ai.code(prompt, model=model, timeout=timeout, **kwargs)
|
||||
|
||||
|
||||
async def search(
|
||||
query: str | UniMessage,
|
||||
*,
|
||||
model: ModelName = None,
|
||||
instruction: str = "",
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""信息搜索便捷函数"""
|
||||
ai = AI()
|
||||
return await ai.search(query, model=model, instruction=instruction, **kwargs)
|
||||
|
||||
|
||||
async def analyze(
|
||||
message: UniMessage,
|
||||
*,
|
||||
instruction: str = "",
|
||||
model: ModelName = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
tool_config: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str | LLMResponse:
|
||||
"""内容分析便捷函数"""
|
||||
ai = AI()
|
||||
return await ai.analyze(
|
||||
message,
|
||||
instruction=instruction,
|
||||
model=model,
|
||||
tools=tools,
|
||||
tool_config=tool_config,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def analyze_with_images(
|
||||
text: str,
|
||||
images: list[str | Path | bytes] | str | Path | bytes,
|
||||
*,
|
||||
instruction: str = "",
|
||||
model: ModelName = None,
|
||||
**kwargs: Any,
|
||||
) -> str | LLMResponse:
|
||||
"""图片分析便捷函数"""
|
||||
message = create_multimodal_message(text=text, images=images)
|
||||
return await analyze(message, instruction=instruction, model=model, **kwargs)
|
||||
|
||||
|
||||
async def analyze_multimodal(
|
||||
text: str | None = None,
|
||||
images: list[str | Path | bytes] | str | Path | bytes | None = None,
|
||||
videos: list[str | Path | bytes] | str | Path | bytes | None = None,
|
||||
audios: list[str | Path | bytes] | str | Path | bytes | None = None,
|
||||
*,
|
||||
instruction: str = "",
|
||||
model: ModelName = None,
|
||||
**kwargs: Any,
|
||||
) -> str | LLMResponse:
|
||||
"""多模态分析便捷函数"""
|
||||
message = create_multimodal_message(
|
||||
text=text, images=images, videos=videos, audios=audios
|
||||
)
|
||||
return await analyze(message, instruction=instruction, model=model, **kwargs)
|
||||
|
||||
|
||||
async def search_multimodal(
|
||||
text: str | None = None,
|
||||
images: list[str | Path | bytes] | str | Path | bytes | None = None,
|
||||
videos: list[str | Path | bytes] | str | Path | bytes | None = None,
|
||||
audios: list[str | Path | bytes] | str | Path | bytes | None = None,
|
||||
*,
|
||||
instruction: str = "",
|
||||
model: ModelName = None,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""多模态搜索便捷函数"""
|
||||
message = create_multimodal_message(
|
||||
text=text, images=images, videos=videos, audios=audios
|
||||
)
|
||||
ai = AI()
|
||||
return await ai.search(message, model=model, instruction=instruction, **kwargs)
|
||||
|
||||
|
||||
async def embed(
|
||||
texts: list[str] | str,
|
||||
*,
|
||||
model: ModelName = None,
|
||||
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
||||
**kwargs: Any,
|
||||
) -> list[list[float]]:
|
||||
"""文本嵌入便捷函数"""
|
||||
ai = AI()
|
||||
return await ai.embed(texts, model=model, task_type=task_type, **kwargs)
|
||||
35
zhenxun/services/llm/config/__init__.py
Normal file
35
zhenxun/services/llm/config/__init__.py
Normal file
@ -0,0 +1,35 @@
|
||||
"""
|
||||
LLM 配置模块
|
||||
|
||||
提供生成配置、预设配置和配置验证功能。
|
||||
"""
|
||||
|
||||
from .generation import (
|
||||
LLMGenerationConfig,
|
||||
ModelConfigOverride,
|
||||
apply_api_specific_mappings,
|
||||
create_generation_config_from_kwargs,
|
||||
validate_override_params,
|
||||
)
|
||||
from .presets import CommonOverrides
|
||||
from .providers import (
|
||||
LLMConfig,
|
||||
get_llm_config,
|
||||
register_llm_configs,
|
||||
set_default_model,
|
||||
validate_llm_config,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CommonOverrides",
|
||||
"LLMConfig",
|
||||
"LLMGenerationConfig",
|
||||
"ModelConfigOverride",
|
||||
"apply_api_specific_mappings",
|
||||
"create_generation_config_from_kwargs",
|
||||
"get_llm_config",
|
||||
"register_llm_configs",
|
||||
"set_default_model",
|
||||
"validate_llm_config",
|
||||
"validate_override_params",
|
||||
]
|
||||
260
zhenxun/services/llm/config/generation.py
Normal file
260
zhenxun/services/llm/config/generation.py
Normal file
@ -0,0 +1,260 @@
|
||||
"""
|
||||
LLM 生成配置相关类和函数
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
from ..types.enums import ResponseFormat
|
||||
from ..types.exceptions import LLMErrorCode, LLMException
|
||||
|
||||
|
||||
class ModelConfigOverride(BaseModel):
|
||||
"""模型配置覆盖参数"""
|
||||
|
||||
temperature: float | None = Field(
|
||||
default=None, ge=0.0, le=2.0, description="生成温度"
|
||||
)
|
||||
max_tokens: int | None = Field(default=None, gt=0, description="最大输出token数")
|
||||
top_p: float | None = Field(default=None, ge=0.0, le=1.0, description="核采样参数")
|
||||
top_k: int | None = Field(default=None, gt=0, description="Top-K采样参数")
|
||||
frequency_penalty: float | None = Field(
|
||||
default=None, ge=-2.0, le=2.0, description="频率惩罚"
|
||||
)
|
||||
presence_penalty: float | None = Field(
|
||||
default=None, ge=-2.0, le=2.0, description="存在惩罚"
|
||||
)
|
||||
repetition_penalty: float | None = Field(
|
||||
default=None, ge=0.0, le=2.0, description="重复惩罚"
|
||||
)
|
||||
|
||||
stop: list[str] | str | None = Field(default=None, description="停止序列")
|
||||
|
||||
response_format: ResponseFormat | dict[str, Any] | None = Field(
|
||||
default=None, description="期望的响应格式"
|
||||
)
|
||||
response_mime_type: str | None = Field(
|
||||
default=None, description="响应MIME类型(Gemini专用)"
|
||||
)
|
||||
response_schema: dict[str, Any] | None = Field(
|
||||
default=None, description="JSON响应模式"
|
||||
)
|
||||
thinking_budget: float | None = Field(
|
||||
default=None, ge=0.0, le=1.0, description="思考预算"
|
||||
)
|
||||
safety_settings: dict[str, str] | None = Field(default=None, description="安全设置")
|
||||
response_modalities: list[str] | None = Field(
|
||||
default=None, description="响应模态类型"
|
||||
)
|
||||
|
||||
enable_code_execution: bool | None = Field(
|
||||
default=None, description="是否启用代码执行"
|
||||
)
|
||||
enable_grounding: bool | None = Field(
|
||||
default=None, description="是否启用信息来源关联"
|
||||
)
|
||||
enable_caching: bool | None = Field(default=None, description="是否启用响应缓存")
|
||||
|
||||
custom_params: dict[str, Any] | None = Field(default=None, description="自定义参数")
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""转换为字典,排除None值"""
|
||||
result = {}
|
||||
model_data = getattr(self, "model_dump", lambda: {})()
|
||||
if not model_data:
|
||||
model_data = {}
|
||||
for field_name, _ in self.__class__.__dict__.get(
|
||||
"model_fields", {}
|
||||
).items():
|
||||
value = getattr(self, field_name, None)
|
||||
if value is not None:
|
||||
model_data[field_name] = value
|
||||
for key, value in model_data.items():
|
||||
if value is not None:
|
||||
if key == "custom_params" and isinstance(value, dict):
|
||||
result.update(value)
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
def merge_with_base_config(
|
||||
self,
|
||||
base_temperature: float | None = None,
|
||||
base_max_tokens: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""与基础配置合并,覆盖参数优先"""
|
||||
merged = {}
|
||||
|
||||
if base_temperature is not None:
|
||||
merged["temperature"] = base_temperature
|
||||
if base_max_tokens is not None:
|
||||
merged["max_tokens"] = base_max_tokens
|
||||
|
||||
override_dict = self.to_dict()
|
||||
merged.update(override_dict)
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
class LLMGenerationConfig(ModelConfigOverride):
|
||||
"""LLM 生成配置,继承模型配置覆盖参数"""
|
||||
|
||||
def to_api_params(self, api_type: str, model_name: str) -> dict[str, Any]:
|
||||
"""转换为API参数,支持不同API类型的参数名映射"""
|
||||
_ = model_name
|
||||
params = {}
|
||||
|
||||
if self.temperature is not None:
|
||||
params["temperature"] = self.temperature
|
||||
|
||||
if self.max_tokens is not None:
|
||||
if api_type in ["gemini", "gemini_native"]:
|
||||
params["maxOutputTokens"] = self.max_tokens
|
||||
else:
|
||||
params["max_tokens"] = self.max_tokens
|
||||
|
||||
if api_type in ["gemini", "gemini_native"]:
|
||||
if self.top_k is not None:
|
||||
params["topK"] = self.top_k
|
||||
if self.top_p is not None:
|
||||
params["topP"] = self.top_p
|
||||
else:
|
||||
if self.top_k is not None:
|
||||
params["top_k"] = self.top_k
|
||||
if self.top_p is not None:
|
||||
params["top_p"] = self.top_p
|
||||
|
||||
if api_type in ["openai", "deepseek", "zhipu", "general_openai_compat"]:
|
||||
if self.frequency_penalty is not None:
|
||||
params["frequency_penalty"] = self.frequency_penalty
|
||||
if self.presence_penalty is not None:
|
||||
params["presence_penalty"] = self.presence_penalty
|
||||
|
||||
if self.repetition_penalty is not None:
|
||||
if api_type == "openai":
|
||||
logger.warning("OpenAI官方API不支持repetition_penalty参数,已忽略")
|
||||
else:
|
||||
params["repetition_penalty"] = self.repetition_penalty
|
||||
|
||||
if self.response_format is not None:
|
||||
if isinstance(self.response_format, dict):
|
||||
if api_type in ["openai", "zhipu", "deepseek", "general_openai_compat"]:
|
||||
params["response_format"] = self.response_format
|
||||
logger.debug(
|
||||
f"为 {api_type} 使用自定义 response_format: "
|
||||
f"{self.response_format}"
|
||||
)
|
||||
elif self.response_format == ResponseFormat.JSON:
|
||||
if api_type in ["openai", "zhipu", "deepseek", "general_openai_compat"]:
|
||||
params["response_format"] = {"type": "json_object"}
|
||||
logger.debug(f"为 {api_type} 启用 JSON 对象输出模式")
|
||||
elif api_type in ["gemini", "gemini_native"]:
|
||||
params["responseMimeType"] = "application/json"
|
||||
if self.response_schema:
|
||||
params["responseSchema"] = self.response_schema
|
||||
logger.debug(f"为 {api_type} 启用 JSON MIME 类型输出模式")
|
||||
|
||||
if api_type in ["gemini", "gemini_native"]:
|
||||
if (
|
||||
self.response_format != ResponseFormat.JSON
|
||||
and self.response_mime_type is not None
|
||||
):
|
||||
params["responseMimeType"] = self.response_mime_type
|
||||
logger.debug(
|
||||
f"使用显式设置的 responseMimeType: {self.response_mime_type}"
|
||||
)
|
||||
|
||||
if self.response_schema is not None and "responseSchema" not in params:
|
||||
params["responseSchema"] = self.response_schema
|
||||
if self.thinking_budget is not None:
|
||||
params["thinkingBudget"] = self.thinking_budget
|
||||
if self.safety_settings is not None:
|
||||
params["safetySettings"] = self.safety_settings
|
||||
if self.response_modalities is not None:
|
||||
params["responseModalities"] = self.response_modalities
|
||||
|
||||
if self.custom_params:
|
||||
custom_mapped = apply_api_specific_mappings(self.custom_params, api_type)
|
||||
params.update(custom_mapped)
|
||||
|
||||
logger.debug(f"为{api_type}转换配置参数: {len(params)}个参数")
|
||||
return params
|
||||
|
||||
|
||||
def validate_override_params(
|
||||
override_config: dict[str, Any] | LLMGenerationConfig | None,
|
||||
) -> LLMGenerationConfig:
|
||||
"""验证和标准化覆盖参数"""
|
||||
if override_config is None:
|
||||
return LLMGenerationConfig()
|
||||
|
||||
if isinstance(override_config, dict):
|
||||
try:
|
||||
filtered_config = {
|
||||
k: v for k, v in override_config.items() if v is not None
|
||||
}
|
||||
return LLMGenerationConfig(**filtered_config)
|
||||
except Exception as e:
|
||||
logger.warning(f"覆盖配置参数验证失败: {e}")
|
||||
raise LLMException(
|
||||
f"无效的覆盖配置参数: {e}",
|
||||
code=LLMErrorCode.CONFIGURATION_ERROR,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
return override_config
|
||||
|
||||
|
||||
def apply_api_specific_mappings(
|
||||
params: dict[str, Any], api_type: str
|
||||
) -> dict[str, Any]:
|
||||
"""应用API特定的参数映射"""
|
||||
mapped_params = params.copy()
|
||||
|
||||
if api_type in ["gemini", "gemini_native"]:
|
||||
if "max_tokens" in mapped_params:
|
||||
mapped_params["maxOutputTokens"] = mapped_params.pop("max_tokens")
|
||||
if "top_k" in mapped_params:
|
||||
mapped_params["topK"] = mapped_params.pop("top_k")
|
||||
if "top_p" in mapped_params:
|
||||
mapped_params["topP"] = mapped_params.pop("top_p")
|
||||
|
||||
unsupported = ["frequency_penalty", "presence_penalty", "repetition_penalty"]
|
||||
for param in unsupported:
|
||||
if param in mapped_params:
|
||||
logger.warning(f"Gemini 原生API不支持参数 '{param}',已忽略")
|
||||
mapped_params.pop(param)
|
||||
|
||||
elif api_type in ["openai", "deepseek", "zhipu", "general_openai_compat"]:
|
||||
if "repetition_penalty" in mapped_params and api_type == "openai":
|
||||
logger.warning("OpenAI官方API不支持repetition_penalty参数,已忽略")
|
||||
mapped_params.pop("repetition_penalty")
|
||||
|
||||
if "stop" in mapped_params:
|
||||
stop_value = mapped_params["stop"]
|
||||
if isinstance(stop_value, str):
|
||||
mapped_params["stop"] = [stop_value]
|
||||
|
||||
return mapped_params
|
||||
|
||||
|
||||
def create_generation_config_from_kwargs(**kwargs) -> LLMGenerationConfig:
|
||||
"""从关键字参数创建生成配置"""
|
||||
model_fields = getattr(LLMGenerationConfig, "model_fields", {})
|
||||
known_fields = set(model_fields.keys())
|
||||
known_params = {}
|
||||
custom_params = {}
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if key in known_fields:
|
||||
known_params[key] = value
|
||||
else:
|
||||
custom_params[key] = value
|
||||
|
||||
if custom_params:
|
||||
known_params["custom_params"] = custom_params
|
||||
|
||||
return LLMGenerationConfig(**known_params)
|
||||
169
zhenxun/services/llm/config/presets.py
Normal file
169
zhenxun/services/llm/config/presets.py
Normal file
@ -0,0 +1,169 @@
|
||||
"""
|
||||
LLM 预设配置
|
||||
|
||||
提供常用的配置预设,特别是针对 Gemini 的高级功能。
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .generation import LLMGenerationConfig
|
||||
|
||||
|
||||
class CommonOverrides:
|
||||
"""常用的配置覆盖预设"""
|
||||
|
||||
@staticmethod
|
||||
def creative() -> LLMGenerationConfig:
|
||||
"""创意模式:高温度,鼓励创新"""
|
||||
return LLMGenerationConfig(temperature=0.9, top_p=0.95, frequency_penalty=0.1)
|
||||
|
||||
@staticmethod
|
||||
def precise() -> LLMGenerationConfig:
|
||||
"""精确模式:低温度,确定性输出"""
|
||||
return LLMGenerationConfig(temperature=0.1, top_p=0.9, frequency_penalty=0.0)
|
||||
|
||||
@staticmethod
|
||||
def balanced() -> LLMGenerationConfig:
|
||||
"""平衡模式:中等温度"""
|
||||
return LLMGenerationConfig(temperature=0.5, top_p=0.9, frequency_penalty=0.0)
|
||||
|
||||
@staticmethod
|
||||
def concise(max_tokens: int = 100) -> LLMGenerationConfig:
|
||||
"""简洁模式:限制输出长度"""
|
||||
return LLMGenerationConfig(
|
||||
temperature=0.3,
|
||||
max_tokens=max_tokens,
|
||||
stop=["\n\n", "。", "!", "?"],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def detailed(max_tokens: int = 2000) -> LLMGenerationConfig:
|
||||
"""详细模式:鼓励详细输出"""
|
||||
return LLMGenerationConfig(
|
||||
temperature=0.7, max_tokens=max_tokens, frequency_penalty=-0.1
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def gemini_json() -> LLMGenerationConfig:
|
||||
"""Gemini JSON模式:强制JSON输出"""
|
||||
return LLMGenerationConfig(
|
||||
temperature=0.3, response_mime_type="application/json"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def gemini_thinking(budget: float = 0.8) -> LLMGenerationConfig:
|
||||
"""Gemini 思考模式:使用思考预算"""
|
||||
return LLMGenerationConfig(temperature=0.7, thinking_budget=budget)
|
||||
|
||||
@staticmethod
|
||||
def gemini_creative() -> LLMGenerationConfig:
|
||||
"""Gemini 创意模式:高温度创意输出"""
|
||||
return LLMGenerationConfig(temperature=0.9, top_p=0.95)
|
||||
|
||||
@staticmethod
|
||||
def gemini_structured(schema: dict[str, Any]) -> LLMGenerationConfig:
|
||||
"""Gemini 结构化输出:自定义JSON模式"""
|
||||
return LLMGenerationConfig(
|
||||
temperature=0.3,
|
||||
response_mime_type="application/json",
|
||||
response_schema=schema,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def gemini_safe() -> LLMGenerationConfig:
|
||||
"""Gemini 安全模式:严格安全设置"""
|
||||
return LLMGenerationConfig(
|
||||
temperature=0.5,
|
||||
safety_settings={
|
||||
"HARM_CATEGORY_HARASSMENT": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
"HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def gemini_multimodal() -> LLMGenerationConfig:
|
||||
"""Gemini 多模态模式:优化多模态处理"""
|
||||
return LLMGenerationConfig(temperature=0.6, max_tokens=2048, top_p=0.8)
|
||||
|
||||
@staticmethod
|
||||
def gemini_code_execution() -> LLMGenerationConfig:
|
||||
"""Gemini 代码执行模式:启用代码执行功能"""
|
||||
return LLMGenerationConfig(
|
||||
temperature=0.3,
|
||||
max_tokens=4096,
|
||||
enable_code_execution=True,
|
||||
custom_params={"code_execution_timeout": 30},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def gemini_grounding() -> LLMGenerationConfig:
|
||||
"""Gemini 信息来源关联模式:启用Google搜索"""
|
||||
return LLMGenerationConfig(
|
||||
temperature=0.5,
|
||||
max_tokens=4096,
|
||||
enable_grounding=True,
|
||||
custom_params={
|
||||
"grounding_config": {"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def gemini_cached() -> LLMGenerationConfig:
|
||||
"""Gemini 缓存模式:启用响应缓存"""
|
||||
return LLMGenerationConfig(
|
||||
temperature=0.3,
|
||||
max_tokens=2048,
|
||||
enable_caching=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def gemini_advanced() -> LLMGenerationConfig:
|
||||
"""Gemini 高级模式:启用所有高级功能"""
|
||||
return LLMGenerationConfig(
|
||||
temperature=0.5,
|
||||
max_tokens=4096,
|
||||
enable_code_execution=True,
|
||||
enable_grounding=True,
|
||||
enable_caching=True,
|
||||
custom_params={
|
||||
"code_execution_timeout": 30,
|
||||
"grounding_config": {
|
||||
"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def gemini_research() -> LLMGenerationConfig:
|
||||
"""Gemini 研究模式:思考+搜索+结构化输出"""
|
||||
return LLMGenerationConfig(
|
||||
temperature=0.6,
|
||||
max_tokens=4096,
|
||||
thinking_budget=0.8,
|
||||
enable_grounding=True,
|
||||
response_mime_type="application/json",
|
||||
custom_params={
|
||||
"grounding_config": {"dynamicRetrievalConfig": {"mode": "MODE_DYNAMIC"}}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def gemini_analysis() -> LLMGenerationConfig:
|
||||
"""Gemini 分析模式:深度思考+详细输出"""
|
||||
return LLMGenerationConfig(
|
||||
temperature=0.4,
|
||||
max_tokens=6000,
|
||||
thinking_budget=0.9,
|
||||
top_p=0.8,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def gemini_fast_response() -> LLMGenerationConfig:
|
||||
"""Gemini 快速响应模式:低延迟+简洁输出"""
|
||||
return LLMGenerationConfig(
|
||||
temperature=0.3,
|
||||
max_tokens=512,
|
||||
top_p=0.8,
|
||||
)
|
||||
328
zhenxun/services/llm/config/providers.py
Normal file
328
zhenxun/services/llm/config/providers.py
Normal file
@ -0,0 +1,328 @@
|
||||
"""
|
||||
LLM 提供商配置管理
|
||||
|
||||
负责注册和管理 AI 服务提供商的配置项。
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
from ..types.models import ModelDetail, ProviderConfig
|
||||
|
||||
AI_CONFIG_GROUP = "AI"
|
||||
PROVIDERS_CONFIG_KEY = "PROVIDERS"
|
||||
|
||||
|
||||
class LLMConfig(BaseModel):
|
||||
"""LLM 服务配置类"""
|
||||
|
||||
default_model_name: str | None = Field(
|
||||
default=None,
|
||||
description="LLM服务全局默认使用的模型名称 (格式: ProviderName/ModelName)",
|
||||
)
|
||||
proxy: str | None = Field(
|
||||
default=None,
|
||||
description="LLM服务请求使用的网络代理,例如 http://127.0.0.1:7890",
|
||||
)
|
||||
timeout: int = Field(default=180, description="LLM服务API请求超时时间(秒)")
|
||||
max_retries_llm: int = Field(
|
||||
default=3, description="LLM服务请求失败时的最大重试次数"
|
||||
)
|
||||
retry_delay_llm: int = Field(
|
||||
default=2, description="LLM服务请求重试的基础延迟时间(秒)"
|
||||
)
|
||||
providers: list[ProviderConfig] = Field(
|
||||
default_factory=list, description="配置多个 AI 服务提供商及其模型信息"
|
||||
)
|
||||
|
||||
def get_provider_by_name(self, name: str) -> ProviderConfig | None:
|
||||
"""根据名称获取提供商配置
|
||||
|
||||
参数:
|
||||
name: 提供商名称
|
||||
|
||||
返回:
|
||||
ProviderConfig | None: 提供商配置,如果未找到则返回 None
|
||||
"""
|
||||
for provider in self.providers:
|
||||
if provider.name == name:
|
||||
return provider
|
||||
return None
|
||||
|
||||
def get_model_by_provider_and_name(
|
||||
self, provider_name: str, model_name: str
|
||||
) -> tuple[ProviderConfig, ModelDetail] | None:
|
||||
"""根据提供商名称和模型名称获取配置
|
||||
|
||||
参数:
|
||||
provider_name: 提供商名称
|
||||
model_name: 模型名称
|
||||
|
||||
返回:
|
||||
tuple[ProviderConfig, ModelDetail] | None: 提供商配置和模型详情的元组,
|
||||
如果未找到则返回 None
|
||||
"""
|
||||
provider = self.get_provider_by_name(provider_name)
|
||||
if not provider:
|
||||
return None
|
||||
|
||||
for model in provider.models:
|
||||
if model.model_name == model_name:
|
||||
return provider, model
|
||||
return None
|
||||
|
||||
def list_available_models(self) -> list[dict[str, Any]]:
|
||||
"""列出所有可用的模型
|
||||
|
||||
返回:
|
||||
list[dict[str, Any]]: 模型信息列表
|
||||
"""
|
||||
models = []
|
||||
for provider in self.providers:
|
||||
for model in provider.models:
|
||||
models.append(
|
||||
{
|
||||
"provider_name": provider.name,
|
||||
"model_name": model.model_name,
|
||||
"full_name": f"{provider.name}/{model.model_name}",
|
||||
"is_available": model.is_available,
|
||||
"is_embedding_model": model.is_embedding_model,
|
||||
"api_type": provider.api_type,
|
||||
}
|
||||
)
|
||||
return models
|
||||
|
||||
def validate_model_name(self, provider_model_name: str) -> bool:
|
||||
"""验证模型名称格式是否正确
|
||||
|
||||
参数:
|
||||
provider_model_name: 格式为 "ProviderName/ModelName" 的字符串
|
||||
|
||||
返回:
|
||||
bool: 是否有效
|
||||
"""
|
||||
if not provider_model_name or "/" not in provider_model_name:
|
||||
return False
|
||||
|
||||
parts = provider_model_name.split("/", 1)
|
||||
if len(parts) != 2:
|
||||
return False
|
||||
|
||||
provider_name, model_name = parts
|
||||
return (
|
||||
self.get_model_by_provider_and_name(provider_name, model_name) is not None
|
||||
)
|
||||
|
||||
|
||||
def get_ai_config():
|
||||
"""获取 AI 配置组"""
|
||||
return Config.get(AI_CONFIG_GROUP)
|
||||
|
||||
|
||||
def get_default_providers() -> list[dict[str, Any]]:
|
||||
"""获取默认的提供商配置
|
||||
|
||||
返回:
|
||||
list[dict[str, Any]]: 默认提供商配置列表
|
||||
"""
|
||||
return [
|
||||
{
|
||||
"name": "DeepSeek",
|
||||
"api_key": "sk-******",
|
||||
"api_base": "https://api.deepseek.com",
|
||||
"api_type": "openai",
|
||||
"models": [
|
||||
{
|
||||
"model_name": "deepseek-chat",
|
||||
"max_tokens": 4096,
|
||||
"temperature": 0.7,
|
||||
},
|
||||
{
|
||||
"model_name": "deepseek-reasoner",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "GLM",
|
||||
"api_key": "",
|
||||
"api_base": "https://open.bigmodel.cn",
|
||||
"api_type": "zhipu",
|
||||
"models": [
|
||||
{"model_name": "glm-4-flash"},
|
||||
{"model_name": "glm-4-plus"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Gemini",
|
||||
"api_key": [
|
||||
"AIzaSy*****************************",
|
||||
"AIzaSy*****************************",
|
||||
"AIzaSy*****************************",
|
||||
],
|
||||
"api_base": "https://generativelanguage.googleapis.com",
|
||||
"api_type": "gemini",
|
||||
"models": [
|
||||
{"model_name": "gemini-2.0-flash"},
|
||||
{"model_name": "gemini-2.5-flash-preview-05-20"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def register_llm_configs():
|
||||
"""注册 LLM 服务的配置项"""
|
||||
logger.info("注册 LLM 服务的配置项")
|
||||
|
||||
llm_config = LLMConfig()
|
||||
|
||||
Config.add_plugin_config(
|
||||
AI_CONFIG_GROUP,
|
||||
"default_model_name",
|
||||
llm_config.default_model_name,
|
||||
help="LLM服务全局默认使用的模型名称 (格式: ProviderName/ModelName)",
|
||||
type=str,
|
||||
)
|
||||
Config.add_plugin_config(
|
||||
AI_CONFIG_GROUP,
|
||||
"proxy",
|
||||
llm_config.proxy,
|
||||
help="LLM服务请求使用的网络代理,例如 http://127.0.0.1:7890",
|
||||
type=str,
|
||||
)
|
||||
Config.add_plugin_config(
|
||||
AI_CONFIG_GROUP,
|
||||
"timeout",
|
||||
llm_config.timeout,
|
||||
help="LLM服务API请求超时时间(秒)",
|
||||
type=int,
|
||||
)
|
||||
Config.add_plugin_config(
|
||||
AI_CONFIG_GROUP,
|
||||
"max_retries_llm",
|
||||
llm_config.max_retries_llm,
|
||||
help="LLM服务请求失败时的最大重试次数",
|
||||
type=int,
|
||||
)
|
||||
Config.add_plugin_config(
|
||||
AI_CONFIG_GROUP,
|
||||
"retry_delay_llm",
|
||||
llm_config.retry_delay_llm,
|
||||
help="LLM服务请求重试的基础延迟时间(秒)",
|
||||
type=int,
|
||||
)
|
||||
|
||||
Config.add_plugin_config(
|
||||
AI_CONFIG_GROUP,
|
||||
PROVIDERS_CONFIG_KEY,
|
||||
get_default_providers(),
|
||||
help="配置多个 AI 服务提供商及其模型信息",
|
||||
default_value=[],
|
||||
type=list[ProviderConfig],
|
||||
)
|
||||
|
||||
|
||||
def get_llm_config() -> LLMConfig:
|
||||
"""获取 LLM 配置实例
|
||||
|
||||
返回:
|
||||
LLMConfig: LLM 配置实例
|
||||
"""
|
||||
ai_config = get_ai_config()
|
||||
|
||||
config_data = {
|
||||
"default_model_name": ai_config.get("default_model_name"),
|
||||
"proxy": ai_config.get("proxy"),
|
||||
"timeout": ai_config.get("timeout", 180),
|
||||
"max_retries_llm": ai_config.get("max_retries_llm", 3),
|
||||
"retry_delay_llm": ai_config.get("retry_delay_llm", 2),
|
||||
"providers": ai_config.get(PROVIDERS_CONFIG_KEY, []),
|
||||
}
|
||||
|
||||
return LLMConfig(**config_data)
|
||||
|
||||
|
||||
def validate_llm_config() -> tuple[bool, list[str]]:
|
||||
"""验证 LLM 配置的有效性
|
||||
|
||||
返回:
|
||||
tuple[bool, list[str]]: (是否有效, 错误信息列表)
|
||||
"""
|
||||
errors = []
|
||||
|
||||
try:
|
||||
llm_config = get_llm_config()
|
||||
|
||||
if llm_config.timeout <= 0:
|
||||
errors.append("timeout 必须大于 0")
|
||||
|
||||
if llm_config.max_retries_llm < 0:
|
||||
errors.append("max_retries_llm 不能小于 0")
|
||||
|
||||
if llm_config.retry_delay_llm <= 0:
|
||||
errors.append("retry_delay_llm 必须大于 0")
|
||||
|
||||
if not llm_config.providers:
|
||||
errors.append("至少需要配置一个 AI 服务提供商")
|
||||
else:
|
||||
provider_names = set()
|
||||
for provider in llm_config.providers:
|
||||
if provider.name in provider_names:
|
||||
errors.append(f"提供商名称重复: {provider.name}")
|
||||
provider_names.add(provider.name)
|
||||
|
||||
if not provider.api_key:
|
||||
errors.append(f"提供商 {provider.name} 缺少 API Key")
|
||||
|
||||
if not provider.models:
|
||||
errors.append(f"提供商 {provider.name} 没有配置任何模型")
|
||||
else:
|
||||
model_names = set()
|
||||
for model in provider.models:
|
||||
if model.model_name in model_names:
|
||||
errors.append(
|
||||
f"提供商 {provider.name} 中模型名称重复: "
|
||||
f"{model.model_name}"
|
||||
)
|
||||
model_names.add(model.model_name)
|
||||
|
||||
if llm_config.default_model_name:
|
||||
if not llm_config.validate_model_name(llm_config.default_model_name):
|
||||
errors.append(
|
||||
f"默认模型 {llm_config.default_model_name} 在配置中不存在"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
errors.append(f"配置解析失败: {e!s}")
|
||||
|
||||
return len(errors) == 0, errors
|
||||
|
||||
|
||||
def set_default_model(provider_model_name: str | None) -> bool:
|
||||
"""设置默认模型
|
||||
|
||||
参数:
|
||||
provider_model_name: 模型名称,格式为 "ProviderName/ModelName",None 表示清除
|
||||
|
||||
返回:
|
||||
bool: 是否设置成功
|
||||
"""
|
||||
if provider_model_name:
|
||||
llm_config = get_llm_config()
|
||||
if not llm_config.validate_model_name(provider_model_name):
|
||||
logger.error(f"模型 {provider_model_name} 在配置中不存在")
|
||||
return False
|
||||
|
||||
Config.set_config(
|
||||
AI_CONFIG_GROUP, "default_model_name", provider_model_name, auto_save=True
|
||||
)
|
||||
|
||||
if provider_model_name:
|
||||
logger.info(f"默认模型已设置为: {provider_model_name}")
|
||||
else:
|
||||
logger.info("默认模型已清除")
|
||||
|
||||
return True
|
||||
378
zhenxun/services/llm/core.py
Normal file
378
zhenxun/services/llm/core.py
Normal file
@ -0,0 +1,378 @@
|
||||
"""
|
||||
LLM 核心基础设施模块
|
||||
|
||||
包含执行 LLM 请求所需的底层组件,如 HTTP 客户端、API Key 存储和智能重试逻辑。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.user_agent import get_user_agent
|
||||
|
||||
from .types import ProviderConfig
|
||||
from .types.exceptions import LLMErrorCode, LLMException
|
||||
|
||||
|
||||
class HttpClientConfig(BaseModel):
|
||||
"""HTTP客户端配置"""
|
||||
|
||||
timeout: int = 180
|
||||
max_connections: int = 100
|
||||
max_keepalive_connections: int = 20
|
||||
proxy: str | None = None
|
||||
|
||||
|
||||
class LLMHttpClient:
|
||||
"""LLM服务专用HTTP客户端"""
|
||||
|
||||
def __init__(self, config: HttpClientConfig | None = None):
|
||||
self.config = config or HttpClientConfig()
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
self._active_requests = 0
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def _ensure_client_initialized(self) -> httpx.AsyncClient:
|
||||
if self._client is None or self._client.is_closed:
|
||||
async with self._lock:
|
||||
if self._client is None or self._client.is_closed:
|
||||
logger.debug(
|
||||
f"LLMHttpClient: Initializing new httpx.AsyncClient "
|
||||
f"with config: {self.config}"
|
||||
)
|
||||
headers = get_user_agent()
|
||||
limits = httpx.Limits(
|
||||
max_connections=self.config.max_connections,
|
||||
max_keepalive_connections=self.config.max_keepalive_connections,
|
||||
)
|
||||
timeout = httpx.Timeout(self.config.timeout)
|
||||
self._client = httpx.AsyncClient(
|
||||
headers=headers,
|
||||
limits=limits,
|
||||
timeout=timeout,
|
||||
proxies=self.config.proxy,
|
||||
follow_redirects=True,
|
||||
)
|
||||
if self._client is None:
|
||||
raise LLMException(
|
||||
"HTTP client failed to initialize.", LLMErrorCode.CONFIGURATION_ERROR
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def post(self, url: str, **kwargs: Any) -> httpx.Response:
|
||||
client = await self._ensure_client_initialized()
|
||||
async with self._lock:
|
||||
self._active_requests += 1
|
||||
try:
|
||||
return await client.post(url, **kwargs)
|
||||
finally:
|
||||
async with self._lock:
|
||||
self._active_requests -= 1
|
||||
|
||||
async def close(self):
|
||||
async with self._lock:
|
||||
if self._client and not self._client.is_closed:
|
||||
logger.debug(
|
||||
f"LLMHttpClient: Closing with config: {self.config}. "
|
||||
f"Active requests: {self._active_requests}"
|
||||
)
|
||||
if self._active_requests > 0:
|
||||
logger.warning(
|
||||
f"LLMHttpClient: Closing while {self._active_requests} "
|
||||
f"requests are still active."
|
||||
)
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
logger.debug(f"LLMHttpClient for config {self.config} definitively closed.")
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
return self._client is None or self._client.is_closed
|
||||
|
||||
|
||||
class LLMHttpClientManager:
|
||||
"""管理 LLMHttpClient 实例的工厂和池"""
|
||||
|
||||
def __init__(self):
|
||||
self._clients: dict[tuple[int, str | None], LLMHttpClient] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
def _get_client_key(
|
||||
self, provider_config: ProviderConfig
|
||||
) -> tuple[int, str | None]:
|
||||
return (provider_config.timeout, provider_config.proxy)
|
||||
|
||||
async def get_client(self, provider_config: ProviderConfig) -> LLMHttpClient:
|
||||
key = self._get_client_key(provider_config)
|
||||
async with self._lock:
|
||||
client = self._clients.get(key)
|
||||
if client and not client.is_closed:
|
||||
logger.debug(
|
||||
f"LLMHttpClientManager: Reusing existing LLMHttpClient "
|
||||
f"for key: {key}"
|
||||
)
|
||||
return client
|
||||
|
||||
if client and client.is_closed:
|
||||
logger.debug(
|
||||
f"LLMHttpClientManager: Found a closed client for key {key}. "
|
||||
f"Creating a new one."
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"LLMHttpClientManager: Creating new LLMHttpClient for key: {key}"
|
||||
)
|
||||
http_client_config = HttpClientConfig(
|
||||
timeout=provider_config.timeout, proxy=provider_config.proxy
|
||||
)
|
||||
new_client = LLMHttpClient(config=http_client_config)
|
||||
self._clients[key] = new_client
|
||||
return new_client
|
||||
|
||||
async def shutdown(self):
|
||||
async with self._lock:
|
||||
logger.info(
|
||||
f"LLMHttpClientManager: Shutting down. "
|
||||
f"Closing {len(self._clients)} client(s)."
|
||||
)
|
||||
close_tasks = [
|
||||
client.close()
|
||||
for client in self._clients.values()
|
||||
if client and not client.is_closed
|
||||
]
|
||||
if close_tasks:
|
||||
await asyncio.gather(*close_tasks, return_exceptions=True)
|
||||
self._clients.clear()
|
||||
logger.info("LLMHttpClientManager: Shutdown complete.")
|
||||
|
||||
|
||||
http_client_manager = LLMHttpClientManager()
|
||||
|
||||
|
||||
async def create_llm_http_client(
|
||||
timeout: int = 180,
|
||||
proxy: str | None = None,
|
||||
) -> LLMHttpClient:
|
||||
"""创建LLM HTTP客户端"""
|
||||
config = HttpClientConfig(timeout=timeout, proxy=proxy)
|
||||
return LLMHttpClient(config)
|
||||
|
||||
|
||||
class RetryConfig:
|
||||
"""重试配置"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
exponential_backoff: bool = True,
|
||||
key_rotation: bool = True,
|
||||
):
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
self.exponential_backoff = exponential_backoff
|
||||
self.key_rotation = key_rotation
|
||||
|
||||
|
||||
async def with_smart_retry(
|
||||
func,
|
||||
*args,
|
||||
retry_config: RetryConfig | None = None,
|
||||
key_store: "KeyStatusStore | None" = None,
|
||||
provider_name: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""智能重试装饰器 - 支持Key轮询和错误分类"""
|
||||
config = retry_config or RetryConfig()
|
||||
last_exception: Exception | None = None
|
||||
failed_keys: set[str] = set()
|
||||
|
||||
for attempt in range(config.max_retries + 1):
|
||||
try:
|
||||
if config.key_rotation and "failed_keys" in func.__code__.co_varnames:
|
||||
kwargs["failed_keys"] = failed_keys
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
except LLMException as e:
|
||||
last_exception = e
|
||||
|
||||
if e.code in [
|
||||
LLMErrorCode.API_KEY_INVALID,
|
||||
LLMErrorCode.API_QUOTA_EXCEEDED,
|
||||
]:
|
||||
if hasattr(e, "details") and e.details and "api_key" in e.details:
|
||||
failed_keys.add(e.details["api_key"])
|
||||
if key_store and provider_name:
|
||||
await key_store.record_failure(
|
||||
e.details["api_key"], e.details.get("status_code")
|
||||
)
|
||||
|
||||
should_retry = _should_retry_llm_error(e, attempt, config.max_retries)
|
||||
if not should_retry:
|
||||
logger.error(f"不可重试的错误,停止重试: {e}")
|
||||
raise
|
||||
|
||||
if attempt < config.max_retries:
|
||||
wait_time = config.retry_delay
|
||||
if config.exponential_backoff:
|
||||
wait_time *= 2**attempt
|
||||
logger.warning(
|
||||
f"请求失败,{wait_time}秒后重试 (第{attempt + 1}次): {e}"
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
logger.error(f"重试{config.max_retries}次后仍然失败: {e}")
|
||||
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
logger.error(f"非LLM异常,停止重试: {e}")
|
||||
raise LLMException(
|
||||
f"操作失败: {e}",
|
||||
code=LLMErrorCode.GENERATION_FAILED,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
if last_exception:
|
||||
raise last_exception
|
||||
else:
|
||||
raise RuntimeError("重试函数未能正常执行且未捕获到异常")
|
||||
|
||||
|
||||
def _should_retry_llm_error(
|
||||
error: LLMException, attempt: int, max_retries: int
|
||||
) -> bool:
|
||||
"""判断LLM错误是否应该重试"""
|
||||
non_retryable_errors = {
|
||||
LLMErrorCode.MODEL_NOT_FOUND,
|
||||
LLMErrorCode.CONTEXT_LENGTH_EXCEEDED,
|
||||
LLMErrorCode.USER_LOCATION_NOT_SUPPORTED,
|
||||
LLMErrorCode.CONFIGURATION_ERROR,
|
||||
}
|
||||
|
||||
if error.code in non_retryable_errors:
|
||||
return False
|
||||
|
||||
retryable_errors = {
|
||||
LLMErrorCode.API_REQUEST_FAILED,
|
||||
LLMErrorCode.API_TIMEOUT,
|
||||
LLMErrorCode.API_RATE_LIMITED,
|
||||
LLMErrorCode.API_RESPONSE_INVALID,
|
||||
LLMErrorCode.RESPONSE_PARSE_ERROR,
|
||||
LLMErrorCode.GENERATION_FAILED,
|
||||
LLMErrorCode.CONTENT_FILTERED,
|
||||
LLMErrorCode.API_KEY_INVALID,
|
||||
LLMErrorCode.API_QUOTA_EXCEEDED,
|
||||
}
|
||||
|
||||
if error.code in retryable_errors:
|
||||
if error.code == LLMErrorCode.API_QUOTA_EXCEEDED:
|
||||
return attempt < min(2, max_retries)
|
||||
elif error.code == LLMErrorCode.CONTENT_FILTERED:
|
||||
return attempt < min(1, max_retries)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class KeyStatusStore:
|
||||
"""API Key 状态管理存储 - 优化版本,支持轮询和负载均衡"""
|
||||
|
||||
def __init__(self):
|
||||
self._key_status: dict[str, bool] = {}
|
||||
self._key_usage_count: dict[str, int] = {}
|
||||
self._key_last_used: dict[str, float] = {}
|
||||
self._provider_key_index: dict[str, int] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def get_next_available_key(
|
||||
self,
|
||||
provider_name: str,
|
||||
api_keys: list[str],
|
||||
exclude_keys: set[str] | None = None,
|
||||
) -> str | None:
|
||||
"""获取下一个可用的API密钥(轮询策略)"""
|
||||
if not api_keys:
|
||||
return None
|
||||
|
||||
exclude_keys = exclude_keys or set()
|
||||
available_keys = [
|
||||
key
|
||||
for key in api_keys
|
||||
if key not in exclude_keys and self._key_status.get(key, True)
|
||||
]
|
||||
|
||||
if not available_keys:
|
||||
return api_keys[0] if api_keys else None
|
||||
|
||||
async with self._lock:
|
||||
current_index = self._provider_key_index.get(provider_name, 0)
|
||||
|
||||
selected_key = available_keys[current_index % len(available_keys)]
|
||||
|
||||
self._provider_key_index[provider_name] = (current_index + 1) % len(
|
||||
available_keys
|
||||
)
|
||||
|
||||
import time
|
||||
|
||||
self._key_usage_count[selected_key] = (
|
||||
self._key_usage_count.get(selected_key, 0) + 1
|
||||
)
|
||||
self._key_last_used[selected_key] = time.time()
|
||||
|
||||
logger.debug(
|
||||
f"轮询选择API密钥: {self._get_key_id(selected_key)} "
|
||||
f"(使用次数: {self._key_usage_count[selected_key]})"
|
||||
)
|
||||
|
||||
return selected_key
|
||||
|
||||
async def record_success(self, api_key: str):
|
||||
"""记录成功使用"""
|
||||
async with self._lock:
|
||||
self._key_status[api_key] = True
|
||||
logger.debug(f"记录API密钥成功使用: {self._get_key_id(api_key)}")
|
||||
|
||||
async def record_failure(self, api_key: str, status_code: int | None):
|
||||
"""记录失败使用"""
|
||||
key_id = self._get_key_id(api_key)
|
||||
async with self._lock:
|
||||
if status_code in [401, 403]:
|
||||
self._key_status[api_key] = False
|
||||
logger.warning(
|
||||
f"API密钥认证失败,标记为不可用: {key_id} (状态码: {status_code})"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"记录API密钥失败使用: {key_id} (状态码: {status_code})")
|
||||
|
||||
async def reset_key_status(self, api_key: str):
|
||||
"""重置密钥状态(用于恢复机制)"""
|
||||
async with self._lock:
|
||||
self._key_status[api_key] = True
|
||||
logger.info(f"重置API密钥状态: {self._get_key_id(api_key)}")
|
||||
|
||||
async def get_key_stats(self, api_keys: list[str]) -> dict[str, dict]:
|
||||
"""获取密钥使用统计"""
|
||||
stats = {}
|
||||
async with self._lock:
|
||||
for key in api_keys:
|
||||
key_id = self._get_key_id(key)
|
||||
stats[key_id] = {
|
||||
"available": self._key_status.get(key, True),
|
||||
"usage_count": self._key_usage_count.get(key, 0),
|
||||
"last_used": self._key_last_used.get(key, 0),
|
||||
}
|
||||
return stats
|
||||
|
||||
def _get_key_id(self, api_key: str) -> str:
|
||||
"""获取API密钥的标识符(用于日志)"""
|
||||
if len(api_key) <= 8:
|
||||
return api_key
|
||||
return f"{api_key[:4]}...{api_key[-4:]}"
|
||||
|
||||
|
||||
key_store = KeyStatusStore()
|
||||
434
zhenxun/services/llm/manager.py
Normal file
434
zhenxun/services/llm/manager.py
Normal file
@ -0,0 +1,434 @@
|
||||
"""
|
||||
LLM 模型管理器
|
||||
|
||||
负责模型实例的创建、缓存、配置管理和生命周期管理。
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
from .config import validate_override_params
|
||||
from .config.providers import AI_CONFIG_GROUP, PROVIDERS_CONFIG_KEY, get_ai_config
|
||||
from .core import http_client_manager, key_store
|
||||
from .service import LLMModel
|
||||
from .types import LLMErrorCode, LLMException, ModelDetail, ProviderConfig
|
||||
|
||||
DEFAULT_MODEL_NAME_KEY = "default_model_name"
|
||||
PROXY_KEY = "proxy"
|
||||
TIMEOUT_KEY = "timeout"
|
||||
|
||||
_model_cache: dict[str, tuple[LLMModel, float]] = {}
|
||||
_cache_ttl = 3600
|
||||
_max_cache_size = 10
|
||||
|
||||
|
||||
def parse_provider_model_string(name_str: str | None) -> tuple[str | None, str | None]:
|
||||
"""解析 'ProviderName/ModelName' 格式的字符串"""
|
||||
if not name_str or "/" not in name_str:
|
||||
return None, None
|
||||
parts = name_str.split("/", 1)
|
||||
if len(parts) == 2 and parts[0].strip() and parts[1].strip():
|
||||
return parts[0].strip(), parts[1].strip()
|
||||
return None, None
|
||||
|
||||
|
||||
def _make_cache_key(
|
||||
provider_model_name: str | None, override_config: dict | None
|
||||
) -> str:
|
||||
"""生成缓存键"""
|
||||
config_str = (
|
||||
json.dumps(override_config, sort_keys=True) if override_config else "None"
|
||||
)
|
||||
key_data = f"{provider_model_name}:{config_str}"
|
||||
return hashlib.md5(key_data.encode()).hexdigest()
|
||||
|
||||
|
||||
def _get_cached_model(cache_key: str) -> LLMModel | None:
|
||||
"""从缓存获取模型"""
|
||||
if cache_key in _model_cache:
|
||||
model, created_time = _model_cache[cache_key]
|
||||
current_time = time.time()
|
||||
|
||||
if current_time - created_time > _cache_ttl:
|
||||
del _model_cache[cache_key]
|
||||
logger.debug(f"模型缓存已过期: {cache_key}")
|
||||
return None
|
||||
|
||||
if model._is_closed:
|
||||
logger.debug(
|
||||
f"缓存的模型 {cache_key} ({model.provider_name}/{model.model_name}) "
|
||||
f"处于_is_closed=True状态,重置为False以供复用。"
|
||||
)
|
||||
model._is_closed = False
|
||||
|
||||
logger.debug(
|
||||
f"使用缓存的模型: {cache_key} -> {model.provider_name}/{model.model_name}"
|
||||
)
|
||||
return model
|
||||
return None
|
||||
|
||||
|
||||
def _cache_model(cache_key: str, model: LLMModel):
|
||||
"""缓存模型实例"""
|
||||
current_time = time.time()
|
||||
|
||||
if len(_model_cache) >= _max_cache_size:
|
||||
oldest_key = min(_model_cache.keys(), key=lambda k: _model_cache[k][1])
|
||||
del _model_cache[oldest_key]
|
||||
|
||||
_model_cache[cache_key] = (model, current_time)
|
||||
|
||||
|
||||
def clear_model_cache():
|
||||
"""清空模型缓存"""
|
||||
global _model_cache
|
||||
_model_cache.clear()
|
||||
logger.info("已清空模型缓存")
|
||||
|
||||
|
||||
def get_cache_stats() -> dict[str, Any]:
|
||||
"""获取缓存统计信息"""
|
||||
return {
|
||||
"cache_size": len(_model_cache),
|
||||
"max_cache_size": _max_cache_size,
|
||||
"cache_ttl": _cache_ttl,
|
||||
"cached_models": list(_model_cache.keys()),
|
||||
}
|
||||
|
||||
|
||||
def get_default_api_base_for_type(api_type: str) -> str | None:
|
||||
"""根据API类型获取默认的API基础地址"""
|
||||
default_api_bases = {
|
||||
"openai": "https://api.openai.com",
|
||||
"deepseek": "https://api.deepseek.com",
|
||||
"zhipu": "https://open.bigmodel.cn",
|
||||
"gemini": "https://generativelanguage.googleapis.com",
|
||||
"general_openai_compat": None,
|
||||
}
|
||||
|
||||
return default_api_bases.get(api_type)
|
||||
|
||||
|
||||
def get_configured_providers() -> list[ProviderConfig]:
|
||||
"""从配置中获取Provider列表 - 简化版本"""
|
||||
ai_config = get_ai_config()
|
||||
providers_raw = ai_config.get(PROVIDERS_CONFIG_KEY, [])
|
||||
if not isinstance(providers_raw, list):
|
||||
logger.error(
|
||||
f"配置项 {AI_CONFIG_GROUP}.{PROVIDERS_CONFIG_KEY} 不是一个列表,"
|
||||
f"将使用空列表。"
|
||||
)
|
||||
return []
|
||||
|
||||
valid_providers = []
|
||||
for i, item in enumerate(providers_raw):
|
||||
if not isinstance(item, dict):
|
||||
logger.warning(f"配置文件中第 {i + 1} 项不是字典格式,已跳过。")
|
||||
continue
|
||||
|
||||
try:
|
||||
if not item.get("name"):
|
||||
logger.warning(f"Provider {i + 1} 缺少 'name' 字段,已跳过。")
|
||||
continue
|
||||
|
||||
if not item.get("api_key"):
|
||||
logger.warning(
|
||||
f"Provider '{item['name']}' 缺少 'api_key' 字段,已跳过。"
|
||||
)
|
||||
continue
|
||||
|
||||
if "api_type" not in item or not item["api_type"]:
|
||||
provider_name = item.get("name", "").lower()
|
||||
if "glm" in provider_name or "zhipu" in provider_name:
|
||||
item["api_type"] = "zhipu"
|
||||
elif "gemini" in provider_name or "google" in provider_name:
|
||||
item["api_type"] = "gemini"
|
||||
else:
|
||||
item["api_type"] = "openai"
|
||||
|
||||
if "api_base" not in item or not item["api_base"]:
|
||||
api_type = item.get("api_type")
|
||||
if api_type:
|
||||
default_api_base = get_default_api_base_for_type(api_type)
|
||||
if default_api_base:
|
||||
item["api_base"] = default_api_base
|
||||
|
||||
if "models" not in item:
|
||||
item["models"] = [{"model_name": item.get("name", "default")}]
|
||||
|
||||
provider_conf = ProviderConfig(**item)
|
||||
valid_providers.append(provider_conf)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"解析配置文件中 Provider {i + 1} 时出错: {e},已跳过。")
|
||||
|
||||
return valid_providers
|
||||
|
||||
|
||||
def find_model_config(
|
||||
provider_name: str, model_name: str
|
||||
) -> tuple[ProviderConfig, ModelDetail] | None:
|
||||
"""在配置中查找指定的 Provider 和 ModelDetail
|
||||
|
||||
Args:
|
||||
provider_name: 提供商名称
|
||||
model_name: 模型名称
|
||||
|
||||
Returns:
|
||||
找到的 (ProviderConfig, ModelDetail) 元组,未找到则返回 None
|
||||
"""
|
||||
providers = get_configured_providers()
|
||||
|
||||
for provider in providers:
|
||||
if provider.name.lower() == provider_name.lower():
|
||||
for model_detail in provider.models:
|
||||
if model_detail.model_name.lower() == model_name.lower():
|
||||
return provider, model_detail
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def list_available_models() -> list[dict[str, Any]]:
|
||||
"""列出所有配置的可用模型"""
|
||||
providers = get_configured_providers()
|
||||
model_list = []
|
||||
for provider in providers:
|
||||
for model_detail in provider.models:
|
||||
model_info = {
|
||||
"provider_name": provider.name,
|
||||
"model_name": model_detail.model_name,
|
||||
"full_name": f"{provider.name}/{model_detail.model_name}",
|
||||
"api_type": provider.api_type or "auto-detect",
|
||||
"api_base": provider.api_base,
|
||||
"is_available": model_detail.is_available,
|
||||
"is_embedding_model": model_detail.is_embedding_model,
|
||||
"available_identifiers": _get_model_identifiers(
|
||||
provider.name, model_detail
|
||||
),
|
||||
}
|
||||
model_list.append(model_info)
|
||||
return model_list
|
||||
|
||||
|
||||
def _get_model_identifiers(provider_name: str, model_detail: ModelDetail) -> list[str]:
|
||||
"""获取模型的所有可用标识符"""
|
||||
return [f"{provider_name}/{model_detail.model_name}"]
|
||||
|
||||
|
||||
def list_model_identifiers() -> dict[str, list[str]]:
|
||||
"""列出所有模型的可用标识符
|
||||
|
||||
Returns:
|
||||
字典,键为模型的完整名称,值为该模型的所有可用标识符列表
|
||||
"""
|
||||
providers = get_configured_providers()
|
||||
result = {}
|
||||
|
||||
for provider in providers:
|
||||
for model_detail in provider.models:
|
||||
full_name = f"{provider.name}/{model_detail.model_name}"
|
||||
identifiers = _get_model_identifiers(provider.name, model_detail)
|
||||
result[full_name] = identifiers
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def list_embedding_models() -> list[dict[str, Any]]:
|
||||
"""列出所有配置的嵌入模型"""
|
||||
all_models = list_available_models()
|
||||
return [model for model in all_models if model.get("is_embedding_model", False)]
|
||||
|
||||
|
||||
async def get_model_instance(
|
||||
provider_model_name: str | None = None,
|
||||
override_config: dict[str, Any] | None = None,
|
||||
) -> LLMModel:
|
||||
"""根据 'ProviderName/ModelName' 字符串获取并实例化 LLMModel (异步版本)"""
|
||||
cache_key = _make_cache_key(provider_model_name, override_config)
|
||||
cached_model = _get_cached_model(cache_key)
|
||||
if cached_model:
|
||||
if override_config:
|
||||
validated_override = validate_override_params(override_config)
|
||||
if cached_model._generation_config != validated_override:
|
||||
cached_model._generation_config = validated_override
|
||||
logger.debug(
|
||||
f"对缓存模型 {provider_model_name} 应用新的覆盖配置: "
|
||||
f"{validated_override.to_dict()}"
|
||||
)
|
||||
return cached_model
|
||||
|
||||
resolved_model_name_str = provider_model_name
|
||||
if resolved_model_name_str is None:
|
||||
resolved_model_name_str = get_global_default_model_name()
|
||||
if resolved_model_name_str is None:
|
||||
available_models_list = list_available_models()
|
||||
if not available_models_list:
|
||||
raise LLMException(
|
||||
"未配置任何AI模型", code=LLMErrorCode.CONFIGURATION_ERROR
|
||||
)
|
||||
resolved_model_name_str = available_models_list[0]["full_name"]
|
||||
logger.warning(f"未指定模型,使用第一个可用模型: {resolved_model_name_str}")
|
||||
|
||||
prov_name_str, mod_name_str = parse_provider_model_string(resolved_model_name_str)
|
||||
if not prov_name_str or not mod_name_str:
|
||||
raise LLMException(
|
||||
f"无效的模型名称格式: '{resolved_model_name_str}'",
|
||||
code=LLMErrorCode.MODEL_NOT_FOUND,
|
||||
)
|
||||
|
||||
config_tuple_found = find_model_config(prov_name_str, mod_name_str)
|
||||
if not config_tuple_found:
|
||||
all_models = list_available_models()
|
||||
raise LLMException(
|
||||
f"未找到模型: '{resolved_model_name_str}'. "
|
||||
f"可用: {[m['full_name'] for m in all_models]}",
|
||||
code=LLMErrorCode.MODEL_NOT_FOUND,
|
||||
)
|
||||
|
||||
provider_config_found, model_detail_found = config_tuple_found
|
||||
|
||||
ai_config = get_ai_config()
|
||||
global_proxy_setting = ai_config.get(PROXY_KEY)
|
||||
default_timeout = (
|
||||
provider_config_found.timeout
|
||||
if provider_config_found.timeout is not None
|
||||
else 180
|
||||
)
|
||||
global_timeout_setting = ai_config.get(TIMEOUT_KEY, default_timeout)
|
||||
|
||||
config_for_http_client = ProviderConfig(
|
||||
name=provider_config_found.name,
|
||||
api_key=provider_config_found.api_key,
|
||||
models=provider_config_found.models,
|
||||
timeout=global_timeout_setting,
|
||||
proxy=global_proxy_setting,
|
||||
api_base=provider_config_found.api_base,
|
||||
api_type=provider_config_found.api_type,
|
||||
openai_compat=provider_config_found.openai_compat,
|
||||
temperature=provider_config_found.temperature,
|
||||
max_tokens=provider_config_found.max_tokens,
|
||||
)
|
||||
|
||||
shared_http_client = await http_client_manager.get_client(config_for_http_client)
|
||||
|
||||
try:
|
||||
model_instance = LLMModel(
|
||||
provider_config=config_for_http_client,
|
||||
model_detail=model_detail_found,
|
||||
key_store=key_store,
|
||||
http_client=shared_http_client,
|
||||
)
|
||||
|
||||
if override_config:
|
||||
validated_override_params = validate_override_params(override_config)
|
||||
model_instance._generation_config = validated_override_params
|
||||
logger.debug(
|
||||
f"为新模型 {resolved_model_name_str} 应用配置覆盖: "
|
||||
f"{validated_override_params.to_dict()}"
|
||||
)
|
||||
|
||||
_cache_model(cache_key, model_instance)
|
||||
logger.debug(
|
||||
f"创建并缓存了新模型: {cache_key} -> {prov_name_str}/{mod_name_str}"
|
||||
)
|
||||
return model_instance
|
||||
except LLMException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"实例化 LLMModel ({resolved_model_name_str}) 时发生内部错误: {e!s}", e=e
|
||||
)
|
||||
raise LLMException(
|
||||
f"初始化模型 '{resolved_model_name_str}' 失败: {e!s}",
|
||||
code=LLMErrorCode.MODEL_INIT_FAILED,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
|
||||
def get_global_default_model_name() -> str | None:
|
||||
"""获取全局默认模型名称"""
|
||||
ai_config = get_ai_config()
|
||||
return ai_config.get(DEFAULT_MODEL_NAME_KEY)
|
||||
|
||||
|
||||
def set_global_default_model_name(provider_model_name: str | None) -> bool:
|
||||
"""设置全局默认模型名称"""
|
||||
if provider_model_name:
|
||||
prov_name, mod_name = parse_provider_model_string(provider_model_name)
|
||||
if not prov_name or not mod_name or not find_model_config(prov_name, mod_name):
|
||||
logger.error(
|
||||
f"尝试设置的全局默认模型 '{provider_model_name}' 无效或未配置。"
|
||||
)
|
||||
return False
|
||||
|
||||
Config.set_config(
|
||||
AI_CONFIG_GROUP, DEFAULT_MODEL_NAME_KEY, provider_model_name, auto_save=True
|
||||
)
|
||||
if provider_model_name:
|
||||
logger.info(f"LLM 服务全局默认模型已更新为: {provider_model_name}")
|
||||
else:
|
||||
logger.info("LLM 服务全局默认模型已清除。")
|
||||
return True
|
||||
|
||||
|
||||
async def get_key_usage_stats() -> dict[str, Any]:
|
||||
"""获取所有Provider的Key使用统计"""
|
||||
providers = get_configured_providers()
|
||||
stats = {}
|
||||
|
||||
for provider in providers:
|
||||
provider_stats = await key_store.get_key_stats(
|
||||
[provider.api_key]
|
||||
if isinstance(provider.api_key, str)
|
||||
else provider.api_key
|
||||
)
|
||||
stats[provider.name] = {
|
||||
"total_keys": len(
|
||||
[provider.api_key]
|
||||
if isinstance(provider.api_key, str)
|
||||
else provider.api_key
|
||||
),
|
||||
"key_stats": provider_stats,
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
async def reset_key_status(provider_name: str, api_key: str | None = None) -> bool:
|
||||
"""重置指定Provider的Key状态"""
|
||||
providers = get_configured_providers()
|
||||
target_provider = None
|
||||
|
||||
for provider in providers:
|
||||
if provider.name.lower() == provider_name.lower():
|
||||
target_provider = provider
|
||||
break
|
||||
|
||||
if not target_provider:
|
||||
logger.error(f"未找到Provider: {provider_name}")
|
||||
return False
|
||||
|
||||
provider_keys = (
|
||||
[target_provider.api_key]
|
||||
if isinstance(target_provider.api_key, str)
|
||||
else target_provider.api_key
|
||||
)
|
||||
|
||||
if api_key:
|
||||
if api_key in provider_keys:
|
||||
await key_store.reset_key_status(api_key)
|
||||
logger.info(f"已重置Provider '{provider_name}' 的指定Key状态")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"指定的Key不属于Provider '{provider_name}'")
|
||||
return False
|
||||
else:
|
||||
for key in provider_keys:
|
||||
await key_store.reset_key_status(key)
|
||||
logger.info(f"已重置Provider '{provider_name}' 的所有Key状态")
|
||||
return True
|
||||
632
zhenxun/services/llm/service.py
Normal file
632
zhenxun/services/llm/service.py
Normal file
@ -0,0 +1,632 @@
|
||||
"""
|
||||
LLM 模型实现类
|
||||
|
||||
包含 LLM 模型的抽象基类和具体实现,负责与各种 AI 提供商的 API 交互。
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Awaitable, Callable
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
from .config import LLMGenerationConfig
|
||||
from .config.providers import get_ai_config
|
||||
from .core import (
|
||||
KeyStatusStore,
|
||||
LLMHttpClient,
|
||||
RetryConfig,
|
||||
http_client_manager,
|
||||
with_smart_retry,
|
||||
)
|
||||
from .types import (
|
||||
EmbeddingTaskType,
|
||||
LLMErrorCode,
|
||||
LLMException,
|
||||
LLMMessage,
|
||||
LLMResponse,
|
||||
LLMTool,
|
||||
ModelDetail,
|
||||
ProviderConfig,
|
||||
)
|
||||
|
||||
|
||||
class LLMModelBase(ABC):
|
||||
"""LLM模型抽象基类"""
|
||||
|
||||
@abstractmethod
|
||||
async def generate_text(
|
||||
self,
|
||||
prompt: str,
|
||||
history: list[dict[str, str]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""生成文本"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def generate_response(
|
||||
self,
|
||||
messages: list[LLMMessage],
|
||||
config: LLMGenerationConfig | None = None,
|
||||
tools: list[LLMTool] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""生成高级响应"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def generate_embeddings(
|
||||
self,
|
||||
texts: list[str],
|
||||
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
||||
**kwargs: Any,
|
||||
) -> list[list[float]]:
|
||||
"""生成文本嵌入向量"""
|
||||
pass
|
||||
|
||||
|
||||
class LLMModel(LLMModelBase):
|
||||
"""LLM 模型实现类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: ProviderConfig,
|
||||
model_detail: ModelDetail,
|
||||
key_store: KeyStatusStore,
|
||||
http_client: LLMHttpClient,
|
||||
config_override: LLMGenerationConfig | None = None,
|
||||
):
|
||||
self.provider_config = provider_config
|
||||
self.model_detail = model_detail
|
||||
self.key_store = key_store
|
||||
self.http_client: LLMHttpClient = http_client
|
||||
self._generation_config = config_override
|
||||
|
||||
self.provider_name = provider_config.name
|
||||
self.api_type = provider_config.api_type
|
||||
self.api_base = provider_config.api_base
|
||||
self.api_keys = (
|
||||
[provider_config.api_key]
|
||||
if isinstance(provider_config.api_key, str)
|
||||
else provider_config.api_key
|
||||
)
|
||||
self.model_name = model_detail.model_name
|
||||
self.temperature = model_detail.temperature
|
||||
self.max_tokens = model_detail.max_tokens
|
||||
|
||||
self._is_closed = False
|
||||
|
||||
async def _get_http_client(self) -> LLMHttpClient:
|
||||
"""获取HTTP客户端"""
|
||||
if self.http_client.is_closed:
|
||||
logger.debug(
|
||||
f"LLMModel {self.provider_name}/{self.model_name} 的 HTTP 客户端已关闭,"
|
||||
"正在获取新的客户端"
|
||||
)
|
||||
self.http_client = await http_client_manager.get_client(
|
||||
self.provider_config
|
||||
)
|
||||
return self.http_client
|
||||
|
||||
async def _select_api_key(self, failed_keys: set[str] | None = None) -> str:
|
||||
"""选择可用的API密钥(使用轮询策略)"""
|
||||
if not self.api_keys:
|
||||
raise LLMException(
|
||||
f"提供商 {self.provider_name} 没有配置API密钥",
|
||||
code=LLMErrorCode.NO_AVAILABLE_KEYS,
|
||||
)
|
||||
|
||||
selected_key = await self.key_store.get_next_available_key(
|
||||
self.provider_name, self.api_keys, failed_keys
|
||||
)
|
||||
|
||||
if not selected_key:
|
||||
raise LLMException(
|
||||
f"提供商 {self.provider_name} 的所有API密钥当前都不可用",
|
||||
code=LLMErrorCode.NO_AVAILABLE_KEYS,
|
||||
details={
|
||||
"total_keys": len(self.api_keys),
|
||||
"failed_keys": len(failed_keys or set()),
|
||||
},
|
||||
)
|
||||
|
||||
return selected_key
|
||||
|
||||
async def _execute_embedding_request(
|
||||
self,
|
||||
adapter,
|
||||
texts: list[str],
|
||||
task_type: EmbeddingTaskType | str,
|
||||
http_client: LLMHttpClient,
|
||||
failed_keys: set[str] | None = None,
|
||||
) -> list[list[float]]:
|
||||
"""执行单次嵌入请求 - 供重试机制调用"""
|
||||
api_key = await self._select_api_key(failed_keys)
|
||||
|
||||
try:
|
||||
request_data = adapter.prepare_embedding_request(
|
||||
model=self,
|
||||
api_key=api_key,
|
||||
texts=texts,
|
||||
task_type=task_type,
|
||||
)
|
||||
|
||||
http_response = await http_client.post(
|
||||
request_data.url,
|
||||
headers=request_data.headers,
|
||||
json=request_data.body,
|
||||
)
|
||||
|
||||
if http_response.status_code != 200:
|
||||
error_text = http_response.text
|
||||
logger.error(
|
||||
f"HTTP嵌入请求失败: {http_response.status_code} - {error_text}"
|
||||
)
|
||||
await self.key_store.record_failure(api_key, http_response.status_code)
|
||||
|
||||
error_code = LLMErrorCode.API_REQUEST_FAILED
|
||||
if http_response.status_code in [401, 403]:
|
||||
error_code = LLMErrorCode.API_KEY_INVALID
|
||||
elif http_response.status_code == 429:
|
||||
error_code = LLMErrorCode.API_RATE_LIMITED
|
||||
|
||||
raise LLMException(
|
||||
f"HTTP嵌入请求失败: {http_response.status_code}",
|
||||
code=error_code,
|
||||
details={
|
||||
"status_code": http_response.status_code,
|
||||
"response": error_text,
|
||||
"api_key": api_key,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response_json = http_response.json()
|
||||
adapter.validate_embedding_response(response_json)
|
||||
embeddings = adapter.parse_embedding_response(response_json)
|
||||
except Exception as e:
|
||||
logger.error(f"解析嵌入响应失败: {e}", e=e)
|
||||
await self.key_store.record_failure(api_key, None)
|
||||
if isinstance(e, LLMException):
|
||||
raise
|
||||
else:
|
||||
raise LLMException(
|
||||
f"解析API嵌入响应失败: {e}",
|
||||
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
await self.key_store.record_success(api_key)
|
||||
return embeddings
|
||||
|
||||
except LLMException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"生成嵌入时发生未预期错误: {e}", e=e)
|
||||
await self.key_store.record_failure(api_key, None)
|
||||
raise LLMException(
|
||||
f"生成嵌入失败: {e}",
|
||||
code=LLMErrorCode.EMBEDDING_FAILED,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def _execute_with_smart_retry(
|
||||
self,
|
||||
adapter,
|
||||
messages: list[LLMMessage],
|
||||
config: LLMGenerationConfig | None,
|
||||
tools_dict: list[dict[str, Any]] | None,
|
||||
tool_choice: str | dict[str, Any] | None,
|
||||
http_client: LLMHttpClient,
|
||||
):
|
||||
"""智能重试机制 - 使用统一的重试装饰器"""
|
||||
ai_config = get_ai_config()
|
||||
max_retries = ai_config.get("max_retries_llm", 3)
|
||||
retry_delay = ai_config.get("retry_delay_llm", 2)
|
||||
retry_config = RetryConfig(max_retries=max_retries, retry_delay=retry_delay)
|
||||
|
||||
return await with_smart_retry(
|
||||
self._execute_single_request,
|
||||
adapter,
|
||||
messages,
|
||||
config,
|
||||
tools_dict,
|
||||
tool_choice,
|
||||
http_client,
|
||||
retry_config=retry_config,
|
||||
key_store=self.key_store,
|
||||
provider_name=self.provider_name,
|
||||
)
|
||||
|
||||
async def _execute_single_request(
|
||||
self,
|
||||
adapter,
|
||||
messages: list[LLMMessage],
|
||||
config: LLMGenerationConfig | None,
|
||||
tools_dict: list[dict[str, Any]] | None,
|
||||
tool_choice: str | dict[str, Any] | None,
|
||||
http_client: LLMHttpClient,
|
||||
failed_keys: set[str] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""执行单次请求 - 供重试机制调用,直接返回 LLMResponse"""
|
||||
api_key = await self._select_api_key(failed_keys)
|
||||
|
||||
try:
|
||||
request_data = adapter.prepare_advanced_request(
|
||||
model=self,
|
||||
api_key=api_key,
|
||||
messages=messages,
|
||||
config=config,
|
||||
tools=tools_dict,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
http_response = await http_client.post(
|
||||
request_data.url,
|
||||
headers=request_data.headers,
|
||||
json=request_data.body,
|
||||
)
|
||||
|
||||
if http_response.status_code != 200:
|
||||
error_text = http_response.text
|
||||
logger.error(
|
||||
f"HTTP请求失败: {http_response.status_code} - {error_text}"
|
||||
)
|
||||
|
||||
await self.key_store.record_failure(api_key, http_response.status_code)
|
||||
|
||||
if http_response.status_code in [401, 403]:
|
||||
error_code = LLMErrorCode.API_KEY_INVALID
|
||||
elif http_response.status_code == 429:
|
||||
error_code = LLMErrorCode.API_RATE_LIMITED
|
||||
elif http_response.status_code in [402, 413]:
|
||||
error_code = LLMErrorCode.API_QUOTA_EXCEEDED
|
||||
else:
|
||||
error_code = LLMErrorCode.API_REQUEST_FAILED
|
||||
|
||||
raise LLMException(
|
||||
f"HTTP请求失败: {http_response.status_code}",
|
||||
code=error_code,
|
||||
details={
|
||||
"status_code": http_response.status_code,
|
||||
"response": error_text,
|
||||
"api_key": api_key,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response_json = http_response.json()
|
||||
response_data = adapter.parse_response(
|
||||
model=self,
|
||||
response_json=response_json,
|
||||
is_advanced=True,
|
||||
)
|
||||
|
||||
from .types.models import LLMToolCall
|
||||
|
||||
response_tool_calls = []
|
||||
if response_data.tool_calls:
|
||||
for tc_data in response_data.tool_calls:
|
||||
if isinstance(tc_data, LLMToolCall):
|
||||
response_tool_calls.append(tc_data)
|
||||
elif isinstance(tc_data, dict):
|
||||
try:
|
||||
response_tool_calls.append(LLMToolCall(**tc_data))
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"无法将工具调用数据转换为LLMToolCall: {tc_data}, "
|
||||
f"error: {e}"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"工具调用数据格式未知: {tc_data}")
|
||||
|
||||
llm_response = LLMResponse(
|
||||
text=response_data.text,
|
||||
usage_info=response_data.usage_info,
|
||||
raw_response=response_data.raw_response,
|
||||
tool_calls=response_tool_calls if response_tool_calls else None,
|
||||
code_executions=response_data.code_executions,
|
||||
grounding_metadata=response_data.grounding_metadata,
|
||||
cache_info=response_data.cache_info,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析响应失败: {e}", e=e)
|
||||
await self.key_store.record_failure(api_key, None)
|
||||
|
||||
if isinstance(e, LLMException):
|
||||
raise
|
||||
else:
|
||||
raise LLMException(
|
||||
f"解析API响应失败: {e}",
|
||||
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
await self.key_store.record_success(api_key)
|
||||
|
||||
return llm_response
|
||||
|
||||
except LLMException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"生成响应时发生未预期错误: {e}", e=e)
|
||||
await self.key_store.record_failure(api_key, None)
|
||||
|
||||
raise LLMException(
|
||||
f"生成响应失败: {e}",
|
||||
code=LLMErrorCode.GENERATION_FAILED,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
标记模型实例的当前使用周期结束。
|
||||
共享的 HTTP 客户端由 LLMHttpClientManager 管理,不由 LLMModel 关闭。
|
||||
"""
|
||||
if self._is_closed:
|
||||
return
|
||||
self._is_closed = True
|
||||
logger.debug(
|
||||
f"LLMModel实例的使用周期已结束: {self} (共享HTTP客户端状态不受影响)"
|
||||
)
|
||||
|
||||
async def __aenter__(self):
|
||||
if self._is_closed:
|
||||
logger.debug(
|
||||
f"Re-entering context for closed LLMModel {self}. "
|
||||
f"Resetting _is_closed to False."
|
||||
)
|
||||
self._is_closed = False
|
||||
self._check_not_closed()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""异步上下文管理器出口"""
|
||||
_ = exc_type, exc_val, exc_tb
|
||||
await self.close()
|
||||
|
||||
def _check_not_closed(self):
|
||||
"""检查实例是否已关闭"""
|
||||
if self._is_closed:
|
||||
raise RuntimeError(f"LLMModel实例已关闭: {self}")
|
||||
|
||||
async def generate_text(
|
||||
self,
|
||||
prompt: str,
|
||||
history: list[dict[str, str]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""生成文本 - 通过 generate_response 实现"""
|
||||
self._check_not_closed()
|
||||
|
||||
messages: list[LLMMessage] = []
|
||||
|
||||
if history:
|
||||
for msg in history:
|
||||
role = msg.get("role", "user")
|
||||
content_text = msg.get("content", "")
|
||||
messages.append(LLMMessage(role=role, content=content_text))
|
||||
|
||||
messages.append(LLMMessage.user(prompt))
|
||||
|
||||
model_fields = getattr(LLMGenerationConfig, "model_fields", {})
|
||||
request_specific_config_dict = {
|
||||
k: v for k, v in kwargs.items() if k in model_fields
|
||||
}
|
||||
request_specific_config = None
|
||||
if request_specific_config_dict:
|
||||
request_specific_config = LLMGenerationConfig(
|
||||
**request_specific_config_dict
|
||||
)
|
||||
|
||||
for key in request_specific_config_dict:
|
||||
kwargs.pop(key, None)
|
||||
|
||||
response = await self.generate_response(
|
||||
messages,
|
||||
config=request_specific_config,
|
||||
**kwargs,
|
||||
)
|
||||
return response.text
|
||||
|
||||
async def generate_response(
|
||||
self,
|
||||
messages: list[LLMMessage],
|
||||
config: LLMGenerationConfig | None = None,
|
||||
tools: list[LLMTool] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tool_executor: Callable[[str, dict[str, Any]], Awaitable[Any]] | None = None,
|
||||
max_tool_iterations: int = 5,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""生成高级响应 - 实现完整的工具调用循环"""
|
||||
self._check_not_closed()
|
||||
|
||||
from .adapters import get_adapter_for_api_type
|
||||
from .config.generation import create_generation_config_from_kwargs
|
||||
|
||||
adapter = get_adapter_for_api_type(self.api_type)
|
||||
if not adapter:
|
||||
raise LLMException(
|
||||
f"未找到适用于 API 类型 '{self.api_type}' 的适配器",
|
||||
code=LLMErrorCode.CONFIGURATION_ERROR,
|
||||
)
|
||||
|
||||
final_request_config = self._generation_config or LLMGenerationConfig()
|
||||
if kwargs:
|
||||
kwargs_config = create_generation_config_from_kwargs(**kwargs)
|
||||
merged_dict = final_request_config.to_dict()
|
||||
merged_dict.update(kwargs_config.to_dict())
|
||||
final_request_config = LLMGenerationConfig(**merged_dict)
|
||||
|
||||
if config is not None:
|
||||
merged_dict = final_request_config.to_dict()
|
||||
merged_dict.update(config.to_dict())
|
||||
final_request_config = LLMGenerationConfig(**merged_dict)
|
||||
|
||||
tools_dict: list[dict[str, Any]] | None = None
|
||||
if tools:
|
||||
tools_dict = []
|
||||
for tool in tools:
|
||||
if hasattr(tool, "model_dump"):
|
||||
model_dump_func = getattr(tool, "model_dump")
|
||||
tools_dict.append(model_dump_func(exclude_none=True))
|
||||
elif isinstance(tool, dict):
|
||||
tools_dict.append(tool)
|
||||
else:
|
||||
try:
|
||||
tools_dict.append(dict(tool))
|
||||
except (TypeError, ValueError):
|
||||
logger.warning(f"工具 '{tool}' 无法转换为字典,已忽略。")
|
||||
|
||||
http_client = await self._get_http_client()
|
||||
current_messages = list(messages)
|
||||
|
||||
for iteration in range(max_tool_iterations):
|
||||
logger.debug(f"工具调用循环迭代: {iteration + 1}/{max_tool_iterations}")
|
||||
|
||||
llm_response = await self._execute_with_smart_retry(
|
||||
adapter,
|
||||
current_messages,
|
||||
final_request_config,
|
||||
tools_dict if iteration == 0 else None,
|
||||
tool_choice if iteration == 0 else None,
|
||||
http_client,
|
||||
)
|
||||
|
||||
response_tool_calls = llm_response.tool_calls or []
|
||||
|
||||
if not response_tool_calls or not tool_executor:
|
||||
logger.debug("模型未请求工具调用,或未提供工具执行器。返回当前响应。")
|
||||
return llm_response
|
||||
|
||||
logger.info(f"模型请求执行 {len(response_tool_calls)} 个工具。")
|
||||
|
||||
assistant_message_content = llm_response.text if llm_response.text else ""
|
||||
current_messages.append(
|
||||
LLMMessage.assistant_tool_calls(
|
||||
content=assistant_message_content, tool_calls=response_tool_calls
|
||||
)
|
||||
)
|
||||
|
||||
tool_response_messages: list[LLMMessage] = []
|
||||
for tool_call in response_tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
try:
|
||||
tool_args_dict = json.loads(tool_call.function.arguments)
|
||||
logger.debug(f"执行工具: {tool_name},参数: {tool_args_dict}")
|
||||
|
||||
tool_result = await tool_executor(tool_name, tool_args_dict)
|
||||
logger.debug(
|
||||
f"工具 '{tool_name}' 执行结果: {str(tool_result)[:200]}..."
|
||||
)
|
||||
|
||||
tool_response_messages.append(
|
||||
LLMMessage.tool_response(
|
||||
tool_call_id=tool_call.id,
|
||||
function_name=tool_name,
|
||||
result=tool_result,
|
||||
)
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(
|
||||
f"工具 '{tool_name}' 参数JSON解析失败: "
|
||||
f"{tool_call.function.arguments}, 错误: {e}"
|
||||
)
|
||||
tool_response_messages.append(
|
||||
LLMMessage.tool_response(
|
||||
tool_call_id=tool_call.id,
|
||||
function_name=tool_name,
|
||||
result={
|
||||
"error": "Argument JSON parsing failed",
|
||||
"details": str(e),
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"执行工具 '{tool_name}' 失败: {e}", e=e)
|
||||
tool_response_messages.append(
|
||||
LLMMessage.tool_response(
|
||||
tool_call_id=tool_call.id,
|
||||
function_name=tool_name,
|
||||
result={
|
||||
"error": "Tool execution failed",
|
||||
"details": str(e),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
current_messages.extend(tool_response_messages)
|
||||
|
||||
logger.warning(f"已达到最大工具调用迭代次数 ({max_tool_iterations})。")
|
||||
raise LLMException(
|
||||
"已达到最大工具调用迭代次数,但模型仍在请求工具调用或未提供最终文本回复。",
|
||||
code=LLMErrorCode.GENERATION_FAILED,
|
||||
details={
|
||||
"iterations": max_tool_iterations,
|
||||
"last_messages": current_messages[-2:],
|
||||
},
|
||||
)
|
||||
|
||||
async def generate_embeddings(
|
||||
self,
|
||||
texts: list[str],
|
||||
task_type: EmbeddingTaskType | str = EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
||||
**kwargs: Any,
|
||||
) -> list[list[float]]:
|
||||
"""生成文本嵌入向量"""
|
||||
self._check_not_closed()
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
from .adapters import get_adapter_for_api_type
|
||||
|
||||
adapter = get_adapter_for_api_type(self.api_type)
|
||||
if not adapter:
|
||||
raise LLMException(
|
||||
f"未找到适用于 API 类型 '{self.api_type}' 的嵌入适配器",
|
||||
code=LLMErrorCode.CONFIGURATION_ERROR,
|
||||
)
|
||||
|
||||
http_client = await self._get_http_client()
|
||||
|
||||
ai_config = get_ai_config()
|
||||
default_max_retries = ai_config.get("max_retries_llm", 3)
|
||||
default_retry_delay = ai_config.get("retry_delay_llm", 2)
|
||||
max_retries_embed = kwargs.get(
|
||||
"max_retries_embed", max(1, default_max_retries // 2)
|
||||
)
|
||||
retry_delay_embed = kwargs.get("retry_delay_embed", default_retry_delay / 2)
|
||||
|
||||
retry_config = RetryConfig(
|
||||
max_retries=max_retries_embed,
|
||||
retry_delay=retry_delay_embed,
|
||||
exponential_backoff=True,
|
||||
key_rotation=True,
|
||||
)
|
||||
|
||||
return await with_smart_retry(
|
||||
self._execute_embedding_request,
|
||||
adapter,
|
||||
texts,
|
||||
task_type,
|
||||
http_client,
|
||||
retry_config=retry_config,
|
||||
key_store=self.key_store,
|
||||
provider_name=self.provider_name,
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
status = "closed" if self._is_closed else "active"
|
||||
return f"LLMModel({self.provider_name}/{self.model_name}, {status})"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
status = "closed" if self._is_closed else "active"
|
||||
return (
|
||||
f"LLMModel(provider={self.provider_name}, model={self.model_name}, "
|
||||
f"api_type={self.api_type}, status={status})"
|
||||
)
|
||||
54
zhenxun/services/llm/types/__init__.py
Normal file
54
zhenxun/services/llm/types/__init__.py
Normal file
@ -0,0 +1,54 @@
|
||||
"""
|
||||
LLM 类型定义模块
|
||||
|
||||
统一导出所有核心类型、协议和异常定义。
|
||||
"""
|
||||
|
||||
from .content import (
|
||||
LLMContentPart,
|
||||
LLMMessage,
|
||||
LLMResponse,
|
||||
)
|
||||
from .enums import EmbeddingTaskType, ModelProvider, ResponseFormat, ToolCategory
|
||||
from .exceptions import LLMErrorCode, LLMException, get_user_friendly_error_message
|
||||
from .models import (
|
||||
LLMCacheInfo,
|
||||
LLMCodeExecution,
|
||||
LLMGroundingAttribution,
|
||||
LLMGroundingMetadata,
|
||||
LLMTool,
|
||||
LLMToolCall,
|
||||
LLMToolFunction,
|
||||
ModelDetail,
|
||||
ModelInfo,
|
||||
ModelName,
|
||||
ProviderConfig,
|
||||
ToolMetadata,
|
||||
UsageInfo,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"EmbeddingTaskType",
|
||||
"LLMCacheInfo",
|
||||
"LLMCodeExecution",
|
||||
"LLMContentPart",
|
||||
"LLMErrorCode",
|
||||
"LLMException",
|
||||
"LLMGroundingAttribution",
|
||||
"LLMGroundingMetadata",
|
||||
"LLMMessage",
|
||||
"LLMResponse",
|
||||
"LLMTool",
|
||||
"LLMToolCall",
|
||||
"LLMToolFunction",
|
||||
"ModelDetail",
|
||||
"ModelInfo",
|
||||
"ModelName",
|
||||
"ModelProvider",
|
||||
"ProviderConfig",
|
||||
"ResponseFormat",
|
||||
"ToolCategory",
|
||||
"ToolMetadata",
|
||||
"UsageInfo",
|
||||
"get_user_friendly_error_message",
|
||||
]
|
||||
428
zhenxun/services/llm/types/content.py
Normal file
428
zhenxun/services/llm/types/content.py
Normal file
@ -0,0 +1,428 @@
|
||||
"""
|
||||
LLM 内容类型定义
|
||||
|
||||
包含多模态内容部分、消息和响应的数据模型。
|
||||
"""
|
||||
|
||||
import base64
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiofiles
|
||||
from pydantic import BaseModel
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
|
||||
class LLMContentPart(BaseModel):
|
||||
"""LLM 消息内容部分 - 支持多模态内容"""
|
||||
|
||||
type: str
|
||||
text: str | None = None
|
||||
image_source: str | None = None
|
||||
audio_source: str | None = None
|
||||
video_source: str | None = None
|
||||
document_source: str | None = None
|
||||
file_uri: str | None = None
|
||||
file_source: str | None = None
|
||||
url: str | None = None
|
||||
mime_type: str | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
def model_post_init(self, /, __context: Any) -> None:
|
||||
"""验证内容部分的有效性"""
|
||||
_ = __context
|
||||
validation_rules = {
|
||||
"text": lambda: self.text,
|
||||
"image": lambda: self.image_source,
|
||||
"audio": lambda: self.audio_source,
|
||||
"video": lambda: self.video_source,
|
||||
"document": lambda: self.document_source,
|
||||
"file": lambda: self.file_uri or self.file_source,
|
||||
"url": lambda: self.url,
|
||||
}
|
||||
|
||||
if self.type in validation_rules:
|
||||
if not validation_rules[self.type]():
|
||||
raise ValueError(f"{self.type}类型的内容部分必须包含相应字段")
|
||||
|
||||
@classmethod
|
||||
def text_part(cls, text: str) -> "LLMContentPart":
|
||||
"""创建文本内容部分"""
|
||||
return cls(type="text", text=text)
|
||||
|
||||
@classmethod
|
||||
def image_url_part(cls, url: str) -> "LLMContentPart":
|
||||
"""创建图片URL内容部分"""
|
||||
return cls(type="image", image_source=url)
|
||||
|
||||
@classmethod
|
||||
def image_base64_part(
|
||||
cls, data: str, mime_type: str = "image/png"
|
||||
) -> "LLMContentPart":
|
||||
"""创建Base64图片内容部分"""
|
||||
data_url = f"data:{mime_type};base64,{data}"
|
||||
return cls(type="image", image_source=data_url)
|
||||
|
||||
@classmethod
|
||||
def audio_url_part(cls, url: str, mime_type: str = "audio/wav") -> "LLMContentPart":
|
||||
"""创建音频URL内容部分"""
|
||||
return cls(type="audio", audio_source=url, mime_type=mime_type)
|
||||
|
||||
@classmethod
|
||||
def video_url_part(cls, url: str, mime_type: str = "video/mp4") -> "LLMContentPart":
|
||||
"""创建视频URL内容部分"""
|
||||
return cls(type="video", video_source=url, mime_type=mime_type)
|
||||
|
||||
@classmethod
|
||||
def video_base64_part(
|
||||
cls, data: str, mime_type: str = "video/mp4"
|
||||
) -> "LLMContentPart":
|
||||
"""创建Base64视频内容部分"""
|
||||
data_url = f"data:{mime_type};base64,{data}"
|
||||
return cls(type="video", video_source=data_url, mime_type=mime_type)
|
||||
|
||||
@classmethod
|
||||
def audio_base64_part(
|
||||
cls, data: str, mime_type: str = "audio/wav"
|
||||
) -> "LLMContentPart":
|
||||
"""创建Base64音频内容部分"""
|
||||
data_url = f"data:{mime_type};base64,{data}"
|
||||
return cls(type="audio", audio_source=data_url, mime_type=mime_type)
|
||||
|
||||
@classmethod
|
||||
def file_uri_part(
|
||||
cls,
|
||||
file_uri: str,
|
||||
mime_type: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> "LLMContentPart":
|
||||
"""创建Gemini File API URI内容部分"""
|
||||
return cls(
|
||||
type="file",
|
||||
file_uri=file_uri,
|
||||
mime_type=mime_type,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def from_path(
|
||||
cls, path_like: str | Path, target_api: str | None = None
|
||||
) -> "LLMContentPart | None":
|
||||
"""
|
||||
从本地文件路径创建 LLMContentPart。
|
||||
自动检测MIME类型,并根据类型(如图片)可能加载为Base64。
|
||||
target_api 可以用于提示如何最好地准备数据(例如 'gemini' 可能偏好 base64)
|
||||
"""
|
||||
try:
|
||||
path = Path(path_like)
|
||||
if not path.exists() or not path.is_file():
|
||||
logger.warning(f"文件不存在或不是一个文件: {path}")
|
||||
return None
|
||||
|
||||
mime_type, _ = mimetypes.guess_type(path.resolve().as_uri())
|
||||
|
||||
if not mime_type:
|
||||
logger.warning(
|
||||
f"无法猜测文件 {path.name} 的MIME类型,将尝试作为文本文件处理。"
|
||||
)
|
||||
try:
|
||||
async with aiofiles.open(path, encoding="utf-8") as f:
|
||||
text_content = await f.read()
|
||||
return cls.text_part(text_content)
|
||||
except Exception as e:
|
||||
logger.error(f"读取文本文件 {path.name} 失败: {e}")
|
||||
return None
|
||||
|
||||
if mime_type.startswith("image/"):
|
||||
if target_api == "gemini" or not path.is_absolute():
|
||||
try:
|
||||
async with aiofiles.open(path, "rb") as f:
|
||||
img_bytes = await f.read()
|
||||
base64_data = base64.b64encode(img_bytes).decode("utf-8")
|
||||
return cls.image_base64_part(
|
||||
data=base64_data, mime_type=mime_type
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"读取或编码图片文件 {path.name} 失败: {e}")
|
||||
return None
|
||||
else:
|
||||
logger.warning(
|
||||
f"为本地图片路径 {path.name} 生成 image_url_part。"
|
||||
"实际API可能不支持 file:// URI。考虑使用Base64或公网URL。"
|
||||
)
|
||||
return cls.image_url_part(url=path.resolve().as_uri())
|
||||
elif mime_type.startswith("audio/"):
|
||||
return cls.audio_url_part(
|
||||
url=path.resolve().as_uri(), mime_type=mime_type
|
||||
)
|
||||
elif mime_type.startswith("video/"):
|
||||
if target_api == "gemini":
|
||||
# 对于 Gemini API,将视频转换为 base64
|
||||
try:
|
||||
async with aiofiles.open(path, "rb") as f:
|
||||
video_bytes = await f.read()
|
||||
base64_data = base64.b64encode(video_bytes).decode("utf-8")
|
||||
return cls.video_base64_part(
|
||||
data=base64_data, mime_type=mime_type
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"读取或编码视频文件 {path.name} 失败: {e}")
|
||||
return None
|
||||
else:
|
||||
return cls.video_url_part(
|
||||
url=path.resolve().as_uri(), mime_type=mime_type
|
||||
)
|
||||
elif (
|
||||
mime_type.startswith("text/")
|
||||
or mime_type == "application/json"
|
||||
or mime_type == "application/xml"
|
||||
):
|
||||
try:
|
||||
async with aiofiles.open(path, encoding="utf-8") as f:
|
||||
text_content = await f.read()
|
||||
return cls.text_part(text_content)
|
||||
except Exception as e:
|
||||
logger.error(f"读取文本类文件 {path.name} 失败: {e}")
|
||||
return None
|
||||
else:
|
||||
logger.info(
|
||||
f"文件 {path.name} (MIME: {mime_type}) 将作为通用文件URI处理。"
|
||||
)
|
||||
return cls.file_uri_part(
|
||||
file_uri=path.resolve().as_uri(),
|
||||
mime_type=mime_type,
|
||||
metadata={"name": path.name, "source": "local_path"},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从路径 {path_like} 创建LLMContentPart时出错: {e}")
|
||||
return None
|
||||
|
||||
def is_image_url(self) -> bool:
|
||||
"""检查图像源是否为URL"""
|
||||
if not self.image_source:
|
||||
return False
|
||||
return self.image_source.startswith(("http://", "https://"))
|
||||
|
||||
def is_image_base64(self) -> bool:
|
||||
"""检查图像源是否为Base64 Data URL"""
|
||||
if not self.image_source:
|
||||
return False
|
||||
return self.image_source.startswith("data:")
|
||||
|
||||
def get_base64_data(self) -> tuple[str, str] | None:
|
||||
"""从Data URL中提取Base64数据和MIME类型"""
|
||||
if not self.is_image_base64() or not self.image_source:
|
||||
return None
|
||||
|
||||
try:
|
||||
header, data = self.image_source.split(",", 1)
|
||||
mime_part = header.split(";")[0].replace("data:", "")
|
||||
return mime_part, data
|
||||
except (ValueError, IndexError):
|
||||
logger.warning(f"无法解析Base64图像数据: {self.image_source[:50]}...")
|
||||
return None
|
||||
|
||||
def convert_for_api(self, api_type: str) -> dict[str, Any]:
|
||||
"""根据API类型转换多模态内容格式"""
|
||||
if self.type == "text":
|
||||
if api_type == "openai":
|
||||
return {"type": "text", "text": self.text}
|
||||
elif api_type == "gemini":
|
||||
return {"text": self.text}
|
||||
else:
|
||||
return {"type": "text", "text": self.text}
|
||||
|
||||
elif self.type == "image":
|
||||
if not self.image_source:
|
||||
raise ValueError("图像类型的内容必须包含image_source")
|
||||
|
||||
if api_type == "openai":
|
||||
return {"type": "image_url", "image_url": {"url": self.image_source}}
|
||||
elif api_type == "gemini":
|
||||
if self.is_image_base64():
|
||||
base64_info = self.get_base64_data()
|
||||
if base64_info:
|
||||
mime_type, data = base64_info
|
||||
return {"inlineData": {"mimeType": mime_type, "data": data}}
|
||||
else:
|
||||
# 如果无法解析 Base64 数据,抛出异常
|
||||
raise ValueError(
|
||||
f"无法解析Base64图像数据: {self.image_source[:50]}..."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Gemini API需要Base64格式,但提供的是URL: {self.image_source}"
|
||||
)
|
||||
return {
|
||||
"inlineData": {
|
||||
"mimeType": "image/jpeg",
|
||||
"data": self.image_source,
|
||||
}
|
||||
}
|
||||
else:
|
||||
return {"type": "image_url", "image_url": {"url": self.image_source}}
|
||||
|
||||
elif self.type == "video":
|
||||
if not self.video_source:
|
||||
raise ValueError("视频类型的内容必须包含video_source")
|
||||
|
||||
if api_type == "gemini":
|
||||
# Gemini 支持视频,但需要通过 File API 上传
|
||||
if self.video_source.startswith("data:"):
|
||||
# 处理 base64 视频数据
|
||||
try:
|
||||
header, data = self.video_source.split(",", 1)
|
||||
mime_type = header.split(";")[0].replace("data:", "")
|
||||
return {"inlineData": {"mimeType": mime_type, "data": data}}
|
||||
except (ValueError, IndexError):
|
||||
raise ValueError(
|
||||
f"无法解析Base64视频数据: {self.video_source[:50]}..."
|
||||
)
|
||||
else:
|
||||
# 对于 URL 或其他格式,暂时不支持直接内联
|
||||
raise ValueError(
|
||||
"Gemini API 的视频处理需要通过 File API 上传,不支持直接 URL"
|
||||
)
|
||||
else:
|
||||
# 其他 API 可能不支持视频
|
||||
raise ValueError(f"API类型 '{api_type}' 不支持视频内容")
|
||||
|
||||
elif self.type == "audio":
|
||||
if not self.audio_source:
|
||||
raise ValueError("音频类型的内容必须包含audio_source")
|
||||
|
||||
if api_type == "gemini":
|
||||
# Gemini 支持音频,处理方式类似视频
|
||||
if self.audio_source.startswith("data:"):
|
||||
try:
|
||||
header, data = self.audio_source.split(",", 1)
|
||||
mime_type = header.split(";")[0].replace("data:", "")
|
||||
return {"inlineData": {"mimeType": mime_type, "data": data}}
|
||||
except (ValueError, IndexError):
|
||||
raise ValueError(
|
||||
f"无法解析Base64音频数据: {self.audio_source[:50]}..."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Gemini API 的音频处理需要通过 File API 上传,不支持直接 URL"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"API类型 '{api_type}' 不支持音频内容")
|
||||
|
||||
elif self.type == "file":
|
||||
if api_type == "gemini" and self.file_uri:
|
||||
return {
|
||||
"fileData": {"mimeType": self.mime_type, "fileUri": self.file_uri}
|
||||
}
|
||||
elif self.file_source:
|
||||
file_name = (
|
||||
self.metadata.get("name", "file") if self.metadata else "file"
|
||||
)
|
||||
if api_type == "gemini":
|
||||
return {"text": f"[文件: {file_name}]\n{self.file_source}"}
|
||||
else:
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"[文件: {file_name}]\n{self.file_source}",
|
||||
}
|
||||
else:
|
||||
raise ValueError("文件类型的内容必须包含file_uri或file_source")
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的内容类型: {self.type}")
|
||||
|
||||
|
||||
class LLMMessage(BaseModel):
|
||||
"""LLM 消息"""
|
||||
|
||||
role: str
|
||||
content: str | list[LLMContentPart]
|
||||
name: str | None = None
|
||||
tool_calls: list[Any] | None = None
|
||||
tool_call_id: str | None = None
|
||||
|
||||
def model_post_init(self, /, __context: Any) -> None:
|
||||
"""验证消息的有效性"""
|
||||
_ = __context
|
||||
if self.role == "tool":
|
||||
if not self.tool_call_id:
|
||||
raise ValueError("工具角色的消息必须包含 tool_call_id")
|
||||
if not self.name:
|
||||
raise ValueError("工具角色的消息必须包含函数名 (在 name 字段中)")
|
||||
if self.role == "tool" and not isinstance(self.content, str):
|
||||
logger.warning(
|
||||
f"工具角色消息的内容期望是字符串,但得到的是: {type(self.content)}. "
|
||||
"将尝试转换为字符串。"
|
||||
)
|
||||
try:
|
||||
self.content = str(self.content)
|
||||
except Exception as e:
|
||||
raise ValueError(f"无法将工具角色的内容转换为字符串: {e}")
|
||||
|
||||
@classmethod
|
||||
def user(cls, content: str | list[LLMContentPart]) -> "LLMMessage":
|
||||
"""创建用户消息"""
|
||||
return cls(role="user", content=content)
|
||||
|
||||
@classmethod
|
||||
def assistant_tool_calls(
|
||||
cls,
|
||||
tool_calls: list[Any],
|
||||
content: str | list[LLMContentPart] = "",
|
||||
) -> "LLMMessage":
|
||||
"""创建助手请求工具调用的消息"""
|
||||
return cls(role="assistant", content=content, tool_calls=tool_calls)
|
||||
|
||||
@classmethod
|
||||
def assistant_text_response(
|
||||
cls, content: str | list[LLMContentPart]
|
||||
) -> "LLMMessage":
|
||||
"""创建助手纯文本回复的消息"""
|
||||
return cls(role="assistant", content=content, tool_calls=None)
|
||||
|
||||
@classmethod
|
||||
def tool_response(
|
||||
cls,
|
||||
tool_call_id: str,
|
||||
function_name: str,
|
||||
result: Any,
|
||||
) -> "LLMMessage":
|
||||
"""创建工具执行结果的消息"""
|
||||
import json
|
||||
|
||||
try:
|
||||
content_str = json.dumps(result)
|
||||
except TypeError as e:
|
||||
logger.error(
|
||||
f"工具 '{function_name}' 的结果无法JSON序列化: {result}. 错误: {e}"
|
||||
)
|
||||
content_str = json.dumps(
|
||||
{"error": "Tool result not JSON serializable", "details": str(e)}
|
||||
)
|
||||
|
||||
return cls(
|
||||
role="tool",
|
||||
content=content_str,
|
||||
tool_call_id=tool_call_id,
|
||||
name=function_name,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def system(cls, content: str) -> "LLMMessage":
|
||||
"""创建系统消息"""
|
||||
return cls(role="system", content=content)
|
||||
|
||||
|
||||
class LLMResponse(BaseModel):
|
||||
"""LLM 响应"""
|
||||
|
||||
text: str
|
||||
usage_info: dict[str, Any] | None = None
|
||||
raw_response: dict[str, Any] | None = None
|
||||
tool_calls: list[Any] | None = None
|
||||
code_executions: list[Any] | None = None
|
||||
grounding_metadata: Any | None = None
|
||||
cache_info: Any | None = None
|
||||
67
zhenxun/services/llm/types/enums.py
Normal file
67
zhenxun/services/llm/types/enums.py
Normal file
@ -0,0 +1,67 @@
|
||||
"""
|
||||
LLM 枚举类型定义
|
||||
"""
|
||||
|
||||
from enum import Enum, auto
|
||||
|
||||
|
||||
class ModelProvider(Enum):
|
||||
"""模型提供商枚举"""
|
||||
|
||||
OPENAI = "openai"
|
||||
GEMINI = "gemini"
|
||||
ZHIXPU = "zhipu"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class ResponseFormat(Enum):
|
||||
"""响应格式枚举"""
|
||||
|
||||
TEXT = "text"
|
||||
JSON = "json"
|
||||
MULTIMODAL = "multimodal"
|
||||
|
||||
|
||||
class EmbeddingTaskType(str, Enum):
|
||||
"""文本嵌入任务类型 (主要用于Gemini)"""
|
||||
|
||||
RETRIEVAL_QUERY = "RETRIEVAL_QUERY"
|
||||
RETRIEVAL_DOCUMENT = "RETRIEVAL_DOCUMENT"
|
||||
SEMANTIC_SIMILARITY = "SEMANTIC_SIMILARITY"
|
||||
CLASSIFICATION = "CLASSIFICATION"
|
||||
CLUSTERING = "CLUSTERING"
|
||||
QUESTION_ANSWERING = "QUESTION_ANSWERING"
|
||||
FACT_VERIFICATION = "FACT_VERIFICATION"
|
||||
|
||||
|
||||
class ToolCategory(Enum):
|
||||
"""工具分类枚举"""
|
||||
|
||||
FILE_SYSTEM = auto()
|
||||
NETWORK = auto()
|
||||
SYSTEM_INFO = auto()
|
||||
CALCULATION = auto()
|
||||
DATA_PROCESSING = auto()
|
||||
CUSTOM = auto()
|
||||
|
||||
|
||||
class LLMErrorCode(Enum):
|
||||
"""LLM 服务相关的错误代码枚举"""
|
||||
|
||||
MODEL_INIT_FAILED = 2000
|
||||
MODEL_NOT_FOUND = 2001
|
||||
API_REQUEST_FAILED = 2002
|
||||
API_RESPONSE_INVALID = 2003
|
||||
API_KEY_INVALID = 2004
|
||||
API_QUOTA_EXCEEDED = 2005
|
||||
API_TIMEOUT = 2006
|
||||
API_RATE_LIMITED = 2007
|
||||
NO_AVAILABLE_KEYS = 2008
|
||||
UNKNOWN_API_TYPE = 2009
|
||||
CONFIGURATION_ERROR = 2010
|
||||
RESPONSE_PARSE_ERROR = 2011
|
||||
CONTEXT_LENGTH_EXCEEDED = 2012
|
||||
CONTENT_FILTERED = 2013
|
||||
USER_LOCATION_NOT_SUPPORTED = 2014
|
||||
GENERATION_FAILED = 2015
|
||||
EMBEDDING_FAILED = 2016
|
||||
80
zhenxun/services/llm/types/exceptions.py
Normal file
80
zhenxun/services/llm/types/exceptions.py
Normal file
@ -0,0 +1,80 @@
|
||||
"""
|
||||
LLM 异常类型定义
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .enums import LLMErrorCode
|
||||
|
||||
|
||||
class LLMException(Exception):
|
||||
"""LLM 服务相关的基础异常类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
code: LLMErrorCode = LLMErrorCode.API_REQUEST_FAILED,
|
||||
details: dict[str, Any] | None = None,
|
||||
recoverable: bool = True,
|
||||
cause: Exception | None = None,
|
||||
):
|
||||
self.message = message
|
||||
self.code = code
|
||||
self.details = details or {}
|
||||
self.recoverable = recoverable
|
||||
self.cause = cause
|
||||
super().__init__(message)
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.details:
|
||||
return f"{self.message} (错误码: {self.code.name}, 详情: {self.details})"
|
||||
return f"{self.message} (错误码: {self.code.name})"
|
||||
|
||||
@property
|
||||
def user_friendly_message(self) -> str:
|
||||
"""返回适合向用户展示的错误消息"""
|
||||
error_messages = {
|
||||
LLMErrorCode.MODEL_NOT_FOUND: "AI模型未找到,请检查配置或联系管理员。",
|
||||
LLMErrorCode.API_KEY_INVALID: "API密钥无效,请联系管理员更新配置。",
|
||||
LLMErrorCode.API_QUOTA_EXCEEDED: (
|
||||
"API使用配额已用尽,请稍后再试或联系管理员。"
|
||||
),
|
||||
LLMErrorCode.API_TIMEOUT: "AI服务响应超时,请稍后再试。",
|
||||
LLMErrorCode.API_RATE_LIMITED: "请求过于频繁,已被AI服务限流,请稍后再试。",
|
||||
LLMErrorCode.MODEL_INIT_FAILED: "AI模型初始化失败,请联系管理员检查配置。",
|
||||
LLMErrorCode.NO_AVAILABLE_KEYS: (
|
||||
"当前所有API密钥均不可用,请稍后再试或联系管理员。"
|
||||
),
|
||||
LLMErrorCode.USER_LOCATION_NOT_SUPPORTED: (
|
||||
"当前地区暂不支持此AI服务,请联系管理员或尝试其他模型。"
|
||||
),
|
||||
LLMErrorCode.API_REQUEST_FAILED: "AI服务请求失败,请稍后再试。",
|
||||
LLMErrorCode.API_RESPONSE_INVALID: "AI服务响应异常,请稍后再试。",
|
||||
LLMErrorCode.CONFIGURATION_ERROR: "AI服务配置错误,请联系管理员。",
|
||||
LLMErrorCode.CONTEXT_LENGTH_EXCEEDED: "输入内容过长,请缩短后重试。",
|
||||
LLMErrorCode.CONTENT_FILTERED: "内容被安全过滤,请修改后重试。",
|
||||
LLMErrorCode.RESPONSE_PARSE_ERROR: "AI服务响应解析失败,请稍后再试。",
|
||||
LLMErrorCode.UNKNOWN_API_TYPE: "不支持的AI服务类型,请联系管理员。",
|
||||
}
|
||||
return error_messages.get(self.code, "AI服务暂时不可用,请稍后再试。")
|
||||
|
||||
|
||||
def get_user_friendly_error_message(error: Exception) -> str:
|
||||
"""将任何异常转换为用户友好的错误消息"""
|
||||
if isinstance(error, LLMException):
|
||||
return error.user_friendly_message
|
||||
|
||||
error_str = str(error).lower()
|
||||
|
||||
if "timeout" in error_str or "超时" in error_str:
|
||||
return "请求超时,请稍后再试。"
|
||||
elif "connection" in error_str or "连接" in error_str:
|
||||
return "网络连接失败,请检查网络后重试。"
|
||||
elif "permission" in error_str or "权限" in error_str:
|
||||
return "权限不足,请联系管理员。"
|
||||
elif "not found" in error_str or "未找到" in error_str:
|
||||
return "请求的资源未找到,请检查配置。"
|
||||
elif "invalid" in error_str or "无效" in error_str:
|
||||
return "请求参数无效,请检查输入。"
|
||||
else:
|
||||
return "服务暂时不可用,请稍后再试。"
|
||||
160
zhenxun/services/llm/types/models.py
Normal file
160
zhenxun/services/llm/types/models.py
Normal file
@ -0,0 +1,160 @@
|
||||
"""
|
||||
LLM 数据模型定义
|
||||
|
||||
包含模型信息、配置、工具定义和响应数据的模型类。
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .enums import ModelProvider, ToolCategory
|
||||
|
||||
ModelName = str | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelInfo:
|
||||
"""模型信息(不可变数据类)"""
|
||||
|
||||
name: str
|
||||
provider: ModelProvider
|
||||
max_tokens: int = 4096
|
||||
supports_tools: bool = False
|
||||
supports_vision: bool = False
|
||||
supports_audio: bool = False
|
||||
cost_per_1k_tokens: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class UsageInfo:
|
||||
"""使用信息数据类"""
|
||||
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
cost: float = 0.0
|
||||
|
||||
@property
|
||||
def efficiency_ratio(self) -> float:
|
||||
"""计算效率比(输出/输入)"""
|
||||
return self.completion_tokens / max(self.prompt_tokens, 1)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolMetadata:
|
||||
"""工具元数据"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
category: ToolCategory
|
||||
read_only: bool = True
|
||||
destructive: bool = False
|
||||
open_world: bool = False
|
||||
parameters: dict[str, Any] = field(default_factory=dict)
|
||||
required_params: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class ModelDetail(BaseModel):
|
||||
"""模型详细信息"""
|
||||
|
||||
model_name: str
|
||||
is_available: bool = True
|
||||
is_embedding_model: bool = False
|
||||
temperature: float | None = None
|
||||
max_tokens: int | None = None
|
||||
|
||||
|
||||
class ProviderConfig(BaseModel):
|
||||
"""LLM 提供商配置"""
|
||||
|
||||
name: str = Field(..., description="Provider 的唯一名称标识")
|
||||
api_key: str | list[str] = Field(..., description="API Key 或 Key 列表")
|
||||
api_base: str | None = Field(None, description="API Base URL,如果为空则使用默认值")
|
||||
api_type: str = Field(default="openai", description="API 类型")
|
||||
openai_compat: bool = Field(default=False, description="是否使用 OpenAI 兼容模式")
|
||||
temperature: float | None = Field(default=0.7, description="默认温度参数")
|
||||
max_tokens: int | None = Field(default=None, description="默认最大输出 token 限制")
|
||||
models: list[ModelDetail] = Field(..., description="支持的模型列表")
|
||||
timeout: int = Field(default=180, description="请求超时时间")
|
||||
proxy: str | None = Field(default=None, description="代理设置")
|
||||
|
||||
|
||||
class LLMToolFunction(BaseModel):
|
||||
"""LLM 工具函数定义"""
|
||||
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class LLMToolCall(BaseModel):
|
||||
"""LLM 工具调用"""
|
||||
|
||||
id: str
|
||||
function: LLMToolFunction
|
||||
|
||||
|
||||
class LLMTool(BaseModel):
|
||||
"""LLM 工具定义(支持 MCP 风格)"""
|
||||
|
||||
type: str = "function"
|
||||
function: dict[str, Any]
|
||||
annotations: dict[str, Any] | None = Field(default=None, description="工具注解")
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
name: str,
|
||||
description: str,
|
||||
parameters: dict[str, Any],
|
||||
required: list[str] | None = None,
|
||||
annotations: dict[str, Any] | None = None,
|
||||
) -> "LLMTool":
|
||||
"""创建工具"""
|
||||
function_def = {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": parameters,
|
||||
"required": required or [],
|
||||
},
|
||||
}
|
||||
return cls(type="function", function=function_def, annotations=annotations)
|
||||
|
||||
|
||||
class LLMCodeExecution(BaseModel):
|
||||
"""代码执行结果"""
|
||||
|
||||
code: str
|
||||
output: str | None = None
|
||||
error: str | None = None
|
||||
execution_time: float | None = None
|
||||
files_generated: list[str] | None = None
|
||||
|
||||
|
||||
class LLMGroundingAttribution(BaseModel):
|
||||
"""信息来源关联"""
|
||||
|
||||
title: str | None = None
|
||||
uri: str | None = None
|
||||
snippet: str | None = None
|
||||
confidence_score: float | None = None
|
||||
|
||||
|
||||
class LLMGroundingMetadata(BaseModel):
|
||||
"""信息来源关联元数据"""
|
||||
|
||||
web_search_queries: list[str] | None = None
|
||||
grounding_attributions: list[LLMGroundingAttribution] | None = None
|
||||
search_suggestions: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class LLMCacheInfo(BaseModel):
|
||||
"""缓存信息"""
|
||||
|
||||
cache_hit: bool = False
|
||||
cache_key: str | None = None
|
||||
cache_ttl: int | None = None
|
||||
created_at: str | None = None
|
||||
218
zhenxun/services/llm/utils.py
Normal file
218
zhenxun/services/llm/utils.py
Normal file
@ -0,0 +1,218 @@
|
||||
"""
|
||||
LLM 模块的工具和转换函数
|
||||
"""
|
||||
|
||||
import base64
|
||||
from pathlib import Path
|
||||
|
||||
from nonebot_plugin_alconna.uniseg import (
|
||||
At,
|
||||
File,
|
||||
Image,
|
||||
Reply,
|
||||
Text,
|
||||
UniMessage,
|
||||
Video,
|
||||
Voice,
|
||||
)
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
from .types import LLMContentPart
|
||||
|
||||
|
||||
async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
|
||||
"""
|
||||
将 UniMessage 实例转换为一个 LLMContentPart 列表。
|
||||
这是处理多模态输入的核心转换逻辑。
|
||||
"""
|
||||
parts: list[LLMContentPart] = []
|
||||
for seg in message:
|
||||
part = None
|
||||
if isinstance(seg, Text):
|
||||
if seg.text.strip():
|
||||
part = LLMContentPart.text_part(seg.text)
|
||||
elif isinstance(seg, Image):
|
||||
if seg.path:
|
||||
part = await LLMContentPart.from_path(seg.path, target_api="gemini")
|
||||
elif seg.url:
|
||||
part = LLMContentPart.image_url_part(seg.url)
|
||||
elif hasattr(seg, "raw") and seg.raw:
|
||||
mime_type = (
|
||||
getattr(seg, "mimetype", "image/png")
|
||||
if hasattr(seg, "mimetype")
|
||||
else "image/png"
|
||||
)
|
||||
if isinstance(seg.raw, bytes):
|
||||
b64_data = base64.b64encode(seg.raw).decode("utf-8")
|
||||
part = LLMContentPart.image_base64_part(b64_data, mime_type)
|
||||
|
||||
elif isinstance(seg, File | Voice | Video):
|
||||
if seg.path:
|
||||
part = await LLMContentPart.from_path(seg.path)
|
||||
elif seg.url:
|
||||
logger.warning(
|
||||
f"直接使用 URL 的 {type(seg).__name__} 段,"
|
||||
f"API 可能不支持: {seg.url}"
|
||||
)
|
||||
part = LLMContentPart.text_part(
|
||||
f"[{type(seg).__name__.upper()} FILE: {seg.name or seg.url}]"
|
||||
)
|
||||
elif hasattr(seg, "raw") and seg.raw:
|
||||
mime_type = getattr(seg, "mimetype", None)
|
||||
if isinstance(seg.raw, bytes):
|
||||
b64_data = base64.b64encode(seg.raw).decode("utf-8")
|
||||
|
||||
if isinstance(seg, Video):
|
||||
if not mime_type:
|
||||
mime_type = "video/mp4"
|
||||
part = LLMContentPart.video_base64_part(
|
||||
data=b64_data, mime_type=mime_type
|
||||
)
|
||||
logger.debug(
|
||||
f"处理视频字节数据: {mime_type}, 大小: {len(seg.raw)} bytes"
|
||||
)
|
||||
elif isinstance(seg, Voice):
|
||||
if not mime_type:
|
||||
mime_type = "audio/wav"
|
||||
part = LLMContentPart.audio_base64_part(
|
||||
data=b64_data, mime_type=mime_type
|
||||
)
|
||||
logger.debug(
|
||||
f"处理音频字节数据: {mime_type}, 大小: {len(seg.raw)} bytes"
|
||||
)
|
||||
else:
|
||||
part = LLMContentPart.text_part(
|
||||
f"[FILE: {mime_type or 'unknown'}, {len(seg.raw)} bytes]"
|
||||
)
|
||||
logger.debug(
|
||||
f"处理其他文件字节数据: {mime_type}, "
|
||||
f"大小: {len(seg.raw)} bytes"
|
||||
)
|
||||
|
||||
elif isinstance(seg, At):
|
||||
if seg.flag == "all":
|
||||
part = LLMContentPart.text_part("[Mentioned Everyone]")
|
||||
else:
|
||||
part = LLMContentPart.text_part(f"[Mentioned user: {seg.target}]")
|
||||
|
||||
elif isinstance(seg, Reply):
|
||||
if seg.msg:
|
||||
try:
|
||||
extract_method = getattr(seg.msg, "extract_plain_text", None)
|
||||
if extract_method and callable(extract_method):
|
||||
reply_text = str(extract_method()).strip()
|
||||
else:
|
||||
reply_text = str(seg.msg).strip()
|
||||
if reply_text:
|
||||
part = LLMContentPart.text_part(
|
||||
f'[Replied to: "{reply_text[:50]}..."]'
|
||||
)
|
||||
except Exception:
|
||||
part = LLMContentPart.text_part("[Replied to a message]")
|
||||
|
||||
if part:
|
||||
parts.append(part)
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
def create_multimodal_message(
|
||||
text: str | None = None,
|
||||
images: list[str | Path | bytes] | str | Path | bytes | None = None,
|
||||
videos: list[str | Path | bytes] | str | Path | bytes | None = None,
|
||||
audios: list[str | Path | bytes] | str | Path | bytes | None = None,
|
||||
image_mimetypes: list[str] | str | None = None,
|
||||
video_mimetypes: list[str] | str | None = None,
|
||||
audio_mimetypes: list[str] | str | None = None,
|
||||
) -> UniMessage:
|
||||
"""
|
||||
创建多模态消息的便捷函数,方便第三方调用。
|
||||
|
||||
Args:
|
||||
text: 文本内容
|
||||
images: 图片数据,支持路径、字节数据或URL
|
||||
videos: 视频数据,支持路径、字节数据或URL
|
||||
audios: 音频数据,支持路径、字节数据或URL
|
||||
image_mimetypes: 图片MIME类型,当images为bytes时需要指定
|
||||
video_mimetypes: 视频MIME类型,当videos为bytes时需要指定
|
||||
audio_mimetypes: 音频MIME类型,当audios为bytes时需要指定
|
||||
|
||||
Returns:
|
||||
UniMessage: 构建好的多模态消息
|
||||
|
||||
Examples:
|
||||
# 纯文本
|
||||
msg = create_multimodal_message("请分析这段文字")
|
||||
|
||||
# 文本 + 单张图片(路径)
|
||||
msg = create_multimodal_message("分析图片", images="/path/to/image.jpg")
|
||||
|
||||
# 文本 + 多张图片
|
||||
msg = create_multimodal_message(
|
||||
"比较图片", images=["/path/1.jpg", "/path/2.jpg"]
|
||||
)
|
||||
|
||||
# 文本 + 图片字节数据
|
||||
msg = create_multimodal_message(
|
||||
"分析", images=image_data, image_mimetypes="image/jpeg"
|
||||
)
|
||||
|
||||
# 文本 + 视频
|
||||
msg = create_multimodal_message("分析视频", videos="/path/to/video.mp4")
|
||||
|
||||
# 文本 + 音频
|
||||
msg = create_multimodal_message("转录音频", audios="/path/to/audio.wav")
|
||||
|
||||
# 混合多模态
|
||||
msg = create_multimodal_message(
|
||||
"分析这些媒体文件",
|
||||
images="/path/to/image.jpg",
|
||||
videos="/path/to/video.mp4",
|
||||
audios="/path/to/audio.wav"
|
||||
)
|
||||
"""
|
||||
message = UniMessage()
|
||||
|
||||
if text:
|
||||
message.append(Text(text))
|
||||
|
||||
if images is not None:
|
||||
_add_media_to_message(message, images, image_mimetypes, Image, "image/png")
|
||||
|
||||
if videos is not None:
|
||||
_add_media_to_message(message, videos, video_mimetypes, Video, "video/mp4")
|
||||
|
||||
if audios is not None:
|
||||
_add_media_to_message(message, audios, audio_mimetypes, Voice, "audio/wav")
|
||||
|
||||
return message
|
||||
|
||||
|
||||
def _add_media_to_message(
|
||||
message: UniMessage,
|
||||
media_items: list[str | Path | bytes] | str | Path | bytes,
|
||||
mimetypes: list[str] | str | None,
|
||||
media_class: type,
|
||||
default_mimetype: str,
|
||||
) -> None:
|
||||
"""添加媒体文件到 UniMessage 的辅助函数"""
|
||||
if not isinstance(media_items, list):
|
||||
media_items = [media_items]
|
||||
|
||||
mime_list = []
|
||||
if mimetypes is not None:
|
||||
if isinstance(mimetypes, str):
|
||||
mime_list = [mimetypes] * len(media_items)
|
||||
else:
|
||||
mime_list = list(mimetypes)
|
||||
|
||||
for i, item in enumerate(media_items):
|
||||
if isinstance(item, str | Path):
|
||||
if str(item).startswith(("http://", "https://")):
|
||||
message.append(media_class(url=str(item)))
|
||||
else:
|
||||
message.append(media_class(path=Path(item)))
|
||||
elif isinstance(item, bytes):
|
||||
mimetype = mime_list[i] if i < len(mime_list) else default_mimetype
|
||||
message.append(media_class(raw=item, mimetype=mimetype))
|
||||
Loading…
Reference in New Issue
Block a user