From f0b05ec5ed2bf06182dd428a28b79a25d02fe55c Mon Sep 17 00:00:00 2001 From: HibiKier <775757368@qq.com> Date: Tue, 30 Jul 2024 22:36:09 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20:=20=E6=96=B0=E5=A2=9E=E6=9B=B4?= =?UTF-8?q?=E5=A4=9A=E7=9A=84=E6=95=B0=E6=8D=AE=E8=BF=81=E7=A7=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- zhenxun/builtin_plugins/__init__.py | 6 +- zhenxun/builtin_plugins/init/init_plugin.py | 259 +++++++++++++++++++- zhenxun/models/plugin_limit.py | 9 +- 3 files changed, 262 insertions(+), 12 deletions(-) diff --git a/zhenxun/builtin_plugins/__init__.py b/zhenxun/builtin_plugins/__init__.py index 4045c9aa..ef7ed6ff 100644 --- a/zhenxun/builtin_plugins/__init__.py +++ b/zhenxun/builtin_plugins/__init__.py @@ -33,7 +33,6 @@ for d in os.listdir(path): driver: Driver = nonebot.get_driver() -flag = True SIGN_SQL = """ select distinct on("user_id") t1.user_id, t1.checkin_count, t1.add_probability, t1.specify_probability, t1.impression @@ -58,15 +57,14 @@ from public.bag_users t1 @driver.on_startup async def _(): - global flag + """签到与用户的数据迁移""" if goods_list := await GoodsInfo.filter(uuid__isnull=True).all(): for goods in goods_list: goods.uuid = uuid.uuid1() # type: ignore await GoodsInfo.bulk_update(goods_list, ["uuid"], 10) await shop_register.load_register() if ( - flag - and not await UserConsole.annotate().count() + not await UserConsole.annotate().count() and not await SignUser.annotate().count() ): try: diff --git a/zhenxun/builtin_plugins/init/init_plugin.py b/zhenxun/builtin_plugins/init/init_plugin.py index 7459d766..005f446e 100644 --- a/zhenxun/builtin_plugins/init/init_plugin.py +++ b/zhenxun/builtin_plugins/init/init_plugin.py @@ -1,15 +1,24 @@ import nonebot +import ujson as json from nonebot import get_loaded_plugins from nonebot.drivers import Driver from nonebot.plugin import Plugin from ruamel.yaml import YAML +from zhenxun.configs.path_config import DATA_PATH from zhenxun.configs.utils import PluginExtraData, PluginSetting +from zhenxun.models.group_console import GroupConsole from zhenxun.models.plugin_info import PluginInfo from zhenxun.models.plugin_limit import PluginLimit from zhenxun.models.task_info import TaskInfo from zhenxun.services.log import logger -from zhenxun.utils.enum import PluginType +from zhenxun.utils.enum import ( + BlockType, + LimitCheckType, + LimitWatchType, + PluginLimitType, + PluginType, +) _yaml = YAML(pure=True) _yaml.allow_unicode = True @@ -150,3 +159,251 @@ async def _(): ["run_time", "status", "name"], 10, ) + await data_migration() + + +async def data_migration(): + await limit_migration() + await plugin_migration() + await group_migration() + + +async def limit_migration(): + """插件限制迁移""" + cd_file = DATA_PATH / "configs" / "plugins2cd.yaml" + block_file = DATA_PATH / "configs" / "plugins2block.yaml" + count_file = DATA_PATH / "configs" / "plugins2count.yaml" + limit_data: dict[str, list[tuple[str, dict]]] = {} + if cd_file.exists(): + with open(cd_file, encoding="utf8") as f: + if data := _yaml.load(f): + for k in data["PluginCdLimit"]: + limit_data[k] = [("CD", data["PluginCdLimit"][k])] + cd_file.unlink() + if block_file.exists(): + with open(block_file, encoding="utf8") as f: + if data := _yaml.load(f): + for k in data["PluginBlockLimit"]: + if k in limit_data: + limit_data[k].append(("BLOCK", data["PluginBlockLimit"][k])) + else: + limit_data[k] = [("BLOCK", data["PluginBlockLimit"][k])] + block_file.unlink() + if count_file.exists(): + with open(count_file, encoding="utf8") as f: + if data := _yaml.load(f): + for k in data["PluginCountLimit"]: + if k in limit_data: + limit_data[k].append(("COUNT", data["PluginCountLimit"][k])) + else: + limit_data[k] = [("COUNT", data["PluginCountLimit"][k])] + count_file.unlink() + if limit_data: + logger.info("开始迁移插件限制数据...") + update_list = [] + create_list = [] + plugins = await PluginInfo.filter(module__in=limit_data.keys()) + for plugin in plugins: + limits: list[PluginLimit] = await plugin.plugin_limit.all() # type: ignore + exits_limit = [x[0] for x in limit_data[plugin.module]] + _not_create_type = [] + for limit in limits: + if _limit_list := [ + x[1] + for x in limit_data[plugin.module] + if x[0] == str(limit.limit_type) + ]: + """修改""" + _not_create_type.append(str(limit.limit_type)) + _limit = _limit_list[0] + watch_type = LimitWatchType.USER + if _limit.get("watch_type") == "group": + watch_type = LimitWatchType.GROUP + check_type = LimitCheckType.ALL + if _limit.get("check_type") == "private": + check_type = LimitCheckType.PRIVATE + elif _limit.get("check_type") == "group": + check_type = LimitCheckType.GROUP + limit.watch_type = watch_type + limit.result = _limit.get("rst", "") + limit.status = _limit.get("status", True) + if limit.watch_type != PluginLimitType.COUNT: + limit.check_type = check_type + if limit.watch_type == PluginLimitType.CD: + limit.cd = _limit["cd"] + if limit.watch_type == PluginLimitType.COUNT: + limit.max_count = _limit["count"] + await limit.save() + update_list.append(limit) + for s in [e for e in exits_limit if e not in _not_create_type]: + if _limit_list := [ + x[1] for x in limit_data[plugin.module] if s == x[0] + ]: + _limit = _limit_list[0] + limit_type = PluginLimitType.CD + if s == "BLOCK": + limit_type = PluginLimitType.BLOCK + elif s == "COUNT": + limit_type = PluginLimitType.COUNT + watch_type = LimitWatchType.USER + if _limit.get("watch_type") == "group": + watch_type = LimitWatchType.GROUP + check_type = LimitCheckType.ALL + if _limit.get("check_type") == "private": + check_type = LimitCheckType.PRIVATE + elif _limit.get("check_type") == "group": + check_type = LimitCheckType.GROUP + create_list.append( + PluginLimit( + module=plugin.module, + module_path=plugin.module_path, + plugin=plugin, + limit_type=limit_type, + watch_type=watch_type, + status=_limit.get("status", True), + check_type=check_type, + result=_limit.get("rst", ""), + cd=_limit.get("cd"), + max_count=_limit.get("max_count"), + ) + ) + # TODO: 批量错误 tortoise.exceptions.OperationalError: syntax error at or near "ALL" + # if update_list: + # await PluginLimit.bulk_update( + # update_list, + # [ + # "watch_type", + # "status", + # "check_type", + # "result", + # "cd", + # "max_count", + # ], + # 10, + # ) + if create_list: + await PluginLimit.bulk_create(create_list, 10) + logger.info("迁移插件限制数据完成!") + + +async def plugin_migration(): + """迁移插件数据""" + setting_file = DATA_PATH / "configs" / "plugins2settings.yaml" + plugin_file = DATA_PATH / "manager" / "plugins_manager.json" + if setting_file.exists(): + with open(setting_file, encoding="utf8") as f: + if data := _yaml.load(f): + logger.info("开始迁移插件setting数据...") + data = data["PluginSettings"] + plugins = await PluginInfo.filter(module__in=data.keys()) + for plugin in plugins: + if plugin_data_list := [ + data[p] for p in data if p == plugin.module + ]: + plugin_data = plugin_data_list[0] + plugin.default_status = plugin_data.get("default_status", True) + plugin.level = plugin_data.get("level", 5) + plugin.limit_superuser = plugin_data.get( + "limit_superuser", False + ) + plugin.menu_type = plugin_data.get("plugin_type", ["功能"])[0] + plugin.cost_gold = plugin_data.get("cost_gold", 0) + await PluginInfo.bulk_update( + plugins, + [ + "default_status", + "level", + "limit_superuser", + "menu_type", + "cost_gold", + ], + 10, + ) + setting_file.unlink() + logger.info("迁移插件setting数据完成!") + if plugin_file.exists(): + with open(plugin_file, encoding="utf8") as f: + if data := json.load(f): + logger.info("开始迁移插件数据...") + plugins = await PluginInfo.filter(module__in=data.keys()) + for plugin in plugins: + if plugin_data := data.get(plugin.module): + plugin.status = plugin_data.get("status", True) + block_type = None + get_block = plugin_data.get("block_type") + if get_block == "all": + block_type = BlockType.ALL + elif get_block == "private": + block_type = BlockType.PRIVATE + elif get_block == "group": + block_type = BlockType.GROUP + plugin.block_type = block_type + await PluginInfo.bulk_update(plugins, ["status", "block_type"], 10) + plugin_file.unlink() + logger.info("迁移插件数据完成!") + + +async def group_migration(): + """ + 群组数据迁移 + """ + group_file = DATA_PATH / "manager" / "group_manager.json" + if group_file.exists(): + with open(group_file, encoding="utf8") as f: + if data := json.load(f): + logger.info("开始迁移群组数据...") + update_list = [] + create_list = [] + white_group = data["white_group"] + close_task = data["close_task"] + old_group_list: dict = data["group_manager"] + if close_task: + """全局被动关闭""" + await TaskInfo.filter(module__in=close_task).update(status=False) + group_list = await GroupConsole.filter( + group_id__in=old_group_list.keys() + ) + for old_group_id in old_group_list: + old_group = old_group_list[old_group_id] + block_plugin = "" + block_task = "" + status = old_group.get("status", True) + level = old_group.get("level", 5) + if close_plugins := old_group.get("close_plugins"): + block_plugin = ",".join(close_plugins) + "," + if group_task_status := old_group.get("group_task_status"): + close_task = [ + t for t in group_task_status if not group_task_status[t] + ] + block_task = ",".join(close_task) + "," + if group_ := [g for g in group_list if g.group_id == old_group_id]: + group = group_[0] + if group.group_id in white_group: + group.is_super = True + group.status = status + group.block_plugin = block_plugin + group.block_task = block_task + group.level = level + update_list.append(group) + else: + """添加""" + create_list.append( + GroupConsole( + group_id=old_group_id, + status=status, + level=level, + block_plugin=block_plugin, + block_task=block_task, + is_super=old_group_id in white_group, + ) + ) + if update_list: + await GroupConsole.bulk_update( + update_list, + ["is_super", "status", "block_plugin", "block_task"], + 10, + ) + if create_list: + await GroupConsole.bulk_create(create_list, 10) + group_file.unlink() + logger.info("迁移群组数据完成!") diff --git a/zhenxun/models/plugin_limit.py b/zhenxun/models/plugin_limit.py index 96538b64..e6b185e7 100644 --- a/zhenxun/models/plugin_limit.py +++ b/zhenxun/models/plugin_limit.py @@ -1,12 +1,7 @@ from tortoise import fields from zhenxun.services.db_context import Model -from zhenxun.utils.enum import ( - BlockType, - LimitCheckType, - LimitWatchType, - PluginLimitType, -) +from zhenxun.utils.enum import LimitCheckType, LimitWatchType, PluginLimitType class PluginLimit(Model): @@ -30,7 +25,7 @@ class PluginLimit(Model): status = fields.BooleanField(default=True, description="限制的开关状态") """限制的开关状态""" check_type = fields.CharEnumField( - LimitCheckType, default=BlockType.ALL, description="检查类型" + LimitCheckType, default=LimitCheckType.ALL, description="检查类型" ) """检查类型""" result = fields.CharField(max_length=255, null=True, description="返回信息")