🎨 优化广播方法

This commit is contained in:
HibiKier 2025-04-13 01:14:00 +08:00
parent 2e087c797d
commit d0093f8a9b

View File

@ -1,7 +1,7 @@
import asyncio import asyncio
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
import random import random
from typing import Literal from typing import cast
import httpx import httpx
import nonebot import nonebot
@ -468,14 +468,104 @@ class PlatformUtils:
return target return target
class BroadcastEngine:
def __init__(
self,
message: str | UniMessage,
bot: Bot | list[Bot] | None = None,
bot_id: str | set[str] | None = None,
ignore_group: list[str] | None = None,
check_func: Callable[[Bot, str], Awaitable] | None = None,
log_cmd: str | None = None,
platform: str | None = None,
):
if ignore_group is None:
ignore_group = []
self.message = MessageUtils.build_message(message)
self.ignore_group = ignore_group
self.check_func = check_func
self.log_cmd = log_cmd
self.platform = platform
self.bot_list = []
if bot:
self.bot_list = [bot] if isinstance(bot, Bot) else bot
if isinstance(bot_id, str):
bot_id = set(bot_id)
if bot_id:
for i in bot_id:
try:
self.bot_list.append(nonebot.get_bot(i))
except KeyError:
logger.warning(f"Bot:{i} 对象未连接或不存在")
if not self.bot_list:
raise ValueError("当前没有可用的Bot对象...", log_cmd)
async def call_check(self, bot: Bot, group_id: str) -> bool:
"""运行发送检测函数
参数:
bot: Bot
group_id: 群组id
返回:
bool: 是否发送
"""
if not self.check_func:
return True
if is_coroutine_callable(self.check_func):
is_run = await self.check_func(bot, group_id)
else:
is_run = self.check_func(bot, group_id)
return cast(bool, is_run)
async def __send_message(self, bot: Bot, group: GroupConsole):
key = f"{group.group_id}:{group.channel_id}"
if not self.call_check(bot, group.group_id):
return logger.debug(
"广播方法检测运行方法为 False, 已跳过该群组...",
self.log_cmd,
group_id=group.group_id,
)
if target := PlatformUtils.get_target(
group_id=group.group_id,
channel_id=group.channel_id,
):
self.ignore_group.append(key)
await MessageUtils.build_message(self.message).send(target, bot)
logger.debug("广播消息发送成功...", self.log_cmd, target=key)
else:
logger.warning("广播消息获取Target失败...", self.log_cmd, target=key)
async def broadcast(self):
for bot in self.bot_list:
if self.platform and self.platform != PlatformUtils.get_platform(bot):
continue
group_list, _ = await PlatformUtils.get_group_list(bot)
if not group_list:
return
for group in group_list:
if (
group.group_id in self.ignore_group
or group.channel_id in self.ignore_group
):
continue
try:
await self.__send_message(bot, group)
await asyncio.sleep(random.randint(1, 3))
except Exception as e:
logger.warning(
"广播消息发送失败", self.log_cmd, target=group.group_id, e=e
)
async def broadcast_group( async def broadcast_group(
message: str | UniMessage, message: str | UniMessage,
bot: Bot | list[Bot] | None = None, bot: Bot | list[Bot] | None = None,
bot_id: str | set[str] | None = None, bot_id: str | set[str] | None = None,
ignore_group: set[int] | None = None, ignore_group: list[str] = [],
check_func: Callable[[Bot, str], Awaitable] | None = None, check_func: Callable[[Bot, str], Awaitable] | None = None,
log_cmd: str | None = None, log_cmd: str | None = None,
platform: Literal["qq", "dodo", "kaiheila"] | None = None, platform: str | None = None,
): ):
"""获取所有Bot或指定Bot对象广播群聊 """获取所有Bot或指定Bot对象广播群聊
@ -488,80 +578,14 @@ async def broadcast_group(
log_cmd: 日志标记. log_cmd: 日志标记.
platform: 指定平台 platform: 指定平台
""" """
if platform and platform not in ["qq", "dodo", "kaiheila"]: if not message.strip():
raise ValueError("指定平台不支持")
if not message:
raise ValueError("群聊广播消息不能为空") raise ValueError("群聊广播消息不能为空")
bot_dict = nonebot.get_bots() await BroadcastEngine(
bot_list: list[Bot] = [] message=message,
if bot: bot=bot,
if isinstance(bot, list): bot_id=bot_id,
bot_list = bot ignore_group=ignore_group,
else: check_func=check_func,
bot_list.append(bot) log_cmd=log_cmd,
elif bot_id: platform=platform,
_bot_id_list = bot_id ).broadcast()
if isinstance(bot_id, str):
_bot_id_list = [bot_id]
for id_ in _bot_id_list:
if bot_id in bot_dict:
bot_list.append(bot_dict[bot_id])
else:
logger.warning(f"Bot:{id_} 对象未连接或不存在")
else:
bot_list = list(bot_dict.values())
_used_group = []
for _bot in bot_list:
try:
if platform and platform != PlatformUtils.get_platform(_bot):
continue
group_list, _ = await PlatformUtils.get_group_list(_bot)
if group_list:
for group in group_list:
key = f"{group.group_id}:{group.channel_id}"
try:
if (
ignore_group
and (
group.group_id in ignore_group
or group.channel_id in ignore_group
)
) or key in _used_group:
logger.debug(
"广播方法群组重复, 已跳过...",
log_cmd,
group_id=group.group_id,
)
continue
is_run = False
if check_func:
if is_coroutine_callable(check_func):
is_run = await check_func(_bot, group.group_id)
else:
is_run = check_func(_bot, group.group_id)
if not is_run:
logger.debug(
"广播方法检测运行方法为 False, 已跳过...",
log_cmd,
group_id=group.group_id,
)
continue
target = PlatformUtils.get_target(
user_id=None,
group_id=group.group_id,
channel_id=group.channel_id,
)
if target:
_used_group.append(key)
message_list = message
await MessageUtils.build_message(message_list).send(
target, _bot
)
logger.debug("发送成功", log_cmd, target=key)
await asyncio.sleep(random.randint(1, 3))
else:
logger.warning("target为空", log_cmd, target=key)
except Exception as e:
logger.error("发送失败", log_cmd, target=key, e=e)
except Exception as e:
logger.error(f"Bot: {_bot.self_id} 获取群聊列表失败", command=log_cmd, e=e)