mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 14:22:55 +08:00
⚡ 插件商店支持aliyun
This commit is contained in:
parent
c9456f292d
commit
6986637ec2
@ -84,7 +84,7 @@ async def _(session: EventSession):
|
||||
try:
|
||||
result = await StoreManager.get_plugins_info()
|
||||
logger.info("查看插件列表", "插件商店", session=session)
|
||||
await MessageUtils.build_message(result).send()
|
||||
await MessageUtils.build_message([*result]).send()
|
||||
except Exception as e:
|
||||
logger.error(f"查看插件列表失败 e: {e}", "插件商店", session=session, e=e)
|
||||
await MessageUtils.build_message("获取插件列表失败...").send()
|
||||
|
||||
@ -1,19 +1,19 @@
|
||||
from pathlib import Path
|
||||
import random
|
||||
import shutil
|
||||
|
||||
from aiocache import cached
|
||||
import ujson as json
|
||||
|
||||
from zhenxun.builtin_plugins.auto_update.config import REQ_TXT_FILE_STRING
|
||||
from zhenxun.builtin_plugins.plugin_store.models import StorePluginInfo
|
||||
from zhenxun.configs.path_config import TEMP_PATH
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.services.plugin_init import PluginInitManager
|
||||
from zhenxun.utils.github_utils import GithubUtils
|
||||
from zhenxun.utils.github_utils.models import RepoAPI
|
||||
from zhenxun.utils.http_utils import AsyncHttpx
|
||||
from zhenxun.utils.image_utils import BuildImage, ImageTemplate, RowStyle
|
||||
from zhenxun.utils.manager.virtual_env_package_manager import VirtualEnvPackageManager
|
||||
from zhenxun.utils.repo_utils import RepoFileManager
|
||||
from zhenxun.utils.repo_utils.models import RepoFileInfo, RepoType
|
||||
from zhenxun.utils.utils import is_number
|
||||
|
||||
from .config import (
|
||||
@ -22,6 +22,7 @@ from .config import (
|
||||
EXTRA_GITHUB_URL,
|
||||
LOG_COMMAND,
|
||||
)
|
||||
from .exceptions import PluginStoreException
|
||||
|
||||
|
||||
def row_style(column: str, text: str) -> RowStyle:
|
||||
@ -40,73 +41,25 @@ def row_style(column: str, text: str) -> RowStyle:
|
||||
return style
|
||||
|
||||
|
||||
async def install_requirement(plugin_path: Path):
|
||||
requirement_files = ["requirement.txt", "requirements.txt"]
|
||||
requirement_paths = [plugin_path / file for file in requirement_files]
|
||||
|
||||
if existing_requirements := next(
|
||||
(path for path in requirement_paths if path.exists()), None
|
||||
):
|
||||
await VirtualEnvPackageManager.install_requirement(existing_requirements)
|
||||
|
||||
|
||||
class StoreManager:
|
||||
@classmethod
|
||||
async def get_github_plugins(cls) -> list[StorePluginInfo]:
|
||||
"""获取github插件列表信息
|
||||
|
||||
返回:
|
||||
list[StorePluginInfo]: 插件列表数据
|
||||
"""
|
||||
repo_info = GithubUtils.parse_github_url(DEFAULT_GITHUB_URL)
|
||||
if await repo_info.update_repo_commit():
|
||||
logger.info(f"获取最新提交: {repo_info.branch}", LOG_COMMAND)
|
||||
else:
|
||||
logger.warning(f"获取最新提交失败: {repo_info}", LOG_COMMAND)
|
||||
default_github_url = await repo_info.get_raw_download_urls("plugins.json")
|
||||
response = await AsyncHttpx.get(default_github_url, check_status_code=200)
|
||||
if response.status_code == 200:
|
||||
logger.info("获取github插件列表成功", LOG_COMMAND)
|
||||
return [StorePluginInfo(**detail) for detail in json.loads(response.text)]
|
||||
else:
|
||||
logger.warning(
|
||||
f"获取github插件列表失败: {response.status_code}", LOG_COMMAND
|
||||
)
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
async def get_extra_plugins(cls) -> list[StorePluginInfo]:
|
||||
"""获取额外插件列表信息
|
||||
|
||||
返回:
|
||||
list[StorePluginInfo]: 插件列表数据
|
||||
"""
|
||||
repo_info = GithubUtils.parse_github_url(EXTRA_GITHUB_URL)
|
||||
if await repo_info.update_repo_commit():
|
||||
logger.info(f"获取最新提交: {repo_info.branch}", LOG_COMMAND)
|
||||
else:
|
||||
logger.warning(f"获取最新提交失败: {repo_info}", LOG_COMMAND)
|
||||
extra_github_url = await repo_info.get_raw_download_urls("plugins.json")
|
||||
response = await AsyncHttpx.get(extra_github_url, check_status_code=200)
|
||||
if response.status_code == 200:
|
||||
return [StorePluginInfo(**detail) for detail in json.loads(response.text)]
|
||||
else:
|
||||
logger.warning(
|
||||
f"获取github扩展插件列表失败: {response.status_code}", LOG_COMMAND
|
||||
)
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
@cached(60)
|
||||
async def get_data(cls) -> list[StorePluginInfo]:
|
||||
async def get_data(cls) -> tuple[list[StorePluginInfo], list[StorePluginInfo]]:
|
||||
"""获取插件信息数据
|
||||
|
||||
返回:
|
||||
list[StorePluginInfo]: 插件信息数据
|
||||
tuple[list[StorePluginInfo], list[StorePluginInfo]]:
|
||||
原生插件信息数据,第三方插件信息数据
|
||||
"""
|
||||
plugins = await cls.get_github_plugins()
|
||||
extra_plugins = await cls.get_extra_plugins()
|
||||
return [*plugins, *extra_plugins]
|
||||
plugins = await RepoFileManager.get_file_content(
|
||||
DEFAULT_GITHUB_URL, "plugins.json"
|
||||
)
|
||||
extra_plugins = await RepoFileManager.get_file_content(
|
||||
EXTRA_GITHUB_URL, "plugins.json", "index"
|
||||
)
|
||||
return [StorePluginInfo(**plugin) for plugin in json.loads(plugins)], [
|
||||
StorePluginInfo(**plugin) for plugin in json.loads(extra_plugins)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def version_check(cls, plugin_info: StorePluginInfo, suc_plugin: dict[str, str]):
|
||||
@ -152,38 +105,94 @@ class StoreManager:
|
||||
return await PluginInfo.filter(load_status=True).values_list(*args)
|
||||
|
||||
@classmethod
|
||||
async def get_plugins_info(cls) -> BuildImage | str:
|
||||
async def get_plugins_info(cls) -> list[BuildImage] | str:
|
||||
"""插件列表
|
||||
|
||||
返回:
|
||||
BuildImage | str: 返回消息
|
||||
"""
|
||||
plugin_list: list[StorePluginInfo] = await cls.get_data()
|
||||
plugin_list, extra_plugin_list = await cls.get_data()
|
||||
column_name = ["-", "ID", "名称", "简介", "作者", "版本", "类型"]
|
||||
db_plugin_list = await cls.get_loaded_plugins("module", "version")
|
||||
suc_plugin = {p[0]: (p[1] or "0.1") for p in db_plugin_list}
|
||||
data_list = [
|
||||
[
|
||||
"已安装" if plugin_info.module in suc_plugin else "",
|
||||
id,
|
||||
plugin_info.name,
|
||||
plugin_info.description,
|
||||
plugin_info.author,
|
||||
cls.version_check(plugin_info, suc_plugin),
|
||||
plugin_info.plugin_type_name,
|
||||
]
|
||||
for id, plugin_info in enumerate(plugin_list)
|
||||
index = 0
|
||||
data_list = []
|
||||
extra_data_list = []
|
||||
for plugin_info in plugin_list:
|
||||
data_list.append(
|
||||
[
|
||||
"已安装" if plugin_info.module in suc_plugin else "",
|
||||
index,
|
||||
plugin_info.name,
|
||||
plugin_info.description,
|
||||
plugin_info.author,
|
||||
cls.version_check(plugin_info, suc_plugin),
|
||||
plugin_info.plugin_type_name,
|
||||
]
|
||||
)
|
||||
index += 1
|
||||
for plugin_info in extra_plugin_list:
|
||||
extra_data_list.append(
|
||||
[
|
||||
"已安装" if plugin_info.module in suc_plugin else "",
|
||||
index,
|
||||
plugin_info.name,
|
||||
plugin_info.description,
|
||||
plugin_info.author,
|
||||
cls.version_check(plugin_info, suc_plugin),
|
||||
plugin_info.plugin_type_name,
|
||||
]
|
||||
)
|
||||
index += 1
|
||||
return [
|
||||
await ImageTemplate.table_page(
|
||||
"原生插件列表",
|
||||
"通过添加/移除插件 ID 来管理插件",
|
||||
column_name,
|
||||
data_list,
|
||||
text_style=row_style,
|
||||
),
|
||||
await ImageTemplate.table_page(
|
||||
"第三方插件列表",
|
||||
"通过添加/移除插件 ID 来管理插件",
|
||||
column_name,
|
||||
extra_data_list,
|
||||
text_style=row_style,
|
||||
),
|
||||
]
|
||||
return await ImageTemplate.table_page(
|
||||
"插件列表",
|
||||
"通过添加/移除插件 ID 来管理插件",
|
||||
column_name,
|
||||
data_list,
|
||||
text_style=row_style,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def add_plugin(cls, plugin_id: str) -> str:
|
||||
async def get_plugin_by_value(
|
||||
cls, index_or_module: str, is_update: bool = False
|
||||
) -> StorePluginInfo:
|
||||
"""获取插件信息
|
||||
|
||||
参数:
|
||||
index_or_module: 插件索引或模块名
|
||||
is_update: 是否是更新插件
|
||||
|
||||
异常:
|
||||
PluginStoreException: 插件不存在
|
||||
PluginStoreException: 插件已安装
|
||||
|
||||
返回:
|
||||
StorePluginInfo: 插件信息
|
||||
"""
|
||||
plugin_list, extra_plugin_list = await cls.get_data()
|
||||
all_plugin_list = plugin_list + extra_plugin_list
|
||||
db_plugin_list = await cls.get_loaded_plugins("module")
|
||||
plugin_key = await cls._resolve_plugin_key(index_or_module)
|
||||
plugin_info = next((p for p in all_plugin_list if p.module == plugin_key), None)
|
||||
if not plugin_info:
|
||||
raise PluginStoreException(f"插件不存在: {plugin_key}")
|
||||
if not is_update and plugin_info.module in [p[0] for p in db_plugin_list]:
|
||||
raise PluginStoreException(f"插件 {plugin_info.name} 已安装,无需重复安装")
|
||||
if plugin_info.module not in [p[0] for p in db_plugin_list] and is_update:
|
||||
raise PluginStoreException(f"插件 {plugin_info.name} 未安装,无法更新")
|
||||
return plugin_info
|
||||
|
||||
@classmethod
|
||||
async def add_plugin(cls, index_or_module: str) -> str:
|
||||
"""添加插件
|
||||
|
||||
参数:
|
||||
@ -192,17 +201,7 @@ class StoreManager:
|
||||
返回:
|
||||
str: 返回消息
|
||||
"""
|
||||
plugin_list: list[StorePluginInfo] = await cls.get_data()
|
||||
try:
|
||||
plugin_key = await cls._resolve_plugin_key(plugin_id)
|
||||
except ValueError as e:
|
||||
return str(e)
|
||||
db_plugin_list = await cls.get_loaded_plugins("module")
|
||||
plugin_info = next((p for p in plugin_list if p.module == plugin_key), None)
|
||||
if plugin_info is None:
|
||||
return f"未找到插件 {plugin_key}"
|
||||
if plugin_info.module in [p[0] for p in db_plugin_list]:
|
||||
return f"插件 {plugin_info.name} 已安装,无需重复安装"
|
||||
plugin_info = await cls.get_plugin_by_value(index_or_module)
|
||||
is_external = True
|
||||
if plugin_info.github_url is None:
|
||||
plugin_info.github_url = DEFAULT_GITHUB_URL
|
||||
@ -228,90 +227,81 @@ class StoreManager:
|
||||
is_dir: bool,
|
||||
is_external: bool = False,
|
||||
):
|
||||
repo_api: RepoAPI
|
||||
repo_info = GithubUtils.parse_github_url(github_url)
|
||||
if await repo_info.update_repo_commit():
|
||||
logger.info(f"获取最新提交: {repo_info.branch}", LOG_COMMAND)
|
||||
else:
|
||||
logger.warning(f"获取最新提交失败: {repo_info}", LOG_COMMAND)
|
||||
logger.debug(f"成功获取仓库信息: {repo_info}", LOG_COMMAND)
|
||||
for repo_api in GithubUtils.iter_api_strategies():
|
||||
try:
|
||||
await repo_api.parse_repo_info(repo_info)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"获取插件文件失败 | API类型: {repo_api.strategy}",
|
||||
LOG_COMMAND,
|
||||
e=e,
|
||||
)
|
||||
continue
|
||||
else:
|
||||
raise ValueError("所有API获取插件文件失败,请检查网络连接")
|
||||
if module_path == ".":
|
||||
module_path = ""
|
||||
"""安装插件
|
||||
|
||||
参数:
|
||||
github_url: 仓库地址
|
||||
module_path: 模块路径
|
||||
is_dir: 是否是文件夹
|
||||
is_external: 是否是外部仓库
|
||||
"""
|
||||
repo_type = RepoType.GITHUB if is_external else None
|
||||
replace_module_path = module_path.replace(".", "/")
|
||||
files = repo_api.get_files(
|
||||
module_path=replace_module_path + ("" if is_dir else ".py"),
|
||||
is_dir=is_dir,
|
||||
)
|
||||
download_urls = [await repo_info.get_raw_download_urls(file) for file in files]
|
||||
base_path = BASE_PATH / "plugins" if is_external else BASE_PATH
|
||||
base_path = base_path if module_path else base_path / repo_info.repo
|
||||
download_paths: list[Path | str] = [base_path / file for file in files]
|
||||
logger.debug(f"插件下载路径: {download_paths}", LOG_COMMAND)
|
||||
result = await AsyncHttpx.gather_download_file(download_urls, download_paths)
|
||||
for _id, success in enumerate(result):
|
||||
if not success:
|
||||
break
|
||||
if is_dir:
|
||||
files = await RepoFileManager.list_directory_files(
|
||||
github_url, replace_module_path, repo_type=repo_type
|
||||
)
|
||||
else:
|
||||
# 安装依赖
|
||||
plugin_path = base_path / "/".join(module_path.split("."))
|
||||
try:
|
||||
req_files = repo_api.get_files(
|
||||
f"{replace_module_path}/{REQ_TXT_FILE_STRING}", False
|
||||
files = [RepoFileInfo(path=f"{replace_module_path}.py", is_dir=False)]
|
||||
local_path = BASE_PATH / "plugins" if is_external else BASE_PATH
|
||||
files = [file for file in files if not file.is_dir]
|
||||
download_files = [(file.path, local_path / file.path) for file in files]
|
||||
await RepoFileManager.download_files(
|
||||
github_url, download_files, repo_type=repo_type
|
||||
)
|
||||
|
||||
requirement_paths = [
|
||||
file
|
||||
for file in files
|
||||
if file.path.endswith("requirement.txt")
|
||||
or file.path.endswith("requirements.txt")
|
||||
]
|
||||
|
||||
is_install_req = False
|
||||
for requirement_path in requirement_paths:
|
||||
requirement_file = local_path / requirement_path.path
|
||||
if requirement_file.exists():
|
||||
is_install_req = True
|
||||
await VirtualEnvPackageManager.install_requirement(requirement_file)
|
||||
|
||||
if not is_install_req:
|
||||
# 从仓库根目录查找文件
|
||||
rand = random.randint(1, 10000)
|
||||
requirement_path = TEMP_PATH / f"plugin_store_{rand}_req.txt"
|
||||
requirements_path = TEMP_PATH / f"plugin_store_{rand}_reqs.txt"
|
||||
await RepoFileManager.download_files(
|
||||
github_url,
|
||||
[
|
||||
("requirement.txt", requirement_path),
|
||||
("requirements.txt", requirements_path),
|
||||
],
|
||||
repo_type=repo_type,
|
||||
ignore_error=True,
|
||||
)
|
||||
if requirement_path.exists():
|
||||
logger.info(
|
||||
f"开始安装插件 {module_path} 依赖文件: {requirement_path}",
|
||||
LOG_COMMAND,
|
||||
)
|
||||
req_files.extend(
|
||||
repo_api.get_files(f"{replace_module_path}/requirement.txt", False)
|
||||
await VirtualEnvPackageManager.install_requirement(requirement_path)
|
||||
if requirements_path.exists():
|
||||
logger.info(
|
||||
f"开始安装插件 {module_path} 依赖文件: {requirements_path}",
|
||||
LOG_COMMAND,
|
||||
)
|
||||
logger.debug(f"获取插件依赖文件列表: {req_files}", LOG_COMMAND)
|
||||
req_download_urls = [
|
||||
await repo_info.get_raw_download_urls(file) for file in req_files
|
||||
]
|
||||
req_paths: list[Path | str] = [plugin_path / file for file in req_files]
|
||||
logger.debug(f"插件依赖文件下载路径: {req_paths}", LOG_COMMAND)
|
||||
if req_files:
|
||||
result = await AsyncHttpx.gather_download_file(
|
||||
req_download_urls, req_paths
|
||||
)
|
||||
for success in result:
|
||||
if not success:
|
||||
raise Exception("插件依赖文件下载失败")
|
||||
logger.debug(f"插件依赖文件列表: {req_paths}", LOG_COMMAND)
|
||||
await install_requirement(plugin_path)
|
||||
except ValueError as e:
|
||||
logger.warning("未获取到依赖文件路径...", e=e)
|
||||
return True
|
||||
raise Exception("插件下载失败...")
|
||||
await VirtualEnvPackageManager.install_requirement(requirements_path)
|
||||
|
||||
@classmethod
|
||||
async def remove_plugin(cls, plugin_id: str) -> str:
|
||||
async def remove_plugin(cls, index_or_module: str) -> str:
|
||||
"""移除插件
|
||||
|
||||
参数:
|
||||
plugin_id: 插件id或模块名
|
||||
index_or_module: 插件id或模块名
|
||||
|
||||
返回:
|
||||
str: 返回消息
|
||||
"""
|
||||
plugin_list: list[StorePluginInfo] = await cls.get_data()
|
||||
try:
|
||||
plugin_key = await cls._resolve_plugin_key(plugin_id)
|
||||
except ValueError as e:
|
||||
return str(e)
|
||||
plugin_info = next((p for p in plugin_list if p.module == plugin_key), None)
|
||||
if plugin_info is None:
|
||||
return f"未找到插件 {plugin_key}"
|
||||
plugin_info = await cls.get_plugin_by_value(index_or_module)
|
||||
path = BASE_PATH
|
||||
if plugin_info.github_url:
|
||||
path = BASE_PATH / "plugins"
|
||||
@ -339,12 +329,13 @@ class StoreManager:
|
||||
返回:
|
||||
BuildImage | str: 返回消息
|
||||
"""
|
||||
plugin_list: list[StorePluginInfo] = await cls.get_data()
|
||||
plugin_list, extra_plugin_list = await cls.get_data()
|
||||
all_plugin_list = plugin_list + extra_plugin_list
|
||||
db_plugin_list = await cls.get_loaded_plugins("module", "version")
|
||||
suc_plugin = {p[0]: (p[1] or "Unknown") for p in db_plugin_list}
|
||||
filtered_data = [
|
||||
(id, plugin_info)
|
||||
for id, plugin_info in enumerate(plugin_list)
|
||||
for id, plugin_info in enumerate(all_plugin_list)
|
||||
if plugin_name_or_author.lower() in plugin_info.name.lower()
|
||||
or plugin_name_or_author.lower() in plugin_info.author.lower()
|
||||
]
|
||||
@ -373,28 +364,19 @@ class StoreManager:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def update_plugin(cls, plugin_id: str) -> str:
|
||||
async def update_plugin(cls, index_or_module: str) -> str:
|
||||
"""更新插件
|
||||
|
||||
参数:
|
||||
plugin_id: 插件id
|
||||
index_or_module: 插件id
|
||||
|
||||
返回:
|
||||
str: 返回消息
|
||||
"""
|
||||
plugin_list: list[StorePluginInfo] = await cls.get_data()
|
||||
try:
|
||||
plugin_key = await cls._resolve_plugin_key(plugin_id)
|
||||
except ValueError as e:
|
||||
return str(e)
|
||||
plugin_info = next((p for p in plugin_list if p.module == plugin_key), None)
|
||||
if plugin_info is None:
|
||||
return f"未找到插件 {plugin_key}"
|
||||
plugin_info = await cls.get_plugin_by_value(index_or_module, True)
|
||||
logger.info(f"尝试更新插件 {plugin_info.name}", LOG_COMMAND)
|
||||
db_plugin_list = await cls.get_loaded_plugins("module", "version")
|
||||
suc_plugin = {p[0]: (p[1] or "Unknown") for p in db_plugin_list}
|
||||
if plugin_info.module not in [p[0] for p in db_plugin_list]:
|
||||
return f"插件 {plugin_info.name} 未安装,无法更新"
|
||||
logger.debug(f"当前插件列表: {suc_plugin}", LOG_COMMAND)
|
||||
if cls.check_version_is_new(plugin_info, suc_plugin):
|
||||
return f"插件 {plugin_info.name} 已是最新版本"
|
||||
@ -492,22 +474,25 @@ class StoreManager:
|
||||
plugin_id: module,id或插件名称
|
||||
|
||||
异常:
|
||||
ValueError: 插件不存在
|
||||
ValueError: 插件不存在
|
||||
PluginStoreException: 插件不存在
|
||||
PluginStoreException: 插件不存在
|
||||
|
||||
返回:
|
||||
str: 插件模块名
|
||||
"""
|
||||
plugin_list: list[StorePluginInfo] = await cls.get_data()
|
||||
plugin_list, extra_plugin_list = await cls.get_data()
|
||||
all_plugin_list = plugin_list + extra_plugin_list
|
||||
if is_number(plugin_id):
|
||||
idx = int(plugin_id)
|
||||
if idx < 0 or idx >= len(plugin_list):
|
||||
raise ValueError("插件ID不存在...")
|
||||
return plugin_list[idx].module
|
||||
if idx < 0 or idx >= len(all_plugin_list):
|
||||
raise PluginStoreException("插件ID不存在...")
|
||||
return all_plugin_list[idx].module
|
||||
elif isinstance(plugin_id, str):
|
||||
result = (
|
||||
None if plugin_id not in [v.module for v in plugin_list] else plugin_id
|
||||
) or next(v for v in plugin_list if v.name == plugin_id).module
|
||||
None
|
||||
if plugin_id not in [v.module for v in all_plugin_list]
|
||||
else plugin_id
|
||||
) or next(v for v in all_plugin_list if v.name == plugin_id).module
|
||||
if not result:
|
||||
raise ValueError("插件 Module / 名称 不存在...")
|
||||
raise PluginStoreException("插件 Module / 名称 不存在...")
|
||||
return result
|
||||
|
||||
6
zhenxun/builtin_plugins/plugin_store/exceptions.py
Normal file
6
zhenxun/builtin_plugins/plugin_store/exceptions.py
Normal file
@ -0,0 +1,6 @@
|
||||
class PluginStoreException(Exception):
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
@ -16,6 +16,7 @@ from .exceptions import (
|
||||
RepoNotFoundError,
|
||||
RepoUpdateError,
|
||||
)
|
||||
from .file_manager import RepoFileManager as RepoFileManagerClass
|
||||
from .github_manager import GithubManager
|
||||
from .models import (
|
||||
FileDownloadResult,
|
||||
@ -28,6 +29,7 @@ from .utils import check_git, filter_files, glob_to_regex, run_git_command
|
||||
|
||||
GithubRepoManager = GithubManager()
|
||||
AliyunRepoManager = AliyunCodeupManager()
|
||||
RepoFileManager = RepoFileManagerClass()
|
||||
|
||||
__all__ = [
|
||||
"AliyunCodeupConfig",
|
||||
@ -45,6 +47,7 @@ __all__ = [
|
||||
"RepoConfig",
|
||||
"RepoDownloadError",
|
||||
"RepoFileInfo",
|
||||
"RepoFileManager",
|
||||
"RepoManagerError",
|
||||
"RepoNotFoundError",
|
||||
"RepoType",
|
||||
|
||||
542
zhenxun/utils/repo_utils/file_manager.py
Normal file
542
zhenxun/utils/repo_utils/file_manager.py
Normal file
@ -0,0 +1,542 @@
|
||||
"""
|
||||
仓库文件管理器,用于从GitHub和阿里云CodeUp获取指定文件内容
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import cast, overload
|
||||
|
||||
import aiofiles
|
||||
from httpx import Response
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.github_utils import GithubUtils
|
||||
from zhenxun.utils.github_utils.models import AliyunTreeType, GitHubStrategy, TreeType
|
||||
from zhenxun.utils.http_utils import AsyncHttpx
|
||||
|
||||
from .config import LOG_COMMAND, RepoConfig
|
||||
from .exceptions import FileNotFoundError, NetworkError, RepoManagerError
|
||||
from .models import FileDownloadResult, RepoFileInfo, RepoType
|
||||
|
||||
|
||||
class RepoFileManager:
|
||||
"""仓库文件管理器,用于获取GitHub和阿里云仓库中的文件内容"""
|
||||
|
||||
def __init__(self, config: RepoConfig | None = None):
|
||||
"""
|
||||
初始化仓库文件管理器
|
||||
|
||||
参数:
|
||||
config: 配置,如果为None则使用默认配置
|
||||
"""
|
||||
self.config = config or RepoConfig.get_instance()
|
||||
self.config.ensure_dirs()
|
||||
|
||||
@overload
|
||||
async def get_github_file_content(
|
||||
self, url: str, file_path: str, ignore_error: bool = False
|
||||
) -> str: ...
|
||||
|
||||
@overload
|
||||
async def get_github_file_content(
|
||||
self, url: str, file_path: list[str], ignore_error: bool = False
|
||||
) -> list[tuple[str, str]]: ...
|
||||
|
||||
async def get_github_file_content(
|
||||
self, url: str, file_path: str | list[str], ignore_error: bool = False
|
||||
) -> str | list[tuple[str, str]]:
|
||||
"""
|
||||
获取GitHub仓库文件内容
|
||||
|
||||
参数:
|
||||
url: 仓库URL
|
||||
file_path: 文件路径或文件路径列表
|
||||
ignore_error: 是否忽略错误
|
||||
|
||||
返回:
|
||||
list[tuple[str, str]]: 文件路径,文件内容
|
||||
"""
|
||||
results = []
|
||||
is_str_input = isinstance(file_path, str)
|
||||
try:
|
||||
if is_str_input:
|
||||
file_path = [file_path]
|
||||
repo_info = GithubUtils.parse_github_url(url)
|
||||
if await repo_info.update_repo_commit():
|
||||
logger.info(f"获取最新提交: {repo_info.branch}", LOG_COMMAND)
|
||||
else:
|
||||
logger.warning(f"获取最新提交失败: {repo_info}", LOG_COMMAND)
|
||||
for f in file_path:
|
||||
try:
|
||||
file_url = await repo_info.get_raw_download_urls(f)
|
||||
for fu in file_url:
|
||||
response: Response = await AsyncHttpx.get(
|
||||
fu, check_status_code=200
|
||||
)
|
||||
if response.status_code == 200:
|
||||
logger.info(f"获取github文件内容成功: {f}", LOG_COMMAND)
|
||||
# 确保使用UTF-8编码解析响应内容
|
||||
try:
|
||||
text_content = response.content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
# 如果UTF-8解码失败,尝试其他编码
|
||||
text_content = response.content.decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
logger.warning(
|
||||
f"解码文件内容时出现错误,使用忽略错误模式: {f}",
|
||||
LOG_COMMAND,
|
||||
)
|
||||
results.append((f, text_content))
|
||||
break
|
||||
else:
|
||||
logger.warning(
|
||||
f"获取github文件内容失败: {response.status_code}",
|
||||
LOG_COMMAND,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"获取github文件内容失败: {f}", LOG_COMMAND, e=e)
|
||||
if not ignore_error:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取GitHub文件内容失败: {file_path}", LOG_COMMAND, e=e)
|
||||
raise
|
||||
logger.debug(f"获取GitHub文件内容: {[r[0] for r in results]}", LOG_COMMAND)
|
||||
|
||||
return results[0][1] if is_str_input and results else results
|
||||
|
||||
@overload
|
||||
async def get_aliyun_file_content(
|
||||
self,
|
||||
repo_name: str,
|
||||
file_path: str,
|
||||
branch: str = "main",
|
||||
ignore_error: bool = False,
|
||||
) -> str: ...
|
||||
|
||||
@overload
|
||||
async def get_aliyun_file_content(
|
||||
self,
|
||||
repo_name: str,
|
||||
file_path: list[str],
|
||||
branch: str = "main",
|
||||
ignore_error: bool = False,
|
||||
) -> list[tuple[str, str]]: ...
|
||||
|
||||
async def get_aliyun_file_content(
|
||||
self,
|
||||
repo_name: str,
|
||||
file_path: str | list[str],
|
||||
branch: str = "main",
|
||||
ignore_error: bool = False,
|
||||
) -> str | list[tuple[str, str]]:
|
||||
"""
|
||||
获取阿里云CodeUp仓库文件内容
|
||||
|
||||
参数:
|
||||
repo: 仓库名称
|
||||
file_path: 文件路径
|
||||
branch: 分支名称
|
||||
ignore_error: 是否忽略错误
|
||||
返回:
|
||||
list[tuple[str, str]]: 文件路径,文件内容
|
||||
"""
|
||||
results = []
|
||||
is_str_input = isinstance(file_path, str)
|
||||
# 导入阿里云相关模块
|
||||
from zhenxun.utils.github_utils.models import AliyunFileInfo
|
||||
|
||||
if is_str_input:
|
||||
file_path = [file_path]
|
||||
for f in file_path:
|
||||
try:
|
||||
content = await AliyunFileInfo.get_file_content(
|
||||
file_path=f, repo=repo_name, ref=branch
|
||||
)
|
||||
results.append((f, content))
|
||||
except Exception as e:
|
||||
logger.error(f"获取阿里云文件内容失败: {file_path}", LOG_COMMAND, e=e)
|
||||
if not ignore_error:
|
||||
raise
|
||||
logger.debug(f"获取阿里云文件内容: {[r[0] for r in results]}", LOG_COMMAND)
|
||||
return results[0][1] if is_str_input and results else results
|
||||
|
||||
@overload
|
||||
async def get_file_content(
|
||||
self,
|
||||
repo_url: str,
|
||||
file_path: str,
|
||||
branch: str = "main",
|
||||
repo_type: RepoType | None = None,
|
||||
ignore_error: bool = False,
|
||||
) -> str: ...
|
||||
|
||||
@overload
|
||||
async def get_file_content(
|
||||
self,
|
||||
repo_url: str,
|
||||
file_path: list[str],
|
||||
branch: str = "main",
|
||||
repo_type: RepoType | None = None,
|
||||
ignore_error: bool = False,
|
||||
) -> list[tuple[str, str]]: ...
|
||||
|
||||
async def get_file_content(
|
||||
self,
|
||||
repo_url: str,
|
||||
file_path: str | list[str],
|
||||
branch: str = "main",
|
||||
repo_type: RepoType | None = None,
|
||||
ignore_error: bool = False,
|
||||
) -> str | list[tuple[str, str]]:
|
||||
"""
|
||||
获取仓库文件内容
|
||||
|
||||
参数:
|
||||
repo_url: 仓库URL
|
||||
file_path: 文件路径
|
||||
branch: 分支名称
|
||||
repo_type: 仓库类型,如果为None则自动判断
|
||||
ignore_error: 是否忽略错误
|
||||
|
||||
返回:
|
||||
str: 文件内容
|
||||
"""
|
||||
# 确定仓库类型
|
||||
repo_name = (
|
||||
repo_url.split("/tree/")[0].split("/")[-1].replace(".git", "").strip()
|
||||
)
|
||||
if repo_type is None:
|
||||
try:
|
||||
return await self.get_aliyun_file_content(
|
||||
repo_name, file_path, branch, ignore_error
|
||||
)
|
||||
except Exception:
|
||||
return await self.get_github_file_content(
|
||||
repo_url, file_path, ignore_error
|
||||
)
|
||||
|
||||
try:
|
||||
if repo_type == RepoType.GITHUB:
|
||||
return await self.get_github_file_content(
|
||||
repo_url, file_path, ignore_error
|
||||
)
|
||||
|
||||
elif repo_type == RepoType.ALIYUN:
|
||||
return await self.get_aliyun_file_content(
|
||||
repo_name, file_path, branch, ignore_error
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, FileNotFoundError | NetworkError | RepoManagerError):
|
||||
raise
|
||||
raise RepoManagerError(f"获取文件内容失败: {e}")
|
||||
|
||||
async def list_directory_files(
|
||||
self,
|
||||
repo_url: str,
|
||||
directory_path: str = "",
|
||||
branch: str = "main",
|
||||
repo_type: RepoType | None = None,
|
||||
recursive: bool = True,
|
||||
) -> list[RepoFileInfo]:
|
||||
"""
|
||||
获取仓库目录下的所有文件路径
|
||||
|
||||
参数:
|
||||
repo_url: 仓库URL
|
||||
directory_path: 目录路径,默认为仓库根目录
|
||||
branch: 分支名称
|
||||
repo_type: 仓库类型,如果为None则自动判断
|
||||
recursive: 是否递归获取子目录文件
|
||||
|
||||
返回:
|
||||
list[RepoFileInfo]: 文件信息列表
|
||||
"""
|
||||
repo_name = (
|
||||
repo_url.split("/tree/")[0].split("/")[-1].replace(".git", "").strip()
|
||||
)
|
||||
try:
|
||||
if repo_type is None:
|
||||
# 尝试阿里云,失败则尝试GitHub
|
||||
try:
|
||||
return await self._list_aliyun_directory_files(
|
||||
repo_name, directory_path, branch, recursive
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"获取阿里云目录文件失败,尝试GitHub", LOG_COMMAND, e=e
|
||||
)
|
||||
return await self._list_github_directory_files(
|
||||
repo_url, directory_path, branch, recursive
|
||||
)
|
||||
if repo_type == RepoType.GITHUB:
|
||||
return await self._list_github_directory_files(
|
||||
repo_url, directory_path, branch, recursive
|
||||
)
|
||||
elif repo_type == RepoType.ALIYUN:
|
||||
return await self._list_aliyun_directory_files(
|
||||
repo_name, directory_path, branch, recursive
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取目录文件列表失败: {directory_path}", LOG_COMMAND, e=e)
|
||||
if isinstance(e, FileNotFoundError | NetworkError | RepoManagerError):
|
||||
raise
|
||||
raise RepoManagerError(f"获取目录文件列表失败: {e}")
|
||||
|
||||
async def _list_github_directory_files(
|
||||
self,
|
||||
repo_url: str,
|
||||
directory_path: str = "",
|
||||
branch: str = "main",
|
||||
recursive: bool = True,
|
||||
build_tree: bool = False,
|
||||
) -> list[RepoFileInfo]:
|
||||
"""
|
||||
获取GitHub仓库目录下的所有文件路径
|
||||
|
||||
参数:
|
||||
repo_url: 仓库URL
|
||||
directory_path: 目录路径,默认为仓库根目录
|
||||
branch: 分支名称
|
||||
recursive: 是否递归获取子目录文件
|
||||
build_tree: 是否构建目录树
|
||||
|
||||
返回:
|
||||
list[RepoFileInfo]: 文件信息列表
|
||||
"""
|
||||
try:
|
||||
repo_info = GithubUtils.parse_github_url(repo_url)
|
||||
if await repo_info.update_repo_commit():
|
||||
logger.info(f"获取最新提交: {repo_info.branch}", LOG_COMMAND)
|
||||
else:
|
||||
logger.warning(f"获取最新提交失败: {repo_info}", LOG_COMMAND)
|
||||
|
||||
# 获取仓库树信息
|
||||
strategy = GitHubStrategy()
|
||||
strategy.body = await GitHubStrategy.parse_repo_info(repo_info)
|
||||
|
||||
# 处理目录路径,确保格式正确
|
||||
if directory_path and not directory_path.endswith("/") and recursive:
|
||||
directory_path = f"{directory_path}/"
|
||||
|
||||
# 获取文件列表
|
||||
file_list = []
|
||||
for tree_item in strategy.body.tree:
|
||||
# 如果不是递归模式,只获取当前目录下的文件
|
||||
if not recursive and "/" in tree_item.path.replace(
|
||||
directory_path, "", 1
|
||||
):
|
||||
continue
|
||||
|
||||
# 检查是否在指定目录下
|
||||
if directory_path and not tree_item.path.startswith(directory_path):
|
||||
continue
|
||||
|
||||
# 创建文件信息对象
|
||||
file_info = RepoFileInfo(
|
||||
path=tree_item.path,
|
||||
is_dir=tree_item.type == TreeType.DIR,
|
||||
size=tree_item.size,
|
||||
last_modified=None, # GitHub API不直接提供最后修改时间
|
||||
)
|
||||
file_list.append(file_info)
|
||||
|
||||
# 构建目录树结构
|
||||
if recursive and build_tree:
|
||||
file_list = self._build_directory_tree(file_list)
|
||||
|
||||
return file_list
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"获取GitHub目录文件列表失败: {directory_path}", LOG_COMMAND, e=e
|
||||
)
|
||||
raise
|
||||
|
||||
async def _list_aliyun_directory_files(
|
||||
self,
|
||||
repo_name: str,
|
||||
directory_path: str = "",
|
||||
branch: str = "main",
|
||||
recursive: bool = True,
|
||||
build_tree: bool = False,
|
||||
) -> list[RepoFileInfo]:
|
||||
"""
|
||||
获取阿里云CodeUp仓库目录下的所有文件路径
|
||||
|
||||
参数:
|
||||
repo_name: 仓库名称
|
||||
directory_path: 目录路径,默认为仓库根目录
|
||||
branch: 分支名称
|
||||
recursive: 是否递归获取子目录文件
|
||||
build_tree: 是否构建目录树
|
||||
|
||||
返回:
|
||||
list[RepoFileInfo]: 文件信息列表
|
||||
"""
|
||||
try:
|
||||
from zhenxun.utils.github_utils.models import AliyunFileInfo
|
||||
|
||||
# 获取仓库树信息
|
||||
search_type = "RECURSIVE" if recursive else "DIRECT"
|
||||
tree_list = await AliyunFileInfo.get_repository_tree(
|
||||
repo=repo_name,
|
||||
path=directory_path,
|
||||
ref=branch,
|
||||
search_type=search_type,
|
||||
)
|
||||
|
||||
# 创建文件信息对象列表
|
||||
file_list = []
|
||||
for tree_item in tree_list:
|
||||
file_info = RepoFileInfo(
|
||||
path=tree_item.path,
|
||||
is_dir=tree_item.type == AliyunTreeType.DIR,
|
||||
size=None, # 阿里云API不直接提供文件大小
|
||||
last_modified=None, # 阿里云API不直接提供最后修改时间
|
||||
)
|
||||
file_list.append(file_info)
|
||||
|
||||
# 构建目录树结构
|
||||
if recursive and build_tree:
|
||||
file_list = self._build_directory_tree(file_list)
|
||||
|
||||
return file_list
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"获取阿里云目录文件列表失败: {directory_path}", LOG_COMMAND, e=e
|
||||
)
|
||||
raise
|
||||
|
||||
def _build_directory_tree(
|
||||
self, file_list: list[RepoFileInfo]
|
||||
) -> list[RepoFileInfo]:
|
||||
"""
|
||||
构建目录树结构
|
||||
|
||||
参数:
|
||||
file_list: 文件信息列表
|
||||
|
||||
返回:
|
||||
list[RepoFileInfo]: 根目录下的文件信息列表
|
||||
"""
|
||||
# 按路径排序,确保父目录在子目录之前
|
||||
file_list.sort(key=lambda x: x.path)
|
||||
# 创建路径到文件信息的映射
|
||||
path_map = {file_info.path: file_info for file_info in file_list}
|
||||
# 根目录文件列表
|
||||
root_files = []
|
||||
|
||||
for file_info in file_list:
|
||||
if parent_path := "/".join(file_info.path.split("/")[:-1]):
|
||||
# 如果有父目录,将当前文件添加到父目录的子文件列表中
|
||||
if parent_path in path_map:
|
||||
path_map[parent_path].children.append(file_info)
|
||||
else:
|
||||
# 如果父目录不在列表中,创建一个虚拟的父目录
|
||||
parent_info = RepoFileInfo(
|
||||
path=parent_path, is_dir=True, children=[file_info]
|
||||
)
|
||||
path_map[parent_path] = parent_info
|
||||
# 检查父目录的父目录
|
||||
grand_parent_path = "/".join(parent_path.split("/")[:-1])
|
||||
if grand_parent_path and grand_parent_path in path_map:
|
||||
path_map[grand_parent_path].children.append(parent_info)
|
||||
else:
|
||||
root_files.append(parent_info)
|
||||
else:
|
||||
# 如果没有父目录,则是根目录下的文件
|
||||
root_files.append(file_info)
|
||||
|
||||
# 返回根目录下的文件列表
|
||||
return [
|
||||
file
|
||||
for file in root_files
|
||||
if all(f.path != file.path for f in file_list if f != file)
|
||||
]
|
||||
|
||||
async def download_files(
|
||||
self,
|
||||
repo_url: str,
|
||||
file_path: tuple[str, Path] | list[tuple[str, Path]],
|
||||
branch: str = "main",
|
||||
repo_type: RepoType | None = None,
|
||||
ignore_error: bool = False,
|
||||
) -> FileDownloadResult:
|
||||
"""
|
||||
下载单个文件
|
||||
|
||||
参数:
|
||||
repo_url: 仓库URL
|
||||
file_path: 文件在仓库中的路径,本地存储路径
|
||||
branch: 分支名称
|
||||
repo_type: 仓库类型,如果为None则自动判断
|
||||
ignore_error: 是否忽略错误
|
||||
|
||||
返回:
|
||||
FileDownloadResult: 下载结果
|
||||
"""
|
||||
# 确定仓库类型和所有者
|
||||
repo_name = (
|
||||
repo_url.split("/tree/")[0].split("/")[-1].replace(".git", "").strip()
|
||||
)
|
||||
|
||||
if isinstance(file_path, tuple):
|
||||
file_path = [file_path]
|
||||
|
||||
file_path_mapping = {f[0]: f[1] for f in file_path}
|
||||
|
||||
# 创建结果对象
|
||||
result = FileDownloadResult(
|
||||
repo_type=repo_type,
|
||||
repo_name=repo_name,
|
||||
file_path=file_path,
|
||||
version=branch,
|
||||
)
|
||||
|
||||
try:
|
||||
# 由于我们传入的是列表,所以这里一定返回列表
|
||||
file_paths = [f[0] for f in file_path]
|
||||
if len(file_paths) == 1:
|
||||
# 如果只有一个文件,可能返回单个元组
|
||||
file_contents_result = await self.get_file_content(
|
||||
repo_url, file_paths[0], branch, repo_type, ignore_error
|
||||
)
|
||||
if isinstance(file_contents_result, tuple):
|
||||
file_contents = [file_contents_result]
|
||||
elif isinstance(file_contents_result, str):
|
||||
file_contents = [(file_paths[0], file_contents_result)]
|
||||
else:
|
||||
file_contents = cast(list[tuple[str, str]], file_contents_result)
|
||||
else:
|
||||
# 多个文件一定返回列表
|
||||
file_contents = cast(
|
||||
list[tuple[str, str]],
|
||||
await self.get_file_content(
|
||||
repo_url, file_paths, branch, repo_type, ignore_error
|
||||
),
|
||||
)
|
||||
|
||||
for repo_file_path, content in file_contents:
|
||||
local_path = file_path_mapping[repo_file_path]
|
||||
local_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# 使用二进制模式写入文件,避免编码问题
|
||||
if isinstance(content, str):
|
||||
content_bytes = content.encode("utf-8")
|
||||
else:
|
||||
content_bytes = content
|
||||
async with aiofiles.open(local_path, "wb") as f:
|
||||
await f.write(content_bytes)
|
||||
result.success = True
|
||||
# 计算文件大小
|
||||
result.file_size = sum(
|
||||
len(content.encode("utf-8") if isinstance(content, str) else content)
|
||||
for _, content in file_contents
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"下载文件失败: {e}")
|
||||
result.success = False
|
||||
result.error_message = str(e)
|
||||
return result
|
||||
@ -5,6 +5,7 @@
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class RepoType(str, Enum):
|
||||
@ -26,6 +27,8 @@ class RepoFileInfo:
|
||||
size: int | None = None
|
||||
# 最后修改时间
|
||||
last_modified: datetime | None = None
|
||||
# 子文件列表
|
||||
children: list["RepoFileInfo"] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -71,15 +74,11 @@ class FileDownloadResult:
|
||||
"""文件下载结果"""
|
||||
|
||||
# 仓库类型
|
||||
repo_type: RepoType
|
||||
repo_type: RepoType | None
|
||||
# 仓库名称
|
||||
repo_name: str
|
||||
# 仓库拥有者
|
||||
owner: str
|
||||
# 文件路径
|
||||
file_path: str
|
||||
# 本地路径
|
||||
local_path: str
|
||||
file_path: list[tuple[str, Path]]
|
||||
# 版本
|
||||
version: str
|
||||
# 是否成功
|
||||
|
||||
Loading…
Reference in New Issue
Block a user