From 5b84ef9a0891f82d444dba6995053ae3a55c3840 Mon Sep 17 00:00:00 2001 From: HibiKier <775757368@qq.com> Date: Fri, 28 Mar 2025 21:51:18 +0800 Subject: [PATCH] =?UTF-8?q?:art:=20=E4=BC=98=E5=8C=96=E7=BE=A4=E7=BB=84?= =?UTF-8?q?=E8=A1=A8=E4=BB=A3=E7=A0=81=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- zhenxun/models/group_console.py | 260 ++++++++++++++++++++------------ 1 file changed, 164 insertions(+), 96 deletions(-) diff --git a/zhenxun/models/group_console.py b/zhenxun/models/group_console.py index 410457c1..0a0693d3 100644 --- a/zhenxun/models/group_console.py +++ b/zhenxun/models/group_console.py @@ -10,6 +10,42 @@ from zhenxun.services.db_context import Model from zhenxun.utils.enum import PluginType +def add_disable_marker(name: str) -> str: + """添加模块禁用标记符 + + Args: + name: 模块名称 + + Returns: + 添加了禁用标记的模块名 (前缀'<'和后缀',') + """ + return f"<{name}," + + +@overload +def convert_module_format(data: str) -> list[str]: ... + + +@overload +def convert_module_format(data: list[str]) -> str: ... + + +def convert_module_format(data: str | list[str]) -> str | list[str]: + """ + 在 ` str: - return f"<{name}," - - @overload @classmethod - def convert_module_format(cls, data: str) -> list[str]: ... - - @overload - @classmethod - def convert_module_format(cls, data: list[str]) -> str: ... - - @classmethod - def convert_module_format(cls, data: str | list[str]) -> str | list[str]: - """ - 在 ` list[str]: + """获取默认禁用的任务模块 返回: - str | list[str]: 根据输入类型返回转换后的数据。 + list[str]: 任务模块列表 """ - if isinstance(data, str): - return [item.strip(",") for item in data.split("<") if item] - elif isinstance(data, list): - return "".join(cls.format(item) for item in data) - - @classmethod - async def __set_default_plugin_status(cls, group: Self) -> list[str]: - """设置新群组信息时默认插件关闭状态 - - 参数: - group: GroupConsole对象 - - 返回: - list[str]: 更新字段列表 - """ - task_modules = cast( + return cast( list[str], - await TaskInfo.filter(default_status=False).values_list( + await TaskInfo.filter(default_status=default_status).values_list( "module", flat=True ), ) - plugin_modules = cast( + + @classmethod + async def _get_plugin_modules(cls, *, default_status: bool) -> list[str]: + """获取默认禁用的插件模块 + + 返回: + list[str]: 插件模块列表 + """ + return cast( list[str], await PluginInfo.filter( plugin_type__in=[PluginType.NORMAL, PluginType.DEPENDANT], - default_status=False, - load_status=True, + default_status=default_status, ).values_list("module", flat=True), ) - if not task_modules and not plugin_modules: - return [] - - update_fields = [] - - if task_modules: - group.block_task = cls.convert_module_format(task_modules) - update_fields.append("block_task") - - if plugin_modules: - group.block_plugin = cls.convert_module_format(list(plugin_modules)) - update_fields.append("block_plugin") - - return update_fields - @classmethod async def create( cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any @@ -126,12 +123,43 @@ class GroupConsole(Model): """覆盖create方法""" group = await super().create(using_db=using_db, **kwargs) - update_fields = await cls.__set_default_plugin_status(group) - if update_fields: - await group.save(using_db=using_db, update_fields=update_fields) + task_modules = await cls._get_task_modules(default_status=False) + plugin_modules = await cls._get_plugin_modules(default_status=False) + + if task_modules or plugin_modules: + await cls._update_modules(group, task_modules, plugin_modules, using_db) return group + @classmethod + async def _update_modules( + cls, + group: Self, + task_modules: list[str], + plugin_modules: list[str], + using_db: BaseDBAsyncClient | None = None, + ) -> None: + """更新模块设置 + + 参数: + group: 群组实例 + task_modules: 任务模块列表 + plugin_modules: 插件模块列表 + using_db: 数据库连接 + """ + update_fields = [] + + if task_modules: + group.block_task = convert_module_format(task_modules) + update_fields.append("block_task") + + if plugin_modules: + group.block_plugin = convert_module_format(plugin_modules) + update_fields.append("block_plugin") + + if update_fields: + await group.save(using_db=using_db, update_fields=update_fields) + @classmethod async def get_or_create( cls, @@ -143,11 +171,14 @@ class GroupConsole(Model): group, is_create = await super().get_or_create( defaults=defaults, using_db=using_db, **kwargs ) + if not is_create: + return group, is_create - if is_create: - update_fields = await cls.__set_default_plugin_status(group) - if update_fields: - await group.save(using_db=using_db, update_fields=update_fields) + task_modules = await cls._get_task_modules(default_status=False) + plugin_modules = await cls._get_plugin_modules(default_status=False) + + if task_modules or plugin_modules: + await cls._update_modules(group, task_modules, plugin_modules, using_db) return group, is_create @@ -162,11 +193,14 @@ class GroupConsole(Model): group, is_create = await super().update_or_create( defaults=defaults, using_db=using_db, **kwargs ) + if not is_create: + return group, is_create - if is_create: - update_fields = await cls.__set_default_plugin_status(group) - if update_fields: - await group.save(using_db=using_db, update_fields=update_fields) + task_modules = await cls._get_task_modules(default_status=False) + plugin_modules = await cls._get_plugin_modules(default_status=False) + + if task_modules or plugin_modules: + await cls._update_modules(group, task_modules, plugin_modules, using_db) return group, is_create @@ -212,7 +246,7 @@ class GroupConsole(Model): """ return await cls.exists( group_id=group_id, - superuser_block_plugin__contains=f"<{module},", + superuser_block_plugin__contains=add_disable_marker(module), ) @classmethod @@ -226,10 +260,11 @@ class GroupConsole(Model): 返回: bool: 是否禁用插件 """ + module = add_disable_marker(module) return await cls.exists( - group_id=group_id, block_plugin__contains=f"<{module}," + group_id=group_id, block_plugin__contains=module ) or await cls.exists( - group_id=group_id, superuser_block_plugin__contains=f"<{module}," + group_id=group_id, superuser_block_plugin__contains=module ) @classmethod @@ -251,12 +286,22 @@ class GroupConsole(Model): group, _ = await cls.get_or_create( group_id=group_id, defaults={"platform": platform} ) + update_fields = [] if is_superuser: - if f"<{module}," not in group.superuser_block_plugin: - group.superuser_block_plugin += f"<{module}," - elif f"<{module}," not in group.block_plugin: - group.block_plugin += f"<{module}," - await group.save(update_fields=["block_plugin", "superuser_block_plugin"]) + superuser_block_plugin = convert_module_format(group.superuser_block_plugin) + if module not in superuser_block_plugin: + superuser_block_plugin.append(module) + group.superuser_block_plugin = convert_module_format( + superuser_block_plugin + ) + update_fields.append("superuser_block_plugin") + elif add_disable_marker(module) not in group.block_plugin: + block_plugin = convert_module_format(group.block_plugin) + block_plugin.append(module) + group.block_plugin = convert_module_format(block_plugin) + update_fields.append("block_plugin") + if update_fields: + await group.save(update_fields=update_fields) @classmethod async def set_unblock_plugin( @@ -277,14 +322,22 @@ class GroupConsole(Model): group, _ = await cls.get_or_create( group_id=group_id, defaults={"platform": platform} ) + update_fields = [] if is_superuser: - if f"<{module}," in group.superuser_block_plugin: - group.superuser_block_plugin = group.superuser_block_plugin.replace( - f"<{module},", "" + superuser_block_plugin = convert_module_format(group.superuser_block_plugin) + if module in superuser_block_plugin: + superuser_block_plugin.remove(module) + group.superuser_block_plugin = convert_module_format( + superuser_block_plugin ) - elif f"<{module}," in group.block_plugin: - group.block_plugin = group.block_plugin.replace(f"<{module},", "") - await group.save(update_fields=["block_plugin", "superuser_block_plugin"]) + update_fields.append("superuser_block_plugin") + elif add_disable_marker(module) in group.block_plugin: + block_plugin = convert_module_format(group.block_plugin) + block_plugin.remove(module) + group.block_plugin = convert_module_format(block_plugin) + update_fields.append("block_plugin") + if update_fields: + await group.save(update_fields=update_fields) @classmethod async def is_normal_block_plugin( @@ -319,7 +372,7 @@ class GroupConsole(Model): """ return await cls.exists( group_id=group_id, - superuser_block_task__contains=f"<{task},", + superuser_block_task__contains=add_disable_marker(task), ) @classmethod @@ -336,22 +389,23 @@ class GroupConsole(Model): 返回: bool: 是否禁用被动 """ + task = add_disable_marker(task) if not channel_id: return await cls.exists( group_id=group_id, channel_id__isnull=True, - block_task__contains=f"<{task},", + block_task__contains=task, ) or await cls.exists( group_id=group_id, channel_id__isnull=True, - superuser_block_task__contains=f"<{task},", + superuser_block_task__contains=task, ) return await cls.exists( - group_id=group_id, channel_id=channel_id, block_task__contains=f"<{task}," + group_id=group_id, channel_id=channel_id, block_task__contains=task ) or await cls.exists( group_id=group_id, channel_id__isnull=True, - superuser_block_task__contains=f"<{task},", + superuser_block_task__contains=task, ) @classmethod @@ -373,12 +427,20 @@ class GroupConsole(Model): group, _ = await cls.get_or_create( group_id=group_id, defaults={"platform": platform} ) + update_fields = [] if is_superuser: - if f"<{task}," not in group.superuser_block_task: - group.superuser_block_task += f"<{task}," - elif f"<{task}," not in group.block_task: - group.block_task += f"<{task}," - await group.save(update_fields=["block_task", "superuser_block_task"]) + superuser_block_task = convert_module_format(group.superuser_block_task) + if task not in group.superuser_block_task: + superuser_block_task.append(task) + group.superuser_block_task = convert_module_format(superuser_block_task) + update_fields.append("superuser_block_task") + elif add_disable_marker(task) not in group.block_task: + block_task = convert_module_format(group.block_task) + block_task.append(task) + group.block_task = convert_module_format(block_task) + update_fields.append("block_task") + if update_fields: + await group.save(update_fields=update_fields) @classmethod async def set_unblock_task( @@ -399,14 +461,20 @@ class GroupConsole(Model): group, _ = await cls.get_or_create( group_id=group_id, defaults={"platform": platform} ) + update_fields = [] if is_superuser: - if f"<{task}," in group.superuser_block_task: - group.superuser_block_task = group.superuser_block_task.replace( - f"<{task},", "" - ) - elif f"<{task}," in group.block_task: - group.block_task = group.block_task.replace(f"<{task},", "") - await group.save(update_fields=["block_task", "superuser_block_task"]) + superuser_block_task = convert_module_format(group.superuser_block_task) + if task in superuser_block_task: + superuser_block_task.remove(task) + group.superuser_block_task = convert_module_format(superuser_block_task) + update_fields.append("superuser_block_task") + elif add_disable_marker(task) in group.block_task: + block_task = convert_module_format(group.block_task) + block_task.remove(task) + group.block_task = convert_module_format(block_task) + update_fields.append("block_task") + if update_fields: + await group.save(update_fields=update_fields) @classmethod def _run_script(cls):