优化插件更新和加载机制,提供测试函数

- 修复了插件更新函数中的条件判断逻辑
This commit is contained in:
molanp 2024-10-01 18:36:02 +08:00
parent f345f412c9
commit 7e0c0125ae
2 changed files with 152 additions and 16 deletions

View File

@ -0,0 +1,120 @@
from typing import cast
from pathlib import Path
from collections.abc import Callable
from nonebug import App
from respx import MockRouter
from pytest_mock import MockerFixture
from nonebot.adapters.onebot.v11 import Bot
from nonebot.adapters.onebot.v11.message import Message
from nonebot.adapters.onebot.v11.event import GroupMessageEvent
from tests.utils import _v11_group_message_event
from tests.config import BotId, UserId, GroupId, MessageId
from tests.builtin_plugins.plugin_store.utils import init_mocked_api
async def test_update_all_plugin_basic_need_update(
app: App,
mocker: MockerFixture,
mocked_api: MockRouter,
create_bot: Callable,
tmp_path: Path,
) -> None:
"""
测试更新基础插件插件需要更新
"""
from zhenxun.builtin_plugins.plugin_store import _matcher
init_mocked_api(mocked_api=mocked_api)
mock_base_path = mocker.patch(
"zhenxun.builtin_plugins.plugin_store.data_source.BASE_PATH",
new=tmp_path / "zhenxun",
)
mocker.patch(
"zhenxun.builtin_plugins.plugin_store.data_source.ShopManage.get_loaded_plugins",
return_value=[("search_image", "0.0")],
)
async with app.test_matcher(_matcher) as ctx:
bot = create_bot(ctx)
bot: Bot = cast(Bot, bot)
raw_message = "更新全部插件"
event: GroupMessageEvent = _v11_group_message_event(
message=raw_message,
self_id=BotId.QQ_BOT,
user_id=UserId.SUPERUSER,
group_id=GroupId.GROUP_ID_LEVEL_5,
message_id=MessageId.MESSAGE_ID,
to_me=True,
)
ctx.receive_event(bot=bot, event=event)
ctx.should_call_send(
event=event,
message=Message(message="正在更新全部插件"),
result=None,
bot=bot,
)
ctx.should_call_send(
event=event,
message=Message(message="已更新插件 \n- 识图\n共计1个插件! 重启后生效"),
result=None,
bot=bot,
)
assert mocked_api["basic_plugins"].called
assert mocked_api["extra_plugins"].called
assert mocked_api["search_image_plugin_file_init"].called
assert (mock_base_path / "plugins" / "search_image" / "__init__.py").is_file()
async def test_update_all_plugin_basic_is_new(
app: App,
mocker: MockerFixture,
mocked_api: MockRouter,
create_bot: Callable,
tmp_path: Path,
) -> None:
"""
测试更新基础插件插件是最新版
"""
from zhenxun.builtin_plugins.plugin_store import _matcher
init_mocked_api(mocked_api=mocked_api)
mocker.patch(
"zhenxun.builtin_plugins.plugin_store.data_source.BASE_PATH",
new=tmp_path / "zhenxun",
)
mocker.patch(
"zhenxun.builtin_plugins.plugin_store.data_source.ShopManage.get_loaded_plugins",
return_value=[("search_image", "0.1")],
)
async with app.test_matcher(_matcher) as ctx:
bot = create_bot(ctx)
bot: Bot = cast(Bot, bot)
raw_message = "更新全部插件"
event: GroupMessageEvent = _v11_group_message_event(
message=raw_message,
self_id=BotId.QQ_BOT,
user_id=UserId.SUPERUSER,
group_id=GroupId.GROUP_ID_LEVEL_5,
message_id=MessageId.MESSAGE_ID,
to_me=True,
)
ctx.receive_event(bot=bot, event=event)
ctx.should_call_send(
event=event,
message=Message(message="正在更新全部插件"),
result=None,
bot=bot,
)
ctx.should_call_send(
event=event,
message=Message(message="全部插件已是最新版本"),
result=None,
bot=bot,
)
assert mocked_api["basic_plugins"].called
assert mocked_api["extra_plugins"].called

View File

