From 51c010daa8f688aec2ac9f203fe91d45f80f8ac1 Mon Sep 17 00:00:00 2001 From: AkashiCoin Date: Mon, 16 Sep 2024 20:08:42 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=88=20perf(github=5Futils):=20?= =?UTF-8?q?=E6=94=AF=E6=8C=81github=20url=E4=B8=8B=E8=BD=BD=E9=81=8D?= =?UTF-8?q?=E5=8E=86=20(#1632)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🎈 perf(github_utils): 支持github url下载遍历 * 🐞 fix(http_utils): 修复一些下载问题 * 🦄 refactor(http_utils): 部分重构 * chore(version): Update version to v0.2.2-e6f17c4 --------- Co-authored-by: AkashiCoin --- .vscode/settings.json | 3 +- __version__ | 2 +- .../auto_update/_data_source.py | 12 +- .../plugin_store/data_source.py | 21 +-- .../web_ui/public/data_source.py | 6 +- zhenxun/utils/github_utils/__init__.py | 32 ++-- zhenxun/utils/github_utils/func.py | 20 +- zhenxun/utils/github_utils/models.py | 50 +++-- zhenxun/utils/http_utils.py | 177 +++++++++++------- 9 files changed, 197 insertions(+), 126 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 0a666ba9..6ab1cbda 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -23,7 +23,8 @@ "ujson", "unban", "userinfo", - "zhenxun" + "zhenxun", + "jsdelivr" ], "python.analysis.autoImportCompletions": true, "python.testing.pytestArgs": ["tests"], diff --git a/__version__ b/__version__ index 0a27f7d4..383741a4 100644 --- a/__version__ +++ b/__version__ @@ -1 +1 @@ -__version__: v0.2.2 +__version__: v0.2.2-e6f17c4 diff --git a/zhenxun/builtin_plugins/auto_update/_data_source.py b/zhenxun/builtin_plugins/auto_update/_data_source.py index b089fc2e..be5e3b23 100644 --- a/zhenxun/builtin_plugins/auto_update/_data_source.py +++ b/zhenxun/builtin_plugins/auto_update/_data_source.py @@ -10,8 +10,8 @@ from nonebot.utils import run_sync from zhenxun.services.log import logger from zhenxun.utils.http_utils import AsyncHttpx from zhenxun.utils.platform import PlatformUtils +from zhenxun.utils.github_utils import GithubUtils from zhenxun.utils.github_utils.models import RepoInfo -from zhenxun.utils.github_utils import parse_github_url from .config import ( TMP_PATH, @@ -170,19 +170,19 @@ class UpdateManage: cur_version = cls.__get_version() url = None new_version = None - repo_info = parse_github_url(DEFAULT_GITHUB_URL) + repo_info = GithubUtils.parse_github_url(DEFAULT_GITHUB_URL) if version_type in {"dev", "main"}: repo_info.branch = version_type new_version = await cls.__get_version_from_repo(repo_info) if new_version: new_version = new_version.split(":")[-1].strip() - url = await repo_info.get_archive_download_url() + url = await repo_info.get_archive_download_urls() elif version_type == "release": data = await cls.__get_latest_data() if not data: return "获取更新版本失败..." new_version = data.get("name", "") - url = await repo_info.get_release_source_download_url_tgz(new_version) + url = await repo_info.get_release_source_download_urls_tgz(new_version) if not url: return "获取版本下载链接失败..." if TMP_PATH.exists(): @@ -200,7 +200,7 @@ class UpdateManage: download_file = ( DOWNLOAD_GZ_FILE if version_type == "release" else DOWNLOAD_ZIP_FILE ) - if await AsyncHttpx.download_file(url, download_file): + if await AsyncHttpx.download_file(url, download_file, stream=True): logger.debug("下载真寻最新版文件完成...", "检查更新") await _file_handle(new_version) return ( @@ -253,7 +253,7 @@ class UpdateManage: 返回: str: 版本号 """ - version_url = await repo_info.get_raw_download_url(path="__version__") + version_url = await repo_info.get_raw_download_urls(path="__version__") try: res = await AsyncHttpx.get(version_url) if res.status_code == 200: diff --git a/zhenxun/builtin_plugins/plugin_store/data_source.py b/zhenxun/builtin_plugins/plugin_store/data_source.py index 6c35890c..bc4dc5af 100644 --- a/zhenxun/builtin_plugins/plugin_store/data_source.py +++ b/zhenxun/builtin_plugins/plugin_store/data_source.py @@ -8,8 +8,8 @@ from aiocache import cached from zhenxun.services.log import logger from zhenxun.utils.http_utils import AsyncHttpx from zhenxun.models.plugin_info import PluginInfo +from zhenxun.utils.github_utils import GithubUtils from zhenxun.utils.github_utils.models import RepoAPI -from zhenxun.utils.github_utils import api_strategy, parse_github_url from zhenxun.builtin_plugins.plugin_store.models import StorePluginInfo from zhenxun.utils.image_utils import RowStyle, BuildImage, ImageTemplate from zhenxun.builtin_plugins.auto_update.config import REQ_TXT_FILE_STRING @@ -78,12 +78,12 @@ class ShopManage: 返回: dict: 插件信息数据 """ - default_github_url = await parse_github_url( + default_github_url = await GithubUtils.parse_github_url( DEFAULT_GITHUB_URL - ).get_raw_download_url("plugins.json") - extra_github_url = await parse_github_url( + ).get_raw_download_urls("plugins.json") + extra_github_url = await GithubUtils.parse_github_url( EXTRA_GITHUB_URL - ).get_raw_download_url("plugins.json") + ).get_raw_download_urls("plugins.json") res = await AsyncHttpx.get(default_github_url) res2 = await AsyncHttpx.get(extra_github_url) @@ -210,9 +210,9 @@ class ShopManage: ): files: list[str] repo_api: RepoAPI - repo_info = parse_github_url(github_url) + repo_info = GithubUtils.parse_github_url(github_url) logger.debug(f"成功获取仓库信息: {repo_info}", "插件管理") - for repo_api in api_strategy: + for repo_api in GithubUtils.iter_api_strategies(): try: await repo_api.parse_repo_info(repo_info) break @@ -227,7 +227,7 @@ class ShopManage: module_path=module_path.replace(".", "/") + ("" if is_dir else ".py"), is_dir=is_dir, ) - download_urls = [await repo_info.get_raw_download_url(file) for file in files] + download_urls = [await repo_info.get_raw_download_urls(file) for file in files] base_path = BASE_PATH / "plugins" if is_external else BASE_PATH download_paths: list[Path | str] = [base_path / file for file in files] logger.debug(f"插件下载路径: {download_paths}", "插件管理") @@ -242,7 +242,7 @@ class ShopManage: req_files.extend(repo_api.get_files("requirement.txt", False)) logger.debug(f"获取插件依赖文件列表: {req_files}", "插件管理") req_download_urls = [ - await repo_info.get_raw_download_url(file) for file in req_files + await repo_info.get_raw_download_urls(file) for file in req_files ] req_paths: list[Path | str] = [plugin_path / file for file in req_files] logger.debug(f"插件依赖文件下载路径: {req_paths}", "插件管理") @@ -252,11 +252,10 @@ class ShopManage: ) for _id, success in enumerate(result): if not success: - break + raise Exception("插件依赖文件下载失败") else: logger.debug(f"插件依赖文件列表: {req_paths}", "插件管理") install_requirement(plugin_path) - raise Exception("插件依赖文件下载失败") return True raise Exception("插件下载失败") diff --git a/zhenxun/builtin_plugins/web_ui/public/data_source.py b/zhenxun/builtin_plugins/web_ui/public/data_source.py index 8134433c..8d094d99 100644 --- a/zhenxun/builtin_plugins/web_ui/public/data_source.py +++ b/zhenxun/builtin_plugins/web_ui/public/data_source.py @@ -6,7 +6,7 @@ from nonebot.utils import run_sync from zhenxun.services.log import logger from zhenxun.utils.http_utils import AsyncHttpx -from zhenxun.utils.github_utils import parse_github_url +from zhenxun.utils.github_utils import GithubUtils from ..config import TMP_PATH, PUBLIC_PATH, WEBUI_DIST_GITHUB_URL @@ -15,9 +15,9 @@ COMMAND_NAME = "WebUI资源管理" async def update_webui_assets(): webui_assets_path = TMP_PATH / "webui_assets.zip" - download_url = await parse_github_url( + download_url = await GithubUtils.parse_github_url( WEBUI_DIST_GITHUB_URL - ).get_archive_download_url() + ).get_archive_download_urls() if await AsyncHttpx.download_file( download_url, webui_assets_path, follow_redirects=True ): diff --git a/zhenxun/utils/github_utils/__init__.py b/zhenxun/utils/github_utils/__init__.py index 56fd1008..89b0a80a 100644 --- a/zhenxun/utils/github_utils/__init__.py +++ b/zhenxun/utils/github_utils/__init__.py @@ -1,23 +1,27 @@ +from collections.abc import Generator + from .consts import GITHUB_REPO_URL_PATTERN -from .func import get_fastest_raw_format, get_fastest_archive_format +from .func import get_fastest_raw_formats, get_fastest_archive_formats from .models import RepoAPI, RepoInfo, GitHubStrategy, JsdelivrStrategy __all__ = [ - "parse_github_url", - "get_fastest_raw_format", - "get_fastest_archive_format", - "api_strategy", + "get_fastest_raw_formats", + "get_fastest_archive_formats", + "GithubUtils", ] -def parse_github_url(github_url: str) -> "RepoInfo": - if matched := GITHUB_REPO_URL_PATTERN.match(github_url): - return RepoInfo(**{k: v for k, v in matched.groupdict().items() if v}) - raise ValueError("github地址格式错误") +class GithubUtils: + # 使用 + jsdelivr_api = RepoAPI(JsdelivrStrategy()) # type: ignore + github_api = RepoAPI(GitHubStrategy()) # type: ignore + @classmethod + def iter_api_strategies(cls) -> Generator[RepoAPI]: + yield from [cls.github_api, cls.jsdelivr_api] -# 使用 -jsdelivr_api = RepoAPI(JsdelivrStrategy()) # type: ignore -github_api = RepoAPI(GitHubStrategy()) # type: ignore - -api_strategy = [github_api, jsdelivr_api] + @classmethod + def parse_github_url(cls, github_url: str) -> "RepoInfo": + if matched := GITHUB_REPO_URL_PATTERN.match(github_url): + return RepoInfo(**{k: v for k, v in matched.groupdict().items() if v}) + raise ValueError("github地址格式错误") diff --git a/zhenxun/utils/github_utils/func.py b/zhenxun/utils/github_utils/func.py index 145dbd96..95d2a3ef 100644 --- a/zhenxun/utils/github_utils/func.py +++ b/zhenxun/utils/github_utils/func.py @@ -9,15 +9,15 @@ from .consts import ( ) -async def __get_fastest_format(formats: dict[str, str]) -> str: +async def __get_fastest_formats(formats: dict[str, str]) -> list[str]: sorted_urls = await AsyncHttpx.get_fastest_mirror(list(formats.keys())) if not sorted_urls: raise Exception("无法获取任意GitHub资源加速地址,请检查网络") - return formats[sorted_urls[0]] + return [formats[url] for url in sorted_urls] @cached() -async def get_fastest_raw_format() -> str: +async def get_fastest_raw_formats() -> list[str]: """获取最快的raw下载地址格式""" formats: dict[str, str] = { "https://raw.githubusercontent.com/": RAW_CONTENT_FORMAT, @@ -26,11 +26,11 @@ async def get_fastest_raw_format() -> str: "https://gh-proxy.com/": f"https://gh-proxy.com/{RAW_CONTENT_FORMAT}", "https://cdn.jsdelivr.net/": "https://cdn.jsdelivr.net/gh/{owner}/{repo}@{branch}/{path}", } - return await __get_fastest_format(formats) + return await __get_fastest_formats(formats) @cached() -async def get_fastest_archive_format() -> str: +async def get_fastest_archive_formats() -> list[str]: """获取最快的归档下载地址格式""" formats: dict[str, str] = { "https://github.com/": ARCHIVE_URL_FORMAT, @@ -38,11 +38,11 @@ async def get_fastest_archive_format() -> str: "https://mirror.ghproxy.com/": f"https://mirror.ghproxy.com/{ARCHIVE_URL_FORMAT}", "https://gh-proxy.com/": f"https://gh-proxy.com/{ARCHIVE_URL_FORMAT}", } - return await __get_fastest_format(formats) + return await __get_fastest_formats(formats) @cached() -async def get_fastest_release_format() -> str: +async def get_fastest_release_formats() -> list[str]: """获取最快的发行版资源下载地址格式""" formats: dict[str, str] = { "https://objects.githubusercontent.com/": RELEASE_ASSETS_FORMAT, @@ -50,14 +50,14 @@ async def get_fastest_release_format() -> str: "https://mirror.ghproxy.com/": f"https://mirror.ghproxy.com/{RELEASE_ASSETS_FORMAT}", "https://gh-proxy.com/": f"https://gh-proxy.com/{RELEASE_ASSETS_FORMAT}", } - return await __get_fastest_format(formats) + return await __get_fastest_formats(formats) @cached() -async def get_fastest_release_source_format() -> str: +async def get_fastest_release_source_formats() -> list[str]: """获取最快的发行版源码下载地址格式""" formats: dict[str, str] = { "https://codeload.github.com/": RELEASE_SOURCE_FORMAT, "https://p.102333.xyz/": f"https://p.102333.xyz/{RELEASE_SOURCE_FORMAT}", } - return await __get_fastest_format(formats) + return await __get_fastest_formats(formats) diff --git a/zhenxun/utils/github_utils/models.py b/zhenxun/utils/github_utils/models.py index 98ce0291..17089281 100644 --- a/zhenxun/utils/github_utils/models.py +++ b/zhenxun/utils/github_utils/models.py @@ -7,9 +7,9 @@ from pydantic import BaseModel from ..http_utils import AsyncHttpx from .consts import CACHED_API_TTL, GIT_API_TREES_FORMAT, JSD_PACKAGE_API_FORMAT from .func import ( - get_fastest_raw_format, - get_fastest_archive_format, - get_fastest_release_source_format, + get_fastest_raw_formats, + get_fastest_archive_formats, + get_fastest_release_source_formats, ) @@ -20,21 +20,41 @@ class RepoInfo(BaseModel): repo: str branch: str = "main" - async def get_raw_download_url(self, path: str): - url_format = await get_fastest_raw_format() - return url_format.format(**self.dict(), path=path) + async def get_raw_download_url(self, path: str) -> str: + return (await self.get_raw_download_urls(path))[0] - async def get_archive_download_url(self): - url_format = await get_fastest_archive_format() - return url_format.format(**self.dict()) + async def get_archive_download_url(self) -> str: + return (await self.get_archive_download_urls())[0] - async def get_release_source_download_url_tgz(self, version: str): - url_format = await get_fastest_release_source_format() - return url_format.format(**self.dict(), version=version, compress="tar.gz") + async def get_release_source_download_url_tgz(self, version: str) -> str: + return (await self.get_release_source_download_urls_tgz(version))[0] - async def get_release_source_download_url_zip(self, version: str): - url_format = await get_fastest_release_source_format() - return url_format.format(**self.dict(), version=version, compress="zip") + async def get_release_source_download_url_zip(self, version: str) -> str: + return (await self.get_release_source_download_urls_zip(version))[0] + + async def get_raw_download_urls(self, path: str) -> list[str]: + url_formats = await get_fastest_raw_formats() + return [ + url_format.format(**self.dict(), path=path) for url_format in url_formats + ] + + async def get_archive_download_urls(self) -> list[str]: + url_formats = await get_fastest_archive_formats() + return [url_format.format(**self.dict()) for url_format in url_formats] + + async def get_release_source_download_urls_tgz(self, version: str) -> list[str]: + url_formats = await get_fastest_release_source_formats() + return [ + url_format.format(**self.dict(), version=version, compress="tar.gz") + for url_format in url_formats + ] + + async def get_release_source_download_urls_zip(self, version: str) -> list[str]: + url_formats = await get_fastest_release_source_formats() + return [ + url_format.format(**self.dict(), version=version, compress="zip") + for url_format in url_formats + ] class APIStrategy(Protocol): diff --git a/zhenxun/utils/http_utils.py b/zhenxun/utils/http_utils.py index 98f2b74e..1891e265 100644 --- a/zhenxun/utils/http_utils.py +++ b/zhenxun/utils/http_utils.py @@ -11,9 +11,9 @@ import httpx import aiofiles from retrying import retry from playwright.async_api import Page -from httpx import Response, ConnectTimeout from nonebot_plugin_alconna import UniMessage from nonebot_plugin_htmlrender import get_browser +from httpx import Response, ConnectTimeout, HTTPStatusError from zhenxun.services.log import logger from zhenxun.configs.config import BotConfig @@ -33,7 +33,7 @@ class AsyncHttpx: @retry(stop_max_attempt_number=3) async def get( cls, - url: str, + url: str | list[str], *, params: dict[str, Any] | None = None, headers: dict[str, str] | None = None, @@ -56,6 +56,49 @@ class AsyncHttpx: proxy: 指定代理 timeout: 超时时间 """ + urls = [url] if isinstance(url, str) else url + return await cls._get_first_successful( + urls, + params=params, + headers=headers, + cookies=cookies, + verify=verify, + use_proxy=use_proxy, + proxy=proxy, + timeout=timeout, + **kwargs, + ) + + @classmethod + async def _get_first_successful( + cls, + urls: list[str], + **kwargs, + ) -> Response: + last_exception = None + for url in urls: + try: + return await cls._get_single(url, **kwargs) + except Exception as e: + last_exception = e + if url != urls[-1]: + logger.warning(f"获取 {url} 失败, 尝试下一个") + raise last_exception or Exception("All URLs failed") + + @classmethod + async def _get_single( + cls, + url: str, + *, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + cookies: dict[str, str] | None = None, + verify: bool = True, + use_proxy: bool = True, + proxy: dict[str, str] | None = None, + timeout: int = 30, + **kwargs, + ) -> Response: if not headers: headers = get_user_agent() _proxy = proxy if proxy else cls.proxy if use_proxy else None @@ -162,7 +205,7 @@ class AsyncHttpx: @classmethod async def download_file( cls, - url: str, + url: str | list[str], path: str | Path, *, params: dict[str, str] | None = None, @@ -195,75 +238,79 @@ class AsyncHttpx: path.parent.mkdir(parents=True, exist_ok=True) try: for _ in range(3): - if not stream: + if not isinstance(url, list): + url = [url] + for u in url: try: - response = await cls.get( - url, - params=params, - headers=headers, - cookies=cookies, - use_proxy=use_proxy, - proxy=proxy, - timeout=timeout, - follow_redirects=follow_redirects, - **kwargs, - ) - response.raise_for_status() - content = response.content - async with aiofiles.open(path, "wb") as wf: - await wf.write(content) - logger.info(f"下载 {url} 成功.. Path:{path.absolute()}") - return True - except (TimeoutError, ConnectTimeout): - pass - else: - if not headers: - headers = get_user_agent() - _proxy = proxy if proxy else cls.proxy if use_proxy else None - try: - async with httpx.AsyncClient( - proxies=_proxy, # type: ignore - verify=verify, - ) as client: - async with client.stream( - "GET", - url, + if not stream: + response = await cls.get( + u, params=params, headers=headers, cookies=cookies, + use_proxy=use_proxy, + proxy=proxy, timeout=timeout, - follow_redirects=True, + follow_redirects=follow_redirects, **kwargs, - ) as response: - response.raise_for_status() - logger.info( - f"开始下载 {path.name}.. Path: {path.absolute()}" - ) - async with aiofiles.open(path, "wb") as wf: - total = int(response.headers["Content-Length"]) - with rich.progress.Progress( # type: ignore - rich.progress.TextColumn(path.name), # type: ignore - "[progress.percentage]{task.percentage:>3.0f}%", # type: ignore - rich.progress.BarColumn(bar_width=None), # type: ignore - rich.progress.DownloadColumn(), # type: ignore - rich.progress.TransferSpeedColumn(), # type: ignore - ) as progress: - download_task = progress.add_task( - "Download", total=total - ) - async for chunk in response.aiter_bytes(): - await wf.write(chunk) - await wf.flush() - progress.update( - download_task, - completed=response.num_bytes_downloaded, - ) + ) + response.raise_for_status() + content = response.content + async with aiofiles.open(path, "wb") as wf: + await wf.write(content) + logger.info(f"下载 {u} 成功.. Path:{path.absolute()}") + return True + else: + if not headers: + headers = get_user_agent() + _proxy = ( + proxy if proxy else cls.proxy if use_proxy else None + ) + async with httpx.AsyncClient( + proxies=_proxy, # type: ignore + verify=verify, + ) as client: + async with client.stream( + "GET", + u, + params=params, + headers=headers, + cookies=cookies, + timeout=timeout, + follow_redirects=True, + **kwargs, + ) as response: + response.raise_for_status() logger.info( - f"下载 {url} 成功.. Path:{path.absolute()}" + f"开始下载 {path.name}.. " + f"Path: {path.absolute()}" ) - return True - except (TimeoutError, ConnectTimeout): - pass + async with aiofiles.open(path, "wb") as wf: + total = int(response.headers["Content-Length"]) + with rich.progress.Progress( # type: ignore + rich.progress.TextColumn(path.name), # type: ignore + "[progress.percentage]{task.percentage:>3.0f}%", # type: ignore + rich.progress.BarColumn(bar_width=None), # type: ignore + rich.progress.DownloadColumn(), # type: ignore + rich.progress.TransferSpeedColumn(), # type: ignore + ) as progress: + download_task = progress.add_task( + "Download", total=total + ) + async for chunk in response.aiter_bytes(): + await wf.write(chunk) + await wf.flush() + progress.update( + download_task, + completed=response.num_bytes_downloaded, + ) + logger.info( + f"下载 {u} 成功.. " + f"Path:{path.absolute()}" + ) + return True + except (TimeoutError, ConnectTimeout, HTTPStatusError): + logger.warning(f"下载 {u} 失败.. 尝试下一个地址..") else: logger.error(f"下载 {url} 下载超时.. Path:{path.absolute()}") except Exception as e: @@ -273,7 +320,7 @@ class AsyncHttpx: @classmethod async def gather_download_file( cls, - url_list: list[str], + url_list: list[str] | list[list[str]], path_list: list[str | Path], *, limit_async_number: int | None = None,