Merge pull request #1567 from Copaan/dev

代码优化
This commit is contained in:
HibiKier 2024-08-22 20:01:25 +08:00 committed by GitHub
commit 7e7e88557d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 61 additions and 27 deletions

View File

@ -2,6 +2,7 @@ import os
import shutil import shutil
import tarfile import tarfile
import zipfile import zipfile
import subprocess
from pathlib import Path from pathlib import Path
from nonebot.adapters import Bot from nonebot.adapters import Bot
@ -20,12 +21,24 @@ from .config import (
MAIN_URL, MAIN_URL,
PYPROJECT_FILE, PYPROJECT_FILE,
PYPROJECT_LOCK_FILE, PYPROJECT_LOCK_FILE,
REQ_TXT_FILE,
RELEASE_URL, RELEASE_URL,
REPLACE_FOLDERS, REPLACE_FOLDERS,
TMP_PATH, TMP_PATH,
VERSION_FILE, VERSION_FILE,
) )
def install_requirement():
requirement_path = (Path() / "requirements.txt").absolute()
if not requirement_path.exists():
logger.debug(f"没有找到zhenxun的requirement.txt,目标路径为{requirement_path}", "插件管理")
return
try:
result = subprocess.run(["pip", "install", "-r", str(requirement_path)], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
logger.debug(f"成功安装真寻依赖,日志:\n{result.stdout}", "插件管理")
except subprocess.CalledProcessError as e:
logger.error(f"安装真寻依赖失败,错误:\n{e.stderr}")
@run_sync @run_sync
def _file_handle(latest_version: str | None): def _file_handle(latest_version: str | None):
@ -49,20 +62,27 @@ def _file_handle(latest_version: str | None):
) )
_pyproject = download_file_path / "pyproject.toml" _pyproject = download_file_path / "pyproject.toml"
_lock_file = download_file_path / "poetry.lock" _lock_file = download_file_path / "poetry.lock"
_req_file = download_file_path / "requirements.txt"
extract_path = download_file_path / "zhenxun" extract_path = download_file_path / "zhenxun"
target_path = BASE_PATH target_path = BASE_PATH
if PYPROJECT_FILE.exists(): if PYPROJECT_FILE.exists():
logger.debug(f"移除备份文件: {PYPROJECT_FILE}", "检查更新") logger.debug(f"移除备份文件: {PYPROJECT_FILE}", "检查更新")
shutil.move(PYPROJECT_FILE, BACKUP_PATH / "pyproject.toml") shutil.move(PYPROJECT_FILE, BACKUP_PATH / "pyproject.toml")
if PYPROJECT_LOCK_FILE.exists(): if PYPROJECT_LOCK_FILE.exists():
logger.debug(f"移除备份文件: {PYPROJECT_FILE}", "检查更新") logger.debug(f"移除备份文件: {PYPROJECT_LOCK_FILE}", "检查更新")
shutil.move(PYPROJECT_LOCK_FILE, BACKUP_PATH / "poetry.lock") shutil.move(PYPROJECT_LOCK_FILE, BACKUP_PATH / "poetry.lock")
if REQ_TXT_FILE.exists():
logger.debug(f"移除备份文件: {REQ_TXT_FILE}", "检查更新")
shutil.move(REQ_TXT_FILE, BACKUP_PATH / "requirements.txt")
if _pyproject.exists(): if _pyproject.exists():
logger.debug("移动文件: pyproject.toml", "检查更新") logger.debug("移动文件: pyproject.toml", "检查更新")
shutil.move(_pyproject, Path() / "pyproject.toml") shutil.move(_pyproject, Path() / "pyproject.toml")
if _lock_file.exists(): if _lock_file.exists():
logger.debug("移动文件: pyproject.toml", "检查更新") logger.debug("移动文件: poetry.lock", "检查更新")
shutil.move(_lock_file, Path() / "poetry.lock") shutil.move(_lock_file, Path() / "poetry.lock")
if _req_file.exists():
logger.debug("移动文件: requirements.txt", "检查更新")
shutil.move(_req_file, Path() / "requirements.txt")
for folder in REPLACE_FOLDERS: for folder in REPLACE_FOLDERS:
"""移动指定文件夹""" """移动指定文件夹"""
_dir = BASE_PATH / folder _dir = BASE_PATH / folder
@ -98,8 +118,7 @@ def _file_handle(latest_version: str | None):
if latest_version: if latest_version:
with open(VERSION_FILE, "w", encoding="utf8") as f: with open(VERSION_FILE, "w", encoding="utf8") as f:
f.write(f"__version__: {latest_version}") f.write(f"__version__: {latest_version}")
os.system(f"poetry run pip install -r requirements.txt") install_requirement()
class UpdateManage: class UpdateManage:
@ -130,12 +149,17 @@ class UpdateManage:
""" """
logger.info(f"开始下载真寻最新版文件....", "检查更新") logger.info(f"开始下载真寻最新版文件....", "检查更新")
cur_version = cls.__get_version() cur_version = cls.__get_version()
new_version = "main"
url = MAIN_URL
if version_type == "dev": if version_type == "dev":
url = DEV_URL url = DEV_URL
new_version = "dev" new_version = await cls.__get_version_from_branch("dev")
if version_type == "release": if new_version:
new_version = new_version.split(":")[-1].strip()
elif version_type == "main":
url = MAIN_URL
new_version = await cls.__get_version_from_branch("main")
if new_version:
new_version = new_version.split(":")[-1].strip()
elif version_type == "release":
data = await cls.__get_latest_data() data = await cls.__get_latest_data()
if not data: if not data:
return "获取更新版本失败..." return "获取更新版本失败..."
@ -161,8 +185,6 @@ class UpdateManage:
) )
if await AsyncHttpx.download_file(url, download_file): if await AsyncHttpx.download_file(url, download_file):
logger.debug("下载真寻最新版文件完成...", "检查更新") logger.debug("下载真寻最新版文件完成...", "检查更新")
if version_type != "release":
new_version = None
await _file_handle(new_version) await _file_handle(new_version)
return f"版本更新完成\n版本: {cur_version} -> {new_version}\n请重新启动真寻以完成更新!" return f"版本更新完成\n版本: {cur_version} -> {new_version}\n请重新启动真寻以完成更新!"
else: else:
@ -200,3 +222,22 @@ class UpdateManage:
except Exception as e: except Exception as e:
logger.error(f"检查更新真寻获取版本失败", e=e) logger.error(f"检查更新真寻获取版本失败", e=e)
return {} return {}
@classmethod
async def __get_version_from_branch(cls, branch: str) -> str:
"""从指定分支获取版本号
参数:
branch: 分支名称
返回:
str: 版本号
"""
version_url = f"https://raw.githubusercontent.com/HibiKier/zhenxun_bot/{branch}/__version__"
try:
res = await AsyncHttpx.get(version_url)
if res.status_code == 200:
return res.text.strip()
except Exception as e:
logger.error(f"获取 {branch} 分支版本失败", e=e)
return "未知版本"

