mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 14:22:55 +08:00
✨ 数据库添加semaphore锁
This commit is contained in:
parent
4929162dcd
commit
f80e541047
@ -64,44 +64,29 @@ async def _(matcher: Matcher, bot: Bot, session: Uninfo):
|
|||||||
raise IgnoredException("群黑名单, 群权限-1..")
|
raise IgnoredException("群黑名单, 群权限-1..")
|
||||||
if user_id:
|
if user_id:
|
||||||
ban_result = Config.get_config("hook", "BAN_RESULT")
|
ban_result = Config.get_config("hook", "BAN_RESULT")
|
||||||
try:
|
if await is_ban(user_id, group_id):
|
||||||
if await is_ban(user_id, group_id):
|
time = await BanConsole.check_ban_time(user_id, group_id)
|
||||||
time = await BanConsole.check_ban_time(user_id, group_id)
|
if time == -1:
|
||||||
if time == -1:
|
time_str = "∞"
|
||||||
time_str = "∞"
|
|
||||||
else:
|
|
||||||
time = abs(int(time))
|
|
||||||
if time < 60:
|
|
||||||
time_str = f"{time!s} 秒"
|
|
||||||
else:
|
|
||||||
minute = int(time / 60)
|
|
||||||
if minute > 60:
|
|
||||||
hours = minute // 60
|
|
||||||
minute %= 60
|
|
||||||
time_str = f"{hours} 小时 {minute}分钟"
|
|
||||||
else:
|
|
||||||
time_str = f"{minute} 分钟"
|
|
||||||
if time != -1 and ban_result and _flmt.check(user_id):
|
|
||||||
_flmt.start_cd(user_id)
|
|
||||||
await MessageUtils.build_message(
|
|
||||||
[
|
|
||||||
At(flag="user", target=user_id),
|
|
||||||
f"{ban_result}\n在..在 {time_str} 后才会理你喔",
|
|
||||||
]
|
|
||||||
).send()
|
|
||||||
logger.debug("用户处于黑名单中...", "BanChecker")
|
|
||||||
raise IgnoredException("用户处于黑名单中...")
|
|
||||||
except MultipleObjectsReturned:
|
|
||||||
logger.warning(
|
|
||||||
"黑名单数据重复,过滤该次hook并移除多余数据...", "BanChecker"
|
|
||||||
)
|
|
||||||
if group_id:
|
|
||||||
ids = await BanConsole.filter(
|
|
||||||
user_id=user_id, group_id=group_id
|
|
||||||
).values_list("id", flat=True)
|
|
||||||
else:
|
else:
|
||||||
ids = await BanConsole.filter(
|
time = abs(int(time))
|
||||||
user_id=user_id, group_id__isnull=True
|
if time < 60:
|
||||||
).values_list("id", flat=True)
|
time_str = f"{time!s} 秒"
|
||||||
await BanConsole.filter(id__in=ids[:-1]).delete()
|
else:
|
||||||
await cache.reload()
|
minute = int(time / 60)
|
||||||
|
if minute > 60:
|
||||||
|
hours = minute // 60
|
||||||
|
minute %= 60
|
||||||
|
time_str = f"{hours} 小时 {minute}分钟"
|
||||||
|
else:
|
||||||
|
time_str = f"{minute} 分钟"
|
||||||
|
if time != -1 and ban_result and _flmt.check(user_id):
|
||||||
|
_flmt.start_cd(user_id)
|
||||||
|
await MessageUtils.build_message(
|
||||||
|
[
|
||||||
|
At(flag="user", target=user_id),
|
||||||
|
f"{ban_result}\n在..在 {time_str} 后才会理你喔",
|
||||||
|
]
|
||||||
|
).send()
|
||||||
|
logger.debug("用户处于黑名单中...", "BanChecker")
|
||||||
|
raise IgnoredException("用户处于黑名单中...")
|
||||||
|
|||||||
@ -103,7 +103,7 @@ group_increase_handle = on_notice(
|
|||||||
group_decrease_handle = on_notice(
|
group_decrease_handle = on_notice(
|
||||||
priority=1,
|
priority=1,
|
||||||
block=False,
|
block=False,
|
||||||
rule=notice_rule([GroupMemberDecreaseEvent, GroupMemberIncreaseEvent]),
|
rule=notice_rule([GroupMemberDecreaseEvent, GroupDecreaseNoticeEvent]),
|
||||||
)
|
)
|
||||||
"""群员减少处理"""
|
"""群员减少处理"""
|
||||||
add_group = on_request(priority=1, block=False)
|
add_group = on_request(priority=1, block=False)
|
||||||
|
|||||||
@ -55,14 +55,16 @@ class GroupManager:
|
|||||||
for plugin in plugin_list:
|
for plugin in plugin_list:
|
||||||
block_plugin += f"<{plugin.module},"
|
block_plugin += f"<{plugin.module},"
|
||||||
group_info = await bot.get_group_info(group_id=group_id)
|
group_info = await bot.get_group_info(group_id=group_id)
|
||||||
await GroupConsole.create(
|
await GroupConsole.update_or_create(
|
||||||
group_id=group_info["group_id"],
|
group_id=group_info["group_id"],
|
||||||
group_name=group_info["group_name"],
|
defaults={
|
||||||
max_member_count=group_info["max_member_count"],
|
"group_name": group_info["group_name"],
|
||||||
member_count=group_info["member_count"],
|
"max_member_count": group_info["max_member_count"],
|
||||||
group_flag=1,
|
"member_count": group_info["member_count"],
|
||||||
block_plugin=block_plugin,
|
"group_flag": 1,
|
||||||
platform="qq",
|
"block_plugin": block_plugin,
|
||||||
|
"platform": "qq",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -144,7 +146,7 @@ class GroupManager:
|
|||||||
e=e,
|
e=e,
|
||||||
)
|
)
|
||||||
raise ForceAddGroupError("强制拉群或未有群信息,退出群聊失败...") from e
|
raise ForceAddGroupError("强制拉群或未有群信息,退出群聊失败...") from e
|
||||||
await GroupConsole.filter(group_id=group_id).delete()
|
# await GroupConsole.filter(group_id=group_id).delete()
|
||||||
raise ForceAddGroupError(f"触发强制入群保护,已成功退出群聊 {group_id}...")
|
raise ForceAddGroupError(f"触发强制入群保护,已成功退出群聊 {group_id}...")
|
||||||
else:
|
else:
|
||||||
await cls.__handle_add_group(bot, group_id, group)
|
await cls.__handle_add_group(bot, group_id, group)
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from nonebot.message import run_preprocessor
|
from nonebot import on_message
|
||||||
from nonebot_plugin_uninfo import Uninfo
|
from nonebot_plugin_uninfo import Uninfo
|
||||||
|
|
||||||
from zhenxun.models.friend_user import FriendUser
|
from zhenxun.models.friend_user import FriendUser
|
||||||
@ -8,24 +8,27 @@ from zhenxun.services.log import logger
|
|||||||
from zhenxun.utils.platform import PlatformUtils
|
from zhenxun.utils.platform import PlatformUtils
|
||||||
|
|
||||||
|
|
||||||
@run_preprocessor
|
def rule(session: Uninfo) -> bool:
|
||||||
async def do_something(session: Uninfo):
|
return PlatformUtils.is_qbot(session)
|
||||||
|
|
||||||
|
|
||||||
|
_matcher = on_message(priority=999, block=False, rule=rule)
|
||||||
|
|
||||||
|
|
||||||
|
@_matcher.handle()
|
||||||
|
async def _(session: Uninfo):
|
||||||
platform = PlatformUtils.get_platform(session)
|
platform = PlatformUtils.get_platform(session)
|
||||||
if session.group:
|
if session.group:
|
||||||
if not await GroupConsole.exists(group_id=session.group.id):
|
if not await GroupConsole.exists(group_id=session.group.id):
|
||||||
await GroupConsole.create(group_id=session.group.id)
|
await GroupConsole.create(group_id=session.group.id)
|
||||||
logger.info("添加当前群组ID信息" "", session=session)
|
logger.info("添加当前群组ID信息", session=session)
|
||||||
|
await GroupInfoUser.update_or_create(
|
||||||
if not await GroupInfoUser.exists(
|
user_id=session.user.id,
|
||||||
user_id=session.user.id, group_id=session.group.id
|
group_id=session.group.id,
|
||||||
):
|
platform=PlatformUtils.get_platform(session),
|
||||||
await GroupInfoUser.create(
|
)
|
||||||
user_id=session.user.id, group_id=session.group.id, platform=platform
|
|
||||||
)
|
|
||||||
logger.info("添加当前用户群组ID信息", "", session=session)
|
|
||||||
elif not await FriendUser.exists(user_id=session.user.id, platform=platform):
|
elif not await FriendUser.exists(user_id=session.user.id, platform=platform):
|
||||||
try:
|
await FriendUser.create(
|
||||||
await FriendUser.create(user_id=session.user.id, platform=platform)
|
user_id=session.user.id, platform=PlatformUtils.get_platform(session)
|
||||||
logger.info("添加当前好友用户信息", "", session=session)
|
)
|
||||||
except Exception as e:
|
logger.info("添加当前好友用户信息", "", session=session)
|
||||||
logger.error("添加当前好友用户信息失败", session=session, e=e)
|
|
||||||
|
|||||||
@ -104,7 +104,7 @@ async def _(bot: v12Bot | v11Bot, event: FriendRequestEvent, session: EventSessi
|
|||||||
await PlatformUtils.send_superuser(
|
await PlatformUtils.send_superuser(
|
||||||
bot,
|
bot,
|
||||||
f"*****一份好友申请*****\n"
|
f"*****一份好友申请*****\n"
|
||||||
f"ID: {f.id}"
|
f"ID: {f.id}\n"
|
||||||
f"昵称:{nickname}({event.user_id})\n"
|
f"昵称:{nickname}({event.user_id})\n"
|
||||||
f"自动同意:{'√' if base_config.get('AUTO_ADD_FRIEND') else '×'}\n"
|
f"自动同意:{'√' if base_config.get('AUTO_ADD_FRIEND') else '×'}\n"
|
||||||
f"日期:{str(datetime.now()).split('.')[0]}\n"
|
f"日期:{str(datetime.now()).split('.')[0]}\n"
|
||||||
|
|||||||
@ -1,11 +1,12 @@
|
|||||||
import time
|
import time
|
||||||
|
from typing import ClassVar
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from tortoise import fields
|
from tortoise import fields
|
||||||
|
|
||||||
from zhenxun.services.db_context import Model
|
from zhenxun.services.db_context import Model
|
||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
from zhenxun.utils.enum import CacheType
|
from zhenxun.utils.enum import CacheType, DbLockType
|
||||||
from zhenxun.utils.exception import UserAndGroupIsNone
|
from zhenxun.utils.exception import UserAndGroupIsNone
|
||||||
|
|
||||||
|
|
||||||
@ -31,6 +32,9 @@ class BanConsole(Model):
|
|||||||
unique_together = ("user_id", "group_id")
|
unique_together = ("user_id", "group_id")
|
||||||
|
|
||||||
cache_type = CacheType.BAN
|
cache_type = CacheType.BAN
|
||||||
|
"""缓存类型"""
|
||||||
|
enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE]
|
||||||
|
"""开启锁"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _get_data(cls, user_id: str | None, group_id: str | None) -> Self | None:
|
async def _get_data(cls, user_id: str | None, group_id: str | None) -> Self | None:
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, overload
|
from typing import Any, ClassVar, overload
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from tortoise import fields
|
from tortoise import fields
|
||||||
@ -8,7 +8,7 @@ from zhenxun.models.plugin_info import PluginInfo
|
|||||||
from zhenxun.models.task_info import TaskInfo
|
from zhenxun.models.task_info import TaskInfo
|
||||||
from zhenxun.services.cache import CacheRoot
|
from zhenxun.services.cache import CacheRoot
|
||||||
from zhenxun.services.db_context import Model
|
from zhenxun.services.db_context import Model
|
||||||
from zhenxun.utils.enum import CacheType, PluginType
|
from zhenxun.utils.enum import CacheType, DbLockType, PluginType
|
||||||
|
|
||||||
|
|
||||||
class GroupConsole(Model):
|
class GroupConsole(Model):
|
||||||
@ -53,6 +53,9 @@ class GroupConsole(Model):
|
|||||||
unique_together = ("group_id", "channel_id")
|
unique_together = ("group_id", "channel_id")
|
||||||
|
|
||||||
cache_type = CacheType.GROUPS
|
cache_type = CacheType.GROUPS
|
||||||
|
"""缓存类型"""
|
||||||
|
enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE]
|
||||||
|
"""开启锁"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def format(name: str) -> str:
|
def format(name: str) -> str:
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
|
from asyncio import Semaphore
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import Any
|
from typing import Any, ClassVar
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
import nonebot
|
import nonebot
|
||||||
@ -10,6 +11,7 @@ from tortoise.connection import connections
|
|||||||
from tortoise.models import Model as TortoiseModel
|
from tortoise.models import Model as TortoiseModel
|
||||||
|
|
||||||
from zhenxun.configs.config import BotConfig
|
from zhenxun.configs.config import BotConfig
|
||||||
|
from zhenxun.utils.enum import DbLockType
|
||||||
|
|
||||||
from .cache import CacheRoot
|
from .cache import CacheRoot
|
||||||
from .log import logger
|
from .log import logger
|
||||||
@ -31,16 +33,24 @@ def _():
|
|||||||
class Model(TortoiseModel):
|
class Model(TortoiseModel):
|
||||||
"""
|
"""
|
||||||
自动添加模块
|
自动添加模块
|
||||||
|
|
||||||
Args:
|
|
||||||
Model_: Model
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
sem_data: ClassVar[dict[str, dict[str, Semaphore]]] = {}
|
||||||
|
|
||||||
def __init_subclass__(cls, **kwargs):
|
def __init_subclass__(cls, **kwargs):
|
||||||
MODELS.append(cls.__module__)
|
MODELS.append(cls.__module__)
|
||||||
|
|
||||||
if func := getattr(cls, "_run_script", None):
|
if func := getattr(cls, "_run_script", None):
|
||||||
SCRIPT_METHOD.append((cls.__module__, func))
|
SCRIPT_METHOD.append((cls.__module__, func))
|
||||||
|
if enable_lock := getattr(cls, "enable_lock", []):
|
||||||
|
"""创建锁"""
|
||||||
|
cls.sem_data[cls.__module__] = {}
|
||||||
|
for lock in enable_lock:
|
||||||
|
cls.sem_data[cls.__module__][lock] = Semaphore(1)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_semaphore(cls, lock_type: DbLockType):
|
||||||
|
return cls.sem_data.get(cls.__module__, {}).get(lock_type, None)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_cache_type(cls):
|
def get_cache_type(cls):
|
||||||
@ -50,10 +60,7 @@ class Model(TortoiseModel):
|
|||||||
async def create(
|
async def create(
|
||||||
cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any
|
cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any
|
||||||
) -> Self:
|
) -> Self:
|
||||||
result = await super().create(using_db=using_db, **kwargs)
|
return await super().create(using_db=using_db, **kwargs)
|
||||||
if cache_type := cls.get_cache_type():
|
|
||||||
await CacheRoot.reload(cache_type)
|
|
||||||
return result
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_or_create(
|
async def get_or_create(
|
||||||
@ -130,12 +137,25 @@ class Model(TortoiseModel):
|
|||||||
force_create: bool = False,
|
force_create: bool = False,
|
||||||
force_update: bool = False,
|
force_update: bool = False,
|
||||||
):
|
):
|
||||||
await super().save(
|
if getattr(self, "id", None) is None:
|
||||||
using_db=using_db,
|
sem = self.get_semaphore(DbLockType.CREATE)
|
||||||
update_fields=update_fields,
|
else:
|
||||||
force_create=force_create,
|
sem = self.get_semaphore(DbLockType.UPDATE)
|
||||||
force_update=force_update,
|
if sem:
|
||||||
)
|
async with sem:
|
||||||
|
await super().save(
|
||||||
|
using_db=using_db,
|
||||||
|
update_fields=update_fields,
|
||||||
|
force_create=force_create,
|
||||||
|
force_update=force_update,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await super().save(
|
||||||
|
using_db=using_db,
|
||||||
|
update_fields=update_fields,
|
||||||
|
force_create=force_create,
|
||||||
|
force_update=force_update,
|
||||||
|
)
|
||||||
if CACHE_FLAG and (cache_type := getattr(self, "cache_type", None)):
|
if CACHE_FLAG and (cache_type := getattr(self, "cache_type", None)):
|
||||||
await CacheRoot.reload(cache_type)
|
await CacheRoot.reload(cache_type)
|
||||||
|
|
||||||
|
|||||||
@ -20,6 +20,21 @@ class CacheType(StrEnum):
|
|||||||
"""用户权限"""
|
"""用户权限"""
|
||||||
|
|
||||||
|
|
||||||
|
class DbLockType(StrEnum):
|
||||||
|
"""
|
||||||
|
锁类型
|
||||||
|
"""
|
||||||
|
|
||||||
|
CREATE = "CREATE"
|
||||||
|
"""创建"""
|
||||||
|
DELETE = "DELETE"
|
||||||
|
"""删除"""
|
||||||
|
UPDATE = "UPDATE"
|
||||||
|
"""更新"""
|
||||||
|
QUERY = "QUERY"
|
||||||
|
"""查询"""
|
||||||
|
|
||||||
|
|
||||||
class GoldHandle(StrEnum):
|
class GoldHandle(StrEnum):
|
||||||
"""
|
"""
|
||||||
金币处理
|
金币处理
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user