Compare commits

...

4 Commits

Author SHA1 Message Date
HibiKier
1641f95e74
Merge 1de1ded65c into 7c153721f0 2025-08-04 15:36:39 +00:00
Rumio
7c153721f0
♻️ refactor!: 重构LLM服务架构并统一Pydantic兼容性处理 (#2002)
Some checks failed
检查bot是否运行正常 / bot check (push) Waiting to run
Sequential Lint and Type Check / ruff-call (push) Waiting to run
Sequential Lint and Type Check / pyright-call (push) Blocked by required conditions
Release Drafter / Update Release Draft (push) Waiting to run
Force Sync to Aliyun / sync (push) Waiting to run
Update Version / update-version (push) Waiting to run
CodeQL Code Security Analysis / Analyze (${{ matrix.language }}) (none, javascript-typescript) (push) Has been cancelled
CodeQL Code Security Analysis / Analyze (${{ matrix.language }}) (none, python) (push) Has been cancelled
* ♻️ refactor(pydantic): 提取 Pydantic 兼容函数到独立模块

* ♻️ refactor!(llm): 重构LLM服务,引入现代化工具和执行器架构

🏗️ **架构变更**
- 引入ToolProvider/ToolExecutable协议,取代ToolRegistry
- 新增LLMToolExecutor,分离工具调用逻辑
- 新增BaseMemory抽象,解耦会话状态管理

🔄 **API重构**
- 移除:analyze, analyze_multimodal, pipeline_chat
- 新增:generate_structured, run_with_tools
- 重构:chat, search, code变为无状态调用

🛠️ **工具系统**
- 新增@function_tool装饰器
- 统一工具定义到ToolExecutable协议
- 移除MCP工具系统和mcp_tools.json

---------

Co-authored-by: webjoin111 <455457521@qq.com>
2025-08-04 23:36:12 +08:00
molanp
1de1ded65c
test(auto_update): 修改更新检测消息格式 (#2003)
Some checks are pending
Sequential Lint and Type Check / ruff-call (push) Waiting to run
Sequential Lint and Type Check / pyright-call (push) Blocked by required conditions
- 移除了不必要的版本号后缀(如 "-e6f17c4")
- 统一了版本更新消息的格式,删除了冗余信息
2025-08-04 23:35:57 +08:00
HibiKier
70f363c0ce feat(update): 移除资源管理器,重构更新逻辑,支持通过ZhenxunRepoManager进行资源和Web UI的更新 2025-08-04 23:32:11 +08:00
42 changed files with 2059 additions and 3725 deletions

View File

@ -1,89 +0,0 @@
SUPERUSERS=[""]
COMMAND_START=[""]
SESSION_RUNNING_EXPRESSION="别急呀,小真寻要宕机了!QAQ"
NICKNAME=["真寻", "小真寻", "绪山真寻", "小寻子"]
SESSION_EXPIRE_TIMEOUT=00:00:30
ALCONNA_USE_COMMAND_START=True
# 全局图片统一使用bytes发送当真寻与协议端不在同一服务器上时为True
IMAGE_TO_BYTES = True
# 回复消息时自称
SELF_NICKNAME="小真寻"
# 官bot appid:bot账号
QBOT_ID_DATA = '{
}'
# 数据库配置
# 示例: "postgres://user:password@127.0.0.1:5432/database"
# 示例: "mysql://user:password@127.0.0.1:3306/database"
# 示例: "sqlite:data/db/zhenxun.db" 在data目录下建立db文件夹
DB_URL = ""
# NONE: 不使用缓存, MEMORY: 使用内存缓存, REDIS: 使用Redis缓存
CACHE_MODE = NONE
# REDIS配置使用REDIS替换Cache内存缓存
# REDIS地址
# REDIS_HOST = "127.0.0.1"
# REDIS端口
# REDIS_PORT = 6379
# REDIS密码
# REDIS_PASSWORD = ""
# REDIS过期时间
# REDIS_EXPIRE = 600
# 系统代理
# SYSTEM_PROXY = "http://127.0.0.1:7890"
PLATFORM_SUPERUSERS = '
{
"qq": [""],
"dodo": [""]
}
'
DRIVER=~fastapi+~httpx+~websockets
# LOG_LEVEL = DEBUG
# 服务器和端口
HOST = 127.0.0.1
PORT = 8080
# kook adapter toekn
# kaiheila_bots =[{"token": ""}]
# # discode adapter
# DISCORD_BOTS='
# [
# {
# "token": "",
# "intent": {
# "guild_messages": true,
# "direct_messages": true
# },
# "application_commands": {"*": ["*"]}
# }
# ]
# '
# DISCORD_PROXY=''
# # dodo adapter
# DODO_BOTS='
# [
# {
# "client_id": "",
# "token": ""
# }
# ]
# '
# application_commands的{"*": ["*"]}代表将全部应用命令注册为全局应用命令
# {"admin": ["123", "456"]}则代表将admin命令注册为id是123、456服务器的局部命令其余命令不注册

View File

@ -324,7 +324,7 @@ async def test_check_update_release(
ctx.should_call_api(
"send_msg",
_v11_private_message_send(
message="检测真寻已更新版本更新v0.2.2 -> v0.2.2\n开始更新...",
message="检测真寻已更新版本更新v0.2.2\n开始更新...",
user_id=UserId.SUPERUSER,
),
)
@ -420,8 +420,7 @@ async def test_check_update_main(
ctx.should_call_api(
"send_msg",
_v11_private_message_send(
message="检测真寻已更新版本更新v0.2.2 -> v0.2.2-e6f17c4\n"
"开始更新...",
message="检测真寻已更新版本更新v0.2.2\n开始更新...",
user_id=UserId.SUPERUSER,
),
)

View File

@ -17,7 +17,7 @@ from zhenxun.models.user_console import UserConsole
from zhenxun.services.log import logger
from zhenxun.utils.decorator.shop import shop_register
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from zhenxun.utils.manager.resource_manager import ResourceManager
from zhenxun.utils.manager.zhenxun_repo_manager import ZhenxunRepoManager
from zhenxun.utils.platform import PlatformUtils
driver: Driver = nonebot.get_driver()
@ -85,7 +85,8 @@ from bag_users t1
@PriorityLifecycle.on_startup(priority=5)
async def _():
await ResourceManager.init_resources()
if not ZhenxunRepoManager.check_resources_exists():
await ZhenxunRepoManager.resources_update(branch="test")
"""签到与用户的数据迁移"""
if goods_list := await GoodsInfo.filter(uuid__isnull=True).all():
for goods in goods_list:

View File

@ -16,10 +16,6 @@ from nonebot_plugin_uninfo import Uninfo
from zhenxun.configs.utils import PluginExtraData
from zhenxun.services.log import logger
from zhenxun.utils.enum import PluginType
from zhenxun.utils.manager.resource_manager import (
DownloadResourceException,
ResourceManager,
)
from zhenxun.utils.message import MessageUtils
from ._data_source import UpdateManager
@ -68,7 +64,6 @@ _matcher = on_alconna(
Option("-f|--force", action=store_true, help_text="强制更新"),
Option("-s", Args["source?", ["git", "ali"]], help_text="更新源"),
Option("-z|--zip", action=store_true, help_text="下载zip文件"),
Option("-t", Args["update_type?", ["git", "download"]], help_text="更新方式"),
),
priority=1,
block=True,
@ -86,39 +81,52 @@ async def _(
force: Query[bool] = Query("force", False),
source: Query[str] = Query("source", "ali"),
zip: Query[bool] = Query("zip", False),
update_type: Query[str] = Query("update_type", "git"),
):
result = ""
await MessageUtils.build_message("正在进行检查更新...").send(reply_to=True)
if ver_type.result in {"main", "release"}:
ver_type_str = ver_type.result
source_str = source.result
if ver_type_str in {"main", "release"}:
if not ver_type.available:
result = await UpdateManager.check_version()
result += await UpdateManager.check_version()
logger.info("查看当前版本...", "检查更新", session=session)
await MessageUtils.build_message(result).finish()
try:
result = await UpdateManager.update(
result += await UpdateManager.update_zhenxun(
bot,
session.user.id,
ver_type.result,
ver_type_str, # type: ignore
force.result,
source.result,
source_str, # type: ignore
zip.result,
update_type.result,
)
await MessageUtils.build_message(result).finish(reply_to=True)
except Exception as e:
logger.error("版本更新失败...", "检查更新", session=session, e=e)
await MessageUtils.build_message(f"更新版本失败...e: {e}").finish()
elif ver_type.result == "webui":
result = await UpdateManager.update_webui(zip.result, source.result)
if zip.result:
source_str = None
try:
result += await UpdateManager.update_webui(
source_str, # type: ignore
"dist",
)
except Exception as e:
logger.error("WebUI更新失败...", "检查更新", session=session, e=e)
result += "\nWebUI更新错误..."
if resource.result or ver_type.result == "resource":
try:
await ResourceManager.init_resources(True, zip.result, source.result)
result += "\n资源文件更新成功!"
except DownloadResourceException:
result += "\n资源更新下载失败..."
if zip.result:
source_str = None
result += await UpdateManager.update_resources(
source_str, # type: ignore
"main",
force.result,
)
except Exception as e:
logger.error("资源更新下载失败...", "检查更新", session=session, e=e)
result += "\n资源更新未知错误..."
result += "\n资源更新错误..."
if result:
await MessageUtils.build_message(result.strip()).finish()
await MessageUtils.build_message("更新版本失败...").finish()

View File

@ -1,167 +1,16 @@
import os
from pathlib import Path
import shutil
import tarfile
import zipfile
from typing import Literal
from nonebot.adapters import Bot
from nonebot.utils import run_sync
from zhenxun.configs.path_config import DATA_PATH
from zhenxun.services.log import logger
from zhenxun.utils.github_utils import GithubUtils
from zhenxun.utils.github_utils.models import RepoInfo
from zhenxun.utils.http_utils import AsyncHttpx
from zhenxun.utils.manager.virtual_env_package_manager import VirtualEnvPackageManager
from zhenxun.utils.manager.zhenxun_repo_manager import ZhenxunRepoManager
from zhenxun.utils.platform import PlatformUtils
from zhenxun.utils.repo_utils import AliyunRepoManager, GithubRepoManager
from .config import (
BACKUP_PATH,
BASE_PATH,
BASE_PATH_STRING,
COMMAND,
DEFAULT_GITHUB_URL,
DOWNLOAD_GZ_FILE,
DOWNLOAD_ZIP_FILE,
GIT_GITHUB_URL,
GIT_WEBUI_UI_URL,
PYPROJECT_FILE,
PYPROJECT_FILE_STRING,
PYPROJECT_LOCK_FILE,
PYPROJECT_LOCK_FILE_STRING,
RELEASE_URL,
REPLACE_FOLDERS,
REQ_TXT_FILE,
REQ_TXT_FILE_STRING,
TMP_PATH,
VERSION_FILE,
)
@run_sync
def _file_handle(latest_version: str | None):
"""文件移动操作
参数:
latest_version: 版本号
"""
BACKUP_PATH.mkdir(exist_ok=True, parents=True)
logger.debug("开始解压文件压缩包...", COMMAND)
download_file = DOWNLOAD_GZ_FILE
if DOWNLOAD_GZ_FILE.exists():
tf = tarfile.open(DOWNLOAD_GZ_FILE)
else:
download_file = DOWNLOAD_ZIP_FILE
tf = zipfile.ZipFile(DOWNLOAD_ZIP_FILE)
tf.extractall(TMP_PATH)
logger.debug("解压文件压缩包完成...", COMMAND)
download_file_path = TMP_PATH / next(
x for x in os.listdir(TMP_PATH) if (TMP_PATH / x).is_dir()
)
_pyproject = download_file_path / PYPROJECT_FILE_STRING
_lock_file = download_file_path / PYPROJECT_LOCK_FILE_STRING
_req_file = download_file_path / REQ_TXT_FILE_STRING
extract_path = download_file_path / BASE_PATH_STRING
target_path = BASE_PATH
if PYPROJECT_FILE.exists():
logger.debug(f"移除备份文件: {PYPROJECT_FILE}", COMMAND)
shutil.move(PYPROJECT_FILE, BACKUP_PATH / PYPROJECT_FILE_STRING)
if PYPROJECT_LOCK_FILE.exists():
logger.debug(f"移除备份文件: {PYPROJECT_LOCK_FILE}", COMMAND)
shutil.move(PYPROJECT_LOCK_FILE, BACKUP_PATH / PYPROJECT_LOCK_FILE_STRING)
if REQ_TXT_FILE.exists():
logger.debug(f"移除备份文件: {REQ_TXT_FILE}", COMMAND)
shutil.move(REQ_TXT_FILE, BACKUP_PATH / REQ_TXT_FILE_STRING)
if _pyproject.exists():
logger.debug("移动文件: pyproject.toml", COMMAND)
shutil.move(_pyproject, PYPROJECT_FILE)
if _lock_file.exists():
logger.debug("移动文件: poetry.lock", COMMAND)
shutil.move(_lock_file, PYPROJECT_LOCK_FILE)
if _req_file.exists():
logger.debug("移动文件: requirements.txt", COMMAND)
shutil.move(_req_file, REQ_TXT_FILE)
for folder in REPLACE_FOLDERS:
"""移动指定文件夹"""
_dir = BASE_PATH / folder
_backup_dir = BACKUP_PATH / folder
if _backup_dir.exists():
logger.debug(f"删除备份文件夹 {_backup_dir}", COMMAND)
shutil.rmtree(_backup_dir)
if _dir.exists():
logger.debug(f"移动旧文件夹 {_dir}", COMMAND)
shutil.move(_dir, _backup_dir)
else:
logger.warning(f"文件夹 {_dir} 不存在,跳过删除", COMMAND)
for folder in REPLACE_FOLDERS:
src_folder_path = extract_path / folder
dest_folder_path = target_path / folder
if src_folder_path.exists():
logger.debug(
f"移动文件夹: {src_folder_path} -> {dest_folder_path}", COMMAND
)
shutil.move(src_folder_path, dest_folder_path)
else:
logger.debug(f"源文件夹不存在: {src_folder_path}", COMMAND)
if tf:
tf.close()
if download_file.exists():
logger.debug(f"删除下载文件: {download_file}", COMMAND)
download_file.unlink()
if extract_path.exists():
logger.debug(f"删除解压文件夹: {extract_path}", COMMAND)
shutil.rmtree(extract_path)
if TMP_PATH.exists():
shutil.rmtree(TMP_PATH)
if latest_version:
with open(VERSION_FILE, "w", encoding="utf8") as f:
f.write(f"__version__: {latest_version}")
from .config import LOG_COMMAND, REQUIREMENTS_FILE, VERSION_FILE
class UpdateManager:
@classmethod
async def update_webui(cls, is_zip: bool, source: str) -> str:
from zhenxun.builtin_plugins.web_ui.public.data_source import (
update_webui_assets,
)
WEBUI_PATH = DATA_PATH / "web_ui" / "public"
BACKUP_PATH = DATA_PATH / "web_ui" / "backup_public"
GIT_WEBUI_PATH = DATA_PATH / "web_ui" / "git_web_ui"
if WEBUI_PATH.exists():
if BACKUP_PATH.exists():
logger.debug(f"删除旧的备份webui文件夹 {BACKUP_PATH}", COMMAND)
shutil.rmtree(BACKUP_PATH)
WEBUI_PATH.rename(BACKUP_PATH)
try:
if is_zip:
await update_webui_assets()
logger.info("更新webui成功...", COMMAND)
else:
if source == "ali":
result = await AliyunRepoManager.update(
GIT_WEBUI_UI_URL, GIT_WEBUI_PATH, "dist", force=True
)
else:
result = await GithubRepoManager.update(
GIT_WEBUI_UI_URL, GIT_WEBUI_PATH, "dist", force=True
)
if not result.success:
return f"Webui更新失败...错误: {result.error_message}"
shutil.rmtree(WEBUI_PATH, ignore_errors=True)
shutil.copytree(GIT_WEBUI_PATH / "dist", WEBUI_PATH)
if BACKUP_PATH.exists():
logger.debug(f"删除旧的webui文件夹 {BACKUP_PATH}", COMMAND)
shutil.rmtree(BACKUP_PATH)
return "Webui更新成功"
except Exception as e:
logger.error("更新webui失败...", COMMAND, e=e)
if BACKUP_PATH.exists():
logger.debug(f"恢复旧的webui文件夹 {BACKUP_PATH}", COMMAND)
BACKUP_PATH.rename(WEBUI_PATH)
raise e
@classmethod
async def check_version(cls) -> str:
"""检查更新版本
@ -170,71 +19,88 @@ class UpdateManager:
str: 更新信息
"""
cur_version = cls.__get_version()
data = await cls.__get_latest_data()
if not data:
release_data = await ZhenxunRepoManager.zhenxun_get_latest_releases_data()
if not release_data:
return "检查更新获取版本失败..."
return (
"检测到当前版本更新\n"
f"当前版本:{cur_version}\n"
f"最新版本:{data.get('name')}\n"
f"创建日期:{data.get('created_at')}\n"
f"更新内容:\n{data.get('body')}"
f"最新版本:{release_data.get('name')}\n"
f"创建日期:{release_data.get('created_at')}\n"
f"更新内容:\n{release_data.get('body')}"
)
@classmethod
async def __zip_update(cls, version_type: str):
logger.info("开始下载真寻最新版文件....", COMMAND)
cur_version = cls.__get_version()
url = None
new_version = None
repo_info = GithubUtils.parse_github_url(DEFAULT_GITHUB_URL)
if version_type in {"main"}:
repo_info.branch = version_type
new_version = await cls.__get_version_from_repo(repo_info)
if new_version:
new_version = new_version.split(":")[-1].strip()
url = await repo_info.get_archive_download_urls()
elif version_type == "release":
data = await cls.__get_latest_data()
if not data:
return "获取更新版本失败..."
new_version = data.get("name", "")
url = await repo_info.get_release_source_download_urls_tgz(new_version)
if not url:
return "获取版本下载链接失败..."
if TMP_PATH.exists():
logger.debug(f"删除临时文件夹 {TMP_PATH}", COMMAND)
shutil.rmtree(TMP_PATH)
logger.debug(
f"开始更新版本:{cur_version} -> {new_version} | 下载链接:{url}",
COMMAND,
async def update_webui(
cls,
source: Literal["git", "ali"] | None,
branch: str = "main",
force: bool = False,
):
"""更新WebUI
参数:
source: 更新源
branch: 分支
force: 是否强制更新
返回:
str: 返回消息
"""
if not source:
await ZhenxunRepoManager.webui_zip_update()
return "WebUI更新完成!"
result = await ZhenxunRepoManager.webui_git_update(
source,
branch=branch,
force=force,
)
download_file = (
DOWNLOAD_GZ_FILE if version_type == "release" else DOWNLOAD_ZIP_FILE
if not result.success:
logger.error(f"WebUI更新失败...错误: {result.error_message}", LOG_COMMAND)
return f"WebUI更新失败...错误: {result.error_message}"
return "WebUI更新完成!"
@classmethod
async def update_resources(
cls,
source: Literal["git", "ali"] | None,
branch: str = "main",
force: bool = False,
) -> str:
"""更新资源
参数:
source: 更新源
branch: 分支
force: 是否强制更新
返回:
str: 返回消息
"""
if not source:
await ZhenxunRepoManager.resources_zip_update()
return "真寻资源更新完成!"
result = await ZhenxunRepoManager.resources_git_update(
source,
branch=branch,
force=force,
)
if await AsyncHttpx.download_file(url, download_file, stream=True):
logger.debug("下载真寻最新版文件完成...", COMMAND)
await _file_handle(new_version)
result = "版本更新完成"
return (
f"{result}\n"
f"版本: {cur_version} -> {new_version}\n"
"请重新启动真寻以完成更新!"
if not result.success:
logger.error(
f"真寻资源更新失败...错误: {result.error_message}", LOG_COMMAND
)
else:
logger.debug("下载真寻最新版文件失败...", COMMAND)
return ""
return f"真寻资源更新失败...错误: {result.error_message}"
return "真寻资源更新完成!"
@classmethod
async def update(
async def update_zhenxun(
cls,
bot: Bot,
user_id: str,
version_type: str,
version_type: Literal["main", "release"],
force: bool,
source: str,
source: Literal["git", "ali"],
zip: bool,
update_type: str,
) -> str:
"""更新操作
@ -257,33 +123,38 @@ class UpdateManager:
user_id,
)
if zip:
return await cls.__zip_update(version_type)
elif source == "git":
result = await GithubRepoManager.update(
GIT_GITHUB_URL,
Path(),
use_git=update_type == "git",
force=force,
new_version = await ZhenxunRepoManager.zhenxun_zip_update(version_type)
await PlatformUtils.send_superuser(
bot, "真寻更新完成,开始安装依赖...", user_id
)
await VirtualEnvPackageManager.install_requirement(REQUIREMENTS_FILE)
return (
f"版本更新完成!\n版本: {cur_version} -> {new_version}\n"
"请重新启动真寻以完成更新!"
)
else:
result = await AliyunRepoManager.update(
GIT_GITHUB_URL,
Path(),
result = await ZhenxunRepoManager.zhenxun_git_update(
source,
branch=version_type,
force=force,
)
if not result.success:
return f"版本更新失败...错误: {result.error_message}"
await PlatformUtils.send_superuser(
bot, "真寻更新完成,开始安装依赖...", user_id
)
await VirtualEnvPackageManager.install_requirement(REQ_TXT_FILE)
return (
f"版本更新完成!\n"
f"版本: {cur_version} -> {result.new_version}\n"
f"变更文件个数: {len(result.changed_files)}"
f"{'' if source == 'git' else '(阿里云更新不支持查看变更文件)'}\n"
"请重新启动真寻以完成更新!"
)
if not result.success:
logger.error(
f"真寻版本更新失败...错误: {result.error_message}",
LOG_COMMAND,
)
return f"版本更新失败...错误: {result.error_message}"
await PlatformUtils.send_superuser(
bot, "真寻更新完成,开始安装依赖...", user_id
)
await VirtualEnvPackageManager.install_requirement(REQUIREMENTS_FILE)
return (
f"版本更新完成!\n"
f"版本: {cur_version} -> {result.new_version}\n"
f"变更文件个数: {len(result.changed_files)}"
f"{'' if source == 'git' else '(阿里云更新不支持查看变更文件)'}\n"
"请重新启动真寻以完成更新!"
)
@classmethod
def __get_version(cls) -> str:
@ -297,40 +168,3 @@ class UpdateManager:
if text := VERSION_FILE.open(encoding="utf8").readline():
_version = text.split(":")[-1].strip()
return _version
@classmethod
async def __get_latest_data(cls) -> dict:
"""获取最新版本信息
返回:
dict: 最新版本数据
"""
for _ in range(3):
try:
res = await AsyncHttpx.get(RELEASE_URL)
if res.status_code == 200:
return res.json()
except TimeoutError:
pass
except Exception as e:
logger.error("检查更新真寻获取版本失败", e=e)
return {}
@classmethod
async def __get_version_from_repo(cls, repo_info: RepoInfo) -> str:
"""从指定分支获取版本号
参数:
branch: 分支名称
返回:
str: 版本号
"""
version_url = await repo_info.get_raw_download_urls(path="__version__")
try:
res = await AsyncHttpx.get(version_url)
if res.status_code == 200:
return res.text.strip()
except Exception as e:
logger.error(f"获取 {repo_info.branch} 分支版本失败", e=e)
return "未知版本"

View File

@ -1,42 +1,7 @@
from pathlib import Path
from zhenxun.configs.path_config import TEMP_PATH
LOG_COMMAND = "AutoUpdate"
GIT_GITHUB_URL = "https://github.com/zhenxun-org/zhenxun_bot.git"
VERSION_FILE = Path() / "__version__"
DEFAULT_GITHUB_URL = "https://github.com/HibiKier/zhenxun_bot/tree/main"
RELEASE_URL = "https://api.github.com/repos/HibiKier/zhenxun_bot/releases/latest"
GIT_WEBUI_UI_URL = "https://github.com/HibiKier/zhenxun_bot_webui.git"
VERSION_FILE_STRING = "__version__"
VERSION_FILE = Path() / VERSION_FILE_STRING
PYPROJECT_FILE_STRING = "pyproject.toml"
PYPROJECT_FILE = Path() / PYPROJECT_FILE_STRING
PYPROJECT_LOCK_FILE_STRING = "poetry.lock"
PYPROJECT_LOCK_FILE = Path() / PYPROJECT_LOCK_FILE_STRING
REQ_TXT_FILE_STRING = "requirements.txt"
REQ_TXT_FILE = Path() / REQ_TXT_FILE_STRING
BASE_PATH_STRING = "zhenxun"
BASE_PATH = Path() / BASE_PATH_STRING
TMP_PATH = TEMP_PATH / "auto_update"
BACKUP_PATH = Path() / "backup"
DOWNLOAD_GZ_FILE_STRING = "download_latest_file.tar.gz"
DOWNLOAD_ZIP_FILE_STRING = "download_latest_file.zip"
DOWNLOAD_GZ_FILE = TMP_PATH / DOWNLOAD_GZ_FILE_STRING
DOWNLOAD_ZIP_FILE = TMP_PATH / DOWNLOAD_ZIP_FILE_STRING
REPLACE_FOLDERS = [
"builtin_plugins",
"services",
"utils",
"models",
"configs",
]
COMMAND = "检查更新"
REQUIREMENTS_FILE = Path() / "requirements.txt"

View File

@ -12,6 +12,7 @@ from zhenxun.services.llm.core import KeyStatus
from zhenxun.services.llm.manager import (
reset_key_status,
)
from zhenxun.services.llm.types import LLMMessage
class DataSource:
@ -58,7 +59,7 @@ class DataSource:
start_time = time.monotonic()
try:
async with await get_model_instance(model_name_str) as model:
await model.generate_text("你好")
await model.generate_response([LLMMessage.user("你好")])
end_time = time.monotonic()
latency = (end_time - start_time) * 1000
return (

View File

@ -1,16 +1,21 @@
from collections.abc import Callable
import copy
from pathlib import Path
from typing import Any, TypeVar, get_args, get_origin
from typing import Any, TypeVar
import cattrs
from nonebot.compat import model_dump
from pydantic import VERSION, BaseModel, Field
from pydantic import BaseModel, Field
from ruamel.yaml import YAML
from ruamel.yaml.scanner import ScannerError
from zhenxun.configs.path_config import DATA_PATH
from zhenxun.services.log import logger
from zhenxun.utils.pydantic_compat import (
_dump_pydantic_obj,
_is_pydantic_type,
model_dump,
parse_as,
)
from .models import (
AICallableParam,
@ -39,46 +44,6 @@ class NoSuchConfig(Exception):
pass
def _dump_pydantic_obj(obj: Any) -> Any:
"""
递归地将一个对象内部的 Pydantic BaseModel 实例转换为字典
支持单个实例实例列表实例字典等情况
"""
if isinstance(obj, BaseModel):
return model_dump(obj)
if isinstance(obj, list):
return [_dump_pydantic_obj(item) for item in obj]
if isinstance(obj, dict):
return {key: _dump_pydantic_obj(value) for key, value in obj.items()}
return obj
def _is_pydantic_type(t: Any) -> bool:
"""
递归检查一个类型注解是否与 Pydantic BaseModel 相关
"""
if t is None:
return False
origin = get_origin(t)
if origin:
return any(_is_pydantic_type(arg) for arg in get_args(t))
return isinstance(t, type) and issubclass(t, BaseModel)
def parse_as(type_: type[T], obj: Any) -> T:
"""
一个兼容 Pydantic V1 parse_obj_as 和V2的TypeAdapter.validate_python 的辅助函数
"""
if VERSION.startswith("1"):
from pydantic import parse_obj_as
return parse_obj_as(type_, obj)
else:
from pydantic import TypeAdapter # type: ignore
return TypeAdapter(type_).validate_python(obj)
class ConfigGroup(BaseModel):
"""
配置组
@ -194,16 +159,11 @@ class ConfigsManager:
"""
result = dict(original_data)
# 遍历新数据的键
for key, value in new_data.items():
# 如果键不在原数据中,添加它
if key not in original_data:
result[key] = value
# 如果两边都是字典,递归处理
elif isinstance(value, dict) and isinstance(original_data[key], dict):
result[key] = self._merge_dicts(value, original_data[key])
# 如果键已存在,保留原值,不覆盖
# (不做任何操作,保持原值)
return result
@ -217,15 +177,11 @@ class ConfigsManager:
返回:
标准化后的值
"""
# 处理BaseModel
processed_value = _dump_pydantic_obj(value)
# 如果处理后的值是字典,且原始值也存在
if isinstance(processed_value, dict) and original_value is not None:
# 处理原始值
processed_original = _dump_pydantic_obj(original_value)
# 如果原始值也是字典,合并它们
if isinstance(processed_original, dict):
return self._merge_dicts(processed_value, processed_original)
@ -263,12 +219,10 @@ class ConfigsManager:
if not module or not key:
raise ValueError("add_plugin_config: module和key不能为为空")
# 获取现有配置值(如果存在)
existing_value = None
if module in self._data and (config := self._data[module].configs.get(key)):
existing_value = config.value
# 标准化值和默认值
processed_value = self._normalize_config_data(value, existing_value)
processed_default_value = self._normalize_config_data(default_value)
@ -348,7 +302,6 @@ class ConfigsManager:
if value_to_process is None:
return default
# 1. 最高优先级:自定义的参数解析器
if config.arg_parser:
try:
return config.arg_parser(value_to_process)

View File

@ -2,10 +2,10 @@ from collections.abc import Callable
from datetime import datetime
from typing import Any, Literal
from nonebot.compat import model_dump
from pydantic import BaseModel, Field
from zhenxun.utils.enum import BlockType, LimitWatchType, PluginLimitType, PluginType
from zhenxun.utils.pydantic_compat import model_dump
__all__ = [
"AICallableParam",

View File

@ -27,23 +27,19 @@ from .llm import (
LLMException,
LLMGenerationConfig,
LLMMessage,
analyze,
analyze_multimodal,
chat,
clear_model_cache,
code,
create_multimodal_message,
embed,
generate,
generate_structured,
get_cache_stats,
get_model_instance,
list_available_models,
list_embedding_models,
pipeline_chat,
search,
search_multimodal,
set_global_default_model_name,
tool_registry,
)
from .log import logger
from .plugin_init import PluginInit, PluginInitManager
@ -60,8 +56,6 @@ __all__ = [
"Model",
"PluginInit",
"PluginInitManager",
"analyze",
"analyze_multimodal",
"chat",
"clear_model_cache",
"code",
@ -69,16 +63,14 @@ __all__ = [
"disconnect",
"embed",
"generate",
"generate_structured",
"get_cache_stats",
"get_model_instance",
"list_available_models",
"list_embedding_models",
"logger",
"pipeline_chat",
"scheduler_manager",
"search",
"search_multimodal",
"set_global_default_model_name",
"tool_registry",
"with_db_timeout",
]

View File

@ -1,558 +0,0 @@
---
# 🚀 Zhenxun LLM 服务模块
本模块是一个功能强大、高度可扩展的统一大语言模型LLM服务框架。它旨在将各种不同的 LLM 提供商(如 OpenAI、Gemini、智谱AI等的 API 封装在一个统一、易于使用的接口之后,让开发者可以无缝切换和使用不同的模型,同时支持多模态输入、工具调用、智能重试和缓存等高级功能。
## 目录
- [🚀 Zhenxun LLM 服务模块](#-zhenxun-llm-服务模块)
- [目录](#目录)
- [✨ 核心特性](#-核心特性)
- [🧠 核心概念](#-核心概念)
- [🛠️ 安装与配置](#-安装与配置)
- [服务提供商配置 (`config.yaml`)](#服务提供商配置-configyaml)
- [MCP 工具配置 (`mcp_tools.json`)](#mcp-工具配置-mcp_toolsjson)
- [📘 使用指南](#-使用指南)
- [**等级1: 便捷函数** - 最快速的调用方式](#等级1-便捷函数---最快速的调用方式)
- [**等级2: `AI` 会话类** - 管理有状态的对话](#等级2-ai-会话类---管理有状态的对话)
- [**等级3: 直接模型控制** - `get_model_instance`](#等级3-直接模型控制---get_model_instance)
- [🌟 功能深度剖析](#-功能深度剖析)
- [精细化控制模型生成 (`LLMGenerationConfig` 与 `CommonOverrides`)](#精细化控制模型生成-llmgenerationconfig-与-commonoverrides)
- [赋予模型能力:工具使用 (Function Calling)](#赋予模型能力工具使用-function-calling)
- [1. 注册工具](#1-注册工具)
- [函数工具注册](#函数工具注册)
- [MCP工具注册](#mcp工具注册)
- [2. 调用带工具的模型](#2-调用带工具的模型)
- [处理多模态输入](#处理多模态输入)
- [🔧 高级主题与扩展](#-高级主题与扩展)
- [模型与密钥管理](#模型与密钥管理)
- [缓存管理](#缓存管理)
- [错误处理 (`LLMException`)](#错误处理-llmexception)
- [自定义适配器 (Adapter)](#自定义适配器-adapter)
- [📚 API 快速参考](#-api-快速参考)
---
## ✨ 核心特性
- **多提供商支持**: 内置对 OpenAI、Gemini、智谱AI 等多种 API 的适配器,并可通过通用 OpenAI 兼容适配器轻松接入更多服务。
- **统一的 API**: 提供从简单到高级的三层 API满足不同场景的需求无论是快速聊天还是复杂的分析任务。
- **强大的工具调用 (Function Calling)**: 支持标准的函数调用和实验性的 MCP (Model Context Protocol) 工具,让 LLM 能够与外部世界交互。
- **多模态能力**: 无缝集成 `UniMessage`,轻松处理文本、图片、音频、视频等混合输入,支持多模态搜索和分析。
- **文本嵌入向量化**: 提供统一的嵌入接口,支持语义搜索、相似度计算和文本聚类等应用。
- **智能重试与 Key 轮询**: 内置健壮的请求重试逻辑,当 API Key 失效或达到速率限制时,能自动轮询使用备用 Key。
- **灵活的配置系统**: 通过配置文件和代码中的 `LLMGenerationConfig`可以精细控制模型的生成行为如温度、最大Token等
- **高性能缓存机制**: 内置模型实例缓存,减少重复初始化开销,提供缓存管理和监控功能。
- **丰富的配置预设**: 提供 `CommonOverrides`包含创意模式、精确模式、JSON输出等多种常用配置预设。
- **可扩展的适配器架构**: 开发者可以轻松编写自己的适配器来支持新的 LLM 服务。
## 🧠 核心概念
- **适配器 (Adapter)**: 这是连接我们统一接口和特定 LLM 提供商 API 的“翻译官”。例如,`GeminiAdapter` 知道如何将我们的标准请求格式转换为 Google Gemini API 需要的格式,并解析其响应。
- **模型实例 (`LLMModel`)**: 这是框架中的核心操作对象,代表一个**具体配置好**的模型。例如,一个 `LLMModel` 实例可能代表使用特定 API Key、特定代理的 `Gemini/gemini-1.5-pro`。所有与模型交互的操作都通过这个类的实例进行。
- **生成配置 (`LLMGenerationConfig`)**: 这是一个数据类,用于控制模型在生成内容时的行为,例如 `temperature` (温度)、`max_tokens` (最大输出长度)、`response_format` (响应格式) 等。
- **工具 (Tool)**: 代表一个可以让 LLM 调用的函数。它可以是一个简单的 Python 函数,也可以是一个更复杂的、有状态的 MCP 服务。
- **多模态内容 (`LLMContentPart`)**: 这是处理多模态输入的基础单元,一个 `LLMMessage` 可以包含多个 `LLMContentPart`,如一个文本部分和多个图片部分。
## 🛠️ 安装与配置
该模块作为 `zhenxun` 项目的一部分被集成,无需额外安装。核心配置主要涉及两个文件。
### 服务提供商配置 (`config.yaml`)
核心配置位于项目 `/data/config.yaml` 文件中的 `AI` 部分。
```yaml
# /data/configs/config.yaml
AI:
# (可选) 全局默认模型,格式: "ProviderName/ModelName"
default_model_name: Gemini/gemini-2.5-flash
# (可选) 全局代理设置
proxy: http://127.0.0.1:7890
# (可选) 全局超时设置 (秒)
timeout: 180
# (可选) Gemini 的安全过滤阈值
gemini_safety_threshold: BLOCK_MEDIUM_AND_ABOVE
# 配置你的AI服务提供商
PROVIDERS:
# 示例1: Gemini
- name: Gemini
api_key:
- "AIzaSy_KEY_1" # 支持多个Key会自动轮询
- "AIzaSy_KEY_2"
api_base: https://generativelanguage.googleapis.com
api_type: gemini
models:
- model_name: gemini-2.5-pro
- model_name: gemini-2.5-flash
- model_name: gemini-2.0-flash
- model_name: embedding-001
is_embedding_model: true # 标记为嵌入模型
max_input_tokens: 2048 # 嵌入模型特有配置
# 示例2: 智谱AI
- name: GLM
api_key: "YOUR_ZHIPU_API_KEY"
api_type: zhipu # 适配器类型
models:
- model_name: glm-4-flash
- model_name: glm-4-plus
temperature: 0.8 # 可以为特定模型设置默认温度
# 示例3: 一个兼容OpenAI的自定义服务
- name: MyOpenAIService
api_key: "sk-my-custom-key"
api_base: "http://localhost:8080/v1"
api_type: general_openai_compat # 使用通用OpenAI兼容适配器
models:
- model_name: Llama3-8B-Instruct
max_tokens: 2048 # 可以为特定模型设置默认最大Token
```
### MCP 工具配置 (`mcp_tools.json`)
此文件位于 `/data/llm/mcp_tools.json`,用于配置通过 MCP 协议启动的外部工具服务。如果文件不存在,系统会自动创建一个包含示例的默认文件。
```json
{
"mcpServers": {
"baidu-map": {
"command": "npx",
"args": ["-y", "@baidumap/mcp-server-baidu-map"],
"env": {
"BAIDU_MAP_API_KEY": "<YOUR_BAIDU_MAP_API_KEY>"
},
"description": "百度地图工具,提供地理编码、路线规划等功能。"
},
"sequential-thinking": {
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-sequential-thinking"],
"description": "顺序思维工具,用于帮助模型进行多步骤推理。"
}
}
}
```
## 📘 使用指南
我们提供了三层 API以满足从简单到复杂的各种需求。
### **等级1: 便捷函数** - 最快速的调用方式
这些函数位于 `zhenxun.services.llm` 包的顶层,为你处理了所有的底层细节。
```python
from zhenxun.services.llm import chat, search, code, pipeline_chat, embed, analyze_multimodal, search_multimodal
from zhenxun.services.llm.utils import create_multimodal_message
# 1. 纯文本聊天
response_text = await chat("你好,请用苏轼的风格写一首关于月亮的诗。")
print(response_text)
# 2. 带网络搜索的问答
search_result = await search("马斯克的Neuralink公司最近有什么新进展")
print(search_result['text'])
# print(search_result['sources']) # 查看信息来源
# 3. 执行代码
code_result = await code("用Python画一个心形图案。")
print(code_result['text']) # 包含代码和解释的回复
# 4. 链式调用
image_msg = create_multimodal_message(images="path/to/cat.jpg")
final_poem = await pipeline_chat(
message=image_msg,
model_chain=["Gemini/gemini-1.5-pro", "GLM/glm-4-flash"],
initial_instruction="详细描述这只猫的外观和姿态。",
final_instruction="将上述描述凝练成一首可爱的短诗。"
)
print(final_poem.text)
# 5. 文本嵌入向量生成
texts_to_embed = ["今天天气真好", "我喜欢打篮球", "这部电影很感人"]
vectors = await embed(texts_to_embed, model="Gemini/embedding-001")
print(f"生成了 {len(vectors)} 个向量,每个向量维度: {len(vectors[0])}")
# 6. 多模态分析便捷函数
response = await analyze_multimodal(
text="请分析这张图片中的内容",
images="path/to/image.jpg",
model="Gemini/gemini-1.5-pro"
)
print(response)
# 7. 多模态搜索便捷函数
search_result = await search_multimodal(
text="搜索与这张图片相关的信息",
images="path/to/image.jpg",
model="Gemini/gemini-1.5-pro"
)
print(search_result['text'])
```
### **等级2: `AI` 会话类** - 管理有状态的对话
当你需要进行有上下文的、连续的对话时,`AI` 类是你的最佳选择。
```python
from zhenxun.services.llm import AI, AIConfig
# 初始化一个AI会话可以传入自定义配置
ai_config = AIConfig(model="GLM/glm-4-flash", temperature=0.7)
ai_session = AI(config=ai_config)
# 更完整的AIConfig配置示例
advanced_config = AIConfig(
model="GLM/glm-4-flash",
default_embedding_model="Gemini/embedding-001", # 默认嵌入模型
temperature=0.7,
max_tokens=2000,
enable_cache=True, # 启用模型缓存
enable_code=True, # 启用代码执行功能
enable_search=True, # 启用搜索功能
timeout=180, # 请求超时时间(秒)
# Gemini特定配置选项
enable_gemini_json_mode=True, # 启用Gemini JSON模式
enable_gemini_thinking=True, # 启用Gemini 思考模式
enable_gemini_safe_mode=True, # 启用Gemini 安全模式
enable_gemini_multimodal=True, # 启用Gemini 多模态优化
enable_gemini_grounding=True, # 启用Gemini 信息来源关联
)
advanced_session = AI(config=advanced_config)
# 进行连续对话
await ai_session.chat("我最喜欢的城市是成都。")
response = await ai_session.chat("它有什么好吃的?") # AI会知道“它”指的是成都
print(response)
# 在同一个会话中,临时切换模型进行一次调用
response_gemini = await ai_session.chat(
"从AI的角度分析一下成都的科技发展潜力。",
model="Gemini/gemini-1.5-pro"
)
print(response_gemini)
# 清空历史,开始新一轮对话
ai_session.clear_history()
```
### **等级3: 直接模型控制** - `get_model_instance`
这是最底层的 API为你提供对模型实例的完全控制。推荐使用 `async with` 语句来优雅地管理模型实例的生命周期。
```python
from zhenxun.services.llm import get_model_instance, LLMMessage
from zhenxun.services.llm.config import LLMGenerationConfig
# 1. 获取模型实例
# get_model_instance 返回一个异步上下文管理器
async with await get_model_instance("Gemini/gemini-1.5-pro") as model:
# 2. 准备消息列表
messages = [
LLMMessage.system("你是一个专业的营养师。"),
LLMMessage.user("我今天吃了汉堡和可乐,请给我一些健康建议。")
]
# 3. (可选) 定义本次调用的生成配置
gen_config = LLMGenerationConfig(
temperature=0.2, # 更严谨的回复
max_tokens=300
)
# 4. 生成响应
response = await model.generate_response(messages, config=gen_config)
# 5. 处理响应
print(response.text)
if response.usage_info:
print(f"Token 消耗: {response.usage_info['total_tokens']}")
```
## 🌟 功能深度剖析
### 精细化控制模型生成 (`LLMGenerationConfig` 与 `CommonOverrides`)
- **`LLMGenerationConfig`**: 一个 Pydantic 模型,用于覆盖模型的默认生成参数。
- **`CommonOverrides`**: 一个包含多种常用配置预设的类,如 `creative()`, `precise()`, `gemini_json()` 等,能极大地简化配置过程。
```python
from zhenxun.services.llm.config import LLMGenerationConfig, CommonOverrides
# LLMGenerationConfig 完整参数示例
comprehensive_config = LLMGenerationConfig(
temperature=0.7, # 生成温度 (0.0-2.0)
max_tokens=1000, # 最大输出token数
top_p=0.9, # 核采样参数 (0.0-1.0)
top_k=40, # Top-K采样参数
frequency_penalty=0.0, # 频率惩罚 (-2.0-2.0)
presence_penalty=0.0, # 存在惩罚 (-2.0-2.0)
repetition_penalty=1.0, # 重复惩罚 (0.0-2.0)
stop=["END", "\n\n"], # 停止序列
response_format={"type": "json_object"}, # 响应格式
response_mime_type="application/json", # Gemini专用MIME类型
response_schema={...}, # JSON响应模式
thinking_budget=0.8, # Gemini思考预算 (0.0-1.0)
enable_code_execution=True, # 启用代码执行
safety_settings={...}, # 安全设置
response_modalities=["TEXT"], # 响应模态类型
)
# 创建一个配置要求模型输出JSON格式
json_config = LLMGenerationConfig(
temperature=0.1,
response_mime_type="application/json" # Gemini特有
)
# 对于OpenAI兼容API可以这样做
json_config_openai = LLMGenerationConfig(
temperature=0.1,
response_format={"type": "json_object"}
)
# 使用框架提供的预设 - 基础预设
safe_config = CommonOverrides.gemini_safe()
creative_config = CommonOverrides.creative()
precise_config = CommonOverrides.precise()
balanced_config = CommonOverrides.balanced()
# 更多实用预设
concise_config = CommonOverrides.concise(max_tokens=50) # 简洁模式
detailed_config = CommonOverrides.detailed(max_tokens=3000) # 详细模式
json_config = CommonOverrides.gemini_json() # JSON输出模式
thinking_config = CommonOverrides.gemini_thinking(budget=0.8) # 思考模式
# Gemini特定高级预设
code_config = CommonOverrides.gemini_code_execution() # 代码执行模式
grounding_config = CommonOverrides.gemini_grounding() # 信息来源关联模式
multimodal_config = CommonOverrides.gemini_multimodal() # 多模态优化模式
# 在调用时传入config对象
# await model.generate_response(messages, config=json_config)
```
### 赋予模型能力:工具使用 (Function Calling)
工具调用让 LLM 能够与外部函数、API 或服务进行交互。
#### 1. 注册工具
##### 函数工具注册
使用 `@tool_registry.function_tool` 装饰器注册一个简单的函数工具。
```python
from zhenxun.services.llm import tool_registry
@tool_registry.function_tool(
name="query_stock_price",
description="查询指定股票代码的当前价格。",
parameters={
"stock_symbol": {"type": "string", "description": "股票代码, 例如 'AAPL' 或 'GOOG'"}
},
required=["stock_symbol"]
)
async def query_stock_price(stock_symbol: str) -> dict:
"""一个查询股票价格的伪函数"""
print(f"--- 正在查询 {stock_symbol} 的价格 ---")
if stock_symbol == "AAPL":
return {"symbol": "AAPL", "price": 175.50, "currency": "USD"}
return {"error": "未知的股票代码"}
```
##### MCP工具注册
对于更复杂的、有状态的工具,可以使用 `@tool_registry.mcp_tool` 装饰器注册MCP工具。
```python
from contextlib import asynccontextmanager
from pydantic import BaseModel
from zhenxun.services.llm import tool_registry
# 定义工具的配置模型
class MyToolConfig(BaseModel):
api_key: str
endpoint: str
timeout: int = 30
# 注册MCP工具
@tool_registry.mcp_tool(name="my-custom-tool", config_model=MyToolConfig)
@asynccontextmanager
async def my_tool_factory(config: MyToolConfig):
"""MCP工具工厂函数"""
# 初始化工具会话
session = MyToolSession(config)
try:
await session.initialize()
yield session
finally:
await session.cleanup()
```
#### 2. 调用带工具的模型
`analyze``generate_response` 中使用 `use_tools` 参数。框架会自动处理整个调用流程。
```python
from zhenxun.services.llm import analyze
from nonebot_plugin_alconna.uniseg import UniMessage
response = await analyze(
UniMessage("帮我查一下苹果公司的股价"),
use_tools=["query_stock_price"]
)
print(response.text) # 输出应为 "苹果公司(AAPL)的当前股价为175.5美元。" 或类似内容
```
### 处理多模态输入
本模块通过 `UniMessage``LLMContentPart` 完美支持多模态。
- **`create_multimodal_message`**: 推荐的、用于从代码中便捷地创建多模态消息的函数。
- **`unimsg_to_llm_parts`**: 框架内部使用的核心转换函数,将 `UniMessage` 的各个段(文本、图片等)转换为 `LLMContentPart` 列表。
```python
from zhenxun.services.llm import analyze
from zhenxun.services.llm.utils import create_multimodal_message
from pathlib import Path
# 从本地文件创建消息
message = create_multimodal_message(
text="请分析这张图片和这个视频。图片里是什么?视频里发生了什么?",
images=[Path("path/to/your/image.jpg")],
videos=[Path("path/to/your/video.mp4")]
)
response = await analyze(message, model="Gemini/gemini-1.5-pro")
print(response.text)
```
## 🔧 高级主题与扩展
### 模型与密钥管理
模块提供了一些工具函数来管理你的模型配置。
```python
from zhenxun.services.llm.manager import (
list_available_models,
list_embedding_models,
set_global_default_model_name,
get_global_default_model_name,
get_key_usage_stats,
reset_key_status
)
# 列出所有在config.yaml中配置的可用模型
models = list_available_models()
print([m['full_name'] for m in models])
# 列出所有可用的嵌入模型
embedding_models = list_embedding_models()
print([m['full_name'] for m in embedding_models])
# 动态设置全局默认模型
success = set_global_default_model_name("GLM/glm-4-plus")
# 获取所有Key的使用统计
stats = await get_key_usage_stats()
print(stats)
# 重置'Gemini'提供商的所有Key
await reset_key_status("Gemini")
```
### 缓存管理
模块提供了模型实例缓存功能,可以提高性能并减少重复初始化的开销。
```python
from zhenxun.services.llm import clear_model_cache, get_cache_stats
# 获取缓存统计信息
stats = get_cache_stats()
print(f"缓存大小: {stats['cache_size']}/{stats['max_cache_size']}")
print(f"缓存TTL: {stats['cache_ttl']}秒")
print(f"已缓存模型: {stats['cached_models']}")
# 清空模型缓存(在内存不足或需要强制重新初始化时使用)
clear_model_cache()
print("模型缓存已清空")
```
### 错误处理 (`LLMException`)
所有模块内的预期错误都会被包装成 `LLMException`,方便统一处理。
```python
from zhenxun.services.llm import chat, LLMException, LLMErrorCode
try:
await chat("test", model="InvalidProvider/invalid_model")
except LLMException as e:
print(f"捕获到LLM异常: {e}")
print(f"错误码: {e.code}") # 例如 LLMErrorCode.MODEL_NOT_FOUND
print(f"用户友好提示: {e.user_friendly_message}")
```
### 自定义适配器 (Adapter)
如果你想支持一个新的、非 OpenAI 兼容的 LLM 服务,可以通过实现自己的适配器来完成。
1. **创建适配器类**: 继承 `BaseAdapter` 并实现其抽象方法。
```python
# my_adapters/custom_adapter.py
from zhenxun.services.llm.adapters import BaseAdapter, RequestData, ResponseData
class MyCustomAdapter(BaseAdapter):
@property
def api_type(self) -> str: return "my_custom_api"
@property
def supported_api_types(self) -> list[str]: return ["my_custom_api"]
# ... 实现 prepare_advanced_request, parse_response 等方法
```
2. **注册适配器**: 在你的插件初始化代码中注册你的适配器。
```python
from zhenxun.services.llm.adapters import register_adapter
from .my_adapters.custom_adapter import MyCustomAdapter
register_adapter(MyCustomAdapter())
```
3. **在 `config.yaml` 中使用**:
```yaml
AI:
PROVIDERS:
- name: MyAwesomeLLM
api_key: "my-secret-key"
api_type: "my_custom_api" # 关键!使用你注册的 api_type
# ...
```
## 📚 API 快速参考
| 类/函数 | 主要用途 | 推荐场景 |
| ------------------------------------- | ---------------------------------------------------------------------- | ------------------------------------------------------------ |
| `llm.chat()` | 进行简单的、无状态的文本对话。 | 快速实现单轮问答。 |
| `llm.search()` | 执行带网络搜索的问答。 | 需要最新信息或回答事实性问题时。 |
| `llm.code()` | 请求模型执行代码。 | 计算、数据处理、代码生成等。 |
| `llm.pipeline_chat()` | 将多个模型串联,处理复杂任务流。 | 需要多模型协作完成的任务,如“图生文再润色”。 |
| `llm.analyze()` | 处理复杂的多模态输入 (`UniMessage`) 和工具调用。 | 插件中处理用户命令需要解析图片、at、回复等复杂消息时。 |
| `llm.AI` (类) | 管理一个有状态的、连续的对话会话。 | 需要实现上下文关联的连续对话机器人。 |
| `llm.get_model_instance()` | 获取一个底层的、可直接控制的 `LLMModel` 实例。 | 需要对模型进行最精细控制的复杂或自定义场景。 |
| `llm.config.LLMGenerationConfig` (类) | 定义模型生成的具体参数如温度、最大Token等。 | 当需要微调模型输出风格或格式时。 |
| `llm.tools.tool_registry` (实例) | 注册和管理可供LLM调用的函数工具。 | 当你想让LLM拥有与外部世界交互的能力时。 |
| `llm.embed()` | 生成文本的嵌入向量表示。 | 语义搜索、相似度计算、文本聚类等。 |
| `llm.search_multimodal()` | 执行带网络搜索的多模态问答。 | 需要基于图片、视频等多模态内容进行搜索时。 |
| `llm.analyze_multimodal()` | 便捷的多模态分析函数。 | 直接分析文本、图片、视频、音频等多模态内容。 |
| `llm.AIConfig` (类) | AI会话的配置类包含模型、温度等参数。 | 配置AI会话的行为和特性。 |
| `llm.clear_model_cache()` | 清空模型实例缓存。 | 内存管理或强制重新初始化模型时。 |
| `llm.get_cache_stats()` | 获取模型缓存的统计信息。 | 监控缓存使用情况和性能优化。 |
| `llm.list_embedding_models()` | 列出所有可用的嵌入模型。 | 选择合适的嵌入模型进行向量化任务。 |
| `llm.config.CommonOverrides` (类) | 提供常用的配置预设,如创意模式、精确模式等。 | 快速应用常见的模型配置组合。 |
| `llm.utils.create_multimodal_message` | 便捷地从文本、图片、音视频等数据创建 `UniMessage`。 | 在代码中以编程方式构建多模态输入时。 |

View File

@ -5,15 +5,13 @@ LLM 服务模块 - 公共 API 入口
"""
from .api import (
analyze,
analyze_multimodal,
chat,
code,
embed,
generate,
pipeline_chat,
generate_structured,
run_with_tools,
search,
search_multimodal,
)
from .config import (
CommonOverrides,
@ -34,7 +32,7 @@ from .manager import (
set_global_default_model_name,
)
from .session import AI, AIConfig
from .tools import tool_registry
from .tools import function_tool, tool_provider_manager
from .types import (
EmbeddingTaskType,
LLMContentPart,
@ -42,8 +40,6 @@ from .types import (
LLMException,
LLMMessage,
LLMResponse,
LLMTool,
MCPCompatible,
ModelDetail,
ModelInfo,
ModelProvider,
@ -66,8 +62,6 @@ __all__ = [
"LLMGenerationConfig",
"LLMMessage",
"LLMResponse",
"LLMTool",
"MCPCompatible",
"ModelDetail",
"ModelInfo",
"ModelName",
@ -77,14 +71,14 @@ __all__ = [
"ToolCategory",
"ToolMetadata",
"UsageInfo",
"analyze",
"analyze_multimodal",
"chat",
"clear_model_cache",
"code",
"create_multimodal_message",
"embed",
"function_tool",
"generate",
"generate_structured",
"get_cache_stats",
"get_global_default_model_name",
"get_model_instance",
@ -92,11 +86,10 @@ __all__ = [
"list_embedding_models",
"list_model_identifiers",
"message_to_unimessage",
"pipeline_chat",
"register_llm_configs",
"run_with_tools",
"search",
"search_multimodal",
"set_global_default_model_name",
"tool_registry",
"tool_provider_manager",
"unimsg_to_llm_parts",
]

View File

@ -17,7 +17,7 @@ if TYPE_CHECKING:
from ..service import LLMModel
from ..types.content import LLMMessage
from ..types.enums import EmbeddingTaskType
from ..types.models import LLMTool
from ..types.protocols import ToolExecutable
class RequestData(BaseModel):
@ -103,7 +103,7 @@ class BaseAdapter(ABC):
api_key: str,
messages: list["LLMMessage"],
config: "LLMGenerationConfig | None" = None,
tools: list["LLMTool"] | None = None,
tools: dict[str, "ToolExecutable"] | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> RequestData:
"""准备高级请求"""
@ -401,7 +401,6 @@ class BaseAdapter(ABC):
class OpenAICompatAdapter(BaseAdapter):
"""
处理所有 OpenAI 兼容 API 的通用适配器
消除 OpenAIAdapter ZhipuAdapter 之间的代码重复
"""
@abstractmethod
@ -445,7 +444,7 @@ class OpenAICompatAdapter(BaseAdapter):
api_key: str,
messages: list["LLMMessage"],
config: "LLMGenerationConfig | None" = None,
tools: list["LLMTool"] | None = None,
tools: dict[str, "ToolExecutable"] | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> RequestData:
"""准备高级请求 - OpenAI兼容格式"""
@ -459,21 +458,20 @@ class OpenAICompatAdapter(BaseAdapter):
}
if tools:
openai_tools = []
for tool in tools:
if tool.type == "function" and tool.function:
openai_tools.append({"type": "function", "function": tool.function})
elif tool.type == "mcp" and tool.mcp_session:
if callable(tool.mcp_session):
raise ValueError(
"适配器接收到未激活的 MCP 会话工厂。"
"会话工厂应该在 LLMModel.generate_response 中被激活。"
)
openai_tools.append(
tool.mcp_session.to_api_tool(api_type=self.api_type)
)
import asyncio
from zhenxun.utils.pydantic_compat import model_dump
definition_tasks = [
executable.get_definition() for executable in tools.values()
]
openai_tools = await asyncio.gather(*definition_tasks)
if openai_tools:
body["tools"] = openai_tools
body["tools"] = [
{"type": "function", "function": model_dump(tool)}
for tool in openai_tools
]
if tool_choice:
body["tool_choice"] = tool_choice

View File

@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any
from zhenxun.services.log import logger
from ..types.exceptions import LLMErrorCode, LLMException
from ..utils import sanitize_schema_for_llm
from .base import BaseAdapter, RequestData, ResponseData
if TYPE_CHECKING:
@ -14,7 +15,8 @@ if TYPE_CHECKING:
from ..service import LLMModel
from ..types.content import LLMMessage
from ..types.enums import EmbeddingTaskType
from ..types.models import LLMTool, LLMToolCall
from ..types.models import LLMToolCall
from ..types.protocols import ToolExecutable
class GeminiAdapter(BaseAdapter):
@ -44,7 +46,7 @@ class GeminiAdapter(BaseAdapter):
api_key: str,
messages: list["LLMMessage"],
config: "LLMGenerationConfig | None" = None,
tools: list["LLMTool"] | None = None,
tools: dict[str, "ToolExecutable"] | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> RequestData:
"""准备高级请求"""
@ -128,11 +130,22 @@ class GeminiAdapter(BaseAdapter):
)
tool_result_obj = {"raw_output": content_str}
if isinstance(tool_result_obj, list):
logger.debug(
f"工具 '{msg.name}' 的返回结果是列表,"
f"正在为Gemini API包装为JSON对象。"
)
final_response_payload = {"result": tool_result_obj}
elif not isinstance(tool_result_obj, dict):
final_response_payload = {"result": tool_result_obj}
else:
final_response_payload = tool_result_obj
current_parts.append(
{
"functionResponse": {
"name": msg.name,
"response": tool_result_obj,
"response": final_response_payload,
}
}
)
@ -145,22 +158,26 @@ class GeminiAdapter(BaseAdapter):
all_tools_for_request = []
if tools:
for tool in tools:
if tool.type == "function" and tool.function:
all_tools_for_request.append(
{"functionDeclarations": [tool.function]}
)
elif tool.type == "mcp" and tool.mcp_session:
if callable(tool.mcp_session):
raise ValueError(
"适配器接收到未激活的 MCP 会话工厂。"
"会话工厂应该在 LLMModel.generate_response 中被激活。"
)
all_tools_for_request.append(
tool.mcp_session.to_api_tool(api_type=self.api_type)
)
elif tool.type == "google_search":
all_tools_for_request.append({"googleSearch": {}})
import asyncio
from zhenxun.utils.pydantic_compat import model_dump
definition_tasks = [
executable.get_definition() for executable in tools.values()
]
tool_definitions = await asyncio.gather(*definition_tasks)
function_declarations = []
for tool_def in tool_definitions:
tool_def.parameters = sanitize_schema_for_llm(
tool_def.parameters, api_type="gemini"
)
function_declarations.append(model_dump(tool_def))
if function_declarations:
all_tools_for_request.append(
{"functionDeclarations": function_declarations}
)
if effective_config:
if getattr(effective_config, "enable_grounding", False):
@ -289,49 +306,21 @@ class GeminiAdapter(BaseAdapter):
self, model: "LLMModel", config: "LLMGenerationConfig | None" = None
) -> dict[str, Any]:
"""构建Gemini生成配置"""
generation_config: dict[str, Any] = {}
effective_config = config if config is not None else model._generation_config
if effective_config:
base_api_params = effective_config.to_api_params(
api_type="gemini", model_name=model.model_name
if not effective_config:
return {}
generation_config = effective_config.to_api_params(
api_type="gemini", model_name=model.model_name
)
if generation_config:
param_keys = list(generation_config.keys())
logger.debug(
f"构建Gemini生成配置完成包含 {len(generation_config)} 个参数: "
f"{param_keys}"
)
generation_config.update(base_api_params)
if getattr(effective_config, "response_mime_type", None):
generation_config["responseMimeType"] = (
effective_config.response_mime_type
)
if getattr(effective_config, "response_schema", None):
generation_config["responseSchema"] = effective_config.response_schema
thinking_budget = getattr(effective_config, "thinking_budget", None)
if thinking_budget is not None:
if "thinkingConfig" not in generation_config:
generation_config["thinkingConfig"] = {}
generation_config["thinkingConfig"]["thinkingBudget"] = thinking_budget
if getattr(effective_config, "response_modalities", None):
modalities = effective_config.response_modalities
if isinstance(modalities, list):
generation_config["responseModalities"] = [
m.upper() for m in modalities
]
elif isinstance(modalities, str):
generation_config["responseModalities"] = [modalities.upper()]
generation_config = {
k: v for k, v in generation_config.items() if v is not None
}
if generation_config:
param_keys = list(generation_config.keys())
logger.debug(
f"构建Gemini生成配置完成包含 {len(generation_config)} 个参数: "
f"{param_keys}"
)
return generation_config
@ -410,10 +399,16 @@ class GeminiAdapter(BaseAdapter):
text_content = ""
parsed_tool_calls: list["LLMToolCall"] | None = None
thought_summary_parts = []
answer_parts = []
for part in parts:
if "text" in part:
text_content += part["text"]
answer_parts.append(part["text"])
elif "thought" in part:
thought_summary_parts.append(part["thought"])
elif "thoughtSummary" in part:
thought_summary_parts.append(part["thoughtSummary"])
elif "functionCall" in part:
if parsed_tool_calls is None:
parsed_tool_calls = []
@ -445,12 +440,27 @@ class GeminiAdapter(BaseAdapter):
result = part["codeExecutionResult"]
if result.get("outcome") == "OK":
output = result.get("output", "")
text_content += f"\n[代码执行结果]:\n{output}\n"
answer_parts.append(f"\n[代码执行结果]:\n```\n{output}\n```\n")
else:
text_content += (
answer_parts.append(
f"\n[代码执行失败]: {result.get('outcome', 'UNKNOWN')}\n"
)
if thought_summary_parts:
full_thought_summary = "\n".join(thought_summary_parts).strip()
full_answer = "".join(answer_parts).strip()
formatted_parts = []
if full_thought_summary:
formatted_parts.append(f"🤔 **思考过程**\n\n{full_thought_summary}")
if full_answer:
separator = "\n\n---\n\n" if full_thought_summary else ""
formatted_parts.append(f"{separator}✅ **回答**\n\n{full_answer}")
text_content = "".join(formatted_parts)
else:
text_content = "".join(answer_parts)
usage_info = response_json.get("usageMetadata")
grounding_metadata_obj = None

View File

@ -1,16 +1,19 @@
"""
LLM 服务的高级 API 接口 - 便捷函数入口
LLM 服务的高级 API 接口 - 便捷函数入口 (无状态)
"""
from pathlib import Path
from typing import Any
from typing import Any, TypeVar
from nonebot_plugin_alconna.uniseg import UniMessage
from pydantic import BaseModel
from zhenxun.services.log import logger
from .config import CommonOverrides
from .config.generation import create_generation_config_from_kwargs
from .manager import get_model_instance
from .session import AI
from .tools.manager import tool_provider_manager
from .types import (
EmbeddingTaskType,
LLMContentPart,
@ -18,37 +21,53 @@ from .types import (
LLMException,
LLMMessage,
LLMResponse,
LLMTool,
ModelName,
)
from .utils import create_multimodal_message, unimsg_to_llm_parts
T = TypeVar("T", bound=BaseModel)
async def chat(
message: str | LLMMessage | list[LLMContentPart],
message: str | UniMessage | LLMMessage | list[LLMContentPart],
*,
model: ModelName = None,
tools: list[LLMTool] | None = None,
instruction: str | None = None,
tools: list[dict[str, Any] | str] | None = None,
tool_choice: str | dict[str, Any] | None = None,
**kwargs: Any,
) -> LLMResponse:
"""
聊天对话便捷函数
无状态的聊天对话便捷函数通过临时的AI会话实例与LLM模型交互
参数:
message: 用户输入的消息
model: 要使用的模型名称
tools: 本次对话可用的工具列表
tool_choice: 强制模型使用的工具
**kwargs: 传递给模型的其他参数
message: 用户输入的消息内容支持多种格式
model: 要使用的模型名称如果为None则使用默认模型
instruction: 系统指令用于指导AI的行为和回复风格
tools: 可用的工具列表支持字典配置或字符串标识符
tool_choice: 工具选择策略控制AI如何选择和使用工具
**kwargs: 额外的生成配置参数会被转换为LLMGenerationConfig
返回:
LLMResponse: 模型的完整响应可能包含文本或工具调用请求
LLMResponse: 包含AI回复内容使用信息和工具调用等的完整响应对象
"""
ai = AI()
return await ai.chat(
message, model=model, tools=tools, tool_choice=tool_choice, **kwargs
)
try:
config = create_generation_config_from_kwargs(**kwargs) if kwargs else None
ai_session = AI()
return await ai_session.chat(
message,
model=model,
instruction=instruction,
tools=tools,
tool_choice=tool_choice,
config=config,
)
except LLMException:
raise
except Exception as e:
logger.error(f"执行 chat 函数失败: {e}", e=e)
raise LLMException(f"聊天执行失败: {e}", cause=e)
async def code(
@ -57,143 +76,68 @@ async def code(
model: ModelName = None,
timeout: int | None = None,
**kwargs: Any,
) -> dict[str, Any]:
) -> LLMResponse:
"""
代码执行便捷函数
无状态的代码执行便捷函数支持在沙箱环境中执行代码
参数:
prompt: 代码执行的提示词
model: 要使用的模型名称
timeout: 代码执行超时时间
**kwargs: 传递给模型的其他参数
prompt: 代码执行的提示词描述要执行的代码任务
model: 要使用的模型名称默认使用Gemini/gemini-2.0-flash
timeout: 代码执行超时时间防止长时间运行的代码阻塞
**kwargs: 额外的生成配置参数
返回:
dict[str, Any]: 包含执行结果的字典
LLMResponse: 包含代码执行结果的完整响应对象
"""
ai = AI()
return await ai.code(prompt, model=model, timeout=timeout, **kwargs)
resolved_model = model or "Gemini/gemini-2.0-flash"
config = CommonOverrides.gemini_code_execution()
if timeout:
config.custom_params = config.custom_params or {}
config.custom_params["code_execution_timeout"] = timeout
final_config = config.to_dict()
final_config.update(kwargs)
return await chat(prompt, model=resolved_model, **final_config)
async def search(
query: str | UniMessage,
query: str | UniMessage | LLMMessage | list[LLMContentPart],
*,
model: ModelName = None,
instruction: str = "",
instruction: str = (
"你是一位强大的信息检索和整合专家。请利用可用的搜索工具,"
"根据用户的查询找到最相关的信息,并进行总结和回答。"
),
**kwargs: Any,
) -> dict[str, Any]:
) -> LLMResponse:
"""
信息搜索便捷函数
无状态的信息搜索便捷函数利用搜索工具获取实时信息
参数:
query: 搜索查询内容
model: 要使用的模型名称
instruction: 搜索指令
**kwargs: 传递给模型的其他参数
query: 搜索查询内容支持多种输入格式
model: 要使用的模型名称如果为None则使用默认模型
instruction: 搜索任务的系统指令指导AI如何处理搜索结果
**kwargs: 额外的生成配置参数
返回:
dict[str, Any]: 包含搜索结果的字典
LLMResponse: 包含搜索结果和AI整合回复的完整响应对象
"""
ai = AI()
return await ai.search(query, model=model, instruction=instruction, **kwargs)
logger.debug("执行无状态 'search' 任务...")
search_config = CommonOverrides.gemini_grounding()
final_config = search_config.to_dict()
final_config.update(kwargs)
async def analyze(
message: UniMessage | None,
*,
instruction: str = "",
model: ModelName = None,
use_tools: list[str] | None = None,
tool_config: dict[str, Any] | None = None,
**kwargs: Any,
) -> str | LLMResponse:
"""
内容分析便捷函数
参数:
message: 要分析的消息内容
instruction: 分析指令
model: 要使用的模型名称
use_tools: 要使用的工具名称列表
tool_config: 工具配置
**kwargs: 传递给模型的其他参数
返回:
str | LLMResponse: 分析结果
"""
ai = AI()
return await ai.analyze(
message,
instruction=instruction,
return await chat(
query,
model=model,
use_tools=use_tools,
tool_config=tool_config,
**kwargs,
instruction=instruction,
**final_config,
)
async def analyze_multimodal(
text: str | None = None,
images: list[str | Path | bytes] | str | Path | bytes | None = None,
videos: list[str | Path | bytes] | str | Path | bytes | None = None,
audios: list[str | Path | bytes] | str | Path | bytes | None = None,
*,
instruction: str = "",
model: ModelName = None,
**kwargs: Any,
) -> str | LLMResponse:
"""
多模态分析便捷函数
参数:
text: 文本内容
images: 图片文件路径字节数据或列表
videos: 视频文件路径字节数据或列表
audios: 音频文件路径字节数据或列表
instruction: 分析指令
model: 要使用的模型名称
**kwargs: 传递给模型的其他参数
返回:
str | LLMResponse: 分析结果
"""
message = create_multimodal_message(
text=text, images=images, videos=videos, audios=audios
)
return await analyze(message, instruction=instruction, model=model, **kwargs)
async def search_multimodal(
text: str | None = None,
images: list[str | Path | bytes] | str | Path | bytes | None = None,
videos: list[str | Path | bytes] | str | Path | bytes | None = None,
audios: list[str | Path | bytes] | str | Path | bytes | None = None,
*,
instruction: str = "",
model: ModelName = None,
**kwargs: Any,
) -> dict[str, Any]:
"""
多模态搜索便捷函数
参数:
text: 文本内容
images: 图片文件路径字节数据或列表
videos: 视频文件路径字节数据或列表
audios: 音频文件路径字节数据或列表
instruction: 搜索指令
model: 要使用的模型名称
**kwargs: 传递给模型的其他参数
返回:
dict[str, Any]: 包含搜索结果的字典
"""
message = create_multimodal_message(
text=text, images=images, videos=videos, audios=audios
)
ai = AI()
return await ai.search(message, model=model, instruction=instruction, **kwargs)
async def embed(
texts: list[str] | str,
*,
@ -202,140 +146,104 @@ async def embed(
**kwargs: Any,
) -> list[list[float]]:
"""
文本嵌入便捷函数
无状态的文本嵌入便捷函数将文本转换为向量表示
参数:
texts: 要生成嵌入向量的文本或文本列表
model: 要使用的嵌入模型名称
task_type: 嵌入任务类型
**kwargs: 传递给模型的其他参数
texts: 要生成嵌入的文本内容支持单个字符串或字符串列表
model: 要使用的嵌入模型名称如果为None则使用默认模型
task_type: 嵌入任务类型影响向量的优化方向如检索分类等
**kwargs: 额外的模型配置参数
返回:
list[list[float]]: 文本的嵌入向量列表
list[list[float]]: 文本对应的嵌入向量列表每个向量为浮点数列表
"""
ai = AI()
return await ai.embed(texts, model=model, task_type=task_type, **kwargs)
if isinstance(texts, str):
texts = [texts]
if not texts:
return []
async def pipeline_chat(
message: UniMessage | str | list[LLMContentPart],
model_chain: list[ModelName],
*,
initial_instruction: str = "",
final_instruction: str = "",
**kwargs: Any,
) -> LLMResponse:
"""
AI模型链式调用前一个模型的输出作为下一个模型的输入
参数:
message: 初始输入消息支持多模态
model_chain: 模型名称列表
initial_instruction: 第一个模型的系统指令
final_instruction: 最后一个模型的系统指令
**kwargs: 传递给模型实例的其他参数
返回:
LLMResponse: 最后一个模型的响应结果
"""
if not model_chain:
raise ValueError("模型链`model_chain`不能为空。")
current_content: str | list[LLMContentPart]
if isinstance(message, UniMessage):
current_content = await unimsg_to_llm_parts(message)
elif isinstance(message, str):
current_content = message
elif isinstance(message, list):
current_content = message
else:
raise TypeError(f"不支持的消息类型: {type(message)}")
final_response: LLMResponse | None = None
for i, model_name in enumerate(model_chain):
if not model_name:
raise ValueError(f"模型链中第 {i + 1} 个模型名称为空。")
is_first_step = i == 0
is_last_step = i == len(model_chain) - 1
messages_for_step: list[LLMMessage] = []
instruction_for_step = ""
if is_first_step and initial_instruction:
instruction_for_step = initial_instruction
elif is_last_step and final_instruction:
instruction_for_step = final_instruction
if instruction_for_step:
messages_for_step.append(LLMMessage.system(instruction_for_step))
messages_for_step.append(LLMMessage.user(current_content))
logger.info(
f"Pipeline Step [{i + 1}/{len(model_chain)}]: "
f"使用模型 '{model_name}' 进行处理..."
)
try:
async with await get_model_instance(model_name, **kwargs) as model:
response = await model.generate_response(messages_for_step)
final_response = response
current_content = response.text.strip()
if not current_content and not is_last_step:
logger.warning(
f"模型 '{model_name}' 在中间步骤返回了空内容,流水线可能无法继续。"
)
break
except Exception as e:
logger.error(f"在模型链的第 {i + 1} 步 ('{model_name}') 出错: {e}", e=e)
raise LLMException(
f"流水线在模型 '{model_name}' 处执行失败: {e}",
code=LLMErrorCode.GENERATION_FAILED,
cause=e,
try:
async with await get_model_instance(model) as model_instance:
return await model_instance.generate_embeddings(
texts, task_type=task_type, **kwargs
)
if final_response is None:
except LLMException:
raise
except Exception as e:
logger.error(f"文本嵌入失败: {e}", e=e)
raise LLMException(
"AI流水线未能产生任何响应。", code=LLMErrorCode.GENERATION_FAILED
f"文本嵌入失败: {e}", code=LLMErrorCode.EMBEDDING_FAILED, cause=e
)
return final_response
async def generate_structured(
message: str | LLMMessage | list[LLMContentPart],
response_model: type[T],
*,
model: ModelName = None,
instruction: str | None = None,
**kwargs: Any,
) -> T:
"""
无状态地生成结构化响应并自动解析为指定的Pydantic模型
参数:
message: 用户输入的消息内容支持多种格式
response_model: 用于解析和验证响应的Pydantic模型类
model: 要使用的模型名称如果为None则使用默认模型
instruction: 系统指令用于指导AI生成符合要求的结构化输出
**kwargs: 额外的生成配置参数
返回:
T: 解析后的Pydantic模型实例类型为response_model指定的类型
"""
try:
config = create_generation_config_from_kwargs(**kwargs) if kwargs else None
ai_session = AI()
return await ai_session.generate_structured(
message,
response_model,
model=model,
instruction=instruction,
config=config,
)
except LLMException:
raise
except Exception as e:
logger.error(f"生成结构化响应失败: {e}", e=e)
raise LLMException(f"生成结构化响应失败: {e}", cause=e)
async def generate(
messages: list[LLMMessage],
*,
model: ModelName = None,
tools: list[LLMTool] | None = None,
tools: list[dict[str, Any] | str] | None = None,
tool_choice: str | dict[str, Any] | None = None,
**kwargs: Any,
) -> LLMResponse:
"""
根据完整的消息列表包括系统指令生成一次性响应
这是一个便捷的函数不使用或修改任何会话历史
根据完整的消息列表生成一次性响应这是一个无状态的底层函数
参数:
messages: 用于生成响应的完整消息列表
model: 要使用的模型名称
tools: 可用的工具列表
tool_choice: 工具选择策略
**kwargs: 传递给模型的其他参数
messages: 完整的消息历史列表包括系统指令用户消息和助手回复
model: 要使用的模型名称如果为None则使用默认模型
tools: 可用的工具列表支持字典配置或字符串标识符
tool_choice: 工具选择策略控制AI如何选择和使用工具
**kwargs: 额外的生成配置参数会覆盖默认配置
返回:
LLMResponse: 模型的完整响应对象
LLMResponse: 包含AI回复内容使用信息和工具调用等的完整响应对象
"""
try:
ai_instance = AI()
resolved_model_name = ai_instance._resolve_model_name(model)
final_config_dict = ai_instance._merge_config(kwargs)
async with await get_model_instance(
resolved_model_name, override_config=final_config_dict
model, override_config=kwargs
) as model_instance:
return await model_instance.generate_response(
messages,
tools=tools,
tools=tools, # type: ignore
tool_choice=tool_choice,
)
except LLMException:
@ -343,3 +251,55 @@ async def generate(
except Exception as e:
logger.error(f"生成响应失败: {e}", e=e)
raise LLMException(f"生成响应失败: {e}", cause=e)
async def run_with_tools(
message: str | UniMessage | LLMMessage | list[LLMContentPart],
*,
model: ModelName = None,
instruction: str | None = None,
tools: list[str],
max_cycles: int = 5,
**kwargs: Any,
) -> LLMResponse:
"""
无状态地执行一个带本地Python函数的LLM调用循环
参数:
message: 用户输入
model: 使用的模型
instruction: 系统指令
tools: 要使用的本地函数工具名称列表 (必须已通过 @function_tool 注册)
max_cycles: 最大工具调用循环次数
**kwargs: 额外的生成配置参数
返回:
LLMResponse: 包含最终回复的响应对象
"""
from .executor import ExecutionConfig, LLMToolExecutor
from .utils import normalize_to_llm_messages
messages = await normalize_to_llm_messages(message, instruction)
async with await get_model_instance(
model, override_config=kwargs
) as model_instance:
resolved_tools = await tool_provider_manager.get_function_tools(tools)
if not resolved_tools:
logger.warning(
"run_with_tools 未找到任何可用的本地函数工具,将作为普通聊天执行。"
)
return await model_instance.generate_response(messages, tools=None)
executor = LLMToolExecutor(model_instance)
config = ExecutionConfig(max_cycles=max_cycles)
final_history = await executor.run(messages, resolved_tools, config)
for msg in reversed(final_history):
if msg.role == "assistant":
text = msg.content if isinstance(msg.content, str) else str(msg.content)
return LLMResponse(text=text, tool_calls=msg.tool_calls)
raise LLMException(
"带工具的执行循环未能产生有效的助手回复。", code=LLMErrorCode.GENERATION_FAILED
)

View File

@ -14,7 +14,6 @@ from .generation import (
from .presets import CommonOverrides
from .providers import (
LLMConfig,
ToolConfig,
get_gemini_safety_threshold,
get_llm_config,
register_llm_configs,
@ -27,7 +26,6 @@ __all__ = [
"LLMConfig",
"LLMGenerationConfig",
"ModelConfigOverride",
"ToolConfig",
"apply_api_specific_mappings",
"create_generation_config_from_kwargs",
"get_gemini_safety_threshold",

View File

@ -7,6 +7,7 @@ from typing import Any
from pydantic import BaseModel, Field
from zhenxun.services.log import logger
from zhenxun.utils.pydantic_compat import model_dump
from ..types.enums import ResponseFormat
from ..types.exceptions import LLMErrorCode, LLMException
@ -45,6 +46,9 @@ class ModelConfigOverride(BaseModel):
thinking_budget: float | None = Field(
default=None, ge=0.0, le=1.0, description="思考预算"
)
include_thoughts: bool | None = Field(
default=None, description="是否在响应中包含思维过程Gemini专用"
)
safety_settings: dict[str, str] | None = Field(default=None, description="安全设置")
response_modalities: list[str] | None = Field(
default=None, description="响应模态类型"
@ -62,22 +66,16 @@ class ModelConfigOverride(BaseModel):
def to_dict(self) -> dict[str, Any]:
"""转换为字典排除None值"""
model_data = model_dump(self, exclude_none=True)
result = {}
model_data = getattr(self, "model_dump", lambda: {})()
if not model_data:
model_data = {}
for field_name, _ in self.__class__.__dict__.get(
"model_fields", {}
).items():
value = getattr(self, field_name, None)
if value is not None:
model_data[field_name] = value
for key, value in model_data.items():
if value is not None:
if key == "custom_params" and isinstance(value, dict):
result.update(value)
else:
result[key] = value
if key == "custom_params" and isinstance(value, dict):
result.update(value)
else:
result[key] = value
return result
def merge_with_base_config(
@ -157,6 +155,10 @@ class LLMGenerationConfig(ModelConfigOverride):
params["responseSchema"] = self.response_schema
logger.debug(f"{api_type} 启用 JSON MIME 类型输出模式")
if self.custom_params:
custom_mapped = apply_api_specific_mappings(self.custom_params, api_type)
params.update(custom_mapped)
if api_type == "gemini":
if (
self.response_format != ResponseFormat.JSON
@ -169,17 +171,28 @@ class LLMGenerationConfig(ModelConfigOverride):
if self.response_schema is not None and "responseSchema" not in params:
params["responseSchema"] = self.response_schema
if self.thinking_budget is not None:
params["thinkingBudget"] = self.thinking_budget
if self.thinking_budget is not None or self.include_thoughts is not None:
thinking_config = params.setdefault("thinkingConfig", {})
if self.thinking_budget is not None:
max_budget = 24576
budget_value = int(self.thinking_budget * max_budget)
thinking_config["thinkingBudget"] = budget_value
logger.debug(
f"已将 thinking_budget (float: {self.thinking_budget}) "
f"转换为 Gemini API 的整数格式: {budget_value}"
)
if self.include_thoughts is not None:
thinking_config["includeThoughts"] = self.include_thoughts
logger.debug(f"已设置 includeThoughts: {self.include_thoughts}")
if self.safety_settings is not None:
params["safetySettings"] = self.safety_settings
if self.response_modalities is not None:
params["responseModalities"] = self.response_modalities
if self.custom_params:
custom_mapped = apply_api_specific_mappings(self.custom_params, api_type)
params.update(custom_mapped)
logger.debug(f"{api_type}转换配置参数: {len(params)}个参数")
return params

View File

@ -5,33 +5,19 @@ LLM 提供商配置管理
"""
from functools import lru_cache
import json
import sys
from typing import Any
from pydantic import BaseModel, Field
from zhenxun.configs.config import Config
from zhenxun.configs.path_config import DATA_PATH
from zhenxun.configs.utils import parse_as
from zhenxun.services.log import logger
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from ..core import key_store
from ..tools import tool_provider_manager
from ..types.models import ModelDetail, ProviderConfig
class ToolConfig(BaseModel):
"""MCP类型工具的配置定义"""
type: str = "mcp"
name: str = Field(..., description="工具的唯一名称标识")
description: str | None = Field(None, description="工具功能的描述")
mcp_config: dict[str, Any] | BaseModel = Field(
..., description="MCP服务器的特定配置"
)
AI_CONFIG_GROUP = "AI"
PROVIDERS_CONFIG_KEY = "PROVIDERS"
@ -57,9 +43,6 @@ class LLMConfig(BaseModel):
providers: list[ProviderConfig] = Field(
default_factory=list, description="配置多个 AI 服务提供商及其模型信息"
)
mcp_tools: list[ToolConfig] = Field(
default_factory=list, description="配置可用的外部MCP工具"
)
def get_provider_by_name(self, name: str) -> ProviderConfig | None:
"""根据名称获取提供商配置
@ -218,33 +201,6 @@ def get_default_providers() -> list[dict[str, Any]]:
]
def get_default_mcp_tools() -> dict[str, Any]:
"""
获取默认的MCP工具配置用于在文件不存在时创建
包含了 baidu-map, Context7, sequential-thinking.
"""
return {
"mcpServers": {
"baidu-map": {
"command": "npx",
"args": ["-y", "@baidumap/mcp-server-baidu-map"],
"env": {"BAIDU_MAP_API_KEY": "<YOUR_BAIDU_MAP_API_KEY>"},
"description": "百度地图工具,提供地理编码、路线规划等功能。",
},
"sequential-thinking": {
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-sequential-thinking"],
"description": "顺序思维工具,用于帮助模型进行多步骤推理。",
},
"Context7": {
"command": "npx",
"args": ["-y", "@upstash/context7-mcp@latest"],
"description": "Upstash 提供的上下文管理和记忆工具。",
},
}
}
def register_llm_configs():
"""注册 LLM 服务的配置项"""
logger.info("注册 LLM 服务的配置项")
@ -312,88 +268,9 @@ def register_llm_configs():
@lru_cache(maxsize=1)
def get_llm_config() -> LLMConfig:
"""获取 LLM 配置实例,现在会从新的 JSON 文件加载 MCP 工具"""
"""获取 LLM 配置实例,不再加载 MCP 工具配置"""
ai_config = get_ai_config()
llm_data_path = DATA_PATH / "llm"
mcp_tools_path = llm_data_path / "mcp_tools.json"
mcp_tools_list = []
mcp_servers_dict = {}
if not mcp_tools_path.exists():
logger.info(f"未找到 MCP 工具配置文件,将在 '{mcp_tools_path}' 创建一个。")
llm_data_path.mkdir(parents=True, exist_ok=True)
default_mcp_config = get_default_mcp_tools()
try:
with mcp_tools_path.open("w", encoding="utf-8") as f:
json.dump(default_mcp_config, f, ensure_ascii=False, indent=2)
mcp_servers_dict = default_mcp_config.get("mcpServers", {})
except Exception as e:
logger.error(f"创建默认 MCP 配置文件失败: {e}", e=e)
mcp_servers_dict = {}
else:
try:
with mcp_tools_path.open("r", encoding="utf-8") as f:
mcp_data = json.load(f)
mcp_servers_dict = mcp_data.get("mcpServers", {})
if not isinstance(mcp_servers_dict, dict):
logger.warning(
f"'{mcp_tools_path}' 中的 'mcpServers' 键不是一个字典,"
f"将使用空配置。"
)
mcp_servers_dict = {}
except json.JSONDecodeError as e:
logger.error(f"解析 MCP 配置文件 '{mcp_tools_path}' 失败: {e}", e=e)
except Exception as e:
logger.error(f"读取 MCP 配置文件时发生未知错误: {e}", e=e)
mcp_servers_dict = {}
if sys.platform == "win32":
logger.debug("检测到Windows平台正在调整MCP工具的npx命令...")
for name, config in mcp_servers_dict.items():
if isinstance(config, dict) and config.get("command") == "npx":
logger.info(f"为工具 '{name}' 包装npx命令以兼容Windows。")
original_args = config.get("args", [])
config["command"] = "cmd"
config["args"] = ["/c", "npx", *original_args]
if mcp_servers_dict:
mcp_tools_list = [
{
"name": name,
"type": "mcp",
"description": config.get("description", f"MCP tool for {name}"),
"mcp_config": config,
}
for name, config in mcp_servers_dict.items()
if isinstance(config, dict)
]
from ..tools.registry import tool_registry
for tool_dict in mcp_tools_list:
if isinstance(tool_dict, dict):
tool_name = tool_dict.get("name")
if not tool_name:
continue
config_model = tool_registry.get_mcp_config_model(tool_name)
if not config_model:
logger.debug(
f"MCP工具 '{tool_name}' 没有注册其配置模型,"
f"将跳过特定配置验证,直接使用原始配置字典。"
)
continue
mcp_config_data = tool_dict.get("mcp_config", {})
try:
parsed_mcp_config = parse_as(config_model, mcp_config_data)
tool_dict["mcp_config"] = parsed_mcp_config
except Exception as e:
raise ValueError(f"MCP工具 '{tool_name}' 的 `mcp_config` 配置错误: {e}")
config_data = {
"default_model_name": ai_config.get("default_model_name"),
"proxy": ai_config.get("proxy"),
@ -401,7 +278,6 @@ def get_llm_config() -> LLMConfig:
"max_retries_llm": ai_config.get("max_retries_llm", 3),
"retry_delay_llm": ai_config.get("retry_delay_llm", 2),
PROVIDERS_CONFIG_KEY: ai_config.get(PROVIDERS_CONFIG_KEY, []),
"mcp_tools": mcp_tools_list,
}
return parse_as(LLMConfig, config_data)
@ -504,12 +380,17 @@ def set_default_model(provider_model_name: str | None) -> bool:
async def _init_llm_config_on_startup():
"""
在服务启动时主动调用一次 get_llm_config key_store.initialize
以触发必要的初始化操作
并预热工具提供者管理器
"""
logger.info("正在初始化 LLM 配置并加载密钥状态...")
try:
get_llm_config()
await key_store.initialize()
logger.info("LLM 配置和密钥状态初始化完成。")
logger.debug("LLM 配置和密钥状态初始化完成。")
logger.debug("正在预热 LLM 工具提供者管理器...")
await tool_provider_manager.initialize()
logger.debug("LLM 工具提供者管理器预热完成。")
except Exception as e:
logger.error(f"LLM 配置或密钥状态初始化时发生错误: {e}", e=e)

View File

@ -335,10 +335,10 @@ async def with_smart_retry(
latency = (time.monotonic() - start_time) * 1000
if key_store and isinstance(result, tuple) and len(result) == 2:
final_result, api_key_used = result
_, api_key_used = result
if api_key_used:
await key_store.record_success(api_key_used, latency)
return final_result
return result
else:
return result

View File

@ -0,0 +1,193 @@
"""
LLM 轻量级工具执行器
提供驱动 LLM 与本地函数工具之间交互的核心循环
"""
import asyncio
from enum import Enum
import json
from typing import Any
from pydantic import BaseModel, Field
from zhenxun.services.log import logger
from zhenxun.utils.decorator.retry import Retry
from zhenxun.utils.pydantic_compat import model_dump
from .service import LLMModel
from .types import (
LLMErrorCode,
LLMException,
LLMMessage,
ToolExecutable,
ToolResult,
)
class ExecutionConfig(BaseModel):
"""
轻量级执行器的配置
"""
max_cycles: int = Field(default=5, description="工具调用循环的最大次数。")
class ToolErrorType(str, Enum):
"""结构化工具错误的类型枚举。"""
TOOL_NOT_FOUND = "ToolNotFound"
INVALID_ARGUMENTS = "InvalidArguments"
EXECUTION_ERROR = "ExecutionError"
USER_CANCELLATION = "UserCancellation"
class ToolErrorResult(BaseModel):
"""一个结构化的工具执行错误模型,用于返回给 LLM。"""
error_type: ToolErrorType = Field(..., description="错误的类型。")
message: str = Field(..., description="对错误的详细描述。")
is_retryable: bool = Field(False, description="指示这个错误是否可能通过重试解决。")
def model_dump(self, **kwargs):
return model_dump(self, **kwargs)
def _is_exception_retryable(e: Exception) -> bool:
"""判断一个异常是否应该触发重试。"""
if isinstance(e, LLMException):
retryable_codes = {
LLMErrorCode.API_REQUEST_FAILED,
LLMErrorCode.API_TIMEOUT,
LLMErrorCode.API_RATE_LIMITED,
}
return e.code in retryable_codes
return True
class LLMToolExecutor:
"""
一个通用的执行器负责驱动 LLM 与工具之间的多轮交互
"""
def __init__(self, model: LLMModel):
self.model = model
async def run(
self,
messages: list[LLMMessage],
tools: dict[str, ToolExecutable],
config: ExecutionConfig | None = None,
) -> list[LLMMessage]:
"""
执行完整的思考-行动循环
"""
effective_config = config or ExecutionConfig()
execution_history = list(messages)
for i in range(effective_config.max_cycles):
response = await self.model.generate_response(
execution_history, tools=tools
)
assistant_message = LLMMessage(
role="assistant",
content=response.text,
tool_calls=response.tool_calls,
)
execution_history.append(assistant_message)
if not response.tool_calls:
logger.info("✅ LLMToolExecutor模型未请求工具调用执行结束。")
return execution_history
logger.info(
f"🛠️ LLMToolExecutor模型请求并行调用 {len(response.tool_calls)} 个工具"
)
tool_results = await self._execute_tools_parallel_safely(
response.tool_calls,
tools,
)
execution_history.extend(tool_results)
raise LLMException(
f"超过最大工具调用循环次数 ({effective_config.max_cycles})。",
code=LLMErrorCode.GENERATION_FAILED,
)
async def _execute_single_tool_safely(
self, tool_call: Any, available_tools: dict[str, ToolExecutable]
) -> tuple[Any, ToolResult]:
"""安全地执行单个工具调用。"""
tool_name = tool_call.function.name
arguments = {}
try:
if tool_call.function.arguments:
arguments = json.loads(tool_call.function.arguments)
except json.JSONDecodeError as e:
error_result = ToolErrorResult(
error_type=ToolErrorType.INVALID_ARGUMENTS,
message=f"参数解析失败: {e}",
is_retryable=False,
)
return tool_call, ToolResult(output=model_dump(error_result))
try:
executable = available_tools.get(tool_name)
if not executable:
raise LLMException(
f"Tool '{tool_name}' not found.",
code=LLMErrorCode.CONFIGURATION_ERROR,
)
@Retry.simple(
stop_max_attempt=2, wait_fixed_seconds=1, return_on_failure=None
)
async def execute_with_retry():
return await executable.execute(**arguments)
execution_result = await execute_with_retry()
if execution_result is None:
raise LLMException("工具执行在多次重试后仍然失败。")
return tool_call, execution_result
except Exception as e:
error_type = ToolErrorType.EXECUTION_ERROR
is_retryable = _is_exception_retryable(e)
if (
isinstance(e, LLMException)
and e.code == LLMErrorCode.CONFIGURATION_ERROR
):
error_type = ToolErrorType.TOOL_NOT_FOUND
is_retryable = False
error_result = ToolErrorResult(
error_type=error_type, message=str(e), is_retryable=is_retryable
)
return tool_call, ToolResult(output=model_dump(error_result))
async def _execute_tools_parallel_safely(
self,
tool_calls: list[Any],
available_tools: dict[str, ToolExecutable],
) -> list[LLMMessage]:
"""并行执行所有工具调用,并对每个调用的错误进行隔离。"""
if not tool_calls:
return []
tasks = [
self._execute_single_tool_safely(call, available_tools)
for call in tool_calls
]
results = await asyncio.gather(*tasks)
tool_messages = [
LLMMessage.tool_response(
tool_call_id=original_call.id,
function_name=original_call.function.name,
result=result.output,
)
for original_call, result in results
]
return tool_messages

View File

@ -86,14 +86,23 @@ def _cache_model(cache_key: str, model: LLMModel):
def clear_model_cache():
"""清空模型缓存"""
"""
清空模型缓存释放所有缓存的模型实例
用于在内存不足或需要强制重新加载模型配置时清理缓存
"""
global _model_cache
_model_cache.clear()
logger.info("已清空模型缓存")
def get_cache_stats() -> dict[str, Any]:
"""获取缓存统计信息"""
"""
获取模型缓存的统计信息
返回:
dict[str, Any]: 包含缓存大小最大容量TTL和已缓存模型列表的统计信息
"""
return {
"cache_size": len(_model_cache),
"max_cache_size": _max_cache_size,
@ -169,7 +178,13 @@ def find_model_config(
def list_available_models() -> list[dict[str, Any]]:
"""列出所有配置的可用模型"""
"""
列出所有配置的可用模型及其详细信息
返回:
list[dict[str, Any]]: 模型信息列表每个字典包含提供商名称模型名称
能力信息是否为嵌入模型等详细信息
"""
providers = get_configured_providers()
model_list = []
for provider in providers:
@ -215,7 +230,13 @@ def list_model_identifiers() -> dict[str, list[str]]:
def list_embedding_models() -> list[dict[str, Any]]:
"""列出所有配置的嵌入模型"""
"""
列出所有配置的嵌入模型
返回:
list[dict[str, Any]]: 嵌入模型信息列表从所有可用模型中筛选出
支持嵌入功能的模型
"""
all_models = list_available_models()
return [model for model in all_models if model.get("is_embedding_model", False)]

View File

@ -0,0 +1,55 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Any
from .types import LLMMessage
class BaseMemory(ABC):
"""
记忆系统的抽象基类
定义了任何记忆后端都必须实现的接口
"""
@abstractmethod
async def get_history(self, session_id: str) -> list[LLMMessage]:
"""根据会话ID获取历史记录。"""
raise NotImplementedError
@abstractmethod
async def add_message(self, session_id: str, message: LLMMessage) -> None:
"""向指定会话添加一条消息。"""
raise NotImplementedError
@abstractmethod
async def add_messages(self, session_id: str, messages: list[LLMMessage]) -> None:
"""向指定会话添加多条消息。"""
raise NotImplementedError
@abstractmethod
async def clear_history(self, session_id: str) -> None:
"""清空指定会话的历史记录。"""
raise NotImplementedError
class InMemoryMemory(BaseMemory):
"""
一个简单的默认的内存记忆后端
将历史记录存储在进程内存中的字典里
"""
def __init__(self, **kwargs: Any):
self._history: dict[str, list[LLMMessage]] = defaultdict(list)
async def get_history(self, session_id: str) -> list[LLMMessage]:
return self._history.get(session_id, []).copy()
async def add_message(self, session_id: str, message: LLMMessage) -> None:
self._history[session_id].append(message)
async def add_messages(self, session_id: str, messages: list[LLMMessage]) -> None:
self._history[session_id].extend(messages)
async def clear_history(self, session_id: str) -> None:
if session_id in self._history:
del self._history[session_id]

View File

@ -6,9 +6,10 @@ LLM 模型实现类
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Callable
from contextlib import AsyncExitStack
import json
from typing import Any
from typing import Any, TypeVar
from pydantic import BaseModel
from zhenxun.services.log import logger
@ -28,33 +29,25 @@ from .types import (
LLMException,
LLMMessage,
LLMResponse,
LLMTool,
ModelDetail,
ProviderConfig,
ToolExecutable,
)
from .types.capabilities import ModelCapabilities, ModelModality
from .utils import _sanitize_request_body_for_logging
T = TypeVar("T", bound=BaseModel)
class LLMModelBase(ABC):
"""LLM模型抽象基类"""
@abstractmethod
async def generate_text(
self,
prompt: str,
history: list[dict[str, str]] | None = None,
**kwargs: Any,
) -> str:
"""生成文本"""
pass
@abstractmethod
async def generate_response(
self,
messages: list[LLMMessage],
config: LLMGenerationConfig | None = None,
tools: list[LLMTool] | None = None,
tools: dict[str, ToolExecutable] | None = None,
tool_choice: str | dict[str, Any] | None = None,
**kwargs: Any,
) -> LLMResponse:
@ -311,7 +304,7 @@ class LLMModel(LLMModelBase):
adapter,
messages: list[LLMMessage],
config: LLMGenerationConfig | None,
tools: list[LLMTool] | None,
tools: dict[str, ToolExecutable] | None,
tool_choice: str | dict[str, Any] | None,
http_client: LLMHttpClient,
):
@ -339,7 +332,7 @@ class LLMModel(LLMModelBase):
adapter,
messages: list[LLMMessage],
config: LLMGenerationConfig | None,
tools: list[LLMTool] | None,
tools: dict[str, ToolExecutable] | None,
tool_choice: str | dict[str, Any] | None,
http_client: LLMHttpClient,
failed_keys: set[str] | None = None,
@ -428,66 +421,23 @@ class LLMModel(LLMModelBase):
if self._is_closed:
raise RuntimeError(f"LLMModel实例已关闭: {self}")
async def generate_text(
self,
prompt: str,
history: list[dict[str, str]] | None = None,
**kwargs: Any,
) -> str:
"""生成文本"""
self._check_not_closed()
messages: list[LLMMessage] = []
if history:
for msg in history:
role = msg.get("role", "user")
content_text = msg.get("content", "")
messages.append(LLMMessage(role=role, content=content_text))
messages.append(LLMMessage.user(prompt))
model_fields = getattr(LLMGenerationConfig, "model_fields", {})
request_specific_config_dict = {
k: v for k, v in kwargs.items() if k in model_fields
}
request_specific_config = None
if request_specific_config_dict:
request_specific_config = LLMGenerationConfig(
**request_specific_config_dict
)
for key in request_specific_config_dict:
kwargs.pop(key, None)
response = await self.generate_response(
messages,
config=request_specific_config,
**kwargs,
)
return response.text
async def generate_response(
self,
messages: list[LLMMessage],
config: LLMGenerationConfig | None = None,
tools: list[LLMTool] | None = None,
tools: dict[str, ToolExecutable] | None = None,
tool_choice: str | dict[str, Any] | None = None,
**kwargs: Any,
) -> LLMResponse:
"""生成高级响应"""
"""
生成高级响应
此方法现在只执行 *单次* LLM API 调用并将结果包括工具调用请求返回
"""
self._check_not_closed()
from .adapters import get_adapter_for_api_type
from .config.generation import create_generation_config_from_kwargs
adapter = get_adapter_for_api_type(self.api_type)
if not adapter:
raise LLMException(
f"未找到适用于 API 类型 '{self.api_type}' 的适配器",
code=LLMErrorCode.CONFIGURATION_ERROR,
)
final_request_config = self._generation_config or LLMGenerationConfig()
if kwargs:
kwargs_config = create_generation_config_from_kwargs(**kwargs)
@ -500,43 +450,19 @@ class LLMModel(LLMModelBase):
merged_dict.update(config.to_dict())
final_request_config = LLMGenerationConfig(**merged_dict)
adapter = get_adapter_for_api_type(self.api_type)
http_client = await self._get_http_client()
async with AsyncExitStack() as stack:
activated_tools = []
if tools:
for tool in tools:
if tool.type == "mcp" and callable(tool.mcp_session):
func_obj = getattr(tool.mcp_session, "func", None)
tool_name = (
getattr(func_obj, "__name__", "unknown")
if func_obj
else "unknown"
)
logger.debug(f"正在激活 MCP 工具会话: {tool_name}")
response, _ = await self._execute_with_smart_retry(
adapter,
messages,
final_request_config,
tools,
tool_choice,
http_client,
)
active_session = await stack.enter_async_context(
tool.mcp_session()
)
activated_tools.append(
LLMTool.from_mcp_session(
session=active_session, annotations=tool.annotations
)
)
else:
activated_tools.append(tool)
llm_response = await self._execute_with_smart_retry(
adapter,
messages,
final_request_config,
activated_tools if activated_tools else None,
tool_choice,
http_client,
)
return llm_response
return response
async def generate_embeddings(
self,

View File

@ -5,17 +5,27 @@ LLM 服务 - 会话客户端
"""
import copy
from dataclasses import dataclass
from typing import Any
from dataclasses import dataclass, field
import json
from typing import Any, TypeVar
import uuid
from jinja2 import Environment
from nonebot.compat import type_validate_json
from nonebot_plugin_alconna.uniseg import UniMessage
from pydantic import BaseModel, ValidationError
from zhenxun.services.log import logger
from zhenxun.utils.pydantic_compat import model_copy, model_dump, model_json_schema
from .config import CommonOverrides, LLMGenerationConfig
from .config import (
CommonOverrides,
LLMGenerationConfig,
)
from .config.providers import get_ai_config
from .manager import get_global_default_model_name, get_model_instance
from .tools import tool_registry
from .memory import BaseMemory, InMemoryMemory
from .tools.manager import tool_provider_manager
from .types import (
EmbeddingTaskType,
LLMContentPart,
@ -23,67 +33,93 @@ from .types import (
LLMException,
LLMMessage,
LLMResponse,
LLMTool,
ModelName,
ResponseFormat,
ToolExecutable,
ToolProvider,
)
from .utils import unimsg_to_llm_parts
from .utils import normalize_to_llm_messages
T = TypeVar("T", bound=BaseModel)
jinja_env = Environment(autoescape=False)
@dataclass
class AIConfig:
"""AI配置类 - 简化版本"""
"""AI配置类 - [重构后] 简化版本"""
model: ModelName = None
default_embedding_model: ModelName = None
temperature: float | None = None
max_tokens: int | None = None
enable_cache: bool = False
enable_code: bool = False
enable_search: bool = False
timeout: int | None = None
enable_gemini_json_mode: bool = False
enable_gemini_thinking: bool = False
enable_gemini_safe_mode: bool = False
enable_gemini_multimodal: bool = False
enable_gemini_grounding: bool = False
default_preserve_media_in_history: bool = False
tool_providers: list[ToolProvider] = field(default_factory=list)
def __post_init__(self):
"""初始化后从配置中读取默认值"""
ai_config = get_ai_config()
if self.model is None:
self.model = ai_config.get("default_model_name")
if self.timeout is None:
self.timeout = ai_config.get("timeout", 180)
class AI:
"""统一的AI服务类 - 平衡设计版本
提供三层API
1. 简单方法ai.chat(), ai.code(), ai.search()
2. 标准方法ai.analyze() 支持复杂参数
3. 高级方法通过get_model_instance()直接访问
"""
统一的AI服务类 - 提供了带记忆的会话接口
不再执行自主工具循环当LLM返回工具调用时会直接将请求返回给调用者
"""
def __init__(
self, config: AIConfig | None = None, history: list[LLMMessage] | None = None
self,
session_id: str | None = None,
config: AIConfig | None = None,
memory: BaseMemory | None = None,
default_generation_config: LLMGenerationConfig | None = None,
):
"""
初始化AI服务
参数:
session_id: 唯一的会话ID用于隔离记忆
config: AI 配置.
history: 可选的初始对话历史.
memory: 可选的自定义记忆后端如果为None则使用默认的InMemoryMemory
default_generation_config: (新增) 此AI实例的默认生成配置
"""
self.session_id = session_id or str(uuid.uuid4())
self.config = config or AIConfig()
self.history = history or []
self.memory = memory or InMemoryMemory()
self.default_generation_config = (
default_generation_config or LLMGenerationConfig()
)
def clear_history(self):
"""清空当前会话的历史记录"""
self.history = []
logger.info("AI session history cleared.")
global_providers = tool_provider_manager._providers
config_providers = self.config.tool_providers
self._tool_providers = list(dict.fromkeys(global_providers + config_providers))
async def clear_history(self):
"""清空当前会话的历史记录。"""
await self.memory.clear_history(self.session_id)
logger.info(f"AI会话历史记录已清空 (session_id: {self.session_id})")
async def add_user_message_to_history(
self, message: str | LLMMessage | list[LLMContentPart]
):
"""
将一条用户消息标准化并添加到会话历史中
参数:
message: 用户消息内容
"""
user_message = await self._normalize_input_to_message(message)
await self.memory.add_message(self.session_id, user_message)
async def add_assistant_response_to_history(self, response_text: str):
"""
将助手的文本回复添加到会话历史中
参数:
response_text: 助手的回复文本
"""
assistant_message = LLMMessage.assistant_text_response(response_text)
await self.memory.add_message(self.session_id, assistant_message)
def _sanitize_message_for_history(self, message: LLMMessage) -> LLMMessage:
"""
@ -121,83 +157,122 @@ class AI:
sanitized_message.content = new_content_parts
return sanitized_message
async def _normalize_input_to_message(
self, message: str | UniMessage | LLMMessage | list[LLMContentPart]
) -> LLMMessage:
"""
[重构后] 内部辅助方法将各种输入类型统一转换为单个 LLMMessage 对象
它调用共享的工具函数并提取最后一条消息通常是用户输入
"""
messages = await normalize_to_llm_messages(message)
if not messages:
raise LLMException(
"无法将输入标准化为有效的消息。", code=LLMErrorCode.CONFIGURATION_ERROR
)
return messages[-1]
async def chat(
self,
message: str | LLMMessage | list[LLMContentPart],
message: str | UniMessage | LLMMessage | list[LLMContentPart],
*,
model: ModelName = None,
instruction: str | None = None,
template_vars: dict[str, Any] | None = None,
preserve_media_in_history: bool | None = None,
tools: list[LLMTool] | None = None,
tools: list[dict[str, Any] | str] | dict[str, ToolExecutable] | None = None,
tool_choice: str | dict[str, Any] | None = None,
**kwargs: Any,
config: LLMGenerationConfig | None = None,
) -> LLMResponse:
"""
进行一次聊天对话支持工具调用
此方法会自动使用和更新会话内的历史记录
核心交互方法管理会话历史并执行单次LLM调用
参数:
message: 用户输入的消息
model: 本次对话要使用的模型
preserve_media_in_history: 是否在历史记录中保留原始多模态信息
- True: 保留用于深度多轮媒体分析
- False: 不保留替换为占位符提高效率
- None (默认): 使用AI实例配置的默认值
tools: 本次对话可用的工具列表
tool_choice: 强制模型使用的工具
**kwargs: 传递给模型的其他生成参数
message: 用户输入的消息内容支持文本UniMessageLLMMessage或
内容部分列表
model: 要使用的模型名称如果为None则使用配置中的默认模型
instruction: 本次调用的特定系统指令会与全局指令合并
template_vars: 模板变量字典用于在指令中进行变量替换
preserve_media_in_history: 是否在历史记录中保留媒体内容
None时使用默认配置
tools: 可用的工具列表或工具字典支持临时工具和预配置工具
tool_choice: 工具选择策略控制AI如何选择和使用工具
config: 生成配置对象用于覆盖默认的生成参数
返回:
LLMResponse: 模型的完整响应可能包含文本或工具调用请求
LLMResponse: 包含AI回复工具调用请求使用信息等的完整响应对象
"""
current_message: LLMMessage
if isinstance(message, str):
current_message = LLMMessage.user(message)
elif isinstance(message, list) and all(
isinstance(part, LLMContentPart) for part in message
):
current_message = LLMMessage.user(message)
elif isinstance(message, LLMMessage):
current_message = message
else:
raise LLMException(
f"AI.chat 不支持的消息类型: {type(message)}. "
"请使用 str, LLMMessage, 或 list[LLMContentPart]. "
"对于更复杂的多模态输入或文件路径,请使用 AI.analyze().",
code=LLMErrorCode.API_REQUEST_FAILED,
current_message = await self._normalize_input_to_message(message)
messages_for_run = []
final_instruction = instruction
if final_instruction and template_vars:
try:
template = jinja_env.from_string(final_instruction)
final_instruction = template.render(**template_vars)
logger.debug(f"渲染后的系统指令: {final_instruction}")
except Exception as e:
logger.error(f"渲染系统指令模板失败: {e}", e=e)
if final_instruction:
messages_for_run.append(LLMMessage.system(final_instruction))
current_history = await self.memory.get_history(self.session_id)
messages_for_run.extend(current_history)
messages_for_run.append(current_message)
try:
resolved_model_name = self._resolve_model_name(model or self.config.model)
final_config = model_copy(self.default_generation_config, deep=True)
if config:
update_dict = model_dump(config, exclude_unset=True)
final_config = model_copy(final_config, update=update_dict)
ad_hoc_tools = None
if tools:
if isinstance(tools, dict):
ad_hoc_tools = tools
else:
ad_hoc_tools = await self._resolve_tools(tools)
async with await get_model_instance(
resolved_model_name,
override_config=final_config.to_dict(),
) as model_instance:
response = await model_instance.generate_response(
messages_for_run, tools=ad_hoc_tools, tool_choice=tool_choice
)
should_preserve = (
preserve_media_in_history
if preserve_media_in_history is not None
else self.config.default_preserve_media_in_history
)
user_msg_to_store = (
current_message
if should_preserve
else self._sanitize_message_for_history(current_message)
)
assistant_response_msg = LLMMessage.assistant_text_response(response.text)
if response.tool_calls:
assistant_response_msg = LLMMessage.assistant_tool_calls(
response.tool_calls, response.text
)
await self.memory.add_messages(
self.session_id, [user_msg_to_store, assistant_response_msg]
)
final_messages = [*self.history, current_message]
return response
response = await self._execute_generation(
messages=final_messages,
model_name=model,
error_message="聊天失败",
config_overrides=kwargs,
llm_tools=tools,
tool_choice=tool_choice,
)
should_preserve = (
preserve_media_in_history
if preserve_media_in_history is not None
else self.config.default_preserve_media_in_history
)
if should_preserve:
logger.debug("深度分析模式:在历史记录中保留原始多模态消息。")
self.history.append(current_message)
else:
logger.debug("高效模式:净化历史记录中的多模态消息。")
sanitized_user_message = self._sanitize_message_for_history(current_message)
self.history.append(sanitized_user_message)
self.history.append(
LLMMessage(
role="assistant", content=response.text, tool_calls=response.tool_calls
except Exception as e:
raise (
e
if isinstance(e, LLMException)
else LLMException(f"聊天执行失败: {e}", cause=e)
)
)
return response
async def code(
self,
@ -205,8 +280,8 @@ class AI:
*,
model: ModelName = None,
timeout: int | None = None,
**kwargs: Any,
) -> dict[str, Any]:
config: LLMGenerationConfig | None = None,
) -> LLMResponse:
"""
代码执行
@ -214,217 +289,120 @@ class AI:
prompt: 代码执行的提示词
model: 要使用的模型名称
timeout: 代码执行超时时间
**kwargs: 传递给模型的其他参数
config: (可选) 覆盖默认的生成配置
返回:
dict[str, Any]: 包含执行结果的字典包含textcode_executions和success字段
LLMResponse: 包含执行结果的完整响应对象
"""
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
config = CommonOverrides.gemini_code_execution()
code_config = CommonOverrides.gemini_code_execution()
if timeout:
config.custom_params = config.custom_params or {}
config.custom_params["code_execution_timeout"] = timeout
code_config.custom_params = code_config.custom_params or {}
code_config.custom_params["code_execution_timeout"] = timeout
messages = [LLMMessage.user(prompt)]
if config:
update_dict = model_dump(config, exclude_unset=True)
code_config = model_copy(code_config, update=update_dict)
response = await self._execute_generation(
messages=messages,
model_name=resolved_model,
error_message="代码执行失败",
config_overrides=kwargs,
base_config=config,
)
return {
"text": response.text,
"code_executions": response.code_executions or [],
"success": True,
}
return await self.chat(prompt, model=resolved_model, config=code_config)
async def search(
self,
query: str | UniMessage,
query: UniMessage,
*,
model: ModelName = None,
instruction: str = "",
**kwargs: Any,
) -> dict[str, Any]:
instruction: str = (
"你是一位强大的信息检索和整合专家。请利用可用的搜索工具,"
"根据用户的查询找到最相关的信息,并进行总结和回答。"
),
template_vars: dict[str, Any] | None = None,
config: LLMGenerationConfig | None = None,
) -> LLMResponse:
"""
信息搜索 - 支持多模态输入
参数:
query: 搜索查询内容支持文本或多模态消息
model: 要使用的模型名称
instruction: 搜索指令
**kwargs: 传递给模型的其他参数
返回:
dict[str, Any]: 包含搜索结果的字典包含textsourcesqueries和success字段
信息搜索的便捷入口原生支持多模态查询
"""
from nonebot_plugin_alconna.uniseg import UniMessage
logger.info("执行 'search' 任务...")
search_config = CommonOverrides.gemini_grounding()
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
config = CommonOverrides.gemini_grounding()
if config:
update_dict = model_dump(config, exclude_unset=True)
search_config = model_copy(search_config, update=update_dict)
if isinstance(query, str):
messages = [LLMMessage.user(query)]
elif isinstance(query, UniMessage):
content_parts = await unimsg_to_llm_parts(query)
final_messages: list[LLMMessage] = []
if instruction:
final_messages.append(LLMMessage.system(instruction))
if not content_parts:
if instruction:
final_messages.append(LLMMessage.user(instruction))
else:
raise LLMException(
"搜索内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
)
else:
final_messages.append(LLMMessage.user(content_parts))
messages = final_messages
else:
raise LLMException(
f"不支持的搜索输入类型: {type(query)}. 请使用 str 或 UniMessage.",
code=LLMErrorCode.API_REQUEST_FAILED,
)
response = await self._execute_generation(
messages=messages,
model_name=resolved_model,
error_message="信息搜索失败",
config_overrides=kwargs,
base_config=config,
return await self.chat(
query,
model=model,
instruction=instruction,
template_vars=template_vars,
config=search_config,
)
result = {
"text": response.text,
"sources": [],
"queries": [],
"success": True,
}
if response.grounding_metadata:
result["sources"] = response.grounding_metadata.grounding_attributions or []
result["queries"] = response.grounding_metadata.web_search_queries or []
return result
async def analyze(
async def generate_structured(
self,
message: UniMessage | None,
message: str | LLMMessage | list[LLMContentPart],
response_model: type[T],
*,
instruction: str = "",
model: ModelName = None,
use_tools: list[str] | None = None,
tool_config: dict[str, Any] | None = None,
activated_tools: list[LLMTool] | None = None,
history: list[LLMMessage] | None = None,
**kwargs: Any,
) -> LLMResponse:
instruction: str | None = None,
config: LLMGenerationConfig | None = None,
) -> T:
"""
内容分析 - 接收 UniMessage 物件进行多模态分析和工具呼叫
生成结构化响应并自动解析为指定的Pydantic模型
参数:
message: 要分析的消息内容支持多模态
instruction: 分析指令
model: 要使用的模型名称
use_tools: 要使用的工具名称列表
tool_config: 工具配置
activated_tools: 已激活的工具列表
history: 对话历史记录
**kwargs: 传递给模型的其他参数
message: 用户输入的消息内容支持多种格式
response_model: 用于解析和验证响应的Pydantic模型类
model: 要使用的模型名称如果为None则使用配置中的默认模型
instruction: 本次调用的特定系统指令会与JSON Schema指令合并
config: 生成配置对象用于覆盖默认的生成参数
返回:
LLMResponse: 模型的完整响应结果
T: 解析后的Pydantic模型实例类型为response_model指定的类型
异常:
LLMException: 如果模型返回的不是有效的JSON或验证失败
"""
from nonebot_plugin_alconna.uniseg import UniMessage
content_parts = await unimsg_to_llm_parts(message or UniMessage())
final_messages: list[LLMMessage] = []
if history:
final_messages.extend(history)
if instruction:
if not any(msg.role == "system" for msg in final_messages):
final_messages.insert(0, LLMMessage.system(instruction))
if not content_parts:
if instruction and not history:
final_messages.append(LLMMessage.user(instruction))
elif not history:
raise LLMException(
"分析内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
)
else:
final_messages.append(LLMMessage.user(content_parts))
llm_tools: list[LLMTool] | None = activated_tools
if not llm_tools and use_tools:
try:
llm_tools = tool_registry.get_tools(use_tools)
logger.debug(f"已从注册表加载工具定义: {use_tools}")
except ValueError as e:
raise LLMException(
f"加载工具定义失败: {e}",
code=LLMErrorCode.CONFIGURATION_ERROR,
cause=e,
)
tool_choice = None
if tool_config:
mode = tool_config.get("mode", "auto")
if mode in ["auto", "any", "none"]:
tool_choice = mode
response = await self._execute_generation(
messages=final_messages,
model_name=model,
error_message="内容分析失败",
config_overrides=kwargs,
llm_tools=llm_tools,
tool_choice=tool_choice,
)
return response
async def _execute_generation(
self,
messages: list[LLMMessage],
model_name: ModelName,
error_message: str,
config_overrides: dict[str, Any],
llm_tools: list[LLMTool] | None = None,
tool_choice: str | dict[str, Any] | None = None,
base_config: LLMGenerationConfig | None = None,
) -> LLMResponse:
"""通用的生成执行方法封装模型获取和单次API调用"""
try:
resolved_model_name = self._resolve_model_name(
model_name or self.config.model
)
final_config_dict = self._merge_config(
config_overrides, base_config=base_config
)
json_schema = model_json_schema(response_model)
except AttributeError:
json_schema = response_model.schema()
async with await get_model_instance(
resolved_model_name, override_config=final_config_dict
) as model_instance:
return await model_instance.generate_response(
messages,
tools=llm_tools,
tool_choice=tool_choice,
)
except LLMException:
raise
schema_str = json.dumps(json_schema, ensure_ascii=False, indent=2)
system_prompt = (
(f"{instruction}\n\n" if instruction else "")
+ "你必须严格按照以下 JSON Schema 格式进行响应。"
+ "不要包含任何额外的解释、注释或代码块标记,只返回纯粹的 JSON 对象。\n\n"
)
system_prompt += f"JSON Schema:\n```json\n{schema_str}\n```"
final_config = model_copy(config) if config else LLMGenerationConfig()
final_config.response_format = ResponseFormat.JSON
final_config.response_schema = json_schema
response = await self.chat(
message, model=model, instruction=system_prompt, config=final_config
)
try:
return type_validate_json(response_model, response.text)
except ValidationError as e:
logger.error(f"LLM结构化输出验证失败: {e}", e=e)
raise LLMException(
"LLM返回的JSON未能通过结构验证。",
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
details={"raw_response": response.text, "validation_error": str(e)},
cause=e,
)
except Exception as e:
logger.error(f"{error_message}: {e}", e=e)
raise LLMException(f"{error_message}: {e}", cause=e)
logger.error(f"解析LLM结构化输出时发生未知错误: {e}", e=e)
raise LLMException(
"解析LLM的JSON输出时失败。",
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
details={"raw_response": response.text},
cause=e,
)
def _resolve_model_name(self, model_name: ModelName) -> str:
"""解析模型名称"""
@ -440,45 +418,6 @@ class AI:
code=LLMErrorCode.MODEL_NOT_FOUND,
)
def _merge_config(
self,
user_config: dict[str, Any],
base_config: LLMGenerationConfig | None = None,
) -> dict[str, Any]:
"""合并配置"""
final_config = {}
if base_config:
final_config.update(base_config.to_dict())
if self.config.temperature is not None:
final_config["temperature"] = self.config.temperature
if self.config.max_tokens is not None:
final_config["max_tokens"] = self.config.max_tokens
if self.config.enable_cache:
final_config["enable_caching"] = True
if self.config.enable_code:
final_config["enable_code_execution"] = True
if self.config.enable_search:
final_config["enable_grounding"] = True
if self.config.enable_gemini_json_mode:
final_config["response_mime_type"] = "application/json"
if self.config.enable_gemini_thinking:
final_config["thinking_budget"] = 0.8
if self.config.enable_gemini_safe_mode:
final_config["safety_settings"] = (
CommonOverrides.gemini_safe().safety_settings
)
if self.config.enable_gemini_multimodal:
final_config.update(CommonOverrides.gemini_multimodal().to_dict())
if self.config.enable_gemini_grounding:
final_config["enable_grounding"] = True
final_config.update(user_config)
return final_config
async def embed(
self,
texts: list[str] | str,
@ -488,16 +427,19 @@ class AI:
**kwargs: Any,
) -> list[list[float]]:
"""
生成文本嵌入向量
生成文本嵌入向量将文本转换为数值向量表示
参数:
texts: 要生成嵌入向量的文本或文本列表
model: 要使用的嵌入模型名称
task_type: 嵌入任务类型
**kwargs: 传递给模型的其他参数
texts: 要生成嵌入的文本内容支持单个字符串或字符串列表
model: 嵌入模型名称如果为None则使用配置中的默认嵌入模型
task_type: 嵌入任务类型影响向量的优化方向如检索分类等
**kwargs: 传递给嵌入模型的额外参数
返回:
list[list[float]]: 文本的嵌入向量列表
list[list[float]]: 文本对应的嵌入向量列表每个向量为浮点数列表
异常:
LLMException: 如果嵌入生成失败或模型配置错误
"""
if isinstance(texts, str):
texts = [texts]
@ -530,3 +472,44 @@ class AI:
raise LLMException(
f"文本嵌入失败: {e}", code=LLMErrorCode.EMBEDDING_FAILED, cause=e
)
async def _resolve_tools(
self,
tool_configs: list[Any],
) -> dict[str, ToolExecutable]:
"""
使用注入的 ToolProvider 异步解析 ad-hoc临时工具配置
返回一个从工具名称到可执行对象的字典
"""
resolved: dict[str, ToolExecutable] = {}
for config in tool_configs:
name = config if isinstance(config, str) else config.get("name")
if not name:
raise LLMException(
"工具配置字典必须包含 'name' 字段。",
code=LLMErrorCode.CONFIGURATION_ERROR,
)
if isinstance(config, str):
config_dict = {"name": name, "type": "function"}
elif isinstance(config, dict):
config_dict = config
else:
raise TypeError(f"不支持的工具配置类型: {type(config)}")
executable = None
for provider in self._tool_providers:
executable = await provider.get_tool_executable(name, config_dict)
if executable:
break
if not executable:
raise LLMException(
f"没有为 ad-hoc 工具 '{name}' 找到合适的提供者。",
code=LLMErrorCode.CONFIGURATION_ERROR,
)
resolved[name] = executable
return resolved

View File

@ -2,6 +2,12 @@
工具模块导出
"""
from .registry import tool_registry
from .manager import tool_provider_manager
__all__ = ["tool_registry"]
function_tool = tool_provider_manager.function_tool
__all__ = [
"function_tool",
"tool_provider_manager",
]

View File

@ -0,0 +1,293 @@
"""
工具提供者管理器
负责注册生命周期管理包括懒加载和统一提供所有工具
"""
import asyncio
from collections.abc import Callable
import inspect
from typing import Any
from pydantic import BaseModel
from zhenxun.services.log import logger
from zhenxun.utils.pydantic_compat import model_json_schema
from ..types import ToolExecutable, ToolProvider
from ..types.models import ToolDefinition, ToolResult
class FunctionExecutable(ToolExecutable):
"""一个 ToolExecutable 的实现,用于包装一个普通的 Python 函数。"""
def __init__(
self,
func: Callable,
name: str,
description: str,
params_model: type[BaseModel] | None,
):
self._func = func
self._name = name
self._description = description
self._params_model = params_model
async def get_definition(self) -> ToolDefinition:
if not self._params_model:
return ToolDefinition(
name=self._name,
description=self._description,
parameters={"type": "object", "properties": {}},
)
schema = model_json_schema(self._params_model)
return ToolDefinition(
name=self._name,
description=self._description,
parameters={
"type": "object",
"properties": schema.get("properties", {}),
"required": schema.get("required", []),
},
)
async def execute(self, **kwargs: Any) -> ToolResult:
raw_result: Any
if self._params_model:
try:
params_instance = self._params_model(**kwargs)
if inspect.iscoroutinefunction(self._func):
raw_result = await self._func(params_instance)
else:
loop = asyncio.get_event_loop()
raw_result = await loop.run_in_executor(
None, lambda: self._func(params_instance)
)
except Exception as e:
logger.error(
f"执行工具 '{self._name}' 时参数验证或实例化失败: {e}", e=e
)
raise
else:
if inspect.iscoroutinefunction(self._func):
raw_result = await self._func(**kwargs)
else:
loop = asyncio.get_event_loop()
raw_result = await loop.run_in_executor(
None, lambda: self._func(**kwargs)
)
return ToolResult(output=raw_result, display_content=str(raw_result))
class BuiltinFunctionToolProvider(ToolProvider):
"""一个内置的 ToolProvider用于处理通过装饰器注册的函数。"""
def __init__(self):
self._functions: dict[str, dict[str, Any]] = {}
def register(
self,
name: str,
func: Callable,
description: str,
params_model: type[BaseModel] | None,
):
self._functions[name] = {
"func": func,
"description": description,
"params_model": params_model,
}
async def initialize(self) -> None:
pass
async def discover_tools(
self,
allowed_servers: list[str] | None = None,
excluded_servers: list[str] | None = None,
) -> dict[str, ToolExecutable]:
executables = {}
for name, info in self._functions.items():
executables[name] = FunctionExecutable(
func=info["func"],
name=name,
description=info["description"],
params_model=info["params_model"],
)
return executables
async def get_tool_executable(
self, name: str, config: dict[str, Any]
) -> ToolExecutable | None:
if config.get("type") == "function" and name in self._functions:
info = self._functions[name]
return FunctionExecutable(
func=info["func"],
name=name,
description=info["description"],
params_model=info["params_model"],
)
return None
class ToolProviderManager:
"""工具提供者的中心化管理器,采用单例模式。"""
_instance: "ToolProviderManager | None" = None
def __new__(cls) -> "ToolProviderManager":
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if hasattr(self, "_initialized") and self._initialized:
return
self._providers: list[ToolProvider] = []
self._resolved_tools: dict[str, ToolExecutable] | None = None
self._init_lock = asyncio.Lock()
self._init_promise: asyncio.Task | None = None
self._builtin_function_provider = BuiltinFunctionToolProvider()
self.register(self._builtin_function_provider)
self._initialized = True
def register(self, provider: ToolProvider):
"""注册一个新的 ToolProvider。"""
if provider not in self._providers:
self._providers.append(provider)
logger.info(f"已注册工具提供者: {provider.__class__.__name__}")
def function_tool(
self,
name: str,
description: str,
params_model: type[BaseModel] | None = None,
):
"""装饰器:将一个函数注册为内置工具。"""
def decorator(func: Callable):
if name in self._builtin_function_provider._functions:
logger.warning(f"正在覆盖已注册的函数工具: {name}")
self._builtin_function_provider.register(
name=name,
func=func,
description=description,
params_model=params_model,
)
logger.info(f"已注册函数工具: '{name}'")
return func
return decorator
async def initialize(self) -> None:
"""懒加载初始化所有已注册的 ToolProvider。"""
if not self._init_promise:
async with self._init_lock:
if not self._init_promise:
self._init_promise = asyncio.create_task(
self._initialize_providers()
)
await self._init_promise
async def _initialize_providers(self) -> None:
"""内部初始化逻辑。"""
logger.info(f"开始初始化 {len(self._providers)} 个工具提供者...")
init_tasks = [provider.initialize() for provider in self._providers]
await asyncio.gather(*init_tasks, return_exceptions=True)
logger.info("所有工具提供者初始化完成。")
async def get_resolved_tools(
self,
allowed_servers: list[str] | None = None,
excluded_servers: list[str] | None = None,
) -> dict[str, ToolExecutable]:
"""
获取所有已发现和解析的工具
此方法会触发懒加载初始化并根据是否传入过滤器来决定是否使用全局缓存
"""
await self.initialize()
has_filters = allowed_servers is not None or excluded_servers is not None
if not has_filters and self._resolved_tools is not None:
logger.debug("使用全局工具缓存。")
return self._resolved_tools
if has_filters:
logger.info("检测到过滤器,执行临时工具发现 (不使用缓存)。")
logger.debug(
f"过滤器详情: allowed_servers={allowed_servers}, "
f"excluded_servers={excluded_servers}"
)
else:
logger.info("未应用过滤器,开始全局工具发现...")
all_tools: dict[str, ToolExecutable] = {}
discover_tasks = []
for provider in self._providers:
sig = inspect.signature(provider.discover_tools)
params_to_pass = {}
if "allowed_servers" in sig.parameters:
params_to_pass["allowed_servers"] = allowed_servers
if "excluded_servers" in sig.parameters:
params_to_pass["excluded_servers"] = excluded_servers
discover_tasks.append(provider.discover_tools(**params_to_pass))
results = await asyncio.gather(*discover_tasks, return_exceptions=True)
for i, provider_result in enumerate(results):
provider_name = self._providers[i].__class__.__name__
if isinstance(provider_result, dict):
logger.debug(
f"提供者 '{provider_name}' 发现了 {len(provider_result)} 个工具。"
)
for name, executable in provider_result.items():
if name in all_tools:
logger.warning(
f"发现重复的工具名称 '{name}',后发现的将覆盖前者。"
)
all_tools[name] = executable
elif isinstance(provider_result, Exception):
logger.error(
f"提供者 '{provider_name}' 在发现工具时出错: {provider_result}"
)
if not has_filters:
self._resolved_tools = all_tools
logger.info(f"全局工具发现完成,共找到并缓存了 {len(all_tools)} 个工具。")
else:
logger.info(f"带过滤器的工具发现完成,共找到 {len(all_tools)} 个工具。")
return all_tools
async def get_function_tools(
self, names: list[str] | None = None
) -> dict[str, ToolExecutable]:
"""
仅从内置的函数提供者中解析指定的工具
"""
all_function_tools = await self._builtin_function_provider.discover_tools()
if names is None:
return all_function_tools
resolved_tools = {}
for name in names:
if name in all_function_tools:
resolved_tools[name] = all_function_tools[name]
else:
logger.warning(
f"本地函数工具 '{name}' 未通过 @function_tool 注册,将被忽略。"
)
return resolved_tools
tool_provider_manager = ToolProviderManager()

View File

@ -1,181 +0,0 @@
"""
工具注册表
负责加载管理和实例化来自配置的工具
"""
from collections.abc import Callable
from contextlib import AbstractAsyncContextManager
from functools import partial
from typing import TYPE_CHECKING
from pydantic import BaseModel
from zhenxun.services.log import logger
from ..types import LLMTool
if TYPE_CHECKING:
from ..config.providers import ToolConfig
from ..types.protocols import MCPCompatible
class ToolRegistry:
"""工具注册表,用于管理和实例化配置的工具。"""
def __init__(self):
self._function_tools: dict[str, LLMTool] = {}
self._mcp_config_models: dict[str, type[BaseModel]] = {}
if TYPE_CHECKING:
self._mcp_factories: dict[
str, Callable[..., AbstractAsyncContextManager["MCPCompatible"]]
] = {}
else:
self._mcp_factories: dict[str, Callable] = {}
self._tool_configs: dict[str, "ToolConfig"] | None = None
self._tool_cache: dict[str, "LLMTool"] = {}
def _load_configs_if_needed(self):
"""如果尚未加载则从主配置中加载MCP工具定义。"""
if self._tool_configs is None:
logger.debug("首次访问正在加载MCP工具配置...")
from ..config.providers import get_llm_config
llm_config = get_llm_config()
self._tool_configs = {tool.name: tool for tool in llm_config.mcp_tools}
logger.info(f"已加载 {len(self._tool_configs)} 个MCP工具配置。")
def function_tool(
self,
name: str,
description: str,
parameters: dict,
required: list[str] | None = None,
):
"""
装饰器在代码中注册一个简单的无状态的函数工具
参数:
name: 工具的唯一名称
description: 工具功能的描述
parameters: OpenAPI格式的函数参数schema的properties部分
required: 必需的参数列表
"""
def decorator(func: Callable):
if name in self._function_tools or name in self._mcp_factories:
logger.warning(f"正在覆盖已注册的工具: {name}")
tool_definition = LLMTool.create(
name=name,
description=description,
parameters=parameters,
required=required,
)
self._function_tools[name] = tool_definition
logger.info(f"已在代码中注册函数工具: '{name}'")
tool_definition.annotations = tool_definition.annotations or {}
tool_definition.annotations["executable"] = func
return func
return decorator
def mcp_tool(self, name: str, config_model: type[BaseModel]):
"""
装饰器注册一个MCP工具及其配置模型
参数:
name: 工具的唯一名称必须与配置文件中的名称匹配
config_model: 一个Pydantic模型用于定义和验证该工具的 `mcp_config`
"""
def decorator(factory_func: Callable):
if name in self._mcp_factories:
logger.warning(f"正在覆盖已注册的 MCP 工厂: {name}")
self._mcp_factories[name] = factory_func
self._mcp_config_models[name] = config_model
logger.info(f"已注册 MCP 工具 '{name}' (配置模型: {config_model.__name__})")
return factory_func
return decorator
def get_mcp_config_model(self, name: str) -> type[BaseModel] | None:
"""根据名称获取MCP工具的配置模型。"""
return self._mcp_config_models.get(name)
def register_mcp_factory(
self,
name: str,
factory: Callable,
):
"""
在代码中注册一个 MCP 会话工厂将其与配置中的工具名称关联
参数:
name: 工具的唯一名称必须与配置文件中的名称匹配
factory: 一个返回异步生成器的可调用对象会话工厂
"""
if name in self._mcp_factories:
logger.warning(f"正在覆盖已注册的 MCP 工厂: {name}")
self._mcp_factories[name] = factory
logger.info(f"已注册 MCP 会话工厂: '{name}'")
def get_tool(self, name: str) -> "LLMTool":
"""
根据名称获取一个 LLMTool 定义
对于MCP工具返回的 LLMTool 实例包含一个可调用的会话工厂
而不是一个已激活的会话
"""
logger.debug(f"🔍 请求获取工具定义: {name}")
if name in self._tool_cache:
logger.debug(f"✅ 从缓存中获取工具定义: {name}")
return self._tool_cache[name]
if name in self._function_tools:
logger.debug(f"🛠️ 获取函数工具定义: {name}")
tool = self._function_tools[name]
self._tool_cache[name] = tool
return tool
self._load_configs_if_needed()
if self._tool_configs is None or name not in self._tool_configs:
known_tools = list(self._function_tools.keys()) + (
list(self._tool_configs.keys()) if self._tool_configs else []
)
logger.error(f"❌ 未找到名为 '{name}' 的工具定义")
logger.debug(f"📋 可用工具定义列表: {known_tools}")
raise ValueError(f"未找到名为 '{name}' 的工具定义。已知工具: {known_tools}")
config = self._tool_configs[name]
tool: "LLMTool"
if name not in self._mcp_factories:
logger.error(f"❌ MCP工具 '{name}' 缺少工厂函数")
available_factories = list(self._mcp_factories.keys())
logger.debug(f"📋 已注册的MCP工厂: {available_factories}")
raise ValueError(
f"MCP 工具 '{name}' 已在配置中定义,但没有注册对应的工厂函数。"
"请使用 `@tool_registry.mcp_tool` 装饰器进行注册。"
)
logger.info(f"🔧 创建MCP工具定义: {name}")
factory = self._mcp_factories[name]
typed_mcp_config = config.mcp_config
logger.debug(f"📋 MCP工具配置: {typed_mcp_config}")
configured_factory = partial(factory, config=typed_mcp_config)
tool = LLMTool.from_mcp_session(session=configured_factory)
self._tool_cache[name] = tool
logger.debug(f"💾 MCP工具定义已缓存: {name}")
return tool
def get_tools(self, names: list[str]) -> list["LLMTool"]:
"""根据名称列表获取多个 LLMTool 实例。"""
return [self.get_tool(name) for name in names]
tool_registry = ToolRegistry()

View File

@ -23,7 +23,6 @@ from .models import (
LLMCodeExecution,
LLMGroundingAttribution,
LLMGroundingMetadata,
LLMTool,
LLMToolCall,
LLMToolFunction,
ModelDetail,
@ -31,9 +30,10 @@ from .models import (
ModelName,
ProviderConfig,
ToolMetadata,
ToolResult,
UsageInfo,
)
from .protocols import MCPCompatible
from .protocols import ToolExecutable, ToolProvider
__all__ = [
"EmbeddingTaskType",
@ -46,10 +46,8 @@ __all__ = [
"LLMGroundingMetadata",
"LLMMessage",
"LLMResponse",
"LLMTool",
"LLMToolCall",
"LLMToolFunction",
"MCPCompatible",
"ModelCapabilities",
"ModelDetail",
"ModelInfo",
@ -60,7 +58,10 @@ __all__ = [
"ResponseFormat",
"TaskType",
"ToolCategory",
"ToolExecutable",
"ToolMetadata",
"ToolProvider",
"ToolResult",
"UsageInfo",
"get_model_capabilities",
"get_user_friendly_error_message",

View File

@ -405,7 +405,7 @@ class LLMMessage(BaseModel):
f"工具 '{function_name}' 的结果无法JSON序列化: {result}. 错误: {e}"
)
content_str = json.dumps(
{"error": "Tool result not JSON serializable", "details": str(e)}
{"error": "工具结果无法JSON序列化", "details": str(e)}
)
return cls(

View File

@ -4,28 +4,39 @@ LLM 数据模型定义
包含模型信息配置工具定义和响应数据的模型类
"""
from collections.abc import Callable
from contextlib import AbstractAsyncContextManager
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from typing import Any
from pydantic import BaseModel, Field
from .enums import ModelProvider, ToolCategory
if TYPE_CHECKING:
from .protocols import MCPCompatible
MCPSessionType = (
MCPCompatible | Callable[[], AbstractAsyncContextManager[MCPCompatible]] | None
)
else:
MCPCompatible = object
MCPSessionType = Any
ModelName = str | None
class ToolDefinition(BaseModel):
"""
一个结构化的工具定义模型用于向LLM描述工具
"""
name: str = Field(..., description="工具的唯一名称标识")
description: str = Field(..., description="工具功能的清晰描述")
parameters: dict[str, Any] = Field(
default_factory=dict, description="符合JSON Schema规范的参数定义"
)
class ToolResult(BaseModel):
"""
一个结构化的工具执行结果模型
"""
output: Any = Field(..., description="返回给LLM的、可JSON序列化的原始输出")
display_content: str | None = Field(
default=None, description="用于日志或UI展示的人类可读的执行摘要"
)
@dataclass(frozen=True)
class ModelInfo:
"""模型信息(不可变数据类)"""
@ -107,55 +118,6 @@ class LLMToolCall(BaseModel):
function: LLMToolFunction
class LLMTool(BaseModel):
"""LLM 工具定义(支持 MCP 风格)"""
model_config = {"arbitrary_types_allowed": True}
type: str = "function"
function: dict[str, Any] | None = None
mcp_session: MCPSessionType = None
annotations: dict[str, Any] | None = Field(default=None, description="工具注解")
def model_post_init(self, /, __context: Any) -> None:
"""验证工具定义的有效性"""
_ = __context
if self.type == "function" and self.function is None:
raise ValueError("函数类型的工具必须包含 'function' 字段。")
if self.type == "mcp" and self.mcp_session is None:
raise ValueError("MCP 类型的工具必须包含 'mcp_session' 字段。")
@classmethod
def create(
cls,
name: str,
description: str,
parameters: dict[str, Any],
required: list[str] | None = None,
annotations: dict[str, Any] | None = None,
) -> "LLMTool":
"""创建函数工具"""
function_def = {
"name": name,
"description": description,
"parameters": {
"type": "object",
"properties": parameters,
"required": required or [],
},
}
return cls(type="function", function=function_def, annotations=annotations)
@classmethod
def from_mcp_session(
cls,
session: Any,
annotations: dict[str, Any] | None = None,
) -> "LLMTool":
"""从 MCP 会话创建工具"""
return cls(type="mcp", mcp_session=session, annotations=annotations)
class LLMCodeExecution(BaseModel):
"""代码执行结果"""

View File

@ -4,21 +4,62 @@ LLM 模块的协议定义
from typing import Any, Protocol
from .models import ToolDefinition, ToolResult
class MCPCompatible(Protocol):
class ToolExecutable(Protocol):
"""
一个协议定义了与LLM模块兼容的MCP会话对象应具备的行为
任何实现了 to_api_tool 方法的对象都可以被认为是 MCPCompatible
一个协议定义了所有可被LLM调用的工具必须实现的行为
它将工具的"定义"给LLM看"执行"由框架调用封装在一起
"""
def to_api_tool(self, api_type: str) -> dict[str, Any]:
async def get_definition(self) -> ToolDefinition:
"""
将此MCP会话转换为特定LLM提供商API所需的工具格式
参数:
api_type: 目标API的类型 (例如 'gemini', 'openai')
返回:
dict[str, Any]: 一个字典代表可以在API请求中使用的工具定义
异步地获取一个结构化的工具定义
"""
...
async def execute(self, **kwargs: Any) -> ToolResult:
"""
异步执行工具并返回一个结构化的结果
参数由LLM根据工具定义生成
"""
...
class ToolProvider(Protocol):
"""
一个协议定义了"工具提供者"的行为
工具提供者负责发现或实例化具体的 ToolExecutable 对象
"""
async def initialize(self) -> None:
"""
异步初始化提供者
此方法应是幂等的即多次调用只会执行一次初始化逻辑
用于执行耗时的I/O操作如网络请求或启动子进程
"""
...
async def discover_tools(
self,
allowed_servers: list[str] | None = None,
excluded_servers: list[str] | None = None,
) -> dict[str, ToolExecutable]:
"""
异步发现此提供者提供的所有工具
`initialize` 成功调用后才应被调用
返回:
一个从工具名称到 ToolExecutable 实例的字典
"""
...
async def get_tool_executable(
self, name: str, config: dict[str, Any]
) -> ToolExecutable | None:
"""
保留如果此提供者能处理名为 'name' 的工具则返回一个可执行实例
此方法主要用于按需解析 ad-hoc 工具
"""
...

View File

@ -5,6 +5,7 @@ LLM 模块的工具和转换函数
import base64
import copy
from pathlib import Path
from typing import Any
from nonebot.adapters import Message as PlatformMessage
from nonebot_plugin_alconna.uniseg import (
@ -21,7 +22,7 @@ from nonebot_plugin_alconna.uniseg import (
from zhenxun.services.log import logger
from zhenxun.utils.http_utils import AsyncHttpx
from .types import LLMContentPart
from .types import LLMContentPart, LLMMessage
async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
@ -112,9 +113,9 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
elif isinstance(seg, At):
if seg.flag == "all":
part = LLMContentPart.text_part("[Mentioned Everyone]")
part = LLMContentPart.text_part("[提及所有人]")
else:
part = LLMContentPart.text_part(f"[Mentioned user: {seg.target}]")
part = LLMContentPart.text_part(f"[提及用户: {seg.target}]")
elif isinstance(seg, Reply):
if seg.msg:
@ -126,10 +127,10 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
reply_text = str(seg.msg).strip()
if reply_text:
part = LLMContentPart.text_part(
f'[Replied to: "{reply_text[:50]}..."]'
f'[回复消息: "{reply_text[:50]}..."]'
)
except Exception:
part = LLMContentPart.text_part("[Replied to a message]")
part = LLMContentPart.text_part("[回复了一条消息]")
if part:
parts.append(part)
@ -137,6 +138,42 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
return parts
async def normalize_to_llm_messages(
message: str | UniMessage | LLMMessage | list[LLMContentPart] | list[LLMMessage],
instruction: str | None = None,
) -> list[LLMMessage]:
"""
将多种输入格式标准化为 LLMMessage 列表并可选地添加系统指令
这是处理 LLM 输入的核心工具函数
参数:
message: 要标准化的输入消息
instruction: 可选的系统指令
返回:
list[LLMMessage]: 标准化后的消息列表
"""
messages = []
if instruction:
messages.append(LLMMessage.system(instruction))
if isinstance(message, LLMMessage):
messages.append(message)
elif isinstance(message, list) and all(isinstance(m, LLMMessage) for m in message):
messages.extend(message)
elif isinstance(message, str):
messages.append(LLMMessage.user(message))
elif isinstance(message, UniMessage):
content_parts = await unimsg_to_llm_parts(message)
messages.append(LLMMessage.user(content_parts))
elif isinstance(message, list):
messages.append(LLMMessage.user(message)) # type: ignore
else:
raise TypeError(f"不支持的消息类型: {type(message)}")
return messages
def create_multimodal_message(
text: str | None = None,
images: list[str | Path | bytes] | str | Path | bytes | None = None,
@ -282,3 +319,37 @@ def _sanitize_request_body_for_logging(body: dict) -> dict:
except Exception as e:
logger.warning(f"日志净化失败: {e},将记录原始请求体。")
return body
def sanitize_schema_for_llm(schema: Any, api_type: str) -> Any:
"""
递归地净化 JSON Schema移除特定 LLM API 不支持的关键字
参数:
schema: 要净化的 JSON Schema (可以是字典列表或其它类型)
api_type: 目标 API 的类型例如 'gemini'
返回:
Any: 净化后的 JSON Schema
"""
if isinstance(schema, dict):
schema_copy = {}
for key, value in schema.items():
if api_type == "gemini":
unsupported_keys = ["exclusiveMinimum", "exclusiveMaximum", "default"]
if key in unsupported_keys:
continue
if key == "format" and isinstance(value, str):
supported_formats = ["enum", "date-time"]
if value not in supported_formats:
continue
schema_copy[key] = sanitize_schema_for_llm(value, api_type)
return schema_copy
elif isinstance(schema, list):
return [sanitize_schema_for_llm(item, api_type) for item in schema]
else:
return schema

View File

@ -1,101 +0,0 @@
import os
from pathlib import Path
import shutil
import zipfile
from zhenxun.configs.path_config import FONT_PATH, TEMP_PATH
from zhenxun.services.log import logger
from zhenxun.utils.github_utils import GithubUtils
from zhenxun.utils.http_utils import AsyncHttpx
from zhenxun.utils.repo_utils import AliyunRepoManager, GithubRepoManager
from zhenxun.utils.repo_utils.utils import clean_git
LOG_COMMAND = "ResourceManager"
class DownloadResourceException(Exception):
pass
class ResourceManager:
GITHUB_URL = "https://github.com/zhenxun-org/zhenxun-bot-resources/tree/main"
RESOURCE_PATH = Path() / "resources"
TMP_PATH = TEMP_PATH / "_resource_tmp"
ZIP_FILE = TMP_PATH / "resources.zip"
UNZIP_PATH = None
@classmethod
async def init_resources(
cls, force: bool = False, is_zip: bool = False, git_source: str = "ali"
):
if (FONT_PATH.exists() and os.listdir(FONT_PATH)) and not force:
return
if is_zip:
if cls.TMP_PATH.exists():
logger.debug(
"resources临时文件夹已存在移除resources临时文件夹", LOG_COMMAND
)
await clean_git(cls.TMP_PATH)
shutil.rmtree(cls.TMP_PATH, ignore_errors=True)
cls.TMP_PATH.mkdir(parents=True, exist_ok=True)
try:
await cls.__download_resources()
cls.file_handle()
except Exception as e:
logger.error("获取resources资源包失败", LOG_COMMAND, e=e)
else:
if git_source == "ali":
await AliyunRepoManager.update(cls.GITHUB_URL, cls.RESOURCE_PATH)
else:
await GithubRepoManager.update(cls.GITHUB_URL, cls.RESOURCE_PATH)
cls.UNZIP_PATH = cls.TMP_PATH / "resources"
cls.file_handle()
if cls.TMP_PATH.exists():
logger.debug("移除resources临时文件夹", LOG_COMMAND)
await clean_git(cls.TMP_PATH)
shutil.rmtree(cls.TMP_PATH)
@classmethod
def file_handle(cls):
if not cls.UNZIP_PATH:
return
cls.__recursive_folder(cls.UNZIP_PATH, ".")
@classmethod
def __recursive_folder(cls, dir: Path, parent_path: str):
for file in dir.iterdir():
if file.is_dir():
cls.__recursive_folder(file, f"{parent_path}/{file.name}")
else:
res_file = Path(parent_path) / file.name
if res_file.exists():
res_file.unlink()
res_file.parent.mkdir(parents=True, exist_ok=True)
file.rename(res_file)
@classmethod
async def __download_resources(cls):
"""获取resources文件夹"""
repo_info = GithubUtils.parse_github_url(cls.GITHUB_URL)
url = await repo_info.get_archive_download_urls()
logger.debug("开始下载resources资源包...", LOG_COMMAND)
if not await AsyncHttpx.download_file(url, cls.ZIP_FILE, stream=True):
logger.error(
"下载resources资源包失败请尝试重启重新下载或前往 "
"https://github.com/zhenxun-org/zhenxun-bot-resources 手动下载..."
)
raise DownloadResourceException("下载resources资源包失败...")
logger.debug("下载resources资源文件压缩包完成...", LOG_COMMAND)
tf = zipfile.ZipFile(cls.ZIP_FILE)
tf.extractall(cls.TMP_PATH)
logger.debug("解压文件压缩包完成...", LOG_COMMAND)
download_file_path = cls.TMP_PATH / next(
x for x in os.listdir(cls.TMP_PATH) if (cls.TMP_PATH / x).is_dir()
)
cls.UNZIP_PATH = download_file_path / "resources"
if tf:
tf.close()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,88 @@
"""
Pydantic V1 & V2 兼容层模块
Pydantic V1 V2 版本提供统一的便捷函数与类
包括 model_dump, model_copy, model_json_schema, parse_as
"""
from typing import Any, TypeVar, get_args, get_origin
from nonebot.compat import PYDANTIC_V2, model_dump
from pydantic import VERSION, BaseModel
T = TypeVar("T", bound=BaseModel)
V = TypeVar("V")
__all__ = [
"PYDANTIC_V2",
"_dump_pydantic_obj",
"_is_pydantic_type",
"model_copy",
"model_dump",
"model_json_schema",
"parse_as",
]
def model_copy(
model: T, *, update: dict[str, Any] | None = None, deep: bool = False
) -> T:
"""
Pydantic `model.copy()` (v1) `model.model_copy()` (v2) 的兼容函数
"""
if PYDANTIC_V2:
return model.model_copy(update=update, deep=deep)
else:
update_dict = update or {}
return model.copy(update=update_dict, deep=deep)
def model_json_schema(model_class: type[BaseModel], **kwargs: Any) -> dict[str, Any]:
"""
Pydantic `Model.schema()` (v1) `Model.model_json_schema()` (v2) 的兼容函数
"""
if PYDANTIC_V2:
return model_class.model_json_schema(**kwargs)
else:
return model_class.schema(by_alias=kwargs.get("by_alias", True))
def _is_pydantic_type(t: Any) -> bool:
"""
递归检查一个类型注解是否与 Pydantic BaseModel 相关
"""
if t is None:
return False
origin = get_origin(t)
if origin:
return any(_is_pydantic_type(arg) for arg in get_args(t))
return isinstance(t, type) and issubclass(t, BaseModel)
def _dump_pydantic_obj(obj: Any) -> Any:
"""
递归地将一个对象内部的 Pydantic BaseModel 实例转换为字典
支持单个实例实例列表实例字典等情况
"""
if isinstance(obj, BaseModel):
return model_dump(obj)
if isinstance(obj, list):
return [_dump_pydantic_obj(item) for item in obj]
if isinstance(obj, dict):
return {key: _dump_pydantic_obj(value) for key, value in obj.items()}
return obj
def parse_as(type_: type[V], obj: Any) -> V:
"""
一个兼容 Pydantic V1 parse_obj_as 和V2的TypeAdapter.validate_python 的辅助函数
"""
if VERSION.startswith("1"):
from pydantic import parse_obj_as
return parse_obj_as(type_, obj)
else:
from pydantic import TypeAdapter # type: ignore
return TypeAdapter(type_).validate_python(obj)

View File

@ -24,9 +24,6 @@ from .models import (
RepoFileInfo,
RepoType,
RepoUpdateResult,
SubmoduleConfig,
SubmoduleInfo,
SubmoduleUpdateResult,
)
from .utils import check_git, filter_files, glob_to_regex, run_git_command
@ -56,9 +53,6 @@ __all__ = [
"RepoType",
"RepoUpdateError",
"RepoUpdateResult",
"SubmoduleConfig",
"SubmoduleInfo",
"SubmoduleUpdateResult",
"check_git",
"filter_files",
"glob_to_regex",

View File

@ -283,10 +283,11 @@ class BaseRepoManager(ABC):
return result
# 如果目录存在检查是否是Git仓库
success, _, _ = await run_git_command(
"rev-parse --is-inside-work-tree", cwd=local_path
)
if not success:
# 首先检查目录本身是否有.git文件夹
git_dir = local_path / ".git"
is_git_repo = git_dir.exists() and git_dir.is_dir()
if not is_git_repo:
# 如果不是Git仓库尝试初始化它
logger.info(f"目录 {local_path} 不是Git仓库尝试初始化", LOG_COMMAND)
init_success, _, init_stderr = await run_git_command(
@ -338,7 +339,7 @@ class BaseRepoManager(ABC):
)
# 获取远程更新
logger.info("获取远程更新", LOG_COMMAND)
logger.info(f"获取远程更新: {repo_url}", LOG_COMMAND)
success, _, stderr = await run_git_command("fetch origin", cwd=local_path)
if not success:
return RepoUpdateResult(
@ -373,7 +374,7 @@ class BaseRepoManager(ABC):
)
# 拉取最新代码
logger.info("拉取最新代码", LOG_COMMAND)
logger.info(f"拉取最新代码: {repo_url}", LOG_COMMAND)
pull_cmd = f"pull origin {branch}"
if force:
pull_cmd = f"fetch --all && git reset --hard origin/{branch}"

View File

@ -29,11 +29,7 @@ from .models import (
RepoFileInfo,
RepoType,
RepoUpdateResult,
SubmoduleConfig,
SubmoduleInfo,
SubmoduleUpdateResult,
)
from .submodule_manager import SubmoduleManager
class GithubManager(BaseRepoManager):
@ -47,7 +43,6 @@ class GithubManager(BaseRepoManager):
config: 配置如果为None则使用默认配置
"""
super().__init__(config)
self.submodule_manager = SubmoduleManager(self)
async def update_repo(
self,
@ -529,158 +524,3 @@ class GithubManager(BaseRepoManager):
raise RepoDownloadError("下载文件失败")
raise RepoDownloadError("下载文件失败: 超过最大重试次数")
# 子模块相关方法
async def init_submodules(
self,
main_repo_path: Path,
submodule_configs: list[SubmoduleConfig],
) -> bool:
"""
初始化子模块
参数:
main_repo_path: 主仓库路径
submodule_configs: 子模块配置列表
返回:
bool: 是否成功
"""
return await self.submodule_manager.init_submodules(
main_repo_path, submodule_configs
)
async def update_submodules(
self,
main_repo_path: Path,
submodule_configs: list[SubmoduleConfig],
) -> list[SubmoduleUpdateResult]:
"""
更新子模块
参数:
main_repo_path: 主仓库路径
submodule_configs: 子模块配置列表
返回:
list[SubmoduleUpdateResult]: 更新结果列表
"""
return await self.submodule_manager.update_submodules(
main_repo_path, submodule_configs
)
async def get_submodule_info(
self,
main_repo_path: Path,
submodule_configs: list[SubmoduleConfig],
) -> list[SubmoduleInfo]:
"""
获取子模块信息
参数:
main_repo_path: 主仓库路径
submodule_configs: 子模块配置列表
返回:
list[SubmoduleInfo]: 子模块信息列表
"""
return await self.submodule_manager.get_submodule_info(
main_repo_path, submodule_configs
)
def save_submodule_configs(
self,
main_repo_path: Path,
submodule_configs: list[SubmoduleConfig],
) -> bool:
"""
保存子模块配置到文件
参数:
main_repo_path: 主仓库路径
submodule_configs: 子模块配置列表
返回:
bool: 是否成功
"""
return self.submodule_manager.save_submodule_configs(
main_repo_path, submodule_configs
)
async def load_submodule_configs(
self, main_repo_path: Path
) -> list[SubmoduleConfig]:
"""
从文件加载子模块配置
参数:
main_repo_path: 主仓库路径
返回:
list[SubmoduleConfig]: 子模块配置列表
"""
return await self.submodule_manager.load_submodule_configs(main_repo_path)
async def update_with_submodules(
self,
repo_url: str,
local_path: Path,
branch: str = "main",
submodule_configs: list[SubmoduleConfig] | None = None,
use_git: bool = True,
force: bool = False,
include_patterns: list[str] | None = None,
exclude_patterns: list[str] | None = None,
) -> RepoUpdateResult:
"""
更新仓库并处理子模块
参数:
repo_url: 仓库URL格式为 https://github.com/owner/repo
local_path: 本地保存路径
branch: 分支名称
submodule_configs: 子模块配置列表
use_git: 是否使用Git命令更新
force: 是否强制更新
include_patterns: 包含的文件模式列表
exclude_patterns: 排除的文件模式列表
返回:
RepoUpdateResult: 更新结果
"""
# 更新主仓库
result = await self.update(
repo_url,
local_path,
branch,
use_git,
force,
include_patterns,
exclude_patterns,
)
# 如果没有子模块配置,直接返回结果
if not submodule_configs:
return result
# 处理子模块
try:
submodule_results = await self.update_submodules(
local_path, submodule_configs
)
result.submodule_results = submodule_results
# 检查子模块更新是否成功
failed_submodules = [r for r in submodule_results if not r.success]
if failed_submodules:
logger.warning(
"部分子模块更新失败:"
f" {[r.submodule_name for r in failed_submodules]}",
LOG_COMMAND,
)
except Exception as e:
logger.error(f"处理子模块时发生错误: {e}", LOG_COMMAND)
result.error_message += f"; 子模块处理失败: {e}"
return result

View File

@ -15,62 +15,6 @@ class RepoType(str, Enum):
ALIYUN = "aliyun"
@dataclass
class SubmoduleConfig:
"""子模块配置"""
# 子模块名称
name: str
# 子模块路径(相对于主仓库)
path: str
# 子模块仓库URL
repo_url: str
# 分支名称
branch: str = "main"
# 是否启用
enabled: bool = True
# 包含的文件模式列表
include_patterns: list[str] | None = None
# 排除的文件模式列表
exclude_patterns: list[str] | None = None
@dataclass
class SubmoduleInfo:
"""子模块信息"""
# 子模块配置
config: SubmoduleConfig
# 当前版本
current_version: str = ""
# 最新版本
latest_version: str = ""
# 最后更新时间
last_update: datetime | None = None
# 更新状态
update_status: str = "unknown" # unknown, up_to_date, outdated, error
@dataclass
class SubmoduleUpdateResult:
"""子模块更新结果"""
# 子模块名称
submodule_name: str
# 子模块路径
submodule_path: str
# 旧版本
old_version: str
# 新版本
new_version: str
# 是否成功
success: bool = False
# 错误消息
error_message: str = ""
# 变更的文件列表
changed_files: list[str] = field(default_factory=list)
@dataclass
class RepoFileInfo:
"""仓库文件信息"""
@ -123,8 +67,6 @@ class RepoUpdateResult:
error_message: str = ""
# 变更的文件列表
changed_files: list[str] = field(default_factory=list)
# 子模块更新结果
submodule_results: list[SubmoduleUpdateResult] = field(default_factory=list)
@dataclass

View File

@ -1,408 +0,0 @@
"""
子模块管理工具
"""
import json
from pathlib import Path
from zhenxun.services.log import logger
from .config import LOG_COMMAND
from .github_manager import GithubManager
from .models import SubmoduleConfig, SubmoduleInfo, SubmoduleUpdateResult
from .utils import run_git_command
class SubmoduleManager:
"""子模块管理器"""
def __init__(self, github_manager: GithubManager):
"""
初始化子模块管理器
参数:
github_manager: GitHub管理器实例
"""
self.github_manager = github_manager
async def init_submodules(
self, main_repo_path: Path, submodule_configs: list[SubmoduleConfig]
) -> bool:
"""
初始化子模块
参数:
main_repo_path: 主仓库路径
submodule_configs: 子模块配置列表
返回:
bool: 是否成功
"""
try:
# 检查是否在Git仓库中
success, stdout, stderr = await run_git_command("status", main_repo_path)
if not success:
logger.error(f"路径 {main_repo_path} 不是有效的Git仓库", LOG_COMMAND)
return False
# 初始化每个子模块
for config in submodule_configs:
if not config.enabled:
continue
await self._init_single_submodule(main_repo_path, config)
# 更新子模块
await self._update_submodules(main_repo_path)
return True
except Exception as e:
logger.error(f"初始化子模块失败: {e}", LOG_COMMAND)
return False
async def _init_single_submodule(
self, main_repo_path: Path, config: SubmoduleConfig
) -> bool:
"""
初始化单个子模块
参数:
main_repo_path: 主仓库路径
config: 子模块配置
返回:
bool: 是否成功
"""
try:
submodule_path = main_repo_path / config.path
# 检查子模块是否已存在
if submodule_path.exists() and (submodule_path / ".git").exists():
logger.info(f"子模块 {config.name} 已存在,跳过初始化", LOG_COMMAND)
return True
# 添加子模块
success, stdout, stderr = await run_git_command(
f"submodule add -b {config.branch} {config.repo_url} {config.path}",
main_repo_path,
)
if not success:
logger.error(f"添加子模块 {config.name} 失败: {stderr}", LOG_COMMAND)
return False
logger.info(f"成功添加子模块 {config.name}", LOG_COMMAND)
return True
except Exception as e:
logger.error(f"初始化子模块 {config.name} 失败: {e}", LOG_COMMAND)
return False
async def _update_submodules(self, main_repo_path: Path) -> bool:
"""
更新所有子模块
参数:
main_repo_path: 主仓库路径
返回:
bool: 是否成功
"""
try:
# 更新子模块
success, stdout, stderr = await run_git_command(
"submodule update --init --recursive", main_repo_path
)
if not success:
logger.error(f"更新子模块失败: {stderr}", LOG_COMMAND)
return False
logger.info("成功更新所有子模块", LOG_COMMAND)
return True
except Exception as e:
logger.error(f"更新子模块失败: {e}", LOG_COMMAND)
return False
async def update_submodules(
self, main_repo_path: Path, submodule_configs: list[SubmoduleConfig]
) -> list[SubmoduleUpdateResult]:
"""
更新子模块
参数:
main_repo_path: 主仓库路径
submodule_configs: 子模块配置列表
返回:
List[SubmoduleUpdateResult]: 更新结果列表
"""
results = []
for config in submodule_configs:
if not config.enabled:
continue
result = await self._update_single_submodule(main_repo_path, config)
results.append(result)
return results
async def _update_single_submodule(
self, main_repo_path: Path, config: SubmoduleConfig
) -> SubmoduleUpdateResult:
"""
更新单个子模块
参数:
main_repo_path: 主仓库路径
config: 子模块配置
返回:
SubmoduleUpdateResult: 更新结果
"""
result = SubmoduleUpdateResult(
submodule_name=config.name,
submodule_path=config.path,
old_version="",
new_version="",
)
try:
submodule_path = main_repo_path / config.path
# 检查子模块是否存在
if not submodule_path.exists():
result.error_message = f"子模块路径不存在: {submodule_path}"
return result
# 获取当前版本
success, stdout, stderr = await run_git_command(
"rev-parse HEAD", submodule_path
)
if not success:
result.error_message = f"获取当前版本失败: {stderr}"
return result
old_version = stdout.strip()
result.old_version = old_version
# 获取远程最新版本
success, stdout, stderr = await run_git_command(
f"ls-remote origin {config.branch}", submodule_path
)
if not success:
result.error_message = f"获取远程版本失败: {stderr}"
return result
# 解析最新版本
lines = stdout.strip().split("\n")
if not lines or not lines[0]:
result.error_message = "无法获取远程版本信息"
return result
latest_version = lines[0].split("\t")[0]
result.new_version = latest_version
# 检查是否需要更新
if old_version == latest_version:
result.success = True
logger.info(f"子模块 {config.name} 已是最新版本", LOG_COMMAND)
return result
# 更新子模块
success, stdout, stderr = await run_git_command(
f"pull origin {config.branch}", submodule_path
)
if not success:
result.error_message = f"更新子模块失败: {stderr}"
return result
# 更新主仓库中的子模块引用
success, stdout, stderr = await run_git_command(
f"add {config.path}", main_repo_path
)
if not success:
result.error_message = f"更新主仓库引用失败: {stderr}"
return result
result.success = True
logger.info(
f"成功更新子模块 {config.name}: {old_version} -> {latest_version}",
LOG_COMMAND,
)
except Exception as e:
result.error_message = f"更新子模块时发生错误: {e}"
logger.error(f"更新子模块 {config.name} 失败: {e}", LOG_COMMAND)
return result
async def get_submodule_info(
self, main_repo_path: Path, submodule_configs: list[SubmoduleConfig]
) -> list[SubmoduleInfo]:
"""
获取子模块信息
参数:
main_repo_path: 主仓库路径
submodule_configs: 子模块配置列表
返回:
List[SubmoduleInfo]: 子模块信息列表
"""
submodule_infos = []
for config in submodule_configs:
if not config.enabled:
continue
info = await self._get_single_submodule_info(main_repo_path, config)
submodule_infos.append(info)
return submodule_infos
async def _get_single_submodule_info(
self, main_repo_path: Path, config: SubmoduleConfig
) -> SubmoduleInfo:
"""
获取单个子模块信息
参数:
main_repo_path: 主仓库路径
config: 子模块配置
返回:
SubmoduleInfo: 子模块信息
"""
info = SubmoduleInfo(config=config)
try:
submodule_path = main_repo_path / config.path
if not submodule_path.exists():
info.update_status = "error"
return info
# 获取当前版本
success, stdout, stderr = await run_git_command(
"rev-parse HEAD", submodule_path
)
if success:
info.current_version = stdout.strip()
# 获取远程最新版本
success, stdout, stderr = await run_git_command(
f"ls-remote origin {config.branch}", submodule_path
)
if success and stdout.strip():
lines = stdout.strip().split("\n")
if lines and lines[0]:
info.latest_version = lines[0].split("\t")[0]
# 确定更新状态
if info.current_version and info.latest_version:
if info.current_version == info.latest_version:
info.update_status = "up_to_date"
else:
info.update_status = "outdated"
else:
info.update_status = "unknown"
except Exception as e:
info.update_status = "error"
logger.error(f"获取子模块 {config.name} 信息失败: {e}", LOG_COMMAND)
return info
def save_submodule_configs(
self, main_repo_path: Path, submodule_configs: list[SubmoduleConfig]
) -> bool:
"""
保存子模块配置到文件
参数:
main_repo_path: 主仓库路径
submodule_configs: 子模块配置列表
返回:
bool: 是否成功
"""
try:
config_file = main_repo_path / ".submodules.json"
# 转换为字典格式
configs_dict = []
for config in submodule_configs:
config_dict = {
"name": config.name,
"path": config.path,
"repo_url": config.repo_url,
"branch": config.branch,
"enabled": config.enabled,
"include_patterns": config.include_patterns,
"exclude_patterns": config.exclude_patterns,
}
configs_dict.append(config_dict)
# 保存到文件
with open(config_file, "w", encoding="utf-8") as f:
json.dump(configs_dict, f, indent=2, ensure_ascii=False)
logger.info(f"子模块配置已保存到 {config_file}", LOG_COMMAND)
return True
except Exception as e:
logger.error(f"保存子模块配置失败: {e}", LOG_COMMAND)
return False
def load_submodule_configs(self, main_repo_path: Path) -> list[SubmoduleConfig]:
"""
从文件加载子模块配置
参数:
main_repo_path: 主仓库路径
返回:
List[SubmoduleConfig]: 子模块配置列表
"""
try:
config_file = main_repo_path / ".submodules.json"
if not config_file.exists():
logger.warning(f"子模块配置文件不存在: {config_file}", LOG_COMMAND)
return []
with open(config_file, encoding="utf-8") as f:
configs_dict = json.load(f)
# 转换为SubmoduleConfig对象
configs = []
for config_dict in configs_dict:
config = SubmoduleConfig(
name=config_dict["name"],
path=config_dict["path"],
repo_url=config_dict["repo_url"],
branch=config_dict.get("branch", "main"),
enabled=config_dict.get("enabled", True),
include_patterns=config_dict.get("include_patterns"),
exclude_patterns=config_dict.get("exclude_patterns"),
)
configs.append(config)
logger.info(
f"{config_file} 加载了 {len(configs)} 个子模块配置", LOG_COMMAND
)
return configs
except Exception as e:
logger.error(f"加载子模块配置失败: {e}", LOG_COMMAND)
return []

View File

@ -1,210 +0,0 @@
#!/usr/bin/env python3
"""
GitHub子模块快速设置脚本
"""
import asyncio
from pathlib import Path
import sys
from zhenxun.services.log import logger
from zhenxun.utils.repo_utils import (
GithubRepoManager,
SubmoduleConfig,
)
def create_sample_configs():
"""创建示例子模块配置"""
return [
SubmoduleConfig(
name="frontend-ui",
path="frontend/ui",
repo_url="https://github.com/your-org/frontend-ui",
branch="main",
enabled=True,
include_patterns=["*.js", "*.css", "*.html", "*.vue", "*.ts"],
exclude_patterns=["node_modules/*", "*.log", "dist/*", "coverage/*"],
),
SubmoduleConfig(
name="backend-api",
path="backend/api",
repo_url="https://github.com/your-org/backend-api",
branch="develop",
enabled=True,
include_patterns=["*.py", "*.json", "requirements.txt", "*.yml"],
exclude_patterns=["__pycache__/*", "*.pyc", "venv/*", ".pytest_cache/*"],
),
SubmoduleConfig(
name="shared-lib",
path="libs/shared",
repo_url="https://github.com/your-org/shared-lib",
branch="main",
enabled=True,
include_patterns=["*.py", "*.js", "*.ts", "*.json"],
exclude_patterns=["tests/*", "docs/*", "examples/*"],
),
]
async def setup_submodules(project_path: str, configs: list[SubmoduleConfig]):
"""设置子模块"""
main_repo_path = Path(project_path)
logger.info(f"正在为项目 {project_path} 设置子模块...")
# 检查路径是否存在
if not main_repo_path.exists():
logger.info(f"错误: 项目路径 {project_path} 不存在")
return False
# 检查是否是Git仓库
git_dir = main_repo_path / ".git"
if not git_dir.exists():
logger.info(f"错误: {project_path} 不是Git仓库")
logger.info("请先执行: git init")
return False
# 初始化子模块
logger.info("正在初始化子模块...")
success = await GithubRepoManager.init_submodules(main_repo_path, configs)
if not success:
logger.info("子模块初始化失败!")
return False
# 保存配置
logger.info("正在保存子模块配置...")
await GithubRepoManager.save_submodule_configs(main_repo_path, configs)
logger.info("✓ 子模块设置完成!")
logger.info(f"配置文件已保存到: {main_repo_path / '.submodules.json'}")
return True
async def update_submodules(project_path: str):
"""更新子模块"""
main_repo_path = Path(project_path)
logger.info(f"正在更新项目 {project_path} 的子模块...")
# 加载配置
configs = await GithubRepoManager.load_submodule_configs(main_repo_path)
if not configs:
logger.info("未找到子模块配置")
return False
logger.info(f"找到 {len(configs)} 个子模块配置")
# 获取子模块信息
infos = await GithubRepoManager.get_submodule_info(main_repo_path, configs)
logger.info("\n子模块状态:")
for info in infos:
status_icon = (
""
if info.update_status == "up_to_date"
else ""
if info.update_status == "outdated"
else ""
)
logger.info(
f"{status_icon} {info.config.name}"
f"({info.config.path}) - {info.update_status}"
)
# 更新子模块
logger.info("\n正在更新子模块...")
results = await GithubRepoManager.update_submodules(main_repo_path, configs)
success_count = 0
for result in results:
if result.success:
success_count += 1
if result.old_version != result.new_version:
logger.info(f"{result.submodule_name} 已更新")
else:
logger.info(f"{result.submodule_name} 已是最新版本")
else:
logger.info(f"{result.submodule_name} 更新失败: {result.error_message}")
logger.info(f"\n更新完成: {success_count}/{len(results)} 个子模块更新成功")
return success_count == len(results)
async def show_submodule_info(project_path: str):
"""显示子模块信息"""
main_repo_path = Path(project_path)
logger.info(f"项目 {project_path} 的子模块信息:")
# 加载配置
configs = await GithubRepoManager.load_submodule_configs(main_repo_path)
if not configs:
logger.info("未找到子模块配置")
return
# 获取详细信息
infos = await GithubRepoManager.get_submodule_info(main_repo_path, configs)
for info in infos:
logger.info(f"\n子模块: {info.config.name}")
logger.info(f" 路径: {info.config.path}")
logger.info(f" 仓库: {info.config.repo_url}")
logger.info(f" 分支: {info.config.branch}")
logger.info(f" 状态: {info.update_status}")
logger.info(f" 启用: {info.config.enabled}")
if info.current_version:
logger.info(f" 当前版本: {info.current_version[:8]}")
if info.latest_version:
logger.info(f" 最新版本: {info.latest_version[:8]}")
if info.config.include_patterns:
logger.info(f" 包含文件: {', '.join(info.config.include_patterns)}")
if info.config.exclude_patterns:
logger.info(f" 排除文件: {', '.join(info.config.exclude_patterns)}")
def print_info_usage():
"""打印使用说明"""
logger.info("GitHub子模块管理工具")
logger.info("用法:")
logger.info(" python submodule_setup.py setup <项目路径>")
logger.info(" python submodule_setup.py update <项目路径>")
logger.info(" python submodule_setup.py info <项目路径>")
logger.info("示例:")
logger.info(" python submodule_setup.py setup ./my_project")
logger.info(" python submodule_setup.py update ./my_project")
logger.info(" python submodule_setup.py info ./my_project")
async def main():
"""主函数"""
if len(sys.argv) < 3:
print_info_usage()
return
command = sys.argv[1]
project_path = sys.argv[2]
if command == "setup":
configs = create_sample_configs()
await setup_submodules(project_path, configs)
elif command == "update":
await update_submodules(project_path)
elif command == "info":
await show_submodule_info(project_path)
else:
logger.info(f"未知命令: {command}")
print_info_usage()
if __name__ == "__main__":
asyncio.run(main())

View File

@ -57,11 +57,13 @@ async def run_git_command(
"""
try:
full_command = f"git {command}"
# 将Path对象转换为字符串
cwd_str = str(cwd) if cwd else None
process = await asyncio.create_subprocess_shell(
full_command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=cwd,
cwd=cwd_str,
)
stdout_bytes, stderr_bytes = await process.communicate()