2025-07-14 22:35:29 +08:00
|
|
|
import asyncio
|
|
|
|
|
import time
|
|
|
|
|
from typing import ClassVar
|
|
|
|
|
|
|
|
|
|
import nonebot
|
|
|
|
|
from nonebot_plugin_uninfo import Uninfo
|
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
|
|
|
|
|
from zhenxun.models.plugin_info import PluginInfo
|
|
|
|
|
from zhenxun.models.plugin_limit import PluginLimit
|
|
|
|
|
from zhenxun.services.db_context import DB_TIMEOUT_SECONDS
|
|
|
|
|
from zhenxun.services.log import logger
|
|
|
|
|
from zhenxun.utils.enum import LimitWatchType, PluginLimitType
|
2025-07-15 17:13:33 +08:00
|
|
|
from zhenxun.utils.limiters import CountLimiter, FreqLimiter, UserBlockLimiter
|
2025-07-14 22:35:29 +08:00
|
|
|
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
|
|
|
|
from zhenxun.utils.message import MessageUtils
|
2025-07-15 17:13:33 +08:00
|
|
|
from zhenxun.utils.time_utils import TimeUtils
|
|
|
|
|
from zhenxun.utils.utils import get_entity_ids
|
2025-07-14 22:35:29 +08:00
|
|
|
|
|
|
|
|
from .config import LOGGER_COMMAND, WARNING_THRESHOLD
|
|
|
|
|
from .exception import SkipPluginException
|
|
|
|
|
|
|
|
|
|
driver = nonebot.get_driver()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@PriorityLifecycle.on_startup(priority=5)
|
|
|
|
|
async def _():
|
|
|
|
|
"""初始化限制"""
|
|
|
|
|
await LimitManager.init_limit()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Limit(BaseModel):
|
|
|
|
|
limit: PluginLimit
|
|
|
|
|
limiter: FreqLimiter | UserBlockLimiter | CountLimiter
|
|
|
|
|
|
|
|
|
|
class Config:
|
|
|
|
|
arbitrary_types_allowed = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LimitManager:
|
|
|
|
|
add_module: ClassVar[list] = []
|
|
|
|
|
last_update_time: ClassVar[float] = 0
|
|
|
|
|
update_interval: ClassVar[float] = 6000 # 1小时更新一次
|
|
|
|
|
is_updating: ClassVar[bool] = False # 防止并发更新
|
|
|
|
|
|
|
|
|
|
cd_limit: ClassVar[dict[str, Limit]] = {}
|
|
|
|
|
block_limit: ClassVar[dict[str, Limit]] = {}
|
|
|
|
|
count_limit: ClassVar[dict[str, Limit]] = {}
|
|
|
|
|
|
|
|
|
|
# 模块限制缓存,避免频繁查询数据库
|
|
|
|
|
module_limit_cache: ClassVar[dict[str, tuple[float, list[PluginLimit]]]] = {}
|
|
|
|
|
module_cache_ttl: ClassVar[float] = 60 # 模块缓存有效期(秒)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
async def init_limit(cls):
|
|
|
|
|
"""初始化限制"""
|
|
|
|
|
cls.last_update_time = time.time()
|
|
|
|
|
try:
|
|
|
|
|
await asyncio.wait_for(cls.update_limits(), timeout=DB_TIMEOUT_SECONDS * 2)
|
|
|
|
|
except asyncio.TimeoutError:
|
|
|
|
|
logger.error("初始化限制超时", LOGGER_COMMAND)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
async def update_limits(cls):
|
|
|
|
|
"""更新限制信息"""
|
|
|
|
|
# 防止并发更新
|
|
|
|
|
if cls.is_updating:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
cls.is_updating = True
|
|
|
|
|
try:
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
try:
|
|
|
|
|
limit_list = await asyncio.wait_for(
|
|
|
|
|
PluginLimit.filter(status=True).all(), timeout=DB_TIMEOUT_SECONDS
|
|
|
|
|
)
|
|
|
|
|
except asyncio.TimeoutError:
|
|
|
|
|
logger.error("查询限制信息超时", LOGGER_COMMAND)
|
|
|
|
|
cls.is_updating = False
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# 清空旧数据
|
|
|
|
|
cls.add_module = []
|
|
|
|
|
cls.cd_limit = {}
|
|
|
|
|
cls.block_limit = {}
|
|
|
|
|
cls.count_limit = {}
|
|
|
|
|
# 添加新数据
|
|
|
|
|
for limit in limit_list:
|
|
|
|
|
cls.add_limit(limit)
|
|
|
|
|
|
|
|
|
|
cls.last_update_time = time.time()
|
|
|
|
|
elapsed = time.time() - start_time
|
|
|
|
|
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的更新
|
|
|
|
|
logger.warning(f"更新限制信息耗时: {elapsed:.3f}s", LOGGER_COMMAND)
|
|
|
|
|
finally:
|
|
|
|
|
cls.is_updating = False
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def add_limit(cls, limit: PluginLimit):
|
|
|
|
|
"""添加限制
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
limit: PluginLimit
|
|
|
|
|
"""
|
|
|
|
|
if limit.module not in cls.add_module:
|
|
|
|
|
cls.add_module.append(limit.module)
|
|
|
|
|
if limit.limit_type == PluginLimitType.BLOCK:
|
|
|
|
|
cls.block_limit[limit.module] = Limit(
|
|
|
|
|
limit=limit, limiter=UserBlockLimiter()
|
|
|
|
|
)
|
|
|
|
|
elif limit.limit_type == PluginLimitType.CD:
|
|
|
|
|
cls.cd_limit[limit.module] = Limit(
|
|
|
|
|
limit=limit, limiter=FreqLimiter(limit.cd)
|
|
|
|
|
)
|
|
|
|
|
elif limit.limit_type == PluginLimitType.COUNT:
|
|
|
|
|
cls.count_limit[limit.module] = Limit(
|
|
|
|
|
limit=limit, limiter=CountLimiter(limit.max_count)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def unblock(
|
|
|
|
|
cls, module: str, user_id: str, group_id: str | None, channel_id: str | None
|
|
|
|
|
):
|
|
|
|
|
"""解除插件block
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
module: 模块名
|
|
|
|
|
user_id: 用户id
|
|
|
|
|
group_id: 群组id
|
|
|
|
|
channel_id: 频道id
|
|
|
|
|
"""
|
|
|
|
|
if limit_model := cls.block_limit.get(module):
|
|
|
|
|
limit = limit_model.limit
|
|
|
|
|
limiter: UserBlockLimiter = limit_model.limiter # type: ignore
|
|
|
|
|
key_type = user_id
|
|
|
|
|
if group_id and limit.watch_type == LimitWatchType.GROUP:
|
|
|
|
|
key_type = channel_id or group_id
|
|
|
|
|
logger.debug(
|
|
|
|
|
f"解除对象: {key_type} 的block限制",
|
|
|
|
|
LOGGER_COMMAND,
|
|
|
|
|
session=user_id,
|
|
|
|
|
group_id=group_id,
|
|
|
|
|
)
|
|
|
|
|
limiter.set_false(key_type)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
async def get_module_limits(cls, module: str) -> list[PluginLimit]:
|
|
|
|
|
"""获取模块的限制信息,使用缓存减少数据库查询
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
module: 模块名
|
|
|
|
|
|
|
|
|
|
返回:
|
|
|
|
|
list[PluginLimit]: 限制列表
|
|
|
|
|
"""
|
|
|
|
|
current_time = time.time()
|
|
|
|
|
|
|
|
|
|
# 检查缓存
|
|
|
|
|
if module in cls.module_limit_cache:
|
|
|
|
|
cache_time, limits = cls.module_limit_cache[module]
|
|
|
|
|
if current_time - cache_time < cls.module_cache_ttl:
|
|
|
|
|
return limits
|
|
|
|
|
|
|
|
|
|
# 缓存不存在或已过期,从数据库查询
|
|
|
|
|
try:
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
limits = await asyncio.wait_for(
|
|
|
|
|
PluginLimit.filter(module=module, status=True).all(),
|
|
|
|
|
timeout=DB_TIMEOUT_SECONDS,
|
|
|
|
|
)
|
|
|
|
|
elapsed = time.time() - start_time
|
|
|
|
|
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的查询
|
|
|
|
|
logger.warning(
|
|
|
|
|
f"查询模块限制信息耗时: {elapsed:.3f}s, 模块: {module}",
|
|
|
|
|
LOGGER_COMMAND,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 更新缓存
|
|
|
|
|
cls.module_limit_cache[module] = (current_time, limits)
|
|
|
|
|
return limits
|
|
|
|
|
except asyncio.TimeoutError:
|
|
|
|
|
logger.error(f"查询模块限制信息超时: {module}", LOGGER_COMMAND)
|
|
|
|
|
# 超时时返回空列表,避免阻塞
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
async def check(
|
|
|
|
|
cls,
|
|
|
|
|
module: str,
|
|
|
|
|
user_id: str,
|
|
|
|
|
group_id: str | None,
|
|
|
|
|
channel_id: str | None,
|
|
|
|
|
):
|
|
|
|
|
"""检测限制
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
module: 模块名
|
|
|
|
|
user_id: 用户id
|
|
|
|
|
group_id: 群组id
|
|
|
|
|
channel_id: 频道id
|
|
|
|
|
|
|
|
|
|
异常:
|
|
|
|
|
IgnoredException: IgnoredException
|
|
|
|
|
"""
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
|
|
|
|
|
# 定期更新全局限制信息
|
|
|
|
|
if (
|
|
|
|
|
time.time() - cls.last_update_time > cls.update_interval
|
|
|
|
|
and not cls.is_updating
|
|
|
|
|
):
|
|
|
|
|
# 使用异步任务更新,避免阻塞当前请求
|
|
|
|
|
asyncio.create_task(cls.update_limits()) # noqa: RUF006
|
|
|
|
|
|
|
|
|
|
# 如果模块不在已加载列表中,只加载该模块的限制
|
|
|
|
|
if module not in cls.add_module:
|
|
|
|
|
limits = await cls.get_module_limits(module)
|
|
|
|
|
for limit in limits:
|
|
|
|
|
cls.add_limit(limit)
|
|
|
|
|
|
|
|
|
|
# 检查各种限制
|
|
|
|
|
try:
|
|
|
|
|
if limit_model := cls.cd_limit.get(module):
|
|
|
|
|
await cls.__check(limit_model, user_id, group_id, channel_id)
|
|
|
|
|
if limit_model := cls.block_limit.get(module):
|
|
|
|
|
await cls.__check(limit_model, user_id, group_id, channel_id)
|
|
|
|
|
if limit_model := cls.count_limit.get(module):
|
|
|
|
|
await cls.__check(limit_model, user_id, group_id, channel_id)
|
|
|
|
|
finally:
|
|
|
|
|
# 记录总执行时间
|
|
|
|
|
elapsed = time.time() - start_time
|
|
|
|
|
if elapsed > WARNING_THRESHOLD: # 记录耗时超过500ms的检查
|
|
|
|
|
logger.warning(
|
|
|
|
|
f"限制检查耗时: {elapsed:.3f}s, 模块: {module}",
|
|
|
|
|
LOGGER_COMMAND,
|
|
|
|
|
session=user_id,
|
|
|
|
|
group_id=group_id,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
async def __check(
|
|
|
|
|
cls,
|
|
|
|
|
limit_model: Limit | None,
|
|
|
|
|
user_id: str,
|
|
|
|
|
group_id: str | None,
|
|
|
|
|
channel_id: str | None,
|
|
|
|
|
):
|
|
|
|
|
"""检测限制
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
limit_model: Limit
|
|
|
|
|
user_id: 用户id
|
|
|
|
|
group_id: 群组id
|
|
|
|
|
channel_id: 频道id
|
|
|
|
|
|
|
|
|
|
异常:
|
|
|
|
|
IgnoredException: IgnoredException
|
|
|
|
|
"""
|
|
|
|
|
if not limit_model:
|
|
|
|
|
return
|
|
|
|
|
limit = limit_model.limit
|
|
|
|
|
limiter = limit_model.limiter
|
|
|
|
|
is_limit = (
|
|
|
|
|
LimitWatchType.ALL
|
|
|
|
|
or (group_id and limit.watch_type == LimitWatchType.GROUP)
|
|
|
|
|
or (not group_id and limit.watch_type == LimitWatchType.USER)
|
|
|
|
|
)
|
|
|
|
|
key_type = user_id
|
|
|
|
|
if group_id and limit.watch_type == LimitWatchType.GROUP:
|
|
|
|
|
key_type = channel_id or group_id
|
|
|
|
|
if is_limit and not limiter.check(key_type):
|
|
|
|
|
if limit.result:
|
2025-07-15 17:13:33 +08:00
|
|
|
format_kwargs = {}
|
|
|
|
|
if isinstance(limiter, FreqLimiter):
|
|
|
|
|
left_time = limiter.left_time(key_type)
|
|
|
|
|
cd_str = TimeUtils.format_duration(left_time)
|
|
|
|
|
format_kwargs = {"cd": cd_str}
|
2025-07-14 22:35:29 +08:00
|
|
|
try:
|
|
|
|
|
await asyncio.wait_for(
|
2025-07-15 17:13:33 +08:00
|
|
|
MessageUtils.build_message(
|
|
|
|
|
limit.result, format_args=format_kwargs
|
|
|
|
|
).send(),
|
2025-07-14 22:35:29 +08:00
|
|
|
timeout=DB_TIMEOUT_SECONDS,
|
|
|
|
|
)
|
|
|
|
|
except asyncio.TimeoutError:
|
|
|
|
|
logger.error(f"发送限制消息超时: {limit.module}", LOGGER_COMMAND)
|
|
|
|
|
raise SkipPluginException(
|
|
|
|
|
f"{limit.module}({limit.limit_type}) 正在限制中..."
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
logger.debug(
|
|
|
|
|
f"开始进行限制 {limit.module}({limit.limit_type})...",
|
|
|
|
|
LOGGER_COMMAND,
|
|
|
|
|
session=user_id,
|
|
|
|
|
group_id=group_id,
|
|
|
|
|
)
|
|
|
|
|
if isinstance(limiter, FreqLimiter):
|
|
|
|
|
limiter.start_cd(key_type)
|
|
|
|
|
if isinstance(limiter, UserBlockLimiter):
|
|
|
|
|
limiter.set_true(key_type)
|
|
|
|
|
if isinstance(limiter, CountLimiter):
|
|
|
|
|
limiter.increase(key_type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def auth_limit(plugin: PluginInfo, session: Uninfo):
|
|
|
|
|
"""插件限制
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
plugin: PluginInfo
|
|
|
|
|
session: Uninfo
|
|
|
|
|
"""
|
|
|
|
|
entity = get_entity_ids(session)
|
|
|
|
|
try:
|
|
|
|
|
await asyncio.wait_for(
|
|
|
|
|
LimitManager.check(
|
|
|
|
|
plugin.module, entity.user_id, entity.group_id, entity.channel_id
|
|
|
|
|
),
|
|
|
|
|
timeout=DB_TIMEOUT_SECONDS * 2, # 给予更长的超时时间
|
|
|
|
|
)
|
|
|
|
|
except asyncio.TimeoutError:
|
|
|
|
|
logger.error(f"检查插件限制超时: {plugin.module}", LOGGER_COMMAND)
|
|
|
|
|
# 超时时不抛出异常,允许继续执行
|