From 251299c56c397bf8cfc804916406281d3e8dcc99 Mon Sep 17 00:00:00 2001 From: HibiKier <775757368@qq.com> Date: Mon, 7 Jul 2025 11:05:21 +0800 Subject: [PATCH] =?UTF-8?q?:zap:=20=EF=BC=9A=E5=A2=9E=E5=BC=BA=E8=8E=B7?= =?UTF-8?q?=E5=8F=96=E7=BE=A4=E7=BB=84=E7=9A=84=E5=AE=89=E5=85=A8=E6=80=A7?= =?UTF-8?q?=E5=92=8C=E5=87=86=E7=A1=AE=E6=80=A7=E3=80=82=E5=90=8C=E6=97=B6?= =?UTF-8?q?=EF=BC=8C=E4=BC=98=E5=8C=96=E4=BA=86=E7=BC=93=E5=AD=98=E7=AE=A1?= =?UTF-8?q?=E7=90=86=E4=B8=AD=E7=9A=84=E7=9B=B8=E5=85=B3=E9=80=BB=E8=BE=91?= =?UTF-8?q?=EF=BC=8C=E7=A1=AE=E4=BF=9D=E7=BC=93=E5=AD=98=E6=93=8D=E4=BD=9C?= =?UTF-8?q?=E7=9A=84=E4=B8=80=E8=87=B4=E6=80=A7=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../builtin_plugins/admin/ban/_data_source.py | 2 +- .../admin/plugin_switch/_data_source.py | 46 ++++---- zhenxun/builtin_plugins/help/_utils.py | 2 +- zhenxun/builtin_plugins/help/normal_help.py | 4 +- .../statistics/_data_source.py | 4 +- .../builtin_plugins/superuser/group_manage.py | 2 +- .../web_ui/api/tabs/manage/data_source.py | 2 +- .../web_ui/api/tabs/system/__init__.py | 6 +- zhenxun/models/ban_console.py | 28 ++++- zhenxun/models/group_console.py | 20 +++- zhenxun/services/cache.py | 30 ++--- zhenxun/services/db_context.py | 105 ++++++++++++++++-- zhenxun/utils/common_utils.py | 4 +- 13 files changed, 182 insertions(+), 73 deletions(-) diff --git a/zhenxun/builtin_plugins/admin/ban/_data_source.py b/zhenxun/builtin_plugins/admin/ban/_data_source.py index ae465bdf..60ee8efa 100644 --- a/zhenxun/builtin_plugins/admin/ban/_data_source.py +++ b/zhenxun/builtin_plugins/admin/ban/_data_source.py @@ -114,7 +114,7 @@ class BanManage: if not is_superuser and user_id and session.id1: user_level = await LevelUser.get_user_level(session.id1, group_id) if idx: - ban_data = await BanConsole.get_or_none(id=idx) + ban_data = await BanConsole.get_ban(id=idx) if not ban_data: return False, "该用户/群组不在黑名单中捏..." if ban_data.ban_level > user_level: diff --git a/zhenxun/builtin_plugins/admin/plugin_switch/_data_source.py b/zhenxun/builtin_plugins/admin/plugin_switch/_data_source.py index 0f8409cf..a6057104 100644 --- a/zhenxun/builtin_plugins/admin/plugin_switch/_data_source.py +++ b/zhenxun/builtin_plugins/admin/plugin_switch/_data_source.py @@ -1,10 +1,12 @@ import os +from typing import cast from zhenxun.configs.path_config import DATA_PATH, IMAGE_PATH from zhenxun.models.group_console import GroupConsole from zhenxun.models.plugin_info import PluginInfo from zhenxun.models.task_info import TaskInfo from zhenxun.services.cache import Cache +from zhenxun.utils.common_utils import CommonUtils from zhenxun.utils.enum import BlockType, CacheType, PluginType from zhenxun.utils.exception import GroupInfoNotFound from zhenxun.utils.image_utils import BuildImage, ImageTemplate, RowStyle @@ -117,9 +119,7 @@ async def build_task(group_id: str | None) -> BuildImage: column_name = ["ID", "模块", "名称", "群组状态", "全局状态", "运行时间"] group = None if group_id: - group = await GroupConsole.get_or_none( - group_id=group_id, channel_id__isnull=True - ) + group = await GroupConsole.get_group(group_id=group_id) if not group: raise GroupInfoNotFound() else: @@ -201,20 +201,19 @@ class PluginManage: ) return f"成功将所有功能进群默认状态修改为: {'开启' if status else '关闭'}" if group_id: - if group := await GroupConsole.get_or_none( - group_id=group_id, channel_id__isnull=True - ): - module_list = await PluginInfo.filter( - plugin_type=PluginType.NORMAL - ).values_list("module", flat=True) + if group := await GroupConsole.get_group(group_id=group_id): + module_list = cast( + list[str], + await PluginInfo.filter(plugin_type=PluginType.NORMAL).values_list( + "module", flat=True + ), + ) if status: - for module in module_list: - group.block_plugin = group.block_plugin.replace( - f"<{module},", "" - ) + # 开启所有功能 - 清空禁用列表 + group.block_plugin = "" else: - module_list = [f"<{module}" for module in module_list] - group.block_plugin = ",".join(module_list) + "," # type: ignore + # 关闭所有功能 - 将模块列表转换为禁用格式 + group.block_plugin = CommonUtils.convert_module_format(module_list) await group.save(update_fields=["block_plugin"]) return f"成功将此群组所有功能状态修改为: {'开启' if status else '关闭'}" return "获取群组失败..." @@ -235,9 +234,7 @@ class PluginManage: 返回: bool: 是否醒来 """ - if c := await GroupConsole.get_or_none( - group_id=group_id, channel_id__isnull=True - ): + if c := await GroupConsole.get_group(group_id=group_id): return c.status return False @@ -428,17 +425,18 @@ class PluginManage: """ status_str = "关闭" if status else "开启" if is_all: - modules = await TaskInfo.annotate().values_list("module", flat=True) - if modules: + module_list = cast( + list[str], await TaskInfo.annotate().values_list("module", flat=True) + ) + if module_list: group, _ = await GroupConsole.get_or_create( group_id=group_id, channel_id__isnull=True ) - modules = [f"<{module}" for module in modules] if status: - group.block_task = ",".join(modules) + "," # type: ignore + group.block_task = CommonUtils.convert_module_format(module_list) else: - for module in modules: - group.block_task = group.block_task.replace(f"{module},", "") + # 开启所有模块 - 清空禁用列表 + group.block_task = "" await group.save(update_fields=["block_task"]) return f"已成功{status_str}全部被动技能!" elif task := await TaskInfo.get_or_none(name=task_name): diff --git a/zhenxun/builtin_plugins/help/_utils.py b/zhenxun/builtin_plugins/help/_utils.py index 0554fc8d..4d256dc0 100644 --- a/zhenxun/builtin_plugins/help/_utils.py +++ b/zhenxun/builtin_plugins/help/_utils.py @@ -45,7 +45,7 @@ async def classify_plugin( """ sort_data = await sort_type() classify: dict[str, list] = {} - group = await GroupConsole.get_or_none(group_id=group_id) if group_id else None + group = await GroupConsole.get_group(group_id=group_id) if group_id else None bot = await BotConsole.get_or_none(bot_id=session.self_id) for menu, value in sort_data.items(): for plugin in value: diff --git a/zhenxun/builtin_plugins/help/normal_help.py b/zhenxun/builtin_plugins/help/normal_help.py index 0ef9aa89..f381f900 100644 --- a/zhenxun/builtin_plugins/help/normal_help.py +++ b/zhenxun/builtin_plugins/help/normal_help.py @@ -45,7 +45,7 @@ async def build_normal_image(group_id: str | None, is_detail: bool) -> BuildImag color="black" if idx % 2 else "white", ) curr_h = 10 - group = await GroupConsole.get_or_none(group_id=group_id) + group = await GroupConsole.get_group(group_id=group_id) if group_id else None for _, plugin in enumerate(plugin_list): text_color = (255, 255, 255) if idx % 2 else (0, 0, 0) if group and f"{plugin.module}," in group.block_plugin: @@ -80,7 +80,7 @@ async def build_normal_image(group_id: str | None, is_detail: bool) -> BuildImag width, height = 10, 10 for s in [ "目前支持的功能列表:", - "可以通过 ‘帮助 [功能名称或功能Id]’ 来获取对应功能的使用方法", + "可以通过 '帮助 [功能名称或功能Id]' 来获取对应功能的使用方法", ]: text = await BuildImage.build_text_image(s, "HYWenHei-85W.ttf", 24) await result.paste(text, (width, height)) diff --git a/zhenxun/builtin_plugins/statistics/_data_source.py b/zhenxun/builtin_plugins/statistics/_data_source.py index d51cb685..134ece87 100644 --- a/zhenxun/builtin_plugins/statistics/_data_source.py +++ b/zhenxun/builtin_plugins/statistics/_data_source.py @@ -45,9 +45,7 @@ class StatisticsManage: title = f"{user.user_name if user else user_id} {day_type}功能调用统计" elif group_id: """查群组""" - group = await GroupConsole.get_or_none( - group_id=group_id, channel_id__isnull=True - ) + group = await GroupConsole.get_group(group_id=group_id) title = f"{group.group_name if group else group_id} {day_type}功能调用统计" else: title = "功能调用统计" diff --git a/zhenxun/builtin_plugins/superuser/group_manage.py b/zhenxun/builtin_plugins/superuser/group_manage.py index fb8c0d2e..ee3c5687 100644 --- a/zhenxun/builtin_plugins/superuser/group_manage.py +++ b/zhenxun/builtin_plugins/superuser/group_manage.py @@ -163,7 +163,7 @@ async def _(session: EventSession, arparma: Arparma, state: T_State, level: int) @_matcher.assign("super-handle", parameterless=[CheckGroupId()]) async def _(session: EventSession, arparma: Arparma, state: T_State): gid = state["group_id"] - group = await GroupConsole.get_or_none(group_id=gid) + group = await GroupConsole.get_group(group_id=gid) if not group: await MessageUtils.build_message("群组信息不存在, 请更新群组信息...").finish() s = "删除" if arparma.find("delete") else "添加" diff --git a/zhenxun/builtin_plugins/web_ui/api/tabs/manage/data_source.py b/zhenxun/builtin_plugins/web_ui/api/tabs/manage/data_source.py index 39de7736..0b068e17 100644 --- a/zhenxun/builtin_plugins/web_ui/api/tabs/manage/data_source.py +++ b/zhenxun/builtin_plugins/web_ui/api/tabs/manage/data_source.py @@ -250,7 +250,7 @@ class ApiDataSource: 返回: GroupDetail | None: 群组详情数据 """ - group = await GroupConsole.get_or_none(group_id=group_id) + group = await GroupConsole.get_group(group_id=group_id) if not group: return None like_plugin = await cls.__get_group_detail_like_plugin(group_id) diff --git a/zhenxun/builtin_plugins/web_ui/api/tabs/system/__init__.py b/zhenxun/builtin_plugins/web_ui/api/tabs/system/__init__.py index b8ae2481..778ca846 100644 --- a/zhenxun/builtin_plugins/web_ui/api/tabs/system/__init__.py +++ b/zhenxun/builtin_plugins/web_ui/api/tabs/system/__init__.py @@ -45,10 +45,10 @@ async def _(path: str | None = None) -> Result[list[DirFile]]: mtime=file_path.stat().st_mtime, ) ) - sorted(data_list, key=lambda f: f.name) - return Result.ok(data_list) + data_list.sort(key=lambda f: f.name) + return Result.ok(data_list) except Exception as e: - return Result.fail(f"获取文件列表失败: {e!s}") + return Result.fail(f"获取文件列表失败: {e!s}") @router.get( diff --git a/zhenxun/models/ban_console.py b/zhenxun/models/ban_console.py index 2df33519..6c6f895b 100644 --- a/zhenxun/models/ban_console.py +++ b/zhenxun/models/ban_console.py @@ -54,12 +54,12 @@ class BanConsole(Model): raise UserAndGroupIsNone() if user_id: return ( - await cls.get_or_none(user_id=user_id, group_id=group_id) + await cls.safe_get_or_none(user_id=user_id, group_id=group_id) if group_id - else await cls.get_or_none(user_id=user_id, group_id__isnull=True) + else await cls.safe_get_or_none(user_id=user_id, group_id__isnull=True) ) else: - return await cls.get_or_none(user_id="", group_id=group_id) + return await cls.safe_get_or_none(user_id="", group_id=group_id) @classmethod async def check_ban_level( @@ -175,3 +175,25 @@ class BanConsole(Model): await user.delete() return True return False + + @classmethod + async def get_ban( + cls, + *, + id: int | None = None, + user_id: str | None = None, + group_id: str | None = None, + ) -> Self | None: + """安全地获取ban记录 + + 参数: + id: 记录id + user_id: 用户id + group_id: 群组id + + 返回: + Self | None: ban记录 + """ + if id is not None: + return await cls.safe_get_or_none(id=id) + return await cls._get_data(user_id, group_id) diff --git a/zhenxun/models/group_console.py b/zhenxun/models/group_console.py index 123c8411..8e81000b 100644 --- a/zhenxun/models/group_console.py +++ b/zhenxun/models/group_console.py @@ -218,20 +218,32 @@ class GroupConsole(Model): @classmethod async def get_group( - cls, group_id: str, channel_id: str | None = None + cls, + group_id: str, + channel_id: str | None = None, + clean_duplicates: bool = True, ) -> Self | None: """获取群组 参数: group_id: 群组id - channel_id: 频道id. + channel_id: 频道id + clean_duplicates: 是否删除重复的记录,仅保留最新的 返回: Self: GroupConsole """ if channel_id: - return await cls.get_or_none(group_id=group_id, channel_id=channel_id) - return await cls.get_or_none(group_id=group_id, channel_id__isnull=True) + return await cls.safe_get_or_none( + group_id=group_id, + channel_id=channel_id, + clean_duplicates=clean_duplicates, + ) + return await cls.safe_get_or_none( + group_id=group_id, + channel_id__isnull=True, + clean_duplicates=clean_duplicates, + ) @classmethod async def is_super_group(cls, group_id: str) -> bool: diff --git a/zhenxun/services/cache.py b/zhenxun/services/cache.py index 4957029f..af78b892 100644 --- a/zhenxun/services/cache.py +++ b/zhenxun/services/cache.py @@ -121,7 +121,7 @@ class CacheData(BaseModel): lazy_load: bool = True # 默认延迟加载 result_model: type | None = None _keys: set[str] = set() # 存储所有缓存键 - _cache: BaseCache | AioCache + cache: BaseCache | AioCache class Config: arbitrary_types_allowed = True @@ -208,7 +208,7 @@ class CacheData(BaseModel): async def get_data(self) -> Any: """从缓存获取数据""" try: - data = await self._cache.get(self.name) + data = await self.cache.get(self.name) # type: ignore logger.debug(f"获取缓存 {self.name} 数据: {data}") # 如果数据为空,尝试重新加载 @@ -307,11 +307,11 @@ class CacheData(BaseModel): logger.debug(f"设置缓存 {self.name} 序列化后数据: {serialized_value}") # 2. 删除旧数据 - await self._cache.delete(self.name) + await self.cache.delete(self.name) # type: ignore logger.debug(f"删除缓存 {self.name} 旧数据") # 3. 设置新数据 - await self._cache.set(self.name, serialized_value, ttl=self.expire) + await self.cache.set(self.name, serialized_value, ttl=self.expire) # type: ignore logger.debug(f"设置缓存 {self.name} 新数据完成") except Exception as e: @@ -321,7 +321,7 @@ class CacheData(BaseModel): async def delete_data(self): """删除缓存数据""" try: - await self._cache.delete(self.name) + await self.cache.delete(self.name) # type: ignore except Exception as e: logger.error(f"删除缓存 {self.name}", e=e) @@ -428,7 +428,7 @@ class CacheData(BaseModel): """ cache_key = self._get_cache_key(key) try: - data = await self._cache.get(cache_key) + data = await self.cache.get(cache_key) # type: ignore logger.debug(f"获取缓存 {cache_key} 数据: {data}") if self.result_model: @@ -467,7 +467,7 @@ class CacheData(BaseModel): for key in self._keys: # 提取原始键名(去掉前缀) original_key = key.split(":", 1)[1] - data = await self._cache.get(key) + data = await self.cache.get(key) # type: ignore if self.result_model: result[original_key] = self._deserialize_value( data, self.result_model @@ -484,7 +484,7 @@ class CacheData(BaseModel): cache_key = self._get_cache_key(key) try: serialized_value = self._serialize_value(value) - await self._cache.set(cache_key, serialized_value, ttl=self.expire) + await self.cache.set(cache_key, serialized_value, ttl=self.expire) # type: ignore self._keys.add(cache_key) # 添加到键列表 logger.debug(f"设置缓存 {cache_key} 数据完成") except Exception as e: @@ -495,7 +495,7 @@ class CacheData(BaseModel): """删除指定键的缓存数据""" cache_key = self._get_cache_key(key) try: - await self._cache.delete(cache_key) + await self.cache.delete(cache_key) # type: ignore self._keys.discard(cache_key) # 从键列表中移除 logger.debug(f"删除缓存 {cache_key} 完成") except Exception as e: @@ -505,7 +505,7 @@ class CacheData(BaseModel): """清除所有缓存数据""" try: for key in list(self._keys): # 使用列表复制避免在迭代时修改 - await self._cache.delete(key) + await self.cache.delete(key) # type: ignore self._keys.clear() logger.debug(f"清除缓存 {self.name} 完成") except Exception as e: @@ -524,7 +524,7 @@ class CacheManager: if self._cache_instance is None: if config.redis_host: self._cache_instance = AioCache( - AioCache.REDIS, + AioCache.REDIS, # type: ignore serializer=JsonSerializer(), namespace="zhenxun_cache", timeout=30, # 操作超时时间 @@ -546,12 +546,12 @@ class CacheManager: async def close(self): if self._cache_instance: - await self._cache_instance.close() + await self._cache_instance.close() # type: ignore async def verify_connection(self): """连接测试""" try: - await self._cache.get("__test__") + await self._cache.get("__test__") # type: ignore except Exception as e: logger.error("连接失败", LOG_COMMAND, e=e) raise @@ -560,7 +560,7 @@ class CacheManager: """初始化所有非延迟加载的缓存""" await self.verify_connection() for name, cache in self._data.items(): - cache._cache = self._cache + cache.cache = self._cache if not cache.lazy_load: try: await cache.reload() @@ -587,7 +587,7 @@ class CacheManager: func=func, expire=expire, lazy_load=lazy_load, - _cache=self._cache, + cache=self._cache, ) return func diff --git a/zhenxun/services/db_context.py b/zhenxun/services/db_context.py index c9203a7d..4e74f600 100644 --- a/zhenxun/services/db_context.py +++ b/zhenxun/services/db_context.py @@ -72,12 +72,23 @@ class Model(TortoiseModel): using_db: BaseDBAsyncClient | None = None, **kwargs: Any, ) -> tuple[Self, bool]: - result, is_create = await super().get_or_create( - defaults=defaults, using_db=using_db, **kwargs - ) - if is_create and (cache_type := cls.get_cache_type()): - await CacheRoot.reload(cache_type) - return (result, is_create) + if sem := cls.get_semaphore(DbLockType.CREATE): + async with sem: + # 在锁内执行查询和创建操作 + result, is_create = await super().get_or_create( + defaults=defaults, using_db=using_db, **kwargs + ) + if is_create and (cache_type := cls.get_cache_type()): + await CacheRoot.reload(cache_type) + return (result, is_create) + else: + # 如果没有锁,则执行原来的逻辑 + result, is_create = await super().get_or_create( + defaults=defaults, using_db=using_db, **kwargs + ) + if is_create and (cache_type := cls.get_cache_type()): + await CacheRoot.reload(cache_type) + return (result, is_create) @classmethod async def update_or_create( @@ -86,12 +97,23 @@ class Model(TortoiseModel): using_db: BaseDBAsyncClient | None = None, **kwargs: Any, ) -> tuple[Self, bool]: - result = await super().update_or_create( - defaults=defaults, using_db=using_db, **kwargs - ) - if cache_type := cls.get_cache_type(): - await CacheRoot.reload(cache_type) - return result + if sem := cls.get_semaphore(DbLockType.CREATE): + async with sem: + # 在锁内执行查询和创建操作 + result = await super().update_or_create( + defaults=defaults, using_db=using_db, **kwargs + ) + if cache_type := cls.get_cache_type(): + await CacheRoot.reload(cache_type) + return result + else: + # 如果没有锁,则执行原来的逻辑 + result = await super().update_or_create( + defaults=defaults, using_db=using_db, **kwargs + ) + if cache_type := cls.get_cache_type(): + await CacheRoot.reload(cache_type) + return result @classmethod async def bulk_create( # type: ignore @@ -167,6 +189,65 @@ class Model(TortoiseModel): if CACHE_FLAG and (cache_type := getattr(self, "cache_type", None)): await CacheRoot.reload(cache_type) + @classmethod + async def safe_get_or_none( + cls, + *args, + using_db: BaseDBAsyncClient | None = None, + clean_duplicates: bool = True, + **kwargs: Any, + ) -> Self | None: + """安全地获取一条记录或None,处理存在多个记录时返回最新的那个 + 注意,默认会删除重复的记录,仅保留最新的 + + 参数: + *args: 查询参数 + using_db: 数据库连接 + clean_duplicates: 是否删除重复的记录,仅保留最新的 + **kwargs: 查询参数 + + 返回: + Self | None: 查询结果,如果不存在返回None + """ + try: + # 先尝试使用 get_or_none 获取单个记录 + return await cls.get_or_none(*args, using_db=using_db, **kwargs) + except Exception as e: + # 如果出现错误(可能是存在多个记录) + if "Multiple objects" in str(e): + logger.warning( + f"{cls.__name__} safe_get_or_none 发现多个记录: {kwargs}" + ) + # 查询所有匹配记录 + records = await cls.filter(*args, **kwargs) + + if not records: + return None + + # 如果需要清理重复记录 + if clean_duplicates and hasattr(cls, "id"): + # 按 id 排序 + records = sorted( + records, key=lambda x: getattr(x, "id", 0), reverse=True + ) + for record in records[1:]: + try: + await record.delete() + logger.info( + f"{cls.__name__} 删除重复记录:" + f" id={getattr(record, 'id', None)}" + ) + except Exception as del_e: + logger.error(f"删除重复记录失败: {del_e}") + return records[0] + # 如果不需要清理或没有 id 字段,则返回最新的记录 + if hasattr(cls, "id"): + return await cls.filter(*args, **kwargs).order_by("-id").first() + # 如果没有 id 字段,则返回第一个记录 + return await cls.filter(*args, **kwargs).first() + # 其他类型的错误则继续抛出 + raise + class DbUrlIsNode(HookPriorityException): """ diff --git a/zhenxun/utils/common_utils.py b/zhenxun/utils/common_utils.py index cc143898..cfdabdc5 100644 --- a/zhenxun/utils/common_utils.py +++ b/zhenxun/utils/common_utils.py @@ -53,9 +53,7 @@ class CommonUtils: if await GroupConsole.is_block_task(group_id, module): """群组是否禁用被动""" return True - if g := await GroupConsole.get_or_none( - group_id=group_id, channel_id__isnull=True - ): + if g := await GroupConsole.get_group(group_id=group_id): """群组权限是否小于0""" if g.level < 0: return True