数据库添加semaphore锁

This commit is contained in:
HibiKier 2025-02-03 13:56:12 +08:00
parent 4929162dcd
commit f80e541047
9 changed files with 116 additions and 84 deletions

View File

@ -64,44 +64,29 @@ async def _(matcher: Matcher, bot: Bot, session: Uninfo):
raise IgnoredException("群黑名单, 群权限-1..") raise IgnoredException("群黑名单, 群权限-1..")
if user_id: if user_id:
ban_result = Config.get_config("hook", "BAN_RESULT") ban_result = Config.get_config("hook", "BAN_RESULT")
try: if await is_ban(user_id, group_id):
if await is_ban(user_id, group_id): time = await BanConsole.check_ban_time(user_id, group_id)
time = await BanConsole.check_ban_time(user_id, group_id) if time == -1:
if time == -1: time_str = ""
time_str = ""
else:
time = abs(int(time))
if time < 60:
time_str = f"{time!s}"
else:
minute = int(time / 60)
if minute > 60:
hours = minute // 60
minute %= 60
time_str = f"{hours} 小时 {minute}分钟"
else:
time_str = f"{minute} 分钟"
if time != -1 and ban_result and _flmt.check(user_id):
_flmt.start_cd(user_id)
await MessageUtils.build_message(
[
At(flag="user", target=user_id),
f"{ban_result}\n在..在 {time_str} 后才会理你喔",
]
).send()
logger.debug("用户处于黑名单中...", "BanChecker")
raise IgnoredException("用户处于黑名单中...")
except MultipleObjectsReturned:
logger.warning(
"黑名单数据重复过滤该次hook并移除多余数据...", "BanChecker"
)
if group_id:
ids = await BanConsole.filter(
user_id=user_id, group_id=group_id
).values_list("id", flat=True)
else: else:
ids = await BanConsole.filter( time = abs(int(time))
user_id=user_id, group_id__isnull=True if time < 60:
).values_list("id", flat=True) time_str = f"{time!s}"
await BanConsole.filter(id__in=ids[:-1]).delete() else:
await cache.reload() minute = int(time / 60)
if minute > 60:
hours = minute // 60
minute %= 60
time_str = f"{hours} 小时 {minute}分钟"
else:
time_str = f"{minute} 分钟"
if time != -1 and ban_result and _flmt.check(user_id):
_flmt.start_cd(user_id)
await MessageUtils.build_message(
[
At(flag="user", target=user_id),
f"{ban_result}\n在..在 {time_str} 后才会理你喔",
]
).send()
logger.debug("用户处于黑名单中...", "BanChecker")
raise IgnoredException("用户处于黑名单中...")

View File

@ -103,7 +103,7 @@ group_increase_handle = on_notice(
group_decrease_handle = on_notice( group_decrease_handle = on_notice(
priority=1, priority=1,
block=False, block=False,
rule=notice_rule([GroupMemberDecreaseEvent, GroupMemberIncreaseEvent]), rule=notice_rule([GroupMemberDecreaseEvent, GroupDecreaseNoticeEvent]),
) )
"""群员减少处理""" """群员减少处理"""
add_group = on_request(priority=1, block=False) add_group = on_request(priority=1, block=False)

View File

@ -55,14 +55,16 @@ class GroupManager:
for plugin in plugin_list: for plugin in plugin_list:
block_plugin += f"<{plugin.module}," block_plugin += f"<{plugin.module},"
group_info = await bot.get_group_info(group_id=group_id) group_info = await bot.get_group_info(group_id=group_id)
await GroupConsole.create( await GroupConsole.update_or_create(
group_id=group_info["group_id"], group_id=group_info["group_id"],
group_name=group_info["group_name"], defaults={
max_member_count=group_info["max_member_count"], "group_name": group_info["group_name"],
member_count=group_info["member_count"], "max_member_count": group_info["max_member_count"],
group_flag=1, "member_count": group_info["member_count"],
block_plugin=block_plugin, "group_flag": 1,
platform="qq", "block_plugin": block_plugin,
"platform": "qq",
},
) )
@classmethod @classmethod
@ -144,7 +146,7 @@ class GroupManager:
e=e, e=e,
) )
raise ForceAddGroupError("强制拉群或未有群信息,退出群聊失败...") from e raise ForceAddGroupError("强制拉群或未有群信息,退出群聊失败...") from e
await GroupConsole.filter(group_id=group_id).delete() # await GroupConsole.filter(group_id=group_id).delete()
raise ForceAddGroupError(f"触发强制入群保护,已成功退出群聊 {group_id}...") raise ForceAddGroupError(f"触发强制入群保护,已成功退出群聊 {group_id}...")
else: else:
await cls.__handle_add_group(bot, group_id, group) await cls.__handle_add_group(bot, group_id, group)