@ -90,8 +90,7 @@ class ShopManage:
# 检查请求结果
if res.status_code != 200 or res2.status_code != 200:
raise ValueError(
f"下载错误, code: {res.status_code}, {res2.status_code}")
raise ValueError(f"下载错误, code: {res.status_code}, {res2.status_code}")
# 解析并合并返回的 JSON 数据
data1 = json.loads(res.text)
@ -188,10 +187,15 @@ class ShopManage:
data: dict[str, StorePluginInfo] = await cls.get_data()
if isinstance(plugin_id, int) and (plugin_id < 0 or plugin_id >= len(data)):
return "插件ID不存在..."
elif isinstance(plugin_id, str) and plugin_id not in [v.module for k, v in data.items()]:
elif isinstance(plugin_id, str) and plugin_id not in [
v.module for k, v in data.items()
]:
return "插件Module不存在..."
plugin_key = list(data.keys())[plugin_id] if isinstance(plugin_id, int) else {
v.module: k for k, v in data.items()}[plugin_id]
plugin_key = (
list(data.keys())[plugin_id]
if isinstance(plugin_id, int)
else {v.module: k for k, v in data.items()}[plugin_id]
)
plugin_list = await cls.get_loaded_plugins("module")
plugin_info = data[plugin_key]
if plugin_info.module in [p[0] for p in plugin_list]:
@ -233,8 +237,7 @@ class ShopManage:
else:
raise ValueError("所有API获取插件文件失败请检查网络连接")
files = repo_api.get_files(
module_path=module_path.replace(
".", "/") + ("" if is_dir else ".py"),
module_path=module_path.replace(".", "/") + ("" if is_dir else ".py"),
is_dir=is_dir,
)
download_urls = [await repo_info.get_raw_download_urls(file) for file in files]
@ -254,8 +257,7 @@ class ShopManage:
req_download_urls = [
await repo_info.get_raw_download_urls(file) for file in req_files
]
req_paths: list[Path | str] = [
plugin_path / file for file in req_files]
req_paths: list[Path | str] = [plugin_path / file for file in req_files]
logger.debug(f"插件依赖文件下载路径: {req_paths}", "插件管理")
if req_files:
result = await AsyncHttpx.gather_download_file(
@ -283,10 +285,15 @@ class ShopManage:
data: dict[str, StorePluginInfo] = await cls.get_data()
if isinstance(plugin_id, int) and (plugin_id < 0 or plugin_id >= len(data)):
return "插件ID不存在..."
elif isinstance(plugin_id, str) and plugin_id not in [v.module for k, v in data.items()]:
elif isinstance(plugin_id, str) and plugin_id not in [
v.module for k, v in data.items()
]:
return "插件Module不存在..."
plugin_key = list(data.keys())[plugin_id] if isinstance(plugin_id, int) else {
v.module: k for k, v in data.items()}[plugin_id]
plugin_key = (
list(data.keys())[plugin_id]
if isinstance(plugin_id, int)
else {v.module: k for k, v in data.items()}[plugin_id]
)
plugin_info = data[plugin_key]
path = BASE_PATH
if plugin_info.github_url:
@ -361,10 +368,15 @@ class ShopManage:
data: dict[str, StorePluginInfo] = await cls.get_data()
if isinstance(plugin_id, int) and (plugin_id < 0 or plugin_id >= len(data)):
return "插件ID不存在..."
elif isinstance(plugin_id, str) and plugin_id not in [v.module for k, v in data.items()]:
elif isinstance(plugin_id, str) and plugin_id not in [
v.module for k, v in data.items()
]:
return "插件Module不存在..."
plugin_key = list(data.keys())[plugin_id] if isinstance(plugin_id, int) else {
v.module: k for k, v in data.items()}[plugin_id]
plugin_key = (
list(data.keys())[plugin_id]
if isinstance(plugin_id, int)
else {v.module: k for k, v in data.items()}[plugin_id]
)
logger.info(f"尝试更新插件 {plugin_key}", "插件管理")
plugin_info = data[plugin_key]
plugin_list = await cls.get_loaded_plugins("module", "version")
@ -422,4 +434,8 @@ class ShopManage:
is_external,
)
update_list.append(plugin_key)
return "已更新插件 {}\n共计{}个插件! 重启后生效".format('\n- '.join(update_list), len(update_list))
if len(update_list) == 0:
return "全部插件已是最新版本"
return "已更新插件 {}\n共计{}个插件! 重启后生效".format(
"\n- ".join(update_list), len(update_list)
)