代码优化

This commit is contained in:
HibiKier 2025-06-16 10:55:53 +08:00 committed by GitHub
parent 3c28184593
commit 890e3564d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,5 +1,4 @@
import asyncio
from asyncio.exceptions import TimeoutError
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from pathlib import Path
@ -7,236 +6,200 @@ import time
from typing import Any, ClassVar, Literal
import aiofiles
from anyio import EndOfStream
import httpx
from httpx import ConnectTimeout, HTTPStatusError, Response
from httpx import AsyncHTTPTransport, HTTPStatusError, Response
from nonebot_plugin_alconna import UniMessage
from nonebot_plugin_htmlrender import get_browser
from packaging.version import parse as parse_version
from playwright.async_api import Page
from retrying import retry
import rich
from rich.progress import (
BarColumn,
DownloadColumn,
Progress,
TextColumn,
TransferSpeedColumn,
)
from zhenxun.configs.config import BotConfig
from zhenxun.services.log import logger
from zhenxun.utils.message import MessageUtils
from zhenxun.utils.user_agent import get_user_agent
# from .browser import get_browser
def get_async_client(
proxies: dict[str, str] | None = None, verify: bool = False, **kwargs
) -> httpx.AsyncClient:
check_httpx_version = parse_version(httpx.__version__) >= parse_version("0.28.0")
transport = kwargs.pop("transport", None) or AsyncHTTPTransport(verify=verify)
def get_async_client(proxies: dict[str, str] | None = None, **kwargs):
transport = httpx.AsyncHTTPTransport(verify=False)
try:
return httpx.AsyncClient(proxies=proxies, transport=transport, **kwargs)
except TypeError:
return httpx.AsyncClient(
mounts={
k: v
for k, v in {
"http://": httpx.AsyncHTTPTransport(proxy=proxies.get("http"))
if proxies
else None,
"https://": httpx.AsyncHTTPTransport(proxy=proxies.get("https"))
if proxies
else None,
}.items()
if v is not None
},
transport=transport,
**kwargs,
)
if not check_httpx_version:
return httpx.AsyncClient(proxies=proxies, transport=transport, **kwargs) # type: ignore
proxy_str = None
if proxies:
proxy_str = proxies.get("http://") or proxies.get("https://")
if not proxy_str:
logger.warning(f"代理字典 {proxies} 中未能提取出有效的URL代理已被忽略。")
return httpx.AsyncClient(proxy=proxy_str, transport=transport, **kwargs) # type: ignore
class AsyncHttpx:
proxy: ClassVar[dict[str, str | None]] = {
"http://": BotConfig.system_proxy,
"https://": BotConfig.system_proxy,
}
default_proxy: ClassVar[dict[str, str] | None] = (
{
"http://": BotConfig.system_proxy,
"https://": BotConfig.system_proxy,
}
if BotConfig.system_proxy
else None
)
@classmethod
@asynccontextmanager
async def _create_client(
cls,
*,
use_proxy: bool = True,
proxy: dict[str, str] | None = None,
headers: dict[str, str] | None = None,
verify: bool = False,
**kwargs,
) -> AsyncGenerator[httpx.AsyncClient, None]:
"""创建一个私有的、配置好的 httpx.AsyncClient 上下文管理器。
说明:
此方法用于内部统一创建客户端处理代理和请求头逻辑减少代码重复
参数:
use_proxy: 是否使用在类中定义的默认代理
proxy: 手动指定的代理会覆盖默认代理
headers: 需要合并到客户端的自定义请求头
verify: 是否验证 SSL 证书
**kwargs: 其他所有传递给 httpx.AsyncClient 的参数
返回:
AsyncGenerator[httpx.AsyncClient, None]: 生成器
"""
proxies_to_use = proxy or (cls.default_proxy if use_proxy else None)
final_headers = get_user_agent()
if headers:
final_headers.update(headers)
async with get_async_client(
proxies=proxies_to_use, verify=verify, headers=final_headers, **kwargs
) as client:
yield client
@classmethod
@retry(stop_max_attempt_number=3)
async def get(
cls,
url: str | list[str],
*,
params: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
cookies: dict[str, str] | None = None,
verify: bool = True,
use_proxy: bool = True,
proxy: dict[str, str] | None = None,
timeout: int = 30, # noqa: ASYNC109
check_status_code: int | None = None,
**kwargs,
) -> Response:
"""Get
) -> Response: # sourcery skip: use-assigned-variable
"""发送 GET 请求,并返回第一个成功的响应。
说明:
本方法是 httpx.get 的高级包装增加了多链接尝试自动重试和统一的代理管理
如果提供 URL 列表它将依次尝试直到成功为止
参数:
url: url
params: params
headers: 请求头
cookies: cookies
verify: verify
use_proxy: 使用默认代理
proxy: 指定代理
timeout: 超时时间
check_status_code: 检查状态码
url: 单个请求 URL 或一个 URL 列表
check_status_code: (可选) 若提供将检查响应状态码是否匹配否则抛出异常
**kwargs: 其他所有传递给 httpx.get 的参数
( `params`, `headers`, `timeout`)
返回:
Response: Response
"""
urls = [url] if isinstance(url, str) else url
return await cls._get_first_successful(
urls,
params=params,
headers=headers,
cookies=cookies,
verify=verify,
use_proxy=use_proxy,
proxy=proxy,
timeout=timeout,
check_status_code=check_status_code,
**kwargs,
)
@classmethod
async def _get_first_successful(
cls,
urls: list[str],
check_status_code: int | None = None,
**kwargs,
) -> Response:
last_exception = None
for url in urls:
for current_url in urls:
try:
logger.info(f"开始获取 {url}..")
response = await cls._get_single(url, **kwargs)
logger.info(f"开始获取 {current_url}..")
async with cls._create_client(**kwargs) as client:
# 从 kwargs 中提取仅 client.get 支持的参数
get_kwargs = {
k: v
for k, v in kwargs.items()
if k not in ["use_proxy", "proxy", "verify", "headers"]
}
response = await client.get(current_url, **get_kwargs)
if check_status_code and response.status_code != check_status_code:
status_code = response.status_code
raise Exception(f"状态码错误:{status_code}!={check_status_code}")
raise HTTPStatusError(
f"状态码错误: {response.status_code}!={check_status_code}",
request=response.request,
response=response,
)
return response
except Exception as e:
last_exception = e
if url != urls[-1]:
logger.warning(f"获取 {url} 失败, 尝试下一个")
raise last_exception or Exception("All URLs failed")
if current_url != urls[-1]:
logger.warning(f"获取 {current_url} 失败, 尝试下一个", e=e)
raise last_exception or Exception("所有URL都获取失败")
@classmethod
async def _get_single(
cls,
url: str,
*,
params: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
cookies: dict[str, str] | None = None,
verify: bool = True,
use_proxy: bool = True,
proxy: dict[str, str] | None = None,
timeout: int = 30, # noqa: ASYNC109
**kwargs,
) -> Response:
if not headers:
headers = get_user_agent()
_proxy = proxy or (cls.proxy if use_proxy else None)
async with get_async_client(proxies=_proxy, verify=verify) as client: # type: ignore
return await client.get(
url,
params=params,
headers=headers,
cookies=cookies,
timeout=timeout,
**kwargs,
)
async def head(cls, url: str, **kwargs) -> Response:
"""发送 HEAD 请求。
@classmethod
async def head(
cls,
url: str,
*,
params: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
cookies: dict[str, str] | None = None,
verify: bool = True,
use_proxy: bool = True,
proxy: dict[str, str] | None = None,
timeout: int = 30, # noqa: ASYNC109
**kwargs,
) -> Response:
"""Get
参数:
url: url
params: params
headers: 请求头
cookies: cookies
verify: verify
use_proxy: 使用默认代理
proxy: 指定代理
timeout: 超时时间
"""
if not headers:
headers = get_user_agent()
_proxy = proxy or (cls.proxy if use_proxy else None)
async with get_async_client(proxies=_proxy, verify=verify) as client: # type: ignore
return await client.head(
url,
params=params,
headers=headers,
cookies=cookies,
timeout=timeout,
**kwargs,
)
@classmethod
async def post(
cls,
url: str,
*,
data: dict[str, Any] | None = None,
content: Any = None,
files: Any = None,
verify: bool = True,
use_proxy: bool = True,
proxy: dict[str, str] | None = None,
json: dict[str, Any] | None = None,
params: dict[str, str] | None = None,
headers: dict[str, str] | None = None,
cookies: dict[str, str] | None = None,
timeout: int = 30, # noqa: ASYNC109
**kwargs,
) -> Response:
"""
说明:
Post
本方法是对 httpx.head 的封装通常用于检查资源的元信息如大小类型
参数:
url: url
data: data
content: content
files: files
use_proxy: 是否默认代理
proxy: 指定代理
json: json
params: params
headers: 请求头
cookies: cookies
timeout: 超时时间
url: 请求的 URL
**kwargs: 其他所有传递给 httpx.head 的参数
( `headers`, `timeout`, `allow_redirects`)
返回:
Response: Response
"""
if not headers:
headers = get_user_agent()
_proxy = proxy or (cls.proxy if use_proxy else None)
async with get_async_client(proxies=_proxy, verify=verify) as client: # type: ignore
return await client.post(
url,
content=content,
data=data,
files=files,
json=json,
params=params,
headers=headers,
cookies=cookies,
timeout=timeout,
**kwargs,
)
async with cls._create_client(**kwargs) as client:
head_kwargs = {
k: v
for k, v in kwargs.items()
if k not in ["use_proxy", "proxy", "verify"]
}
return await client.head(url, **head_kwargs)
@classmethod
async def post(cls, url: str, **kwargs) -> Response:
"""发送 POST 请求。
说明:
本方法是对 httpx.post 的封装提供了统一的代理和客户端管理
参数:
url: 请求的 URL
**kwargs: 其他所有传递给 httpx.post 的参数
( `data`, `json`, `content` )
返回:
Response: Response
"""
async with cls._create_client(**kwargs) as client:
post_kwargs = {
k: v
for k, v in kwargs.items()
if k not in ["use_proxy", "proxy", "verify"]
}
return await client.post(url, **post_kwargs)
@classmethod
async def get_content(cls, url: str, **kwargs) -> bytes:
"""获取指定 URL 的二进制内容。
说明:
这是一个便捷方法等同于调用 get() 后再访问 .content 属性
参数:
url: 请求的 URL
**kwargs: 所有传递给 get() 方法的参数
返回:
bytes: 响应内容的二进制字节流 (bytes)
"""
res = await cls.get(url, **kwargs)
return res.content
@ -246,195 +209,132 @@ class AsyncHttpx:
url: str | list[str],
path: str | Path,
*,
params: dict[str, str] | None = None,
verify: bool = True,
use_proxy: bool = True,
proxy: dict[str, str] | None = None,
headers: dict[str, str] | None = None,
cookies: dict[str, str] | None = None,
timeout: int = 30, # noqa: ASYNC109
stream: bool = False,
follow_redirects: bool = True,
**kwargs,
) -> bool:
"""下载文件
"""下载文件到指定路径。
说明:
支持多链接尝试和流式下载带进度条
参数:
url: url
path: 存储路径
params: params
verify: verify
use_proxy: 使用代理
proxy: 指定代理
headers: 请求头
cookies: cookies
timeout: 超时时间
stream: 是否使用流式下载流式写入+进度条适用于下载大文件
url: 单个文件 URL 或一个备用 URL 列表
path: 文件保存的本地路径
stream: (可选) 是否使用流式下载适用于大文件默认为 False
**kwargs: 其他所有传递给 get() 方法或 httpx.stream() 的参数
返回:
bool: 是否下载成功
"""
if isinstance(path, str):
path = Path(path)
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
try:
for _ in range(3):
if not isinstance(url, list):
url = [url]
for u in url:
try:
if not stream:
response = await cls.get(
u,
params=params,
headers=headers,
cookies=cookies,
use_proxy=use_proxy,
proxy=proxy,
timeout=timeout,
follow_redirects=follow_redirects,
**kwargs,
)
urls = [url] if isinstance(url, str) else url
for current_url in urls:
try:
if not stream:
response = await cls.get(current_url, **kwargs)
response.raise_for_status()
async with aiofiles.open(path, "wb") as f:
await f.write(response.content)
else:
async with cls._create_client(**kwargs) as client:
stream_kwargs = {
k: v
for k, v in kwargs.items()
if k not in ["use_proxy", "proxy", "verify"]
}
async with client.stream(
"GET", current_url, **stream_kwargs
) as response:
response.raise_for_status()
content = response.content
async with aiofiles.open(path, "wb") as wf:
await wf.write(content)
logger.info(f"下载 {u} 成功.. Path{path.absolute()}")
else:
if not headers:
headers = get_user_agent()
_proxy = proxy or (cls.proxy if use_proxy else None)
async with get_async_client(
proxies=_proxy, # type: ignore
verify=verify,
) as client:
async with client.stream(
"GET",
u,
params=params,
headers=headers,
cookies=cookies,
timeout=timeout,
follow_redirects=True,
**kwargs,
) as response:
response.raise_for_status()
logger.info(
f"开始下载 {path.name}.. "
f"Url: {u}.. "
f"Path: {path.absolute()}"
)
async with aiofiles.open(path, "wb") as wf:
total = int(
response.headers.get("Content-Length", 0)
)
with rich.progress.Progress( # type: ignore
rich.progress.TextColumn(path.name), # type: ignore
"[progress.percentage]{task.percentage:>3.0f}%", # type: ignore
rich.progress.BarColumn(bar_width=None), # type: ignore
rich.progress.DownloadColumn(), # type: ignore
rich.progress.TransferSpeedColumn(), # type: ignore
) as progress:
download_task = progress.add_task(
"Download",
total=total or None,
)
async for chunk in response.aiter_bytes():
await wf.write(chunk)
await wf.flush()
progress.update(
download_task,
completed=response.num_bytes_downloaded,
)
logger.info(
f"下载 {u} 成功.. Path{path.absolute()}"
)
return True
except (TimeoutError, ConnectTimeout, HTTPStatusError):
logger.warning(f"下载 {u} 失败.. 尝试下一个地址..")
except EndOfStream as e:
logger.warning(
f"下载 {url} EndOfStream 异常 Path{path.absolute()}", e=e
)
if path.exists():
return True
logger.error(f"下载 {url} 下载超时.. Path{path.absolute()}")
except Exception as e:
logger.error(f"下载 {url} 错误 Path{path.absolute()}", e=e)
total = int(response.headers.get("Content-Length", 0))
with Progress(
TextColumn(path.name),
"[progress.percentage]{task.percentage:>3.0f}%",
BarColumn(bar_width=None),
DownloadColumn(),
TransferSpeedColumn(),
) as progress:
task_id = progress.add_task("Download", total=total)
async with aiofiles.open(path, "wb") as f:
async for chunk in response.aiter_bytes():
await f.write(chunk)
progress.update(task_id, advance=len(chunk))
logger.info(f"下载 {current_url} 成功 -> {path.absolute()}")
return True
except Exception as e:
logger.warning(f"下载 {current_url} 失败,尝试下一个。错误: {e}")
logger.error(f"所有URL {urls} 下载均失败 -> {path.absolute()}")
return False
@classmethod
async def gather_download_file(
cls,
url_list: list[str] | list[list[str]],
url_list: list[str],
path_list: list[str | Path],
*,
limit_async_number: int | None = None,
params: dict[str, str] | None = None,
use_proxy: bool = True,
proxy: dict[str, str] | None = None,
headers: dict[str, str] | None = None,
cookies: dict[str, str] | None = None,
timeout: int = 30, # noqa: ASYNC109
limit_async_number: int = 5,
**kwargs,
) -> list[bool]:
"""分组同时下载文件
"""并发下载多个文件。
说明:
使用 asyncio.Semaphore 来控制并发请求的数量
参数:
url_list: url列表
path_list: 存储路径列表
limit_async_number: 限制同时请求数量
params: params
use_proxy: 使用代理
proxy: 指定代理
headers: 请求头
cookies: cookies
timeout: 超时时间
url_list: 包含所有文件 URL 的列表
path_list: URL 列表对应的文件保存路径列表
limit_async_number: (可选) 最大并发下载数默认为 5
**kwargs: 其他所有传递给 download_file() 方法的参数
返回:
bool: 是否下载成功
"""
if n := len(url_list) != len(path_list):
raise UrlPathNumberNotEqual(
f"Url数量与Path数量不对等Url{len(url_list)}Path{len(path_list)}"
)
if limit_async_number and n > limit_async_number:
m = float(n) / limit_async_number
x = 0
j = limit_async_number
_split_url_list = []
_split_path_list = []
for _ in range(int(m)):
_split_url_list.append(url_list[x:j])
_split_path_list.append(path_list[x:j])
x += limit_async_number
j += limit_async_number
if int(m) < m:
_split_url_list.append(url_list[j:])
_split_path_list.append(path_list[j:])
else:
_split_url_list = [url_list]
_split_path_list = [path_list]
tasks = []
result_ = []
for x, y in zip(_split_url_list, _split_path_list):
tasks.extend(
asyncio.create_task(
cls.download_file(
url,
path,
params=params,
headers=headers,
cookies=cookies,
use_proxy=use_proxy,
timeout=timeout,
proxy=proxy,
**kwargs,
)
)
for url, path in zip(x, y)
)
_x = await asyncio.gather(*tasks)
result_ = result_ + list(_x)
tasks.clear()
return result_
if len(url_list) != len(path_list):
raise ValueError("URL 列表和路径列表的长度必须相等")
semaphore = asyncio.Semaphore(limit_async_number)
async def _download_with_semaphore(url: str, path: str | Path):
async with semaphore:
return await cls.download_file(url, path, **kwargs)
tasks = [
_download_with_semaphore(url, path)
for url, path in zip(url_list, path_list)
]
results = await asyncio.gather(*tasks, return_exceptions=True)
final_results = []
for i, result in enumerate(results):
if isinstance(result, Exception):
logger.error(f"并发下载 {url_list[i]} 时发生错误: {result}")
final_results.append(False)
else:
final_results.append(result) # type: ignore
return final_results
@classmethod
async def get_fastest_mirror(cls, url_list: list[str]) -> list[str]:
"""测试并返回最快的镜像地址。
说明:
通过并发发送 HEAD 请求来测试每个 URL 的响应时间和可用性并按响应速度排序
参数:
url_list: 需要测试的镜像 URL 列表
返回:
list[str]: 按从快到慢的顺序包含了所有可用的 URL
"""
assert url_list
async def head_mirror(client: type[AsyncHttpx], url: str) -> dict[str, Any]:
@ -503,7 +403,7 @@ class AsyncPlaywright:
wait_until: (
Literal["domcontentloaded", "load", "networkidle"] | None
) = "networkidle",
timeout: float | None = None, # noqa: ASYNC109
timeout: float | None = None,
type_: Literal["jpeg", "png"] | None = None,
user_agent: str | None = None,
cookies: list[dict[str, Any]] | dict[str, Any] | None = None,