mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 14:22:55 +08:00
🎈 perf(github_utils): 支持github url下载遍历 (#1632)
* 🎈 perf(github_utils): 支持github url下载遍历 * 🐞 fix(http_utils): 修复一些下载问题 * 🦄 refactor(http_utils): 部分重构 * chore(version): Update version to v0.2.2-e6f17c4 --------- Co-authored-by: AkashiCoin <AkashiCoin@users.noreply.github.com>
This commit is contained in:
parent
cd88d805ce
commit
51c010daa8
3
.vscode/settings.json
vendored
3
.vscode/settings.json
vendored
@ -23,7 +23,8 @@
|
|||||||
"ujson",
|
"ujson",
|
||||||
"unban",
|
"unban",
|
||||||
"userinfo",
|
"userinfo",
|
||||||
"zhenxun"
|
"zhenxun",
|
||||||
|
"jsdelivr"
|
||||||
],
|
],
|
||||||
"python.analysis.autoImportCompletions": true,
|
"python.analysis.autoImportCompletions": true,
|
||||||
"python.testing.pytestArgs": ["tests"],
|
"python.testing.pytestArgs": ["tests"],
|
||||||
|
|||||||
@ -1 +1 @@
|
|||||||
__version__: v0.2.2
|
__version__: v0.2.2-e6f17c4
|
||||||
|
|||||||
@ -10,8 +10,8 @@ from nonebot.utils import run_sync
|
|||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
from zhenxun.utils.http_utils import AsyncHttpx
|
from zhenxun.utils.http_utils import AsyncHttpx
|
||||||
from zhenxun.utils.platform import PlatformUtils
|
from zhenxun.utils.platform import PlatformUtils
|
||||||
|
from zhenxun.utils.github_utils import GithubUtils
|
||||||
from zhenxun.utils.github_utils.models import RepoInfo
|
from zhenxun.utils.github_utils.models import RepoInfo
|
||||||
from zhenxun.utils.github_utils import parse_github_url
|
|
||||||
|
|
||||||
from .config import (
|
from .config import (
|
||||||
TMP_PATH,
|
TMP_PATH,
|
||||||
@ -170,19 +170,19 @@ class UpdateManage:
|
|||||||
cur_version = cls.__get_version()
|
cur_version = cls.__get_version()
|
||||||
url = None
|
url = None
|
||||||
new_version = None
|
new_version = None
|
||||||
repo_info = parse_github_url(DEFAULT_GITHUB_URL)
|
repo_info = GithubUtils.parse_github_url(DEFAULT_GITHUB_URL)
|
||||||
if version_type in {"dev", "main"}:
|
if version_type in {"dev", "main"}:
|
||||||
repo_info.branch = version_type
|
repo_info.branch = version_type
|
||||||
new_version = await cls.__get_version_from_repo(repo_info)
|
new_version = await cls.__get_version_from_repo(repo_info)
|
||||||
if new_version:
|
if new_version:
|
||||||
new_version = new_version.split(":")[-1].strip()
|
new_version = new_version.split(":")[-1].strip()
|
||||||
url = await repo_info.get_archive_download_url()
|
url = await repo_info.get_archive_download_urls()
|
||||||
elif version_type == "release":
|
elif version_type == "release":
|
||||||
data = await cls.__get_latest_data()
|
data = await cls.__get_latest_data()
|
||||||
if not data:
|
if not data:
|
||||||
return "获取更新版本失败..."
|
return "获取更新版本失败..."
|
||||||
new_version = data.get("name", "")
|
new_version = data.get("name", "")
|
||||||
url = await repo_info.get_release_source_download_url_tgz(new_version)
|
url = await repo_info.get_release_source_download_urls_tgz(new_version)
|
||||||
if not url:
|
if not url:
|
||||||
return "获取版本下载链接失败..."
|
return "获取版本下载链接失败..."
|
||||||
if TMP_PATH.exists():
|
if TMP_PATH.exists():
|
||||||
@ -200,7 +200,7 @@ class UpdateManage:
|
|||||||
download_file = (
|
download_file = (
|
||||||
DOWNLOAD_GZ_FILE if version_type == "release" else DOWNLOAD_ZIP_FILE
|
DOWNLOAD_GZ_FILE if version_type == "release" else DOWNLOAD_ZIP_FILE
|
||||||
)
|
)
|
||||||
if await AsyncHttpx.download_file(url, download_file):
|
if await AsyncHttpx.download_file(url, download_file, stream=True):
|
||||||
logger.debug("下载真寻最新版文件完成...", "检查更新")
|
logger.debug("下载真寻最新版文件完成...", "检查更新")
|
||||||
await _file_handle(new_version)
|
await _file_handle(new_version)
|
||||||
return (
|
return (
|
||||||
@ -253,7 +253,7 @@ class UpdateManage:
|
|||||||
返回:
|
返回:
|
||||||
str: 版本号
|
str: 版本号
|
||||||
"""
|
"""
|
||||||
version_url = await repo_info.get_raw_download_url(path="__version__")
|
version_url = await repo_info.get_raw_download_urls(path="__version__")
|
||||||
try:
|
try:
|
||||||
res = await AsyncHttpx.get(version_url)
|
res = await AsyncHttpx.get(version_url)
|
||||||
if res.status_code == 200:
|
if res.status_code == 200:
|
||||||
|
|||||||
@ -8,8 +8,8 @@ from aiocache import cached
|
|||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
from zhenxun.utils.http_utils import AsyncHttpx
|
from zhenxun.utils.http_utils import AsyncHttpx
|
||||||
from zhenxun.models.plugin_info import PluginInfo
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
|
from zhenxun.utils.github_utils import GithubUtils
|
||||||
from zhenxun.utils.github_utils.models import RepoAPI
|
from zhenxun.utils.github_utils.models import RepoAPI
|
||||||
from zhenxun.utils.github_utils import api_strategy, parse_github_url
|
|
||||||
from zhenxun.builtin_plugins.plugin_store.models import StorePluginInfo
|
from zhenxun.builtin_plugins.plugin_store.models import StorePluginInfo
|
||||||
from zhenxun.utils.image_utils import RowStyle, BuildImage, ImageTemplate
|
from zhenxun.utils.image_utils import RowStyle, BuildImage, ImageTemplate
|
||||||
from zhenxun.builtin_plugins.auto_update.config import REQ_TXT_FILE_STRING
|
from zhenxun.builtin_plugins.auto_update.config import REQ_TXT_FILE_STRING
|
||||||
@ -78,12 +78,12 @@ class ShopManage:
|
|||||||
返回:
|
返回:
|
||||||
dict: 插件信息数据
|
dict: 插件信息数据
|
||||||
"""
|
"""
|
||||||
default_github_url = await parse_github_url(
|
default_github_url = await GithubUtils.parse_github_url(
|
||||||
DEFAULT_GITHUB_URL
|
DEFAULT_GITHUB_URL
|
||||||
).get_raw_download_url("plugins.json")
|
).get_raw_download_urls("plugins.json")
|
||||||
extra_github_url = await parse_github_url(
|
extra_github_url = await GithubUtils.parse_github_url(
|
||||||
EXTRA_GITHUB_URL
|
EXTRA_GITHUB_URL
|
||||||
).get_raw_download_url("plugins.json")
|
).get_raw_download_urls("plugins.json")
|
||||||
res = await AsyncHttpx.get(default_github_url)
|
res = await AsyncHttpx.get(default_github_url)
|
||||||
res2 = await AsyncHttpx.get(extra_github_url)
|
res2 = await AsyncHttpx.get(extra_github_url)
|
||||||
|
|
||||||
@ -210,9 +210,9 @@ class ShopManage:
|
|||||||
):
|
):
|
||||||
files: list[str]
|
files: list[str]
|
||||||
repo_api: RepoAPI
|
repo_api: RepoAPI
|
||||||
repo_info = parse_github_url(github_url)
|
repo_info = GithubUtils.parse_github_url(github_url)
|
||||||
logger.debug(f"成功获取仓库信息: {repo_info}", "插件管理")
|
logger.debug(f"成功获取仓库信息: {repo_info}", "插件管理")
|
||||||
for repo_api in api_strategy:
|
for repo_api in GithubUtils.iter_api_strategies():
|
||||||
try:
|
try:
|
||||||
await repo_api.parse_repo_info(repo_info)
|
await repo_api.parse_repo_info(repo_info)
|
||||||
break
|
break
|
||||||
@ -227,7 +227,7 @@ class ShopManage:
|
|||||||
module_path=module_path.replace(".", "/") + ("" if is_dir else ".py"),
|
module_path=module_path.replace(".", "/") + ("" if is_dir else ".py"),
|
||||||
is_dir=is_dir,
|
is_dir=is_dir,
|
||||||
)
|
)
|
||||||
download_urls = [await repo_info.get_raw_download_url(file) for file in files]
|
download_urls = [await repo_info.get_raw_download_urls(file) for file in files]
|
||||||
base_path = BASE_PATH / "plugins" if is_external else BASE_PATH
|
base_path = BASE_PATH / "plugins" if is_external else BASE_PATH
|
||||||
download_paths: list[Path | str] = [base_path / file for file in files]
|
download_paths: list[Path | str] = [base_path / file for file in files]
|
||||||
logger.debug(f"插件下载路径: {download_paths}", "插件管理")
|
logger.debug(f"插件下载路径: {download_paths}", "插件管理")
|
||||||
@ -242,7 +242,7 @@ class ShopManage:
|
|||||||
req_files.extend(repo_api.get_files("requirement.txt", False))
|
req_files.extend(repo_api.get_files("requirement.txt", False))
|
||||||
logger.debug(f"获取插件依赖文件列表: {req_files}", "插件管理")
|
logger.debug(f"获取插件依赖文件列表: {req_files}", "插件管理")
|
||||||
req_download_urls = [
|
req_download_urls = [
|
||||||
await repo_info.get_raw_download_url(file) for file in req_files
|
await repo_info.get_raw_download_urls(file) for file in req_files
|
||||||
]
|
]
|
||||||
req_paths: list[Path | str] = [plugin_path / file for file in req_files]
|
req_paths: list[Path | str] = [plugin_path / file for file in req_files]
|
||||||
logger.debug(f"插件依赖文件下载路径: {req_paths}", "插件管理")
|
logger.debug(f"插件依赖文件下载路径: {req_paths}", "插件管理")
|
||||||
@ -252,11 +252,10 @@ class ShopManage:
|
|||||||
)
|
)
|
||||||
for _id, success in enumerate(result):
|
for _id, success in enumerate(result):
|
||||||
if not success:
|
if not success:
|
||||||
break
|
raise Exception("插件依赖文件下载失败")
|
||||||
else:
|
else:
|
||||||
logger.debug(f"插件依赖文件列表: {req_paths}", "插件管理")
|
logger.debug(f"插件依赖文件列表: {req_paths}", "插件管理")
|
||||||
install_requirement(plugin_path)
|
install_requirement(plugin_path)
|
||||||
raise Exception("插件依赖文件下载失败")
|
|
||||||
return True
|
return True
|
||||||
raise Exception("插件下载失败")
|
raise Exception("插件下载失败")
|
||||||
|
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from nonebot.utils import run_sync
|
|||||||
|
|
||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
from zhenxun.utils.http_utils import AsyncHttpx
|
from zhenxun.utils.http_utils import AsyncHttpx
|
||||||
from zhenxun.utils.github_utils import parse_github_url
|
from zhenxun.utils.github_utils import GithubUtils
|
||||||
|
|
||||||
from ..config import TMP_PATH, PUBLIC_PATH, WEBUI_DIST_GITHUB_URL
|
from ..config import TMP_PATH, PUBLIC_PATH, WEBUI_DIST_GITHUB_URL
|
||||||
|
|
||||||
@ -15,9 +15,9 @@ COMMAND_NAME = "WebUI资源管理"
|
|||||||
|
|
||||||
async def update_webui_assets():
|
async def update_webui_assets():
|
||||||
webui_assets_path = TMP_PATH / "webui_assets.zip"
|
webui_assets_path = TMP_PATH / "webui_assets.zip"
|
||||||
download_url = await parse_github_url(
|
download_url = await GithubUtils.parse_github_url(
|
||||||
WEBUI_DIST_GITHUB_URL
|
WEBUI_DIST_GITHUB_URL
|
||||||
).get_archive_download_url()
|
).get_archive_download_urls()
|
||||||
if await AsyncHttpx.download_file(
|
if await AsyncHttpx.download_file(
|
||||||
download_url, webui_assets_path, follow_redirects=True
|
download_url, webui_assets_path, follow_redirects=True
|
||||||
):
|
):
|
||||||
|
|||||||
@ -1,23 +1,27 @@
|
|||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
from .consts import GITHUB_REPO_URL_PATTERN
|
from .consts import GITHUB_REPO_URL_PATTERN
|
||||||
from .func import get_fastest_raw_format, get_fastest_archive_format
|
from .func import get_fastest_raw_formats, get_fastest_archive_formats
|
||||||
from .models import RepoAPI, RepoInfo, GitHubStrategy, JsdelivrStrategy
|
from .models import RepoAPI, RepoInfo, GitHubStrategy, JsdelivrStrategy
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"parse_github_url",
|
"get_fastest_raw_formats",
|
||||||
"get_fastest_raw_format",
|
"get_fastest_archive_formats",
|
||||||
"get_fastest_archive_format",
|
"GithubUtils",
|
||||||
"api_strategy",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def parse_github_url(github_url: str) -> "RepoInfo":
|
class GithubUtils:
|
||||||
if matched := GITHUB_REPO_URL_PATTERN.match(github_url):
|
# 使用
|
||||||
return RepoInfo(**{k: v for k, v in matched.groupdict().items() if v})
|
jsdelivr_api = RepoAPI(JsdelivrStrategy()) # type: ignore
|
||||||
raise ValueError("github地址格式错误")
|
github_api = RepoAPI(GitHubStrategy()) # type: ignore
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def iter_api_strategies(cls) -> Generator[RepoAPI]:
|
||||||
|
yield from [cls.github_api, cls.jsdelivr_api]
|
||||||
|
|
||||||
# 使用
|
@classmethod
|
||||||
jsdelivr_api = RepoAPI(JsdelivrStrategy()) # type: ignore
|
def parse_github_url(cls, github_url: str) -> "RepoInfo":
|
||||||
github_api = RepoAPI(GitHubStrategy()) # type: ignore
|
if matched := GITHUB_REPO_URL_PATTERN.match(github_url):
|
||||||
|
return RepoInfo(**{k: v for k, v in matched.groupdict().items() if v})
|
||||||
api_strategy = [github_api, jsdelivr_api]
|
raise ValueError("github地址格式错误")
|
||||||
|
|||||||
@ -9,15 +9,15 @@ from .consts import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def __get_fastest_format(formats: dict[str, str]) -> str:
|
async def __get_fastest_formats(formats: dict[str, str]) -> list[str]:
|
||||||
sorted_urls = await AsyncHttpx.get_fastest_mirror(list(formats.keys()))
|
sorted_urls = await AsyncHttpx.get_fastest_mirror(list(formats.keys()))
|
||||||
if not sorted_urls:
|
if not sorted_urls:
|
||||||
raise Exception("无法获取任意GitHub资源加速地址,请检查网络")
|
raise Exception("无法获取任意GitHub资源加速地址,请检查网络")
|
||||||
return formats[sorted_urls[0]]
|
return [formats[url] for url in sorted_urls]
|
||||||
|
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
async def get_fastest_raw_format() -> str:
|
async def get_fastest_raw_formats() -> list[str]:
|
||||||
"""获取最快的raw下载地址格式"""
|
"""获取最快的raw下载地址格式"""
|
||||||
formats: dict[str, str] = {
|
formats: dict[str, str] = {
|
||||||
"https://raw.githubusercontent.com/": RAW_CONTENT_FORMAT,
|
"https://raw.githubusercontent.com/": RAW_CONTENT_FORMAT,
|
||||||
@ -26,11 +26,11 @@ async def get_fastest_raw_format() -> str:
|
|||||||
"https://gh-proxy.com/": f"https://gh-proxy.com/{RAW_CONTENT_FORMAT}",
|
"https://gh-proxy.com/": f"https://gh-proxy.com/{RAW_CONTENT_FORMAT}",
|
||||||
"https://cdn.jsdelivr.net/": "https://cdn.jsdelivr.net/gh/{owner}/{repo}@{branch}/{path}",
|
"https://cdn.jsdelivr.net/": "https://cdn.jsdelivr.net/gh/{owner}/{repo}@{branch}/{path}",
|
||||||
}
|
}
|
||||||
return await __get_fastest_format(formats)
|
return await __get_fastest_formats(formats)
|
||||||
|
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
async def get_fastest_archive_format() -> str:
|
async def get_fastest_archive_formats() -> list[str]:
|
||||||
"""获取最快的归档下载地址格式"""
|
"""获取最快的归档下载地址格式"""
|
||||||
formats: dict[str, str] = {
|
formats: dict[str, str] = {
|
||||||
"https://github.com/": ARCHIVE_URL_FORMAT,
|
"https://github.com/": ARCHIVE_URL_FORMAT,
|
||||||
@ -38,11 +38,11 @@ async def get_fastest_archive_format() -> str:
|
|||||||
"https://mirror.ghproxy.com/": f"https://mirror.ghproxy.com/{ARCHIVE_URL_FORMAT}",
|
"https://mirror.ghproxy.com/": f"https://mirror.ghproxy.com/{ARCHIVE_URL_FORMAT}",
|
||||||
"https://gh-proxy.com/": f"https://gh-proxy.com/{ARCHIVE_URL_FORMAT}",
|
"https://gh-proxy.com/": f"https://gh-proxy.com/{ARCHIVE_URL_FORMAT}",
|
||||||
}
|
}
|
||||||
return await __get_fastest_format(formats)
|
return await __get_fastest_formats(formats)
|
||||||
|
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
async def get_fastest_release_format() -> str:
|
async def get_fastest_release_formats() -> list[str]:
|
||||||
"""获取最快的发行版资源下载地址格式"""
|
"""获取最快的发行版资源下载地址格式"""
|
||||||
formats: dict[str, str] = {
|
formats: dict[str, str] = {
|
||||||
"https://objects.githubusercontent.com/": RELEASE_ASSETS_FORMAT,
|
"https://objects.githubusercontent.com/": RELEASE_ASSETS_FORMAT,
|
||||||
@ -50,14 +50,14 @@ async def get_fastest_release_format() -> str:
|
|||||||
"https://mirror.ghproxy.com/": f"https://mirror.ghproxy.com/{RELEASE_ASSETS_FORMAT}",
|
"https://mirror.ghproxy.com/": f"https://mirror.ghproxy.com/{RELEASE_ASSETS_FORMAT}",
|
||||||
"https://gh-proxy.com/": f"https://gh-proxy.com/{RELEASE_ASSETS_FORMAT}",
|
"https://gh-proxy.com/": f"https://gh-proxy.com/{RELEASE_ASSETS_FORMAT}",
|
||||||
}
|
}
|
||||||
return await __get_fastest_format(formats)
|
return await __get_fastest_formats(formats)
|
||||||
|
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
async def get_fastest_release_source_format() -> str:
|
async def get_fastest_release_source_formats() -> list[str]:
|
||||||
"""获取最快的发行版源码下载地址格式"""
|
"""获取最快的发行版源码下载地址格式"""
|
||||||
formats: dict[str, str] = {
|
formats: dict[str, str] = {
|
||||||
"https://codeload.github.com/": RELEASE_SOURCE_FORMAT,
|
"https://codeload.github.com/": RELEASE_SOURCE_FORMAT,
|
||||||
"https://p.102333.xyz/": f"https://p.102333.xyz/{RELEASE_SOURCE_FORMAT}",
|
"https://p.102333.xyz/": f"https://p.102333.xyz/{RELEASE_SOURCE_FORMAT}",
|
||||||
}
|
}
|
||||||
return await __get_fastest_format(formats)
|
return await __get_fastest_formats(formats)
|
||||||
|
|||||||
@ -7,9 +7,9 @@ from pydantic import BaseModel
|
|||||||
from ..http_utils import AsyncHttpx
|
from ..http_utils import AsyncHttpx
|
||||||
from .consts import CACHED_API_TTL, GIT_API_TREES_FORMAT, JSD_PACKAGE_API_FORMAT
|
from .consts import CACHED_API_TTL, GIT_API_TREES_FORMAT, JSD_PACKAGE_API_FORMAT
|
||||||
from .func import (
|
from .func import (
|
||||||
get_fastest_raw_format,
|
get_fastest_raw_formats,
|
||||||
get_fastest_archive_format,
|
get_fastest_archive_formats,
|
||||||
get_fastest_release_source_format,
|
get_fastest_release_source_formats,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -20,21 +20,41 @@ class RepoInfo(BaseModel):
|
|||||||
repo: str
|
repo: str
|
||||||
branch: str = "main"
|
branch: str = "main"
|
||||||
|
|
||||||
async def get_raw_download_url(self, path: str):
|
async def get_raw_download_url(self, path: str) -> str:
|
||||||
url_format = await get_fastest_raw_format()
|
return (await self.get_raw_download_urls(path))[0]
|
||||||
return url_format.format(**self.dict(), path=path)
|
|
||||||
|
|
||||||
async def get_archive_download_url(self):
|
async def get_archive_download_url(self) -> str:
|
||||||
url_format = await get_fastest_archive_format()
|
return (await self.get_archive_download_urls())[0]
|
||||||
return url_format.format(**self.dict())
|
|
||||||
|
|
||||||
async def get_release_source_download_url_tgz(self, version: str):
|
async def get_release_source_download_url_tgz(self, version: str) -> str:
|
||||||
url_format = await get_fastest_release_source_format()
|
return (await self.get_release_source_download_urls_tgz(version))[0]
|
||||||
return url_format.format(**self.dict(), version=version, compress="tar.gz")
|
|
||||||
|
|
||||||
async def get_release_source_download_url_zip(self, version: str):
|
async def get_release_source_download_url_zip(self, version: str) -> str:
|
||||||
url_format = await get_fastest_release_source_format()
|
return (await self.get_release_source_download_urls_zip(version))[0]
|
||||||
return url_format.format(**self.dict(), version=version, compress="zip")
|
|
||||||
|
async def get_raw_download_urls(self, path: str) -> list[str]:
|
||||||
|
url_formats = await get_fastest_raw_formats()
|
||||||
|
return [
|
||||||
|
url_format.format(**self.dict(), path=path) for url_format in url_formats
|
||||||
|
]
|
||||||
|
|
||||||
|
async def get_archive_download_urls(self) -> list[str]:
|
||||||
|
url_formats = await get_fastest_archive_formats()
|
||||||
|
return [url_format.format(**self.dict()) for url_format in url_formats]
|
||||||
|
|
||||||
|
async def get_release_source_download_urls_tgz(self, version: str) -> list[str]:
|
||||||
|
url_formats = await get_fastest_release_source_formats()
|
||||||
|
return [
|
||||||
|
url_format.format(**self.dict(), version=version, compress="tar.gz")
|
||||||
|
for url_format in url_formats
|
||||||
|
]
|
||||||
|
|
||||||
|
async def get_release_source_download_urls_zip(self, version: str) -> list[str]:
|
||||||
|
url_formats = await get_fastest_release_source_formats()
|
||||||
|
return [
|
||||||
|
url_format.format(**self.dict(), version=version, compress="zip")
|
||||||
|
for url_format in url_formats
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class APIStrategy(Protocol):
|
class APIStrategy(Protocol):
|
||||||
|
|||||||
@ -11,9 +11,9 @@ import httpx
|
|||||||
import aiofiles
|
import aiofiles
|
||||||
from retrying import retry
|
from retrying import retry
|
||||||
from playwright.async_api import Page
|
from playwright.async_api import Page
|
||||||
from httpx import Response, ConnectTimeout
|
|
||||||
from nonebot_plugin_alconna import UniMessage
|
from nonebot_plugin_alconna import UniMessage
|
||||||
from nonebot_plugin_htmlrender import get_browser
|
from nonebot_plugin_htmlrender import get_browser
|
||||||
|
from httpx import Response, ConnectTimeout, HTTPStatusError
|
||||||
|
|
||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
from zhenxun.configs.config import BotConfig
|
from zhenxun.configs.config import BotConfig
|
||||||
@ -33,7 +33,7 @@ class AsyncHttpx:
|
|||||||
@retry(stop_max_attempt_number=3)
|
@retry(stop_max_attempt_number=3)
|
||||||
async def get(
|
async def get(
|
||||||
cls,
|
cls,
|
||||||
url: str,
|
url: str | list[str],
|
||||||
*,
|
*,
|
||||||
params: dict[str, Any] | None = None,
|
params: dict[str, Any] | None = None,
|
||||||
headers: dict[str, str] | None = None,
|
headers: dict[str, str] | None = None,
|
||||||
@ -56,6 +56,49 @@ class AsyncHttpx:
|
|||||||
proxy: 指定代理
|
proxy: 指定代理
|
||||||
timeout: 超时时间
|
timeout: 超时时间
|
||||||
"""
|
"""
|
||||||
|
urls = [url] if isinstance(url, str) else url
|
||||||
|
return await cls._get_first_successful(
|
||||||
|
urls,
|
||||||
|
params=params,
|
||||||
|
headers=headers,
|
||||||
|
cookies=cookies,
|
||||||
|
verify=verify,
|
||||||
|
use_proxy=use_proxy,
|
||||||
|
proxy=proxy,
|
||||||
|
timeout=timeout,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _get_first_successful(
|
||||||
|
cls,
|
||||||
|
urls: list[str],
|
||||||
|
**kwargs,
|
||||||
|
) -> Response:
|
||||||
|
last_exception = None
|
||||||
|
for url in urls:
|
||||||
|
try:
|
||||||
|
return await cls._get_single(url, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
last_exception = e
|
||||||
|
if url != urls[-1]:
|
||||||
|
logger.warning(f"获取 {url} 失败, 尝试下一个")
|
||||||
|
raise last_exception or Exception("All URLs failed")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _get_single(
|
||||||
|
cls,
|
||||||
|
url: str,
|
||||||
|
*,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
cookies: dict[str, str] | None = None,
|
||||||
|
verify: bool = True,
|
||||||
|
use_proxy: bool = True,
|
||||||
|
proxy: dict[str, str] | None = None,
|
||||||
|
timeout: int = 30,
|
||||||
|
**kwargs,
|
||||||
|
) -> Response:
|
||||||
if not headers:
|
if not headers:
|
||||||
headers = get_user_agent()
|
headers = get_user_agent()
|
||||||
_proxy = proxy if proxy else cls.proxy if use_proxy else None
|
_proxy = proxy if proxy else cls.proxy if use_proxy else None
|
||||||
@ -162,7 +205,7 @@ class AsyncHttpx:
|
|||||||
@classmethod
|
@classmethod
|
||||||
async def download_file(
|
async def download_file(
|
||||||
cls,
|
cls,
|
||||||
url: str,
|
url: str | list[str],
|
||||||
path: str | Path,
|
path: str | Path,
|
||||||
*,
|
*,
|
||||||
params: dict[str, str] | None = None,
|
params: dict[str, str] | None = None,
|
||||||
@ -195,75 +238,79 @@ class AsyncHttpx:
|
|||||||
path.parent.mkdir(parents=True, exist_ok=True)
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
try:
|
try:
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
if not stream:
|
if not isinstance(url, list):
|
||||||
|
url = [url]
|
||||||
|
for u in url:
|
||||||
try:
|
try:
|
||||||
response = await cls.get(
|
if not stream:
|
||||||
url,
|
response = await cls.get(
|
||||||
params=params,
|
u,
|
||||||
headers=headers,
|
|
||||||
cookies=cookies,
|
|
||||||
use_proxy=use_proxy,
|
|
||||||
proxy=proxy,
|
|
||||||
timeout=timeout,
|
|
||||||
follow_redirects=follow_redirects,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
content = response.content
|
|
||||||
async with aiofiles.open(path, "wb") as wf:
|
|
||||||
await wf.write(content)
|
|
||||||
logger.info(f"下载 {url} 成功.. Path:{path.absolute()}")
|
|
||||||
return True
|
|
||||||
except (TimeoutError, ConnectTimeout):
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
if not headers:
|
|
||||||
headers = get_user_agent()
|
|
||||||
_proxy = proxy if proxy else cls.proxy if use_proxy else None
|
|
||||||
try:
|
|
||||||
async with httpx.AsyncClient(
|
|
||||||
proxies=_proxy, # type: ignore
|
|
||||||
verify=verify,
|
|
||||||
) as client:
|
|
||||||
async with client.stream(
|
|
||||||
"GET",
|
|
||||||
url,
|
|
||||||
params=params,
|
params=params,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
cookies=cookies,
|
cookies=cookies,
|
||||||
|
use_proxy=use_proxy,
|
||||||
|
proxy=proxy,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
follow_redirects=True,
|
follow_redirects=follow_redirects,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) as response:
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
logger.info(
|
content = response.content
|
||||||
f"开始下载 {path.name}.. Path: {path.absolute()}"
|
async with aiofiles.open(path, "wb") as wf:
|
||||||
)
|
await wf.write(content)
|
||||||
async with aiofiles.open(path, "wb") as wf:
|
logger.info(f"下载 {u} 成功.. Path:{path.absolute()}")
|
||||||
total = int(response.headers["Content-Length"])
|
return True
|
||||||
with rich.progress.Progress( # type: ignore
|
else:
|
||||||
rich.progress.TextColumn(path.name), # type: ignore
|
if not headers:
|
||||||
"[progress.percentage]{task.percentage:>3.0f}%", # type: ignore
|
headers = get_user_agent()
|
||||||
rich.progress.BarColumn(bar_width=None), # type: ignore
|
_proxy = (
|
||||||
rich.progress.DownloadColumn(), # type: ignore
|
proxy if proxy else cls.proxy if use_proxy else None
|
||||||
rich.progress.TransferSpeedColumn(), # type: ignore
|
)
|
||||||
) as progress:
|
async with httpx.AsyncClient(
|
||||||
download_task = progress.add_task(
|
proxies=_proxy, # type: ignore
|
||||||
"Download", total=total
|
verify=verify,
|
||||||
)
|
) as client:
|
||||||
async for chunk in response.aiter_bytes():
|
async with client.stream(
|
||||||
await wf.write(chunk)
|
"GET",
|
||||||
await wf.flush()
|
u,
|
||||||
progress.update(
|
params=params,
|
||||||
download_task,
|
headers=headers,
|
||||||
completed=response.num_bytes_downloaded,
|
cookies=cookies,
|
||||||
)
|
timeout=timeout,
|
||||||
|
follow_redirects=True,
|
||||||
|
**kwargs,
|
||||||
|
) as response:
|
||||||
|
response.raise_for_status()
|
||||||
logger.info(
|
logger.info(
|
||||||
f"下载 {url} 成功.. Path:{path.absolute()}"
|
f"开始下载 {path.name}.. "
|
||||||
|
f"Path: {path.absolute()}"
|
||||||
)
|
)
|
||||||
return True
|
async with aiofiles.open(path, "wb") as wf:
|
||||||
except (TimeoutError, ConnectTimeout):
|
total = int(response.headers["Content-Length"])
|
||||||
pass
|
with rich.progress.Progress( # type: ignore
|
||||||
|
rich.progress.TextColumn(path.name), # type: ignore
|
||||||
|
"[progress.percentage]{task.percentage:>3.0f}%", # type: ignore
|
||||||
|
rich.progress.BarColumn(bar_width=None), # type: ignore
|
||||||
|
rich.progress.DownloadColumn(), # type: ignore
|
||||||
|
rich.progress.TransferSpeedColumn(), # type: ignore
|
||||||
|
) as progress:
|
||||||
|
download_task = progress.add_task(
|
||||||
|
"Download", total=total
|
||||||
|
)
|
||||||
|
async for chunk in response.aiter_bytes():
|
||||||
|
await wf.write(chunk)
|
||||||
|
await wf.flush()
|
||||||
|
progress.update(
|
||||||
|
download_task,
|
||||||
|
completed=response.num_bytes_downloaded,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"下载 {u} 成功.. "
|
||||||
|
f"Path:{path.absolute()}"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except (TimeoutError, ConnectTimeout, HTTPStatusError):
|
||||||
|
logger.warning(f"下载 {u} 失败.. 尝试下一个地址..")
|
||||||
else:
|
else:
|
||||||
logger.error(f"下载 {url} 下载超时.. Path:{path.absolute()}")
|
logger.error(f"下载 {url} 下载超时.. Path:{path.absolute()}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -273,7 +320,7 @@ class AsyncHttpx:
|
|||||||
@classmethod
|
@classmethod
|
||||||
async def gather_download_file(
|
async def gather_download_file(
|
||||||
cls,
|
cls,
|
||||||
url_list: list[str],
|
url_list: list[str] | list[list[str]],
|
||||||
path_list: list[str | Path],
|
path_list: list[str | Path],
|
||||||
*,
|
*,
|
||||||
limit_async_number: int | None = None,
|
limit_async_number: int | None = None,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user