数据库添加semaphore锁

This commit is contained in:
HibiKier 2025-02-03 13:56:12 +08:00
parent 4929162dcd
commit f80e541047
9 changed files with 116 additions and 84 deletions

View File

@ -64,44 +64,29 @@ async def _(matcher: Matcher, bot: Bot, session: Uninfo):
raise IgnoredException("群黑名单, 群权限-1..")
if user_id:
ban_result = Config.get_config("hook", "BAN_RESULT")
try:
if await is_ban(user_id, group_id):
time = await BanConsole.check_ban_time(user_id, group_id)
if time == -1:
time_str = ""
else:
time = abs(int(time))
if time < 60:
time_str = f"{time!s}"
else:
minute = int(time / 60)
if minute > 60:
hours = minute // 60
minute %= 60
time_str = f"{hours} 小时 {minute}分钟"
else:
time_str = f"{minute} 分钟"
if time != -1 and ban_result and _flmt.check(user_id):
_flmt.start_cd(user_id)
await MessageUtils.build_message(
[
At(flag="user", target=user_id),
f"{ban_result}\n在..在 {time_str} 后才会理你喔",
]
).send()
logger.debug("用户处于黑名单中...", "BanChecker")
raise IgnoredException("用户处于黑名单中...")
except MultipleObjectsReturned:
logger.warning(
"黑名单数据重复过滤该次hook并移除多余数据...", "BanChecker"
)
if group_id:
ids = await BanConsole.filter(
user_id=user_id, group_id=group_id
).values_list("id", flat=True)
if await is_ban(user_id, group_id):
time = await BanConsole.check_ban_time(user_id, group_id)
if time == -1:
time_str = ""
else:
ids = await BanConsole.filter(
user_id=user_id, group_id__isnull=True
).values_list("id", flat=True)
await BanConsole.filter(id__in=ids[:-1]).delete()
await cache.reload()
time = abs(int(time))
if time < 60:
time_str = f"{time!s}"
else:
minute = int(time / 60)
if minute > 60:
hours = minute // 60
minute %= 60
time_str = f"{hours} 小时 {minute}分钟"
else:
time_str = f"{minute} 分钟"
if time != -1 and ban_result and _flmt.check(user_id):
_flmt.start_cd(user_id)
await MessageUtils.build_message(
[
At(flag="user", target=user_id),
f"{ban_result}\n在..在 {time_str} 后才会理你喔",
]
).send()
logger.debug("用户处于黑名单中...", "BanChecker")
raise IgnoredException("用户处于黑名单中...")

View File