View File

@ -1,4 +1,4 @@
from nonebot.message import run_preprocessor from nonebot import on_message
from nonebot_plugin_uninfo import Uninfo from nonebot_plugin_uninfo import Uninfo
from zhenxun.models.friend_user import FriendUser from zhenxun.models.friend_user import FriendUser
@ -8,24 +8,27 @@ from zhenxun.services.log import logger
from zhenxun.utils.platform import PlatformUtils from zhenxun.utils.platform import PlatformUtils
@run_preprocessor def rule(session: Uninfo) -> bool:
async def do_something(session: Uninfo): return PlatformUtils.is_qbot(session)
_matcher = on_message(priority=999, block=False, rule=rule)
@_matcher.handle()
async def _(session: Uninfo):
platform = PlatformUtils.get_platform(session) platform = PlatformUtils.get_platform(session)
if session.group: if session.group:
if not await GroupConsole.exists(group_id=session.group.id): if not await GroupConsole.exists(group_id=session.group.id):
await GroupConsole.create(group_id=session.group.id) await GroupConsole.create(group_id=session.group.id)
logger.info("添加当前群组ID信息" "", session=session) logger.info("添加当前群组ID信息", session=session)
await GroupInfoUser.update_or_create(
if not await GroupInfoUser.exists( user_id=session.user.id,
user_id=session.user.id, group_id=session.group.id group_id=session.group.id,
): platform=PlatformUtils.get_platform(session),
await GroupInfoUser.create( )
user_id=session.user.id, group_id=session.group.id, platform=platform
)
logger.info("添加当前用户群组ID信息", "", session=session)
elif not await FriendUser.exists(user_id=session.user.id, platform=platform): elif not await FriendUser.exists(user_id=session.user.id, platform=platform):
try: await FriendUser.create(
await FriendUser.create(user_id=session.user.id, platform=platform) user_id=session.user.id, platform=PlatformUtils.get_platform(session)
logger.info("添加当前好友用户信息", "", session=session) )
except Exception as e: logger.info("添加当前好友用户信息", "", session=session)
logger.error("添加当前好友用户信息失败", session=session, e=e)

View File

