feat(env): 支持git更新

This commit is contained in:
HibiKier 2025-07-31 19:04:14 +08:00
parent 59d72c3b3d
commit 0e19de102a
20 changed files with 2136 additions and 100 deletions

89
.env.example Normal file
View File

@ -0,0 +1,89 @@
SUPERUSERS=[""]
COMMAND_START=[""]
SESSION_RUNNING_EXPRESSION="别急呀,小真寻要宕机了!QAQ"
NICKNAME=["真寻", "小真寻", "绪山真寻", "小寻子"]
SESSION_EXPIRE_TIMEOUT=00:00:30
ALCONNA_USE_COMMAND_START=True
# 全局图片统一使用bytes发送当真寻与协议端不在同一服务器上时为True
IMAGE_TO_BYTES = True
# 回复消息时自称
SELF_NICKNAME="小真寻"
# 官bot appid:bot账号
QBOT_ID_DATA = '{
}'
# 数据库配置
# 示例: "postgres://user:password@127.0.0.1:5432/database"
# 示例: "mysql://user:password@127.0.0.1:3306/database"
# 示例: "sqlite:data/db/zhenxun.db" 在data目录下建立db文件夹
DB_URL = ""
# NONE: 不使用缓存, MEMORY: 使用内存缓存, REDIS: 使用Redis缓存
CACHE_MODE = NONE
# REDIS配置使用REDIS替换Cache内存缓存
# REDIS地址
# REDIS_HOST = "127.0.0.1"
# REDIS端口
# REDIS_PORT = 6379
# REDIS密码
# REDIS_PASSWORD = ""
# REDIS过期时间
# REDIS_EXPIRE = 600
# 系统代理
# SYSTEM_PROXY = "http://127.0.0.1:7890"
PLATFORM_SUPERUSERS = '
{
"qq": [""],
"dodo": [""]
}
'
DRIVER=~fastapi+~httpx+~websockets
# LOG_LEVEL = DEBUG
# 服务器和端口
HOST = 127.0.0.1
PORT = 8080
# kook adapter toekn
# kaiheila_bots =[{"token": ""}]
# # discode adapter
# DISCORD_BOTS='
# [
# {
# "token": "",
# "intent": {
# "guild_messages": true,
# "direct_messages": true
# },
# "application_commands": {"*": ["*"]}
# }
# ]
# '
# DISCORD_PROXY=''
# # dodo adapter
# DODO_BOTS='
# [
# {
# "client_id": "",
# "token": ""
# }
# ]
# '
# application_commands的{"*": ["*"]}代表将全部应用命令注册为全局应用命令
# {"admin": ["123", "456"]}则代表将admin命令注册为id是123、456服务器的局部命令其余命令不注册

4
.gitignore vendored
View File

@ -144,4 +144,6 @@ log/
backup/
.idea/
resources/
.vscode/launch.json
.vscode/launch.json
./.env.dev

View File

