mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 14:22:55 +08:00
✨ 数据库添加semaphore锁
This commit is contained in:
parent
4929162dcd
commit
f80e541047
@ -64,44 +64,29 @@ async def _(matcher: Matcher, bot: Bot, session: Uninfo):
|
||||
raise IgnoredException("群黑名单, 群权限-1..")
|
||||
if user_id:
|
||||
ban_result = Config.get_config("hook", "BAN_RESULT")
|
||||
try:
|
||||
if await is_ban(user_id, group_id):
|
||||
time = await BanConsole.check_ban_time(user_id, group_id)
|
||||
if time == -1:
|
||||
time_str = "∞"
|
||||
else:
|
||||
time = abs(int(time))
|
||||
if time < 60:
|
||||
time_str = f"{time!s} 秒"
|
||||
else:
|
||||
minute = int(time / 60)
|
||||
if minute > 60:
|
||||
hours = minute // 60
|
||||
minute %= 60
|
||||
time_str = f"{hours} 小时 {minute}分钟"
|
||||
else:
|
||||
time_str = f"{minute} 分钟"
|
||||
if time != -1 and ban_result and _flmt.check(user_id):
|
||||
_flmt.start_cd(user_id)
|
||||
await MessageUtils.build_message(
|
||||
[
|
||||
At(flag="user", target=user_id),
|
||||
f"{ban_result}\n在..在 {time_str} 后才会理你喔",
|
||||
]
|
||||
).send()
|
||||
logger.debug("用户处于黑名单中...", "BanChecker")
|
||||
raise IgnoredException("用户处于黑名单中...")
|
||||
except MultipleObjectsReturned:
|
||||
logger.warning(
|
||||
"黑名单数据重复,过滤该次hook并移除多余数据...", "BanChecker"
|
||||
)
|
||||
if group_id:
|
||||
ids = await BanConsole.filter(
|
||||
user_id=user_id, group_id=group_id
|
||||
).values_list("id", flat=True)
|
||||
if await is_ban(user_id, group_id):
|
||||
time = await BanConsole.check_ban_time(user_id, group_id)
|
||||
if time == -1:
|
||||
time_str = "∞"
|
||||
else:
|
||||
ids = await BanConsole.filter(
|
||||
user_id=user_id, group_id__isnull=True
|
||||
).values_list("id", flat=True)
|
||||
await BanConsole.filter(id__in=ids[:-1]).delete()
|
||||
await cache.reload()
|
||||
time = abs(int(time))
|
||||
if time < 60:
|
||||
time_str = f"{time!s} 秒"
|
||||
else:
|
||||
minute = int(time / 60)
|
||||
if minute > 60:
|
||||
hours = minute // 60
|
||||
minute %= 60
|
||||
time_str = f"{hours} 小时 {minute}分钟"
|
||||
else:
|
||||
time_str = f"{minute} 分钟"
|
||||
if time != -1 and ban_result and _flmt.check(user_id):
|
||||
_flmt.start_cd(user_id)
|
||||
await MessageUtils.build_message(
|
||||
[
|
||||
At(flag="user", target=user_id),
|
||||
f"{ban_result}\n在..在 {time_str} 后才会理你喔",
|
||||
]
|
||||
).send()
|
||||
logger.debug("用户处于黑名单中...", "BanChecker")
|
||||
raise IgnoredException("用户处于黑名单中...")
|
||||
|
||||
@ -103,7 +103,7 @@ group_increase_handle = on_notice(
|
||||
group_decrease_handle = on_notice(
|
||||
priority=1,
|
||||
block=False,
|
||||
rule=notice_rule([GroupMemberDecreaseEvent, GroupMemberIncreaseEvent]),
|
||||
rule=notice_rule([GroupMemberDecreaseEvent, GroupDecreaseNoticeEvent]),
|
||||
)
|
||||
"""群员减少处理"""
|
||||
add_group = on_request(priority=1, block=False)
|
||||
|
||||
@ -55,14 +55,16 @@ class GroupManager:
|
||||
for plugin in plugin_list:
|
||||
block_plugin += f"<{plugin.module},"
|
||||
group_info = await bot.get_group_info(group_id=group_id)
|
||||
await GroupConsole.create(
|
||||
await GroupConsole.update_or_create(
|
||||
group_id=group_info["group_id"],
|
||||
group_name=group_info["group_name"],
|
||||
max_member_count=group_info["max_member_count"],
|
||||
member_count=group_info["member_count"],
|
||||
group_flag=1,
|
||||
block_plugin=block_plugin,
|
||||
platform="qq",
|
||||
defaults={
|
||||
"group_name": group_info["group_name"],
|
||||
"max_member_count": group_info["max_member_count"],
|
||||
"member_count": group_info["member_count"],
|
||||
"group_flag": 1,
|
||||
"block_plugin": block_plugin,
|
||||
"platform": "qq",
|
||||
},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -144,7 +146,7 @@ class GroupManager:
|
||||
e=e,
|
||||
)
|
||||
raise ForceAddGroupError("强制拉群或未有群信息,退出群聊失败...") from e
|
||||
await GroupConsole.filter(group_id=group_id).delete()
|
||||
# await GroupConsole.filter(group_id=group_id).delete()
|
||||
raise ForceAddGroupError(f"触发强制入群保护,已成功退出群聊 {group_id}...")
|
||||
else:
|
||||
await cls.__handle_add_group(bot, group_id, group)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from nonebot.message import run_preprocessor
|
||||
from nonebot import on_message
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.models.friend_user import FriendUser
|
||||
@ -8,24 +8,27 @@ from zhenxun.services.log import logger
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
|
||||
@run_preprocessor
|
||||
async def do_something(session: Uninfo):
|
||||
def rule(session: Uninfo) -> bool:
|
||||
return PlatformUtils.is_qbot(session)
|
||||
|
||||
|
||||
_matcher = on_message(priority=999, block=False, rule=rule)
|
||||
|
||||
|
||||
@_matcher.handle()
|
||||
async def _(session: Uninfo):
|
||||
platform = PlatformUtils.get_platform(session)
|
||||
if session.group:
|
||||
if not await GroupConsole.exists(group_id=session.group.id):
|
||||
await GroupConsole.create(group_id=session.group.id)
|
||||
logger.info("添加当前群组ID信息" "", session=session)
|
||||
|
||||
if not await GroupInfoUser.exists(
|
||||
user_id=session.user.id, group_id=session.group.id
|
||||
):
|
||||
await GroupInfoUser.create(
|
||||
user_id=session.user.id, group_id=session.group.id, platform=platform
|
||||
)
|
||||
logger.info("添加当前用户群组ID信息", "", session=session)
|
||||
logger.info("添加当前群组ID信息", session=session)
|
||||
await GroupInfoUser.update_or_create(
|
||||
user_id=session.user.id,
|
||||
group_id=session.group.id,
|
||||
platform=PlatformUtils.get_platform(session),
|
||||
)
|
||||
elif not await FriendUser.exists(user_id=session.user.id, platform=platform):
|
||||
try:
|
||||
await FriendUser.create(user_id=session.user.id, platform=platform)
|
||||
logger.info("添加当前好友用户信息", "", session=session)
|
||||
except Exception as e:
|
||||
logger.error("添加当前好友用户信息失败", session=session, e=e)
|
||||
await FriendUser.create(
|
||||
user_id=session.user.id, platform=PlatformUtils.get_platform(session)
|
||||
)
|
||||
logger.info("添加当前好友用户信息", "", session=session)
|
||||
|
||||
@ -104,7 +104,7 @@ async def _(bot: v12Bot | v11Bot, event: FriendRequestEvent, session: EventSessi
|
||||
await PlatformUtils.send_superuser(
|
||||
bot,
|
||||
f"*****一份好友申请*****\n"
|
||||
f"ID: {f.id}"
|
||||
f"ID: {f.id}\n"
|
||||
f"昵称:{nickname}({event.user_id})\n"
|
||||
f"自动同意:{'√' if base_config.get('AUTO_ADD_FRIEND') else '×'}\n"
|
||||
f"日期:{str(datetime.now()).split('.')[0]}\n"
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
import time
|
||||
from typing import ClassVar
|
||||
from typing_extensions import Self
|
||||
|
||||
from tortoise import fields
|
||||
|
||||
from zhenxun.services.db_context import Model
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import CacheType
|
||||
from zhenxun.utils.enum import CacheType, DbLockType
|
||||
from zhenxun.utils.exception import UserAndGroupIsNone
|
||||
|
||||
|
||||
@ -31,6 +32,9 @@ class BanConsole(Model):
|
||||
unique_together = ("user_id", "group_id")
|
||||
|
||||
cache_type = CacheType.BAN
|
||||
"""缓存类型"""
|
||||
enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE]
|
||||
"""开启锁"""
|
||||
|
||||
@classmethod
|
||||
async def _get_data(cls, user_id: str | None, group_id: str | None) -> Self | None:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any, overload
|
||||
from typing import Any, ClassVar, overload
|
||||
from typing_extensions import Self
|
||||
|
||||
from tortoise import fields
|
||||
@ -8,7 +8,7 @@ from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.models.task_info import TaskInfo
|
||||
from zhenxun.services.cache import CacheRoot
|
||||
from zhenxun.services.db_context import Model
|
||||
from zhenxun.utils.enum import CacheType, PluginType
|
||||
from zhenxun.utils.enum import CacheType, DbLockType, PluginType
|
||||
|
||||
|
||||
class GroupConsole(Model):
|
||||
@ -53,6 +53,9 @@ class GroupConsole(Model):
|
||||
unique_together = ("group_id", "channel_id")
|
||||
|
||||
cache_type = CacheType.GROUPS
|
||||
"""缓存类型"""
|
||||
enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE]
|
||||
"""开启锁"""
|
||||
|
||||
@staticmethod
|
||||
def format(name: str) -> str:
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from asyncio import Semaphore
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
from typing import Any, ClassVar
|
||||
from typing_extensions import Self
|
||||
|
||||
import nonebot
|
||||
@ -10,6 +11,7 @@ from tortoise.connection import connections
|
||||
from tortoise.models import Model as TortoiseModel
|
||||
|
||||
from zhenxun.configs.config import BotConfig
|
||||
from zhenxun.utils.enum import DbLockType
|
||||
|
||||
from .cache import CacheRoot
|
||||
from .log import logger
|
||||
@ -31,16 +33,24 @@ def _():
|
||||
class Model(TortoiseModel):
|
||||
"""
|
||||
自动添加模块
|
||||
|
||||
Args:
|
||||
Model_: Model
|
||||
"""
|
||||
|
||||
sem_data: ClassVar[dict[str, dict[str, Semaphore]]] = {}
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
MODELS.append(cls.__module__)
|
||||
|
||||
if func := getattr(cls, "_run_script", None):
|
||||
SCRIPT_METHOD.append((cls.__module__, func))
|
||||
if enable_lock := getattr(cls, "enable_lock", []):
|
||||
"""创建锁"""
|
||||
cls.sem_data[cls.__module__] = {}
|
||||
for lock in enable_lock:
|
||||
cls.sem_data[cls.__module__][lock] = Semaphore(1)
|
||||
|
||||
@classmethod
|
||||
def get_semaphore(cls, lock_type: DbLockType):
|
||||
return cls.sem_data.get(cls.__module__, {}).get(lock_type, None)
|
||||
|
||||
@classmethod
|
||||
def get_cache_type(cls):
|
||||
@ -50,10 +60,7 @@ class Model(TortoiseModel):
|
||||
async def create(
|
||||
cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any
|
||||
) -> Self:
|
||||
result = await super().create(using_db=using_db, **kwargs)
|
||||
if cache_type := cls.get_cache_type():
|
||||
await CacheRoot.reload(cache_type)
|
||||
return result
|
||||
return await super().create(using_db=using_db, **kwargs)
|
||||
|
||||
@classmethod
|
||||
async def get_or_create(
|
||||
@ -130,12 +137,25 @@ class Model(TortoiseModel):
|
||||
force_create: bool = False,
|
||||
force_update: bool = False,
|
||||
):
|
||||
await super().save(
|
||||
using_db=using_db,
|
||||
update_fields=update_fields,
|
||||
force_create=force_create,
|
||||
force_update=force_update,
|
||||
)
|
||||
if getattr(self, "id", None) is None:
|
||||
sem = self.get_semaphore(DbLockType.CREATE)
|
||||
else:
|
||||
sem = self.get_semaphore(DbLockType.UPDATE)
|
||||
if sem:
|
||||
async with sem:
|
||||
await super().save(
|
||||
using_db=using_db,
|
||||
update_fields=update_fields,
|
||||
force_create=force_create,
|
||||
force_update=force_update,
|
||||
)
|
||||
else:
|
||||
await super().save(
|
||||
using_db=using_db,
|
||||
update_fields=update_fields,
|
||||
force_create=force_create,
|
||||
force_update=force_update,
|
||||
)
|
||||
if CACHE_FLAG and (cache_type := getattr(self, "cache_type", None)):
|
||||
await CacheRoot.reload(cache_type)
|
||||
|
||||
|
||||
@ -20,6 +20,21 @@ class CacheType(StrEnum):
|
||||
"""用户权限"""
|
||||
|
||||
|
||||
class DbLockType(StrEnum):
|
||||
"""
|
||||
锁类型
|
||||
"""
|
||||
|
||||
CREATE = "CREATE"
|
||||
"""创建"""
|
||||
DELETE = "DELETE"
|
||||
"""删除"""
|
||||
UPDATE = "UPDATE"
|
||||
"""更新"""
|
||||
QUERY = "QUERY"
|
||||
"""查询"""
|
||||
|
||||
|
||||
class GoldHandle(StrEnum):
|
||||
"""
|
||||
金币处理
|
||||
|
||||
Loading…
Reference in New Issue
Block a user