插件商店支持aliyun

This commit is contained in:
HibiKier 2025-08-01 14:15:28 +08:00
parent c9456f292d
commit 6986637ec2
6 changed files with 732 additions and 197 deletions

View File

@ -84,7 +84,7 @@ async def _(session: EventSession):
try: try:
result = await StoreManager.get_plugins_info() result = await StoreManager.get_plugins_info()
logger.info("查看插件列表", "插件商店", session=session) logger.info("查看插件列表", "插件商店", session=session)
await MessageUtils.build_message(result).send() await MessageUtils.build_message([*result]).send()
except Exception as e: except Exception as e:
logger.error(f"查看插件列表失败 e: {e}", "插件商店", session=session, e=e) logger.error(f"查看插件列表失败 e: {e}", "插件商店", session=session, e=e)
await MessageUtils.build_message("获取插件列表失败...").send() await MessageUtils.build_message("获取插件列表失败...").send()

View File

@ -1,19 +1,19 @@
from pathlib import Path from pathlib import Path
import random
import shutil import shutil
from aiocache import cached from aiocache import cached
import ujson as json import ujson as json
from zhenxun.builtin_plugins.auto_update.config import REQ_TXT_FILE_STRING
from zhenxun.builtin_plugins.plugin_store.models import StorePluginInfo from zhenxun.builtin_plugins.plugin_store.models import StorePluginInfo
from zhenxun.configs.path_config import TEMP_PATH
from zhenxun.models.plugin_info import PluginInfo from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.log import logger from zhenxun.services.log import logger
from zhenxun.services.plugin_init import PluginInitManager from zhenxun.services.plugin_init import PluginInitManager
from zhenxun.utils.github_utils import GithubUtils
from zhenxun.utils.github_utils.models import RepoAPI
from zhenxun.utils.http_utils import AsyncHttpx
from zhenxun.utils.image_utils import BuildImage, ImageTemplate, RowStyle from zhenxun.utils.image_utils import BuildImage, ImageTemplate, RowStyle
from zhenxun.utils.manager.virtual_env_package_manager import VirtualEnvPackageManager from zhenxun.utils.manager.virtual_env_package_manager import VirtualEnvPackageManager
from zhenxun.utils.repo_utils import RepoFileManager
from zhenxun.utils.repo_utils.models import RepoFileInfo, RepoType
from zhenxun.utils.utils import is_number from zhenxun.utils.utils import is_number
from .config import ( from .config import (
@ -22,6 +22,7 @@ from .config import (
EXTRA_GITHUB_URL, EXTRA_GITHUB_URL,
LOG_COMMAND, LOG_COMMAND,
) )
from .exceptions import PluginStoreException
def row_style(column: str, text: str) -> RowStyle: def row_style(column: str, text: str) -> RowStyle:
@ -40,73 +41,25 @@ def row_style(column: str, text: str) -> RowStyle:
return style return style
async def install_requirement(plugin_path: Path):
requirement_files = ["requirement.txt", "requirements.txt"]
requirement_paths = [plugin_path / file for file in requirement_files]
if existing_requirements := next(
(path for path in requirement_paths if path.exists()), None
):
await VirtualEnvPackageManager.install_requirement(existing_requirements)
class StoreManager: class StoreManager:
@classmethod
async def get_github_plugins(cls) -> list[StorePluginInfo]:
"""获取github插件列表信息
返回:
list[StorePluginInfo]: 插件列表数据
"""
repo_info = GithubUtils.parse_github_url(DEFAULT_GITHUB_URL)
if await repo_info.update_repo_commit():
logger.info(f"获取最新提交: {repo_info.branch}", LOG_COMMAND)
else:
logger.warning(f"获取最新提交失败: {repo_info}", LOG_COMMAND)
default_github_url = await repo_info.get_raw_download_urls("plugins.json")
response = await AsyncHttpx.get(default_github_url, check_status_code=200)
if response.status_code == 200:
logger.info("获取github插件列表成功", LOG_COMMAND)
return [StorePluginInfo(**detail) for detail in json.loads(response.text)]
else:
logger.warning(
f"获取github插件列表失败: {response.status_code}", LOG_COMMAND
)
return []
@classmethod
async def get_extra_plugins(cls) -> list[StorePluginInfo]:
"""获取额外插件列表信息
返回:
list[StorePluginInfo]: 插件列表数据
"""
repo_info = GithubUtils.parse_github_url(EXTRA_GITHUB_URL)
if await repo_info.update_repo_commit():
logger.info(f"获取最新提交: {repo_info.branch}", LOG_COMMAND)
else:
logger.warning(f"获取最新提交失败: {repo_info}", LOG_COMMAND)
extra_github_url = await repo_info.get_raw_download_urls("plugins.json")
response = await AsyncHttpx.get(extra_github_url, check_status_code=200)
if response.status_code == 200:
return [StorePluginInfo(**detail) for detail in json.loads(response.text)]
else:
logger.warning(
f"获取github扩展插件列表失败: {response.status_code}", LOG_COMMAND
)
return []
@classmethod @classmethod
@cached(60) @cached(60)
async def get_data(cls) -> list[StorePluginInfo]: async def get_data(cls) -> tuple[list[StorePluginInfo], list[StorePluginInfo]]:
"""获取插件信息数据 """获取插件信息数据
返回: 返回:
list[StorePluginInfo]: 插件信息数据 tuple[list[StorePluginInfo], list[StorePluginInfo]]:
原生插件信息数据第三方插件信息数据
""" """
plugins = await cls.get_github_plugins() plugins = await RepoFileManager.get_file_content(
extra_plugins = await cls.get_extra_plugins() DEFAULT_GITHUB_URL, "plugins.json"
return [*plugins, *extra_plugins] )
extra_plugins = await RepoFileManager.get_file_content(
EXTRA_GITHUB_URL, "plugins.json", "index"
)
return [StorePluginInfo(**plugin) for plugin in json.loads(plugins)], [
StorePluginInfo(**plugin) for plugin in json.loads(extra_plugins)
]
@classmethod @classmethod
def version_check(cls, plugin_info: StorePluginInfo, suc_plugin: dict[str, str]): def version_check(cls, plugin_info: StorePluginInfo, suc_plugin: dict[str, str]):
@ -152,38 +105,94 @@ class StoreManager:
return await PluginInfo.filter(load_status=True).values_list(*args) return await PluginInfo.filter(load_status=True).values_list(*args)
@classmethod @classmethod
async def get_plugins_info(cls) -> BuildImage | str: async def get_plugins_info(cls) -> list[BuildImage] | str:
"""插件列表 """插件列表
返回: 返回:
BuildImage | str: 返回消息 BuildImage | str: 返回消息
""" """
plugin_list: list[StorePluginInfo] = await cls.get_data() plugin_list, extra_plugin_list = await cls.get_data()
column_name = ["-", "ID", "名称", "简介", "作者", "版本", "类型"] column_name = ["-", "ID", "名称", "简介", "作者", "版本", "类型"]
db_plugin_list = await cls.get_loaded_plugins("module", "version") db_plugin_list = await cls.get_loaded_plugins("module", "version")
suc_plugin = {p[0]: (p[1] or "0.1") for p in db_plugin_list} suc_plugin = {p[0]: (p[1] or "0.1") for p in db_plugin_list}
data_list = [ index = 0
data_list = []
extra_data_list = []
for plugin_info in plugin_list:
data_list.append(
[ [
"已安装" if plugin_info.module in suc_plugin else "", "已安装" if plugin_info.module in suc_plugin else "",
id, index,
plugin_info.name, plugin_info.name,
plugin_info.description, plugin_info.description,
plugin_info.author, plugin_info.author,
cls.version_check(plugin_info, suc_plugin), cls.version_check(plugin_info, suc_plugin),
plugin_info.plugin_type_name, plugin_info.plugin_type_name,
] ]
for id, plugin_info in enumerate(plugin_list) )
index += 1
for plugin_info in extra_plugin_list:
extra_data_list.append(
[
"已安装" if plugin_info.module in suc_plugin else "",
index,
plugin_info.name,
plugin_info.description,
plugin_info.author,
cls.version_check(plugin_info, suc_plugin),
plugin_info.plugin_type_name,
] ]
return await ImageTemplate.table_page( )
"插件列表", index += 1
return [
await ImageTemplate.table_page(
"原生插件列表",
"通过添加/移除插件 ID 来管理插件", "通过添加/移除插件 ID 来管理插件",
column_name, column_name,
data_list, data_list,
text_style=row_style, text_style=row_style,
) ),
await ImageTemplate.table_page(
"第三方插件列表",
"通过添加/移除插件 ID 来管理插件",
column_name,
extra_data_list,
text_style=row_style,
),
]
@classmethod @classmethod
async def add_plugin(cls, plugin_id: str) -> str: async def get_plugin_by_value(
cls, index_or_module: str, is_update: bool = False
) -> StorePluginInfo:
"""获取插件信息
参数:
index_or_module: 插件索引或模块名
is_update: 是否是更新插件
异常:
PluginStoreException: 插件不存在
PluginStoreException: 插件已安装
返回:
StorePluginInfo: 插件信息
"""
plugin_list, extra_plugin_list = await cls.get_data()
all_plugin_list = plugin_list + extra_plugin_list
db_plugin_list = await cls.get_loaded_plugins("module")
plugin_key = await cls._resolve_plugin_key(index_or_module)
plugin_info = next((p for p in all_plugin_list if p.module == plugin_key), None)
if not plugin_info:
raise PluginStoreException(f"插件不存在: {plugin_key}")
if not is_update and plugin_info.module in [p[0] for p in db_plugin_list]:
raise PluginStoreException(f"插件 {plugin_info.name} 已安装,无需重复安装")
if plugin_info.module not in [p[0] for p in db_plugin_list] and is_update:
raise PluginStoreException(f"插件 {plugin_info.name} 未安装,无法更新")
return plugin_info
@classmethod
async def add_plugin(cls, index_or_module: str) -> str:
"""添加插件 """添加插件
参数: 参数:
@ -192,17 +201,7 @@ class StoreManager:
返回: 返回:
str: 返回消息 str: 返回消息
""" """
plugin_list: list[StorePluginInfo] = await cls.get_data() plugin_info = await cls.get_plugin_by_value(index_or_module)
try:
plugin_key = await cls._resolve_plugin_key(plugin_id)
except ValueError as e:
return str(e)
db_plugin_list = await cls.get_loaded_plugins("module")
plugin_info = next((p for p in plugin_list if p.module == plugin_key), None)
if plugin_info is None:
return f"未找到插件 {plugin_key}"
if plugin_info.module in [p[0] for p in db_plugin_list]:
return f"插件 {plugin_info.name} 已安装,无需重复安装"
is_external = True is_external = True
if plugin_info.github_url is None: if plugin_info.github_url is None:
plugin_info.github_url = DEFAULT_GITHUB_URL plugin_info.github_url = DEFAULT_GITHUB_URL
@ -228,90 +227,81 @@ class StoreManager:
is_dir: bool, is_dir: bool,
is_external: bool = False, is_external: bool = False,
): ):
repo_api: RepoAPI """安装插件
repo_info = GithubUtils.parse_github_url(github_url)
if await repo_info.update_repo_commit(): 参数:
logger.info(f"获取最新提交: {repo_info.branch}", LOG_COMMAND) github_url: 仓库地址
else: module_path: 模块路径
logger.warning(f"获取最新提交失败: {repo_info}", LOG_COMMAND) is_dir: 是否是文件夹
logger.debug(f"成功获取仓库信息: {repo_info}", LOG_COMMAND) is_external: 是否是外部仓库
for repo_api in GithubUtils.iter_api_strategies(): """
try: repo_type = RepoType.GITHUB if is_external else None
await repo_api.parse_repo_info(repo_info)
break
except Exception as e:
logger.warning(
f"获取插件文件失败 | API类型: {repo_api.strategy}",
LOG_COMMAND,
e=e,
)
continue
else:
raise ValueError("所有API获取插件文件失败请检查网络连接")
if module_path == ".":
module_path = ""
replace_module_path = module_path.replace(".", "/") replace_module_path = module_path.replace(".", "/")
files = repo_api.get_files( if is_dir:
module_path=replace_module_path + ("" if is_dir else ".py"), files = await RepoFileManager.list_directory_files(
is_dir=is_dir, github_url, replace_module_path, repo_type=repo_type
) )
download_urls = [await repo_info.get_raw_download_urls(file) for file in files]
base_path = BASE_PATH / "plugins" if is_external else BASE_PATH
base_path = base_path if module_path else base_path / repo_info.repo
download_paths: list[Path | str] = [base_path / file for file in files]
logger.debug(f"插件下载路径: {download_paths}", LOG_COMMAND)
result = await AsyncHttpx.gather_download_file(download_urls, download_paths)
for _id, success in enumerate(result):
if not success:
break
else: else:
# 安装依赖 files = [RepoFileInfo(path=f"{replace_module_path}.py", is_dir=False)]
plugin_path = base_path / "/".join(module_path.split(".")) local_path = BASE_PATH / "plugins" if is_external else BASE_PATH
try: files = [file for file in files if not file.is_dir]
req_files = repo_api.get_files( download_files = [(file.path, local_path / file.path) for file in files]
f"{replace_module_path}/{REQ_TXT_FILE_STRING}", False await RepoFileManager.download_files(
github_url, download_files, repo_type=repo_type
) )
req_files.extend(
repo_api.get_files(f"{replace_module_path}/requirement.txt", False) requirement_paths = [
) file
logger.debug(f"获取插件依赖文件列表: {req_files}", LOG_COMMAND) for file in files
req_download_urls = [ if file.path.endswith("requirement.txt")
await repo_info.get_raw_download_urls(file) for file in req_files or file.path.endswith("requirements.txt")
] ]
req_paths: list[Path | str] = [plugin_path / file for file in req_files]
logger.debug(f"插件依赖文件下载路径: {req_paths}", LOG_COMMAND) is_install_req = False
if req_files: for requirement_path in requirement_paths:
result = await AsyncHttpx.gather_download_file( requirement_file = local_path / requirement_path.path
req_download_urls, req_paths if requirement_file.exists():
is_install_req = True
await VirtualEnvPackageManager.install_requirement(requirement_file)
if not is_install_req:
# 从仓库根目录查找文件
rand = random.randint(1, 10000)
requirement_path = TEMP_PATH / f"plugin_store_{rand}_req.txt"
requirements_path = TEMP_PATH / f"plugin_store_{rand}_reqs.txt"
await RepoFileManager.download_files(
github_url,
[
("requirement.txt", requirement_path),
("requirements.txt", requirements_path),
],
repo_type=repo_type,
ignore_error=True,
) )
for success in result: if requirement_path.exists():
if not success: logger.info(
raise Exception("插件依赖文件下载失败") f"开始安装插件 {module_path} 依赖文件: {requirement_path}",
logger.debug(f"插件依赖文件列表: {req_paths}", LOG_COMMAND) LOG_COMMAND,
await install_requirement(plugin_path) )
except ValueError as e: await VirtualEnvPackageManager.install_requirement(requirement_path)
logger.warning("未获取到依赖文件路径...", e=e) if requirements_path.exists():
return True logger.info(
raise Exception("插件下载失败...") f"开始安装插件 {module_path} 依赖文件: {requirements_path}",
LOG_COMMAND,
)
await VirtualEnvPackageManager.install_requirement(requirements_path)
@classmethod @classmethod
async def remove_plugin(cls, plugin_id: str) -> str: async def remove_plugin(cls, index_or_module: str) -> str:
"""移除插件 """移除插件
参数: 参数:
plugin_id: 插件id或模块名 index_or_module: 插件id或模块名
返回: 返回:
str: 返回消息 str: 返回消息
""" """
plugin_list: list[StorePluginInfo] = await cls.get_data() plugin_info = await cls.get_plugin_by_value(index_or_module)
try:
plugin_key = await cls._resolve_plugin_key(plugin_id)
except ValueError as e:
return str(e)
plugin_info = next((p for p in plugin_list if p.module == plugin_key), None)
if plugin_info is None:
return f"未找到插件 {plugin_key}"
path = BASE_PATH path = BASE_PATH
if plugin_info.github_url: if plugin_info.github_url:
path = BASE_PATH / "plugins" path = BASE_PATH / "plugins"
@ -339,12 +329,13 @@ class StoreManager:
返回: 返回:
BuildImage | str: 返回消息 BuildImage | str: 返回消息
""" """
plugin_list: list[StorePluginInfo] = await cls.get_data() plugin_list, extra_plugin_list = await cls.get_data()
all_plugin_list = plugin_list + extra_plugin_list
db_plugin_list = await cls.get_loaded_plugins("module", "version") db_plugin_list = await cls.get_loaded_plugins("module", "version")
suc_plugin = {p[0]: (p[1] or "Unknown") for p in db_plugin_list} suc_plugin = {p[0]: (p[1] or "Unknown") for p in db_plugin_list}
filtered_data = [ filtered_data = [
(id, plugin_info) (id, plugin_info)
for id, plugin_info in enumerate(plugin_list) for id, plugin_info in enumerate(all_plugin_list)
if plugin_name_or_author.lower() in plugin_info.name.lower() if plugin_name_or_author.lower() in plugin_info.name.lower()
or plugin_name_or_author.lower() in plugin_info.author.lower() or plugin_name_or_author.lower() in plugin_info.author.lower()
] ]
@ -373,28 +364,19 @@ class StoreManager:
) )
@classmethod @classmethod
async def update_plugin(cls, plugin_id: str) -> str: async def update_plugin(cls, index_or_module: str) -> str:
"""更新插件 """更新插件
参数: 参数:
plugin_id: 插件id index_or_module: 插件id
返回: 返回:
str: 返回消息 str: 返回消息
""" """
plugin_list: list[StorePluginInfo] = await cls.get_data() plugin_info = await cls.get_plugin_by_value(index_or_module, True)
try:
plugin_key = await cls._resolve_plugin_key(plugin_id)
except ValueError as e:
return str(e)
plugin_info = next((p for p in plugin_list if p.module == plugin_key), None)
if plugin_info is None:
return f"未找到插件 {plugin_key}"
logger.info(f"尝试更新插件 {plugin_info.name}", LOG_COMMAND) logger.info(f"尝试更新插件 {plugin_info.name}", LOG_COMMAND)
db_plugin_list = await cls.get_loaded_plugins("module", "version") db_plugin_list = await cls.get_loaded_plugins("module", "version")
suc_plugin = {p[0]: (p[1] or "Unknown") for p in db_plugin_list} suc_plugin = {p[0]: (p[1] or "Unknown") for p in db_plugin_list}
if plugin_info.module not in [p[0] for p in db_plugin_list]:
return f"插件 {plugin_info.name} 未安装,无法更新"
logger.debug(f"当前插件列表: {suc_plugin}", LOG_COMMAND) logger.debug(f"当前插件列表: {suc_plugin}", LOG_COMMAND)
if cls.check_version_is_new(plugin_info, suc_plugin): if cls.check_version_is_new(plugin_info, suc_plugin):
return f"插件 {plugin_info.name} 已是最新版本" return f"插件 {plugin_info.name} 已是最新版本"
@ -492,22 +474,25 @@ class StoreManager:
plugin_id: moduleid或插件名称 plugin_id: moduleid或插件名称
异常: 异常:
ValueError: 插件不存在 PluginStoreException: 插件不存在
ValueError: 插件不存在 PluginStoreException: 插件不存在
返回: 返回:
str: 插件模块名 str: 插件模块名
""" """
plugin_list: list[StorePluginInfo] = await cls.get_data() plugin_list, extra_plugin_list = await cls.get_data()
all_plugin_list = plugin_list + extra_plugin_list
if is_number(plugin_id): if is_number(plugin_id):
idx = int(plugin_id) idx = int(plugin_id)
if idx < 0 or idx >= len(plugin_list): if idx < 0 or idx >= len(all_plugin_list):
raise ValueError("插件ID不存在...") raise PluginStoreException("插件ID不存在...")
return plugin_list[idx].module return all_plugin_list[idx].module
elif isinstance(plugin_id, str): elif isinstance(plugin_id, str):
result = ( result = (
None if plugin_id not in [v.module for v in plugin_list] else plugin_id None
) or next(v for v in plugin_list if v.name == plugin_id).module if plugin_id not in [v.module for v in all_plugin_list]
else plugin_id
) or next(v for v in all_plugin_list if v.name == plugin_id).module
if not result: if not result:
raise ValueError("插件 Module / 名称 不存在...") raise PluginStoreException("插件 Module / 名称 不存在...")
return result return result

