diff --git a/zhenxun/builtin_plugins/hooks/ban_hook.py b/zhenxun/builtin_plugins/hooks/ban_hook.py index fd3b101b..a77a7300 100644 --- a/zhenxun/builtin_plugins/hooks/ban_hook.py +++ b/zhenxun/builtin_plugins/hooks/ban_hook.py @@ -64,44 +64,29 @@ async def _(matcher: Matcher, bot: Bot, session: Uninfo): raise IgnoredException("群黑名单, 群权限-1..") if user_id: ban_result = Config.get_config("hook", "BAN_RESULT") - try: - if await is_ban(user_id, group_id): - time = await BanConsole.check_ban_time(user_id, group_id) - if time == -1: - time_str = "∞" - else: - time = abs(int(time)) - if time < 60: - time_str = f"{time!s} 秒" - else: - minute = int(time / 60) - if minute > 60: - hours = minute // 60 - minute %= 60 - time_str = f"{hours} 小时 {minute}分钟" - else: - time_str = f"{minute} 分钟" - if time != -1 and ban_result and _flmt.check(user_id): - _flmt.start_cd(user_id) - await MessageUtils.build_message( - [ - At(flag="user", target=user_id), - f"{ban_result}\n在..在 {time_str} 后才会理你喔", - ] - ).send() - logger.debug("用户处于黑名单中...", "BanChecker") - raise IgnoredException("用户处于黑名单中...") - except MultipleObjectsReturned: - logger.warning( - "黑名单数据重复,过滤该次hook并移除多余数据...", "BanChecker" - ) - if group_id: - ids = await BanConsole.filter( - user_id=user_id, group_id=group_id - ).values_list("id", flat=True) + if await is_ban(user_id, group_id): + time = await BanConsole.check_ban_time(user_id, group_id) + if time == -1: + time_str = "∞" else: - ids = await BanConsole.filter( - user_id=user_id, group_id__isnull=True - ).values_list("id", flat=True) - await BanConsole.filter(id__in=ids[:-1]).delete() - await cache.reload() + time = abs(int(time)) + if time < 60: + time_str = f"{time!s} 秒" + else: + minute = int(time / 60) + if minute > 60: + hours = minute // 60 + minute %= 60 + time_str = f"{hours} 小时 {minute}分钟" + else: + time_str = f"{minute} 分钟" + if time != -1 and ban_result and _flmt.check(user_id): + _flmt.start_cd(user_id) + await MessageUtils.build_message( + [ + At(flag="user", target=user_id), + f"{ban_result}\n在..在 {time_str} 后才会理你喔", + ] + ).send() + logger.debug("用户处于黑名单中...", "BanChecker") + raise IgnoredException("用户处于黑名单中...") diff --git a/zhenxun/builtin_plugins/platform/qq/group_handle/__init__.py b/zhenxun/builtin_plugins/platform/qq/group_handle/__init__.py index 0cde0a6c..a5ba1363 100644 --- a/zhenxun/builtin_plugins/platform/qq/group_handle/__init__.py +++ b/zhenxun/builtin_plugins/platform/qq/group_handle/__init__.py @@ -103,7 +103,7 @@ group_increase_handle = on_notice( group_decrease_handle = on_notice( priority=1, block=False, - rule=notice_rule([GroupMemberDecreaseEvent, GroupMemberIncreaseEvent]), + rule=notice_rule([GroupMemberDecreaseEvent, GroupDecreaseNoticeEvent]), ) """群员减少处理""" add_group = on_request(priority=1, block=False) diff --git a/zhenxun/builtin_plugins/platform/qq/group_handle/data_source.py b/zhenxun/builtin_plugins/platform/qq/group_handle/data_source.py index fc188eb1..893d9e94 100644 --- a/zhenxun/builtin_plugins/platform/qq/group_handle/data_source.py +++ b/zhenxun/builtin_plugins/platform/qq/group_handle/data_source.py @@ -55,14 +55,16 @@ class GroupManager: for plugin in plugin_list: block_plugin += f"<{plugin.module}," group_info = await bot.get_group_info(group_id=group_id) - await GroupConsole.create( + await GroupConsole.update_or_create( group_id=group_info["group_id"], - group_name=group_info["group_name"], - max_member_count=group_info["max_member_count"], - member_count=group_info["member_count"], - group_flag=1, - block_plugin=block_plugin, - platform="qq", + defaults={ + "group_name": group_info["group_name"], + "max_member_count": group_info["max_member_count"], + "member_count": group_info["member_count"], + "group_flag": 1, + "block_plugin": block_plugin, + "platform": "qq", + }, ) @classmethod @@ -144,7 +146,7 @@ class GroupManager: e=e, ) raise ForceAddGroupError("强制拉群或未有群信息,退出群聊失败...") from e - await GroupConsole.filter(group_id=group_id).delete() + # await GroupConsole.filter(group_id=group_id).delete() raise ForceAddGroupError(f"触发强制入群保护,已成功退出群聊 {group_id}...") else: await cls.__handle_add_group(bot, group_id, group) diff --git a/zhenxun/builtin_plugins/platform/qq_api/ug_watch.py b/zhenxun/builtin_plugins/platform/qq_api/ug_watch.py index 4e7a708c..4435e880 100644 --- a/zhenxun/builtin_plugins/platform/qq_api/ug_watch.py +++ b/zhenxun/builtin_plugins/platform/qq_api/ug_watch.py @@ -1,4 +1,4 @@ -from nonebot.message import run_preprocessor +from nonebot import on_message from nonebot_plugin_uninfo import Uninfo from zhenxun.models.friend_user import FriendUser @@ -8,24 +8,27 @@ from zhenxun.services.log import logger from zhenxun.utils.platform import PlatformUtils -@run_preprocessor -async def do_something(session: Uninfo): +def rule(session: Uninfo) -> bool: + return PlatformUtils.is_qbot(session) + + +_matcher = on_message(priority=999, block=False, rule=rule) + + +@_matcher.handle() +async def _(session: Uninfo): platform = PlatformUtils.get_platform(session) if session.group: if not await GroupConsole.exists(group_id=session.group.id): await GroupConsole.create(group_id=session.group.id) - logger.info("添加当前群组ID信息" "", session=session) - - if not await GroupInfoUser.exists( - user_id=session.user.id, group_id=session.group.id - ): - await GroupInfoUser.create( - user_id=session.user.id, group_id=session.group.id, platform=platform - ) - logger.info("添加当前用户群组ID信息", "", session=session) + logger.info("添加当前群组ID信息", session=session) + await GroupInfoUser.update_or_create( + user_id=session.user.id, + group_id=session.group.id, + platform=PlatformUtils.get_platform(session), + ) elif not await FriendUser.exists(user_id=session.user.id, platform=platform): - try: - await FriendUser.create(user_id=session.user.id, platform=platform) - logger.info("添加当前好友用户信息", "", session=session) - except Exception as e: - logger.error("添加当前好友用户信息失败", session=session, e=e) + await FriendUser.create( + user_id=session.user.id, platform=PlatformUtils.get_platform(session) + ) + logger.info("添加当前好友用户信息", "", session=session) diff --git a/zhenxun/builtin_plugins/record_request.py b/zhenxun/builtin_plugins/record_request.py index 8577692b..41193aef 100644 --- a/zhenxun/builtin_plugins/record_request.py +++ b/zhenxun/builtin_plugins/record_request.py @@ -104,7 +104,7 @@ async def _(bot: v12Bot | v11Bot, event: FriendRequestEvent, session: EventSessi await PlatformUtils.send_superuser( bot, f"*****一份好友申请*****\n" - f"ID: {f.id}" + f"ID: {f.id}\n" f"昵称:{nickname}({event.user_id})\n" f"自动同意:{'√' if base_config.get('AUTO_ADD_FRIEND') else '×'}\n" f"日期:{str(datetime.now()).split('.')[0]}\n" diff --git a/zhenxun/models/ban_console.py b/zhenxun/models/ban_console.py index 8b5fbd03..55d7422d 100644 --- a/zhenxun/models/ban_console.py +++ b/zhenxun/models/ban_console.py @@ -1,11 +1,12 @@ import time +from typing import ClassVar from typing_extensions import Self from tortoise import fields from zhenxun.services.db_context import Model from zhenxun.services.log import logger -from zhenxun.utils.enum import CacheType +from zhenxun.utils.enum import CacheType, DbLockType from zhenxun.utils.exception import UserAndGroupIsNone @@ -31,6 +32,9 @@ class BanConsole(Model): unique_together = ("user_id", "group_id") cache_type = CacheType.BAN + """缓存类型""" + enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE] + """开启锁""" @classmethod async def _get_data(cls, user_id: str | None, group_id: str | None) -> Self | None: diff --git a/zhenxun/models/group_console.py b/zhenxun/models/group_console.py index 56accf66..373e5f68 100644 --- a/zhenxun/models/group_console.py +++ b/zhenxun/models/group_console.py @@ -1,4 +1,4 @@ -from typing import Any, overload +from typing import Any, ClassVar, overload from typing_extensions import Self from tortoise import fields @@ -8,7 +8,7 @@ from zhenxun.models.plugin_info import PluginInfo from zhenxun.models.task_info import TaskInfo from zhenxun.services.cache import CacheRoot from zhenxun.services.db_context import Model -from zhenxun.utils.enum import CacheType, PluginType +from zhenxun.utils.enum import CacheType, DbLockType, PluginType class GroupConsole(Model): @@ -53,6 +53,9 @@ class GroupConsole(Model): unique_together = ("group_id", "channel_id") cache_type = CacheType.GROUPS + """缓存类型""" + enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE] + """开启锁""" @staticmethod def format(name: str) -> str: diff --git a/zhenxun/services/db_context.py b/zhenxun/services/db_context.py index 2ac35b18..359fc5f2 100644 --- a/zhenxun/services/db_context.py +++ b/zhenxun/services/db_context.py @@ -1,5 +1,6 @@ +from asyncio import Semaphore from collections.abc import Iterable -from typing import Any +from typing import Any, ClassVar from typing_extensions import Self import nonebot @@ -10,6 +11,7 @@ from tortoise.connection import connections from tortoise.models import Model as TortoiseModel from zhenxun.configs.config import BotConfig +from zhenxun.utils.enum import DbLockType from .cache import CacheRoot from .log import logger @@ -31,16 +33,24 @@ def _(): class Model(TortoiseModel): """ 自动添加模块 - - Args: - Model_: Model """ + sem_data: ClassVar[dict[str, dict[str, Semaphore]]] = {} + def __init_subclass__(cls, **kwargs): MODELS.append(cls.__module__) if func := getattr(cls, "_run_script", None): SCRIPT_METHOD.append((cls.__module__, func)) + if enable_lock := getattr(cls, "enable_lock", []): + """创建锁""" + cls.sem_data[cls.__module__] = {} + for lock in enable_lock: + cls.sem_data[cls.__module__][lock] = Semaphore(1) + + @classmethod + def get_semaphore(cls, lock_type: DbLockType): + return cls.sem_data.get(cls.__module__, {}).get(lock_type, None) @classmethod def get_cache_type(cls): @@ -50,10 +60,7 @@ class Model(TortoiseModel): async def create( cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any ) -> Self: - result = await super().create(using_db=using_db, **kwargs) - if cache_type := cls.get_cache_type(): - await CacheRoot.reload(cache_type) - return result + return await super().create(using_db=using_db, **kwargs) @classmethod async def get_or_create( @@ -130,12 +137,25 @@ class Model(TortoiseModel): force_create: bool = False, force_update: bool = False, ): - await super().save( - using_db=using_db, - update_fields=update_fields, - force_create=force_create, - force_update=force_update, - ) + if getattr(self, "id", None) is None: + sem = self.get_semaphore(DbLockType.CREATE) + else: + sem = self.get_semaphore(DbLockType.UPDATE) + if sem: + async with sem: + await super().save( + using_db=using_db, + update_fields=update_fields, + force_create=force_create, + force_update=force_update, + ) + else: + await super().save( + using_db=using_db, + update_fields=update_fields, + force_create=force_create, + force_update=force_update, + ) if CACHE_FLAG and (cache_type := getattr(self, "cache_type", None)): await CacheRoot.reload(cache_type) diff --git a/zhenxun/utils/enum.py b/zhenxun/utils/enum.py index 001dae6f..c6d31572 100644 --- a/zhenxun/utils/enum.py +++ b/zhenxun/utils/enum.py @@ -20,6 +20,21 @@ class CacheType(StrEnum): """用户权限""" +class DbLockType(StrEnum): + """ + 锁类型 + """ + + CREATE = "CREATE" + """创建""" + DELETE = "DELETE" + """删除""" + UPDATE = "UPDATE" + """更新""" + QUERY = "QUERY" + """查询""" + + class GoldHandle(StrEnum): """ 金币处理