@ -104,7 +104,7 @@ async def _(bot: v12Bot | v11Bot, event: FriendRequestEvent, session: EventSessi
await PlatformUtils.send_superuser( await PlatformUtils.send_superuser(
bot, bot,
f"*****一份好友申请*****\n" f"*****一份好友申请*****\n"
f"ID: {f.id}" f"ID: {f.id}\n"
f"昵称:{nickname}({event.user_id})\n" f"昵称:{nickname}({event.user_id})\n"
f"自动同意:{'' if base_config.get('AUTO_ADD_FRIEND') else '×'}\n" f"自动同意:{'' if base_config.get('AUTO_ADD_FRIEND') else '×'}\n"
f"日期:{str(datetime.now()).split('.')[0]}\n" f"日期:{str(datetime.now()).split('.')[0]}\n"

View File

@ -1,11 +1,12 @@
import time import time
from typing import ClassVar
from typing_extensions import Self from typing_extensions import Self
from tortoise import fields from tortoise import fields
from zhenxun.services.db_context import Model from zhenxun.services.db_context import Model
from zhenxun.services.log import logger from zhenxun.services.log import logger
from zhenxun.utils.enum import CacheType from zhenxun.utils.enum import CacheType, DbLockType
from zhenxun.utils.exception import UserAndGroupIsNone from zhenxun.utils.exception import UserAndGroupIsNone
@ -31,6 +32,9 @@ class BanConsole(Model):
unique_together = ("user_id", "group_id") unique_together = ("user_id", "group_id")
cache_type = CacheType.BAN cache_type = CacheType.BAN
"""缓存类型"""
enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE]
"""开启锁"""
@classmethod @classmethod
async def _get_data(cls, user_id: str | None, group_id: str | None) -> Self | None: async def _get_data(cls, user_id: str | None, group_id: str | None) -> Self | None:

View File

@ -1,4 +1,4 @@
from typing import Any, overload from typing import Any, ClassVar, overload
from typing_extensions import Self from typing_extensions import Self
from tortoise import fields from tortoise import fields
@ -8,7 +8,7 @@ from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.task_info import TaskInfo from zhenxun.models.task_info import TaskInfo
from zhenxun.services.cache import CacheRoot from zhenxun.services.cache import CacheRoot
from zhenxun.services.db_context import Model from zhenxun.services.db_context import Model
from zhenxun.utils.enum import CacheType, PluginType from zhenxun.utils.enum import CacheType, DbLockType, PluginType
class GroupConsole(Model): class GroupConsole(Model):
@ -53,6 +53,9 @@ class GroupConsole(Model):
unique_together = ("group_id", "channel_id") unique_together = ("group_id", "channel_id")
cache_type = CacheType.GROUPS cache_type = CacheType.GROUPS
"""缓存类型"""
enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE]
"""开启锁"""
@staticmethod @staticmethod
def format(name: str) -> str: def format(name: str) -> str:

View File

@ -1,5 +1,6 @@
from asyncio import Semaphore
from collections.abc import Iterable from collections.abc import Iterable
from typing import Any from typing import Any, ClassVar
from typing_extensions import Self from typing_extensions import Self
import nonebot import nonebot
@ -10,6 +11,7 @@ from tortoise.connection import connections
from tortoise.models import Model as TortoiseModel from tortoise.models import Model as TortoiseModel
from zhenxun.configs.config import BotConfig from zhenxun.configs.config import BotConfig
from zhenxun.utils.enum import DbLockType
from .cache import CacheRoot from .cache import CacheRoot
from .log import logger from .log import logger
@ -31,16 +33,24 @@ def _():
class Model(TortoiseModel): class Model(TortoiseModel):
""" """
自动添加模块 自动添加模块
Args:
Model_: Model
""" """
sem_data: ClassVar[dict[str, dict[str, Semaphore]]] = {}
def __init_subclass__(cls, **kwargs): def __init_subclass__(cls, **kwargs):
MODELS.append(cls.__module__) MODELS.append(cls.__module__)
if func := getattr(cls, "_run_script", None): if func := getattr(cls, "_run_script", None):
SCRIPT_METHOD.append((cls.__module__, func)) SCRIPT_METHOD.append((cls.__module__, func))
if enable_lock := getattr(cls, "enable_lock", []):
"""创建锁"""
cls.sem_data[cls.__module__] = {}
for lock in enable_lock:
cls.sem_data[cls.__module__][lock] = Semaphore(1)
@classmethod
def get_semaphore(cls, lock_type: DbLockType):
return cls.sem_data.get(cls.__module__, {}).get(lock_type, None)
@classmethod @classmethod
def get_cache_type(cls): def get_cache_type(cls):
@ -50,10 +60,7 @@ class Model(TortoiseModel):
async def create( async def create(
cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any
) -> Self: ) -> Self:
result = await super().create(using_db=using_db, **kwargs) return await super().create(using_db=using_db, **kwargs)
if cache_type := cls.get_cache_type():
await CacheRoot.reload(cache_type)
return result
@classmethod @classmethod
async def get_or_create( async def get_or_create(
@ -130,12 +137,25 @@ class Model(TortoiseModel):
force_create: bool = False, force_create: bool = False,
force_update: bool = False, force_update: bool = False,
): ):
await super().save( if getattr(self, "id", None) is None:
using_db=using_db, sem = self.get_semaphore(DbLockType.CREATE)
update_fields=update_fields, else:
force_create=force_create, sem = self.get_semaphore(DbLockType.UPDATE)
force_update=force_update, if sem:
) async with sem:
await super().save(
using_db=using_db,
update_fields=update_fields,
force_create=force_create,
force_update=force_update,
)
else:
await super().save(
using_db=using_db,
update_fields=update_fields,
force_create=force_create,
force_update=force_update,
)
if CACHE_FLAG and (cache_type := getattr(self, "cache_type", None)): if CACHE_FLAG and (cache_type := getattr(self, "cache_type", None)):
await CacheRoot.reload(cache_type) await CacheRoot.reload(cache_type)

View File

@ -20,6 +20,21 @@ class CacheType(StrEnum):
"""用户权限""" """用户权限"""
class DbLockType(StrEnum):
"""
锁类型
"""
CREATE = "CREATE"
"""创建"""
DELETE = "DELETE"
"""删除"""
UPDATE = "UPDATE"
"""更新"""
QUERY = "QUERY"
"""查询"""
class GoldHandle(StrEnum): class GoldHandle(StrEnum):
""" """
金币处理 金币处理