View File

@ -0,0 +1,6 @@
class PluginStoreException(Exception):
def __init__(self, message: str):
self.message = message
def __str__(self):
return self.message

View File

@ -16,6 +16,7 @@ from .exceptions import (
RepoNotFoundError, RepoNotFoundError,
RepoUpdateError, RepoUpdateError,
) )
from .file_manager import RepoFileManager as RepoFileManagerClass
from .github_manager import GithubManager from .github_manager import GithubManager
from .models import ( from .models import (
FileDownloadResult, FileDownloadResult,
@ -28,6 +29,7 @@ from .utils import check_git, filter_files, glob_to_regex, run_git_command
GithubRepoManager = GithubManager() GithubRepoManager = GithubManager()
AliyunRepoManager = AliyunCodeupManager() AliyunRepoManager = AliyunCodeupManager()
RepoFileManager = RepoFileManagerClass()
__all__ = [ __all__ = [
"AliyunCodeupConfig", "AliyunCodeupConfig",
@ -45,6 +47,7 @@ __all__ = [
"RepoConfig", "RepoConfig",
"RepoDownloadError", "RepoDownloadError",
"RepoFileInfo", "RepoFileInfo",
"RepoFileManager",
"RepoManagerError", "RepoManagerError",
"RepoNotFoundError", "RepoNotFoundError",
"RepoType", "RepoType",

View File

@ -0,0 +1,542 @@
"""
仓库文件管理器用于从GitHub和阿里云CodeUp获取指定文件内容
"""
from pathlib import Path
from typing import cast, overload
import aiofiles
from httpx import Response
from zhenxun.services.log import logger
from zhenxun.utils.github_utils import GithubUtils
from zhenxun.utils.github_utils.models import AliyunTreeType, GitHubStrategy, TreeType
from zhenxun.utils.http_utils import AsyncHttpx
from .config import LOG_COMMAND, RepoConfig
from .exceptions import FileNotFoundError, NetworkError, RepoManagerError
from .models import FileDownloadResult, RepoFileInfo, RepoType
class RepoFileManager:
"""仓库文件管理器用于获取GitHub和阿里云仓库中的文件内容"""
def __init__(self, config: RepoConfig | None = None):
"""
初始化仓库文件管理器
参数:
config: 配置如果为None则使用默认配置
"""
self.config = config or RepoConfig.get_instance()
self.config.ensure_dirs()
@overload
async def get_github_file_content(
self, url: str, file_path: str, ignore_error: bool = False
) -> str: ...
@overload
async def get_github_file_content(
self, url: str, file_path: list[str], ignore_error: bool = False
) -> list[tuple[str, str]]: ...
async def get_github_file_content(
self, url: str, file_path: str | list[str], ignore_error: bool = False
) -> str | list[tuple[str, str]]:
"""
获取GitHub仓库文件内容
参数:
url: 仓库URL
file_path: 文件路径或文件路径列表
ignore_error: 是否忽略错误
返回:
list[tuple[str, str]]: 文件路径文件内容
"""
results = []
is_str_input = isinstance(file_path, str)
try:
if is_str_input:
file_path = [file_path]
repo_info = GithubUtils.parse_github_url(url)
if await repo_info.update_repo_commit():
logger.info(f"获取最新提交: {repo_info.branch}", LOG_COMMAND)
else:
logger.warning(f"获取最新提交失败: {repo_info}", LOG_COMMAND)
for f in file_path:
try:
file_url = await repo_info.get_raw_download_urls(f)
for fu in file_url:
response: Response = await AsyncHttpx.get(
fu, check_status_code=200
)
if response.status_code == 200:
logger.info(f"获取github文件内容成功: {f}", LOG_COMMAND)
# 确保使用UTF-8编码解析响应内容
try:
text_content = response.content.decode("utf-8")
except UnicodeDecodeError:
# 如果UTF-8解码失败尝试其他编码
text_content = response.content.decode(
"utf-8", errors="ignore"
)
logger.warning(
f"解码文件内容时出现错误,使用忽略错误模式: {f}",
LOG_COMMAND,
)
results.append((f, text_content))
break
else:
logger.warning(
f"获取github文件内容失败: {response.status_code}",
LOG_COMMAND,
)
except Exception as e:
logger.warning(f"获取github文件内容失败: {f}", LOG_COMMAND, e=e)
if not ignore_error:
raise
except Exception as e:
logger.error(f"获取GitHub文件内容失败: {file_path}", LOG_COMMAND, e=e)
raise
logger.debug(f"获取GitHub文件内容: {[r[0] for r in results]}", LOG_COMMAND)
return results[0][1] if is_str_input and results else results
@overload
async def get_aliyun_file_content(
self,
repo_name: str,
file_path: str,
branch: str = "main",
ignore_error: bool = False,
) -> str: ...
@overload
async def get_aliyun_file_content(
self,
repo_name: str,
file_path: list[str],
branch: str = "main",
ignore_error: bool = False,
) -> list[tuple[str, str]]: ...
async def get_aliyun_file_content(
self,
repo_name: str,
file_path: str | list[str],
branch: str = "main",
ignore_error: bool = False,
) -> str | list[tuple[str, str]]:
"""
获取阿里云CodeUp仓库文件内容
参数:
repo: 仓库名称
file_path: 文件路径
branch: 分支名称
ignore_error: 是否忽略错误
返回:
list[tuple[str, str]]: 文件路径文件内容
"""
results = []
is_str_input = isinstance(file_path, str)
# 导入阿里云相关模块
from zhenxun.utils.github_utils.models import AliyunFileInfo
if is_str_input:
file_path = [file_path]
for f in file_path:
try:
content = await AliyunFileInfo.get_file_content(
file_path=f, repo=repo_name, ref=branch
)
results.append((f, content))
except Exception as e:
logger.error(f"获取阿里云文件内容失败: {file_path}", LOG_COMMAND, e=e)
if not ignore_error:
raise
logger.debug(f"获取阿里云文件内容: {[r[0] for r in results]}", LOG_COMMAND)
return results[0][1] if is_str_input and results else results
@overload
async def get_file_content(
self,
repo_url: str,
file_path: str,
branch: str = "main",
repo_type: RepoType | None = None,
ignore_error: bool = False,
) -> str: ...
@overload
async def get_file_content(
self,
repo_url: str,
file_path: list[str],
branch: str = "main",
repo_type: RepoType | None = None,
ignore_error: bool = False,
) -> list[tuple[str, str]]: ...
async def get_file_content(
self,
repo_url: str,
file_path: str | list[str],
branch: str = "main",
repo_type: RepoType | None = None,
ignore_error: bool = False,
) -> str | list[tuple[str, str]]:
"""
获取仓库文件内容
参数:
repo_url: 仓库URL
file_path: 文件路径
branch: 分支名称
repo_type: 仓库类型如果为None则自动判断
ignore_error: 是否忽略错误
返回:
str: 文件内容
"""
# 确定仓库类型
repo_name = (
repo_url.split("/tree/")[0].split("/")[-1].replace(".git", "").strip()
)
if repo_type is None:
try:
return await self.get_aliyun_file_content(
repo_name, file_path, branch, ignore_error
)
except Exception:
return await self.get_github_file_content(
repo_url, file_path, ignore_error
)
try:
if repo_type == RepoType.GITHUB:
return await self.get_github_file_content(
repo_url, file_path, ignore_error
)
elif repo_type == RepoType.ALIYUN:
return await self.get_aliyun_file_content(
repo_name, file_path, branch, ignore_error
)
except Exception as e:
if isinstance(e, FileNotFoundError | NetworkError | RepoManagerError):
raise
raise RepoManagerError(f"获取文件内容失败: {e}")
async def list_directory_files(
self,
repo_url: str,
directory_path: str = "",
branch: str = "main",
repo_type: RepoType | None = None,
recursive: bool = True,
) -> list[RepoFileInfo]:
"""
获取仓库目录下的所有文件路径
参数:
repo_url: 仓库URL
directory_path: 目录路径默认为仓库根目录
branch: 分支名称
repo_type: 仓库类型如果为None则自动判断
recursive: 是否递归获取子目录文件
返回:
list[RepoFileInfo]: 文件信息列表
"""
repo_name = (
repo_url.split("/tree/")[0].split("/")[-1].replace(".git", "").strip()
)
try:
if repo_type is None:
# 尝试阿里云失败则尝试GitHub
try:
return await self._list_aliyun_directory_files(
repo_name, directory_path, branch, recursive
)
except Exception as e:
logger.warning(
"获取阿里云目录文件失败尝试GitHub", LOG_COMMAND, e=e
)
return await self._list_github_directory_files(
repo_url, directory_path, branch, recursive
)
if repo_type == RepoType.GITHUB:
return await self._list_github_directory_files(
repo_url, directory_path, branch, recursive
)
elif repo_type == RepoType.ALIYUN:
return await self._list_aliyun_directory_files(
repo_name, directory_path, branch, recursive
)
except Exception as e:
logger.error(f"获取目录文件列表失败: {directory_path}", LOG_COMMAND, e=e)
if isinstance(e, FileNotFoundError | NetworkError | RepoManagerError):
raise
raise RepoManagerError(f"获取目录文件列表失败: {e}")
async def _list_github_directory_files(
self,
repo_url: str,
directory_path: str = "",
branch: str = "main",
recursive: bool = True,
build_tree: bool = False,
) -> list[RepoFileInfo]:
"""
获取GitHub仓库目录下的所有文件路径
参数:
repo_url: 仓库URL
directory_path: 目录路径默认为仓库根目录
branch: 分支名称
recursive: 是否递归获取子目录文件
build_tree: 是否构建目录树
返回:
list[RepoFileInfo]: 文件信息列表
"""
try:
repo_info = GithubUtils.parse_github_url(repo_url)
if await repo_info.update_repo_commit():
logger.info(f"获取最新提交: {repo_info.branch}", LOG_COMMAND)
else:
logger.warning(f"获取最新提交失败: {repo_info}", LOG_COMMAND)
# 获取仓库树信息
strategy = GitHubStrategy()
strategy.body = await GitHubStrategy.parse_repo_info(repo_info)
# 处理目录路径,确保格式正确
if directory_path and not directory_path.endswith("/") and recursive:
directory_path = f"{directory_path}/"
# 获取文件列表
file_list = []
for tree_item in strategy.body.tree:
# 如果不是递归模式,只获取当前目录下的文件
if not recursive and "/" in tree_item.path.replace(
directory_path, "", 1
):
continue
# 检查是否在指定目录下
if directory_path and not tree_item.path.startswith(directory_path):
continue
# 创建文件信息对象
file_info = RepoFileInfo(
path=tree_item.path,
is_dir=tree_item.type == TreeType.DIR,
size=tree_item.size,
last_modified=None, # GitHub API不直接提供最后修改时间
)
file_list.append(file_info)
# 构建目录树结构
if recursive and build_tree:
file_list = self._build_directory_tree(file_list)
return file_list
except Exception as e:
logger.error(
f"获取GitHub目录文件列表失败: {directory_path}", LOG_COMMAND, e=e
)
raise
async def _list_aliyun_directory_files(
self,
repo_name: str,
directory_path: str = "",
branch: str = "main",
recursive: bool = True,
build_tree: bool = False,
) -> list[RepoFileInfo]:
"""
获取阿里云CodeUp仓库目录下的所有文件路径
参数:
repo_name: 仓库名称
directory_path: 目录路径默认为仓库根目录
branch: 分支名称
recursive: 是否递归获取子目录文件
build_tree: 是否构建目录树
返回:
list[RepoFileInfo]: 文件信息列表
"""
try:
from zhenxun.utils.github_utils.models import AliyunFileInfo
# 获取仓库树信息
search_type = "RECURSIVE" if recursive else "DIRECT"
tree_list = await AliyunFileInfo.get_repository_tree(
repo=repo_name,
path=directory_path,
ref=branch,
search_type=search_type,
)
# 创建文件信息对象列表
file_list = []
for tree_item in tree_list:
file_info = RepoFileInfo(
path=tree_item.path,
is_dir=tree_item.type == AliyunTreeType.DIR,
size=None, # 阿里云API不直接提供文件大小
last_modified=None, # 阿里云API不直接提供最后修改时间
)
file_list.append(file_info)
# 构建目录树结构
if recursive and build_tree:
file_list = self._build_directory_tree(file_list)
return file_list
except Exception as e:
logger.error(
f"获取阿里云目录文件列表失败: {directory_path}", LOG_COMMAND, e=e
)
raise
def _build_directory_tree(
self, file_list: list[RepoFileInfo]
) -> list[RepoFileInfo]:
"""
构建目录树结构
参数:
file_list: 文件信息列表
返回:
list[RepoFileInfo]: 根目录下的文件信息列表
"""
# 按路径排序,确保父目录在子目录之前
file_list.sort(key=lambda x: x.path)
# 创建路径到文件信息的映射
path_map = {file_info.path: file_info for file_info in file_list}
# 根目录文件列表
root_files = []
for file_info in file_list:
if parent_path := "/".join(file_info.path.split("/")[:-1]):
# 如果有父目录,将当前文件添加到父目录的子文件列表中
if parent_path in path_map:
path_map[parent_path].children.append(file_info)
else:
# 如果父目录不在列表中,创建一个虚拟的父目录
parent_info = RepoFileInfo(
path=parent_path, is_dir=True, children=[file_info]
)
path_map[parent_path] = parent_info
# 检查父目录的父目录
grand_parent_path = "/".join(parent_path.split("/")[:-1])
if grand_parent_path and grand_parent_path in path_map:
path_map[grand_parent_path].children.append(parent_info)
else:
root_files.append(parent_info)
else:
# 如果没有父目录,则是根目录下的文件
root_files.append(file_info)
# 返回根目录下的文件列表
return [
file
for file in root_files
if all(f.path != file.path for f in file_list if f != file)
]
async def download_files(
self,
repo_url: str,
file_path: tuple[str, Path] | list[tuple[str, Path]],
branch: str = "main",
repo_type: RepoType | None = None,
ignore_error: bool = False,
) -> FileDownloadResult:
"""
下载单个文件
参数:
repo_url: 仓库URL
file_path: 文件在仓库中的路径本地存储路径
branch: 分支名称
repo_type: 仓库类型如果为None则自动判断
ignore_error: 是否忽略错误
返回:
FileDownloadResult: 下载结果
"""
# 确定仓库类型和所有者
repo_name = (
repo_url.split("/tree/")[0].split("/")[-1].replace(".git", "").strip()
)
if isinstance(file_path, tuple):
file_path = [file_path]
file_path_mapping = {f[0]: f[1] for f in file_path}
# 创建结果对象
result = FileDownloadResult(
repo_type=repo_type,
repo_name=repo_name,
file_path=file_path,
version=branch,
)
try:
# 由于我们传入的是列表,所以这里一定返回列表
file_paths = [f[0] for f in file_path]
if len(file_paths) == 1:
# 如果只有一个文件,可能返回单个元组
file_contents_result = await self.get_file_content(
repo_url, file_paths[0], branch, repo_type, ignore_error
)
if isinstance(file_contents_result, tuple):
file_contents = [file_contents_result]
elif isinstance(file_contents_result, str):
file_contents = [(file_paths[0], file_contents_result)]
else:
file_contents = cast(list[tuple[str, str]], file_contents_result)
else:
# 多个文件一定返回列表
file_contents = cast(
list[tuple[str, str]],
await self.get_file_content(
repo_url, file_paths, branch, repo_type, ignore_error
),
)
for repo_file_path, content in file_contents:
local_path = file_path_mapping[repo_file_path]
local_path.parent.mkdir(parents=True, exist_ok=True)
# 使用二进制模式写入文件,避免编码问题
if isinstance(content, str):
content_bytes = content.encode("utf-8")
else:
content_bytes = content
async with aiofiles.open(local_path, "wb") as f:
await f.write(content_bytes)
result.success = True
# 计算文件大小
result.file_size = sum(
len(content.encode("utf-8") if isinstance(content, str) else content)
for _, content in file_contents
)
return result
except Exception as e:
logger.error(f"下载文件失败: {e}")
result.success = False
result.error_message = str(e)
return result

View File

@ -5,6 +5,7 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from pathlib import Path
class RepoType(str, Enum): class RepoType(str, Enum):
@ -26,6 +27,8 @@ class RepoFileInfo:
size: int | None = None size: int | None = None
# 最后修改时间 # 最后修改时间
last_modified: datetime | None = None last_modified: datetime | None = None
# 子文件列表
children: list["RepoFileInfo"] = field(default_factory=list)
@dataclass @dataclass
@ -71,15 +74,11 @@ class FileDownloadResult:
"""文件下载结果""" """文件下载结果"""
# 仓库类型 # 仓库类型
repo_type: RepoType repo_type: RepoType | None
# 仓库名称 # 仓库名称
repo_name: str repo_name: str
# 仓库拥有者
owner: str
# 文件路径 # 文件路径
file_path: str file_path: list[tuple[str, Path]]
# 本地路径
local_path: str
# 版本 # 版本
version: str version: str
# 是否成功 # 是否成功