View File

@ -11,6 +11,7 @@ VERSION_FILE = Path() / "__version__"
PYPROJECT_FILE = Path() / "pyproject.toml" PYPROJECT_FILE = Path() / "pyproject.toml"
PYPROJECT_LOCK_FILE = Path() / "poetry.lock" PYPROJECT_LOCK_FILE = Path() / "poetry.lock"
REQ_TXT_FILE = Path() / "requirements.txt"
BASE_PATH = Path() / "zhenxun" BASE_PATH = Path() / "zhenxun"

View File

@ -87,15 +87,14 @@ async def _handle_setting(
) )
if extra_data.tasks: if extra_data.tasks:
for task in extra_data.tasks: for task in extra_data.tasks:
task_list.append( task_list.append((task.create_status,
TaskInfo( TaskInfo(
module=task.module, module=task.module,
name=task.name, name=task.name,
status=task.status, status=task.status,
run_time=task.run_time, run_time=task.run_time,
default_status=task.default_status,
) )
) ))
@driver.on_startup @driver.on_startup
@ -164,15 +163,16 @@ async def _():
} }
create_list = [] create_list = []
update_list = [] update_list = []
for task in task_list: for status, task in task_list:
if task.module not in module_dict: if task.module not in module_dict:
create_list.append(task) create_list.append((status, task))
else: else:
task.id = module_dict[task.module] task.id = module_dict[task.module]
update_list.append(task) update_list.append(task)
if create_list: if create_list:
await TaskInfo.bulk_create(create_list, 10) _create_list = [t[1] for t in create_list]
if block := [t.module for t in create_list if not t.default_status]: await TaskInfo.bulk_create(_create_list, 10)
if block := [t[1].module for t in create_list if not t[0]]:
block_task = ",".join(block) + "," block_task = ",".join(block) + ","
if group_list := await GroupConsole.all(): if group_list := await GroupConsole.all():
for group in group_list: for group in group_list:

View File

@ -137,7 +137,7 @@ class Task(BaseBlock):
"""被动技能名称""" """被动技能名称"""
status: bool = True status: bool = True
"""全局开关状态""" """全局开关状态"""
default_status: bool = True create_status: bool = False
"""初次加载默认开关状态""" """初次加载默认开关状态"""
run_time: str | None = None run_time: str | None = None
"""运行时间""" """运行时间"""

View File

@ -15,8 +15,6 @@ class TaskInfo(Model):
"""被动技能名称""" """被动技能名称"""
status = fields.BooleanField(default=True, description="全局开关状态") status = fields.BooleanField(default=True, description="全局开关状态")
"""全局开关状态""" """全局开关状态"""
default_status = fields.BooleanField(default=True, description="进群默认状态")
"""加载默认状态"""
run_time = fields.CharField(255, null=True, description="运行时间") run_time = fields.CharField(255, null=True, description="运行时间")
"""运行时间""" """运行时间"""
run_count = fields.IntField(default=0, description="运行次数") run_count = fields.IntField(default=0, description="运行次数")
@ -55,9 +53,3 @@ class TaskInfo(Model):
"""群组是否被ban""" """群组是否被ban"""
return True return True
return False return False
@classmethod
def _run_script(cls):
return [
"ALTER TABLE task_info ADD default_status boolean NOT NULL DEFAULT true;",
]

View File

@ -5,7 +5,7 @@ from tortoise import Tortoise
from tortoise.exceptions import OperationalError from tortoise.exceptions import OperationalError
from zhenxun.models.plugin_info import PluginInfo from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.db_context import TestSQL from zhenxun.models.task_info import TaskInfo
from ....base_model import BaseResultModel, QueryModel, Result from ....base_model import BaseResultModel, QueryModel, Result
from ....utils import authentication from ....utils import authentication