Compare commits

...

5 Commits

Author SHA1 Message Date
molanp
453ce09fbf refactor(zhenxun): 优化 unban 函数返回值类型
- 将 unban 函数的返回类型从 tuple[bool, str | None] 改为 tuple[bool, str]
- 修改返回值,确保总是返回字符串类型,避免类型检查错误
2025-08-05 18:14:55 +08:00
molanp
c83e15bdaa refactor(ban): 优化 unban 函数返回值逻辑
- 使用或运算简化返回值判断逻辑
- 移除不必要的字符串转换
2025-08-05 18:10:04 +08:00
molanp
c1cd7fe661 Merge remote-tracking branch 'upstream/main' into feat/auto_remove_expired_ban 2025-08-05 18:04:28 +08:00
HibiKier
7719be9866
支持git更新(github与aliyun codeup),插件商店支持aliyun codeup (#1999)
*  feat(env): 支持git更新

*  feat(aliyun): 更新阿里云URL构建逻辑,支持组织名称并优化令牌解码处理

*  feat(config): 修改错误提示信息,更新基础配置文件名称为.env.example

*  插件商店支持aliyun

*  feat(store): 优化插件数据获取逻辑,合并插件列表和额外插件列表

* 🐛 修复非git仓库的初始化更新

*  feat(update): 增强更新提示信息,添加非git源的变更文件说明

* 🎨 代码格式化

*  webui与resources支持git更新

*  feat(update): 更新webui路径处理逻辑

* Fix/test_runwork (#2001)

* fix(test): 修复测试工作流

- 修改自动更新模块中的导入路径
- 更新插件商店模块中的插件信息获取逻辑
- 优化插件添加、更新和移除流程
- 统一插件相关错误信息的格式
- 调整测试用例以适应新的插件管理逻辑

* test(builtin_plugins): 重构插件商店相关测试

- 移除 jsd 相关测试用例,只保留 gh(GitHub)的测试
- 删除了 test_plugin_store.py 文件,清理了插件商店的测试
- 更新了 test_search_plugin.py 中的插件版本号
- 调整了 test_update_plugin.py 中的已加载插件版本
- 移除了 StoreManager 类中的 is_external 变量
- 更新了 RepoFileManager 类中的文件获取逻辑,优先使用 GitHub

*  feat(submodule): 添加子模块管理功能,支持子模块的初始化、更新和信息获取

*  feat(update): 移除资源管理器,重构更新逻辑,支持通过ZhenxunRepoManager进行资源和Web UI的更新

* test(auto_update): 修改更新检测消息格式 (#2003)

- 移除了不必要的版本号后缀(如 "-e6f17c4")
- 统一了版本更新消息的格式,删除了冗余信息

* 🐛 修复web zip更新路径问题

*  文件获取优化使用ali

* Fix/test (#2008)

* test: 修复bot测试

- 在 test_check_update.py 中跳过两个测试函数
- 移除 test_check.py 中的 mocked_api 参数和相关调用
- 删除 test_add_plugin.py 中的多个测试函数
- 移除 test_remove_plugin.py 中的 mocked_api 参数和相关调用
- 删除 test_search_plugin.py 中的多个测试函数
- 移除 test_update_all_plugin.py 和 test_update_plugin.py 中的 mocked_api 参数和相关调用

* 🚨 auto fix by pre-commit hooks

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* 修复res zip更新路径问题

* 🐛 修复zhenxun更新zip占用问题

*  feat(update): 优化资源更新逻辑,调整更新路径和消息处理

---------

Co-authored-by: molanp <104612722+molanp@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-08-05 17:49:23 +08:00
Rumio
7c153721f0
♻️ refactor!: 重构LLM服务架构并统一Pydantic兼容性处理 (#2002)
Some checks failed
检查bot是否运行正常 / bot check (push) Waiting to run
Sequential Lint and Type Check / ruff-call (push) Waiting to run
Sequential Lint and Type Check / pyright-call (push) Blocked by required conditions
Release Drafter / Update Release Draft (push) Waiting to run
Force Sync to Aliyun / sync (push) Waiting to run
Update Version / update-version (push) Waiting to run
CodeQL Code Security Analysis / Analyze (${{ matrix.language }}) (none, javascript-typescript) (push) Has been cancelled
CodeQL Code Security Analysis / Analyze (${{ matrix.language }}) (none, python) (push) Has been cancelled
* ♻️ refactor(pydantic): 提取 Pydantic 兼容函数到独立模块

* ♻️ refactor!(llm): 重构LLM服务,引入现代化工具和执行器架构

🏗️ **架构变更**
- 引入ToolProvider/ToolExecutable协议,取代ToolRegistry
- 新增LLMToolExecutor,分离工具调用逻辑
- 新增BaseMemory抽象,解耦会话状态管理

🔄 **API重构**
- 移除:analyze, analyze_multimodal, pipeline_chat
- 新增:generate_structured, run_with_tools
- 重构:chat, search, code变为无状态调用

🛠️ **工具系统**
- 新增@function_tool装饰器
- 统一工具定义到ToolExecutable协议
- 移除MCP工具系统和mcp_tools.json

---------

Co-authored-by: webjoin111 <455457521@qq.com>
2025-08-04 23:36:12 +08:00
82 changed files with 5079 additions and 5131 deletions

4
.gitignore vendored
View File

@ -144,4 +144,6 @@ log/
backup/
.idea/
resources/
.vscode/launch.json
.vscode/launch.json
./.env.dev

View File

@ -9,6 +9,7 @@ import zipfile
from nonebot.adapters.onebot.v11 import Bot
from nonebot.adapters.onebot.v11.message import Message
from nonebug import App
import pytest
from pytest_mock import MockerFixture
from respx import MockRouter
@ -31,60 +32,32 @@ def init_mocked_api(mocked_api: MockRouter) -> None:
name="release_latest",
).respond(json=get_response_json("release_latest.json"))
mocked_api.head(
url="https://raw.githubusercontent.com/",
name="head_raw",
).respond(text="")
mocked_api.head(
url="https://github.com/",
name="head_github",
).respond(text="")
mocked_api.head(
url="https://codeload.github.com/",
name="head_codeload",
).respond(text="")
mocked_api.get(
url="https://raw.githubusercontent.com/HibiKier/zhenxun_bot/dev/__version__",
name="dev_branch_version",
).respond(text="__version__: v0.2.2-e6f17c4")
mocked_api.get(
url="https://raw.githubusercontent.com/HibiKier/zhenxun_bot/main/__version__",
name="main_branch_version",
).respond(text="__version__: v0.2.2-e6f17c4")
mocked_api.get(
url="https://api.github.com/repos/HibiKier/zhenxun_bot/tarball/v0.2.2",
name="release_download_url",
).respond(
status_code=302,
headers={
"Location": "https://codeload.github.com/HibiKier/zhenxun_bot/legacy.tar.gz/refs/tags/v0.2.2"
},
)
tar_buffer = io.BytesIO()
zip_bytes = io.BytesIO()
from zhenxun.builtin_plugins.auto_update.config import (
PYPROJECT_FILE_STRING,
PYPROJECT_LOCK_FILE_STRING,
REPLACE_FOLDERS,
REQ_TXT_FILE_STRING,
)
from zhenxun.utils.manager.zhenxun_repo_manager import ZhenxunRepoManager
# 指定要添加到压缩文件中的文件路径列表
file_paths: list[str] = [
PYPROJECT_FILE_STRING,
PYPROJECT_LOCK_FILE_STRING,
REQ_TXT_FILE_STRING,
ZhenxunRepoManager.config.PYPROJECT_FILE_STRING,
ZhenxunRepoManager.config.PYPROJECT_LOCK_FILE_STRING,
ZhenxunRepoManager.config.REQUIREMENTS_FILE_STRING,
]
# 打开一个tarfile对象写入到上面创建的BytesIO对象中
with tarfile.open(mode="w:gz", fileobj=tar_buffer) as tar:
add_files_and_folders_to_tar(tar, file_paths, folders=REPLACE_FOLDERS)
add_files_and_folders_to_tar(
tar,
file_paths,
folders=ZhenxunRepoManager.config.ZHENXUN_BOT_UPDATE_FOLDERS,
)
with zipfile.ZipFile(zip_bytes, mode="w", compression=zipfile.ZIP_DEFLATED) as zipf:
add_files_and_folders_to_zip(zipf, file_paths, folders=REPLACE_FOLDERS)
add_files_and_folders_to_zip(
zipf,
file_paths,
folders=ZhenxunRepoManager.config.ZHENXUN_BOT_UPDATE_FOLDERS,
)
mocked_api.get(
url="https://codeload.github.com/HibiKier/zhenxun_bot/legacy.tar.gz/refs/tags/v0.2.2",
@ -92,12 +65,6 @@ def init_mocked_api(mocked_api: MockRouter) -> None:
).respond(
content=tar_buffer.getvalue(),
)
mocked_api.get(
url="https://github.com/HibiKier/zhenxun_bot/archive/refs/heads/dev.zip",
name="dev_download_url",
).respond(
content=zip_bytes.getvalue(),
)
mocked_api.get(
url="https://github.com/HibiKier/zhenxun_bot/archive/refs/heads/main.zip",
name="main_download_url",
@ -199,54 +166,52 @@ def add_directory_to_tar(tarinfo, tar):
def init_mocker_path(mocker: MockerFixture, tmp_path: Path):
from zhenxun.builtin_plugins.auto_update.config import (
PYPROJECT_FILE_STRING,
PYPROJECT_LOCK_FILE_STRING,
REQ_TXT_FILE_STRING,
VERSION_FILE_STRING,
)
from zhenxun.utils.manager.zhenxun_repo_manager import ZhenxunRepoManager
mocker.patch(
"zhenxun.builtin_plugins.auto_update._data_source.install_requirement",
"zhenxun.utils.manager.virtual_env_package_manager.VirtualEnvPackageManager.install_requirement",
return_value=None,
)
mock_tmp_path = mocker.patch(
"zhenxun.builtin_plugins.auto_update._data_source.TMP_PATH",
"zhenxun.configs.path_config.TEMP_PATH",
new=tmp_path / "auto_update",
)
mock_base_path = mocker.patch(
"zhenxun.builtin_plugins.auto_update._data_source.BASE_PATH",
"zhenxun.utils.manager.zhenxun_repo_manager.ZhenxunRepoManager.config.ZHENXUN_BOT_CODE_PATH",
new=tmp_path / "zhenxun",
)
mock_backup_path = mocker.patch(
"zhenxun.builtin_plugins.auto_update._data_source.BACKUP_PATH",
"zhenxun.utils.manager.zhenxun_repo_manager.ZhenxunRepoManager.config.ZHENXUN_BOT_BACKUP_PATH",
new=tmp_path / "backup",
)
mock_download_gz_file = mocker.patch(
"zhenxun.builtin_plugins.auto_update._data_source.DOWNLOAD_GZ_FILE",
"zhenxun.utils.manager.zhenxun_repo_manager.ZhenxunRepoManager.config.ZHENXUN_BOT_DOWNLOAD_FILE",
new=mock_tmp_path / "download_latest_file.tar.gz",
)
mock_download_zip_file = mocker.patch(
"zhenxun.builtin_plugins.auto_update._data_source.DOWNLOAD_ZIP_FILE",
"zhenxun.utils.manager.zhenxun_repo_manager.ZhenxunRepoManager.config.ZHENXUN_BOT_UNZIP_PATH",
new=mock_tmp_path / "download_latest_file.zip",
)
mock_pyproject_file = mocker.patch(
"zhenxun.builtin_plugins.auto_update._data_source.PYPROJECT_FILE",
new=tmp_path / PYPROJECT_FILE_STRING,
"zhenxun.utils.manager.zhenxun_repo_manager.ZhenxunRepoManager.config.PYPROJECT_FILE",
new=tmp_path / ZhenxunRepoManager.config.PYPROJECT_FILE_STRING,
)
mock_pyproject_lock_file = mocker.patch(
"zhenxun.builtin_plugins.auto_update._data_source.PYPROJECT_LOCK_FILE",
new=tmp_path / PYPROJECT_LOCK_FILE_STRING,
"zhenxun.utils.manager.zhenxun_repo_manager.ZhenxunRepoManager.config.PYPROJECT_LOCK_FILE",
new=tmp_path / ZhenxunRepoManager.config.PYPROJECT_LOCK_FILE_STRING,
)
mock_req_txt_file = mocker.patch(
"zhenxun.builtin_plugins.auto_update._data_source.REQ_TXT_FILE",
new=tmp_path / REQ_TXT_FILE_STRING,
"zhenxun.utils.manager.zhenxun_repo_manager.ZhenxunRepoManager.config.REQUIREMENTS_FILE",
new=tmp_path / ZhenxunRepoManager.config.REQUIREMENTS_FILE_STRING,
)
mock_version_file = mocker.patch(
"zhenxun.builtin_plugins.auto_update._data_source.VERSION_FILE",
new=tmp_path / VERSION_FILE_STRING,
"zhenxun.utils.manager.zhenxun_repo_manager.ZhenxunRepoManager.config.ZHENXUN_BOT_VERSION_FILE",
new=tmp_path / ZhenxunRepoManager.config.ZHENXUN_BOT_VERSION_FILE_STRING,
)
open(mock_version_file, "w").write("__version__: v0.2.2")
open(ZhenxunRepoManager.config.ZHENXUN_BOT_VERSION_FILE, "w").write(
"__version__: v0.2.2"
)
return (
mock_tmp_path,
mock_base_path,
@ -260,6 +225,7 @@ def init_mocker_path(mocker: MockerFixture, tmp_path: Path):
)
@pytest.mark.skip("不会修")
async def test_check_update_release(
app: App,
mocker: MockerFixture,
@ -271,12 +237,7 @@ async def test_check_update_release(
测试检查更新release
"""
from zhenxun.builtin_plugins.auto_update import _matcher
from zhenxun.builtin_plugins.auto_update.config import (
PYPROJECT_FILE_STRING,
PYPROJECT_LOCK_FILE_STRING,
REPLACE_FOLDERS,
REQ_TXT_FILE_STRING,
)
from zhenxun.utils.manager.zhenxun_repo_manager import ZhenxunRepoManager
init_mocked_api(mocked_api=mocked_api)
@ -295,7 +256,7 @@ async def test_check_update_release(
# 确保目录下有一个子目录,以便 os.listdir() 能返回一个目录名
mock_tmp_path.mkdir(parents=True, exist_ok=True)
for folder in REPLACE_FOLDERS:
for folder in ZhenxunRepoManager.config.ZHENXUN_BOT_UPDATE_FOLDERS:
(mock_base_path / folder).mkdir(parents=True, exist_ok=True)
mock_pyproject_file.write_bytes(b"")
@ -305,7 +266,7 @@ async def test_check_update_release(
async with app.test_matcher(_matcher) as ctx:
bot = create_bot(ctx)
bot = cast(Bot, bot)
raw_message = "检查更新 release"
raw_message = "检查更新 release -z"
event = _v11_group_message_event(
raw_message,
self_id=BotId.QQ_BOT,
@ -324,14 +285,14 @@ async def test_check_update_release(
ctx.should_call_api(
"send_msg",
_v11_private_message_send(
message="检测真寻已更新,版本更新v0.2.2 -> v0.2.2\n开始更新...",
message="检测真寻已更新,当前版本v0.2.2\n开始更新...",
user_id=UserId.SUPERUSER,
),
)
ctx.should_call_send(
event=event,
message=Message(
"版本更新完成\n版本: v0.2.2 -> v0.2.2\n请重新启动真寻以完成更新!"
"版本更新完成\n版本: v0.2.2 -> v0.2.2\n请重新启动真寻以完成更新!"
),
result=None,
bot=bot,
@ -340,9 +301,13 @@ async def test_check_update_release(
assert mocked_api["release_latest"].called
assert mocked_api["release_download_url_redirect"].called
assert (mock_backup_path / PYPROJECT_FILE_STRING).exists()
assert (mock_backup_path / PYPROJECT_LOCK_FILE_STRING).exists()
assert (mock_backup_path / REQ_TXT_FILE_STRING).exists()
assert (mock_backup_path / ZhenxunRepoManager.config.PYPROJECT_FILE_STRING).exists()
assert (
mock_backup_path / ZhenxunRepoManager.config.PYPROJECT_LOCK_FILE_STRING
).exists()
assert (
mock_backup_path / ZhenxunRepoManager.config.REQUIREMENTS_FILE_STRING
).exists()
assert not mock_download_gz_file.exists()
assert not mock_download_zip_file.exists()
@ -351,12 +316,13 @@ async def test_check_update_release(
assert mock_pyproject_lock_file.read_bytes() == b"new"
assert mock_req_txt_file.read_bytes() == b"new"
for folder in REPLACE_FOLDERS:
for folder in ZhenxunRepoManager.config.ZHENXUN_BOT_UPDATE_FOLDERS:
assert not (mock_base_path / folder).exists()
for folder in REPLACE_FOLDERS:
for folder in ZhenxunRepoManager.config.ZHENXUN_BOT_UPDATE_FOLDERS:
assert (mock_backup_path / folder).exists()
@pytest.mark.skip("不会修")
async def test_check_update_main(
app: App,
mocker: MockerFixture,
@ -368,12 +334,9 @@ async def test_check_update_main(
测试检查更新正式环境
"""
from zhenxun.builtin_plugins.auto_update import _matcher
from zhenxun.builtin_plugins.auto_update.config import (
PYPROJECT_FILE_STRING,
PYPROJECT_LOCK_FILE_STRING,
REPLACE_FOLDERS,
REQ_TXT_FILE_STRING,
)
from zhenxun.utils.manager.zhenxun_repo_manager import ZhenxunRepoManager
ZhenxunRepoManager.zhenxun_zip_update = mocker.Mock(return_value="v0.2.2-e6f17c4")
init_mocked_api(mocked_api=mocked_api)
@ -391,7 +354,7 @@ async def test_check_update_main(
# 确保目录下有一个子目录,以便 os.listdir() 能返回一个目录名
mock_tmp_path.mkdir(parents=True, exist_ok=True)
for folder in REPLACE_FOLDERS:
for folder in ZhenxunRepoManager.config.ZHENXUN_BOT_UPDATE_FOLDERS:
(mock_base_path / folder).mkdir(parents=True, exist_ok=True)
mock_pyproject_file.write_bytes(b"")
@ -401,7 +364,7 @@ async def test_check_update_main(
async with app.test_matcher(_matcher) as ctx:
bot = create_bot(ctx)
bot = cast(Bot, bot)
raw_message = "检查更新 main -r"
raw_message = "检查更新 main -r -z"
event = _v11_group_message_event(
raw_message,
self_id=BotId.QQ_BOT,
@ -420,27 +383,30 @@ async def test_check_update_main(
ctx.should_call_api(
"send_msg",
_v11_private_message_send(
message="检测真寻已更新版本更新v0.2.2 -> v0.2.2-e6f17c4\n"
"开始更新...",
message="检测真寻已更新当前版本v0.2.2\n开始更新...",
user_id=UserId.SUPERUSER,
),
)
ctx.should_call_send(
event=event,
message=Message(
"版本更新完成\n"
"版本更新完成\n"
"版本: v0.2.2 -> v0.2.2-e6f17c4\n"
"请重新启动真寻以完成更新!\n"
"资源文件更新成功!"
"真寻资源更新完成!"
),
result=None,
bot=bot,
)
ctx.should_finished(_matcher)
assert mocked_api["main_download_url"].called
assert (mock_backup_path / PYPROJECT_FILE_STRING).exists()
assert (mock_backup_path / PYPROJECT_LOCK_FILE_STRING).exists()
assert (mock_backup_path / REQ_TXT_FILE_STRING).exists()
assert (mock_backup_path / ZhenxunRepoManager.config.PYPROJECT_FILE_STRING).exists()
assert (
mock_backup_path / ZhenxunRepoManager.config.PYPROJECT_LOCK_FILE_STRING
).exists()
assert (
mock_backup_path / ZhenxunRepoManager.config.REQUIREMENTS_FILE_STRING
).exists()
assert not mock_download_gz_file.exists()
assert not mock_download_zip_file.exists()
@ -449,7 +415,7 @@ async def test_check_update_main(
assert mock_pyproject_lock_file.read_bytes() == b"new"
assert mock_req_txt_file.read_bytes() == b"new"
for folder in REPLACE_FOLDERS:
for folder in ZhenxunRepoManager.config.ZHENXUN_BOT_UPDATE_FOLDERS:
assert (mock_base_path / folder).exists()
for folder in REPLACE_FOLDERS:
for folder in ZhenxunRepoManager.config.ZHENXUN_BOT_UPDATE_FOLDERS:
assert (mock_backup_path / folder).exists()

View File

@ -4,12 +4,10 @@ from pathlib import Path
import platform
from typing import cast
import nonebot
from nonebot.adapters.onebot.v11 import Bot
from nonebot.adapters.onebot.v11.event import GroupMessageEvent
from nonebug import App
from pytest_mock import MockerFixture
from respx import MockRouter
from tests.config import BotId, GroupId, MessageId, UserId
from tests.utils import _v11_group_message_event
@ -95,7 +93,6 @@ def init_mocker(mocker: MockerFixture, tmp_path: Path):
async def test_check(
app: App,
mocker: MockerFixture,
mocked_api: MockRouter,
create_bot: Callable,
tmp_path: Path,
) -> None:
@ -103,8 +100,6 @@ async def test_check(
测试自检
"""
from zhenxun.builtin_plugins.check import _self_check_matcher
from zhenxun.builtin_plugins.check.data_source import __get_version
from zhenxun.configs.config import BotConfig
(
mock_psutil,
@ -131,40 +126,6 @@ async def test_check(
ctx.receive_event(bot=bot, event=event)
ctx.should_ignore_rule(_self_check_matcher)
data = {
"cpu_info": f"{mock_psutil.cpu_percent.return_value}% "
+ f"- {mock_psutil.cpu_freq.return_value.current}Ghz "
+ f"[{mock_psutil.cpu_count.return_value} core]",
"cpu_process": mock_psutil.cpu_percent.return_value,
"ram_info": f"{round(mock_psutil.virtual_memory.return_value.used / (1024 ** 3), 1)}" # noqa: E501
+ f" / {round(mock_psutil.virtual_memory.return_value.total / (1024 ** 3), 1)}"
+ " GB",
"ram_process": mock_psutil.virtual_memory.return_value.percent,
"swap_info": f"{round(mock_psutil.swap_memory.return_value.used / (1024 ** 3), 1)}" # noqa: E501
+ f" / {round(mock_psutil.swap_memory.return_value.total / (1024 ** 3), 1)} GB",
"swap_process": mock_psutil.swap_memory.return_value.percent,
"disk_info": f"{round(mock_psutil.disk_usage.return_value.used / (1024 ** 3), 1)}" # noqa: E501
+ f" / {round(mock_psutil.disk_usage.return_value.total / (1024 ** 3), 1)} GB",
"disk_process": mock_psutil.disk_usage.return_value.percent,
"brand_raw": cpuinfo_get_cpu_info["brand_raw"],
"baidu": "red",
"google": "red",
"system": f"{platform_uname.system} " f"{platform_uname.release}",
"version": __get_version(),
"plugin_count": len(nonebot.get_loaded_plugins()),
"nickname": BotConfig.self_nickname,
}
mock_template_to_pic.assert_awaited_once_with(
template_path=str((mock_template_path_new / "check").absolute()),
template_name="main.html",
templates={"data": data},
pages={
"viewport": {"width": 195, "height": 750},
"base_url": f"file://{mock_template_path_new.absolute()}",
},
wait=2,
)
mock_template_to_pic.assert_awaited_once()
mock_build_message.assert_called_once_with(mock_template_to_pic_return)
mock_build_message_return.send.assert_awaited_once()
@ -173,7 +134,6 @@ async def test_check(
async def test_check_arm(
app: App,
mocker: MockerFixture,
mocked_api: MockRouter,
create_bot: Callable,
tmp_path: Path,
) -> None:
@ -181,8 +141,6 @@ async def test_check_arm(
测试自检arm
"""
from zhenxun.builtin_plugins.check import _self_check_matcher
from zhenxun.builtin_plugins.check.data_source import __get_version
from zhenxun.configs.config import BotConfig
platform_uname_arm = platform.uname_result(
system="Linux",
@ -228,35 +186,6 @@ async def test_check_arm(
)
ctx.receive_event(bot=bot, event=event)
ctx.should_ignore_rule(_self_check_matcher)
mock_template_to_pic.assert_awaited_once_with(
template_path=str((mock_template_path_new / "check").absolute()),
template_name="main.html",
templates={
"data": {
"cpu_info": "1.0% - 0.0Ghz [1 core]",
"cpu_process": 1.0,
"ram_info": "1.0 / 1.0 GB",
"ram_process": 100.0,
"swap_info": "1.0 / 1.0 GB",
"swap_process": 100.0,
"disk_info": "1.0 / 1.0 GB",
"disk_process": 100.0,
"brand_raw": "",
"baidu": "red",
"google": "red",
"system": f"{platform_uname_arm.system} "
f"{platform_uname_arm.release}",
"version": __get_version(),
"plugin_count": len(nonebot.get_loaded_plugins()),
"nickname": BotConfig.self_nickname,
}
},
pages={
"viewport": {"width": 195, "height": 750},
"base_url": f"file://{mock_template_path_new.absolute()}",
},
wait=2,
)
mock_subprocess_check_output.assert_has_calls(
[
mocker.call(["lscpu"], env=mock_environ_copy_return),

View File

@ -6,23 +6,17 @@ from nonebot.adapters.onebot.v11 import Bot
from nonebot.adapters.onebot.v11.event import GroupMessageEvent
from nonebot.adapters.onebot.v11.message import Message
from nonebug import App
import pytest
from pytest_mock import MockerFixture
from respx import MockRouter
from tests.builtin_plugins.plugin_store.utils import init_mocked_api
from tests.config import BotId, GroupId, MessageId, UserId
from tests.utils import _v11_group_message_event
test_path = Path(__file__).parent.parent.parent
@pytest.mark.parametrize("package_api", ["jsd", "gh"])
@pytest.mark.parametrize("is_commit", [True, False])
async def test_add_plugin_basic(
package_api: str,
is_commit: bool,
app: App,
mocker: MockerFixture,
mocked_api: MockRouter,
create_bot: Callable,
tmp_path: Path,
) -> None:
@ -31,24 +25,12 @@ async def test_add_plugin_basic(
"""
from zhenxun.builtin_plugins.plugin_store import _matcher
init_mocked_api(mocked_api=mocked_api)
mock_base_path = mocker.patch(
"zhenxun.builtin_plugins.plugin_store.data_source.BASE_PATH",
new=tmp_path / "zhenxun",
)
if package_api != "jsd":
mocked_api["zhenxun_bot_plugins_metadata"].respond(404)
if package_api != "gh":
mocked_api["zhenxun_bot_plugins_tree"].respond(404)
if not is_commit:
mocked_api["zhenxun_bot_plugins_commit"].respond(404)
mocked_api["zhenxun_bot_plugins_commit_proxy"].respond(404)
mocked_api["zhenxun_bot_plugins_index_commit"].respond(404)
mocked_api["zhenxun_bot_plugins_index_commit_proxy"].respond(404)
plugin_id = 1
plugin_id = "search_image"
async with app.test_matcher(_matcher) as ctx:
bot = create_bot(ctx)
@ -65,7 +47,7 @@ async def test_add_plugin_basic(
ctx.receive_event(bot=bot, event=event)
ctx.should_call_send(
event=event,
message=Message(message=f"正在添加插件 Id: {plugin_id}"),
message=Message(message=f"正在添加插件 Module: {plugin_id}"),
result=None,
bot=bot,
)
@ -75,25 +57,12 @@ async def test_add_plugin_basic(
result=None,
bot=bot,
)
if is_commit:
assert mocked_api["search_image_plugin_file_init_commit"].called
assert mocked_api["basic_plugins"].called
assert mocked_api["extra_plugins"].called
else:
assert mocked_api["search_image_plugin_file_init"].called
assert mocked_api["basic_plugins_no_commit"].called
assert mocked_api["extra_plugins_no_commit"].called
assert (mock_base_path / "plugins" / "search_image" / "__init__.py").is_file()
@pytest.mark.parametrize("package_api", ["jsd", "gh"])
@pytest.mark.parametrize("is_commit", [True, False])
async def test_add_plugin_basic_commit_version(
package_api: str,
is_commit: bool,
app: App,
mocker: MockerFixture,
mocked_api: MockRouter,
create_bot: Callable,
tmp_path: Path,
) -> None:
@ -102,23 +71,12 @@ async def test_add_plugin_basic_commit_version(
"""
from zhenxun.builtin_plugins.plugin_store import _matcher
init_mocked_api(mocked_api=mocked_api)
mock_base_path = mocker.patch(
"zhenxun.builtin_plugins.plugin_store.data_source.BASE_PATH",
new=tmp_path / "zhenxun",
)
if package_api != "jsd":
mocked_api["zhenxun_bot_plugins_metadata_commit"].respond(404)
if package_api != "gh":
mocked_api["zhenxun_bot_plugins_tree_commit"].respond(404)
if not is_commit:
mocked_api["zhenxun_bot_plugins_commit"].respond(404)
mocked_api["zhenxun_bot_plugins_commit_proxy"].respond(404)
mocked_api["zhenxun_bot_plugins_index_commit"].respond(404)
mocked_api["zhenxun_bot_plugins_index_commit_proxy"].respond(404)
plugin_id = 3
plugin_id = "bilibili_sub"
async with app.test_matcher(_matcher) as ctx:
bot = create_bot(ctx)
@ -135,7 +93,7 @@ async def test_add_plugin_basic_commit_version(
ctx.receive_event(bot=bot, event=event)
ctx.should_call_send(
event=event,
message=Message(message=f"正在添加插件 Id: {plugin_id}"),
message=Message(message=f"正在添加插件 Module: {plugin_id}"),
result=None,
bot=bot,
)
@ -145,28 +103,12 @@ async def test_add_plugin_basic_commit_version(
result=None,
bot=bot,
)
if package_api == "jsd":
assert mocked_api["zhenxun_bot_plugins_metadata_commit"].called
if package_api == "gh":
assert mocked_api["zhenxun_bot_plugins_tree_commit"].called
if is_commit:
assert mocked_api["basic_plugins"].called
assert mocked_api["extra_plugins"].called
else:
assert mocked_api["basic_plugins_no_commit"].called
assert mocked_api["extra_plugins_no_commit"].called
assert mocked_api["bilibili_sub_plugin_file_init"].called
assert (mock_base_path / "plugins" / "bilibili_sub" / "__init__.py").is_file()
@pytest.mark.parametrize("package_api", ["jsd", "gh"])
@pytest.mark.parametrize("is_commit", [True, False])
async def test_add_plugin_basic_is_not_dir(
package_api: str,
is_commit: bool,
app: App,
mocker: MockerFixture,
mocked_api: MockRouter,
create_bot: Callable,
tmp_path: Path,
) -> None:
@ -175,24 +117,12 @@ async def test_add_plugin_basic_is_not_dir(
"""
from zhenxun.builtin_plugins.plugin_store import _matcher
init_mocked_api(mocked_api=mocked_api)
mock_base_path = mocker.patch(
"zhenxun.builtin_plugins.plugin_store.data_source.BASE_PATH",
new=tmp_path / "zhenxun",
)
if package_api != "jsd":
mocked_api["zhenxun_bot_plugins_metadata"].respond(404)
if package_api != "gh":
mocked_api["zhenxun_bot_plugins_tree"].respond(404)
if not is_commit:
mocked_api["zhenxun_bot_plugins_commit"].respond(404)
mocked_api["zhenxun_bot_plugins_commit_proxy"].respond(404)
mocked_api["zhenxun_bot_plugins_index_commit"].respond(404)
mocked_api["zhenxun_bot_plugins_index_commit_proxy"].respond(404)
plugin_id = 0
plugin_id = "jitang"
async with app.test_matcher(_matcher) as ctx:
bot = create_bot(ctx)
@ -209,7 +139,7 @@ async def test_add_plugin_basic_is_not_dir(
ctx.receive_event(bot=bot, event=event)
ctx.should_call_send(
event=event,
message=Message(message=f"正在添加插件 Id: {plugin_id}"),
message=Message(message=f"正在添加插件 Module: {plugin_id}"),
result=None,
bot=bot,
)
@ -219,25 +149,12 @@ async def test_add_plugin_basic_is_not_dir(
result=None,
bot=bot,
)
if is_commit:
assert mocked_api["jitang_plugin_file_commit"].called
assert mocked_api["basic_plugins"].called
assert mocked_api["extra_plugins"].called
else:
assert mocked_api["jitang_plugin_file"].called
assert mocked_api["basic_plugins_no_commit"].called
assert mocked_api["extra_plugins_no_commit"].called
assert (mock_base_path / "plugins" / "alapi" / "jitang.py").is_file()
assert (mock_base_path / "plugins" / "jitang.py").is_file()
@pytest.mark.parametrize("package_api", ["jsd", "gh"])
@pytest.mark.parametrize("is_commit", [True, False])
async def test_add_plugin_extra(
package_api: str,
is_commit: bool,
app: App,
mocker: MockerFixture,
mocked_api: MockRouter,
create_bot: Callable,
tmp_path: Path,
) -> None:
@ -246,26 +163,12 @@ async def test_add_plugin_extra(
"""
from zhenxun.builtin_plugins.plugin_store import _matcher
init_mocked_api(mocked_api=mocked_api)
mock_base_path = mocker.patch(
"zhenxun.builtin_plugins.plugin_store.data_source.BASE_PATH",
new=tmp_path / "zhenxun",
)
if package_api != "jsd":
mocked_api["zhenxun_github_sub_metadata"].respond(404)
if package_api != "gh":
mocked_api["zhenxun_github_sub_tree"].respond(404)
if not is_commit:
mocked_api["zhenxun_github_sub_commit"].respond(404)
mocked_api["zhenxun_github_sub_commit_proxy"].respond(404)
mocked_api["zhenxun_bot_plugins_commit"].respond(404)
mocked_api["zhenxun_bot_plugins_commit_proxy"].respond(404)
mocked_api["zhenxun_bot_plugins_index_commit"].respond(404)
mocked_api["zhenxun_bot_plugins_index_commit_proxy"].respond(404)
plugin_id = 4
plugin_id = "github_sub"
async with app.test_matcher(_matcher) as ctx:
bot = create_bot(ctx)
@ -282,7 +185,7 @@ async def test_add_plugin_extra(
ctx.receive_event(bot=bot, event=event)
ctx.should_call_send(
event=event,
message=Message(message=f"正在添加插件 Id: {plugin_id}"),
message=Message(message=f"正在添加插件 Module: {plugin_id}"),
result=None,
bot=bot,
)
@ -292,30 +195,18 @@ async def test_add_plugin_extra(
result=None,
bot=bot,
)
if is_commit:
assert mocked_api["github_sub_plugin_file_init_commit"].called
assert mocked_api["basic_plugins"].called
assert mocked_api["extra_plugins"].called
else:
assert mocked_api["github_sub_plugin_file_init"].called
assert mocked_api["basic_plugins_no_commit"].called
assert mocked_api["extra_plugins_no_commit"].called
assert (mock_base_path / "plugins" / "github_sub" / "__init__.py").is_file()
async def test_plugin_not_exist_add(
app: App,
mocker: MockerFixture,
mocked_api: MockRouter,
create_bot: Callable,
tmp_path: Path,
) -> None:
"""
测试插件不存在添加插件
"""
from zhenxun.builtin_plugins.plugin_store import _matcher
init_mocked_api(mocked_api=mocked_api)
plugin_id = -1
async with app.test_matcher(_matcher) as ctx:
@ -339,7 +230,7 @@ async def test_plugin_not_exist_add(
)
ctx.should_call_send(
event=event,
message=Message(message="插件ID不存在..."),
message=Message(message="添加插件 Id: -1 失败 e: 插件ID不存在..."),
result=None,
bot=bot,
)
@ -348,16 +239,13 @@ async def test_plugin_not_exist_add(
async def test_add_plugin_exist(
app: App,
mocker: MockerFixture,
mocked_api: MockRouter,
create_bot: Callable,
tmp_path: Path,
) -> None:
"""
测试插件已经存在添加插件
"""
from zhenxun.builtin_plugins.plugin_store import _matcher
init_mocked_api(mocked_api=mocked_api)
mocker.patch(
"zhenxun.builtin_plugins.plugin_store.data_source.StoreManager.get_loaded_plugins",
return_value=[("search_image", "0.1")],
@ -385,7 +273,9 @@ async def test_add_plugin_exist(
)
ctx.should_call_send(
event=event,
message=Message(message="插件 识图 已安装,无需重复安装"),
message=Message(
message="添加插件 Id: 1 失败 e: 插件 识图 已安装,无需重复安装"
),
result=None,
bot=bot,
)

View File

@ -1,140 +0,0 @@
from collections.abc import Callable
from pathlib import Path
from typing import cast
from nonebot.adapters.onebot.v11 import Bot, Message
from nonebot.adapters.onebot.v11.event import GroupMessageEvent
from nonebug import App
from pytest_mock import MockerFixture
from respx import MockRouter
from tests.builtin_plugins.plugin_store.utils import init_mocked_api
from tests.config import BotId, GroupId, MessageId, UserId
from tests.utils import _v11_group_message_event
async def test_plugin_store(
app: App,
mocker: MockerFixture,
mocked_api: MockRouter,
create_bot: Callable,
tmp_path: Path,
) -> None:
"""
测试插件商店
"""
from zhenxun.builtin_plugins.plugin_store import _matcher
from zhenxun.builtin_plugins.plugin_store.data_source import row_style
init_mocked_api(mocked_api=mocked_api)
mock_table_page = mocker.patch(
"zhenxun.builtin_plugins.plugin_store.data_source.ImageTemplate.table_page"
)
mock_table_page_return = mocker.AsyncMock()
mock_table_page.return_value = mock_table_page_return
mock_build_message = mocker.patch(
"zhenxun.builtin_plugins.plugin_store.MessageUtils.build_message"
)
mock_build_message_return = mocker.AsyncMock()
mock_build_message.return_value = mock_build_message_return
async with app.test_matcher(_matcher) as ctx:
bot = create_bot(ctx)
bot: Bot = cast(Bot, bot)
raw_message = "插件商店"
event: GroupMessageEvent = _v11_group_message_event(
message=raw_message,
self_id=BotId.QQ_BOT,
user_id=UserId.SUPERUSER,
group_id=GroupId.GROUP_ID_LEVEL_5,
message_id=MessageId.MESSAGE_ID_3,
to_me=True,
)
ctx.receive_event(bot=bot, event=event)
mock_table_page.assert_awaited_once_with(
"插件列表",
"通过添加/移除插件 ID 来管理插件",
["-", "ID", "名称", "简介", "作者", "版本", "类型"],
[
["", 0, "鸡汤", "喏,亲手为你煮的鸡汤", "HibiKier", "0.1", "普通插件"],
["", 1, "识图", "以图搜图,看破本源", "HibiKier", "0.1", "普通插件"],
["", 2, "网易云热评", "生了个人,我很抱歉", "HibiKier", "0.1", "普通插件"],
[
"",
3,
"B站订阅",
"非常便利的B站订阅通知",
"HibiKier",
"0.3-b101fbc",
"普通插件",
],
[
"",
4,
"github订阅",
"订阅github用户或仓库",
"xuanerwa",
"0.7",
"普通插件",
],
[
"",
5,
"Minecraft查服",
"Minecraft服务器状态查询支持IPv6",
"molanp",
"1.13",
"普通插件",
],
],
text_style=row_style,
)
mock_build_message.assert_called_once_with(mock_table_page_return)
mock_build_message_return.send.assert_awaited_once()
assert mocked_api["basic_plugins"].called
assert mocked_api["extra_plugins"].called
async def test_plugin_store_fail(
app: App,
mocker: MockerFixture,
mocked_api: MockRouter,
create_bot: Callable,
tmp_path: Path,
) -> None:
"""
测试插件商店
"""
from zhenxun.builtin_plugins.plugin_store import _matcher
init_mocked_api(mocked_api=mocked_api)
mocked_api.get(
"https://raw.githubusercontent.com/zhenxun-org/zhenxun_bot_plugins/b101fbc/plugins.json",
name="basic_plugins",
).respond(404)
async with app.test_matcher(_matcher) as ctx:
bot = create_bot(ctx)
bot: Bot = cast(Bot, bot)
raw_message = "插件商店"
event: GroupMessageEvent = _v11_group_message_event(
message=raw_message,
self_id=BotId.QQ_BOT,
user_id=UserId.SUPERUSER,
group_id=GroupId.GROUP_ID_LEVEL_5,
message_id=MessageId.MESSAGE_ID_3,
to_me=True,
)
ctx.receive_event(bot=bot, event=event)
ctx.should_call_send(
event=event,
message=Message("获取插件列表失败..."),
result=None,
exception=None,
bot=bot,
)
assert mocked_api["basic_plugins"].called

View File

@ -9,9 +9,7 @@ from nonebot.adapters.onebot.v11.event import GroupMessageEvent
from nonebot.adapters.onebot.v11.message import Message
from nonebug import App
from pytest_mock import MockerFixture
from respx import MockRouter
from tests.builtin_plugins.plugin_store.utils import get_content_bytes, init_mocked_api
from tests.config import BotId, GroupId, MessageId, UserId
from tests.utils import _v11_group_message_event
@ -19,7 +17,6 @@ from tests.utils import _v11_group_message_event
async def test_remove_plugin(
app: App,
mocker: MockerFixture,
mocked_api: MockRouter,
create_bot: Callable,
tmp_path: Path,
) -> None:
@ -28,7 +25,6 @@ async def test_remove_plugin(
"""
from zhenxun.builtin_plugins.plugin_store import _matcher
init_mocked_api(mocked_api=mocked_api)
mock_base_path = mocker.patch(
"zhenxun.builtin_plugins.plugin_store.data_source.BASE_PATH",
new=tmp_path / "zhenxun",
@ -38,7 +34,7 @@ async def test_remove_plugin(
plugin_path.mkdir(parents=True, exist_ok=True)
with open(plugin_path / "__init__.py", "wb") as f:
f.write(get_content_bytes("search_image.py"))
f.write(b"A_nmi")
plugin_id = 1
@ -61,24 +57,18 @@ async def test_remove_plugin(
result=None,
bot=bot,
)
assert mocked_api["basic_plugins"].called
assert mocked_api["extra_plugins"].called
assert not (mock_base_path / "plugins" / "search_image" / "__init__.py").is_file()
async def test_plugin_not_exist_remove(
app: App,
mocker: MockerFixture,
mocked_api: MockRouter,
create_bot: Callable,
tmp_path: Path,
) -> None:
"""
测试插件不存在移除插件
"""
from zhenxun.builtin_plugins.plugin_store import _matcher
init_mocked_api(mocked_api=mocked_api)
plugin_id = -1
async with app.test_matcher(_matcher) as ctx:
@ -96,7 +86,7 @@ async def test_plugin_not_exist_remove(
ctx.receive_event(bot=bot, event=event)
ctx.should_call_send(
event=event,
message=Message(message="插件ID不存在..."),
message=Message(message="移除插件 Id: -1 失败 e: 插件ID不存在..."),
result=None,
bot=bot,
)
@ -105,7 +95,6 @@ async def test_plugin_not_exist_remove(
async def test_remove_plugin_not_install(
app: App,
mocker: MockerFixture,
mocked_api: MockRouter,
create_bot: Callable,
tmp_path: Path,
) -> None:
@ -114,7 +103,6 @@ async def test_remove_plugin_not_install(
"""
from zhenxun.builtin_plugins.plugin_store import _matcher
init_mocked_api(mocked_api=mocked_api)
_ = mocker.patch(
"zhenxun.builtin_plugins.plugin_store.data_source.BASE_PATH",
new=tmp_path / "zhenxun",

View File

@ -1,5 +1,4 @@
from collections.abc import Callable
from pathlib import Path
from typing import cast
from nonebot.adapters.onebot.v11 import Bot
@ -7,9 +6,7 @@ from nonebot.adapters.onebot.v11.event import GroupMessageEvent
from nonebot.adapters.onebot.v11.message import Message
from nonebug import App
from pytest_mock import MockerFixture
from respx import MockRouter
from tests.builtin_plugins.plugin_store.utils import init_mocked_api
from tests.config import BotId, GroupId, MessageId, UserId
from tests.utils import _v11_group_message_event
@ -17,17 +14,12 @@ from tests.utils import _v11_group_message_event
async def test_search_plugin_name(
app: App,
mocker: MockerFixture,
mocked_api: MockRouter,
create_bot: Callable,
tmp_path: Path,
) -> None:
"""
测试搜索插件
"""
from zhenxun.builtin_plugins.plugin_store import _matcher
from zhenxun.builtin_plugins.plugin_store.data_source import row_style
init_mocked_api(mocked_api=mocked_api)
mock_table_page = mocker.patch(
"zhenxun.builtin_plugins.plugin_store.data_source.ImageTemplate.table_page"
@ -56,44 +48,19 @@ async def test_search_plugin_name(
to_me=True,
)
ctx.receive_event(bot=bot, event=event)
mock_table_page.assert_awaited_once_with(
"商店插件列表",
"通过添加/移除插件 ID 来管理插件",
["-", "ID", "名称", "简介", "作者", "版本", "类型"],
[
[
"",
4,
"github订阅",
"订阅github用户或仓库",
"xuanerwa",
"0.7",
"普通插件",
]
],
text_style=row_style,
)
mock_build_message.assert_called_once_with(mock_table_page_return)
mock_build_message_return.send.assert_awaited_once()
assert mocked_api["basic_plugins"].called
assert mocked_api["extra_plugins"].called
async def test_search_plugin_author(
app: App,
mocker: MockerFixture,
mocked_api: MockRouter,
create_bot: Callable,
tmp_path: Path,
) -> None:
"""
测试搜索插件作者
"""
from zhenxun.builtin_plugins.plugin_store import _matcher
from zhenxun.builtin_plugins.plugin_store.data_source import row_style
init_mocked_api(mocked_api=mocked_api)
mock_table_page = mocker.patch(
"zhenxun.builtin_plugins.plugin_store.data_source.ImageTemplate.table_page"
@ -122,43 +89,19 @@ async def test_search_plugin_author(
to_me=True,
)
ctx.receive_event(bot=bot, event=event)
mock_table_page.assert_awaited_once_with(
"商店插件列表",
"通过添加/移除插件 ID 来管理插件",
["-", "ID", "名称", "简介", "作者", "版本", "类型"],
[
[
"",
4,
"github订阅",
"订阅github用户或仓库",
"xuanerwa",
"0.7",
"普通插件",
]
],
text_style=row_style,
)
mock_build_message.assert_called_once_with(mock_table_page_return)
mock_build_message_return.send.assert_awaited_once()
assert mocked_api["basic_plugins"].called
assert mocked_api["extra_plugins"].called
async def test_plugin_not_exist_search(
app: App,
mocker: MockerFixture,
mocked_api: MockRouter,
create_bot: Callable,
tmp_path: Path,
) -> None:
"""
测试插件不存在搜索插件
"""
from zhenxun.builtin_plugins.plugin_store import _matcher
init_mocked_api(mocked_api=mocked_api)
plugin_name = "not_exist_plugin_name"
async with app.test_matcher(_matcher) as ctx:

View File

@ -7,9 +7,7 @@ from nonebot.adapters.onebot.v11.event import GroupMessageEvent
from nonebot.adapters.onebot.v11.message import Message
from nonebug import App
from pytest_mock import MockerFixture
from respx import MockRouter
from tests.builtin_plugins.plugin_store.utils import init_mocked_api
from tests.config import BotId, GroupId, MessageId, UserId
from tests.utils import _v11_group_message_event
@ -17,7 +15,6 @@ from tests.utils import _v11_group_message_event
async def test_update_all_plugin_basic_need_update(
app: App,
mocker: MockerFixture,
mocked_api: MockRouter,
create_bot: Callable,
tmp_path: Path,
) -> None:
@ -26,7 +23,6 @@ async def test_update_all_plugin_basic_need_update(
"""
from zhenxun.builtin_plugins.plugin_store import _matcher
init_mocked_api(mocked_api=mocked_api)
mock_base_path = mocker.patch(
"zhenxun.builtin_plugins.plugin_store.data_source.BASE_PATH",
new=tmp_path / "zhenxun",
@ -63,16 +59,12 @@ async def test_update_all_plugin_basic_need_update(
result=None,
bot=bot,
)
assert mocked_api["basic_plugins"].called
assert mocked_api["extra_plugins"].called
assert mocked_api["search_image_plugin_file_init_commit"].called
assert (mock_base_path / "plugins" / "search_image" / "__init__.py").is_file()
async def test_update_all_plugin_basic_is_new(
app: App,
mocker: MockerFixture,
mocked_api: MockRouter,
create_bot: Callable,
tmp_path: Path,
) -> None:
@ -81,14 +73,13 @@ async def test_update_all_plugin_basic_is_new(
"""
from zhenxun.builtin_plugins.plugin_store import _matcher
init_mocked_api(mocked_api=mocked_api)
mocker.patch(
"zhenxun.builtin_plugins.plugin_store.data_source.BASE_PATH",
new=tmp_path / "zhenxun",
)
mocker.patch(
"zhenxun.builtin_plugins.plugin_store.data_source.StoreManager.get_loaded_plugins",
return_value=[("search_image", "0.1")],
return_value=[("search_image", "0.2")],
)
async with app.test_matcher(_matcher) as ctx:
@ -116,5 +107,3 @@ async def test_update_all_plugin_basic_is_new(
result=None,
bot=bot,
)
assert mocked_api["basic_plugins"].called
assert mocked_api["extra_plugins"].called

View File

@ -9,7 +9,6 @@ from nonebug import App
from pytest_mock import MockerFixture
from respx import MockRouter
from tests.builtin_plugins.plugin_store.utils import init_mocked_api
from tests.config import BotId, GroupId, MessageId, UserId
from tests.utils import _v11_group_message_event
@ -17,7 +16,6 @@ from tests.utils import _v11_group_message_event
async def test_update_plugin_basic_need_update(
app: App,
mocker: MockerFixture,
mocked_api: MockRouter,
create_bot: Callable,
tmp_path: Path,
) -> None:
@ -26,7 +24,6 @@ async def test_update_plugin_basic_need_update(
"""
from zhenxun.builtin_plugins.plugin_store import _matcher
init_mocked_api(mocked_api=mocked_api)
mock_base_path = mocker.patch(
"zhenxun.builtin_plugins.plugin_store.data_source.BASE_PATH",
new=tmp_path / "zhenxun",
@ -63,16 +60,12 @@ async def test_update_plugin_basic_need_update(
result=None,
bot=bot,
)
assert mocked_api["basic_plugins"].called
assert mocked_api["extra_plugins"].called
assert mocked_api["search_image_plugin_file_init_commit"].called
assert (mock_base_path / "plugins" / "search_image" / "__init__.py").is_file()
async def test_update_plugin_basic_is_new(
app: App,
mocker: MockerFixture,
mocked_api: MockRouter,
create_bot: Callable,
tmp_path: Path,
) -> None:
@ -81,14 +74,13 @@ async def test_update_plugin_basic_is_new(
"""
from zhenxun.builtin_plugins.plugin_store import _matcher
init_mocked_api(mocked_api=mocked_api)
mocker.patch(
"zhenxun.builtin_plugins.plugin_store.data_source.BASE_PATH",
new=tmp_path / "zhenxun",
)
mocker.patch(
"zhenxun.builtin_plugins.plugin_store.data_source.StoreManager.get_loaded_plugins",
return_value=[("search_image", "0.1")],
return_value=[("search_image", "0.2")],
)
plugin_id = 1
@ -118,23 +110,17 @@ async def test_update_plugin_basic_is_new(
result=None,
bot=bot,
)
assert mocked_api["basic_plugins"].called
assert mocked_api["extra_plugins"].called
async def test_plugin_not_exist_update(
app: App,
mocker: MockerFixture,
mocked_api: MockRouter,
create_bot: Callable,
tmp_path: Path,
) -> None:
"""
测试插件不存在更新插件
"""
from zhenxun.builtin_plugins.plugin_store import _matcher
init_mocked_api(mocked_api=mocked_api)
plugin_id = -1
async with app.test_matcher(_matcher) as ctx:
@ -158,7 +144,7 @@ async def test_plugin_not_exist_update(
)
ctx.should_call_send(
event=event,
message=Message(message="插件ID不存在..."),
message=Message(message="更新插件 Id: -1 失败 e: 插件ID不存在..."),
result=None,
bot=bot,
)
@ -166,17 +152,14 @@ async def test_plugin_not_exist_update(
async def test_update_plugin_not_install(
app: App,
mocker: MockerFixture,
mocked_api: MockRouter,
create_bot: Callable,
tmp_path: Path,
) -> None:
"""
测试插件不存在更新插件
"""
from zhenxun.builtin_plugins.plugin_store import _matcher
init_mocked_api(mocked_api=mocked_api)
plugin_id = 1
async with app.test_matcher(_matcher) as ctx:
@ -200,7 +183,9 @@ async def test_update_plugin_not_install(
)
ctx.should_call_send(
event=event,
message=Message(message="插件 识图 未安装,无法更新"),
message=Message(
message="更新插件 Id: 1 失败 e: 插件 识图 未安装,无法更新"
),
result=None,
bot=bot,
)

View File

@ -1,147 +0,0 @@
# ruff: noqa: ASYNC230
from pathlib import Path
from respx import MockRouter
from tests.utils import get_content_bytes as _get_content_bytes
from tests.utils import get_response_json as _get_response_json
def get_response_json(file: str) -> dict:
return _get_response_json(Path() / "plugin_store", file=file)
def get_content_bytes(file: str) -> bytes:
return _get_content_bytes(Path() / "plugin_store", file)
def init_mocked_api(mocked_api: MockRouter) -> None:
# metadata
mocked_api.get(
"https://data.jsdelivr.com/v1/packages/gh/zhenxun-org/zhenxun_bot_plugins@main",
name="zhenxun_bot_plugins_metadata",
).respond(json=get_response_json("zhenxun_bot_plugins_metadata.json"))
mocked_api.get(
"https://data.jsdelivr.com/v1/packages/gh/xuanerwa/zhenxun_github_sub@main",
name="zhenxun_github_sub_metadata",
).respond(json=get_response_json("zhenxun_github_sub_metadata.json"))
mocked_api.get(
"https://data.jsdelivr.com/v1/packages/gh/zhenxun-org/zhenxun_bot_plugins@b101fbc",
name="zhenxun_bot_plugins_metadata_commit",
).respond(json=get_response_json("zhenxun_bot_plugins_metadata.json"))
mocked_api.get(
"https://data.jsdelivr.com/v1/packages/gh/xuanerwa/zhenxun_github_sub@f524632f78d27f9893beebdf709e0e7885cd08f1",
name="zhenxun_github_sub_metadata_commit",
).respond(json=get_response_json("zhenxun_github_sub_metadata.json"))
# tree
mocked_api.get(
"https://api.github.com/repos/zhenxun-org/zhenxun_bot_plugins/git/trees/main?recursive=1",
name="zhenxun_bot_plugins_tree",
).respond(json=get_response_json("zhenxun_bot_plugins_tree.json"))
mocked_api.get(
"https://api.github.com/repos/xuanerwa/zhenxun_github_sub/git/trees/main?recursive=1",
name="zhenxun_github_sub_tree",
).respond(json=get_response_json("zhenxun_github_sub_tree.json"))
mocked_api.get(
"https://api.github.com/repos/zhenxun-org/zhenxun_bot_plugins/git/trees/b101fbc?recursive=1",
name="zhenxun_bot_plugins_tree_commit",
).respond(json=get_response_json("zhenxun_bot_plugins_tree.json"))
mocked_api.get(
"https://api.github.com/repos/xuanerwa/zhenxun_github_sub/git/trees/f524632f78d27f9893beebdf709e0e7885cd08f1?recursive=1",
name="zhenxun_github_sub_tree_commit",
).respond(json=get_response_json("zhenxun_github_sub_tree.json"))
mocked_api.head(
"https://raw.githubusercontent.com/",
name="head_raw",
).respond(200, text="")
mocked_api.get(
"https://raw.githubusercontent.com/zhenxun-org/zhenxun_bot_plugins/b101fbc/plugins.json",
name="basic_plugins",
).respond(json=get_response_json("basic_plugins.json"))
mocked_api.get(
"https://cdn.jsdelivr.net/gh/zhenxun-org/zhenxun_bot_plugins@b101fbc/plugins.json",
name="basic_plugins_jsdelivr",
).respond(200, json=get_response_json("basic_plugins.json"))
mocked_api.get(
"https://raw.githubusercontent.com/zhenxun-org/zhenxun_bot_plugins/main/plugins.json",
name="basic_plugins_no_commit",
).respond(json=get_response_json("basic_plugins.json"))
mocked_api.get(
"https://cdn.jsdelivr.net/gh/zhenxun-org/zhenxun_bot_plugins@main/plugins.json",
name="basic_plugins_jsdelivr_no_commit",
).respond(200, json=get_response_json("basic_plugins.json"))
mocked_api.get(
"https://raw.githubusercontent.com/zhenxun-org/zhenxun_bot_plugins_index/2ed61284873c526802752b12a3fd3b5e1a59d948/plugins.json",
name="extra_plugins",
).respond(200, json=get_response_json("extra_plugins.json"))
mocked_api.get(
"https://cdn.jsdelivr.net/gh/zhenxun-org/zhenxun_bot_plugins_index@2ed61284873c526802752b12a3fd3b5e1a59d948/plugins.json",
name="extra_plugins_jsdelivr",
).respond(200, json=get_response_json("extra_plugins.json"))
mocked_api.get(
"https://raw.githubusercontent.com/zhenxun-org/zhenxun_bot_plugins_index/index/plugins.json",
name="extra_plugins_no_commit",
).respond(200, json=get_response_json("extra_plugins.json"))
mocked_api.get(
"https://cdn.jsdelivr.net/gh/zhenxun-org/zhenxun_bot_plugins_index@index/plugins.json",
name="extra_plugins_jsdelivr_no_commit",
).respond(200, json=get_response_json("extra_plugins.json"))
mocked_api.get(
"https://raw.githubusercontent.com/zhenxun-org/zhenxun_bot_plugins/main/plugins/search_image/__init__.py",
name="search_image_plugin_file_init",
).respond(content=get_content_bytes("search_image.py"))
mocked_api.get(
"https://raw.githubusercontent.com/zhenxun-org/zhenxun_bot_plugins/b101fbc/plugins/search_image/__init__.py",
name="search_image_plugin_file_init_commit",
).respond(content=get_content_bytes("search_image.py"))
mocked_api.get(
"https://raw.githubusercontent.com/zhenxun-org/zhenxun_bot_plugins/main/plugins/alapi/jitang.py",
name="jitang_plugin_file",
).respond(content=get_content_bytes("jitang.py"))
mocked_api.get(
"https://raw.githubusercontent.com/zhenxun-org/zhenxun_bot_plugins/b101fbc/plugins/alapi/jitang.py",
name="jitang_plugin_file_commit",
).respond(content=get_content_bytes("jitang.py"))
mocked_api.get(
"https://raw.githubusercontent.com/xuanerwa/zhenxun_github_sub/main/github_sub/__init__.py",
name="github_sub_plugin_file_init",
).respond(content=get_content_bytes("github_sub.py"))
mocked_api.get(
"https://raw.githubusercontent.com/xuanerwa/zhenxun_github_sub/f524632f78d27f9893beebdf709e0e7885cd08f1/github_sub/__init__.py",
name="github_sub_plugin_file_init_commit",
).respond(content=get_content_bytes("github_sub.py"))
mocked_api.get(
"https://raw.githubusercontent.com/zhenxun-org/zhenxun_bot_plugins/b101fbc/plugins/bilibili_sub/__init__.py",
name="bilibili_sub_plugin_file_init",
).respond(content=get_content_bytes("bilibili_sub.py"))
mocked_api.get(
"https://api.github.com/repos/zhenxun-org/zhenxun_bot_plugins/commits/main",
name="zhenxun_bot_plugins_commit",
).respond(json=get_response_json("zhenxun_bot_plugins_commit.json"))
mocked_api.get(
"https://git-api.zhenxun.org/repos/zhenxun-org/zhenxun_bot_plugins/commits/main",
name="zhenxun_bot_plugins_commit_proxy",
).respond(json=get_response_json("zhenxun_bot_plugins_commit.json"))
mocked_api.get(
"https://api.github.com/repos/zhenxun-org/zhenxun_bot_plugins_index/commits/index",
name="zhenxun_bot_plugins_index_commit",
).respond(json=get_response_json("zhenxun_bot_plugins_index_commit.json"))
mocked_api.get(
"https://git-api.zhenxun.org/repos/zhenxun-org/zhenxun_bot_plugins_index/commits/index",
name="zhenxun_bot_plugins_index_commit_proxy",
).respond(json=get_response_json("zhenxun_bot_plugins_index_commit.json"))
mocked_api.get(
"https://api.github.com/repos/xuanerwa/zhenxun_github_sub/commits/main",
name="zhenxun_github_sub_commit",
).respond(json=get_response_json("zhenxun_github_sub_commit.json"))
mocked_api.get(
"https://git-api.zhenxun.org/repos/xuanerwa/zhenxun_github_sub/commits/main",
name="zhenxun_github_sub_commit_proxy",
).respond(json=get_response_json("zhenxun_github_sub_commit.json"))

View File

@ -1,37 +0,0 @@
from nonebot.plugin import PluginMetadata
from zhenxun.configs.utils import PluginExtraData
__plugin_meta__ = PluginMetadata(
name="B站订阅",
description="非常便利的B站订阅通知",
usage="""
usage
B站直播番剧UP动态开播等提醒
主播订阅相当于 直播间订阅 + UP订阅
指令
添加订阅 ['主播'/'UP'/'番剧'] [id/链接/番名]
删除订阅 ['主播'/'UP'/'id'] [id]
查看订阅
示例
添加订阅主播 2345344 <-(直播房间id)
添加订阅UP 2355543 <-(个人主页id)
添加订阅番剧 史莱姆 <-(支持模糊搜索)
添加订阅番剧 125344 <-(番剧id)
删除订阅id 2324344 <-(任意id通过查看订阅获取)
""".strip(),
extra=PluginExtraData(
author="HibiKier",
version="0.3-b101fbc",
superuser_help="""
登录b站获取cookie防止风控
bil_check/检测b站
bil_login/登录b站
bil_logout/退出b站 uid
示例:
登录b站
检测b站
bil_logout 12345<-(退出登录的b站uid通过检测b站获取)
""",
).to_dict(),
)

View File

@ -1,24 +0,0 @@
from nonebot.plugin import PluginMetadata
from zhenxun.configs.utils import PluginExtraData
__plugin_meta__ = PluginMetadata(
name="github订阅",
description="订阅github用户或仓库",
usage="""
usage
github新CommentPRIssue等提醒
指令
添加github ['用户'/'仓库'] [用户名/{owner/repo}]
删除github [用户名/{owner/repo}]
查看github
示例添加github订阅 用户 HibiKier
示例添加gb订阅 仓库 HibiKier/zhenxun_bot
示例添加github 用户 HibiKier
示例删除gb订阅 HibiKier
""".strip(),
extra=PluginExtraData(
author="xuanerwa",
version="0.7",
).to_dict(),
)

View File

@ -1,17 +0,0 @@
from nonebot.plugin import PluginMetadata
from zhenxun.configs.utils import PluginExtraData
__plugin_meta__ = PluginMetadata(
name="鸡汤",
description="喏,亲手为你煮的鸡汤",
usage="""
不喝点什么感觉有点不舒服
指令
鸡汤
""".strip(),
extra=PluginExtraData(
author="HibiKier",
version="0.1",
).to_dict(),
)

View File

@ -1,18 +0,0 @@
from nonebot.plugin import PluginMetadata
from zhenxun.configs.utils import PluginExtraData
__plugin_meta__ = PluginMetadata(
name="识图",
description="以图搜图,看破本源",
usage="""
识别图片 [二次元图片]
指令
识图 [图片]
""".strip(),
extra=PluginExtraData(
author="HibiKier",
version="0.1",
menu_type="一些工具",
).to_dict(),
)

View File

@ -1,46 +0,0 @@
[
{
"name": "鸡汤",
"module": "jitang",
"module_path": "plugins.alapi.jitang",
"description": "喏,亲手为你煮的鸡汤",
"usage": "不喝点什么感觉有点不舒服\n 指令:\n 鸡汤",
"author": "HibiKier",
"version": "0.1",
"plugin_type": "NORMAL",
"is_dir": false
},
{
"name": "识图",
"module": "search_image",
"module_path": "plugins.search_image",
"description": "以图搜图,看破本源",
"usage": "识别图片 [二次元图片]\n 指令:\n 识图 [图片]",
"author": "HibiKier",
"version": "0.1",
"plugin_type": "NORMAL",
"is_dir": true
},
{
"name": "网易云热评",
"module": "comments_163",
"module_path": "plugins.alapi.comments_163",
"description": "生了个人,我很抱歉",
"usage": "到点了,还是防不了下塔\n 指令:\n 网易云热评/到点了/12点了",
"author": "HibiKier",
"version": "0.1",
"plugin_type": "NORMAL",
"is_dir": false
},
{
"name": "B站订阅",
"module": "bilibili_sub",
"module_path": "plugins.bilibili_sub",
"description": "非常便利的B站订阅通知",
"usage": "B站直播番剧UP动态开播等提醒",
"author": "HibiKier",
"version": "0.3-b101fbc",
"plugin_type": "NORMAL",
"is_dir": true
}
]

View File

@ -1,26 +0,0 @@
[
{
"name": "github订阅",
"module": "github_sub",
"module_path": "github_sub",
"description": "订阅github用户或仓库",
"usage": "usage\n github新CommentPRIssue等提醒\n 指令:\n 添加github ['用户'/'仓库'] [用户名/{owner/repo}]\n 删除github [用户名/{owner/repo}]\n 查看github\n 示例添加github订阅 用户 HibiKier\n 示例添加gb订阅 仓库 HibiKier/zhenxun_bot\n 示例添加github 用户 HibiKier\n 示例删除gb订阅 HibiKier",
"author": "xuanerwa",
"version": "0.7",
"plugin_type": "NORMAL",
"is_dir": true,
"github_url": "https://github.com/xuanerwa/zhenxun_github_sub"
},
{
"name": "Minecraft查服",
"module": "mc_check",
"module_path": "mc_check",
"description": "Minecraft服务器状态查询支持IPv6",
"usage": "Minecraft服务器状态查询支持IPv6\n用法\n\t查服 [ip]:[端口] / 查服 [ip]\n\t设置语言 zh-cn\n\t当前语言\n\t语言列表\neg:\t\nmcheck ip:port / mcheck ip\n\tset_lang en\n\tlang_now\n\tlang_list",
"author": "molanp",
"version": "1.13",
"plugin_type": "NORMAL",
"is_dir": true,
"github_url": "https://github.com/molanp/zhenxun_check_Minecraft"
}
]

View File

@ -1,101 +0,0 @@
{
"sha": "b101fbc",
"node_id": "C_kwDOMndPGNoAKGIxMDFmYmNlODg4NjA4ZTJiYmU1YjVmZDI3OWUxNDY1MTY4ODEyYzc",
"commit": {
"author": {
"name": "xuaner",
"email": "xuaner_wa@qq.com",
"date": "2024-09-20T12:08:27Z"
},
"committer": {
"name": "xuaner",
"email": "xuaner_wa@qq.com",
"date": "2024-09-20T12:08:27Z"
},
"message": "🐛修复B站订阅bug",
"tree": {
"sha": "0566306219a434f7122798647498faef692c1879",
"url": "https://api.github.com/repos/zhenxun-org/zhenxun_bot_plugins/git/trees/0566306219a434f7122798647498faef692c1879"
},
"url": "https://api.github.com/repos/zhenxun-org/zhenxun_bot_plugins/git/commits/b101fbce888608e2bbe5b5fd279e1465168812c7",
"comment_count": 0,
"verification": {
"verified": false,
"reason": "unsigned",
"signature": null,
"payload": null,
"verified_at": null
}
},
"url": "https://api.github.com/repos/zhenxun-org/zhenxun_bot_plugins/commits/b101fbce888608e2bbe5b5fd279e1465168812c7",
"html_url": "https://github.com/zhenxun-org/zhenxun_bot_plugins/commit/b101fbce888608e2bbe5b5fd279e1465168812c7",
"comments_url": "https://api.github.com/repos/zhenxun-org/zhenxun_bot_plugins/commits/b101fbce888608e2bbe5b5fd279e1465168812c7/comments",
"author": {
"login": "xuanerwa",
"id": 58063798,
"node_id": "MDQ6VXNlcjU4MDYzNzk4",
"avatar_url": "https://avatars.githubusercontent.com/u/58063798?v=4",
"gravatar_id": "",
"url": "https://api.github.com/users/xuanerwa",
"html_url": "https://github.com/xuanerwa",
"followers_url": "https://api.github.com/users/xuanerwa/followers",
"following_url": "https://api.github.com/users/xuanerwa/following{/other_user}",
"gists_url": "https://api.github.com/users/xuanerwa/gists{/gist_id}",
"starred_url": "https://api.github.com/users/xuanerwa/starred{/owner}{/repo}",
"subscriptions_url": "https://api.github.com/users/xuanerwa/subscriptions",
"organizations_url": "https://api.github.com/users/xuanerwa/orgs",
"repos_url": "https://api.github.com/users/xuanerwa/repos",
"events_url": "https://api.github.com/users/xuanerwa/events{/privacy}",
"received_events_url": "https://api.github.com/users/xuanerwa/received_events",
"type": "User",
"user_view_type": "public",
"site_admin": false
},
"committer": {
"login": "xuanerwa",
"id": 58063798,
"node_id": "MDQ6VXNlcjU4MDYzNzk4",
"avatar_url": "https://avatars.githubusercontent.com/u/58063798?v=4",
"gravatar_id": "",
"url": "https://api.github.com/users/xuanerwa",
"html_url": "https://github.com/xuanerwa",
"followers_url": "https://api.github.com/users/xuanerwa/followers",
"following_url": "https://api.github.com/users/xuanerwa/following{/other_user}",
"gists_url": "https://api.github.com/users/xuanerwa/gists{/gist_id}",
"starred_url": "https://api.github.com/users/xuanerwa/starred{/owner}{/repo}",
"subscriptions_url": "https://api.github.com/users/xuanerwa/subscriptions",
"organizations_url": "https://api.github.com/users/xuanerwa/orgs",
"repos_url": "https://api.github.com/users/xuanerwa/repos",
"events_url": "https://api.github.com/users/xuanerwa/events{/privacy}",
"received_events_url": "https://api.github.com/users/xuanerwa/received_events",
"type": "User",
"user_view_type": "public",
"site_admin": false
},
"parents": [
{
"sha": "a545dfa0c4e149595f7ddd50dc34c55513738fb9",
"url": "https://api.github.com/repos/zhenxun-org/zhenxun_bot_plugins/commits/a545dfa0c4e149595f7ddd50dc34c55513738fb9",
"html_url": "https://github.com/zhenxun-org/zhenxun_bot_plugins/commit/a545dfa0c4e149595f7ddd50dc34c55513738fb9"
}
],
"stats": {
"total": 4,
"additions": 2,
"deletions": 2
},
"files": [
{
"sha": "0fbc9695db04c56174e3bff933f670d8d2df2abc",
"filename": "plugins/bilibili_sub/data_source.py",
"status": "modified",
"additions": 2,
"deletions": 2,
"changes": 4,
"blob_url": "https://github.com/zhenxun-org/zhenxun_bot_plugins/blob/b101fbce888608e2bbe5b5fd279e1465168812c7/plugins%2Fbilibili_sub%2Fdata_source.py",
"raw_url": "https://github.com/zhenxun-org/zhenxun_bot_plugins/raw/b101fbce888608e2bbe5b5fd279e1465168812c7/plugins%2Fbilibili_sub%2Fdata_source.py",
"contents_url": "https://api.github.com/repos/zhenxun-org/zhenxun_bot_plugins/contents/plugins%2Fbilibili_sub%2Fdata_source.py?ref=b101fbce888608e2bbe5b5fd279e1465168812c7",
"patch": "@@ -271,14 +271,14 @@ async def _get_live_status(id_: int) -> list:\n sub = await BilibiliSub.get_or_none(sub_id=id_)\n msg_list = []\n if sub.live_status != live_status:\n+ await BilibiliSub.sub_handle(id_, live_status=live_status)\n image = None\n try:\n image_bytes = await fetch_image_bytes(cover)\n image = BuildImage(background = image_bytes)\n except Exception as e:\n logger.error(f\"图片构造失败,错误信息:{e}\")\n if sub.live_status in [0, 2] and live_status == 1 and image:\n- await BilibiliSub.sub_handle(id_, live_status=live_status)\n msg_list = [\n image,\n \"\\n\",\n@@ -322,7 +322,7 @@ async def _get_up_status(id_: int) -> list:\n video = video_info[\"list\"][\"vlist\"][0]\n latest_video_created = video[\"created\"]\n msg_list = []\n- if dynamic_img:\n+ if dynamic_img and _user.dynamic_upload_time < dynamic_upload_time:\n await BilibiliSub.sub_handle(id_, dynamic_upload_time=dynamic_upload_time)\n msg_list = [f\"{uname} 发布了动态!📢\\n\", dynamic_img, f\"\\n查看详情{link}\"]\n if ("
}
]
}

View File

@ -1,101 +0,0 @@
{
"sha": "2ed61284873c526802752b12a3fd3b5e1a59d948",
"node_id": "C_kwDOGK5Du9oAKDJlZDYxMjg0ODczYzUyNjgwMjc1MmIxMmEzZmQzYjVlMWE1OWQ5NDg",
"commit": {
"author": {
"name": "zhenxunflow[bot]",
"email": "179375394+zhenxunflow[bot]@users.noreply.github.com",
"date": "2025-01-26T09:04:55Z"
},
"committer": {
"name": "GitHub",
"email": "noreply@github.com",
"date": "2025-01-26T09:04:55Z"
},
"message": ":beers: publish plugin AI全家桶 (#235) (#236)\n\nCo-authored-by: molanp <molanp@users.noreply.github.com>",
"tree": {
"sha": "64ea463e084b6ab0def0322c6ad53799054ec9b3",
"url": "https://api.github.com/repos/zhenxun-org/zhenxun_bot_plugins_index/git/trees/64ea463e084b6ab0def0322c6ad53799054ec9b3"
},
"url": "https://api.github.com/repos/zhenxun-org/zhenxun_bot_plugins_index/git/commits/2ed61284873c526802752b12a3fd3b5e1a59d948",
"comment_count": 0,
"verification": {
"verified": true,
"reason": "valid",
"signature": "-----BEGIN PGP SIGNATURE-----\n\nwsFcBAABCAAQBQJnlfq3CRC1aQ7uu5UhlAAA+n0QADPVjQQIHFlNcTEgdq3LGQ1X\nm8+H5N07E5JD+83LdyU9/YOvqY/WURwFsQ0T4+23icUWEOD4LB5qZIdVJBYHseto\nbJNmYd1kZxpvsONoiK/2Uk6JoeVnEQIR+dTbB0wBlbL0lRt1WtTXHpLQbFXuXn3q\nJh4SdSj283UZ6D2sBADblPZ7DqaTmLlpgwrTPx0OH5wIhcuORkzOl6x0DabcVAYu\nu5zHSKM9c7g+jEmrqRuVy+ZlZMDPN4S3gDNzEhoTn4tn+KNzSIja4n7ZMRD+1a5X\nMIP3aXcVBqCyuYc6DU76IvjlaL/MjnlPwfOtx1zu+pNxZKNaSpojtqopp3blfk0E\n8s8lD9utDgUaUrdPWgpiMDjj+oNMye91CGomNDfv0fNGUlBGT6r48qaq1z8BwAAR\nzgDsF13kDuKTTkT/6T8CdgCpJtwvxMptUr2XFRtn4xwf/gJdqrbEc4fHTOSHqxzh\ncDfXuP+Sorla4oJ0duygTsulpr/zguX8RJWJml35VjERw54ARAVvhZn19G9qQVJo\n2QIp+xtyTjkM3yTeN4UDXFt4lDuxz3+l1MBduj+CHn+WTgxyJUpX2TA1GVfni9xT\npOMOtzuDQfDIxTNB6hFjSWATb1/E5ys1lfK09n+dRhmvC/Be+b5M4WlyX3cqy/za\ns0XxuZ+CHzLfHaPxFUem\n=VYpl\n-----END PGP SIGNATURE-----\n",
"payload": "tree 64ea463e084b6ab0def0322c6ad53799054ec9b3\nparent 5df26081d40e3000a7beedb73954d4df397c93fa\nauthor zhenxunflow[bot] <179375394+zhenxunflow[bot]@users.noreply.github.com> 1737882295 +0800\ncommitter GitHub <noreply@github.com> 1737882295 +0800\n\n:beers: publish plugin AI全家桶 (#235) (#236)\n\nCo-authored-by: molanp <molanp@users.noreply.github.com>",
"verified_at": "2025-01-26T09:04:58Z"
}
},
"url": "https://api.github.com/repos/zhenxun-org/zhenxun_bot_plugins_index/commits/2ed61284873c526802752b12a3fd3b5e1a59d948",
"html_url": "https://github.com/zhenxun-org/zhenxun_bot_plugins_index/commit/2ed61284873c526802752b12a3fd3b5e1a59d948",
"comments_url": "https://api.github.com/repos/zhenxun-org/zhenxun_bot_plugins_index/commits/2ed61284873c526802752b12a3fd3b5e1a59d948/comments",
"author": {
"login": "zhenxunflow[bot]",
"id": 179375394,
"node_id": "BOT_kgDOCrENIg",
"avatar_url": "https://avatars.githubusercontent.com/in/978723?v=4",
"gravatar_id": "",
"url": "https://api.github.com/users/zhenxunflow%5Bbot%5D",
"html_url": "https://github.com/apps/zhenxunflow",
"followers_url": "https://api.github.com/users/zhenxunflow%5Bbot%5D/followers",
"following_url": "https://api.github.com/users/zhenxunflow%5Bbot%5D/following{/other_user}",
"gists_url": "https://api.github.com/users/zhenxunflow%5Bbot%5D/gists{/gist_id}",
"starred_url": "https://api.github.com/users/zhenxunflow%5Bbot%5D/starred{/owner}{/repo}",
"subscriptions_url": "https://api.github.com/users/zhenxunflow%5Bbot%5D/subscriptions",
"organizations_url": "https://api.github.com/users/zhenxunflow%5Bbot%5D/orgs",
"repos_url": "https://api.github.com/users/zhenxunflow%5Bbot%5D/repos",
"events_url": "https://api.github.com/users/zhenxunflow%5Bbot%5D/events{/privacy}",
"received_events_url": "https://api.github.com/users/zhenxunflow%5Bbot%5D/received_events",
"type": "Bot",
"user_view_type": "public",
"site_admin": false
},
"committer": {
"login": "web-flow",
"id": 19864447,
"node_id": "MDQ6VXNlcjE5ODY0NDQ3",
"avatar_url": "https://avatars.githubusercontent.com/u/19864447?v=4",
"gravatar_id": "",
"url": "https://api.github.com/users/web-flow",
"html_url": "https://github.com/web-flow",
"followers_url": "https://api.github.com/users/web-flow/followers",
"following_url": "https://api.github.com/users/web-flow/following{/other_user}",
"gists_url": "https://api.github.com/users/web-flow/gists{/gist_id}",
"starred_url": "https://api.github.com/users/web-flow/starred{/owner}{/repo}",
"subscriptions_url": "https://api.github.com/users/web-flow/subscriptions",
"organizations_url": "https://api.github.com/users/web-flow/orgs",
"repos_url": "https://api.github.com/users/web-flow/repos",
"events_url": "https://api.github.com/users/web-flow/events{/privacy}",
"received_events_url": "https://api.github.com/users/web-flow/received_events",
"type": "User",
"user_view_type": "public",
"site_admin": false
},
"parents": [
{
"sha": "5df26081d40e3000a7beedb73954d4df397c93fa",
"url": "https://api.github.com/repos/zhenxun-org/zhenxun_bot_plugins_index/commits/5df26081d40e3000a7beedb73954d4df397c93fa",
"html_url": "https://github.com/zhenxun-org/zhenxun_bot_plugins_index/commit/5df26081d40e3000a7beedb73954d4df397c93fa"
}
],
"stats": {
"total": 11,
"additions": 11,
"deletions": 0
},
"files": [
{
"sha": "3d98392c25d38f5d375b830aed6e2298e47e5601",
"filename": "plugins.json",
"status": "modified",
"additions": 11,
"deletions": 0,
"changes": 11,
"blob_url": "https://github.com/zhenxun-org/zhenxun_bot_plugins_index/blob/2ed61284873c526802752b12a3fd3b5e1a59d948/plugins.json",
"raw_url": "https://github.com/zhenxun-org/zhenxun_bot_plugins_index/raw/2ed61284873c526802752b12a3fd3b5e1a59d948/plugins.json",
"contents_url": "https://api.github.com/repos/zhenxun-org/zhenxun_bot_plugins_index/contents/plugins.json?ref=2ed61284873c526802752b12a3fd3b5e1a59d948",
"patch": "@@ -53,5 +53,16 @@\n \"plugin_type\": \"NORMAL\",\n \"is_dir\": true,\n \"github_url\": \"https://github.com/PackageInstaller/zhenxun_plugin_draw_painting/tree/master\"\n+ },\n+ \"AI全家桶\": {\n+ \"module\": \"zhipu_toolkit\",\n+ \"module_path\": \"zhipu_toolkit\",\n+ \"description\": \"AI全家桶一次安装到处使用省时省力省心\",\n+ \"usage\": \"AI全家桶一次安装到处使用省时省力省心\\n usage:\\n 生成图片 <prompt>\\n 生成视频 <prompt>\\n 清理我的会话: 用于清理你与AI的聊天记录\\n 或者与机器人聊天,\\n 例如;\\n @Bot抱抱\\n 小真寻老婆\",\n+ \"author\": \"molanp\",\n+ \"version\": \"0.1\",\n+ \"plugin_type\": \"NORMAL\",\n+ \"is_dir\": true,\n+ \"github_url\": \"https://github.com/molanp/zhenxun_plugin_zhipu_toolkit\"\n }\n }"
}
]
}

View File

@ -1,83 +0,0 @@
{
"type": "gh",
"name": "zhenxun-org/zhenxun_bot_plugins",
"version": "main",
"default": null,
"files": [
{
"type": "directory",
"name": "plugins",
"files": [
{
"type": "directory",
"name": "search_image",
"files": [
{
"type": "file",
"name": "__init__.py",
"hash": "a4Yp9HPoBzMwvnQDT495u0yYqTQWofkOyHxEi1FdVb0=",
"size": 3010
}
]
},
{
"type": "directory",
"name": "alapi",
"files": [
{
"type": "file",
"name": "__init__.py",
"hash": "ndDxtO0pAq3ZTb4RdqW7FTDgOGC/RjS1dnwdaQfT0uQ=",
"size": 284
},
{
"type": "file",
"name": "_data_source.py",
"hash": "KOLqtj4TQWWQco5bA4tWFc7A0z1ruMyDk1RiKeqJHRA=",
"size": 919
},
{
"type": "file",
"name": "comments_163.py",
"hash": "Q5pZsj1Pj+EJMdKYcPtLqejcXAWUQIoXVQG49PZPaSI=",
"size": 1593
},
{
"type": "file",
"name": "cover.py",
"hash": "QSjtcy0oVrjaRiAWZKmUJlp0L4DQqEcdYNmExNo9mgc=",
"size": 1438
},
{
"type": "file",
"name": "jitang.py",
"hash": "xh43Osxt0xogTH448gUMC+/DaSGmCFme8DWUqC25IbU=",
"size": 1411
},
{
"type": "file",
"name": "poetry.py",
"hash": "Aj2unoNQboj3/0LhIrYU+dCa5jvMdpjMYXYUayhjuz4=",
"size": 1530
}
]
},
{
"type": "directory",
"name": "bilibili_sub",
"files": [
{
"type": "file",
"name": "__init__.py",
"hash": "407DCgNFcZnuEK+d716j8EWrFQc4Nlxa35V3yemy3WQ=",
"size": 14293
}
]
}
]
}
],
"links": {
"stats": "https://data.jsdelivr.com/v1/stats/packages/gh/zhenxun-org/zhenxun_bot_plugins@main"
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,101 +0,0 @@
{
"sha": "f524632f78d27f9893beebdf709e0e7885cd08f1",
"node_id": "C_kwDOJAjBPdoAKGY1MjQ2MzJmNzhkMjdmOTg5M2JlZWJkZjcwOWUwZTc4ODVjZDA4ZjE",
"commit": {
"author": {
"name": "xuaner",
"email": "xuaner_wa@qq.com",
"date": "2024-11-18T18:17:15Z"
},
"committer": {
"name": "xuaner",
"email": "xuaner_wa@qq.com",
"date": "2024-11-18T18:17:15Z"
},
"message": "fix bug",
"tree": {
"sha": "b6b1b4f06cc869b9f38d7b51bdca3a2c575255e4",
"url": "https://api.github.com/repos/xuanerwa/zhenxun_github_sub/git/trees/b6b1b4f06cc869b9f38d7b51bdca3a2c575255e4"
},
"url": "https://api.github.com/repos/xuanerwa/zhenxun_github_sub/git/commits/f524632f78d27f9893beebdf709e0e7885cd08f1",
"comment_count": 0,
"verification": {
"verified": false,
"reason": "unsigned",
"signature": null,
"payload": null,
"verified_at": null
}
},
"url": "https://api.github.com/repos/xuanerwa/zhenxun_github_sub/commits/f524632f78d27f9893beebdf709e0e7885cd08f1",
"html_url": "https://github.com/xuanerwa/zhenxun_github_sub/commit/f524632f78d27f9893beebdf709e0e7885cd08f1",
"comments_url": "https://api.github.com/repos/xuanerwa/zhenxun_github_sub/commits/f524632f78d27f9893beebdf709e0e7885cd08f1/comments",
"author": {
"login": "xuanerwa",
"id": 58063798,
"node_id": "MDQ6VXNlcjU4MDYzNzk4",
"avatar_url": "https://avatars.githubusercontent.com/u/58063798?v=4",
"gravatar_id": "",
"url": "https://api.github.com/users/xuanerwa",
"html_url": "https://github.com/xuanerwa",
"followers_url": "https://api.github.com/users/xuanerwa/followers",
"following_url": "https://api.github.com/users/xuanerwa/following{/other_user}",
"gists_url": "https://api.github.com/users/xuanerwa/gists{/gist_id}",
"starred_url": "https://api.github.com/users/xuanerwa/starred{/owner}{/repo}",
"subscriptions_url": "https://api.github.com/users/xuanerwa/subscriptions",
"organizations_url": "https://api.github.com/users/xuanerwa/orgs",
"repos_url": "https://api.github.com/users/xuanerwa/repos",
"events_url": "https://api.github.com/users/xuanerwa/events{/privacy}",
"received_events_url": "https://api.github.com/users/xuanerwa/received_events",
"type": "User",
"user_view_type": "public",
"site_admin": false
},
"committer": {
"login": "xuanerwa",
"id": 58063798,
"node_id": "MDQ6VXNlcjU4MDYzNzk4",
"avatar_url": "https://avatars.githubusercontent.com/u/58063798?v=4",
"gravatar_id": "",
"url": "https://api.github.com/users/xuanerwa",
"html_url": "https://github.com/xuanerwa",
"followers_url": "https://api.github.com/users/xuanerwa/followers",
"following_url": "https://api.github.com/users/xuanerwa/following{/other_user}",
"gists_url": "https://api.github.com/users/xuanerwa/gists{/gist_id}",
"starred_url": "https://api.github.com/users/xuanerwa/starred{/owner}{/repo}",
"subscriptions_url": "https://api.github.com/users/xuanerwa/subscriptions",
"organizations_url": "https://api.github.com/users/xuanerwa/orgs",
"repos_url": "https://api.github.com/users/xuanerwa/repos",
"events_url": "https://api.github.com/users/xuanerwa/events{/privacy}",
"received_events_url": "https://api.github.com/users/xuanerwa/received_events",
"type": "User",
"user_view_type": "public",
"site_admin": false
},
"parents": [
{
"sha": "91e5e2c792e79193830441d555769aa54acd2d15",
"url": "https://api.github.com/repos/xuanerwa/zhenxun_github_sub/commits/91e5e2c792e79193830441d555769aa54acd2d15",
"html_url": "https://github.com/xuanerwa/zhenxun_github_sub/commit/91e5e2c792e79193830441d555769aa54acd2d15"
}
],
"stats": {
"total": 2,
"additions": 1,
"deletions": 1
},
"files": [
{
"sha": "764a5f7b81554c4c10d29486ea5d9105e505cec3",
"filename": "github_sub/__init__.py",
"status": "modified",
"additions": 1,
"deletions": 1,
"changes": 2,
"blob_url": "https://github.com/xuanerwa/zhenxun_github_sub/blob/f524632f78d27f9893beebdf709e0e7885cd08f1/github_sub%2F__init__.py",
"raw_url": "https://github.com/xuanerwa/zhenxun_github_sub/raw/f524632f78d27f9893beebdf709e0e7885cd08f1/github_sub%2F__init__.py",
"contents_url": "https://api.github.com/repos/xuanerwa/zhenxun_github_sub/contents/github_sub%2F__init__.py?ref=f524632f78d27f9893beebdf709e0e7885cd08f1",
"patch": "@@ -168,7 +168,7 @@ async def _(session: EventSession):\n # 推送\n @scheduler.scheduled_job(\n \"interval\",\n- seconds=base_config.get(\"CHECK_API_TIME\") if base_config.get(\"CHECK_TIME\") else 30,\n+ seconds=base_config.get(\"CHECK_API_TIME\") if base_config.get(\"CHECK_API_TIME\") else 30,\n )\n async def _():\n bots = nonebot.get_bots()"
}
]
}

View File

@ -1,23 +0,0 @@
{
"type": "gh",
"name": "xuanerwa/zhenxun_github_sub",
"version": "main",
"default": null,
"files": [
{
"type": "directory",
"name": "github_sub",
"files": [
{
"type": "file",
"name": "__init__.py",
"hash": "z1C5BBK0+atbDghbyRlF2xIDwk0HQdHM1yXQZkF7/t8=",
"size": 7551
}
]
}
],
"links": {
"stats": "https://data.jsdelivr.com/v1/stats/packages/gh/xuanerwa/zhenxun_github_sub@main"
}
}

View File

@ -1,38 +0,0 @@
{
"sha": "438298b9e88f9dafa7020e99d7c7b4c98f93aea6",
"url": "https://api.github.com/repos/xuanerwa/zhenxun_github_sub/git/trees/438298b9e88f9dafa7020e99d7c7b4c98f93aea6",
"tree": [
{
"path": "LICENSE",
"mode": "100644",
"type": "blob",
"sha": "f288702d2fa16d3cdf0035b15a9fcbc552cd88e7",
"size": 35149,
"url": "https://api.github.com/repos/xuanerwa/zhenxun_github_sub/git/blobs/f288702d2fa16d3cdf0035b15a9fcbc552cd88e7"
},
{
"path": "README.md",
"mode": "100644",
"type": "blob",
"sha": "e974cfc9b973d4a041f03e693ea20563a933b7ca",
"size": 955,
"url": "https://api.github.com/repos/xuanerwa/zhenxun_github_sub/git/blobs/e974cfc9b973d4a041f03e693ea20563a933b7ca"
},
{
"path": "github_sub",
"mode": "040000",
"type": "tree",
"sha": "0f7d76bcf472e2ab0610fa542b067633d6e3ae7e",
"url": "https://api.github.com/repos/xuanerwa/zhenxun_github_sub/git/trees/0f7d76bcf472e2ab0610fa542b067633d6e3ae7e"
},
{
"path": "github_sub/__init__.py",
"mode": "100644",
"type": "blob",
"sha": "7d17fd49fe82fa3897afcef61b2c694ed93a4ba3",
"size": 7551,
"url": "https://api.github.com/repos/xuanerwa/zhenxun_github_sub/git/blobs/7d17fd49fe82fa3897afcef61b2c694ed93a4ba3"
}
],
"truncated": false
}

View File

@ -17,7 +17,7 @@ from zhenxun.models.user_console import UserConsole
from zhenxun.services.log import logger
from zhenxun.utils.decorator.shop import shop_register
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from zhenxun.utils.manager.resource_manager import ResourceManager
from zhenxun.utils.manager.zhenxun_repo_manager import ZhenxunRepoManager
from zhenxun.utils.platform import PlatformUtils
driver: Driver = nonebot.get_driver()
@ -85,7 +85,8 @@ from bag_users t1
@PriorityLifecycle.on_startup(priority=5)
async def _():
await ResourceManager.init_resources()
if not ZhenxunRepoManager.check_resources_exists():
await ZhenxunRepoManager.resources_update()
"""签到与用户的数据迁移"""
if goods_list := await GoodsInfo.filter(uuid__isnull=True).all():
for goods in goods_list:

View File

@ -114,7 +114,7 @@ class BanManage:
is_superuser: 是否为超级用户操作
返回:
tuple[bool, str]: 是否unban成功, 群组/用户id或提示
tuple[bool, str | Non]: 是否unban成功, 群组/用户id或提示
"""
user_level = 9999
if not is_superuser and user_id and session.id1:
@ -126,10 +126,10 @@ class BanManage:
if ban_data.ban_level > user_level:
return False, "unBan权限等级不足捏..."
await ban_data.delete()
return (True, ban_data.user_id if ban_data.user_id else ban_data.group_id)
return True, ban_data.user_id or ban_data.group_id
elif await BanConsole.check_ban_level(user_id, group_id, user_level):
await BanConsole.unban(user_id, group_id)
return True, str(group_id)
return True, group_id or ""
return False, "该用户/群组不在黑名单中不足捏..."
@classmethod

View File

@ -104,25 +104,16 @@ class MemberUpdateManage:
exist_member_list.append(member.id)
if data_list[0]:
try:
await GroupInfoUser.bulk_create(data_list[0], 30)
await GroupInfoUser.bulk_create(
data_list[0], 30, ignore_conflicts=True
)
logger.debug(
f"创建用户数据 {len(data_list[0])}",
"更新群组成员信息",
target=group_id,
)
except Exception as e:
logger.error(
f"批量创建用户数据失败: {e},开始进行逐个存储",
"更新群组成员信息",
)
for u in data_list[0]:
try:
await u.save()
except Exception as e:
logger.error(
f"创建用户 {u.user_name}({u.user_id}) 数据失败: {e}",
"更新群组成员信息",
)
logger.error("批量创建用户数据失败", "更新群组成员信息", e=e)
if data_list[1]:
await GroupInfoUser.bulk_update(data_list[1], ["user_name"], 30)
logger.debug(

View File

@ -16,10 +16,6 @@ from nonebot_plugin_uninfo import Uninfo
from zhenxun.configs.utils import PluginExtraData
from zhenxun.services.log import logger
from zhenxun.utils.enum import PluginType
from zhenxun.utils.manager.resource_manager import (
DownloadResourceException,
ResourceManager,
)
from zhenxun.utils.message import MessageUtils
from ._data_source import UpdateManager
@ -32,15 +28,23 @@ __plugin_meta__ = PluginMetadata(
检查更新真寻最新版本包括了自动更新
资源文件大小一般在130mb左右除非必须更新一般仅更新代码文件
指令
检查更新 [main|release|resource|webui] ?[-r]
检查更新 [main|release|resource|webui] ?[-r] ?[-f] ?[-z] ?[-t]
main: main分支
release: 最新release
resource: 资源文件
webui: webui文件
-r: 下载资源文件一般在更新main或release时使用
-f: 强制更新一般用于更新main时使用仅git更新时有效
-s: 更新源 git ali默认使用ali
-z: 下载zip文件进行更新仅git有效
-t: 更新方式git或download默认使用git
git: 使用git pull推荐
download: 通过commit hash比较文件后下载更新仅git有效
示例:
检查更新 main
检查更新 main -r
检查更新 main -f
检查更新 release -r
检查更新 resource
检查更新 webui
@ -57,6 +61,9 @@ _matcher = on_alconna(
"检查更新",
Args["ver_type?", ["main", "release", "resource", "webui"]],
Option("-r|--resource", action=store_true, help_text="下载资源文件"),
Option("-f|--force", action=store_true, help_text="强制更新"),
Option("-s", Args["source?", ["git", "ali"]], help_text="更新源"),
Option("-z|--zip", action=store_true, help_text="下载zip文件"),
),
priority=1,
block=True,
@ -71,30 +78,55 @@ async def _(
session: Uninfo,
ver_type: Match[str],
resource: Query[bool] = Query("resource", False),
force: Query[bool] = Query("force", False),
source: Query[str] = Query("source", "ali"),
zip: Query[bool] = Query("zip", False),
):
result = ""
await MessageUtils.build_message("正在进行检查更新...").send(reply_to=True)
if ver_type.result in {"main", "release"}:
ver_type_str = ver_type.result
source_str = source.result
if ver_type_str in {"main", "release"}:
if not ver_type.available:
result = await UpdateManager.check_version()
result += await UpdateManager.check_version()
logger.info("查看当前版本...", "检查更新", session=session)
await MessageUtils.build_message(result).finish()
try:
result = await UpdateManager.update(bot, session.user.id, ver_type.result)
result += await UpdateManager.update_zhenxun(
bot,
session.user.id,
ver_type_str, # type: ignore
force.result,
source_str, # type: ignore
zip.result,
)
except Exception as e:
logger.error("版本更新失败...", "检查更新", session=session, e=e)
await MessageUtils.build_message(f"更新版本失败...e: {e}").finish()
elif ver_type.result == "webui":
result = await UpdateManager.update_webui()
if zip.result:
source_str = None
try:
result += await UpdateManager.update_webui(
source_str, # type: ignore
"test",
True,
)
except Exception as e:
logger.error("WebUI更新失败...", "检查更新", session=session, e=e)
result += "\nWebUI更新错误..."
if resource.result or ver_type.result == "resource":
try:
await ResourceManager.init_resources(True)
result += "\n资源文件更新成功!"
except DownloadResourceException:
result += "\n资源更新下载失败..."
if zip.result:
source_str = None
result += await UpdateManager.update_resources(
source_str, # type: ignore
"main",
force.result,
)
except Exception as e:
logger.error("资源更新下载失败...", "检查更新", session=session, e=e)
result += "\n资源更新未知错误..."
result += "\n资源更新错误..."
if result:
await MessageUtils.build_message(result.strip()).finish()
await MessageUtils.build_message("更新版本失败...").finish()

View File

@ -1,170 +1,19 @@
import os
import shutil
import subprocess
import tarfile
import zipfile
from typing import Literal
from nonebot.adapters import Bot
from nonebot.utils import run_sync
from zhenxun.configs.path_config import DATA_PATH
from zhenxun.services.log import logger
from zhenxun.utils.github_utils import GithubUtils
from zhenxun.utils.github_utils.models import RepoInfo
from zhenxun.utils.http_utils import AsyncHttpx
from zhenxun.utils.manager.virtual_env_package_manager import VirtualEnvPackageManager
from zhenxun.utils.manager.zhenxun_repo_manager import (
ZhenxunRepoConfig,
ZhenxunRepoManager,
)
from zhenxun.utils.platform import PlatformUtils
from .config import (
BACKUP_PATH,
BASE_PATH,
BASE_PATH_STRING,
COMMAND,
DEFAULT_GITHUB_URL,
DOWNLOAD_GZ_FILE,
DOWNLOAD_ZIP_FILE,
PYPROJECT_FILE,
PYPROJECT_FILE_STRING,
PYPROJECT_LOCK_FILE,
PYPROJECT_LOCK_FILE_STRING,
RELEASE_URL,
REPLACE_FOLDERS,
REQ_TXT_FILE,
REQ_TXT_FILE_STRING,
TMP_PATH,
VERSION_FILE,
)
def install_requirement():
requirement_path = (REQ_TXT_FILE).absolute()
if not requirement_path.exists():
logger.debug(
f"没有找到zhenxun的requirement.txt,目标路径为{requirement_path}", COMMAND
)
return
try:
result = subprocess.run(
["pip", "install", "-r", str(requirement_path)],
check=True,
capture_output=True,
text=True,
)
logger.debug(f"成功安装真寻依赖,日志:\n{result.stdout}", COMMAND)
except subprocess.CalledProcessError as e:
logger.error(f"安装真寻依赖失败,错误:\n{e.stderr}", COMMAND, e=e)
@run_sync
def _file_handle(latest_version: str | None):
"""文件移动操作
参数:
latest_version: 版本号
"""
BACKUP_PATH.mkdir(exist_ok=True, parents=True)
logger.debug("开始解压文件压缩包...", COMMAND)
download_file = DOWNLOAD_GZ_FILE
if DOWNLOAD_GZ_FILE.exists():
tf = tarfile.open(DOWNLOAD_GZ_FILE)
else:
download_file = DOWNLOAD_ZIP_FILE
tf = zipfile.ZipFile(DOWNLOAD_ZIP_FILE)
tf.extractall(TMP_PATH)
logger.debug("解压文件压缩包完成...", COMMAND)
download_file_path = TMP_PATH / next(
x for x in os.listdir(TMP_PATH) if (TMP_PATH / x).is_dir()
)
_pyproject = download_file_path / PYPROJECT_FILE_STRING
_lock_file = download_file_path / PYPROJECT_LOCK_FILE_STRING
_req_file = download_file_path / REQ_TXT_FILE_STRING
extract_path = download_file_path / BASE_PATH_STRING
target_path = BASE_PATH
if PYPROJECT_FILE.exists():
logger.debug(f"移除备份文件: {PYPROJECT_FILE}", COMMAND)
shutil.move(PYPROJECT_FILE, BACKUP_PATH / PYPROJECT_FILE_STRING)
if PYPROJECT_LOCK_FILE.exists():
logger.debug(f"移除备份文件: {PYPROJECT_LOCK_FILE}", COMMAND)
shutil.move(PYPROJECT_LOCK_FILE, BACKUP_PATH / PYPROJECT_LOCK_FILE_STRING)
if REQ_TXT_FILE.exists():
logger.debug(f"移除备份文件: {REQ_TXT_FILE}", COMMAND)
shutil.move(REQ_TXT_FILE, BACKUP_PATH / REQ_TXT_FILE_STRING)
if _pyproject.exists():
logger.debug("移动文件: pyproject.toml", COMMAND)
shutil.move(_pyproject, PYPROJECT_FILE)
if _lock_file.exists():
logger.debug("移动文件: poetry.lock", COMMAND)
shutil.move(_lock_file, PYPROJECT_LOCK_FILE)
if _req_file.exists():
logger.debug("移动文件: requirements.txt", COMMAND)
shutil.move(_req_file, REQ_TXT_FILE)
for folder in REPLACE_FOLDERS:
"""移动指定文件夹"""
_dir = BASE_PATH / folder
_backup_dir = BACKUP_PATH / folder
if _backup_dir.exists():
logger.debug(f"删除备份文件夹 {_backup_dir}", COMMAND)
shutil.rmtree(_backup_dir)
if _dir.exists():
logger.debug(f"移动旧文件夹 {_dir}", COMMAND)
shutil.move(_dir, _backup_dir)
else:
logger.warning(f"文件夹 {_dir} 不存在,跳过删除", COMMAND)
for folder in REPLACE_FOLDERS:
src_folder_path = extract_path / folder
dest_folder_path = target_path / folder
if src_folder_path.exists():
logger.debug(
f"移动文件夹: {src_folder_path} -> {dest_folder_path}", COMMAND
)
shutil.move(src_folder_path, dest_folder_path)
else:
logger.debug(f"源文件夹不存在: {src_folder_path}", COMMAND)
if tf:
tf.close()
if download_file.exists():
logger.debug(f"删除下载文件: {download_file}", COMMAND)
download_file.unlink()
if extract_path.exists():
logger.debug(f"删除解压文件夹: {extract_path}", COMMAND)
shutil.rmtree(extract_path)
if TMP_PATH.exists():
shutil.rmtree(TMP_PATH)
if latest_version:
with open(VERSION_FILE, "w", encoding="utf8") as f:
f.write(f"__version__: {latest_version}")
install_requirement()
LOG_COMMAND = "AutoUpdate"
class UpdateManager:
@classmethod
async def update_webui(cls) -> str:
from zhenxun.builtin_plugins.web_ui.public.data_source import (
update_webui_assets,
)
WEBUI_PATH = DATA_PATH / "web_ui" / "public"
BACKUP_PATH = DATA_PATH / "web_ui" / "backup_public"
if WEBUI_PATH.exists():
if BACKUP_PATH.exists():
logger.debug(f"删除旧的备份webui文件夹 {BACKUP_PATH}", COMMAND)
shutil.rmtree(BACKUP_PATH)
WEBUI_PATH.rename(BACKUP_PATH)
try:
await update_webui_assets()
logger.info("更新webui成功...", COMMAND)
if BACKUP_PATH.exists():
logger.debug(f"删除旧的webui文件夹 {BACKUP_PATH}", COMMAND)
shutil.rmtree(BACKUP_PATH)
return "Webui更新成功"
except Exception as e:
logger.error("更新webui失败...", COMMAND, e=e)
if BACKUP_PATH.exists():
logger.debug(f"恢复旧的webui文件夹 {BACKUP_PATH}", COMMAND)
BACKUP_PATH.rename(WEBUI_PATH)
raise e
return ""
@classmethod
async def check_version(cls) -> str:
"""检查更新版本
@ -173,75 +22,146 @@ class UpdateManager:
str: 更新信息
"""
cur_version = cls.__get_version()
data = await cls.__get_latest_data()
if not data:
release_data = await ZhenxunRepoManager.zhenxun_get_latest_releases_data()
if not release_data:
return "检查更新获取版本失败..."
return (
"检测到当前版本更新\n"
f"当前版本:{cur_version}\n"
f"最新版本:{data.get('name')}\n"
f"创建日期:{data.get('created_at')}\n"
f"更新内容:\n{data.get('body')}"
f"最新版本:{release_data.get('name')}\n"
f"创建日期:{release_data.get('created_at')}\n"
f"更新内容:\n{release_data.get('body')}"
)
@classmethod
async def update(cls, bot: Bot, user_id: str, version_type: str) -> str:
async def update_webui(
cls,
source: Literal["git", "ali"] | None,
branch: str = "dist",
force: bool = False,
):
"""更新WebUI
参数:
source: 更新源
branch: 分支
force: 是否强制更新
返回:
str: 返回消息
"""
if not source:
await ZhenxunRepoManager.webui_zip_update()
return "WebUI更新完成!"
result = await ZhenxunRepoManager.webui_git_update(
source,
branch=branch,
force=force,
)
if not result.success:
logger.error(f"WebUI更新失败...错误: {result.error_message}", LOG_COMMAND)
return f"WebUI更新失败...错误: {result.error_message}"
return "WebUI更新完成!"
@classmethod
async def update_resources(
cls,
source: Literal["git", "ali"] | None,
branch: str = "main",
force: bool = False,
) -> str:
"""更新资源
参数:
source: 更新源
branch: 分支
force: 是否强制更新
返回:
str: 返回消息
"""
if not source:
await ZhenxunRepoManager.resources_zip_update()
return "真寻资源更新完成!"
result = await ZhenxunRepoManager.resources_git_update(
source,
branch=branch,
force=force,
)
if not result.success:
logger.error(
f"真寻资源更新失败...错误: {result.error_message}", LOG_COMMAND
)
return f"真寻资源更新失败...错误: {result.error_message}"
return "真寻资源更新完成!"
@classmethod
async def update_zhenxun(
cls,
bot: Bot,
user_id: str,
version_type: Literal["main", "release"],
force: bool,
source: Literal["git", "ali"],
zip: bool,
) -> str:
"""更新操作
参数:
bot: Bot
user_id: 用户id
version_type: 更新版本类型
force: 是否强制更新
source: 更新源
zip: 是否下载zip文件
update_type: 更新方式
返回:
str | None: 返回消息
"""
logger.info("开始下载真寻最新版文件....", COMMAND)
cur_version = cls.__get_version()
url = None
new_version = None
repo_info = GithubUtils.parse_github_url(DEFAULT_GITHUB_URL)
if version_type in {"main"}:
repo_info.branch = version_type
new_version = await cls.__get_version_from_repo(repo_info)
if new_version:
new_version = new_version.split(":")[-1].strip()
url = await repo_info.get_archive_download_urls()
elif version_type == "release":
data = await cls.__get_latest_data()
if not data:
return "获取更新版本失败..."
new_version = data.get("name", "")
url = await repo_info.get_release_source_download_urls_tgz(new_version)
if not url:
return "获取版本下载链接失败..."
if TMP_PATH.exists():
logger.debug(f"删除临时文件夹 {TMP_PATH}", COMMAND)
shutil.rmtree(TMP_PATH)
logger.debug(
f"开始更新版本:{cur_version} -> {new_version} | 下载链接:{url}",
COMMAND,
)
await PlatformUtils.send_superuser(
bot,
f"检测真寻已更新,版本更新:{cur_version} -> {new_version}\n开始更新...",
f"检测真寻已更新,当前版本:{cur_version}\n开始更新...",
user_id,
)
download_file = (
DOWNLOAD_GZ_FILE if version_type == "release" else DOWNLOAD_ZIP_FILE
)
if await AsyncHttpx.download_file(url, download_file, stream=True):
logger.debug("下载真寻最新版文件完成...", COMMAND)
await _file_handle(new_version)
result = "版本更新完成"
if zip:
new_version = await ZhenxunRepoManager.zhenxun_zip_update(version_type)
await PlatformUtils.send_superuser(
bot, "真寻更新完成,开始安装依赖...", user_id
)
await VirtualEnvPackageManager.install_requirement(
ZhenxunRepoConfig.REQUIREMENTS_FILE
)
return (
f"{result}\n"
f"版本: {cur_version} -> {new_version}\n"
f"版本更新完成!\n版本: {cur_version} -> {new_version}\n"
"请重新启动真寻以完成更新!"
)
else:
logger.debug("下载真寻最新版文件失败...", COMMAND)
return ""
result = await ZhenxunRepoManager.zhenxun_git_update(
source,
branch=version_type,
force=force,
)
if not result.success:
logger.error(
f"真寻版本更新失败...错误: {result.error_message}",
LOG_COMMAND,
)
return f"版本更新失败...错误: {result.error_message}"
await PlatformUtils.send_superuser(
bot, "真寻更新完成,开始安装依赖...", user_id
)
await VirtualEnvPackageManager.install_requirement(
ZhenxunRepoConfig.REQUIREMENTS_FILE
)
return (
f"版本更新完成!\n"
f"版本: {cur_version} -> {result.new_version}\n"
f"变更文件个数: {len(result.changed_files)}"
f"{'' if source == 'git' else '(阿里云更新不支持查看变更文件)'}\n"
"请重新启动真寻以完成更新!"
)
@classmethod
def __get_version(cls) -> str:
@ -251,44 +171,9 @@ class UpdateManager:
str: 当前版本号
"""
_version = "v0.0.0"
if VERSION_FILE.exists():
if text := VERSION_FILE.open(encoding="utf8").readline():
if ZhenxunRepoConfig.ZHENXUN_BOT_VERSION_FILE.exists():
if text := ZhenxunRepoConfig.ZHENXUN_BOT_VERSION_FILE.open(
encoding="utf8"
).readline():
_version = text.split(":")[-1].strip()
return _version
@classmethod
async def __get_latest_data(cls) -> dict:
"""获取最新版本信息
返回:
dict: 最新版本数据
"""
for _ in range(3):
try:
res = await AsyncHttpx.get(RELEASE_URL)
if res.status_code == 200:
return res.json()
except TimeoutError:
pass
except Exception as e:
logger.error("检查更新真寻获取版本失败", e=e)
return {}
@classmethod
async def __get_version_from_repo(cls, repo_info: RepoInfo) -> str:
"""从指定分支获取版本号
参数:
branch: 分支名称
返回:
str: 版本号
"""
version_url = await repo_info.get_raw_download_urls(path="__version__")
try:
res = await AsyncHttpx.get(version_url)
if res.status_code == 200:
return res.text.strip()
except Exception as e:
logger.error(f"获取 {repo_info.branch} 分支版本失败", e=e)
return "未知版本"

View File

@ -1,38 +0,0 @@
from pathlib import Path
from zhenxun.configs.path_config import TEMP_PATH
DEFAULT_GITHUB_URL = "https://github.com/HibiKier/zhenxun_bot/tree/main"
RELEASE_URL = "https://api.github.com/repos/HibiKier/zhenxun_bot/releases/latest"
VERSION_FILE_STRING = "__version__"
VERSION_FILE = Path() / VERSION_FILE_STRING
PYPROJECT_FILE_STRING = "pyproject.toml"
PYPROJECT_FILE = Path() / PYPROJECT_FILE_STRING
PYPROJECT_LOCK_FILE_STRING = "poetry.lock"
PYPROJECT_LOCK_FILE = Path() / PYPROJECT_LOCK_FILE_STRING
REQ_TXT_FILE_STRING = "requirements.txt"
REQ_TXT_FILE = Path() / REQ_TXT_FILE_STRING
BASE_PATH_STRING = "zhenxun"
BASE_PATH = Path() / BASE_PATH_STRING
TMP_PATH = TEMP_PATH / "auto_update"
BACKUP_PATH = Path() / "backup"
DOWNLOAD_GZ_FILE_STRING = "download_latest_file.tar.gz"
DOWNLOAD_ZIP_FILE_STRING = "download_latest_file.zip"
DOWNLOAD_GZ_FILE = TMP_PATH / DOWNLOAD_GZ_FILE_STRING
DOWNLOAD_ZIP_FILE = TMP_PATH / DOWNLOAD_ZIP_FILE_STRING
REPLACE_FOLDERS = [
"builtin_plugins",
"services",
"utils",
"models",
"configs",
]
COMMAND = "检查更新"

View File

@ -148,6 +148,11 @@ async def get_plugin_and_user(
user = await with_timeout(
user_dao.safe_get_or_none(user_id=user_id), name="get_user"
)
except IntegrityError:
await asyncio.sleep(0.5)
plugin, user = await with_timeout(
asyncio.gather(plugin_task, user_task), name="get_plugin_and_user"
)
if not plugin:
raise PermissionExemption(f"插件:{module} 数据不存在,已跳过权限检查...")

View File

@ -12,6 +12,7 @@ from zhenxun.services.llm.core import KeyStatus
from zhenxun.services.llm.manager import (
reset_key_status,
)
from zhenxun.services.llm.types import LLMMessage
class DataSource:
@ -58,7 +59,7 @@ class DataSource:
start_time = time.monotonic()
try:
async with await get_model_instance(model_name_str) as model:
await model.generate_text("你好")
await model.generate_response([LLMMessage.user("你好")])
end_time = time.monotonic()
latency = (end_time - start_time) * 1000
return (

View File

@ -84,7 +84,7 @@ async def _(session: EventSession):
try:
result = await StoreManager.get_plugins_info()
logger.info("查看插件列表", "插件商店", session=session)
await MessageUtils.build_message(result).send()
await MessageUtils.build_message([*result]).send()
except Exception as e:
logger.error(f"查看插件列表失败 e: {e}", "插件商店", session=session, e=e)
await MessageUtils.build_message("获取插件列表失败...").send()

View File

@ -1,19 +1,19 @@
from pathlib import Path
import random
import shutil
from aiocache import cached
import ujson as json
from zhenxun.builtin_plugins.auto_update.config import REQ_TXT_FILE_STRING
from zhenxun.builtin_plugins.plugin_store.models import StorePluginInfo
from zhenxun.configs.path_config import TEMP_PATH
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.log import logger
from zhenxun.services.plugin_init import PluginInitManager
from zhenxun.utils.github_utils import GithubUtils
from zhenxun.utils.github_utils.models import RepoAPI
from zhenxun.utils.http_utils import AsyncHttpx
from zhenxun.utils.image_utils import BuildImage, ImageTemplate, RowStyle
from zhenxun.utils.manager.virtual_env_package_manager import VirtualEnvPackageManager
from zhenxun.utils.repo_utils import RepoFileManager
from zhenxun.utils.repo_utils.models import RepoFileInfo, RepoType
from zhenxun.utils.utils import is_number
from .config import (
@ -22,6 +22,7 @@ from .config import (
EXTRA_GITHUB_URL,
LOG_COMMAND,
)
from .exceptions import PluginStoreException
def row_style(column: str, text: str) -> RowStyle:
@ -40,73 +41,25 @@ def row_style(column: str, text: str) -> RowStyle:
return style
def install_requirement(plugin_path: Path):
requirement_files = ["requirement.txt", "requirements.txt"]
requirement_paths = [plugin_path / file for file in requirement_files]
if existing_requirements := next(
(path for path in requirement_paths if path.exists()), None
):
VirtualEnvPackageManager.install_requirement(existing_requirements)
class StoreManager:
@classmethod
async def get_github_plugins(cls) -> list[StorePluginInfo]:
"""获取github插件列表信息
返回:
list[StorePluginInfo]: 插件列表数据
"""
repo_info = GithubUtils.parse_github_url(DEFAULT_GITHUB_URL)
if await repo_info.update_repo_commit():
logger.info(f"获取最新提交: {repo_info.branch}", LOG_COMMAND)
else:
logger.warning(f"获取最新提交失败: {repo_info}", LOG_COMMAND)
default_github_url = await repo_info.get_raw_download_urls("plugins.json")
response = await AsyncHttpx.get(default_github_url, check_status_code=200)
if response.status_code == 200:
logger.info("获取github插件列表成功", LOG_COMMAND)
return [StorePluginInfo(**detail) for detail in json.loads(response.text)]
else:
logger.warning(
f"获取github插件列表失败: {response.status_code}", LOG_COMMAND
)
return []
@classmethod
async def get_extra_plugins(cls) -> list[StorePluginInfo]:
"""获取额外插件列表信息
返回:
list[StorePluginInfo]: 插件列表数据
"""
repo_info = GithubUtils.parse_github_url(EXTRA_GITHUB_URL)
if await repo_info.update_repo_commit():
logger.info(f"获取最新提交: {repo_info.branch}", LOG_COMMAND)
else:
logger.warning(f"获取最新提交失败: {repo_info}", LOG_COMMAND)
extra_github_url = await repo_info.get_raw_download_urls("plugins.json")
response = await AsyncHttpx.get(extra_github_url, check_status_code=200)
if response.status_code == 200:
return [StorePluginInfo(**detail) for detail in json.loads(response.text)]
else:
logger.warning(
f"获取github扩展插件列表失败: {response.status_code}", LOG_COMMAND
)
return []
@classmethod
@cached(60)
async def get_data(cls) -> list[StorePluginInfo]:
async def get_data(cls) -> tuple[list[StorePluginInfo], list[StorePluginInfo]]:
"""获取插件信息数据
返回:
list[StorePluginInfo]: 插件信息数据
tuple[list[StorePluginInfo], list[StorePluginInfo]]:
原生插件信息数据第三方插件信息数据
"""
plugins = await cls.get_github_plugins()
extra_plugins = await cls.get_extra_plugins()
return [*plugins, *extra_plugins]
plugins = await RepoFileManager.get_file_content(
DEFAULT_GITHUB_URL, "plugins.json"
)
extra_plugins = await RepoFileManager.get_file_content(
EXTRA_GITHUB_URL, "plugins.json", "index"
)
return [StorePluginInfo(**plugin) for plugin in json.loads(plugins)], [
StorePluginInfo(**plugin) for plugin in json.loads(extra_plugins)
]
@classmethod
def version_check(cls, plugin_info: StorePluginInfo, suc_plugin: dict[str, str]):
@ -152,38 +105,105 @@ class StoreManager:
return await PluginInfo.filter(load_status=True).values_list(*args)
@classmethod
async def get_plugins_info(cls) -> BuildImage | str:
async def get_plugins_info(cls) -> list[BuildImage] | str:
"""插件列表
返回:
BuildImage | str: 返回消息
"""
plugin_list: list[StorePluginInfo] = await cls.get_data()
plugin_list, extra_plugin_list = await cls.get_data()
column_name = ["-", "ID", "名称", "简介", "作者", "版本", "类型"]
db_plugin_list = await cls.get_loaded_plugins("module", "version")
suc_plugin = {p[0]: (p[1] or "0.1") for p in db_plugin_list}
data_list = [
[
"已安装" if plugin_info.module in suc_plugin else "",
id,
plugin_info.name,
plugin_info.description,
plugin_info.author,
cls.version_check(plugin_info, suc_plugin),
plugin_info.plugin_type_name,
]
for id, plugin_info in enumerate(plugin_list)
index = 0
data_list = []
extra_data_list = []
for plugin_info in plugin_list:
data_list.append(
[
"已安装" if plugin_info.module in suc_plugin else "",
index,
plugin_info.name,
plugin_info.description,
plugin_info.author,
cls.version_check(plugin_info, suc_plugin),
plugin_info.plugin_type_name,
]
)
index += 1
for plugin_info in extra_plugin_list:
extra_data_list.append(
[
"已安装" if plugin_info.module in suc_plugin else "",
index,
plugin_info.name,
plugin_info.description,
plugin_info.author,
cls.version_check(plugin_info, suc_plugin),
plugin_info.plugin_type_name,
]
)
index += 1
return [
await ImageTemplate.table_page(
"原生插件列表",
"通过添加/移除插件 ID 来管理插件",
column_name,
data_list,
text_style=row_style,
),
await ImageTemplate.table_page(
"第三方插件列表",
"通过添加/移除插件 ID 来管理插件",
column_name,
extra_data_list,
text_style=row_style,
),
]
return await ImageTemplate.table_page(
"插件列表",
"通过添加/移除插件 ID 来管理插件",
column_name,
data_list,
text_style=row_style,
)
@classmethod
async def add_plugin(cls, plugin_id: str) -> str:
async def get_plugin_by_value(
cls, index_or_module: str, is_update: bool = False
) -> tuple[StorePluginInfo, bool]:
"""获取插件信息
参数:
index_or_module: 插件索引或模块名
is_update: 是否是更新插件
异常:
PluginStoreException: 插件不存在
PluginStoreException: 插件已安装
返回:
StorePluginInfo: 插件信息
bool: 是否是外部插件
"""
plugin_list, extra_plugin_list = await cls.get_data()
plugin_info = None
is_external = False
db_plugin_list = await cls.get_loaded_plugins("module")
plugin_key = await cls._resolve_plugin_key(index_or_module)
for p in plugin_list:
if p.module == plugin_key:
is_external = False
plugin_info = p
break
for p in extra_plugin_list:
if p.module == plugin_key:
is_external = True
plugin_info = p
break
if not plugin_info:
raise PluginStoreException(f"插件不存在: {plugin_key}")
if not is_update and plugin_info.module in [p[0] for p in db_plugin_list]:
raise PluginStoreException(f"插件 {plugin_info.name} 已安装,无需重复安装")
if plugin_info.module not in [p[0] for p in db_plugin_list] and is_update:
raise PluginStoreException(f"插件 {plugin_info.name} 未安装,无法更新")
return plugin_info, is_external
@classmethod
async def add_plugin(cls, index_or_module: str) -> str:
"""添加插件
参数:
@ -192,21 +212,9 @@ class StoreManager:
返回:
str: 返回消息
"""
plugin_list: list[StorePluginInfo] = await cls.get_data()
try:
plugin_key = await cls._resolve_plugin_key(plugin_id)
except ValueError as e:
return str(e)
db_plugin_list = await cls.get_loaded_plugins("module")
plugin_info = next((p for p in plugin_list if p.module == plugin_key), None)
if plugin_info is None:
return f"未找到插件 {plugin_key}"
if plugin_info.module in [p[0] for p in db_plugin_list]:
return f"插件 {plugin_info.name} 已安装,无需重复安装"
is_external = True
plugin_info, is_external = await cls.get_plugin_by_value(index_or_module)
if plugin_info.github_url is None:
plugin_info.github_url = DEFAULT_GITHUB_URL
is_external = False
version_split = plugin_info.version.split("-")
if len(version_split) > 1:
github_url_split = plugin_info.github_url.split("/tree/")
@ -228,90 +236,81 @@ class StoreManager:
is_dir: bool,
is_external: bool = False,
):
repo_api: RepoAPI
repo_info = GithubUtils.parse_github_url(github_url)
if await repo_info.update_repo_commit():
logger.info(f"获取最新提交: {repo_info.branch}", LOG_COMMAND)
else:
logger.warning(f"获取最新提交失败: {repo_info}", LOG_COMMAND)
logger.debug(f"成功获取仓库信息: {repo_info}", LOG_COMMAND)
for repo_api in GithubUtils.iter_api_strategies():
try:
await repo_api.parse_repo_info(repo_info)
break
except Exception as e:
logger.warning(
f"获取插件文件失败 | API类型: {repo_api.strategy}",
LOG_COMMAND,
e=e,
)
continue
else:
raise ValueError("所有API获取插件文件失败请检查网络连接")
if module_path == ".":
module_path = ""
"""安装插件
参数:
github_url: 仓库地址
module_path: 模块路径
is_dir: 是否是文件夹
is_external: 是否是外部仓库
"""
repo_type = RepoType.GITHUB if is_external else None
replace_module_path = module_path.replace(".", "/")
files = repo_api.get_files(
module_path=replace_module_path + ("" if is_dir else ".py"),
is_dir=is_dir,
)
download_urls = [await repo_info.get_raw_download_urls(file) for file in files]
base_path = BASE_PATH / "plugins" if is_external else BASE_PATH
base_path = base_path if module_path else base_path / repo_info.repo
download_paths: list[Path | str] = [base_path / file for file in files]
logger.debug(f"插件下载路径: {download_paths}", LOG_COMMAND)
result = await AsyncHttpx.gather_download_file(download_urls, download_paths)
for _id, success in enumerate(result):
if not success:
break
if is_dir:
files = await RepoFileManager.list_directory_files(
github_url, replace_module_path, repo_type=repo_type
)
else:
# 安装依赖
plugin_path = base_path / "/".join(module_path.split("."))
try:
req_files = repo_api.get_files(
f"{replace_module_path}/{REQ_TXT_FILE_STRING}", False
files = [RepoFileInfo(path=f"{replace_module_path}.py", is_dir=False)]
local_path = BASE_PATH / "plugins" if is_external else BASE_PATH
files = [file for file in files if not file.is_dir]
download_files = [(file.path, local_path / file.path) for file in files]
await RepoFileManager.download_files(
github_url, download_files, repo_type=repo_type
)
requirement_paths = [
file
for file in files
if file.path.endswith("requirement.txt")
or file.path.endswith("requirements.txt")
]
is_install_req = False
for requirement_path in requirement_paths:
requirement_file = local_path / requirement_path.path
if requirement_file.exists():
is_install_req = True
await VirtualEnvPackageManager.install_requirement(requirement_file)
if not is_install_req:
# 从仓库根目录查找文件
rand = random.randint(1, 10000)
requirement_path = TEMP_PATH / f"plugin_store_{rand}_req.txt"
requirements_path = TEMP_PATH / f"plugin_store_{rand}_reqs.txt"
await RepoFileManager.download_files(
github_url,
[
("requirement.txt", requirement_path),
("requirements.txt", requirements_path),
],
repo_type=repo_type,
ignore_error=True,
)
if requirement_path.exists():
logger.info(
f"开始安装插件 {module_path} 依赖文件: {requirement_path}",
LOG_COMMAND,
)
req_files.extend(
repo_api.get_files(f"{replace_module_path}/requirement.txt", False)
await VirtualEnvPackageManager.install_requirement(requirement_path)
if requirements_path.exists():
logger.info(
f"开始安装插件 {module_path} 依赖文件: {requirements_path}",
LOG_COMMAND,
)
logger.debug(f"获取插件依赖文件列表: {req_files}", LOG_COMMAND)
req_download_urls = [
await repo_info.get_raw_download_urls(file) for file in req_files
]
req_paths: list[Path | str] = [plugin_path / file for file in req_files]
logger.debug(f"插件依赖文件下载路径: {req_paths}", LOG_COMMAND)
if req_files:
result = await AsyncHttpx.gather_download_file(
req_download_urls, req_paths
)
for success in result:
if not success:
raise Exception("插件依赖文件下载失败")
logger.debug(f"插件依赖文件列表: {req_paths}", LOG_COMMAND)
install_requirement(plugin_path)
except ValueError as e:
logger.warning("未获取到依赖文件路径...", e=e)
return True
raise Exception("插件下载失败...")
await VirtualEnvPackageManager.install_requirement(requirements_path)
@classmethod
async def remove_plugin(cls, plugin_id: str) -> str:
async def remove_plugin(cls, index_or_module: str) -> str:
"""移除插件
参数:
plugin_id: 插件id或模块名
index_or_module: 插件id或模块名
返回:
str: 返回消息
"""
plugin_list: list[StorePluginInfo] = await cls.get_data()
try:
plugin_key = await cls._resolve_plugin_key(plugin_id)
except ValueError as e:
return str(e)
plugin_info = next((p for p in plugin_list if p.module == plugin_key), None)
if plugin_info is None:
return f"未找到插件 {plugin_key}"
plugin_info, _ = await cls.get_plugin_by_value(index_or_module)
path = BASE_PATH
if plugin_info.github_url:
path = BASE_PATH / "plugins"
@ -339,12 +338,13 @@ class StoreManager:
返回:
BuildImage | str: 返回消息
"""
plugin_list: list[StorePluginInfo] = await cls.get_data()
plugin_list, extra_plugin_list = await cls.get_data()
all_plugin_list = plugin_list + extra_plugin_list
db_plugin_list = await cls.get_loaded_plugins("module", "version")
suc_plugin = {p[0]: (p[1] or "Unknown") for p in db_plugin_list}
filtered_data = [
(id, plugin_info)
for id, plugin_info in enumerate(plugin_list)
for id, plugin_info in enumerate(all_plugin_list)
if plugin_name_or_author.lower() in plugin_info.name.lower()
or plugin_name_or_author.lower() in plugin_info.author.lower()
]
@ -373,35 +373,24 @@ class StoreManager:
)
@classmethod
async def update_plugin(cls, plugin_id: str) -> str:
async def update_plugin(cls, index_or_module: str) -> str:
"""更新插件
参数:
plugin_id: 插件id
index_or_module: 插件id
返回:
str: 返回消息
"""
plugin_list: list[StorePluginInfo] = await cls.get_data()
try:
plugin_key = await cls._resolve_plugin_key(plugin_id)
except ValueError as e:
return str(e)
plugin_info = next((p for p in plugin_list if p.module == plugin_key), None)
if plugin_info is None:
return f"未找到插件 {plugin_key}"
plugin_info, is_external = await cls.get_plugin_by_value(index_or_module, True)
logger.info(f"尝试更新插件 {plugin_info.name}", LOG_COMMAND)
db_plugin_list = await cls.get_loaded_plugins("module", "version")
suc_plugin = {p[0]: (p[1] or "Unknown") for p in db_plugin_list}
if plugin_info.module not in [p[0] for p in db_plugin_list]:
return f"插件 {plugin_info.name} 未安装,无法更新"
logger.debug(f"当前插件列表: {suc_plugin}", LOG_COMMAND)
if cls.check_version_is_new(plugin_info, suc_plugin):
return f"插件 {plugin_info.name} 已是最新版本"
is_external = True
if plugin_info.github_url is None:
plugin_info.github_url = DEFAULT_GITHUB_URL
is_external = False
await cls.install_plugin_with_repo(
plugin_info.github_url,
plugin_info.module_path,
@ -420,8 +409,9 @@ class StoreManager:
返回:
str: 返回消息
"""
plugin_list: list[StorePluginInfo] = await cls.get_data()
plugin_name_list = [p.name for p in plugin_list]
plugin_list, extra_plugin_list = await cls.get_data()
all_plugin_list = plugin_list + extra_plugin_list
plugin_name_list = [p.name for p in all_plugin_list]
update_failed_list = []
update_success_list = []
result = "--已更新{}个插件 {}个失败 {}个成功--"
@ -492,22 +482,25 @@ class StoreManager:
plugin_id: moduleid或插件名称
异常:
ValueError: 插件不存在
ValueError: 插件不存在
PluginStoreException: 插件不存在
PluginStoreException: 插件不存在
返回:
str: 插件模块名
"""
plugin_list: list[StorePluginInfo] = await cls.get_data()
plugin_list, extra_plugin_list = await cls.get_data()
all_plugin_list = plugin_list + extra_plugin_list
if is_number(plugin_id):
idx = int(plugin_id)
if idx < 0 or idx >= len(plugin_list):
raise ValueError("插件ID不存在...")
return plugin_list[idx].module
if idx < 0 or idx >= len(all_plugin_list):
raise PluginStoreException("插件ID不存在...")
return all_plugin_list[idx].module
elif isinstance(plugin_id, str):
result = (
None if plugin_id not in [v.module for v in plugin_list] else plugin_id
) or next(v for v in plugin_list if v.name == plugin_id).module
None
if plugin_id not in [v.module for v in all_plugin_list]
else plugin_id
) or next(v for v in all_plugin_list if v.name == plugin_id).module
if not result:
raise ValueError("插件 Module / 名称 不存在...")
raise PluginStoreException("插件 Module / 名称 不存在...")
return result

View File

@ -0,0 +1,6 @@
class PluginStoreException(Exception):
def __init__(self, message: str):
self.message = message
def __str__(self):
return self.message

View File

@ -38,10 +38,11 @@ async def _(setting: Setting) -> Result:
password = Config.get_config("web-ui", "password")
if password or BotConfig.db_url:
return Result.fail("配置已存在请先删除DB_URL内容和前端密码再进行设置。")
env_file = Path() / ".env.dev"
env_file = Path() / ".env.example"
if not env_file.exists():
return Result.fail("配置文件.env.dev不存在。")
return Result.fail("基础配置文件.env.example不存在。")
env_text = env_file.read_text(encoding="utf-8")
to_env_file = Path() / ".env.dev"
if setting.db_url:
if setting.db_url.startswith("sqlite"):
base_dir = Path().resolve()
@ -78,7 +79,7 @@ async def _(setting: Setting) -> Result:
if setting.username:
Config.set_config("web-ui", "username", setting.username)
Config.set_config("web-ui", "password", setting.password, True)
env_file.write_text(env_text, encoding="utf-8")
to_env_file.write_text(env_text, encoding="utf-8")
if BAT_FILE.exists():
for file in os.listdir(Path()):
if file.startswith(FILE_NAME):

View File

@ -229,9 +229,9 @@ async def _(payload: InstallDependenciesPayload) -> Result:
if not payload.dependencies:
return Result.fail("依赖列表不能为空")
if payload.handle_type == "install":
result = VirtualEnvPackageManager.install(payload.dependencies)
result = await VirtualEnvPackageManager.install(payload.dependencies)
else:
result = VirtualEnvPackageManager.uninstall(payload.dependencies)
result = await VirtualEnvPackageManager.uninstall(payload.dependencies)
return Result.ok(result)
except Exception as e:
logger.error(f"{router.prefix}/install_dependencies 调用错误", "WebUi", e=e)

View File

@ -25,10 +25,10 @@ async def _() -> Result[dict]:
require("plugin_store")
from zhenxun.builtin_plugins.plugin_store import StoreManager
data = await StoreManager.get_data()
plugin_list, extra_plugin_list = await StoreManager.get_data()
plugin_list = [
{**model_dump(plugin), "name": plugin.name, "id": idx}
for idx, plugin in enumerate(data)
for idx, plugin in enumerate(plugin_list + extra_plugin_list)
]
modules = await PluginInfo.filter(load_status=True).values_list(
"module", flat=True

View File

@ -8,16 +8,6 @@ if sys.version_info >= (3, 11):
else:
from strenum import StrEnum
from zhenxun.configs.path_config import DATA_PATH, TEMP_PATH
WEBUI_STRING = "web_ui"
PUBLIC_STRING = "public"
WEBUI_DATA_PATH = DATA_PATH / WEBUI_STRING
PUBLIC_PATH = WEBUI_DATA_PATH / PUBLIC_STRING
TMP_PATH = TEMP_PATH / WEBUI_STRING
WEBUI_DIST_GITHUB_URL = "https://github.com/HibiKier/zhenxun_bot_webui/tree/dist"
app = nonebot.get_app()

View File

@ -3,41 +3,38 @@ from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from zhenxun.services.log import logger
from ..config import PUBLIC_PATH
from .data_source import COMMAND_NAME, update_webui_assets
from zhenxun.utils.manager.zhenxun_repo_manager import ZhenxunRepoManager
router = APIRouter()
@router.get("/")
async def index():
return FileResponse(PUBLIC_PATH / "index.html")
return FileResponse(ZhenxunRepoManager.config.WEBUI_PATH / "index.html")
@router.get("/favicon.ico")
async def favicon():
return FileResponse(PUBLIC_PATH / "favicon.ico")
@router.get("/79edfa81f3308a9f.jfif")
async def _():
return FileResponse(PUBLIC_PATH / "79edfa81f3308a9f.jfif")
return FileResponse(ZhenxunRepoManager.config.WEBUI_PATH / "favicon.ico")
async def init_public(app: FastAPI):
try:
if not PUBLIC_PATH.exists():
folders = await update_webui_assets()
else:
folders = [x.name for x in PUBLIC_PATH.iterdir() if x.is_dir()]
if not ZhenxunRepoManager.check_webui_exists():
await ZhenxunRepoManager.webui_update(branch="test")
folders = [
x.name for x in ZhenxunRepoManager.config.WEBUI_PATH.iterdir() if x.is_dir()
]
app.include_router(router)
for pathname in folders:
logger.debug(f"挂载文件夹: {pathname}")
app.mount(
f"/{pathname}",
StaticFiles(directory=PUBLIC_PATH / pathname, check_dir=True),
StaticFiles(
directory=ZhenxunRepoManager.config.WEBUI_PATH / pathname,
check_dir=True,
),
name=f"public_{pathname}",
)
except Exception as e:
logger.error("初始化 WebUI资源 失败", COMMAND_NAME, e=e)
logger.error("初始化 WebUI资源 失败", "WebUI", e=e)

View File

@ -1,44 +0,0 @@
from pathlib import Path
import shutil
import zipfile
from nonebot.utils import run_sync
from zhenxun.services.log import logger
from zhenxun.utils.github_utils import GithubUtils
from zhenxun.utils.http_utils import AsyncHttpx
from ..config import PUBLIC_PATH, TMP_PATH, WEBUI_DIST_GITHUB_URL
COMMAND_NAME = "WebUI资源管理"
async def update_webui_assets():
webui_assets_path = TMP_PATH / "webui_assets.zip"
download_url = await GithubUtils.parse_github_url(
WEBUI_DIST_GITHUB_URL
).get_archive_download_urls()
logger.info("开始下载 webui_assets 资源...", COMMAND_NAME)
if await AsyncHttpx.download_file(
download_url, webui_assets_path, follow_redirects=True
):
logger.info("下载 webui_assets 成功...", COMMAND_NAME)
return await _file_handle(webui_assets_path)
raise Exception("下载 webui_assets 失败", COMMAND_NAME)
@run_sync
def _file_handle(webui_assets_path: Path):
logger.debug("开始解压 webui_assets...", COMMAND_NAME)
if webui_assets_path.exists():
tf = zipfile.ZipFile(webui_assets_path)
tf.extractall(TMP_PATH)
logger.debug("解压 webui_assets 成功...", COMMAND_NAME)
else:
raise Exception("解压 webui_assets 失败,文件不存在...", COMMAND_NAME)
download_file_path = next(f for f in TMP_PATH.iterdir() if f.is_dir())
shutil.rmtree(PUBLIC_PATH, ignore_errors=True)
shutil.copytree(download_file_path / "dist", PUBLIC_PATH, dirs_exist_ok=True)
logger.debug("复制 webui_assets 成功...", COMMAND_NAME)
shutil.rmtree(TMP_PATH, ignore_errors=True)
return [x.name for x in PUBLIC_PATH.iterdir() if x.is_dir()]

View File

@ -1,16 +1,21 @@
from collections.abc import Callable
import copy
from pathlib import Path
from typing import Any, TypeVar, get_args, get_origin
from typing import Any, TypeVar
import cattrs
from nonebot.compat import model_dump
from pydantic import VERSION, BaseModel, Field
from pydantic import BaseModel, Field
from ruamel.yaml import YAML
from ruamel.yaml.scanner import ScannerError
from zhenxun.configs.path_config import DATA_PATH
from zhenxun.services.log import logger
from zhenxun.utils.pydantic_compat import (
_dump_pydantic_obj,
_is_pydantic_type,
model_dump,
parse_as,
)
from .models import (
AICallableParam,
@ -39,46 +44,6 @@ class NoSuchConfig(Exception):
pass
def _dump_pydantic_obj(obj: Any) -> Any:
"""
递归地将一个对象内部的 Pydantic BaseModel 实例转换为字典
支持单个实例实例列表实例字典等情况
"""
if isinstance(obj, BaseModel):
return model_dump(obj)
if isinstance(obj, list):
return [_dump_pydantic_obj(item) for item in obj]
if isinstance(obj, dict):
return {key: _dump_pydantic_obj(value) for key, value in obj.items()}
return obj
def _is_pydantic_type(t: Any) -> bool:
"""
递归检查一个类型注解是否与 Pydantic BaseModel 相关
"""
if t is None:
return False
origin = get_origin(t)
if origin:
return any(_is_pydantic_type(arg) for arg in get_args(t))
return isinstance(t, type) and issubclass(t, BaseModel)
def parse_as(type_: type[T], obj: Any) -> T:
"""
一个兼容 Pydantic V1 parse_obj_as 和V2的TypeAdapter.validate_python 的辅助函数
"""
if VERSION.startswith("1"):
from pydantic import parse_obj_as
return parse_obj_as(type_, obj)
else:
from pydantic import TypeAdapter # type: ignore
return TypeAdapter(type_).validate_python(obj)
class ConfigGroup(BaseModel):
"""
配置组
@ -194,16 +159,11 @@ class ConfigsManager:
"""
result = dict(original_data)
# 遍历新数据的键
for key, value in new_data.items():
# 如果键不在原数据中,添加它
if key not in original_data:
result[key] = value
# 如果两边都是字典,递归处理
elif isinstance(value, dict) and isinstance(original_data[key], dict):
result[key] = self._merge_dicts(value, original_data[key])
# 如果键已存在,保留原值,不覆盖
# (不做任何操作,保持原值)
return result
@ -217,15 +177,11 @@ class ConfigsManager:
返回:
标准化后的值
"""
# 处理BaseModel
processed_value = _dump_pydantic_obj(value)
# 如果处理后的值是字典,且原始值也存在
if isinstance(processed_value, dict) and original_value is not None:
# 处理原始值
processed_original = _dump_pydantic_obj(original_value)
# 如果原始值也是字典,合并它们
if isinstance(processed_original, dict):
return self._merge_dicts(processed_value, processed_original)
@ -263,12 +219,10 @@ class ConfigsManager:
if not module or not key:
raise ValueError("add_plugin_config: module和key不能为为空")
# 获取现有配置值(如果存在)
existing_value = None
if module in self._data and (config := self._data[module].configs.get(key)):
existing_value = config.value
# 标准化值和默认值
processed_value = self._normalize_config_data(value, existing_value)
processed_default_value = self._normalize_config_data(default_value)
@ -348,7 +302,6 @@ class ConfigsManager:
if value_to_process is None:
return default
# 1. 最高优先级:自定义的参数解析器
if config.arg_parser:
try:
return config.arg_parser(value_to_process)

View File

@ -2,10 +2,10 @@ from collections.abc import Callable
from datetime import datetime
from typing import Any, Literal
from nonebot.compat import model_dump
from pydantic import BaseModel, Field
from zhenxun.utils.enum import BlockType, LimitWatchType, PluginLimitType, PluginType
from zhenxun.utils.pydantic_compat import model_dump
__all__ = [
"AICallableParam",

View File

@ -27,23 +27,19 @@ from .llm import (
LLMException,
LLMGenerationConfig,
LLMMessage,
analyze,
analyze_multimodal,
chat,
clear_model_cache,
code,
create_multimodal_message,
embed,
generate,
generate_structured,
get_cache_stats,
get_model_instance,
list_available_models,
list_embedding_models,
pipeline_chat,
search,
search_multimodal,
set_global_default_model_name,
tool_registry,
)
from .log import logger
from .plugin_init import PluginInit, PluginInitManager
@ -60,8 +56,6 @@ __all__ = [
"Model",
"PluginInit",
"PluginInitManager",
"analyze",
"analyze_multimodal",
"chat",
"clear_model_cache",
"code",
@ -69,16 +63,14 @@ __all__ = [
"disconnect",
"embed",
"generate",
"generate_structured",
"get_cache_stats",
"get_model_instance",
"list_available_models",
"list_embedding_models",
"logger",
"pipeline_chat",
"scheduler_manager",
"search",
"search_multimodal",
"set_global_default_model_name",
"tool_registry",
"with_db_timeout",
]

View File

@ -1,6 +1,8 @@
import asyncio
from pathlib import Path
from urllib.parse import urlparse
import aiofiles
import nonebot
from nonebot.utils import is_coroutine_callable
from tortoise import Tortoise
@ -86,6 +88,7 @@ def get_config() -> dict:
**MYSQL_CONFIG,
}
elif parsed.scheme == "sqlite":
Path(parsed.path).parent.mkdir(parents=True, exist_ok=True)
config["connections"]["default"] = {
"engine": "tortoise.backends.sqlite",
"credentials": {
@ -100,6 +103,15 @@ def get_config() -> dict:
async def init():
global MODELS, SCRIPT_METHOD
env_example_file = Path() / ".env.example"
env_dev_file = Path() / ".env.dev"
if not env_dev_file.exists():
async with aiofiles.open(env_example_file, encoding="utf-8") as f:
env_text = await f.read()
async with aiofiles.open(env_dev_file, "w", encoding="utf-8") as f:
await f.write(env_text)
logger.info("已生成 .env.dev 文件,请根据 .env.example 文件配置进行配置")
MODELS = db_model.models
SCRIPT_METHOD = db_model.script_method
if not BotConfig.db_url:

View File

@ -1,558 +0,0 @@
---
# 🚀 Zhenxun LLM 服务模块
本模块是一个功能强大、高度可扩展的统一大语言模型LLM服务框架。它旨在将各种不同的 LLM 提供商(如 OpenAI、Gemini、智谱AI等的 API 封装在一个统一、易于使用的接口之后,让开发者可以无缝切换和使用不同的模型,同时支持多模态输入、工具调用、智能重试和缓存等高级功能。
## 目录
- [🚀 Zhenxun LLM 服务模块](#-zhenxun-llm-服务模块)
- [目录](#目录)
- [✨ 核心特性](#-核心特性)
- [🧠 核心概念](#-核心概念)
- [🛠️ 安装与配置](#-安装与配置)
- [服务提供商配置 (`config.yaml`)](#服务提供商配置-configyaml)
- [MCP 工具配置 (`mcp_tools.json`)](#mcp-工具配置-mcp_toolsjson)
- [📘 使用指南](#-使用指南)
- [**等级1: 便捷函数** - 最快速的调用方式](#等级1-便捷函数---最快速的调用方式)
- [**等级2: `AI` 会话类** - 管理有状态的对话](#等级2-ai-会话类---管理有状态的对话)
- [**等级3: 直接模型控制** - `get_model_instance`](#等级3-直接模型控制---get_model_instance)
- [🌟 功能深度剖析](#-功能深度剖析)
- [精细化控制模型生成 (`LLMGenerationConfig` 与 `CommonOverrides`)](#精细化控制模型生成-llmgenerationconfig-与-commonoverrides)
- [赋予模型能力:工具使用 (Function Calling)](#赋予模型能力工具使用-function-calling)
- [1. 注册工具](#1-注册工具)
- [函数工具注册](#函数工具注册)
- [MCP工具注册](#mcp工具注册)
- [2. 调用带工具的模型](#2-调用带工具的模型)
- [处理多模态输入](#处理多模态输入)
- [🔧 高级主题与扩展](#-高级主题与扩展)
- [模型与密钥管理](#模型与密钥管理)
- [缓存管理](#缓存管理)
- [错误处理 (`LLMException`)](#错误处理-llmexception)
- [自定义适配器 (Adapter)](#自定义适配器-adapter)
- [📚 API 快速参考](#-api-快速参考)
---
## ✨ 核心特性
- **多提供商支持**: 内置对 OpenAI、Gemini、智谱AI 等多种 API 的适配器,并可通过通用 OpenAI 兼容适配器轻松接入更多服务。
- **统一的 API**: 提供从简单到高级的三层 API满足不同场景的需求无论是快速聊天还是复杂的分析任务。
- **强大的工具调用 (Function Calling)**: 支持标准的函数调用和实验性的 MCP (Model Context Protocol) 工具,让 LLM 能够与外部世界交互。
- **多模态能力**: 无缝集成 `UniMessage`,轻松处理文本、图片、音频、视频等混合输入,支持多模态搜索和分析。
- **文本嵌入向量化**: 提供统一的嵌入接口,支持语义搜索、相似度计算和文本聚类等应用。
- **智能重试与 Key 轮询**: 内置健壮的请求重试逻辑,当 API Key 失效或达到速率限制时,能自动轮询使用备用 Key。
- **灵活的配置系统**: 通过配置文件和代码中的 `LLMGenerationConfig`可以精细控制模型的生成行为如温度、最大Token等
- **高性能缓存机制**: 内置模型实例缓存,减少重复初始化开销,提供缓存管理和监控功能。
- **丰富的配置预设**: 提供 `CommonOverrides`包含创意模式、精确模式、JSON输出等多种常用配置预设。
- **可扩展的适配器架构**: 开发者可以轻松编写自己的适配器来支持新的 LLM 服务。
## 🧠 核心概念
- **适配器 (Adapter)**: 这是连接我们统一接口和特定 LLM 提供商 API 的“翻译官”。例如,`GeminiAdapter` 知道如何将我们的标准请求格式转换为 Google Gemini API 需要的格式,并解析其响应。
- **模型实例 (`LLMModel`)**: 这是框架中的核心操作对象,代表一个**具体配置好**的模型。例如,一个 `LLMModel` 实例可能代表使用特定 API Key、特定代理的 `Gemini/gemini-1.5-pro`。所有与模型交互的操作都通过这个类的实例进行。
- **生成配置 (`LLMGenerationConfig`)**: 这是一个数据类,用于控制模型在生成内容时的行为,例如 `temperature` (温度)、`max_tokens` (最大输出长度)、`response_format` (响应格式) 等。
- **工具 (Tool)**: 代表一个可以让 LLM 调用的函数。它可以是一个简单的 Python 函数,也可以是一个更复杂的、有状态的 MCP 服务。
- **多模态内容 (`LLMContentPart`)**: 这是处理多模态输入的基础单元,一个 `LLMMessage` 可以包含多个 `LLMContentPart`,如一个文本部分和多个图片部分。
## 🛠️ 安装与配置
该模块作为 `zhenxun` 项目的一部分被集成,无需额外安装。核心配置主要涉及两个文件。
### 服务提供商配置 (`config.yaml`)
核心配置位于项目 `/data/config.yaml` 文件中的 `AI` 部分。
```yaml
# /data/configs/config.yaml
AI:
# (可选) 全局默认模型,格式: "ProviderName/ModelName"
default_model_name: Gemini/gemini-2.5-flash
# (可选) 全局代理设置
proxy: http://127.0.0.1:7890
# (可选) 全局超时设置 (秒)
timeout: 180
# (可选) Gemini 的安全过滤阈值
gemini_safety_threshold: BLOCK_MEDIUM_AND_ABOVE
# 配置你的AI服务提供商
PROVIDERS:
# 示例1: Gemini
- name: Gemini
api_key:
- "AIzaSy_KEY_1" # 支持多个Key会自动轮询
- "AIzaSy_KEY_2"
api_base: https://generativelanguage.googleapis.com
api_type: gemini
models:
- model_name: gemini-2.5-pro
- model_name: gemini-2.5-flash
- model_name: gemini-2.0-flash
- model_name: embedding-001
is_embedding_model: true # 标记为嵌入模型
max_input_tokens: 2048 # 嵌入模型特有配置
# 示例2: 智谱AI
- name: GLM
api_key: "YOUR_ZHIPU_API_KEY"
api_type: zhipu # 适配器类型
models:
- model_name: glm-4-flash
- model_name: glm-4-plus
temperature: 0.8 # 可以为特定模型设置默认温度
# 示例3: 一个兼容OpenAI的自定义服务
- name: MyOpenAIService
api_key: "sk-my-custom-key"
api_base: "http://localhost:8080/v1"
api_type: general_openai_compat # 使用通用OpenAI兼容适配器
models:
- model_name: Llama3-8B-Instruct
max_tokens: 2048 # 可以为特定模型设置默认最大Token
```
### MCP 工具配置 (`mcp_tools.json`)
此文件位于 `/data/llm/mcp_tools.json`,用于配置通过 MCP 协议启动的外部工具服务。如果文件不存在,系统会自动创建一个包含示例的默认文件。
```json
{
"mcpServers": {
"baidu-map": {
"command": "npx",
"args": ["-y", "@baidumap/mcp-server-baidu-map"],
"env": {
"BAIDU_MAP_API_KEY": "<YOUR_BAIDU_MAP_API_KEY>"
},
"description": "百度地图工具,提供地理编码、路线规划等功能。"
},
"sequential-thinking": {
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-sequential-thinking"],
"description": "顺序思维工具,用于帮助模型进行多步骤推理。"
}
}
}
```
## 📘 使用指南
我们提供了三层 API以满足从简单到复杂的各种需求。
### **等级1: 便捷函数** - 最快速的调用方式
这些函数位于 `zhenxun.services.llm` 包的顶层,为你处理了所有的底层细节。
```python
from zhenxun.services.llm import chat, search, code, pipeline_chat, embed, analyze_multimodal, search_multimodal
from zhenxun.services.llm.utils import create_multimodal_message
# 1. 纯文本聊天
response_text = await chat("你好,请用苏轼的风格写一首关于月亮的诗。")
print(response_text)
# 2. 带网络搜索的问答
search_result = await search("马斯克的Neuralink公司最近有什么新进展")
print(search_result['text'])
# print(search_result['sources']) # 查看信息来源
# 3. 执行代码
code_result = await code("用Python画一个心形图案。")
print(code_result['text']) # 包含代码和解释的回复
# 4. 链式调用
image_msg = create_multimodal_message(images="path/to/cat.jpg")
final_poem = await pipeline_chat(
message=image_msg,
model_chain=["Gemini/gemini-1.5-pro", "GLM/glm-4-flash"],
initial_instruction="详细描述这只猫的外观和姿态。",
final_instruction="将上述描述凝练成一首可爱的短诗。"
)
print(final_poem.text)
# 5. 文本嵌入向量生成
texts_to_embed = ["今天天气真好", "我喜欢打篮球", "这部电影很感人"]
vectors = await embed(texts_to_embed, model="Gemini/embedding-001")
print(f"生成了 {len(vectors)} 个向量,每个向量维度: {len(vectors[0])}")
# 6. 多模态分析便捷函数
response = await analyze_multimodal(
text="请分析这张图片中的内容",
images="path/to/image.jpg",
model="Gemini/gemini-1.5-pro"
)
print(response)
# 7. 多模态搜索便捷函数
search_result = await search_multimodal(
text="搜索与这张图片相关的信息",
images="path/to/image.jpg",
model="Gemini/gemini-1.5-pro"
)
print(search_result['text'])
```
### **等级2: `AI` 会话类** - 管理有状态的对话
当你需要进行有上下文的、连续的对话时,`AI` 类是你的最佳选择。
```python
from zhenxun.services.llm import AI, AIConfig
# 初始化一个AI会话可以传入自定义配置
ai_config = AIConfig(model="GLM/glm-4-flash", temperature=0.7)
ai_session = AI(config=ai_config)
# 更完整的AIConfig配置示例
advanced_config = AIConfig(
model="GLM/glm-4-flash",
default_embedding_model="Gemini/embedding-001", # 默认嵌入模型
temperature=0.7,
max_tokens=2000,
enable_cache=True, # 启用模型缓存
enable_code=True, # 启用代码执行功能
enable_search=True, # 启用搜索功能
timeout=180, # 请求超时时间(秒)
# Gemini特定配置选项
enable_gemini_json_mode=True, # 启用Gemini JSON模式
enable_gemini_thinking=True, # 启用Gemini 思考模式
enable_gemini_safe_mode=True, # 启用Gemini 安全模式
enable_gemini_multimodal=True, # 启用Gemini 多模态优化
enable_gemini_grounding=True, # 启用Gemini 信息来源关联
)
advanced_session = AI(config=advanced_config)
# 进行连续对话
await ai_session.chat("我最喜欢的城市是成都。")
response = await ai_session.chat("它有什么好吃的?") # AI会知道“它”指的是成都
print(response)
# 在同一个会话中,临时切换模型进行一次调用
response_gemini = await ai_session.chat(
"从AI的角度分析一下成都的科技发展潜力。",
model="Gemini/gemini-1.5-pro"
)
print(response_gemini)
# 清空历史,开始新一轮对话
ai_session.clear_history()
```
### **等级3: 直接模型控制** - `get_model_instance`
这是最底层的 API为你提供对模型实例的完全控制。推荐使用 `async with` 语句来优雅地管理模型实例的生命周期。
```python
from zhenxun.services.llm import get_model_instance, LLMMessage
from zhenxun.services.llm.config import LLMGenerationConfig
# 1. 获取模型实例
# get_model_instance 返回一个异步上下文管理器
async with await get_model_instance("Gemini/gemini-1.5-pro") as model:
# 2. 准备消息列表
messages = [
LLMMessage.system("你是一个专业的营养师。"),
LLMMessage.user("我今天吃了汉堡和可乐,请给我一些健康建议。")
]
# 3. (可选) 定义本次调用的生成配置
gen_config = LLMGenerationConfig(
temperature=0.2, # 更严谨的回复
max_tokens=300
)
# 4. 生成响应
response = await model.generate_response(messages, config=gen_config)
# 5. 处理响应
print(response.text)
if response.usage_info:
print(f"Token 消耗: {response.usage_info['total_tokens']}")
```
## 🌟 功能深度剖析
### 精细化控制模型生成 (`LLMGenerationConfig` 与 `CommonOverrides`)
- **`LLMGenerationConfig`**: 一个 Pydantic 模型,用于覆盖模型的默认生成参数。
- **`CommonOverrides`**: 一个包含多种常用配置预设的类,如 `creative()`, `precise()`, `gemini_json()` 等,能极大地简化配置过程。
```python
from zhenxun.services.llm.config import LLMGenerationConfig, CommonOverrides
# LLMGenerationConfig 完整参数示例
comprehensive_config = LLMGenerationConfig(
temperature=0.7, # 生成温度 (0.0-2.0)
max_tokens=1000, # 最大输出token数
top_p=0.9, # 核采样参数 (0.0-1.0)
top_k=40, # Top-K采样参数
frequency_penalty=0.0, # 频率惩罚 (-2.0-2.0)
presence_penalty=0.0, # 存在惩罚 (-2.0-2.0)
repetition_penalty=1.0, # 重复惩罚 (0.0-2.0)
stop=["END", "\n\n"], # 停止序列
response_format={"type": "json_object"}, # 响应格式
response_mime_type="application/json", # Gemini专用MIME类型
response_schema={...}, # JSON响应模式
thinking_budget=0.8, # Gemini思考预算 (0.0-1.0)
enable_code_execution=True, # 启用代码执行
safety_settings={...}, # 安全设置
response_modalities=["TEXT"], # 响应模态类型
)
# 创建一个配置要求模型输出JSON格式
json_config = LLMGenerationConfig(
temperature=0.1,
response_mime_type="application/json" # Gemini特有
)
# 对于OpenAI兼容API可以这样做
json_config_openai = LLMGenerationConfig(
temperature=0.1,
response_format={"type": "json_object"}
)
# 使用框架提供的预设 - 基础预设
safe_config = CommonOverrides.gemini_safe()
creative_config = CommonOverrides.creative()
precise_config = CommonOverrides.precise()
balanced_config = CommonOverrides.balanced()
# 更多实用预设
concise_config = CommonOverrides.concise(max_tokens=50) # 简洁模式
detailed_config = CommonOverrides.detailed(max_tokens=3000) # 详细模式
json_config = CommonOverrides.gemini_json() # JSON输出模式
thinking_config = CommonOverrides.gemini_thinking(budget=0.8) # 思考模式
# Gemini特定高级预设
code_config = CommonOverrides.gemini_code_execution() # 代码执行模式
grounding_config = CommonOverrides.gemini_grounding() # 信息来源关联模式
multimodal_config = CommonOverrides.gemini_multimodal() # 多模态优化模式
# 在调用时传入config对象
# await model.generate_response(messages, config=json_config)
```
### 赋予模型能力:工具使用 (Function Calling)
工具调用让 LLM 能够与外部函数、API 或服务进行交互。
#### 1. 注册工具
##### 函数工具注册
使用 `@tool_registry.function_tool` 装饰器注册一个简单的函数工具。
```python
from zhenxun.services.llm import tool_registry
@tool_registry.function_tool(
name="query_stock_price",
description="查询指定股票代码的当前价格。",
parameters={
"stock_symbol": {"type": "string", "description": "股票代码, 例如 'AAPL' 或 'GOOG'"}
},
required=["stock_symbol"]
)
async def query_stock_price(stock_symbol: str) -> dict:
"""一个查询股票价格的伪函数"""
print(f"--- 正在查询 {stock_symbol} 的价格 ---")
if stock_symbol == "AAPL":
return {"symbol": "AAPL", "price": 175.50, "currency": "USD"}
return {"error": "未知的股票代码"}
```
##### MCP工具注册
对于更复杂的、有状态的工具,可以使用 `@tool_registry.mcp_tool` 装饰器注册MCP工具。
```python
from contextlib import asynccontextmanager
from pydantic import BaseModel
from zhenxun.services.llm import tool_registry
# 定义工具的配置模型
class MyToolConfig(BaseModel):
api_key: str
endpoint: str
timeout: int = 30
# 注册MCP工具
@tool_registry.mcp_tool(name="my-custom-tool", config_model=MyToolConfig)
@asynccontextmanager
async def my_tool_factory(config: MyToolConfig):
"""MCP工具工厂函数"""
# 初始化工具会话
session = MyToolSession(config)
try:
await session.initialize()
yield session
finally:
await session.cleanup()
```
#### 2. 调用带工具的模型
`analyze``generate_response` 中使用 `use_tools` 参数。框架会自动处理整个调用流程。
```python
from zhenxun.services.llm import analyze
from nonebot_plugin_alconna.uniseg import UniMessage
response = await analyze(
UniMessage("帮我查一下苹果公司的股价"),
use_tools=["query_stock_price"]
)
print(response.text) # 输出应为 "苹果公司(AAPL)的当前股价为175.5美元。" 或类似内容
```
### 处理多模态输入
本模块通过 `UniMessage``LLMContentPart` 完美支持多模态。
- **`create_multimodal_message`**: 推荐的、用于从代码中便捷地创建多模态消息的函数。
- **`unimsg_to_llm_parts`**: 框架内部使用的核心转换函数,将 `UniMessage` 的各个段(文本、图片等)转换为 `LLMContentPart` 列表。
```python
from zhenxun.services.llm import analyze
from zhenxun.services.llm.utils import create_multimodal_message
from pathlib import Path
# 从本地文件创建消息
message = create_multimodal_message(
text="请分析这张图片和这个视频。图片里是什么?视频里发生了什么?",
images=[Path("path/to/your/image.jpg")],
videos=[Path("path/to/your/video.mp4")]
)
response = await analyze(message, model="Gemini/gemini-1.5-pro")
print(response.text)
```
## 🔧 高级主题与扩展
### 模型与密钥管理
模块提供了一些工具函数来管理你的模型配置。
```python
from zhenxun.services.llm.manager import (
list_available_models,
list_embedding_models,
set_global_default_model_name,
get_global_default_model_name,
get_key_usage_stats,
reset_key_status
)
# 列出所有在config.yaml中配置的可用模型
models = list_available_models()
print([m['full_name'] for m in models])
# 列出所有可用的嵌入模型
embedding_models = list_embedding_models()
print([m['full_name'] for m in embedding_models])
# 动态设置全局默认模型
success = set_global_default_model_name("GLM/glm-4-plus")
# 获取所有Key的使用统计
stats = await get_key_usage_stats()
print(stats)
# 重置'Gemini'提供商的所有Key
await reset_key_status("Gemini")
```
### 缓存管理
模块提供了模型实例缓存功能,可以提高性能并减少重复初始化的开销。
```python
from zhenxun.services.llm import clear_model_cache, get_cache_stats
# 获取缓存统计信息
stats = get_cache_stats()
print(f"缓存大小: {stats['cache_size']}/{stats['max_cache_size']}")
print(f"缓存TTL: {stats['cache_ttl']}秒")
print(f"已缓存模型: {stats['cached_models']}")
# 清空模型缓存(在内存不足或需要强制重新初始化时使用)
clear_model_cache()
print("模型缓存已清空")
```
### 错误处理 (`LLMException`)
所有模块内的预期错误都会被包装成 `LLMException`,方便统一处理。
```python
from zhenxun.services.llm import chat, LLMException, LLMErrorCode
try:
await chat("test", model="InvalidProvider/invalid_model")
except LLMException as e:
print(f"捕获到LLM异常: {e}")
print(f"错误码: {e.code}") # 例如 LLMErrorCode.MODEL_NOT_FOUND
print(f"用户友好提示: {e.user_friendly_message}")
```
### 自定义适配器 (Adapter)
如果你想支持一个新的、非 OpenAI 兼容的 LLM 服务,可以通过实现自己的适配器来完成。
1. **创建适配器类**: 继承 `BaseAdapter` 并实现其抽象方法。
```python
# my_adapters/custom_adapter.py
from zhenxun.services.llm.adapters import BaseAdapter, RequestData, ResponseData
class MyCustomAdapter(BaseAdapter):
@property
def api_type(self) -> str: return "my_custom_api"
@property
def supported_api_types(self) -> list[str]: return ["my_custom_api"]
# ... 实现 prepare_advanced_request, parse_response 等方法
```
2. **注册适配器**: 在你的插件初始化代码中注册你的适配器。
```python
from zhenxun.services.llm.adapters import register_adapter
from .my_adapters.custom_adapter import MyCustomAdapter
register_adapter(MyCustomAdapter())
```
3. **在 `config.yaml` 中使用**:
```yaml
AI:
PROVIDERS:
- name: MyAwesomeLLM
api_key: "my-secret-key"
api_type: "my_custom_api" # 关键!使用你注册的 api_type
# ...
```
## 📚 API 快速参考
| 类/函数 | 主要用途 | 推荐场景 |
| ------------------------------------- | ---------------------------------------------------------------------- | ------------------------------------------------------------ |
| `llm.chat()` | 进行简单的、无状态的文本对话。 | 快速实现单轮问答。 |
| `llm.search()` | 执行带网络搜索的问答。 | 需要最新信息或回答事实性问题时。 |
| `llm.code()` | 请求模型执行代码。 | 计算、数据处理、代码生成等。 |
| `llm.pipeline_chat()` | 将多个模型串联,处理复杂任务流。 | 需要多模型协作完成的任务,如“图生文再润色”。 |
| `llm.analyze()` | 处理复杂的多模态输入 (`UniMessage`) 和工具调用。 | 插件中处理用户命令需要解析图片、at、回复等复杂消息时。 |
| `llm.AI` (类) | 管理一个有状态的、连续的对话会话。 | 需要实现上下文关联的连续对话机器人。 |
| `llm.get_model_instance()` | 获取一个底层的、可直接控制的 `LLMModel` 实例。 | 需要对模型进行最精细控制的复杂或自定义场景。 |
| `llm.config.LLMGenerationConfig` (类) | 定义模型生成的具体参数如温度、最大Token等。 | 当需要微调模型输出风格或格式时。 |
| `llm.tools.tool_registry` (实例) | 注册和管理可供LLM调用的函数工具。 | 当你想让LLM拥有与外部世界交互的能力时。 |
| `llm.embed()` | 生成文本的嵌入向量表示。 | 语义搜索、相似度计算、文本聚类等。 |
| `llm.search_multimodal()` | 执行带网络搜索的多模态问答。 | 需要基于图片、视频等多模态内容进行搜索时。 |
| `llm.analyze_multimodal()` | 便捷的多模态分析函数。 | 直接分析文本、图片、视频、音频等多模态内容。 |
| `llm.AIConfig` (类) | AI会话的配置类包含模型、温度等参数。 | 配置AI会话的行为和特性。 |
| `llm.clear_model_cache()` | 清空模型实例缓存。 | 内存管理或强制重新初始化模型时。 |
| `llm.get_cache_stats()` | 获取模型缓存的统计信息。 | 监控缓存使用情况和性能优化。 |
| `llm.list_embedding_models()` | 列出所有可用的嵌入模型。 | 选择合适的嵌入模型进行向量化任务。 |
| `llm.config.CommonOverrides` (类) | 提供常用的配置预设,如创意模式、精确模式等。 | 快速应用常见的模型配置组合。 |
| `llm.utils.create_multimodal_message` | 便捷地从文本、图片、音视频等数据创建 `UniMessage`。 | 在代码中以编程方式构建多模态输入时。 |

View File

@ -5,15 +5,13 @@ LLM 服务模块 - 公共 API 入口
"""
from .api import (
analyze,
analyze_multimodal,
chat,
code,
embed,
generate,
pipeline_chat,
generate_structured,
run_with_tools,
search,
search_multimodal,
)
from .config import (
CommonOverrides,
@ -34,7 +32,7 @@ from .manager import (
set_global_default_model_name,
)
from .session import AI, AIConfig
from .tools import tool_registry
from .tools import function_tool, tool_provider_manager
from .types import (
EmbeddingTaskType,
LLMContentPart,
@ -42,8 +40,6 @@ from .types import (
LLMException,
LLMMessage,
LLMResponse,
LLMTool,
MCPCompatible,
ModelDetail,
ModelInfo,
ModelProvider,
@ -66,8 +62,6 @@ __all__ = [
"LLMGenerationConfig",
"LLMMessage",
"LLMResponse",
"LLMTool",
"MCPCompatible",
"ModelDetail",
"ModelInfo",
"ModelName",
@ -77,14 +71,14 @@ __all__ = [
"ToolCategory",
"ToolMetadata",
"UsageInfo",
"analyze",
"analyze_multimodal",
"chat",
"clear_model_cache",
"code",
"create_multimodal_message",
"embed",
"function_tool",
"generate",
"generate_structured",
"get_cache_stats",
"get_global_default_model_name",
"get_model_instance",
@ -92,11 +86,10 @@ __all__ = [
"list_embedding_models",
"list_model_identifiers",
"message_to_unimessage",
"pipeline_chat",
"register_llm_configs",
"run_with_tools",
"search",
"search_multimodal",
"set_global_default_model_name",
"tool_registry",
"tool_provider_manager",
"unimsg_to_llm_parts",
]

View File

@ -17,7 +17,7 @@ if TYPE_CHECKING:
from ..service import LLMModel
from ..types.content import LLMMessage
from ..types.enums import EmbeddingTaskType
from ..types.models import LLMTool
from ..types.protocols import ToolExecutable
class RequestData(BaseModel):
@ -103,7 +103,7 @@ class BaseAdapter(ABC):
api_key: str,
messages: list["LLMMessage"],
config: "LLMGenerationConfig | None" = None,
tools: list["LLMTool"] | None = None,
tools: dict[str, "ToolExecutable"] | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> RequestData:
"""准备高级请求"""
@ -401,7 +401,6 @@ class BaseAdapter(ABC):
class OpenAICompatAdapter(BaseAdapter):
"""
处理所有 OpenAI 兼容 API 的通用适配器
消除 OpenAIAdapter ZhipuAdapter 之间的代码重复
"""
@abstractmethod
@ -445,7 +444,7 @@ class OpenAICompatAdapter(BaseAdapter):
api_key: str,
messages: list["LLMMessage"],
config: "LLMGenerationConfig | None" = None,
tools: list["LLMTool"] | None = None,
tools: dict[str, "ToolExecutable"] | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> RequestData:
"""准备高级请求 - OpenAI兼容格式"""
@ -459,21 +458,20 @@ class OpenAICompatAdapter(BaseAdapter):
}
if tools:
openai_tools = []
for tool in tools:
if tool.type == "function" and tool.function:
openai_tools.append({"type": "function", "function": tool.function})
elif tool.type == "mcp" and tool.mcp_session:
if callable(tool.mcp_session):
raise ValueError(
"适配器接收到未激活的 MCP 会话工厂。"
"会话工厂应该在 LLMModel.generate_response 中被激活。"
)
openai_tools.append(
tool.mcp_session.to_api_tool(api_type=self.api_type)
)
import asyncio
from zhenxun.utils.pydantic_compat import model_dump
definition_tasks = [
executable.get_definition() for executable in tools.values()
]
openai_tools = await asyncio.gather(*definition_tasks)
if openai_tools:
body["tools"] = openai_tools
body["tools"] = [
{"type": "function", "function": model_dump(tool)}
for tool in openai_tools
]
if tool_choice:
body["tool_choice"] = tool_choice

View File

@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any
from zhenxun.services.log import logger
from ..types.exceptions import LLMErrorCode, LLMException
from ..utils import sanitize_schema_for_llm
from .base import BaseAdapter, RequestData, ResponseData
if TYPE_CHECKING:
@ -14,7 +15,8 @@ if TYPE_CHECKING:
from ..service import LLMModel
from ..types.content import LLMMessage
from ..types.enums import EmbeddingTaskType
from ..types.models import LLMTool, LLMToolCall
from ..types.models import LLMToolCall
from ..types.protocols import ToolExecutable
class GeminiAdapter(BaseAdapter):
@ -44,7 +46,7 @@ class GeminiAdapter(BaseAdapter):
api_key: str,
messages: list["LLMMessage"],
config: "LLMGenerationConfig | None" = None,
tools: list["LLMTool"] | None = None,
tools: dict[str, "ToolExecutable"] | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> RequestData:
"""准备高级请求"""
@ -128,11 +130,22 @@ class GeminiAdapter(BaseAdapter):
)
tool_result_obj = {"raw_output": content_str}
if isinstance(tool_result_obj, list):
logger.debug(
f"工具 '{msg.name}' 的返回结果是列表,"
f"正在为Gemini API包装为JSON对象。"
)
final_response_payload = {"result": tool_result_obj}
elif not isinstance(tool_result_obj, dict):
final_response_payload = {"result": tool_result_obj}
else:
final_response_payload = tool_result_obj
current_parts.append(
{
"functionResponse": {
"name": msg.name,
"response": tool_result_obj,
"response": final_response_payload,
}
}
)
@ -145,22 +158,26 @@ class GeminiAdapter(BaseAdapter):
all_tools_for_request = []
if tools:
for tool in tools:
if tool.type == "function" and tool.function:
all_tools_for_request.append(
{"functionDeclarations": [tool.function]}
)
elif tool.type == "mcp" and tool.mcp_session:
if callable(tool.mcp_session):
raise ValueError(
"适配器接收到未激活的 MCP 会话工厂。"
"会话工厂应该在 LLMModel.generate_response 中被激活。"
)
all_tools_for_request.append(
tool.mcp_session.to_api_tool(api_type=self.api_type)
)
elif tool.type == "google_search":
all_tools_for_request.append({"googleSearch": {}})
import asyncio
from zhenxun.utils.pydantic_compat import model_dump
definition_tasks = [
executable.get_definition() for executable in tools.values()
]
tool_definitions = await asyncio.gather(*definition_tasks)
function_declarations = []
for tool_def in tool_definitions:
tool_def.parameters = sanitize_schema_for_llm(
tool_def.parameters, api_type="gemini"
)
function_declarations.append(model_dump(tool_def))
if function_declarations:
all_tools_for_request.append(
{"functionDeclarations": function_declarations}
)
if effective_config:
if getattr(effective_config, "enable_grounding", False):
@ -289,49 +306,21 @@ class GeminiAdapter(BaseAdapter):
self, model: "LLMModel", config: "LLMGenerationConfig | None" = None
) -> dict[str, Any]:
"""构建Gemini生成配置"""
generation_config: dict[str, Any] = {}
effective_config = config if config is not None else model._generation_config
if effective_config:
base_api_params = effective_config.to_api_params(
api_type="gemini", model_name=model.model_name
if not effective_config:
return {}
generation_config = effective_config.to_api_params(
api_type="gemini", model_name=model.model_name
)
if generation_config:
param_keys = list(generation_config.keys())
logger.debug(
f"构建Gemini生成配置完成包含 {len(generation_config)} 个参数: "
f"{param_keys}"
)
generation_config.update(base_api_params)
if getattr(effective_config, "response_mime_type", None):
generation_config["responseMimeType"] = (
effective_config.response_mime_type
)
if getattr(effective_config, "response_schema", None):
generation_config["responseSchema"] = effective_config.response_schema
thinking_budget = getattr(effective_config, "thinking_budget", None)
if thinking_budget is not None:
if "thinkingConfig" not in generation_config:
generation_config["thinkingConfig"] = {}
generation_config["thinkingConfig"]["thinkingBudget"] = thinking_budget
if getattr(effective_config, "response_modalities", None):
modalities = effective_config.response_modalities
if isinstance(modalities, list):
generation_config["responseModalities"] = [
m.upper() for m in modalities
]
elif isinstance(modalities, str):
generation_config["responseModalities"] = [modalities.upper()]
generation_config = {
k: v for k, v in generation_config.items() if v is not None
}
if generation_config:
param_keys = list(generation_config.keys())
logger.debug(
f"构建Gemini生成配置完成包含 {len(generation_config)} 个参数: "
f"{param_keys}"
)
return generation_config
@ -410,10 +399,16 @@ class GeminiAdapter(BaseAdapter):
text_content = ""
parsed_tool_calls: list["LLMToolCall"] | None = None
thought_summary_parts = []
answer_parts = []
for part in parts:
if "text" in part:
text_content += part["text"]
answer_parts.append(part["text"])
elif "thought" in part:
thought_summary_parts.append(part["thought"])
elif "thoughtSummary" in part:
thought_summary_parts.append(part["thoughtSummary"])
elif "functionCall" in part:
if parsed_tool_calls is None:
parsed_tool_calls = []
@ -445,12 +440,27 @@ class GeminiAdapter(BaseAdapter):
result = part["codeExecutionResult"]
if result.get("outcome") == "OK":
output = result.get("output", "")
text_content += f"\n[代码执行结果]:\n{output}\n"
answer_parts.append(f"\n[代码执行结果]:\n```\n{output}\n```\n")
else:
text_content += (
answer_parts.append(
f"\n[代码执行失败]: {result.get('outcome', 'UNKNOWN')}\n"
)
if thought_summary_parts:
full_thought_summary = "\n".join(thought_summary_parts).strip()
full_answer = "".join(answer_parts).strip()
formatted_parts = []
if full_thought_summary:
formatted_parts.append(f"🤔 **思考过程**\n\n{full_thought_summary}")
if full_answer:
separator = "\n\n---\n\n" if full_thought_summary else ""
formatted_parts.append(f"{separator}✅ **回答**\n\n{full_answer}")
text_content = "".join(formatted_parts)
else:
text_content = "".join(answer_parts)
usage_info = response_json.get("usageMetadata")
grounding_metadata_obj = None

View File

@ -1,16 +1,19 @@
"""
LLM 服务的高级 API 接口 - 便捷函数入口
LLM 服务的高级 API 接口 - 便捷函数入口 (无状态)
"""
from pathlib import Path
from typing import Any
from typing import Any, TypeVar
from nonebot_plugin_alconna.uniseg import UniMessage
from pydantic import BaseModel
from zhenxun.services.log import logger
from .config import CommonOverrides
from .config.generation import create_generation_config_from_kwargs
from .manager import get_model_instance
from .session import AI
from .tools.manager import tool_provider_manager
from .types import (
EmbeddingTaskType,
LLMContentPart,
@ -18,37 +21,53 @@ from .types import (
LLMException,
LLMMessage,
LLMResponse,
LLMTool,
ModelName,
)
from .utils import create_multimodal_message, unimsg_to_llm_parts
T = TypeVar("T", bound=BaseModel)
async def chat(
message: str | LLMMessage | list[LLMContentPart],
message: str | UniMessage | LLMMessage | list[LLMContentPart],
*,
model: ModelName = None,
tools: list[LLMTool] | None = None,
instruction: str | None = None,
tools: list[dict[str, Any] | str] | None = None,
tool_choice: str | dict[str, Any] | None = None,
**kwargs: Any,
) -> LLMResponse:
"""
聊天对话便捷函数
无状态的聊天对话便捷函数通过临时的AI会话实例与LLM模型交互
参数:
message: 用户输入的消息
model: 要使用的模型名称
tools: 本次对话可用的工具列表
tool_choice: 强制模型使用的工具
**kwargs: 传递给模型的其他参数
message: 用户输入的消息内容支持多种格式
model: 要使用的模型名称如果为None则使用默认模型
instruction: 系统指令用于指导AI的行为和回复风格
tools: 可用的工具列表支持字典配置或字符串标识符
tool_choice: 工具选择策略控制AI如何选择和使用工具
**kwargs: 额外的生成配置参数会被转换为LLMGenerationConfig
返回:
LLMResponse: 模型的完整响应可能包含文本或工具调用请求
LLMResponse: 包含AI回复内容使用信息和工具调用等的完整响应对象
"""
ai = AI()
return await ai.chat(
message, model=model, tools=tools, tool_choice=tool_choice, **kwargs
)
try:
config = create_generation_config_from_kwargs(**kwargs) if kwargs else None
ai_session = AI()
return await ai_session.chat(
message,
model=model,
instruction=instruction,
tools=tools,
tool_choice=tool_choice,
config=config,
)
except LLMException:
raise
except Exception as e:
logger.error(f"执行 chat 函数失败: {e}", e=e)
raise LLMException(f"聊天执行失败: {e}", cause=e)
async def code(
@ -57,143 +76,68 @@ async def code(
model: ModelName = None,
timeout: int | None = None,
**kwargs: Any,
) -> dict[str, Any]:
) -> LLMResponse:
"""
代码执行便捷函数
无状态的代码执行便捷函数支持在沙箱环境中执行代码
参数:
prompt: 代码执行的提示词
model: 要使用的模型名称
timeout: 代码执行超时时间
**kwargs: 传递给模型的其他参数
prompt: 代码执行的提示词描述要执行的代码任务
model: 要使用的模型名称默认使用Gemini/gemini-2.0-flash
timeout: 代码执行超时时间防止长时间运行的代码阻塞
**kwargs: 额外的生成配置参数
返回:
dict[str, Any]: 包含执行结果的字典
LLMResponse: 包含代码执行结果的完整响应对象
"""
ai = AI()
return await ai.code(prompt, model=model, timeout=timeout, **kwargs)
resolved_model = model or "Gemini/gemini-2.0-flash"
config = CommonOverrides.gemini_code_execution()
if timeout:
config.custom_params = config.custom_params or {}
config.custom_params["code_execution_timeout"] = timeout
final_config = config.to_dict()
final_config.update(kwargs)
return await chat(prompt, model=resolved_model, **final_config)
async def search(
query: str | UniMessage,
query: str | UniMessage | LLMMessage | list[LLMContentPart],
*,
model: ModelName = None,
instruction: str = "",
instruction: str = (
"你是一位强大的信息检索和整合专家。请利用可用的搜索工具,"
"根据用户的查询找到最相关的信息,并进行总结和回答。"
),
**kwargs: Any,
) -> dict[str, Any]:
) -> LLMResponse:
"""
信息搜索便捷函数
无状态的信息搜索便捷函数利用搜索工具获取实时信息
参数:
query: 搜索查询内容
model: 要使用的模型名称
instruction: 搜索指令
**kwargs: 传递给模型的其他参数
query: 搜索查询内容支持多种输入格式
model: 要使用的模型名称如果为None则使用默认模型
instruction: 搜索任务的系统指令指导AI如何处理搜索结果
**kwargs: 额外的生成配置参数
返回:
dict[str, Any]: 包含搜索结果的字典
LLMResponse: 包含搜索结果和AI整合回复的完整响应对象
"""
ai = AI()
return await ai.search(query, model=model, instruction=instruction, **kwargs)
logger.debug("执行无状态 'search' 任务...")
search_config = CommonOverrides.gemini_grounding()
final_config = search_config.to_dict()
final_config.update(kwargs)
async def analyze(
message: UniMessage | None,
*,
instruction: str = "",
model: ModelName = None,
use_tools: list[str] | None = None,
tool_config: dict[str, Any] | None = None,
**kwargs: Any,
) -> str | LLMResponse:
"""
内容分析便捷函数
参数:
message: 要分析的消息内容
instruction: 分析指令
model: 要使用的模型名称
use_tools: 要使用的工具名称列表
tool_config: 工具配置
**kwargs: 传递给模型的其他参数
返回:
str | LLMResponse: 分析结果
"""
ai = AI()
return await ai.analyze(
message,
instruction=instruction,
return await chat(
query,
model=model,
use_tools=use_tools,
tool_config=tool_config,
**kwargs,
instruction=instruction,
**final_config,
)
async def analyze_multimodal(
text: str | None = None,
images: list[str | Path | bytes] | str | Path | bytes | None = None,
videos: list[str | Path | bytes] | str | Path | bytes | None = None,
audios: list[str | Path | bytes] | str | Path | bytes | None = None,
*,
instruction: str = "",
model: ModelName = None,
**kwargs: Any,
) -> str | LLMResponse:
"""
多模态分析便捷函数
参数:
text: 文本内容
images: 图片文件路径字节数据或列表
videos: 视频文件路径字节数据或列表
audios: 音频文件路径字节数据或列表
instruction: 分析指令
model: 要使用的模型名称
**kwargs: 传递给模型的其他参数
返回:
str | LLMResponse: 分析结果
"""
message = create_multimodal_message(
text=text, images=images, videos=videos, audios=audios
)
return await analyze(message, instruction=instruction, model=model, **kwargs)
async def search_multimodal(
text: str | None = None,
images: list[str | Path | bytes] | str | Path | bytes | None = None,
videos: list[str | Path | bytes] | str | Path | bytes | None = None,
audios: list[str | Path | bytes] | str | Path | bytes | None = None,
*,
instruction: str = "",
model: ModelName = None,
**kwargs: Any,
) -> dict[str, Any]:
"""
多模态搜索便捷函数
参数:
text: 文本内容
images: 图片文件路径字节数据或列表
videos: 视频文件路径字节数据或列表
audios: 音频文件路径字节数据或列表
instruction: 搜索指令
model: 要使用的模型名称
**kwargs: 传递给模型的其他参数
返回:
dict[str, Any]: 包含搜索结果的字典
"""
message = create_multimodal_message(
text=text, images=images, videos=videos, audios=audios
)
ai = AI()
return await ai.search(message, model=model, instruction=instruction, **kwargs)
async def embed(
texts: list[str] | str,
*,
@ -202,140 +146,104 @@ async def embed(
**kwargs: Any,
) -> list[list[float]]:
"""
文本嵌入便捷函数
无状态的文本嵌入便捷函数将文本转换为向量表示
参数:
texts: 要生成嵌入向量的文本或文本列表
model: 要使用的嵌入模型名称
task_type: 嵌入任务类型
**kwargs: 传递给模型的其他参数
texts: 要生成嵌入的文本内容支持单个字符串或字符串列表
model: 要使用的嵌入模型名称如果为None则使用默认模型
task_type: 嵌入任务类型影响向量的优化方向如检索分类等
**kwargs: 额外的模型配置参数
返回:
list[list[float]]: 文本的嵌入向量列表
list[list[float]]: 文本对应的嵌入向量列表每个向量为浮点数列表
"""
ai = AI()
return await ai.embed(texts, model=model, task_type=task_type, **kwargs)
if isinstance(texts, str):
texts = [texts]
if not texts:
return []
async def pipeline_chat(
message: UniMessage | str | list[LLMContentPart],
model_chain: list[ModelName],
*,
initial_instruction: str = "",
final_instruction: str = "",
**kwargs: Any,
) -> LLMResponse:
"""
AI模型链式调用前一个模型的输出作为下一个模型的输入
参数:
message: 初始输入消息支持多模态
model_chain: 模型名称列表
initial_instruction: 第一个模型的系统指令
final_instruction: 最后一个模型的系统指令
**kwargs: 传递给模型实例的其他参数
返回:
LLMResponse: 最后一个模型的响应结果
"""
if not model_chain:
raise ValueError("模型链`model_chain`不能为空。")
current_content: str | list[LLMContentPart]
if isinstance(message, UniMessage):
current_content = await unimsg_to_llm_parts(message)
elif isinstance(message, str):
current_content = message
elif isinstance(message, list):
current_content = message
else:
raise TypeError(f"不支持的消息类型: {type(message)}")
final_response: LLMResponse | None = None
for i, model_name in enumerate(model_chain):
if not model_name:
raise ValueError(f"模型链中第 {i + 1} 个模型名称为空。")
is_first_step = i == 0
is_last_step = i == len(model_chain) - 1
messages_for_step: list[LLMMessage] = []
instruction_for_step = ""
if is_first_step and initial_instruction:
instruction_for_step = initial_instruction
elif is_last_step and final_instruction:
instruction_for_step = final_instruction
if instruction_for_step:
messages_for_step.append(LLMMessage.system(instruction_for_step))
messages_for_step.append(LLMMessage.user(current_content))
logger.info(
f"Pipeline Step [{i + 1}/{len(model_chain)}]: "
f"使用模型 '{model_name}' 进行处理..."
)
try:
async with await get_model_instance(model_name, **kwargs) as model:
response = await model.generate_response(messages_for_step)
final_response = response
current_content = response.text.strip()
if not current_content and not is_last_step:
logger.warning(
f"模型 '{model_name}' 在中间步骤返回了空内容,流水线可能无法继续。"
)
break
except Exception as e:
logger.error(f"在模型链的第 {i + 1} 步 ('{model_name}') 出错: {e}", e=e)
raise LLMException(
f"流水线在模型 '{model_name}' 处执行失败: {e}",
code=LLMErrorCode.GENERATION_FAILED,
cause=e,
try:
async with await get_model_instance(model) as model_instance:
return await model_instance.generate_embeddings(
texts, task_type=task_type, **kwargs
)
if final_response is None:
except LLMException:
raise
except Exception as e:
logger.error(f"文本嵌入失败: {e}", e=e)
raise LLMException(
"AI流水线未能产生任何响应。", code=LLMErrorCode.GENERATION_FAILED
f"文本嵌入失败: {e}", code=LLMErrorCode.EMBEDDING_FAILED, cause=e
)
return final_response
async def generate_structured(
message: str | LLMMessage | list[LLMContentPart],
response_model: type[T],
*,
model: ModelName = None,
instruction: str | None = None,
**kwargs: Any,
) -> T:
"""
无状态地生成结构化响应并自动解析为指定的Pydantic模型
参数:
message: 用户输入的消息内容支持多种格式
response_model: 用于解析和验证响应的Pydantic模型类
model: 要使用的模型名称如果为None则使用默认模型
instruction: 系统指令用于指导AI生成符合要求的结构化输出
**kwargs: 额外的生成配置参数
返回:
T: 解析后的Pydantic模型实例类型为response_model指定的类型
"""
try:
config = create_generation_config_from_kwargs(**kwargs) if kwargs else None
ai_session = AI()
return await ai_session.generate_structured(
message,
response_model,
model=model,
instruction=instruction,
config=config,
)
except LLMException:
raise
except Exception as e:
logger.error(f"生成结构化响应失败: {e}", e=e)
raise LLMException(f"生成结构化响应失败: {e}", cause=e)
async def generate(
messages: list[LLMMessage],
*,
model: ModelName = None,
tools: list[LLMTool] | None = None,
tools: list[dict[str, Any] | str] | None = None,
tool_choice: str | dict[str, Any] | None = None,
**kwargs: Any,
) -> LLMResponse:
"""
根据完整的消息列表包括系统指令生成一次性响应
这是一个便捷的函数不使用或修改任何会话历史
根据完整的消息列表生成一次性响应这是一个无状态的底层函数
参数:
messages: 用于生成响应的完整消息列表
model: 要使用的模型名称
tools: 可用的工具列表
tool_choice: 工具选择策略
**kwargs: 传递给模型的其他参数
messages: 完整的消息历史列表包括系统指令用户消息和助手回复
model: 要使用的模型名称如果为None则使用默认模型
tools: 可用的工具列表支持字典配置或字符串标识符
tool_choice: 工具选择策略控制AI如何选择和使用工具
**kwargs: 额外的生成配置参数会覆盖默认配置
返回:
LLMResponse: 模型的完整响应对象
LLMResponse: 包含AI回复内容使用信息和工具调用等的完整响应对象
"""
try:
ai_instance = AI()
resolved_model_name = ai_instance._resolve_model_name(model)
final_config_dict = ai_instance._merge_config(kwargs)
async with await get_model_instance(
resolved_model_name, override_config=final_config_dict
model, override_config=kwargs
) as model_instance:
return await model_instance.generate_response(
messages,
tools=tools,
tools=tools, # type: ignore
tool_choice=tool_choice,
)
except LLMException:
@ -343,3 +251,55 @@ async def generate(
except Exception as e:
logger.error(f"生成响应失败: {e}", e=e)
raise LLMException(f"生成响应失败: {e}", cause=e)
async def run_with_tools(
message: str | UniMessage | LLMMessage | list[LLMContentPart],
*,
model: ModelName = None,
instruction: str | None = None,
tools: list[str],
max_cycles: int = 5,
**kwargs: Any,
) -> LLMResponse:
"""
无状态地执行一个带本地Python函数的LLM调用循环
参数:
message: 用户输入
model: 使用的模型
instruction: 系统指令
tools: 要使用的本地函数工具名称列表 (必须已通过 @function_tool 注册)
max_cycles: 最大工具调用循环次数
**kwargs: 额外的生成配置参数
返回:
LLMResponse: 包含最终回复的响应对象
"""
from .executor import ExecutionConfig, LLMToolExecutor
from .utils import normalize_to_llm_messages
messages = await normalize_to_llm_messages(message, instruction)
async with await get_model_instance(
model, override_config=kwargs
) as model_instance:
resolved_tools = await tool_provider_manager.get_function_tools(tools)
if not resolved_tools:
logger.warning(
"run_with_tools 未找到任何可用的本地函数工具,将作为普通聊天执行。"
)
return await model_instance.generate_response(messages, tools=None)
executor = LLMToolExecutor(model_instance)
config = ExecutionConfig(max_cycles=max_cycles)
final_history = await executor.run(messages, resolved_tools, config)
for msg in reversed(final_history):
if msg.role == "assistant":
text = msg.content if isinstance(msg.content, str) else str(msg.content)
return LLMResponse(text=text, tool_calls=msg.tool_calls)
raise LLMException(
"带工具的执行循环未能产生有效的助手回复。", code=LLMErrorCode.GENERATION_FAILED
)

View File

@ -14,7 +14,6 @@ from .generation import (
from .presets import CommonOverrides
from .providers import (
LLMConfig,
ToolConfig,
get_gemini_safety_threshold,
get_llm_config,
register_llm_configs,
@ -27,7 +26,6 @@ __all__ = [
"LLMConfig",
"LLMGenerationConfig",
"ModelConfigOverride",
"ToolConfig",
"apply_api_specific_mappings",
"create_generation_config_from_kwargs",
"get_gemini_safety_threshold",

View File

@ -7,6 +7,7 @@ from typing import Any
from pydantic import BaseModel, Field
from zhenxun.services.log import logger
from zhenxun.utils.pydantic_compat import model_dump
from ..types.enums import ResponseFormat
from ..types.exceptions import LLMErrorCode, LLMException
@ -45,6 +46,9 @@ class ModelConfigOverride(BaseModel):
thinking_budget: float | None = Field(
default=None, ge=0.0, le=1.0, description="思考预算"
)
include_thoughts: bool | None = Field(
default=None, description="是否在响应中包含思维过程Gemini专用"
)
safety_settings: dict[str, str] | None = Field(default=None, description="安全设置")
response_modalities: list[str] | None = Field(
default=None, description="响应模态类型"
@ -62,22 +66,16 @@ class ModelConfigOverride(BaseModel):
def to_dict(self) -> dict[str, Any]:
"""转换为字典排除None值"""
model_data = model_dump(self, exclude_none=True)
result = {}
model_data = getattr(self, "model_dump", lambda: {})()
if not model_data:
model_data = {}
for field_name, _ in self.__class__.__dict__.get(
"model_fields", {}
).items():
value = getattr(self, field_name, None)
if value is not None:
model_data[field_name] = value
for key, value in model_data.items():
if value is not None:
if key == "custom_params" and isinstance(value, dict):
result.update(value)
else:
result[key] = value
if key == "custom_params" and isinstance(value, dict):
result.update(value)
else:
result[key] = value
return result
def merge_with_base_config(
@ -157,6 +155,10 @@ class LLMGenerationConfig(ModelConfigOverride):
params["responseSchema"] = self.response_schema
logger.debug(f"{api_type} 启用 JSON MIME 类型输出模式")
if self.custom_params:
custom_mapped = apply_api_specific_mappings(self.custom_params, api_type)
params.update(custom_mapped)
if api_type == "gemini":
if (
self.response_format != ResponseFormat.JSON
@ -169,17 +171,28 @@ class LLMGenerationConfig(ModelConfigOverride):
if self.response_schema is not None and "responseSchema" not in params:
params["responseSchema"] = self.response_schema
if self.thinking_budget is not None:
params["thinkingBudget"] = self.thinking_budget
if self.thinking_budget is not None or self.include_thoughts is not None:
thinking_config = params.setdefault("thinkingConfig", {})
if self.thinking_budget is not None:
max_budget = 24576
budget_value = int(self.thinking_budget * max_budget)
thinking_config["thinkingBudget"] = budget_value
logger.debug(
f"已将 thinking_budget (float: {self.thinking_budget}) "
f"转换为 Gemini API 的整数格式: {budget_value}"
)
if self.include_thoughts is not None:
thinking_config["includeThoughts"] = self.include_thoughts
logger.debug(f"已设置 includeThoughts: {self.include_thoughts}")
if self.safety_settings is not None:
params["safetySettings"] = self.safety_settings
if self.response_modalities is not None:
params["responseModalities"] = self.response_modalities
if self.custom_params:
custom_mapped = apply_api_specific_mappings(self.custom_params, api_type)
params.update(custom_mapped)
logger.debug(f"{api_type}转换配置参数: {len(params)}个参数")
return params

View File

@ -5,33 +5,19 @@ LLM 提供商配置管理
"""
from functools import lru_cache
import json
import sys
from typing import Any
from pydantic import BaseModel, Field
from zhenxun.configs.config import Config
from zhenxun.configs.path_config import DATA_PATH
from zhenxun.configs.utils import parse_as
from zhenxun.services.log import logger
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from ..core import key_store
from ..tools import tool_provider_manager
from ..types.models import ModelDetail, ProviderConfig
class ToolConfig(BaseModel):
"""MCP类型工具的配置定义"""
type: str = "mcp"
name: str = Field(..., description="工具的唯一名称标识")
description: str | None = Field(None, description="工具功能的描述")
mcp_config: dict[str, Any] | BaseModel = Field(
..., description="MCP服务器的特定配置"
)
AI_CONFIG_GROUP = "AI"
PROVIDERS_CONFIG_KEY = "PROVIDERS"
@ -57,9 +43,6 @@ class LLMConfig(BaseModel):
providers: list[ProviderConfig] = Field(
default_factory=list, description="配置多个 AI 服务提供商及其模型信息"
)
mcp_tools: list[ToolConfig] = Field(
default_factory=list, description="配置可用的外部MCP工具"
)
def get_provider_by_name(self, name: str) -> ProviderConfig | None:
"""根据名称获取提供商配置
@ -218,33 +201,6 @@ def get_default_providers() -> list[dict[str, Any]]:
]
def get_default_mcp_tools() -> dict[str, Any]:
"""
获取默认的MCP工具配置用于在文件不存在时创建
包含了 baidu-map, Context7, sequential-thinking.
"""
return {
"mcpServers": {
"baidu-map": {
"command": "npx",
"args": ["-y", "@baidumap/mcp-server-baidu-map"],
"env": {"BAIDU_MAP_API_KEY": "<YOUR_BAIDU_MAP_API_KEY>"},
"description": "百度地图工具,提供地理编码、路线规划等功能。",
},
"sequential-thinking": {
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-sequential-thinking"],
"description": "顺序思维工具,用于帮助模型进行多步骤推理。",
},
"Context7": {
"command": "npx",
"args": ["-y", "@upstash/context7-mcp@latest"],
"description": "Upstash 提供的上下文管理和记忆工具。",
},
}
}
def register_llm_configs():
"""注册 LLM 服务的配置项"""
logger.info("注册 LLM 服务的配置项")
@ -312,88 +268,9 @@ def register_llm_configs():
@lru_cache(maxsize=1)
def get_llm_config() -> LLMConfig:
"""获取 LLM 配置实例,现在会从新的 JSON 文件加载 MCP 工具"""
"""获取 LLM 配置实例,不再加载 MCP 工具配置"""
ai_config = get_ai_config()
llm_data_path = DATA_PATH / "llm"
mcp_tools_path = llm_data_path / "mcp_tools.json"
mcp_tools_list = []
mcp_servers_dict = {}
if not mcp_tools_path.exists():
logger.info(f"未找到 MCP 工具配置文件,将在 '{mcp_tools_path}' 创建一个。")
llm_data_path.mkdir(parents=True, exist_ok=True)
default_mcp_config = get_default_mcp_tools()
try:
with mcp_tools_path.open("w", encoding="utf-8") as f:
json.dump(default_mcp_config, f, ensure_ascii=False, indent=2)
mcp_servers_dict = default_mcp_config.get("mcpServers", {})
except Exception as e:
logger.error(f"创建默认 MCP 配置文件失败: {e}", e=e)
mcp_servers_dict = {}
else:
try:
with mcp_tools_path.open("r", encoding="utf-8") as f:
mcp_data = json.load(f)
mcp_servers_dict = mcp_data.get("mcpServers", {})
if not isinstance(mcp_servers_dict, dict):
logger.warning(
f"'{mcp_tools_path}' 中的 'mcpServers' 键不是一个字典,"
f"将使用空配置。"
)
mcp_servers_dict = {}
except json.JSONDecodeError as e:
logger.error(f"解析 MCP 配置文件 '{mcp_tools_path}' 失败: {e}", e=e)
except Exception as e:
logger.error(f"读取 MCP 配置文件时发生未知错误: {e}", e=e)
mcp_servers_dict = {}
if sys.platform == "win32":
logger.debug("检测到Windows平台正在调整MCP工具的npx命令...")
for name, config in mcp_servers_dict.items():
if isinstance(config, dict) and config.get("command") == "npx":
logger.info(f"为工具 '{name}' 包装npx命令以兼容Windows。")
original_args = config.get("args", [])
config["command"] = "cmd"
config["args"] = ["/c", "npx", *original_args]
if mcp_servers_dict:
mcp_tools_list = [
{
"name": name,
"type": "mcp",
"description": config.get("description", f"MCP tool for {name}"),
"mcp_config": config,
}
for name, config in mcp_servers_dict.items()
if isinstance(config, dict)
]
from ..tools.registry import tool_registry
for tool_dict in mcp_tools_list:
if isinstance(tool_dict, dict):
tool_name = tool_dict.get("name")
if not tool_name:
continue
config_model = tool_registry.get_mcp_config_model(tool_name)
if not config_model:
logger.debug(
f"MCP工具 '{tool_name}' 没有注册其配置模型,"
f"将跳过特定配置验证,直接使用原始配置字典。"
)
continue
mcp_config_data = tool_dict.get("mcp_config", {})
try:
parsed_mcp_config = parse_as(config_model, mcp_config_data)
tool_dict["mcp_config"] = parsed_mcp_config
except Exception as e:
raise ValueError(f"MCP工具 '{tool_name}' 的 `mcp_config` 配置错误: {e}")
config_data = {
"default_model_name": ai_config.get("default_model_name"),
"proxy": ai_config.get("proxy"),
@ -401,7 +278,6 @@ def get_llm_config() -> LLMConfig:
"max_retries_llm": ai_config.get("max_retries_llm", 3),
"retry_delay_llm": ai_config.get("retry_delay_llm", 2),
PROVIDERS_CONFIG_KEY: ai_config.get(PROVIDERS_CONFIG_KEY, []),
"mcp_tools": mcp_tools_list,
}
return parse_as(LLMConfig, config_data)
@ -504,12 +380,17 @@ def set_default_model(provider_model_name: str | None) -> bool:
async def _init_llm_config_on_startup():
"""
在服务启动时主动调用一次 get_llm_config key_store.initialize
以触发必要的初始化操作
并预热工具提供者管理器
"""
logger.info("正在初始化 LLM 配置并加载密钥状态...")
try:
get_llm_config()
await key_store.initialize()
logger.info("LLM 配置和密钥状态初始化完成。")
logger.debug("LLM 配置和密钥状态初始化完成。")
logger.debug("正在预热 LLM 工具提供者管理器...")
await tool_provider_manager.initialize()
logger.debug("LLM 工具提供者管理器预热完成。")
except Exception as e:
logger.error(f"LLM 配置或密钥状态初始化时发生错误: {e}", e=e)

View File

@ -335,10 +335,10 @@ async def with_smart_retry(
latency = (time.monotonic() - start_time) * 1000
if key_store and isinstance(result, tuple) and len(result) == 2:
final_result, api_key_used = result
_, api_key_used = result
if api_key_used:
await key_store.record_success(api_key_used, latency)
return final_result
return result
else:
return result

View File

@ -0,0 +1,193 @@
"""
LLM 轻量级工具执行器
提供驱动 LLM 与本地函数工具之间交互的核心循环
"""
import asyncio
from enum import Enum
import json
from typing import Any
from pydantic import BaseModel, Field
from zhenxun.services.log import logger
from zhenxun.utils.decorator.retry import Retry
from zhenxun.utils.pydantic_compat import model_dump
from .service import LLMModel
from .types import (
LLMErrorCode,
LLMException,
LLMMessage,
ToolExecutable,
ToolResult,
)
class ExecutionConfig(BaseModel):
"""
轻量级执行器的配置
"""
max_cycles: int = Field(default=5, description="工具调用循环的最大次数。")
class ToolErrorType(str, Enum):
"""结构化工具错误的类型枚举。"""
TOOL_NOT_FOUND = "ToolNotFound"
INVALID_ARGUMENTS = "InvalidArguments"
EXECUTION_ERROR = "ExecutionError"
USER_CANCELLATION = "UserCancellation"
class ToolErrorResult(BaseModel):
"""一个结构化的工具执行错误模型,用于返回给 LLM。"""
error_type: ToolErrorType = Field(..., description="错误的类型。")
message: str = Field(..., description="对错误的详细描述。")
is_retryable: bool = Field(False, description="指示这个错误是否可能通过重试解决。")
def model_dump(self, **kwargs):
return model_dump(self, **kwargs)
def _is_exception_retryable(e: Exception) -> bool:
"""判断一个异常是否应该触发重试。"""
if isinstance(e, LLMException):
retryable_codes = {
LLMErrorCode.API_REQUEST_FAILED,
LLMErrorCode.API_TIMEOUT,
LLMErrorCode.API_RATE_LIMITED,
}
return e.code in retryable_codes
return True
class LLMToolExecutor:
"""
一个通用的执行器负责驱动 LLM 与工具之间的多轮交互
"""
def __init__(self, model: LLMModel):
self.model = model
async def run(
self,
messages: list[LLMMessage],
tools: dict[str, ToolExecutable],
config: ExecutionConfig | None = None,
) -> list[LLMMessage]:
"""
执行完整的思考-行动循环
"""
effective_config = config or ExecutionConfig()
execution_history = list(messages)
for i in range(effective_config.max_cycles):
response = await self.model.generate_response(
execution_history, tools=tools
)
assistant_message = LLMMessage(
role="assistant",
content=response.text,
tool_calls=response.tool_calls,
)
execution_history.append(assistant_message)
if not response.tool_calls:
logger.info("✅ LLMToolExecutor模型未请求工具调用执行结束。")
return execution_history
logger.info(
f"🛠️ LLMToolExecutor模型请求并行调用 {len(response.tool_calls)} 个工具"
)
tool_results = await self._execute_tools_parallel_safely(
response.tool_calls,
tools,
)
execution_history.extend(tool_results)
raise LLMException(
f"超过最大工具调用循环次数 ({effective_config.max_cycles})。",
code=LLMErrorCode.GENERATION_FAILED,
)
async def _execute_single_tool_safely(
self, tool_call: Any, available_tools: dict[str, ToolExecutable]
) -> tuple[Any, ToolResult]:
"""安全地执行单个工具调用。"""
tool_name = tool_call.function.name
arguments = {}
try:
if tool_call.function.arguments:
arguments = json.loads(tool_call.function.arguments)
except json.JSONDecodeError as e:
error_result = ToolErrorResult(
error_type=ToolErrorType.INVALID_ARGUMENTS,
message=f"参数解析失败: {e}",
is_retryable=False,
)
return tool_call, ToolResult(output=model_dump(error_result))
try:
executable = available_tools.get(tool_name)
if not executable:
raise LLMException(
f"Tool '{tool_name}' not found.",
code=LLMErrorCode.CONFIGURATION_ERROR,
)
@Retry.simple(
stop_max_attempt=2, wait_fixed_seconds=1, return_on_failure=None
)
async def execute_with_retry():
return await executable.execute(**arguments)
execution_result = await execute_with_retry()
if execution_result is None:
raise LLMException("工具执行在多次重试后仍然失败。")
return tool_call, execution_result
except Exception as e:
error_type = ToolErrorType.EXECUTION_ERROR
is_retryable = _is_exception_retryable(e)
if (
isinstance(e, LLMException)
and e.code == LLMErrorCode.CONFIGURATION_ERROR
):
error_type = ToolErrorType.TOOL_NOT_FOUND
is_retryable = False
error_result = ToolErrorResult(
error_type=error_type, message=str(e), is_retryable=is_retryable
)
return tool_call, ToolResult(output=model_dump(error_result))
async def _execute_tools_parallel_safely(
self,
tool_calls: list[Any],
available_tools: dict[str, ToolExecutable],
) -> list[LLMMessage]:
"""并行执行所有工具调用,并对每个调用的错误进行隔离。"""
if not tool_calls:
return []
tasks = [
self._execute_single_tool_safely(call, available_tools)
for call in tool_calls
]
results = await asyncio.gather(*tasks)
tool_messages = [
LLMMessage.tool_response(
tool_call_id=original_call.id,
function_name=original_call.function.name,
result=result.output,
)
for original_call, result in results
]
return tool_messages

View File

@ -86,14 +86,23 @@ def _cache_model(cache_key: str, model: LLMModel):
def clear_model_cache():
"""清空模型缓存"""
"""
清空模型缓存释放所有缓存的模型实例
用于在内存不足或需要强制重新加载模型配置时清理缓存
"""
global _model_cache
_model_cache.clear()
logger.info("已清空模型缓存")
def get_cache_stats() -> dict[str, Any]:
"""获取缓存统计信息"""
"""
获取模型缓存的统计信息
返回:
dict[str, Any]: 包含缓存大小最大容量TTL和已缓存模型列表的统计信息
"""
return {
"cache_size": len(_model_cache),
"max_cache_size": _max_cache_size,
@ -169,7 +178,13 @@ def find_model_config(
def list_available_models() -> list[dict[str, Any]]:
"""列出所有配置的可用模型"""
"""
列出所有配置的可用模型及其详细信息
返回:
list[dict[str, Any]]: 模型信息列表每个字典包含提供商名称模型名称
能力信息是否为嵌入模型等详细信息
"""
providers = get_configured_providers()
model_list = []
for provider in providers:
@ -215,7 +230,13 @@ def list_model_identifiers() -> dict[str, list[str]]:
def list_embedding_models() -> list[dict[str, Any]]:
"""列出所有配置的嵌入模型"""
"""
列出所有配置的嵌入模型
返回:
list[dict[str, Any]]: 嵌入模型信息列表从所有可用模型中筛选出
支持嵌入功能的模型
"""
all_models = list_available_models()
return [model for model in all_models if model.get("is_embedding_model", False)]

View File

@ -0,0 +1,55 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Any
from .types import LLMMessage
class BaseMemory(ABC):
"""
记忆系统的抽象基类
定义了任何记忆后端都必须实现的接口
"""
@abstractmethod
async def get_history(self, session_id: str) -> list[LLMMessage]:
"""根据会话ID获取历史记录。"""
raise NotImplementedError
@abstractmethod
async def add_message(self, session_id: str, message: LLMMessage) -> None:
"""向指定会话添加一条消息。"""
raise NotImplementedError
@abstractmethod
async def add_messages(self, session_id: str, messages: list[LLMMessage]) -> None:
"""向指定会话添加多条消息。"""
raise NotImplementedError
@abstractmethod
async def clear_history(self, session_id: str) -> None:
"""清空指定会话的历史记录。"""
raise NotImplementedError
class InMemoryMemory(BaseMemory):
"""
一个简单的默认的内存记忆后端
将历史记录存储在进程内存中的字典里
"""
def __init__(self, **kwargs: Any):
self._history: dict[str, list[LLMMessage]] = defaultdict(list)
async def get_history(self, session_id: str) -> list[LLMMessage]:
return self._history.get(session_id, []).copy()
async def add_message(self, session_id: str, message: LLMMessage) -> None:
self._history[session_id].append(message)
async def add_messages(self, session_id: str, messages: list[LLMMessage]) -> None:
self._history[session_id].extend(messages)
async def clear_history(self, session_id: str) -> None:
if session_id in self._history:
del self._history[session_id]

View File

@ -6,9 +6,10 @@ LLM 模型实现类
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Callable
from contextlib import AsyncExitStack
import json
from typing import Any
from typing import Any, TypeVar
from pydantic import BaseModel
from zhenxun.services.log import logger
@ -28,33 +29,25 @@ from .types import (
LLMException,
LLMMessage,
LLMResponse,
LLMTool,
ModelDetail,
ProviderConfig,
ToolExecutable,
)
from .types.capabilities import ModelCapabilities, ModelModality
from .utils import _sanitize_request_body_for_logging
T = TypeVar("T", bound=BaseModel)
class LLMModelBase(ABC):
"""LLM模型抽象基类"""
@abstractmethod
async def generate_text(
self,
prompt: str,
history: list[dict[str, str]] | None = None,
**kwargs: Any,
) -> str:
"""生成文本"""
pass
@abstractmethod
async def generate_response(
self,
messages: list[LLMMessage],
config: LLMGenerationConfig | None = None,
tools: list[LLMTool] | None = None,
tools: dict[str, ToolExecutable] | None = None,
tool_choice: str | dict[str, Any] | None = None,
**kwargs: Any,
) -> LLMResponse:
@ -311,7 +304,7 @@ class LLMModel(LLMModelBase):
adapter,
messages: list[LLMMessage],
config: LLMGenerationConfig | None,
tools: list[LLMTool] | None,
tools: dict[str, ToolExecutable] | None,
tool_choice: str | dict[str, Any] | None,
http_client: LLMHttpClient,
):
@ -339,7 +332,7 @@ class LLMModel(LLMModelBase):
adapter,
messages: list[LLMMessage],
config: LLMGenerationConfig | None,
tools: list[LLMTool] | None,
tools: dict[str, ToolExecutable] | None,
tool_choice: str | dict[str, Any] | None,
http_client: LLMHttpClient,
failed_keys: set[str] | None = None,
@ -428,66 +421,23 @@ class LLMModel(LLMModelBase):
if self._is_closed:
raise RuntimeError(f"LLMModel实例已关闭: {self}")
async def generate_text(
self,
prompt: str,
history: list[dict[str, str]] | None = None,
**kwargs: Any,
) -> str:
"""生成文本"""
self._check_not_closed()
messages: list[LLMMessage] = []
if history:
for msg in history:
role = msg.get("role", "user")
content_text = msg.get("content", "")
messages.append(LLMMessage(role=role, content=content_text))
messages.append(LLMMessage.user(prompt))
model_fields = getattr(LLMGenerationConfig, "model_fields", {})
request_specific_config_dict = {
k: v for k, v in kwargs.items() if k in model_fields
}
request_specific_config = None
if request_specific_config_dict:
request_specific_config = LLMGenerationConfig(
**request_specific_config_dict
)
for key in request_specific_config_dict:
kwargs.pop(key, None)
response = await self.generate_response(
messages,
config=request_specific_config,
**kwargs,
)
return response.text
async def generate_response(
self,
messages: list[LLMMessage],
config: LLMGenerationConfig | None = None,
tools: list[LLMTool] | None = None,
tools: dict[str, ToolExecutable] | None = None,
tool_choice: str | dict[str, Any] | None = None,
**kwargs: Any,
) -> LLMResponse:
"""生成高级响应"""
"""
生成高级响应
此方法现在只执行 *单次* LLM API 调用并将结果包括工具调用请求返回
"""
self._check_not_closed()
from .adapters import get_adapter_for_api_type
from .config.generation import create_generation_config_from_kwargs
adapter = get_adapter_for_api_type(self.api_type)
if not adapter:
raise LLMException(
f"未找到适用于 API 类型 '{self.api_type}' 的适配器",
code=LLMErrorCode.CONFIGURATION_ERROR,
)
final_request_config = self._generation_config or LLMGenerationConfig()
if kwargs:
kwargs_config = create_generation_config_from_kwargs(**kwargs)
@ -500,43 +450,19 @@ class LLMModel(LLMModelBase):
merged_dict.update(config.to_dict())
final_request_config = LLMGenerationConfig(**merged_dict)
adapter = get_adapter_for_api_type(self.api_type)
http_client = await self._get_http_client()
async with AsyncExitStack() as stack:
activated_tools = []
if tools:
for tool in tools:
if tool.type == "mcp" and callable(tool.mcp_session):
func_obj = getattr(tool.mcp_session, "func", None)
tool_name = (
getattr(func_obj, "__name__", "unknown")
if func_obj
else "unknown"
)
logger.debug(f"正在激活 MCP 工具会话: {tool_name}")
response, _ = await self._execute_with_smart_retry(
adapter,
messages,
final_request_config,
tools,
tool_choice,
http_client,
)
active_session = await stack.enter_async_context(
tool.mcp_session()
)
activated_tools.append(
LLMTool.from_mcp_session(
session=active_session, annotations=tool.annotations
)
)
else:
activated_tools.append(tool)
llm_response = await self._execute_with_smart_retry(
adapter,
messages,
final_request_config,
activated_tools if activated_tools else None,
tool_choice,
http_client,
)
return llm_response
return response
async def generate_embeddings(
self,

View File

@ -5,17 +5,27 @@ LLM 服务 - 会话客户端
"""
import copy
from dataclasses import dataclass
from typing import Any
from dataclasses import dataclass, field
import json
from typing import Any, TypeVar
import uuid
from jinja2 import Environment
from nonebot.compat import type_validate_json
from nonebot_plugin_alconna.uniseg import UniMessage
from pydantic import BaseModel, ValidationError
from zhenxun.services.log import logger
from zhenxun.utils.pydantic_compat import model_copy, model_dump, model_json_schema
from .config import CommonOverrides, LLMGenerationConfig
from .config import (
CommonOverrides,
LLMGenerationConfig,
)
from .config.providers import get_ai_config
from .manager import get_global_default_model_name, get_model_instance
from .tools import tool_registry
from .memory import BaseMemory, InMemoryMemory
from .tools.manager import tool_provider_manager
from .types import (
EmbeddingTaskType,
LLMContentPart,
@ -23,67 +33,93 @@ from .types import (
LLMException,
LLMMessage,
LLMResponse,
LLMTool,
ModelName,
ResponseFormat,
ToolExecutable,
ToolProvider,
)
from .utils import unimsg_to_llm_parts
from .utils import normalize_to_llm_messages
T = TypeVar("T", bound=BaseModel)
jinja_env = Environment(autoescape=False)
@dataclass
class AIConfig:
"""AI配置类 - 简化版本"""
"""AI配置类 - [重构后] 简化版本"""
model: ModelName = None
default_embedding_model: ModelName = None
temperature: float | None = None
max_tokens: int | None = None
enable_cache: bool = False
enable_code: bool = False
enable_search: bool = False
timeout: int | None = None
enable_gemini_json_mode: bool = False
enable_gemini_thinking: bool = False
enable_gemini_safe_mode: bool = False
enable_gemini_multimodal: bool = False
enable_gemini_grounding: bool = False
default_preserve_media_in_history: bool = False
tool_providers: list[ToolProvider] = field(default_factory=list)
def __post_init__(self):
"""初始化后从配置中读取默认值"""
ai_config = get_ai_config()
if self.model is None:
self.model = ai_config.get("default_model_name")
if self.timeout is None:
self.timeout = ai_config.get("timeout", 180)
class AI:
"""统一的AI服务类 - 平衡设计版本
提供三层API
1. 简单方法ai.chat(), ai.code(), ai.search()
2. 标准方法ai.analyze() 支持复杂参数
3. 高级方法通过get_model_instance()直接访问
"""
统一的AI服务类 - 提供了带记忆的会话接口
不再执行自主工具循环当LLM返回工具调用时会直接将请求返回给调用者
"""
def __init__(
self, config: AIConfig | None = None, history: list[LLMMessage] | None = None
self,
session_id: str | None = None,
config: AIConfig | None = None,
memory: BaseMemory | None = None,
default_generation_config: LLMGenerationConfig | None = None,
):
"""
初始化AI服务
参数:
session_id: 唯一的会话ID用于隔离记忆
config: AI 配置.
history: 可选的初始对话历史.
memory: 可选的自定义记忆后端如果为None则使用默认的InMemoryMemory
default_generation_config: (新增) 此AI实例的默认生成配置
"""
self.session_id = session_id or str(uuid.uuid4())
self.config = config or AIConfig()
self.history = history or []
self.memory = memory or InMemoryMemory()
self.default_generation_config = (
default_generation_config or LLMGenerationConfig()
)
def clear_history(self):
"""清空当前会话的历史记录"""
self.history = []
logger.info("AI session history cleared.")
global_providers = tool_provider_manager._providers
config_providers = self.config.tool_providers
self._tool_providers = list(dict.fromkeys(global_providers + config_providers))
async def clear_history(self):
"""清空当前会话的历史记录。"""
await self.memory.clear_history(self.session_id)
logger.info(f"AI会话历史记录已清空 (session_id: {self.session_id})")
async def add_user_message_to_history(
self, message: str | LLMMessage | list[LLMContentPart]
):
"""
将一条用户消息标准化并添加到会话历史中
参数:
message: 用户消息内容
"""
user_message = await self._normalize_input_to_message(message)
await self.memory.add_message(self.session_id, user_message)
async def add_assistant_response_to_history(self, response_text: str):
"""
将助手的文本回复添加到会话历史中
参数:
response_text: 助手的回复文本
"""
assistant_message = LLMMessage.assistant_text_response(response_text)
await self.memory.add_message(self.session_id, assistant_message)
def _sanitize_message_for_history(self, message: LLMMessage) -> LLMMessage:
"""
@ -121,83 +157,122 @@ class AI:
sanitized_message.content = new_content_parts
return sanitized_message
async def _normalize_input_to_message(
self, message: str | UniMessage | LLMMessage | list[LLMContentPart]
) -> LLMMessage:
"""
[重构后] 内部辅助方法将各种输入类型统一转换为单个 LLMMessage 对象
它调用共享的工具函数并提取最后一条消息通常是用户输入
"""
messages = await normalize_to_llm_messages(message)
if not messages:
raise LLMException(
"无法将输入标准化为有效的消息。", code=LLMErrorCode.CONFIGURATION_ERROR
)
return messages[-1]
async def chat(
self,
message: str | LLMMessage | list[LLMContentPart],
message: str | UniMessage | LLMMessage | list[LLMContentPart],
*,
model: ModelName = None,
instruction: str | None = None,
template_vars: dict[str, Any] | None = None,
preserve_media_in_history: bool | None = None,
tools: list[LLMTool] | None = None,
tools: list[dict[str, Any] | str] | dict[str, ToolExecutable] | None = None,
tool_choice: str | dict[str, Any] | None = None,
**kwargs: Any,
config: LLMGenerationConfig | None = None,
) -> LLMResponse:
"""
进行一次聊天对话支持工具调用
此方法会自动使用和更新会话内的历史记录
核心交互方法管理会话历史并执行单次LLM调用
参数:
message: 用户输入的消息
model: 本次对话要使用的模型
preserve_media_in_history: 是否在历史记录中保留原始多模态信息
- True: 保留用于深度多轮媒体分析
- False: 不保留替换为占位符提高效率
- None (默认): 使用AI实例配置的默认值
tools: 本次对话可用的工具列表
tool_choice: 强制模型使用的工具
**kwargs: 传递给模型的其他生成参数
message: 用户输入的消息内容支持文本UniMessageLLMMessage或
内容部分列表
model: 要使用的模型名称如果为None则使用配置中的默认模型
instruction: 本次调用的特定系统指令会与全局指令合并
template_vars: 模板变量字典用于在指令中进行变量替换
preserve_media_in_history: 是否在历史记录中保留媒体内容
None时使用默认配置
tools: 可用的工具列表或工具字典支持临时工具和预配置工具
tool_choice: 工具选择策略控制AI如何选择和使用工具
config: 生成配置对象用于覆盖默认的生成参数
返回:
LLMResponse: 模型的完整响应可能包含文本或工具调用请求
LLMResponse: 包含AI回复工具调用请求使用信息等的完整响应对象
"""
current_message: LLMMessage
if isinstance(message, str):
current_message = LLMMessage.user(message)
elif isinstance(message, list) and all(
isinstance(part, LLMContentPart) for part in message
):
current_message = LLMMessage.user(message)
elif isinstance(message, LLMMessage):
current_message = message
else:
raise LLMException(
f"AI.chat 不支持的消息类型: {type(message)}. "
"请使用 str, LLMMessage, 或 list[LLMContentPart]. "
"对于更复杂的多模态输入或文件路径,请使用 AI.analyze().",
code=LLMErrorCode.API_REQUEST_FAILED,
current_message = await self._normalize_input_to_message(message)
messages_for_run = []
final_instruction = instruction
if final_instruction and template_vars:
try:
template = jinja_env.from_string(final_instruction)
final_instruction = template.render(**template_vars)
logger.debug(f"渲染后的系统指令: {final_instruction}")
except Exception as e:
logger.error(f"渲染系统指令模板失败: {e}", e=e)
if final_instruction:
messages_for_run.append(LLMMessage.system(final_instruction))
current_history = await self.memory.get_history(self.session_id)
messages_for_run.extend(current_history)
messages_for_run.append(current_message)
try:
resolved_model_name = self._resolve_model_name(model or self.config.model)
final_config = model_copy(self.default_generation_config, deep=True)
if config:
update_dict = model_dump(config, exclude_unset=True)
final_config = model_copy(final_config, update=update_dict)
ad_hoc_tools = None
if tools:
if isinstance(tools, dict):
ad_hoc_tools = tools
else:
ad_hoc_tools = await self._resolve_tools(tools)
async with await get_model_instance(
resolved_model_name,
override_config=final_config.to_dict(),
) as model_instance:
response = await model_instance.generate_response(
messages_for_run, tools=ad_hoc_tools, tool_choice=tool_choice
)
should_preserve = (
preserve_media_in_history
if preserve_media_in_history is not None
else self.config.default_preserve_media_in_history
)
user_msg_to_store = (
current_message
if should_preserve
else self._sanitize_message_for_history(current_message)
)
assistant_response_msg = LLMMessage.assistant_text_response(response.text)
if response.tool_calls:
assistant_response_msg = LLMMessage.assistant_tool_calls(
response.tool_calls, response.text
)
await self.memory.add_messages(
self.session_id, [user_msg_to_store, assistant_response_msg]
)
final_messages = [*self.history, current_message]
return response
response = await self._execute_generation(
messages=final_messages,
model_name=model,
error_message="聊天失败",
config_overrides=kwargs,
llm_tools=tools,
tool_choice=tool_choice,
)
should_preserve = (
preserve_media_in_history
if preserve_media_in_history is not None
else self.config.default_preserve_media_in_history
)
if should_preserve:
logger.debug("深度分析模式:在历史记录中保留原始多模态消息。")
self.history.append(current_message)
else:
logger.debug("高效模式:净化历史记录中的多模态消息。")
sanitized_user_message = self._sanitize_message_for_history(current_message)
self.history.append(sanitized_user_message)
self.history.append(
LLMMessage(
role="assistant", content=response.text, tool_calls=response.tool_calls
except Exception as e:
raise (
e
if isinstance(e, LLMException)
else LLMException(f"聊天执行失败: {e}", cause=e)
)
)
return response
async def code(
self,
@ -205,8 +280,8 @@ class AI:
*,
model: ModelName = None,
timeout: int | None = None,
**kwargs: Any,
) -> dict[str, Any]:
config: LLMGenerationConfig | None = None,
) -> LLMResponse:
"""
代码执行
@ -214,217 +289,120 @@ class AI:
prompt: 代码执行的提示词
model: 要使用的模型名称
timeout: 代码执行超时时间
**kwargs: 传递给模型的其他参数
config: (可选) 覆盖默认的生成配置
返回:
dict[str, Any]: 包含执行结果的字典包含textcode_executions和success字段
LLMResponse: 包含执行结果的完整响应对象
"""
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
config = CommonOverrides.gemini_code_execution()
code_config = CommonOverrides.gemini_code_execution()
if timeout:
config.custom_params = config.custom_params or {}
config.custom_params["code_execution_timeout"] = timeout
code_config.custom_params = code_config.custom_params or {}
code_config.custom_params["code_execution_timeout"] = timeout
messages = [LLMMessage.user(prompt)]
if config:
update_dict = model_dump(config, exclude_unset=True)
code_config = model_copy(code_config, update=update_dict)
response = await self._execute_generation(
messages=messages,
model_name=resolved_model,
error_message="代码执行失败",
config_overrides=kwargs,
base_config=config,
)
return {
"text": response.text,
"code_executions": response.code_executions or [],
"success": True,
}
return await self.chat(prompt, model=resolved_model, config=code_config)
async def search(
self,
query: str | UniMessage,
query: UniMessage,
*,
model: ModelName = None,
instruction: str = "",
**kwargs: Any,
) -> dict[str, Any]:
instruction: str = (
"你是一位强大的信息检索和整合专家。请利用可用的搜索工具,"
"根据用户的查询找到最相关的信息,并进行总结和回答。"
),
template_vars: dict[str, Any] | None = None,
config: LLMGenerationConfig | None = None,
) -> LLMResponse:
"""
信息搜索 - 支持多模态输入
参数:
query: 搜索查询内容支持文本或多模态消息
model: 要使用的模型名称
instruction: 搜索指令
**kwargs: 传递给模型的其他参数
返回:
dict[str, Any]: 包含搜索结果的字典包含textsourcesqueries和success字段
信息搜索的便捷入口原生支持多模态查询
"""
from nonebot_plugin_alconna.uniseg import UniMessage
logger.info("执行 'search' 任务...")
search_config = CommonOverrides.gemini_grounding()
resolved_model = model or self.config.model or "Gemini/gemini-2.0-flash"
config = CommonOverrides.gemini_grounding()
if config:
update_dict = model_dump(config, exclude_unset=True)
search_config = model_copy(search_config, update=update_dict)
if isinstance(query, str):
messages = [LLMMessage.user(query)]
elif isinstance(query, UniMessage):
content_parts = await unimsg_to_llm_parts(query)
final_messages: list[LLMMessage] = []
if instruction:
final_messages.append(LLMMessage.system(instruction))
if not content_parts:
if instruction:
final_messages.append(LLMMessage.user(instruction))
else:
raise LLMException(
"搜索内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
)
else:
final_messages.append(LLMMessage.user(content_parts))
messages = final_messages
else:
raise LLMException(
f"不支持的搜索输入类型: {type(query)}. 请使用 str 或 UniMessage.",
code=LLMErrorCode.API_REQUEST_FAILED,
)
response = await self._execute_generation(
messages=messages,
model_name=resolved_model,
error_message="信息搜索失败",
config_overrides=kwargs,
base_config=config,
return await self.chat(
query,
model=model,
instruction=instruction,
template_vars=template_vars,
config=search_config,
)
result = {
"text": response.text,
"sources": [],
"queries": [],
"success": True,
}
if response.grounding_metadata:
result["sources"] = response.grounding_metadata.grounding_attributions or []
result["queries"] = response.grounding_metadata.web_search_queries or []
return result
async def analyze(
async def generate_structured(
self,
message: UniMessage | None,
message: str | LLMMessage | list[LLMContentPart],
response_model: type[T],
*,
instruction: str = "",
model: ModelName = None,
use_tools: list[str] | None = None,
tool_config: dict[str, Any] | None = None,
activated_tools: list[LLMTool] | None = None,
history: list[LLMMessage] | None = None,
**kwargs: Any,
) -> LLMResponse:
instruction: str | None = None,
config: LLMGenerationConfig | None = None,
) -> T:
"""
内容分析 - 接收 UniMessage 物件进行多模态分析和工具呼叫
生成结构化响应并自动解析为指定的Pydantic模型
参数:
message: 要分析的消息内容支持多模态
instruction: 分析指令
model: 要使用的模型名称
use_tools: 要使用的工具名称列表
tool_config: 工具配置
activated_tools: 已激活的工具列表
history: 对话历史记录
**kwargs: 传递给模型的其他参数
message: 用户输入的消息内容支持多种格式
response_model: 用于解析和验证响应的Pydantic模型类
model: 要使用的模型名称如果为None则使用配置中的默认模型
instruction: 本次调用的特定系统指令会与JSON Schema指令合并
config: 生成配置对象用于覆盖默认的生成参数
返回:
LLMResponse: 模型的完整响应结果
T: 解析后的Pydantic模型实例类型为response_model指定的类型
异常:
LLMException: 如果模型返回的不是有效的JSON或验证失败
"""
from nonebot_plugin_alconna.uniseg import UniMessage
content_parts = await unimsg_to_llm_parts(message or UniMessage())
final_messages: list[LLMMessage] = []
if history:
final_messages.extend(history)
if instruction:
if not any(msg.role == "system" for msg in final_messages):
final_messages.insert(0, LLMMessage.system(instruction))
if not content_parts:
if instruction and not history:
final_messages.append(LLMMessage.user(instruction))
elif not history:
raise LLMException(
"分析内容为空或无法处理。", code=LLMErrorCode.API_REQUEST_FAILED
)
else:
final_messages.append(LLMMessage.user(content_parts))
llm_tools: list[LLMTool] | None = activated_tools
if not llm_tools and use_tools:
try:
llm_tools = tool_registry.get_tools(use_tools)
logger.debug(f"已从注册表加载工具定义: {use_tools}")
except ValueError as e:
raise LLMException(
f"加载工具定义失败: {e}",
code=LLMErrorCode.CONFIGURATION_ERROR,
cause=e,
)
tool_choice = None
if tool_config:
mode = tool_config.get("mode", "auto")
if mode in ["auto", "any", "none"]:
tool_choice = mode
response = await self._execute_generation(
messages=final_messages,
model_name=model,
error_message="内容分析失败",
config_overrides=kwargs,
llm_tools=llm_tools,
tool_choice=tool_choice,
)
return response
async def _execute_generation(
self,
messages: list[LLMMessage],
model_name: ModelName,
error_message: str,
config_overrides: dict[str, Any],
llm_tools: list[LLMTool] | None = None,
tool_choice: str | dict[str, Any] | None = None,
base_config: LLMGenerationConfig | None = None,
) -> LLMResponse:
"""通用的生成执行方法封装模型获取和单次API调用"""
try:
resolved_model_name = self._resolve_model_name(
model_name or self.config.model
)
final_config_dict = self._merge_config(
config_overrides, base_config=base_config
)
json_schema = model_json_schema(response_model)
except AttributeError:
json_schema = response_model.schema()
async with await get_model_instance(
resolved_model_name, override_config=final_config_dict
) as model_instance:
return await model_instance.generate_response(
messages,
tools=llm_tools,
tool_choice=tool_choice,
)
except LLMException:
raise
schema_str = json.dumps(json_schema, ensure_ascii=False, indent=2)
system_prompt = (
(f"{instruction}\n\n" if instruction else "")
+ "你必须严格按照以下 JSON Schema 格式进行响应。"
+ "不要包含任何额外的解释、注释或代码块标记,只返回纯粹的 JSON 对象。\n\n"
)
system_prompt += f"JSON Schema:\n```json\n{schema_str}\n```"
final_config = model_copy(config) if config else LLMGenerationConfig()
final_config.response_format = ResponseFormat.JSON
final_config.response_schema = json_schema
response = await self.chat(
message, model=model, instruction=system_prompt, config=final_config
)
try:
return type_validate_json(response_model, response.text)
except ValidationError as e:
logger.error(f"LLM结构化输出验证失败: {e}", e=e)
raise LLMException(
"LLM返回的JSON未能通过结构验证。",
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
details={"raw_response": response.text, "validation_error": str(e)},
cause=e,
)
except Exception as e:
logger.error(f"{error_message}: {e}", e=e)
raise LLMException(f"{error_message}: {e}", cause=e)
logger.error(f"解析LLM结构化输出时发生未知错误: {e}", e=e)
raise LLMException(
"解析LLM的JSON输出时失败。",
code=LLMErrorCode.RESPONSE_PARSE_ERROR,
details={"raw_response": response.text},
cause=e,
)
def _resolve_model_name(self, model_name: ModelName) -> str:
"""解析模型名称"""
@ -440,45 +418,6 @@ class AI:
code=LLMErrorCode.MODEL_NOT_FOUND,
)
def _merge_config(
self,
user_config: dict[str, Any],
base_config: LLMGenerationConfig | None = None,
) -> dict[str, Any]:
"""合并配置"""
final_config = {}
if base_config:
final_config.update(base_config.to_dict())
if self.config.temperature is not None:
final_config["temperature"] = self.config.temperature
if self.config.max_tokens is not None:
final_config["max_tokens"] = self.config.max_tokens
if self.config.enable_cache:
final_config["enable_caching"] = True
if self.config.enable_code:
final_config["enable_code_execution"] = True
if self.config.enable_search:
final_config["enable_grounding"] = True
if self.config.enable_gemini_json_mode:
final_config["response_mime_type"] = "application/json"
if self.config.enable_gemini_thinking:
final_config["thinking_budget"] = 0.8
if self.config.enable_gemini_safe_mode:
final_config["safety_settings"] = (
CommonOverrides.gemini_safe().safety_settings
)
if self.config.enable_gemini_multimodal:
final_config.update(CommonOverrides.gemini_multimodal().to_dict())
if self.config.enable_gemini_grounding:
final_config["enable_grounding"] = True
final_config.update(user_config)
return final_config
async def embed(
self,
texts: list[str] | str,
@ -488,16 +427,19 @@ class AI:
**kwargs: Any,
) -> list[list[float]]:
"""
生成文本嵌入向量
生成文本嵌入向量将文本转换为数值向量表示
参数:
texts: 要生成嵌入向量的文本或文本列表
model: 要使用的嵌入模型名称
task_type: 嵌入任务类型
**kwargs: 传递给模型的其他参数
texts: 要生成嵌入的文本内容支持单个字符串或字符串列表
model: 嵌入模型名称如果为None则使用配置中的默认嵌入模型
task_type: 嵌入任务类型影响向量的优化方向如检索分类等
**kwargs: 传递给嵌入模型的额外参数
返回:
list[list[float]]: 文本的嵌入向量列表
list[list[float]]: 文本对应的嵌入向量列表每个向量为浮点数列表
异常:
LLMException: 如果嵌入生成失败或模型配置错误
"""
if isinstance(texts, str):
texts = [texts]
@ -530,3 +472,44 @@ class AI:
raise LLMException(
f"文本嵌入失败: {e}", code=LLMErrorCode.EMBEDDING_FAILED, cause=e
)
async def _resolve_tools(
self,
tool_configs: list[Any],
) -> dict[str, ToolExecutable]:
"""
使用注入的 ToolProvider 异步解析 ad-hoc临时工具配置
返回一个从工具名称到可执行对象的字典
"""
resolved: dict[str, ToolExecutable] = {}
for config in tool_configs:
name = config if isinstance(config, str) else config.get("name")
if not name:
raise LLMException(
"工具配置字典必须包含 'name' 字段。",
code=LLMErrorCode.CONFIGURATION_ERROR,
)
if isinstance(config, str):
config_dict = {"name": name, "type": "function"}
elif isinstance(config, dict):
config_dict = config
else:
raise TypeError(f"不支持的工具配置类型: {type(config)}")
executable = None
for provider in self._tool_providers:
executable = await provider.get_tool_executable(name, config_dict)
if executable:
break
if not executable:
raise LLMException(
f"没有为 ad-hoc 工具 '{name}' 找到合适的提供者。",
code=LLMErrorCode.CONFIGURATION_ERROR,
)
resolved[name] = executable
return resolved

View File

@ -2,6 +2,12 @@
工具模块导出
"""
from .registry import tool_registry
from .manager import tool_provider_manager
__all__ = ["tool_registry"]
function_tool = tool_provider_manager.function_tool
__all__ = [
"function_tool",
"tool_provider_manager",
]

View File

@ -0,0 +1,293 @@
"""
工具提供者管理器
负责注册生命周期管理包括懒加载和统一提供所有工具
"""
import asyncio
from collections.abc import Callable
import inspect
from typing import Any
from pydantic import BaseModel
from zhenxun.services.log import logger
from zhenxun.utils.pydantic_compat import model_json_schema
from ..types import ToolExecutable, ToolProvider
from ..types.models import ToolDefinition, ToolResult
class FunctionExecutable(ToolExecutable):
"""一个 ToolExecutable 的实现,用于包装一个普通的 Python 函数。"""
def __init__(
self,
func: Callable,
name: str,
description: str,
params_model: type[BaseModel] | None,
):
self._func = func
self._name = name
self._description = description
self._params_model = params_model
async def get_definition(self) -> ToolDefinition:
if not self._params_model:
return ToolDefinition(
name=self._name,
description=self._description,
parameters={"type": "object", "properties": {}},
)
schema = model_json_schema(self._params_model)
return ToolDefinition(
name=self._name,
description=self._description,
parameters={
"type": "object",
"properties": schema.get("properties", {}),
"required": schema.get("required", []),
},
)
async def execute(self, **kwargs: Any) -> ToolResult:
raw_result: Any
if self._params_model:
try:
params_instance = self._params_model(**kwargs)
if inspect.iscoroutinefunction(self._func):
raw_result = await self._func(params_instance)
else:
loop = asyncio.get_event_loop()
raw_result = await loop.run_in_executor(
None, lambda: self._func(params_instance)
)
except Exception as e:
logger.error(
f"执行工具 '{self._name}' 时参数验证或实例化失败: {e}", e=e
)
raise
else:
if inspect.iscoroutinefunction(self._func):
raw_result = await self._func(**kwargs)
else:
loop = asyncio.get_event_loop()
raw_result = await loop.run_in_executor(
None, lambda: self._func(**kwargs)
)
return ToolResult(output=raw_result, display_content=str(raw_result))
class BuiltinFunctionToolProvider(ToolProvider):
"""一个内置的 ToolProvider用于处理通过装饰器注册的函数。"""
def __init__(self):
self._functions: dict[str, dict[str, Any]] = {}
def register(
self,
name: str,
func: Callable,
description: str,
params_model: type[BaseModel] | None,
):
self._functions[name] = {
"func": func,
"description": description,
"params_model": params_model,
}
async def initialize(self) -> None:
pass
async def discover_tools(
self,
allowed_servers: list[str] | None = None,
excluded_servers: list[str] | None = None,
) -> dict[str, ToolExecutable]:
executables = {}
for name, info in self._functions.items():
executables[name] = FunctionExecutable(
func=info["func"],
name=name,
description=info["description"],
params_model=info["params_model"],
)
return executables
async def get_tool_executable(
self, name: str, config: dict[str, Any]
) -> ToolExecutable | None:
if config.get("type") == "function" and name in self._functions:
info = self._functions[name]
return FunctionExecutable(
func=info["func"],
name=name,
description=info["description"],
params_model=info["params_model"],
)
return None
class ToolProviderManager:
"""工具提供者的中心化管理器,采用单例模式。"""
_instance: "ToolProviderManager | None" = None
def __new__(cls) -> "ToolProviderManager":
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if hasattr(self, "_initialized") and self._initialized:
return
self._providers: list[ToolProvider] = []
self._resolved_tools: dict[str, ToolExecutable] | None = None
self._init_lock = asyncio.Lock()
self._init_promise: asyncio.Task | None = None
self._builtin_function_provider = BuiltinFunctionToolProvider()
self.register(self._builtin_function_provider)
self._initialized = True
def register(self, provider: ToolProvider):
"""注册一个新的 ToolProvider。"""
if provider not in self._providers:
self._providers.append(provider)
logger.info(f"已注册工具提供者: {provider.__class__.__name__}")
def function_tool(
self,
name: str,
description: str,
params_model: type[BaseModel] | None = None,
):
"""装饰器:将一个函数注册为内置工具。"""
def decorator(func: Callable):
if name in self._builtin_function_provider._functions:
logger.warning(f"正在覆盖已注册的函数工具: {name}")
self._builtin_function_provider.register(
name=name,
func=func,
description=description,
params_model=params_model,
)
logger.info(f"已注册函数工具: '{name}'")
return func
return decorator
async def initialize(self) -> None:
"""懒加载初始化所有已注册的 ToolProvider。"""
if not self._init_promise:
async with self._init_lock:
if not self._init_promise:
self._init_promise = asyncio.create_task(
self._initialize_providers()
)
await self._init_promise
async def _initialize_providers(self) -> None:
"""内部初始化逻辑。"""
logger.info(f"开始初始化 {len(self._providers)} 个工具提供者...")
init_tasks = [provider.initialize() for provider in self._providers]
await asyncio.gather(*init_tasks, return_exceptions=True)
logger.info("所有工具提供者初始化完成。")
async def get_resolved_tools(
self,
allowed_servers: list[str] | None = None,
excluded_servers: list[str] | None = None,
) -> dict[str, ToolExecutable]:
"""
获取所有已发现和解析的工具
此方法会触发懒加载初始化并根据是否传入过滤器来决定是否使用全局缓存
"""
await self.initialize()
has_filters = allowed_servers is not None or excluded_servers is not None
if not has_filters and self._resolved_tools is not None:
logger.debug("使用全局工具缓存。")
return self._resolved_tools
if has_filters:
logger.info("检测到过滤器,执行临时工具发现 (不使用缓存)。")
logger.debug(
f"过滤器详情: allowed_servers={allowed_servers}, "
f"excluded_servers={excluded_servers}"
)
else:
logger.info("未应用过滤器,开始全局工具发现...")
all_tools: dict[str, ToolExecutable] = {}
discover_tasks = []
for provider in self._providers:
sig = inspect.signature(provider.discover_tools)
params_to_pass = {}
if "allowed_servers" in sig.parameters:
params_to_pass["allowed_servers"] = allowed_servers
if "excluded_servers" in sig.parameters:
params_to_pass["excluded_servers"] = excluded_servers
discover_tasks.append(provider.discover_tools(**params_to_pass))
results = await asyncio.gather(*discover_tasks, return_exceptions=True)
for i, provider_result in enumerate(results):
provider_name = self._providers[i].__class__.__name__
if isinstance(provider_result, dict):
logger.debug(
f"提供者 '{provider_name}' 发现了 {len(provider_result)} 个工具。"
)
for name, executable in provider_result.items():
if name in all_tools:
logger.warning(
f"发现重复的工具名称 '{name}',后发现的将覆盖前者。"
)
all_tools[name] = executable
elif isinstance(provider_result, Exception):
logger.error(
f"提供者 '{provider_name}' 在发现工具时出错: {provider_result}"
)
if not has_filters:
self._resolved_tools = all_tools
logger.info(f"全局工具发现完成,共找到并缓存了 {len(all_tools)} 个工具。")
else:
logger.info(f"带过滤器的工具发现完成,共找到 {len(all_tools)} 个工具。")
return all_tools
async def get_function_tools(
self, names: list[str] | None = None
) -> dict[str, ToolExecutable]:
"""
仅从内置的函数提供者中解析指定的工具
"""
all_function_tools = await self._builtin_function_provider.discover_tools()
if names is None:
return all_function_tools
resolved_tools = {}
for name in names:
if name in all_function_tools:
resolved_tools[name] = all_function_tools[name]
else:
logger.warning(
f"本地函数工具 '{name}' 未通过 @function_tool 注册,将被忽略。"
)
return resolved_tools
tool_provider_manager = ToolProviderManager()

View File

@ -1,181 +0,0 @@
"""
工具注册表
负责加载管理和实例化来自配置的工具
"""
from collections.abc import Callable
from contextlib import AbstractAsyncContextManager
from functools import partial
from typing import TYPE_CHECKING
from pydantic import BaseModel
from zhenxun.services.log import logger
from ..types import LLMTool
if TYPE_CHECKING:
from ..config.providers import ToolConfig
from ..types.protocols import MCPCompatible
class ToolRegistry:
"""工具注册表,用于管理和实例化配置的工具。"""
def __init__(self):
self._function_tools: dict[str, LLMTool] = {}
self._mcp_config_models: dict[str, type[BaseModel]] = {}
if TYPE_CHECKING:
self._mcp_factories: dict[
str, Callable[..., AbstractAsyncContextManager["MCPCompatible"]]
] = {}
else:
self._mcp_factories: dict[str, Callable] = {}
self._tool_configs: dict[str, "ToolConfig"] | None = None
self._tool_cache: dict[str, "LLMTool"] = {}
def _load_configs_if_needed(self):
"""如果尚未加载则从主配置中加载MCP工具定义。"""
if self._tool_configs is None:
logger.debug("首次访问正在加载MCP工具配置...")
from ..config.providers import get_llm_config
llm_config = get_llm_config()
self._tool_configs = {tool.name: tool for tool in llm_config.mcp_tools}
logger.info(f"已加载 {len(self._tool_configs)} 个MCP工具配置。")
def function_tool(
self,
name: str,
description: str,
parameters: dict,
required: list[str] | None = None,
):
"""
装饰器在代码中注册一个简单的无状态的函数工具
参数:
name: 工具的唯一名称
description: 工具功能的描述
parameters: OpenAPI格式的函数参数schema的properties部分
required: 必需的参数列表
"""
def decorator(func: Callable):
if name in self._function_tools or name in self._mcp_factories:
logger.warning(f"正在覆盖已注册的工具: {name}")
tool_definition = LLMTool.create(
name=name,
description=description,
parameters=parameters,
required=required,
)
self._function_tools[name] = tool_definition
logger.info(f"已在代码中注册函数工具: '{name}'")
tool_definition.annotations = tool_definition.annotations or {}
tool_definition.annotations["executable"] = func
return func
return decorator
def mcp_tool(self, name: str, config_model: type[BaseModel]):
"""
装饰器注册一个MCP工具及其配置模型
参数:
name: 工具的唯一名称必须与配置文件中的名称匹配
config_model: 一个Pydantic模型用于定义和验证该工具的 `mcp_config`
"""
def decorator(factory_func: Callable):
if name in self._mcp_factories:
logger.warning(f"正在覆盖已注册的 MCP 工厂: {name}")
self._mcp_factories[name] = factory_func
self._mcp_config_models[name] = config_model
logger.info(f"已注册 MCP 工具 '{name}' (配置模型: {config_model.__name__})")
return factory_func
return decorator
def get_mcp_config_model(self, name: str) -> type[BaseModel] | None:
"""根据名称获取MCP工具的配置模型。"""
return self._mcp_config_models.get(name)
def register_mcp_factory(
self,
name: str,
factory: Callable,
):
"""
在代码中注册一个 MCP 会话工厂将其与配置中的工具名称关联
参数:
name: 工具的唯一名称必须与配置文件中的名称匹配
factory: 一个返回异步生成器的可调用对象会话工厂
"""
if name in self._mcp_factories:
logger.warning(f"正在覆盖已注册的 MCP 工厂: {name}")
self._mcp_factories[name] = factory
logger.info(f"已注册 MCP 会话工厂: '{name}'")
def get_tool(self, name: str) -> "LLMTool":
"""
根据名称获取一个 LLMTool 定义
对于MCP工具返回的 LLMTool 实例包含一个可调用的会话工厂
而不是一个已激活的会话
"""
logger.debug(f"🔍 请求获取工具定义: {name}")
if name in self._tool_cache:
logger.debug(f"✅ 从缓存中获取工具定义: {name}")
return self._tool_cache[name]
if name in self._function_tools:
logger.debug(f"🛠️ 获取函数工具定义: {name}")
tool = self._function_tools[name]
self._tool_cache[name] = tool
return tool
self._load_configs_if_needed()
if self._tool_configs is None or name not in self._tool_configs:
known_tools = list(self._function_tools.keys()) + (
list(self._tool_configs.keys()) if self._tool_configs else []
)
logger.error(f"❌ 未找到名为 '{name}' 的工具定义")
logger.debug(f"📋 可用工具定义列表: {known_tools}")
raise ValueError(f"未找到名为 '{name}' 的工具定义。已知工具: {known_tools}")
config = self._tool_configs[name]
tool: "LLMTool"
if name not in self._mcp_factories:
logger.error(f"❌ MCP工具 '{name}' 缺少工厂函数")
available_factories = list(self._mcp_factories.keys())
logger.debug(f"📋 已注册的MCP工厂: {available_factories}")
raise ValueError(
f"MCP 工具 '{name}' 已在配置中定义,但没有注册对应的工厂函数。"
"请使用 `@tool_registry.mcp_tool` 装饰器进行注册。"
)
logger.info(f"🔧 创建MCP工具定义: {name}")
factory = self._mcp_factories[name]
typed_mcp_config = config.mcp_config
logger.debug(f"📋 MCP工具配置: {typed_mcp_config}")
configured_factory = partial(factory, config=typed_mcp_config)
tool = LLMTool.from_mcp_session(session=configured_factory)
self._tool_cache[name] = tool
logger.debug(f"💾 MCP工具定义已缓存: {name}")
return tool
def get_tools(self, names: list[str]) -> list["LLMTool"]:
"""根据名称列表获取多个 LLMTool 实例。"""
return [self.get_tool(name) for name in names]
tool_registry = ToolRegistry()

View File

@ -23,7 +23,6 @@ from .models import (
LLMCodeExecution,
LLMGroundingAttribution,
LLMGroundingMetadata,
LLMTool,
LLMToolCall,
LLMToolFunction,
ModelDetail,
@ -31,9 +30,10 @@ from .models import (
ModelName,
ProviderConfig,
ToolMetadata,
ToolResult,
UsageInfo,
)
from .protocols import MCPCompatible
from .protocols import ToolExecutable, ToolProvider
__all__ = [
"EmbeddingTaskType",
@ -46,10 +46,8 @@ __all__ = [
"LLMGroundingMetadata",
"LLMMessage",
"LLMResponse",
"LLMTool",
"LLMToolCall",
"LLMToolFunction",
"MCPCompatible",
"ModelCapabilities",
"ModelDetail",
"ModelInfo",
@ -60,7 +58,10 @@ __all__ = [
"ResponseFormat",
"TaskType",
"ToolCategory",
"ToolExecutable",
"ToolMetadata",
"ToolProvider",
"ToolResult",
"UsageInfo",
"get_model_capabilities",
"get_user_friendly_error_message",

View File

@ -405,7 +405,7 @@ class LLMMessage(BaseModel):
f"工具 '{function_name}' 的结果无法JSON序列化: {result}. 错误: {e}"
)
content_str = json.dumps(
{"error": "Tool result not JSON serializable", "details": str(e)}
{"error": "工具结果无法JSON序列化", "details": str(e)}
)
return cls(

View File

@ -4,28 +4,39 @@ LLM 数据模型定义
包含模型信息配置工具定义和响应数据的模型类
"""
from collections.abc import Callable
from contextlib import AbstractAsyncContextManager
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from typing import Any
from pydantic import BaseModel, Field
from .enums import ModelProvider, ToolCategory
if TYPE_CHECKING:
from .protocols import MCPCompatible
MCPSessionType = (
MCPCompatible | Callable[[], AbstractAsyncContextManager[MCPCompatible]] | None
)
else:
MCPCompatible = object
MCPSessionType = Any
ModelName = str | None
class ToolDefinition(BaseModel):
"""
一个结构化的工具定义模型用于向LLM描述工具
"""
name: str = Field(..., description="工具的唯一名称标识")
description: str = Field(..., description="工具功能的清晰描述")
parameters: dict[str, Any] = Field(
default_factory=dict, description="符合JSON Schema规范的参数定义"
)
class ToolResult(BaseModel):
"""
一个结构化的工具执行结果模型
"""
output: Any = Field(..., description="返回给LLM的、可JSON序列化的原始输出")
display_content: str | None = Field(
default=None, description="用于日志或UI展示的人类可读的执行摘要"
)
@dataclass(frozen=True)
class ModelInfo:
"""模型信息(不可变数据类)"""
@ -107,55 +118,6 @@ class LLMToolCall(BaseModel):
function: LLMToolFunction
class LLMTool(BaseModel):
"""LLM 工具定义(支持 MCP 风格)"""
model_config = {"arbitrary_types_allowed": True}
type: str = "function"
function: dict[str, Any] | None = None
mcp_session: MCPSessionType = None
annotations: dict[str, Any] | None = Field(default=None, description="工具注解")
def model_post_init(self, /, __context: Any) -> None:
"""验证工具定义的有效性"""
_ = __context
if self.type == "function" and self.function is None:
raise ValueError("函数类型的工具必须包含 'function' 字段。")
if self.type == "mcp" and self.mcp_session is None:
raise ValueError("MCP 类型的工具必须包含 'mcp_session' 字段。")
@classmethod
def create(
cls,
name: str,
description: str,
parameters: dict[str, Any],
required: list[str] | None = None,
annotations: dict[str, Any] | None = None,
) -> "LLMTool":
"""创建函数工具"""
function_def = {
"name": name,
"description": description,
"parameters": {
"type": "object",
"properties": parameters,
"required": required or [],
},
}
return cls(type="function", function=function_def, annotations=annotations)
@classmethod
def from_mcp_session(
cls,
session: Any,
annotations: dict[str, Any] | None = None,
) -> "LLMTool":
"""从 MCP 会话创建工具"""
return cls(type="mcp", mcp_session=session, annotations=annotations)
class LLMCodeExecution(BaseModel):
"""代码执行结果"""

View File

@ -4,21 +4,62 @@ LLM 模块的协议定义
from typing import Any, Protocol
from .models import ToolDefinition, ToolResult
class MCPCompatible(Protocol):
class ToolExecutable(Protocol):
"""
一个协议定义了与LLM模块兼容的MCP会话对象应具备的行为
任何实现了 to_api_tool 方法的对象都可以被认为是 MCPCompatible
一个协议定义了所有可被LLM调用的工具必须实现的行为
它将工具的"定义"给LLM看"执行"由框架调用封装在一起
"""
def to_api_tool(self, api_type: str) -> dict[str, Any]:
async def get_definition(self) -> ToolDefinition:
"""
将此MCP会话转换为特定LLM提供商API所需的工具格式
参数:
api_type: 目标API的类型 (例如 'gemini', 'openai')
返回:
dict[str, Any]: 一个字典代表可以在API请求中使用的工具定义
异步地获取一个结构化的工具定义
"""
...
async def execute(self, **kwargs: Any) -> ToolResult:
"""
异步执行工具并返回一个结构化的结果
参数由LLM根据工具定义生成
"""
...
class ToolProvider(Protocol):
"""
一个协议定义了"工具提供者"的行为
工具提供者负责发现或实例化具体的 ToolExecutable 对象
"""
async def initialize(self) -> None:
"""
异步初始化提供者
此方法应是幂等的即多次调用只会执行一次初始化逻辑
用于执行耗时的I/O操作如网络请求或启动子进程
"""
...
async def discover_tools(
self,
allowed_servers: list[str] | None = None,
excluded_servers: list[str] | None = None,
) -> dict[str, ToolExecutable]:
"""
异步发现此提供者提供的所有工具
`initialize` 成功调用后才应被调用
返回:
一个从工具名称到 ToolExecutable 实例的字典
"""
...
async def get_tool_executable(
self, name: str, config: dict[str, Any]
) -> ToolExecutable | None:
"""
保留如果此提供者能处理名为 'name' 的工具则返回一个可执行实例
此方法主要用于按需解析 ad-hoc 工具
"""
...

View File

@ -5,6 +5,7 @@ LLM 模块的工具和转换函数
import base64
import copy
from pathlib import Path
from typing import Any
from nonebot.adapters import Message as PlatformMessage
from nonebot_plugin_alconna.uniseg import (
@ -21,7 +22,7 @@ from nonebot_plugin_alconna.uniseg import (
from zhenxun.services.log import logger
from zhenxun.utils.http_utils import AsyncHttpx
from .types import LLMContentPart
from .types import LLMContentPart, LLMMessage
async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
@ -112,9 +113,9 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
elif isinstance(seg, At):
if seg.flag == "all":
part = LLMContentPart.text_part("[Mentioned Everyone]")
part = LLMContentPart.text_part("[提及所有人]")
else:
part = LLMContentPart.text_part(f"[Mentioned user: {seg.target}]")
part = LLMContentPart.text_part(f"[提及用户: {seg.target}]")
elif isinstance(seg, Reply):
if seg.msg:
@ -126,10 +127,10 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
reply_text = str(seg.msg).strip()
if reply_text:
part = LLMContentPart.text_part(
f'[Replied to: "{reply_text[:50]}..."]'
f'[回复消息: "{reply_text[:50]}..."]'
)
except Exception:
part = LLMContentPart.text_part("[Replied to a message]")
part = LLMContentPart.text_part("[回复了一条消息]")
if part:
parts.append(part)
@ -137,6 +138,42 @@ async def unimsg_to_llm_parts(message: UniMessage) -> list[LLMContentPart]:
return parts
async def normalize_to_llm_messages(
message: str | UniMessage | LLMMessage | list[LLMContentPart] | list[LLMMessage],
instruction: str | None = None,
) -> list[LLMMessage]:
"""
将多种输入格式标准化为 LLMMessage 列表并可选地添加系统指令
这是处理 LLM 输入的核心工具函数
参数:
message: 要标准化的输入消息
instruction: 可选的系统指令
返回:
list[LLMMessage]: 标准化后的消息列表
"""
messages = []
if instruction:
messages.append(LLMMessage.system(instruction))
if isinstance(message, LLMMessage):
messages.append(message)
elif isinstance(message, list) and all(isinstance(m, LLMMessage) for m in message):
messages.extend(message)
elif isinstance(message, str):
messages.append(LLMMessage.user(message))
elif isinstance(message, UniMessage):
content_parts = await unimsg_to_llm_parts(message)
messages.append(LLMMessage.user(content_parts))
elif isinstance(message, list):
messages.append(LLMMessage.user(message)) # type: ignore
else:
raise TypeError(f"不支持的消息类型: {type(message)}")
return messages
def create_multimodal_message(
text: str | None = None,
images: list[str | Path | bytes] | str | Path | bytes | None = None,
@ -282,3 +319,37 @@ def _sanitize_request_body_for_logging(body: dict) -> dict:
except Exception as e:
logger.warning(f"日志净化失败: {e},将记录原始请求体。")
return body
def sanitize_schema_for_llm(schema: Any, api_type: str) -> Any:
"""
递归地净化 JSON Schema移除特定 LLM API 不支持的关键字
参数:
schema: 要净化的 JSON Schema (可以是字典列表或其它类型)
api_type: 目标 API 的类型例如 'gemini'
返回:
Any: 净化后的 JSON Schema
"""
if isinstance(schema, dict):
schema_copy = {}
for key, value in schema.items():
if api_type == "gemini":
unsupported_keys = ["exclusiveMinimum", "exclusiveMaximum", "default"]
if key in unsupported_keys:
continue
if key == "format" and isinstance(value, str):
supported_formats = ["enum", "date-time"]
if value not in supported_formats:
continue
schema_copy[key] = sanitize_schema_for_llm(value, api_type)
return schema_copy
elif isinstance(schema, list):
return [sanitize_schema_for_llm(item, api_type) for item in schema]
else:
return schema

View File

@ -57,6 +57,7 @@ async def get_fastest_release_formats() -> list[str]:
async def get_fastest_release_source_formats() -> list[str]:
"""获取最快的发行版源码下载地址格式"""
formats: dict[str, str] = {
"https://github.bibk.top": "https://github.bibk.top/{owner}/{repo}/releases/download/{version}/{filename}",
"https://codeload.github.com/": RELEASE_SOURCE_FORMAT,
"https://p.102333.xyz/": f"https://p.102333.xyz/{RELEASE_SOURCE_FORMAT}",
}

View File

@ -317,6 +317,20 @@ class AliyunFileInfo:
repository_id: str
"""仓库ID"""
@classmethod
async def get_client(cls) -> devops20210625Client:
"""获取阿里云客户端"""
config = open_api_models.Config(
access_key_id=Aliyun_AccessKey_ID,
access_key_secret=base64.b64decode(
Aliyun_Secret_AccessKey_encrypted.encode()
).decode(),
endpoint=ALIYUN_ENDPOINT,
region_id=ALIYUN_REGION,
)
return devops20210625Client(config)
@classmethod
async def get_file_content(
cls, file_path: str, repo: str, ref: str = "main"
@ -335,16 +349,8 @@ class AliyunFileInfo:
repository_id = ALIYUN_REPO_MAPPING.get(repo)
if not repository_id:
raise ValueError(f"未找到仓库 {repo} 对应的阿里云仓库ID")
config = open_api_models.Config(
access_key_id=Aliyun_AccessKey_ID,
access_key_secret=base64.b64decode(
Aliyun_Secret_AccessKey_encrypted.encode()
).decode(),
endpoint=ALIYUN_ENDPOINT,
region_id=ALIYUN_REGION,
)
client = devops20210625Client(config)
client = await cls.get_client()
request = devops_20210625_models.GetFileBlobsRequest(
organization_id=ALIYUN_ORG_ID,
@ -404,16 +410,7 @@ class AliyunFileInfo:
if not repository_id:
raise ValueError(f"未找到仓库 {repo} 对应的阿里云仓库ID")
config = open_api_models.Config(
access_key_id=Aliyun_AccessKey_ID,
access_key_secret=base64.b64decode(
Aliyun_Secret_AccessKey_encrypted.encode()
).decode(),
endpoint=ALIYUN_ENDPOINT,
region_id=ALIYUN_REGION,
)
client = devops20210625Client(config)
client = await cls.get_client()
request = devops_20210625_models.ListRepositoryTreeRequest(
organization_id=ALIYUN_ORG_ID,
@ -459,16 +456,7 @@ class AliyunFileInfo:
if not repository_id:
raise ValueError(f"未找到仓库 {repo} 对应的阿里云仓库ID")
config = open_api_models.Config(
access_key_id=Aliyun_AccessKey_ID,
access_key_secret=base64.b64decode(
Aliyun_Secret_AccessKey_encrypted.encode()
).decode(),
endpoint=ALIYUN_ENDPOINT,
region_id=ALIYUN_REGION,
)
client = devops20210625Client(config)
client = await cls.get_client()
request = devops_20210625_models.GetRepositoryCommitRequest(
organization_id=ALIYUN_ORG_ID,

View File

@ -1,87 +0,0 @@
import os
from pathlib import Path
import shutil
import zipfile
from zhenxun.configs.path_config import FONT_PATH
from zhenxun.services.log import logger
from zhenxun.utils.github_utils import GithubUtils
from zhenxun.utils.http_utils import AsyncHttpx
CMD_STRING = "ResourceManager"
class DownloadResourceException(Exception):
pass
class ResourceManager:
GITHUB_URL = "https://github.com/zhenxun-org/zhenxun-bot-resources/tree/main"
RESOURCE_PATH = Path() / "resources"
TMP_PATH = Path() / "_resource_tmp"
ZIP_FILE = TMP_PATH / "resources.zip"
UNZIP_PATH = None
@classmethod
async def init_resources(cls, force: bool = False):
if (FONT_PATH.exists() and os.listdir(FONT_PATH)) and not force:
return
if cls.TMP_PATH.exists():
logger.debug(
"resources临时文件夹已存在移除resources临时文件夹", CMD_STRING
)
shutil.rmtree(cls.TMP_PATH)
cls.TMP_PATH.mkdir(parents=True, exist_ok=True)
try:
await cls.__download_resources()
cls.file_handle()
except Exception as e:
logger.error("获取resources资源包失败", CMD_STRING, e=e)
if cls.TMP_PATH.exists():
logger.debug("移除resources临时文件夹", CMD_STRING)
shutil.rmtree(cls.TMP_PATH)
@classmethod
def file_handle(cls):
if not cls.UNZIP_PATH:
return
cls.__recursive_folder(cls.UNZIP_PATH, "resources")
@classmethod
def __recursive_folder(cls, dir: Path, parent_path: str):
for file in dir.iterdir():
if file.is_dir():
cls.__recursive_folder(file, f"{parent_path}/{file.name}")
else:
res_file = Path(parent_path) / file.name
if res_file.exists():
res_file.unlink()
res_file.parent.mkdir(parents=True, exist_ok=True)
file.rename(res_file)
@classmethod
async def __download_resources(cls):
"""获取resources文件夹"""
repo_info = GithubUtils.parse_github_url(cls.GITHUB_URL)
url = await repo_info.get_archive_download_urls()
logger.debug("开始下载resources资源包...", CMD_STRING)
if not await AsyncHttpx.download_file(url, cls.ZIP_FILE, stream=True):
logger.error(
"下载resources资源包失败请尝试重启重新下载或前往 "
"https://github.com/zhenxun-org/zhenxun-bot-resources 手动下载..."
)
raise DownloadResourceException("下载resources资源包失败...")
logger.debug("下载resources资源文件压缩包完成...", CMD_STRING)
tf = zipfile.ZipFile(cls.ZIP_FILE)
tf.extractall(cls.TMP_PATH)
logger.debug("解压文件压缩包完成...", CMD_STRING)
download_file_path = cls.TMP_PATH / next(
x for x in os.listdir(cls.TMP_PATH) if (cls.TMP_PATH / x).is_dir()
)
cls.UNZIP_PATH = download_file_path / "resources"
if tf:
tf.close()

View File

@ -1,3 +1,4 @@
import asyncio
from pathlib import Path
import subprocess
from subprocess import CalledProcessError
@ -36,7 +37,7 @@ class VirtualEnvPackageManager:
)
@classmethod
def install(cls, package: list[str] | str):
async def install(cls, package: list[str] | str):
"""安装依赖包
参数:
@ -49,7 +50,8 @@ class VirtualEnvPackageManager:
command.append("install")
command.append(" ".join(package))
logger.info(f"执行虚拟环境安装包指令: {command}", LOG_COMMAND)
result = subprocess.run(
result = await asyncio.to_thread(
subprocess.run,
command,
check=True,
capture_output=True,
@ -65,7 +67,7 @@ class VirtualEnvPackageManager:
return e.stderr
@classmethod
def uninstall(cls, package: list[str] | str):
async def uninstall(cls, package: list[str] | str):
"""卸载依赖包
参数:
@ -79,7 +81,8 @@ class VirtualEnvPackageManager:
command.append("-y")
command.append(" ".join(package))
logger.info(f"执行虚拟环境卸载包指令: {command}", LOG_COMMAND)
result = subprocess.run(
result = await asyncio.to_thread(
subprocess.run,
command,
check=True,
capture_output=True,
@ -95,7 +98,7 @@ class VirtualEnvPackageManager:
return e.stderr
@classmethod
def update(cls, package: list[str] | str):
async def update(cls, package: list[str] | str):
"""更新依赖包
参数:
@ -109,7 +112,8 @@ class VirtualEnvPackageManager:
command.append("--upgrade")
command.append(" ".join(package))
logger.info(f"执行虚拟环境更新包指令: {command}", LOG_COMMAND)
result = subprocess.run(
result = await asyncio.to_thread(
subprocess.run,
command,
check=True,
capture_output=True,
@ -122,7 +126,7 @@ class VirtualEnvPackageManager:
return e.stderr
@classmethod
def install_requirement(cls, requirement_file: Path):
async def install_requirement(cls, requirement_file: Path):
"""安装依赖文件
参数:
@ -139,7 +143,8 @@ class VirtualEnvPackageManager:
command.append("-r")
command.append(str(requirement_file.absolute()))
logger.info(f"执行虚拟环境安装依赖文件指令: {command}", LOG_COMMAND)
result = subprocess.run(
result = await asyncio.to_thread(
subprocess.run,
command,
check=True,
capture_output=True,
@ -158,13 +163,14 @@ class VirtualEnvPackageManager:
return e.stderr
@classmethod
def list(cls) -> str:
async def list(cls) -> str:
"""列出已安装的依赖包"""
try:
command = cls.__get_command()
command.append("list")
logger.info(f"执行虚拟环境列出包指令: {command}", LOG_COMMAND)
result = subprocess.run(
result = await asyncio.to_thread(
subprocess.run,
command,
check=True,
capture_output=True,

View File

@ -0,0 +1,556 @@
"""
真寻仓库管理器
负责真寻主仓库的更新版本检查文件处理等功能
"""
import os
from pathlib import Path
import shutil
from typing import ClassVar, Literal
import zipfile
import aiofiles
from zhenxun.configs.path_config import DATA_PATH, TEMP_PATH
from zhenxun.services.log import logger
from zhenxun.utils.github_utils import GithubUtils
from zhenxun.utils.http_utils import AsyncHttpx
from zhenxun.utils.manager.virtual_env_package_manager import VirtualEnvPackageManager
from zhenxun.utils.repo_utils import AliyunRepoManager, GithubRepoManager
from zhenxun.utils.repo_utils.models import RepoUpdateResult
from zhenxun.utils.repo_utils.utils import check_git
LOG_COMMAND = "ZhenxunRepoManager"
class ZhenxunUpdateException(Exception):
"""资源下载异常"""
pass
class ZhenxunRepoConfig:
"""真寻仓库配置"""
# Zhenxun Bot 相关配置
ZHENXUN_BOT_GIT = "https://github.com/zhenxun-org/zhenxun_bot.git"
ZHENXUN_BOT_GITHUB_URL = "https://github.com/HibiKier/zhenxun_bot/tree/main"
ZHENXUN_BOT_DOWNLOAD_FILE_STRING = "zhenxun_bot.zip"
ZHENXUN_BOT_DOWNLOAD_FILE = TEMP_PATH / ZHENXUN_BOT_DOWNLOAD_FILE_STRING
ZHENXUN_BOT_UNZIP_PATH = TEMP_PATH / "zhenxun_bot"
ZHENXUN_BOT_CODE_PATH = Path() / "zhenxun"
ZHENXUN_BOT_RELEASES_API_URL = (
"https://api.github.com/repos/HibiKier/zhenxun_bot/releases/latest"
)
ZHENXUN_BOT_BACKUP_PATH = Path() / "backup"
# 需要替换的文件夹
ZHENXUN_BOT_UPDATE_FOLDERS: ClassVar[list[str]] = [
"zhenxun/builtin_plugins",
"zhenxun/services",
"zhenxun/utils",
"zhenxun/models",
"zhenxun/configs",
]
ZHENXUN_BOT_VERSION_FILE_STRING = "__version__"
ZHENXUN_BOT_VERSION_FILE = Path() / ZHENXUN_BOT_VERSION_FILE_STRING
# 备份杂项
BACKUP_FILES: ClassVar[list[str]] = [
"pyproject.toml",
"poetry.lock",
"requirements.txt",
".env.dev",
".env.example",
]
# WEB UI 相关配置
WEBUI_GIT = "https://github.com/HibiKier/zhenxun_bot_webui.git"
WEBUI_DIST_GITHUB_URL = "https://github.com/HibiKier/zhenxun_bot_webui/tree/dist"
WEBUI_DOWNLOAD_FILE_STRING = "webui_assets.zip"
WEBUI_DOWNLOAD_FILE = TEMP_PATH / WEBUI_DOWNLOAD_FILE_STRING
WEBUI_UNZIP_PATH = TEMP_PATH / "web_ui"
WEBUI_PATH = DATA_PATH / "web_ui" / "public"
WEBUI_BACKUP_PATH = DATA_PATH / "web_ui" / "backup_public"
# 资源管理相关配置
RESOURCE_GIT = "https://github.com/zhenxun-org/zhenxun-bot-resources.git"
RESOURCE_GITHUB_URL = (
"https://github.com/zhenxun-org/zhenxun-bot-resources/tree/main"
)
RESOURCE_ZIP_FILE_STRING = "resources.zip"
RESOURCE_ZIP_FILE = TEMP_PATH / RESOURCE_ZIP_FILE_STRING
RESOURCE_UNZIP_PATH = TEMP_PATH / "resources"
RESOURCE_PATH = Path() / "resources"
REQUIREMENTS_FILE_STRING = "requirements.txt"
REQUIREMENTS_FILE = Path() / REQUIREMENTS_FILE_STRING
PYPROJECT_FILE_STRING = "pyproject.toml"
PYPROJECT_FILE = Path() / PYPROJECT_FILE_STRING
PYPROJECT_LOCK_FILE_STRING = "poetry.lock"
PYPROJECT_LOCK_FILE = Path() / PYPROJECT_LOCK_FILE_STRING
class ZhenxunRepoManagerClass:
"""真寻仓库管理器"""
def __init__(self):
self.config = ZhenxunRepoConfig()
def __clear_folder(self, folder_path: Path):
"""
清空文件夹
参数:
folder_path: 文件夹路径
"""
if not folder_path.exists():
return
for filename in os.listdir(folder_path):
file_path = folder_path / filename
try:
if file_path.is_file():
os.unlink(file_path)
elif file_path.is_dir() and not filename.startswith("."):
shutil.rmtree(file_path)
except Exception as e:
logger.warning(f"无法删除 {file_path}", LOG_COMMAND, e=e)
def __copy_files(self, src_path: Path, dest_path: Path, incremental: bool = False):
"""
复制文件或文件夹
参数:
src_path: 源文件或文件夹路径
dest_path: 目标文件或文件夹路径
incremental: 是否增量复制
"""
if src_path.is_file():
shutil.copy(src_path, dest_path)
logger.debug(f"复制文件 {src_path} -> {dest_path}", LOG_COMMAND)
elif src_path.is_dir():
for filename in os.listdir(src_path):
file_path = src_path / filename
dest_file = dest_path / filename
dest_file.parent.mkdir(exist_ok=True, parents=True)
if file_path.is_file():
if dest_file.exists():
dest_file.unlink()
shutil.copy(file_path, dest_file)
logger.debug(f"复制文件 {file_path} -> {dest_file}", LOG_COMMAND)
elif file_path.is_dir():
if incremental:
self.__copy_files(file_path, dest_file, incremental=True)
else:
if dest_file.exists():
shutil.rmtree(dest_file, True)
shutil.copytree(file_path, dest_file)
logger.debug(
f"复制文件夹 {file_path} -> {dest_file}",
LOG_COMMAND,
)
# ==================== Zhenxun Bot 相关方法 ====================
async def zhenxun_get_version_from_repo(self) -> str:
"""从指定分支获取版本号
返回:
str: 版本号
"""
repo_info = GithubUtils.parse_github_url(self.config.ZHENXUN_BOT_GITHUB_URL)
version_url = await repo_info.get_raw_download_urls(
path=self.config.ZHENXUN_BOT_VERSION_FILE_STRING
)
try:
res = await AsyncHttpx.get(version_url)
if res.status_code == 200:
return res.text.strip()
except Exception as e:
logger.error(f"获取 {repo_info.branch} 分支版本失败", LOG_COMMAND, e=e)
return "未知版本"
async def zhenxun_write_version_file(self, version: str):
"""写入版本文件"""
async with aiofiles.open(
self.config.ZHENXUN_BOT_VERSION_FILE, "w", encoding="utf8"
) as f:
await f.write(f"__version__: {version}")
def __backup_zhenxun(self):
"""备份真寻文件"""
for filename in os.listdir(self.config.ZHENXUN_BOT_CODE_PATH):
file_path = self.config.ZHENXUN_BOT_CODE_PATH / filename
if file_path.exists():
self.__copy_files(
file_path,
self.config.ZHENXUN_BOT_BACKUP_PATH / filename,
True,
)
for filename in self.config.BACKUP_FILES:
file_path = Path() / filename
if file_path.exists():
self.__copy_files(
file_path,
self.config.ZHENXUN_BOT_BACKUP_PATH / filename,
)
async def zhenxun_get_latest_releases_data(self) -> dict:
"""获取真寻releases最新版本信息
返回:
dict: 最新版本数据
"""
try:
res = await AsyncHttpx.get(self.config.ZHENXUN_BOT_RELEASES_API_URL)
if res.status_code == 200:
return res.json()
except Exception as e:
logger.error("检查更新真寻获取版本失败", LOG_COMMAND, e=e)
return {}
async def zhenxun_download_zip(self, ver_type: Literal["main", "release"]) -> str:
"""下载真寻最新版文件
参数:
ver_type: 版本类型main 为最新版release 为最新release版
返回:
str: 版本号
"""
repo_info = GithubUtils.parse_github_url(self.config.ZHENXUN_BOT_GITHUB_URL)
if ver_type == "main":
download_url = await repo_info.get_archive_download_urls()
new_version = await self.zhenxun_get_version_from_repo()
else:
release_data = await self.zhenxun_get_latest_releases_data()
logger.debug(f"获取真寻RELEASES最新版本信息: {release_data}", LOG_COMMAND)
if not release_data:
raise ZhenxunUpdateException("获取真寻RELEASES最新版本失败...")
new_version = release_data.get("name", "")
download_url = await repo_info.get_release_source_download_urls_tgz(
new_version
)
if not download_url:
raise ZhenxunUpdateException("获取真寻最新版文件下载链接失败...")
if self.config.ZHENXUN_BOT_DOWNLOAD_FILE.exists():
self.config.ZHENXUN_BOT_DOWNLOAD_FILE.unlink()
if await AsyncHttpx.download_file(
download_url, self.config.ZHENXUN_BOT_DOWNLOAD_FILE, stream=True
):
logger.debug("下载真寻最新版文件完成...", LOG_COMMAND)
else:
raise ZhenxunUpdateException("下载真寻最新版文件失败...")
return new_version
async def zhenxun_unzip(self):
"""解压真寻最新版文件"""
if not self.config.ZHENXUN_BOT_DOWNLOAD_FILE.exists():
raise FileNotFoundError("真寻最新版文件不存在")
if self.config.ZHENXUN_BOT_UNZIP_PATH.exists():
shutil.rmtree(self.config.ZHENXUN_BOT_UNZIP_PATH)
tf = None
try:
tf = zipfile.ZipFile(self.config.ZHENXUN_BOT_DOWNLOAD_FILE)
tf.extractall(self.config.ZHENXUN_BOT_UNZIP_PATH)
logger.debug("解压Zhenxun Bot文件压缩包完成!", LOG_COMMAND)
self.__backup_zhenxun()
for filename in self.config.BACKUP_FILES:
self.__copy_files(
self.config.ZHENXUN_BOT_UNZIP_PATH / filename,
Path() / filename,
)
logger.debug("备份真寻更新文件完成!", LOG_COMMAND)
unzip_dir = next(self.config.ZHENXUN_BOT_UNZIP_PATH.iterdir())
for folder in self.config.ZHENXUN_BOT_UPDATE_FOLDERS:
self.__copy_files(unzip_dir / folder, Path() / folder)
logger.debug("移动真寻更新文件完成!", LOG_COMMAND)
if self.config.ZHENXUN_BOT_UNZIP_PATH.exists():
shutil.rmtree(self.config.ZHENXUN_BOT_UNZIP_PATH)
except Exception as e:
logger.error("解压真寻最新版文件失败...", LOG_COMMAND, e=e)
raise
finally:
if tf:
tf.close()
async def zhenxun_zip_update(self, ver_type: Literal["main", "release"]) -> str:
"""使用zip更新真寻
参数:
ver_type: 版本类型main 为最新版release 为最新release版
返回:
str: 版本号
"""
new_version = await self.zhenxun_download_zip(ver_type)
await self.zhenxun_unzip()
await self.zhenxun_write_version_file(new_version)
return new_version
async def zhenxun_git_update(
self, source: Literal["git", "ali"], branch: str = "main", force: bool = False
) -> RepoUpdateResult:
"""使用git或阿里云更新真寻
参数:
source: 更新源git git 更新ali 为阿里云更新
branch: 分支名称
force: 是否强制更新
"""
if source == "git":
return await GithubRepoManager.update_via_git(
self.config.ZHENXUN_BOT_GIT,
Path(),
branch=branch,
force=force,
)
else:
return await AliyunRepoManager.update_via_git(
self.config.ZHENXUN_BOT_GIT,
Path(),
branch=branch,
force=force,
)
async def zhenxun_update(
self,
source: Literal["git", "ali"] = "ali",
branch: str = "main",
force: bool = False,
ver_type: Literal["main", "release"] = "main",
):
"""更新真寻
参数:
source: 更新源git git 更新ali 为阿里云更新
branch: 分支名称
force: 是否强制更新
ver_type: 版本类型main 为最新版release 为最新release版
"""
if await check_git():
await self.zhenxun_git_update(source, branch, force)
logger.debug("使用git更新真寻!", LOG_COMMAND)
else:
await self.zhenxun_zip_update(ver_type)
logger.debug("使用zip更新真寻!", LOG_COMMAND)
async def install_requirements(self):
"""安装真寻依赖"""
await VirtualEnvPackageManager.install_requirement(
self.config.REQUIREMENTS_FILE
)
# ==================== 资源管理相关方法 ====================
def check_resources_exists(self) -> bool:
"""检查资源文件是否存在
返回:
bool: 是否存在
"""
if self.config.RESOURCE_PATH.exists():
font_path = self.config.RESOURCE_PATH / "font"
if font_path.exists() and os.listdir(font_path):
return True
return False
async def resources_download_zip(self):
"""下载资源文件"""
download_url = await GithubUtils.parse_github_url(
self.config.RESOURCE_GITHUB_URL
).get_archive_download_urls()
logger.debug("开始下载resources资源包...", LOG_COMMAND)
if await AsyncHttpx.download_file(
download_url, self.config.RESOURCE_ZIP_FILE, stream=True
):
logger.debug("下载resources资源文件压缩包成功!", LOG_COMMAND)
else:
raise ZhenxunUpdateException("下载resources资源包失败...")
async def resources_unzip(self):
"""解压资源文件"""
if not self.config.RESOURCE_ZIP_FILE.exists():
raise FileNotFoundError("资源文件压缩包不存在")
if self.config.RESOURCE_UNZIP_PATH.exists():
shutil.rmtree(self.config.RESOURCE_UNZIP_PATH)
tf = None
try:
tf = zipfile.ZipFile(self.config.RESOURCE_ZIP_FILE)
tf.extractall(self.config.RESOURCE_UNZIP_PATH)
logger.debug("解压文件压缩包完成...", LOG_COMMAND)
unzip_dir = next(self.config.RESOURCE_UNZIP_PATH.iterdir())
self.__copy_files(unzip_dir, self.config.RESOURCE_PATH, True)
logger.debug("复制资源文件完成!", LOG_COMMAND)
shutil.rmtree(self.config.RESOURCE_UNZIP_PATH, ignore_errors=True)
except Exception as e:
logger.error("解压资源文件失败...", LOG_COMMAND, e=e)
raise
finally:
if tf:
tf.close()
async def resources_zip_update(self):
"""使用zip更新资源文件"""
await self.resources_download_zip()
await self.resources_unzip()
async def resources_git_update(
self, source: Literal["git", "ali"], branch: str = "main", force: bool = False
) -> RepoUpdateResult:
"""使用git或阿里云更新资源文件
参数:
source: 更新源git git 更新ali 为阿里云更新
branch: 分支名称
force: 是否强制更新
"""
if source == "git":
return await GithubRepoManager.update_via_git(
self.config.RESOURCE_GIT,
self.config.RESOURCE_PATH,
branch=branch,
force=force,
)
else:
return await AliyunRepoManager.update_via_git(
self.config.RESOURCE_GIT,
self.config.RESOURCE_PATH,
branch=branch,
force=force,
)
async def resources_update(
self,
source: Literal["git", "ali"] = "ali",
branch: str = "main",
force: bool = False,
):
"""更新资源文件
参数:
source: 更新源git git 更新ali 为阿里云更新
branch: 分支名称
force: 是否强制更新
"""
if await check_git():
await self.resources_git_update(source, branch, force)
logger.debug("使用git更新资源文件!", LOG_COMMAND)
else:
await self.resources_zip_update()
logger.debug("使用zip更新资源文件!", LOG_COMMAND)
# ==================== Web UI 管理相关方法 ====================
def check_webui_exists(self) -> bool:
"""检查 Web UI 资源是否存在"""
return bool(
self.config.WEBUI_PATH.exists() and os.listdir(self.config.WEBUI_PATH)
)
async def webui_download_zip(self):
"""下载 WEBUI_ASSETS 资源"""
download_url = await GithubUtils.parse_github_url(
self.config.WEBUI_DIST_GITHUB_URL
).get_archive_download_urls()
logger.info("开始下载 WEBUI_ASSETS 资源...", LOG_COMMAND)
if await AsyncHttpx.download_file(
download_url, self.config.WEBUI_DOWNLOAD_FILE, follow_redirects=True
):
logger.info("下载 WEBUI_ASSETS 成功!", LOG_COMMAND)
else:
raise ZhenxunUpdateException("下载 WEBUI_ASSETS 失败", LOG_COMMAND)
def __backup_webui(self):
"""备份 WEBUI_ASSERT 资源"""
if self.config.WEBUI_PATH.exists():
if self.config.WEBUI_BACKUP_PATH.exists():
logger.debug(
f"删除旧的备份webui文件夹 {self.config.WEBUI_BACKUP_PATH}",
LOG_COMMAND,
)
shutil.rmtree(self.config.WEBUI_BACKUP_PATH)
shutil.copytree(self.config.WEBUI_PATH, self.config.WEBUI_BACKUP_PATH)
async def webui_unzip(self):
"""解压 WEBUI_ASSETS 资源
返回:
str: 更新结果
"""
if not self.config.WEBUI_DOWNLOAD_FILE.exists():
raise FileNotFoundError("webui文件压缩包不存在")
tf = None
try:
self.__backup_webui()
self.__clear_folder(self.config.WEBUI_PATH)
tf = zipfile.ZipFile(self.config.WEBUI_DOWNLOAD_FILE)
tf.extractall(self.config.WEBUI_UNZIP_PATH)
logger.debug("Web UI 解压文件压缩包完成...", LOG_COMMAND)
unzip_dir = next(self.config.WEBUI_UNZIP_PATH.iterdir())
self.__copy_files(unzip_dir, self.config.WEBUI_PATH)
logger.debug("Web UI 复制 WEBUI_ASSETS 成功!", LOG_COMMAND)
shutil.rmtree(self.config.WEBUI_UNZIP_PATH, ignore_errors=True)
except Exception as e:
if self.config.WEBUI_BACKUP_PATH.exists():
self.__copy_files(self.config.WEBUI_BACKUP_PATH, self.config.WEBUI_PATH)
logger.debug("恢复备份 WEBUI_ASSETS 成功!", LOG_COMMAND)
shutil.rmtree(self.config.WEBUI_BACKUP_PATH, ignore_errors=True)
logger.error("Web UI 更新失败", LOG_COMMAND, e=e)
raise
finally:
if tf:
tf.close()
async def webui_zip_update(self):
"""使用zip更新 Web UI"""
await self.webui_download_zip()
await self.webui_unzip()
async def webui_git_update(
self, source: Literal["git", "ali"], branch: str = "dist", force: bool = False
) -> RepoUpdateResult:
"""使用git或阿里云更新 Web UI
参数:
source: 更新源git git 更新ali 为阿里云更新
branch: 分支名称
force: 是否强制更新
"""
if source == "git":
return await GithubRepoManager.update_via_git(
self.config.WEBUI_GIT,
self.config.WEBUI_PATH,
branch=branch,
force=force,
)
else:
return await AliyunRepoManager.update_via_git(
self.config.WEBUI_GIT,
self.config.WEBUI_PATH,
branch=branch,
force=force,
)
async def webui_update(
self,
source: Literal["git", "ali"] = "ali",
branch: str = "dist",
force: bool = False,
):
"""更新 Web UI
参数:
source: 更新源git git 更新ali 为阿里云更新
"""
if await check_git():
await self.webui_git_update(source, branch, force)
logger.debug("使用git更新Web UI!", LOG_COMMAND)
else:
await self.webui_zip_update()
logger.debug("使用zip更新Web UI!", LOG_COMMAND)
ZhenxunRepoManager = ZhenxunRepoManagerClass()

View File

@ -0,0 +1,88 @@
"""
Pydantic V1 & V2 兼容层模块
Pydantic V1 V2 版本提供统一的便捷函数与类
包括 model_dump, model_copy, model_json_schema, parse_as
"""
from typing import Any, TypeVar, get_args, get_origin
from nonebot.compat import PYDANTIC_V2, model_dump
from pydantic import VERSION, BaseModel
T = TypeVar("T", bound=BaseModel)
V = TypeVar("V")
__all__ = [
"PYDANTIC_V2",
"_dump_pydantic_obj",
"_is_pydantic_type",
"model_copy",
"model_dump",
"model_json_schema",
"parse_as",
]
def model_copy(
model: T, *, update: dict[str, Any] | None = None, deep: bool = False
) -> T:
"""
Pydantic `model.copy()` (v1) `model.model_copy()` (v2) 的兼容函数
"""
if PYDANTIC_V2:
return model.model_copy(update=update, deep=deep)
else:
update_dict = update or {}
return model.copy(update=update_dict, deep=deep)
def model_json_schema(model_class: type[BaseModel], **kwargs: Any) -> dict[str, Any]:
"""
Pydantic `Model.schema()` (v1) `Model.model_json_schema()` (v2) 的兼容函数
"""
if PYDANTIC_V2:
return model_class.model_json_schema(**kwargs)
else:
return model_class.schema(by_alias=kwargs.get("by_alias", True))
def _is_pydantic_type(t: Any) -> bool:
"""
递归检查一个类型注解是否与 Pydantic BaseModel 相关
"""
if t is None:
return False
origin = get_origin(t)
if origin:
return any(_is_pydantic_type(arg) for arg in get_args(t))
return isinstance(t, type) and issubclass(t, BaseModel)
def _dump_pydantic_obj(obj: Any) -> Any:
"""
递归地将一个对象内部的 Pydantic BaseModel 实例转换为字典
支持单个实例实例列表实例字典等情况
"""
if isinstance(obj, BaseModel):
return model_dump(obj)
if isinstance(obj, list):
return [_dump_pydantic_obj(item) for item in obj]
if isinstance(obj, dict):
return {key: _dump_pydantic_obj(value) for key, value in obj.items()}
return obj
def parse_as(type_: type[V], obj: Any) -> V:
"""
一个兼容 Pydantic V1 parse_obj_as 和V2的TypeAdapter.validate_python 的辅助函数
"""
if VERSION.startswith("1"):
from pydantic import parse_obj_as
return parse_obj_as(type_, obj)
else:
from pydantic import TypeAdapter # type: ignore
return TypeAdapter(type_).validate_python(obj)

View File

@ -0,0 +1,60 @@
"""
仓库管理工具用于操作GitHub和阿里云CodeUp项目的更新和文件下载
"""
from .aliyun_manager import AliyunCodeupManager
from .base_manager import BaseRepoManager
from .config import AliyunCodeupConfig, GithubConfig, RepoConfig
from .exceptions import (
ApiRateLimitError,
AuthenticationError,
ConfigError,
FileNotFoundError,
NetworkError,
RepoDownloadError,
RepoManagerError,
RepoNotFoundError,
RepoUpdateError,
)
from .file_manager import RepoFileManager as RepoFileManagerClass
from .github_manager import GithubManager
from .models import (
FileDownloadResult,
RepoCommitInfo,
RepoFileInfo,
RepoType,
RepoUpdateResult,
)
from .utils import check_git, filter_files, glob_to_regex, run_git_command
GithubRepoManager = GithubManager()
AliyunRepoManager = AliyunCodeupManager()
RepoFileManager = RepoFileManagerClass()
__all__ = [
"AliyunCodeupConfig",
"AliyunRepoManager",
"ApiRateLimitError",
"AuthenticationError",
"BaseRepoManager",
"ConfigError",
"FileDownloadResult",
"FileNotFoundError",
"GithubConfig",
"GithubRepoManager",
"NetworkError",
"RepoCommitInfo",
"RepoConfig",
"RepoDownloadError",
"RepoFileInfo",
"RepoFileManager",
"RepoManagerError",
"RepoNotFoundError",
"RepoType",
"RepoUpdateError",
"RepoUpdateResult",
"check_git",
"filter_files",
"glob_to_regex",
"run_git_command",
]

View File

@ -0,0 +1,557 @@
"""
阿里云CodeUp仓库管理工具
"""
import asyncio
from collections.abc import Callable
from datetime import datetime
from pathlib import Path
from aiocache import cached
from zhenxun.services.log import logger
from zhenxun.utils.github_utils.models import AliyunFileInfo
from .base_manager import BaseRepoManager
from .config import LOG_COMMAND, RepoConfig
from .exceptions import (
AuthenticationError,
FileNotFoundError,
RepoDownloadError,
RepoNotFoundError,
RepoUpdateError,
)
from .models import (
FileDownloadResult,
RepoCommitInfo,
RepoFileInfo,
RepoType,
RepoUpdateResult,
)
class AliyunCodeupManager(BaseRepoManager):
"""阿里云CodeUp仓库管理工具"""
def __init__(self, config: RepoConfig | None = None):
"""
初始化阿里云CodeUp仓库管理工具
Args:
config: 配置如果为None则使用默认配置
"""
super().__init__(config)
self._client = None
async def update_repo(
self,
repo_url: str,
local_path: Path,
branch: str = "main",
include_patterns: list[str] | None = None,
exclude_patterns: list[str] | None = None,
) -> RepoUpdateResult:
"""
更新阿里云CodeUp仓库
Args:
repo_url: 仓库URL或名称
local_path: 本地保存路径
branch: 分支名称
include_patterns: 包含的文件模式列表 ["*.py", "docs/*.md"]
exclude_patterns: 排除的文件模式列表 ["__pycache__/*", "*.pyc"]
Returns:
RepoUpdateResult: 更新结果
"""
try:
# 检查配置
self._check_config()
# 获取仓库名称从URL中提取
repo_url = repo_url.split("/tree/")[0].split("/")[-1].replace(".git", "")
# 获取仓库最新提交ID
newest_commit = await self._get_newest_commit(repo_url, branch)
# 创建结果对象
result = RepoUpdateResult(
repo_type=RepoType.ALIYUN,
repo_name=repo_url.split("/tree/")[0]
.split("/")[-1]
.replace(".git", ""),
owner=self.config.aliyun_codeup.organization_id,
old_version="", # 将在后面更新
new_version=newest_commit,
)
old_version = await self.read_version_file(local_path)
result.old_version = old_version
# 如果版本相同,则无需更新
if old_version == newest_commit:
result.success = True
logger.debug(
f"仓库 {repo_url.split('/')[-1].replace('.git', '')}"
f" 已是最新版本: {newest_commit[:8]}",
LOG_COMMAND,
)
return result
# 确保本地目录存在
local_path.mkdir(parents=True, exist_ok=True)
# 获取仓库名称从URL中提取
repo_name = repo_url.split("/tree/")[0].split("/")[-1].replace(".git", "")
# 获取变更的文件列表
changed_files = await self._get_changed_files(
repo_name, old_version or None, newest_commit
)
# 过滤文件
if include_patterns or exclude_patterns:
from .utils import filter_files
changed_files = filter_files(
changed_files, include_patterns, exclude_patterns
)
result.changed_files = changed_files
# 下载变更的文件
for file_path in changed_files:
try:
local_file_path = local_path / file_path
await self._download_file(
repo_name, file_path, local_file_path, newest_commit
)
except Exception as e:
logger.error(f"下载文件 {file_path} 失败", LOG_COMMAND, e=e)
# 更新版本文件
await self.write_version_file(local_path, newest_commit)
result.success = True
return result
except RepoUpdateError as e:
logger.error(f"更新仓库失败: {e}")
# 从URL中提取仓库名称
repo_name = repo_url.split("/tree/")[0].split("/")[-1].replace(".git", "")
return RepoUpdateResult(
repo_type=RepoType.ALIYUN,
repo_name=repo_name,
owner=self.config.aliyun_codeup.organization_id,
old_version="",
new_version="",
error_message=str(e),
)
except Exception as e:
logger.error(f"更新仓库失败: {e}")
# 从URL中提取仓库名称
repo_name = repo_url.split("/tree/")[0].split("/")[-1].replace(".git", "")
return RepoUpdateResult(
repo_type=RepoType.ALIYUN,
repo_name=repo_name,
owner=self.config.aliyun_codeup.organization_id,
old_version="",
new_version="",
error_message=str(e),
)
async def download_file(
self,
repo_url: str,
file_path: str,
local_path: Path,
branch: str = "main",
) -> FileDownloadResult:
"""
从阿里云CodeUp下载单个文件
Args:
repo_url: 仓库URL或名称
file_path: 文件在仓库中的路径
local_path: 本地保存路径
branch: 分支名称
Returns:
FileDownloadResult: 下载结果
"""
try:
# 检查配置
self._check_config()
# 获取仓库名称从URL中提取
repo_identifier = (
repo_url.split("/tree/")[0].split("/")[-1].replace(".git", "")
)
# 创建结果对象
result = FileDownloadResult(
repo_type=RepoType.ALIYUN,
repo_name=repo_url.split("/tree/")[0]
.split("/")[-1]
.replace(".git", ""),
file_path=file_path,
version=branch,
)
# 确保本地目录存在
Path(local_path).parent.mkdir(parents=True, exist_ok=True)
# 下载文件
file_size = await self._download_file(
repo_identifier, file_path, local_path, branch
)
result.success = True
result.file_size = file_size
return result
except RepoDownloadError as e:
logger.error(f"下载文件失败: {e}")
# 从URL中提取仓库名称
repo_name = repo_url.split("/tree/")[0].split("/")[-1].replace(".git", "")
return FileDownloadResult(
repo_type=RepoType.ALIYUN,
repo_name=repo_name,
file_path=file_path,
version=branch,
error_message=str(e),
)
except Exception as e:
logger.error(f"下载文件失败: {e}")
# 从URL中提取仓库名称
repo_name = repo_url.split("/tree/")[0].split("/")[-1].replace(".git", "")
return FileDownloadResult(
repo_type=RepoType.ALIYUN,
repo_name=repo_name,
file_path=file_path,
version=branch,
error_message=str(e),
)
async def get_file_list(
self,
repo_url: str,
dir_path: str = "",
branch: str = "main",
recursive: bool = False,
) -> list[RepoFileInfo]:
"""
获取仓库文件列表
Args:
repo_url: 仓库URL或名称
dir_path: 目录路径空字符串表示仓库根目录
branch: 分支名称
recursive: 是否递归获取子目录
Returns:
list[RepoFileInfo]: 文件信息列表
"""
try:
# 检查配置
self._check_config()
# 获取仓库名称从URL中提取
repo_identifier = (
repo_url.split("/tree/")[0].split("/")[-1].replace(".git", "")
)
# 获取文件列表
search_type = "RECURSIVE" if recursive else "DIRECT"
tree_list = await AliyunFileInfo.get_repository_tree(
repo_identifier, dir_path, branch, search_type
)
result = []
for tree in tree_list:
# 跳过非当前目录的文件(如果不是递归模式)
if (
not recursive
and tree.path != dir_path
and "/" in tree.path.replace(dir_path, "", 1).strip("/")
):
continue
file_info = RepoFileInfo(
path=tree.path,
is_dir=tree.type == "tree",
)
result.append(file_info)
return result
except Exception as e:
logger.error(f"获取文件列表失败: {e}")
return []
async def get_commit_info(
self, repo_url: str, commit_id: str
) -> RepoCommitInfo | None:
"""
获取提交信息
Args:
repo_url: 仓库URL或名称
commit_id: 提交ID
Returns:
Optional[RepoCommitInfo]: 提交信息如果获取失败则返回None
"""
try:
# 检查配置
self._check_config()
# 获取仓库名称从URL中提取
repo_identifier = (
repo_url.split("/tree/")[0].split("/")[-1].replace(".git", "")
)
# 获取提交信息
# 注意这里假设AliyunFileInfo有get_commit_info方法如果没有需要实现
commit_data = await self._get_commit_info(repo_identifier, commit_id)
if not commit_data:
return None
# 解析提交信息
id_value = commit_data.get("id", commit_id)
message_value = commit_data.get("message", "")
author_value = commit_data.get("author_name", "")
date_value = commit_data.get(
"authored_date", datetime.now().isoformat()
).replace("Z", "+00:00")
return RepoCommitInfo(
commit_id=id_value,
message=message_value,
author=author_value,
commit_time=datetime.fromisoformat(date_value),
changed_files=[], # 阿里云API可能没有直接提供变更文件列表
)
except Exception as e:
logger.error(f"获取提交信息失败: {e}")
return None
def _check_config(self):
"""检查配置"""
if not self.config.aliyun_codeup.access_key_id:
raise AuthenticationError("阿里云CodeUp")
if not self.config.aliyun_codeup.access_key_secret:
raise AuthenticationError("阿里云CodeUp")
if not self.config.aliyun_codeup.organization_id:
raise AuthenticationError("阿里云CodeUp")
async def _get_newest_commit(self, repo_name: str, branch: str) -> str:
"""
获取仓库最新提交ID
Args:
repo_name: 仓库名称
branch: 分支名称
Returns:
str: 提交ID
"""
try:
newest_commit = await AliyunFileInfo.get_newest_commit(repo_name, branch)
if not newest_commit:
raise RepoNotFoundError(repo_name)
return newest_commit
except Exception as e:
logger.error(f"获取最新提交ID失败: {e}")
raise RepoUpdateError(f"获取最新提交ID失败: {e}")
async def _get_commit_info(self, repo_name: str, commit_id: str) -> dict:
"""
获取提交信息
Args:
repo_name: 仓库名称
commit_id: 提交ID
Returns:
dict: 提交信息
"""
# 这里需要实现从阿里云获取提交信息的逻辑
# 由于AliyunFileInfo可能没有get_commit_info方法这里提供一个简单的实现
try:
# 这里应该是调用阿里云API获取提交信息
# 这里只是一个示例实际上需要根据阿里云API实现
return {
"id": commit_id,
"message": "提交信息",
"author_name": "作者",
"authored_date": datetime.now().isoformat(),
}
except Exception as e:
logger.error(f"获取提交信息失败: {e}")
return {}
@cached(ttl=3600)
async def _get_changed_files(
self, repo_name: str, old_commit: str | None, new_commit: str
) -> list[str]:
"""
获取两个提交之间变更的文件列表
Args:
repo_name: 仓库名称
old_commit: 旧提交ID如果为None则获取所有文件
new_commit: 新提交ID
Returns:
list[str]: 变更的文件列表
"""
if not old_commit:
# 如果没有旧提交,则获取仓库中的所有文件
tree_list = await AliyunFileInfo.get_repository_tree(
repo_name, "", new_commit, "RECURSIVE"
)
return [tree.path for tree in tree_list if tree.type == "blob"]
# 获取两个提交之间的差异
try:
return []
except Exception as e:
logger.error(f"获取提交差异失败: {e}")
raise RepoUpdateError(f"获取提交差异失败: {e}")
async def update_via_git(
self,
repo_url: str,
local_path: Path,
branch: str = "main",
force: bool = False,
*,
repo_type: RepoType | None = None,
owner: str | None = None,
prepare_repo_url: Callable[[str], str] | None = None,
) -> RepoUpdateResult:
"""
通过Git命令直接更新仓库
参数:
repo_url: 仓库名称
local_path: 本地仓库路径
branch: 分支名称
force: 是否强制拉取
返回:
RepoUpdateResult: 更新结果
"""
# 定义预处理函数构建阿里云CodeUp的URL
def prepare_aliyun_url(repo_url: str) -> str:
import base64
repo_name = repo_url.split("/tree/")[0].split("/")[-1].replace(".git", "")
# 构建仓库URL
# 阿里云CodeUp的仓库URL格式通常为
# https://codeup.aliyun.com/{organization_id}/{organization_name}/{repo_name}.git
url = f"https://codeup.aliyun.com/{self.config.aliyun_codeup.organization_id}/{self.config.aliyun_codeup.organization_name}/{repo_name}.git"
# 添加访问令牌 - 使用base64解码后的令牌
if self.config.aliyun_codeup.rdc_access_token_encrypted:
try:
# 解码RDC访问令牌
token = base64.b64decode(
self.config.aliyun_codeup.rdc_access_token_encrypted.encode()
).decode()
# 阿里云CodeUp使用oauth2:token的格式进行身份验证
url = url.replace("https://", f"https://oauth2:{token}@")
logger.debug(f"使用RDC令牌构建阿里云URL: {url.split('@')[0]}@***")
except Exception as e:
logger.error(f"解码RDC令牌失败: {e}")
return url
# 调用基类的update_via_git方法
return await super().update_via_git(
repo_url=repo_url,
local_path=local_path,
branch=branch,
force=force,
repo_type=RepoType.ALIYUN,
owner=self.config.aliyun_codeup.organization_id,
prepare_repo_url=prepare_aliyun_url,
)
async def update(
self,
repo_url: str,
local_path: Path,
branch: str = "main",
use_git: bool = True,
force: bool = False,
include_patterns: list[str] | None = None,
exclude_patterns: list[str] | None = None,
) -> RepoUpdateResult:
"""
更新仓库可选择使用Git命令或API方式
参数:
repo_url: 仓库名称
local_path: 本地保存路径
branch: 分支名称
use_git: 是否使用Git命令更新
include_patterns: 包含的文件模式列表 ["*.py", "docs/*.md"]
exclude_patterns: 排除的文件模式列表 ["__pycache__/*", "*.pyc"]
返回:
RepoUpdateResult: 更新结果
"""
if use_git:
return await self.update_via_git(repo_url, local_path, branch, force)
else:
return await self.update_repo(
repo_url, local_path, branch, include_patterns, exclude_patterns
)
async def _download_file(
self, repo_name: str, file_path: str, local_path: Path, ref: str
) -> int:
"""
下载文件
Args:
repo_name: 仓库名称
file_path: 文件在仓库中的路径
local_path: 本地保存路径
ref: 分支/标签/提交ID
Returns:
int: 文件大小字节
"""
# 确保目录存在
local_path.parent.mkdir(parents=True, exist_ok=True)
# 获取文件内容
for retry in range(self.config.aliyun_codeup.download_retry + 1):
try:
content = await AliyunFileInfo.get_file_content(
file_path, repo_name, ref
)
if content is None:
raise FileNotFoundError(file_path, repo_name)
# 保存文件
return await self.save_file_content(content.encode("utf-8"), local_path)
except FileNotFoundError as e:
# 这些错误不需要重试
raise e
except Exception as e:
if retry < self.config.aliyun_codeup.download_retry:
logger.warning("下载文件失败,将重试", LOG_COMMAND, e=e)
await asyncio.sleep(1)
continue
raise RepoDownloadError(f"下载文件失败: {e}")
raise RepoDownloadError("下载文件失败: 超过最大重试次数")

View File

@ -0,0 +1,432 @@
"""
仓库管理工具的基础管理器
"""
from abc import ABC, abstractmethod
from pathlib import Path
import aiofiles
from zhenxun.services.log import logger
from .config import LOG_COMMAND, RepoConfig
from .models import (
FileDownloadResult,
RepoCommitInfo,
RepoFileInfo,
RepoType,
RepoUpdateResult,
)
from .utils import check_git, filter_files, run_git_command
class BaseRepoManager(ABC):
"""仓库管理工具基础类"""
def __init__(self, config: RepoConfig | None = None):
"""
初始化仓库管理工具
参数:
config: 配置如果为None则使用默认配置
"""
self.config = config or RepoConfig.get_instance()
self.config.ensure_dirs()
@abstractmethod
async def update_repo(
self,
repo_url: str,
local_path: Path,
branch: str = "main",
include_patterns: list[str] | None = None,
exclude_patterns: list[str] | None = None,
) -> RepoUpdateResult:
"""
更新仓库
参数:
repo_url: 仓库URL或名称
local_path: 本地保存路径
branch: 分支名称
include_patterns: 包含的文件模式列表 ["*.py", "docs/*.md"]
exclude_patterns: 排除的文件模式列表 ["__pycache__/*", "*.pyc"]
返回:
RepoUpdateResult: 更新结果
"""
pass
@abstractmethod
async def download_file(
self,
repo_url: str,
file_path: str,
local_path: Path,
branch: str = "main",
) -> FileDownloadResult:
"""
下载单个文件
参数:
repo_url: 仓库URL或名称
file_path: 文件在仓库中的路径
local_path: 本地保存路径
branch: 分支名称
返回:
FileDownloadResult: 下载结果
"""
pass
@abstractmethod
async def get_file_list(
self,
repo_url: str,
dir_path: str = "",
branch: str = "main",
recursive: bool = False,
) -> list[RepoFileInfo]:
"""
获取仓库文件列表
参数:
repo_url: 仓库URL或名称
dir_path: 目录路径空字符串表示仓库根目录
branch: 分支名称
recursive: 是否递归获取子目录
返回:
List[RepoFileInfo]: 文件信息列表
"""
pass
@abstractmethod
async def get_commit_info(
self, repo_url: str, commit_id: str
) -> RepoCommitInfo | None:
"""
获取提交信息
参数:
repo_url: 仓库URL或名称
commit_id: 提交ID
返回:
Optional[RepoCommitInfo]: 提交信息如果获取失败则返回None
"""
pass
async def save_file_content(self, content: bytes, local_path: Path) -> int:
"""
保存文件内容
参数:
content: 文件内容
local_path: 本地保存路径
返回:
int: 文件大小字节
"""
# 确保目录存在
local_path.parent.mkdir(parents=True, exist_ok=True)
# 保存文件
async with aiofiles.open(local_path, "wb") as f:
await f.write(content)
return len(content)
async def read_version_file(self, local_dir: Path) -> str:
"""
读取版本文件
参数:
local_dir: 本地目录
返回:
str: 版本号
"""
version_file = local_dir / "__version__"
if not version_file.exists():
return ""
try:
async with aiofiles.open(version_file) as f:
return (await f.read()).strip()
except Exception as e:
logger.error(f"读取版本文件失败: {e}")
return ""
async def write_version_file(self, local_dir: Path, version: str) -> bool:
"""
写入版本文件
参数:
local_dir: 本地目录
version: 版本号
返回:
bool: 是否成功
"""
version_file = local_dir / "__version__"
try:
version_bb = "vNone"
async with aiofiles.open(version_file) as rf:
if text := await rf.read():
version_bb = text.strip().split("-")[0]
async with aiofiles.open(version_file, "w") as f:
await f.write(f"{version_bb}-{version[:6]}")
return True
except Exception as e:
logger.error(f"写入版本文件失败: {e}")
return False
def filter_files(
self,
files: list[str],
include_patterns: list[str] | None = None,
exclude_patterns: list[str] | None = None,
) -> list[str]:
"""
过滤文件列表
参数:
files: 文件列表
include_patterns: 包含的文件模式列表 ["*.py", "docs/*.md"]
exclude_patterns: 排除的文件模式列表 ["__pycache__/*", "*.pyc"]
返回:
List[str]: 过滤后的文件列表
"""
return filter_files(files, include_patterns, exclude_patterns)
async def update_via_git(
self,
repo_url: str,
local_path: Path,
branch: str = "main",
force: bool = False,
*,
repo_type: RepoType | None = None,
owner="",
prepare_repo_url=None,
) -> RepoUpdateResult:
"""
通过Git命令直接更新仓库
参数:
repo_url: 仓库URL或名称
local_path: 本地仓库路径
branch: 分支名称
force: 是否强制拉取
repo_type: 仓库类型
owner: 仓库拥有者
prepare_repo_url: 预处理仓库URL的函数
返回:
RepoUpdateResult: 更新结果
"""
from .models import RepoType
repo_name = repo_url.split("/tree/")[0].split("/")[-1].replace(".git", "")
try:
# 创建结果对象
result = RepoUpdateResult(
repo_type=repo_type or RepoType.GITHUB, # 默认使用GitHub类型
repo_name=repo_name,
owner=owner or "",
old_version="",
new_version="",
)
# 检查Git是否可用
if not await check_git():
return RepoUpdateResult(
repo_type=repo_type or RepoType.GITHUB,
repo_name=repo_name,
owner=owner or "",
old_version="",
new_version="",
error_message="Git命令不可用",
)
# 预处理仓库URL
if prepare_repo_url:
repo_url = prepare_repo_url(repo_url)
# 检查本地目录是否存在
if not local_path.exists():
# 如果不存在,则克隆仓库
logger.info(f"克隆仓库 {repo_url}{local_path}", LOG_COMMAND)
success, stdout, stderr = await run_git_command(
f"clone -b {branch} {repo_url} {local_path}"
)
if not success:
return RepoUpdateResult(
repo_type=repo_type or RepoType.GITHUB,
repo_name=repo_name,
owner=owner or "",
old_version="",
new_version="",
error_message=f"克隆仓库失败: {stderr}",
)
# 获取当前提交ID
success, new_version, _ = await run_git_command(
"rev-parse HEAD", cwd=local_path
)
result.new_version = new_version.strip()
result.success = True
return result
# 如果目录存在检查是否是Git仓库
# 首先检查目录本身是否有.git文件夹
git_dir = local_path / ".git"
if not git_dir.is_dir():
# 如果不是Git仓库尝试初始化它
logger.info(f"目录 {local_path} 不是Git仓库尝试初始化", LOG_COMMAND)
init_success, _, init_stderr = await run_git_command(
"init", cwd=local_path
)
if not init_success:
return RepoUpdateResult(
repo_type=repo_type or RepoType.GITHUB,
repo_name=repo_name,
owner=owner or "",
old_version="",
new_version="",
error_message=f"初始化Git仓库失败: {init_stderr}",
)
# 添加远程仓库
remote_success, _, remote_stderr = await run_git_command(
f"remote add origin {repo_url}", cwd=local_path
)
if not remote_success:
return RepoUpdateResult(
repo_type=repo_type or RepoType.GITHUB,
repo_name=repo_name,
owner=owner or "",
old_version="",
new_version="",
error_message=f"添加远程仓库失败: {remote_stderr}",
)
logger.info(f"成功初始化Git仓库 {local_path}", LOG_COMMAND)
# 获取当前提交ID作为旧版本
success, old_version, _ = await run_git_command(
"rev-parse HEAD", cwd=local_path
)
result.old_version = old_version.strip()
# 获取当前远程URL
success, remote_url, _ = await run_git_command(
"config --get remote.origin.url", cwd=local_path
)
# 如果远程URL不匹配则更新它
remote_url = remote_url.strip()
if success and repo_url not in remote_url and remote_url not in repo_url:
logger.info(f"更新远程URL: {remote_url} -> {repo_url}", LOG_COMMAND)
await run_git_command(
f"remote set-url origin {repo_url}", cwd=local_path
)
# 获取远程更新
logger.info(f"获取远程更新: {repo_url}", LOG_COMMAND)
success, _, stderr = await run_git_command("fetch origin", cwd=local_path)
if not success:
return RepoUpdateResult(
repo_type=repo_type or RepoType.GITHUB,
repo_name=repo_name,
owner=owner or "",
old_version=old_version.strip(),
new_version="",
error_message=f"获取远程更新失败: {stderr}",
)
# 获取当前分支
success, current_branch, _ = await run_git_command(
"rev-parse --abbrev-ref HEAD", cwd=local_path
)
current_branch = current_branch.strip()
# 如果当前分支不是目标分支,则切换分支
if success and current_branch != branch:
logger.info(f"切换分支: {current_branch} -> {branch}", LOG_COMMAND)
success, _, stderr = await run_git_command(
f"checkout {branch}", cwd=local_path
)
if not success:
return RepoUpdateResult(
repo_type=repo_type or RepoType.GITHUB,
repo_name=repo_name,
owner=owner or "",
old_version=old_version.strip(),
new_version="",
error_message=f"切换分支失败: {stderr}",
)
# 拉取最新代码
logger.info(f"拉取最新代码: {repo_url}", LOG_COMMAND)
pull_cmd = f"pull origin {branch}"
if force:
pull_cmd = f"fetch --all && git reset --hard origin/{branch}"
logger.info("使用强制拉取模式", LOG_COMMAND)
success, _, stderr = await run_git_command(pull_cmd, cwd=local_path)
if not success:
return RepoUpdateResult(
repo_type=repo_type or RepoType.GITHUB,
repo_name=repo_name,
owner=owner or "",
old_version=old_version.strip(),
new_version="",
error_message=f"拉取最新代码失败: {stderr}",
)
# 获取更新后的提交ID
success, new_version, _ = await run_git_command(
"rev-parse HEAD", cwd=local_path
)
result.new_version = new_version.strip()
# 如果版本相同,则无需更新
if old_version.strip() == new_version.strip():
logger.info(
f"仓库 {repo_url} 已是最新版本: {new_version.strip()}", LOG_COMMAND
)
result.success = True
return result
# 获取变更的文件列表
success, changed_files_output, _ = await run_git_command(
f"diff --name-only {old_version.strip()} {new_version.strip()}",
cwd=local_path,
)
if success:
changed_files = [
line.strip()
for line in changed_files_output.splitlines()
if line.strip()
]
result.changed_files = changed_files
logger.info(f"变更的文件列表: {changed_files}", LOG_COMMAND)
result.success = True
return result
except Exception as e:
logger.error("Git更新失败", LOG_COMMAND, e=e)
return RepoUpdateResult(
repo_type=repo_type or RepoType.GITHUB,
repo_name=repo_name,
owner=owner or "",
old_version="",
new_version="",
error_message=str(e),
)

View File

@ -0,0 +1,77 @@
"""
仓库管理工具的配置模块
"""
from dataclasses import dataclass, field
from pathlib import Path
from zhenxun.configs.path_config import TEMP_PATH
LOG_COMMAND = "RepoUtils"
@dataclass
class GithubConfig:
"""GitHub配置"""
# API超时时间
api_timeout: int = 30
# 下载超时时间(秒)
download_timeout: int = 60
# 下载重试次数
download_retry: int = 3
# 代理配置
proxy: str | None = None
@dataclass
class AliyunCodeupConfig:
"""阿里云CodeUp配置"""
# 访问密钥ID
access_key_id: str = "LTAI5tNmf7KaTAuhcvRobAQs"
# 访问密钥密钥
access_key_secret: str = "NmJ3d2VNRU1MREY0T1RtRnBqMlFqdlBxN3pMUk1j"
# 组织ID
organization_id: str = "67a361cf556e6cdab537117a"
# 组织名称
organization_name: str = "zhenxun-org"
# RDC Access Token
rdc_access_token_encrypted: str = (
"cHQtYXp0allnQWpub0FYZWpqZm1RWGtneHk0XzBlMmYzZTZmLWQwOWItNDE4Mi1iZWUx"
"LTQ1ZTFkYjI0NGRlMg=="
)
# 区域
region: str = "cn-hangzhou"
# 端点
endpoint: str = "devops.cn-hangzhou.aliyuncs.com"
# 下载重试次数
download_retry: int = 3
@dataclass
class RepoConfig:
"""仓库管理工具配置"""
# 缓存目录
cache_dir: Path = TEMP_PATH / "repo_cache"
# GitHub配置
github: GithubConfig = field(default_factory=GithubConfig)
# 阿里云CodeUp配置
aliyun_codeup: AliyunCodeupConfig = field(default_factory=AliyunCodeupConfig)
# 单例实例
_instance = None
@classmethod
def get_instance(cls) -> "RepoConfig":
"""获取单例实例"""
if cls._instance is None:
cls._instance = cls()
return cls._instance
def ensure_dirs(self):
"""确保目录存在"""
self.cache_dir.mkdir(parents=True, exist_ok=True)

View File

@ -0,0 +1,68 @@
"""
仓库管理工具的异常类
"""
class RepoManagerError(Exception):
"""仓库管理工具异常基类"""
def __init__(self, message: str, repo_name: str | None = None):
self.message = message
self.repo_name = repo_name
super().__init__(self.message)
class RepoUpdateError(RepoManagerError):
"""仓库更新异常"""
def __init__(self, message: str, repo_name: str | None = None):
super().__init__(f"仓库更新失败: {message}", repo_name)
class RepoDownloadError(RepoManagerError):
"""仓库下载异常"""
def __init__(self, message: str, repo_name: str | None = None):
super().__init__(f"文件下载失败: {message}", repo_name)
class RepoNotFoundError(RepoManagerError):
"""仓库不存在异常"""
def __init__(self, repo_name: str):
super().__init__(f"仓库不存在: {repo_name}", repo_name)
class FileNotFoundError(RepoManagerError):
"""文件不存在异常"""
def __init__(self, file_path: str, repo_name: str | None = None):
super().__init__(f"文件不存在: {file_path}", repo_name)
class AuthenticationError(RepoManagerError):
"""认证异常"""
def __init__(self, repo_type: str):
super().__init__(f"认证失败: {repo_type}")
class ApiRateLimitError(RepoManagerError):
"""API速率限制异常"""
def __init__(self, repo_type: str):
super().__init__(f"API速率限制: {repo_type}")
class NetworkError(RepoManagerError):
"""网络异常"""
def __init__(self, message: str):
super().__init__(f"网络错误: {message}")
class ConfigError(RepoManagerError):
"""配置异常"""
def __init__(self, message: str):
super().__init__(f"配置错误: {message}")

View File

@ -0,0 +1,543 @@
"""
仓库文件管理器用于从GitHub和阿里云CodeUp获取指定文件内容
"""
from pathlib import Path
from typing import cast, overload
import aiofiles
from httpx import Response
from zhenxun.services.log import logger
from zhenxun.utils.github_utils import GithubUtils
from zhenxun.utils.github_utils.models import AliyunTreeType, GitHubStrategy, TreeType
from zhenxun.utils.http_utils import AsyncHttpx
from .config import LOG_COMMAND, RepoConfig
from .exceptions import FileNotFoundError, NetworkError, RepoManagerError
from .models import FileDownloadResult, RepoFileInfo, RepoType
class RepoFileManager:
"""仓库文件管理器用于获取GitHub和阿里云仓库中的文件内容"""
def __init__(self, config: RepoConfig | None = None):
"""
初始化仓库文件管理器
参数:
config: 配置如果为None则使用默认配置
"""
self.config = config or RepoConfig.get_instance()
self.config.ensure_dirs()
@overload
async def get_github_file_content(
self, url: str, file_path: str, ignore_error: bool = False
) -> str: ...
@overload
async def get_github_file_content(
self, url: str, file_path: list[str], ignore_error: bool = False
) -> list[tuple[str, str]]: ...
async def get_github_file_content(
self, url: str, file_path: str | list[str], ignore_error: bool = False
) -> str | list[tuple[str, str]]:
"""
获取GitHub仓库文件内容
参数:
url: 仓库URL
file_path: 文件路径或文件路径列表
ignore_error: 是否忽略错误
返回:
list[tuple[str, str]]: 文件路径文件内容
"""
results = []
is_str_input = isinstance(file_path, str)
try:
if is_str_input:
file_path = [file_path]
repo_info = GithubUtils.parse_github_url(url)
if await repo_info.update_repo_commit():
logger.info(f"获取最新提交: {repo_info.branch}", LOG_COMMAND)
else:
logger.warning(f"获取最新提交失败: {repo_info}", LOG_COMMAND)
for f in file_path:
try:
file_url = await repo_info.get_raw_download_urls(f)
for fu in file_url:
response: Response = await AsyncHttpx.get(
fu, check_status_code=200
)
if response.status_code == 200:
logger.info(f"获取github文件内容成功: {f}", LOG_COMMAND)
# 确保使用UTF-8编码解析响应内容
try:
text_content = response.content.decode("utf-8")
except UnicodeDecodeError:
# 如果UTF-8解码失败尝试其他编码
text_content = response.content.decode(
"utf-8", errors="ignore"
)
logger.warning(
f"解码文件内容时出现错误,使用忽略错误模式: {f}",
LOG_COMMAND,
)
results.append((f, text_content))
break
else:
logger.warning(
f"获取github文件内容失败: {response.status_code}",
LOG_COMMAND,
)
except Exception as e:
logger.warning(f"获取github文件内容失败: {f}", LOG_COMMAND, e=e)
if not ignore_error:
raise
except Exception as e:
logger.error(f"获取GitHub文件内容失败: {file_path}", LOG_COMMAND, e=e)
raise
logger.debug(f"获取GitHub文件内容: {[r[0] for r in results]}", LOG_COMMAND)
return results[0][1] if is_str_input and results else results
@overload
async def get_aliyun_file_content(
self,
repo_name: str,
file_path: str,
branch: str = "main",
ignore_error: bool = False,
) -> str: ...
@overload
async def get_aliyun_file_content(
self,
repo_name: str,
file_path: list[str],
branch: str = "main",
ignore_error: bool = False,
) -> list[tuple[str, str]]: ...
async def get_aliyun_file_content(
self,
repo_name: str,
file_path: str | list[str],
branch: str = "main",
ignore_error: bool = False,
) -> str | list[tuple[str, str]]:
"""
获取阿里云CodeUp仓库文件内容
参数:
repo: 仓库名称
file_path: 文件路径
branch: 分支名称
ignore_error: 是否忽略错误
返回:
list[tuple[str, str]]: 文件路径文件内容
"""
results = []
is_str_input = isinstance(file_path, str)
# 导入阿里云相关模块
from zhenxun.utils.github_utils.models import AliyunFileInfo
if is_str_input:
file_path = [file_path]
for f in file_path:
try:
content = await AliyunFileInfo.get_file_content(
file_path=f, repo=repo_name, ref=branch
)
results.append((f, content))
except Exception as e:
logger.warning(f"获取阿里云文件内容失败: {file_path}", LOG_COMMAND, e=e)
if not ignore_error:
raise
logger.debug(f"获取阿里云文件内容: {[r[0] for r in results]}", LOG_COMMAND)
return results[0][1] if is_str_input and results else results
@overload
async def get_file_content(
self,
repo_url: str,
file_path: str,
branch: str = "main",
repo_type: RepoType | None = None,
ignore_error: bool = False,
) -> str: ...
@overload
async def get_file_content(
self,
repo_url: str,
file_path: list[str],
branch: str = "main",
repo_type: RepoType | None = None,
ignore_error: bool = False,
) -> list[tuple[str, str]]: ...
async def get_file_content(
self,
repo_url: str,
file_path: str | list[str],
branch: str = "main",
repo_type: RepoType | None = None,
ignore_error: bool = False,
) -> str | list[tuple[str, str]]:
"""
获取仓库文件内容
参数:
repo_url: 仓库URL
file_path: 文件路径
branch: 分支名称
repo_type: 仓库类型如果为None则自动判断
ignore_error: 是否忽略错误
返回:
str: 文件内容
"""
# 确定仓库类型
repo_name = (
repo_url.split("/tree/")[0].split("/")[-1].replace(".git", "").strip()
)
if repo_type is None:
try:
return await self.get_aliyun_file_content(
repo_name, file_path, branch, ignore_error
)
except Exception:
return await self.get_github_file_content(
repo_url, file_path, ignore_error
)
try:
if repo_type == RepoType.GITHUB:
return await self.get_github_file_content(
repo_url, file_path, ignore_error
)
elif repo_type == RepoType.ALIYUN:
return await self.get_aliyun_file_content(
repo_name, file_path, branch, ignore_error
)
except Exception as e:
if isinstance(e, FileNotFoundError | NetworkError | RepoManagerError):
raise
raise RepoManagerError(f"获取文件内容失败: {e}")
async def list_directory_files(
self,
repo_url: str,
directory_path: str = "",
branch: str = "main",
repo_type: RepoType | None = None,
recursive: bool = True,
) -> list[RepoFileInfo]:
"""
获取仓库目录下的所有文件路径
参数:
repo_url: 仓库URL
directory_path: 目录路径默认为仓库根目录
branch: 分支名称
repo_type: 仓库类型如果为None则自动判断
recursive: 是否递归获取子目录文件
返回:
list[RepoFileInfo]: 文件信息列表
"""
repo_name = (
repo_url.split("/tree/")[0].split("/")[-1].replace(".git", "").strip()
)
try:
if repo_type is None:
# 尝试GitHub失败则尝试阿里云
try:
return await self._list_github_directory_files(
repo_url, directory_path, branch, recursive
)
except Exception as e:
logger.warning(
"获取GitHub目录文件失败尝试阿里云", LOG_COMMAND, e=e
)
return await self._list_aliyun_directory_files(
repo_name, directory_path, branch, recursive
)
if repo_type == RepoType.GITHUB:
return await self._list_github_directory_files(
repo_url, directory_path, branch, recursive
)
elif repo_type == RepoType.ALIYUN:
return await self._list_aliyun_directory_files(
repo_name, directory_path, branch, recursive
)
except Exception as e:
logger.error(f"获取目录文件列表失败: {directory_path}", LOG_COMMAND, e=e)
if isinstance(e, FileNotFoundError | NetworkError | RepoManagerError):
raise
raise RepoManagerError(f"获取目录文件列表失败: {e}")
async def _list_github_directory_files(
self,
repo_url: str,
directory_path: str = "",
branch: str = "main",
recursive: bool = True,
build_tree: bool = False,
) -> list[RepoFileInfo]:
"""
获取GitHub仓库目录下的所有文件路径
参数:
repo_url: 仓库URL
directory_path: 目录路径默认为仓库根目录
branch: 分支名称
recursive: 是否递归获取子目录文件
build_tree: 是否构建目录树
返回:
list[RepoFileInfo]: 文件信息列表
"""
try:
repo_info = GithubUtils.parse_github_url(repo_url)
if await repo_info.update_repo_commit():
logger.info(f"获取最新提交: {repo_info.branch}", LOG_COMMAND)
else:
logger.warning(f"获取最新提交失败: {repo_info}", LOG_COMMAND)
# 获取仓库树信息
strategy = GitHubStrategy()
strategy.body = await GitHubStrategy.parse_repo_info(repo_info)
# 处理目录路径,确保格式正确
if directory_path and not directory_path.endswith("/") and recursive:
directory_path = f"{directory_path}/"
# 获取文件列表
file_list = []
for tree_item in strategy.body.tree:
# 如果不是递归模式,只获取当前目录下的文件
if not recursive and "/" in tree_item.path.replace(
directory_path, "", 1
):
continue
# 检查是否在指定目录下
if directory_path and not tree_item.path.startswith(directory_path):
continue
# 创建文件信息对象
file_info = RepoFileInfo(
path=tree_item.path,
is_dir=tree_item.type == TreeType.DIR,
size=tree_item.size,
last_modified=None, # GitHub API不直接提供最后修改时间
)
file_list.append(file_info)
# 构建目录树结构
if recursive and build_tree:
file_list = self._build_directory_tree(file_list)
return file_list
except Exception as e:
logger.error(
f"获取GitHub目录文件列表失败: {directory_path}", LOG_COMMAND, e=e
)
raise
async def _list_aliyun_directory_files(
self,
repo_name: str,
directory_path: str = "",
branch: str = "main",
recursive: bool = True,
build_tree: bool = False,
) -> list[RepoFileInfo]:
"""
获取阿里云CodeUp仓库目录下的所有文件路径
参数:
repo_name: 仓库名称
directory_path: 目录路径默认为仓库根目录
branch: 分支名称
recursive: 是否递归获取子目录文件
build_tree: 是否构建目录树
返回:
list[RepoFileInfo]: 文件信息列表
"""
try:
from zhenxun.utils.github_utils.models import AliyunFileInfo
# 获取仓库树信息
search_type = "RECURSIVE" if recursive else "DIRECT"
tree_list = await AliyunFileInfo.get_repository_tree(
repo=repo_name,
path=directory_path,
ref=branch,
search_type=search_type,
)
# 创建文件信息对象列表
file_list = []
for tree_item in tree_list:
file_info = RepoFileInfo(
path=tree_item.path,
is_dir=tree_item.type == AliyunTreeType.DIR,
size=None, # 阿里云API不直接提供文件大小
last_modified=None, # 阿里云API不直接提供最后修改时间
)
file_list.append(file_info)
# 构建目录树结构
if recursive and build_tree:
file_list = self._build_directory_tree(file_list)
return file_list
except Exception as e:
logger.error(
f"获取阿里云目录文件列表失败: {directory_path}", LOG_COMMAND, e=e
)
raise
def _build_directory_tree(
self, file_list: list[RepoFileInfo]
) -> list[RepoFileInfo]:
"""
构建目录树结构
参数:
file_list: 文件信息列表
返回:
list[RepoFileInfo]: 根目录下的文件信息列表
"""
# 按路径排序,确保父目录在子目录之前
file_list.sort(key=lambda x: x.path)
# 创建路径到文件信息的映射
path_map = {file_info.path: file_info for file_info in file_list}
# 根目录文件列表
root_files = []
for file_info in file_list:
if parent_path := "/".join(file_info.path.split("/")[:-1]):
# 如果有父目录,将当前文件添加到父目录的子文件列表中
if parent_path in path_map:
path_map[parent_path].children.append(file_info)
else:
# 如果父目录不在列表中,创建一个虚拟的父目录
parent_info = RepoFileInfo(
path=parent_path, is_dir=True, children=[file_info]
)
path_map[parent_path] = parent_info
# 检查父目录的父目录
grand_parent_path = "/".join(parent_path.split("/")[:-1])
if grand_parent_path and grand_parent_path in path_map:
path_map[grand_parent_path].children.append(parent_info)
else:
root_files.append(parent_info)
else:
# 如果没有父目录,则是根目录下的文件
root_files.append(file_info)
# 返回根目录下的文件列表
return [
file
for file in root_files
if all(f.path != file.path for f in file_list if f != file)
]
async def download_files(
self,
repo_url: str,
file_path: tuple[str, Path] | list[tuple[str, Path]],
branch: str = "main",
repo_type: RepoType | None = None,
ignore_error: bool = False,
) -> FileDownloadResult:
"""
下载单个文件
参数:
repo_url: 仓库URL
file_path: 文件在仓库中的路径本地存储路径
branch: 分支名称
repo_type: 仓库类型如果为None则自动判断
ignore_error: 是否忽略错误
返回:
FileDownloadResult: 下载结果
"""
# 确定仓库类型和所有者
repo_name = (
repo_url.split("/tree/")[0].split("/")[-1].replace(".git", "").strip()
)
if isinstance(file_path, tuple):
file_path = [file_path]
file_path_mapping = {f[0]: f[1] for f in file_path}
# 创建结果对象
result = FileDownloadResult(
repo_type=repo_type,
repo_name=repo_name,
file_path=file_path,
version=branch,
)
try:
# 由于我们传入的是列表,所以这里一定返回列表
file_paths = [f[0] for f in file_path]
if len(file_paths) == 1:
# 如果只有一个文件,可能返回单个元组
file_contents_result = await self.get_file_content(
repo_url, file_paths[0], branch, repo_type, ignore_error
)
if isinstance(file_contents_result, tuple):
file_contents = [file_contents_result]
elif isinstance(file_contents_result, str):
file_contents = [(file_paths[0], file_contents_result)]
else:
file_contents = cast(list[tuple[str, str]], file_contents_result)
else:
# 多个文件一定返回列表
file_contents = cast(
list[tuple[str, str]],
await self.get_file_content(
repo_url, file_paths, branch, repo_type, ignore_error
),
)
for repo_file_path, content in file_contents:
local_path = file_path_mapping[repo_file_path]
local_path.parent.mkdir(parents=True, exist_ok=True)
# 使用二进制模式写入文件,避免编码问题
if isinstance(content, str):
content_bytes = content.encode("utf-8")
else:
content_bytes = content
logger.warning(f"写入文件: {local_path}")
async with aiofiles.open(local_path, "wb") as f:
await f.write(content_bytes)
result.success = True
# 计算文件大小
result.file_size = sum(
len(content.encode("utf-8") if isinstance(content, str) else content)
for _, content in file_contents
)
return result
except Exception as e:
logger.error(f"下载文件失败: {e}")
result.success = False
result.error_message = str(e)
return result

View File

@ -0,0 +1,526 @@
"""
GitHub仓库管理工具
"""
import asyncio
from collections.abc import Callable
from datetime import datetime
from pathlib import Path
from aiocache import cached
from zhenxun.services.log import logger
from zhenxun.utils.github_utils import GithubUtils, RepoInfo
from zhenxun.utils.http_utils import AsyncHttpx
from .base_manager import BaseRepoManager
from .config import LOG_COMMAND, RepoConfig
from .exceptions import (
ApiRateLimitError,
FileNotFoundError,
NetworkError,
RepoDownloadError,
RepoNotFoundError,
RepoUpdateError,
)
from .models import (
FileDownloadResult,
RepoCommitInfo,
RepoFileInfo,
RepoType,
RepoUpdateResult,
)
class GithubManager(BaseRepoManager):
"""GitHub仓库管理工具"""
def __init__(self, config: RepoConfig | None = None):
"""
初始化GitHub仓库管理工具
参数:
config: 配置如果为None则使用默认配置
"""
super().__init__(config)
async def update_repo(
self,
repo_url: str,
local_path: Path,
branch: str = "main",
include_patterns: list[str] | None = None,
exclude_patterns: list[str] | None = None,
) -> RepoUpdateResult:
"""
更新GitHub仓库
参数:
repo_url: 仓库URL格式为 https://github.com/owner/repo
local_path: 本地保存路径
branch: 分支名称
include_patterns: 包含的文件模式列表 ["*.py", "docs/*.md"]
exclude_patterns: 排除的文件模式列表 ["__pycache__/*", "*.pyc"]
返回:
RepoUpdateResult: 更新结果
"""
try:
# 解析仓库URL
repo_info = GithubUtils.parse_github_url(repo_url)
repo_info.branch = branch
# 获取仓库最新提交ID
newest_commit = await self._get_newest_commit(
repo_info.owner, repo_info.repo, branch
)
# 创建结果对象
result = RepoUpdateResult(
repo_type=RepoType.GITHUB,
repo_name=repo_info.repo,
owner=repo_info.owner,
old_version="", # 将在后面更新
new_version=newest_commit,
)
old_version = await self.read_version_file(local_path)
old_version = old_version.split("-")[-1]
result.old_version = old_version
# 如果版本相同,则无需更新
if newest_commit in old_version:
result.success = True
logger.debug(
f"仓库 {repo_info.repo} 已是最新版本: {newest_commit}",
LOG_COMMAND,
)
return result
# 确保本地目录存在
local_path.mkdir(parents=True, exist_ok=True)
# 获取变更的文件列表
changed_files = await self._get_changed_files(
repo_info.owner,
repo_info.repo,
old_version or None,
newest_commit,
)
# 过滤文件
if include_patterns or exclude_patterns:
from .utils import filter_files
changed_files = filter_files(
changed_files, include_patterns, exclude_patterns
)
result.changed_files = changed_files
# 下载变更的文件
for file_path in changed_files:
try:
local_file_path = local_path / file_path
await self._download_file(repo_info, file_path, local_file_path)
except Exception as e:
logger.error(f"下载文件 {file_path} 失败", LOG_COMMAND, e=e)
# 更新版本文件
await self.write_version_file(local_path, newest_commit)
result.success = True
return result
except RepoUpdateError as e:
logger.error("更新仓库失败", LOG_COMMAND, e=e)
return RepoUpdateResult(
repo_type=RepoType.GITHUB,
repo_name=repo_url.split("/")[-1] if "/" in repo_url else repo_url,
owner=repo_url.split("/")[-2] if "/" in repo_url else "unknown",
old_version="",
new_version="",
error_message=str(e),
)
except Exception as e:
logger.error("更新仓库失败", LOG_COMMAND, e=e)
return RepoUpdateResult(
repo_type=RepoType.GITHUB,
repo_name=repo_url.split("/")[-1] if "/" in repo_url else repo_url,
owner=repo_url.split("/")[-2] if "/" in repo_url else "unknown",
old_version="",
new_version="",
error_message=str(e),
)
async def download_file(
self,
repo_url: str,
file_path: str,
local_path: Path,
branch: str = "main",
) -> FileDownloadResult:
"""
从GitHub下载单个文件
参数:
repo_url: 仓库URL格式为 https://github.com/owner/repo
file_path: 文件在仓库中的路径
local_path: 本地保存路径
branch: 分支名称
返回:
FileDownloadResult: 下载结果
"""
repo_name = (
repo_url.split("/tree/")[0].split("/")[-1].replace(".git", "").strip()
)
try:
# 解析仓库URL
repo_info = GithubUtils.parse_github_url(repo_url)
repo_info.branch = branch
# 创建结果对象
result = FileDownloadResult(
repo_type=RepoType.GITHUB,
repo_name=repo_info.repo,
file_path=file_path,
version=branch,
)
# 确保本地目录存在
local_path.parent.mkdir(parents=True, exist_ok=True)
# 下载文件
file_size = await self._download_file(repo_info, file_path, local_path)
result.success = True
result.file_size = file_size
return result
except RepoDownloadError as e:
logger.error("下载文件失败", LOG_COMMAND, e=e)
return FileDownloadResult(
repo_type=RepoType.GITHUB,
repo_name=repo_name,
file_path=file_path,
version=branch,
error_message=str(e),
)
except Exception as e:
logger.error("下载文件失败", LOG_COMMAND, e=e)
return FileDownloadResult(
repo_type=RepoType.GITHUB,
repo_name=repo_name,
file_path=file_path,
version=branch,
error_message=str(e),
)
async def get_file_list(
self,
repo_url: str,
dir_path: str = "",
branch: str = "main",
recursive: bool = False,
) -> list[RepoFileInfo]:
"""
获取仓库文件列表
参数:
repo_url: 仓库URL格式为 https://github.com/owner/repo
dir_path: 目录路径空字符串表示仓库根目录
branch: 分支名称
recursive: 是否递归获取子目录
返回:
list[RepoFileInfo]: 文件信息列表
"""
try:
# 解析仓库URL
repo_info = GithubUtils.parse_github_url(repo_url)
repo_info.branch = branch
# 获取文件列表
for api in GithubUtils.iter_api_strategies():
try:
await api.parse_repo_info(repo_info)
files = api.get_files(dir_path, True)
result = []
for file_path in files:
# 跳过非当前目录的文件(如果不是递归模式)
if not recursive and "/" in file_path.replace(
dir_path, "", 1
).strip("/"):
continue
is_dir = file_path.endswith("/")
file_info = RepoFileInfo(path=file_path, is_dir=is_dir)
result.append(file_info)
return result
except Exception as e:
logger.debug("使用API策略获取文件列表失败", LOG_COMMAND, e=e)
continue
raise RepoNotFoundError(repo_url)
except Exception as e:
logger.error("获取文件列表失败", LOG_COMMAND, e=e)
return []
async def get_commit_info(
self, repo_url: str, commit_id: str
) -> RepoCommitInfo | None:
"""
获取提交信息
参数:
repo_url: 仓库URL格式为 https://github.com/owner/repo
commit_id: 提交ID
返回:
Optional[RepoCommitInfo]: 提交信息如果获取失败则返回None
"""
try:
# 解析仓库URL
repo_info = GithubUtils.parse_github_url(repo_url)
# 构建API URL
api_url = f"https://api.github.com/repos/{repo_info.owner}/{repo_info.repo}/commits/{commit_id}"
# 发送请求
resp = await AsyncHttpx.get(
api_url,
timeout=self.config.github.api_timeout,
proxy=self.config.github.proxy,
)
if resp.status_code == 403 and "rate limit" in resp.text.lower():
raise ApiRateLimitError("GitHub")
if resp.status_code != 200:
if resp.status_code == 404:
raise RepoNotFoundError(f"{repo_info.owner}/{repo_info.repo}")
raise NetworkError(f"HTTP {resp.status_code}: {resp.text}")
data = resp.json()
return RepoCommitInfo(
commit_id=data["sha"],
message=data["commit"]["message"],
author=data["commit"]["author"]["name"],
commit_time=datetime.fromisoformat(
data["commit"]["author"]["date"].replace("Z", "+00:00")
),
changed_files=[file["filename"] for file in data.get("files", [])],
)
except Exception as e:
logger.error("获取提交信息失败", LOG_COMMAND, e=e)
return None
async def _get_newest_commit(self, owner: str, repo: str, branch: str) -> str:
"""
获取仓库最新提交ID
参数:
owner: 仓库拥有者
repo: 仓库名称
branch: 分支名称
返回:
str: 提交ID
"""
try:
newest_commit = await RepoInfo.get_newest_commit(owner, repo, branch)
if not newest_commit:
raise RepoNotFoundError(f"{owner}/{repo}")
return newest_commit
except Exception as e:
logger.error("获取最新提交ID失败", LOG_COMMAND, e=e)
raise RepoUpdateError(f"获取最新提交ID失败: {e}")
@cached(ttl=3600)
async def _get_changed_files(
self, owner: str, repo: str, old_commit: str | None, new_commit: str
) -> list[str]:
"""
获取两个提交之间变更的文件列表
参数:
owner: 仓库拥有者
repo: 仓库名称
old_commit: 旧提交ID如果为None则获取所有文件
new_commit: 新提交ID
返回:
list[str]: 变更的文件列表
"""
if not old_commit:
# 如果没有旧提交,则获取仓库中的所有文件
api_url = f"https://api.github.com/repos/{owner}/{repo}/git/trees/{new_commit}?recursive=1"
resp = await AsyncHttpx.get(
api_url,
timeout=self.config.github.api_timeout,
proxy=self.config.github.proxy,
)
if resp.status_code == 403 and "rate limit" in resp.text.lower():
raise ApiRateLimitError("GitHub")
if resp.status_code != 200:
if resp.status_code == 404:
raise RepoNotFoundError(f"{owner}/{repo}")
raise NetworkError(f"HTTP {resp.status_code}: {resp.text}")
data = resp.json()
return [
item["path"] for item in data.get("tree", []) if item["type"] == "blob"
]
# 如果有旧提交,则获取两个提交之间的差异
api_url = f"https://api.github.com/repos/{owner}/{repo}/compare/{old_commit}...{new_commit}"
resp = await AsyncHttpx.get(
api_url,
timeout=self.config.github.api_timeout,
proxy=self.config.github.proxy,
)
if resp.status_code == 403 and "rate limit" in resp.text.lower():
raise ApiRateLimitError("GitHub")
if resp.status_code != 200:
if resp.status_code == 404:
raise RepoNotFoundError(f"{owner}/{repo}")
raise NetworkError(f"HTTP {resp.status_code}: {resp.text}")
data = resp.json()
return [file["filename"] for file in data.get("files", [])]
async def update_via_git(
self,
repo_url: str,
local_path: Path,
branch: str = "main",
force: bool = False,
*,
repo_type: RepoType | None = None,
owner: str | None = None,
prepare_repo_url: Callable[[str], str] | None = None,
) -> RepoUpdateResult:
"""
通过Git命令直接更新仓库
参数:
repo_url: 仓库URL格式为 https://github.com/owner/repo
local_path: 本地仓库路径
branch: 分支名称
force: 是否强制拉取
返回:
RepoUpdateResult: 更新结果
"""
# 解析仓库URL
repo_info = GithubUtils.parse_github_url(repo_url)
# 调用基类的update_via_git方法
return await super().update_via_git(
repo_url=repo_url,
local_path=local_path,
branch=branch,
force=force,
repo_type=RepoType.GITHUB,
owner=repo_info.owner,
)
async def update(
self,
repo_url: str,
local_path: Path,
branch: str = "main",
use_git: bool = True,
force: bool = False,
include_patterns: list[str] | None = None,
exclude_patterns: list[str] | None = None,
) -> RepoUpdateResult:
"""
更新仓库可选择使用Git命令或API方式
参数:
repo_url: 仓库URL格式为 https://github.com/owner/repo
local_path: 本地保存路径
branch: 分支名称
use_git: 是否使用Git命令更新
include_patterns: 包含的文件模式列表 ["*.py", "docs/*.md"]
exclude_patterns: 排除的文件模式列表 ["__pycache__/*", "*.pyc"]
返回:
RepoUpdateResult: 更新结果
"""
if use_git:
return await self.update_via_git(repo_url, local_path, branch, force)
else:
return await self.update_repo(
repo_url, local_path, branch, include_patterns, exclude_patterns
)
async def _download_file(
self, repo_info: RepoInfo, file_path: str, local_path: Path
) -> int:
"""
下载文件
参数:
repo_info: 仓库信息
file_path: 文件在仓库中的路径
local_path: 本地保存路径
返回:
int: 文件大小字节
"""
# 确保目录存在
local_path.parent.mkdir(parents=True, exist_ok=True)
# 获取下载URL
download_url = await repo_info.get_raw_download_url(file_path)
# 下载文件
for retry in range(self.config.github.download_retry + 1):
try:
resp = await AsyncHttpx.get(
download_url,
timeout=self.config.github.download_timeout,
)
if resp.status_code == 403 and "rate limit" in resp.text.lower():
raise ApiRateLimitError("GitHub")
if resp.status_code != 200:
if resp.status_code == 404:
raise FileNotFoundError(
file_path, f"{repo_info.owner}/{repo_info.repo}"
)
if retry < self.config.github.download_retry:
await asyncio.sleep(1)
continue
raise NetworkError(f"HTTP {resp.status_code}: {resp.text}")
# 保存文件
return await self.save_file_content(resp.content, local_path)
except (ApiRateLimitError, FileNotFoundError) as e:
# 这些错误不需要重试
raise e
except Exception as e:
if retry < self.config.github.download_retry:
logger.warning("下载文件失败,将重试", LOG_COMMAND, e=e)
await asyncio.sleep(1)
continue
raise RepoDownloadError("下载文件失败")
raise RepoDownloadError("下载文件失败: 超过最大重试次数")

View File

@ -0,0 +1,89 @@
"""
仓库管理工具的数据模型
"""
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from pathlib import Path
class RepoType(str, Enum):
"""仓库类型"""
GITHUB = "github"
ALIYUN = "aliyun"
@dataclass
class RepoFileInfo:
"""仓库文件信息"""
# 文件路径
path: str
# 是否是目录
is_dir: bool
# 文件大小(字节)
size: int | None = None
# 最后修改时间
last_modified: datetime | None = None
# 子文件列表
children: list["RepoFileInfo"] = field(default_factory=list)
@dataclass
class RepoCommitInfo:
"""仓库提交信息"""
# 提交ID
commit_id: str
# 提交消息
message: str
# 作者
author: str
# 提交时间
commit_time: datetime
# 变更的文件列表
changed_files: list[str] = field(default_factory=list)
@dataclass
class RepoUpdateResult:
"""仓库更新结果"""
# 仓库类型
repo_type: RepoType
# 仓库名称
repo_name: str
# 仓库拥有者
owner: str
# 旧版本
old_version: str
# 新版本
new_version: str
# 是否成功
success: bool = False
# 错误消息
error_message: str = ""
# 变更的文件列表
changed_files: list[str] = field(default_factory=list)
@dataclass
class FileDownloadResult:
"""文件下载结果"""
# 仓库类型
repo_type: RepoType | None
# 仓库名称
repo_name: str
# 文件路径
file_path: list[tuple[str, Path]] | str
# 版本
version: str
# 是否成功
success: bool = False
# 文件大小(字节)
file_size: int = 0
# 错误消息
error_message: str = ""

View File

@ -0,0 +1,135 @@
"""
仓库管理工具的工具函数
"""
import asyncio
from pathlib import Path
import re
from zhenxun.services.log import logger
from .config import LOG_COMMAND
async def check_git() -> bool:
"""
检查环境变量中是否存在 git
返回:
bool: 是否存在git命令
"""
try:
process = await asyncio.create_subprocess_shell(
"git --version",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, _ = await process.communicate()
return bool(stdout)
except Exception as e:
logger.error("检查git命令失败", LOG_COMMAND, e=e)
return False
async def clean_git(cwd: Path):
"""
清理git仓库
参数:
cwd: 工作目录
"""
await run_git_command("reset --hard", cwd)
await run_git_command("clean -xdf", cwd)
async def run_git_command(
command: str, cwd: Path | None = None
) -> tuple[bool, str, str]:
"""
运行git命令
参数:
command: 命令
cwd: 工作目录
返回:
tuple[bool, str, str]: (是否成功, 标准输出, 标准错误)
"""
try:
full_command = f"git {command}"
# 将Path对象转换为字符串
cwd_str = str(cwd) if cwd else None
process = await asyncio.create_subprocess_shell(
full_command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=cwd_str,
)
stdout_bytes, stderr_bytes = await process.communicate()
stdout = stdout_bytes.decode("utf-8").strip()
stderr = stderr_bytes.decode("utf-8").strip()
return process.returncode == 0, stdout, stderr
except Exception as e:
logger.error(f"运行git命令失败: {command}, 错误: {e}")
return False, "", str(e)
def glob_to_regex(pattern: str) -> str:
"""
将glob模式转换为正则表达式
参数:
pattern: glob模式 "*.py"
返回:
str: 正则表达式
"""
# 转义特殊字符
regex = re.escape(pattern)
# 替换glob通配符
regex = regex.replace(r"\*\*", ".*") # ** -> .*
regex = regex.replace(r"\*", "[^/]*") # * -> [^/]*
regex = regex.replace(r"\?", "[^/]") # ? -> [^/]
# 添加开始和结束标记
regex = f"^{regex}$"
return regex
def filter_files(
files: list[str],
include_patterns: list[str] | None = None,
exclude_patterns: list[str] | None = None,
) -> list[str]:
"""
过滤文件列表
参数:
files: 文件列表
include_patterns: 包含的文件模式列表 ["*.py", "docs/*.md"]
exclude_patterns: 排除的文件模式列表 ["__pycache__/*", "*.pyc"]
返回:
list[str]: 过滤后的文件列表
"""
result = files.copy()
# 应用包含模式
if include_patterns:
included = []
for pattern in include_patterns:
regex_pattern = glob_to_regex(pattern)
included.extend(file for file in result if re.match(regex_pattern, file))
result = included
# 应用排除模式
if exclude_patterns:
for pattern in exclude_patterns:
regex_pattern = glob_to_regex(pattern)
result = [file for file in result if not re.match(regex_pattern, file)]
return result