mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 06:12:53 +08:00
- 【新功能】 - 新增标签定向广播功能,支持通过 `-t <标签名>` 或 `广播到 <标签名>` 命令向指定标签的群组发送消息 - 引入广播强制发送模式,允许绕过群组的任务阻断设置 - 实现广播并发控制,通过配置限制同时发送任务数量,避免API速率限制 - 优化视频消息处理,支持从URL下载视频内容并作为原始数据发送,提高跨平台兼容性 - 【配置】 - 添加 `DEFAULT_BROADCAST` 配置项,用于设置群组进群时广播功能的默认开关状态 - 添加 `BROADCAST_CONCURRENCY_LIMIT` 配置项,用于控制广播时的最大并发任务数
560 lines
20 KiB
Python
560 lines
20 KiB
Python
import asyncio
|
||
import random
|
||
import traceback
|
||
from typing import ClassVar
|
||
|
||
from nonebot.adapters import Bot
|
||
from nonebot.adapters.onebot.v11 import Bot as V11Bot
|
||
from nonebot.exception import ActionFailed, AdapterException
|
||
from nonebot_plugin_alconna import UniMessage
|
||
from nonebot_plugin_alconna.uniseg import Receipt, Reference
|
||
from nonebot_plugin_session import EventSession
|
||
|
||
from zhenxun.configs.config import Config
|
||
from zhenxun.models.group_console import GroupConsole
|
||
from zhenxun.services.log import logger
|
||
from zhenxun.utils.common_utils import CommonUtils
|
||
from zhenxun.utils.platform import PlatformUtils
|
||
|
||
from .models import BroadcastDetailResult, BroadcastResult
|
||
from .utils import custom_nodes_to_v11_nodes, uni_message_to_v11_list_of_dicts
|
||
|
||
BROADCAST_SEND_DELAY_RANGE = (1, 3)
|
||
|
||
|
||
class BroadcastManager:
|
||
"""广播管理器"""
|
||
|
||
_last_broadcast_msg_ids: ClassVar[dict[str, int]] = {}
|
||
|
||
@staticmethod
|
||
def _get_session_info(session: EventSession | None) -> str:
|
||
"""获取会话信息字符串"""
|
||
if not session:
|
||
return ""
|
||
|
||
try:
|
||
platform = getattr(session, "platform", "unknown")
|
||
session_id = str(session)
|
||
return f"[{platform}:{session_id}]"
|
||
except Exception:
|
||
return "[session-info-error]"
|
||
|
||
@staticmethod
|
||
def log_error(
|
||
message: str, error: Exception, session: EventSession | None = None, **kwargs
|
||
):
|
||
"""记录错误日志"""
|
||
session_info = BroadcastManager._get_session_info(session)
|
||
error_type = type(error).__name__
|
||
stack_trace = traceback.format_exc()
|
||
error_details = f"\n类型: {error_type}\n信息: {error!s}\n堆栈: {stack_trace}"
|
||
|
||
logger.error(
|
||
f"{session_info} {message}{error_details}", "广播", e=error, **kwargs
|
||
)
|
||
|
||
@staticmethod
|
||
def log_warning(message: str, session: EventSession | None = None, **kwargs):
|
||
"""记录警告级别日志"""
|
||
session_info = BroadcastManager._get_session_info(session)
|
||
logger.warning(f"{session_info} {message}", "广播", **kwargs)
|
||
|
||
@staticmethod
|
||
def log_info(message: str, session: EventSession | None = None, **kwargs):
|
||
"""记录信息级别日志"""
|
||
session_info = BroadcastManager._get_session_info(session)
|
||
logger.info(f"{session_info} {message}", "广播", **kwargs)
|
||
|
||
@classmethod
|
||
def get_last_broadcast_msg_ids(cls) -> dict[str, int]:
|
||
"""获取最近广播消息ID"""
|
||
return cls._last_broadcast_msg_ids.copy()
|
||
|
||
@classmethod
|
||
def clear_last_broadcast_msg_ids(cls) -> None:
|
||
"""清空消息ID记录"""
|
||
cls._last_broadcast_msg_ids.clear()
|
||
|
||
@classmethod
|
||
async def get_all_groups(cls, bot: Bot) -> tuple[list[GroupConsole], str]:
|
||
"""获取群组列表"""
|
||
return await PlatformUtils.get_group_list(bot)
|
||
|
||
@classmethod
|
||
async def send(
|
||
cls, bot: Bot, message: UniMessage, session: EventSession
|
||
) -> BroadcastResult:
|
||
"""发送广播到所有群组"""
|
||
logger.debug(
|
||
f"开始广播(send - 广播到所有群组),Bot ID: {bot.self_id}",
|
||
"广播",
|
||
session=session,
|
||
)
|
||
|
||
logger.debug("清空上一次的广播消息ID记录", "广播", session=session)
|
||
cls.clear_last_broadcast_msg_ids()
|
||
|
||
concurrency_limit = Config.get_config(
|
||
"_task",
|
||
"BROADCAST_CONCURRENCY_LIMIT",
|
||
10,
|
||
)
|
||
|
||
all_groups, _ = await cls.get_all_groups(bot)
|
||
return await cls.send_to_specific_groups(
|
||
bot, message, all_groups, session, concurrency_limit=concurrency_limit
|
||
)
|
||
|
||
@classmethod
|
||
async def send_to_specific_groups(
|
||
cls,
|
||
bot: Bot,
|
||
message: UniMessage,
|
||
target_groups: list[GroupConsole],
|
||
session_info: EventSession | str | None = None,
|
||
force_send: bool = False,
|
||
concurrency_limit: int = 10,
|
||
) -> BroadcastResult:
|
||
"""发送广播到指定群组"""
|
||
log_session = session_info or bot.self_id
|
||
target_count = len(target_groups)
|
||
log_message = (
|
||
f"开始广播,目标 {target_count} 个群组 (并发数: {concurrency_limit}),"
|
||
f"Bot ID: {bot.self_id}, ForceSend: {force_send}"
|
||
)
|
||
logger.info(log_message, "广播", session=log_session)
|
||
|
||
if not target_groups:
|
||
logger.debug("目标群组列表为空,广播结束", "广播", session=log_session)
|
||
return 0, 0
|
||
|
||
platform = PlatformUtils.get_platform(bot)
|
||
is_forward_broadcast = any(
|
||
isinstance(seg, Reference) and getattr(seg, "nodes", None)
|
||
for seg in message
|
||
)
|
||
|
||
if platform == "qq" and isinstance(bot, V11Bot) and is_forward_broadcast:
|
||
if (
|
||
len(message) == 1
|
||
and isinstance(message[0], Reference)
|
||
and getattr(message[0], "nodes", None)
|
||
):
|
||
nodes_list = getattr(message[0], "nodes", [])
|
||
v11_nodes = custom_nodes_to_v11_nodes(nodes_list)
|
||
node_count = len(v11_nodes)
|
||
logger.debug(
|
||
f"从 UniMessage<Reference> 构造转发节点数: {node_count}",
|
||
"广播",
|
||
session=log_session,
|
||
)
|
||
else:
|
||
logger.warning(
|
||
"广播消息包含合并转发段和其他段,将尝试打平成一个节点发送",
|
||
"广播",
|
||
session=log_session,
|
||
)
|
||
v11_content_list = uni_message_to_v11_list_of_dicts(message)
|
||
v11_nodes = (
|
||
[
|
||
{
|
||
"type": "node",
|
||
"data": {
|
||
"user_id": bot.self_id,
|
||
"nickname": "广播",
|
||
"content": v11_content_list,
|
||
},
|
||
}
|
||
]
|
||
if v11_content_list
|
||
else []
|
||
)
|
||
|
||
if not v11_nodes:
|
||
logger.warning(
|
||
"构造出的 V11 合并转发节点为空,无法发送",
|
||
"广播",
|
||
session=log_session,
|
||
)
|
||
return 0, len(target_groups)
|
||
success_count, error_count, skip_count = await cls._broadcast_forward(
|
||
bot,
|
||
log_session,
|
||
target_groups,
|
||
v11_nodes,
|
||
force_send,
|
||
concurrency_limit,
|
||
)
|
||
else:
|
||
if is_forward_broadcast:
|
||
logger.warning(
|
||
f"合并转发消息在适配器 ({platform}) 不支持,将作为普通消息发送",
|
||
"广播",
|
||
session=log_session,
|
||
)
|
||
success_count, error_count, skip_count = await cls._broadcast_normal(
|
||
bot,
|
||
log_session,
|
||
target_groups,
|
||
message,
|
||
force_send,
|
||
concurrency_limit,
|
||
)
|
||
|
||
total = len(target_groups)
|
||
stats = f"成功: {success_count}, 失败: {error_count}"
|
||
stats += f", 跳过: {skip_count}, 总计: {total}"
|
||
logger.debug(
|
||
f"广播统计 - {stats}",
|
||
"广播",
|
||
session=log_session,
|
||
)
|
||
|
||
msg_ids = cls.get_last_broadcast_msg_ids()
|
||
if msg_ids:
|
||
id_list_str = ", ".join([f"{k}:{v}" for k, v in msg_ids.items()])
|
||
logger.debug(
|
||
f"广播结束,记录了 {len(msg_ids)} 条消息ID: {id_list_str}",
|
||
"广播",
|
||
session=log_session,
|
||
)
|
||
else:
|
||
logger.warning(
|
||
"广播结束,但没有记录任何消息ID",
|
||
"广播",
|
||
session=log_session,
|
||
)
|
||
|
||
return success_count, error_count
|
||
|
||
@classmethod
|
||
async def _extract_message_id_from_result(
|
||
cls,
|
||
result: dict | Receipt,
|
||
group_key: str,
|
||
session_info: EventSession | str,
|
||
msg_type: str = "普通",
|
||
) -> None:
|
||
"""提取消息ID并记录"""
|
||
if isinstance(result, dict) and "message_id" in result:
|
||
msg_id = result["message_id"]
|
||
try:
|
||
msg_id_int = int(msg_id)
|
||
cls._last_broadcast_msg_ids[group_key] = msg_id_int
|
||
logger.debug(
|
||
f"记录群 {group_key} 的{msg_type}消息ID: {msg_id_int}",
|
||
"广播",
|
||
session=session_info,
|
||
)
|
||
except (ValueError, TypeError):
|
||
logger.warning(
|
||
f"{msg_type}结果中的 message_id 不是有效整数: {msg_id}",
|
||
"广播",
|
||
session=session_info,
|
||
)
|
||
elif isinstance(result, Receipt) and result.msg_ids:
|
||
try:
|
||
first_id_info = result.msg_ids[0]
|
||
msg_id = None
|
||
if isinstance(first_id_info, dict) and "message_id" in first_id_info:
|
||
msg_id = first_id_info["message_id"]
|
||
logger.debug(
|
||
f"从 Receipt.msg_ids[0] 提取到 ID: {msg_id}",
|
||
"广播",
|
||
session=session_info,
|
||
)
|
||
elif isinstance(first_id_info, int | str):
|
||
msg_id = first_id_info
|
||
logger.debug(
|
||
f"从 Receipt.msg_ids[0] 提取到原始ID: {msg_id}",
|
||
"广播",
|
||
session=session_info,
|
||
)
|
||
|
||
if msg_id is not None:
|
||
try:
|
||
msg_id_int = int(msg_id)
|
||
cls._last_broadcast_msg_ids[group_key] = msg_id_int
|
||
logger.debug(
|
||
f"记录群 {group_key} 的消息ID: {msg_id_int}",
|
||
"广播",
|
||
session=session_info,
|
||
)
|
||
except (ValueError, TypeError):
|
||
logger.warning(
|
||
f"提取的ID ({msg_id}) 不是有效整数",
|
||
"广播",
|
||
session=session_info,
|
||
)
|
||
else:
|
||
info_str = str(first_id_info)
|
||
logger.warning(
|
||
f"无法从 Receipt.msg_ids[0] 提取ID: {info_str}",
|
||
"广播",
|
||
session=session_info,
|
||
)
|
||
except IndexError:
|
||
logger.warning("Receipt.msg_ids 为空", "广播", session=session_info)
|
||
except Exception as e_extract:
|
||
logger.error(
|
||
f"从 Receipt 提取 msg_id 时出错: {e_extract}",
|
||
"广播",
|
||
session=session_info,
|
||
e=e_extract,
|
||
)
|
||
else:
|
||
logger.warning(
|
||
f"发送成功但无法从结果获取消息 ID. 结果: {result}",
|
||
"广播",
|
||
session=session_info,
|
||
)
|
||
|
||
@classmethod
|
||
async def _check_group_availability(
|
||
cls, bot: Bot, group: GroupConsole, force_send: bool = False
|
||
) -> bool:
|
||
"""检查群组是否可用"""
|
||
if not group.group_id:
|
||
return False
|
||
|
||
if force_send:
|
||
return True
|
||
|
||
if await CommonUtils.task_is_block(bot, "broadcast", group.group_id):
|
||
return False
|
||
|
||
return True
|
||
|
||
@classmethod
|
||
async def _broadcast_forward(
|
||
cls,
|
||
bot: V11Bot,
|
||
session_info: EventSession | str,
|
||
group_list: list[GroupConsole],
|
||
v11_nodes: list[dict],
|
||
force_send: bool = False,
|
||
concurrency_limit: int = 10,
|
||
) -> BroadcastDetailResult:
|
||
"""发送合并转发"""
|
||
semaphore = asyncio.Semaphore(concurrency_limit)
|
||
msg_id_lock = asyncio.Lock()
|
||
|
||
async def send_to_group(group: GroupConsole) -> GroupConsole:
|
||
group_key = group.group_id or group.channel_id
|
||
async with semaphore:
|
||
try:
|
||
result = await bot.send_group_forward_msg(
|
||
group_id=int(group.group_id), messages=v11_nodes
|
||
)
|
||
async with msg_id_lock:
|
||
await cls._extract_message_id_from_result(
|
||
result, group_key, session_info, "合并转发"
|
||
)
|
||
await asyncio.sleep(random.uniform(*BROADCAST_SEND_DELAY_RANGE))
|
||
return group
|
||
except (ActionFailed, AdapterException) as ae:
|
||
logger.error(
|
||
f"发送失败(合并转发) to {group_key}: {ae}",
|
||
"广播",
|
||
session=session_info,
|
||
e=ae,
|
||
)
|
||
raise
|
||
except Exception as e:
|
||
logger.error(
|
||
f"发送失败(合并转发) to {group_key}: {e}",
|
||
"广播",
|
||
session=session_info,
|
||
e=e,
|
||
)
|
||
raise
|
||
|
||
tasks: list[asyncio.Task] = []
|
||
skipped_groups: list[GroupConsole] = []
|
||
for group in group_list:
|
||
if await cls._check_group_availability(bot, group, force_send):
|
||
tasks.append(asyncio.create_task(send_to_group(group)))
|
||
else:
|
||
skipped_groups.append(group)
|
||
|
||
if skipped_groups:
|
||
logger.info(
|
||
f"跳过 {len(skipped_groups)} 个不符合条件的群组",
|
||
"广播",
|
||
session=session_info,
|
||
)
|
||
|
||
if not tasks:
|
||
return 0, 0, len(skipped_groups)
|
||
|
||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
success_count = sum(
|
||
1 for result in results if not isinstance(result, Exception)
|
||
)
|
||
error_count = len(results) - success_count
|
||
|
||
return success_count, error_count, len(skipped_groups)
|
||
|
||
@classmethod
|
||
async def _broadcast_normal(
|
||
cls,
|
||
bot: Bot,
|
||
session_info: EventSession | str,
|
||
group_list: list[GroupConsole],
|
||
message: UniMessage,
|
||
force_send: bool = False,
|
||
concurrency_limit: int = 10,
|
||
) -> BroadcastDetailResult:
|
||
"""发送普通消息"""
|
||
semaphore = asyncio.Semaphore(concurrency_limit)
|
||
msg_id_lock = asyncio.Lock()
|
||
|
||
async def send_to_group(group: GroupConsole) -> GroupConsole:
|
||
group_key = (
|
||
f"{group.group_id}:{group.channel_id}"
|
||
if group.channel_id
|
||
else str(group.group_id)
|
||
)
|
||
target = PlatformUtils.get_target(
|
||
group_id=group.group_id, channel_id=group.channel_id
|
||
)
|
||
if not target:
|
||
logger.warning(
|
||
"target为空",
|
||
"广播",
|
||
session=session_info,
|
||
target=group_key,
|
||
)
|
||
raise ValueError(f"无法为群组 {group_key} 创建发送目标")
|
||
|
||
async with semaphore:
|
||
try:
|
||
receipt: Receipt = await message.send(target, bot=bot)
|
||
async with msg_id_lock:
|
||
await cls._extract_message_id_from_result(
|
||
receipt, group_key, session_info
|
||
)
|
||
await asyncio.sleep(random.uniform(*BROADCAST_SEND_DELAY_RANGE))
|
||
return group
|
||
except (ActionFailed, AdapterException) as ae:
|
||
logger.error(
|
||
f"发送失败(普通) to {group_key}: {ae}",
|
||
"广播",
|
||
session=session_info,
|
||
e=ae,
|
||
)
|
||
raise
|
||
except Exception as e:
|
||
logger.error(
|
||
f"发送失败(普通) to {group_key}: {e}",
|
||
"广播",
|
||
session=session_info,
|
||
e=e,
|
||
)
|
||
raise
|
||
|
||
tasks: list[asyncio.Task] = []
|
||
skipped_groups: list[GroupConsole] = []
|
||
for group in group_list:
|
||
if await cls._check_group_availability(bot, group, force_send):
|
||
tasks.append(asyncio.create_task(send_to_group(group)))
|
||
else:
|
||
skipped_groups.append(group)
|
||
|
||
if skipped_groups:
|
||
logger.info(
|
||
f"跳过 {len(skipped_groups)} 个不符合条件的群组",
|
||
"广播",
|
||
session=session_info,
|
||
)
|
||
|
||
if not tasks:
|
||
return 0, 0, len(skipped_groups)
|
||
|
||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
success_count = sum(
|
||
1 for result in results if not isinstance(result, Exception)
|
||
)
|
||
error_count = len(results) - success_count
|
||
|
||
return success_count, error_count, len(skipped_groups)
|
||
|
||
@classmethod
|
||
async def recall_last_broadcast(
|
||
cls, bot: Bot, session_info: EventSession | str
|
||
) -> BroadcastResult:
|
||
"""撤回最近广播"""
|
||
msg_ids_to_recall = cls.get_last_broadcast_msg_ids()
|
||
|
||
if not msg_ids_to_recall:
|
||
logger.warning(
|
||
"没有找到最近的广播消息ID记录", "广播撤回", session=session_info
|
||
)
|
||
return 0, 0
|
||
|
||
id_list_str = ", ".join([f"{k}:{v}" for k, v in msg_ids_to_recall.items()])
|
||
logger.debug(
|
||
f"找到 {len(msg_ids_to_recall)} 条广播消息ID记录: {id_list_str}",
|
||
"广播撤回",
|
||
session=session_info,
|
||
)
|
||
|
||
success_count = 0
|
||
error_count = 0
|
||
|
||
logger.info(
|
||
f"准备撤回 {len(msg_ids_to_recall)} 条广播消息",
|
||
"广播撤回",
|
||
session=session_info,
|
||
)
|
||
|
||
for group_key, msg_id in msg_ids_to_recall.items():
|
||
try:
|
||
logger.debug(
|
||
f"尝试撤回消息 (ID: {msg_id}) in {group_key}",
|
||
"广播撤回",
|
||
session=session_info,
|
||
)
|
||
await bot.call_api("delete_msg", message_id=msg_id)
|
||
success_count += 1
|
||
except ActionFailed as af_e:
|
||
retcode = getattr(af_e, "retcode", None)
|
||
wording = getattr(af_e, "wording", "")
|
||
if retcode == 100 and "MESSAGE_NOT_FOUND" in wording.upper():
|
||
logger.warning(
|
||
f"消息 (ID: {msg_id}) 可能已被撤回或不存在于 {group_key}",
|
||
"广播撤回",
|
||
session=session_info,
|
||
)
|
||
elif retcode == 300 and "delete message" in wording.lower():
|
||
logger.warning(
|
||
f"消息 (ID: {msg_id}) 可能已被撤回或不存在于 {group_key}",
|
||
"广播撤回",
|
||
session=session_info,
|
||
)
|
||
else:
|
||
error_count += 1
|
||
logger.error(
|
||
f"撤回消息失败 (ID: {msg_id}) in {group_key}: {af_e}",
|
||
"广播撤回",
|
||
session=session_info,
|
||
e=af_e,
|
||
)
|
||
except Exception as e:
|
||
error_count += 1
|
||
logger.error(
|
||
f"撤回消息时发生未知错误 (ID: {msg_id}) in {group_key}: {e}",
|
||
"广播撤回",
|
||
session=session_info,
|
||
e=e,
|
||
)
|
||
await asyncio.sleep(0.2)
|
||
|
||
logger.debug("撤回操作完成,清空消息ID记录", "广播撤回", session=session_info)
|
||
cls.clear_last_broadcast_msg_ids()
|
||
|
||
return success_count, error_count
|