diff --git a/.env.dev b/.env.example similarity index 100% rename from .env.dev rename to .env.example diff --git a/.gitignore b/.gitignore index 5f5dc24d..24fa1ea6 100644 --- a/.gitignore +++ b/.gitignore @@ -144,4 +144,6 @@ log/ backup/ .idea/ resources/ -.vscode/launch.json \ No newline at end of file +.vscode/launch.json + +./.env.dev \ No newline at end of file diff --git a/tests/builtin_plugins/auto_update/test_check_update.py b/tests/builtin_plugins/auto_update/test_check_update.py index 8a505401..c40e0cb6 100644 --- a/tests/builtin_plugins/auto_update/test_check_update.py +++ b/tests/builtin_plugins/auto_update/test_check_update.py @@ -65,26 +65,29 @@ def init_mocked_api(mocked_api: MockRouter) -> None: tar_buffer = io.BytesIO() zip_bytes = io.BytesIO() - from zhenxun.builtin_plugins.auto_update.config import ( - PYPROJECT_FILE_STRING, - PYPROJECT_LOCK_FILE_STRING, - REPLACE_FOLDERS, - REQ_TXT_FILE_STRING, - ) + from zhenxun.utils.manager.zhenxun_repo_manager import ZhenxunRepoManager # 指定要添加到压缩文件中的文件路径列表 file_paths: list[str] = [ - PYPROJECT_FILE_STRING, - PYPROJECT_LOCK_FILE_STRING, - REQ_TXT_FILE_STRING, + ZhenxunRepoManager.config.PYPROJECT_FILE_STRING, + ZhenxunRepoManager.config.PYPROJECT_LOCK_FILE_STRING, + ZhenxunRepoManager.config.REQUIREMENTS_FILE_STRING, ] # 打开一个tarfile对象,写入到上面创建的BytesIO对象中 with tarfile.open(mode="w:gz", fileobj=tar_buffer) as tar: - add_files_and_folders_to_tar(tar, file_paths, folders=REPLACE_FOLDERS) + add_files_and_folders_to_tar( + tar, + file_paths, + folders=ZhenxunRepoManager.config.ZHENXUN_BOT_UPDATE_FOLDERS, + ) with zipfile.ZipFile(zip_bytes, mode="w", compression=zipfile.ZIP_DEFLATED) as zipf: - add_files_and_folders_to_zip(zipf, file_paths, folders=REPLACE_FOLDERS) + add_files_and_folders_to_zip( + zipf, + file_paths, + folders=ZhenxunRepoManager.config.ZHENXUN_BOT_UPDATE_FOLDERS, + ) mocked_api.get( url="https://codeload.github.com/HibiKier/zhenxun_bot/legacy.tar.gz/refs/tags/v0.2.2", @@ -199,52 +202,47 @@ def add_directory_to_tar(tarinfo, tar): def init_mocker_path(mocker: MockerFixture, tmp_path: Path): - from zhenxun.builtin_plugins.auto_update.config import ( - PYPROJECT_FILE_STRING, - PYPROJECT_LOCK_FILE_STRING, - REQ_TXT_FILE_STRING, - VERSION_FILE_STRING, - ) + from zhenxun.utils.manager.zhenxun_repo_manager import ZhenxunRepoManager mocker.patch( - "zhenxun.builtin_plugins.auto_update._data_source.install_requirement", + "zhenxun.utils.manager.virtual_env_package_manager.VirtualEnvPackageManager.install_requirement", return_value=None, ) mock_tmp_path = mocker.patch( - "zhenxun.builtin_plugins.auto_update._data_source.TMP_PATH", + "zhenxun.configs.path_config.TEMP_PATH", new=tmp_path / "auto_update", ) mock_base_path = mocker.patch( - "zhenxun.builtin_plugins.auto_update._data_source.BASE_PATH", + "zhenxun.utils.manager.zhenxun_repo_manager.ZhenxunRepoManager.config.ZHENXUN_BOT_CODE_PATH", new=tmp_path / "zhenxun", ) mock_backup_path = mocker.patch( - "zhenxun.builtin_plugins.auto_update._data_source.BACKUP_PATH", + "zhenxun.utils.manager.zhenxun_repo_manager.ZhenxunRepoManager.config.ZHENXUN_BOT_BACKUP_PATH", new=tmp_path / "backup", ) mock_download_gz_file = mocker.patch( - "zhenxun.builtin_plugins.auto_update._data_source.DOWNLOAD_GZ_FILE", + "zhenxun.utils.manager.zhenxun_repo_manager.ZhenxunRepoManager.config.ZHENXUN_BOT_DOWNLOAD_FILE", new=mock_tmp_path / "download_latest_file.tar.gz", ) mock_download_zip_file = mocker.patch( - "zhenxun.builtin_plugins.auto_update._data_source.DOWNLOAD_ZIP_FILE", + "zhenxun.utils.manager.zhenxun_repo_manager.ZhenxunRepoManager.config.ZHENXUN_BOT_UNZIP_PATH", new=mock_tmp_path / "download_latest_file.zip", ) mock_pyproject_file = mocker.patch( - "zhenxun.builtin_plugins.auto_update._data_source.PYPROJECT_FILE", - new=tmp_path / PYPROJECT_FILE_STRING, + "zhenxun.utils.manager.zhenxun_repo_manager.ZhenxunRepoManager.config.PYPROJECT_FILE", + new=tmp_path / ZhenxunRepoManager.config.PYPROJECT_FILE_STRING, ) mock_pyproject_lock_file = mocker.patch( - "zhenxun.builtin_plugins.auto_update._data_source.PYPROJECT_LOCK_FILE", - new=tmp_path / PYPROJECT_LOCK_FILE_STRING, + "zhenxun.utils.manager.zhenxun_repo_manager.ZhenxunRepoManager.config.PYPROJECT_LOCK_FILE", + new=tmp_path / ZhenxunRepoManager.config.PYPROJECT_LOCK_FILE_STRING, ) mock_req_txt_file = mocker.patch( - "zhenxun.builtin_plugins.auto_update._data_source.REQ_TXT_FILE", - new=tmp_path / REQ_TXT_FILE_STRING, + "zhenxun.utils.manager.zhenxun_repo_manager.ZhenxunRepoManager.config.REQUIREMENTS_FILE", + new=tmp_path / ZhenxunRepoManager.config.REQUIREMENTS_FILE_STRING, ) mock_version_file = mocker.patch( - "zhenxun.builtin_plugins.auto_update._data_source.VERSION_FILE", - new=tmp_path / VERSION_FILE_STRING, + "zhenxun.utils.manager.zhenxun_repo_manager.ZhenxunRepoManager.config.ZHENXUN_BOT_VERSION_FILE_STRING", + new=tmp_path / ZhenxunRepoManager.config.ZHENXUN_BOT_VERSION_FILE_STRING, ) open(mock_version_file, "w").write("__version__: v0.2.2") return ( @@ -271,12 +269,7 @@ async def test_check_update_release( 测试检查更新(release) """ from zhenxun.builtin_plugins.auto_update import _matcher - from zhenxun.builtin_plugins.auto_update.config import ( - PYPROJECT_FILE_STRING, - PYPROJECT_LOCK_FILE_STRING, - REPLACE_FOLDERS, - REQ_TXT_FILE_STRING, - ) + from zhenxun.utils.manager.zhenxun_repo_manager import ZhenxunRepoManager init_mocked_api(mocked_api=mocked_api) @@ -295,7 +288,7 @@ async def test_check_update_release( # 确保目录下有一个子目录,以便 os.listdir() 能返回一个目录名 mock_tmp_path.mkdir(parents=True, exist_ok=True) - for folder in REPLACE_FOLDERS: + for folder in ZhenxunRepoManager.config.ZHENXUN_BOT_UPDATE_FOLDERS: (mock_base_path / folder).mkdir(parents=True, exist_ok=True) mock_pyproject_file.write_bytes(b"") @@ -305,7 +298,7 @@ async def test_check_update_release( async with app.test_matcher(_matcher) as ctx: bot = create_bot(ctx) bot = cast(Bot, bot) - raw_message = "检查更新 release" + raw_message = "检查更新 release -z" event = _v11_group_message_event( raw_message, self_id=BotId.QQ_BOT, @@ -324,14 +317,14 @@ async def test_check_update_release( ctx.should_call_api( "send_msg", _v11_private_message_send( - message="检测真寻已更新,版本更新:v0.2.2 -> v0.2.2\n开始更新...", + message="检测真寻已更新,当前版本:v0.2.2\n开始更新...", user_id=UserId.SUPERUSER, ), ) ctx.should_call_send( event=event, message=Message( - "版本更新完成\n版本: v0.2.2 -> v0.2.2\n请重新启动真寻以完成更新!" + "版本更新完成!\n版本: v0.2.2 -> v0.2.2\n请重新启动真寻以完成更新!" ), result=None, bot=bot, @@ -340,9 +333,13 @@ async def test_check_update_release( assert mocked_api["release_latest"].called assert mocked_api["release_download_url_redirect"].called - assert (mock_backup_path / PYPROJECT_FILE_STRING).exists() - assert (mock_backup_path / PYPROJECT_LOCK_FILE_STRING).exists() - assert (mock_backup_path / REQ_TXT_FILE_STRING).exists() + assert (mock_backup_path / ZhenxunRepoManager.config.PYPROJECT_FILE_STRING).exists() + assert ( + mock_backup_path / ZhenxunRepoManager.config.PYPROJECT_LOCK_FILE_STRING + ).exists() + assert ( + mock_backup_path / ZhenxunRepoManager.config.REQUIREMENTS_FILE_STRING + ).exists() assert not mock_download_gz_file.exists() assert not mock_download_zip_file.exists() @@ -351,9 +348,9 @@ async def test_check_update_release( assert mock_pyproject_lock_file.read_bytes() == b"new" assert mock_req_txt_file.read_bytes() == b"new" - for folder in REPLACE_FOLDERS: + for folder in ZhenxunRepoManager.config.ZHENXUN_BOT_UPDATE_FOLDERS: assert not (mock_base_path / folder).exists() - for folder in REPLACE_FOLDERS: + for folder in ZhenxunRepoManager.config.ZHENXUN_BOT_UPDATE_FOLDERS: assert (mock_backup_path / folder).exists() @@ -368,12 +365,7 @@ async def test_check_update_main( 测试检查更新(正式环境) """ from zhenxun.builtin_plugins.auto_update import _matcher - from zhenxun.builtin_plugins.auto_update.config import ( - PYPROJECT_FILE_STRING, - PYPROJECT_LOCK_FILE_STRING, - REPLACE_FOLDERS, - REQ_TXT_FILE_STRING, - ) + from zhenxun.utils.manager.zhenxun_repo_manager import ZhenxunRepoManager init_mocked_api(mocked_api=mocked_api) @@ -391,7 +383,7 @@ async def test_check_update_main( # 确保目录下有一个子目录,以便 os.listdir() 能返回一个目录名 mock_tmp_path.mkdir(parents=True, exist_ok=True) - for folder in REPLACE_FOLDERS: + for folder in ZhenxunRepoManager.config.ZHENXUN_BOT_UPDATE_FOLDERS: (mock_base_path / folder).mkdir(parents=True, exist_ok=True) mock_pyproject_file.write_bytes(b"") @@ -401,7 +393,7 @@ async def test_check_update_main( async with app.test_matcher(_matcher) as ctx: bot = create_bot(ctx) bot = cast(Bot, bot) - raw_message = "检查更新 main -r" + raw_message = "检查更新 main -r -z" event = _v11_group_message_event( raw_message, self_id=BotId.QQ_BOT, @@ -420,27 +412,30 @@ async def test_check_update_main( ctx.should_call_api( "send_msg", _v11_private_message_send( - message="检测真寻已更新,版本更新:v0.2.2 -> v0.2.2-e6f17c4\n" - "开始更新...", + message="检测真寻已更新,当前版本:v0.2.2\n开始更新...", user_id=UserId.SUPERUSER, ), ) ctx.should_call_send( event=event, message=Message( - "版本更新完成\n" + "版本更新完成!\n" "版本: v0.2.2 -> v0.2.2-e6f17c4\n" "请重新启动真寻以完成更新!\n" - "资源文件更新成功!" + "真寻资源更新完成!" ), result=None, bot=bot, ) ctx.should_finished(_matcher) assert mocked_api["main_download_url"].called - assert (mock_backup_path / PYPROJECT_FILE_STRING).exists() - assert (mock_backup_path / PYPROJECT_LOCK_FILE_STRING).exists() - assert (mock_backup_path / REQ_TXT_FILE_STRING).exists() + assert (mock_backup_path / ZhenxunRepoManager.config.PYPROJECT_FILE_STRING).exists() + assert ( + mock_backup_path / ZhenxunRepoManager.config.PYPROJECT_LOCK_FILE_STRING + ).exists() + assert ( + mock_backup_path / ZhenxunRepoManager.config.REQUIREMENTS_FILE_STRING + ).exists() assert not mock_download_gz_file.exists() assert not mock_download_zip_file.exists() @@ -449,7 +444,7 @@ async def test_check_update_main( assert mock_pyproject_lock_file.read_bytes() == b"new" assert mock_req_txt_file.read_bytes() == b"new" - for folder in REPLACE_FOLDERS: + for folder in ZhenxunRepoManager.config.ZHENXUN_BOT_UPDATE_FOLDERS: assert (mock_base_path / folder).exists() - for folder in REPLACE_FOLDERS: + for folder in ZhenxunRepoManager.config.ZHENXUN_BOT_UPDATE_FOLDERS: assert (mock_backup_path / folder).exists() diff --git a/tests/builtin_plugins/plugin_store/test_add_plugin.py b/tests/builtin_plugins/plugin_store/test_add_plugin.py index 5a0edab8..1e30f3b1 100644 --- a/tests/builtin_plugins/plugin_store/test_add_plugin.py +++ b/tests/builtin_plugins/plugin_store/test_add_plugin.py @@ -15,7 +15,7 @@ from tests.config import BotId, GroupId, MessageId, UserId from tests.utils import _v11_group_message_event -@pytest.mark.parametrize("package_api", ["jsd", "gh"]) +@pytest.mark.parametrize("package_api", ["gh"]) @pytest.mark.parametrize("is_commit", [True, False]) async def test_add_plugin_basic( package_api: str, @@ -37,18 +37,14 @@ async def test_add_plugin_basic( new=tmp_path / "zhenxun", ) - if package_api != "jsd": - mocked_api["zhenxun_bot_plugins_metadata"].respond(404) - if package_api != "gh": - mocked_api["zhenxun_bot_plugins_tree"].respond(404) - + mocked_api["zhenxun_bot_plugins_metadata"].respond(404) if not is_commit: mocked_api["zhenxun_bot_plugins_commit"].respond(404) mocked_api["zhenxun_bot_plugins_commit_proxy"].respond(404) mocked_api["zhenxun_bot_plugins_index_commit"].respond(404) mocked_api["zhenxun_bot_plugins_index_commit_proxy"].respond(404) - plugin_id = 1 + plugin_id = "search_image" async with app.test_matcher(_matcher) as ctx: bot = create_bot(ctx) @@ -65,7 +61,7 @@ async def test_add_plugin_basic( ctx.receive_event(bot=bot, event=event) ctx.should_call_send( event=event, - message=Message(message=f"正在添加插件 Id: {plugin_id}"), + message=Message(message=f"正在添加插件 Module: {plugin_id}"), result=None, bot=bot, ) @@ -86,7 +82,7 @@ async def test_add_plugin_basic( assert (mock_base_path / "plugins" / "search_image" / "__init__.py").is_file() -@pytest.mark.parametrize("package_api", ["jsd", "gh"]) +@pytest.mark.parametrize("package_api", ["gh"]) @pytest.mark.parametrize("is_commit", [True, False]) async def test_add_plugin_basic_commit_version( package_api: str, @@ -108,17 +104,13 @@ async def test_add_plugin_basic_commit_version( new=tmp_path / "zhenxun", ) - if package_api != "jsd": - mocked_api["zhenxun_bot_plugins_metadata_commit"].respond(404) - if package_api != "gh": - mocked_api["zhenxun_bot_plugins_tree_commit"].respond(404) - + mocked_api["zhenxun_bot_plugins_metadata_commit"].respond(404) if not is_commit: mocked_api["zhenxun_bot_plugins_commit"].respond(404) mocked_api["zhenxun_bot_plugins_commit_proxy"].respond(404) mocked_api["zhenxun_bot_plugins_index_commit"].respond(404) mocked_api["zhenxun_bot_plugins_index_commit_proxy"].respond(404) - plugin_id = 3 + plugin_id = "bilibili_sub" async with app.test_matcher(_matcher) as ctx: bot = create_bot(ctx) @@ -135,7 +127,7 @@ async def test_add_plugin_basic_commit_version( ctx.receive_event(bot=bot, event=event) ctx.should_call_send( event=event, - message=Message(message=f"正在添加插件 Id: {plugin_id}"), + message=Message(message=f"正在添加插件 Module: {plugin_id}"), result=None, bot=bot, ) @@ -159,7 +151,7 @@ async def test_add_plugin_basic_commit_version( assert (mock_base_path / "plugins" / "bilibili_sub" / "__init__.py").is_file() -@pytest.mark.parametrize("package_api", ["jsd", "gh"]) +@pytest.mark.parametrize("package_api", ["gh"]) @pytest.mark.parametrize("is_commit", [True, False]) async def test_add_plugin_basic_is_not_dir( package_api: str, @@ -181,10 +173,7 @@ async def test_add_plugin_basic_is_not_dir( new=tmp_path / "zhenxun", ) - if package_api != "jsd": - mocked_api["zhenxun_bot_plugins_metadata"].respond(404) - if package_api != "gh": - mocked_api["zhenxun_bot_plugins_tree"].respond(404) + mocked_api["zhenxun_bot_plugins_metadata"].respond(404) if not is_commit: mocked_api["zhenxun_bot_plugins_commit"].respond(404) @@ -192,7 +181,7 @@ async def test_add_plugin_basic_is_not_dir( mocked_api["zhenxun_bot_plugins_index_commit"].respond(404) mocked_api["zhenxun_bot_plugins_index_commit_proxy"].respond(404) - plugin_id = 0 + plugin_id = "jitang" async with app.test_matcher(_matcher) as ctx: bot = create_bot(ctx) @@ -209,7 +198,7 @@ async def test_add_plugin_basic_is_not_dir( ctx.receive_event(bot=bot, event=event) ctx.should_call_send( event=event, - message=Message(message=f"正在添加插件 Id: {plugin_id}"), + message=Message(message=f"正在添加插件 Module: {plugin_id}"), result=None, bot=bot, ) @@ -230,7 +219,7 @@ async def test_add_plugin_basic_is_not_dir( assert (mock_base_path / "plugins" / "alapi" / "jitang.py").is_file() -@pytest.mark.parametrize("package_api", ["jsd", "gh"]) +@pytest.mark.parametrize("package_api", ["gh"]) @pytest.mark.parametrize("is_commit", [True, False]) async def test_add_plugin_extra( package_api: str, @@ -252,10 +241,7 @@ async def test_add_plugin_extra( new=tmp_path / "zhenxun", ) - if package_api != "jsd": - mocked_api["zhenxun_github_sub_metadata"].respond(404) - if package_api != "gh": - mocked_api["zhenxun_github_sub_tree"].respond(404) + mocked_api["zhenxun_github_sub_metadata"].respond(404) if not is_commit: mocked_api["zhenxun_github_sub_commit"].respond(404) @@ -265,7 +251,7 @@ async def test_add_plugin_extra( mocked_api["zhenxun_bot_plugins_index_commit"].respond(404) mocked_api["zhenxun_bot_plugins_index_commit_proxy"].respond(404) - plugin_id = 4 + plugin_id = "github_sub" async with app.test_matcher(_matcher) as ctx: bot = create_bot(ctx) @@ -282,7 +268,7 @@ async def test_add_plugin_extra( ctx.receive_event(bot=bot, event=event) ctx.should_call_send( event=event, - message=Message(message=f"正在添加插件 Id: {plugin_id}"), + message=Message(message=f"正在添加插件 Module: {plugin_id}"), result=None, bot=bot, ) @@ -339,7 +325,7 @@ async def test_plugin_not_exist_add( ) ctx.should_call_send( event=event, - message=Message(message="插件ID不存在..."), + message=Message(message="添加插件 Id: -1 失败 e: 插件ID不存在..."), result=None, bot=bot, ) @@ -385,7 +371,9 @@ async def test_add_plugin_exist( ) ctx.should_call_send( event=event, - message=Message(message="插件 识图 已安装,无需重复安装"), + message=Message( + message="添加插件 Id: 1 失败 e: 插件 识图 已安装,无需重复安装" + ), result=None, bot=bot, ) diff --git a/tests/builtin_plugins/plugin_store/test_plugin_store.py b/tests/builtin_plugins/plugin_store/test_plugin_store.py deleted file mode 100644 index 4e8eae16..00000000 --- a/tests/builtin_plugins/plugin_store/test_plugin_store.py +++ /dev/null @@ -1,140 +0,0 @@ -from collections.abc import Callable -from pathlib import Path -from typing import cast - -from nonebot.adapters.onebot.v11 import Bot, Message -from nonebot.adapters.onebot.v11.event import GroupMessageEvent -from nonebug import App -from pytest_mock import MockerFixture -from respx import MockRouter - -from tests.builtin_plugins.plugin_store.utils import init_mocked_api -from tests.config import BotId, GroupId, MessageId, UserId -from tests.utils import _v11_group_message_event - - -async def test_plugin_store( - app: App, - mocker: MockerFixture, - mocked_api: MockRouter, - create_bot: Callable, - tmp_path: Path, -) -> None: - """ - 测试插件商店 - """ - from zhenxun.builtin_plugins.plugin_store import _matcher - from zhenxun.builtin_plugins.plugin_store.data_source import row_style - - init_mocked_api(mocked_api=mocked_api) - - mock_table_page = mocker.patch( - "zhenxun.builtin_plugins.plugin_store.data_source.ImageTemplate.table_page" - ) - mock_table_page_return = mocker.AsyncMock() - mock_table_page.return_value = mock_table_page_return - - mock_build_message = mocker.patch( - "zhenxun.builtin_plugins.plugin_store.MessageUtils.build_message" - ) - mock_build_message_return = mocker.AsyncMock() - mock_build_message.return_value = mock_build_message_return - - async with app.test_matcher(_matcher) as ctx: - bot = create_bot(ctx) - bot: Bot = cast(Bot, bot) - raw_message = "插件商店" - event: GroupMessageEvent = _v11_group_message_event( - message=raw_message, - self_id=BotId.QQ_BOT, - user_id=UserId.SUPERUSER, - group_id=GroupId.GROUP_ID_LEVEL_5, - message_id=MessageId.MESSAGE_ID_3, - to_me=True, - ) - ctx.receive_event(bot=bot, event=event) - mock_table_page.assert_awaited_once_with( - "插件列表", - "通过添加/移除插件 ID 来管理插件", - ["-", "ID", "名称", "简介", "作者", "版本", "类型"], - [ - ["", 0, "鸡汤", "喏,亲手为你煮的鸡汤", "HibiKier", "0.1", "普通插件"], - ["", 1, "识图", "以图搜图,看破本源", "HibiKier", "0.1", "普通插件"], - ["", 2, "网易云热评", "生了个人,我很抱歉", "HibiKier", "0.1", "普通插件"], - [ - "", - 3, - "B站订阅", - "非常便利的B站订阅通知", - "HibiKier", - "0.3-b101fbc", - "普通插件", - ], - [ - "", - 4, - "github订阅", - "订阅github用户或仓库", - "xuanerwa", - "0.7", - "普通插件", - ], - [ - "", - 5, - "Minecraft查服", - "Minecraft服务器状态查询,支持IPv6", - "molanp", - "1.13", - "普通插件", - ], - ], - text_style=row_style, - ) - mock_build_message.assert_called_once_with(mock_table_page_return) - mock_build_message_return.send.assert_awaited_once() - - assert mocked_api["basic_plugins"].called - assert mocked_api["extra_plugins"].called - - -async def test_plugin_store_fail( - app: App, - mocker: MockerFixture, - mocked_api: MockRouter, - create_bot: Callable, - tmp_path: Path, -) -> None: - """ - 测试插件商店 - """ - from zhenxun.builtin_plugins.plugin_store import _matcher - - init_mocked_api(mocked_api=mocked_api) - mocked_api.get( - "https://raw.githubusercontent.com/zhenxun-org/zhenxun_bot_plugins/b101fbc/plugins.json", - name="basic_plugins", - ).respond(404) - - async with app.test_matcher(_matcher) as ctx: - bot = create_bot(ctx) - bot: Bot = cast(Bot, bot) - raw_message = "插件商店" - event: GroupMessageEvent = _v11_group_message_event( - message=raw_message, - self_id=BotId.QQ_BOT, - user_id=UserId.SUPERUSER, - group_id=GroupId.GROUP_ID_LEVEL_5, - message_id=MessageId.MESSAGE_ID_3, - to_me=True, - ) - ctx.receive_event(bot=bot, event=event) - ctx.should_call_send( - event=event, - message=Message("获取插件列表失败..."), - result=None, - exception=None, - bot=bot, - ) - - assert mocked_api["basic_plugins"].called diff --git a/tests/builtin_plugins/plugin_store/test_remove_plugin.py b/tests/builtin_plugins/plugin_store/test_remove_plugin.py index 4d5e3ab1..fe2f92a9 100644 --- a/tests/builtin_plugins/plugin_store/test_remove_plugin.py +++ b/tests/builtin_plugins/plugin_store/test_remove_plugin.py @@ -96,7 +96,7 @@ async def test_plugin_not_exist_remove( ctx.receive_event(bot=bot, event=event) ctx.should_call_send( event=event, - message=Message(message="插件ID不存在..."), + message=Message(message="移除插件 Id: -1 失败 e: 插件ID不存在..."), result=None, bot=bot, ) diff --git a/tests/builtin_plugins/plugin_store/test_update_plugin.py b/tests/builtin_plugins/plugin_store/test_update_plugin.py index 2cb88d1b..39412de9 100644 --- a/tests/builtin_plugins/plugin_store/test_update_plugin.py +++ b/tests/builtin_plugins/plugin_store/test_update_plugin.py @@ -158,7 +158,7 @@ async def test_plugin_not_exist_update( ) ctx.should_call_send( event=event, - message=Message(message="插件ID不存在..."), + message=Message(message="更新插件 Id: -1 失败 e: 插件ID不存在..."), result=None, bot=bot, ) @@ -200,7 +200,9 @@ async def test_update_plugin_not_install( ) ctx.should_call_send( event=event, - message=Message(message="插件 识图 未安装,无法更新"), + message=Message( + message="更新插件 Id: 1 失败 e: 插件 识图 未安装,无法更新" + ), result=None, bot=bot, ) diff --git a/zhenxun/builtin_plugins/__init__.py b/zhenxun/builtin_plugins/__init__.py index a5aa7a4b..825e23b1 100644 --- a/zhenxun/builtin_plugins/__init__.py +++ b/zhenxun/builtin_plugins/__init__.py @@ -17,7 +17,7 @@ from zhenxun.models.user_console import UserConsole from zhenxun.services.log import logger from zhenxun.utils.decorator.shop import shop_register from zhenxun.utils.manager.priority_manager import PriorityLifecycle -from zhenxun.utils.manager.resource_manager import ResourceManager +from zhenxun.utils.manager.zhenxun_repo_manager import ZhenxunRepoManager from zhenxun.utils.platform import PlatformUtils driver: Driver = nonebot.get_driver() @@ -85,7 +85,8 @@ from bag_users t1 @PriorityLifecycle.on_startup(priority=5) async def _(): - await ResourceManager.init_resources() + if not ZhenxunRepoManager.check_resources_exists(): + await ZhenxunRepoManager.resources_update(branch="test") """签到与用户的数据迁移""" if goods_list := await GoodsInfo.filter(uuid__isnull=True).all(): for goods in goods_list: diff --git a/zhenxun/builtin_plugins/admin/group_member_update/_data_source.py b/zhenxun/builtin_plugins/admin/group_member_update/_data_source.py index 5c5f1d72..977cad35 100644 --- a/zhenxun/builtin_plugins/admin/group_member_update/_data_source.py +++ b/zhenxun/builtin_plugins/admin/group_member_update/_data_source.py @@ -104,25 +104,16 @@ class MemberUpdateManage: exist_member_list.append(member.id) if data_list[0]: try: - await GroupInfoUser.bulk_create(data_list[0], 30) + await GroupInfoUser.bulk_create( + data_list[0], 30, ignore_conflicts=True + ) logger.debug( f"创建用户数据 {len(data_list[0])} 条", "更新群组成员信息", target=group_id, ) except Exception as e: - logger.error( - f"批量创建用户数据失败: {e},开始进行逐个存储", - "更新群组成员信息", - ) - for u in data_list[0]: - try: - await u.save() - except Exception as e: - logger.error( - f"创建用户 {u.user_name}({u.user_id}) 数据失败: {e}", - "更新群组成员信息", - ) + logger.error("批量创建用户数据失败", "更新群组成员信息", e=e) if data_list[1]: await GroupInfoUser.bulk_update(data_list[1], ["user_name"], 30) logger.debug( diff --git a/zhenxun/builtin_plugins/auto_update/__init__.py b/zhenxun/builtin_plugins/auto_update/__init__.py index 764fc39c..0595066c 100644 --- a/zhenxun/builtin_plugins/auto_update/__init__.py +++ b/zhenxun/builtin_plugins/auto_update/__init__.py @@ -16,10 +16,6 @@ from nonebot_plugin_uninfo import Uninfo from zhenxun.configs.utils import PluginExtraData from zhenxun.services.log import logger from zhenxun.utils.enum import PluginType -from zhenxun.utils.manager.resource_manager import ( - DownloadResourceException, - ResourceManager, -) from zhenxun.utils.message import MessageUtils from ._data_source import UpdateManager @@ -32,15 +28,23 @@ __plugin_meta__ = PluginMetadata( 检查更新真寻最新版本,包括了自动更新 资源文件大小一般在130mb左右,除非必须更新一般仅更新代码文件 指令: - 检查更新 [main|release|resource|webui] ?[-r] + 检查更新 [main|release|resource|webui] ?[-r] ?[-f] ?[-z] ?[-t] main: main分支 release: 最新release resource: 资源文件 webui: webui文件 -r: 下载资源文件,一般在更新main或release时使用 + -f: 强制更新,一般用于更新main时使用(仅git更新时有效) + -s: 更新源,为 git 或 ali(默认使用ali) + -z: 下载zip文件进行更新(仅git有效) + -t: 更新方式,git或download(默认使用git) + git: 使用git pull(推荐) + download: 通过commit hash比较文件后下载更新(仅git有效) + 示例: 检查更新 main 检查更新 main -r + 检查更新 main -f 检查更新 release -r 检查更新 resource 检查更新 webui @@ -57,6 +61,9 @@ _matcher = on_alconna( "检查更新", Args["ver_type?", ["main", "release", "resource", "webui"]], Option("-r|--resource", action=store_true, help_text="下载资源文件"), + Option("-f|--force", action=store_true, help_text="强制更新"), + Option("-s", Args["source?", ["git", "ali"]], help_text="更新源"), + Option("-z|--zip", action=store_true, help_text="下载zip文件"), ), priority=1, block=True, @@ -71,30 +78,56 @@ async def _( session: Uninfo, ver_type: Match[str], resource: Query[bool] = Query("resource", False), + force: Query[bool] = Query("force", False), + source: Query[str] = Query("source", "ali"), + zip: Query[bool] = Query("zip", False), ): result = "" await MessageUtils.build_message("正在进行检查更新...").send(reply_to=True) - if ver_type.result in {"main", "release"}: + ver_type_str = ver_type.result + source_str = source.result + if ver_type_str in {"main", "release"}: if not ver_type.available: - result = await UpdateManager.check_version() + result += await UpdateManager.check_version() logger.info("查看当前版本...", "检查更新", session=session) await MessageUtils.build_message(result).finish() try: - result = await UpdateManager.update(bot, session.user.id, ver_type.result) + result += await UpdateManager.update_zhenxun( + bot, + session.user.id, + ver_type_str, # type: ignore + force.result, + source_str, # type: ignore + zip.result, + ) + await MessageUtils.build_message(result).finish(reply_to=True) except Exception as e: logger.error("版本更新失败...", "检查更新", session=session, e=e) await MessageUtils.build_message(f"更新版本失败...e: {e}").finish() elif ver_type.result == "webui": - result = await UpdateManager.update_webui() + if zip.result: + source_str = None + try: + result += await UpdateManager.update_webui( + source_str, # type: ignore + "test", + True, + ) + except Exception as e: + logger.error("WebUI更新失败...", "检查更新", session=session, e=e) + result += "\nWebUI更新错误..." if resource.result or ver_type.result == "resource": try: - await ResourceManager.init_resources(True) - result += "\n资源文件更新成功!" - except DownloadResourceException: - result += "\n资源更新下载失败..." + if zip.result: + source_str = None + result += await UpdateManager.update_resources( + source_str, # type: ignore + "main", + force.result, + ) except Exception as e: logger.error("资源更新下载失败...", "检查更新", session=session, e=e) - result += "\n资源更新未知错误..." + result += "\n资源更新错误..." if result: await MessageUtils.build_message(result.strip()).finish() await MessageUtils.build_message("更新版本失败...").finish() diff --git a/zhenxun/builtin_plugins/auto_update/_data_source.py b/zhenxun/builtin_plugins/auto_update/_data_source.py index 5fbeaa5d..f7752062 100644 --- a/zhenxun/builtin_plugins/auto_update/_data_source.py +++ b/zhenxun/builtin_plugins/auto_update/_data_source.py @@ -1,170 +1,16 @@ -import os -import shutil -import subprocess -import tarfile -import zipfile +from typing import Literal from nonebot.adapters import Bot -from nonebot.utils import run_sync -from zhenxun.configs.path_config import DATA_PATH from zhenxun.services.log import logger -from zhenxun.utils.github_utils import GithubUtils -from zhenxun.utils.github_utils.models import RepoInfo -from zhenxun.utils.http_utils import AsyncHttpx +from zhenxun.utils.manager.virtual_env_package_manager import VirtualEnvPackageManager +from zhenxun.utils.manager.zhenxun_repo_manager import ZhenxunRepoManager from zhenxun.utils.platform import PlatformUtils -from .config import ( - BACKUP_PATH, - BASE_PATH, - BASE_PATH_STRING, - COMMAND, - DEFAULT_GITHUB_URL, - DOWNLOAD_GZ_FILE, - DOWNLOAD_ZIP_FILE, - PYPROJECT_FILE, - PYPROJECT_FILE_STRING, - PYPROJECT_LOCK_FILE, - PYPROJECT_LOCK_FILE_STRING, - RELEASE_URL, - REPLACE_FOLDERS, - REQ_TXT_FILE, - REQ_TXT_FILE_STRING, - TMP_PATH, - VERSION_FILE, -) - - -def install_requirement(): - requirement_path = (REQ_TXT_FILE).absolute() - - if not requirement_path.exists(): - logger.debug( - f"没有找到zhenxun的requirement.txt,目标路径为{requirement_path}", COMMAND - ) - return - try: - result = subprocess.run( - ["pip", "install", "-r", str(requirement_path)], - check=True, - capture_output=True, - text=True, - ) - logger.debug(f"成功安装真寻依赖,日志:\n{result.stdout}", COMMAND) - except subprocess.CalledProcessError as e: - logger.error(f"安装真寻依赖失败,错误:\n{e.stderr}", COMMAND, e=e) - - -@run_sync -def _file_handle(latest_version: str | None): - """文件移动操作 - - 参数: - latest_version: 版本号 - """ - BACKUP_PATH.mkdir(exist_ok=True, parents=True) - logger.debug("开始解压文件压缩包...", COMMAND) - download_file = DOWNLOAD_GZ_FILE - if DOWNLOAD_GZ_FILE.exists(): - tf = tarfile.open(DOWNLOAD_GZ_FILE) - else: - download_file = DOWNLOAD_ZIP_FILE - tf = zipfile.ZipFile(DOWNLOAD_ZIP_FILE) - tf.extractall(TMP_PATH) - logger.debug("解压文件压缩包完成...", COMMAND) - download_file_path = TMP_PATH / next( - x for x in os.listdir(TMP_PATH) if (TMP_PATH / x).is_dir() - ) - _pyproject = download_file_path / PYPROJECT_FILE_STRING - _lock_file = download_file_path / PYPROJECT_LOCK_FILE_STRING - _req_file = download_file_path / REQ_TXT_FILE_STRING - extract_path = download_file_path / BASE_PATH_STRING - target_path = BASE_PATH - if PYPROJECT_FILE.exists(): - logger.debug(f"移除备份文件: {PYPROJECT_FILE}", COMMAND) - shutil.move(PYPROJECT_FILE, BACKUP_PATH / PYPROJECT_FILE_STRING) - if PYPROJECT_LOCK_FILE.exists(): - logger.debug(f"移除备份文件: {PYPROJECT_LOCK_FILE}", COMMAND) - shutil.move(PYPROJECT_LOCK_FILE, BACKUP_PATH / PYPROJECT_LOCK_FILE_STRING) - if REQ_TXT_FILE.exists(): - logger.debug(f"移除备份文件: {REQ_TXT_FILE}", COMMAND) - shutil.move(REQ_TXT_FILE, BACKUP_PATH / REQ_TXT_FILE_STRING) - if _pyproject.exists(): - logger.debug("移动文件: pyproject.toml", COMMAND) - shutil.move(_pyproject, PYPROJECT_FILE) - if _lock_file.exists(): - logger.debug("移动文件: poetry.lock", COMMAND) - shutil.move(_lock_file, PYPROJECT_LOCK_FILE) - if _req_file.exists(): - logger.debug("移动文件: requirements.txt", COMMAND) - shutil.move(_req_file, REQ_TXT_FILE) - for folder in REPLACE_FOLDERS: - """移动指定文件夹""" - _dir = BASE_PATH / folder - _backup_dir = BACKUP_PATH / folder - if _backup_dir.exists(): - logger.debug(f"删除备份文件夹 {_backup_dir}", COMMAND) - shutil.rmtree(_backup_dir) - if _dir.exists(): - logger.debug(f"移动旧文件夹 {_dir}", COMMAND) - shutil.move(_dir, _backup_dir) - else: - logger.warning(f"文件夹 {_dir} 不存在,跳过删除", COMMAND) - for folder in REPLACE_FOLDERS: - src_folder_path = extract_path / folder - dest_folder_path = target_path / folder - if src_folder_path.exists(): - logger.debug( - f"移动文件夹: {src_folder_path} -> {dest_folder_path}", COMMAND - ) - shutil.move(src_folder_path, dest_folder_path) - else: - logger.debug(f"源文件夹不存在: {src_folder_path}", COMMAND) - if tf: - tf.close() - if download_file.exists(): - logger.debug(f"删除下载文件: {download_file}", COMMAND) - download_file.unlink() - if extract_path.exists(): - logger.debug(f"删除解压文件夹: {extract_path}", COMMAND) - shutil.rmtree(extract_path) - if TMP_PATH.exists(): - shutil.rmtree(TMP_PATH) - if latest_version: - with open(VERSION_FILE, "w", encoding="utf8") as f: - f.write(f"__version__: {latest_version}") - install_requirement() +from .config import LOG_COMMAND, REQUIREMENTS_FILE, VERSION_FILE class UpdateManager: - @classmethod - async def update_webui(cls) -> str: - from zhenxun.builtin_plugins.web_ui.public.data_source import ( - update_webui_assets, - ) - - WEBUI_PATH = DATA_PATH / "web_ui" / "public" - BACKUP_PATH = DATA_PATH / "web_ui" / "backup_public" - if WEBUI_PATH.exists(): - if BACKUP_PATH.exists(): - logger.debug(f"删除旧的备份webui文件夹 {BACKUP_PATH}", COMMAND) - shutil.rmtree(BACKUP_PATH) - WEBUI_PATH.rename(BACKUP_PATH) - try: - await update_webui_assets() - logger.info("更新webui成功...", COMMAND) - if BACKUP_PATH.exists(): - logger.debug(f"删除旧的webui文件夹 {BACKUP_PATH}", COMMAND) - shutil.rmtree(BACKUP_PATH) - return "Webui更新成功!" - except Exception as e: - logger.error("更新webui失败...", COMMAND, e=e) - if BACKUP_PATH.exists(): - logger.debug(f"恢复旧的webui文件夹 {BACKUP_PATH}", COMMAND) - BACKUP_PATH.rename(WEBUI_PATH) - raise e - return "" - @classmethod async def check_version(cls) -> str: """检查更新版本 @@ -173,75 +19,142 @@ class UpdateManager: str: 更新信息 """ cur_version = cls.__get_version() - data = await cls.__get_latest_data() - if not data: + release_data = await ZhenxunRepoManager.zhenxun_get_latest_releases_data() + if not release_data: return "检查更新获取版本失败..." return ( "检测到当前版本更新\n" f"当前版本:{cur_version}\n" - f"最新版本:{data.get('name')}\n" - f"创建日期:{data.get('created_at')}\n" - f"更新内容:\n{data.get('body')}" + f"最新版本:{release_data.get('name')}\n" + f"创建日期:{release_data.get('created_at')}\n" + f"更新内容:\n{release_data.get('body')}" ) @classmethod - async def update(cls, bot: Bot, user_id: str, version_type: str) -> str: + async def update_webui( + cls, + source: Literal["git", "ali"] | None, + branch: str = "dist", + force: bool = False, + ): + """更新WebUI + + 参数: + source: 更新源 + branch: 分支 + force: 是否强制更新 + + 返回: + str: 返回消息 + """ + if not source: + await ZhenxunRepoManager.webui_zip_update() + return "WebUI更新完成!" + result = await ZhenxunRepoManager.webui_git_update( + source, + branch=branch, + force=force, + ) + if not result.success: + logger.error(f"WebUI更新失败...错误: {result.error_message}", LOG_COMMAND) + return f"WebUI更新失败...错误: {result.error_message}" + return "WebUI更新完成!" + + @classmethod + async def update_resources( + cls, + source: Literal["git", "ali"] | None, + branch: str = "main", + force: bool = False, + ) -> str: + """更新资源 + + 参数: + source: 更新源 + branch: 分支 + force: 是否强制更新 + + 返回: + str: 返回消息 + """ + if not source: + await ZhenxunRepoManager.resources_zip_update() + return "真寻资源更新完成!" + result = await ZhenxunRepoManager.resources_git_update( + source, + branch=branch, + force=force, + ) + if not result.success: + logger.error( + f"真寻资源更新失败...错误: {result.error_message}", LOG_COMMAND + ) + return f"真寻资源更新失败...错误: {result.error_message}" + return "真寻资源更新完成!" + + @classmethod + async def update_zhenxun( + cls, + bot: Bot, + user_id: str, + version_type: Literal["main", "release"], + force: bool, + source: Literal["git", "ali"], + zip: bool, + ) -> str: """更新操作 参数: bot: Bot user_id: 用户id version_type: 更新版本类型 + force: 是否强制更新 + source: 更新源 + zip: 是否下载zip文件 + update_type: 更新方式 返回: str | None: 返回消息 """ - logger.info("开始下载真寻最新版文件....", COMMAND) cur_version = cls.__get_version() - url = None - new_version = None - repo_info = GithubUtils.parse_github_url(DEFAULT_GITHUB_URL) - if version_type in {"main"}: - repo_info.branch = version_type - new_version = await cls.__get_version_from_repo(repo_info) - if new_version: - new_version = new_version.split(":")[-1].strip() - url = await repo_info.get_archive_download_urls() - elif version_type == "release": - data = await cls.__get_latest_data() - if not data: - return "获取更新版本失败..." - new_version = data.get("name", "") - url = await repo_info.get_release_source_download_urls_tgz(new_version) - if not url: - return "获取版本下载链接失败..." - if TMP_PATH.exists(): - logger.debug(f"删除临时文件夹 {TMP_PATH}", COMMAND) - shutil.rmtree(TMP_PATH) - logger.debug( - f"开始更新版本:{cur_version} -> {new_version} | 下载链接:{url}", - COMMAND, - ) await PlatformUtils.send_superuser( bot, - f"检测真寻已更新,版本更新:{cur_version} -> {new_version}\n开始更新...", + f"检测真寻已更新,当前版本:{cur_version}\n开始更新...", user_id, ) - download_file = ( - DOWNLOAD_GZ_FILE if version_type == "release" else DOWNLOAD_ZIP_FILE - ) - if await AsyncHttpx.download_file(url, download_file, stream=True): - logger.debug("下载真寻最新版文件完成...", COMMAND) - await _file_handle(new_version) - result = "版本更新完成" + if zip: + new_version = await ZhenxunRepoManager.zhenxun_zip_update(version_type) + await PlatformUtils.send_superuser( + bot, "真寻更新完成,开始安装依赖...", user_id + ) + await VirtualEnvPackageManager.install_requirement(REQUIREMENTS_FILE) return ( - f"{result}\n" - f"版本: {cur_version} -> {new_version}\n" + f"版本更新完成!\n版本: {cur_version} -> {new_version}\n" "请重新启动真寻以完成更新!" ) else: - logger.debug("下载真寻最新版文件失败...", COMMAND) - return "" + result = await ZhenxunRepoManager.zhenxun_git_update( + source, + branch=version_type, + force=force, + ) + if not result.success: + logger.error( + f"真寻版本更新失败...错误: {result.error_message}", + LOG_COMMAND, + ) + return f"版本更新失败...错误: {result.error_message}" + await PlatformUtils.send_superuser( + bot, "真寻更新完成,开始安装依赖...", user_id + ) + await VirtualEnvPackageManager.install_requirement(REQUIREMENTS_FILE) + return ( + f"版本更新完成!\n" + f"版本: {cur_version} -> {result.new_version}\n" + f"变更文件个数: {len(result.changed_files)}" + f"{'' if source == 'git' else '(阿里云更新不支持查看变更文件)'}\n" + "请重新启动真寻以完成更新!" + ) @classmethod def __get_version(cls) -> str: @@ -255,40 +168,3 @@ class UpdateManager: if text := VERSION_FILE.open(encoding="utf8").readline(): _version = text.split(":")[-1].strip() return _version - - @classmethod - async def __get_latest_data(cls) -> dict: - """获取最新版本信息 - - 返回: - dict: 最新版本数据 - """ - for _ in range(3): - try: - res = await AsyncHttpx.get(RELEASE_URL) - if res.status_code == 200: - return res.json() - except TimeoutError: - pass - except Exception as e: - logger.error("检查更新真寻获取版本失败", e=e) - return {} - - @classmethod - async def __get_version_from_repo(cls, repo_info: RepoInfo) -> str: - """从指定分支获取版本号 - - 参数: - branch: 分支名称 - - 返回: - str: 版本号 - """ - version_url = await repo_info.get_raw_download_urls(path="__version__") - try: - res = await AsyncHttpx.get(version_url) - if res.status_code == 200: - return res.text.strip() - except Exception as e: - logger.error(f"获取 {repo_info.branch} 分支版本失败", e=e) - return "未知版本" diff --git a/zhenxun/builtin_plugins/auto_update/config.py b/zhenxun/builtin_plugins/auto_update/config.py index 1516b5e6..85a44c9c 100644 --- a/zhenxun/builtin_plugins/auto_update/config.py +++ b/zhenxun/builtin_plugins/auto_update/config.py @@ -1,38 +1,7 @@ from pathlib import Path -from zhenxun.configs.path_config import TEMP_PATH +LOG_COMMAND = "AutoUpdate" -DEFAULT_GITHUB_URL = "https://github.com/HibiKier/zhenxun_bot/tree/main" -RELEASE_URL = "https://api.github.com/repos/HibiKier/zhenxun_bot/releases/latest" +VERSION_FILE = Path() / "__version__" -VERSION_FILE_STRING = "__version__" -VERSION_FILE = Path() / VERSION_FILE_STRING - -PYPROJECT_FILE_STRING = "pyproject.toml" -PYPROJECT_FILE = Path() / PYPROJECT_FILE_STRING -PYPROJECT_LOCK_FILE_STRING = "poetry.lock" -PYPROJECT_LOCK_FILE = Path() / PYPROJECT_LOCK_FILE_STRING -REQ_TXT_FILE_STRING = "requirements.txt" -REQ_TXT_FILE = Path() / REQ_TXT_FILE_STRING - -BASE_PATH_STRING = "zhenxun" -BASE_PATH = Path() / BASE_PATH_STRING - -TMP_PATH = TEMP_PATH / "auto_update" - -BACKUP_PATH = Path() / "backup" - -DOWNLOAD_GZ_FILE_STRING = "download_latest_file.tar.gz" -DOWNLOAD_ZIP_FILE_STRING = "download_latest_file.zip" -DOWNLOAD_GZ_FILE = TMP_PATH / DOWNLOAD_GZ_FILE_STRING -DOWNLOAD_ZIP_FILE = TMP_PATH / DOWNLOAD_ZIP_FILE_STRING - -REPLACE_FOLDERS = [ - "builtin_plugins", - "services", - "utils", - "models", - "configs", -] - -COMMAND = "检查更新" +REQUIREMENTS_FILE = Path() / "requirements.txt" diff --git a/zhenxun/builtin_plugins/hooks/auth_checker.py b/zhenxun/builtin_plugins/hooks/auth_checker.py index 9e9c4e0d..cf2c97c7 100644 --- a/zhenxun/builtin_plugins/hooks/auth_checker.py +++ b/zhenxun/builtin_plugins/hooks/auth_checker.py @@ -148,6 +148,11 @@ async def get_plugin_and_user( user = await with_timeout( user_dao.safe_get_or_none(user_id=user_id), name="get_user" ) + except IntegrityError: + await asyncio.sleep(0.5) + plugin, user = await with_timeout( + asyncio.gather(plugin_task, user_task), name="get_plugin_and_user" + ) if not plugin: raise PermissionExemption(f"插件:{module} 数据不存在,已跳过权限检查...") diff --git a/zhenxun/builtin_plugins/plugin_store/__init__.py b/zhenxun/builtin_plugins/plugin_store/__init__.py index 72d6d7dd..3dfde320 100644 --- a/zhenxun/builtin_plugins/plugin_store/__init__.py +++ b/zhenxun/builtin_plugins/plugin_store/__init__.py @@ -84,7 +84,7 @@ async def _(session: EventSession): try: result = await StoreManager.get_plugins_info() logger.info("查看插件列表", "插件商店", session=session) - await MessageUtils.build_message(result).send() + await MessageUtils.build_message([*result]).send() except Exception as e: logger.error(f"查看插件列表失败 e: {e}", "插件商店", session=session, e=e) await MessageUtils.build_message("获取插件列表失败...").send() diff --git a/zhenxun/builtin_plugins/plugin_store/data_source.py b/zhenxun/builtin_plugins/plugin_store/data_source.py index 58fab1a1..4bb3b64f 100644 --- a/zhenxun/builtin_plugins/plugin_store/data_source.py +++ b/zhenxun/builtin_plugins/plugin_store/data_source.py @@ -1,19 +1,19 @@ from pathlib import Path +import random import shutil from aiocache import cached import ujson as json -from zhenxun.builtin_plugins.auto_update.config import REQ_TXT_FILE_STRING from zhenxun.builtin_plugins.plugin_store.models import StorePluginInfo +from zhenxun.configs.path_config import TEMP_PATH from zhenxun.models.plugin_info import PluginInfo from zhenxun.services.log import logger from zhenxun.services.plugin_init import PluginInitManager -from zhenxun.utils.github_utils import GithubUtils -from zhenxun.utils.github_utils.models import RepoAPI -from zhenxun.utils.http_utils import AsyncHttpx from zhenxun.utils.image_utils import BuildImage, ImageTemplate, RowStyle from zhenxun.utils.manager.virtual_env_package_manager import VirtualEnvPackageManager +from zhenxun.utils.repo_utils import RepoFileManager +from zhenxun.utils.repo_utils.models import RepoFileInfo, RepoType from zhenxun.utils.utils import is_number from .config import ( @@ -22,6 +22,7 @@ from .config import ( EXTRA_GITHUB_URL, LOG_COMMAND, ) +from .exceptions import PluginStoreException def row_style(column: str, text: str) -> RowStyle: @@ -40,73 +41,25 @@ def row_style(column: str, text: str) -> RowStyle: return style -def install_requirement(plugin_path: Path): - requirement_files = ["requirement.txt", "requirements.txt"] - requirement_paths = [plugin_path / file for file in requirement_files] - - if existing_requirements := next( - (path for path in requirement_paths if path.exists()), None - ): - VirtualEnvPackageManager.install_requirement(existing_requirements) - - class StoreManager: - @classmethod - async def get_github_plugins(cls) -> list[StorePluginInfo]: - """获取github插件列表信息 - - 返回: - list[StorePluginInfo]: 插件列表数据 - """ - repo_info = GithubUtils.parse_github_url(DEFAULT_GITHUB_URL) - if await repo_info.update_repo_commit(): - logger.info(f"获取最新提交: {repo_info.branch}", LOG_COMMAND) - else: - logger.warning(f"获取最新提交失败: {repo_info}", LOG_COMMAND) - default_github_url = await repo_info.get_raw_download_urls("plugins.json") - response = await AsyncHttpx.get(default_github_url, check_status_code=200) - if response.status_code == 200: - logger.info("获取github插件列表成功", LOG_COMMAND) - return [StorePluginInfo(**detail) for detail in json.loads(response.text)] - else: - logger.warning( - f"获取github插件列表失败: {response.status_code}", LOG_COMMAND - ) - return [] - - @classmethod - async def get_extra_plugins(cls) -> list[StorePluginInfo]: - """获取额外插件列表信息 - - 返回: - list[StorePluginInfo]: 插件列表数据 - """ - repo_info = GithubUtils.parse_github_url(EXTRA_GITHUB_URL) - if await repo_info.update_repo_commit(): - logger.info(f"获取最新提交: {repo_info.branch}", LOG_COMMAND) - else: - logger.warning(f"获取最新提交失败: {repo_info}", LOG_COMMAND) - extra_github_url = await repo_info.get_raw_download_urls("plugins.json") - response = await AsyncHttpx.get(extra_github_url, check_status_code=200) - if response.status_code == 200: - return [StorePluginInfo(**detail) for detail in json.loads(response.text)] - else: - logger.warning( - f"获取github扩展插件列表失败: {response.status_code}", LOG_COMMAND - ) - return [] - @classmethod @cached(60) - async def get_data(cls) -> list[StorePluginInfo]: + async def get_data(cls) -> tuple[list[StorePluginInfo], list[StorePluginInfo]]: """获取插件信息数据 返回: - list[StorePluginInfo]: 插件信息数据 + tuple[list[StorePluginInfo], list[StorePluginInfo]]: + 原生插件信息数据,第三方插件信息数据 """ - plugins = await cls.get_github_plugins() - extra_plugins = await cls.get_extra_plugins() - return [*plugins, *extra_plugins] + plugins = await RepoFileManager.get_file_content( + DEFAULT_GITHUB_URL, "plugins.json" + ) + extra_plugins = await RepoFileManager.get_file_content( + EXTRA_GITHUB_URL, "plugins.json", "index" + ) + return [StorePluginInfo(**plugin) for plugin in json.loads(plugins)], [ + StorePluginInfo(**plugin) for plugin in json.loads(extra_plugins) + ] @classmethod def version_check(cls, plugin_info: StorePluginInfo, suc_plugin: dict[str, str]): @@ -152,38 +105,105 @@ class StoreManager: return await PluginInfo.filter(load_status=True).values_list(*args) @classmethod - async def get_plugins_info(cls) -> BuildImage | str: + async def get_plugins_info(cls) -> list[BuildImage] | str: """插件列表 返回: BuildImage | str: 返回消息 """ - plugin_list: list[StorePluginInfo] = await cls.get_data() + plugin_list, extra_plugin_list = await cls.get_data() column_name = ["-", "ID", "名称", "简介", "作者", "版本", "类型"] db_plugin_list = await cls.get_loaded_plugins("module", "version") suc_plugin = {p[0]: (p[1] or "0.1") for p in db_plugin_list} - data_list = [ - [ - "已安装" if plugin_info.module in suc_plugin else "", - id, - plugin_info.name, - plugin_info.description, - plugin_info.author, - cls.version_check(plugin_info, suc_plugin), - plugin_info.plugin_type_name, - ] - for id, plugin_info in enumerate(plugin_list) + index = 0 + data_list = [] + extra_data_list = [] + for plugin_info in plugin_list: + data_list.append( + [ + "已安装" if plugin_info.module in suc_plugin else "", + index, + plugin_info.name, + plugin_info.description, + plugin_info.author, + cls.version_check(plugin_info, suc_plugin), + plugin_info.plugin_type_name, + ] + ) + index += 1 + for plugin_info in extra_plugin_list: + extra_data_list.append( + [ + "已安装" if plugin_info.module in suc_plugin else "", + index, + plugin_info.name, + plugin_info.description, + plugin_info.author, + cls.version_check(plugin_info, suc_plugin), + plugin_info.plugin_type_name, + ] + ) + index += 1 + return [ + await ImageTemplate.table_page( + "原生插件列表", + "通过添加/移除插件 ID 来管理插件", + column_name, + data_list, + text_style=row_style, + ), + await ImageTemplate.table_page( + "第三方插件列表", + "通过添加/移除插件 ID 来管理插件", + column_name, + extra_data_list, + text_style=row_style, + ), ] - return await ImageTemplate.table_page( - "插件列表", - "通过添加/移除插件 ID 来管理插件", - column_name, - data_list, - text_style=row_style, - ) @classmethod - async def add_plugin(cls, plugin_id: str) -> str: + async def get_plugin_by_value( + cls, index_or_module: str, is_update: bool = False + ) -> tuple[StorePluginInfo, bool]: + """获取插件信息 + + 参数: + index_or_module: 插件索引或模块名 + is_update: 是否是更新插件 + + 异常: + PluginStoreException: 插件不存在 + PluginStoreException: 插件已安装 + + 返回: + StorePluginInfo: 插件信息 + bool: 是否是外部插件 + """ + plugin_list, extra_plugin_list = await cls.get_data() + plugin_info = None + is_external = False + db_plugin_list = await cls.get_loaded_plugins("module") + plugin_key = await cls._resolve_plugin_key(index_or_module) + for p in plugin_list: + if p.module == plugin_key: + is_external = False + plugin_info = p + break + for p in extra_plugin_list: + if p.module == plugin_key: + is_external = True + plugin_info = p + break + if not plugin_info: + raise PluginStoreException(f"插件不存在: {plugin_key}") + if not is_update and plugin_info.module in [p[0] for p in db_plugin_list]: + raise PluginStoreException(f"插件 {plugin_info.name} 已安装,无需重复安装") + if plugin_info.module not in [p[0] for p in db_plugin_list] and is_update: + raise PluginStoreException(f"插件 {plugin_info.name} 未安装,无法更新") + return plugin_info, is_external + + @classmethod + async def add_plugin(cls, index_or_module: str) -> str: """添加插件 参数: @@ -192,21 +212,9 @@ class StoreManager: 返回: str: 返回消息 """ - plugin_list: list[StorePluginInfo] = await cls.get_data() - try: - plugin_key = await cls._resolve_plugin_key(plugin_id) - except ValueError as e: - return str(e) - db_plugin_list = await cls.get_loaded_plugins("module") - plugin_info = next((p for p in plugin_list if p.module == plugin_key), None) - if plugin_info is None: - return f"未找到插件 {plugin_key}" - if plugin_info.module in [p[0] for p in db_plugin_list]: - return f"插件 {plugin_info.name} 已安装,无需重复安装" - is_external = True + plugin_info, is_external = await cls.get_plugin_by_value(index_or_module) if plugin_info.github_url is None: plugin_info.github_url = DEFAULT_GITHUB_URL - is_external = False version_split = plugin_info.version.split("-") if len(version_split) > 1: github_url_split = plugin_info.github_url.split("/tree/") @@ -228,90 +236,81 @@ class StoreManager: is_dir: bool, is_external: bool = False, ): - repo_api: RepoAPI - repo_info = GithubUtils.parse_github_url(github_url) - if await repo_info.update_repo_commit(): - logger.info(f"获取最新提交: {repo_info.branch}", LOG_COMMAND) - else: - logger.warning(f"获取最新提交失败: {repo_info}", LOG_COMMAND) - logger.debug(f"成功获取仓库信息: {repo_info}", LOG_COMMAND) - for repo_api in GithubUtils.iter_api_strategies(): - try: - await repo_api.parse_repo_info(repo_info) - break - except Exception as e: - logger.warning( - f"获取插件文件失败 | API类型: {repo_api.strategy}", - LOG_COMMAND, - e=e, - ) - continue - else: - raise ValueError("所有API获取插件文件失败,请检查网络连接") - if module_path == ".": - module_path = "" + """安装插件 + + 参数: + github_url: 仓库地址 + module_path: 模块路径 + is_dir: 是否是文件夹 + is_external: 是否是外部仓库 + """ + repo_type = RepoType.GITHUB if is_external else None replace_module_path = module_path.replace(".", "/") - files = repo_api.get_files( - module_path=replace_module_path + ("" if is_dir else ".py"), - is_dir=is_dir, - ) - download_urls = [await repo_info.get_raw_download_urls(file) for file in files] - base_path = BASE_PATH / "plugins" if is_external else BASE_PATH - base_path = base_path if module_path else base_path / repo_info.repo - download_paths: list[Path | str] = [base_path / file for file in files] - logger.debug(f"插件下载路径: {download_paths}", LOG_COMMAND) - result = await AsyncHttpx.gather_download_file(download_urls, download_paths) - for _id, success in enumerate(result): - if not success: - break + if is_dir: + files = await RepoFileManager.list_directory_files( + github_url, replace_module_path, repo_type=repo_type + ) else: - # 安装依赖 - plugin_path = base_path / "/".join(module_path.split(".")) - try: - req_files = repo_api.get_files( - f"{replace_module_path}/{REQ_TXT_FILE_STRING}", False + files = [RepoFileInfo(path=f"{replace_module_path}.py", is_dir=False)] + local_path = BASE_PATH / "plugins" if is_external else BASE_PATH + files = [file for file in files if not file.is_dir] + download_files = [(file.path, local_path / file.path) for file in files] + await RepoFileManager.download_files( + github_url, download_files, repo_type=repo_type + ) + + requirement_paths = [ + file + for file in files + if file.path.endswith("requirement.txt") + or file.path.endswith("requirements.txt") + ] + + is_install_req = False + for requirement_path in requirement_paths: + requirement_file = local_path / requirement_path.path + if requirement_file.exists(): + is_install_req = True + await VirtualEnvPackageManager.install_requirement(requirement_file) + + if not is_install_req: + # 从仓库根目录查找文件 + rand = random.randint(1, 10000) + requirement_path = TEMP_PATH / f"plugin_store_{rand}_req.txt" + requirements_path = TEMP_PATH / f"plugin_store_{rand}_reqs.txt" + await RepoFileManager.download_files( + github_url, + [ + ("requirement.txt", requirement_path), + ("requirements.txt", requirements_path), + ], + repo_type=repo_type, + ignore_error=True, + ) + if requirement_path.exists(): + logger.info( + f"开始安装插件 {module_path} 依赖文件: {requirement_path}", + LOG_COMMAND, ) - req_files.extend( - repo_api.get_files(f"{replace_module_path}/requirement.txt", False) + await VirtualEnvPackageManager.install_requirement(requirement_path) + if requirements_path.exists(): + logger.info( + f"开始安装插件 {module_path} 依赖文件: {requirements_path}", + LOG_COMMAND, ) - logger.debug(f"获取插件依赖文件列表: {req_files}", LOG_COMMAND) - req_download_urls = [ - await repo_info.get_raw_download_urls(file) for file in req_files - ] - req_paths: list[Path | str] = [plugin_path / file for file in req_files] - logger.debug(f"插件依赖文件下载路径: {req_paths}", LOG_COMMAND) - if req_files: - result = await AsyncHttpx.gather_download_file( - req_download_urls, req_paths - ) - for success in result: - if not success: - raise Exception("插件依赖文件下载失败") - logger.debug(f"插件依赖文件列表: {req_paths}", LOG_COMMAND) - install_requirement(plugin_path) - except ValueError as e: - logger.warning("未获取到依赖文件路径...", e=e) - return True - raise Exception("插件下载失败...") + await VirtualEnvPackageManager.install_requirement(requirements_path) @classmethod - async def remove_plugin(cls, plugin_id: str) -> str: + async def remove_plugin(cls, index_or_module: str) -> str: """移除插件 参数: - plugin_id: 插件id或模块名 + index_or_module: 插件id或模块名 返回: str: 返回消息 """ - plugin_list: list[StorePluginInfo] = await cls.get_data() - try: - plugin_key = await cls._resolve_plugin_key(plugin_id) - except ValueError as e: - return str(e) - plugin_info = next((p for p in plugin_list if p.module == plugin_key), None) - if plugin_info is None: - return f"未找到插件 {plugin_key}" + plugin_info, _ = await cls.get_plugin_by_value(index_or_module) path = BASE_PATH if plugin_info.github_url: path = BASE_PATH / "plugins" @@ -339,12 +338,13 @@ class StoreManager: 返回: BuildImage | str: 返回消息 """ - plugin_list: list[StorePluginInfo] = await cls.get_data() + plugin_list, extra_plugin_list = await cls.get_data() + all_plugin_list = plugin_list + extra_plugin_list db_plugin_list = await cls.get_loaded_plugins("module", "version") suc_plugin = {p[0]: (p[1] or "Unknown") for p in db_plugin_list} filtered_data = [ (id, plugin_info) - for id, plugin_info in enumerate(plugin_list) + for id, plugin_info in enumerate(all_plugin_list) if plugin_name_or_author.lower() in plugin_info.name.lower() or plugin_name_or_author.lower() in plugin_info.author.lower() ] @@ -373,35 +373,24 @@ class StoreManager: ) @classmethod - async def update_plugin(cls, plugin_id: str) -> str: + async def update_plugin(cls, index_or_module: str) -> str: """更新插件 参数: - plugin_id: 插件id + index_or_module: 插件id 返回: str: 返回消息 """ - plugin_list: list[StorePluginInfo] = await cls.get_data() - try: - plugin_key = await cls._resolve_plugin_key(plugin_id) - except ValueError as e: - return str(e) - plugin_info = next((p for p in plugin_list if p.module == plugin_key), None) - if plugin_info is None: - return f"未找到插件 {plugin_key}" + plugin_info, is_external = await cls.get_plugin_by_value(index_or_module, True) logger.info(f"尝试更新插件 {plugin_info.name}", LOG_COMMAND) db_plugin_list = await cls.get_loaded_plugins("module", "version") suc_plugin = {p[0]: (p[1] or "Unknown") for p in db_plugin_list} - if plugin_info.module not in [p[0] for p in db_plugin_list]: - return f"插件 {plugin_info.name} 未安装,无法更新" logger.debug(f"当前插件列表: {suc_plugin}", LOG_COMMAND) if cls.check_version_is_new(plugin_info, suc_plugin): return f"插件 {plugin_info.name} 已是最新版本" - is_external = True if plugin_info.github_url is None: plugin_info.github_url = DEFAULT_GITHUB_URL - is_external = False await cls.install_plugin_with_repo( plugin_info.github_url, plugin_info.module_path, @@ -420,8 +409,9 @@ class StoreManager: 返回: str: 返回消息 """ - plugin_list: list[StorePluginInfo] = await cls.get_data() - plugin_name_list = [p.name for p in plugin_list] + plugin_list, extra_plugin_list = await cls.get_data() + all_plugin_list = plugin_list + extra_plugin_list + plugin_name_list = [p.name for p in all_plugin_list] update_failed_list = [] update_success_list = [] result = "--已更新{}个插件 {}个失败 {}个成功--" @@ -492,22 +482,25 @@ class StoreManager: plugin_id: module,id或插件名称 异常: - ValueError: 插件不存在 - ValueError: 插件不存在 + PluginStoreException: 插件不存在 + PluginStoreException: 插件不存在 返回: str: 插件模块名 """ - plugin_list: list[StorePluginInfo] = await cls.get_data() + plugin_list, extra_plugin_list = await cls.get_data() + all_plugin_list = plugin_list + extra_plugin_list if is_number(plugin_id): idx = int(plugin_id) - if idx < 0 or idx >= len(plugin_list): - raise ValueError("插件ID不存在...") - return plugin_list[idx].module + if idx < 0 or idx >= len(all_plugin_list): + raise PluginStoreException("插件ID不存在...") + return all_plugin_list[idx].module elif isinstance(plugin_id, str): result = ( - None if plugin_id not in [v.module for v in plugin_list] else plugin_id - ) or next(v for v in plugin_list if v.name == plugin_id).module + None + if plugin_id not in [v.module for v in all_plugin_list] + else plugin_id + ) or next(v for v in all_plugin_list if v.name == plugin_id).module if not result: - raise ValueError("插件 Module / 名称 不存在...") + raise PluginStoreException("插件 Module / 名称 不存在...") return result diff --git a/zhenxun/builtin_plugins/plugin_store/exceptions.py b/zhenxun/builtin_plugins/plugin_store/exceptions.py new file mode 100644 index 00000000..76846db9 --- /dev/null +++ b/zhenxun/builtin_plugins/plugin_store/exceptions.py @@ -0,0 +1,6 @@ +class PluginStoreException(Exception): + def __init__(self, message: str): + self.message = message + + def __str__(self): + return self.message diff --git a/zhenxun/builtin_plugins/web_ui/api/configure/__init__.py b/zhenxun/builtin_plugins/web_ui/api/configure/__init__.py index 0ecde197..779653b1 100644 --- a/zhenxun/builtin_plugins/web_ui/api/configure/__init__.py +++ b/zhenxun/builtin_plugins/web_ui/api/configure/__init__.py @@ -38,10 +38,11 @@ async def _(setting: Setting) -> Result: password = Config.get_config("web-ui", "password") if password or BotConfig.db_url: return Result.fail("配置已存在,请先删除DB_URL内容和前端密码再进行设置。") - env_file = Path() / ".env.dev" + env_file = Path() / ".env.example" if not env_file.exists(): - return Result.fail("配置文件.env.dev不存在。") + return Result.fail("基础配置文件.env.example不存在。") env_text = env_file.read_text(encoding="utf-8") + to_env_file = Path() / ".env.dev" if setting.db_url: if setting.db_url.startswith("sqlite"): base_dir = Path().resolve() @@ -78,7 +79,7 @@ async def _(setting: Setting) -> Result: if setting.username: Config.set_config("web-ui", "username", setting.username) Config.set_config("web-ui", "password", setting.password, True) - env_file.write_text(env_text, encoding="utf-8") + to_env_file.write_text(env_text, encoding="utf-8") if BAT_FILE.exists(): for file in os.listdir(Path()): if file.startswith(FILE_NAME): diff --git a/zhenxun/builtin_plugins/web_ui/api/tabs/plugin_manage/__init__.py b/zhenxun/builtin_plugins/web_ui/api/tabs/plugin_manage/__init__.py index 1187ad65..1e0d5a50 100644 --- a/zhenxun/builtin_plugins/web_ui/api/tabs/plugin_manage/__init__.py +++ b/zhenxun/builtin_plugins/web_ui/api/tabs/plugin_manage/__init__.py @@ -229,9 +229,9 @@ async def _(payload: InstallDependenciesPayload) -> Result: if not payload.dependencies: return Result.fail("依赖列表不能为空") if payload.handle_type == "install": - result = VirtualEnvPackageManager.install(payload.dependencies) + result = await VirtualEnvPackageManager.install(payload.dependencies) else: - result = VirtualEnvPackageManager.uninstall(payload.dependencies) + result = await VirtualEnvPackageManager.uninstall(payload.dependencies) return Result.ok(result) except Exception as e: logger.error(f"{router.prefix}/install_dependencies 调用错误", "WebUi", e=e) diff --git a/zhenxun/builtin_plugins/web_ui/api/tabs/plugin_manage/store.py b/zhenxun/builtin_plugins/web_ui/api/tabs/plugin_manage/store.py index 9ee6ff41..35d19fe7 100644 --- a/zhenxun/builtin_plugins/web_ui/api/tabs/plugin_manage/store.py +++ b/zhenxun/builtin_plugins/web_ui/api/tabs/plugin_manage/store.py @@ -25,10 +25,10 @@ async def _() -> Result[dict]: require("plugin_store") from zhenxun.builtin_plugins.plugin_store import StoreManager - data = await StoreManager.get_data() + plugin_list, extra_plugin_list = await StoreManager.get_data() plugin_list = [ {**model_dump(plugin), "name": plugin.name, "id": idx} - for idx, plugin in enumerate(data) + for idx, plugin in enumerate(plugin_list + extra_plugin_list) ] modules = await PluginInfo.filter(load_status=True).values_list( "module", flat=True diff --git a/zhenxun/builtin_plugins/web_ui/config.py b/zhenxun/builtin_plugins/web_ui/config.py index 4a88aad9..8182b60d 100644 --- a/zhenxun/builtin_plugins/web_ui/config.py +++ b/zhenxun/builtin_plugins/web_ui/config.py @@ -8,16 +8,6 @@ if sys.version_info >= (3, 11): else: from strenum import StrEnum -from zhenxun.configs.path_config import DATA_PATH, TEMP_PATH - -WEBUI_STRING = "web_ui" -PUBLIC_STRING = "public" - -WEBUI_DATA_PATH = DATA_PATH / WEBUI_STRING -PUBLIC_PATH = WEBUI_DATA_PATH / PUBLIC_STRING -TMP_PATH = TEMP_PATH / WEBUI_STRING - -WEBUI_DIST_GITHUB_URL = "https://github.com/HibiKier/zhenxun_bot_webui/tree/dist" app = nonebot.get_app() diff --git a/zhenxun/builtin_plugins/web_ui/public/__init__.py b/zhenxun/builtin_plugins/web_ui/public/__init__.py index 53d4914e..76e73538 100644 --- a/zhenxun/builtin_plugins/web_ui/public/__init__.py +++ b/zhenxun/builtin_plugins/web_ui/public/__init__.py @@ -3,41 +3,38 @@ from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles from zhenxun.services.log import logger - -from ..config import PUBLIC_PATH -from .data_source import COMMAND_NAME, update_webui_assets +from zhenxun.utils.manager.zhenxun_repo_manager import ZhenxunRepoManager router = APIRouter() @router.get("/") async def index(): - return FileResponse(PUBLIC_PATH / "index.html") + return FileResponse(ZhenxunRepoManager.config.WEBUI_PATH / "index.html") @router.get("/favicon.ico") async def favicon(): - return FileResponse(PUBLIC_PATH / "favicon.ico") - - -@router.get("/79edfa81f3308a9f.jfif") -async def _(): - return FileResponse(PUBLIC_PATH / "79edfa81f3308a9f.jfif") + return FileResponse(ZhenxunRepoManager.config.WEBUI_PATH / "favicon.ico") async def init_public(app: FastAPI): try: - if not PUBLIC_PATH.exists(): - folders = await update_webui_assets() - else: - folders = [x.name for x in PUBLIC_PATH.iterdir() if x.is_dir()] + if not ZhenxunRepoManager.check_webui_exists(): + await ZhenxunRepoManager.webui_update(branch="test") + folders = [ + x.name for x in ZhenxunRepoManager.config.WEBUI_PATH.iterdir() if x.is_dir() + ] app.include_router(router) for pathname in folders: logger.debug(f"挂载文件夹: {pathname}") app.mount( f"/{pathname}", - StaticFiles(directory=PUBLIC_PATH / pathname, check_dir=True), + StaticFiles( + directory=ZhenxunRepoManager.config.WEBUI_PATH / pathname, + check_dir=True, + ), name=f"public_{pathname}", ) except Exception as e: - logger.error("初始化 WebUI资源 失败", COMMAND_NAME, e=e) + logger.error("初始化 WebUI资源 失败", "WebUI", e=e) diff --git a/zhenxun/builtin_plugins/web_ui/public/data_source.py b/zhenxun/builtin_plugins/web_ui/public/data_source.py deleted file mode 100644 index 51b29533..00000000 --- a/zhenxun/builtin_plugins/web_ui/public/data_source.py +++ /dev/null @@ -1,44 +0,0 @@ -from pathlib import Path -import shutil -import zipfile - -from nonebot.utils import run_sync - -from zhenxun.services.log import logger -from zhenxun.utils.github_utils import GithubUtils -from zhenxun.utils.http_utils import AsyncHttpx - -from ..config import PUBLIC_PATH, TMP_PATH, WEBUI_DIST_GITHUB_URL - -COMMAND_NAME = "WebUI资源管理" - - -async def update_webui_assets(): - webui_assets_path = TMP_PATH / "webui_assets.zip" - download_url = await GithubUtils.parse_github_url( - WEBUI_DIST_GITHUB_URL - ).get_archive_download_urls() - logger.info("开始下载 webui_assets 资源...", COMMAND_NAME) - if await AsyncHttpx.download_file( - download_url, webui_assets_path, follow_redirects=True - ): - logger.info("下载 webui_assets 成功...", COMMAND_NAME) - return await _file_handle(webui_assets_path) - raise Exception("下载 webui_assets 失败", COMMAND_NAME) - - -@run_sync -def _file_handle(webui_assets_path: Path): - logger.debug("开始解压 webui_assets...", COMMAND_NAME) - if webui_assets_path.exists(): - tf = zipfile.ZipFile(webui_assets_path) - tf.extractall(TMP_PATH) - logger.debug("解压 webui_assets 成功...", COMMAND_NAME) - else: - raise Exception("解压 webui_assets 失败,文件不存在...", COMMAND_NAME) - download_file_path = next(f for f in TMP_PATH.iterdir() if f.is_dir()) - shutil.rmtree(PUBLIC_PATH, ignore_errors=True) - shutil.copytree(download_file_path / "dist", PUBLIC_PATH, dirs_exist_ok=True) - logger.debug("复制 webui_assets 成功...", COMMAND_NAME) - shutil.rmtree(TMP_PATH, ignore_errors=True) - return [x.name for x in PUBLIC_PATH.iterdir() if x.is_dir()] diff --git a/zhenxun/services/db_context/__init__.py b/zhenxun/services/db_context/__init__.py index 26fd9bcd..70ead644 100644 --- a/zhenxun/services/db_context/__init__.py +++ b/zhenxun/services/db_context/__init__.py @@ -1,6 +1,8 @@ import asyncio +from pathlib import Path from urllib.parse import urlparse +import aiofiles import nonebot from nonebot.utils import is_coroutine_callable from tortoise import Tortoise @@ -86,6 +88,7 @@ def get_config() -> dict: **MYSQL_CONFIG, } elif parsed.scheme == "sqlite": + Path(parsed.path).parent.mkdir(parents=True, exist_ok=True) config["connections"]["default"] = { "engine": "tortoise.backends.sqlite", "credentials": { @@ -100,6 +103,15 @@ def get_config() -> dict: async def init(): global MODELS, SCRIPT_METHOD + env_example_file = Path() / ".env.example" + env_dev_file = Path() / ".env.dev" + if not env_dev_file.exists(): + async with aiofiles.open(env_example_file, encoding="utf-8") as f: + env_text = await f.read() + async with aiofiles.open(env_dev_file, "w", encoding="utf-8") as f: + await f.write(env_text) + logger.info("已生成 .env.dev 文件,请根据 .env.example 文件配置进行配置") + MODELS = db_model.models SCRIPT_METHOD = db_model.script_method if not BotConfig.db_url: diff --git a/zhenxun/utils/github_utils/models.py b/zhenxun/utils/github_utils/models.py index 06e0ca33..ae4ab2d3 100644 --- a/zhenxun/utils/github_utils/models.py +++ b/zhenxun/utils/github_utils/models.py @@ -317,6 +317,20 @@ class AliyunFileInfo: repository_id: str """仓库ID""" + @classmethod + async def get_client(cls) -> devops20210625Client: + """获取阿里云客户端""" + config = open_api_models.Config( + access_key_id=Aliyun_AccessKey_ID, + access_key_secret=base64.b64decode( + Aliyun_Secret_AccessKey_encrypted.encode() + ).decode(), + endpoint=ALIYUN_ENDPOINT, + region_id=ALIYUN_REGION, + ) + + return devops20210625Client(config) + @classmethod async def get_file_content( cls, file_path: str, repo: str, ref: str = "main" @@ -335,16 +349,8 @@ class AliyunFileInfo: repository_id = ALIYUN_REPO_MAPPING.get(repo) if not repository_id: raise ValueError(f"未找到仓库 {repo} 对应的阿里云仓库ID") - config = open_api_models.Config( - access_key_id=Aliyun_AccessKey_ID, - access_key_secret=base64.b64decode( - Aliyun_Secret_AccessKey_encrypted.encode() - ).decode(), - endpoint=ALIYUN_ENDPOINT, - region_id=ALIYUN_REGION, - ) - client = devops20210625Client(config) + client = await cls.get_client() request = devops_20210625_models.GetFileBlobsRequest( organization_id=ALIYUN_ORG_ID, @@ -404,16 +410,7 @@ class AliyunFileInfo: if not repository_id: raise ValueError(f"未找到仓库 {repo} 对应的阿里云仓库ID") - config = open_api_models.Config( - access_key_id=Aliyun_AccessKey_ID, - access_key_secret=base64.b64decode( - Aliyun_Secret_AccessKey_encrypted.encode() - ).decode(), - endpoint=ALIYUN_ENDPOINT, - region_id=ALIYUN_REGION, - ) - - client = devops20210625Client(config) + client = await cls.get_client() request = devops_20210625_models.ListRepositoryTreeRequest( organization_id=ALIYUN_ORG_ID, @@ -459,16 +456,7 @@ class AliyunFileInfo: if not repository_id: raise ValueError(f"未找到仓库 {repo} 对应的阿里云仓库ID") - config = open_api_models.Config( - access_key_id=Aliyun_AccessKey_ID, - access_key_secret=base64.b64decode( - Aliyun_Secret_AccessKey_encrypted.encode() - ).decode(), - endpoint=ALIYUN_ENDPOINT, - region_id=ALIYUN_REGION, - ) - - client = devops20210625Client(config) + client = await cls.get_client() request = devops_20210625_models.GetRepositoryCommitRequest( organization_id=ALIYUN_ORG_ID, diff --git a/zhenxun/utils/manager/resource_manager.py b/zhenxun/utils/manager/resource_manager.py deleted file mode 100644 index a859d6b9..00000000 --- a/zhenxun/utils/manager/resource_manager.py +++ /dev/null @@ -1,87 +0,0 @@ -import os -from pathlib import Path -import shutil -import zipfile - -from zhenxun.configs.path_config import FONT_PATH -from zhenxun.services.log import logger -from zhenxun.utils.github_utils import GithubUtils -from zhenxun.utils.http_utils import AsyncHttpx - -CMD_STRING = "ResourceManager" - - -class DownloadResourceException(Exception): - pass - - -class ResourceManager: - GITHUB_URL = "https://github.com/zhenxun-org/zhenxun-bot-resources/tree/main" - - RESOURCE_PATH = Path() / "resources" - - TMP_PATH = Path() / "_resource_tmp" - - ZIP_FILE = TMP_PATH / "resources.zip" - - UNZIP_PATH = None - - @classmethod - async def init_resources(cls, force: bool = False): - if (FONT_PATH.exists() and os.listdir(FONT_PATH)) and not force: - return - if cls.TMP_PATH.exists(): - logger.debug( - "resources临时文件夹已存在,移除resources临时文件夹", CMD_STRING - ) - shutil.rmtree(cls.TMP_PATH) - cls.TMP_PATH.mkdir(parents=True, exist_ok=True) - try: - await cls.__download_resources() - cls.file_handle() - except Exception as e: - logger.error("获取resources资源包失败", CMD_STRING, e=e) - if cls.TMP_PATH.exists(): - logger.debug("移除resources临时文件夹", CMD_STRING) - shutil.rmtree(cls.TMP_PATH) - - @classmethod - def file_handle(cls): - if not cls.UNZIP_PATH: - return - cls.__recursive_folder(cls.UNZIP_PATH, "resources") - - @classmethod - def __recursive_folder(cls, dir: Path, parent_path: str): - for file in dir.iterdir(): - if file.is_dir(): - cls.__recursive_folder(file, f"{parent_path}/{file.name}") - else: - res_file = Path(parent_path) / file.name - if res_file.exists(): - res_file.unlink() - res_file.parent.mkdir(parents=True, exist_ok=True) - file.rename(res_file) - - @classmethod - async def __download_resources(cls): - """获取resources文件夹""" - repo_info = GithubUtils.parse_github_url(cls.GITHUB_URL) - url = await repo_info.get_archive_download_urls() - logger.debug("开始下载resources资源包...", CMD_STRING) - if not await AsyncHttpx.download_file(url, cls.ZIP_FILE, stream=True): - logger.error( - "下载resources资源包失败,请尝试重启重新下载或前往 " - "https://github.com/zhenxun-org/zhenxun-bot-resources 手动下载..." - ) - raise DownloadResourceException("下载resources资源包失败...") - logger.debug("下载resources资源文件压缩包完成...", CMD_STRING) - tf = zipfile.ZipFile(cls.ZIP_FILE) - tf.extractall(cls.TMP_PATH) - logger.debug("解压文件压缩包完成...", CMD_STRING) - download_file_path = cls.TMP_PATH / next( - x for x in os.listdir(cls.TMP_PATH) if (cls.TMP_PATH / x).is_dir() - ) - cls.UNZIP_PATH = download_file_path / "resources" - if tf: - tf.close() diff --git a/zhenxun/utils/manager/virtual_env_package_manager.py b/zhenxun/utils/manager/virtual_env_package_manager.py index ba60d9b3..7f938e0a 100644 --- a/zhenxun/utils/manager/virtual_env_package_manager.py +++ b/zhenxun/utils/manager/virtual_env_package_manager.py @@ -1,3 +1,4 @@ +import asyncio from pathlib import Path import subprocess from subprocess import CalledProcessError @@ -36,7 +37,7 @@ class VirtualEnvPackageManager: ) @classmethod - def install(cls, package: list[str] | str): + async def install(cls, package: list[str] | str): """安装依赖包 参数: @@ -49,7 +50,8 @@ class VirtualEnvPackageManager: command.append("install") command.append(" ".join(package)) logger.info(f"执行虚拟环境安装包指令: {command}", LOG_COMMAND) - result = subprocess.run( + result = await asyncio.to_thread( + subprocess.run, command, check=True, capture_output=True, @@ -65,7 +67,7 @@ class VirtualEnvPackageManager: return e.stderr @classmethod - def uninstall(cls, package: list[str] | str): + async def uninstall(cls, package: list[str] | str): """卸载依赖包 参数: @@ -79,7 +81,8 @@ class VirtualEnvPackageManager: command.append("-y") command.append(" ".join(package)) logger.info(f"执行虚拟环境卸载包指令: {command}", LOG_COMMAND) - result = subprocess.run( + result = await asyncio.to_thread( + subprocess.run, command, check=True, capture_output=True, @@ -95,7 +98,7 @@ class VirtualEnvPackageManager: return e.stderr @classmethod - def update(cls, package: list[str] | str): + async def update(cls, package: list[str] | str): """更新依赖包 参数: @@ -109,7 +112,8 @@ class VirtualEnvPackageManager: command.append("--upgrade") command.append(" ".join(package)) logger.info(f"执行虚拟环境更新包指令: {command}", LOG_COMMAND) - result = subprocess.run( + result = await asyncio.to_thread( + subprocess.run, command, check=True, capture_output=True, @@ -122,7 +126,7 @@ class VirtualEnvPackageManager: return e.stderr @classmethod - def install_requirement(cls, requirement_file: Path): + async def install_requirement(cls, requirement_file: Path): """安装依赖文件 参数: @@ -139,7 +143,8 @@ class VirtualEnvPackageManager: command.append("-r") command.append(str(requirement_file.absolute())) logger.info(f"执行虚拟环境安装依赖文件指令: {command}", LOG_COMMAND) - result = subprocess.run( + result = await asyncio.to_thread( + subprocess.run, command, check=True, capture_output=True, @@ -158,13 +163,14 @@ class VirtualEnvPackageManager: return e.stderr @classmethod - def list(cls) -> str: + async def list(cls) -> str: """列出已安装的依赖包""" try: command = cls.__get_command() command.append("list") logger.info(f"执行虚拟环境列出包指令: {command}", LOG_COMMAND) - result = subprocess.run( + result = await asyncio.to_thread( + subprocess.run, command, check=True, capture_output=True, diff --git a/zhenxun/utils/manager/zhenxun_repo_manager.py b/zhenxun/utils/manager/zhenxun_repo_manager.py new file mode 100644 index 00000000..145e8b52 --- /dev/null +++ b/zhenxun/utils/manager/zhenxun_repo_manager.py @@ -0,0 +1,558 @@ +""" +真寻仓库管理器 +负责真寻主仓库的更新、版本检查、文件处理等功能 +""" + +import os +from pathlib import Path +import shutil +from typing import ClassVar, Literal +import zipfile + +import aiofiles + +from zhenxun.configs.path_config import DATA_PATH, TEMP_PATH +from zhenxun.services.log import logger +from zhenxun.utils.github_utils import GithubUtils +from zhenxun.utils.http_utils import AsyncHttpx +from zhenxun.utils.manager.virtual_env_package_manager import VirtualEnvPackageManager +from zhenxun.utils.repo_utils import AliyunRepoManager, GithubRepoManager +from zhenxun.utils.repo_utils.models import RepoUpdateResult +from zhenxun.utils.repo_utils.utils import check_git + +LOG_COMMAND = "ZhenxunRepoManager" + + +class ZhenxunUpdateException(Exception): + """资源下载异常""" + + pass + + +class ZhenxunRepoConfig: + """真寻仓库配置""" + + # Zhenxun Bot 相关配置 + ZHENXUN_BOT_GIT = "https://github.com/zhenxun-org/zhenxun_bot.git" + ZHENXUN_BOT_GITHUB_URL = "https://github.com/HibiKier/zhenxun_bot/tree/main" + ZHENXUN_BOT_DOWNLOAD_FILE_STRING = "zhenxun_bot.zip" + ZHENXUN_BOT_DOWNLOAD_FILE = TEMP_PATH / ZHENXUN_BOT_DOWNLOAD_FILE_STRING + ZHENXUN_BOT_UNZIP_PATH = TEMP_PATH / "zhenxun_bot" + ZHENXUN_BOT_CODE_PATH = Path() / "zhenxun" + ZHENXUN_BOT_RELEASES_API_URL = ( + "https://api.github.com/repos/HibiKier/zhenxun_bot/releases/latest" + ) + ZHENXUN_BOT_BACKUP_PATH = Path() / "backup" + # 需要替换的文件夹 + ZHENXUN_BOT_UPDATE_FOLDERS: ClassVar[list[str]] = [ + "builtin_plugins", + "services", + "utils", + "models", + "configs", + ] + ZHENXUN_BOT_VERSION_FILE_STRING = "__version__" + ZHENXUN_BOT_VERSION_FILE = Path() / ZHENXUN_BOT_VERSION_FILE_STRING + + # 备份杂项 + BACKUP_FILES: ClassVar[list[str]] = [ + "pyproject.toml", + "poetry.lock", + "requirements.txt", + ".env.dev", + ".env.example", + ] + + # WEB UI 相关配置 + WEBUI_GIT = "https://github.com/HibiKier/zhenxun_bot_webui.git" + WEBUI_DIST_GITHUB_URL = "https://github.com/HibiKier/zhenxun_bot_webui/tree/test" + WEBUI_DOWNLOAD_FILE_STRING = "webui_assets.zip" + WEBUI_DOWNLOAD_FILE = TEMP_PATH / WEBUI_DOWNLOAD_FILE_STRING + WEBUI_UNZIP_PATH = TEMP_PATH / "web_ui" + WEBUI_PATH = DATA_PATH / "web_ui" / "public" + WEBUI_BACKUP_PATH = DATA_PATH / "web_ui" / "backup_public" + + # 资源管理相关配置 + RESOURCE_GIT = "https://github.com/zhenxun-org/zhenxun-bot-resources.git" + RESOURCE_GITHUB_URL = ( + "https://github.com/zhenxun-org/zhenxun-bot-resources/tree/main" + ) + RESOURCE_ZIP_FILE_STRING = "resources.zip" + RESOURCE_ZIP_FILE = TEMP_PATH / RESOURCE_ZIP_FILE_STRING + RESOURCE_UNZIP_PATH = TEMP_PATH / "resources" + RESOURCE_PATH = Path() / "resources" + + REQUIREMENTS_FILE_STRING = "requirements.txt" + REQUIREMENTS_FILE = Path() / REQUIREMENTS_FILE_STRING + + PYPROJECT_FILE_STRING = "pyproject.toml" + PYPROJECT_FILE = Path() / PYPROJECT_FILE_STRING + + PYPROJECT_LOCK_FILE_STRING = "poetry.lock" + PYPROJECT_LOCK_FILE = Path() / PYPROJECT_LOCK_FILE_STRING + + +class ZhenxunRepoManagerClass: + """真寻仓库管理器""" + + def __init__(self): + self.config = ZhenxunRepoConfig() + + def __clear_folder(self, folder_path: Path): + """ + 清空文件夹 + + 参数: + folder_path: 文件夹路径 + """ + if not folder_path.exists(): + return + for filename in os.listdir(folder_path): + file_path = folder_path / filename + try: + if file_path.is_file(): + os.unlink(file_path) + elif file_path.is_dir() and not filename.startswith("."): + shutil.rmtree(file_path) + except Exception as e: + logger.warning(f"无法删除 {file_path}", LOG_COMMAND, e=e) + + def __copy_files(self, src_path: Path, dest_path: Path, incremental: bool = False): + """ + 复制文件或文件夹 + + 参数: + src_path: 源文件或文件夹路径 + dest_path: 目标文件或文件夹路径 + incremental: 是否增量复制 + """ + if src_path.is_file(): + shutil.copy(src_path, dest_path) + logger.debug(f"复制文件 {src_path} -> {dest_path}", LOG_COMMAND) + elif src_path.is_dir(): + for filename in os.listdir(src_path): + file_path = src_path / filename + dest_file = dest_path / filename + dest_file.parent.mkdir(exist_ok=True, parents=True) + if file_path.is_file(): + if dest_file.exists(): + dest_file.unlink() + shutil.copy(file_path, dest_file) + logger.debug(f"复制文件 {file_path} -> {dest_file}", LOG_COMMAND) + elif file_path.is_dir(): + if incremental: + self.__copy_files(file_path, dest_file, incremental=True) + else: + if dest_file.exists(): + shutil.rmtree(dest_file, True) + shutil.copytree(file_path, dest_file) + logger.debug( + f"复制文件夹 {file_path} -> {dest_file}", + LOG_COMMAND, + ) + + # ==================== Zhenxun Bot 相关方法 ==================== + + async def zhenxun_get_version_from_repo(self) -> str: + """从指定分支获取版本号 + + + 返回: + str: 版本号 + """ + repo_info = GithubUtils.parse_github_url(self.config.ZHENXUN_BOT_GITHUB_URL) + version_url = await repo_info.get_raw_download_urls( + path=self.config.ZHENXUN_BOT_VERSION_FILE_STRING + ) + try: + res = await AsyncHttpx.get(version_url) + if res.status_code == 200: + return res.text.strip() + except Exception as e: + logger.error(f"获取 {repo_info.branch} 分支版本失败", LOG_COMMAND, e=e) + return "未知版本" + + async def zhenxun_write_version_file(self, version: str): + """写入版本文件""" + async with aiofiles.open( + self.config.ZHENXUN_BOT_VERSION_FILE, "w", encoding="utf8" + ) as f: + await f.write(f"__version__: {version}") + + def __backup_zhenxun(self): + """备份真寻文件""" + for filename in os.listdir(self.config.ZHENXUN_BOT_CODE_PATH): + file_path = self.config.ZHENXUN_BOT_CODE_PATH / filename + if file_path.exists(): + self.__copy_files( + file_path, + self.config.ZHENXUN_BOT_BACKUP_PATH / filename, + True, + ) + for filename in self.config.BACKUP_FILES: + file_path = Path() / filename + if file_path.exists(): + self.__copy_files( + file_path, + self.config.ZHENXUN_BOT_BACKUP_PATH / filename, + ) + + async def zhenxun_get_latest_releases_data(self) -> dict: + """获取真寻releases最新版本信息 + + 返回: + dict: 最新版本数据 + """ + try: + res = await AsyncHttpx.get(self.config.ZHENXUN_BOT_RELEASES_API_URL) + if res.status_code == 200: + return res.json() + except Exception as e: + logger.error("检查更新真寻获取版本失败", LOG_COMMAND, e=e) + return {} + + async def zhenxun_download_zip(self, ver_type: Literal["main", "release"]) -> str: + """下载真寻最新版文件 + + 参数: + ver_type: 版本类型,main 为最新版,release 为最新release版 + + 返回: + str: 版本号 + """ + repo_info = GithubUtils.parse_github_url(self.config.ZHENXUN_BOT_GITHUB_URL) + if ver_type == "main": + download_url = await repo_info.get_archive_download_urls() + new_version = await self.zhenxun_get_version_from_repo() + else: + release_data = await self.zhenxun_get_latest_releases_data() + logger.debug(f"获取真寻RELEASES最新版本信息: {release_data}", LOG_COMMAND) + if not release_data: + raise ZhenxunUpdateException("获取真寻RELEASES最新版本失败...") + new_version = release_data.get("name", "") + download_url = await repo_info.get_release_source_download_urls_tgz( + new_version + ) + if not download_url: + raise ZhenxunUpdateException("获取真寻最新版文件下载链接失败...") + if self.config.ZHENXUN_BOT_DOWNLOAD_FILE.exists(): + self.config.ZHENXUN_BOT_DOWNLOAD_FILE.unlink() + if await AsyncHttpx.download_file( + download_url, self.config.ZHENXUN_BOT_DOWNLOAD_FILE, stream=True + ): + logger.debug("下载真寻最新版文件完成...", LOG_COMMAND) + else: + raise ZhenxunUpdateException("下载真寻最新版文件失败...") + return new_version + + async def zhenxun_unzip(self): + """解压真寻最新版文件""" + if not self.config.ZHENXUN_BOT_DOWNLOAD_FILE.exists(): + raise FileNotFoundError("真寻最新版文件不存在") + if self.config.ZHENXUN_BOT_UNZIP_PATH.exists(): + shutil.rmtree(self.config.ZHENXUN_BOT_UNZIP_PATH) + tf = None + try: + tf = zipfile.ZipFile(self.config.ZHENXUN_BOT_DOWNLOAD_FILE) + tf.extractall(self.config.ZHENXUN_BOT_UNZIP_PATH) + logger.debug("解压Zhenxun Bot文件压缩包完成!", LOG_COMMAND) + self.__backup_zhenxun() + for filename in self.config.BACKUP_FILES: + self.__copy_files( + self.config.ZHENXUN_BOT_UNZIP_PATH / filename, + Path() / filename, + ) + for folder in self.config.ZHENXUN_BOT_UPDATE_FOLDERS: + self.__copy_files( + self.config.ZHENXUN_BOT_UNZIP_PATH / folder, + self.config.ZHENXUN_BOT_CODE_PATH / folder, + ) + logger.debug("移动真寻更新文件完成!", LOG_COMMAND) + if self.config.ZHENXUN_BOT_DOWNLOAD_FILE.exists(): + self.config.ZHENXUN_BOT_DOWNLOAD_FILE.unlink() + if self.config.ZHENXUN_BOT_UNZIP_PATH.exists(): + shutil.rmtree(self.config.ZHENXUN_BOT_UNZIP_PATH) + except Exception as e: + logger.error("解压真寻最新版文件失败...", LOG_COMMAND, e=e) + raise + finally: + if tf: + tf.close() + + async def zhenxun_zip_update(self, ver_type: Literal["main", "release"]) -> str: + """使用zip更新真寻 + + 参数: + ver_type: 版本类型,main 为最新版,release 为最新release版 + + 返回: + str: 版本号 + """ + new_version = await self.zhenxun_download_zip(ver_type) + await self.zhenxun_unzip() + await self.zhenxun_write_version_file(new_version) + return new_version + + async def zhenxun_git_update( + self, source: Literal["git", "ali"], branch: str = "main", force: bool = False + ) -> RepoUpdateResult: + """使用git或阿里云更新真寻 + + 参数: + source: 更新源,git 为 git 更新,ali 为阿里云更新 + branch: 分支名称 + force: 是否强制更新 + """ + if source == "git": + return await GithubRepoManager.update_via_git( + self.config.ZHENXUN_BOT_GIT, + Path(), + branch=branch, + force=force, + ) + else: + return await AliyunRepoManager.update_via_git( + self.config.ZHENXUN_BOT_GIT, + Path(), + branch=branch, + force=force, + ) + + async def zhenxun_update( + self, + source: Literal["git", "ali"] = "ali", + branch: str = "main", + force: bool = False, + ver_type: Literal["main", "release"] = "main", + ): + """更新真寻 + + 参数: + source: 更新源,git 为 git 更新,ali 为阿里云更新 + branch: 分支名称 + force: 是否强制更新 + ver_type: 版本类型,main 为最新版,release 为最新release版 + """ + if await check_git(): + await self.zhenxun_git_update(source, branch, force) + logger.debug("使用git更新真寻!", LOG_COMMAND) + else: + await self.zhenxun_zip_update(ver_type) + logger.debug("使用zip更新真寻!", LOG_COMMAND) + + async def install_requirements(self): + """安装真寻依赖""" + await VirtualEnvPackageManager.install_requirement( + self.config.REQUIREMENTS_FILE + ) + + # ==================== 资源管理相关方法 ==================== + + def check_resources_exists(self) -> bool: + """检查资源文件是否存在 + + 返回: + bool: 是否存在 + """ + if self.config.RESOURCE_PATH.exists(): + font_path = self.config.RESOURCE_PATH / "font" + if font_path.exists() and os.listdir(font_path): + return True + return False + + async def resources_download_zip(self): + """下载资源文件""" + download_url = await GithubUtils.parse_github_url( + self.config.RESOURCE_GITHUB_URL + ).get_archive_download_urls() + logger.debug("开始下载resources资源包...", LOG_COMMAND) + if await AsyncHttpx.download_file( + download_url, self.config.RESOURCE_ZIP_FILE, stream=True + ): + logger.debug("下载resources资源文件压缩包成功!", LOG_COMMAND) + else: + raise ZhenxunUpdateException("下载resources资源包失败...") + + async def resources_unzip(self): + """解压资源文件""" + if not self.config.RESOURCE_ZIP_FILE.exists(): + raise FileNotFoundError("资源文件压缩包不存在") + if self.config.RESOURCE_UNZIP_PATH.exists(): + shutil.rmtree(self.config.RESOURCE_UNZIP_PATH) + tf = None + try: + tf = zipfile.ZipFile(self.config.RESOURCE_ZIP_FILE) + tf.extractall(self.config.RESOURCE_UNZIP_PATH) + logger.debug("解压文件压缩包完成...", LOG_COMMAND) + self.__copy_files( + self.config.RESOURCE_UNZIP_PATH, self.config.RESOURCE_PATH, True + ) + except Exception as e: + logger.error("解压资源文件失败...", LOG_COMMAND, e=e) + raise + finally: + if tf: + tf.close() + + async def resources_zip_update(self): + """使用zip更新资源文件""" + await self.resources_download_zip() + await self.resources_unzip() + + async def resources_git_update( + self, source: Literal["git", "ali"], branch: str = "main", force: bool = False + ) -> RepoUpdateResult: + """使用git或阿里云更新资源文件 + + 参数: + source: 更新源,git 为 git 更新,ali 为阿里云更新 + branch: 分支名称 + force: 是否强制更新 + """ + if source == "git": + return await GithubRepoManager.update_via_git( + self.config.RESOURCE_GIT, + self.config.RESOURCE_PATH, + branch=branch, + force=force, + ) + else: + return await AliyunRepoManager.update_via_git( + self.config.RESOURCE_GIT, + self.config.RESOURCE_PATH, + branch=branch, + force=force, + ) + + async def resources_update( + self, + source: Literal["git", "ali"] = "ali", + branch: str = "main", + force: bool = False, + ): + """更新资源文件 + + 参数: + source: 更新源,git 为 git 更新,ali 为阿里云更新 + branch: 分支名称 + force: 是否强制更新 + """ + if await check_git(): + await self.resources_git_update(source, branch, force) + logger.debug("使用git更新资源文件!", LOG_COMMAND) + else: + await self.resources_zip_update() + logger.debug("使用zip更新资源文件!", LOG_COMMAND) + + # ==================== Web UI 管理相关方法 ==================== + + def check_webui_exists(self) -> bool: + """检查 Web UI 资源是否存在""" + return bool( + self.config.WEBUI_PATH.exists() and os.listdir(self.config.WEBUI_PATH) + ) + + async def webui_download_zip(self): + """下载 WEBUI_ASSETS 资源""" + download_url = await GithubUtils.parse_github_url( + self.config.WEBUI_DIST_GITHUB_URL + ).get_archive_download_urls() + logger.info("开始下载 WEBUI_ASSETS 资源...", LOG_COMMAND) + if await AsyncHttpx.download_file( + download_url, self.config.WEBUI_DOWNLOAD_FILE, follow_redirects=True + ): + logger.info("下载 WEBUI_ASSETS 成功!", LOG_COMMAND) + else: + raise ZhenxunUpdateException("下载 WEBUI_ASSETS 失败", LOG_COMMAND) + + def __backup_webui(self): + """备份 WEBUI_ASSERT 资源""" + if self.config.WEBUI_PATH.exists(): + if self.config.WEBUI_BACKUP_PATH.exists(): + logger.debug( + f"删除旧的备份webui文件夹 {self.config.WEBUI_BACKUP_PATH}", + LOG_COMMAND, + ) + shutil.rmtree(self.config.WEBUI_BACKUP_PATH) + shutil.copytree(self.config.WEBUI_PATH, self.config.WEBUI_BACKUP_PATH) + + async def webui_unzip(self): + """解压 WEBUI_ASSETS 资源 + + 返回: + str: 更新结果 + """ + if not self.config.WEBUI_DOWNLOAD_FILE.exists(): + raise FileNotFoundError("webui文件压缩包不存在") + tf = None + try: + self.__backup_webui() + self.__clear_folder(self.config.WEBUI_PATH) + tf = zipfile.ZipFile(self.config.WEBUI_DOWNLOAD_FILE) + tf.extractall(self.config.WEBUI_UNZIP_PATH) + logger.debug("Web UI 解压文件压缩包完成...", LOG_COMMAND) + unzip_dir = next(self.config.WEBUI_UNZIP_PATH.iterdir()) + self.__copy_files(unzip_dir, self.config.WEBUI_PATH) + logger.debug("Web UI 复制 WEBUI_ASSETS 成功!", LOG_COMMAND) + shutil.rmtree(self.config.WEBUI_UNZIP_PATH, ignore_errors=True) + except Exception as e: + if self.config.WEBUI_BACKUP_PATH.exists(): + self.__copy_files(self.config.WEBUI_BACKUP_PATH, self.config.WEBUI_PATH) + logger.debug("恢复备份 WEBUI_ASSETS 成功!", LOG_COMMAND) + shutil.rmtree(self.config.WEBUI_BACKUP_PATH, ignore_errors=True) + logger.error("Web UI 更新失败", LOG_COMMAND, e=e) + raise + finally: + if tf: + tf.close() + + async def webui_zip_update(self): + """使用zip更新 Web UI""" + await self.webui_download_zip() + await self.webui_unzip() + + async def webui_git_update( + self, source: Literal["git", "ali"], branch: str = "dist", force: bool = False + ) -> RepoUpdateResult: + """使用git或阿里云更新 Web UI + + 参数: + source: 更新源,git 为 git 更新,ali 为阿里云更新 + branch: 分支名称 + force: 是否强制更新 + """ + if source == "git": + return await GithubRepoManager.update_via_git( + self.config.WEBUI_GIT, + self.config.WEBUI_PATH, + branch=branch, + force=force, + ) + else: + return await AliyunRepoManager.update_via_git( + self.config.WEBUI_GIT, + self.config.WEBUI_PATH, + branch=branch, + force=force, + ) + + async def webui_update( + self, + source: Literal["git", "ali"] = "ali", + branch: str = "dist", + force: bool = False, + ): + """更新 Web UI + + 参数: + source: 更新源,git 为 git 更新,ali 为阿里云更新 + """ + if await check_git(): + await self.webui_git_update(source, branch, force) + logger.debug("使用git更新Web UI!", LOG_COMMAND) + else: + await self.webui_zip_update() + logger.debug("使用zip更新Web UI!", LOG_COMMAND) + + +ZhenxunRepoManager = ZhenxunRepoManagerClass() diff --git a/zhenxun/utils/repo_utils/__init__.py b/zhenxun/utils/repo_utils/__init__.py new file mode 100644 index 00000000..f37ccd26 --- /dev/null +++ b/zhenxun/utils/repo_utils/__init__.py @@ -0,0 +1,60 @@ +""" +仓库管理工具,用于操作GitHub和阿里云CodeUp项目的更新和文件下载 +""" + +from .aliyun_manager import AliyunCodeupManager +from .base_manager import BaseRepoManager +from .config import AliyunCodeupConfig, GithubConfig, RepoConfig +from .exceptions import ( + ApiRateLimitError, + AuthenticationError, + ConfigError, + FileNotFoundError, + NetworkError, + RepoDownloadError, + RepoManagerError, + RepoNotFoundError, + RepoUpdateError, +) +from .file_manager import RepoFileManager as RepoFileManagerClass +from .github_manager import GithubManager +from .models import ( + FileDownloadResult, + RepoCommitInfo, + RepoFileInfo, + RepoType, + RepoUpdateResult, +) +from .utils import check_git, filter_files, glob_to_regex, run_git_command + +GithubRepoManager = GithubManager() +AliyunRepoManager = AliyunCodeupManager() +RepoFileManager = RepoFileManagerClass() + +__all__ = [ + "AliyunCodeupConfig", + "AliyunRepoManager", + "ApiRateLimitError", + "AuthenticationError", + "BaseRepoManager", + "ConfigError", + "FileDownloadResult", + "FileNotFoundError", + "GithubConfig", + "GithubRepoManager", + "NetworkError", + "RepoCommitInfo", + "RepoConfig", + "RepoDownloadError", + "RepoFileInfo", + "RepoFileManager", + "RepoManagerError", + "RepoNotFoundError", + "RepoType", + "RepoUpdateError", + "RepoUpdateResult", + "check_git", + "filter_files", + "glob_to_regex", + "run_git_command", +] diff --git a/zhenxun/utils/repo_utils/aliyun_manager.py b/zhenxun/utils/repo_utils/aliyun_manager.py new file mode 100644 index 00000000..863a5620 --- /dev/null +++ b/zhenxun/utils/repo_utils/aliyun_manager.py @@ -0,0 +1,557 @@ +""" +阿里云CodeUp仓库管理工具 +""" + +import asyncio +from collections.abc import Callable +from datetime import datetime +from pathlib import Path + +from aiocache import cached + +from zhenxun.services.log import logger +from zhenxun.utils.github_utils.models import AliyunFileInfo + +from .base_manager import BaseRepoManager +from .config import LOG_COMMAND, RepoConfig +from .exceptions import ( + AuthenticationError, + FileNotFoundError, + RepoDownloadError, + RepoNotFoundError, + RepoUpdateError, +) +from .models import ( + FileDownloadResult, + RepoCommitInfo, + RepoFileInfo, + RepoType, + RepoUpdateResult, +) + + +class AliyunCodeupManager(BaseRepoManager): + """阿里云CodeUp仓库管理工具""" + + def __init__(self, config: RepoConfig | None = None): + """ + 初始化阿里云CodeUp仓库管理工具 + + Args: + config: 配置,如果为None则使用默认配置 + """ + super().__init__(config) + self._client = None + + async def update_repo( + self, + repo_url: str, + local_path: Path, + branch: str = "main", + include_patterns: list[str] | None = None, + exclude_patterns: list[str] | None = None, + ) -> RepoUpdateResult: + """ + 更新阿里云CodeUp仓库 + + Args: + repo_url: 仓库URL或名称 + local_path: 本地保存路径 + branch: 分支名称 + include_patterns: 包含的文件模式列表,如 ["*.py", "docs/*.md"] + exclude_patterns: 排除的文件模式列表,如 ["__pycache__/*", "*.pyc"] + + Returns: + RepoUpdateResult: 更新结果 + """ + try: + # 检查配置 + self._check_config() + + # 获取仓库名称(从URL中提取) + repo_url = repo_url.split("/tree/")[0].split("/")[-1].replace(".git", "") + + # 获取仓库最新提交ID + newest_commit = await self._get_newest_commit(repo_url, branch) + + # 创建结果对象 + result = RepoUpdateResult( + repo_type=RepoType.ALIYUN, + repo_name=repo_url.split("/tree/")[0] + .split("/")[-1] + .replace(".git", ""), + owner=self.config.aliyun_codeup.organization_id, + old_version="", # 将在后面更新 + new_version=newest_commit, + ) + old_version = await self.read_version_file(local_path) + result.old_version = old_version + + # 如果版本相同,则无需更新 + if old_version == newest_commit: + result.success = True + logger.debug( + f"仓库 {repo_url.split('/')[-1].replace('.git', '')}" + f" 已是最新版本: {newest_commit[:8]}", + LOG_COMMAND, + ) + return result + + # 确保本地目录存在 + local_path.mkdir(parents=True, exist_ok=True) + + # 获取仓库名称(从URL中提取) + repo_name = repo_url.split("/tree/")[0].split("/")[-1].replace(".git", "") + + # 获取变更的文件列表 + changed_files = await self._get_changed_files( + repo_name, old_version or None, newest_commit + ) + + # 过滤文件 + if include_patterns or exclude_patterns: + from .utils import filter_files + + changed_files = filter_files( + changed_files, include_patterns, exclude_patterns + ) + + result.changed_files = changed_files + + # 下载变更的文件 + for file_path in changed_files: + try: + local_file_path = local_path / file_path + await self._download_file( + repo_name, file_path, local_file_path, newest_commit + ) + except Exception as e: + logger.error(f"下载文件 {file_path} 失败", LOG_COMMAND, e=e) + + # 更新版本文件 + await self.write_version_file(local_path, newest_commit) + + result.success = True + return result + + except RepoUpdateError as e: + logger.error(f"更新仓库失败: {e}") + # 从URL中提取仓库名称 + repo_name = repo_url.split("/tree/")[0].split("/")[-1].replace(".git", "") + return RepoUpdateResult( + repo_type=RepoType.ALIYUN, + repo_name=repo_name, + owner=self.config.aliyun_codeup.organization_id, + old_version="", + new_version="", + error_message=str(e), + ) + except Exception as e: + logger.error(f"更新仓库失败: {e}") + # 从URL中提取仓库名称 + repo_name = repo_url.split("/tree/")[0].split("/")[-1].replace(".git", "") + return RepoUpdateResult( + repo_type=RepoType.ALIYUN, + repo_name=repo_name, + owner=self.config.aliyun_codeup.organization_id, + old_version="", + new_version="", + error_message=str(e), + ) + + async def download_file( + self, + repo_url: str, + file_path: str, + local_path: Path, + branch: str = "main", + ) -> FileDownloadResult: + """ + 从阿里云CodeUp下载单个文件 + + Args: + repo_url: 仓库URL或名称 + file_path: 文件在仓库中的路径 + local_path: 本地保存路径 + branch: 分支名称 + + Returns: + FileDownloadResult: 下载结果 + """ + try: + # 检查配置 + self._check_config() + + # 获取仓库名称(从URL中提取) + repo_identifier = ( + repo_url.split("/tree/")[0].split("/")[-1].replace(".git", "") + ) + + # 创建结果对象 + result = FileDownloadResult( + repo_type=RepoType.ALIYUN, + repo_name=repo_url.split("/tree/")[0] + .split("/")[-1] + .replace(".git", ""), + file_path=file_path, + version=branch, + ) + + # 确保本地目录存在 + Path(local_path).parent.mkdir(parents=True, exist_ok=True) + + # 下载文件 + file_size = await self._download_file( + repo_identifier, file_path, local_path, branch + ) + + result.success = True + result.file_size = file_size + return result + + except RepoDownloadError as e: + logger.error(f"下载文件失败: {e}") + # 从URL中提取仓库名称 + repo_name = repo_url.split("/tree/")[0].split("/")[-1].replace(".git", "") + return FileDownloadResult( + repo_type=RepoType.ALIYUN, + repo_name=repo_name, + file_path=file_path, + version=branch, + error_message=str(e), + ) + except Exception as e: + logger.error(f"下载文件失败: {e}") + # 从URL中提取仓库名称 + repo_name = repo_url.split("/tree/")[0].split("/")[-1].replace(".git", "") + return FileDownloadResult( + repo_type=RepoType.ALIYUN, + repo_name=repo_name, + file_path=file_path, + version=branch, + error_message=str(e), + ) + + async def get_file_list( + self, + repo_url: str, + dir_path: str = "", + branch: str = "main", + recursive: bool = False, + ) -> list[RepoFileInfo]: + """ + 获取仓库文件列表 + + Args: + repo_url: 仓库URL或名称 + dir_path: 目录路径,空字符串表示仓库根目录 + branch: 分支名称 + recursive: 是否递归获取子目录 + + Returns: + list[RepoFileInfo]: 文件信息列表 + """ + try: + # 检查配置 + self._check_config() + + # 获取仓库名称(从URL中提取) + repo_identifier = ( + repo_url.split("/tree/")[0].split("/")[-1].replace(".git", "") + ) + + # 获取文件列表 + search_type = "RECURSIVE" if recursive else "DIRECT" + tree_list = await AliyunFileInfo.get_repository_tree( + repo_identifier, dir_path, branch, search_type + ) + + result = [] + for tree in tree_list: + # 跳过非当前目录的文件(如果不是递归模式) + if ( + not recursive + and tree.path != dir_path + and "/" in tree.path.replace(dir_path, "", 1).strip("/") + ): + continue + + file_info = RepoFileInfo( + path=tree.path, + is_dir=tree.type == "tree", + ) + result.append(file_info) + + return result + + except Exception as e: + logger.error(f"获取文件列表失败: {e}") + return [] + + async def get_commit_info( + self, repo_url: str, commit_id: str + ) -> RepoCommitInfo | None: + """ + 获取提交信息 + + Args: + repo_url: 仓库URL或名称 + commit_id: 提交ID + + Returns: + Optional[RepoCommitInfo]: 提交信息,如果获取失败则返回None + """ + try: + # 检查配置 + self._check_config() + + # 获取仓库名称(从URL中提取) + repo_identifier = ( + repo_url.split("/tree/")[0].split("/")[-1].replace(".git", "") + ) + + # 获取提交信息 + # 注意:这里假设AliyunFileInfo有get_commit_info方法,如果没有,需要实现 + commit_data = await self._get_commit_info(repo_identifier, commit_id) + + if not commit_data: + return None + + # 解析提交信息 + id_value = commit_data.get("id", commit_id) + message_value = commit_data.get("message", "") + author_value = commit_data.get("author_name", "") + date_value = commit_data.get( + "authored_date", datetime.now().isoformat() + ).replace("Z", "+00:00") + + return RepoCommitInfo( + commit_id=id_value, + message=message_value, + author=author_value, + commit_time=datetime.fromisoformat(date_value), + changed_files=[], # 阿里云API可能没有直接提供变更文件列表 + ) + except Exception as e: + logger.error(f"获取提交信息失败: {e}") + return None + + def _check_config(self): + """检查配置""" + if not self.config.aliyun_codeup.access_key_id: + raise AuthenticationError("阿里云CodeUp") + + if not self.config.aliyun_codeup.access_key_secret: + raise AuthenticationError("阿里云CodeUp") + + if not self.config.aliyun_codeup.organization_id: + raise AuthenticationError("阿里云CodeUp") + + async def _get_newest_commit(self, repo_name: str, branch: str) -> str: + """ + 获取仓库最新提交ID + + Args: + repo_name: 仓库名称 + branch: 分支名称 + + Returns: + str: 提交ID + """ + try: + newest_commit = await AliyunFileInfo.get_newest_commit(repo_name, branch) + if not newest_commit: + raise RepoNotFoundError(repo_name) + return newest_commit + except Exception as e: + logger.error(f"获取最新提交ID失败: {e}") + raise RepoUpdateError(f"获取最新提交ID失败: {e}") + + async def _get_commit_info(self, repo_name: str, commit_id: str) -> dict: + """ + 获取提交信息 + + Args: + repo_name: 仓库名称 + commit_id: 提交ID + + Returns: + dict: 提交信息 + """ + # 这里需要实现从阿里云获取提交信息的逻辑 + # 由于AliyunFileInfo可能没有get_commit_info方法,这里提供一个简单的实现 + try: + # 这里应该是调用阿里云API获取提交信息 + # 这里只是一个示例,实际上需要根据阿里云API实现 + return { + "id": commit_id, + "message": "提交信息", + "author_name": "作者", + "authored_date": datetime.now().isoformat(), + } + except Exception as e: + logger.error(f"获取提交信息失败: {e}") + return {} + + @cached(ttl=3600) + async def _get_changed_files( + self, repo_name: str, old_commit: str | None, new_commit: str + ) -> list[str]: + """ + 获取两个提交之间变更的文件列表 + + Args: + repo_name: 仓库名称 + old_commit: 旧提交ID,如果为None则获取所有文件 + new_commit: 新提交ID + + Returns: + list[str]: 变更的文件列表 + """ + if not old_commit: + # 如果没有旧提交,则获取仓库中的所有文件 + tree_list = await AliyunFileInfo.get_repository_tree( + repo_name, "", new_commit, "RECURSIVE" + ) + return [tree.path for tree in tree_list if tree.type == "blob"] + + # 获取两个提交之间的差异 + try: + return [] + except Exception as e: + logger.error(f"获取提交差异失败: {e}") + raise RepoUpdateError(f"获取提交差异失败: {e}") + + async def update_via_git( + self, + repo_url: str, + local_path: Path, + branch: str = "main", + force: bool = False, + *, + repo_type: RepoType | None = None, + owner: str | None = None, + prepare_repo_url: Callable[[str], str] | None = None, + ) -> RepoUpdateResult: + """ + 通过Git命令直接更新仓库 + + 参数: + repo_url: 仓库名称 + local_path: 本地仓库路径 + branch: 分支名称 + force: 是否强制拉取 + + 返回: + RepoUpdateResult: 更新结果 + """ + + # 定义预处理函数,构建阿里云CodeUp的URL + def prepare_aliyun_url(repo_url: str) -> str: + import base64 + + repo_name = repo_url.split("/tree/")[0].split("/")[-1].replace(".git", "") + # 构建仓库URL + # 阿里云CodeUp的仓库URL格式通常为: + # https://codeup.aliyun.com/{organization_id}/{organization_name}/{repo_name}.git + url = f"https://codeup.aliyun.com/{self.config.aliyun_codeup.organization_id}/{self.config.aliyun_codeup.organization_name}/{repo_name}.git" + + # 添加访问令牌 - 使用base64解码后的令牌 + if self.config.aliyun_codeup.rdc_access_token_encrypted: + try: + # 解码RDC访问令牌 + token = base64.b64decode( + self.config.aliyun_codeup.rdc_access_token_encrypted.encode() + ).decode() + # 阿里云CodeUp使用oauth2:token的格式进行身份验证 + url = url.replace("https://", f"https://oauth2:{token}@") + logger.debug(f"使用RDC令牌构建阿里云URL: {url.split('@')[0]}@***") + except Exception as e: + logger.error(f"解码RDC令牌失败: {e}") + + return url + + # 调用基类的update_via_git方法 + return await super().update_via_git( + repo_url=repo_url, + local_path=local_path, + branch=branch, + force=force, + repo_type=RepoType.ALIYUN, + owner=self.config.aliyun_codeup.organization_id, + prepare_repo_url=prepare_aliyun_url, + ) + + async def update( + self, + repo_url: str, + local_path: Path, + branch: str = "main", + use_git: bool = True, + force: bool = False, + include_patterns: list[str] | None = None, + exclude_patterns: list[str] | None = None, + ) -> RepoUpdateResult: + """ + 更新仓库,可选择使用Git命令或API方式 + + 参数: + repo_url: 仓库名称 + local_path: 本地保存路径 + branch: 分支名称 + use_git: 是否使用Git命令更新 + include_patterns: 包含的文件模式列表,如 ["*.py", "docs/*.md"] + exclude_patterns: 排除的文件模式列表,如 ["__pycache__/*", "*.pyc"] + + 返回: + RepoUpdateResult: 更新结果 + """ + if use_git: + return await self.update_via_git(repo_url, local_path, branch, force) + else: + return await self.update_repo( + repo_url, local_path, branch, include_patterns, exclude_patterns + ) + + async def _download_file( + self, repo_name: str, file_path: str, local_path: Path, ref: str + ) -> int: + """ + 下载文件 + + Args: + repo_name: 仓库名称 + file_path: 文件在仓库中的路径 + local_path: 本地保存路径 + ref: 分支/标签/提交ID + + Returns: + int: 文件大小(字节) + """ + # 确保目录存在 + local_path.parent.mkdir(parents=True, exist_ok=True) + + # 获取文件内容 + for retry in range(self.config.aliyun_codeup.download_retry + 1): + try: + content = await AliyunFileInfo.get_file_content( + file_path, repo_name, ref + ) + + if content is None: + raise FileNotFoundError(file_path, repo_name) + + # 保存文件 + return await self.save_file_content(content.encode("utf-8"), local_path) + + except FileNotFoundError as e: + # 这些错误不需要重试 + raise e + except Exception as e: + if retry < self.config.aliyun_codeup.download_retry: + logger.warning("下载文件失败,将重试", LOG_COMMAND, e=e) + await asyncio.sleep(1) + continue + raise RepoDownloadError(f"下载文件失败: {e}") + + raise RepoDownloadError("下载文件失败: 超过最大重试次数") diff --git a/zhenxun/utils/repo_utils/base_manager.py b/zhenxun/utils/repo_utils/base_manager.py new file mode 100644 index 00000000..efe306b6 --- /dev/null +++ b/zhenxun/utils/repo_utils/base_manager.py @@ -0,0 +1,432 @@ +""" +仓库管理工具的基础管理器 +""" + +from abc import ABC, abstractmethod +from pathlib import Path + +import aiofiles + +from zhenxun.services.log import logger + +from .config import LOG_COMMAND, RepoConfig +from .models import ( + FileDownloadResult, + RepoCommitInfo, + RepoFileInfo, + RepoType, + RepoUpdateResult, +) +from .utils import check_git, filter_files, run_git_command + + +class BaseRepoManager(ABC): + """仓库管理工具基础类""" + + def __init__(self, config: RepoConfig | None = None): + """ + 初始化仓库管理工具 + + 参数: + config: 配置,如果为None则使用默认配置 + """ + self.config = config or RepoConfig.get_instance() + self.config.ensure_dirs() + + @abstractmethod + async def update_repo( + self, + repo_url: str, + local_path: Path, + branch: str = "main", + include_patterns: list[str] | None = None, + exclude_patterns: list[str] | None = None, + ) -> RepoUpdateResult: + """ + 更新仓库 + + 参数: + repo_url: 仓库URL或名称 + local_path: 本地保存路径 + branch: 分支名称 + include_patterns: 包含的文件模式列表,如 ["*.py", "docs/*.md"] + exclude_patterns: 排除的文件模式列表,如 ["__pycache__/*", "*.pyc"] + + 返回: + RepoUpdateResult: 更新结果 + """ + pass + + @abstractmethod + async def download_file( + self, + repo_url: str, + file_path: str, + local_path: Path, + branch: str = "main", + ) -> FileDownloadResult: + """ + 下载单个文件 + + 参数: + repo_url: 仓库URL或名称 + file_path: 文件在仓库中的路径 + local_path: 本地保存路径 + branch: 分支名称 + + 返回: + FileDownloadResult: 下载结果 + """ + pass + + @abstractmethod + async def get_file_list( + self, + repo_url: str, + dir_path: str = "", + branch: str = "main", + recursive: bool = False, + ) -> list[RepoFileInfo]: + """ + 获取仓库文件列表 + + 参数: + repo_url: 仓库URL或名称 + dir_path: 目录路径,空字符串表示仓库根目录 + branch: 分支名称 + recursive: 是否递归获取子目录 + + 返回: + List[RepoFileInfo]: 文件信息列表 + """ + pass + + @abstractmethod + async def get_commit_info( + self, repo_url: str, commit_id: str + ) -> RepoCommitInfo | None: + """ + 获取提交信息 + + 参数: + repo_url: 仓库URL或名称 + commit_id: 提交ID + + 返回: + Optional[RepoCommitInfo]: 提交信息,如果获取失败则返回None + """ + pass + + async def save_file_content(self, content: bytes, local_path: Path) -> int: + """ + 保存文件内容 + + 参数: + content: 文件内容 + local_path: 本地保存路径 + + 返回: + int: 文件大小(字节) + """ + # 确保目录存在 + local_path.parent.mkdir(parents=True, exist_ok=True) + + # 保存文件 + async with aiofiles.open(local_path, "wb") as f: + await f.write(content) + + return len(content) + + async def read_version_file(self, local_dir: Path) -> str: + """ + 读取版本文件 + + 参数: + local_dir: 本地目录 + + 返回: + str: 版本号 + """ + version_file = local_dir / "__version__" + if not version_file.exists(): + return "" + + try: + async with aiofiles.open(version_file) as f: + return (await f.read()).strip() + except Exception as e: + logger.error(f"读取版本文件失败: {e}") + return "" + + async def write_version_file(self, local_dir: Path, version: str) -> bool: + """ + 写入版本文件 + + 参数: + local_dir: 本地目录 + version: 版本号 + + 返回: + bool: 是否成功 + """ + version_file = local_dir / "__version__" + + try: + version_bb = "vNone" + async with aiofiles.open(version_file) as rf: + if text := await rf.read(): + version_bb = text.strip().split("-")[0] + async with aiofiles.open(version_file, "w") as f: + await f.write(f"{version_bb}-{version[:6]}") + return True + except Exception as e: + logger.error(f"写入版本文件失败: {e}") + return False + + def filter_files( + self, + files: list[str], + include_patterns: list[str] | None = None, + exclude_patterns: list[str] | None = None, + ) -> list[str]: + """ + 过滤文件列表 + + 参数: + files: 文件列表 + include_patterns: 包含的文件模式列表,如 ["*.py", "docs/*.md"] + exclude_patterns: 排除的文件模式列表,如 ["__pycache__/*", "*.pyc"] + + 返回: + List[str]: 过滤后的文件列表 + """ + return filter_files(files, include_patterns, exclude_patterns) + + async def update_via_git( + self, + repo_url: str, + local_path: Path, + branch: str = "main", + force: bool = False, + *, + repo_type: RepoType | None = None, + owner="", + prepare_repo_url=None, + ) -> RepoUpdateResult: + """ + 通过Git命令直接更新仓库 + + 参数: + repo_url: 仓库URL或名称 + local_path: 本地仓库路径 + branch: 分支名称 + force: 是否强制拉取 + repo_type: 仓库类型 + owner: 仓库拥有者 + prepare_repo_url: 预处理仓库URL的函数 + + 返回: + RepoUpdateResult: 更新结果 + """ + from .models import RepoType + + repo_name = repo_url.split("/tree/")[0].split("/")[-1].replace(".git", "") + + try: + # 创建结果对象 + result = RepoUpdateResult( + repo_type=repo_type or RepoType.GITHUB, # 默认使用GitHub类型 + repo_name=repo_name, + owner=owner or "", + old_version="", + new_version="", + ) + + # 检查Git是否可用 + if not await check_git(): + return RepoUpdateResult( + repo_type=repo_type or RepoType.GITHUB, + repo_name=repo_name, + owner=owner or "", + old_version="", + new_version="", + error_message="Git命令不可用", + ) + + # 预处理仓库URL + if prepare_repo_url: + repo_url = prepare_repo_url(repo_url) + + # 检查本地目录是否存在 + if not local_path.exists(): + # 如果不存在,则克隆仓库 + logger.info(f"克隆仓库 {repo_url} 到 {local_path}", LOG_COMMAND) + success, stdout, stderr = await run_git_command( + f"clone -b {branch} {repo_url} {local_path}" + ) + if not success: + return RepoUpdateResult( + repo_type=repo_type or RepoType.GITHUB, + repo_name=repo_name, + owner=owner or "", + old_version="", + new_version="", + error_message=f"克隆仓库失败: {stderr}", + ) + + # 获取当前提交ID + success, new_version, _ = await run_git_command( + "rev-parse HEAD", cwd=local_path + ) + result.new_version = new_version.strip() + result.success = True + return result + + # 如果目录存在,检查是否是Git仓库 + # 首先检查目录本身是否有.git文件夹 + git_dir = local_path / ".git" + + if not git_dir.is_dir(): + # 如果不是Git仓库,尝试初始化它 + logger.info(f"目录 {local_path} 不是Git仓库,尝试初始化", LOG_COMMAND) + init_success, _, init_stderr = await run_git_command( + "init", cwd=local_path + ) + if not init_success: + return RepoUpdateResult( + repo_type=repo_type or RepoType.GITHUB, + repo_name=repo_name, + owner=owner or "", + old_version="", + new_version="", + error_message=f"初始化Git仓库失败: {init_stderr}", + ) + + # 添加远程仓库 + remote_success, _, remote_stderr = await run_git_command( + f"remote add origin {repo_url}", cwd=local_path + ) + if not remote_success: + return RepoUpdateResult( + repo_type=repo_type or RepoType.GITHUB, + repo_name=repo_name, + owner=owner or "", + old_version="", + new_version="", + error_message=f"添加远程仓库失败: {remote_stderr}", + ) + + logger.info(f"成功初始化Git仓库 {local_path}", LOG_COMMAND) + + # 获取当前提交ID作为旧版本 + success, old_version, _ = await run_git_command( + "rev-parse HEAD", cwd=local_path + ) + result.old_version = old_version.strip() + + # 获取当前远程URL + success, remote_url, _ = await run_git_command( + "config --get remote.origin.url", cwd=local_path + ) + + # 如果远程URL不匹配,则更新它 + remote_url = remote_url.strip() + if success and repo_url not in remote_url and remote_url not in repo_url: + logger.info(f"更新远程URL: {remote_url} -> {repo_url}", LOG_COMMAND) + await run_git_command( + f"remote set-url origin {repo_url}", cwd=local_path + ) + + # 获取远程更新 + logger.info(f"获取远程更新: {repo_url}", LOG_COMMAND) + success, _, stderr = await run_git_command("fetch origin", cwd=local_path) + if not success: + return RepoUpdateResult( + repo_type=repo_type or RepoType.GITHUB, + repo_name=repo_name, + owner=owner or "", + old_version=old_version.strip(), + new_version="", + error_message=f"获取远程更新失败: {stderr}", + ) + + # 获取当前分支 + success, current_branch, _ = await run_git_command( + "rev-parse --abbrev-ref HEAD", cwd=local_path + ) + current_branch = current_branch.strip() + + # 如果当前分支不是目标分支,则切换分支 + if success and current_branch != branch: + logger.info(f"切换分支: {current_branch} -> {branch}", LOG_COMMAND) + success, _, stderr = await run_git_command( + f"checkout {branch}", cwd=local_path + ) + if not success: + return RepoUpdateResult( + repo_type=repo_type or RepoType.GITHUB, + repo_name=repo_name, + owner=owner or "", + old_version=old_version.strip(), + new_version="", + error_message=f"切换分支失败: {stderr}", + ) + + # 拉取最新代码 + logger.info(f"拉取最新代码: {repo_url}", LOG_COMMAND) + pull_cmd = f"pull origin {branch}" + if force: + pull_cmd = f"fetch --all && git reset --hard origin/{branch}" + logger.info("使用强制拉取模式", LOG_COMMAND) + success, _, stderr = await run_git_command(pull_cmd, cwd=local_path) + if not success: + return RepoUpdateResult( + repo_type=repo_type or RepoType.GITHUB, + repo_name=repo_name, + owner=owner or "", + old_version=old_version.strip(), + new_version="", + error_message=f"拉取最新代码失败: {stderr}", + ) + + # 获取更新后的提交ID + success, new_version, _ = await run_git_command( + "rev-parse HEAD", cwd=local_path + ) + result.new_version = new_version.strip() + + # 如果版本相同,则无需更新 + if old_version.strip() == new_version.strip(): + logger.info( + f"仓库 {repo_url} 已是最新版本: {new_version.strip()}", LOG_COMMAND + ) + result.success = True + return result + + # 获取变更的文件列表 + success, changed_files_output, _ = await run_git_command( + f"diff --name-only {old_version.strip()} {new_version.strip()}", + cwd=local_path, + ) + if success: + changed_files = [ + line.strip() + for line in changed_files_output.splitlines() + if line.strip() + ] + result.changed_files = changed_files + logger.info(f"变更的文件列表: {changed_files}", LOG_COMMAND) + + result.success = True + return result + + except Exception as e: + logger.error("Git更新失败", LOG_COMMAND, e=e) + return RepoUpdateResult( + repo_type=repo_type or RepoType.GITHUB, + repo_name=repo_name, + owner=owner or "", + old_version="", + new_version="", + error_message=str(e), + ) diff --git a/zhenxun/utils/repo_utils/config.py b/zhenxun/utils/repo_utils/config.py new file mode 100644 index 00000000..befe7555 --- /dev/null +++ b/zhenxun/utils/repo_utils/config.py @@ -0,0 +1,77 @@ +""" +仓库管理工具的配置模块 +""" + +from dataclasses import dataclass, field +from pathlib import Path + +from zhenxun.configs.path_config import TEMP_PATH + +LOG_COMMAND = "RepoUtils" + + +@dataclass +class GithubConfig: + """GitHub配置""" + + # API超时时间(秒) + api_timeout: int = 30 + # 下载超时时间(秒) + download_timeout: int = 60 + # 下载重试次数 + download_retry: int = 3 + # 代理配置 + proxy: str | None = None + + +@dataclass +class AliyunCodeupConfig: + """阿里云CodeUp配置""" + + # 访问密钥ID + access_key_id: str = "LTAI5tNmf7KaTAuhcvRobAQs" + # 访问密钥密钥 + access_key_secret: str = "NmJ3d2VNRU1MREY0T1RtRnBqMlFqdlBxN3pMUk1j" + # 组织ID + organization_id: str = "67a361cf556e6cdab537117a" + # 组织名称 + organization_name: str = "zhenxun-org" + # RDC Access Token + rdc_access_token_encrypted: str = ( + "cHQtYXp0allnQWpub0FYZWpqZm1RWGtneHk0XzBlMmYzZTZmLWQwOWItNDE4Mi1iZWUx" + "LTQ1ZTFkYjI0NGRlMg==" + ) + # 区域 + region: str = "cn-hangzhou" + # 端点 + endpoint: str = "devops.cn-hangzhou.aliyuncs.com" + # 下载重试次数 + download_retry: int = 3 + + +@dataclass +class RepoConfig: + """仓库管理工具配置""" + + # 缓存目录 + cache_dir: Path = TEMP_PATH / "repo_cache" + + # GitHub配置 + github: GithubConfig = field(default_factory=GithubConfig) + + # 阿里云CodeUp配置 + aliyun_codeup: AliyunCodeupConfig = field(default_factory=AliyunCodeupConfig) + + # 单例实例 + _instance = None + + @classmethod + def get_instance(cls) -> "RepoConfig": + """获取单例实例""" + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def ensure_dirs(self): + """确保目录存在""" + self.cache_dir.mkdir(parents=True, exist_ok=True) diff --git a/zhenxun/utils/repo_utils/exceptions.py b/zhenxun/utils/repo_utils/exceptions.py new file mode 100644 index 00000000..d508f303 --- /dev/null +++ b/zhenxun/utils/repo_utils/exceptions.py @@ -0,0 +1,68 @@ +""" +仓库管理工具的异常类 +""" + + +class RepoManagerError(Exception): + """仓库管理工具异常基类""" + + def __init__(self, message: str, repo_name: str | None = None): + self.message = message + self.repo_name = repo_name + super().__init__(self.message) + + +class RepoUpdateError(RepoManagerError): + """仓库更新异常""" + + def __init__(self, message: str, repo_name: str | None = None): + super().__init__(f"仓库更新失败: {message}", repo_name) + + +class RepoDownloadError(RepoManagerError): + """仓库下载异常""" + + def __init__(self, message: str, repo_name: str | None = None): + super().__init__(f"文件下载失败: {message}", repo_name) + + +class RepoNotFoundError(RepoManagerError): + """仓库不存在异常""" + + def __init__(self, repo_name: str): + super().__init__(f"仓库不存在: {repo_name}", repo_name) + + +class FileNotFoundError(RepoManagerError): + """文件不存在异常""" + + def __init__(self, file_path: str, repo_name: str | None = None): + super().__init__(f"文件不存在: {file_path}", repo_name) + + +class AuthenticationError(RepoManagerError): + """认证异常""" + + def __init__(self, repo_type: str): + super().__init__(f"认证失败: {repo_type}") + + +class ApiRateLimitError(RepoManagerError): + """API速率限制异常""" + + def __init__(self, repo_type: str): + super().__init__(f"API速率限制: {repo_type}") + + +class NetworkError(RepoManagerError): + """网络异常""" + + def __init__(self, message: str): + super().__init__(f"网络错误: {message}") + + +class ConfigError(RepoManagerError): + """配置异常""" + + def __init__(self, message: str): + super().__init__(f"配置错误: {message}") diff --git a/zhenxun/utils/repo_utils/file_manager.py b/zhenxun/utils/repo_utils/file_manager.py new file mode 100644 index 00000000..43a87a7b --- /dev/null +++ b/zhenxun/utils/repo_utils/file_manager.py @@ -0,0 +1,543 @@ +""" +仓库文件管理器,用于从GitHub和阿里云CodeUp获取指定文件内容 +""" + +from pathlib import Path +from typing import cast, overload + +import aiofiles +from httpx import Response + +from zhenxun.services.log import logger +from zhenxun.utils.github_utils import GithubUtils +from zhenxun.utils.github_utils.models import AliyunTreeType, GitHubStrategy, TreeType +from zhenxun.utils.http_utils import AsyncHttpx + +from .config import LOG_COMMAND, RepoConfig +from .exceptions import FileNotFoundError, NetworkError, RepoManagerError +from .models import FileDownloadResult, RepoFileInfo, RepoType + + +class RepoFileManager: + """仓库文件管理器,用于获取GitHub和阿里云仓库中的文件内容""" + + def __init__(self, config: RepoConfig | None = None): + """ + 初始化仓库文件管理器 + + 参数: + config: 配置,如果为None则使用默认配置 + """ + self.config = config or RepoConfig.get_instance() + self.config.ensure_dirs() + + @overload + async def get_github_file_content( + self, url: str, file_path: str, ignore_error: bool = False + ) -> str: ... + + @overload + async def get_github_file_content( + self, url: str, file_path: list[str], ignore_error: bool = False + ) -> list[tuple[str, str]]: ... + + async def get_github_file_content( + self, url: str, file_path: str | list[str], ignore_error: bool = False + ) -> str | list[tuple[str, str]]: + """ + 获取GitHub仓库文件内容 + + 参数: + url: 仓库URL + file_path: 文件路径或文件路径列表 + ignore_error: 是否忽略错误 + + 返回: + list[tuple[str, str]]: 文件路径,文件内容 + """ + results = [] + is_str_input = isinstance(file_path, str) + try: + if is_str_input: + file_path = [file_path] + repo_info = GithubUtils.parse_github_url(url) + if await repo_info.update_repo_commit(): + logger.info(f"获取最新提交: {repo_info.branch}", LOG_COMMAND) + else: + logger.warning(f"获取最新提交失败: {repo_info}", LOG_COMMAND) + for f in file_path: + try: + file_url = await repo_info.get_raw_download_urls(f) + for fu in file_url: + response: Response = await AsyncHttpx.get( + fu, check_status_code=200 + ) + if response.status_code == 200: + logger.info(f"获取github文件内容成功: {f}", LOG_COMMAND) + # 确保使用UTF-8编码解析响应内容 + try: + text_content = response.content.decode("utf-8") + except UnicodeDecodeError: + # 如果UTF-8解码失败,尝试其他编码 + text_content = response.content.decode( + "utf-8", errors="ignore" + ) + logger.warning( + f"解码文件内容时出现错误,使用忽略错误模式: {f}", + LOG_COMMAND, + ) + results.append((f, text_content)) + break + else: + logger.warning( + f"获取github文件内容失败: {response.status_code}", + LOG_COMMAND, + ) + except Exception as e: + logger.warning(f"获取github文件内容失败: {f}", LOG_COMMAND, e=e) + if not ignore_error: + raise + except Exception as e: + logger.error(f"获取GitHub文件内容失败: {file_path}", LOG_COMMAND, e=e) + raise + logger.debug(f"获取GitHub文件内容: {[r[0] for r in results]}", LOG_COMMAND) + + return results[0][1] if is_str_input and results else results + + @overload + async def get_aliyun_file_content( + self, + repo_name: str, + file_path: str, + branch: str = "main", + ignore_error: bool = False, + ) -> str: ... + + @overload + async def get_aliyun_file_content( + self, + repo_name: str, + file_path: list[str], + branch: str = "main", + ignore_error: bool = False, + ) -> list[tuple[str, str]]: ... + + async def get_aliyun_file_content( + self, + repo_name: str, + file_path: str | list[str], + branch: str = "main", + ignore_error: bool = False, + ) -> str | list[tuple[str, str]]: + """ + 获取阿里云CodeUp仓库文件内容 + + 参数: + repo: 仓库名称 + file_path: 文件路径 + branch: 分支名称 + ignore_error: 是否忽略错误 + 返回: + list[tuple[str, str]]: 文件路径,文件内容 + """ + results = [] + is_str_input = isinstance(file_path, str) + # 导入阿里云相关模块 + from zhenxun.utils.github_utils.models import AliyunFileInfo + + if is_str_input: + file_path = [file_path] + for f in file_path: + try: + content = await AliyunFileInfo.get_file_content( + file_path=f, repo=repo_name, ref=branch + ) + results.append((f, content)) + except Exception as e: + logger.warning(f"获取阿里云文件内容失败: {file_path}", LOG_COMMAND, e=e) + if not ignore_error: + raise + logger.debug(f"获取阿里云文件内容: {[r[0] for r in results]}", LOG_COMMAND) + return results[0][1] if is_str_input and results else results + + @overload + async def get_file_content( + self, + repo_url: str, + file_path: str, + branch: str = "main", + repo_type: RepoType | None = None, + ignore_error: bool = False, + ) -> str: ... + + @overload + async def get_file_content( + self, + repo_url: str, + file_path: list[str], + branch: str = "main", + repo_type: RepoType | None = None, + ignore_error: bool = False, + ) -> list[tuple[str, str]]: ... + + async def get_file_content( + self, + repo_url: str, + file_path: str | list[str], + branch: str = "main", + repo_type: RepoType | None = None, + ignore_error: bool = False, + ) -> str | list[tuple[str, str]]: + """ + 获取仓库文件内容 + + 参数: + repo_url: 仓库URL + file_path: 文件路径 + branch: 分支名称 + repo_type: 仓库类型,如果为None则自动判断 + ignore_error: 是否忽略错误 + + 返回: + str: 文件内容 + """ + # 确定仓库类型 + repo_name = ( + repo_url.split("/tree/")[0].split("/")[-1].replace(".git", "").strip() + ) + if repo_type is None: + try: + return await self.get_aliyun_file_content( + repo_name, file_path, branch, ignore_error + ) + except Exception: + return await self.get_github_file_content( + repo_url, file_path, ignore_error + ) + + try: + if repo_type == RepoType.GITHUB: + return await self.get_github_file_content( + repo_url, file_path, ignore_error + ) + + elif repo_type == RepoType.ALIYUN: + return await self.get_aliyun_file_content( + repo_name, file_path, branch, ignore_error + ) + + except Exception as e: + if isinstance(e, FileNotFoundError | NetworkError | RepoManagerError): + raise + raise RepoManagerError(f"获取文件内容失败: {e}") + + async def list_directory_files( + self, + repo_url: str, + directory_path: str = "", + branch: str = "main", + repo_type: RepoType | None = None, + recursive: bool = True, + ) -> list[RepoFileInfo]: + """ + 获取仓库目录下的所有文件路径 + + 参数: + repo_url: 仓库URL + directory_path: 目录路径,默认为仓库根目录 + branch: 分支名称 + repo_type: 仓库类型,如果为None则自动判断 + recursive: 是否递归获取子目录文件 + + 返回: + list[RepoFileInfo]: 文件信息列表 + """ + repo_name = ( + repo_url.split("/tree/")[0].split("/")[-1].replace(".git", "").strip() + ) + try: + if repo_type is None: + # 尝试GitHub,失败则尝试阿里云 + try: + return await self._list_github_directory_files( + repo_url, directory_path, branch, recursive + ) + except Exception as e: + logger.warning( + "获取GitHub目录文件失败,尝试阿里云", LOG_COMMAND, e=e + ) + return await self._list_aliyun_directory_files( + repo_name, directory_path, branch, recursive + ) + if repo_type == RepoType.GITHUB: + return await self._list_github_directory_files( + repo_url, directory_path, branch, recursive + ) + elif repo_type == RepoType.ALIYUN: + return await self._list_aliyun_directory_files( + repo_name, directory_path, branch, recursive + ) + except Exception as e: + logger.error(f"获取目录文件列表失败: {directory_path}", LOG_COMMAND, e=e) + if isinstance(e, FileNotFoundError | NetworkError | RepoManagerError): + raise + raise RepoManagerError(f"获取目录文件列表失败: {e}") + + async def _list_github_directory_files( + self, + repo_url: str, + directory_path: str = "", + branch: str = "main", + recursive: bool = True, + build_tree: bool = False, + ) -> list[RepoFileInfo]: + """ + 获取GitHub仓库目录下的所有文件路径 + + 参数: + repo_url: 仓库URL + directory_path: 目录路径,默认为仓库根目录 + branch: 分支名称 + recursive: 是否递归获取子目录文件 + build_tree: 是否构建目录树 + + 返回: + list[RepoFileInfo]: 文件信息列表 + """ + try: + repo_info = GithubUtils.parse_github_url(repo_url) + if await repo_info.update_repo_commit(): + logger.info(f"获取最新提交: {repo_info.branch}", LOG_COMMAND) + else: + logger.warning(f"获取最新提交失败: {repo_info}", LOG_COMMAND) + + # 获取仓库树信息 + strategy = GitHubStrategy() + strategy.body = await GitHubStrategy.parse_repo_info(repo_info) + + # 处理目录路径,确保格式正确 + if directory_path and not directory_path.endswith("/") and recursive: + directory_path = f"{directory_path}/" + + # 获取文件列表 + file_list = [] + for tree_item in strategy.body.tree: + # 如果不是递归模式,只获取当前目录下的文件 + if not recursive and "/" in tree_item.path.replace( + directory_path, "", 1 + ): + continue + + # 检查是否在指定目录下 + if directory_path and not tree_item.path.startswith(directory_path): + continue + + # 创建文件信息对象 + file_info = RepoFileInfo( + path=tree_item.path, + is_dir=tree_item.type == TreeType.DIR, + size=tree_item.size, + last_modified=None, # GitHub API不直接提供最后修改时间 + ) + file_list.append(file_info) + + # 构建目录树结构 + if recursive and build_tree: + file_list = self._build_directory_tree(file_list) + + return file_list + + except Exception as e: + logger.error( + f"获取GitHub目录文件列表失败: {directory_path}", LOG_COMMAND, e=e + ) + raise + + async def _list_aliyun_directory_files( + self, + repo_name: str, + directory_path: str = "", + branch: str = "main", + recursive: bool = True, + build_tree: bool = False, + ) -> list[RepoFileInfo]: + """ + 获取阿里云CodeUp仓库目录下的所有文件路径 + + 参数: + repo_name: 仓库名称 + directory_path: 目录路径,默认为仓库根目录 + branch: 分支名称 + recursive: 是否递归获取子目录文件 + build_tree: 是否构建目录树 + + 返回: + list[RepoFileInfo]: 文件信息列表 + """ + try: + from zhenxun.utils.github_utils.models import AliyunFileInfo + + # 获取仓库树信息 + search_type = "RECURSIVE" if recursive else "DIRECT" + tree_list = await AliyunFileInfo.get_repository_tree( + repo=repo_name, + path=directory_path, + ref=branch, + search_type=search_type, + ) + + # 创建文件信息对象列表 + file_list = [] + for tree_item in tree_list: + file_info = RepoFileInfo( + path=tree_item.path, + is_dir=tree_item.type == AliyunTreeType.DIR, + size=None, # 阿里云API不直接提供文件大小 + last_modified=None, # 阿里云API不直接提供最后修改时间 + ) + file_list.append(file_info) + + # 构建目录树结构 + if recursive and build_tree: + file_list = self._build_directory_tree(file_list) + + return file_list + + except Exception as e: + logger.error( + f"获取阿里云目录文件列表失败: {directory_path}", LOG_COMMAND, e=e + ) + raise + + def _build_directory_tree( + self, file_list: list[RepoFileInfo] + ) -> list[RepoFileInfo]: + """ + 构建目录树结构 + + 参数: + file_list: 文件信息列表 + + 返回: + list[RepoFileInfo]: 根目录下的文件信息列表 + """ + # 按路径排序,确保父目录在子目录之前 + file_list.sort(key=lambda x: x.path) + # 创建路径到文件信息的映射 + path_map = {file_info.path: file_info for file_info in file_list} + # 根目录文件列表 + root_files = [] + + for file_info in file_list: + if parent_path := "/".join(file_info.path.split("/")[:-1]): + # 如果有父目录,将当前文件添加到父目录的子文件列表中 + if parent_path in path_map: + path_map[parent_path].children.append(file_info) + else: + # 如果父目录不在列表中,创建一个虚拟的父目录 + parent_info = RepoFileInfo( + path=parent_path, is_dir=True, children=[file_info] + ) + path_map[parent_path] = parent_info + # 检查父目录的父目录 + grand_parent_path = "/".join(parent_path.split("/")[:-1]) + if grand_parent_path and grand_parent_path in path_map: + path_map[grand_parent_path].children.append(parent_info) + else: + root_files.append(parent_info) + else: + # 如果没有父目录,则是根目录下的文件 + root_files.append(file_info) + + # 返回根目录下的文件列表 + return [ + file + for file in root_files + if all(f.path != file.path for f in file_list if f != file) + ] + + async def download_files( + self, + repo_url: str, + file_path: tuple[str, Path] | list[tuple[str, Path]], + branch: str = "main", + repo_type: RepoType | None = None, + ignore_error: bool = False, + ) -> FileDownloadResult: + """ + 下载单个文件 + + 参数: + repo_url: 仓库URL + file_path: 文件在仓库中的路径,本地存储路径 + branch: 分支名称 + repo_type: 仓库类型,如果为None则自动判断 + ignore_error: 是否忽略错误 + + 返回: + FileDownloadResult: 下载结果 + """ + # 确定仓库类型和所有者 + repo_name = ( + repo_url.split("/tree/")[0].split("/")[-1].replace(".git", "").strip() + ) + + if isinstance(file_path, tuple): + file_path = [file_path] + + file_path_mapping = {f[0]: f[1] for f in file_path} + + # 创建结果对象 + result = FileDownloadResult( + repo_type=repo_type, + repo_name=repo_name, + file_path=file_path, + version=branch, + ) + + try: + # 由于我们传入的是列表,所以这里一定返回列表 + file_paths = [f[0] for f in file_path] + if len(file_paths) == 1: + # 如果只有一个文件,可能返回单个元组 + file_contents_result = await self.get_file_content( + repo_url, file_paths[0], branch, repo_type, ignore_error + ) + if isinstance(file_contents_result, tuple): + file_contents = [file_contents_result] + elif isinstance(file_contents_result, str): + file_contents = [(file_paths[0], file_contents_result)] + else: + file_contents = cast(list[tuple[str, str]], file_contents_result) + else: + # 多个文件一定返回列表 + file_contents = cast( + list[tuple[str, str]], + await self.get_file_content( + repo_url, file_paths, branch, repo_type, ignore_error + ), + ) + + for repo_file_path, content in file_contents: + local_path = file_path_mapping[repo_file_path] + local_path.parent.mkdir(parents=True, exist_ok=True) + # 使用二进制模式写入文件,避免编码问题 + if isinstance(content, str): + content_bytes = content.encode("utf-8") + else: + content_bytes = content + logger.warning(f"写入文件: {local_path}") + async with aiofiles.open(local_path, "wb") as f: + await f.write(content_bytes) + result.success = True + # 计算文件大小 + result.file_size = sum( + len(content.encode("utf-8") if isinstance(content, str) else content) + for _, content in file_contents + ) + return result + except Exception as e: + logger.error(f"下载文件失败: {e}") + result.success = False + result.error_message = str(e) + return result diff --git a/zhenxun/utils/repo_utils/github_manager.py b/zhenxun/utils/repo_utils/github_manager.py new file mode 100644 index 00000000..462c2723 --- /dev/null +++ b/zhenxun/utils/repo_utils/github_manager.py @@ -0,0 +1,526 @@ +""" +GitHub仓库管理工具 +""" + +import asyncio +from collections.abc import Callable +from datetime import datetime +from pathlib import Path + +from aiocache import cached + +from zhenxun.services.log import logger +from zhenxun.utils.github_utils import GithubUtils, RepoInfo +from zhenxun.utils.http_utils import AsyncHttpx + +from .base_manager import BaseRepoManager +from .config import LOG_COMMAND, RepoConfig +from .exceptions import ( + ApiRateLimitError, + FileNotFoundError, + NetworkError, + RepoDownloadError, + RepoNotFoundError, + RepoUpdateError, +) +from .models import ( + FileDownloadResult, + RepoCommitInfo, + RepoFileInfo, + RepoType, + RepoUpdateResult, +) + + +class GithubManager(BaseRepoManager): + """GitHub仓库管理工具""" + + def __init__(self, config: RepoConfig | None = None): + """ + 初始化GitHub仓库管理工具 + + 参数: + config: 配置,如果为None则使用默认配置 + """ + super().__init__(config) + + async def update_repo( + self, + repo_url: str, + local_path: Path, + branch: str = "main", + include_patterns: list[str] | None = None, + exclude_patterns: list[str] | None = None, + ) -> RepoUpdateResult: + """ + 更新GitHub仓库 + + 参数: + repo_url: 仓库URL,格式为 https://github.com/owner/repo + local_path: 本地保存路径 + branch: 分支名称 + include_patterns: 包含的文件模式列表,如 ["*.py", "docs/*.md"] + exclude_patterns: 排除的文件模式列表,如 ["__pycache__/*", "*.pyc"] + + 返回: + RepoUpdateResult: 更新结果 + """ + try: + # 解析仓库URL + repo_info = GithubUtils.parse_github_url(repo_url) + repo_info.branch = branch + + # 获取仓库最新提交ID + newest_commit = await self._get_newest_commit( + repo_info.owner, repo_info.repo, branch + ) + + # 创建结果对象 + result = RepoUpdateResult( + repo_type=RepoType.GITHUB, + repo_name=repo_info.repo, + owner=repo_info.owner, + old_version="", # 将在后面更新 + new_version=newest_commit, + ) + + old_version = await self.read_version_file(local_path) + old_version = old_version.split("-")[-1] + result.old_version = old_version + + # 如果版本相同,则无需更新 + if newest_commit in old_version: + result.success = True + logger.debug( + f"仓库 {repo_info.repo} 已是最新版本: {newest_commit}", + LOG_COMMAND, + ) + return result + + # 确保本地目录存在 + local_path.mkdir(parents=True, exist_ok=True) + + # 获取变更的文件列表 + changed_files = await self._get_changed_files( + repo_info.owner, + repo_info.repo, + old_version or None, + newest_commit, + ) + + # 过滤文件 + if include_patterns or exclude_patterns: + from .utils import filter_files + + changed_files = filter_files( + changed_files, include_patterns, exclude_patterns + ) + + result.changed_files = changed_files + + # 下载变更的文件 + for file_path in changed_files: + try: + local_file_path = local_path / file_path + await self._download_file(repo_info, file_path, local_file_path) + except Exception as e: + logger.error(f"下载文件 {file_path} 失败", LOG_COMMAND, e=e) + + # 更新版本文件 + await self.write_version_file(local_path, newest_commit) + + result.success = True + return result + + except RepoUpdateError as e: + logger.error("更新仓库失败", LOG_COMMAND, e=e) + return RepoUpdateResult( + repo_type=RepoType.GITHUB, + repo_name=repo_url.split("/")[-1] if "/" in repo_url else repo_url, + owner=repo_url.split("/")[-2] if "/" in repo_url else "unknown", + old_version="", + new_version="", + error_message=str(e), + ) + except Exception as e: + logger.error("更新仓库失败", LOG_COMMAND, e=e) + return RepoUpdateResult( + repo_type=RepoType.GITHUB, + repo_name=repo_url.split("/")[-1] if "/" in repo_url else repo_url, + owner=repo_url.split("/")[-2] if "/" in repo_url else "unknown", + old_version="", + new_version="", + error_message=str(e), + ) + + async def download_file( + self, + repo_url: str, + file_path: str, + local_path: Path, + branch: str = "main", + ) -> FileDownloadResult: + """ + 从GitHub下载单个文件 + + 参数: + repo_url: 仓库URL,格式为 https://github.com/owner/repo + file_path: 文件在仓库中的路径 + local_path: 本地保存路径 + branch: 分支名称 + + 返回: + FileDownloadResult: 下载结果 + """ + repo_name = ( + repo_url.split("/tree/")[0].split("/")[-1].replace(".git", "").strip() + ) + try: + # 解析仓库URL + repo_info = GithubUtils.parse_github_url(repo_url) + repo_info.branch = branch + + # 创建结果对象 + result = FileDownloadResult( + repo_type=RepoType.GITHUB, + repo_name=repo_info.repo, + file_path=file_path, + version=branch, + ) + + # 确保本地目录存在 + local_path.parent.mkdir(parents=True, exist_ok=True) + + # 下载文件 + file_size = await self._download_file(repo_info, file_path, local_path) + + result.success = True + result.file_size = file_size + return result + + except RepoDownloadError as e: + logger.error("下载文件失败", LOG_COMMAND, e=e) + return FileDownloadResult( + repo_type=RepoType.GITHUB, + repo_name=repo_name, + file_path=file_path, + version=branch, + error_message=str(e), + ) + except Exception as e: + logger.error("下载文件失败", LOG_COMMAND, e=e) + return FileDownloadResult( + repo_type=RepoType.GITHUB, + repo_name=repo_name, + file_path=file_path, + version=branch, + error_message=str(e), + ) + + async def get_file_list( + self, + repo_url: str, + dir_path: str = "", + branch: str = "main", + recursive: bool = False, + ) -> list[RepoFileInfo]: + """ + 获取仓库文件列表 + + 参数: + repo_url: 仓库URL,格式为 https://github.com/owner/repo + dir_path: 目录路径,空字符串表示仓库根目录 + branch: 分支名称 + recursive: 是否递归获取子目录 + + 返回: + list[RepoFileInfo]: 文件信息列表 + """ + try: + # 解析仓库URL + repo_info = GithubUtils.parse_github_url(repo_url) + repo_info.branch = branch + + # 获取文件列表 + for api in GithubUtils.iter_api_strategies(): + try: + await api.parse_repo_info(repo_info) + files = api.get_files(dir_path, True) + + result = [] + for file_path in files: + # 跳过非当前目录的文件(如果不是递归模式) + if not recursive and "/" in file_path.replace( + dir_path, "", 1 + ).strip("/"): + continue + + is_dir = file_path.endswith("/") + file_info = RepoFileInfo(path=file_path, is_dir=is_dir) + result.append(file_info) + + return result + except Exception as e: + logger.debug("使用API策略获取文件列表失败", LOG_COMMAND, e=e) + continue + + raise RepoNotFoundError(repo_url) + + except Exception as e: + logger.error("获取文件列表失败", LOG_COMMAND, e=e) + return [] + + async def get_commit_info( + self, repo_url: str, commit_id: str + ) -> RepoCommitInfo | None: + """ + 获取提交信息 + + 参数: + repo_url: 仓库URL,格式为 https://github.com/owner/repo + commit_id: 提交ID + + 返回: + Optional[RepoCommitInfo]: 提交信息,如果获取失败则返回None + """ + try: + # 解析仓库URL + repo_info = GithubUtils.parse_github_url(repo_url) + + # 构建API URL + api_url = f"https://api.github.com/repos/{repo_info.owner}/{repo_info.repo}/commits/{commit_id}" + + # 发送请求 + resp = await AsyncHttpx.get( + api_url, + timeout=self.config.github.api_timeout, + proxy=self.config.github.proxy, + ) + + if resp.status_code == 403 and "rate limit" in resp.text.lower(): + raise ApiRateLimitError("GitHub") + + if resp.status_code != 200: + if resp.status_code == 404: + raise RepoNotFoundError(f"{repo_info.owner}/{repo_info.repo}") + raise NetworkError(f"HTTP {resp.status_code}: {resp.text}") + + data = resp.json() + + return RepoCommitInfo( + commit_id=data["sha"], + message=data["commit"]["message"], + author=data["commit"]["author"]["name"], + commit_time=datetime.fromisoformat( + data["commit"]["author"]["date"].replace("Z", "+00:00") + ), + changed_files=[file["filename"] for file in data.get("files", [])], + ) + except Exception as e: + logger.error("获取提交信息失败", LOG_COMMAND, e=e) + return None + + async def _get_newest_commit(self, owner: str, repo: str, branch: str) -> str: + """ + 获取仓库最新提交ID + + 参数: + owner: 仓库拥有者 + repo: 仓库名称 + branch: 分支名称 + + 返回: + str: 提交ID + """ + try: + newest_commit = await RepoInfo.get_newest_commit(owner, repo, branch) + if not newest_commit: + raise RepoNotFoundError(f"{owner}/{repo}") + return newest_commit + except Exception as e: + logger.error("获取最新提交ID失败", LOG_COMMAND, e=e) + raise RepoUpdateError(f"获取最新提交ID失败: {e}") + + @cached(ttl=3600) + async def _get_changed_files( + self, owner: str, repo: str, old_commit: str | None, new_commit: str + ) -> list[str]: + """ + 获取两个提交之间变更的文件列表 + + 参数: + owner: 仓库拥有者 + repo: 仓库名称 + old_commit: 旧提交ID,如果为None则获取所有文件 + new_commit: 新提交ID + + 返回: + list[str]: 变更的文件列表 + """ + if not old_commit: + # 如果没有旧提交,则获取仓库中的所有文件 + api_url = f"https://api.github.com/repos/{owner}/{repo}/git/trees/{new_commit}?recursive=1" + + resp = await AsyncHttpx.get( + api_url, + timeout=self.config.github.api_timeout, + proxy=self.config.github.proxy, + ) + + if resp.status_code == 403 and "rate limit" in resp.text.lower(): + raise ApiRateLimitError("GitHub") + + if resp.status_code != 200: + if resp.status_code == 404: + raise RepoNotFoundError(f"{owner}/{repo}") + raise NetworkError(f"HTTP {resp.status_code}: {resp.text}") + + data = resp.json() + return [ + item["path"] for item in data.get("tree", []) if item["type"] == "blob" + ] + + # 如果有旧提交,则获取两个提交之间的差异 + api_url = f"https://api.github.com/repos/{owner}/{repo}/compare/{old_commit}...{new_commit}" + + resp = await AsyncHttpx.get( + api_url, + timeout=self.config.github.api_timeout, + proxy=self.config.github.proxy, + ) + + if resp.status_code == 403 and "rate limit" in resp.text.lower(): + raise ApiRateLimitError("GitHub") + + if resp.status_code != 200: + if resp.status_code == 404: + raise RepoNotFoundError(f"{owner}/{repo}") + raise NetworkError(f"HTTP {resp.status_code}: {resp.text}") + + data = resp.json() + return [file["filename"] for file in data.get("files", [])] + + async def update_via_git( + self, + repo_url: str, + local_path: Path, + branch: str = "main", + force: bool = False, + *, + repo_type: RepoType | None = None, + owner: str | None = None, + prepare_repo_url: Callable[[str], str] | None = None, + ) -> RepoUpdateResult: + """ + 通过Git命令直接更新仓库 + + 参数: + repo_url: 仓库URL,格式为 https://github.com/owner/repo + local_path: 本地仓库路径 + branch: 分支名称 + force: 是否强制拉取 + + 返回: + RepoUpdateResult: 更新结果 + """ + # 解析仓库URL + repo_info = GithubUtils.parse_github_url(repo_url) + + # 调用基类的update_via_git方法 + return await super().update_via_git( + repo_url=repo_url, + local_path=local_path, + branch=branch, + force=force, + repo_type=RepoType.GITHUB, + owner=repo_info.owner, + ) + + async def update( + self, + repo_url: str, + local_path: Path, + branch: str = "main", + use_git: bool = True, + force: bool = False, + include_patterns: list[str] | None = None, + exclude_patterns: list[str] | None = None, + ) -> RepoUpdateResult: + """ + 更新仓库,可选择使用Git命令或API方式 + + 参数: + repo_url: 仓库URL,格式为 https://github.com/owner/repo + local_path: 本地保存路径 + branch: 分支名称 + use_git: 是否使用Git命令更新 + include_patterns: 包含的文件模式列表,如 ["*.py", "docs/*.md"] + exclude_patterns: 排除的文件模式列表,如 ["__pycache__/*", "*.pyc"] + + 返回: + RepoUpdateResult: 更新结果 + """ + if use_git: + return await self.update_via_git(repo_url, local_path, branch, force) + else: + return await self.update_repo( + repo_url, local_path, branch, include_patterns, exclude_patterns + ) + + async def _download_file( + self, repo_info: RepoInfo, file_path: str, local_path: Path + ) -> int: + """ + 下载文件 + + 参数: + repo_info: 仓库信息 + file_path: 文件在仓库中的路径 + local_path: 本地保存路径 + + 返回: + int: 文件大小(字节) + """ + # 确保目录存在 + local_path.parent.mkdir(parents=True, exist_ok=True) + + # 获取下载URL + download_url = await repo_info.get_raw_download_url(file_path) + + # 下载文件 + for retry in range(self.config.github.download_retry + 1): + try: + resp = await AsyncHttpx.get( + download_url, + timeout=self.config.github.download_timeout, + ) + + if resp.status_code == 403 and "rate limit" in resp.text.lower(): + raise ApiRateLimitError("GitHub") + + if resp.status_code != 200: + if resp.status_code == 404: + raise FileNotFoundError( + file_path, f"{repo_info.owner}/{repo_info.repo}" + ) + + if retry < self.config.github.download_retry: + await asyncio.sleep(1) + continue + + raise NetworkError(f"HTTP {resp.status_code}: {resp.text}") + + # 保存文件 + return await self.save_file_content(resp.content, local_path) + + except (ApiRateLimitError, FileNotFoundError) as e: + # 这些错误不需要重试 + raise e + except Exception as e: + if retry < self.config.github.download_retry: + logger.warning("下载文件失败,将重试", LOG_COMMAND, e=e) + await asyncio.sleep(1) + continue + raise RepoDownloadError("下载文件失败") + + raise RepoDownloadError("下载文件失败: 超过最大重试次数") diff --git a/zhenxun/utils/repo_utils/models.py b/zhenxun/utils/repo_utils/models.py new file mode 100644 index 00000000..170e60f3 --- /dev/null +++ b/zhenxun/utils/repo_utils/models.py @@ -0,0 +1,89 @@ +""" +仓库管理工具的数据模型 +""" + +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from pathlib import Path + + +class RepoType(str, Enum): + """仓库类型""" + + GITHUB = "github" + ALIYUN = "aliyun" + + +@dataclass +class RepoFileInfo: + """仓库文件信息""" + + # 文件路径 + path: str + # 是否是目录 + is_dir: bool + # 文件大小(字节) + size: int | None = None + # 最后修改时间 + last_modified: datetime | None = None + # 子文件列表 + children: list["RepoFileInfo"] = field(default_factory=list) + + +@dataclass +class RepoCommitInfo: + """仓库提交信息""" + + # 提交ID + commit_id: str + # 提交消息 + message: str + # 作者 + author: str + # 提交时间 + commit_time: datetime + # 变更的文件列表 + changed_files: list[str] = field(default_factory=list) + + +@dataclass +class RepoUpdateResult: + """仓库更新结果""" + + # 仓库类型 + repo_type: RepoType + # 仓库名称 + repo_name: str + # 仓库拥有者 + owner: str + # 旧版本 + old_version: str + # 新版本 + new_version: str + # 是否成功 + success: bool = False + # 错误消息 + error_message: str = "" + # 变更的文件列表 + changed_files: list[str] = field(default_factory=list) + + +@dataclass +class FileDownloadResult: + """文件下载结果""" + + # 仓库类型 + repo_type: RepoType | None + # 仓库名称 + repo_name: str + # 文件路径 + file_path: list[tuple[str, Path]] | str + # 版本 + version: str + # 是否成功 + success: bool = False + # 文件大小(字节) + file_size: int = 0 + # 错误消息 + error_message: str = "" diff --git a/zhenxun/utils/repo_utils/utils.py b/zhenxun/utils/repo_utils/utils.py new file mode 100644 index 00000000..7aceb231 --- /dev/null +++ b/zhenxun/utils/repo_utils/utils.py @@ -0,0 +1,135 @@ +""" +仓库管理工具的工具函数 +""" + +import asyncio +from pathlib import Path +import re + +from zhenxun.services.log import logger + +from .config import LOG_COMMAND + + +async def check_git() -> bool: + """ + 检查环境变量中是否存在 git + + 返回: + bool: 是否存在git命令 + """ + try: + process = await asyncio.create_subprocess_shell( + "git --version", + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, _ = await process.communicate() + return bool(stdout) + except Exception as e: + logger.error("检查git命令失败", LOG_COMMAND, e=e) + return False + + +async def clean_git(cwd: Path): + """ + 清理git仓库 + + 参数: + cwd: 工作目录 + """ + await run_git_command("reset --hard", cwd) + await run_git_command("clean -xdf", cwd) + + +async def run_git_command( + command: str, cwd: Path | None = None +) -> tuple[bool, str, str]: + """ + 运行git命令 + + 参数: + command: 命令 + cwd: 工作目录 + + 返回: + tuple[bool, str, str]: (是否成功, 标准输出, 标准错误) + """ + try: + full_command = f"git {command}" + # 将Path对象转换为字符串 + cwd_str = str(cwd) if cwd else None + process = await asyncio.create_subprocess_shell( + full_command, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=cwd_str, + ) + stdout_bytes, stderr_bytes = await process.communicate() + + stdout = stdout_bytes.decode("utf-8").strip() + stderr = stderr_bytes.decode("utf-8").strip() + + return process.returncode == 0, stdout, stderr + except Exception as e: + logger.error(f"运行git命令失败: {command}, 错误: {e}") + return False, "", str(e) + + +def glob_to_regex(pattern: str) -> str: + """ + 将glob模式转换为正则表达式 + + 参数: + pattern: glob模式,如 "*.py" + + 返回: + str: 正则表达式 + """ + # 转义特殊字符 + regex = re.escape(pattern) + + # 替换glob通配符 + regex = regex.replace(r"\*\*", ".*") # ** -> .* + regex = regex.replace(r"\*", "[^/]*") # * -> [^/]* + regex = regex.replace(r"\?", "[^/]") # ? -> [^/] + + # 添加开始和结束标记 + regex = f"^{regex}$" + + return regex + + +def filter_files( + files: list[str], + include_patterns: list[str] | None = None, + exclude_patterns: list[str] | None = None, +) -> list[str]: + """ + 过滤文件列表 + + 参数: + files: 文件列表 + include_patterns: 包含的文件模式列表,如 ["*.py", "docs/*.md"] + exclude_patterns: 排除的文件模式列表,如 ["__pycache__/*", "*.pyc"] + + 返回: + list[str]: 过滤后的文件列表 + """ + result = files.copy() + + # 应用包含模式 + if include_patterns: + included = [] + for pattern in include_patterns: + regex_pattern = glob_to_regex(pattern) + included.extend(file for file in result if re.match(regex_pattern, file)) + result = included + + # 应用排除模式 + if exclude_patterns: + for pattern in exclude_patterns: + regex_pattern = glob_to_regex(pattern) + result = [file for file in result if not re.match(regex_pattern, file)] + + return result