@ -103,7 +103,7 @@ group_increase_handle = on_notice(
group_decrease_handle = on_notice(
priority=1,
block=False,
rule=notice_rule([GroupMemberDecreaseEvent, GroupMemberIncreaseEvent]),
rule=notice_rule([GroupMemberDecreaseEvent, GroupDecreaseNoticeEvent]),
)
"""群员减少处理"""
add_group = on_request(priority=1, block=False)

View File

@ -55,14 +55,16 @@ class GroupManager:
for plugin in plugin_list:
block_plugin += f"<{plugin.module},"
group_info = await bot.get_group_info(group_id=group_id)
await GroupConsole.create(
await GroupConsole.update_or_create(
group_id=group_info["group_id"],
group_name=group_info["group_name"],
max_member_count=group_info["max_member_count"],
member_count=group_info["member_count"],
group_flag=1,
block_plugin=block_plugin,
platform="qq",
defaults={
"group_name": group_info["group_name"],
"max_member_count": group_info["max_member_count"],
"member_count": group_info["member_count"],
"group_flag": 1,
"block_plugin": block_plugin,
"platform": "qq",
},
)
@classmethod
@ -144,7 +146,7 @@ class GroupManager:
e=e,
)
raise ForceAddGroupError("强制拉群或未有群信息,退出群聊失败...") from e
await GroupConsole.filter(group_id=group_id).delete()
# await GroupConsole.filter(group_id=group_id).delete()
raise ForceAddGroupError(f"触发强制入群保护,已成功退出群聊 {group_id}...")
else:
await cls.__handle_add_group(bot, group_id, group)

View File

@ -1,4 +1,4 @@
from nonebot.message import run_preprocessor
from nonebot import on_message
from nonebot_plugin_uninfo import Uninfo
from zhenxun.models.friend_user import FriendUser
@ -8,24 +8,27 @@ from zhenxun.services.log import logger
from zhenxun.utils.platform import PlatformUtils
@run_preprocessor
async def do_something(session: Uninfo):
def rule(session: Uninfo) -> bool:
return PlatformUtils.is_qbot(session)
_matcher = on_message(priority=999, block=False, rule=rule)
@_matcher.handle()
async def _(session: Uninfo):
platform = PlatformUtils.get_platform(session)
if session.group:
if not await GroupConsole.exists(group_id=session.group.id):
await GroupConsole.create(group_id=session.group.id)
logger.info("添加当前群组ID信息" "", session=session)
if not await GroupInfoUser.exists(
user_id=session.user.id, group_id=session.group.id
):
await GroupInfoUser.create(
user_id=session.user.id, group_id=session.group.id, platform=platform
)
logger.info("添加当前用户群组ID信息", "", session=session)
logger.info("添加当前群组ID信息", session=session)
await GroupInfoUser.update_or_create(
user_id=session.user.id,
group_id=session.group.id,
platform=PlatformUtils.get_platform(session),
)
elif not await FriendUser.exists(user_id=session.user.id, platform=platform):
try:
await FriendUser.create(user_id=session.user.id, platform=platform)
logger.info("添加当前好友用户信息", "", session=session)
except Exception as e:
logger.error("添加当前好友用户信息失败", session=session, e=e)
await FriendUser.create(
user_id=session.user.id, platform=PlatformUtils.get_platform(session)
)
logger.info("添加当前好友用户信息", "", session=session)

View File

@ -104,7 +104,7 @@ async def _(bot: v12Bot | v11Bot, event: FriendRequestEvent, session: EventSessi
await PlatformUtils.send_superuser(
bot,
f"*****一份好友申请*****\n"
f"ID: {f.id}"
f"ID: {f.id}\n"
f"昵称:{nickname}({event.user_id})\n"
f"自动同意:{'' if base_config.get('AUTO_ADD_FRIEND') else '×'}\n"
f"日期:{str(datetime.now()).split('.')[0]}\n"

View File

@ -1,11 +1,12 @@
import time
from typing import ClassVar
from typing_extensions import Self
from tortoise import fields
from zhenxun.services.db_context import Model
from zhenxun.services.log import logger
from zhenxun.utils.enum import CacheType
from zhenxun.utils.enum import CacheType, DbLockType
from zhenxun.utils.exception import UserAndGroupIsNone
@ -31,6 +32,9 @@ class BanConsole(Model):
unique_together = ("user_id", "group_id")
cache_type = CacheType.BAN
"""缓存类型"""
enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE]
"""开启锁"""
@classmethod
async def _get_data(cls, user_id: str | None, group_id: str | None) -> Self | None:

View File

@ -1,4 +1,4 @@
from typing import Any, overload
from typing import Any, ClassVar, overload
from typing_extensions import Self
from tortoise import fields
@ -8,7 +8,7 @@ from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.task_info import TaskInfo
from zhenxun.services.cache import CacheRoot
from zhenxun.services.db_context import Model
from zhenxun.utils.enum import CacheType, PluginType
from zhenxun.utils.enum import CacheType, DbLockType, PluginType
class GroupConsole(Model):
@ -53,6 +53,9 @@ class GroupConsole(Model):
unique_together = ("group_id", "channel_id")
cache_type = CacheType.GROUPS
"""缓存类型"""
enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE]
"""开启锁"""
@staticmethod
def format(name: str) -> str:

View File

@ -1,5 +1,6 @@
from asyncio import Semaphore
from collections.abc import Iterable
from typing import Any
from typing import Any, ClassVar
from typing_extensions import Self
import nonebot
@ -10,6 +11,7 @@ from tortoise.connection import connections
from tortoise.models import Model as TortoiseModel
from zhenxun.configs.config import BotConfig
from zhenxun.utils.enum import DbLockType
from .cache import CacheRoot
from .log import logger
@ -31,16 +33,24 @@ def _():
class Model(TortoiseModel):
"""
自动添加模块
Args:
Model_: Model
"""
sem_data: ClassVar[dict[str, dict[str, Semaphore]]] = {}
def __init_subclass__(cls, **kwargs):
MODELS.append(cls.__module__)
if func := getattr(cls, "_run_script", None):
SCRIPT_METHOD.append((cls.__module__, func))
if enable_lock := getattr(cls, "enable_lock", []):
"""创建锁"""
cls.sem_data[cls.__module__] = {}
for lock in enable_lock:
cls.sem_data[cls.__module__][lock] = Semaphore(1)
@classmethod
def get_semaphore(cls, lock_type: DbLockType):
return cls.sem_data.get(cls.__module__, {}).get(lock_type, None)
@classmethod
def get_cache_type(cls):
@ -50,10 +60,7 @@ class Model(TortoiseModel):
async def create(
cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any
) -> Self:
result = await super().create(using_db=using_db, **kwargs)
if cache_type := cls.get_cache_type():
await CacheRoot.reload(cache_type)
return result
return await super().create(using_db=using_db, **kwargs)
@classmethod
async def get_or_create(
@ -130,12 +137,25 @@ class Model(TortoiseModel):
force_create: bool = False,
force_update: bool = False,
):
await super().save(
using_db=using_db,
update_fields=update_fields,
force_create=force_create,
force_update=force_update,
)
if getattr(self, "id", None) is None:
sem = self.get_semaphore(DbLockType.CREATE)
else:
sem = self.get_semaphore(DbLockType.UPDATE)
if sem:
async with sem:
await super().save(
using_db=using_db,
update_fields=update_fields,
force_create=force_create,
force_update=force_update,
)
else:
await super().save(
using_db=using_db,
update_fields=update_fields,
force_create=force_create,
force_update=force_update,
)
if CACHE_FLAG and (cache_type := getattr(self, "cache_type", None)):
await CacheRoot.reload(cache_type)

View File

@ -20,6 +20,21 @@ class CacheType(StrEnum):
"""用户权限"""
class DbLockType(StrEnum):
"""
锁类型
"""
CREATE = "CREATE"
"""创建"""
DELETE = "DELETE"
"""删除"""
UPDATE = "UPDATE"
"""更新"""
QUERY = "QUERY"
"""查询"""
class GoldHandle(StrEnum):
"""
金币处理