@ -104,25 +104,16 @@ class MemberUpdateManage:
exist_member_list.append(member.id)
if data_list[0]:
try:
await GroupInfoUser.bulk_create(data_list[0], 30)
await GroupInfoUser.bulk_create(
data_list[0], 30, ignore_conflicts=True
)
logger.debug(
f"创建用户数据 {len(data_list[0])}",
"更新群组成员信息",
target=group_id,
)
except Exception as e:
logger.error(
f"批量创建用户数据失败: {e},开始进行逐个存储",
"更新群组成员信息",
)
for u in data_list[0]:
try:
await u.save()
except Exception as e:
logger.error(
f"创建用户 {u.user_name}({u.user_id}) 数据失败: {e}",
"更新群组成员信息",
)
logger.error("批量创建用户数据失败", "更新群组成员信息", e=e)
if data_list[1]:
await GroupInfoUser.bulk_update(data_list[1], ["user_name"], 30)
logger.debug(

View File

@ -32,15 +32,23 @@ __plugin_meta__ = PluginMetadata(
检查更新真寻最新版本包括了自动更新
资源文件大小一般在130mb左右除非必须更新一般仅更新代码文件
指令
检查更新 [main|release|resource|webui] ?[-r]
检查更新 [main|release|resource|webui] ?[-r] ?[-f] ?[-z] ?[-t]
main: main分支
release: 最新release
resource: 资源文件
webui: webui文件
-r: 下载资源文件一般在更新main或release时使用
-f: 强制更新一般用于更新main时使用仅git更新时有效
-s: 更新源 git ali默认使用ali
-z: 下载zip文件进行更新仅git有效
-t: 更新方式git或download默认使用git
git: 使用git pull推荐
download: 通过commit hash比较文件后下载更新仅git有效
示例:
检查更新 main
检查更新 main -r
检查更新 main -f
检查更新 release -r
检查更新 resource
检查更新 webui
@ -57,6 +65,10 @@ _matcher = on_alconna(
"检查更新",
Args["ver_type?", ["main", "release", "resource", "webui"]],
Option("-r|--resource", action=store_true, help_text="下载资源文件"),
Option("-f|--force", action=store_true, help_text="强制更新"),
Option("-s", Args["source?", ["git", "ali"]], help_text="更新源"),
Option("-z|--zip", action=store_true, help_text="下载zip文件"),
Option("-t", Args["update_type?", ["git", "download"]], help_text="更新方式"),
),
priority=1,
block=True,
@ -71,6 +83,10 @@ async def _(
session: Uninfo,
ver_type: Match[str],
resource: Query[bool] = Query("resource", False),
force: Query[bool] = Query("force", False),
source: Query[str] = Query("source", "ali"),
zip: Query[bool] = Query("zip", False),
update_type: Query[str] = Query("update_type", "git"),
):
result = ""
await MessageUtils.build_message("正在进行检查更新...").send(reply_to=True)
@ -80,7 +96,15 @@ async def _(
logger.info("查看当前版本...", "检查更新", session=session)
await MessageUtils.build_message(result).finish()
try:
result = await UpdateManager.update(bot, session.user.id, ver_type.result)
result = await UpdateManager.update(
bot,
session.user.id,
ver_type.result,
force.result,
source.result,
zip.result,
update_type.result,
)
except Exception as e:
logger.error("版本更新失败...", "检查更新", session=session, e=e)
await MessageUtils.build_message(f"更新版本失败...e: {e}").finish()

View File

@ -1,6 +1,5 @@
import os
import shutil
import subprocess
import tarfile
import zipfile
@ -12,7 +11,9 @@ from zhenxun.services.log import logger
from zhenxun.utils.github_utils import GithubUtils
from zhenxun.utils.github_utils.models import RepoInfo
from zhenxun.utils.http_utils import AsyncHttpx
from zhenxun.utils.manager.virtual_env_package_manager import VirtualEnvPackageManager
from zhenxun.utils.platform import PlatformUtils
from zhenxun.utils.repo_utils import AliyunRepoManager, GithubRepoManager
from .config import (
BACKUP_PATH,
@ -22,6 +23,7 @@ from .config import (
DEFAULT_GITHUB_URL,
DOWNLOAD_GZ_FILE,
DOWNLOAD_ZIP_FILE,
GIT_GITHUB_URL,
PYPROJECT_FILE,
PYPROJECT_FILE_STRING,
PYPROJECT_LOCK_FILE,
@ -35,26 +37,6 @@ from .config import (
)
def install_requirement():
requirement_path = (REQ_TXT_FILE).absolute()
if not requirement_path.exists():
logger.debug(
f"没有找到zhenxun的requirement.txt,目标路径为{requirement_path}", COMMAND
)
return
try:
result = subprocess.run(
["pip", "install", "-r", str(requirement_path)],
check=True,
capture_output=True,
text=True,
)
logger.debug(f"成功安装真寻依赖,日志:\n{result.stdout}", COMMAND)
except subprocess.CalledProcessError as e:
logger.error(f"安装真寻依赖失败,错误:\n{e.stderr}", COMMAND, e=e)
@run_sync
def _file_handle(latest_version: str | None):
"""文件移动操作
@ -133,7 +115,6 @@ def _file_handle(latest_version: str | None):
if latest_version:
with open(VERSION_FILE, "w", encoding="utf8") as f:
f.write(f"__version__: {latest_version}")
install_requirement()
class UpdateManager:
@ -185,17 +166,7 @@ class UpdateManager:
)
@classmethod
async def update(cls, bot: Bot, user_id: str, version_type: str) -> str:
"""更新操作
参数:
bot: Bot
user_id: 用户id
version_type: 更新版本类型
返回:
str | None: 返回消息
"""
async def __zip_update(cls, version_type: str):
logger.info("开始下载真寻最新版文件....", COMMAND)
cur_version = cls.__get_version()
url = None
@ -222,11 +193,6 @@ class UpdateManager:
f"开始更新版本:{cur_version} -> {new_version} | 下载链接:{url}",
COMMAND,
)
await PlatformUtils.send_superuser(
bot,
f"检测真寻已更新,版本更新:{cur_version} -> {new_version}\n开始更新...",
user_id,
)
download_file = (
DOWNLOAD_GZ_FILE if version_type == "release" else DOWNLOAD_ZIP_FILE
)
@ -243,6 +209,65 @@ class UpdateManager:
logger.debug("下载真寻最新版文件失败...", COMMAND)
return ""
@classmethod
async def update(
cls,
bot: Bot,
user_id: str,
version_type: str,
force: bool,
source: str,
zip: bool,
update_type: str,
) -> str:
"""更新操作
参数:
bot: Bot
user_id: 用户id
version_type: 更新版本类型
force: 是否强制更新
source: 更新源
zip: 是否下载zip文件
update_type: 更新方式
返回:
str | None: 返回消息
"""
cur_version = cls.__get_version()
await PlatformUtils.send_superuser(
bot,
f"检测真寻已更新,当前版本:{cur_version}\n开始更新...",
user_id,
)
if zip:
return await cls.__zip_update(version_type)
elif source == "git":
result = await GithubRepoManager.update(
GIT_GITHUB_URL,
BASE_PATH,
use_git=update_type == "git",
force=force,
)
else:
result = await AliyunRepoManager.update(
GIT_GITHUB_URL,
BASE_PATH,
force=force,
)
if not result.success:
return f"版本更新失败...错误: {result.error_message}"
await PlatformUtils.send_superuser(
bot, "真寻更新完成,开始安装依赖...", user_id
)
await VirtualEnvPackageManager.install_requirement(REQ_TXT_FILE)
return (
f"版本更新完成!\n"
f"版本: {cur_version} -> {result.new_version}\n"
f"变更文件个数: {len(result.changed_files)}\n"
"请重新启动真寻以完成更新!"
)
@classmethod
def __get_version(cls) -> str:
"""获取当前版本

View File

@ -2,6 +2,8 @@ from pathlib import Path
from zhenxun.configs.path_config import TEMP_PATH
GIT_GITHUB_URL = "https://github.com/zhenxun-org/zhenxun_bot.git"
DEFAULT_GITHUB_URL = "https://github.com/HibiKier/zhenxun_bot/tree/main"
RELEASE_URL = "https://api.github.com/repos/HibiKier/zhenxun_bot/releases/latest"

View File

@ -40,14 +40,14 @@ def row_style(column: str, text: str) -> RowStyle:
return style
def install_requirement(plugin_path: Path):
async def install_requirement(plugin_path: Path):
requirement_files = ["requirement.txt", "requirements.txt"]
requirement_paths = [plugin_path / file for file in requirement_files]
if existing_requirements := next(
(path for path in requirement_paths if path.exists()), None
):
VirtualEnvPackageManager.install_requirement(existing_requirements)
await VirtualEnvPackageManager.install_requirement(existing_requirements)
class StoreManager:
@ -288,7 +288,7 @@ class StoreManager:
if not success:
raise Exception("插件依赖文件下载失败")
logger.debug(f"插件依赖文件列表: {req_paths}", LOG_COMMAND)
install_requirement(plugin_path)
await install_requirement(plugin_path)
except ValueError as e:
logger.warning("未获取到依赖文件路径...", e=e)
return True

View File

@ -38,10 +38,11 @@ async def _(setting: Setting) -> Result:
password = Config.get_config("web-ui", "password")
if password or BotConfig.db_url:
return Result.fail("配置已存在请先删除DB_URL内容和前端密码再进行设置。")
env_file = Path() / ".env.dev"
env_file = Path() / ".env.example"
if not env_file.exists():
return Result.fail("配置文件.env.dev不存在。")
env_text = env_file.read_text(encoding="utf-8")
to_env_file = Path() / ".env.dev"
if setting.db_url:
if setting.db_url.startswith("sqlite"):
base_dir = Path().resolve()
@ -78,7 +79,7 @@ async def _(setting: Setting) -> Result:
if setting.username:
Config.set_config("web-ui", "username", setting.username)
Config.set_config("web-ui", "password", setting.password, True)
env_file.write_text(env_text, encoding="utf-8")
to_env_file.write_text(env_text, encoding="utf-8")
if BAT_FILE.exists():
for file in os.listdir(Path()):
if file.startswith(FILE_NAME):

View File

@ -229,9 +229,9 @@ async def _(payload: InstallDependenciesPayload) -> Result:
if not payload.dependencies:
return Result.fail("依赖列表不能为空")
if payload.handle_type == "install":
result = VirtualEnvPackageManager.install(payload.dependencies)
result = await VirtualEnvPackageManager.install(payload.dependencies)
else:
result = VirtualEnvPackageManager.uninstall(payload.dependencies)
result = await VirtualEnvPackageManager.uninstall(payload.dependencies)
return Result.ok(result)
except Exception as e:
logger.error(f"{router.prefix}/install_dependencies 调用错误", "WebUi", e=e)

View File

@ -1,6 +1,8 @@
import asyncio
from pathlib import Path
from urllib.parse import urlparse
import aiofiles
import nonebot
from nonebot.utils import is_coroutine_callable
from tortoise import Tortoise
@ -86,6 +88,7 @@ def get_config() -> dict:
**MYSQL_CONFIG,
}
elif parsed.scheme == "sqlite":
Path(parsed.path).parent.mkdir(parents=True, exist_ok=True)
config["connections"]["default"] = {
"engine": "tortoise.backends.sqlite",
"credentials": {
@ -100,6 +103,15 @@ def get_config() -> dict:
async def init():
global MODELS, SCRIPT_METHOD
env_example_file = Path() / ".env.example"
env_dev_file = Path() / ".env.dev"
if not env_dev_file.exists():
async with aiofiles.open(env_example_file, encoding="utf-8") as f:
env_text = await f.read()
async with aiofiles.open(env_dev_file, "w", encoding="utf-8") as f:
await f.write(env_text)
logger.info("已生成 .env.dev 文件,请根据 .env.example 文件配置进行配置")
MODELS = db_model.models
SCRIPT_METHOD = db_model.script_method
if not BotConfig.db_url:

View File

@ -317,6 +317,20 @@ class AliyunFileInfo:
repository_id: str
"""仓库ID"""
@classmethod
async def get_client(cls) -> devops20210625Client:
"""获取阿里云客户端"""
config = open_api_models.Config(
access_key_id=Aliyun_AccessKey_ID,
access_key_secret=base64.b64decode(
Aliyun_Secret_AccessKey_encrypted.encode()
).decode(),
endpoint=ALIYUN_ENDPOINT,
region_id=ALIYUN_REGION,
)
return devops20210625Client(config)
@classmethod
async def get_file_content(
cls, file_path: str, repo: str, ref: str = "main"
@ -335,16 +349,8 @@ class AliyunFileInfo:
repository_id = ALIYUN_REPO_MAPPING.get(repo)
if not repository_id:
raise ValueError(f"未找到仓库 {repo} 对应的阿里云仓库ID")
config = open_api_models.Config(
access_key_id=Aliyun_AccessKey_ID,
access_key_secret=base64.b64decode(
Aliyun_Secret_AccessKey_encrypted.encode()
).decode(),
endpoint=ALIYUN_ENDPOINT,
region_id=ALIYUN_REGION,
)
client = devops20210625Client(config)
client = await cls.get_client()
request = devops_20210625_models.GetFileBlobsRequest(
organization_id=ALIYUN_ORG_ID,
@ -404,16 +410,7 @@ class AliyunFileInfo:
if not repository_id:
raise ValueError(f"未找到仓库 {repo} 对应的阿里云仓库ID")
config = open_api_models.Config(
access_key_id=Aliyun_AccessKey_ID,
access_key_secret=base64.b64decode(
Aliyun_Secret_AccessKey_encrypted.encode()
).decode(),
endpoint=ALIYUN_ENDPOINT,
region_id=ALIYUN_REGION,
)
client = devops20210625Client(config)
client = await cls.get_client()
request = devops_20210625_models.ListRepositoryTreeRequest(
organization_id=ALIYUN_ORG_ID,
@ -459,16 +456,7 @@ class AliyunFileInfo:
if not repository_id:
raise ValueError(f"未找到仓库 {repo} 对应的阿里云仓库ID")
config = open_api_models.Config(
access_key_id=Aliyun_AccessKey_ID,
access_key_secret=base64.b64decode(
Aliyun_Secret_AccessKey_encrypted.encode()
).decode(),
endpoint=ALIYUN_ENDPOINT,
region_id=ALIYUN_REGION,
)
client = devops20210625Client(config)
client = await cls.get_client()
request = devops_20210625_models.GetRepositoryCommitRequest(
organization_id=ALIYUN_ORG_ID,

View File

@ -1,3 +1,4 @@
import asyncio
from pathlib import Path
import subprocess
from subprocess import CalledProcessError
@ -36,7 +37,7 @@ class VirtualEnvPackageManager:
)
@classmethod
def install(cls, package: list[str] | str):
async def install(cls, package: list[str] | str):
"""安装依赖包
参数:
@ -49,7 +50,8 @@ class VirtualEnvPackageManager:
command.append("install")
command.append(" ".join(package))
logger.info(f"执行虚拟环境安装包指令: {command}", LOG_COMMAND)
result = subprocess.run(
result = await asyncio.to_thread(
subprocess.run,
command,
check=True,
capture_output=True,
@ -65,7 +67,7 @@ class VirtualEnvPackageManager:
return e.stderr
@classmethod
def uninstall(cls, package: list[str] | str):
async def uninstall(cls, package: list[str] | str):
"""卸载依赖包
参数:
@ -79,7 +81,8 @@ class VirtualEnvPackageManager:
command.append("-y")
command.append(" ".join(package))
logger.info(f"执行虚拟环境卸载包指令: {command}", LOG_COMMAND)
result = subprocess.run(
result = await asyncio.to_thread(
subprocess.run,
command,
check=True,
capture_output=True,
@ -95,7 +98,7 @@ class VirtualEnvPackageManager:
return e.stderr
@classmethod
def update(cls, package: list[str] | str):
async def update(cls, package: list[str] | str):
"""更新依赖包
参数:
@ -109,7 +112,8 @@ class VirtualEnvPackageManager:
command.append("--upgrade")
command.append(" ".join(package))
logger.info(f"执行虚拟环境更新包指令: {command}", LOG_COMMAND)
result = subprocess.run(
result = await asyncio.to_thread(
subprocess.run,
command,
check=True,
capture_output=True,
@ -122,7 +126,7 @@ class VirtualEnvPackageManager:
return e.stderr
@classmethod
def install_requirement(cls, requirement_file: Path):
async def install_requirement(cls, requirement_file: Path):
"""安装依赖文件
参数:
@ -139,7 +143,8 @@ class VirtualEnvPackageManager:
command.append("-r")
command.append(str(requirement_file.absolute()))
logger.info(f"执行虚拟环境安装依赖文件指令: {command}", LOG_COMMAND)
result = subprocess.run(
result = await asyncio.to_thread(
subprocess.run,
command,
check=True,
capture_output=True,
@ -158,13 +163,14 @@ class VirtualEnvPackageManager:
return e.stderr
@classmethod
def list(cls) -> str:
async def list(cls) -> str:
"""列出已安装的依赖包"""
try:
command = cls.__get_command()
command.append("list")
logger.info(f"执行虚拟环境列出包指令: {command}", LOG_COMMAND)
result = subprocess.run(
result = await asyncio.to_thread(
subprocess.run,
command,
check=True,
capture_output=True,

View File

@ -0,0 +1,57 @@
"""
仓库管理工具用于操作GitHub和阿里云CodeUp项目的更新和文件下载
"""
from .aliyun_manager import AliyunCodeupManager
from .base_manager import BaseRepoManager
from .config import AliyunCodeupConfig, GithubConfig, RepoConfig
from .exceptions import (
ApiRateLimitError,
AuthenticationError,
ConfigError,
FileNotFoundError,
NetworkError,
RepoDownloadError,
RepoManagerError,
RepoNotFoundError,
RepoUpdateError,
)
from .github_manager import GithubManager
from .models import (
FileDownloadResult,
RepoCommitInfo,
RepoFileInfo,
RepoType,
RepoUpdateResult,
)
from .utils import check_git, filter_files, glob_to_regex, run_git_command
GithubRepoManager = GithubManager()
AliyunRepoManager = AliyunCodeupManager()
__all__ = [
"AliyunCodeupConfig",
"AliyunRepoManager",
"ApiRateLimitError",
"AuthenticationError",
"BaseRepoManager",
"ConfigError",
"FileDownloadResult",
"FileNotFoundError",
"GithubConfig",
"GithubRepoManager",
"NetworkError",
"RepoCommitInfo",
"RepoConfig",
"RepoDownloadError",
"RepoFileInfo",
"RepoManagerError",
"RepoNotFoundError",
"RepoType",
"RepoUpdateError",
"RepoUpdateResult",
"check_git",
"filter_files",
"glob_to_regex",
"run_git_command",
]

View File

@ -0,0 +1,549 @@
"""
阿里云CodeUp仓库管理工具
"""
import asyncio
from collections.abc import Callable
from datetime import datetime
from pathlib import Path
from aiocache import cached
from zhenxun.services.log import logger
from zhenxun.utils.github_utils.models import AliyunFileInfo
from .base_manager import BaseRepoManager
from .config import LOG_COMMAND, RepoConfig
from .exceptions import (
AuthenticationError,
FileNotFoundError,
RepoDownloadError,
RepoNotFoundError,
RepoUpdateError,
)
from .models import (
FileDownloadResult,
RepoCommitInfo,
RepoFileInfo,
RepoType,
RepoUpdateResult,
)
class AliyunCodeupManager(BaseRepoManager):
"""阿里云CodeUp仓库管理工具"""
def __init__(self, config: RepoConfig | None = None):
"""
初始化阿里云CodeUp仓库管理工具
Args:
config: 配置如果为None则使用默认配置
"""
super().__init__(config)
self._client = None
async def update_repo(
self,
repo_url: str,
local_path: Path,
branch: str = "main",
include_patterns: list[str] | None = None,
exclude_patterns: list[str] | None = None,
) -> RepoUpdateResult:
"""
更新阿里云CodeUp仓库
Args:
repo_url: 仓库URL或名称
local_path: 本地保存路径
branch: 分支名称
include_patterns: 包含的文件模式列表 ["*.py", "docs/*.md"]
exclude_patterns: 排除的文件模式列表 ["__pycache__/*", "*.pyc"]
Returns:
RepoUpdateResult: 更新结果
"""
try:
# 检查配置
self._check_config()
# 获取仓库名称从URL中提取
repo_url = repo_url.split("/")[-1].replace(".git", "")
# 获取仓库最新提交ID
newest_commit = await self._get_newest_commit(repo_url, branch)
# 创建结果对象
result = RepoUpdateResult(
repo_type=RepoType.ALIYUN,
repo_name=repo_url.split("/")[-1].replace(".git", ""),
owner=self.config.aliyun_codeup.organization_id,
old_version="", # 将在后面更新
new_version=newest_commit,
)
old_version = await self.read_version_file(local_path)
result.old_version = old_version
# 如果版本相同,则无需更新
if old_version == newest_commit:
result.success = True
logger.debug(
f"仓库 {repo_url.split('/')[-1].replace('.git', '')}"
f" 已是最新版本: {newest_commit[:8]}",
LOG_COMMAND,
)
return result
# 确保本地目录存在
local_path.mkdir(parents=True, exist_ok=True)
# 获取仓库名称从URL中提取
repo_name = repo_url.split("/")[-1].replace(".git", "")
# 获取变更的文件列表
changed_files = await self._get_changed_files(
repo_name, old_version or None, newest_commit
)
# 过滤文件
if include_patterns or exclude_patterns:
from .utils import filter_files
changed_files = filter_files(
changed_files, include_patterns, exclude_patterns
)
result.changed_files = changed_files
# 下载变更的文件
for file_path in changed_files:
try:
local_file_path = local_path / file_path
await self._download_file(
repo_name, file_path, local_file_path, newest_commit
)
except Exception as e:
logger.error(f"下载文件 {file_path} 失败", LOG_COMMAND, e=e)
# 更新版本文件
await self.write_version_file(local_path, newest_commit)
result.success = True
return result
except RepoUpdateError as e:
logger.error(f"更新仓库失败: {e}")
# 从URL中提取仓库名称
repo_name = repo_url.split("/")[-1].replace(".git", "")
return RepoUpdateResult(
repo_type=RepoType.ALIYUN,
repo_name=repo_name,
owner=self.config.aliyun_codeup.organization_id,
old_version="",
new_version="",
error_message=str(e),
)
except Exception as e:
logger.error(f"更新仓库失败: {e}")
# 从URL中提取仓库名称
repo_name = repo_url.split("/")[-1].replace(".git", "")
return RepoUpdateResult(
repo_type=RepoType.ALIYUN,
repo_name=repo_name,
owner=self.config.aliyun_codeup.organization_id,
old_version="",
new_version="",
error_message=str(e),
)
async def download_file(
self,
repo_url: str,
file_path: str,
local_path: Path,
branch: str = "main",
) -> FileDownloadResult:
"""
从阿里云CodeUp下载单个文件
Args:
repo_url: 仓库URL或名称
file_path: 文件在仓库中的路径
local_path: 本地保存路径
branch: 分支名称
Returns:
FileDownloadResult: 下载结果
"""
try:
# 检查配置
self._check_config()
# 获取仓库名称从URL中提取
repo_identifier = repo_url.split("/")[-1].replace(".git", "")
# 创建结果对象
result = FileDownloadResult(
repo_type=RepoType.ALIYUN,
repo_name=repo_url.split("/")[-1].replace(".git", ""),
owner=self.config.aliyun_codeup.organization_id,
file_path=file_path,
local_path=str(local_path),
version=branch,
)
# 确保本地目录存在
Path(local_path).parent.mkdir(parents=True, exist_ok=True)
# 下载文件
file_size = await self._download_file(
repo_identifier, file_path, local_path, branch
)
result.success = True
result.file_size = file_size
return result
except RepoDownloadError as e:
logger.error(f"下载文件失败: {e}")
# 从URL中提取仓库名称
repo_name = repo_url.split("/")[-1].replace(".git", "")
return FileDownloadResult(
repo_type=RepoType.ALIYUN,
repo_name=repo_name,
owner=self.config.aliyun_codeup.organization_id,
file_path=file_path,
local_path=str(local_path),
version=branch,
error_message=str(e),
)
except Exception as e:
logger.error(f"下载文件失败: {e}")
# 从URL中提取仓库名称
repo_name = repo_url.split("/")[-1].replace(".git", "")
return FileDownloadResult(
repo_type=RepoType.ALIYUN,
repo_name=repo_name,
owner=self.config.aliyun_codeup.organization_id,
file_path=file_path,
local_path=str(local_path),
version=branch,
error_message=str(e),
)
async def get_file_list(
self,
repo_url: str,
dir_path: str = "",
branch: str = "main",
recursive: bool = False,
) -> list[RepoFileInfo]:
"""
获取仓库文件列表
Args:
repo_url: 仓库URL或名称
dir_path: 目录路径空字符串表示仓库根目录
branch: 分支名称
recursive: 是否递归获取子目录
Returns:
list[RepoFileInfo]: 文件信息列表
"""
try:
# 检查配置
self._check_config()
# 获取仓库名称从URL中提取
repo_identifier = repo_url.split("/")[-1].replace(".git", "")
# 获取文件列表
search_type = "RECURSIVE" if recursive else "DIRECT"
tree_list = await AliyunFileInfo.get_repository_tree(
repo_identifier, dir_path, branch, search_type
)
result = []
for tree in tree_list:
# 跳过非当前目录的文件(如果不是递归模式)
if (
not recursive
and tree.path != dir_path
and "/" in tree.path.replace(dir_path, "", 1).strip("/")
):
continue
file_info = RepoFileInfo(
path=tree.path,
is_dir=tree.type == "tree",
)
result.append(file_info)
return result
except Exception as e:
logger.error(f"获取文件列表失败: {e}")
return []
async def get_commit_info(
self, repo_url: str, commit_id: str
) -> RepoCommitInfo | None:
"""
获取提交信息
Args:
repo_url: 仓库URL或名称
commit_id: 提交ID
Returns:
Optional[RepoCommitInfo]: 提交信息如果获取失败则返回None
"""
try:
# 检查配置
self._check_config()
# 获取仓库名称从URL中提取
repo_identifier = repo_url.split("/")[-1].replace(".git", "")
# 获取提交信息
# 注意这里假设AliyunFileInfo有get_commit_info方法如果没有需要实现
commit_data = await self._get_commit_info(repo_identifier, commit_id)
if not commit_data:
return None
# 解析提交信息
id_value = commit_data.get("id", commit_id)
message_value = commit_data.get("message", "")
author_value = commit_data.get("author_name", "")
date_value = commit_data.get(
"authored_date", datetime.now().isoformat()
).replace("Z", "+00:00")
commit_info = RepoCommitInfo(
commit_id=id_value,
message=message_value,
author=author_value,
commit_time=datetime.fromisoformat(date_value),
changed_files=[], # 阿里云API可能没有直接提供变更文件列表
)
return commit_info
except Exception as e:
logger.error(f"获取提交信息失败: {e}")
return None
def _check_config(self):
"""检查配置"""
if not self.config.aliyun_codeup.access_key_id:
raise AuthenticationError("阿里云CodeUp")
if not self.config.aliyun_codeup.access_key_secret:
raise AuthenticationError("阿里云CodeUp")
if not self.config.aliyun_codeup.organization_id:
raise AuthenticationError("阿里云CodeUp")
async def _get_newest_commit(self, repo_name: str, branch: str) -> str:
"""
获取仓库最新提交ID
Args:
repo_name: 仓库名称
branch: 分支名称
Returns:
str: 提交ID
"""
try:
newest_commit = await AliyunFileInfo.get_newest_commit(repo_name, branch)
if not newest_commit:
raise RepoNotFoundError(repo_name)
return newest_commit
except Exception as e:
logger.error(f"获取最新提交ID失败: {e}")
raise RepoUpdateError(f"获取最新提交ID失败: {e}")
async def _get_commit_info(self, repo_name: str, commit_id: str) -> dict:
"""
获取提交信息
Args:
repo_name: 仓库名称
commit_id: 提交ID
Returns:
dict: 提交信息
"""
# 这里需要实现从阿里云获取提交信息的逻辑
# 由于AliyunFileInfo可能没有get_commit_info方法这里提供一个简单的实现
try:
# 这里应该是调用阿里云API获取提交信息
# 这里只是一个示例实际上需要根据阿里云API实现
return {
"id": commit_id,
"message": "提交信息",
"author_name": "作者",
"authored_date": datetime.now().isoformat(),
}
except Exception as e:
logger.error(f"获取提交信息失败: {e}")
return {}
@cached(ttl=3600)
async def _get_changed_files(
self, repo_name: str, old_commit: str | None, new_commit: str
) -> list[str]:
"""
获取两个提交之间变更的文件列表
Args:
repo_name: 仓库名称
old_commit: 旧提交ID如果为None则获取所有文件
new_commit: 新提交ID
Returns:
list[str]: 变更的文件列表
"""
if not old_commit:
# 如果没有旧提交,则获取仓库中的所有文件
tree_list = await AliyunFileInfo.get_repository_tree(
repo_name, "", new_commit, "RECURSIVE"
)
return [tree.path for tree in tree_list if tree.type == "blob"]
# 获取两个提交之间的差异
try:
# 这里需要实现从阿里云获取提交差异的逻辑
# 由于AliyunFileInfo可能没有get_commit_diff_files方法 这里提供一个简单的实现
# 实际上应该调用阿里云API获取提交差异
files = [] # 这里应该是从阿里云API获取的文件列表
return files
except Exception as e:
logger.error(f"获取提交差异失败: {e}")
raise RepoUpdateError(f"获取提交差异失败: {e}")
async def update_via_git(
self,
repo_url: str,
local_path: Path,
branch: str = "main",
force: bool = False,
*,
repo_type: RepoType | None = None,
owner: str | None = None,
prepare_repo_url: Callable[[str], str] | None = None,
) -> RepoUpdateResult:
"""
通过Git命令直接更新仓库
参数:
repo_url: 仓库名称
local_path: 本地仓库路径
branch: 分支名称
force: 是否强制拉取
返回:
RepoUpdateResult: 更新结果
"""
# 定义预处理函数构建阿里云CodeUp的URL
def prepare_aliyun_url(repo_name: str) -> str:
# 构建仓库URL
# 阿里云CodeUp的仓库URL格式通常为
# https://codeup.aliyun.com/{organization_id}/{repo_name}.git
url = f"https://codeup.aliyun.com/{self.config.aliyun_codeup.organization_id}/{repo_name}.git"
# 添加访问令牌
if self.config.aliyun_codeup.rdc_access_token_encrypted:
token = self.config.aliyun_codeup.rdc_access_token_encrypted
url = url.replace("https://", f"https://oauth2:{token}@")
return url
# 调用基类的update_via_git方法
return await super().update_via_git(
repo_url=repo_url,
local_path=local_path,
branch=branch,
force=force,
repo_type=RepoType.ALIYUN,
owner=self.config.aliyun_codeup.organization_id,
prepare_repo_url=prepare_aliyun_url,
)
async def update(
self,
repo_url: str,
local_path: Path,
branch: str = "main",
use_git: bool = True,
force: bool = False,
include_patterns: list[str] | None = None,
exclude_patterns: list[str] | None = None,
) -> RepoUpdateResult:
"""
更新仓库可选择使用Git命令或API方式
参数:
repo_url: 仓库名称
local_path: 本地保存路径
branch: 分支名称
use_git: 是否使用Git命令更新
include_patterns: 包含的文件模式列表 ["*.py", "docs/*.md"]
exclude_patterns: 排除的文件模式列表 ["__pycache__/*", "*.pyc"]
返回:
RepoUpdateResult: 更新结果
"""
if use_git:
return await self.update_via_git(repo_url, local_path, branch, force)
else:
return await self.update_repo(
repo_url, local_path, branch, include_patterns, exclude_patterns
)
async def _download_file(
self, repo_name: str, file_path: str, local_path: Path, ref: str
) -> int:
"""
下载文件
Args:
repo_name: 仓库名称
file_path: 文件在仓库中的路径
local_path: 本地保存路径
ref: 分支/标签/提交ID
Returns:
int: 文件大小字节
"""
# 确保目录存在
local_path.parent.mkdir(parents=True, exist_ok=True)
# 获取文件内容
for retry in range(self.config.aliyun_codeup.download_retry + 1):
try:
content = await AliyunFileInfo.get_file_content(
file_path, repo_name, ref
)
if content is None:
raise FileNotFoundError(file_path, repo_name)
# 保存文件
return await self.save_file_content(content.encode("utf-8"), local_path)
except FileNotFoundError as e:
# 这些错误不需要重试
raise e
except Exception as e:
if retry < self.config.aliyun_codeup.download_retry:
logger.warning(f"下载文件失败,将重试: {e}")
await asyncio.sleep(1)
continue
raise RepoDownloadError(f"下载文件失败: {e}")
raise RepoDownloadError("下载文件失败: 超过最大重试次数")

View File

@ -0,0 +1,406 @@
"""
仓库管理工具的基础管理器
"""
from abc import ABC, abstractmethod
from pathlib import Path
import aiofiles
from zhenxun.services.log import logger
from .config import RepoConfig
from .models import (
FileDownloadResult,
RepoCommitInfo,
RepoFileInfo,
RepoType,
RepoUpdateResult,
)
from .utils import check_git, filter_files, run_git_command
class BaseRepoManager(ABC):
"""仓库管理工具基础类"""
def __init__(self, config: RepoConfig | None = None):
"""
初始化仓库管理工具
参数:
config: 配置如果为None则使用默认配置
"""
self.config = config or RepoConfig.get_instance()
self.config.ensure_dirs()
@abstractmethod
async def update_repo(
self,
repo_url: str,
local_path: Path,
branch: str = "main",
include_patterns: list[str] | None = None,
exclude_patterns: list[str] | None = None,
) -> RepoUpdateResult:
"""
更新仓库
参数:
repo_url: 仓库URL或名称
local_path: 本地保存路径
branch: 分支名称
include_patterns: 包含的文件模式列表 ["*.py", "docs/*.md"]
exclude_patterns: 排除的文件模式列表 ["__pycache__/*", "*.pyc"]
返回:
RepoUpdateResult: 更新结果
"""
pass
@abstractmethod
async def download_file(
self,
repo_url: str,
file_path: str,
local_path: Path,
branch: str = "main",
) -> FileDownloadResult:
"""
下载单个文件
参数:
repo_url: 仓库URL或名称
file_path: 文件在仓库中的路径
local_path: 本地保存路径
branch: 分支名称
返回:
FileDownloadResult: 下载结果
"""
pass
@abstractmethod
async def get_file_list(
self,
repo_url: str,
dir_path: str = "",
branch: str = "main",
recursive: bool = False,
) -> list[RepoFileInfo]:
"""
获取仓库文件列表
参数:
repo_url: 仓库URL或名称
dir_path: 目录路径空字符串表示仓库根目录
branch: 分支名称
recursive: 是否递归获取子目录
返回:
List[RepoFileInfo]: 文件信息列表
"""
pass
@abstractmethod
async def get_commit_info(
self, repo_url: str, commit_id: str
) -> RepoCommitInfo | None:
"""
获取提交信息
参数:
repo_url: 仓库URL或名称
commit_id: 提交ID
返回:
Optional[RepoCommitInfo]: 提交信息如果获取失败则返回None
"""
pass
async def save_file_content(self, content: bytes, local_path: Path) -> int:
"""
保存文件内容
参数:
content: 文件内容
local_path: 本地保存路径
返回:
int: 文件大小字节
"""
# 确保目录存在
local_path.parent.mkdir(parents=True, exist_ok=True)
# 保存文件
async with aiofiles.open(local_path, "wb") as f:
await f.write(content)
return len(content)
async def read_version_file(self, local_dir: Path) -> str:
"""
读取版本文件
参数:
local_dir: 本地目录
返回:
str: 版本号
"""
version_file = local_dir / "__version__"
if not version_file.exists():
return ""
try:
async with aiofiles.open(version_file) as f:
return (await f.read()).strip()
except Exception as e:
logger.error(f"读取版本文件失败: {e}")
return ""
async def write_version_file(self, local_dir: Path, version: str) -> bool:
"""
写入版本文件
参数:
local_dir: 本地目录
version: 版本号
返回:
bool: 是否成功
"""
version_file = local_dir / "__version__"
try:
version_bb = "vNone"
async with aiofiles.open(version_file) as rf:
if text := await rf.read():
version_bb = text.strip().split("-")[0]
async with aiofiles.open(version_file, "w") as f:
await f.write(f"{version_bb}-{version[:6]}")
return True
except Exception as e:
logger.error(f"写入版本文件失败: {e}")
return False
def filter_files(
self,
files: list[str],
include_patterns: list[str] | None = None,
exclude_patterns: list[str] | None = None,
) -> list[str]:
"""
过滤文件列表
参数:
files: 文件列表
include_patterns: 包含的文件模式列表 ["*.py", "docs/*.md"]
exclude_patterns: 排除的文件模式列表 ["__pycache__/*", "*.pyc"]
返回:
List[str]: 过滤后的文件列表
"""
return filter_files(files, include_patterns, exclude_patterns)
async def update_via_git(
self,
repo_url: str,
local_path: Path,
branch: str = "main",
force: bool = False,
*,
repo_type: RepoType | None = None,
owner="",
prepare_repo_url=None,
) -> RepoUpdateResult:
"""
通过Git命令直接更新仓库
参数:
repo_url: 仓库URL或名称
local_path: 本地仓库路径
branch: 分支名称
force: 是否强制拉取
repo_type: 仓库类型
owner: 仓库拥有者
prepare_repo_url: 预处理仓库URL的函数
返回:
RepoUpdateResult: 更新结果
"""
from .models import RepoType
try:
# 创建结果对象
result = RepoUpdateResult(
repo_type=repo_type or RepoType.GITHUB, # 默认使用GitHub类型
repo_name=repo_url.split("/")[-1].replace(".git", ""),
owner=owner or "",
old_version="",
new_version="",
)
# 检查Git是否可用
if not await check_git():
return RepoUpdateResult(
repo_type=repo_type or RepoType.GITHUB,
repo_name=repo_url.split("/")[-1].replace(".git", ""),
owner=owner or "",
old_version="",
new_version="",
error_message="Git命令不可用",
)
# 预处理仓库URL
if prepare_repo_url:
repo_url = prepare_repo_url(repo_url)
# 检查本地目录是否存在
if not local_path.exists():
# 如果不存在,则克隆仓库
logger.info(f"克隆仓库 {repo_url}{local_path}")
success, stdout, stderr = await run_git_command(
f"clone -b {branch} {repo_url} {local_path}"
)
if not success:
return RepoUpdateResult(
repo_type=repo_type or RepoType.GITHUB,
repo_name=repo_url.split("/")[-1].replace(".git", ""),
owner=owner or "",
old_version="",
new_version="",
error_message=f"克隆仓库失败: {stderr}",
)
# 获取当前提交ID
success, new_version, _ = await run_git_command(
"rev-parse HEAD", cwd=local_path
)
result.new_version = new_version.strip()
result.success = True
return result
# 如果目录存在检查是否是Git仓库
success, _, _ = await run_git_command(
"rev-parse --is-inside-work-tree", cwd=local_path
)
if not success:
return RepoUpdateResult(
repo_type=repo_type or RepoType.GITHUB,
repo_name=repo_url.split("/")[-1].replace(".git", ""),
owner=owner or "",
old_version="",
new_version="",
error_message=f"{local_path} 不是一个Git仓库",
)
# 获取当前提交ID作为旧版本
success, old_version, _ = await run_git_command(
"rev-parse HEAD", cwd=local_path
)
result.old_version = old_version.strip()
# 获取当前远程URL
success, remote_url, _ = await run_git_command(
"config --get remote.origin.url", cwd=local_path
)
# 如果远程URL不匹配则更新它
remote_url = remote_url.strip()
if success and repo_url not in remote_url and remote_url not in repo_url:
logger.info(f"更新远程URL: {remote_url} -> {repo_url}")
await run_git_command(
f"remote set-url origin {repo_url}", cwd=local_path
)
# 获取远程更新
logger.info("获取远程更新")
success, _, stderr = await run_git_command("fetch origin", cwd=local_path)
if not success:
return RepoUpdateResult(
repo_type=repo_type or RepoType.GITHUB,
repo_name=repo_url.split("/")[-1].replace(".git", ""),
owner=owner or "",
old_version=old_version.strip(),
new_version="",
error_message=f"获取远程更新失败: {stderr}",
)
# 获取当前分支
success, current_branch, _ = await run_git_command(
"rev-parse --abbrev-ref HEAD", cwd=local_path
)
current_branch = current_branch.strip()
# 如果当前分支不是目标分支,则切换分支
if success and current_branch != branch:
logger.info(f"切换分支: {current_branch} -> {branch}")
success, _, stderr = await run_git_command(
f"checkout {branch}", cwd=local_path
)
if not success:
return RepoUpdateResult(
repo_type=repo_type or RepoType.GITHUB,
repo_name=repo_url.split("/")[-1].replace(".git", ""),
owner=owner or "",
old_version=old_version.strip(),
new_version="",
error_message=f"切换分支失败: {stderr}",
)
# 拉取最新代码
logger.info("拉取最新代码")
pull_cmd = f"pull origin {branch}"
if force:
pull_cmd = f"pull --force origin {branch}"
logger.info("使用强制拉取模式")
success, _, stderr = await run_git_command(pull_cmd, cwd=local_path)
if not success:
return RepoUpdateResult(
repo_type=repo_type or RepoType.GITHUB,
repo_name=repo_url.split("/")[-1].replace(".git", ""),
owner=owner or "",
old_version=old_version.strip(),
new_version="",
error_message=f"拉取最新代码失败: {stderr}",
)
# 获取更新后的提交ID
success, new_version, _ = await run_git_command(
"rev-parse HEAD", cwd=local_path
)
result.new_version = new_version.strip()
# 如果版本相同,则无需更新
if old_version.strip() == new_version.strip():
logger.info(f"仓库 {repo_url} 已是最新版本: {new_version.strip()}")
result.success = True
return result
# 获取变更的文件列表
success, changed_files_output, _ = await run_git_command(
f"diff --name-only {old_version.strip()} {new_version.strip()}",
cwd=local_path,
)
if success:
changed_files = [
line.strip()
for line in changed_files_output.splitlines()
if line.strip()
]
result.changed_files = changed_files
logger.info(f"变更的文件列表: {changed_files}")
result.success = True
return result
except Exception as e:
logger.error(f"Git更新失败: {e}")
return RepoUpdateResult(
repo_type=repo_type or RepoType.GITHUB,
repo_name=repo_url.split("/")[-1].replace(".git", ""),
owner=owner or "",
old_version="",
new_version="",
error_message=str(e),
)

View File

@ -0,0 +1,75 @@
"""
仓库管理工具的配置模块
"""
from dataclasses import dataclass, field
from pathlib import Path
from zhenxun.configs.path_config import TEMP_PATH
LOG_COMMAND = "RepoUtils"
@dataclass
class GithubConfig:
"""GitHub配置"""
# API超时时间
api_timeout: int = 30
# 下载超时时间(秒)
download_timeout: int = 60
# 下载重试次数
download_retry: int = 3
# 代理配置
proxy: str | None = None
@dataclass
class AliyunCodeupConfig:
"""阿里云CodeUp配置"""
# 访问密钥ID
access_key_id: str = "LTAI5tNmf7KaTAuhcvRobAQs"
# 访问密钥密钥
access_key_secret: str = "NmJ3d2VNRU1MREY0T1RtRnBqMlFqdlBxN3pMUk1j"
# 组织ID
organization_id: str = "67a361cf556e6cdab537117a"
# RDC Access Token
rdc_access_token_encrypted: str = (
"cHQtYXp0allnQWpub0FYZWpqZm1RWGtneHk0XzBlMmYzZTZmLWQwOWItNDE4Mi1iZWUx"
"LTQ1ZTFkYjI0NGRlMg=="
)
# 区域
region: str = "cn-hangzhou"
# 端点
endpoint: str = "devops.cn-hangzhou.aliyuncs.com"
# 下载重试次数
download_retry: int = 3
@dataclass
class RepoConfig:
"""仓库管理工具配置"""
# 缓存目录
cache_dir: Path = TEMP_PATH / "repo_cache"
# GitHub配置
github: GithubConfig = field(default_factory=GithubConfig)
# 阿里云CodeUp配置
aliyun_codeup: AliyunCodeupConfig = field(default_factory=AliyunCodeupConfig)
# 单例实例
_instance = None
@classmethod
def get_instance(cls) -> "RepoConfig":
"""获取单例实例"""
if cls._instance is None:
cls._instance = cls()
return cls._instance
def ensure_dirs(self):
"""确保目录存在"""
self.cache_dir.mkdir(parents=True, exist_ok=True)

View File

@ -0,0 +1,68 @@
"""
仓库管理工具的异常类
"""
class RepoManagerError(Exception):
"""仓库管理工具异常基类"""
def __init__(self, message: str, repo_name: str | None = None):
self.message = message
self.repo_name = repo_name
super().__init__(self.message)
class RepoUpdateError(RepoManagerError):
"""仓库更新异常"""
def __init__(self, message: str, repo_name: str | None = None):
super().__init__(f"仓库更新失败: {message}", repo_name)
class RepoDownloadError(RepoManagerError):
"""仓库下载异常"""
def __init__(self, message: str, repo_name: str | None = None):
super().__init__(f"文件下载失败: {message}", repo_name)
class RepoNotFoundError(RepoManagerError):
"""仓库不存在异常"""
def __init__(self, repo_name: str):
super().__init__(f"仓库不存在: {repo_name}", repo_name)
class FileNotFoundError(RepoManagerError):
"""文件不存在异常"""
def __init__(self, file_path: str, repo_name: str | None = None):
super().__init__(f"文件不存在: {file_path}", repo_name)
class AuthenticationError(RepoManagerError):
"""认证异常"""
def __init__(self, repo_type: str):
super().__init__(f"认证失败: {repo_type}")
class ApiRateLimitError(RepoManagerError):
"""API速率限制异常"""
def __init__(self, repo_type: str):
super().__init__(f"API速率限制: {repo_type}")
class NetworkError(RepoManagerError):
"""网络异常"""
def __init__(self, message: str):
super().__init__(f"网络错误: {message}")
class ConfigError(RepoManagerError):
"""配置异常"""
def __init__(self, message: str):
super().__init__(f"配置错误: {message}")

View File

@ -0,0 +1,529 @@
"""
GitHub仓库管理工具
"""
import asyncio
from collections.abc import Callable
from datetime import datetime
from pathlib import Path
from aiocache import cached
from zhenxun.services.log import logger
from zhenxun.utils.github_utils import GithubUtils, RepoInfo
from zhenxun.utils.http_utils import AsyncHttpx
from .base_manager import BaseRepoManager
from .config import LOG_COMMAND, RepoConfig
from .exceptions import (
ApiRateLimitError,
FileNotFoundError,
NetworkError,
RepoDownloadError,
RepoNotFoundError,
RepoUpdateError,
)
from .models import (
FileDownloadResult,
RepoCommitInfo,
RepoFileInfo,
RepoType,
RepoUpdateResult,
)
class GithubManager(BaseRepoManager):
"""GitHub仓库管理工具"""
def __init__(self, config: RepoConfig | None = None):
"""
初始化GitHub仓库管理工具
参数:
config: 配置如果为None则使用默认配置
"""
super().__init__(config)
async def update_repo(
self,
repo_url: str,
local_path: Path,
branch: str = "main",
include_patterns: list[str] | None = None,
exclude_patterns: list[str] | None = None,
) -> RepoUpdateResult:
"""
更新GitHub仓库
参数:
repo_url: 仓库URL格式为 https://github.com/owner/repo
local_path: 本地保存路径
branch: 分支名称
include_patterns: 包含的文件模式列表 ["*.py", "docs/*.md"]
exclude_patterns: 排除的文件模式列表 ["__pycache__/*", "*.pyc"]
返回:
RepoUpdateResult: 更新结果
"""
try:
# 解析仓库URL
repo_info = GithubUtils.parse_github_url(repo_url)
repo_info.branch = branch
# 获取仓库最新提交ID
newest_commit = await self._get_newest_commit(
repo_info.owner, repo_info.repo, branch
)
# 创建结果对象
result = RepoUpdateResult(
repo_type=RepoType.GITHUB,
repo_name=repo_info.repo,
owner=repo_info.owner,
old_version="", # 将在后面更新
new_version=newest_commit,
)
old_version = await self.read_version_file(local_path)
old_version = old_version.split("-")[-1]
result.old_version = old_version
# 如果版本相同,则无需更新
if newest_commit in old_version:
result.success = True
logger.debug(
f"仓库 {repo_info.repo} 已是最新版本: {newest_commit}",
LOG_COMMAND,
)
return result
# 确保本地目录存在
local_path.mkdir(parents=True, exist_ok=True)
# 获取变更的文件列表
changed_files = await self._get_changed_files(
repo_info.owner,
repo_info.repo,
old_version or None,
newest_commit,
)
# 过滤文件
if include_patterns or exclude_patterns:
from .utils import filter_files
changed_files = filter_files(
changed_files, include_patterns, exclude_patterns
)
result.changed_files = changed_files
# 下载变更的文件
for file_path in changed_files:
try:
local_file_path = local_path / file_path
await self._download_file(repo_info, file_path, local_file_path)
except Exception as e:
logger.error(f"下载文件 {file_path} 失败", LOG_COMMAND, e=e)
# 更新版本文件
await self.write_version_file(local_path, newest_commit)
result.success = True
return result
except RepoUpdateError as e:
logger.error("更新仓库失败", LOG_COMMAND, e=e)
return RepoUpdateResult(
repo_type=RepoType.GITHUB,
repo_name=repo_url.split("/")[-1] if "/" in repo_url else repo_url,
owner=repo_url.split("/")[-2] if "/" in repo_url else "unknown",
old_version="",
new_version="",
error_message=str(e),
)
except Exception as e:
logger.error("更新仓库失败", LOG_COMMAND, e=e)
return RepoUpdateResult(
repo_type=RepoType.GITHUB,
repo_name=repo_url.split("/")[-1] if "/" in repo_url else repo_url,
owner=repo_url.split("/")[-2] if "/" in repo_url else "unknown",
old_version="",
new_version="",
error_message=str(e),
)
async def download_file(
self,
repo_url: str,
file_path: str,
local_path: Path,
branch: str = "main",
) -> FileDownloadResult:
"""
从GitHub下载单个文件
参数:
repo_url: 仓库URL格式为 https://github.com/owner/repo
file_path: 文件在仓库中的路径
local_path: 本地保存路径
branch: 分支名称
返回:
FileDownloadResult: 下载结果
"""
try:
# 解析仓库URL
repo_info = GithubUtils.parse_github_url(repo_url)
repo_info.branch = branch
# 创建结果对象
result = FileDownloadResult(
repo_type=RepoType.GITHUB,
repo_name=repo_info.repo,
owner=repo_info.owner,
file_path=file_path,
local_path=str(local_path),
version=branch,
)
# 确保本地目录存在
local_path.parent.mkdir(parents=True, exist_ok=True)
# 下载文件
file_size = await self._download_file(repo_info, file_path, local_path)
result.success = True
result.file_size = file_size
return result
except RepoDownloadError as e:
logger.error("下载文件失败", LOG_COMMAND, e=e)
return FileDownloadResult(
repo_type=RepoType.GITHUB,
repo_name=repo_url.split("/")[-1] if "/" in repo_url else repo_url,
owner=repo_url.split("/")[-2] if "/" in repo_url else "unknown",
file_path=file_path,
local_path=str(local_path),
version=branch,
error_message=str(e),
)
except Exception as e:
logger.error("下载文件失败", LOG_COMMAND, e=e)
return FileDownloadResult(
repo_type=RepoType.GITHUB,
repo_name=repo_url.split("/")[-1] if "/" in repo_url else repo_url,
owner=repo_url.split("/")[-2] if "/" in repo_url else "unknown",
file_path=file_path,
local_path=str(local_path),
version=branch,
error_message=str(e),
)
async def get_file_list(
self,
repo_url: str,
dir_path: str = "",
branch: str = "main",
recursive: bool = False,
) -> list[RepoFileInfo]:
"""
获取仓库文件列表
参数:
repo_url: 仓库URL格式为 https://github.com/owner/repo
dir_path: 目录路径空字符串表示仓库根目录
branch: 分支名称
recursive: 是否递归获取子目录
返回:
list[RepoFileInfo]: 文件信息列表
"""
try:
# 解析仓库URL
repo_info = GithubUtils.parse_github_url(repo_url)
repo_info.branch = branch
# 获取文件列表
for api in GithubUtils.iter_api_strategies():
try:
await api.parse_repo_info(repo_info)
files = api.get_files(dir_path, True)
result = []
for file_path in files:
# 跳过非当前目录的文件(如果不是递归模式)
if not recursive and "/" in file_path.replace(
dir_path, "", 1
).strip("/"):
continue
is_dir = file_path.endswith("/")
file_info = RepoFileInfo(path=file_path, is_dir=is_dir)
result.append(file_info)
return result
except Exception as e:
logger.debug("使用API策略获取文件列表失败", LOG_COMMAND, e=e)
continue
raise RepoNotFoundError(repo_url)
except Exception as e:
logger.error("获取文件列表失败", LOG_COMMAND, e=e)
return []
async def get_commit_info(
self, repo_url: str, commit_id: str
) -> RepoCommitInfo | None:
"""
获取提交信息
参数:
repo_url: 仓库URL格式为 https://github.com/owner/repo
commit_id: 提交ID
返回:
Optional[RepoCommitInfo]: 提交信息如果获取失败则返回None
"""
try:
# 解析仓库URL
repo_info = GithubUtils.parse_github_url(repo_url)
# 构建API URL
api_url = f"https://api.github.com/repos/{repo_info.owner}/{repo_info.repo}/commits/{commit_id}"
# 发送请求
resp = await AsyncHttpx.get(
api_url,
timeout=self.config.github.api_timeout,
proxy=self.config.github.proxy,
)
if resp.status_code == 403 and "rate limit" in resp.text.lower():
raise ApiRateLimitError("GitHub")
if resp.status_code != 200:
if resp.status_code == 404:
raise RepoNotFoundError(f"{repo_info.owner}/{repo_info.repo}")
raise NetworkError(f"HTTP {resp.status_code}: {resp.text}")
data = resp.json()
return RepoCommitInfo(
commit_id=data["sha"],
message=data["commit"]["message"],
author=data["commit"]["author"]["name"],
commit_time=datetime.fromisoformat(
data["commit"]["author"]["date"].replace("Z", "+00:00")
),
changed_files=[file["filename"] for file in data.get("files", [])],
)
except Exception as e:
logger.error("获取提交信息失败", LOG_COMMAND, e=e)
return None
async def _get_newest_commit(self, owner: str, repo: str, branch: str) -> str:
"""
获取仓库最新提交ID
参数:
owner: 仓库拥有者
repo: 仓库名称
branch: 分支名称
返回:
str: 提交ID
"""
try:
newest_commit = await RepoInfo.get_newest_commit(owner, repo, branch)
if not newest_commit:
raise RepoNotFoundError(f"{owner}/{repo}")
return newest_commit
except Exception as e:
logger.error("获取最新提交ID失败", LOG_COMMAND, e=e)
raise RepoUpdateError(f"获取最新提交ID失败: {e}")
@cached(ttl=3600)
async def _get_changed_files(
self, owner: str, repo: str, old_commit: str | None, new_commit: str
) -> list[str]:
"""
获取两个提交之间变更的文件列表
参数:
owner: 仓库拥有者
repo: 仓库名称
old_commit: 旧提交ID如果为None则获取所有文件
new_commit: 新提交ID
返回:
list[str]: 变更的文件列表
"""
if not old_commit:
# 如果没有旧提交,则获取仓库中的所有文件
api_url = f"https://api.github.com/repos/{owner}/{repo}/git/trees/{new_commit}?recursive=1"
resp = await AsyncHttpx.get(
api_url,
timeout=self.config.github.api_timeout,
proxy=self.config.github.proxy,
)
if resp.status_code == 403 and "rate limit" in resp.text.lower():
raise ApiRateLimitError("GitHub")
if resp.status_code != 200:
if resp.status_code == 404:
raise RepoNotFoundError(f"{owner}/{repo}")
raise NetworkError(f"HTTP {resp.status_code}: {resp.text}")
data = resp.json()
return [
item["path"] for item in data.get("tree", []) if item["type"] == "blob"
]
# 如果有旧提交,则获取两个提交之间的差异
api_url = f"https://api.github.com/repos/{owner}/{repo}/compare/{old_commit}...{new_commit}"
resp = await AsyncHttpx.get(
api_url,
timeout=self.config.github.api_timeout,
proxy=self.config.github.proxy,
)
if resp.status_code == 403 and "rate limit" in resp.text.lower():
raise ApiRateLimitError("GitHub")
if resp.status_code != 200:
if resp.status_code == 404:
raise RepoNotFoundError(f"{owner}/{repo}")
raise NetworkError(f"HTTP {resp.status_code}: {resp.text}")
data = resp.json()
return [file["filename"] for file in data.get("files", [])]
async def update_via_git(
self,
repo_url: str,
local_path: Path,
branch: str = "main",
force: bool = False,
*,
repo_type: RepoType | None = None,
owner: str | None = None,
prepare_repo_url: Callable[[str], str] | None = None,
) -> RepoUpdateResult:
"""
通过Git命令直接更新仓库
参数:
repo_url: 仓库URL格式为 https://github.com/owner/repo
local_path: 本地仓库路径
branch: 分支名称
force: 是否强制拉取
返回:
RepoUpdateResult: 更新结果
"""
# 解析仓库URL
repo_info = GithubUtils.parse_github_url(repo_url)
# 调用基类的update_via_git方法
return await super().update_via_git(
repo_url=repo_url,
local_path=local_path,
branch=branch,
force=force,
repo_type=RepoType.GITHUB,
owner=repo_info.owner,
)
async def update(
self,
repo_url: str,
local_path: Path,
branch: str = "main",
use_git: bool = True,
force: bool = False,
include_patterns: list[str] | None = None,
exclude_patterns: list[str] | None = None,
) -> RepoUpdateResult:
"""
更新仓库可选择使用Git命令或API方式
参数:
repo_url: 仓库URL格式为 https://github.com/owner/repo
local_path: 本地保存路径
branch: 分支名称
use_git: 是否使用Git命令更新
include_patterns: 包含的文件模式列表 ["*.py", "docs/*.md"]
exclude_patterns: 排除的文件模式列表 ["__pycache__/*", "*.pyc"]
返回:
RepoUpdateResult: 更新结果
"""
if use_git:
return await self.update_via_git(repo_url, local_path, branch, force)
else:
return await self.update_repo(
repo_url, local_path, branch, include_patterns, exclude_patterns
)
async def _download_file(
self, repo_info: RepoInfo, file_path: str, local_path: Path
) -> int:
"""
下载文件
参数:
repo_info: 仓库信息
file_path: 文件在仓库中的路径
local_path: 本地保存路径
返回:
int: 文件大小字节
"""
# 确保目录存在
local_path.parent.mkdir(parents=True, exist_ok=True)
# 获取下载URL
download_url = await repo_info.get_raw_download_url(file_path)
# 下载文件
for retry in range(self.config.github.download_retry + 1):
try:
resp = await AsyncHttpx.get(
download_url,
timeout=self.config.github.download_timeout,
)
if resp.status_code == 403 and "rate limit" in resp.text.lower():
raise ApiRateLimitError("GitHub")
if resp.status_code != 200:
if resp.status_code == 404:
raise FileNotFoundError(
file_path, f"{repo_info.owner}/{repo_info.repo}"
)
if retry < self.config.github.download_retry:
await asyncio.sleep(1)
continue
raise NetworkError(f"HTTP {resp.status_code}: {resp.text}")
# 保存文件
return await self.save_file_content(resp.content, local_path)
except (ApiRateLimitError, FileNotFoundError) as e:
# 这些错误不需要重试
raise e
except Exception as e:
if retry < self.config.github.download_retry:
logger.warning("下载文件失败,将重试", LOG_COMMAND, e=e)
await asyncio.sleep(1)
continue
raise RepoDownloadError("下载文件失败")
raise RepoDownloadError("下载文件失败: 超过最大重试次数")

View File

@ -0,0 +1,90 @@
"""
仓库管理工具的数据模型
"""
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
class RepoType(str, Enum):
"""仓库类型"""
GITHUB = "github"
ALIYUN = "aliyun"
@dataclass
class RepoFileInfo:
"""仓库文件信息"""
# 文件路径
path: str
# 是否是目录
is_dir: bool
# 文件大小(字节)
size: int | None = None
# 最后修改时间
last_modified: datetime | None = None
@dataclass
class RepoCommitInfo:
"""仓库提交信息"""
# 提交ID
commit_id: str
# 提交消息
message: str
# 作者
author: str
# 提交时间
commit_time: datetime
# 变更的文件列表
changed_files: list[str] = field(default_factory=list)
@dataclass
class RepoUpdateResult:
"""仓库更新结果"""
# 仓库类型
repo_type: RepoType
# 仓库名称
repo_name: str
# 仓库拥有者
owner: str
# 旧版本
old_version: str
# 新版本
new_version: str
# 是否成功
success: bool = False
# 错误消息
error_message: str = ""
# 变更的文件列表
changed_files: list[str] = field(default_factory=list)
@dataclass
class FileDownloadResult:
"""文件下载结果"""
# 仓库类型
repo_type: RepoType
# 仓库名称
repo_name: str
# 仓库拥有者
owner: str
# 文件路径
file_path: str
# 本地路径
local_path: str
# 版本
version: str
# 是否成功
success: bool = False
# 文件大小(字节)
file_size: int = 0
# 错误消息
error_message: str = ""

View File

@ -0,0 +1,122 @@
"""
仓库管理工具的工具函数
"""
import asyncio
from pathlib import Path
import re
from zhenxun.services.log import logger
from .config import LOG_COMMAND
async def check_git() -> bool:
"""
检查环境变量中是否存在 git
返回:
bool: 是否存在git命令
"""
try:
process = await asyncio.create_subprocess_shell(
"git --version",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, _ = await process.communicate()
return bool(stdout)
except Exception as e:
logger.error("检查git命令失败", LOG_COMMAND, e=e)
return False
async def run_git_command(
command: str, cwd: Path | None = None
) -> tuple[bool, str, str]:
"""
运行git命令
参数:
command: 命令
cwd: 工作目录
返回:
tuple[bool, str, str]: (是否成功, 标准输出, 标准错误)
"""
try:
full_command = f"git {command}"
process = await asyncio.create_subprocess_shell(
full_command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=cwd,
)
stdout_bytes, stderr_bytes = await process.communicate()
stdout = stdout_bytes.decode("utf-8").strip()
stderr = stderr_bytes.decode("utf-8").strip()
return process.returncode == 0, stdout, stderr
except Exception as e:
logger.error(f"运行git命令失败: {command}, 错误: {e}")
return False, "", str(e)
def glob_to_regex(pattern: str) -> str:
"""
将glob模式转换为正则表达式
参数:
pattern: glob模式 "*.py"
返回:
str: 正则表达式
"""
# 转义特殊字符
regex = re.escape(pattern)
# 替换glob通配符
regex = regex.replace(r"\*\*", ".*") # ** -> .*
regex = regex.replace(r"\*", "[^/]*") # * -> [^/]*
regex = regex.replace(r"\?", "[^/]") # ? -> [^/]
# 添加开始和结束标记
regex = f"^{regex}$"
return regex
def filter_files(
files: list[str],
include_patterns: list[str] | None = None,
exclude_patterns: list[str] | None = None,
) -> list[str]:
"""
过滤文件列表
参数:
files: 文件列表
include_patterns: 包含的文件模式列表 ["*.py", "docs/*.md"]
exclude_patterns: 排除的文件模式列表 ["__pycache__/*", "*.pyc"]
返回:
list[str]: 过滤后的文件列表
"""
result = files.copy()
# 应用包含模式
if include_patterns:
included = []
for pattern in include_patterns:
regex_pattern = glob_to_regex(pattern)
included.extend(file for file in result if re.match(regex_pattern, file))
result = included
# 应用排除模式
if exclude_patterns:
for pattern in exclude_patterns:
regex_pattern = glob_to_regex(pattern)
result = [file for file in result if not re.match(regex_pattern, file)]
return result